Skip to content

Fix: Scaled Matmul at rocm 9.0#739

Open
shurale-nkn wants to merge 1 commit intorocm-jaxlib-v0.9.0from
fix_scaled_matmul_dot_test_0.9.0
Open

Fix: Scaled Matmul at rocm 9.0#739
shurale-nkn wants to merge 1 commit intorocm-jaxlib-v0.9.0from
fix_scaled_matmul_dot_test_0.9.0

Conversation

@shurale-nkn
Copy link

@shurale-nkn shurale-nkn commented Mar 17, 2026

PR fixes failing ScaledMatmul and ScaledDot tests on ROCm caused by:
MLIR translation rule for primitive 'scaled_matmul' not found for platform rocm

On ROCm, we now register a dedicated lowering for scaled_matmul that uses a composite approach: it lowers through lax.scaled_dot instead of relying on direct fusion.
The resulting HLO can then be fused later by XLA, which matches the intended backend behavior and avoids the missing primitive translation path on ROCm.

What changed

  • Implemented ROCm lowering by delegating to lax.scaled_dot.
  • Kept existing CUDA lowering unchanged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant