From ad748dadd66f6e0e9620d95dfa5b172ed67f28b0 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 9 Dec 2025 15:58:03 -0600 Subject: [PATCH 01/51] GEMM reference HIP implementation --- tests/cpp/operator/test_cublaslt_gemm.cu | 309 ++++++++++++++++++----- 1 file changed, 245 insertions(+), 64 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 071470bdf..e1e0b9316 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -51,11 +51,224 @@ using TShape = std::vector; } // namespace -float ref_gelu(float x){ +__device__ __host__ __forceinline__ float ref_gelu(float x){ float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); return x * cdf; } +template +__global__ void compute_ref_kernel( + const A_Type* __restrict__ a_data, + const B_Type* __restrict__ b_data, + float a_scale_inv_scalar, // used when mxfp8 == false + float b_scale_inv_scalar, + const fp8e8m0* __restrict__ a_scale_inv_mxfp8, // used when mxfp8 == true + const fp8e8m0* __restrict__ b_scale_inv_mxfp8, + const Bias_Type* __restrict__ bias_data, + float d_scale, + size_t m, size_t k, size_t n, + D_Type* __restrict__ d_data, + float* __restrict__ d_amax, + Gelu_Type* __restrict__ gelu_data, + bool transa, + bool transb, + bool is_fp8_output) +{ + const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; + const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; + + if (ii >= m || jj >= n) + return; + + float val = 0.0f; + + for (size_t kk = 0; kk < k; ++kk) { + const size_t a_idx = transa ? (ii * k + kk) : (kk * m + ii); + const size_t b_idx = transb ? (kk * n + jj) : (jj * k + kk); + + float a_scale_inv_val = a_scale_inv_scalar; + float b_scale_inv_val = b_scale_inv_scalar; + + if (a_scale_inv_mxfp8) { + const size_t a_scale_idx = + transa ? (a_idx / 32) : ((kk / 32) * m + ii); + const size_t b_scale_idx = + transb ? ((kk / 32) * n + jj) : (b_idx / 32); + + const float a_byte = static_cast(a_scale_inv_mxfp8[a_scale_idx]); + const float b_byte = static_cast(b_scale_inv_mxfp8[b_scale_idx]); + + a_scale_inv_val = exp2f(a_byte - 127.0f); + b_scale_inv_val = exp2f(b_byte - 127.0f); + } + + const float a_val = a_data[a_idx]; + const float b_val = b_data[b_idx]; + + val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; + } + + if (bias_data) { + val += (float)bias_data[ii]; + } + + if (gelu_data) { + gelu_data[ii + jj * m] = val; + val = ref_gelu(val); + } + + const float scaled = val * d_scale; + d_data[ii + jj * m] = scaled; + + if (is_fp8_output && d_amax) { + atomicMax(d_amax, fabsf(val)); + } +} + +// Common implementation used by both tensor-wise and MXFP8 frontends +template +static void compute_ref_impl( + const A_Type* a_data, + const B_Type* b_data, + float a_scale_inv_scalar, // used when mxfp8 == false + float b_scale_inv_scalar, + const fp8e8m0* a_scale_inv_mxfp8, // used when mxfp8 == true + const fp8e8m0* b_scale_inv_mxfp8, + const Bias_Type* bias_data, + float d_scale, + size_t m, size_t k, size_t n, + D_Type* d_data, + float* d_amax_host, + Gelu_Type* gelu_data, + bool transa, + bool transb) +{ + using transformer_engine::DType; + using ::TypeInfo; + using ::isFp8Type; + + const bool use_mxfp8 = (a_scale_inv_mxfp8 != nullptr); + + const DType dtype = TypeInfo::dtype; + const bool is_fp8_output = isFp8Type(dtype); + + const size_t lenA = m * k; + const size_t lenB = k * n; + const size_t lenD = m * n; + const size_t lenBias = m; + const size_t lenGelu = m * n; + + const size_t lenA_scale = use_mxfp8 ? (lenA + 31) / 32 : 0; + const size_t lenB_scale = use_mxfp8 ? (lenB + 31) / 32 : 0; + + A_Type* dA = nullptr; + B_Type* dB = nullptr; + Bias_Type* dBias = nullptr; + D_Type* dD = nullptr; + Gelu_Type* dGelu = nullptr; + float* dAmax = nullptr; + fp8e8m0* dA_scale = nullptr; + fp8e8m0* dB_scale = nullptr; + + // Allocations and H2D transfers + NVTE_CHECK_CUDA(cudaMalloc(&dA, lenA * sizeof(A_Type))); + NVTE_CHECK_CUDA(cudaMalloc(&dB, lenB * sizeof(B_Type))); + NVTE_CHECK_CUDA(cudaMalloc(&dD, lenD * sizeof(D_Type))); + + NVTE_CHECK_CUDA(cudaMemcpy( + dA, a_data, lenA * sizeof(A_Type), cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy( + dB, b_data, lenB * sizeof(B_Type), cudaMemcpyHostToDevice)); + + if (bias_data) { + NVTE_CHECK_CUDA(cudaMalloc(&dBias, lenBias * sizeof(Bias_Type))); + NVTE_CHECK_CUDA(cudaMemcpy( + dBias, bias_data, lenBias * sizeof(Bias_Type), + cudaMemcpyHostToDevice)); + } + + if (gelu_data) { + NVTE_CHECK_CUDA(cudaMalloc(&dGelu, lenGelu * sizeof(Gelu_Type))); + NVTE_CHECK_CUDA(cudaMemset(dGelu, 0, lenGelu * sizeof(Gelu_Type))); + } + + if (use_mxfp8) { + NVTE_CHECK_CUDA(cudaMalloc(&dA_scale, lenA_scale * sizeof(fp8e8m0))); + NVTE_CHECK_CUDA(cudaMalloc(&dB_scale, lenB_scale * sizeof(fp8e8m0))); + NVTE_CHECK_CUDA(cudaMemcpy( + dA_scale, a_scale_inv_mxfp8, lenA_scale * sizeof(fp8e8m0), + cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy( + dB_scale, b_scale_inv_mxfp8, lenB_scale * sizeof(fp8e8m0), + cudaMemcpyHostToDevice)); + } + + if (is_fp8_output && d_amax_host) { + NVTE_CHECK_CUDA(cudaMalloc(&dAmax, sizeof(float))); + NVTE_CHECK_CUDA(cudaMemset(dAmax, 0, sizeof(float))); + } + + // Kernel launch + dim3 block(16, 16); + dim3 grid((n + block.x - 1) / block.x, (m + block.y - 1) / block.y); + + compute_ref_kernel + <<>>( + dA, + dB, + a_scale_inv_scalar, + b_scale_inv_scalar, + dA_scale, + dB_scale, + dBias, + d_scale, + m, k, n, + dD, + dAmax, + dGelu, + transa, + transb, + is_fp8_output); + + NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + + // D2H copies + NVTE_CHECK_CUDA(cudaMemcpy( + d_data, dD, lenD * sizeof(D_Type), cudaMemcpyDeviceToHost)); + + if (gelu_data) { + NVTE_CHECK_CUDA(cudaMemcpy( + gelu_data, dGelu, lenGelu * sizeof(Gelu_Type), + cudaMemcpyDeviceToHost)); + } + + if (is_fp8_output && d_amax_host) { + NVTE_CHECK_CUDA(cudaMemcpy( + d_amax_host, dAmax, sizeof(float), cudaMemcpyDeviceToHost)); + } else if (d_amax_host) { + *d_amax_host = 0.0f; + } + + // cleanup + NVTE_CHECK_CUDA(cudaFree(dA)); + NVTE_CHECK_CUDA(cudaFree(dB)); + NVTE_CHECK_CUDA(cudaFree(dD)); + if (dBias) + NVTE_CHECK_CUDA(cudaFree(dBias)); + if (dGelu) + NVTE_CHECK_CUDA(cudaFree(dGelu)); + if (dAmax) + NVTE_CHECK_CUDA(cudaFree(dAmax)); + if (dA_scale) + NVTE_CHECK_CUDA(cudaFree(dA_scale)); + if (dB_scale) + NVTE_CHECK_CUDA(cudaFree(dB_scale)); +} + + template void compute_ref( const A_Type* a_data, @@ -71,36 +284,21 @@ void compute_ref( bool transa, bool transb){ - float ref_d_amax = 0; - - #pragma omp parallel for schedule(static) collapse(2) reduction(max: ref_d_amax) proc_bind(spread) - for(size_t ii = 0; ii < m; ii++){ - for(size_t jj = 0; jj < n; jj++){ - float val = 0; - for(size_t kk = 0; kk < k; kk++){ - float a_val = transa ? a_data[kk + ii*k] : a_data[ii + kk*m]; - float b_val = transb ? b_data[jj + kk*n] : b_data[kk + jj*k]; - val += a_scale_inv*a_val*b_scale_inv*b_val; - } - if(bias_data){ - val += (float)bias_data[ii]; - } - if(ref_gelu_data){ - ref_gelu_data[ii + jj*m] = (Gelu_Type)(val); - val = ref_gelu(val); - } - ref_d_data[ii+jj*m] = (D_Type)(val*d_scale); - // update ref_d_amax if in fp8 - DType dtype = TypeInfo::dtype; - if(isFp8Type(dtype)){ - ref_d_amax = std::max(ref_d_amax, std::fabs(val)); - } - } - } - if (ref_d_amax_ptr) - { - *ref_d_amax_ptr = ref_d_amax; - } + compute_ref_impl( + a_data, + b_data, + /*a_scale_inv_scalar=*/a_scale_inv, + /*b_scale_inv_scalar=*/b_scale_inv, + /*a_scale_inv_mxfp8=*/nullptr, + /*b_scale_inv_mxfp8=*/nullptr, + bias_data, + d_scale, + m, k, n, + ref_d_data, + ref_d_amax_ptr, + ref_gelu_data, + transa, + transb); } template @@ -118,38 +316,21 @@ void compute_mxfp8_ref( bool transa, bool transb){ - float ref_d_amax = 0; - - #pragma omp parallel for schedule(static) collapse(2) reduction(max: ref_d_amax) proc_bind(spread) - for(size_t ii = 0; ii < m; ii++){ - for(size_t jj = 0; jj < n; jj++){ - float val = 0; - for(size_t kk = 0; kk < k; kk++){ - size_t a_idx = transa ? (ii*k + kk) : (kk*m + ii); - size_t b_idx = transb ? (kk*n + jj) : (jj*k + kk); - float a_scale_inv_val = std::exp2f(a_scale_inv_data[transa ? a_idx/32 : (kk/32 * m + ii)] - 127); - float b_scale_inv_val = std::exp2f(b_scale_inv_data[transb ? (kk/32 * n + jj) : b_idx/32] - 127); - val += a_scale_inv_val * (float)a_data[a_idx] * b_scale_inv_val * (float)b_data[b_idx]; - } - if(bias_data){ - val += (float)bias_data[ii]; - } - if(ref_gelu_data){ - ref_gelu_data[ii + jj*m] = (Gelu_Type)(val); - val = ref_gelu(val); - } - ref_d_data[ii+jj*m] = (D_Type)(val*d_scale); - // update ref_d_amax if in fp8 - DType dtype = TypeInfo::dtype; - if(isFp8Type(dtype)){ - ref_d_amax = std::max(ref_d_amax, std::fabs(val)); - } - } - } - if (ref_d_amax_ptr) - { - *ref_d_amax_ptr = ref_d_amax; - } + compute_ref_impl( + a_data, + b_data, + /*a_scale_inv_scalar=*/1.0f, + /*b_scale_inv_scalar=*/1.0f, + /*a_scale_inv_mxfp8=*/a_scale_inv_data, + /*b_scale_inv_mxfp8=*/b_scale_inv_data, + bias_data, + d_scale, + m, k, n, + ref_d_data, + ref_d_amax_ptr, + ref_gelu_data, + transa, + transb); } template @@ -371,7 +552,7 @@ void performTest(const TestParams& params) { pre_gelu_out.to_cpu(); } - //perform the gemm in CPU + //perform the reference gemm on GPU std::unique_ptr ref_D = std::make_unique(params.m*params.n); std::unique_ptr ref_pre_gelu_out; if(params.use_gelu){ From 11e090b9e34f0fc792122e232af4e2b863122ef6 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 11 Dec 2025 15:14:53 -0600 Subject: [PATCH 02/51] blockwise amax --- tests/cpp/operator/test_cublaslt_gemm.cu | 86 +++++++++++++++--------- 1 file changed, 55 insertions(+), 31 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index e1e0b9316..0c5f9a759 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -78,51 +78,72 @@ __global__ void compute_ref_kernel( const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; - if (ii >= m || jj >= n) - return; + const bool in_range = (ii < m) && (jj < n); float val = 0.0f; - for (size_t kk = 0; kk < k; ++kk) { - const size_t a_idx = transa ? (ii * k + kk) : (kk * m + ii); - const size_t b_idx = transb ? (kk * n + jj) : (jj * k + kk); + if (in_range) { + for (size_t kk = 0; kk < k; ++kk) { + const size_t a_idx = transa ? (ii * k + kk) : (kk * m + ii); + const size_t b_idx = transb ? (kk * n + jj) : (jj * k + kk); - float a_scale_inv_val = a_scale_inv_scalar; - float b_scale_inv_val = b_scale_inv_scalar; + float a_scale_inv_val = a_scale_inv_scalar; + float b_scale_inv_val = b_scale_inv_scalar; - if (a_scale_inv_mxfp8) { - const size_t a_scale_idx = - transa ? (a_idx / 32) : ((kk / 32) * m + ii); - const size_t b_scale_idx = - transb ? ((kk / 32) * n + jj) : (b_idx / 32); + if (a_scale_inv_mxfp8) { + const size_t a_scale_idx = + transa ? (a_idx / 32) : ((kk / 32) * m + ii); + const size_t b_scale_idx = + transb ? ((kk / 32) * n + jj) : (b_idx / 32); - const float a_byte = static_cast(a_scale_inv_mxfp8[a_scale_idx]); - const float b_byte = static_cast(b_scale_inv_mxfp8[b_scale_idx]); + const float a_byte = static_cast(a_scale_inv_mxfp8[a_scale_idx]); + const float b_byte = static_cast(b_scale_inv_mxfp8[b_scale_idx]); - a_scale_inv_val = exp2f(a_byte - 127.0f); - b_scale_inv_val = exp2f(b_byte - 127.0f); + a_scale_inv_val = exp2f(a_byte - 127.0f); + b_scale_inv_val = exp2f(b_byte - 127.0f); + } + + const float a_val = static_cast(a_data[a_idx]); + const float b_val = static_cast(b_data[b_idx]); + + val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; } - const float a_val = a_data[a_idx]; - const float b_val = b_data[b_idx]; + if (bias_data) { + val += static_cast(bias_data[ii]); + } - val += a_scale_inv_val * a_val * b_scale_inv_val * b_val; - } + if (gelu_data) { + gelu_data[ii + jj * m] = static_cast(val); + val = ref_gelu(val); + } - if (bias_data) { - val += (float)bias_data[ii]; + const float scaled = val * d_scale; + d_data[ii + jj * m] = static_cast(scaled); } - if (gelu_data) { - gelu_data[ii + jj * m] = val; - val = ref_gelu(val); - } + // Blockwise reduction for amax + if (is_fp8_output && d_amax) { + const int tid = threadIdx.y * blockDim.x + threadIdx.x; + const int nthreads = blockDim.x * blockDim.y; - const float scaled = val * d_scale; - d_data[ii + jj * m] = scaled; + extern __shared__ float s_amax[]; - if (is_fp8_output && d_amax) { - atomicMax(d_amax, fabsf(val)); + // Out-of-range threads contribute 0 + s_amax[tid] = in_range ? fabsf(val) : 0.0f; + __syncthreads(); + + for (int offset = nthreads / 2; offset > 0; offset /= 2) { + if (tid < offset) { + s_amax[tid] = fmaxf(s_amax[tid], s_amax[tid + offset]); + } + __syncthreads(); + } + + if (tid == 0) { + const float block_max = s_amax[0]; + atomicMax(d_amax, block_max); + } } } @@ -214,8 +235,11 @@ static void compute_ref_impl( dim3 block(16, 16); dim3 grid((n + block.x - 1) / block.x, (m + block.y - 1) / block.y); + const int nthreads = block.x * block.y; + size_t shmem_bytes = nthreads * sizeof(float); + compute_ref_kernel - <<>>( + <<>>( dA, dB, a_scale_inv_scalar, From 3ecea7fb11748bc6c99250e3de46ebec68dfc778 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 13 Jan 2026 17:13:48 -0600 Subject: [PATCH 03/51] Change to use Tensor arguments, combine mxfp8/non-mxfp8 paths --- tests/cpp/operator/test_cublaslt_gemm.cu | 343 +++++++++-------------- tests/cpp/test_common.h | 14 +- 2 files changed, 137 insertions(+), 220 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 3f5249a6a..631c06c51 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -73,7 +73,9 @@ __global__ void compute_ref_kernel( Gelu_Type* __restrict__ gelu_data, bool transa, bool transb, - bool is_fp8_output) + bool is_fp8_output, + bool a_is_colwise, + bool b_is_colwise) { const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; @@ -84,17 +86,26 @@ __global__ void compute_ref_kernel( if (in_range) { for (size_t kk = 0; kk < k; ++kk) { - const size_t a_idx = transa ? (ii * k + kk) : (kk * m + ii); - const size_t b_idx = transb ? (kk * n + jj) : (jj * k + kk); + // Indexing depends on which backing buffer we passed in + const size_t a_idx = + a_is_colwise ? (ii * k + kk) + : (transa ? (ii * k + kk) : (kk * m + ii)); + + const size_t b_idx = + b_is_colwise ? (jj * k + kk) + : (transb ? (kk * n + jj) : (jj * k + kk)); float a_scale_inv_val = a_scale_inv_scalar; float b_scale_inv_val = b_scale_inv_scalar; if (a_scale_inv_mxfp8) { const size_t a_scale_idx = - transa ? (a_idx / 32) : ((kk / 32) * m + ii); + a_is_colwise ? (a_idx / 32) + : (transa ? (a_idx / 32) : ((kk / 32) * m + ii)); + const size_t b_scale_idx = - transb ? ((kk / 32) * n + jj) : (b_idx / 32); + b_is_colwise ? (b_idx / 32) + : (transb ? ((kk / 32) * n + jj) : (b_idx / 32)); const float a_byte = static_cast(a_scale_inv_mxfp8[a_scale_idx]); const float b_byte = static_cast(b_scale_inv_mxfp8[b_scale_idx]); @@ -147,216 +158,145 @@ __global__ void compute_ref_kernel( } } -// Common implementation used by both tensor-wise and MXFP8 frontends + +struct TestParams { + size_t m; + size_t k; + size_t n; + bool use_bias; + bool use_gelu; + bool transa; + bool transb; + NVTEScalingMode scaling_mode; +}; + + template -static void compute_ref_impl( - const A_Type* a_data, - const B_Type* b_data, - float a_scale_inv_scalar, // used when mxfp8 == false - float b_scale_inv_scalar, - const fp8e8m0* a_scale_inv_mxfp8, // used when mxfp8 == true - const fp8e8m0* b_scale_inv_mxfp8, - const Bias_Type* bias_data, - float d_scale, - size_t m, size_t k, size_t n, - D_Type* d_data, - float* d_amax_host, - Gelu_Type* gelu_data, - bool transa, - bool transb) +static void run_reference( + const TestParams& params, + const Tensor& A, + const Tensor& B, + const Tensor* Bias, // nullable + float d_scale, + std::unique_ptr& ref_D, // m*n + float* ref_amax_d, + std::unique_ptr& ref_pre_gelu_out) // nullable { - using transformer_engine::DType; - using ::TypeInfo; - using ::isFp8Type; + const bool use_mxfp8 = (params.scaling_mode == NVTE_MXFP8_1D_SCALING); - const bool use_mxfp8 = (a_scale_inv_mxfp8 != nullptr); + Gelu_Type* ref_gelu_host = (params.use_gelu ? ref_pre_gelu_out.get() : nullptr); - const DType dtype = TypeInfo::dtype; - const bool is_fp8_output = isFp8Type(dtype); + const bool is_fp8_output = test::isFp8Type(test::TypeInfo::dtype); - const size_t lenA = m * k; - const size_t lenB = k * n; - const size_t lenD = m * n; - const size_t lenBias = m; - const size_t lenGelu = m * n; + const bool a_use_colwise = (!params.transa) && A.columnwise(); + const bool b_use_colwise = ( params.transb) && B.columnwise(); - const size_t lenA_scale = use_mxfp8 ? (lenA + 31) / 32 : 0; - const size_t lenB_scale = use_mxfp8 ? (lenB + 31) / 32 : 0; + const A_Type* a_dev = static_cast( + a_use_colwise ? A.columnwise_dptr() : A.rowwise_dptr()); - A_Type* dA = nullptr; - B_Type* dB = nullptr; - Bias_Type* dBias = nullptr; - D_Type* dD = nullptr; - Gelu_Type* dGelu = nullptr; - float* dAmax = nullptr; - fp8e8m0* dA_scale = nullptr; - fp8e8m0* dB_scale = nullptr; + const B_Type* b_dev = static_cast( + b_use_colwise ? B.columnwise_dptr() : B.rowwise_dptr()); - // Allocations and H2D transfers - NVTE_CHECK_CUDA(cudaMalloc(&dA, lenA * sizeof(A_Type))); - NVTE_CHECK_CUDA(cudaMalloc(&dB, lenB * sizeof(B_Type))); - NVTE_CHECK_CUDA(cudaMalloc(&dD, lenD * sizeof(D_Type))); + // scaling inputs + float a_scale_inv_scalar = 1.0f; + float b_scale_inv_scalar = 1.0f; - NVTE_CHECK_CUDA(cudaMemcpy( - dA, a_data, lenA * sizeof(A_Type), cudaMemcpyHostToDevice)); - NVTE_CHECK_CUDA(cudaMemcpy( - dB, b_data, lenB * sizeof(B_Type), cudaMemcpyHostToDevice)); + const fp8e8m0* a_scale_dev = nullptr; + const fp8e8m0* b_scale_dev = nullptr; - if (bias_data) { - NVTE_CHECK_CUDA(cudaMalloc(&dBias, lenBias * sizeof(Bias_Type))); - NVTE_CHECK_CUDA(cudaMemcpy( - dBias, bias_data, lenBias * sizeof(Bias_Type), - cudaMemcpyHostToDevice)); - } + if (use_mxfp8) { + a_scale_dev = params.transa + ? (const fp8e8m0*) A.rowwise_scale_inv_dptr() + : (const fp8e8m0*) A.columnwise_scale_inv_dptr(); - if (gelu_data) { - NVTE_CHECK_CUDA(cudaMalloc(&dGelu, lenGelu * sizeof(Gelu_Type))); - NVTE_CHECK_CUDA(cudaMemset(dGelu, 0, lenGelu * sizeof(Gelu_Type))); + b_scale_dev = params.transb + ? (const fp8e8m0*) B.columnwise_scale_inv_dptr() + : (const fp8e8m0*) B.rowwise_scale_inv_dptr(); + } else { + a_scale_inv_scalar = A.rowwise_scale_inv(); + b_scale_inv_scalar = B.rowwise_scale_inv(); } - if (use_mxfp8) { - NVTE_CHECK_CUDA(cudaMalloc(&dA_scale, lenA_scale * sizeof(fp8e8m0))); - NVTE_CHECK_CUDA(cudaMalloc(&dB_scale, lenB_scale * sizeof(fp8e8m0))); - NVTE_CHECK_CUDA(cudaMemcpy( - dA_scale, a_scale_inv_mxfp8, lenA_scale * sizeof(fp8e8m0), - cudaMemcpyHostToDevice)); - NVTE_CHECK_CUDA(cudaMemcpy( - dB_scale, b_scale_inv_mxfp8, lenB_scale * sizeof(fp8e8m0), - cudaMemcpyHostToDevice)); + // optional bias device pointer + const Bias_Type* bias_dev = nullptr; + if (Bias) { + bias_dev = static_cast(Bias->rowwise_dptr()); } - if (is_fp8_output && d_amax_host) { - NVTE_CHECK_CUDA(cudaMalloc(&dAmax, sizeof(float))); - NVTE_CHECK_CUDA(cudaMemset(dAmax, 0, sizeof(float))); + // allocate device outputs + const size_t lenD = params.m * params.n; + const size_t bytesD = lenD * sizeof(D_Type); + + D_Type* d_refD = nullptr; + Gelu_Type* d_refGelu = nullptr; + float* d_refAmax = nullptr; + + NVTE_CHECK_CUDA(cudaMalloc(&d_refD, bytesD)); + if (ref_gelu_host) { + NVTE_CHECK_CUDA(cudaMalloc(&d_refGelu, lenD * sizeof(Gelu_Type))); + } + if (is_fp8_output && ref_amax_d) { + NVTE_CHECK_CUDA(cudaMalloc(&d_refAmax, sizeof(float))); + NVTE_CHECK_CUDA(cudaMemset(d_refAmax, 0, sizeof(float))); } // Kernel launch dim3 block(16, 16); - dim3 grid((n + block.x - 1) / block.x, (m + block.y - 1) / block.y); + dim3 grid((unsigned)((params.n + block.x - 1) / block.x), + (unsigned)((params.m + block.y - 1) / block.y)); - const int nthreads = block.x * block.y; - size_t shmem_bytes = nthreads * sizeof(float); + const size_t shmem_bytes = size_t(block.x) * size_t(block.y) * sizeof(float); compute_ref_kernel <<>>( - dA, - dB, + a_dev, + b_dev, a_scale_inv_scalar, b_scale_inv_scalar, - dA_scale, - dB_scale, - dBias, + a_scale_dev, + b_scale_dev, + bias_dev, d_scale, - m, k, n, - dD, - dAmax, - dGelu, - transa, - transb, - is_fp8_output); + params.m, params.k, params.n, + d_refD, + d_refAmax, + d_refGelu, + params.transa, + params.transb, + is_fp8_output, + a_use_colwise, + b_use_colwise); NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaDeviceSynchronize()); - // D2H copies - NVTE_CHECK_CUDA(cudaMemcpy( - d_data, dD, lenD * sizeof(D_Type), cudaMemcpyDeviceToHost)); + // copy outputs back + NVTE_CHECK_CUDA(cudaMemcpy(ref_D.get(), d_refD, bytesD, cudaMemcpyDeviceToHost)); - if (gelu_data) { - NVTE_CHECK_CUDA(cudaMemcpy( - gelu_data, dGelu, lenGelu * sizeof(Gelu_Type), - cudaMemcpyDeviceToHost)); + if (ref_gelu_host) { + NVTE_CHECK_CUDA(cudaMemcpy(ref_gelu_host, d_refGelu, lenD * sizeof(Gelu_Type), + cudaMemcpyDeviceToHost)); } - if (is_fp8_output && d_amax_host) { - NVTE_CHECK_CUDA(cudaMemcpy( - d_amax_host, dAmax, sizeof(float), cudaMemcpyDeviceToHost)); - } else if (d_amax_host) { - *d_amax_host = 0.0f; + if (ref_amax_d) { + if (is_fp8_output) { + NVTE_CHECK_CUDA(cudaMemcpy(ref_amax_d, d_refAmax, sizeof(float), + cudaMemcpyDeviceToHost)); + } else { + *ref_amax_d = 0.0f; + } } // cleanup - NVTE_CHECK_CUDA(cudaFree(dA)); - NVTE_CHECK_CUDA(cudaFree(dB)); - NVTE_CHECK_CUDA(cudaFree(dD)); - if (dBias) - NVTE_CHECK_CUDA(cudaFree(dBias)); - if (dGelu) - NVTE_CHECK_CUDA(cudaFree(dGelu)); - if (dAmax) - NVTE_CHECK_CUDA(cudaFree(dAmax)); - if (dA_scale) - NVTE_CHECK_CUDA(cudaFree(dA_scale)); - if (dB_scale) - NVTE_CHECK_CUDA(cudaFree(dB_scale)); + NVTE_CHECK_CUDA(cudaFree(d_refD)); + if (d_refGelu) + NVTE_CHECK_CUDA(cudaFree(d_refGelu)); + if (d_refAmax) + NVTE_CHECK_CUDA(cudaFree(d_refAmax)); } -template -void compute_ref( - const A_Type* a_data, - const B_Type* b_data, - const float a_scale_inv, - const float b_scale_inv, - const Bias_Type* bias_data, //bias is of dim m - const float d_scale, - size_t m, size_t k, size_t n, - D_Type* ref_d_data, - float* ref_d_amax_ptr, - Gelu_Type* ref_gelu_data, - bool transa, - bool transb){ - - compute_ref_impl( - a_data, - b_data, - /*a_scale_inv_scalar=*/a_scale_inv, - /*b_scale_inv_scalar=*/b_scale_inv, - /*a_scale_inv_mxfp8=*/nullptr, - /*b_scale_inv_mxfp8=*/nullptr, - bias_data, - d_scale, - m, k, n, - ref_d_data, - ref_d_amax_ptr, - ref_gelu_data, - transa, - transb); -} - -template -void compute_mxfp8_ref( - const A_Type* a_data, - const B_Type* b_data, - const fp8e8m0* a_scale_inv_data, - const fp8e8m0* b_scale_inv_data, - const Bias_Type* bias_data, //bias is of dim m - const float d_scale, - size_t m, size_t k, size_t n, - D_Type* ref_d_data, - float* ref_d_amax_ptr, - Gelu_Type* ref_gelu_data, - bool transa, - bool transb){ - - compute_ref_impl( - a_data, - b_data, - /*a_scale_inv_scalar=*/1.0f, - /*b_scale_inv_scalar=*/1.0f, - /*a_scale_inv_mxfp8=*/a_scale_inv_data, - /*b_scale_inv_mxfp8=*/b_scale_inv_data, - bias_data, - d_scale, - m, k, n, - ref_d_data, - ref_d_amax_ptr, - ref_gelu_data, - transa, - transb); -} - template void cpu_rowwise_to_columnwise( size_t m, size_t n, @@ -396,16 +336,6 @@ std::pair getTestTolerances(const DType type, bool use_fp8, bool return {atol, rtol}; } -struct TestParams { - size_t m; - size_t k; - size_t n; - bool use_bias; - bool use_gelu; - bool transa; - bool transb; - NVTEScalingMode scaling_mode; -}; template void performTest(const TestParams& params) { @@ -588,40 +518,17 @@ void performTest(const TestParams& params) { } float ref_amax_d; - if (use_mxfp8) { - const A_Type *a_data; - const B_Type *b_data; - const fp8e8m0 *a_scale_inv_data, *b_scale_inv_data; - if (params.transa) { - a_data = A.rowwise_cpu_dptr(); - a_scale_inv_data = A.rowwise_cpu_scale_inv_ptr(); - } else { - a_data = A.columnwise_cpu_dptr(); - a_scale_inv_data = A.columnwise_cpu_scale_inv_ptr(); - } - if (params.transb) { - b_data = B.columnwise_cpu_dptr(); - b_scale_inv_data = B.columnwise_cpu_scale_inv_ptr(); - } else { - b_data = B.rowwise_cpu_dptr(); - b_scale_inv_data = B.rowwise_cpu_scale_inv_ptr(); - } - compute_mxfp8_ref( - a_data, b_data, a_scale_inv_data, b_scale_inv_data, - params.use_bias ? bias.rowwise_cpu_dptr() : nullptr, - D.scale(), params.m, params.k, params.n, ref_D.get(), &ref_amax_d, - params.use_gelu ? ref_pre_gelu_out.get() : nullptr, - params.transa, params.transb); - } else { - compute_ref( - A.rowwise_cpu_dptr(), B.rowwise_cpu_dptr(), - A.rowwise_scale_inv(), B.rowwise_scale_inv(), - params.use_bias ? bias.rowwise_cpu_dptr() : nullptr, - D.scale(), params.m, params.k, params.n, ref_D.get(), &ref_amax_d, - params.use_gelu ? ref_pre_gelu_out.get() : nullptr, - params.transa, params.transb); - } + run_reference( + params, + A, + B, + params.use_bias ? &bias : nullptr, + D.scale(), + ref_D, + &ref_amax_d, + ref_pre_gelu_out); + // check if error message happens in running (void)cudaDeviceSynchronize(); auto err = cudaGetLastError(); diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index bfb46f8a0..8892ff097 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -224,6 +224,16 @@ class Tensor { return reinterpret_cast(cpu_data_columnwise_.get()); } + void *rowwise_scale_inv_dptr() const { + NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); + return tensor_.get_rowwise_scale_inv().data_ptr; + } + + void *columnwise_scale_inv_dptr() const { + NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); + return tensor_.get_columnwise_scale_inv().data_ptr; + } + float amax() const { if(amax_cpu_data_) { to_cpu(); @@ -244,7 +254,7 @@ class Tensor { } template - T *rowwise_cpu_scale_inv_ptr(){ + T *rowwise_cpu_scale_inv_ptr() const { if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { @@ -269,7 +279,7 @@ class Tensor { return reinterpret_cast(columnwise_scale_inv_cpu_data_.get()); } - float rowwise_scale_inv(){ + float rowwise_scale_inv() const { if(rowwise_scale_inv_cpu_data_) { float scale_inv = rowwise_cpu_scale_inv_ptr()[0]; return scale_inv; From 86fbbac87113f00341062e4a9b150a855207acd6 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Jan 2026 14:17:24 -0600 Subject: [PATCH 04/51] skip on SwizzleScale limitation on gfx950 --- tests/cpp/operator/test_cublaslt_gemm.cu | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 560218575..da59a8dee 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -427,6 +427,11 @@ void performTest(const TestParams& params) { GTEST_SKIP() << "FP8 GEMM with bias is not supported in current config"; } } + + if (use_mxfp8 && (isFp8Type(atype) || isFp8Type(btype)) && (params.transa != true || params.transb != false)) { + GTEST_SKIP() << "On gfx950, MXFP8 FP8/BF8 GEMM currently requires TN (SwizzleScale limitation)."; + } + } if (prop.major == 9 && prop.minor == 4) //gfx942 specific hipblasLt limitations { From 54de3dbd3891e0a0d0f0962fe3ccc4a9eaac759f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Jan 2026 21:44:57 +0000 Subject: [PATCH 05/51] Revert "skip on SwizzleScale limitation on gfx950" This reverts commit 86fbbac87113f00341062e4a9b150a855207acd6. --- tests/cpp/operator/test_cublaslt_gemm.cu | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index da59a8dee..560218575 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -427,11 +427,6 @@ void performTest(const TestParams& params) { GTEST_SKIP() << "FP8 GEMM with bias is not supported in current config"; } } - - if (use_mxfp8 && (isFp8Type(atype) || isFp8Type(btype)) && (params.transa != true || params.transb != false)) { - GTEST_SKIP() << "On gfx950, MXFP8 FP8/BF8 GEMM currently requires TN (SwizzleScale limitation)."; - } - } if (prop.major == 9 && prop.minor == 4) //gfx942 specific hipblasLt limitations { From 311ddfe66bbe738ab550b74dccaf5fb8d885438d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Jan 2026 17:21:58 -0600 Subject: [PATCH 06/51] MXFP8 fix --- tests/cpp/operator/test_cublaslt_gemm.cu | 8 +++----- tests/cpp/test_common.h | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 560218575..3d15ac3d4 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -107,11 +107,9 @@ __global__ void compute_ref_kernel( b_is_colwise ? (b_idx / 32) : (transb ? ((kk / 32) * n + jj) : (b_idx / 32)); - const float a_byte = static_cast(a_scale_inv_mxfp8[a_scale_idx]); - const float b_byte = static_cast(b_scale_inv_mxfp8[b_scale_idx]); - - a_scale_inv_val = exp2f(a_byte - 127.0f); - b_scale_inv_val = exp2f(b_byte - 127.0f); + // scale_inv is stored as an e8m0 biased exponent; convert to 2^(127-exp) + a_scale_inv_val = exp2f_rcp(a_scale_inv_mxfp8[a_scale_idx]); + b_scale_inv_val = exp2f_rcp(b_scale_inv_mxfp8[b_scale_idx]); } const float a_val = static_cast(a_data[a_idx]); diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 8892ff097..2114feacc 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -446,7 +446,7 @@ inline fp8e8m0 float_to_e8m0(float val) { return exponent; } -inline float exp2f_rcp(fp8e8m0 biased_exp) { +__device__ __host__ __forceinline__ float exp2f_rcp(fp8e8m0 biased_exp) { return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); } From 445e64fbce9060bfe5d0f23dedf5de209bcf353f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 15 Jan 2026 14:14:53 -0600 Subject: [PATCH 07/51] =?UTF-8?q?correct=20scale=5Finv=20packing=20and=20e?= =?UTF-8?q?xp2(biased=E2=88=92127)=20conversion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/cpp/operator/test_cublaslt_gemm.cu | 99 ++++++++++++++++++------ tests/cpp/test_common.h | 2 +- 2 files changed, 75 insertions(+), 26 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 3d15ac3d4..376a5fc26 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -75,8 +75,10 @@ __global__ void compute_ref_kernel( bool transb, bool is_fp8_output, bool a_is_colwise, - bool b_is_colwise) + bool b_is_colwise, + bool use_mxfp8) { + const size_t k_chunks = k / 32; const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; @@ -86,30 +88,33 @@ __global__ void compute_ref_kernel( if (in_range) { for (size_t kk = 0; kk < k; ++kk) { - // Indexing depends on which backing buffer we passed in - const size_t a_idx = - a_is_colwise ? (ii * k + kk) - : (transa ? (ii * k + kk) : (kk * m + ii)); - - const size_t b_idx = - b_is_colwise ? (jj * k + kk) - : (transb ? (kk * n + jj) : (jj * k + kk)); + size_t a_idx = 0; + size_t b_idx = 0; + + if (use_mxfp8) { + a_idx = transa ? (ii * k + kk) : (kk * m + ii); + b_idx = transb ? (kk * n + jj) : (jj * k + kk); + } else { + // Non-MXFP8 FP8 path may use explicit transpose buffers (cpu_rowwise_to_columnwise), + // so indexing depends on which backing buffer is passed in. + a_idx = a_is_colwise ? (ii * k + kk) + : (transa ? (ii * k + kk) : (kk * m + ii)); + + b_idx = b_is_colwise ? (jj * k + kk) + : (transb ? (kk * n + jj) : (jj * k + kk)); + } float a_scale_inv_val = a_scale_inv_scalar; float b_scale_inv_val = b_scale_inv_scalar; if (a_scale_inv_mxfp8) { - const size_t a_scale_idx = - a_is_colwise ? (a_idx / 32) - : (transa ? (a_idx / 32) : ((kk / 32) * m + ii)); + const size_t kc = kk / 32; - const size_t b_scale_idx = - b_is_colwise ? (b_idx / 32) - : (transb ? ((kk / 32) * n + jj) : (b_idx / 32)); + const size_t a_scale_idx = ii * k_chunks + kc; + const size_t b_scale_idx = jj * k_chunks + kc; - // scale_inv is stored as an e8m0 biased exponent; convert to 2^(127-exp) - a_scale_inv_val = exp2f_rcp(a_scale_inv_mxfp8[a_scale_idx]); - b_scale_inv_val = exp2f_rcp(b_scale_inv_mxfp8[b_scale_idx]); + a_scale_inv_val = exp2f(a_scale_inv_mxfp8[a_scale_idx] - 127.0f); + b_scale_inv_val = exp2f(b_scale_inv_mxfp8[b_scale_idx] - 127.0f); } const float a_val = static_cast(a_data[a_idx]); @@ -183,6 +188,8 @@ static void run_reference( { const bool use_mxfp8 = (params.scaling_mode == NVTE_MXFP8_1D_SCALING); + const size_t k_chunks = params.k / 32; + Gelu_Type* ref_gelu_host = (params.use_gelu ? ref_pre_gelu_out.get() : nullptr); const bool is_fp8_output = test::isFp8Type(test::TypeInfo::dtype); @@ -203,14 +210,51 @@ static void run_reference( const fp8e8m0* a_scale_dev = nullptr; const fp8e8m0* b_scale_dev = nullptr; + // If MXFP8, pack scale_inv into tight [row][kc] buffers on host, then transfer to device + std::vector a_scale_packed; + std::vector b_scale_packed; + fp8e8m0* d_a_scale_packed = nullptr; + fp8e8m0* d_b_scale_packed = nullptr; + if (use_mxfp8) { - a_scale_dev = params.transa - ? (const fp8e8m0*) A.rowwise_scale_inv_dptr() - : (const fp8e8m0*) A.columnwise_scale_inv_dptr(); + const fp8e8m0* a_scale_cpu = params.transa + ? A.rowwise_cpu_scale_inv_ptr() + : A.columnwise_cpu_scale_inv_ptr(); + const fp8e8m0* b_scale_cpu = params.transb + ? B.columnwise_cpu_scale_inv_ptr() + : B.rowwise_cpu_scale_inv_ptr(); + + // Pack into row-major [row][kc]: + // A_packed[ii, kc] and B_packed[jj, kc] + a_scale_packed.resize(params.m * k_chunks); + b_scale_packed.resize(params.n * k_chunks); + + for (size_t ii = 0; ii < params.m; ++ii) { + for (size_t kc = 0; kc < k_chunks; ++kc) { + const size_t src_idx = params.transa ? (ii * k_chunks + kc) : (kc * params.m + ii); + a_scale_packed[ii * k_chunks + kc] = a_scale_cpu[src_idx]; + } + } + + for (size_t jj = 0; jj < params.n; ++jj) { + for (size_t kc = 0; kc < k_chunks; ++kc) { + const size_t src_idx = params.transb ? (kc * params.n + jj) : (jj * k_chunks + kc); + b_scale_packed[jj * k_chunks + kc] = b_scale_cpu[src_idx]; + } + } + + NVTE_CHECK_CUDA(cudaMalloc(&d_a_scale_packed, a_scale_packed.size() * sizeof(fp8e8m0))); + NVTE_CHECK_CUDA(cudaMalloc(&d_b_scale_packed, b_scale_packed.size() * sizeof(fp8e8m0))); + + NVTE_CHECK_CUDA(cudaMemcpy(d_a_scale_packed, a_scale_packed.data(), + a_scale_packed.size() * sizeof(fp8e8m0), + cudaMemcpyHostToDevice)); + NVTE_CHECK_CUDA(cudaMemcpy(d_b_scale_packed, b_scale_packed.data(), + b_scale_packed.size() * sizeof(fp8e8m0), + cudaMemcpyHostToDevice)); - b_scale_dev = params.transb - ? (const fp8e8m0*) B.columnwise_scale_inv_dptr() - : (const fp8e8m0*) B.rowwise_scale_inv_dptr(); + a_scale_dev = d_a_scale_packed; + b_scale_dev = d_b_scale_packed; } else { a_scale_inv_scalar = A.rowwise_scale_inv(); b_scale_inv_scalar = B.rowwise_scale_inv(); @@ -264,7 +308,8 @@ static void run_reference( params.transb, is_fp8_output, a_use_colwise, - b_use_colwise); + b_use_colwise, + use_mxfp8); NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaDeviceSynchronize()); @@ -292,6 +337,10 @@ static void run_reference( NVTE_CHECK_CUDA(cudaFree(d_refGelu)); if (d_refAmax) NVTE_CHECK_CUDA(cudaFree(d_refAmax)); + if (d_a_scale_packed) + NVTE_CHECK_CUDA(cudaFree(d_a_scale_packed)); + if (d_b_scale_packed) + NVTE_CHECK_CUDA(cudaFree(d_b_scale_packed)); } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 2114feacc..7596bcf06 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -267,7 +267,7 @@ class Tensor { } template - T *columnwise_cpu_scale_inv_ptr(){ + T *columnwise_cpu_scale_inv_ptr() const { if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { From 462945fc299deca92a99e783fb1f71f4ae034252 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 15 Jan 2026 15:27:42 -0600 Subject: [PATCH 08/51] cleanups --- tests/cpp/operator/test_cublaslt_gemm.cu | 2 +- tests/cpp/test_common.h | 12 +----------- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 376a5fc26..21e4d4be6 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -203,7 +203,7 @@ static void run_reference( const B_Type* b_dev = static_cast( b_use_colwise ? B.columnwise_dptr() : B.rowwise_dptr()); - // scaling inputs + // scaling inputs float a_scale_inv_scalar = 1.0f; float b_scale_inv_scalar = 1.0f; diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 7596bcf06..07b4cd9bf 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -224,16 +224,6 @@ class Tensor { return reinterpret_cast(cpu_data_columnwise_.get()); } - void *rowwise_scale_inv_dptr() const { - NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); - return tensor_.get_rowwise_scale_inv().data_ptr; - } - - void *columnwise_scale_inv_dptr() const { - NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); - return tensor_.get_columnwise_scale_inv().data_ptr; - } - float amax() const { if(amax_cpu_data_) { to_cpu(); @@ -446,7 +436,7 @@ inline fp8e8m0 float_to_e8m0(float val) { return exponent; } -__device__ __host__ __forceinline__ float exp2f_rcp(fp8e8m0 biased_exp) { +inline float exp2f_rcp(fp8e8m0 biased_exp) { return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); } From e11e40034c7adc2a0845b4fd66529f9f0929669b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 22 Jan 2026 16:30:38 -0600 Subject: [PATCH 09/51] use Tensor class for more device objects --- tests/cpp/operator/test_cublaslt_gemm.cu | 111 +++++++++-------------- tests/cpp/test_common.h | 10 ++ 2 files changed, 53 insertions(+), 68 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 21e4d4be6..33d3fd85a 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ @@ -65,6 +65,10 @@ __global__ void compute_ref_kernel( float b_scale_inv_scalar, const fp8e8m0* __restrict__ a_scale_inv_mxfp8, // used when mxfp8 == true const fp8e8m0* __restrict__ b_scale_inv_mxfp8, + size_t a_scale_ld, + size_t b_scale_ld, + bool a_scale_is_colwise, + bool b_scale_is_colwise, const Bias_Type* __restrict__ bias_data, float d_scale, size_t m, size_t k, size_t n, @@ -110,8 +114,10 @@ __global__ void compute_ref_kernel( if (a_scale_inv_mxfp8) { const size_t kc = kk / 32; - const size_t a_scale_idx = ii * k_chunks + kc; - const size_t b_scale_idx = jj * k_chunks + kc; + const size_t a_scale_idx = + a_scale_is_colwise ? (kc * a_scale_ld + ii) : (ii * a_scale_ld + kc); + const size_t b_scale_idx = + b_scale_is_colwise ? (kc * b_scale_ld + jj) : (jj * b_scale_ld + kc); a_scale_inv_val = exp2f(a_scale_inv_mxfp8[a_scale_idx] - 127.0f); b_scale_inv_val = exp2f(b_scale_inv_mxfp8[b_scale_idx] - 127.0f); @@ -209,52 +215,22 @@ static void run_reference( const fp8e8m0* a_scale_dev = nullptr; const fp8e8m0* b_scale_dev = nullptr; - - // If MXFP8, pack scale_inv into tight [row][kc] buffers on host, then transfer to device - std::vector a_scale_packed; - std::vector b_scale_packed; - fp8e8m0* d_a_scale_packed = nullptr; - fp8e8m0* d_b_scale_packed = nullptr; + size_t a_scale_ld = 0; + size_t b_scale_ld = 0; + bool a_scale_is_colwise = !params.transa; + bool b_scale_is_colwise = params.transb; if (use_mxfp8) { - const fp8e8m0* a_scale_cpu = params.transa - ? A.rowwise_cpu_scale_inv_ptr() - : A.columnwise_cpu_scale_inv_ptr(); - const fp8e8m0* b_scale_cpu = params.transb - ? B.columnwise_cpu_scale_inv_ptr() - : B.rowwise_cpu_scale_inv_ptr(); - - // Pack into row-major [row][kc]: - // A_packed[ii, kc] and B_packed[jj, kc] - a_scale_packed.resize(params.m * k_chunks); - b_scale_packed.resize(params.n * k_chunks); - - for (size_t ii = 0; ii < params.m; ++ii) { - for (size_t kc = 0; kc < k_chunks; ++kc) { - const size_t src_idx = params.transa ? (ii * k_chunks + kc) : (kc * params.m + ii); - a_scale_packed[ii * k_chunks + kc] = a_scale_cpu[src_idx]; - } - } - - for (size_t jj = 0; jj < params.n; ++jj) { - for (size_t kc = 0; kc < k_chunks; ++kc) { - const size_t src_idx = params.transb ? (kc * params.n + jj) : (jj * k_chunks + kc); - b_scale_packed[jj * k_chunks + kc] = b_scale_cpu[src_idx]; - } - } - - NVTE_CHECK_CUDA(cudaMalloc(&d_a_scale_packed, a_scale_packed.size() * sizeof(fp8e8m0))); - NVTE_CHECK_CUDA(cudaMalloc(&d_b_scale_packed, b_scale_packed.size() * sizeof(fp8e8m0))); - - NVTE_CHECK_CUDA(cudaMemcpy(d_a_scale_packed, a_scale_packed.data(), - a_scale_packed.size() * sizeof(fp8e8m0), - cudaMemcpyHostToDevice)); - NVTE_CHECK_CUDA(cudaMemcpy(d_b_scale_packed, b_scale_packed.data(), - b_scale_packed.size() * sizeof(fp8e8m0), - cudaMemcpyHostToDevice)); - - a_scale_dev = d_a_scale_packed; - b_scale_dev = d_b_scale_packed; + a_scale_dev = static_cast( + a_scale_is_colwise ? A.columnwise_scale_inv_dptr() : A.rowwise_scale_inv_dptr()); + b_scale_dev = static_cast( + b_scale_is_colwise ? B.columnwise_scale_inv_dptr() : B.rowwise_scale_inv_dptr()); + + const NVTEShape a_s = a_scale_is_colwise ? A.columnwise_scale_inv_shape() : A.rowwise_scale_inv_shape(); + const NVTEShape b_s = b_scale_is_colwise ? B.columnwise_scale_inv_shape() : B.rowwise_scale_inv_shape(); + NVTE_CHECK(a_s.ndim == 2 && b_s.ndim == 2, "Expected 2D MXFP8 scale_inv"); + a_scale_ld = a_s.data[1]; + b_scale_ld = b_s.data[1]; } else { a_scale_inv_scalar = A.rowwise_scale_inv(); b_scale_inv_scalar = B.rowwise_scale_inv(); @@ -266,20 +242,25 @@ static void run_reference( bias_dev = static_cast(Bias->rowwise_dptr()); } - // allocate device outputs + // allocate device outputs as test::Tensor objects const size_t lenD = params.m * params.n; const size_t bytesD = lenD * sizeof(D_Type); - D_Type* d_refD = nullptr; + Tensor RefD("RefD", TShape{params.n, params.m}, TypeInfo::dtype); + D_Type* d_refD = static_cast(RefD.rowwise_dptr()); + + Tensor RefGelu; + Tensor RefAmax; Gelu_Type* d_refGelu = nullptr; float* d_refAmax = nullptr; - NVTE_CHECK_CUDA(cudaMalloc(&d_refD, bytesD)); if (ref_gelu_host) { - NVTE_CHECK_CUDA(cudaMalloc(&d_refGelu, lenD * sizeof(Gelu_Type))); + RefGelu = Tensor("RefGelu", TShape{params.n, params.m}, TypeInfo::dtype); + d_refGelu = static_cast(RefGelu.rowwise_dptr()); } if (is_fp8_output && ref_amax_d) { - NVTE_CHECK_CUDA(cudaMalloc(&d_refAmax, sizeof(float))); + RefAmax = Tensor("RefAmax", TShape{1}, DType::kFloat32); + d_refAmax = static_cast(RefAmax.rowwise_dptr()); NVTE_CHECK_CUDA(cudaMemset(d_refAmax, 0, sizeof(float))); } @@ -298,6 +279,10 @@ static void run_reference( b_scale_inv_scalar, a_scale_dev, b_scale_dev, + a_scale_ld, + b_scale_ld, + a_scale_is_colwise, + b_scale_is_colwise, bias_dev, d_scale, params.m, params.k, params.n, @@ -315,32 +300,22 @@ static void run_reference( NVTE_CHECK_CUDA(cudaDeviceSynchronize()); // copy outputs back - NVTE_CHECK_CUDA(cudaMemcpy(ref_D.get(), d_refD, bytesD, cudaMemcpyDeviceToHost)); + RefD.to_cpu(); + memcpy(ref_D.get(), RefD.rowwise_cpu_dptr(), bytesD); if (ref_gelu_host) { - NVTE_CHECK_CUDA(cudaMemcpy(ref_gelu_host, d_refGelu, lenD * sizeof(Gelu_Type), - cudaMemcpyDeviceToHost)); + RefGelu.to_cpu(); + memcpy(ref_gelu_host, RefGelu.rowwise_cpu_dptr(), lenD * sizeof(Gelu_Type)); } if (ref_amax_d) { if (is_fp8_output) { - NVTE_CHECK_CUDA(cudaMemcpy(ref_amax_d, d_refAmax, sizeof(float), - cudaMemcpyDeviceToHost)); + RefAmax.to_cpu(); + *ref_amax_d = RefAmax.rowwise_cpu_dptr()[0]; } else { *ref_amax_d = 0.0f; } } - - // cleanup - NVTE_CHECK_CUDA(cudaFree(d_refD)); - if (d_refGelu) - NVTE_CHECK_CUDA(cudaFree(d_refGelu)); - if (d_refAmax) - NVTE_CHECK_CUDA(cudaFree(d_refAmax)); - if (d_a_scale_packed) - NVTE_CHECK_CUDA(cudaFree(d_a_scale_packed)); - if (d_b_scale_packed) - NVTE_CHECK_CUDA(cudaFree(d_b_scale_packed)); } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index e181bce68..17db2a021 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -298,6 +298,16 @@ class Tensor { std::mt19937& gen() { return gen_; } + void *rowwise_scale_inv_dptr() const { + NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); + return tensor_.scale_inv(); // rowwise scale_inv backing storage + } + + void *columnwise_scale_inv_dptr() const { + NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); + return tensor_.get_columnwise_scale_inv().data_ptr; + } + private: TensorWrapper tensor_; std::unique_ptr cpu_data_rowwise_; From 325ece611769ceb0af3bb1af26d53838646871ca Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 23 Jan 2026 14:11:48 -0600 Subject: [PATCH 10/51] Pass D Tensor into run_reference and move RefD allocation into PerformTest --- tests/cpp/operator/test_cublaslt_gemm.cu | 77 ++++++++---------------- tests/cpp/test_common.h | 4 ++ 2 files changed, 28 insertions(+), 53 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 33d3fd85a..e1c963734 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -82,7 +82,6 @@ __global__ void compute_ref_kernel( bool b_is_colwise, bool use_mxfp8) { - const size_t k_chunks = k / 32; const size_t jj = blockIdx.x * blockDim.x + threadIdx.x; const size_t ii = blockIdx.y * blockDim.y + threadIdx.y; @@ -187,16 +186,13 @@ static void run_reference( const Tensor& A, const Tensor& B, const Tensor* Bias, // nullable - float d_scale, - std::unique_ptr& ref_D, // m*n - float* ref_amax_d, - std::unique_ptr& ref_pre_gelu_out) // nullable + const Tensor& D_for_scale, + Tensor& RefD, + Tensor* RefPreGeluOut) // nullable { const bool use_mxfp8 = (params.scaling_mode == NVTE_MXFP8_1D_SCALING); - const size_t k_chunks = params.k / 32; - - Gelu_Type* ref_gelu_host = (params.use_gelu ? ref_pre_gelu_out.get() : nullptr); + const float d_scale = D_for_scale.scale(); const bool is_fp8_output = test::isFp8Type(test::TypeInfo::dtype); @@ -242,26 +238,19 @@ static void run_reference( bias_dev = static_cast(Bias->rowwise_dptr()); } - // allocate device outputs as test::Tensor objects - const size_t lenD = params.m * params.n; - const size_t bytesD = lenD * sizeof(D_Type); - - Tensor RefD("RefD", TShape{params.n, params.m}, TypeInfo::dtype); D_Type* d_refD = static_cast(RefD.rowwise_dptr()); - Tensor RefGelu; - Tensor RefAmax; Gelu_Type* d_refGelu = nullptr; float* d_refAmax = nullptr; - if (ref_gelu_host) { - RefGelu = Tensor("RefGelu", TShape{params.n, params.m}, TypeInfo::dtype); - d_refGelu = static_cast(RefGelu.rowwise_dptr()); + if (RefPreGeluOut) { + d_refGelu = static_cast(RefPreGeluOut->rowwise_dptr()); } - if (is_fp8_output && ref_amax_d) { - RefAmax = Tensor("RefAmax", TShape{1}, DType::kFloat32); - d_refAmax = static_cast(RefAmax.rowwise_dptr()); - NVTE_CHECK_CUDA(cudaMemset(d_refAmax, 0, sizeof(float))); + + if (is_fp8_output) { + d_refAmax = static_cast(RefD.amax_dptr()); + if (d_refAmax) + NVTE_CHECK_CUDA(cudaMemset(d_refAmax, 0, sizeof(float))); } // Kernel launch @@ -297,25 +286,6 @@ static void run_reference( use_mxfp8); NVTE_CHECK_CUDA(cudaGetLastError()); - NVTE_CHECK_CUDA(cudaDeviceSynchronize()); - - // copy outputs back - RefD.to_cpu(); - memcpy(ref_D.get(), RefD.rowwise_cpu_dptr(), bytesD); - - if (ref_gelu_host) { - RefGelu.to_cpu(); - memcpy(ref_gelu_host, RefGelu.rowwise_cpu_dptr(), lenD * sizeof(Gelu_Type)); - } - - if (ref_amax_d) { - if (is_fp8_output) { - RefAmax.to_cpu(); - *ref_amax_d = RefAmax.rowwise_cpu_dptr()[0]; - } else { - *ref_amax_d = 0.0f; - } - } } @@ -541,23 +511,21 @@ void performTest(const TestParams& params) { } //perform the reference gemm on GPU - std::unique_ptr ref_D = std::make_unique(params.m*params.n); - std::unique_ptr ref_pre_gelu_out; - if(params.use_gelu){ - ref_pre_gelu_out = std::make_unique(params.m*params.n); - } + Tensor RefD("RefD", TShape{ params.n, params.m }, dtype); + Tensor RefPreGeluOut; - float ref_amax_d; + if (params.use_gelu) { + RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type); + } run_reference( params, A, B, params.use_bias ? &bias : nullptr, - D.scale(), - ref_D, - &ref_amax_d, - ref_pre_gelu_out); + D, + RefD, + params.use_gelu ? &RefPreGeluOut : nullptr); // check if error message happens in running (void)cudaDeviceSynchronize(); @@ -567,15 +535,18 @@ void performTest(const TestParams& params) { //compare results auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); if (isFp8Type(dtype)) { + const float ref_amax_d = RefD.amax(); compareResults("D_amax", D.amax(), ref_amax_d, atol_amax, rtol_amax); } auto [atol, rtol] = getTestTolerances(dtype, has_fp8, use_mxfp8); - compareResults("D", D, ref_D.get(), true, atol, rtol); + RefD.to_cpu(); + compareResults("D", D, RefD.rowwise_cpu_dptr(), true, atol, rtol); if(params.use_gelu){ auto [atol, rtol] = getTestTolerances(gelu_type, false, false); - compareResults("gelu", pre_gelu_out, ref_pre_gelu_out.get(), true, atol, rtol); + RefPreGeluOut.to_cpu(); + compareResults("gelu", pre_gelu_out, RefPreGeluOut.rowwise_cpu_dptr(), true, atol, rtol); } } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 17db2a021..b824f8d4d 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -233,6 +233,10 @@ class Tensor { } } + void *amax_dptr() const { + return tensor_.amax(); + } + float scale() const { if(scale_cpu_data_) { NVTE_CHECK(tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING, "Invalid scaling_mode!"); From fc64b8cec14026905fdaa43be96beb5ce407552c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 26 Jan 2026 17:06:23 -0600 Subject: [PATCH 11/51] [WIP] proof-of-concept: grouped GEMM with ck_tile --- gmm2.py | 62 +++ transformer_engine/common/CMakeLists.txt | 8 + .../common/gemm/ck_grouped_gemm.cuh | 449 ++++++++++++++++++ .../common/gemm/cublaslt_gemm.cu | 38 +- 4 files changed, 555 insertions(+), 2 deletions(-) create mode 100644 gmm2.py create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm.cuh diff --git a/gmm2.py b/gmm2.py new file mode 100644 index 000000000..e5bceebbb --- /dev/null +++ b/gmm2.py @@ -0,0 +1,62 @@ +import os, torch +import transformer_engine.pytorch as te +from time import time + +torch.manual_seed(0) + +os.environ["NVTE_USE_CK_GROUPED_GEMM"] = "1" +os.environ["NVTE_CK_GROUPED_GEMM_WARN_FALLBACK"] = "1" + +device = "cuda" +dtype = torch.bfloat16 + +E = 4 +K = 1024 +N = 2048 +m_splits = [128, 64, 0, 256] +M_total = sum(m_splits) + +x = torch.randn(M_total, K, device=device, dtype=dtype) + +# TE +start = time() + +glinear = te.GroupedLinear(E, K, N, bias=False).to(device=device, dtype=dtype) +y_te = glinear(x, m_splits=m_splits) +print("TE time: ", time()-start) + + +Ws = [] +for e in range(E): + w = getattr(glinear, f"weight{e}") # expect [N, K] + Ws.append(w) +W = torch.stack(Ws, dim=0) # [E, N, K] +assert W.shape == (E, N, K), f"Unexpected weight shape: {W.shape}" + + +# Torch +start = time() + +ys = [] +offset = 0 +for e, m in enumerate(m_splits): + if m == 0: + continue + x_e = x[offset:offset+m] # [m, K] + y_e = x_e @ W[e].transpose(0, 1) # [m, N] + ys.append(y_e) + offset += m + +y_ref = torch.cat(ys, dim=0) +print("Torch time:", time()-start) + +# Compare +diff = (y_te.float() - y_ref.float()) +max_abs = diff.abs().max().item() +rel = (diff.abs() / (y_ref.float().abs() + 1e-6)).max().item() + +print(f"{y_te.shape=}, {y_ref.shape=}") +print("max_abs_err:", max_abs) +print("max_rel_err:", rel) + +torch.testing.assert_close(y_te.float(), y_ref.float(), rtol=3e-2, atol=3e-2) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index cefec6d06..56207f16d 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -241,6 +241,14 @@ endif() target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") +set(CK_ROOT ${CMAKE_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel) + +target_include_directories(transformer_engine + BEFORE PRIVATE + ${CK_ROOT}/include +) + + if (USE_CUDA) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) set_source_files_properties( diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cuh b/transformer_engine/common/gemm/ck_grouped_gemm.cuh new file mode 100644 index 000000000..fa1f1cca1 --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cuh @@ -0,0 +1,449 @@ +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +static inline int get_num_cu_for_stream(hipStream_t stream) { + int device = -1; + hipError_t st = hipGetDevice(&device); + if (st != hipSuccess) + return 0; + + hipDeviceProp_t prop{}; + st = hipGetDeviceProperties(&prop, device); + if (st != hipSuccess) + return 0; + + return prop.multiProcessorCount; +} + +// Map TE DType to CK_Tile scalar type +template +struct TeDTypeToCk; + +template <> struct TeDTypeToCk { + using type = ck_tile::half_t; +}; +template <> struct TeDTypeToCk { + using type = ck_tile::bfloat16_t; +}; + +// TE Tensor -> SimpleTensor view +static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { + // For GEMM we want the "data" view (rowwise) + return t.data; +} + +// CK_Tile runner + +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; + +struct TileCfg_basic { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool kPadM = true; + static constexpr bool kPadN = true; + static constexpr bool kPadK = true; + + static constexpr bool DoubleSmemBuffer = false; + + // Spatially-local partitioner parameters + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 1; +}; + +template +inline void launch_grouped_kernel(const ck_tile::stream_config& stream_cfg, + ck_tile::index_t group_num, + void* args_ptr, + uint32_t num_cu) { + constexpr int kBlockPerCu = 1; + const dim3 blocks = Kernel::BlockSize(); + dim3 grids = Kernel::MaxOccupancyGridSize(stream_cfg); + grids.x = std::min(grids.x, static_cast(num_cu)); + ck_tile::launch_kernel( + stream_cfg, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(args_ptr), + group_num)); +} + +template +class Runner{ +public: + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, TileCfg::TileParitionerGroupNum, TileCfg::TileParitionerM01>; + + using UniversalTraits = ck_tile::PersistentTileGemmUniversalTraits< + TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, + TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>; + + static constexpr ck_tile::GemmPipelineScheduler Scheduler = + ck_tile::GemmPipelineScheduler::Intrawave; + + using Problem = ck_tile::UniversalGemmPipelineProblem< + AType, BType, AccType, GemmShape, UniversalTraits, Scheduler>; + + using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + static constexpr ck_tile::memory_operation_enum MemOp = ck_tile::memory_operation_enum::set; + + using Epilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem< + AType, BType, ck_tile::tuple<>, AccType, + CType, ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, + Partitioner::MPerBlock, Partitioner::NPerBlock, + TileCfg::M_Warp, TileCfg::N_Warp, + TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, + Problem::TransposeC, MemOp>>; + + using Kernel = ck_tile::GroupedGemmKernel; + + void run(const ck_tile::stream_config& stream_cfg, + ck_tile::index_t group_num, + void* args_ptr, + uint32_t num_cu) { + launch_grouped_kernel(stream_cfg, group_num, args_ptr, num_cu); + } +}; + +// Arg builder kernel + +template +__global__ void build_args_kernel(ck_tile::GemmTransKernelArg<>* args, + const void* const* a_ptrs, + const void* const* b_ptrs, + void* const* d_ptrs, + const int64_t* ms, + const int64_t* ns, + const int64_t* ks, + ck_tile::index_t group_num, + ck_tile::index_t strideA, + ck_tile::index_t strideB, + ck_tile::index_t strideD, + ck_tile::index_t k_batch) { + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid >= group_num) + return; + + // CK_Tile's grouped arg uses arrays for As/Bs + const_cast&>(args[gid].group_karg.as_ptr)[0] = + static_cast(a_ptrs[gid]); + const_cast&>(args[gid].group_karg.bs_ptr)[0] = + static_cast(b_ptrs[gid]); + + args[gid].group_karg.e_ptr = d_ptrs[gid]; + + args[gid].group_karg.M = static_cast(ms[gid]); + args[gid].group_karg.N = static_cast(ns[gid]); + args[gid].group_karg.K = static_cast(ks[gid]); + + args[gid].group_karg.stride_As[0] = strideA; + args[gid].group_karg.stride_Bs[0] = strideB; + args[gid].group_karg.stride_E = strideD; + args[gid].group_karg.k_batch = k_batch; +} + +bool grouped_gemm_ck_tile(const transformer_engine::Tensor* const* A, + const transformer_engine::Tensor* const* B, + transformer_engine::Tensor* const* D, + int group_num, + bool transA, + bool transB, + void* workspace, + size_t workspace_bytes, + hipStream_t stream, + uint32_t num_cu_override = 0) { + // TE sometimes passes (A=weight, B=input, transA=1, transB=0) for y = x * W^T + // CK_Tile expects the left operand to be the activation matrix + // So for (transA && !transB), swap A/B and turn it into (!transA && transB) + const transformer_engine::Tensor* const* A_use = A; + const transformer_engine::Tensor* const* B_use = B; + bool transA_use = transA; + bool transB_use = transB; + if (transA && !transB) { + A_use = B; + B_use = A; + transA_use = false; + transB_use = true; + } + + if (!( (!transA_use && !transB_use) || (!transA_use && transB_use) )) { + NVTE_ERROR("grouped_gemm_ck_tile: only NN/NT/TN supported."); + return false; + } + + // DType routing: allow fp16/bf16 for now + const auto a_dtype = A_use[0]->dtype(); + const auto b_dtype = B_use[0]->dtype(); + const auto d_dtype = D[0]->dtype(); + if (a_dtype != b_dtype || a_dtype != d_dtype) { + NVTE_ERROR("grouped_gemm_ck_tile: dtype mismatch A/B/D."); + return false; + } + if (!(a_dtype == transformer_engine::DType::kFloat16 || + a_dtype == transformer_engine::DType::kBFloat16)) { + NVTE_ERROR("grouped_gemm_ck_tile: only fp16/bf16 supported."); + return false; + } + + // Workspace layout: + // [0] device arrays of pointers (A_ptrs, B_ptrs, D_ptrs) + // [1] device arrays of int64 (M, N, K) + // [2] ck_tile::GemmTransKernelArg<>[group_num] + const size_t ptr_arr_bytes = sizeof(void*) * static_cast(group_num); + const size_t i64_arr_bytes = sizeof(int64_t) * static_cast(group_num); + + const size_t off_a_ptrs = 0; + const size_t off_b_ptrs = off_a_ptrs + ptr_arr_bytes; + const size_t off_d_ptrs = off_b_ptrs + ptr_arr_bytes; + const size_t off_ms = off_d_ptrs + ptr_arr_bytes; + const size_t off_ns = off_ms + i64_arr_bytes; + const size_t off_ks = off_ns + i64_arr_bytes; + + const size_t off_args = ck_tile::integer_divide_ceil(off_ks + i64_arr_bytes, size_t(16)) * 16; + + const size_t args_bytes = sizeof(ck_tile::GemmTransKernelArg<>) * static_cast(group_num); + const size_t needed = off_args + args_bytes; + + if (workspace == nullptr || workspace_bytes < needed) { + NVTE_ERROR("grouped_gemm_ck_tile: insufficient workspace. Needed bytes=", needed); + return false; + } + + auto* base = static_cast(workspace); + + void** d_a_ptrs = reinterpret_cast(base + off_a_ptrs); + void** d_b_ptrs = reinterpret_cast(base + off_b_ptrs); + void** d_d_ptrs = reinterpret_cast(base + off_d_ptrs); + int64_t* d_ms = reinterpret_cast(base + off_ms); + int64_t* d_ns = reinterpret_cast(base + off_ns); + int64_t* d_ks = reinterpret_cast(base + off_ks); + + auto* d_args = reinterpret_cast*>(base + off_args); + + // Build host-side staging buffers and memcpy to device + std::vector h_a_ptrs(group_num); + std::vector h_b_ptrs(group_num); + std::vector h_d_ptrs(group_num); + std::vector h_ms(group_num); + std::vector h_ns(group_num); + std::vector h_ks(group_num); + + // Infer global N/K from group 0 + const auto& a0 = data_view(*A_use[0]); + const auto& b0 = data_view(*B_use[0]); + const auto& d0 = data_view(*D[0]); + if (a0.shape.size() != 2 || b0.shape.size() != 2 || d0.shape.size() != 2) { + NVTE_ERROR("grouped_gemm_ck_tile: expected 2D tensors."); + return false; + } + + printf("grouped_gemm_ck_tile gg0 A=[%zu,%zu] B=[%zu,%zu] D=[%zu,%zu] transA=%d transB=%d\n", + a0.shape[0], a0.shape[1], + b0.shape[0], b0.shape[1], + d0.shape[0], d0.shape[1], + (int)transA_use, (int)transB_use); + + // Infer logical M/K from A depending on transA + // - NN/NT: A stored [M,K] + // - TN: A stored [K,M] row-major, interpret as ColMajor [M,K] + const int64_t m0 = transA_use ? static_cast(a0.shape[1]) : static_cast(a0.shape[0]); + const int64_t k0 = transA_use ? static_cast(a0.shape[0]) : static_cast(a0.shape[1]); + + const int64_t n0 = transB_use ? static_cast(b0.shape[0]) + : static_cast(b0.shape[1]); + const int64_t kb = transB_use ? static_cast(b0.shape[1]) + : static_cast(b0.shape[0]); + if (kb != k0) { + NVTE_ERROR("grouped_gemm_ck_tile: K mismatch between A and B in group 0."); + return false; + } + if (static_cast(d0.shape[0]) != m0 || static_cast(d0.shape[1]) != n0) { + NVTE_ERROR("grouped_gemm_ck_tile: D shape mismatch in group 0."); + return false; + } + + for (int i = 0; i < group_num; ++i) { + const auto& ai = data_view(*A_use[i]); + const auto& bi = data_view(*B_use[i]); + const auto& di = data_view(*D[i]); + + if (ai.shape.size() != 2 || bi.shape.size() != 2 || di.shape.size() != 2) { + NVTE_ERROR("grouped_gemm_ck_tile: expected all groups to be 2D."); + return false; + } + + const int64_t mi = transA_use ? static_cast(ai.shape[1]) : static_cast(ai.shape[0]); + const int64_t ki = transA_use ? static_cast(ai.shape[0]) : static_cast(ai.shape[1]); + const int64_t ni = transB_use ? static_cast(bi.shape[0]) + : static_cast(bi.shape[1]); + const int64_t kbi = transB_use ? static_cast(bi.shape[1]) + : static_cast(bi.shape[0]); + + if (ki != k0 || ni != n0 || kbi != k0) { + NVTE_ERROR("grouped_gemm_ck_tile: N/K must be constant across groups."); + return false; + } + if (static_cast(di.shape[0]) != mi || static_cast(di.shape[1]) != n0) { + NVTE_ERROR("grouped_gemm_ck_tile: D shape mismatch in group ", i); + return false; + } + + h_a_ptrs[i] = ai.dptr; + h_b_ptrs[i] = bi.dptr; + h_d_ptrs[i] = di.dptr; + h_ms[i] = mi; + h_ns[i] = n0; + h_ks[i] = k0; + } + + HIP_CHECK_ERROR(hipMemcpyAsync(d_a_ptrs, h_a_ptrs.data(), ptr_arr_bytes, hipMemcpyHostToDevice, + reinterpret_cast(stream))); + HIP_CHECK_ERROR(hipMemcpyAsync(d_b_ptrs, h_b_ptrs.data(), ptr_arr_bytes, hipMemcpyHostToDevice, + reinterpret_cast(stream))); + HIP_CHECK_ERROR(hipMemcpyAsync(d_d_ptrs, h_d_ptrs.data(), ptr_arr_bytes, hipMemcpyHostToDevice, + reinterpret_cast(stream))); + HIP_CHECK_ERROR(hipMemcpyAsync(d_ms, h_ms.data(), i64_arr_bytes, hipMemcpyHostToDevice, + reinterpret_cast(stream))); + HIP_CHECK_ERROR(hipMemcpyAsync(d_ns, h_ns.data(), i64_arr_bytes, hipMemcpyHostToDevice, + reinterpret_cast(stream))); + HIP_CHECK_ERROR(hipMemcpyAsync(d_ks, h_ks.data(), i64_arr_bytes, hipMemcpyHostToDevice, + reinterpret_cast(stream))); + + // Leading dimensions for CK layouts: + // A is row-major [M,K] and we only support transA=false -> ALayout=RowMajor, strideA=K + // B is row-major [K,N] if NN -> BLayout=RowMajor, strideB=N + // B is row-major [N,K] if NT -> BLayout=ColMajor (logical [K,N]), strideB=K + const ck_tile::index_t strideA = static_cast(transA_use ? m0 : k0); + const ck_tile::index_t strideB = static_cast(transB_use ? k0 : n0); + const ck_tile::index_t strideD = static_cast(n0); + + // Build CK arg structs on device + { + const int threads = 256; + const int blocks = (group_num + threads - 1) / threads; + const ck_tile::index_t k_batch = 1; + if (a_dtype == transformer_engine::DType::kFloat16) { + using AType = TeDTypeToCk::type; + using BType = AType; + using CType = AType; + hipLaunchKernelGGL((build_args_kernel), + dim3(blocks), dim3(threads), 0, + reinterpret_cast(stream), + d_args, + const_cast(reinterpret_cast(d_a_ptrs)), + const_cast(reinterpret_cast(d_b_ptrs)), + reinterpret_cast(d_d_ptrs), + d_ms, d_ns, d_ks, + static_cast(group_num), + strideA, strideB, strideD, + k_batch); + } else { + using AType = TeDTypeToCk::type; + using BType = AType; + using CType = AType; + hipLaunchKernelGGL((build_args_kernel), + dim3(blocks), dim3(threads), 0, + reinterpret_cast(stream), + d_args, + const_cast(reinterpret_cast(d_a_ptrs)), + const_cast(reinterpret_cast(d_b_ptrs)), + reinterpret_cast(d_d_ptrs), + d_ms, d_ns, d_ks, + static_cast(group_num), + strideA, strideB, strideD, + k_batch); + } + } + + // Runner selection + const uint32_t num_cu = (num_cu_override != 0) ? num_cu_override + : static_cast(get_num_cu_for_stream(stream)); + const ck_tile::stream_config stream_cfg{reinterpret_cast(stream)}; + + // Choose layouts based on transB + if (a_dtype == transformer_engine::DType::kFloat16) { + using T = TeDTypeToCk::type; + + if (!transB_use) { + // NN: A RowMajor, B RowMajor, D RowMajor + Runner runner; + runner.run(stream_cfg, static_cast(group_num), d_args, num_cu); + } else { + // NT: B is stored as [N,K] row-major -> treat as ColMajor logical [K,N] + Runner runner; + runner.run(stream_cfg, static_cast(group_num), d_args, num_cu); + } + } else { + using T = TeDTypeToCk::type; + + if (!transB_use) { + Runner runner; + runner.run(stream_cfg, static_cast(group_num), d_args, num_cu); + } else { + Runner runner; + runner.run(stream_cfg, static_cast(group_num), d_args, num_cu); + } + } + + return true; +} + +bool grouped_gemm_ck_tile(const NVTETensor* A, + const NVTETensor* B, + NVTETensor* D, + int group_num, + bool transA, + bool transB, + NVTETensor* workspace, + hipStream_t stream) { + if (group_num <= 0) + return true; + + // Convert A/B/D arrays into TE Tensor* arrays + std::vector A_te(group_num); + std::vector B_te(group_num); + std::vector D_te(group_num); + + for (int i = 0; i < group_num; ++i) { + A_te[i] = transformer_engine::convertNVTETensorCheck(A[i]); + B_te[i] = transformer_engine::convertNVTETensorCheck(B[i]); + D_te[i] = transformer_engine::convertNVTETensorCheck(D[i]); + } + + // Workspace pointer + bytes + void* ws_ptr = nullptr; + size_t ws_bytes = 0; + if (workspace) { + auto* ws_te = transformer_engine::convertNVTETensorCheck(*workspace); + ws_ptr = ws_te->data.dptr; + ws_bytes = ws_te->data.numel() * transformer_engine::typeToSize(ws_te->data.dtype); + } + + return grouped_gemm_ck_tile(A_te.data(), B_te.data(), D_te.data(), + group_num, transA, transB, + ws_ptr, ws_bytes, + stream); +} diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 9c2ca9b4c..e583bc14f 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -24,8 +24,11 @@ #include "../util/logging.h" #include "../util/multi_stream.h" #include "common/util/cuda_runtime.h" +#include "common/util/system.h" #ifndef __HIP_PLATFORM_AMD__ #include "cutlass_grouped_gemm.cuh" +#else +#include "ck_grouped_gemm.cuh" #endif #ifndef __HIP_PLATFORM_AMD__ @@ -788,7 +791,38 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor NVTE_API_CALL(nvte_multi_tensor_gemm); #ifdef __HIP_PLATFORM_AMD__ - multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, + const bool use_ck = transformer_engine::getenv("NVTE_USE_CK_GROUPED_GEMM", false); + const bool warn_fallback = + transformer_engine::getenv("NVTE_CK_GROUPED_GEMM_WARN_FALLBACK", false); + + auto is_supported_dtype = [&]() -> bool { + auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); + auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]); + auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]); + auto A_dt = inputA->data.dtype; + auto B_dt = inputB->data.dtype; + auto D_dt = OutputD->data.dtype; + + return (A_dt == B_dt) && (A_dt == D_dt) && + (A_dt == transformer_engine::DType::kFloat16 || + A_dt == transformer_engine::DType::kBFloat16); + }; + + if (use_ck && + is_supported_dtype() && + !accumulate) { + + if (grouped_gemm_ck_tile(A, B, D, num_gemms, transa, transb, workspace, stream)) { + printf("grouped_gemm_ck_tile done.\n"); + return; + } else if (warn_fallback) { + NVTE_WARN("Fallback to hipBLASLt grouped GEMM (grouped_gemm_ck_tile returned false)."); + } + } + + NVTE_WARN("Fallback to hipBLASLt grouped GEMM (CK config unsupported).\n"); + + multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, workspace, accumulate, use_split_accumulator, math_sm_count, stream); #else const int current_device = transformer_engine::cuda::current_device(); From 9091e6ce73ea47b0436edc0af6f394465c2bd1cd Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 29 Jan 2026 10:52:22 -0600 Subject: [PATCH 12/51] restructure and enable tests --- gmm2.py | 90 ++- tests/pytorch/test_numerics.py | 16 +- .../common/gemm/ck_grouped_gemm.cuh | 549 ++++++++---------- .../common/gemm/cublaslt_gemm.cu | 13 +- 4 files changed, 319 insertions(+), 349 deletions(-) diff --git a/gmm2.py b/gmm2.py index e5bceebbb..016304d9f 100644 --- a/gmm2.py +++ b/gmm2.py @@ -1,13 +1,14 @@ -import os, torch +import os +import time +import torch import transformer_engine.pytorch as te -from time import time torch.manual_seed(0) os.environ["NVTE_USE_CK_GROUPED_GEMM"] = "1" os.environ["NVTE_CK_GROUPED_GEMM_WARN_FALLBACK"] = "1" -device = "cuda" +device = "cuda" dtype = torch.bfloat16 E = 4 @@ -18,45 +19,74 @@ x = torch.randn(M_total, K, device=device, dtype=dtype) -# TE -start = time() - +# Timing helper +def bench_cuda(fn, warmup=20, iters=100, name=""): + # Warmup + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + # Timed + start = time.time() + for _ in range(iters): + fn() + torch.cuda.synchronize() + end = time.time() + + avg_ms = (end - start) * 1000.0 / iters + if name: + print(f"{name}: {avg_ms:.3f} ms (avg over {iters} runs, {warmup} warmup)") + return avg_ms + +# TE GroupedLinear glinear = te.GroupedLinear(E, K, N, bias=False).to(device=device, dtype=dtype) -y_te = glinear(x, m_splits=m_splits) -print("TE time: ", time()-start) +def te_run(): + return glinear(x, m_splits=m_splits) + +te_ms = bench_cuda(te_run, warmup=20, iters=100, name="TE GroupedLinear") -Ws = [] -for e in range(E): - w = getattr(glinear, f"weight{e}") # expect [N, K] - Ws.append(w) -W = torch.stack(Ws, dim=0) # [E, N, K] +# Grab weights for reference path +Ws = [getattr(glinear, f"weight{e}") for e in range(E)] # each [N, K] +W = torch.stack(Ws, dim=0) # [E, N, K] assert W.shape == (E, N, K), f"Unexpected weight shape: {W.shape}" +# Torch reference (group loop) +offsets = [] +off = 0 +for m in m_splits: + offsets.append(off) + off += m -# Torch -start = time() +y_ref_buf = torch.empty((M_total, N), device=device, dtype=dtype) -ys = [] -offset = 0 -for e, m in enumerate(m_splits): - if m == 0: - continue - x_e = x[offset:offset+m] # [m, K] - y_e = x_e @ W[e].transpose(0, 1) # [m, N] - ys.append(y_e) - offset += m +def torch_run(): + # Fill the preallocated buffer + for e, m in enumerate(m_splits): + if m == 0: + continue + o = offsets[e] + y_ref_buf[o:o+m].copy_(x[o:o+m] @ W[e].transpose(0, 1)) + return y_ref_buf -y_ref = torch.cat(ys, dim=0) -print("Torch time:", time()-start) +torch_ms = bench_cuda(torch_run, warmup=20, iters=100, name="Torch loop (prealloc out)") + +# Compare outputs +y_te = te_run() +y_ref = torch_run().clone() -# Compare diff = (y_te.float() - y_ref.float()) max_abs = diff.abs().max().item() rel = (diff.abs() / (y_ref.float().abs() + 1e-6)).max().item() -print(f"{y_te.shape=}, {y_ref.shape=}") -print("max_abs_err:", max_abs) -print("max_rel_err:", rel) +print(f"\nErrors:") +print(f" {y_te.shape=}, {y_ref.shape=}") +print(" max_abs_err:", max_abs) +print(" max_rel_err:", rel) torch.testing.assert_close(y_te.float(), y_ref.float(), rtol=3e-2, atol=3e-2) + +print(f"\nTiming:") +print(f" TE avg: {te_ms:.3f} ms") +print(f" Torch avg: {torch_ms:.3f} ms") +print(f" Speedup: {torch_ms/te_ms:.2f}x (Torch / TE)") diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a4dfd64ba..5f1489f88 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1385,7 +1385,7 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ if IS_HIP_EXTENSION: if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias: - pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") + pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") te_linear_ref = Linear( config.hidden_size, @@ -1677,7 +1677,7 @@ def test_layernorm_linear_accuracy_delay_wgrad_compute( ): if IS_HIP_EXTENSION: if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias: - pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") + pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") config = model_configs[model] ln_linear_ref = LayerNormLinear( @@ -1891,7 +1891,7 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( if IS_HIP_EXTENSION: if dtype not in (torch.float32,) and fuse_wgrad_accumulation and bias: - pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") + pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") ln_mlp = LayerNormMLP( hidden_size=config.hidden_size, @@ -2036,7 +2036,7 @@ def test_grouped_linear_accuracy( if IS_HIP_EXTENSION: if dtype not in (torch.float32,) and fuse_wgrad_accumulation and not fp8: - pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") + pytest.skip(f"ROCm does not support fused wgrad accumulation for {dtype}.") if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") @@ -2115,7 +2115,7 @@ def test_grouped_linear_accuracy( @pytest.mark.skipif( - torch.cuda.get_device_capability() != (9, 0), + torch.cuda.get_device_capability() != (9, 0) and not IS_HIP_EXTENSION, reason="Only enable CUTLASS grouped gemm on Hopper", ) @pytest.mark.parametrize("dtype", param_types, ids=str) @@ -2133,6 +2133,9 @@ def test_grouped_linear_accuracy_cutlass( delay_wgrad_compute, ): os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + if IS_HIP_EXTENSION: + os.environ["NVTE_USE_CK_GROUPED_GEMM"] = "1" + os.environ["NVTE_CK_GROUPED_GEMM_WARN_FALLBACK"] = "1" test_grouped_linear_accuracy( dtype, num_gemms, @@ -2147,6 +2150,9 @@ def test_grouped_linear_accuracy_cutlass( use_cutlass=True, ) os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) + if IS_HIP_EXTENSION: + os.environ.pop("NVTE_USE_CK_GROUPED_GEMM", None) + os.environ.pop("NVTE_CK_GROUPED_GEMM_WARN_FALLBACK", None) @pytest.mark.parametrize("dtype", param_types, ids=str) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cuh b/transformer_engine/common/gemm/ck_grouped_gemm.cuh index fa1f1cca1..1171e33f9 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cuh +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cuh @@ -1,45 +1,23 @@ +/* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. */ + #include #include "ck_tile/core.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -static inline int get_num_cu_for_stream(hipStream_t stream) { - int device = -1; - hipError_t st = hipGetDevice(&device); - if (st != hipSuccess) - return 0; - - hipDeviceProp_t prop{}; - st = hipGetDeviceProperties(&prop, device); - if (st != hipSuccess) - return 0; - - return prop.multiProcessorCount; -} +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; -// Map TE DType to CK_Tile scalar type template struct TeDTypeToCk; +template <> struct TeDTypeToCk { using type = ck_tile::half_t; }; +template <> struct TeDTypeToCk{ using type = ck_tile::bfloat16_t; }; -template <> struct TeDTypeToCk { - using type = ck_tile::half_t; -}; -template <> struct TeDTypeToCk { - using type = ck_tile::bfloat16_t; -}; - -// TE Tensor -> SimpleTensor view static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { - // For GEMM we want the "data" view (rowwise) - return t.data; + return t.data; // rowwise data view } -// CK_Tile runner - -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; - struct TileCfg_basic { static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 128; @@ -59,31 +37,14 @@ struct TileCfg_basic { static constexpr bool DoubleSmemBuffer = false; - // Spatially-local partitioner parameters - static constexpr ck_tile::index_t TileParitionerGroupNum = 8; - static constexpr ck_tile::index_t TileParitionerM01 = 1; + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 1; }; -template -inline void launch_grouped_kernel(const ck_tile::stream_config& stream_cfg, - ck_tile::index_t group_num, - void* args_ptr, - uint32_t num_cu) { - constexpr int kBlockPerCu = 1; - const dim3 blocks = Kernel::BlockSize(); - dim3 grids = Kernel::MaxOccupancyGridSize(stream_cfg); - grids.x = std::min(grids.x, static_cast(num_cu)); - ck_tile::launch_kernel( - stream_cfg, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, - ck_tile::cast_pointer_to_constant_address_space(args_ptr), - group_num)); -} - template + typename TileCfg, ck_tile::memory_operation_enum MemOp, + typename AccType = float> class Runner{ public: using GemmShape = ck_tile::TileGemmShape< @@ -92,7 +53,7 @@ public: ck_tile::sequence>; using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< - GemmShape, TileCfg::TileParitionerGroupNum, TileCfg::TileParitionerM01>; + GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; using UniversalTraits = ck_tile::PersistentTileGemmUniversalTraits< TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, @@ -106,8 +67,6 @@ public: using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; - static constexpr ck_tile::memory_operation_enum MemOp = ck_tile::memory_operation_enum::set; - using Epilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem< AType, BType, ck_tile::tuple<>, AccType, @@ -119,296 +78,272 @@ public: Problem::TransposeC, MemOp>>; using Kernel = ck_tile::GroupedGemmKernel; - - void run(const ck_tile::stream_config& stream_cfg, - ck_tile::index_t group_num, - void* args_ptr, - uint32_t num_cu) { - launch_grouped_kernel(stream_cfg, group_num, args_ptr, num_cu); - } }; -// Arg builder kernel - -template -__global__ void build_args_kernel(ck_tile::GemmTransKernelArg<>* args, - const void* const* a_ptrs, - const void* const* b_ptrs, - void* const* d_ptrs, - const int64_t* ms, - const int64_t* ns, - const int64_t* ks, - ck_tile::index_t group_num, - ck_tile::index_t strideA, - ck_tile::index_t strideB, - ck_tile::index_t strideD, - ck_tile::index_t k_batch) { - const int gid = blockIdx.x * blockDim.x + threadIdx.x; - if (gid >= group_num) - return; - - // CK_Tile's grouped arg uses arrays for As/Bs - const_cast&>(args[gid].group_karg.as_ptr)[0] = - static_cast(a_ptrs[gid]); - const_cast&>(args[gid].group_karg.bs_ptr)[0] = - static_cast(b_ptrs[gid]); - - args[gid].group_karg.e_ptr = d_ptrs[gid]; - - args[gid].group_karg.M = static_cast(ms[gid]); - args[gid].group_karg.N = static_cast(ns[gid]); - args[gid].group_karg.K = static_cast(ks[gid]); - - args[gid].group_karg.stride_As[0] = strideA; - args[gid].group_karg.stride_Bs[0] = strideB; - args[gid].group_karg.stride_E = strideD; - args[gid].group_karg.k_batch = k_batch; -} +template +static inline void launch_tileloop_kernel(const ck_tile::stream_config& s, + ck_tile::index_t group_num, + void* kargs_dev) +{ + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); -bool grouped_gemm_ck_tile(const transformer_engine::Tensor* const* A, - const transformer_engine::Tensor* const* B, - transformer_engine::Tensor* const* D, - int group_num, - bool transA, - bool transB, - void* workspace, - size_t workspace_bytes, - hipStream_t stream, - uint32_t num_cu_override = 0) { - // TE sometimes passes (A=weight, B=input, transA=1, transB=0) for y = x * W^T - // CK_Tile expects the left operand to be the activation matrix - // So for (transA && !transB), swap A/B and turn it into (!transA && transB) - const transformer_engine::Tensor* const* A_use = A; - const transformer_engine::Tensor* const* B_use = B; - bool transA_use = transA; - bool transB_use = transB; - if (transA && !transB) { - A_use = B; - B_use = A; - transA_use = false; - transB_use = true; - } + ck_tile::launch_kernel( + s, + ck_tile::make_kernel<1>( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_dev), + group_num)); +} - if (!( (!transA_use && !transB_use) || (!transA_use && transB_use) )) { - NVTE_ERROR("grouped_gemm_ck_tile: only NN/NT/TN supported."); +template +static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, + const transformer_engine::Tensor* const* B_use, + transformer_engine::Tensor* const* D, + int group_num, + bool transA_use, + bool transB_use, + void* workspace, + size_t workspace_bytes, + hipStream_t stream) +{ + using R = Runner; + using Kernel = typename R::Kernel; + + const size_t needed = Kernel::GetWorkSpaceSize(group_num); + if (!workspace || workspace_bytes < needed) { + NVTE_ERROR("grouped_gemm_ck_tile: insufficient workspace. Needed bytes=", needed); return false; } - // DType routing: allow fp16/bf16 for now - const auto a_dtype = A_use[0]->dtype(); - const auto b_dtype = B_use[0]->dtype(); - const auto d_dtype = D[0]->dtype(); - if (a_dtype != b_dtype || a_dtype != d_dtype) { - NVTE_ERROR("grouped_gemm_ck_tile: dtype mismatch A/B/D."); - return false; - } - if (!(a_dtype == transformer_engine::DType::kFloat16 || - a_dtype == transformer_engine::DType::kBFloat16)) { - NVTE_ERROR("grouped_gemm_ck_tile: only fp16/bf16 supported."); - return false; - } + std::vector> descs; + descs.reserve(group_num); + + for (int i = 0; i < group_num; ++i) { + const auto& a = data_view(*A_use[i]); + const auto& b = data_view(*B_use[i]); + const auto& d = data_view(*D[i]); + + if (a.shape.size() != 2 || b.shape.size() != 2 || d.shape.size() != 2) { + NVTE_ERROR("grouped_gemm_ck_tile: expected all groups to be 2D."); + return false; + } - // Workspace layout: - // [0] device arrays of pointers (A_ptrs, B_ptrs, D_ptrs) - // [1] device arrays of int64 (M, N, K) - // [2] ck_tile::GemmTransKernelArg<>[group_num] - const size_t ptr_arr_bytes = sizeof(void*) * static_cast(group_num); - const size_t i64_arr_bytes = sizeof(int64_t) * static_cast(group_num); + const int64_t Ad0 = a.shape[0]; + const int64_t Ad1 = a.shape[1]; + const int64_t Bd0 = b.shape[0]; + const int64_t Bd1 = b.shape[1]; - const size_t off_a_ptrs = 0; - const size_t off_b_ptrs = off_a_ptrs + ptr_arr_bytes; - const size_t off_d_ptrs = off_b_ptrs + ptr_arr_bytes; - const size_t off_ms = off_d_ptrs + ptr_arr_bytes; - const size_t off_ns = off_ms + i64_arr_bytes; - const size_t off_ks = off_ns + i64_arr_bytes; + const int64_t M = transA_use ? Ad1 : Ad0; + const int64_t K = transA_use ? Ad0 : Ad1; + const int64_t N = transB_use ? Bd0 : Bd1; + const int64_t Kb = transB_use ? Bd1 : Bd0; - const size_t off_args = ck_tile::integer_divide_ceil(off_ks + i64_arr_bytes, size_t(16)) * 16; + if (Kb != K) { + NVTE_ERROR("grouped_gemm_ck_tile: K mismatch between A and B in group ", i); + return false; + } - const size_t args_bytes = sizeof(ck_tile::GemmTransKernelArg<>) * static_cast(group_num); - const size_t needed = off_args + args_bytes; + if (d.shape[0] != M || d.shape[1] != N) { + NVTE_ERROR("grouped_gemm_ck_tile: D shape mismatch in group ", i); + return false; + } - if (workspace == nullptr || workspace_bytes < needed) { - NVTE_ERROR("grouped_gemm_ck_tile: insufficient workspace. Needed bytes=", needed); + const ck_tile::index_t stride_A = a.shape[1]; + const ck_tile::index_t stride_B = b.shape[1]; + const ck_tile::index_t stride_E = d.shape[1]; + + descs.emplace_back( + a.dptr, + b.dptr, + std::array{}, + d.dptr, + 1, + M, + N, + K, + stride_A, + stride_B, + std::array{}, + stride_E); + } + + auto kargs = Kernel::MakeKargs(descs); + if (!Kernel::IsSupportedArgument(kargs)) { + NVTE_ERROR("grouped_gemm_ck_tile: CK-Tile kernel arguments not supported for this config."); return false; } - auto* base = static_cast(workspace); + HIP_CHECK_ERROR(hipMemcpyAsync(workspace, + kargs.data(), + kargs.size() * sizeof(typename decltype(kargs)::value_type), + hipMemcpyHostToDevice, + stream)); - void** d_a_ptrs = reinterpret_cast(base + off_a_ptrs); - void** d_b_ptrs = reinterpret_cast(base + off_b_ptrs); - void** d_d_ptrs = reinterpret_cast(base + off_d_ptrs); - int64_t* d_ms = reinterpret_cast(base + off_ms); - int64_t* d_ns = reinterpret_cast(base + off_ns); - int64_t* d_ks = reinterpret_cast(base + off_ks); + const ck_tile::stream_config s{stream}; + launch_tileloop_kernel(s, group_num, workspace); + return true; +} - auto* d_args = reinterpret_cast*>(base + off_args); +static inline bool infer_gemm_mode_group0(const transformer_engine::Tensor* const* A, + const transformer_engine::Tensor* const* B, + transformer_engine::Tensor* const* D, + int group_num, + const transformer_engine::Tensor* const*& A_use, + const transformer_engine::Tensor* const*& B_use, + bool& transA_use, + bool& transB_use) +{ + A_use = A; + B_use = B; + transA_use = false; + transB_use = false; - // Build host-side staging buffers and memcpy to device - std::vector h_a_ptrs(group_num); - std::vector h_b_ptrs(group_num); - std::vector h_d_ptrs(group_num); - std::vector h_ms(group_num); - std::vector h_ns(group_num); - std::vector h_ks(group_num); + if (group_num <= 0) + return true; - // Infer global N/K from group 0 - const auto& a0 = data_view(*A_use[0]); - const auto& b0 = data_view(*B_use[0]); + const auto& a0 = data_view(*A[0]); + const auto& b0 = data_view(*B[0]); const auto& d0 = data_view(*D[0]); - if (a0.shape.size() != 2 || b0.shape.size() != 2 || d0.shape.size() != 2) { - NVTE_ERROR("grouped_gemm_ck_tile: expected 2D tensors."); - return false; - } - printf("grouped_gemm_ck_tile gg0 A=[%zu,%zu] B=[%zu,%zu] D=[%zu,%zu] transA=%d transB=%d\n", - a0.shape[0], a0.shape[1], - b0.shape[0], b0.shape[1], - d0.shape[0], d0.shape[1], - (int)transA_use, (int)transB_use); - - // Infer logical M/K from A depending on transA - // - NN/NT: A stored [M,K] - // - TN: A stored [K,M] row-major, interpret as ColMajor [M,K] - const int64_t m0 = transA_use ? static_cast(a0.shape[1]) : static_cast(a0.shape[0]); - const int64_t k0 = transA_use ? static_cast(a0.shape[0]) : static_cast(a0.shape[1]); - - const int64_t n0 = transB_use ? static_cast(b0.shape[0]) - : static_cast(b0.shape[1]); - const int64_t kb = transB_use ? static_cast(b0.shape[1]) - : static_cast(b0.shape[0]); - if (kb != k0) { - NVTE_ERROR("grouped_gemm_ck_tile: K mismatch between A and B in group 0."); - return false; - } - if (static_cast(d0.shape[0]) != m0 || static_cast(d0.shape[1]) != n0) { - NVTE_ERROR("grouped_gemm_ck_tile: D shape mismatch in group 0."); + if (a0.shape.size() != 2 || b0.shape.size() != 2 || d0.shape.size() != 2) { return false; } - for (int i = 0; i < group_num; ++i) { - const auto& ai = data_view(*A_use[i]); - const auto& bi = data_view(*B_use[i]); - const auto& di = data_view(*D[i]); - - if (ai.shape.size() != 2 || bi.shape.size() != 2 || di.shape.size() != 2) { - NVTE_ERROR("grouped_gemm_ck_tile: expected all groups to be 2D."); - return false; + const int64_t Ad0 = a0.shape[0]; + const int64_t Ad1 = a0.shape[1]; + const int64_t Bd0 = b0.shape[0]; + const int64_t Bd1 = b0.shape[1]; + const int64_t Dm = d0.shape[0]; + const int64_t Dn = d0.shape[1]; + + auto check = [&](bool do_swap, bool ta, bool tb) -> bool { + const int64_t A0d0 = do_swap ? Bd0 : Ad0; + const int64_t A0d1 = do_swap ? Bd1 : Ad1; + const int64_t B0d0 = do_swap ? Ad0 : Bd0; + const int64_t B0d1 = do_swap ? Ad1 : Bd1; + + const int64_t M = ta ? A0d1 : A0d0; + const int64_t K = ta ? A0d0 : A0d1; + const int64_t N = tb ? B0d0 : B0d1; + const int64_t Kb = tb ? B0d1 : B0d0; + + return (M == Dm) && (N == Dn) && (K == Kb); + }; + + // Try all candidates; prefer "no swap" first, then swap. + for (bool do_swap : {false, true}) { + for (bool ta : {false, true}) { + for (bool tb : {false, true}) { + if (check(do_swap, ta, tb)) { + A_use = do_swap ? B : A; + B_use = do_swap ? A : B; + transA_use = ta; + transB_use = tb; + return true; + } + } } + } - const int64_t mi = transA_use ? static_cast(ai.shape[1]) : static_cast(ai.shape[0]); - const int64_t ki = transA_use ? static_cast(ai.shape[0]) : static_cast(ai.shape[1]); - const int64_t ni = transB_use ? static_cast(bi.shape[0]) - : static_cast(bi.shape[1]); - const int64_t kbi = transB_use ? static_cast(bi.shape[1]) - : static_cast(bi.shape[0]); + // Nothing matched D = op(A) * op(B) + return false; +} - if (ki != k0 || ni != n0 || kbi != k0) { - NVTE_ERROR("grouped_gemm_ck_tile: N/K must be constant across groups."); - return false; - } - if (static_cast(di.shape[0]) != mi || static_cast(di.shape[1]) != n0) { - NVTE_ERROR("grouped_gemm_ck_tile: D shape mismatch in group ", i); - return false; - } +bool grouped_gemm_ck_tile(const transformer_engine::Tensor* const* A, + const transformer_engine::Tensor* const* B, + transformer_engine::Tensor* const* D, + int group_num, + bool transA, + bool transB, + void* workspace, + size_t workspace_bytes, + bool accumulate, + hipStream_t stream) +{ + const transformer_engine::Tensor* const* A_use = A; + const transformer_engine::Tensor* const* B_use = B; + bool transA_use = transA; + bool transB_use = transB; - h_a_ptrs[i] = ai.dptr; - h_b_ptrs[i] = bi.dptr; - h_d_ptrs[i] = di.dptr; - h_ms[i] = mi; - h_ns[i] = n0; - h_ks[i] = k0; + // If TE's flags disagree with storage, infer the correct mode from shapes. + if (!infer_gemm_mode_group0(A, B, D, group_num, A_use, B_use, transA_use, transB_use)) { + const auto& a0 = data_view(*A[0]); + const auto& b0 = data_view(*B[0]); + const auto& d0 = data_view(*D[0]); + NVTE_ERROR("grouped_gemm_ck_tile: could not infer a consistent GEMM mode from shapes. ", + "A0=[", a0.shape[0], ",", a0.shape[1], "] ", + "B0=[", b0.shape[0], ",", b0.shape[1], "] ", + "D0=[", d0.shape[0], ",", d0.shape[1], "] ", + "given flags transA=", transA, " transB=", transB); + return false; } - HIP_CHECK_ERROR(hipMemcpyAsync(d_a_ptrs, h_a_ptrs.data(), ptr_arr_bytes, hipMemcpyHostToDevice, - reinterpret_cast(stream))); - HIP_CHECK_ERROR(hipMemcpyAsync(d_b_ptrs, h_b_ptrs.data(), ptr_arr_bytes, hipMemcpyHostToDevice, - reinterpret_cast(stream))); - HIP_CHECK_ERROR(hipMemcpyAsync(d_d_ptrs, h_d_ptrs.data(), ptr_arr_bytes, hipMemcpyHostToDevice, - reinterpret_cast(stream))); - HIP_CHECK_ERROR(hipMemcpyAsync(d_ms, h_ms.data(), i64_arr_bytes, hipMemcpyHostToDevice, - reinterpret_cast(stream))); - HIP_CHECK_ERROR(hipMemcpyAsync(d_ns, h_ns.data(), i64_arr_bytes, hipMemcpyHostToDevice, - reinterpret_cast(stream))); - HIP_CHECK_ERROR(hipMemcpyAsync(d_ks, h_ks.data(), i64_arr_bytes, hipMemcpyHostToDevice, - reinterpret_cast(stream))); - - // Leading dimensions for CK layouts: - // A is row-major [M,K] and we only support transA=false -> ALayout=RowMajor, strideA=K - // B is row-major [K,N] if NN -> BLayout=RowMajor, strideB=N - // B is row-major [N,K] if NT -> BLayout=ColMajor (logical [K,N]), strideB=K - const ck_tile::index_t strideA = static_cast(transA_use ? m0 : k0); - const ck_tile::index_t strideB = static_cast(transB_use ? k0 : n0); - const ck_tile::index_t strideD = static_cast(n0); - - // Build CK arg structs on device - { - const int threads = 256; - const int blocks = (group_num + threads - 1) / threads; - const ck_tile::index_t k_batch = 1; - if (a_dtype == transformer_engine::DType::kFloat16) { - using AType = TeDTypeToCk::type; - using BType = AType; - using CType = AType; - hipLaunchKernelGGL((build_args_kernel), - dim3(blocks), dim3(threads), 0, - reinterpret_cast(stream), - d_args, - const_cast(reinterpret_cast(d_a_ptrs)), - const_cast(reinterpret_cast(d_b_ptrs)), - reinterpret_cast(d_d_ptrs), - d_ms, d_ns, d_ks, - static_cast(group_num), - strideA, strideB, strideD, - k_batch); - } else { - using AType = TeDTypeToCk::type; - using BType = AType; - using CType = AType; - hipLaunchKernelGGL((build_args_kernel), - dim3(blocks), dim3(threads), 0, - reinterpret_cast(stream), - d_args, - const_cast(reinterpret_cast(d_a_ptrs)), - const_cast(reinterpret_cast(d_b_ptrs)), - reinterpret_cast(d_d_ptrs), - d_ms, d_ns, d_ks, - static_cast(group_num), - strideA, strideB, strideD, - k_batch); - } - } + const auto a_dtype = A_use[0]->dtype(); - // Runner selection - const uint32_t num_cu = (num_cu_override != 0) ? num_cu_override - : static_cast(get_num_cu_for_stream(stream)); - const ck_tile::stream_config stream_cfg{reinterpret_cast(stream)}; + const auto memop = accumulate ? ck_tile::memory_operation_enum::atomic_add + : ck_tile::memory_operation_enum::set; - // Choose layouts based on transB if (a_dtype == transformer_engine::DType::kFloat16) { using T = TeDTypeToCk::type; - if (!transB_use) { - // NN: A RowMajor, B RowMajor, D RowMajor - Runner runner; - runner.run(stream_cfg, static_cast(group_num), d_args, num_cu); - } else { - // NT: B is stored as [N,K] row-major -> treat as ColMajor logical [K,N] - Runner runner; - runner.run(stream_cfg, static_cast(group_num), d_args, num_cu); - } + if (!transA_use && !transB_use) + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, false, false, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, false, false, workspace, workspace_bytes, stream); + + if (!transA_use && transB_use) + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, false, true, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, false, true, workspace, workspace_bytes, stream); + + if (transA_use && !transB_use) + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, true, false, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, true, false, workspace, workspace_bytes, stream); + + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, true, true, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, true, true, workspace, workspace_bytes, stream); } else { using T = TeDTypeToCk::type; - if (!transB_use) { - Runner runner; - runner.run(stream_cfg, static_cast(group_num), d_args, num_cu); - } else { - Runner runner; - runner.run(stream_cfg, static_cast(group_num), d_args, num_cu); - } + if (!transA_use && !transB_use) + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, false, false, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, false, false, workspace, workspace_bytes, stream); + + if (!transA_use && transB_use) + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, false, true, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, false, true, workspace, workspace_bytes, stream); + + if (transA_use && !transB_use) + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, true, false, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, true, false, workspace, workspace_bytes, stream); + + return (memop == ck_tile::memory_operation_enum::set) + ? run_grouped_impl( + A_use, B_use, D, group_num, true, true, workspace, workspace_bytes, stream) + : run_grouped_impl( + A_use, B_use, D, group_num, true, true, workspace, workspace_bytes, stream); } - - return true; } bool grouped_gemm_ck_tile(const NVTETensor* A, @@ -418,7 +353,9 @@ bool grouped_gemm_ck_tile(const NVTETensor* A, bool transA, bool transB, NVTETensor* workspace, - hipStream_t stream) { + bool accumulate, + hipStream_t stream) +{ if (group_num <= 0) return true; @@ -444,6 +381,6 @@ bool grouped_gemm_ck_tile(const NVTETensor* A, return grouped_gemm_ck_tile(A_te.data(), B_te.data(), D_te.data(), group_num, transA, transB, - ws_ptr, ws_bytes, + ws_ptr, ws_bytes, accumulate, stream); } diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index e583bc14f..fcbdac91c 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -808,20 +808,17 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor A_dt == transformer_engine::DType::kBFloat16); }; - if (use_ck && - is_supported_dtype() && - !accumulate) { - - if (grouped_gemm_ck_tile(A, B, D, num_gemms, transa, transb, workspace, stream)) { - printf("grouped_gemm_ck_tile done.\n"); + if (use_ck && is_supported_dtype()) { + if (grouped_gemm_ck_tile(A, B, D, num_gemms, transa, transb, workspace, accumulate, stream)) { + // NVTE_WARN("grouped_gemm_ck_tile done.\n"); return; } else if (warn_fallback) { NVTE_WARN("Fallback to hipBLASLt grouped GEMM (grouped_gemm_ck_tile returned false)."); } + } else if (warn_fallback) { + NVTE_WARN("Fallback to hipBLASLt grouped GEMM (CK config unsupported or CK disabled). use_ck=", use_ck, " is_supported_dtype=", is_supported_dtype()); } - NVTE_WARN("Fallback to hipBLASLt grouped GEMM (CK config unsupported).\n"); - multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, workspace, accumulate, use_split_accumulator, math_sm_count, stream); #else From 4e9ead9a5a8de6266a44e873ef01d6bf6a147e61 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 30 Jan 2026 14:47:30 -0600 Subject: [PATCH 13/51] grid improvements --- gmm2.py | 10 ++++------ transformer_engine/common/gemm/ck_grouped_gemm.cuh | 5 +++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/gmm2.py b/gmm2.py index 016304d9f..8966afa02 100644 --- a/gmm2.py +++ b/gmm2.py @@ -20,7 +20,7 @@ x = torch.randn(M_total, K, device=device, dtype=dtype) # Timing helper -def bench_cuda(fn, warmup=20, iters=100, name=""): +def bench_cuda(fn, warmup=20, iters=100): # Warmup for _ in range(warmup): fn() @@ -34,8 +34,6 @@ def bench_cuda(fn, warmup=20, iters=100, name=""): end = time.time() avg_ms = (end - start) * 1000.0 / iters - if name: - print(f"{name}: {avg_ms:.3f} ms (avg over {iters} runs, {warmup} warmup)") return avg_ms # TE GroupedLinear @@ -44,7 +42,7 @@ def bench_cuda(fn, warmup=20, iters=100, name=""): def te_run(): return glinear(x, m_splits=m_splits) -te_ms = bench_cuda(te_run, warmup=20, iters=100, name="TE GroupedLinear") +te_ms = bench_cuda(te_run, warmup=20, iters=100) # Grab weights for reference path Ws = [getattr(glinear, f"weight{e}") for e in range(E)] # each [N, K] @@ -69,7 +67,7 @@ def torch_run(): y_ref_buf[o:o+m].copy_(x[o:o+m] @ W[e].transpose(0, 1)) return y_ref_buf -torch_ms = bench_cuda(torch_run, warmup=20, iters=100, name="Torch loop (prealloc out)") +torch_ms = bench_cuda(torch_run, warmup=20, iters=100) # Compare outputs y_te = te_run() @@ -79,7 +77,7 @@ def torch_run(): max_abs = diff.abs().max().item() rel = (diff.abs() / (y_ref.float().abs() + 1e-6)).max().item() -print(f"\nErrors:") +print(f"Errors:") print(f" {y_te.shape=}, {y_ref.shape=}") print(" max_abs_err:", max_abs) print(" max_rel_err:", rel) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cuh b/transformer_engine/common/gemm/ck_grouped_gemm.cuh index 1171e33f9..2ae402c47 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cuh +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cuh @@ -82,11 +82,11 @@ public: template static inline void launch_tileloop_kernel(const ck_tile::stream_config& s, + dim3 grids, ck_tile::index_t group_num, void* kargs_dev) { const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); ck_tile::launch_kernel( s, @@ -169,6 +169,7 @@ static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, stride_E); } + const dim3 grids = Kernel::GridSize(descs); auto kargs = Kernel::MakeKargs(descs); if (!Kernel::IsSupportedArgument(kargs)) { NVTE_ERROR("grouped_gemm_ck_tile: CK-Tile kernel arguments not supported for this config."); @@ -182,7 +183,7 @@ static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, stream)); const ck_tile::stream_config s{stream}; - launch_tileloop_kernel(s, group_num, workspace); + launch_tileloop_kernel(s, grids, group_num, workspace); return true; } From 259645cdbee80c336acaa1b47420a4db2873c47a Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 3 Feb 2026 16:32:07 -0600 Subject: [PATCH 14/51] restructure --- gmm2.py | 4 +- tests/pytorch/test_numerics.py | 6 - transformer_engine/common/CMakeLists.txt | 12 +- ...k_grouped_gemm.cuh => ck_grouped_gemm.cpp} | 127 +++++++++--------- .../common/gemm/ck_grouped_gemm.h | 11 ++ .../common/gemm/cublaslt_gemm.cu | 6 +- 6 files changed, 80 insertions(+), 86 deletions(-) rename transformer_engine/common/gemm/{ck_grouped_gemm.cuh => ck_grouped_gemm.cpp} (83%) create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm.h diff --git a/gmm2.py b/gmm2.py index 8966afa02..938d69247 100644 --- a/gmm2.py +++ b/gmm2.py @@ -5,8 +5,8 @@ torch.manual_seed(0) -os.environ["NVTE_USE_CK_GROUPED_GEMM"] = "1" -os.environ["NVTE_CK_GROUPED_GEMM_WARN_FALLBACK"] = "1" +os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" +os.environ["NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK"] = "1" device = "cuda" dtype = torch.bfloat16 diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 5f1489f88..8fb65f7f6 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2133,9 +2133,6 @@ def test_grouped_linear_accuracy_cutlass( delay_wgrad_compute, ): os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" - if IS_HIP_EXTENSION: - os.environ["NVTE_USE_CK_GROUPED_GEMM"] = "1" - os.environ["NVTE_CK_GROUPED_GEMM_WARN_FALLBACK"] = "1" test_grouped_linear_accuracy( dtype, num_gemms, @@ -2150,9 +2147,6 @@ def test_grouped_linear_accuracy_cutlass( use_cutlass=True, ) os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) - if IS_HIP_EXTENSION: - os.environ.pop("NVTE_USE_CK_GROUPED_GEMM", None) - os.environ.pop("NVTE_CK_GROUPED_GEMM_WARN_FALLBACK", None) @pytest.mark.parametrize("dtype", param_types, ids=str) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 56207f16d..4a04a630f 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -203,6 +203,7 @@ else() fused_attn_rocm/fused_attn_ck.cpp fused_attn_rocm/utils.cpp gemm/rocm_gemm.cu + gemm/ck_grouped_gemm.cpp amd_detail/system.cpp) # process source code files @@ -241,14 +242,6 @@ endif() target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") -set(CK_ROOT ${CMAKE_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel) - -target_include_directories(transformer_engine - BEFORE PRIVATE - ${CK_ROOT}/include -) - - if (USE_CUDA) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) set_source_files_properties( @@ -259,6 +252,9 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) else() message(FATAL_ERROR "cutlass gemm/cutlass_grouped_gemm.cu kernel required sm 90a") endif() +else() + set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel) + target_include_directories(transformer_engine PRIVATE ${CK_ROOT}/include) endif() #USE_CUDA # Configure dependencies diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cuh b/transformer_engine/common/gemm/ck_grouped_gemm.cpp similarity index 83% rename from transformer_engine/common/gemm/ck_grouped_gemm.cuh rename to transformer_engine/common/gemm/ck_grouped_gemm.cpp index 2ae402c47..995812574 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cuh +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -2,6 +2,9 @@ #include +#include +#include "../common.h" + #include "ck_tile/core.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" @@ -233,7 +236,7 @@ static inline bool infer_gemm_mode_group0(const transformer_engine::Tensor* cons return (M == Dm) && (N == Dn) && (K == Kb); }; - // Try all candidates; prefer "no swap" first, then swap. + // Try all candidates; prefer "no swap" first, then swap for (bool do_swap : {false, true}) { for (bool ta : {false, true}) { for (bool tb : {false, true}) { @@ -252,27 +255,51 @@ static inline bool infer_gemm_mode_group0(const transformer_engine::Tensor* cons return false; } -bool grouped_gemm_ck_tile(const transformer_engine::Tensor* const* A, - const transformer_engine::Tensor* const* B, - transformer_engine::Tensor* const* D, +bool grouped_gemm_ck_tile(const NVTETensor* A, + const NVTETensor* B, + NVTETensor* D, int group_num, bool transA, bool transB, - void* workspace, - size_t workspace_bytes, + NVTETensor* workspace, bool accumulate, hipStream_t stream) { - const transformer_engine::Tensor* const* A_use = A; - const transformer_engine::Tensor* const* B_use = B; + if (group_num <= 0) + return true; + + // Convert A/B/D arrays into TE Tensor* arrays + std::vector A_te(group_num); + std::vector B_te(group_num); + std::vector D_te(group_num); + + for (int i = 0; i < group_num; ++i) { + A_te[i] = transformer_engine::convertNVTETensorCheck(A[i]); + B_te[i] = transformer_engine::convertNVTETensorCheck(B[i]); + D_te[i] = transformer_engine::convertNVTETensorCheck(D[i]); + } + + // Workspace pointer + bytes + void* ws_ptr = nullptr; + size_t ws_bytes = 0; + if (workspace) { + auto* ws_te = transformer_engine::convertNVTETensorCheck(*workspace); + ws_ptr = ws_te->data.dptr; + ws_bytes = ws_te->data.numel() * + transformer_engine::typeToSize(ws_te->data.dtype); + } + + const transformer_engine::Tensor* const* A_use = A_te.data(); + const transformer_engine::Tensor* const* B_use = B_te.data(); bool transA_use = transA; bool transB_use = transB; // If TE's flags disagree with storage, infer the correct mode from shapes. - if (!infer_gemm_mode_group0(A, B, D, group_num, A_use, B_use, transA_use, transB_use)) { - const auto& a0 = data_view(*A[0]); - const auto& b0 = data_view(*B[0]); - const auto& d0 = data_view(*D[0]); + if (!infer_gemm_mode_group0(A_te.data(), B_te.data(), D_te.data(), + group_num, A_use, B_use, transA_use, transB_use)) { + const auto& a0 = data_view(*A_te[0]); + const auto& b0 = data_view(*B_te[0]); + const auto& d0 = data_view(*D_te[0]); NVTE_ERROR("grouped_gemm_ck_tile: could not infer a consistent GEMM mode from shapes. ", "A0=[", a0.shape[0], ",", a0.shape[1], "] ", "B0=[", b0.shape[0], ",", b0.shape[1], "] ", @@ -292,96 +319,62 @@ bool grouped_gemm_ck_tile(const transformer_engine::Tensor* const* A, if (!transA_use && !transB_use) return (memop == ck_tile::memory_operation_enum::set) ? run_grouped_impl( - A_use, B_use, D, group_num, false, false, workspace, workspace_bytes, stream) + A_use, B_use, D_te.data(), group_num, false, false, ws_ptr, ws_bytes, stream) : run_grouped_impl( - A_use, B_use, D, group_num, false, false, workspace, workspace_bytes, stream); + A_use, B_use, D_te.data(), group_num, false, false, ws_ptr, ws_bytes, stream); if (!transA_use && transB_use) return (memop == ck_tile::memory_operation_enum::set) ? run_grouped_impl( - A_use, B_use, D, group_num, false, true, workspace, workspace_bytes, stream) + A_use, B_use, D_te.data(), group_num, false, true, ws_ptr, ws_bytes, stream) : run_grouped_impl( - A_use, B_use, D, group_num, false, true, workspace, workspace_bytes, stream); + A_use, B_use, D_te.data(), group_num, false, true, ws_ptr, ws_bytes, stream); if (transA_use && !transB_use) return (memop == ck_tile::memory_operation_enum::set) ? run_grouped_impl( - A_use, B_use, D, group_num, true, false, workspace, workspace_bytes, stream) + A_use, B_use, D_te.data(), group_num, true, false, ws_ptr, ws_bytes, stream) : run_grouped_impl( - A_use, B_use, D, group_num, true, false, workspace, workspace_bytes, stream); + A_use, B_use, D_te.data(), group_num, true, false, ws_ptr, ws_bytes, stream); return (memop == ck_tile::memory_operation_enum::set) ? run_grouped_impl( - A_use, B_use, D, group_num, true, true, workspace, workspace_bytes, stream) + A_use, B_use, D_te.data(), group_num, true, true, ws_ptr, ws_bytes, stream) : run_grouped_impl( - A_use, B_use, D, group_num, true, true, workspace, workspace_bytes, stream); - } else { + A_use, B_use, D_te.data(), group_num, true, true, ws_ptr, ws_bytes, stream); + + } else if (a_dtype == transformer_engine::DType::kBFloat16) { using T = TeDTypeToCk::type; if (!transA_use && !transB_use) return (memop == ck_tile::memory_operation_enum::set) ? run_grouped_impl( - A_use, B_use, D, group_num, false, false, workspace, workspace_bytes, stream) + A_use, B_use, D_te.data(), group_num, false, false, ws_ptr, ws_bytes, stream) : run_grouped_impl( - A_use, B_use, D, group_num, false, false, workspace, workspace_bytes, stream); + A_use, B_use, D_te.data(), group_num, false, false, ws_ptr, ws_bytes, stream); if (!transA_use && transB_use) return (memop == ck_tile::memory_operation_enum::set) ? run_grouped_impl( - A_use, B_use, D, group_num, false, true, workspace, workspace_bytes, stream) + A_use, B_use, D_te.data(), group_num, false, true, ws_ptr, ws_bytes, stream) : run_grouped_impl( - A_use, B_use, D, group_num, false, true, workspace, workspace_bytes, stream); + A_use, B_use, D_te.data(), group_num, false, true, ws_ptr, ws_bytes, stream); if (transA_use && !transB_use) return (memop == ck_tile::memory_operation_enum::set) ? run_grouped_impl( - A_use, B_use, D, group_num, true, false, workspace, workspace_bytes, stream) + A_use, B_use, D_te.data(), group_num, true, false, ws_ptr, ws_bytes, stream) : run_grouped_impl( - A_use, B_use, D, group_num, true, false, workspace, workspace_bytes, stream); + A_use, B_use, D_te.data(), group_num, true, false, ws_ptr, ws_bytes, stream); return (memop == ck_tile::memory_operation_enum::set) ? run_grouped_impl( - A_use, B_use, D, group_num, true, true, workspace, workspace_bytes, stream) + A_use, B_use, D_te.data(), group_num, true, true, ws_ptr, ws_bytes, stream) : run_grouped_impl( - A_use, B_use, D, group_num, true, true, workspace, workspace_bytes, stream); - } -} - -bool grouped_gemm_ck_tile(const NVTETensor* A, - const NVTETensor* B, - NVTETensor* D, - int group_num, - bool transA, - bool transB, - NVTETensor* workspace, - bool accumulate, - hipStream_t stream) -{ - if (group_num <= 0) - return true; - - // Convert A/B/D arrays into TE Tensor* arrays - std::vector A_te(group_num); - std::vector B_te(group_num); - std::vector D_te(group_num); - - for (int i = 0; i < group_num; ++i) { - A_te[i] = transformer_engine::convertNVTETensorCheck(A[i]); - B_te[i] = transformer_engine::convertNVTETensorCheck(B[i]); - D_te[i] = transformer_engine::convertNVTETensorCheck(D[i]); - } + A_use, B_use, D_te.data(), group_num, true, true, ws_ptr, ws_bytes, stream); - // Workspace pointer + bytes - void* ws_ptr = nullptr; - size_t ws_bytes = 0; - if (workspace) { - auto* ws_te = transformer_engine::convertNVTETensorCheck(*workspace); - ws_ptr = ws_te->data.dptr; - ws_bytes = ws_te->data.numel() * transformer_engine::typeToSize(ws_te->data.dtype); + } else { + NVTE_ERROR("grouped_gemm_ck_tile: unsupported dtype (expected FP16/BF16)."); + return false; } - - return grouped_gemm_ck_tile(A_te.data(), B_te.data(), D_te.data(), - group_num, transA, transB, - ws_ptr, ws_bytes, accumulate, - stream); } diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.h b/transformer_engine/common/gemm/ck_grouped_gemm.h new file mode 100644 index 000000000..16755942e --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm.h @@ -0,0 +1,11 @@ +/* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. */ + +bool grouped_gemm_ck_tile(const NVTETensor* A, + const NVTETensor* B, + NVTETensor* D, + int group_num, + bool transA, + bool transB, + NVTETensor* workspace, + bool accumulate, + hipStream_t stream); diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index fcbdac91c..26a786215 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -28,7 +28,7 @@ #ifndef __HIP_PLATFORM_AMD__ #include "cutlass_grouped_gemm.cuh" #else -#include "ck_grouped_gemm.cuh" +#include "ck_grouped_gemm.h" #endif #ifndef __HIP_PLATFORM_AMD__ @@ -791,9 +791,9 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor NVTE_API_CALL(nvte_multi_tensor_gemm); #ifdef __HIP_PLATFORM_AMD__ - const bool use_ck = transformer_engine::getenv("NVTE_USE_CK_GROUPED_GEMM", false); + const bool use_ck = transformer_engine::getenv("NVTE_USE_CUTLASS_GROUPED_GEMM", false); const bool warn_fallback = - transformer_engine::getenv("NVTE_CK_GROUPED_GEMM_WARN_FALLBACK", false); + transformer_engine::getenv("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", false); auto is_supported_dtype = [&]() -> bool { auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); From 9986bd4a684a2ab6a7e61e398323ed22e8070d6a Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 4 Feb 2026 15:41:40 -0600 Subject: [PATCH 15/51] reduce code duplication & simplify --- .../common/gemm/ck_grouped_gemm.cpp | 141 +++++++----------- 1 file changed, 50 insertions(+), 91 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp index 995812574..3bebdc117 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -12,10 +12,10 @@ using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; -template -struct TeDTypeToCk; -template <> struct TeDTypeToCk { using type = ck_tile::half_t; }; -template <> struct TeDTypeToCk{ using type = ck_tile::bfloat16_t; }; +template struct TeTypeToCkType; +template <> struct TeTypeToCkType { using type = ck_tile::half_t; }; +template <> struct TeTypeToCkType { using type = ck_tile::bfloat16_t; }; + static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { return t.data; // rowwise data view @@ -48,8 +48,7 @@ template -class Runner{ -public: +struct Runner{ using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, ck_tile::sequence, @@ -83,22 +82,6 @@ class Runner{ using Kernel = ck_tile::GroupedGemmKernel; }; -template -static inline void launch_tileloop_kernel(const ck_tile::stream_config& s, - dim3 grids, - ck_tile::index_t group_num, - void* kargs_dev) -{ - const dim3 blocks = Kernel::BlockSize(); - - ck_tile::launch_kernel( - s, - ck_tile::make_kernel<1>( - Kernel{}, grids, blocks, 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_dev), - group_num)); -} - template static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, @@ -186,7 +169,14 @@ static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, stream)); const ck_tile::stream_config s{stream}; - launch_tileloop_kernel(s, grids, group_num, workspace); + const dim3 blocks = Kernel::BlockSize(); + + ck_tile::launch_kernel( + s, + ck_tile::make_kernel<1>( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(workspace), + group_num)); return true; } @@ -255,6 +245,30 @@ static inline bool infer_gemm_mode_group0(const transformer_engine::Tensor* cons return false; } +template +static inline bool dispatch_grouped(bool transA_use, + bool transB_use, + const transformer_engine::Tensor* const* A_use, + const transformer_engine::Tensor* const* B_use, + transformer_engine::Tensor* const* D, + int group_num, + void* workspace, + size_t workspace_bytes, + hipStream_t stream) { + +// FIXME: This could be a templated lambda function in C++20. +#define CALL(ALayout_, BLayout_, ta_, tb_) \ + return run_grouped_impl( \ + A_use, B_use, D, group_num, (ta_), (tb_), workspace, workspace_bytes, stream) + + if (!transA_use && !transB_use) { CALL(RowMajor, RowMajor, false, false); } + if (!transA_use && transB_use) { CALL(RowMajor, ColMajor, false, true ); } + if ( transA_use && !transB_use) { CALL(ColMajor, RowMajor, true, false); } + /* transA_use && transB_use */ { CALL(ColMajor, ColMajor, true, true ); } + +#undef CALL +} + bool grouped_gemm_ck_tile(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, @@ -268,7 +282,7 @@ bool grouped_gemm_ck_tile(const NVTETensor* A, if (group_num <= 0) return true; - // Convert A/B/D arrays into TE Tensor* arrays + // Convert A/B/D arrays into TE Tensor arrays std::vector A_te(group_num); std::vector B_te(group_num); std::vector D_te(group_num); @@ -310,71 +324,16 @@ bool grouped_gemm_ck_tile(const NVTETensor* A, const auto a_dtype = A_use[0]->dtype(); - const auto memop = accumulate ? ck_tile::memory_operation_enum::atomic_add - : ck_tile::memory_operation_enum::set; - - if (a_dtype == transformer_engine::DType::kFloat16) { - using T = TeDTypeToCk::type; - - if (!transA_use && !transB_use) - return (memop == ck_tile::memory_operation_enum::set) - ? run_grouped_impl( - A_use, B_use, D_te.data(), group_num, false, false, ws_ptr, ws_bytes, stream) - : run_grouped_impl( - A_use, B_use, D_te.data(), group_num, false, false, ws_ptr, ws_bytes, stream); - - if (!transA_use && transB_use) - return (memop == ck_tile::memory_operation_enum::set) - ? run_grouped_impl( - A_use, B_use, D_te.data(), group_num, false, true, ws_ptr, ws_bytes, stream) - : run_grouped_impl( - A_use, B_use, D_te.data(), group_num, false, true, ws_ptr, ws_bytes, stream); - - if (transA_use && !transB_use) - return (memop == ck_tile::memory_operation_enum::set) - ? run_grouped_impl( - A_use, B_use, D_te.data(), group_num, true, false, ws_ptr, ws_bytes, stream) - : run_grouped_impl( - A_use, B_use, D_te.data(), group_num, true, false, ws_ptr, ws_bytes, stream); - - return (memop == ck_tile::memory_operation_enum::set) - ? run_grouped_impl( - A_use, B_use, D_te.data(), group_num, true, true, ws_ptr, ws_bytes, stream) - : run_grouped_impl( - A_use, B_use, D_te.data(), group_num, true, true, ws_ptr, ws_bytes, stream); - - } else if (a_dtype == transformer_engine::DType::kBFloat16) { - using T = TeDTypeToCk::type; - - if (!transA_use && !transB_use) - return (memop == ck_tile::memory_operation_enum::set) - ? run_grouped_impl( - A_use, B_use, D_te.data(), group_num, false, false, ws_ptr, ws_bytes, stream) - : run_grouped_impl( - A_use, B_use, D_te.data(), group_num, false, false, ws_ptr, ws_bytes, stream); - - if (!transA_use && transB_use) - return (memop == ck_tile::memory_operation_enum::set) - ? run_grouped_impl( - A_use, B_use, D_te.data(), group_num, false, true, ws_ptr, ws_bytes, stream) - : run_grouped_impl( - A_use, B_use, D_te.data(), group_num, false, true, ws_ptr, ws_bytes, stream); - - if (transA_use && !transB_use) - return (memop == ck_tile::memory_operation_enum::set) - ? run_grouped_impl( - A_use, B_use, D_te.data(), group_num, true, false, ws_ptr, ws_bytes, stream) - : run_grouped_impl( - A_use, B_use, D_te.data(), group_num, true, false, ws_ptr, ws_bytes, stream); - - return (memop == ck_tile::memory_operation_enum::set) - ? run_grouped_impl( - A_use, B_use, D_te.data(), group_num, true, true, ws_ptr, ws_bytes, stream) - : run_grouped_impl( - A_use, B_use, D_te.data(), group_num, true, true, ws_ptr, ws_bytes, stream); - - } else { - NVTE_ERROR("grouped_gemm_ck_tile: unsupported dtype (expected FP16/BF16)."); - return false; - } + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, te_type, { + using T = typename TeTypeToCkType::type; + + if (accumulate) + return dispatch_grouped(transA_use, transB_use, + A_use, B_use, D_te.data(), group_num, + ws_ptr, ws_bytes, stream); + else + return dispatch_grouped(transA_use, transB_use, + A_use, B_use, D_te.data(), group_num, + ws_ptr, ws_bytes, stream); + }); } From 355ec2f156025dfdf479ce5ff84a8c3b6829a4f9 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 4 Feb 2026 15:50:57 -0600 Subject: [PATCH 16/51] make the code more similar to nv, check emopty gelu/bias --- .../common/gemm/ck_grouped_gemm.cpp | 2 +- .../common/gemm/ck_grouped_gemm.h | 2 +- .../common/gemm/cublaslt_gemm.cu | 45 ++++++++++++++----- 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp index 3bebdc117..f4f865057 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -269,7 +269,7 @@ static inline bool dispatch_grouped(bool transA_use, #undef CALL } -bool grouped_gemm_ck_tile(const NVTETensor* A, +bool ck_tile_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, int group_num, diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.h b/transformer_engine/common/gemm/ck_grouped_gemm.h index 16755942e..d539b47f7 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm.h @@ -1,6 +1,6 @@ /* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. */ -bool grouped_gemm_ck_tile(const NVTETensor* A, +bool ck_tile_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, int group_num, diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 26a786215..f367499e9 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -791,10 +791,28 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor NVTE_API_CALL(nvte_multi_tensor_gemm); #ifdef __HIP_PLATFORM_AMD__ - const bool use_ck = transformer_engine::getenv("NVTE_USE_CUTLASS_GROUPED_GEMM", false); + const bool use_cutlass = transformer_engine::getenv("NVTE_USE_CUTLASS_GROUPED_GEMM", false); const bool warn_fallback = transformer_engine::getenv("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", false); + auto cublas_path = [&]() { + multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, + workspace, accumulate, use_split_accumulator, math_sm_count, stream); + }; + + if (!use_cutlass) { + cublas_path(); + return; + } + + auto is_empty_arr = [&](const NVTETensor *p) -> bool { + if (p == nullptr) return true; + for (int i = 0; i < num_gemms; ++i) { + if (transformer_engine::convertNVTETensor(p[i])->has_data()) return false; + } + return true; + }; + auto is_supported_dtype = [&]() -> bool { auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]); @@ -808,19 +826,22 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor A_dt == transformer_engine::DType::kBFloat16); }; - if (use_ck && is_supported_dtype()) { - if (grouped_gemm_ck_tile(A, B, D, num_gemms, transa, transb, workspace, accumulate, stream)) { - // NVTE_WARN("grouped_gemm_ck_tile done.\n"); - return; - } else if (warn_fallback) { - NVTE_WARN("Fallback to hipBLASLt grouped GEMM (grouped_gemm_ck_tile returned false)."); + // CK_Tile Grouped GEMM fast path + // Conditions: + // - No fused epilogue: both bias and pre_gelu_out are empty. + // - Supported dtypes only: FP16/BF16 (FP32 accumulate). + // - use_split_accumulator is ignored for FP16/BF16. + // - grad is irrelevant when bias/pre_gelu_out are empty. + // + // Otherwise, fall back to cuBLAS. + if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype()) { + ck_tile_grouped_gemm(A, B, D, num_gemms, transa, transb, workspace, accumulate, stream); + } else { + if (warn_fallback) { + NVTE_WARN("Fallback to cuBLAS grouped GEMM."); } - } else if (warn_fallback) { - NVTE_WARN("Fallback to hipBLASLt grouped GEMM (CK config unsupported or CK disabled). use_ck=", use_ck, " is_supported_dtype=", is_supported_dtype()); + cublas_path(); } - - multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, - workspace, accumulate, use_split_accumulator, math_sm_count, stream); #else const int current_device = transformer_engine::cuda::current_device(); const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90); From a42f7ca9245d6ffc40e3408c239bbc4593f3d2a8 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 4 Feb 2026 16:18:57 -0600 Subject: [PATCH 17/51] further simplify & make closer to nv --- .../common/gemm/cublaslt_gemm.cu | 72 +++++-------------- 1 file changed, 18 insertions(+), 54 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index f367499e9..881ccb2a1 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -790,59 +790,6 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_gemm); -#ifdef __HIP_PLATFORM_AMD__ - const bool use_cutlass = transformer_engine::getenv("NVTE_USE_CUTLASS_GROUPED_GEMM", false); - const bool warn_fallback = - transformer_engine::getenv("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", false); - - auto cublas_path = [&]() { - multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, - workspace, accumulate, use_split_accumulator, math_sm_count, stream); - }; - - if (!use_cutlass) { - cublas_path(); - return; - } - - auto is_empty_arr = [&](const NVTETensor *p) -> bool { - if (p == nullptr) return true; - for (int i = 0; i < num_gemms; ++i) { - if (transformer_engine::convertNVTETensor(p[i])->has_data()) return false; - } - return true; - }; - - auto is_supported_dtype = [&]() -> bool { - auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); - auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]); - auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]); - auto A_dt = inputA->data.dtype; - auto B_dt = inputB->data.dtype; - auto D_dt = OutputD->data.dtype; - - return (A_dt == B_dt) && (A_dt == D_dt) && - (A_dt == transformer_engine::DType::kFloat16 || - A_dt == transformer_engine::DType::kBFloat16); - }; - - // CK_Tile Grouped GEMM fast path - // Conditions: - // - No fused epilogue: both bias and pre_gelu_out are empty. - // - Supported dtypes only: FP16/BF16 (FP32 accumulate). - // - use_split_accumulator is ignored for FP16/BF16. - // - grad is irrelevant when bias/pre_gelu_out are empty. - // - // Otherwise, fall back to cuBLAS. - if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype()) { - ck_tile_grouped_gemm(A, B, D, num_gemms, transa, transb, workspace, accumulate, stream); - } else { - if (warn_fallback) { - NVTE_WARN("Fallback to cuBLAS grouped GEMM."); - } - cublas_path(); - } -#else const int current_device = transformer_engine::cuda::current_device(); const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90); const bool use_cutlass = transformer_engine::getenv("NVTE_USE_CUTLASS_GROUPED_GEMM", false); @@ -855,7 +802,11 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor }; // Currently only support cutlass group gemm on Hopper Arch +#ifdef __HIP_PLATFORM_AMD__ + if (!use_cutlass) { +#else if (!(is_hopper && use_cutlass)) { +#endif cublas_path(); return; } @@ -889,12 +840,21 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]); auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]); +#ifdef __HIP_PLATFORM_AMD__ + auto A_dt = inputA->data.dtype; + auto B_dt = inputB->data.dtype; + auto D_dt = OutputD->data.dtype; + return (A_dt == B_dt) && (A_dt == D_dt) && + (A_dt == transformer_engine::DType::kFloat16 || + A_dt == transformer_engine::DType::kBFloat16); +#else auto A_type = get_cuda_dtype(inputA->data.dtype); auto B_type = get_cuda_dtype(inputB->data.dtype); auto D_type = get_cuda_dtype(OutputD->data.dtype); return (A_type == B_type) && (A_type == D_type) && ((A_type == CUDA_R_16BF) || (A_type == CUDA_R_16F)); +#endif }; // CUTLASS Grouped GEMM fast path (SM90/TMA) @@ -907,14 +867,18 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor // // Otherwise, fall back to cuBLAS. if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() && +#ifdef __HIP_PLATFORM_AMD__ + true) { + ck_tile_grouped_gemm(A, B, D, num_gemms, transa, transb, workspace, accumulate, stream); +#else all_groups_uniform_k128(B, transb)) { cutlass_grouped_gemm(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate, current_device, math_sm_count, stream); +#endif } else { if (warn_fallback) { NVTE_WARN("Fallback to cuBLAS grouped GEMM."); } cublas_path(); } -#endif // __HIP_PLATFORM_AMD__ } From fac7c111f93a9f1645279bf51490bba485b74dcb Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 4 Feb 2026 17:07:06 -0600 Subject: [PATCH 18/51] add ck_tile reference --- transformer_engine/common/gemm/ck_grouped_gemm.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp index f4f865057..2d4445923 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -44,6 +44,8 @@ struct TileCfg_basic { static constexpr ck_tile::index_t TilePartitionerM01 = 1; }; +// This class instantiates CK_Tile's grouped GEMM pipeline. +// See e.g. https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/03_gemm/universal_gemm_invoker.hpp for reference. template ; - using Kernel = typename R::Kernel; + using Kernel = typename Runner::Kernel; const size_t needed = Kernel::GetWorkSpaceSize(group_num); if (!workspace || workspace_bytes < needed) { @@ -158,7 +159,7 @@ static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, const dim3 grids = Kernel::GridSize(descs); auto kargs = Kernel::MakeKargs(descs); if (!Kernel::IsSupportedArgument(kargs)) { - NVTE_ERROR("grouped_gemm_ck_tile: CK-Tile kernel arguments not supported for this config."); + NVTE_ERROR("grouped_gemm_ck_tile: CK_Tile kernel arguments not supported for this config."); return false; } From 71b97e05acb9f4467d7780e52e70f297fd63e320 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 4 Feb 2026 17:10:59 -0600 Subject: [PATCH 19/51] rename in error messages --- transformer_engine/common/gemm/ck_grouped_gemm.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp index 2d4445923..1c5617cfb 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -100,7 +100,7 @@ static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, const size_t needed = Kernel::GetWorkSpaceSize(group_num); if (!workspace || workspace_bytes < needed) { - NVTE_ERROR("grouped_gemm_ck_tile: insufficient workspace. Needed bytes=", needed); + NVTE_ERROR("ck_tile_grouped_gemm: insufficient workspace. Needed bytes=", needed); return false; } @@ -113,7 +113,7 @@ static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, const auto& d = data_view(*D[i]); if (a.shape.size() != 2 || b.shape.size() != 2 || d.shape.size() != 2) { - NVTE_ERROR("grouped_gemm_ck_tile: expected all groups to be 2D."); + NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be 2D."); return false; } @@ -128,12 +128,12 @@ static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, const int64_t Kb = transB_use ? Bd1 : Bd0; if (Kb != K) { - NVTE_ERROR("grouped_gemm_ck_tile: K mismatch between A and B in group ", i); + NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); return false; } if (d.shape[0] != M || d.shape[1] != N) { - NVTE_ERROR("grouped_gemm_ck_tile: D shape mismatch in group ", i); + NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); return false; } @@ -159,7 +159,7 @@ static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, const dim3 grids = Kernel::GridSize(descs); auto kargs = Kernel::MakeKargs(descs); if (!Kernel::IsSupportedArgument(kargs)) { - NVTE_ERROR("grouped_gemm_ck_tile: CK_Tile kernel arguments not supported for this config."); + NVTE_ERROR("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config."); return false; } @@ -315,7 +315,7 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, const auto& a0 = data_view(*A_te[0]); const auto& b0 = data_view(*B_te[0]); const auto& d0 = data_view(*D_te[0]); - NVTE_ERROR("grouped_gemm_ck_tile: could not infer a consistent GEMM mode from shapes. ", + NVTE_ERROR("ck_tile_grouped_gemm: could not infer a consistent GEMM mode from shapes. ", "A0=[", a0.shape[0], ",", a0.shape[1], "] ", "B0=[", b0.shape[0], ",", b0.shape[1], "] ", "D0=[", d0.shape[0], ",", d0.shape[1], "] ", From dd3ed2f3812fadbed9b154856b3da517cd07011b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 4 Feb 2026 17:46:48 -0600 Subject: [PATCH 20/51] allow flattened higher-D tensors --- .../common/gemm/ck_grouped_gemm.cpp | 61 ++++++++++--------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp index 1c5617cfb..5f7520f9e 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -16,6 +16,17 @@ template struct TeTypeToCkType; template <> struct TeTypeToCkType { using type = ck_tile::half_t; }; template <> struct TeTypeToCkType { using type = ck_tile::bfloat16_t; }; +// Treat TE tensors as generalized 2D matrices by flattening: +// (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim. +static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, + int64_t& d0, int64_t& d1) { + // Require at least a matrix (rank >= 2). Higher ranks are flattened. + if (t.shape().size() < 2) + return false; + d0 = static_cast(t.flat_first_dim()); + d1 = static_cast(t.flat_last_dim()); + return true; +} static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { return t.data; // rowwise data view @@ -112,16 +123,14 @@ static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, const auto& b = data_view(*B_use[i]); const auto& d = data_view(*D[i]); - if (a.shape.size() != 2 || b.shape.size() != 2 || d.shape.size() != 2) { - NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be 2D."); + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; + if (!get_flat_2d_dims(*A_use[i], Ad0, Ad1) || + !get_flat_2d_dims(*B_use[i], Bd0, Bd1) || + !get_flat_2d_dims(*D[i], Dd0, Dd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2 (2D or higher)."); return false; } - const int64_t Ad0 = a.shape[0]; - const int64_t Ad1 = a.shape[1]; - const int64_t Bd0 = b.shape[0]; - const int64_t Bd1 = b.shape[1]; - const int64_t M = transA_use ? Ad1 : Ad0; const int64_t K = transA_use ? Ad0 : Ad1; const int64_t N = transB_use ? Bd0 : Bd1; @@ -132,14 +141,15 @@ static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, return false; } - if (d.shape[0] != M || d.shape[1] != N) { + if (Dd0 != M || Dd1 != N) { NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); return false; } - const ck_tile::index_t stride_A = a.shape[1]; - const ck_tile::index_t stride_B = b.shape[1]; - const ck_tile::index_t stride_E = d.shape[1]; + // Leading dimensions under the flattened-contiguous interpretation + const ck_tile::index_t stride_A = Ad1; + const ck_tile::index_t stride_B = Bd1; + const ck_tile::index_t stride_E = Dd1; descs.emplace_back( a.dptr, @@ -198,21 +208,13 @@ static inline bool infer_gemm_mode_group0(const transformer_engine::Tensor* cons if (group_num <= 0) return true; - const auto& a0 = data_view(*A[0]); - const auto& b0 = data_view(*B[0]); - const auto& d0 = data_view(*D[0]); - - if (a0.shape.size() != 2 || b0.shape.size() != 2 || d0.shape.size() != 2) { + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dm = 0, Dn = 0; + if (!get_flat_2d_dims(*A[0], Ad0, Ad1) || + !get_flat_2d_dims(*B[0], Bd0, Bd1) || + !get_flat_2d_dims(*D[0], Dm, Dn)) { return false; } - const int64_t Ad0 = a0.shape[0]; - const int64_t Ad1 = a0.shape[1]; - const int64_t Bd0 = b0.shape[0]; - const int64_t Bd1 = b0.shape[1]; - const int64_t Dm = d0.shape[0]; - const int64_t Dn = d0.shape[1]; - auto check = [&](bool do_swap, bool ta, bool tb) -> bool { const int64_t A0d0 = do_swap ? Bd0 : Ad0; const int64_t A0d1 = do_swap ? Bd1 : Ad1; @@ -312,13 +314,14 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, // If TE's flags disagree with storage, infer the correct mode from shapes. if (!infer_gemm_mode_group0(A_te.data(), B_te.data(), D_te.data(), group_num, A_use, B_use, transA_use, transB_use)) { - const auto& a0 = data_view(*A_te[0]); - const auto& b0 = data_view(*B_te[0]); - const auto& d0 = data_view(*D_te[0]); + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; + (void)get_flat_2d_dims(*A_te[0], Ad0, Ad1); + (void)get_flat_2d_dims(*B_te[0], Bd0, Bd1); + (void)get_flat_2d_dims(*D_te[0], Dd0, Dd1); NVTE_ERROR("ck_tile_grouped_gemm: could not infer a consistent GEMM mode from shapes. ", - "A0=[", a0.shape[0], ",", a0.shape[1], "] ", - "B0=[", b0.shape[0], ",", b0.shape[1], "] ", - "D0=[", d0.shape[0], ",", d0.shape[1], "] ", + "A0(flat)=[", Ad0, ",", Ad1, "] ", + "B0(flat)=[", Bd0, ",", Bd1, "] ", + "D0(flat)=[", Dd0, ",", Dd1, "] ", "given flags transA=", transA, " transB=", transB); return false; } From ebc005f98af04ed43b7ebd2a7e2864e41746fc5d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 5 Feb 2026 12:50:04 -0600 Subject: [PATCH 21/51] relax tolerance on gfx942 --- tests/pytorch/test_numerics.py | 4 +++- transformer_engine/pytorch/utils.py | 10 +++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 1b849ea78..5f832bd3f 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -28,7 +28,7 @@ is_bf16_compatible, ) if IS_HIP_EXTENSION: - from transformer_engine.pytorch.utils import is_mi200, is_mi308 + from transformer_engine.pytorch.utils import is_mi200, is_mi308, is_mi300_class from transformer_engine.pytorch import ( DotProductAttention, @@ -2121,6 +2121,8 @@ def test_grouped_linear_accuracy( atol, rtol = 0, 0 if use_cutlass: atol, rtol = 1e-3, 1e-3 + if IS_HIP_EXTENSION and is_mi300_class(): + atol, rtol = 3e-2, 3e-2 if use_triton: atol, rtol = get_tolerances(dtype) if dtype == torch.float32: diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index d124fbeaf..92253ad9c 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -456,6 +456,14 @@ def is_mi308(): import re return (re.search('AMD Instinct MI308', torch.cuda.get_device_name(torch.cuda.current_device())) is not None) + def is_mi300_class(): + """check whether the current device is of the gfx942 class""" + return get_device_compute_capability() == (9, 4) + + def is_mi350_class(): + """check whether the current device is of the gfx950 class""" + return get_device_compute_capability() == (9, 5) + @functools.lru_cache(maxsize=None) def is_fp8_fnuz(): return IS_HIP_EXTENSION and get_device_compute_capability() == (9, 4) From c0bf502b00feb3de9d404c86ad394a56b27239a4 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 5 Feb 2026 14:53:56 -0600 Subject: [PATCH 22/51] enable more tests --- tests/pytorch/test_numerics.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 5f832bd3f..bd71e0a50 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -148,7 +148,7 @@ def rocm_attn_backend() -> tuple[bool, bool, bool]: use_cutlass_grouped_gemm = [False] # Only enable cutlass grouped gemm on Hopper -if torch.cuda.get_device_capability() == (9, 0): +if torch.cuda.get_device_capability() == (9, 0) or IS_HIP_EXTENSION: use_cutlass_grouped_gemm.append(True) @@ -2938,7 +2938,10 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): # cublas implementation should be bit-wise match torch.testing.assert_close(o, o_ref, rtol=0, atol=0) else: - torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) + if IS_HIP_EXTENSION and is_mi300_class(): + torch.testing.assert_close(o, o_ref, rtol=2.0e-2, atol=3.0e-2) + else: + torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) if use_cutlass: os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) From 0b162874aa00c01e70d13cc5872bcc20a647c8b0 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 5 Feb 2026 15:01:00 -0600 Subject: [PATCH 23/51] return early when num_gemms<=0 --- transformer_engine/common/gemm/cublaslt_gemm.cu | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 881ccb2a1..32a9c4b6f 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -790,6 +790,11 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_gemm); +#ifdef __HIP_PLATFORM_AMD__ + if (num_gemms <= 0) + return; +#endif + const int current_device = transformer_engine::cuda::current_device(); const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90); const bool use_cutlass = transformer_engine::getenv("NVTE_USE_CUTLASS_GROUPED_GEMM", false); From 58b34e7ba39dfdfcae58bad689ecfd2850eea384 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 5 Feb 2026 17:22:19 -0600 Subject: [PATCH 24/51] simplify normalization --- .../common/gemm/ck_grouped_gemm.cpp | 82 ++----------------- 1 file changed, 6 insertions(+), 76 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp index 5f7520f9e..4c3e0264a 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -191,63 +191,6 @@ static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, return true; } -static inline bool infer_gemm_mode_group0(const transformer_engine::Tensor* const* A, - const transformer_engine::Tensor* const* B, - transformer_engine::Tensor* const* D, - int group_num, - const transformer_engine::Tensor* const*& A_use, - const transformer_engine::Tensor* const*& B_use, - bool& transA_use, - bool& transB_use) -{ - A_use = A; - B_use = B; - transA_use = false; - transB_use = false; - - if (group_num <= 0) - return true; - - int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dm = 0, Dn = 0; - if (!get_flat_2d_dims(*A[0], Ad0, Ad1) || - !get_flat_2d_dims(*B[0], Bd0, Bd1) || - !get_flat_2d_dims(*D[0], Dm, Dn)) { - return false; - } - - auto check = [&](bool do_swap, bool ta, bool tb) -> bool { - const int64_t A0d0 = do_swap ? Bd0 : Ad0; - const int64_t A0d1 = do_swap ? Bd1 : Ad1; - const int64_t B0d0 = do_swap ? Ad0 : Bd0; - const int64_t B0d1 = do_swap ? Ad1 : Bd1; - - const int64_t M = ta ? A0d1 : A0d0; - const int64_t K = ta ? A0d0 : A0d1; - const int64_t N = tb ? B0d0 : B0d1; - const int64_t Kb = tb ? B0d1 : B0d0; - - return (M == Dm) && (N == Dn) && (K == Kb); - }; - - // Try all candidates; prefer "no swap" first, then swap - for (bool do_swap : {false, true}) { - for (bool ta : {false, true}) { - for (bool tb : {false, true}) { - if (check(do_swap, ta, tb)) { - A_use = do_swap ? B : A; - B_use = do_swap ? A : B; - transA_use = ta; - transB_use = tb; - return true; - } - } - } - } - - // Nothing matched D = op(A) * op(B) - return false; -} - template static inline bool dispatch_grouped(bool transA_use, bool transB_use, @@ -306,25 +249,12 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, transformer_engine::typeToSize(ws_te->data.dtype); } - const transformer_engine::Tensor* const* A_use = A_te.data(); - const transformer_engine::Tensor* const* B_use = B_te.data(); - bool transA_use = transA; - bool transB_use = transB; - - // If TE's flags disagree with storage, infer the correct mode from shapes. - if (!infer_gemm_mode_group0(A_te.data(), B_te.data(), D_te.data(), - group_num, A_use, B_use, transA_use, transB_use)) { - int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; - (void)get_flat_2d_dims(*A_te[0], Ad0, Ad1); - (void)get_flat_2d_dims(*B_te[0], Bd0, Bd1); - (void)get_flat_2d_dims(*D_te[0], Dd0, Dd1); - NVTE_ERROR("ck_tile_grouped_gemm: could not infer a consistent GEMM mode from shapes. ", - "A0(flat)=[", Ad0, ",", Ad1, "] ", - "B0(flat)=[", Bd0, ",", Bd1, "] ", - "D0(flat)=[", Dd0, ",", Dd1, "] ", - "given flags transA=", transA, " transB=", transB); - return false; - } + // Normalize similar to upstream + // See https://github.com/NVIDIA/TransformerEngine/blob/59f6f3876767d07045152bfae07b5dd4c54e1725/transformer_engine/common/gemm/cutlass_grouped_gemm.cu#L54-L68 + const transformer_engine::Tensor* const* A_use = B_te.data(); + const transformer_engine::Tensor* const* B_use = A_te.data(); + const bool transA_use = transB; + const bool transB_use = transA; const auto a_dtype = A_use[0]->dtype(); From e28c80142957f87d0400f6e1e8a0007aa13c42a1 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 11 Feb 2026 13:04:07 -0600 Subject: [PATCH 25/51] run hipblaslt for num_gemms==1 --- transformer_engine/common/gemm/cublaslt_gemm.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 32a9c4b6f..bb322233b 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -808,7 +808,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor // Currently only support cutlass group gemm on Hopper Arch #ifdef __HIP_PLATFORM_AMD__ - if (!use_cutlass) { + if (!use_cutlass || num_gemms == 1) { #else if (!(is_hopper && use_cutlass)) { #endif From 5c57d47297fffd05aa87c3d33c7a441ce28f8ff0 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 17 Feb 2026 20:26:16 +0000 Subject: [PATCH 26/51] disable ck_tile when accumulate=true --- transformer_engine/common/gemm/ck_grouped_gemm.cpp | 7 +++++-- transformer_engine/common/gemm/cublaslt_gemm.cu | 5 +++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp index 4c3e0264a..73c88de10 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -261,13 +261,16 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, te_type, { using T = typename TeTypeToCkType::type; - if (accumulate) + if (accumulate) { + // FIXME: The accumulate path is currently disabled in nvte_multi_tensor_gemm + // due to instability on MI325. return dispatch_grouped(transA_use, transB_use, A_use, B_use, D_te.data(), group_num, ws_ptr, ws_bytes, stream); - else + } else { return dispatch_grouped(transA_use, transB_use, A_use, B_use, D_te.data(), group_num, ws_ptr, ws_bytes, stream); + } }); } diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index bb322233b..b087145ab 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -806,10 +806,11 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor workspace, accumulate, use_split_accumulator, math_sm_count, stream); }; - // Currently only support cutlass group gemm on Hopper Arch #ifdef __HIP_PLATFORM_AMD__ - if (!use_cutlass || num_gemms == 1) { + // FIXME: The accumulate path is currently disabled due to instability on MI325. + if (!use_cutlass || num_gemms == 1 || accumulate == true) { #else + // Currently only support cutlass group gemm on Hopper Arch if (!(is_hopper && use_cutlass)) { #endif cublas_path(); From 2e844d999f718432ea9f3e28a35fa7e86c663b34 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 17 Feb 2026 22:23:01 +0000 Subject: [PATCH 27/51] remove test file --- gmm2.py | 90 --------------------------------------------------------- 1 file changed, 90 deletions(-) delete mode 100644 gmm2.py diff --git a/gmm2.py b/gmm2.py deleted file mode 100644 index 938d69247..000000000 --- a/gmm2.py +++ /dev/null @@ -1,90 +0,0 @@ -import os -import time -import torch -import transformer_engine.pytorch as te - -torch.manual_seed(0) - -os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" -os.environ["NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK"] = "1" - -device = "cuda" -dtype = torch.bfloat16 - -E = 4 -K = 1024 -N = 2048 -m_splits = [128, 64, 0, 256] -M_total = sum(m_splits) - -x = torch.randn(M_total, K, device=device, dtype=dtype) - -# Timing helper -def bench_cuda(fn, warmup=20, iters=100): - # Warmup - for _ in range(warmup): - fn() - torch.cuda.synchronize() - - # Timed - start = time.time() - for _ in range(iters): - fn() - torch.cuda.synchronize() - end = time.time() - - avg_ms = (end - start) * 1000.0 / iters - return avg_ms - -# TE GroupedLinear -glinear = te.GroupedLinear(E, K, N, bias=False).to(device=device, dtype=dtype) - -def te_run(): - return glinear(x, m_splits=m_splits) - -te_ms = bench_cuda(te_run, warmup=20, iters=100) - -# Grab weights for reference path -Ws = [getattr(glinear, f"weight{e}") for e in range(E)] # each [N, K] -W = torch.stack(Ws, dim=0) # [E, N, K] -assert W.shape == (E, N, K), f"Unexpected weight shape: {W.shape}" - -# Torch reference (group loop) -offsets = [] -off = 0 -for m in m_splits: - offsets.append(off) - off += m - -y_ref_buf = torch.empty((M_total, N), device=device, dtype=dtype) - -def torch_run(): - # Fill the preallocated buffer - for e, m in enumerate(m_splits): - if m == 0: - continue - o = offsets[e] - y_ref_buf[o:o+m].copy_(x[o:o+m] @ W[e].transpose(0, 1)) - return y_ref_buf - -torch_ms = bench_cuda(torch_run, warmup=20, iters=100) - -# Compare outputs -y_te = te_run() -y_ref = torch_run().clone() - -diff = (y_te.float() - y_ref.float()) -max_abs = diff.abs().max().item() -rel = (diff.abs() / (y_ref.float().abs() + 1e-6)).max().item() - -print(f"Errors:") -print(f" {y_te.shape=}, {y_ref.shape=}") -print(" max_abs_err:", max_abs) -print(" max_rel_err:", rel) - -torch.testing.assert_close(y_te.float(), y_ref.float(), rtol=3e-2, atol=3e-2) - -print(f"\nTiming:") -print(f" TE avg: {te_ms:.3f} ms") -print(f" Torch avg: {torch_ms:.3f} ms") -print(f" Speedup: {torch_ms/te_ms:.2f}x (Torch / TE)") From f680d6a8e224c4912289038f18faa10e8126872f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 23 Feb 2026 12:43:09 -0600 Subject: [PATCH 28/51] fix copyright header --- transformer_engine/common/gemm/ck_grouped_gemm.cpp | 6 +++++- transformer_engine/common/gemm/ck_grouped_gemm.h | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp index 73c88de10..317cd4c43 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -1,4 +1,8 @@ -/* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. */ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ #include diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.h b/transformer_engine/common/gemm/ck_grouped_gemm.h index d539b47f7..97b4cfd88 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm.h @@ -1,4 +1,8 @@ -/* Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. */ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ bool ck_tile_grouped_gemm(const NVTETensor* A, const NVTETensor* B, From 6d85088f337aff73cd787098532182e4829ac900 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 23 Feb 2026 13:10:46 -0600 Subject: [PATCH 29/51] simplify calls in dispatch_grouped --- .../common/gemm/ck_grouped_gemm.cpp | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp index 317cd4c43..78bcb6d0d 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -206,17 +206,16 @@ static inline bool dispatch_grouped(bool transA_use, size_t workspace_bytes, hipStream_t stream) { -// FIXME: This could be a templated lambda function in C++20. -#define CALL(ALayout_, BLayout_, ta_, tb_) \ - return run_grouped_impl( \ - A_use, B_use, D, group_num, (ta_), (tb_), workspace, workspace_bytes, stream) - - if (!transA_use && !transB_use) { CALL(RowMajor, RowMajor, false, false); } - if (!transA_use && transB_use) { CALL(RowMajor, ColMajor, false, true ); } - if ( transA_use && !transB_use) { CALL(ColMajor, RowMajor, true, false); } - /* transA_use && transB_use */ { CALL(ColMajor, ColMajor, true, true ); } - -#undef CALL + TRANSFORMER_ENGINE_SWITCH_CONDITION(transA_use, kTransA, { + using ALayout = std::conditional_t; + + TRANSFORMER_ENGINE_SWITCH_CONDITION(transB_use, kTransB, { + using BLayout = std::conditional_t; + + return run_grouped_impl( + A_use, B_use, D, group_num, kTransA, kTransB, workspace, workspace_bytes, stream); + }); + }); } bool ck_tile_grouped_gemm(const NVTETensor* A, From 791003844f59f715110c742414e5e497863014b4 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 23 Feb 2026 13:14:22 -0600 Subject: [PATCH 30/51] remove is_mi3*0_class --- tests/pytorch/test_numerics.py | 8 +++++--- transformer_engine/pytorch/utils.py | 10 +--------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index bd71e0a50..dae88f0df 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -28,7 +28,7 @@ is_bf16_compatible, ) if IS_HIP_EXTENSION: - from transformer_engine.pytorch.utils import is_mi200, is_mi308, is_mi300_class + from transformer_engine.pytorch.utils import is_mi200, is_mi308 from transformer_engine.pytorch import ( DotProductAttention, @@ -2121,7 +2121,8 @@ def test_grouped_linear_accuracy( atol, rtol = 0, 0 if use_cutlass: atol, rtol = 1e-3, 1e-3 - if IS_HIP_EXTENSION and is_mi300_class(): + if IS_HIP_EXTENSION and torch.cuda.get_device_capability() == (9, 4): + # gfx942 atol, rtol = 3e-2, 3e-2 if use_triton: atol, rtol = get_tolerances(dtype) @@ -2938,7 +2939,8 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): # cublas implementation should be bit-wise match torch.testing.assert_close(o, o_ref, rtol=0, atol=0) else: - if IS_HIP_EXTENSION and is_mi300_class(): + if IS_HIP_EXTENSION and torch.cuda.get_device_capability() == (9, 4): + # gfx942 torch.testing.assert_close(o, o_ref, rtol=2.0e-2, atol=3.0e-2) else: torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 92253ad9c..d124fbeaf 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -456,14 +456,6 @@ def is_mi308(): import re return (re.search('AMD Instinct MI308', torch.cuda.get_device_name(torch.cuda.current_device())) is not None) - def is_mi300_class(): - """check whether the current device is of the gfx942 class""" - return get_device_compute_capability() == (9, 4) - - def is_mi350_class(): - """check whether the current device is of the gfx950 class""" - return get_device_compute_capability() == (9, 5) - @functools.lru_cache(maxsize=None) def is_fp8_fnuz(): return IS_HIP_EXTENSION and get_device_compute_capability() == (9, 4) From e8ebb0ea3ca4225fce9e9225dd861b1057fef23c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 23 Feb 2026 13:26:25 -0600 Subject: [PATCH 31/51] disable unused constants --- transformer_engine/common/gemm/cublaslt_gemm.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index b087145ab..8cb80d704 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -793,10 +793,10 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor #ifdef __HIP_PLATFORM_AMD__ if (num_gemms <= 0) return; +#else + const int current_device = transformer_engine::cuda::current_device(); + const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90); #endif - - const int current_device = transformer_engine::cuda::current_device(); - const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90); const bool use_cutlass = transformer_engine::getenv("NVTE_USE_CUTLASS_GROUPED_GEMM", false); const bool warn_fallback = transformer_engine::getenv("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", false); From e866bc63707e8808b7b21c1662a72cbcbe8235e6 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 24 Feb 2026 10:18:10 -0600 Subject: [PATCH 32/51] add another fallback --- transformer_engine/common/gemm/cublaslt_gemm.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 8cb80d704..b435c57c7 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -875,7 +875,8 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() && #ifdef __HIP_PLATFORM_AMD__ true) { - ck_tile_grouped_gemm(A, B, D, num_gemms, transa, transb, workspace, accumulate, stream); + if (!ck_tile_grouped_gemm(A, B, D, num_gemms, transa, transb, workspace, accumulate, stream)) + cublas_path(); #else all_groups_uniform_k128(B, transb)) { cutlass_grouped_gemm(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate, From ee438fb00fbbfd7eae60d1226e5a3381babc096e Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 25 Feb 2026 10:43:24 -0600 Subject: [PATCH 33/51] implement Primus-Turbo selection logic, persistent descs --- .../common/gemm/ck_grouped_gemm.cpp | 69 ++++++++++++++----- .../common/gemm/cublaslt_gemm.cu | 6 +- 2 files changed, 58 insertions(+), 17 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp index 78bcb6d0d..839b9ee6a 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -36,9 +36,14 @@ static inline const transformer_engine::SimpleTensor& data_view(const transforme return t.data; // rowwise data view } -struct TileCfg_basic { +// Primus-Turbo-like FP16/BF16 tile configs +// Selection rule: +// if (N % 256 == 0) use 256x256x64 +// else if (N % 128 == 0) use 256x128x64 +// else use 256x128x64 with N padding enabled +struct TileCfg_256x256x64 { static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 64; static constexpr ck_tile::index_t M_Warp = 2; @@ -49,14 +54,22 @@ struct TileCfg_basic { static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool kPadM = true; - static constexpr bool kPadN = true; - static constexpr bool kPadK = true; + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; - static constexpr ck_tile::index_t TilePartitionerM01 = 1; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + +struct TileCfg_256x128x64 : TileCfg_256x256x64 { + static constexpr ck_tile::index_t N_Tile = 128; +}; + +struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { + static constexpr bool kPadN = true; }; // This class instantiates CK_Tile's grouped GEMM pipeline. @@ -100,7 +113,7 @@ struct Runner{ }; template + ck_tile::memory_operation_enum MemOp, typename TileCfg> static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, const transformer_engine::Tensor* const* B_use, transformer_engine::Tensor* const* D, @@ -111,7 +124,7 @@ static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, size_t workspace_bytes, hipStream_t stream) { - using Kernel = typename Runner::Kernel; + using Kernel = typename Runner::Kernel; const size_t needed = Kernel::GetWorkSpaceSize(group_num); if (!workspace || workspace_bytes < needed) { @@ -119,7 +132,8 @@ static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, return false; } - std::vector> descs; + thread_local std::vector> descs; + descs.clear(); descs.reserve(group_num); for (int i = 0; i < group_num; ++i) { @@ -206,16 +220,39 @@ static inline bool dispatch_grouped(bool transA_use, size_t workspace_bytes, hipStream_t stream) { - TRANSFORMER_ENGINE_SWITCH_CONDITION(transA_use, kTransA, { - using ALayout = std::conditional_t; + // Select tile config like Primus-Turbo for FP16/BF16: + // N%256 -> 256x256x64 + // N%128 -> 256x128x64 + // else -> 256x128x64 padding + // NOTE: We assume N is uniform across groups. + int64_t ref_d0 = 0, ref_d1 = 0; + if (!get_flat_2d_dims(*D[0], ref_d0, ref_d1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for D[0]"); + return false; + } + const ck_tile::index_t N = static_cast(ref_d1); + + auto run_with_tilecfg = [&](auto tile_tag) -> bool { + using TileCfgSel = decltype(tile_tag); + TRANSFORMER_ENGINE_SWITCH_CONDITION(transA_use, kTransA, { + using ALayout = std::conditional_t; - TRANSFORMER_ENGINE_SWITCH_CONDITION(transB_use, kTransB, { - using BLayout = std::conditional_t; + TRANSFORMER_ENGINE_SWITCH_CONDITION(transB_use, kTransB, { + using BLayout = std::conditional_t; - return run_grouped_impl( - A_use, B_use, D, group_num, kTransA, kTransB, workspace, workspace_bytes, stream); + return run_grouped_impl( + A_use, B_use, D, group_num, kTransA, kTransB, workspace, workspace_bytes, stream); + }); }); - }); + }; + + if ((N % 256) == 0) { + return run_with_tilecfg(TileCfg_256x256x64{}); + } else if ((N % 128) == 0) { + return run_with_tilecfg(TileCfg_256x128x64{}); + } else { + return run_with_tilecfg(TileCfg_256x128x64_padding{}); + } } bool ck_tile_grouped_gemm(const NVTETensor* A, diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index b435c57c7..fa2608dab 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -875,8 +875,12 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() && #ifdef __HIP_PLATFORM_AMD__ true) { - if (!ck_tile_grouped_gemm(A, B, D, num_gemms, transa, transb, workspace, accumulate, stream)) + if (!ck_tile_grouped_gemm(A, B, D, num_gemms, transa, transb, workspace, accumulate, stream)) { + if (warn_fallback) { + NVTE_WARN("Fallback to cuBLAS grouped GEMM."); + } cublas_path(); + } #else all_groups_uniform_k128(B, transb)) { cutlass_grouped_gemm(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate, From 0cbf1cddfc844cd7102c38e1256ffbca287807fc Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 25 Feb 2026 14:39:05 -0600 Subject: [PATCH 34/51] tighten tolerances --- tests/pytorch/test_numerics.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index dae88f0df..847d26ade 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2123,7 +2123,7 @@ def test_grouped_linear_accuracy( atol, rtol = 1e-3, 1e-3 if IS_HIP_EXTENSION and torch.cuda.get_device_capability() == (9, 4): # gfx942 - atol, rtol = 3e-2, 3e-2 + atol, rtol = 1e-3, 8e-3 if use_triton: atol, rtol = get_tolerances(dtype) if dtype == torch.float32: @@ -2939,11 +2939,7 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): # cublas implementation should be bit-wise match torch.testing.assert_close(o, o_ref, rtol=0, atol=0) else: - if IS_HIP_EXTENSION and torch.cuda.get_device_capability() == (9, 4): - # gfx942 - torch.testing.assert_close(o, o_ref, rtol=2.0e-2, atol=3.0e-2) - else: - torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) + torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) if use_cutlass: os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) From 98e0c66844bbd4a7c936e2e3831910ca15d09358 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 25 Feb 2026 17:13:19 -0600 Subject: [PATCH 35/51] use namespace, various cleanups --- transformer_engine/common/CMakeLists.txt | 2 +- .../common/gemm/ck_grouped_gemm.cpp | 19 +++++++++++++------ .../common/gemm/cublaslt_gemm.cu | 6 ++++-- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 03954b6a6..d5aab2cda 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp index 839b9ee6a..8f52383f7 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -13,12 +13,15 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" +namespace transformer_engine { +namespace grouped_gemm { + using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; -template struct TeTypeToCkType; -template <> struct TeTypeToCkType { using type = ck_tile::half_t; }; -template <> struct TeTypeToCkType { using type = ck_tile::bfloat16_t; }; +template struct TETypeToCKType; +template <> struct TETypeToCKType { using type = ck_tile::half_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; // Treat TE tensors as generalized 2D matrices by flattening: // (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim. @@ -255,6 +258,9 @@ static inline bool dispatch_grouped(bool transA_use, } } +} // namespace grouped_gemm +} // namespace transformer_engine + bool ck_tile_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, @@ -291,6 +297,7 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, // Normalize similar to upstream // See https://github.com/NVIDIA/TransformerEngine/blob/59f6f3876767d07045152bfae07b5dd4c54e1725/transformer_engine/common/gemm/cutlass_grouped_gemm.cu#L54-L68 + // I.e., swap A and B, as well as transa and transb. const transformer_engine::Tensor* const* A_use = B_te.data(); const transformer_engine::Tensor* const* B_use = A_te.data(); const bool transA_use = transB; @@ -299,16 +306,16 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, const auto a_dtype = A_use[0]->dtype(); TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, te_type, { - using T = typename TeTypeToCkType::type; + using T = typename transformer_engine::grouped_gemm::TETypeToCKType::type; if (accumulate) { // FIXME: The accumulate path is currently disabled in nvte_multi_tensor_gemm // due to instability on MI325. - return dispatch_grouped(transA_use, transB_use, + return transformer_engine::grouped_gemm::dispatch_grouped(transA_use, transB_use, A_use, B_use, D_te.data(), group_num, ws_ptr, ws_bytes, stream); } else { - return dispatch_grouped(transA_use, transB_use, + return transformer_engine::grouped_gemm::dispatch_grouped(transA_use, transB_use, A_use, B_use, D_te.data(), group_num, ws_ptr, ws_bytes, stream); } diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index fa2608dab..cbed586ca 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -794,8 +794,8 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor if (num_gemms <= 0) return; #else - const int current_device = transformer_engine::cuda::current_device(); - const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90); + const int current_device = transformer_engine::cuda::current_device(); + const bool is_hopper = (transformer_engine::cuda::sm_arch(current_device) == 90); #endif const bool use_cutlass = transformer_engine::getenv("NVTE_USE_CUTLASS_GROUPED_GEMM", false); const bool warn_fallback = @@ -825,6 +825,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor return true; }; +#ifndef __HIP_PLATFORM_AMD__ auto all_groups_uniform_k128 = [&](const NVTETensor *p, bool trans) -> bool { int64_t ref_k = -1; for (size_t i = 0; i < num_gemms; i++) { @@ -841,6 +842,7 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor return true; }; +#endif auto is_supported_dtype = [&]() -> bool { auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); From 36bd68e9d1f6eba03b9a6251d1bfe29c6b0bfb01 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 26 Feb 2026 15:22:33 -0600 Subject: [PATCH 36/51] avoid creating vector with Tensors --- .../common/gemm/ck_grouped_gemm.cpp | 65 +++++++++---------- 1 file changed, 31 insertions(+), 34 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp index 8f52383f7..64eafc333 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -117,9 +117,9 @@ struct Runner{ template -static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, - const transformer_engine::Tensor* const* B_use, - transformer_engine::Tensor* const* D, +static bool run_grouped_impl(const NVTETensor* A_use, + const NVTETensor* B_use, + NVTETensor* D, int group_num, bool transA_use, bool transB_use, @@ -140,14 +140,21 @@ static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, descs.reserve(group_num); for (int i = 0; i < group_num; ++i) { - const auto& a = data_view(*A_use[i]); - const auto& b = data_view(*B_use[i]); - const auto& d = data_view(*D[i]); + const transformer_engine::Tensor* const A_te = + transformer_engine::convertNVTETensorCheck(A_use[i]); + const transformer_engine::Tensor* const B_te = + transformer_engine::convertNVTETensorCheck(B_use[i]); + transformer_engine::Tensor* D_te = + transformer_engine::convertNVTETensorCheck(D[i]); + + const auto& a = data_view(*A_te); + const auto& b = data_view(*B_te); + const auto& d = data_view(*D_te); int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; - if (!get_flat_2d_dims(*A_use[i], Ad0, Ad1) || - !get_flat_2d_dims(*B_use[i], Bd0, Bd1) || - !get_flat_2d_dims(*D[i], Dd0, Dd1)) { + if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || + !get_flat_2d_dims(*B_te, Bd0, Bd1) || + !get_flat_2d_dims(*D_te, Dd0, Dd1)) { NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2 (2D or higher)."); return false; } @@ -215,21 +222,17 @@ static bool run_grouped_impl(const transformer_engine::Tensor* const* A_use, template static inline bool dispatch_grouped(bool transA_use, bool transB_use, - const transformer_engine::Tensor* const* A_use, - const transformer_engine::Tensor* const* B_use, - transformer_engine::Tensor* const* D, + const NVTETensor* A_use, + const NVTETensor* B_use, + NVTETensor* D, int group_num, void* workspace, size_t workspace_bytes, hipStream_t stream) { - // Select tile config like Primus-Turbo for FP16/BF16: - // N%256 -> 256x256x64 - // N%128 -> 256x128x64 - // else -> 256x128x64 padding - // NOTE: We assume N is uniform across groups. int64_t ref_d0 = 0, ref_d1 = 0; - if (!get_flat_2d_dims(*D[0], ref_d0, ref_d1)) { + transformer_engine::Tensor* D_te = transformer_engine::convertNVTETensorCheck(D[0]); + if (!get_flat_2d_dims(*D_te, ref_d0, ref_d1)) { NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for D[0]"); return false; } @@ -249,6 +252,11 @@ static inline bool dispatch_grouped(bool transA_use, }); }; + // Select tile config like Primus-Turbo for FP16/BF16: + // N%256 -> 256x256x64 + // N%128 -> 256x128x64 + // else -> 256x128x64 padding + // NOTE: We assume N is uniform across groups. if ((N % 256) == 0) { return run_with_tilecfg(TileCfg_256x256x64{}); } else if ((N % 128) == 0) { @@ -274,17 +282,6 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, if (group_num <= 0) return true; - // Convert A/B/D arrays into TE Tensor arrays - std::vector A_te(group_num); - std::vector B_te(group_num); - std::vector D_te(group_num); - - for (int i = 0; i < group_num; ++i) { - A_te[i] = transformer_engine::convertNVTETensorCheck(A[i]); - B_te[i] = transformer_engine::convertNVTETensorCheck(B[i]); - D_te[i] = transformer_engine::convertNVTETensorCheck(D[i]); - } - // Workspace pointer + bytes void* ws_ptr = nullptr; size_t ws_bytes = 0; @@ -298,12 +295,12 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, // Normalize similar to upstream // See https://github.com/NVIDIA/TransformerEngine/blob/59f6f3876767d07045152bfae07b5dd4c54e1725/transformer_engine/common/gemm/cutlass_grouped_gemm.cu#L54-L68 // I.e., swap A and B, as well as transa and transb. - const transformer_engine::Tensor* const* A_use = B_te.data(); - const transformer_engine::Tensor* const* B_use = A_te.data(); + const NVTETensor* A_use = B; + const NVTETensor* B_use = A; const bool transA_use = transB; const bool transB_use = transA; - const auto a_dtype = A_use[0]->dtype(); + const auto a_dtype = transformer_engine::convertNVTETensorCheck(A_use[0])->dtype(); TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, te_type, { using T = typename transformer_engine::grouped_gemm::TETypeToCKType::type; @@ -312,11 +309,11 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, // FIXME: The accumulate path is currently disabled in nvte_multi_tensor_gemm // due to instability on MI325. return transformer_engine::grouped_gemm::dispatch_grouped(transA_use, transB_use, - A_use, B_use, D_te.data(), group_num, + A_use, B_use, D, group_num, ws_ptr, ws_bytes, stream); } else { return transformer_engine::grouped_gemm::dispatch_grouped(transA_use, transB_use, - A_use, B_use, D_te.data(), group_num, + A_use, B_use, D, group_num, ws_ptr, ws_bytes, stream); } }); From c5d83a426d49c448aabc26def95b5cb49c2e172d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 26 Feb 2026 16:37:36 -0600 Subject: [PATCH 37/51] merge dispatch_grouped into ck_tile_grouped_gemm --- .../common/gemm/ck_grouped_gemm.cpp | 111 ++++++++---------- 1 file changed, 49 insertions(+), 62 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp index 64eafc333..61c08f0de 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -219,53 +219,6 @@ static bool run_grouped_impl(const NVTETensor* A_use, return true; } -template -static inline bool dispatch_grouped(bool transA_use, - bool transB_use, - const NVTETensor* A_use, - const NVTETensor* B_use, - NVTETensor* D, - int group_num, - void* workspace, - size_t workspace_bytes, - hipStream_t stream) { - - int64_t ref_d0 = 0, ref_d1 = 0; - transformer_engine::Tensor* D_te = transformer_engine::convertNVTETensorCheck(D[0]); - if (!get_flat_2d_dims(*D_te, ref_d0, ref_d1)) { - NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for D[0]"); - return false; - } - const ck_tile::index_t N = static_cast(ref_d1); - - auto run_with_tilecfg = [&](auto tile_tag) -> bool { - using TileCfgSel = decltype(tile_tag); - TRANSFORMER_ENGINE_SWITCH_CONDITION(transA_use, kTransA, { - using ALayout = std::conditional_t; - - TRANSFORMER_ENGINE_SWITCH_CONDITION(transB_use, kTransB, { - using BLayout = std::conditional_t; - - return run_grouped_impl( - A_use, B_use, D, group_num, kTransA, kTransB, workspace, workspace_bytes, stream); - }); - }); - }; - - // Select tile config like Primus-Turbo for FP16/BF16: - // N%256 -> 256x256x64 - // N%128 -> 256x128x64 - // else -> 256x128x64 padding - // NOTE: We assume N is uniform across groups. - if ((N % 256) == 0) { - return run_with_tilecfg(TileCfg_256x256x64{}); - } else if ((N % 128) == 0) { - return run_with_tilecfg(TileCfg_256x128x64{}); - } else { - return run_with_tilecfg(TileCfg_256x128x64_padding{}); - } -} - } // namespace grouped_gemm } // namespace transformer_engine @@ -282,14 +235,16 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, if (group_num <= 0) return true; + using namespace transformer_engine; + using namespace transformer_engine::grouped_gemm; + // Workspace pointer + bytes void* ws_ptr = nullptr; size_t ws_bytes = 0; if (workspace) { - auto* ws_te = transformer_engine::convertNVTETensorCheck(*workspace); + auto* ws_te = convertNVTETensorCheck(*workspace); ws_ptr = ws_te->data.dptr; - ws_bytes = ws_te->data.numel() * - transformer_engine::typeToSize(ws_te->data.dtype); + ws_bytes = ws_te->data.numel() * typeToSize(ws_te->data.dtype); } // Normalize similar to upstream @@ -300,21 +255,53 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, const bool transA_use = transB; const bool transB_use = transA; - const auto a_dtype = transformer_engine::convertNVTETensorCheck(A_use[0])->dtype(); + const auto a_dtype = convertNVTETensorCheck(A_use[0])->dtype(); + + // Get N from D[0] (assume uniform N across groups) + int64_t ref_d0 = 0, ref_d1 = 0; + Tensor* D0_te = convertNVTETensorCheck(D[0]); + if (!get_flat_2d_dims(*D0_te, ref_d0, ref_d1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for D[0]"); + return false; + } + const ck_tile::index_t N = static_cast(ref_d1); TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, te_type, { - using T = typename transformer_engine::grouped_gemm::TETypeToCKType::type; - - if (accumulate) { - // FIXME: The accumulate path is currently disabled in nvte_multi_tensor_gemm - // due to instability on MI325. - return transformer_engine::grouped_gemm::dispatch_grouped(transA_use, transB_use, - A_use, B_use, D, group_num, - ws_ptr, ws_bytes, stream); + using T = typename TETypeToCKType::type; + + auto run_with_tilecfg = [&](auto tile_tag) -> bool { + using TileCfgSel = decltype(tile_tag); + + TRANSFORMER_ENGINE_SWITCH_CONDITION(transA_use, kTransA, { + using ALayout = std::conditional_t; + + TRANSFORMER_ENGINE_SWITCH_CONDITION(transB_use, kTransB, { + using BLayout = std::conditional_t; + + if (accumulate) { + return run_grouped_impl( + A_use, B_use, D, group_num, kTransA, kTransB, ws_ptr, ws_bytes, stream); + } else { + return run_grouped_impl( + A_use, B_use, D, group_num, kTransA, kTransB, ws_ptr, ws_bytes, stream); + } + }); + }); + }; + + // Select tile config like Primus-Turbo for FP16/BF16: + // N%256 -> 256x256x64 + // N%128 -> 256x128x64 + // else -> 256x128x64 padding + // NOTE: We assume N is uniform across groups. + if ((N % 256) == 0) { + return run_with_tilecfg(TileCfg_256x256x64{}); + } else if ((N % 128) == 0) { + return run_with_tilecfg(TileCfg_256x128x64{}); } else { - return transformer_engine::grouped_gemm::dispatch_grouped(transA_use, transB_use, - A_use, B_use, D, group_num, - ws_ptr, ws_bytes, stream); + return run_with_tilecfg(TileCfg_256x128x64_padding{}); } }); } From 26dfbb60ffd2b2ee8be66a1991129a400fcc19a5 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 27 Feb 2026 11:23:10 -0600 Subject: [PATCH 38/51] same tolerances for gfx950 --- tests/pytorch/test_numerics.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 847d26ade..4be6e69e7 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2121,8 +2121,7 @@ def test_grouped_linear_accuracy( atol, rtol = 0, 0 if use_cutlass: atol, rtol = 1e-3, 1e-3 - if IS_HIP_EXTENSION and torch.cuda.get_device_capability() == (9, 4): - # gfx942 + if IS_HIP_EXTENSION: atol, rtol = 1e-3, 8e-3 if use_triton: atol, rtol = get_tolerances(dtype) From 5a7eb694934cdc25b0d55953171ff555a71cccae Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 27 Feb 2026 14:45:46 -0600 Subject: [PATCH 39/51] feat(gemm): enable TensorQuant pipeline for FP8 on GFX942 Align GemmRowColTensorQuantPipelineProblem with ck_tile V3 requirements by using AccType for intermediate C results. Specific to TensorQuant (per-tensor scaling); limited to e4m3/e5m2 FNUZ formats. Updates test_numerics.py to exercise FP8 inputs in the grouped linear accuracy suite. --- README.rst | 13 + tests/pytorch/test_numerics.py | 10 +- transformer_engine/common/common.h | 27 ++ .../common/gemm/ck_grouped_gemm.cpp | 249 ++++++++++++++---- 4 files changed, 249 insertions(+), 50 deletions(-) diff --git a/README.rst b/README.rst index 66d1b0b3e..0f29d8eef 100644 --- a/README.rst +++ b/README.rst @@ -354,6 +354,19 @@ legacy single-stage atomic kernel by setting: NVTE_USE_ATOMIC_AMAX=1 +Grouped GEMM using CK_Tile +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Transformer Engine provides a CK_Tile–based implementation of grouped GEMM +as an alternative to the hipBlasLt-based default grouped GEMM implementation. +This will provide performance improvements in most supported cases. + +You can enable the CK_Tile-based backend using the same environment variables as in the +upstream CUTLASS implementation: + + NVTE_USE_CUTLASS_GROUPED_GEMM=1 # Enable CK_Tile-based grouped GEMM + NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK=1 # Print a warning if falling back to hipBlasLt backend (e.g., due to an unsupported config) + Transformer Engine ****************** diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 4be6e69e7..697ddb84f 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -97,7 +97,7 @@ def rocm_attn_backend() -> tuple[bool, bool, bool]: module_inference = ["TransformerLayer", "MultiheadAttention"] input_formats_inference = ["sbhd", "bshd"] -param_types = [torch.float32, torch.float16] +param_types = [torch.float32, torch.float16, torch.float8_e4m3fnuz] if is_bf16_compatible(): # bf16 requires sm_80 or higher param_types.append(torch.bfloat16) @@ -2140,6 +2140,8 @@ def test_grouped_linear_accuracy( @pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) +@pytest.mark.parametrize("fp8_model_params", all_boolean) +@pytest.mark.parametrize("recipe", fp8_recipes + [None]) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("delay_wgrad_compute", all_boolean) def test_grouped_linear_accuracy_cutlass( @@ -2147,6 +2149,8 @@ def test_grouped_linear_accuracy_cutlass( num_gemms, bs, model, + recipe, + fp8_model_params, fuse_wgrad_accumulation, delay_wgrad_compute, ): @@ -2156,8 +2160,8 @@ def test_grouped_linear_accuracy_cutlass( num_gemms, bs, model, - None, - False, + recipe, + fp8_model_params, fuse_wgrad_accumulation, False, delay_wgrad_compute, diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 0015e9155..2e9672644 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -658,6 +658,33 @@ struct TypeInfo { NVTE_ERROR("Invalid type for 16 bit."); \ } +#define TRANSFORMER_ENGINE_TYPE_SWITCH_MIXED(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat16: { \ + using type = fp16; \ + __VA_ARGS__; \ + break; \ + } \ + case DType::kBFloat16: { \ + using type = bf16; \ + __VA_ARGS__; \ + break; \ + } \ + case DType::kFloat8E5M2: { \ + using type = fp8e5m2; \ + __VA_ARGS__; \ + break; \ + } \ + case DType::kFloat8E4M3: { \ + using type = fp8e4m3; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + NVTE_ERROR("Invalid type."); \ + } + #define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \ switch (SCALE_DIM) { \ case 1: { \ diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp index 61c08f0de..aa7f50f95 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -13,6 +13,11 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" + namespace transformer_engine { namespace grouped_gemm { @@ -20,6 +25,8 @@ using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; template struct TETypeToCKType; +template <> struct TETypeToCKType { using type = ck_tile::fp8_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bf8_t; }; template <> struct TETypeToCKType { using type = ck_tile::half_t; }; template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; @@ -39,6 +46,10 @@ static inline const transformer_engine::SimpleTensor& data_view(const transforme return t.data; // rowwise data view } +static inline const transformer_engine::SimpleTensor& inv_scale_view(const transformer_engine::Tensor& t) { + return t.scale_inv; // dequantization scaling factor +} + // Primus-Turbo-like FP16/BF16 tile configs // Selection rule: // if (N % 256 == 0) use 256x256x64 @@ -75,11 +86,81 @@ struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { static constexpr bool kPadN = true; }; + +// Primus-Turbo-like FP8/BF8 tile configs +// Selection rule: +// if (N % 256 == 0) use 256x256x128 +// else if (N % 128 == 0) use 256x128x128 +// else use 256x128x128 with N padding enabled +struct TileCfg_256x256x128 { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 32; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + +struct TileCfg_256x128x128 : TileCfg_256x256x128 { + static constexpr ck_tile::index_t N_Tile = 128; +}; + +struct TileCfg_256x128x128_padding : TileCfg_256x256x128 { + static constexpr bool kPadN = true; +}; + +template struct GemmTilePolicy; + +// FP16/BF16 – K=64 tiles +template <> +struct GemmTilePolicy { + using Tile256x256 = TileCfg_256x256x64; + using Tile256x128 = TileCfg_256x128x64; + using TilePadding = TileCfg_256x128x64_padding; +}; + +template <> +struct GemmTilePolicy { + using Tile256x256 = TileCfg_256x256x64; + using Tile256x128 = TileCfg_256x128x64; + using TilePadding = TileCfg_256x128x64_padding; +}; + +// FP8 – K=128 tiles +template <> +struct GemmTilePolicy { + using Tile256x256 = TileCfg_256x256x128; + using Tile256x128 = TileCfg_256x128x128; + using TilePadding = TileCfg_256x128x128_padding; +}; + +template <> +struct GemmTilePolicy { + using Tile256x256 = TileCfg_256x256x128; + using Tile256x128 = TileCfg_256x128x128; + using TilePadding = TileCfg_256x128x128_padding; +}; + // This class instantiates CK_Tile's grouped GEMM pipeline. // See e.g. https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/03_gemm/universal_gemm_invoker.hpp for reference. template struct Runner{ using GemmShape = ck_tile::TileGemmShape< @@ -90,15 +171,33 @@ struct Runner{ using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; - using UniversalTraits = ck_tile::PersistentTileGemmUniversalTraits< + using AQLayout = RowMajor; + using BQLayout = RowMajor; + static constexpr ck_tile::QuantType QuantMode = ck_tile::QuantType::TensorQuant; + + using UniversalTraits = std::conditional_t< + useTensorQuant, + ck_tile::TileGemmQuantTraits< + TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, + false, false, ALayout, BLayout, CLayout, + QuantMode, AQLayout, BQLayout, + false, TileCfg::DoubleSmemBuffer, false>, + ck_tile::PersistentTileGemmUniversalTraits< TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, - TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>; + TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>>; static constexpr ck_tile::GemmPipelineScheduler Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - using Problem = ck_tile::UniversalGemmPipelineProblem< - AType, BType, AccType, GemmShape, UniversalTraits, Scheduler>; + using Problem = std::conditional_t< + useTensorQuant, + ck_tile::GemmRowColTensorQuantPipelineProblem< + AType, BType, AccType, + AccType, GemmShape, UniversalTraits, + false, AccType>, + ck_tile::UniversalGemmPipelineProblem< + AType, BType, AccType, + GemmShape, UniversalTraits, Scheduler>>; using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -112,11 +211,21 @@ struct Runner{ TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, Problem::TransposeC, MemOp>>; - using Kernel = ck_tile::GroupedGemmKernel; + using Kernel = std::conditional_t< + useTensorQuant, + ck_tile::QuantGroupedGemmKernel< + Partitioner, Pipeline, + Epilogue, QuantMode>, + ck_tile::GroupedGemmKernel< + Partitioner, Pipeline, Epilogue>>; }; -template +template static bool run_grouped_impl(const NVTETensor* A_use, const NVTETensor* B_use, NVTETensor* D, @@ -127,7 +236,13 @@ static bool run_grouped_impl(const NVTETensor* A_use, size_t workspace_bytes, hipStream_t stream) { - using Kernel = typename Runner::Kernel; + using RunnerT = Runner; + using Kernel = typename RunnerT::Kernel; + + using HostArgs = std::conditional_t< + useTensorQuant, + ck_tile::QuantGroupedGemmHostArgs, + ck_tile::GroupedGemmHostArgs<0>>; const size_t needed = Kernel::GetWorkSpaceSize(group_num); if (!workspace || workspace_bytes < needed) { @@ -135,7 +250,7 @@ static bool run_grouped_impl(const NVTETensor* A_use, return false; } - thread_local std::vector> descs; + thread_local std::vector descs; descs.clear(); descs.reserve(group_num); @@ -179,19 +294,46 @@ static bool run_grouped_impl(const NVTETensor* A_use, const ck_tile::index_t stride_B = Bd1; const ck_tile::index_t stride_E = Dd1; - descs.emplace_back( + if constexpr (useTensorQuant) { + ck_tile::index_t AQK = 1; // Tensor quantization: tensor shape [1] + ck_tile::index_t BQK = 1; // Tensor quantization: tensor shape [1] + ck_tile::index_t stride_AQ = 1; // Tensor quantization: tensor shape [1] + ck_tile::index_t stride_BQ = 1; // Tensor quantization: tensor shape [1] + const auto& aq = inv_scale_view(*A_te); + const auto& bq = inv_scale_view(*B_te); + descs.emplace_back( a.dptr, b.dptr, - std::array{}, d.dptr, + aq.dptr, + bq.dptr, 1, M, N, K, + AQK, + BQK, stride_A, stride_B, - std::array{}, - stride_E); + stride_E, + stride_AQ, + stride_BQ + ); + } else { + descs.emplace_back( + a.dptr, + b.dptr, + std::array{}, + d.dptr, + 1, + M, + N, + K, + stride_A, + stride_B, + std::array{}, + stride_E); + } } const dim3 grids = Kernel::GridSize(descs); @@ -260,48 +402,61 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, // Get N from D[0] (assume uniform N across groups) int64_t ref_d0 = 0, ref_d1 = 0; Tensor* D0_te = convertNVTETensorCheck(D[0]); + const auto d_dtype = D0_te->dtype(); if (!get_flat_2d_dims(*D0_te, ref_d0, ref_d1)) { NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for D[0]"); return false; } const ck_tile::index_t N = static_cast(ref_d1); - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, te_type, { - using T = typename TETypeToCKType::type; - auto run_with_tilecfg = [&](auto tile_tag) -> bool { - using TileCfgSel = decltype(tile_tag); - TRANSFORMER_ENGINE_SWITCH_CONDITION(transA_use, kTransA, { - using ALayout = std::conditional_t; + // Mixed type dispatch: fp16, bf16, fp8 e4m3/e5m2 + TRANSFORMER_ENGINE_TYPE_SWITCH_MIXED(a_dtype, te_type, { + using AType = typename TETypeToCKType::type; + using BType = AType; + using Policy = GemmTilePolicy; + + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + // Select quantization mode based on input data type + constexpr bool TensorQuantMode = + std::is_same_v || std::is_same_v; + + auto run_with_tilecfg = [&](auto tile_tag) -> bool { + using TileCfgSel = decltype(tile_tag); - TRANSFORMER_ENGINE_SWITCH_CONDITION(transB_use, kTransB, { - using BLayout = std::conditional_t; + TRANSFORMER_ENGINE_SWITCH_CONDITION(transA_use, kTransA, { + using ALayout = std::conditional_t; - if (accumulate) { - return run_grouped_impl( - A_use, B_use, D, group_num, kTransA, kTransB, ws_ptr, ws_bytes, stream); - } else { - return run_grouped_impl( - A_use, B_use, D, group_num, kTransA, kTransB, ws_ptr, ws_bytes, stream); - } + TRANSFORMER_ENGINE_SWITCH_CONDITION(transB_use, kTransB, { + using BLayout = std::conditional_t; + + if (accumulate) { + return run_grouped_impl( + A_use, B_use, D, group_num, kTransA, kTransB, ws_ptr, ws_bytes, stream); + } else { + return run_grouped_impl( + A_use, B_use, D, group_num, kTransA, kTransB, ws_ptr, ws_bytes, stream); + } + }); }); - }); - }; - - // Select tile config like Primus-Turbo for FP16/BF16: - // N%256 -> 256x256x64 - // N%128 -> 256x128x64 - // else -> 256x128x64 padding - // NOTE: We assume N is uniform across groups. - if ((N % 256) == 0) { - return run_with_tilecfg(TileCfg_256x256x64{}); - } else if ((N % 128) == 0) { - return run_with_tilecfg(TileCfg_256x128x64{}); - } else { - return run_with_tilecfg(TileCfg_256x128x64_padding{}); - } - }); + }; + + // Select tile config like Primus-Turbo for FP16/BF16: + // N%256 -> 256x256x64 + // N%128 -> 256x128x64 + // else -> 256x128x64 padding + // NOTE: We assume N is uniform across groups. + if ((N % 256) == 0) { + return run_with_tilecfg(typename Policy::Tile256x256{}); + } else if ((N % 128) == 0) { + return run_with_tilecfg(typename Policy::Tile256x128{}); + } else { + return run_with_tilecfg(typename Policy::TilePadding{}); + } + }); // TRANSFORMER_ENGINE_SWITCH_16BIT + }); // TRANSFORMER_ENGINE_TYPE_SWITCH_MIXED } From 54da682d2293ebd29a7675b6e631bfa9a0d8c8bf Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Wed, 4 Mar 2026 16:16:45 +0000 Subject: [PATCH 40/51] Include Float8 E4M3/E5M2 in is_supported_dtype and remove float8 from param_types in test_numerics --- tests/pytorch/test_numerics.py | 2 +- transformer_engine/common/gemm/cublaslt_gemm.cu | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 697ddb84f..28b669dc9 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -97,7 +97,7 @@ def rocm_attn_backend() -> tuple[bool, bool, bool]: module_inference = ["TransformerLayer", "MultiheadAttention"] input_formats_inference = ["sbhd", "bshd"] -param_types = [torch.float32, torch.float16, torch.float8_e4m3fnuz] +param_types = [torch.float32, torch.float16] if is_bf16_compatible(): # bf16 requires sm_80 or higher param_types.append(torch.bfloat16) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index cbed586ca..3b2633a36 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -852,9 +852,11 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor auto A_dt = inputA->data.dtype; auto B_dt = inputB->data.dtype; auto D_dt = OutputD->data.dtype; - return (A_dt == B_dt) && (A_dt == D_dt) && + return (A_dt == B_dt) && (A_dt == transformer_engine::DType::kFloat16 || - A_dt == transformer_engine::DType::kBFloat16); + A_dt == transformer_engine::DType::kBFloat16 || + A_dt == transformer_engine::DType::kFloat8E4M3 || + A_dt == transformer_engine::DType::kFloat8E5M2); #else auto A_type = get_cuda_dtype(inputA->data.dtype); auto B_type = get_cuda_dtype(inputB->data.dtype); From 78a702f1fe52353401603397aaaabc4fd1869f27 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Wed, 4 Mar 2026 19:43:30 +0000 Subject: [PATCH 41/51] forward pass ck_tile with matching FP8 data type inputs passing accuracy tests --- .../common/gemm/ck_grouped_gemm.cpp | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp index aa7f50f95..ac0ef786c 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -25,6 +25,7 @@ using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; template struct TETypeToCKType; +template <> struct TETypeToCKType { using type = float; }; template <> struct TETypeToCKType { using type = ck_tile::fp8_t; }; template <> struct TETypeToCKType { using type = ck_tile::bf8_t; }; template <> struct TETypeToCKType { using type = ck_tile::half_t; }; @@ -159,7 +160,7 @@ struct GemmTilePolicy { // See e.g. https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/03_gemm/universal_gemm_invoker.hpp for reference. template struct Runner{ @@ -192,9 +193,9 @@ struct Runner{ using Problem = std::conditional_t< useTensorQuant, ck_tile::GemmRowColTensorQuantPipelineProblem< - AType, BType, AccType, - AccType, GemmShape, UniversalTraits, - false, AccType>, + AType, BType, AccType, + AccType, GemmShape, UniversalTraits, + false, AccType>, ck_tile::UniversalGemmPipelineProblem< AType, BType, AccType, GemmShape, UniversalTraits, Scheduler>>; @@ -217,7 +218,7 @@ struct Runner{ Partitioner, Pipeline, Epilogue, QuantMode>, ck_tile::GroupedGemmKernel< - Partitioner, Pipeline, Epilogue>>; + Partitioner, Pipeline, Epilogue>>; }; template (ref_d1); - - // Mixed type dispatch: fp16, bf16, fp8 e4m3/e5m2 TRANSFORMER_ENGINE_TYPE_SWITCH_MIXED(a_dtype, te_type, { using AType = typename TETypeToCKType::type; using BType = AType; using Policy = GemmTilePolicy; - TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(d_dtype, d_te_type, { + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(d_dtype, d_te_type, { using CType = typename TETypeToCKType::type; // Select quantization mode based on input data type constexpr bool TensorQuantMode = - std::is_same_v || std::is_same_v; + std::is_same_v || std::is_same_v; auto run_with_tilecfg = [&](auto tile_tag) -> bool { using TileCfgSel = decltype(tile_tag); @@ -459,4 +458,4 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, } }); // TRANSFORMER_ENGINE_SWITCH_16BIT }); // TRANSFORMER_ENGINE_TYPE_SWITCH_MIXED -} +} \ No newline at end of file From f198341c41843f4b58b5cfddb450e9b1ef2ac2d1 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Fri, 6 Mar 2026 12:39:35 +0000 Subject: [PATCH 42/51] Support mixed FP8/BF8 grouped GEMM in CK backward path Enable mixed FP8/BF8 grouped GEMM for the CK backend used by GroupedLinear backward. Certain mixed-type combinations normalize to (AType=bf8_t, BType=fp8_t), but CK currently lacks a corresponding warp GEMM specialization for WarpGemmMfma_f32_32x32x32_bf8_fp8. This prevents the default FP8 tile configuration (K_Warp_Tile=32) from compiling or dispatching correctly. To address this, a fallback tile policy is introduced that routes the (bf8_t, fp8_t) case to a supported kernel configuration using K_Warp_Tile=16. This preserves correct GEMM operand ordering and avoids unsafe operand-swapping workarounds. Notes: - Only tensor quantization mode is currently supported. - Implementation targets MI300X (CDNA3) FP8/BF8 kernels. - Additional kernel coverage may be required for MI350X (CDNA4). With this change, mixed FP8/BF8 backprop paths are supported and all parametrized unit tests in test_grouped_linear_accuracy_cutlass() pass successfully. --- tests/pytorch/test_numerics.py | 2 + transformer_engine/common/common.h | 27 -- .../common/gemm/ck_grouped_gemm.cpp | 290 ++++++++++++------ .../common/gemm/cublaslt_gemm.cu | 9 +- 4 files changed, 202 insertions(+), 126 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 28b669dc9..6c3cfbe2a 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2136,6 +2136,8 @@ def test_grouped_linear_accuracy( torch.cuda.get_device_capability() != (9, 0) and not IS_HIP_EXTENSION, reason="Only enable CUTLASS grouped gemm on Hopper", ) +#@pytest.mark.parametrize("dtype", param_types, ids=str) + @pytest.mark.parametrize("dtype", param_types, ids=str) @pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("bs", batch_sizes) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 2e9672644..0015e9155 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -658,33 +658,6 @@ struct TypeInfo { NVTE_ERROR("Invalid type for 16 bit."); \ } -#define TRANSFORMER_ENGINE_TYPE_SWITCH_MIXED(dtype, type, ...) \ - switch (dtype) { \ - using namespace transformer_engine; \ - case DType::kFloat16: { \ - using type = fp16; \ - __VA_ARGS__; \ - break; \ - } \ - case DType::kBFloat16: { \ - using type = bf16; \ - __VA_ARGS__; \ - break; \ - } \ - case DType::kFloat8E5M2: { \ - using type = fp8e5m2; \ - __VA_ARGS__; \ - break; \ - } \ - case DType::kFloat8E4M3: { \ - using type = fp8e4m3; \ - __VA_ARGS__; \ - break; \ - } \ - default: \ - NVTE_ERROR("Invalid type."); \ - } - #define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \ switch (SCALE_DIM) { \ case 1: { \ diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp index ac0ef786c..8c9f9a894 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -6,6 +6,8 @@ #include +#include + #include #include "../common.h" @@ -25,16 +27,16 @@ using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; template struct TETypeToCKType; -template <> struct TETypeToCKType { using type = float; }; +template <> struct TETypeToCKType { using type = float; }; template <> struct TETypeToCKType { using type = ck_tile::fp8_t; }; template <> struct TETypeToCKType { using type = ck_tile::bf8_t; }; -template <> struct TETypeToCKType { using type = ck_tile::half_t; }; -template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; +template <> struct TETypeToCKType { using type = ck_tile::half_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; // Treat TE tensors as generalized 2D matrices by flattening: // (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim. static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, - int64_t& d0, int64_t& d1) { + int64_t& d0, int64_t& d1) { // Require at least a matrix (rank >= 2). Higher ranks are flattened. if (t.shape().size() < 2) return false; @@ -44,11 +46,11 @@ static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, } static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { - return t.data; // rowwise data view + return t.data; // rowwise data view } -static inline const transformer_engine::SimpleTensor& inv_scale_view(const transformer_engine::Tensor& t) { - return t.scale_inv; // dequantization scaling factor +static inline const transformer_engine::SimpleTensor& scale_inv_view(const transformer_engine::Tensor& t) { + return t.scale_inv; // dequantization scaling factor } // Primus-Turbo-like FP16/BF16 tile configs @@ -87,7 +89,6 @@ struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { static constexpr bool kPadN = true; }; - // Primus-Turbo-like FP8/BF8 tile configs // Selection rule: // if (N % 256 == 0) use 256x256x128 @@ -120,7 +121,43 @@ struct TileCfg_256x128x128 : TileCfg_256x256x128 { static constexpr ck_tile::index_t N_Tile = 128; }; -struct TileCfg_256x128x128_padding : TileCfg_256x256x128 { +struct TileCfg_256x128x128_padding : TileCfg_256x128x128 { + static constexpr bool kPadN = true; +}; + +// Fallback FP8/BF8 tile family for normalized (bf8_t, fp8_t) pair during backprop. +// That is, while there is a supported WarpGemmMfma_f32_32x32x32_fp8_bf8, +// there is no such thing as WarpGemmMfma_f32_32x32x32_bf8_fp8, +// so we need to fall back to WarpGemmMfma_f32_32x32x16_bf8_fp8 +// by selecting K_Warp_Tile = 16 +struct TileCfg_256x256x128_k16 { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + +struct TileCfg_256x128x128_k16 : TileCfg_256x256x128_k16 { + static constexpr ck_tile::index_t N_Tile = 128; +}; + +struct TileCfg_256x128x128_k16_padding : TileCfg_256x128x128_k16 { static constexpr bool kPadN = true; }; @@ -141,7 +178,7 @@ struct GemmTilePolicy { using TilePadding = TileCfg_256x128x64_padding; }; -// FP8 – K=128 tiles +// FP8/BF8 – K=128 tiles template <> struct GemmTilePolicy { using Tile256x256 = TileCfg_256x256x128; @@ -156,14 +193,23 @@ struct GemmTilePolicy { using TilePadding = TileCfg_256x128x128_padding; }; +// Fallback policy for normalized mixed pair: +// AType = bf8_t, BType = fp8_t +struct GemmTilePolicyBF8FP8Fallback { + using Tile256x256 = TileCfg_256x256x128_k16; + using Tile256x128 = TileCfg_256x128x128_k16; + using TilePadding = TileCfg_256x128x128_k16_padding; +}; + // This class instantiates CK_Tile's grouped GEMM pipeline. // See e.g. https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/03_gemm/universal_gemm_invoker.hpp for reference. +// Currently the only quantization mode supported is tensor quant template -struct Runner{ +struct Runner { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, ck_tile::sequence, @@ -177,15 +223,15 @@ struct Runner{ static constexpr ck_tile::QuantType QuantMode = ck_tile::QuantType::TensorQuant; using UniversalTraits = std::conditional_t< - useTensorQuant, - ck_tile::TileGemmQuantTraits< - TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, - false, false, ALayout, BLayout, CLayout, - QuantMode, AQLayout, BQLayout, - false, TileCfg::DoubleSmemBuffer, false>, - ck_tile::PersistentTileGemmUniversalTraits< - TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, - TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>>; + useTensorQuant, + ck_tile::TileGemmQuantTraits< + TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, + false, false, ALayout, BLayout, CLayout, + QuantMode, AQLayout, BQLayout, + false, TileCfg::DoubleSmemBuffer, false>, + ck_tile::PersistentTileGemmUniversalTraits< + TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, + TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>>; static constexpr ck_tile::GemmPipelineScheduler Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; @@ -193,12 +239,12 @@ struct Runner{ using Problem = std::conditional_t< useTensorQuant, ck_tile::GemmRowColTensorQuantPipelineProblem< - AType, BType, AccType, - AccType, GemmShape, UniversalTraits, - false, AccType>, + AType, BType, AccType, + AccType, GemmShape, UniversalTraits, + false, AccType>, ck_tile::UniversalGemmPipelineProblem< - AType, BType, AccType, - GemmShape, UniversalTraits, Scheduler>>; + AType, BType, AccType, + GemmShape, UniversalTraits, Scheduler>>; using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -213,12 +259,11 @@ struct Runner{ Problem::TransposeC, MemOp>>; using Kernel = std::conditional_t< - useTensorQuant, - ck_tile::QuantGroupedGemmKernel< - Partitioner, Pipeline, - Epilogue, QuantMode>, - ck_tile::GroupedGemmKernel< - Partitioner, Pipeline, Epilogue>>; + useTensorQuant, + ck_tile::QuantGroupedGemmKernel< + Partitioner, Pipeline, Epilogue, QuantMode>, + ck_tile::GroupedGemmKernel< + Partitioner, Pipeline, Epilogue>>; }; template ; - using Kernel = typename RunnerT::Kernel; + using Kernel = typename RunnerT::Kernel; using HostArgs = std::conditional_t< useTensorQuant, @@ -290,36 +335,36 @@ static bool run_grouped_impl(const NVTETensor* A_use, return false; } - // Leading dimensions under the flattened-contiguous interpretation const ck_tile::index_t stride_A = Ad1; const ck_tile::index_t stride_B = Bd1; const ck_tile::index_t stride_E = Dd1; if constexpr (useTensorQuant) { - ck_tile::index_t AQK = 1; // Tensor quantization: tensor shape [1] - ck_tile::index_t BQK = 1; // Tensor quantization: tensor shape [1] - ck_tile::index_t stride_AQ = 1; // Tensor quantization: tensor shape [1] - ck_tile::index_t stride_BQ = 1; // Tensor quantization: tensor shape [1] - const auto& aq = inv_scale_view(*A_te); - const auto& bq = inv_scale_view(*B_te); + ck_tile::index_t AQK = 1; + ck_tile::index_t BQK = 1; + ck_tile::index_t stride_AQ = 1; + ck_tile::index_t stride_BQ = 1; + + const auto& aq = scale_inv_view(*A_te); + const auto& bq = scale_inv_view(*B_te); + descs.emplace_back( - a.dptr, - b.dptr, - d.dptr, - aq.dptr, - bq.dptr, - 1, - M, - N, - K, - AQK, - BQK, - stride_A, - stride_B, - stride_E, - stride_AQ, - stride_BQ - ); + a.dptr, + b.dptr, + d.dptr, + aq.dptr, + bq.dptr, + 1, + M, + N, + K, + AQK, + BQK, + stride_A, + stride_B, + stride_E, + stride_AQ, + stride_BQ); } else { descs.emplace_back( a.dptr, @@ -338,17 +383,17 @@ static bool run_grouped_impl(const NVTETensor* A_use, } const dim3 grids = Kernel::GridSize(descs); - auto kargs = Kernel::MakeKargs(descs); + auto kargs = Kernel::MakeKargs(descs); if (!Kernel::IsSupportedArgument(kargs)) { NVTE_ERROR("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config."); return false; } HIP_CHECK_ERROR(hipMemcpyAsync(workspace, - kargs.data(), - kargs.size() * sizeof(typename decltype(kargs)::value_type), - hipMemcpyHostToDevice, - stream)); + kargs.data(), + kargs.size() * sizeof(typename decltype(kargs)::value_type), + hipMemcpyHostToDevice, + stream)); const ck_tile::stream_config s{stream}; const dim3 blocks = Kernel::BlockSize(); @@ -359,6 +404,7 @@ static bool run_grouped_impl(const NVTETensor* A_use, Kernel{}, grids, blocks, 0, ck_tile::cast_pointer_to_constant_address_space(workspace), group_num)); + return true; } @@ -373,8 +419,7 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, bool transB, NVTETensor* workspace, bool accumulate, - hipStream_t stream) -{ + hipStream_t stream) { if (group_num <= 0) return true; @@ -391,7 +436,8 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, } // Normalize similar to upstream - // See https://github.com/NVIDIA/TransformerEngine/blob/59f6f3876767d07045152bfae07b5dd4c54e1725/transformer_engine/common/gemm/cutlass_grouped_gemm.cu#L54-L68 + // See: + // transformer_engine/common/gemm/cutlass_grouped_gemm.cu // I.e., swap A and B, as well as transa and transb. const NVTETensor* A_use = B; const NVTETensor* B_use = A; @@ -399,30 +445,55 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, const bool transB_use = transA; const auto a_dtype = convertNVTETensorCheck(A_use[0])->dtype(); + const auto b_dtype = convertNVTETensorCheck(B_use[0])->dtype(); - // Get N from D[0] (assume uniform N across groups) - int64_t ref_d0 = 0, ref_d1 = 0; + // D dtype + N (assume uniform N across groups) Tensor* D0_te = convertNVTETensorCheck(D[0]); const auto d_dtype = D0_te->dtype(); + + int64_t ref_d0 = 0, ref_d1 = 0; if (!get_flat_2d_dims(*D0_te, ref_d0, ref_d1)) { NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for D[0]"); return false; } const ck_tile::index_t N = static_cast(ref_d1); - // Mixed type dispatch: fp16, bf16, fp8 e4m3/e5m2 - TRANSFORMER_ENGINE_TYPE_SWITCH_MIXED(a_dtype, te_type, { - using AType = typename TETypeToCKType::type; - using BType = AType; - using Policy = GemmTilePolicy; + auto choose_tile = [&](auto policy_tag, auto&& run_tile) -> bool { + using Policy = decltype(policy_tag); + if ((N % 256) == 0) return run_tile(typename Policy::Tile256x256{}); + if ((N % 128) == 0) return run_tile(typename Policy::Tile256x128{}); + return run_tile(typename Policy::TilePadding{}); + }; + + auto dispatch_pair = [&](auto ATag, auto BTag) -> bool { + using teA = decltype(ATag); + using teB = decltype(BTag); + + using AType = typename TETypeToCKType::type; + using BType = typename TETypeToCKType::type; + + constexpr bool TensorQuantMode = + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v; - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(d_dtype, d_te_type, { + constexpr bool UseBF8FP8Fallback = + std::is_same_v && + std::is_same_v; + + using DefaultPolicy = std::conditional_t< + TensorQuantMode, + GemmTilePolicy, + GemmTilePolicy>; + + using Policy = std::conditional_t< + UseBF8FP8Fallback, + GemmTilePolicyBF8FP8Fallback, + DefaultPolicy>; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { using CType = typename TETypeToCKType::type; - // Select quantization mode based on input data type - constexpr bool TensorQuantMode = - std::is_same_v || std::is_same_v; - auto run_with_tilecfg = [&](auto tile_tag) -> bool { + auto run_tile = [&](auto tile_tag) -> bool { using TileCfgSel = decltype(tile_tag); TRANSFORMER_ENGINE_SWITCH_CONDITION(transA_use, kTransA, { @@ -432,30 +503,57 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, using BLayout = std::conditional_t; if (accumulate) { - return run_grouped_impl( - A_use, B_use, D, group_num, kTransA, kTransB, ws_ptr, ws_bytes, stream); + return run_grouped_impl< + AType, BType, CType, + ALayout, BLayout, RowMajor, + TileCfgSel, TensorQuantMode, + ck_tile::memory_operation_enum::atomic_add>( + A_use, B_use, D, group_num, + kTransA, kTransB, + ws_ptr, ws_bytes, stream); } else { - return run_grouped_impl( - A_use, B_use, D, group_num, kTransA, kTransB, ws_ptr, ws_bytes, stream); + return run_grouped_impl< + AType, BType, CType, + ALayout, BLayout, RowMajor, + TileCfgSel, TensorQuantMode, + ck_tile::memory_operation_enum::set>( + A_use, B_use, D, group_num, + kTransA, kTransB, + ws_ptr, ws_bytes, stream); } }); }); }; - // Select tile config like Primus-Turbo for FP16/BF16: - // N%256 -> 256x256x64 - // N%128 -> 256x128x64 - // else -> 256x128x64 padding - // NOTE: We assume N is uniform across groups. - if ((N % 256) == 0) { - return run_with_tilecfg(typename Policy::Tile256x256{}); - } else if ((N % 128) == 0) { - return run_with_tilecfg(typename Policy::Tile256x128{}); - } else { - return run_with_tilecfg(typename Policy::TilePadding{}); - } - }); // TRANSFORMER_ENGINE_SWITCH_16BIT - }); // TRANSFORMER_ENGINE_TYPE_SWITCH_MIXED + return choose_tile(Policy{}, run_tile); + }); + + return false; + }; + + switch (a_dtype) { + case DType::kFloat16: + if (b_dtype == DType::kFloat16) return dispatch_pair(fp16{}, fp16{}); + break; + + case DType::kBFloat16: + if (b_dtype == DType::kBFloat16) return dispatch_pair(bf16{}, bf16{}); + break; + + case DType::kFloat8E4M3: + if (b_dtype == DType::kFloat8E4M3) return dispatch_pair(fp8e4m3{}, fp8e4m3{}); + if (b_dtype == DType::kFloat8E5M2) return dispatch_pair(fp8e4m3{}, fp8e5m2{}); + break; + + case DType::kFloat8E5M2: + if (b_dtype == DType::kFloat8E5M2) return dispatch_pair(fp8e5m2{}, fp8e5m2{}); + if (b_dtype == DType::kFloat8E4M3) return dispatch_pair(fp8e5m2{}, fp8e4m3{}); + break; + + default: + break; + } + + NVTE_ERROR("ck_tile_grouped_gemm: unsupported dtype pair for CK path."); + return false; } \ No newline at end of file diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 3b2633a36..0dc425876 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -852,11 +852,14 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor auto A_dt = inputA->data.dtype; auto B_dt = inputB->data.dtype; auto D_dt = OutputD->data.dtype; - return (A_dt == B_dt) && - (A_dt == transformer_engine::DType::kFloat16 || + return (A_dt == transformer_engine::DType::kFloat16 || A_dt == transformer_engine::DType::kBFloat16 || A_dt == transformer_engine::DType::kFloat8E4M3 || - A_dt == transformer_engine::DType::kFloat8E5M2); + A_dt == transformer_engine::DType::kFloat8E5M2) && + (B_dt == transformer_engine::DType::kFloat16 || + B_dt == transformer_engine::DType::kBFloat16 || + B_dt == transformer_engine::DType::kFloat8E4M3 || + B_dt == transformer_engine::DType::kFloat8E5M2); #else auto A_type = get_cuda_dtype(inputA->data.dtype); auto B_type = get_cuda_dtype(inputB->data.dtype); From 6b24be2fe6be5ba42bcbab40380a4f8ed9d52af1 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Fri, 6 Mar 2026 15:31:56 +0000 Subject: [PATCH 43/51] include more descriptive comment regarding tensor normalization in ck grouped gemm --- transformer_engine/common/gemm/ck_grouped_gemm.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp index 8c9f9a894..2881151d7 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -435,10 +435,12 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, ws_bytes = ws_te->data.numel() * typeToSize(ws_te->data.dtype); } - // Normalize similar to upstream - // See: - // transformer_engine/common/gemm/cutlass_grouped_gemm.cu - // I.e., swap A and B, as well as transa and transb. + + // Normalize operand order to match upstream CUTLASS path. + // TE grouped GEMM frontend passes (A=weights, B=inputs) with layout "TN", + // i.e. effectively W^T * X. The backend kernels expect inputs first + // (X * W^T), so swap A/B and their transpose flags while preserving + // the same mathematical operation. const NVTETensor* A_use = B; const NVTETensor* B_use = A; const bool transA_use = transB; From a161c2037f3c2b5a82fb733e12ef1a2ac7d019a5 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 9 Mar 2026 17:15:44 +0000 Subject: [PATCH 44/51] Refactor CK grouped GEMM: split FP8/FP16 implementations and introduce shared runner abstraction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit restructures the CK grouped GEMM implementation to improve maintainability and better separate datatype-specific logic. Key changes: • Split the original single-source implementation into separate files for FP16 and FP8 grouped GEMM kernels. • Introduced a shared header defining a common abstraction for grouped GEMM runners. The design is similar in spirit to the Primus Turbo dispatch. • Added an abstract parent class that encapsulates the common interface and provides an overloaded operator() / run() entrypoint for launching kernels. Concrete runners implement datatype-specific behavior while sharing the same invocation path. • Introduced a GroupedGemmRunContext structure that carries runtime configuration (layout, splits, pointers, etc.) through the dispatch pipeline. This removes large argument lists and centralizes execution state. • Refactored dispatch code to construct the appropriate runner and invoke it through the unified interface. • Added documentation comments explaining the new structure and the responsibilities of each component (context, runner base class, and datatype-specific implementations). Functional behavior is unchanged. The refactor preserves the previous execution paths and continues to pass all existing Transformer Engine tests that exercised the original implementation. --- transformer_engine/common/CMakeLists.txt | 2 + .../common/gemm/ck_grouped_gemm.cpp | 517 +----------------- .../common/gemm/ck_grouped_gemm_fp16.cpp | 145 +++++ .../common/gemm/ck_grouped_gemm_fp8.cpp | 201 +++++++ 4 files changed, 360 insertions(+), 505 deletions(-) create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm_fp16.cpp create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm_fp8.cpp diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index d5aab2cda..2d33215c3 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -203,6 +203,8 @@ else() fused_attn_rocm/utils.cpp gemm/rocm_gemm.cu gemm/ck_grouped_gemm.cpp + gemm/ck_grouped_gemm_fp16.cpp + gemm/ck_grouped_gemm_fp8.cpp amd_detail/system.cpp) # process source code files diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp index 2881151d7..f2c19df9d 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -4,412 +4,7 @@ * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ -#include - -#include - -#include -#include "../common.h" - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" - -#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" -#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" - -namespace transformer_engine { -namespace grouped_gemm { - -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; - -template struct TETypeToCKType; -template <> struct TETypeToCKType { using type = float; }; -template <> struct TETypeToCKType { using type = ck_tile::fp8_t; }; -template <> struct TETypeToCKType { using type = ck_tile::bf8_t; }; -template <> struct TETypeToCKType { using type = ck_tile::half_t; }; -template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; - -// Treat TE tensors as generalized 2D matrices by flattening: -// (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim. -static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, - int64_t& d0, int64_t& d1) { - // Require at least a matrix (rank >= 2). Higher ranks are flattened. - if (t.shape().size() < 2) - return false; - d0 = static_cast(t.flat_first_dim()); - d1 = static_cast(t.flat_last_dim()); - return true; -} - -static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { - return t.data; // rowwise data view -} - -static inline const transformer_engine::SimpleTensor& scale_inv_view(const transformer_engine::Tensor& t) { - return t.scale_inv; // dequantization scaling factor -} - -// Primus-Turbo-like FP16/BF16 tile configs -// Selection rule: -// if (N % 256 == 0) use 256x256x64 -// else if (N % 128 == 0) use 256x128x64 -// else use 256x128x64 with N padding enabled -struct TileCfg_256x256x64 { - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 64; - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool DoubleSmemBuffer = false; - - static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; - static constexpr ck_tile::index_t TilePartitionerM01 = 4; -}; - -struct TileCfg_256x128x64 : TileCfg_256x256x64 { - static constexpr ck_tile::index_t N_Tile = 128; -}; - -struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { - static constexpr bool kPadN = true; -}; - -// Primus-Turbo-like FP8/BF8 tile configs -// Selection rule: -// if (N % 256 == 0) use 256x256x128 -// else if (N % 128 == 0) use 256x128x128 -// else use 256x128x128 with N padding enabled -struct TileCfg_256x256x128 { - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 128; - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 32; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool DoubleSmemBuffer = false; - - static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; - static constexpr ck_tile::index_t TilePartitionerM01 = 4; -}; - -struct TileCfg_256x128x128 : TileCfg_256x256x128 { - static constexpr ck_tile::index_t N_Tile = 128; -}; - -struct TileCfg_256x128x128_padding : TileCfg_256x128x128 { - static constexpr bool kPadN = true; -}; - -// Fallback FP8/BF8 tile family for normalized (bf8_t, fp8_t) pair during backprop. -// That is, while there is a supported WarpGemmMfma_f32_32x32x32_fp8_bf8, -// there is no such thing as WarpGemmMfma_f32_32x32x32_bf8_fp8, -// so we need to fall back to WarpGemmMfma_f32_32x32x16_bf8_fp8 -// by selecting K_Warp_Tile = 16 -struct TileCfg_256x256x128_k16 { - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 128; - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool DoubleSmemBuffer = false; - - static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; - static constexpr ck_tile::index_t TilePartitionerM01 = 4; -}; - -struct TileCfg_256x128x128_k16 : TileCfg_256x256x128_k16 { - static constexpr ck_tile::index_t N_Tile = 128; -}; - -struct TileCfg_256x128x128_k16_padding : TileCfg_256x128x128_k16 { - static constexpr bool kPadN = true; -}; - -template struct GemmTilePolicy; - -// FP16/BF16 – K=64 tiles -template <> -struct GemmTilePolicy { - using Tile256x256 = TileCfg_256x256x64; - using Tile256x128 = TileCfg_256x128x64; - using TilePadding = TileCfg_256x128x64_padding; -}; - -template <> -struct GemmTilePolicy { - using Tile256x256 = TileCfg_256x256x64; - using Tile256x128 = TileCfg_256x128x64; - using TilePadding = TileCfg_256x128x64_padding; -}; - -// FP8/BF8 – K=128 tiles -template <> -struct GemmTilePolicy { - using Tile256x256 = TileCfg_256x256x128; - using Tile256x128 = TileCfg_256x128x128; - using TilePadding = TileCfg_256x128x128_padding; -}; - -template <> -struct GemmTilePolicy { - using Tile256x256 = TileCfg_256x256x128; - using Tile256x128 = TileCfg_256x128x128; - using TilePadding = TileCfg_256x128x128_padding; -}; - -// Fallback policy for normalized mixed pair: -// AType = bf8_t, BType = fp8_t -struct GemmTilePolicyBF8FP8Fallback { - using Tile256x256 = TileCfg_256x256x128_k16; - using Tile256x128 = TileCfg_256x128x128_k16; - using TilePadding = TileCfg_256x128x128_k16_padding; -}; - -// This class instantiates CK_Tile's grouped GEMM pipeline. -// See e.g. https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/03_gemm/universal_gemm_invoker.hpp for reference. -// Currently the only quantization mode supported is tensor quant -template -struct Runner { - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile::sequence>; - - using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< - GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; - - using AQLayout = RowMajor; - using BQLayout = RowMajor; - static constexpr ck_tile::QuantType QuantMode = ck_tile::QuantType::TensorQuant; - - using UniversalTraits = std::conditional_t< - useTensorQuant, - ck_tile::TileGemmQuantTraits< - TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, - false, false, ALayout, BLayout, CLayout, - QuantMode, AQLayout, BQLayout, - false, TileCfg::DoubleSmemBuffer, false>, - ck_tile::PersistentTileGemmUniversalTraits< - TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, - TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>>; - - static constexpr ck_tile::GemmPipelineScheduler Scheduler = - ck_tile::GemmPipelineScheduler::Intrawave; - - using Problem = std::conditional_t< - useTensorQuant, - ck_tile::GemmRowColTensorQuantPipelineProblem< - AType, BType, AccType, - AccType, GemmShape, UniversalTraits, - false, AccType>, - ck_tile::UniversalGemmPipelineProblem< - AType, BType, AccType, - GemmShape, UniversalTraits, Scheduler>>; - - using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; - - using Epilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem< - AType, BType, ck_tile::tuple<>, AccType, - CType, ck_tile::tuple<>, CLayout, - ck_tile::element_wise::PassThrough, - Partitioner::MPerBlock, Partitioner::NPerBlock, - TileCfg::M_Warp, TileCfg::N_Warp, - TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, - Problem::TransposeC, MemOp>>; - - using Kernel = std::conditional_t< - useTensorQuant, - ck_tile::QuantGroupedGemmKernel< - Partitioner, Pipeline, Epilogue, QuantMode>, - ck_tile::GroupedGemmKernel< - Partitioner, Pipeline, Epilogue>>; -}; - -template -static bool run_grouped_impl(const NVTETensor* A_use, - const NVTETensor* B_use, - NVTETensor* D, - int group_num, - bool transA_use, - bool transB_use, - void* workspace, - size_t workspace_bytes, - hipStream_t stream) -{ - using RunnerT = Runner; - using Kernel = typename RunnerT::Kernel; - - using HostArgs = std::conditional_t< - useTensorQuant, - ck_tile::QuantGroupedGemmHostArgs, - ck_tile::GroupedGemmHostArgs<0>>; - - const size_t needed = Kernel::GetWorkSpaceSize(group_num); - if (!workspace || workspace_bytes < needed) { - NVTE_ERROR("ck_tile_grouped_gemm: insufficient workspace. Needed bytes=", needed); - return false; - } - - thread_local std::vector descs; - descs.clear(); - descs.reserve(group_num); - - for (int i = 0; i < group_num; ++i) { - const transformer_engine::Tensor* const A_te = - transformer_engine::convertNVTETensorCheck(A_use[i]); - const transformer_engine::Tensor* const B_te = - transformer_engine::convertNVTETensorCheck(B_use[i]); - transformer_engine::Tensor* D_te = - transformer_engine::convertNVTETensorCheck(D[i]); - - const auto& a = data_view(*A_te); - const auto& b = data_view(*B_te); - const auto& d = data_view(*D_te); - - int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; - if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || - !get_flat_2d_dims(*B_te, Bd0, Bd1) || - !get_flat_2d_dims(*D_te, Dd0, Dd1)) { - NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2 (2D or higher)."); - return false; - } - - const int64_t M = transA_use ? Ad1 : Ad0; - const int64_t K = transA_use ? Ad0 : Ad1; - const int64_t N = transB_use ? Bd0 : Bd1; - const int64_t Kb = transB_use ? Bd1 : Bd0; - - if (Kb != K) { - NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); - return false; - } - - if (Dd0 != M || Dd1 != N) { - NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); - return false; - } - - const ck_tile::index_t stride_A = Ad1; - const ck_tile::index_t stride_B = Bd1; - const ck_tile::index_t stride_E = Dd1; - - if constexpr (useTensorQuant) { - ck_tile::index_t AQK = 1; - ck_tile::index_t BQK = 1; - ck_tile::index_t stride_AQ = 1; - ck_tile::index_t stride_BQ = 1; - - const auto& aq = scale_inv_view(*A_te); - const auto& bq = scale_inv_view(*B_te); - - descs.emplace_back( - a.dptr, - b.dptr, - d.dptr, - aq.dptr, - bq.dptr, - 1, - M, - N, - K, - AQK, - BQK, - stride_A, - stride_B, - stride_E, - stride_AQ, - stride_BQ); - } else { - descs.emplace_back( - a.dptr, - b.dptr, - std::array{}, - d.dptr, - 1, - M, - N, - K, - stride_A, - stride_B, - std::array{}, - stride_E); - } - } - - const dim3 grids = Kernel::GridSize(descs); - auto kargs = Kernel::MakeKargs(descs); - if (!Kernel::IsSupportedArgument(kargs)) { - NVTE_ERROR("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config."); - return false; - } - - HIP_CHECK_ERROR(hipMemcpyAsync(workspace, - kargs.data(), - kargs.size() * sizeof(typename decltype(kargs)::value_type), - hipMemcpyHostToDevice, - stream)); - - const ck_tile::stream_config s{stream}; - const dim3 blocks = Kernel::BlockSize(); - - ck_tile::launch_kernel( - s, - ck_tile::make_kernel<1>( - Kernel{}, grids, blocks, 0, - ck_tile::cast_pointer_to_constant_address_space(workspace), - group_num)); - - return true; -} - -} // namespace grouped_gemm -} // namespace transformer_engine +#include "ck_grouped_gemm_common.h" bool ck_tile_grouped_gemm(const NVTETensor* A, const NVTETensor* B, @@ -420,22 +15,21 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, NVTETensor* workspace, bool accumulate, hipStream_t stream) { - if (group_num <= 0) + if (group_num <= 0) { return true; + } using namespace transformer_engine; using namespace transformer_engine::grouped_gemm; - // Workspace pointer + bytes - void* ws_ptr = nullptr; + void* ws_ptr = nullptr; size_t ws_bytes = 0; if (workspace) { auto* ws_te = convertNVTETensorCheck(*workspace); - ws_ptr = ws_te->data.dptr; + ws_ptr = ws_te->data.dptr; ws_bytes = ws_te->data.numel() * typeToSize(ws_te->data.dtype); } - // Normalize operand order to match upstream CUTLASS path. // TE grouped GEMM frontend passes (A=weights, B=inputs) with layout "TN", // i.e. effectively W^T * X. The backend kernels expect inputs first @@ -449,7 +43,6 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, const auto a_dtype = convertNVTETensorCheck(A_use[0])->dtype(); const auto b_dtype = convertNVTETensorCheck(B_use[0])->dtype(); - // D dtype + N (assume uniform N across groups) Tensor* D0_te = convertNVTETensorCheck(D[0]); const auto d_dtype = D0_te->dtype(); @@ -458,102 +51,16 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for D[0]"); return false; } - const ck_tile::index_t N = static_cast(ref_d1); - - auto choose_tile = [&](auto policy_tag, auto&& run_tile) -> bool { - using Policy = decltype(policy_tag); - if ((N % 256) == 0) return run_tile(typename Policy::Tile256x256{}); - if ((N % 128) == 0) return run_tile(typename Policy::Tile256x128{}); - return run_tile(typename Policy::TilePadding{}); - }; - - auto dispatch_pair = [&](auto ATag, auto BTag) -> bool { - using teA = decltype(ATag); - using teB = decltype(BTag); - - using AType = typename TETypeToCKType::type; - using BType = typename TETypeToCKType::type; - - constexpr bool TensorQuantMode = - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v; - - constexpr bool UseBF8FP8Fallback = - std::is_same_v && - std::is_same_v; - - using DefaultPolicy = std::conditional_t< - TensorQuantMode, - GemmTilePolicy, - GemmTilePolicy>; - - using Policy = std::conditional_t< - UseBF8FP8Fallback, - GemmTilePolicyBF8FP8Fallback, - DefaultPolicy>; - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { - using CType = typename TETypeToCKType::type; - auto run_tile = [&](auto tile_tag) -> bool { - using TileCfgSel = decltype(tile_tag); + // construct run context + GroupedGemmRunContext ctx = {A_use, B_use, D, ref_d1, group_num, transA_use, transB_use, accumulate, ws_ptr, ws_bytes, stream}; - TRANSFORMER_ENGINE_SWITCH_CONDITION(transA_use, kTransA, { - using ALayout = std::conditional_t; - - TRANSFORMER_ENGINE_SWITCH_CONDITION(transB_use, kTransB, { - using BLayout = std::conditional_t; - - if (accumulate) { - return run_grouped_impl< - AType, BType, CType, - ALayout, BLayout, RowMajor, - TileCfgSel, TensorQuantMode, - ck_tile::memory_operation_enum::atomic_add>( - A_use, B_use, D, group_num, - kTransA, kTransB, - ws_ptr, ws_bytes, stream); - } else { - return run_grouped_impl< - AType, BType, CType, - ALayout, BLayout, RowMajor, - TileCfgSel, TensorQuantMode, - ck_tile::memory_operation_enum::set>( - A_use, B_use, D, group_num, - kTransA, kTransB, - ws_ptr, ws_bytes, stream); - } - }); - }); - }; - - return choose_tile(Policy{}, run_tile); - }); - - return false; - }; - - switch (a_dtype) { - case DType::kFloat16: - if (b_dtype == DType::kFloat16) return dispatch_pair(fp16{}, fp16{}); - break; - - case DType::kBFloat16: - if (b_dtype == DType::kBFloat16) return dispatch_pair(bf16{}, bf16{}); - break; - - case DType::kFloat8E4M3: - if (b_dtype == DType::kFloat8E4M3) return dispatch_pair(fp8e4m3{}, fp8e4m3{}); - if (b_dtype == DType::kFloat8E5M2) return dispatch_pair(fp8e4m3{}, fp8e5m2{}); - break; - - case DType::kFloat8E5M2: - if (b_dtype == DType::kFloat8E5M2) return dispatch_pair(fp8e5m2{}, fp8e5m2{}); - if (b_dtype == DType::kFloat8E4M3) return dispatch_pair(fp8e5m2{}, fp8e4m3{}); - break; + if (ck_tile_grouped_gemm_fp16_dispatch(a_dtype, b_dtype, d_dtype, ctx)) { + return true; + } - default: - break; + if (ck_tile_grouped_gemm_fp8_dispatch(a_dtype, b_dtype, d_dtype, ctx)) { + return true; } NVTE_ERROR("ck_tile_grouped_gemm: unsupported dtype pair for CK path."); diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp16.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp16.cpp new file mode 100644 index 000000000..4dcf35a45 --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp16.cpp @@ -0,0 +1,145 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include "ck_grouped_gemm_common.h" + +namespace transformer_engine { +namespace grouped_gemm { + +template +std::unique_ptr get_f16_gemm_instance(DType d_dtype, const GroupedGemmRunContext& ctx) { + std::unique_ptr runner = nullptr; + using AType = typename TETypeToCKType::type; + using BType = typename TETypeToCKType::type; + using CLayout = RowMajor; + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + if (ctx.N % 256 == 0) { + using TileCfg = TileCfg_256x256x64; + if (ctx.accumulate) { + using Runner = GroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = GroupedGemmRunner; + runner = std::make_unique(); + } + + } else if (ctx.N % 128 == 0) { + using TileCfg = TileCfg_256x128x64; + if (ctx.accumulate) { + using Runner = GroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = GroupedGemmRunner; + runner = std::make_unique(); + } + } else { + using TileCfg = TileCfg_256x128x64_padding; + if (ctx.accumulate) { + using Runner = GroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = GroupedGemmRunner; + runner = std::make_unique(); + } + } + }); + return runner; +} + +bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + + const ck_tile::stream_config s{ctx.stream}; + std::unique_ptr runner; + + if (!ctx.transA && !ctx.transB) { + using ALayout = RowMajor; + using BLayout = RowMajor; + + switch (a_dtype) { + case DType::kFloat16: + if (b_dtype == DType::kFloat16) { + runner = get_f16_gemm_instance(d_dtype, ctx); + } + break; + + case DType::kBFloat16: + if (b_dtype == DType::kBFloat16) { + runner = get_f16_gemm_instance(d_dtype, ctx); + } + break; + + default: + break; + } + } else if (!ctx.transA && ctx.transB) { + using ALayout = RowMajor; + using BLayout = ColMajor; + + switch (a_dtype) { + case DType::kFloat16: + if (b_dtype == DType::kFloat16) { + runner = get_f16_gemm_instance(d_dtype, ctx); + } + break; + + case DType::kBFloat16: + if (b_dtype == DType::kBFloat16) { + runner = get_f16_gemm_instance(d_dtype, ctx); + } + break; + + default: + break; + } + } else if (ctx.transA && !ctx.transB) { + using ALayout = ColMajor; + using BLayout = RowMajor; + + switch (a_dtype) { + case DType::kFloat16: + if (b_dtype == DType::kFloat16) { + runner = get_f16_gemm_instance(d_dtype, ctx); + } + break; + + case DType::kBFloat16: + if (b_dtype == DType::kBFloat16) { + runner = get_f16_gemm_instance(d_dtype, ctx); + } + break; + + default: + break; + } + } else { + return false; + } + + if (runner != nullptr) { + return runner->run(s, ctx); + } else { + return false; + } +} + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp8.cpp new file mode 100644 index 000000000..9facadd69 --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8.cpp @@ -0,0 +1,201 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include "ck_grouped_gemm_common.h" + +namespace transformer_engine { +namespace grouped_gemm { + +template +std::unique_ptr get_f8_gemm_instance(DType d_dtype, const GroupedGemmRunContext& ctx) { + std::unique_ptr runner = nullptr; + using AType = typename TETypeToCKType::type; + using BType = typename TETypeToCKType::type; + using CLayout = RowMajor; + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + if (ctx.N % 256 == 0) { + if constexpr (std::is_same_v && std::is_same_v) { + using TileCfg = TileCfg_256x256x128_k16; + if (ctx.accumulate) { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } + } else { + using TileCfg = TileCfg_256x256x128; + if (ctx.accumulate) { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } + } + } else if (ctx.N % 128 == 0) { + if constexpr (std::is_same_v && std::is_same_v) { + using TileCfg = TileCfg_256x128x128_k16; + if (ctx.accumulate) { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } + } else { + using TileCfg = TileCfg_256x128x128; + if (ctx.accumulate) { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } + } + } else { + if constexpr (std::is_same_v && std::is_same_v) { + using TileCfg = TileCfg_256x128x128_k16_padding; + if (ctx.accumulate) { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } + } else { + using TileCfg = TileCfg_256x128x128_padding; + if (ctx.accumulate) { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } + } + } + }); + return runner; +} + +bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + + const ck_tile::stream_config s{ctx.stream}; + std::unique_ptr runner; + + if (!ctx.transA && !ctx.transB) { + using ALayout = RowMajor; + using BLayout = RowMajor; + + switch (a_dtype) { + case DType::kFloat8E4M3: + if (b_dtype == DType::kFloat8E4M3) { + runner = get_f8_gemm_instance(d_dtype, ctx); + } else if (b_dtype == DType::kFloat8E5M2) { + runner = get_f8_gemm_instance(d_dtype, ctx); + } + break; + + case DType::kFloat8E5M2: + if (b_dtype == DType::kFloat8E4M3) { + runner = get_f8_gemm_instance(d_dtype, ctx); + } else if (b_dtype == DType::kFloat8E5M2) { + runner = get_f8_gemm_instance(d_dtype, ctx); + } + break; + + default: + break; + } + } else if (!ctx.transA && ctx.transB) { + using ALayout = RowMajor; + using BLayout = ColMajor; + + switch (a_dtype) { + case DType::kFloat8E4M3: + if (b_dtype == DType::kFloat8E4M3) { + runner = get_f8_gemm_instance(d_dtype, ctx); + } else if (b_dtype == DType::kFloat8E5M2) { + runner = get_f8_gemm_instance(d_dtype, ctx); + } + break; + + case DType::kFloat8E5M2: + if (b_dtype == DType::kFloat8E4M3) { + runner = get_f8_gemm_instance(d_dtype, ctx); + } else if (b_dtype == DType::kFloat8E5M2) { + runner = get_f8_gemm_instance(d_dtype, ctx); + } + break; + + default: + break; + } + } else if (ctx.transA && !ctx.transB) { + using ALayout = ColMajor; + using BLayout = RowMajor; + + switch (a_dtype) { + case DType::kFloat8E4M3: + if (b_dtype == DType::kFloat8E4M3) { + runner = get_f8_gemm_instance(d_dtype, ctx); + } else if (b_dtype == DType::kFloat8E5M2) { + runner = get_f8_gemm_instance(d_dtype, ctx); + } + break; + + case DType::kFloat8E5M2: + if (b_dtype == DType::kFloat8E4M3) { + runner = get_f8_gemm_instance(d_dtype, ctx); + } else if (b_dtype == DType::kFloat8E5M2) { + runner = get_f8_gemm_instance(d_dtype, ctx); + } + break; + + default: + break; + } + } else { + return false; + } + + if (runner != nullptr) { + return runner->run(s, ctx); + } else { + return false; + } +} + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file From af953822eaeb7408d282716f31bc1d4c6dbd7561 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Mon, 9 Mar 2026 19:35:10 +0000 Subject: [PATCH 45/51] Add explicit template instantiations for CK grouped GEMM runners Introduce extern template declarations and dedicated instantiation translation units for GroupedGemmRunner and QuantGroupedGemmRunner. This moves template instantiation for the supported dtype/layout/tile combinations into separate compilation units to reduce duplicate template instantiation across translation units and better isolate kernel codegen. No functional changes; this is purely a build/compile-time refactor. --- transformer_engine/common/CMakeLists.txt | 2 + .../common/gemm/ck_grouped_gemm_common.h | 611 ++++++++++++++++++ .../ck_grouped_gemm_fp16_instantiations.cpp | 43 ++ .../ck_grouped_gemm_fp8_instantiations.cpp | 79 +++ 4 files changed, 735 insertions(+) create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm_common.h create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm_fp16_instantiations.cpp create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm_fp8_instantiations.cpp diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 2d33215c3..594b55c9f 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -205,6 +205,8 @@ else() gemm/ck_grouped_gemm.cpp gemm/ck_grouped_gemm_fp16.cpp gemm/ck_grouped_gemm_fp8.cpp + gemm/ck_grouped_gemm_fp16_instantiations.cpp + gemm/ck_grouped_gemm_fp8_instantiations.cpp amd_detail/system.cpp) # process source code files diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_common.h b/transformer_engine/common/gemm/ck_grouped_gemm_common.h new file mode 100644 index 000000000..f0c61639a --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm_common.h @@ -0,0 +1,611 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#pragma once + +#include + +#include +#include +#include +#include + +#include +#include "../common.h" + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" + +namespace transformer_engine { +namespace grouped_gemm { + +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; + +template struct TETypeToCKType; +template <> struct TETypeToCKType { using type = float; }; +template <> struct TETypeToCKType { using type = ck_tile::fp8_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bf8_t; }; +template <> struct TETypeToCKType { using type = ck_tile::half_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; + +// Treat TE tensors as generalized 2D matrices by flattening: +// (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim. +static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, + int64_t& d0, int64_t& d1) { + if (t.shape().size() < 2) { + return false; + } + d0 = static_cast(t.flat_first_dim()); + d1 = static_cast(t.flat_last_dim()); + return true; +} + +static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { + return t.data; +} + +static inline const transformer_engine::SimpleTensor& scale_inv_view(const transformer_engine::Tensor& t) { + return t.scale_inv; +} + +struct GroupedGemmRunContext { + const NVTETensor* A = nullptr; + const NVTETensor* B = nullptr; + NVTETensor* D = nullptr; + int64_t N = 0; + + int group_num = 0; + bool transA = false; + bool transB = false; + bool accumulate = false; + + void* workspace = nullptr; + size_t workspace_bytes = 0; + hipStream_t stream = nullptr; +}; + +bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx); + +bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx); + +// ------------------------- +// Tile configs: FP16/BF16 +// ------------------------- + +struct TileCfg_256x256x64 { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + +struct TileCfg_256x128x64 : TileCfg_256x256x64 { + static constexpr ck_tile::index_t N_Tile = 128; +}; + +struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { + static constexpr bool kPadN = true; +}; + +// ------------------------- +// Tile configs: FP8/BF8 +// ------------------------- + +struct TileCfg_256x256x128 { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 32; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + +struct TileCfg_256x128x128 : TileCfg_256x256x128 { + static constexpr ck_tile::index_t N_Tile = 128; +}; + +struct TileCfg_256x128x128_padding : TileCfg_256x128x128 { + static constexpr bool kPadN = true; +}; + +// ------------------------- +// Fallback FP8/BF8 tile family for normalized (bf8_t, fp8_t) pair. +// ------------------------- + +struct TileCfg_256x256x128_k16 { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + +struct TileCfg_256x128x128_k16 : TileCfg_256x256x128_k16 { + static constexpr ck_tile::index_t N_Tile = 128; +}; + +struct TileCfg_256x128x128_k16_padding : TileCfg_256x128x128_k16 { + static constexpr bool kPadN = true; +}; + +// ------------------------- +// CK runner +// ------------------------- +class RunnerInterface { +public: + virtual ~RunnerInterface() = default; + virtual bool run(const ck_tile::stream_config& stream_cfg, + const GroupedGemmRunContext& ctx) = 0; +}; + +template +class GroupedGemmRunner : public RunnerInterface { +public: + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; + + using UniversalTraits = + ck_tile::PersistentTileGemmUniversalTraits< + TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, + TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>; + + static constexpr ck_tile::GemmPipelineScheduler Scheduler = + ck_tile::GemmPipelineScheduler::Intrawave; + + using Problem = ck_tile::UniversalGemmPipelineProblem< + AType, BType, AccType, + GemmShape, UniversalTraits, Scheduler>; + + using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + using Epilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem< + AType, BType, ck_tile::tuple<>, AccType, + CType, ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, + Partitioner::MPerBlock, Partitioner::NPerBlock, + TileCfg::M_Warp, TileCfg::N_Warp, + TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, + Problem::TransposeC, MemOp>>; + + using Kernel = ck_tile::GroupedGemmKernel; + + using HostArgs = ck_tile::GroupedGemmHostArgs<0>; + +public: + static std::vector build_descs(const GroupedGemmRunContext& ctx) { + const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); + if (!ctx.workspace || ctx.workspace_bytes < needed) { + NVTE_ERROR("ck_tile_grouped_gemm: insufficient workspace. Needed bytes=", needed); + } + std::vector descs; + descs.reserve(ctx.group_num); + for (int i = 0; i < ctx.group_num; ++i) { + const transformer_engine::Tensor* const A_te = + transformer_engine::convertNVTETensorCheck(ctx.A[i]); + const transformer_engine::Tensor* const B_te = + transformer_engine::convertNVTETensorCheck(ctx.B[i]); + transformer_engine::Tensor* D_te = + transformer_engine::convertNVTETensorCheck(ctx.D[i]); + + const auto& a = data_view(*A_te); + const auto& b = data_view(*B_te); + const auto& d = data_view(*D_te); + + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; + if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || + !get_flat_2d_dims(*B_te, Bd0, Bd1) || + !get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2."); + } + + const int64_t M = ctx.transA ? Ad1 : Ad0; + const int64_t K = ctx.transA ? Ad0 : Ad1; + const int64_t N = ctx.transB ? Bd0 : Bd1; + const int64_t Kb = ctx.transB ? Bd1 : Bd0; + + if (Kb != K) { + NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); + } + + if (Dd0 != M || Dd1 != N) { + NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); + } + + const ck_tile::index_t stride_A = Ad1; + const ck_tile::index_t stride_B = Bd1; + const ck_tile::index_t stride_E = Dd1; + + descs.emplace_back( + a.dptr, + b.dptr, + std::array{}, + d.dptr, + 1, + M, + N, + K, + stride_A, + stride_B, + std::array{}, + stride_E); + } + + return descs; + }; + + + bool run(const ck_tile::stream_config& stream_cfg, + const GroupedGemmRunContext& ctx) override { + auto descs = build_descs(ctx); + + constexpr int kBlockPerCu = 1; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(descs); + auto kargs = Kernel::MakeKargs(descs); + if (!Kernel::IsSupportedArgument(kargs)) { + NVTE_ERROR("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config."); + } + + HIP_CHECK_ERROR(hipMemcpyAsync(ctx.workspace, + kargs.data(), + kargs.size() * sizeof(typename decltype(kargs)::value_type), + hipMemcpyHostToDevice, + ctx.stream)); + + ck_tile::launch_kernel( + stream_cfg, ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(ctx.workspace), + ctx.group_num)); + return true; + }; +}; + +template +class QuantGroupedGemmRunner : public RunnerInterface { +public: + // hard-coded for tensor quant for now + static constexpr ck_tile::QuantType QuantMode = ck_tile::QuantType::TensorQuant; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; + + using AQLayout = RowMajor; + using BQLayout = RowMajor; + + using UniversalTraits = + ck_tile::TileGemmQuantTraits< + TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, + false, false, ALayout, BLayout, CLayout, + QuantMode, AQLayout, BQLayout, + false, TileCfg::DoubleSmemBuffer, false>; + + using Problem = ck_tile::GemmRowColTensorQuantPipelineProblem< + AType, BType, AccType, + AccType, GemmShape, UniversalTraits, + false, AccType>; + + using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + using Epilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem< + AType, BType, ck_tile::tuple<>, AccType, + CType, ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, + Partitioner::MPerBlock, Partitioner::NPerBlock, + TileCfg::M_Warp, TileCfg::N_Warp, + TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, + Problem::TransposeC, MemOp>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + + using HostArgs = ck_tile::QuantGroupedGemmHostArgs; + +public: + static std::vector build_descs(const GroupedGemmRunContext& ctx) { + const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); + if (!ctx.workspace || ctx.workspace_bytes < needed) { + NVTE_ERROR("ck_tile_grouped_gemm: insufficient workspace. Needed bytes=", needed); + } + std::vector descs; + descs.reserve(ctx.group_num); + for (int i = 0; i < ctx.group_num; ++i) { + const transformer_engine::Tensor* const A_te = + transformer_engine::convertNVTETensorCheck(ctx.A[i]); + const transformer_engine::Tensor* const B_te = + transformer_engine::convertNVTETensorCheck(ctx.B[i]); + transformer_engine::Tensor* D_te = + transformer_engine::convertNVTETensorCheck(ctx.D[i]); + + const auto& a = data_view(*A_te); + const auto& b = data_view(*B_te); + const auto& d = data_view(*D_te); + + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; + if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || + !get_flat_2d_dims(*B_te, Bd0, Bd1) || + !get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2."); + } + + const int64_t M = ctx.transA ? Ad1 : Ad0; + const int64_t K = ctx.transA ? Ad0 : Ad1; + const int64_t N = ctx.transB ? Bd0 : Bd1; + const int64_t Kb = ctx.transB ? Bd1 : Bd0; + + if (Kb != K) { + NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); + } + + if (Dd0 != M || Dd1 != N) { + NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); + } + + const ck_tile::index_t stride_A = Ad1; + const ck_tile::index_t stride_B = Bd1; + const ck_tile::index_t stride_E = Dd1; + + // Hard-coded to tensor quant for the moment + ck_tile::index_t AQK = 1; + ck_tile::index_t BQK = 1; + ck_tile::index_t stride_AQ = 1; + ck_tile::index_t stride_BQ = 1; + + const auto& aq = scale_inv_view(*A_te); + const auto& bq = scale_inv_view(*B_te); + + descs.emplace_back( + a.dptr, + b.dptr, + d.dptr, + aq.dptr, + bq.dptr, + 1, + M, + N, + K, + AQK, + BQK, + stride_A, + stride_B, + stride_E, + stride_AQ, + stride_BQ); + } + + return descs; + }; + bool run(const ck_tile::stream_config& stream_cfg, + const GroupedGemmRunContext& ctx) override { + auto descs = build_descs(ctx); + + constexpr int kBlockPerCu = 1; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(descs); + auto kargs = Kernel::MakeKargs(descs); + if (!Kernel::IsSupportedArgument(kargs)) { + NVTE_ERROR("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config."); + } + + HIP_CHECK_ERROR(hipMemcpyAsync(ctx.workspace, + kargs.data(), + kargs.size() * sizeof(typename decltype(kargs)::value_type), + hipMemcpyHostToDevice, + ctx.stream)); + + ck_tile::launch_kernel( + stream_cfg, ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(ctx.workspace), + ctx.group_num)); + return true; + }; +}; + +// Primus-Turbo-style extern template declarations +#define DECL_CK_GG_RUNNER_EXTERN(AType, BType, CType, ALayout, BLayout, CLayout, TileCfg, MemOp) \ + extern template class GroupedGemmRunner; + +#define DECL_CK_GG_RUNNER(AType, BType, CType, ALayout, BLayout, CLayout, TileCfg, MemOp) \ + template class GroupedGemmRunner; + +#define DECL_CK_QUANT_GG_RUNNER_EXTERN(AType, BType, CType, ALayout, BLayout, CLayout, TileCfg, MemOp) \ + extern template class QuantGroupedGemmRunner; + +#define DECL_CK_QUANT_GG_RUNNER(AType, BType, CType, ALayout, BLayout, CLayout, TileCfg, MemOp) \ + template class QuantGroupedGemmRunner; + +#define APPLY_CK_GG_ALL_LAYOUT(MACRO, AType, BType, CType, TileCfg, MemOp) \ + MACRO(AType, BType, CType, RowMajor, ColMajor, RowMajor, TileCfg, MemOp) \ + MACRO(AType, BType, CType, RowMajor, RowMajor, RowMajor, TileCfg, MemOp) \ + MACRO(AType, BType, CType, ColMajor, RowMajor, RowMajor, TileCfg, MemOp) + +// FP16 * FP16 = FP16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP16 * FP16 = FP32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + +// BF16 * BF16 = FP16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + +// BF16 * BF16 = FP32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E4M3 * FP8_E4M3 = FP16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E4M3 * FP8_E4M3 = FP32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E5M2 * FP8_E5M2 = FP16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E5M2 * FP8_E5M2 = FP32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E5M2 * FP8_E4M3 = F16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E5M2 * FP8_E4M3 = F32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E4M3 * FP8_E5M2 = F16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E4M3 * FP8_E5M2 = F32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp16_instantiations.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp16_instantiations.cpp new file mode 100644 index 000000000..a143666ae --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp16_instantiations.cpp @@ -0,0 +1,43 @@ +#include "ck_grouped_gemm_common.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP16 * FP16 = FP16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP16 * FP16 = FP32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + +// BF16 * BF16 = FP16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + +// BF16 * BF16 = FP32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_instantiations.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_instantiations.cpp new file mode 100644 index 000000000..4c411e78c --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_instantiations.cpp @@ -0,0 +1,79 @@ +#include "ck_grouped_gemm_common.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E4M3 * FP8_E4M3 = FP16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E4M3 * FP8_E4M3 = FP32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E5M2 * FP8_E5M2 = FP16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E5M2 * FP8_E5M2 = FP32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E5M2 * FP8_E4M3 = F16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E5M2 * FP8_E4M3 = F32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E4M3 * FP8_E5M2 = F16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E4M3 * FP8_E5M2 = F32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file From 9990db351caa96e8d9d24b20d6a0c4f7801a1256 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Tue, 10 Mar 2026 15:37:30 +0000 Subject: [PATCH 46/51] Split CK grouped GEMM implementation to reduce compile-time coupling Break up the previous monolithic templated header into separate FP16 and FP8 implementation paths. Key changes: - Introduce a lightweight common header (ck_grouped_gemm_common.h) containing shared utilities, the run context, and the RunnerInterface abstraction. - Move heavy template definitions into dtype-specific implementation headers: ck_grouped_gemm_fp16_impl.h and ck_grouped_gemm_fp8_impl.h. - Add FP16/FP8 factory source files responsible for constructing the correct runner instances based on dtype/layout/tile configuration. - Keep dispatch entry points thin and dependent only on the lightweight header. This isolates the heavy CK template code to a smaller number of translation units and prevents unnecessary template parsing across the codebase, improving build scalability without changing runtime behavior. --- transformer_engine/common/CMakeLists.txt | 6 +- .../common/gemm/ck_grouped_gemm_common.h | 572 +----------------- .../common/gemm/ck_grouped_gemm_fp16.cpp | 132 +--- .../gemm/ck_grouped_gemm_fp16_factory.cpp | 138 +++++ .../common/gemm/ck_grouped_gemm_fp16_impl.h | 275 +++++++++ .../ck_grouped_gemm_fp16_instantiations.cpp | 20 +- .../common/gemm/ck_grouped_gemm_fp8.cpp | 188 +----- .../gemm/ck_grouped_gemm_fp8_factory.cpp | 194 ++++++ .../common/gemm/ck_grouped_gemm_fp8_impl.h | 383 ++++++++++++ .../ck_grouped_gemm_fp8_instantiations.cpp | 38 +- 10 files changed, 1081 insertions(+), 865 deletions(-) create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm_fp16_factory.cpp create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm_fp16_impl.h create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory.cpp create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm_fp8_impl.h diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 594b55c9f..add2ed455 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -203,10 +203,12 @@ else() fused_attn_rocm/utils.cpp gemm/rocm_gemm.cu gemm/ck_grouped_gemm.cpp - gemm/ck_grouped_gemm_fp16.cpp gemm/ck_grouped_gemm_fp8.cpp - gemm/ck_grouped_gemm_fp16_instantiations.cpp + gemm/ck_grouped_gemm_fp8_factory.cpp gemm/ck_grouped_gemm_fp8_instantiations.cpp + gemm/ck_grouped_gemm_fp16.cpp + gemm/ck_grouped_gemm_fp16_factory.cpp + gemm/ck_grouped_gemm_fp16_instantiations.cpp amd_detail/system.cpp) # process source code files diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_common.h b/transformer_engine/common/gemm/ck_grouped_gemm_common.h index f0c61639a..5a567da3d 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm_common.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm_common.h @@ -17,47 +17,11 @@ #include "../common.h" #include "ck_tile/core.hpp" -#include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" -#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" - namespace transformer_engine { namespace grouped_gemm { -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; - -template struct TETypeToCKType; -template <> struct TETypeToCKType { using type = float; }; -template <> struct TETypeToCKType { using type = ck_tile::fp8_t; }; -template <> struct TETypeToCKType { using type = ck_tile::bf8_t; }; -template <> struct TETypeToCKType { using type = ck_tile::half_t; }; -template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; - -// Treat TE tensors as generalized 2D matrices by flattening: -// (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim. -static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, - int64_t& d0, int64_t& d1) { - if (t.shape().size() < 2) { - return false; - } - d0 = static_cast(t.flat_first_dim()); - d1 = static_cast(t.flat_last_dim()); - return true; -} - -static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { - return t.data; -} - -static inline const transformer_engine::SimpleTensor& scale_inv_view(const transformer_engine::Tensor& t) { - return t.scale_inv; -} - struct GroupedGemmRunContext { const NVTETensor* A = nullptr; const NVTETensor* B = nullptr; @@ -74,6 +38,18 @@ struct GroupedGemmRunContext { hipStream_t stream = nullptr; }; +// Treat TE tensors as generalized 2D matrices by flattening: +// (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim. +static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, + int64_t& d0, int64_t& d1) { + if (t.shape().size() < 2) { + return false; + } + d0 = static_cast(t.flat_first_dim()); + d1 = static_cast(t.flat_last_dim()); + return true; +} + bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, DType b_dtype, DType d_dtype, @@ -84,114 +60,6 @@ bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, DType d_dtype, const GroupedGemmRunContext& ctx); -// ------------------------- -// Tile configs: FP16/BF16 -// ------------------------- - -struct TileCfg_256x256x64 { - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 64; - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool DoubleSmemBuffer = false; - - static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; - static constexpr ck_tile::index_t TilePartitionerM01 = 4; -}; - -struct TileCfg_256x128x64 : TileCfg_256x256x64 { - static constexpr ck_tile::index_t N_Tile = 128; -}; - -struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { - static constexpr bool kPadN = true; -}; - -// ------------------------- -// Tile configs: FP8/BF8 -// ------------------------- - -struct TileCfg_256x256x128 { - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 128; - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 32; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool DoubleSmemBuffer = false; - - static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; - static constexpr ck_tile::index_t TilePartitionerM01 = 4; -}; - -struct TileCfg_256x128x128 : TileCfg_256x256x128 { - static constexpr ck_tile::index_t N_Tile = 128; -}; - -struct TileCfg_256x128x128_padding : TileCfg_256x128x128 { - static constexpr bool kPadN = true; -}; - -// ------------------------- -// Fallback FP8/BF8 tile family for normalized (bf8_t, fp8_t) pair. -// ------------------------- - -struct TileCfg_256x256x128_k16 { - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 128; - - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; - - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool DoubleSmemBuffer = false; - - static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; - static constexpr ck_tile::index_t TilePartitionerM01 = 4; -}; - -struct TileCfg_256x128x128_k16 : TileCfg_256x256x128_k16 { - static constexpr ck_tile::index_t N_Tile = 128; -}; - -struct TileCfg_256x128x128_k16_padding : TileCfg_256x128x128_k16 { - static constexpr bool kPadN = true; -}; - -// ------------------------- -// CK runner -// ------------------------- class RunnerInterface { public: virtual ~RunnerInterface() = default; @@ -199,413 +67,15 @@ class RunnerInterface { const GroupedGemmRunContext& ctx) = 0; }; -template -class GroupedGemmRunner : public RunnerInterface { -public: - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile::sequence>; - - using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< - GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; - - using UniversalTraits = - ck_tile::PersistentTileGemmUniversalTraits< - TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, - TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>; - - static constexpr ck_tile::GemmPipelineScheduler Scheduler = - ck_tile::GemmPipelineScheduler::Intrawave; - - using Problem = ck_tile::UniversalGemmPipelineProblem< - AType, BType, AccType, - GemmShape, UniversalTraits, Scheduler>; - - using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; - - using Epilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem< - AType, BType, ck_tile::tuple<>, AccType, - CType, ck_tile::tuple<>, CLayout, - ck_tile::element_wise::PassThrough, - Partitioner::MPerBlock, Partitioner::NPerBlock, - TileCfg::M_Warp, TileCfg::N_Warp, - TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, - Problem::TransposeC, MemOp>>; - - using Kernel = ck_tile::GroupedGemmKernel; - - using HostArgs = ck_tile::GroupedGemmHostArgs<0>; - -public: - static std::vector build_descs(const GroupedGemmRunContext& ctx) { - const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); - if (!ctx.workspace || ctx.workspace_bytes < needed) { - NVTE_ERROR("ck_tile_grouped_gemm: insufficient workspace. Needed bytes=", needed); - } - std::vector descs; - descs.reserve(ctx.group_num); - for (int i = 0; i < ctx.group_num; ++i) { - const transformer_engine::Tensor* const A_te = - transformer_engine::convertNVTETensorCheck(ctx.A[i]); - const transformer_engine::Tensor* const B_te = - transformer_engine::convertNVTETensorCheck(ctx.B[i]); - transformer_engine::Tensor* D_te = - transformer_engine::convertNVTETensorCheck(ctx.D[i]); - - const auto& a = data_view(*A_te); - const auto& b = data_view(*B_te); - const auto& d = data_view(*D_te); - - int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; - if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || - !get_flat_2d_dims(*B_te, Bd0, Bd1) || - !get_flat_2d_dims(*D_te, Dd0, Dd1)) { - NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2."); - } - - const int64_t M = ctx.transA ? Ad1 : Ad0; - const int64_t K = ctx.transA ? Ad0 : Ad1; - const int64_t N = ctx.transB ? Bd0 : Bd1; - const int64_t Kb = ctx.transB ? Bd1 : Bd0; - - if (Kb != K) { - NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); - } - - if (Dd0 != M || Dd1 != N) { - NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); - } - - const ck_tile::index_t stride_A = Ad1; - const ck_tile::index_t stride_B = Bd1; - const ck_tile::index_t stride_E = Dd1; - - descs.emplace_back( - a.dptr, - b.dptr, - std::array{}, - d.dptr, - 1, - M, - N, - K, - stride_A, - stride_B, - std::array{}, - stride_E); - } - - return descs; - }; - - - bool run(const ck_tile::stream_config& stream_cfg, - const GroupedGemmRunContext& ctx) override { - auto descs = build_descs(ctx); - - constexpr int kBlockPerCu = 1; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(descs); - auto kargs = Kernel::MakeKargs(descs); - if (!Kernel::IsSupportedArgument(kargs)) { - NVTE_ERROR("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config."); - } - - HIP_CHECK_ERROR(hipMemcpyAsync(ctx.workspace, - kargs.data(), - kargs.size() * sizeof(typename decltype(kargs)::value_type), - hipMemcpyHostToDevice, - ctx.stream)); - - ck_tile::launch_kernel( - stream_cfg, ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, - ck_tile::cast_pointer_to_constant_address_space(ctx.workspace), - ctx.group_num)); - return true; - }; -}; - -template -class QuantGroupedGemmRunner : public RunnerInterface { -public: - // hard-coded for tensor quant for now - static constexpr ck_tile::QuantType QuantMode = ck_tile::QuantType::TensorQuant; - - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile::sequence>; - - using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< - GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; - - using AQLayout = RowMajor; - using BQLayout = RowMajor; - - using UniversalTraits = - ck_tile::TileGemmQuantTraits< - TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, - false, false, ALayout, BLayout, CLayout, - QuantMode, AQLayout, BQLayout, - false, TileCfg::DoubleSmemBuffer, false>; - - using Problem = ck_tile::GemmRowColTensorQuantPipelineProblem< - AType, BType, AccType, - AccType, GemmShape, UniversalTraits, - false, AccType>; - - using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; - - using Epilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem< - AType, BType, ck_tile::tuple<>, AccType, - CType, ck_tile::tuple<>, CLayout, - ck_tile::element_wise::PassThrough, - Partitioner::MPerBlock, Partitioner::NPerBlock, - TileCfg::M_Warp, TileCfg::N_Warp, - TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, - Problem::TransposeC, MemOp>>; - - using Kernel = ck_tile::QuantGroupedGemmKernel; - - using HostArgs = ck_tile::QuantGroupedGemmHostArgs; - -public: - static std::vector build_descs(const GroupedGemmRunContext& ctx) { - const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); - if (!ctx.workspace || ctx.workspace_bytes < needed) { - NVTE_ERROR("ck_tile_grouped_gemm: insufficient workspace. Needed bytes=", needed); - } - std::vector descs; - descs.reserve(ctx.group_num); - for (int i = 0; i < ctx.group_num; ++i) { - const transformer_engine::Tensor* const A_te = - transformer_engine::convertNVTETensorCheck(ctx.A[i]); - const transformer_engine::Tensor* const B_te = - transformer_engine::convertNVTETensorCheck(ctx.B[i]); - transformer_engine::Tensor* D_te = - transformer_engine::convertNVTETensorCheck(ctx.D[i]); - - const auto& a = data_view(*A_te); - const auto& b = data_view(*B_te); - const auto& d = data_view(*D_te); - - int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; - if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || - !get_flat_2d_dims(*B_te, Bd0, Bd1) || - !get_flat_2d_dims(*D_te, Dd0, Dd1)) { - NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2."); - } - - const int64_t M = ctx.transA ? Ad1 : Ad0; - const int64_t K = ctx.transA ? Ad0 : Ad1; - const int64_t N = ctx.transB ? Bd0 : Bd1; - const int64_t Kb = ctx.transB ? Bd1 : Bd0; - - if (Kb != K) { - NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); - } - - if (Dd0 != M || Dd1 != N) { - NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); - } - - const ck_tile::index_t stride_A = Ad1; - const ck_tile::index_t stride_B = Bd1; - const ck_tile::index_t stride_E = Dd1; - - // Hard-coded to tensor quant for the moment - ck_tile::index_t AQK = 1; - ck_tile::index_t BQK = 1; - ck_tile::index_t stride_AQ = 1; - ck_tile::index_t stride_BQ = 1; - - const auto& aq = scale_inv_view(*A_te); - const auto& bq = scale_inv_view(*B_te); - - descs.emplace_back( - a.dptr, - b.dptr, - d.dptr, - aq.dptr, - bq.dptr, - 1, - M, - N, - K, - AQK, - BQK, - stride_A, - stride_B, - stride_E, - stride_AQ, - stride_BQ); - } - - return descs; - }; - bool run(const ck_tile::stream_config& stream_cfg, - const GroupedGemmRunContext& ctx) override { - auto descs = build_descs(ctx); - - constexpr int kBlockPerCu = 1; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(descs); - auto kargs = Kernel::MakeKargs(descs); - if (!Kernel::IsSupportedArgument(kargs)) { - NVTE_ERROR("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config."); - } - - HIP_CHECK_ERROR(hipMemcpyAsync(ctx.workspace, - kargs.data(), - kargs.size() * sizeof(typename decltype(kargs)::value_type), - hipMemcpyHostToDevice, - ctx.stream)); - - ck_tile::launch_kernel( - stream_cfg, ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, - ck_tile::cast_pointer_to_constant_address_space(ctx.workspace), - ctx.group_num)); - return true; - }; -}; - -// Primus-Turbo-style extern template declarations -#define DECL_CK_GG_RUNNER_EXTERN(AType, BType, CType, ALayout, BLayout, CLayout, TileCfg, MemOp) \ - extern template class GroupedGemmRunner; - -#define DECL_CK_GG_RUNNER(AType, BType, CType, ALayout, BLayout, CLayout, TileCfg, MemOp) \ - template class GroupedGemmRunner; - -#define DECL_CK_QUANT_GG_RUNNER_EXTERN(AType, BType, CType, ALayout, BLayout, CLayout, TileCfg, MemOp) \ - extern template class QuantGroupedGemmRunner; - -#define DECL_CK_QUANT_GG_RUNNER(AType, BType, CType, ALayout, BLayout, CLayout, TileCfg, MemOp) \ - template class QuantGroupedGemmRunner; - -#define APPLY_CK_GG_ALL_LAYOUT(MACRO, AType, BType, CType, TileCfg, MemOp) \ - MACRO(AType, BType, CType, RowMajor, ColMajor, RowMajor, TileCfg, MemOp) \ - MACRO(AType, BType, CType, RowMajor, RowMajor, RowMajor, TileCfg, MemOp) \ - MACRO(AType, BType, CType, ColMajor, RowMajor, RowMajor, TileCfg, MemOp) - -// FP16 * FP16 = FP16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP16 * FP16 = FP32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) - -// BF16 * BF16 = FP16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) - -// BF16 * BF16 = FP32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP8_E4M3 * FP8_E4M3 = FP16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP8_E4M3 * FP8_E4M3 = FP32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP8_E5M2 * FP8_E5M2 = FP16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP8_E5M2 * FP8_E5M2 = FP32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP8_E5M2 * FP8_E4M3 = F16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP8_E5M2 * FP8_E4M3 = F32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP8_E4M3 * FP8_E5M2 = F16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP8_E4M3 * FP8_E5M2 = F32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +std::unique_ptr make_fp8_runner(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx); +std::unique_ptr make_fp16_runner(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx); + } // namespace grouped_gemm } // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp16.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp16.cpp index 4dcf35a45..582572680 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm_fp16.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp16.cpp @@ -1,144 +1,22 @@ -/************************************************************************* - * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. - * - * License for AMD contributions = MIT. See LICENSE for more information - ************************************************************************/ - #include "ck_grouped_gemm_common.h" namespace transformer_engine { namespace grouped_gemm { -template -std::unique_ptr get_f16_gemm_instance(DType d_dtype, const GroupedGemmRunContext& ctx) { - std::unique_ptr runner = nullptr; - using AType = typename TETypeToCKType::type; - using BType = typename TETypeToCKType::type; - using CLayout = RowMajor; - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { - using CType = typename TETypeToCKType::type; - if (ctx.N % 256 == 0) { - using TileCfg = TileCfg_256x256x64; - if (ctx.accumulate) { - using Runner = GroupedGemmRunner; - runner = std::make_unique(); - } else { - using Runner = GroupedGemmRunner; - runner = std::make_unique(); - } - - } else if (ctx.N % 128 == 0) { - using TileCfg = TileCfg_256x128x64; - if (ctx.accumulate) { - using Runner = GroupedGemmRunner; - runner = std::make_unique(); - } else { - using Runner = GroupedGemmRunner; - runner = std::make_unique(); - } - } else { - using TileCfg = TileCfg_256x128x64_padding; - if (ctx.accumulate) { - using Runner = GroupedGemmRunner; - runner = std::make_unique(); - } else { - using Runner = GroupedGemmRunner; - runner = std::make_unique(); - } - } - }); - return runner; -} - bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, DType b_dtype, DType d_dtype, const GroupedGemmRunContext& ctx) { + const ck_tile::stream_config s{ctx.stream}; - const ck_tile::stream_config s{ctx.stream}; - std::unique_ptr runner; - - if (!ctx.transA && !ctx.transB) { - using ALayout = RowMajor; - using BLayout = RowMajor; - - switch (a_dtype) { - case DType::kFloat16: - if (b_dtype == DType::kFloat16) { - runner = get_f16_gemm_instance(d_dtype, ctx); - } - break; - - case DType::kBFloat16: - if (b_dtype == DType::kBFloat16) { - runner = get_f16_gemm_instance(d_dtype, ctx); - } - break; - - default: - break; - } - } else if (!ctx.transA && ctx.transB) { - using ALayout = RowMajor; - using BLayout = ColMajor; - - switch (a_dtype) { - case DType::kFloat16: - if (b_dtype == DType::kFloat16) { - runner = get_f16_gemm_instance(d_dtype, ctx); - } - break; - - case DType::kBFloat16: - if (b_dtype == DType::kBFloat16) { - runner = get_f16_gemm_instance(d_dtype, ctx); - } - break; - - default: - break; - } - } else if (ctx.transA && !ctx.transB) { - using ALayout = ColMajor; - using BLayout = RowMajor; - - switch (a_dtype) { - case DType::kFloat16: - if (b_dtype == DType::kFloat16) { - runner = get_f16_gemm_instance(d_dtype, ctx); - } - break; - - case DType::kBFloat16: - if (b_dtype == DType::kBFloat16) { - runner = get_f16_gemm_instance(d_dtype, ctx); - } - break; + auto runner = make_fp16_runner( + a_dtype, b_dtype, d_dtype, ctx); - default: - break; + if (!runner) { + return false; } - } else { - return false; - } - if (runner != nullptr) { return runner->run(s, ctx); - } else { - return false; - } } } // namespace grouped_gemm diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp16_factory.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp16_factory.cpp new file mode 100644 index 000000000..3ba637b09 --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp16_factory.cpp @@ -0,0 +1,138 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + + #include "ck_grouped_gemm_common.h" +#include "ck_grouped_gemm_fp16_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +template +std::unique_ptr make_fp16_runner_typed(DType d_dtype, const GroupedGemmRunContext& ctx) { + std::unique_ptr runner = nullptr; + using AType = typename TETypeToCKType::type; + using BType = typename TETypeToCKType::type; + using CLayout = RowMajor; + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + if (ctx.N % 256 == 0) { + using TileCfg = TileCfg_256x256x64; + if (ctx.accumulate) { + using Runner = GroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = GroupedGemmRunner; + runner = std::make_unique(); + } + + } else if (ctx.N % 128 == 0) { + using TileCfg = TileCfg_256x128x64; + if (ctx.accumulate) { + using Runner = GroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = GroupedGemmRunner; + runner = std::make_unique(); + } + } else { + using TileCfg = TileCfg_256x128x64_padding; + if (ctx.accumulate) { + using Runner = GroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = GroupedGemmRunner; + runner = std::make_unique(); + } + } + }); + return runner; +} + +std::unique_ptr make_fp16_runner(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + + if (!ctx.transA && !ctx.transB) { + using ALayout = RowMajor; + using BLayout = RowMajor; + + switch (a_dtype) { + case DType::kFloat16: + if (b_dtype == DType::kFloat16) { + return make_fp16_runner_typed(d_dtype, ctx); + } + break; + + case DType::kBFloat16: + if (b_dtype == DType::kBFloat16) { + return make_fp16_runner_typed(d_dtype, ctx); + } + break; + + default: + break; + } + } else if (!ctx.transA && ctx.transB) { + using ALayout = RowMajor; + using BLayout = ColMajor; + + switch (a_dtype) { + case DType::kFloat16: + if (b_dtype == DType::kFloat16) { + return make_fp16_runner_typed(d_dtype, ctx); + } + break; + + case DType::kBFloat16: + if (b_dtype == DType::kBFloat16) { + return make_fp16_runner_typed(d_dtype, ctx); + } + break; + + default: + break; + } + } else if (ctx.transA && !ctx.transB) { + using ALayout = ColMajor; + using BLayout = RowMajor; + + switch (a_dtype) { + case DType::kFloat16: + if (b_dtype == DType::kFloat16) { + return make_fp16_runner_typed(d_dtype, ctx); + } + break; + + case DType::kBFloat16: + if (b_dtype == DType::kBFloat16) { + return make_fp16_runner_typed(d_dtype, ctx); + } + break; + + default: + break; + } + } else { + return nullptr; + } + return nullptr; +} + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp16_impl.h b/transformer_engine/common/gemm/ck_grouped_gemm_fp16_impl.h new file mode 100644 index 000000000..b9bab836a --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp16_impl.h @@ -0,0 +1,275 @@ +#pragma once +#include "ck_grouped_gemm_common.h" + +#include +#include +#include +#include + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" + +namespace transformer_engine { +namespace grouped_gemm { + +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; + +template struct TETypeToCKType; +template <> struct TETypeToCKType { using type = float; }; +template <> struct TETypeToCKType { using type = ck_tile::fp8_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bf8_t; }; +template <> struct TETypeToCKType { using type = ck_tile::half_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; + +static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { + return t.data; +} + +static inline const transformer_engine::SimpleTensor& scale_inv_view(const transformer_engine::Tensor& t) { + return t.scale_inv; +} + +// ------------------------- +// Tile configs: FP16/BF16 +// ------------------------- + +struct TileCfg_256x256x64 { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + +struct TileCfg_256x128x64 : TileCfg_256x256x64 { + static constexpr ck_tile::index_t N_Tile = 128; +}; + +struct TileCfg_256x128x64_padding : TileCfg_256x128x64 { + static constexpr bool kPadN = true; +}; + +template +class GroupedGemmRunner : public RunnerInterface { +public: + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; + + using UniversalTraits = + ck_tile::PersistentTileGemmUniversalTraits< + TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, + TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>; + + static constexpr ck_tile::GemmPipelineScheduler Scheduler = + ck_tile::GemmPipelineScheduler::Intrawave; + + using Problem = ck_tile::UniversalGemmPipelineProblem< + AType, BType, AccType, + GemmShape, UniversalTraits, Scheduler>; + + using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + using Epilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem< + AType, BType, ck_tile::tuple<>, AccType, + CType, ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, + Partitioner::MPerBlock, Partitioner::NPerBlock, + TileCfg::M_Warp, TileCfg::N_Warp, + TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, + Problem::TransposeC, MemOp>>; + + using Kernel = ck_tile::GroupedGemmKernel; + + using HostArgs = ck_tile::GroupedGemmHostArgs<0>; + +public: + static std::vector build_descs(const GroupedGemmRunContext& ctx) { + const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); + if (!ctx.workspace || ctx.workspace_bytes < needed) { + NVTE_ERROR("ck_tile_grouped_gemm: insufficient workspace. Needed bytes=", needed); + } + std::vector descs; + descs.reserve(ctx.group_num); + for (int i = 0; i < ctx.group_num; ++i) { + const transformer_engine::Tensor* const A_te = + transformer_engine::convertNVTETensorCheck(ctx.A[i]); + const transformer_engine::Tensor* const B_te = + transformer_engine::convertNVTETensorCheck(ctx.B[i]); + transformer_engine::Tensor* D_te = + transformer_engine::convertNVTETensorCheck(ctx.D[i]); + + const auto& a = data_view(*A_te); + const auto& b = data_view(*B_te); + const auto& d = data_view(*D_te); + + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; + if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || + !get_flat_2d_dims(*B_te, Bd0, Bd1) || + !get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2."); + } + + const int64_t M = ctx.transA ? Ad1 : Ad0; + const int64_t K = ctx.transA ? Ad0 : Ad1; + const int64_t N = ctx.transB ? Bd0 : Bd1; + const int64_t Kb = ctx.transB ? Bd1 : Bd0; + + if (Kb != K) { + NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); + } + + if (Dd0 != M || Dd1 != N) { + NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); + } + + const ck_tile::index_t stride_A = Ad1; + const ck_tile::index_t stride_B = Bd1; + const ck_tile::index_t stride_E = Dd1; + + descs.emplace_back( + a.dptr, + b.dptr, + std::array{}, + d.dptr, + 1, + M, + N, + K, + stride_A, + stride_B, + std::array{}, + stride_E); + } + + return descs; + }; + + + bool run(const ck_tile::stream_config& stream_cfg, + const GroupedGemmRunContext& ctx) override { + auto descs = build_descs(ctx); + + constexpr int kBlockPerCu = 1; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(descs); + auto kargs = Kernel::MakeKargs(descs); + if (!Kernel::IsSupportedArgument(kargs)) { + NVTE_ERROR("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config."); + } + + HIP_CHECK_ERROR(hipMemcpyAsync(ctx.workspace, + kargs.data(), + kargs.size() * sizeof(typename decltype(kargs)::value_type), + hipMemcpyHostToDevice, + ctx.stream)); + + ck_tile::launch_kernel( + stream_cfg, ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(ctx.workspace), + ctx.group_num)); + return true; + }; +}; + +// Primus-Turbo-style extern template declarations +#define DECL_CK_GG_RUNNER_EXTERN(AType, BType, CType, ALayout, BLayout, CLayout, TileCfg, MemOp) \ + extern template class GroupedGemmRunner; + +#define DECL_CK_GG_RUNNER(AType, BType, CType, ALayout, BLayout, CLayout, TileCfg, MemOp) \ + template class GroupedGemmRunner; + +#define APPLY_CK_GG_ALL_LAYOUT(MACRO, AType, BType, CType, TileCfg, MemOp) \ + MACRO(AType, BType, CType, RowMajor, ColMajor, RowMajor, TileCfg, MemOp) \ + MACRO(AType, BType, CType, RowMajor, RowMajor, RowMajor, TileCfg, MemOp) \ + MACRO(AType, BType, CType, ColMajor, RowMajor, RowMajor, TileCfg, MemOp) + +// FP16 * FP16 = FP16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP16 * FP16 = FP32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP16 * FP16 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + +// BF16 * BF16 = FP16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + +// BF16 * BF16 = FP32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + +// BF16 * BF16 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + +} +} \ No newline at end of file diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp16_instantiations.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp16_instantiations.cpp index a143666ae..2179e3ffe 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm_fp16_instantiations.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp16_instantiations.cpp @@ -1,4 +1,4 @@ -#include "ck_grouped_gemm_common.h" +#include "ck_grouped_gemm_fp16_impl.h" namespace transformer_engine { namespace grouped_gemm { @@ -21,6 +21,15 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, floa APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) +// FP16 * FP16 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + // BF16 * BF16 = FP16 APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) @@ -39,5 +48,14 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16 APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) +// BF16 * BF16 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + } // namespace grouped_gemm } // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp8.cpp index 9facadd69..427f2dd67 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm_fp8.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8.cpp @@ -1,200 +1,22 @@ -/************************************************************************* - * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. - * - * License for AMD contributions = MIT. See LICENSE for more information - ************************************************************************/ - #include "ck_grouped_gemm_common.h" namespace transformer_engine { namespace grouped_gemm { -template -std::unique_ptr get_f8_gemm_instance(DType d_dtype, const GroupedGemmRunContext& ctx) { - std::unique_ptr runner = nullptr; - using AType = typename TETypeToCKType::type; - using BType = typename TETypeToCKType::type; - using CLayout = RowMajor; - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { - using CType = typename TETypeToCKType::type; - if (ctx.N % 256 == 0) { - if constexpr (std::is_same_v && std::is_same_v) { - using TileCfg = TileCfg_256x256x128_k16; - if (ctx.accumulate) { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } else { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } - } else { - using TileCfg = TileCfg_256x256x128; - if (ctx.accumulate) { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } else { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } - } - } else if (ctx.N % 128 == 0) { - if constexpr (std::is_same_v && std::is_same_v) { - using TileCfg = TileCfg_256x128x128_k16; - if (ctx.accumulate) { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } else { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } - } else { - using TileCfg = TileCfg_256x128x128; - if (ctx.accumulate) { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } else { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } - } - } else { - if constexpr (std::is_same_v && std::is_same_v) { - using TileCfg = TileCfg_256x128x128_k16_padding; - if (ctx.accumulate) { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } else { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } - } else { - using TileCfg = TileCfg_256x128x128_padding; - if (ctx.accumulate) { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } else { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } - } - } - }); - return runner; -} - bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, DType b_dtype, DType d_dtype, const GroupedGemmRunContext& ctx) { + const ck_tile::stream_config s{ctx.stream}; - const ck_tile::stream_config s{ctx.stream}; - std::unique_ptr runner; - - if (!ctx.transA && !ctx.transB) { - using ALayout = RowMajor; - using BLayout = RowMajor; - - switch (a_dtype) { - case DType::kFloat8E4M3: - if (b_dtype == DType::kFloat8E4M3) { - runner = get_f8_gemm_instance(d_dtype, ctx); - } else if (b_dtype == DType::kFloat8E5M2) { - runner = get_f8_gemm_instance(d_dtype, ctx); - } - break; - - case DType::kFloat8E5M2: - if (b_dtype == DType::kFloat8E4M3) { - runner = get_f8_gemm_instance(d_dtype, ctx); - } else if (b_dtype == DType::kFloat8E5M2) { - runner = get_f8_gemm_instance(d_dtype, ctx); - } - break; - - default: - break; - } - } else if (!ctx.transA && ctx.transB) { - using ALayout = RowMajor; - using BLayout = ColMajor; - - switch (a_dtype) { - case DType::kFloat8E4M3: - if (b_dtype == DType::kFloat8E4M3) { - runner = get_f8_gemm_instance(d_dtype, ctx); - } else if (b_dtype == DType::kFloat8E5M2) { - runner = get_f8_gemm_instance(d_dtype, ctx); - } - break; - - case DType::kFloat8E5M2: - if (b_dtype == DType::kFloat8E4M3) { - runner = get_f8_gemm_instance(d_dtype, ctx); - } else if (b_dtype == DType::kFloat8E5M2) { - runner = get_f8_gemm_instance(d_dtype, ctx); - } - break; - - default: - break; - } - } else if (ctx.transA && !ctx.transB) { - using ALayout = ColMajor; - using BLayout = RowMajor; - - switch (a_dtype) { - case DType::kFloat8E4M3: - if (b_dtype == DType::kFloat8E4M3) { - runner = get_f8_gemm_instance(d_dtype, ctx); - } else if (b_dtype == DType::kFloat8E5M2) { - runner = get_f8_gemm_instance(d_dtype, ctx); - } - break; - - case DType::kFloat8E5M2: - if (b_dtype == DType::kFloat8E4M3) { - runner = get_f8_gemm_instance(d_dtype, ctx); - } else if (b_dtype == DType::kFloat8E5M2) { - runner = get_f8_gemm_instance(d_dtype, ctx); - } - break; + auto runner = make_fp8_runner( + a_dtype, b_dtype, d_dtype, ctx); - default: - break; + if (!runner) { + return false; } - } else { - return false; - } - if (runner != nullptr) { return runner->run(s, ctx); - } else { - return false; - } } } // namespace grouped_gemm diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory.cpp new file mode 100644 index 000000000..717788d61 --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory.cpp @@ -0,0 +1,194 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include "ck_grouped_gemm_common.h" +#include "ck_grouped_gemm_fp8_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +template +std::unique_ptr make_fp8_runner_typed(DType d_dtype, const GroupedGemmRunContext& ctx) { + std::unique_ptr runner = nullptr; + using AType = typename TETypeToCKType::type; + using BType = typename TETypeToCKType::type; + using CLayout = RowMajor; + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + if (ctx.N % 256 == 0) { + if constexpr (std::is_same_v && std::is_same_v) { + using TileCfg = TileCfg_256x256x128_k16; + if (ctx.accumulate) { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } + } else { + using TileCfg = TileCfg_256x256x128; + if (ctx.accumulate) { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } + } + } else if (ctx.N % 128 == 0) { + if constexpr (std::is_same_v && std::is_same_v) { + using TileCfg = TileCfg_256x128x128_k16; + if (ctx.accumulate) { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } + } else { + using TileCfg = TileCfg_256x128x128; + if (ctx.accumulate) { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } + } + } else { + if constexpr (std::is_same_v && std::is_same_v) { + using TileCfg = TileCfg_256x128x128_k16_padding; + if (ctx.accumulate) { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } + } else { + using TileCfg = TileCfg_256x128x128_padding; + if (ctx.accumulate) { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } + } + } + }); + return runner; +} + +std::unique_ptr make_fp8_runner(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + + if (!ctx.transA && !ctx.transB) { + using ALayout = RowMajor; + using BLayout = RowMajor; + + switch (a_dtype) { + case DType::kFloat8E4M3: + if (b_dtype == DType::kFloat8E4M3) { + return make_fp8_runner_typed(d_dtype, ctx); + } else if (b_dtype == DType::kFloat8E5M2) { + return make_fp8_runner_typed(d_dtype, ctx); + } + break; + + case DType::kFloat8E5M2: + if (b_dtype == DType::kFloat8E4M3) { + return make_fp8_runner_typed(d_dtype, ctx); + } else if (b_dtype == DType::kFloat8E5M2) { + return make_fp8_runner_typed(d_dtype, ctx); + } + break; + + default: + break; + } + } else if (!ctx.transA && ctx.transB) { + using ALayout = RowMajor; + using BLayout = ColMajor; + + switch (a_dtype) { + case DType::kFloat8E4M3: + if (b_dtype == DType::kFloat8E4M3) { + return make_fp8_runner_typed(d_dtype, ctx); + } else if (b_dtype == DType::kFloat8E5M2) { + return make_fp8_runner_typed(d_dtype, ctx); + } + break; + + case DType::kFloat8E5M2: + if (b_dtype == DType::kFloat8E4M3) { + return make_fp8_runner_typed(d_dtype, ctx); + } else if (b_dtype == DType::kFloat8E5M2) { + return make_fp8_runner_typed(d_dtype, ctx); + } + break; + + default: + break; + } + } else if (ctx.transA && !ctx.transB) { + using ALayout = ColMajor; + using BLayout = RowMajor; + + switch (a_dtype) { + case DType::kFloat8E4M3: + if (b_dtype == DType::kFloat8E4M3) { + return make_fp8_runner_typed(d_dtype, ctx); + } else if (b_dtype == DType::kFloat8E5M2) { + return make_fp8_runner_typed(d_dtype, ctx); + } + break; + + case DType::kFloat8E5M2: + if (b_dtype == DType::kFloat8E4M3) { + return make_fp8_runner_typed(d_dtype, ctx); + } else if (b_dtype == DType::kFloat8E5M2) { + return make_fp8_runner_typed(d_dtype, ctx); + } + break; + + default: + break; + } + } else { + return nullptr; + } + return nullptr; +} + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_impl.h b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_impl.h new file mode 100644 index 000000000..9aa8f0347 --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_impl.h @@ -0,0 +1,383 @@ + +#pragma once +#include "ck_grouped_gemm_common.h" + +#include +#include +#include +#include + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" + +namespace transformer_engine { +namespace grouped_gemm { + +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; + +template struct TETypeToCKType; +template <> struct TETypeToCKType { using type = float; }; +template <> struct TETypeToCKType { using type = ck_tile::fp8_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bf8_t; }; +template <> struct TETypeToCKType { using type = ck_tile::half_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; + +static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { + return t.data; +} + +static inline const transformer_engine::SimpleTensor& scale_inv_view(const transformer_engine::Tensor& t) { + return t.scale_inv; +} + +// ------------------------- +// Tile configs: FP8/BF8 +// ------------------------- + +struct TileCfg_256x256x128 { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 32; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + +struct TileCfg_256x128x128 : TileCfg_256x256x128 { + static constexpr ck_tile::index_t N_Tile = 128; +}; + +struct TileCfg_256x128x128_padding : TileCfg_256x128x128 { + static constexpr bool kPadN = true; +}; + +// ------------------------- +// Fallback FP8/BF8 tile family for normalized (bf8_t, fp8_t) pair. +// ------------------------- + +struct TileCfg_256x256x128_k16 { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + +struct TileCfg_256x128x128_k16 : TileCfg_256x256x128_k16 { + static constexpr ck_tile::index_t N_Tile = 128; +}; + +struct TileCfg_256x128x128_k16_padding : TileCfg_256x128x128_k16 { + static constexpr bool kPadN = true; +}; + +template +class QuantGroupedGemmRunner : public RunnerInterface { +public: + // hard-coded for tensor quant for now + static constexpr ck_tile::QuantType QuantMode = ck_tile::QuantType::TensorQuant; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; + + using AQLayout = RowMajor; + using BQLayout = RowMajor; + + using UniversalTraits = + ck_tile::TileGemmQuantTraits< + TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, + false, false, ALayout, BLayout, CLayout, + QuantMode, AQLayout, BQLayout, + false, TileCfg::DoubleSmemBuffer, false>; + + using Problem = ck_tile::GemmRowColTensorQuantPipelineProblem< + AType, BType, AccType, + AccType, GemmShape, UniversalTraits, + false, AccType>; + + using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + using Epilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem< + AType, BType, ck_tile::tuple<>, AccType, + CType, ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, + Partitioner::MPerBlock, Partitioner::NPerBlock, + TileCfg::M_Warp, TileCfg::N_Warp, + TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, + Problem::TransposeC, MemOp>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + + using HostArgs = ck_tile::QuantGroupedGemmHostArgs; + +public: + static std::vector build_descs(const GroupedGemmRunContext& ctx) { + const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); + if (!ctx.workspace || ctx.workspace_bytes < needed) { + NVTE_ERROR("ck_tile_grouped_gemm: insufficient workspace. Needed bytes=", needed); + } + std::vector descs; + descs.reserve(ctx.group_num); + for (int i = 0; i < ctx.group_num; ++i) { + const transformer_engine::Tensor* const A_te = + transformer_engine::convertNVTETensorCheck(ctx.A[i]); + const transformer_engine::Tensor* const B_te = + transformer_engine::convertNVTETensorCheck(ctx.B[i]); + transformer_engine::Tensor* D_te = + transformer_engine::convertNVTETensorCheck(ctx.D[i]); + + const auto& a = data_view(*A_te); + const auto& b = data_view(*B_te); + const auto& d = data_view(*D_te); + + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; + if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || + !get_flat_2d_dims(*B_te, Bd0, Bd1) || + !get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2."); + } + + const int64_t M = ctx.transA ? Ad1 : Ad0; + const int64_t K = ctx.transA ? Ad0 : Ad1; + const int64_t N = ctx.transB ? Bd0 : Bd1; + const int64_t Kb = ctx.transB ? Bd1 : Bd0; + + if (Kb != K) { + NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); + } + + if (Dd0 != M || Dd1 != N) { + NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); + } + + const ck_tile::index_t stride_A = Ad1; + const ck_tile::index_t stride_B = Bd1; + const ck_tile::index_t stride_E = Dd1; + + // Hard-coded to tensor quant for the moment + ck_tile::index_t AQK = 1; + ck_tile::index_t BQK = 1; + ck_tile::index_t stride_AQ = 1; + ck_tile::index_t stride_BQ = 1; + + const auto& aq = scale_inv_view(*A_te); + const auto& bq = scale_inv_view(*B_te); + + descs.emplace_back( + a.dptr, + b.dptr, + d.dptr, + aq.dptr, + bq.dptr, + 1, + M, + N, + K, + AQK, + BQK, + stride_A, + stride_B, + stride_E, + stride_AQ, + stride_BQ); + } + + return descs; + }; + bool run(const ck_tile::stream_config& stream_cfg, + const GroupedGemmRunContext& ctx) override { + auto descs = build_descs(ctx); + + constexpr int kBlockPerCu = 1; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(descs); + auto kargs = Kernel::MakeKargs(descs); + if (!Kernel::IsSupportedArgument(kargs)) { + NVTE_ERROR("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config."); + } + + HIP_CHECK_ERROR(hipMemcpyAsync(ctx.workspace, + kargs.data(), + kargs.size() * sizeof(typename decltype(kargs)::value_type), + hipMemcpyHostToDevice, + ctx.stream)); + + ck_tile::launch_kernel( + stream_cfg, ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(ctx.workspace), + ctx.group_num)); + return true; + }; +}; + +// Primus-Turbo-style extern template declarations +#define DECL_CK_QUANT_GG_RUNNER_EXTERN(AType, BType, CType, ALayout, BLayout, CLayout, TileCfg, MemOp) \ + extern template class QuantGroupedGemmRunner; + +#define DECL_CK_QUANT_GG_RUNNER(AType, BType, CType, ALayout, BLayout, CLayout, TileCfg, MemOp) \ + template class QuantGroupedGemmRunner; + +#define APPLY_CK_GG_ALL_LAYOUT(MACRO, AType, BType, CType, TileCfg, MemOp) \ + MACRO(AType, BType, CType, RowMajor, ColMajor, RowMajor, TileCfg, MemOp) \ + MACRO(AType, BType, CType, RowMajor, RowMajor, RowMajor, TileCfg, MemOp) \ + MACRO(AType, BType, CType, ColMajor, RowMajor, RowMajor, TileCfg, MemOp) + +// FP8_E4M3 * FP8_E4M3 = FP16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E4M3 * FP8_E4M3 = FP32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E4M3 * FP8_E4M3 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E5M2 * FP8_E5M2 = FP16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E5M2 * FP8_E5M2 = FP32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E5M2 * FP8_E5M2 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E5M2 * FP8_E4M3 = F16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E5M2 * FP8_E4M3 = F32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E5M2 * FP8_E4M3 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E4M3 * FP8_E5M2 = F16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E4M3 * FP8_E5M2 = F32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E4M3 * FP8_E5M2 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_instantiations.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_instantiations.cpp index 4c411e78c..df1ad72c5 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_instantiations.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_instantiations.cpp @@ -1,4 +1,4 @@ -#include "ck_grouped_gemm_common.h" +#include "ck_grouped_gemm_fp8_impl.h" namespace transformer_engine { namespace grouped_gemm { @@ -21,6 +21,15 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +// FP8_E4M3 * FP8_E4M3 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + // FP8_E5M2 * FP8_E5M2 = FP16 APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) @@ -39,6 +48,15 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +// FP8_E5M2 * FP8_E5M2 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + // FP8_E5M2 * FP8_E4M3 = F16 APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) @@ -57,6 +75,15 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) +// FP8_E5M2 * FP8_E4M3 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) + // FP8_E4M3 * FP8_E5M2 = F16 APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) @@ -75,5 +102,14 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +// FP8_E4M3 * FP8_E5M2 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + } // namespace grouped_gemm } // namespace transformer_engine \ No newline at end of file From bff80fe8ddd481d90b52344f0bd9bd58c86fe13e Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Tue, 10 Mar 2026 17:49:06 +0000 Subject: [PATCH 47/51] Split CK grouped GEMM explicit instantiations by operand dtype Refactor the FP8 and FP16 explicit template instantiations into smaller, more manageable translation units. Key changes: - Move explicit instantiations into a new subdirectory: gemm/instantiations/. - Split FP8 and FP16 instantiations across multiple source files organized by operand data type combinations. - Reduce the size of individual translation units to improve build parallelism and avoid long single-TU compile bottlenecks. No functional changes; this is a build-structure refactor only. --- transformer_engine/common/CMakeLists.txt | 20 ++- .../ck_grouped_gemm_fp16_instantiations.cpp | 61 ---------- .../ck_grouped_gemm_fp8_instantiations.cpp | 115 ------------------ ...ped_gemm_bf16_bf16_bf16_instantiations.cpp | 16 +++ ...ped_gemm_bf16_bf16_fp16_instantiations.cpp | 16 +++ ...ped_gemm_bf16_bf16_fp32_instantiations.cpp | 16 +++ ...ouped_gemm_bf8_bf8_bf16_instantiations.cpp | 16 +++ ...ouped_gemm_bf8_bf8_fp16_instantiations.cpp | 16 +++ ...ouped_gemm_bf8_bf8_fp32_instantiations.cpp | 16 +++ ...ouped_gemm_bf8_fp8_bf16_instantiations.cpp | 16 +++ ...ouped_gemm_bf8_fp8_fp16_instantiations.cpp | 16 +++ ...ouped_gemm_bf8_fp8_fp32_instantiations.cpp | 16 +++ ...ped_gemm_fp16_fp16_bf16_instantiations.cpp | 16 +++ ...ped_gemm_fp16_fp16_fp16_instantiations.cpp | 16 +++ ...ped_gemm_fp16_fp16_fp32_instantiations.cpp | 16 +++ ...ouped_gemm_fp8_bf8_bf16_instantiations.cpp | 16 +++ ...ouped_gemm_fp8_bf8_fp16_instantiations.cpp | 16 +++ ...ouped_gemm_fp8_bf8_fp32_instantiations.cpp | 16 +++ ...ouped_gemm_fp8_fp8_bf16_instantiations.cpp | 16 +++ ...ouped_gemm_fp8_fp8_fp16_instantiations.cpp | 16 +++ ...ouped_gemm_fp8_fp8_fp32_instantiations.cpp | 16 +++ 21 files changed, 306 insertions(+), 178 deletions(-) delete mode 100644 transformer_engine/common/gemm/ck_grouped_gemm_fp16_instantiations.cpp delete mode 100644 transformer_engine/common/gemm/ck_grouped_gemm_fp8_instantiations.cpp create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_bf16_instantiations.cpp create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp16_instantiations.cpp create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp32_instantiations.cpp create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_instantiations.cpp create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_instantiations.cpp create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_instantiations.cpp create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_instantiations.cpp create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_instantiations.cpp create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_instantiations.cpp create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_bf16_instantiations.cpp create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp16_instantiations.cpp create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp32_instantiations.cpp create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_instantiations.cpp create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_instantiations.cpp create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_instantiations.cpp create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_instantiations.cpp create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_instantiations.cpp create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_instantiations.cpp diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index add2ed455..2c98f59f8 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -205,10 +205,26 @@ else() gemm/ck_grouped_gemm.cpp gemm/ck_grouped_gemm_fp8.cpp gemm/ck_grouped_gemm_fp8_factory.cpp - gemm/ck_grouped_gemm_fp8_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_instantiations.cpp gemm/ck_grouped_gemm_fp16.cpp gemm/ck_grouped_gemm_fp16_factory.cpp - gemm/ck_grouped_gemm_fp16_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf16_bf16_bf16_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp16_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp32_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp16_fp16_bf16_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp16_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp32_instantiations.cpp amd_detail/system.cpp) # process source code files diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp16_instantiations.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp16_instantiations.cpp deleted file mode 100644 index 2179e3ffe..000000000 --- a/transformer_engine/common/gemm/ck_grouped_gemm_fp16_instantiations.cpp +++ /dev/null @@ -1,61 +0,0 @@ -#include "ck_grouped_gemm_fp16_impl.h" - -namespace transformer_engine { -namespace grouped_gemm { - -// FP16 * FP16 = FP16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP16 * FP16 = FP32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP16 * FP16 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) - -// BF16 * BF16 = FP16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) - -// BF16 * BF16 = FP32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) - -// BF16 * BF16 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) - -} // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_instantiations.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_instantiations.cpp deleted file mode 100644 index df1ad72c5..000000000 --- a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_instantiations.cpp +++ /dev/null @@ -1,115 +0,0 @@ -#include "ck_grouped_gemm_fp8_impl.h" - -namespace transformer_engine { -namespace grouped_gemm { - -// FP8_E4M3 * FP8_E4M3 = FP16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP8_E4M3 * FP8_E4M3 = FP32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP8_E4M3 * FP8_E4M3 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP8_E5M2 * FP8_E5M2 = FP16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP8_E5M2 * FP8_E5M2 = FP32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP8_E5M2 * FP8_E5M2 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP8_E5M2 * FP8_E4M3 = F16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP8_E5M2 * FP8_E4M3 = F32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP8_E5M2 * FP8_E4M3 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP8_E4M3 * FP8_E5M2 = F16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP8_E4M3 * FP8_E5M2 = F32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) - -// FP8_E4M3 * FP8_E5M2 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) - -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) - -} // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_bf16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_bf16_instantiations.cpp new file mode 100644 index 000000000..dad9f279b --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_bf16_instantiations.cpp @@ -0,0 +1,16 @@ +#include "../ck_grouped_gemm_fp16_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// BF16 * BF16 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp16_instantiations.cpp new file mode 100644 index 000000000..ae49a781b --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp16_instantiations.cpp @@ -0,0 +1,16 @@ +#include "../ck_grouped_gemm_fp16_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// BF16 * BF16 = FP16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp32_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp32_instantiations.cpp new file mode 100644 index 000000000..af5f7ac0c --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp32_instantiations.cpp @@ -0,0 +1,16 @@ +#include "../ck_grouped_gemm_fp16_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// BF16 * BF16 = FP32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_instantiations.cpp new file mode 100644 index 000000000..af888bc3a --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_instantiations.cpp @@ -0,0 +1,16 @@ +#include "../ck_grouped_gemm_fp8_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E5M2 * FP8_E5M2 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_instantiations.cpp new file mode 100644 index 000000000..2a5e1f913 --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_instantiations.cpp @@ -0,0 +1,16 @@ +#include "../ck_grouped_gemm_fp8_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E5M2 * FP8_E5M2 = FP16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_instantiations.cpp new file mode 100644 index 000000000..266ec89db --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_instantiations.cpp @@ -0,0 +1,16 @@ +#include "../ck_grouped_gemm_fp8_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E5M2 * FP8_E5M2 = FP32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_instantiations.cpp new file mode 100644 index 000000000..7fa5514b2 --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_instantiations.cpp @@ -0,0 +1,16 @@ +#include "../ck_grouped_gemm_fp8_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E5M2 * FP8_E4M3 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_instantiations.cpp new file mode 100644 index 000000000..3d6a572b5 --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_instantiations.cpp @@ -0,0 +1,16 @@ +#include "../ck_grouped_gemm_fp8_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E5M2 * FP8_E4M3 = F16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_instantiations.cpp new file mode 100644 index 000000000..5b47862b3 --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_instantiations.cpp @@ -0,0 +1,16 @@ +#include "../ck_grouped_gemm_fp8_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E5M2 * FP8_E4M3 = F32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_bf16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_bf16_instantiations.cpp new file mode 100644 index 000000000..f40a62c38 --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_bf16_instantiations.cpp @@ -0,0 +1,16 @@ +#include "../ck_grouped_gemm_fp16_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP16 * FP16 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp16_instantiations.cpp new file mode 100644 index 000000000..0315b68ee --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp16_instantiations.cpp @@ -0,0 +1,16 @@ +#include "../ck_grouped_gemm_fp16_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP16 * FP16 = FP16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp32_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp32_instantiations.cpp new file mode 100644 index 000000000..9aaf6e9f4 --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp32_instantiations.cpp @@ -0,0 +1,16 @@ +#include "../ck_grouped_gemm_fp16_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP16 * FP16 = FP32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x256x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_instantiations.cpp new file mode 100644 index 000000000..0897d3065 --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_instantiations.cpp @@ -0,0 +1,16 @@ +#include "../ck_grouped_gemm_fp8_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E4M3 * FP8_E5M2 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_instantiations.cpp new file mode 100644 index 000000000..b57eaf4c5 --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_instantiations.cpp @@ -0,0 +1,16 @@ +#include "../ck_grouped_gemm_fp8_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E4M3 * FP8_E5M2 = F16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_instantiations.cpp new file mode 100644 index 000000000..7f4d43b9b --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_instantiations.cpp @@ -0,0 +1,16 @@ +#include "../ck_grouped_gemm_fp8_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E4M3 * FP8_E5M2 = F32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_instantiations.cpp new file mode 100644 index 000000000..96df6b412 --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_instantiations.cpp @@ -0,0 +1,16 @@ +#include "../ck_grouped_gemm_fp8_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E4M3 * FP8_E4M3 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_instantiations.cpp new file mode 100644 index 000000000..9923a3467 --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_instantiations.cpp @@ -0,0 +1,16 @@ +#include "../ck_grouped_gemm_fp8_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E4M3 * FP8_E4M3 = FP16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_instantiations.cpp new file mode 100644 index 000000000..255d3847c --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_instantiations.cpp @@ -0,0 +1,16 @@ +#include "../ck_grouped_gemm_fp8_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E4M3 * FP8_E4M3 = FP32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file From e9cd6b8dac1be3f3934fe3323f00e9f4f860189b Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Wed, 11 Mar 2026 14:19:28 +0000 Subject: [PATCH 48/51] Add runtime architecture dispatch for CK FP8 grouped GEMM (gfx942/gfx950) Introduce runtime GPU architecture detection and split CK FP8 grouped GEMM factories into arch-specific implementations for gfx942 and gfx950. Key changes: - Added GPUArch detection helper using hipGetDeviceProperties. - Introduced common factory dispatcher that selects the correct arch-specific runner factory at runtime. - Split FP8 runner factories into gfx942 and gfx950 implementations. - Added arch-specific kernel instantiation translation units for each arch. - Updated build to compile both arch implementations into the same library while selecting the correct one at runtime. This enables a single TransformerEngine build to support both MI300 (gfx942) and MI350 (gfx950) while avoiding invalid tile configurations during kernel instantiation. --- transformer_engine/common/CMakeLists.txt | 40 ++-- .../common/gemm/ck_grouped_gemm.cpp | 2 +- .../ck_grouped_gemm_fp8_factory_common.cpp | 50 +++++ .../gemm/ck_grouped_gemm_fp8_factory_decl.h | 30 +++ ...=> ck_grouped_gemm_fp8_factory_gfx942.cpp} | 37 ++-- .../ck_grouped_gemm_fp8_factory_gfx950.cpp | 197 +++++++++++++++++ ...pl.h => ck_grouped_gemm_fp8_gfx942_impl.h} | 181 +--------------- .../gemm/ck_grouped_gemm_fp8_gfx950_impl.h | 188 ++++++++++++++++ .../gemm/ck_grouped_gemm_fp8_runner_common.h | 203 ++++++++++++++++++ ...ped_gemm_bf16_bf16_bf16_instantiations.cpp | 1 + ...mm_bf8_bf8_bf16_gfx942_instantiations.cpp} | 6 +- ...emm_bf8_bf8_bf16_gfx950_instantiations.cpp | 18 ++ ...mm_bf8_bf8_fp16_gfx942_instantiations.cpp} | 6 +- ...emm_bf8_bf8_fp16_gfx950_instantiations.cpp | 18 ++ ...mm_bf8_bf8_fp32_gfx942_instantiations.cpp} | 6 +- ...emm_bf8_bf8_fp32_gfx950_instantiations.cpp | 18 ++ ...mm_bf8_fp8_bf16_gfx942_instantiations.cpp} | 6 +- ...emm_bf8_fp8_bf16_gfx950_instantiations.cpp | 18 ++ ...mm_bf8_fp8_fp16_gfx942_instantiations.cpp} | 6 +- ...emm_bf8_fp8_fp16_gfx950_instantiations.cpp | 18 ++ ...mm_bf8_fp8_fp32_gfx942_instantiations.cpp} | 6 +- ...emm_bf8_fp8_fp32_gfx950_instantiations.cpp | 18 ++ ...mm_fp8_bf8_bf16_gfx942_instantiations.cpp} | 6 +- ...emm_fp8_bf8_bf16_gfx950_instantiations.cpp | 18 ++ ...mm_fp8_bf8_fp16_gfx942_instantiations.cpp} | 6 +- ...emm_fp8_bf8_fp16_gfx950_instantiations.cpp | 18 ++ ...mm_fp8_bf8_fp32_gfx942_instantiations.cpp} | 6 +- ...emm_fp8_bf8_fp32_gfx950_instantiations.cpp | 18 ++ ...mm_fp8_fp8_bf16_gfx942_instantiations.cpp} | 6 +- ...emm_fp8_fp8_bf16_gfx950_instantiations.cpp | 18 ++ ...mm_fp8_fp8_fp16_gfx942_instantiations.cpp} | 6 +- ...emm_fp8_fp8_fp16_gfx950_instantiations.cpp | 18 ++ ...mm_fp8_fp8_fp32_gfx942_instantiations.cpp} | 6 +- ...emm_fp8_fp8_fp32_gfx950_instantiations.cpp | 18 ++ 34 files changed, 982 insertions(+), 235 deletions(-) create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_common.cpp create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_decl.h rename transformer_engine/common/gemm/{ck_grouped_gemm_fp8_factory.cpp => ck_grouped_gemm_fp8_factory_gfx942.cpp} (83%) create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_gfx950.cpp rename transformer_engine/common/gemm/{ck_grouped_gemm_fp8_impl.h => ck_grouped_gemm_fp8_gfx942_impl.h} (66%) create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm_fp8_gfx950_impl.h create mode 100644 transformer_engine/common/gemm/ck_grouped_gemm_fp8_runner_common.h rename transformer_engine/common/gemm/instantiations/{ck_grouped_gemm_bf8_bf8_bf16_instantiations.cpp => ck_grouped_gemm_bf8_bf8_bf16_gfx942_instantiations.cpp} (88%) create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx950_instantiations.cpp rename transformer_engine/common/gemm/instantiations/{ck_grouped_gemm_bf8_bf8_fp16_instantiations.cpp => ck_grouped_gemm_bf8_bf8_fp16_gfx942_instantiations.cpp} (88%) create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx950_instantiations.cpp rename transformer_engine/common/gemm/instantiations/{ck_grouped_gemm_bf8_bf8_fp32_instantiations.cpp => ck_grouped_gemm_bf8_bf8_fp32_gfx942_instantiations.cpp} (88%) create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx950_instantiations.cpp rename transformer_engine/common/gemm/instantiations/{ck_grouped_gemm_bf8_fp8_bf16_instantiations.cpp => ck_grouped_gemm_bf8_fp8_bf16_gfx942_instantiations.cpp} (89%) create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx950_instantiations.cpp rename transformer_engine/common/gemm/instantiations/{ck_grouped_gemm_bf8_fp8_fp16_instantiations.cpp => ck_grouped_gemm_bf8_fp8_fp16_gfx942_instantiations.cpp} (88%) create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx950_instantiations.cpp rename transformer_engine/common/gemm/instantiations/{ck_grouped_gemm_bf8_fp8_fp32_instantiations.cpp => ck_grouped_gemm_bf8_fp8_fp32_gfx942_instantiations.cpp} (88%) create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx950_instantiations.cpp rename transformer_engine/common/gemm/instantiations/{ck_grouped_gemm_fp8_bf8_bf16_instantiations.cpp => ck_grouped_gemm_fp8_bf8_bf16_gfx942_instantiations.cpp} (88%) create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx950_instantiations.cpp rename transformer_engine/common/gemm/instantiations/{ck_grouped_gemm_fp8_bf8_fp16_instantiations.cpp => ck_grouped_gemm_fp8_bf8_fp16_gfx942_instantiations.cpp} (88%) create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx950_instantiations.cpp rename transformer_engine/common/gemm/instantiations/{ck_grouped_gemm_fp8_bf8_fp32_instantiations.cpp => ck_grouped_gemm_fp8_bf8_fp32_gfx942_instantiations.cpp} (88%) create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx950_instantiations.cpp rename transformer_engine/common/gemm/instantiations/{ck_grouped_gemm_fp8_fp8_bf16_instantiations.cpp => ck_grouped_gemm_fp8_fp8_bf16_gfx942_instantiations.cpp} (88%) create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx950_instantiations.cpp rename transformer_engine/common/gemm/instantiations/{ck_grouped_gemm_fp8_fp8_fp16_instantiations.cpp => ck_grouped_gemm_fp8_fp8_fp16_gfx942_instantiations.cpp} (88%) create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx950_instantiations.cpp rename transformer_engine/common/gemm/instantiations/{ck_grouped_gemm_fp8_fp8_fp32_instantiations.cpp => ck_grouped_gemm_fp8_fp8_fp32_gfx942_instantiations.cpp} (88%) create mode 100644 transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx950_instantiations.cpp diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 2c98f59f8..5974d7c89 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -204,19 +204,33 @@ else() gemm/rocm_gemm.cu gemm/ck_grouped_gemm.cpp gemm/ck_grouped_gemm_fp8.cpp - gemm/ck_grouped_gemm_fp8_factory.cpp - gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_instantiations.cpp + gemm/ck_grouped_gemm_fp8_factory_common.cpp + gemm/ck_grouped_gemm_fp8_factory_gfx942.cpp + gemm/ck_grouped_gemm_fp8_factory_gfx950.cpp + gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx950_instantiations.cpp gemm/ck_grouped_gemm_fp16.cpp gemm/ck_grouped_gemm_fp16_factory.cpp gemm/instantiations/ck_grouped_gemm_bf16_bf16_bf16_instantiations.cpp diff --git a/transformer_engine/common/gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm.cpp index f2c19df9d..59d2e8e4c 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm.cpp @@ -5,7 +5,7 @@ ************************************************************************/ #include "ck_grouped_gemm_common.h" - +#include bool ck_tile_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_common.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_common.cpp new file mode 100644 index 000000000..39118c46e --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_common.cpp @@ -0,0 +1,50 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#include "ck_grouped_gemm_fp8_factory_decl.h" + +#include +#include + +namespace transformer_engine { +namespace grouped_gemm { + +GPUArch detect_gpu_arch() { + int device = 0; + HIP_CHECK_ERROR(hipGetDevice(&device)); + + hipDeviceProp_t props{}; + HIP_CHECK_ERROR(hipGetDeviceProperties(&props, device)); + + const std::string arch(props.gcnArchName); + + if (arch.find("gfx942") != std::string::npos) { + return GPUArch::GFX942; + } + if (arch.find("gfx950") != std::string::npos) { + return GPUArch::GFX950; + } + return GPUArch::UNKNOWN; +} + +std::unique_ptr make_fp8_runner( + DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + switch (detect_gpu_arch()) { + case GPUArch::GFX942: + return make_fp8_runner_gfx942(a_dtype, b_dtype, d_dtype, ctx); + case GPUArch::GFX950: + return make_fp8_runner_gfx950(a_dtype, b_dtype, d_dtype, ctx); + default: + NVTE_ERROR("ck_tile_grouped_gemm: available architectures = {gfx942, gfx950}"); + return nullptr; + } +} + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_decl.h b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_decl.h new file mode 100644 index 000000000..a9512370c --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_decl.h @@ -0,0 +1,30 @@ +#pragma once + +#include "ck_grouped_gemm_common.h" +#include + +namespace transformer_engine { +namespace grouped_gemm { + +enum class GPUArch { + GFX942, + GFX950, + UNKNOWN +}; + +GPUArch detect_gpu_arch(); + +std::unique_ptr make_fp8_runner_gfx942( + DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx); + +std::unique_ptr make_fp8_runner_gfx950( + DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx); + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_gfx942.cpp similarity index 83% rename from transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory.cpp rename to transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_gfx942.cpp index 717788d61..1281af69a 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_gfx942.cpp @@ -3,15 +3,16 @@ * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) #include "ck_grouped_gemm_common.h" -#include "ck_grouped_gemm_fp8_impl.h" +#include "ck_grouped_gemm_fp8_gfx942_impl.h" namespace transformer_engine { namespace grouped_gemm { template -std::unique_ptr make_fp8_runner_typed(DType d_dtype, const GroupedGemmRunContext& ctx) { +std::unique_ptr make_fp8_runner_typed_gfx942(DType d_dtype, const GroupedGemmRunContext& ctx) { std::unique_ptr runner = nullptr; using AType = typename TETypeToCKType::type; using BType = typename TETypeToCKType::type; @@ -107,7 +108,7 @@ std::unique_ptr make_fp8_runner_typed(DType d_dtype, const Grou return runner; } -std::unique_ptr make_fp8_runner(DType a_dtype, +std::unique_ptr make_fp8_runner_gfx942(DType a_dtype, DType b_dtype, DType d_dtype, const GroupedGemmRunContext& ctx) { @@ -119,17 +120,17 @@ std::unique_ptr make_fp8_runner(DType a_dtype, switch (a_dtype) { case DType::kFloat8E4M3: if (b_dtype == DType::kFloat8E4M3) { - return make_fp8_runner_typed(d_dtype, ctx); + return make_fp8_runner_typed_gfx942(d_dtype, ctx); } else if (b_dtype == DType::kFloat8E5M2) { - return make_fp8_runner_typed(d_dtype, ctx); + return make_fp8_runner_typed_gfx942(d_dtype, ctx); } break; case DType::kFloat8E5M2: if (b_dtype == DType::kFloat8E4M3) { - return make_fp8_runner_typed(d_dtype, ctx); + return make_fp8_runner_typed_gfx942(d_dtype, ctx); } else if (b_dtype == DType::kFloat8E5M2) { - return make_fp8_runner_typed(d_dtype, ctx); + return make_fp8_runner_typed_gfx942(d_dtype, ctx); } break; @@ -143,17 +144,17 @@ std::unique_ptr make_fp8_runner(DType a_dtype, switch (a_dtype) { case DType::kFloat8E4M3: if (b_dtype == DType::kFloat8E4M3) { - return make_fp8_runner_typed(d_dtype, ctx); + return make_fp8_runner_typed_gfx942(d_dtype, ctx); } else if (b_dtype == DType::kFloat8E5M2) { - return make_fp8_runner_typed(d_dtype, ctx); + return make_fp8_runner_typed_gfx942(d_dtype, ctx); } break; case DType::kFloat8E5M2: if (b_dtype == DType::kFloat8E4M3) { - return make_fp8_runner_typed(d_dtype, ctx); + return make_fp8_runner_typed_gfx942(d_dtype, ctx); } else if (b_dtype == DType::kFloat8E5M2) { - return make_fp8_runner_typed(d_dtype, ctx); + return make_fp8_runner_typed_gfx942(d_dtype, ctx); } break; @@ -167,17 +168,17 @@ std::unique_ptr make_fp8_runner(DType a_dtype, switch (a_dtype) { case DType::kFloat8E4M3: if (b_dtype == DType::kFloat8E4M3) { - return make_fp8_runner_typed(d_dtype, ctx); + return make_fp8_runner_typed_gfx942(d_dtype, ctx); } else if (b_dtype == DType::kFloat8E5M2) { - return make_fp8_runner_typed(d_dtype, ctx); + return make_fp8_runner_typed_gfx942(d_dtype, ctx); } break; case DType::kFloat8E5M2: if (b_dtype == DType::kFloat8E4M3) { - return make_fp8_runner_typed(d_dtype, ctx); + return make_fp8_runner_typed_gfx942(d_dtype, ctx); } else if (b_dtype == DType::kFloat8E5M2) { - return make_fp8_runner_typed(d_dtype, ctx); + return make_fp8_runner_typed_gfx942(d_dtype, ctx); } break; @@ -190,5 +191,7 @@ std::unique_ptr make_fp8_runner(DType a_dtype, return nullptr; } -} // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} +} + +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_gfx950.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_gfx950.cpp new file mode 100644 index 000000000..aba6d865c --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_gfx950.cpp @@ -0,0 +1,197 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) + +#include "ck_grouped_gemm_common.h" +#include "ck_grouped_gemm_fp8_gfx950_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +template +std::unique_ptr make_fp8_runner_typed_gfx950(DType d_dtype, const GroupedGemmRunContext& ctx) { + std::unique_ptr runner = nullptr; + using AType = typename TETypeToCKType::type; + using BType = typename TETypeToCKType::type; + using CLayout = RowMajor; + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + if (ctx.N % 256 == 0) { + if constexpr (std::is_same_v && std::is_same_v) { + using TileCfg = TileCfg_GFX950_128x128x128; + if (ctx.accumulate) { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } + } else { + using TileCfg = TileCfg_GFX950_128x128x128; + if (ctx.accumulate) { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } + } + } else if (ctx.N % 128 == 0) { + if constexpr (std::is_same_v && std::is_same_v) { + using TileCfg = TileCfg_GFX950_128x128x128; + if (ctx.accumulate) { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } + } else { + using TileCfg = TileCfg_GFX950_128x128x128; + if (ctx.accumulate) { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } + } + } else { + if constexpr (std::is_same_v && std::is_same_v) { + using TileCfg = TileCfg_GFX950_128x128x128; + if (ctx.accumulate) { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } + } else { + using TileCfg = TileCfg_GFX950_128x128x128; + if (ctx.accumulate) { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } + } + } + }); + return runner; +} + +std::unique_ptr make_fp8_runner_gfx950(DType a_dtype, + DType b_dtype, + DType d_dtype, + const GroupedGemmRunContext& ctx) { + + if (!ctx.transA && !ctx.transB) { + using ALayout = RowMajor; + using BLayout = RowMajor; + + switch (a_dtype) { + case DType::kFloat8E4M3: + if (b_dtype == DType::kFloat8E4M3) { + return make_fp8_runner_typed_gfx950(d_dtype, ctx); + } else if (b_dtype == DType::kFloat8E5M2) { + return make_fp8_runner_typed_gfx950(d_dtype, ctx); + } + break; + + case DType::kFloat8E5M2: + if (b_dtype == DType::kFloat8E4M3) { + return make_fp8_runner_typed_gfx950(d_dtype, ctx); + } else if (b_dtype == DType::kFloat8E5M2) { + return make_fp8_runner_typed_gfx950(d_dtype, ctx); + } + break; + + default: + break; + } + } else if (!ctx.transA && ctx.transB) { + using ALayout = RowMajor; + using BLayout = ColMajor; + + switch (a_dtype) { + case DType::kFloat8E4M3: + if (b_dtype == DType::kFloat8E4M3) { + return make_fp8_runner_typed_gfx950(d_dtype, ctx); + } else if (b_dtype == DType::kFloat8E5M2) { + return make_fp8_runner_typed_gfx950(d_dtype, ctx); + } + break; + + case DType::kFloat8E5M2: + if (b_dtype == DType::kFloat8E4M3) { + return make_fp8_runner_typed_gfx950(d_dtype, ctx); + } else if (b_dtype == DType::kFloat8E5M2) { + return make_fp8_runner_typed_gfx950(d_dtype, ctx); + } + break; + + default: + break; + } + } else if (ctx.transA && !ctx.transB) { + using ALayout = ColMajor; + using BLayout = RowMajor; + + switch (a_dtype) { + case DType::kFloat8E4M3: + if (b_dtype == DType::kFloat8E4M3) { + return make_fp8_runner_typed_gfx950(d_dtype, ctx); + } else if (b_dtype == DType::kFloat8E5M2) { + return make_fp8_runner_typed_gfx950(d_dtype, ctx); + } + break; + + case DType::kFloat8E5M2: + if (b_dtype == DType::kFloat8E4M3) { + return make_fp8_runner_typed_gfx950(d_dtype, ctx); + } else if (b_dtype == DType::kFloat8E5M2) { + return make_fp8_runner_typed_gfx950(d_dtype, ctx); + } + break; + + default: + break; + } + } else { + return nullptr; + } + return nullptr; +} + +} +} + +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_impl.h b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_gfx942_impl.h similarity index 66% rename from transformer_engine/common/gemm/ck_grouped_gemm_fp8_impl.h rename to transformer_engine/common/gemm/ck_grouped_gemm_fp8_gfx942_impl.h index 9aa8f0347..53e46cf03 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_impl.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_gfx942_impl.h @@ -1,6 +1,7 @@ #pragma once #include "ck_grouped_gemm_common.h" +#include "ck_grouped_gemm_fp8_runner_common.h" #include #include @@ -21,24 +22,6 @@ namespace transformer_engine { namespace grouped_gemm { -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; - -template struct TETypeToCKType; -template <> struct TETypeToCKType { using type = float; }; -template <> struct TETypeToCKType { using type = ck_tile::fp8_t; }; -template <> struct TETypeToCKType { using type = ck_tile::bf8_t; }; -template <> struct TETypeToCKType { using type = ck_tile::half_t; }; -template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; - -static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { - return t.data; -} - -static inline const transformer_engine::SimpleTensor& scale_inv_view(const transformer_engine::Tensor& t) { - return t.scale_inv; -} - // ------------------------- // Tile configs: FP8/BF8 // ------------------------- @@ -109,168 +92,6 @@ struct TileCfg_256x128x128_k16_padding : TileCfg_256x128x128_k16 { static constexpr bool kPadN = true; }; -template -class QuantGroupedGemmRunner : public RunnerInterface { -public: - // hard-coded for tensor quant for now - static constexpr ck_tile::QuantType QuantMode = ck_tile::QuantType::TensorQuant; - - using GemmShape = ck_tile::TileGemmShape< - ck_tile::sequence, - ck_tile::sequence, - ck_tile::sequence>; - - using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< - GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; - - using AQLayout = RowMajor; - using BQLayout = RowMajor; - - using UniversalTraits = - ck_tile::TileGemmQuantTraits< - TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, - false, false, ALayout, BLayout, CLayout, - QuantMode, AQLayout, BQLayout, - false, TileCfg::DoubleSmemBuffer, false>; - - using Problem = ck_tile::GemmRowColTensorQuantPipelineProblem< - AType, BType, AccType, - AccType, GemmShape, UniversalTraits, - false, AccType>; - - using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; - - using Epilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem< - AType, BType, ck_tile::tuple<>, AccType, - CType, ck_tile::tuple<>, CLayout, - ck_tile::element_wise::PassThrough, - Partitioner::MPerBlock, Partitioner::NPerBlock, - TileCfg::M_Warp, TileCfg::N_Warp, - TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, - Problem::TransposeC, MemOp>>; - - using Kernel = ck_tile::QuantGroupedGemmKernel; - - using HostArgs = ck_tile::QuantGroupedGemmHostArgs; - -public: - static std::vector build_descs(const GroupedGemmRunContext& ctx) { - const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); - if (!ctx.workspace || ctx.workspace_bytes < needed) { - NVTE_ERROR("ck_tile_grouped_gemm: insufficient workspace. Needed bytes=", needed); - } - std::vector descs; - descs.reserve(ctx.group_num); - for (int i = 0; i < ctx.group_num; ++i) { - const transformer_engine::Tensor* const A_te = - transformer_engine::convertNVTETensorCheck(ctx.A[i]); - const transformer_engine::Tensor* const B_te = - transformer_engine::convertNVTETensorCheck(ctx.B[i]); - transformer_engine::Tensor* D_te = - transformer_engine::convertNVTETensorCheck(ctx.D[i]); - - const auto& a = data_view(*A_te); - const auto& b = data_view(*B_te); - const auto& d = data_view(*D_te); - - int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; - if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || - !get_flat_2d_dims(*B_te, Bd0, Bd1) || - !get_flat_2d_dims(*D_te, Dd0, Dd1)) { - NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2."); - } - - const int64_t M = ctx.transA ? Ad1 : Ad0; - const int64_t K = ctx.transA ? Ad0 : Ad1; - const int64_t N = ctx.transB ? Bd0 : Bd1; - const int64_t Kb = ctx.transB ? Bd1 : Bd0; - - if (Kb != K) { - NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); - } - - if (Dd0 != M || Dd1 != N) { - NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); - } - - const ck_tile::index_t stride_A = Ad1; - const ck_tile::index_t stride_B = Bd1; - const ck_tile::index_t stride_E = Dd1; - - // Hard-coded to tensor quant for the moment - ck_tile::index_t AQK = 1; - ck_tile::index_t BQK = 1; - ck_tile::index_t stride_AQ = 1; - ck_tile::index_t stride_BQ = 1; - - const auto& aq = scale_inv_view(*A_te); - const auto& bq = scale_inv_view(*B_te); - - descs.emplace_back( - a.dptr, - b.dptr, - d.dptr, - aq.dptr, - bq.dptr, - 1, - M, - N, - K, - AQK, - BQK, - stride_A, - stride_B, - stride_E, - stride_AQ, - stride_BQ); - } - - return descs; - }; - bool run(const ck_tile::stream_config& stream_cfg, - const GroupedGemmRunContext& ctx) override { - auto descs = build_descs(ctx); - - constexpr int kBlockPerCu = 1; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(descs); - auto kargs = Kernel::MakeKargs(descs); - if (!Kernel::IsSupportedArgument(kargs)) { - NVTE_ERROR("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config."); - } - - HIP_CHECK_ERROR(hipMemcpyAsync(ctx.workspace, - kargs.data(), - kargs.size() * sizeof(typename decltype(kargs)::value_type), - hipMemcpyHostToDevice, - ctx.stream)); - - ck_tile::launch_kernel( - stream_cfg, ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, - ck_tile::cast_pointer_to_constant_address_space(ctx.workspace), - ctx.group_num)); - return true; - }; -}; - -// Primus-Turbo-style extern template declarations -#define DECL_CK_QUANT_GG_RUNNER_EXTERN(AType, BType, CType, ALayout, BLayout, CLayout, TileCfg, MemOp) \ - extern template class QuantGroupedGemmRunner; - -#define DECL_CK_QUANT_GG_RUNNER(AType, BType, CType, ALayout, BLayout, CLayout, TileCfg, MemOp) \ - template class QuantGroupedGemmRunner; - -#define APPLY_CK_GG_ALL_LAYOUT(MACRO, AType, BType, CType, TileCfg, MemOp) \ - MACRO(AType, BType, CType, RowMajor, ColMajor, RowMajor, TileCfg, MemOp) \ - MACRO(AType, BType, CType, RowMajor, RowMajor, RowMajor, TileCfg, MemOp) \ - MACRO(AType, BType, CType, ColMajor, RowMajor, RowMajor, TileCfg, MemOp) - // FP8_E4M3 * FP8_E4M3 = FP16 APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_gfx950_impl.h b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_gfx950_impl.h new file mode 100644 index 000000000..e934231ae --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_gfx950_impl.h @@ -0,0 +1,188 @@ +#pragma once + +#include "ck_grouped_gemm_common.h" +#include "ck_grouped_gemm_fp8_runner_common.h" + +#include +#include +#include +#include + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" + +namespace transformer_engine { +namespace grouped_gemm { + +// ------------------------------------- +// GFX950-specific tile configs: FP8/BF8 +// ------------------------------------- + +struct TileCfg_GFX950_256x256x128 { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 128; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + +struct TileCfg_GFX950_256x256x128_padding : TileCfg_GFX950_256x256x128 { + static constexpr bool kPadN = true; +}; + +struct TileCfg_GFX950_128x128x128 { + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 128; + + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool DoubleSmemBuffer = false; + + static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; + static constexpr ck_tile::index_t TilePartitionerM01 = 4; +}; + +// FP8_E4M3 * FP8_E4M3 = FP16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E4M3 * FP8_E4M3 = FP32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E4M3 * FP8_E4M3 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E5M2 * FP8_E5M2 = FP16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E5M2 * FP8_E5M2 = FP32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E5M2 * FP8_E5M2 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E5M2 * FP8_E4M3 = F16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E5M2 * FP8_E4M3 = F32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E5M2 * FP8_E4M3 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E4M3 * FP8_E5M2 = F16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E4M3 * FP8_E5M2 = F32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +// FP8_E4M3 * FP8_E5M2 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_runner_common.h b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_runner_common.h new file mode 100644 index 000000000..1e54a61bc --- /dev/null +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_runner_common.h @@ -0,0 +1,203 @@ +#include "ck_grouped_gemm_common.h" + +#include +#include +#include +#include + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" + +namespace transformer_engine { +namespace grouped_gemm { + +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor; + +template struct TETypeToCKType; +template <> struct TETypeToCKType { using type = float; }; +template <> struct TETypeToCKType { using type = ck_tile::fp8_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bf8_t; }; +template <> struct TETypeToCKType { using type = ck_tile::half_t; }; +template <> struct TETypeToCKType { using type = ck_tile::bfloat16_t; }; + +static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) { + return t.data; +} + +static inline const transformer_engine::SimpleTensor& scale_inv_view(const transformer_engine::Tensor& t) { + return t.scale_inv; +} + +template +class QuantGroupedGemmRunner : public RunnerInterface { +public: + // hard-coded for tensor quant for now + static constexpr ck_tile::QuantType QuantMode = ck_tile::QuantType::TensorQuant; + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + + using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner< + GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>; + + using AQLayout = RowMajor; + using BQLayout = RowMajor; + + using UniversalTraits = + ck_tile::TileGemmQuantTraits< + TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK, + false, false, ALayout, BLayout, CLayout, + QuantMode, AQLayout, BQLayout, + false, TileCfg::DoubleSmemBuffer, false>; + + using Problem = ck_tile::GemmRowColTensorQuantPipelineProblem< + AType, BType, AccType, + AccType, GemmShape, UniversalTraits, + false, AccType>; + + using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + using Epilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem< + AType, BType, ck_tile::tuple<>, AccType, + CType, ck_tile::tuple<>, CLayout, + ck_tile::element_wise::PassThrough, + Partitioner::MPerBlock, Partitioner::NPerBlock, + TileCfg::M_Warp, TileCfg::N_Warp, + TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile, + Problem::TransposeC, MemOp>>; + + using Kernel = ck_tile::QuantGroupedGemmKernel; + + using HostArgs = ck_tile::QuantGroupedGemmHostArgs; + +public: + static std::vector build_descs(const GroupedGemmRunContext& ctx) { + const size_t needed = Kernel::GetWorkSpaceSize(ctx.group_num); + if (!ctx.workspace || ctx.workspace_bytes < needed) { + NVTE_ERROR("ck_tile_grouped_gemm: insufficient workspace. Needed bytes=", needed); + } + std::vector descs; + descs.reserve(ctx.group_num); + for (int i = 0; i < ctx.group_num; ++i) { + const transformer_engine::Tensor* const A_te = + transformer_engine::convertNVTETensorCheck(ctx.A[i]); + const transformer_engine::Tensor* const B_te = + transformer_engine::convertNVTETensorCheck(ctx.B[i]); + transformer_engine::Tensor* D_te = + transformer_engine::convertNVTETensorCheck(ctx.D[i]); + + const auto& a = data_view(*A_te); + const auto& b = data_view(*B_te); + const auto& d = data_view(*D_te); + + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0; + if (!get_flat_2d_dims(*A_te, Ad0, Ad1) || + !get_flat_2d_dims(*B_te, Bd0, Bd1) || + !get_flat_2d_dims(*D_te, Dd0, Dd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2."); + } + + const int64_t M = ctx.transA ? Ad1 : Ad0; + const int64_t K = ctx.transA ? Ad0 : Ad1; + const int64_t N = ctx.transB ? Bd0 : Bd1; + const int64_t Kb = ctx.transB ? Bd1 : Bd0; + + if (Kb != K) { + NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i); + } + + if (Dd0 != M || Dd1 != N) { + NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i); + } + + const ck_tile::index_t stride_A = Ad1; + const ck_tile::index_t stride_B = Bd1; + const ck_tile::index_t stride_E = Dd1; + + // Hard-coded to tensor quant for the moment + ck_tile::index_t AQK = 1; + ck_tile::index_t BQK = 1; + ck_tile::index_t stride_AQ = 1; + ck_tile::index_t stride_BQ = 1; + + const auto& aq = scale_inv_view(*A_te); + const auto& bq = scale_inv_view(*B_te); + + descs.emplace_back( + a.dptr, + b.dptr, + d.dptr, + aq.dptr, + bq.dptr, + 1, + M, + N, + K, + AQK, + BQK, + stride_A, + stride_B, + stride_E, + stride_AQ, + stride_BQ); + } + + return descs; + }; + bool run(const ck_tile::stream_config& stream_cfg, + const GroupedGemmRunContext& ctx) override { + auto descs = build_descs(ctx); + + constexpr int kBlockPerCu = 1; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(descs); + auto kargs = Kernel::MakeKargs(descs); + if (!Kernel::IsSupportedArgument(kargs)) { + NVTE_ERROR("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config."); + } + + HIP_CHECK_ERROR(hipMemcpyAsync(ctx.workspace, + kargs.data(), + kargs.size() * sizeof(typename decltype(kargs)::value_type), + hipMemcpyHostToDevice, + ctx.stream)); + + ck_tile::launch_kernel( + stream_cfg, ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, + ck_tile::cast_pointer_to_constant_address_space(ctx.workspace), + ctx.group_num)); + return true; + }; +}; + +// Primus-Turbo-style extern template declarations +#define DECL_CK_QUANT_GG_RUNNER_EXTERN(AType, BType, CType, ALayout, BLayout, CLayout, TileCfg, MemOp) \ + extern template class QuantGroupedGemmRunner; + +#define DECL_CK_QUANT_GG_RUNNER(AType, BType, CType, ALayout, BLayout, CLayout, TileCfg, MemOp) \ + template class QuantGroupedGemmRunner; + +#define APPLY_CK_GG_ALL_LAYOUT(MACRO, AType, BType, CType, TileCfg, MemOp) \ + MACRO(AType, BType, CType, RowMajor, ColMajor, RowMajor, TileCfg, MemOp) \ + MACRO(AType, BType, CType, RowMajor, RowMajor, RowMajor, TileCfg, MemOp) \ + MACRO(AType, BType, CType, ColMajor, RowMajor, RowMajor, TileCfg, MemOp) + +} +} \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_bf16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_bf16_instantiations.cpp index dad9f279b..4c60faead 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_bf16_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_bf16_instantiations.cpp @@ -12,5 +12,6 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16 APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x128x64, ck_tile::memory_operation_enum::atomic_add) APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) + } // namespace grouped_gemm } // namespace transformer_engine \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx942_instantiations.cpp similarity index 88% rename from transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_instantiations.cpp rename to transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx942_instantiations.cpp index af888bc3a..5a58e612d 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx942_instantiations.cpp @@ -1,4 +1,5 @@ -#include "../ck_grouped_gemm_fp8_impl.h" +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) +#include "../ck_grouped_gemm_fp8_gfx942_impl.h" namespace transformer_engine { namespace grouped_gemm { @@ -13,4 +14,5 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx950_instantiations.cpp new file mode 100644 index 000000000..d548848ab --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx950_instantiations.cpp @@ -0,0 +1,18 @@ +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) +#include "../ck_grouped_gemm_fp8_gfx950_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E5M2 * FP8_E5M2 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx942_instantiations.cpp similarity index 88% rename from transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_instantiations.cpp rename to transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx942_instantiations.cpp index 2a5e1f913..60486248f 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx942_instantiations.cpp @@ -1,4 +1,5 @@ -#include "../ck_grouped_gemm_fp8_impl.h" +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) +#include "../ck_grouped_gemm_fp8_gfx942_impl.h" namespace transformer_engine { namespace grouped_gemm { @@ -13,4 +14,5 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx950_instantiations.cpp new file mode 100644 index 000000000..6f024e588 --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx950_instantiations.cpp @@ -0,0 +1,18 @@ +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) +#include "../ck_grouped_gemm_fp8_gfx950_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E5M2 * FP8_E5M2 = FP16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx942_instantiations.cpp similarity index 88% rename from transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_instantiations.cpp rename to transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx942_instantiations.cpp index 266ec89db..6dac2443c 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx942_instantiations.cpp @@ -1,4 +1,5 @@ -#include "../ck_grouped_gemm_fp8_impl.h" +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) +#include "../ck_grouped_gemm_fp8_gfx942_impl.h" namespace transformer_engine { namespace grouped_gemm { @@ -13,4 +14,5 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx950_instantiations.cpp new file mode 100644 index 000000000..cf8ad7b35 --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx950_instantiations.cpp @@ -0,0 +1,18 @@ +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) +#include "../ck_grouped_gemm_fp8_gfx950_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E5M2 * FP8_E5M2 = FP32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx942_instantiations.cpp similarity index 89% rename from transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_instantiations.cpp rename to transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx942_instantiations.cpp index 7fa5514b2..c76bbdb47 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx942_instantiations.cpp @@ -1,4 +1,5 @@ -#include "../ck_grouped_gemm_fp8_impl.h" +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) +#include "../ck_grouped_gemm_fp8_gfx942_impl.h" namespace transformer_engine { namespace grouped_gemm { @@ -13,4 +14,5 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx950_instantiations.cpp new file mode 100644 index 000000000..954620a06 --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx950_instantiations.cpp @@ -0,0 +1,18 @@ +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) +#include "../ck_grouped_gemm_fp8_gfx950_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E5M2 * FP8_E4M3 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx942_instantiations.cpp similarity index 88% rename from transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_instantiations.cpp rename to transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx942_instantiations.cpp index 3d6a572b5..6a1e6811b 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx942_instantiations.cpp @@ -1,4 +1,5 @@ -#include "../ck_grouped_gemm_fp8_impl.h" +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) +#include "../ck_grouped_gemm_fp8_gfx942_impl.h" namespace transformer_engine { namespace grouped_gemm { @@ -13,4 +14,5 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx950_instantiations.cpp new file mode 100644 index 000000000..42d3d39ad --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx950_instantiations.cpp @@ -0,0 +1,18 @@ +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) +#include "../ck_grouped_gemm_fp8_gfx950_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E5M2 * FP8_E4M3 = F16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx942_instantiations.cpp similarity index 88% rename from transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_instantiations.cpp rename to transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx942_instantiations.cpp index 5b47862b3..66ec6cc80 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx942_instantiations.cpp @@ -1,4 +1,5 @@ -#include "../ck_grouped_gemm_fp8_impl.h" +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) +#include "../ck_grouped_gemm_fp8_gfx942_impl.h" namespace transformer_engine { namespace grouped_gemm { @@ -13,4 +14,5 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx950_instantiations.cpp new file mode 100644 index 000000000..586e5d84c --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx950_instantiations.cpp @@ -0,0 +1,18 @@ +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) +#include "../ck_grouped_gemm_fp8_gfx950_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E5M2 * FP8_E4M3 = F32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx942_instantiations.cpp similarity index 88% rename from transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_instantiations.cpp rename to transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx942_instantiations.cpp index 0897d3065..9c81ba606 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx942_instantiations.cpp @@ -1,4 +1,5 @@ -#include "../ck_grouped_gemm_fp8_impl.h" +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) +#include "../ck_grouped_gemm_fp8_gfx942_impl.h" namespace transformer_engine { namespace grouped_gemm { @@ -13,4 +14,5 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx950_instantiations.cpp new file mode 100644 index 000000000..d85a7ca9a --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx950_instantiations.cpp @@ -0,0 +1,18 @@ +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) +#include "../ck_grouped_gemm_fp8_gfx950_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E4M3 * FP8_E5M2 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx942_instantiations.cpp similarity index 88% rename from transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_instantiations.cpp rename to transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx942_instantiations.cpp index b57eaf4c5..4c8e77e27 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx942_instantiations.cpp @@ -1,4 +1,5 @@ -#include "../ck_grouped_gemm_fp8_impl.h" +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) +#include "../ck_grouped_gemm_fp8_gfx942_impl.h" namespace transformer_engine { namespace grouped_gemm { @@ -13,4 +14,5 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx950_instantiations.cpp new file mode 100644 index 000000000..89ca90b4a --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx950_instantiations.cpp @@ -0,0 +1,18 @@ +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) +#include "../ck_grouped_gemm_fp8_gfx950_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E4M3 * FP8_E5M2 = F16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx942_instantiations.cpp similarity index 88% rename from transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_instantiations.cpp rename to transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx942_instantiations.cpp index 7f4d43b9b..eea2ae44f 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx942_instantiations.cpp @@ -1,4 +1,5 @@ -#include "../ck_grouped_gemm_fp8_impl.h" +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) +#include "../ck_grouped_gemm_fp8_gfx942_impl.h" namespace transformer_engine { namespace grouped_gemm { @@ -13,4 +14,5 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx950_instantiations.cpp new file mode 100644 index 000000000..f85c99856 --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx950_instantiations.cpp @@ -0,0 +1,18 @@ +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) +#include "../ck_grouped_gemm_fp8_gfx950_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E4M3 * FP8_E5M2 = F32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx942_instantiations.cpp similarity index 88% rename from transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_instantiations.cpp rename to transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx942_instantiations.cpp index 96df6b412..fecc93061 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx942_instantiations.cpp @@ -1,4 +1,5 @@ -#include "../ck_grouped_gemm_fp8_impl.h" +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) +#include "../ck_grouped_gemm_fp8_gfx942_impl.h" namespace transformer_engine { namespace grouped_gemm { @@ -13,4 +14,5 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx950_instantiations.cpp new file mode 100644 index 000000000..0a37f32d8 --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx950_instantiations.cpp @@ -0,0 +1,18 @@ +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) +#include "../ck_grouped_gemm_fp8_gfx950_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E4M3 * FP8_E4M3 = BF16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx942_instantiations.cpp similarity index 88% rename from transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_instantiations.cpp rename to transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx942_instantiations.cpp index 9923a3467..c626f4379 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx942_instantiations.cpp @@ -1,4 +1,5 @@ -#include "../ck_grouped_gemm_fp8_impl.h" +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) +#include "../ck_grouped_gemm_fp8_gfx942_impl.h" namespace transformer_engine { namespace grouped_gemm { @@ -13,4 +14,5 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx950_instantiations.cpp new file mode 100644 index 000000000..01b33b8de --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx950_instantiations.cpp @@ -0,0 +1,18 @@ +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) +#include "../ck_grouped_gemm_fp8_gfx950_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E4M3 * FP8_E4M3 = FP16 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx942_instantiations.cpp similarity index 88% rename from transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_instantiations.cpp rename to transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx942_instantiations.cpp index 255d3847c..bbe64fd0e 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx942_instantiations.cpp @@ -1,4 +1,5 @@ -#include "../ck_grouped_gemm_fp8_impl.h" +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) +#include "../ck_grouped_gemm_fp8_gfx942_impl.h" namespace transformer_engine { namespace grouped_gemm { @@ -13,4 +14,5 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine +#endif \ No newline at end of file diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx950_instantiations.cpp new file mode 100644 index 000000000..53a589524 --- /dev/null +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx950_instantiations.cpp @@ -0,0 +1,18 @@ +#if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) +#include "../ck_grouped_gemm_fp8_gfx950_impl.h" + +namespace transformer_engine { +namespace grouped_gemm { + +// FP8_E4M3 * FP8_E4M3 = FP32 +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) + +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) + +} // namespace grouped_gemm +} // namespace transformer_engine +#endif \ No newline at end of file From 5275ac6359c0ed249fd0c489c9f34ceaa88b4589 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Wed, 11 Mar 2026 16:32:43 +0000 Subject: [PATCH 49/51] Fix dev merge conflicts in CMakeLists.txt --- transformer_engine/common/CMakeLists.txt | 83 ++++++++++-------------- 1 file changed, 36 insertions(+), 47 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 32822dfb9..05fe3b4a3 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -255,6 +255,42 @@ else() fused_attn_rocm/fused_attn.cpp gemm/rocm_gemm.cu gemm/ck_grouped_gemm.cpp + gemm/ck_grouped_gemm_fp8.cpp + gemm/ck_grouped_gemm_fp8_factory_common.cpp + gemm/ck_grouped_gemm_fp8_factory_gfx942.cpp + gemm/ck_grouped_gemm_fp8_factory_gfx950.cpp + gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx942_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx950_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx950_instantiations.cpp + gemm/ck_grouped_gemm_fp16.cpp + gemm/ck_grouped_gemm_fp16_factory.cpp + gemm/instantiations/ck_grouped_gemm_bf16_bf16_bf16_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp16_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp32_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp16_fp16_bf16_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp16_instantiations.cpp + gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp32_instantiations.cpp amd_detail/system.cpp) list(APPEND transformer_engine_cuda_sources fused_attn_rocm/fused_attn_aotriton.cpp @@ -312,50 +348,6 @@ endif() add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) else() - list(APPEND transformer_engine_SOURCES - fused_attn_rocm/fused_attn.cpp - fused_attn_rocm/fused_attn_aotriton.cpp - fused_attn_rocm/fused_attn_ck.cpp - fused_attn_rocm/utils.cpp - gemm/rocm_gemm.cu - gemm/ck_grouped_gemm.cpp - gemm/ck_grouped_gemm_fp8.cpp - gemm/ck_grouped_gemm_fp8_factory_common.cpp - gemm/ck_grouped_gemm_fp8_factory_gfx942.cpp - gemm/ck_grouped_gemm_fp8_factory_gfx950.cpp - gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx942_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx942_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx942_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx942_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx942_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx942_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx942_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx942_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx942_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx942_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx942_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx942_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx950_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx950_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx950_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx950_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx950_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx950_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx950_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx950_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx950_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx950_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx950_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx950_instantiations.cpp - gemm/ck_grouped_gemm_fp16.cpp - gemm/ck_grouped_gemm_fp16_factory.cpp - gemm/instantiations/ck_grouped_gemm_bf16_bf16_bf16_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp16_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp32_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_fp16_fp16_bf16_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp16_instantiations.cpp - gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp32_instantiations.cpp - amd_detail/system.cpp) # process source code files set(TE ${CMAKE_CURRENT_SOURCE_DIR}/../..) @@ -400,9 +392,6 @@ set_property( APPEND PROPERTY COMPILE_OPTIONS "--generate-code=arch=compute_90a,code=sm_90a;-g0") -else() - message(FATAL_ERROR "cutlass gemm/cutlass_grouped_gemm.cu kernel required sm 90a") -endif() else() set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel) target_include_directories(transformer_engine PRIVATE ${CK_ROOT}/include) From 39407482041b09ede59d0245fc60f9e529d40116 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Thu, 12 Mar 2026 19:24:58 +0000 Subject: [PATCH 50/51] add copyright headers and add blank line at bottom of each new file. Also tweaked tile configs slightly, and gave them clearer names for gfx942 --- .../common/gemm/ck_grouped_gemm_common.h | 2 +- .../common/gemm/ck_grouped_gemm_fp16.cpp | 8 +- .../gemm/ck_grouped_gemm_fp16_factory.cpp | 4 +- .../common/gemm/ck_grouped_gemm_fp16_impl.h | 8 +- .../common/gemm/ck_grouped_gemm_fp8.cpp | 8 +- .../ck_grouped_gemm_fp8_factory_common.cpp | 2 +- .../gemm/ck_grouped_gemm_fp8_factory_decl.h | 8 +- .../ck_grouped_gemm_fp8_factory_gfx942.cpp | 18 +- .../ck_grouped_gemm_fp8_factory_gfx950.cpp | 96 +++------- .../gemm/ck_grouped_gemm_fp8_gfx942_impl.h | 165 ++++++++--------- .../gemm/ck_grouped_gemm_fp8_gfx950_impl.h | 168 +++++++++--------- .../gemm/ck_grouped_gemm_fp8_runner_common.h | 8 +- ...ped_gemm_bf16_bf16_bf16_instantiations.cpp | 8 +- ...ped_gemm_bf16_bf16_fp16_instantiations.cpp | 8 +- ...ped_gemm_bf16_bf16_fp32_instantiations.cpp | 8 +- ...emm_bf8_bf8_bf16_gfx942_instantiations.cpp | 20 ++- ...emm_bf8_bf8_bf16_gfx950_instantiations.cpp | 20 ++- ...emm_bf8_bf8_fp16_gfx942_instantiations.cpp | 20 ++- ...emm_bf8_bf8_fp16_gfx950_instantiations.cpp | 20 ++- ...emm_bf8_bf8_fp32_gfx942_instantiations.cpp | 20 ++- ...emm_bf8_bf8_fp32_gfx950_instantiations.cpp | 20 ++- ...emm_bf8_fp8_bf16_gfx942_instantiations.cpp | 20 ++- ...emm_bf8_fp8_bf16_gfx950_instantiations.cpp | 20 ++- ...emm_bf8_fp8_fp16_gfx942_instantiations.cpp | 20 ++- ...emm_bf8_fp8_fp16_gfx950_instantiations.cpp | 20 ++- ...emm_bf8_fp8_fp32_gfx942_instantiations.cpp | 20 ++- ...emm_bf8_fp8_fp32_gfx950_instantiations.cpp | 20 ++- ...ped_gemm_fp16_fp16_bf16_instantiations.cpp | 8 +- ...ped_gemm_fp16_fp16_fp16_instantiations.cpp | 8 +- ...ped_gemm_fp16_fp16_fp32_instantiations.cpp | 8 +- ...emm_fp8_bf8_bf16_gfx942_instantiations.cpp | 20 ++- ...emm_fp8_bf8_bf16_gfx950_instantiations.cpp | 20 ++- ...emm_fp8_bf8_fp16_gfx942_instantiations.cpp | 20 ++- ...emm_fp8_bf8_fp16_gfx950_instantiations.cpp | 20 ++- ...emm_fp8_bf8_fp32_gfx942_instantiations.cpp | 20 ++- ...emm_fp8_bf8_fp32_gfx950_instantiations.cpp | 20 ++- ...emm_fp8_fp8_bf16_gfx942_instantiations.cpp | 20 ++- ...emm_fp8_fp8_bf16_gfx950_instantiations.cpp | 20 ++- ...emm_fp8_fp8_fp16_gfx942_instantiations.cpp | 20 ++- ...emm_fp8_fp8_fp16_gfx950_instantiations.cpp | 20 ++- ...emm_fp8_fp8_fp32_gfx942_instantiations.cpp | 20 ++- ...emm_fp8_fp8_fp32_gfx950_instantiations.cpp | 20 ++- 42 files changed, 602 insertions(+), 421 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_common.h b/transformer_engine/common/gemm/ck_grouped_gemm_common.h index 5a567da3d..5e3f36386 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm_common.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm_common.h @@ -78,4 +78,4 @@ std::unique_ptr make_fp16_runner(DType a_dtype, const GroupedGemmRunContext& ctx); } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp16.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp16.cpp index 582572680..3446b4d29 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm_fp16.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp16.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #include "ck_grouped_gemm_common.h" namespace transformer_engine { @@ -20,4 +26,4 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, } } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp16_factory.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp16_factory.cpp index 3ba637b09..aa859c737 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm_fp16_factory.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp16_factory.cpp @@ -4,7 +4,7 @@ * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ - #include "ck_grouped_gemm_common.h" +#include "ck_grouped_gemm_common.h" #include "ck_grouped_gemm_fp16_impl.h" namespace transformer_engine { @@ -135,4 +135,4 @@ std::unique_ptr make_fp16_runner(DType a_dtype, } } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp16_impl.h b/transformer_engine/common/gemm/ck_grouped_gemm_fp16_impl.h index b9bab836a..c8e0c06f6 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm_fp16_impl.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp16_impl.h @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #pragma once #include "ck_grouped_gemm_common.h" @@ -272,4 +278,4 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::b APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER_EXTERN, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::bfloat16_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) } -} \ No newline at end of file +} diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp8.cpp index 427f2dd67..cdf1fc546 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm_fp8.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #include "ck_grouped_gemm_common.h" namespace transformer_engine { @@ -20,4 +26,4 @@ bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, } } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_common.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_common.cpp index 39118c46e..01fa860e5 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_common.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_common.cpp @@ -47,4 +47,4 @@ std::unique_ptr make_fp8_runner( } } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_decl.h b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_decl.h index a9512370c..200cfecee 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_decl.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_decl.h @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #pragma once #include "ck_grouped_gemm_common.h" @@ -27,4 +33,4 @@ std::unique_ptr make_fp8_runner_gfx950( const GroupedGemmRunContext& ctx); } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_gfx942.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_gfx942.cpp index 1281af69a..e8740047b 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_gfx942.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_gfx942.cpp @@ -3,6 +3,7 @@ * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) #include "ck_grouped_gemm_common.h" @@ -20,8 +21,11 @@ std::unique_ptr make_fp8_runner_typed_gfx942(DType d_dtype, con TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { using CType = typename TETypeToCKType::type; if (ctx.N % 256 == 0) { + // we check whether the operand order is bf8/fp8 (as opposed to fp8/bf8) because there is no supported + // WarpGemmMfma_f32_32x32x32_bf8_fp8 in CK's warp gemm dispatcher + // See: ops/gemm/warp/warp_gemm_dispatcher.hpp if constexpr (std::is_same_v && std::is_same_v) { - using TileCfg = TileCfg_256x256x128_k16; + using TileCfg = TileCfg_GFX942_256x256x128_32x32x16_2x2x1; if (ctx.accumulate) { using Runner = QuantGroupedGemmRunner make_fp8_runner_typed_gfx942(DType d_dtype, con runner = std::make_unique(); } } else { - using TileCfg = TileCfg_256x256x128; + using TileCfg = TileCfg_GFX942_256x256x128_32x32x32_2x2x1; if (ctx.accumulate) { using Runner = QuantGroupedGemmRunner make_fp8_runner_typed_gfx942(DType d_dtype, con } } else if (ctx.N % 128 == 0) { if constexpr (std::is_same_v && std::is_same_v) { - using TileCfg = TileCfg_256x128x128_k16; + using TileCfg = TileCfg_GFX942_256x128x128_32x32x16_2x2x1; if (ctx.accumulate) { using Runner = QuantGroupedGemmRunner make_fp8_runner_typed_gfx942(DType d_dtype, con runner = std::make_unique(); } } else { - using TileCfg = TileCfg_256x128x128; + using TileCfg = TileCfg_GFX942_256x128x128_32x32x32_2x2x1; if (ctx.accumulate) { using Runner = QuantGroupedGemmRunner make_fp8_runner_typed_gfx942(DType d_dtype, con } } else { if constexpr (std::is_same_v && std::is_same_v) { - using TileCfg = TileCfg_256x128x128_k16_padding; + using TileCfg = TileCfg_GFX942_256x128x128_32x32x16_2x2x1_padding; if (ctx.accumulate) { using Runner = QuantGroupedGemmRunner make_fp8_runner_typed_gfx942(DType d_dtype, con runner = std::make_unique(); } } else { - using TileCfg = TileCfg_256x128x128_padding; + using TileCfg = TileCfg_GFX942_256x128x128_padding; if (ctx.accumulate) { using Runner = QuantGroupedGemmRunner make_fp8_runner_gfx942(DType a_dtype, } } -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_gfx950.cpp b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_gfx950.cpp index aba6d865c..4b1215f8e 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_gfx950.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_factory_gfx950.cpp @@ -3,6 +3,7 @@ * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) #include "ck_grouped_gemm_common.h" @@ -20,49 +21,33 @@ std::unique_ptr make_fp8_runner_typed_gfx950(DType d_dtype, con TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { using CType = typename TETypeToCKType::type; if (ctx.N % 256 == 0) { - if constexpr (std::is_same_v && std::is_same_v) { - using TileCfg = TileCfg_GFX950_128x128x128; - if (ctx.accumulate) { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } else { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } + using TileCfg = TileCfg_GFX950_256x256x128_16x16x128_2x2x1; + if (ctx.accumulate) { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); } else { - using TileCfg = TileCfg_GFX950_128x128x128; - if (ctx.accumulate) { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } else { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); } } else if (ctx.N % 128 == 0) { - if constexpr (std::is_same_v && std::is_same_v) { - using TileCfg = TileCfg_GFX950_128x128x128; - if (ctx.accumulate) { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } else { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } - } else { - using TileCfg = TileCfg_GFX950_128x128x128; + using TileCfg = TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding; + if (ctx.accumulate) { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } else { + using Runner = QuantGroupedGemmRunner; + runner = std::make_unique(); + } + } else { + using TileCfg = TileCfg_GFX950_128x128x128_32x32x64_2x2x1; if (ctx.accumulate) { using Runner = QuantGroupedGemmRunner make_fp8_runner_typed_gfx950(DType d_dtype, con TileCfg, ck_tile::memory_operation_enum::set>; runner = std::make_unique(); } - } - } else { - if constexpr (std::is_same_v && std::is_same_v) { - using TileCfg = TileCfg_GFX950_128x128x128; - if (ctx.accumulate) { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } else { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } - } else { - using TileCfg = TileCfg_GFX950_128x128x128; - if (ctx.accumulate) { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } else { - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - } - } } }); return runner; @@ -194,4 +150,4 @@ std::unique_ptr make_fp8_runner_gfx950(DType a_dtype, } } -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_gfx942_impl.h b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_gfx942_impl.h index 53e46cf03..e2a8dca5d 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_gfx942_impl.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_gfx942_impl.h @@ -1,3 +1,8 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ #pragma once #include "ck_grouped_gemm_common.h" @@ -23,10 +28,10 @@ namespace transformer_engine { namespace grouped_gemm { // ------------------------- -// Tile configs: FP8/BF8 +// GFX942-specific tile configs: FP8/BF8 // ------------------------- -struct TileCfg_256x256x128 { +struct TileCfg_GFX942_256x256x128_32x32x32_2x2x1 { static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 128; @@ -49,11 +54,11 @@ struct TileCfg_256x256x128 { static constexpr ck_tile::index_t TilePartitionerM01 = 4; }; -struct TileCfg_256x128x128 : TileCfg_256x256x128 { +struct TileCfg_GFX942_256x128x128_32x32x32_2x2x1 : TileCfg_GFX942_256x256x128_32x32x32_2x2x1 { static constexpr ck_tile::index_t N_Tile = 128; }; -struct TileCfg_256x128x128_padding : TileCfg_256x128x128 { +struct TileCfg_GFX942_256x128x128_padding : TileCfg_GFX942_256x128x128_32x32x32_2x2x1 { static constexpr bool kPadN = true; }; @@ -61,7 +66,7 @@ struct TileCfg_256x128x128_padding : TileCfg_256x128x128 { // Fallback FP8/BF8 tile family for normalized (bf8_t, fp8_t) pair. // ------------------------- -struct TileCfg_256x256x128_k16 { +struct TileCfg_GFX942_256x256x128_32x32x16_2x2x1 { static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 128; @@ -84,121 +89,121 @@ struct TileCfg_256x256x128_k16 { static constexpr ck_tile::index_t TilePartitionerM01 = 4; }; -struct TileCfg_256x128x128_k16 : TileCfg_256x256x128_k16 { +struct TileCfg_GFX942_256x128x128_32x32x16_2x2x1 : TileCfg_GFX942_256x256x128_32x32x16_2x2x1 { static constexpr ck_tile::index_t N_Tile = 128; }; -struct TileCfg_256x128x128_k16_padding : TileCfg_256x128x128_k16 { +struct TileCfg_GFX942_256x128x128_32x32x16_2x2x1_padding : TileCfg_GFX942_256x128x128_32x32x16_2x2x1 { static constexpr bool kPadN = true; }; // FP8_E4M3 * FP8_E4M3 = FP16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) // FP8_E4M3 * FP8_E4M3 = FP32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) // FP8_E4M3 * FP8_E4M3 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) // FP8_E5M2 * FP8_E5M2 = FP16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) // FP8_E5M2 * FP8_E5M2 = FP32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) // FP8_E5M2 * FP8_E5M2 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) // FP8_E5M2 * FP8_E4M3 = F16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x256x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_32x32x16_2x2x1_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x256x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_32x32x16_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) // FP8_E5M2 * FP8_E4M3 = F32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x256x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x128x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x128x128_32x32x16_2x2x1_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x256x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x128x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x128x128_32x32x16_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) // FP8_E5M2 * FP8_E4M3 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x256x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_32x32x16_2x2x1_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x256x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_32x32x16_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) // FP8_E4M3 * FP8_E5M2 = F16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) // FP8_E4M3 * FP8_E5M2 = F32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) // FP8_E4M3 * FP8_E5M2 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_gfx950_impl.h b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_gfx950_impl.h index e934231ae..0af5de74b 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_gfx950_impl.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_gfx950_impl.h @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #pragma once #include "ck_grouped_gemm_common.h" @@ -26,7 +32,7 @@ namespace grouped_gemm { // GFX950-specific tile configs: FP8/BF8 // ------------------------------------- -struct TileCfg_GFX950_256x256x128 { +struct TileCfg_GFX950_256x256x128_16x16x128_2x2x1 { static constexpr ck_tile::index_t M_Tile = 256; static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 128; @@ -45,21 +51,21 @@ struct TileCfg_GFX950_256x256x128 { static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t TilePartitionerGroupNum = 8; - static constexpr ck_tile::index_t TilePartitionerM01 = 4; + static constexpr ck_tile::index_t TilePartitionerGroupNum = 16; + static constexpr ck_tile::index_t TilePartitionerM01 = 8; }; -struct TileCfg_GFX950_256x256x128_padding : TileCfg_GFX950_256x256x128 { - static constexpr bool kPadN = true; +struct TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding : TileCfg_GFX950_256x256x128_16x16x128_2x2x1 { + static constexpr bool kPadN = true; }; -struct TileCfg_GFX950_128x128x128 { +struct TileCfg_GFX950_128x128x128_32x32x64_2x2x1 { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; static constexpr ck_tile::index_t K_Tile = 128; - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; static constexpr ck_tile::index_t K_Warp = 1; static constexpr ck_tile::index_t M_Warp_Tile = 16; @@ -77,112 +83,112 @@ struct TileCfg_GFX950_128x128x128 { }; // FP8_E4M3 * FP8_E4M3 = FP16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) // FP8_E4M3 * FP8_E4M3 = FP32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) // FP8_E4M3 * FP8_E4M3 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) // FP8_E5M2 * FP8_E5M2 = FP16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) // FP8_E5M2 * FP8_E5M2 = FP32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) // FP8_E5M2 * FP8_E5M2 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) // FP8_E5M2 * FP8_E4M3 = F16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) // FP8_E5M2 * FP8_E4M3 = F32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) // FP8_E5M2 * FP8_E4M3 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) // FP8_E4M3 * FP8_E5M2 = F16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) // FP8_E4M3 * FP8_E5M2 = F32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) // FP8_E4M3 * FP8_E5M2 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER_EXTERN, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_runner_common.h b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_runner_common.h index 1e54a61bc..68218951e 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm_fp8_runner_common.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm_fp8_runner_common.h @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #include "ck_grouped_gemm_common.h" #include @@ -200,4 +206,4 @@ class QuantGroupedGemmRunner : public RunnerInterface { MACRO(AType, BType, CType, ColMajor, RowMajor, RowMajor, TileCfg, MemOp) } -} \ No newline at end of file +} diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_bf16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_bf16_instantiations.cpp index 4c60faead..d10b3010a 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_bf16_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_bf16_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #include "../ck_grouped_gemm_fp16_impl.h" namespace transformer_engine { @@ -14,4 +20,4 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16 } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp16_instantiations.cpp index ae49a781b..b87d2af96 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp16_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp16_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #include "../ck_grouped_gemm_fp16_impl.h" namespace transformer_engine { @@ -13,4 +19,4 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16 APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp32_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp32_instantiations.cpp index af5f7ac0c..7d6ad491c 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp32_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp32_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #include "../ck_grouped_gemm_fp16_impl.h" namespace transformer_engine { @@ -13,4 +19,4 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16 APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::bfloat16_t, ck_tile::bfloat16_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx942_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx942_instantiations.cpp index 5a58e612d..77575b0fa 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx942_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx942_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) #include "../ck_grouped_gemm_fp8_gfx942_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E5M2 * FP8_E5M2 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx950_instantiations.cpp index d548848ab..7fc368557 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx950_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx950_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) #include "../ck_grouped_gemm_fp8_gfx950_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E5M2 * FP8_E5M2 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx942_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx942_instantiations.cpp index 60486248f..424499cdc 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx942_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx942_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) #include "../ck_grouped_gemm_fp8_gfx942_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E5M2 * FP8_E5M2 = FP16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx950_instantiations.cpp index 6f024e588..8536d943d 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx950_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx950_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) #include "../ck_grouped_gemm_fp8_gfx950_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E5M2 * FP8_E5M2 = FP16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx942_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx942_instantiations.cpp index 6dac2443c..61613e906 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx942_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx942_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) #include "../ck_grouped_gemm_fp8_gfx942_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E5M2 * FP8_E5M2 = FP32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx950_instantiations.cpp index cf8ad7b35..36b67f37f 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx950_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx950_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) #include "../ck_grouped_gemm_fp8_gfx950_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E5M2 * FP8_E5M2 = FP32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx942_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx942_instantiations.cpp index c76bbdb47..24d2cf6f9 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx942_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx942_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) #include "../ck_grouped_gemm_fp8_gfx942_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E5M2 * FP8_E4M3 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x256x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_32x32x16_2x2x1_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x256x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_32x32x16_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx950_instantiations.cpp index 954620a06..85a11c0d4 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx950_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx950_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) #include "../ck_grouped_gemm_fp8_gfx950_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E5M2 * FP8_E4M3 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx942_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx942_instantiations.cpp index 6a1e6811b..0b55d6d76 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx942_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx942_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) #include "../ck_grouped_gemm_fp8_gfx942_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E5M2 * FP8_E4M3 = F16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x256x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_32x32x16_2x2x1_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x256x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_32x32x16_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx950_instantiations.cpp index 42d3d39ad..a7910de75 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx950_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx950_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) #include "../ck_grouped_gemm_fp8_gfx950_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E5M2 * FP8_E4M3 = F16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx942_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx942_instantiations.cpp index 66ec6cc80..8bc77b78c 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx942_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx942_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) #include "../ck_grouped_gemm_fp8_gfx942_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E5M2 * FP8_E4M3 = F32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x256x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x128x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x128x128_32x32x16_2x2x1_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x256x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_k16_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x256x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x128x128_32x32x16_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x128x128_32x32x16_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx950_instantiations.cpp index 586e5d84c..8a2c68be1 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx950_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx950_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) #include "../ck_grouped_gemm_fp8_gfx950_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E5M2 * FP8_E4M3 = F32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::bf8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_bf16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_bf16_instantiations.cpp index f40a62c38..258b6525a 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_bf16_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_bf16_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #include "../ck_grouped_gemm_fp16_impl.h" namespace transformer_engine { @@ -13,4 +19,4 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_t APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::bfloat16_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp16_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp16_instantiations.cpp index 0315b68ee..6f7076eea 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp16_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp16_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #include "../ck_grouped_gemm_fp16_impl.h" namespace transformer_engine { @@ -13,4 +19,4 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_t APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp32_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp32_instantiations.cpp index 9aaf6e9f4..b3adc5223 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp32_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp32_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #include "../ck_grouped_gemm_fp16_impl.h" namespace transformer_engine { @@ -13,4 +19,4 @@ APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, floa APPLY_CK_GG_ALL_LAYOUT(DECL_CK_GG_RUNNER, ck_tile::half_t, ck_tile::half_t, float, TileCfg_256x128x64_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm -} // namespace transformer_engine \ No newline at end of file +} // namespace transformer_engine diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx942_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx942_instantiations.cpp index 9c81ba606..ac51d28dd 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx942_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx942_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) #include "../ck_grouped_gemm_fp8_gfx942_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E4M3 * FP8_E5M2 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx950_instantiations.cpp index d85a7ca9a..52e08c866 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx950_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx950_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) #include "../ck_grouped_gemm_fp8_gfx950_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E4M3 * FP8_E5M2 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx942_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx942_instantiations.cpp index 4c8e77e27..40887be68 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx942_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx942_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) #include "../ck_grouped_gemm_fp8_gfx942_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E4M3 * FP8_E5M2 = F16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx950_instantiations.cpp index 89ca90b4a..c09eae062 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx950_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx950_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) #include "../ck_grouped_gemm_fp8_gfx950_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E4M3 * FP8_E5M2 = F16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx942_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx942_instantiations.cpp index eea2ae44f..d4d4e5cda 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx942_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx942_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) #include "../ck_grouped_gemm_fp8_gfx942_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E4M3 * FP8_E5M2 = F32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx950_instantiations.cpp index f85c99856..a2e20a482 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx950_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx950_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) #include "../ck_grouped_gemm_fp8_gfx950_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E4M3 * FP8_E5M2 = F32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::bf8_t, float, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx942_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx942_instantiations.cpp index fecc93061..6e659eb98 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx942_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx942_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) #include "../ck_grouped_gemm_fp8_gfx942_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E4M3 * FP8_E4M3 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx950_instantiations.cpp index 0a37f32d8..5c0a68ade 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx950_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx950_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) #include "../ck_grouped_gemm_fp8_gfx950_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E4M3 * FP8_E4M3 = BF16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::bfloat16_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx942_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx942_instantiations.cpp index c626f4379..aa989d2f1 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx942_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx942_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) #include "../ck_grouped_gemm_fp8_gfx942_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E4M3 * FP8_E4M3 = FP16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx950_instantiations.cpp index 01b33b8de..4201a7964 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx950_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx950_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) #include "../ck_grouped_gemm_fp8_gfx950_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E4M3 * FP8_E4M3 = FP16 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx942_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx942_instantiations.cpp index bbe64fd0e..9ab39d463 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx942_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx942_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx942__) #include "../ck_grouped_gemm_fp8_gfx942_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E4M3 * FP8_E4M3 = FP32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x256x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x128x128_32x32x32_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX942_256x128x128_padding, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif diff --git a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx950_instantiations.cpp b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx950_instantiations.cpp index 53a589524..4459e5202 100644 --- a/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx950_instantiations.cpp +++ b/transformer_engine/common/gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx950_instantiations.cpp @@ -1,3 +1,9 @@ +/************************************************************************* + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + #if !__HIP_DEVICE_COMPILE__ || defined(__gfx950__) #include "../ck_grouped_gemm_fp8_gfx950_impl.h" @@ -5,14 +11,14 @@ namespace transformer_engine { namespace grouped_gemm { // FP8_E4M3 * FP8_E4M3 = FP32 -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::set) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::set) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_padding, ck_tile::memory_operation_enum::atomic_add) -APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_256x256x128_16x16x128_2x2x1_padding, ck_tile::memory_operation_enum::atomic_add) +APPLY_CK_GG_ALL_LAYOUT(DECL_CK_QUANT_GG_RUNNER, ck_tile::fp8_t, ck_tile::fp8_t, float, TileCfg_GFX950_128x128x128_32x32x64_2x2x1, ck_tile::memory_operation_enum::atomic_add) } // namespace grouped_gemm } // namespace transformer_engine -#endif \ No newline at end of file +#endif From 670d2c402c8369ad22b4e7f95cb3b1891392d81d Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Thu, 19 Mar 2026 17:45:50 +0000 Subject: [PATCH 51/51] Fix cudnn-frontend submodule pointer after dev merge --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index deda80e53..be6c079be 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit deda80e5372d50e925d7bf4f76c5db779be3fbd5 +Subproject commit be6c079be8aaffa0fc079fcf039887e637c289c7