|
26 | 26 |
|
27 | 27 | class TasksV1(TasksAPI): |
28 | 28 | @openml.utils.thread_safe_if_oslo_installed |
29 | | - def get_task( |
| 29 | + def get( |
30 | 30 | self, |
31 | 31 | task_id: int, |
32 | 32 | download_splits: bool = False, # noqa: FBT001, FBT002 |
@@ -477,7 +477,7 @@ def get_tasks( |
477 | 477 | ) -> list[OpenMLTask]: |
478 | 478 | """Download tasks. |
479 | 479 |
|
480 | | - This function iterates :meth:`openml.tasks.get_task`. |
| 480 | + This function iterates :meth:`openml.tasks.get`. |
481 | 481 |
|
482 | 482 | Parameters |
483 | 483 | ---------- |
@@ -511,7 +511,7 @@ def get_tasks( |
511 | 511 | tasks = [] |
512 | 512 | for task_id in task_ids: |
513 | 513 | tasks.append( |
514 | | - self.get_task( |
| 514 | + self.get( |
515 | 515 | task_id, download_data=download_data, download_qualities=download_qualities |
516 | 516 | ) |
517 | 517 | ) |
@@ -606,14 +606,20 @@ def delete_task(self, task_id: int) -> bool: |
606 | 606 |
|
607 | 607 | class TasksV2(TasksAPI): |
608 | 608 | @openml.utils.thread_safe_if_oslo_installed |
609 | | - def get_task( |
| 609 | + def get( |
610 | 610 | self, |
611 | 611 | task_id: int, |
| 612 | + download_splits: bool = False, # noqa: FBT001, FBT002 |
612 | 613 | **get_dataset_kwargs: Any, |
613 | 614 | ) -> OpenMLTask: |
614 | 615 | if not isinstance(task_id, int): |
615 | 616 | raise TypeError(f"Task id should be integer, is {type(task_id)}") |
616 | 617 |
|
| 618 | + if download_splits: |
| 619 | + warnings.warn( |
| 620 | + "`download_splits` is not yet supported in the v2 API and will be ignored.", |
| 621 | + stacklevel=2, |
| 622 | + ) |
617 | 623 | task = self._get_task_description(task_id) |
618 | 624 | dataset = get_dataset(task.dataset_id, **get_dataset_kwargs) # Shrivaths work |
619 | 625 | # List of class labels available in dataset description |
@@ -667,3 +673,40 @@ def _create_task_from_json(self, task_json: dict) -> OpenMLTask: |
667 | 673 | }[task_type_id] |
668 | 674 |
|
669 | 675 | return cls(**common_kwargs) |
| 676 | + |
| 677 | + def list_task_types(self) -> list[dict[str, str | int | None]]: |
| 678 | + response = self._http.get("tasktype") |
| 679 | + payload = response.json() |
| 680 | + |
| 681 | + return [ |
| 682 | + { |
| 683 | + "id": int(tt["id"]), |
| 684 | + "name": tt["name"], |
| 685 | + "description": tt["description"] or None, |
| 686 | + "creator": tt.get("creator"), |
| 687 | + } |
| 688 | + for tt in payload["task_types"]["task_type"] |
| 689 | + ] |
| 690 | + |
| 691 | + def get_task_type(self, task_type_id: int) -> dict[str, Any]: |
| 692 | + if not isinstance(task_type_id, int): |
| 693 | + raise TypeError("task_type_id must be int") |
| 694 | + |
| 695 | + response = self._http.get(f"tasktype/{task_type_id}") |
| 696 | + tt = response.json()["task_type"] |
| 697 | + |
| 698 | + return { |
| 699 | + "id": int(tt["id"]), |
| 700 | + "name": tt["name"], |
| 701 | + "description": tt.get("description"), |
| 702 | + "creator": tt.get("creator", []), |
| 703 | + "creation_date": tt.get("creation_date"), |
| 704 | + "inputs": [ |
| 705 | + { |
| 706 | + "name": i["name"], |
| 707 | + "required": i.get("requirement") == "required", |
| 708 | + "data_type": i.get("data_type"), |
| 709 | + } |
| 710 | + for i in tt.get("input", []) |
| 711 | + ], |
| 712 | + } |
0 commit comments