Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
ad748da
GEMM reference HIP implementation
matthiasdiener Dec 9, 2025
11e090b
blockwise amax
matthiasdiener Dec 11, 2025
9006224
Merge branch 'dev' into compute-ref-offload
matthiasdiener Dec 18, 2025
3ecea7f
Change to use Tensor arguments, combine mxfp8/non-mxfp8 paths
matthiasdiener Jan 13, 2026
cafee59
Merge remote-tracking branch 'origin/dev' into compute-ref-offload
matthiasdiener Jan 14, 2026
86fbbac
skip on SwizzleScale limitation on gfx950
matthiasdiener Jan 14, 2026
54de3db
Revert "skip on SwizzleScale limitation on gfx950"
matthiasdiener Jan 14, 2026
311ddfe
MXFP8 fix
matthiasdiener Jan 14, 2026
306e432
Merge remote-tracking branch 'origin/dev' into compute-ref-offload
matthiasdiener Jan 15, 2026
445e64f
correct scale_inv packing and exp2(biased−127) conversion
matthiasdiener Jan 15, 2026
462945f
cleanups
matthiasdiener Jan 15, 2026
e32fb3d
Merge branch 'dev' into compute-ref-offload
matthiasdiener Jan 19, 2026
7bf8adb
Merge remote-tracking branch 'origin/dev' into compute-ref-offload
matthiasdiener Jan 22, 2026
e11e400
use Tensor class for more device objects
matthiasdiener Jan 22, 2026
325ece6
Pass D Tensor into run_reference and move RefD allocation into Perfor…
matthiasdiener Jan 23, 2026
fc64b8c
[WIP] proof-of-concept: grouped GEMM with ck_tile
matthiasdiener Jan 26, 2026
134b350
Merge branch 'dev' into ck-grouped-gemm
matthiasdiener Jan 28, 2026
9091e6c
restructure and enable tests
matthiasdiener Jan 29, 2026
7435062
Merge remote-tracking branch 'origin/dev' into ck-grouped-gemm
matthiasdiener Jan 29, 2026
a00a1c8
Merge remote-tracking branch 'origin/dev' into ck-grouped-gemm
matthiasdiener Jan 30, 2026
4e9ead9
grid improvements
matthiasdiener Jan 30, 2026
259645c
restructure
matthiasdiener Feb 3, 2026
9986bd4
reduce code duplication & simplify
matthiasdiener Feb 4, 2026
355ec2f
make the code more similar to nv, check emopty gelu/bias
matthiasdiener Feb 4, 2026
df5e3ea
Merge branch 'dev' into ck-grouped-gemm
matthiasdiener Feb 4, 2026
a42f7ca
further simplify & make closer to nv
matthiasdiener Feb 4, 2026
fac7c11
add ck_tile reference
matthiasdiener Feb 4, 2026
71b97e0
rename in error messages
matthiasdiener Feb 4, 2026
dd3ed2f
allow flattened higher-D tensors
matthiasdiener Feb 4, 2026
7b0413e
Merge remote-tracking branch 'origin/dev' into ck-grouped-gemm
matthiasdiener Feb 5, 2026
ebc005f
relax tolerance on gfx942
matthiasdiener Feb 5, 2026
c0bf502
enable more tests
matthiasdiener Feb 5, 2026
0b16287
return early when num_gemms<=0
matthiasdiener Feb 5, 2026
58b34e7
simplify normalization
matthiasdiener Feb 5, 2026
74f229a
Merge remote-tracking branch 'origin/dev' into ck-grouped-gemm
matthiasdiener Feb 10, 2026
e28c801
run hipblaslt for num_gemms==1
matthiasdiener Feb 11, 2026
6151b96
Merge remote-tracking branch 'origin/dev' into ck-grouped-gemm
matthiasdiener Feb 12, 2026
5c57d47
disable ck_tile when accumulate=true
matthiasdiener Feb 17, 2026
29d6ab7
Merge remote-tracking branch 'origin/dev' into ck-grouped-gemm
matthiasdiener Feb 17, 2026
6e9aae4
Merge remote-tracking branch 'origin/dev' into ck-grouped-gemm
matthiasdiener Feb 17, 2026
2e844d9
remove test file
matthiasdiener Feb 17, 2026
4aa8229
Merge branch 'dev' into ck-grouped-gemm
matthiasdiener Feb 23, 2026
f680d6a
fix copyright header
matthiasdiener Feb 23, 2026
6d85088
simplify calls in dispatch_grouped
matthiasdiener Feb 23, 2026
7910038
remove is_mi3*0_class
matthiasdiener Feb 23, 2026
e8ebb0e
disable unused constants
matthiasdiener Feb 23, 2026
deb7474
Merge remote-tracking branch 'origin/dev' into ck-grouped-gemm
matthiasdiener Feb 24, 2026
e866bc6
add another fallback
matthiasdiener Feb 24, 2026
ee438fb
implement Primus-Turbo selection logic, persistent descs
matthiasdiener Feb 25, 2026
b65dbfa
Merge remote-tracking branch 'origin/dev' into ck-grouped-gemm
matthiasdiener Feb 25, 2026
0cbf1cd
tighten tolerances
matthiasdiener Feb 25, 2026
98e0c66
use namespace, various cleanups
matthiasdiener Feb 25, 2026
36bd68e
avoid creating vector with Tensors
matthiasdiener Feb 26, 2026
070c58d
Merge remote-tracking branch 'origin/dev' into ck-grouped-gemm
matthiasdiener Feb 26, 2026
c5d83a4
merge dispatch_grouped into ck_tile_grouped_gemm
matthiasdiener Feb 26, 2026
56afb04
Merge remote-tracking branch 'origin/dev' into ck-grouped-gemm
matthiasdiener Feb 27, 2026
26dfbb6
same tolerances for gfx950
matthiasdiener Feb 27, 2026
5a7eb69
feat(gemm): enable TensorQuant pipeline for FP8 on GFX942
matthiasdiener Feb 27, 2026
54da682
Include Float8 E4M3/E5M2 in is_supported_dtype and remove float8 from…
aris134 Mar 4, 2026
78a702f
forward pass ck_tile with matching FP8 data type inputs passing accur…
aris134 Mar 4, 2026
f198341
Support mixed FP8/BF8 grouped GEMM in CK backward path
aris134 Mar 6, 2026
6b24be2
include more descriptive comment regarding tensor normalization in ck…
aris134 Mar 6, 2026
a161c20
Refactor CK grouped GEMM: split FP8/FP16 implementations and introduc…
aris134 Mar 9, 2026
af95382
Add explicit template instantiations for CK grouped GEMM runners
aris134 Mar 9, 2026
9990db3
Split CK grouped GEMM implementation to reduce compile-time coupling
aris134 Mar 10, 2026
bff80fe
Split CK grouped GEMM explicit instantiations by operand dtype
aris134 Mar 10, 2026
e9cd6b8
Add runtime architecture dispatch for CK FP8 grouped GEMM (gfx942/gfx…
aris134 Mar 11, 2026
32f2ac3
Merge dev and fix FP8 logic/CMake
aris134 Mar 11, 2026
5275ac6
Fix dev merge conflicts in CMakeLists.txt
aris134 Mar 11, 2026
3940748
add copyright headers and add blank line at bottom of each new file. …
aris134 Mar 12, 2026
76f207a
Merge branch 'dev' into amartin/ck-grouped-gemm-fp8
aris134 Mar 19, 2026
670d2c4
Fix cudnn-frontend submodule pointer after dev merge
aris134 Mar 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions buildit.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash

export NVTE_FRAMEWORK=pytorch
export NVTE_ROCM_ARCH=gfx950
export NVTE_USE_ROCM=1
export CU_NUM=256

pip install -U ninja psutil pybind11

export NVTE_AITER_PREBUILT_BASE_URL=https://compute-artifactory.amd.com:5000/artifactory/rocm-generic-local/te-ci/aiter-prebuilts

pip install -ve . --no-build-isolation

10 changes: 8 additions & 2 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2167,17 +2167,23 @@ def test_grouped_linear_accuracy(
torch.cuda.get_device_capability() != (9, 0) and not IS_HIP_EXTENSION,
reason="Only enable CUTLASS grouped gemm on Hopper",
)
#@pytest.mark.parametrize("dtype", param_types, ids=str)

@pytest.mark.parametrize("dtype", param_types, ids=str)
@pytest.mark.parametrize("num_gemms", [3, 6])
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"])
@pytest.mark.parametrize("fp8_model_params", all_boolean)
@pytest.mark.parametrize("recipe", fp8_recipes + [None])
@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean)
@pytest.mark.parametrize("delay_wgrad_compute", all_boolean)
def test_grouped_linear_accuracy_cutlass(
dtype,
num_gemms,
bs,
model,
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
delay_wgrad_compute,
):
Expand All @@ -2187,8 +2193,8 @@ def test_grouped_linear_accuracy_cutlass(
num_gemms,
bs,
model,
None,
False,
recipe,
fp8_model_params,
fuse_wgrad_accumulation,
False,
delay_wgrad_compute,
Expand Down
36 changes: 36 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,42 @@ else()
fused_attn_rocm/fused_attn.cpp
gemm/rocm_gemm.cu
gemm/ck_grouped_gemm.cpp
gemm/ck_grouped_gemm_fp8.cpp
gemm/ck_grouped_gemm_fp8_factory_common.cpp
gemm/ck_grouped_gemm_fp8_factory_gfx942.cpp
gemm/ck_grouped_gemm_fp8_factory_gfx950.cpp
gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx942_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx942_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx942_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx942_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx942_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx942_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx942_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx942_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx942_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx942_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx942_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx942_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp16_gfx950_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_fp8_fp8_fp32_gfx950_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_fp8_fp8_bf16_gfx950_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp16_gfx950_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_bf8_bf8_fp32_gfx950_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_bf8_bf8_bf16_gfx950_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp16_gfx950_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_bf8_fp8_fp32_gfx950_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_bf8_fp8_bf16_gfx950_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp16_gfx950_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_fp8_bf8_fp32_gfx950_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_fp8_bf8_bf16_gfx950_instantiations.cpp
gemm/ck_grouped_gemm_fp16.cpp
gemm/ck_grouped_gemm_fp16_factory.cpp
gemm/instantiations/ck_grouped_gemm_bf16_bf16_bf16_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp16_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_bf16_bf16_fp32_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_fp16_fp16_bf16_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp16_instantiations.cpp
gemm/instantiations/ck_grouped_gemm_fp16_fp16_fp32_instantiations.cpp
amd_detail/system.cpp)
list(APPEND transformer_engine_cuda_sources
fused_attn_rocm/fused_attn_aotriton.cpp
Expand Down
278 changes: 19 additions & 259 deletions transformer_engine/common/gemm/ck_grouped_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,224 +4,8 @@
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/

#include <hip/hip_runtime.h>

#include <transformer_engine/transformer_engine.h>
#include "../common.h"

#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"

namespace transformer_engine {
namespace grouped_gemm {

using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColMajor = ck_tile::tensor_layout::gemm::ColumnMajor;

template <typename TEScalar> struct TETypeToCKType;
template <> struct TETypeToCKType<transformer_engine::fp16> { using type = ck_tile::half_t; };
template <> struct TETypeToCKType<transformer_engine::bf16> { using type = ck_tile::bfloat16_t; };

// Treat TE tensors as generalized 2D matrices by flattening:
// (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim.
static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t,
int64_t& d0, int64_t& d1) {
// Require at least a matrix (rank >= 2). Higher ranks are flattened.
if (t.shape().size() < 2)
return false;
d0 = static_cast<int64_t>(t.flat_first_dim());
d1 = static_cast<int64_t>(t.flat_last_dim());
return true;
}

static inline const transformer_engine::SimpleTensor& data_view(const transformer_engine::Tensor& t) {
return t.data; // rowwise data view
}

// Primus-Turbo-like FP16/BF16 tile configs
// Selection rule:
// if (N % 256 == 0) use 256x256x64
// else if (N % 128 == 0) use 256x128x64
// else use 256x128x64 with N padding enabled
struct TileCfg_256x256x64 {
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64;

static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;

static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;

static constexpr bool DoubleSmemBuffer = false;

static constexpr ck_tile::index_t TilePartitionerGroupNum = 8;
static constexpr ck_tile::index_t TilePartitionerM01 = 4;
};

struct TileCfg_256x128x64 : TileCfg_256x256x64 {
static constexpr ck_tile::index_t N_Tile = 128;
};

struct TileCfg_256x128x64_padding : TileCfg_256x128x64 {
static constexpr bool kPadN = true;
};

// This class instantiates CK_Tile's grouped GEMM pipeline.
// See e.g. https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/03_gemm/universal_gemm_invoker.hpp for reference.
template <typename AType, typename BType, typename CType,
typename ALayout, typename BLayout, typename CLayout,
typename TileCfg, ck_tile::memory_operation_enum MemOp,
typename AccType = float>
struct Runner{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<TileCfg::M_Tile, TileCfg::N_Tile, TileCfg::K_Tile>,
ck_tile::sequence<TileCfg::M_Warp, TileCfg::N_Warp, TileCfg::K_Warp>,
ck_tile::sequence<TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile>>;

using Partitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape, TileCfg::TilePartitionerGroupNum, TileCfg::TilePartitionerM01>;

using UniversalTraits = ck_tile::PersistentTileGemmUniversalTraits<
TileCfg::kPadM, TileCfg::kPadN, TileCfg::kPadK,
TileCfg::DoubleSmemBuffer, ALayout, BLayout, CLayout>;

static constexpr ck_tile::GemmPipelineScheduler Scheduler =
ck_tile::GemmPipelineScheduler::Intrawave;

using Problem = ck_tile::UniversalGemmPipelineProblem<
AType, BType, AccType, GemmShape, UniversalTraits, Scheduler>;

using Pipeline = ck_tile::GemmPipelineAgBgCrCompV3<Problem>;

using Epilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<
AType, BType, ck_tile::tuple<>, AccType,
CType, ck_tile::tuple<>, CLayout,
ck_tile::element_wise::PassThrough,
Partitioner::MPerBlock, Partitioner::NPerBlock,
TileCfg::M_Warp, TileCfg::N_Warp,
TileCfg::M_Warp_Tile, TileCfg::N_Warp_Tile, TileCfg::K_Warp_Tile,
Problem::TransposeC, MemOp>>;

using Kernel = ck_tile::GroupedGemmKernel<Partitioner, Pipeline, Epilogue>;
};

template <typename T, typename ALayout, typename BLayout, typename CLayout,
ck_tile::memory_operation_enum MemOp, typename TileCfg>
static bool run_grouped_impl(const NVTETensor* A_use,
const NVTETensor* B_use,
NVTETensor* D,
int group_num,
bool transA_use,
bool transB_use,
void* workspace,
size_t workspace_bytes,
hipStream_t stream)
{
using Kernel = typename Runner<T, T, T, ALayout, BLayout, CLayout, TileCfg, MemOp>::Kernel;

const size_t needed = Kernel::GetWorkSpaceSize(group_num);
if (!workspace || workspace_bytes < needed) {
NVTE_ERROR("ck_tile_grouped_gemm: insufficient workspace. Needed bytes=", needed);
return false;
}

thread_local std::vector<ck_tile::GroupedGemmHostArgs<0>> descs;
descs.clear();
descs.reserve(group_num);

for (int i = 0; i < group_num; ++i) {
const transformer_engine::Tensor* const A_te =
transformer_engine::convertNVTETensorCheck(A_use[i]);
const transformer_engine::Tensor* const B_te =
transformer_engine::convertNVTETensorCheck(B_use[i]);
transformer_engine::Tensor* D_te =
transformer_engine::convertNVTETensorCheck(D[i]);

const auto& a = data_view(*A_te);
const auto& b = data_view(*B_te);
const auto& d = data_view(*D_te);

int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0, Dd0 = 0, Dd1 = 0;
if (!get_flat_2d_dims(*A_te, Ad0, Ad1) ||
!get_flat_2d_dims(*B_te, Bd0, Bd1) ||
!get_flat_2d_dims(*D_te, Dd0, Dd1)) {
NVTE_ERROR("ck_tile_grouped_gemm: expected all groups to be rank>=2 (2D or higher).");
return false;
}

const int64_t M = transA_use ? Ad1 : Ad0;
const int64_t K = transA_use ? Ad0 : Ad1;
const int64_t N = transB_use ? Bd0 : Bd1;
const int64_t Kb = transB_use ? Bd1 : Bd0;

if (Kb != K) {
NVTE_ERROR("ck_tile_grouped_gemm: K mismatch between A and B in group ", i);
return false;
}

if (Dd0 != M || Dd1 != N) {
NVTE_ERROR("ck_tile_grouped_gemm: D shape mismatch in group ", i);
return false;
}

// Leading dimensions under the flattened-contiguous interpretation
const ck_tile::index_t stride_A = Ad1;
const ck_tile::index_t stride_B = Bd1;
const ck_tile::index_t stride_E = Dd1;

descs.emplace_back(
a.dptr,
b.dptr,
std::array<const void*, 0>{},
d.dptr,
1,
M,
N,
K,
stride_A,
stride_B,
std::array<ck_tile::index_t, 0>{},
stride_E);
}

const dim3 grids = Kernel::GridSize(descs);
auto kargs = Kernel::MakeKargs(descs);
if (!Kernel::IsSupportedArgument(kargs)) {
NVTE_ERROR("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config.");
return false;
}

HIP_CHECK_ERROR(hipMemcpyAsync(workspace,
kargs.data(),
kargs.size() * sizeof(typename decltype(kargs)::value_type),
hipMemcpyHostToDevice,
stream));

const ck_tile::stream_config s{stream};
const dim3 blocks = Kernel::BlockSize();

ck_tile::launch_kernel(
s,
ck_tile::make_kernel<1>(
Kernel{}, grids, blocks, 0,
ck_tile::cast_pointer_to_constant_address_space(workspace),
group_num));
return true;
}

} // namespace grouped_gemm
} // namespace transformer_engine

