diff --git a/cumm/common.py b/cumm/common.py index a69d3a8..2e20b45 100644 --- a/cumm/common.py +++ b/cumm/common.py @@ -278,7 +278,7 @@ def _get_cuda_include_lib(): except: pass - linux_cuda_root = Path("/usr/local/cuda") + linux_cuda_root = Path(os.getenv("CUDA_HOME", "/usr/local/cuda")) include = linux_cuda_root / f"include" lib64 = linux_cuda_root / f"lib64" assert linux_cuda_root.exists(), f"can't find cuda in {linux_cuda_root} install via cuda installer or conda first." @@ -857,4 +857,4 @@ def device_function(self): code = pccm.code() code.arg("a, b", "float") code.raw("return a + b;") - return code.ret("float") \ No newline at end of file + return code.ret("float")