@@ -109,13 +109,32 @@ def __init__(self, name, members):
109109 self ._name = name
110110 self ._member_names = []
111111 self ._member_types = []
112+ self ._member_declarators = []
112113 for var_name , var_type , _ in members :
113- var_type = var_type [0 ]
114- var_type = var_type .removeprefix ("struct " )
115- var_type = var_type .removeprefix ("union " )
114+ base_type = var_type [0 ]
115+ base_type = base_type .removeprefix ("struct " )
116+ base_type = base_type .removeprefix ("union " )
116117
117118 self ._member_names += [var_name ]
118- self ._member_types += [var_type ]
119+ self ._member_types += [base_type ]
120+ self ._member_declarators += [tuple (var_type [1 :])]
121+
122+ def member_type (self , member_name ):
123+ try :
124+ return self ._member_types [self ._member_names .index (member_name )]
125+ except ValueError :
126+ return None
127+
128+ def member_array_length (self , member_name ):
129+ try :
130+ declarators = self ._member_declarators [self ._member_names .index (member_name )]
131+ except ValueError :
132+ return None
133+
134+ for declarator in declarators :
135+ if isinstance (declarator , list ) and len (declarator ) == 1 :
136+ return declarator [0 ]
137+ return None
119138
120139 def discoverMembers (self , memberDict , prefix , seen = None ):
121140 if seen is None :
@@ -192,6 +211,7 @@ def _parse_headers(header_dict, include_path_list, parser_caching):
192211 # Since we only support 64 bit architectures, we can inline the sizeof(T*) to 8 and then compute the
193212 # result in Python. The arithmetic expression is preserved to help with clarity and understanding
194213 r"char reserved\[52 - sizeof\(CUcheckpointGpuPair \*\)\];" : rf"char reserved[{ 52 - 8 } ];" ,
214+ r"char reserved\[64 - sizeof\(CUcheckpointGpuPair \*\) - sizeof\(unsigned int\)\];" : rf"char reserved[{ 64 - 8 - 4 } ];" ,
195215 }
196216
197217 print (f'Parsing headers in "{ include_path_list } " (Caching = { parser_caching } )' , flush = True )
@@ -341,6 +361,13 @@ def _build_cuda_bindings(debug=False):
341361 found_types , found_functions , found_values , found_struct , struct_list = _parse_headers (
342362 header_dict , include_path_list , parser_caching
343363 )
364+ struct_field_types = {}
365+ struct_field_array_lengths = {}
366+ for struct_name , struct in struct_list .items ():
367+ for member_name in struct ._member_names :
368+ key = f"{ struct_name } .{ member_name } "
369+ struct_field_types [key ] = struct .member_type (member_name )
370+ struct_field_array_lengths [key ] = struct .member_array_length (member_name )
344371
345372 # Generate code from .in templates
346373 path_list = [
@@ -363,6 +390,8 @@ def _build_cuda_bindings(debug=False):
363390 "found_values" : found_values ,
364391 "found_struct" : found_struct ,
365392 "struct_list" : struct_list ,
393+ "struct_field_types" : struct_field_types ,
394+ "struct_field_array_lengths" : struct_field_array_lengths ,
366395 "os" : os ,
367396 "sys" : sys ,
368397 "platform" : platform ,
0 commit comments