Fix issue with missing upcast/downcast for bf16 libdevice calls.#660
Fix issue with missing upcast/downcast for bf16 libdevice calls.#660zoranjovanovic-ns wants to merge 1 commit intorocm-jaxlib-v0.9.0from
Conversation
| (output_type.isBF16() || | ||
| (output_type.isF16() && | ||
| !HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_)))) { |
There was a problem hiding this comment.
nit: indentation misalignment. !HasF16Implementation(...) is inside the nested (output_type.isF16() && ...) parenthesized expression but is indented to the same level as the opening (output_type.isF16() line. It should be indented one additional space to align with the content of the inner parentheses. Same issue at lines 217-219 in the downcast condition.
| (output_type.isBF16() || | |
| (output_type.isF16() && | |
| !HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_)))) { | |
| (output_type.isBF16() || | |
| (output_type.isF16() && | |
| !HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_)))) { |
| if (res.getType() != output_type || | ||
| (output_type.isF16() && | ||
| !HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_))) { | ||
| (output_type.isBF16() || | ||
| (output_type.isF16() && | ||
| !HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_)))) { |
There was a problem hiding this comment.
question: The res.getType() != output_type comparison checks a shaped type (e.g. tensor<1024xf32>) against an element type (bf16 or f16). These will always differ, making the entire condition always true regardless of the bf16/f16 sub-conditions. Is this intentional? If so, the bf16/f16 checks here are redundant. If the intent was res.getType() != shaped_output_type (comparing shaped types), that would make the sub-conditions meaningful.
Also, same indentation nit as above:
| if (res.getType() != output_type || | |
| (output_type.isF16() && | |
| !HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_))) { | |
| (output_type.isBF16() || | |
| (output_type.isF16() && | |
| !HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_)))) { | |
| if (res.getType() != output_type || | |
| (output_type.isBF16() || | |
| (output_type.isF16() && | |
| !HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_)))) { |
There was a problem hiding this comment.
output_type is not an element type, it could be a shape type. Can you double check that? Please consider the entire function to be sure.
There was a problem hiding this comment.
Good question. After examining the full function, output_type is indeed an element type, not a shaped type -- see line 177 where it is extracted via maybe_shaped_type.getElementType() (or falls back to op.getType() for scalar ops). However, res.getType() returns the result type of ExternElementwiseOp, which is set to casted_inputs[0].getType(). When the inputs are upcasted to F32, casted_inputs[0].getType() is the shaped tensor type (e.g. tensor<1024xf32>), while output_type is just bf16 or f16. So in the tensor case, res.getType() != output_type will always be true (comparing shaped vs element type), making the bf16/f16 sub-conditions redundant. This works in practice because the downcast Cast(builder, res, output_type) handles the mismatch gracefully. But the condition logic is misleading. The indentation nit from the original review also still applies here (line 219 should have one additional space).
| if (output_type_is_16bit_float && | ||
| (output_type.isF16() && | ||
| !HasF16Implementation(OpInfo<OpTy>::kFunctionID, triple_))) { | ||
| (output_type.isBF16() || |
There was a problem hiding this comment.
suggestion: Missing test coverage for the fixed scenario. The existing bf16 tests (exp_bf16, log_bf16 in triton_xla_math_to_libdevice.mlir) cover ops that have HasFastF32Approximation == true (kExp, kLog), which already took a different code path. A test for a bf16 op without a fast approximation (e.g. sin_bf16, tanh_bf16, rsqrt_bf16) would directly exercise this new upcast/downcast fix and prevent regressions.
There was a problem hiding this comment.
Can you provide a test for a bf16 op without a fast approximation (e.g. sin_bf16)?
There was a problem hiding this comment.
Agreed -- adding a test for a bf16 op without a fast F32 approximation (e.g. math.sin on bf16) would directly exercise the new upcast/downcast code path being fixed in this PR. The existing exp_bf16 and log_bf16 tests take the HasFastF32Approximation path, so they do not cover this fix. This is still needed to prevent regressions.
There was a problem hiding this comment.
I am asking not PR's author, but you Claude to provide a test for a bf16 op without a fast approximation (e.g. sin_bf16). Can you do that here in the comment?
There was a problem hiding this comment.
Here is an example sin_bf16 test that exercises the upcast/downcast code path:
// -----
func.func @sin_bf16(%arg0: tensor<1024xbf16>) -> tensor<1024xbf16> {
%result = math.sin %arg0 : tensor<1024xbf16>
return %result : tensor<1024xbf16>
}
// CHECK: %[[CAST:.*]] = arith.extf %arg0 : tensor<1024xbf16> to tensor<1024xf32>
// CHECK: tt.extern_elementwise %[[CAST]]
// CHECK-SAME: {libname = "libdevice", libpath = "/path/to/libdevice",
// CHECK-SAME: pure = true, symbol = "__nv_sinf"}
// CHECK: arith.truncf {{.*}} : tensor<1024xf32> to tensor<1024xbf16>math.sin has no HasFastF32Approximation, so it takes the regular libdevice path. This test verifies bf16 inputs get upcasted to f32 before __nv_sinf and truncated back to bf16. Without the fix, upcast/downcast is skipped for bf16.
|
Review Summary: The fix is correct and necessary. Before this change, bf16 operands were passed directly to libdevice functions that have no native bf16 support, producing incorrect IR. Adding output_type.isBF16() to the upcast/downcast conditions properly handles bf16 by casting through f32. Three inline comments posted covering: indentation alignment, downcast condition type comparison question, and test coverage gap for bf16 ops without fast f32 approximations. |
i-chaochen
left a comment
There was a problem hiding this comment.
are we missing this UT on CI?
We are not running in ROCm CI tests that are not tagged with //xla/backends/gpu/codegen/triton/ir/tests:all
//xla/backends/gpu/codegen/triton/transforms/tests:all
//xla/backends/gpu/codegen/emitters/transforms/tests:all
//xla/backends/gpu/codegen/emitters/ir/tests:all
//xla/codegen/emitters/transforms/tests:all
//xla/codegen/emitters/ir/tests:all
//xla/codegen/xtile/ir/tests:all
//xla/codegen/xtile/ir/transforms/tests:all |
Re-review SummaryThe diff is unchanged since the prior review. All 3 previous findings still apply:
No new issues found. See inline threads for details. |
Hmmm...we need to bring up CPU UT pipeline on our internal CI to check all these IR check UTs. |
📝 Summary of Changes
Introduced missing upcast/downcast for bf16 type
🎯 Justification
upcast/downcast are necessary because there is no native bf16 implementation in libdevice
🚀 Kind of Contribution
Please remove what does not apply: 🐛 Bug Fix
📊 Benchmark (for Performance Improvements)
Please measure and include speedups for one of the public HLOs in
compiler/xla/tools/benchmarks/hlo/.🧪 Unit Tests:
triton_xla_math_to_libdevice.mlir
🧪 Execution Tests:
What execution tests were added? For example, a new optimization should be
tested with an end-to-end execution test triggering the optimization and
asserting correctness. Please provide test cases running with at most 2 GPUs.