Skip to content

Backport from upstream to 0.9.1 branch#672

Open
hsharsha wants to merge 13 commits intorocm-jaxlib-v0.9.1from
v0.9.1-test-ci
Open

Backport from upstream to 0.9.1 branch#672
hsharsha wants to merge 13 commits intorocm-jaxlib-v0.9.1from
v0.9.1-test-ci

Conversation

@hsharsha
Copy link
Collaborator

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

@claude
Copy link

claude bot commented Mar 17, 2026

Claude Review Summary

This PR changes only README.md, downgrading the H1 heading to H2. Since it is a dummy commit to trigger CI, it should not be merged as-is — see inline comment for details and a suggested fix.

zoranjovanovic-ns and others added 8 commits March 18, 2026 11:01
Imported from GitHub PR openxla#38507

📝 Summary of Changes
Updated functions from triton/support.c and triton/support_test to execute
correctly on ROCm.

🎯 Justification
support_test was failing on ROCm.
🚀 Kind of Contribution
Please remove what does not apply: 🐛 Bug Fix, ⚡️ Performance Improvement,
✨ New Feature, ♻️ Cleanup, 📚 Documentation, 🧪 Tests

📊 Benchmark (for Performance Improvements)
Please measure and include speedups for one of the public HLOs in
`compiler/xla/tools/benchmarks/hlo/`.

🧪 Unit Tests:
Used existing triton/support_test

🧪 Execution Tests:
What execution tests were added? For example, a new optimization should be
tested with an end-to-end execution test triggering the optimization and
asserting correctness. Please provide test cases running with at most 2 GPUs.

Copybara import of the project:

--
e067431 by zoranjovanovic-ns <126815388+zoranjovanovic-ns@users.noreply.github.com>:

Fixed triton_support_test on rocm.

--
16f1f07 by Zoran Jovanovic <zjovanov@amd.com>:

Review comments.

--
731fb44 by Zoran Jovanovic <zjovanov@amd.com>:

Code review 2

Merging this change closes openxla#38507

COPYBARA_INTEGRATE_REVIEW=openxla#38507 from ROCm:ci_rocm-fix-triton-support-4 731fb44
PiperOrigin-RevId: 877900520
…evice_test for ROCm.

Imported from GitHub PR openxla#38742

📝 Summary of Changes
Created expected output for FuseSubchannelDequantizationWithTranspose in triton/fusion_emitter_int4_device_test on ROCm.

🎯 Justification
triton/fusion_emitter_int4_device_test  was failing

🚀 Kind of Contribution
Please remove what does not apply: 🧪 Tests

📊 Benchmark (for Performance Improvements)
Please measure and include speedups for one of the public HLOs in
`compiler/xla/tools/benchmarks/hlo/`.

🧪 Unit Tests:
triton/fusion_emitter_int4_device_test

🧪 Execution Tests:
What execution tests were added? For example, a new optimization should be
tested with an end-to-end execution test triggering the optimization and
asserting correctness. Please provide test cases running with at most 2 GPUs.

Copybara import of the project:

--
d00e6d1 by Zoran Jovanovic <zjovanov@amd.com>:

Fix expected output in fusion_emitter_int4_device_test for ROCm.

Merging this change closes openxla#38742

COPYBARA_INTEGRATE_REVIEW=openxla#38742 from ROCm:rocm-fusion_emitter_int4_device_test d00e6d1
PiperOrigin-RevId: 880942122
…u in gpu_triton_cu…

Imported from GitHub PR openxla#38801

…stom_call_test for ROCm

📝 Summary of Changes
skipped CanNotEmitTritonCustomCallOnPreAmpereGpu in gpu_triton_custom_call_test for ROCm

🎯 Justification
Unit test was failing because it works on ROCm

🚀 Kind of Contribution
Please remove what does not apply:  🧪 Tests

📊 Benchmark (for Performance Improvements)
Please measure and include speedups for one of the public HLOs in
`compiler/xla/tools/benchmarks/hlo/`.

🧪 Unit Tests:
gpu_triton_custom_call_test

🧪 Execution Tests:
What execution tests were added? For example, a new optimization should be
tested with an end-to-end execution test triggering the optimization and
asserting correctness. Please provide test cases running with at most 2 GPUs.

Copybara import of the project:

--
6cf15ac by Zoran Jovanovic <zjovanov@amd.com>:

