1111import glob
1212import os
1313import re
14- import subprocess
1514
1615from Cython .Build import cythonize
1716from setuptools import Extension
2625
2726
2827@functools .cache
29- def _get_proper_cuda_bindings_major_version () -> str :
30- # for local development (with/without build isolation)
31- try :
32- import cuda .bindings
28+ def _get_cuda_paths () -> list [str ]:
29+ CUDA_PATH = os .environ .get ("CUDA_PATH" , os .environ .get ("CUDA_HOME" , None ))
30+ if not CUDA_PATH :
31+ raise RuntimeError ("Environment variable CUDA_PATH or CUDA_HOME is not set" )
32+ CUDA_PATH = CUDA_PATH .split (os .pathsep )
33+ print ("CUDA paths:" , CUDA_PATH )
34+ return CUDA_PATH
3335
34- return cuda .bindings .__version__ .split ("." )[0 ]
35- except ImportError :
36- pass
3736
38- # for custom overwrite, e.g. in CI
37+ @functools .cache
38+ def _determine_cuda_major_version () -> str :
39+ """Determine the CUDA major version for building cuda.core.
40+
41+ This version is used for two purposes:
42+ 1. Determining which cuda-bindings version to install as a build dependency
43+ 2. Setting CUDA_CORE_BUILD_MAJOR for Cython compile-time conditionals
44+
45+ The version is derived from (in order of priority):
46+ 1. CUDA_CORE_BUILD_MAJOR environment variable (explicit override, e.g. in CI)
47+ 2. CUDA_VERSION macro in cuda.h from CUDA_PATH or CUDA_HOME
48+
49+ Since CUDA_PATH or CUDA_HOME is required for the build (to provide include
50+ directories), the cuda.h header should always be available.
51+ """
52+ # Explicit override, e.g. in CI.
3953 cuda_major = os .environ .get ("CUDA_CORE_BUILD_MAJOR" )
4054 if cuda_major is not None :
55+ print ("CUDA MAJOR VERSION:" , cuda_major )
4156 return cuda_major
4257
43- # also for local development
44- try :
45- out = subprocess .run ("nvidia-smi" , env = os .environ , capture_output = True , check = True ) # noqa: S603, S607
46- m = re .search (r"CUDA Version:\s*([\d\.]+)" , out .stdout .decode ())
47- if m :
48- return m .group (1 ).split ("." )[0 ]
49- except (FileNotFoundError , subprocess .CalledProcessError ):
50- # the build machine has no driver installed
51- pass
52-
53- # default fallback
54- return "13"
58+ # Derive from the CUDA headers (the authoritative source for what we compile against).
59+ cuda_path = _get_cuda_paths ()
60+ for root in cuda_path :
61+ cuda_h = os .path .join (root , "include" , "cuda.h" )
62+ try :
63+ with open (cuda_h , encoding = "utf-8" ) as f :
64+ for line in f :
65+ m = re .match (r"^#\s*define\s+CUDA_VERSION\s+(\d+)\s*$" , line )
66+ if m :
67+ v = int (m .group (1 ))
68+ # CUDA_VERSION is e.g. 12020 for 12.2.
69+ cuda_major = str (v // 1000 )
70+ print ("CUDA MAJOR VERSION:" , cuda_major )
71+ return cuda_major
72+ except OSError :
73+ continue
74+
75+ # CUDA_PATH or CUDA_HOME is required for the build, so we should not reach here
76+ # in normal circumstances. Raise an error to make the issue clear.
77+ raise RuntimeError (
78+ "Cannot determine CUDA major version. "
79+ "Set CUDA_CORE_BUILD_MAJOR environment variable, or ensure CUDA_PATH or CUDA_HOME "
80+ "points to a valid CUDA installation with include/cuda.h."
81+ )
5582
5683
5784# used later by setup()
@@ -68,25 +95,12 @@ def _build_cuda_core():
6895
6996 # It seems setuptools' wildcard support has problems for namespace packages,
7097 # so we explicitly spell out all Extension instances.
71- root_module = "cuda.core"
72- root_path = f"{ os .path .sep } " .join (root_module .split ("." )) + os .path .sep
73- ext_files = glob .glob (f"{ root_path } /**/*.pyx" , recursive = True )
74-
75- def strip_prefix_suffix (filename ):
76- return filename [len (root_path ) : - 4 ]
77-
78- module_names = (strip_prefix_suffix (f ) for f in ext_files )
79-
80- @functools .cache
81- def get_cuda_paths ():
82- CUDA_PATH = os .environ .get ("CUDA_PATH" , os .environ .get ("CUDA_HOME" , None ))
83- if not CUDA_PATH :
84- raise RuntimeError ("Environment variable CUDA_PATH or CUDA_HOME is not set" )
85- CUDA_PATH = CUDA_PATH .split (os .pathsep )
86- print ("CUDA paths:" , CUDA_PATH )
87- return CUDA_PATH
98+ def module_names ():
99+ root_path = os .path .sep .join (["cuda" , "core" , "" ])
100+ for filename in glob .glob (f"{ root_path } /**/*.pyx" , recursive = True ):
101+ yield filename [len (root_path ) : - 4 ]
88102
89- all_include_dirs = list (os .path .join (root , "include" ) for root in get_cuda_paths ())
103+ all_include_dirs = list (os .path .join (root , "include" ) for root in _get_cuda_paths ())
90104 extra_compile_args = []
91105 if COMPILE_FOR_COVERAGE :
92106 # CYTHON_TRACE_NOGIL indicates to trace nogil functions. It is not
@@ -101,11 +115,11 @@ def get_cuda_paths():
101115 language = "c++" ,
102116 extra_compile_args = extra_compile_args ,
103117 )
104- for mod in module_names
118+ for mod in module_names ()
105119 )
106120
107121 nthreads = int (os .environ .get ("CUDA_PYTHON_PARALLEL_LEVEL" , os .cpu_count () // 2 ))
108- compile_time_env = {"CUDA_CORE_BUILD_MAJOR" : int (_get_proper_cuda_bindings_major_version ())}
122+ compile_time_env = {"CUDA_CORE_BUILD_MAJOR" : int (_determine_cuda_major_version ())}
109123 compiler_directives = {"embedsignature" : True , "warn.deprecated.IF" : False , "freethreading_compatible" : True }
110124 if COMPILE_FOR_COVERAGE :
111125 compiler_directives ["linetrace" ] = True
@@ -132,7 +146,7 @@ def build_wheel(wheel_directory, config_settings=None, metadata_directory=None):
132146
133147
134148def _get_cuda_bindings_require ():
135- cuda_major = _get_proper_cuda_bindings_major_version ()
149+ cuda_major = _determine_cuda_major_version ()
136150 return [f"cuda-bindings=={ cuda_major } .*" ]
137151
138152
0 commit comments