Skip to content

fix(ci): configure ROCm library paths for JAX wheel tests#4012

Open
WBobby wants to merge 2 commits intomainfrom
users/yanywang/16-issue-3627-fix
Open

fix(ci): configure ROCm library paths for JAX wheel tests#4012
WBobby wants to merge 2 commits intomainfrom
users/yanywang/16-issue-3627-fix

Conversation

@WBobby
Copy link
Contributor

@WBobby WBobby commented Mar 17, 2026

Motivation

JAX wheel tests in CI pass on CPU but GPU backend is completely broken, masking the issue:

  1. ROCm tarball installs to /opt/rocm-<version>/ without a /opt/rocm symlink, so the hardcoded RUNPATH /opt/rocm/lib in JAX wheels never resolves
  2. Without JAX_PLATFORMS=rocm, JAX silently falls back to CPU — 896 tests passed on CPU, none on GPU

Fixes #3627

Changes

  • ln -sfn "${DEST}" /opt/rocm — create symlink so RUNPATH /opt/rocm/lib and /opt/rocm/lib/rocm_sysdeps/lib in JAX wheels resolve to the versioned install path
  • JAX_PLATFORMS: rocm — force ROCm backend in test step, preventing silent CPU fallback

Why LD_LIBRARY_PATH is not needed

JAX plugin .so files have all 25 librocm_sysdeps_*.so as direct DT_NEEDED entries (XLA's BUILD.tpl uses srcs for system_libs). RUNPATH applies to all direct dependencies, so there is no transitivity issue.

With the companion JAX fix (ROCm/jax#737 / jax-ml/jax#35978) adding /opt/rocm/lib/rocm_sysdeps/lib to the wheel RUNPATH, the symlink alone is sufficient:

  • /opt/rocm/lib${DEST}/lib → main ROCm libs (libamdhip64.so.7, libhsa-runtime64.so.1, etc.)
  • /opt/rocm/lib/rocm_sysdeps/lib${DEST}/lib/rocm_sysdeps/lib → 25 vendored sysdeps libs

Dependency: This PR requires ROCm/jax#737 to be merged and new wheels published before the rocm_sysdeps libraries can be resolved without LD_LIBRARY_PATH.

Test Plan

  • CI workflow triggers on a GPU runner with ROCm tarball install
  • rocm_plugin_extension imports successfully (no ImportError)
  • JAX tests run on ROCm backend, not CPU fallback
  • Verify logs show ROCm wheel install found instead of No ROCm wheel installation found

Test Result

Pending CI validation

Submission Checklist

ROCm tarball installs to /opt/rocm-<version>/ but JAX wheels have
RUNPATH hardcoded to /opt/rocm/lib and missing rocm_sysdeps/lib/.
Add symlink, LD_LIBRARY_PATH, and JAX_PLATFORMS=rocm to fix.

Fixes #3627

Signed-off-by: Yanyao Wang <wangyanyao@msn.com>
# Set ROCm environment variables for dynamic linker.
# Both lib paths are required: main ROCm libs + vendored sysdeps libs
echo "ROCM_PATH=${DEST}" >> "$GITHUB_ENV"
echo "LD_LIBRARY_PATH=${DEST}/lib:${DEST}/lib/rocm_sysdeps/lib${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}" >> "$GITHUB_ENV"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding /opt/rocm/lib to LD_LIBRARY_PATH is somehow acceptable but I don't think adding /opt/rocm/lib/rocm_sysdeps/lib is the right fix and needs another solution. The question is why those sysdeps aren't get picked up correctly and if we add this, the issue will be kind of masked in CI.

Copy link
Contributor Author

@WBobby WBobby Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point — adding rocm_sysdeps/lib to LD_LIBRARY_PATH is intentionally a short-term CI unblock, not the proper fix. Let me explain why it's needed here and what the real fix looks like.

The proper fix (JAX side)

I'll submit a fix to JAX's jaxlib/rocm/rocm_rpath.bzl to add rocm_sysdeps/lib to _WHEEL_RPATHS:

Once that lands and new wheels are published, the LD_LIBRARY_PATH line here can be simplified to just ${DEST}/lib (or removed entirely if the symlink is sufficient).

Why this CI PR is still needed

Even after the JAX RPATH fix, this workflow still needs:

  1. ln -sfn — ROCm installs to /opt/rocm-version/ but RUNPATH is hardcoded to /opt/rocm/lib
  2. JAX_PLATFORMS=rocm — without this, tests silently pass on CPU (896 passed, 0 on GPU)

So this PR unblocks CI now; the JAX-side fix eliminates the need for rocm_sysdeps/lib in LD_LIBRARY_PATH going forward.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JAX_PLATFORMS=rocm is fine and doesn't concern me. I am also aware that you need to set some path but pointing to /opt/rocm/lib (and symlinking) should be it. So the what we set LD_LIBRARY_PATH should be different or rather we only want this as a short term fix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I submitted a JAX PR (ROCm/jax#737). After it is merged and confirmed working, I can remove the LD_LIBRARY_PATH part from this PR.

@WBobby WBobby requested review from ScottTodd and marbre March 17, 2026 16:17
JAX wheels have RUNPATH hardcoded to /opt/rocm/lib but ROCm installs
to /opt/rocm-<version>/. Add symlink to resolve this. Set
JAX_PLATFORMS=rocm to prevent silent CPU fallback.

Requires ROCm/jax#737 for rocm_sysdeps/lib RUNPATH coverage.

Fixes #3627

Signed-off-by: Yanyao Wang <wangyanyao@msn.com>
@WBobby
Copy link
Contributor Author

WBobby commented Mar 18, 2026

Submitted a follow-up PR after ROCm/jax#737 was merged.

Copy link
Member

@marbre marbre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please update the PR description? I think you'd still need to set LD_LIBRARY_PATH as band-aid or not?

@WBobby
Copy link
Contributor Author

WBobby commented Mar 18, 2026

Updated the PR description.

@WBobby WBobby requested a review from marbre March 18, 2026 14:46
@marbre
Copy link
Member

marbre commented Mar 19, 2026

Any update on the test plan?

@WBobby
Copy link
Contributor Author

WBobby commented Mar 19, 2026

JAX wheel builds are currently blocked by #3876 (missing libamd_comgr_stub.a). This affects all JAX versions (v0.8.2 & v0.9.0) across all Python versions and architectures. The latest nightly run (#626, today Mar 19) still fails with the same error. This means no JAX wheel tests can run end-to-end in CI right now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: TODO

Development

Successfully merging this pull request may close these issues.

[Issue]: JAX wheels not built with rocm support

2 participants