Skip to content

Commit c2b9e1a

Browse files
committed
commiting latest cahnges
1 parent 510b286 commit c2b9e1a

5 files changed

Lines changed: 380 additions & 9 deletions

File tree

openml/_api/resources/base.py

Lines changed: 143 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from typing import TYPE_CHECKING
4+
from typing import TYPE_CHECKING, Any
55

66
if TYPE_CHECKING:
7+
from build.lib.openml.tasks.task import TaskType
78
from requests import Response
89

910
from openml._api.http import HTTPClient
@@ -22,10 +23,148 @@ def get(self, dataset_id: int) -> OpenMLDataset | tuple[OpenMLDataset, Response]
2223

2324

2425
class TasksAPI(ResourceAPI, ABC):
26+
# Single task retrieval (V1 and V2)
2527
@abstractmethod
2628
def get(
2729
self,
2830
task_id: int,
29-
*,
30-
return_response: bool = False,
31-
) -> OpenMLTask | tuple[OpenMLTask, Response]: ...
31+
download_splits: bool = False, # noqa: FBT001, FBT002
32+
**get_dataset_kwargs: Any,
33+
) -> OpenMLTask:
34+
"""
35+
API v1:
36+
GET /task/{task_id}
37+
38+
API v2:
39+
GET /tasks/{task_id}
40+
"""
41+
...
42+
43+
# # Multiple task retrieval (V1 only)
44+
# @abstractmethod
45+
# def get_tasks(
46+
# self,
47+
# task_ids: list[int],
48+
# **kwargs: Any,
49+
# ) -> list[OpenMLTask]:
50+
# """
51+
# Retrieve multiple tasks.
52+
53+
# API v1:
54+
# Implemented via repeated GET /task/{task_id}
55+
56+
# API v2:
57+
# Not currently supported
58+
59+
# Parameters
60+
# ----------
61+
# task_ids : list[int]
62+
63+
# Returns
64+
# -------
65+
# list[OpenMLTask]
66+
# """
67+
# ...
68+
69+
# # Task listing (V1 only)
70+
# @abstractmethod
71+
# def list_tasks(
72+
# self,
73+
# *,
74+
# task_type: TaskType | None = None,
75+
# offset: int | None = None,
76+
# size: int | None = None,
77+
# **filters: Any,
78+
# ):
79+
# """
80+
# List tasks with filters.
81+
82+
# API v1:
83+
# GET /task/list
84+
85+
# API v2:
86+
# Not available.
87+
88+
# Returns
89+
# -------
90+
# pandas.DataFrame
91+
# """
92+
# ...
93+
94+
# # Task creation (V1 only)
95+
# @abstractmethod
96+
# def create_task(
97+
# self,
98+
# task_type: TaskType,
99+
# dataset_id: int,
100+
# estimation_procedure_id: int,
101+
# **kwargs: Any,
102+
# ) -> OpenMLTask:
103+
# """
104+
# Create a new task.
105+
106+
# API v1:
107+
# POST /task
108+
109+
# API v2:
110+
# Not supported.
111+
112+
# Returns
113+
# -------
114+
# OpenMLTask
115+
# """
116+
# ...
117+
118+
# # Task deletion (V1 only)
119+
# @abstractmethod
120+
# def delete_task(self, task_id: int) -> bool:
121+
# """
122+
# Delete a task.
123+
124+
# API v1:
125+
# DELETE /task/{task_id}
126+
127+
# API v2:
128+
# Not supported.
129+
130+
# Returns
131+
# -------
132+
# bool
133+
# """
134+
# ...
135+
136+
# # Task type listing (V2 only)
137+
# @abstractmethod
138+
# def list_task_types(self) -> list[dict[str, Any]]:
139+
# """
140+
# List all task types.
141+
142+
# API v2:
143+
# GET /tasktype/list
144+
145+
# API v1:
146+
# Not available.
147+
148+
# Returns
149+
# -------
150+
# list[dict]
151+
# """
152+
# ...
153+
154+
# # Task type retrieval (V2 only)
155+
# @abstractmethod
156+
# def get_task_type(self, task_type_id: int) -> dict[str, Any]:
157+
# """
158+
# Retrieve a single task type.
159+
160+
# API v2:
161+
# GET /tasktype/{task_type_id}
162+
163+
# API v1:
164+
# Not available.
165+
166+
# Returns
167+
# -------
168+
# dict
169+
# """
170+
# ...