#include "ck_grouped_gemm_common.h"
#include <iostream>
bool ck_tile_grouped_gemm(const NVTETensor* A,
const NVTETensor* B,
NVTETensor* D,
Expand All @@ -230,10 +14,10 @@ bool ck_tile_grouped_gemm(const NVTETensor* A,
bool transB,
NVTETensor* workspace,
bool accumulate,
hipStream_t stream)
{
if (group_num <= 0)
hipStream_t stream) {
if (group_num <= 0) {
return true;
}

using namespace transformer_engine;
using namespace transformer_engine::grouped_gemm;
Expand All @@ -256,52 +40,28 @@ bool ck_tile_grouped_gemm(const NVTETensor* A,
const bool transB_use = transA;

const auto a_dtype = convertNVTETensorCheck(A_use[0])->dtype();
const auto b_dtype = convertNVTETensorCheck(B_use[0])->dtype();

// Get N from D[0] (assume uniform N across groups)
int64_t ref_d0 = 0, ref_d1 = 0;
Tensor* D0_te = convertNVTETensorCheck(D[0]);
const auto d_dtype = D0_te->dtype();

int64_t ref_d0 = 0, ref_d1 = 0;
if (!get_flat_2d_dims(*D0_te, ref_d0, ref_d1)) {
NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for D[0]");
return false;
}
const ck_tile::index_t N = static_cast<ck_tile::index_t>(ref_d1);

TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(a_dtype, te_type, {
using T = typename TETypeToCKType<te_type>::type;

auto run_with_tilecfg = [&](auto tile_tag) -> bool {
using TileCfgSel = decltype(tile_tag);
// construct run context
GroupedGemmRunContext ctx = {A_use, B_use, D, ref_d1, group_num, transA_use, transB_use, accumulate, ws_ptr, ws_bytes, stream};

TRANSFORMER_ENGINE_SWITCH_CONDITION(transA_use, kTransA, {
using ALayout = std::conditional_t<kTransA, ColMajor, RowMajor>;

TRANSFORMER_ENGINE_SWITCH_CONDITION(transB_use, kTransB, {
using BLayout = std::conditional_t<kTransB, ColMajor, RowMajor>;
if (ck_tile_grouped_gemm_fp16_dispatch(a_dtype, b_dtype, d_dtype, ctx)) {
return true;
}

if (accumulate) {
return run_grouped_impl<T, ALayout, BLayout, RowMajor,
ck_tile::memory_operation_enum::atomic_add, TileCfgSel>(
A_use, B_use, D, group_num, kTransA, kTransB, ws_ptr, ws_bytes, stream);
} else {
return run_grouped_impl<T, ALayout, BLayout, RowMajor,
ck_tile::memory_operation_enum::set, TileCfgSel>(
A_use, B_use, D, group_num, kTransA, kTransB, ws_ptr, ws_bytes, stream);
}
});
});
};
if (ck_tile_grouped_gemm_fp8_dispatch(a_dtype, b_dtype, d_dtype, ctx)) {
return true;
}

// Select tile config like Primus-Turbo for FP16/BF16:
// N%256 -> 256x256x64
// N%128 -> 256x128x64
// else -> 256x128x64 padding
// NOTE: We assume N is uniform across groups.
if ((N % 256) == 0) {
return run_with_tilecfg(TileCfg_256x256x64{});
} else if ((N % 128) == 0) {
return run_with_tilecfg(TileCfg_256x128x64{});
} else {
return run_with_tilecfg(TileCfg_256x128x64_padding{});
}
});
NVTE_ERROR("ck_tile_grouped_gemm: unsupported dtype pair for CK path.");
return false;
}
Loading