Skip to content

Commit baa3e6f

Browse files
authored
Merge branch 'main' into issue2027
2 parents edc1e8b + db4e1c6 commit baa3e6f

7 files changed

Lines changed: 112 additions & 54 deletions

File tree

cuda_core/build_hooks.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import glob
1212
import os
1313
import re
14-
import subprocess
1514
import sys
1615
import tempfile
1716
import zipfile
@@ -185,28 +184,6 @@ def get_sources(mod_name):
185184
# related to free-threading builds.
186185
extra_compile_args += ["-DCYTHON_TRACE_NOGIL=1", "-DCYTHON_USE_SYS_MONITORING=0"]
187186

188-
# On Windows, _tensor_bridge.pyx needs a stub import library so the MSVC
189-
# linker can resolve the AOTI symbols (they live in torch_cpu.dll at
190-
# runtime). We generate the .lib from a .def file at build time.
191-
# Note: aoti_torch_get_current_cuda_stream lives in torch_cuda.dll and
192-
# is resolved lazily at runtime (not via the stub lib) — see
193-
# _tensor_bridge.pyx.
194-
_aoti_extra_link_args = []
195-
if sys.platform == "win32":
196-
_def_file = os.path.join("cuda", "core", "_include", "aoti_shim.def")
197-
_lib_file = os.path.join("build", "aoti_shim.lib")
198-
os.makedirs("build", exist_ok=True)
199-
subprocess.check_call( # noqa: S603
200-
["lib", f"/DEF:{_def_file}", f"/OUT:{_lib_file}", "/MACHINE:X64"], # noqa: S607
201-
stdout=subprocess.DEVNULL,
202-
)
203-
_aoti_extra_link_args = [_lib_file]
204-
205-
def get_extra_link_args(mod_name):
206-
if mod_name == "_tensor_bridge" and _aoti_extra_link_args:
207-
return extra_link_args + _aoti_extra_link_args
208-
return extra_link_args
209-
210187
ext_modules = tuple(
211188
Extension(
212189
f"cuda.core.{mod.replace(os.path.sep, '.')}",
@@ -218,7 +195,7 @@ def get_extra_link_args(mod_name):
218195
+ all_include_dirs,
219196
language="c++",
220197
extra_compile_args=extra_compile_args,
221-
extra_link_args=get_extra_link_args(mod),
198+
extra_link_args=extra_link_args,
222199
)
223200
for mod in module_names()
224201
)

