Skip to content

Enable ROCm Triton backend for AllReduce#684

Draft
mfrancepillois wants to merge 4 commits intorocm-jaxlib-v0.9.1from
ci_maxime_allreduce_triton_rocm_elementwise_rocm
Draft

Enable ROCm Triton backend for AllReduce#684
mfrancepillois wants to merge 4 commits intorocm-jaxlib-v0.9.1from
ci_maxime_allreduce_triton_rocm_elementwise_rocm

Conversation

@mfrancepillois
Copy link

📝 Summary of Changes
This PR enables the ROCm Triton backend for AllReduce (collective emitter).
To this end:

  • Add a new passes to lower atomic operations triton_xla.atomic_write and triton_xla.atomic_spin_wait and triton_xla.get_tid. These passes rely on extern_elementwise triton operations. thereby avoiding the use of target specific inline assembly. The extern_elementwise ops are then caught later in the compilater pipeline and replaced by llvm intrinsics.
  • Add missing API to rocm_executor RocmExecutor::CanEnablePeerAccessTo(int other_device_ordinal) (this API is required to enable collective_emitter thunk).

🎯 Justification
Prior to this PR, the triton_xla.get_tid, triton_xla.atomic_write and triton_xla.atomic_spin_wait operations were lowered using PTX assembly. Therefore, AllReduce triton backend was only available for CUDA target.
This PR adds a new way to lower these operations only Triton operations using extern_elementwise.
Thanks to that, the triton backend for AllReduce is now available for ROCm target.

🚀 Kind of Contribution
Please remove what does not apply: ✨ New Feature

🧪 Unit Tests:
This PR includes a LIT test checking the lowering of atomic operations.

@mfrancepillois mfrancepillois force-pushed the ci_maxime_allreduce_triton_rocm_elementwise_rocm branch from 08ee4a7 to 1daa0d9 Compare March 18, 2026 14:51
@claude
Copy link

claude bot commented Mar 18, 2026

Claude Code Review Summary

This PR enables the ROCm Triton backend for AllReduce collective operations via tt.extern_elementwise ops and a ROCm-specific LLVM IR implementation pass. The overall architecture — two-stage lowering (high-level ops → extern calls → platform-specific LLVM atomics) — is clean and extensible.

Key issues found (see inline comments for details):

  • Masks silently ignored in both LowerAtomicWriteOp and LowerAtomicSpinWaitOp — the CUDA pass handles masks via predicated stores/waits, but the extern lowering discards them entirely. The spin wait case can cause GPU hangs.
  • Signed vs unsigned comparison mismatch — the ROCm spin wait uses slt (signed) while the CUDA PTX uses setp.lt.u32 (unsigned), which can produce incorrect results for large signal values.
  • Syncscope "one-as" for system scope may not provide cross-device visibility guarantees on AMDGPU — needs verification against the AMDGPU memory model docs.
  • No validation of memory semantics — the CUDA pass rejects unsupported semantics (e.g., ACQUIRE for writes); this pass silently accepts them.
  • default: return "unknown" in enum-to-string switches will generate unmatchable function names instead of failing early.
  • Minor: duplicated scope-parsing logic, fragile substring matching, misplaced #include.

Missing test coverage for TritonXLAImplementExternAtomicsROCmPass and TritonXLALowerExternGetTidPass.

@claude
Copy link

claude bot commented Mar 18, 2026

Re-review Summary

Re-reviewed the latest diff. All 8 previously flagged issues remain unaddressed — no new findings. Key open items:

  • Masks silently discarded in both LowerAtomicWriteOp and LowerAtomicSpinWaitOp (risk of GPU hang / memory corruption)
  • Signed vs unsigned comparison mismatch with CUDA path in spin-wait
  • Incorrect AMDGPU syncscopes ("one-as" alone likely insufficient for system scope)
  • No semantic validation and hardcoded atomic orderings ignoring function-name-encoded semantics
  • Fragile substring matching and duplicated scope-parsing logic
  • Missing test coverage for TritonXLAImplementExternAtomicsROCmPass and TritonXLALowerExternGetTidPass
  • Include ordering issue in rocm_executor.cc

See existing inline comments for details.

🤖 Generated with Claude Code

@mfrancepillois mfrancepillois force-pushed the ci_maxime_allreduce_triton_rocm_elementwise_rocm branch from a866b69 to e885196 Compare March 18, 2026 16:00
@claude
Copy link

