fix(ci): configure ROCm library paths for JAX wheel tests#4012
fix(ci): configure ROCm library paths for JAX wheel tests#4012
Conversation
ROCm tarball installs to /opt/rocm-<version>/ but JAX wheels have RUNPATH hardcoded to /opt/rocm/lib and missing rocm_sysdeps/lib/. Add symlink, LD_LIBRARY_PATH, and JAX_PLATFORMS=rocm to fix. Fixes #3627 Signed-off-by: Yanyao Wang <wangyanyao@msn.com>
| # Set ROCm environment variables for dynamic linker. | ||
| # Both lib paths are required: main ROCm libs + vendored sysdeps libs | ||
| echo "ROCM_PATH=${DEST}" >> "$GITHUB_ENV" | ||
| echo "LD_LIBRARY_PATH=${DEST}/lib:${DEST}/lib/rocm_sysdeps/lib${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}" >> "$GITHUB_ENV" |
There was a problem hiding this comment.
I think adding /opt/rocm/lib to LD_LIBRARY_PATH is somehow acceptable but I don't think adding /opt/rocm/lib/rocm_sysdeps/lib is the right fix and needs another solution. The question is why those sysdeps aren't get picked up correctly and if we add this, the issue will be kind of masked in CI.
There was a problem hiding this comment.
Good point — adding rocm_sysdeps/lib to LD_LIBRARY_PATH is intentionally a short-term CI unblock, not the proper fix. Let me explain why it's needed here and what the real fix looks like.
The proper fix (JAX side)
I'll submit a fix to JAX's jaxlib/rocm/rocm_rpath.bzl to add rocm_sysdeps/lib to _WHEEL_RPATHS:
Once that lands and new wheels are published, the LD_LIBRARY_PATH line here can be simplified to just ${DEST}/lib (or removed entirely if the symlink is sufficient).
Why this CI PR is still needed
Even after the JAX RPATH fix, this workflow still needs:
- ln -sfn — ROCm installs to /opt/rocm-version/ but RUNPATH is hardcoded to /opt/rocm/lib
- JAX_PLATFORMS=rocm — without this, tests silently pass on CPU (896 passed, 0 on GPU)
So this PR unblocks CI now; the JAX-side fix eliminates the need for rocm_sysdeps/lib in LD_LIBRARY_PATH going forward.
There was a problem hiding this comment.
JAX_PLATFORMS=rocm is fine and doesn't concern me. I am also aware that you need to set some path but pointing to /opt/rocm/lib (and symlinking) should be it. So the what we set LD_LIBRARY_PATH should be different or rather we only want this as a short term fix.
There was a problem hiding this comment.
I submitted a JAX PR (ROCm/jax#737). After it is merged and confirmed working, I can remove the LD_LIBRARY_PATH part from this PR.
JAX wheels have RUNPATH hardcoded to /opt/rocm/lib but ROCm installs to /opt/rocm-<version>/. Add symlink to resolve this. Set JAX_PLATFORMS=rocm to prevent silent CPU fallback. Requires ROCm/jax#737 for rocm_sysdeps/lib RUNPATH coverage. Fixes #3627 Signed-off-by: Yanyao Wang <wangyanyao@msn.com>
|
Submitted a follow-up PR after ROCm/jax#737 was merged. |
marbre
left a comment
There was a problem hiding this comment.
Can you please update the PR description? I think you'd still need to set LD_LIBRARY_PATH as band-aid or not?
|
Updated the PR description. |
|
Any update on the test plan? |
|
JAX wheel builds are currently blocked by #3876 (missing libamd_comgr_stub.a). This affects all JAX versions (v0.8.2 & v0.9.0) across all Python versions and architectures. The latest nightly run (#626, today Mar 19) still fails with the same error. This means no JAX wheel tests can run end-to-end in CI right now. |
Motivation
JAX wheel tests in CI pass on CPU but GPU backend is completely broken, masking the issue:
/opt/rocm-<version>/without a/opt/rocmsymlink, so the hardcoded RUNPATH/opt/rocm/libin JAX wheels never resolvesJAX_PLATFORMS=rocm, JAX silently falls back to CPU — 896 tests passed on CPU, none on GPUFixes #3627
Changes
ln -sfn "${DEST}" /opt/rocm— create symlink so RUNPATH/opt/rocm/liband/opt/rocm/lib/rocm_sysdeps/libin JAX wheels resolve to the versioned install pathJAX_PLATFORMS: rocm— force ROCm backend in test step, preventing silent CPU fallbackWhy LD_LIBRARY_PATH is not needed
JAX plugin
.sofiles have all 25librocm_sysdeps_*.soas direct DT_NEEDED entries (XLA'sBUILD.tplusessrcsforsystem_libs). RUNPATH applies to all direct dependencies, so there is no transitivity issue.With the companion JAX fix (ROCm/jax#737 / jax-ml/jax#35978) adding
/opt/rocm/lib/rocm_sysdeps/libto the wheel RUNPATH, the symlink alone is sufficient:/opt/rocm/lib→${DEST}/lib→ main ROCm libs (libamdhip64.so.7,libhsa-runtime64.so.1, etc.)/opt/rocm/lib/rocm_sysdeps/lib→${DEST}/lib/rocm_sysdeps/lib→ 25 vendored sysdeps libsDependency: This PR requires ROCm/jax#737 to be merged and new wheels published before the
rocm_sysdepslibraries can be resolved withoutLD_LIBRARY_PATH.Test Plan
rocm_plugin_extensionimports successfully (noImportError)ROCm wheel install foundinstead ofNo ROCm wheel installation foundTest Result
Pending CI validation
Submission Checklist