Skip to content

Rocm jaxlib v0.9.0 dev scaled dot#644

Open
ScXfjiang wants to merge 12 commits intorocm-jaxlib-v0.9.0from
rocm-jaxlib-v0.9.0_dev_scaled_dot
Open

Rocm jaxlib v0.9.0 dev scaled dot#644
ScXfjiang wants to merge 12 commits intorocm-jaxlib-v0.9.0from
rocm-jaxlib-v0.9.0_dev_scaled_dot

Conversation

@ScXfjiang
Copy link

Replace old mx type implementation with the new one.

New implementation reference PR: openxla#38333

Cherry-picked from 3f9ea02 with conflict resolution:
- hipblaslt.cc: use new ApplyConfig with scaled dot fusion support
- thunk_emitter.cc: add IsCublasLtMatmulMx to existing inline dispatch

Made-with: Cursor
@claude
Copy link

claude bot commented Feb 27, 2026

Review Summary: This PR replaces the old BlockScalingRewriter-based MX type handling on ROCm with a new approach that routes kScaledDot operations through the autotuner HipblasLtBackend. 5 findings posted as inline comments covering proto wire-compatibility, default workspace sizing for MX ops, integer division in validation, FP4 test tolerance, and semantic broadening of IsCublasGemm. Overall the architectural direction is sound. The main items to address are the workspace size mismatch in GetDefaultConfig and verifying the proto field number swap will not corrupt existing caches.

@claude
Copy link

claude bot commented Feb 27, 2026

PR Review Summary: This PR replaces the old ROCm MX type implementation in BlockScalingRewriter with a new approach that routes kScaledDot operations through the autotuner HipblasLtBackend. Good architectural direction. Key findings (see inline comments): (1) Bug: GetDefaultConfig workspace size mismatch for MX scaled-dot. (2) Bug: IsCublasLtMatmulMx in IsCublasGemm without EmitAsyncCustomCallStart support. (3) Wire-format breakage: Proto field numbers swapped in GemmConfigProto. (4) FP4 test tolerance may be too tight. (5) Global FP4 random distribution change in literal_util.cc. The overall refactoring looks correct.

@claude
Copy link

claude bot commented Mar 13, 2026

Re-review Summary (v2)

No new issues found. All findings from this review pass were already covered by prior inline comments.

Resolved (fixed in this revision):

  • FP8 scale pointer regression in hip_blas_lt.cc — fixed via scale_mode switch
  • Integer division concern in hipblaslt.cc — safe due to existing k%32 guard
  • Async thunk wrapping of MX custom calls in cublas_cudnn.cc — removed

Still open (from prior review, no changes in this revision):

  • Proto wire-format field number swap (gpu_blas_lt.proto fields 13-14)
  • backend_configs.proto field 21 type change (bool to int32)
  • GetDefaultConfig missing workspace size for MX scaled-dot
  • ApplyConfig replacing fusion with custom call (unusual autotuner pattern)
  • Removal of f8e8m0fnu/f4e2m1fn random data specializations in literal_util.cc
  • FP4 test tolerance may be too tight

@claude
Copy link

claude bot commented Mar 13, 2026

Re-review Summary

Re-reviewed the latest diff. All previously flagged issues have been examined against the current code:

  • Resolved: Several prior findings (.value() usage, size_t loop index, workspace size, swapped scale variables, dynamic_slice_thunk_test arg, integer division check) have been addressed in subsequent commits.
  • Acknowledged/Declined: Default workspace size for MX in GetDefaultConfig — developer confirmed no change needed (design decision).
  • Unchanged (previously flagged): Proto field number swap (wire compatibility), FP4 test tolerance, PopulateWithRandomFloatingPointData specialization removal, async custom call wrapping for MX, and a few style nits remain from prior rounds.

No new issues found in this round. The PR is converging well.

🤖 Generated with Claude Code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants