From c30ea62015eb099412f9f74e67dc6fc2901727d6 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Fri, 12 Jun 2026 19:44:05 +0200 Subject: [PATCH] Update Nebius provider - Add `RTXPRO6000` - Add missing fabrics - Support the `NEBIUS_ACCESS_TOKEN` environment variable for easier local testing --- pyproject.toml | 2 +- src/gpuhunt/__main__.py | 15 ++++++++++----- src/gpuhunt/providers/nebius.py | 21 +++++++++++++++++---- src/integrity_tests/test_nebius.py | 2 +- 4 files changed, 29 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ea4840c..d8a61f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ gcp = [ "google-cloud-tpu", ] nebius = [ - "nebius>=0.2.40,<0.4", + "nebius>=0.3.20,<0.4", ] oci = [ "oci", diff --git a/src/gpuhunt/__main__.py b/src/gpuhunt/__main__.py index b28928f..620e538 100644 --- a/src/gpuhunt/__main__.py +++ b/src/gpuhunt/__main__.py @@ -89,11 +89,16 @@ def main(): from gpuhunt.providers.nebius import NebiusProvider provider = NebiusProvider( - credentials=PKReader( - filename=os.getenv("NEBIUS_PRIVATE_KEY_FILE"), - public_key_id=os.getenv("NEBIUS_PUBLIC_KEY_ID"), - service_account_id=os.getenv("NEBIUS_SERVICE_ACCOUNT_ID"), - ), + credentials=( + # temporary user token from `nebius iam get-access-token` + os.getenv("NEBIUS_ACCESS_TOKEN") + # or service account credentials + or PKReader( + filename=os.getenv("NEBIUS_PRIVATE_KEY_FILE"), + public_key_id=os.getenv("NEBIUS_PUBLIC_KEY_ID"), + service_account_id=os.getenv("NEBIUS_SERVICE_ACCOUNT_ID"), + ) + ) ) elif args.provider == "oci": from gpuhunt.providers.oci import OCICredentials, OCIProvider diff --git a/src/gpuhunt/providers/nebius.py b/src/gpuhunt/providers/nebius.py index 083e6c3..dfd014e 100644 --- a/src/gpuhunt/providers/nebius.py +++ b/src/gpuhunt/providers/nebius.py @@ -15,6 +15,7 @@ InstanceSpec, ListPlatformsRequest, ListPlatformsResponse, + Platform, PlatformServiceClient, PreemptibleSpec, Preset, @@ -41,6 +42,11 @@ logger = logging.getLogger(__name__) TIMEOUT = 7 +# When GPU name from the platform ID does not match the standard dstack GPU name. +# https://docs.nebius.com/compute/virtual-machines/types +GPU_NAME_OVERRIDES = { + "rtx6000": "rtxpro6000", +} @dataclass(frozen=True) @@ -58,6 +64,9 @@ class InfinibandFabric: InfinibandFabric("fabric-5", "gpu-h200-sxm", "eu-west1"), InfinibandFabric("fabric-6", "gpu-h100-sxm", "eu-north1"), InfinibandFabric("fabric-7", "gpu-h200-sxm", "eu-north1"), + InfinibandFabric("eu-north2-a", "gpu-h200-sxm", "eu-north2"), + InfinibandFabric("me-west1-a", "gpu-b200-sxm-a", "me-west1"), + InfinibandFabric("uk-south1-a", "gpu-b300-sxm", "uk-south1"), InfinibandFabric("us-central1-a", "gpu-h200-sxm", "us-central1"), InfinibandFabric("us-central1-b", "gpu-b200-sxm", "us-central1"), ] @@ -81,7 +90,7 @@ def get( platforms = list_platforms(sdk, project_id).items for platform in platforms: logger.info("Processing %s/%s", region, platform.metadata.name) - gpu = get_gpu_info(platform.metadata.name) + gpu = get_gpu_info(platform) for preset in platform.spec.presets: for spot in [False] + ( [True] if platform.status.allowed_for_preemptibles else [] @@ -149,13 +158,17 @@ def list_platforms(sdk: SDK, project_id: str) -> ListPlatformsResponse: return PlatformServiceClient(sdk).list(req, per_retry_timeout=TIMEOUT).wait() -def get_gpu_info(platform: str) -> AcceleratorInfo | None: - m = re.match(r"gpu-([^-]+)-", platform) +def get_gpu_info(platform: Platform) -> AcceleratorInfo | None: + m = re.match(r"gpu-([^-]+)", platform.metadata.name) if m is None: return None gpu_name = m.group(1) + gpu_name = GPU_NAME_OVERRIDES.get(gpu_name, gpu_name) accelerator_info = find_accelerators(names=[gpu_name], vendors=[AcceleratorVendor.NVIDIA]) - if len(accelerator_info) != 1: + if ( + len(accelerator_info) != 1 + or accelerator_info[0].memory != platform.spec.gpu_memory_gigabytes + ): return None return accelerator_info[0] diff --git a/src/integrity_tests/test_nebius.py b/src/integrity_tests/test_nebius.py index 86f504c..81131db 100644 --- a/src/integrity_tests/test_nebius.py +++ b/src/integrity_tests/test_nebius.py @@ -12,7 +12,7 @@ def data_rows(catalog_dir: Path) -> list[dict]: return list(csv.DictReader(f)) -@pytest.mark.parametrize("gpu", ["L40S", "H100", "H200", "B200", ""]) +@pytest.mark.parametrize("gpu", ["RTXPRO6000", "L40S", "H100", "H200", "B200", ""]) def test_gpu_present(gpu: str, data_rows: list[dict]): assert gpu in map(itemgetter("gpu_name"), data_rows)