openml/_api/resources/tasks.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
class TasksV1(TasksAPI):
2828
@openml.utils.thread_safe_if_oslo_installed
29-
def get_task(
29+
def get(
3030
self,
3131
task_id: int,
3232
download_splits: bool = False, # noqa: FBT001, FBT002
@@ -477,7 +477,7 @@ def get_tasks(
477477
) -> list[OpenMLTask]:
478478
"""Download tasks.
479479
480-
This function iterates :meth:`openml.tasks.get_task`.
480+
This function iterates :meth:`openml.tasks.get`.
481481
482482
Parameters
483483
----------
@@ -511,7 +511,7 @@ def get_tasks(
511511
tasks = []
512512
for task_id in task_ids:
513513
tasks.append(
514-
self.get_task(
514+
self.get(
515515
task_id, download_data=download_data, download_qualities=download_qualities
516516
)
517517
)
@@ -606,14 +606,20 @@ def delete_task(self, task_id: int) -> bool:
606606

607607
class TasksV2(TasksAPI):
608608
@openml.utils.thread_safe_if_oslo_installed
609-
def get_task(
609+
def get(
610610
self,
611611
task_id: int,
612+
download_splits: bool = False, # noqa: FBT001, FBT002
612613
**get_dataset_kwargs: Any,
613614
) -> OpenMLTask:
614615
if not isinstance(task_id, int):
615616
raise TypeError(f"Task id should be integer, is {type(task_id)}")
616617

618+
if download_splits:
619+
warnings.warn(
620+
"`download_splits` is not yet supported in the v2 API and will be ignored.",
621+
stacklevel=2,
622+
)
617623
task = self._get_task_description(task_id)
618624
dataset = get_dataset(task.dataset_id, **get_dataset_kwargs) # Shrivaths work
619625
# List of class labels available in dataset description
@@ -667,3 +673,40 @@ def _create_task_from_json(self, task_json: dict) -> OpenMLTask:
667673
}[task_type_id]
668674

669675
return cls(**common_kwargs)
676+
677+
def list_task_types(self) -> list[dict[str, str | int | None]]:
678+
response = self._http.get("tasktype")
679+
payload = response.json()
680+
681+
return [
682+
{
683+
"id": int(tt["id"]),
684+
"name": tt["name"],
685+
"description": tt["description"] or None,
686+
"creator": tt.get("creator"),
687+
}
688+
for tt in payload["task_types"]["task_type"]
689+
]
690+
691+
def get_task_type(self, task_type_id: int) -> dict[str, Any]:
692+
if not isinstance(task_type_id, int):
693+
raise TypeError("task_type_id must be int")
694+
695+
response = self._http.get(f"tasktype/{task_type_id}")
696+
tt = response.json()["task_type"]
697+
698+
return {
699+
"id": int(tt["id"]),
700+
"name": tt["name"],
701+
"description": tt.get("description"),
702+
"creator": tt.get("creator", []),
703+
"creation_date": tt.get("creation_date"),
704+
"inputs": [
705+
{
706+
"name": i["name"],
707+
"required": i.get("requirement") == "required",
708+
"data_type": i.get("data_type"),
709+
}
710+
for i in tt.get("input", [])
711+
],
712+
}

openml/tasks/functions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,6 @@ def __list_tasks(api_call: str) -> pd.DataFrame: # noqa: C901, PLR0912
340340
return pd.DataFrame.from_dict(tasks, orient="index")
341341

342342

343-
# /tasktype/list
344343
def get_tasks(
345344
task_ids: list[int],
346345
download_data: bool | None = None,

openml/tasks/task.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from enum import Enum
99
from pathlib import Path
1010
from typing import TYPE_CHECKING, Any, Sequence
11+
from attr import dataclass
1112
from typing_extensions import TypedDict
1213

1314
import openml._api_calls

0 commit comments

Comments
 (0)