diff --git a/cuda_bindings/tests/test_nvjitlink.py b/cuda_bindings/tests/test_nvjitlink.py index 5c6ca98ea73..ff403eb0b2b 100644 --- a/cuda_bindings/tests/test_nvjitlink.py +++ b/cuda_bindings/tests/test_nvjitlink.py @@ -66,12 +66,13 @@ def check_nvjitlink_usable(): def get_dummy_ltoir(): def CHECK_NVRTC(err): if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: - raise RuntimeError(f"Nvrtc Error: {err}") + raise RuntimeError(repr(err)) empty_cplusplus_kernel = "__global__ void A() {}" err, program_handle = nvrtc.nvrtcCreateProgram(empty_cplusplus_kernel.encode(), b"", 0, [], []) CHECK_NVRTC(err) - nvrtc.nvrtcCompileProgram(program_handle, 1, [b"-dlto"]) + err = nvrtc.nvrtcCompileProgram(program_handle, 1, [b"-dlto"])[0] + CHECK_NVRTC(err) err, size = nvrtc.nvrtcGetLTOIRSize(program_handle) CHECK_NVRTC(err) empty_kernel_ltoir = b" " * size