@@ -78,13 +78,32 @@ def __init__(self, name, members):
7878 self ._name = name
7979 self ._member_names = []
8080 self ._member_types = []
81+ self ._member_declarators = []
8182 for var_name , var_type , _ in members :
82- var_type = var_type [0 ]
83- var_type = var_type .removeprefix ("struct " )
84- var_type = var_type .removeprefix ("union " )
83+ base_type = var_type [0 ]
84+ base_type = base_type .removeprefix ("struct " )
85+ base_type = base_type .removeprefix ("union " )
8586
8687 self ._member_names += [var_name ]
87- self ._member_types += [var_type ]
88+ self ._member_types += [base_type ]
89+ self ._member_declarators += [tuple (var_type [1 :])]
90+
91+ def member_type (self , member_name ):
92+ try :
93+ return self ._member_types [self ._member_names .index (member_name )]
94+ except ValueError :
95+ return None
96+
97+ def member_array_length (self , member_name ):
98+ try :
99+ declarators = self ._member_declarators [self ._member_names .index (member_name )]
100+ except ValueError :
101+ return None
102+
103+ for declarator in declarators :
104+ if isinstance (declarator , list ) and len (declarator ) == 1 :
105+ return declarator [0 ]
106+ return None
88107
89108 def discoverMembers (self , memberDict , prefix , seen = None ):
90109 if seen is None :
@@ -161,6 +180,9 @@ def _parse_headers(header_dict, include_path_list, parser_caching):
161180 # Since we only support 64 bit architectures, we can inline the sizeof(T*) to 8 and then compute the
162181 # result in Python. The arithmetic expression is preserved to help with clarity and understanding
163182 r"char reserved\[52 - sizeof\(CUcheckpointGpuPair \*\)\];" : rf"char reserved[{ 52 - 8 } ];" ,
183+ r"char reserved\[64 - sizeof\(CUcheckpointGpuPair \*\) - sizeof\(unsigned int\)\];" : (
184+ rf"char reserved[{ 64 - 8 - 4 } ];"
185+ ),
164186 }
165187
166188 print (f'Parsing headers in "{ include_path_list } " (Caching = { parser_caching } )' , flush = True )
@@ -310,6 +332,13 @@ def _build_cuda_bindings(strip=False):
310332 found_types , found_functions , found_values , found_struct , struct_list = _parse_headers (
311333 header_dict , include_path_list , parser_caching
312334 )
335+ struct_field_types = {}
336+ struct_field_array_lengths = {}
337+ for struct_name , struct in struct_list .items ():
338+ for member_name in struct ._member_names :
339+ key = f"{ struct_name } .{ member_name } "
340+ struct_field_types [key ] = struct .member_type (member_name )
341+ struct_field_array_lengths [key ] = struct .member_array_length (member_name )
313342
314343 # Generate code from .in templates
315344 path_list = [
@@ -332,6 +361,8 @@ def _build_cuda_bindings(strip=False):
332361 "found_values" : found_values ,
333362 "found_struct" : found_struct ,
334363 "struct_list" : struct_list ,
364+ "struct_field_types" : struct_field_types ,
365+ "struct_field_array_lengths" : struct_field_array_lengths ,
335366 "os" : os ,
336367 "sys" : sys ,
337368 "platform" : platform ,
0 commit comments