diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index 2f7ce10..a958e52 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -140,6 +140,92 @@ static void fp8_fp4_gemm_tt(const std::pair& a, d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast); } +static void fp8_gemm_nt_mxfp8out(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& d_sf, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + // Shape must be `[M, K] @ [N, K].T` with MXFP8 output + const auto& major_a = get_major_type_ab(a.first); + const auto& major_b = get_major_type_ab(b.first); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); + + // C/D checks + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d_sf.scalar_type() == torch::kUInt8 or d_sf.scalar_type() == torch::kFloat8_e4m3fn); + check_major_type_cd(d); + + // Type and shape checks + const auto arch_major = device_runtime->get_arch_major(); + const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major); + const auto [n , k_] = check_ab_fp8_fp4(b.first, major_b, arch_major); + const auto [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(n % 32 == 0); + + // Scale shape check + const auto [sm, sn] = get_shape<2>(d_sf); + DG_HOST_ASSERT(sm == m and sn == n / 32); + + // Do nothing if empty + if (m == 0 or n == 0 or k == 0) + return; + + // Transform SFA and SFB into compute-required layout + std::optional> recipe = std::nullopt; + const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout( + a.second, b.second, m, n, k, recipe, std::nullopt, std::nullopt, std::nullopt, std::nullopt, disable_ue8m0_cast); + + // Dispatch + if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { + sm90_fp8_gemm_1d1d_mxfp8out(a.first, sfa, b.first, sfb, d, d_sf, m, n, k, major_a, major_b, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + sm100_fp8_gemm_1d1d_mxfp8out(a.first, sfa, b.first, sfb, d, d_sf, m, n, k, gran_k_a, gran_k_b, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("MXFP8 output requires SM90 or SM100"); + } +} + +static void m_grouped_fp8_gemm_nt_contiguous_mxfp8out(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& d_sf, + const torch::Tensor& grouped_layout, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + const auto& major_a = get_major_type_ab(a.first); + const auto& major_b = get_major_type_ab(b.first); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + DG_HOST_ASSERT(grouped_layout.is_contiguous()); + + const auto arch_major = device_runtime->get_arch_major(); + const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major); + const auto [num_groups, n, k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major); + const auto [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(n % 32 == 0); + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat8_e4m3fn); + + const auto [sm, sn] = get_shape<2>(d_sf); + DG_HOST_ASSERT(sm == m and sn == n / 32); + + check_major_type_cd(d); + if (m == 0) return; + + std::optional> recipe = std::nullopt; + const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout( + a.second, b.second, m, n, k, recipe, std::nullopt, std::nullopt, std::nullopt, num_groups, disable_ue8m0_cast); + + if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + sm100_m_grouped_fp8_gemm_contiguous_1d1d_mxfp8out(a.first, sfa, b.first, sfb, d, d_sf, grouped_layout, + num_groups, m, n, k, gran_k_a, gran_k_b, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("m_grouped MXFP8 output requires SM100"); + } +} + static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair& a, const std::pair& b, const torch::Tensor& d, @@ -652,6 +738,17 @@ static void register_apis(pybind11::module_& m) { py::arg("recipe") = std::make_tuple(1, 1, 128), py::arg("compiled_dims") = "mn"); + // MXFP8 output GEMM + m.def("fp8_gemm_nt_mxfp8out", &fp8_gemm_nt_mxfp8out, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("d_sf"), + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + + m.def("m_grouped_fp8_gemm_nt_contiguous_mxfp8out", &m_grouped_fp8_gemm_nt_contiguous_mxfp8out, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("d_sf"), py::arg("grouped_layout"), + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + // FP8 GEMM alias names m.attr("fp8_gemm_nt") = m.attr("fp8_fp4_gemm_nt"); m.attr("fp8_gemm_nn") = m.attr("fp8_fp4_gemm_nn"); diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index a49584f..badfc82 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -162,7 +162,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k DG_HOST_ASSERT(a_dtype == torch::kFloat8_e4m3fn or a_dtype == kPackedFP4); DG_HOST_ASSERT(b_dtype == torch::kFloat8_e4m3fn or b_dtype == kPackedFP4); } - DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat); + DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat or cd_dtype == torch::kFloat8_e4m3fn); // Select M/N block sizes auto block_ms = ArchSpec::get_block_m_candidates(kernel_type, major_a, m); diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp index 07a977d..316b500 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp @@ -32,6 +32,11 @@ class SM100FP8FP4Gemm1D1DRuntime final: public LaunchRuntimebuild("sm100_fp8_fp4_gemm_1d1d", code); SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); } +static void sm100_fp8_gemm_1d1d_mxfp8out(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, const torch::Tensor& d_sf, + const int& m, const int& n, const int& k, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d_sf.scalar_type() == torch::kUInt8 or d_sf.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(n % 32 == 0); + + // Use BF16 as proxy for config selection, then override swizzle for MXFP8 + auto config = get_best_config( + GemmType::Normal, KernelType::Kernel1D1D, + m, n, k, 1, major_a, major_b, + a.scalar_type(), b.scalar_type(), + torch::kBFloat16, false, + device_runtime->get_num_sms()); + + // Keep original BF16 config as-is (no swizzle override). + // Ensure block_n >= 32 for MXFP8 block size. + if (config.block_n < 32) config.block_n = 32; + + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, gran_k_a, 1, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.block_n, gran_k_b, 1, 0); + + // TMA descriptor for FP8 output: box = STORE_BLOCK_M × 32, no swizzle + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + 32, + static_cast(d.stride(-2)), 1, 0); + + // Override cd_dtype + auto mxfp8_config = config; + mxfp8_config.cd_dtype = torch::kFloat8_e4m3fn; + + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = 1, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, + .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, + .gemm_config = mxfp8_config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd, + .tensor_map_cd_sf = {}, + .gmem_cd_sf_ptr = d_sf.data_ptr(), + .cd_sf_stride = static_cast(n / 32), + .gmem_cd_fp8_ptr = d.data_ptr(), + .cd_fp8_stride = static_cast(n) + }; + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d_mxfp8out", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + static void sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, const torch::Tensor& b, const torch::Tensor& sfb, const torch::Tensor& d, @@ -206,13 +294,88 @@ static void sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(const torch::Tensor& a, .tensor_map_b = tensor_map_b, .tensor_map_sfa = tensor_map_sfa, .tensor_map_sfb = tensor_map_sfb, - .tensor_map_cd = tensor_map_cd + .tensor_map_cd = tensor_map_cd, + .tensor_map_cd_sf = {}, + .gmem_cd_sf_ptr = nullptr, + .cd_sf_stride = 0, + .gmem_cd_fp8_ptr = nullptr, + .cd_fp8_stride = 0 }; const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d", code); SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); } +static void sm100_m_grouped_fp8_gemm_contiguous_1d1d_mxfp8out( + const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, const torch::Tensor& d_sf, + const torch::Tensor& grouped_layout, + const int& num_groups, const int& m, const int& n, const int& k, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(n % 32 == 0); + + auto config = get_best_config( + GemmType::MGroupedContiguous, KernelType::Kernel1D1D, + m, n, k, 1, major_a, major_b, + a.scalar_type(), b.scalar_type(), + torch::kBFloat16, false, + device_runtime->get_num_sms()); + + if (config.block_n < 32) config.block_n = 32; + + // Override cd_dtype + config.cd_dtype = torch::kFloat8_e4m3fn; + + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, gran_k_a, 1, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.block_n, gran_k_b, num_groups, 0); + + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = num_groups, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, + .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = grouped_layout.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + 32, + static_cast(d.stride(-2)), 1, 0), + .tensor_map_cd_sf = {}, + .gmem_cd_sf_ptr = d_sf.data_ptr(), + .cd_sf_stride = static_cast(n / 32), + .gmem_cd_fp8_ptr = d.data_ptr(), + .cd_fp8_stride = static_cast(n) + }; + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d1d_mxfp8out", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + static void sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, const torch::Tensor& b, const torch::Tensor& sfb, const torch::Tensor& d, @@ -267,7 +430,12 @@ static void sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(const torch::Tensor& a, con .tensor_map_b = tensor_map_b, .tensor_map_sfa = tensor_map_sfa, .tensor_map_sfb = tensor_map_sfb, - .tensor_map_cd = tensor_map_cd + .tensor_map_cd = tensor_map_cd, + .tensor_map_cd_sf = {}, + .gmem_cd_sf_ptr = nullptr, + .cd_sf_stride = 0, + .gmem_cd_fp8_ptr = nullptr, + .cd_fp8_stride = 0 }; const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_masked_1d1d", code); @@ -338,7 +506,12 @@ static void sm100_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::T .tensor_map_b = tensor_map_b, .tensor_map_sfa = tensor_map_sfa, .tensor_map_sfb = tensor_map_sfb, - .tensor_map_cd = tensor_map_cd + .tensor_map_cd = tensor_map_cd, + .tensor_map_cd_sf = {}, + .gmem_cd_sf_ptr = nullptr, + .cd_sf_stride = 0, + .gmem_cd_fp8_ptr = nullptr, + .cd_fp8_stride = 0 }; const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_k_grouped_fp8_gemm_1d1d", code); @@ -406,7 +579,12 @@ static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, .tensor_map_b = tensor_map_b, .tensor_map_sfa = tensor_map_sfa, .tensor_map_sfb = tensor_map_sfb, - .tensor_map_cd = tensor_map_cd + .tensor_map_cd = tensor_map_cd, + .tensor_map_cd_sf = {}, + .gmem_cd_sf_ptr = nullptr, + .cd_sf_stride = 0, + .gmem_cd_fp8_ptr = nullptr, + .cd_fp8_stride = 0 }; const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code); diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp index e61841b..4cdbcae 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp @@ -30,6 +30,7 @@ class SM90FP8Gemm1D1DRuntime final: public LaunchRuntime CUtensorMap tensor_map_sfa; CUtensorMap tensor_map_sfb; CUtensorMap tensor_map_cd; + CUtensorMap tensor_map_cd_sf; }; static std::string generate_impl(const Args& args) { @@ -71,7 +72,7 @@ static void __instantiate_kernel() {{ args.m, args.n, args.k, args.tensor_map_a_base, args.tensor_map_b_base, args.tensor_map_sfa, args.tensor_map_sfb, - args.tensor_map_cd)); + args.tensor_map_cd, args.tensor_map_cd_sf)); } }; @@ -132,6 +133,7 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, .tensor_map_sfa = tensor_map_sfa, .tensor_map_sfb = tensor_map_sfb, .tensor_map_cd = tensor_map_cd, + .tensor_map_cd_sf = {}, }; const auto& code = SM90FP8Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code); @@ -139,6 +141,88 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, SM90FP8Gemm1D1DRuntime::launch(runtime, args); } +static void sm90_fp8_gemm_1d1d_mxfp8out(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, const torch::Tensor& d_sf, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d_sf.scalar_type() == torch::kUInt8); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(n % 32 == 0); + + // Use FP8 cd_dtype for config selection — but heuristics expect BF16 or FP32 + // We use BF16 as proxy since FP8 SMEM is even smaller + const auto& config = get_best_config( + GemmType::Normal, KernelType::Kernel1D1D, + m, n, k, 1, major_a, major_b, + a.scalar_type(), b.scalar_type(), + torch::kBFloat16, false, + device_runtime->get_num_sms()); + + // Requires no TMA splits + DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); + DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); + + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, k, 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, k, 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, 1, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.block_n, config.block_k, 1, 0); + + // TMA descriptor for FP8 output data (no swizzle for now) + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM90ArchSpec::get_cd_store_block_m(config.block_m, true), + config.block_n, + static_cast(d.stride(-2)), 1, + 0); + + // TMA descriptor for E8M0 scale output + const int sf_n = n / 32; + const auto& tensor_map_cd_sf = make_tma_cd_desc(d_sf, m, sf_n, + SM90ArchSpec::get_cd_store_block_m(config.block_m, true), + config.block_n / 32, + sf_n, 1, + 0); + + // Override cd_dtype to FP8 for kernel template instantiation + auto mxfp8_config = config; + mxfp8_config.cd_dtype = torch::kFloat8_e4m3fn; + + // Launch + const SM90FP8Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = 1, + .compiled_dims = compiled_dims, + .gemm_config = mxfp8_config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .gmem_a_ptr = nullptr, + .gmem_b_ptr = nullptr, + .grouped_layout = nullptr, + .tensor_map_buffer = nullptr, + .tensor_map_a_base = tensor_map_a, + .tensor_map_b_base = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd, + .tensor_map_cd_sf = tensor_map_cd_sf, + }; + const auto& code = SM90FP8Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d_mxfp8out", code); + + SM90FP8Gemm1D1DRuntime::launch(runtime, args); +} + static void sm90_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, const torch::Tensor& b, const torch::Tensor& sfb, const std::optional& c, @@ -208,6 +292,7 @@ static void sm90_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Te .tensor_map_sfa = tensor_map_sfa, .tensor_map_sfb = tensor_map_sfb, .tensor_map_cd = tensor_map_cd, + .tensor_map_cd_sf = {}, }; const auto& code = SM90FP8Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code); diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index 1c07f5d..a922be6 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -45,6 +45,9 @@ m_grouped_fp8_gemm_nt_masked, k_grouped_fp8_gemm_nt_contiguous, k_grouped_fp8_gemm_tn_contiguous, + # MXFP8 output GEMMs + fp8_gemm_nt_mxfp8out, + m_grouped_fp8_gemm_nt_contiguous_mxfp8out, # BF16 GEMMs bf16_gemm_nt, bf16_gemm_nn, bf16_gemm_tn, bf16_gemm_tt, @@ -66,6 +69,8 @@ get_mk_alignment_for_contiguous_layout ) + # Some alias for legacy supports + # TODO: remove these later # Some alias for legacy supports # TODO: remove these later fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_gemm_nt_masked diff --git a/deep_gemm/include/deep_gemm/common/utils.cuh b/deep_gemm/include/deep_gemm/common/utils.cuh index 8fb6c2f..e159bfe 100644 --- a/deep_gemm/include/deep_gemm/common/utils.cuh +++ b/deep_gemm/include/deep_gemm/common/utils.cuh @@ -158,6 +158,36 @@ __device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) { return *reinterpret_cast(&bf16x2); } +// MXFP8 epilogue utilities + +// Compute E8M0 exponent from max absolute value +// E = clamp(floor(log2(max_abs / 448.0)) + 127, 0, 254) +__device__ __forceinline__ uint8_t compute_e8m0_exponent(float max_abs) { + if (max_abs <= 0.0f) return 0; + int e = __float2int_ru(log2f(max_abs * (1.0f / 448.0f))) + 127; + return static_cast(min(max(e, 0), 254)); +} + +// Warp-level max reduction across col_idx (4 threads that each hold part of a 32-element N group) +// After this, all 4 threads (col_idx 0-3) share the same max value +__device__ __forceinline__ float warp_reduce_max_4(float val) { + val = fmaxf(val, __shfl_xor_sync(0xffffffff, val, 1)); + val = fmaxf(val, __shfl_xor_sync(0xffffffff, val, 2)); + return val; +} + +// Pack 4 FP8 bytes into a uint32 +__device__ __forceinline__ uint32_t pack_fp8x4(uint8_t a, uint8_t b, uint8_t c, uint8_t d) { + return static_cast(a) | (static_cast(b) << 8) | + (static_cast(c) << 16) | (static_cast(d) << 24); +} + +// Convert float to FP8 E4M3 (saturating) +__device__ __forceinline__ uint8_t float_to_fp8_e4m3_sat(float x) { + __nv_fp8_e4m3 fp8_val = __nv_fp8_e4m3(x); + return *reinterpret_cast(&fp8_val); +} + __device__ __forceinline__ void prefetch_l1(void *ptr) { asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); } diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh index 45a603a..f126ebf 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -33,7 +33,10 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, const __grid_constant__ cute::TmaDescriptor tensor_map_b, const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, - const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { + const __grid_constant__ cute::TmaDescriptor tensor_map_cd, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd_sf, + uint8_t* gmem_cd_sf_ptr, uint32_t cd_sf_stride, + uint8_t* gmem_cd_fp8_ptr, uint32_t cd_fp8_stride) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) using Barrier = cutlass::arch::ClusterTransactionBarrier; using Allocator = cute::conditional_t; @@ -42,6 +45,9 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, if constexpr (kWithAccumulation) DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + // MXFP8 output mode: use FP32 SMEM path internally, then convert to FP8 + E8M0 scales + constexpr bool kIsMXFP8Output = cute::is_same_v; + // Configs constexpr uint32_t LAYOUT_AD_M = 128; constexpr uint32_t WAVE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); @@ -75,7 +81,9 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); - constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); + // For MXFP8 output, use FP32 TMEM load granularity internally + using internal_cd_dtype_t = cute::conditional_t; + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(internal_cd_dtype_t); constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M; DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D"); @@ -83,6 +91,9 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); // Share memory sizes + // For MXFP8: SMEM_CD holds FP8 output + E8M0 scales (separate from TMEM load staging) + constexpr uint32_t SMEM_CD_FP8_SIZE = kIsMXFP8Output ? constexpr_align(STORE_BLOCK_M * BLOCK_N, 1024u) : 0; + constexpr uint32_t SMEM_CD_SF_SIZE = kIsMXFP8Output ? constexpr_align(STORE_BLOCK_M * (BLOCK_N / 32), 128u) : 0; constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); @@ -118,6 +129,8 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, cute::prefetch_tma_descriptor(&tensor_map_sfa); cute::prefetch_tma_descriptor(&tensor_map_sfb); cute::prefetch_tma_descriptor(&tensor_map_cd); + if constexpr (kIsMXFP8Output) + cute::prefetch_tma_descriptor(&tensor_map_cd_sf); } // D/A/B shared memory @@ -434,7 +447,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // TMA checks constexpr uint32_t kNumBankGroupBytes = 16; - constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(internal_cd_dtype_t); DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); @@ -457,91 +470,152 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); - // Iterate over M waves - #pragma unroll - for (uint32_t w = 0; w < kNumMWaves; ++ w) { - // Issue every swizzled atom and pipeline STSM and TMA store - constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + if constexpr (not kIsMXFP8Output) { + // Original epilogue: TMEM → SMEM → TMA store #pragma unroll - for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { - // Wait shared memory to be released - if (epilogue_warp_idx == 0) - cute::tma_store_wait(); - cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; + const auto n_idx = epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + s * STORE_BLOCK_N); + + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + + w * BLOCK_N + + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + + epilogue_warp_idx * 32 * kSwizzleCDMode + + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; + + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, + cast_into_bf16_and_pack(values[0], values[1]), + cast_into_bf16_and_pack(values[2], values[3]), + cast_into_bf16_and_pack(values[4], values[5]), + cast_into_bf16_and_pack(values[6], values[7])); + } + } - // The pipeline stage - const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; - const auto n_idx = epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + s * STORE_BLOCK_N); + if (w == kNumMWaves - 1 and s == kNumStores - 1) { + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } - // Store into shared memory + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], + n_idx, m_idx, scheduler.current_group_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx); + } + cute::tma_store_arrive(); + } + } + } + } else { + // MXFP8 epilogue: TMEM → registers → quantize → SMEM → TMA store + // Independent s-loop with own TMA pipeline, same structure as BF16 path + constexpr uint32_t kNumGroups32 = BLOCK_N / 32; + const uint32_t local_m = epilogue_warp_idx * 32 + lane_idx; + const uint32_t tmem_base = accum_stage_idx * kNumMWaves * BLOCK_N; + + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { #pragma unroll - for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { - // Calculate the index of the bank group to be written in the atom - auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); - - // Reshape the atom in another view and swizzle - // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` - // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` - // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern - constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; - auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); - auto col = kHasShortcut ? (i) : (bank_group_index % 8); - col ^= row % (kSwizzleCDMode / 16); - - // Source and destination memory address - uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset - w * BLOCK_N + // Wave offset - s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset - auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer - epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset - row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset - - // Load from tensor memory, store into shared memory - uint32_t values[kNumElemsPerBankGroup]; - if constexpr (cute::is_same_v) { - // For FP32 output, read and store - DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); - cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, - values[0], values[1], values[2], values[3]); - cutlass::arch::fence_view_async_tmem_load(); - st_shared(smem_ptr, values[0], values[1], values[2], values[3]); - } else { - // For BF16 output, read, cast and store - DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); - cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, - values[0], values[1], values[2], values[3], - values[4], values[5], values[6], values[7]); + for (uint32_t g = 0; g < kNumGroups32; ++ g, advance_store_pipeline()) { + // Pipeline: wait for SMEM stage to be free + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; + + // Load 32 consecutive N-elements from TMEM (8 loads × 4 values) + float group_vals[32]; + float group_max = 0.0f; + + #pragma unroll + for (uint32_t i = 0; i < 8; ++ i) { + uint32_t tmem_addr = tmem_base + w * BLOCK_N + g * 32 + i * 4; + uint32_t raw[4]; + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, raw[0], raw[1], raw[2], raw[3]); cutlass::arch::fence_view_async_tmem_load(); - st_shared(smem_ptr, - cast_into_bf16_and_pack(values[0], values[1]), - cast_into_bf16_and_pack(values[2], values[3]), - cast_into_bf16_and_pack(values[4], values[5]), - cast_into_bf16_and_pack(values[6], values[7])); + #pragma unroll + for (uint32_t j = 0; j < 4; ++ j) { + float val = *reinterpret_cast(&raw[j]); + group_vals[i * 4 + j] = val; + group_max = fmaxf(group_max, fabsf(val)); + } } - } - // Notify tensor memory empty (only at the leader CTA) arrival ASAP - // NOTES: only the last stage needs to do this - if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { - tcgen05_before_thread_sync(); - tmem_empty_barriers[accum_stage_idx]->arrive(0u); - } + // Quantize + uint8_t e8m0 = compute_e8m0_exponent(group_max); + float inv_scale = (e8m0 == 0 and group_max == 0.0f) ? 0.0f : exp2f(127.0f - static_cast(e8m0)); + + // Write FP8 to SMEM (32 bytes per row, contiguous, no swizzle) + auto smem_fp8_row = reinterpret_cast(smem_cd[tma_stage_idx]) + + (epilogue_warp_idx * 32 + lane_idx) * 32; + #pragma unroll + for (uint32_t j = 0; j < 32; j += 4) { + uint32_t packed = pack_fp8x4( + float_to_fp8_e4m3_sat(group_vals[j + 0] * inv_scale), + float_to_fp8_e4m3_sat(group_vals[j + 1] * inv_scale), + float_to_fp8_e4m3_sat(group_vals[j + 2] * inv_scale), + float_to_fp8_e4m3_sat(group_vals[j + 3] * inv_scale)); + st_shared(reinterpret_cast(smem_fp8_row + j), packed); + } + + // Write E8M0 scale directly to global memory (too small for TMA) + if (local_m < shape_m) { + const auto gmem_m_idx = m_idx + local_m; + gmem_cd_sf_ptr[gmem_m_idx * cd_sf_stride + n_block_idx * (BLOCK_N / 32) + g] = e8m0; + } + + // Notify tensor memory empty after last group of last wave + if (w == kNumMWaves - 1 and g == kNumGroups32 - 1) { + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } - // Synchronize all threads and issue TMA - cute::tma_store_fence(); - cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); - if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { - if constexpr (kGemmType == GemmType::Batched) { - using cute_tma_t = cute::conditional_t; - cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], - n_idx, m_idx, scheduler.current_group_idx); - } else { - using cute_tma_t = cute::conditional_t; - cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx); + // TMA store: SMEM → HBM + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_cd, + smem_cd[tma_stage_idx], + n_block_idx * BLOCK_N + g * 32, m_idx); + cute::tma_store_arrive(); } - cute::tma_store_arrive(); } } } diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh index 2c24c5e..35e0f65 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh @@ -36,12 +36,14 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, const __grid_constant__ cute::TmaDescriptor tensor_map_b_base, const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, - const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { + const __grid_constant__ cute::TmaDescriptor tensor_map_cd, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd_sf) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) // Scaling checks DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid Threads"); DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); - DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + DG_STATIC_ASSERT(cute::is_same_v or cute::is_same_v, "Invalid C/D data dtype"); + constexpr bool kIsMXFP8Output = cute::is_same_v; DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type"); // Types @@ -56,7 +58,10 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, // Shared memory static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 4 : 0); - static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float); + // For MXFP8 output: FP8 data + E8M0 scales; for FP32 output: FP32 accumulators + static constexpr uint32_t SMEM_D_FP8_SIZE = constexpr_align(BLOCK_M * BLOCK_N * sizeof(__nv_fp8_e4m3), 1024u); + static constexpr uint32_t SMEM_D_SF_SIZE = constexpr_align(BLOCK_M * (BLOCK_N / 32) * sizeof(uint8_t), 128u); + static constexpr uint32_t SMEM_D_SIZE = kIsMXFP8Output ? (SMEM_D_FP8_SIZE + SMEM_D_SF_SIZE) : (BLOCK_M * BLOCK_N * sizeof(float)); static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); @@ -75,6 +80,8 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, cute::prefetch_tma_descriptor(&tensor_map_sfa); cute::prefetch_tma_descriptor(&tensor_map_sfb); cute::prefetch_tma_descriptor(&tensor_map_cd); + if constexpr (kIsMXFP8Output) + cute::prefetch_tma_descriptor(&tensor_map_cd_sf); } __syncwarp(); @@ -347,25 +354,100 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, cute::tma_store_wait<0>(); cutlass::arch::NamedBarrier::sync(128, math_wg_idx); - // Store to D shared memory - const auto& smem_d_0 = reinterpret_cast(smem_d + r_0 * BLOCK_N + col_idx * 2); - const auto& smem_d_1 = reinterpret_cast(smem_d + r_1 * BLOCK_N + col_idx * 2); - #pragma unroll - for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { - st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]}); - st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]}); - } - cute::tma_store_fence(); - cutlass::arch::NamedBarrier::sync(128, math_wg_idx); + if constexpr (not kIsMXFP8Output) { + // FP32 output path: store to D shared memory + const auto& smem_d_0 = reinterpret_cast(smem_d + r_0 * BLOCK_N + col_idx * 2); + const auto& smem_d_1 = reinterpret_cast(smem_d + r_1 * BLOCK_N + col_idx * 2); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]}); + st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]}); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(128, math_wg_idx); + + // Use TMA store to write back to global memory + if (warp_idx % 4 == 0 and cute::elect_one_sync()) { + cute::SM90_TMA_REDUCE_ADD_2D::copy( + &tensor_map_cd, smem_d_0, n_block_idx * BLOCK_N, + current_group_idx * shape_m + m_block_idx * BLOCK_M + r_0); + cute::tma_store_arrive(); + } + __syncwarp(); + } else { + // MXFP8 output path: quantize FP32 accumulators to FP8 + E8M0 scales + auto smem_d_fp8 = reinterpret_cast(smem_buffer + SMEM_TENSOR_MAP_SIZE); + auto smem_d_sf = reinterpret_cast(smem_buffer + SMEM_TENSOR_MAP_SIZE + SMEM_D_FP8_SIZE); + + // Process each group of 32 N-elements + // In WGMMA layout: col_idx = lane_idx % 4, each thread holds pairs at N positions col_idx*2 + i*8, col_idx*2 + i*8 + 1 + // A group of 32 N-elements spans 4 consecutive i values (i_base, i_base+1, i_base+2, i_base+3) + constexpr uint32_t kNumGroups32 = BLOCK_N / 32; - // Use TMA store to write back to global memory - if (warp_idx % 4 == 0 and cute::elect_one_sync()) { - cute::SM90_TMA_REDUCE_ADD_2D::copy( - &tensor_map_cd, smem_d_0, n_block_idx * BLOCK_N, - current_group_idx * shape_m + m_block_idx * BLOCK_M + r_0); - cute::tma_store_arrive(); + #pragma unroll + for (uint32_t g = 0; g < kNumGroups32; ++ g) { + // Each group of 32 N-elements maps to i_base = g * 4 in the accumulator indexing + const uint32_t i_base = g * 4; + + // Compute local max of this thread's 8 elements in the group (4 pairs for r_0 and r_1) + float local_max_r0 = 0.0f, local_max_r1 = 0.0f; + #pragma unroll + for (uint32_t di = 0; di < 4; ++ di) { + const uint32_t acc_idx = (i_base + di) * 4; + local_max_r0 = fmaxf(local_max_r0, fmaxf(fabsf(final_accum[acc_idx + 0]), fabsf(final_accum[acc_idx + 1]))); + local_max_r1 = fmaxf(local_max_r1, fmaxf(fabsf(final_accum[acc_idx + 2]), fabsf(final_accum[acc_idx + 3]))); + } + + // Warp-level reduction across col_idx (4 threads) for each row + float group_max_r0 = warp_reduce_max_4(local_max_r0); + float group_max_r1 = warp_reduce_max_4(local_max_r1); + + // Compute E8M0 exponents + uint8_t e8m0_r0 = compute_e8m0_exponent(group_max_r0); + uint8_t e8m0_r1 = compute_e8m0_exponent(group_max_r1); + float inv_scale_r0 = (e8m0_r0 == 0 and group_max_r0 == 0.0f) ? 0.0f : exp2f(127.0f - static_cast(e8m0_r0)); + float inv_scale_r1 = (e8m0_r1 == 0 and group_max_r1 == 0.0f) ? 0.0f : exp2f(127.0f - static_cast(e8m0_r1)); + + // Convert to FP8 and store to SMEM + #pragma unroll + for (uint32_t di = 0; di < 4; ++ di) { + const uint32_t acc_idx = (i_base + di) * 4; + uint8_t fp8_r0_0 = float_to_fp8_e4m3_sat(final_accum[acc_idx + 0] * inv_scale_r0); + uint8_t fp8_r0_1 = float_to_fp8_e4m3_sat(final_accum[acc_idx + 1] * inv_scale_r0); + uint8_t fp8_r1_0 = float_to_fp8_e4m3_sat(final_accum[acc_idx + 2] * inv_scale_r1); + uint8_t fp8_r1_1 = float_to_fp8_e4m3_sat(final_accum[acc_idx + 3] * inv_scale_r1); + + // N position = col_idx * 2 + di * 8 + const uint32_t n_pos = col_idx * 2 + di * 8; + smem_d_fp8[r_0 * BLOCK_N + g * 32 + n_pos + 0] = fp8_r0_0; + smem_d_fp8[r_0 * BLOCK_N + g * 32 + n_pos + 1] = fp8_r0_1; + smem_d_fp8[r_1 * BLOCK_N + g * 32 + n_pos + 0] = fp8_r1_0; + smem_d_fp8[r_1 * BLOCK_N + g * 32 + n_pos + 1] = fp8_r1_1; + } + + // Store E8M0 scale (one per 32-element group per row, only col_idx == 0 writes) + if (col_idx == 0) { + smem_d_sf[r_0 * (BLOCK_N / 32) + g] = e8m0_r0; + smem_d_sf[r_1 * (BLOCK_N / 32) + g] = e8m0_r1; + } + } + + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(128, math_wg_idx); + + // TMA store FP8 data and scales to global memory + if (warp_idx % 4 == 0 and cute::elect_one_sync()) { + const uint32_t m_idx = current_group_idx * shape_m + m_block_idx * BLOCK_M + r_0; + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_cd, smem_d_fp8 + r_0 * BLOCK_N, + n_block_idx * BLOCK_N, m_idx); + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_cd_sf, smem_d_sf + r_0 * (BLOCK_N / 32), + n_block_idx * (BLOCK_N / 32), m_idx); + cute::tma_store_arrive(); + } + __syncwarp(); } - __syncwarp(); } } #else diff --git a/deep_gemm/utils/math.py b/deep_gemm/utils/math.py index c65026e..3fdc7b0 100644 --- a/deep_gemm/utils/math.py +++ b/deep_gemm/utils/math.py @@ -60,6 +60,19 @@ def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) - return x_scaled, sf.squeeze() +def mxfp8_quantize_output(x: torch.Tensor, block_size: int = 32) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + assert x.size(1) % block_size == 0, f"N ({x.size(1)}) must be divisible by block_size ({block_size})" + m, n = x.shape + x_float = x.float() + x_view = x_float.view(m, n // block_size, block_size) + x_amax = x_view.abs().amax(dim=2).clamp(1e-4) + e8m0_exp = torch.clamp(torch.ceil(torch.log2(x_amax / 448.0)) + 127, 0, 254).to(torch.uint8) + scale = torch.pow(2.0, e8m0_exp.float() - 127.0).unsqueeze(2) + fp8_data = (x_view / scale).to(torch.float8_e4m3fn).view(m, n) + return fp8_data, e8m0_exp + + def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor: ax = x.abs().clamp_max(6.0) # {0, 0.5, 1, 1.5, 2, 3, 4, 6} diff --git a/tests/test-mxfp8.yaml b/tests/test-mxfp8.yaml new file mode 100644 index 0000000..f37b3a5 --- /dev/null +++ b/tests/test-mxfp8.yaml @@ -0,0 +1,103 @@ +apiVersion: trainer.kubeflow.org/v1alpha1 +kind: TrainJob +metadata: + name: jangwoong-dg-mxfp8v7 + namespace: kbm-g-np-motif +spec: + podTemplateOverrides: + - metadata: + annotations: + k8s.v1.cni.cncf.io/networks: |- + kbm-g-np-motif, + kbm-g-np-motif, + kbm-g-np-motif, + kbm-g-np-motif, + kbm-g-np-motif, + kbm-g-np-motif, + kbm-g-np-motif, + kbm-g-np-motif + spec: + containers: + - name: node + volumeMounts: + - mountPath: /dev/shm + name: shm + - mountPath: /mair + name: mair + volumes: + - emptyDir: + medium: Memory + sizeLimit: 64Gi + name: shm + - name: mair + persistentVolumeClaim: + claimName: mair + targetJobs: + - name: node + runtimeRef: + apiGroup: trainer.kubeflow.org + kind: ClusterTrainingRuntime + name: torch-distributed + suspend: false + trainer: + args: + - /bin/bash + - '-c' + - | + echo "=== MXFP8 Epilogue Test on $(hostname) ===" + nvidia-smi | head -5 + + # Install DeepGEMM (force rebuild) + cd /mair/team-sys/jangwoong/DeepGEMM + rm -f deep_gemm/_C.cpython-*.so + rm -rf build/ + DG_USE_LOCAL_VERSION=0 pip install -e . --no-build-isolation -q 2>&1 | tail -3 + DG_USE_LOCAL_VERSION=0 python setup.py build_ext --inplace --force 2>&1 | tail -5 + python -c "import deep_gemm._C; print('Loaded:', deep_gemm._C.__file__)" + + export DEEPGEMM_ROOT=/mair/team-sys/jangwoong/DeepGEMM + export PYTHONPATH=$PYTHONPATH:$DEEPGEMM_ROOT + export CPATH=$DEEPGEMM_ROOT/third-party/cutlass/include:${CPATH:-} + export CPLUS_INCLUDE_PATH=$DEEPGEMM_ROOT/third-party/cutlass/include:${CPLUS_INCLUDE_PATH:-} + export DG_JIT_CACHE_DIR=/mair/torch_cache/deepgemm/jangwoong-mxfp8-test + rm -rf $DG_JIT_CACHE_DIR + mkdir -p $DG_JIT_CACHE_DIR + + # Run breakdown + profiling + cd /mair/team-sys/jangwoong/DeepGEMM + find . -name '__pycache__' -exec rm -rf {} + 2>/dev/null || true + mkdir -p traces + cd tests + python -B test_mxfp8_profile.py 2>&1 + + echo "=== Test complete ===" + env: + - name: PYTHONUNBUFFERED + value: '1' + - name: CUDA_VISIBLE_DEVICES + value: '0' + - name: CUDA_LAUNCH_BLOCKING + value: '1' + - name: NODE_NAME + valueFrom: + fieldRef: + fieldPath: spec.nodeName + - name: JOB_NAME + valueFrom: + fieldRef: + fieldPath: metadata.labels['job-name'] + image: >- + ghcr.io/motiftechnologies/llm-training:v0.1.6 + numNodes: 1 + numProcPerNode: 8 + resourcesPerNode: + limits: + cpu: '96' + memory: 1024Gi + nvidia.com/gpu: '8' + nvidia.com/mlnxnics: '8' + requests: + cpu: '96' + memory: 1024Gi + nvidia.com/gpu: '8' + nvidia.com/mlnxnics: '8' diff --git a/tests/test_mxfp8_breakdown.py b/tests/test_mxfp8_breakdown.py new file mode 100644 index 0000000..6122a76 --- /dev/null +++ b/tests/test_mxfp8_breakdown.py @@ -0,0 +1,206 @@ +import sys +import torch +from typing import Tuple + +sys.path.insert(0, '/mair/team-sys/jangwoong/DeepGEMM/tests') + +import deep_gemm +from deep_gemm.testing import calc_diff, get_arch_major +from deep_gemm.utils import mxfp8_quantize_output + +from generators import ( + KernelType, QuantConfig, MajorTypeAB, get_ue8m0_usage, + generate_normal, generate_m_grouped_contiguous, +) + + +def bench(fn, num_iters=100): + for _ in range(5): + fn() + torch.cuda.synchronize() + s = torch.cuda.Event(enable_timing=True) + e = torch.cuda.Event(enable_timing=True) + s.record() + for _ in range(num_iters): + fn() + e.record() + torch.cuda.synchronize() + return s.elapsed_time(e) / num_iters + + +def dequantize_mxfp8(fp8_data, e8m0_scales_uint8): + m, n = fp8_data.shape + scale = torch.pow(2.0, e8m0_scales_uint8.float() - 127.0) + return (fp8_data.float().view(m, n // 32, 32) * scale.unsqueeze(2)).view(m, n) + + +# Register custom ops for torch.compile tracing +@torch.library.custom_op("deepgemm::fp8_gemm_nt_bf16out", mutates_args=("d",)) +def _fp8_gemm_nt_bf16out( + a_data: torch.Tensor, a_sf: torch.Tensor, + b_data: torch.Tensor, b_sf: torch.Tensor, + d: torch.Tensor, + disable_ue8m0_cast: bool, +) -> None: + deep_gemm.fp8_gemm_nt((a_data, a_sf), (b_data, b_sf), d, disable_ue8m0_cast=disable_ue8m0_cast) + +@_fp8_gemm_nt_bf16out.register_fake +def _fp8_gemm_nt_bf16out_fake(a_data, a_sf, b_data, b_sf, d, disable_ue8m0_cast): + pass + + +@torch.library.custom_op("deepgemm::m_grouped_fp8_gemm_nt_bf16out", mutates_args=("d",)) +def _m_grouped_fp8_gemm_nt_bf16out( + a_data: torch.Tensor, a_sf: torch.Tensor, + b_data: torch.Tensor, b_sf: torch.Tensor, + d: torch.Tensor, + grouped_layout: torch.Tensor, + disable_ue8m0_cast: bool, +) -> None: + deep_gemm.m_grouped_fp8_gemm_nt_contiguous((a_data, a_sf), (b_data, b_sf), d, grouped_layout, disable_ue8m0_cast=disable_ue8m0_cast) + +@_m_grouped_fp8_gemm_nt_bf16out.register_fake +def _m_grouped_fp8_gemm_nt_bf16out_fake(a_data, a_sf, b_data, b_sf, d, grouped_layout, disable_ue8m0_cast): + pass + + +def run_normal_gemm_test(): + print('='*100) + print('[Normal GEMM] Accuracy + Performance (baseline: compiled GEMM+quantize graph)') + print('='*100) + + arch = get_arch_major() + use_ue8m0 = get_ue8m0_usage(KernelType.Kernel1D1D) + quant_config = QuantConfig() + disable_cast = not use_ue8m0 + + shapes = [ + (1, 7168, 2048), + (128, 7168, 2048), + (256, 7168, 2048), + (4096, 7168, 2048), + (128, 4096, 7168), + (256, 4096, 7168), + ] + + print(f'\nArch: SM{arch}0') + print(f'{"M":>6} {"N":>6} {"K":>6} | {"Fused":>7} | {"BL compiled":>11} | {"Spdup":>5} | {"Diff":>8}') + print('-'*70) + + for m, n, k in shapes: + a, b, c, d_dummy, ref_d = generate_normal( + m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, + False, torch.bfloat16, KernelType.Kernel1D1D, + use_ue8m0=use_ue8m0, quant_config=quant_config) + + # Fused + d_fp8 = torch.empty((m, n), device='cuda', dtype=torch.float8_e4m3fn) + d_sf = torch.empty((m, n // 32), device='cuda', dtype=torch.float8_e4m3fn) + t_fused = bench(lambda: deep_gemm.fp8_gemm_nt_mxfp8out( + a, b, d_fp8, d_sf, disable_ue8m0_cast=disable_cast)) + + # Baseline: compiled GEMM + quantize as one graph + d_bf16 = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + + def baseline_fn(): + _fp8_gemm_nt_bf16out(a[0], a[1], b[0], b[1], d_bf16, disable_cast) + return mxfp8_quantize_output(d_bf16, block_size=32) + + try: + compiled_bl = torch.compile(baseline_fn, fullgraph=True) + fg = 'FG=T' + except Exception: + compiled_bl = torch.compile(baseline_fn, fullgraph=False) + fg = 'FG=F' + + # Warmup compiled + for _ in range(3): + compiled_bl() + torch.cuda.synchronize() + + t_bl = bench(compiled_bl) + + # Accuracy + deep_gemm.fp8_gemm_nt_mxfp8out(a, b, d_fp8, d_sf, disable_ue8m0_cast=disable_cast) + bl_fp8, bl_sf = compiled_bl() + deq_fused = dequantize_mxfp8(d_fp8, d_sf.view(torch.uint8)) + deq_bl = dequantize_mxfp8(bl_fp8, bl_sf) + diff = calc_diff(deq_fused, deq_bl) + + speedup = t_bl / t_fused if t_fused > 0 else 0 + print(f'{m:6} {n:6} {k:6} | {t_fused:5.3f}ms | {t_bl:9.3f}ms | {speedup:4.2f}x | {diff:8.5f} {fg}') + + print() + + +def run_grouped_gemm_test(): + print('='*100) + print('[Grouped GEMM] Accuracy + Performance (baseline: compiled GEMM+quantize graph)') + print('='*100) + + arch = get_arch_major() + use_ue8m0 = get_ue8m0_usage(KernelType.Kernel1D1D) + quant_config = QuantConfig() + disable_cast = not use_ue8m0 + + grouped_shapes = [ + (4, 8192, 7168, 2048), + (8, 4096, 7168, 2048), + (4, 8192, 4096, 7168), + (8, 4096, 4096, 7168), + (48, 1280, 1280, 4096), + ] + + print(f'\nArch: SM{arch}0') + print(f'{"G":>3} {"M":>6} {"N":>6} {"K":>6} | {"Fused":>7} | {"BL compiled":>11} | {"Spdup":>5} | {"Diff":>8}') + print('-'*75) + + for num_groups, expected_m, n, k in grouped_shapes: + m, a, b, grouped_layout, d_dummy, ref_d = generate_m_grouped_contiguous( + num_groups, expected_m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, + use_ue8m0=use_ue8m0, quant_config=quant_config) + + # Fused + d_fp8 = torch.empty((m, n), device='cuda', dtype=torch.float8_e4m3fn) + d_sf = torch.empty((m, n // 32), device='cuda', dtype=torch.float8_e4m3fn) + t_fused = bench(lambda: deep_gemm.m_grouped_fp8_gemm_nt_contiguous_mxfp8out( + a, b, d_fp8, d_sf, grouped_layout, disable_ue8m0_cast=disable_cast)) + + # Baseline: compiled GEMM + quantize as one graph + d_bf16 = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + + def baseline_fn(): + _m_grouped_fp8_gemm_nt_bf16out(a[0], a[1], b[0], b[1], d_bf16, grouped_layout, disable_cast) + return mxfp8_quantize_output(d_bf16, block_size=32) + + try: + compiled_bl = torch.compile(baseline_fn, fullgraph=True) + fg = 'FG=T' + except Exception: + compiled_bl = torch.compile(baseline_fn, fullgraph=False) + fg = 'FG=F' + + # Warmup + for _ in range(3): + compiled_bl() + torch.cuda.synchronize() + + t_bl = bench(compiled_bl) + + # Accuracy + deep_gemm.m_grouped_fp8_gemm_nt_contiguous_mxfp8out(a, b, d_fp8, d_sf, grouped_layout, disable_ue8m0_cast=disable_cast) + bl_fp8, bl_sf = compiled_bl() + deq_fused = dequantize_mxfp8(d_fp8, d_sf.view(torch.uint8)) + deq_bl = dequantize_mxfp8(bl_fp8, bl_sf) + diff = calc_diff(deq_fused, deq_bl) + + speedup = t_bl / t_fused if t_fused > 0 else 0 + print(f'{num_groups:3} {m:6} {n:6} {k:6} | {t_fused:5.3f}ms | {t_bl:9.3f}ms | {speedup:4.2f}x | {diff:8.5f} {fg}') + + print() + + +if __name__ == '__main__': + run_normal_gemm_test() + run_grouped_gemm_test() + print('All tests done!')