diff --git a/src/gpuhunt/providers/vastai.py b/src/gpuhunt/providers/vastai.py index a50b397..b823fb1 100644 --- a/src/gpuhunt/providers/vastai.py +++ b/src/gpuhunt/providers/vastai.py @@ -2,6 +2,7 @@ import logging import re from collections import defaultdict +from collections.abc import Iterable from typing import Any, Literal import requests @@ -25,16 +26,16 @@ def __init__( self, extra_filters: dict[str, dict[Operators, FilterValue]] | None = None, community_cloud: bool = True, + order: Iterable[tuple[str, str]] = [("score", "desc")], ): self.extra_filters = extra_filters self.community_cloud = community_cloud + self.order = list(order) def get( self, query_filter: QueryFilter | None = None, balance_resources: bool = True ) -> list[RawCatalogItem]: - filters: dict[str, Any] = self.make_filters( - query_filter or QueryFilter(), community_cloud=self.community_cloud - ) + filters: dict[str, Any] = self.make_filters(query_filter or QueryFilter()) if self.extra_filters: for key, constraints in self.extra_filters.items(): for op, value in constraints.items(): @@ -85,10 +86,7 @@ def get( instance_offers.append(spot_offer) return instance_offers - @staticmethod - def make_filters( - q: QueryFilter, community_cloud: bool = True - ) -> dict[str, dict[Operators, FilterValue]]: + def make_filters(self, q: QueryFilter) -> dict[str, dict[Operators, FilterValue]]: filters = defaultdict(dict) if q.min_cpu is not None: filters["cpu_cores"]["gte"] = q.min_cpu @@ -132,11 +130,11 @@ def make_filters( # Datacenter offers map to Vast's "server cloud" scope. # When community_cloud is enabled, keep scope unfiltered so both # server and community offers are returned. - if not community_cloud: + if not self.community_cloud: filters["datacenter"]["eq"] = True filters["rentable"]["eq"] = True filters["rented"]["eq"] = False - filters["order"] = [["score", "desc"]] + filters["order"] = self.order return filters @staticmethod diff --git a/src/tests/providers/test_vastai.py b/src/tests/providers/test_vastai.py index ffb5f7c..8eb5473 100644 --- a/src/tests/providers/test_vastai.py +++ b/src/tests/providers/test_vastai.py @@ -3,12 +3,12 @@ def test_make_filters_defaults_to_datacenter_only(): - filters = VastAIProvider.make_filters(QueryFilter(), community_cloud=False) + filters = VastAIProvider(community_cloud=False).make_filters(QueryFilter()) assert filters["datacenter"]["eq"] is True assert "external" not in filters def test_make_filters_does_not_constrain_scope_when_community_cloud_enabled(): - filters = VastAIProvider.make_filters(QueryFilter(), community_cloud=True) + filters = VastAIProvider(community_cloud=True).make_filters(QueryFilter()) assert "datacenter" not in filters assert "external" not in filters