diff --git a/cuda_bindings/tests/test_nvjitlink.py b/cuda_bindings/tests/test_nvjitlink.py index 000ef52e075..5c6ca98ea73 100644 --- a/cuda_bindings/tests/test_nvjitlink.py +++ b/cuda_bindings/tests/test_nvjitlink.py @@ -1,5 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. -# +# Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE import pytest @@ -11,15 +10,13 @@ PTX_VERSIONS = ["5.0", "6.4", "7.0", "8.5"] -def ptx_header(version, arch): - return f""" -.version {version} -.target {arch} +PTX_HEADER = """\ +.version {VERSION} +.target {ARCH} .address_size 64 """ - -ptx_kernel = """ +PTX_KERNEL = """ .visible .entry _Z6kernelPi( .param .u64 _Z6kernelPi_param_0 ) @@ -36,20 +33,21 @@ def ptx_header(version, arch): } """ -minimal_ptx_kernel = """ -.func _MinimalKernel() -{ - ret; -} -""" -ptx_kernel_bytes = [ - (ptx_header(version, arch) + ptx_kernel).encode("utf-8") for version, arch in zip(PTX_VERSIONS, ARCHITECTURES) -] -minimal_ptx_kernel_bytes = [ - (ptx_header(version, arch) + minimal_ptx_kernel).encode("utf-8") - for version, arch in zip(PTX_VERSIONS, ARCHITECTURES) -] +def _build_arch_ptx_parametrized_callable(): + av = tuple(zip(ARCHITECTURES, PTX_VERSIONS)) + return pytest.mark.parametrize( + ("arch", "ptx_bytes"), + [(a, (PTX_HEADER.format(VERSION=v, ARCH=a) + PTX_KERNEL).encode("utf-8")) for a, v in av], + ids=[f"{a}_{v}" for a, v in av], + ) + + +ARCH_PTX_PARAMETRIZED_CALLABLE = _build_arch_ptx_parametrized_callable() + + +def arch_ptx_parametrized(func): + return ARCH_PTX_PARAMETRIZED_CALLABLE(func) def check_nvjitlink_usable(): @@ -108,17 +106,17 @@ def test_complete_empty(option): nvjitlink.destroy(handle) -@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes)) -def test_add_data(option, ptx_bytes): - handle = nvjitlink.create(1, [f"-arch={option}"]) +@arch_ptx_parametrized +def test_add_data(arch, ptx_bytes): + handle = nvjitlink.create(1, [f"-arch={arch}"]) nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data") nvjitlink.complete(handle) nvjitlink.destroy(handle) -@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes)) -def test_add_file(option, ptx_bytes, tmp_path): - handle = nvjitlink.create(1, [f"-arch={option}"]) +@arch_ptx_parametrized +def test_add_file(arch, ptx_bytes, tmp_path): + handle = nvjitlink.create(1, [f"-arch={arch}"]) file_path = tmp_path / "test_file.cubin" file_path.write_bytes(ptx_bytes) nvjitlink.add_file(handle, nvjitlink.InputType.ANY, str(file_path)) @@ -126,9 +124,9 @@ def test_add_file(option, ptx_bytes, tmp_path): nvjitlink.destroy(handle) -@pytest.mark.parametrize("option", ARCHITECTURES) -def test_get_error_log(option): - handle = nvjitlink.create(1, [f"-arch={option}"]) +@pytest.mark.parametrize("arch", ARCHITECTURES) +def test_get_error_log(arch): + handle = nvjitlink.create(1, [f"-arch={arch}"]) nvjitlink.complete(handle) log_size = nvjitlink.get_error_log_size(handle) log = bytearray(log_size) @@ -137,9 +135,9 @@ def test_get_error_log(option): nvjitlink.destroy(handle) -@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes)) -def test_get_info_log(option, ptx_bytes): - handle = nvjitlink.create(1, [f"-arch={option}"]) +@arch_ptx_parametrized +def test_get_info_log(arch, ptx_bytes): + handle = nvjitlink.create(1, [f"-arch={arch}"]) nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data") nvjitlink.complete(handle) log_size = nvjitlink.get_info_log_size(handle) @@ -149,9 +147,9 @@ def test_get_info_log(option, ptx_bytes): nvjitlink.destroy(handle) -@pytest.mark.parametrize("option, ptx_bytes", zip(ARCHITECTURES, ptx_kernel_bytes)) -def test_get_linked_cubin(option, ptx_bytes): - handle = nvjitlink.create(1, [f"-arch={option}"]) +@arch_ptx_parametrized +def test_get_linked_cubin(arch, ptx_bytes): + handle = nvjitlink.create(1, [f"-arch={arch}"]) nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data") nvjitlink.complete(handle) cubin_size = nvjitlink.get_linked_cubin_size(handle) @@ -161,9 +159,9 @@ def test_get_linked_cubin(option, ptx_bytes): nvjitlink.destroy(handle) -@pytest.mark.parametrize("option", ARCHITECTURES) -def test_get_linked_ptx(option, get_dummy_ltoir): - handle = nvjitlink.create(3, [f"-arch={option}", "-lto", "-ptx"]) +@pytest.mark.parametrize("arch", ARCHITECTURES) +def test_get_linked_ptx(arch, get_dummy_ltoir): + handle = nvjitlink.create(3, [f"-arch={arch}", "-lto", "-ptx"]) nvjitlink.add_data(handle, nvjitlink.InputType.LTOIR, get_dummy_ltoir, len(get_dummy_ltoir), "test_data") nvjitlink.complete(handle) ptx_size = nvjitlink.get_linked_ptx_size(handle)