Skip to content

Commit 44fd98e

Browse files
committed
Fix build errors, update tests, remove unused imports
- Change cdef function return types from ObjectCode to object (Cython limitation) - Remove unused imports: intptr_t, NvrtcProgramHandle, NvvmProgramHandle, as_intptr - Update as_py(NvvmProgramHandle) to return Python int via PyLong_FromSsize_t - Update test assertions: remove handle checks after close(), test idempotency instead - Update NVVM error message regex to match new unified format
1 parent c15b12e commit 44fd98e

5 files changed

Lines changed: 172 additions & 116 deletions

File tree

cuda_core/cuda/core/_cpp/resource_handles.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,8 @@ inline PyObject* as_py(const NvrtcProgramHandle& h) noexcept {
445445
}
446446

447447
inline PyObject* as_py(const NvvmProgramHandle& h) noexcept {
448-
return detail::make_py("cuda.bindings.nvvm", "nvvmProgram", as_intptr(h));
448+
// NVVM bindings use raw integers, not wrapper classes
449+
return PyLong_FromSsize_t(as_intptr(h));
449450
}
450451

451452
} // namespace cuda_core

cuda_core/cuda/core/_program.pyx

Lines changed: 163 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@ from warnings import warn
1414

1515
from cuda.bindings import driver, nvrtc
1616

17-
from libc.stdint cimport intptr_t
17+
from libcpp.vector cimport vector
1818

