From 2a26098499e99e8d1a873d9e8412ff2b010b906a Mon Sep 17 00:00:00 2001 From: zjli2013 Date: Mon, 30 Mar 2026 17:57:05 +0800 Subject: [PATCH] [ROCm] Add AMD MI300X support: HIP compat fix, optimization knowledge, rocm-smi fallback Enable AutoKernel on AMD Instinct MI300X (gfx942, CDNA3) with ROCm/HIP backend. Changes: - kernels/fused_mlp.py: replace tl.math.tanh with sigmoid identity (tl.math.tanh is unavailable on Triton HIP backend, crashes at compile) - knowledge/amd_cdna3_optimization.md: MI300X architecture guide for agents (304 CUs, 64-thread wavefronts, MFMA, LDS, HBM3 hierarchy, perf counters) - knowledge/workload_guidance.md: bottleneck-aware optimization strategies (memory/compute/latency-bound decision framework for Triton and HIP) - program.md: add AMD CDNA3 (gfx942) Tier 5 optimization playbook - prepare.py: add rocm-smi fallback when nvidia-smi is unavailable Tested on MI300X (gfx942), ROCm 6.4, PyTorch 2.6.0, Triton 3.2.0 HIP backend. All 9 Triton starter kernels pass correctness checks on MI300X. AMD optimization knowledge sourced from AMD-AGI/GEAK with attribution. Made-with: Cursor --- kernels/fused_mlp.py | 4 +- knowledge/amd_cdna3_optimization.md | 249 ++++++++++++++++++++++++++++ knowledge/workload_guidance.md | 134 +++++++++++++++ prepare.py | 24 ++- program.md | 26 +++ 5 files changed, 432 insertions(+), 5 deletions(-) create mode 100644 knowledge/amd_cdna3_optimization.md create mode 100644 knowledge/workload_guidance.md diff --git a/kernels/fused_mlp.py b/kernels/fused_mlp.py index 3c91ced..d3e3ff5 100644 --- a/kernels/fused_mlp.py +++ b/kernels/fused_mlp.py @@ -89,7 +89,9 @@ def fused_gate_up_kernel( gate_activated = acc_gate * tl.sigmoid(acc_gate) else: # GELU approximation - gate_activated = 0.5 * acc_gate * (1.0 + tl.math.tanh(0.7978845608 * (acc_gate + 0.044715 * acc_gate * acc_gate * acc_gate))) + # tl.math.tanh is unavailable on HIP backend; use sigmoid identity: tanh(x) = 2*sigmoid(2x) - 1 + gelu_arg = 0.7978845608 * (acc_gate + 0.044715 * acc_gate * acc_gate * acc_gate) + gate_activated = 0.5 * acc_gate * (1.0 + 2.0 * tl.sigmoid(2.0 * gelu_arg) - 1.0) result = gate_activated * acc_up diff --git a/knowledge/amd_cdna3_optimization.md b/knowledge/amd_cdna3_optimization.md new file mode 100644 index 0000000..1bf527d --- /dev/null +++ b/knowledge/amd_cdna3_optimization.md @@ -0,0 +1,249 @@ +# AMD CDNA3 (MI300X, gfx942) Optimization Reference + +> Curated from [AMD-AGI/GEAK](https://github.com/AMD-AGI/GEAK) knowledge-base. +> Use this reference when optimizing Triton or HIP kernels on MI300-series GPUs. + +--- + +## 1. Hardware Architecture + +### MI300X Key Specs + +| Parameter | Value | +|-----------|-------| +| Architecture | CDNA3 (3D chiplet, 5nm + 6nm) | +| Compute Units | 304 CUs | +| Matrix Cores | 1216 (MFMA engines) | +| HBM3 Memory | 192 GB | +| Memory Bandwidth | 5.3 TB/s peak | +| FP16 / BF16 Peak | 1307 TFLOPS | +| FP32 Peak | 163 TFLOPS | +| FP64 Peak | 163 TFLOPS | +| INT8 Peak | 2614 TOPS | +| LDS per CU | 64 KB | +| L2 Cache | 256 MB (shared across dies) | +| Wavefront Size | **64 threads** (not 32) | +| Max VGPRs per CU | 65536 | +| TDP | 750W | +| Interconnect | Infinity Fabric, 128 GB/s per link | + +### Variant Detection + +All MI300-series GPUs share gfx942 ISA and 304 CUs. Detected via `gcnArchName = "gfx942:sramecc+:xnack-"`. +Some variants (MI308XHF etc.) may report different device names but share the same architecture and specs. + +### Memory Hierarchy + +``` +HBM3 (5.3 TB/s, 192 GB) + ↓ +Infinity Cache / L3 (shared across dies) + ↓ +L2 Cache (256 MB, shared by CU groups) + ↓ +L1 Cache (per CU, 16-32 KB) + ↓ +LDS (64 KB per CU, explicitly managed — maps to Triton shared memory) + ↓ +Registers (65536 VGPRs per CU, partitioned across active wavefronts) +``` + +--- + +## 2. Triton on MI300X: Key Differences from NVIDIA + +### Wavefront = 64 (not Warp = 32) + +All Triton `num_warps` values actually control wavefronts of 64 threads: +- `num_warps=4` → 4 × 64 = 256 threads per block (good default for CDNA) +- `num_warps=8` → 8 × 64 = 512 threads per block + +### Recommended Autotune Configs for gfx942 + +```python +import triton + +MI300_MATMUL_CONFIGS = [ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, + num_warps=4, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, + num_warps=8, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, + num_warps=8, num_stages=2), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, + num_warps=8, num_stages=2), +] + +MI300_ELEMENTWISE_CONFIGS = [ + triton.Config({'BLOCK_SIZE': 256}, num_warps=4), + triton.Config({'BLOCK_SIZE': 512}, num_warps=8), + triton.Config({'BLOCK_SIZE': 1024}, num_warps=8), + triton.Config({'BLOCK_SIZE': 2048}, num_warps=8), +] +``` + +### `waves_per_eu` Occupancy Hint + +Controls how many wavefronts per execution unit the compiler targets: +```python +triton.Config({'BLOCK_SIZE': 256}, num_warps=4, waves_per_eu=2) # compute-bound +triton.Config({'BLOCK_SIZE': 256}, num_warps=4, waves_per_eu=4) # memory-bound +``` +- `waves_per_eu=0`: auto (compiler decides) +- `waves_per_eu=2-4`: good for compute-bound kernels (more registers per wavefront) +- `waves_per_eu=4-8`: good for memory-bound kernels (more wavefronts to hide latency) + +### HIP Backend Limitations + +| Feature | Status | +|---------|--------| +| `tl.dot` | Works — maps to MFMA | +| `tl.trans` | Works | +| `tl.sigmoid` | Works | +| `tl.math.tanh` | **NOT available** — use `2*tl.sigmoid(2*x) - 1` | +| `tl.math.rsqrt` | Works | +| `num_stages` | **Must be >= 1** (0 crashes). Use 2 for default pipelining. | +| `tl.load(..., eviction_policy=...)` | May be ignored on HIP | + +### Environment Variables for Debugging + +```bash +export TRITON_PRINT_AUTOTUNING=1 # show autotune results +export TRITON_ALWAYS_COMPILE=1 # bypass cache +export TRITON_DEBUG=1 # debug mode +export MLIR_ENABLE_DUMP=1 # dump MLIR IR +export HIP_VISIBLE_DEVICES=0 # select GPU +``` + +--- + +## 3. Occupancy Tuning on CDNA + +> Source: GEAK `knowledge-base/amd-knowledge-base/layer-6-extended/optimize-guides/silu_optim/occupancy-tuning.md` + +### Resource Limits per CU + +| Resource | Limit | Impact | +|----------|-------|--------| +| Wavefronts per CU | 32-40 (arch dependent) | Hard cap on parallelism | +| VGPRs per CU | 65536 | Registers per wavefront = registers_per_thread × 64 | +| LDS per CU | 64 KB | Shared among all blocks on the CU | +| Wavefront size | 64 threads | Fixed | + +### Occupancy Calculation + +``` +Occupancy = Active_Wavefronts / Max_Wavefronts_per_CU + +Registers_per_Wavefront = Registers_per_Thread × 64 +Max_Concurrent_Wavefronts = 65536 / Registers_per_Wavefront +``` + +Example: 32 registers/thread → 2048/wavefront → 32 wavefronts → ~80% occupancy. + +### Occupancy Sweet Spots (Memory-Bound Kernels) + +| Occupancy | Performance | When to Use | +|-----------|-------------|-------------| +| 25-40% | Sub-optimal | Avoid unless compute-bound with high ILP | +| 50-75% | Good | Balanced for most kernels | +| **75-90%** | **Sweet spot** | Memory-bound kernels (softmax, layernorm, etc.) | +| 100% | Not always best | May over-constrain register usage | + +--- + +## 4. Memory Coalescing on CDNA + +> Source: GEAK `knowledge-base/amd-knowledge-base/layer-6-extended/optimize-guides/silu_optim/memory-coalescing-hip.md` + +### Transaction Sizes + +| Cache Level | Line Size | Notes | +|-------------|-----------|-------| +| L1 Cache Line | 64 bytes | 16 × fp32 or 32 × bf16 | +| L2 Cache Line | 128 bytes | 32 × fp32 or 64 × bf16 | +| Optimal Transaction | 128-256 bytes | Full wavefront: 64 threads × 4 bytes = 256 bytes | + +### Coalescing Efficiency + +| Pattern | Efficiency | Bandwidth (MI300X) | +|---------|------------|---------------------| +| Perfect coalescing | 100% | 3700-4700 GB/s (70-90% peak) | +| Stride-2 access | 25-50% | 1300-2600 GB/s | +| Random access | 1.5% | 80-265 GB/s | + +### Rules for Triton + +1. Standard `pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)` pattern is coalesced. +2. Use vectorized loads when possible (compiler usually handles this). +3. Avoid strided pointer arithmetic that breaks contiguity. +4. BLOCK_SIZE should be a multiple of 64 (wavefront size). + +--- + +## 5. Performance Counters for Profiling + +> Source: GEAK `knowledge-base/amd-knowledge-base/layer-1-hardware/amd-gpu-arch/mi300-mi200-performance-counters.md` + +### Key Counters to Collect + +```bash +# Compute utilization +rocprof -i counters.txt ./your_app +# counters.txt: +# pmc : SQ_INSTS_VALU SQ_INSTS_MFMA_MOPS_FP16 SQ_INSTS_MFMA_MOPS_BF16 +# pmc : SQ_ACTIVE_INST_VALU SQ_ACTIVE_INST_MFMA SQ_BUSY_CYCLES +# pmc : SQ_WAVE_CYCLES SQ_WAVES +# pmc : TCC_HIT TCC_MISS TCC_REQ +# pmc : TCC_EA_RDREQ TCC_EA_WRREQ +``` + +### Derived Metrics + +| Metric | Formula | Target | +|--------|---------|--------| +| GPU Utilization | `GRBM_GUI_ACTIVE / GRBM_COUNT` | >90% | +| MFMA Efficiency | `SQ_ACTIVE_INST_MFMA / SQ_BUSY_CYCLES` | >50% for matmul | +| L2 Hit Rate | `TCC_HIT / TCC_REQ` | >80% for tiled kernels | +| Wavefront Occupancy | `SQ_WAVE_CYCLES / (CUs × Max_Waves × Total_Cycles)` | >60% | +| Memory BW Utilization | `(TCC_EA_RDREQ × 32B + TCC_EA_WRREQ × 64B) / Time / Peak_BW` | >70% for mem-bound | + +### Bottleneck Identification + +| Indicator | Bottleneck | Action | +|-----------|-----------|--------| +| High `TCP_PENDING_STALL_CYCLES` | Memory-bound | Improve blocking, increase SRAM/L2 reuse | +| Low L2 hit rate | Memory-bound | Better tiling, prefetch | +| High `SQ_ACTIVE_INST_MFMA` | Compute-bound | Already using matrix cores well | +| Low `SPI_CSN_BUSY` | Launch-bound | Fuse kernels, use persistent patterns | +| High `SQ_LDS_BANK_CONFLICT` | LDS contention | Adjust access stride, add padding | + +--- + +## 6. CDNA3 vs NVIDIA Architecture Comparison + +| Feature | MI300X (CDNA3) | H100 (Hopper) | +|---------|---------------|---------------| +| Wavefront/Warp Size | 64 | 32 | +| FP16 Peak TFLOPS | 1307 | 989 (SXM) | +| Memory BW | 5.3 TB/s | 3.35 TB/s | +| Memory Capacity | 192 GB HBM3 | 80 GB HBM3 | +| Shared Mem / LDS | 64 KB/CU | 228 KB/SM (configurable) | +| L2 Cache | 256 MB | 50 MB | +| Matrix Ops | MFMA | WGMMA / HMMA | +| Async Copy | Compiler-managed | TMA (hardware) | +| Sparsity | Not available | 2:4 structured | + +**Implication for kernel optimization**: MI300X has higher peak compute and bandwidth but less +per-CU shared memory (64 KB vs 228 KB). Use smaller tile sizes in shared memory and rely more +on L2 cache (256 MB is 5× larger than H100's). + +--- + +## References + +- [GEAK: GPU Efficient Automatic Kernel optimizer](https://github.com/AMD-AGI/GEAK) +- [CDNA3 Architecture Guide](https://www.amd.com/en/technologies/cdna) +- [MI300 Performance Counters](https://instinct.docs.amd.com/latest/gpu-arch/mi300-mi200-performance-counters.html) +- [Triton on ROCm](https://github.com/ROCm/triton) +- [ROCm HIP Performance Guidelines](https://rocm.docs.amd.com/projects/HIP/en/latest/how-to/performance_guidelines.html) diff --git a/knowledge/workload_guidance.md b/knowledge/workload_guidance.md new file mode 100644 index 0000000..d1be0f4 --- /dev/null +++ b/knowledge/workload_guidance.md @@ -0,0 +1,134 @@ +# Workload-Aware Optimization Strategy + +> Curated from [AMD-AGI/GEAK](https://github.com/AMD-AGI/GEAK) `workload_guidance.py`. +> This framework guides the optimization agent to prioritize profiling-driven kernel-body +> rewrites over parameter sweeps, for both Triton and HIP backends. + +--- + +## Core Principle + +**Prefer kernel-body algorithmic changes over autotune parameter sweeps.** + +The GEAK project's empirical finding: pure `@triton.autotune` config sweeps or +`num_warps/num_stages/BLOCK_*` parameter searches without kernel-body changes yield +limited gains (<5%). Kernel-body rewrites (tiling, fusion, math reformulation) typically +yield 20-50%+ improvement. + +--- + +## Triton Backend Strategy + +### Bottleneck: Memory-Bound + +When `bench.py` reports `bottleneck: memory` or high `% peak bandwidth`: + +**Prefer First:** +1. Algorithmic kernel-body rewrites that change the reduction tree, tiling scheme, or math formulation +2. Operation fusion — merge adjacent work into the Triton kernel body to eliminate memory round-trips +3. Memory-access rewrites: better blocking, fewer redundant loads/stores, higher SRAM/L2 reuse +4. Masking, pointer-arithmetic, or load/store simplifications that reduce HBM traffic + +**Consider Next:** +- Shape-specialized kernel variants for different input regimes +- Vectorized or blocked load/store patterns as part of a broader traffic reduction plan +- Kernel-body memory-layout and live-range cleanup + +**Deprioritize:** +- `@triton.autotune`-only config sweeps +- Pure `num_warps / num_stages / BLOCK_*` parameter search without body change +- Python dispatch, import-routing, or wrapper-only edits + +### Bottleneck: Compute-Bound + +When `bench.py` reports `bottleneck: compute` or high `% peak TFLOPS`: + +**Prefer First:** +1. Instruction-count reduction and control-flow simplification inside hot loops +2. MFMA / `tl.dot`-friendly reformulations, cheaper math primitives +3. Algorithmic approximations when correctness permits + +**Consider Next:** +- Register-pressure and live-range reductions for better compiler scheduling +- Shape-specialized variants + +**Deprioritize:** +- Same as memory-bound + +### Bottleneck: Latency-Bound + +When kernel is very short (small shapes, launch overhead dominates): + +**Prefer First:** +1. Fuse adjacent short kernels so each launch does materially more work +2. Persistent or multi-tile kernel patterns that amortize launch overhead +3. Increase work per program + +**Consider Next:** +- Shape-specialized kernel variants for small vs large shapes + +--- + +## HIP Backend Strategy + +### Bottleneck: Memory-Bound + +**Prefer First:** +1. Algorithmic HIP kernel-body rewrites (search / reduction / tiling structure) +2. Coalescing, vectorized access, or LDS staging to raise effective bandwidth +3. Global-memory traffic reduction by fusing steps or recomputing cheap values + +**Consider Next:** +- Wavefront-level memory-access reordering or bank-conflict reduction +- Size-specialized kernel variants + +**Deprioritize:** +- Launch-config or occupancy-only tuning +- Wrapper / dispatch / copy-path edits + +### Bottleneck: Compute-Bound + +**Prefer First:** +1. Instruction-count reduction, branch simplification, cheaper per-thread math +2. Wave intrinsics, MFMA-friendly decomposition, unrolled inner loops + +### Bottleneck: LDS-Bound + +**Prefer First:** +1. LDS-bank-conflict reduction and staged-access restructuring +2. Move transient data from LDS to registers when it reduces LDS pressure + +--- + +## Strategy Selection Flowchart + +``` +bench.py output + │ + ├── bottleneck: memory + │ ├── % peak BW > 70% → focus on algorithmic fusion / fewer memory ops + │ └── % peak BW < 50% → focus on coalescing / blocking / access patterns + │ + ├── bottleneck: compute + │ ├── % peak TFLOPS > 60% → near roofline, try smaller algorithmic changes + │ └── % peak TFLOPS < 40% → tl.dot not utilized well, check data layouts + │ + └── bottleneck: latency + └── focus on kernel fusion and persistent patterns +``` + +--- + +## Planning Policy (from GEAK) + +1. Fill most optimization attempts with "Prefer First" strategies +2. Only add autotune / launch / wrapper attempts after at least 3 preferred-family attempts +3. Skip iteration if only low-priority ideas remain — move to next kernel +4. Each iteration should change ONE thing in the kernel body and measure its effect + +--- + +## References + +- [GEAK workload_guidance.py](https://github.com/AMD-AGI/GEAK/blob/main/src/minisweagent/agents/heterogeneous/workload_guidance.py) +- [GEAK mini_kernel_strategy_list.yaml](https://github.com/AMD-AGI/GEAK/blob/main/src/minisweagent/config/mini_kernel_strategy_list.yaml) diff --git a/prepare.py b/prepare.py index 95aaa73..882677c 100644 --- a/prepare.py +++ b/prepare.py @@ -104,14 +104,16 @@ def verify_environment() -> None: cc_major = props.major cc_minor = props.minor - # Driver and CUDA runtime versions - # torch.version.cuda gives the CUDA toolkit version PyTorch was compiled with + # Runtime version: CUDA toolkit or ROCm/HIP version cuda_version = torch.version.cuda or "unknown" + hip_version = getattr(torch.version, "hip", None) + if hip_version: + cuda_version = f"ROCm {hip_version}" - # nvidia-smi driver version -- fall back gracefully + # GPU driver version -- try nvidia-smi first, then rocm-smi as fallback + import subprocess driver_str = "unknown" try: - import subprocess result = subprocess.run( ["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader,nounits"], capture_output=True, text=True, timeout=5, @@ -121,6 +123,20 @@ def verify_environment() -> None: except Exception: pass + if driver_str == "unknown": + try: + result = subprocess.run( + ["rocm-smi", "--showdriverversion"], + capture_output=True, text=True, timeout=5, + ) + if result.returncode == 0: + for line in result.stdout.strip().split("\n"): + if "Driver" in line: + driver_str = line.split(":")[-1].strip() + break + except Exception: + pass + print(f"GPU: {gpu_name}") print(f" Memory: {mem_gb:.1f} GB") print(f" SM Count: {sm_count}") diff --git a/program.md b/program.md index 87c8749..2d4f2be 100644 --- a/program.md +++ b/program.md @@ -527,6 +527,21 @@ Once block sizes are tuned, memory is usually the bottleneck. - L4: very memory-bandwidth limited. - RTX 4090: 128 SMs but consumer-grade memory bandwidth. +**MI300X (CDNA3, gfx942):** *(source: [GEAK](https://github.com/AMD-AGI/GEAK) knowledge-base)* +- 304 CUs, wavefront size = 64 (not 32). All warp-level reasoning must use 64-wide. +- HBM3: 5.3 TB/s peak bandwidth, 192 GB capacity. Memory-bound kernels should target >70% of this. +- MFMA (Matrix Fused Multiply-Add) instructions: 1307 TFLOPS FP16/BF16. Use `tl.dot` which maps to MFMA. +- LDS: 64 KB per CU, shared memory in Triton maps to LDS. Avoid >64 KB per block. +- BLOCK_SIZE = 256 is often optimal for CDNA (vs 128 for NVIDIA). Try 128/256 in autotune configs. +- `num_warps`: on AMD each "warp" in Triton is a wavefront of 64 threads. 4 warps = 256 threads. +- `waves_per_eu`: controls occupancy hint. Add to `triton.Config(kwargs, num_warps=N, waves_per_eu=M)`. + Typical values: 0 (auto), 2-4 for compute-bound, 4-8 for memory-bound. +- `tl.math.tanh` is NOT available on HIP backend. Use sigmoid identity: `2*tl.sigmoid(2*x) - 1`. +- No `cp.async` or TMA equivalent; Triton compiler handles async prefetch internally. +- Prefer BF16 over FP16 for training workloads (same throughput, better dynamic range). +- Profile with `rocprof --hip-trace` or `rocprofv3`. Key counters: `SQ_INSTS_MFMA_MOPS_FP16`, + `TCC_HIT/TCC_REQ` (L2 hit rate), `SQ_WAVE_CYCLES` (occupancy), `TCP_PENDING_STALL_CYCLES` (mem stall). + **Typical gains**: 5-15% from architecture-specific tuning. ### Tier 6: Kernel-Specific Tricks @@ -677,6 +692,17 @@ asm volatile("cp.async.wait_group 0;"); - Smaller shared memory per SM -- use smaller tiles. - High FP16 throughput but consumer-grade memory bandwidth. +**AMD CDNA3 (gfx942, MI300X) — HIP backend:** *(source: [GEAK](https://github.com/AMD-AGI/GEAK) knowledge-base)* +- Use `hipcc` with `-march=gfx942 -O3`. Compile flags: `--offload-arch=gfx942`. +- Wavefront = 64 threads (vs CUDA warp = 32). All `__shfl*` → `__shfl*` in HIP but width=64. +- MFMA intrinsics replace NVIDIA's WMMA/MMA. Use rocBLAS or rocWMMA for matrix ops. +- LDS: 64 KB per CU. Use `__launch_bounds__(256, 4)` for occupancy (256 threads, min 4 blocks/CU). +- VGPRs: 65536 per CU. Target 20-32 registers/thread for good occupancy (75-100%). +- `__restrict__` and `__builtin_nontemporal_load/store` for memory hint control. +- No `cp.async`; use `__builtin_amdgcn_global_load_lds()` for async LDS fills (advanced). +- Memory coalescing: same principle, 128-byte transactions, wavefront of 64 threads. +- Profile: `rocprof --stats`, `rocprofv3`. Key counters: `SQ_INSTS_MFMA_MOPS_*`, `TCC_*`, `SQ_WAVE_CYCLES`. + **Typical gains**: 5-15% from arch-specific tuning. ### CUDA Tier 6: Kernel-Specific Tricks