|
6 | 6 | import inspect |
7 | 7 | import numpy as np |
8 | 8 | import pytest |
9 | | - |
| 9 | +from unittest.mock import patch |
10 | 10 | import openml.testing |
11 | 11 | from openml.extensions import Extension, get_extension_by_flow, get_extension_by_model, register_extension |
12 | 12 |
|
@@ -109,43 +109,48 @@ def instantiate_model_from_hpo_class(self, model, trace_iteration): |
109 | 109 | return DummyModel() |
110 | 110 |
|
111 | 111 |
|
112 | | -def _unregister(): |
113 | | - # "Un-register" the test extensions |
114 | | - openml.extensions.extensions.clear() |
115 | | - |
116 | 112 |
|
117 | 113 | class TestInit(openml.testing.TestBase): |
118 | | - def setUp(self): |
119 | | - super().setUp() |
120 | | - _unregister() |
121 | 114 |
|
122 | 115 | def test_get_extension_by_flow(self): |
123 | | - assert get_extension_by_flow(DummyFlow()) is None |
124 | | - with pytest.raises(ValueError, match="No extension registered which can handle flow:"): |
125 | | - get_extension_by_flow(DummyFlow(), raise_if_no_extension=True) |
126 | | - register_extension(DummyExtension1) |
127 | | - assert isinstance(get_extension_by_flow(DummyFlow()), DummyExtension1) |
128 | | - register_extension(DummyExtension2) |
129 | | - assert isinstance(get_extension_by_flow(DummyFlow()), DummyExtension1) |
130 | | - register_extension(DummyExtension1) |
131 | | - with pytest.raises( |
132 | | - ValueError, match="Multiple extensions registered which can handle flow:" |
133 | | - ): |
134 | | - get_extension_by_flow(DummyFlow()) |
| 116 | + # We replace the global list with a new empty list [] ONLY for this block |
| 117 | + with patch("openml.extensions.extensions", []): |
| 118 | + assert get_extension_by_flow(DummyFlow()) is None |
| 119 | + |
| 120 | + with pytest.raises(ValueError, match="No extension registered which can handle flow:"): |
| 121 | + get_extension_by_flow(DummyFlow(), raise_if_no_extension=True) |
| 122 | + |
| 123 | + register_extension(DummyExtension1) |
| 124 | + assert isinstance(get_extension_by_flow(DummyFlow()), DummyExtension1) |
| 125 | + |
| 126 | + register_extension(DummyExtension2) |
| 127 | + assert isinstance(get_extension_by_flow(DummyFlow()), DummyExtension1) |
| 128 | + |
| 129 | + register_extension(DummyExtension1) |
| 130 | + with pytest.raises( |
| 131 | + ValueError, match="Multiple extensions registered which can handle flow:" |
| 132 | + ): |
| 133 | + get_extension_by_flow(DummyFlow()) |
135 | 134 |
|
136 | 135 | def test_get_extension_by_model(self): |
137 | | - assert get_extension_by_model(DummyModel()) is None |
138 | | - with pytest.raises(ValueError, match="No extension registered which can handle model:"): |
139 | | - get_extension_by_model(DummyModel(), raise_if_no_extension=True) |
140 | | - register_extension(DummyExtension1) |
141 | | - assert isinstance(get_extension_by_model(DummyModel()), DummyExtension1) |
142 | | - register_extension(DummyExtension2) |
143 | | - assert isinstance(get_extension_by_model(DummyModel()), DummyExtension1) |
144 | | - register_extension(DummyExtension1) |
145 | | - with pytest.raises( |
146 | | - ValueError, match="Multiple extensions registered which can handle model:" |
147 | | - ): |
148 | | - get_extension_by_model(DummyModel()) |
| 136 | + # Again, we start with a fresh empty list automatically |
| 137 | + with patch("openml.extensions.extensions", []): |
| 138 | + assert get_extension_by_model(DummyModel()) is None |
| 139 | + |
| 140 | + with pytest.raises(ValueError, match="No extension registered which can handle model:"): |
| 141 | + get_extension_by_model(DummyModel(), raise_if_no_extension=True) |
| 142 | + |
| 143 | + register_extension(DummyExtension1) |
| 144 | + assert isinstance(get_extension_by_model(DummyModel()), DummyExtension1) |
| 145 | + |
| 146 | + register_extension(DummyExtension2) |
| 147 | + assert isinstance(get_extension_by_model(DummyModel()), DummyExtension1) |
| 148 | + |
| 149 | + register_extension(DummyExtension1) |
| 150 | + with pytest.raises( |
| 151 | + ValueError, match="Multiple extensions registered which can handle model:" |
| 152 | + ): |
| 153 | + get_extension_by_model(DummyModel()) |
149 | 154 |
|
150 | 155 |
|
151 | 156 | def test_flow_to_model_with_defaults(): |
@@ -194,34 +199,33 @@ class InvalidFlow: |
194 | 199 | ext.flow_to_model(flow) |
195 | 200 |
|
196 | 201 |
|
| 202 | +@patch("openml.extensions.extensions", []) |
197 | 203 | def test_extension_not_found_error_message(): |
198 | 204 | """Test error message contains helpful information.""" |
199 | 205 | class UnknownModel: |
200 | 206 | pass |
201 | 207 |
|
202 | | - _unregister() |
203 | | - |
204 | 208 | with pytest.raises(ValueError, match="No extension registered"): |
205 | 209 | get_extension_by_model(UnknownModel(), raise_if_no_extension=True) |
206 | 210 |
|
207 | 211 |
|
208 | 212 | def test_register_same_extension_twice(): |
209 | 213 | """Test behavior when registering same extension twice.""" |
210 | | - register_extension(DummyExtension) |
211 | | - register_extension(DummyExtension) |
| 214 | + # Using a context manager here to isolate the list |
| 215 | + with patch("openml.extensions.extensions", []): |
| 216 | + register_extension(DummyExtension) |
| 217 | + register_extension(DummyExtension) |
212 | 218 |
|
213 | | - matches = [ |
214 | | - ext for ext in openml.extensions.extensions |
215 | | - if ext is DummyExtension |
216 | | - ] |
217 | | - |
218 | | - assert len(matches) == 2 |
| 219 | + matches = [ |
| 220 | + ext for ext in openml.extensions.extensions |
| 221 | + if ext is DummyExtension |
| 222 | + ] |
| 223 | + assert len(matches) == 2 |
219 | 224 |
|
220 | 225 |
|
| 226 | +@patch("openml.extensions.extensions", []) |
221 | 227 | def test_extension_priority_order(): |
222 | | - """Test that extensions are checked in registration order.""" |
223 | | - _unregister() |
224 | | - |
| 228 | + """Test that extensions are checked in registration order.""" |
225 | 229 | class DummyExtensionA(DummyExtension): |
226 | 230 | pass |
227 | 231 | class DummyExtensionB(DummyExtension): |
|
0 commit comments