Backport from upstream to 0.9.1 branch#672
Backport from upstream to 0.9.1 branch#672hsharsha wants to merge 13 commits intorocm-jaxlib-v0.9.1from
Conversation
Claude Review SummaryThis PR changes only |
Imported from GitHub PR openxla#38507 📝 Summary of Changes Updated functions from triton/support.c and triton/support_test to execute correctly on ROCm. 🎯 Justification support_test was failing on ROCm. 🚀 Kind of Contribution Please remove what does not apply: 🐛 Bug Fix, ⚡️ Performance Improvement, ✨ New Feature, ♻️ Cleanup, 📚 Documentation, 🧪 Tests 📊 Benchmark (for Performance Improvements) Please measure and include speedups for one of the public HLOs in `compiler/xla/tools/benchmarks/hlo/`. 🧪 Unit Tests: Used existing triton/support_test 🧪 Execution Tests: What execution tests were added? For example, a new optimization should be tested with an end-to-end execution test triggering the optimization and asserting correctness. Please provide test cases running with at most 2 GPUs. Copybara import of the project: -- e067431 by zoranjovanovic-ns <126815388+zoranjovanovic-ns@users.noreply.github.com>: Fixed triton_support_test on rocm. -- 16f1f07 by Zoran Jovanovic <zjovanov@amd.com>: Review comments. -- 731fb44 by Zoran Jovanovic <zjovanov@amd.com>: Code review 2 Merging this change closes openxla#38507 COPYBARA_INTEGRATE_REVIEW=openxla#38507 from ROCm:ci_rocm-fix-triton-support-4 731fb44 PiperOrigin-RevId: 877900520
…evice_test for ROCm. Imported from GitHub PR openxla#38742 📝 Summary of Changes Created expected output for FuseSubchannelDequantizationWithTranspose in triton/fusion_emitter_int4_device_test on ROCm. 🎯 Justification triton/fusion_emitter_int4_device_test was failing 🚀 Kind of Contribution Please remove what does not apply: 🧪 Tests 📊 Benchmark (for Performance Improvements) Please measure and include speedups for one of the public HLOs in `compiler/xla/tools/benchmarks/hlo/`. 🧪 Unit Tests: triton/fusion_emitter_int4_device_test 🧪 Execution Tests: What execution tests were added? For example, a new optimization should be tested with an end-to-end execution test triggering the optimization and asserting correctness. Please provide test cases running with at most 2 GPUs. Copybara import of the project: -- d00e6d1 by Zoran Jovanovic <zjovanov@amd.com>: Fix expected output in fusion_emitter_int4_device_test for ROCm. Merging this change closes openxla#38742 COPYBARA_INTEGRATE_REVIEW=openxla#38742 from ROCm:rocm-fusion_emitter_int4_device_test d00e6d1 PiperOrigin-RevId: 880942122
…u in gpu_triton_cu… Imported from GitHub PR openxla#38801 …stom_call_test for ROCm 📝 Summary of Changes skipped CanNotEmitTritonCustomCallOnPreAmpereGpu in gpu_triton_custom_call_test for ROCm 🎯 Justification Unit test was failing because it works on ROCm 🚀 Kind of Contribution Please remove what does not apply: 🧪 Tests 📊 Benchmark (for Performance Improvements) Please measure and include speedups for one of the public HLOs in `compiler/xla/tools/benchmarks/hlo/`. 🧪 Unit Tests: gpu_triton_custom_call_test 🧪 Execution Tests: What execution tests were added? For example, a new optimization should be tested with an end-to-end execution test triggering the optimization and asserting correctness. Please provide test cases running with at most 2 GPUs. Copybara import of the project: -- 6cf15ac by Zoran Jovanovic <zjovanov@amd.com>: [ROCm] Skip CanNotEmitTritonCustomCallOnPreAmpereGpu in gpu_triton_custom_call_test for ROCm Merging this change closes openxla#38801 COPYBARA_INTEGRATE_REVIEW=openxla#38801 from ROCm:rocm-fix-gpu_triton_custom_call_test 6cf15ac PiperOrigin-RevId: 881953512
…egacy and test itself. Imported from GitHub PR openxla#38759 📝 Summary of Changes Modified IsDotAlgorithmSupportedByTriton to reflect implementation on ROCm. 🎯 Justification triton/dot_algorithms_test was failing 🚀 Kind of Contribution Please remove what does not apply: 🐛 Bug Fix, 🧪 Tests 📊 Benchmark (for Performance Improvements) Please measure and include speedups for one of the public HLOs in `compiler/xla/tools/benchmarks/hlo/`. 🧪 Unit Tests: triton/dot_algorithms_test 🧪 Execution Tests: What execution tests were added? For example, a new optimization should be tested with an end-to-end execution test triggering the optimization and asserting correctness. Please provide test cases running with at most 2 GPUs. Copybara import of the project: -- 757877c by Zoran Jovanovic <zjovanov@amd.com>: Fixed dot_algorithms_test. Updated support_legacy and test itself. -- 5685c9c by Zoran Jovanovic <zjovanov@amd.com>: Review comments. -- d801633 by Zoran Jovanovic <zjovanov@amd.com>: Review comment. Merging this change closes openxla#38759 COPYBARA_INTEGRATE_REVIEW=openxla#38759 from ROCm:rocm-fix-dot_algorithms_test d801633 PiperOrigin-RevId: 885489190
d5ab3f3 to
4bda461
Compare
| llvm::SmallVector<Value, 2> casted_inputs; | ||
| if (output_type_is_16bit_float) { | ||
| if (output_type_is_16bit_float && | ||
| (output_type.isF16() && | ||
| !HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_))) { |
There was a problem hiding this comment.
bug: The condition output_type_is_16bit_float && (output_type.isF16() && !HasF16Implementation(...)) means BF16 inputs will never be upcasted to F32 on any platform anymore. Before this change, the condition was just output_type_is_16bit_float, so both BF16 and F16 were upcasted to F32 before calling libdevice functions. Now, BF16 will be passed directly to the extern function call without upcast, but libdevice does not have native BF16 implementations -- the function name will still resolve to _f32 via ObtainDeviceFunctionName, causing a type mismatch between the BF16 operands and the F32 function signature.
The same issue applies to the downcast condition on line 225-227: (res.getType() != original_output_type) || (output_type.isF16() && !HasF16Implementation(...)) -- for BF16, res.getType() will equal original_output_type (both BF16), so no truncation happens even though the extern call expects/returns F32.
Consider restructuring the upcast condition to something like:
if (output_type_is_16bit_float &&
!(output_type.isF16() &&
HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_))) {This preserves the existing BF16 upcast behavior while skipping it only for F16 ops that have native F16 implementations on AMD.
| (gpu_version.IsRocm() && | ||
| reduce.shape().element_type() == PrimitiveType::F8E4M3FNUZ)) { | ||
| return CodegenDecision::Forbid( | ||
| "F8E4M3FN and F8E5M2 are not supported for reductions."); |
There was a problem hiding this comment.
nit: This error message is now stale -- the condition also rejects F8E5M2FNUZ and F8E4M3FNUZ on ROCm, but the message only mentions F8E4M3FN and F8E5M2. Consider updating to include the FNUZ variants so the error is not misleading when debugging on ROCm.
| "F8E4M3FN and F8E5M2 are not supported for reductions."); | |
| "F8 types are not supported for reductions."); |
| return CodegenDecision::Forbid( | ||
| "Unsupported GPU architecture for F8E4M3FN/F8E5M2."); | ||
| case F8E4M3FNUZ: | ||
| "Unsupported GPU architecture for F8E4M3FN."); |
There was a problem hiding this comment.
nit: This error message mentions only F8E4M3FN, but the case block covers both F8E5M2 and F8E4M3FN. Consider updating it to mention both types.
| "Unsupported GPU architecture for F8E4M3FN."); | |
| "Unsupported GPU architecture for F8E4M3FN/F8E5M2."); |
|
|
||
| bool HasF16Implementation(TargetDeviceFunctionID func_id, | ||
| llvm::Triple target_triple) { | ||
| if(target_triple.isAMDGPU() && |
There was a problem hiding this comment.
nit: Missing space between if and ( -- should be if (target_triple.isAMDGPU() per the project's coding style.
| if(target_triple.isAMDGPU() && | |
| if (target_triple.isAMDGPU() && |
| llvm::SmallVector<Value, 2> casted_inputs; | ||
| if (output_type_is_16bit_float) { | ||
| if (output_type_is_16bit_float && | ||
| (output_type.isF16() && | ||
| !HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_))) { |
There was a problem hiding this comment.
Bug: BF16 libdevice calls broken by this change.
The new condition:
output_type_is_16bit_float && (output_type.isF16() && !HasF16Implementation(...))is logically equivalent to:
output_type.isF16() && !HasF16Implementation(...)because isF16() already implies output_type_is_16bit_float. This means BF16 operands are no longer upcasted to F32 before calling libdevice functions.
However, libdevice/OCML functions don't have native BF16 variants — ObtainDeviceFunctionName will still return an _f32 function name for BF16 inputs. The result is a type mismatch: BF16 operands passed to an F32 function signature, which will produce incorrect results or crash.
The same issue affects the downcast path on lines 225–227 — BF16 outputs are no longer truncated back from F32.
Previously, all 16-bit float types (both F16 and BF16) were upcasted to F32, which was correct for BF16. The fix for F16-on-AMDGPU should preserve the BF16→F32 upcast. Suggested fix:
if (output_type_is_16bit_float &&
!(output_type.isF16() && HasF16Implementation(...))) {This upcasts BF16 always, and upcasts F16 only when there's no native F16 implementation.
| const HloReduceInstruction& reduce, | ||
| const se::GpuComputeCapability& gpu_version) { | ||
| if (reduce.shape().element_type() == PrimitiveType::F8E4M3FN || | ||
| reduce.shape().element_type() == PrimitiveType::F8E5M2) { | ||
| reduce.shape().element_type() == PrimitiveType::F8E5M2 || | ||
| (gpu_version.IsRocm() && | ||
| reduce.shape().element_type() == PrimitiveType::F8E5M2FNUZ) || | ||
| (gpu_version.IsRocm() && |
There was a problem hiding this comment.
nit: The error message says "F8E4M3FN and F8E5M2 are not supported for reductions" but the condition now also rejects F8E5M2FNUZ and F8E4M3FNUZ. Please update the message to mention the FNUZ types as well for accurate debugging.
| const HloReduceInstruction& reduce, | |
| const se::GpuComputeCapability& gpu_version) { | |
| if (reduce.shape().element_type() == PrimitiveType::F8E4M3FN || | |
| reduce.shape().element_type() == PrimitiveType::F8E5M2) { | |
| reduce.shape().element_type() == PrimitiveType::F8E5M2 || | |
| (gpu_version.IsRocm() && | |
| reduce.shape().element_type() == PrimitiveType::F8E5M2FNUZ) || | |
| (gpu_version.IsRocm() && | |
| if (primitive_util::IsFloat8Type(type)) { | |
| return CodegenDecision::Forbid( | |
| "F8E4M3FN, F8E5M2, F8E4M3FNUZ, and F8E5M2FNUZ are not supported " | |
| "for reductions."); | |
| } |
| case PC::ALG_DOT_BF16_BF16_F32_X9: | ||
| if (GpuComputeComp().IsRocm()) { | ||
| // X6 and X9 algorithms on ROCm marked as not supported | ||
| // because they often require too much shared memory. | ||
| EXPECT_FALSE(result_or_status.value()) | ||
| << "algorithms not supported on ROCm"; | ||
| } else { | ||
| ASSERT_TRUE(result_or_status.status().ok()) | ||
| << "failed to compile " << algorithm_; | ||
| EXPECT_TRUE(result_or_status.value()) | ||
| << "wrong result for " << algorithm_; | ||
| } |
There was a problem hiding this comment.
Bug: Missing status assertion in ROCm branch.
The CUDA path correctly does ASSERT_TRUE(result_or_status.status().ok()) before calling .value(), but the ROCm branch goes straight to EXPECT_FALSE(result_or_status.value()). If result_or_status holds a non-OK status, calling .value() will crash.
Please add the status assertion:
#if TENSORFLOW_USE_ROCM
ASSERT_TRUE(result_or_status.status().ok());
EXPECT_FALSE(result_or_status.value());
#else| return CodegenDecision( | ||
| element_type != PrimitiveType::F8E4M3FN && | ||
| element_type != PrimitiveType::F8E5M2 && | ||
| !(gpu_version.IsRocm() && | ||
| element_type == PrimitiveType::F8E4M3FNUZ) && | ||
| !(gpu_version.IsRocm() && |
There was a problem hiding this comment.
nit: Same stale error message issue — the condition now blocks FNUZ types too, but the message only mentions "F8E4M3FN, F8E5M2 and S4". Please add the FNUZ types to the message.
| return CodegenDecision::Forbid( | ||
| "F8E4M3FN and F8E5M2 are not supported for reductions."); |
There was a problem hiding this comment.
nit: The error message still reads "F8E4M3FN and F8E5M2 are not supported for reductions." but the condition now also forbids F8E5M2FNUZ and F8E4M3FNUZ on ROCm. Consider updating the message to mention the FNUZ variants so it is accurate when triggered on ROCm.
| return CodegenDecision::Forbid( | |
| "F8E4M3FN and F8E5M2 are not supported for reductions."); | |
| return CodegenDecision::Forbid( | |
| "F8E4M3FN, F8E5M2, F8E4M3FNUZ and F8E5M2FNUZ are not supported for " | |
| "reductions."); |
| if (GpuComputeComp().IsRocm()) { | ||
| // X6 and X9 algorithms on ROCm marked as not supported | ||
| // because they often require too much shared memory. | ||
| EXPECT_FALSE(result_or_status.value()) |
There was a problem hiding this comment.
bug: On the ROCm path you call result_or_status.value() without first asserting that result_or_status.status().ok(). If the status is not OK, .value() will either crash or produce a confusing test failure.
The non-ROCm branch correctly guards with ASSERT_TRUE(result_or_status.status().ok()) before calling .value(). Consider adding the same guard here:
if (GpuComputeComp().IsRocm()) {
// X6 and X9 algorithms on ROCm marked as not supported
// because they often require too much shared memory.
ASSERT_TRUE(result_or_status.status().ok())
<< "failed to compile " << algorithm_;
EXPECT_FALSE(result_or_status.value())
<< "algorithms not supported on ROCm";| case F8E5M2FNUZ: | ||
| if (gpu_cc.has_value() && !gpu_cc.value().IsRocm()) { | ||
| return absl::UnimplementedError( | ||
| "F8E5M2FNUZ is not supported on this GPU."); | ||
| } | ||
| return b.getType<mlir::Float8E5M2FNUZType>(); | ||
| case F8E4M3FNUZ: | ||
| if (gpu_cc.has_value() && !gpu_cc.value().IsRocm()) { | ||
| return absl::UnimplementedError( | ||
| "F8E4M3FNUZ is not supported on this GPU."); | ||
| } | ||
| return b.getType<mlir::Float8E4M3FNUZType>(); |
There was a problem hiding this comment.
Default-allow concern for FNUZ types. When gpu_cc is std::nullopt (the default), FNUZ types are unconditionally allowed. This means callers that don't pass a gpu_cc (e.g., tests, other XTile usages) will silently accept FNUZ types even on CUDA, where they're unsupported. Consider inverting the logic to require explicit ROCm opt-in rather than only forbidding when explicitly non-ROCm.
| if (output_type_is_16bit_float && | ||
| (output_type.isF16() && |
There was a problem hiding this comment.
nit: output_type_is_16bit_float is redundant here -- if output_type.isF16() is true, then output_type_is_16bit_float is necessarily also true (F16 is a 16-bit float). The outer check can be dropped to simplify the condition:
| if (output_type_is_16bit_float && | |
| (output_type.isF16() && | |
| if (output_type.isF16() && |
|
|
||
| bool HasF16Implementation(TargetDeviceFunctionID func_id, | ||
| llvm::Triple target_triple) { | ||
| if(target_triple.isAMDGPU() && |
There was a problem hiding this comment.
style: Missing space between if and ( -- the rest of the codebase uses if (...).
| if(target_triple.isAMDGPU() && | |
| if (target_triple.isAMDGPU() && |
| case F8E5M2: | ||
| case F8E4M3FN: | ||
| if (gpu_version.IsCuda()) { | ||
| if (gpu_version.IsCuda() || gpu_version.IsRocm()) { |
There was a problem hiding this comment.
question: The old code required has_ocp_fp8_support() for F8E5M2/F8E4M3FN on ROCm, but the refactored check now unconditionally allows these types for any ROCm device. Was the capability check intentionally dropped, or should older ROCm devices without OCP FP8 support still be gated?
Re-review Summary (Claude)Previous finding on README.md H1→H2 downgrade is resolved (no longer in diff). New Findings2 bugs:
3 nits:
2 design considerations:
See inline comments for details and suggested fixes. 🤖 Generated with Claude Code |
| return CodegenDecision::Forbid( | ||
| "F8E4M3FN and F8E5M2 are not supported for reductions."); |
There was a problem hiding this comment.
nit: The condition now also forbids F8E5M2FNUZ and F8E4M3FNUZ on ROCm (line 320–321), but the error message still says only "F8E4M3FN and F8E5M2 are not supported for reductions." The same issue exists for the iota error message around line 767. Consider updating both messages to mention the FNUZ variants so the diagnostic is accurate for ROCm users.
| return CodegenDecision::Forbid( | |
| "F8E4M3FN and F8E5M2 are not supported for reductions."); | |
| return FailedPrecondition( | |
| "F8E4M3FN, F8E5M2, F8E5M2FNUZ, and F8E4M3FNUZ are not supported for " |
| if (GpuComputeComp().IsRocm()) { | ||
| // X6 and X9 algorithms on ROCm marked as not supported | ||
| // because they often require too much shared memory. | ||
| EXPECT_FALSE(result_or_status.value()) |
There was a problem hiding this comment.
bug: The ROCm branch calls result_or_status.value() without first asserting result_or_status.status().ok(). The CUDA branch (line 1625) does ASSERT_TRUE(result_or_status.status().ok()) before accessing .value(). If the status is not OK, .value() will crash or produce UB. Consider adding the same assertion here:
| EXPECT_FALSE(result_or_status.value()) | |
| ASSERT_TRUE(result_or_status.status().ok()); | |
| EXPECT_FALSE(result_or_status.value()); |
| if (output_type_is_16bit_float && | ||
| (output_type.isF16() && |
There was a problem hiding this comment.
question/bug: The condition output_type_is_16bit_float && (output_type.isF16() && !HasF16Implementation(...)) is equivalent to output_type.isF16() && !HasF16Implementation(...) — the output_type_is_16bit_float check is redundant since isF16() implies it.
More importantly, this means BF16 inputs are never upcast to F32 anymore (on any target), since the outer && short-circuits when isF16() is false. Previously, all 16-bit floats (including BF16) were upcast to F32. If this is intentional for BF16, please add a comment explaining why. Otherwise, the intended logic is likely:
if (output_type_is_16bit_float &&
!(output_type.isF16() && HasF16Implementation(op_name, target_triple))) {This would preserve the BF16→F32 upcast while skipping it only for F16 ops with native AMDGPU support.
|
|
||
| bool HasF16Implementation(TargetDeviceFunctionID func_id, | ||
| llvm::Triple target_triple) { | ||
| if(target_triple.isAMDGPU() && |
There was a problem hiding this comment.
nit: Missing space after if keyword — Google C++ style requires if ( not if(.
| if(target_triple.isAMDGPU() && | |
| if (target_triple.isAMDGPU() && |
Re-review SummaryPrevious finding on the README.md H1→H2 change is now resolved (removed from PR). New findings (6 inline comments):
See inline comments for details and suggested fixes. |
… autotuner Imported from GitHub PR openxla#38792 Enable FissionBackend autotuning for ROCm (rocBLAS + hipBLASLt) - Added HIPBLASLT_FISSION to backend proto - Updated factory_rocm.cc to register the backends - xla_gpu_experimental_disabe_binary_libraries, xla_gpu_enable_cublaslt behavior mirror CUDA Also minor fix: removed dead return in fission_backend.cc 🚀 Kind of Contribution ✨ New Feature Copybara import of the project: -- ad88d8c by Eetu Sjöblom <eetu.sjoblom@amd.com>: Enable rocblas/hipBLASLt fission on ROCm Merging this change closes openxla#38792 COPYBARA_INTEGRATE_REVIEW=openxla#38792 from ROCm:ci_rocm_enable_fission_upstream ad88d8c PiperOrigin-RevId: 884317138 (cherry picked from commit 77ecf53)
…ne_level == 0 Imported from GitHub PR openxla#37074 📝 Summary of Changes Adapt GpuCompiler::AddConvAndGemmAutotuningPass to match pre refactor behavior of AMDGPUCompiler. 🎯 Justification For ROCm we need to run miopen backend even when autotuning is disabled in order to decompose back unsupported fused convolutions. There is no runtime fallback. 🚀 Kind of Contribution 🐛 Bug Fix 📊 Benchmark (for Performance Improvements) N\A 🧪 Unit Tests: None 🧪 Execution Tests: None Copybara import of the project: -- 486498b by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>: [ROCm] Reenable miopen autotune when xla_gpu_autotune_level == 0 -- feec478 by Dragan Mladjenovic <Dragan.Mladjenovic@amd.com>: Fix //xla/backends/gpu/autotuner:miopen_test_amdgpu_any Merging this change closes openxla#37074 COPYBARA_INTEGRATE_REVIEW=openxla#37074 from ROCm:miopen_autotune feec478 PiperOrigin-RevId: 877364967
Backport upstream fix that prevents a crash in the BlockPingpong optimization when applied to FMA (non-MFMA) dot operations. The pass now uses dyn_cast instead of cast for AMDMfmaEncodingAttr and returns early when the encoding is not MFMA.
|
Created PR with review comments from claude: |
|
Let us merge this. Verified locally on MI300 that sgpu and mgpu UTs pass |
|
@hsharsha wondering which script you tested on all UTs? do we still include these stale fiter list? https://github.com/ROCm/xla/blob/rocm-jaxlib-v0.9.1/build_tools/rocm/rocm_xla.bazelrc#L61-L78 I think these filter list are invalid now. |
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist