@@ -60,38 +60,37 @@ cdef int cuPythonInit() except -1 nogil:
6060 except:
6161 handle = None
6262
63- # Else try default search
64- if not handle:
65- LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
66- try:
67- handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
68- except:
69- pass
70-
71- # Final check if DLLs can be found within pip installations
63+ # Check if DLLs can be found within pip installations
7264 if not handle:
7365 site_packages = [site.getusersitepackages()] + site.getsitepackages()
7466 for sp in site_packages:
7567 mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin")
76- if not os.path.isdir(mod_path):
77- continue
78- else:
68+ if os.path.isdir(mod_path):
7969 os.add_dll_directory(mod_path)
80- break
81- LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
82- LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
83- try:
84- handle = win32api.LoadLibraryEx(
85- # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
86- os.path.join(mod_path, "nvrtc64_120_0.dll"),
87- 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
88-
89- # Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
90- # located in the same mod_path.
91- # Update PATH environ so that the two dlls can find each other
92- os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
93- except:
94- pass
70+ LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
71+ LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
72+ try:
73+ handle = win32api.LoadLibraryEx(
74+ # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
75+ os.path.join(mod_path, "nvrtc64_120_0.dll"),
76+ 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
77+
78+ # Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
79+ # located in the same mod_path.
80+ # Update PATH environ so that the two dlls can find each other
81+ os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
82+ except:
83+ pass
84+ else:
85+ break
86+ else:
87+ # Else try default search
88+ # Only reached if DLL wasn't found in any site-package path
89+ LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
90+ try:
91+ handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
92+ except:
93+ pass
9594
9695 if not handle:
9796 raise RuntimeError('Failed to LoadLibraryEx nvrtc64_120_0.dll')
0 commit comments