Skip to content

Commit 8e7bbac

Browse files
committed
Fix missing binding version checks alongside driver version checks
Two source locations and one test helper only checked the driver version when gating features that also require the corresponding cuda-bindings version. When bindings are older than the driver, the driver check passes but the binding attribute/symbol is missing, causing AttributeError or similar runtime failures. - graph/_subclasses.pyx: _check_node_get_params() now also checks binding_version() >= (13, 2, 0) - _module.pyx: _get_arguments_info() now also checks cy_binding_version() >= (12, 4, 0) - tests/graph/test_graph_definition.py: _driver_has_node_get_params() renamed to _has_node_get_params() and checks both versions Closes #2052
1 parent 9cc3420 commit 8e7bbac

3 files changed

Lines changed: 14 additions & 7 deletions

File tree

cuda_core/cuda/core/_module.pyx

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ from cuda.core._utils.clear_error_support import (
3232
raise_code_path_meant_to_be_unreachable,
3333
)
3434
from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
35-
from cuda.core._utils.version cimport cy_driver_version
35+
from cuda.core._utils.version cimport cy_binding_version, cy_driver_version
3636
from cuda.core._utils.cuda_utils import driver
3737
from cuda.bindings cimport cydriver
3838

@@ -463,6 +463,11 @@ cdef class Kernel:
463463
"Driver version 12.4 or newer is required for this function. "
464464
f"Using driver version {'.'.join(map(str, cy_driver_version()))}"
465465
)
466+
if cy_binding_version() < (12, 4, 0):
467+
raise NotImplementedError(
468+
"cuda.bindings 12.4 or newer is required for this function. "
469+
f"Using binding version {'.'.join(map(str, cy_binding_version()))}"
470+
)
466471
cdef size_t arg_pos = 0
467472
cdef list param_info_data = []
468473
cdef cydriver.CUkernel cu_kernel = as_cu(self._h_kernel)

cuda_core/cuda/core/graph/_subclasses.pyx

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,10 @@ cdef bint _version_checked = False
6060
cdef bint _check_node_get_params():
6161
global _has_cuGraphNodeGetParams, _version_checked
6262
if not _version_checked:
63-
from cuda.core._utils.version import driver_version
64-
_has_cuGraphNodeGetParams = driver_version() >= (13, 2, 0)
63+
from cuda.core._utils.version import binding_version, driver_version
64+
_has_cuGraphNodeGetParams = (
65+
driver_version() >= (13, 2, 0) and binding_version() >= (13, 2, 0)
66+
)
6567
_version_checked = True
6668
return _has_cuGraphNodeGetParams
6769

cuda_core/tests/graph/test_graph_definition.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ def _skip_if_no_managed_mempool():
4848
pytest.skip("Device does not support managed memory pool operations")
4949

5050

51-
def _driver_has_node_get_params():
52-
from cuda.core._utils.version import driver_version
51+
def _has_node_get_params():
52+
from cuda.core._utils.version import binding_version, driver_version
5353

54-
return driver_version() >= (13, 2, 0)
54+
return driver_version() >= (13, 2, 0) and binding_version() >= (13, 2, 0)
5555

5656

57-
_HAS_NODE_GET_PARAMS = _driver_has_node_get_params()
57+
_HAS_NODE_GET_PARAMS = _has_node_get_params()
5858

5959

6060
def _bindings_major_version():

0 commit comments

Comments
 (0)