Enable ROCm Triton backend for AllReduce#684
Enable ROCm Triton backend for AllReduce#684mfrancepillois wants to merge 4 commits intorocm-jaxlib-v0.9.1from
Conversation
08ee4a7 to
1daa0d9
Compare
xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_extern_atomics_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_extern_atomics_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_extern_atomics_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_extern_atomics_pass.cc
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_implement_extern_atomics_rocm_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_implement_extern_atomics_rocm_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_implement_extern_atomics_rocm_pass.cc
Outdated
Show resolved
Hide resolved
Claude Code Review SummaryThis PR enables the ROCm Triton backend for AllReduce collective operations via Key issues found (see inline comments for details):
Missing test coverage for |
Re-review SummaryRe-reviewed the latest diff. All 8 previously flagged issues remain unaddressed — no new findings. Key open items:
See existing inline comments for details. 🤖 Generated with Claude Code |
a866b69 to
e885196
Compare
Re-review SummaryExcellent progress — 7 of 8 previous findings have been addressed in this revision:
1 minor item remaining: Overall this is looking solid. The two-stage lowering design is clean and the ROCm atomics implementation is well-structured. 🤖 Generated with Claude Code |
xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_extern_atomics_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_implement_extern_atomics_rocm_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_implement_extern_atomics_rocm_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_implement_extern_atomics_rocm_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_implement_extern_atomics_rocm_pass.cc
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_extern_atomics_pass.cc
Outdated
Show resolved
Hide resolved
xla/backends/gpu/codegen/triton/transforms/tests/triton_xla_lower_extern_atomics.mlir
Show resolved
Hide resolved
Re-review SummaryGood progress since the last review — 6 of 8 previous findings have been addressed (mask handling, validation, syncscope, signed/unsigned comparison, scope parsing). The include ordering issue in New findings (see inline comments):
🤖 Generated with Claude Code |
Re-review SummaryRe-reviewed the latest diff. All previously flagged issues remain applicable — no new findings. Key outstanding items from prior review:
🤖 Generated with Claude Code |
Re-review Summary (commit 90ec42c)All previously flagged issues have been addressed in this revision. Key fixes verified:
No new issues found. One minor prior note (include ordering of 🤖 Generated with Claude Code |
xla/backends/gpu/codegen/triton/transforms/triton_xla_lower_extern_get_tid_pass.cc
Show resolved
Hide resolved
| #if defined(TENSORFLOW_USE_ROCM) | ||
| // ROCm: Use constant value directly as ROCDL dialect doesn't define memory | ||
| // space enum | ||
| static constexpr int32_t kGlobalAddressSpace = 1; |
There was a problem hiding this comment.
https://mlir.llvm.org/docs/Dialects/GPU/#gpu-address-spaces should be same for both.
| bool is_supported = false; | ||
|
|
||
| // CUDA: Requires compute capability 9.0+ (Hopper or newer) | ||
| if (device_info.cuda_compute_capability().major >= 9) { |
There was a problem hiding this comment.
Where did this check exist before this change?
There was a problem hiding this comment.
| // Add XLA custom pass to implement extern_elementwise atomic functions | ||
| // This must run after MLIR->LLVM conversion but before final optimizations | ||
| pm->addPass( | ||
| mlir::triton::xla::CreateTritonXLAImplementExternAtomicsROCmPass()); |
There was a problem hiding this comment.
Use it for cuda too. Should be fairly late in the pipeline though.
There was a problem hiding this comment.
I would prefer to merge the ROCm support first, and then extend it to Cuda as we do not really have enough Nvidia GPU versions to make sure that this change will not affect Nvidia performances.
There was a problem hiding this comment.
I liberally have to check one, because of ptx. Just to see the lowering.
| pm->addPass(mt::createTritonAMDGPULowerInstructionSchedHintsPass( | ||
| rocm_cc.gfx_version(), num_stages)); | ||
| } | ||
| pm->addPass(mt::createConvertBuiltinFuncToLLVMPass(rocm_cc.gfx_version(), |
There was a problem hiding this comment.
Why would we want to delay the pass until this point?
This type of external calls are usually lowered in the ConvertTritonAMDGPUToLLVMPass pass. So, earlier in the pipeline. I would have been more inclined to put the pass early in the pipeline, right after the ConvertTritonAMDGPUToLLVMPass pass.
There was a problem hiding this comment.
ConvertBuiltinFuncToLLVMPass serves the same purpose as yours and its last one, so I presume it is better not to expand your calls which may give other passes a change to mess with them in expanded form.
| #include "rocm/include/hip/hip_runtime.h" | ||
| #include "rocm/include/hip/hip_version.h" | ||
| #include "rocm/rocm_config.h" | ||
| #include <unistd.h> |
There was a problem hiding this comment.
Avoid drive-by formatting changes if possible.
There was a problem hiding this comment.
I tried to avoid it, but otherwise the clang-format check fails
There was a problem hiding this comment.
Hmm. Odd. Must be a local thing. we might have pushed thing w/o checking the CI.
|
|
||
| // Check peer access capability and cache the result | ||
| TF_ASSIGN_OR_RETURN(hipDevice_t peer_device, GetDevice(i)); | ||
| peer_access_cache_[i] = CanEnablePeerAccess(device_, peer_device); |
There was a problem hiding this comment.
I don't think there is need to cache this. Leave it for the separate effort.
There was a problem hiding this comment.
This mimics the CUDA implementation. And since this function is called every time a Collective Emitter thunk is launched, I think it’s better to use a cache rather than calling the API each time.
There was a problem hiding this comment.
Oh I see. I guess they have some good reason why they don't cache the value. This used to get called only on clinet creation and thunk prepare later on. Maybe just use std::vector. You can retrofit the other overloads of CanEnablePeerAccessTo to use it while you are at it. Might still be better to split it from overall change.
| // Check if ptr is a tensor type (vectorized operation) | ||
| auto tensor_type = mlir::dyn_cast<mlir::RankedTensorType>(ptr.getType()); | ||
|
|
||
| if (tensor_type) { |
There was a problem hiding this comment.
Maybe unify both paths. Split to args prepare + one elementwise_extern emit at the end.
| mlir::RankedTensorType::get(tensor_type.getShape(), elem_type); | ||
|
|
||
| // Create expected value tensor (splat scalar to tensor) | ||
| auto expected_tensor = |
There was a problem hiding this comment.
There is no need to handle scalar case as above?
| namespace { | ||
|
|
||
| // Helper to parse syncscope from function name | ||
| // Function names follow pattern: xla_atomic_*_<semantic>_<scope>[_<comparator>] |
There was a problem hiding this comment.
comparator is not optional nor is scope. just turn it in std::string::contains
|
|
||
| // Replace each call inline | ||
| for (auto call_op : calls_to_replace) { | ||
| std::string callee_name = call_op.getCallee()->str(); |
xla/backends/gpu/codegen/triton/transforms/triton_xla_implement_extern_atomics_rocm_pass.cc
Show resolved
Hide resolved
|
|
||
| // Atomic block: perform atomic exchange | ||
| builder.setInsertionPointToStart(atomic_block); | ||
| auto atomic_xchg = LLVM::AtomicRMWOp::create( |
There was a problem hiding this comment.
Use atomic store https://mlir.llvm.org/docs/Dialects/LLVM/#llvmstore-llvmstoreop. This is not optimal. Use https://mlir.llvm.org/docs/Dialects/LLVM/#llvmmlirpoison-llvmpoisonop for result.
There was a problem hiding this comment.
We expect result not to be used. If it is it is a bug.
| // Loop block: spin wait | ||
| builder.setInsertionPointToStart(loop_block); | ||
| auto loaded = LLVM::LoadOp::create( | ||
| builder, loc, i32_type, addr, 4, false, false, false, false, |
There was a problem hiding this comment.
Is it atomic load? It is hard to follow. Maybe inline comment (/* arg_name */ false) each of the bool args.
There was a problem hiding this comment.
NVM. I see it from the test, but still comment.
| mlir::ValueRange{exit_block->getArgument(0)}); | ||
| call_op.erase(); | ||
| } else { | ||
| // Unmasked spin wait: direct loop |
There was a problem hiding this comment.
Can you unify these two paths. Maybe via lambda that you can call at both places.
|
|
||
| // Clean up unused extern function declarations | ||
| llvm::SmallVector<LLVM::LLVMFuncOp> to_erase; | ||
| module.walk([&](LLVM::LLVMFuncOp func) { |
| // Function names follow pattern: xla_atomic_*_<semantic>_<scope>[_<comparator>] | ||
| std::string ParseSyncScope(const std::string& func_name) { | ||
| // Per AMDGPU memory model (Table 31): | ||
| // - "" (empty) = system scope (cross-device visibility) |
| // CHECK-NOT: llvm.call @xla_get_thread_id | ||
| // CHECK: [[TID:%.*]] = llvm.call_intrinsic "llvm.amdgcn.workitem.id.x"() : () -> i32 | ||
| // CHECK: llvm.return [[TID]] | ||
| %tid = llvm.call @xla_get_thread_id() : () -> i32 |
There was a problem hiding this comment.
maybe prefix them with __triton_xla
| auto value = operands[1]; | ||
| mlir::Value mask = operands.size() > 2 ? operands[2] : mlir::Value{}; | ||
|
|
||
| std::string syncscope = ParseSyncScope(callee_name); |
| // Exit block: phi node to select result | ||
| exit_block->addArgument(i32_type, loc); | ||
| call_op.replaceAllUsesWith( | ||
| mlir::ValueRange{exit_block->getArgument(0)}); |
There was a problem hiding this comment.
Is the result expeced to be used. If not poison it.
📝 Summary of Changes
This PR enables the ROCm Triton backend for AllReduce (collective emitter).
To this end:
triton_xla.atomic_writeandtriton_xla.atomic_spin_waitandtriton_xla.get_tid. These passes rely onextern_elementwisetriton operations. thereby avoiding the use of target specific inline assembly. Theextern_elementwiseops are then caught later in the compilater pipeline and replaced by llvm intrinsics.RocmExecutor::CanEnablePeerAccessTo(int other_device_ordinal)(this API is required to enable collective_emitter thunk).🎯 Justification
Prior to this PR, the
triton_xla.get_tid,triton_xla.atomic_writeandtriton_xla.atomic_spin_waitoperations were lowered using PTX assembly. Therefore, AllReduce triton backend was only available for CUDA target.This PR adds a new way to lower these operations only Triton operations using
extern_elementwise.Thanks to that, the triton backend for AllReduce is now available for ROCm target.
🚀 Kind of Contribution
Please remove what does not apply: ✨ New Feature
🧪 Unit Tests:
This PR includes a LIT test checking the lowering of atomic operations.