Skip to content

Commit 6884723

Browse files
committed
Handle CUDA checkpoint restore arg layouts
1 parent a4e89b5 commit 6884723

8 files changed

Lines changed: 120 additions & 54 deletions

File tree

cuda_bindings/build_hooks.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,32 @@ def __init__(self, name, members):
7878
self._name = name
7979
self._member_names = []
8080
self._member_types = []
81+
self._member_declarators = []
8182
for var_name, var_type, _ in members:
82-
var_type = var_type[0]
83-
var_type = var_type.removeprefix("struct ")
84-
var_type = var_type.removeprefix("union ")
83+
base_type = var_type[0]
84+
base_type = base_type.removeprefix("struct ")
85+
base_type = base_type.removeprefix("union ")
8586

8687
self._member_names += [var_name]
87-
self._member_types += [var_type]
88+
self._member_types += [base_type]
89+
self._member_declarators += [tuple(var_type[1:])]
90+
91+
def member_type(self, member_name):
92+
try:
93+
return self._member_types[self._member_names.index(member_name)]
94+
except ValueError:
95+
return None
96+
97+
def member_array_length(self, member_name):
98+
try:
99+
declarators = self._member_declarators[self._member_names.index(member_name)]
100+
except ValueError:
101+
return None
102+
103+
for declarator in declarators:
104+
if isinstance(declarator, list) and len(declarator) == 1:
105+
return declarator[0]
106+
return None
88107

89108
def discoverMembers(self, memberDict, prefix, seen=None):
90109
if seen is None:
@@ -161,6 +180,9 @@ def _parse_headers(header_dict, include_path_list, parser_caching):
161180
# Since we only support 64 bit architectures, we can inline the sizeof(T*) to 8 and then compute the
162181
# result in Python. The arithmetic expression is preserved to help with clarity and understanding
163182
r"char reserved\[52 - sizeof\(CUcheckpointGpuPair \*\)\];": rf"char reserved[{52 - 8}];",
183+
r"char reserved\[64 - sizeof\(CUcheckpointGpuPair \*\) - sizeof\(unsigned int\)\];": (
184+
rf"char reserved[{64 - 8 - 4}];"
185+
),
164186
}
165187

166188
print(f'Parsing headers in "{include_path_list}" (Caching = {parser_caching})', flush=True)
@@ -310,6 +332,13 @@ def _build_cuda_bindings(strip=False):
310332
found_types, found_functions, found_values, found_struct, struct_list = _parse_headers(
311333
header_dict, include_path_list, parser_caching
312334
)
335+
struct_field_types = {}
336+
struct_field_array_lengths = {}
337+
for struct_name, struct in struct_list.items():
338+
for member_name in struct._member_names:
339+
key = f"{struct_name}.{member_name}"
340+
struct_field_types[key] = struct.member_type(member_name)
341+
struct_field_array_lengths[key] = struct.member_array_length(member_name)
313342

