Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@ enum class MfmaTypeId : uint32_t {
Fp8Fp8TyId,
Fp8Bf8TyId,
Bf8Fp8TyId,
Bf8Bf8TyId
Bf8Bf8TyId,
// FP8 via scaled MFMA (uses mfma_scale_f32_16x16x128_f8f6f4 with cbsz=0)
// These provide larger K dimension (128 for 16x16, 64 for 32x32)
Fp8Fp8ScaledTyId,
Fp8Bf8ScaledTyId,
Bf8Fp8ScaledTyId,
Bf8Bf8ScaledTyId
};

struct MfmaInsnInfo {
Expand Down Expand Up @@ -155,6 +161,10 @@ class MfmaInsnGroup {
bool isCoherentWithK(int64_t kPack, int64_t kPerBlock,
int64_t scheduleVersion);
SmallString<16> getROCDLIntrinsicName() { return groupAttr.insn; }

// Check if this is FP8 using scaled MFMA (mfma_scale with cbsz=0, blgp=0)
// These instructions have larger K dimension (128 for 16x16, 64 for 32x32)
bool isScaledFp8() const;
};

} // namespace rock
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Rock/IR/RockAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def Rock_LDSTransposeConfigAttr : Rock_Attr<"LDSTransposeConfig", []> {
and tiling parameters.

- DDim: Matrix-multiply accelerator instruction D dimension (M or N, typically 16 or 32)
- KDim: Matrix-multiply accelerator instruction K dimension (typically 8, 16, or 32)
- KDim: Matrix-multiply accelerator instruction K dimension (8, 16, 32, 64, or 128)
- mPerBlock: M dimension size per block
- nPerBlock: N dimension size per block
- kPerBlock: K dimension size per block
Expand Down
17 changes: 12 additions & 5 deletions mlir/include/mlir/Dialect/Rock/IR/RockOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1229,9 +1229,10 @@ defvar SameShapeVectorOfI1 = [{
def Rock_LDSTransposeLoadOp
: Rock_Op<"lds_transpose_load", [DeclareOpInterfaceMethods<
MemoryEffectsOpInterface>]>,
Arguments<(ins Arg<MemRefOf<[F16, BF16]>, "LDS source buffer">:$source,
Arguments<(ins Arg<MemRefOf<[F16, BF16, F8E4M3FN, F8E5M2]>,
"LDS source buffer">:$source,
Variadic<Index>:$indices)>,
Results<(outs VectorOfLengthAndType<[4], [F16, BF16]>:$result)> {
Results<(outs AnyVectorOfNonZeroRank:$result)> {
let summary =
"Hardware-assisted LDS transpose load for matrix accelerator tile";
let description = [{
Expand All @@ -1240,9 +1241,15 @@ def Rock_LDSTransposeLoadOp
The tile dimensions match the selected matrix-multiply accelerator
instruction geometry (`dDim × instrK`), where:
- `dDim` is the accelerator M/N dimension (e.g., 16 or 32)
- `instrK` is the accelerator K dimension (e.g., 8, 16, or 32)
The operation returns a vector of 4 elements per thread containing
transposed elements in a layout suitable for matrix accelerator instructions.
- `instrK` is the accelerator K dimension (8, 16, 32, 64, or 128)

For 16-bit types (f16, bf16):
- Uses ds_read_tr16_b64 instruction
- Returns vector<4xtype> (4 elements per thread)

For 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8 for gfx950):
- Uses ds_read_tr8_b64 instruction
- Returns vector<8xtype> (8 elements per thread)

Benefits:
- Reduces LDS bank conflicts through optimized access patterns
Expand Down
17 changes: 13 additions & 4 deletions mlir/include/mlir/Dialect/Rock/utility/LdsTransposeLoad.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,17 @@
// to the LDS transpose load operation in an accelerator-friendly layout.
//
// It is intended to simplify the IR generation logic and ensure
// consistent handling of f16/bf16 matrix accelerator tile loads from LDS
// memory.
// consistent handling of f16/bf16/fp8/bf8 matrix accelerator tile loads from
// LDS memory.
//
// Supported element types:
// - f16, bf16: uses ds_read_tr16_b64 (returns 4 elements per thread)
// - f8E4M3FN, f8E5M2 (OCP FP8): uses ds_read_tr8_b64 (returns 8 elements)
//
// Supported MFMA geometries:
// - Standard: (16,16), (16,32), (32,8), (32,16) - single-rate or double-rate
// - Scaled FP8: (16,128) - quad-rate (4 ds_read_tr8 calls per K tile)
// - Scaled FP8: (32,64) - quad-rate (4 ds_read_tr8 calls per K tile)
//
//===----------------------------------------------------------------------===//

Expand All @@ -43,7 +52,7 @@ enum class OperandKind { A, B };
// Build LDS transpose config attribute from already-computed MFMA params.
// Used in BlockwiseLoadTileToThreadwise when decision was made upstream.
// Requires mfmaDDim > 0 and mfmaKDim > 0 (asserted).
// Valid combinations: (16,16), (16,32), (32,8), (32,16)
// Valid combinations: (16,16), (16,32), (16,128), (32,8), (32,16), (32,64)
LDSTransposeConfigAttr buildTransposeAttrFromParams(
PatternRewriter &rewriter, int64_t mfmaDDim, int64_t mfmaKDim,
int64_t mPerBlock, int64_t nPerBlock, int64_t kPerBlock, int64_t mPerWave,
Expand All @@ -60,7 +69,7 @@ struct LDSTransposeDecision {
bool enableA{false}; // Enable for operand A
bool enableB{false}; // Enable for operand B
int64_t mfmaDDim{0}; // MFMA D dimension (M or N, 16 or 32)
int64_t mfmaKDim{0}; // MFMA K dimension (8, 16, or 32)
int64_t mfmaKDim{0}; // MFMA K dimension (8, 16, 32, 64, or 128)
};

// Decides whether to enable LDS transpose for operands A and B
Expand Down
111 changes: 110 additions & 1 deletion mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,13 @@ static auto getMfmaInsnInfoMap = []() -> const llvm::StringMap<MfmaInsnInfo> & {
{ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName(),
{MfmaTypeId::Bf8Bf8TyId, 16, 32, 1}},

// fp4
// Scaled MFMA instructions (FP4 and scaled FP8 types)
// Note: FP8 scaled types (Fp8Fp8ScaledTyId, Fp8Bf8ScaledTyId, etc.)
// use the same underlying instruction with identical (mfmaDDim, k,
// blocksMfma).
// Since deriveAttr only uses those fields (not MfmaTypeId), we only need
// one entry per instruction. The type differentiation happens elsewhere
// via cbsz/blgp parameters at code generation time.
{ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName(),
{MfmaTypeId::Fp4TyId, 16, 128, 1}},
{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName(),
Expand Down Expand Up @@ -448,6 +454,25 @@ static auto getMfmaInsnGroupAttrMapGfx950 = []() {
{{MfmaTypeId::Fp4TyId, 32, 32},
{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}},

// FP8 via scaled MFMA (cbsz=0, blgp=0 mode)
// 16x16 with K=128, 32x32 with K=64
{{MfmaTypeId::Fp8Fp8ScaledTyId, 16, 16},
{ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName()}},
{{MfmaTypeId::Fp8Fp8ScaledTyId, 32, 32},
{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}},
{{MfmaTypeId::Fp8Bf8ScaledTyId, 16, 16},
{ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName()}},
{{MfmaTypeId::Fp8Bf8ScaledTyId, 32, 32},
{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}},
{{MfmaTypeId::Bf8Fp8ScaledTyId, 16, 16},
{ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName()}},
{{MfmaTypeId::Bf8Fp8ScaledTyId, 32, 32},
{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}},
{{MfmaTypeId::Bf8Bf8ScaledTyId, 16, 16},
{ROCDL::mfma_scale_f32_16x16x128_f8f6f4::getOperationName()}},
{{MfmaTypeId::Bf8Bf8ScaledTyId, 32, 32},
{ROCDL::mfma_scale_f32_32x32x64_f8f6f4::getOperationName()}},

// i8 double rate
{{MfmaTypeId::I8TyId, 32, 32},
{ROCDL::mfma_i32_32x32x32_i8::getOperationName()}},
Expand Down Expand Up @@ -583,6 +608,28 @@ static MfmaTypeId convertTypesToId(Type dataTypeA, Type dataTypeB) {
llvm_unreachable("Unsupported input argument type.");
}

// Convert native FP8 TypeId to scaled FP8 TypeId for gfx950 scaled MFMA
static std::optional<MfmaTypeId> getScaledFp8TypeId(MfmaTypeId nativeTypeId) {
switch (nativeTypeId) {
case MfmaTypeId::Fp8Fp8TyId:
return MfmaTypeId::Fp8Fp8ScaledTyId;
case MfmaTypeId::Fp8Bf8TyId:
return MfmaTypeId::Fp8Bf8ScaledTyId;
case MfmaTypeId::Bf8Fp8TyId:
return MfmaTypeId::Bf8Fp8ScaledTyId;
case MfmaTypeId::Bf8Bf8TyId:
return MfmaTypeId::Bf8Bf8ScaledTyId;
default:
return std::nullopt;
}
}

// Check if this is a native FP8 type (not scaled)
static bool isNativeFp8TypeId(MfmaTypeId typeId) {
return typeId == MfmaTypeId::Fp8Fp8TyId || typeId == MfmaTypeId::Fp8Bf8TyId ||
typeId == MfmaTypeId::Bf8Fp8TyId || typeId == MfmaTypeId::Bf8Bf8TyId;
}

FailureOr<MfmaInsnGroup>
MfmaInsnGroup::select(Type elementTypeA, Type elementTypeB, StringRef arch,
int64_t mnPerXdl, int64_t kPack, int64_t kPackPerBlock,
Expand Down Expand Up @@ -624,6 +671,44 @@ MfmaInsnGroup::select(Type elementTypeA, Type elementTypeB, StringRef arch,
};

auto selectForGfx950 = [&]() {
int64_t kPerBlock = kPack * kPackPerBlock;

// For FP8 types, try scaled MFMA first if kPerBlock is large enough
// Scaled MFMA has K=128 for 16x16 and K=64 for 32x32
if (isNativeFp8TypeId(key.type)) {
int64_t scaledK = (mPerMfmaGroup == 16) ? 128 : 64;
if (kPerBlock >= scaledK) {
LLVM_DEBUG(llvm::dbgs()
<< ">>> Trying scaled FP8: kPerBlock=" << kPerBlock
<< " >= scaledK=" << scaledK << "\n");
auto scaledTypeId = getScaledFp8TypeId(key.type);
if (scaledTypeId) {
MfmaInsnGroupSelectKey scaledKey = {*scaledTypeId, mPerMfmaGroup,
nPerMfmaGroup};
const auto &gfx950Map = getMfmaInsnGroupAttrMapGfx950();
auto it = gfx950Map.find(scaledKey);
if (it != gfx950Map.end()) {
MfmaInsnGroupAttr groupAttr = (*it).second;
auto maybeInsn = MfmaInsn::select(groupAttr.insn);
if (succeeded(maybeInsn)) {
auto scaledResult = MfmaInsnGroup(elementTypeA, elementTypeB,
*maybeInsn, groupAttr);
if (scaledResult.isCoherentWithK(kPack, kPackPerBlock,
scheduleVersion)) {
LLVM_DEBUG(llvm::dbgs() << ">>> SELECTED SCALED FP8 MFMA: K="
<< maybeInsn->getAttr().k << "\n");
result = scaledResult;
return;
}
}
}
}
LLVM_DEBUG(
llvm::dbgs()
<< ">>> Scaled FP8 MFMA not suitable, falling back to native\n");
}
}

// gfx950 has double rate instructions. Select from those first.
selectFrom(getMfmaInsnGroupAttrMapGfx950());
if (succeeded(result)) {
Expand Down Expand Up @@ -714,3 +799,27 @@ bool MfmaInsnGroup::isCoherentWithK(int64_t kpack, int64_t kPerBlock,
int64_t scheduleVersion) {
return insn.isCoherentWithK(kpack, kPerBlock, scheduleVersion);
}

bool MfmaInsnGroup::isScaledFp8() const {
// Check if the instruction is a scaled MFMA
// (rocdl.mfma.scale.f32.*x*x*.f8f6f4)
StringRef insnName = groupAttr.insn;
bool isScaledInsn = insnName.contains("mfma.scale.f32.16x16x128.f8f6f4") ||
insnName.contains("mfma.scale.f32.32x32x64.f8f6f4");
if (!isScaledInsn)
return false;

// Check if the element type is FP8 (not FP4)
// FP8 types: Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ
// FP4 types: Float4E2M1FN
bool isFp8A = isa<Float8E4M3FNType>(elementTypeA) ||
isa<Float8E4M3FNUZType>(elementTypeA) ||
isa<Float8E5M2Type>(elementTypeA) ||
isa<Float8E5M2FNUZType>(elementTypeA);
bool isFp8B = isa<Float8E4M3FNType>(elementTypeB) ||
isa<Float8E4M3FNUZType>(elementTypeB) ||
isa<Float8E5M2Type>(elementTypeB) ||
isa<Float8E5M2FNUZType>(elementTypeB);

return isFp8A && isFp8B;
}
35 changes: 30 additions & 5 deletions mlir/lib/Dialect/Rock/IR/RockDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,8 +560,11 @@ LogicalResult TransformMapAttr::verify(

// Helper function to check valid MFMA geometry for LDS transpose
static bool isValidLdsTransposeMfmaGeometry(int64_t dDim, int64_t kDim) {
return (dDim == 16 && (kDim == 16 || kDim == 32)) ||
(dDim == 32 && (kDim == 8 || kDim == 16));
// Supported geometries:
// Standard: (16,16), (16,32), (32,8), (32,16)
// Scaled FP8: (16,128), (32,64)
return (dDim == 16 && (kDim == 16 || kDim == 32 || kDim == 128)) ||
(dDim == 32 && (kDim == 8 || kDim == 16 || kDim == 64));
}

LogicalResult LDSTransposeConfigAttr::verify(
Expand All @@ -571,9 +574,10 @@ LogicalResult LDSTransposeConfigAttr::verify(

// Validate MFMA geometry
if (!isValidLdsTransposeMfmaGeometry(dDim, kDim)) {
return emitError() << "invalid MFMA geometry (" << dDim << "x" << kDim
<< ") for LDS transpose - valid combinations: "
"(16,16), (16,32), (32,8), (32,16)";
return emitError()
<< "invalid MFMA geometry (" << dDim << "x" << kDim
<< ") for LDS transpose - valid combinations: "
"(16,16), (16,32), (16,128), (32,8), (32,16), (32,64)";
}

// Validate positive dimensions
Expand Down Expand Up @@ -2161,6 +2165,27 @@ LogicalResult LDSTransposeLoadOp::verify() {
<< srcElemType << ")";
}

// Verify result vector length based on element type:
// - 16-bit types (f16, bf16): ds_read_tr16_b64 returns 4 elements
// - 8-bit types (f8E4M3FN, f8E5M2 - OCP FP8 for gfx950): ds_read_tr8_b64
// returns 8 elements
int64_t expectedVecLen;
if (srcElemType.isF16() || srcElemType.isBF16()) {
expectedVecLen = 4;
} else if (isa<Float8E4M3FNType>(srcElemType) ||
isa<Float8E5M2Type>(srcElemType)) {
expectedVecLen = 8;
} else {
return emitOpError("unsupported element type for LDS transpose load: ")
<< srcElemType;
}

if (resultType.getNumElements() != expectedVecLen) {
return emitOpError("expected result vector of ")
<< expectedVecLen << " elements for " << srcElemType
<< " type, but got " << resultType.getNumElements();
}

// Check hardware support using AmdArchDb
StringRef arch = rock::getArchValue(*this);
AmdArchInfo archInfo = rock::lookupArchInfo(arch);
Expand Down
26 changes: 26 additions & 0 deletions mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,22 @@ void MfmaEmitter::emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA,
VectorType vectorType = mfmaGroup.getRetType();
auto outputOffset = llvm::to_vector(regCOffset);
bool isScaled = scaleA && scaleB;
bool isScaledFp8 = mfmaGroup.isScaledFp8();

// For scaled FP8 MFMA without explicit scale buffers, create neutral scales.
// In cbsz=0, blgp=0 mode, a scale exponent value of 0 means no scaling
// because 2^0 = 1. For Float8E8M0FNU (an exponent-only format), the call
// getFloatAttr(scaleType, 0.0) is used to produce the encoding with
// exponent = 0 (all-zero bit pattern), which corresponds to a scale of 1.
Value neutralScaleA, neutralScaleB;
if (isScaledFp8 && !isScaled) {
Type scaleType = b.getType<Float8E8M0FNUType>();
auto neutralScaleAttr = b.getFloatAttr(scaleType, 0.0);
neutralScaleA =
arith::ConstantOp::create(b, loc, scaleType, neutralScaleAttr);
neutralScaleB =
arith::ConstantOp::create(b, loc, scaleType, neutralScaleAttr);
}

for (int64_t i = 0; i < nResultVectors; ++i) {
Value offset = b.createOrFold<arith::ConstantIndexOp>(loc, i);
Expand All @@ -203,11 +219,21 @@ void MfmaEmitter::emitThreadwiseLoop(OpBuilder &b, Location loc, Value argA,

Value vectorD;
if (isScaled) {
// Explicit scale buffers provided (FP4 or scaled FP8 with explicit scales)
auto mfma = amdgpu::ScaledMFMAOp::create(
b, loc, vectorType, mfmaDDim, mfmaDDim, mfmaAttr.k, argA, argB,
vectorC, scaleA, scaleB, /*scalesIdxA=*/0, /*scalesIdxB=*/0);
vectorD = mfma.getDestD();
} else if (isScaledFp8) {
// Scaled FP8 MFMA (K=128 for 16x16, K=64 for 32x32) without explicit scales
// Use neutral scale values (0) which means 2^0 = 1 (no scaling)
auto mfma = amdgpu::ScaledMFMAOp::create(
b, loc, vectorType, mfmaDDim, mfmaDDim, mfmaAttr.k, argA, argB,
vectorC, neutralScaleA, neutralScaleB,
/*scalesIdxA=*/0, /*scalesIdxB=*/0);
vectorD = mfma.getDestD();
} else {
// Regular MFMA
auto mfma = amdgpu::MFMAOp::create(
b, loc, vectorType, mfmaDDim, mfmaDDim, mfmaAttr.k,
mfmaAttr.blocksMfma, argA, argB, vectorC, /*cbsz=*/imms[i].cbsz,
Expand Down
Loading