Skip to content

Commit db4e1c6

Browse files
rwgkcursoragent
andauthored
Move cuda.core Windows stub-lib generation into build_ext. (#2033)
Reuse setuptools' initialized MSVC compiler when linking _tensor_bridge so Windows source builds do not need separate toolchain discovery in build_hooks.py. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent fec00bf commit db4e1c6

4 files changed

Lines changed: 59 additions & 29 deletions

File tree

cuda_core/build_hooks.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import glob
1212
import os
1313
import re
14-
import subprocess
1514
import sys
1615
import tempfile
1716
import zipfile
@@ -185,28 +184,6 @@ def get_sources(mod_name):
185184
# related to free-threading builds.
186185
extra_compile_args += ["-DCYTHON_TRACE_NOGIL=1", "-DCYTHON_USE_SYS_MONITORING=0"]
187186

188-
# On Windows, _tensor_bridge.pyx needs a stub import library so the MSVC
189-
# linker can resolve the AOTI symbols (they live in torch_cpu.dll at
190-
# runtime). We generate the .lib from a .def file at build time.
191-
# Note: aoti_torch_get_current_cuda_stream lives in torch_cuda.dll and
192-
# is resolved lazily at runtime (not via the stub lib) — see
193-
# _tensor_bridge.pyx.
194-
_aoti_extra_link_args = []
195-
if sys.platform == "win32":
196-
_def_file = os.path.join("cuda", "core", "_include", "aoti_shim.def")
197-
_lib_file = os.path.join("build", "aoti_shim.lib")
198-
os.makedirs("build", exist_ok=True)
199-
subprocess.check_call( # noqa: S603
200-
["lib", f"/DEF:{_def_file}", f"/OUT:{_lib_file}", "/MACHINE:X64"], # noqa: S607
201-
stdout=subprocess.DEVNULL,
202-
)
203-
_aoti_extra_link_args = [_lib_file]
204-
205-
def get_extra_link_args(mod_name):
206-
if mod_name == "_tensor_bridge" and _aoti_extra_link_args:
207-
return extra_link_args + _aoti_extra_link_args
208-
return extra_link_args
209-
210187
ext_modules = tuple(
211188
Extension(
212189
f"cuda.core.{mod.replace(os.path.sep, '.')}",
@@ -218,7 +195,7 @@ def get_extra_link_args(mod_name):
218195
+ all_include_dirs,
219196
language="c++",
220197
extra_compile_args=extra_compile_args,
221-
extra_link_args=get_extra_link_args(mod),
198+
extra_link_args=extra_link_args,
222199
)
223200
for mod in module_names()
224201
)

cuda_core/cuda/core/_include/aoti_shim.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
; At runtime the symbols resolve from torch_cpu.dll (loaded by 'import torch').
55
;
66
; IMPORTANT: Keep this export list in sync with the AOTI_SHIM_API declarations
7-
; in aoti_shim.h. build_hooks.py turns this file into the stub import library
7+
; in aoti_shim.h. setup.py turns this file into the stub import library
88
; that MSVC uses to link _tensor_bridge, so any added/removed/renamed AOTI
99
; symbol must be updated in both files.
1010
LIBRARY torch_cpu.dll

cuda_core/cuda/core/_include/aoti_shim.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ typedef struct AtenTensorOpaque* AtenTensorHandle;
5252

5353
/*
5454
* IMPORTANT: Keep the AOTI_SHIM_API declaration list below in sync with
55-
* aoti_shim.def. On Windows, build_hooks.py turns that .def file into the
56-
* stub import library that MSVC needs to link _tensor_bridge without making
57-
* PyTorch a build-time dependency. If you add, remove, or rename an
58-
* imported AOTI symbol here, update aoti_shim.def in the same change.
55+
* aoti_shim.def. On Windows, setup.py generates that stub import library
56+
* during build_ext so MSVC can link _tensor_bridge without making PyTorch a
57+
* build-time dependency. If you add, remove, or rename an imported AOTI
58+
* symbol here, update aoti_shim.def in the same change.
5959
*
6060
* Exception: aoti_torch_get_current_cuda_stream lives in torch_cuda (not
6161
* torch_cpu) and is resolved lazily at runtime — see _tensor_bridge.pyx.

cuda_core/setup.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import os
6+
from pathlib import Path
67

78
import build_hooks # our build backend
89
from setuptools import setup
@@ -11,11 +12,63 @@
1112

1213
nthreads = int(os.environ.get("CUDA_PYTHON_PARALLEL_LEVEL", os.cpu_count() // 2))
1314
coverage_mode = bool(int(os.environ.get("CUDA_PYTHON_COVERAGE", "0")))
15+
_ROOT_DIR = Path(__file__).resolve().parent
16+
_AOTI_SHIM_DEF_FILE = _ROOT_DIR / "cuda" / "core" / "_include" / "aoti_shim.def"
17+
_AOTI_SHIM_LIB_FILE = _ROOT_DIR / "build" / "aoti_shim.lib"
18+
_TENSOR_BRIDGE_EXT_NAME = "cuda.core._tensor_bridge"
19+
20+
21+
def _ensure_compiler_initialized(compiler, plat_name):
22+
initialize = getattr(compiler, "initialize", None)
23+
if callable(initialize) and not getattr(compiler, "initialized", False):
24+
if plat_name is None:
25+
initialize()
26+
else:
27+
initialize(plat_name)
28+
29+
30+
def _build_aoti_shim_lib(compiler):
31+
# Reuse setuptools' initialized MSVC compiler instead of rediscovering
32+
# lib.exe separately in the build backend.
33+
lib_exe = getattr(compiler, "lib", None)
34+
if not lib_exe:
35+
raise RuntimeError("MSVC compiler did not expose lib.exe after initialization.")
36+
37+
_AOTI_SHIM_LIB_FILE.parent.mkdir(exist_ok=True)
38+
compiler.spawn(
39+
[
40+
lib_exe,
41+
f"/DEF:{_AOTI_SHIM_DEF_FILE}",
42+
f"/OUT:{_AOTI_SHIM_LIB_FILE}",
43+
"/MACHINE:X64",
44+
]
45+
)
46+
return str(_AOTI_SHIM_LIB_FILE)
1447

1548

1649
class build_ext(_build_ext): # noqa: N801
50+
def _configure_windows_tensor_bridge(self):
51+
if os.name != "nt" or getattr(self.compiler, "compiler_type", None) != "msvc":
52+
return
53+
54+
# _tensor_bridge imports AOTI symbols from torch_cpu.dll, which on
55+
# Windows requires a stub import library for the MSVC linker.
56+
for ext in self.extensions:
57+
if ext.name != _TENSOR_BRIDGE_EXT_NAME:
58+
continue
59+
60+
_ensure_compiler_initialized(self.compiler, self.plat_name)
61+
shim_lib = _build_aoti_shim_lib(self.compiler)
62+
link_args = list(ext.extra_link_args or [])
63+
if shim_lib not in link_args:
64+
ext.extra_link_args = [*link_args, shim_lib]
65+
return
66+
67+
raise RuntimeError(f"Failed to find extension {_TENSOR_BRIDGE_EXT_NAME!r} for Windows build.")
68+
1769
def build_extensions(self):
1870
self.parallel = nthreads
71+
self._configure_windows_tensor_bridge()
1972
super().build_extensions()
2073

2174

0 commit comments

Comments
 (0)