File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -33,8 +33,26 @@ def _checkpoint_available():
3333 try :
3434 checkpoint ._get_driver ()
3535 return True
36- except RuntimeError :
37- return False
36+ except RuntimeError as exc :
37+ if _checkpoint_unavailable_can_skip (str (exc )):
38+ return False
39+ raise
40+
41+
42+ def _checkpoint_unavailable_can_skip (message ):
43+ if message .startswith (
44+ (
45+ "CUDA checkpointing is not supported by the installed NVIDIA driver." ,
46+ "CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. Found cuda.bindings " ,
47+ )
48+ ):
49+ return True
50+
51+ return (
52+ checkpoint ._binding_version ()[0 ] == 12
53+ and message
54+ == "CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. Missing: CUcheckpointGpuPair"
55+ )
3856
3957
4058needs_checkpoint = pytest .mark .skipif (
You can’t perform that action at this time.
0 commit comments