Skip to content

Commit 8d05331

Browse files
committed
added the tests
1 parent cefae00 commit 8d05331

1 file changed

Lines changed: 61 additions & 0 deletions

File tree

tests/test_runs/test_run_functions.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,67 @@ def test_local_run_metric_score(self):
10961096

10971097
self._test_local_evaluations(run)
10981098

1099+
@pytest.mark.sklearn()
1100+
@pytest.mark.uses_test_server()
1101+
def test_run_flow_on_task_basic(self):
1102+
"""Test that run_flow_on_task executes successfully with basic flow and task."""
1103+
# construct sci-kit learn classifier
1104+
clf = Pipeline(
1105+
steps=[
1106+
("imputer", SimpleImputer(strategy="most_frequent")),
1107+
("encoder", OneHotEncoder(handle_unknown="ignore")),
1108+
("estimator", RandomForestClassifier(n_estimators=5, random_state=42)),
1109+
],
1110+
)
1111+
1112+
# convert model to flow
1113+
flow = self.extension.model_to_flow(clf)
1114+
1115+
# download task
1116+
task = openml.tasks.get_task(119) # diabetes; holdout
1117+
1118+
# invoke run_flow_on_task (refactored function under test)
1119+
run = openml.runs.run_flow_on_task(
1120+
flow=flow,
1121+
task=task,
1122+
upload_flow=False,
1123+
)
1124+
1125+
# verify run was created successfully
1126+
assert run.task_id == task.task_id
1127+
assert run.flow_name == flow.name
1128+
assert run.dataset_id == task.dataset_id
1129+
assert run.data_content is not None
1130+
assert len(run.data_content) > 0
1131+
1132+
TestBase._mark_entity_for_removal("run", run.run_id)
1133+
TestBase.logger.info(f"collected from test_run_flow_on_task_basic: {run.run_id}")
1134+
1135+
@pytest.mark.sklearn()
1136+
@pytest.mark.uses_test_server()
1137+
def test_run_flow_on_task_with_flow_tags(self):
1138+
"""Test run_flow_on_task with custom flow tags (for the flow, not the run)."""
1139+
clf = RandomForestClassifier(n_estimators=5, random_state=42)
1140+
flow = self.extension.model_to_flow(clf)
1141+
task = openml.tasks.get_task(119)
1142+
1143+
# invoke run_flow_on_task with custom flow tags
1144+
# Note: flow_tags are tags for the flow object, not the run
1145+
run = openml.runs.run_flow_on_task(
1146+
flow=flow,
1147+
task=task,
1148+
flow_tags=["test_flow_tag_1", "test_flow_tag_2"],
1149+
upload_flow=False,
1150+
)
1151+
1152+
# verify run was created successfully
1153+
assert run.task_id == task.task_id
1154+
assert run.flow_name == flow.name
1155+
assert run.data_content is not None
1156+
1157+
TestBase._mark_entity_for_removal("run", run.run_id)
1158+
TestBase.logger.info(f"collected from test_run_flow_on_task_with_flow_tags: {run.run_id}")
1159+
10991160
@pytest.mark.production()
11001161
def test_online_run_metric_score(self):
11011162
self.use_production_server()

0 commit comments

Comments
 (0)