@@ -14,16 +14,16 @@ from warnings import warn
1414
1515from cuda.bindings import driver, nvrtc
1616
17- from libc.stdint cimport intptr_t
17+ from libcpp.vector cimport vector
1818
1919from ._resource_handles cimport (
20- NvrtcProgramHandle,
21- NvvmProgramHandle,
22- as_intptr,
20+ as_cu,
21+ as_py,
2322 create_nvrtc_program_handle,
2423 create_nvvm_program_handle,
2524)
2625from cuda.bindings cimport cynvrtc, cynvvm
26+ from cuda.core._utils.cuda_utils cimport HANDLE_RETURN_NVRTC, HANDLE_RETURN_NVVM
2727from cuda.core._device import Device
2828from cuda.core._linker import Linker, LinkerHandleT, LinkerOptions
2929from cuda.core._module import ObjectCode
@@ -40,8 +40,11 @@ from cuda.core._utils.cuda_utils import (
4040
4141__all__ = [" Program" , " ProgramOptions" ]
4242
43- ProgramHandleT = nvrtc.nvrtcProgram | LinkerHandleT
44- """ Type alias for program handle types across different backends."""
43+ ProgramHandleT = nvrtc.nvrtcProgram | int | LinkerHandleT
44+ """ Type alias for program handle types across different backends.
45+
46+ The ``int`` type covers NVVM handles, which don't have a wrapper class.
47+ """
4548
4649
4750# =============================================================================
@@ -76,8 +79,8 @@ cdef class Program:
7679 if self ._linker:
7780 self ._linker.close()
7881 # Reset handles - the C++ shared_ptr destructor handles cleanup
79- self ._h_nvrtc = NvrtcProgramHandle ()
80- self ._h_nvvm = NvvmProgramHandle ()
82+ self ._h_nvrtc.reset ()
83+ self ._h_nvvm.reset ()
8184
8285 def compile (
8386 self , target_type: str , name_expressions: tuple | list = (), logs = None
@@ -120,14 +123,11 @@ cdef class Program:
120123 handle , call ``int(Program.handle )``.
121124 """
122125 if self._backend == "NVRTC":
123- ptr = as_intptr(self ._h_nvrtc)
124- return nvrtc.nvrtcProgram(ptr ) if ptr else None
126+ return as_py(self._h_nvrtc )
125127 elif self._backend == "NVVM":
126- # NVVM uses raw integers for handles , not wrapper classes
127- ptr = as_intptr(self ._h_nvvm)
128- return ptr if ptr else None
128+ return as_py(self._h_nvvm ) # returns int (NVVM uses raw integers )
129129 else:
130- return self._linker.handle if self._linker else None
130+ return self._linker.handle
131131
132132 @staticmethod
133133 def driver_can_load_nvrtc_ptx_output() -> bool:
@@ -392,7 +392,7 @@ class ProgramOptions:
392392 def _prepare_nvvm_options(self, as_bytes: bool = True) -> list[bytes] | list[str]:
393393 return _prepare_nvvm_options_impl(self, as_bytes)
394394
395- def as_bytes(self, backend: str) -> list[bytes]:
395+ def as_bytes(self, backend: str, target_type: str | None = None ) -> list[bytes]:
396396 """ Convert program options to bytes format for the specified backend.
397397
398398 This method transforms the program options into a format suitable for the
@@ -403,6 +403,9 @@ class ProgramOptions:
403403 ----------
404404 backend : str
405405 The compiler backend to prepare options for . Must be either " nvrtc" or " nvvm" .
406+ target_type : str , optional
407+ The compilation target type (e.g., " ptx" , " cubin" , " ltoir" ). Some backends
408+ require additional options based on the target type .
406409
407410 Returns
408411 -------
@@ -425,7 +428,10 @@ class ProgramOptions:
425428 if backend == "nvrtc":
426429 return self._prepare_nvrtc_options()
427430 elif backend == "nvvm":
428- return self._prepare_nvvm_options(as_bytes=True)
431+ options = self._prepare_nvvm_options(as_bytes=True)
432+ if target_type == "ltoir" and b"-gen-lto" not in options:
433+ options.append(b"-gen-lto")
434+ return options
429435 else:
430436 raise ValueError(f"Unknown backend '{backend}'. Must be one of: 'nvrtc', 'nvvm'")
431437
@@ -530,15 +536,27 @@ cdef inline object _translate_program_options(object options):
530536
531537cdef inline int Program_init(Program self, object code, str code_type, object options) except -1:
532538 """ Initialize a Program instance."""
539+ cdef cynvrtc.nvrtcProgram nvrtc_prog
540+ cdef cynvvm.nvvmProgram nvvm_prog
541+ cdef bytes code_bytes
542+ cdef const char* code_ptr
543+ cdef const char* name_ptr
544+ cdef size_t code_len
545+
533546 self._options = options = check_or_create_options(ProgramOptions, options, "Program options")
534547 code_type = code_type.lower()
535548
536549 if code_type == "c++":
537550 assert_type(code, str)
538551 # TODO: support pre-loaded headers & include names
539- # TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
540- py_prog = handle_return(nvrtc.nvrtcCreateProgram(code.encode(), options._name, 0, [], []))
541- self._h_nvrtc = create_nvrtc_program_handle(<cynvrtc.nvrtcProgram><intptr_t>int(py_prog))
552+ code_bytes = code.encode()
553+ code_ptr = <const char*>code_bytes
554+ name_ptr = <const char*>options._name
555+
556+ with nogil:
557+ HANDLE_RETURN_NVRTC(NULL, cynvrtc.nvrtcCreateProgram(
558+ &nvrtc_prog, code_ptr, name_ptr, 0, NULL, NULL))
559+ self._h_nvrtc = create_nvrtc_program_handle(nvrtc_prog)
542560 self._backend = "NVRTC"
543561 self._linker = None
544562
@@ -550,15 +568,21 @@ cdef inline int Program_init(Program self, object code, str code_type, object op
550568 self._backend = self._linker.backend
551569
552570 elif code_type == "nvvm":
571+ _get_nvvm_module() # Validate NVVM availability
553572 if isinstance(code, str):
554573 code = code.encode("utf-8")
555574 elif not isinstance(code, (bytes, bytearray)):
556575 raise TypeError("NVVM IR code must be provided as str, bytes, or bytearray")
557576
558- nvvm = _get_nvvm_module()
559- py_prog = nvvm.create_program()
560- nvvm.add_module_to_program(py_prog, code, len(code), options._name.decode())
561- self._h_nvvm = create_nvvm_program_handle(<cynvvm.nvvmProgram><intptr_t>int(py_prog))
577+ code_ptr = <const char*>(<bytes>code)
578+ name_ptr = <const char*>options._name
579+ code_len = len(code)
580+
581+ with nogil:
582+ HANDLE_RETURN_NVVM(NULL, cynvvm.nvvmCreateProgram(&nvvm_prog))
583+ self._h_nvvm = create_nvvm_program_handle(nvvm_prog) # RAII from here
584+ with nogil:
585+ HANDLE_RETURN_NVVM(nvvm_prog, cynvvm.nvvmAddModuleToProgram(nvvm_prog, code_ptr, code_len, name_ptr))
562586 self._backend = "NVVM"
563587 self._linker = None
564588
@@ -571,115 +595,149 @@ cdef inline int Program_init(Program self, object code, str code_type, object op
571595
572596
573597cdef object Program_compile_nvrtc(Program self, str target_type, object name_expressions, object logs):
574- """ Compile using NVRTC backend."""
575- if target_type == "ptx" and not _can_load_generated_ptx():
576- warn(
577- "The CUDA driver version is older than the backend version. "
578- "The generated ptx will not be loadable by the current driver.",
579- stacklevel=2,
580- category=RuntimeWarning,
581- )
582-
583- # Create Python wrapper for handle_return calls that need it
584- py_handle = nvrtc.nvrtcProgram(as_intptr(self._h_nvrtc))
585-
598+ """ Compile using NVRTC backend and return ObjectCode ."""
599+ cdef cynvrtc.nvrtcProgram prog = as_cu(self._h_nvrtc)
600+ cdef size_t output_size = 0
601+ cdef size_t logsize = 0
602+ cdef vector[const char*] options_vec
603+ cdef char* data_ptr = NULL
604+ cdef bytes name_bytes
605+ cdef const char* name_ptr = NULL
606+ cdef const char* lowered_name = NULL
607+ cdef dict symbol_mapping = {}
608+
609+ # Add name expressions before compilation
586610 if name_expressions:
587611 for n in name_expressions:
588- handle_return(
589- nvrtc.nvrtcAddNameExpression(py_handle, n.encode()),
590- handle=py_handle,
591- )
592-
593- options = self._options.as_bytes("nvrtc")
594- handle_return(
595- nvrtc.nvrtcCompileProgram(py_handle, len(options), options),
596- handle=py_handle,
597- )
598-
599- size_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}Size")
600- comp_func = getattr(nvrtc, f"nvrtcGet{target_type.upper()}")
601- size = handle_return(size_func(py_handle), handle=py_handle)
602- data = b" " * size
603- handle_return(comp_func(py_handle, data), handle=py_handle)
604-
605- symbol_mapping = {}
612+ name_bytes = n.encode() if isinstance(n, str) else n
613+ name_ptr = <const char*>name_bytes
614+ HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcAddNameExpression(prog, name_ptr))
615+
616+ # Build options array
617+ options_list = self._options.as_bytes("nvrtc", target_type)
618+ options_vec.resize(len(options_list))
619+ for i in range(len(options_list)):
620+ options_vec[i] = <const char*>(<bytes>options_list[i])
621+
622+ # Compile
623+ with nogil:
624+ HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcCompileProgram(prog, <int>options_vec.size(), options_vec.data()))
625+
626+ # Get compiled output based on target type
627+ if target_type == "ptx":
628+ HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetPTXSize(prog, &output_size))
629+ data = bytearray(output_size)
630+ data_ptr = <char*>(<bytearray>data)
631+ with nogil:
632+ HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetPTX(prog, data_ptr))
633+ elif target_type == "cubin":
634+ HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetCUBINSize(prog, &output_size))
635+ data = bytearray(output_size)
636+ data_ptr = <char*>(<bytearray>data)
637+ with nogil:
638+ HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetCUBIN(prog, data_ptr))
639+ else: # ltoir
640+ HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetLTOIRSize(prog, &output_size))
641+ data = bytearray(output_size)
642+ data_ptr = <char*>(<bytearray>data)
643+ with nogil:
644+ HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetLTOIR(prog, data_ptr))
645+
646+ # Get lowered names after compilation
606647 if name_expressions:
607648 for n in name_expressions:
608- symbol_mapping[n] = handle_return(
609- nvrtc.nvrtcGetLoweredName(py_handle, n.encode()), handle=py_handle
610- )
649+ name_bytes = n.encode() if isinstance(n, str) else n
650+ name_ptr = <const char*>name_bytes
651+ HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetLoweredName(prog, name_ptr, &lowered_name))
652+ symbol_mapping[n] = lowered_name.decode() if lowered_name != NULL else None
611653
654+ # Get compilation log if requested
612655 if logs is not None:
613- logsize = handle_return(nvrtc .nvrtcGetProgramLogSize(py_handle), handle=py_handle )
656+ HANDLE_RETURN_NVRTC(prog, cynvrtc .nvrtcGetProgramLogSize(prog, &logsize) )
614657 if logsize > 1:
615- log = b" " * logsize
616- handle_return(nvrtc.nvrtcGetProgramLog(py_handle, log), handle=py_handle)
658+ log = bytearray(logsize)
659+ data_ptr = <char*>(<bytearray>log)
660+ with nogil:
661+ HANDLE_RETURN_NVRTC(prog, cynvrtc.nvrtcGetProgramLog(prog, data_ptr))
617662 logs.write(log.decode("utf-8", errors="backslashreplace"))
618663
619- return ObjectCode._init(data, target_type, symbol_mapping=symbol_mapping, name=self._options.name)
664+ return ObjectCode._init(bytes( data) , target_type, symbol_mapping=symbol_mapping, name=self._options.name)
620665
621666
622667cdef object Program_compile_nvvm(Program self, str target_type, object logs):
623- """ Compile using NVVM backend."""
624- if target_type not in ("ptx", "ltoir"):
625- raise ValueError(f'NVVM backend only supports target_type="ptx", "ltoir", got "{target_type}"')
626-
627- # TODO: flip to True when NVIDIA/cuda-python#1354 is resolved and CUDA 12 is dropped
628- nvvm_options = self._options._prepare_nvvm_options(as_bytes=False)
629- if target_type == "ltoir" and "-gen-lto" not in nvvm_options:
630- nvvm_options.append("-gen-lto")
631-
632- nvvm = _get_nvvm_module()
633- # NVVM uses raw integers for handles
634- py_handle = as_intptr(self._h_nvvm)
635-
636- try:
637- nvvm.verify_program(py_handle, len(nvvm_options), nvvm_options)
638- nvvm.compile_program(py_handle, len(nvvm_options), nvvm_options)
639- except Exception as e:
640- # Capture NVVM program log on error
641- error_log = ""
642- try:
643- logsize = nvvm.get_program_log_size(py_handle)
644- if logsize > 1:
645- log = bytearray(logsize)
646- nvvm.get_program_log(py_handle, log)
647- error_log = log.decode("utf-8", errors="backslashreplace")
648- except Exception:
649- pass
650- e.args = (e.args[0] + (f"\n NVVM program log: {error_log}" if error_log else ""), *e.args[1:])
651- raise
652-
653- size = nvvm.get_compiled_result_size(py_handle)
654- data = bytearray(size)
655- nvvm.get_compiled_result(py_handle, data)
656-
668+ """ Compile using NVVM backend and return ObjectCode."""
669+ cdef cynvvm.nvvmProgram prog = as_cu(self._h_nvvm)
670+ cdef size_t output_size = 0
671+ cdef size_t logsize = 0
672+ cdef vector[const char*] options_vec
673+ cdef char* data_ptr = NULL
674+
675+ # Build options array
676+ options_list = self._options.as_bytes("nvvm", target_type)
677+ options_vec.resize(len(options_list))
678+ for i in range(len(options_list)):
679+ options_vec[i] = <const char*>(<bytes>options_list[i])
680+
681+ # Compile
682+ with nogil:
683+ HANDLE_RETURN_NVVM(prog, cynvvm.nvvmVerifyProgram(prog, <int>options_vec.size(), options_vec.data()))
684+ HANDLE_RETURN_NVVM(prog, cynvvm.nvvmCompileProgram(prog, <int>options_vec.size(), options_vec.data()))
685+
686+ # Get compiled result
687+ HANDLE_RETURN_NVVM(prog, cynvvm.nvvmGetCompiledResultSize(prog, &output_size))
688+ data = bytearray(output_size)
689+ data_ptr = <char*>(<bytearray>data)
690+ with nogil:
691+ HANDLE_RETURN_NVVM(prog, cynvvm.nvvmGetCompiledResult(prog, data_ptr))
692+
693+ # Get compilation log if requested
657694 if logs is not None:
658- logsize = nvvm.get_program_log_size(py_handle )
695+ HANDLE_RETURN_NVVM(prog, cynvvm.nvvmGetProgramLogSize(prog, &logsize) )
659696 if logsize > 1:
660697 log = bytearray(logsize)
661- nvvm.get_program_log(py_handle, log)
698+ data_ptr = <char*>(<bytearray>log)
699+ with nogil:
700+ HANDLE_RETURN_NVVM(prog, cynvvm.nvvmGetProgramLog(prog, data_ptr))
662701 logs.write(log.decode("utf-8", errors="backslashreplace"))
663702
664- return ObjectCode._init(data, target_type, name=self._options.name)
703+ return ObjectCode._init(bytes(data), target_type, name=self._options.name)
704+
705+ # Supported target types per backend
706+ cdef dict SUPPORTED_TARGETS = {
707+ "NVRTC": ("ptx", "cubin", "ltoir"),
708+ "NVVM": ("ptx", "ltoir"),
709+ "nvJitLink": ("cubin", "ptx"),
710+ "driver": ("cubin", "ptx"),
711+ }
665712
666713
667714cdef object Program_compile(Program self, str target_type, object name_expressions, object logs):
668715 """ Compile the program to the specified target type ."""
669- supported_target_types = ("ptx", "cubin", "ltoir")
670- if target_type not in supported_target_types:
671- raise ValueError(f'Unsupported target_type="{target_type}" ({supported_target_types=})')
716+ # Validate target_type for this backend
717+ supported = SUPPORTED_TARGETS.get(self._backend)
718+ if supported is None:
719+ raise ValueError(f'Unknown backend="{self._backend}"')
720+ if target_type not in supported:
721+ raise ValueError(
722+ f'Unsupported target_type="{target_type}" for {self._backend} '
723+ f'(supported: {", ".join(repr(t) for t in supported)})'
724+ )
672725
673726 if self._backend == "NVRTC":
727+ if target_type == "ptx" and not _can_load_generated_ptx():
728+ warn(
729+ "The CUDA driver version is older than the backend version. "
730+ "The generated ptx will not be loadable by the current driver.",
731+ stacklevel=2,
732+ category=RuntimeWarning,
733+ )
674734 return Program_compile_nvrtc(self, target_type, name_expressions, logs)
735+
675736 elif self._backend == "NVVM":
676737 return Program_compile_nvvm(self, target_type, logs)
677738
678- # Linker backend (PTX code type)
679- supported_backends = ("nvJitLink", "driver")
680- if self._backend not in supported_backends:
681- raise ValueError(f'Unsupported backend="{self._backend}" ({supported_backends=})')
682- return self._linker.link(target_type)
739+ else:
740+ return self._linker.link(target_type)
683741
684742
685743cdef inline list _prepare_nvrtc_options_impl(object opts):
0 commit comments