Skip to content

Commit 20e37a2

Browse files
committed
experimental model framework
1 parent 7975eb5 commit 20e37a2

15 files changed

Lines changed: 311 additions & 2 deletions

File tree

openml/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
OpenMLSupervisedTask,
4949
OpenMLTask,
5050
)
51+
from openml._get import get
5152

5253

5354
def populate_cache(
@@ -120,4 +121,5 @@ def populate_cache(
120121
"utils",
121122
"_api_calls",
122123
"__version__",
124+
"get",
123125
]

openml/_get.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""Global get dispatch utility."""
2+
3+
# currently just a forward to models
4+
# to discuss and possibly
5+
# todo: add global get utility here
6+
# in general, e.g., datasets will not have same name as models etc
7+
from openml.models import get
8+
9+
__all__ = ["get"]

openml/base/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""Module of base classes."""
2+
3+
from openml.base._base import OpenMLBase
4+
from openml.base._base_pkg import _BasePkg
5+
6+
__all__ = ["_BasePkg", "OpenMLBase"]

openml/base.py renamed to openml/base/_base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010

1111
import openml._api_calls
1212
import openml.config
13-
14-
from .utils import _get_rest_api_type_alias, _tag_openml_base
13+
from openml.utils import _get_rest_api_type_alias, _tag_openml_base
1514

1615

1716
class OpenMLBase(ABC):

openml/base/_base_pkg.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
"""Base Packager class."""
2+
3+
import inspect
4+
from pathlib import Path
5+
import sys
6+
import textwrap
7+
8+
from skbase.base import BaseObject
9+
from skbase.utils.dependencies import _check_estimator_deps
10+
11+
12+
class _BasePkg(BaseObject):
13+
14+
_tags = {
15+
"python_dependencies": None,
16+
"python_version": None,
17+
# package register and manifest
18+
"pkg_id": None, # object id contained, "__multiple" if multiple
19+
"pkg_obj": "reference", # or "code"
20+
"pkg_obj_type": None, # openml API type
21+
"pkg_compression": "zlib", # compression
22+
}
23+
24+
def __init__(self):
25+
super().__init__()
26+
27+
def materialize(self):
28+
try:
29+
_check_estimator_deps(obj=self)
30+
except ModuleNotFoundError as e:
31+
# prettier message, so the reference is to the pkg_id
32+
# currently, we cannot simply pass the object name to skbase
33+
# in the error message, so this is a hack
34+
# todo: fix this in scikit-base
35+
msg = str(e)
36+
if len(msg) > 11:
37+
msg = msg[11:]
38+
raise ModuleNotFoundError(msg) from e
39+
40+
return self._materialize()
41+
42+
def _materialize(self):
43+
raise RuntimeError("abstract method")
44+
45+
def serialize(self):
46+
cls_str = class_to_source(type(self))
47+
compress_method = self.get_tag("pkg_compression")
48+
if compress_method in [None, "None"]:
49+
return cls_str
50+
51+
cls_str = cls_str.encode("utf-8")
52+
exec(f"import {compress_method}")
53+
compressed_str = eval(f"{compress_method}.compress(cls_str)")
54+
55+
return compressed_str
56+
57+
58+
def _has_source(obj) -> bool:
59+
"""
60+
Return True if inspect.getsource(obj) should succeed.
61+
"""
62+
module_name = getattr(obj, "__module__", None)
63+
if not module_name or module_name not in sys.modules:
64+
return False
65+
66+
module = sys.modules[module_name]
67+
file = getattr(module, "__file__", None)
68+
if not file:
69+
return False
70+
71+
return Path(file).suffix == ".py"
72+
73+
74+
def class_to_source(cls) -> str:
75+
"""Return full source definition of python class as string.
76+
77+
Parameters
78+
----------
79+
cls : class to serialize
80+
81+
Returns
82+
-------
83+
str : complete definition of cls, as str.
84+
Imports are not contained or serialized.
85+
"""""
86+
87+
# Fast path: class has retrievable source
88+
if _has_source(cls):
89+
source = inspect.getsource(cls)
90+
return textwrap.dedent(source)
91+
92+
# Fallback for dynamically created classes
93+
lines = []
94+
95+
bases = [base.__name__ for base in cls.__bases__ if base is not object]
96+
base_str = f"({', '.join(bases)})" if bases else ""
97+
lines.append(f"class {cls.__name__}{base_str}:")
98+
99+
body_added = False
100+
101+
for name, value in cls.__dict__.items():
102+
if name.startswith("__") and name.endswith("__"):
103+
continue
104+
105+
if inspect.isfunction(value):
106+
if _has_source(value):
107+
method_src = inspect.getsource(value)
108+
method_src = textwrap.indent(textwrap.dedent(method_src), " ")
109+
lines.append(method_src)
110+
else:
111+
lines.append(f" def {name}(self): ...")
112+
body_added = True
113+
else:
114+
lines.append(f" {name} = {repr(value)}")
115+
body_added = True
116+
117+
if not body_added:
118+
lines.append(" pass")
119+
120+
return "\n".join(lines)

openml/models/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Module with packaging adapters."""
2+
3+
from openml.models._get import get
4+
5+
__all__ = ["get"]

openml/models/_get.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
2+
"""Model retrieval utility."""
3+
4+
from functools import lru_cache
5+
6+
7+
def get(id: str):
8+
"""Retrieve model object with unique identifier.
9+
10+
Parameter
11+
---------
12+
id : str
13+
unique identifier of object to retrieve
14+
15+
Returns
16+
-------
17+
class
18+
retrieved object
19+
20+
Raises
21+
------
22+
ModuleNotFoundError
23+
if dependencies of object to retrieve are not satisfied
24+
"""
25+
26+
id_lookup = _id_lookup()
27+
obj = id_lookup.get(id)
28+
if obj is None:
29+
raise ValueError(
30+
f"Error in openml.get, object with package id {id} "
31+
"does not exist."
32+
)
33+
return obj().materialize()
34+
35+
36+
# todo: need to generalize this later to more types
37+
# currently intentionally retrieves only classifiers
38+
# todo: replace this, optionally, by database backend
39+
def _id_lookup(obj_type=None):
40+
return _id_lookup_cached(obj_type=obj_type).copy()
41+
42+
43+
@lru_cache
44+
def _id_lookup_cached(obj_type=None):
45+
all_objs = _all_objects(obj_type=obj_type)
46+
47+
# todo: generalize that pkg can contain more than one object
48+
lookup_dict = {obj.get_class_tag("pkg_id"): obj for obj in all_objs}
49+
50+
return lookup_dict
51+
52+
53+
@lru_cache
54+
def _all_objects(obj_type=None):
55+
from skbase.lookup import all_objects
56+
57+
from openml.models.apis._classifier import _ModelPkgClassifier
58+
59+
clses = all_objects(
60+
object_types=_ModelPkgClassifier, package_name="openml", return_names=False
61+
)
62+
63+
return clses

openml/models/apis/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Module with packaging adapters."""
2+
3+
from openml.models.apis._classifier import _ModelPkgClassifier
4+
5+
__all__ = ["_ModelPkgClassifier"]

openml/models/apis/_classifier.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""Base package for sklearn classifiers."""
2+
3+
from openml.models.base import _OpenmlModelPkg
4+
5+
6+
class _ModelPkgClassifier(_OpenmlModelPkg):
7+
8+
_tags = {
9+
# tags specific to API type
10+
"pkg_obj_type": "classifier",
11+
}
12+
13+
def get_obj_tags(self):
14+
"""Return tags of the object as a dictionary."""
15+
return {} # this needs to be implemented
16+
17+
def get_obj_param_names(self):
18+
"""Return parameter names of the object as a list.
19+
20+
Returns
21+
-------
22+
list: names of object parameters
23+
"""
24+
return list(self.materialize()().get_params().keys())

openml/models/base/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Module with packaging adapters."""
2+
3+
from openml.models.base._base import _OpenmlModelPkg
4+
5+
__all__ = ["_OpenmlModelPkg"]

0 commit comments

Comments
 (0)