@@ -191,3 +191,40 @@ def test_update_intervals_for_new_snapshots(
191191 else :
192192 assert not snapshot .dev_intervals
193193 state_sync_mock .add_interval .assert_not_called ()
194+
195+
196+ def test_state_based_airflow_evaluator_with_restatements (
197+ sushi_context : Context , mocker : MockerFixture
198+ ):
199+ model_fqn = sushi_context .get_model ("sushi.waiter_revenue_by_day" ).fqn
200+ downstream_model_fqn = sushi_context .get_model ("sushi.top_waiters" ).fqn
201+
202+ plan = PlanBuilder (
203+ sushi_context ._context_diff ("prod" ),
204+ sushi_context .engine_adapter .SCHEMA_DIFFER ,
205+ restate_models = [sushi_context .get_model ("sushi.waiter_revenue_by_day" ).fqn ],
206+ ).build ()
207+
208+ mwaa_client_mock = mocker .Mock ()
209+ mwaa_client_mock .wait_for_dag_run_completion .return_value = True
210+ mwaa_client_mock .wait_for_first_dag_run .return_value = "test_plan_application_dag_run_id"
211+ mwaa_client_mock .set_variable .return_value = "" , ""
212+
213+ state_sync_mock = mocker .Mock ()
214+
215+ plan_dag_spec_mock = mocker .Mock ()
216+
217+ create_plan_dag_spec_mock = mocker .patch ("sqlmesh.schedulers.airflow.plan.create_plan_dag_spec" )
218+ create_plan_dag_spec_mock .return_value = plan_dag_spec_mock
219+
220+ plan_dag_state_mock = mocker .Mock ()
221+ mocker .patch (
222+ "sqlmesh.schedulers.airflow.plan.PlanDagState.from_state_sync" ,
223+ return_value = plan_dag_state_mock ,
224+ )
225+
226+ evaluator = MWAAPlanEvaluator (mwaa_client_mock , state_sync_mock )
227+ evaluator .evaluate (plan )
228+
229+ plan_application_request = create_plan_dag_spec_mock .call_args [0 ][0 ]
230+ assert plan_application_request .restatements .keys () == {model_fqn , downstream_model_fqn }
0 commit comments