Rocm jaxlib v0.9.0 dev scaled dot#644
Conversation
|
Review Summary: This PR replaces the old BlockScalingRewriter-based MX type handling on ROCm with a new approach that routes kScaledDot operations through the autotuner HipblasLtBackend. 5 findings posted as inline comments covering proto wire-compatibility, default workspace sizing for MX ops, integer division in validation, FP4 test tolerance, and semantic broadening of IsCublasGemm. Overall the architectural direction is sound. The main items to address are the workspace size mismatch in GetDefaultConfig and verifying the proto field number swap will not corrupt existing caches. |
|
PR Review Summary: This PR replaces the old ROCm MX type implementation in BlockScalingRewriter with a new approach that routes kScaledDot operations through the autotuner HipblasLtBackend. Good architectural direction. Key findings (see inline comments): (1) Bug: GetDefaultConfig workspace size mismatch for MX scaled-dot. (2) Bug: IsCublasLtMatmulMx in IsCublasGemm without EmitAsyncCustomCallStart support. (3) Wire-format breakage: Proto field numbers swapped in GemmConfigProto. (4) FP4 test tolerance may be too tight. (5) Global FP4 random distribution change in literal_util.cc. The overall refactoring looks correct. |
Re-review Summary (v2)No new issues found. All findings from this review pass were already covered by prior inline comments. Resolved (fixed in this revision):
Still open (from prior review, no changes in this revision):
|
Re-review SummaryRe-reviewed the latest diff. All previously flagged issues have been examined against the current code:
No new issues found in this round. The PR is converging well. 🤖 Generated with Claude Code |
Replace old mx type implementation with the new one.
New implementation reference PR: openxla#38333