Skip to content

Commit 82c35fd

Browse files
committed
cuda.core: prefer binding names for runtime errors
Use the generated runtime error enum as the name source for known CUDA Runtime errors so error messages remain stable when the runtime name table differs from the installed bindings. Made-with: Cursor
1 parent aac5bf5 commit 82c35fd

2 files changed

Lines changed: 26 additions & 1 deletion

File tree

cuda_core/cuda/core/_utils/cuda_utils.pyx

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,13 @@ cdef object _RUNTIME_SUCCESS = runtime.cudaError_t.cudaSuccess
149149
cdef object _NVRTC_SUCCESS = nvrtc.nvrtcResult.NVRTC_SUCCESS
150150

151151

152+
def _known_runtime_error_name(error):
153+
try:
154+
return runtime.cudaError_t(error).name
155+
except ValueError:
156+
return None
157+
158+
152159
cpdef inline int _check_driver_error(cydriver.CUresult error) except?-1 nogil:
153160
if error == cydriver.CUresult.CUDA_SUCCESS:
154161
return 0
@@ -174,7 +181,9 @@ cpdef inline int _check_runtime_error(error) except?-1:
174181
name_err, name = runtime.cudaGetErrorName(error)
175182
if name_err != _RUNTIME_SUCCESS:
176183
raise CUDAError(f"UNEXPECTED ERROR CODE: {error}")
177-
name = name.decode()
184+
# Windows hybrid cudart can lag the generated bindings' enum table.
185+
# Prefer the binding name for values the bindings know.
186+
name = _known_runtime_error_name(error) or name.decode()
178187
expl = RUNTIME_CUDA_ERROR_EXPLANATIONS.get(int(error))
179188
if expl is not None:
180189
raise CUDAError(f"{name}: {expl}")

cuda_core/tests/test_cuda_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,22 @@ def test_check_runtime_error_attaches_explanation():
179179
assert str(e.value) != f"{name.decode()}: {desc.decode()}"
180180

181181

182+
def test_check_runtime_error_uses_binding_name_for_known_runtime_error(monkeypatch):
183+
error = runtime.cudaError_t.cudaErrorInvalidValue
184+
runtime_name = b"runtime-provided-name"
185+
186+
def cuda_get_error_name(_error):
187+
return runtime.cudaError_t.cudaSuccess, runtime_name
188+
189+
monkeypatch.setattr(cuda_utils.runtime, "cudaGetErrorName", cuda_get_error_name)
190+
191+
with pytest.raises(cuda_utils.CUDAError) as e:
192+
cuda_utils._check_runtime_error(error)
193+
194+
assert str(e.value).startswith(f"{error.name}: ")
195+
assert runtime_name.decode() not in str(e.value)
196+
197+
182198
def test_precondition():
183199
def checker(*args, what=""):
184200
if args[0] < 0:

0 commit comments

Comments
 (0)