diff --git a/mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h b/mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h index b37907d05c8d..5377cb6c2280 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h +++ b/mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h @@ -31,7 +31,13 @@ enum class MfmaTypeId : uint32_t { Fp8Fp8TyId, Fp8Bf8TyId, Bf8Fp8TyId, - Bf8Bf8TyId + Bf8Bf8TyId, + // FP8 via scaled MFMA (uses mfma_scale_f32_16x16x128_f8f6f4 with cbsz=0) + // These provide larger K dimension (128 for 16x16, 64 for 32x32) + Fp8Fp8ScaledTyId, + Fp8Bf8ScaledTyId, + Bf8Fp8ScaledTyId, + Bf8Bf8ScaledTyId }; struct MfmaInsnInfo { @@ -155,6 +161,10 @@ class MfmaInsnGroup { bool isCoherentWithK(int64_t kPack, int64_t kPerBlock, int64_t scheduleVersion); SmallString<16> getROCDLIntrinsicName() { return groupAttr.insn; } + + // Check if this is FP8 using scaled MFMA (mfma_scale with cbsz=0, blgp=0) + // These instructions have larger K dimension (128 for 16x16, 64 for 32x32) + bool isScaledFp8() const; }; } // namespace rock diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td index 796e6d6dae9c..5f24ff6605c5 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td @@ -578,7 +578,7 @@ def Rock_LDSTransposeConfigAttr : Rock_Attr<"LDSTransposeConfig", []> { and tiling parameters. - DDim: Matrix-multiply accelerator instruction D dimension (M or N, typically 16 or 32) - - KDim: Matrix-multiply accelerator instruction K dimension (typically 8, 16, or 32) + - KDim: Matrix-multiply accelerator instruction K dimension (8, 16, 32, 64, or 128) - mPerBlock: M dimension size per block - nPerBlock: N dimension size per block - kPerBlock: K dimension size per block diff --git a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td index e4b0ee7496d6..dbc0f666c528 100644 --- a/mlir/include/mlir/Dialect/Rock/IR/RockOps.td +++ b/mlir/include/mlir/Dialect/Rock/IR/RockOps.td @@ -1229,9 +1229,10 @@ defvar SameShapeVectorOfI1 = [{ def Rock_LDSTransposeLoadOp : Rock_Op<"lds_transpose_load", [DeclareOpInterfaceMethods< MemoryEffectsOpInterface>]>, - Arguments<(ins Arg, "LDS source buffer">:$source, + Arguments<(ins Arg, + "LDS source buffer">:$source, Variadic:$indices)>, - Results<(outs VectorOfLengthAndType<[4], [F16, BF16]>:$result)> { + Results<(outs AnyVectorOfNonZeroRank:$result)> { let summary = "Hardware-assisted LDS transpose load for matrix accelerator tile"; let description = [{ @@ -1240,9 +1241,15 @@ def Rock_LDSTransposeLoadOp The tile dimensions match the selected matrix-multiply accelerator instruction geometry (`dDim × instrK`), where: - `dDim` is the accelerator M/N dimension (e.g., 16 or 32) - - `instrK` is the accelerator K dimension (e.g., 8, 16, or 32) - The operation returns a vector of 4 elements per thread containing - transposed elements in a layout suitable for matrix accelerator instructions. + - `instrK` is the accelerator K dimension (8, 16, 32, 64, or 128) + + For 16-bit types (f16, bf16): + - Uses ds_read_tr16_b64 instruction + - Returns vector<4xtype> (4 elements per thread) + + For 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8 for gfx950): + - Uses ds_read_tr8_b64 instruction + - Returns vector<8xtype> (8 elements per thread) Benefits: - Reduces LDS bank conflicts through optimized access patterns diff --git a/mlir/include/mlir/Dialect/Rock/utility/LdsTransposeLoad.h b/mlir/include/mlir/Dialect/Rock/utility/LdsTransposeLoad.h index 07d87d6aca33..c1f9c6cd15bc 100644 --- a/mlir/include/mlir/Dialect/Rock/utility/LdsTransposeLoad.h +++ b/mlir/include/mlir/Dialect/Rock/utility/LdsTransposeLoad.h @@ -21,8 +21,17 @@ // to the LDS transpose load operation in an accelerator-friendly layout. // // It is intended to simplify the IR generation logic and ensure -// consistent handling of f16/bf16 matrix accelerator tile loads from LDS -// memory. +// consistent handling of f16/bf16/fp8/bf8 matrix accelerator tile loads from +// LDS memory. +// +// Supported element types: +// - f16, bf16: uses ds_read_tr16_b64 (returns 4 elements per thread) +// - f8E4M3FN, f8E5M2 (OCP FP8): uses ds_read_tr8_b64 (returns 8 elements) +// +// Supported MFMA geometries: +// - Standard: (16,16), (16,32), (32,8), (32,16) - single-rate or double-rate +// - Scaled FP8: (16,128) - quad-rate (4 ds_read_tr8 calls per K tile) +// - Scaled FP8: (32,64) - quad-rate (4 ds_read_tr8 calls per K tile) // //===----------------------------------------------------------------------===// @@ -43,7 +52,7 @@ enum class OperandKind { A, B }; // Build LDS transpose config attribute from already-computed MFMA params. // Used in BlockwiseLoadTileToThreadwise when decision was made upstream. // Requires mfmaDDim > 0 and mfmaKDim > 0 (asserted). -// Valid combinations: (16,16), (16,32), (32,8), (32,16) +// Valid combinations: (16,16), (16,32), (16,128), (32,8), (32,16), (32,64) LDSTransposeConfigAttr buildTransposeAttrFromParams( PatternRewriter &rewriter, int64_t mfmaDDim, int64_t mfmaKDim, int64_t mPerBlock, int64_t nPerBlock, int64_t kPerBlock, int64_t mPerWave, @@ -60,7 +69,7 @@ struct LDSTransposeDecision { bool enableA{false}; // Enable for operand A bool enableB{false}; // Enable for operand B int64_t mfmaDDim{0}; // MFMA D dimension (M or N, 16 or 32) - int64_t mfmaKDim{0}; // MFMA K dimension (8, 16, or 32) + int64_t mfmaKDim{0}; // MFMA K dimension (8, 16, 32, 64, or 128) }; // Decides whether to enable LDS transpose for operands A and B diff --git a/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp b/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp index f02ba3e9012f..19ebb4296bfe 100644 --- a/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp +++ b/mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp @@ -112,7 +112,13 @@ static auto getMfmaInsnInfoMap = []() -> const llvm::StringMap & { {ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(), {MfmaTypeId::Bf8Bf8TyId, 16, 32, 1}}, - // fp4 + // Scaled MFMA instructions (FP4 and scaled FP8 types) + // Note: FP8 scaled types (Fp8Fp8ScaledTyId, Fp8Bf8ScaledTyId, etc.) + // use the same underlying instruction with identical (mfmaDDim, k, + // blocksMfma). + // Since deriveAttr only uses those fields (not MfmaTypeId), we only need + // one entry per instruction. The type differentiation happens elsewhere + // via cbsz/blgp parameters at code generation time. {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(), {MfmaTypeId::Fp4TyId, 16, 128, 1}}, {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(), @@ -448,6 +454,25 @@ static auto getMfmaInsnGroupAttrMapGfx950 = []() { {{MfmaTypeId::Fp4TyId, 32, 32}, {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}}, + // FP8 via scaled MFMA (cbsz=0, blgp=0 mode) + // 16x16 with K=128, 32x32 with K=64 + {{MfmaTypeId::Fp8Fp8ScaledTyId, 16, 16}, + {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Fp8Fp8ScaledTyId, 32, 32}, + {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Fp8Bf8ScaledTyId, 16, 16}, + {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Fp8Bf8ScaledTyId, 32, 32}, + {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Bf8Fp8ScaledTyId, 16, 16}, + {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Bf8Fp8ScaledTyId, 32, 32}, + {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Bf8Bf8ScaledTyId, 16, 16}, + {ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName()}}, + {{MfmaTypeId::Bf8Bf8ScaledTyId, 32, 32}, + {ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}}, + // i8 double rate {{MfmaTypeId::I8TyId, 32, 32}, {ROCDL::mfma_i32_32x32x32_i8::getOperationName()}}, @@ -583,6 +608,28 @@ static MfmaTypeId convertTypesToId(Type dataTypeA, Type dataTypeB) { llvm_unreachable("Unsupported input argument type."); } +// Convert native FP8 TypeId to scaled FP8 TypeId for gfx950 scaled MFMA +static std::optional getScaledFp8TypeId(MfmaTypeId nativeTypeId) { + switch (nativeTypeId) { + case MfmaTypeId::Fp8Fp8TyId: + return MfmaTypeId::Fp8Fp8ScaledTyId; + case MfmaTypeId::Fp8Bf8TyId: + return MfmaTypeId::Fp8Bf8ScaledTyId; + case MfmaTypeId::Bf8Fp8TyId: + return MfmaTypeId::Bf8Fp8ScaledTyId; + case MfmaTypeId::Bf8Bf8TyId: + return MfmaTypeId::Bf8Bf8ScaledTyId; + default: + return std::nullopt; + } +} + +// Check if this is a native FP8 type (not scaled) +static bool isNativeFp8TypeId(MfmaTypeId typeId) { + return typeId == MfmaTypeId::Fp8Fp8TyId || typeId == MfmaTypeId::Fp8Bf8TyId || + typeId == MfmaTypeId::Bf8Fp8TyId || typeId == MfmaTypeId::Bf8Bf8TyId; +} + FailureOr MfmaInsnGroup::select(Type elementTypeA, Type elementTypeB, StringRef arch, int64_t mnPerXdl, int64_t kPack, int64_t kPackPerBlock, @@ -624,6 +671,44 @@ MfmaInsnGroup::select(Type elementTypeA, Type elementTypeB, StringRef arch, }; auto selectForGfx950 = [&]() { + int64_t kPerBlock = kPack * kPackPerBlock; + + // For FP8 types, try scaled MFMA first if kPerBlock is large enough + // Scaled MFMA has K=128 for 16x16 and K=64 for 32x32 + if (isNativeFp8TypeId(key.type)) { + int64_t scaledK = (mPerMfmaGroup == 16) ? 128 : 64; + if (kPerBlock >= scaledK) { + LLVM_DEBUG(llvm::dbgs() + << ">>> Trying scaled FP8: kPerBlock=" << kPerBlock + << " >= scaledK=" << scaledK << "\n"); + auto scaledTypeId = getScaledFp8TypeId(key.type); + if (scaledTypeId) { + MfmaInsnGroupSelectKey scaledKey = {*scaledTypeId, mPerMfmaGroup, + nPerMfmaGroup}; + const auto &gfx950Map = getMfmaInsnGroupAttrMapGfx950(); + auto it = gfx950Map.find(scaledKey); + if (it != gfx950Map.end()) { + MfmaInsnGroupAttr groupAttr = (*it).second; + auto maybeInsn = MfmaInsn::select(groupAttr.insn); + if (succeeded(maybeInsn)) { + auto scaledResult = MfmaInsnGroup(elementTypeA, elementTypeB, + *maybeInsn, groupAttr); + if (scaledResult.isCoherentWithK(kPack, kPackPerBlock, + scheduleVersion)) { + LLVM_DEBUG(llvm::dbgs() << ">>> SELECTED SCALED FP8 MFMA: K=" + << maybeInsn->getAttr().k << "\n"); + result = scaledResult; + return; + } + } + } + } + LLVM_DEBUG( + llvm::dbgs() + << ">>> Scaled FP8 MFMA not suitable, falling back to native\n"); + } + } + // gfx950 has double rate instructions. Select from those first. selectFrom(getMfmaInsnGroupAttrMapGfx950()); if (succeeded(result)) { @@ -714,3 +799,27 @@ bool MfmaInsnGroup::isCoherentWithK(int64_t kpack, int64_t kPerBlock, int64_t scheduleVersion) { return insn.isCoherentWithK(kpack, kPerBlock, scheduleVersion); } + +bool MfmaInsnGroup::isScaledFp8() const { + // Check if the instruction is a scaled MFMA + // (rocdl.mfma.scale.f32.*x*x*.f8f6f4) + StringRef insnName = groupAttr.insn; + bool isScaledInsn = insnName.contains("mfma.scale.f32.16x16x128.f8f6f4") || + insnName.contains("mfma.scale.f32.32x32x64.f8f6f4"); + if (!isScaledInsn) + return false; + + // Check if the element type is FP8 (not FP4) + // FP8 types: Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ + // FP4 types: Float4E2M1FN + bool isFp8A = isa(elementTypeA) || + isa(elementTypeA) || + isa(elementTypeA) || + isa(elementTypeA); + bool isFp8B = isa(elementTypeB) || + isa(elementTypeB) || + isa(elementTypeB) || + isa(elementTypeB); + + return isFp8A && isFp8B; +} diff --git a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp index 52914cb74829..d1dabcb34af8 100644 --- a/mlir/lib/Dialect/Rock/IR/RockDialect.cpp +++ b/mlir/lib/Dialect/Rock/IR/RockDialect.cpp @@ -560,8 +560,11 @@ LogicalResult TransformMapAttr::verify( // Helper function to check valid MFMA geometry for LDS transpose static bool isValidLdsTransposeMfmaGeometry(int64_t dDim, int64_t kDim) { - return (dDim == 16 && (kDim == 16 || kDim == 32)) || - (dDim == 32 && (kDim == 8 || kDim == 16)); + // Supported geometries: + // Standard: (16,16), (16,32), (32,8), (32,16) + // Scaled FP8: (16,128), (32,64) + return (dDim == 16 && (kDim == 16 || kDim == 32 || kDim == 128)) || + (dDim == 32 && (kDim == 8 || kDim == 16 || kDim == 64)); } LogicalResult LDSTransposeConfigAttr::verify( @@ -571,9 +574,10 @@ LogicalResult LDSTransposeConfigAttr::verify( // Validate MFMA geometry if (!isValidLdsTransposeMfmaGeometry(dDim, kDim)) { - return emitError() << "invalid MFMA geometry (" << dDim << "x" << kDim - << ") for LDS transpose - valid combinations: " - "(16,16), (16,32), (32,8), (32,16)"; + return emitError() + << "invalid MFMA geometry (" << dDim << "x" << kDim + << ") for LDS transpose - valid combinations: " + "(16,16), (16,32), (16,128), (32,8), (32,16), (32,64)"; } // Validate positive dimensions @@ -2161,6 +2165,27 @@ LogicalResult LDSTransposeLoadOp::verify() { << srcElemType << ")"; } + // Verify result vector length based on element type: + // - 16-bit types (f16, bf16): ds_read_tr16_b64 returns 4 elements + // - 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8 for gfx950): ds_read_tr8_b64 + // returns 8 elements + int64_t expectedVecLen; + if (srcElemType.isF16() || srcElemType.isBF16()) { + expectedVecLen = 4; + } else if (isa(srcElemType) || + isa(srcElemType)) { + expectedVecLen = 8; + } else { + return emitOpError("unsupported element type for LDS transpose load: ") + << srcElemType; + } + + if (resultType.getNumElements() != expectedVecLen) { + return emitOpError("expected result vector of ") + << expectedVecLen << " elements for " << srcElemType + << " type, but got " << resultType.getNumElements(); + } + // Check hardware support using AmdArchDb StringRef arch = rock::getArchValue(*this); AmdArchInfo archInfo = rock::lookupArchInfo(arch); diff --git a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp index 258b5369d7bc..2621fe5e048f 100644 --- a/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp +++ b/mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp @@ -191,6 +191,22 @@ void MfmaEmitter::emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA, VectorType vectorType = mfmaGroup.getRetType(); auto outputOffset = llvm::to_vector(regCOffset); bool isScaled = scaleA && scaleB; + bool isScaledFp8 = mfmaGroup.isScaledFp8(); + + // For scaled FP8 MFMA without explicit scale buffers, create neutral scales. + // In cbsz=0, blgp=0 mode, a scale exponent value of 0 means no scaling + // because 2^0 = 1. For Float8E8M0FNU (an exponent-only format), the call + // getFloatAttr(scaleType, 0.0) is used to produce the encoding with + // exponent = 0 (all-zero bit pattern), which corresponds to a scale of 1. + Value neutralScaleA, neutralScaleB; + if (isScaledFp8 && !isScaled) { + Type scaleType = b.getType(); + auto neutralScaleAttr = b.getFloatAttr(scaleType, 0.0); + neutralScaleA = + arith::ConstantOp::create(b, loc, scaleType, neutralScaleAttr); + neutralScaleB = + arith::ConstantOp::create(b, loc, scaleType, neutralScaleAttr); + } for (int64_t i = 0; i < nResultVectors; ++i) { Value offset = b.createOrFold(loc, i); @@ -203,11 +219,21 @@ void MfmaEmitter::emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA, Value vectorD; if (isScaled) { + // Explicit scale buffers provided (FP4 or scaled FP8 with explicit scales) auto mfma = amdgpu::ScaledMFMAOp::create( b, loc, vectorType, mfmaDDim, mfmaDDim, mfmaAttr.k, argA, argB, vectorC, scaleA, scaleB, /*scalesIdxA=*/0, /*scalesIdxB=*/0); vectorD = mfma.getDestD(); + } else if (isScaledFp8) { + // Scaled FP8 MFMA (K=128 for 16x16, K=64 for 32x32) without explicit scales + // Use neutral scale values (0) which means 2^0 = 1 (no scaling) + auto mfma = amdgpu::ScaledMFMAOp::create( + b, loc, vectorType, mfmaDDim, mfmaDDim, mfmaAttr.k, argA, argB, + vectorC, neutralScaleA, neutralScaleB, + /*scalesIdxA=*/0, /*scalesIdxB=*/0); + vectorD = mfma.getDestD(); } else { + // Regular MFMA auto mfma = amdgpu::MFMAOp::create( b, loc, vectorType, mfmaDDim, mfmaDDim, mfmaAttr.k, mfmaAttr.blocksMfma, argA, argB, vectorC, /*cbsz=*/imms[i].cbsz, diff --git a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp index f65e1487c847..2264ac06642f 100644 --- a/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp +++ b/mlir/lib/Dialect/Rock/utility/LdsTransposeLoad.cpp @@ -21,8 +21,13 @@ // to the LDS transpose load operation in an accelerator-friendly layout. // // It is intended to simplify the IR generation logic and ensure -// consistent handling of f16/bf16 matrix accelerator tile loads from LDS -// memory. +// consistent handling of f16/bf16/fp8/bf8 matrix accelerator tile loads from +// LDS memory. +// +// Supported element types: +// - f16, bf16: uses ds_read_tr16_b64 (returns 4 elements per thread) +// - f8E4M3FN, f8E5M2 (OCP FP8): uses ds_read_tr8_b64 (returns 8 elements per +// thread) // //===----------------------------------------------------------------------===// @@ -47,11 +52,39 @@ namespace { bool archSupported(StringRef arch) { return arch.contains("gfx950"); } +// Check if element type is supported for LDS transpose load +// - f16, bf16: ds_read_tr16_b64 (4 elements) +// - f8E4M3FN, f8E5M2 (OCP FP8 for gfx950): ds_read_tr8_b64 (8 elements) +static bool isSupportedElementType(Type t) { + return t.isF16() || t.isBF16() || isa(t) || + isa(t); +} + +// Check if element type is 8-bit float (FP8 E4M3 or BF8 E5M2) +// Used for: +// 1. Selecting ds_read_tr8_b64 vs ds_read_tr16_b64 +// 2. Checking mixed-type compatibility (fp8+bf8 combinations are valid) +static bool isFp8Type(Type t) { + return isa(t) || isa(t); +} + +// Returns the number of elements returned by LDS transpose load instruction +static int64_t getTransposeLoadVectorLength(Type elemType) { + if (elemType.isF16() || elemType.isBF16()) { + return 4; // ds_read_tr16_b64 + } else if (isFp8Type(elemType)) { + return 8; // ds_read_tr8_b64 + } + llvm_unreachable("Unsupported element type for LDS transpose load"); +} + // Validates MFMA geometry for LDS transpose support. -// Only specific combinations are supported: (16,16), (16,32), (32,8), (32,16) +// Supported combinations: +// Standard: (16,16), (16,32), (32,8), (32,16) +// Scaled FP8: (16,128) quad-rate, (32,64) quad-rate static bool isValidMfmaGeometry(int64_t dDim, int64_t kDim) { - return (dDim == 16 && (kDim == 16 || kDim == 32)) || - (dDim == 32 && (kDim == 8 || kDim == 16)); + return (dDim == 16 && (kDim == 16 || kDim == 32 || kDim == 128)) || + (dDim == 32 && (kDim == 8 || kDim == 16 || kDim == 64)); } // Shape of a single MFMA instruction (internal use only). @@ -119,7 +152,16 @@ static Decision makeDecision(StringRef arch, Type elemTypeA, Type elemTypeB, return dec; } - if (elemTypeA != elemTypeB || !(elemTypeA.isF16() || elemTypeA.isBF16())) + // Check type compatibility: + // - Same types are always allowed (f16==f16, fp8==fp8, etc.) + // - Mixed FP8/BF8 combinations are allowed (hardware supports mixed fp8/bf8 + // MFMA: + // mfma_f32_16x16x32_fp8_fp8, mfma_f32_16x16x32_fp8_bf8, etc.) + // - Other mixed types are NOT allowed (e.g., f16 with fp8) + bool typesCompatible = (elemTypeA == elemTypeB) || + (isFp8Type(elemTypeA) && isFp8Type(elemTypeB)); + if (!typesCompatible || !isSupportedElementType(elemTypeA) || + !isSupportedElementType(elemTypeB)) return dec; // Validate MFMA geometry @@ -306,7 +348,7 @@ LDSTransposeConfigAttr buildTransposeAttrFromParams( "MFMA geometry must be set when building transpose attributes"); assert(isValidMfmaGeometry(mfmaDDim, mfmaKDim) && "Invalid MFMA geometry for LDS transpose - valid: (16,16), (16,32), " - "(32,8), (32,16)"); + "(16,128), (32,8), (32,16), (32,64)"); // Create structured attribute with all parameters return LDSTransposeConfigAttr::get(rewriter.getContext(), mfmaDDim, mfmaKDim, @@ -419,23 +461,27 @@ static Value getDoubleRateKOffsetBase(PatternRewriter &b, Location loc, //===----------------------------------------------------------------------===// // computePanelFinalOffset - Compute final K offset for a specific K tile // -// This function centralizes the K offset computation logic for both single-rate -// and double-rate layouts. It handles the tile-based offset calculation and -// optional low/high half splitting for double-rate layouts. +// This function centralizes the K offset computation logic for single-rate, +// double-rate, and quad-rate layouts. It handles the tile-based offset +// calculation and optional low/high half splitting for double-rate layouts, +// as well as read index offset for quad-rate layouts. // // Formula: // Single-rate: k_final = k_base_local + (kTileIdx * kTileStride) // Double-rate: k_final = k_base_local + kOffsetBase + (kTileIdx * // kTileStride) + halfOffset // where halfOffset = 0 for low half, 4 for high half +// Quad-rate: k_final = k_base_local + (kTileIdx * kTileStride) + readIdx*8 +// where readIdx = 0, 1, 2, 3 for the 4 ds_read_tr8 calls per K tile // // Parameters: // isDoubleRate - Whether this is a double-rate layout (L32x16, L16x32) // kBaseLocal - Local K base offset from computeLDSBaseOffsets() // kOffsetBase - Double-rate K offset base (from getDoubleRateKOffsetBase) // kTileIdx - Current K tile index (0, 1, 2, ...) -// kTileStride - K stride per tile (instrK, e.g., 8 or 16) +// kTileStride - K stride per tile (instrK: 8, 16, 32, 64, or 128) // isHighHalf - For double-rate: true = high half (+4), false = low half +// readIdx - For quad-rate: 0-3 index for consecutive 8-K chunks // // Returns: // Final K offset value to use for emitPanelLoad() @@ -443,8 +489,8 @@ static Value getDoubleRateKOffsetBase(PatternRewriter &b, Location loc, static Value computePanelFinalOffset(PatternRewriter &b, Location loc, bool isDoubleRate, Value kBaseLocal, Value kOffsetBase, int64_t kTileIdx, - Value kTileStride, - bool isHighHalf = false) { + Value kTileStride, bool isHighHalf = false, + int64_t readIdx = 0) { Value kBase = kBaseLocal; if (isDoubleRate) { @@ -470,17 +516,27 @@ static Value computePanelFinalOffset(PatternRewriter &b, Location loc, kBase = arith::AddIOp::create(b, loc, kBaseLocal, k_offset); } else { - // Single-rate: k_base = k_base_local + kTileIdx * kTileStride + // Single-rate or Quad-rate: k_base = k_base_local + kTileIdx * kTileStride if (kTileIdx > 0) { Value kIdxVal = arith::ConstantIndexOp::create(b, loc, kTileIdx); Value kOffsetAdd = arith::MulIOp::create(b, loc, kTileStride, kIdxVal); kBase = arith::AddIOp::create(b, loc, kBase, kOffsetAdd); } + + // Quad-rate: add readIdx * 8 for consecutive 8-K chunks within k_base=32 + // readIdx=0: K+0..7, readIdx=1: K+8..15, readIdx=2: K+16..23, readIdx=3: + // K+24..31 + if (readIdx > 0) { + Value c8 = arith::ConstantIndexOp::create(b, loc, 8); + Value readIdxVal = arith::ConstantIndexOp::create(b, loc, readIdx); + Value readOffset = arith::MulIOp::create(b, loc, readIdxVal, c8); + kBase = arith::AddIOp::create(b, loc, kBase, readOffset); + } } LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] Computed panel K offset for tile " << kTileIdx << (isHighHalf ? " (high)" : " (low)") - << "\n"); + << ", readIdx=" << readIdx << "\n"); return kBase; } @@ -489,8 +545,9 @@ static Value computePanelFinalOffset(PatternRewriter &b, Location loc, // emitPanelLoad - Emit an LDS transpose load operation // // Computes the final LDS offset and emits a hardware LDS transpose load -// instruction (ds_read_tr16_b64). This instruction always returns vector<4xf16> -// regardless of the layout. +// instruction: +// - ds_read_tr16_b64 for f16/bf16: returns vector<4> +// - ds_read_tr8_b64 for fp8/bf8: returns vector<8> // // The final offset is computed as: final_offset = k_base * ldsStride + m_base // where ldsStride depends on the operand (mPerBlock for A, nPerBlock for B). @@ -503,10 +560,10 @@ static Value computePanelFinalOffset(PatternRewriter &b, Location loc, // kBase - K dimension base offset for this panel // mBase - M/N dimension base offset for this panel // ldsStride - Stride between K rows in LDS (mPerBlock or nPerBlock) -// panelVecType - Result type (always vector<4xf16> or vector<4xbf16>) +// panelVecType - Result type (vector<4> for f16/bf16, vector<8> for fp8/bf8) // // Returns: -// The loaded panel vector (vector<4xf16/bf16>) +// The loaded panel vector (vector<4> for f16/bf16, vector<8> for fp8/bf8) //===----------------------------------------------------------------------===// static Value emitPanelLoad(PatternRewriter &b, Location loc, Value rawSrc, Value kBase, Value mBase, Value ldsStride, @@ -515,7 +572,7 @@ static Value emitPanelLoad(PatternRewriter &b, Location loc, Value rawSrc, Value kOffset = arith::MulIOp::create(b, loc, kBase, ldsStride); Value finalOffset = arith::AddIOp::create(b, loc, mBase, kOffset); - // Emit hardware LDS transpose load: ds_read_tr16_b64 + // Emit hardware LDS transpose load auto loadOp = rock::LDSTransposeLoadOp::create(b, loc, panelVecType, rawSrc, ValueRange{finalOffset}); @@ -530,9 +587,9 @@ static Value emitPanelLoad(PatternRewriter &b, Location loc, Value rawSrc, // elements (ds_read_tr16_b64 always returns vector<4xf16>). // // Parameters: -// panelVectors - Array of loaded panel vectors (each is vector<4xf16>) -// dest - Destination memref (rank-1, scalar layout) -// targetElems - Maximum number of elements to write +// panelVectors - Array of loaded panel vectors (vector<4> for f16/bf16, +// vector<8> for fp8/bf8) dest - Destination memref (rank-1, scalar +// layout) targetElems - Maximum number of elements to write // // Returns: // success() if all target elements were written @@ -546,14 +603,16 @@ writePanelVectorsToDestination(PatternRewriter &b, Location loc, int64_t produced = 0; // Extract elements per vector from the actual vector type - // Hardware instruction ds_read_tr16_b64 always returns vector<4xf16> + // Hardware instructions: + // - ds_read_tr16_b64 returns vector<4xf16/bf16> + // - ds_read_tr8_b64 returns vector<8xfp8> assert(!panelVectors.empty() && "Panel vectors array must not be empty"); auto panelVecType = cast(panelVectors[0].getType()); int64_t elementsPerVector = panelVecType.getShape()[0]; - // Verify hardware constraint: ds_read_tr16_b64 returns exactly 4 elements - assert(elementsPerVector == 4 && - "LDS transpose load must produce vector<4xf16> per panel"); + // Verify hardware constraint: 4 elements for 16-bit, 8 elements for 8-bit + assert((elementsPerVector == 4 || elementsPerVector == 8) && + "LDS transpose load must produce vector<4> or vector<8> per panel"); LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] Writing " << panelVectors.size() << " panel vectors (" @@ -609,60 +668,129 @@ writePanelVectorsToDestination(PatternRewriter &b, Location loc, //===----------------------------------------------------------------------===// static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, int64_t dDim, int64_t kDim, - Value lane) { - Value c16 = arith::ConstantIndexOp::create(b, loc, 16); - Value c4 = arith::ConstantIndexOp::create(b, loc, 4); + Value lane, Type elemType) { + // Common constants used by both FP8 and F16/BF16 Value c2 = arith::ConstantIndexOp::create(b, loc, 2); + Value c4 = arith::ConstantIndexOp::create(b, loc, 4); + Value c16 = arith::ConstantIndexOp::create(b, loc, 16); - Value blockId = arith::DivUIOp::create(b, loc, lane, c16); - Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); + Value kOffsetBase, mOffsetBase; - // Base offset calculations - Value mOffsetBase = arith::MulIOp::create( - b, loc, arith::RemUIOp::create(b, loc, laneInBlock, c4), c4); - Value kOffsetBase = arith::DivUIOp::create(b, loc, laneInBlock, c4); + if (isFp8Type(elemType)) { + Value c8 = arith::ConstantIndexOp::create(b, loc, 8); - SmallVector panelOffsets; + Value blockId = arith::DivUIOp::create(b, loc, lane, c16); + Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); + Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c2); + Value mParity = arith::RemUIOp::create(b, loc, laneInBlock, c2); + + if (dDim == 16 && kDim == 32) { + // Block layout: Block 0-3 map to K=0..7, 8..15, 16..23, 24..31 + + // kOffsetBase = k_local + block_id * 8 + Value blockKOffset = arith::MulIOp::create(b, loc, blockId, c8); + kOffsetBase = arith::AddIOp::create(b, loc, kLocal, blockKOffset); + + // mOffsetBase = m_parity * 8 + mOffsetBase = arith::MulIOp::create(b, loc, mParity, c8); + + } else if (dDim == 32 && kDim == 16) { + // Block layout: m_block = block_id % 2, k_block = block_id / 2 + + Value mBlock = arith::RemUIOp::create(b, loc, blockId, c2); + Value kBlock = arith::DivUIOp::create(b, loc, blockId, c2); + + // kOffsetBase = k_local + k_block * 8 + Value kBlockOffset = arith::MulIOp::create(b, loc, kBlock, c8); + kOffsetBase = arith::AddIOp::create(b, loc, kLocal, kBlockOffset); + + // mOffsetBase = m_parity * 8 + m_block * 16 + Value mParityOffset = arith::MulIOp::create(b, loc, mParity, c8); + Value mBlockOffset = arith::MulIOp::create(b, loc, mBlock, c16); + mOffsetBase = arith::AddIOp::create(b, loc, mParityOffset, mBlockOffset); + + } else if (dDim == 16 && kDim == 128) { + // FP8 Scaled 16x128: quad-rate (k_base=32) + // Block layout: Block 0-3 map to K=0..31, 32..63, 64..95, 96..127 + + Value c32 = arith::ConstantIndexOp::create(b, loc, 32); + + // kOffsetBase = k_local + block_id * 32 + Value blockKOffset = arith::MulIOp::create(b, loc, blockId, c32); + kOffsetBase = arith::AddIOp::create(b, loc, kLocal, blockKOffset); + + // mOffsetBase = m_parity * 8 + mOffsetBase = arith::MulIOp::create(b, loc, mParity, c8); + + } else if (dDim == 32 && kDim == 64) { + // FP8 Scaled 32x64: quad-rate (k_base=32) + // Block layout: m_block = block_id % 2, k_block = block_id / 2 + + Value c32 = arith::ConstantIndexOp::create(b, loc, 32); + + Value mBlock = arith::RemUIOp::create(b, loc, blockId, c2); + Value kBlock = arith::DivUIOp::create(b, loc, blockId, c2); + + // kOffsetBase = k_local + k_block * 32 + Value kBlockOffset = arith::MulIOp::create(b, loc, kBlock, c32); + kOffsetBase = arith::AddIOp::create(b, loc, kLocal, kBlockOffset); + + // mOffsetBase = m_parity * 8 + m_block * 16 + Value mParityOffset = arith::MulIOp::create(b, loc, mParity, c8); + Value mBlockOffset = arith::MulIOp::create(b, loc, mBlock, c16); + mOffsetBase = arith::AddIOp::create(b, loc, mParityOffset, mBlockOffset); + + } else { + llvm_unreachable("Unsupported FP8 MFMA geometry in getBasePanelOffsets"); + } - if (dDim == 16 && kDim == 32) { - // 16x32 layout - panelOffsets = {kOffsetBase, mOffsetBase}; - } else if (dDim == 16 && kDim == 16) { - // 16x16 layout - // kbase = kOffsetBase + (blockId * 4) - Value kBase = arith::AddIOp::create( - b, loc, arith::MulIOp::create(b, loc, blockId, c4), kOffsetBase); - panelOffsets = {kBase, mOffsetBase}; - } else if (dDim == 32 && kDim == 16) { - // 32x16 layout - // mbase = mOffsetBase + (blockId % 2) * 16 - Value mBase = arith::AddIOp::create( - b, loc, - arith::MulIOp::create(b, loc, - arith::RemUIOp::create(b, loc, blockId, c2), c16), - mOffsetBase); - panelOffsets = {kOffsetBase, mBase}; - } else if (dDim == 32 && kDim == 8) { - // 32x8 layout - // k_base_local = kOffsetBase + (blockId / 2) * 4 - Value kBase = arith::AddIOp::create( - b, loc, - arith::MulIOp::create(b, loc, - arith::DivUIOp::create(b, loc, blockId, c2), c4), - kOffsetBase); - - // m_offset_base = mOffsetBase + (blockId % 2) * 16 - Value mBase = arith::AddIOp::create( - b, loc, - arith::MulIOp::create(b, loc, - arith::RemUIOp::create(b, loc, blockId, c2), c16), - mOffsetBase); - panelOffsets = {kBase, mBase}; } else { - llvm_unreachable("Unsupported MFMA geometry in getBasePanelOffsets"); + // F16/BF16 uses block-based lane mapping + Value blockId = arith::DivUIOp::create(b, loc, lane, c16); + Value laneInBlock = arith::RemUIOp::create(b, loc, lane, c16); + + // Base calculations common to all F16/BF16 geometries + Value kLocal = arith::DivUIOp::create(b, loc, laneInBlock, c4); + Value mLocal = arith::MulIOp::create( + b, loc, arith::RemUIOp::create(b, loc, laneInBlock, c4), c4); + + if (dDim == 16 && kDim == 32) { + // 16x32: direct mapping + kOffsetBase = kLocal; + mOffsetBase = mLocal; + + } else if (dDim == 16 && kDim == 16) { + // 16x16: k += blockId * 4 + kOffsetBase = arith::AddIOp::create( + b, loc, kLocal, arith::MulIOp::create(b, loc, blockId, c4)); + mOffsetBase = mLocal; + + } else if (dDim == 32 && kDim == 16) { + // 32x16: m += (blockId % 2) * 16 + kOffsetBase = kLocal; + mOffsetBase = arith::AddIOp::create( + b, loc, mLocal, + arith::MulIOp::create( + b, loc, arith::RemUIOp::create(b, loc, blockId, c2), c16)); + + } else if (dDim == 32 && kDim == 8) { + // 32x8: k += (blockId / 2) * 4, m += (blockId % 2) * 16 + kOffsetBase = arith::AddIOp::create( + b, loc, kLocal, + arith::MulIOp::create( + b, loc, arith::DivUIOp::create(b, loc, blockId, c2), c4)); + mOffsetBase = arith::AddIOp::create( + b, loc, mLocal, + arith::MulIOp::create( + b, loc, arith::RemUIOp::create(b, loc, blockId, c2), c16)); + + } else { + llvm_unreachable( + "Unsupported F16/BF16 MFMA geometry in getBasePanelOffsets"); + } } - return panelOffsets; + return {kOffsetBase, mOffsetBase}; } //===----------------------------------------------------------------------===// @@ -683,6 +811,7 @@ static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, // dDim - MFMA D dimension (M or N, 16 or 32) // kDim - MFMA K dimension (8, 16, or 32) // lane - Thread's lane ID within the workgroup +// elemType - Element type (f16, bf16, fp8, or bf8) for selecting lane mapping // // Returns: // std::pair: @@ -691,8 +820,10 @@ static SmallVector getBasePanelOffsets(PatternRewriter &b, Location loc, //===----------------------------------------------------------------------===// static std::pair computeLDSBaseOffsets(PatternRewriter &b, Location loc, int64_t dDim, - int64_t kDim, Value lane) { - SmallVector offsets = getBasePanelOffsets(b, loc, dDim, kDim, lane); + int64_t kDim, Value lane, + Type elemType) { + SmallVector offsets = + getBasePanelOffsets(b, loc, dDim, kDim, lane, elemType); LLVM_DEBUG(llvm::dbgs() << "[lds_transpose] Computed LDS base offsets for " << dDim << "x" << kDim << ": " @@ -1098,8 +1229,9 @@ static Value computeFinalMNOffset(PatternRewriter &b, Location loc, // emitThreadwiseHWTranspose - Lower threadwise_read_into to HW transpose loads //===----------------------------------------------------------------------===// // Lowers threadwise_read_into with LDS transpose config into hardware transpose -// load instructions (ds_read_tr16_b64) that read from LDS in MFMA-friendly -// order. +// load instructions that read from LDS in MFMA-friendly order: +// - ds_read_tr16_b64 for f16/bf16 (returns vector<4>) +// - ds_read_tr8_b64 for fp8/bf8 (returns vector<8>) // // Algorithm: // 1. Extract config: MFMA geometry (dDim, kDim), tiling params, operand kind @@ -1112,8 +1244,8 @@ static Value computeFinalMNOffset(PatternRewriter &b, Location loc, // - Outer: M/N tiles (all at once for double-buffering, one at a time // otherwise) // - Inner: K tiles (1 load for single-rate, 2 loads for double-rate layouts) -// 6. For each iteration: compute final LDS offset, emit ds_read_tr16_b64 -// instruction (returns vector<4xf16>) +// 6. For each iteration: compute final LDS offset, emit LDS transpose load +// instruction (ds_read_tr16_b64 for f16/bf16, ds_read_tr8_b64 for fp8/bf8) // 7. Extract elements from panel vectors and write sequentially to destination // // Example: 16x32 layout, 1 M-tile, 2 K-tiles → 2 ds_read_tr16_b64 calls → 8 @@ -1166,24 +1298,32 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, // Use mPerBlock as stride for operand A, nPerBlock for operand B int64_t ldsStride = (operand == OperandKind::A) ? mPerBlock : nPerBlock; - // Determine if this is a double-rate instruction - // Double-rate ONLY for (32,16) and (16,32) MFMA - // (16,16) and (32,8) are SINGLE-RATE - bool isDoubleRate = - (dDim == 32 && instrK == 16) || (dDim == 16 && instrK == 32); - - // Each ds_read_tr16_b64 call ALWAYS returns vector<4xf16> - // For double-rate, we make 2 calls and store all 8 elements separately - VectorType panelVecType = VectorType::get({4}, elemType); + // Determine if this is a double-rate or quad-rate instruction + // Double-rate ONLY for (32,16) and (16,32) MFMA with F16/BF16 + // FP8/BF8 uses ds_read_tr8_b64 which returns 8 elements, so (16,32) and + // (32,16) are SINGLE-RATE for FP8/BF8 (16,16) and (32,8) are always + // SINGLE-RATE + // Quad-rate for FP8 scaled MFMA: 16x128 and 32x64 (k_base=32, 4 reads of 8) + bool isDoubleRate = !isFp8Type(elemType) && ((dDim == 32 && instrK == 16) || + (dDim == 16 && instrK == 32)); + bool isQuadRate = isFp8Type(elemType) && ((dDim == 16 && instrK == 128) || + (dDim == 32 && instrK == 64)); + + // Determine vector length based on element type: + // - f16/bf16: ds_read_tr16_b64 returns vector<4> + // - fp8/bf8: ds_read_tr8_b64 returns vector<8> + int64_t vecLen = getTransposeLoadVectorLength(elemType); + VectorType panelVecType = VectorType::get({vecLen}, elemType); // panelVectors will contain: - // - Single-rate: 1 vector<4xf16> per K tile - // - Double-rate: 2 vector<4xf16> per K tile (low + high) + // - Single-rate: 1 vector per K tile + // - Double-rate (f16/bf16 only): 2 vectors per K tile (low + high) + // - Quad-rate (FP8 16x128 or 32x64): 4 vectors per K tile (readIdx 0-3) SmallVector panelVectors; // Get base offsets using computeLDSBaseOffsets helper auto [k_base_local, m_offset_base] = - computeLDSBaseOffsets(b, loc, dDim, instrK, lane); + computeLDSBaseOffsets(b, loc, dDim, instrK, lane, elemType); // K stride per tile: instrK (MFMA K dimension) int64_t kTileStride = instrK; @@ -1237,21 +1377,38 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, b, loc, m_offset_base, operand, waveM, waveN, mnTileIndex, mnIdxLocal, useDynamicMnIndex, waveOffsetStrideVal, tileOffsetStrideVal); - if (!isDoubleRate) { - // SINGLE-RATE (L32x8, L16x16): One load per K tile + if (isQuadRate) { + // QUAD-RATE (FP8 scaled MFMA 16x128 or 32x64): FOUR loads per K tile + // Each load returns vector<8> for fp8, total 32 elements per K tile + // k_base=32, so 4 reads of 8 elements each give consecutive K + // For 16x128: all blocks in K dimension (block_id * 32) + // For 32x64: m_block/k_block split (k_block = block_id / 2) * 32 + for (int64_t readIdx = 0; readIdx < 4; ++readIdx) { + Value k_base = computePanelFinalOffset( + b, loc, /*isDoubleRate=*/false, k_base_local, kOffsetBase, kIdx, + kTileStrideVal, /*isHighHalf=*/false, /*readIdx=*/readIdx); + + Value panelVec = emitPanelLoad(b, loc, rawSrc, k_base, m_base, + ldsStrideVal, panelVecType); + panelVectors.push_back(panelVec); + } + + } else if (!isDoubleRate) { + // SINGLE-RATE (L32x8, L16x16, or FP8/BF8): One load per K tile Value k_base = computePanelFinalOffset(b, loc, isDoubleRate, k_base_local, kOffsetBase, kIdx, kTileStrideVal); - // Emit LDS transpose load for this K tile (single-rate: one per K tile) + // Emit LDS transpose load for this K tile Value panelVec = emitPanelLoad(b, loc, rawSrc, k_base, m_base, ldsStrideVal, panelVecType); panelVectors.push_back(panelVec); } else { - // DOUBLE-RATE (L32x16, L16x32): TWO loads per K tile - // Each load returns vector<4xf16>, total 8 elements per K tile - // Compute K offsets for low and high halves + // DOUBLE-RATE (L32x16, L16x32 with F16/BF16 only): TWO loads per K tile + // Each load returns vector<4> for f16/bf16, total 8 elements per K tile + // Note: FP8/BF8 is NEVER double-rate (ds_read_tr8_b64 returns 8 + // elements) Compute K offsets for low and high halves Value k_base_low = computePanelFinalOffset( b, loc, isDoubleRate, k_base_local, kOffsetBase, kIdx, kTileStrideVal, /*isHighHalf=*/false); @@ -1274,15 +1431,18 @@ LogicalResult emitThreadwiseHWTranspose(PatternRewriter &b, // Calculate expected number of loads // - For double buffering: we generate ALL M/N panels → endMnIdx panels × - // kPanels × (1 or 2 for rate) + // kPanels × (1, 2, or 4 for rate) // - Single-rate: 1 load per K tile → actualMnTiles × kPanels loads // - Double-rate: 2 loads per K tile → actualMnTiles × kPanels × 2 loads + // - Quad-rate: 4 loads per K tile → actualMnTiles × kPanels × 4 loads int64_t actualMnTiles = endMnIdx - startMnIdx; - int64_t loadsPerKTile = isDoubleRate ? 2 : 1; + int64_t loadsPerKTile = isQuadRate ? 4 : (isDoubleRate ? 2 : 1); int64_t expectedLoads = actualMnTiles * kPanels * loadsPerKTile; - // Each load ALWAYS produces 4 elements (ds_read_tr16_b64 → vector<4xf16>) - int64_t sliceElems = expectedLoads * 4; + // Each load produces vecLen elements: + // - f16/bf16: 4 elements (ds_read_tr16_b64) + // - fp8/bf8: 8 elements (ds_read_tr8_b64) + int64_t sliceElems = expectedLoads * vecLen; // Verify we generated the expected number of loads if (panelVectors.size() != (size_t)expectedLoads) { diff --git a/mlir/test/Dialect/Rock/lds_transpose_error.mlir b/mlir/test/Dialect/Rock/lds_transpose_error.mlir index efc7cfc42fa4..d7ee8500cacf 100644 --- a/mlir/test/Dialect/Rock/lds_transpose_error.mlir +++ b/mlir/test/Dialect/Rock/lds_transpose_error.mlir @@ -2,13 +2,13 @@ // Error case: Invalid MFMA geometry (16x8 is not valid) // This tests that LDSTransposeConfigAttr::verify() catches invalid MFMA -// geometry combinations. Valid combinations are: (16,16), (16,32), (32,8), (32,16) +// geometry combinations. Valid combinations are: (16,16), (16,32), (16,128), (32,8), (32,16), (32,64) func.func @threadwise_read_into_invalid_mfma_geometry_16x8( %source: memref<128xf16, #gpu.address_space>, %dest: memref<8xf16, #gpu.address_space>) attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { rock.threadwise_read_into { - // expected-error @+1 {{invalid MFMA geometry (16x8) for LDS transpose - valid combinations: (16,16), (16,32), (32,8), (32,16)}} + // expected-error @+1 {{invalid MFMA geometry (16x8) for LDS transpose - valid combinations: (16,16), (16,32), (16,128), (32,8), (32,16), (32,64)}} ldsTransposeConfig = #rock.lds_transpose_config< dDim = 16, kDim = 8, mPerBlock = 128, nPerBlock = 128, kPerBlock = 32, @@ -27,7 +27,7 @@ func.func @threadwise_read_into_invalid_mfma_geometry_32x32( %dest: memref<8xf16, #gpu.address_space>) attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { rock.threadwise_read_into { - // expected-error @+1 {{invalid MFMA geometry (32x32) for LDS transpose - valid combinations: (16,16), (16,32), (32,8), (32,16)}} + // expected-error @+1 {{invalid MFMA geometry (32x32) for LDS transpose - valid combinations: (16,16), (16,32), (16,128), (32,8), (32,16), (32,64)}} ldsTransposeConfig = #rock.lds_transpose_config< dDim = 32, kDim = 32, mPerBlock = 128, nPerBlock = 128, kPerBlock = 32, @@ -46,7 +46,7 @@ func.func @threadwise_read_into_invalid_mfma_geometry_8x8( %dest: memref<8xf16, #gpu.address_space>) attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { rock.threadwise_read_into { - // expected-error @+1 {{invalid MFMA geometry (8x8) for LDS transpose - valid combinations: (16,16), (16,32), (32,8), (32,16)}} + // expected-error @+1 {{invalid MFMA geometry (8x8) for LDS transpose - valid combinations: (16,16), (16,32), (16,128), (32,8), (32,16), (32,64)}} ldsTransposeConfig = #rock.lds_transpose_config< dDim = 8, kDim = 8, mPerBlock = 128, nPerBlock = 128, kPerBlock = 32, diff --git a/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir b/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir index 29a2f190a06b..b6a64a6985f7 100644 --- a/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir +++ b/mlir/test/Dialect/Rock/lowering_load_transpose_lds.mlir @@ -14,4 +14,18 @@ module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx950"} { %v = rock.lds_transpose_load %src[%i, %j] : memref<64x128xbf16, #gpu.address_space> -> vector<4xbf16> return %v : vector<4xbf16> } + +// CHECK-LABEL: func @test_load_transpose_fp8_e4m3 + func.func @test_load_transpose_fp8_e4m3(%src: memref<128x256xf8E4M3FN, #gpu.address_space>, %i: index, %j: index) -> vector<8xf8E4M3FN> { + // CHECK: amdgpu.transpose_load %arg0[%arg1, %arg2] : memref<128x256xf8E4M3FN, #gpu.address_space> -> vector<8xf8E4M3FN> + %v = rock.lds_transpose_load %src[%i, %j] : memref<128x256xf8E4M3FN, #gpu.address_space> -> vector<8xf8E4M3FN> + return %v : vector<8xf8E4M3FN> + } + +// CHECK-LABEL: func @test_load_transpose_fp8_e5m2 + func.func @test_load_transpose_fp8_e5m2(%src: memref<64x128xf8E5M2, #gpu.address_space>, %i: index, %j: index) -> vector<8xf8E5M2> { + // CHECK: amdgpu.transpose_load %arg0[%arg1, %arg2] : memref<64x128xf8E5M2, #gpu.address_space> -> vector<8xf8E5M2> + %v = rock.lds_transpose_load %src[%i, %j] : memref<64x128xf8E5M2, #gpu.address_space> -> vector<8xf8E5M2> + return %v : vector<8xf8E5M2> + } } diff --git a/mlir/test/Dialect/Rock/lowering_xdlops_gemm.mlir b/mlir/test/Dialect/Rock/lowering_xdlops_gemm.mlir index cca87c21f4b8..1a3baeb845c0 100644 --- a/mlir/test/Dialect/Rock/lowering_xdlops_gemm.mlir +++ b/mlir/test/Dialect/Rock/lowering_xdlops_gemm.mlir @@ -649,3 +649,255 @@ func.func @accel_gemm_gfx950_f32_16x16x512_fp4_scaled_multi(%matrixA : memref<1x } : memref<1x1xvector<4xf32>, 5> += memref<1x4xvector<32xf4E2M1FN>, 5> scaled by memref<1x4xvector<32xf8E8M0FNU>, 5> * memref<1x4xvector<32xf4E2M1FN>, 5> scaled by memref<1x4xvector<32xf8E8M0FNU>, 5> return } + +func.func @accel_gemm_gfx950_scaled_fp8_bf8_32x32x64_single_v1(%matrixA : memref<1x2xvector<32xf8E4M3FN>, 5>, + %matrixB : memref<1x2xvector<32xf8E5M2>, 5>, + %matrixC : memref<1x1xvector<16xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_fp8_bf8_32x32x64_single_v1 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 32x32x64 + // CHECK-SAME: vector<32xf8E4M3FN>{{.*}}vector<32xf8E5M2>{{.*}}vector<16xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 2, + mPerBlock = 32, + nPerBlock = 32, + kpack = 32, + mPerWave = 32, + nPerWave = 32, + mnPerXdl = 32, + splitKFactor = 1, + scheduleVersion = 1, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<16xf32>, 5> += memref<1x2xvector<32xf8E4M3FN>, 5> * memref<1x2xvector<32xf8E5M2>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_fp8_fp8_16x16x128_single_v1(%matrixA : memref<1x4xvector<32xf8E4M3FN>, 5>, + %matrixB : memref<1x4xvector<32xf8E4M3FN>, 5>, + %matrixC : memref<1x1xvector<4xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_fp8_fp8_16x16x128_single_v1 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 16x16x128 + // CHECK-SAME: vector<32xf8E4M3FN>{{.*}}vector<32xf8E4M3FN>{{.*}}vector<4xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 4, + mPerBlock = 16, + nPerBlock = 16, + kpack = 32, + mPerWave = 16, + nPerWave = 16, + mnPerXdl = 16, + splitKFactor = 1, + scheduleVersion = 1, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<4xf32>, 5> += memref<1x4xvector<32xf8E4M3FN>, 5> * memref<1x4xvector<32xf8E4M3FN>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_bf8_bf8_16x16x128_single_v3_kpack1(%matrixA : memref<1x1xvector<32xf8E5M2>, 5>, + %matrixB : memref<1x1xvector<32xf8E5M2>, 5>, + %matrixC : memref<1x1xvector<4xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_bf8_bf8_16x16x128_single_v3_kpack1 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 16x16x128 + // CHECK-SAME: vector<32xf8E5M2>{{.*}}vector<32xf8E5M2>{{.*}}vector<4xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 128, + mPerBlock = 16, + nPerBlock = 16, + kpack = 1, + mPerWave = 16, + nPerWave = 16, + mnPerXdl = 16, + splitKFactor = 1, + scheduleVersion = 3, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<4xf32>, 5> += memref<1x1xvector<32xf8E5M2>, 5> * memref<1x1xvector<32xf8E5M2>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_bf8_fp8_32x32x64_single_v3_kpack1(%matrixA : memref<1x1xvector<32xf8E5M2>, 5>, + %matrixB : memref<1x1xvector<32xf8E4M3FN>, 5>, + %matrixC : memref<1x1xvector<16xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_bf8_fp8_32x32x64_single_v3_kpack1 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 32x32x64 + // CHECK-SAME: vector<32xf8E5M2>{{.*}}vector<32xf8E4M3FN>{{.*}}vector<16xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 64, + mPerBlock = 32, + nPerBlock = 32, + kpack = 1, + mPerWave = 32, + nPerWave = 32, + mnPerXdl = 32, + splitKFactor = 1, + scheduleVersion = 3, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<16xf32>, 5> += memref<1x1xvector<32xf8E5M2>, 5> * memref<1x1xvector<32xf8E4M3FN>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_fp8_bf8_32x32x64_double_v2(%matrixA : memref<1x2xvector<32xf8E4M3FN>, 5>, + %matrixB : memref<1x2xvector<32xf8E5M2>, 5>, + %matrixC : memref<1x1xvector<16xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_fp8_bf8_32x32x64_double_v2 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 32x32x64 + // CHECK-SAME: vector<32xf8E4M3FN>{{.*}}vector<32xf8E5M2>{{.*}}vector<16xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 2, + mPerBlock = 32, + nPerBlock = 32, + kpack = 32, + mPerWave = 32, + nPerWave = 32, + mnPerXdl = 32, + splitKFactor = 1, + scheduleVersion = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<16xf32>, 5> += memref<1x2xvector<32xf8E4M3FN>, 5> * memref<1x2xvector<32xf8E5M2>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_bf8_bf8_16x16x128_double_v4(%matrixA : memref<1x4xvector<32xf8E5M2>, 5>, + %matrixB : memref<1x4xvector<32xf8E5M2>, 5>, + %matrixC : memref<1x1xvector<4xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_bf8_bf8_16x16x128_double_v4 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 16x16x128 + // CHECK-SAME: vector<32xf8E5M2>{{.*}}vector<32xf8E5M2>{{.*}}vector<4xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 4, + mPerBlock = 16, + nPerBlock = 16, + kpack = 32, + mPerWave = 16, + nPerWave = 16, + mnPerXdl = 16, + splitKFactor = 1, + scheduleVersion = 4, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<4xf32>, 5> += memref<1x4xvector<32xf8E5M2>, 5> * memref<1x4xvector<32xf8E5M2>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_bf8_bf8_32x32x64_double_v2_kpack8(%matrixA : memref<1x2xvector<32xf8E5M2>, 5>, + %matrixB : memref<1x2xvector<32xf8E5M2>, 5>, + %matrixC : memref<1x1xvector<16xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_bf8_bf8_32x32x64_double_v2_kpack8 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 32x32x64 + // CHECK-SAME: vector<32xf8E5M2>{{.*}}vector<32xf8E5M2>{{.*}}vector<16xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 8, + mPerBlock = 32, + nPerBlock = 32, + kpack = 8, + mPerWave = 32, + nPerWave = 32, + mnPerXdl = 32, + splitKFactor = 1, + scheduleVersion = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<16xf32>, 5> += memref<1x2xvector<32xf8E5M2>, 5> * memref<1x2xvector<32xf8E5M2>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_fp8_fp8_16x16x128_double_v4_kpack4(%matrixA : memref<1x4xvector<32xf8E4M3FN>, 5>, + %matrixB : memref<1x4xvector<32xf8E4M3FN>, 5>, + %matrixC : memref<1x1xvector<4xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_fp8_fp8_16x16x128_double_v4_kpack4 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 16x16x128 + // CHECK-SAME: vector<32xf8E4M3FN>{{.*}}vector<32xf8E4M3FN>{{.*}}vector<4xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 32, + mPerBlock = 16, + nPerBlock = 16, + kpack = 4, + mPerWave = 16, + nPerWave = 16, + mnPerXdl = 16, + splitKFactor = 1, + scheduleVersion = 4, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<4xf32>, 5> += memref<1x4xvector<32xf8E4M3FN>, 5> * memref<1x4xvector<32xf8E4M3FN>, 5> + return +} + +func.func @accel_gemm_gfx950_scaled_bf8_bf8_16x16x128_double_v2_kpack1(%matrixA : memref<1x4xvector<32xf8E5M2>, 5>, + %matrixB : memref<1x4xvector<32xf8E5M2>, 5>, + %matrixC : memref<1x1xvector<4xf32>, 5>) { + // CHECK-LABEL: func.func @accel_gemm_gfx950_scaled_bf8_bf8_16x16x128_double_v2_kpack1 + // CHECK: rock.transforming_for + // CHECK-SAME: bounds [1, 1, 1] + // CHECK: amdgpu.scaled_mfma 16x16x128 + // CHECK-SAME: vector<32xf8E5M2>{{.*}}vector<32xf8E5M2>{{.*}}vector<4xf32> + // CHECK-NOT: amdgpu.scaled_mfma + %c0 = arith.constant 0 : index + rock.threadwise_gemm_accel %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma { + arch = "amdgcn-amd-amdhsa:gfx950", + params = #rock.accel_gemm_params< + kpackPerBlock = 128, + mPerBlock = 16, + nPerBlock = 16, + kpack = 1, + mPerWave = 16, + nPerWave = 16, + mnPerXdl = 16, + splitKFactor = 1, + scheduleVersion = 2, + outputSwizzle = 2, wavesPerEU = 0, gridGroupSize = 0, + forceUnroll = true> + } : memref<1x1xvector<4xf32>, 5> += memref<1x4xvector<32xf8E5M2>, 5> * memref<1x4xvector<32xf8E5M2>, 5> + return +} diff --git a/mlir/test/Dialect/Rock/ops.mlir b/mlir/test/Dialect/Rock/ops.mlir index fafe8fa665d6..60ffdd2d5a15 100644 --- a/mlir/test/Dialect/Rock/ops.mlir +++ b/mlir/test/Dialect/Rock/ops.mlir @@ -433,6 +433,27 @@ func.func @rock_lds_transpose_load_full_arch(%lds_buffer: memref<128x64xf16, #gp return } +// CHECK-LABEL: func.func @rock_lds_transpose_load_fp8_e4m3 +// CHECK: rock.lds_transpose_load +func.func @rock_lds_transpose_load_fp8_e4m3(%lds_buffer: memref<128x64xf8E4M3FN, #gpu.address_space>) + attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { + %c0 = arith.constant 0 : index + %fragment = rock.lds_transpose_load %lds_buffer[%c0, %c0] + : memref<128x64xf8E4M3FN, #gpu.address_space> -> vector<8xf8E4M3FN> + return +} + +// CHECK-LABEL: func.func @rock_lds_transpose_load_fp8_e5m2 +// CHECK: rock.lds_transpose_load +func.func @rock_lds_transpose_load_fp8_e5m2(%lds_buffer: memref<256x128xf8E5M2, #gpu.address_space>) + attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %fragment = rock.lds_transpose_load %lds_buffer[%c32, %c0] + : memref<256x128xf8E5M2, #gpu.address_space> -> vector<8xf8E5M2> + return +} + // CHECK-LABEL: func.func @test_lds_transpose_config_attr_16x32 // CHECK: ldsTransposeConfig = #rock.lds_transpose_config func.func @test_lds_transpose_config_attr_16x32(%src: memref<8192xf16, #gpu.address_space>, @@ -525,6 +546,52 @@ func.func @test_lds_transpose_config_attr_32x8(%src: memref<2048xf16, #gpu.addre return } +// CHECK-LABEL: func.func @test_lds_transpose_config_attr_16x128 +// CHECK: ldsTransposeConfig = #rock.lds_transpose_config +func.func @test_lds_transpose_config_attr_16x128(%src: memref<2048xf8E4M3FN, #gpu.address_space>, + %dest: memref<32xf8E4M3FN, #gpu.address_space>) + attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { + rock.threadwise_read_into { + forceUnroll, + ldsTransposeConfig = #rock.lds_transpose_config< + dDim = 16, + kDim = 128, + mPerBlock = 16, + nPerBlock = 16, + kPerBlock = 128, + mPerWave = 16, + nPerWave = 16, + doubleBuffering = false, + isOperandA = true + >, + useIndexDiffs + } [](%src) [] -> %dest : memref<2048xf8E4M3FN, #gpu.address_space> -> memref<32xf8E4M3FN, #gpu.address_space> + return +} + +// CHECK-LABEL: func.func @test_lds_transpose_config_attr_32x64 +// CHECK: ldsTransposeConfig = #rock.lds_transpose_config +func.func @test_lds_transpose_config_attr_32x64(%src: memref<2048xf8E4M3FN, #gpu.address_space>, + %dest: memref<32xf8E4M3FN, #gpu.address_space>) + attributes {arch = "amdgcn-amd-amdhsa:gfx950"} { + rock.threadwise_read_into { + forceUnroll, + ldsTransposeConfig = #rock.lds_transpose_config< + dDim = 32, + kDim = 64, + mPerBlock = 32, + nPerBlock = 32, + kPerBlock = 64, + mPerWave = 32, + nPerWave = 32, + doubleBuffering = false, + isOperandA = false + >, + useIndexDiffs + } [](%src) [] -> %dest : memref<2048xf8E4M3FN, #gpu.address_space> -> memref<32xf8E4M3FN, #gpu.address_space> + return +} + // CHECK-LABEL: func.func @test_threadwise_read_into_without_lds_transpose // CHECK: rock.threadwise_read_into {forceUnroll, useIndexDiffs} // CHECK-NOT: ldsTransposeConfig diff --git a/mlir/test/e2e/CMakeLists.txt b/mlir/test/e2e/CMakeLists.txt index ca0d7c2473f3..1ffabef3538a 100644 --- a/mlir/test/e2e/CMakeLists.txt +++ b/mlir/test/e2e/CMakeLists.txt @@ -50,6 +50,7 @@ if (ROCMLIR_DRIVER_PR_E2E_TEST_ENABLED) PrConvElementwiseGemmBF16SplitK PrGemmDirectToLDS PrLdsTransposeLoad + PrLdsTransposeLoadFp8 PrLdsTransposeLoadAttention PrConvDirectToLDS PrAttentionDirectToLDS diff --git a/mlir/test/e2e/PrLdsTransposeLoadFp8.cfg b/mlir/test/e2e/PrLdsTransposeLoadFp8.cfg new file mode 100644 index 000000000000..46909aa10a02 --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadFp8.cfg @@ -0,0 +1,2 @@ +if not 'lds_transpose_load' in config.features: + config.unsupported = True diff --git a/mlir/test/e2e/PrLdsTransposeLoadFp8.toml b/mlir/test/e2e/PrLdsTransposeLoadFp8.toml new file mode 100644 index 000000000000..c331af6e49fb --- /dev/null +++ b/mlir/test/e2e/PrLdsTransposeLoadFp8.toml @@ -0,0 +1,129 @@ +directory = "PrLdsTransposeLoadFp8" +prefix = "rocmlir-gen" +suffix = "--operation gemm --arch %arch %pv %constrained_float_range_random_data %rocmlir_gen_flags | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix=" + +[[axis]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8", "fp8_bf8", "bf8_fp8"] +prefix = "-t " + +# ============================================================================ +# Suite 1: BOTH A and B use LDS transpose (transA=true, transB=false) +# 5 tests covering different MFMA geometries and kpack/kPerBlock combinations +# Only fp8_fp8 and bf8_bf8 +# ============================================================================ +[[suite]] +name = "lds_transpose_both_operands_fp8" + +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=false --perf_config v3:64,64,32,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 64 -n 64 --transA=true --transB=false --perf_config v3:128,64,16,32,16,16,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 256 -k 256 -n 64 --transA=true --transB=false --perf_config v3:256,64,8,16,16,16,1,4,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 128 --transA=true --transB=false --perf_config v3:128,128,16,32,32,4,1,4,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 256 -k 128 -n 128 --transA=true --transB=false --perf_config v3:256,128,32,64,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +# ============================================================================ +# Suite 2: ONLY A uses LDS transpose (transA=true, transB=true) +# 3 tests - B uses regular load +# Only fp8_fp8 and bf8_bf8 +# ============================================================================ +[[suite]] +name = "lds_transpose_A_only_fp8" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=true --transB=true --perf_config v3:64,64,16,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 64 --transA=true --transB=true --perf_config v3:128,64,8,32,32,16,1,4,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 256 -n 128 --transA=true --transB=true --perf_config v3:128,128,32,16,16,8,1,4,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +# ============================================================================ +# Suite 3: ONLY B uses LDS transpose (transA=false, transB=false) +# 3 tests - A uses regular load +# Only fp8_fp8 and bf8_bf8 +# ============================================================================ +[[suite]] +name = "lds_transpose_B_only_fp8" + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=false --transB=false --perf_config v3:64,64,16,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 64 --transA=false --transB=false --perf_config v3:128,64,32,32,16,1,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 256 -n 128 --transA=false --transB=false --perf_config v3:128,128,32,32,32,4,1,4,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_bf8", "bf8_fp8"] + +# ============================================================================ +# Suite 4: Mixed FP8/BF8 types (fp8_bf8 and bf8_fp8 only) +# 3 tests: both operands, only A, only B +# ============================================================================ +[[suite]] +name = "lds_transpose_mixed_fp8_bf8" + +[[suite.test]] +config = "-g 1 -m 64 -k 128 -n 64 --transA=true --transB=false --perf_config v3:64,64,32,16,16,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8"] + +[[suite.test]] +config = "-g 1 -m 128 -k 128 -n 128 --transA=true --transB=true --perf_config v3:128,128,16,32,32,8,1,3,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8"] + +[[suite.test]] +config = "-g 1 -m 16 -k 128 -n 32 --transA=false --transB=false --perf_config v3:16,16,16,16,16,8,1,4,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8"] + +[[suite.test]] +config = "-g 1 -m 64 -k 64 -n 64 --transA=false --transB=false --perf_config v3:32,32,32,32,32,4,1,4,2,1,1" +[[suite.test.exclude]] +name = "data type" +values = ["fp8_fp8", "bf8_bf8"]