[ROCm] Skip CanNotEmitTritonCustomCallOnPreAmpereGpu in gpu_triton_custom_call_test for ROCm

Merging this change closes openxla#38801

COPYBARA_INTEGRATE_REVIEW=openxla#38801 from ROCm:rocm-fix-gpu_triton_custom_call_test 6cf15ac
PiperOrigin-RevId: 881953512
…egacy and test itself.

Imported from GitHub PR openxla#38759

📝 Summary of Changes
Modified IsDotAlgorithmSupportedByTriton to reflect implementation on ROCm.

🎯 Justification
triton/dot_algorithms_test was failing

🚀 Kind of Contribution
Please remove what does not apply: 🐛 Bug Fix,  🧪 Tests

📊 Benchmark (for Performance Improvements)
Please measure and include speedups for one of the public HLOs in
`compiler/xla/tools/benchmarks/hlo/`.

🧪 Unit Tests:
triton/dot_algorithms_test

🧪 Execution Tests:
What execution tests were added? For example, a new optimization should be
tested with an end-to-end execution test triggering the optimization and
asserting correctness. Please provide test cases running with at most 2 GPUs.

Copybara import of the project:

--
757877c by Zoran Jovanovic <zjovanov@amd.com>:

Fixed dot_algorithms_test. Updated support_legacy and test itself.

--
5685c9c by Zoran Jovanovic <zjovanov@amd.com>:

Review comments.

--
d801633 by Zoran Jovanovic <zjovanov@amd.com>:

Review comment.

Merging this change closes openxla#38759

COPYBARA_INTEGRATE_REVIEW=openxla#38759 from ROCm:rocm-fix-dot_algorithms_test d801633
PiperOrigin-RevId: 885489190
@hsharsha hsharsha changed the title Dummy commit to trigger CI Backport from upstream to 0.9.1 branch Mar 18, 2026
Comment on lines 195 to +198
llvm::SmallVector<Value, 2> casted_inputs;
if (output_type_is_16bit_float) {
if (output_type_is_16bit_float &&
(output_type.isF16() &&
!HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_))) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bug: The condition output_type_is_16bit_float && (output_type.isF16() && !HasF16Implementation(...)) means BF16 inputs will never be upcasted to F32 on any platform anymore. Before this change, the condition was just output_type_is_16bit_float, so both BF16 and F16 were upcasted to F32 before calling libdevice functions. Now, BF16 will be passed directly to the extern function call without upcast, but libdevice does not have native BF16 implementations -- the function name will still resolve to _f32 via ObtainDeviceFunctionName, causing a type mismatch between the BF16 operands and the F32 function signature.

The same issue applies to the downcast condition on line 225-227: (res.getType() != original_output_type) || (output_type.isF16() && !HasF16Implementation(...)) -- for BF16, res.getType() will equal original_output_type (both BF16), so no truncation happens even though the extern call expects/returns F32.

Consider restructuring the upcast condition to something like:

if (output_type_is_16bit_float &&
    !(output_type.isF16() &&
      HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_))) {

This preserves the existing BF16 upcast behavior while skipping it only for F16 ops that have native F16 implementations on AMD.

(gpu_version.IsRocm() &&
reduce.shape().element_type() == PrimitiveType::F8E4M3FNUZ)) {
return CodegenDecision::Forbid(
"F8E4M3FN and F8E5M2 are not supported for reductions.");
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This error message is now stale -- the condition also rejects F8E5M2FNUZ and F8E4M3FNUZ on ROCm, but the message only mentions F8E4M3FN and F8E5M2. Consider updating to include the FNUZ variants so the error is not misleading when debugging on ROCm.

Suggested change
"F8E4M3FN and F8E5M2 are not supported for reductions.");
"F8 types are not supported for reductions.");

return CodegenDecision::Forbid(
"Unsupported GPU architecture for F8E4M3FN/F8E5M2.");
case F8E4M3FNUZ:
"Unsupported GPU architecture for F8E4M3FN.");
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This error message mentions only F8E4M3FN, but the case block covers both F8E5M2 and F8E4M3FN. Consider updating it to mention both types.

Suggested change
"Unsupported GPU architecture for F8E4M3FN.");
"Unsupported GPU architecture for F8E4M3FN/F8E5M2.");


