Skip to content

Commit 08f9f56

Browse files
committed
fix(test): import SUPPORTED_TARGETS directly instead of parsing pyx
Drop the elaborate tokenize+AST walk that the cross-check test was using to extract `SUPPORTED_TARGETS` from the pyx source. Drop the `cdef` qualifier on the dict in `_program.pyx` (it gated Python visibility -- removing it gives the test a direct import) and rewrite the test as a plain comparison: import both views, check that NVRTC / NVVM line up with their code-type counterparts and that the two linker backends (nvJitLink, driver) match the ptx entry. Performance impact of dropping `cdef` is one Python-level dict `.get` per `Program.compile` call, which is negligible against the NVRTC/linker work that follows.
1 parent 547d3f3 commit 08f9f56

2 files changed

Lines changed: 15 additions & 48 deletions

File tree

cuda_core/cuda/core/_program.pyx

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,8 +1035,11 @@ cdef object Program_compile_nvvm(Program self, str target_type, object logs):
10351035

10361036
return ObjectCode._init(bytes(data), target_type, name=self._options.name)
10371037

1038-
# Supported target types per backend
1039-
cdef dict SUPPORTED_TARGETS = {
1038+
# Supported target types per backend. Plain Python-level binding (not
1039+
# ``cdef``) so the cache layer's cross-check test can import it directly
1040+
# instead of parsing this source file. Lookup performance inside
1041+
# ``Program_compile`` is unaffected -- it's one dict ``.get`` per compile.
1042+
SUPPORTED_TARGETS = {
10401043
CompilerBackendType.NVRTC: (ObjectCodeFormatType.PTX, ObjectCodeFormatType.CUBIN, ObjectCodeFormatType.LTOIR),
10411044
CompilerBackendType.NVVM: (ObjectCodeFormatType.PTX, ObjectCodeFormatType.LTOIR),
10421045
CompilerBackendType.NVJITLINK: (ObjectCodeFormatType.CUBIN, ObjectCodeFormatType.PTX),

cuda_core/tests/test_program_cache.py

Lines changed: 10 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -343,56 +343,20 @@ def test_make_program_cache_key_rejects(kwargs, exc_type, match):
343343

344344
def test_make_program_cache_key_supported_targets_matches_program_compile():
345345
"""``_SUPPORTED_TARGETS_BY_CODE_TYPE`` duplicates the backend target
346-
matrix in ``_program.pyx``. Guard against drift: parse the pyx source
347-
with :mod:`tokenize` (which skips string literals and comments) to
348-
extract ``SUPPORTED_TARGETS`` and assert the two views agree."""
349-
import ast
350-
import io
351-
import tokenize
352-
from pathlib import Path
353-
346+
matrix in ``_program.pyx``. Import the pyx-side dict directly and
347+
compare; both linker backends (nvJitLink, cuLink/driver) must agree
348+
with the ``"ptx"`` code-type entry on the cache side."""
349+
from cuda.core._program import SUPPORTED_TARGETS
350+
from cuda.core.typing import CompilerBackendType
354351
from cuda.core.utils._program_cache._keys import _SUPPORTED_TARGETS_BY_CODE_TYPE
355352

356-
backend_to_code_type = {"NVRTC": "c++", "NVVM": "nvvm"}
357-
linker_backends = ("nvJitLink", "driver")
358-
359-
pyx = Path(__file__).parent.parent / "cuda" / "core" / "_program.pyx"
360-
text = pyx.read_text()
361-
marker_idx = text.index("cdef dict SUPPORTED_TARGETS")
362-
tokens = tokenize.generate_tokens(io.StringIO(text[marker_idx:]).readline)
363-
364-
depth = 0
365-
start_offset = None
366-
end_offset = None
367-
lines = text[marker_idx:].splitlines(keepends=True)
368-
line_starts = [0]
369-
for line in lines[:-1]:
370-
line_starts.append(line_starts[-1] + len(line))
371-
372-
def _offset(row, col):
373-
return line_starts[row - 1] + col
374-
375-
for tok in tokens:
376-
if tok.type != tokenize.OP:
377-
continue
378-
if tok.string == "{":
379-
if depth == 0:
380-
start_offset = _offset(tok.start[0], tok.start[1])
381-
depth += 1
382-
elif tok.string == "}":
383-
depth -= 1
384-
if depth == 0:
385-
end_offset = _offset(tok.end[0], tok.end[1])
386-
break
387-
assert start_offset is not None and end_offset is not None, "could not locate SUPPORTED_TARGETS literal"
388-
pyx_targets = ast.literal_eval(text[marker_idx + start_offset : marker_idx + end_offset])
353+
backend_to_code_type = {CompilerBackendType.NVRTC: "c++", CompilerBackendType.NVVM: "nvvm"}
354+
linker_backends = (CompilerBackendType.NVJITLINK, CompilerBackendType.DRIVER)
389355

390356
for backend, code_type in backend_to_code_type.items():
391-
assert frozenset(pyx_targets[backend]) == _SUPPORTED_TARGETS_BY_CODE_TYPE[code_type], (
392-
backend,
393-
code_type,
394-
)
395-
linker_sets = [frozenset(pyx_targets[b]) for b in linker_backends]
357+
pyx_set = frozenset(str(t) for t in SUPPORTED_TARGETS[backend])
358+
assert pyx_set == _SUPPORTED_TARGETS_BY_CODE_TYPE[code_type], (backend, code_type)
359+
linker_sets = [frozenset(str(t) for t in SUPPORTED_TARGETS[b]) for b in linker_backends]
396360
assert all(s == linker_sets[0] for s in linker_sets)
397361
assert linker_sets[0] == _SUPPORTED_TARGETS_BY_CODE_TYPE["ptx"]
398362

0 commit comments

Comments
 (0)