1- import unittest
2- from unittest .mock import MagicMock , patch , call
3- import pandas as pd
1+ # License: BSD 3-Clause
2+ from __future__ import annotations
43
4+ import pytest
5+ import pandas as pd
6+ import requests
7+ from openml .testing import TestBase
8+ from openml ._api import api_context
59from openml ._api .resources .tasks import TasksV1 , TasksV2
6- from openml .tasks import (
7- TaskType ,
8- OpenMLClassificationTask ,
9- OpenMLRegressionTask ,
10- list_tasks ,
11- get_task ,
12- get_tasks ,
13- delete_task ,
14- create_task
10+ from openml .tasks .task import (
11+ OpenMLClassificationTask ,
12+ OpenMLRegressionTask ,
13+ OpenMLLearningCurveTask ,
14+ TaskType
1515)
1616
17- class TestTasksEndpoints (unittest .TestCase ):
18-
17+ class TestTasksEndpoints (TestBase ):
1918 def setUp (self ):
20- # We mock the HTTP client (requests session) used by the API classes
21- self .mock_http = MagicMock ()
22-
23- def test_v1_get_endpoint (self ):
24- """Test GET task/{id} endpoint construction and parsing"""
25- api = TasksV1 (self .mock_http )
26-
27- # We include two parameters to ensure xmltodict parses 'oml:parameter'
28- # as a list, preventing the TypeError seen previously.
29- self .mock_http .get .return_value .text = """
30- <oml:task xmlns:oml="http://openml.org/openml">
31- <oml:task_id>1</oml:task_id>
32- <oml:task_type_id>1</oml:task_type_id>
33- <oml:task_type>Supervised Classification</oml:task_type>
34- <oml:input name="source_data">
35- <oml:data_set>
36- <oml:data_set_id>100</oml:data_set_id>
37- <oml:target_feature>class</oml:target_feature>
38- </oml:data_set>
39- </oml:input>
40- <oml:input name="estimation_procedure">
41- <oml:estimation_procedure>
42- <oml:id>1</oml:id>
43- <oml:type>crossvalidation</oml:type>
44- <oml:data_splits_url>http://splits</oml:data_splits_url>
45- <oml:parameter name="folds">10</oml:parameter>
46- <oml:parameter name="stratified">true</oml:parameter>
47- </oml:estimation_procedure>
48- </oml:input>
49- </oml:task>
50- """
51-
52- task = api .get (1 )
53-
54- self .mock_http .get .assert_called_with ("task/1" )
55- self .assertIsInstance (task , OpenMLClassificationTask )
56- self .assertEqual (task .task_id , 1 )
57-
58- def test_v1_list_endpoint_url_construction (self ):
59- """Test list tasks endpoint URL generation with filters"""
60- api = TasksV1 (self .mock_http )
61-
62- # We mock `_fetch_tasks_df` because parsing the list XML is complex
63- # and we just want to verify the URL parameters here.
64- with patch .object (api , '_fetch_tasks_df' ) as mock_fetch :
65- api .list (
66- limit = 100 ,
67- offset = 50 ,
68- task_type = TaskType .SUPERVISED_CLASSIFICATION ,
69- tag = "study_14"
70- )
71-
72- # Verify the constructed API call string passed to the fetcher
73- expected_call = "task/list/limit/100/offset/50/type/1/tag/study_14"
74- mock_fetch .assert_called_with (api_call = expected_call )
75-
76-
77- def test_v2_get_endpoint (self ):
78- """Test GET tasks/{id} V2 endpoint"""
79- api = TasksV2 (self .mock_http )
80-
81- # JSON response structure matches what V2 expects
82- self .mock_http .get .return_value .json .return_value = {
83- "id" : 500 ,
84- "task_type_id" : "2" , # Regression
85- "task_type" : "Supervised Regression" ,
86- "input" : [
87- {
88- "name" : "source_data" ,
89- "data_set" : {"data_set_id" : "99" , "target_feature" : "price" }
90- },
91- {
92- "name" : "estimation_procedure" ,
93- "estimation_procedure" : {
94- "id" : "5" ,
95- "type" : "cv" ,
96- "parameter" : []
97- }
98- }
99- ]
100- }
101-
102- task = api .get (500 )
103-
104- self .mock_http .get .assert_called_with ("tasks/500" )
105- self .assertIsInstance (task , OpenMLRegressionTask )
106- self .assertEqual (task .target_name , "price" )
107-
108- def test_v2_list_not_available (self ):
109- """Ensure V2 list endpoint raises error (as per code)"""
110- api = TasksV2 (self .mock_http )
111- with self .assertRaises (NotImplementedError ):
112- api .list (limit = 10 , offset = 0 )
113-
114-
115- class TestTaskHighLevelFunctions (unittest .TestCase ):
116- """Test the user-facing functions in functions.py"""
117-
118- @patch ("openml.tasks.functions.api_context" )
119- def test_list_tasks_wrapper (self , mock_api_context ):
120- """Test list_tasks() calls the backend correctly"""
121- # Setup backend to return a dummy dataframe
122- mock_api_context .backend .tasks .list .return_value = pd .DataFrame ({'id' : [1 ]})
123-
124- list_tasks (
125- task_type = TaskType .SUPERVISED_CLASSIFICATION ,
126- offset = 10 ,
127- size = 50 ,
128- tag = "my_tag"
129- )
130-
131- # The backend list method is called with positional arguments for limit (size)
132- # and offset because of how `_list_all` works internally.
133- mock_api_context .backend .tasks .list .assert_called_with (
134- 50 , # limit (size)
135- 10 , # offset
136- task_type = TaskType .SUPERVISED_CLASSIFICATION ,
137- tag = "my_tag" ,
138- data_tag = None ,
139- status = None ,
140- data_id = None ,
141- data_name = None ,
142- number_instances = None ,
143- number_features = None ,
144- number_classes = None ,
145- number_missing_values = None
146- )
147-
148- @patch ("openml.tasks.functions.get_dataset" )
149- @patch ("openml.tasks.functions.api_context" )
150- def test_get_task_wrapper (self , mock_api_context , mock_get_dataset ):
151- """Test get_task() retrieves task and dataset"""
152- # Mock Task
153- mock_task_obj = MagicMock ()
154- mock_task_obj .dataset_id = 123
155- mock_task_obj .target_name = "class"
156- mock_api_context .backend .tasks .get .return_value = mock_task_obj
157-
158- # Mock Dataset (needed for class labels)
159- mock_dataset = MagicMock ()
160- mock_get_dataset .return_value = mock_dataset
161-
162- get_task (task_id = 10 , download_data = False )
163-
164- # Verify calls
165- mock_api_context .backend .tasks .get .assert_called_with (10 )
166-
167- # `get_task` passes kwargs directly to get_dataset.
168- mock_get_dataset .assert_called_with (123 , download_data = False )
169-
170- @patch ("openml.tasks.functions.get_task" )
171- def test_get_tasks_list_wrapper (self , mock_get_task ):
172- """Test get_tasks() iterates and calls get_task() for each ID"""
173- ids_to_fetch = [100 , 101 ]
174-
175- # Execute the bulk fetch
176- get_tasks (ids_to_fetch , download_data = False , download_qualities = False )
177-
178- # Verify `get_task` was called exactly twice
179- self .assertEqual (mock_get_task .call_count , 2 )
180-
181- # Verify the arguments for each call
182- expected_calls = [
183- call (100 , download_data = False , download_qualities = False ),
184- call (101 , download_data = False , download_qualities = False )
185- ]
186- mock_get_task .assert_has_calls (expected_calls )
187-
188- @patch ("openml.utils._delete_entity" )
189- def test_delete_task_wrapper (self , mock_delete ):
190- """Test delete_task() hits the delete endpoint"""
191- delete_task (999 )
192- mock_delete .assert_called_with ("task" , 999 )
193-
194- def test_create_task_factory (self ):
195- """Test create_task() returns correct object (no API call until publish)"""
196- task = create_task (
197- task_type = TaskType .SUPERVISED_CLASSIFICATION ,
198- dataset_id = 1 ,
199- estimation_procedure_id = 1 ,
200- target_name = "class"
201- )
202- self .assertIsInstance (task , OpenMLClassificationTask )
203- self .assertEqual (task .dataset_id , 1 )
19+ super ().setUp ()
20+ self .v1_api = TasksV1 (api_context .backend .tasks ._http )
21+ self .v2_api = TasksV2 (api_context .backend .tasks ._http )
22+
23+ def _get_first_tid (self , task_type : TaskType ) -> int :
24+ """Helper to find an existing task ID for a given type on the server."""
25+ tasks = self .v1_api .list (limit = 1 , offset = 0 , task_type = task_type )
26+ if tasks .empty :
27+ pytest .skip (f"No tasks of type { task_type } found on test server." )
28+ return int (tasks .iloc [0 ]["tid" ])
29+
30+ @pytest .mark .uses_test_server ()
31+ def test_v1_get_classification_task (self ):
32+ tid = self ._get_first_tid (TaskType .SUPERVISED_CLASSIFICATION )
33+ task = self .v1_api .get (tid )
34+ assert isinstance (task , OpenMLClassificationTask )
35+ assert int (task .task_id ) == tid
36+
37+ @pytest .mark .uses_test_server ()
38+ def test_v1_get_regression_task (self ):
39+ tid = self ._get_first_tid (TaskType .SUPERVISED_REGRESSION )
40+ task = self .v1_api .get (tid )
41+ assert isinstance (task , OpenMLRegressionTask )
42+ assert int (task .task_id ) == tid
43+
44+ @pytest .mark .uses_test_server ()
45+ def test_v1_get_learning_curve_task (self ):
46+ tid = self ._get_first_tid (TaskType .LEARNING_CURVE )
47+ task = self .v1_api .get (tid )
48+ assert isinstance (task , OpenMLLearningCurveTask )
49+ assert int (task .task_id ) == tid
50+
51+ @pytest .mark .uses_test_server ()
52+ def test_v1_list_tasks (self ):
53+ """Verify V1 list endpoint returns a populated DataFrame."""
54+ tasks_df = self .v1_api .list (limit = 5 , offset = 0 )
55+ assert isinstance (tasks_df , pd .DataFrame )
56+ assert not tasks_df .empty
57+ assert "tid" in tasks_df .columns
58+
59+ @pytest .mark .uses_test_server ()
60+ def test_v2_get_task (self ):
61+ """Verify TasksV2 (JSON) skips gracefully if V2 is not supported."""
62+ tid = self ._get_first_tid (TaskType .SUPERVISED_CLASSIFICATION )
63+ try :
64+ task_v2 = self .v2_api .get (tid )
65+ assert int (task_v2 .task_id ) == tid
66+ except (requests .exceptions .JSONDecodeError , Exception ):
67+ pytest .skip ("V2 API JSON format not supported on this server." )
68+
69+ @pytest .mark .uses_test_server ()
70+ def test_v1_estimation_procedure_list (self ):
71+ procs = self .v1_api ._get_estimation_procedure_list ()
72+ assert isinstance (procs , list )
73+ assert len (procs ) > 0
74+ assert "id" in procs [0 ]
0 commit comments