Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions csrc/apis/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,92 @@ static void fp8_fp4_gemm_tt(const std::pair<torch::Tensor, torch::Tensor>& a,
d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast);
}

static void fp8_gemm_nt_mxfp8out(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& d_sf,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
// Shape must be `[M, K] @ [N, K].T` with MXFP8 output
const auto& major_a = get_major_type_ab(a.first);
const auto& major_b = get_major_type_ab(b.first);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);

// C/D checks
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(d_sf.scalar_type() == torch::kUInt8 or d_sf.scalar_type() == torch::kFloat8_e4m3fn);
check_major_type_cd(d);

// Type and shape checks
const auto arch_major = device_runtime->get_arch_major();
const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major);
const auto [n , k_] = check_ab_fp8_fp4(b.first, major_b, arch_major);
const auto [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(n % 32 == 0);

// Scale shape check
const auto [sm, sn] = get_shape<2>(d_sf);
DG_HOST_ASSERT(sm == m and sn == n / 32);

// Do nothing if empty
if (m == 0 or n == 0 or k == 0)
return;

// Transform SFA and SFB into compute-required layout
std::optional<std::tuple<int, int, int>> recipe = std::nullopt;
const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
a.second, b.second, m, n, k, recipe, std::nullopt, std::nullopt, std::nullopt, std::nullopt, disable_ue8m0_cast);

// Dispatch
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
sm90_fp8_gemm_1d1d_mxfp8out(a.first, sfa, b.first, sfb, d, d_sf, m, n, k, major_a, major_b, compiled_dims);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_fp8_gemm_1d1d_mxfp8out(a.first, sfa, b.first, sfb, d, d_sf, m, n, k, gran_k_a, gran_k_b, major_a, major_b, compiled_dims);
} else {
DG_HOST_UNREACHABLE("MXFP8 output requires SM90 or SM100");
}
}

static void m_grouped_fp8_gemm_nt_contiguous_mxfp8out(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& d_sf,
const torch::Tensor& grouped_layout,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
const auto& major_a = get_major_type_ab(a.first);
const auto& major_b = get_major_type_ab(b.first);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(grouped_layout.is_contiguous());

const auto arch_major = device_runtime->get_arch_major();
const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major);
const auto [num_groups, n, k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major);
const auto [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(n % 32 == 0);
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat8_e4m3fn);

const auto [sm, sn] = get_shape<2>(d_sf);
DG_HOST_ASSERT(sm == m and sn == n / 32);

check_major_type_cd(d);
if (m == 0) return;

std::optional<std::tuple<int, int, int>> recipe = std::nullopt;
const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
a.second, b.second, m, n, k, recipe, std::nullopt, std::nullopt, std::nullopt, num_groups, disable_ue8m0_cast);

if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_m_grouped_fp8_gemm_contiguous_1d1d_mxfp8out(a.first, sfa, b.first, sfb, d, d_sf, grouped_layout,
num_groups, m, n, k, gran_k_a, gran_k_b, major_a, major_b, compiled_dims);
} else {
DG_HOST_UNREACHABLE("m_grouped MXFP8 output requires SM100");
}
}

static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
Expand Down Expand Up @@ -652,6 +738,17 @@ static void register_apis(pybind11::module_& m) {
py::arg("recipe") = std::make_tuple(1, 1, 128),
py::arg("compiled_dims") = "mn");

// MXFP8 output GEMM
m.def("fp8_gemm_nt_mxfp8out", &fp8_gemm_nt_mxfp8out,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("d_sf"),
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);

m.def("m_grouped_fp8_gemm_nt_contiguous_mxfp8out", &m_grouped_fp8_gemm_nt_contiguous_mxfp8out,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("d_sf"), py::arg("grouped_layout"),
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);

// FP8 GEMM alias names
m.attr("fp8_gemm_nt") = m.attr("fp8_fp4_gemm_nt");
m.attr("fp8_gemm_nn") = m.attr("fp8_fp4_gemm_nn");
Expand Down
2 changes: 1 addition & 1 deletion csrc/jit_kernels/heuristics/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
DG_HOST_ASSERT(a_dtype == torch::kFloat8_e4m3fn or a_dtype == kPackedFP4);
DG_HOST_ASSERT(b_dtype == torch::kFloat8_e4m3fn or b_dtype == kPackedFP4);
}
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat);
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat or cd_dtype == torch::kFloat8_e4m3fn);

// Select M/N block sizes
auto block_ms = ArchSpec::get_block_m_candidates(kernel_type, major_a, m);
Expand Down
190 changes: 184 additions & 6 deletions csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ class SM100FP8FP4Gemm1D1DRuntime final: public LaunchRuntime<SM100FP8FP4Gemm1D1D
CUtensorMap tensor_map_sfa;
CUtensorMap tensor_map_sfb;
CUtensorMap tensor_map_cd;
CUtensorMap tensor_map_cd_sf;
void* gmem_cd_sf_ptr;
uint32_t cd_sf_stride;
void* gmem_cd_fp8_ptr;
uint32_t cd_fp8_stride;
};

static std::string generate_impl(const Args& args) {
Expand Down Expand Up @@ -79,7 +84,9 @@ static void __instantiate_kernel() {{
args.grouped_layout, args.m, args.n, args.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_sfa, args.tensor_map_sfb,
args.tensor_map_cd));
args.tensor_map_cd, args.tensor_map_cd_sf,
args.gmem_cd_sf_ptr, args.cd_sf_stride,
args.gmem_cd_fp8_ptr, args.cd_fp8_stride));
}
};