bool HasF16Implementation(TargetDeviceFunctionID func_id,
llvm::Triple target_triple) {
if(target_triple.isAMDGPU() &&
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Missing space between if and ( -- should be if (target_triple.isAMDGPU() per the project's coding style.

Suggested change
if(target_triple.isAMDGPU() &&
if (target_triple.isAMDGPU() &&

}

TEST(CommandBufferConversionPassTest, ConvertWhileThunk) {
if(IsRocm()) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Missing space between if and ( -- should be if (IsRocm()). Same issue at line 746.

Suggested change
if(IsRocm()) {
if (IsRocm()) {

Comment on lines 195 to +198
llvm::SmallVector<Value, 2> casted_inputs;
if (output_type_is_16bit_float) {
if (output_type_is_16bit_float &&
(output_type.isF16() &&
!HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_))) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: BF16 libdevice calls broken by this change.

The new condition:

output_type_is_16bit_float && (output_type.isF16() && !HasF16Implementation(...))

is logically equivalent to:

output_type.isF16() && !HasF16Implementation(...)

because isF16() already implies output_type_is_16bit_float. This means BF16 operands are no longer upcasted to F32 before calling libdevice functions.

However, libdevice/OCML functions don't have native BF16 variants — ObtainDeviceFunctionName will still return an _f32 function name for BF16 inputs. The result is a type mismatch: BF16 operands passed to an F32 function signature, which will produce incorrect results or crash.

The same issue affects the downcast path on lines 225–227 — BF16 outputs are no longer truncated back from F32.

Previously, all 16-bit float types (both F16 and BF16) were upcasted to F32, which was correct for BF16. The fix for F16-on-AMDGPU should preserve the BF16→F32 upcast. Suggested fix:

if (output_type_is_16bit_float &&
    !(output_type.isF16() && HasF16Implementation(...))) {

This upcasts BF16 always, and upcasts F16 only when there's no native F16 implementation.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is covered by #660

Comment on lines 316 to +322
const HloReduceInstruction& reduce,
const se::GpuComputeCapability& gpu_version) {
if (reduce.shape().element_type() == PrimitiveType::F8E4M3FN ||
reduce.shape().element_type() == PrimitiveType::F8E5M2) {
reduce.shape().element_type() == PrimitiveType::F8E5M2 ||
(gpu_version.IsRocm() &&
reduce.shape().element_type() == PrimitiveType::F8E5M2FNUZ) ||
(gpu_version.IsRocm() &&
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The error message says "F8E4M3FN and F8E5M2 are not supported for reductions" but the condition now also rejects F8E5M2FNUZ and F8E4M3FNUZ. Please update the message to mention the FNUZ types as well for accurate debugging.

Suggested change
const HloReduceInstruction& reduce,
const se::GpuComputeCapability& gpu_version) {
if (reduce.shape().element_type() == PrimitiveType::F8E4M3FN ||
reduce.shape().element_type() == PrimitiveType::F8E5M2) {
reduce.shape().element_type() == PrimitiveType::F8E5M2 ||
(gpu_version.IsRocm() &&
reduce.shape().element_type() == PrimitiveType::F8E5M2FNUZ) ||
(gpu_version.IsRocm() &&
if (primitive_util::IsFloat8Type(type)) {
return CodegenDecision::Forbid(
"F8E4M3FN, F8E5M2, F8E4M3FNUZ, and F8E5M2FNUZ are not supported "
"for reductions.");
}


bool HasF16Implementation(TargetDeviceFunctionID func_id,
llvm::Triple target_triple) {
if(target_triple.isAMDGPU() &&
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Missing space after if keyword — should be if (target_triple.isAMDGPU() per Google C++ style.

}

TEST(CommandBufferConversionPassTest, ConvertWhileThunk) {
if(IsRocm()) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Missing space after if keyword — if(IsRocm())if (IsRocm()). Same issue at line 746.

Comment on lines +1627 to +1638
case PC::ALG_DOT_BF16_BF16_F32_X9:
if (GpuComputeComp().IsRocm()) {
// X6 and X9 algorithms on ROCm marked as not supported
// because they often require too much shared memory.
EXPECT_FALSE(result_or_status.value())
<< "algorithms not supported on ROCm";
} else {
ASSERT_TRUE(result_or_status.status().ok())
<< "failed to compile " << algorithm_;
EXPECT_TRUE(result_or_status.value())
<< "wrong result for " << algorithm_;
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Missing status assertion in ROCm branch.

The CUDA path correctly does ASSERT_TRUE(result_or_status.status().ok()) before calling .value(), but the ROCm branch goes straight to EXPECT_FALSE(result_or_status.value()). If result_or_status holds a non-OK status, calling .value() will crash.

Please add the status assertion:

#if TENSORFLOW_USE_ROCM
    ASSERT_TRUE(result_or_status.status().ok());
    EXPECT_FALSE(result_or_status.value());
#else

Comment on lines 759 to +764
return CodegenDecision(
element_type != PrimitiveType::F8E4M3FN &&
element_type != PrimitiveType::F8E5M2 &&
!(gpu_version.IsRocm() &&
element_type == PrimitiveType::F8E4M3FNUZ) &&
!(gpu_version.IsRocm() &&
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Same stale error message issue — the condition now blocks FNUZ types too, but the message only mentions "F8E4M3FN, F8E5M2 and S4". Please add the FNUZ types to the message.

Comment on lines 324 to 325
return CodegenDecision::Forbid(
"F8E4M3FN and F8E5M2 are not supported for reductions.");
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The error message still reads "F8E4M3FN and F8E5M2 are not supported for reductions." but the condition now also forbids F8E5M2FNUZ and F8E4M3FNUZ on ROCm. Consider updating the message to mention the FNUZ variants so it is accurate when triggered on ROCm.

Suggested change
return CodegenDecision::Forbid(
"F8E4M3FN and F8E5M2 are not supported for reductions.");
return CodegenDecision::Forbid(
"F8E4M3FN, F8E5M2, F8E4M3FNUZ and F8E5M2FNUZ are not supported for "
"reductions.");

if (GpuComputeComp().IsRocm()) {
// X6 and X9 algorithms on ROCm marked as not supported
// because they often require too much shared memory.
EXPECT_FALSE(result_or_status.value())
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bug: On the ROCm path you call result_or_status.value() without first asserting that result_or_status.status().ok(). If the status is not OK, .value() will either crash or produce a confusing test failure.

The non-ROCm branch correctly guards with ASSERT_TRUE(result_or_status.status().ok()) before calling .value(). Consider adding the same guard here:

      if (GpuComputeComp().IsRocm()) {
        // X6 and X9 algorithms on ROCm marked as not supported
        // because they often require too much shared memory.
        ASSERT_TRUE(result_or_status.status().ok())
            << "failed to compile " << algorithm_;
        EXPECT_FALSE(result_or_status.value())
            << "algorithms not supported on ROCm";

Comment on lines +228 to +239
case F8E5M2FNUZ:
if (gpu_cc.has_value() && !gpu_cc.value().IsRocm()) {
return absl::UnimplementedError(
"F8E5M2FNUZ is not supported on this GPU.");
}
return b.getType<mlir::Float8E5M2FNUZType>();
case F8E4M3FNUZ:
if (gpu_cc.has_value() && !gpu_cc.value().IsRocm()) {
return absl::UnimplementedError(
"F8E4M3FNUZ is not supported on this GPU.");
}
return b.getType<mlir::Float8E4M3FNUZType>();
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default-allow concern for FNUZ types. When gpu_cc is std::nullopt (the default), FNUZ types are unconditionally allowed. This means callers that don't pass a gpu_cc (e.g., tests, other XTile usages) will silently accept FNUZ types even on CUDA, where they're unsupported. Consider inverting the logic to require explicit ROCm opt-in rather than only forbidding when explicitly non-ROCm.

Comment on lines +196 to +197
if (output_type_is_16bit_float &&
(output_type.isF16() &&
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: output_type_is_16bit_float is redundant here -- if output_type.isF16() is true, then output_type_is_16bit_float is necessarily also true (F16 is a 16-bit float). The outer check can be dropped to simplify the condition:

Suggested change
if (output_type_is_16bit_float &&
(output_type.isF16() &&
if (output_type.isF16() &&


bool HasF16Implementation(TargetDeviceFunctionID func_id,
llvm::Triple target_triple) {
if(target_triple.isAMDGPU() &&
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Missing space between if and ( -- the rest of the codebase uses if (...).

Suggested change
if(target_triple.isAMDGPU() &&
if (target_triple.isAMDGPU() &&

}

TEST(CommandBufferConversionPassTest, ConvertWhileThunk) {
if(IsRocm()) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Missing space between if and ( -- same issue on line 702.

Suggested change
if(IsRocm()) {
if (IsRocm()) {

case F8E5M2:
case F8E4M3FN:
if (gpu_version.IsCuda()) {
if (gpu_version.IsCuda() || gpu_version.IsRocm()) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: The old code required has_ocp_fp8_support() for F8E5M2/F8E4M3FN on ROCm, but the refactored check now unconditionally allows these types for any ROCm device. Was the capability check intentionally dropped, or should older ROCm devices without OCP FP8 support still be gated?

@claude
Copy link

claude bot commented Mar 18, 2026

Re-review Summary (Claude)

Previous finding on README.md H1→H2 downgrade is resolved (no longer in diff).

New Findings

2 bugs:

  • BF16 libdevice regression (triton_xla_math_to_libdevice.cc): The upcast/downcast condition change accidentally drops BF16→F32 conversion, causing type mismatches with libdevice functions that only have F32 variants.
  • Missing status assertion (dot_algorithms_test.cc): ROCm branch calls .value() without first asserting .status().ok(), risking a crash on non-OK status.

3 nits:

  • Stale error messages in support.cc — FNUZ types now rejected but not mentioned in messages (2 locations).
  • Missing space after if keyword in target_util.cc and command_buffer_conversion_pass_test.cc.

2 design considerations:

  • ALG_DOT_BF16_BF16_F32_X6/X9 allowed in new path but rejected in legacy path — potential inconsistency.
  • FNUZ types default-allowed in emitter_helpers.cc when gpu_cc is nullopt — could silently accept on CUDA.

See inline comments for details and suggested fixes.

🤖 Generated with Claude Code

Comment on lines 324 to 325
return CodegenDecision::Forbid(
"F8E4M3FN and F8E5M2 are not supported for reductions.");
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The condition now also forbids F8E5M2FNUZ and F8E4M3FNUZ on ROCm (line 320–321), but the error message still says only "F8E4M3FN and F8E5M2 are not supported for reductions." The same issue exists for the iota error message around line 767. Consider updating both messages to mention the FNUZ variants so the diagnostic is accurate for ROCm users.

Suggested change
return CodegenDecision::Forbid(
"F8E4M3FN and F8E5M2 are not supported for reductions.");
return FailedPrecondition(
"F8E4M3FN, F8E5M2, F8E5M2FNUZ, and F8E4M3FNUZ are not supported for "

if (GpuComputeComp().IsRocm()) {
// X6 and X9 algorithms on ROCm marked as not supported
// because they often require too much shared memory.
EXPECT_FALSE(result_or_status.value())
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bug: The ROCm branch calls result_or_status.value() without first asserting result_or_status.status().ok(). The CUDA branch (line 1625) does ASSERT_TRUE(result_or_status.status().ok()) before accessing .value(). If the status is not OK, .value() will crash or produce UB. Consider adding the same assertion here:

Suggested change
EXPECT_FALSE(result_or_status.value())
ASSERT_TRUE(result_or_status.status().ok());
EXPECT_FALSE(result_or_status.value());

Comment on lines +196 to +197
if (output_type_is_16bit_float &&
(output_type.isF16() &&
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question/bug: The condition output_type_is_16bit_float && (output_type.isF16() && !HasF16Implementation(...)) is equivalent to output_type.isF16() && !HasF16Implementation(...) — the output_type_is_16bit_float check is redundant since isF16() implies it.

More importantly, this means BF16 inputs are never upcast to F32 anymore (on any target), since the outer && short-circuits when isF16() is false. Previously, all 16-bit floats (including BF16) were upcast to F32. If this is intentional for BF16, please add a comment explaining why. Otherwise, the intended logic is likely:

if (output_type_is_16bit_float &&
    !(output_type.isF16() && HasF16Implementation(op_name, target_triple))) {

This would preserve the BF16→F32 upcast while skipping it only for F16 ops with native AMDGPU support.


bool HasF16Implementation(TargetDeviceFunctionID func_id,
llvm::Triple target_triple) {
if(target_triple.isAMDGPU() &&
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Missing space after if keyword — Google C++ style requires if ( not if(.

Suggested change
if(target_triple.isAMDGPU() &&
if (target_triple.isAMDGPU() &&

}

TEST(CommandBufferConversionPassTest, ConvertWhileThunk) {
if(IsRocm()) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Missing space after if keyword (same on line 702). Google C++ style requires if ( not if(.

Suggested change
if(IsRocm()) {
if (IsRocm()) {

@claude
Copy link

claude bot commented Mar 18, 2026

Re-review Summary

Previous finding on the README.md H1→H2 change is now resolved (removed from PR).

New findings (6 inline comments):

  • Potential bug in triton_xla_math_to_libdevice.cc: upcast logic change unintentionally skips BF16→F32 upcast on all targets
  • Missing status assertion in dot_algorithms_test.cc ROCm branch before accessing .value()
  • Stale error message in support.cc — doesn't mention FNUZ types after condition was broadened
  • Question on removing has_ocp_fp8_support() / has_nanoo_fp8_support() capability checks for ROCm
  • Style nits: missing if ( spacing in target_util.cc and command_buffer_conversion_pass_test.cc

See inline comments for details and suggested fixes.

Eetusjo and others added 4 commits March 19, 2026 10:29
… autotuner

Imported from GitHub PR openxla#38792

Enable FissionBackend autotuning for ROCm (rocBLAS + hipBLASLt)

- Added HIPBLASLT_FISSION to backend proto
- Updated factory_rocm.cc to register the backends
- xla_gpu_experimental_disabe_binary_libraries, xla_gpu_enable_cublaslt behavior mirror CUDA

Also minor fix: removed dead return in fission_backend.cc

🚀 Kind of Contribution
✨ New Feature

Copybara import of the project:

--
ad88d8c by Eetu Sjöblom <eetu.sjoblom@amd.com>:

Enable rocblas/hipBLASLt fission on ROCm

Merging this change closes openxla#38792

COPYBARA_INTEGRATE_REVIEW=openxla#38792 from ROCm:ci_rocm_enable_fission_upstream ad88d8c
PiperOrigin-RevId: 884317138

(cherry picked from commit 77ecf53)
…ne_level == 0

Imported from GitHub PR openxla#37074

📝 Summary of Changes
Adapt GpuCompiler::AddConvAndGemmAutotuningPass to match pre refactor behavior of AMDGPUCompiler.

🎯 Justification
For ROCm we need to run miopen backend even when autotuning is disabled in order to decompose back unsupported fused convolutions. There is no runtime fallback.

🚀 Kind of Contribution
🐛 Bug Fix

📊 Benchmark (for Performance Improvements)
N\A

🧪 Unit Tests:
None

🧪 Execution Tests:
None

Copybara import of the project:

--
486498b by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>:

[ROCm] Reenable miopen autotune when xla_gpu_autotune_level == 0

--
feec478 by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>:

Fix //xla/backends/gpu/autotuner:miopen_test_amdgpu_any

Merging this change closes openxla#37074

COPYBARA_INTEGRATE_REVIEW=openxla#37074 from ROCm:miopen_autotune feec478
PiperOrigin-RevId: 877364967
Backport upstream fix that prevents a crash in the BlockPingpong
optimization when applied to FMA (non-MFMA) dot operations. The pass
now uses dyn_cast instead of cast for AMDMfmaEncodingAttr and returns
early when the encoding is not MFMA.
@zoranjovanovic-ns
Copy link

Created PR with review comments from claude:
#698

@hsharsha
Copy link
Collaborator Author

hsharsha commented Mar 20, 2026

Let us merge this. Verified locally on MI300 that sgpu and mgpu UTs pass

//xla/tools/hlo_opt:tests/list_passes.hlo.test                           PASSED in 0.8s
//xla/tools/hlo_opt:tests/run_pass_with_input.hlo.test                   PASSED in 0.5s

Executed 412 out of 428 tests: 428 tests pass.
There were tests whose specified size is too big. Use the --test_verbose_timeout_warnings command line option to see which ones these are.
//xla/service:reduce_scatter_reassociate_test                            PASSED in 0.1s
//xla/service:scatter_simplifier_test                                    PASSED in 0.2s
//xla/service:sharding_propagation_test                                  PASSED in 0.4s
//xla/service:sharding_remover_test                                      PASSED in 0.2s
//xla/tests:collective_ops_e2e_test_amdgpu_any                           PASSED in 107.6s

Executed 17 out of 17 tests: 17 tests pass.
There were tests whose specified size is too big. Use the --test_verbose_timeout_warnings command line option to see which ones these are.

@i-chaochen
Copy link
Collaborator

@hsharsha wondering which script you tested on all UTs? do we still include these stale fiter list? https://github.com/ROCm/xla/blob/rocm-jaxlib-v0.9.1/build_tools/rocm/rocm_xla.bazelrc#L61-L78

I think these filter list are invalid now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants