Skip to content

Commit e0465db

Browse files
TST: Add remaining tests for ObjectCode.from_<library/object>
1 parent 11eb35f commit e0465db

1 file changed

Lines changed: 87 additions & 0 deletions

File tree

cuda_core/tests/test_module.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import ctypes
55
import pickle
6+
import shutil
7+
import subprocess
68
import warnings
79

810
import pytest
@@ -33,6 +35,10 @@
3335
"""
3436

3537

38+
def _nvcc_path():
39+
return shutil.which("nvcc")
40+
41+
3642
def _is_nvfatbin_available():
3743
"""Check if nvfatbin bindings are available."""
3844
try:
@@ -44,7 +50,24 @@ def _is_nvfatbin_available():
4450
return False
4551

4652

53+
def _is_nvcc_available():
54+
return _nvcc_path() is not None
55+
56+
57+
def _compile_with_nvcc(*, tmp_path, kernel, flags, arch, suffix):
58+
"""Compile kernel with nvcc, return content of output path."""
59+
src = tmp_path / "kernel.cu"
60+
src.write_text(kernel)
61+
out = tmp_path / f"kernel{suffix}"
62+
compile_cmd = [_nvcc_path(), f"-arch={arch}", *flags, "-o", str(out), str(src)]
63+
result = subprocess.run(compile_cmd, capture_output=True) # noqa: S603
64+
if result.returncode != 0:
65+
pytest.fail(f"nvcc failed: {result.stderr}")
66+
return out.read_bytes()
67+
68+
4769
nvfatbin_available = pytest.mark.skipif(not _is_nvfatbin_available(), reason="nvfatbin bindings not available")
70+
nvcc_available = pytest.mark.skipif(not _is_nvcc_available(), reason="nvcc not in PATH")
4871

4972

5073
@pytest.fixture(scope="module")
@@ -172,6 +195,20 @@ def get_saxpy_fatbin(init_cuda):
172195
return bytes(fatbin), sym_map
173196

174197

198+
@pytest.fixture
199+
def get_saxpy_object(init_cuda, tmp_path):
200+
dev = Device()
201+
arch = dev.arch
202+
return _compile_with_nvcc(tmp_path=tmp_path, kernel=SAXPY_KERNEL, flags=["-dc"], arch=f"sm_{arch}", suffix=".o")
203+
204+
205+
@pytest.fixture
206+
def get_saxpy_library(init_cuda, tmp_path):
207+
dev = Device()
208+
arch = dev.arch
209+
return _compile_with_nvcc(tmp_path=tmp_path, kernel=SAXPY_KERNEL, flags=["-lib"], arch=f"sm_{arch}", suffix=".a")
210+
211+
175212
def test_get_kernel(init_cuda):
176213
kernel = """extern "C" __global__ void ABC() { }"""
177214

@@ -330,6 +367,56 @@ def test_object_code_load_fatbin_from_file(get_saxpy_fatbin, tmp_path, convert_p
330367
mod_obj.get_kernel("saxpy<double>") # force loading
331368

332369

370+
@nvcc_available
371+
def test_object_code_load_object(get_saxpy_object):
372+
objct = get_saxpy_object
373+
assert isinstance(objct, bytes)
374+
mod_obj = ObjectCode.from_object(objct)
375+
assert mod_obj.code == objct
376+
assert mod_obj.code_type == "object"
377+
# object doesn't support kernel retrieval directly as it's used for linking
378+
# Test that get_kernel fails for unsupported code type
379+
with pytest.raises(RuntimeError, match=r'Unsupported code type "object"'):
380+
mod_obj.get_kernel("saxpy<float>")
381+
382+
383+
@nvcc_available
384+
def test_object_code_load_object_from_file(get_saxpy_object, tmp_path, convert_path):
385+
objct = get_saxpy_object
386+
assert isinstance(objct, bytes)
387+
object_file = tmp_path / "test.o"
388+
object_file.write_bytes(objct)
389+
arg = convert_path(object_file)
390+
mod_obj = ObjectCode.from_object(arg)
391+
assert mod_obj.code == str(arg)
392+
assert mod_obj.code_type == "object"
393+
394+
395+
@nvcc_available
396+
def test_object_code_load_library(get_saxpy_library):
397+
library = get_saxpy_library
398+
assert isinstance(library, bytes)
399+
mod_obj = ObjectCode.from_library(library)
400+
assert mod_obj.code == library
401+
assert mod_obj.code_type == "library"
402+
# object doesn't support kernel retrieval directly as it's used for linking
403+
# Test that get_kernel fails for unsupported code type
404+
with pytest.raises(RuntimeError, match=r'Unsupported code type "library"'):
405+
mod_obj.get_kernel("saxpy<float>")
406+
407+
408+
@nvcc_available
409+
def test_object_code_load_library_from_file(get_saxpy_library, tmp_path, convert_path):
410+
library = get_saxpy_library
411+
assert isinstance(library, bytes)
412+
library_file = tmp_path / "test.a"
413+
library_file.write_bytes(library)
414+
arg = convert_path(library_file)
415+
mod_obj = ObjectCode.from_library(arg)
416+
assert mod_obj.code == str(arg)
417+
assert mod_obj.code_type == "library"
418+
419+
333420
def test_saxpy_arguments(get_saxpy_kernel_cubin, cuda12_4_prerequisite_check):
334421
krn, _ = get_saxpy_kernel_cubin
335422

0 commit comments

Comments
 (0)