Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion csrc/hybrid_ep/jit/compiler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,27 @@ inline std::string get_env(std::string name) {
return std::string(env);
}

std::string get_nixl_lib_dir(const std::string& nixl_home) {
const std::string lib_dir = nixl_home + "/lib";
const std::string x86_lib_dir = lib_dir + "/x86_64-linux-gnu";
const std::string arm_lib_dir = lib_dir + "/aarch64-linux-gnu";
#if defined(__aarch64__)
std::vector<std::string> candidates = {arm_lib_dir, x86_lib_dir, lib_dir};
#elif defined(__x86_64__)
std::vector<std::string> candidates = {x86_lib_dir, arm_lib_dir, lib_dir};
#else
std::vector<std::string> candidates = {lib_dir, x86_lib_dir, arm_lib_dir};
#endif

for (const auto& path : candidates) {
if (std::filesystem::exists(path)) {
return path;
}
}

return lib_dir;
}

std::string get_jit_dir() {
std::string cache_dir = get_env("HYBRID_EP_CACHE_DIR");
if (cache_dir.empty()) {
Expand Down Expand Up @@ -68,7 +89,7 @@ NVCCCompiler::NVCCCompiler(std::string base_path, std::string comm_id):
include += " -I" + nixl_home + "/include ";
include += " -I" + nixl_home + "/include/gpu/ucx ";
include += " -I" + ucx_home + "/include ";
std::string nixl_lib = nixl_home + "/lib/x86_64-linux-gnu";
std::string nixl_lib = get_nixl_lib_dir(nixl_home);
library += " -L" + nixl_lib + " -lnixl -lnixl_build -lnixl_common ";
library += " -Xlinker -rpath -Xlinker " + nixl_lib + " ";
#else
Expand Down