Skip to content

Commit f13c754

Browse files
committed
Adjust real guard-rails tests for cu12 nvcc wheels.
This keeps the real host-backed checks strict when an installed nvcc wheel actually ships nvcc, while avoiding a false failure in cu12 wheel environments that only provide lower-level compiler pieces such as ptxas. Made-with: Cursor
1 parent 583af91 commit f13c754

1 file changed

Lines changed: 33 additions & 27 deletions

File tree

cuda_pathfinder/tests/test_compatibility_guard_rails.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import pytest
99
from local_helpers import (
10+
have_distribution,
1011
locate_real_cuda_toolkit_version_from_cuda_h,
1112
require_real_cuda_toolkit_version_from_cuda_h,
1213
require_real_driver_cuda_version,
@@ -49,6 +50,7 @@ def _default_process_wide_guard_rails_mode(monkeypatch):
4950

5051
@pytest.fixture
5152
def clear_real_host_probe_caches():
53+
have_distribution.cache_clear()
5254
locate_real_cuda_toolkit_version_from_cuda_h.cache_clear()
5355
locate_nvidia_header_directory_raw.cache_clear()
5456
_resolve_system_loaded_abs_path_in_subprocess.cache_clear()
@@ -57,6 +59,7 @@ def clear_real_host_probe_caches():
5759
driver_info._load_nvidia_dynamic_lib.cache_clear()
5860
driver_info.query_driver_cuda_version.cache_clear()
5961
yield
62+
have_distribution.cache_clear()
6063
locate_real_cuda_toolkit_version_from_cuda_h.cache_clear()
6164
locate_nvidia_header_directory_raw.cache_clear()
6265
_resolve_system_loaded_abs_path_in_subprocess.cache_clear()
@@ -685,16 +688,26 @@ def test_find_nvidia_header_directory_returns_none_when_unresolved(monkeypatch):
685688

686689

687690
@pytest.mark.usefixtures("clear_real_host_probe_caches")
688-
def test_real_wheel_ctk_items_are_compatible(info_summary_append):
689-
real_ctk = require_real_cuda_toolkit_version_from_cuda_h()
691+
def test_real_driver(info_summary_append):
690692
real_driver = require_real_driver_cuda_version()
691693
info_summary_append(
692-
f"real cuda.h CTK version={real_ctk.version.major}.{real_ctk.version.minor} "
693-
f"via {real_ctk.found_via} at {real_ctk.cuda_h_path!r}"
694+
f"real driver CUDA version={real_driver.major}.{real_driver.minor} (encoded={real_driver.encoded})"
694695
)
696+
697+
698+
@pytest.mark.usefixtures("clear_real_host_probe_caches")
699+
def test_real_ctk(info_summary_append):
700+
real_ctk = require_real_cuda_toolkit_version_from_cuda_h()
695701
info_summary_append(
696-
f"real driver CUDA version={real_driver.major}.{real_driver.minor} (encoded={real_driver.encoded})"
702+
f"real cuda.h CTK version={real_ctk.version.major}.{real_ctk.version.minor} "
703+
f"via {real_ctk.found_via} at {real_ctk.cuda_h_path!r}"
697704
)
705+
706+
707+
@pytest.mark.usefixtures("clear_real_host_probe_caches")
708+
def test_real_wheel_ctk_items_are_compatible(info_summary_append):
709+
real_ctk = require_real_cuda_toolkit_version_from_cuda_h()
710+
real_driver = require_real_driver_cuda_version()
698711
guard_rails = CompatibilityGuardRails(
699712
ctk_major=real_ctk.version.major,
700713
ctk_minor=real_ctk.version.minor,
@@ -716,33 +729,29 @@ def test_real_wheel_ctk_items_are_compatible(info_summary_append):
716729
) as exc:
717730
if STRICTNESS == "all_must_work":
718731
raise
719-
info_summary_append(f"real CTK check unavailable: {exc.__class__.__name__}: {exc}")
720-
return
721-
722-
info_summary_append(f"nvrtc={loaded.abs_path!r}")
723-
info_summary_append(f"nvrtc_headers={header_dir!r}")
724-
info_summary_append(f"cudadevrt={static_lib!r}")
725-
info_summary_append(f"libdevice={bitcode_lib!r}")
726-
info_summary_append(f"nvcc={nvcc!r}")
732+
pytest.skip(f"real CTK check unavailable: {exc.__class__.__name__}: {exc}")
727733

728734
assert isinstance(loaded.abs_path, str)
729735
assert header_dir is not None
730-
assert nvcc is not None
731-
for path in (loaded.abs_path, header_dir, static_lib, bitcode_lib, nvcc):
736+
for path in (loaded.abs_path, header_dir, static_lib, bitcode_lib):
732737
_assert_real_ctk_backed_path(path)
738+
if have_distribution(r"^nvidia-cuda-nvcc-cu12$"):
739+
# For CUDA 12, NVIDIA publishes a PyPI package named nvidia-cuda-nvcc-cu12,
740+
# but the wheels only contain nvcc-adjacent compiler components such as
741+
# ptxas, CRT headers, libnvvm, and libdevice; the nvcc executable itself
742+
# is not included.
743+
if nvcc is not None:
744+
# nvcc found elsewhere, e.g. /usr/local or Conda.
745+
_assert_real_ctk_backed_path(nvcc)
746+
else:
747+
assert nvcc is not None
748+
_assert_real_ctk_backed_path(nvcc)
733749

734750

735751
@pytest.mark.usefixtures("clear_real_host_probe_caches")
736752
def test_real_wheel_component_version_does_not_override_ctk_line(info_summary_append):
737753
real_ctk = require_real_cuda_toolkit_version_from_cuda_h()
738754
real_driver = require_real_driver_cuda_version()
739-
info_summary_append(
740-
f"real cuda.h CTK version={real_ctk.version.major}.{real_ctk.version.minor} "
741-
f"via {real_ctk.found_via} at {real_ctk.cuda_h_path!r}"
742-
)
743-
info_summary_append(
744-
f"real driver CUDA version={real_driver.major}.{real_driver.minor} (encoded={real_driver.encoded})"
745-
)
746755
guard_rails = CompatibilityGuardRails(
747756
ctk_major=real_ctk.version.major,
748757
ctk_minor=real_ctk.version.minor,
@@ -754,14 +763,11 @@ def test_real_wheel_component_version_does_not_override_ctk_line(info_summary_ap
754763
except (CompatibilityCheckError, CompatibilityInsufficientMetadataError) as exc:
755764
if STRICTNESS == "all_must_work":
756765
raise
757-
info_summary_append(f"real cufft CTK check unavailable: {exc.__class__.__name__}: {exc}")
758-
return
766+
pytest.skip(f"real cufft CTK check unavailable: {exc.__class__.__name__}: {exc}")
759767

760768
if header_dir is None:
761769
if STRICTNESS == "all_must_work":
762770
raise AssertionError("Expected CTK-backed cufft headers to be discoverable.")
763-
info_summary_append("real cufft CTK check unavailable: cufft headers not found")
764-
return
771+
pytest.skip("real cufft CTK check unavailable: cufft headers not found")
765772

766-
info_summary_append(f"cufft_headers={header_dir!r}")
767773
_assert_real_ctk_backed_path(header_dir)

0 commit comments

Comments
 (0)