From 0a0f56396ff99f8417ff9cf5fe8b27545ef3cfca Mon Sep 17 00:00:00 2001 From: jli10004 Date: Thu, 19 Mar 2026 09:26:48 +0000 Subject: [PATCH 01/11] squash: pr/v0.1_gfx12 branch changes (87 commits) Squashed commits from pr/v0.1_gfx12 branch up to last merge with main. This includes gfx1250 WMMA GEMM kernels, global prefetch, triple buffer, TDM async copy, MoE kernel ports, blockscale API, and various fixes. Original commits: - b2b94a2 sync pre_v0.1 (Feng Shijie) - dd8c6fe update header macro (Feng Shijie) - eafa2d6 add separate target-specific rocdl dialect (Feng Shijie) - 3bd4150 Add utility nbmodules (Feng Shijie) - 51e45a5 Add universalMma Atom (Feng Shijie) - 866e2ed fix example02 (Feng Shijie) - 8b29d66 Add DLTensorAdaptor for torch Tensor support (Feng Shijie) - 44c6a3b Fix Python compatibility and remove hardcoded paths (jli) - ce61b2b Add logger and EnvManager (Feng Shijie) - bba5422 Refact Python module (Feng Shijie) - 6ca89f6 Fix missing module (Feng Shijie) - 8c784fb Add right inverse (Xudong Yuan) - 65f6a31 Add numeric typing (Feng Shijie) - a3e25cb Add ASTRewriter and improve jit_function cache mechanism (Feng Shijie) - 9b56be5 unwrap dsl_type before calling ir Op (Feng Shijie) - 8b64168 [MLIR][python] Upgrade to LLVM 23 (jli) - 7502ba4 Refactor Python bindings and improve DSL module exports (Feng Shijie) - 26d45a9 fix missing expore in primitive (Feng Shijie) - a250f90 Add tiled_copy partition (Feng Shijie) - 8085ea0 Pre v0.1 gemm (Felix Li) - 0ba4bb7 [FLYDSL]: add recast_layout op (Xudong Yuan) - 2679f23 Pre v0.1 gemm fix (Felix Li) - 890c860 add compile only and dumpir (Felix Li) - 6cee3f1 add version and wheel build (coderfeli) - 98fad64 port docs (coderfeli) - 04e7c46 build whl and dist version ok, upload pypi ok (coderfeli) - 0a8f4fe add aot example (coderfeli) - 178d332 [Tool][fly-opt] Add fly-opt tool and lit-based test suite (jli10004) - 1570126 Apply clang-format to fly-opt.cpp (jli10004) - 45e31d1 [Tests][Lit] Add lit tests to run_tests.sh and fix fly-opt build integration (jli10004) - 023e49b port gemm main (coderfeli) - 47ef125 merge latest gemm (coderfeli) - 2aca991 port async copy (coderfeli) - 9b24bcb add norm and softmax, and fix some style (coderfeli) - 68e29cc rm useless and fix 950 (coderfeli) - 9494711 cherry-pick pre_v0.1 (coderfeli) - cfd3bb1 temp remove profiler (coderfeli) - 0329979 [FLYDSL]: Bug fixes for algebra not being the simplest (Xudong Yuan) - 4db284c [Compiler][CacheKey] improve JIT cache key to hash entire compiler toolchain (jli10004) - 94c51bb [Bugfix] fix HIP graph capture segfault on PyTorch 2.9 / ROCm 7.1 (jli10004) - 2997922 [Test] fix run_tests.sh (jli10004) - 325b6f5 port moe gemm kernels to new flydsl runtime (xzhu) - 188698d port tests (felix) - 96a8b32 [FLYDSL]: add product test (Xudong Yuan) - c0b32c7 [BugFix] fix buffer descriptor flags and add missing ROCDL Python wrappers (jli10004) - fb6e20a add very naive wmma gemm for gfx1250 (aoli26) - 8f6d5cd refactor gfx1250 gemm & prepared for AM perf (aoli26) - 115ebca fix AM simulator target error (aoli26) - 1b1f9c6 fix gfx12 AM tcp assert failed introduced by torch (aoli26) - 96677b5 Add ROCDL subpackage (Feng Shijie) - bb2a7bf Rename BufferCopy op (Feng Shijie) - d5dc144 feat: dump IR with cache bypass and improved ISA readability (xzhu) - 906f8ee add TDM async copy WMMA GEMM kernel for gfx1250 (aoli26) - 4b690f6 [FLYDSL]: add logical_divide 2D by-mode mlir tests (Xudong Yuan) - 9a3d2c7 Dev blockscale new api (Felix Li) - f3fdae5 add global prefetch & triple buffer (aoli26) - 6e22816 remove cshuffle, preshuffle (aoli26) - 06ede86 bump llvm version to turn on global_prefetch_b8 (aoli26) - b3d5880 tdm sgpr descriptor & coalesced frag load (aoli26) - 8fad0d1 resolve merged issues (aoli26) - d8297ab add k-subtile for better pipeline (aoli26) - 5aeebdb support mmaAtom for gfx1250 wmma (aoli26) Contributors: Feng Shijie, Felix Li, Xudong Yuan, aoli26, xzhu, jli10004, coderfeli --- after_build.sh | 10 + build_flydsl.sh | 5 + dump_ir.sh | 56 ++ env.sh | 5 + hsakmt_counters.csv | 20 + include/flydsl-c/FlyROCDLDialect.h | 20 + include/flydsl/Dialect/Fly/IR/FlyDialect.td | 1 - include/flydsl/Dialect/FlyROCDL/IR/Dialect.td | 1 - include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td | 25 + ir_dump/00_origin_ir.mlir | 83 +++ ir_dump/01_GpuKernelOutliningPass.mlir | 83 +++ ir_dump/02_FlyCanonicalizePass.mlir | 79 +++ ir_dump/03_FlyLayoutLoweringPass.mlir | 78 +++ ir_dump/04_FlyToROCDLConversionPass.mlir | 93 +++ ir_dump/05_Canonicalizer.mlir | 59 ++ ir_dump/06_SCFToControlFlowPass.mlir | 39 ++ ir_dump/07_ConvertGpuOpsToROCDLOps.mlir | 35 ++ ir_dump/08_GpuROCDLAttachTarget.mlir | 55 ++ ir_dump/09_SCFToControlFlowPass.mlir | 55 ++ ir_dump/10_ConvertControlFlowToLLVMPass.mlir | 55 ++ ir_dump/11_FlyGpuStreamMarkPass.mlir | 55 ++ ir_dump/12_GpuToLLVMConversionPass.mlir | 65 ++ ir_dump/13_FlyGpuStreamInjectPass.mlir | 65 ++ ir_dump/14_ArithToLLVMConversionPass.mlir | 65 ++ ir_dump/15_ConvertFuncToLLVMPass.mlir | 65 ++ ir_dump/16_ReconcileUnrealizedCastsPass.mlir | 65 ++ ir_dump/17_GpuModuleToBinaryPass.mlir | 44 ++ kernels/wmma_gemm_gfx1250.py | 556 ++++++++++++++++++ kernels/wmma_gemm_simple.py | 255 ++++++++ lib/Bindings/Python/FlyROCDLExtension.cpp | 75 +++ lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp | 44 ++ lib/Conversion/FlyToROCDL/FlyToROCDL.cpp | 168 ++++++ lib/Dialect/FlyROCDL/CMakeLists.txt | 1 + lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp | 158 +++++ python/flydsl/_mlir | 1 + python/flydsl/compiler/jit_function.py | 7 + python/flydsl/expr/__init__.py | 2 +- python/flydsl/expr/rocdl.py | 51 ++ python/flydsl/expr/rocdl/__init__.py | 12 + python/flydsl/expr/rocdl/universal.py | 10 + python/flydsl/expr/tdm_ops.py | 502 ++++++++++++++++ python/flydsl/runtime/device.py | 14 +- tests/kernels/test_wmma_gemm_gfx1250.py | 133 +++++ tests/kernels/test_wmma_gemm_simple.py | 122 ++++ thirdparty/llvm-hash.txt | 2 +- 45 files changed, 3384 insertions(+), 10 deletions(-) create mode 100644 after_build.sh create mode 100755 build_flydsl.sh create mode 100755 dump_ir.sh create mode 100644 env.sh create mode 100644 hsakmt_counters.csv create mode 100644 ir_dump/00_origin_ir.mlir create mode 100644 ir_dump/01_GpuKernelOutliningPass.mlir create mode 100644 ir_dump/02_FlyCanonicalizePass.mlir create mode 100644 ir_dump/03_FlyLayoutLoweringPass.mlir create mode 100644 ir_dump/04_FlyToROCDLConversionPass.mlir create mode 100644 ir_dump/05_Canonicalizer.mlir create mode 100644 ir_dump/06_SCFToControlFlowPass.mlir create mode 100644 ir_dump/07_ConvertGpuOpsToROCDLOps.mlir create mode 100644 ir_dump/08_GpuROCDLAttachTarget.mlir create mode 100644 ir_dump/09_SCFToControlFlowPass.mlir create mode 100644 ir_dump/10_ConvertControlFlowToLLVMPass.mlir create mode 100644 ir_dump/11_FlyGpuStreamMarkPass.mlir create mode 100644 ir_dump/12_GpuToLLVMConversionPass.mlir create mode 100644 ir_dump/13_FlyGpuStreamInjectPass.mlir create mode 100644 ir_dump/14_ArithToLLVMConversionPass.mlir create mode 100644 ir_dump/15_ConvertFuncToLLVMPass.mlir create mode 100644 ir_dump/16_ReconcileUnrealizedCastsPass.mlir create mode 100644 ir_dump/17_GpuModuleToBinaryPass.mlir create mode 100644 kernels/wmma_gemm_gfx1250.py create mode 100644 kernels/wmma_gemm_simple.py create mode 100644 lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp create mode 120000 python/flydsl/_mlir create mode 100644 python/flydsl/expr/tdm_ops.py create mode 100644 tests/kernels/test_wmma_gemm_gfx1250.py create mode 100644 tests/kernels/test_wmma_gemm_simple.py diff --git a/after_build.sh b/after_build.sh new file mode 100644 index 00000000..ffc921a0 --- /dev/null +++ b/after_build.sh @@ -0,0 +1,10 @@ +#export MLIR_PATH=/home/jli10004/flydsl/llvm-project/buildmlir +#export PYTHONPATH=/data/jli/flydsl-ws/flydsl-prev/build/python_packages/:/data/jli/flydsl-ws/flydsl-prev/flydsl_/src:/data/jli/flydsl-ws/flydsl-prev:$PYTHONPATH +#export PATH=/data/jli/flydsl-ws/flydsl-prev/build/bin:$PATH + +MLIR_INSTALL=/home/jli10004/flydsl/mi450-cmodel-env/llvm-project/mlir_install +MLIR_PATH=/home/jli10004/flydsl/mi450-cmodel-env/llvm-project/build-flydsl + +export PYTHONPATH=/home/jli10004/flydsl/mi450-cmodel-env/flydsl-prev/build-fly/python_packages:/home/jli10004/flydsl/mi450-cmodel-env/llvm-project/build-flydsl/tools/mlir/python_packages/mlir_core:$PYTHONPATH +export PYTHONPATH=/home/jli10004/flydsl/mi450-cmodel-env/flydsl-prev/python:$PYTHONPATH +export LD_LIBRARY_PATH=$MLIR_INSTALL/lib:$LD_LIBRARY_PATH diff --git a/build_flydsl.sh b/build_flydsl.sh new file mode 100755 index 00000000..74ba47f3 --- /dev/null +++ b/build_flydsl.sh @@ -0,0 +1,5 @@ +export MLIR_PATH=/home/jli10004/flydsl/mi450-cmodel-env/llvm-project/mlir_install +mkdir -p build-fly && cd build-fly +cmake .. -DMLIR_DIR=/home/jli10004/flydsl/mi450-cmodel-env/llvm-project/mlir_install/lib/cmake/mlir -GNinja +NPROC=$(nproc 2>/dev/null || echo 4) +ninja -j${NPROC} diff --git a/dump_ir.sh b/dump_ir.sh new file mode 100755 index 00000000..ce5e1c92 --- /dev/null +++ b/dump_ir.sh @@ -0,0 +1,56 @@ +#!/usr/bin/bash +# Usage: ./dump_ir.sh [output_dir] +# +# Runs the example with IR printing enabled, then splits each pass's +# IR dump into a numbered file under (default: ./ir_dump). +set -euo pipefail + +EXAMPLE="${1:?Usage: $0 [output_dir]}" +OUTDIR="${2:-./ir_dump}" + +rm -rf "$OUTDIR" ~/.flydsl/cache/ +mkdir -p "$OUTDIR" + +cd /home/jli10004/flydsl/flydsl-prev +source after_build.sh 2>/dev/null || true + +export HIP_VISIBLE_DEVICES=0 +export FLYDSL_DEBUG_PRINT_ORIGIN_IR=1 +export FLYDSL_DEBUG_PRINT_AFTER_ALL=1 +export FLYDSL_DEBUG_LOG_TO_CONSOLE=1 +export FLYDSL_DEBUG_LOG_LEVEL=INFO + +python "$EXAMPLE" >"$OUTDIR/_raw.txt" 2>&1 + +python3 -c " +import re, sys, os + +outdir = '$OUTDIR' +with open(os.path.join(outdir, '_raw.txt')) as f: + text = f.read() + +sections = [] + +# Origin IR +m = re.search(r'Origin IR:\s*\n(module\b.*?)(?=\n// -----// IR Dump|\Z)', text, re.DOTALL) +if m: + sections.append(('origin_ir', m.group(1).rstrip())) + +# Per-pass IR +marker = re.compile(r'^// -----// IR Dump After (\S+)(?: \(.*?\))? //----- //\$', re.MULTILINE) +hits = list(marker.finditer(text)) +for i, h in enumerate(hits): + end = hits[i+1].start() if i+1 < len(hits) else len(text) + body = text[h.end()+1:end].rstrip() + sections.append((h.group(1), body)) + +for seq, (name, body) in enumerate(sections): + safe = re.sub(r'[^\w-]', '_', name) + fn = f'{seq:02d}_{safe}.mlir' + with open(os.path.join(outdir, fn), 'w') as f: + f.write(body + '\n') + print(f' {fn}') + +os.remove(os.path.join(outdir, '_raw.txt')) +print(f'\n{len(sections)} IR files written to {outdir}/') +" diff --git a/env.sh b/env.sh new file mode 100644 index 00000000..4fbd1ccb --- /dev/null +++ b/env.sh @@ -0,0 +1,5 @@ +export MLIR_PATH=/data/jli/flydsl-ws/llvm-project/build-flydsl/mlir_install +export PYTHONPATH=/data/jli/flydsl-ws/FlyDSL/.flir/build/python_packages/flydsl:/data/jli/flydsl-ws/FlyDSL/flydsl/src:/data/jli/flydsl-ws/FlyDSL:$PYTHONPATH +export PATH=/data/jli/flydsl-ws/FlyDSL/.flir/build/bin:$PATH +export SHOW_IR=1 +export PATH=/data/jli/flydsl-ws/llvm-project/build-flydsl/mlir_install/bin:$PATH diff --git a/hsakmt_counters.csv b/hsakmt_counters.csv new file mode 100644 index 00000000..e2ef75da --- /dev/null +++ b/hsakmt_counters.csv @@ -0,0 +1,20 @@ +hsakmtmodel.executor.0.jitcu.num_instr_analyzed,hsakmtmodel.executor.0.jitcu.num_instr_executed,hsakmtmodel.executor.0.jitcu.num_rts_primcache_hits,hsakmtmodel.executor.0.jitcu.num_rts_primcache_misses,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_flat,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_global_scratch_load,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_global_scratch_store,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_lds,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_salu,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_smem,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_tex,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_valu,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_valu_xdlmacc,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_waves,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_waves_created,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_waves_finished,hsakmtmodel.jitcu.num_instr_analyzed,hsakmtmodel.jitcu.num_instr_executed,hsakmtmodel.jitcu.num_rts_primcache_hits,hsakmtmodel.jitcu.num_rts_primcache_misses,hsakmtmodel.jitcu.sq_perf_sel_insts_flat,hsakmtmodel.jitcu.sq_perf_sel_insts_global_scratch_load,hsakmtmodel.jitcu.sq_perf_sel_insts_global_scratch_store,hsakmtmodel.jitcu.sq_perf_sel_insts_lds,hsakmtmodel.jitcu.sq_perf_sel_insts_salu,hsakmtmodel.jitcu.sq_perf_sel_insts_smem,hsakmtmodel.jitcu.sq_perf_sel_insts_tex,hsakmtmodel.jitcu.sq_perf_sel_insts_valu,hsakmtmodel.jitcu.sq_perf_sel_insts_valu_xdlmacc,hsakmtmodel.jitcu.sq_perf_sel_insts_waves,hsakmtmodel.jitcu.sq_perf_sel_insts_waves_created,hsakmtmodel.jitcu.sq_perf_sel_insts_waves_finished, +21248,21248,0,0,0,0,0,512,2176,2304,512,13184,0,384,384,384,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, +37376,37376,0,0,0,0,0,3584,2944,2304,3584,20864,0,384,384,384,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, +37376,37376,0,0,0,0,0,3584,2944,2304,3584,20864,0,384,384,384,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, +57426,57426,0,0,0,0,0,7510,3371,2304,7510,30674,0,384,384,384,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, +33792,32768,0,0,0,0,1024,0,19456,1024,1024,3072,0,1024,1024,1024,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, +33792,32768,0,0,0,0,1024,0,19456,1024,1024,3072,0,1024,1024,1024,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, +50176,50176,0,0,0,1024,1024,0,14336,2048,2048,14336,0,1024,10224668160,23365632,0,0,0,0,516096,0,2393088,21504,516096,13854720,0,132096,3072,3072,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, +93825,93825,0,0,0,0,0,14896,2840,2304,14896,49139,0,384,384,384,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, +3806208,3677184,0,027584,27584,0,0,0,0,0,1664,2752,2304,1664,16064,0,384,384,384,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, +117184,114704,0,0,0,4104,4008,20480,264,24,12208,31464,8192,136,8,8,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, +37376,37376,0,0,0,0,0,3584,2944,2304,3584,20864,0,384,384,384,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, +37376,37376,0,0,0,0,0,3584,2944,2304,3584,20864,0,384,384,384,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, +27904,27904,0,0,0,0,0,1792,2432,2304,1792,16384,0,384,384,384,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, +0, +10716,10548,0,0,0,92,72,0,4584,60,164,1900,0,20,20,20,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, +3976,3952,0,0,0,144,72,0,628,352,216,1396,0,20,20,20,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, +24643584,24643584,0,0,0,516096,516096,0,3870720,2451456,1032192,8386560,0,129024,129024,129024,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, +21248,21248,0,0,0,0,0,512,2176,2304,512,13184,0,384,384,384,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, +45158400,44126208,0,0,0,516096,516096,0,11612160,1935360,1032192,16128000,0,129024,129024,129024,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, diff --git a/include/flydsl-c/FlyROCDLDialect.h b/include/flydsl-c/FlyROCDLDialect.h index 0e2777ed..f4640d25 100644 --- a/include/flydsl-c/FlyROCDLDialect.h +++ b/include/flydsl-c/FlyROCDLDialect.h @@ -30,6 +30,26 @@ MLIR_CAPI_EXPORTED MlirType mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetElemTyA(MlirType MLIR_CAPI_EXPORTED MlirType mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetElemTyB(MlirType type); MLIR_CAPI_EXPORTED MlirType mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetElemTyAcc(MlirType type); +//===----------------------------------------------------------------------===// +// MmaAtomGFX1250_WMMAType +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAFlyROCDLMmaAtomGFX1250_WMMAType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetTypeID(void); + +// Constructor +MLIR_CAPI_EXPORTED MlirType mlirFlyROCDLMmaAtomGFX1250_WMMATypeGet(int32_t m, int32_t n, int32_t k, + MlirType elemTyA, MlirType elemTyB, + MlirType elemTyAcc); + +// Accessors +MLIR_CAPI_EXPORTED int32_t mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetM(MlirType type); +MLIR_CAPI_EXPORTED int32_t mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetN(MlirType type); +MLIR_CAPI_EXPORTED int32_t mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetK(MlirType type); +MLIR_CAPI_EXPORTED MlirType mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetElemTyA(MlirType type); +MLIR_CAPI_EXPORTED MlirType mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetElemTyB(MlirType type); +MLIR_CAPI_EXPORTED MlirType mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetElemTyAcc(MlirType type); + //===----------------------------------------------------------------------===// // CopyOpCDNA3BufferLDSTType //===----------------------------------------------------------------------===// diff --git a/include/flydsl/Dialect/Fly/IR/FlyDialect.td b/include/flydsl/Dialect/Fly/IR/FlyDialect.td index 73dd85cd..48ac923a 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyDialect.td +++ b/include/flydsl/Dialect/Fly/IR/FlyDialect.td @@ -16,7 +16,6 @@ def Fly_Dialect : Dialect { let useDefaultTypePrinterParser = 1; let useDefaultAttributePrinterParser = 1; - let usePropertiesForAttributes = 1; } class Fly_Type traits = []> diff --git a/include/flydsl/Dialect/FlyROCDL/IR/Dialect.td b/include/flydsl/Dialect/FlyROCDL/IR/Dialect.td index bae2c63b..60c2ba26 100644 --- a/include/flydsl/Dialect/FlyROCDL/IR/Dialect.td +++ b/include/flydsl/Dialect/FlyROCDL/IR/Dialect.td @@ -16,7 +16,6 @@ def FlyROCDL_Dialect : Dialect { ]; let useDefaultTypePrinterParser = 1; - let usePropertiesForAttributes = 1; } class FlyxROCL_MmaAtom traits = []> diff --git a/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td b/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td index e3275f06..5c73a089 100644 --- a/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td +++ b/include/flydsl/Dialect/FlyROCDL/IR/MmaAtom.td @@ -30,4 +30,29 @@ def FlyROCDL_MmaAtomCDNA3_MFMA : FlyxROCL_MmaAtom<"MmaAtomCDNA3_MFMA", "atom.cdn // MmaAtom CDNA4 //===----------------------------------------------------------------------===// + + +//===----------------------------------------------------------------------===// +// MmaAtom GFX1250 — WMMA wave32 +//===----------------------------------------------------------------------===// + +def FlyROCDL_MmaAtomGFX1250_WMMA : FlyxROCL_MmaAtom<"MmaAtomGFX1250_WMMA", "atom.gfx1250.wmma", []> { + let parameters = (ins + "int32_t":$m, + "int32_t":$n, + "int32_t":$k, + "Type":$elemTyA, + "Type":$elemTyB, + "Type":$elemTyAcc + ); + let assemblyFormat = "`<` custom($m, $n, $k) `,` `(` $elemTyA `,` $elemTyB `)` `->` $elemTyAcc `>`"; + + let builders = [ + TypeBuilderWithInferredContext<(ins "int32_t":$m, "int32_t":$n, "int32_t":$k, "Type":$elemTyA, "Type":$elemTyB, "Type":$elemTyAcc), [{ + return $_get(elemTyA.getContext(), m, n, k, elemTyA, elemTyB, elemTyAcc); + }]> + ]; + let genVerifyDecl = 1; +} + #endif // FLYROCDL_MMAATOM diff --git a/ir_dump/00_origin_ir.mlir b/ir_dump/00_origin_ir.mlir new file mode 100644 index 00000000..c4596836 --- /dev/null +++ b/ir_dump/00_origin_ir.mlir @@ -0,0 +1,83 @@ +module attributes {gpu.container_module} { + gpu.module @kernels [#rocdl.target] { + gpu.func @vectorAddKernel_0(%arg0: !fly.memref>, %arg1: !fly.memref>, %arg2: !fly.memref>) kernel { + %block_id_x = gpu.block_id x + %0 = arith.index_cast %block_id_x : index to i32 + %thread_id_x = gpu.thread_id x + %1 = arith.index_cast %thread_id_x : index to i32 + %2 = fly.make_shape() : () -> !fly.int_tuple<64> + %3 = fly.make_stride() : () -> !fly.int_tuple<1> + %4 = fly.make_layout(%2, %3) : (!fly.int_tuple<64>, !fly.int_tuple<1>) -> !fly.layout<64:1> + %5 = fly.logical_divide(%arg0, %4) : (!fly.memref>, !fly.layout<64:1>) -> !fly.memref> + %6 = fly.make_shape() : () -> !fly.int_tuple<64> + %7 = fly.make_stride() : () -> !fly.int_tuple<1> + %8 = fly.make_layout(%6, %7) : (!fly.int_tuple<64>, !fly.int_tuple<1>) -> !fly.layout<64:1> + %9 = fly.logical_divide(%arg1, %8) : (!fly.memref>, !fly.layout<64:1>) -> !fly.memref> + %10 = fly.make_shape() : () -> !fly.int_tuple<64> + %11 = fly.make_stride() : () -> !fly.int_tuple<1> + %12 = fly.make_layout(%10, %11) : (!fly.int_tuple<64>, !fly.int_tuple<1>) -> !fly.layout<64:1> + %13 = fly.logical_divide(%arg2, %12) : (!fly.memref>, !fly.layout<64:1>) -> !fly.memref> + %14 = fly.make_coord(%0) : (i32) -> !fly.int_tuple<(*,?)> + %15 = fly.slice(%5, %14) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> + %16 = fly.make_coord(%0) : (i32) -> !fly.int_tuple<(*,?)> + %17 = fly.slice(%9, %16) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> + %18 = fly.make_coord(%0) : (i32) -> !fly.int_tuple<(*,?)> + %19 = fly.slice(%13, %18) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> + %20 = fly.make_shape() : () -> !fly.int_tuple<1> + %21 = fly.make_stride() : () -> !fly.int_tuple<1> + %22 = fly.make_layout(%20, %21) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %23 = fly.logical_divide(%15, %22) : (!fly.memref>, !fly.layout<1:1>) -> !fly.memref> + %24 = fly.make_shape() : () -> !fly.int_tuple<1> + %25 = fly.make_stride() : () -> !fly.int_tuple<1> + %26 = fly.make_layout(%24, %25) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %27 = fly.logical_divide(%17, %26) : (!fly.memref>, !fly.layout<1:1>) -> !fly.memref> + %28 = fly.make_shape() : () -> !fly.int_tuple<1> + %29 = fly.make_stride() : () -> !fly.int_tuple<1> + %30 = fly.make_layout(%28, %29) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %31 = fly.logical_divide(%19, %30) : (!fly.memref>, !fly.layout<1:1>) -> !fly.memref> + %32 = fly.make_atom : () -> !fly.atom.universal_copy<32> + %33 = fly.make_shape() : () -> !fly.int_tuple<1> + %34 = fly.make_stride() : () -> !fly.int_tuple<1> + %35 = fly.make_layout(%33, %34) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %36 = fly.memref.alloca(%35) : (!fly.layout<1:1>) -> !fly.memref + %37 = fly.make_shape() : () -> !fly.int_tuple<1> + %38 = fly.make_stride() : () -> !fly.int_tuple<1> + %39 = fly.make_layout(%37, %38) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %40 = fly.memref.alloca(%39) : (!fly.layout<1:1>) -> !fly.memref + %41 = fly.make_shape() : () -> !fly.int_tuple<1> + %42 = fly.make_stride() : () -> !fly.int_tuple<1> + %43 = fly.make_layout(%41, %42) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %44 = fly.memref.alloca(%43) : (!fly.layout<1:1>) -> !fly.memref + %45 = fly.make_coord(%1) : (i32) -> !fly.int_tuple<(*,?)> + %46 = fly.slice(%23, %45) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> + fly.copy_atom_call(%32, %46, %36) : (!fly.atom.universal_copy<32>, !fly.memref>, !fly.memref) -> () + %47 = fly.make_coord(%1) : (i32) -> !fly.int_tuple<(*,?)> + %48 = fly.slice(%27, %47) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> + fly.copy_atom_call(%32, %48, %40) : (!fly.atom.universal_copy<32>, !fly.memref>, !fly.memref) -> () + %49 = fly.memref.load_vec(%36) : (!fly.memref) -> vector<1xf32> + %50 = fly.memref.load_vec(%40) : (!fly.memref) -> vector<1xf32> + %51 = arith.addf %49, %50 : vector<1xf32> + fly.memref.store_vec(%51, %44) : (vector<1xf32>, !fly.memref) -> () + %52 = fly.make_coord(%1) : (i32) -> !fly.int_tuple<(*,?)> + %53 = fly.slice(%31, %52) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> + fly.copy_atom_call(%32, %44, %53) : (!fly.atom.universal_copy<32>, !fly.memref, !fly.memref>) -> () + gpu.return + } + } + func.func @vectorAdd(%arg0: !fly.memref>, %arg1: !fly.memref>, %arg2: !fly.memref>, %arg3: i32, %arg4: !gpu.async.token) attributes {llvm.emit_c_interface} { + %c64_i32 = arith.constant 64 : i32 + %0 = arith.addi %arg3, %c64_i32 : i32 + %c1_i32 = arith.constant 1 : i32 + %1 = arith.subi %0, %c1_i32 : i32 + %c64_i32_0 = arith.constant 64 : i32 + %2 = arith.floordivsi %1, %c64_i32_0 : i32 + %3 = arith.index_cast %2 : i32 to index + %c1 = arith.constant 1 : index + %c1_1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1_2 = arith.constant 1 : index + %c1_3 = arith.constant 1 : index + gpu.launch_func <%arg4 : !gpu.async.token> @kernels::@vectorAddKernel_0 blocks in (%3, %c1, %c1_1) threads in (%c64, %c1_2, %c1_3) args(%arg0 : !fly.memref>, %arg1 : !fly.memref>, %arg2 : !fly.memref>) + return + } +} diff --git a/ir_dump/01_GpuKernelOutliningPass.mlir b/ir_dump/01_GpuKernelOutliningPass.mlir new file mode 100644 index 00000000..c4596836 --- /dev/null +++ b/ir_dump/01_GpuKernelOutliningPass.mlir @@ -0,0 +1,83 @@ +module attributes {gpu.container_module} { + gpu.module @kernels [#rocdl.target] { + gpu.func @vectorAddKernel_0(%arg0: !fly.memref>, %arg1: !fly.memref>, %arg2: !fly.memref>) kernel { + %block_id_x = gpu.block_id x + %0 = arith.index_cast %block_id_x : index to i32 + %thread_id_x = gpu.thread_id x + %1 = arith.index_cast %thread_id_x : index to i32 + %2 = fly.make_shape() : () -> !fly.int_tuple<64> + %3 = fly.make_stride() : () -> !fly.int_tuple<1> + %4 = fly.make_layout(%2, %3) : (!fly.int_tuple<64>, !fly.int_tuple<1>) -> !fly.layout<64:1> + %5 = fly.logical_divide(%arg0, %4) : (!fly.memref>, !fly.layout<64:1>) -> !fly.memref> + %6 = fly.make_shape() : () -> !fly.int_tuple<64> + %7 = fly.make_stride() : () -> !fly.int_tuple<1> + %8 = fly.make_layout(%6, %7) : (!fly.int_tuple<64>, !fly.int_tuple<1>) -> !fly.layout<64:1> + %9 = fly.logical_divide(%arg1, %8) : (!fly.memref>, !fly.layout<64:1>) -> !fly.memref> + %10 = fly.make_shape() : () -> !fly.int_tuple<64> + %11 = fly.make_stride() : () -> !fly.int_tuple<1> + %12 = fly.make_layout(%10, %11) : (!fly.int_tuple<64>, !fly.int_tuple<1>) -> !fly.layout<64:1> + %13 = fly.logical_divide(%arg2, %12) : (!fly.memref>, !fly.layout<64:1>) -> !fly.memref> + %14 = fly.make_coord(%0) : (i32) -> !fly.int_tuple<(*,?)> + %15 = fly.slice(%5, %14) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> + %16 = fly.make_coord(%0) : (i32) -> !fly.int_tuple<(*,?)> + %17 = fly.slice(%9, %16) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> + %18 = fly.make_coord(%0) : (i32) -> !fly.int_tuple<(*,?)> + %19 = fly.slice(%13, %18) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> + %20 = fly.make_shape() : () -> !fly.int_tuple<1> + %21 = fly.make_stride() : () -> !fly.int_tuple<1> + %22 = fly.make_layout(%20, %21) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %23 = fly.logical_divide(%15, %22) : (!fly.memref>, !fly.layout<1:1>) -> !fly.memref> + %24 = fly.make_shape() : () -> !fly.int_tuple<1> + %25 = fly.make_stride() : () -> !fly.int_tuple<1> + %26 = fly.make_layout(%24, %25) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %27 = fly.logical_divide(%17, %26) : (!fly.memref>, !fly.layout<1:1>) -> !fly.memref> + %28 = fly.make_shape() : () -> !fly.int_tuple<1> + %29 = fly.make_stride() : () -> !fly.int_tuple<1> + %30 = fly.make_layout(%28, %29) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %31 = fly.logical_divide(%19, %30) : (!fly.memref>, !fly.layout<1:1>) -> !fly.memref> + %32 = fly.make_atom : () -> !fly.atom.universal_copy<32> + %33 = fly.make_shape() : () -> !fly.int_tuple<1> + %34 = fly.make_stride() : () -> !fly.int_tuple<1> + %35 = fly.make_layout(%33, %34) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %36 = fly.memref.alloca(%35) : (!fly.layout<1:1>) -> !fly.memref + %37 = fly.make_shape() : () -> !fly.int_tuple<1> + %38 = fly.make_stride() : () -> !fly.int_tuple<1> + %39 = fly.make_layout(%37, %38) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %40 = fly.memref.alloca(%39) : (!fly.layout<1:1>) -> !fly.memref + %41 = fly.make_shape() : () -> !fly.int_tuple<1> + %42 = fly.make_stride() : () -> !fly.int_tuple<1> + %43 = fly.make_layout(%41, %42) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %44 = fly.memref.alloca(%43) : (!fly.layout<1:1>) -> !fly.memref + %45 = fly.make_coord(%1) : (i32) -> !fly.int_tuple<(*,?)> + %46 = fly.slice(%23, %45) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> + fly.copy_atom_call(%32, %46, %36) : (!fly.atom.universal_copy<32>, !fly.memref>, !fly.memref) -> () + %47 = fly.make_coord(%1) : (i32) -> !fly.int_tuple<(*,?)> + %48 = fly.slice(%27, %47) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> + fly.copy_atom_call(%32, %48, %40) : (!fly.atom.universal_copy<32>, !fly.memref>, !fly.memref) -> () + %49 = fly.memref.load_vec(%36) : (!fly.memref) -> vector<1xf32> + %50 = fly.memref.load_vec(%40) : (!fly.memref) -> vector<1xf32> + %51 = arith.addf %49, %50 : vector<1xf32> + fly.memref.store_vec(%51, %44) : (vector<1xf32>, !fly.memref) -> () + %52 = fly.make_coord(%1) : (i32) -> !fly.int_tuple<(*,?)> + %53 = fly.slice(%31, %52) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> + fly.copy_atom_call(%32, %44, %53) : (!fly.atom.universal_copy<32>, !fly.memref, !fly.memref>) -> () + gpu.return + } + } + func.func @vectorAdd(%arg0: !fly.memref>, %arg1: !fly.memref>, %arg2: !fly.memref>, %arg3: i32, %arg4: !gpu.async.token) attributes {llvm.emit_c_interface} { + %c64_i32 = arith.constant 64 : i32 + %0 = arith.addi %arg3, %c64_i32 : i32 + %c1_i32 = arith.constant 1 : i32 + %1 = arith.subi %0, %c1_i32 : i32 + %c64_i32_0 = arith.constant 64 : i32 + %2 = arith.floordivsi %1, %c64_i32_0 : i32 + %3 = arith.index_cast %2 : i32 to index + %c1 = arith.constant 1 : index + %c1_1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %c1_2 = arith.constant 1 : index + %c1_3 = arith.constant 1 : index + gpu.launch_func <%arg4 : !gpu.async.token> @kernels::@vectorAddKernel_0 blocks in (%3, %c1, %c1_1) threads in (%c64, %c1_2, %c1_3) args(%arg0 : !fly.memref>, %arg1 : !fly.memref>, %arg2 : !fly.memref>) + return + } +} diff --git a/ir_dump/02_FlyCanonicalizePass.mlir b/ir_dump/02_FlyCanonicalizePass.mlir new file mode 100644 index 00000000..b61fd81d --- /dev/null +++ b/ir_dump/02_FlyCanonicalizePass.mlir @@ -0,0 +1,79 @@ +module attributes {gpu.container_module} { + gpu.module @kernels [#rocdl.target] { + gpu.func @vectorAddKernel_0(%arg0: !fly.memref>, %arg1: !fly.memref>, %arg2: !fly.memref>) kernel { + %block_id_x = gpu.block_id x + %0 = arith.index_cast %block_id_x : index to i32 + %thread_id_x = gpu.thread_id x + %1 = arith.index_cast %thread_id_x : index to i32 + %2 = fly.make_int_tuple() : () -> !fly.int_tuple<64> + %3 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %4 = fly.make_layout(%2, %3) : (!fly.int_tuple<64>, !fly.int_tuple<1>) -> !fly.layout<64:1> + %5 = fly.logical_divide(%arg0, %4) : (!fly.memref>, !fly.layout<64:1>) -> !fly.memref> + %6 = fly.make_int_tuple() : () -> !fly.int_tuple<64> + %7 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %8 = fly.make_layout(%6, %7) : (!fly.int_tuple<64>, !fly.int_tuple<1>) -> !fly.layout<64:1> + %9 = fly.logical_divide(%arg1, %8) : (!fly.memref>, !fly.layout<64:1>) -> !fly.memref> + %10 = fly.make_int_tuple() : () -> !fly.int_tuple<64> + %11 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %12 = fly.make_layout(%10, %11) : (!fly.int_tuple<64>, !fly.int_tuple<1>) -> !fly.layout<64:1> + %13 = fly.logical_divide(%arg2, %12) : (!fly.memref>, !fly.layout<64:1>) -> !fly.memref> + %14 = fly.make_int_tuple(%0) : (i32) -> !fly.int_tuple<(*,?)> + %15 = fly.slice(%5, %14) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> + %16 = fly.make_int_tuple(%0) : (i32) -> !fly.int_tuple<(*,?)> + %17 = fly.slice(%9, %16) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> + %18 = fly.make_int_tuple(%0) : (i32) -> !fly.int_tuple<(*,?)> + %19 = fly.slice(%13, %18) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> + %20 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %21 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %22 = fly.make_layout(%20, %21) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %23 = fly.logical_divide(%15, %22) : (!fly.memref>, !fly.layout<1:1>) -> !fly.memref> + %24 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %25 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %26 = fly.make_layout(%24, %25) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %27 = fly.logical_divide(%17, %26) : (!fly.memref>, !fly.layout<1:1>) -> !fly.memref> + %28 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %29 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %30 = fly.make_layout(%28, %29) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %31 = fly.logical_divide(%19, %30) : (!fly.memref>, !fly.layout<1:1>) -> !fly.memref> + %32 = fly.make_atom : () -> !fly.atom.universal_copy<32> + %33 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %34 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %35 = fly.make_layout(%33, %34) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %36 = fly.memref.alloca(%35) : (!fly.layout<1:1>) -> !fly.memref + %37 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %38 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %39 = fly.make_layout(%37, %38) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %40 = fly.memref.alloca(%39) : (!fly.layout<1:1>) -> !fly.memref + %41 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %42 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %43 = fly.make_layout(%41, %42) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %44 = fly.memref.alloca(%43) : (!fly.layout<1:1>) -> !fly.memref + %45 = fly.make_int_tuple(%1) : (i32) -> !fly.int_tuple<(*,?)> + %46 = fly.slice(%23, %45) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> + fly.copy_atom_call(%32, %46, %36) : (!fly.atom.universal_copy<32>, !fly.memref>, !fly.memref) -> () + %47 = fly.make_int_tuple(%1) : (i32) -> !fly.int_tuple<(*,?)> + %48 = fly.slice(%27, %47) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> + fly.copy_atom_call(%32, %48, %40) : (!fly.atom.universal_copy<32>, !fly.memref>, !fly.memref) -> () + %49 = fly.memref.load_vec(%36) : (!fly.memref) -> vector<1xf32> + %50 = fly.memref.load_vec(%40) : (!fly.memref) -> vector<1xf32> + %51 = arith.addf %49, %50 : vector<1xf32> + fly.memref.store_vec(%51, %44) : (vector<1xf32>, !fly.memref) -> () + %52 = fly.make_int_tuple(%1) : (i32) -> !fly.int_tuple<(*,?)> + %53 = fly.slice(%31, %52) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> + fly.copy_atom_call(%32, %44, %53) : (!fly.atom.universal_copy<32>, !fly.memref, !fly.memref>) -> () + gpu.return + } + } + func.func @vectorAdd(%arg0: !fly.memref>, %arg1: !fly.memref>, %arg2: !fly.memref>, %arg3: i32, %arg4: !gpu.async.token) attributes {llvm.emit_c_interface} { + %c64 = arith.constant 64 : index + %c1 = arith.constant 1 : index + %c1_i32 = arith.constant 1 : i32 + %c64_i32 = arith.constant 64 : i32 + %0 = arith.addi %arg3, %c64_i32 : i32 + %1 = arith.subi %0, %c1_i32 : i32 + %2 = arith.floordivsi %1, %c64_i32 : i32 + %3 = arith.index_cast %2 : i32 to index + gpu.launch_func <%arg4 : !gpu.async.token> @kernels::@vectorAddKernel_0 blocks in (%3, %c1, %c1) threads in (%c64, %c1, %c1) args(%arg0 : !fly.memref>, %arg1 : !fly.memref>, %arg2 : !fly.memref>) + return + } +} diff --git a/ir_dump/03_FlyLayoutLoweringPass.mlir b/ir_dump/03_FlyLayoutLoweringPass.mlir new file mode 100644 index 00000000..5b828cf6 --- /dev/null +++ b/ir_dump/03_FlyLayoutLoweringPass.mlir @@ -0,0 +1,78 @@ +module attributes {gpu.container_module} { + gpu.module @kernels [#rocdl.target] { + gpu.func @vectorAddKernel_0(%arg0: !fly.ptr>, %arg1: !llvm.struct, struct)>, %arg2: !fly.ptr>, %arg3: !fly.ptr>) kernel { + %c64_i32 = arith.constant 64 : i32 + %block_id_x = gpu.block_id x + %0 = arith.index_cast %block_id_x : index to i32 + %thread_id_x = gpu.thread_id x + %1 = arith.index_cast %thread_id_x : index to i32 + %2 = arith.muli %0, %c64_i32 : i32 + %3 = fly.make_int_tuple(%2) : (i32) -> !fly.int_tuple + %4 = fly.add_offset(%arg0, %3) : (!fly.ptr>, !fly.int_tuple) -> !fly.ptr> + %5 = arith.muli %0, %c64_i32 : i32 + %6 = fly.make_int_tuple(%5) : (i32) -> !fly.int_tuple + %7 = fly.add_offset(%arg2, %6) : (!fly.ptr>, !fly.int_tuple) -> !fly.ptr> + %8 = arith.muli %0, %c64_i32 : i32 + %9 = fly.make_int_tuple(%8) : (i32) -> !fly.int_tuple + %10 = fly.add_offset(%arg3, %9) : (!fly.ptr>, !fly.int_tuple) -> !fly.ptr> + %11 = fly.make_atom : () -> !fly.atom.universal_copy<32> + %12 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %13 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %14 = fly.make_layout(%12, %13) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %15 = fly.memref.alloca(%14) : (!fly.layout<1:1>) -> !fly.memref + %16 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %17 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %18 = fly.make_layout(%16, %17) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %19 = fly.memref.alloca(%18) : (!fly.layout<1:1>) -> !fly.memref + %20 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %21 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %22 = fly.make_layout(%20, %21) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %23 = fly.memref.alloca(%22) : (!fly.layout<1:1>) -> !fly.memref + %24 = fly.make_int_tuple(%1) : (i32) -> !fly.int_tuple + %25 = fly.add_offset(%4, %24) : (!fly.ptr>, !fly.int_tuple) -> !fly.ptr> + %26 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %27 = fly.make_int_tuple() : () -> !fly.int_tuple<0> + %28 = fly.make_layout(%26, %27) : (!fly.int_tuple<1>, !fly.int_tuple<0>) -> !fly.layout<1:0> + %29 = fly.make_view(%25, %28) : (!fly.ptr>, !fly.layout<1:0>) -> !fly.memref> + fly.copy_atom_call(%11, %29, %15) : (!fly.atom.universal_copy<32>, !fly.memref>, !fly.memref) -> () + %30 = fly.make_int_tuple(%1) : (i32) -> !fly.int_tuple + %31 = fly.add_offset(%7, %30) : (!fly.ptr>, !fly.int_tuple) -> !fly.ptr> + %32 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %33 = fly.make_int_tuple() : () -> !fly.int_tuple<0> + %34 = fly.make_layout(%32, %33) : (!fly.int_tuple<1>, !fly.int_tuple<0>) -> !fly.layout<1:0> + %35 = fly.make_view(%31, %34) : (!fly.ptr>, !fly.layout<1:0>) -> !fly.memref> + fly.copy_atom_call(%11, %35, %19) : (!fly.atom.universal_copy<32>, !fly.memref>, !fly.memref) -> () + %36 = fly.memref.load_vec(%15) : (!fly.memref) -> vector<1xf32> + %37 = fly.memref.load_vec(%19) : (!fly.memref) -> vector<1xf32> + %38 = arith.addf %36, %37 : vector<1xf32> + fly.memref.store_vec(%38, %23) : (vector<1xf32>, !fly.memref) -> () + %39 = fly.make_int_tuple(%1) : (i32) -> !fly.int_tuple + %40 = fly.add_offset(%10, %39) : (!fly.ptr>, !fly.int_tuple) -> !fly.ptr> + %41 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %42 = fly.make_int_tuple() : () -> !fly.int_tuple<0> + %43 = fly.make_layout(%41, %42) : (!fly.int_tuple<1>, !fly.int_tuple<0>) -> !fly.layout<1:0> + %44 = fly.make_view(%40, %43) : (!fly.ptr>, !fly.layout<1:0>) -> !fly.memref> + fly.copy_atom_call(%11, %23, %44) : (!fly.atom.universal_copy<32>, !fly.memref, !fly.memref>) -> () + gpu.return + } + } + func.func @vectorAdd(%arg0: !fly.ptr>, %arg1: !llvm.struct, struct)>, %arg2: !fly.ptr>, %arg3: !fly.ptr>, %arg4: i32, %arg5: !gpu.async.token) attributes {llvm.emit_c_interface} { + %0 = llvm.mlir.undef : !llvm.struct, struct)> + %1 = llvm.mlir.undef : !llvm.struct + %2 = llvm.mlir.undef : !llvm.struct + %c64_i32 = arith.constant 64 : i32 + %c1_i32 = arith.constant 1 : i32 + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %3 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> + %4 = arith.addi %arg4, %c64_i32 : i32 + %5 = arith.subi %4, %c1_i32 : i32 + %6 = arith.floordivsi %5, %c64_i32 : i32 + %7 = arith.index_cast %6 : i32 to index + %8 = llvm.insertvalue %3, %2[0] : !llvm.struct + %9 = llvm.insertvalue %8, %0[0] : !llvm.struct, struct)> + %10 = llvm.insertvalue %1, %9[1] : !llvm.struct, struct)> + gpu.launch_func <%arg5 : !gpu.async.token> @kernels::@vectorAddKernel_0 blocks in (%7, %c1, %c1) threads in (%c64, %c1, %c1) args(%arg0 : !fly.ptr>, %10 : !llvm.struct, struct)>, %arg2 : !fly.ptr>, %arg3 : !fly.ptr>) + return + } +} diff --git a/ir_dump/04_FlyToROCDLConversionPass.mlir b/ir_dump/04_FlyToROCDLConversionPass.mlir new file mode 100644 index 00000000..57a4b418 --- /dev/null +++ b/ir_dump/04_FlyToROCDLConversionPass.mlir @@ -0,0 +1,93 @@ +module attributes {gpu.container_module} { + gpu.module @kernels [#rocdl.target] { + gpu.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) kernel { + %c64_i32 = arith.constant 64 : i32 + %block_id_x = gpu.block_id x + %0 = arith.index_cast %block_id_x : index to i32 + %thread_id_x = gpu.thread_id x + %1 = arith.index_cast %thread_id_x : index to i32 + %2 = arith.muli %0, %c64_i32 : i32 + %3 = fly.make_int_tuple(%2) : (i32) -> !fly.int_tuple + %4 = arith.index_cast %2 : i32 to index + %5 = arith.index_cast %4 : index to i64 + %6 = llvm.getelementptr %arg0[%5] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %7 = arith.muli %0, %c64_i32 : i32 + %8 = fly.make_int_tuple(%7) : (i32) -> !fly.int_tuple + %9 = arith.index_cast %7 : i32 to index + %10 = arith.index_cast %9 : index to i64 + %11 = llvm.getelementptr %arg2[%10] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %12 = arith.muli %0, %c64_i32 : i32 + %13 = fly.make_int_tuple(%12) : (i32) -> !fly.int_tuple + %14 = arith.index_cast %12 : i32 to index + %15 = arith.index_cast %14 : index to i64 + %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %17 = fly.make_atom : () -> !fly.atom.universal_copy<32> + %18 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %19 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %20 = fly.make_layout(%18, %19) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %c1_i64 = arith.constant 1 : i64 + %21 = llvm.alloca %c1_i64 x f32 : (i64) -> !llvm.ptr<5> + %22 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %23 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %24 = fly.make_layout(%22, %23) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %c1_i64_0 = arith.constant 1 : i64 + %25 = llvm.alloca %c1_i64_0 x f32 : (i64) -> !llvm.ptr<5> + %26 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %27 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %28 = fly.make_layout(%26, %27) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> + %c1_i64_1 = arith.constant 1 : i64 + %29 = llvm.alloca %c1_i64_1 x f32 : (i64) -> !llvm.ptr<5> + %30 = fly.make_int_tuple(%1) : (i32) -> !fly.int_tuple + %31 = arith.index_cast %1 : i32 to index + %32 = arith.index_cast %31 : index to i64 + %33 = llvm.getelementptr %6[%32] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %34 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %35 = fly.make_int_tuple() : () -> !fly.int_tuple<0> + %36 = fly.make_layout(%34, %35) : (!fly.int_tuple<1>, !fly.int_tuple<0>) -> !fly.layout<1:0> + %c4_i64 = arith.constant 4 : i64 + "llvm.intr.memcpy"(%21, %33, %c4_i64) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %37 = fly.make_int_tuple(%1) : (i32) -> !fly.int_tuple + %38 = arith.index_cast %1 : i32 to index + %39 = arith.index_cast %38 : index to i64 + %40 = llvm.getelementptr %11[%39] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %41 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %42 = fly.make_int_tuple() : () -> !fly.int_tuple<0> + %43 = fly.make_layout(%41, %42) : (!fly.int_tuple<1>, !fly.int_tuple<0>) -> !fly.layout<1:0> + %c4_i64_2 = arith.constant 4 : i64 + "llvm.intr.memcpy"(%25, %40, %c4_i64_2) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %44 = llvm.load %21 : !llvm.ptr<5> -> vector<1xf32> + %45 = llvm.load %25 : !llvm.ptr<5> -> vector<1xf32> + %46 = arith.addf %44, %45 : vector<1xf32> + llvm.store %46, %29 : vector<1xf32>, !llvm.ptr<5> + %47 = fly.make_int_tuple(%1) : (i32) -> !fly.int_tuple + %48 = arith.index_cast %1 : i32 to index + %49 = arith.index_cast %48 : index to i64 + %50 = llvm.getelementptr %16[%49] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %51 = fly.make_int_tuple() : () -> !fly.int_tuple<1> + %52 = fly.make_int_tuple() : () -> !fly.int_tuple<0> + %53 = fly.make_layout(%51, %52) : (!fly.int_tuple<1>, !fly.int_tuple<0>) -> !fly.layout<1:0> + %c4_i64_3 = arith.constant 4 : i64 + "llvm.intr.memcpy"(%50, %29, %c4_i64_3) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () + gpu.return + } + } + func.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !gpu.async.token) attributes {llvm.emit_c_interface} { + %0 = llvm.mlir.undef : !llvm.struct, struct)> + %1 = llvm.mlir.undef : !llvm.struct + %2 = llvm.mlir.undef : !llvm.struct + %c64_i32 = arith.constant 64 : i32 + %c1_i32 = arith.constant 1 : i32 + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %3 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> + %4 = arith.addi %arg4, %c64_i32 : i32 + %5 = arith.subi %4, %c1_i32 : i32 + %6 = arith.floordivsi %5, %c64_i32 : i32 + %7 = arith.index_cast %6 : i32 to index + %8 = llvm.insertvalue %3, %2[0] : !llvm.struct + %9 = llvm.insertvalue %8, %0[0] : !llvm.struct, struct)> + %10 = llvm.insertvalue %1, %9[1] : !llvm.struct, struct)> + gpu.launch_func <%arg5 : !gpu.async.token> @kernels::@vectorAddKernel_0 blocks in (%7, %c1, %c1) threads in (%c64, %c1, %c1) args(%arg0 : !llvm.ptr<1>, %10 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) + return + } +} diff --git a/ir_dump/05_Canonicalizer.mlir b/ir_dump/05_Canonicalizer.mlir new file mode 100644 index 00000000..227a8222 --- /dev/null +++ b/ir_dump/05_Canonicalizer.mlir @@ -0,0 +1,59 @@ +module attributes {gpu.container_module} { + gpu.module @kernels [#rocdl.target] { + gpu.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) kernel { + %c4_i64 = arith.constant 4 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i32 = arith.constant 64 : i32 + %block_id_x = gpu.block_id x + %0 = arith.index_cast %block_id_x : index to i32 + %thread_id_x = gpu.thread_id x + %1 = arith.muli %0, %c64_i32 : i32 + %2 = arith.index_cast %1 : i32 to index + %3 = arith.index_cast %2 : index to i64 + %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %5 = arith.muli %0, %c64_i32 : i32 + %6 = arith.index_cast %5 : i32 to index + %7 = arith.index_cast %6 : index to i64 + %8 = llvm.getelementptr %arg2[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %9 = arith.muli %0, %c64_i32 : i32 + %10 = arith.index_cast %9 : i32 to index + %11 = arith.index_cast %10 : index to i64 + %12 = llvm.getelementptr %arg3[%11] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %13 = llvm.alloca %c1_i64 x f32 : (i64) -> !llvm.ptr<5> + %14 = llvm.alloca %c1_i64 x f32 : (i64) -> !llvm.ptr<5> + %15 = llvm.alloca %c1_i64 x f32 : (i64) -> !llvm.ptr<5> + %16 = arith.index_cast %thread_id_x : index to i64 + %17 = llvm.getelementptr %4[%16] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%13, %17, %c4_i64) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %18 = arith.index_cast %thread_id_x : index to i64 + %19 = llvm.getelementptr %8[%18] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%14, %19, %c4_i64) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %20 = llvm.load %13 : !llvm.ptr<5> -> vector<1xf32> + %21 = llvm.load %14 : !llvm.ptr<5> -> vector<1xf32> + %22 = arith.addf %20, %21 : vector<1xf32> + llvm.store %22, %15 : vector<1xf32>, !llvm.ptr<5> + %23 = arith.index_cast %thread_id_x : index to i64 + %24 = llvm.getelementptr %12[%23] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%24, %15, %c4_i64) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () + gpu.return + } + } + func.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !gpu.async.token) attributes {llvm.emit_c_interface} { + %c63_i32 = arith.constant 63 : i32 + %0 = llvm.mlir.undef : !llvm.struct, struct)> + %1 = llvm.mlir.undef : !llvm.struct + %2 = llvm.mlir.undef : !llvm.struct + %c64_i32 = arith.constant 64 : i32 + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %3 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> + %4 = arith.addi %arg4, %c63_i32 : i32 + %5 = arith.floordivsi %4, %c64_i32 : i32 + %6 = arith.index_cast %5 : i32 to index + %7 = llvm.insertvalue %3, %2[0] : !llvm.struct + %8 = llvm.insertvalue %7, %0[0] : !llvm.struct, struct)> + %9 = llvm.insertvalue %1, %8[1] : !llvm.struct, struct)> + gpu.launch_func <%arg5 : !gpu.async.token> @kernels::@vectorAddKernel_0 blocks in (%6, %c1, %c1) threads in (%c64, %c1, %c1) args(%arg0 : !llvm.ptr<1>, %9 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) + return + } +} diff --git a/ir_dump/06_SCFToControlFlowPass.mlir b/ir_dump/06_SCFToControlFlowPass.mlir new file mode 100644 index 00000000..2dec6698 --- /dev/null +++ b/ir_dump/06_SCFToControlFlowPass.mlir @@ -0,0 +1,39 @@ +gpu.module @kernels [#rocdl.target] { + gpu.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) kernel { + %c4_i64 = arith.constant 4 : i64 + %c1_i64 = arith.constant 1 : i64 + %c64_i32 = arith.constant 64 : i32 + %block_id_x = gpu.block_id x + %0 = arith.index_cast %block_id_x : index to i32 + %thread_id_x = gpu.thread_id x + %1 = arith.muli %0, %c64_i32 : i32 + %2 = arith.index_cast %1 : i32 to index + %3 = arith.index_cast %2 : index to i64 + %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %5 = arith.muli %0, %c64_i32 : i32 + %6 = arith.index_cast %5 : i32 to index + %7 = arith.index_cast %6 : index to i64 + %8 = llvm.getelementptr %arg2[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %9 = arith.muli %0, %c64_i32 : i32 + %10 = arith.index_cast %9 : i32 to index + %11 = arith.index_cast %10 : index to i64 + %12 = llvm.getelementptr %arg3[%11] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %13 = llvm.alloca %c1_i64 x f32 : (i64) -> !llvm.ptr<5> + %14 = llvm.alloca %c1_i64 x f32 : (i64) -> !llvm.ptr<5> + %15 = llvm.alloca %c1_i64 x f32 : (i64) -> !llvm.ptr<5> + %16 = arith.index_cast %thread_id_x : index to i64 + %17 = llvm.getelementptr %4[%16] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%13, %17, %c4_i64) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %18 = arith.index_cast %thread_id_x : index to i64 + %19 = llvm.getelementptr %8[%18] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%14, %19, %c4_i64) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %20 = llvm.load %13 : !llvm.ptr<5> -> vector<1xf32> + %21 = llvm.load %14 : !llvm.ptr<5> -> vector<1xf32> + %22 = arith.addf %20, %21 : vector<1xf32> + llvm.store %22, %15 : vector<1xf32>, !llvm.ptr<5> + %23 = arith.index_cast %thread_id_x : index to i64 + %24 = llvm.getelementptr %12[%23] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%24, %15, %c4_i64) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () + gpu.return + } +} diff --git a/ir_dump/07_ConvertGpuOpsToROCDLOps.mlir b/ir_dump/07_ConvertGpuOpsToROCDLOps.mlir new file mode 100644 index 00000000..6e862f77 --- /dev/null +++ b/ir_dump/07_ConvertGpuOpsToROCDLOps.mlir @@ -0,0 +1,35 @@ +gpu.module @kernels [#rocdl.target] attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"} { + llvm.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {gpu.kernel, rocdl.kernel} { + %0 = llvm.mlir.constant(4 : i64) : i64 + %1 = llvm.mlir.constant(1 : i64) : i64 + %2 = llvm.mlir.constant(64 : i32) : i32 + %3 = rocdl.workgroup.id.x : i32 + %4 = llvm.sext %3 : i32 to i64 + %5 = llvm.trunc %4 : i64 to i32 + %6 = rocdl.workitem.id.x : i32 + %7 = llvm.sext %6 : i32 to i64 + %8 = llvm.mul %5, %2 : i32 + %9 = llvm.sext %8 : i32 to i64 + %10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %11 = llvm.mul %5, %2 : i32 + %12 = llvm.sext %11 : i32 to i64 + %13 = llvm.getelementptr %arg2[%12] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %14 = llvm.mul %5, %2 : i32 + %15 = llvm.sext %14 : i32 to i64 + %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %17 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %18 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %19 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %20 = llvm.getelementptr %10[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%17, %20, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %21 = llvm.getelementptr %13[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%18, %21, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %22 = llvm.load %17 : !llvm.ptr<5> -> vector<1xf32> + %23 = llvm.load %18 : !llvm.ptr<5> -> vector<1xf32> + %24 = llvm.fadd %22, %23 : vector<1xf32> + llvm.store %24, %19 : vector<1xf32>, !llvm.ptr<5> + %25 = llvm.getelementptr %16[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%25, %19, %0) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () + llvm.return + } +} diff --git a/ir_dump/08_GpuROCDLAttachTarget.mlir b/ir_dump/08_GpuROCDLAttachTarget.mlir new file mode 100644 index 00000000..496e5e86 --- /dev/null +++ b/ir_dump/08_GpuROCDLAttachTarget.mlir @@ -0,0 +1,55 @@ +module attributes {gpu.container_module} { + gpu.module @kernels [#rocdl.target] attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"} { + llvm.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {gpu.kernel, rocdl.kernel} { + %0 = llvm.mlir.constant(4 : i64) : i64 + %1 = llvm.mlir.constant(1 : i64) : i64 + %2 = llvm.mlir.constant(64 : i32) : i32 + %3 = rocdl.workgroup.id.x : i32 + %4 = llvm.sext %3 : i32 to i64 + %5 = llvm.trunc %4 : i64 to i32 + %6 = rocdl.workitem.id.x : i32 + %7 = llvm.sext %6 : i32 to i64 + %8 = llvm.mul %5, %2 : i32 + %9 = llvm.sext %8 : i32 to i64 + %10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %11 = llvm.mul %5, %2 : i32 + %12 = llvm.sext %11 : i32 to i64 + %13 = llvm.getelementptr %arg2[%12] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %14 = llvm.mul %5, %2 : i32 + %15 = llvm.sext %14 : i32 to i64 + %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %17 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %18 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %19 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %20 = llvm.getelementptr %10[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%17, %20, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %21 = llvm.getelementptr %13[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%18, %21, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %22 = llvm.load %17 : !llvm.ptr<5> -> vector<1xf32> + %23 = llvm.load %18 : !llvm.ptr<5> -> vector<1xf32> + %24 = llvm.fadd %22, %23 : vector<1xf32> + llvm.store %24, %19 : vector<1xf32>, !llvm.ptr<5> + %25 = llvm.getelementptr %16[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%25, %19, %0) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () + llvm.return + } + } + func.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !gpu.async.token) attributes {llvm.emit_c_interface} { + %c63_i32 = arith.constant 63 : i32 + %0 = llvm.mlir.undef : !llvm.struct, struct)> + %1 = llvm.mlir.undef : !llvm.struct + %2 = llvm.mlir.undef : !llvm.struct + %c64_i32 = arith.constant 64 : i32 + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %3 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> + %4 = arith.addi %arg4, %c63_i32 : i32 + %5 = arith.floordivsi %4, %c64_i32 : i32 + %6 = arith.index_cast %5 : i32 to index + %7 = llvm.insertvalue %3, %2[0] : !llvm.struct + %8 = llvm.insertvalue %7, %0[0] : !llvm.struct, struct)> + %9 = llvm.insertvalue %1, %8[1] : !llvm.struct, struct)> + gpu.launch_func <%arg5 : !gpu.async.token> @kernels::@vectorAddKernel_0 blocks in (%6, %c1, %c1) threads in (%c64, %c1, %c1) args(%arg0 : !llvm.ptr<1>, %9 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) + return + } +} diff --git a/ir_dump/09_SCFToControlFlowPass.mlir b/ir_dump/09_SCFToControlFlowPass.mlir new file mode 100644 index 00000000..496e5e86 --- /dev/null +++ b/ir_dump/09_SCFToControlFlowPass.mlir @@ -0,0 +1,55 @@ +module attributes {gpu.container_module} { + gpu.module @kernels [#rocdl.target] attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"} { + llvm.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {gpu.kernel, rocdl.kernel} { + %0 = llvm.mlir.constant(4 : i64) : i64 + %1 = llvm.mlir.constant(1 : i64) : i64 + %2 = llvm.mlir.constant(64 : i32) : i32 + %3 = rocdl.workgroup.id.x : i32 + %4 = llvm.sext %3 : i32 to i64 + %5 = llvm.trunc %4 : i64 to i32 + %6 = rocdl.workitem.id.x : i32 + %7 = llvm.sext %6 : i32 to i64 + %8 = llvm.mul %5, %2 : i32 + %9 = llvm.sext %8 : i32 to i64 + %10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %11 = llvm.mul %5, %2 : i32 + %12 = llvm.sext %11 : i32 to i64 + %13 = llvm.getelementptr %arg2[%12] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %14 = llvm.mul %5, %2 : i32 + %15 = llvm.sext %14 : i32 to i64 + %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %17 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %18 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %19 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %20 = llvm.getelementptr %10[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%17, %20, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %21 = llvm.getelementptr %13[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%18, %21, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %22 = llvm.load %17 : !llvm.ptr<5> -> vector<1xf32> + %23 = llvm.load %18 : !llvm.ptr<5> -> vector<1xf32> + %24 = llvm.fadd %22, %23 : vector<1xf32> + llvm.store %24, %19 : vector<1xf32>, !llvm.ptr<5> + %25 = llvm.getelementptr %16[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%25, %19, %0) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () + llvm.return + } + } + func.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !gpu.async.token) attributes {llvm.emit_c_interface} { + %c63_i32 = arith.constant 63 : i32 + %0 = llvm.mlir.undef : !llvm.struct, struct)> + %1 = llvm.mlir.undef : !llvm.struct + %2 = llvm.mlir.undef : !llvm.struct + %c64_i32 = arith.constant 64 : i32 + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %3 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> + %4 = arith.addi %arg4, %c63_i32 : i32 + %5 = arith.floordivsi %4, %c64_i32 : i32 + %6 = arith.index_cast %5 : i32 to index + %7 = llvm.insertvalue %3, %2[0] : !llvm.struct + %8 = llvm.insertvalue %7, %0[0] : !llvm.struct, struct)> + %9 = llvm.insertvalue %1, %8[1] : !llvm.struct, struct)> + gpu.launch_func <%arg5 : !gpu.async.token> @kernels::@vectorAddKernel_0 blocks in (%6, %c1, %c1) threads in (%c64, %c1, %c1) args(%arg0 : !llvm.ptr<1>, %9 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) + return + } +} diff --git a/ir_dump/10_ConvertControlFlowToLLVMPass.mlir b/ir_dump/10_ConvertControlFlowToLLVMPass.mlir new file mode 100644 index 00000000..496e5e86 --- /dev/null +++ b/ir_dump/10_ConvertControlFlowToLLVMPass.mlir @@ -0,0 +1,55 @@ +module attributes {gpu.container_module} { + gpu.module @kernels [#rocdl.target] attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"} { + llvm.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {gpu.kernel, rocdl.kernel} { + %0 = llvm.mlir.constant(4 : i64) : i64 + %1 = llvm.mlir.constant(1 : i64) : i64 + %2 = llvm.mlir.constant(64 : i32) : i32 + %3 = rocdl.workgroup.id.x : i32 + %4 = llvm.sext %3 : i32 to i64 + %5 = llvm.trunc %4 : i64 to i32 + %6 = rocdl.workitem.id.x : i32 + %7 = llvm.sext %6 : i32 to i64 + %8 = llvm.mul %5, %2 : i32 + %9 = llvm.sext %8 : i32 to i64 + %10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %11 = llvm.mul %5, %2 : i32 + %12 = llvm.sext %11 : i32 to i64 + %13 = llvm.getelementptr %arg2[%12] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %14 = llvm.mul %5, %2 : i32 + %15 = llvm.sext %14 : i32 to i64 + %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %17 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %18 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %19 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %20 = llvm.getelementptr %10[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%17, %20, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %21 = llvm.getelementptr %13[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%18, %21, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %22 = llvm.load %17 : !llvm.ptr<5> -> vector<1xf32> + %23 = llvm.load %18 : !llvm.ptr<5> -> vector<1xf32> + %24 = llvm.fadd %22, %23 : vector<1xf32> + llvm.store %24, %19 : vector<1xf32>, !llvm.ptr<5> + %25 = llvm.getelementptr %16[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%25, %19, %0) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () + llvm.return + } + } + func.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !gpu.async.token) attributes {llvm.emit_c_interface} { + %c63_i32 = arith.constant 63 : i32 + %0 = llvm.mlir.undef : !llvm.struct, struct)> + %1 = llvm.mlir.undef : !llvm.struct + %2 = llvm.mlir.undef : !llvm.struct + %c64_i32 = arith.constant 64 : i32 + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %3 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> + %4 = arith.addi %arg4, %c63_i32 : i32 + %5 = arith.floordivsi %4, %c64_i32 : i32 + %6 = arith.index_cast %5 : i32 to index + %7 = llvm.insertvalue %3, %2[0] : !llvm.struct + %8 = llvm.insertvalue %7, %0[0] : !llvm.struct, struct)> + %9 = llvm.insertvalue %1, %8[1] : !llvm.struct, struct)> + gpu.launch_func <%arg5 : !gpu.async.token> @kernels::@vectorAddKernel_0 blocks in (%6, %c1, %c1) threads in (%c64, %c1, %c1) args(%arg0 : !llvm.ptr<1>, %9 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) + return + } +} diff --git a/ir_dump/11_FlyGpuStreamMarkPass.mlir b/ir_dump/11_FlyGpuStreamMarkPass.mlir new file mode 100644 index 00000000..ea5f7edc --- /dev/null +++ b/ir_dump/11_FlyGpuStreamMarkPass.mlir @@ -0,0 +1,55 @@ +module attributes {gpu.container_module} { + gpu.module @kernels [#rocdl.target] attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"} { + llvm.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {gpu.kernel, rocdl.kernel} { + %0 = llvm.mlir.constant(4 : i64) : i64 + %1 = llvm.mlir.constant(1 : i64) : i64 + %2 = llvm.mlir.constant(64 : i32) : i32 + %3 = rocdl.workgroup.id.x : i32 + %4 = llvm.sext %3 : i32 to i64 + %5 = llvm.trunc %4 : i64 to i32 + %6 = rocdl.workitem.id.x : i32 + %7 = llvm.sext %6 : i32 to i64 + %8 = llvm.mul %5, %2 : i32 + %9 = llvm.sext %8 : i32 to i64 + %10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %11 = llvm.mul %5, %2 : i32 + %12 = llvm.sext %11 : i32 to i64 + %13 = llvm.getelementptr %arg2[%12] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %14 = llvm.mul %5, %2 : i32 + %15 = llvm.sext %14 : i32 to i64 + %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %17 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %18 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %19 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %20 = llvm.getelementptr %10[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%17, %20, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %21 = llvm.getelementptr %13[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%18, %21, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %22 = llvm.load %17 : !llvm.ptr<5> -> vector<1xf32> + %23 = llvm.load %18 : !llvm.ptr<5> -> vector<1xf32> + %24 = llvm.fadd %22, %23 : vector<1xf32> + llvm.store %24, %19 : vector<1xf32>, !llvm.ptr<5> + %25 = llvm.getelementptr %16[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%25, %19, %0) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () + llvm.return + } + } + func.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !gpu.async.token) attributes {fly.stream_arg_index = 5 : index, llvm.emit_c_interface} { + %c63_i32 = arith.constant 63 : i32 + %0 = llvm.mlir.undef : !llvm.struct, struct)> + %1 = llvm.mlir.undef : !llvm.struct + %2 = llvm.mlir.undef : !llvm.struct + %c64_i32 = arith.constant 64 : i32 + %c1 = arith.constant 1 : index + %c64 = arith.constant 64 : index + %3 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> + %4 = arith.addi %arg4, %c63_i32 : i32 + %5 = arith.floordivsi %4, %c64_i32 : i32 + %6 = arith.index_cast %5 : i32 to index + %7 = llvm.insertvalue %3, %2[0] : !llvm.struct + %8 = llvm.insertvalue %7, %0[0] : !llvm.struct, struct)> + %9 = llvm.insertvalue %1, %8[1] : !llvm.struct, struct)> + gpu.launch_func <%arg5 : !gpu.async.token> @kernels::@vectorAddKernel_0 blocks in (%6, %c1, %c1) threads in (%c64, %c1, %c1) args(%arg0 : !llvm.ptr<1>, %9 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) + return + } +} diff --git a/ir_dump/12_GpuToLLVMConversionPass.mlir b/ir_dump/12_GpuToLLVMConversionPass.mlir new file mode 100644 index 00000000..ab9375d7 --- /dev/null +++ b/ir_dump/12_GpuToLLVMConversionPass.mlir @@ -0,0 +1,65 @@ +module attributes {gpu.container_module} { + gpu.module @kernels [#rocdl.target] attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"} { + llvm.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {gpu.kernel, rocdl.kernel} { + %0 = llvm.mlir.constant(4 : i64) : i64 + %1 = llvm.mlir.constant(1 : i64) : i64 + %2 = llvm.mlir.constant(64 : i32) : i32 + %3 = rocdl.workgroup.id.x : i32 + %4 = llvm.sext %3 : i32 to i64 + %5 = llvm.trunc %4 : i64 to i32 + %6 = rocdl.workitem.id.x : i32 + %7 = llvm.sext %6 : i32 to i64 + %8 = llvm.mul %5, %2 : i32 + %9 = llvm.sext %8 : i32 to i64 + %10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %11 = llvm.mul %5, %2 : i32 + %12 = llvm.sext %11 : i32 to i64 + %13 = llvm.getelementptr %arg2[%12] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %14 = llvm.mul %5, %2 : i32 + %15 = llvm.sext %14 : i32 to i64 + %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %17 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %18 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %19 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %20 = llvm.getelementptr %10[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%17, %20, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %21 = llvm.getelementptr %13[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%18, %21, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %22 = llvm.load %17 : !llvm.ptr<5> -> vector<1xf32> + %23 = llvm.load %18 : !llvm.ptr<5> -> vector<1xf32> + %24 = llvm.fadd %22, %23 : vector<1xf32> + llvm.store %24, %19 : vector<1xf32>, !llvm.ptr<5> + %25 = llvm.getelementptr %16[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%25, %19, %0) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () + llvm.return + } + } + llvm.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !llvm.ptr) attributes {fly.stream_arg_index = 5 : index, llvm.emit_c_interface} { + %0 = llvm.mlir.constant(63 : i32) : i32 + %1 = llvm.mlir.undef : !llvm.struct, struct)> + %2 = llvm.mlir.undef : !llvm.struct + %3 = llvm.mlir.undef : !llvm.struct + %4 = llvm.mlir.constant(64 : i32) : i32 + %5 = llvm.mlir.constant(1 : index) : i64 + %6 = llvm.mlir.constant(64 : index) : i64 + %7 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> + %8 = llvm.add %arg4, %0 : i32 + %9 = llvm.sdiv %8, %4 : i32 + %10 = llvm.mul %9, %4 : i32 + %11 = llvm.icmp "ne" %8, %10 : i32 + %12 = llvm.mlir.constant(0 : i32) : i32 + %13 = llvm.icmp "slt" %8, %12 : i32 + %14 = llvm.mlir.constant(false) : i1 + %15 = llvm.icmp "ne" %13, %14 : i1 + %16 = llvm.and %11, %15 : i1 + %17 = llvm.mlir.constant(-1 : i32) : i32 + %18 = llvm.add %9, %17 : i32 + %19 = llvm.select %16, %18, %9 : i1, i32 + %20 = llvm.sext %19 : i32 to i64 + %21 = llvm.insertvalue %7, %3[0] : !llvm.struct + %22 = llvm.insertvalue %21, %1[0] : !llvm.struct, struct)> + %23 = llvm.insertvalue %2, %22[1] : !llvm.struct, struct)> + gpu.launch_func @kernels::@vectorAddKernel_0 blocks in (%20, %5, %5) threads in (%6, %5, %5) : i64 args(%arg0 : !llvm.ptr<1>, %23 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) + llvm.return + } +} diff --git a/ir_dump/13_FlyGpuStreamInjectPass.mlir b/ir_dump/13_FlyGpuStreamInjectPass.mlir new file mode 100644 index 00000000..0144d3db --- /dev/null +++ b/ir_dump/13_FlyGpuStreamInjectPass.mlir @@ -0,0 +1,65 @@ +module attributes {gpu.container_module} { + gpu.module @kernels [#rocdl.target] attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"} { + llvm.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {gpu.kernel, rocdl.kernel} { + %0 = llvm.mlir.constant(4 : i64) : i64 + %1 = llvm.mlir.constant(1 : i64) : i64 + %2 = llvm.mlir.constant(64 : i32) : i32 + %3 = rocdl.workgroup.id.x : i32 + %4 = llvm.sext %3 : i32 to i64 + %5 = llvm.trunc %4 : i64 to i32 + %6 = rocdl.workitem.id.x : i32 + %7 = llvm.sext %6 : i32 to i64 + %8 = llvm.mul %5, %2 : i32 + %9 = llvm.sext %8 : i32 to i64 + %10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %11 = llvm.mul %5, %2 : i32 + %12 = llvm.sext %11 : i32 to i64 + %13 = llvm.getelementptr %arg2[%12] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %14 = llvm.mul %5, %2 : i32 + %15 = llvm.sext %14 : i32 to i64 + %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %17 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %18 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %19 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %20 = llvm.getelementptr %10[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%17, %20, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %21 = llvm.getelementptr %13[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%18, %21, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %22 = llvm.load %17 : !llvm.ptr<5> -> vector<1xf32> + %23 = llvm.load %18 : !llvm.ptr<5> -> vector<1xf32> + %24 = llvm.fadd %22, %23 : vector<1xf32> + llvm.store %24, %19 : vector<1xf32>, !llvm.ptr<5> + %25 = llvm.getelementptr %16[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%25, %19, %0) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () + llvm.return + } + } + llvm.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !llvm.ptr) attributes {llvm.emit_c_interface} { + %0 = llvm.mlir.constant(63 : i32) : i32 + %1 = llvm.mlir.undef : !llvm.struct, struct)> + %2 = llvm.mlir.undef : !llvm.struct + %3 = llvm.mlir.undef : !llvm.struct + %4 = llvm.mlir.constant(64 : i32) : i32 + %5 = llvm.mlir.constant(1 : index) : i64 + %6 = llvm.mlir.constant(64 : index) : i64 + %7 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> + %8 = llvm.add %arg4, %0 : i32 + %9 = llvm.sdiv %8, %4 : i32 + %10 = llvm.mul %9, %4 : i32 + %11 = llvm.icmp "ne" %8, %10 : i32 + %12 = llvm.mlir.constant(0 : i32) : i32 + %13 = llvm.icmp "slt" %8, %12 : i32 + %14 = llvm.mlir.constant(false) : i1 + %15 = llvm.icmp "ne" %13, %14 : i1 + %16 = llvm.and %11, %15 : i1 + %17 = llvm.mlir.constant(-1 : i32) : i32 + %18 = llvm.add %9, %17 : i32 + %19 = llvm.select %16, %18, %9 : i1, i32 + %20 = llvm.sext %19 : i32 to i64 + %21 = llvm.insertvalue %7, %3[0] : !llvm.struct + %22 = llvm.insertvalue %21, %1[0] : !llvm.struct, struct)> + %23 = llvm.insertvalue %2, %22[1] : !llvm.struct, struct)> + gpu.launch_func <%arg5 : !llvm.ptr> @kernels::@vectorAddKernel_0 blocks in (%20, %5, %5) threads in (%6, %5, %5) : i64 args(%arg0 : !llvm.ptr<1>, %23 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) + llvm.return + } +} diff --git a/ir_dump/14_ArithToLLVMConversionPass.mlir b/ir_dump/14_ArithToLLVMConversionPass.mlir new file mode 100644 index 00000000..0144d3db --- /dev/null +++ b/ir_dump/14_ArithToLLVMConversionPass.mlir @@ -0,0 +1,65 @@ +module attributes {gpu.container_module} { + gpu.module @kernels [#rocdl.target] attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"} { + llvm.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {gpu.kernel, rocdl.kernel} { + %0 = llvm.mlir.constant(4 : i64) : i64 + %1 = llvm.mlir.constant(1 : i64) : i64 + %2 = llvm.mlir.constant(64 : i32) : i32 + %3 = rocdl.workgroup.id.x : i32 + %4 = llvm.sext %3 : i32 to i64 + %5 = llvm.trunc %4 : i64 to i32 + %6 = rocdl.workitem.id.x : i32 + %7 = llvm.sext %6 : i32 to i64 + %8 = llvm.mul %5, %2 : i32 + %9 = llvm.sext %8 : i32 to i64 + %10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %11 = llvm.mul %5, %2 : i32 + %12 = llvm.sext %11 : i32 to i64 + %13 = llvm.getelementptr %arg2[%12] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %14 = llvm.mul %5, %2 : i32 + %15 = llvm.sext %14 : i32 to i64 + %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %17 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %18 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %19 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %20 = llvm.getelementptr %10[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%17, %20, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %21 = llvm.getelementptr %13[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%18, %21, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %22 = llvm.load %17 : !llvm.ptr<5> -> vector<1xf32> + %23 = llvm.load %18 : !llvm.ptr<5> -> vector<1xf32> + %24 = llvm.fadd %22, %23 : vector<1xf32> + llvm.store %24, %19 : vector<1xf32>, !llvm.ptr<5> + %25 = llvm.getelementptr %16[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%25, %19, %0) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () + llvm.return + } + } + llvm.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !llvm.ptr) attributes {llvm.emit_c_interface} { + %0 = llvm.mlir.constant(63 : i32) : i32 + %1 = llvm.mlir.undef : !llvm.struct, struct)> + %2 = llvm.mlir.undef : !llvm.struct + %3 = llvm.mlir.undef : !llvm.struct + %4 = llvm.mlir.constant(64 : i32) : i32 + %5 = llvm.mlir.constant(1 : index) : i64 + %6 = llvm.mlir.constant(64 : index) : i64 + %7 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> + %8 = llvm.add %arg4, %0 : i32 + %9 = llvm.sdiv %8, %4 : i32 + %10 = llvm.mul %9, %4 : i32 + %11 = llvm.icmp "ne" %8, %10 : i32 + %12 = llvm.mlir.constant(0 : i32) : i32 + %13 = llvm.icmp "slt" %8, %12 : i32 + %14 = llvm.mlir.constant(false) : i1 + %15 = llvm.icmp "ne" %13, %14 : i1 + %16 = llvm.and %11, %15 : i1 + %17 = llvm.mlir.constant(-1 : i32) : i32 + %18 = llvm.add %9, %17 : i32 + %19 = llvm.select %16, %18, %9 : i1, i32 + %20 = llvm.sext %19 : i32 to i64 + %21 = llvm.insertvalue %7, %3[0] : !llvm.struct + %22 = llvm.insertvalue %21, %1[0] : !llvm.struct, struct)> + %23 = llvm.insertvalue %2, %22[1] : !llvm.struct, struct)> + gpu.launch_func <%arg5 : !llvm.ptr> @kernels::@vectorAddKernel_0 blocks in (%20, %5, %5) threads in (%6, %5, %5) : i64 args(%arg0 : !llvm.ptr<1>, %23 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) + llvm.return + } +} diff --git a/ir_dump/15_ConvertFuncToLLVMPass.mlir b/ir_dump/15_ConvertFuncToLLVMPass.mlir new file mode 100644 index 00000000..0144d3db --- /dev/null +++ b/ir_dump/15_ConvertFuncToLLVMPass.mlir @@ -0,0 +1,65 @@ +module attributes {gpu.container_module} { + gpu.module @kernels [#rocdl.target] attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"} { + llvm.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {gpu.kernel, rocdl.kernel} { + %0 = llvm.mlir.constant(4 : i64) : i64 + %1 = llvm.mlir.constant(1 : i64) : i64 + %2 = llvm.mlir.constant(64 : i32) : i32 + %3 = rocdl.workgroup.id.x : i32 + %4 = llvm.sext %3 : i32 to i64 + %5 = llvm.trunc %4 : i64 to i32 + %6 = rocdl.workitem.id.x : i32 + %7 = llvm.sext %6 : i32 to i64 + %8 = llvm.mul %5, %2 : i32 + %9 = llvm.sext %8 : i32 to i64 + %10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %11 = llvm.mul %5, %2 : i32 + %12 = llvm.sext %11 : i32 to i64 + %13 = llvm.getelementptr %arg2[%12] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %14 = llvm.mul %5, %2 : i32 + %15 = llvm.sext %14 : i32 to i64 + %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %17 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %18 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %19 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %20 = llvm.getelementptr %10[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%17, %20, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %21 = llvm.getelementptr %13[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%18, %21, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %22 = llvm.load %17 : !llvm.ptr<5> -> vector<1xf32> + %23 = llvm.load %18 : !llvm.ptr<5> -> vector<1xf32> + %24 = llvm.fadd %22, %23 : vector<1xf32> + llvm.store %24, %19 : vector<1xf32>, !llvm.ptr<5> + %25 = llvm.getelementptr %16[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%25, %19, %0) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () + llvm.return + } + } + llvm.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !llvm.ptr) attributes {llvm.emit_c_interface} { + %0 = llvm.mlir.constant(63 : i32) : i32 + %1 = llvm.mlir.undef : !llvm.struct, struct)> + %2 = llvm.mlir.undef : !llvm.struct + %3 = llvm.mlir.undef : !llvm.struct + %4 = llvm.mlir.constant(64 : i32) : i32 + %5 = llvm.mlir.constant(1 : index) : i64 + %6 = llvm.mlir.constant(64 : index) : i64 + %7 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> + %8 = llvm.add %arg4, %0 : i32 + %9 = llvm.sdiv %8, %4 : i32 + %10 = llvm.mul %9, %4 : i32 + %11 = llvm.icmp "ne" %8, %10 : i32 + %12 = llvm.mlir.constant(0 : i32) : i32 + %13 = llvm.icmp "slt" %8, %12 : i32 + %14 = llvm.mlir.constant(false) : i1 + %15 = llvm.icmp "ne" %13, %14 : i1 + %16 = llvm.and %11, %15 : i1 + %17 = llvm.mlir.constant(-1 : i32) : i32 + %18 = llvm.add %9, %17 : i32 + %19 = llvm.select %16, %18, %9 : i1, i32 + %20 = llvm.sext %19 : i32 to i64 + %21 = llvm.insertvalue %7, %3[0] : !llvm.struct + %22 = llvm.insertvalue %21, %1[0] : !llvm.struct, struct)> + %23 = llvm.insertvalue %2, %22[1] : !llvm.struct, struct)> + gpu.launch_func <%arg5 : !llvm.ptr> @kernels::@vectorAddKernel_0 blocks in (%20, %5, %5) threads in (%6, %5, %5) : i64 args(%arg0 : !llvm.ptr<1>, %23 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) + llvm.return + } +} diff --git a/ir_dump/16_ReconcileUnrealizedCastsPass.mlir b/ir_dump/16_ReconcileUnrealizedCastsPass.mlir new file mode 100644 index 00000000..0144d3db --- /dev/null +++ b/ir_dump/16_ReconcileUnrealizedCastsPass.mlir @@ -0,0 +1,65 @@ +module attributes {gpu.container_module} { + gpu.module @kernels [#rocdl.target] attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"} { + llvm.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {gpu.kernel, rocdl.kernel} { + %0 = llvm.mlir.constant(4 : i64) : i64 + %1 = llvm.mlir.constant(1 : i64) : i64 + %2 = llvm.mlir.constant(64 : i32) : i32 + %3 = rocdl.workgroup.id.x : i32 + %4 = llvm.sext %3 : i32 to i64 + %5 = llvm.trunc %4 : i64 to i32 + %6 = rocdl.workitem.id.x : i32 + %7 = llvm.sext %6 : i32 to i64 + %8 = llvm.mul %5, %2 : i32 + %9 = llvm.sext %8 : i32 to i64 + %10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %11 = llvm.mul %5, %2 : i32 + %12 = llvm.sext %11 : i32 to i64 + %13 = llvm.getelementptr %arg2[%12] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %14 = llvm.mul %5, %2 : i32 + %15 = llvm.sext %14 : i32 to i64 + %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + %17 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %18 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %19 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> + %20 = llvm.getelementptr %10[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%17, %20, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %21 = llvm.getelementptr %13[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%18, %21, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () + %22 = llvm.load %17 : !llvm.ptr<5> -> vector<1xf32> + %23 = llvm.load %18 : !llvm.ptr<5> -> vector<1xf32> + %24 = llvm.fadd %22, %23 : vector<1xf32> + llvm.store %24, %19 : vector<1xf32>, !llvm.ptr<5> + %25 = llvm.getelementptr %16[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 + "llvm.intr.memcpy"(%25, %19, %0) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () + llvm.return + } + } + llvm.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !llvm.ptr) attributes {llvm.emit_c_interface} { + %0 = llvm.mlir.constant(63 : i32) : i32 + %1 = llvm.mlir.undef : !llvm.struct, struct)> + %2 = llvm.mlir.undef : !llvm.struct + %3 = llvm.mlir.undef : !llvm.struct + %4 = llvm.mlir.constant(64 : i32) : i32 + %5 = llvm.mlir.constant(1 : index) : i64 + %6 = llvm.mlir.constant(64 : index) : i64 + %7 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> + %8 = llvm.add %arg4, %0 : i32 + %9 = llvm.sdiv %8, %4 : i32 + %10 = llvm.mul %9, %4 : i32 + %11 = llvm.icmp "ne" %8, %10 : i32 + %12 = llvm.mlir.constant(0 : i32) : i32 + %13 = llvm.icmp "slt" %8, %12 : i32 + %14 = llvm.mlir.constant(false) : i1 + %15 = llvm.icmp "ne" %13, %14 : i1 + %16 = llvm.and %11, %15 : i1 + %17 = llvm.mlir.constant(-1 : i32) : i32 + %18 = llvm.add %9, %17 : i32 + %19 = llvm.select %16, %18, %9 : i1, i32 + %20 = llvm.sext %19 : i32 to i64 + %21 = llvm.insertvalue %7, %3[0] : !llvm.struct + %22 = llvm.insertvalue %21, %1[0] : !llvm.struct, struct)> + %23 = llvm.insertvalue %2, %22[1] : !llvm.struct, struct)> + gpu.launch_func <%arg5 : !llvm.ptr> @kernels::@vectorAddKernel_0 blocks in (%20, %5, %5) threads in (%6, %5, %5) : i64 args(%arg0 : !llvm.ptr<1>, %23 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) + llvm.return + } +} diff --git a/ir_dump/17_GpuModuleToBinaryPass.mlir b/ir_dump/17_GpuModuleToBinaryPass.mlir new file mode 100644 index 00000000..354e120d --- /dev/null +++ b/ir_dump/17_GpuModuleToBinaryPass.mlir @@ -0,0 +1,44 @@ +module attributes {gpu.container_module} { + gpu.binary @kernels [#gpu.object<#rocdl.target, kernels = <[#gpu.kernel_metadata<"vectorAddKernel_0", !llvm.func, struct, struct)>, ptr<1>, ptr<1>)>, metadata = {agpr_count = 0 : i64, group_segment_fixed_size = 0 : i64, max_flat_workgroup_size = 256 : i64, private_segment_fixed_size = 0 : i64, reqd_workgroup_size = array, sgpr_count = 16 : i64, sgpr_spill_count = 0 : i64, vgpr_count = 3 : i64, vgpr_spill_count = 0 : i64, wavefront_size = 64 : i64, workgroup_size_hint = array}>]>, bin = "\7FELF\02\01\01@\04\00\00\00\00\00\00\00\03\00\E0\00\01\00\00\00\00\00\00\00\00\00\00\00@\00\00\00\00\00\00\00\A8\0F\00\00\00\00\00\00L\05\00\00@\008\00\08\00@\00\0F\00\0D\00\06\00\00\00\04\00\00\00@\00\00\00\00\00\00\00@\00\00\00\00\00\00\00@\00\00\00\00\00\00\00\C0\01\00\00\00\00\00\00\C0\01\00\00\00\00\00\00\08\00\00\00\00\00\00\00\01\00\00\00\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\C0\05\00\00\00\00\00\00\C0\05\00\00\00\00\00\00\00\10\00\00\00\00\00\00\01\00\00\00\05\00\00\00\00\06\00\00\00\00\00\00\00\16\00\00\00\00\00\00\00\16\00\00\00\00\00\00\80\04\00\00\00\00\00\00\80\04\00\00\00\00\00\00\00\10\00\00\00\00\00\00\01\00\00\00\06\00\00\00\80\0A\00\00\00\00\00\00\80*\00\00\00\00\00\00\80*\00\00\00\00\00\00p\00\00\00\00\00\00\00\80\05\00\00\00\00\00\00\00\10\00\00\00\00\00\00\02\00\00\00\06\00\00\00\80\0A\00\00\00\00\00\00\80*\00\00\00\00\00\00\80*\00\00\00\00\00\00p\00\00\00\00\00\00\00p\00\00\00\00\00\00\00\08\00\00\00\00\00\00\00R\E5td\04\00\00\00\80\0A\00\00\00\00\00\00\80*\00\00\00\00\00\00\80*\00\00\00\00\00\00p\00\00\00\00\00\00\00\80\05\00\00\00\00\00\00\01\00\00\00\00\00\00\00Q\E5td\06\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\04\00\00\00\04\00\00\00\00\02\00\00\00\00\00\00\00\02\00\00\00\00\00\00\00\02\00\00\00\00\00\00\98\02\00\00\00\00\00\00\98\02\00\00\00\00\00\00\04\00\00\00\00\00\00\00\07\00\00\00\83\02\00\00 \00\00\00AMDGPU\00\00\83\AEamdhsa.kernels\91\DE\00\10\AB.agpr_count\00\A5.args\94\84\AE.address_space\A6global\A7.offset\00\A5.size\08\AB.value_kind\ADglobal_buffer\83\A7.offset\08\A5.size\04\AB.value_kind\A8by_value\84\AE.address_space\A6global\A7.offset\10\A5.size\08\AB.value_kind\ADglobal_buffer\84\AE.address_space\A6global\A7.offset\18\A5.size\08\AB.value_kind\ADglobal_buffer\B9.group_segment_fixed_size\00\B6.kernarg_segment_align\08\B5.kernarg_segment_size \B8.max_flat_workgroup_size\CD\01\00\A5.name\B1vectorAddKernel_0\BB.private_segment_fixed_size\00\AB.sgpr_count\10\B1.sgpr_spill_count\00\A7.symbol\B4vectorAddKernel_0.kd\B8.uniform_work_group_size\01\B3.uses_dynamic_stack\C2\AB.vgpr_count\03\B1.vgpr_spill_count\00\AF.wavefront_size@\ADamdhsa.target\B9amdgcn-amd-amdhsa--gfx942\AEamdhsa.version\92\01\02\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\12\03\07\00\00\16\00\00\00\00\00\00`\00\00\00\00\00\00\00\13\00\00\00\11\00\06\00\80\05\00\00\00\00\00\00@\00\00\00\00\00\00\00\01\00\00\00\01\00\00\00\01\00\00\00\1A\00\00\00\00\00\00\00\00\C0\02\00\01\00\00\00\B0\CA%\C6\EFj+\BF\03\00\00\00\03\00\00\00\01\00\00\00\02\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00vectorAddKernel_0\00vectorAddKernel_0.kd\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00 \00\00\00\00\00\00\00\80\10\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00@\00\AF\00\84\00\00\00\08\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\02\06\C0\00\00\00\00\00\01\0A\C0\10\00\00\00\02\86\00\8E\00\9F\01\90\00\82\80\8E\7F\C0\8C\BF\08\00\02\80\09\01\03\82\04\00\04\80\82\00\00$\05\01\05\82\00\80P\DC\00\00\02\01\00\80P\DC\00\00\04\02\06\00\00\80\07\01\01\82p\0F\8C\BF\01\05\02\02\00\80p\DC\00\01\00\00\00\00\81\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\06\00\00\00\00\00\00\00\98\04\00\00\00\00\00\00\0B\00\00\00\00\00\00\00\18\00\00\00\00\00\00\00\05\00\00\00\00\00\00\00$\05\00\00\00\00\00\00\0A\00\00\00\00\00\00\00(\00\00\00\00\00\00\00\F5\FE\FFo\00\00\00\00\E0\04\00\00\00\00\00\00\04\00\00\00\00\00\00\00\04\05\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00Linker: AMD LLD 20.0.0 (/longer_pathname_so_that_rpms_can_support_packaging_the_debug_info_for_all_os_profiles/src/llvm-project/llvm 1b0eada6b0ee93e2e694c8c146d23fca90bc11c5)\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\F1\FF\03\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\1C\00\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\007\00\00\00\00\00\F1\FF\0A\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00W\00\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00{\00\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\9E\00\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\B9\00\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\DD\00\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\03\01\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00#\01\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00G\01\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00[\01\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00o\01\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\83\01\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\C7\01\00\00\00\02\08\00\80*\00\00\00\00\00\00\00\00\00\00\00\00\00\00\A0\01\00\00\12\03\07\00\00\16\00\00\00\00\00\00`\00\00\00\00\00\00\00\B2\01\00\00\11\00\06\00\80\05\00\00\00\00\00\00@\00\00\00\00\00\00\00\00.note\00.dynsym\00.gnu.hash\00.hash\00.dynstr\00.rodata\00.text\00.dynamic\00.relro_padding\00.AMDGPU.gpr_maximums\00.comment\00.symtab\00.shstrtab\00.strtab\00\00vectorAddKernel_0.num_vgpr\00vectorAddKernel_0.num_agpr\00vectorAddKernel_0.numbered_sgpr\00vectorAddKernel_0.num_named_barrier\00vectorAddKernel_0.private_seg_size\00vectorAddKernel_0.uses_vcc\00vectorAddKernel_0.uses_flat_scratch\00vectorAddKernel_0.has_dyn_sized_stack\00vectorAddKernel_0.has_recursion\00vectorAddKernel_0.has_indirect_call\00amdgpu.max_num_vgpr\00amdgpu.max_num_agpr\00amdgpu.max_num_sgpr\00amdgpu.max_num_named_barrier\00vectorAddKernel_0\00vectorAddKernel_0.kd\00_DYNAMIC\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\07\00\00\00\02\00\00\00\00\00\00\00\00\02\00\00\00\00\00\00\00\02\00\00\00\00\00\00\98\02\00\00\00\00\00\00\00\00\00\00\00\00\00\00\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\07\00\00\00\0B\00\00\00\02\00\00\00\00\00\00\00\98\04\00\00\00\00\00\00\98\04\00\00\00\00\00\00H\00\00\00\00\00\00\00\05\00\00\00\01\00\00\00\08\00\00\00\00\00\00\00\18\00\00\00\00\00\00\00\0F\00\00\00\F6\FF\FFo\02\00\00\00\00\00\00\00\E0\04\00\00\00\00\00\00\E0\04\00\00\00\00\00\00$\00\00\00\00\00\00\00\02\00\00\00\00\00\00\00\08\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\19\00\00\00\05\00\00\00\02\00\00\00\00\00\00\00\04\05\00\00\00\00\00\00\04\05\00\00\00\00\00\00 \00\00\00\00\00\00\00\02\00\00\00\00\00\00\00\04\00\00\00\00\00\00\00\04\00\00\00\00\00\00\00\1F\00\00\00\03\00\00\00\02\00\00\00\00\00\00\00$\05\00\00\00\00\00\00$\05\00\00\00\00\00\00(\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00'\00\00\00\01\00\00\00\02\00\00\00\00\00\00\00\80\05\00\00\00\00\00\00\80\05\00\00\00\00\00\00@\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00@\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00/\00\00\00\01\00\00\00\06\00\00\00\00\00\00\00\00\16\00\00\00\00\00\00\00\06\00\00\00\00\00\00\80\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\005\00\00\00\06\00\00\00\03\00\00\00\00\00\00\00\80*\00\00\00\00\00\00\80\0A\00\00\00\00\00\00p\00\00\00\00\00\00\00\05\00\00\00\00\00\00\00\08\00\00\00\00\00\00\00\10\00\00\00\00\00\00\00>\00\00\00\08\00\00\00\03\00\00\00\00\00\00\00\F0*\00\00\00\00\00\00\F0\0A\00\00\00\00\00\00\10\05\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00M\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\F0\0A\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00b\00\00\00\01\00\00\000\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\F0\0A\00\00\00\00\00\00\AF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00k\00\00\00\02\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\A0\0B\00\00\00\00\00\00\B0\01\00\00\00\00\00\00\0E\00\00\00\10\00\00\00\08\00\00\00\00\00\00\00\18\00\00\00\00\00\00\00s\00\00\00\03\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00P\0D\00\00\00\00\00\00\85\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00}\00\00\00\03\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\D5\0D\00\00\00\00\00\00\D0\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00">] + llvm.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !llvm.ptr) attributes {llvm.emit_c_interface} { + %0 = llvm.mlir.constant(63 : i32) : i32 + %1 = llvm.mlir.undef : !llvm.struct, struct)> + %2 = llvm.mlir.undef : !llvm.struct + %3 = llvm.mlir.undef : !llvm.struct + %4 = llvm.mlir.constant(64 : i32) : i32 + %5 = llvm.mlir.constant(1 : index) : i64 + %6 = llvm.mlir.constant(64 : index) : i64 + %7 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> + %8 = llvm.add %arg4, %0 : i32 + %9 = llvm.sdiv %8, %4 : i32 + %10 = llvm.mul %9, %4 : i32 + %11 = llvm.icmp "ne" %8, %10 : i32 + %12 = llvm.mlir.constant(0 : i32) : i32 + %13 = llvm.icmp "slt" %8, %12 : i32 + %14 = llvm.mlir.constant(false) : i1 + %15 = llvm.icmp "ne" %13, %14 : i1 + %16 = llvm.and %11, %15 : i1 + %17 = llvm.mlir.constant(-1 : i32) : i32 + %18 = llvm.add %9, %17 : i32 + %19 = llvm.select %16, %18, %9 : i1, i32 + %20 = llvm.sext %19 : i32 to i64 + %21 = llvm.insertvalue %7, %3[0] : !llvm.struct + %22 = llvm.insertvalue %21, %1[0] : !llvm.struct, struct)> + %23 = llvm.insertvalue %2, %22[1] : !llvm.struct, struct)> + gpu.launch_func <%arg5 : !llvm.ptr> @kernels::@vectorAddKernel_0 blocks in (%20, %5, %5) threads in (%6, %5, %5) : i64 args(%arg0 : !llvm.ptr<1>, %23 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) + llvm.return + } +} + + +================================================== +Test 1: Eager execution +================================================== +[Eager] Result correct: True + +================================================== +Test 2: CUDA Graph Capture +================================================== +[Graph Capture] Result correct: True + +All passed: True diff --git a/kernels/wmma_gemm_gfx1250.py b/kernels/wmma_gemm_gfx1250.py new file mode 100644 index 00000000..cc3e8892 --- /dev/null +++ b/kernels/wmma_gemm_gfx1250.py @@ -0,0 +1,556 @@ +"""TDM async copy WMMA GEMM kernel for gfx1250. + +Supports double-buffer (2-stage) and triple-buffer (3-stage) pipelining +with TDM (Tensor Data Mover) hardware async copy for both A and B tiles. +""" + +import flydsl.compiler as flyc +import flydsl.expr as fx + +from flydsl._mlir import ir +from flydsl.compiler.kernel_function import CompilationContext +from flydsl.expr import arith, buffer_ops, gpu, range_constexpr, rocdl, tdm_ops, vector +from flydsl.expr.arith import _to_raw as _raw +from flydsl.expr.typing import T +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl._mlir.dialects import memref as memref_d +from flydsl.expr.gpu import lds_space +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr, get_op_result_or_value +from flydsl._mlir.extras import types as mlir_T + +from kernels.layout_utils import crd2idx, idx2crd + +WMMA_M, WMMA_N, WMMA_K = 16, 16, 32 +WAVE_SIZE = 32 + +LDS_PAD_A = 8 +LDS_PAD_B = 8 + + +def compile_wmma_gemm_tdm( + *, + M: int = 0, + N: int = 0, + K: int, + tile_m: int = 256, + tile_n: int = 256, + tile_k: int = 128, + m_warp: int = 2, + n_warp: int = 4, + in_dtype: str = "fp16", + num_buffers: int = 2, + waves_per_eu: int = None, + l2_prefetch_distance: int = 2, +): + """Compile a WMMA GEMM kernel with TDM async copy and multi-stage buffering. + + Returns a JitFunction: launch_fn(arg_c, arg_a, arg_b, M, N, stream) + + Args: + num_buffers: Number of LDS buffers (2=double, 3=triple buffering). + waves_per_eu: Occupancy hint (None = default, 1-4 = limit occupancy). + l2_prefetch_distance: Number of k-tiles ahead to prefetch into L2. + 0 = disabled, 2 = typical value. + """ + _ = (M, N) + if num_buffers not in (2, 3): + raise ValueError(f"num_buffers must be 2 or 3, got {num_buffers}") + use_triple_buffer = num_buffers == 3 + if in_dtype not in ("fp16", "bf16"): + raise ValueError(f"in_dtype must be 'fp16' or 'bf16', got {in_dtype!r}") + is_f16 = in_dtype == "fp16" + elem_bytes = 2 + + block_threads = m_warp * n_warp * WAVE_SIZE + + if K % tile_k != 0: + raise ValueError(f"K must be divisible by tile_k={tile_k}, got K={K}") + if tile_k % WMMA_K != 0: + raise ValueError(f"tile_k must be a multiple of {WMMA_K}, got {tile_k}") + if tile_m % WMMA_M != 0: + raise ValueError(f"tile_m must be a multiple of {WMMA_M}, got {tile_m}") + if tile_n % WMMA_N != 0: + raise ValueError(f"tile_n must be a multiple of {WMMA_N}, got {tile_n}") + if (tile_k & (tile_k - 1)) != 0: + raise ValueError(f"tile_k must be a power of 2 for TDM async copy, got {tile_k}") + + warp_tile_m = tile_m // m_warp + warp_tile_n = tile_n // n_warp + if warp_tile_m % WMMA_M != 0: + raise ValueError(f"warp_tile_m={warp_tile_m} must be a multiple of {WMMA_M}") + if warp_tile_n % WMMA_N != 0: + raise ValueError(f"warp_tile_n={warp_tile_n} must be a multiple of {WMMA_N}") + + num_k_tiles = K // tile_k + if num_k_tiles < num_buffers: + raise ValueError( + f"{num_buffers}-stage buffering requires num_k_tiles >= {num_buffers}, " + f"got {num_k_tiles} (K={K}, tile_k={tile_k})") + + gpu_arch = str(get_hip_arch(timeout_s=300)) + assert gpu_arch.startswith("gfx1250"), f"Expected gfx1250, got {gpu_arch}" + + wmma_op = rocdl.wmma_f32_16x16x32_f16 if is_f16 else rocdl.wmma_f32_16x16x32_bf16 + k_wmma_steps = tile_k // WMMA_K + + def _elem_type(): + return T.f16 if is_f16 else T.bf16 + + wmma_m_rep = warp_tile_m // WMMA_M + wmma_n_rep = warp_tile_n // WMMA_N + n_accs = wmma_m_rep * wmma_n_rep + + lds_a_stride = tile_k + LDS_PAD_A + lds_b_stride = tile_k + LDS_PAD_B + lds_a_elems = tile_m * lds_a_stride + LDS_PAD_A + lds_b_elems = tile_n * lds_b_stride + LDS_PAD_A + + buf_size_elems = lds_a_elems + lds_b_elems + + # --- LDS allocation --- + num_warps = m_warp * n_warp + + if use_triple_buffer: + # Triple-buffer: 3 separate allocators (ping/pong/pang) + allocator_ping = SmemAllocator(None, arch=gpu_arch, global_sym_name="wmma_tdm_ping") + allocator_pong = SmemAllocator(None, arch=gpu_arch, global_sym_name="wmma_tdm_pong") + allocator_pang = SmemAllocator(None, arch=gpu_arch, global_sym_name="wmma_tdm_pang") + + ping_offset = allocator_ping._align(allocator_ping.ptr, 16) + allocator_ping.ptr = ping_offset + buf_size_elems * elem_bytes + pong_offset = allocator_pong._align(allocator_pong.ptr, 16) + allocator_pong.ptr = pong_offset + buf_size_elems * elem_bytes + pang_offset = allocator_pang._align(allocator_pang.ptr, 16) + allocator_pang.ptr = pang_offset + buf_size_elems * elem_bytes + + lds_a_offset_ping = ping_offset + lds_b_offset_ping = ping_offset + lds_a_elems * elem_bytes + lds_a_offset_pong = pong_offset + lds_b_offset_pong = pong_offset + lds_a_elems * elem_bytes + lds_a_offset_pang = pang_offset + lds_b_offset_pang = pang_offset + lds_a_elems * elem_bytes + + allocator_dbuf = None + else: + # Double-buffer: unified allocator with dynamic buffer selection + allocator_ping = None + allocator_pong = None + allocator_pang = None + + allocator_dbuf = SmemAllocator(None, arch=gpu_arch, global_sym_name="wmma_tdm_dbuf") + _dbuf0_off = allocator_dbuf._align(allocator_dbuf.ptr, 16) + allocator_dbuf.ptr = _dbuf0_off + buf_size_elems * elem_bytes + _dbuf1_off = allocator_dbuf._align(allocator_dbuf.ptr, 16) + allocator_dbuf.ptr = _dbuf1_off + buf_size_elems * elem_bytes + + lds_a_off_b0 = _dbuf0_off + lds_b_off_b0 = _dbuf0_off + lds_a_elems * elem_bytes + lds_a_off_b1 = _dbuf1_off + lds_b_off_b1 = _dbuf1_off + lds_a_elems * elem_bytes + + @flyc.kernel + def kernel_wmma_gemm_tdm( + arg_c: fx.Tensor, + arg_a: fx.Tensor, + arg_b: fx.Tensor, + i32_m: fx.Int32, + i32_n: fx.Int32, + ): + tx = gpu.thread_id("x") + bx = gpu.block_id("x") + by = gpu.block_id("y") + + n_stride = arith.index_cast(T.index, i32_n.ir_value()) + blk_m = bx * arith.index(tile_m) + blk_n = by * arith.index(tile_n) + + # --- Thread/wave decomposition --- + layout_thr = fx.make_layout( + (m_warp, n_warp, 2, 16), + (n_warp * WAVE_SIZE, WAVE_SIZE, 16, 1)) + thr_coord = idx2crd(tx, layout_thr) + wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( + thr_coord[0], thr_coord[1], thr_coord[2], thr_coord[3]) + + warp_m_base = wave_m_idx * arith.index(warp_tile_m) + warp_n_base = wave_n_idx * arith.index(warp_tile_n) + + elem_ty = _elem_type() + + # --- Buffer resources --- + m_idx = arith.index_cast(T.index, i32_m.ir_value()) + a_nrec = m_idx * arith.index(K * elem_bytes) + b_nrec = n_stride * arith.index(K * elem_bytes) + c_nrec = m_idx * n_stride * arith.index(4) # f32 output + a_rsrc = buffer_ops.create_buffer_resource(arg_a, num_records_bytes=a_nrec) + b_rsrc = buffer_ops.create_buffer_resource(arg_b, num_records_bytes=b_nrec) + c_rsrc = buffer_ops.create_buffer_resource(arg_c, num_records_bytes=c_nrec) + + # --- TDM async copy helpers --- + def copy_a_to_lds(k_base, lds_a_mem_ref): + desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_a, lds_memref=lds_a_mem_ref, + global_offset=(blk_m, k_base), + tensor_shape=(tile_m, tile_k), strides=(K, 1), + tile_shape=(tile_m, tile_k), elem_bytes=elem_bytes, + pad_interval=tile_k, pad_amount=LDS_PAD_A, + num_warps=num_warps) + tdm_ops.tensor_load_2d(desc) + + def copy_b_to_lds(k_base, lds_b_mem_ref): + desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_b, lds_memref=lds_b_mem_ref, + global_offset=(blk_n, k_base), + tensor_shape=(tile_n, tile_k), strides=(K, 1), + tile_shape=(tile_n, tile_k), elem_bytes=elem_bytes, + pad_interval=tile_k, pad_amount=LDS_PAD_B, + num_warps=num_warps) + tdm_ops.tensor_load_2d(desc) + + layout_smem_a = fx.make_layout((tile_m, lds_a_stride), (lds_a_stride, 1)) + layout_smem_b = fx.make_layout((tile_n, lds_b_stride), (lds_b_stride, 1)) + + # --- LDS load helpers --- + from flydsl._mlir.dialects import vector as vec_d + + FRAG_K_ELEMS = 8 + + def _get_lds_memref(lds_ptr): + """Get the raw memref value from SmemPtr or raw memref.""" + if isinstance(lds_ptr, SmemPtr): + return get_op_result_or_value(lds_ptr.get()) + return get_op_result_or_value(lds_ptr) + + def load_wmma_frag(lds_ptr, row_base, k_base, lds_layout): + """Load one 16x32 WMMA fragment from LDS using vectorized 128-bit loads. + + Uses vector.load to read 8 contiguous fp16 elements at once, + avoiding scalar load + v_perm/v_alignbit overhead. + Two 128-bit loads per fragment (2 × 8 fp16 = 16 fp16 values). + """ + raw_memref = _get_lds_memref(lds_ptr) + row = row_base + lane16 + vec8_ty = ir.VectorType.get([8], elem_ty) + + # Two K-groups per fragment: + # Group 0 (values 0-7): k = k_base + lane_kgrp * 8 + # Group 1 (values 8-15): k = k_base + (2 + lane_kgrp) * 8 + k0 = k_base + lane_kgrp * arith.index(8) + k1 = k_base + (arith.index(2) + lane_kgrp) * arith.index(8) + + off0 = crd2idx((row, k0), lds_layout) + off1 = crd2idx((row, k1), lds_layout) + + idx0 = [get_op_result_or_value(off0)] + idx1 = [get_op_result_or_value(off1)] + + v0 = vec_d.load(vec8_ty, raw_memref, idx0) + v1 = vec_d.load(vec8_ty, raw_memref, idx1) + + # Concatenate two vec<8> into vec<16> via vector.shuffle + mask = ir.DenseI64ArrayAttr.get(list(range(16))) + return vec_d.shuffle(v0, v1, mask) + + # --- K-subtile load/compute helpers --- + # Number of ds_load_b128 per K-subtile: + # B frags: wmma_n_rep * 2, A frags: wmma_m_rep * 2 + LOADS_PER_SUBTILE = (wmma_m_rep + wmma_n_rep) * 2 + + def load_k_subtile_frags(lds_a_ptr, lds_b_ptr, ks): + """Batch-load all A and B fragments for one K-subtile (no wait).""" + k_off = arith.index(ks * WMMA_K) + + b_frags = [load_wmma_frag( + lds_b_ptr, warp_n_base + arith.index(wn * WMMA_N), + k_off, layout_smem_b) + for wn in range_constexpr(wmma_n_rep)] + + a_frags = [load_wmma_frag( + lds_a_ptr, warp_m_base + arith.index(wm * WMMA_M), + k_off, layout_smem_a) + for wm in range_constexpr(wmma_m_rep)] + + return a_frags, b_frags + + def do_k_subtile_wmma(a_frags, b_frags, accs): + """Execute all WMMAs for one K-subtile using pre-loaded fragments.""" + for wm in range_constexpr(wmma_m_rep): + for wn in range_constexpr(wmma_n_rep): + idx = wm * wmma_n_rep + wn + accs[idx] = wmma_op( + T.vec(8, T.f32), + b_frags[wn], a_frags[wm], + accs[idx], + signA=False, signB=False, modC=0, + reuseA=False, reuseB=False, + ).result + return accs + + # --- Compute on one LDS buffer (K-subtile pipelined) --- + def compute_tile(accs_in, lds_a_ptr, lds_b_ptr): + rocdl.sched_barrier(0) + current_accs = list(accs_in) + + if k_wmma_steps == 1: + a_frags, b_frags = load_k_subtile_frags(lds_a_ptr, lds_b_ptr, 0) + rocdl.s_wait_dscnt(0) + current_accs = do_k_subtile_wmma(a_frags, b_frags, current_accs) + else: + # Prologue: batch-load K-subtile 0 + prev_a, prev_b = load_k_subtile_frags(lds_a_ptr, lds_b_ptr, 0) + + # Main K-loop: overlap load[ks+1] with compute[ks] + for ks in range_constexpr(k_wmma_steps - 1): + next_a, next_b = load_k_subtile_frags( + lds_a_ptr, lds_b_ptr, ks + 1) + rocdl.s_wait_dscnt(LOADS_PER_SUBTILE) + current_accs = do_k_subtile_wmma(prev_a, prev_b, current_accs) + prev_a, prev_b = next_a, next_b + + # Epilogue: wait for last subtile, then compute + rocdl.s_wait_dscnt(0) + current_accs = do_k_subtile_wmma(prev_a, prev_b, current_accs) + + return current_accs + + # --- Scheduling --- + def hot_loop_scheduler(): + rocdl.sched_barrier(0) + + # --- Epilogue: vectorized buffer_store_b128 --- + def epilogue(final_accs): + for wm in range_constexpr(wmma_m_rep): + for wn in range_constexpr(wmma_n_rep): + idx = wm * wmma_n_rep + wn + row = blk_m + warp_m_base + arith.index(wm * WMMA_M) + lane16 + col_base = (blk_n + warp_n_base + arith.index(wn * WMMA_N) + + lane_kgrp * arith.index(8)) + for half in range_constexpr(2): + col = col_base + arith.index(half * 4) + c_off = row * n_stride + col + vals = [vector.extract( + final_accs[idx], + static_position=[half * 4 + vi], + dynamic_position=[]) + for vi in range_constexpr(4)] + vec4 = vector.from_elements(T.vec(4, T.f32), vals) + buffer_ops.buffer_store(vec4, c_rsrc, c_off) + + # --- Pipeline helpers --- + def wait_and_barrier(outstanding=0): + tdm_ops.tensor_wait(outstanding) + gpu.barrier() + + def _compute_and_schedule(accs_in, lds_a, lds_b): + accs_out = compute_tile(accs_in, lds_a, lds_b) + hot_loop_scheduler() + return accs_out + + def _l2_prefetch(k_base): + if l2_prefetch_distance <= 0: + return + pf_k = k_base + arith.index(l2_prefetch_distance * tile_k) + tdm_ops.l2_prefetch_tile( + arg_a, (blk_m, pf_k), (tile_m, tile_k), (K, 1), + elem_bytes=elem_bytes, thread_id=tx, block_threads=block_threads) + tdm_ops.l2_prefetch_tile( + arg_b, (blk_n, pf_k), (tile_n, tile_k), (K, 1), + elem_bytes=elem_bytes, thread_id=tx, block_threads=block_threads) + + # ====== Multi-stage pipeline ====== + acc_zero = arith.constant_vector(0.0, T.vec(8, T.f32)) + accs = [acc_zero] * n_accs + + if use_triple_buffer: + # ====== Triple-buffer 3-stage pipeline ====== + base_ptr_ping = allocator_ping.get_base() + base_ptr_pong = allocator_pong.get_base() + base_ptr_pang = allocator_pang.get_base() + + lds_a_ping = SmemPtr(base_ptr_ping, lds_a_offset_ping, elem_ty, shape=(lds_a_elems,)) + lds_a_pong = SmemPtr(base_ptr_pong, lds_a_offset_pong, elem_ty, shape=(lds_a_elems,)) + lds_a_pang = SmemPtr(base_ptr_pang, lds_a_offset_pang, elem_ty, shape=(lds_a_elems,)) + lds_b_ping = SmemPtr(base_ptr_ping, lds_b_offset_ping, elem_ty, shape=(lds_b_elems,)) + lds_b_pong = SmemPtr(base_ptr_pong, lds_b_offset_pong, elem_ty, shape=(lds_b_elems,)) + lds_b_pang = SmemPtr(base_ptr_pang, lds_b_offset_pang, elem_ty, shape=(lds_b_elems,)) + + lds_a_ping_mem = lds_a_ping.get() + lds_a_pong_mem = lds_a_pong.get() + lds_a_pang_mem = lds_a_pang.get() + lds_b_ping_mem = lds_b_ping.get() + lds_b_pong_mem = lds_b_pong.get() + lds_b_pang_mem = lds_b_pang.get() + + # Prologue: load first 2 tiles + copy_b_to_lds(arith.index(0), lds_b_pong_mem) + copy_a_to_lds(arith.index(0), lds_a_pong_mem) + copy_b_to_lds(arith.index(tile_k), lds_b_ping_mem) + copy_a_to_lds(arith.index(tile_k), lds_a_ping_mem) + wait_and_barrier(outstanding=2) + + _safe_iters = max(0, num_k_tiles - 2) // 3 + _tiles_in_loop = _safe_iters * 3 + _tail_start = _tiles_in_loop + _n_tail = num_k_tiles - _tail_start + safe_loop_bound = _safe_iters * 3 * tile_k + + if _safe_iters > 0: + for iv, state in range(0, safe_loop_bound, tile_k * 3, init=list(accs)): + accs_in = list(state) + + copy_a_to_lds(iv + arith.index(tile_k * 2), lds_a_pang_mem) + copy_b_to_lds(iv + arith.index(tile_k * 2), lds_b_pang_mem) + _l2_prefetch(iv) + accs_in = _compute_and_schedule(accs_in, lds_a_pong, lds_b_pong) + wait_and_barrier(outstanding=2) + + copy_a_to_lds(iv + arith.index(tile_k * 3), lds_a_pong_mem) + copy_b_to_lds(iv + arith.index(tile_k * 3), lds_b_pong_mem) + _l2_prefetch(iv + arith.index(tile_k)) + accs_in = _compute_and_schedule(accs_in, lds_a_ping, lds_b_ping) + wait_and_barrier(outstanding=2) + + copy_a_to_lds(iv + arith.index(tile_k * 4), lds_a_ping_mem) + copy_b_to_lds(iv + arith.index(tile_k * 4), lds_b_ping_mem) + _l2_prefetch(iv + arith.index(tile_k * 2)) + accs_in = _compute_and_schedule(accs_in, lds_a_pang, lds_b_pang) + wait_and_barrier(outstanding=2) + + results = yield list(accs_in) + accs = list(results) + + t0 = _tail_start + if _n_tail == 2: + accs = _compute_and_schedule(accs, lds_a_pong, lds_b_pong) + wait_and_barrier(outstanding=0) + accs = compute_tile(accs, lds_a_ping, lds_b_ping) + elif _n_tail == 3: + copy_a_to_lds(arith.index((t0 + 2) * tile_k), lds_a_pang_mem) + copy_b_to_lds(arith.index((t0 + 2) * tile_k), lds_b_pang_mem) + accs = _compute_and_schedule(accs, lds_a_pong, lds_b_pong) + wait_and_barrier(outstanding=2) + accs = _compute_and_schedule(accs, lds_a_ping, lds_b_ping) + wait_and_barrier(outstanding=0) + accs = compute_tile(accs, lds_a_pang, lds_b_pang) + elif _n_tail == 4: + copy_a_to_lds(arith.index((t0 + 2) * tile_k), lds_a_pang_mem) + copy_b_to_lds(arith.index((t0 + 2) * tile_k), lds_b_pang_mem) + accs = _compute_and_schedule(accs, lds_a_pong, lds_b_pong) + wait_and_barrier(outstanding=2) + copy_a_to_lds(arith.index((t0 + 3) * tile_k), lds_a_pong_mem) + copy_b_to_lds(arith.index((t0 + 3) * tile_k), lds_b_pong_mem) + accs = _compute_and_schedule(accs, lds_a_ping, lds_b_ping) + wait_and_barrier(outstanding=2) + accs = _compute_and_schedule(accs, lds_a_pang, lds_b_pang) + wait_and_barrier(outstanding=0) + accs = compute_tile(accs, lds_a_pong, lds_b_pong) + else: + raise RuntimeError(f"Unexpected _n_tail={_n_tail}") + + else: + # ====== Double-buffer 2-stage SCF pipeline ====== + base_uni = allocator_dbuf.get_base() + _a_vtype = mlir_T.memref(lds_a_elems, elem_ty, memory_space=lds_space()) + _b_vtype = mlir_T.memref(lds_b_elems, elem_ty, memory_space=lds_space()) + + def _mk_a_view(off_val): + return memref_d.view(_a_vtype, base_uni, off_val, sizes=[]) + def _mk_b_view(off_val): + return memref_d.view(_b_vtype, base_uni, off_val, sizes=[]) + + c_a0 = arith.index(lds_a_off_b0) + c_b0 = arith.index(lds_b_off_b0) + c_a1 = arith.index(lds_a_off_b1) + c_b1 = arith.index(lds_b_off_b1) + + # Prologue: load k=0 → buf0 + copy_a_to_lds(arith.index(0), _mk_a_view(c_a0)) + copy_b_to_lds(arith.index(0), _mk_b_view(c_b0)) + wait_and_barrier() + + # Main loop: each iteration loads NEXT tile, computes CURRENT + main_end = (num_k_tiles - 1) * tile_k + init_st = list(accs) + [arith.index(0)] + + for iv, state in range(0, main_end, tile_k, init=init_st): + accs_in = list(state[:n_accs]) + buf_flag = state[n_accs] + is_buf0 = arith.cmpi(arith.CmpIPredicate.eq, buf_flag, arith.index(0)) + + comp_a = arith.select(is_buf0, c_a0, c_a1) + comp_b = arith.select(is_buf0, c_b0, c_b1) + load_a = arith.select(is_buf0, c_a1, c_a0) + load_b = arith.select(is_buf0, c_b1, c_b0) + + next_k = iv + arith.index(tile_k) + copy_a_to_lds(next_k, _mk_a_view(load_a)) + copy_b_to_lds(next_k, _mk_b_view(load_b)) + _l2_prefetch(iv) + + accs_in = compute_tile(accs_in, _mk_a_view(comp_a), _mk_b_view(comp_b)) + hot_loop_scheduler() + wait_and_barrier(outstanding=2) + + next_flag = arith.select(is_buf0, arith.index(1), arith.index(0)) + results = yield list(accs_in) + [next_flag] + + accs = list(results[:n_accs]) + last_flag = results[n_accs] + + # Tail: compute the last loaded tile + is_last_b0 = arith.cmpi(arith.CmpIPredicate.eq, last_flag, arith.index(0)) + tail_a = arith.select(is_last_b0, c_a0, c_a1) + tail_b = arith.select(is_last_b0, c_b0, c_b1) + accs = compute_tile(accs, _mk_a_view(tail_a), _mk_b_view(tail_b)) + + epilogue(accs) + + cache_tag = (in_dtype, K, tile_m, tile_n, tile_k, m_warp, n_warp, + num_buffers, waves_per_eu, l2_prefetch_distance) + + @flyc.jit + def launch_wmma_gemm_tdm( + arg_c: fx.Tensor, + arg_a: fx.Tensor, + arg_b: fx.Tensor, + i32_m: fx.Int32, + i32_n: fx.Int32, + stream: fx.Stream, + ): + _ = cache_tag + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + if use_triple_buffer: + allocator_ping.finalized = False + allocator_pong.finalized = False + allocator_pang.finalized = False + allocator_ping.finalize() + allocator_pong.finalize() + allocator_pang.finalize() + else: + allocator_dbuf.finalized = False + allocator_dbuf.finalize() + + idx_m = arith.index_cast(T.index, i32_m.ir_value()) + idx_n = arith.index_cast(T.index, i32_n.ir_value()) + gx = _raw((idx_m + arith.index(tile_m - 1)) / arith.index(tile_m)) + gy = _raw((idx_n + arith.index(tile_n - 1)) / arith.index(tile_n)) + + launcher = kernel_wmma_gemm_tdm(arg_c, arg_a, arg_b, i32_m, i32_n) + if waves_per_eu is not None: + _wpe = int(waves_per_eu) + if _wpe >= 1: + for op in ctx.gpu_module_body.operations: + if hasattr(op, 'attributes') and op.OPERATION_NAME == "gpu.func": + op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get( + ir.IntegerType.get_signless(32), _wpe) + launcher.launch( + grid=(gx, gy, 1), + block=(block_threads, 1, 1), + stream=stream, + ) + + return launch_wmma_gemm_tdm + + +__all__ = ["compile_wmma_gemm_tdm"] diff --git a/kernels/wmma_gemm_simple.py b/kernels/wmma_gemm_simple.py new file mode 100644 index 00000000..0187369d --- /dev/null +++ b/kernels/wmma_gemm_simple.py @@ -0,0 +1,255 @@ +"""WMMA GEMM kernel for gfx1250.""" + +import flydsl.compiler as flyc +import flydsl.expr as fx + +from flydsl._mlir import ir +from flydsl.compiler.kernel_function import CompilationContext +from flydsl.expr import arith, buffer_ops, gpu, range_constexpr, rocdl, vector +from flydsl.expr.arith import _to_raw as _raw +from flydsl.expr.typing import T +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr + +from kernels.layout_utils import crd2idx, idx2crd, get as layout_get + +WMMA_M, WMMA_N, WMMA_K = 16, 16, 32 +WAVE_SIZE = 32 + + +def compile_wmma_gemm( + *, + M: int = 0, + N: int = 0, + K: int, + tile_m: int = 64, + tile_n: int = 128, + tile_k: int = WMMA_K, + in_dtype: str = "fp16", + block_threads: int = 128, +): + """Compile a WMMA GEMM kernel using the @flyc.kernel API. + + Returns a JitFunction that auto-compiles and executes when called. + Signature: launch_fn(arg_c, arg_a, arg_b, arg_scale_a, arg_scale_b, M, N, stream) + + Compile-time constants: K, tile_m/n/k, in_dtype (determine loop structure). + Runtime parameters: M, N (passed as i32 kernel args). + """ + _ = (M, N) + if in_dtype not in ("fp16", "bf16"): + raise ValueError(f"in_dtype must be 'fp16' or 'bf16', got {in_dtype!r}") + is_fp4 = in_dtype == "fp4" + is_int4 = in_dtype == "int4" + is_int8 = (in_dtype == "int8") or is_int4 + is_f16 = in_dtype == "fp16" + is_bf16 = in_dtype == "bf16" + is_f16_or_bf16 = is_f16 or is_bf16 + elem_bytes = 1 if (in_dtype in ("fp8", "int8", "int4", "fp4")) else 2 + + if K % tile_k != 0: + raise ValueError(f"K must be divisible by tile_k={tile_k}, got K={K}") + if tile_k % WMMA_K != 0: + raise ValueError(f"tile_k must be a multiple of {WMMA_K}, got {tile_k}") + if tile_m % WMMA_M != 0: + raise ValueError(f"tile_m must be a multiple of {WMMA_M}, got {tile_m}") + + waves_per_block = block_threads // WAVE_SIZE + if tile_n % (waves_per_block * WMMA_N) != 0: + raise ValueError( + f"tile_n must be a multiple of waves_per_block*{WMMA_N}={waves_per_block * WMMA_N}, got {tile_n}" + ) + + gpu_arch = str(get_hip_arch(timeout_s=300)) + assert gpu_arch.startswith("gfx1250"), f"Expected a gfx1250 architecture, got {gpu_arch}" + + wmma_op = rocdl.wmma_f32_16x16x32_f16 if is_f16 else rocdl.wmma_f32_16x16x32_bf16 + k_wmma_steps = tile_k // WMMA_K + + def _elem_type(): + return T.f16 if is_f16 else T.bf16 + + warp_tile_n = tile_n // waves_per_block + wmma_m_rep = tile_m // WMMA_M + wmma_n_rep = warp_tile_n // WMMA_N + n_accs = wmma_m_rep * wmma_n_rep + + lds_a_elems = tile_m * tile_k + lds_b_elems = tile_k * tile_n + lds_a_offset = 0 + lds_b_offset = lds_a_elems * elem_bytes + + allocator = SmemAllocator(None, arch=gpu_arch, global_sym_name="wmma_gemm_smem") + allocator.ptr = lds_b_offset + lds_b_elems * elem_bytes + + total_vec_a = tile_m * (tile_k // 4) + total_vec_b = tile_k * (tile_n // 4) + if total_vec_a % block_threads != 0 or total_vec_b % block_threads != 0: + raise ValueError( + f"vectorized copy requires vec slots divisible by block_threads: " + f"A={total_vec_a}, B={total_vec_b}, block_threads={block_threads}" + ) + vec_iters_a = total_vec_a // block_threads + vec_iters_b = total_vec_b // block_threads + + @flyc.kernel + def kernel_wmma_gemm( + arg_c: fx.Tensor, + arg_a: fx.Tensor, + arg_b: fx.Tensor, + i32_m: fx.Int32, + i32_n: fx.Int32, + ): + tx = gpu.thread_id("x") + bx = gpu.block_id("x") + by = gpu.block_id("y") + + n_stride = arith.index_cast(T.index, i32_n.ir_value()) + blk_m = bx * arith.index(tile_m) + blk_n = by * arith.index(tile_n) + + layout_thr = fx.make_layout((waves_per_block, WAVE_SIZE), (WAVE_SIZE, 1)) + layout_lane = fx.make_layout((2, 16), (16, 1)) + layout_lds_a = fx.make_layout((tile_m, tile_k), (tile_k, 1)) + layout_lds_b = fx.make_layout((tile_k, tile_n), (tile_n, 1)) + layout_vec_a = fx.make_layout((tile_m, tile_k // 4), (tile_k // 4, 1)) + layout_vec_b = fx.make_layout((tile_k, tile_n // 4), (tile_n // 4, 1)) + + thr = idx2crd(tx, layout_thr) + wave_id = layout_get(thr, 0) + lane = layout_get(thr, 1) + + lc = idx2crd(lane, layout_lane) + lane_kgrp = layout_get(lc, 0) # 0/1 + lane16 = layout_get(lc, 1) # 0..15 + warp_n_off = wave_id * arith.index(warp_tile_n) + + elem_ty = _elem_type() + base_ptr = allocator.get_base() + lds_a = SmemPtr(base_ptr, lds_a_offset, elem_ty, shape=(lds_a_elems,)) + lds_b = SmemPtr(base_ptr, lds_b_offset, elem_ty, shape=(lds_b_elems,)) + lds_a_mem = lds_a.get() + lds_b_mem = lds_b.get() + + a_rsrc = buffer_ops.create_buffer_resource(arg_a, max_size=True) + b_rsrc = buffer_ops.create_buffer_resource(arg_b, max_size=True) + vec4_elem_ty = T.vec(4, elem_ty) + + acc_zero = arith.constant_vector(0.0, T.vec(8, T.f32)) + accs = [acc_zero] * n_accs + + for kblk in range_constexpr(K // tile_k): + k_base = arith.index(kblk * tile_k) + + for t in range_constexpr(vec_iters_a): + vec_idx = tx + arith.index(t * block_threads) + a_crd = idx2crd(vec_idx, layout_vec_a) + a_m = layout_get(a_crd, 0) + a_kv = layout_get(a_crd, 1) + a_k = a_kv * arith.index(4) + + g_off = (blk_m + a_m) * arith.index(K) + (k_base + a_k) + v_i16 = buffer_ops.buffer_load(a_rsrc, g_off, vec_width=4, dtype=T.i16) + v = vector.bitcast(vec4_elem_ty, v_i16) + lds_off = crd2idx((a_m, a_k), layout_lds_a) + vector.store(v, lds_a_mem, [lds_off]) + + for t in range_constexpr(vec_iters_b): + vec_idx = tx + arith.index(t * block_threads) + b_crd = idx2crd(vec_idx, layout_vec_b) + b_k = layout_get(b_crd, 0) + b_nv = layout_get(b_crd, 1) + b_n = b_nv * arith.index(4) + + g_off = (k_base + b_k) * n_stride + (blk_n + b_n) + v_i16 = buffer_ops.buffer_load(b_rsrc, g_off, vec_width=4, dtype=T.i16) + v = vector.bitcast(vec4_elem_ty, v_i16) + lds_off = crd2idx((b_k, b_n), layout_lds_b) + vector.store(v, lds_b_mem, [lds_off]) + + gpu.barrier() + + for ks in range_constexpr(k_wmma_steps): + k_step = arith.index(ks * WMMA_K) + + b_frags = [] + for wn in range_constexpr(wmma_n_rep): + n_off = warp_n_off + arith.index(wn * WMMA_N) + vals = [] + for k0 in range_constexpr(2): + for k1 in range_constexpr(8): + kk = k_step + (arith.index(k0 * 2) + lane_kgrp) * arith.index(8) + arith.index(k1) + off = crd2idx((kk, n_off + lane16), layout_lds_b) + vals.append(lds_b.load([off])) + b_frags.append(vector.from_elements(T.vec(16, elem_ty), vals)) + + for wm in range_constexpr(wmma_m_rep): + m_off = arith.index(wm * WMMA_M) + a_vals = [] + for k0 in range_constexpr(2): + for k1 in range_constexpr(8): + kk = k_step + (arith.index(k0 * 2) + lane_kgrp) * arith.index(8) + arith.index(k1) + off = crd2idx((m_off + lane16, kk), layout_lds_a) + a_vals.append(lds_a.load([off])) + a_frag = vector.from_elements(T.vec(16, elem_ty), a_vals) + + for wn in range_constexpr(wmma_n_rep): + acc_idx = wm * wmma_n_rep + wn + accs[acc_idx] = wmma_op( + T.vec(8, T.f32), + a_frag, + b_frags[wn], + accs[acc_idx], + signA=False, + signB=False, + modC=0, + reuseA=False, + reuseB=False, + ).result + + gpu.barrier() + + for wm in range_constexpr(wmma_m_rep): + for wn in range_constexpr(wmma_n_rep): + acc_idx = wm * wmma_n_rep + wn + m_base = blk_m + arith.index(wm * WMMA_M) + n_base = blk_n + warp_n_off + arith.index(wn * WMMA_N) + for mi in range_constexpr(8): + row = m_base + lane_kgrp * arith.index(8) + arith.index(mi) + col = n_base + lane16 + c_off = row * n_stride + col + c_val = vector.extract(accs[acc_idx], static_position=[mi], dynamic_position=[]) + fx.memref_store(c_val, arg_c, c_off) + + cache_tag = (in_dtype, K, tile_m, tile_n, tile_k, block_threads) + + @flyc.jit + def launch_wmma_gemm( + arg_c: fx.Tensor, + arg_a: fx.Tensor, + arg_b: fx.Tensor, + i32_m: fx.Int32, + i32_n: fx.Int32, + stream: fx.Stream, + ): + _ = cache_tag + allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + idx_m = arith.index_cast(T.index, i32_m.ir_value()) + idx_n = arith.index_cast(T.index, i32_n.ir_value()) + gx = _raw((idx_m + arith.index(tile_m - 1)) / arith.index(tile_m)) + gy = _raw((idx_n + arith.index(tile_n - 1)) / arith.index(tile_n)) + + kernel_wmma_gemm(arg_c, arg_a, arg_b, i32_m, i32_n).launch( + grid=(gx, gy, 1), + block=(block_threads, 1, 1), + stream=stream, + ) + + return launch_wmma_gemm + + +__all__ = ["compile_wmma_gemm"] diff --git a/lib/Bindings/Python/FlyROCDLExtension.cpp b/lib/Bindings/Python/FlyROCDLExtension.cpp index 12aeb117..330e482d 100644 --- a/lib/Bindings/Python/FlyROCDLExtension.cpp +++ b/lib/Bindings/Python/FlyROCDLExtension.cpp @@ -99,6 +99,80 @@ struct PyMmaAtomCDNA3_MFMAType : PyConcreteType { } }; +struct PyMmaAtomGFX1250_WMMAType : PyConcreteType { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFlyROCDLMmaAtomGFX1250_WMMAType; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetTypeID; + static constexpr const char *pyClassName = "MmaAtomGFX1250_WMMAType"; + using Base::Base; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](int32_t m, int32_t n, int32_t k, PyType &elemTyA, PyType &elemTyB, PyType &elemTyAcc, + DefaultingPyMlirContext context) { + return PyMmaAtomGFX1250_WMMAType(context->getRef(), + wrap(::mlir::fly_rocdl::MmaAtomGFX1250_WMMAType::get( + m, n, k, unwrap(static_cast(elemTyA)), + unwrap(static_cast(elemTyB)), + unwrap(static_cast(elemTyAcc))))); + }, + "m"_a, "n"_a, "k"_a, "elem_ty_a"_a, "elem_ty_b"_a, "elem_ty_acc"_a, nb::kw_only(), + "context"_a = nb::none(), + "Create a MmaAtomGFX1250_WMMAType with m, n, k dimensions and element types"); + + c.def_prop_ro("m", [](PyMmaAtomGFX1250_WMMAType &self) -> int32_t { + return mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetM(self); + }); + c.def_prop_ro("n", [](PyMmaAtomGFX1250_WMMAType &self) -> int32_t { + return mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetN(self); + }); + c.def_prop_ro("k", [](PyMmaAtomGFX1250_WMMAType &self) -> int32_t { + return mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetK(self); + }); + c.def_prop_ro("elem_ty_a", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { + return mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetElemTyA(self); + }); + c.def_prop_ro("elem_ty_b", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { + return mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetElemTyB(self); + }); + c.def_prop_ro("elem_ty_acc", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { + return mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetElemTyAcc(self); + }); + + c.def_prop_ro("thr_layout", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { + auto ty = + ::mlir::cast<::mlir::fly::MmaAtomTypeInterface>(unwrap(static_cast(self))); + auto attr = ::mlir::cast<::mlir::fly::LayoutAttr>(ty.getThrLayout()); + return wrap(::mlir::fly::LayoutType::get(attr)); + }); + c.def_prop_ro("shape_mnk", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { + auto ty = + ::mlir::cast<::mlir::fly::MmaAtomTypeInterface>(unwrap(static_cast(self))); + auto attr = ::mlir::cast<::mlir::fly::IntTupleAttr>(ty.getShapeMNK()); + return wrap(::mlir::fly::IntTupleType::get(attr)); + }); + c.def_prop_ro("tv_layout_a", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { + auto ty = + ::mlir::cast<::mlir::fly::MmaAtomTypeInterface>(unwrap(static_cast(self))); + auto attr = ::mlir::cast<::mlir::fly::LayoutAttr>(ty.getThrValLayoutA()); + return wrap(::mlir::fly::LayoutType::get(attr)); + }); + c.def_prop_ro("tv_layout_b", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { + auto ty = + ::mlir::cast<::mlir::fly::MmaAtomTypeInterface>(unwrap(static_cast(self))); + auto attr = ::mlir::cast<::mlir::fly::LayoutAttr>(ty.getThrValLayoutB()); + return wrap(::mlir::fly::LayoutType::get(attr)); + }); + c.def_prop_ro("tv_layout_c", [](PyMmaAtomGFX1250_WMMAType &self) -> MlirType { + auto ty = + ::mlir::cast<::mlir::fly::MmaAtomTypeInterface>(unwrap(static_cast(self))); + auto attr = ::mlir::cast<::mlir::fly::LayoutAttr>(ty.getThrValLayoutC()); + return wrap(::mlir::fly::LayoutType::get(attr)); + }); + } +}; + struct PyCopyOpCDNA3BufferLDSTType : PyConcreteType { static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFlyROCDLCopyOpCDNA3BufferLDSTType; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -131,5 +205,6 @@ NB_MODULE(_fly_rocdl, m) { m.doc() = "MLIR Python FlyROCDL Extension"; ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly_rocdl::PyMmaAtomCDNA3_MFMAType::bind(m); + ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly_rocdl::PyMmaAtomGFX1250_WMMAType::bind(m); ::mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::fly_rocdl::PyCopyOpCDNA3BufferLDSTType::bind(m); } diff --git a/lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp b/lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp index 07bdcf51..6fc89ee7 100644 --- a/lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp +++ b/lib/CAPI/Dialect/FlyROCDL/FlyROCDLDialect.cpp @@ -55,6 +55,50 @@ MlirType mlirFlyROCDLMmaAtomCDNA3_MFMATypeGetElemTyAcc(MlirType type) { return wrap(cast(unwrap(type)).getElemTyAcc()); } +//===----------------------------------------------------------------------===// +// MmaAtomGFX1250_WMMAType +//===----------------------------------------------------------------------===// + +bool mlirTypeIsAFlyROCDLMmaAtomGFX1250_WMMAType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetTypeID(void) { + return wrap(MmaAtomGFX1250_WMMAType::getTypeID()); +} + +MlirType mlirFlyROCDLMmaAtomGFX1250_WMMATypeGet(int32_t m, int32_t n, int32_t k, + MlirType elemTyA, + MlirType elemTyB, + MlirType elemTyAcc) { + return wrap(MmaAtomGFX1250_WMMAType::get(m, n, k, unwrap(elemTyA), + unwrap(elemTyB), unwrap(elemTyAcc))); +} + +int32_t mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetM(MlirType type) { + return cast(unwrap(type)).getM(); +} + +int32_t mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetN(MlirType type) { + return cast(unwrap(type)).getN(); +} + +int32_t mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetK(MlirType type) { + return cast(unwrap(type)).getK(); +} + +MlirType mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetElemTyA(MlirType type) { + return wrap(cast(unwrap(type)).getElemTyA()); +} + +MlirType mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetElemTyB(MlirType type) { + return wrap(cast(unwrap(type)).getElemTyB()); +} + +MlirType mlirFlyROCDLMmaAtomGFX1250_WMMATypeGetElemTyAcc(MlirType type) { + return wrap(cast(unwrap(type)).getElemTyAcc()); +} + //===----------------------------------------------------------------------===// // CopyOpCDNA3BufferLDSTType //===----------------------------------------------------------------------===// diff --git a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp index 433ba625..9681a7cb 100644 --- a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp +++ b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp @@ -693,6 +693,8 @@ class MmaAtomCallLowering : public OpConversionPattern { return lowerUniversalFMA(op, rewriter, loc, universalFma, dPtr, aPtr, bPtr, cPtr); else if (auto cdna3Mfma = dyn_cast(mmaAtomType)) return lowerCDNA3MFMA(op, rewriter, loc, cdna3Mfma, dPtr, aPtr, bPtr, cPtr); + else if (auto gfx1250Wmma = dyn_cast(mmaAtomType)) + return lowerGFX1250WMMA(op, rewriter, loc, gfx1250Wmma, dPtr, aPtr, bPtr, cPtr); return rewriter.notifyMatchFailure(op, "unsupported MmaAtom type"); } @@ -849,6 +851,172 @@ class MmaAtomCallLowering : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "no matching ROCDL MFMA intrinsic"); } + + static Type getWmmaABType(MLIRContext *ctx, int32_t m, int32_t k, Type elemTy) { + if (m <= 0 || k <= 0) + return nullptr; + + Type i32Ty = IntegerType::get(ctx, 32); + + // fp8/bf8 WMMA operands are packed into i32 vectors. + if (isF8(elemTy)) { + if (k == 16) + return VectorType::get({2}, i32Ty); + if (k == 64) + return VectorType::get({8}, i32Ty); + if (k == 128) + return VectorType::get({16}, i32Ty); + return nullptr; + } + + // Integer WMMA operands are packed into i32 vectors. + if (elemTy.isInteger(8)) { + if (k == 16 || k == 32) + return VectorType::get({4}, i32Ty); + if (k == 64) + return VectorType::get({8}, i32Ty); + return nullptr; + } + + int64_t abElemsPerLane = static_cast(m) * static_cast(k) / 32; + if (abElemsPerLane <= 0 || (static_cast(m) * static_cast(k)) % 32 != 0) + return nullptr; + return VectorType::get({abElemsPerLane}, elemTy); + } + + static int64_t getWmmaAccVecSize(int32_t m, int32_t k, Type elemTyA, Type elemTyB, + Type elemTyAcc) { + // Current backend wiring only dispatches ROCDL ops that exist in this + // MLIR version; keep sizing generic per supported WMMA shape/type family. + if (m != 16) + return 0; + + // NOTE: rocdl.wmma.f64.16x16x4.f64 is not exposed in the current MLIR + // ROCDL dialect build, so f64 is intentionally not dispatched here. + if (k == 4 && elemTyA.isF32() && elemTyB.isF32() && elemTyAcc.isF32()) + return 8; + + if (k == 32 && elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF32()) + return 8; + if (k == 32 && elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF16()) + return 8; + if (k == 32 && elemTyA.isBF16() && elemTyB.isBF16() && elemTyAcc.isF32()) + return 8; + if (k == 32 && elemTyA.isBF16() && elemTyB.isBF16() && elemTyAcc.isBF16()) + return 8; + + if (k == 64 && isF8(elemTyA) && isF8(elemTyB) && elemTyAcc.isF32()) + return 8; + if (k == 64 && isF8(elemTyA) && isF8(elemTyB) && elemTyAcc.isF16()) + return 8; + if (k == 128 && isF8(elemTyA) && isF8(elemTyB) && elemTyAcc.isF32()) + return 8; + if (k == 128 && isF8(elemTyA) && isF8(elemTyB) && elemTyAcc.isF16()) + return 8; + + if (k == 64 && elemTyA.isInteger(8) && elemTyB.isInteger(8) && elemTyAcc.isInteger(32)) + return 8; + + return 0; + } + + enum class WmmaVariant { ModsAllReuse, ModsC, ModsABClamp }; + + template + LogicalResult emitWmma(MmaAtomCall op, ConversionPatternRewriter &rewriter, Location loc, + Type abTyA, Type abTyB, VectorType accTy, Value aPtr, Value bPtr, + Value cPtr, Value dPtr) const { + Value a = LLVM::LoadOp::create(rewriter, loc, abTyA, aPtr); + Value b = LLVM::LoadOp::create(rewriter, loc, abTyB, bPtr); + Value c = LLVM::LoadOp::create(rewriter, loc, accTy, cPtr); + Value res; + if constexpr (Variant == WmmaVariant::ModsAllReuse) { + res = WmmaOp::create(rewriter, loc, accTy, + /*signA=*/false, a, /*signB=*/false, b, + /*modC=*/(uint16_t)0, c) + .getResult(); + } else if constexpr (Variant == WmmaVariant::ModsC) { + res = WmmaOp::create(rewriter, loc, accTy, a, b, + /*modC=*/(uint16_t)0, c, + /*reuseA=*/false, /*reuseB=*/false) + .getResult(); + } else { + static_assert(Variant == WmmaVariant::ModsABClamp); + res = WmmaOp::create(rewriter, loc, accTy, + /*signA=*/false, a, /*signB=*/false, b, c, + /*reuseA=*/false, /*reuseB=*/false, /*clamp=*/false) + .getResult(); + } + LLVM::StoreOp::create(rewriter, loc, res, dPtr); + rewriter.eraseOp(op); + return success(); + } + + LogicalResult lowerGFX1250WMMA(MmaAtomCall op, ConversionPatternRewriter &rewriter, Location loc, + fly_rocdl::MmaAtomGFX1250_WMMAType atomTy, Value dPtr, Value aPtr, + Value bPtr, Value cPtr) const { + int32_t m = atomTy.getM(); + int32_t n = atomTy.getN(); + int32_t k = atomTy.getK(); + Type elemTyA = atomTy.getElemTyA(); + Type elemTyB = atomTy.getElemTyB(); + Type elemTyAcc = atomTy.getElemTyAcc(); + MLIRContext *ctx = rewriter.getContext(); + + Type abTyA = getWmmaABType(ctx, m, k, elemTyA); + Type abTyB = getWmmaABType(ctx, m, k, elemTyB); + if (!abTyA || !abTyB) + return rewriter.notifyMatchFailure(op, "unsupported A/B element packing for WMMA"); + + int64_t accVecSize = getWmmaAccVecSize(m, k, elemTyA, elemTyB, elemTyAcc); + if (accVecSize == 0) + return rewriter.notifyMatchFailure(op, "unsupported MNK/type combination for WMMA"); + + VectorType accTy = VectorType::get({accVecSize}, elemTyAcc); + +#define DISPATCH_WMMA(M_, K_, PRED, OP, VARIANT) \ + if (m == M_ && n == M_ && k == K_ && (PRED)) \ + return emitWmma(op, rewriter, loc, abTyA, abTyB, accTy, \ + aPtr, bPtr, cPtr, dPtr); + +#define DISPATCH_WMMA_FP8(K_, ACC_PRED, ACC_PREFIX) \ + DISPATCH_WMMA(16, K_, isFP8(elemTyA) && isFP8(elemTyB) && ACC_PRED, \ + wmma_##ACC_PREFIX##_16x16x##K_##_fp8_fp8, ModsC) \ + DISPATCH_WMMA(16, K_, isFP8(elemTyA) && isBF8(elemTyB) && ACC_PRED, \ + wmma_##ACC_PREFIX##_16x16x##K_##_fp8_bf8, ModsC) \ + DISPATCH_WMMA(16, K_, isBF8(elemTyA) && isFP8(elemTyB) && ACC_PRED, \ + wmma_##ACC_PREFIX##_16x16x##K_##_bf8_fp8, ModsC) \ + DISPATCH_WMMA(16, K_, isBF8(elemTyA) && isBF8(elemTyB) && ACC_PRED, \ + wmma_##ACC_PREFIX##_16x16x##K_##_bf8_bf8, ModsC) + + DISPATCH_WMMA(16, 4, elemTyA.isF32() && elemTyB.isF32() && elemTyAcc.isF32(), + wmma_f32_16x16x4_f32, ModsAllReuse) + + DISPATCH_WMMA(16, 32, elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF32(), + wmma_f32_16x16x32_f16, ModsAllReuse) + DISPATCH_WMMA(16, 32, elemTyA.isBF16() && elemTyB.isBF16() && elemTyAcc.isF32(), + wmma_f32_16x16x32_bf16, ModsAllReuse) + DISPATCH_WMMA(16, 32, elemTyA.isF16() && elemTyB.isF16() && elemTyAcc.isF16(), + wmma_f16_16x16x32_f16, ModsAllReuse) + DISPATCH_WMMA(16, 32, elemTyA.isBF16() && elemTyB.isBF16() && elemTyAcc.isBF16(), + wmma_bf16_16x16x32_bf16, ModsAllReuse) + + // bf16f32 WMMA requires C:f32 and D:bf16. Current MmaAtom interface carries + // one accumulator type, so mixed C/D typing is not representable yet. + + DISPATCH_WMMA_FP8(64, elemTyAcc.isF32(), f32) + DISPATCH_WMMA_FP8(64, elemTyAcc.isF16(), f16) + DISPATCH_WMMA_FP8(128, elemTyAcc.isF32(), f32) + DISPATCH_WMMA_FP8(128, elemTyAcc.isF16(), f16) + + DISPATCH_WMMA(16, 64, elemTyA.isInteger(8) && elemTyB.isInteger(8) && elemTyAcc.isInteger(32), + wmma_i32_16x16x64_iu8, ModsABClamp) + +#undef DISPATCH_WMMA_FP8 +#undef DISPATCH_WMMA + + return rewriter.notifyMatchFailure(op, "no matching ROCDL WMMA intrinsic"); + } }; /// Lower `gpu.launch_func` kernel operands so that any `!fly.memref` values are diff --git a/lib/Dialect/FlyROCDL/CMakeLists.txt b/lib/Dialect/FlyROCDL/CMakeLists.txt index 0151891b..f8283598 100644 --- a/lib/Dialect/FlyROCDL/CMakeLists.txt +++ b/lib/Dialect/FlyROCDL/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRFlyROCDLDialect Dialect.cpp CDNA3/MmaAtom.cpp CDNA3/CopyAtom.cpp + GFX1250/MmaAtom.cpp DEPENDS MLIRFlyROCDLIncGen diff --git a/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp b/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp new file mode 100644 index 00000000..6521583f --- /dev/null +++ b/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp @@ -0,0 +1,158 @@ +#include "flydsl/Dialect/Fly/IR/FlyDialect.h" +#include "flydsl/Dialect/FlyROCDL/IR/Dialect.h" +#include "mlir/IR/BuiltinTypes.h" + +#include "flydsl/Dialect/Fly/Utils/ThrValLayoutMacro.h.inc" + +using namespace mlir; +using namespace mlir::fly; + +namespace gfx1250 { + +static int getElemBits(Type ty) { + if (ty.isF32() || ty.isInteger(32)) + return 32; + if (ty.isF16() || ty.isBF16()) + return 16; + if (isa(ty) || isa(ty) || + ty.isInteger(8)) + return 8; + return 0; +} + +// A/B matrix register layout for GFX1250 WMMA (wave32). +// +// The A matrix is M×K (M=16, K varies by instruction). The 32 lanes split +// into two groups of 16 (group = lane/16). Both groups hold different slices +// of the K dimension. +// +// For 32-bit elements (f32, K=4): +// Each lane holds K/2 values. Group g covers K = g*(K/2) .. (g+1)*(K/2)-1. +// No sub-element packing. 2 VGPRs per lane. +// Formula: K = (l/16)*2 + v +// +// For sub-32-bit elements (f16/bf16 K=32, fp8/bf8/i8 K=64/128): +// Each lane holds K/2 values, organized in blocks of 8. Within each +// block, group 0 holds the lower 8 K-values, group 1 holds the upper 8. +// Formula: K = block*16 + (l/16)*8 + within_block +// where block = flat_val / 8, within_block = flat_val % 8. +// +// Reference space is column-major (M,K) with stride (1, M=16). +// The B matrix (N×K) uses the identical layout with N substituted for M. +LayoutAttr getThrValLayoutAB(MLIRContext *ctx, int32_t K, Type elemTy) { + auto getContext = [&]() { return ctx; }; + + int elemBits = getElemBits(elemTy); + int valsPerLane = K / 2; + + if (elemBits == 32) { + // f32 16×4: 2 values/lane, no sub-element packing. + // pos = (l%16)*1 + (l/16)*(valsPerLane*16) + v*16 + return FxLayout(FxShape(FxThr(16, 2), FxVal(valsPerLane)), + FxStride(FxThr(1, valsPerLane * 16), FxVal(16))); + } + + // Sub-32-bit: interleaving block of 8 values between lane groups. + // pos = (l%16)*1 + (l/16)*128 + val_within*16 [+ block*256] + int numBlocks = valsPerLane / 8; + if (numBlocks == 1) { + return FxLayout(FxShape(FxThr(16, 2), FxVal(8)), + FxStride(FxThr(1, 128), FxVal(16))); + } + return FxLayout(FxShape(FxThr(16, 2), FxVal(8, numBlocks)), + FxStride(FxThr(1, 128), FxVal(16, 256))); +} + +// C/D matrix register layout for GFX1250 WMMA (wave32). +// +// C/D is always 16×16 (M×N). Lane l covers N = l%16. The two lane groups +// cover M=0..7 (group 0) and M=8..15 (group 1). +// +// 32-bit accumulator (f32, i32): 8 VGPRs, one element per VGPR. +// M = (l/16)*8 + v +// +// 16-bit accumulator (f16, bf16): 4 VGPRs, two packed sub-elements each. +// M = (l/16)*8 + v*2 + s +// +// Reference space is column-major (M,N) with stride (1, M=16). +LayoutAttr getThrValLayoutCD(MLIRContext *ctx, Type elemTyAcc) { + auto getContext = [&]() { return ctx; }; + + int elemBits = getElemBits(elemTyAcc); + if (elemBits >= 32) { + return FxLayout(FxShape(FxThr(16, 2), FxVal(8)), + FxStride(FxThr(16, 8), FxVal(1))); + } + // 16-bit: 4 VGPRs × 2 sub-elements = 8 values. + return FxLayout(FxShape(FxThr(16, 2), FxVal(4, 2)), + FxStride(FxThr(16, 8), FxVal(2, 1))); +} + +} // namespace gfx1250 + +namespace mlir::fly_rocdl { + +bool MmaAtomGFX1250_WMMAType::isStatic() const { return true; } + +Attribute MmaAtomGFX1250_WMMAType::getThrLayout() const { + return FxLayout(FxC(32), FxC(1)); +} + +Attribute MmaAtomGFX1250_WMMAType::getShapeMNK() const { + return IntTupleAttr::get( + ArrayAttr::get(getContext(), {FxC(getM()), FxC(getN()), FxC(getK())})); +} + +Attribute MmaAtomGFX1250_WMMAType::getThrValLayoutA() const { + return gfx1250::getThrValLayoutAB(getContext(), getK(), getElemTyA()); +} + +Attribute MmaAtomGFX1250_WMMAType::getThrValLayoutB() const { + return gfx1250::getThrValLayoutAB(getContext(), getK(), getElemTyB()); +} + +Attribute MmaAtomGFX1250_WMMAType::getThrValLayoutC() const { + return gfx1250::getThrValLayoutCD(getContext(), getElemTyAcc()); +} + +LogicalResult +MmaAtomGFX1250_WMMAType::verify(function_ref emitError, + int32_t m, int32_t n, int32_t k, Type elemTyA, + Type elemTyB, Type elemTyAcc) { + if (m != 16 || n != 16) + return emitError() << "GFX1250 WMMA requires M=N=16, got " << m << "x" + << n; + + auto isF8 = [](Type ty) { + return isa(ty) || isa(ty); + }; + + bool valid = false; + + if (k == 4 && elemTyA.isF32() && elemTyB.isF32() && elemTyAcc.isF32()) + valid = true; + + if (k == 32 && elemTyA.isF16() && elemTyB.isF16() && + (elemTyAcc.isF32() || elemTyAcc.isF16())) + valid = true; + if (k == 32 && elemTyA.isBF16() && elemTyB.isBF16() && + (elemTyAcc.isF32() || elemTyAcc.isBF16())) + valid = true; + + if ((k == 64 || k == 128) && isF8(elemTyA) && isF8(elemTyB) && + (elemTyAcc.isF32() || elemTyAcc.isF16())) + valid = true; + + if (k == 64 && elemTyA.isInteger(8) && elemTyB.isInteger(8) && + elemTyAcc.isInteger(32)) + valid = true; + + if (!valid) { + return emitError() << "unsupported GFX1250 WMMA configuration: " << m + << "x" << n << "x" << k << " with A=" << elemTyA + << ", B=" << elemTyB << ", Acc=" << elemTyAcc; + } + return success(); +} + +} // namespace mlir::fly_rocdl diff --git a/python/flydsl/_mlir b/python/flydsl/_mlir new file mode 120000 index 00000000..eba03cff --- /dev/null +++ b/python/flydsl/_mlir @@ -0,0 +1 @@ +/home/jli10004/flydsl/mi450-cmodel-env/flydsl-prev/build-fly/python_packages/flydsl/_mlir \ No newline at end of file diff --git a/python/flydsl/compiler/jit_function.py b/python/flydsl/compiler/jit_function.py index d1876427..5bd12ebb 100644 --- a/python/flydsl/compiler/jit_function.py +++ b/python/flydsl/compiler/jit_function.py @@ -316,6 +316,13 @@ def _pipeline_fragments(*, chip: str) -> list: "gpu-module-to-binary{format=fatbin}", ] + @staticmethod + def _use_wave64(chip: str) -> bool: + chip = str(chip) + if chip.startswith("gfx12"): + return False + return True + @classmethod def compile(cls, module: ir.Module, *, chip: str = None, func_name: str = "") -> ir.Module: module.operation.verify() diff --git a/python/flydsl/expr/__init__.py b/python/flydsl/expr/__init__.py index 892eb6bf..b654baa4 100644 --- a/python/flydsl/expr/__init__.py +++ b/python/flydsl/expr/__init__.py @@ -4,4 +4,4 @@ from .gpu import * from .derived import * -from . import arith, vector, gpu, buffer_ops, rocdl +from . import arith, vector, gpu, buffer_ops, rocdl, tdm_ops diff --git a/python/flydsl/expr/rocdl.py b/python/flydsl/expr/rocdl.py index cb1656b4..b2a979a0 100644 --- a/python/flydsl/expr/rocdl.py +++ b/python/flydsl/expr/rocdl.py @@ -18,6 +18,7 @@ from .._mlir._mlir_libs._fly_rocdl import CopyOpCDNA3BufferLDSTType from .._mlir._mlir_libs._fly_rocdl import MmaAtomCDNA3_MFMAType +from .._mlir._mlir_libs._fly_rocdl import MmaAtomGFX1250_WMMAType BufferLDST = lambda bit_size: CopyOpCDNA3BufferLDSTType.get(bit_size) # noqa: E731 BufferLDST32b = lambda: CopyOpCDNA3BufferLDSTType.get(32) # noqa: E731 @@ -49,6 +50,29 @@ def MFMA(m, n, k, elem_type, elem_type_b=None, elem_type_acc=None): return MmaAtomCDNA3_MFMAType.get(m, n, k, ty, ty_b, ty_acc) +def WMMA(m, n, k, elem_type, elem_type_b=None, elem_type_acc=None): + """Create a WMMA MMA atom type for GFX1250 (wave32). + + Args: + m, n, k: WMMA tile dimensions. + elem_type: Element type for A operand. + elem_type_b: Element type for B operand (defaults to elem_type). + elem_type_acc: Element type for accumulator (defaults to elem_type). + """ + from .._mlir import ir + + if isinstance(elem_type, type) and hasattr(elem_type, 'ir_type'): + ty = elem_type.ir_type + elif isinstance(elem_type, ir.Type): + ty = elem_type + else: + raise TypeError(f"WMMA: unsupported elem_type {elem_type}") + + ty_b = ty if elem_type_b is None else (elem_type_b.ir_type if hasattr(elem_type_b, 'ir_type') else elem_type_b) + ty_acc = ty if elem_type_acc is None else (elem_type_acc.ir_type if hasattr(elem_type_acc, 'ir_type') else elem_type_acc) + return MmaAtomGFX1250_WMMAType.get(m, n, k, ty, ty_b, ty_acc) + + def make_buffer_tensor(memref, alignment=4, loc=None, ip=None): """Convert a global-address-space fly memref to a buffer_desc memref. @@ -84,6 +108,7 @@ def make_buffer_tensor(memref, alignment=4, loc=None, ip=None): return _prim.make_view(bd_ptr, layout, loc=loc, ip=ip) # Keep references to ODS-generated builders so we can wrap them without losing access. +_ods_wave_id = wave_id # ODS: wave_id(res, ...) -> i32 _ods_mfma_f32_16x16x16f16 = mfma_f32_16x16x16f16 _ods_mfma_f32_16x16x16bf16_1k = globals().get("mfma_f32_16x16x16bf16_1k", None) _ods_mfma_f32_16x16x32_fp8_fp8 = mfma_f32_16x16x32_fp8_fp8 @@ -292,6 +317,20 @@ def wmma_i32_16x16x32_iu4(result_type, operands, *, loc=None, ip=None): ops = [_unwrap_wmma_operand(v, loc=loc) for v in operands] return _ods_wmma_i32_16x16x32_iu4(result_type, ops, loc=loc, ip=ip).result +def wave_id(): + """Get wave-id-in-workgroup as SGPR (via TTMP8[29:25]). + + On gfx1250 this reads an architected SGPR, so the result stays in + the SGPR pipeline and all derived computations are automatically + scalarized by LLVM uniformity analysis. + + Returns: + i32 value (SGPR) with the wave ID within the workgroup. + """ + from .._mlir import ir + i32 = ir.IntegerType.get_signless(32) + return _ods_wave_id(i32) + __all__ = [ # Thread/Block/Grid IDs and dimensions @@ -300,6 +339,7 @@ def wmma_i32_16x16x32_iu4(result_type, operands, *, loc=None, ip=None): 'workgroup_dim_x', 'workgroup_dim_y', 'workgroup_dim_z', 'grid_dim_x', 'grid_dim_y', 'grid_dim_z', 'wavefrontsize', + 'wave_id', # Synchronization 'barrier', 's_barrier', 's_barrier_signal', 's_barrier_wait', @@ -367,9 +407,20 @@ def wmma_i32_16x16x32_iu4(result_type, operands, *, loc=None, ip=None): # MMA atom types 'MmaAtomCDNA3_MFMAType', 'MFMA', + 'MmaAtomGFX1250_WMMAType', 'WMMA', # Convenience wrappers 'make_buffer_tensor', + + # gfx1250 TDM - descriptor-driven tile copy (preferred over per-lane) + 'tensor_load_to_lds', # 4-group, up to 5D tensor + 'tensor_load_to_lds_d2', # 2-group, up to 2D tensor + 'tensor_store_from_lds', # 4-group store + 'tensor_store_from_lds_d2', # 2-group store + 's_wait_tensorcnt', + + # gfx1250 L2 prefetch + 'global_prefetch', # per-lane 1-byte prefetch hint ] diff --git a/python/flydsl/expr/rocdl/__init__.py b/python/flydsl/expr/rocdl/__init__.py index cce39593..7c0f393e 100644 --- a/python/flydsl/expr/rocdl/__init__.py +++ b/python/flydsl/expr/rocdl/__init__.py @@ -14,6 +14,7 @@ from ..._mlir.dialects.rocdl import * # noqa: F401,F403 # Keep references to ODS-generated builders so we can wrap them without losing access. +_ods_wave_id = wave_id # ODS: wave_id(res, ...) -> i32 _ods_mfma_f32_16x16x16f16 = mfma_f32_16x16x16f16 _ods_mfma_f32_16x16x16bf16_1k = globals().get("mfma_f32_16x16x16bf16_1k", None) _ods_mfma_f32_16x16x32_fp8_fp8 = mfma_f32_16x16x32_fp8_fp8 @@ -101,6 +102,17 @@ def mfma_scale_f32_16x16x128_f8f6f4(result_type, operands, *, loc=None, ip=None) ).result +def wave_id(): + """Get wave-id-in-workgroup as SGPR (via TTMP8[29:25]). + + Returns: + i32 value (SGPR) with the wave ID within the workgroup. + """ + from ..._mlir import ir + i32 = ir.IntegerType.get_signless(32) + return _ods_wave_id(i32) + + # ── New high-level helpers from universal.py ────────────────────────── from .universal import * # noqa: F401,F403 diff --git a/python/flydsl/expr/rocdl/universal.py b/python/flydsl/expr/rocdl/universal.py index d7a6f463..47ddbe3d 100644 --- a/python/flydsl/expr/rocdl/universal.py +++ b/python/flydsl/expr/rocdl/universal.py @@ -4,6 +4,7 @@ from ..._mlir.dialects.fly import LayoutType, PointerType from ..._mlir.dialects.fly import MemRefType as FlyMemRefType from ..._mlir.dialects.fly_rocdl import CopyOpCDNA3BufferLDSTType, MmaAtomCDNA3_MFMAType +from ..._mlir._mlir_libs._fly_rocdl import MmaAtomGFX1250_WMMAType from ..primitive import ( get_iter, get_layout, @@ -28,6 +29,15 @@ def MFMA(m, n, k, elem_ty_ab, elem_ty_acc=None): return MmaAtomCDNA3_MFMAType.get(m, n, k, ty_ab, ty_ab, ty_acc) +def WMMA(m, n, k, elem_ty_ab, elem_ty_acc=None): + ty_ab = elem_ty_ab.ir_type if hasattr(elem_ty_ab, "ir_type") else elem_ty_ab + if elem_ty_acc is None: + ty_acc = ir.F32Type.get() + else: + ty_acc = elem_ty_acc.ir_type if hasattr(elem_ty_acc, "ir_type") else elem_ty_acc + return MmaAtomGFX1250_WMMAType.get(m, n, k, ty_ab, ty_ab, ty_acc) + + def make_buffer_tensor(tensor: Tensor) -> Tensor: def _elem_bit_width(elem_ty): if hasattr(elem_ty, "width"): diff --git a/python/flydsl/expr/tdm_ops.py b/python/flydsl/expr/tdm_ops.py new file mode 100644 index 00000000..d8ebd8de --- /dev/null +++ b/python/flydsl/expr/tdm_ops.py @@ -0,0 +1,502 @@ +"""TDM (Tensor Data Mover) operations for gfx1250. + +High-level Python API that encapsulates TDM descriptor construction, +analogous to how buffer_ops.py wraps buffer resource descriptors. + +The TDM hardware on gfx1250 provides descriptor-driven DMA for +Global <-> LDS transfers. This module hides the bitfield packing +behind a clean API: + + desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_a, lds_memref=lds_a_mem, + global_offset=(blk_m, k_base), + tensor_shape=(tile_m, K), strides=(K, 1), + tile_shape=(tile_m, tile_k), + elem_bytes=2, + pad_interval=64, pad_amount=8, + num_warps=8, + ) + tdm_ops.tensor_load_2d(desc) + tdm_ops.tensor_wait(0) +""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Optional, Sequence, Tuple, Union + +from .._mlir import ir +from .._mlir.dialects import ( + arith as std_arith, + llvm as llvm_dialect, + memref as memref_dialect, + rocdl, +) +from ..expr import arith, vector +from ..expr.arith import _to_raw as _raw +from ..expr.typing import T +from ..expr.utils.arith import ArithValue as _ArithValue + +__all__ = [ + "TDMDescriptor2D", + "make_tensor_descriptor_2d", + "tensor_load_2d", + "tensor_store_2d", + "tensor_wait", + "compute_padding_encoding", + "compute_warp_distribution", + "l2_prefetch_tile", +] + + +# --------------------------------------------------------------------------- +# Pure-Python helpers (compile-time, no IR emission) +# --------------------------------------------------------------------------- + +def compute_padding_encoding( + pad_interval_elems: int, + pad_amount_elems: int, + elem_bits: int = 16, +) -> Tuple[int, int]: + """Compute TDM descriptor padding bitfield values. + + Follows Triton TDMUtility.cpp convention: + padIntervalInDwords = pad_interval_elems * elem_bits / 32 + padAmountInDwords = pad_amount_elems * elem_bits / 32 + encoded_interval = log2(padIntervalInDwords) - 1 + encoded_amount = padAmountInDwords - 1 + + Args: + pad_interval_elems: Padding interval in elements (e.g. tile_k = 64). + pad_amount_elems: Padding amount in elements (e.g. LDS_PAD = 8). + elem_bits: Bits per element (16 for f16/bf16, 32 for f32). + + Returns: + (encoded_interval, encoded_amount) ready for descriptor bits. + """ + dword_bits = 32 + interval_dw = pad_interval_elems * elem_bits // dword_bits + amount_dw = pad_amount_elems * elem_bits // dword_bits + if interval_dw <= 0 or amount_dw <= 0: + return (0, 0) + assert interval_dw & (interval_dw - 1) == 0, ( + f"padIntervalInDwords must be power-of-2, got {interval_dw}" + ) + encoded_interval = int(math.log2(interval_dw)) - 1 + encoded_amount = amount_dw - 1 + return (encoded_interval, encoded_amount) + + +def compute_warp_distribution( + block_shape: Sequence[int], + num_warps: int, +) -> Tuple[list, list]: + """Compute per-warp block sub-tile after distributing warps. + + Mirrors Triton's tdmGetWarpDistribution + tdmGetAdjustedBlockShape + from TDMCommon.h. + + Args: + block_shape: Full tile shape, e.g. [tile_m, tile_k]. + num_warps: Total number of warps in the workgroup. + + Returns: + (warps_per_dim, block_per_warp) — how many warps along each dim + and the sub-tile size each warp handles. + """ + ndims = len(block_shape) + warps = [1] * ndims + remaining = num_warps + for i in range(ndims): + while remaining > 1 and warps[i] * 2 <= block_shape[i]: + warps[i] *= 2 + remaining //= 2 + if remaining > 1: + warps[-1] *= remaining + block_per_warp = [ + (block_shape[i] + warps[i] - 1) // warps[i] + for i in range(ndims) + ] + return warps, block_per_warp + + +# --------------------------------------------------------------------------- +# Descriptor data class +# --------------------------------------------------------------------------- + +@dataclass +class TDMDescriptor2D: + """Holds constructed GROUP0 and GROUP1 vectors for tensor_load_to_lds_d2.""" + dgroup0: object # vector<4xi32> MLIR Value + dgroup1: object # vector<8xi32> MLIR Value + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +def _unwrap(value): + """Unwrap ArithValue wrappers to get raw ir.Value.""" + max_depth = 10 + depth = 0 + while depth < max_depth and not isinstance(value, ir.Value): + if hasattr(value, "_value"): + value = value._value + elif hasattr(value, "value"): + value = value.value + else: + break + depth += 1 + return value + + +def _i32_const(v: int) -> ir.Value: + """Emit an i32 constant, handling negative / unsigned values.""" + i32 = ir.IntegerType.get_signless(32) + if v > 0x7FFFFFFF: + v = int(v - 2**32) + return _unwrap(std_arith.ConstantOp(i32, ir.IntegerAttr.get(i32, v)).result) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def make_tensor_descriptor_2d( + global_ptr, + lds_memref, + global_offset: Tuple, + tensor_shape: Tuple[int, int], + strides: Tuple[int, int], + tile_shape: Tuple[int, int], + elem_bytes: int = 2, + pad_interval: int = 0, + pad_amount: int = 0, + num_warps: int = 1, + cache_policy: int = 0, + pred: int = 1, +) -> TDMDescriptor2D: + """Build a 2D TDM descriptor for tensor_load_to_lds_d2. + + Convention (matching ISA): + dim0 = innermost (fastest-varying, e.g. K for row-major A) + dim1 = outermost (e.g. M for row-major A) + tensor_shape = (outer_size, inner_size) in user order + strides = (outer_stride, inner_stride) + tile_shape = (outer_tile, inner_tile) + global_offset is (outer_offset, inner_offset) — MLIR index Values + + Per-warp distribution is handled internally when num_warps > 1: + each wave computes its own LDS and global offsets so that all waves + collectively cover the full tile. + + Padding params are in ELEMENTS (converted to dwords for encoding). + + Args: + global_ptr: The global tensor (fx.Tensor or fly memref value). + lds_memref: The LDS memref value (already the correct buffer slot). + global_offset: (outer_idx, inner_idx) as MLIR index values. + tensor_shape: (outer_size, inner_size) as Python ints. + strides: (outer_stride, inner_stride) as Python ints. + tile_shape: (outer_tile, inner_tile) as Python ints. + elem_bytes: Element size in bytes (2 for f16/bf16, 4 for f32). + pad_interval: Padding interval in elements (0 to disable). + pad_amount: Padding amount in elements (0 to disable). + num_warps: Total warps in the workgroup. + cache_policy: Cache policy (0 = default). + pred: Predicate (1 = enabled). + + Returns: + TDMDescriptor2D with dgroup0 and dgroup1 ready for tensor_load_2d. + """ + from .._mlir.dialects import fly as _fly_d + + outer_size, inner_size = tensor_shape + outer_stride, inner_stride = strides + outer_tile, inner_tile = tile_shape + outer_off, inner_off = global_offset + + # -- Warp distribution -- + warps_per_dim, block_per_warp = compute_warp_distribution( + [outer_tile, inner_tile], num_warps, + ) + bpw_outer, bpw_inner = block_per_warp + warps_dim0 = warps_per_dim[0] + + if num_warps > 1: + # Auto-acquire SGPR wave_id via hardware register (TTMP8[29:25]). + # This keeps the entire descriptor address chain in SALU, + from . import rocdl as _rocdl_ext + _wid_i32 = _rocdl_ext.wave_id() + wave_id = arith.index_cast(T.index, _wid_i32) + warp_coord_outer = wave_id % arith.index(warps_dim0) + warp_coord_inner = wave_id / arith.index(warps_dim0) + warp_off_outer = warp_coord_outer * arith.index(bpw_outer) + warp_off_inner = warp_coord_inner * arith.index(bpw_inner) + else: + warp_off_outer = arith.index(0) + warp_off_inner = arith.index(0) + + # -- Global address (byte address for descriptor) -- + glb_ptr_type = ir.Type.parse("!llvm.ptr<1>") + i64 = ir.IntegerType.get_signless(64) + a_raw = global_ptr.__fly_values__()[0] + glb_ptr = _fly_d.extract_aligned_pointer_as_index(glb_ptr_type, a_raw) + glb_base_i64 = _ArithValue(llvm_dialect.ptrtoint(i64, glb_ptr)) + glb_elem_off = ( + (outer_off + warp_off_outer) * arith.index(outer_stride) + + (inner_off + warp_off_inner) * arith.index(inner_stride) + ) + glb_byte_off = glb_elem_off * arith.index(elem_bytes) + glb_byte_off_i64 = arith.index_cast(T.i64, glb_byte_off) + glb_addr_i64 = glb_base_i64 + glb_byte_off_i64 + + # -- LDS address (byte address within shared memory) -- + lds_base_idx = _ArithValue(memref_dialect.extract_aligned_pointer_as_index(lds_memref)) + # Compute padded LDS stride (elements) for the outer dim + if pad_interval > 0 and pad_amount > 0: + lds_inner_stride = inner_tile + pad_amount # padded row width + else: + lds_inner_stride = inner_tile + lds_warp_elem_off = ( + warp_off_outer * arith.index(lds_inner_stride) + warp_off_inner + ) + lds_warp_byte_off = lds_warp_elem_off * arith.index(elem_bytes) + lds_addr_i32 = arith.index_cast(T.i32, lds_base_idx + lds_warp_byte_off) + + # ================================================================ + # GROUP0 (vector<4xi32>): pred, lds_addr, global_addr_lo/hi + # ================================================================ + g0_s0 = arith.constant(pred, type=T.i32) + g0_s1 = lds_addr_i32 + i32 = ir.IntegerType.get_signless(32) + g0_s2 = _ArithValue(std_arith.TruncIOp(i32, _raw(glb_addr_i64)).result) + hi_raw = _ArithValue(_raw(glb_addr_i64)).shrui(arith.constant(32, type=T.i64)) + g0_s3 = ( + _ArithValue(std_arith.TruncIOp(i32, _raw(hi_raw)).result) + | arith.constant(1 << 31, type=T.i32) # type field = 2 in [31:30] + ) + dgroup0 = vector.from_elements( + T.vec(4, T.i32), [g0_s0, g0_s1, g0_s2, g0_s3] + ) + + # ================================================================ + # GROUP1 (vector<8xi32>): config + tensor dims + strides + tile + # ================================================================ + # Descriptor dim ordering: dim0=innermost, dim1=outermost + tdim0 = bpw_inner # innermost extent per warp + tdim1 = bpw_outer # outermost extent per warp + tile_d0 = bpw_inner # block dim0 per warp + tile_d1 = bpw_outer # block dim1 per warp + # stride_dim0 in descriptor = outermost stride in elements + stride0 = outer_stride + + # data_size = log2(elem_bytes) + data_size_code = int(math.log2(elem_bytes)) + + # Padding encoding + if pad_interval > 0 and pad_amount > 0: + elem_bits = elem_bytes * 8 + enc_interval, enc_amount = compute_padding_encoding( + pad_interval, pad_amount, elem_bits + ) + pad_enable = 1 + else: + enc_interval, enc_amount = 0, 0 + pad_enable = 0 + + # sgpr0: config bitfields + g1_s0_val = ( + (0) # workgroup_mask [15:0] + | (data_size_code << 16) # data_size [17:16] + | (0 << 18) # atomic_barrier_enable + | (0 << 19) # iterate_enable + | (pad_enable << 20) # pad_enable + | (0 << 21) # early_timeout + | (enc_interval << 22) # pad_interval [24:22] + | (enc_amount << 25) # pad_amount [31:25] + ) + g1_s0 = arith.constant(g1_s0_val, type=T.i32) + + # sgpr1: atomic_barrier_addr[15:0]=0 | tensor_dim0_lo[31:16] + g1_s1 = arith.constant((tdim0 & 0xFFFF) << 16, type=T.i32) + + # sgpr2: tensor_dim0_hi[15:0] | tensor_dim1_lo[31:16] + g1_s2 = arith.constant( + ((tdim0 >> 16) & 0xFFFF) | ((tdim1 & 0xFFFF) << 16), + type=T.i32, + ) + + # sgpr3: tensor_dim1_hi[15:0] | tile_dim0[31:16] + g1_s3 = arith.constant( + ((tdim1 >> 16) & 0xFFFF) | (tile_d0 << 16), + type=T.i32, + ) + + # sgpr4: tile_dim1[15:0] | tile_dim2[31:16]=0 + g1_s4 = arith.constant(tile_d1 & 0xFFFF, type=T.i32) + + # sgpr5: tensor_dim0_stride (low 32 bits) — stride of outermost dim + g1_s5 = arith.constant(stride0 & 0xFFFFFFFF, type=T.i32) + + # sgpr6-7: for 2D, no higher-dim strides + g1_s6 = arith.constant(0, type=T.i32) + g1_s7 = arith.constant(0, type=T.i32) + + dgroup1 = vector.from_elements( + T.vec(8, T.i32), + [g1_s0, g1_s1, g1_s2, g1_s3, g1_s4, g1_s5, g1_s6, g1_s7], + ) + + return TDMDescriptor2D(dgroup0=dgroup0, dgroup1=dgroup1) + + +def _zero_dgroup_v4i32(): + """Create a zero vector<4xi32> for unused descriptor groups.""" + z = arith.constant(0, type=T.i32) + return vector.from_elements(T.vec(4, T.i32), [z, z, z, z]) + + +def _zero_dgroup_v8i32(): + """Create a zero vector<8xi32> for unused descriptor groups.""" + z = arith.constant(0, type=T.i32) + return vector.from_elements(T.vec(8, T.i32), [z, z, z, z, z, z, z, z]) + + +def tensor_load_2d( + desc: TDMDescriptor2D, + cache_policy: int = 0, +) -> None: + """Issue a TDM 2D async load (Global -> LDS). + + Each wave in the workgroup calls this with its own descriptor + (as built by make_tensor_descriptor_2d). All waves together + cover the full tile. + + Uses the unified 5-group intrinsic with dgroup2/dgroup3/dgroup4 + zero-initialized for 2D tensors. + + Args: + desc: TDMDescriptor2D from make_tensor_descriptor_2d. + cache_policy: Cache policy (0 = default). + """ + dg2 = _raw(_zero_dgroup_v4i32()) + dg3 = _raw(_zero_dgroup_v4i32()) + dg4 = _raw(_zero_dgroup_v8i32()) + rocdl.tensor_load_to_lds( + _raw(desc.dgroup0), _raw(desc.dgroup1), dg2, dg3, dg4, cache_policy + ) + + +def tensor_store_2d( + desc: TDMDescriptor2D, + cache_policy: int = 0, +) -> None: + """Issue a TDM 2D async store (LDS -> Global). + + Uses the unified 5-group intrinsic with dgroup2/dgroup3/dgroup4 + zero-initialized for 2D tensors. + + Args: + desc: TDMDescriptor2D (with LDS source and global destination). + cache_policy: Cache policy (0 = default). + """ + dg2 = _raw(_zero_dgroup_v4i32()) + dg3 = _raw(_zero_dgroup_v4i32()) + dg4 = _raw(_zero_dgroup_v8i32()) + rocdl.tensor_store_from_lds( + _raw(desc.dgroup0), _raw(desc.dgroup1), dg2, dg3, dg4, cache_policy + ) + + +def tensor_wait(count: int = 0) -> None: + """Wait for outstanding TDM tensor operations. + + Issues s_wait_tensorcnt. + + Args: + count: Number of outstanding operations to allow (0 = wait for all). + """ + rocdl.s_wait_tensorcnt(count) + + +# --------------------------------------------------------------------------- +# L2 prefetch +# --------------------------------------------------------------------------- + +# Scope constants for global_prefetch +PREFETCH_SCOPE_SE = 8 # SE scope = L2 cache +PREFETCH_SCOPE_DEVICE = 16 # Device scope + +def l2_prefetch_tile( + global_ptr, + global_offset: Tuple, + tile_shape: Tuple[int, int], + strides: Tuple[int, int], + elem_bytes: int = 2, + num_warps: int = 1, + wave_id=None, + thread_id=None, + block_threads: int = 256, + scope: int = PREFETCH_SCOPE_SE, +) -> None: + """Issue per-lane L2 cache prefetch hints for a 2D tile. + + Each lane in the workgroup prefetches 1 byte at a distinct global address + within the tile, distributing prefetch coverage across the tile. + + For a tile of outer×inner elements, each lane covers a unique row offset. + Multiple calls (from successive iterations) accumulate coverage. + + Args: + global_ptr: The global tensor (fx.Tensor). + global_offset: (outer_idx, inner_idx) as MLIR index values. + tile_shape: (outer_size, inner_size) in elements. + strides: (outer_stride, inner_stride) in elements. + elem_bytes: Element size in bytes. + num_warps: Total warps in the workgroup. + wave_id: Current wave ID (MLIR index). Unused; thread_id used instead. + thread_id: Workgroup-local thread ID (MLIR index value). + block_threads: Total threads in the workgroup. + scope: Prefetch scope (default: SE = L2). + """ + from .._mlir.dialects import ( + fly as _fly_d, + llvm as llvm_dialect, + ) + + outer_size, inner_size = tile_shape + outer_stride, inner_stride = strides + outer_off, inner_off = global_offset + + # Get global base address as i64 + glb_ptr_type = ir.Type.parse("!llvm.ptr<1>") + i64 = ir.IntegerType.get_signless(64) + a_raw = global_ptr.__fly_values__()[0] + glb_ptr = _fly_d.extract_aligned_pointer_as_index(glb_ptr_type, a_raw) + glb_base_i64 = _ArithValue(llvm_dialect.ptrtoint(i64, glb_ptr)) + + # Each thread prefetches one row of the tile. + # thread_id maps to an outer-dim offset within the tile. + # Total rows = outer_size; if block_threads > outer_size, some threads + # wrap and prefetch additional cachelines. + # For simplicity, each thread prefetches row[tid % outer_size], col=0. + tile_row = thread_id % arith.index(outer_size) + + elem_off = ( + (outer_off + tile_row) * arith.index(outer_stride) + + inner_off * arith.index(inner_stride) + ) + byte_off = elem_off * arith.index(elem_bytes) + byte_off_i64 = arith.index_cast(T.i64, byte_off) + addr_i64 = glb_base_i64 + byte_off_i64 + + # Convert i64 address to pointer + ptr_val = llvm_dialect.inttoptr(glb_ptr_type, _raw(addr_i64)) + + # Issue prefetch hint via ROCDL dialect op. + # NOTE: rocdl.global_prefetch lowers to llvm.amdgcn.global.prefetch, which + # requires LLVM ISel support for gfx1250 global_prefetch_b8. If the LLVM + # build lacks this pattern, the instruction will be silently dropped. + rocdl.global_prefetch(ptr_val, scope) diff --git a/python/flydsl/runtime/device.py b/python/flydsl/runtime/device.py index c42833f1..36a3faef 100644 --- a/python/flydsl/runtime/device.py +++ b/python/flydsl/runtime/device.py @@ -4,13 +4,13 @@ from typing import Optional -def _arch_from_rocm_agent_enumerator() -> Optional[str]: +def _arch_from_rocm_agent_enumerator(timeout_s: int = 5) -> Optional[str]: """Query rocm_agent_enumerator (standard ROCm tool) for the first GPU arch.""" try: out = subprocess.check_output( ["rocm_agent_enumerator", "-name"], text=True, - timeout=5, + timeout=timeout_s, stderr=subprocess.DEVNULL, ) for line in out.splitlines(): @@ -23,9 +23,11 @@ def _arch_from_rocm_agent_enumerator() -> Optional[str]: @functools.lru_cache(maxsize=None) -def get_rocm_arch() -> str: - """Best-effort ROCm GPU arch string (e.g. 'gfx942').""" - env = os.environ.get("FLYDSL_GPU_ARCH") or os.environ.get("HSA_OVERRIDE_GFX_VERSION") +def get_rocm_arch(timeout_s: int = 5) -> str: + """Best-effort ROCm GPU arch string (e.g. 'gfx942') without torch.""" + env = (os.environ.get("FLYDSL_GPU_ARCH") + or os.environ.get("HSA_OVERRIDE_GFX_VERSION") + ) if env: if env.startswith("gfx"): return env @@ -33,7 +35,7 @@ def get_rocm_arch() -> str: parts = env.split(".") return f"gfx{parts[0]}{parts[1]}{parts[2]}" - arch = _arch_from_rocm_agent_enumerator() + arch = _arch_from_rocm_agent_enumerator(timeout_s=timeout_s) if arch: return arch.split(":", 1)[0] diff --git a/tests/kernels/test_wmma_gemm_gfx1250.py b/tests/kernels/test_wmma_gemm_gfx1250.py new file mode 100644 index 00000000..068302d3 --- /dev/null +++ b/tests/kernels/test_wmma_gemm_gfx1250.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 +"""WMMA GEMM using TDM tests for gfx1250. + +Kernel implementation lives in `kernels/wmma_gemm_gfx1250.py`. +This file is the correctness harness. +""" + +import os +import sys + +import pytest +import torch + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +_PYFLIR_SRC = os.path.join(_REPO_ROOT, "flydsl", "src") +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) +if _PYFLIR_SRC not in sys.path: + sys.path.insert(0, _PYFLIR_SRC) + +from flydsl.runtime.device import get_rocm_arch +from kernels.wmma_gemm_gfx1250 import compile_wmma_gemm_tdm +from tests.test_common import verify_output + + +if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available. Skipping GPU tests.", allow_module_level=True) + + +@pytest.mark.parametrize("in_dtype", ["fp16", "bf16"]) +@pytest.mark.parametrize( + "M, N, K, tile_m, tile_n, tile_k", + [ + (128, 128, 64, 64, 128, 32), + (128, 128, 256, 64, 128, 128), + (256, 256, 256, 64, 256, 128), + (256, 256, 192, 64, 256, 64), + (256, 512, 256, 64, 256, 128), + (512, 512, 512, 64, 256, 128), + (201, 179, 128, 64, 128, 64), + (300, 399, 256, 64, 256, 128), + (256, 256, 256, 256, 256, 128), + (1024, 1024, 1024, 256, 256, 128), + (512, 512, 512, 256, 256, 128), + ], +) +@pytest.mark.parametrize("num_buffers", [2, 3]) +def test_wmma_gemm_tdm(in_dtype, M, N, K, tile_m, tile_n, tile_k, + num_buffers, + m_warp=2, n_warp=4, l2_prefetch_distance=2): + arch = str(get_rocm_arch(timeout_s=300)) + if arch != "gfx1250": + pytest.skip(f"WMMA requires gfx1250, got {arch}") + + num_k_tiles = K // tile_k + if num_buffers == 3 and num_k_tiles < 3: + pytest.skip(f"Triple buffer requires num_k_tiles >= 3, got {num_k_tiles}") + + lds_pad = 8 + elem_bytes = 2 + a_buf = tile_m * (tile_k + lds_pad) * elem_bytes + b_buf = tile_n * (tile_k + lds_pad) * elem_bytes + total_lds = (a_buf + b_buf) * num_buffers + if total_lds > 327680: + pytest.skip(f"LDS budget exceeded: {total_lds} > 327680") + + print(f"Running WMMA GEMM TDM: M={M}, N={N}, K={K}, " + f"dtype={in_dtype}, bufs={num_buffers}") + + torch_dtype = torch.float16 if in_dtype == "fp16" else torch.bfloat16 + device = torch.device("cuda") + torch.manual_seed(0) + + mpad = (M + tile_m - 1) // tile_m * tile_m + npad = (N + tile_n - 1) // tile_n * tile_n + + a = torch.randn((M, K), dtype=torch_dtype, device='cpu').cuda() + b = torch.randn((N, K), dtype=torch_dtype, device='cpu').cuda() + + a_pad = torch.zeros((mpad, K), dtype=torch_dtype, device=device) + b_pad = torch.zeros((npad, K), dtype=torch_dtype, device=device) + a_pad[:M, :] = a + b_pad[:N, :] = b + + c_pad = torch.zeros((mpad, npad), dtype=torch.float32, device=device) + + launch_fn = compile_wmma_gemm_tdm( + M=mpad, N=npad, K=K, + tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, + m_warp=m_warp, n_warp=n_warp, in_dtype=in_dtype, + num_buffers=num_buffers, + l2_prefetch_distance=l2_prefetch_distance, + ) + launch_fn( + c_pad.contiguous().view(-1), + a_pad.contiguous().view(-1), + b_pad.contiguous().view(-1), + mpad, npad, torch.cuda.current_stream(), + ) + torch.cuda.synchronize() + + ref = torch.mm(a.cpu().to(torch.float32), b.cpu().to(torch.float32).T) + rtol = 3e-2 + atol = 3e-2 + assert verify_output(c_pad[:M, :N].cpu().to(torch.float32), ref, rtol=rtol, atol=atol) + print("PASSED") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-M", type=int, default=1024) + parser.add_argument("-N", type=int, default=1024) + parser.add_argument("-K", type=int, default=1024) + parser.add_argument("--tile-m", type=int, default=128) + parser.add_argument("--tile-n", type=int, default=256) + parser.add_argument("--tile-k", type=int, default=128) + parser.add_argument("--m-warp", type=int, default=2) + parser.add_argument("--n-warp", type=int, default=4) + parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "bf16"]) + parser.add_argument("--num-buffers", type=int, default=2, choices=[2, 3]) + parser.add_argument("--l2-prefetch-distance", type=int, default=0) + args = parser.parse_args() + + test_wmma_gemm_tdm( + args.dtype, args.M, args.N, args.K, + args.tile_m, args.tile_n, args.tile_k, + num_buffers=args.num_buffers, + m_warp=args.m_warp, + n_warp=args.n_warp, + l2_prefetch_distance=args.l2_prefetch_distance, + ) diff --git a/tests/kernels/test_wmma_gemm_simple.py b/tests/kernels/test_wmma_gemm_simple.py new file mode 100644 index 00000000..f54fe039 --- /dev/null +++ b/tests/kernels/test_wmma_gemm_simple.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +"""WMMA GEMM tests for gfx1250 — @flyc.kernel API. + +Kernel implementation lives in `kernels/wmma_gemm_simple.py`. +This file is the correctness + perf harness. +""" + +import os +import sys + +import pytest +import torch + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +_PYFLIR_SRC = os.path.join(_REPO_ROOT, "flydsl", "src") +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) +if _PYFLIR_SRC not in sys.path: + sys.path.insert(0, _PYFLIR_SRC) + +from flydsl.runtime.device import get_rocm_arch +from kernels.wmma_gemm_simple import compile_wmma_gemm +from tests.test_common import verify_output + + +if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available. Skipping GPU tests.", allow_module_level=True) + + +@pytest.mark.parametrize("in_dtype", ["fp16", "bf16"]) +@pytest.mark.parametrize( + "M, N, K, tile_m, tile_n, tile_k, block_threads", + [ + (32, 32, 32, 32, 32, 32, 32), + (64, 64, 32, 64, 64, 32, 128), + (128, 128, 32, 64, 128, 32, 256), + (128, 128, 64, 64, 128, 32, 256), + (256, 256, 32, 64, 64, 32, 128), + (200, 180, 64, 64, 64, 32, 128), + (128, 128, 128, 64, 128, 64, 256), + ], +) +def test_wmma_gemm(in_dtype, M, N, K, tile_m, tile_n, tile_k, block_threads): + # rocm_agent_enumerator is very slow on AM simulator, + # set large timeout to avoid timeout and fallback to gfx942 + arch = str(get_rocm_arch(timeout_s=300)) + if arch != "gfx1250": + pytest.skip(f"WMMA requires gfx1250, got {arch}") + print(f"Running WMMA GEMM test with: M={M}, N={N}, K={K}, " + f"tile_m={tile_m}, tile_n={tile_n}, tile_k={tile_k}, " + f"block_threads={block_threads}, dtype={in_dtype}, arch={arch}") + + torch_dtype = torch.float16 if in_dtype == "fp16" else torch.bfloat16 + device = torch.device("cuda") + torch.manual_seed(0) + + # Pad M/N to tile boundaries + mpad = (M + tile_m - 1) // tile_m * tile_m + npad = (N + tile_n - 1) // tile_n * tile_n + + # torch gpu randn has some issues on gfx1250 AM simulator + a = torch.randn((M, K), dtype=torch_dtype, device='cpu').cuda() + b = torch.randn((K, N), dtype=torch_dtype, device='cpu').cuda() + + a_pad = torch.zeros((mpad, K), dtype=torch_dtype, device=device) + b_pad = torch.zeros((K, npad), dtype=torch_dtype, device=device) + a_pad[:M, :] = a + b_pad[:, :N] = b + c_pad = torch.zeros((mpad, npad), dtype=torch.float32, device=device) + + launch_fn = compile_wmma_gemm( + M=mpad, + N=npad, + K=K, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + in_dtype=in_dtype, + block_threads=block_threads, + ) + launch_fn( + c_pad.contiguous().view(-1), + a_pad.contiguous().view(-1), + b_pad.contiguous().view(-1), + mpad, + npad, + torch.cuda.current_stream(), + ) + torch.cuda.synchronize() + + ref = torch.matmul(a.cpu().to(torch.float32), b.cpu().to(torch.float32)) + assert verify_output(c_pad[:M, :N].cpu(), ref, rtol=3e-2, atol=3e-2) + print("✓ PASSED") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-M", type=int, default=256, help='problem M size') + parser.add_argument("-N", type=int, default=256, help='problem N size') + parser.add_argument("-K", type=int, default=1024, help='problem K size') + parser.add_argument("--tile_m", type=int, default=256) + parser.add_argument("--tile_n", type=int, default=256) + parser.add_argument("--tile_k", type=int, default=128) + parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "bf16"], + help="Input data type") + args = parser.parse_args() + + WARP_SIZE = 32 + BLOCK_THREADS = min(args.tile_n, 8 * WARP_SIZE) + + test_wmma_gemm( + args.dtype, + args.M, + args.N, + args.K, + args.tile_m, + args.tile_n, + args.tile_k, + BLOCK_THREADS, + ) diff --git a/thirdparty/llvm-hash.txt b/thirdparty/llvm-hash.txt index 4faf2ea9..978cdc8d 100644 --- a/thirdparty/llvm-hash.txt +++ b/thirdparty/llvm-hash.txt @@ -1 +1 @@ -ac5dc54d509169d387fcfd495d71853d81c46484 +27d654c4c4e6eb7c19e46af20500200e793da7c7 From cc9b7e7cfe63c6d28955f8710e8fd434105ce8c1 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Thu, 12 Mar 2026 09:43:31 +0000 Subject: [PATCH 02/11] refactor multi-stage pipeline --- kernels/wmma_gemm_gfx1250.py | 319 ++++++++++++++--------------------- 1 file changed, 122 insertions(+), 197 deletions(-) diff --git a/kernels/wmma_gemm_gfx1250.py b/kernels/wmma_gemm_gfx1250.py index cc3e8892..0559ec11 100644 --- a/kernels/wmma_gemm_gfx1250.py +++ b/kernels/wmma_gemm_gfx1250.py @@ -13,10 +13,7 @@ from flydsl.expr.arith import _to_raw as _raw from flydsl.expr.typing import T from flydsl.runtime.device import get_rocm_arch as get_hip_arch -from flydsl._mlir.dialects import memref as memref_d -from flydsl.expr.gpu import lds_space from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr, get_op_result_or_value -from flydsl._mlir.extras import types as mlir_T from kernels.layout_utils import crd2idx, idx2crd @@ -26,6 +23,47 @@ LDS_PAD_A = 8 LDS_PAD_B = 8 +_STAGE_NAMES = ("ping", "pong", "pang") + + +def _make_tail_plan(num_buffers, pre_loaded, extra): + """Compute a compile-time tail execution plan for the N-stage pipeline. + + Returns a list of (load_stage, compute_stage, outstanding) tuples, one per + tail step. outstanding=-1 means "last step, use compute_tile (no barrier)". + + Args: + num_buffers: total number of pipeline stages. + pre_loaded: stages already loaded and ready to compute (= num_buffers - 1). + extra: additional tiles that must be loaded in the tail. + """ + steps = pre_loaded + extra + plan = [] + for i in range(steps): + compute_stage = ( + i if i < pre_loaded + else (i - pre_loaded + num_buffers - 1) % num_buffers + ) + load_stage = ( + (i + num_buffers - 1) % num_buffers if i < extra + else None + ) + is_last = (i == steps - 1) + if is_last: + outstanding = -1 + else: + j = i + 1 + next_compute = ( + j if j < pre_loaded + else (j - pre_loaded + num_buffers - 1) % num_buffers + ) + outstanding = ( + 2 if (load_stage is not None and load_stage != next_compute) + else 0 + ) + plan.append((load_stage, compute_stage, outstanding)) + return plan + def compile_wmma_gemm_tdm( *, @@ -55,7 +93,6 @@ def compile_wmma_gemm_tdm( _ = (M, N) if num_buffers not in (2, 3): raise ValueError(f"num_buffers must be 2 or 3, got {num_buffers}") - use_triple_buffer = num_buffers == 3 if in_dtype not in ("fp16", "bf16"): raise ValueError(f"in_dtype must be 'fp16' or 'bf16', got {in_dtype!r}") is_f16 = in_dtype == "fp16" @@ -110,43 +147,24 @@ def _elem_type(): # --- LDS allocation --- num_warps = m_warp * n_warp - if use_triple_buffer: - # Triple-buffer: 3 separate allocators (ping/pong/pang) - allocator_ping = SmemAllocator(None, arch=gpu_arch, global_sym_name="wmma_tdm_ping") - allocator_pong = SmemAllocator(None, arch=gpu_arch, global_sym_name="wmma_tdm_pong") - allocator_pang = SmemAllocator(None, arch=gpu_arch, global_sym_name="wmma_tdm_pang") - - ping_offset = allocator_ping._align(allocator_ping.ptr, 16) - allocator_ping.ptr = ping_offset + buf_size_elems * elem_bytes - pong_offset = allocator_pong._align(allocator_pong.ptr, 16) - allocator_pong.ptr = pong_offset + buf_size_elems * elem_bytes - pang_offset = allocator_pang._align(allocator_pang.ptr, 16) - allocator_pang.ptr = pang_offset + buf_size_elems * elem_bytes - - lds_a_offset_ping = ping_offset - lds_b_offset_ping = ping_offset + lds_a_elems * elem_bytes - lds_a_offset_pong = pong_offset - lds_b_offset_pong = pong_offset + lds_a_elems * elem_bytes - lds_a_offset_pang = pang_offset - lds_b_offset_pang = pang_offset + lds_a_elems * elem_bytes - - allocator_dbuf = None - else: - # Double-buffer: unified allocator with dynamic buffer selection - allocator_ping = None - allocator_pong = None - allocator_pang = None - - allocator_dbuf = SmemAllocator(None, arch=gpu_arch, global_sym_name="wmma_tdm_dbuf") - _dbuf0_off = allocator_dbuf._align(allocator_dbuf.ptr, 16) - allocator_dbuf.ptr = _dbuf0_off + buf_size_elems * elem_bytes - _dbuf1_off = allocator_dbuf._align(allocator_dbuf.ptr, 16) - allocator_dbuf.ptr = _dbuf1_off + buf_size_elems * elem_bytes - - lds_a_off_b0 = _dbuf0_off - lds_b_off_b0 = _dbuf0_off + lds_a_elems * elem_bytes - lds_a_off_b1 = _dbuf1_off - lds_b_off_b1 = _dbuf1_off + lds_a_elems * elem_bytes + stage_allocators = [] + stage_a_offsets = [] + stage_b_offsets = [] + for i in range(num_buffers): + name = _STAGE_NAMES[i] + alloc = SmemAllocator(None, arch=gpu_arch, global_sym_name=f"wmma_tdm_{name}") + off = alloc._align(alloc.ptr, 16) + alloc.ptr = off + buf_size_elems * elem_bytes + stage_allocators.append(alloc) + stage_a_offsets.append(off) + stage_b_offsets.append(off + lds_a_elems * elem_bytes) + + # Compile-time pipeline parameters + pre_loaded = num_buffers - 1 # stages pre-loaded in prologue + loop_iters = (num_k_tiles - pre_loaded) // num_buffers + _tail_start = loop_iters * num_buffers # index of first un-computed tile in tail + extra = num_k_tiles - _tail_start - pre_loaded + tail_plan = _make_tail_plan(num_buffers, pre_loaded, extra) @flyc.kernel def kernel_wmma_gemm_tdm( @@ -160,7 +178,6 @@ def kernel_wmma_gemm_tdm( bx = gpu.block_id("x") by = gpu.block_id("y") - n_stride = arith.index_cast(T.index, i32_n.ir_value()) blk_m = bx * arith.index(tile_m) blk_n = by * arith.index(tile_n) @@ -177,13 +194,10 @@ def kernel_wmma_gemm_tdm( elem_ty = _elem_type() - # --- Buffer resources --- + # --- Epilogue setup --- m_idx = arith.index_cast(T.index, i32_m.ir_value()) - a_nrec = m_idx * arith.index(K * elem_bytes) - b_nrec = n_stride * arith.index(K * elem_bytes) - c_nrec = m_idx * n_stride * arith.index(4) # f32 output - a_rsrc = buffer_ops.create_buffer_resource(arg_a, num_records_bytes=a_nrec) - b_rsrc = buffer_ops.create_buffer_resource(arg_b, num_records_bytes=b_nrec) + n_stride = arith.index(N) + c_nrec = m_idx * n_stride * arith.index(4) c_rsrc = buffer_ops.create_buffer_resource(arg_c, num_records_bytes=c_nrec) # --- TDM async copy helpers --- @@ -361,147 +375,64 @@ def _l2_prefetch(k_base): acc_zero = arith.constant_vector(0.0, T.vec(8, T.f32)) accs = [acc_zero] * n_accs - if use_triple_buffer: - # ====== Triple-buffer 3-stage pipeline ====== - base_ptr_ping = allocator_ping.get_base() - base_ptr_pong = allocator_pong.get_base() - base_ptr_pang = allocator_pang.get_base() - - lds_a_ping = SmemPtr(base_ptr_ping, lds_a_offset_ping, elem_ty, shape=(lds_a_elems,)) - lds_a_pong = SmemPtr(base_ptr_pong, lds_a_offset_pong, elem_ty, shape=(lds_a_elems,)) - lds_a_pang = SmemPtr(base_ptr_pang, lds_a_offset_pang, elem_ty, shape=(lds_a_elems,)) - lds_b_ping = SmemPtr(base_ptr_ping, lds_b_offset_ping, elem_ty, shape=(lds_b_elems,)) - lds_b_pong = SmemPtr(base_ptr_pong, lds_b_offset_pong, elem_ty, shape=(lds_b_elems,)) - lds_b_pang = SmemPtr(base_ptr_pang, lds_b_offset_pang, elem_ty, shape=(lds_b_elems,)) - - lds_a_ping_mem = lds_a_ping.get() - lds_a_pong_mem = lds_a_pong.get() - lds_a_pang_mem = lds_a_pang.get() - lds_b_ping_mem = lds_b_ping.get() - lds_b_pong_mem = lds_b_pong.get() - lds_b_pang_mem = lds_b_pang.get() - - # Prologue: load first 2 tiles - copy_b_to_lds(arith.index(0), lds_b_pong_mem) - copy_a_to_lds(arith.index(0), lds_a_pong_mem) - copy_b_to_lds(arith.index(tile_k), lds_b_ping_mem) - copy_a_to_lds(arith.index(tile_k), lds_a_ping_mem) - wait_and_barrier(outstanding=2) - - _safe_iters = max(0, num_k_tiles - 2) // 3 - _tiles_in_loop = _safe_iters * 3 - _tail_start = _tiles_in_loop - _n_tail = num_k_tiles - _tail_start - safe_loop_bound = _safe_iters * 3 * tile_k - - if _safe_iters > 0: - for iv, state in range(0, safe_loop_bound, tile_k * 3, init=list(accs)): - accs_in = list(state) - - copy_a_to_lds(iv + arith.index(tile_k * 2), lds_a_pang_mem) - copy_b_to_lds(iv + arith.index(tile_k * 2), lds_b_pang_mem) - _l2_prefetch(iv) - accs_in = _compute_and_schedule(accs_in, lds_a_pong, lds_b_pong) - wait_and_barrier(outstanding=2) - - copy_a_to_lds(iv + arith.index(tile_k * 3), lds_a_pong_mem) - copy_b_to_lds(iv + arith.index(tile_k * 3), lds_b_pong_mem) - _l2_prefetch(iv + arith.index(tile_k)) - accs_in = _compute_and_schedule(accs_in, lds_a_ping, lds_b_ping) - wait_and_barrier(outstanding=2) - - copy_a_to_lds(iv + arith.index(tile_k * 4), lds_a_ping_mem) - copy_b_to_lds(iv + arith.index(tile_k * 4), lds_b_ping_mem) - _l2_prefetch(iv + arith.index(tile_k * 2)) - accs_in = _compute_and_schedule(accs_in, lds_a_pang, lds_b_pang) + # Build per-stage SmemPtrs (one per pipeline stage) + base_ptrs = [sa.get_base() for sa in stage_allocators] + stages_a = [ + SmemPtr(base_ptrs[i], stage_a_offsets[i], elem_ty, shape=(lds_a_elems,)) + for i in range_constexpr(num_buffers) + ] + stages_b = [ + SmemPtr(base_ptrs[i], stage_b_offsets[i], elem_ty, shape=(lds_b_elems,)) + for i in range_constexpr(num_buffers) + ] + stages_a_mem = [stages_a[i].get() for i in range_constexpr(num_buffers)] + stages_b_mem = [stages_b[i].get() for i in range_constexpr(num_buffers)] + + # Prologue: load first (num_buffers - 1) tiles into stages 0..(num_buffers-2) + for i in range_constexpr(pre_loaded): + copy_a_to_lds(arith.index(i * tile_k), stages_a_mem[i]) + copy_b_to_lds(arith.index(i * tile_k), stages_b_mem[i]) + # Wait until stage[0] is ready; allow later-stage loads to still be in flight + # outstanding = 2 * (num_buffers - 2): 0 for double-buffer, 2 for triple-buffer + wait_and_barrier(outstanding=2 * (num_buffers - 2)) + + # Main loop: each iteration covers (num_buffers) K-tiles + # Sub-phase s: load the "next" tile, compute the "current" tile, then barrier + # load_stage = (s + num_buffers - 1) % num_buffers + # load_offset = iv + (s + num_buffers - 1) * tile_k + # compute stage[s] + main_end = loop_iters * num_buffers * tile_k + + if loop_iters > 0: + for iv, state in range(0, main_end, num_buffers * tile_k, init=list(accs)): + accs_in = list(state) + for s in range_constexpr(num_buffers): + _load_stage = (s + num_buffers - 1) % num_buffers + _load_k_off = (s + num_buffers - 1) * tile_k + copy_a_to_lds(iv + arith.index(_load_k_off), stages_a_mem[_load_stage]) + copy_b_to_lds(iv + arith.index(_load_k_off), stages_b_mem[_load_stage]) + _l2_prefetch(iv + arith.index(s * tile_k)) + accs_in = _compute_and_schedule(accs_in, stages_a[s], stages_b[s]) wait_and_barrier(outstanding=2) - - results = yield list(accs_in) - accs = list(results) - - t0 = _tail_start - if _n_tail == 2: - accs = _compute_and_schedule(accs, lds_a_pong, lds_b_pong) - wait_and_barrier(outstanding=0) - accs = compute_tile(accs, lds_a_ping, lds_b_ping) - elif _n_tail == 3: - copy_a_to_lds(arith.index((t0 + 2) * tile_k), lds_a_pang_mem) - copy_b_to_lds(arith.index((t0 + 2) * tile_k), lds_b_pang_mem) - accs = _compute_and_schedule(accs, lds_a_pong, lds_b_pong) - wait_and_barrier(outstanding=2) - accs = _compute_and_schedule(accs, lds_a_ping, lds_b_ping) - wait_and_barrier(outstanding=0) - accs = compute_tile(accs, lds_a_pang, lds_b_pang) - elif _n_tail == 4: - copy_a_to_lds(arith.index((t0 + 2) * tile_k), lds_a_pang_mem) - copy_b_to_lds(arith.index((t0 + 2) * tile_k), lds_b_pang_mem) - accs = _compute_and_schedule(accs, lds_a_pong, lds_b_pong) - wait_and_barrier(outstanding=2) - copy_a_to_lds(arith.index((t0 + 3) * tile_k), lds_a_pong_mem) - copy_b_to_lds(arith.index((t0 + 3) * tile_k), lds_b_pong_mem) - accs = _compute_and_schedule(accs, lds_a_ping, lds_b_ping) - wait_and_barrier(outstanding=2) - accs = _compute_and_schedule(accs, lds_a_pang, lds_b_pang) - wait_and_barrier(outstanding=0) - accs = compute_tile(accs, lds_a_pong, lds_b_pong) + results = yield list(accs_in) + accs = list(results) + + # Tail: handle remaining tiles using the compile-time plan + # Each plan step: optionally load one tile, compute one stage, then wait. + # outstanding=-1 → last step: use compute_tile (no barrier). + _extra_j = 0 + for _load_stage, _compute_stage, _outstanding in tail_plan: + if _load_stage is not None: + _k_off = (_tail_start + pre_loaded + _extra_j) * tile_k + copy_a_to_lds(arith.index(_k_off), stages_a_mem[_load_stage]) + copy_b_to_lds(arith.index(_k_off), stages_b_mem[_load_stage]) + _extra_j += 1 + if _outstanding == -1: + accs = compute_tile(accs, stages_a[_compute_stage], stages_b[_compute_stage]) else: - raise RuntimeError(f"Unexpected _n_tail={_n_tail}") - - else: - # ====== Double-buffer 2-stage SCF pipeline ====== - base_uni = allocator_dbuf.get_base() - _a_vtype = mlir_T.memref(lds_a_elems, elem_ty, memory_space=lds_space()) - _b_vtype = mlir_T.memref(lds_b_elems, elem_ty, memory_space=lds_space()) - - def _mk_a_view(off_val): - return memref_d.view(_a_vtype, base_uni, off_val, sizes=[]) - def _mk_b_view(off_val): - return memref_d.view(_b_vtype, base_uni, off_val, sizes=[]) - - c_a0 = arith.index(lds_a_off_b0) - c_b0 = arith.index(lds_b_off_b0) - c_a1 = arith.index(lds_a_off_b1) - c_b1 = arith.index(lds_b_off_b1) - - # Prologue: load k=0 → buf0 - copy_a_to_lds(arith.index(0), _mk_a_view(c_a0)) - copy_b_to_lds(arith.index(0), _mk_b_view(c_b0)) - wait_and_barrier() - - # Main loop: each iteration loads NEXT tile, computes CURRENT - main_end = (num_k_tiles - 1) * tile_k - init_st = list(accs) + [arith.index(0)] - - for iv, state in range(0, main_end, tile_k, init=init_st): - accs_in = list(state[:n_accs]) - buf_flag = state[n_accs] - is_buf0 = arith.cmpi(arith.CmpIPredicate.eq, buf_flag, arith.index(0)) - - comp_a = arith.select(is_buf0, c_a0, c_a1) - comp_b = arith.select(is_buf0, c_b0, c_b1) - load_a = arith.select(is_buf0, c_a1, c_a0) - load_b = arith.select(is_buf0, c_b1, c_b0) - - next_k = iv + arith.index(tile_k) - copy_a_to_lds(next_k, _mk_a_view(load_a)) - copy_b_to_lds(next_k, _mk_b_view(load_b)) - _l2_prefetch(iv) - - accs_in = compute_tile(accs_in, _mk_a_view(comp_a), _mk_b_view(comp_b)) - hot_loop_scheduler() - wait_and_barrier(outstanding=2) - - next_flag = arith.select(is_buf0, arith.index(1), arith.index(0)) - results = yield list(accs_in) + [next_flag] - - accs = list(results[:n_accs]) - last_flag = results[n_accs] - - # Tail: compute the last loaded tile - is_last_b0 = arith.cmpi(arith.CmpIPredicate.eq, last_flag, arith.index(0)) - tail_a = arith.select(is_last_b0, c_a0, c_a1) - tail_b = arith.select(is_last_b0, c_b0, c_b1) - accs = compute_tile(accs, _mk_a_view(tail_a), _mk_b_view(tail_b)) + accs = _compute_and_schedule( + accs, stages_a[_compute_stage], stages_b[_compute_stage]) + wait_and_barrier(outstanding=_outstanding) epilogue(accs) @@ -520,16 +451,10 @@ def launch_wmma_gemm_tdm( _ = cache_tag ctx = CompilationContext.get_current() with ir.InsertionPoint(ctx.gpu_module_body): - if use_triple_buffer: - allocator_ping.finalized = False - allocator_pong.finalized = False - allocator_pang.finalized = False - allocator_ping.finalize() - allocator_pong.finalize() - allocator_pang.finalize() - else: - allocator_dbuf.finalized = False - allocator_dbuf.finalize() + for alloc in stage_allocators: + alloc.finalized = False + for alloc in stage_allocators: + alloc.finalize() idx_m = arith.index_cast(T.index, i32_m.ir_value()) idx_n = arith.index_cast(T.index, i32_n.ir_value()) From 0e5cdbc178ba5ecddc562c11ca6ad4fa638786d8 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Thu, 12 Mar 2026 13:20:17 +0000 Subject: [PATCH 03/11] pre-calc epilogue addresses to eliminate all s_set_vgpr_msb --- kernels/wmma_gemm_gfx1250.py | 39 ++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/kernels/wmma_gemm_gfx1250.py b/kernels/wmma_gemm_gfx1250.py index 0559ec11..8b547a13 100644 --- a/kernels/wmma_gemm_gfx1250.py +++ b/kernels/wmma_gemm_gfx1250.py @@ -301,13 +301,14 @@ def do_k_subtile_wmma(a_frags, b_frags, accs): return accs # --- Compute on one LDS buffer (K-subtile pipelined) --- - def compute_tile(accs_in, lds_a_ptr, lds_b_ptr): - rocdl.sched_barrier(0) + def compute_tile(accs_in, lds_a_ptr, lds_b_ptr, emit_filler=None): current_accs = list(accs_in) if k_wmma_steps == 1: a_frags, b_frags = load_k_subtile_frags(lds_a_ptr, lds_b_ptr, 0) rocdl.s_wait_dscnt(0) + if emit_filler is not None: + emit_filler() current_accs = do_k_subtile_wmma(a_frags, b_frags, current_accs) else: # Prologue: batch-load K-subtile 0 @@ -321,8 +322,10 @@ def compute_tile(accs_in, lds_a_ptr, lds_b_ptr): current_accs = do_k_subtile_wmma(prev_a, prev_b, current_accs) prev_a, prev_b = next_a, next_b - # Epilogue: wait for last subtile, then compute rocdl.s_wait_dscnt(0) + if emit_filler is not None: + rocdl.sched_barrier(0) + emit_filler() current_accs = do_k_subtile_wmma(prev_a, prev_b, current_accs) return current_accs @@ -332,23 +335,35 @@ def hot_loop_scheduler(): rocdl.sched_barrier(0) # --- Epilogue: vectorized buffer_store_b128 --- - def epilogue(final_accs): + def epilogue_prepare_addrs(): + """Precompute all epilogue store addresses (VALU only, no stores). """ + addrs = [] for wm in range_constexpr(wmma_m_rep): for wn in range_constexpr(wmma_n_rep): - idx = wm * wmma_n_rep + wn row = blk_m + warp_m_base + arith.index(wm * WMMA_M) + lane16 col_base = (blk_n + warp_n_base + arith.index(wn * WMMA_N) + lane_kgrp * arith.index(8)) for half in range_constexpr(2): col = col_base + arith.index(half * 4) c_off = row * n_stride + col + addrs.append(c_off) + return addrs + + def epilogue_stores(final_accs, addrs): + """Execute buffer_store using precomputed addresses.""" + addr_idx = 0 + for wm in range_constexpr(wmma_m_rep): + for wn in range_constexpr(wmma_n_rep): + idx = wm * wmma_n_rep + wn + for half in range_constexpr(2): vals = [vector.extract( final_accs[idx], static_position=[half * 4 + vi], dynamic_position=[]) for vi in range_constexpr(4)] vec4 = vector.from_elements(T.vec(4, T.f32), vals) - buffer_ops.buffer_store(vec4, c_rsrc, c_off) + buffer_ops.buffer_store(vec4, c_rsrc, addrs[addr_idx]) + addr_idx += 1 # --- Pipeline helpers --- def wait_and_barrier(outstanding=0): @@ -356,6 +371,7 @@ def wait_and_barrier(outstanding=0): gpu.barrier() def _compute_and_schedule(accs_in, lds_a, lds_b): + rocdl.sched_barrier(0) accs_out = compute_tile(accs_in, lds_a, lds_b) hot_loop_scheduler() return accs_out @@ -428,13 +444,20 @@ def _l2_prefetch(k_base): copy_b_to_lds(arith.index(_k_off), stages_b_mem[_load_stage]) _extra_j += 1 if _outstanding == -1: - accs = compute_tile(accs, stages_a[_compute_stage], stages_b[_compute_stage]) + epi_addrs_box = [None] + + def _emit_epi_addrs(): + epi_addrs_box[0] = epilogue_prepare_addrs() + + accs = compute_tile( + accs, stages_a[_compute_stage], stages_b[_compute_stage], + emit_filler=_emit_epi_addrs) else: accs = _compute_and_schedule( accs, stages_a[_compute_stage], stages_b[_compute_stage]) wait_and_barrier(outstanding=_outstanding) - epilogue(accs) + epilogue_stores(accs, epi_addrs_box[0]) cache_tag = (in_dtype, K, tile_m, tile_n, tile_k, m_warp, n_warp, num_buffers, waves_per_eu, l2_prefetch_distance) From 3c95640e555dacd93e974e526b7713f2d78ffe99 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Fri, 13 Mar 2026 03:37:45 +0000 Subject: [PATCH 04/11] use ds_load_tr16_b128 to eliminate B bank conflicts --- kernels/wmma_gemm_gfx1250.py | 134 +++++++++++++++--------- python/flydsl/expr/rocdl.py | 38 +++++++ python/flydsl/expr/rocdl/__init__.py | 37 +++++++ tests/kernels/test_wmma_gemm_gfx1250.py | 10 +- 4 files changed, 165 insertions(+), 54 deletions(-) diff --git a/kernels/wmma_gemm_gfx1250.py b/kernels/wmma_gemm_gfx1250.py index 8b547a13..8fd4ccaf 100644 --- a/kernels/wmma_gemm_gfx1250.py +++ b/kernels/wmma_gemm_gfx1250.py @@ -15,7 +15,7 @@ from flydsl.runtime.device import get_rocm_arch as get_hip_arch from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr, get_op_result_or_value -from kernels.layout_utils import crd2idx, idx2crd +from kernels.layout_utils import idx2crd WMMA_M, WMMA_N, WMMA_K = 16, 16, 32 WAVE_SIZE = 32 @@ -138,9 +138,9 @@ def _elem_type(): n_accs = wmma_m_rep * wmma_n_rep lds_a_stride = tile_k + LDS_PAD_A - lds_b_stride = tile_k + LDS_PAD_B + lds_b_stride = tile_n + LDS_PAD_B lds_a_elems = tile_m * lds_a_stride + LDS_PAD_A - lds_b_elems = tile_n * lds_b_stride + LDS_PAD_A + lds_b_elems = tile_k * lds_b_stride + LDS_PAD_B buf_size_elems = lds_a_elems + lds_b_elems @@ -214,74 +214,104 @@ def copy_a_to_lds(k_base, lds_a_mem_ref): def copy_b_to_lds(k_base, lds_b_mem_ref): desc = tdm_ops.make_tensor_descriptor_2d( global_ptr=arg_b, lds_memref=lds_b_mem_ref, - global_offset=(blk_n, k_base), - tensor_shape=(tile_n, tile_k), strides=(K, 1), - tile_shape=(tile_n, tile_k), elem_bytes=elem_bytes, - pad_interval=tile_k, pad_amount=LDS_PAD_B, + global_offset=(k_base, blk_n), + tensor_shape=(tile_k, tile_n), strides=(N, 1), + tile_shape=(tile_k, tile_n), elem_bytes=elem_bytes, + pad_interval=tile_n, pad_amount=LDS_PAD_B, num_warps=num_warps) tdm_ops.tensor_load_2d(desc) - layout_smem_a = fx.make_layout((tile_m, lds_a_stride), (lds_a_stride, 1)) - layout_smem_b = fx.make_layout((tile_n, lds_b_stride), (lds_b_stride, 1)) - # --- LDS load helpers --- - from flydsl._mlir.dialects import vector as vec_d - - FRAG_K_ELEMS = 8 - def _get_lds_memref(lds_ptr): """Get the raw memref value from SmemPtr or raw memref.""" if isinstance(lds_ptr, SmemPtr): return get_op_result_or_value(lds_ptr.get()) return get_op_result_or_value(lds_ptr) - def load_wmma_frag(lds_ptr, row_base, k_base, lds_layout): + def _precompute_a_lane_bases(lds_ptr): + """Precompute per-wm A fragment lane base addresses. + + Returns (lds_buffer, bases) where bases[wm] = + (warp_m_base + wm*WMMA_M + lane16) * lds_a_stride + lane_kgrp * 8 + """ + lds_buffer = _get_lds_memref(lds_ptr) + row_stride_off = (warp_m_base + lane16) * arith.index(lds_a_stride) + k_lane_off = lane_kgrp * arith.index(8) + bases = [] + for wm in range_constexpr(wmma_m_rep): + a_base = row_stride_off + arith.index(wm * WMMA_M * lds_a_stride) + k_lane_off + bases.append(a_base) + return lds_buffer, bases + + def load_wmma_frag(a_lds_buffer, a_lane_base, ks): """Load one 16x32 WMMA fragment from LDS using vectorized 128-bit loads. - Uses vector.load to read 8 contiguous fp16 elements at once, - avoiding scalar load + v_perm/v_alignbit overhead. - Two 128-bit loads per fragment (2 × 8 fp16 = 16 fp16 values). + a_lane_base is precomputed by _precompute_a_lane_bases. + ks is the K-subtile index (compile-time constant). """ - raw_memref = _get_lds_memref(lds_ptr) - row = row_base + lane16 vec8_ty = ir.VectorType.get([8], elem_ty) - # Two K-groups per fragment: - # Group 0 (values 0-7): k = k_base + lane_kgrp * 8 - # Group 1 (values 8-15): k = k_base + (2 + lane_kgrp) * 8 - k0 = k_base + lane_kgrp * arith.index(8) - k1 = k_base + (arith.index(2) + lane_kgrp) * arith.index(8) + off0 = a_lane_base + arith.index(ks * WMMA_K) + off1 = a_lane_base + arith.index(ks * WMMA_K + 16) - off0 = crd2idx((row, k0), lds_layout) - off1 = crd2idx((row, k1), lds_layout) + v0 = vector.load_op(vec8_ty, a_lds_buffer, [off0]) + v1 = vector.load_op(vec8_ty, a_lds_buffer, [off1]) - idx0 = [get_op_result_or_value(off0)] - idx1 = [get_op_result_or_value(off1)] + return vector.shuffle(v0, v1, list(range(16))) - v0 = vec_d.load(vec8_ty, raw_memref, idx0) - v1 = vec_d.load(vec8_ty, raw_memref, idx1) + def _precompute_b_lane_bases(lds_ptr): + """Precompute per-wn B fragment lane base addresses. - # Concatenate two vec<8> into vec<16> via vector.shuffle - mask = ir.DenseI64ArrayAttr.get(list(range(16))) - return vec_d.shuffle(v0, v1, mask) + Returns a list of (lds_buffer, b_lane_base) for each wn. + b_lane_base = (lane_kgrp*8 + lane8) * lds_b_stride + + (warp_n_base + wn*WMMA_N + lane_ngrp*8) + where lane8 = lane16 % 8, lane_ngrp = lane16 / 8. + + After precompute, lane8/lane_ngrp are dead → frees VGPRs. + """ + lds_buffer = _get_lds_memref(lds_ptr) + lane8 = lane16 % arith.index(8) + lane_ngrp = lane16 / arith.index(8) + k_lane_off = (lane_kgrp * arith.index(8) + lane8) * arith.index(lds_b_stride) + n_lane_off = lane_ngrp * arith.index(8) + bases = [] + for wn in range_constexpr(wmma_n_rep): + n_col = warp_n_base + arith.index(wn * WMMA_N) + n_lane_off + b_base = k_lane_off + n_col + bases.append(b_base) + return lds_buffer, bases + + def load_wmma_frag_tr(lds_buffer, b_lane_base, ks): + """Load one 16x32 WMMA B fragment using ds_load_tr16_b128. + + b_lane_base is precomputed by _precompute_b_lane_bases. + ks is the K-subtile index (compile-time constant from range_constexpr). + The K offset is folded into a compile-time constant multiplication. + """ + vec8_ty = ir.VectorType.get([8], elem_ty) + results = [] + for k_half in range_constexpr(2): + k_row_off = (ks * WMMA_K + k_half * 16) * lds_b_stride + elem_off = b_lane_base + arith.index(k_row_off) + v = rocdl.lds_transpose_load(vec8_ty, lds_buffer, elem_off, elem_bytes) + results.append(v) + return vector.shuffle(results[0], results[1], list(range(16))) # --- K-subtile load/compute helpers --- - # Number of ds_load_b128 per K-subtile: - # B frags: wmma_n_rep * 2, A frags: wmma_m_rep * 2 + # Number of LDS loads per K-subtile: + # B frags: wmma_n_rep * 2 (ds_load_tr16_b128), A frags: wmma_m_rep * 2 LOADS_PER_SUBTILE = (wmma_m_rep + wmma_n_rep) * 2 - def load_k_subtile_frags(lds_a_ptr, lds_b_ptr, ks): - """Batch-load all A and B fragments for one K-subtile (no wait).""" - k_off = arith.index(ks * WMMA_K) + def load_k_subtile_frags(a_lds_buffer, a_bases, b_lds_buffer, b_bases, ks): + """Batch-load all A and B fragments for one K-subtile (no wait). - b_frags = [load_wmma_frag( - lds_b_ptr, warp_n_base + arith.index(wn * WMMA_N), - k_off, layout_smem_b) + All base addresses are precomputed by _precompute_{a,b}_lane_bases. + ks is the K-subtile index (compile-time constant). + """ + b_frags = [load_wmma_frag_tr(b_lds_buffer, b_bases[wn], ks) for wn in range_constexpr(wmma_n_rep)] - a_frags = [load_wmma_frag( - lds_a_ptr, warp_m_base + arith.index(wm * WMMA_M), - k_off, layout_smem_a) + a_frags = [load_wmma_frag(a_lds_buffer, a_bases[wm], ks) for wm in range_constexpr(wmma_m_rep)] return a_frags, b_frags @@ -304,20 +334,26 @@ def do_k_subtile_wmma(a_frags, b_frags, accs): def compute_tile(accs_in, lds_a_ptr, lds_b_ptr, emit_filler=None): current_accs = list(accs_in) + # Precompute all lane bases once per tile + a_lds_buffer, a_bases = _precompute_a_lane_bases(lds_a_ptr) + b_lds_buffer, b_bases = _precompute_b_lane_bases(lds_b_ptr) + if k_wmma_steps == 1: - a_frags, b_frags = load_k_subtile_frags(lds_a_ptr, lds_b_ptr, 0) + a_frags, b_frags = load_k_subtile_frags( + a_lds_buffer, a_bases, b_lds_buffer, b_bases, 0) rocdl.s_wait_dscnt(0) if emit_filler is not None: emit_filler() current_accs = do_k_subtile_wmma(a_frags, b_frags, current_accs) else: # Prologue: batch-load K-subtile 0 - prev_a, prev_b = load_k_subtile_frags(lds_a_ptr, lds_b_ptr, 0) + prev_a, prev_b = load_k_subtile_frags( + a_lds_buffer, a_bases, b_lds_buffer, b_bases, 0) # Main K-loop: overlap load[ks+1] with compute[ks] for ks in range_constexpr(k_wmma_steps - 1): next_a, next_b = load_k_subtile_frags( - lds_a_ptr, lds_b_ptr, ks + 1) + a_lds_buffer, a_bases, b_lds_buffer, b_bases, ks + 1) rocdl.s_wait_dscnt(LOADS_PER_SUBTILE) current_accs = do_k_subtile_wmma(prev_a, prev_b, current_accs) prev_a, prev_b = next_a, next_b @@ -384,7 +420,7 @@ def _l2_prefetch(k_base): arg_a, (blk_m, pf_k), (tile_m, tile_k), (K, 1), elem_bytes=elem_bytes, thread_id=tx, block_threads=block_threads) tdm_ops.l2_prefetch_tile( - arg_b, (blk_n, pf_k), (tile_n, tile_k), (K, 1), + arg_b, (pf_k, blk_n), (tile_k, tile_n), (N, 1), elem_bytes=elem_bytes, thread_id=tx, block_threads=block_threads) # ====== Multi-stage pipeline ====== diff --git a/python/flydsl/expr/rocdl.py b/python/flydsl/expr/rocdl.py index b2a979a0..c00340eb 100644 --- a/python/flydsl/expr/rocdl.py +++ b/python/flydsl/expr/rocdl.py @@ -332,6 +332,43 @@ def wave_id(): return _ods_wave_id(i32) +def lds_transpose_load(result_type, lds_memref, elem_offset, elem_bytes): + """Transpose-load from LDS memref via ds_load_tr16_b128 (gfx1250). + + Args: + result_type: Vector result type, e.g. ``VectorType.get([8], f16)``. + lds_memref: LDS memref value (address-space 3), typically from + ``SmemPtr.get()`` or ``get_op_result_or_value(...)``. + elem_offset: Per-lane linearized element offset into the memref + (ArithValue / ir.Value of index type / Python int). + elem_bytes: Element size in bytes (Python int, e.g. 2 for f16). + + Returns: + Loaded and transposed vector ``ir.Value``. + """ + from .._mlir import ir as _ir + from .._mlir.dialects import ( + llvm as _llvm, + memref as _memref, + rocdl as _rocdl, + ) + from . import arith as _arith + from .arith import _to_raw + from .typing import T + from .utils.arith import ArithValue as _AV + + lds_ptr_ty = _ir.Type.parse("!llvm.ptr<3>") + raw_memref = _arith.unwrap(lds_memref) + lds_base = _memref.extract_aligned_pointer_as_index(raw_memref) + + byte_off = _AV(_arith.unwrap(elem_offset, index=True)) * _arith.index(elem_bytes) + total_byte_idx = _AV(lds_base) + byte_off + addr_i32 = _to_raw(_arith.index_cast(T.i32, total_byte_idx)) + ptr_val = _llvm.inttoptr(lds_ptr_ty, addr_i32) + + return _rocdl.ds_load_tr16_b128(result_type, ptr_val) + + __all__ = [ # Thread/Block/Grid IDs and dimensions 'workitem_id_x', 'workitem_id_y', 'workitem_id_z', @@ -411,6 +448,7 @@ def wave_id(): # Convenience wrappers 'make_buffer_tensor', + 'lds_transpose_load', # memref-level wrapper for gfx1250 ds_load_tr16_b128 # gfx1250 TDM - descriptor-driven tile copy (preferred over per-lane) 'tensor_load_to_lds', # 4-group, up to 5D tensor diff --git a/python/flydsl/expr/rocdl/__init__.py b/python/flydsl/expr/rocdl/__init__.py index 7c0f393e..14efa911 100644 --- a/python/flydsl/expr/rocdl/__init__.py +++ b/python/flydsl/expr/rocdl/__init__.py @@ -113,6 +113,43 @@ def wave_id(): return _ods_wave_id(i32) +def lds_transpose_load(result_type, lds_memref, elem_offset, elem_bytes): + """Transpose-load from LDS memref via ds_load_tr16_b128 (gfx1250). + + Args: + result_type: Vector result type, e.g. ``VectorType.get([8], f16)``. + lds_memref: LDS memref value (address-space 3), typically from + ``SmemPtr.get()`` or ``get_op_result_or_value(...)``. + elem_offset: Per-lane linearized element offset into the memref + (ArithValue / ir.Value of index type / Python int). + elem_bytes: Element size in bytes (Python int, e.g. 2 for f16). + + Returns: + Loaded and transposed vector ``ir.Value``. + """ + from ..._mlir import ir as _ir + from ..._mlir.dialects import ( + llvm as _llvm, + memref as _memref, + rocdl as _rocdl, + ) + from .. import arith as _arith + from ..arith import _to_raw + from ..typing import T + from ..utils.arith import ArithValue as _AV + + lds_ptr_ty = _ir.Type.parse("!llvm.ptr<3>") + raw_memref = _arith.unwrap(lds_memref) + lds_base = _memref.extract_aligned_pointer_as_index(raw_memref) + + byte_off = _AV(_arith.unwrap(elem_offset, index=True)) * _arith.index(elem_bytes) + total_byte_idx = _AV(lds_base) + byte_off + addr_i32 = _to_raw(_arith.index_cast(T.i32, total_byte_idx)) + ptr_val = _llvm.inttoptr(lds_ptr_ty, addr_i32) + + return _rocdl.ds_load_tr16_b128(result_type, ptr_val) + + # ── New high-level helpers from universal.py ────────────────────────── from .universal import * # noqa: F401,F403 diff --git a/tests/kernels/test_wmma_gemm_gfx1250.py b/tests/kernels/test_wmma_gemm_gfx1250.py index 068302d3..5eb7d13b 100644 --- a/tests/kernels/test_wmma_gemm_gfx1250.py +++ b/tests/kernels/test_wmma_gemm_gfx1250.py @@ -59,7 +59,7 @@ def test_wmma_gemm_tdm(in_dtype, M, N, K, tile_m, tile_n, tile_k, lds_pad = 8 elem_bytes = 2 a_buf = tile_m * (tile_k + lds_pad) * elem_bytes - b_buf = tile_n * (tile_k + lds_pad) * elem_bytes + b_buf = tile_k * (tile_n + lds_pad) * elem_bytes total_lds = (a_buf + b_buf) * num_buffers if total_lds > 327680: pytest.skip(f"LDS budget exceeded: {total_lds} > 327680") @@ -75,12 +75,12 @@ def test_wmma_gemm_tdm(in_dtype, M, N, K, tile_m, tile_n, tile_k, npad = (N + tile_n - 1) // tile_n * tile_n a = torch.randn((M, K), dtype=torch_dtype, device='cpu').cuda() - b = torch.randn((N, K), dtype=torch_dtype, device='cpu').cuda() + b = torch.randn((K, N), dtype=torch_dtype, device='cpu').cuda() a_pad = torch.zeros((mpad, K), dtype=torch_dtype, device=device) - b_pad = torch.zeros((npad, K), dtype=torch_dtype, device=device) + b_pad = torch.zeros((K, npad), dtype=torch_dtype, device=device) a_pad[:M, :] = a - b_pad[:N, :] = b + b_pad[:, :N] = b c_pad = torch.zeros((mpad, npad), dtype=torch.float32, device=device) @@ -99,7 +99,7 @@ def test_wmma_gemm_tdm(in_dtype, M, N, K, tile_m, tile_n, tile_k, ) torch.cuda.synchronize() - ref = torch.mm(a.cpu().to(torch.float32), b.cpu().to(torch.float32).T) + ref = torch.mm(a.cpu().to(torch.float32), b.cpu().to(torch.float32)) rtol = 3e-2 atol = 3e-2 assert verify_output(c_pad[:M, :N].cpu().to(torch.float32), ref, rtol=rtol, atol=atol) From bb8f2a6756d5fab2766e5cdc67bf7b3ea1b3832d Mon Sep 17 00:00:00 2001 From: aoli26 Date: Sat, 14 Mar 2026 00:57:24 +0000 Subject: [PATCH 05/11] support mcast --- kernels/wmma_gemm_gfx1250.py | 103 +++++++++++++++---- lib/Runtime/FlyRocmRuntimeWrappers.cpp | 55 ++++++++++ python/flydsl/__init__.py | 33 ++++++ python/flydsl/compiler/kernel_function.py | 14 +++ python/flydsl/expr/gpu.py | 119 +++++++++++++++++++++- python/flydsl/expr/tdm_ops.py | 19 +++- python/flydsl/utils/smem_allocator.py | 4 +- tests/kernels/test_wmma_gemm_gfx1250.py | 34 ++++++- 8 files changed, 351 insertions(+), 30 deletions(-) diff --git a/kernels/wmma_gemm_gfx1250.py b/kernels/wmma_gemm_gfx1250.py index 8fd4ccaf..469d0ee3 100644 --- a/kernels/wmma_gemm_gfx1250.py +++ b/kernels/wmma_gemm_gfx1250.py @@ -10,6 +10,7 @@ from flydsl._mlir import ir from flydsl.compiler.kernel_function import CompilationContext from flydsl.expr import arith, buffer_ops, gpu, range_constexpr, rocdl, tdm_ops, vector +from flydsl._mlir.dialects import llvm as llvm_dialect from flydsl.expr.arith import _to_raw as _raw from flydsl.expr.typing import T from flydsl.runtime.device import get_rocm_arch as get_hip_arch @@ -79,6 +80,8 @@ def compile_wmma_gemm_tdm( num_buffers: int = 2, waves_per_eu: int = None, l2_prefetch_distance: int = 2, + cluster_m: int = 1, + cluster_n: int = 1, ): """Compile a WMMA GEMM kernel with TDM async copy and multi-stage buffering. @@ -89,6 +92,8 @@ def compile_wmma_gemm_tdm( waves_per_eu: Occupancy hint (None = default, 1-4 = limit occupancy). l2_prefetch_distance: Number of k-tiles ahead to prefetch into L2. 0 = disabled, 2 = typical value. + cluster_m: Cluster dimension along M (WG rows per cluster, 1=disabled). + cluster_n: Cluster dimension along N (WG cols per cluster, 1=disabled). """ _ = (M, N) if num_buffers not in (2, 3): @@ -98,6 +103,20 @@ def compile_wmma_gemm_tdm( is_f16 = in_dtype == "fp16" elem_bytes = 2 + use_cluster = cluster_m > 1 or cluster_n > 1 + if use_cluster: + if cluster_m * cluster_n > 16: + raise ValueError( + f"cluster_m * cluster_n must be <= 16, got {cluster_m}*{cluster_n}={cluster_m * cluster_n}") + if cluster_m < 1 or cluster_n < 1: + raise ValueError(f"cluster dims must be >= 1, got ({cluster_m}, {cluster_n})") + effective_waves_per_eu = waves_per_eu + if use_cluster and effective_waves_per_eu is None: + # Cluster mode can deadlock if a workgroup is split and only a subset + # of its waves are resident while hitting early workgroup barriers. + # Use conservative occupancy by default for cluster-enabled kernels. + effective_waves_per_eu = 1 + block_threads = m_warp * n_warp * WAVE_SIZE if K % tile_k != 0: @@ -174,6 +193,15 @@ def kernel_wmma_gemm_tdm( i32_m: fx.Int32, i32_n: fx.Int32, ): + # Enable back-to-back WMMA issue (SCHED_MODE bit[4] = DISABLE_VALU_STALL) + # hwreg(26, 4, 1) = HW_REG_SCHED_MODE, offset=4, size=1 + llvm_dialect.inline_asm( + None, [], # void result, no operands + "s_setreg_imm32_b32 hwreg(26, 4, 1), 1", + "", # no constraints + has_side_effects=True, + ) + tx = gpu.thread_id("x") bx = gpu.block_id("x") by = gpu.block_id("y") @@ -181,6 +209,15 @@ def kernel_wmma_gemm_tdm( blk_m = bx * arith.index(tile_m) blk_n = by * arith.index(tile_n) + # --- Cluster MCAST setup --- + if use_cluster: + local_x, local_y = gpu.compute_cluster_position(bx, by, cluster_m, cluster_n) + a_mcast_mask, b_mcast_mask = gpu.compute_mcast_masks( + local_x, local_y, cluster_m, cluster_n) + else: + a_mcast_mask = 0 + b_mcast_mask = 0 + # --- Thread/wave decomposition --- layout_thr = fx.make_layout( (m_warp, n_warp, 2, 16), @@ -200,7 +237,7 @@ def kernel_wmma_gemm_tdm( c_nrec = m_idx * n_stride * arith.index(4) c_rsrc = buffer_ops.create_buffer_resource(arg_c, num_records_bytes=c_nrec) - # --- TDM async copy helpers --- + # --- TDM async copy helpers (MCAST-aware) --- def copy_a_to_lds(k_base, lds_a_mem_ref): desc = tdm_ops.make_tensor_descriptor_2d( global_ptr=arg_a, lds_memref=lds_a_mem_ref, @@ -208,7 +245,8 @@ def copy_a_to_lds(k_base, lds_a_mem_ref): tensor_shape=(tile_m, tile_k), strides=(K, 1), tile_shape=(tile_m, tile_k), elem_bytes=elem_bytes, pad_interval=tile_k, pad_amount=LDS_PAD_A, - num_warps=num_warps) + num_warps=num_warps, + workgroup_mask=a_mcast_mask) tdm_ops.tensor_load_2d(desc) def copy_b_to_lds(k_base, lds_b_mem_ref): @@ -218,7 +256,8 @@ def copy_b_to_lds(k_base, lds_b_mem_ref): tensor_shape=(tile_k, tile_n), strides=(N, 1), tile_shape=(tile_k, tile_n), elem_bytes=elem_bytes, pad_interval=tile_n, pad_amount=LDS_PAD_B, - num_warps=num_warps) + num_warps=num_warps, + workgroup_mask=b_mcast_mask) tdm_ops.tensor_load_2d(desc) # --- LDS load helpers --- @@ -406,16 +445,30 @@ def wait_and_barrier(outstanding=0): tdm_ops.tensor_wait(outstanding) gpu.barrier() + def wait_and_cluster_barrier(outstanding=0): + """Fused WG barrier + cluster sync: reduces instruction overhead + by issuing the cluster signal while tensor_wait is still draining, + then waiting for both to complete.""" + tdm_ops.tensor_wait(outstanding) + if use_cluster: + gpu.cluster_barrier() + else: + gpu.barrier() + def _compute_and_schedule(accs_in, lds_a, lds_b): rocdl.sched_barrier(0) accs_out = compute_tile(accs_in, lds_a, lds_b) hot_loop_scheduler() return accs_out + _effective_l2_pf = l2_prefetch_distance + if use_cluster and l2_prefetch_distance > 0: + _effective_l2_pf = max(1, l2_prefetch_distance - 1) + def _l2_prefetch(k_base): - if l2_prefetch_distance <= 0: + if _effective_l2_pf <= 0: return - pf_k = k_base + arith.index(l2_prefetch_distance * tile_k) + pf_k = k_base + arith.index(_effective_l2_pf * tile_k) tdm_ops.l2_prefetch_tile( arg_a, (blk_m, pf_k), (tile_m, tile_k), (K, 1), elem_bytes=elem_bytes, thread_id=tx, block_threads=block_threads) @@ -444,15 +497,12 @@ def _l2_prefetch(k_base): for i in range_constexpr(pre_loaded): copy_a_to_lds(arith.index(i * tile_k), stages_a_mem[i]) copy_b_to_lds(arith.index(i * tile_k), stages_b_mem[i]) - # Wait until stage[0] is ready; allow later-stage loads to still be in flight - # outstanding = 2 * (num_buffers - 2): 0 for double-buffer, 2 for triple-buffer wait_and_barrier(outstanding=2 * (num_buffers - 2)) # Main loop: each iteration covers (num_buffers) K-tiles - # Sub-phase s: load the "next" tile, compute the "current" tile, then barrier - # load_stage = (s + num_buffers - 1) % num_buffers - # load_offset = iv + (s + num_buffers - 1) * tile_k - # compute stage[s] + # Sub-phase s: load next tile (MCAST), compute current tile, then barrier + # The last sub-phase uses wait_and_cluster_barrier to fuse the WG + # barrier with cluster sync for the NEXT iteration's MCAST loads. main_end = loop_iters * num_buffers * tile_k if loop_iters > 0: @@ -465,13 +515,17 @@ def _l2_prefetch(k_base): copy_b_to_lds(iv + arith.index(_load_k_off), stages_b_mem[_load_stage]) _l2_prefetch(iv + arith.index(s * tile_k)) accs_in = _compute_and_schedule(accs_in, stages_a[s], stages_b[s]) - wait_and_barrier(outstanding=2) + if s == num_buffers - 1: + wait_and_cluster_barrier(outstanding=2) + else: + wait_and_barrier(outstanding=2) results = yield list(accs_in) accs = list(results) # Tail: handle remaining tiles using the compile-time plan - # Each plan step: optionally load one tile, compute one stage, then wait. # outstanding=-1 → last step: use compute_tile (no barrier). + if loop_iters == 0 and use_cluster: + gpu.cluster_barrier() _extra_j = 0 for _load_stage, _compute_stage, _outstanding in tail_plan: if _load_stage is not None: @@ -491,12 +545,16 @@ def _emit_epi_addrs(): else: accs = _compute_and_schedule( accs, stages_a[_compute_stage], stages_b[_compute_stage]) - wait_and_barrier(outstanding=_outstanding) + if use_cluster and _load_stage is not None: + wait_and_cluster_barrier(outstanding=_outstanding) + else: + wait_and_barrier(outstanding=_outstanding) epilogue_stores(accs, epi_addrs_box[0]) cache_tag = (in_dtype, K, tile_m, tile_n, tile_k, m_warp, n_warp, - num_buffers, waves_per_eu, l2_prefetch_distance) + num_buffers, effective_waves_per_eu, l2_prefetch_distance, + cluster_m, cluster_n) @flyc.jit def launch_wmma_gemm_tdm( @@ -521,17 +579,22 @@ def launch_wmma_gemm_tdm( gy = _raw((idx_n + arith.index(tile_n - 1)) / arith.index(tile_n)) launcher = kernel_wmma_gemm_tdm(arg_c, arg_a, arg_b, i32_m, i32_n) - if waves_per_eu is not None: - _wpe = int(waves_per_eu) - if _wpe >= 1: - for op in ctx.gpu_module_body.operations: - if hasattr(op, 'attributes') and op.OPERATION_NAME == "gpu.func": + for op in ctx.gpu_module_body.operations: + if hasattr(op, 'attributes') and op.OPERATION_NAME == "gpu.func": + if effective_waves_per_eu is not None: + _wpe = int(effective_waves_per_eu) + if _wpe >= 1: op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get( ir.IntegerType.get_signless(32), _wpe) + if use_cluster: + op.attributes["rocdl.cluster_dims"] = ir.StringAttr.get( + f"{cluster_m},{cluster_n},1") + cluster_arg = (cluster_m, cluster_n, 1) if use_cluster else None launcher.launch( grid=(gx, gy, 1), block=(block_threads, 1, 1), stream=stream, + cluster=cluster_arg, ) return launch_wmma_gemm_tdm diff --git a/lib/Runtime/FlyRocmRuntimeWrappers.cpp b/lib/Runtime/FlyRocmRuntimeWrappers.cpp index 2981e219..f4037620 100644 --- a/lib/Runtime/FlyRocmRuntimeWrappers.cpp +++ b/lib/Runtime/FlyRocmRuntimeWrappers.cpp @@ -66,6 +66,61 @@ extern "C" void mgpuLaunchKernel(hipFunction_t function, intptr_t gridX, stream, params, extra)); } +extern "C" void mgpuLaunchClusterKernel(hipFunction_t function, + intptr_t clusterX, intptr_t clusterY, + intptr_t clusterZ, + intptr_t gridX, intptr_t gridY, + intptr_t gridZ, + intptr_t blockX, intptr_t blockY, + intptr_t blockZ, int32_t smem, + hipStream_t stream, void **params, + void **extra, size_t /*paramsCount*/) { + hipLaunchAttribute attrs[1]; + attrs[0].id = hipLaunchAttributeClusterDimension; + attrs[0].value.clusterDim.x = static_cast(clusterX); + attrs[0].value.clusterDim.y = static_cast(clusterY); + attrs[0].value.clusterDim.z = static_cast(clusterZ); + + HIP_LAUNCH_CONFIG config{}; + config.gridDimX = static_cast(gridX); + config.gridDimY = static_cast(gridY); + config.gridDimZ = static_cast(gridZ); + config.blockDimX = static_cast(blockX); + config.blockDimY = static_cast(blockY); + config.blockDimZ = static_cast(blockZ); + config.sharedMemBytes = static_cast(smem); + config.hStream = stream; + config.attrs = attrs; + config.numAttrs = 1; + + hipError_t err = hipDrvLaunchKernelEx(&config, function, params, extra); + if (err == hipSuccess) + return; + + const bool requestedRealCluster = + (clusterX > 1) || (clusterY > 1) || (clusterZ > 1); + if (requestedRealCluster) { + fprintf(stderr, + "[mgpuLaunchClusterKernel] hipDrvLaunchKernelEx failed (err=%d) " + "for requested cluster=(%ld,%ld,%ld); not falling back to " + "hipModuleLaunchKernel.\n", + static_cast(err), static_cast(clusterX), + static_cast(clusterY), static_cast(clusterZ)); + HIP_REPORT_IF_ERROR(err); + HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, 0, 0, 0, 0, 0, 0, smem, + stream, params, extra)); + return; + } + + fprintf(stderr, + "[mgpuLaunchClusterKernel] hipDrvLaunchKernelEx failed (err=%d) " + "for cluster=(1,1,1); falling back to hipModuleLaunchKernel.\n", + static_cast(err)); + HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ, + blockX, blockY, blockZ, smem, + stream, params, extra)); +} + extern "C" hipStream_t mgpuStreamCreate() { hipStream_t stream = nullptr; HIP_REPORT_IF_ERROR(hipStreamCreate(&stream)); diff --git a/python/flydsl/__init__.py b/python/flydsl/__init__.py index dbd27816..9f9e3849 100644 --- a/python/flydsl/__init__.py +++ b/python/flydsl/__init__.py @@ -1,5 +1,38 @@ +import ctypes +import os + _BASE_VERSION = "0.1.0" + +# Workaround: resolve FFM simulator "LLVM ERROR: Option 'greedy' already exists!" +def _maybe_preload_system_comgr() -> None: + disable = os.environ.get("FLYDSL_DISABLE_COMGR_PRELOAD", "").strip().lower() + if disable in {"1", "true", "yes", "on"}: + return + + model_path = os.environ.get("GFX1250_MODEL_PATH", "") + hsa_model_lib = os.environ.get("HSA_MODEL_LIB", "") + in_ffm_session = ("ffm-lite" in hsa_model_lib) or ("ffmlite" in model_path) + if not in_ffm_session: + return + + system_comgr = os.environ.get( + "FLYDSL_COMGR_PRELOAD_PATH", "/opt/rocm/lib/libamd_comgr.so.3" + ) + sim_comgr = os.path.join(model_path, "rocm", "libamd_comgr.so.3") + if not (os.path.exists(system_comgr) and os.path.exists(sim_comgr)): + return + + mode = getattr(os, "RTLD_NOW", 0) | getattr(os, "RTLD_GLOBAL", 0) + try: + ctypes.CDLL(system_comgr, mode=mode) + except OSError: + # Keep import robust if the host ROCm stack differs. + pass + + +_maybe_preload_system_comgr() + try: from ._version import __version__ except ImportError: diff --git a/python/flydsl/compiler/kernel_function.py b/python/flydsl/compiler/kernel_function.py index ee41bd39..bd754f60 100644 --- a/python/flydsl/compiler/kernel_function.py +++ b/python/flydsl/compiler/kernel_function.py @@ -237,6 +237,7 @@ def launch( block: DimType = (1, 1, 1), smem: Union[int, ir.Value] = 0, stream: Optional[ir.Value] = None, + cluster: Optional[DimType] = None, ) -> None: """Emit gpu.launch_func operation with the given configuration. @@ -245,6 +246,8 @@ def launch( block: Block dimensions (x, y, z). Can be int, ir.Value, tuple, or list. smem: Dynamic shared memory size in bytes. Can be int or ir.Value. stream: CUDA/HIP stream as ir.Value. None means default stream. + cluster: Cluster dimensions (x, y, z) for workgroup clustering. + None means no clustering. Enables MCAST and cluster barriers. """ launch_loc = create_caller_location(depth=2) @@ -277,6 +280,15 @@ def launch( async_deps = [stream_val] if stream_val is not None else None + cluster_size = None + if cluster is not None: + cx, cy, cz = _normalize_dim(cluster) + cluster_size = ( + _to_index_value(cx), + _to_index_value(cy), + _to_index_value(cz), + ) + gpu.LaunchFuncOp( ["kernels", self._kernel_name], (grid_x, grid_y, grid_z), @@ -284,6 +296,8 @@ def launch( kernel_operands, async_dependencies=async_deps, dynamic_shared_memory_size=smem_val, + async_object=async_object, + cluster_size=cluster_size, loc=launch_loc, ip=None, ) diff --git a/python/flydsl/expr/gpu.py b/python/flydsl/expr/gpu.py index be4d08e7..87985e3b 100644 --- a/python/flydsl/expr/gpu.py +++ b/python/flydsl/expr/gpu.py @@ -14,9 +14,12 @@ """ from .._mlir import ir -from .._mlir.dialects import gpu +from .._mlir.dialects import gpu, rocdl, scf from .._mlir.ir import Attribute from .typing import Tuple3D +from . import arith as _arith_ext +from . import rocdl as _rocdl_ext +from .typing import T thread_id = gpu.thread_id block_id = gpu.block_id @@ -52,6 +55,112 @@ class SharedAllocator: pass +# ========================================================================= +# Cluster operations (gfx1250 workgroup clustering) +# ========================================================================= + +CLUSTER_BARRIER_ID = -3 +# For cluster sync, wait on the cluster user barrier itself. +CLUSTER_WAIT_ALL = CLUSTER_BARRIER_ID + + +def is_wave_leader(): + """Return true for wave-0 inside the workgroup.""" + return _arith_ext.cmpi( + _arith_ext.CmpIPredicate.eq, + _rocdl_ext.wave_id(), + _arith_ext.constant(0, type=T.i32), + ) + + +def cluster_signal_once_per_wg(): + """Signal cluster barrier from exactly one wave per workgroup.""" + if_op = scf.IfOp(is_wave_leader(), [], has_else=False, loc=ir.Location.unknown()) + if len(if_op.regions[0].blocks) == 0: + if_op.regions[0].blocks.append(*[]) + with ir.InsertionPoint(if_op.regions[0].blocks[0]): + rocdl.s_barrier_signal(CLUSTER_BARRIER_ID) + scf.YieldOp([]) + + +def cluster_wait(): + """Wait on the cluster user barrier.""" + rocdl.s_barrier_wait(CLUSTER_WAIT_ALL) + + +def cluster_barrier(): + """Workgroup + cluster barrier with one-wave signal semantics. + + This is the safe default for kernels using cluster multicast: + 1) synchronize waves inside each workgroup + 2) signal cluster barrier once per workgroup (wave-0 only) + 3) wait for all workgroups in the cluster + """ + gpu.barrier() + cluster_signal_once_per_wg() + cluster_wait() + + +def compute_cluster_position(bx, by, cluster_m: _int, cluster_n: _int): + """Compute a workgroup's (row, col) position within its cluster. + + Args: + bx: Block index X (M direction), MLIR index value. + by: Block index Y (N direction), MLIR index value. + cluster_m: Cluster dimension along M (number of WG rows per cluster). + cluster_n: Cluster dimension along N (number of WG cols per cluster). + + Returns: + (local_x, local_y) as MLIR index values — position within the cluster. + """ + local_x = bx % _arith_ext.index(cluster_m) + local_y = by % _arith_ext.index(cluster_n) + return local_x, local_y + + +def compute_mcast_masks(local_x, local_y, cluster_m: _int, cluster_n: _int): + """Compute MCAST workgroup_mask values for A and B matrices. + + Hardware flat WG index within a cluster uses X-inner ordering + (MI400 Shader Programming, TTMP6 layout, section 3.5.5.1): + + flat_wg_id = wg_x + wg_y * nwg_x = local_x + local_y * cluster_m + + where cluster_dims = (cluster_m, cluster_n, 1), so nwg_x = cluster_m. + + A mask: WGs sharing the same M-tile row (same local_x, varying local_y). + Bits: {local_x + ly * cluster_m : ly in 0..cluster_n-1} + B mask: WGs sharing the same N-tile column (same local_y, varying local_x). + Bits: {lx + local_y * cluster_m : lx in 0..cluster_m-1} + + Args: + local_x: WG row within cluster (MLIR index, 0..cluster_m-1). + local_y: WG column within cluster (MLIR index, 0..cluster_n-1). + cluster_m: Cluster rows (Python int). + cluster_n: Cluster columns (Python int). + + Returns: + (a_mask, b_mask) as MLIR i32 values for TDM workgroup_mask. + """ + local_x_i32 = _arith_ext.index_cast(T.i32, local_x) + local_y_i32 = _arith_ext.index_cast(T.i32, local_y) + cluster_m_i32 = _arith_ext.constant(cluster_m, type=T.i32) + + # A mask: pattern has bits at strides of cluster_m, shifted by local_x + a_pattern_val = 0 + for ly in range(cluster_n): + a_pattern_val |= (1 << (ly * cluster_m)) + a_pattern = _arith_ext.constant(a_pattern_val, type=T.i32) + a_mask = _arith_ext.shli(a_pattern, local_x_i32) + + # B mask: cluster_m contiguous low bits, shifted by local_y * cluster_m + b_pattern = _arith_ext.constant((1 << cluster_m) - 1, type=T.i32) + col_base = _arith_ext.muli(local_y_i32, cluster_m_i32) + b_mask = _arith_ext.shli(b_pattern, col_base) + + return a_mask, b_mask + + __all__ = [ "thread_id", "block_id", @@ -63,4 +172,12 @@ class SharedAllocator: "smem_space", "lds_space", "SharedAllocator", + "is_wave_leader", + "cluster_signal_once_per_wg", + "cluster_wait", + "cluster_barrier", + "compute_cluster_position", + "compute_mcast_masks", + "CLUSTER_BARRIER_ID", + "CLUSTER_WAIT_ALL", ] diff --git a/python/flydsl/expr/tdm_ops.py b/python/flydsl/expr/tdm_ops.py index d8ebd8de..706c0d8f 100644 --- a/python/flydsl/expr/tdm_ops.py +++ b/python/flydsl/expr/tdm_ops.py @@ -176,6 +176,7 @@ def make_tensor_descriptor_2d( num_warps: int = 1, cache_policy: int = 0, pred: int = 1, + workgroup_mask: Union[int, "ir.Value"] = 0, ) -> TDMDescriptor2D: """Build a 2D TDM descriptor for tensor_load_to_lds_d2. @@ -206,6 +207,10 @@ def make_tensor_descriptor_2d( num_warps: Total warps in the workgroup. cache_policy: Cache policy (0 = default). pred: Predicate (1 = enabled). + workgroup_mask: MCAST workgroup mask [15:0] for TDM GROUP1 descriptor. + int: compile-time constant folded into descriptor. + ir.Value (i32 SGPR): runtime mask, ORed with upper config bits. + 0 = no multicast (default). Returns: TDMDescriptor2D with dgroup0 and dgroup1 ready for tensor_load_2d. @@ -307,9 +312,8 @@ def make_tensor_descriptor_2d( pad_enable = 0 # sgpr0: config bitfields - g1_s0_val = ( - (0) # workgroup_mask [15:0] - | (data_size_code << 16) # data_size [17:16] + g1_s0_upper = ( + (data_size_code << 16) # data_size [17:16] | (0 << 18) # atomic_barrier_enable | (0 << 19) # iterate_enable | (pad_enable << 20) # pad_enable @@ -317,7 +321,14 @@ def make_tensor_descriptor_2d( | (enc_interval << 22) # pad_interval [24:22] | (enc_amount << 25) # pad_amount [31:25] ) - g1_s0 = arith.constant(g1_s0_val, type=T.i32) + + if isinstance(workgroup_mask, int): + g1_s0_val = (workgroup_mask & 0xFFFF) | g1_s0_upper + g1_s0 = arith.constant(g1_s0_val, type=T.i32) + else: + upper_const = arith.constant(g1_s0_upper, type=T.i32) + mask_i32 = arith.andi(workgroup_mask, arith.constant(0xFFFF, type=T.i32)) + g1_s0 = arith.ori(upper_const, mask_i32) # sgpr1: atomic_barrier_addr[15:0]=0 | tensor_dim0_lo[31:16] g1_s1 = arith.constant((tdim0 & 0xFFFF) << 16, type=T.i32) diff --git a/python/flydsl/utils/smem_allocator.py b/python/flydsl/utils/smem_allocator.py index bf7a47a8..33eefa8a 100644 --- a/python/flydsl/utils/smem_allocator.py +++ b/python/flydsl/utils/smem_allocator.py @@ -209,10 +209,12 @@ def get_base(self): SMEM_CAPACITY_MAP = { # ===================== AMD CDNA Architectures (Data Center Compute Cards) ===================== # CDNA 3 (MI300 Series) - 64KB LDS per CU - "gfx942": 65536, # MI300A / MI300X: 64KB LDS per CU + "gfx942": 65536, # MI300A / MI300X: 64KB LDS per CU # CDNA 4 (MI350 Series) - 160KB LDS per CU (key upgrade for CDNA4) "gfx950": 163840, # MI300C / MI300X Enhanced Models: 64KB LDS per CU "gfx1201": 65536, # RDNA4: 64KB LDS per WGP + # GFX1250 (MI450 Series) - 320KB LDS (WGP$ unified, 5 × 64KB segments) + "gfx1250": 327680, # MI450: 320KB configurable as LDS } def check_smem_capacity(allocated_bytes: int, arch: str = None): diff --git a/tests/kernels/test_wmma_gemm_gfx1250.py b/tests/kernels/test_wmma_gemm_gfx1250.py index 5eb7d13b..0f1e351c 100644 --- a/tests/kernels/test_wmma_gemm_gfx1250.py +++ b/tests/kernels/test_wmma_gemm_gfx1250.py @@ -47,7 +47,8 @@ @pytest.mark.parametrize("num_buffers", [2, 3]) def test_wmma_gemm_tdm(in_dtype, M, N, K, tile_m, tile_n, tile_k, num_buffers, - m_warp=2, n_warp=4, l2_prefetch_distance=2): + m_warp=2, n_warp=4, l2_prefetch_distance=2, + cluster_m=1, cluster_n=1): arch = str(get_rocm_arch(timeout_s=300)) if arch != "gfx1250": pytest.skip(f"WMMA requires gfx1250, got {arch}") @@ -64,15 +65,34 @@ def test_wmma_gemm_tdm(in_dtype, M, N, K, tile_m, tile_n, tile_k, if total_lds > 327680: pytest.skip(f"LDS budget exceeded: {total_lds} > 327680") - print(f"Running WMMA GEMM TDM: M={M}, N={N}, K={K}, " - f"dtype={in_dtype}, bufs={num_buffers}") - torch_dtype = torch.float16 if in_dtype == "fp16" else torch.bfloat16 device = torch.device("cuda") torch.manual_seed(0) mpad = (M + tile_m - 1) // tile_m * tile_m npad = (N + tile_n - 1) // tile_n * tile_n + wg_m = mpad // tile_m + wg_n = npad // tile_n + + if cluster_m < 1 or cluster_n < 1: + pytest.skip(f"Invalid cluster dims: ({cluster_m}, {cluster_n}), both must be >= 1") + if cluster_m > 1 or cluster_n > 1: + if wg_m < cluster_m or wg_n < cluster_n: + pytest.skip( + "Cluster dims exceed launch grid: " + f"wg_grid=({wg_m},{wg_n}), cluster=({cluster_m},{cluster_n})" + ) + if (wg_m % cluster_m) != 0 or (wg_n % cluster_n) != 0: + pytest.skip( + "WG grid must be divisible by cluster dims: " + f"wg_grid=({wg_m},{wg_n}), cluster=({cluster_m},{cluster_n})" + ) + + print( + f"Running WMMA GEMM TDM: M={M}, N={N}, K={K}, " + f"dtype={in_dtype}, bufs={num_buffers}, " + f"cluster=({cluster_m},{cluster_n})" + ) a = torch.randn((M, K), dtype=torch_dtype, device='cpu').cuda() b = torch.randn((K, N), dtype=torch_dtype, device='cpu').cuda() @@ -90,6 +110,8 @@ def test_wmma_gemm_tdm(in_dtype, M, N, K, tile_m, tile_n, tile_k, m_warp=m_warp, n_warp=n_warp, in_dtype=in_dtype, num_buffers=num_buffers, l2_prefetch_distance=l2_prefetch_distance, + cluster_m=cluster_m, + cluster_n=cluster_n, ) launch_fn( c_pad.contiguous().view(-1), @@ -121,6 +143,8 @@ def test_wmma_gemm_tdm(in_dtype, M, N, K, tile_m, tile_n, tile_k, parser.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "bf16"]) parser.add_argument("--num-buffers", type=int, default=2, choices=[2, 3]) parser.add_argument("--l2-prefetch-distance", type=int, default=0) + parser.add_argument("--cluster-m", type=int, default=1) + parser.add_argument("--cluster-n", type=int, default=1) args = parser.parse_args() test_wmma_gemm_tdm( @@ -130,4 +154,6 @@ def test_wmma_gemm_tdm(in_dtype, M, N, K, tile_m, tile_n, tile_k, m_warp=args.m_warp, n_warp=args.n_warp, l2_prefetch_distance=args.l2_prefetch_distance, + cluster_m=args.cluster_m, + cluster_n=args.cluster_n, ) From 6d841f483167eaa97f1030df99078f3b017f0db3 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Sat, 14 Mar 2026 14:32:53 +0000 Subject: [PATCH 06/11] add mcast unit tests --- kernels/wmma_gemm_gfx1250.py | 2 +- python/flydsl/expr/gpu.py | 12 +--- python/flydsl/expr/rocdl.py | 80 ++++++++++++++++++++++++- python/flydsl/expr/rocdl/__init__.py | 63 +++++++++++++++++++ tests/kernels/test_wmma_gemm_gfx1250.py | 39 +++++++++--- 5 files changed, 177 insertions(+), 19 deletions(-) diff --git a/kernels/wmma_gemm_gfx1250.py b/kernels/wmma_gemm_gfx1250.py index 469d0ee3..cd8d55ee 100644 --- a/kernels/wmma_gemm_gfx1250.py +++ b/kernels/wmma_gemm_gfx1250.py @@ -211,7 +211,7 @@ def kernel_wmma_gemm_tdm( # --- Cluster MCAST setup --- if use_cluster: - local_x, local_y = gpu.compute_cluster_position(bx, by, cluster_m, cluster_n) + local_x, local_y = gpu.compute_cluster_position() a_mcast_mask, b_mcast_mask = gpu.compute_mcast_masks( local_x, local_y, cluster_m, cluster_n) else: diff --git a/python/flydsl/expr/gpu.py b/python/flydsl/expr/gpu.py index 87985e3b..cac80ea2 100644 --- a/python/flydsl/expr/gpu.py +++ b/python/flydsl/expr/gpu.py @@ -101,20 +101,14 @@ def cluster_barrier(): cluster_wait() -def compute_cluster_position(bx, by, cluster_m: _int, cluster_n: _int): +def compute_cluster_position(): """Compute a workgroup's (row, col) position within its cluster. - Args: - bx: Block index X (M direction), MLIR index value. - by: Block index Y (N direction), MLIR index value. - cluster_m: Cluster dimension along M (number of WG rows per cluster). - cluster_n: Cluster dimension along N (number of WG cols per cluster). - Returns: (local_x, local_y) as MLIR index values — position within the cluster. """ - local_x = bx % _arith_ext.index(cluster_m) - local_y = by % _arith_ext.index(cluster_n) + local_x = _arith_ext.index_cast(T.index, _rocdl_ext.cluster_workgroup_id_x()) + local_y = _arith_ext.index_cast(T.index, _rocdl_ext.cluster_workgroup_id_y()) return local_x, local_y diff --git a/python/flydsl/expr/rocdl.py b/python/flydsl/expr/rocdl.py index c00340eb..debe89e0 100644 --- a/python/flydsl/expr/rocdl.py +++ b/python/flydsl/expr/rocdl.py @@ -109,6 +109,14 @@ def make_buffer_tensor(memref, alignment=4, loc=None, ip=None): # Keep references to ODS-generated builders so we can wrap them without losing access. _ods_wave_id = wave_id # ODS: wave_id(res, ...) -> i32 +_ods_cluster_workgroup_id_x = cluster_workgroup_id_x +_ods_cluster_workgroup_id_y = cluster_workgroup_id_y +_ods_cluster_workgroup_id_z = cluster_workgroup_id_z +_ods_cluster_load_async_to_lds_b8 = cluster_load_async_to_lds_b8 +_ods_cluster_load_async_to_lds_b32 = cluster_load_async_to_lds_b32 +_ods_cluster_load_async_to_lds_b64 = cluster_load_async_to_lds_b64 +_ods_cluster_load_async_to_lds_b128 = cluster_load_async_to_lds_b128 +_ods_s_wait_asynccnt = s_wait_asynccnt _ods_mfma_f32_16x16x16f16 = mfma_f32_16x16x16f16 _ods_mfma_f32_16x16x16bf16_1k = globals().get("mfma_f32_16x16x16bf16_1k", None) _ods_mfma_f32_16x16x32_fp8_fp8 = mfma_f32_16x16x32_fp8_fp8 @@ -332,6 +340,70 @@ def wave_id(): return _ods_wave_id(i32) +def cluster_workgroup_id_x(): + """Get workgroup position within cluster along X (SGPR, gfx1250). """ + from .._mlir import ir + i32 = ir.IntegerType.get_signless(32) + return _ods_cluster_workgroup_id_x(i32) + + +def cluster_workgroup_id_y(): + """Get workgroup position within cluster along Y (SGPR, gfx1250). """ + from .._mlir import ir + i32 = ir.IntegerType.get_signless(32) + return _ods_cluster_workgroup_id_y(i32) + + +def cluster_workgroup_id_z(): + """Get workgroup position within cluster along Z (SGPR, gfx1250). """ + from .._mlir import ir + i32 = ir.IntegerType.get_signless(32) + return _ods_cluster_workgroup_id_z(i32) + + +def cluster_load_async_to_lds(global_ptr, lds_ptr, size_bytes, offset=0, cpol=0, mask=None): + """Per-lane cluster broadcast load: Global -> LDS with MCAST (gfx1250). + + Args: + global_ptr: ``!llvm.ptr<1>`` — global address space pointer. + lds_ptr: ``!llvm.ptr<3>`` — LDS address space pointer. + size_bytes: Load width: 1, 4, 8, or 16 bytes (selects b8/b32/b64/b128). + offset: Byte offset (int, default 0). + cpol: Cache policy (int, default 0). + mask: i32 workgroup_mask for MCAST broadcast. None means no mask + (falls back to non-cluster global_load_async_to_lds). + + Raises: + ValueError: If ``size_bytes`` is not 1, 4, 8, or 16. + """ + _dispatch = { + 1: _ods_cluster_load_async_to_lds_b8, + 4: _ods_cluster_load_async_to_lds_b32, + 8: _ods_cluster_load_async_to_lds_b64, + 16: _ods_cluster_load_async_to_lds_b128, + } + fn = _dispatch.get(size_bytes) + if fn is None: + raise ValueError( + f"cluster_load_async_to_lds: size_bytes must be 1, 4, 8, or 16, " + f"got {size_bytes}") + if mask is None: + from .._mlir import ir + from . import arith as _arith + mask = _arith.unwrap(_arith.constant(0, type=ir.IntegerType.get_signless(32))) + fn(global_ptr, lds_ptr, offset, cpol, mask) + + +def s_wait_asynccnt(count=0): + """Wait for outstanding async load/store operations (ASYNCcnt counter). + + Args: + count: Maximum number of outstanding operations to allow. + 0 = wait for all. + """ + _ods_s_wait_asynccnt(count) + + def lds_transpose_load(result_type, lds_memref, elem_offset, elem_bytes): """Transpose-load from LDS memref via ds_load_tr16_b128 (gfx1250). @@ -377,11 +449,11 @@ def lds_transpose_load(result_type, lds_memref, elem_offset, elem_bytes): 'grid_dim_x', 'grid_dim_y', 'grid_dim_z', 'wavefrontsize', 'wave_id', - + # Synchronization 'barrier', 's_barrier', 's_barrier_signal', 's_barrier_wait', 's_waitcnt', 's_wait_loadcnt', 's_wait_storecnt', - 's_wait_dscnt', 's_wait_expcnt', + 's_wait_dscnt', 's_wait_expcnt', 's_wait_asynccnt', # Matrix operations - MFMA (Matrix Fused Multiply-Add) 'mfma_f32_32x32x8f16', 'mfma_f32_16x16x16f16', @@ -459,6 +531,10 @@ def lds_transpose_load(result_type, lds_memref, elem_offset, elem_bytes): # gfx1250 L2 prefetch 'global_prefetch', # per-lane 1-byte prefetch hint + + # Cluster (gfx1250 workgroup clustering) + 'cluster_workgroup_id_x', 'cluster_workgroup_id_y', 'cluster_workgroup_id_z', + 'cluster_load_async_to_lds', # per-lane MCAST load (Global → LDS) ] diff --git a/python/flydsl/expr/rocdl/__init__.py b/python/flydsl/expr/rocdl/__init__.py index 14efa911..4eb2b4b2 100644 --- a/python/flydsl/expr/rocdl/__init__.py +++ b/python/flydsl/expr/rocdl/__init__.py @@ -15,6 +15,14 @@ # Keep references to ODS-generated builders so we can wrap them without losing access. _ods_wave_id = wave_id # ODS: wave_id(res, ...) -> i32 +_ods_cluster_workgroup_id_x = cluster_workgroup_id_x +_ods_cluster_workgroup_id_y = cluster_workgroup_id_y +_ods_cluster_workgroup_id_z = cluster_workgroup_id_z +_ods_cluster_load_async_to_lds_b8 = cluster_load_async_to_lds_b8 +_ods_cluster_load_async_to_lds_b32 = cluster_load_async_to_lds_b32 +_ods_cluster_load_async_to_lds_b64 = cluster_load_async_to_lds_b64 +_ods_cluster_load_async_to_lds_b128 = cluster_load_async_to_lds_b128 +_ods_s_wait_asynccnt = s_wait_asynccnt _ods_mfma_f32_16x16x16f16 = mfma_f32_16x16x16f16 _ods_mfma_f32_16x16x16bf16_1k = globals().get("mfma_f32_16x16x16bf16_1k", None) _ods_mfma_f32_16x16x32_fp8_fp8 = mfma_f32_16x16x32_fp8_fp8 @@ -113,6 +121,61 @@ def wave_id(): return _ods_wave_id(i32) +def cluster_workgroup_id_x(): + """Get workgroup position within cluster along X (SGPR, gfx1250). """ + from ..._mlir import ir + i32 = ir.IntegerType.get_signless(32) + return _ods_cluster_workgroup_id_x(i32) + + +def cluster_workgroup_id_y(): + """Get workgroup position within cluster along Y (SGPR, gfx1250). """ + from ..._mlir import ir + i32 = ir.IntegerType.get_signless(32) + return _ods_cluster_workgroup_id_y(i32) + + +def cluster_workgroup_id_z(): + """Get workgroup position within cluster along Z (SGPR, gfx1250). """ + from ..._mlir import ir + i32 = ir.IntegerType.get_signless(32) + return _ods_cluster_workgroup_id_z(i32) + + +def cluster_load_async_to_lds(global_ptr, lds_ptr, size_bytes, offset=0, cpol=0, mask=None): + """Per-lane cluster broadcast load: Global -> LDS with MCAST (gfx1250). + + Args: + global_ptr: ``!llvm.ptr<1>`` -- global address space pointer. + lds_ptr: ``!llvm.ptr<3>`` -- LDS address space pointer. + size_bytes: Load width: 1, 4, 8, or 16 bytes (selects b8/b32/b64/b128). + offset: Byte offset (int, default 0). + cpol: Cache policy (int, default 0). + mask: i32 workgroup_mask for MCAST broadcast. None means no mask. + """ + _dispatch = { + 1: _ods_cluster_load_async_to_lds_b8, + 4: _ods_cluster_load_async_to_lds_b32, + 8: _ods_cluster_load_async_to_lds_b64, + 16: _ods_cluster_load_async_to_lds_b128, + } + fn = _dispatch.get(size_bytes) + if fn is None: + raise ValueError( + f"cluster_load_async_to_lds: size_bytes must be 1, 4, 8, or 16, " + f"got {size_bytes}") + if mask is None: + from ..._mlir import ir + from .. import arith as _arith + mask = _arith.unwrap(_arith.constant(0, type=ir.IntegerType.get_signless(32))) + fn(global_ptr, lds_ptr, offset, cpol, mask) + + +def s_wait_asynccnt(count=0): + """Wait for outstanding async load/store operations (ASYNCcnt counter).""" + _ods_s_wait_asynccnt(count) + + def lds_transpose_load(result_type, lds_memref, elem_offset, elem_bytes): """Transpose-load from LDS memref via ds_load_tr16_b128 (gfx1250). diff --git a/tests/kernels/test_wmma_gemm_gfx1250.py b/tests/kernels/test_wmma_gemm_gfx1250.py index 0f1e351c..784f81d2 100644 --- a/tests/kernels/test_wmma_gemm_gfx1250.py +++ b/tests/kernels/test_wmma_gemm_gfx1250.py @@ -8,9 +8,6 @@ import os import sys -import pytest -import torch - _REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) _PYFLIR_SRC = os.path.join(_REPO_ROOT, "flydsl", "src") if _REPO_ROOT not in sys.path: @@ -18,6 +15,12 @@ if _PYFLIR_SRC not in sys.path: sys.path.insert(0, _PYFLIR_SRC) +# workaround for simulator +import flydsl # noqa: E402,F401 -- preload system comgr before torch/HIP loads LLVM + +import pytest +import torch + from flydsl.runtime.device import get_rocm_arch from kernels.wmma_gemm_gfx1250 import compile_wmma_gemm_tdm from tests.test_common import verify_output @@ -49,6 +52,7 @@ def test_wmma_gemm_tdm(in_dtype, M, N, K, tile_m, tile_n, tile_k, num_buffers, m_warp=2, n_warp=4, l2_prefetch_distance=2, cluster_m=1, cluster_n=1): + """Non-cluster GEMM correctness test.""" arch = str(get_rocm_arch(timeout_s=300)) if arch != "gfx1250": pytest.skip(f"WMMA requires gfx1250, got {arch}") @@ -66,7 +70,6 @@ def test_wmma_gemm_tdm(in_dtype, M, N, K, tile_m, tile_n, tile_k, pytest.skip(f"LDS budget exceeded: {total_lds} > 327680") torch_dtype = torch.float16 if in_dtype == "fp16" else torch.bfloat16 - device = torch.device("cuda") torch.manual_seed(0) mpad = (M + tile_m - 1) // tile_m * tile_m @@ -97,12 +100,12 @@ def test_wmma_gemm_tdm(in_dtype, M, N, K, tile_m, tile_n, tile_k, a = torch.randn((M, K), dtype=torch_dtype, device='cpu').cuda() b = torch.randn((K, N), dtype=torch_dtype, device='cpu').cuda() - a_pad = torch.zeros((mpad, K), dtype=torch_dtype, device=device) - b_pad = torch.zeros((K, npad), dtype=torch_dtype, device=device) + a_pad = torch.zeros((mpad, K), dtype=torch_dtype, device='cpu').cuda() + b_pad = torch.zeros((K, npad), dtype=torch_dtype, device='cpu').cuda() a_pad[:M, :] = a b_pad[:, :N] = b - c_pad = torch.zeros((mpad, npad), dtype=torch.float32, device=device) + c_pad = torch.zeros((mpad, npad), dtype=torch.float32, device='cpu').cuda() launch_fn = compile_wmma_gemm_tdm( M=mpad, N=npad, K=K, @@ -128,6 +131,28 @@ def test_wmma_gemm_tdm(in_dtype, M, N, K, tile_m, tile_n, tile_k, print("PASSED") +@pytest.mark.parametrize("in_dtype", ["fp16"]) +@pytest.mark.parametrize( + "M, N, K, tile_m, tile_n, tile_k", + [ + (1024, 1024, 1024, 128, 256, 128), + (2048, 2048, 1024, 128, 256, 128), + (2048, 2048, 2048, 128, 256, 128), + (4096, 4096, 1024, 128, 256, 128), + ], +) +@pytest.mark.parametrize("cluster_m, cluster_n", [(2, 2), (4, 4)]) +def test_wmma_gemm_tdm_mcast(in_dtype, M, N, K, tile_m, tile_n, tile_k, + cluster_m, cluster_n): + """Cluster multicast GEMM correctness test (large shapes only).""" + test_wmma_gemm_tdm( + in_dtype, M, N, K, tile_m, tile_n, tile_k, + num_buffers=2, m_warp=2, n_warp=4, + l2_prefetch_distance=2, + cluster_m=cluster_m, cluster_n=cluster_n, + ) + + if __name__ == "__main__": import argparse From 9fad5968ff847ee7d56a39de5eadb2cf60127136 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Sun, 15 Mar 2026 15:01:10 +0000 Subject: [PATCH 07/11] add gemm mxfp4 support --- kernels/mxfp4_gemm_gfx1250.py | 715 +++++++++++++++++++++++ kernels/pipeline_utils.py | 43 ++ kernels/wmma_gemm_gfx1250.py | 39 +- python/flydsl/expr/rocdl.py | 77 +++ python/flydsl/expr/rocdl/__init__.py | 74 +++ tests/kernels/test_mxfp4_gemm_gfx1250.py | 178 ++++++ tests/test_common.py | 2 +- 7 files changed, 1090 insertions(+), 38 deletions(-) create mode 100644 kernels/mxfp4_gemm_gfx1250.py create mode 100644 kernels/pipeline_utils.py create mode 100644 tests/kernels/test_mxfp4_gemm_gfx1250.py diff --git a/kernels/mxfp4_gemm_gfx1250.py b/kernels/mxfp4_gemm_gfx1250.py new file mode 100644 index 00000000..ec961d6c --- /dev/null +++ b/kernels/mxfp4_gemm_gfx1250.py @@ -0,0 +1,715 @@ +"""MXFP4 GEMM kernel for gfx1250. + +Uses V_WMMA_SCALE_F32_16X16X128_F8F6F4 with FP4 (E2M1) data and E8M0 block scales. +Supports N-stage buffering (2/3/4), TDM async copy, cluster MCAST. +""" + +import flydsl.compiler as flyc +import flydsl.expr as fx + +from flydsl._mlir import ir +from flydsl.compiler.kernel_function import CompilationContext +from flydsl.expr import arith, buffer_ops, gpu, range_constexpr, rocdl, tdm_ops, vector +from flydsl._mlir.dialects import llvm as llvm_dialect, memref as memref_dialect +from flydsl.expr.arith import _to_raw as _raw +from flydsl.expr.typing import T +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr, get_op_result_or_value + +from kernels.layout_utils import idx2crd +from kernels.pipeline_utils import make_tail_plan + +# WMMA tile dimensions for MXFP4 +WMMA_M, WMMA_N, WMMA_K = 16, 16, 128 +WAVE_SIZE = 32 +PACK_FACTOR = 2 # 2 FP4 elements per byte +SCALE_BLOCK = 32 # 32 FP4 elements per E8M0 scale +SCALES_PER_WMMA = WMMA_K // SCALE_BLOCK # 4 + +# LDS padding in bytes (4 DWORDs = 16 bytes, matches SP3) +LDS_PAD_A_BYTES = 16 +LDS_PAD_B_BYTES = 16 + +_STAGE_NAMES = ("ping", "pong", "pang", "pung") + + +def compile_mxfp4_gemm( + *, + M: int = 0, + N: int = 0, + K: int, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, + m_warp: int = 2, + n_warp: int = 2, + num_buffers: int = 2, + waves_per_eu: int = None, + l2_prefetch_distance: int = 2, + cluster_m: int = 1, + cluster_n: int = 1, +): + """Compile an MXFP4 GEMM kernel with TDM async copy and multi-stage buffering. + + Returns a JitFunction: launch_fn(arg_c, arg_a, arg_b, arg_a_scale, arg_b_scale, M, N, stream) + """ + _ = (M, N) + if num_buffers not in (2, 3, 4): + raise ValueError(f"num_buffers must be 2, 3, or 4, got {num_buffers}") + + use_cluster = cluster_m > 1 or cluster_n > 1 + if use_cluster: + if cluster_m * cluster_n > 16: + raise ValueError( + f"cluster_m * cluster_n must be <= 16, got {cluster_m}*{cluster_n}") + effective_waves_per_eu = waves_per_eu + if use_cluster and effective_waves_per_eu is None: + effective_waves_per_eu = 1 + + num_warps = m_warp * n_warp + block_threads = num_warps * WAVE_SIZE + + packed_tile_k = tile_k // PACK_FACTOR # bytes along K in LDS per row + scale_k_per_tile = tile_k // SCALE_BLOCK + K_packed = K // PACK_FACTOR + K_scale = K // SCALE_BLOCK + + if K % tile_k != 0: + raise ValueError(f"K must be divisible by tile_k={tile_k}, got K={K}") + if tile_k % WMMA_K != 0: + raise ValueError(f"tile_k must be a multiple of {WMMA_K}, got {tile_k}") + if tile_m % WMMA_M != 0: + raise ValueError(f"tile_m must be a multiple of {WMMA_M}, got {tile_m}") + if tile_n % WMMA_N != 0: + raise ValueError(f"tile_n must be a multiple of {WMMA_N}, got {tile_n}") + if packed_tile_k % 4 != 0: + raise ValueError(f"packed_tile_k must be a multiple of 4, got {packed_tile_k}") + if scale_k_per_tile % 4 != 0: + raise ValueError( + f"scale_k_per_tile must be a multiple of 4 (tile_k >= 128), got {scale_k_per_tile}") + + warp_tile_m = tile_m // m_warp + warp_tile_n = tile_n // n_warp + if warp_tile_m % WMMA_M != 0: + raise ValueError(f"warp_tile_m={warp_tile_m} must be a multiple of {WMMA_M}") + if warp_tile_n % WMMA_N != 0: + raise ValueError(f"warp_tile_n={warp_tile_n} must be a multiple of {WMMA_N}") + + num_k_tiles = K // tile_k + if num_k_tiles < num_buffers: + raise ValueError( + f"{num_buffers}-stage buffering requires num_k_tiles >= {num_buffers}, " + f"got {num_k_tiles}") + + gpu_arch = str(get_hip_arch(timeout_s=300)) + assert gpu_arch.startswith("gfx1250"), f"Expected gfx1250, got {gpu_arch}" + + k_wmma_steps = tile_k // WMMA_K + wmma_m_rep = warp_tile_m // WMMA_M + wmma_n_rep = warp_tile_n // WMMA_N + n_accs = wmma_m_rep * wmma_n_rep + + lds_a_stride_bytes = packed_tile_k + LDS_PAD_A_BYTES + lds_b_stride_bytes = packed_tile_k + LDS_PAD_B_BYTES + + lds_a_data_bytes = tile_m * lds_a_stride_bytes + lds_b_data_bytes = tile_n * lds_b_stride_bytes + lds_a_scale_bytes = tile_m * scale_k_per_tile + lds_b_scale_bytes = tile_n * scale_k_per_tile + + stage_allocators = [] + stage_a_data_off = [] + stage_b_data_off = [] + stage_a_scale_off = [] + stage_b_scale_off = [] + + for i in range(num_buffers): + name = _STAGE_NAMES[i] + alloc = SmemAllocator(None, arch=gpu_arch, global_sym_name=f"mxfp4_{name}") + + off = alloc._align(alloc.ptr, 16) + stage_a_data_off.append(off) + alloc.ptr = off + lds_a_data_bytes + + off = alloc._align(alloc.ptr, 16) + stage_b_data_off.append(off) + alloc.ptr = off + lds_b_data_bytes + + off = alloc._align(alloc.ptr, 16) + stage_a_scale_off.append(off) + alloc.ptr = off + lds_a_scale_bytes + + off = alloc._align(alloc.ptr, 16) + stage_b_scale_off.append(off) + alloc.ptr = off + lds_b_scale_bytes + + stage_allocators.append(alloc) + + pre_loaded = num_buffers - 1 + loop_iters = (num_k_tiles - pre_loaded) // num_buffers + _tail_start = loop_iters * num_buffers + extra = num_k_tiles - _tail_start - pre_loaded + _raw_tail_plan = make_tail_plan(num_buffers, pre_loaded, extra) + + # Number of TDM loads per step: A_data + B_data + A_scale + B_scale = 4 + TDM_LOADS_PER_STEP = 4 + + # Scale tail plan outstanding values: make_tail_plan uses 2 (for fp16's A+B), + # but MXFP4 has 4 loads per step (A_data + B_data + A_scale + B_scale). + tail_plan = [ + (ls, cs, o * TDM_LOADS_PER_STEP // 2 if o > 0 else o) + for ls, cs, o in _raw_tail_plan + ] + + # Number of LDS loads per K-subtile (for s_wait_dscnt): + # A frag: wmma_m_rep * 2 ds_load_b128 + # B frag: wmma_n_rep * 2 ds_load_b128 + # A scale: wmma_m_rep ds_load_b32 + # B scale: wmma_n_rep ds_load_b32 + LOADS_PER_SUBTILE = wmma_m_rep * 2 + wmma_n_rep * 2 + wmma_m_rep + wmma_n_rep + + @flyc.kernel + def kernel_mxfp4_gemm( + arg_c: fx.Tensor, + arg_a: fx.Tensor, + arg_b: fx.Tensor, + arg_a_scale: fx.Tensor, + arg_b_scale: fx.Tensor, + i32_m: fx.Int32, + i32_n: fx.Int32, + ): + # Disable VALU stall for back-to-back WMMA + llvm_dialect.inline_asm( + None, [], + "s_setreg_imm32_b32 hwreg(26, 4, 1), 1", + "", has_side_effects=True, + ) + + tx = gpu.thread_id("x") + bx = gpu.block_id("x") + by = gpu.block_id("y") + + blk_m = bx * arith.index(tile_m) + blk_n = by * arith.index(tile_n) + + if use_cluster: + local_x, local_y = gpu.compute_cluster_position() + a_mcast_mask, b_mcast_mask = gpu.compute_mcast_masks( + local_x, local_y, cluster_m, cluster_n) + else: + a_mcast_mask = 0 + b_mcast_mask = 0 + + layout_thr = fx.make_layout( + (m_warp, n_warp, 2, 16), + (n_warp * WAVE_SIZE, WAVE_SIZE, 16, 1)) + thr_coord = idx2crd(tx, layout_thr) + wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( + thr_coord[0], thr_coord[1], thr_coord[2], thr_coord[3]) + + warp_m_base = wave_m_idx * arith.index(warp_tile_m) + warp_n_base = wave_n_idx * arith.index(warp_tile_n) + + m_idx = arith.index_cast(T.index, i32_m.ir_value()) + n_idx = arith.index_cast(T.index, i32_n.ir_value()) + n_stride = arith.index(N) + c_nrec = m_idx * n_stride * arith.index(4) + c_rsrc = buffer_ops.create_buffer_resource(arg_c, num_records_bytes=c_nrec) + + def _get_lds_memref(lds_ptr): + if isinstance(lds_ptr, SmemPtr): + return get_op_result_or_value(lds_ptr.get()) + return get_op_result_or_value(lds_ptr) + + def copy_a_data_to_lds(k_base, lds_mem_ref): + k_packed_off = k_base / arith.index(PACK_FACTOR) + desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_a, lds_memref=lds_mem_ref, + global_offset=(blk_m, k_packed_off), + tensor_shape=(tile_m, packed_tile_k), + strides=(K_packed, 1), + tile_shape=(tile_m, packed_tile_k), + elem_bytes=1, + pad_interval=packed_tile_k, pad_amount=LDS_PAD_A_BYTES, + num_warps=num_warps, + workgroup_mask=a_mcast_mask) + tdm_ops.tensor_load_2d(desc) + + def copy_b_data_to_lds(k_base, lds_mem_ref): + k_packed_off = k_base / arith.index(PACK_FACTOR) + desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_b, lds_memref=lds_mem_ref, + global_offset=(blk_n, k_packed_off), + tensor_shape=(tile_n, packed_tile_k), + strides=(K_packed, 1), + tile_shape=(tile_n, packed_tile_k), + elem_bytes=1, + pad_interval=packed_tile_k, pad_amount=LDS_PAD_B_BYTES, + num_warps=num_warps, + workgroup_mask=b_mcast_mask) + tdm_ops.tensor_load_2d(desc) + + def copy_a_scale_to_lds(k_base, lds_mem_ref): + k_scale_off = k_base / arith.index(SCALE_BLOCK) + desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_a_scale, lds_memref=lds_mem_ref, + global_offset=(blk_m, k_scale_off), + tensor_shape=(tile_m, scale_k_per_tile), + strides=(K_scale, 1), + tile_shape=(tile_m, scale_k_per_tile), + elem_bytes=1, + pad_interval=0, pad_amount=0, + num_warps=num_warps, + workgroup_mask=a_mcast_mask) + tdm_ops.tensor_load_2d(desc) + + def copy_b_scale_to_lds(k_base, lds_mem_ref): + k_scale_off = k_base / arith.index(SCALE_BLOCK) + desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_b_scale, lds_memref=lds_mem_ref, + global_offset=(blk_n, k_scale_off), + tensor_shape=(tile_n, scale_k_per_tile), + strides=(K_scale, 1), + tile_shape=(tile_n, scale_k_per_tile), + elem_bytes=1, + pad_interval=0, pad_amount=0, + num_warps=num_warps, + workgroup_mask=b_mcast_mask) + tdm_ops.tensor_load_2d(desc) + + def issue_all_tdm_loads(k_base, a_mem, b_mem, as_mem, bs_mem): + copy_a_data_to_lds(k_base, a_mem) + copy_b_data_to_lds(k_base, b_mem) + copy_a_scale_to_lds(k_base, as_mem) + copy_b_scale_to_lds(k_base, bs_mem) + + elem_ty_lds = T.f16 + + def _precompute_a_lane_bases(lds_ptr): + """Precompute per-wm A fragment lane base addresses (in BYTES). + + Each lane loads 32 bytes = 64 FP4 (one K-half). + lane16 → M-row, lane_kgrp → K-half (0 or 1). + """ + lds_buffer = _get_lds_memref(lds_ptr) + row_base = (warp_m_base + lane16) * arith.index(lds_a_stride_bytes) + k_half_off = lane_kgrp * arith.index(32) # 32 bytes = 64 FP4 + bases = [] + for wm in range_constexpr(wmma_m_rep): + base = row_base + arith.index(wm * WMMA_M * lds_a_stride_bytes) + k_half_off + bases.append(base) + return lds_buffer, bases + + def _lds_load_b128(lds_buffer, byte_offset): + """Load 16 bytes from LDS at given byte offset via ds_load_b128.""" + from flydsl._mlir.dialects import llvm as _llvm, memref as _memref + lds_ptr_ty = ir.Type.parse("!llvm.ptr<3>") + raw_memref = arith.unwrap(lds_buffer) + lds_base = _memref.extract_aligned_pointer_as_index(raw_memref) + from flydsl.expr.arith import ArithValue as _AV + total_byte = _AV(lds_base) + byte_offset + addr_i32 = _raw(arith.index_cast(T.i32, total_byte)) + ptr_val = _llvm.inttoptr(lds_ptr_ty, addr_i32) + vec4_i32_ty = ir.VectorType.get([4], ir.IntegerType.get_signless(32)) + return llvm_dialect.load(vec4_i32_ty, ptr_val) + + def load_a_frag(lds_buffer, a_lane_base, ks): + """Load one 16x128 FP4 A-fragment from LDS. + + Returns vector<8xi32> (8 VGPRs, 64 FP4 per lane). + 2 x ds_load_b128 via direct LDS pointer load. + """ + k_byte_off = arith.index(ks * WMMA_K // PACK_FACTOR) # bytes per K-subtile + byte_off = a_lane_base + k_byte_off + v0 = _lds_load_b128(lds_buffer, byte_off) + v1 = _lds_load_b128(lds_buffer, byte_off + arith.index(16)) + return vector.shuffle(v0, v1, list(range(8))) + + def _precompute_b_lane_bases(lds_ptr): + """Precompute per-wn B fragment lane base addresses (in BYTES). + + B stored as [tile_n, packed_tile_k + pad] in LDS. + lane16 → N-row, lane_kgrp → K-half. + """ + lds_buffer = _get_lds_memref(lds_ptr) + row_base = (warp_n_base + lane16) * arith.index(lds_b_stride_bytes) + k_half_off = lane_kgrp * arith.index(32) + bases = [] + for wn in range_constexpr(wmma_n_rep): + base = row_base + arith.index(wn * WMMA_N * lds_b_stride_bytes) + k_half_off + bases.append(base) + return lds_buffer, bases + + def load_b_frag(lds_buffer, b_lane_base, ks): + """Load one 128x16 FP4 B-fragment from LDS. Same pattern as A.""" + k_byte_off = arith.index(ks * WMMA_K // PACK_FACTOR) + byte_off = b_lane_base + k_byte_off + v0 = _lds_load_b128(lds_buffer, byte_off) + v1 = _lds_load_b128(lds_buffer, byte_off + arith.index(16)) + return vector.shuffle(v0, v1, list(range(8))) + + def _precompute_scale_lane_bases(lds_ptr, warp_base, reps): + """Precompute scale lane bases (in BYTES). + + Scale LDS layout: [tile_m_or_n, scale_k_per_tile] bytes. + lane16 → M/N row. + """ + lds_buffer = _get_lds_memref(lds_ptr) + row_base = (warp_base + lane16) * arith.index(scale_k_per_tile) + bases = [] + for w in range_constexpr(reps): + base = row_base + arith.index(w * WMMA_M * scale_k_per_tile) + bases.append(base) + return lds_buffer, bases + + def _shuffle_scale_i32(val): + """Swap bytes 1 and 2 of an i32 scale value via v_perm_b32. + + FP4 data VGPR layout splits K=128 as: + V0-V3 lanes 0-15: K=0..31, V4-V7 lanes 0-15: K=64..95 + V0-V3 lanes 16-31: K=32..63, V4-V7 lanes 16-31: K=96..127 + The WMMA_SCALE hardware processes data in VGPR-group-first order, + so the scale i32 byte-to-K-block mapping is [0, 2, 1, 3]: + byte0 → K=0..31, byte1 → K=64..95, byte2 → K=32..63, byte3 → K=96..127 + Memory stores scales sequentially [K0,K1,K2,K3], so we swap bytes 1↔2 + to produce [K0,K2,K1,K3] using a single v_perm_b32 instruction. + """ + i32_ty = ir.IntegerType.get_signless(32) + return llvm_dialect.inline_asm( + i32_ty, [_raw(val) if not isinstance(val, ir.Value) else val], + "v_perm_b32 $0, $1, $1, 0x03010200", + "=v,v", has_side_effects=False, + ) + + def load_scale(lds_buffer, scale_base, ks): + """Load scale for one 16x128 WMMA from LDS. + + Returns i32 (1 VGPR) containing 4 packed E8M0 scale values, + shuffled to match the WMMA_SCALE instruction's byte-to-K-block + mapping: [K0, K2, K1, K3]. + ds_load_b32 via direct LDS pointer load. + """ + from flydsl._mlir.dialects import llvm as _llvm, memref as _memref + lds_ptr_ty = ir.Type.parse("!llvm.ptr<3>") + raw_memref = arith.unwrap(lds_buffer) + lds_base = _memref.extract_aligned_pointer_as_index(raw_memref) + # scale_k_per_tile bytes per row, ks-th group = ks * SCALES_PER_WMMA bytes + byte_off = scale_base + arith.index(ks * SCALES_PER_WMMA) + from flydsl.expr.arith import ArithValue as _AV + total_byte = _AV(lds_base) + byte_off + addr_i32 = _raw(arith.index_cast(T.i32, total_byte)) + ptr_val = _llvm.inttoptr(lds_ptr_ty, addr_i32) + i32_ty = ir.IntegerType.get_signless(32) + raw_scale = llvm_dialect.load(i32_ty, ptr_val) + return _shuffle_scale_i32(raw_scale) + + def load_k_subtile_frags(a_buf, a_bases, b_buf, b_bases, + as_buf, as_bases, bs_buf, bs_bases, ks): + """Batch-load all A/B fragments and scales for one K-subtile.""" + # Load B frags first (gives more time for A frags to arrive) + b_frags = [load_b_frag(b_buf, b_bases[wn], ks) + for wn in range_constexpr(wmma_n_rep)] + a_frags = [load_a_frag(a_buf, a_bases[wm], ks) + for wm in range_constexpr(wmma_m_rep)] + # Load scales + b_scales = [load_scale(bs_buf, bs_bases[wn], ks) + for wn in range_constexpr(wmma_n_rep)] + a_scales = [load_scale(as_buf, as_bases[wm], ks) + for wm in range_constexpr(wmma_m_rep)] + return a_frags, b_frags, a_scales, b_scales + + def do_k_subtile_wmma(a_frags, b_frags, a_scales, b_scales, accs): + """Execute all WMMAs for one K-subtile with scales. + + Uses wmma_scale_f32_16x16x128_f8f6f4 (gfx1250 wave32) with: + fmtA=4 (FP4/E2M1), fmtB=4 (FP4/E2M1), + scaleAType=0 (opsel lo), scaleBType=0 (opsel lo). + fmtScaleA/B defaults to 0 (E8M0). + + Operands are passed as (B, A) instead of (A, B) to compensate + for the WMMA output VGPR layout where lane16→col and + lane_kgrp→row_group. Swapping computes C^T, making the output + match the epilogue's row-major store pattern. + """ + for wm in range_constexpr(wmma_m_rep): + for wn in range_constexpr(wmma_n_rep): + idx = wm * wmma_n_rep + wn + accs[idx] = rocdl.wmma_scale_f32_16x16x128_f8f6f4( + T.vec(8, T.f32), + b_frags[wn], a_frags[wm], accs[idx], + b_scales[wn], a_scales[wm], + fmtA=4, fmtB=4, + scaleAType=0, scaleBType=0, + ) + return accs + + def compute_tile(accs_in, lds_a, lds_b, lds_as, lds_bs, emit_filler=None): + current_accs = list(accs_in) + + a_buf, a_bases = _precompute_a_lane_bases(lds_a) + b_buf, b_bases = _precompute_b_lane_bases(lds_b) + as_buf, as_bases = _precompute_scale_lane_bases(lds_as, warp_m_base, wmma_m_rep) + bs_buf, bs_bases = _precompute_scale_lane_bases(lds_bs, warp_n_base, wmma_n_rep) + + if k_wmma_steps == 1: + frags = load_k_subtile_frags( + a_buf, a_bases, b_buf, b_bases, + as_buf, as_bases, bs_buf, bs_bases, 0) + rocdl.s_wait_dscnt(0) + if emit_filler is not None: + emit_filler() + current_accs = do_k_subtile_wmma(*frags, current_accs) + else: + prev = load_k_subtile_frags( + a_buf, a_bases, b_buf, b_bases, + as_buf, as_bases, bs_buf, bs_bases, 0) + + # Main K-loop: overlap load[ks+1] with compute[ks] + for ks in range_constexpr(k_wmma_steps - 1): + next_frags = load_k_subtile_frags( + a_buf, a_bases, b_buf, b_bases, + as_buf, as_bases, bs_buf, bs_bases, ks + 1) + rocdl.s_wait_dscnt(LOADS_PER_SUBTILE) + current_accs = do_k_subtile_wmma(*prev, current_accs) + prev = next_frags + + # Epilogue + rocdl.s_wait_dscnt(0) + if emit_filler is not None: + rocdl.sched_barrier(0) + emit_filler() + current_accs = do_k_subtile_wmma(*prev, current_accs) + + return current_accs + + def hot_loop_scheduler(): + rocdl.sched_barrier(0) + + # --- Epilogue: vectorized buffer_store_b128 --- + # WMMA output VGPR layout (wave32, 16x16 tile): + # lane16 (lane_id % 16) → N column + # lane_kgrp (lane_id / 16) → M row group (0=rows 0-7, 1=rows 8-15) + # element[i] → M row offset within group + # We compensate by swapping A/B operands in the WMMA call (see + # do_k_subtile_wmma) so the WMMA effectively computes C^T, making + # the output VGPR layout match this epilogue's store pattern: + # lane16 → M row, lane_kgrp*8 + ele → N column group. + def epilogue_prepare_addrs(): + addrs = [] + for wm in range_constexpr(wmma_m_rep): + for wn in range_constexpr(wmma_n_rep): + row = blk_m + warp_m_base + arith.index(wm * WMMA_M) + lane16 + col_base = (blk_n + warp_n_base + arith.index(wn * WMMA_N) + + lane_kgrp * arith.index(8)) + for half in range_constexpr(2): + col = col_base + arith.index(half * 4) + c_off = row * n_stride + col + addrs.append(c_off) + return addrs + + def epilogue_stores(final_accs, addrs): + addr_idx = 0 + for wm in range_constexpr(wmma_m_rep): + for wn in range_constexpr(wmma_n_rep): + idx = wm * wmma_n_rep + wn + for half in range_constexpr(2): + vals = [vector.extract( + final_accs[idx], + static_position=[half * 4 + vi], + dynamic_position=[]) + for vi in range_constexpr(4)] + vec4 = vector.from_elements(T.vec(4, T.f32), vals) + buffer_ops.buffer_store(vec4, c_rsrc, addrs[addr_idx]) + addr_idx += 1 + + def wait_and_barrier(outstanding=0): + tdm_ops.tensor_wait(outstanding) + gpu.barrier() + + def wait_and_cluster_barrier(outstanding=0): + tdm_ops.tensor_wait(outstanding) + if use_cluster: + gpu.cluster_barrier() + else: + gpu.barrier() + + def _compute_and_schedule(accs_in, a, b, a_s, b_s): + rocdl.sched_barrier(0) + accs_out = compute_tile(accs_in, a, b, a_s, b_s) + hot_loop_scheduler() + return accs_out + + _effective_l2_pf = l2_prefetch_distance + if use_cluster and l2_prefetch_distance > 0: + _effective_l2_pf = max(1, l2_prefetch_distance - 1) + + def _l2_prefetch(k_base): + if _effective_l2_pf <= 0: + return + pf_k = k_base + arith.index(_effective_l2_pf * tile_k) + pf_k_packed = pf_k / arith.index(PACK_FACTOR) + tdm_ops.l2_prefetch_tile( + arg_a, (blk_m, pf_k_packed), (tile_m, packed_tile_k), (K_packed, 1), + elem_bytes=1, thread_id=tx, block_threads=block_threads) + tdm_ops.l2_prefetch_tile( + arg_b, (blk_n, pf_k_packed), (tile_n, packed_tile_k), (K_packed, 1), + elem_bytes=1, thread_id=tx, block_threads=block_threads) + + acc_zero = arith.constant_vector(0.0, T.vec(8, T.f32)) + accs = [acc_zero] * n_accs + + # Build per-stage SmemPtrs using f16 element type for addressing. + # FP4 packed data (1 byte = 2 FP4) + scale (1 byte E8M0) both + # addressed in f16 units (2 bytes). This matches the fp16 kernel's + # proven vector.load_op pattern. + lds_a_data_f16 = lds_a_data_bytes // 2 + lds_b_data_f16 = lds_b_data_bytes // 2 + lds_a_scale_f16 = lds_a_scale_bytes // 2 + lds_b_scale_f16 = lds_b_scale_bytes // 2 + + base_ptrs = [sa.get_base() for sa in stage_allocators] + + stages_a = [ + SmemPtr(base_ptrs[i], stage_a_data_off[i], elem_ty_lds, shape=(lds_a_data_f16,)) + for i in range_constexpr(num_buffers) + ] + stages_b = [ + SmemPtr(base_ptrs[i], stage_b_data_off[i], elem_ty_lds, shape=(lds_b_data_f16,)) + for i in range_constexpr(num_buffers) + ] + stages_as = [ + SmemPtr(base_ptrs[i], stage_a_scale_off[i], elem_ty_lds, shape=(lds_a_scale_f16,)) + for i in range_constexpr(num_buffers) + ] + stages_bs = [ + SmemPtr(base_ptrs[i], stage_b_scale_off[i], elem_ty_lds, shape=(lds_b_scale_f16,)) + for i in range_constexpr(num_buffers) + ] + + # Get memrefs for TDM (raw memref values) + stages_a_mem = [stages_a[i].get() for i in range_constexpr(num_buffers)] + stages_b_mem = [stages_b[i].get() for i in range_constexpr(num_buffers)] + stages_as_mem = [stages_as[i].get() for i in range_constexpr(num_buffers)] + stages_bs_mem = [stages_bs[i].get() for i in range_constexpr(num_buffers)] + + # Prologue: load first (num_buffers - 1) tiles + for i in range_constexpr(pre_loaded): + issue_all_tdm_loads( + arith.index(i * tile_k), + stages_a_mem[i], stages_b_mem[i], + stages_as_mem[i], stages_bs_mem[i]) + # Wait for all but the last batch of TDM loads + wait_and_barrier(outstanding=TDM_LOADS_PER_STEP * (num_buffers - 2)) + + # Main loop + main_end = loop_iters * num_buffers * tile_k + + if loop_iters > 0: + for iv, state in range(0, main_end, num_buffers * tile_k, init=list(accs)): + accs_in = list(state) + for s in range_constexpr(num_buffers): + _load_stage = (s + num_buffers - 1) % num_buffers + _load_k_off = (s + num_buffers - 1) * tile_k + issue_all_tdm_loads( + iv + arith.index(_load_k_off), + stages_a_mem[_load_stage], stages_b_mem[_load_stage], + stages_as_mem[_load_stage], stages_bs_mem[_load_stage]) + _l2_prefetch(iv + arith.index(s * tile_k)) + accs_in = _compute_and_schedule( + accs_in, + stages_a[s], stages_b[s], + stages_as[s], stages_bs[s]) + if s == num_buffers - 1: + wait_and_cluster_barrier(outstanding=TDM_LOADS_PER_STEP) + else: + wait_and_barrier(outstanding=TDM_LOADS_PER_STEP) + results = yield list(accs_in) + accs = list(results) + + # Tail + if loop_iters == 0 and use_cluster: + gpu.cluster_barrier() + _extra_j = 0 + for _load_stage, _compute_stage, _outstanding in tail_plan: + if _load_stage is not None: + _k_off = (_tail_start + pre_loaded + _extra_j) * tile_k + issue_all_tdm_loads( + arith.index(_k_off), + stages_a_mem[_load_stage], stages_b_mem[_load_stage], + stages_as_mem[_load_stage], stages_bs_mem[_load_stage]) + _extra_j += 1 + if _outstanding == -1: + epi_addrs_box = [None] + + def _emit_epi_addrs(): + epi_addrs_box[0] = epilogue_prepare_addrs() + + accs = compute_tile( + accs, + stages_a[_compute_stage], stages_b[_compute_stage], + stages_as[_compute_stage], stages_bs[_compute_stage], + emit_filler=_emit_epi_addrs) + else: + accs = _compute_and_schedule( + accs, + stages_a[_compute_stage], stages_b[_compute_stage], + stages_as[_compute_stage], stages_bs[_compute_stage]) + if use_cluster and _load_stage is not None: + wait_and_cluster_barrier(outstanding=_outstanding) + else: + wait_and_barrier(outstanding=_outstanding) + + epilogue_stores(accs, epi_addrs_box[0]) + + cache_tag = (K, tile_m, tile_n, tile_k, m_warp, n_warp, + num_buffers, effective_waves_per_eu, l2_prefetch_distance, + cluster_m, cluster_n) + + @flyc.jit + def launch_mxfp4_gemm( + arg_c: fx.Tensor, + arg_a: fx.Tensor, + arg_b: fx.Tensor, + arg_a_scale: fx.Tensor, + arg_b_scale: fx.Tensor, + i32_m: fx.Int32, + i32_n: fx.Int32, + stream: fx.Stream, + ): + _ = cache_tag + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + for alloc in stage_allocators: + alloc.finalized = False + for alloc in stage_allocators: + alloc.finalize() + + idx_m = arith.index_cast(T.index, i32_m.ir_value()) + idx_n = arith.index_cast(T.index, i32_n.ir_value()) + gx = _raw((idx_m + arith.index(tile_m - 1)) / arith.index(tile_m)) + gy = _raw((idx_n + arith.index(tile_n - 1)) / arith.index(tile_n)) + + launcher = kernel_mxfp4_gemm( + arg_c, arg_a, arg_b, arg_a_scale, arg_b_scale, i32_m, i32_n) + for op in ctx.gpu_module_body.operations: + if hasattr(op, 'attributes') and op.OPERATION_NAME == "gpu.func": + if effective_waves_per_eu is not None: + _wpe = int(effective_waves_per_eu) + if _wpe >= 1: + op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get( + ir.IntegerType.get_signless(32), _wpe) + if use_cluster: + op.attributes["rocdl.cluster_dims"] = ir.StringAttr.get( + f"{cluster_m},{cluster_n},1") + cluster_arg = (cluster_m, cluster_n, 1) if use_cluster else None + launcher.launch( + grid=(gx, gy, 1), + block=(block_threads, 1, 1), + stream=stream, + cluster=cluster_arg, + ) + + return launch_mxfp4_gemm + + +__all__ = ["compile_mxfp4_gemm"] diff --git a/kernels/pipeline_utils.py b/kernels/pipeline_utils.py new file mode 100644 index 00000000..37fabcc8 --- /dev/null +++ b/kernels/pipeline_utils.py @@ -0,0 +1,43 @@ +"""Shared pipeline utilities for gfx1250 GEMM kernels. """ + + +def make_tail_plan(num_buffers, pre_loaded, extra): + """Compute a compile-time tail execution plan for the N-stage pipeline. + + Returns a list of (load_stage, compute_stage, outstanding) tuples, one per + tail step. outstanding=-1 means "last step, use compute_tile (no barrier)". + + Args: + num_buffers: total number of pipeline stages. + pre_loaded: stages already loaded and ready to compute (= num_buffers - 1). + extra: additional tiles that must be loaded in the tail. + """ + steps = pre_loaded + extra + plan = [] + for i in range(steps): + compute_stage = ( + i if i < pre_loaded + else (i - pre_loaded + num_buffers - 1) % num_buffers + ) + load_stage = ( + (i + num_buffers - 1) % num_buffers if i < extra + else None + ) + is_last = (i == steps - 1) + if is_last: + outstanding = -1 + else: + j = i + 1 + next_compute = ( + j if j < pre_loaded + else (j - pre_loaded + num_buffers - 1) % num_buffers + ) + outstanding = ( + 2 if (load_stage is not None and load_stage != next_compute) + else 0 + ) + plan.append((load_stage, compute_stage, outstanding)) + return plan + + +__all__ = ["make_tail_plan"] diff --git a/kernels/wmma_gemm_gfx1250.py b/kernels/wmma_gemm_gfx1250.py index cd8d55ee..730f94a3 100644 --- a/kernels/wmma_gemm_gfx1250.py +++ b/kernels/wmma_gemm_gfx1250.py @@ -17,6 +17,7 @@ from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr, get_op_result_or_value from kernels.layout_utils import idx2crd +from kernels.pipeline_utils import make_tail_plan WMMA_M, WMMA_N, WMMA_K = 16, 16, 32 WAVE_SIZE = 32 @@ -27,43 +28,7 @@ _STAGE_NAMES = ("ping", "pong", "pang") -def _make_tail_plan(num_buffers, pre_loaded, extra): - """Compute a compile-time tail execution plan for the N-stage pipeline. - - Returns a list of (load_stage, compute_stage, outstanding) tuples, one per - tail step. outstanding=-1 means "last step, use compute_tile (no barrier)". - - Args: - num_buffers: total number of pipeline stages. - pre_loaded: stages already loaded and ready to compute (= num_buffers - 1). - extra: additional tiles that must be loaded in the tail. - """ - steps = pre_loaded + extra - plan = [] - for i in range(steps): - compute_stage = ( - i if i < pre_loaded - else (i - pre_loaded + num_buffers - 1) % num_buffers - ) - load_stage = ( - (i + num_buffers - 1) % num_buffers if i < extra - else None - ) - is_last = (i == steps - 1) - if is_last: - outstanding = -1 - else: - j = i + 1 - next_compute = ( - j if j < pre_loaded - else (j - pre_loaded + num_buffers - 1) % num_buffers - ) - outstanding = ( - 2 if (load_stage is not None and load_stage != next_compute) - else 0 - ) - plan.append((load_stage, compute_stage, outstanding)) - return plan +_make_tail_plan = make_tail_plan def compile_wmma_gemm_tdm( diff --git a/python/flydsl/expr/rocdl.py b/python/flydsl/expr/rocdl.py index debe89e0..4896815b 100644 --- a/python/flydsl/expr/rocdl.py +++ b/python/flydsl/expr/rocdl.py @@ -108,6 +108,12 @@ def make_buffer_tensor(memref, alignment=4, loc=None, ip=None): return _prim.make_view(bd_ptr, layout, loc=loc, ip=ip) # Keep references to ODS-generated builders so we can wrap them without losing access. +_ods_wmma_scale_f32_16x16x128_f8f6f4 = ( + globals().get("wmma_scale_f32_16x16x128_f8f6f4", None) +) +_ods_wmma_scale_f32_32x16x128_f4 = ( + globals().get("wmma_scale_f32_32x16x128_f4", None) +) _ods_wave_id = wave_id # ODS: wave_id(res, ...) -> i32 _ods_cluster_workgroup_id_x = cluster_workgroup_id_x _ods_cluster_workgroup_id_y = cluster_workgroup_id_y @@ -325,6 +331,75 @@ def wmma_i32_16x16x32_iu4(result_type, operands, *, loc=None, ip=None): ops = [_unwrap_wmma_operand(v, loc=loc) for v in operands] return _ods_wmma_i32_16x16x32_iu4(result_type, ops, loc=loc, ip=ip).result + +# --- WMMA Scale variants (gfx1250 mxfp4) --- + +def wmma_scale_f32_16x16x128_f8f6f4(result_type, a, b, c, scaleA, scaleB, + *, fmtA=4, fmtB=4, modC=0, + scaleAType=0, fmtScaleA=0, + scaleBType=0, fmtScaleB=0, + reuseA=False, reuseB=False, + loc=None, ip=None): + """V_WMMA_SCALE_F32_16X16X128_F8F6F4 for gfx1250 (wave32). + + Operand types (wave32): + a: vector<8xi32> (16x128 FP4 data) + b: vector<8xi32> (128x16 FP4 data) + c: vector<8xf32> (16x16 FP32 accumulator) + scaleA: i32 (A scale VGPR) + scaleB: i32 (B scale VGPR) + + fmtA/fmtB: data type encoding (4=FP4/E2M1) + scaleAType/scaleBType: opsel – selects lo/hi 16-bit half of scale VGPR (0=lo, 1=hi) + fmtScaleA/fmtScaleB: scale format (0=E8M0, 1=E5M3, 2=E4M3) + """ + if _ods_wmma_scale_f32_16x16x128_f8f6f4 is None: + raise AttributeError("ROCDL op not found: wmma_scale_f32_16x16x128_f8f6f4") + a_v = _unwrap_mfma_operand(a, loc=loc) + b_v = _unwrap_mfma_operand(b, loc=loc) + c_v = _unwrap_mfma_operand(c, loc=loc) + sA = _unwrap_mfma_operand(scaleA, loc=loc) + sB = _unwrap_mfma_operand(scaleB, loc=loc) + return _ods_wmma_scale_f32_16x16x128_f8f6f4( + result_type, a_v, b_v, c_v, sA, sB, + fmtA=fmtA, fmtB=fmtB, modC=modC, + scaleAType=scaleAType, fmtScaleA=fmtScaleA, + scaleBType=scaleBType, fmtScaleB=fmtScaleB, + reuseA=reuseA, reuseB=reuseB, + loc=loc, ip=ip, + ).result + + +def wmma_scale_f32_32x16x128_f4(result_type, a, b, c, scaleA, scaleB, + *, modC=0, + scaleAType=0, fmtScaleA=0, + scaleBType=0, fmtScaleB=0, + reuseA=False, reuseB=False, + loc=None, ip=None): + """V_WMMA_SCALE_F32_32X16X128_F4 for gfx1250 (wave32). + + Operand types (wave32): + a: vector<16xi32> (32x128 FP4 data) + b: vector<8xi32> (128x16 FP4 data) + c: vector<16xf32> (32x16 FP32 accumulator) + scaleA: i32 (A scale VGPR) + scaleB: i32 (B scale VGPR) + """ + if _ods_wmma_scale_f32_32x16x128_f4 is None: + raise AttributeError("ROCDL op not found: wmma_scale_f32_32x16x128_f4") + a_v = _unwrap_mfma_operand(a, loc=loc) + b_v = _unwrap_mfma_operand(b, loc=loc) + c_v = _unwrap_mfma_operand(c, loc=loc) + sA = _unwrap_mfma_operand(scaleA, loc=loc) + sB = _unwrap_mfma_operand(scaleB, loc=loc) + return _ods_wmma_scale_f32_32x16x128_f4( + result_type, a_v, b_v, c_v, sA, sB, + modC=modC, + scaleAType=scaleAType, fmtScaleA=fmtScaleA, + scaleBType=scaleBType, fmtScaleB=fmtScaleB, + reuseA=reuseA, reuseB=reuseB, + loc=loc, ip=ip, + ).result def wave_id(): """Get wave-id-in-workgroup as SGPR (via TTMP8[29:25]). @@ -475,6 +550,8 @@ def lds_transpose_load(result_type, lds_memref, elem_offset, elem_bytes): 'wmma_f32_16x16x16_fp8_fp8', 'wmma_f32_16x16x16_fp8_bf8', 'wmma_f32_16x16x16_bf8_fp8', 'wmma_f32_16x16x16_bf8_bf8', 'wmma_i32_16x16x32_iu4', + 'wmma_scale_f32_16x16x128_f8f6f4', # gfx1250 WMMA_SCALE 16x16x128 (FP4/FP6/FP8) + 'wmma_scale_f32_32x16x128_f4', # gfx1250 WMMA_SCALE 32x16x128 (FP4 only) # Matrix operations - SMFMAC (Sparse Matrix FMA) 'smfmac_f32_32x32x16_f16', 'smfmac_f32_32x32x16_bf16', diff --git a/python/flydsl/expr/rocdl/__init__.py b/python/flydsl/expr/rocdl/__init__.py index 4eb2b4b2..bcc101ba 100644 --- a/python/flydsl/expr/rocdl/__init__.py +++ b/python/flydsl/expr/rocdl/__init__.py @@ -14,6 +14,12 @@ from ..._mlir.dialects.rocdl import * # noqa: F401,F403 # Keep references to ODS-generated builders so we can wrap them without losing access. +_ods_wmma_scale_f32_16x16x128_f8f6f4 = ( + globals().get("wmma_scale_f32_16x16x128_f8f6f4", None) +) +_ods_wmma_scale_f32_32x16x128_f4 = ( + globals().get("wmma_scale_f32_32x16x128_f4", None) +) _ods_wave_id = wave_id # ODS: wave_id(res, ...) -> i32 _ods_cluster_workgroup_id_x = cluster_workgroup_id_x _ods_cluster_workgroup_id_y = cluster_workgroup_id_y @@ -110,6 +116,74 @@ def mfma_scale_f32_16x16x128_f8f6f4(result_type, operands, *, loc=None, ip=None) ).result +def wmma_scale_f32_16x16x128_f8f6f4(result_type, a, b, c, scaleA, scaleB, + *, fmtA=4, fmtB=4, modC=0, + scaleAType=0, fmtScaleA=0, + scaleBType=0, fmtScaleB=0, + reuseA=False, reuseB=False, + loc=None, ip=None): + """V_WMMA_SCALE_F32_16X16X128_F8F6F4 for gfx1250 (wave32). + + Operand types (wave32): + a: vector<8xi32> (16x128 FP4 data) + b: vector<8xi32> (128x16 FP4 data) + c: vector<8xf32> (16x16 FP32 accumulator) + scaleA: i32 (A scale VGPR) + scaleB: i32 (B scale VGPR) + + fmtA/fmtB: data type encoding (4=FP4/E2M1) + scaleAType/scaleBType: opsel – selects lo/hi 16-bit half of scale VGPR (0=lo, 1=hi) + fmtScaleA/fmtScaleB: scale format (0=E8M0, 1=E5M3, 2=E4M3) + """ + if _ods_wmma_scale_f32_16x16x128_f8f6f4 is None: + raise AttributeError("ROCDL op not found: wmma_scale_f32_16x16x128_f8f6f4") + a_v = _unwrap_mfma_operand(a, loc=loc) + b_v = _unwrap_mfma_operand(b, loc=loc) + c_v = _unwrap_mfma_operand(c, loc=loc) + sA = _unwrap_mfma_operand(scaleA, loc=loc) + sB = _unwrap_mfma_operand(scaleB, loc=loc) + return _ods_wmma_scale_f32_16x16x128_f8f6f4( + result_type, a_v, b_v, c_v, sA, sB, + fmtA=fmtA, fmtB=fmtB, modC=modC, + scaleAType=scaleAType, fmtScaleA=fmtScaleA, + scaleBType=scaleBType, fmtScaleB=fmtScaleB, + reuseA=reuseA, reuseB=reuseB, + loc=loc, ip=ip, + ).result + + +def wmma_scale_f32_32x16x128_f4(result_type, a, b, c, scaleA, scaleB, + *, modC=0, + scaleAType=0, fmtScaleA=0, + scaleBType=0, fmtScaleB=0, + reuseA=False, reuseB=False, + loc=None, ip=None): + """V_WMMA_SCALE_F32_32X16X128_F4 for gfx1250 (wave32). + + Operand types (wave32): + a: vector<16xi32> (32x128 FP4 data) + b: vector<8xi32> (128x16 FP4 data) + c: vector<16xf32> (32x16 FP32 accumulator) + scaleA: i32 (A scale VGPR) + scaleB: i32 (B scale VGPR) + """ + if _ods_wmma_scale_f32_32x16x128_f4 is None: + raise AttributeError("ROCDL op not found: wmma_scale_f32_32x16x128_f4") + a_v = _unwrap_mfma_operand(a, loc=loc) + b_v = _unwrap_mfma_operand(b, loc=loc) + c_v = _unwrap_mfma_operand(c, loc=loc) + sA = _unwrap_mfma_operand(scaleA, loc=loc) + sB = _unwrap_mfma_operand(scaleB, loc=loc) + return _ods_wmma_scale_f32_32x16x128_f4( + result_type, a_v, b_v, c_v, sA, sB, + modC=modC, + scaleAType=scaleAType, fmtScaleA=fmtScaleA, + scaleBType=scaleBType, fmtScaleB=fmtScaleB, + reuseA=reuseA, reuseB=reuseB, + loc=loc, ip=ip, + ).result + + def wave_id(): """Get wave-id-in-workgroup as SGPR (via TTMP8[29:25]). diff --git a/tests/kernels/test_mxfp4_gemm_gfx1250.py b/tests/kernels/test_mxfp4_gemm_gfx1250.py new file mode 100644 index 00000000..7fc8e541 --- /dev/null +++ b/tests/kernels/test_mxfp4_gemm_gfx1250.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 +"""MXFP4 GEMM correctness tests for gfx1250. + +Kernel implementation: kernels/mxfp4_gemm_gfx1250.py +""" + +import os +import sys + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +_PYFLIR_SRC = os.path.join(_REPO_ROOT, "flydsl", "src") +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) +if _PYFLIR_SRC not in sys.path: + sys.path.insert(0, _PYFLIR_SRC) + +# workaround for simulator +import flydsl # noqa: E402,F401 -- preload system comgr before torch/HIP loads LLVM + +import pytest +import torch + +from flydsl.runtime.device import get_rocm_arch +from kernels.mxfp4_gemm_gfx1250 import compile_mxfp4_gemm +from tests.kernels.utils import fp4_utils + + +if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available. Skipping GPU tests.", allow_module_level=True) + + +SCALE_BLOCK = 32 + + +def random_mxfp4_packed(rows: int, cols: int, *, device="cpu") -> torch.Tensor: + """Generate random packed MXFP4 data [rows, cols//2] uint8. """ + assert cols % 2 == 0 + unpacked = torch.randint(0, 16, (rows, cols), dtype=torch.uint8, device=device) + return fp4_utils.pack_uint4(unpacked) + + +def random_e8m0(rows: int, cols: int, *, low_exp=127, high_exp=132, + device="cpu") -> torch.Tensor: + """Generate random E8M0 scale bytes [rows, cols] uint8. """ + return torch.randint(low_exp, high_exp + 1, (rows, cols), + dtype=torch.uint8, device=device) + + +def reference_mxfp4_gemm(a_packed, b_packed, a_scale, b_scale, M, N, K): + """Reference MXFP4 GEMM: D = (A * A_scale) @ (B * B_scale)^T. + + Args: + a_packed: [M, K//2] uint8 packed FP4 + b_packed: [N, K//2] uint8 packed FP4 + a_scale: [M, K//SCALE_BLOCK] uint8 E8M0 + b_scale: [N, K//SCALE_BLOCK] uint8 E8M0 + + Returns: + [M, N] float32 result. + """ + a_f32 = fp4_utils.mxfp4_to_f32(a_packed.view(torch.uint8))[:M, :K] + b_f32 = fp4_utils.mxfp4_to_f32(b_packed.view(torch.uint8))[:N, :K] + + a_sc = fp4_utils.e8m0_to_f32(a_scale.view(torch.uint8)) + b_sc = fp4_utils.e8m0_to_f32(b_scale.view(torch.uint8)) + + a_sc_exp = a_sc.repeat_interleave(SCALE_BLOCK, dim=-1)[:M, :K] + b_sc_exp = b_sc.repeat_interleave(SCALE_BLOCK, dim=-1)[:N, :K] + + return torch.matmul(a_f32 * a_sc_exp, (b_f32 * b_sc_exp).T) + + +@pytest.mark.parametrize( + "M, N, K, tile_m, tile_n, tile_k", + [ + (128, 128, 256, 128, 128, 128), + (128, 128, 512, 128, 128, 128), + (128, 128, 1024, 128, 128, 128), + ], +) +@pytest.mark.parametrize("num_buffers", [2]) +def test_mxfp4_gemm(M, N, K, tile_m, tile_n, tile_k, num_buffers, + m_warp=2, n_warp=2, l2_prefetch_distance=0, + cluster_m=1, cluster_n=1): + """MXFP4 GEMM correctness unit test.""" + arch = str(get_rocm_arch(timeout_s=300)) + if arch != "gfx1250": + pytest.skip(f"WMMA_SCALE requires gfx1250, got {arch}") + + num_k_tiles = K // tile_k + if num_buffers > 1 and num_k_tiles < num_buffers: + pytest.skip(f"{num_buffers}-buf requires num_k_tiles >= {num_buffers}") + + torch.manual_seed(0) + + print(f"\nRunning MXFP4 GEMM: M={M}, N={N}, K={K}, " + f"tiles=({tile_m},{tile_n},{tile_k}), bufs={num_buffers}") + + a_packed = random_mxfp4_packed(M, K) + b_packed = random_mxfp4_packed(N, K) + a_scale = random_e8m0(M, K // SCALE_BLOCK) + b_scale = random_e8m0(N, K // SCALE_BLOCK) + + ref = reference_mxfp4_gemm(a_packed, b_packed, a_scale, b_scale, M, N, K) + print(f"Ref stats: min={ref.min():.2f}, max={ref.max():.2f}, " + f"mean={ref.mean():.2f}, std={ref.std():.2f}") + + a_gpu = a_packed.cuda() + b_gpu = b_packed.cuda() + as_gpu = a_scale.cuda() + bs_gpu = b_scale.cuda() + c_gpu = torch.zeros(M, N, dtype=torch.float32, device="cpu").cuda() + + launch_fn = compile_mxfp4_gemm( + M=M, N=N, K=K, + tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, + m_warp=m_warp, n_warp=n_warp, + num_buffers=num_buffers, + l2_prefetch_distance=l2_prefetch_distance, + cluster_m=cluster_m, cluster_n=cluster_n, + ) + launch_fn( + c_gpu.contiguous().view(-1), + a_gpu.contiguous().view(-1), + b_gpu.contiguous().view(-1), + as_gpu.contiguous().view(-1), + bs_gpu.contiguous().view(-1), + M, N, torch.cuda.current_stream(), + ) + torch.cuda.synchronize() + + c_out = c_gpu.cpu() + + print(f"Out stats: min={c_out.min():.2f}, max={c_out.max():.2f}, " + f"mean={c_out.mean():.2f}, std={c_out.std():.2f}") + + if c_out.abs().max() < 1e-10: + print("WARNING: kernel output is all zeros!") + + diff = (c_out - ref).abs() + print(f"Abs diff: max={diff.max():.4f}, mean={diff.mean():.4f}") + + cos_sim = torch.nn.functional.cosine_similarity( + c_out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() + print(f"Cosine similarity: {cos_sim:.6f}") + + torch.testing.assert_close(c_out, ref, rtol=1e-5, atol=1e-8) + print("PASSED") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("-M", type=int, default=128) + parser.add_argument("-N", type=int, default=128) + parser.add_argument("-K", type=int, default=256) + parser.add_argument("--tile-m", type=int, default=128) + parser.add_argument("--tile-n", type=int, default=128) + parser.add_argument("--tile-k", type=int, default=128) + parser.add_argument("--m-warp", type=int, default=2) + parser.add_argument("--n-warp", type=int, default=2) + parser.add_argument("--num-buffers", type=int, default=2, choices=[2, 3, 4]) + parser.add_argument("--l2-prefetch-distance", type=int, default=0) + parser.add_argument("--cluster-m", type=int, default=1) + parser.add_argument("--cluster-n", type=int, default=1) + args = parser.parse_args() + + test_mxfp4_gemm( + args.M, args.N, args.K, + args.tile_m, args.tile_n, args.tile_k, + num_buffers=args.num_buffers, + m_warp=args.m_warp, + n_warp=args.n_warp, + l2_prefetch_distance=args.l2_prefetch_distance, + cluster_m=args.cluster_m, + cluster_n=args.cluster_n, + ) diff --git a/tests/test_common.py b/tests/test_common.py index 2f27592e..276cfeab 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -419,7 +419,7 @@ def checkAllclose( def verify_output(c_out, c_ref, atol=1e-2, rtol=1e-2, msg='', logits_diff_threshold=2e-3): - if checkAllclose(c_out, c_ref, rtol=atol, atol=atol) < 0.05: + if checkAllclose(c_out, c_ref, rtol=rtol, atol=atol) < 0.05: return True # Calculate various error metrics From e2735b348e2ca1e946b3a8ecfed4b52bf621b00a Mon Sep 17 00:00:00 2001 From: aoli26 Date: Wed, 18 Mar 2026 14:12:01 +0000 Subject: [PATCH 08/11] add mxfp4 scale preshuffle optimization --- kernels/mxfp4_gemm_gfx1250.py | 143 ++++++++++++++++------- python/flydsl/expr/rocdl.py | 2 +- python/flydsl/expr/rocdl/__init__.py | 2 +- tests/kernels/test_mxfp4_gemm_gfx1250.py | 69 +++++++++-- 4 files changed, 167 insertions(+), 49 deletions(-) diff --git a/kernels/mxfp4_gemm_gfx1250.py b/kernels/mxfp4_gemm_gfx1250.py index ec961d6c..035a2b96 100644 --- a/kernels/mxfp4_gemm_gfx1250.py +++ b/kernels/mxfp4_gemm_gfx1250.py @@ -48,6 +48,7 @@ def compile_mxfp4_gemm( l2_prefetch_distance: int = 2, cluster_m: int = 1, cluster_n: int = 1, + scale_preshuffle: bool = True, ): """Compile an MXFP4 GEMM kernel with TDM async copy and multi-stage buffering. @@ -116,6 +117,9 @@ def compile_mxfp4_gemm( lds_b_data_bytes = tile_n * lds_b_stride_bytes lds_a_scale_bytes = tile_m * scale_k_per_tile lds_b_scale_bytes = tile_n * scale_k_per_tile + # Interleaved scale layout: [WMMA_M * m_warp, wmma_m_rep * scale_k_per_tile] + interleaved_scale_cols_a = wmma_m_rep * scale_k_per_tile + interleaved_scale_cols_b = wmma_n_rep * scale_k_per_tile stage_allocators = [] stage_a_data_off = [] @@ -164,9 +168,12 @@ def compile_mxfp4_gemm( # Number of LDS loads per K-subtile (for s_wait_dscnt): # A frag: wmma_m_rep * 2 ds_load_b128 # B frag: wmma_n_rep * 2 ds_load_b128 - # A scale: wmma_m_rep ds_load_b32 - # B scale: wmma_n_rep ds_load_b32 - LOADS_PER_SUBTILE = wmma_m_rep * 2 + wmma_n_rep * 2 + wmma_m_rep + wmma_n_rep + # A scale: 1 ds_load_b128 (interleave) or wmma_m_rep ds_load_b32 + # B scale: 1 ds_load_b128 (interleave) or wmma_n_rep ds_load_b32 + if scale_preshuffle: + LOADS_PER_SUBTILE = wmma_m_rep * 2 + wmma_n_rep * 2 + 1 + 1 + else: + LOADS_PER_SUBTILE = wmma_m_rep * 2 + wmma_n_rep * 2 + wmma_m_rep + wmma_n_rep @flyc.kernel def kernel_mxfp4_gemm( @@ -251,30 +258,59 @@ def copy_b_data_to_lds(k_base, lds_mem_ref): def copy_a_scale_to_lds(k_base, lds_mem_ref): k_scale_off = k_base / arith.index(SCALE_BLOCK) - desc = tdm_ops.make_tensor_descriptor_2d( - global_ptr=arg_a_scale, lds_memref=lds_mem_ref, - global_offset=(blk_m, k_scale_off), - tensor_shape=(tile_m, scale_k_per_tile), - strides=(K_scale, 1), - tile_shape=(tile_m, scale_k_per_tile), - elem_bytes=1, - pad_interval=0, pad_amount=0, - num_warps=num_warps, - workgroup_mask=a_mcast_mask) + if scale_preshuffle: + # Interleaved global: [M // wmma_m_rep, wmma_m_rep * K_scale] + outer_off = blk_m / arith.index(wmma_m_rep) + inner_off = k_scale_off * arith.index(wmma_m_rep) + desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_a_scale, lds_memref=lds_mem_ref, + global_offset=(outer_off, inner_off), + tensor_shape=(WMMA_M * m_warp, interleaved_scale_cols_a), + strides=(wmma_m_rep * K_scale, 1), + tile_shape=(WMMA_M * m_warp, interleaved_scale_cols_a), + elem_bytes=1, + pad_interval=0, pad_amount=0, + num_warps=num_warps, + workgroup_mask=a_mcast_mask) + else: + desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_a_scale, lds_memref=lds_mem_ref, + global_offset=(blk_m, k_scale_off), + tensor_shape=(tile_m, scale_k_per_tile), + strides=(K_scale, 1), + tile_shape=(tile_m, scale_k_per_tile), + elem_bytes=1, + pad_interval=0, pad_amount=0, + num_warps=num_warps, + workgroup_mask=a_mcast_mask) tdm_ops.tensor_load_2d(desc) def copy_b_scale_to_lds(k_base, lds_mem_ref): k_scale_off = k_base / arith.index(SCALE_BLOCK) - desc = tdm_ops.make_tensor_descriptor_2d( - global_ptr=arg_b_scale, lds_memref=lds_mem_ref, - global_offset=(blk_n, k_scale_off), - tensor_shape=(tile_n, scale_k_per_tile), - strides=(K_scale, 1), - tile_shape=(tile_n, scale_k_per_tile), - elem_bytes=1, - pad_interval=0, pad_amount=0, - num_warps=num_warps, - workgroup_mask=b_mcast_mask) + if scale_preshuffle: + outer_off = blk_n / arith.index(wmma_n_rep) + inner_off = k_scale_off * arith.index(wmma_n_rep) + desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_b_scale, lds_memref=lds_mem_ref, + global_offset=(outer_off, inner_off), + tensor_shape=(WMMA_N * n_warp, interleaved_scale_cols_b), + strides=(wmma_n_rep * K_scale, 1), + tile_shape=(WMMA_N * n_warp, interleaved_scale_cols_b), + elem_bytes=1, + pad_interval=0, pad_amount=0, + num_warps=num_warps, + workgroup_mask=b_mcast_mask) + else: + desc = tdm_ops.make_tensor_descriptor_2d( + global_ptr=arg_b_scale, lds_memref=lds_mem_ref, + global_offset=(blk_n, k_scale_off), + tensor_shape=(tile_n, scale_k_per_tile), + strides=(K_scale, 1), + tile_shape=(tile_n, scale_k_per_tile), + elem_bytes=1, + pad_interval=0, pad_amount=0, + num_warps=num_warps, + workgroup_mask=b_mcast_mask) tdm_ops.tensor_load_2d(desc) def issue_all_tdm_loads(k_base, a_mem, b_mem, as_mem, bs_mem): @@ -348,19 +384,25 @@ def load_b_frag(lds_buffer, b_lane_base, ks): v1 = _lds_load_b128(lds_buffer, byte_off + arith.index(16)) return vector.shuffle(v0, v1, list(range(8))) - def _precompute_scale_lane_bases(lds_ptr, warp_base, reps): + def _precompute_scale_lane_bases(lds_ptr, warp_base, reps, interleaved_cols=0): """Precompute scale lane bases (in BYTES). - Scale LDS layout: [tile_m_or_n, scale_k_per_tile] bytes. - lane16 → M/N row. + Original layout: [tile_m_or_n, scale_k_per_tile] bytes. + Interleaved layout: [WMMA_M * m_or_n_warp, wmma_rep * scale_k_per_tile] bytes. """ lds_buffer = _get_lds_memref(lds_ptr) - row_base = (warp_base + lane16) * arith.index(scale_k_per_tile) - bases = [] - for w in range_constexpr(reps): - base = row_base + arith.index(w * WMMA_M * scale_k_per_tile) - bases.append(base) - return lds_buffer, bases + if scale_preshuffle and interleaved_cols > 0: + # Interleaved: row = (warp_base / reps) + lane16, stride = interleaved_cols + warp_lds_row = warp_base / arith.index(reps) + lane16 + base = warp_lds_row * arith.index(interleaved_cols) + return lds_buffer, [base] # single base for b128 load + else: + row_base = (warp_base + lane16) * arith.index(scale_k_per_tile) + bases = [] + for w in range_constexpr(reps): + base = row_base + arith.index(w * WMMA_M * scale_k_per_tile) + bases.append(base) + return lds_buffer, bases def _shuffle_scale_i32(val): """Swap bytes 1 and 2 of an i32 scale value via v_perm_b32. @@ -401,8 +443,25 @@ def load_scale(lds_buffer, scale_base, ks): ptr_val = _llvm.inttoptr(lds_ptr_ty, addr_i32) i32_ty = ir.IntegerType.get_signless(32) raw_scale = llvm_dialect.load(i32_ty, ptr_val) + if scale_preshuffle: + return raw_scale return _shuffle_scale_i32(raw_scale) + def load_scale_b128(lds_buffer, scale_base, reps): + """Load all wmma_rep scales via 1 ds_load_b128. + + Interleaved LDS layout places all reps i32 values contiguously. + Returns list of reps i32 values extracted from vec<4xi32>. + """ + v = _lds_load_b128(lds_buffer, scale_base) + results = [] + for i in range_constexpr(reps): + vi = vector.extract(v, static_position=[i], dynamic_position=[]) + if not scale_preshuffle: + vi = _shuffle_scale_i32(vi) + results.append(vi) + return results + def load_k_subtile_frags(a_buf, a_bases, b_buf, b_bases, as_buf, as_bases, bs_buf, bs_bases, ks): """Batch-load all A/B fragments and scales for one K-subtile.""" @@ -412,10 +471,14 @@ def load_k_subtile_frags(a_buf, a_bases, b_buf, b_bases, a_frags = [load_a_frag(a_buf, a_bases[wm], ks) for wm in range_constexpr(wmma_m_rep)] # Load scales - b_scales = [load_scale(bs_buf, bs_bases[wn], ks) - for wn in range_constexpr(wmma_n_rep)] - a_scales = [load_scale(as_buf, as_bases[wm], ks) - for wm in range_constexpr(wmma_m_rep)] + if scale_preshuffle: + b_scales = load_scale_b128(bs_buf, bs_bases[0], wmma_n_rep) + a_scales = load_scale_b128(as_buf, as_bases[0], wmma_m_rep) + else: + b_scales = [load_scale(bs_buf, bs_bases[wn], ks) + for wn in range_constexpr(wmma_n_rep)] + a_scales = [load_scale(as_buf, as_bases[wm], ks) + for wm in range_constexpr(wmma_m_rep)] return a_frags, b_frags, a_scales, b_scales def do_k_subtile_wmma(a_frags, b_frags, a_scales, b_scales, accs): @@ -448,8 +511,10 @@ def compute_tile(accs_in, lds_a, lds_b, lds_as, lds_bs, emit_filler=None): a_buf, a_bases = _precompute_a_lane_bases(lds_a) b_buf, b_bases = _precompute_b_lane_bases(lds_b) - as_buf, as_bases = _precompute_scale_lane_bases(lds_as, warp_m_base, wmma_m_rep) - bs_buf, bs_bases = _precompute_scale_lane_bases(lds_bs, warp_n_base, wmma_n_rep) + as_buf, as_bases = _precompute_scale_lane_bases( + lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a) + bs_buf, bs_bases = _precompute_scale_lane_bases( + lds_bs, warp_n_base, wmma_n_rep, interleaved_scale_cols_b) if k_wmma_steps == 1: frags = load_k_subtile_frags( @@ -663,7 +728,7 @@ def _emit_epi_addrs(): cache_tag = (K, tile_m, tile_n, tile_k, m_warp, n_warp, num_buffers, effective_waves_per_eu, l2_prefetch_distance, - cluster_m, cluster_n) + cluster_m, cluster_n, scale_preshuffle) @flyc.jit def launch_mxfp4_gemm( diff --git a/python/flydsl/expr/rocdl.py b/python/flydsl/expr/rocdl.py index 4896815b..27c3fada 100644 --- a/python/flydsl/expr/rocdl.py +++ b/python/flydsl/expr/rocdl.py @@ -349,7 +349,7 @@ def wmma_scale_f32_16x16x128_f8f6f4(result_type, a, b, c, scaleA, scaleB, scaleA: i32 (A scale VGPR) scaleB: i32 (B scale VGPR) - fmtA/fmtB: data type encoding (4=FP4/E2M1) + fmtA/fmtB: data type encoding (0=FP8/E4M3, 1=FP8/E5M2, 2=FP6/E2M3, 3=FP6/E3M2, 4=FP4/E2M1) scaleAType/scaleBType: opsel – selects lo/hi 16-bit half of scale VGPR (0=lo, 1=hi) fmtScaleA/fmtScaleB: scale format (0=E8M0, 1=E5M3, 2=E4M3) """ diff --git a/python/flydsl/expr/rocdl/__init__.py b/python/flydsl/expr/rocdl/__init__.py index bcc101ba..a788bd23 100644 --- a/python/flydsl/expr/rocdl/__init__.py +++ b/python/flydsl/expr/rocdl/__init__.py @@ -131,7 +131,7 @@ def wmma_scale_f32_16x16x128_f8f6f4(result_type, a, b, c, scaleA, scaleB, scaleA: i32 (A scale VGPR) scaleB: i32 (B scale VGPR) - fmtA/fmtB: data type encoding (4=FP4/E2M1) + fmtA/fmtB: data type encoding (0=FP8/E4M3, 1=FP8/E5M2, 2=FP6/E2M3, 3=FP6/E3M2, 4=FP4/E2M1) scaleAType/scaleBType: opsel – selects lo/hi 16-bit half of scale VGPR (0=lo, 1=hi) fmtScaleA/fmtScaleB: scale format (0=E8M0, 1=E5M3, 2=E4M3) """ diff --git a/tests/kernels/test_mxfp4_gemm_gfx1250.py b/tests/kernels/test_mxfp4_gemm_gfx1250.py index 7fc8e541..99ac58b2 100644 --- a/tests/kernels/test_mxfp4_gemm_gfx1250.py +++ b/tests/kernels/test_mxfp4_gemm_gfx1250.py @@ -32,6 +32,24 @@ SCALE_BLOCK = 32 +def preshuffle_e8m0_scale(scale: torch.Tensor, warp_tile: int, + scale_k_per_tile: int = 4, + WMMA_DIM: int = 16) -> torch.Tensor: + """Preshuffle E8M0 scale for WMMA_SCALE: byte swap + interleave for ds_load_b128. """ + _, K_scale = scale.shape + assert K_scale % 4 == 0, f"K_scale must be divisible by 4, got {K_scale}" + + grouped = scale.view(-1, K_scale // 4, 4) + shuffled = grouped[:, :, [0, 2, 1, 3]].contiguous() + scale = shuffled.view(-1, K_scale) + + wmma_rep = warp_tile // WMMA_DIM + k_groups = K_scale // scale_k_per_tile + g = scale.view(-1, wmma_rep, WMMA_DIM, k_groups, scale_k_per_tile) + g = g.permute(0, 2, 3, 1, 4).contiguous() + return g.reshape(-1, k_groups * wmma_rep * scale_k_per_tile) + + def random_mxfp4_packed(rows: int, cols: int, *, device="cpu") -> torch.Tensor: """Generate random packed MXFP4 data [rows, cols//2] uint8. """ assert cols % 2 == 0 @@ -71,17 +89,18 @@ def reference_mxfp4_gemm(a_packed, b_packed, a_scale, b_scale, M, N, K): @pytest.mark.parametrize( - "M, N, K, tile_m, tile_n, tile_k", + "M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp", [ - (128, 128, 256, 128, 128, 128), - (128, 128, 512, 128, 128, 128), - (128, 128, 1024, 128, 128, 128), + (128, 128, 256, 128, 128, 128, 2, 2), + (128, 128, 512, 128, 128, 128, 2, 2), + (128, 128, 1024, 128, 128, 128, 2, 2), + (1024, 1024, 1024, 128, 256, 128, 2, 4), ], ) @pytest.mark.parametrize("num_buffers", [2]) -def test_mxfp4_gemm(M, N, K, tile_m, tile_n, tile_k, num_buffers, - m_warp=2, n_warp=2, l2_prefetch_distance=0, - cluster_m=1, cluster_n=1): +def test_mxfp4_gemm(M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, + num_buffers, l2_prefetch_distance=0, + cluster_m=1, cluster_n=1, scale_preshuffle=True): """MXFP4 GEMM correctness unit test.""" arch = str(get_rocm_arch(timeout_s=300)) if arch != "gfx1250": @@ -93,8 +112,9 @@ def test_mxfp4_gemm(M, N, K, tile_m, tile_n, tile_k, num_buffers, torch.manual_seed(0) + mcast_str = f", cluster=({cluster_m},{cluster_n})" if cluster_m > 1 or cluster_n > 1 else "" print(f"\nRunning MXFP4 GEMM: M={M}, N={N}, K={K}, " - f"tiles=({tile_m},{tile_n},{tile_k}), bufs={num_buffers}") + f"tiles=({tile_m},{tile_n},{tile_k}), bufs={num_buffers}{mcast_str}") a_packed = random_mxfp4_packed(M, K) b_packed = random_mxfp4_packed(N, K) @@ -105,6 +125,11 @@ def test_mxfp4_gemm(M, N, K, tile_m, tile_n, tile_k, num_buffers, print(f"Ref stats: min={ref.min():.2f}, max={ref.max():.2f}, " f"mean={ref.mean():.2f}, std={ref.std():.2f}") + if scale_preshuffle: + skt = tile_k // SCALE_BLOCK + a_scale = preshuffle_e8m0_scale(a_scale, tile_m // m_warp, scale_k_per_tile=skt) + b_scale = preshuffle_e8m0_scale(b_scale, tile_n // n_warp, scale_k_per_tile=skt) + a_gpu = a_packed.cuda() b_gpu = b_packed.cuda() as_gpu = a_scale.cuda() @@ -118,6 +143,7 @@ def test_mxfp4_gemm(M, N, K, tile_m, tile_n, tile_k, num_buffers, num_buffers=num_buffers, l2_prefetch_distance=l2_prefetch_distance, cluster_m=cluster_m, cluster_n=cluster_n, + scale_preshuffle=scale_preshuffle, ) launch_fn( c_gpu.contiguous().view(-1), @@ -148,6 +174,31 @@ def test_mxfp4_gemm(M, N, K, tile_m, tile_n, tile_k, num_buffers, print("PASSED") +@pytest.mark.parametrize( + "M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, cluster_m, cluster_n", + [ + # 2x2 cluster: needs >= 2 tile-rows and 2 tile-cols + (256, 256, 256, 128, 128, 128, 2, 2, 2, 2), + (1024, 1024, 1024, 128, 256, 128, 2, 4, 2, 2), + # 1x2 cluster: B shared along N + (128, 256, 256, 128, 128, 128, 2, 2, 1, 2), + # 2x1 cluster: A shared along M + (256, 128, 256, 128, 128, 128, 2, 2, 2, 1), + ], +) +@pytest.mark.parametrize("num_buffers", [2]) +def test_mxfp4_gemm_mcast(M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, + cluster_m, cluster_n, num_buffers): + """MXFP4 GEMM correctness test with cluster MCAST.""" + test_mxfp4_gemm( + M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, + num_buffers=num_buffers, + l2_prefetch_distance=2, + cluster_m=cluster_m, cluster_n=cluster_n, + scale_preshuffle=True, + ) + + if __name__ == "__main__": import argparse @@ -164,6 +215,7 @@ def test_mxfp4_gemm(M, N, K, tile_m, tile_n, tile_k, num_buffers, parser.add_argument("--l2-prefetch-distance", type=int, default=0) parser.add_argument("--cluster-m", type=int, default=1) parser.add_argument("--cluster-n", type=int, default=1) + parser.add_argument("--no-scale-preshuffle", action="store_true", default=False) args = parser.parse_args() test_mxfp4_gemm( @@ -175,4 +227,5 @@ def test_mxfp4_gemm(M, N, K, tile_m, tile_n, tile_k, num_buffers, l2_prefetch_distance=args.l2_prefetch_distance, cluster_m=args.cluster_m, cluster_n=args.cluster_n, + scale_preshuffle=not args.no_scale_preshuffle, ) From 734e3179f7bf5053d5d3958d87ce06a89cb53bc3 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Thu, 19 Mar 2026 03:05:15 +0000 Subject: [PATCH 09/11] fix scale preshuffle k-subtile permute and tile-m offset --- kernels/mxfp4_gemm_gfx1250.py | 26 ++++++++++++++---------- tests/kernels/test_mxfp4_gemm_gfx1250.py | 8 +++++--- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/kernels/mxfp4_gemm_gfx1250.py b/kernels/mxfp4_gemm_gfx1250.py index 035a2b96..8ac7836f 100644 --- a/kernels/mxfp4_gemm_gfx1250.py +++ b/kernels/mxfp4_gemm_gfx1250.py @@ -171,7 +171,9 @@ def compile_mxfp4_gemm( # A scale: 1 ds_load_b128 (interleave) or wmma_m_rep ds_load_b32 # B scale: 1 ds_load_b128 (interleave) or wmma_n_rep ds_load_b32 if scale_preshuffle: - LOADS_PER_SUBTILE = wmma_m_rep * 2 + wmma_n_rep * 2 + 1 + 1 + a_scale_b128_loads = (wmma_m_rep + 3) // 4 + b_scale_b128_loads = (wmma_n_rep + 3) // 4 + LOADS_PER_SUBTILE = wmma_m_rep * 2 + wmma_n_rep * 2 + a_scale_b128_loads + b_scale_b128_loads else: LOADS_PER_SUBTILE = wmma_m_rep * 2 + wmma_n_rep * 2 + wmma_m_rep + wmma_n_rep @@ -447,16 +449,18 @@ def load_scale(lds_buffer, scale_base, ks): return raw_scale return _shuffle_scale_i32(raw_scale) - def load_scale_b128(lds_buffer, scale_base, reps): - """Load all wmma_rep scales via 1 ds_load_b128. - - Interleaved LDS layout places all reps i32 values contiguously. - Returns list of reps i32 values extracted from vec<4xi32>. - """ - v = _lds_load_b128(lds_buffer, scale_base) + def load_scale_b128(lds_buffer, scale_base, reps, ks=0): + """Load all wmma_rep scales via ds_load_b128(s) for K-subtile *ks*. """ + ks_byte_off = ks * reps * SCALES_PER_WMMA + eff_base = scale_base if ks_byte_off == 0 else scale_base + arith.index(ks_byte_off) + num_loads = (reps + 3) // 4 + vecs = [] + for ld in range_constexpr(num_loads): + off = eff_base if ld == 0 else eff_base + arith.index(ld * 16) + vecs.append(_lds_load_b128(lds_buffer, off)) results = [] for i in range_constexpr(reps): - vi = vector.extract(v, static_position=[i], dynamic_position=[]) + vi = vector.extract(vecs[i // 4], static_position=[i % 4], dynamic_position=[]) if not scale_preshuffle: vi = _shuffle_scale_i32(vi) results.append(vi) @@ -472,8 +476,8 @@ def load_k_subtile_frags(a_buf, a_bases, b_buf, b_bases, for wm in range_constexpr(wmma_m_rep)] # Load scales if scale_preshuffle: - b_scales = load_scale_b128(bs_buf, bs_bases[0], wmma_n_rep) - a_scales = load_scale_b128(as_buf, as_bases[0], wmma_m_rep) + b_scales = load_scale_b128(bs_buf, bs_bases[0], wmma_n_rep, ks) + a_scales = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) else: b_scales = [load_scale(bs_buf, bs_bases[wn], ks) for wn in range_constexpr(wmma_n_rep)] diff --git a/tests/kernels/test_mxfp4_gemm_gfx1250.py b/tests/kernels/test_mxfp4_gemm_gfx1250.py index 99ac58b2..e44f5c07 100644 --- a/tests/kernels/test_mxfp4_gemm_gfx1250.py +++ b/tests/kernels/test_mxfp4_gemm_gfx1250.py @@ -43,11 +43,13 @@ def preshuffle_e8m0_scale(scale: torch.Tensor, warp_tile: int, shuffled = grouped[:, :, [0, 2, 1, 3]].contiguous() scale = shuffled.view(-1, K_scale) + SCALES_PER_WMMA = 4 wmma_rep = warp_tile // WMMA_DIM k_groups = K_scale // scale_k_per_tile - g = scale.view(-1, wmma_rep, WMMA_DIM, k_groups, scale_k_per_tile) - g = g.permute(0, 2, 3, 1, 4).contiguous() - return g.reshape(-1, k_groups * wmma_rep * scale_k_per_tile) + k_wmma_steps = scale_k_per_tile // SCALES_PER_WMMA + g = scale.view(-1, wmma_rep, WMMA_DIM, k_groups, k_wmma_steps, SCALES_PER_WMMA) + g = g.permute(0, 2, 3, 4, 1, 5).contiguous() + return g.reshape(-1, k_groups * k_wmma_steps * wmma_rep * SCALES_PER_WMMA) def random_mxfp4_packed(rows: int, cols: int, *, device="cpu") -> torch.Tensor: From 00ea83c4a681c9196885f93fe7479bb9e0dfd65d Mon Sep 17 00:00:00 2001 From: jli10004 Date: Thu, 19 Mar 2026 09:40:37 +0000 Subject: [PATCH 10/11] fix rebase issues: add missing GFX1250 interface methods and fix launch params - Add getValTypeA/B/C/D() to MmaAtomGFX1250_WMMAType (required by MmaAtomTypeInterface added on main) - Remove stale async_object param from LaunchFuncOp, use async_dependencies consistent with main --- lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp | 5 +++++ python/flydsl/compiler/kernel_function.py | 1 - 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp b/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp index 6521583f..e8af1e49 100644 --- a/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp +++ b/lib/Dialect/FlyROCDL/GFX1250/MmaAtom.cpp @@ -94,6 +94,11 @@ namespace mlir::fly_rocdl { bool MmaAtomGFX1250_WMMAType::isStatic() const { return true; } +Type MmaAtomGFX1250_WMMAType::getValTypeA() const { return getElemTyA(); } +Type MmaAtomGFX1250_WMMAType::getValTypeB() const { return getElemTyB(); } +Type MmaAtomGFX1250_WMMAType::getValTypeC() const { return getElemTyAcc(); } +Type MmaAtomGFX1250_WMMAType::getValTypeD() const { return getElemTyAcc(); } + Attribute MmaAtomGFX1250_WMMAType::getThrLayout() const { return FxLayout(FxC(32), FxC(1)); } diff --git a/python/flydsl/compiler/kernel_function.py b/python/flydsl/compiler/kernel_function.py index bd754f60..cadd51fb 100644 --- a/python/flydsl/compiler/kernel_function.py +++ b/python/flydsl/compiler/kernel_function.py @@ -296,7 +296,6 @@ def launch( kernel_operands, async_dependencies=async_deps, dynamic_shared_memory_size=smem_val, - async_object=async_object, cluster_size=cluster_size, loc=launch_loc, ip=None, From dfe9202db5fcff9a013edfffd7bffecd25fe717d Mon Sep 17 00:00:00 2001 From: jli10004 Date: Thu, 19 Mar 2026 09:47:47 +0000 Subject: [PATCH 11/11] remove local build scripts, ir_dump, and _mlir symlink from repo --- after_build.sh | 10 --- build_flydsl.sh | 5 -- dump_ir.sh | 56 ------------ env.sh | 5 -- hsakmt_counters.csv | 20 ----- ir_dump/00_origin_ir.mlir | 83 ----------------- ir_dump/01_GpuKernelOutliningPass.mlir | 83 ----------------- ir_dump/02_FlyCanonicalizePass.mlir | 79 ----------------- ir_dump/03_FlyLayoutLoweringPass.mlir | 78 ---------------- ir_dump/04_FlyToROCDLConversionPass.mlir | 93 -------------------- ir_dump/05_Canonicalizer.mlir | 59 ------------- ir_dump/06_SCFToControlFlowPass.mlir | 39 -------- ir_dump/07_ConvertGpuOpsToROCDLOps.mlir | 35 -------- ir_dump/08_GpuROCDLAttachTarget.mlir | 55 ------------ ir_dump/09_SCFToControlFlowPass.mlir | 55 ------------ ir_dump/10_ConvertControlFlowToLLVMPass.mlir | 55 ------------ ir_dump/11_FlyGpuStreamMarkPass.mlir | 55 ------------ ir_dump/12_GpuToLLVMConversionPass.mlir | 65 -------------- ir_dump/13_FlyGpuStreamInjectPass.mlir | 65 -------------- ir_dump/14_ArithToLLVMConversionPass.mlir | 65 -------------- ir_dump/15_ConvertFuncToLLVMPass.mlir | 65 -------------- ir_dump/16_ReconcileUnrealizedCastsPass.mlir | 65 -------------- ir_dump/17_GpuModuleToBinaryPass.mlir | 44 --------- python/flydsl/_mlir | 1 - 24 files changed, 1235 deletions(-) delete mode 100644 after_build.sh delete mode 100755 build_flydsl.sh delete mode 100755 dump_ir.sh delete mode 100644 env.sh delete mode 100644 hsakmt_counters.csv delete mode 100644 ir_dump/00_origin_ir.mlir delete mode 100644 ir_dump/01_GpuKernelOutliningPass.mlir delete mode 100644 ir_dump/02_FlyCanonicalizePass.mlir delete mode 100644 ir_dump/03_FlyLayoutLoweringPass.mlir delete mode 100644 ir_dump/04_FlyToROCDLConversionPass.mlir delete mode 100644 ir_dump/05_Canonicalizer.mlir delete mode 100644 ir_dump/06_SCFToControlFlowPass.mlir delete mode 100644 ir_dump/07_ConvertGpuOpsToROCDLOps.mlir delete mode 100644 ir_dump/08_GpuROCDLAttachTarget.mlir delete mode 100644 ir_dump/09_SCFToControlFlowPass.mlir delete mode 100644 ir_dump/10_ConvertControlFlowToLLVMPass.mlir delete mode 100644 ir_dump/11_FlyGpuStreamMarkPass.mlir delete mode 100644 ir_dump/12_GpuToLLVMConversionPass.mlir delete mode 100644 ir_dump/13_FlyGpuStreamInjectPass.mlir delete mode 100644 ir_dump/14_ArithToLLVMConversionPass.mlir delete mode 100644 ir_dump/15_ConvertFuncToLLVMPass.mlir delete mode 100644 ir_dump/16_ReconcileUnrealizedCastsPass.mlir delete mode 100644 ir_dump/17_GpuModuleToBinaryPass.mlir delete mode 120000 python/flydsl/_mlir diff --git a/after_build.sh b/after_build.sh deleted file mode 100644 index ffc921a0..00000000 --- a/after_build.sh +++ /dev/null @@ -1,10 +0,0 @@ -#export MLIR_PATH=/home/jli10004/flydsl/llvm-project/buildmlir -#export PYTHONPATH=/data/jli/flydsl-ws/flydsl-prev/build/python_packages/:/data/jli/flydsl-ws/flydsl-prev/flydsl_/src:/data/jli/flydsl-ws/flydsl-prev:$PYTHONPATH -#export PATH=/data/jli/flydsl-ws/flydsl-prev/build/bin:$PATH - -MLIR_INSTALL=/home/jli10004/flydsl/mi450-cmodel-env/llvm-project/mlir_install -MLIR_PATH=/home/jli10004/flydsl/mi450-cmodel-env/llvm-project/build-flydsl - -export PYTHONPATH=/home/jli10004/flydsl/mi450-cmodel-env/flydsl-prev/build-fly/python_packages:/home/jli10004/flydsl/mi450-cmodel-env/llvm-project/build-flydsl/tools/mlir/python_packages/mlir_core:$PYTHONPATH -export PYTHONPATH=/home/jli10004/flydsl/mi450-cmodel-env/flydsl-prev/python:$PYTHONPATH -export LD_LIBRARY_PATH=$MLIR_INSTALL/lib:$LD_LIBRARY_PATH diff --git a/build_flydsl.sh b/build_flydsl.sh deleted file mode 100755 index 74ba47f3..00000000 --- a/build_flydsl.sh +++ /dev/null @@ -1,5 +0,0 @@ -export MLIR_PATH=/home/jli10004/flydsl/mi450-cmodel-env/llvm-project/mlir_install -mkdir -p build-fly && cd build-fly -cmake .. -DMLIR_DIR=/home/jli10004/flydsl/mi450-cmodel-env/llvm-project/mlir_install/lib/cmake/mlir -GNinja -NPROC=$(nproc 2>/dev/null || echo 4) -ninja -j${NPROC} diff --git a/dump_ir.sh b/dump_ir.sh deleted file mode 100755 index ce5e1c92..00000000 --- a/dump_ir.sh +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/bash -# Usage: ./dump_ir.sh [output_dir] -# -# Runs the example with IR printing enabled, then splits each pass's -# IR dump into a numbered file under (default: ./ir_dump). -set -euo pipefail - -EXAMPLE="${1:?Usage: $0 [output_dir]}" -OUTDIR="${2:-./ir_dump}" - -rm -rf "$OUTDIR" ~/.flydsl/cache/ -mkdir -p "$OUTDIR" - -cd /home/jli10004/flydsl/flydsl-prev -source after_build.sh 2>/dev/null || true - -export HIP_VISIBLE_DEVICES=0 -export FLYDSL_DEBUG_PRINT_ORIGIN_IR=1 -export FLYDSL_DEBUG_PRINT_AFTER_ALL=1 -export FLYDSL_DEBUG_LOG_TO_CONSOLE=1 -export FLYDSL_DEBUG_LOG_LEVEL=INFO - -python "$EXAMPLE" >"$OUTDIR/_raw.txt" 2>&1 - -python3 -c " -import re, sys, os - -outdir = '$OUTDIR' -with open(os.path.join(outdir, '_raw.txt')) as f: - text = f.read() - -sections = [] - -# Origin IR -m = re.search(r'Origin IR:\s*\n(module\b.*?)(?=\n// -----// IR Dump|\Z)', text, re.DOTALL) -if m: - sections.append(('origin_ir', m.group(1).rstrip())) - -# Per-pass IR -marker = re.compile(r'^// -----// IR Dump After (\S+)(?: \(.*?\))? //----- //\$', re.MULTILINE) -hits = list(marker.finditer(text)) -for i, h in enumerate(hits): - end = hits[i+1].start() if i+1 < len(hits) else len(text) - body = text[h.end()+1:end].rstrip() - sections.append((h.group(1), body)) - -for seq, (name, body) in enumerate(sections): - safe = re.sub(r'[^\w-]', '_', name) - fn = f'{seq:02d}_{safe}.mlir' - with open(os.path.join(outdir, fn), 'w') as f: - f.write(body + '\n') - print(f' {fn}') - -os.remove(os.path.join(outdir, '_raw.txt')) -print(f'\n{len(sections)} IR files written to {outdir}/') -" diff --git a/env.sh b/env.sh deleted file mode 100644 index 4fbd1ccb..00000000 --- a/env.sh +++ /dev/null @@ -1,5 +0,0 @@ -export MLIR_PATH=/data/jli/flydsl-ws/llvm-project/build-flydsl/mlir_install -export PYTHONPATH=/data/jli/flydsl-ws/FlyDSL/.flir/build/python_packages/flydsl:/data/jli/flydsl-ws/FlyDSL/flydsl/src:/data/jli/flydsl-ws/FlyDSL:$PYTHONPATH -export PATH=/data/jli/flydsl-ws/FlyDSL/.flir/build/bin:$PATH -export SHOW_IR=1 -export PATH=/data/jli/flydsl-ws/llvm-project/build-flydsl/mlir_install/bin:$PATH diff --git a/hsakmt_counters.csv b/hsakmt_counters.csv deleted file mode 100644 index e2ef75da..00000000 --- a/hsakmt_counters.csv +++ /dev/null @@ -1,20 +0,0 @@ -hsakmtmodel.executor.0.jitcu.num_instr_analyzed,hsakmtmodel.executor.0.jitcu.num_instr_executed,hsakmtmodel.executor.0.jitcu.num_rts_primcache_hits,hsakmtmodel.executor.0.jitcu.num_rts_primcache_misses,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_flat,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_global_scratch_load,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_global_scratch_store,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_lds,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_salu,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_smem,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_tex,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_valu,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_valu_xdlmacc,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_waves,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_waves_created,hsakmtmodel.executor.0.jitcu.sq_perf_sel_insts_waves_finished,hsakmtmodel.jitcu.num_instr_analyzed,hsakmtmodel.jitcu.num_instr_executed,hsakmtmodel.jitcu.num_rts_primcache_hits,hsakmtmodel.jitcu.num_rts_primcache_misses,hsakmtmodel.jitcu.sq_perf_sel_insts_flat,hsakmtmodel.jitcu.sq_perf_sel_insts_global_scratch_load,hsakmtmodel.jitcu.sq_perf_sel_insts_global_scratch_store,hsakmtmodel.jitcu.sq_perf_sel_insts_lds,hsakmtmodel.jitcu.sq_perf_sel_insts_salu,hsakmtmodel.jitcu.sq_perf_sel_insts_smem,hsakmtmodel.jitcu.sq_perf_sel_insts_tex,hsakmtmodel.jitcu.sq_perf_sel_insts_valu,hsakmtmodel.jitcu.sq_perf_sel_insts_valu_xdlmacc,hsakmtmodel.jitcu.sq_perf_sel_insts_waves,hsakmtmodel.jitcu.sq_perf_sel_insts_waves_created,hsakmtmodel.jitcu.sq_perf_sel_insts_waves_finished, -21248,21248,0,0,0,0,0,512,2176,2304,512,13184,0,384,384,384,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, -37376,37376,0,0,0,0,0,3584,2944,2304,3584,20864,0,384,384,384,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, -37376,37376,0,0,0,0,0,3584,2944,2304,3584,20864,0,384,384,384,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, -57426,57426,0,0,0,0,0,7510,3371,2304,7510,30674,0,384,384,384,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, -33792,32768,0,0,0,0,1024,0,19456,1024,1024,3072,0,1024,1024,1024,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, -33792,32768,0,0,0,0,1024,0,19456,1024,1024,3072,0,1024,1024,1024,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, -50176,50176,0,0,0,1024,1024,0,14336,2048,2048,14336,0,1024,10224668160,23365632,0,0,0,0,516096,0,2393088,21504,516096,13854720,0,132096,3072,3072,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, -93825,93825,0,0,0,0,0,14896,2840,2304,14896,49139,0,384,384,384,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, -3806208,3677184,0,027584,27584,0,0,0,0,0,1664,2752,2304,1664,16064,0,384,384,384,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, -117184,114704,0,0,0,4104,4008,20480,264,24,12208,31464,8192,136,8,8,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, -37376,37376,0,0,0,0,0,3584,2944,2304,3584,20864,0,384,384,384,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, -37376,37376,0,0,0,0,0,3584,2944,2304,3584,20864,0,384,384,384,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, -27904,27904,0,0,0,0,0,1792,2432,2304,1792,16384,0,384,384,384,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, -0, -10716,10548,0,0,0,92,72,0,4584,60,164,1900,0,20,20,20,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, -3976,3952,0,0,0,144,72,0,628,352,216,1396,0,20,20,20,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, -24643584,24643584,0,0,0,516096,516096,0,3870720,2451456,1032192,8386560,0,129024,129024,129024,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, -21248,21248,0,0,0,0,0,512,2176,2304,512,13184,0,384,384,384,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, -45158400,44126208,0,0,0,516096,516096,0,11612160,1935360,1032192,16128000,0,129024,129024,129024,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, diff --git a/ir_dump/00_origin_ir.mlir b/ir_dump/00_origin_ir.mlir deleted file mode 100644 index c4596836..00000000 --- a/ir_dump/00_origin_ir.mlir +++ /dev/null @@ -1,83 +0,0 @@ -module attributes {gpu.container_module} { - gpu.module @kernels [#rocdl.target] { - gpu.func @vectorAddKernel_0(%arg0: !fly.memref>, %arg1: !fly.memref>, %arg2: !fly.memref>) kernel { - %block_id_x = gpu.block_id x - %0 = arith.index_cast %block_id_x : index to i32 - %thread_id_x = gpu.thread_id x - %1 = arith.index_cast %thread_id_x : index to i32 - %2 = fly.make_shape() : () -> !fly.int_tuple<64> - %3 = fly.make_stride() : () -> !fly.int_tuple<1> - %4 = fly.make_layout(%2, %3) : (!fly.int_tuple<64>, !fly.int_tuple<1>) -> !fly.layout<64:1> - %5 = fly.logical_divide(%arg0, %4) : (!fly.memref>, !fly.layout<64:1>) -> !fly.memref> - %6 = fly.make_shape() : () -> !fly.int_tuple<64> - %7 = fly.make_stride() : () -> !fly.int_tuple<1> - %8 = fly.make_layout(%6, %7) : (!fly.int_tuple<64>, !fly.int_tuple<1>) -> !fly.layout<64:1> - %9 = fly.logical_divide(%arg1, %8) : (!fly.memref>, !fly.layout<64:1>) -> !fly.memref> - %10 = fly.make_shape() : () -> !fly.int_tuple<64> - %11 = fly.make_stride() : () -> !fly.int_tuple<1> - %12 = fly.make_layout(%10, %11) : (!fly.int_tuple<64>, !fly.int_tuple<1>) -> !fly.layout<64:1> - %13 = fly.logical_divide(%arg2, %12) : (!fly.memref>, !fly.layout<64:1>) -> !fly.memref> - %14 = fly.make_coord(%0) : (i32) -> !fly.int_tuple<(*,?)> - %15 = fly.slice(%5, %14) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> - %16 = fly.make_coord(%0) : (i32) -> !fly.int_tuple<(*,?)> - %17 = fly.slice(%9, %16) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> - %18 = fly.make_coord(%0) : (i32) -> !fly.int_tuple<(*,?)> - %19 = fly.slice(%13, %18) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> - %20 = fly.make_shape() : () -> !fly.int_tuple<1> - %21 = fly.make_stride() : () -> !fly.int_tuple<1> - %22 = fly.make_layout(%20, %21) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %23 = fly.logical_divide(%15, %22) : (!fly.memref>, !fly.layout<1:1>) -> !fly.memref> - %24 = fly.make_shape() : () -> !fly.int_tuple<1> - %25 = fly.make_stride() : () -> !fly.int_tuple<1> - %26 = fly.make_layout(%24, %25) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %27 = fly.logical_divide(%17, %26) : (!fly.memref>, !fly.layout<1:1>) -> !fly.memref> - %28 = fly.make_shape() : () -> !fly.int_tuple<1> - %29 = fly.make_stride() : () -> !fly.int_tuple<1> - %30 = fly.make_layout(%28, %29) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %31 = fly.logical_divide(%19, %30) : (!fly.memref>, !fly.layout<1:1>) -> !fly.memref> - %32 = fly.make_atom : () -> !fly.atom.universal_copy<32> - %33 = fly.make_shape() : () -> !fly.int_tuple<1> - %34 = fly.make_stride() : () -> !fly.int_tuple<1> - %35 = fly.make_layout(%33, %34) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %36 = fly.memref.alloca(%35) : (!fly.layout<1:1>) -> !fly.memref - %37 = fly.make_shape() : () -> !fly.int_tuple<1> - %38 = fly.make_stride() : () -> !fly.int_tuple<1> - %39 = fly.make_layout(%37, %38) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %40 = fly.memref.alloca(%39) : (!fly.layout<1:1>) -> !fly.memref - %41 = fly.make_shape() : () -> !fly.int_tuple<1> - %42 = fly.make_stride() : () -> !fly.int_tuple<1> - %43 = fly.make_layout(%41, %42) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %44 = fly.memref.alloca(%43) : (!fly.layout<1:1>) -> !fly.memref - %45 = fly.make_coord(%1) : (i32) -> !fly.int_tuple<(*,?)> - %46 = fly.slice(%23, %45) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> - fly.copy_atom_call(%32, %46, %36) : (!fly.atom.universal_copy<32>, !fly.memref>, !fly.memref) -> () - %47 = fly.make_coord(%1) : (i32) -> !fly.int_tuple<(*,?)> - %48 = fly.slice(%27, %47) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> - fly.copy_atom_call(%32, %48, %40) : (!fly.atom.universal_copy<32>, !fly.memref>, !fly.memref) -> () - %49 = fly.memref.load_vec(%36) : (!fly.memref) -> vector<1xf32> - %50 = fly.memref.load_vec(%40) : (!fly.memref) -> vector<1xf32> - %51 = arith.addf %49, %50 : vector<1xf32> - fly.memref.store_vec(%51, %44) : (vector<1xf32>, !fly.memref) -> () - %52 = fly.make_coord(%1) : (i32) -> !fly.int_tuple<(*,?)> - %53 = fly.slice(%31, %52) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> - fly.copy_atom_call(%32, %44, %53) : (!fly.atom.universal_copy<32>, !fly.memref, !fly.memref>) -> () - gpu.return - } - } - func.func @vectorAdd(%arg0: !fly.memref>, %arg1: !fly.memref>, %arg2: !fly.memref>, %arg3: i32, %arg4: !gpu.async.token) attributes {llvm.emit_c_interface} { - %c64_i32 = arith.constant 64 : i32 - %0 = arith.addi %arg3, %c64_i32 : i32 - %c1_i32 = arith.constant 1 : i32 - %1 = arith.subi %0, %c1_i32 : i32 - %c64_i32_0 = arith.constant 64 : i32 - %2 = arith.floordivsi %1, %c64_i32_0 : i32 - %3 = arith.index_cast %2 : i32 to index - %c1 = arith.constant 1 : index - %c1_1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %c1_2 = arith.constant 1 : index - %c1_3 = arith.constant 1 : index - gpu.launch_func <%arg4 : !gpu.async.token> @kernels::@vectorAddKernel_0 blocks in (%3, %c1, %c1_1) threads in (%c64, %c1_2, %c1_3) args(%arg0 : !fly.memref>, %arg1 : !fly.memref>, %arg2 : !fly.memref>) - return - } -} diff --git a/ir_dump/01_GpuKernelOutliningPass.mlir b/ir_dump/01_GpuKernelOutliningPass.mlir deleted file mode 100644 index c4596836..00000000 --- a/ir_dump/01_GpuKernelOutliningPass.mlir +++ /dev/null @@ -1,83 +0,0 @@ -module attributes {gpu.container_module} { - gpu.module @kernels [#rocdl.target] { - gpu.func @vectorAddKernel_0(%arg0: !fly.memref>, %arg1: !fly.memref>, %arg2: !fly.memref>) kernel { - %block_id_x = gpu.block_id x - %0 = arith.index_cast %block_id_x : index to i32 - %thread_id_x = gpu.thread_id x - %1 = arith.index_cast %thread_id_x : index to i32 - %2 = fly.make_shape() : () -> !fly.int_tuple<64> - %3 = fly.make_stride() : () -> !fly.int_tuple<1> - %4 = fly.make_layout(%2, %3) : (!fly.int_tuple<64>, !fly.int_tuple<1>) -> !fly.layout<64:1> - %5 = fly.logical_divide(%arg0, %4) : (!fly.memref>, !fly.layout<64:1>) -> !fly.memref> - %6 = fly.make_shape() : () -> !fly.int_tuple<64> - %7 = fly.make_stride() : () -> !fly.int_tuple<1> - %8 = fly.make_layout(%6, %7) : (!fly.int_tuple<64>, !fly.int_tuple<1>) -> !fly.layout<64:1> - %9 = fly.logical_divide(%arg1, %8) : (!fly.memref>, !fly.layout<64:1>) -> !fly.memref> - %10 = fly.make_shape() : () -> !fly.int_tuple<64> - %11 = fly.make_stride() : () -> !fly.int_tuple<1> - %12 = fly.make_layout(%10, %11) : (!fly.int_tuple<64>, !fly.int_tuple<1>) -> !fly.layout<64:1> - %13 = fly.logical_divide(%arg2, %12) : (!fly.memref>, !fly.layout<64:1>) -> !fly.memref> - %14 = fly.make_coord(%0) : (i32) -> !fly.int_tuple<(*,?)> - %15 = fly.slice(%5, %14) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> - %16 = fly.make_coord(%0) : (i32) -> !fly.int_tuple<(*,?)> - %17 = fly.slice(%9, %16) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> - %18 = fly.make_coord(%0) : (i32) -> !fly.int_tuple<(*,?)> - %19 = fly.slice(%13, %18) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> - %20 = fly.make_shape() : () -> !fly.int_tuple<1> - %21 = fly.make_stride() : () -> !fly.int_tuple<1> - %22 = fly.make_layout(%20, %21) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %23 = fly.logical_divide(%15, %22) : (!fly.memref>, !fly.layout<1:1>) -> !fly.memref> - %24 = fly.make_shape() : () -> !fly.int_tuple<1> - %25 = fly.make_stride() : () -> !fly.int_tuple<1> - %26 = fly.make_layout(%24, %25) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %27 = fly.logical_divide(%17, %26) : (!fly.memref>, !fly.layout<1:1>) -> !fly.memref> - %28 = fly.make_shape() : () -> !fly.int_tuple<1> - %29 = fly.make_stride() : () -> !fly.int_tuple<1> - %30 = fly.make_layout(%28, %29) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %31 = fly.logical_divide(%19, %30) : (!fly.memref>, !fly.layout<1:1>) -> !fly.memref> - %32 = fly.make_atom : () -> !fly.atom.universal_copy<32> - %33 = fly.make_shape() : () -> !fly.int_tuple<1> - %34 = fly.make_stride() : () -> !fly.int_tuple<1> - %35 = fly.make_layout(%33, %34) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %36 = fly.memref.alloca(%35) : (!fly.layout<1:1>) -> !fly.memref - %37 = fly.make_shape() : () -> !fly.int_tuple<1> - %38 = fly.make_stride() : () -> !fly.int_tuple<1> - %39 = fly.make_layout(%37, %38) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %40 = fly.memref.alloca(%39) : (!fly.layout<1:1>) -> !fly.memref - %41 = fly.make_shape() : () -> !fly.int_tuple<1> - %42 = fly.make_stride() : () -> !fly.int_tuple<1> - %43 = fly.make_layout(%41, %42) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %44 = fly.memref.alloca(%43) : (!fly.layout<1:1>) -> !fly.memref - %45 = fly.make_coord(%1) : (i32) -> !fly.int_tuple<(*,?)> - %46 = fly.slice(%23, %45) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> - fly.copy_atom_call(%32, %46, %36) : (!fly.atom.universal_copy<32>, !fly.memref>, !fly.memref) -> () - %47 = fly.make_coord(%1) : (i32) -> !fly.int_tuple<(*,?)> - %48 = fly.slice(%27, %47) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> - fly.copy_atom_call(%32, %48, %40) : (!fly.atom.universal_copy<32>, !fly.memref>, !fly.memref) -> () - %49 = fly.memref.load_vec(%36) : (!fly.memref) -> vector<1xf32> - %50 = fly.memref.load_vec(%40) : (!fly.memref) -> vector<1xf32> - %51 = arith.addf %49, %50 : vector<1xf32> - fly.memref.store_vec(%51, %44) : (vector<1xf32>, !fly.memref) -> () - %52 = fly.make_coord(%1) : (i32) -> !fly.int_tuple<(*,?)> - %53 = fly.slice(%31, %52) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> - fly.copy_atom_call(%32, %44, %53) : (!fly.atom.universal_copy<32>, !fly.memref, !fly.memref>) -> () - gpu.return - } - } - func.func @vectorAdd(%arg0: !fly.memref>, %arg1: !fly.memref>, %arg2: !fly.memref>, %arg3: i32, %arg4: !gpu.async.token) attributes {llvm.emit_c_interface} { - %c64_i32 = arith.constant 64 : i32 - %0 = arith.addi %arg3, %c64_i32 : i32 - %c1_i32 = arith.constant 1 : i32 - %1 = arith.subi %0, %c1_i32 : i32 - %c64_i32_0 = arith.constant 64 : i32 - %2 = arith.floordivsi %1, %c64_i32_0 : i32 - %3 = arith.index_cast %2 : i32 to index - %c1 = arith.constant 1 : index - %c1_1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %c1_2 = arith.constant 1 : index - %c1_3 = arith.constant 1 : index - gpu.launch_func <%arg4 : !gpu.async.token> @kernels::@vectorAddKernel_0 blocks in (%3, %c1, %c1_1) threads in (%c64, %c1_2, %c1_3) args(%arg0 : !fly.memref>, %arg1 : !fly.memref>, %arg2 : !fly.memref>) - return - } -} diff --git a/ir_dump/02_FlyCanonicalizePass.mlir b/ir_dump/02_FlyCanonicalizePass.mlir deleted file mode 100644 index b61fd81d..00000000 --- a/ir_dump/02_FlyCanonicalizePass.mlir +++ /dev/null @@ -1,79 +0,0 @@ -module attributes {gpu.container_module} { - gpu.module @kernels [#rocdl.target] { - gpu.func @vectorAddKernel_0(%arg0: !fly.memref>, %arg1: !fly.memref>, %arg2: !fly.memref>) kernel { - %block_id_x = gpu.block_id x - %0 = arith.index_cast %block_id_x : index to i32 - %thread_id_x = gpu.thread_id x - %1 = arith.index_cast %thread_id_x : index to i32 - %2 = fly.make_int_tuple() : () -> !fly.int_tuple<64> - %3 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %4 = fly.make_layout(%2, %3) : (!fly.int_tuple<64>, !fly.int_tuple<1>) -> !fly.layout<64:1> - %5 = fly.logical_divide(%arg0, %4) : (!fly.memref>, !fly.layout<64:1>) -> !fly.memref> - %6 = fly.make_int_tuple() : () -> !fly.int_tuple<64> - %7 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %8 = fly.make_layout(%6, %7) : (!fly.int_tuple<64>, !fly.int_tuple<1>) -> !fly.layout<64:1> - %9 = fly.logical_divide(%arg1, %8) : (!fly.memref>, !fly.layout<64:1>) -> !fly.memref> - %10 = fly.make_int_tuple() : () -> !fly.int_tuple<64> - %11 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %12 = fly.make_layout(%10, %11) : (!fly.int_tuple<64>, !fly.int_tuple<1>) -> !fly.layout<64:1> - %13 = fly.logical_divide(%arg2, %12) : (!fly.memref>, !fly.layout<64:1>) -> !fly.memref> - %14 = fly.make_int_tuple(%0) : (i32) -> !fly.int_tuple<(*,?)> - %15 = fly.slice(%5, %14) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> - %16 = fly.make_int_tuple(%0) : (i32) -> !fly.int_tuple<(*,?)> - %17 = fly.slice(%9, %16) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> - %18 = fly.make_int_tuple(%0) : (i32) -> !fly.int_tuple<(*,?)> - %19 = fly.slice(%13, %18) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> - %20 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %21 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %22 = fly.make_layout(%20, %21) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %23 = fly.logical_divide(%15, %22) : (!fly.memref>, !fly.layout<1:1>) -> !fly.memref> - %24 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %25 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %26 = fly.make_layout(%24, %25) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %27 = fly.logical_divide(%17, %26) : (!fly.memref>, !fly.layout<1:1>) -> !fly.memref> - %28 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %29 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %30 = fly.make_layout(%28, %29) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %31 = fly.logical_divide(%19, %30) : (!fly.memref>, !fly.layout<1:1>) -> !fly.memref> - %32 = fly.make_atom : () -> !fly.atom.universal_copy<32> - %33 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %34 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %35 = fly.make_layout(%33, %34) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %36 = fly.memref.alloca(%35) : (!fly.layout<1:1>) -> !fly.memref - %37 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %38 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %39 = fly.make_layout(%37, %38) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %40 = fly.memref.alloca(%39) : (!fly.layout<1:1>) -> !fly.memref - %41 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %42 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %43 = fly.make_layout(%41, %42) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %44 = fly.memref.alloca(%43) : (!fly.layout<1:1>) -> !fly.memref - %45 = fly.make_int_tuple(%1) : (i32) -> !fly.int_tuple<(*,?)> - %46 = fly.slice(%23, %45) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> - fly.copy_atom_call(%32, %46, %36) : (!fly.atom.universal_copy<32>, !fly.memref>, !fly.memref) -> () - %47 = fly.make_int_tuple(%1) : (i32) -> !fly.int_tuple<(*,?)> - %48 = fly.slice(%27, %47) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> - fly.copy_atom_call(%32, %48, %40) : (!fly.atom.universal_copy<32>, !fly.memref>, !fly.memref) -> () - %49 = fly.memref.load_vec(%36) : (!fly.memref) -> vector<1xf32> - %50 = fly.memref.load_vec(%40) : (!fly.memref) -> vector<1xf32> - %51 = arith.addf %49, %50 : vector<1xf32> - fly.memref.store_vec(%51, %44) : (vector<1xf32>, !fly.memref) -> () - %52 = fly.make_int_tuple(%1) : (i32) -> !fly.int_tuple<(*,?)> - %53 = fly.slice(%31, %52) : (!fly.memref>, !fly.int_tuple<(*,?)>) -> !fly.memref> - fly.copy_atom_call(%32, %44, %53) : (!fly.atom.universal_copy<32>, !fly.memref, !fly.memref>) -> () - gpu.return - } - } - func.func @vectorAdd(%arg0: !fly.memref>, %arg1: !fly.memref>, %arg2: !fly.memref>, %arg3: i32, %arg4: !gpu.async.token) attributes {llvm.emit_c_interface} { - %c64 = arith.constant 64 : index - %c1 = arith.constant 1 : index - %c1_i32 = arith.constant 1 : i32 - %c64_i32 = arith.constant 64 : i32 - %0 = arith.addi %arg3, %c64_i32 : i32 - %1 = arith.subi %0, %c1_i32 : i32 - %2 = arith.floordivsi %1, %c64_i32 : i32 - %3 = arith.index_cast %2 : i32 to index - gpu.launch_func <%arg4 : !gpu.async.token> @kernels::@vectorAddKernel_0 blocks in (%3, %c1, %c1) threads in (%c64, %c1, %c1) args(%arg0 : !fly.memref>, %arg1 : !fly.memref>, %arg2 : !fly.memref>) - return - } -} diff --git a/ir_dump/03_FlyLayoutLoweringPass.mlir b/ir_dump/03_FlyLayoutLoweringPass.mlir deleted file mode 100644 index 5b828cf6..00000000 --- a/ir_dump/03_FlyLayoutLoweringPass.mlir +++ /dev/null @@ -1,78 +0,0 @@ -module attributes {gpu.container_module} { - gpu.module @kernels [#rocdl.target] { - gpu.func @vectorAddKernel_0(%arg0: !fly.ptr>, %arg1: !llvm.struct, struct)>, %arg2: !fly.ptr>, %arg3: !fly.ptr>) kernel { - %c64_i32 = arith.constant 64 : i32 - %block_id_x = gpu.block_id x - %0 = arith.index_cast %block_id_x : index to i32 - %thread_id_x = gpu.thread_id x - %1 = arith.index_cast %thread_id_x : index to i32 - %2 = arith.muli %0, %c64_i32 : i32 - %3 = fly.make_int_tuple(%2) : (i32) -> !fly.int_tuple - %4 = fly.add_offset(%arg0, %3) : (!fly.ptr>, !fly.int_tuple) -> !fly.ptr> - %5 = arith.muli %0, %c64_i32 : i32 - %6 = fly.make_int_tuple(%5) : (i32) -> !fly.int_tuple - %7 = fly.add_offset(%arg2, %6) : (!fly.ptr>, !fly.int_tuple) -> !fly.ptr> - %8 = arith.muli %0, %c64_i32 : i32 - %9 = fly.make_int_tuple(%8) : (i32) -> !fly.int_tuple - %10 = fly.add_offset(%arg3, %9) : (!fly.ptr>, !fly.int_tuple) -> !fly.ptr> - %11 = fly.make_atom : () -> !fly.atom.universal_copy<32> - %12 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %13 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %14 = fly.make_layout(%12, %13) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %15 = fly.memref.alloca(%14) : (!fly.layout<1:1>) -> !fly.memref - %16 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %17 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %18 = fly.make_layout(%16, %17) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %19 = fly.memref.alloca(%18) : (!fly.layout<1:1>) -> !fly.memref - %20 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %21 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %22 = fly.make_layout(%20, %21) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %23 = fly.memref.alloca(%22) : (!fly.layout<1:1>) -> !fly.memref - %24 = fly.make_int_tuple(%1) : (i32) -> !fly.int_tuple - %25 = fly.add_offset(%4, %24) : (!fly.ptr>, !fly.int_tuple) -> !fly.ptr> - %26 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %27 = fly.make_int_tuple() : () -> !fly.int_tuple<0> - %28 = fly.make_layout(%26, %27) : (!fly.int_tuple<1>, !fly.int_tuple<0>) -> !fly.layout<1:0> - %29 = fly.make_view(%25, %28) : (!fly.ptr>, !fly.layout<1:0>) -> !fly.memref> - fly.copy_atom_call(%11, %29, %15) : (!fly.atom.universal_copy<32>, !fly.memref>, !fly.memref) -> () - %30 = fly.make_int_tuple(%1) : (i32) -> !fly.int_tuple - %31 = fly.add_offset(%7, %30) : (!fly.ptr>, !fly.int_tuple) -> !fly.ptr> - %32 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %33 = fly.make_int_tuple() : () -> !fly.int_tuple<0> - %34 = fly.make_layout(%32, %33) : (!fly.int_tuple<1>, !fly.int_tuple<0>) -> !fly.layout<1:0> - %35 = fly.make_view(%31, %34) : (!fly.ptr>, !fly.layout<1:0>) -> !fly.memref> - fly.copy_atom_call(%11, %35, %19) : (!fly.atom.universal_copy<32>, !fly.memref>, !fly.memref) -> () - %36 = fly.memref.load_vec(%15) : (!fly.memref) -> vector<1xf32> - %37 = fly.memref.load_vec(%19) : (!fly.memref) -> vector<1xf32> - %38 = arith.addf %36, %37 : vector<1xf32> - fly.memref.store_vec(%38, %23) : (vector<1xf32>, !fly.memref) -> () - %39 = fly.make_int_tuple(%1) : (i32) -> !fly.int_tuple - %40 = fly.add_offset(%10, %39) : (!fly.ptr>, !fly.int_tuple) -> !fly.ptr> - %41 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %42 = fly.make_int_tuple() : () -> !fly.int_tuple<0> - %43 = fly.make_layout(%41, %42) : (!fly.int_tuple<1>, !fly.int_tuple<0>) -> !fly.layout<1:0> - %44 = fly.make_view(%40, %43) : (!fly.ptr>, !fly.layout<1:0>) -> !fly.memref> - fly.copy_atom_call(%11, %23, %44) : (!fly.atom.universal_copy<32>, !fly.memref, !fly.memref>) -> () - gpu.return - } - } - func.func @vectorAdd(%arg0: !fly.ptr>, %arg1: !llvm.struct, struct)>, %arg2: !fly.ptr>, %arg3: !fly.ptr>, %arg4: i32, %arg5: !gpu.async.token) attributes {llvm.emit_c_interface} { - %0 = llvm.mlir.undef : !llvm.struct, struct)> - %1 = llvm.mlir.undef : !llvm.struct - %2 = llvm.mlir.undef : !llvm.struct - %c64_i32 = arith.constant 64 : i32 - %c1_i32 = arith.constant 1 : i32 - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %3 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> - %4 = arith.addi %arg4, %c64_i32 : i32 - %5 = arith.subi %4, %c1_i32 : i32 - %6 = arith.floordivsi %5, %c64_i32 : i32 - %7 = arith.index_cast %6 : i32 to index - %8 = llvm.insertvalue %3, %2[0] : !llvm.struct - %9 = llvm.insertvalue %8, %0[0] : !llvm.struct, struct)> - %10 = llvm.insertvalue %1, %9[1] : !llvm.struct, struct)> - gpu.launch_func <%arg5 : !gpu.async.token> @kernels::@vectorAddKernel_0 blocks in (%7, %c1, %c1) threads in (%c64, %c1, %c1) args(%arg0 : !fly.ptr>, %10 : !llvm.struct, struct)>, %arg2 : !fly.ptr>, %arg3 : !fly.ptr>) - return - } -} diff --git a/ir_dump/04_FlyToROCDLConversionPass.mlir b/ir_dump/04_FlyToROCDLConversionPass.mlir deleted file mode 100644 index 57a4b418..00000000 --- a/ir_dump/04_FlyToROCDLConversionPass.mlir +++ /dev/null @@ -1,93 +0,0 @@ -module attributes {gpu.container_module} { - gpu.module @kernels [#rocdl.target] { - gpu.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) kernel { - %c64_i32 = arith.constant 64 : i32 - %block_id_x = gpu.block_id x - %0 = arith.index_cast %block_id_x : index to i32 - %thread_id_x = gpu.thread_id x - %1 = arith.index_cast %thread_id_x : index to i32 - %2 = arith.muli %0, %c64_i32 : i32 - %3 = fly.make_int_tuple(%2) : (i32) -> !fly.int_tuple - %4 = arith.index_cast %2 : i32 to index - %5 = arith.index_cast %4 : index to i64 - %6 = llvm.getelementptr %arg0[%5] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %7 = arith.muli %0, %c64_i32 : i32 - %8 = fly.make_int_tuple(%7) : (i32) -> !fly.int_tuple - %9 = arith.index_cast %7 : i32 to index - %10 = arith.index_cast %9 : index to i64 - %11 = llvm.getelementptr %arg2[%10] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %12 = arith.muli %0, %c64_i32 : i32 - %13 = fly.make_int_tuple(%12) : (i32) -> !fly.int_tuple - %14 = arith.index_cast %12 : i32 to index - %15 = arith.index_cast %14 : index to i64 - %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %17 = fly.make_atom : () -> !fly.atom.universal_copy<32> - %18 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %19 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %20 = fly.make_layout(%18, %19) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %c1_i64 = arith.constant 1 : i64 - %21 = llvm.alloca %c1_i64 x f32 : (i64) -> !llvm.ptr<5> - %22 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %23 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %24 = fly.make_layout(%22, %23) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %c1_i64_0 = arith.constant 1 : i64 - %25 = llvm.alloca %c1_i64_0 x f32 : (i64) -> !llvm.ptr<5> - %26 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %27 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %28 = fly.make_layout(%26, %27) : (!fly.int_tuple<1>, !fly.int_tuple<1>) -> !fly.layout<1:1> - %c1_i64_1 = arith.constant 1 : i64 - %29 = llvm.alloca %c1_i64_1 x f32 : (i64) -> !llvm.ptr<5> - %30 = fly.make_int_tuple(%1) : (i32) -> !fly.int_tuple - %31 = arith.index_cast %1 : i32 to index - %32 = arith.index_cast %31 : index to i64 - %33 = llvm.getelementptr %6[%32] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %34 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %35 = fly.make_int_tuple() : () -> !fly.int_tuple<0> - %36 = fly.make_layout(%34, %35) : (!fly.int_tuple<1>, !fly.int_tuple<0>) -> !fly.layout<1:0> - %c4_i64 = arith.constant 4 : i64 - "llvm.intr.memcpy"(%21, %33, %c4_i64) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %37 = fly.make_int_tuple(%1) : (i32) -> !fly.int_tuple - %38 = arith.index_cast %1 : i32 to index - %39 = arith.index_cast %38 : index to i64 - %40 = llvm.getelementptr %11[%39] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %41 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %42 = fly.make_int_tuple() : () -> !fly.int_tuple<0> - %43 = fly.make_layout(%41, %42) : (!fly.int_tuple<1>, !fly.int_tuple<0>) -> !fly.layout<1:0> - %c4_i64_2 = arith.constant 4 : i64 - "llvm.intr.memcpy"(%25, %40, %c4_i64_2) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %44 = llvm.load %21 : !llvm.ptr<5> -> vector<1xf32> - %45 = llvm.load %25 : !llvm.ptr<5> -> vector<1xf32> - %46 = arith.addf %44, %45 : vector<1xf32> - llvm.store %46, %29 : vector<1xf32>, !llvm.ptr<5> - %47 = fly.make_int_tuple(%1) : (i32) -> !fly.int_tuple - %48 = arith.index_cast %1 : i32 to index - %49 = arith.index_cast %48 : index to i64 - %50 = llvm.getelementptr %16[%49] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %51 = fly.make_int_tuple() : () -> !fly.int_tuple<1> - %52 = fly.make_int_tuple() : () -> !fly.int_tuple<0> - %53 = fly.make_layout(%51, %52) : (!fly.int_tuple<1>, !fly.int_tuple<0>) -> !fly.layout<1:0> - %c4_i64_3 = arith.constant 4 : i64 - "llvm.intr.memcpy"(%50, %29, %c4_i64_3) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () - gpu.return - } - } - func.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !gpu.async.token) attributes {llvm.emit_c_interface} { - %0 = llvm.mlir.undef : !llvm.struct, struct)> - %1 = llvm.mlir.undef : !llvm.struct - %2 = llvm.mlir.undef : !llvm.struct - %c64_i32 = arith.constant 64 : i32 - %c1_i32 = arith.constant 1 : i32 - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %3 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> - %4 = arith.addi %arg4, %c64_i32 : i32 - %5 = arith.subi %4, %c1_i32 : i32 - %6 = arith.floordivsi %5, %c64_i32 : i32 - %7 = arith.index_cast %6 : i32 to index - %8 = llvm.insertvalue %3, %2[0] : !llvm.struct - %9 = llvm.insertvalue %8, %0[0] : !llvm.struct, struct)> - %10 = llvm.insertvalue %1, %9[1] : !llvm.struct, struct)> - gpu.launch_func <%arg5 : !gpu.async.token> @kernels::@vectorAddKernel_0 blocks in (%7, %c1, %c1) threads in (%c64, %c1, %c1) args(%arg0 : !llvm.ptr<1>, %10 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) - return - } -} diff --git a/ir_dump/05_Canonicalizer.mlir b/ir_dump/05_Canonicalizer.mlir deleted file mode 100644 index 227a8222..00000000 --- a/ir_dump/05_Canonicalizer.mlir +++ /dev/null @@ -1,59 +0,0 @@ -module attributes {gpu.container_module} { - gpu.module @kernels [#rocdl.target] { - gpu.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) kernel { - %c4_i64 = arith.constant 4 : i64 - %c1_i64 = arith.constant 1 : i64 - %c64_i32 = arith.constant 64 : i32 - %block_id_x = gpu.block_id x - %0 = arith.index_cast %block_id_x : index to i32 - %thread_id_x = gpu.thread_id x - %1 = arith.muli %0, %c64_i32 : i32 - %2 = arith.index_cast %1 : i32 to index - %3 = arith.index_cast %2 : index to i64 - %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %5 = arith.muli %0, %c64_i32 : i32 - %6 = arith.index_cast %5 : i32 to index - %7 = arith.index_cast %6 : index to i64 - %8 = llvm.getelementptr %arg2[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %9 = arith.muli %0, %c64_i32 : i32 - %10 = arith.index_cast %9 : i32 to index - %11 = arith.index_cast %10 : index to i64 - %12 = llvm.getelementptr %arg3[%11] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %13 = llvm.alloca %c1_i64 x f32 : (i64) -> !llvm.ptr<5> - %14 = llvm.alloca %c1_i64 x f32 : (i64) -> !llvm.ptr<5> - %15 = llvm.alloca %c1_i64 x f32 : (i64) -> !llvm.ptr<5> - %16 = arith.index_cast %thread_id_x : index to i64 - %17 = llvm.getelementptr %4[%16] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%13, %17, %c4_i64) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %18 = arith.index_cast %thread_id_x : index to i64 - %19 = llvm.getelementptr %8[%18] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%14, %19, %c4_i64) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %20 = llvm.load %13 : !llvm.ptr<5> -> vector<1xf32> - %21 = llvm.load %14 : !llvm.ptr<5> -> vector<1xf32> - %22 = arith.addf %20, %21 : vector<1xf32> - llvm.store %22, %15 : vector<1xf32>, !llvm.ptr<5> - %23 = arith.index_cast %thread_id_x : index to i64 - %24 = llvm.getelementptr %12[%23] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%24, %15, %c4_i64) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () - gpu.return - } - } - func.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !gpu.async.token) attributes {llvm.emit_c_interface} { - %c63_i32 = arith.constant 63 : i32 - %0 = llvm.mlir.undef : !llvm.struct, struct)> - %1 = llvm.mlir.undef : !llvm.struct - %2 = llvm.mlir.undef : !llvm.struct - %c64_i32 = arith.constant 64 : i32 - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %3 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> - %4 = arith.addi %arg4, %c63_i32 : i32 - %5 = arith.floordivsi %4, %c64_i32 : i32 - %6 = arith.index_cast %5 : i32 to index - %7 = llvm.insertvalue %3, %2[0] : !llvm.struct - %8 = llvm.insertvalue %7, %0[0] : !llvm.struct, struct)> - %9 = llvm.insertvalue %1, %8[1] : !llvm.struct, struct)> - gpu.launch_func <%arg5 : !gpu.async.token> @kernels::@vectorAddKernel_0 blocks in (%6, %c1, %c1) threads in (%c64, %c1, %c1) args(%arg0 : !llvm.ptr<1>, %9 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) - return - } -} diff --git a/ir_dump/06_SCFToControlFlowPass.mlir b/ir_dump/06_SCFToControlFlowPass.mlir deleted file mode 100644 index 2dec6698..00000000 --- a/ir_dump/06_SCFToControlFlowPass.mlir +++ /dev/null @@ -1,39 +0,0 @@ -gpu.module @kernels [#rocdl.target] { - gpu.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) kernel { - %c4_i64 = arith.constant 4 : i64 - %c1_i64 = arith.constant 1 : i64 - %c64_i32 = arith.constant 64 : i32 - %block_id_x = gpu.block_id x - %0 = arith.index_cast %block_id_x : index to i32 - %thread_id_x = gpu.thread_id x - %1 = arith.muli %0, %c64_i32 : i32 - %2 = arith.index_cast %1 : i32 to index - %3 = arith.index_cast %2 : index to i64 - %4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %5 = arith.muli %0, %c64_i32 : i32 - %6 = arith.index_cast %5 : i32 to index - %7 = arith.index_cast %6 : index to i64 - %8 = llvm.getelementptr %arg2[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %9 = arith.muli %0, %c64_i32 : i32 - %10 = arith.index_cast %9 : i32 to index - %11 = arith.index_cast %10 : index to i64 - %12 = llvm.getelementptr %arg3[%11] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %13 = llvm.alloca %c1_i64 x f32 : (i64) -> !llvm.ptr<5> - %14 = llvm.alloca %c1_i64 x f32 : (i64) -> !llvm.ptr<5> - %15 = llvm.alloca %c1_i64 x f32 : (i64) -> !llvm.ptr<5> - %16 = arith.index_cast %thread_id_x : index to i64 - %17 = llvm.getelementptr %4[%16] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%13, %17, %c4_i64) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %18 = arith.index_cast %thread_id_x : index to i64 - %19 = llvm.getelementptr %8[%18] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%14, %19, %c4_i64) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %20 = llvm.load %13 : !llvm.ptr<5> -> vector<1xf32> - %21 = llvm.load %14 : !llvm.ptr<5> -> vector<1xf32> - %22 = arith.addf %20, %21 : vector<1xf32> - llvm.store %22, %15 : vector<1xf32>, !llvm.ptr<5> - %23 = arith.index_cast %thread_id_x : index to i64 - %24 = llvm.getelementptr %12[%23] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%24, %15, %c4_i64) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () - gpu.return - } -} diff --git a/ir_dump/07_ConvertGpuOpsToROCDLOps.mlir b/ir_dump/07_ConvertGpuOpsToROCDLOps.mlir deleted file mode 100644 index 6e862f77..00000000 --- a/ir_dump/07_ConvertGpuOpsToROCDLOps.mlir +++ /dev/null @@ -1,35 +0,0 @@ -gpu.module @kernels [#rocdl.target] attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"} { - llvm.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {gpu.kernel, rocdl.kernel} { - %0 = llvm.mlir.constant(4 : i64) : i64 - %1 = llvm.mlir.constant(1 : i64) : i64 - %2 = llvm.mlir.constant(64 : i32) : i32 - %3 = rocdl.workgroup.id.x : i32 - %4 = llvm.sext %3 : i32 to i64 - %5 = llvm.trunc %4 : i64 to i32 - %6 = rocdl.workitem.id.x : i32 - %7 = llvm.sext %6 : i32 to i64 - %8 = llvm.mul %5, %2 : i32 - %9 = llvm.sext %8 : i32 to i64 - %10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %11 = llvm.mul %5, %2 : i32 - %12 = llvm.sext %11 : i32 to i64 - %13 = llvm.getelementptr %arg2[%12] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %14 = llvm.mul %5, %2 : i32 - %15 = llvm.sext %14 : i32 to i64 - %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %17 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %18 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %19 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %20 = llvm.getelementptr %10[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%17, %20, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %21 = llvm.getelementptr %13[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%18, %21, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %22 = llvm.load %17 : !llvm.ptr<5> -> vector<1xf32> - %23 = llvm.load %18 : !llvm.ptr<5> -> vector<1xf32> - %24 = llvm.fadd %22, %23 : vector<1xf32> - llvm.store %24, %19 : vector<1xf32>, !llvm.ptr<5> - %25 = llvm.getelementptr %16[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%25, %19, %0) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () - llvm.return - } -} diff --git a/ir_dump/08_GpuROCDLAttachTarget.mlir b/ir_dump/08_GpuROCDLAttachTarget.mlir deleted file mode 100644 index 496e5e86..00000000 --- a/ir_dump/08_GpuROCDLAttachTarget.mlir +++ /dev/null @@ -1,55 +0,0 @@ -module attributes {gpu.container_module} { - gpu.module @kernels [#rocdl.target] attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"} { - llvm.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {gpu.kernel, rocdl.kernel} { - %0 = llvm.mlir.constant(4 : i64) : i64 - %1 = llvm.mlir.constant(1 : i64) : i64 - %2 = llvm.mlir.constant(64 : i32) : i32 - %3 = rocdl.workgroup.id.x : i32 - %4 = llvm.sext %3 : i32 to i64 - %5 = llvm.trunc %4 : i64 to i32 - %6 = rocdl.workitem.id.x : i32 - %7 = llvm.sext %6 : i32 to i64 - %8 = llvm.mul %5, %2 : i32 - %9 = llvm.sext %8 : i32 to i64 - %10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %11 = llvm.mul %5, %2 : i32 - %12 = llvm.sext %11 : i32 to i64 - %13 = llvm.getelementptr %arg2[%12] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %14 = llvm.mul %5, %2 : i32 - %15 = llvm.sext %14 : i32 to i64 - %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %17 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %18 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %19 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %20 = llvm.getelementptr %10[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%17, %20, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %21 = llvm.getelementptr %13[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%18, %21, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %22 = llvm.load %17 : !llvm.ptr<5> -> vector<1xf32> - %23 = llvm.load %18 : !llvm.ptr<5> -> vector<1xf32> - %24 = llvm.fadd %22, %23 : vector<1xf32> - llvm.store %24, %19 : vector<1xf32>, !llvm.ptr<5> - %25 = llvm.getelementptr %16[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%25, %19, %0) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () - llvm.return - } - } - func.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !gpu.async.token) attributes {llvm.emit_c_interface} { - %c63_i32 = arith.constant 63 : i32 - %0 = llvm.mlir.undef : !llvm.struct, struct)> - %1 = llvm.mlir.undef : !llvm.struct - %2 = llvm.mlir.undef : !llvm.struct - %c64_i32 = arith.constant 64 : i32 - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %3 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> - %4 = arith.addi %arg4, %c63_i32 : i32 - %5 = arith.floordivsi %4, %c64_i32 : i32 - %6 = arith.index_cast %5 : i32 to index - %7 = llvm.insertvalue %3, %2[0] : !llvm.struct - %8 = llvm.insertvalue %7, %0[0] : !llvm.struct, struct)> - %9 = llvm.insertvalue %1, %8[1] : !llvm.struct, struct)> - gpu.launch_func <%arg5 : !gpu.async.token> @kernels::@vectorAddKernel_0 blocks in (%6, %c1, %c1) threads in (%c64, %c1, %c1) args(%arg0 : !llvm.ptr<1>, %9 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) - return - } -} diff --git a/ir_dump/09_SCFToControlFlowPass.mlir b/ir_dump/09_SCFToControlFlowPass.mlir deleted file mode 100644 index 496e5e86..00000000 --- a/ir_dump/09_SCFToControlFlowPass.mlir +++ /dev/null @@ -1,55 +0,0 @@ -module attributes {gpu.container_module} { - gpu.module @kernels [#rocdl.target] attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"} { - llvm.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {gpu.kernel, rocdl.kernel} { - %0 = llvm.mlir.constant(4 : i64) : i64 - %1 = llvm.mlir.constant(1 : i64) : i64 - %2 = llvm.mlir.constant(64 : i32) : i32 - %3 = rocdl.workgroup.id.x : i32 - %4 = llvm.sext %3 : i32 to i64 - %5 = llvm.trunc %4 : i64 to i32 - %6 = rocdl.workitem.id.x : i32 - %7 = llvm.sext %6 : i32 to i64 - %8 = llvm.mul %5, %2 : i32 - %9 = llvm.sext %8 : i32 to i64 - %10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %11 = llvm.mul %5, %2 : i32 - %12 = llvm.sext %11 : i32 to i64 - %13 = llvm.getelementptr %arg2[%12] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %14 = llvm.mul %5, %2 : i32 - %15 = llvm.sext %14 : i32 to i64 - %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %17 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %18 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %19 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %20 = llvm.getelementptr %10[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%17, %20, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %21 = llvm.getelementptr %13[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%18, %21, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %22 = llvm.load %17 : !llvm.ptr<5> -> vector<1xf32> - %23 = llvm.load %18 : !llvm.ptr<5> -> vector<1xf32> - %24 = llvm.fadd %22, %23 : vector<1xf32> - llvm.store %24, %19 : vector<1xf32>, !llvm.ptr<5> - %25 = llvm.getelementptr %16[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%25, %19, %0) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () - llvm.return - } - } - func.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !gpu.async.token) attributes {llvm.emit_c_interface} { - %c63_i32 = arith.constant 63 : i32 - %0 = llvm.mlir.undef : !llvm.struct, struct)> - %1 = llvm.mlir.undef : !llvm.struct - %2 = llvm.mlir.undef : !llvm.struct - %c64_i32 = arith.constant 64 : i32 - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %3 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> - %4 = arith.addi %arg4, %c63_i32 : i32 - %5 = arith.floordivsi %4, %c64_i32 : i32 - %6 = arith.index_cast %5 : i32 to index - %7 = llvm.insertvalue %3, %2[0] : !llvm.struct - %8 = llvm.insertvalue %7, %0[0] : !llvm.struct, struct)> - %9 = llvm.insertvalue %1, %8[1] : !llvm.struct, struct)> - gpu.launch_func <%arg5 : !gpu.async.token> @kernels::@vectorAddKernel_0 blocks in (%6, %c1, %c1) threads in (%c64, %c1, %c1) args(%arg0 : !llvm.ptr<1>, %9 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) - return - } -} diff --git a/ir_dump/10_ConvertControlFlowToLLVMPass.mlir b/ir_dump/10_ConvertControlFlowToLLVMPass.mlir deleted file mode 100644 index 496e5e86..00000000 --- a/ir_dump/10_ConvertControlFlowToLLVMPass.mlir +++ /dev/null @@ -1,55 +0,0 @@ -module attributes {gpu.container_module} { - gpu.module @kernels [#rocdl.target] attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"} { - llvm.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {gpu.kernel, rocdl.kernel} { - %0 = llvm.mlir.constant(4 : i64) : i64 - %1 = llvm.mlir.constant(1 : i64) : i64 - %2 = llvm.mlir.constant(64 : i32) : i32 - %3 = rocdl.workgroup.id.x : i32 - %4 = llvm.sext %3 : i32 to i64 - %5 = llvm.trunc %4 : i64 to i32 - %6 = rocdl.workitem.id.x : i32 - %7 = llvm.sext %6 : i32 to i64 - %8 = llvm.mul %5, %2 : i32 - %9 = llvm.sext %8 : i32 to i64 - %10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %11 = llvm.mul %5, %2 : i32 - %12 = llvm.sext %11 : i32 to i64 - %13 = llvm.getelementptr %arg2[%12] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %14 = llvm.mul %5, %2 : i32 - %15 = llvm.sext %14 : i32 to i64 - %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %17 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %18 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %19 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %20 = llvm.getelementptr %10[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%17, %20, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %21 = llvm.getelementptr %13[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%18, %21, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %22 = llvm.load %17 : !llvm.ptr<5> -> vector<1xf32> - %23 = llvm.load %18 : !llvm.ptr<5> -> vector<1xf32> - %24 = llvm.fadd %22, %23 : vector<1xf32> - llvm.store %24, %19 : vector<1xf32>, !llvm.ptr<5> - %25 = llvm.getelementptr %16[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%25, %19, %0) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () - llvm.return - } - } - func.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !gpu.async.token) attributes {llvm.emit_c_interface} { - %c63_i32 = arith.constant 63 : i32 - %0 = llvm.mlir.undef : !llvm.struct, struct)> - %1 = llvm.mlir.undef : !llvm.struct - %2 = llvm.mlir.undef : !llvm.struct - %c64_i32 = arith.constant 64 : i32 - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %3 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> - %4 = arith.addi %arg4, %c63_i32 : i32 - %5 = arith.floordivsi %4, %c64_i32 : i32 - %6 = arith.index_cast %5 : i32 to index - %7 = llvm.insertvalue %3, %2[0] : !llvm.struct - %8 = llvm.insertvalue %7, %0[0] : !llvm.struct, struct)> - %9 = llvm.insertvalue %1, %8[1] : !llvm.struct, struct)> - gpu.launch_func <%arg5 : !gpu.async.token> @kernels::@vectorAddKernel_0 blocks in (%6, %c1, %c1) threads in (%c64, %c1, %c1) args(%arg0 : !llvm.ptr<1>, %9 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) - return - } -} diff --git a/ir_dump/11_FlyGpuStreamMarkPass.mlir b/ir_dump/11_FlyGpuStreamMarkPass.mlir deleted file mode 100644 index ea5f7edc..00000000 --- a/ir_dump/11_FlyGpuStreamMarkPass.mlir +++ /dev/null @@ -1,55 +0,0 @@ -module attributes {gpu.container_module} { - gpu.module @kernels [#rocdl.target] attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"} { - llvm.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {gpu.kernel, rocdl.kernel} { - %0 = llvm.mlir.constant(4 : i64) : i64 - %1 = llvm.mlir.constant(1 : i64) : i64 - %2 = llvm.mlir.constant(64 : i32) : i32 - %3 = rocdl.workgroup.id.x : i32 - %4 = llvm.sext %3 : i32 to i64 - %5 = llvm.trunc %4 : i64 to i32 - %6 = rocdl.workitem.id.x : i32 - %7 = llvm.sext %6 : i32 to i64 - %8 = llvm.mul %5, %2 : i32 - %9 = llvm.sext %8 : i32 to i64 - %10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %11 = llvm.mul %5, %2 : i32 - %12 = llvm.sext %11 : i32 to i64 - %13 = llvm.getelementptr %arg2[%12] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %14 = llvm.mul %5, %2 : i32 - %15 = llvm.sext %14 : i32 to i64 - %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %17 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %18 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %19 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %20 = llvm.getelementptr %10[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%17, %20, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %21 = llvm.getelementptr %13[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%18, %21, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %22 = llvm.load %17 : !llvm.ptr<5> -> vector<1xf32> - %23 = llvm.load %18 : !llvm.ptr<5> -> vector<1xf32> - %24 = llvm.fadd %22, %23 : vector<1xf32> - llvm.store %24, %19 : vector<1xf32>, !llvm.ptr<5> - %25 = llvm.getelementptr %16[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%25, %19, %0) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () - llvm.return - } - } - func.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !gpu.async.token) attributes {fly.stream_arg_index = 5 : index, llvm.emit_c_interface} { - %c63_i32 = arith.constant 63 : i32 - %0 = llvm.mlir.undef : !llvm.struct, struct)> - %1 = llvm.mlir.undef : !llvm.struct - %2 = llvm.mlir.undef : !llvm.struct - %c64_i32 = arith.constant 64 : i32 - %c1 = arith.constant 1 : index - %c64 = arith.constant 64 : index - %3 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> - %4 = arith.addi %arg4, %c63_i32 : i32 - %5 = arith.floordivsi %4, %c64_i32 : i32 - %6 = arith.index_cast %5 : i32 to index - %7 = llvm.insertvalue %3, %2[0] : !llvm.struct - %8 = llvm.insertvalue %7, %0[0] : !llvm.struct, struct)> - %9 = llvm.insertvalue %1, %8[1] : !llvm.struct, struct)> - gpu.launch_func <%arg5 : !gpu.async.token> @kernels::@vectorAddKernel_0 blocks in (%6, %c1, %c1) threads in (%c64, %c1, %c1) args(%arg0 : !llvm.ptr<1>, %9 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) - return - } -} diff --git a/ir_dump/12_GpuToLLVMConversionPass.mlir b/ir_dump/12_GpuToLLVMConversionPass.mlir deleted file mode 100644 index ab9375d7..00000000 --- a/ir_dump/12_GpuToLLVMConversionPass.mlir +++ /dev/null @@ -1,65 +0,0 @@ -module attributes {gpu.container_module} { - gpu.module @kernels [#rocdl.target] attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"} { - llvm.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {gpu.kernel, rocdl.kernel} { - %0 = llvm.mlir.constant(4 : i64) : i64 - %1 = llvm.mlir.constant(1 : i64) : i64 - %2 = llvm.mlir.constant(64 : i32) : i32 - %3 = rocdl.workgroup.id.x : i32 - %4 = llvm.sext %3 : i32 to i64 - %5 = llvm.trunc %4 : i64 to i32 - %6 = rocdl.workitem.id.x : i32 - %7 = llvm.sext %6 : i32 to i64 - %8 = llvm.mul %5, %2 : i32 - %9 = llvm.sext %8 : i32 to i64 - %10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %11 = llvm.mul %5, %2 : i32 - %12 = llvm.sext %11 : i32 to i64 - %13 = llvm.getelementptr %arg2[%12] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %14 = llvm.mul %5, %2 : i32 - %15 = llvm.sext %14 : i32 to i64 - %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %17 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %18 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %19 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %20 = llvm.getelementptr %10[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%17, %20, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %21 = llvm.getelementptr %13[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%18, %21, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %22 = llvm.load %17 : !llvm.ptr<5> -> vector<1xf32> - %23 = llvm.load %18 : !llvm.ptr<5> -> vector<1xf32> - %24 = llvm.fadd %22, %23 : vector<1xf32> - llvm.store %24, %19 : vector<1xf32>, !llvm.ptr<5> - %25 = llvm.getelementptr %16[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%25, %19, %0) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () - llvm.return - } - } - llvm.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !llvm.ptr) attributes {fly.stream_arg_index = 5 : index, llvm.emit_c_interface} { - %0 = llvm.mlir.constant(63 : i32) : i32 - %1 = llvm.mlir.undef : !llvm.struct, struct)> - %2 = llvm.mlir.undef : !llvm.struct - %3 = llvm.mlir.undef : !llvm.struct - %4 = llvm.mlir.constant(64 : i32) : i32 - %5 = llvm.mlir.constant(1 : index) : i64 - %6 = llvm.mlir.constant(64 : index) : i64 - %7 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> - %8 = llvm.add %arg4, %0 : i32 - %9 = llvm.sdiv %8, %4 : i32 - %10 = llvm.mul %9, %4 : i32 - %11 = llvm.icmp "ne" %8, %10 : i32 - %12 = llvm.mlir.constant(0 : i32) : i32 - %13 = llvm.icmp "slt" %8, %12 : i32 - %14 = llvm.mlir.constant(false) : i1 - %15 = llvm.icmp "ne" %13, %14 : i1 - %16 = llvm.and %11, %15 : i1 - %17 = llvm.mlir.constant(-1 : i32) : i32 - %18 = llvm.add %9, %17 : i32 - %19 = llvm.select %16, %18, %9 : i1, i32 - %20 = llvm.sext %19 : i32 to i64 - %21 = llvm.insertvalue %7, %3[0] : !llvm.struct - %22 = llvm.insertvalue %21, %1[0] : !llvm.struct, struct)> - %23 = llvm.insertvalue %2, %22[1] : !llvm.struct, struct)> - gpu.launch_func @kernels::@vectorAddKernel_0 blocks in (%20, %5, %5) threads in (%6, %5, %5) : i64 args(%arg0 : !llvm.ptr<1>, %23 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) - llvm.return - } -} diff --git a/ir_dump/13_FlyGpuStreamInjectPass.mlir b/ir_dump/13_FlyGpuStreamInjectPass.mlir deleted file mode 100644 index 0144d3db..00000000 --- a/ir_dump/13_FlyGpuStreamInjectPass.mlir +++ /dev/null @@ -1,65 +0,0 @@ -module attributes {gpu.container_module} { - gpu.module @kernels [#rocdl.target] attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"} { - llvm.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {gpu.kernel, rocdl.kernel} { - %0 = llvm.mlir.constant(4 : i64) : i64 - %1 = llvm.mlir.constant(1 : i64) : i64 - %2 = llvm.mlir.constant(64 : i32) : i32 - %3 = rocdl.workgroup.id.x : i32 - %4 = llvm.sext %3 : i32 to i64 - %5 = llvm.trunc %4 : i64 to i32 - %6 = rocdl.workitem.id.x : i32 - %7 = llvm.sext %6 : i32 to i64 - %8 = llvm.mul %5, %2 : i32 - %9 = llvm.sext %8 : i32 to i64 - %10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %11 = llvm.mul %5, %2 : i32 - %12 = llvm.sext %11 : i32 to i64 - %13 = llvm.getelementptr %arg2[%12] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %14 = llvm.mul %5, %2 : i32 - %15 = llvm.sext %14 : i32 to i64 - %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %17 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %18 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %19 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %20 = llvm.getelementptr %10[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%17, %20, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %21 = llvm.getelementptr %13[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%18, %21, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %22 = llvm.load %17 : !llvm.ptr<5> -> vector<1xf32> - %23 = llvm.load %18 : !llvm.ptr<5> -> vector<1xf32> - %24 = llvm.fadd %22, %23 : vector<1xf32> - llvm.store %24, %19 : vector<1xf32>, !llvm.ptr<5> - %25 = llvm.getelementptr %16[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%25, %19, %0) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () - llvm.return - } - } - llvm.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !llvm.ptr) attributes {llvm.emit_c_interface} { - %0 = llvm.mlir.constant(63 : i32) : i32 - %1 = llvm.mlir.undef : !llvm.struct, struct)> - %2 = llvm.mlir.undef : !llvm.struct - %3 = llvm.mlir.undef : !llvm.struct - %4 = llvm.mlir.constant(64 : i32) : i32 - %5 = llvm.mlir.constant(1 : index) : i64 - %6 = llvm.mlir.constant(64 : index) : i64 - %7 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> - %8 = llvm.add %arg4, %0 : i32 - %9 = llvm.sdiv %8, %4 : i32 - %10 = llvm.mul %9, %4 : i32 - %11 = llvm.icmp "ne" %8, %10 : i32 - %12 = llvm.mlir.constant(0 : i32) : i32 - %13 = llvm.icmp "slt" %8, %12 : i32 - %14 = llvm.mlir.constant(false) : i1 - %15 = llvm.icmp "ne" %13, %14 : i1 - %16 = llvm.and %11, %15 : i1 - %17 = llvm.mlir.constant(-1 : i32) : i32 - %18 = llvm.add %9, %17 : i32 - %19 = llvm.select %16, %18, %9 : i1, i32 - %20 = llvm.sext %19 : i32 to i64 - %21 = llvm.insertvalue %7, %3[0] : !llvm.struct - %22 = llvm.insertvalue %21, %1[0] : !llvm.struct, struct)> - %23 = llvm.insertvalue %2, %22[1] : !llvm.struct, struct)> - gpu.launch_func <%arg5 : !llvm.ptr> @kernels::@vectorAddKernel_0 blocks in (%20, %5, %5) threads in (%6, %5, %5) : i64 args(%arg0 : !llvm.ptr<1>, %23 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) - llvm.return - } -} diff --git a/ir_dump/14_ArithToLLVMConversionPass.mlir b/ir_dump/14_ArithToLLVMConversionPass.mlir deleted file mode 100644 index 0144d3db..00000000 --- a/ir_dump/14_ArithToLLVMConversionPass.mlir +++ /dev/null @@ -1,65 +0,0 @@ -module attributes {gpu.container_module} { - gpu.module @kernels [#rocdl.target] attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"} { - llvm.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {gpu.kernel, rocdl.kernel} { - %0 = llvm.mlir.constant(4 : i64) : i64 - %1 = llvm.mlir.constant(1 : i64) : i64 - %2 = llvm.mlir.constant(64 : i32) : i32 - %3 = rocdl.workgroup.id.x : i32 - %4 = llvm.sext %3 : i32 to i64 - %5 = llvm.trunc %4 : i64 to i32 - %6 = rocdl.workitem.id.x : i32 - %7 = llvm.sext %6 : i32 to i64 - %8 = llvm.mul %5, %2 : i32 - %9 = llvm.sext %8 : i32 to i64 - %10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %11 = llvm.mul %5, %2 : i32 - %12 = llvm.sext %11 : i32 to i64 - %13 = llvm.getelementptr %arg2[%12] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %14 = llvm.mul %5, %2 : i32 - %15 = llvm.sext %14 : i32 to i64 - %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %17 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %18 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %19 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %20 = llvm.getelementptr %10[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%17, %20, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %21 = llvm.getelementptr %13[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%18, %21, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %22 = llvm.load %17 : !llvm.ptr<5> -> vector<1xf32> - %23 = llvm.load %18 : !llvm.ptr<5> -> vector<1xf32> - %24 = llvm.fadd %22, %23 : vector<1xf32> - llvm.store %24, %19 : vector<1xf32>, !llvm.ptr<5> - %25 = llvm.getelementptr %16[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%25, %19, %0) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () - llvm.return - } - } - llvm.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !llvm.ptr) attributes {llvm.emit_c_interface} { - %0 = llvm.mlir.constant(63 : i32) : i32 - %1 = llvm.mlir.undef : !llvm.struct, struct)> - %2 = llvm.mlir.undef : !llvm.struct - %3 = llvm.mlir.undef : !llvm.struct - %4 = llvm.mlir.constant(64 : i32) : i32 - %5 = llvm.mlir.constant(1 : index) : i64 - %6 = llvm.mlir.constant(64 : index) : i64 - %7 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> - %8 = llvm.add %arg4, %0 : i32 - %9 = llvm.sdiv %8, %4 : i32 - %10 = llvm.mul %9, %4 : i32 - %11 = llvm.icmp "ne" %8, %10 : i32 - %12 = llvm.mlir.constant(0 : i32) : i32 - %13 = llvm.icmp "slt" %8, %12 : i32 - %14 = llvm.mlir.constant(false) : i1 - %15 = llvm.icmp "ne" %13, %14 : i1 - %16 = llvm.and %11, %15 : i1 - %17 = llvm.mlir.constant(-1 : i32) : i32 - %18 = llvm.add %9, %17 : i32 - %19 = llvm.select %16, %18, %9 : i1, i32 - %20 = llvm.sext %19 : i32 to i64 - %21 = llvm.insertvalue %7, %3[0] : !llvm.struct - %22 = llvm.insertvalue %21, %1[0] : !llvm.struct, struct)> - %23 = llvm.insertvalue %2, %22[1] : !llvm.struct, struct)> - gpu.launch_func <%arg5 : !llvm.ptr> @kernels::@vectorAddKernel_0 blocks in (%20, %5, %5) threads in (%6, %5, %5) : i64 args(%arg0 : !llvm.ptr<1>, %23 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) - llvm.return - } -} diff --git a/ir_dump/15_ConvertFuncToLLVMPass.mlir b/ir_dump/15_ConvertFuncToLLVMPass.mlir deleted file mode 100644 index 0144d3db..00000000 --- a/ir_dump/15_ConvertFuncToLLVMPass.mlir +++ /dev/null @@ -1,65 +0,0 @@ -module attributes {gpu.container_module} { - gpu.module @kernels [#rocdl.target] attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"} { - llvm.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {gpu.kernel, rocdl.kernel} { - %0 = llvm.mlir.constant(4 : i64) : i64 - %1 = llvm.mlir.constant(1 : i64) : i64 - %2 = llvm.mlir.constant(64 : i32) : i32 - %3 = rocdl.workgroup.id.x : i32 - %4 = llvm.sext %3 : i32 to i64 - %5 = llvm.trunc %4 : i64 to i32 - %6 = rocdl.workitem.id.x : i32 - %7 = llvm.sext %6 : i32 to i64 - %8 = llvm.mul %5, %2 : i32 - %9 = llvm.sext %8 : i32 to i64 - %10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %11 = llvm.mul %5, %2 : i32 - %12 = llvm.sext %11 : i32 to i64 - %13 = llvm.getelementptr %arg2[%12] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %14 = llvm.mul %5, %2 : i32 - %15 = llvm.sext %14 : i32 to i64 - %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %17 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %18 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %19 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %20 = llvm.getelementptr %10[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%17, %20, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %21 = llvm.getelementptr %13[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%18, %21, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %22 = llvm.load %17 : !llvm.ptr<5> -> vector<1xf32> - %23 = llvm.load %18 : !llvm.ptr<5> -> vector<1xf32> - %24 = llvm.fadd %22, %23 : vector<1xf32> - llvm.store %24, %19 : vector<1xf32>, !llvm.ptr<5> - %25 = llvm.getelementptr %16[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%25, %19, %0) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () - llvm.return - } - } - llvm.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !llvm.ptr) attributes {llvm.emit_c_interface} { - %0 = llvm.mlir.constant(63 : i32) : i32 - %1 = llvm.mlir.undef : !llvm.struct, struct)> - %2 = llvm.mlir.undef : !llvm.struct - %3 = llvm.mlir.undef : !llvm.struct - %4 = llvm.mlir.constant(64 : i32) : i32 - %5 = llvm.mlir.constant(1 : index) : i64 - %6 = llvm.mlir.constant(64 : index) : i64 - %7 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> - %8 = llvm.add %arg4, %0 : i32 - %9 = llvm.sdiv %8, %4 : i32 - %10 = llvm.mul %9, %4 : i32 - %11 = llvm.icmp "ne" %8, %10 : i32 - %12 = llvm.mlir.constant(0 : i32) : i32 - %13 = llvm.icmp "slt" %8, %12 : i32 - %14 = llvm.mlir.constant(false) : i1 - %15 = llvm.icmp "ne" %13, %14 : i1 - %16 = llvm.and %11, %15 : i1 - %17 = llvm.mlir.constant(-1 : i32) : i32 - %18 = llvm.add %9, %17 : i32 - %19 = llvm.select %16, %18, %9 : i1, i32 - %20 = llvm.sext %19 : i32 to i64 - %21 = llvm.insertvalue %7, %3[0] : !llvm.struct - %22 = llvm.insertvalue %21, %1[0] : !llvm.struct, struct)> - %23 = llvm.insertvalue %2, %22[1] : !llvm.struct, struct)> - gpu.launch_func <%arg5 : !llvm.ptr> @kernels::@vectorAddKernel_0 blocks in (%20, %5, %5) threads in (%6, %5, %5) : i64 args(%arg0 : !llvm.ptr<1>, %23 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) - llvm.return - } -} diff --git a/ir_dump/16_ReconcileUnrealizedCastsPass.mlir b/ir_dump/16_ReconcileUnrealizedCastsPass.mlir deleted file mode 100644 index 0144d3db..00000000 --- a/ir_dump/16_ReconcileUnrealizedCastsPass.mlir +++ /dev/null @@ -1,65 +0,0 @@ -module attributes {gpu.container_module} { - gpu.module @kernels [#rocdl.target] attributes {llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"} { - llvm.func @vectorAddKernel_0(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {gpu.kernel, rocdl.kernel} { - %0 = llvm.mlir.constant(4 : i64) : i64 - %1 = llvm.mlir.constant(1 : i64) : i64 - %2 = llvm.mlir.constant(64 : i32) : i32 - %3 = rocdl.workgroup.id.x : i32 - %4 = llvm.sext %3 : i32 to i64 - %5 = llvm.trunc %4 : i64 to i32 - %6 = rocdl.workitem.id.x : i32 - %7 = llvm.sext %6 : i32 to i64 - %8 = llvm.mul %5, %2 : i32 - %9 = llvm.sext %8 : i32 to i64 - %10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %11 = llvm.mul %5, %2 : i32 - %12 = llvm.sext %11 : i32 to i64 - %13 = llvm.getelementptr %arg2[%12] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %14 = llvm.mul %5, %2 : i32 - %15 = llvm.sext %14 : i32 to i64 - %16 = llvm.getelementptr %arg3[%15] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - %17 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %18 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %19 = llvm.alloca %1 x f32 : (i64) -> !llvm.ptr<5> - %20 = llvm.getelementptr %10[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%17, %20, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %21 = llvm.getelementptr %13[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%18, %21, %0) <{isVolatile = false}> : (!llvm.ptr<5>, !llvm.ptr<1>, i64) -> () - %22 = llvm.load %17 : !llvm.ptr<5> -> vector<1xf32> - %23 = llvm.load %18 : !llvm.ptr<5> -> vector<1xf32> - %24 = llvm.fadd %22, %23 : vector<1xf32> - llvm.store %24, %19 : vector<1xf32>, !llvm.ptr<5> - %25 = llvm.getelementptr %16[%7] : (!llvm.ptr<1>, i64) -> !llvm.ptr<1>, f32 - "llvm.intr.memcpy"(%25, %19, %0) <{isVolatile = false}> : (!llvm.ptr<1>, !llvm.ptr<5>, i64) -> () - llvm.return - } - } - llvm.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !llvm.ptr) attributes {llvm.emit_c_interface} { - %0 = llvm.mlir.constant(63 : i32) : i32 - %1 = llvm.mlir.undef : !llvm.struct, struct)> - %2 = llvm.mlir.undef : !llvm.struct - %3 = llvm.mlir.undef : !llvm.struct - %4 = llvm.mlir.constant(64 : i32) : i32 - %5 = llvm.mlir.constant(1 : index) : i64 - %6 = llvm.mlir.constant(64 : index) : i64 - %7 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> - %8 = llvm.add %arg4, %0 : i32 - %9 = llvm.sdiv %8, %4 : i32 - %10 = llvm.mul %9, %4 : i32 - %11 = llvm.icmp "ne" %8, %10 : i32 - %12 = llvm.mlir.constant(0 : i32) : i32 - %13 = llvm.icmp "slt" %8, %12 : i32 - %14 = llvm.mlir.constant(false) : i1 - %15 = llvm.icmp "ne" %13, %14 : i1 - %16 = llvm.and %11, %15 : i1 - %17 = llvm.mlir.constant(-1 : i32) : i32 - %18 = llvm.add %9, %17 : i32 - %19 = llvm.select %16, %18, %9 : i1, i32 - %20 = llvm.sext %19 : i32 to i64 - %21 = llvm.insertvalue %7, %3[0] : !llvm.struct - %22 = llvm.insertvalue %21, %1[0] : !llvm.struct, struct)> - %23 = llvm.insertvalue %2, %22[1] : !llvm.struct, struct)> - gpu.launch_func <%arg5 : !llvm.ptr> @kernels::@vectorAddKernel_0 blocks in (%20, %5, %5) threads in (%6, %5, %5) : i64 args(%arg0 : !llvm.ptr<1>, %23 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) - llvm.return - } -} diff --git a/ir_dump/17_GpuModuleToBinaryPass.mlir b/ir_dump/17_GpuModuleToBinaryPass.mlir deleted file mode 100644 index 354e120d..00000000 --- a/ir_dump/17_GpuModuleToBinaryPass.mlir +++ /dev/null @@ -1,44 +0,0 @@ -module attributes {gpu.container_module} { - gpu.binary @kernels [#gpu.object<#rocdl.target, kernels = <[#gpu.kernel_metadata<"vectorAddKernel_0", !llvm.func, struct, struct)>, ptr<1>, ptr<1>)>, metadata = {agpr_count = 0 : i64, group_segment_fixed_size = 0 : i64, max_flat_workgroup_size = 256 : i64, private_segment_fixed_size = 0 : i64, reqd_workgroup_size = array, sgpr_count = 16 : i64, sgpr_spill_count = 0 : i64, vgpr_count = 3 : i64, vgpr_spill_count = 0 : i64, wavefront_size = 64 : i64, workgroup_size_hint = array}>]>, bin = "\7FELF\02\01\01@\04\00\00\00\00\00\00\00\03\00\E0\00\01\00\00\00\00\00\00\00\00\00\00\00@\00\00\00\00\00\00\00\A8\0F\00\00\00\00\00\00L\05\00\00@\008\00\08\00@\00\0F\00\0D\00\06\00\00\00\04\00\00\00@\00\00\00\00\00\00\00@\00\00\00\00\00\00\00@\00\00\00\00\00\00\00\C0\01\00\00\00\00\00\00\C0\01\00\00\00\00\00\00\08\00\00\00\00\00\00\00\01\00\00\00\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\C0\05\00\00\00\00\00\00\C0\05\00\00\00\00\00\00\00\10\00\00\00\00\00\00\01\00\00\00\05\00\00\00\00\06\00\00\00\00\00\00\00\16\00\00\00\00\00\00\00\16\00\00\00\00\00\00\80\04\00\00\00\00\00\00\80\04\00\00\00\00\00\00\00\10\00\00\00\00\00\00\01\00\00\00\06\00\00\00\80\0A\00\00\00\00\00\00\80*\00\00\00\00\00\00\80*\00\00\00\00\00\00p\00\00\00\00\00\00\00\80\05\00\00\00\00\00\00\00\10\00\00\00\00\00\00\02\00\00\00\06\00\00\00\80\0A\00\00\00\00\00\00\80*\00\00\00\00\00\00\80*\00\00\00\00\00\00p\00\00\00\00\00\00\00p\00\00\00\00\00\00\00\08\00\00\00\00\00\00\00R\E5td\04\00\00\00\80\0A\00\00\00\00\00\00\80*\00\00\00\00\00\00\80*\00\00\00\00\00\00p\00\00\00\00\00\00\00\80\05\00\00\00\00\00\00\01\00\00\00\00\00\00\00Q\E5td\06\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\04\00\00\00\04\00\00\00\00\02\00\00\00\00\00\00\00\02\00\00\00\00\00\00\00\02\00\00\00\00\00\00\98\02\00\00\00\00\00\00\98\02\00\00\00\00\00\00\04\00\00\00\00\00\00\00\07\00\00\00\83\02\00\00 \00\00\00AMDGPU\00\00\83\AEamdhsa.kernels\91\DE\00\10\AB.agpr_count\00\A5.args\94\84\AE.address_space\A6global\A7.offset\00\A5.size\08\AB.value_kind\ADglobal_buffer\83\A7.offset\08\A5.size\04\AB.value_kind\A8by_value\84\AE.address_space\A6global\A7.offset\10\A5.size\08\AB.value_kind\ADglobal_buffer\84\AE.address_space\A6global\A7.offset\18\A5.size\08\AB.value_kind\ADglobal_buffer\B9.group_segment_fixed_size\00\B6.kernarg_segment_align\08\B5.kernarg_segment_size \B8.max_flat_workgroup_size\CD\01\00\A5.name\B1vectorAddKernel_0\BB.private_segment_fixed_size\00\AB.sgpr_count\10\B1.sgpr_spill_count\00\A7.symbol\B4vectorAddKernel_0.kd\B8.uniform_work_group_size\01\B3.uses_dynamic_stack\C2\AB.vgpr_count\03\B1.vgpr_spill_count\00\AF.wavefront_size@\ADamdhsa.target\B9amdgcn-amd-amdhsa--gfx942\AEamdhsa.version\92\01\02\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\12\03\07\00\00\16\00\00\00\00\00\00`\00\00\00\00\00\00\00\13\00\00\00\11\00\06\00\80\05\00\00\00\00\00\00@\00\00\00\00\00\00\00\01\00\00\00\01\00\00\00\01\00\00\00\1A\00\00\00\00\00\00\00\00\C0\02\00\01\00\00\00\B0\CA%\C6\EFj+\BF\03\00\00\00\03\00\00\00\01\00\00\00\02\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00vectorAddKernel_0\00vectorAddKernel_0.kd\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00 \00\00\00\00\00\00\00\80\10\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00@\00\AF\00\84\00\00\00\08\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\02\06\C0\00\00\00\00\00\01\0A\C0\10\00\00\00\02\86\00\8E\00\9F\01\90\00\82\80\8E\7F\C0\8C\BF\08\00\02\80\09\01\03\82\04\00\04\80\82\00\00$\05\01\05\82\00\80P\DC\00\00\02\01\00\80P\DC\00\00\04\02\06\00\00\80\07\01\01\82p\0F\8C\BF\01\05\02\02\00\80p\DC\00\01\00\00\00\00\81\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\00\00\80\BF\06\00\00\00\00\00\00\00\98\04\00\00\00\00\00\00\0B\00\00\00\00\00\00\00\18\00\00\00\00\00\00\00\05\00\00\00\00\00\00\00$\05\00\00\00\00\00\00\0A\00\00\00\00\00\00\00(\00\00\00\00\00\00\00\F5\FE\FFo\00\00\00\00\E0\04\00\00\00\00\00\00\04\00\00\00\00\00\00\00\04\05\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00Linker: AMD LLD 20.0.0 (/longer_pathname_so_that_rpms_can_support_packaging_the_debug_info_for_all_os_profiles/src/llvm-project/llvm 1b0eada6b0ee93e2e694c8c146d23fca90bc11c5)\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\F1\FF\03\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\1C\00\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\007\00\00\00\00\00\F1\FF\0A\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00W\00\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00{\00\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\9E\00\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\B9\00\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\DD\00\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\03\01\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00#\01\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00G\01\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00[\01\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00o\01\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\83\01\00\00\00\00\F1\FF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\C7\01\00\00\00\02\08\00\80*\00\00\00\00\00\00\00\00\00\00\00\00\00\00\A0\01\00\00\12\03\07\00\00\16\00\00\00\00\00\00`\00\00\00\00\00\00\00\B2\01\00\00\11\00\06\00\80\05\00\00\00\00\00\00@\00\00\00\00\00\00\00\00.note\00.dynsym\00.gnu.hash\00.hash\00.dynstr\00.rodata\00.text\00.dynamic\00.relro_padding\00.AMDGPU.gpr_maximums\00.comment\00.symtab\00.shstrtab\00.strtab\00\00vectorAddKernel_0.num_vgpr\00vectorAddKernel_0.num_agpr\00vectorAddKernel_0.numbered_sgpr\00vectorAddKernel_0.num_named_barrier\00vectorAddKernel_0.private_seg_size\00vectorAddKernel_0.uses_vcc\00vectorAddKernel_0.uses_flat_scratch\00vectorAddKernel_0.has_dyn_sized_stack\00vectorAddKernel_0.has_recursion\00vectorAddKernel_0.has_indirect_call\00amdgpu.max_num_vgpr\00amdgpu.max_num_agpr\00amdgpu.max_num_sgpr\00amdgpu.max_num_named_barrier\00vectorAddKernel_0\00vectorAddKernel_0.kd\00_DYNAMIC\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\07\00\00\00\02\00\00\00\00\00\00\00\00\02\00\00\00\00\00\00\00\02\00\00\00\00\00\00\98\02\00\00\00\00\00\00\00\00\00\00\00\00\00\00\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\07\00\00\00\0B\00\00\00\02\00\00\00\00\00\00\00\98\04\00\00\00\00\00\00\98\04\00\00\00\00\00\00H\00\00\00\00\00\00\00\05\00\00\00\01\00\00\00\08\00\00\00\00\00\00\00\18\00\00\00\00\00\00\00\0F\00\00\00\F6\FF\FFo\02\00\00\00\00\00\00\00\E0\04\00\00\00\00\00\00\E0\04\00\00\00\00\00\00$\00\00\00\00\00\00\00\02\00\00\00\00\00\00\00\08\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\19\00\00\00\05\00\00\00\02\00\00\00\00\00\00\00\04\05\00\00\00\00\00\00\04\05\00\00\00\00\00\00 \00\00\00\00\00\00\00\02\00\00\00\00\00\00\00\04\00\00\00\00\00\00\00\04\00\00\00\00\00\00\00\1F\00\00\00\03\00\00\00\02\00\00\00\00\00\00\00$\05\00\00\00\00\00\00$\05\00\00\00\00\00\00(\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00'\00\00\00\01\00\00\00\02\00\00\00\00\00\00\00\80\05\00\00\00\00\00\00\80\05\00\00\00\00\00\00@\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00@\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00/\00\00\00\01\00\00\00\06\00\00\00\00\00\00\00\00\16\00\00\00\00\00\00\00\06\00\00\00\00\00\00\80\04\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\005\00\00\00\06\00\00\00\03\00\00\00\00\00\00\00\80*\00\00\00\00\00\00\80\0A\00\00\00\00\00\00p\00\00\00\00\00\00\00\05\00\00\00\00\00\00\00\08\00\00\00\00\00\00\00\10\00\00\00\00\00\00\00>\00\00\00\08\00\00\00\03\00\00\00\00\00\00\00\F0*\00\00\00\00\00\00\F0\0A\00\00\00\00\00\00\10\05\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00M\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\F0\0A\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00b\00\00\00\01\00\00\000\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\F0\0A\00\00\00\00\00\00\AF\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00k\00\00\00\02\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\A0\0B\00\00\00\00\00\00\B0\01\00\00\00\00\00\00\0E\00\00\00\10\00\00\00\08\00\00\00\00\00\00\00\18\00\00\00\00\00\00\00s\00\00\00\03\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00P\0D\00\00\00\00\00\00\85\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00}\00\00\00\03\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00\D5\0D\00\00\00\00\00\00\D0\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\01\00\00\00\00\00\00\00\00\00\00\00\00\00\00\00">] - llvm.func @vectorAdd(%arg0: !llvm.ptr<1>, %arg1: !llvm.struct, struct)>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>, %arg4: i32, %arg5: !llvm.ptr) attributes {llvm.emit_c_interface} { - %0 = llvm.mlir.constant(63 : i32) : i32 - %1 = llvm.mlir.undef : !llvm.struct, struct)> - %2 = llvm.mlir.undef : !llvm.struct - %3 = llvm.mlir.undef : !llvm.struct - %4 = llvm.mlir.constant(64 : i32) : i32 - %5 = llvm.mlir.constant(1 : index) : i64 - %6 = llvm.mlir.constant(64 : index) : i64 - %7 = llvm.extractvalue %arg1[0, 0] : !llvm.struct, struct)> - %8 = llvm.add %arg4, %0 : i32 - %9 = llvm.sdiv %8, %4 : i32 - %10 = llvm.mul %9, %4 : i32 - %11 = llvm.icmp "ne" %8, %10 : i32 - %12 = llvm.mlir.constant(0 : i32) : i32 - %13 = llvm.icmp "slt" %8, %12 : i32 - %14 = llvm.mlir.constant(false) : i1 - %15 = llvm.icmp "ne" %13, %14 : i1 - %16 = llvm.and %11, %15 : i1 - %17 = llvm.mlir.constant(-1 : i32) : i32 - %18 = llvm.add %9, %17 : i32 - %19 = llvm.select %16, %18, %9 : i1, i32 - %20 = llvm.sext %19 : i32 to i64 - %21 = llvm.insertvalue %7, %3[0] : !llvm.struct - %22 = llvm.insertvalue %21, %1[0] : !llvm.struct, struct)> - %23 = llvm.insertvalue %2, %22[1] : !llvm.struct, struct)> - gpu.launch_func <%arg5 : !llvm.ptr> @kernels::@vectorAddKernel_0 blocks in (%20, %5, %5) threads in (%6, %5, %5) : i64 args(%arg0 : !llvm.ptr<1>, %23 : !llvm.struct, struct)>, %arg2 : !llvm.ptr<1>, %arg3 : !llvm.ptr<1>) - llvm.return - } -} - - -================================================== -Test 1: Eager execution -================================================== -[Eager] Result correct: True - -================================================== -Test 2: CUDA Graph Capture -================================================== -[Graph Capture] Result correct: True - -All passed: True diff --git a/python/flydsl/_mlir b/python/flydsl/_mlir deleted file mode 120000 index eba03cff..00000000 --- a/python/flydsl/_mlir +++ /dev/null @@ -1 +0,0 @@ -/home/jli10004/flydsl/mi450-cmodel-env/flydsl-prev/build-fly/python_packages/flydsl/_mlir \ No newline at end of file