Skip to content

Fix issue with missing upcast/downcast for bf16 libdevice calls.#660

Open
zoranjovanovic-ns wants to merge 1 commit intorocm-jaxlib-v0.9.0from
rocm-jaxlib-v0.9.0-math_to_libdevice
Open

Fix issue with missing upcast/downcast for bf16 libdevice calls.#660
zoranjovanovic-ns wants to merge 1 commit intorocm-jaxlib-v0.9.0from
rocm-jaxlib-v0.9.0-math_to_libdevice

Conversation

@zoranjovanovic-ns
Copy link

📝 Summary of Changes
Introduced missing upcast/downcast for bf16 type

🎯 Justification
upcast/downcast are necessary because there is no native bf16 implementation in libdevice

🚀 Kind of Contribution
Please remove what does not apply: 🐛 Bug Fix

📊 Benchmark (for Performance Improvements)
Please measure and include speedups for one of the public HLOs in
compiler/xla/tools/benchmarks/hlo/.

🧪 Unit Tests:
triton_xla_math_to_libdevice.mlir

🧪 Execution Tests:
What execution tests were added? For example, a new optimization should be
tested with an end-to-end execution test triggering the optimization and
asserting correctness. Please provide test cases running with at most 2 GPUs.

@zoranjovanovic-ns zoranjovanovic-ns added rocm-jaxlib-v0.8.2 cherry-pick-candidate Mark a PR to be cherry-picked into the next ROCm JAX. Remove IIF the latest upstream contain the PR. Upstream rocm-jaxlib-v0.9.0 and removed rocm-jaxlib-v0.8.2 labels Mar 9, 2026
Comment on lines +197 to +199
(output_type.isBF16() ||
(output_type.isF16() &&
!HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_)))) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indentation misalignment. !HasF16Implementation(...) is inside the nested (output_type.isF16() && ...) parenthesized expression but is indented to the same level as the opening (output_type.isF16() line. It should be indented one additional space to align with the content of the inner parentheses. Same issue at lines 217-219 in the downcast condition.

Suggested change
(output_type.isBF16() ||
(output_type.isF16() &&
!HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_)))) {
(output_type.isBF16() ||
(output_type.isF16() &&
!HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_)))) {

Comment on lines 216 to +219
if (res.getType() != output_type ||
(output_type.isF16() &&
!HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_))) {
(output_type.isBF16() ||
(output_type.isF16() &&
!HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_)))) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: The res.getType() != output_type comparison checks a shaped type (e.g. tensor<1024xf32>) against an element type (bf16 or f16). These will always differ, making the entire condition always true regardless of the bf16/f16 sub-conditions. Is this intentional? If so, the bf16/f16 checks here are redundant. If the intent was res.getType() != shaped_output_type (comparing shaped types), that would make the sub-conditions meaningful.

Also, same indentation nit as above:

