@@ -1096,6 +1096,67 @@ def test_local_run_metric_score(self):
10961096
10971097 self ._test_local_evaluations (run )
10981098
1099+ @pytest .mark .sklearn ()
1100+ @pytest .mark .uses_test_server ()
1101+ def test_run_flow_on_task_basic (self ):
1102+ """Test that run_flow_on_task executes successfully with basic flow and task."""
1103+ # construct sci-kit learn classifier
1104+ clf = Pipeline (
1105+ steps = [
1106+ ("imputer" , SimpleImputer (strategy = "most_frequent" )),
1107+ ("encoder" , OneHotEncoder (handle_unknown = "ignore" )),
1108+ ("estimator" , RandomForestClassifier (n_estimators = 5 , random_state = 42 )),
1109+ ],
1110+ )
1111+
1112+ # convert model to flow
1113+ flow = self .extension .model_to_flow (clf )
1114+
1115+ # download task
1116+ task = openml .tasks .get_task (119 ) # diabetes; holdout
1117+
1118+ # invoke run_flow_on_task (refactored function under test)
1119+ run = openml .runs .run_flow_on_task (
1120+ flow = flow ,
1121+ task = task ,
1122+ upload_flow = False ,
1123+ )
1124+
1125+ # verify run was created successfully
1126+ assert run .task_id == task .task_id
1127+ assert run .flow_name == flow .name
1128+ assert run .dataset_id == task .dataset_id
1129+ assert run .data_content is not None
1130+ assert len (run .data_content ) > 0
1131+
1132+ TestBase ._mark_entity_for_removal ("run" , run .run_id )
1133+ TestBase .logger .info (f"collected from test_run_flow_on_task_basic: { run .run_id } " )
1134+
1135+ @pytest .mark .sklearn ()
1136+ @pytest .mark .uses_test_server ()
1137+ def test_run_flow_on_task_with_flow_tags (self ):
1138+ """Test run_flow_on_task with custom flow tags (for the flow, not the run)."""
1139+ clf = RandomForestClassifier (n_estimators = 5 , random_state = 42 )
1140+ flow = self .extension .model_to_flow (clf )
1141+ task = openml .tasks .get_task (119 )
1142+
1143+ # invoke run_flow_on_task with custom flow tags
1144+ # Note: flow_tags are tags for the flow object, not the run
1145+ run = openml .runs .run_flow_on_task (
1146+ flow = flow ,
1147+ task = task ,
1148+ flow_tags = ["test_flow_tag_1" , "test_flow_tag_2" ],
1149+ upload_flow = False ,
1150+ )
1151+
1152+ # verify run was created successfully
1153+ assert run .task_id == task .task_id
1154+ assert run .flow_name == flow .name
1155+ assert run .data_content is not None
1156+
1157+ TestBase ._mark_entity_for_removal ("run" , run .run_id )
1158+ TestBase .logger .info (f"collected from test_run_flow_on_task_with_flow_tags: { run .run_id } " )
1159+
10991160 @pytest .mark .production ()
11001161 def test_online_run_metric_score (self ):
11011162 self .use_production_server ()
0 commit comments