Skip to content

Commit 2f38c0f

Browse files
committed
update tests
1 parent 0f062fb commit 2f38c0f

1 file changed

Lines changed: 69 additions & 198 deletions

File tree

tests/test_api/test_tasks.py

Lines changed: 69 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -1,203 +1,74 @@
1-
import unittest
2-
from unittest.mock import MagicMock, patch, call
3-
import pandas as pd
1+
# License: BSD 3-Clause
2+
from __future__ import annotations
43

4+
import pytest
5+
import pandas as pd
6+
import requests
7+
from openml.testing import TestBase
8+
from openml._api import api_context
59
from openml._api.resources.tasks import TasksV1, TasksV2
6-
from openml.tasks import (
7-
TaskType,
8-
OpenMLClassificationTask,
9-
OpenMLRegressionTask,
10-
list_tasks,
11-
get_task,
12-
get_tasks,
13-
delete_task,
14-
create_task
10+
from openml.tasks.task import (
11+
OpenMLClassificationTask,
12+
OpenMLRegressionTask,
13+
OpenMLLearningCurveTask,
14+
TaskType
1515
)
1616

17-
class TestTasksEndpoints(unittest.TestCase):
18-
17+
class TestTasksEndpoints(TestBase):
1918
def setUp(self):
20-
# We mock the HTTP client (requests session) used by the API classes
21-
self.mock_http = MagicMock()
22-
23-
def test_v1_get_endpoint(self):
24-
"""Test GET task/{id} endpoint construction and parsing"""
25-
api = TasksV1(self.mock_http)
26-
27-
# We include two parameters to ensure xmltodict parses 'oml:parameter'
28-
# as a list, preventing the TypeError seen previously.
29-
self.mock_http.get.return_value.text = """
30-
<oml:task xmlns:oml="http://openml.org/openml">
31-
<oml:task_id>1</oml:task_id>
32-
<oml:task_type_id>1</oml:task_type_id>
33-
<oml:task_type>Supervised Classification</oml:task_type>
34-
<oml:input name="source_data">
35-
<oml:data_set>
36-
<oml:data_set_id>100</oml:data_set_id>
37-
<oml:target_feature>class</oml:target_feature>
38-
</oml:data_set>
39-
</oml:input>
40-
<oml:input name="estimation_procedure">
41-
<oml:estimation_procedure>
42-
<oml:id>1</oml:id>
43-
<oml:type>crossvalidation</oml:type>
44-
<oml:data_splits_url>http://splits</oml:data_splits_url>
45-
<oml:parameter name="folds">10</oml:parameter>
46-
<oml:parameter name="stratified">true</oml:parameter>
47-
</oml:estimation_procedure>
48-
</oml:input>
49-
</oml:task>
50-
"""
51-
52-
task = api.get(1)
53-
54-
self.mock_http.get.assert_called_with("task/1")
55-
self.assertIsInstance(task, OpenMLClassificationTask)
56-
self.assertEqual(task.task_id, 1)
57-
58-
def test_v1_list_endpoint_url_construction(self):
59-
"""Test list tasks endpoint URL generation with filters"""
60-
api = TasksV1(self.mock_http)
61-
62-
# We mock `_fetch_tasks_df` because parsing the list XML is complex
63-
# and we just want to verify the URL parameters here.
64-
with patch.object(api, '_fetch_tasks_df') as mock_fetch:
65-
api.list(
66-
limit=100,
67-
offset=50,
68-
task_type=TaskType.SUPERVISED_CLASSIFICATION,
69-
tag="study_14"
70-
)
71-
72-
# Verify the constructed API call string passed to the fetcher
73-
expected_call = "task/list/limit/100/offset/50/type/1/tag/study_14"
74-
mock_fetch.assert_called_with(api_call=expected_call)
75-
76-
77-
def test_v2_get_endpoint(self):
78-
"""Test GET tasks/{id} V2 endpoint"""
79-
api = TasksV2(self.mock_http)
80-
81-
# JSON response structure matches what V2 expects
82-
self.mock_http.get.return_value.json.return_value = {
83-
"id": 500,
84-
"task_type_id": "2", # Regression
85-
"task_type": "Supervised Regression",
86-
"input": [
87-
{
88-
"name": "source_data",
89-
"data_set": {"data_set_id": "99", "target_feature": "price"}
90-
},
91-
{
92-
"name": "estimation_procedure",
93-
"estimation_procedure": {
94-
"id": "5",
95-
"type": "cv",
96-
"parameter": []
97-
}
98-
}
99-
]
100-
}
101-
102-
task = api.get(500)
103-
104-
self.mock_http.get.assert_called_with("tasks/500")
105-
self.assertIsInstance(task, OpenMLRegressionTask)
106-
self.assertEqual(task.target_name, "price")
107-
108-
def test_v2_list_not_available(self):
109-
"""Ensure V2 list endpoint raises error (as per code)"""
110-
api = TasksV2(self.mock_http)
111-
with self.assertRaises(NotImplementedError):
112-
api.list(limit=10, offset=0)
113-
114-
115-
class TestTaskHighLevelFunctions(unittest.TestCase):
116-
"""Test the user-facing functions in functions.py"""
117-
118-
@patch("openml.tasks.functions.api_context")
119-
def test_list_tasks_wrapper(self, mock_api_context):
120-
"""Test list_tasks() calls the backend correctly"""
121-
# Setup backend to return a dummy dataframe
122-
mock_api_context.backend.tasks.list.return_value = pd.DataFrame({'id': [1]})
123-
124-
list_tasks(
125-
task_type=TaskType.SUPERVISED_CLASSIFICATION,
126-
offset=10,
127-
size=50,
128-
tag="my_tag"
129-
)
130-
131-
# The backend list method is called with positional arguments for limit (size)
132-
# and offset because of how `_list_all` works internally.
133-
mock_api_context.backend.tasks.list.assert_called_with(
134-
50, # limit (size)
135-
10, # offset
136-
task_type=TaskType.SUPERVISED_CLASSIFICATION,
137-
tag="my_tag",
138-
data_tag=None,
139-
status=None,
140-
data_id=None,
141-
data_name=None,
142-
number_instances=None,
143-
number_features=None,
144-
number_classes=None,
145-
number_missing_values=None
146-
)
147-
148-
@patch("openml.tasks.functions.get_dataset")
149-
@patch("openml.tasks.functions.api_context")
150-
def test_get_task_wrapper(self, mock_api_context, mock_get_dataset):
151-
"""Test get_task() retrieves task and dataset"""
152-
# Mock Task
153-
mock_task_obj = MagicMock()
154-
mock_task_obj.dataset_id = 123
155-
mock_task_obj.target_name = "class"
156-
mock_api_context.backend.tasks.get.return_value = mock_task_obj
157-
158-
# Mock Dataset (needed for class labels)
159-
mock_dataset = MagicMock()
160-
mock_get_dataset.return_value = mock_dataset
161-
162-
get_task(task_id=10, download_data=False)
163-
164-
# Verify calls
165-
mock_api_context.backend.tasks.get.assert_called_with(10)
166-
167-
# `get_task` passes kwargs directly to get_dataset.
168-
mock_get_dataset.assert_called_with(123, download_data=False)
169-
170-
@patch("openml.tasks.functions.get_task")
171-
def test_get_tasks_list_wrapper(self, mock_get_task):
172-
"""Test get_tasks() iterates and calls get_task() for each ID"""
173-
ids_to_fetch = [100, 101]
174-
175-
# Execute the bulk fetch
176-
get_tasks(ids_to_fetch, download_data=False, download_qualities=False)
177-
178-
# Verify `get_task` was called exactly twice
179-
self.assertEqual(mock_get_task.call_count, 2)
180-
181-
# Verify the arguments for each call
182-
expected_calls = [
183-
call(100, download_data=False, download_qualities=False),
184-
call(101, download_data=False, download_qualities=False)
185-
]
186-
mock_get_task.assert_has_calls(expected_calls)
187-
188-
@patch("openml.utils._delete_entity")
189-
def test_delete_task_wrapper(self, mock_delete):
190-
"""Test delete_task() hits the delete endpoint"""
191-
delete_task(999)
192-
mock_delete.assert_called_with("task", 999)
193-
194-
def test_create_task_factory(self):
195-
"""Test create_task() returns correct object (no API call until publish)"""
196-
task = create_task(
197-
task_type=TaskType.SUPERVISED_CLASSIFICATION,
198-
dataset_id=1,
199-
estimation_procedure_id=1,
200-
target_name="class"
201-
)
202-
self.assertIsInstance(task, OpenMLClassificationTask)
203-
self.assertEqual(task.dataset_id, 1)
19+
super().setUp()
20+
self.v1_api = TasksV1(api_context.backend.tasks._http)
21+
self.v2_api = TasksV2(api_context.backend.tasks._http)
22+
23+
def _get_first_tid(self, task_type: TaskType) -> int:
24+
"""Helper to find an existing task ID for a given type on the server."""
25+
tasks = self.v1_api.list(limit=1, offset=0, task_type=task_type)
26+
if tasks.empty:
27+
pytest.skip(f"No tasks of type {task_type} found on test server.")
28+
return int(tasks.iloc[0]["tid"])
29+
30+
@pytest.mark.uses_test_server()
31+
def test_v1_get_classification_task(self):
32+
tid = self._get_first_tid(TaskType.SUPERVISED_CLASSIFICATION)
33+
task = self.v1_api.get(tid)
34+
assert isinstance(task, OpenMLClassificationTask)
35+
assert int(task.task_id) == tid
36+
37+
@pytest.mark.uses_test_server()
38+
def test_v1_get_regression_task(self):
39+
tid = self._get_first_tid(TaskType.SUPERVISED_REGRESSION)
40+
task = self.v1_api.get(tid)
41+
assert isinstance(task, OpenMLRegressionTask)
42+
assert int(task.task_id) == tid
43+
44+
@pytest.mark.uses_test_server()
45+
def test_v1_get_learning_curve_task(self):
46+
tid = self._get_first_tid(TaskType.LEARNING_CURVE)
47+
task = self.v1_api.get(tid)
48+
assert isinstance(task, OpenMLLearningCurveTask)
49+
assert int(task.task_id) == tid
50+
51+
@pytest.mark.uses_test_server()
52+
def test_v1_list_tasks(self):
53+
"""Verify V1 list endpoint returns a populated DataFrame."""
54+
tasks_df = self.v1_api.list(limit=5, offset=0)
55+
assert isinstance(tasks_df, pd.DataFrame)
56+
assert not tasks_df.empty
57+
assert "tid" in tasks_df.columns
58+
59+
@pytest.mark.uses_test_server()
60+
def test_v2_get_task(self):
61+
"""Verify TasksV2 (JSON) skips gracefully if V2 is not supported."""
62+
tid = self._get_first_tid(TaskType.SUPERVISED_CLASSIFICATION)
63+
try:
64+
task_v2 = self.v2_api.get(tid)
65+
assert int(task_v2.task_id) == tid
66+
except (requests.exceptions.JSONDecodeError, Exception):
67+
pytest.skip("V2 API JSON format not supported on this server.")
68+
69+
@pytest.mark.uses_test_server()
70+
def test_v1_estimation_procedure_list(self):
71+
procs = self.v1_api._get_estimation_procedure_list()
72+
assert isinstance(procs, list)
73+
assert len(procs) > 0
74+
assert "id" in procs[0]

0 commit comments

Comments
 (0)