Suggested change
if (res.getType() != output_type ||
(output_type.isF16() &&
!HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_))) {
(output_type.isBF16() ||
(output_type.isF16() &&
!HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_)))) {
if (res.getType() != output_type ||
(output_type.isBF16() ||
(output_type.isF16() &&
!HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_)))) {

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output_type is not an element type, it could be a shape type. Can you double check that? Please consider the entire function to be sure.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. After examining the full function, output_type is indeed an element type, not a shaped type -- see line 177 where it is extracted via maybe_shaped_type.getElementType() (or falls back to op.getType() for scalar ops). However, res.getType() returns the result type of ExternElementwiseOp, which is set to casted_inputs[0].getType(). When the inputs are upcasted to F32, casted_inputs[0].getType() is the shaped tensor type (e.g. tensor<1024xf32>), while output_type is just bf16 or f16. So in the tensor case, res.getType() != output_type will always be true (comparing shaped vs element type), making the bf16/f16 sub-conditions redundant. This works in practice because the downcast Cast(builder, res, output_type) handles the mismatch gracefully. But the condition logic is misleading. The indentation nit from the original review also still applies here (line 219 should have one additional space).

Comment on lines 196 to +197
if (output_type_is_16bit_float &&
(output_type.isF16() &&
!HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_))) {
(output_type.isBF16() ||
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Missing test coverage for the fixed scenario. The existing bf16 tests (exp_bf16, log_bf16 in triton_xla_math_to_libdevice.mlir) cover ops that have HasFastF32Approximation == true (kExp, kLog), which already took a different code path. A test for a bf16 op without a fast approximation (e.g. sin_bf16, tanh_bf16, rsqrt_bf16) would directly exercise this new upcast/downcast fix and prevent regressions.

Copy link
Member

@nurmukhametov nurmukhametov Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide a test for a bf16 op without a fast approximation (e.g. sin_bf16)?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed -- adding a test for a bf16 op without a fast F32 approximation (e.g. math.sin on bf16) would directly exercise the new upcast/downcast code path being fixed in this PR. The existing exp_bf16 and log_bf16 tests take the HasFastF32Approximation path, so they do not cover this fix. This is still needed to prevent regressions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am asking not PR's author, but you Claude to provide a test for a bf16 op without a fast approximation (e.g. sin_bf16). Can you do that here in the comment?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is an example sin_bf16 test that exercises the upcast/downcast code path:

// -----

func.func @sin_bf16(%arg0: tensor<1024xbf16>) -> tensor<1024xbf16> {
  %result = math.sin %arg0 : tensor<1024xbf16>
  return %result : tensor<1024xbf16>
}

// CHECK:       %[[CAST:.*]] = arith.extf %arg0 : tensor<1024xbf16> to tensor<1024xf32>
// CHECK:       tt.extern_elementwise %[[CAST]]
// CHECK-SAME:    {libname = "libdevice", libpath = "/path/to/libdevice",
// CHECK-SAME:    pure = true, symbol = "__nv_sinf"}
// CHECK:       arith.truncf {{.*}} : tensor<1024xf32> to tensor<1024xbf16>

math.sin has no HasFastF32Approximation, so it takes the regular libdevice path. This test verifies bf16 inputs get upcasted to f32 before __nv_sinf and truncated back to bf16. Without the fix, upcast/downcast is skipped for bf16.

@claude
Copy link

claude bot commented Mar 9, 2026

Review Summary: The fix is correct and necessary. Before this change, bf16 operands were passed directly to libdevice functions that have no native bf16 support, producing incorrect IR. Adding output_type.isBF16() to the upcast/downcast conditions properly handles bf16 by casting through f32. Three inline comments posted covering: indentation alignment, downcast condition type comparison question, and test coverage gap for bf16 ops without fast f32 approximations.

Copy link
Collaborator

@i-chaochen i-chaochen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we missing this UT on CI?

@nurmukhametov
Copy link
Member

are we missing this UT on CI?

We are not running in ROCm CI tests that are not tagged with gpu, so yes this UT and similar ones are missed:

//xla/backends/gpu/codegen/triton/ir/tests:all   
//xla/backends/gpu/codegen/triton/transforms/tests:all   
//xla/backends/gpu/codegen/emitters/transforms/tests:all   
//xla/backends/gpu/codegen/emitters/ir/tests:all   
//xla/codegen/emitters/transforms/tests:all  
//xla/codegen/emitters/ir/tests:all   
//xla/codegen/xtile/ir/tests:all   
//xla/codegen/xtile/ir/transforms/tests:all

@claude
Copy link

claude bot commented Mar 10, 2026

Re-review Summary

The diff is unchanged since the prior review. All 3 previous findings still apply:

  1. Indentation misalignment (upcast and downcast conditions) - not yet addressed
  2. Type comparison concern in downcast guard - replied to developer question with clarification
  3. Missing bf16 test for ops without fast f32 approximation (e.g. sin_bf16, tanh_bf16) - replied confirming the test is still needed

No new issues found. See inline threads for details.

@i-chaochen
Copy link
Collaborator

are we missing this UT on CI?

We are not running in ROCm CI tests that are not tagged with gpu, so yes this UT and similar ones are missed:

//xla/backends/gpu/codegen/triton/ir/tests:all   
//xla/backends/gpu/codegen/triton/transforms/tests:all   
//xla/backends/gpu/codegen/emitters/transforms/tests:all   
//xla/backends/gpu/codegen/emitters/ir/tests:all   
//xla/codegen/emitters/transforms/tests:all  
//xla/codegen/emitters/ir/tests:all   
//xla/codegen/xtile/ir/tests:all   
//xla/codegen/xtile/ir/transforms/tests:all

Hmmm...we need to bring up CPU UT pipeline on our internal CI to check all these IR check UTs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cherry-pick-candidate Mark a PR to be cherry-picked into the next ROCm JAX. Remove IIF the latest upstream contain the PR. rocm-jaxlib-v0.9.0 Upstream

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants