diff --git a/include/flydsl-c/FlyROCDLDialect.h b/include/flydsl-c/FlyROCDLDialect.h index 0e2777ed..f4640d25 100644 --- a/include/flydsl-c/FlyROCDLDialect.h +++ b/include/flydsl-c/FlyROCDLDialect.h @@ -30,6 +30,26 @@ MLIR_CAPI_EXPORTED MlirType mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetElemTyA(MlirType MLIR_CAPI_EXPORTED MlirType mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetElemTyB(MlirType type); MLIR_CAPI_EXPORTED MlirType mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetElemTyAcc(MlirType type); +//===----------------------------------------------------------------------===// +// MmaAtomGFX1250_WMMAType +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAFlyROCDLMmaAtomGFX1250_WMMAType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetTypeID(void); + +// Constructor +MLIR_CAPI_EXPORTED MlirType mlirFlyROCDLMmaAtomGFX1250_WMMATypeGet(int32_t m, int32_t n, int32_t k, + MlirType elemTyA, MlirType elemTyB, + MlirType elemTyAcc); + +// Accessors +MLIR_CAPI_EXPORTED int32_t mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetM(MlirType type); +MLIR_CAPI_EXPORTED int32_t mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetN(MlirType type); +MLIR_CAPI_EXPORTED int32_t mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetK(MlirType type); +MLIR_CAPI_EXPORTED MlirType mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetElemTyA(MlirType type); +MLIR_CAPI_EXPORTED MlirType mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetElemTyB(MlirType type); +MLIR_CAPI_EXPORTED MlirType mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetElemTyAcc(MlirType type); + //===----------------------------------------------------------------------===// // CopyOpCDNA3BufferLDSTType //===----------------------------------------------------------------------===// diff --git a/include/flydsl/Dialect/Fly/IR/FlyDialect.td b/include/flydsl/Dialect/Fly/IR/FlyDialect.td index 73dd85cd..48ac923a 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyDialect.td +++ b/include/flydsl/Dialect/Fly/IR/FlyDialect.td @@ -16,7 +16,6 @@ def Fly_Dialect : Dialect { let useDefaultTypePrinterParser = 1; let useDefaultAttributePrinterParser = 1; - let usePropertiesForAttributes = 1; } class Fly_Type traits = []> diff --git a/include/flydsl/Dialect/FlyROCDL/IR/Dialect.td b/include/flydsl/Dialect/FlyROCDL/IR/Dialect.td index bae2c63b..60c2ba26 100644 --- a/include/flydsl/Dialect/FlyROCDL/IR/Dialect.td +++ b/include/flydsl/Dialect/FlyROCDL/IR/Dialect.td @@ -16,7 +16,6 @@ def FlyROCDL_Dialect : Dialect { ]; let useDefaultTypePrinterParser = 1; - let usePropertiesForAttributes = 1; } class FlyxROCL_MmaAtom traits = []> diff --git a/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td b/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td index e3275f06..5c73a089 100644 --- a/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td +++ b/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td @@ -30,4 +30,29 @@ def FlyROCDL_MmaAtomCDNA3_MFMA : FlyxROCL_MmaAtom<"MmaAtomCDNA3_MFMA", "atom.cdn // MmaAtom CDNA4 //===----------------------------------------------------------------------===// + + +//===----------------------------------------------------------------------===// +// MmaAtom GFX1250 — WMMA wave32 +//===----------------------------------------------------------------------===// + +def FlyROCDL_MmaAtomGFX1250_WMMA : FlyxROCL_MmaAtom<"MmaAtomGFX1250_WMMA", "atom.gfx1250.wmma", []> { + let parameters = (ins + "int32_t":$m, + "int32_t":$n, + "int32_t":$k, + "Type":$elemTyA, + "Type":$elemTyB, + "Type":$elemTyAcc + ); + let assemblyFormat = "`<` custom($m, $n, $k) `,` `(` $elemTyA `,` $elemTyB `)` `->` $elemTyAcc `>`"; + + let builders = [ + TypeBuilderWithInferredContext<(ins "int32_t":$m, "int32_t":$n, "int32_t":$k, "Type":$elemTyA, "Type":$elemTyB, "Type":$elemTyAcc), [{ + return $_get(elemTyA.getContext(), m, n, k, elemTyA, elemTyB, elemTyAcc); + }]> + ]; + let genVerifyDecl = 1; +} + #endif // FLYROCDL_MMAATOM diff --git a/kernels/mxfp4_gemm_gfx1250.py b/kernels/mxfp4_gemm_gfx1250.py new file mode 100644 index 00000000..8ac7836f --- /dev/null +++ b/kernels/mxfp4_gemm_gfx1250.py @@ -0,0 +1,784 @@ +"""MXFP4 GEMM kernel for gfx1250. + +Uses V_WMMA_SCALE_F32_16X16X128_F8F6F4 with FP4 (E2M1) data and E8M0 block scales. +Supports N-stage buffering (2/3/4), TDM async copy, cluster MCAST. +""" + +import flydsl.compiler as flyc +import flydsl.expr as fx + +from flydsl._mlir import ir +from flydsl.compiler.kernel_function import CompilationContext +from flydsl.expr import arith, buffer_ops, gpu, range_constexpr, rocdl, tdm_ops, vector +from flydsl._mlir.dialects import llvm as llvm_dialect, memref as memref_dialect +from flydsl.expr.arith import _to_raw as _raw +from flydsl.expr.typing import T +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr, get_op_result_or_value + +from kernels.layout_utils import idx2crd +from kernels.pipeline_utils import make_tail_plan + +# WMMA tile dimensions for MXFP4 +WMMA_M, WMMA_N, WMMA_K = 16, 16, 128 +WAVE_SIZE = 32 +PACK_FACTOR = 2 # 2 FP4 elements per byte +SCALE_BLOCK = 32 # 32 FP4 elements per E8M0 scale +SCALES_PER_WMMA = WMMA_K // SCALE_BLOCK # 4 + +# LDS padding in bytes (4 DWORDs = 16 bytes, matches SP3) +LDS_PAD_A_BYTES = 16 +LDS_PAD_B_BYTES = 16 + +_STAGE_NAMES = ("ping", "pong", "pang", "pung") + + +def compile_mxfp4_gemm( + *, + M: int = 0, + N: int = 0, + K: int, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, + m_warp: int = 2, + n_warp: int = 2, + num_buffers: int = 2, + waves_per_eu: int = None, + l2_prefetch_distance: int = 2, + cluster_m: int = 1, + cluster_n: int = 1, + scale_preshuffle: bool = True, +): + """Compile an MXFP4 GEMM kernel with TDM async copy and multi-stage buffering. + + Returns a JitFunction: launch_fn(arg_c, arg_a, arg_b, arg_a_scale, arg_b_scale, M, N, stream) + """ + _ = (M, N) + if num_buffers not in (2, 3, 4): + raise ValueError(f"num_buffers must be 2, 3, or 4, got {num_buffers}") + + use_cluster = cluster_m > 1 or cluster_n > 1 + if use_cluster: + if cluster_m * cluster_n > 16: + raise ValueError( + f"cluster_m * cluster_n must be <= 16, got {cluster_m}*{cluster_n}") + effective_waves_per_eu = waves_per_eu + if use_cluster and effective_waves_per_eu is None: + effective_waves_per_eu = 1 + + num_warps = m_warp * n_warp + block_threads = num_warps * WAVE_SIZE + + packed_tile_k = tile_k // PACK_FACTOR # bytes along K in LDS per row + scale_k_per_tile = tile_k // SCALE_BLOCK + K_packed = K // PACK_FACTOR + K_scale = K // SCALE_BLOCK + + if K % tile_k != 0: + raise ValueError(f"K must be divisible by tile_k={tile_k}, got K={K}") + if tile_k % WMMA_K != 0: + raise ValueError(f"tile_k must be a multiple of {WMMA_K}, got {tile_k}") + if tile_m % WMMA_M != 0: + raise ValueError(f"tile_m must be a multiple of {WMMA_M}, got {tile_m}") + if tile_n % WMMA_N != 0: + raise ValueError(f"tile_n must be a multiple of {WMMA_N}, got {tile_n}") + if packed_tile_k % 4 != 0: + raise ValueError(f"packed_tile_k must be a multiple of 4, got {packed_tile_k}") + if scale_k_per_tile % 4 != 0: + raise ValueError( + f"scale_k_per_tile must be a multiple of 4 (tile_k >= 128), got {scale_k_per_tile}") + + warp_tile_m = tile_m // m_warp + warp_tile_n = tile_n // n_warp + if warp_tile_m % WMMA_M != 0: + raise ValueError(f"warp_tile_m={warp_tile_m} must be a multiple of {WMMA_M}") + if warp_tile_n % WMMA_N != 0: + raise ValueError(f"warp_tile_n={warp_tile_n} must be a multiple of {WMMA_N}") + + num_k_tiles = K // tile_k + if num_k_tiles < num_buffers: + raise ValueError( + f"{num_buffers}-stage buffering requires num_k_tiles >= {num_buffers}, " + f"got {num_k_tiles}") + + gpu_arch = str(get_hip_arch(timeout_s=300)) + assert gpu_arch.startswith("gfx1250"), f"Expected gfx1250, got {gpu_arch}" + + k_wmma_steps = tile_k // WMMA_K + wmma_m_rep = warp_tile_m // WMMA_M + wmma_n_rep = warp_tile_n // WMMA_N + n_accs = wmma_m_rep * wmma_n_rep + + lds_a_stride_bytes = packed_tile_k + LDS_PAD_A_BYTES + lds_b_stride_bytes = packed_tile_k + LDS_PAD_B_BYTES + + lds_a_data_bytes = tile_m * lds_a_stride_bytes + lds_b_data_bytes = tile_n * lds_b_stride_bytes + lds_a_scale_bytes = tile_m * scale_k_per_tile + lds_b_scale_bytes = tile_n * scale_k_per_tile + # Interleaved scale layout: [WMMA_M * m_warp, wmma_m_rep * scale_k_per_tile] + interleaved_scale_cols_a = wmma_m_rep * scale_k_per_tile + interleaved_scale_cols_b = wmma_n_rep * scale_k_per_tile + + stage_allocators = [] + stage_a_data_off = [] + stage_b_data_off = [] + stage_a_scale_off = [] + stage_b_scale_off = [] + + for i in range(num_buffers): + name = _STAGE_NAMES[i] + alloc = SmemAllocator(None, arch=gpu_arch, global_sym_name=f"mxfp4_{name}") + + off = alloc._align(alloc.ptr, 16) + stage_a_data_off.append(off) + alloc.ptr = off + lds_a_data_bytes + + off = alloc._align(alloc.ptr, 16) + stage_b_data_off.append(off) + alloc.ptr = off + lds_b_data_bytes + + off = alloc._align(alloc.ptr, 16) + stage_a_scale_off.append(off) + alloc.ptr = off + lds_a_scale_bytes + + off = alloc._align(alloc.ptr, 16) + stage_b_scale_off.append(off) + alloc.ptr = off + lds_b_scale_bytes + + stage_allocators.append(alloc) + + pre_loaded = num_buffers - 1 + loop_iters = (num_k_tiles - pre_loaded) // num_buffers + _tail_start = loop_iters * num_buffers + extra = num_k_tiles - _tail_start - pre_loaded + _raw_tail_plan = make_tail_plan(num_buffers, pre_loaded, extra) + + # Number of TDM loads per step: A_data + B_data + A_scale + B_scale = 4 + TDM_LOADS_PER_STEP = 4 + + # Scale tail plan outstanding values: make_tail_plan uses 2 (for fp16's A+B), + # but MXFP4 has 4 loads per step (A_data + B_data + A_scale + B_scale). + tail_plan = [ + (ls, cs, o * TDM_LOADS_PER_STEP // 2 if o > 0 else o) + for ls, cs, o in _raw_tail_plan + ] + + # Number of LDS loads per K-subtile (for s_wait_dscnt): + # A frag: wmma_m_rep * 2 ds_load_b128 + # B frag: wmma_n_rep * 2 ds_load_b128 + # A scale: 1 ds_load_b128 (interleave) or wmma_m_rep ds_load_b32 + # B scale: 1 ds_load_b128 (interleave) or wmma_n_rep ds_load_b32 + if scale_preshuffle: + a_scale_b128_loads = (wmma_m_rep + 3) // 4 + b_scale_b128_loads = (wmma_n_rep + 3) // 4 + LOADS_PER_SUBTILE = wmma_m_rep * 2 + wmma_n_rep * 2 + a_scale_b128_loads + b_scale_b128_loads + else: + LOADS_PER_SUBTILE = wmma_m_rep * 2 + wmma_n_rep * 2 + wmma_m_rep + wmma_n_rep + + @flyc.kernel + def kernel_mxfp4_gemm( + arg_c: fx.Tensor, + arg_a: fx.Tensor, + arg_b: fx.Tensor, + arg_a_scale: fx.Tensor, + arg_b_scale: fx.Tensor, + i32_m: fx.Int32, + i32_n: fx.Int32, + ): + # Disable VALU stall for back-to-back WMMA + llvm_dialect.inline_asm( + None, [], + "s_setreg_imm32_b32 hwreg(26, 4, 1), 1", + "", has_side_effects=True, + ) + + tx = gpu.thread_id("x") + bx = gpu.block_id("x") + by = gpu.block_id("y") + + blk_m = bx * arith.index(tile_m) + blk_n = by * arith.index(tile_n) + + if use_cluster: + local_x, local_y = gpu.compute_cluster_position() + a_mcast_mask, b_mcast_mask = gpu.compute_mcast_masks( + local_x, local_y, cluster_m, cluster_n) + else: + a_mcast_mask = 0 + b_mcast_mask = 0 + + layout_thr = fx.make_layout( + (m_warp, n_warp, 2, 16), + (n_warp * WAVE_SIZE, WAVE_SIZE, 16, 1)) + thr_coord = idx2crd(tx, layout_thr) + wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( + thr_coord[0], thr_coord[1], thr_coord[2], thr_coord[3]) + + warp_m_base = wave_m_idx * arith.index(warp_tile_m) + warp_n_base = wave_n_idx * arith.index(warp_tile_n) + + m_idx = arith.index_cast(T.index, i32_m.ir_value()) + n_idx = arith.index_cast(T.index, i32_n.ir_value()) + n_stride = arith.index(N) + c_nrec = m_idx * n_stride * arith.index(4) + c_rsrc = buffer_ops.create_buffer_resource(arg_c, num_records_bytes=c_nrec) + + def _get_lds_memref(lds_ptr): + if isinstance(lds_ptr, SmemPtr): + return get_op_result_or_value(lds_ptr.get()) + return get_op_result_or_value(lds_ptr) + + def copy_a_data_to_lds(k_base, lds_mem_ref): + k_packed_off = k_base / arith.index(PACK_FACTOR) + desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_a, lds_memref=lds_mem_ref, + global_offset=(blk_m, k_packed_off), + tensor_shape=(tile_m, packed_tile_k), + strides=(K_packed, 1), + tile_shape=(tile_m, packed_tile_k), + elem_bytes=1, + pad_interval=packed_tile_k, pad_amount=LDS_PAD_A_BYTES, + num_warps=num_warps, + workgroup_mask=a_mcast_mask) + tdm_ops.tensor_load_2d(desc) + + def copy_b_data_to_lds(k_base, lds_mem_ref): + k_packed_off = k_base / arith.index(PACK_FACTOR) + desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_b, lds_memref=lds_mem_ref, + global_offset=(blk_n, k_packed_off), + tensor_shape=(tile_n, packed_tile_k), + strides=(K_packed, 1), + tile_shape=(tile_n, packed_tile_k), + elem_bytes=1, + pad_interval=packed_tile_k, pad_amount=LDS_PAD_B_BYTES, + num_warps=num_warps, + workgroup_mask=b_mcast_mask) + tdm_ops.tensor_load_2d(desc) + + def copy_a_scale_to_lds(k_base, lds_mem_ref): + k_scale_off = k_base / arith.index(SCALE_BLOCK) + if scale_preshuffle: + # Interleaved global: [M // wmma_m_rep, wmma_m_rep * K_scale] + outer_off = blk_m / arith.index(wmma_m_rep) + inner_off = k_scale_off * arith.index(wmma_m_rep) + desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_a_scale, lds_memref=lds_mem_ref, + global_offset=(outer_off, inner_off), + tensor_shape=(WMMA_M * m_warp, interleaved_scale_cols_a), + strides=(wmma_m_rep * K_scale, 1), + tile_shape=(WMMA_M * m_warp, interleaved_scale_cols_a), + elem_bytes=1, + pad_interval=0, pad_amount=0, + num_warps=num_warps, + workgroup_mask=a_mcast_mask) + else: + desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_a_scale, lds_memref=lds_mem_ref, + global_offset=(blk_m, k_scale_off), + tensor_shape=(tile_m, scale_k_per_tile), + strides=(K_scale, 1), + tile_shape=(tile_m, scale_k_per_tile), + elem_bytes=1, + pad_interval=0, pad_amount=0, + num_warps=num_warps, + workgroup_mask=a_mcast_mask) + tdm_ops.tensor_load_2d(desc) + + def copy_b_scale_to_lds(k_base, lds_mem_ref): + k_scale_off = k_base / arith.index(SCALE_BLOCK) + if scale_preshuffle: + outer_off = blk_n / arith.index(wmma_n_rep) + inner_off = k_scale_off * arith.index(wmma_n_rep) + desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_b_scale, lds_memref=lds_mem_ref, + global_offset=(outer_off, inner_off), + tensor_shape=(WMMA_N * n_warp, interleaved_scale_cols_b), + strides=(wmma_n_rep * K_scale, 1), + tile_shape=(WMMA_N * n_warp, interleaved_scale_cols_b), + elem_bytes=1, + pad_interval=0, pad_amount=0, + num_warps=num_warps, + workgroup_mask=b_mcast_mask) + else: + desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_b_scale, lds_memref=lds_mem_ref, + global_offset=(blk_n, k_scale_off), + tensor_shape=(tile_n, scale_k_per_tile), + strides=(K_scale, 1), + tile_shape=(tile_n, scale_k_per_tile), + elem_bytes=1, + pad_interval=0, pad_amount=0, + num_warps=num_warps, + workgroup_mask=b_mcast_mask) + tdm_ops.tensor_load_2d(desc) + + def issue_all_tdm_loads(k_base, a_mem, b_mem, as_mem, bs_mem): + copy_a_data_to_lds(k_base, a_mem) + copy_b_data_to_lds(k_base, b_mem) + copy_a_scale_to_lds(k_base, as_mem) + copy_b_scale_to_lds(k_base, bs_mem) + + elem_ty_lds = T.f16 + + def _precompute_a_lane_bases(lds_ptr): + """Precompute per-wm A fragment lane base addresses (in BYTES). + + Each lane loads 32 bytes = 64 FP4 (one K-half). + lane16 → M-row, lane_kgrp → K-half (0 or 1). + """ + lds_buffer = _get_lds_memref(lds_ptr) + row_base = (warp_m_base + lane16) * arith.index(lds_a_stride_bytes) + k_half_off = lane_kgrp * arith.index(32) # 32 bytes = 64 FP4 + bases = [] + for wm in range_constexpr(wmma_m_rep): + base = row_base + arith.index(wm * WMMA_M * lds_a_stride_bytes) + k_half_off + bases.append(base) + return lds_buffer, bases + + def _lds_load_b128(lds_buffer, byte_offset): + """Load 16 bytes from LDS at given byte offset via ds_load_b128.""" + from flydsl._mlir.dialects import llvm as _llvm, memref as _memref + lds_ptr_ty = ir.Type.parse("!llvm.ptr<3>") + raw_memref = arith.unwrap(lds_buffer) + lds_base = _memref.extract_aligned_pointer_as_index(raw_memref) + from flydsl.expr.arith import ArithValue as _AV + total_byte = _AV(lds_base) + byte_offset + addr_i32 = _raw(arith.index_cast(T.i32, total_byte)) + ptr_val = _llvm.inttoptr(lds_ptr_ty, addr_i32) + vec4_i32_ty = ir.VectorType.get([4], ir.IntegerType.get_signless(32)) + return llvm_dialect.load(vec4_i32_ty, ptr_val) + + def load_a_frag(lds_buffer, a_lane_base, ks): + """Load one 16x128 FP4 A-fragment from LDS. + + Returns vector<8xi32> (8 VGPRs, 64 FP4 per lane). + 2 x ds_load_b128 via direct LDS pointer load. + """ + k_byte_off = arith.index(ks * WMMA_K // PACK_FACTOR) # bytes per K-subtile + byte_off = a_lane_base + k_byte_off + v0 = _lds_load_b128(lds_buffer, byte_off) + v1 = _lds_load_b128(lds_buffer, byte_off + arith.index(16)) + return vector.shuffle(v0, v1, list(range(8))) + + def _precompute_b_lane_bases(lds_ptr): + """Precompute per-wn B fragment lane base addresses (in BYTES). + + B stored as [tile_n, packed_tile_k + pad] in LDS. + lane16 → N-row, lane_kgrp → K-half. + """ + lds_buffer = _get_lds_memref(lds_ptr) + row_base = (warp_n_base + lane16) * arith.index(lds_b_stride_bytes) + k_half_off = lane_kgrp * arith.index(32) + bases = [] + for wn in range_constexpr(wmma_n_rep): + base = row_base + arith.index(wn * WMMA_N * lds_b_stride_bytes) + k_half_off + bases.append(base) + return lds_buffer, bases + + def load_b_frag(lds_buffer, b_lane_base, ks): + """Load one 128x16 FP4 B-fragment from LDS. Same pattern as A.""" + k_byte_off = arith.index(ks * WMMA_K // PACK_FACTOR) + byte_off = b_lane_base + k_byte_off + v0 = _lds_load_b128(lds_buffer, byte_off) + v1 = _lds_load_b128(lds_buffer, byte_off + arith.index(16)) + return vector.shuffle(v0, v1, list(range(8))) + + def _precompute_scale_lane_bases(lds_ptr, warp_base, reps, interleaved_cols=0): + """Precompute scale lane bases (in BYTES). + + Original layout: [tile_m_or_n, scale_k_per_tile] bytes. + Interleaved layout: [WMMA_M * m_or_n_warp, wmma_rep * scale_k_per_tile] bytes. + """ + lds_buffer = _get_lds_memref(lds_ptr) + if scale_preshuffle and interleaved_cols > 0: + # Interleaved: row = (warp_base / reps) + lane16, stride = interleaved_cols + warp_lds_row = warp_base / arith.index(reps) + lane16 + base = warp_lds_row * arith.index(interleaved_cols) + return lds_buffer, [base] # single base for b128 load + else: + row_base = (warp_base + lane16) * arith.index(scale_k_per_tile) + bases = [] + for w in range_constexpr(reps): + base = row_base + arith.index(w * WMMA_M * scale_k_per_tile) + bases.append(base) + return lds_buffer, bases + + def _shuffle_scale_i32(val): + """Swap bytes 1 and 2 of an i32 scale value via v_perm_b32. + + FP4 data VGPR layout splits K=128 as: + V0-V3 lanes 0-15: K=0..31, V4-V7 lanes 0-15: K=64..95 + V0-V3 lanes 16-31: K=32..63, V4-V7 lanes 16-31: K=96..127 + The WMMA_SCALE hardware processes data in VGPR-group-first order, + so the scale i32 byte-to-K-block mapping is [0, 2, 1, 3]: + byte0 → K=0..31, byte1 → K=64..95, byte2 → K=32..63, byte3 → K=96..127 + Memory stores scales sequentially [K0,K1,K2,K3], so we swap bytes 1↔2 + to produce [K0,K2,K1,K3] using a single v_perm_b32 instruction. + """ + i32_ty = ir.IntegerType.get_signless(32) + return llvm_dialect.inline_asm( + i32_ty, [_raw(val) if not isinstance(val, ir.Value) else val], + "v_perm_b32 $0, $1, $1, 0x03010200", + "=v,v", has_side_effects=False, + ) + + def load_scale(lds_buffer, scale_base, ks): + """Load scale for one 16x128 WMMA from LDS. + + Returns i32 (1 VGPR) containing 4 packed E8M0 scale values, + shuffled to match the WMMA_SCALE instruction's byte-to-K-block + mapping: [K0, K2, K1, K3]. + ds_load_b32 via direct LDS pointer load. + """ + from flydsl._mlir.dialects import llvm as _llvm, memref as _memref + lds_ptr_ty = ir.Type.parse("!llvm.ptr<3>") + raw_memref = arith.unwrap(lds_buffer) + lds_base = _memref.extract_aligned_pointer_as_index(raw_memref) + # scale_k_per_tile bytes per row, ks-th group = ks * SCALES_PER_WMMA bytes + byte_off = scale_base + arith.index(ks * SCALES_PER_WMMA) + from flydsl.expr.arith import ArithValue as _AV + total_byte = _AV(lds_base) + byte_off + addr_i32 = _raw(arith.index_cast(T.i32, total_byte)) + ptr_val = _llvm.inttoptr(lds_ptr_ty, addr_i32) + i32_ty = ir.IntegerType.get_signless(32) + raw_scale = llvm_dialect.load(i32_ty, ptr_val) + if scale_preshuffle: + return raw_scale + return _shuffle_scale_i32(raw_scale) + + def load_scale_b128(lds_buffer, scale_base, reps, ks=0): + """Load all wmma_rep scales via ds_load_b128(s) for K-subtile *ks*. """ + ks_byte_off = ks * reps * SCALES_PER_WMMA + eff_base = scale_base if ks_byte_off == 0 else scale_base + arith.index(ks_byte_off) + num_loads = (reps + 3) // 4 + vecs = [] + for ld in range_constexpr(num_loads): + off = eff_base if ld == 0 else eff_base + arith.index(ld * 16) + vecs.append(_lds_load_b128(lds_buffer, off)) + results = [] + for i in range_constexpr(reps): + vi = vector.extract(vecs[i // 4], static_position=[i % 4], dynamic_position=[]) + if not scale_preshuffle: + vi = _shuffle_scale_i32(vi) + results.append(vi) + return results + + def load_k_subtile_frags(a_buf, a_bases, b_buf, b_bases, + as_buf, as_bases, bs_buf, bs_bases, ks): + """Batch-load all A/B fragments and scales for one K-subtile.""" + # Load B frags first (gives more time for A frags to arrive) + b_frags = [load_b_frag(b_buf, b_bases[wn], ks) + for wn in range_constexpr(wmma_n_rep)] + a_frags = [load_a_frag(a_buf, a_bases[wm], ks) + for wm in range_constexpr(wmma_m_rep)] + # Load scales + if scale_preshuffle: + b_scales = load_scale_b128(bs_buf, bs_bases[0], wmma_n_rep, ks) + a_scales = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) + else: + b_scales = [load_scale(bs_buf, bs_bases[wn], ks) + for wn in range_constexpr(wmma_n_rep)] + a_scales = [load_scale(as_buf, as_bases[wm], ks) + for wm in range_constexpr(wmma_m_rep)] + return a_frags, b_frags, a_scales, b_scales + + def do_k_subtile_wmma(a_frags, b_frags, a_scales, b_scales, accs): + """Execute all WMMAs for one K-subtile with scales. + + Uses wmma_scale_f32_16x16x128_f8f6f4 (gfx1250 wave32) with: + fmtA=4 (FP4/E2M1), fmtB=4 (FP4/E2M1), + scaleAType=0 (opsel lo), scaleBType=0 (opsel lo). + fmtScaleA/B defaults to 0 (E8M0). + + Operands are passed as (B, A) instead of (A, B) to compensate + for the WMMA output VGPR layout where lane16→col and + lane_kgrp→row_group. Swapping computes C^T, making the output + match the epilogue's row-major store pattern. + """ + for wm in range_constexpr(wmma_m_rep): + for wn in range_constexpr(wmma_n_rep): + idx = wm * wmma_n_rep + wn + accs[idx] = rocdl.wmma_scale_f32_16x16x128_f8f6f4( + T.vec(8, T.f32), + b_frags[wn], a_frags[wm], accs[idx], + b_scales[wn], a_scales[wm], + fmtA=4, fmtB=4, + scaleAType=0, scaleBType=0, + ) + return accs + + def compute_tile(accs_in, lds_a, lds_b, lds_as, lds_bs, emit_filler=None): + current_accs = list(accs_in) + + a_buf, a_bases = _precompute_a_lane_bases(lds_a) + b_buf, b_bases = _precompute_b_lane_bases(lds_b) + as_buf, as_bases = _precompute_scale_lane_bases( + lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a) + bs_buf, bs_bases = _precompute_scale_lane_bases( + lds_bs, warp_n_base, wmma_n_rep, interleaved_scale_cols_b) + + if k_wmma_steps == 1: + frags = load_k_subtile_frags( + a_buf, a_bases, b_buf, b_bases, + as_buf, as_bases, bs_buf, bs_bases, 0) + rocdl.s_wait_dscnt(0) + if emit_filler is not None: + emit_filler() + current_accs = do_k_subtile_wmma(*frags, current_accs) + else: + prev = load_k_subtile_frags( + a_buf, a_bases, b_buf, b_bases, + as_buf, as_bases, bs_buf, bs_bases, 0) + + # Main K-loop: overlap load[ks+1] with compute[ks] + for ks in range_constexpr(k_wmma_steps - 1): + next_frags = load_k_subtile_frags( + a_buf, a_bases, b_buf, b_bases, + as_buf, as_bases, bs_buf, bs_bases, ks + 1) + rocdl.s_wait_dscnt(LOADS_PER_SUBTILE) + current_accs = do_k_subtile_wmma(*prev, current_accs) + prev = next_frags + + # Epilogue + rocdl.s_wait_dscnt(0) + if emit_filler is not None: + rocdl.sched_barrier(0) + emit_filler() + current_accs = do_k_subtile_wmma(*prev, current_accs) + + return current_accs + + def hot_loop_scheduler(): + rocdl.sched_barrier(0) + + # --- Epilogue: vectorized buffer_store_b128 --- + # WMMA output VGPR layout (wave32, 16x16 tile): + # lane16 (lane_id % 16) → N column + # lane_kgrp (lane_id / 16) → M row group (0=rows 0-7, 1=rows 8-15) + # element[i] → M row offset within group + # We compensate by swapping A/B operands in the WMMA call (see + # do_k_subtile_wmma) so the WMMA effectively computes C^T, making + # the output VGPR layout match this epilogue's store pattern: + # lane16 → M row, lane_kgrp*8 + ele → N column group. + def epilogue_prepare_addrs(): + addrs = [] + for wm in range_constexpr(wmma_m_rep): + for wn in range_constexpr(wmma_n_rep): + row = blk_m + warp_m_base + arith.index(wm * WMMA_M) + lane16 + col_base = (blk_n + warp_n_base + arith.index(wn * WMMA_N) + + lane_kgrp * arith.index(8)) + for half in range_constexpr(2): + col = col_base + arith.index(half * 4) + c_off = row * n_stride + col + addrs.append(c_off) + return addrs + + def epilogue_stores(final_accs, addrs): + addr_idx = 0 + for wm in range_constexpr(wmma_m_rep): + for wn in range_constexpr(wmma_n_rep): + idx = wm * wmma_n_rep + wn + for half in range_constexpr(2): + vals = [vector.extract( + final_accs[idx], + static_position=[half * 4 + vi], + dynamic_position=[]) + for vi in range_constexpr(4)] + vec4 = vector.from_elements(T.vec(4, T.f32), vals) + buffer_ops.buffer_store(vec4, c_rsrc, addrs[addr_idx]) + addr_idx += 1 + + def wait_and_barrier(outstanding=0): + tdm_ops.tensor_wait(outstanding) + gpu.barrier() + + def wait_and_cluster_barrier(outstanding=0): + tdm_ops.tensor_wait(outstanding) + if use_cluster: + gpu.cluster_barrier() + else: + gpu.barrier() + + def _compute_and_schedule(accs_in, a, b, a_s, b_s): + rocdl.sched_barrier(0) + accs_out = compute_tile(accs_in, a, b, a_s, b_s) + hot_loop_scheduler() + return accs_out + + _effective_l2_pf = l2_prefetch_distance + if use_cluster and l2_prefetch_distance > 0: + _effective_l2_pf = max(1, l2_prefetch_distance - 1) + + def _l2_prefetch(k_base): + if _effective_l2_pf <= 0: + return + pf_k = k_base + arith.index(_effective_l2_pf * tile_k) + pf_k_packed = pf_k / arith.index(PACK_FACTOR) + tdm_ops.l2_prefetch_tile( + arg_a, (blk_m, pf_k_packed), (tile_m, packed_tile_k), (K_packed, 1), + elem_bytes=1, thread_id=tx, block_threads=block_threads) + tdm_ops.l2_prefetch_tile( + arg_b, (blk_n, pf_k_packed), (tile_n, packed_tile_k), (K_packed, 1), + elem_bytes=1, thread_id=tx, block_threads=block_threads) + + acc_zero = arith.constant_vector(0.0, T.vec(8, T.f32)) + accs = [acc_zero] * n_accs + + # Build per-stage SmemPtrs using f16 element type for addressing. + # FP4 packed data (1 byte = 2 FP4) + scale (1 byte E8M0) both + # addressed in f16 units (2 bytes). This matches the fp16 kernel's + # proven vector.load_op pattern. + lds_a_data_f16 = lds_a_data_bytes // 2 + lds_b_data_f16 = lds_b_data_bytes // 2 + lds_a_scale_f16 = lds_a_scale_bytes // 2 + lds_b_scale_f16 = lds_b_scale_bytes // 2 + + base_ptrs = [sa.get_base() for sa in stage_allocators] + + stages_a = [ + SmemPtr(base_ptrs[i], stage_a_data_off[i], elem_ty_lds, shape=(lds_a_data_f16,)) + for i in range_constexpr(num_buffers) + ] + stages_b = [ + SmemPtr(base_ptrs[i], stage_b_data_off[i], elem_ty_lds, shape=(lds_b_data_f16,)) + for i in range_constexpr(num_buffers) + ] + stages_as = [ + SmemPtr(base_ptrs[i], stage_a_scale_off[i], elem_ty_lds, shape=(lds_a_scale_f16,)) + for i in range_constexpr(num_buffers) + ] + stages_bs = [ + SmemPtr(base_ptrs[i], stage_b_scale_off[i], elem_ty_lds, shape=(lds_b_scale_f16,)) + for i in range_constexpr(num_buffers) + ] + + # Get memrefs for TDM (raw memref values) + stages_a_mem = [stages_a[i].get() for i in range_constexpr(num_buffers)] + stages_b_mem = [stages_b[i].get() for i in range_constexpr(num_buffers)] + stages_as_mem = [stages_as[i].get() for i in range_constexpr(num_buffers)] + stages_bs_mem = [stages_bs[i].get() for i in range_constexpr(num_buffers)] + + # Prologue: load first (num_buffers - 1) tiles + for i in range_constexpr(pre_loaded): + issue_all_tdm_loads( + arith.index(i * tile_k), + stages_a_mem[i], stages_b_mem[i], + stages_as_mem[i], stages_bs_mem[i]) + # Wait for all but the last batch of TDM loads + wait_and_barrier(outstanding=TDM_LOADS_PER_STEP * (num_buffers - 2)) + + # Main loop + main_end = loop_iters * num_buffers * tile_k + + if loop_iters > 0: + for iv, state in range(0, main_end, num_buffers * tile_k, init=list(accs)): + accs_in = list(state) + for s in range_constexpr(num_buffers): + _load_stage = (s + num_buffers - 1) % num_buffers + _load_k_off = (s + num_buffers - 1) * tile_k + issue_all_tdm_loads( + iv + arith.index(_load_k_off), + stages_a_mem[_load_stage], stages_b_mem[_load_stage], + stages_as_mem[_load_stage], stages_bs_mem[_load_stage]) + _l2_prefetch(iv + arith.index(s * tile_k)) + accs_in = _compute_and_schedule( + accs_in, + stages_a[s], stages_b[s], + stages_as[s], stages_bs[s]) + if s == num_buffers - 1: + wait_and_cluster_barrier(outstanding=TDM_LOADS_PER_STEP) + else: + wait_and_barrier(outstanding=TDM_LOADS_PER_STEP) + results = yield list(accs_in) + accs = list(results) + + # Tail + if loop_iters == 0 and use_cluster: + gpu.cluster_barrier() + _extra_j = 0 + for _load_stage, _compute_stage, _outstanding in tail_plan: + if _load_stage is not None: + _k_off = (_tail_start + pre_loaded + _extra_j) * tile_k + issue_all_tdm_loads( + arith.index(_k_off), + stages_a_mem[_load_stage], stages_b_mem[_load_stage], + stages_as_mem[_load_stage], stages_bs_mem[_load_stage]) + _extra_j += 1 + if _outstanding == -1: + epi_addrs_box = [None] + + def _emit_epi_addrs(): + epi_addrs_box[0] = epilogue_prepare_addrs() + + accs = compute_tile( + accs, + stages_a[_compute_stage], stages_b[_compute_stage], + stages_as[_compute_stage], stages_bs[_compute_stage], + emit_filler=_emit_epi_addrs) + else: + accs = _compute_and_schedule( + accs, + stages_a[_compute_stage], stages_b[_compute_stage], + stages_as[_compute_stage], stages_bs[_compute_stage]) + if use_cluster and _load_stage is not None: + wait_and_cluster_barrier(outstanding=_outstanding) + else: + wait_and_barrier(outstanding=_outstanding) + + epilogue_stores(accs, epi_addrs_box[0]) + + cache_tag = (K, tile_m, tile_n, tile_k, m_warp, n_warp, + num_buffers, effective_waves_per_eu, l2_prefetch_distance, + cluster_m, cluster_n, scale_preshuffle) + + @flyc.jit + def launch_mxfp4_gemm( + arg_c: fx.Tensor, + arg_a: fx.Tensor, + arg_b: fx.Tensor, + arg_a_scale: fx.Tensor, + arg_b_scale: fx.Tensor, + i32_m: fx.Int32, + i32_n: fx.Int32, + stream: fx.Stream, + ): + _ = cache_tag + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + for alloc in stage_allocators: + alloc.finalized = False + for alloc in stage_allocators: + alloc.finalize() + + idx_m = arith.index_cast(T.index, i32_m.ir_value()) + idx_n = arith.index_cast(T.index, i32_n.ir_value()) + gx = _raw((idx_m + arith.index(tile_m - 1)) / arith.index(tile_m)) + gy = _raw((idx_n + arith.index(tile_n - 1)) / arith.index(tile_n)) + + launcher = kernel_mxfp4_gemm( + arg_c, arg_a, arg_b, arg_a_scale, arg_b_scale, i32_m, i32_n) + for op in ctx.gpu_module_body.operations: + if hasattr(op, 'attributes') and op.OPERATION_NAME == "gpu.func": + if effective_waves_per_eu is not None: + _wpe = int(effective_waves_per_eu) + if _wpe >= 1: + op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get( + ir.IntegerType.get_signless(32), _wpe) + if use_cluster: + op.attributes["rocdl.cluster_dims"] = ir.StringAttr.get( + f"{cluster_m},{cluster_n},1") + cluster_arg = (cluster_m, cluster_n, 1) if use_cluster else None + launcher.launch( + grid=(gx, gy, 1), + block=(block_threads, 1, 1), + stream=stream, + cluster=cluster_arg, + ) + + return launch_mxfp4_gemm + + +__all__ = ["compile_mxfp4_gemm"] diff --git a/kernels/pipeline_utils.py b/kernels/pipeline_utils.py new file mode 100644 index 00000000..37fabcc8 --- /dev/null +++ b/kernels/pipeline_utils.py @@ -0,0 +1,43 @@ +"""Shared pipeline utilities for gfx1250 GEMM kernels. """ + + +def make_tail_plan(num_buffers, pre_loaded, extra): + """Compute a compile-time tail execution plan for the N-stage pipeline. + + Returns a list of (load_stage, compute_stage, outstanding) tuples, one per + tail step. outstanding=-1 means "last step, use compute_tile (no barrier)". + + Args: + num_buffers: total number of pipeline stages. + pre_loaded: stages already loaded and ready to compute (= num_buffers - 1). + extra: additional tiles that must be loaded in the tail. + """ + steps = pre_loaded + extra + plan = [] + for i in range(steps): + compute_stage = ( + i if i < pre_loaded + else (i - pre_loaded + num_buffers - 1) % num_buffers + ) + load_stage = ( + (i + num_buffers - 1) % num_buffers if i < extra + else None + ) + is_last = (i == steps - 1) + if is_last: + outstanding = -1 + else: + j = i + 1 + next_compute = ( + j if j < pre_loaded + else (j - pre_loaded + num_buffers - 1) % num_buffers + ) + outstanding = ( + 2 if (load_stage is not None and load_stage != next_compute) + else 0 + ) + plan.append((load_stage, compute_stage, outstanding)) + return plan + + +__all__ = ["make_tail_plan"] diff --git a/kernels/wmma_gemm_gfx1250.py b/kernels/wmma_gemm_gfx1250.py new file mode 100644 index 00000000..730f94a3 --- /dev/null +++ b/kernels/wmma_gemm_gfx1250.py @@ -0,0 +1,568 @@ +"""TDM async copy WMMA GEMM kernel for gfx1250. + +Supports double-buffer (2-stage) and triple-buffer (3-stage) pipelining +with TDM (Tensor Data Mover) hardware async copy for both A and B tiles. +""" + +import flydsl.compiler as flyc +import flydsl.expr as fx + +from flydsl._mlir import ir +from flydsl.compiler.kernel_function import CompilationContext +from flydsl.expr import arith, buffer_ops, gpu, range_constexpr, rocdl, tdm_ops, vector +from flydsl._mlir.dialects import llvm as llvm_dialect +from flydsl.expr.arith import _to_raw as _raw +from flydsl.expr.typing import T +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr, get_op_result_or_value + +from kernels.layout_utils import idx2crd +from kernels.pipeline_utils import make_tail_plan + +WMMA_M, WMMA_N, WMMA_K = 16, 16, 32 +WAVE_SIZE = 32 + +LDS_PAD_A = 8 +LDS_PAD_B = 8 + +_STAGE_NAMES = ("ping", "pong", "pang") + + +_make_tail_plan = make_tail_plan + + +def compile_wmma_gemm_tdm( + *, + M: int = 0, + N: int = 0, + K: int, + tile_m: int = 256, + tile_n: int = 256, + tile_k: int = 128, + m_warp: int = 2, + n_warp: int = 4, + in_dtype: str = "fp16", + num_buffers: int = 2, + waves_per_eu: int = None, + l2_prefetch_distance: int = 2, + cluster_m: int = 1, + cluster_n: int = 1, +): + """Compile a WMMA GEMM kernel with TDM async copy and multi-stage buffering. + + Returns a JitFunction: launch_fn(arg_c, arg_a, arg_b, M, N, stream) + + Args: + num_buffers: Number of LDS buffers (2=double, 3=triple buffering). + waves_per_eu: Occupancy hint (None = default, 1-4 = limit occupancy). + l2_prefetch_distance: Number of k-tiles ahead to prefetch into L2. + 0 = disabled, 2 = typical value. + cluster_m: Cluster dimension along M (WG rows per cluster, 1=disabled). + cluster_n: Cluster dimension along N (WG cols per cluster, 1=disabled). + """ + _ = (M, N) + if num_buffers not in (2, 3): + raise ValueError(f"num_buffers must be 2 or 3, got {num_buffers}") + if in_dtype not in ("fp16", "bf16"): + raise ValueError(f"in_dtype must be 'fp16' or 'bf16', got {in_dtype!r}") + is_f16 = in_dtype == "fp16" + elem_bytes = 2 + + use_cluster = cluster_m > 1 or cluster_n > 1 + if use_cluster: + if cluster_m * cluster_n > 16: + raise ValueError( + f"cluster_m * cluster_n must be <= 16, got {cluster_m}*{cluster_n}={cluster_m * cluster_n}") + if cluster_m < 1 or cluster_n < 1: + raise ValueError(f"cluster dims must be >= 1, got ({cluster_m}, {cluster_n})") + effective_waves_per_eu = waves_per_eu + if use_cluster and effective_waves_per_eu is None: + # Cluster mode can deadlock if a workgroup is split and only a subset + # of its waves are resident while hitting early workgroup barriers. + # Use conservative occupancy by default for cluster-enabled kernels. + effective_waves_per_eu = 1 + + block_threads = m_warp * n_warp * WAVE_SIZE + + if K % tile_k != 0: + raise ValueError(f"K must be divisible by tile_k={tile_k}, got K={K}") + if tile_k % WMMA_K != 0: + raise ValueError(f"tile_k must be a multiple of {WMMA_K}, got {tile_k}") + if tile_m % WMMA_M != 0: + raise ValueError(f"tile_m must be a multiple of {WMMA_M}, got {tile_m}") + if tile_n % WMMA_N != 0: + raise ValueError(f"tile_n must be a multiple of {WMMA_N}, got {tile_n}") + if (tile_k & (tile_k - 1)) != 0: + raise ValueError(f"tile_k must be a power of 2 for TDM async copy, got {tile_k}") + + warp_tile_m = tile_m // m_warp + warp_tile_n = tile_n // n_warp + if warp_tile_m % WMMA_M != 0: + raise ValueError(f"warp_tile_m={warp_tile_m} must be a multiple of {WMMA_M}") + if warp_tile_n % WMMA_N != 0: + raise ValueError(f"warp_tile_n={warp_tile_n} must be a multiple of {WMMA_N}") + + num_k_tiles = K // tile_k + if num_k_tiles < num_buffers: + raise ValueError( + f"{num_buffers}-stage buffering requires num_k_tiles >= {num_buffers}, " + f"got {num_k_tiles} (K={K}, tile_k={tile_k})") + + gpu_arch = str(get_hip_arch(timeout_s=300)) + assert gpu_arch.startswith("gfx1250"), f"Expected gfx1250, got {gpu_arch}" + + wmma_op = rocdl.wmma_f32_16x16x32_f16 if is_f16 else rocdl.wmma_f32_16x16x32_bf16 + k_wmma_steps = tile_k // WMMA_K + + def _elem_type(): + return T.f16 if is_f16 else T.bf16 + + wmma_m_rep = warp_tile_m // WMMA_M + wmma_n_rep = warp_tile_n // WMMA_N + n_accs = wmma_m_rep * wmma_n_rep + + lds_a_stride = tile_k + LDS_PAD_A + lds_b_stride = tile_n + LDS_PAD_B + lds_a_elems = tile_m * lds_a_stride + LDS_PAD_A + lds_b_elems = tile_k * lds_b_stride + LDS_PAD_B + + buf_size_elems = lds_a_elems + lds_b_elems + + # --- LDS allocation --- + num_warps = m_warp * n_warp + + stage_allocators = [] + stage_a_offsets = [] + stage_b_offsets = [] + for i in range(num_buffers): + name = _STAGE_NAMES[i] + alloc = SmemAllocator(None, arch=gpu_arch, global_sym_name=f"wmma_tdm_{name}") + off = alloc._align(alloc.ptr, 16) + alloc.ptr = off + buf_size_elems * elem_bytes + stage_allocators.append(alloc) + stage_a_offsets.append(off) + stage_b_offsets.append(off + lds_a_elems * elem_bytes) + + # Compile-time pipeline parameters + pre_loaded = num_buffers - 1 # stages pre-loaded in prologue + loop_iters = (num_k_tiles - pre_loaded) // num_buffers + _tail_start = loop_iters * num_buffers # index of first un-computed tile in tail + extra = num_k_tiles - _tail_start - pre_loaded + tail_plan = _make_tail_plan(num_buffers, pre_loaded, extra) + + @flyc.kernel + def kernel_wmma_gemm_tdm( + arg_c: fx.Tensor, + arg_a: fx.Tensor, + arg_b: fx.Tensor, + i32_m: fx.Int32, + i32_n: fx.Int32, + ): + # Enable back-to-back WMMA issue (SCHED_MODE bit[4] = DISABLE_VALU_STALL) + # hwreg(26, 4, 1) = HW_REG_SCHED_MODE, offset=4, size=1 + llvm_dialect.inline_asm( + None, [], # void result, no operands + "s_setreg_imm32_b32 hwreg(26, 4, 1), 1", + "", # no constraints + has_side_effects=True, + ) + + tx = gpu.thread_id("x") + bx = gpu.block_id("x") + by = gpu.block_id("y") + + blk_m = bx * arith.index(tile_m) + blk_n = by * arith.index(tile_n) + + # --- Cluster MCAST setup --- + if use_cluster: + local_x, local_y = gpu.compute_cluster_position() + a_mcast_mask, b_mcast_mask = gpu.compute_mcast_masks( + local_x, local_y, cluster_m, cluster_n) + else: + a_mcast_mask = 0 + b_mcast_mask = 0 + + # --- Thread/wave decomposition --- + layout_thr = fx.make_layout( + (m_warp, n_warp, 2, 16), + (n_warp * WAVE_SIZE, WAVE_SIZE, 16, 1)) + thr_coord = idx2crd(tx, layout_thr) + wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( + thr_coord[0], thr_coord[1], thr_coord[2], thr_coord[3]) + + warp_m_base = wave_m_idx * arith.index(warp_tile_m) + warp_n_base = wave_n_idx * arith.index(warp_tile_n) + + elem_ty = _elem_type() + + # --- Epilogue setup --- + m_idx = arith.index_cast(T.index, i32_m.ir_value()) + n_stride = arith.index(N) + c_nrec = m_idx * n_stride * arith.index(4) + c_rsrc = buffer_ops.create_buffer_resource(arg_c, num_records_bytes=c_nrec) + + # --- TDM async copy helpers (MCAST-aware) --- + def copy_a_to_lds(k_base, lds_a_mem_ref): + desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_a, lds_memref=lds_a_mem_ref, + global_offset=(blk_m, k_base), + tensor_shape=(tile_m, tile_k), strides=(K, 1), + tile_shape=(tile_m, tile_k), elem_bytes=elem_bytes, + pad_interval=tile_k, pad_amount=LDS_PAD_A, + num_warps=num_warps, + workgroup_mask=a_mcast_mask) + tdm_ops.tensor_load_2d(desc) + + def copy_b_to_lds(k_base, lds_b_mem_ref): + desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_b, lds_memref=lds_b_mem_ref, + global_offset=(k_base, blk_n), + tensor_shape=(tile_k, tile_n), strides=(N, 1), + tile_shape=(tile_k, tile_n), elem_bytes=elem_bytes, + pad_interval=tile_n, pad_amount=LDS_PAD_B, + num_warps=num_warps, + workgroup_mask=b_mcast_mask) + tdm_ops.tensor_load_2d(desc) + + # --- LDS load helpers --- + def _get_lds_memref(lds_ptr): + """Get the raw memref value from SmemPtr or raw memref.""" + if isinstance(lds_ptr, SmemPtr): + return get_op_result_or_value(lds_ptr.get()) + return get_op_result_or_value(lds_ptr) + + def _precompute_a_lane_bases(lds_ptr): + """Precompute per-wm A fragment lane base addresses. + + Returns (lds_buffer, bases) where bases[wm] = + (warp_m_base + wm*WMMA_M + lane16) * lds_a_stride + lane_kgrp * 8 + """ + lds_buffer = _get_lds_memref(lds_ptr) + row_stride_off = (warp_m_base + lane16) * arith.index(lds_a_stride) + k_lane_off = lane_kgrp * arith.index(8) + bases = [] + for wm in range_constexpr(wmma_m_rep): + a_base = row_stride_off + arith.index(wm * WMMA_M * lds_a_stride) + k_lane_off + bases.append(a_base) + return lds_buffer, bases + + def load_wmma_frag(a_lds_buffer, a_lane_base, ks): + """Load one 16x32 WMMA fragment from LDS using vectorized 128-bit loads. + + a_lane_base is precomputed by _precompute_a_lane_bases. + ks is the K-subtile index (compile-time constant). + """ + vec8_ty = ir.VectorType.get([8], elem_ty) + + off0 = a_lane_base + arith.index(ks * WMMA_K) + off1 = a_lane_base + arith.index(ks * WMMA_K + 16) + + v0 = vector.load_op(vec8_ty, a_lds_buffer, [off0]) + v1 = vector.load_op(vec8_ty, a_lds_buffer, [off1]) + + return vector.shuffle(v0, v1, list(range(16))) + + def _precompute_b_lane_bases(lds_ptr): + """Precompute per-wn B fragment lane base addresses. + + Returns a list of (lds_buffer, b_lane_base) for each wn. + b_lane_base = (lane_kgrp*8 + lane8) * lds_b_stride + + (warp_n_base + wn*WMMA_N + lane_ngrp*8) + where lane8 = lane16 % 8, lane_ngrp = lane16 / 8. + + After precompute, lane8/lane_ngrp are dead → frees VGPRs. + """ + lds_buffer = _get_lds_memref(lds_ptr) + lane8 = lane16 % arith.index(8) + lane_ngrp = lane16 / arith.index(8) + k_lane_off = (lane_kgrp * arith.index(8) + lane8) * arith.index(lds_b_stride) + n_lane_off = lane_ngrp * arith.index(8) + bases = [] + for wn in range_constexpr(wmma_n_rep): + n_col = warp_n_base + arith.index(wn * WMMA_N) + n_lane_off + b_base = k_lane_off + n_col + bases.append(b_base) + return lds_buffer, bases + + def load_wmma_frag_tr(lds_buffer, b_lane_base, ks): + """Load one 16x32 WMMA B fragment using ds_load_tr16_b128. + + b_lane_base is precomputed by _precompute_b_lane_bases. + ks is the K-subtile index (compile-time constant from range_constexpr). + The K offset is folded into a compile-time constant multiplication. + """ + vec8_ty = ir.VectorType.get([8], elem_ty) + results = [] + for k_half in range_constexpr(2): + k_row_off = (ks * WMMA_K + k_half * 16) * lds_b_stride + elem_off = b_lane_base + arith.index(k_row_off) + v = rocdl.lds_transpose_load(vec8_ty, lds_buffer, elem_off, elem_bytes) + results.append(v) + return vector.shuffle(results[0], results[1], list(range(16))) + + # --- K-subtile load/compute helpers --- + # Number of LDS loads per K-subtile: + # B frags: wmma_n_rep * 2 (ds_load_tr16_b128), A frags: wmma_m_rep * 2 + LOADS_PER_SUBTILE = (wmma_m_rep + wmma_n_rep) * 2 + + def load_k_subtile_frags(a_lds_buffer, a_bases, b_lds_buffer, b_bases, ks): + """Batch-load all A and B fragments for one K-subtile (no wait). + + All base addresses are precomputed by _precompute_{a,b}_lane_bases. + ks is the K-subtile index (compile-time constant). + """ + b_frags = [load_wmma_frag_tr(b_lds_buffer, b_bases[wn], ks) + for wn in range_constexpr(wmma_n_rep)] + + a_frags = [load_wmma_frag(a_lds_buffer, a_bases[wm], ks) + for wm in range_constexpr(wmma_m_rep)] + + return a_frags, b_frags + + def do_k_subtile_wmma(a_frags, b_frags, accs): + """Execute all WMMAs for one K-subtile using pre-loaded fragments.""" + for wm in range_constexpr(wmma_m_rep): + for wn in range_constexpr(wmma_n_rep): + idx = wm * wmma_n_rep + wn + accs[idx] = wmma_op( + T.vec(8, T.f32), + b_frags[wn], a_frags[wm], + accs[idx], + signA=False, signB=False, modC=0, + reuseA=False, reuseB=False, + ).result + return accs + + # --- Compute on one LDS buffer (K-subtile pipelined) --- + def compute_tile(accs_in, lds_a_ptr, lds_b_ptr, emit_filler=None): + current_accs = list(accs_in) + + # Precompute all lane bases once per tile + a_lds_buffer, a_bases = _precompute_a_lane_bases(lds_a_ptr) + b_lds_buffer, b_bases = _precompute_b_lane_bases(lds_b_ptr) + + if k_wmma_steps == 1: + a_frags, b_frags = load_k_subtile_frags( + a_lds_buffer, a_bases, b_lds_buffer, b_bases, 0) + rocdl.s_wait_dscnt(0) + if emit_filler is not None: + emit_filler() + current_accs = do_k_subtile_wmma(a_frags, b_frags, current_accs) + else: + # Prologue: batch-load K-subtile 0 + prev_a, prev_b = load_k_subtile_frags( + a_lds_buffer, a_bases, b_lds_buffer, b_bases, 0) + + # Main K-loop: overlap load[ks+1] with compute[ks] + for ks in range_constexpr(k_wmma_steps - 1): + next_a, next_b = load_k_subtile_frags( + a_lds_buffer, a_bases, b_lds_buffer, b_bases, ks + 1) + rocdl.s_wait_dscnt(LOADS_PER_SUBTILE) + current_accs = do_k_subtile_wmma(prev_a, prev_b, current_accs) + prev_a, prev_b = next_a, next_b + + rocdl.s_wait_dscnt(0) + if emit_filler is not None: + rocdl.sched_barrier(0) + emit_filler() + current_accs = do_k_subtile_wmma(prev_a, prev_b, current_accs) + + return current_accs + + # --- Scheduling --- + def hot_loop_scheduler(): + rocdl.sched_barrier(0) + + # --- Epilogue: vectorized buffer_store_b128 --- + def epilogue_prepare_addrs(): + """Precompute all epilogue store addresses (VALU only, no stores). """ + addrs = [] + for wm in range_constexpr(wmma_m_rep): + for wn in range_constexpr(wmma_n_rep): + row = blk_m + warp_m_base + arith.index(wm * WMMA_M) + lane16 + col_base = (blk_n + warp_n_base + arith.index(wn * WMMA_N) + + lane_kgrp * arith.index(8)) + for half in range_constexpr(2): + col = col_base + arith.index(half * 4) + c_off = row * n_stride + col + addrs.append(c_off) + return addrs + + def epilogue_stores(final_accs, addrs): + """Execute buffer_store using precomputed addresses.""" + addr_idx = 0 + for wm in range_constexpr(wmma_m_rep): + for wn in range_constexpr(wmma_n_rep): + idx = wm * wmma_n_rep + wn + for half in range_constexpr(2): + vals = [vector.extract( + final_accs[idx], + static_position=[half * 4 + vi], + dynamic_position=[]) + for vi in range_constexpr(4)] + vec4 = vector.from_elements(T.vec(4, T.f32), vals) + buffer_ops.buffer_store(vec4, c_rsrc, addrs[addr_idx]) + addr_idx += 1 + + # --- Pipeline helpers --- + def wait_and_barrier(outstanding=0): + tdm_ops.tensor_wait(outstanding) + gpu.barrier() + + def wait_and_cluster_barrier(outstanding=0): + """Fused WG barrier + cluster sync: reduces instruction overhead + by issuing the cluster signal while tensor_wait is still draining, + then waiting for both to complete.""" + tdm_ops.tensor_wait(outstanding) + if use_cluster: + gpu.cluster_barrier() + else: + gpu.barrier() + + def _compute_and_schedule(accs_in, lds_a, lds_b): + rocdl.sched_barrier(0) + accs_out = compute_tile(accs_in, lds_a, lds_b) + hot_loop_scheduler() + return accs_out + + _effective_l2_pf = l2_prefetch_distance + if use_cluster and l2_prefetch_distance > 0: + _effective_l2_pf = max(1, l2_prefetch_distance - 1) + + def _l2_prefetch(k_base): + if _effective_l2_pf <= 0: + return + pf_k = k_base + arith.index(_effective_l2_pf * tile_k) + tdm_ops.l2_prefetch_tile( + arg_a, (blk_m, pf_k), (tile_m, tile_k), (K, 1), + elem_bytes=elem_bytes, thread_id=tx, block_threads=block_threads) + tdm_ops.l2_prefetch_tile( + arg_b, (pf_k, blk_n), (tile_k, tile_n), (N, 1), + elem_bytes=elem_bytes, thread_id=tx, block_threads=block_threads) + + # ====== Multi-stage pipeline ====== + acc_zero = arith.constant_vector(0.0, T.vec(8, T.f32)) + accs = [acc_zero] * n_accs + + # Build per-stage SmemPtrs (one per pipeline stage) + base_ptrs = [sa.get_base() for sa in stage_allocators] + stages_a = [ + SmemPtr(base_ptrs[i], stage_a_offsets[i], elem_ty, shape=(lds_a_elems,)) + for i in range_constexpr(num_buffers) + ] + stages_b = [ + SmemPtr(base_ptrs[i], stage_b_offsets[i], elem_ty, shape=(lds_b_elems,)) + for i in range_constexpr(num_buffers) + ] + stages_a_mem = [stages_a[i].get() for i in range_constexpr(num_buffers)] + stages_b_mem = [stages_b[i].get() for i in range_constexpr(num_buffers)] + + # Prologue: load first (num_buffers - 1) tiles into stages 0..(num_buffers-2) + for i in range_constexpr(pre_loaded): + copy_a_to_lds(arith.index(i * tile_k), stages_a_mem[i]) + copy_b_to_lds(arith.index(i * tile_k), stages_b_mem[i]) + wait_and_barrier(outstanding=2 * (num_buffers - 2)) + + # Main loop: each iteration covers (num_buffers) K-tiles + # Sub-phase s: load next tile (MCAST), compute current tile, then barrier + # The last sub-phase uses wait_and_cluster_barrier to fuse the WG + # barrier with cluster sync for the NEXT iteration's MCAST loads. + main_end = loop_iters * num_buffers * tile_k + + if loop_iters > 0: + for iv, state in range(0, main_end, num_buffers * tile_k, init=list(accs)): + accs_in = list(state) + for s in range_constexpr(num_buffers): + _load_stage = (s + num_buffers - 1) % num_buffers + _load_k_off = (s + num_buffers - 1) * tile_k + copy_a_to_lds(iv + arith.index(_load_k_off), stages_a_mem[_load_stage]) + copy_b_to_lds(iv + arith.index(_load_k_off), stages_b_mem[_load_stage]) + _l2_prefetch(iv + arith.index(s * tile_k)) + accs_in = _compute_and_schedule(accs_in, stages_a[s], stages_b[s]) + if s == num_buffers - 1: + wait_and_cluster_barrier(outstanding=2) + else: + wait_and_barrier(outstanding=2) + results = yield list(accs_in) + accs = list(results) + + # Tail: handle remaining tiles using the compile-time plan + # outstanding=-1 → last step: use compute_tile (no barrier). + if loop_iters == 0 and use_cluster: + gpu.cluster_barrier() + _extra_j = 0 + for _load_stage, _compute_stage, _outstanding in tail_plan: + if _load_stage is not None: + _k_off = (_tail_start + pre_loaded + _extra_j) * tile_k + copy_a_to_lds(arith.index(_k_off), stages_a_mem[_load_stage]) + copy_b_to_lds(arith.index(_k_off), stages_b_mem[_load_stage]) + _extra_j += 1 + if _outstanding == -1: + epi_addrs_box = [None] + + def _emit_epi_addrs(): + epi_addrs_box[0] = epilogue_prepare_addrs() + + accs = compute_tile( + accs, stages_a[_compute_stage], stages_b[_compute_stage], + emit_filler=_emit_epi_addrs) + else: + accs = _compute_and_schedule( + accs, stages_a[_compute_stage], stages_b[_compute_stage]) + if use_cluster and _load_stage is not None: + wait_and_cluster_barrier(outstanding=_outstanding) + else: + wait_and_barrier(outstanding=_outstanding) + + epilogue_stores(accs, epi_addrs_box[0]) + + cache_tag = (in_dtype, K, tile_m, tile_n, tile_k, m_warp, n_warp, + num_buffers, effective_waves_per_eu, l2_prefetch_distance, + cluster_m, cluster_n) + + @flyc.jit + def launch_wmma_gemm_tdm( + arg_c: fx.Tensor, + arg_a: fx.Tensor, + arg_b: fx.Tensor, + i32_m: fx.Int32, + i32_n: fx.Int32, + stream: fx.Stream, + ): + _ = cache_tag + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + for alloc in stage_allocators: + alloc.finalized = False + for alloc in stage_allocators: + alloc.finalize() + + idx_m = arith.index_cast(T.index, i32_m.ir_value()) + idx_n = arith.index_cast(T.index, i32_n.ir_value()) + gx = _raw((idx_m + arith.index(tile_m - 1)) / arith.index(tile_m)) + gy = _raw((idx_n + arith.index(tile_n - 1)) / arith.index(tile_n)) + + launcher = kernel_wmma_gemm_tdm(arg_c, arg_a, arg_b, i32_m, i32_n) + for op in ctx.gpu_module_body.operations: + if hasattr(op, 'attributes') and op.OPERATION_NAME == "gpu.func": + if effective_waves_per_eu is not None: + _wpe = int(effective_waves_per_eu) + if _wpe >= 1: + op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get( + ir.IntegerType.get_signless(32), _wpe) + if use_cluster: + op.attributes["rocdl.cluster_dims"] = ir.StringAttr.get( + f"{cluster_m},{cluster_n},1") + cluster_arg = (cluster_m, cluster_n, 1) if use_cluster else None + launcher.launch( + grid=(gx, gy, 1), + block=(block_threads, 1, 1), + stream=stream, + cluster=cluster_arg, + ) + + return launch_wmma_gemm_tdm + + +__all__ = ["compile_wmma_gemm_tdm"] diff --git a/kernels/wmma_gemm_simple.py b/kernels/wmma_gemm_simple.py new file mode 100644 index 00000000..0187369d --- /dev/null +++ b/kernels/wmma_gemm_simple.py @@ -0,0 +1,255 @@ +"""WMMA GEMM kernel for gfx1250.""" + +import flydsl.compiler as flyc +import flydsl.expr as fx + +from flydsl._mlir import ir +from flydsl.compiler.kernel_function import CompilationContext +from flydsl.expr import arith, buffer_ops, gpu, range_constexpr, rocdl, vector +from flydsl.expr.arith import _to_raw as _raw +from flydsl.expr.typing import T +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr + +from kernels.layout_utils import crd2idx, idx2crd, get as layout_get + +WMMA_M, WMMA_N, WMMA_K = 16, 16, 32 +WAVE_SIZE = 32 + + +def compile_wmma_gemm( + *, + M: int = 0, + N: int = 0, + K: int, + tile_m: int = 64, + tile_n: int = 128, + tile_k: int = WMMA_K, + in_dtype: str = "fp16", + block_threads: int = 128, +): + """Compile a WMMA GEMM kernel using the @flyc.kernel API. + + Returns a JitFunction that auto-compiles and executes when called. + Signature: launch_fn(arg_c, arg_a, arg_b, arg_scale_a, arg_scale_b, M, N, stream) + + Compile-time constants: K, tile_m/n/k, in_dtype (determine loop structure). + Runtime parameters: M, N (passed as i32 kernel args). + """ + _ = (M, N) + if in_dtype not in ("fp16", "bf16"): + raise ValueError(f"in_dtype must be 'fp16' or 'bf16', got {in_dtype!r}") + is_fp4 = in_dtype == "fp4" + is_int4 = in_dtype == "int4" + is_int8 = (in_dtype == "int8") or is_int4 + is_f16 = in_dtype == "fp16" + is_bf16 = in_dtype == "bf16" + is_f16_or_bf16 = is_f16 or is_bf16 + elem_bytes = 1 if (in_dtype in ("fp8", "int8", "int4", "fp4")) else 2 + + if K % tile_k != 0: + raise ValueError(f"K must be divisible by tile_k={tile_k}, got K={K}") + if tile_k % WMMA_K != 0: + raise ValueError(f"tile_k must be a multiple of {WMMA_K}, got {tile_k}") + if tile_m % WMMA_M != 0: + raise ValueError(f"tile_m must be a multiple of {WMMA_M}, got {tile_m}") + + waves_per_block = block_threads // WAVE_SIZE + if tile_n % (waves_per_block * WMMA_N) != 0: + raise ValueError( + f"tile_n must be a multiple of waves_per_block*{WMMA_N}={waves_per_block * WMMA_N}, got {tile_n}" + ) + + gpu_arch = str(get_hip_arch(timeout_s=300)) + assert gpu_arch.startswith("gfx1250"), f"Expected a gfx1250 architecture, got {gpu_arch}" + + wmma_op = rocdl.wmma_f32_16x16x32_f16 if is_f16 else rocdl.wmma_f32_16x16x32_bf16 + k_wmma_steps = tile_k // WMMA_K + + def _elem_type(): + return T.f16 if is_f16 else T.bf16 + + warp_tile_n = tile_n // waves_per_block + wmma_m_rep = tile_m // WMMA_M + wmma_n_rep = warp_tile_n // WMMA_N + n_accs = wmma_m_rep * wmma_n_rep + + lds_a_elems = tile_m * tile_k + lds_b_elems = tile_k * tile_n + lds_a_offset = 0 + lds_b_offset = lds_a_elems * elem_bytes + + allocator = SmemAllocator(None, arch=gpu_arch, global_sym_name="wmma_gemm_smem") + allocator.ptr = lds_b_offset + lds_b_elems * elem_bytes + + total_vec_a = tile_m * (tile_k // 4) + total_vec_b = tile_k * (tile_n // 4) + if total_vec_a % block_threads != 0 or total_vec_b % block_threads != 0: + raise ValueError( + f"vectorized copy requires vec slots divisible by block_threads: " + f"A={total_vec_a}, B={total_vec_b}, block_threads={block_threads}" + ) + vec_iters_a = total_vec_a // block_threads + vec_iters_b = total_vec_b // block_threads + + @flyc.kernel + def kernel_wmma_gemm( + arg_c: fx.Tensor, + arg_a: fx.Tensor, + arg_b: fx.Tensor, + i32_m: fx.Int32, + i32_n: fx.Int32, + ): + tx = gpu.thread_id("x") + bx = gpu.block_id("x") + by = gpu.block_id("y") + + n_stride = arith.index_cast(T.index, i32_n.ir_value()) + blk_m = bx * arith.index(tile_m) + blk_n = by * arith.index(tile_n) + + layout_thr = fx.make_layout((waves_per_block, WAVE_SIZE), (WAVE_SIZE, 1)) + layout_lane = fx.make_layout((2, 16), (16, 1)) + layout_lds_a = fx.make_layout((tile_m, tile_k), (tile_k, 1)) + layout_lds_b = fx.make_layout((tile_k, tile_n), (tile_n, 1)) + layout_vec_a = fx.make_layout((tile_m, tile_k // 4), (tile_k // 4, 1)) + layout_vec_b = fx.make_layout((tile_k, tile_n // 4), (tile_n // 4, 1)) + + thr = idx2crd(tx, layout_thr) + wave_id = layout_get(thr, 0) + lane = layout_get(thr, 1) + + lc = idx2crd(lane, layout_lane) + lane_kgrp = layout_get(lc, 0) # 0/1 + lane16 = layout_get(lc, 1) # 0..15 + warp_n_off = wave_id * arith.index(warp_tile_n) + + elem_ty = _elem_type() + base_ptr = allocator.get_base() + lds_a = SmemPtr(base_ptr, lds_a_offset, elem_ty, shape=(lds_a_elems,)) + lds_b = SmemPtr(base_ptr, lds_b_offset, elem_ty, shape=(lds_b_elems,)) + lds_a_mem = lds_a.get() + lds_b_mem = lds_b.get() + + a_rsrc = buffer_ops.create_buffer_resource(arg_a, max_size=True) + b_rsrc = buffer_ops.create_buffer_resource(arg_b, max_size=True) + vec4_elem_ty = T.vec(4, elem_ty) + + acc_zero = arith.constant_vector(0.0, T.vec(8, T.f32)) + accs = [acc_zero] * n_accs + + for kblk in range_constexpr(K // tile_k): + k_base = arith.index(kblk * tile_k) + + for t in range_constexpr(vec_iters_a): + vec_idx = tx + arith.index(t * block_threads) + a_crd = idx2crd(vec_idx, layout_vec_a) + a_m = layout_get(a_crd, 0) + a_kv = layout_get(a_crd, 1) + a_k = a_kv * arith.index(4) + + g_off = (blk_m + a_m) * arith.index(K) + (k_base + a_k) + v_i16 = buffer_ops.buffer_load(a_rsrc, g_off, vec_width=4, dtype=T.i16) + v = vector.bitcast(vec4_elem_ty, v_i16) + lds_off = crd2idx((a_m, a_k), layout_lds_a) + vector.store(v, lds_a_mem, [lds_off]) + + for t in range_constexpr(vec_iters_b): + vec_idx = tx + arith.index(t * block_threads) + b_crd = idx2crd(vec_idx, layout_vec_b) + b_k = layout_get(b_crd, 0) + b_nv = layout_get(b_crd, 1) + b_n = b_nv * arith.index(4) + + g_off = (k_base + b_k) * n_stride + (blk_n + b_n) + v_i16 = buffer_ops.buffer_load(b_rsrc, g_off, vec_width=4, dtype=T.i16) + v = vector.bitcast(vec4_elem_ty, v_i16) + lds_off = crd2idx((b_k, b_n), layout_lds_b) + vector.store(v, lds_b_mem, [lds_off]) + + gpu.barrier() + + for ks in range_constexpr(k_wmma_steps): + k_step = arith.index(ks * WMMA_K) + + b_frags = [] + for wn in range_constexpr(wmma_n_rep): + n_off = warp_n_off + arith.index(wn * WMMA_N) + vals = [] + for k0 in range_constexpr(2): + for k1 in range_constexpr(8): + kk = k_step + (arith.index(k0 * 2) + lane_kgrp) * arith.index(8) + arith.index(k1) + off = crd2idx((kk, n_off + lane16), layout_lds_b) + vals.append(lds_b.load([off])) + b_frags.append(vector.from_elements(T.vec(16, elem_ty), vals)) + + for wm in range_constexpr(wmma_m_rep): + m_off = arith.index(wm * WMMA_M) + a_vals = [] + for k0 in range_constexpr(2): + for k1 in range_constexpr(8): + kk = k_step + (arith.index(k0 * 2) + lane_kgrp) * arith.index(8) + arith.index(k1) + off = crd2idx((m_off + lane16, kk), layout_lds_a) + a_vals.append(lds_a.load([off])) + a_frag = vector.from_elements(T.vec(16, elem_ty), a_vals) + + for wn in range_constexpr(wmma_n_rep): + acc_idx = wm * wmma_n_rep + wn + accs[acc_idx] = wmma_op( + T.vec(8, T.f32), + a_frag, + b_frags[wn], + accs[acc_idx], + signA=False, + signB=False, + modC=0, + reuseA=False, + reuseB=False, + ).result + + gpu.barrier() + + for wm in range_constexpr(wmma_m_rep): + for wn in range_constexpr(wmma_n_rep): + acc_idx = wm * wmma_n_rep + wn + m_base = blk_m + arith.index(wm * WMMA_M) + n_base = blk_n + warp_n_off + arith.index(wn * WMMA_N) + for mi in range_constexpr(8): + row = m_base + lane_kgrp * arith.index(8) + arith.index(mi) + col = n_base + lane16 + c_off = row * n_stride + col + c_val = vector.extract(accs[acc_idx], static_position=[mi], dynamic_position=[]) + fx.memref_store(c_val, arg_c, c_off) + + cache_tag = (in_dtype, K, tile_m, tile_n, tile_k, block_threads) + + @flyc.jit + def launch_wmma_gemm( + arg_c: fx.Tensor, + arg_a: fx.Tensor, + arg_b: fx.Tensor, + i32_m: fx.Int32, + i32_n: fx.Int32, + stream: fx.Stream, + ): + _ = cache_tag + allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + idx_m = arith.index_cast(T.index, i32_m.ir_value()) + idx_n = arith.index_cast(T.index, i32_n.ir_value()) + gx = _raw((idx_m + arith.index(tile_m - 1)) / arith.index(tile_m)) + gy = _raw((idx_n + arith.index(tile_n - 1)) / arith.index(tile_n)) + + kernel_wmma_gemm(arg_c, arg_a, arg_b, i32_m, i32_n).launch( + grid=(gx, gy, 1), + block=(block_threads, 1, 1), + stream=stream, + ) + + return launch_wmma_gemm + + +__all__ = ["compile_wmma_gemm"] diff --git a/lib/Bindings/Python/FlyROCDLExtension.cpp b/lib/Bindings/Python/FlyROCDLExtension.cpp index 12aeb117..330e482d 100644 --- a/lib/Bindings/Python/FlyROCDLExtension.cpp +++ b/lib/Bindings/Python/FlyROCDLExtension.cpp @@ -99,6 +99,80 @@ struct PyMmaAtomCDNA3_MFMAType : PyConcreteType { } }; +struct PyMmaAtomGFX1250_WMMAType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFlyROCDLMmaAtomGFX1250_WMMAType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetTypeID; + static constexpr const char *pyClassName = "MmaAtomGFX1250_WMMAType"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](int32_t m, int32_t n, int32_t k, PyType &elemTyA, PyType &elemTyB, PyType &elemTyAcc, + DefaultingPyMlirContext context) { + return PyMmaAtomGFX1250_WMMAType(context->getRef(), + wrap(::mlir::fly_rocdl::MmaAtomGFX1250_WMMAType::get( + m, n, k, unwrap(static_cast(elemTyA)), + unwrap(static_cast(elemTyB)), + unwrap(static_cast(elemTyAcc))))); + }, + "m"_a, "n"_a, "k"_a, "elem_ty_a"_a, "elem_ty_b"_a, "elem_ty_acc"_a, nb::kw_only(), + "context"_a = nb::none(), + "Create a MmaAtomGFX1250_WMMAType with m, n, k dimensions and element types"); + + c.def_prop_ro("m", [](PyMmaAtomGFX1250_WMMAType &self) -> int32_t { + return mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetM(self); + }); + c.def_prop_ro("n", [](PyMmaAtomGFX1250_WMMAType &self) -> int32_t { + return mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetN(self); + }); + c.def_prop_ro("k", [](PyMmaAtomGFX1250_WMMAType &self) -> int32_t { + return mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetK(self); + }); + c.def_prop_ro("elem_ty_a", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { + return mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetElemTyA(self); + }); + c.def_prop_ro("elem_ty_b", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { + return mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetElemTyB(self); + }); + c.def_prop_ro("elem_ty_acc", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { + return mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetElemTyAcc(self); + }); + + c.def_prop_ro("thr_layout", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { + auto ty = + ::mlir::cast<::mlir::fly::MmaAtomTypeInterface>(unwrap(static_cast(self))); + auto attr = ::mlir::cast<::mlir::fly::LayoutAttr>(ty.getThrLayout()); + return wrap(::mlir::fly::LayoutType::get(attr)); + }); + c.def_prop_ro("shape_mnk", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { + auto ty = + ::mlir::cast<::mlir::fly::MmaAtomTypeInterface>(unwrap(static_cast(self))); + auto attr = ::mlir::cast<::mlir::fly::IntTupleAttr>(ty.getShapeMNK()); + return wrap(::mlir::fly::IntTupleType::get(attr)); + }); + c.def_prop_ro("tv_layout_a", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { + auto ty = + ::mlir::cast<::mlir::fly::MmaAtomTypeInterface>(unwrap(static_cast(self))); + auto attr = ::mlir::cast<::mlir::fly::LayoutAttr>(ty.getThrValLayoutA()); + return wrap(::mlir::fly::LayoutType::get(attr)); + }); + c.def_prop_ro("tv_layout_b", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { + auto ty = + ::mlir::cast<::mlir::fly::MmaAtomTypeInterface>(unwrap(static_cast(self))); + auto attr = ::mlir::cast<::mlir::fly::LayoutAttr>(ty.getThrValLayoutB()); + return wrap(::mlir::fly::LayoutType::get(attr)); + }); + c.def_prop_ro("tv_layout_c", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { + auto ty = + ::mlir::cast<::mlir::fly::MmaAtomTypeInterface>(unwrap(static_cast(self))); + auto attr = ::mlir::cast<::mlir::fly::LayoutAttr>(ty.getThrValLayoutC()); + return wrap(::mlir::fly::LayoutType::get(attr)); + }); + } +}; + struct PyCopyOpCDNA3BufferLDSTType : PyConcreteType { static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFlyROCDLCopyOpCDNA3BufferLDSTType; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -131,5 +205,6 @@ NB_MODULE(_fly_rocdl, m) { m.doc() = "MLIR Python FlyROCDL Extension"; ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly_rocdl::PyMmaAtomCDNA3_MFMAType::bind(m); + ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly_rocdl::PyMmaAtomGFX1250_WMMAType::bind(m); ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly_rocdl::PyCopyOpCDNA3BufferLDSTType::bind(m); } diff --git a/lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp b/lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp index 07bdcf51..6fc89ee7 100644 --- a/lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp +++ b/lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp @@ -55,6 +55,50 @@ MlirType mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetElemTyAcc(MlirType type) { return wrap(cast(unwrap(type)).getElemTyAcc()); } +//===----------------------------------------------------------------------===// +// MmaAtomGFX1250_WMMAType +//===----------------------------------------------------------------------===// + +bool mlirTypeIsAFlyROCDLMmaAtomGFX1250_WMMAType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetTypeID(void) { + return wrap(MmaAtomGFX1250_WMMAType::getTypeID()); +} + +MlirType mlirFlyROCDLMmaAtomGFX1250_WMMATypeGet(int32_t m, int32_t n, int32_t k, + MlirType elemTyA, + MlirType elemTyB, + MlirType elemTyAcc) { + return wrap(MmaAtomGFX1250_WMMAType::get(m, n, k, unwrap(elemTyA), + unwrap(elemTyB), unwrap(elemTyAcc))); +} + +int32_t mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetM(MlirType type) { + return cast(unwrap(type)).getM(); +} + +int32_t mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetN(MlirType type) { + return cast(unwrap(type)).getN(); +} + +int32_t mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetK(MlirType type) { + return cast(unwrap(type)).getK(); +} + +MlirType mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetElemTyA(MlirType type) { + return wrap(cast(unwrap(type)).getElemTyA()); +} + +MlirType mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetElemTyB(MlirType type) { + return wrap(cast(unwrap(type)).getElemTyB()); +} + +MlirType mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetElemTyAcc(MlirType type) { + return wrap(cast(unwrap(type)).getElemTyAcc()); +} + //===----------------------------------------------------------------------===// // CopyOpCDNA3BufferLDSTType //===----------------------------------------------------------------------===// diff --git a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp index 433ba625..9681a7cb 100644 --- a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp +++ b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp @@ -693,6 +693,8 @@ class MmaAtomCallLowering : public OpConversionPattern { return lowerUniversalFMA(op, rewriter, loc, universalFma, dPtr, aPtr, bPtr, cPtr); else if (auto cdna3Mfma = dyn_cast(mmaAtomType)) return lowerCDNA3MFMA(op, rewriter, loc, cdna3Mfma, dPtr, aPtr, bPtr, cPtr); + else if (auto gfx1250Wmma = dyn_cast(mmaAtomType)) + return lowerGFX1250WMMA(op, rewriter, loc, gfx1250Wmma, dPtr, aPtr, bPtr, cPtr); return rewriter.notifyMatchFailure(op, "unsupported MmaAtom type"); } @@ -849,6 +851,172 @@ class MmaAtomCallLowering : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "no matching ROCDL MFMA intrinsic"); } + + static Type getWmmaABType(MLIRContext *ctx, int32_t m, int32_t k, Type elemTy) { + if (m <= 0 || k <= 0) + return nullptr; + + Type i32Ty = IntegerType::get(ctx, 32); + + // fp8/bf8 WMMA operands are packed into i32 vectors. + if (isF8(elemTy)) { + if (k == 16) + return VectorType::get({2}, i32Ty); + if (k == 64) + return VectorType::get({8}, i32Ty); + if (k == 128) + return VectorType::get({16}, i32Ty); + return nullptr; + } + + // Integer WMMA operands are packed into i32 vectors. + if (elemTy.isInteger(8)) { + if (k == 16 || k == 32) + return VectorType::get({4}, i32Ty); + if (k == 64) + return VectorType::get({8}, i32Ty); + return nullptr; + } + + int64_t abElemsPerLane = static_cast(m) * static_cast(k) / 32; + if (abElemsPerLane <= 0 || (static_cast(m) * static_cast(k)) % 32 != 0) + return nullptr; + return VectorType::get({abElemsPerLane}, elemTy); + } + + static int64_t getWmmaAccVecSize(int32_t m, int32_t k, Type elemTyA, Type elemTyB, + Type elemTyAcc) { + // Current backend wiring only dispatches ROCDL ops that exist in this + // MLIR version; keep sizing generic per supported WMMA shape/type family. + if (m != 16) + return 0; + + // NOTE: rocdl.wmma.f64.16x16x4.f64 is not exposed in the current MLIR + // ROCDL dialect build, so f64 is intentionally not dispatched here. + if (k == 4 && elemTyA.isF32() && elemTyB.isF32() && elemTyAcc.isF32()) + return 8; + + if (k == 32 && elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF32()) + return 8; + if (k == 32 && elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF16()) + return 8; + if (k == 32 && elemTyA.isBF16() && elemTyB.isBF16() && elemTyAcc.isF32()) + return 8; + if (k == 32 && elemTyA.isBF16() && elemTyB.isBF16() && elemTyAcc.isBF16()) + return 8; + + if (k == 64 && isF8(elemTyA) && isF8(elemTyB) && elemTyAcc.isF32()) + return 8; + if (k == 64 && isF8(elemTyA) && isF8(elemTyB) && elemTyAcc.isF16()) + return 8; + if (k == 128 && isF8(elemTyA) && isF8(elemTyB) && elemTyAcc.isF32()) + return 8; + if (k == 128 && isF8(elemTyA) && isF8(elemTyB) && elemTyAcc.isF16()) + return 8; + + if (k == 64 && elemTyA.isInteger(8) && elemTyB.isInteger(8) && elemTyAcc.isInteger(32)) + return 8; + + return 0; + } + + enum class WmmaVariant { ModsAllReuse, ModsC, ModsABClamp }; + + template + LogicalResult emitWmma(MmaAtomCall op, ConversionPatternRewriter &rewriter, Location loc, + Type abTyA, Type abTyB, VectorType accTy, Value aPtr, Value bPtr, + Value cPtr, Value dPtr) const { + Value a = LLVM::LoadOp::create(rewriter, loc, abTyA, aPtr); + Value b = LLVM::LoadOp::create(rewriter, loc, abTyB, bPtr); + Value c = LLVM::LoadOp::create(rewriter, loc, accTy, cPtr); + Value res; + if constexpr (Variant == WmmaVariant::ModsAllReuse) { + res = WmmaOp::create(rewriter, loc, accTy, + /*signA=*/false, a, /*signB=*/false, b, + /*modC=*/(uint16_t)0, c) + .getResult(); + } else if constexpr (Variant == WmmaVariant::ModsC) { + res = WmmaOp::create(rewriter, loc, accTy, a, b, + /*modC=*/(uint16_t)0, c, + /*reuseA=*/false, /*reuseB=*/false) + .getResult(); + } else { + static_assert(Variant == WmmaVariant::ModsABClamp); + res = WmmaOp::create(rewriter, loc, accTy, + /*signA=*/false, a, /*signB=*/false, b, c, + /*reuseA=*/false, /*reuseB=*/false, /*clamp=*/false) + .getResult(); + } + LLVM::StoreOp::create(rewriter, loc, res, dPtr); + rewriter.eraseOp(op); + return success(); + } + + LogicalResult lowerGFX1250WMMA(MmaAtomCall op, ConversionPatternRewriter &rewriter, Location loc, + fly_rocdl::MmaAtomGFX1250_WMMAType atomTy, Value dPtr, Value aPtr, + Value bPtr, Value cPtr) const { + int32_t m = atomTy.getM(); + int32_t n = atomTy.getN(); + int32_t k = atomTy.getK(); + Type elemTyA = atomTy.getElemTyA(); + Type elemTyB = atomTy.getElemTyB(); + Type elemTyAcc = atomTy.getElemTyAcc(); + MLIRContext *ctx = rewriter.getContext(); + + Type abTyA = getWmmaABType(ctx, m, k, elemTyA); + Type abTyB = getWmmaABType(ctx, m, k, elemTyB); + if (!abTyA || !abTyB) + return rewriter.notifyMatchFailure(op, "unsupported A/B element packing for WMMA"); + + int64_t accVecSize = getWmmaAccVecSize(m, k, elemTyA, elemTyB, elemTyAcc); + if (accVecSize == 0) + return rewriter.notifyMatchFailure(op, "unsupported MNK/type combination for WMMA"); + + VectorType accTy = VectorType::get({accVecSize}, elemTyAcc); + +#define DISPATCH_WMMA(M_, K_, PRED, OP, VARIANT) \ + if (m == M_ && n == M_ && k == K_ && (PRED)) \ + return emitWmma(op, rewriter, loc, abTyA, abTyB, accTy, \ + aPtr, bPtr, cPtr, dPtr); + +#define DISPATCH_WMMA_FP8(K_, ACC_PRED, ACC_PREFIX) \ + DISPATCH_WMMA(16, K_, isFP8(elemTyA) && isFP8(elemTyB) && ACC_PRED, \ + wmma_##ACC_PREFIX##_16x16x##K_##_fp8_fp8, ModsC) \ + DISPATCH_WMMA(16, K_, isFP8(elemTyA) && isBF8(elemTyB) && ACC_PRED, \ + wmma_##ACC_PREFIX##_16x16x##K_##_fp8_bf8, ModsC) \ + DISPATCH_WMMA(16, K_, isBF8(elemTyA) && isFP8(elemTyB) && ACC_PRED, \ + wmma_##ACC_PREFIX##_16x16x##K_##_bf8_fp8, ModsC) \ + DISPATCH_WMMA(16, K_, isBF8(elemTyA) && isBF8(elemTyB) && ACC_PRED, \ + wmma_##ACC_PREFIX##_16x16x##K_##_bf8_bf8, ModsC) + + DISPATCH_WMMA(16, 4, elemTyA.isF32() && elemTyB.isF32() && elemTyAcc.isF32(), + wmma_f32_16x16x4_f32, ModsAllReuse) + + DISPATCH_WMMA(16, 32, elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF32(), + wmma_f32_16x16x32_f16, ModsAllReuse) + DISPATCH_WMMA(16, 32, elemTyA.isBF16() && elemTyB.isBF16() && elemTyAcc.isF32(), + wmma_f32_16x16x32_bf16, ModsAllReuse) + DISPATCH_WMMA(16, 32, elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF16(), + wmma_f16_16x16x32_f16, ModsAllReuse) + DISPATCH_WMMA(16, 32, elemTyA.isBF16() && elemTyB.isBF16() && elemTyAcc.isBF16(), + wmma_bf16_16x16x32_bf16, ModsAllReuse) + + // bf16f32 WMMA requires C:f32 and D:bf16. Current MmaAtom interface carries + // one accumulator type, so mixed C/D typing is not representable yet. + + DISPATCH_WMMA_FP8(64, elemTyAcc.isF32(), f32) + DISPATCH_WMMA_FP8(64, elemTyAcc.isF16(), f16) + DISPATCH_WMMA_FP8(128, elemTyAcc.isF32(), f32) + DISPATCH_WMMA_FP8(128, elemTyAcc.isF16(), f16) + + DISPATCH_WMMA(16, 64, elemTyA.isInteger(8) && elemTyB.isInteger(8) && elemTyAcc.isInteger(32), + wmma_i32_16x16x64_iu8, ModsABClamp) + +#undef DISPATCH_WMMA_FP8 +#undef DISPATCH_WMMA + + return rewriter.notifyMatchFailure(op, "no matching ROCDL WMMA intrinsic"); + } }; /// Lower `gpu.launch_func` kernel operands so that any `!fly.memref` values are diff --git a/lib/Dialect/FlyROCDL/CMakeLists.txt b/lib/Dialect/FlyROCDL/CMakeLists.txt index 0151891b..f8283598 100644 --- a/lib/Dialect/FlyROCDL/CMakeLists.txt +++ b/lib/Dialect/FlyROCDL/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRFlyROCDLDialect Dialect.cpp CDNA3/MmaAtom.cpp CDNA3/CopyAtom.cpp + GFX1250/MmaAtom.cpp DEPENDS MLIRFlyROCDLIncGen diff --git a/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp b/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp new file mode 100644 index 00000000..e8af1e49 --- /dev/null +++ b/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp @@ -0,0 +1,163 @@ +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/FlyROCDL/IR/Dialect.h" +#include "mlir/IR/BuiltinTypes.h" + +#include "flydsl/Dialect/Fly/Utils/ThrValLayoutMacro.h.inc" + +using namespace mlir; +using namespace mlir::fly; + +namespace gfx1250 { + +static int getElemBits(Type ty) { + if (ty.isF32() || ty.isInteger(32)) + return 32; + if (ty.isF16() || ty.isBF16()) + return 16; + if (isa(ty) || isa(ty) || + ty.isInteger(8)) + return 8; + return 0; +} + +// A/B matrix register layout for GFX1250 WMMA (wave32). +// +// The A matrix is M×K (M=16, K varies by instruction). The 32 lanes split +// into two groups of 16 (group = lane/16). Both groups hold different slices +// of the K dimension. +// +// For 32-bit elements (f32, K=4): +// Each lane holds K/2 values. Group g covers K = g*(K/2) .. (g+1)*(K/2)-1. +// No sub-element packing. 2 VGPRs per lane. +// Formula: K = (l/16)*2 + v +// +// For sub-32-bit elements (f16/bf16 K=32, fp8/bf8/i8 K=64/128): +// Each lane holds K/2 values, organized in blocks of 8. Within each +// block, group 0 holds the lower 8 K-values, group 1 holds the upper 8. +// Formula: K = block*16 + (l/16)*8 + within_block +// where block = flat_val / 8, within_block = flat_val % 8. +// +// Reference space is column-major (M,K) with stride (1, M=16). +// The B matrix (N×K) uses the identical layout with N substituted for M. +LayoutAttr getThrValLayoutAB(MLIRContext *ctx, int32_t K, Type elemTy) { + auto getContext = [&]() { return ctx; }; + + int elemBits = getElemBits(elemTy); + int valsPerLane = K / 2; + + if (elemBits == 32) { + // f32 16×4: 2 values/lane, no sub-element packing. + // pos = (l%16)*1 + (l/16)*(valsPerLane*16) + v*16 + return FxLayout(FxShape(FxThr(16, 2), FxVal(valsPerLane)), + FxStride(FxThr(1, valsPerLane * 16), FxVal(16))); + } + + // Sub-32-bit: interleaving block of 8 values between lane groups. + // pos = (l%16)*1 + (l/16)*128 + val_within*16 [+ block*256] + int numBlocks = valsPerLane / 8; + if (numBlocks == 1) { + return FxLayout(FxShape(FxThr(16, 2), FxVal(8)), + FxStride(FxThr(1, 128), FxVal(16))); + } + return FxLayout(FxShape(FxThr(16, 2), FxVal(8, numBlocks)), + FxStride(FxThr(1, 128), FxVal(16, 256))); +} + +// C/D matrix register layout for GFX1250 WMMA (wave32). +// +// C/D is always 16×16 (M×N). Lane l covers N = l%16. The two lane groups +// cover M=0..7 (group 0) and M=8..15 (group 1). +// +// 32-bit accumulator (f32, i32): 8 VGPRs, one element per VGPR. +// M = (l/16)*8 + v +// +// 16-bit accumulator (f16, bf16): 4 VGPRs, two packed sub-elements each. +// M = (l/16)*8 + v*2 + s +// +// Reference space is column-major (M,N) with stride (1, M=16). +LayoutAttr getThrValLayoutCD(MLIRContext *ctx, Type elemTyAcc) { + auto getContext = [&]() { return ctx; }; + + int elemBits = getElemBits(elemTyAcc); + if (elemBits >= 32) { + return FxLayout(FxShape(FxThr(16, 2), FxVal(8)), + FxStride(FxThr(16, 8), FxVal(1))); + } + // 16-bit: 4 VGPRs × 2 sub-elements = 8 values. + return FxLayout(FxShape(FxThr(16, 2), FxVal(4, 2)), + FxStride(FxThr(16, 8), FxVal(2, 1))); +} + +} // namespace gfx1250 + +namespace mlir::fly_rocdl { + +bool MmaAtomGFX1250_WMMAType::isStatic() const { return true; } + +Type MmaAtomGFX1250_WMMAType::getValTypeA() const { return getElemTyA(); } +Type MmaAtomGFX1250_WMMAType::getValTypeB() const { return getElemTyB(); } +Type MmaAtomGFX1250_WMMAType::getValTypeC() const { return getElemTyAcc(); } +Type MmaAtomGFX1250_WMMAType::getValTypeD() const { return getElemTyAcc(); } + +Attribute MmaAtomGFX1250_WMMAType::getThrLayout() const { + return FxLayout(FxC(32), FxC(1)); +} + +Attribute MmaAtomGFX1250_WMMAType::getShapeMNK() const { + return IntTupleAttr::get( + ArrayAttr::get(getContext(), {FxC(getM()), FxC(getN()), FxC(getK())})); +} + +Attribute MmaAtomGFX1250_WMMAType::getThrValLayoutA() const { + return gfx1250::getThrValLayoutAB(getContext(), getK(), getElemTyA()); +} + +Attribute MmaAtomGFX1250_WMMAType::getThrValLayoutB() const { + return gfx1250::getThrValLayoutAB(getContext(), getK(), getElemTyB()); +} + +Attribute MmaAtomGFX1250_WMMAType::getThrValLayoutC() const { + return gfx1250::getThrValLayoutCD(getContext(), getElemTyAcc()); +} + +LogicalResult +MmaAtomGFX1250_WMMAType::verify(function_ref emitError, + int32_t m, int32_t n, int32_t k, Type elemTyA, + Type elemTyB, Type elemTyAcc) { + if (m != 16 || n != 16) + return emitError() << "GFX1250 WMMA requires M=N=16, got " << m << "x" + << n; + + auto isF8 = [](Type ty) { + return isa(ty) || isa(ty); + }; + + bool valid = false; + + if (k == 4 && elemTyA.isF32() && elemTyB.isF32() && elemTyAcc.isF32()) + valid = true; + + if (k == 32 && elemTyA.isF16() && elemTyB.isF16() && + (elemTyAcc.isF32() || elemTyAcc.isF16())) + valid = true; + if (k == 32 && elemTyA.isBF16() && elemTyB.isBF16() && + (elemTyAcc.isF32() || elemTyAcc.isBF16())) + valid = true; + + if ((k == 64 || k == 128) && isF8(elemTyA) && isF8(elemTyB) && + (elemTyAcc.isF32() || elemTyAcc.isF16())) + valid = true; + + if (k == 64 && elemTyA.isInteger(8) && elemTyB.isInteger(8) && + elemTyAcc.isInteger(32)) + valid = true; + + if (!valid) { + return emitError() << "unsupported GFX1250 WMMA configuration: " << m + << "x" << n << "x" << k << " with A=" << elemTyA + << ", B=" << elemTyB << ", Acc=" << elemTyAcc; + } + return success(); +} + +} // namespace mlir::fly_rocdl diff --git a/lib/Runtime/FlyRocmRuntimeWrappers.cpp b/lib/Runtime/FlyRocmRuntimeWrappers.cpp index 2981e219..f4037620 100644 --- a/lib/Runtime/FlyRocmRuntimeWrappers.cpp +++ b/lib/Runtime/FlyRocmRuntimeWrappers.cpp @@ -66,6 +66,61 @@ extern "C" void mgpuLaunchKernel(hipFunction_t function, intptr_t gridX, stream, params, extra)); } +extern "C" void mgpuLaunchClusterKernel(hipFunction_t function, + intptr_t clusterX, intptr_t clusterY, + intptr_t clusterZ, + intptr_t gridX, intptr_t gridY, + intptr_t gridZ, + intptr_t blockX, intptr_t blockY, + intptr_t blockZ, int32_t smem, + hipStream_t stream, void **params, + void **extra, size_t /*paramsCount*/) { + hipLaunchAttribute attrs[1]; + attrs[0].id = hipLaunchAttributeClusterDimension; + attrs[0].value.clusterDim.x = static_cast(clusterX); + attrs[0].value.clusterDim.y = static_cast(clusterY); + attrs[0].value.clusterDim.z = static_cast(clusterZ); + + HIP_LAUNCH_CONFIG config{}; + config.gridDimX = static_cast(gridX); + config.gridDimY = static_cast(gridY); + config.gridDimZ = static_cast(gridZ); + config.blockDimX = static_cast(blockX); + config.blockDimY = static_cast(blockY); + config.blockDimZ = static_cast(blockZ); + config.sharedMemBytes = static_cast(smem); + config.hStream = stream; + config.attrs = attrs; + config.numAttrs = 1; + + hipError_t err = hipDrvLaunchKernelEx(&config, function, params, extra); + if (err == hipSuccess) + return; + + const bool requestedRealCluster = + (clusterX > 1) || (clusterY > 1) || (clusterZ > 1); + if (requestedRealCluster) { + fprintf(stderr, + "[mgpuLaunchClusterKernel] hipDrvLaunchKernelEx failed (err=%d) " + "for requested cluster=(%ld,%ld,%ld); not falling back to " + "hipModuleLaunchKernel.\n", + static_cast(err), static_cast(clusterX), + static_cast(clusterY), static_cast(clusterZ)); + HIP_REPORT_IF_ERROR(err); + HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, 0, 0, 0, 0, 0, 0, smem, + stream, params, extra)); + return; + } + + fprintf(stderr, + "[mgpuLaunchClusterKernel] hipDrvLaunchKernelEx failed (err=%d) " + "for cluster=(1,1,1); falling back to hipModuleLaunchKernel.\n", + static_cast(err)); + HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ, + blockX, blockY, blockZ, smem, + stream, params, extra)); +} + extern "C" hipStream_t mgpuStreamCreate() { hipStream_t stream = nullptr; HIP_REPORT_IF_ERROR(hipStreamCreate(&stream)); diff --git a/python/flydsl/__init__.py b/python/flydsl/__init__.py index dbd27816..9f9e3849 100644 --- a/python/flydsl/__init__.py +++ b/python/flydsl/__init__.py @@ -1,5 +1,38 @@ +import ctypes +import os + _BASE_VERSION = "0.1.0" + +# Workaround: resolve FFM simulator "LLVM ERROR: Option 'greedy' already exists!" +def _maybe_preload_system_comgr() -> None: + disable = os.environ.get("FLYDSL_DISABLE_COMGR_PRELOAD", "").strip().lower() + if disable in {"1", "true", "yes", "on"}: + return + + model_path = os.environ.get("GFX1250_MODEL_PATH", "") + hsa_model_lib = os.environ.get("HSA_MODEL_LIB", "") + in_ffm_session = ("ffm-lite" in hsa_model_lib) or ("ffmlite" in model_path) + if not in_ffm_session: + return + + system_comgr = os.environ.get( + "FLYDSL_COMGR_PRELOAD_PATH", "/opt/rocm/lib/libamd_comgr.so.3" + ) + sim_comgr = os.path.join(model_path, "rocm", "libamd_comgr.so.3") + if not (os.path.exists(system_comgr) and os.path.exists(sim_comgr)): + return + + mode = getattr(os, "RTLD_NOW", 0) | getattr(os, "RTLD_GLOBAL", 0) + try: + ctypes.CDLL(system_comgr, mode=mode) + except OSError: + # Keep import robust if the host ROCm stack differs. + pass + + +_maybe_preload_system_comgr() + try: from ._version import __version__ except ImportError: diff --git a/python/flydsl/compiler/jit_function.py b/python/flydsl/compiler/jit_function.py index d1876427..5bd12ebb 100644 --- a/python/flydsl/compiler/jit_function.py +++ b/python/flydsl/compiler/jit_function.py @@ -316,6 +316,13 @@ def _pipeline_fragments(*, chip: str) -> list: "gpu-module-to-binary{format=fatbin}", ] + @staticmethod + def _use_wave64(chip: str) -> bool: + chip = str(chip) + if chip.startswith("gfx12"): + return False + return True + @classmethod def compile(cls, module: ir.Module, *, chip: str = None, func_name: str = "") -> ir.Module: module.operation.verify() diff --git a/python/flydsl/compiler/kernel_function.py b/python/flydsl/compiler/kernel_function.py index ee41bd39..cadd51fb 100644 --- a/python/flydsl/compiler/kernel_function.py +++ b/python/flydsl/compiler/kernel_function.py @@ -237,6 +237,7 @@ def launch( block: DimType = (1, 1, 1), smem: Union[int, ir.Value] = 0, stream: Optional[ir.Value] = None, + cluster: Optional[DimType] = None, ) -> None: """Emit gpu.launch_func operation with the given configuration. @@ -245,6 +246,8 @@ def launch( block: Block dimensions (x, y, z). Can be int, ir.Value, tuple, or list. smem: Dynamic shared memory size in bytes. Can be int or ir.Value. stream: CUDA/HIP stream as ir.Value. None means default stream. + cluster: Cluster dimensions (x, y, z) for workgroup clustering. + None means no clustering. Enables MCAST and cluster barriers. """ launch_loc = create_caller_location(depth=2) @@ -277,6 +280,15 @@ def launch( async_deps = [stream_val] if stream_val is not None else None + cluster_size = None + if cluster is not None: + cx, cy, cz = _normalize_dim(cluster) + cluster_size = ( + _to_index_value(cx), + _to_index_value(cy), + _to_index_value(cz), + ) + gpu.LaunchFuncOp( ["kernels", self._kernel_name], (grid_x, grid_y, grid_z), @@ -284,6 +296,7 @@ def launch( kernel_operands, async_dependencies=async_deps, dynamic_shared_memory_size=smem_val, + cluster_size=cluster_size, loc=launch_loc, ip=None, ) diff --git a/python/flydsl/expr/__init__.py b/python/flydsl/expr/__init__.py index 892eb6bf..b654baa4 100644 --- a/python/flydsl/expr/__init__.py +++ b/python/flydsl/expr/__init__.py @@ -4,4 +4,4 @@ from .gpu import * from .derived import * -from . import arith, vector, gpu, buffer_ops, rocdl +from . import arith, vector, gpu, buffer_ops, rocdl, tdm_ops diff --git a/python/flydsl/expr/gpu.py b/python/flydsl/expr/gpu.py index be4d08e7..cac80ea2 100644 --- a/python/flydsl/expr/gpu.py +++ b/python/flydsl/expr/gpu.py @@ -14,9 +14,12 @@ """ from .._mlir import ir -from .._mlir.dialects import gpu +from .._mlir.dialects import gpu, rocdl, scf from .._mlir.ir import Attribute from .typing import Tuple3D +from . import arith as _arith_ext +from . import rocdl as _rocdl_ext +from .typing import T thread_id = gpu.thread_id block_id = gpu.block_id @@ -52,6 +55,106 @@ class SharedAllocator: pass +# ========================================================================= +# Cluster operations (gfx1250 workgroup clustering) +# ========================================================================= + +CLUSTER_BARRIER_ID = -3 +# For cluster sync, wait on the cluster user barrier itself. +CLUSTER_WAIT_ALL = CLUSTER_BARRIER_ID + + +def is_wave_leader(): + """Return true for wave-0 inside the workgroup.""" + return _arith_ext.cmpi( + _arith_ext.CmpIPredicate.eq, + _rocdl_ext.wave_id(), + _arith_ext.constant(0, type=T.i32), + ) + + +def cluster_signal_once_per_wg(): + """Signal cluster barrier from exactly one wave per workgroup.""" + if_op = scf.IfOp(is_wave_leader(), [], has_else=False, loc=ir.Location.unknown()) + if len(if_op.regions[0].blocks) == 0: + if_op.regions[0].blocks.append(*[]) + with ir.InsertionPoint(if_op.regions[0].blocks[0]): + rocdl.s_barrier_signal(CLUSTER_BARRIER_ID) + scf.YieldOp([]) + + +def cluster_wait(): + """Wait on the cluster user barrier.""" + rocdl.s_barrier_wait(CLUSTER_WAIT_ALL) + + +def cluster_barrier(): + """Workgroup + cluster barrier with one-wave signal semantics. + + This is the safe default for kernels using cluster multicast: + 1) synchronize waves inside each workgroup + 2) signal cluster barrier once per workgroup (wave-0 only) + 3) wait for all workgroups in the cluster + """ + gpu.barrier() + cluster_signal_once_per_wg() + cluster_wait() + + +def compute_cluster_position(): + """Compute a workgroup's (row, col) position within its cluster. + + Returns: + (local_x, local_y) as MLIR index values — position within the cluster. + """ + local_x = _arith_ext.index_cast(T.index, _rocdl_ext.cluster_workgroup_id_x()) + local_y = _arith_ext.index_cast(T.index, _rocdl_ext.cluster_workgroup_id_y()) + return local_x, local_y + + +def compute_mcast_masks(local_x, local_y, cluster_m: _int, cluster_n: _int): + """Compute MCAST workgroup_mask values for A and B matrices. + + Hardware flat WG index within a cluster uses X-inner ordering + (MI400 Shader Programming, TTMP6 layout, section 3.5.5.1): + + flat_wg_id = wg_x + wg_y * nwg_x = local_x + local_y * cluster_m + + where cluster_dims = (cluster_m, cluster_n, 1), so nwg_x = cluster_m. + + A mask: WGs sharing the same M-tile row (same local_x, varying local_y). + Bits: {local_x + ly * cluster_m : ly in 0..cluster_n-1} + B mask: WGs sharing the same N-tile column (same local_y, varying local_x). + Bits: {lx + local_y * cluster_m : lx in 0..cluster_m-1} + + Args: + local_x: WG row within cluster (MLIR index, 0..cluster_m-1). + local_y: WG column within cluster (MLIR index, 0..cluster_n-1). + cluster_m: Cluster rows (Python int). + cluster_n: Cluster columns (Python int). + + Returns: + (a_mask, b_mask) as MLIR i32 values for TDM workgroup_mask. + """ + local_x_i32 = _arith_ext.index_cast(T.i32, local_x) + local_y_i32 = _arith_ext.index_cast(T.i32, local_y) + cluster_m_i32 = _arith_ext.constant(cluster_m, type=T.i32) + + # A mask: pattern has bits at strides of cluster_m, shifted by local_x + a_pattern_val = 0 + for ly in range(cluster_n): + a_pattern_val |= (1 << (ly * cluster_m)) + a_pattern = _arith_ext.constant(a_pattern_val, type=T.i32) + a_mask = _arith_ext.shli(a_pattern, local_x_i32) + + # B mask: cluster_m contiguous low bits, shifted by local_y * cluster_m + b_pattern = _arith_ext.constant((1 << cluster_m) - 1, type=T.i32) + col_base = _arith_ext.muli(local_y_i32, cluster_m_i32) + b_mask = _arith_ext.shli(b_pattern, col_base) + + return a_mask, b_mask + + __all__ = [ "thread_id", "block_id", @@ -63,4 +166,12 @@ class SharedAllocator: "smem_space", "lds_space", "SharedAllocator", + "is_wave_leader", + "cluster_signal_once_per_wg", + "cluster_wait", + "cluster_barrier", + "compute_cluster_position", + "compute_mcast_masks", + "CLUSTER_BARRIER_ID", + "CLUSTER_WAIT_ALL", ] diff --git a/python/flydsl/expr/rocdl.py b/python/flydsl/expr/rocdl.py index cb1656b4..27c3fada 100644 --- a/python/flydsl/expr/rocdl.py +++ b/python/flydsl/expr/rocdl.py @@ -18,6 +18,7 @@ from .._mlir._mlir_libs._fly_rocdl import CopyOpCDNA3BufferLDSTType from .._mlir._mlir_libs._fly_rocdl import MmaAtomCDNA3_MFMAType +from .._mlir._mlir_libs._fly_rocdl import MmaAtomGFX1250_WMMAType BufferLDST = lambda bit_size: CopyOpCDNA3BufferLDSTType.get(bit_size) # noqa: E731 BufferLDST32b = lambda: CopyOpCDNA3BufferLDSTType.get(32) # noqa: E731 @@ -49,6 +50,29 @@ def MFMA(m, n, k, elem_type, elem_type_b=None, elem_type_acc=None): return MmaAtomCDNA3_MFMAType.get(m, n, k, ty, ty_b, ty_acc) +def WMMA(m, n, k, elem_type, elem_type_b=None, elem_type_acc=None): + """Create a WMMA MMA atom type for GFX1250 (wave32). + + Args: + m, n, k: WMMA tile dimensions. + elem_type: Element type for A operand. + elem_type_b: Element type for B operand (defaults to elem_type). + elem_type_acc: Element type for accumulator (defaults to elem_type). + """ + from .._mlir import ir + + if isinstance(elem_type, type) and hasattr(elem_type, 'ir_type'): + ty = elem_type.ir_type + elif isinstance(elem_type, ir.Type): + ty = elem_type + else: + raise TypeError(f"WMMA: unsupported elem_type {elem_type}") + + ty_b = ty if elem_type_b is None else (elem_type_b.ir_type if hasattr(elem_type_b, 'ir_type') else elem_type_b) + ty_acc = ty if elem_type_acc is None else (elem_type_acc.ir_type if hasattr(elem_type_acc, 'ir_type') else elem_type_acc) + return MmaAtomGFX1250_WMMAType.get(m, n, k, ty, ty_b, ty_acc) + + def make_buffer_tensor(memref, alignment=4, loc=None, ip=None): """Convert a global-address-space fly memref to a buffer_desc memref. @@ -84,6 +108,21 @@ def make_buffer_tensor(memref, alignment=4, loc=None, ip=None): return _prim.make_view(bd_ptr, layout, loc=loc, ip=ip) # Keep references to ODS-generated builders so we can wrap them without losing access. +_ods_wmma_scale_f32_16x16x128_f8f6f4 = ( + globals().get("wmma_scale_f32_16x16x128_f8f6f4", None) +) +_ods_wmma_scale_f32_32x16x128_f4 = ( + globals().get("wmma_scale_f32_32x16x128_f4", None) +) +_ods_wave_id = wave_id # ODS: wave_id(res, ...) -> i32 +_ods_cluster_workgroup_id_x = cluster_workgroup_id_x +_ods_cluster_workgroup_id_y = cluster_workgroup_id_y +_ods_cluster_workgroup_id_z = cluster_workgroup_id_z +_ods_cluster_load_async_to_lds_b8 = cluster_load_async_to_lds_b8 +_ods_cluster_load_async_to_lds_b32 = cluster_load_async_to_lds_b32 +_ods_cluster_load_async_to_lds_b64 = cluster_load_async_to_lds_b64 +_ods_cluster_load_async_to_lds_b128 = cluster_load_async_to_lds_b128 +_ods_s_wait_asynccnt = s_wait_asynccnt _ods_mfma_f32_16x16x16f16 = mfma_f32_16x16x16f16 _ods_mfma_f32_16x16x16bf16_1k = globals().get("mfma_f32_16x16x16bf16_1k", None) _ods_mfma_f32_16x16x32_fp8_fp8 = mfma_f32_16x16x32_fp8_fp8 @@ -293,6 +332,190 @@ def wmma_i32_16x16x32_iu4(result_type, operands, *, loc=None, ip=None): return _ods_wmma_i32_16x16x32_iu4(result_type, ops, loc=loc, ip=ip).result +# --- WMMA Scale variants (gfx1250 mxfp4) --- + +def wmma_scale_f32_16x16x128_f8f6f4(result_type, a, b, c, scaleA, scaleB, + *, fmtA=4, fmtB=4, modC=0, + scaleAType=0, fmtScaleA=0, + scaleBType=0, fmtScaleB=0, + reuseA=False, reuseB=False, + loc=None, ip=None): + """V_WMMA_SCALE_F32_16X16X128_F8F6F4 for gfx1250 (wave32). + + Operand types (wave32): + a: vector<8xi32> (16x128 FP4 data) + b: vector<8xi32> (128x16 FP4 data) + c: vector<8xf32> (16x16 FP32 accumulator) + scaleA: i32 (A scale VGPR) + scaleB: i32 (B scale VGPR) + + fmtA/fmtB: data type encoding (0=FP8/E4M3, 1=FP8/E5M2, 2=FP6/E2M3, 3=FP6/E3M2, 4=FP4/E2M1) + scaleAType/scaleBType: opsel – selects lo/hi 16-bit half of scale VGPR (0=lo, 1=hi) + fmtScaleA/fmtScaleB: scale format (0=E8M0, 1=E5M3, 2=E4M3) + """ + if _ods_wmma_scale_f32_16x16x128_f8f6f4 is None: + raise AttributeError("ROCDL op not found: wmma_scale_f32_16x16x128_f8f6f4") + a_v = _unwrap_mfma_operand(a, loc=loc) + b_v = _unwrap_mfma_operand(b, loc=loc) + c_v = _unwrap_mfma_operand(c, loc=loc) + sA = _unwrap_mfma_operand(scaleA, loc=loc) + sB = _unwrap_mfma_operand(scaleB, loc=loc) + return _ods_wmma_scale_f32_16x16x128_f8f6f4( + result_type, a_v, b_v, c_v, sA, sB, + fmtA=fmtA, fmtB=fmtB, modC=modC, + scaleAType=scaleAType, fmtScaleA=fmtScaleA, + scaleBType=scaleBType, fmtScaleB=fmtScaleB, + reuseA=reuseA, reuseB=reuseB, + loc=loc, ip=ip, + ).result + + +def wmma_scale_f32_32x16x128_f4(result_type, a, b, c, scaleA, scaleB, + *, modC=0, + scaleAType=0, fmtScaleA=0, + scaleBType=0, fmtScaleB=0, + reuseA=False, reuseB=False, + loc=None, ip=None): + """V_WMMA_SCALE_F32_32X16X128_F4 for gfx1250 (wave32). + + Operand types (wave32): + a: vector<16xi32> (32x128 FP4 data) + b: vector<8xi32> (128x16 FP4 data) + c: vector<16xf32> (32x16 FP32 accumulator) + scaleA: i32 (A scale VGPR) + scaleB: i32 (B scale VGPR) + """ + if _ods_wmma_scale_f32_32x16x128_f4 is None: + raise AttributeError("ROCDL op not found: wmma_scale_f32_32x16x128_f4") + a_v = _unwrap_mfma_operand(a, loc=loc) + b_v = _unwrap_mfma_operand(b, loc=loc) + c_v = _unwrap_mfma_operand(c, loc=loc) + sA = _unwrap_mfma_operand(scaleA, loc=loc) + sB = _unwrap_mfma_operand(scaleB, loc=loc) + return _ods_wmma_scale_f32_32x16x128_f4( + result_type, a_v, b_v, c_v, sA, sB, + modC=modC, + scaleAType=scaleAType, fmtScaleA=fmtScaleA, + scaleBType=scaleBType, fmtScaleB=fmtScaleB, + reuseA=reuseA, reuseB=reuseB, + loc=loc, ip=ip, + ).result +def wave_id(): + """Get wave-id-in-workgroup as SGPR (via TTMP8[29:25]). + + On gfx1250 this reads an architected SGPR, so the result stays in + the SGPR pipeline and all derived computations are automatically + scalarized by LLVM uniformity analysis. + + Returns: + i32 value (SGPR) with the wave ID within the workgroup. + """ + from .._mlir import ir + i32 = ir.IntegerType.get_signless(32) + return _ods_wave_id(i32) + + +def cluster_workgroup_id_x(): + """Get workgroup position within cluster along X (SGPR, gfx1250). """ + from .._mlir import ir + i32 = ir.IntegerType.get_signless(32) + return _ods_cluster_workgroup_id_x(i32) + + +def cluster_workgroup_id_y(): + """Get workgroup position within cluster along Y (SGPR, gfx1250). """ + from .._mlir import ir + i32 = ir.IntegerType.get_signless(32) + return _ods_cluster_workgroup_id_y(i32) + + +def cluster_workgroup_id_z(): + """Get workgroup position within cluster along Z (SGPR, gfx1250). """ + from .._mlir import ir + i32 = ir.IntegerType.get_signless(32) + return _ods_cluster_workgroup_id_z(i32) + + +def cluster_load_async_to_lds(global_ptr, lds_ptr, size_bytes, offset=0, cpol=0, mask=None): + """Per-lane cluster broadcast load: Global -> LDS with MCAST (gfx1250). + + Args: + global_ptr: ``!llvm.ptr<1>`` — global address space pointer. + lds_ptr: ``!llvm.ptr<3>`` — LDS address space pointer. + size_bytes: Load width: 1, 4, 8, or 16 bytes (selects b8/b32/b64/b128). + offset: Byte offset (int, default 0). + cpol: Cache policy (int, default 0). + mask: i32 workgroup_mask for MCAST broadcast. None means no mask + (falls back to non-cluster global_load_async_to_lds). + + Raises: + ValueError: If ``size_bytes`` is not 1, 4, 8, or 16. + """ + _dispatch = { + 1: _ods_cluster_load_async_to_lds_b8, + 4: _ods_cluster_load_async_to_lds_b32, + 8: _ods_cluster_load_async_to_lds_b64, + 16: _ods_cluster_load_async_to_lds_b128, + } + fn = _dispatch.get(size_bytes) + if fn is None: + raise ValueError( + f"cluster_load_async_to_lds: size_bytes must be 1, 4, 8, or 16, " + f"got {size_bytes}") + if mask is None: + from .._mlir import ir + from . import arith as _arith + mask = _arith.unwrap(_arith.constant(0, type=ir.IntegerType.get_signless(32))) + fn(global_ptr, lds_ptr, offset, cpol, mask) + + +def s_wait_asynccnt(count=0): + """Wait for outstanding async load/store operations (ASYNCcnt counter). + + Args: + count: Maximum number of outstanding operations to allow. + 0 = wait for all. + """ + _ods_s_wait_asynccnt(count) + + +def lds_transpose_load(result_type, lds_memref, elem_offset, elem_bytes): + """Transpose-load from LDS memref via ds_load_tr16_b128 (gfx1250). + + Args: + result_type: Vector result type, e.g. ``VectorType.get([8], f16)``. + lds_memref: LDS memref value (address-space 3), typically from + ``SmemPtr.get()`` or ``get_op_result_or_value(...)``. + elem_offset: Per-lane linearized element offset into the memref + (ArithValue / ir.Value of index type / Python int). + elem_bytes: Element size in bytes (Python int, e.g. 2 for f16). + + Returns: + Loaded and transposed vector ``ir.Value``. + """ + from .._mlir import ir as _ir + from .._mlir.dialects import ( + llvm as _llvm, + memref as _memref, + rocdl as _rocdl, + ) + from . import arith as _arith + from .arith import _to_raw + from .typing import T + from .utils.arith import ArithValue as _AV + + lds_ptr_ty = _ir.Type.parse("!llvm.ptr<3>") + raw_memref = _arith.unwrap(lds_memref) + lds_base = _memref.extract_aligned_pointer_as_index(raw_memref) + + byte_off = _AV(_arith.unwrap(elem_offset, index=True)) * _arith.index(elem_bytes) + total_byte_idx = _AV(lds_base) + byte_off + addr_i32 = _to_raw(_arith.index_cast(T.i32, total_byte_idx)) + ptr_val = _llvm.inttoptr(lds_ptr_ty, addr_i32) + + return _rocdl.ds_load_tr16_b128(result_type, ptr_val) + + __all__ = [ # Thread/Block/Grid IDs and dimensions 'workitem_id_x', 'workitem_id_y', 'workitem_id_z', @@ -300,11 +523,12 @@ def wmma_i32_16x16x32_iu4(result_type, operands, *, loc=None, ip=None): 'workgroup_dim_x', 'workgroup_dim_y', 'workgroup_dim_z', 'grid_dim_x', 'grid_dim_y', 'grid_dim_z', 'wavefrontsize', - + 'wave_id', + # Synchronization 'barrier', 's_barrier', 's_barrier_signal', 's_barrier_wait', 's_waitcnt', 's_wait_loadcnt', 's_wait_storecnt', - 's_wait_dscnt', 's_wait_expcnt', + 's_wait_dscnt', 's_wait_expcnt', 's_wait_asynccnt', # Matrix operations - MFMA (Matrix Fused Multiply-Add) 'mfma_f32_32x32x8f16', 'mfma_f32_16x16x16f16', @@ -326,6 +550,8 @@ def wmma_i32_16x16x32_iu4(result_type, operands, *, loc=None, ip=None): 'wmma_f32_16x16x16_fp8_fp8', 'wmma_f32_16x16x16_fp8_bf8', 'wmma_f32_16x16x16_bf8_fp8', 'wmma_f32_16x16x16_bf8_bf8', 'wmma_i32_16x16x32_iu4', + 'wmma_scale_f32_16x16x128_f8f6f4', # gfx1250 WMMA_SCALE 16x16x128 (FP4/FP6/FP8) + 'wmma_scale_f32_32x16x128_f4', # gfx1250 WMMA_SCALE 32x16x128 (FP4 only) # Matrix operations - SMFMAC (Sparse Matrix FMA) 'smfmac_f32_32x32x16_f16', 'smfmac_f32_32x32x16_bf16', @@ -367,9 +593,25 @@ def wmma_i32_16x16x32_iu4(result_type, operands, *, loc=None, ip=None): # MMA atom types 'MmaAtomCDNA3_MFMAType', 'MFMA', + 'MmaAtomGFX1250_WMMAType', 'WMMA', # Convenience wrappers 'make_buffer_tensor', + 'lds_transpose_load', # memref-level wrapper for gfx1250 ds_load_tr16_b128 + + # gfx1250 TDM - descriptor-driven tile copy (preferred over per-lane) + 'tensor_load_to_lds', # 4-group, up to 5D tensor + 'tensor_load_to_lds_d2', # 2-group, up to 2D tensor + 'tensor_store_from_lds', # 4-group store + 'tensor_store_from_lds_d2', # 2-group store + 's_wait_tensorcnt', + + # gfx1250 L2 prefetch + 'global_prefetch', # per-lane 1-byte prefetch hint + + # Cluster (gfx1250 workgroup clustering) + 'cluster_workgroup_id_x', 'cluster_workgroup_id_y', 'cluster_workgroup_id_z', + 'cluster_load_async_to_lds', # per-lane MCAST load (Global → LDS) ] diff --git a/python/flydsl/expr/rocdl/__init__.py b/python/flydsl/expr/rocdl/__init__.py index cce39593..a788bd23 100644 --- a/python/flydsl/expr/rocdl/__init__.py +++ b/python/flydsl/expr/rocdl/__init__.py @@ -14,6 +14,21 @@ from ..._mlir.dialects.rocdl import * # noqa: F401,F403 # Keep references to ODS-generated builders so we can wrap them without losing access. +_ods_wmma_scale_f32_16x16x128_f8f6f4 = ( + globals().get("wmma_scale_f32_16x16x128_f8f6f4", None) +) +_ods_wmma_scale_f32_32x16x128_f4 = ( + globals().get("wmma_scale_f32_32x16x128_f4", None) +) +_ods_wave_id = wave_id # ODS: wave_id(res, ...) -> i32 +_ods_cluster_workgroup_id_x = cluster_workgroup_id_x +_ods_cluster_workgroup_id_y = cluster_workgroup_id_y +_ods_cluster_workgroup_id_z = cluster_workgroup_id_z +_ods_cluster_load_async_to_lds_b8 = cluster_load_async_to_lds_b8 +_ods_cluster_load_async_to_lds_b32 = cluster_load_async_to_lds_b32 +_ods_cluster_load_async_to_lds_b64 = cluster_load_async_to_lds_b64 +_ods_cluster_load_async_to_lds_b128 = cluster_load_async_to_lds_b128 +_ods_s_wait_asynccnt = s_wait_asynccnt _ods_mfma_f32_16x16x16f16 = mfma_f32_16x16x16f16 _ods_mfma_f32_16x16x16bf16_1k = globals().get("mfma_f32_16x16x16bf16_1k", None) _ods_mfma_f32_16x16x32_fp8_fp8 = mfma_f32_16x16x32_fp8_fp8 @@ -101,6 +116,177 @@ def mfma_scale_f32_16x16x128_f8f6f4(result_type, operands, *, loc=None, ip=None) ).result +def wmma_scale_f32_16x16x128_f8f6f4(result_type, a, b, c, scaleA, scaleB, + *, fmtA=4, fmtB=4, modC=0, + scaleAType=0, fmtScaleA=0, + scaleBType=0, fmtScaleB=0, + reuseA=False, reuseB=False, + loc=None, ip=None): + """V_WMMA_SCALE_F32_16X16X128_F8F6F4 for gfx1250 (wave32). + + Operand types (wave32): + a: vector<8xi32> (16x128 FP4 data) + b: vector<8xi32> (128x16 FP4 data) + c: vector<8xf32> (16x16 FP32 accumulator) + scaleA: i32 (A scale VGPR) + scaleB: i32 (B scale VGPR) + + fmtA/fmtB: data type encoding (0=FP8/E4M3, 1=FP8/E5M2, 2=FP6/E2M3, 3=FP6/E3M2, 4=FP4/E2M1) + scaleAType/scaleBType: opsel – selects lo/hi 16-bit half of scale VGPR (0=lo, 1=hi) + fmtScaleA/fmtScaleB: scale format (0=E8M0, 1=E5M3, 2=E4M3) + """ + if _ods_wmma_scale_f32_16x16x128_f8f6f4 is None: + raise AttributeError("ROCDL op not found: wmma_scale_f32_16x16x128_f8f6f4") + a_v = _unwrap_mfma_operand(a, loc=loc) + b_v = _unwrap_mfma_operand(b, loc=loc) + c_v = _unwrap_mfma_operand(c, loc=loc) + sA = _unwrap_mfma_operand(scaleA, loc=loc) + sB = _unwrap_mfma_operand(scaleB, loc=loc) + return _ods_wmma_scale_f32_16x16x128_f8f6f4( + result_type, a_v, b_v, c_v, sA, sB, + fmtA=fmtA, fmtB=fmtB, modC=modC, + scaleAType=scaleAType, fmtScaleA=fmtScaleA, + scaleBType=scaleBType, fmtScaleB=fmtScaleB, + reuseA=reuseA, reuseB=reuseB, + loc=loc, ip=ip, + ).result + + +def wmma_scale_f32_32x16x128_f4(result_type, a, b, c, scaleA, scaleB, + *, modC=0, + scaleAType=0, fmtScaleA=0, + scaleBType=0, fmtScaleB=0, + reuseA=False, reuseB=False, + loc=None, ip=None): + """V_WMMA_SCALE_F32_32X16X128_F4 for gfx1250 (wave32). + + Operand types (wave32): + a: vector<16xi32> (32x128 FP4 data) + b: vector<8xi32> (128x16 FP4 data) + c: vector<16xf32> (32x16 FP32 accumulator) + scaleA: i32 (A scale VGPR) + scaleB: i32 (B scale VGPR) + """ + if _ods_wmma_scale_f32_32x16x128_f4 is None: + raise AttributeError("ROCDL op not found: wmma_scale_f32_32x16x128_f4") + a_v = _unwrap_mfma_operand(a, loc=loc) + b_v = _unwrap_mfma_operand(b, loc=loc) + c_v = _unwrap_mfma_operand(c, loc=loc) + sA = _unwrap_mfma_operand(scaleA, loc=loc) + sB = _unwrap_mfma_operand(scaleB, loc=loc) + return _ods_wmma_scale_f32_32x16x128_f4( + result_type, a_v, b_v, c_v, sA, sB, + modC=modC, + scaleAType=scaleAType, fmtScaleA=fmtScaleA, + scaleBType=scaleBType, fmtScaleB=fmtScaleB, + reuseA=reuseA, reuseB=reuseB, + loc=loc, ip=ip, + ).result + + +def wave_id(): + """Get wave-id-in-workgroup as SGPR (via TTMP8[29:25]). + + Returns: + i32 value (SGPR) with the wave ID within the workgroup. + """ + from ..._mlir import ir + i32 = ir.IntegerType.get_signless(32) + return _ods_wave_id(i32) + + +def cluster_workgroup_id_x(): + """Get workgroup position within cluster along X (SGPR, gfx1250). """ + from ..._mlir import ir + i32 = ir.IntegerType.get_signless(32) + return _ods_cluster_workgroup_id_x(i32) + + +def cluster_workgroup_id_y(): + """Get workgroup position within cluster along Y (SGPR, gfx1250). """ + from ..._mlir import ir + i32 = ir.IntegerType.get_signless(32) + return _ods_cluster_workgroup_id_y(i32) + + +def cluster_workgroup_id_z(): + """Get workgroup position within cluster along Z (SGPR, gfx1250). """ + from ..._mlir import ir + i32 = ir.IntegerType.get_signless(32) + return _ods_cluster_workgroup_id_z(i32) + + +def cluster_load_async_to_lds(global_ptr, lds_ptr, size_bytes, offset=0, cpol=0, mask=None): + """Per-lane cluster broadcast load: Global -> LDS with MCAST (gfx1250). + + Args: + global_ptr: ``!llvm.ptr<1>`` -- global address space pointer. + lds_ptr: ``!llvm.ptr<3>`` -- LDS address space pointer. + size_bytes: Load width: 1, 4, 8, or 16 bytes (selects b8/b32/b64/b128). + offset: Byte offset (int, default 0). + cpol: Cache policy (int, default 0). + mask: i32 workgroup_mask for MCAST broadcast. None means no mask. + """ + _dispatch = { + 1: _ods_cluster_load_async_to_lds_b8, + 4: _ods_cluster_load_async_to_lds_b32, + 8: _ods_cluster_load_async_to_lds_b64, + 16: _ods_cluster_load_async_to_lds_b128, + } + fn = _dispatch.get(size_bytes) + if fn is None: + raise ValueError( + f"cluster_load_async_to_lds: size_bytes must be 1, 4, 8, or 16, " + f"got {size_bytes}") + if mask is None: + from ..._mlir import ir + from .. import arith as _arith + mask = _arith.unwrap(_arith.constant(0, type=ir.IntegerType.get_signless(32))) + fn(global_ptr, lds_ptr, offset, cpol, mask) + + +def s_wait_asynccnt(count=0): + """Wait for outstanding async load/store operations (ASYNCcnt counter).""" + _ods_s_wait_asynccnt(count) + + +def lds_transpose_load(result_type, lds_memref, elem_offset, elem_bytes): + """Transpose-load from LDS memref via ds_load_tr16_b128 (gfx1250). + + Args: + result_type: Vector result type, e.g. ``VectorType.get([8], f16)``. + lds_memref: LDS memref value (address-space 3), typically from + ``SmemPtr.get()`` or ``get_op_result_or_value(...)``. + elem_offset: Per-lane linearized element offset into the memref + (ArithValue / ir.Value of index type / Python int). + elem_bytes: Element size in bytes (Python int, e.g. 2 for f16). + + Returns: + Loaded and transposed vector ``ir.Value``. + """ + from ..._mlir import ir as _ir + from ..._mlir.dialects import ( + llvm as _llvm, + memref as _memref, + rocdl as _rocdl, + ) + from .. import arith as _arith + from ..arith import _to_raw + from ..typing import T + from ..utils.arith import ArithValue as _AV + + lds_ptr_ty = _ir.Type.parse("!llvm.ptr<3>") + raw_memref = _arith.unwrap(lds_memref) + lds_base = _memref.extract_aligned_pointer_as_index(raw_memref) + + byte_off = _AV(_arith.unwrap(elem_offset, index=True)) * _arith.index(elem_bytes) + total_byte_idx = _AV(lds_base) + byte_off + addr_i32 = _to_raw(_arith.index_cast(T.i32, total_byte_idx)) + ptr_val = _llvm.inttoptr(lds_ptr_ty, addr_i32) + + return _rocdl.ds_load_tr16_b128(result_type, ptr_val) + + # ── New high-level helpers from universal.py ────────────────────────── from .universal import * # noqa: F401,F403 diff --git a/python/flydsl/expr/rocdl/universal.py b/python/flydsl/expr/rocdl/universal.py index d7a6f463..47ddbe3d 100644 --- a/python/flydsl/expr/rocdl/universal.py +++ b/python/flydsl/expr/rocdl/universal.py @@ -4,6 +4,7 @@ from ..._mlir.dialects.fly import LayoutType, PointerType from ..._mlir.dialects.fly import MemRefType as FlyMemRefType from ..._mlir.dialects.fly_rocdl import CopyOpCDNA3BufferLDSTType, MmaAtomCDNA3_MFMAType +from ..._mlir._mlir_libs._fly_rocdl import MmaAtomGFX1250_WMMAType from ..primitive import ( get_iter, get_layout, @@ -28,6 +29,15 @@ def MFMA(m, n, k, elem_ty_ab, elem_ty_acc=None): return MmaAtomCDNA3_MFMAType.get(m, n, k, ty_ab, ty_ab, ty_acc) +def WMMA(m, n, k, elem_ty_ab, elem_ty_acc=None): + ty_ab = elem_ty_ab.ir_type if hasattr(elem_ty_ab, "ir_type") else elem_ty_ab + if elem_ty_acc is None: + ty_acc = ir.F32Type.get() + else: + ty_acc = elem_ty_acc.ir_type if hasattr(elem_ty_acc, "ir_type") else elem_ty_acc + return MmaAtomGFX1250_WMMAType.get(m, n, k, ty_ab, ty_ab, ty_acc) + + def make_buffer_tensor(tensor: Tensor) -> Tensor: def _elem_bit_width(elem_ty): if hasattr(elem_ty, "width"): diff --git a/python/flydsl/expr/tdm_ops.py b/python/flydsl/expr/tdm_ops.py new file mode 100644 index 00000000..706c0d8f --- /dev/null +++ b/python/flydsl/expr/tdm_ops.py @@ -0,0 +1,513 @@ +"""TDM (Tensor Data Mover) operations for gfx1250. + +High-level Python API that encapsulates TDM descriptor construction, +analogous to how buffer_ops.py wraps buffer resource descriptors. + +The TDM hardware on gfx1250 provides descriptor-driven DMA for +Global <-> LDS transfers. This module hides the bitfield packing +behind a clean API: + + desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_a, lds_memref=lds_a_mem, + global_offset=(blk_m, k_base), + tensor_shape=(tile_m, K), strides=(K, 1), + tile_shape=(tile_m, tile_k), + elem_bytes=2, + pad_interval=64, pad_amount=8, + num_warps=8, + ) + tdm_ops.tensor_load_2d(desc) + tdm_ops.tensor_wait(0) +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Optional, Sequence, Tuple, Union + +from .._mlir import ir +from .._mlir.dialects import ( + arith as std_arith, + llvm as llvm_dialect, + memref as memref_dialect, + rocdl, +) +from ..expr import arith, vector +from ..expr.arith import _to_raw as _raw +from ..expr.typing import T +from ..expr.utils.arith import ArithValue as _ArithValue + +__all__ = [ + "TDMDescriptor2D", + "make_tensor_descriptor_2d", + "tensor_load_2d", + "tensor_store_2d", + "tensor_wait", + "compute_padding_encoding", + "compute_warp_distribution", + "l2_prefetch_tile", +] + + +# --------------------------------------------------------------------------- +# Pure-Python helpers (compile-time, no IR emission) +# --------------------------------------------------------------------------- + +def compute_padding_encoding( + pad_interval_elems: int, + pad_amount_elems: int, + elem_bits: int = 16, +) -> Tuple[int, int]: + """Compute TDM descriptor padding bitfield values. + + Follows Triton TDMUtility.cpp convention: + padIntervalInDwords = pad_interval_elems * elem_bits / 32 + padAmountInDwords = pad_amount_elems * elem_bits / 32 + encoded_interval = log2(padIntervalInDwords) - 1 + encoded_amount = padAmountInDwords - 1 + + Args: + pad_interval_elems: Padding interval in elements (e.g. tile_k = 64). + pad_amount_elems: Padding amount in elements (e.g. LDS_PAD = 8). + elem_bits: Bits per element (16 for f16/bf16, 32 for f32). + + Returns: + (encoded_interval, encoded_amount) ready for descriptor bits. + """ + dword_bits = 32 + interval_dw = pad_interval_elems * elem_bits // dword_bits + amount_dw = pad_amount_elems * elem_bits // dword_bits + if interval_dw <= 0 or amount_dw <= 0: + return (0, 0) + assert interval_dw & (interval_dw - 1) == 0, ( + f"padIntervalInDwords must be power-of-2, got {interval_dw}" + ) + encoded_interval = int(math.log2(interval_dw)) - 1 + encoded_amount = amount_dw - 1 + return (encoded_interval, encoded_amount) + + +def compute_warp_distribution( + block_shape: Sequence[int], + num_warps: int, +) -> Tuple[list, list]: + """Compute per-warp block sub-tile after distributing warps. + + Mirrors Triton's tdmGetWarpDistribution + tdmGetAdjustedBlockShape + from TDMCommon.h. + + Args: + block_shape: Full tile shape, e.g. [tile_m, tile_k]. + num_warps: Total number of warps in the workgroup. + + Returns: + (warps_per_dim, block_per_warp) — how many warps along each dim + and the sub-tile size each warp handles. + """ + ndims = len(block_shape) + warps = [1] * ndims + remaining = num_warps + for i in range(ndims): + while remaining > 1 and warps[i] * 2 <= block_shape[i]: + warps[i] *= 2 + remaining //= 2 + if remaining > 1: + warps[-1] *= remaining + block_per_warp = [ + (block_shape[i] + warps[i] - 1) // warps[i] + for i in range(ndims) + ] + return warps, block_per_warp + + +# --------------------------------------------------------------------------- +# Descriptor data class +# --------------------------------------------------------------------------- + +@dataclass +class TDMDescriptor2D: + """Holds constructed GROUP0 and GROUP1 vectors for tensor_load_to_lds_d2.""" + dgroup0: object # vector<4xi32> MLIR Value + dgroup1: object # vector<8xi32> MLIR Value + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _unwrap(value): + """Unwrap ArithValue wrappers to get raw ir.Value.""" + max_depth = 10 + depth = 0 + while depth < max_depth and not isinstance(value, ir.Value): + if hasattr(value, "_value"): + value = value._value + elif hasattr(value, "value"): + value = value.value + else: + break + depth += 1 + return value + + +def _i32_const(v: int) -> ir.Value: + """Emit an i32 constant, handling negative / unsigned values.""" + i32 = ir.IntegerType.get_signless(32) + if v > 0x7FFFFFFF: + v = int(v - 2**32) + return _unwrap(std_arith.ConstantOp(i32, ir.IntegerAttr.get(i32, v)).result) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def make_tensor_descriptor_2d( + global_ptr, + lds_memref, + global_offset: Tuple, + tensor_shape: Tuple[int, int], + strides: Tuple[int, int], + tile_shape: Tuple[int, int], + elem_bytes: int = 2, + pad_interval: int = 0, + pad_amount: int = 0, + num_warps: int = 1, + cache_policy: int = 0, + pred: int = 1, + workgroup_mask: Union[int, "ir.Value"] = 0, +) -> TDMDescriptor2D: + """Build a 2D TDM descriptor for tensor_load_to_lds_d2. + + Convention (matching ISA): + dim0 = innermost (fastest-varying, e.g. K for row-major A) + dim1 = outermost (e.g. M for row-major A) + tensor_shape = (outer_size, inner_size) in user order + strides = (outer_stride, inner_stride) + tile_shape = (outer_tile, inner_tile) + global_offset is (outer_offset, inner_offset) — MLIR index Values + + Per-warp distribution is handled internally when num_warps > 1: + each wave computes its own LDS and global offsets so that all waves + collectively cover the full tile. + + Padding params are in ELEMENTS (converted to dwords for encoding). + + Args: + global_ptr: The global tensor (fx.Tensor or fly memref value). + lds_memref: The LDS memref value (already the correct buffer slot). + global_offset: (outer_idx, inner_idx) as MLIR index values. + tensor_shape: (outer_size, inner_size) as Python ints. + strides: (outer_stride, inner_stride) as Python ints. + tile_shape: (outer_tile, inner_tile) as Python ints. + elem_bytes: Element size in bytes (2 for f16/bf16, 4 for f32). + pad_interval: Padding interval in elements (0 to disable). + pad_amount: Padding amount in elements (0 to disable). + num_warps: Total warps in the workgroup. + cache_policy: Cache policy (0 = default). + pred: Predicate (1 = enabled). + workgroup_mask: MCAST workgroup mask [15:0] for TDM GROUP1 descriptor. + int: compile-time constant folded into descriptor. + ir.Value (i32 SGPR): runtime mask, ORed with upper config bits. + 0 = no multicast (default). + + Returns: + TDMDescriptor2D with dgroup0 and dgroup1 ready for tensor_load_2d. + """ + from .._mlir.dialects import fly as _fly_d + + outer_size, inner_size = tensor_shape + outer_stride, inner_stride = strides + outer_tile, inner_tile = tile_shape + outer_off, inner_off = global_offset + + # -- Warp distribution -- + warps_per_dim, block_per_warp = compute_warp_distribution( + [outer_tile, inner_tile], num_warps, + ) + bpw_outer, bpw_inner = block_per_warp + warps_dim0 = warps_per_dim[0] + + if num_warps > 1: + # Auto-acquire SGPR wave_id via hardware register (TTMP8[29:25]). + # This keeps the entire descriptor address chain in SALU, + from . import rocdl as _rocdl_ext + _wid_i32 = _rocdl_ext.wave_id() + wave_id = arith.index_cast(T.index, _wid_i32) + warp_coord_outer = wave_id % arith.index(warps_dim0) + warp_coord_inner = wave_id / arith.index(warps_dim0) + warp_off_outer = warp_coord_outer * arith.index(bpw_outer) + warp_off_inner = warp_coord_inner * arith.index(bpw_inner) + else: + warp_off_outer = arith.index(0) + warp_off_inner = arith.index(0) + + # -- Global address (byte address for descriptor) -- + glb_ptr_type = ir.Type.parse("!llvm.ptr<1>") + i64 = ir.IntegerType.get_signless(64) + a_raw = global_ptr.__fly_values__()[0] + glb_ptr = _fly_d.extract_aligned_pointer_as_index(glb_ptr_type, a_raw) + glb_base_i64 = _ArithValue(llvm_dialect.ptrtoint(i64, glb_ptr)) + glb_elem_off = ( + (outer_off + warp_off_outer) * arith.index(outer_stride) + + (inner_off + warp_off_inner) * arith.index(inner_stride) + ) + glb_byte_off = glb_elem_off * arith.index(elem_bytes) + glb_byte_off_i64 = arith.index_cast(T.i64, glb_byte_off) + glb_addr_i64 = glb_base_i64 + glb_byte_off_i64 + + # -- LDS address (byte address within shared memory) -- + lds_base_idx = _ArithValue(memref_dialect.extract_aligned_pointer_as_index(lds_memref)) + # Compute padded LDS stride (elements) for the outer dim + if pad_interval > 0 and pad_amount > 0: + lds_inner_stride = inner_tile + pad_amount # padded row width + else: + lds_inner_stride = inner_tile + lds_warp_elem_off = ( + warp_off_outer * arith.index(lds_inner_stride) + warp_off_inner + ) + lds_warp_byte_off = lds_warp_elem_off * arith.index(elem_bytes) + lds_addr_i32 = arith.index_cast(T.i32, lds_base_idx + lds_warp_byte_off) + + # ================================================================ + # GROUP0 (vector<4xi32>): pred, lds_addr, global_addr_lo/hi + # ================================================================ + g0_s0 = arith.constant(pred, type=T.i32) + g0_s1 = lds_addr_i32 + i32 = ir.IntegerType.get_signless(32) + g0_s2 = _ArithValue(std_arith.TruncIOp(i32, _raw(glb_addr_i64)).result) + hi_raw = _ArithValue(_raw(glb_addr_i64)).shrui(arith.constant(32, type=T.i64)) + g0_s3 = ( + _ArithValue(std_arith.TruncIOp(i32, _raw(hi_raw)).result) + | arith.constant(1 << 31, type=T.i32) # type field = 2 in [31:30] + ) + dgroup0 = vector.from_elements( + T.vec(4, T.i32), [g0_s0, g0_s1, g0_s2, g0_s3] + ) + + # ================================================================ + # GROUP1 (vector<8xi32>): config + tensor dims + strides + tile + # ================================================================ + # Descriptor dim ordering: dim0=innermost, dim1=outermost + tdim0 = bpw_inner # innermost extent per warp + tdim1 = bpw_outer # outermost extent per warp + tile_d0 = bpw_inner # block dim0 per warp + tile_d1 = bpw_outer # block dim1 per warp + # stride_dim0 in descriptor = outermost stride in elements + stride0 = outer_stride + + # data_size = log2(elem_bytes) + data_size_code = int(math.log2(elem_bytes)) + + # Padding encoding + if pad_interval > 0 and pad_amount > 0: + elem_bits = elem_bytes * 8 + enc_interval, enc_amount = compute_padding_encoding( + pad_interval, pad_amount, elem_bits + ) + pad_enable = 1 + else: + enc_interval, enc_amount = 0, 0 + pad_enable = 0 + + # sgpr0: config bitfields + g1_s0_upper = ( + (data_size_code << 16) # data_size [17:16] + | (0 << 18) # atomic_barrier_enable + | (0 << 19) # iterate_enable + | (pad_enable << 20) # pad_enable + | (0 << 21) # early_timeout + | (enc_interval << 22) # pad_interval [24:22] + | (enc_amount << 25) # pad_amount [31:25] + ) + + if isinstance(workgroup_mask, int): + g1_s0_val = (workgroup_mask & 0xFFFF) | g1_s0_upper + g1_s0 = arith.constant(g1_s0_val, type=T.i32) + else: + upper_const = arith.constant(g1_s0_upper, type=T.i32) + mask_i32 = arith.andi(workgroup_mask, arith.constant(0xFFFF, type=T.i32)) + g1_s0 = arith.ori(upper_const, mask_i32) + + # sgpr1: atomic_barrier_addr[15:0]=0 | tensor_dim0_lo[31:16] + g1_s1 = arith.constant((tdim0 & 0xFFFF) << 16, type=T.i32) + + # sgpr2: tensor_dim0_hi[15:0] | tensor_dim1_lo[31:16] + g1_s2 = arith.constant( + ((tdim0 >> 16) & 0xFFFF) | ((tdim1 & 0xFFFF) << 16), + type=T.i32, + ) + + # sgpr3: tensor_dim1_hi[15:0] | tile_dim0[31:16] + g1_s3 = arith.constant( + ((tdim1 >> 16) & 0xFFFF) | (tile_d0 << 16), + type=T.i32, + ) + + # sgpr4: tile_dim1[15:0] | tile_dim2[31:16]=0 + g1_s4 = arith.constant(tile_d1 & 0xFFFF, type=T.i32) + + # sgpr5: tensor_dim0_stride (low 32 bits) — stride of outermost dim + g1_s5 = arith.constant(stride0 & 0xFFFFFFFF, type=T.i32) + + # sgpr6-7: for 2D, no higher-dim strides + g1_s6 = arith.constant(0, type=T.i32) + g1_s7 = arith.constant(0, type=T.i32) + + dgroup1 = vector.from_elements( + T.vec(8, T.i32), + [g1_s0, g1_s1, g1_s2, g1_s3, g1_s4, g1_s5, g1_s6, g1_s7], + ) + + return TDMDescriptor2D(dgroup0=dgroup0, dgroup1=dgroup1) + + +def _zero_dgroup_v4i32(): + """Create a zero vector<4xi32> for unused descriptor groups.""" + z = arith.constant(0, type=T.i32) + return vector.from_elements(T.vec(4, T.i32), [z, z, z, z]) + + +def _zero_dgroup_v8i32(): + """Create a zero vector<8xi32> for unused descriptor groups.""" + z = arith.constant(0, type=T.i32) + return vector.from_elements(T.vec(8, T.i32), [z, z, z, z, z, z, z, z]) + + +def tensor_load_2d( + desc: TDMDescriptor2D, + cache_policy: int = 0, +) -> None: + """Issue a TDM 2D async load (Global -> LDS). + + Each wave in the workgroup calls this with its own descriptor + (as built by make_tensor_descriptor_2d). All waves together + cover the full tile. + + Uses the unified 5-group intrinsic with dgroup2/dgroup3/dgroup4 + zero-initialized for 2D tensors. + + Args: + desc: TDMDescriptor2D from make_tensor_descriptor_2d. + cache_policy: Cache policy (0 = default). + """ + dg2 = _raw(_zero_dgroup_v4i32()) + dg3 = _raw(_zero_dgroup_v4i32()) + dg4 = _raw(_zero_dgroup_v8i32()) + rocdl.tensor_load_to_lds( + _raw(desc.dgroup0), _raw(desc.dgroup1), dg2, dg3, dg4, cache_policy + ) + + +def tensor_store_2d( + desc: TDMDescriptor2D, + cache_policy: int = 0, +) -> None: + """Issue a TDM 2D async store (LDS -> Global). + + Uses the unified 5-group intrinsic with dgroup2/dgroup3/dgroup4 + zero-initialized for 2D tensors. + + Args: + desc: TDMDescriptor2D (with LDS source and global destination). + cache_policy: Cache policy (0 = default). + """ + dg2 = _raw(_zero_dgroup_v4i32()) + dg3 = _raw(_zero_dgroup_v4i32()) + dg4 = _raw(_zero_dgroup_v8i32()) + rocdl.tensor_store_from_lds( + _raw(desc.dgroup0), _raw(desc.dgroup1), dg2, dg3, dg4, cache_policy + ) + + +def tensor_wait(count: int = 0) -> None: + """Wait for outstanding TDM tensor operations. + + Issues s_wait_tensorcnt. + + Args: + count: Number of outstanding operations to allow (0 = wait for all). + """ + rocdl.s_wait_tensorcnt(count) + + +# --------------------------------------------------------------------------- +# L2 prefetch +# --------------------------------------------------------------------------- + +# Scope constants for global_prefetch +PREFETCH_SCOPE_SE = 8 # SE scope = L2 cache +PREFETCH_SCOPE_DEVICE = 16 # Device scope + +def l2_prefetch_tile( + global_ptr, + global_offset: Tuple, + tile_shape: Tuple[int, int], + strides: Tuple[int, int], + elem_bytes: int = 2, + num_warps: int = 1, + wave_id=None, + thread_id=None, + block_threads: int = 256, + scope: int = PREFETCH_SCOPE_SE, +) -> None: + """Issue per-lane L2 cache prefetch hints for a 2D tile. + + Each lane in the workgroup prefetches 1 byte at a distinct global address + within the tile, distributing prefetch coverage across the tile. + + For a tile of outer×inner elements, each lane covers a unique row offset. + Multiple calls (from successive iterations) accumulate coverage. + + Args: + global_ptr: The global tensor (fx.Tensor). + global_offset: (outer_idx, inner_idx) as MLIR index values. + tile_shape: (outer_size, inner_size) in elements. + strides: (outer_stride, inner_stride) in elements. + elem_bytes: Element size in bytes. + num_warps: Total warps in the workgroup. + wave_id: Current wave ID (MLIR index). Unused; thread_id used instead. + thread_id: Workgroup-local thread ID (MLIR index value). + block_threads: Total threads in the workgroup. + scope: Prefetch scope (default: SE = L2). + """ + from .._mlir.dialects import ( + fly as _fly_d, + llvm as llvm_dialect, + ) + + outer_size, inner_size = tile_shape + outer_stride, inner_stride = strides + outer_off, inner_off = global_offset + + # Get global base address as i64 + glb_ptr_type = ir.Type.parse("!llvm.ptr<1>") + i64 = ir.IntegerType.get_signless(64) + a_raw = global_ptr.__fly_values__()[0] + glb_ptr = _fly_d.extract_aligned_pointer_as_index(glb_ptr_type, a_raw) + glb_base_i64 = _ArithValue(llvm_dialect.ptrtoint(i64, glb_ptr)) + + # Each thread prefetches one row of the tile. + # thread_id maps to an outer-dim offset within the tile. + # Total rows = outer_size; if block_threads > outer_size, some threads + # wrap and prefetch additional cachelines. + # For simplicity, each thread prefetches row[tid % outer_size], col=0. + tile_row = thread_id % arith.index(outer_size) + + elem_off = ( + (outer_off + tile_row) * arith.index(outer_stride) + + inner_off * arith.index(inner_stride) + ) + byte_off = elem_off * arith.index(elem_bytes) + byte_off_i64 = arith.index_cast(T.i64, byte_off) + addr_i64 = glb_base_i64 + byte_off_i64 + + # Convert i64 address to pointer + ptr_val = llvm_dialect.inttoptr(glb_ptr_type, _raw(addr_i64)) + + # Issue prefetch hint via ROCDL dialect op. + # NOTE: rocdl.global_prefetch lowers to llvm.amdgcn.global.prefetch, which + # requires LLVM ISel support for gfx1250 global_prefetch_b8. If the LLVM + # build lacks this pattern, the instruction will be silently dropped. + rocdl.global_prefetch(ptr_val, scope) diff --git a/python/flydsl/runtime/device.py b/python/flydsl/runtime/device.py index c42833f1..36a3faef 100644 --- a/python/flydsl/runtime/device.py +++ b/python/flydsl/runtime/device.py @@ -4,13 +4,13 @@ from typing import Optional -def _arch_from_rocm_agent_enumerator() -> Optional[str]: +def _arch_from_rocm_agent_enumerator(timeout_s: int = 5) -> Optional[str]: """Query rocm_agent_enumerator (standard ROCm tool) for the first GPU arch.""" try: out = subprocess.check_output( ["rocm_agent_enumerator", "-name"], text=True, - timeout=5, + timeout=timeout_s, stderr=subprocess.DEVNULL, ) for line in out.splitlines(): @@ -23,9 +23,11 @@ def _arch_from_rocm_agent_enumerator() -> Optional[str]: @functools.lru_cache(maxsize=None) -def get_rocm_arch() -> str: - """Best-effort ROCm GPU arch string (e.g. 'gfx942').""" - env = os.environ.get("FLYDSL_GPU_ARCH") or os.environ.get("HSA_OVERRIDE_GFX_VERSION") +def get_rocm_arch(timeout_s: int = 5) -> str: + """Best-effort ROCm GPU arch string (e.g. 'gfx942') without torch.""" + env = (os.environ.get("FLYDSL_GPU_ARCH") + or os.environ.get("HSA_OVERRIDE_GFX_VERSION") + ) if env: if env.startswith("gfx"): return env @@ -33,7 +35,7 @@ def get_rocm_arch() -> str: parts = env.split(".") return f"gfx{parts[0]}{parts[1]}{parts[2]}" - arch = _arch_from_rocm_agent_enumerator() + arch = _arch_from_rocm_agent_enumerator(timeout_s=timeout_s) if arch: return arch.split(":", 1)[0] diff --git a/python/flydsl/utils/smem_allocator.py b/python/flydsl/utils/smem_allocator.py index bf7a47a8..33eefa8a 100644 --- a/python/flydsl/utils/smem_allocator.py +++ b/python/flydsl/utils/smem_allocator.py @@ -209,10 +209,12 @@ def get_base(self): SMEM_CAPACITY_MAP = { # ===================== AMD CDNA Architectures (Data Center Compute Cards) ===================== # CDNA 3 (MI300 Series) - 64KB LDS per CU - "gfx942": 65536, # MI300A / MI300X: 64KB LDS per CU + "gfx942": 65536, # MI300A / MI300X: 64KB LDS per CU # CDNA 4 (MI350 Series) - 160KB LDS per CU (key upgrade for CDNA4) "gfx950": 163840, # MI300C / MI300X Enhanced Models: 64KB LDS per CU "gfx1201": 65536, # RDNA4: 64KB LDS per WGP + # GFX1250 (MI450 Series) - 320KB LDS (WGP$ unified, 5 × 64KB segments) + "gfx1250": 327680, # MI450: 320KB configurable as LDS } def check_smem_capacity(allocated_bytes: int, arch: str = None): diff --git a/tests/kernels/test_mxfp4_gemm_gfx1250.py b/tests/kernels/test_mxfp4_gemm_gfx1250.py new file mode 100644 index 00000000..e44f5c07 --- /dev/null +++ b/tests/kernels/test_mxfp4_gemm_gfx1250.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 +"""MXFP4 GEMM correctness tests for gfx1250. + +Kernel implementation: kernels/mxfp4_gemm_gfx1250.py +""" + +import os +import sys + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +_PYFLIR_SRC = os.path.join(_REPO_ROOT, "flydsl", "src") +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) +if _PYFLIR_SRC not in sys.path: + sys.path.insert(0, _PYFLIR_SRC) + +# workaround for simulator +import flydsl # noqa: E402,F401 -- preload system comgr before torch/HIP loads LLVM + +import pytest +import torch + +from flydsl.runtime.device import get_rocm_arch +from kernels.mxfp4_gemm_gfx1250 import compile_mxfp4_gemm +from tests.kernels.utils import fp4_utils + + +if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available. Skipping GPU tests.", allow_module_level=True) + + +SCALE_BLOCK = 32 + + +def preshuffle_e8m0_scale(scale: torch.Tensor, warp_tile: int, + scale_k_per_tile: int = 4, + WMMA_DIM: int = 16) -> torch.Tensor: + """Preshuffle E8M0 scale for WMMA_SCALE: byte swap + interleave for ds_load_b128. """ + _, K_scale = scale.shape + assert K_scale % 4 == 0, f"K_scale must be divisible by 4, got {K_scale}" + + grouped = scale.view(-1, K_scale // 4, 4) + shuffled = grouped[:, :, [0, 2, 1, 3]].contiguous() + scale = shuffled.view(-1, K_scale) + + SCALES_PER_WMMA = 4 + wmma_rep = warp_tile // WMMA_DIM + k_groups = K_scale // scale_k_per_tile + k_wmma_steps = scale_k_per_tile // SCALES_PER_WMMA + g = scale.view(-1, wmma_rep, WMMA_DIM, k_groups, k_wmma_steps, SCALES_PER_WMMA) + g = g.permute(0, 2, 3, 4, 1, 5).contiguous() + return g.reshape(-1, k_groups * k_wmma_steps * wmma_rep * SCALES_PER_WMMA) + + +def random_mxfp4_packed(rows: int, cols: int, *, device="cpu") -> torch.Tensor: + """Generate random packed MXFP4 data [rows, cols//2] uint8. """ + assert cols % 2 == 0 + unpacked = torch.randint(0, 16, (rows, cols), dtype=torch.uint8, device=device) + return fp4_utils.pack_uint4(unpacked) + + +def random_e8m0(rows: int, cols: int, *, low_exp=127, high_exp=132, + device="cpu") -> torch.Tensor: + """Generate random E8M0 scale bytes [rows, cols] uint8. """ + return torch.randint(low_exp, high_exp + 1, (rows, cols), + dtype=torch.uint8, device=device) + + +def reference_mxfp4_gemm(a_packed, b_packed, a_scale, b_scale, M, N, K): + """Reference MXFP4 GEMM: D = (A * A_scale) @ (B * B_scale)^T. + + Args: + a_packed: [M, K//2] uint8 packed FP4 + b_packed: [N, K//2] uint8 packed FP4 + a_scale: [M, K//SCALE_BLOCK] uint8 E8M0 + b_scale: [N, K//SCALE_BLOCK] uint8 E8M0 + + Returns: + [M, N] float32 result. + """ + a_f32 = fp4_utils.mxfp4_to_f32(a_packed.view(torch.uint8))[:M, :K] + b_f32 = fp4_utils.mxfp4_to_f32(b_packed.view(torch.uint8))[:N, :K] + + a_sc = fp4_utils.e8m0_to_f32(a_scale.view(torch.uint8)) + b_sc = fp4_utils.e8m0_to_f32(b_scale.view(torch.uint8)) + + a_sc_exp = a_sc.repeat_interleave(SCALE_BLOCK, dim=-1)[:M, :K] + b_sc_exp = b_sc.repeat_interleave(SCALE_BLOCK, dim=-1)[:N, :K] + + return torch.matmul(a_f32 * a_sc_exp, (b_f32 * b_sc_exp).T) + + +@pytest.mark.parametrize( + "M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp", + [ + (128, 128, 256, 128, 128, 128, 2, 2), + (128, 128, 512, 128, 128, 128, 2, 2), + (128, 128, 1024, 128, 128, 128, 2, 2), + (1024, 1024, 1024, 128, 256, 128, 2, 4), + ], +) +@pytest.mark.parametrize("num_buffers", [2]) +def test_mxfp4_gemm(M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, + num_buffers, l2_prefetch_distance=0, + cluster_m=1, cluster_n=1, scale_preshuffle=True): + """MXFP4 GEMM correctness unit test.""" + arch = str(get_rocm_arch(timeout_s=300)) + if arch != "gfx1250": + pytest.skip(f"WMMA_SCALE requires gfx1250, got {arch}") + + num_k_tiles = K // tile_k + if num_buffers > 1 and num_k_tiles < num_buffers: + pytest.skip(f"{num_buffers}-buf requires num_k_tiles >= {num_buffers}") + + torch.manual_seed(0) + + mcast_str = f", cluster=({cluster_m},{cluster_n})" if cluster_m > 1 or cluster_n > 1 else "" + print(f"\nRunning MXFP4 GEMM: M={M}, N={N}, K={K}, " + f"tiles=({tile_m},{tile_n},{tile_k}), bufs={num_buffers}{mcast_str}") + + a_packed = random_mxfp4_packed(M, K) + b_packed = random_mxfp4_packed(N, K) + a_scale = random_e8m0(M, K // SCALE_BLOCK) + b_scale = random_e8m0(N, K // SCALE_BLOCK) + + ref = reference_mxfp4_gemm(a_packed, b_packed, a_scale, b_scale, M, N, K) + print(f"Ref stats: min={ref.min():.2f}, max={ref.max():.2f}, " + f"mean={ref.mean():.2f}, std={ref.std():.2f}") + + if scale_preshuffle: + skt = tile_k // SCALE_BLOCK + a_scale = preshuffle_e8m0_scale(a_scale, tile_m // m_warp, scale_k_per_tile=skt) + b_scale = preshuffle_e8m0_scale(b_scale, tile_n // n_warp, scale_k_per_tile=skt) + + a_gpu = a_packed.cuda() + b_gpu = b_packed.cuda() + as_gpu = a_scale.cuda() + bs_gpu = b_scale.cuda() + c_gpu = torch.zeros(M, N, dtype=torch.float32, device="cpu").cuda() + + launch_fn = compile_mxfp4_gemm( + M=M, N=N, K=K, + tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, + m_warp=m_warp, n_warp=n_warp, + num_buffers=num_buffers, + l2_prefetch_distance=l2_prefetch_distance, + cluster_m=cluster_m, cluster_n=cluster_n, + scale_preshuffle=scale_preshuffle, + ) + launch_fn( + c_gpu.contiguous().view(-1), + a_gpu.contiguous().view(-1), + b_gpu.contiguous().view(-1), + as_gpu.contiguous().view(-1), + bs_gpu.contiguous().view(-1), + M, N, torch.cuda.current_stream(), + ) + torch.cuda.synchronize() + + c_out = c_gpu.cpu() + + print(f"Out stats: min={c_out.min():.2f}, max={c_out.max():.2f}, " + f"mean={c_out.mean():.2f}, std={c_out.std():.2f}") + + if c_out.abs().max() < 1e-10: + print("WARNING: kernel output is all zeros!") + + diff = (c_out - ref).abs() + print(f"Abs diff: max={diff.max():.4f}, mean={diff.mean():.4f}") + + cos_sim = torch.nn.functional.cosine_similarity( + c_out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() + print(f"Cosine similarity: {cos_sim:.6f}") + + torch.testing.assert_close(c_out, ref, rtol=1e-5, atol=1e-8) + print("PASSED") + + +@pytest.mark.parametrize( + "M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, cluster_m, cluster_n", + [ + # 2x2 cluster: needs >= 2 tile-rows and 2 tile-cols + (256, 256, 256, 128, 128, 128, 2, 2, 2, 2), + (1024, 1024, 1024, 128, 256, 128, 2, 4, 2, 2), + # 1x2 cluster: B shared along N + (128, 256, 256, 128, 128, 128, 2, 2, 1, 2), + # 2x1 cluster: A shared along M + (256, 128, 256, 128, 128, 128, 2, 2, 2, 1), + ], +) +@pytest.mark.parametrize("num_buffers", [2]) +def test_mxfp4_gemm_mcast(M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, + cluster_m, cluster_n, num_buffers): + """MXFP4 GEMM correctness test with cluster MCAST.""" + test_mxfp4_gemm( + M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, + num_buffers=num_buffers, + l2_prefetch_distance=2, + cluster_m=cluster_m, cluster_n=cluster_n, + scale_preshuffle=True, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-M", type=int, default=128) + parser.add_argument("-N", type=int, default=128) + parser.add_argument("-K", type=int, default=256) + parser.add_argument("--tile-m", type=int, default=128) + parser.add_argument("--tile-n", type=int, default=128) + parser.add_argument("--tile-k", type=int, default=128) + parser.add_argument("--m-warp", type=int, default=2) + parser.add_argument("--n-warp", type=int, default=2) + parser.add_argument("--num-buffers", type=int, default=2, choices=[2, 3, 4]) + parser.add_argument("--l2-prefetch-distance", type=int, default=0) + parser.add_argument("--cluster-m", type=int, default=1) + parser.add_argument("--cluster-n", type=int, default=1) + parser.add_argument("--no-scale-preshuffle", action="store_true", default=False) + args = parser.parse_args() + + test_mxfp4_gemm( + args.M, args.N, args.K, + args.tile_m, args.tile_n, args.tile_k, + num_buffers=args.num_buffers, + m_warp=args.m_warp, + n_warp=args.n_warp, + l2_prefetch_distance=args.l2_prefetch_distance, + cluster_m=args.cluster_m, + cluster_n=args.cluster_n, + scale_preshuffle=not args.no_scale_preshuffle, + ) diff --git a/tests/kernels/test_wmma_gemm_gfx1250.py b/tests/kernels/test_wmma_gemm_gfx1250.py new file mode 100644 index 00000000..784f81d2 --- /dev/null +++ b/tests/kernels/test_wmma_gemm_gfx1250.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +"""WMMA GEMM using TDM tests for gfx1250. + +Kernel implementation lives in `kernels/wmma_gemm_gfx1250.py`. +This file is the correctness harness. +""" + +import os +import sys + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +_PYFLIR_SRC = os.path.join(_REPO_ROOT, "flydsl", "src") +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) +if _PYFLIR_SRC not in sys.path: + sys.path.insert(0, _PYFLIR_SRC) + +# workaround for simulator +import flydsl # noqa: E402,F401 -- preload system comgr before torch/HIP loads LLVM + +import pytest +import torch + +from flydsl.runtime.device import get_rocm_arch +from kernels.wmma_gemm_gfx1250 import compile_wmma_gemm_tdm +from tests.test_common import verify_output + + +if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available. Skipping GPU tests.", allow_module_level=True) + + +@pytest.mark.parametrize("in_dtype", ["fp16", "bf16"]) +@pytest.mark.parametrize( + "M, N, K, tile_m, tile_n, tile_k", + [ + (128, 128, 64, 64, 128, 32), + (128, 128, 256, 64, 128, 128), + (256, 256, 256, 64, 256, 128), + (256, 256, 192, 64, 256, 64), + (256, 512, 256, 64, 256, 128), + (512, 512, 512, 64, 256, 128), + (201, 179, 128, 64, 128, 64), + (300, 399, 256, 64, 256, 128), + (256, 256, 256, 256, 256, 128), + (1024, 1024, 1024, 256, 256, 128), + (512, 512, 512, 256, 256, 128), + ], +) +@pytest.mark.parametrize("num_buffers", [2, 3]) +def test_wmma_gemm_tdm(in_dtype, M, N, K, tile_m, tile_n, tile_k, + num_buffers, + m_warp=2, n_warp=4, l2_prefetch_distance=2, + cluster_m=1, cluster_n=1): + """Non-cluster GEMM correctness test.""" + arch = str(get_rocm_arch(timeout_s=300)) + if arch != "gfx1250": + pytest.skip(f"WMMA requires gfx1250, got {arch}") + + num_k_tiles = K // tile_k + if num_buffers == 3 and num_k_tiles < 3: + pytest.skip(f"Triple buffer requires num_k_tiles >= 3, got {num_k_tiles}") + + lds_pad = 8 + elem_bytes = 2 + a_buf = tile_m * (tile_k + lds_pad) * elem_bytes + b_buf = tile_k * (tile_n + lds_pad) * elem_bytes + total_lds = (a_buf + b_buf) * num_buffers + if total_lds > 327680: + pytest.skip(f"LDS budget exceeded: {total_lds} > 327680") + + torch_dtype = torch.float16 if in_dtype == "fp16" else torch.bfloat16 + torch.manual_seed(0) + + mpad = (M + tile_m - 1) // tile_m * tile_m + npad = (N + tile_n - 1) // tile_n * tile_n + wg_m = mpad // tile_m + wg_n = npad // tile_n + + if cluster_m < 1 or cluster_n < 1: + pytest.skip(f"Invalid cluster dims: ({cluster_m}, {cluster_n}), both must be >= 1") + if cluster_m > 1 or cluster_n > 1: + if wg_m < cluster_m or wg_n < cluster_n: + pytest.skip( + "Cluster dims exceed launch grid: " + f"wg_grid=({wg_m},{wg_n}), cluster=({cluster_m},{cluster_n})" + ) + if (wg_m % cluster_m) != 0 or (wg_n % cluster_n) != 0: + pytest.skip( + "WG grid must be divisible by cluster dims: " + f"wg_grid=({wg_m},{wg_n}), cluster=({cluster_m},{cluster_n})" + ) + + print( + f"Running WMMA GEMM TDM: M={M}, N={N}, K={K}, " + f"dtype={in_dtype}, bufs={num_buffers}, " + f"cluster=({cluster_m},{cluster_n})" + ) + + a = torch.randn((M, K), dtype=torch_dtype, device='cpu').cuda() + b = torch.randn((K, N), dtype=torch_dtype, device='cpu').cuda() + + a_pad = torch.zeros((mpad, K), dtype=torch_dtype, device='cpu').cuda() + b_pad = torch.zeros((K, npad), dtype=torch_dtype, device='cpu').cuda() + a_pad[:M, :] = a + b_pad[:, :N] = b + + c_pad = torch.zeros((mpad, npad), dtype=torch.float32, device='cpu').cuda() + + launch_fn = compile_wmma_gemm_tdm( + M=mpad, N=npad, K=K, + tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, + m_warp=m_warp, n_warp=n_warp, in_dtype=in_dtype, + num_buffers=num_buffers, + l2_prefetch_distance=l2_prefetch_distance, + cluster_m=cluster_m, + cluster_n=cluster_n, + ) + launch_fn( + c_pad.contiguous().view(-1), + a_pad.contiguous().view(-1), + b_pad.contiguous().view(-1), + mpad, npad, torch.cuda.current_stream(), + ) + torch.cuda.synchronize() + + ref = torch.mm(a.cpu().to(torch.float32), b.cpu().to(torch.float32)) + rtol = 3e-2 + atol = 3e-2 + assert verify_output(c_pad[:M, :N].cpu().to(torch.float32), ref, rtol=rtol, atol=atol) + print("PASSED") + + +@pytest.mark.parametrize("in_dtype", ["fp16"]) +@pytest.mark.parametrize( + "M, N, K, tile_m, tile_n, tile_k", + [ + (1024, 1024, 1024, 128, 256, 128), + (2048, 2048, 1024, 128, 256, 128), + (2048, 2048, 2048, 128, 256, 128), + (4096, 4096, 1024, 128, 256, 128), + ], +) +@pytest.mark.parametrize("cluster_m, cluster_n", [(2, 2), (4, 4)]) +def test_wmma_gemm_tdm_mcast(in_dtype, M, N, K, tile_m, tile_n, tile_k, + cluster_m, cluster_n): + """Cluster multicast GEMM correctness test (large shapes only).""" + test_wmma_gemm_tdm( + in_dtype, M, N, K, tile_m, tile_n, tile_k, + num_buffers=2, m_warp=2, n_warp=4, + l2_prefetch_distance=2, + cluster_m=cluster_m, cluster_n=cluster_n, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-M", type=int, default=1024) + parser.add_argument("-N", type=int, default=1024) + parser.add_argument("-K", type=int, default=1024) + parser.add_argument("--tile-m", type=int, default=128) + parser.add_argument("--tile-n", type=int, default=256) + parser.add_argument("--tile-k", type=int, default=128) + parser.add_argument("--m-warp", type=int, default=2) + parser.add_argument("--n-warp", type=int, default=4) + parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "bf16"]) + parser.add_argument("--num-buffers", type=int, default=2, choices=[2, 3]) + parser.add_argument("--l2-prefetch-distance", type=int, default=0) + parser.add_argument("--cluster-m", type=int, default=1) + parser.add_argument("--cluster-n", type=int, default=1) + args = parser.parse_args() + + test_wmma_gemm_tdm( + args.dtype, args.M, args.N, args.K, + args.tile_m, args.tile_n, args.tile_k, + num_buffers=args.num_buffers, + m_warp=args.m_warp, + n_warp=args.n_warp, + l2_prefetch_distance=args.l2_prefetch_distance, + cluster_m=args.cluster_m, + cluster_n=args.cluster_n, + ) diff --git a/tests/kernels/test_wmma_gemm_simple.py b/tests/kernels/test_wmma_gemm_simple.py new file mode 100644 index 00000000..f54fe039 --- /dev/null +++ b/tests/kernels/test_wmma_gemm_simple.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +"""WMMA GEMM tests for gfx1250 — @flyc.kernel API. + +Kernel implementation lives in `kernels/wmma_gemm_simple.py`. +This file is the correctness + perf harness. +""" + +import os +import sys + +import pytest +import torch + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +_PYFLIR_SRC = os.path.join(_REPO_ROOT, "flydsl", "src") +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) +if _PYFLIR_SRC not in sys.path: + sys.path.insert(0, _PYFLIR_SRC) + +from flydsl.runtime.device import get_rocm_arch +from kernels.wmma_gemm_simple import compile_wmma_gemm +from tests.test_common import verify_output + + +if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available. Skipping GPU tests.", allow_module_level=True) + + +@pytest.mark.parametrize("in_dtype", ["fp16", "bf16"]) +@pytest.mark.parametrize( + "M, N, K, tile_m, tile_n, tile_k, block_threads", + [ + (32, 32, 32, 32, 32, 32, 32), + (64, 64, 32, 64, 64, 32, 128), + (128, 128, 32, 64, 128, 32, 256), + (128, 128, 64, 64, 128, 32, 256), + (256, 256, 32, 64, 64, 32, 128), + (200, 180, 64, 64, 64, 32, 128), + (128, 128, 128, 64, 128, 64, 256), + ], +) +def test_wmma_gemm(in_dtype, M, N, K, tile_m, tile_n, tile_k, block_threads): + # rocm_agent_enumerator is very slow on AM simulator, + # set large timeout to avoid timeout and fallback to gfx942 + arch = str(get_rocm_arch(timeout_s=300)) + if arch != "gfx1250": + pytest.skip(f"WMMA requires gfx1250, got {arch}") + print(f"Running WMMA GEMM test with: M={M}, N={N}, K={K}, " + f"tile_m={tile_m}, tile_n={tile_n}, tile_k={tile_k}, " + f"block_threads={block_threads}, dtype={in_dtype}, arch={arch}") + + torch_dtype = torch.float16 if in_dtype == "fp16" else torch.bfloat16 + device = torch.device("cuda") + torch.manual_seed(0) + + # Pad M/N to tile boundaries + mpad = (M + tile_m - 1) // tile_m * tile_m + npad = (N + tile_n - 1) // tile_n * tile_n + + # torch gpu randn has some issues on gfx1250 AM simulator + a = torch.randn((M, K), dtype=torch_dtype, device='cpu').cuda() + b = torch.randn((K, N), dtype=torch_dtype, device='cpu').cuda() + + a_pad = torch.zeros((mpad, K), dtype=torch_dtype, device=device) + b_pad = torch.zeros((K, npad), dtype=torch_dtype, device=device) + a_pad[:M, :] = a + b_pad[:, :N] = b + c_pad = torch.zeros((mpad, npad), dtype=torch.float32, device=device) + + launch_fn = compile_wmma_gemm( + M=mpad, + N=npad, + K=K, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + in_dtype=in_dtype, + block_threads=block_threads, + ) + launch_fn( + c_pad.contiguous().view(-1), + a_pad.contiguous().view(-1), + b_pad.contiguous().view(-1), + mpad, + npad, + torch.cuda.current_stream(), + ) + torch.cuda.synchronize() + + ref = torch.matmul(a.cpu().to(torch.float32), b.cpu().to(torch.float32)) + assert verify_output(c_pad[:M, :N].cpu(), ref, rtol=3e-2, atol=3e-2) + print("✓ PASSED") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-M", type=int, default=256, help='problem M size') + parser.add_argument("-N", type=int, default=256, help='problem N size') + parser.add_argument("-K", type=int, default=1024, help='problem K size') + parser.add_argument("--tile_m", type=int, default=256) + parser.add_argument("--tile_n", type=int, default=256) + parser.add_argument("--tile_k", type=int, default=128) + parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "bf16"], + help="Input data type") + args = parser.parse_args() + + WARP_SIZE = 32 + BLOCK_THREADS = min(args.tile_n, 8 * WARP_SIZE) + + test_wmma_gemm( + args.dtype, + args.M, + args.N, + args.K, + args.tile_m, + args.tile_n, + args.tile_k, + BLOCK_THREADS, + ) diff --git a/tests/test_common.py b/tests/test_common.py index 2f27592e..276cfeab 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -419,7 +419,7 @@ def checkAllclose( def verify_output(c_out, c_ref, atol=1e-2, rtol=1e-2, msg='', logits_diff_threshold=2e-3): - if checkAllclose(c_out, c_ref, rtol=atol, atol=atol) < 0.05: + if checkAllclose(c_out, c_ref, rtol=rtol, atol=atol) < 0.05: return True # Calculate various error metrics diff --git a/thirdparty/llvm-hash.txt b/thirdparty/llvm-hash.txt index 4faf2ea9..978cdc8d 100644 --- a/thirdparty/llvm-hash.txt +++ b/thirdparty/llvm-hash.txt @@ -1 +1 @@ -ac5dc54d509169d387fcfd495d71853d81c46484 +27d654c4c4e6eb7c19e46af20500200e793da7c7