Skip to content

Commit 5bb06e5

Browse files
committed
Merge remote-tracking branch 'origin/cybind-catchup-12.9.x' into cybind-catchup-12.9.x
2 parents 433841e + 9a3b242 commit 5bb06e5

1 file changed

Lines changed: 12 additions & 8 deletions

File tree

cuda_bindings/tests/test_nvfatbin.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,7 @@ def nvcc_smoke(tmpdir) -> str:
122122
return nvcc
123123

124124

125-
@pytest.fixture
126-
def CUBIN(arch):
125+
def _build_cubin(arch):
127126
def CHECK_NVRTC(err):
128127
if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
129128
raise RuntimeError(repr(err))
@@ -142,6 +141,11 @@ def CHECK_NVRTC(err):
142141
return cubin
143142

144143

144+
@pytest.fixture
145+
def CUBIN(arch):
146+
return _build_cubin(arch)
147+
148+
145149
# create a valid LTOIR input for testing
146150
@pytest.fixture
147151
def LTOIR(arch):
@@ -261,11 +265,11 @@ def test_nvfatbin_add_ptx(PTX, arch):
261265
nvfatbin.destroy(handle)
262266

263267

264-
@pytest.mark.parametrize("arch", ["sm_80"], indirect=True)
265-
def test_nvfatbin_add_cubin_ELF_SIZE_MISMATCH(CUBIN, arch):
268+
def test_nvfatbin_add_cubin_ELF_SIZE_MISMATCH():
269+
cubin = _build_cubin("sm_80")
266270
handle = nvfatbin.create([], 0)
267271
with pytest.raises(nvfatbin.nvFatbinError, match="ERROR_ELF_ARCH_MISMATCH"):
268-
nvfatbin.add_cubin(handle, CUBIN, len(CUBIN), "75", "inc")
272+
nvfatbin.add_cubin(handle, cubin, len(cubin), "75", "inc")
269273

270274
nvfatbin.destroy(handle)
271275

@@ -282,11 +286,11 @@ def test_nvfatbin_add_cubin(CUBIN, arch):
282286
nvfatbin.destroy(handle)
283287

284288

285-
@pytest.mark.parametrize("arch", ["sm_80"], indirect=True)
286-
def test_nvfatbin_add_cubin_ELF_ARCH_MISMATCH(CUBIN, arch):
289+
def test_nvfatbin_add_cubin_ELF_ARCH_MISMATCH():
290+
cubin = _build_cubin("sm_80")
287291
handle = nvfatbin.create([], 0)
288292
with pytest.raises(nvfatbin.nvFatbinError, match="ERROR_ELF_ARCH_MISMATCH"):
289-
nvfatbin.add_cubin(handle, CUBIN, len(CUBIN), "75", "inc")
293+
nvfatbin.add_cubin(handle, cubin, len(cubin), "75", "inc")
290294

291295
nvfatbin.destroy(handle)
292296

0 commit comments

Comments
 (0)