claude bot commented Mar 18, 2026

Re-review Summary

Excellent progress — 7 of 8 previous findings have been addressed in this revision:

  • Mask handling: Both atomic_write and atomic_spin_wait now properly pass and handle masks
  • Exhaustive switches: default cases replaced with exhaustive switch + LOG(FATAL)
  • Semantic validation: Unsupported semantics now rejected via notifyMatchFailure
  • Sync scope: System scope corrected to empty string per AMDGPU convention
  • Unsigned comparison: Spin-wait comparator now uses ICmpPredicate::ult matching CUDA
  • Helper extraction: ParseSyncScope extracted; EndsWithComparator used for suffix matching

1 minor item remaining: #include <unistd.h> placement in rocm_executor.cc (between ROCm and XLA headers — should be in the system headers group).

Overall this is looking solid. The two-stage lowering design is clean and the ROCm atomics implementation is well-structured.

🤖 Generated with Claude Code

@claude
Copy link

claude bot commented Mar 18, 2026

Re-review Summary

Good progress since the last review — 6 of 8 previous findings have been addressed (mask handling, validation, syncscope, signed/unsigned comparison, scope parsing). The include ordering issue in rocm_executor.cc is still present.

New findings (see inline comments):

  • Bug (high): SSA domination issue in masked and unmasked spin-wait — replaceAllUsesWith references a value not defined on all control-flow paths (triton_xla_implement_extern_atomics_rocm_pass.cc:213)
  • Bug (high): Scalar atomic write path silently discards the mask operand (triton_xla_lower_extern_atomics_pass.cc:162-169)
  • Nit: Memory ordering hardcoded to release/acquire even when relaxed was requested — safe but suboptimal
  • Nit: Substring matching in ParseSyncScope is fragile
  • Nit: Mask detection by operand count without documentation
  • Testing gap: No tests for masked operations, ROCm implementation pass, or GetTid pass

🤖 Generated with Claude Code

@claude
Copy link

claude bot commented Mar 18, 2026

Re-review Summary

Re-reviewed the latest diff. All previously flagged issues remain applicable — no new findings. Key outstanding items from prior review:

  • SSA dominance bug in spin-wait loop (triton_xla_implement_extern_atomics_rocm_pass.cc:213,248) — loaded value defined inside loop block used after exit
  • Hardcoded release ordering ignoring _relaxed_ semantic in function names
  • Missing test coverage for the ROCm-specific atomics pass and the get_tid pass

🤖 Generated with Claude Code

@claude
Copy link

claude bot commented Mar 18, 2026

Re-review Summary (commit 90ec42c)

All previously flagged issues have been addressed in this revision. Key fixes verified:

  • Mask handling: Scalar and vectorized mask paths now handled correctly with proper operand count checks
  • Memory ordering: Parsed from function names instead of hardcoded
  • SSA domination: Fixed via block arguments
  • Sync scope parsing: Uses robust field extraction
  • Unsigned comparison: Spin-wait loop now uses unsigned icmp as required
  • Test coverage: Comprehensive tests added for masked/unmasked operations in both lowering and ROCm implementation passes

No new issues found. One minor prior note (include ordering of <unistd.h> in rocm_executor.cc) remains outstanding from the initial review.

🤖 Generated with Claude Code

#if defined(TENSORFLOW_USE_ROCM)
// ROCm: Use constant value directly as ROCDL dialect doesn't define memory
// space enum
static constexpr int32_t kGlobalAddressSpace = 1;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bool is_supported = false;

// CUDA: Requires compute capability 9.0+ (Hopper or newer)
if (device_info.cuda_compute_capability().major >= 9) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where did this check exist before this change?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Add XLA custom pass to implement extern_elementwise atomic functions
// This must run after MLIR->LLVM conversion but before final optimizations
pm->addPass(
mlir::triton::xla::CreateTritonXLAImplementExternAtomicsROCmPass());

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use it for cuda too. Should be fairly late in the pipeline though.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to merge the ROCm support first, and then extend it to Cuda as we do not really have enough Nvidia GPU versions to make sure that this change will not affect Nvidia performances.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I liberally have to check one, because of ptx. Just to see the lowering.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to see the ptx.