314343
# Generate code from .in templates
315344
path_list = [
@@ -332,6 +361,8 @@ def _build_cuda_bindings(strip=False):
332361
"found_values": found_values,
333362
"found_struct": found_struct,
334363
"struct_list": struct_list,
364+
"struct_field_types": struct_field_types,
365+
"struct_field_array_lengths": struct_field_array_lengths,
335366
"os": os,
336367
"sys": sys,
337368
"platform": platform,

cuda_bindings/cuda/bindings/_bindings/cydriver.pxd.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
33

4-
# This code was automatically generated with version 12.9.0, generator version 49a8141. Do not modify it directly.
4+
# This code was automatically generated with version 12.9.0, generator version 0.3.1.dev1711+g875fec45. Do not modify it directly.
55
from cuda.bindings.cydriver cimport *
66

77
{{if 'cuGetErrorString' in found_functions}}

cuda_bindings/cuda/bindings/_bindings/cydriver.pyx.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
33

4-
# This code was automatically generated with version 12.9.0, generator version 0.3.1.dev1630+gadce055ea.d20260422. Do not modify it directly.
4+
# This code was automatically generated with version 12.9.0, generator version 0.3.1.dev1711+g875fec45. Do not modify it directly.
55
{{if 'Windows' == platform.system()}}
66
import os
77
cimport cuda.bindings._lib.windll as windll

cuda_bindings/cuda/bindings/cydriver.pxd.in

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
33

4-
# This code was automatically generated with version 12.9.0, generator version 49a8141. Do not modify it directly.
4+
# This code was automatically generated with version 12.9.0, generator version 0.3.1.dev1711+g875fec45. Do not modify it directly.
55

66
from libc.stdint cimport uint32_t, uint64_t
77

@@ -2311,7 +2311,12 @@ cdef extern from "cuda.h":
23112311
ctypedef CUcheckpointCheckpointArgs_st CUcheckpointCheckpointArgs
23122312

23132313
cdef struct CUcheckpointRestoreArgs_st:
2314-
cuuint64_t reserved[8]
2314+
{{if struct_field_types.get('CUcheckpointRestoreArgs_st.reserved') == 'char'}}
2315+
char reserved[{{struct_field_array_lengths['CUcheckpointRestoreArgs_st.reserved']}}]
2316+
{{endif}}
2317+
{{if struct_field_types.get('CUcheckpointRestoreArgs_st.reserved') == 'cuuint64_t'}}
2318+
cuuint64_t reserved[{{struct_field_array_lengths['CUcheckpointRestoreArgs_st.reserved']}}]
2319+
{{endif}}
23152320

23162321
ctypedef CUcheckpointRestoreArgs_st CUcheckpointRestoreArgs
23172322

cuda_bindings/cuda/bindings/cydriver.pyx.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
33

4-
# This code was automatically generated with version 12.9.0, generator version 49a8141. Do not modify it directly.
4+
# This code was automatically generated with version 12.9.0, generator version 0.3.1.dev1711+g875fec45. Do not modify it directly.
55
cimport cuda.bindings._bindings.cydriver as cydriver
66

77
{{if 'cuGetErrorString' in found_functions}}

cuda_bindings/cuda/bindings/driver.pxd.in

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2021-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
33

4-
# This code was automatically generated with version 12.9.0, generator version 0.3.1.dev1588+g61faef43a. Do not modify it directly.
4+
# This code was automatically generated with version 12.9.0, generator version 0.3.1.dev1711+g875fec45. Do not modify it directly.
55
cimport cuda.bindings.cydriver as cydriver
66

77
include "_lib/utils.pxd"
@@ -5097,7 +5097,11 @@ cdef class CUcheckpointRestoreArgs_st:
50975097

50985098
Attributes
50995099
----------
5100-
{{if 'CUcheckpointRestoreArgs_st.reserved' in found_struct}}
5100+
{{if struct_field_types.get('CUcheckpointRestoreArgs_st.reserved') == 'char'}}
5101+
reserved : bytes
5102+
Reserved for future use, must be zeroed
5103+
{{endif}}
5104+
{{if struct_field_types.get('CUcheckpointRestoreArgs_st.reserved') == 'cuuint64_t'}}
51015105
reserved : list[cuuint64_t]
51025106
Reserved for future use, must be zeroed
51035107
{{endif}}
@@ -10560,7 +10564,11 @@ cdef class CUcheckpointRestoreArgs(CUcheckpointRestoreArgs_st):
1056010564

1056110565
Attributes
1056210566
----------
10563-
{{if 'CUcheckpointRestoreArgs_st.reserved' in found_struct}}
10567+
{{if struct_field_types.get('CUcheckpointRestoreArgs_st.reserved') == 'char'}}
10568+
reserved : bytes
10569+
Reserved for future use, must be zeroed
10570+
{{endif}}
10571+
{{if struct_field_types.get('CUcheckpointRestoreArgs_st.reserved') == 'cuuint64_t'}}
1056410572
reserved : list[cuuint64_t]
1056510573
Reserved for future use, must be zeroed
1056610574
{{endif}}

0 commit comments

Comments
 (0)