Skip to content

Commit 93033a0

Browse files
committed
fixed extensions
1 parent e8c64f2 commit 93033a0

1 file changed

Lines changed: 49 additions & 45 deletions

File tree

tests/test_extensions/test_functions.py

Lines changed: 49 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import inspect
77
import numpy as np
88
import pytest
9-
9+
from unittest.mock import patch
1010
import openml.testing
1111
from openml.extensions import Extension, get_extension_by_flow, get_extension_by_model, register_extension
1212

@@ -109,43 +109,48 @@ def instantiate_model_from_hpo_class(self, model, trace_iteration):
109109
return DummyModel()
110110

111111

112-
def _unregister():
113-
# "Un-register" the test extensions
114-
openml.extensions.extensions.clear()
115-
116112

117113
class TestInit(openml.testing.TestBase):
118-
def setUp(self):
119-
super().setUp()
120-
_unregister()
121114

122115
def test_get_extension_by_flow(self):
123-
assert get_extension_by_flow(DummyFlow()) is None
124-
with pytest.raises(ValueError, match="No extension registered which can handle flow:"):
125-
get_extension_by_flow(DummyFlow(), raise_if_no_extension=True)
126-
register_extension(DummyExtension1)
127-
assert isinstance(get_extension_by_flow(DummyFlow()), DummyExtension1)
128-
register_extension(DummyExtension2)
129-
assert isinstance(get_extension_by_flow(DummyFlow()), DummyExtension1)
130-
register_extension(DummyExtension1)
131-
with pytest.raises(
132-
ValueError, match="Multiple extensions registered which can handle flow:"
133-
):
134-
get_extension_by_flow(DummyFlow())
116+
# We replace the global list with a new empty list [] ONLY for this block
117+
with patch("openml.extensions.extensions", []):
118+
assert get_extension_by_flow(DummyFlow()) is None
119+
120+
with pytest.raises(ValueError, match="No extension registered which can handle flow:"):
121+
get_extension_by_flow(DummyFlow(), raise_if_no_extension=True)
122+
123+
register_extension(DummyExtension1)
124+
assert isinstance(get_extension_by_flow(DummyFlow()), DummyExtension1)
125+
126+
register_extension(DummyExtension2)
127+
assert isinstance(get_extension_by_flow(DummyFlow()), DummyExtension1)
128+
129+
register_extension(DummyExtension1)
130+
with pytest.raises(
131+
ValueError, match="Multiple extensions registered which can handle flow:"
132+
):
133+
get_extension_by_flow(DummyFlow())
135134

136135
def test_get_extension_by_model(self):
137-
assert get_extension_by_model(DummyModel()) is None
138-
with pytest.raises(ValueError, match="No extension registered which can handle model:"):
139-
get_extension_by_model(DummyModel(), raise_if_no_extension=True)
140-
register_extension(DummyExtension1)
141-
assert isinstance(get_extension_by_model(DummyModel()), DummyExtension1)
142-
register_extension(DummyExtension2)
143-
assert isinstance(get_extension_by_model(DummyModel()), DummyExtension1)
144-
register_extension(DummyExtension1)
145-
with pytest.raises(
146-
ValueError, match="Multiple extensions registered which can handle model:"
147-
):
148-
get_extension_by_model(DummyModel())
136+
# Again, we start with a fresh empty list automatically
137+
with patch("openml.extensions.extensions", []):
138+
assert get_extension_by_model(DummyModel()) is None
139+
140+
with pytest.raises(ValueError, match="No extension registered which can handle model:"):
141+
get_extension_by_model(DummyModel(), raise_if_no_extension=True)
142+
143+
register_extension(DummyExtension1)
144+
assert isinstance(get_extension_by_model(DummyModel()), DummyExtension1)
145+
146+
register_extension(DummyExtension2)
147+
assert isinstance(get_extension_by_model(DummyModel()), DummyExtension1)
148+
149+
register_extension(DummyExtension1)
150+
with pytest.raises(
151+
ValueError, match="Multiple extensions registered which can handle model:"
152+
):
153+
get_extension_by_model(DummyModel())
149154

150155

151156
def test_flow_to_model_with_defaults():
@@ -194,34 +199,33 @@ class InvalidFlow:
194199
ext.flow_to_model(flow)
195200

196201

202+
@patch("openml.extensions.extensions", [])
197203
def test_extension_not_found_error_message():
198204
"""Test error message contains helpful information."""
199205
class UnknownModel:
200206
pass
201207

202-
_unregister()
203-
204208
with pytest.raises(ValueError, match="No extension registered"):
205209
get_extension_by_model(UnknownModel(), raise_if_no_extension=True)
206210

207211

208212
def test_register_same_extension_twice():
209213
"""Test behavior when registering same extension twice."""
210-
register_extension(DummyExtension)
211-
register_extension(DummyExtension)
214+
# Using a context manager here to isolate the list
215+
with patch("openml.extensions.extensions", []):
216+
register_extension(DummyExtension)
217+
register_extension(DummyExtension)
212218

213-
matches = [
214-
ext for ext in openml.extensions.extensions
215-
if ext is DummyExtension
216-
]
217-
218-
assert len(matches) == 2
219+
matches = [
220+
ext for ext in openml.extensions.extensions
221+
if ext is DummyExtension
222+
]
223+
assert len(matches) == 2
219224

220225

226+
@patch("openml.extensions.extensions", [])
221227
def test_extension_priority_order():
222-
"""Test that extensions are checked in registration order."""
223-
_unregister()
224-
228+
"""Test that extensions are checked in registration order."""
225229
class DummyExtensionA(DummyExtension):
226230
pass
227231
class DummyExtensionB(DummyExtension):

0 commit comments

Comments
 (0)