cuda_core/cuda/core/_device_resources.pyx

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,18 @@ cdef class SMResourceOptions:
106106
Preferred co-scheduled SM count; the driver tries to satisfy
107107
this but may fall back to ``coscheduled_sm_count``.
108108
(Default to ``None``)
109+
backfill : bool or Sequence[bool], optional
110+
If ``True``, allow the driver to relax the co-scheduling
111+
constraint when assigning SMs. This enables requesting
112+
arbitrary aligned SM counts that the driver would otherwise
113+
reject due to hardware topology constraints.
114+
(Default to ``False``)
109115
"""
110116

111117
count: int | SequenceABC | None = None
112118
coscheduled_sm_count: int | SequenceABC | None = None
113119
preferred_coscheduled_sm_count: int | SequenceABC | None = None
120+
backfill: bool | SequenceABC = False
114121

115122

116123
@dataclass
@@ -172,6 +179,12 @@ cdef inline int _resolve_group_count(SMResourceOptions options) except?-1:
172179
n_groups,
173180
count_is_scalar,
174181
)
182+
_validate_split_field_length(
183+
options.backfill,
184+
"backfill",
185+
n_groups,
186+
count_is_scalar,
187+
)
175188
return n_groups
176189

177190

@@ -243,6 +256,7 @@ IF CUDA_CORE_BUILD_MAJOR >= 13:
243256
cdef list counts = _broadcast_field(options.count, n_groups)
244257
cdef list coscheduled = _broadcast_field(options.coscheduled_sm_count, n_groups)
245258
cdef list preferred = _broadcast_field(options.preferred_coscheduled_sm_count, n_groups)
259+
cdef list backfills = _broadcast_field(options.backfill, n_groups)
246260
cdef int i
247261

248262
for i in range(n_groups):
@@ -252,7 +266,10 @@ IF CUDA_CORE_BUILD_MAJOR >= 13:
252266
params[i].coscheduledSmCount = <unsigned int>(coscheduled[i])
253267
if preferred[i] is not None:
254268
params[i].preferredCoscheduledSmCount = <unsigned int>(preferred[i])
255-
params[i].flags = 0
269+
params[i].flags = (
270+
cydriver.CUdevSmResourceGroup_flags.CU_DEV_SM_RESOURCE_GROUP_BACKFILL
271+
if backfills[i] else 0
272+
)
256273
return 0
257274

258275

cuda_core/cuda/core/_include/aoti_shim.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
; At runtime the symbols resolve from torch_cpu.dll (loaded by 'import torch').
55
;
66
; IMPORTANT: Keep this export list in sync with the AOTI_SHIM_API declarations
7-
; in aoti_shim.h. build_hooks.py turns this file into the stub import library
7+
; in aoti_shim.h. setup.py turns this file into the stub import library
88
; that MSVC uses to link _tensor_bridge, so any added/removed/renamed AOTI
99
; symbol must be updated in both files.
1010
LIBRARY torch_cpu.dll

cuda_core/cuda/core/_include/aoti_shim.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ typedef struct AtenTensorOpaque* AtenTensorHandle;
5252

5353
/*
5454
* IMPORTANT: Keep the AOTI_SHIM_API declaration list below in sync with
55-
* aoti_shim.def. On Windows, build_hooks.py turns that .def file into the
56-
* stub import library that MSVC needs to link _tensor_bridge without making
57-
* PyTorch a build-time dependency. If you add, remove, or rename an
58-
* imported AOTI symbol here, update aoti_shim.def in the same change.
55+
* aoti_shim.def. On Windows, setup.py generates that stub import library
56+
* during build_ext so MSVC can link _tensor_bridge without making PyTorch a
57+
* build-time dependency. If you add, remove, or rename an imported AOTI
58+
* symbol here, update aoti_shim.def in the same change.
5959
*
6060
* Exception: aoti_torch_get_current_cuda_stream lives in torch_cuda (not
6161
* torch_cpu) and is resolved lazily at runtime — see _tensor_bridge.pyx.

cuda_core/docs/source/api.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ execution.
245245
CUDA system information and NVIDIA Management Library (NVML)
246246
------------------------------------------------------------
247247

248+
.. note::
249+
``cuda.core.system`` support requires ``cuda_bindings`` 12.9.6 or later, or 13.2.0 or later.
250+
248251
Basic functions
249252
```````````````
250253

cuda_core/setup.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import os
6+
from pathlib import Path
67

78
import build_hooks # our build backend
89
from setuptools import setup
@@ -11,11 +12,63 @@
1112

1213
nthreads = int(os.environ.get("CUDA_PYTHON_PARALLEL_LEVEL", os.cpu_count() // 2))
1314
coverage_mode = bool(int(os.environ.get("CUDA_PYTHON_COVERAGE", "0")))
15+
_ROOT_DIR = Path(__file__).resolve().parent
16+
_AOTI_SHIM_DEF_FILE = _ROOT_DIR / "cuda" / "core" / "_include" / "aoti_shim.def"
17+
_AOTI_SHIM_LIB_FILE = _ROOT_DIR / "build" / "aoti_shim.lib"
18+
_TENSOR_BRIDGE_EXT_NAME = "cuda.core._tensor_bridge"
19+
20+
21+
def _ensure_compiler_initialized(compiler, plat_name):
22+
initialize = getattr(compiler, "initialize", None)
23+
if callable(initialize) and not getattr(compiler, "initialized", False):
24+
if plat_name is None:
25+
initialize()
26+
else:
27+
initialize(plat_name)
28+
29+
30+
def _build_aoti_shim_lib(compiler):
31+
# Reuse setuptools' initialized MSVC compiler instead of rediscovering
32+
# lib.exe separately in the build backend.
33+
lib_exe = getattr(compiler, "lib", None)
34+
if not lib_exe:
35+
raise RuntimeError("MSVC compiler did not expose lib.exe after initialization.")
36+
37+
_AOTI_SHIM_LIB_FILE.parent.mkdir(exist_ok=True)
38+
compiler.spawn(
39+
[
40+
lib_exe,
41+
f"/DEF:{_AOTI_SHIM_DEF_FILE}",
42+
f"/OUT:{_AOTI_SHIM_LIB_FILE}",
43+
"/MACHINE:X64",
44+
]
45+
)
46+
return str(_AOTI_SHIM_LIB_FILE)
1447

1548

1649
class build_ext(_build_ext): # noqa: N801
50+
def _configure_windows_tensor_bridge(self):
51+
if os.name != "nt" or getattr(self.compiler, "compiler_type", None) != "msvc":
52+
return
53+
54+
# _tensor_bridge imports AOTI symbols from torch_cpu.dll, which on
55+
# Windows requires a stub import library for the MSVC linker.
56+
for ext in self.extensions:
57+
if ext.name != _TENSOR_BRIDGE_EXT_NAME:
58+
continue
59+
60+
_ensure_compiler_initialized(self.compiler, self.plat_name)
61+
shim_lib = _build_aoti_shim_lib(self.compiler)
62+
link_args = list(ext.extra_link_args or [])
63+
if shim_lib not in link_args:
64+
ext.extra_link_args = [*link_args, shim_lib]
65+
return
66+
67+
raise RuntimeError(f"Failed to find extension {_TENSOR_BRIDGE_EXT_NAME!r} for Windows build.")
68+
1769
def build_extensions(self):
1870
self.parallel = nthreads
71+
self._configure_windows_tensor_bridge()
1972
super().build_extensions()
2073

2174

cuda_core/tests/test_green_context.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,16 @@ def fill_kernel(init_cuda):
8282
return mod.get_kernel("fill")
8383

8484

85-
def _aligned_half(sm):
86-
"""Compute half the SM count, rounded down to min_partition_size alignment."""
85+
def _safe_two_group_count(sm):
86+
"""Return a safe per-group SM count for a 2-group split.
87+
88+
Uses min_partition_size which is always a valid split size regardless
89+
of hardware topology. Returns None if the device doesn't have enough SMs.
90+
"""
8791
min_size = sm.min_partition_size
88-
half = (sm.sm_count // 2 // min_size) * min_size
89-
return half
92+
if sm.sm_count < 2 * min_size:
93+
return None
94+
return min_size
9095

9196

9297
@contextlib.contextmanager
@@ -238,30 +243,33 @@ def test_discovery_respects_alignment(self, sm_resource):
238243
assert groups[0].sm_count % sm_resource.coscheduled_alignment == 0
239244

240245
def test_two_groups(self, sm_resource):
241-
"""Two-group split with explicit aligned counts."""
242-
half = _aligned_half(sm_resource)
243-
if half < sm_resource.min_partition_size:
246+
"""Two-group split with min_partition_size (always topology-safe)."""
247+
count = _safe_two_group_count(sm_resource)
248+
if count is None:
244249
pytest.skip("Not enough SMs for a 2-group split")
245250

246-
groups, rem = sm_resource.split(SMResourceOptions(count=(half, half)))
251+
groups, rem = sm_resource.split(SMResourceOptions(count=(count, count)))
247252

248253
assert len(groups) == 2
249-
assert groups[0].sm_count > 0
250-
assert groups[1].sm_count > 0
254+
assert groups[0].sm_count >= count
255+
assert groups[1].sm_count >= count
251256
total = groups[0].sm_count + groups[1].sm_count + rem.sm_count
252257
assert total <= sm_resource.sm_count
253258

254-
def test_two_groups_each_meets_request(self, sm_resource):
255-
min_size = sm_resource.min_partition_size
256-
half = _aligned_half(sm_resource)
257-
if half < min_size:
258-
pytest.skip("Not enough SMs for a 2-group split")
259+
def test_two_groups_backfill(self, sm_resource):
260+
"""Two-group split with backfill allows larger partitions."""
261+
align = sm_resource.coscheduled_alignment
262+
if align == 0:
263+
align = sm_resource.min_partition_size
264+
half = (sm_resource.sm_count // 2 // align) * align
265+
if half < sm_resource.min_partition_size:
266+
pytest.skip("Not enough SMs for a 2-group backfill split")
259267

260-
groups, _ = sm_resource.split(SMResourceOptions(count=(min_size, min_size)))
268+
groups, rem = sm_resource.split(SMResourceOptions(count=(half, half), backfill=True))
261269

262270
assert len(groups) == 2
263-
assert groups[0].sm_count >= min_size
264-
assert groups[1].sm_count >= min_size
271+
assert groups[0].sm_count >= half
272+
assert groups[1].sm_count >= half
265273

266274
def test_dry_run_matches_real(self, sm_resource):
267275
"""Dry-run reports the same SM counts as a real split."""
@@ -352,11 +360,11 @@ def test_green_ctx_sm_resources(self, green_ctx, sm_resource):
352360

353361
def test_green_ctx_resources_reflect_partition(self, init_cuda, sm_resource):
354362
"""Two green contexts should have disjoint SM partitions."""
355-
half = _aligned_half(sm_resource)
356-
if half < sm_resource.min_partition_size:
363+
count = _safe_two_group_count(sm_resource)
364+
if count is None:
357365
pytest.skip("Not enough SMs for a 2-group split")
358366

359-
groups, _ = sm_resource.split(SMResourceOptions(count=(half, half)))
367+
groups, _ = sm_resource.split(SMResourceOptions(count=(count, count)))
360368

361369
ctx_a = ctx_b = None
362370
try:
@@ -425,11 +433,11 @@ def test_launch_and_verify(self, init_cuda, green_ctx, fill_kernel):
425433
def test_two_green_contexts_independent(self, init_cuda, sm_resource, fill_kernel):
426434
"""Two SM groups -> two green contexts -> two independent kernels."""
427435
dev = init_cuda
428-
half = _aligned_half(sm_resource)
429-
if half < sm_resource.min_partition_size:
436+
count = _safe_two_group_count(sm_resource)
437+
if count is None:
430438
pytest.skip("Not enough SMs for a 2-group split")
431439

432-
groups, _ = sm_resource.split(SMResourceOptions(count=(half, half)))
440+
groups, _ = sm_resource.split(SMResourceOptions(count=(count, count)))
433441
assert len(groups) == 2
434442

435443
ctx_a = ctx_b = None

0 commit comments

Comments
 (0)