diff --git a/src/gpuhunt/_internal/constraints.py b/src/gpuhunt/_internal/constraints.py index 5b039aa..6cf19bf 100644 --- a/src/gpuhunt/_internal/constraints.py +++ b/src/gpuhunt/_internal/constraints.py @@ -306,6 +306,12 @@ def is_nvidia_superchip(gpu_name: str) -> bool: KNOWN_TENSTORRENT_ACCELERATORS: list[TenstorrentAcceleratorInfo] = [ TenstorrentAcceleratorInfo(name="n150", memory=12), TenstorrentAcceleratorInfo(name="n300", memory=24), + TenstorrentAcceleratorInfo(name="tt-galaxy-wh", memory=12), + TenstorrentAcceleratorInfo(name="p100a", memory=28), + TenstorrentAcceleratorInfo(name="p150", memory=32), + TenstorrentAcceleratorInfo(name="p300", memory=32), + TenstorrentAcceleratorInfo(name="p300", memory=64), + TenstorrentAcceleratorInfo(name="tt-galaxy-bh", memory=32), ] KNOWN_ACCELERATORS: list[ diff --git a/src/tests/_internal/test_constraints.py b/src/tests/_internal/test_constraints.py index 7d115da..817781f 100644 --- a/src/tests/_internal/test_constraints.py +++ b/src/tests/_internal/test_constraints.py @@ -1,7 +1,12 @@ import pytest from gpuhunt import CatalogItem, QueryFilter -from gpuhunt._internal.constraints import correct_gpu_memory_gib, matches +from gpuhunt._internal.constraints import ( + correct_gpu_memory_gib, + find_accelerators, + get_gpu_vendor, + matches, +) from gpuhunt._internal.models import AcceleratorVendor @@ -222,3 +227,26 @@ def test_matches_cpu_instances_with_zero_gpu_count_and_gpu_memory(self): ) def test_correct_gpu_memory(gpu_name: str, memory_mib: float, expected_memory_gib: int) -> None: assert correct_gpu_memory_gib(gpu_name, memory_mib) == expected_memory_gib + + +@pytest.mark.parametrize( + ("gpu_name", "expected_memories_gib"), + [ + ("n150", {12}), + ("n300", {24}), + ("tt-galaxy-wh", {12}), + ("p100a", {28}), + ("p150", {32}), + ("p300", {32, 64}), + ("tt-galaxy-bh", {32}), + ], +) +def test_tenstorrent_accelerators(gpu_name: str, expected_memories_gib: set[int]) -> None: + accelerators = find_accelerators( + names=[gpu_name.upper()], + vendors=[AcceleratorVendor.TENSTORRENT], + ) + + assert {accelerator.name for accelerator in accelerators} == {gpu_name} + assert {accelerator.memory for accelerator in accelerators} == expected_memories_gib + assert get_gpu_vendor(gpu_name.upper()) == AcceleratorVendor.TENSTORRENT