File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -149,6 +149,13 @@ cdef object _RUNTIME_SUCCESS = runtime.cudaError_t.cudaSuccess
149149cdef object _NVRTC_SUCCESS = nvrtc.nvrtcResult.NVRTC_SUCCESS
150150
151151
152+ def _known_runtime_error_name (error ):
153+ try :
154+ return runtime.cudaError_t(error).name
155+ except ValueError :
156+ return None
157+
158+
152159cpdef inline int _check_driver_error(cydriver.CUresult error) except ?- 1 nogil:
153160 if error == cydriver.CUresult.CUDA_SUCCESS:
154161 return 0
@@ -174,7 +181,9 @@ cpdef inline int _check_runtime_error(error) except?-1:
174181 name_err, name = runtime.cudaGetErrorName(error)
175182 if name_err != _RUNTIME_SUCCESS:
176183 raise CUDAError(f" UNEXPECTED ERROR CODE: {error}" )
177- name = name.decode()
184+ # Windows hybrid cudart can lag the generated bindings' enum table.
185+ # Prefer the binding name for values the bindings know.
186+ name = _known_runtime_error_name(error) or name.decode()
178187 expl = RUNTIME_CUDA_ERROR_EXPLANATIONS.get(int (error))
179188 if expl is not None :
180189 raise CUDAError(f" {name}: {expl}" )
Original file line number Diff line number Diff line change @@ -179,6 +179,22 @@ def test_check_runtime_error_attaches_explanation():
179179 assert str (e .value ) != f"{ name .decode ()} : { desc .decode ()} "
180180
181181
182+ def test_check_runtime_error_uses_binding_name_for_known_runtime_error (monkeypatch ):
183+ error = runtime .cudaError_t .cudaErrorInvalidValue
184+ runtime_name = b"runtime-provided-name"
185+
186+ def cuda_get_error_name (_error ):
187+ return runtime .cudaError_t .cudaSuccess , runtime_name
188+
189+ monkeypatch .setattr (cuda_utils .runtime , "cudaGetErrorName" , cuda_get_error_name )
190+
191+ with pytest .raises (cuda_utils .CUDAError ) as e :
192+ cuda_utils ._check_runtime_error (error )
193+
194+ assert str (e .value ).startswith (f"{ error .name } : " )
195+ assert runtime_name .decode () not in str (e .value )
196+
197+
182198def test_precondition ():
183199 def checker (* args , what = "" ):
184200 if args [0 ] < 0 :
You can’t perform that action at this time.
0 commit comments