diff --git a/cuda_bindings/tests/test_nvfatbin.py b/cuda_bindings/tests/test_nvfatbin.py index db6281a575c..400de0c4b79 100644 --- a/cuda_bindings/tests/test_nvfatbin.py +++ b/cuda_bindings/tests/test_nvfatbin.py @@ -122,8 +122,7 @@ def nvcc_smoke(tmpdir) -> str: return nvcc -@pytest.fixture -def CUBIN(arch): +def _build_cubin(arch): def CHECK_NVRTC(err): if err != nvrtc.nvrtcResult.NVRTC_SUCCESS: raise RuntimeError(repr(err)) @@ -142,6 +141,11 @@ def CHECK_NVRTC(err): return cubin +@pytest.fixture +def CUBIN(arch): + return _build_cubin(arch) + + # create a valid LTOIR input for testing @pytest.fixture def LTOIR(arch): @@ -261,11 +265,11 @@ def test_nvfatbin_add_ptx(PTX, arch): nvfatbin.destroy(handle) -@pytest.mark.parametrize("arch", ["sm_80"], indirect=True) -def test_nvfatbin_add_cubin_ELF_SIZE_MISMATCH(CUBIN, arch): +def test_nvfatbin_add_cubin_ELF_SIZE_MISMATCH(): + cubin = _build_cubin("sm_80") handle = nvfatbin.create([], 0) with pytest.raises(nvfatbin.nvFatbinError, match="ERROR_ELF_ARCH_MISMATCH"): - nvfatbin.add_cubin(handle, CUBIN, len(CUBIN), "75", "inc") + nvfatbin.add_cubin(handle, cubin, len(cubin), "75", "inc") nvfatbin.destroy(handle) @@ -282,11 +286,11 @@ def test_nvfatbin_add_cubin(CUBIN, arch): nvfatbin.destroy(handle) -@pytest.mark.parametrize("arch", ["sm_80"], indirect=True) -def test_nvfatbin_add_cubin_ELF_ARCH_MISMATCH(CUBIN, arch): +def test_nvfatbin_add_cubin_ELF_ARCH_MISMATCH(): + cubin = _build_cubin("sm_80") handle = nvfatbin.create([], 0) with pytest.raises(nvfatbin.nvFatbinError, match="ERROR_ELF_ARCH_MISMATCH"): - nvfatbin.add_cubin(handle, CUBIN, len(CUBIN), "75", "inc") + nvfatbin.add_cubin(handle, cubin, len(cubin), "75", "inc") nvfatbin.destroy(handle)