Expand Down Expand Up @@ -137,13 +144,94 @@ static void sm100_fp8_fp4_gemm_1d1d(const torch::Tensor& a, const torch::Tensor&
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
.tensor_map_cd = tensor_map_cd,
.tensor_map_cd_sf = {},
.gmem_cd_sf_ptr = nullptr,
.cd_sf_stride = 0,
.gmem_cd_fp8_ptr = nullptr,
.cd_fp8_stride = 0
};
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_fp4_gemm_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}

static void sm100_fp8_gemm_1d1d_mxfp8out(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d, const torch::Tensor& d_sf,
const int& m, const int& n, const int& k,
const int& gran_k_a, const int& gran_k_b,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(d_sf.scalar_type() == torch::kUInt8 or d_sf.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(n % 32 == 0);

// Use BF16 as proxy for config selection, then override swizzle for MXFP8
auto config = get_best_config<SM100ArchSpec>(
GemmType::Normal, KernelType::Kernel1D1D,
m, n, k, 1, major_a, major_b,
a.scalar_type(), b.scalar_type(),
torch::kBFloat16, false,
device_runtime->get_num_sms());

// Keep original BF16 config as-is (no swizzle override).
// Ensure block_n >= 32 for MXFP8 block size.
if (config.block_n < 32) config.block_n = 32;

const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, gran_k_a, 1, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.block_n, gran_k_b, 1, 0);

// TMA descriptor for FP8 output: box = STORE_BLOCK_M × 32, no swizzle
const auto& tensor_map_cd = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
32,
static_cast<int>(d.stride(-2)), 1, 0);

// Override cd_dtype
auto mxfp8_config = config;
mxfp8_config.cd_dtype = torch::kFloat8_e4m3fn;

const SM100FP8FP4Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = k,
.num_groups = 1,
.gran_k_a = gran_k_a,
.gran_k_b = gran_k_b,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = mxfp8_config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd,
.tensor_map_cd_sf = {},
.gmem_cd_sf_ptr = d_sf.data_ptr(),
.cd_sf_stride = static_cast<uint32_t>(n / 32),
.gmem_cd_fp8_ptr = d.data_ptr(),
.cd_fp8_stride = static_cast<uint32_t>(n)
};
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d_mxfp8out", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}

static void sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
Expand Down Expand Up @@ -206,13 +294,88 @@ static void sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(const torch::Tensor& a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
.tensor_map_cd = tensor_map_cd,
.tensor_map_cd_sf = {},
.gmem_cd_sf_ptr = nullptr,
.cd_sf_stride = 0,
.gmem_cd_fp8_ptr = nullptr,
.cd_fp8_stride = 0
};
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}

static void sm100_m_grouped_fp8_gemm_contiguous_1d1d_mxfp8out(
const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d, const torch::Tensor& d_sf,
const torch::Tensor& grouped_layout,
const int& num_groups, const int& m, const int& n, const int& k,
const int& gran_k_a, const int& gran_k_b,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(n % 32 == 0);

auto config = get_best_config<SM100ArchSpec>(
GemmType::MGroupedContiguous, KernelType::Kernel1D1D,
m, n, k, 1, major_a, major_b,
a.scalar_type(), b.scalar_type(),
torch::kBFloat16, false,
device_runtime->get_num_sms());

if (config.block_n < 32) config.block_n = 32;

// Override cd_dtype
config.cd_dtype = torch::kFloat8_e4m3fn;

const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, gran_k_a, 1, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.block_n, gran_k_b, num_groups, 0);

const SM100FP8FP4Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = k,
.num_groups = num_groups,
.gran_k_a = gran_k_a,
.gran_k_b = gran_k_b,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.grouped_layout = grouped_layout.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
32,
static_cast<int>(d.stride(-2)), 1, 0),
.tensor_map_cd_sf = {},
.gmem_cd_sf_ptr = d_sf.data_ptr(),
.cd_sf_stride = static_cast<uint32_t>(n / 32),
.gmem_cd_fp8_ptr = d.data_ptr(),
.cd_fp8_stride = static_cast<uint32_t>(n)
};
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d1d_mxfp8out", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}

static void sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
Expand Down Expand Up @@ -267,7 +430,12 @@ static void sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(const torch::Tensor& a, con
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
.tensor_map_cd = tensor_map_cd,
.tensor_map_cd_sf = {},
.gmem_cd_sf_ptr = nullptr,
.cd_sf_stride = 0,
.gmem_cd_fp8_ptr = nullptr,
.cd_fp8_stride = 0
};
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_masked_1d1d", code);
Expand Down Expand Up @@ -338,7 +506,12 @@ static void sm100_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::T
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
.tensor_map_cd = tensor_map_cd,
.tensor_map_cd_sf = {},
.gmem_cd_sf_ptr = nullptr,
.cd_sf_stride = 0,
.gmem_cd_fp8_ptr = nullptr,
.cd_fp8_stride = 0
};
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_k_grouped_fp8_gemm_1d1d", code);
Expand Down Expand Up @@ -406,7 +579,12 @@ static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
.tensor_map_cd = tensor_map_cd,
.tensor_map_cd_sf = {},
.gmem_cd_sf_ptr = nullptr,
.cd_sf_stride = 0,
.gmem_cd_fp8_ptr = nullptr,
.cd_fp8_stride = 0
};
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
Expand Down
Loading