@@ -122,8 +122,7 @@ def nvcc_smoke(tmpdir) -> str:
122122 return nvcc
123123
124124
125- @pytest .fixture
126- def CUBIN (arch ):
125+ def _build_cubin (arch ):
127126 def CHECK_NVRTC (err ):
128127 if err != nvrtc .nvrtcResult .NVRTC_SUCCESS :
129128 raise RuntimeError (repr (err ))
@@ -142,6 +141,11 @@ def CHECK_NVRTC(err):
142141 return cubin
143142
144143
144+ @pytest .fixture
145+ def CUBIN (arch ):
146+ return _build_cubin (arch )
147+
148+
145149# create a valid LTOIR input for testing
146150@pytest .fixture
147151def LTOIR (arch ):
@@ -261,11 +265,11 @@ def test_nvfatbin_add_ptx(PTX, arch):
261265 nvfatbin .destroy (handle )
262266
263267
264- @ pytest . mark . parametrize ( "arch" , [ "sm_80" ], indirect = True )
265- def test_nvfatbin_add_cubin_ELF_SIZE_MISMATCH ( CUBIN , arch ):
268+ def test_nvfatbin_add_cubin_ELF_SIZE_MISMATCH ():
269+ cubin = _build_cubin ( "sm_80" )
266270 handle = nvfatbin .create ([], 0 )
267271 with pytest .raises (nvfatbin .nvFatbinError , match = "ERROR_ELF_ARCH_MISMATCH" ):
268- nvfatbin .add_cubin (handle , CUBIN , len (CUBIN ), "75" , "inc" )
272+ nvfatbin .add_cubin (handle , cubin , len (cubin ), "75" , "inc" )
269273
270274 nvfatbin .destroy (handle )
271275
@@ -282,11 +286,11 @@ def test_nvfatbin_add_cubin(CUBIN, arch):
282286 nvfatbin .destroy (handle )
283287
284288
285- @ pytest . mark . parametrize ( "arch" , [ "sm_80" ], indirect = True )
286- def test_nvfatbin_add_cubin_ELF_ARCH_MISMATCH ( CUBIN , arch ):
289+ def test_nvfatbin_add_cubin_ELF_ARCH_MISMATCH ():
290+ cubin = _build_cubin ( "sm_80" )
287291 handle = nvfatbin .create ([], 0 )
288292 with pytest .raises (nvfatbin .nvFatbinError , match = "ERROR_ELF_ARCH_MISMATCH" ):
289- nvfatbin .add_cubin (handle , CUBIN , len (CUBIN ), "75" , "inc" )
293+ nvfatbin .add_cubin (handle , cubin , len (cubin ), "75" , "inc" )
290294
291295 nvfatbin .destroy (handle )
292296
0 commit comments