-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsetup.py
More file actions
52 lines (39 loc) · 1.39 KB
/
Copy pathsetup.py
File metadata and controls
52 lines (39 loc) · 1.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import sys
from pathlib import Path
from setuptools import setup
def build_cuda_extensions():
from torch.utils.cpp_extension import CUDAExtension, CUDA_HOME
if CUDA_HOME is None:
raise RuntimeError(
"CUDA toolkit was not found. Build the extension on a Linux host with CUDA installed."
)
root = Path(__file__).parent
csrc = root / "kernel_lab" / "ops" / "cuda" / "csrc"
return [
CUDAExtension(
name="kernel_lab_cuda",
sources=[
str(csrc / "bindings.cpp"),
str(csrc / "softmax.cu"),
str(csrc / "rmsnorm.cu"),
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": ["-O3", "--use_fast_math"],
},
)
]
def should_build_cuda_extension() -> bool:
return "build_ext" in sys.argv
if __name__ == "__main__":
setup_kwargs = {"name": "kernel_lab_cuda"}
if should_build_cuda_extension():
try:
from torch.utils.cpp_extension import BuildExtension
except ModuleNotFoundError as exc:
raise RuntimeError(
"Building the CUDA extension requires torch in the active environment."
) from exc
setup_kwargs["ext_modules"] = build_cuda_extensions()
setup_kwargs["cmdclass"] = {"build_ext": BuildExtension}
setup(**setup_kwargs)