1919
from ._resource_handles cimport (
20-
NvrtcProgramHandle,
21-
NvvmProgramHandle,
22-
as_intptr,
20+
as_cu,
21+
as_py,
2322
create_nvrtc_program_handle,
2423
create_nvvm_program_handle,
2524
)
2625
from cuda.bindings cimport cynvrtc, cynvvm
26+
from cuda.core._utils.cuda_utils cimport HANDLE_RETURN_NVRTC, HANDLE_RETURN_NVVM
2727
from cuda.core._device import Device
2828
from cuda.core._linker import Linker, LinkerHandleT, LinkerOptions
2929
from cuda.core._module import ObjectCode
@@ -40,8 +40,11 @@ from cuda.core._utils.cuda_utils import (
4040

4141
__all__ = ["Program", "ProgramOptions"]
4242

43-
ProgramHandleT = nvrtc.nvrtcProgram | LinkerHandleT
44-
"""Type alias for program handle types across different backends."""
43+
ProgramHandleT = nvrtc.nvrtcProgram | int | LinkerHandleT
44+
"""Type alias for program handle types across different backends.
45+
46+
The ``int`` type covers NVVM handles, which don't have a wrapper class.
47+
"""
4548

4649

4750
# =============================================================================
@@ -76,8 +79,8 @@ cdef class Program:
7679
if self._linker:
7780
self._linker.close()
7881
# Reset handles - the C++ shared_ptr destructor handles cleanup
79-
self._h_nvrtc = NvrtcProgramHandle()
80-
self._h_nvvm = NvvmProgramHandle()
82+
self._h_nvrtc.reset()
83+
self._h_nvvm.reset()
8184

8285
def compile(
8386
self, target_type: str, name_expressions: tuple | list = (), logs = None
@@ -120,14 +123,11 @@ cdef class Program:
120123
handle, call ``int(Program.handle)``.
121124
"""
122125
if self._backend == "NVRTC":
123-
ptr = as_intptr(self._h_nvrtc)
124-
return nvrtc.nvrtcProgram(ptr) if ptr else None
126+
return as_py(self._h_nvrtc)
125127
elif self._backend == "NVVM":
126-
# NVVM uses raw integers for handles, not wrapper classes
127-
ptr = as_intptr(self._h_nvvm)
128-
return ptr if ptr else None
128+
return as_py(self._h_nvvm) # returns int (NVVM uses raw integers)
129129
else:
130-
return self._linker.handle if self._linker else None
130+
return self._linker.handle
131131

132132
@staticmethod
133133
def driver_can_load_nvrtc_ptx_output() -> bool:
@@ -392,7 +392,7 @@ class ProgramOptions:
392392
def _prepare_nvvm_options(self, as_bytes: bool = True) -> list[bytes] | list[str]:
393393
return _prepare_nvvm_options_impl(self, as_bytes)
394394
395-
def as_bytes(self, backend: str) -> list[bytes]:
395+
def as_bytes(self, backend: str, target_type: str | None = None) -> list[bytes]:
396396
"""Convert program options to bytes format for the specified backend.
397397

398398
This method transforms the program options into a format suitable for the
@@ -403,6 +403,9 @@ class ProgramOptions:
403403
----------
404404
backend : str
405405
The compiler backend to prepare options for. Must be either "nvrtc" or "nvvm".
406+
target_type : str, optional
407+
The compilation target type (e.g., "ptx", "cubin", "ltoir"). Some backends
408+
require additional options based on the target type.
406409

407410
Returns
408411
-------
@@ -425,7 +428,10 @@ class ProgramOptions:
425428
if backend == "nvrtc":
426429
return self._prepare_nvrtc_options()
427430
elif backend == "nvvm":
428-
return self._prepare_nvvm_options(as_bytes=True)
431+
options = self._prepare_nvvm_options(as_bytes=True)
432+
if target_type == "ltoir" and b"-gen-lto" not in options:
433+
options.append(b"-gen-lto")
434+
return options
429435
else:
430436
raise ValueError(f"Unknown backend '{backend}'. Must be one of: 'nvrtc', 'nvvm'")
431437
@@ -530,15 +536,27 @@ cdef inline object _translate_program_options(object options):
530536
531537
cdef inline int Program_init(Program self, object code, str code_type, object options) except -1:
532538
"""Initialize a Program instance."""
539+
cdef cynvrtc.nvrtcProgram nvrtc_prog
540+
cdef cynvvm.nvvmProgram nvvm_prog
541+
cdef bytes code_bytes
542+
cdef const char* code_ptr
543+
cdef const char* name_ptr
544+
cdef size_t code_len
545+
533546
self._options = options = check_or_create_options(ProgramOptions, options, "Program options")
534547
code_type = code_type.lower()
535548
536549
if code_type == "c++":
537550
assert_type(code, str)
538551
# TODO: support pre-loaded headers & include names
539-
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
540-
py_prog = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), options._name, 0, [], []))
541-
self._h_nvrtc = create_nvrtc_program_handle(<cynvrtc.nvrtcProgram><intptr_t>int(py_prog))
552+
code_bytes = code.encode()
553+
code_ptr = <const char*>code_bytes
554+
name_ptr = <const char*>options._name
555+
556+
with nogil:
557+
HANDLE_RETURN_NVRTC(NULL, cynvrtc.nvrtcCreateProgram(
558+
&nvrtc_prog, code_ptr, name_ptr, 0, NULL, NULL))
559+
self._h_nvrtc = create_nvrtc_program_handle(nvrtc_prog)
542560
self._backend = "NVRTC"
543561
self._linker = None
544562
@@ -550,15 +568,21 @@ cdef inline int Program_init(Program self, object code, str code_type, object op
550568
self._backend = self._linker.backend
551569
552570
elif code_type == "nvvm":
571+
_get_nvvm_module() # Validate NVVM availability
553572
if isinstance(code, str):
554573
code = code.encode("utf-8")
555574
elif not isinstance(code, (bytes, bytearray)):
556575
raise TypeError("NVVM IR code must be provided as str, bytes, or bytearray")
557576
558-
nvvm = _get_nvvm_module()
559-
py_prog = nvvm.create_program()
560-
nvvm.add_module_to_program(py_prog, code, len(code), options._name.decode())
561-
self._h_nvvm = create_nvvm_program_handle(<cynvvm.nvvmProgram><intptr_t>int(py_prog))
577+
code_ptr = <const char*>(<bytes>code)
578+
name_ptr = <const char*>options._name
579+
code_len = len(code)
580+
581+
with nogil:
582+
HANDLE_RETURN_NVVM(NULL, cynvvm.nvvmCreateProgram(&nvvm_prog))
583+
self._h_nvvm = create_nvvm_program_handle(nvvm_prog) # RAII from here
584+
with nogil:
585+
HANDLE_RETURN_NVVM(nvvm_prog, cynvvm.nvvmAddModuleToProgram(nvvm_prog, code_ptr, code_len, name_ptr))
562586
self._backend = "NVVM"
563587
self._linker = None
564588
@@ -571,115 +595,149 @@ cdef inline int Program_init(Program self, object code, str code_type, object op
571595
572596
573597
cdef object Program_compile_nvrtc(Program self, str target_type, object name_expressions, object logs):
574-
"""Compile using NVRTC backend."""
575-
if target_type == "ptx" and not _can_load_generated_ptx():
576-
warn(
577-
"The CUDA driver version is older than the backend version. "
578-
"The generated ptx will not be loadable by the current driver.",
579-
stacklevel=2,
580-
category=RuntimeWarning,
581-
)
582-
583-
# Create Python wrapper for handle_return calls that need it
584-
py_handle = nvrtc.nvrtcProgram(as_intptr(self._h_nvrtc))
585-
598+
"""Compile using NVRTC backend and return ObjectCode."""
599+
cdef cynvrtc.nvrtcProgram prog = as_cu(self._h_nvrtc)
600+
cdef size_t output_size = 0
601+
cdef size_t logsize = 0
602+
cdef vector[const char*] options_vec
603+
cdef char* data_ptr = NULL
604+
cdef bytes name_bytes
605+
cdef const char* name_ptr = NULL
606+
cdef const char* lowered_name = NULL
607+
cdef dict symbol_mapping = {}
608+
609+
# Add name expressions before compilation
586610
if name_expressions:
587611
for n in name_expressions:
588-
handle_return(
589-
nvrtc.nvrtcAddNameExpression(py_handle, n.encode()),
590-
handle=py_handle,
591-
)
592-
593-
options = self._options.as_bytes("nvrtc")
594-
handle_return(
595-
nvrtc.nvrtcCompileProgram(py_handle, len(options), options),
596-
handle=py_handle,
597-
)
598-
599-
size_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}Size")
600-
comp_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}")
601-
size = handle_return(size_func(py_handle), handle=py_handle)
602-
data = b" " * size
603-
handle_return(comp_func(py_handle, data), handle=py_handle)
604-
605-
symbol_mapping = {}
612+
name_bytes = n.encode() if isinstance(n, str) else n
613+
name_ptr = <const char*>name_bytes
614+
HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcAddNameExpression(prog, name_ptr))
615+
616+
# Build options array
617+
options_list = self._options.as_bytes("nvrtc", target_type)
618+
options_vec.resize(len(options_list))
619+
for i in range(len(options_list)):
620+
options_vec[i] = <const char*>(<bytes>options_list[i])
621+
622+
# Compile
623+
with nogil:
624+
HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcCompileProgram(prog, <int>options_vec.size(), options_vec.data()))
625+
626+
# Get compiled output based on target type
627+
if target_type == "ptx":
628+
HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetPTXSize(prog, &output_size))
629+
data = bytearray(output_size)
630+
data_ptr = <char*>(<bytearray>data)
631+
with nogil:
632+
HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetPTX(prog, data_ptr))
633+
elif target_type == "cubin":
634+
HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetCUBINSize(prog, &output_size))
635+
data = bytearray(output_size)
636+
data_ptr = <char*>(<bytearray>data)
637+
with nogil:
638+
HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetCUBIN(prog, data_ptr))
639+
else: # ltoir
640+
HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetLTOIRSize(prog, &output_size))
641+
data = bytearray(output_size)
642+
data_ptr = <char*>(<bytearray>data)
643+
with nogil:
644+
HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetLTOIR(prog, data_ptr))
645+
646+
# Get lowered names after compilation
606647
if name_expressions:
607648
for n in name_expressions:
608-
symbol_mapping[n] = handle_return(
609-
nvrtc.nvrtcGetLoweredName(py_handle, n.encode()), handle=py_handle
610-
)
649+
name_bytes = n.encode() if isinstance(n, str) else n
650+
name_ptr = <const char*>name_bytes
651+
HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetLoweredName(prog, name_ptr, &lowered_name))
652+
symbol_mapping[n] = lowered_name.decode() if lowered_name != NULL else None
611653
654+
# Get compilation log if requested
612655
if logs is not None:
613-
logsize = handle_return(nvrtc.nvrtcGetProgramLogSize(py_handle), handle=py_handle)
656+
HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetProgramLogSize(prog, &logsize))
614657
if logsize > 1:
615-
log = b" " * logsize
616-
handle_return(nvrtc.nvrtcGetProgramLog(py_handle, log), handle=py_handle)
658+
log = bytearray(logsize)
659+
data_ptr = <char*>(<bytearray>log)
660+
with nogil:
661+
HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetProgramLog(prog, data_ptr))
617662
logs.write(log.decode("utf-8", errors="backslashreplace"))
618663
619-
return ObjectCode._init(data, target_type, symbol_mapping=symbol_mapping, name=self._options.name)
664+
return ObjectCode._init(bytes(data), target_type, symbol_mapping=symbol_mapping, name=self._options.name)
620665
621666
622667
cdef object Program_compile_nvvm(Program self, str target_type, object logs):
623-
"""Compile using NVVM backend."""
624-
if target_type not in ("ptx", "ltoir"):
625-
raise ValueError(f'NVVM backend only supports target_type="ptx", "ltoir", got "{target_type}"')
626-
627-
# TODO: flip to True when NVIDIA/cuda-python#1354 is resolved and CUDA 12 is dropped
628-
nvvm_options = self._options._prepare_nvvm_options(as_bytes=False)
629-
if target_type == "ltoir" and "-gen-lto" not in nvvm_options:
630-
nvvm_options.append("-gen-lto")
631-
632-
nvvm = _get_nvvm_module()
633-
# NVVM uses raw integers for handles
634-
py_handle = as_intptr(self._h_nvvm)
635-
636-
try:
637-
nvvm.verify_program(py_handle, len(nvvm_options), nvvm_options)
638-
nvvm.compile_program(py_handle, len(nvvm_options), nvvm_options)
639-
except Exception as e:
640-
# Capture NVVM program log on error
641-
error_log = ""
642-
try:
643-
logsize = nvvm.get_program_log_size(py_handle)
644-
if logsize > 1:
645-
log = bytearray(logsize)
646-
nvvm.get_program_log(py_handle, log)
647-
error_log = log.decode("utf-8", errors="backslashreplace")
648-
except Exception:
649-
pass
650-
e.args = (e.args[0] + (f"\nNVVM program log: {error_log}" if error_log else ""), *e.args[1:])
651-
raise
652-
653-
size = nvvm.get_compiled_result_size(py_handle)
654-
data = bytearray(size)
655-
nvvm.get_compiled_result(py_handle, data)
656-
668+
"""Compile using NVVM backend and return ObjectCode."""
669+
cdef cynvvm.nvvmProgram prog = as_cu(self._h_nvvm)
670+
cdef size_t output_size = 0
671+
cdef size_t logsize = 0
672+
cdef vector[const char*] options_vec
673+
cdef char* data_ptr = NULL
674+
675+
# Build options array
676+
options_list = self._options.as_bytes("nvvm", target_type)
677+
options_vec.resize(len(options_list))
678+
for i in range(len(options_list)):
679+
options_vec[i] = <const char*>(<bytes>options_list[i])
680+
681+
# Compile
682+
with nogil:
683+
HANDLE_RETURN_NVVM(prog, cynvvm.nvvmVerifyProgram(prog, <int>options_vec.size(), options_vec.data()))
684+
HANDLE_RETURN_NVVM(prog, cynvvm.nvvmCompileProgram(prog, <int>options_vec.size(), options_vec.data()))
685+
686+
# Get compiled result
687+
HANDLE_RETURN_NVVM(prog, cynvvm.nvvmGetCompiledResultSize(prog, &output_size))
688+
data = bytearray(output_size)
689+
data_ptr = <char*>(<bytearray>data)
690+
with nogil:
691+
HANDLE_RETURN_NVVM(prog, cynvvm.nvvmGetCompiledResult(prog, data_ptr))
692+
693+
# Get compilation log if requested
657694
if logs is not None:
658-
logsize = nvvm.get_program_log_size(py_handle)
695+
HANDLE_RETURN_NVVM(prog, cynvvm.nvvmGetProgramLogSize(prog, &logsize))
659696
if logsize > 1:
660697
log = bytearray(logsize)
661-
nvvm.get_program_log(py_handle, log)
698+
data_ptr = <char*>(<bytearray>log)
699+
with nogil:
700+
HANDLE_RETURN_NVVM(prog, cynvvm.nvvmGetProgramLog(prog, data_ptr))
662701
logs.write(log.decode("utf-8", errors="backslashreplace"))
663702
664-
return ObjectCode._init(data, target_type, name=self._options.name)
703+
return ObjectCode._init(bytes(data), target_type, name=self._options.name)
704+
705+
# Supported target types per backend
706+
cdef dict SUPPORTED_TARGETS = {
707+
"NVRTC": ("ptx", "cubin", "ltoir"),
708+
"NVVM": ("ptx", "ltoir"),
709+
"nvJitLink": ("cubin", "ptx"),
710+
"driver": ("cubin", "ptx"),
711+
}
665712
666713
667714
cdef object Program_compile(Program self, str target_type, object name_expressions, object logs):
668715
"""Compile the program to the specified target type."""
669-
supported_target_types = ("ptx", "cubin", "ltoir")
670-
if target_type not in supported_target_types:
671-
raise ValueError(f'Unsupported target_type="{target_type}" ({supported_target_types=})')
716+
# Validate target_type for this backend
717+
supported = SUPPORTED_TARGETS.get(self._backend)
718+
if supported is None:
719+
raise ValueError(f'Unknown backend="{self._backend}"')
720+
if target_type not in supported:
721+
raise ValueError(
722+
f'Unsupported target_type="{target_type}" for {self._backend} '
723+
f'(supported: {", ".join(repr(t) for t in supported)})'
724+
)
672725
673726
if self._backend == "NVRTC":
727+
if target_type == "ptx" and not _can_load_generated_ptx():
728+
warn(
729+
"The CUDA driver version is older than the backend version. "
730+
"The generated ptx will not be loadable by the current driver.",
731+
stacklevel=2,
732+
category=RuntimeWarning,
733+
)
674734
return Program_compile_nvrtc(self, target_type, name_expressions, logs)
735+
675736
elif self._backend == "NVVM":
676737
return Program_compile_nvvm(self, target_type, logs)
677738
678-
# Linker backend (PTX code type)
679-
supported_backends = ("nvJitLink", "driver")
680-
if self._backend not in supported_backends:
681-
raise ValueError(f'Unsupported backend="{self._backend}" ({supported_backends=})')
682-
return self._linker.link(target_type)
739+
else:
740+
return self._linker.link(target_type)
683741
684742
685743
cdef inline list _prepare_nvrtc_options_impl(object opts):

0 commit comments

Comments
 (0)