Skip to content

[ROCm] Add AMD MI300X agent optimization support: HIP compat, knowledge base, program playbook#6

Open
ZJLi2013 wants to merge 1 commit intoRightNow-AI:mainfrom
ZJLi2013:mi300x-rocm-support
Open

[ROCm] Add AMD MI300X agent optimization support: HIP compat, knowledge base, program playbook#6
ZJLi2013 wants to merge 1 commit intoRightNow-AI:mainfrom
ZJLi2013:mi300x-rocm-support

Conversation

@ZJLi2013
Copy link
Copy Markdown

Summary

Enable AutoKernel's optimization agent to work effectively on AMD Instinct MI300X (gfx942, CDNA3) with ROCm/HIP backend. This complements PR #3 (GPU detection) by adding the pieces needed for agents to actually optimize kernels on MI300X — compatibility fixes, architecture knowledge, and an optimization playbook.
Builds on top of v1.3.0 AMD GPU detection by @andyluo7.

Changes

HIP Compatibility Fix

  • kernels/fused_mlp.py: Replace tl.math.tanh with sigmoid identity (tanh(x) = 2·sigmoid(2x) − 1). tl.math.tanh is unavailable on the Triton HIP backend and crashes at compile time. Numerically equivalent on both CUDA and ROCm.

Agent Knowledge Base (new files)

  • knowledge/amd_cdna3_optimization.md (249 lines): MI300X architecture reference — 304 CUs, 64-thread wavefronts, MFMA instructions, LDS/HBM3 memory hierarchy, waves_per_eu tuning, ROCProfiler counters. Curated from AMD-AGI/GEAK with attribution.
  • knowledge/workload_guidance.md (134 lines): Bottleneck-aware optimization strategy framework — guides agents to prioritize kernel-body changes over parameter sweeps, with specific "Prefer First / Consider / Deprioritize" lists for memory-bound, compute-bound, and latency-bound workloads.

Agent Playbook

  • program.md: Add MI300X (CDNA3, gfx942) sections to both Triton Tier 5 and CUDA Tier 5, covering wavefront sizing, MFMA, LDS limits, waves_per_eu, tl.math.tanh workaround, profiling counters, and HIP compilation flags.

ROCm Environment Support

  • prepare.py: Add rocm-smi fallback for driver detection when nvidia-smi is unavailable. Also detects ROCm/HIP version from torch.version.hip.

Testing

Tested end-to-end on AMD Instinct MI300X (gfx942, ROCm 6.4, PyTorch 2.6.0, Triton 3.2.0 HIP backend):

  • All 9 Triton starter kernels pass bench.py correctness checks (smoke test, shape sweep, numerical stability, determinism, edge cases)
  • prepare.py correctly detects ROCm driver via rocm-smi
  • Agent optimization loop verified: 5 kernels optimized through iterative edit-benchmark-keep/revert cycles
    Optimization results from closed-loop agent testing (for reference, not included in this PR):
    | Kernel | Starter → Optimized | Key technique |
    |--------|---------------------|---------------|
    | flash_attention | 0.50x → 2.20x | Remove .to(tl.float32) on tl.dot inputs → enable MFMA FP16 |
    | softmax | 1.16x → 2.26x | Multi-row processing (4 rows/program) |
    | matmul | 0.44x → 0.53x | 3-arg tl.dot(a, b, acc) + autotune |
    | rotary_embedding | 0.82x → 1.09x | Native dtype computation |
    | fused_mlp | 0.92x → 1.02x | Grouped tile ordering + autotune |

Key discovery: a single .to(tl.float32) cast on tl.dot inputs bypasses MFMA FP16 instructions entirely on MI300X, causing 4.4x performance loss. This is documented in the knowledge base for future agents.

Zero NVIDIA impact

  • fused_mlp.py: sigmoid identity is mathematically equivalent to tl.math.tanh; works on both backends
  • prepare.py: rocm-smi path only triggers when nvidia-smi fails (never on NVIDIA)
  • program.md / knowledge/: additive content in new sections, no existing content modified

…, rocm-smi fallback

Enable AutoKernel on AMD Instinct MI300X (gfx942, CDNA3) with ROCm/HIP backend.

Changes:

  - kernels/fused_mlp.py: replace tl.math.tanh with sigmoid identity

    (tl.math.tanh is unavailable on Triton HIP backend, crashes at compile)

  - knowledge/amd_cdna3_optimization.md: MI300X architecture guide for agents

    (304 CUs, 64-thread wavefronts, MFMA, LDS, HBM3 hierarchy, perf counters)

  - knowledge/workload_guidance.md: bottleneck-aware optimization strategies

    (memory/compute/latency-bound decision framework for Triton and HIP)

  - program.md: add AMD CDNA3 (gfx942) Tier 5 optimization playbook

  - prepare.py: add rocm-smi fallback when nvidia-smi is unavailable

Tested on MI300X (gfx942), ROCm 6.4, PyTorch 2.6.0, Triton 3.2.0 HIP backend.

All 9 Triton starter kernels pass correctness checks on MI300X.

AMD optimization knowledge sourced from AMD-AGI/GEAK with attribution.

Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant