diff --git a/csrc/hybrid_ep/jit/compiler.cu b/csrc/hybrid_ep/jit/compiler.cu index 128aba4d..de485b2b 100644 --- a/csrc/hybrid_ep/jit/compiler.cu +++ b/csrc/hybrid_ep/jit/compiler.cu @@ -17,6 +17,27 @@ inline std::string get_env(std::string name) { return std::string(env); } +std::string get_nixl_lib_dir(const std::string& nixl_home) { + const std::string lib_dir = nixl_home + "/lib"; + const std::string x86_lib_dir = lib_dir + "/x86_64-linux-gnu"; + const std::string arm_lib_dir = lib_dir + "/aarch64-linux-gnu"; +#if defined(__aarch64__) + std::vector candidates = {arm_lib_dir, x86_lib_dir, lib_dir}; +#elif defined(__x86_64__) + std::vector candidates = {x86_lib_dir, arm_lib_dir, lib_dir}; +#else + std::vector candidates = {lib_dir, x86_lib_dir, arm_lib_dir}; +#endif + + for (const auto& path : candidates) { + if (std::filesystem::exists(path)) { + return path; + } + } + + return lib_dir; +} + std::string get_jit_dir() { std::string cache_dir = get_env("HYBRID_EP_CACHE_DIR"); if (cache_dir.empty()) { @@ -68,7 +89,7 @@ NVCCCompiler::NVCCCompiler(std::string base_path, std::string comm_id): include += " -I" + nixl_home + "/include "; include += " -I" + nixl_home + "/include/gpu/ucx "; include += " -I" + ucx_home + "/include "; - std::string nixl_lib = nixl_home + "/lib/x86_64-linux-gnu"; + std::string nixl_lib = get_nixl_lib_dir(nixl_home); library += " -L" + nixl_lib + " -lnixl -lnixl_build -lnixl_common "; library += " -Xlinker -rpath -Xlinker " + nixl_lib + " "; #else