Skip to content

Commit 7238714

Browse files
update tests structure
1 parent 7e35f4f commit 7238714

1 file changed

Lines changed: 67 additions & 11 deletions

File tree

tests/test_api/test_estimation_procedures.py

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,33 @@
22
from __future__ import annotations
33

44
import pytest
5+
from openml._api.config import settings
56

6-
from openml._api.runtime.core import build_backend
7-
from openml.testing import TestBase
7+
from openml._api.resources.estimation_procedures import EstimationProceduresV1, EstimationProceduresV2
8+
from openml.testing import TestAPIBase
9+
from openml._api.resources.base.fallback import FallbackProxy
810

911

10-
class TestEstimationProceduresV1(TestBase):
12+
class TestEstimationProceduresV1(TestAPIBase):
1113
"""Tests for V1 XML API implementation of estimation procedures."""
1214

1315
_multiprocess_can_split_ = True
1416

1517
def setUp(self) -> None:
1618
super().setUp()
17-
backend = build_backend('v1', strict=True)
18-
self.api = backend.estimation_procedures
19+
self.client = self._get_http_client(
20+
server=settings.api.v1.server,
21+
base_url=settings.api.v1.base_url,
22+
api_key=settings.api.v1.api_key,
23+
timeout=settings.api.v1.timeout,
24+
retries=settings.connection.retries,
25+
retry_policy=settings.connection.retry_policy,
26+
)
27+
self.resource = EstimationProceduresV1(self.client)
1928

2029
@pytest.mark.uses_test_server()
2130
def test_list(self):
22-
procedures = self.api.list()
31+
procedures = self.resource.list()
2332

2433
assert isinstance(procedures, list)
2534
assert len(procedures) > 0
@@ -28,7 +37,7 @@ def test_list(self):
2837

2938
@pytest.mark.uses_test_server()
3039
def test_get_details(self):
31-
details = self.api._get_details()
40+
details = self.resource._get_details()
3241

3342
assert isinstance(details, list)
3443
assert len(details) > 0
@@ -39,20 +48,67 @@ def test_get_details(self):
3948
assert all("task_type_id" in d for d in details)
4049

4150

42-
class TestEstimationProceduresV2(TestBase):
51+
class TestEstimationProceduresV2(TestAPIBase):
4352
"""Tests for V2 JSON API implementation of estimation procedures."""
4453

4554
_multiprocess_can_split_ = True
4655

4756
def setUp(self) -> None:
4857
super().setUp()
49-
backend = build_backend('v2', strict=True)
50-
self.api = backend.estimation_procedures
58+
self.client = self._get_http_client(
59+
server=settings.api.v2.server,
60+
base_url=settings.api.v2.base_url,
61+
api_key=settings.api.v2.api_key,
62+
timeout=settings.api.v2.timeout,
63+
retries=settings.connection.retries,
64+
retry_policy=settings.connection.retry_policy,
65+
)
66+
self.resource = EstimationProceduresV2(self.client)
5167

5268
@pytest.mark.uses_test_server()
5369
def test_list(self):
54-
procedures = self.api.list()
70+
procedures = self.resource.list()
5571

5672
assert isinstance(procedures, list)
5773
assert len(procedures) > 0
5874
assert all(isinstance(p, str) for p in procedures)
75+
76+
77+
class TestEstimationProceduresCombined(TestAPIBase):
78+
def setUp(self):
79+
super().setUp()
80+
self.v1_client = self._get_http_client(
81+
server=settings.api.v1.server,
82+
base_url=settings.api.v1.base_url,
83+
api_key=settings.api.v1.api_key,
84+
timeout=settings.api.v1.timeout,
85+
retries=settings.connection.retries,
86+
retry_policy=settings.connection.retry_policy,
87+
)
88+
self.v2_client = self._get_http_client(
89+
server=settings.api.v2.server,
90+
base_url=settings.api.v2.base_url,
91+
api_key=settings.api.v2.api_key,
92+
timeout=settings.api.v2.timeout,
93+
retries=settings.connection.retries,
94+
retry_policy=settings.connection.retry_policy,
95+
)
96+
self.resource_v1 = EstimationProceduresV1(self.v1_client)
97+
self.resource_v2 = EstimationProceduresV2(self.v2_client)
98+
self.resource_fallback = FallbackProxy(self.resource_v2, self.resource_v1)
99+
100+
@pytest.mark.uses_test_server()
101+
def test_list_matches(self):
102+
output_v1 = self.resource_v1.list()
103+
output_v2 = self.resource_v2.list()
104+
# output_v1 matches output_v2
105+
assert isinstance(output_v1, list)
106+
assert isinstance(output_v2, list)
107+
assert output_v1 == output_v2
108+
109+
@pytest.mark.uses_test_server()
110+
def test_list_fallback(self):
111+
output_fallback = self.resource_fallback.list()
112+
assert isinstance(output_fallback, list)
113+
assert len(output_fallback) > 0
114+
assert all(isinstance(p, str) for p in output_fallback)

0 commit comments

Comments
 (0)