Skip to content

Commit d755d4c

Browse files
committed
Updated OpenML classes which require repr
1 parent a76333e commit d755d4c

4 files changed

Lines changed: 40 additions & 78 deletions

File tree

openml/base.py

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,22 @@
11
# License: BSD 3-Clause
22
from __future__ import annotations
33

4-
import re
54
import webbrowser
65
from abc import ABC, abstractmethod
7-
from typing import Iterable, Sequence
6+
from typing import Sequence
87

98
import xmltodict
109

1110
import openml._api_calls
1211
import openml.config
12+
from openml.utils import ReprMixin
1313

1414
from .utils import _get_rest_api_type_alias, _tag_openml_base
1515

1616

17-
class OpenMLBase(ABC):
17+
class OpenMLBase(ReprMixin, ABC):
1818
"""Base object for functionality that is shared across entities."""
1919

20-
def __repr__(self) -> str:
21-
body_fields = self._get_repr_body_fields()
22-
return self._apply_repr_template(body_fields)
23-
2420
@property
2521
@abstractmethod
2622
def id(self) -> int | None:
@@ -60,34 +56,6 @@ def _get_repr_body_fields(self) -> Sequence[tuple[str, str | int | list[str] | N
6056
"""
6157
# Should be implemented in the base class.
6258

63-
def _apply_repr_template(
64-
self,
65-
body_fields: Iterable[tuple[str, str | int | list[str] | None]],
66-
) -> str:
67-
"""Generates the header and formats the body for string representation of the object.
68-
69-
Parameters
70-
----------
71-
body_fields: List[Tuple[str, str]]
72-
A list of (name, value) pairs to display in the body of the __repr__.
73-
"""
74-
# We add spaces between capitals, e.g. ClassificationTask -> Classification Task
75-
name_with_spaces = re.sub(
76-
r"(\w)([A-Z])",
77-
r"\1 \2",
78-
self.__class__.__name__[len("OpenML") :],
79-
)
80-
header_text = f"OpenML {name_with_spaces}"
81-
header = f"{header_text}\n{'=' * len(header_text)}\n"
82-
83-
_body_fields: list[tuple[str, str | int | list[str]]] = [
84-
(k, "None" if v is None else v) for k, v in body_fields
85-
]
86-
longest_field_name_length = max(len(name) for name, _ in _body_fields)
87-
field_line_format = f"{{:.<{longest_field_name_length}}}: {{}}"
88-
body = "\n".join(field_line_format.format(name, value) for name, value in _body_fields)
89-
return header + body
90-
9159
@abstractmethod
9260
def _to_dict(self) -> dict[str, dict]:
9361
"""Creates a dictionary representation of self.

openml/datasets/data_feature.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
if TYPE_CHECKING:
77
from IPython.lib import pretty
88

9+
from openml.utils import ReprMixin
910

10-
class OpenMLDataFeature:
11+
12+
class OpenMLDataFeature(ReprMixin):
1113
"""
1214
Data Feature (a.k.a. Attribute) object.
1315
@@ -74,8 +76,20 @@ def __init__( # noqa: PLR0913
7476
self.number_missing_values = number_missing_values
7577
self.ontologies = ontologies
7678

77-
def __repr__(self) -> str:
78-
return "[%d - %s (%s)]" % (self.index, self.name, self.data_type)
79+
def _get_repr_body_fields(self) -> Sequence[tuple[str, str | int | list[str] | None]]:
80+
"""Collect all information to display in the __repr__ body."""
81+
fields: dict[str, int | str | None] = {
82+
"Index": self.index,
83+
"Name": self.name,
84+
"Data Type": self.data_type,
85+
}
86+
87+
order = [
88+
"Index",
89+
"Name",
90+
"Data Type",
91+
]
92+
return [(key, fields[key]) for key in order if key in fields]
7993

8094
def __eq__(self, other: Any) -> bool:
8195
return isinstance(other, OpenMLDataFeature) and self.__dict__ == other.__dict__

openml/setups/setup.py

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# License: BSD 3-Clause
22
from __future__ import annotations
33

4-
from typing import Any
4+
from typing import Any, Sequence
55

66
import openml.config
77
import openml.flows
8+
from openml.utils import ReprMixin
89

910

10-
class OpenMLSetup:
11+
class OpenMLSetup(ReprMixin):
1112
"""Setup object (a.k.a. Configuration).
1213
1314
Parameters
@@ -43,30 +44,21 @@ def _to_dict(self) -> dict[str, Any]:
4344
else None,
4445
}
4546

46-
def __repr__(self) -> str:
47-
header = "OpenML Setup"
48-
header = f"{header}\n{'=' * len(header)}\n"
49-
50-
fields = {
47+
def _get_repr_body_fields(self) -> Sequence[tuple[str, str | int | list[str] | None]]:
48+
"""Collect all information to display in the __repr__ body."""
49+
fields: dict[str, int | str | None] = {
5150
"Setup ID": self.setup_id,
5251
"Flow ID": self.flow_id,
5352
"Flow URL": openml.flows.OpenMLFlow.url_for_id(self.flow_id),
54-
"# of Parameters": (
55-
len(self.parameters) if self.parameters is not None else float("nan")
56-
),
53+
"# of Parameters": (len(self.parameters) if self.parameters is not None else "nan"),
5754
}
5855

5956
# determines the order in which the information will be printed
6057
order = ["Setup ID", "Flow ID", "Flow URL", "# of Parameters"]
61-
_fields = [(key, fields[key]) for key in order if key in fields]
62-
63-
longest_field_name_length = max(len(name) for name, _ in _fields)
64-
field_line_format = f"{{:.<{longest_field_name_length}}}: {{}}"
65-
body = "\n".join(field_line_format.format(name, value) for name, value in _fields)
66-
return header + body
58+
return [(key, fields[key]) for key in order if key in fields]
6759

6860

69-
class OpenMLParameter:
61+
class OpenMLParameter(ReprMixin):
7062
"""Parameter object (used in setup).
7163
7264
Parameters
@@ -123,11 +115,9 @@ def _to_dict(self) -> dict[str, Any]:
123115
"value": self.value,
124116
}
125117

126-
def __repr__(self) -> str:
127-
header = "OpenML Parameter"
128-
header = f"{header}\n{'=' * len(header)}\n"
129-
130-
fields = {
118+
def _get_repr_body_fields(self) -> Sequence[tuple[str, str | int | list[str] | None]]:
119+
"""Collect all information to display in the __repr__ body."""
120+
fields: dict[str, int | str | None] = {
131121
"ID": self.id,
132122
"Flow ID": self.flow_id,
133123
# "Flow Name": self.flow_name,
@@ -156,9 +146,4 @@ def __repr__(self) -> str:
156146
parameter_default,
157147
parameter_value,
158148
]
159-
_fields = [(key, fields[key]) for key in order if key in fields]
160-
161-
longest_field_name_length = max(len(name) for name, _ in _fields)
162-
field_line_format = f"{{:.<{longest_field_name_length}}}: {{}}"
163-
body = "\n".join(field_line_format.format(name, value) for name, value in _fields)
164-
return header + body
149+
return [(key, fields[key]) for key in order if key in fields]

openml/tasks/split.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
import pickle
55
from collections import OrderedDict
66
from pathlib import Path
7-
from typing import Any
7+
from typing import Any, Sequence
88
from typing_extensions import NamedTuple
99

1010
import arff # type: ignore
1111
import numpy as np
1212

13+
from openml.utils import ReprMixin
14+
1315

1416
class Split(NamedTuple):
1517
"""A single split of a dataset."""
@@ -18,7 +20,7 @@ class Split(NamedTuple):
1820
test: np.ndarray
1921

2022

21-
class OpenMLSplit:
23+
class OpenMLSplit(ReprMixin):
2224
"""OpenML Split object.
2325
2426
This class manages train-test splits for a dataset across multiple
@@ -63,10 +65,8 @@ def __init__(
6365
self.folds = len(self.split[0])
6466
self.samples = len(self.split[0][0])
6567

66-
def __repr__(self) -> str:
67-
header = "OpenML Split"
68-
header = f"{header}\n{'=' * len(header)}\n"
69-
68+
def _get_repr_body_fields(self) -> Sequence[tuple[str, str | int | list[str] | None]]:
69+
"""Collect all information to display in the __repr__ body."""
7070
fields = {
7171
"Name": self.name,
7272
"Description": (
@@ -79,12 +79,7 @@ def __repr__(self) -> str:
7979

8080
order = ["Name", "Description", "Repeats", "Folds", "Samples"]
8181

82-
_fields = [(key, fields[key]) for key in order if key in fields]
83-
84-
longest_field_name_length = max(len(name) for name, _ in _fields)
85-
field_line_format = f"{{:.<{longest_field_name_length}}}: {{}}"
86-
body = "\n".join(field_line_format.format(name, value) for name, value in _fields)
87-
return header + body
82+
return [(key, fields[key]) for key in order if key in fields]
8883

8984
def __eq__(self, other: Any) -> bool:
9085
if (

0 commit comments

Comments
 (0)