Skip to content

Commit 757bdf0

Browse files
committed
Add SM arch bitcode lookup
1 parent 5c17600 commit 757bdf0

2 files changed

Lines changed: 19 additions & 14 deletions

File tree

cuda_pathfinder/cuda/pathfinder/_static_libs/find_bitcode_lib.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ def _no_such_file_in_dir(dir_path: str, filename: str, error_messages: list[str]
7575
attachments.append(f' Directory does not exist: "{dir_path}"')
7676

7777

78-
def _filename_with_sm_arch(filename: str, sm_arch: str) -> str:
79-
if not sm_arch:
78+
def _filename_with_sm_arch(filename: str, sm_arch: str | None) -> str:
79+
if sm_arch is None:
8080
return filename
8181

8282
if not re.match(r"^sm[0-9]+[a-z]?$", sm_arch):
@@ -87,7 +87,7 @@ def _filename_with_sm_arch(filename: str, sm_arch: str) -> str:
8787

8888

8989
class _FindBitcodeLib:
90-
def __init__(self, name: str, sm_arch: str = "") -> None:
90+
def __init__(self, name: str, sm_arch: str | None = None) -> None:
9191
if name not in _SUPPORTED_BITCODE_LIBS_INFO: # Updated reference
9292
raise ValueError(f"Unknown bitcode library: '{name}'. Supported: {', '.join(SUPPORTED_BITCODE_LIBS)}")
9393
self.name: str = name
@@ -142,20 +142,22 @@ def raise_not_found_error(self) -> NoReturn:
142142
raise BitcodeLibNotFoundError(f'Failure finding "{self.filename}": {err}\n{att}')
143143

144144

145-
def locate_bitcode_lib(name: str, *, sm_arch: str = "") -> LocatedBitcodeLib:
145+
def locate_bitcode_lib(name: str, *, sm_arch: str | None = None) -> LocatedBitcodeLib:
146146
"""Locate a bitcode library by name.
147147
148-
When ``sm_arch`` is set, locate the architecture-specific bitcode filename
149-
with ``_{sm_arch}`` inserted before the ``.bc`` suffix.
148+
When ``sm_arch`` is not ``None``, locate the architecture-specific bitcode
149+
filename with ``_{sm_arch}`` inserted before the ``.bc`` suffix.
150150
151151
Args:
152152
name: Name of the supported bitcode library to locate.
153153
sm_arch: Optional SM architecture suffix, such as ``"sm90"`` or
154-
``"sm90a"``. If set, it must match ``sm[0-9]+[a-z]?``.
154+
``"sm90a"``. If not ``None``, it must match
155+
``sm[0-9]+[a-z]?``.
155156
156157
Raises:
157158
ValueError: If ``name`` is not a supported bitcode library, or if
158-
``sm_arch`` is set but does not match ``sm[0-9]+[a-z]?``.
159+
``sm_arch`` is not ``None`` and does not match
160+
``sm[0-9]+[a-z]?``.
159161
BitcodeLibNotFoundError: If the bitcode library cannot be found.
160162
"""
161163
finder = _FindBitcodeLib(name, sm_arch)
@@ -191,20 +193,22 @@ def locate_bitcode_lib(name: str, *, sm_arch: str = "") -> LocatedBitcodeLib:
191193

192194

193195
@functools.cache
194-
def find_bitcode_lib(name: str, sm_arch: str = "") -> str:
196+
def find_bitcode_lib(name: str, sm_arch: str | None = None) -> str:
195197
"""Find the absolute path to a bitcode library.
196198
197-
When ``sm_arch`` is set, find the architecture-specific bitcode filename
198-
with ``_{sm_arch}`` inserted before the ``.bc`` suffix.
199+
When ``sm_arch`` is not ``None``, find the architecture-specific bitcode
200+
filename with ``_{sm_arch}`` inserted before the ``.bc`` suffix.
199201
200202
Args:
201203
name: Name of the supported bitcode library to find.
202204
sm_arch: Optional SM architecture suffix, such as ``"sm90"`` or
203-
``"sm90a"``. If set, it must match ``sm[0-9]+[a-z]?``.
205+
``"sm90a"``. If not ``None``, it must match
206+
``sm[0-9]+[a-z]?``.
204207
205208
Raises:
206209
ValueError: If ``name`` is not a supported bitcode library, or if
207-
``sm_arch`` is set but does not match ``sm[0-9]+[a-z]?``.
210+
``sm_arch`` is not ``None`` and does not match
211+
``sm[0-9]+[a-z]?``.
208212
BitcodeLibNotFoundError: If the bitcode library cannot be found.
209213
"""
210-
return locate_bitcode_lib(name, sm_arch).abs_path
214+
return locate_bitcode_lib(name, sm_arch=sm_arch).abs_path

cuda_pathfinder/tests/test_find_bitcode_lib.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def test_find_bitcode_lib_invalid_name():
296296
@pytest.mark.parametrize(
297297
"sm_arch",
298298
[
299+
"",
299300
"../sm90",
300301
"compute90",
301302
"sm_90",

0 commit comments

Comments
 (0)