pm->addPass(mt::createTritonAMDGPULowerInstructionSchedHintsPass(
rocm_cc.gfx_version(), num_stages));
}
pm->addPass(mt::createConvertBuiltinFuncToLLVMPass(rocm_cc.gfx_version(),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we want to delay the pass until this point?
This type of external calls are usually lowered in the ConvertTritonAMDGPUToLLVMPass pass. So, earlier in the pipeline. I would have been more inclined to put the pass early in the pipeline, right after the ConvertTritonAMDGPUToLLVMPass pass.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConvertBuiltinFuncToLLVMPass serves the same purpose as yours and its last one, so I presume it is better not to expand your calls which may give other passes a change to mess with them in expanded form.

#include "rocm/include/hip/hip_runtime.h"
#include "rocm/include/hip/hip_version.h"
#include "rocm/rocm_config.h"
#include <unistd.h>

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid drive-by formatting changes if possible.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to avoid it, but otherwise the clang-format check fails

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. Odd. Must be a local thing. we might have pushed thing w/o checking the CI.


// Check peer access capability and cache the result
TF_ASSIGN_OR_RETURN(hipDevice_t peer_device, GetDevice(i));
peer_access_cache_[i] = CanEnablePeerAccess(device_, peer_device);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is need to cache this. Leave it for the separate effort.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mimics the CUDA implementation. And since this function is called every time a Collective Emitter thunk is launched, I think it’s better to use a cache rather than calling the API each time.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. I guess they have some good reason why they don't cache the value. This used to get called only on clinet creation and thunk prepare later on. Maybe just use std::vector. You can retrofit the other overloads of CanEnablePeerAccessTo to use it while you are at it. Might still be better to split it from overall change.

// Check if ptr is a tensor type (vectorized operation)
auto tensor_type = mlir::dyn_cast<mlir::RankedTensorType>(ptr.getType());

if (tensor_type) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe unify both paths. Split to args prepare + one elementwise_extern emit at the end.

mlir::RankedTensorType::get(tensor_type.getShape(), elem_type);

// Create expected value tensor (splat scalar to tensor)
auto expected_tensor =

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to handle scalar case as above?

namespace {

// Helper to parse syncscope from function name
// Function names follow pattern: xla_atomic_*_<semantic>_<scope>[_<comparator>]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comparator is not optional nor is scope. just turn it in std::string::contains


// Replace each call inline
for (auto call_op : calls_to_replace) {
std::string callee_name = call_op.getCallee()->str();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep the StringRef


// Atomic block: perform atomic exchange
builder.setInsertionPointToStart(atomic_block);
auto atomic_xchg = LLVM::AtomicRMWOp::create(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We expect result not to be used. If it is it is a bug.

// Loop block: spin wait
builder.setInsertionPointToStart(loop_block);
auto loaded = LLVM::LoadOp::create(
builder, loc, i32_type, addr, 4, false, false, false, false,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it atomic load? It is hard to follow. Maybe inline comment (/* arg_name */ false) each of the bool args.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVM. I see it from the test, but still comment.

mlir::ValueRange{exit_block->getArgument(0)});
call_op.erase();
} else {
// Unmasked spin wait: direct loop

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you unify these two paths. Maybe via lambda that you can call at both places.


// Clean up unused extern function declarations
llvm::SmallVector<LLVM::LLVMFuncOp> to_erase;
module.walk([&](LLVM::LLVMFuncOp func) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice touch!

// Function names follow pattern: xla_atomic_*_<semantic>_<scope>[_<comparator>]
std::string ParseSyncScope(const std::string& func_name) {
// Per AMDGPU memory model (Table 31):
// - "" (empty) = system scope (cross-device visibility)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for nvptx backend.

// CHECK-NOT: llvm.call @xla_get_thread_id
// CHECK: [[TID:%.*]] = llvm.call_intrinsic "llvm.amdgcn.workitem.id.x"() : () -> i32
// CHECK: llvm.return [[TID]]
%tid = llvm.call @xla_get_thread_id() : () -> i32

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe prefix them with __triton_xla

auto value = operands[1];
mlir::Value mask = operands.size() > 2 ? operands[2] : mlir::Value{};

std::string syncscope = ParseSyncScope(callee_name);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StringRefs

// Exit block: phi node to select result
exit_block->addArgument(i32_type, loc);
call_op.replaceAllUsesWith(
mlir::ValueRange{exit_block->getArgument(0)});

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the result expeced to be used. If not poison it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants