Skip to content

Commit 503bfac

Browse files
committed
cuda.core: simplify runtime error naming path
`_check_error()` only routes `runtime.cudaError_t` instances into `_check_runtime_error()`, so consulting `cudaGetErrorName()` and keeping a fallback for unknown values does not improve the normal `cuda.core` path. The Windows hybrid cudart issue is that the runtime name table can lag the generated enum table, so using `error.name` directly is both simpler and a better match for the values the code already has. With the runtime path now relying on enum members, the runtime-side tests no longer need to account for `UNEXPECTED ERROR CODE` in this loop or keep a separate monkeypatch test for avoiding the runtime name lookup. Made-with: Cursor
1 parent 82c35fd commit 503bfac

2 files changed

Lines changed: 4 additions & 38 deletions

File tree

cuda_core/cuda/core/_utils/cuda_utils.pyx

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,6 @@ cdef object _RUNTIME_SUCCESS = runtime.cudaError_t.cudaSuccess
149149
cdef object _NVRTC_SUCCESS = nvrtc.nvrtcResult.NVRTC_SUCCESS
150150

151151

152-
def _known_runtime_error_name(error):
153-
try:
154-
return runtime.cudaError_t(error).name
155-
except ValueError:
156-
return None
157-
158-
159152
cpdef inline int _check_driver_error(cydriver.CUresult error) except?-1 nogil:
160153
if error == cydriver.CUresult.CUDA_SUCCESS:
161154
return 0
@@ -178,12 +171,9 @@ cpdef inline int _check_driver_error(cydriver.CUresult error) except?-1 nogil:
178171
cpdef inline int _check_runtime_error(error) except?-1:
179172
if error == _RUNTIME_SUCCESS:
180173
return 0
181-
name_err, name = runtime.cudaGetErrorName(error)
182-
if name_err != _RUNTIME_SUCCESS:
183-
raise CUDAError(f"UNEXPECTED ERROR CODE: {error}")
184-
# Windows hybrid cudart can lag the generated bindings' enum table.
185-
# Prefer the binding name for values the bindings know.
186-
name = _known_runtime_error_name(error) or name.decode()
174+
# `_check_error()` reaches this path only for `runtime.cudaError_t` values.
175+
# Use the enum name directly because Windows hybrid cudart can lag that table.
176+
name = error.name
187177
expl = RUNTIME_CUDA_ERROR_EXPLANATIONS.get(int(error))
188178
if expl is not None:
189179
raise CUDAError(f"{name}: {expl}")

cuda_core/tests/test_cuda_utils.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,14 @@ def test_check_driver_error():
4848

4949

5050
def test_check_runtime_error():
51-
num_unexpected = 0
5251
for error in runtime.cudaError_t:
5352
if error == runtime.cudaError_t.cudaSuccess:
5453
assert cuda_utils._check_runtime_error(error) == 0
5554
else:
5655
with pytest.raises(cuda_utils.CUDAError) as e:
5756
cuda_utils._check_runtime_error(error)
5857
msg = str(e)
59-
if "UNEXPECTED ERROR CODE" in msg:
60-
num_unexpected += 1
61-
else:
62-
# Example repr(error): <cudaError_t.cudaErrorUnknown: 999>
63-
enum_name = repr(error).split(".", 1)[1].split(":", 1)[0]
64-
assert enum_name in msg
65-
# Smoke test: We don't want most to be unexpected.
66-
assert num_unexpected < len(driver.CUresult) * 0.5
58+
assert error.name in msg
6759

6860

6961
def test_driver_error_enum_has_non_empty_docstring():
@@ -179,22 +171,6 @@ def test_check_runtime_error_attaches_explanation():
179171
assert str(e.value) != f"{name.decode()}: {desc.decode()}"
180172

181173

182-
def test_check_runtime_error_uses_binding_name_for_known_runtime_error(monkeypatch):
183-
error = runtime.cudaError_t.cudaErrorInvalidValue
184-
runtime_name = b"runtime-provided-name"
185-
186-
def cuda_get_error_name(_error):
187-
return runtime.cudaError_t.cudaSuccess, runtime_name
188-
189-
monkeypatch.setattr(cuda_utils.runtime, "cudaGetErrorName", cuda_get_error_name)
190-
191-
with pytest.raises(cuda_utils.CUDAError) as e:
192-
cuda_utils._check_runtime_error(error)
193-
194-
assert str(e.value).startswith(f"{error.name}: ")
195-
assert runtime_name.decode() not in str(e.value)
196-
197-
198174
def test_precondition():
199175
def checker(*args, what=""):
200176
if args[0] < 0:

0 commit comments

Comments
 (0)