Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ gcp = [
"google-cloud-tpu",
]
nebius = [
"nebius>=0.2.40,<0.4",
"nebius>=0.3.20,<0.4",
]
oci = [
"oci",
Expand Down
15 changes: 10 additions & 5 deletions src/gpuhunt/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,16 @@ def main():
from gpuhunt.providers.nebius import NebiusProvider

provider = NebiusProvider(
credentials=PKReader(
filename=os.getenv("NEBIUS_PRIVATE_KEY_FILE"),
public_key_id=os.getenv("NEBIUS_PUBLIC_KEY_ID"),
service_account_id=os.getenv("NEBIUS_SERVICE_ACCOUNT_ID"),
),
credentials=(
# temporary user token from `nebius iam get-access-token`
os.getenv("NEBIUS_ACCESS_TOKEN")
# or service account credentials
or PKReader(
filename=os.getenv("NEBIUS_PRIVATE_KEY_FILE"),
public_key_id=os.getenv("NEBIUS_PUBLIC_KEY_ID"),
service_account_id=os.getenv("NEBIUS_SERVICE_ACCOUNT_ID"),
)
)
)
elif args.provider == "oci":
from gpuhunt.providers.oci import OCICredentials, OCIProvider
Expand Down
21 changes: 17 additions & 4 deletions src/gpuhunt/providers/nebius.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
InstanceSpec,
ListPlatformsRequest,
ListPlatformsResponse,
Platform,
PlatformServiceClient,
PreemptibleSpec,
Preset,
Expand All @@ -41,6 +42,11 @@

logger = logging.getLogger(__name__)
TIMEOUT = 7
# When GPU name from the platform ID does not match the standard dstack GPU name.
# https://docs.nebius.com/compute/virtual-machines/types
GPU_NAME_OVERRIDES = {
"rtx6000": "rtxpro6000",
}


@dataclass(frozen=True)
Expand All @@ -58,6 +64,9 @@ class InfinibandFabric:
InfinibandFabric("fabric-5", "gpu-h200-sxm", "eu-west1"),
InfinibandFabric("fabric-6", "gpu-h100-sxm", "eu-north1"),
InfinibandFabric("fabric-7", "gpu-h200-sxm", "eu-north1"),
InfinibandFabric("eu-north2-a", "gpu-h200-sxm", "eu-north2"),
InfinibandFabric("me-west1-a", "gpu-b200-sxm-a", "me-west1"),
InfinibandFabric("uk-south1-a", "gpu-b300-sxm", "uk-south1"),
InfinibandFabric("us-central1-a", "gpu-h200-sxm", "us-central1"),
InfinibandFabric("us-central1-b", "gpu-b200-sxm", "us-central1"),
]
Expand All @@ -81,7 +90,7 @@ def get(
platforms = list_platforms(sdk, project_id).items
for platform in platforms:
logger.info("Processing %s/%s", region, platform.metadata.name)
gpu = get_gpu_info(platform.metadata.name)
gpu = get_gpu_info(platform)
for preset in platform.spec.presets:
for spot in [False] + (
[True] if platform.status.allowed_for_preemptibles else []
Expand Down Expand Up @@ -149,13 +158,17 @@ def list_platforms(sdk: SDK, project_id: str) -> ListPlatformsResponse:
return PlatformServiceClient(sdk).list(req, per_retry_timeout=TIMEOUT).wait()


def get_gpu_info(platform: str) -> AcceleratorInfo | None:
m = re.match(r"gpu-([^-]+)-", platform)
def get_gpu_info(platform: Platform) -> AcceleratorInfo | None:
m = re.match(r"gpu-([^-]+)", platform.metadata.name)
if m is None:
return None
gpu_name = m.group(1)
gpu_name = GPU_NAME_OVERRIDES.get(gpu_name, gpu_name)
accelerator_info = find_accelerators(names=[gpu_name], vendors=[AcceleratorVendor.NVIDIA])
if len(accelerator_info) != 1:
if (
len(accelerator_info) != 1
or accelerator_info[0].memory != platform.spec.gpu_memory_gigabytes
):
return None
return accelerator_info[0]

Expand Down
2 changes: 1 addition & 1 deletion src/integrity_tests/test_nebius.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def data_rows(catalog_dir: Path) -> list[dict]:
return list(csv.DictReader(f))


@pytest.mark.parametrize("gpu", ["L40S", "H100", "H200", "B200", ""])
@pytest.mark.parametrize("gpu", ["RTXPRO6000", "L40S", "H100", "H200", "B200", ""])
def test_gpu_present(gpu: str, data_rows: list[dict]):
assert gpu in map(itemgetter("gpu_name"), data_rows)

Expand Down
Loading