33
44import ctypes
55import pickle
6+ import shutil
7+ import subprocess
68import warnings
79
810import pytest
3335"""
3436
3537
38+ def _nvcc_path ():
39+ return shutil .which ("nvcc" )
40+
41+
3642def _is_nvfatbin_available ():
3743 """Check if nvfatbin bindings are available."""
3844 try :
@@ -44,7 +50,24 @@ def _is_nvfatbin_available():
4450 return False
4551
4652
53+ def _is_nvcc_available ():
54+ return _nvcc_path () is not None
55+
56+
57+ def _compile_with_nvcc (* , tmp_path , kernel , flags , arch , suffix ):
58+ """Compile kernel with nvcc, return content of output path."""
59+ src = tmp_path / "kernel.cu"
60+ src .write_text (kernel )
61+ out = tmp_path / f"kernel{ suffix } "
62+ compile_cmd = [_nvcc_path (), f"-arch={ arch } " , * flags , "-o" , str (out ), str (src )]
63+ result = subprocess .run (compile_cmd , capture_output = True ) # noqa: S603
64+ if result .returncode != 0 :
65+ pytest .fail (f"nvcc failed: { result .stderr } " )
66+ return out .read_bytes ()
67+
68+
4769nvfatbin_available = pytest .mark .skipif (not _is_nvfatbin_available (), reason = "nvfatbin bindings not available" )
70+ nvcc_available = pytest .mark .skipif (not _is_nvcc_available (), reason = "nvcc not in PATH" )
4871
4972
5073@pytest .fixture (scope = "module" )
@@ -172,6 +195,20 @@ def get_saxpy_fatbin(init_cuda):
172195 return bytes (fatbin ), sym_map
173196
174197
198+ @pytest .fixture
199+ def get_saxpy_object (init_cuda , tmp_path ):
200+ dev = Device ()
201+ arch = dev .arch
202+ return _compile_with_nvcc (tmp_path = tmp_path , kernel = SAXPY_KERNEL , flags = ["-dc" ], arch = f"sm_{ arch } " , suffix = ".o" )
203+
204+
205+ @pytest .fixture
206+ def get_saxpy_library (init_cuda , tmp_path ):
207+ dev = Device ()
208+ arch = dev .arch
209+ return _compile_with_nvcc (tmp_path = tmp_path , kernel = SAXPY_KERNEL , flags = ["-lib" ], arch = f"sm_{ arch } " , suffix = ".a" )
210+
211+
175212def test_get_kernel (init_cuda ):
176213 kernel = """extern "C" __global__ void ABC() { }"""
177214
@@ -330,6 +367,56 @@ def test_object_code_load_fatbin_from_file(get_saxpy_fatbin, tmp_path, convert_p
330367 mod_obj .get_kernel ("saxpy<double>" ) # force loading
331368
332369
370+ @nvcc_available
371+ def test_object_code_load_object (get_saxpy_object ):
372+ objct = get_saxpy_object
373+ assert isinstance (objct , bytes )
374+ mod_obj = ObjectCode .from_object (objct )
375+ assert mod_obj .code == objct
376+ assert mod_obj .code_type == "object"
377+ # object doesn't support kernel retrieval directly as it's used for linking
378+ # Test that get_kernel fails for unsupported code type
379+ with pytest .raises (RuntimeError , match = r'Unsupported code type "object"' ):
380+ mod_obj .get_kernel ("saxpy<float>" )
381+
382+
383+ @nvcc_available
384+ def test_object_code_load_object_from_file (get_saxpy_object , tmp_path , convert_path ):
385+ objct = get_saxpy_object
386+ assert isinstance (objct , bytes )
387+ object_file = tmp_path / "test.o"
388+ object_file .write_bytes (objct )
389+ arg = convert_path (object_file )
390+ mod_obj = ObjectCode .from_object (arg )
391+ assert mod_obj .code == str (arg )
392+ assert mod_obj .code_type == "object"
393+
394+
395+ @nvcc_available
396+ def test_object_code_load_library (get_saxpy_library ):
397+ library = get_saxpy_library
398+ assert isinstance (library , bytes )
399+ mod_obj = ObjectCode .from_library (library )
400+ assert mod_obj .code == library
401+ assert mod_obj .code_type == "library"
402+ # object doesn't support kernel retrieval directly as it's used for linking
403+ # Test that get_kernel fails for unsupported code type
404+ with pytest .raises (RuntimeError , match = r'Unsupported code type "library"' ):
405+ mod_obj .get_kernel ("saxpy<float>" )
406+
407+
408+ @nvcc_available
409+ def test_object_code_load_library_from_file (get_saxpy_library , tmp_path , convert_path ):
410+ library = get_saxpy_library
411+ assert isinstance (library , bytes )
412+ library_file = tmp_path / "test.a"
413+ library_file .write_bytes (library )
414+ arg = convert_path (library_file )
415+ mod_obj = ObjectCode .from_library (arg )
416+ assert mod_obj .code == str (arg )
417+ assert mod_obj .code_type == "library"
418+
419+
333420def test_saxpy_arguments (get_saxpy_kernel_cubin , cuda12_4_prerequisite_check ):
334421 krn , _ = get_saxpy_kernel_cubin
335422
0 commit comments