From 60d30653886bbd0655e89f2004f6cc84d1b89b01 Mon Sep 17 00:00:00 2001 From: David Blain Date: Sun, 29 Sep 2024 17:58:17 +0200 Subject: [PATCH 01/97] refactor: Implemented StreamedOperator which runs all mapped tasks within the same TaskInstance using a ThreadPoolExecutor on the same worker instance --- airflow/exceptions.py | 13 +- airflow/models/baseoperator.py | 17 +- airflow/models/mappedoperator.py | 2 +- airflow/models/streamedoperator.py | 335 +++++++++++++++++++++++++++++ 4 files changed, 360 insertions(+), 7 deletions(-) create mode 100644 airflow/models/streamedoperator.py diff --git a/airflow/exceptions.py b/airflow/exceptions.py index 55dd02fdae313..e1aca74aac4aa 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -31,7 +31,7 @@ import datetime from collections.abc import Sized - from airflow.models import DAG, DagRun + from airflow.models import DAG, DagRun, TaskInstance class AirflowException(Exception): @@ -84,6 +84,17 @@ def serialize(self): return f"{cls.__module__}.{cls.__name__}", (), {"reschedule_date": self.reschedule_date} +class AirflowRescheduleTaskInstanceException(AirflowRescheduleException): + """ + Raise when the task should be re-scheduled for a specific TaskInstance at a later time. + + :param task_instance: The task instance that should be rescheduled + """ + def __init__(self, task_instance: TaskInstance): + super().__init__(reschedule_date=task_instance.next_retry_datetime()) + self.task_instance = task_instance + + class InvalidStatsNameException(AirflowException): """Raise when name of the stats is invalid.""" diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 20656586ba01e..24864e69be029 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -391,10 +391,11 @@ def decorator(cls, func): @wraps(func) def wrapper(self, *args, **kwargs): from airflow.decorators.base import DecoratedOperator + from airflow.models.streamedoperator import StreamedOperator sentinel = kwargs.pop(f"{self.__class__.__name__}__sentinel", None) - if not cls.test_mode and not sentinel == _sentinel and not isinstance(self, DecoratedOperator): + if not cls.test_mode and not sentinel == _sentinel and not isinstance(self, DecoratedOperator) and not isinstance(self, StreamedOperator): message = f"{self.__class__.__name__}.{func.__name__} cannot be called outside TaskInstance!" if not self.allow_nested_operators: raise AirflowException(message) @@ -1723,20 +1724,26 @@ def defer( """ raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) - def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context): - """Call this method when a deferred task is resumed.""" + @classmethod + def next_callable(cls, operator, next_method, next_kwargs) -> Callable[[Context, Any], Any]: + """Get the next callable from given operator.""" # __fail__ is a special signal value for next_method that indicates # this task was scheduled specifically to fail. if next_method == "__fail__": next_kwargs = next_kwargs or {} traceback = next_kwargs.get("traceback") if traceback is not None: - self.log.error("Trigger failed:\n%s", "\n".join(traceback)) + logging.error("Trigger failed:\n%s", "\n".join(traceback)) raise TaskDeferralError(next_kwargs.get("error", "Unknown")) # Grab the callable off the Operator/Task and add in any kwargs - execute_callable = getattr(self, next_method) + execute_callable = getattr(operator, next_method) if next_kwargs: execute_callable = functools.partial(execute_callable, **next_kwargs) + return execute_callable + + def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, context: Context): + """Call this method when a deferred task is resumed.""" + execute_callable = self.next_callable(self, next_method, next_kwargs) return execute_callable(context) def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator: diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 8a9e790ea7fc6..7cb44ca486b06 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -89,7 +89,7 @@ TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, List[TaskStateChangeCallback]] -ValidationSource = Union[Literal["expand"], Literal["partial"]] +ValidationSource = Union[Literal["expand"], Literal["partial"], Literal["stream"]] def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py new file mode 100644 index 0000000000000..f02a879fb0173 --- /dev/null +++ b/airflow/models/streamedoperator.py @@ -0,0 +1,335 @@ +from __future__ import annotations + +import asyncio +import logging +import os +from asyncio import AbstractEventLoop, iscoroutinefunction, Semaphore, ensure_future +from contextlib import contextmanager +from datetime import datetime +from math import ceil +from time import sleep +from typing import Any, Sequence, Callable, Generator + +import jinja2 +from sqlalchemy.orm import Session + +from airflow import XComArg +from airflow.exceptions import ( + TaskDeferred, + AirflowException, + AirflowRescheduleTaskInstanceException, +) +from airflow.models import BaseOperator, Operator, TaskInstance +from airflow.models.expandinput import ( + ExpandInput, + DictOfListsExpandInput, + OperatorExpandArgument, + _needs_run_time_resolution, +) +from airflow.models.mappedoperator import ( + ensure_xcomarg_return_value, + validate_mapping_kwargs, +) +from airflow.triggers.base import TriggerEvent, BaseTrigger +from airflow.utils import timezone +from airflow.utils.context import Context, context_get_outlet_events +from airflow.utils.helpers import prevent_duplicates +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.operator_helpers import ExecutionCallableRunner +from airflow.utils.task_instance_session import get_current_task_instance_session + + +@contextmanager +def event_loop() -> Generator[AbstractEventLoop, None, None]: + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + yield loop + + +async def run_trigger(trigger: BaseTrigger) -> list[TriggerEvent]: + events = [] + async for event in trigger.run(): + events.append(event) + return events + + +class OperatorMethodExecutor(LoggingMixin): + def __init__( + self, + semaphore: Semaphore, + operator: Operator, + context: Context, + task_instance: TaskInstance, + ): + super().__init__() + self._semaphore = semaphore + self.operator = operator + self.__context = context + self._task_instance = task_instance + + @property + def task_instance(self) -> TaskInstance: + # TODO: If we want a specialized TaskInstance for the StreamedOperator, + # we could inherit from TaskInstanceDependencies + return self._task_instance + + @property + def context(self) -> Context: + return {**self.__context, **{"ti": self.task_instance}} + + async def _run_callable(self, method: Callable, *args, **kwargs): + self.log.debug("semaphore: %s (%s)", self._semaphore, self._semaphore.locked()) + async with self._semaphore: + while self.task_instance.try_number <= self.operator.retries: + if self.log.isEnabledFor(logging.INFO): + self.log.info( + "Attempting running task %s of %s for %s with map_index %s.", + self.task_instance.try_number, + self.operator.retries, + type(self.operator).__name__, + self.task_instance.map_index, + ) + + try: + outlet_events = context_get_outlet_events(self.context) + callable_runner = ExecutionCallableRunner( + func=method, outlet_events=outlet_events, logger=self.log + ) + if iscoroutinefunction(method): + return await callable_runner.run(*args, **kwargs) + return callable_runner.run(*args, **kwargs) + except AirflowException as e: + if self.task_instance.next_try_number > self.operator.retries: + self.log.error( + "Max number of attempts for %s with map_index %s failed due to: %s", + type(self.operator).__name__, + self.task_instance.map_index, + e, + ) + raise e + + self.log.error("An error occurred: %s", e) + self.task_instance.try_number += 1 + self.task_instance.end_date = timezone.utcnow() + + raise AirflowRescheduleTaskInstanceException( + task_instance=self.task_instance + ) + + async def run_deferrable(self, context: Context, task_deferred: TaskDeferred): + event = next(iter(await run_trigger(task_deferred.trigger))) + + self.log.debug("event: %s", event) + self.log.debug("next_method: %s", task_deferred.method_name) + + if task_deferred.method_name: + next_method = BaseOperator.next_callable( + self.operator, task_deferred.method_name, task_deferred.kwargs + ) + result = next_method(context, event.payload) + self.log.debug("result: %s", result) + return result + + async def run(self, method: Callable, *args, **kwargs): + self.operator.pre_execute(context=self.context) + self.task_instance._run_execute_callback( + context=self.context, task=self.operator + ) + + try: + return await self._run_callable( + method, *(list(args or ()) + [self.context]), **kwargs + ) + except TaskDeferred as task_deferred: + return await self._run_callable( + self.run_deferrable, *[self.context, task_deferred] + ) + finally: + self.operator.post_execute(context=self.context) + + +class StreamedOperator(BaseOperator): + _operator_class: type[BaseOperator] | dict[str, Any] + _expand_input: ExpandInput + _partial_kwargs: dict[str, Any] + + def __init__( + self, + *, + operator_class: type[BaseOperator] | dict[str, Any], + expand_input: ExpandInput, + partial_kwargs: dict[str, Any] | None = None, + **kwargs: Any, + ): + super().__init__(**kwargs) + self._operator_class = operator_class + self._expand_input = expand_input + self._partial_kwargs = partial_kwargs or {} + self._mapped_kwargs = [] + self._semaphore = Semaphore(self.max_active_tis_per_dag) + XComArg.apply_upstream_relationship(self, self._expand_input.value) + + @property + def operator_name(self) -> str: + return self._operator_class.__name__ + + def _unmap_operator(self, index): + self.log.debug("index: %s", index) + kwargs = { + **self._partial_kwargs, + **{"task_id": f"{self._partial_kwargs.get('task_id')}_{index}"}, + **self._mapped_kwargs[index], + } + self.log.debug("kwargs: %s", kwargs) + self.log.debug("operator_class: %s", self._operator_class) + return self._operator_class(**kwargs) + + def _resolve_expand_input(self, context: Context, session: Session): + for key, value in self._expand_input.value.items(): + if _needs_run_time_resolution(value): + value = value.resolve(context=context, session=session) + + if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): + value = list(value) + + self.log.debug("resolved_value: %s", value) + + if isinstance(value, list): + self._mapped_kwargs.extend([{key: item} for item in value]) + else: + self._mapped_kwargs.append({key: value}) + self.log.debug("resolve_expand_input: %s", self._mapped_kwargs) + + def render_template_fields( + self, + context: Context, + jinja_env: jinja2.Environment | None = None, + ) -> None: + session = get_current_task_instance_session() + self._resolve_expand_input(context=context, session=session) + + def _run_futures( + self, context: Context, futures, results: list[Any] | None = None + ) -> list[Any]: + reschedule_date: datetime | None = None + results = results or [] + failed_futures = [] + + with event_loop() as loop: + for result in loop.run_until_complete( + asyncio.gather(*futures, return_exceptions=True) + ): + if isinstance(result, Exception): + if not isinstance(result, AirflowRescheduleTaskInstanceException): + raise result + reschedule_date = result.reschedule_date + failed_futures.append( + ensure_future(self._run_task(context, result.task_instance)) + ) + else: + results.append(result) + + if not failed_futures: + return list(filter(None, results)) + + # session = get_current_task_instance_session() + # TaskInstance._set_state(context["ti"], TaskInstanceState.UP_FOR_RETRY, session) + + # Calculate delay before the next retry + delay = reschedule_date - timezone.utcnow() + delay_seconds = ceil(delay.total_seconds()) + + self.log.debug("delay_seconds: %s", delay_seconds) + + if delay_seconds > 0: + self.log.info( + "Attempting to run %s failed tasks within %s seconds...", + len(failed_futures), + delay_seconds, + ) + + sleep(delay_seconds) + + # TaskInstance._set_state(context["ti"], TaskInstanceState.RUNNING, session) + + return self._run_futures(context, failed_futures, results) + + async def _run_task(self, context: Context, task_instance: TaskInstance): + operator = task_instance.task + self.log.debug("operator: %s", operator) + result = await OperatorMethodExecutor( + semaphore=self._semaphore, + operator=operator, + context=context, + task_instance=task_instance, + ).run(operator.execute) + self.log.debug("result: %s", result) + self.log.debug("do_xcom_push: %s", operator.do_xcom_push) + if operator.do_xcom_push: + return result + + def _create_future(self, context: Context, index: int): + operator = self._unmap_operator(index) + operator.render_template_fields(context=context) + task_instance = TaskInstance( + task=operator, + execution_date=context["ti"].execution_date, + run_id=context["ti"].run_id, + state=context["ti"].state, + map_index=index, + ) + return asyncio.ensure_future(self._run_task(context, task_instance)) + + def execute(self, context: Context): + self.log.info( + "Executing %s mapped tasks on %s with %s workers", + len(self._mapped_kwargs), + self._operator_class.__name__, + self.max_active_tis_per_dag, + ) + + return self._run_futures( + context=context, + futures=[ + self._create_future(context, index) + for index, mapped_kwargs in enumerate(self._mapped_kwargs) + ], + ) + + +def stream(self, **mapped_kwargs: OperatorExpandArgument) -> StreamedOperator: + if not mapped_kwargs: + raise TypeError("no arguments to expand against") + validate_mapping_kwargs(self.operator_class, "stream", mapped_kwargs) + prevent_duplicates( + self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified" + ) + + expand_input = DictOfListsExpandInput(mapped_kwargs) + ensure_xcomarg_return_value(expand_input.value) + + partial_kwargs = self.kwargs.copy() + task_id = partial_kwargs.pop("task_id") + dag = partial_kwargs.pop("dag") + task_group = partial_kwargs.pop("task_group") + start_date = partial_kwargs.pop("start_date") + end_date = partial_kwargs.pop("end_date") + max_active_tis_per_dag = ( + partial_kwargs.pop("max_active_tis_per_dag", None) or os.cpu_count() + ) + + return StreamedOperator( + task_id=task_id, + dag=dag, + task_group=task_group, + start_date=start_date, + end_date=end_date, + max_active_tis_per_dag=max_active_tis_per_dag, + operator_class=self.operator_class, + expand_input=expand_input, + retries=0, + partial_kwargs=self.kwargs.copy(), + ) From effb0ea2499d4ef3650e664f6a76a0b2016bc27c Mon Sep 17 00:00:00 2001 From: David Blain Date: Sun, 29 Sep 2024 19:47:36 +0200 Subject: [PATCH 02/97] refactor: Changed return type of next_callable method --- airflow/models/baseoperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 24864e69be029..ca24ede475e10 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1725,7 +1725,7 @@ def defer( raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) @classmethod - def next_callable(cls, operator, next_method, next_kwargs) -> Callable[[Context, Any], Any]: + def next_callable(cls, operator, next_method, next_kwargs) -> Callable[[Context], Any]: """Get the next callable from given operator.""" # __fail__ is a special signal value for next_method that indicates # this task was scheduled specifically to fail. From d013e3d7a2477e5c3b3f32c4e4b623825e2f775c Mon Sep 17 00:00:00 2001 From: David Blain Date: Sun, 29 Sep 2024 19:50:14 +0200 Subject: [PATCH 03/97] refactor: Changed operator type to BaseOperator in OperatorMethodExecutor --- airflow/models/streamedoperator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index f02a879fb0173..254062717bf28 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -19,7 +19,7 @@ AirflowException, AirflowRescheduleTaskInstanceException, ) -from airflow.models import BaseOperator, Operator, TaskInstance +from airflow.models import BaseOperator, TaskInstance from airflow.models.expandinput import ( ExpandInput, DictOfListsExpandInput, @@ -60,7 +60,7 @@ class OperatorMethodExecutor(LoggingMixin): def __init__( self, semaphore: Semaphore, - operator: Operator, + operator: BaseOperator, context: Context, task_instance: TaskInstance, ): From 2a827d2d1356fa81c3625e7f168bd8f76d615196 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 30 Sep 2024 12:41:08 +0200 Subject: [PATCH 04/97] refactor: Fixed some static checks --- airflow/exceptions.py | 7 ++-- airflow/models/baseoperator.py | 7 +++- airflow/models/streamedoperator.py | 59 +++++++++++++++--------------- 3 files changed, 40 insertions(+), 33 deletions(-) diff --git a/airflow/exceptions.py b/airflow/exceptions.py index e1aca74aac4aa..f509ab964b7da 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -86,10 +86,11 @@ def serialize(self): class AirflowRescheduleTaskInstanceException(AirflowRescheduleException): """ - Raise when the task should be re-scheduled for a specific TaskInstance at a later time. + Raise when the task should be re-scheduled for a specific TaskInstance at a later time. + + :param task_instance: The task instance that should be rescheduled + """ - :param task_instance: The task instance that should be rescheduled - """ def __init__(self, task_instance: TaskInstance): super().__init__(reschedule_date=task_instance.next_retry_datetime()) self.task_instance = task_instance diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index ca24ede475e10..0fe3fabbcd92c 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -395,7 +395,12 @@ def wrapper(self, *args, **kwargs): sentinel = kwargs.pop(f"{self.__class__.__name__}__sentinel", None) - if not cls.test_mode and not sentinel == _sentinel and not isinstance(self, DecoratedOperator) and not isinstance(self, StreamedOperator): + if ( + not cls.test_mode + and not sentinel == _sentinel + and not isinstance(self, DecoratedOperator) + and not isinstance(self, StreamedOperator) + ): message = f"{self.__class__.__name__}.{func.__name__} cannot be called outside TaskInstance!" if not self.allow_nested_operators: raise AirflowException(message) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 254062717bf28..7b443466906fc 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -1,28 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. from __future__ import annotations import asyncio import logging import os -from asyncio import AbstractEventLoop, iscoroutinefunction, Semaphore, ensure_future +from asyncio import AbstractEventLoop, Semaphore, ensure_future, iscoroutinefunction from contextlib import contextmanager from datetime import datetime from math import ceil from time import sleep -from typing import Any, Sequence, Callable, Generator +from typing import Any, Callable, Generator, Sequence import jinja2 from sqlalchemy.orm import Session from airflow import XComArg from airflow.exceptions import ( - TaskDeferred, AirflowException, AirflowRescheduleTaskInstanceException, + TaskDeferred, ) from airflow.models import BaseOperator, TaskInstance from airflow.models.expandinput import ( - ExpandInput, DictOfListsExpandInput, + ExpandInput, OperatorExpandArgument, _needs_run_time_resolution, ) @@ -30,7 +47,7 @@ ensure_xcomarg_return_value, validate_mapping_kwargs, ) -from airflow.triggers.base import TriggerEvent, BaseTrigger +from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.utils import timezone from airflow.utils.context import Context, context_get_outlet_events from airflow.utils.helpers import prevent_duplicates @@ -115,9 +132,7 @@ async def _run_callable(self, method: Callable, *args, **kwargs): self.task_instance.try_number += 1 self.task_instance.end_date = timezone.utcnow() - raise AirflowRescheduleTaskInstanceException( - task_instance=self.task_instance - ) + raise AirflowRescheduleTaskInstanceException(task_instance=self.task_instance) async def run_deferrable(self, context: Context, task_deferred: TaskDeferred): event = next(iter(await run_trigger(task_deferred.trigger))) @@ -135,18 +150,12 @@ async def run_deferrable(self, context: Context, task_deferred: TaskDeferred): async def run(self, method: Callable, *args, **kwargs): self.operator.pre_execute(context=self.context) - self.task_instance._run_execute_callback( - context=self.context, task=self.operator - ) + self.task_instance._run_execute_callback(context=self.context, task=self.operator) try: - return await self._run_callable( - method, *(list(args or ()) + [self.context]), **kwargs - ) + return await self._run_callable(method, *(list(args or ()) + [self.context]), **kwargs) except TaskDeferred as task_deferred: - return await self._run_callable( - self.run_deferrable, *[self.context, task_deferred] - ) + return await self._run_callable(self.run_deferrable, *[self.context, task_deferred]) finally: self.operator.post_execute(context=self.context) @@ -211,24 +220,18 @@ def render_template_fields( session = get_current_task_instance_session() self._resolve_expand_input(context=context, session=session) - def _run_futures( - self, context: Context, futures, results: list[Any] | None = None - ) -> list[Any]: + def _run_futures(self, context: Context, futures, results: list[Any] | None = None) -> list[Any]: reschedule_date: datetime | None = None results = results or [] failed_futures = [] with event_loop() as loop: - for result in loop.run_until_complete( - asyncio.gather(*futures, return_exceptions=True) - ): + for result in loop.run_until_complete(asyncio.gather(*futures, return_exceptions=True)): if isinstance(result, Exception): if not isinstance(result, AirflowRescheduleTaskInstanceException): raise result reschedule_date = result.reschedule_date - failed_futures.append( - ensure_future(self._run_task(context, result.task_instance)) - ) + failed_futures.append(ensure_future(self._run_task(context, result.task_instance))) else: results.append(result) @@ -304,9 +307,7 @@ def stream(self, **mapped_kwargs: OperatorExpandArgument) -> StreamedOperator: if not mapped_kwargs: raise TypeError("no arguments to expand against") validate_mapping_kwargs(self.operator_class, "stream", mapped_kwargs) - prevent_duplicates( - self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified" - ) + prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") expand_input = DictOfListsExpandInput(mapped_kwargs) ensure_xcomarg_return_value(expand_input.value) From e9e0ad4ad83fea981fa6a41617d755886e7db042 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 30 Sep 2024 12:56:17 +0200 Subject: [PATCH 05/97] refactor: Initialise Semaphore correctly --- airflow/models/streamedoperator.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 7b443466906fc..539f05b4b3331 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -178,7 +178,7 @@ def __init__( self._expand_input = expand_input self._partial_kwargs = partial_kwargs or {} self._mapped_kwargs = [] - self._semaphore = Semaphore(self.max_active_tis_per_dag) + self._semaphore = Semaphore(self.max_active_tis_per_dag or os.cpu_count()) XComArg.apply_upstream_relationship(self, self._expand_input.value) @property @@ -318,9 +318,7 @@ def stream(self, **mapped_kwargs: OperatorExpandArgument) -> StreamedOperator: task_group = partial_kwargs.pop("task_group") start_date = partial_kwargs.pop("start_date") end_date = partial_kwargs.pop("end_date") - max_active_tis_per_dag = ( - partial_kwargs.pop("max_active_tis_per_dag", None) or os.cpu_count() - ) + max_active_tis_per_dag = (partial_kwargs.pop("max_active_tis_per_dag", None)) return StreamedOperator( task_id=task_id, From 20f02ed5d5078babc05cbf33809d37a43431f293 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 30 Sep 2024 18:36:01 +0200 Subject: [PATCH 06/97] refactor: Reformatted StreamedOperator --- airflow/models/streamedoperator.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 539f05b4b3331..3b14b47d33c5d 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -296,10 +296,7 @@ def execute(self, context: Context): return self._run_futures( context=context, - futures=[ - self._create_future(context, index) - for index, mapped_kwargs in enumerate(self._mapped_kwargs) - ], + futures=[self._create_future(context, index) for index, mapped_kwargs in enumerate(self._mapped_kwargs)], ) @@ -318,7 +315,7 @@ def stream(self, **mapped_kwargs: OperatorExpandArgument) -> StreamedOperator: task_group = partial_kwargs.pop("task_group") start_date = partial_kwargs.pop("start_date") end_date = partial_kwargs.pop("end_date") - max_active_tis_per_dag = (partial_kwargs.pop("max_active_tis_per_dag", None)) + max_active_tis_per_dag = partial_kwargs.pop("max_active_tis_per_dag", None) return StreamedOperator( task_id=task_id, From 5d97c317d6a5c9246bef6c6815e08cdf9d4912ec Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 30 Sep 2024 18:48:41 +0200 Subject: [PATCH 07/97] refactor: Fixed some mypy issues --- airflow/models/baseoperator.py | 2 +- airflow/models/streamedoperator.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 0fe3fabbcd92c..b03b205058eef 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1730,7 +1730,7 @@ def defer( raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) @classmethod - def next_callable(cls, operator, next_method, next_kwargs) -> Callable[[Context], Any]: + def next_callable(cls, operator, next_method, next_kwargs) -> Callable[[Context, ...], Any]: """Get the next callable from given operator.""" # __fail__ is a special signal value for next_method that indicates # this task was scheduled specifically to fail. diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 3b14b47d33c5d..a831cb54f80f9 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -168,7 +168,7 @@ class StreamedOperator(BaseOperator): def __init__( self, *, - operator_class: type[BaseOperator] | dict[str, Any], + operator_class: type[BaseOperator], expand_input: ExpandInput, partial_kwargs: dict[str, Any] | None = None, **kwargs: Any, @@ -177,7 +177,7 @@ def __init__( self._operator_class = operator_class self._expand_input = expand_input self._partial_kwargs = partial_kwargs or {} - self._mapped_kwargs = [] + self._mapped_kwargs: list[dict] = [] self._semaphore = Semaphore(self.max_active_tis_per_dag or os.cpu_count()) XComArg.apply_upstream_relationship(self, self._expand_input.value) @@ -279,7 +279,6 @@ def _create_future(self, context: Context, index: int): operator.render_template_fields(context=context) task_instance = TaskInstance( task=operator, - execution_date=context["ti"].execution_date, run_id=context["ti"].run_id, state=context["ti"].state, map_index=index, From 513edd01a0dafb829a8f9ef0c55f4a81b1c286b8 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 30 Sep 2024 19:05:04 +0200 Subject: [PATCH 08/97] refactor: Reformatted StreamOperator --- airflow/models/streamedoperator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index a831cb54f80f9..658a92951f735 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -295,7 +295,9 @@ def execute(self, context: Context): return self._run_futures( context=context, - futures=[self._create_future(context, index) for index, mapped_kwargs in enumerate(self._mapped_kwargs)], + futures=[ + self._create_future(context, index) for index, mapped_kwargs in enumerate(self._mapped_kwargs) + ], ) From d142ebfd135db38748cbdad888dc03b0587f56fb Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 30 Sep 2024 19:23:54 +0200 Subject: [PATCH 09/97] refactor: Changed next_callable to instance method in BaseOperator --- airflow/models/baseoperator.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index b03b205058eef..5b14e27a4da64 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1729,8 +1729,7 @@ def defer( """ raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) - @classmethod - def next_callable(cls, operator, next_method, next_kwargs) -> Callable[[Context, ...], Any]: + def next_callable(self, operator, next_method, next_kwargs) -> Callable[[Context, ...], Any]: """Get the next callable from given operator.""" # __fail__ is a special signal value for next_method that indicates # this task was scheduled specifically to fail. @@ -1738,7 +1737,7 @@ def next_callable(cls, operator, next_method, next_kwargs) -> Callable[[Context, next_kwargs = next_kwargs or {} traceback = next_kwargs.get("traceback") if traceback is not None: - logging.error("Trigger failed:\n%s", "\n".join(traceback)) + self.log.error("Trigger failed:\n%s", "\n".join(traceback)) raise TaskDeferralError(next_kwargs.get("error", "Unknown")) # Grab the callable off the Operator/Task and add in any kwargs execute_callable = getattr(operator, next_method) From 985dbbeefa22ed6c110a7112d03e9aba1894f6cd Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 30 Sep 2024 19:35:37 +0200 Subject: [PATCH 10/97] refactor: Added docstrings in StreamedOperator --- airflow/models/streamedoperator.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 658a92951f735..3a34147ed6d8d 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -25,10 +25,7 @@ from datetime import datetime from math import ceil from time import sleep -from typing import Any, Callable, Generator, Sequence - -import jinja2 -from sqlalchemy.orm import Session +from typing import TYPE_CHECKING, Any, Callable, Generator, Sequence from airflow import XComArg from airflow.exceptions import ( @@ -47,7 +44,6 @@ ensure_xcomarg_return_value, validate_mapping_kwargs, ) -from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.utils import timezone from airflow.utils.context import Context, context_get_outlet_events from airflow.utils.helpers import prevent_duplicates @@ -55,6 +51,12 @@ from airflow.utils.operator_helpers import ExecutionCallableRunner from airflow.utils.task_instance_session import get_current_task_instance_session +if TYPE_CHECKING: + import jinja2 + + from sqlalchemy.orm import Session + from airflow.triggers.base import BaseTrigger, TriggerEvent + @contextmanager def event_loop() -> Generator[AbstractEventLoop, None, None]: @@ -73,7 +75,17 @@ async def run_trigger(trigger: BaseTrigger) -> list[TriggerEvent]: return events -class OperatorMethodExecutor(LoggingMixin): +class OperatorExecutor(LoggingMixin): + """ + Run an operator with given task context and task instance. + + If the execute function raises a TaskDeferred exception, then the trigger instance within the + TaskDeferred exception will be executed with the given context and task instance. The operator + or trigger will always be executed in an async way. + + :meta private: + """ + def __init__( self, semaphore: Semaphore, @@ -161,6 +173,8 @@ async def run(self, method: Callable, *args, **kwargs): class StreamedOperator(BaseOperator): + """Object representing a streamed operator in a DAG.""" + _operator_class: type[BaseOperator] | dict[str, Any] _expand_input: ExpandInput _partial_kwargs: dict[str, Any] @@ -263,7 +277,7 @@ def _run_futures(self, context: Context, futures, results: list[Any] | None = No async def _run_task(self, context: Context, task_instance: TaskInstance): operator = task_instance.task self.log.debug("operator: %s", operator) - result = await OperatorMethodExecutor( + result = await OperatorExecutor( semaphore=self._semaphore, operator=operator, context=context, From 542727a0433144893c5375a4592cfa423b3593f9 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 1 Oct 2024 08:16:33 +0200 Subject: [PATCH 11/97] refactor: Changed next_callable method in StreamedOperator back to classmethod --- airflow/models/baseoperator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 5b14e27a4da64..5e46a4db9d1ed 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1729,7 +1729,8 @@ def defer( """ raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) - def next_callable(self, operator, next_method, next_kwargs) -> Callable[[Context, ...], Any]: + @classmethod + def next_callable(cls, operator, next_method, next_kwargs) -> Callable[[Context, ...], Any]: """Get the next callable from given operator.""" # __fail__ is a special signal value for next_method that indicates # this task was scheduled specifically to fail. @@ -1737,7 +1738,7 @@ def next_callable(self, operator, next_method, next_kwargs) -> Callable[[Context next_kwargs = next_kwargs or {} traceback = next_kwargs.get("traceback") if traceback is not None: - self.log.error("Trigger failed:\n%s", "\n".join(traceback)) + cls.logger().error("Trigger failed:\n%s", "\n".join(traceback)) raise TaskDeferralError(next_kwargs.get("error", "Unknown")) # Grab the callable off the Operator/Task and add in any kwargs execute_callable = getattr(operator, next_method) From cb00425254c74aa2888efe0984d35243bbbc20b4 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 1 Oct 2024 08:28:02 +0200 Subject: [PATCH 12/97] refactor: Changed typing of next_callable method in StreamedOperator --- airflow/models/baseoperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 5e46a4db9d1ed..25f7d5c278dad 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1730,7 +1730,7 @@ def defer( raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) @classmethod - def next_callable(cls, operator, next_method, next_kwargs) -> Callable[[Context, ...], Any]: + def next_callable(cls, operator, next_method, next_kwargs) -> Callable[..., Any]: """Get the next callable from given operator.""" # __fail__ is a special signal value for next_method that indicates # this task was scheduled specifically to fail. From 08a7c7f9519557823c0ba6701418777f2649f84e Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 1 Oct 2024 08:29:46 +0200 Subject: [PATCH 13/97] refactor: Fixed some typing issues StreamedOperator --- airflow/models/streamedoperator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 3a34147ed6d8d..0e2fc57b4368a 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -33,7 +33,7 @@ AirflowRescheduleTaskInstanceException, TaskDeferred, ) -from airflow.models import BaseOperator, TaskInstance +from airflow.models.baseoperator import BaseOperator from airflow.models.expandinput import ( DictOfListsExpandInput, ExpandInput, @@ -44,6 +44,7 @@ ensure_xcomarg_return_value, validate_mapping_kwargs, ) +from airflow.models.taskinstance import TaskInstance from airflow.utils import timezone from airflow.utils.context import Context, context_get_outlet_events from airflow.utils.helpers import prevent_duplicates From 2712b85845f2f0cc9867a4ce7e071410815cbaf5 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 1 Oct 2024 08:33:03 +0200 Subject: [PATCH 14/97] refactor: Initialise reschedule_date as utcnow by default --- airflow/models/streamedoperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 0e2fc57b4368a..c5e58b83fab3e 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -236,7 +236,7 @@ def render_template_fields( self._resolve_expand_input(context=context, session=session) def _run_futures(self, context: Context, futures, results: list[Any] | None = None) -> list[Any]: - reschedule_date: datetime | None = None + reschedule_date = timezone.utcnow() results = results or [] failed_futures = [] From 33bb6b75882b85e433daad7fce0146a454d2caaf Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 1 Oct 2024 10:07:03 +0200 Subject: [PATCH 15/97] refactor: Reorganized imports in StreamedOperator --- airflow/models/streamedoperator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index c5e58b83fab3e..e7f5213f3b5d8 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -22,7 +22,6 @@ import os from asyncio import AbstractEventLoop, Semaphore, ensure_future, iscoroutinefunction from contextlib import contextmanager -from datetime import datetime from math import ceil from time import sleep from typing import TYPE_CHECKING, Any, Callable, Generator, Sequence @@ -54,8 +53,8 @@ if TYPE_CHECKING: import jinja2 - from sqlalchemy.orm import Session + from airflow.triggers.base import BaseTrigger, TriggerEvent From 8906fa9a5fe00493b5093827e068f8a2a125f517 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 1 Oct 2024 11:10:18 +0200 Subject: [PATCH 16/97] refactor: Force cast task to BaseOperator --- airflow/models/streamedoperator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index e7f5213f3b5d8..efb8e2db96fb4 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -24,7 +24,7 @@ from contextlib import contextmanager from math import ceil from time import sleep -from typing import TYPE_CHECKING, Any, Callable, Generator, Sequence +from typing import TYPE_CHECKING, Any, Callable, Generator, Sequence, cast from airflow import XComArg from airflow.exceptions import ( @@ -275,7 +275,7 @@ def _run_futures(self, context: Context, futures, results: list[Any] | None = No return self._run_futures(context, failed_futures, results) async def _run_task(self, context: Context, task_instance: TaskInstance): - operator = task_instance.task + operator: BaseOperator = cast(BaseOperator, task_instance.task) self.log.debug("operator: %s", operator) result = await OperatorExecutor( semaphore=self._semaphore, From eaa734a08a3a8ffec57a18e3b0ff64cbc8491201 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 1 Oct 2024 11:31:38 +0200 Subject: [PATCH 17/97] refactor: Make sure semaphore is correctly initialized with int value --- airflow/models/streamedoperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index efb8e2db96fb4..cf313471016f4 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -192,7 +192,7 @@ def __init__( self._expand_input = expand_input self._partial_kwargs = partial_kwargs or {} self._mapped_kwargs: list[dict] = [] - self._semaphore = Semaphore(self.max_active_tis_per_dag or os.cpu_count()) + self._semaphore = Semaphore(self.max_active_tis_per_dag or os.cpu_count() or 1) XComArg.apply_upstream_relationship(self, self._expand_input.value) @property From d2e93c3ec6efede6c2db4060d4309bc52ada451d Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 1 Oct 2024 11:35:34 +0200 Subject: [PATCH 18/97] refactor: Cast operator to type[BaseOperator] --- airflow/models/streamedoperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index cf313471016f4..80637364aea11 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -339,7 +339,7 @@ def stream(self, **mapped_kwargs: OperatorExpandArgument) -> StreamedOperator: start_date=start_date, end_date=end_date, max_active_tis_per_dag=max_active_tis_per_dag, - operator_class=self.operator_class, + operator_class=cast(type[BaseOperator], self.operator_class), expand_input=expand_input, retries=0, partial_kwargs=self.kwargs.copy(), From de3f31424b4044fca30a33f7d7fadf58e5024fd9 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 1 Oct 2024 11:38:45 +0200 Subject: [PATCH 19/97] refactor: Check if expand_input value is of type dict --- airflow/models/streamedoperator.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 80637364aea11..06682bf84e60d 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -211,20 +211,21 @@ def _unmap_operator(self, index): return self._operator_class(**kwargs) def _resolve_expand_input(self, context: Context, session: Session): - for key, value in self._expand_input.value.items(): - if _needs_run_time_resolution(value): - value = value.resolve(context=context, session=session) + if isinstance(self._expand_input.value, dict): + for key, value in self._expand_input.value.items(): + if _needs_run_time_resolution(value): + value = value.resolve(context=context, session=session) - if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): - value = list(value) + if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): + value = list(value) - self.log.debug("resolved_value: %s", value) + self.log.debug("resolved_value: %s", value) - if isinstance(value, list): - self._mapped_kwargs.extend([{key: item} for item in value]) - else: - self._mapped_kwargs.append({key: value}) - self.log.debug("resolve_expand_input: %s", self._mapped_kwargs) + if isinstance(value, list): + self._mapped_kwargs.extend([{key: item} for item in value]) + else: + self._mapped_kwargs.append({key: value}) + self.log.debug("resolve_expand_input: %s", self._mapped_kwargs) def render_template_fields( self, From ceb024f9f67db993dd20ed3823661cdf93077bbc Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 1 Oct 2024 11:53:40 +0200 Subject: [PATCH 20/97] refactor: Operator class is of type BaseOperator only --- airflow/models/streamedoperator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 06682bf84e60d..781531f31888c 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -175,7 +175,7 @@ async def run(self, method: Callable, *args, **kwargs): class StreamedOperator(BaseOperator): """Object representing a streamed operator in a DAG.""" - _operator_class: type[BaseOperator] | dict[str, Any] + _operator_class: type[BaseOperator] _expand_input: ExpandInput _partial_kwargs: dict[str, Any] @@ -340,7 +340,7 @@ def stream(self, **mapped_kwargs: OperatorExpandArgument) -> StreamedOperator: start_date=start_date, end_date=end_date, max_active_tis_per_dag=max_active_tis_per_dag, - operator_class=cast(type[BaseOperator], self.operator_class), + operator_class=self.operator_class, expand_input=expand_input, retries=0, partial_kwargs=self.kwargs.copy(), From c7ec7c09f72683dc5675cf6ea0371d35661e55eb Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 1 Oct 2024 12:04:42 +0200 Subject: [PATCH 21/97] refactor: Close the event loop if it was a newly created one in event_loop context manager --- airflow/models/streamedoperator.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 781531f31888c..ea3c99e609928 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -60,14 +60,22 @@ @contextmanager def event_loop() -> Generator[AbstractEventLoop, None, None]: + new_event_loop = False + loop = None try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - yield loop - - + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + new_event_loop = True + yield loop + finally: + if new_event_loop and loop is not None: + loop.close() + + +# TODO: Check def _run_inline_trigger(trigger) method from DAG, could be refactored so it uses this method async def run_trigger(trigger: BaseTrigger) -> list[TriggerEvent]: events = [] async for event in trigger.run(): From 555b68bb80082a9129d53e35270761fad11de262 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 1 Oct 2024 12:09:54 +0200 Subject: [PATCH 22/97] refactor: Refactored run_trigger method --- airflow/models/streamedoperator.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index ea3c99e609928..8dba1c6d4e8ae 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -76,11 +76,9 @@ def event_loop() -> Generator[AbstractEventLoop, None, None]: # TODO: Check def _run_inline_trigger(trigger) method from DAG, could be refactored so it uses this method -async def run_trigger(trigger: BaseTrigger) -> list[TriggerEvent]: - events = [] +async def run_trigger(trigger: BaseTrigger) -> TriggerEvent: async for event in trigger.run(): - events.append(event) - return events + return event class OperatorExecutor(LoggingMixin): @@ -155,7 +153,7 @@ async def _run_callable(self, method: Callable, *args, **kwargs): raise AirflowRescheduleTaskInstanceException(task_instance=self.task_instance) async def run_deferrable(self, context: Context, task_deferred: TaskDeferred): - event = next(iter(await run_trigger(task_deferred.trigger))) + event = await run_trigger(task_deferred.trigger) self.log.debug("event: %s", event) self.log.debug("next_method: %s", task_deferred.method_name) From b2b9cee1332f76f85131a412a2dcb08b136f8fc0 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 1 Oct 2024 12:16:53 +0200 Subject: [PATCH 23/97] refactor: Moved run_trigger method to BaseTrigger module so it can be reused by StreamedOperator and DAG --- airflow/models/dag.py | 6 ++---- airflow/models/streamedoperator.py | 9 +-------- airflow/triggers/base.py | 5 +++++ 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 91f8aec7302cb..92f5c180a7d30 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -3449,11 +3449,9 @@ def get_current_dag(cls) -> DAG | None: def _run_inline_trigger(trigger): - async def _run_inline_trigger_main(): - async for event in trigger.run(): - return event + from airflow.triggers.base import run_trigger - return asyncio.run(_run_inline_trigger_main()) + return asyncio.run(run_trigger(trigger)) def _run_task( diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 8dba1c6d4e8ae..b7f08c88fdda8 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -44,6 +44,7 @@ validate_mapping_kwargs, ) from airflow.models.taskinstance import TaskInstance +from airflow.triggers.base import run_trigger from airflow.utils import timezone from airflow.utils.context import Context, context_get_outlet_events from airflow.utils.helpers import prevent_duplicates @@ -55,8 +56,6 @@ import jinja2 from sqlalchemy.orm import Session - from airflow.triggers.base import BaseTrigger, TriggerEvent - @contextmanager def event_loop() -> Generator[AbstractEventLoop, None, None]: @@ -75,12 +74,6 @@ def event_loop() -> Generator[AbstractEventLoop, None, None]: loop.close() -# TODO: Check def _run_inline_trigger(trigger) method from DAG, could be refactored so it uses this method -async def run_trigger(trigger: BaseTrigger) -> TriggerEvent: - async for event in trigger.run(): - return event - - class OperatorExecutor(LoggingMixin): """ Run an operator with given task context and task instance. diff --git a/airflow/triggers/base.py b/airflow/triggers/base.py index bc1da861f3c2d..18783c7fa2adf 100644 --- a/airflow/triggers/base.py +++ b/airflow/triggers/base.py @@ -37,6 +37,11 @@ log = logging.getLogger(__name__) +async def run_trigger(trigger: BaseTrigger) -> TriggerEvent: + async for event in trigger.run(): + return event + + @dataclass class StartTriggerArgs: """Arguments required for start task execution from triggerer.""" From cb6f0de33a75eefc239793b5662f8d9dd462241e Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 1 Oct 2024 14:48:35 +0200 Subject: [PATCH 24/97] refactor: It is possible the run_trigger method doesn't yield any TriggerEvent --- airflow/triggers/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/triggers/base.py b/airflow/triggers/base.py index 18783c7fa2adf..c3ee57c443b65 100644 --- a/airflow/triggers/base.py +++ b/airflow/triggers/base.py @@ -37,7 +37,7 @@ log = logging.getLogger(__name__) -async def run_trigger(trigger: BaseTrigger) -> TriggerEvent: +async def run_trigger(trigger: BaseTrigger) -> TriggerEvent | None: async for event in trigger.run(): return event From 6c18b749011c041ee9f58e29cb82e5941354a59e Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 1 Oct 2024 14:57:16 +0200 Subject: [PATCH 25/97] refactor: Return None in run_trigger method of no events are returned --- airflow/triggers/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow/triggers/base.py b/airflow/triggers/base.py index c3ee57c443b65..c973c819db252 100644 --- a/airflow/triggers/base.py +++ b/airflow/triggers/base.py @@ -40,6 +40,7 @@ async def run_trigger(trigger: BaseTrigger) -> TriggerEvent | None: async for event in trigger.run(): return event + return None @dataclass From be49bfe1e10db67301680f4eb1c93332d9a90eb9 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 1 Oct 2024 14:59:14 +0200 Subject: [PATCH 26/97] refactor: Check if returned event in _run_deferrable method isn't None --- airflow/models/streamedoperator.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index b7f08c88fdda8..4df4794b1eba5 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -145,19 +145,21 @@ async def _run_callable(self, method: Callable, *args, **kwargs): raise AirflowRescheduleTaskInstanceException(task_instance=self.task_instance) - async def run_deferrable(self, context: Context, task_deferred: TaskDeferred): + async def _run_deferrable(self, context: Context, task_deferred: TaskDeferred): event = await run_trigger(task_deferred.trigger) self.log.debug("event: %s", event) - self.log.debug("next_method: %s", task_deferred.method_name) - if task_deferred.method_name: - next_method = BaseOperator.next_callable( - self.operator, task_deferred.method_name, task_deferred.kwargs - ) - result = next_method(context, event.payload) - self.log.debug("result: %s", result) - return result + if event: + self.log.debug("next_method: %s", task_deferred.method_name) + + if task_deferred.method_name: + next_method = BaseOperator.next_callable( + self.operator, task_deferred.method_name, task_deferred.kwargs + ) + result = next_method(context, event.payload) + self.log.debug("result: %s", result) + return result async def run(self, method: Callable, *args, **kwargs): self.operator.pre_execute(context=self.context) @@ -166,7 +168,7 @@ async def run(self, method: Callable, *args, **kwargs): try: return await self._run_callable(method, *(list(args or ()) + [self.context]), **kwargs) except TaskDeferred as task_deferred: - return await self._run_callable(self.run_deferrable, *[self.context, task_deferred]) + return await self._run_callable(self._run_deferrable, *[self.context, task_deferred]) finally: self.operator.post_execute(context=self.context) From f514f90f5a386f0436b56da100d206e8c9c42173 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 2 Oct 2024 13:33:54 +0200 Subject: [PATCH 27/97] refactor: Added stream method on PartialOperator --- airflow/models/mappedoperator.py | 33 ++++++++++++++++++++++++++ airflow/models/streamedoperator.py | 38 ------------------------------ 2 files changed, 33 insertions(+), 38 deletions(-) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 7cb44ca486b06..3e682a026337a 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -80,6 +80,7 @@ ) from airflow.models.operator import Operator from airflow.models.param import ParamsDict + from airflow.models.streamedoperator import StreamedOperator from airflow.models.xcom_arg import XComArg from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils.context import Context @@ -239,6 +240,38 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: ) return op + def stream(self, **mapped_kwargs: OperatorExpandArgument) -> StreamedOperator: + from airflow.models.streamedoperator import StreamedOperator + + if not mapped_kwargs: + raise TypeError("no arguments to expand against") + validate_mapping_kwargs(self.operator_class, "stream", mapped_kwargs) + prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") + + expand_input = DictOfListsExpandInput(mapped_kwargs) + ensure_xcomarg_return_value(expand_input.value) + + partial_kwargs = self.kwargs.copy() + task_id = partial_kwargs.pop("task_id") + dag = partial_kwargs.pop("dag") + task_group = partial_kwargs.pop("task_group") + start_date = partial_kwargs.pop("start_date") + end_date = partial_kwargs.pop("end_date") + max_active_tis_per_dag = partial_kwargs.pop("max_active_tis_per_dag", None) + + return StreamedOperator( + task_id=task_id, + dag=dag, + task_group=task_group, + start_date=start_date, + end_date=end_date, + max_active_tis_per_dag=max_active_tis_per_dag, + operator_class=self.operator_class, + expand_input=expand_input, + retries=0, + partial_kwargs=self.kwargs.copy(), + ) + @attr.define( kw_only=True, diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 4df4794b1eba5..b7808f9c24dff 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -34,20 +34,13 @@ ) from airflow.models.baseoperator import BaseOperator from airflow.models.expandinput import ( - DictOfListsExpandInput, ExpandInput, - OperatorExpandArgument, _needs_run_time_resolution, ) -from airflow.models.mappedoperator import ( - ensure_xcomarg_return_value, - validate_mapping_kwargs, -) from airflow.models.taskinstance import TaskInstance from airflow.triggers.base import run_trigger from airflow.utils import timezone from airflow.utils.context import Context, context_get_outlet_events -from airflow.utils.helpers import prevent_duplicates from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.operator_helpers import ExecutionCallableRunner from airflow.utils.task_instance_session import get_current_task_instance_session @@ -315,34 +308,3 @@ def execute(self, context: Context): self._create_future(context, index) for index, mapped_kwargs in enumerate(self._mapped_kwargs) ], ) - - -def stream(self, **mapped_kwargs: OperatorExpandArgument) -> StreamedOperator: - if not mapped_kwargs: - raise TypeError("no arguments to expand against") - validate_mapping_kwargs(self.operator_class, "stream", mapped_kwargs) - prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") - - expand_input = DictOfListsExpandInput(mapped_kwargs) - ensure_xcomarg_return_value(expand_input.value) - - partial_kwargs = self.kwargs.copy() - task_id = partial_kwargs.pop("task_id") - dag = partial_kwargs.pop("dag") - task_group = partial_kwargs.pop("task_group") - start_date = partial_kwargs.pop("start_date") - end_date = partial_kwargs.pop("end_date") - max_active_tis_per_dag = partial_kwargs.pop("max_active_tis_per_dag", None) - - return StreamedOperator( - task_id=task_id, - dag=dag, - task_group=task_group, - start_date=start_date, - end_date=end_date, - max_active_tis_per_dag=max_active_tis_per_dag, - operator_class=self.operator_class, - expand_input=expand_input, - retries=0, - partial_kwargs=self.kwargs.copy(), - ) From c0e8c460f39c04200a50ee98cd0ef39baa43b5dd Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 2 Oct 2024 13:34:14 +0200 Subject: [PATCH 28/97] refactor: Added unit test for StreamedOperator --- tests/models/test_streamedoperator.py | 1751 +++++++++++++++++++++++++ 1 file changed, 1751 insertions(+) create mode 100644 tests/models/test_streamedoperator.py diff --git a/tests/models/test_streamedoperator.py b/tests/models/test_streamedoperator.py new file mode 100644 index 0000000000000..17aab897d2498 --- /dev/null +++ b/tests/models/test_streamedoperator.py @@ -0,0 +1,1751 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from collections import defaultdict +from datetime import timedelta +from typing import TYPE_CHECKING +from unittest.mock import patch + +import pendulum +import pytest +from sqlalchemy import select + +from airflow.decorators import setup, task, task_group, teardown +from airflow.exceptions import AirflowSkipException +from airflow.models.baseoperator import BaseOperator +from airflow.models.dag import DAG +from airflow.models.mappedoperator import MappedOperator +from airflow.models.param import ParamsDict +from airflow.models.streamedoperator import StreamedOperator +from airflow.models.taskinstance import TaskInstance +from airflow.models.taskmap import TaskMap +from airflow.models.xcom_arg import XComArg +from airflow.operators.python import PythonOperator +from airflow.utils.state import TaskInstanceState +from airflow.utils.task_group import TaskGroup +from airflow.utils.task_instance_session import set_current_task_instance_session +from airflow.utils.trigger_rule import TriggerRule +from airflow.utils.xcom import XCOM_RETURN_KEY +from tests.models import DEFAULT_DATE +from tests.test_utils.mapping import expand_mapped_task +from tests.test_utils.mock_operators import MockOperator, MockOperatorWithNestedFields, NestedFields + +pytestmark = pytest.mark.db_test + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +def test_task_mapping_with_dag(): + with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE) as dag: + task1 = BaseOperator(task_id="op1") + literal = ["a", "b", "c"] + mapped = MockOperator.partial(task_id="task_2").stream(arg2=literal) + finish = MockOperator(task_id="finish") + + task1 >> mapped >> finish + + assert task1.downstream_list == [mapped] + assert mapped in dag.tasks + assert mapped.task_group == dag.task_group + # At parse time there should only be three tasks! + assert len(dag.tasks) == 3 + + assert finish.upstream_list == [mapped] + assert mapped.downstream_list == [finish] + + +@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode +@patch("airflow.models.abstractoperator.AbstractOperator.render_template") +def test_task_mapping_with_dag_and_list_of_pandas_dataframe(mock_render_template, caplog): + class UnrenderableClass: + def __bool__(self): + raise ValueError("Similar to Pandas DataFrames, this class raises an exception.") + + class CustomOperator(BaseOperator): + template_fields = ("arg",) + + def __init__(self, arg, **kwargs): + super().__init__(**kwargs) + self.arg = arg + + def execute(self, context: Context): + pass + + with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE) as dag: + task1 = CustomOperator(task_id="op1", arg=None) + unrenderable_values = [UnrenderableClass(), UnrenderableClass()] + mapped = CustomOperator.partial(task_id="task_2").stream(arg=unrenderable_values) + task1 >> mapped + dag.test() + assert ( + "Unable to check if the value of type 'UnrenderableClass' is False for task 'task_2', field 'arg'" + in caplog.text + ) + mock_render_template.assert_called() + + +def test_task_mapping_without_dag_context(): + with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE) as dag: + task1 = BaseOperator(task_id="op1") + literal = ["a", "b", "c"] + streamed = MockOperator.partial(task_id="task_2").stream(arg2=literal) + + task1 >> streamed + + assert isinstance(streamed, StreamedOperator) + assert streamed in dag.tasks + assert task1.downstream_list == [streamed] + # At parse time there should only be two tasks! + assert len(dag.tasks) == 2 + + +def test_task_mapping_default_args(): + default_args = {"start_date": DEFAULT_DATE.now(), "owner": "test"} + with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE, default_args=default_args): + task1 = BaseOperator(task_id="op1") + literal = ["a", "b", "c"] + streamed = MockOperator.partial(task_id="task_2").stream(arg2=literal) + + task1 >> streamed + + assert streamed.owner == "test" + assert streamed.start_date == pendulum.instance(default_args["start_date"]) + + +def test_task_mapping_override_default_args(): + default_args = {"retries": 2, "start_date": DEFAULT_DATE.now()} + with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE, default_args=default_args): + literal = ["a", "b", "c"] + streamed = MockOperator.partial(task_id="task", retries=1).stream(arg2=literal) + + # retries should be 0 because it will be applied on the streamed tasks + assert streamed.retries == 0 + # start_date should be equal to default_args["start_date"] because it is not provided as partial arg + assert streamed.start_date == pendulum.instance(default_args["start_date"]) + # owner should be equal to Airflow default owner (airflow) because it is not provided at all + assert streamed.owner == "airflow" + + +def test_map_unknown_arg_raises(): + with pytest.raises(TypeError, match=r"argument 'file'"): + BaseOperator.partial(task_id="a").stream(file=[1, 2, {"a": "b"}]) + + +def test_map_xcom_arg(): + """Test that dependencies are correct when mapping with an XComArg""" + with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id="task_2").stream(arg2=task1.output) + finish = MockOperator(task_id="finish") + + mapped >> finish + + assert task1.downstream_list == [mapped] + + +@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode +def test_map_xcom_arg_multiple_upstream_xcoms(dag_maker, session): + """Test that the correct number of downstream tasks are generated when mapping with an XComArg""" + + class PushExtraXComOperator(BaseOperator): + """Push an extra XCom value along with the default return value.""" + + def __init__(self, return_value, **kwargs): + super().__init__(**kwargs) + self.return_value = return_value + + def execute(self, context): + context["task_instance"].xcom_push(key="extra_key", value="extra_value") + return self.return_value + + with dag_maker("test-dag", session=session, start_date=DEFAULT_DATE) as dag: + upstream_return = [1, 2, 3] + task1 = PushExtraXComOperator(return_value=upstream_return, task_id="task_1") + task2 = PushExtraXComOperator.partial(task_id="task_2").stream(return_value=task1.output) + task3 = PushExtraXComOperator.partial(task_id="task_3").stream(return_value=task2.output) + + dr = dag_maker.create_dagrun() + ti_1 = dr.get_task_instance("task_1", session) + ti_1.run() + + ti_2s, _ = task2.expand_mapped_task(dr.run_id, session=session) + for ti in ti_2s: + ti.refresh_from_task(dag.get_task("task_2")) + ti.run() + + ti_3s, _ = task3.expand_mapped_task(dr.run_id, session=session) + for ti in ti_3s: + ti.refresh_from_task(dag.get_task("task_3")) + ti.run() + + assert len(ti_3s) == len(ti_2s) == len(upstream_return) + + +def test_partial_on_instance() -> None: + """`.partial` on an instance should fail -- it's only designed to be called on classes""" + with pytest.raises(TypeError): + MockOperator(task_id="a").partial() + + +def test_partial_on_class() -> None: + # Test that we accept args for superclasses too + op = MockOperator.partial(task_id="a", arg1="a", trigger_rule=TriggerRule.ONE_FAILED) + assert op.kwargs["arg1"] == "a" + assert op.kwargs["trigger_rule"] == TriggerRule.ONE_FAILED + + +def test_partial_on_class_invalid_ctor_args() -> None: + """Test that when we pass invalid args to partial(). + + I.e. if an arg is not known on the class or any of its parent classes we error at parse time + """ + with pytest.raises(TypeError, match=r"arguments 'foo', 'bar'"): + MockOperator.partial(task_id="a", foo="bar", bar=2) + + +@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode +@pytest.mark.parametrize( + ["num_existing_tis", "expected"], + ( + pytest.param(0, [(0, None), (1, None), (2, None)], id="only-unmapped-ti-exists"), + pytest.param( + 3, + [(0, "success"), (1, "success"), (2, "success")], + id="all-tis-exist", + ), + pytest.param( + 5, + [ + (0, "success"), + (1, "success"), + (2, "success"), + (3, TaskInstanceState.REMOVED), + (4, TaskInstanceState.REMOVED), + ], + id="tis-to-be-removed", + ), + ), +) +def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expected): + literal = [1, 2, {"a": "b"}] + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id="task_2").stream(arg2=task1.output) + + dr = dag_maker.create_dagrun() + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=len(literal), + keys=None, + ) + ) + + if num_existing_tis: + # Remove the map_index=-1 TI when we're creating other TIs + session.query(TaskInstance).filter( + TaskInstance.dag_id == mapped.dag_id, + TaskInstance.task_id == mapped.task_id, + TaskInstance.run_id == dr.run_id, + ).delete() + + for index in range(num_existing_tis): + # Give the existing TIs a state to make sure we don't change them + ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS) + session.add(ti) + session.flush() + + mapped.expand_mapped_task(dr.run_id, session=session) + + indices = ( + session.query(TaskInstance.map_index, TaskInstance.state) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TaskInstance.map_index) + .all() + ) + + assert indices == expected + + +@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode +def test_expand_mapped_task_failed_state_in_db(dag_maker, session): + """ + This test tries to recreate a faulty state in the database and checks if we can recover from it. + The state that happens is that there exists mapped task instances and the unmapped task instance. + So we have instances with map_index [-1, 0, 1]. The -1 task instances should be removed in this case. + """ + literal = [1, 2] + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id="task_2").stream(arg2=task1.output) + + dr = dag_maker.create_dagrun() + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=len(literal), + keys=None, + ) + ) + + for index in range(2): + # Give the existing TIs a state to make sure we don't change them + ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS) + session.add(ti) + session.flush() + + indices = ( + session.query(TaskInstance.map_index, TaskInstance.state) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TaskInstance.map_index) + .all() + ) + # Make sure we have the faulty state in the database + assert indices == [(-1, None), (0, "success"), (1, "success")] + + mapped.expand_mapped_task(dr.run_id, session=session) + + indices = ( + session.query(TaskInstance.map_index, TaskInstance.state) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TaskInstance.map_index) + .all() + ) + # The -1 index should be cleaned up + assert indices == [(0, "success"), (1, "success")] + + +@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode +def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id="task_2").stream(arg2=task1.output) + + dr = dag_maker.create_dagrun() + + expand_mapped_task(mapped, dr.run_id, task1.task_id, length=0, session=session) + + indices = ( + session.query(TaskInstance.map_index, TaskInstance.state) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TaskInstance.map_index) + .all() + ) + + assert indices == [(-1, TaskInstanceState.SKIPPED)] + + +def test_mapped_task_applies_default_args_classic(dag_maker): + with dag_maker(default_args={"execution_timeout": timedelta(minutes=30)}) as dag: + MockOperator(task_id="simple", arg1=None, arg2=0) + MockOperator.partial(task_id="mapped").stream(arg1=[1], arg2=[2, 3]) + + assert dag.get_task("simple").execution_timeout == timedelta(minutes=30) + assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30) + + +def test_mapped_task_applies_default_args_taskflow(dag_maker): + with dag_maker(default_args={"execution_timeout": timedelta(minutes=30)}) as dag: + + @dag.task + def simple(arg): + pass + + @dag.task + def mapped(arg): + pass + + simple(arg=0) + mapped.stream(arg=[1, 2]) + + assert dag.get_task("simple").execution_timeout == timedelta(minutes=30) + assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30) + + +@pytest.mark.parametrize( + "dag_params, task_params, expected_partial_params", + [ + pytest.param(None, None, ParamsDict(), id="none"), + pytest.param({"a": -1}, None, ParamsDict({"a": -1}), id="dag"), + pytest.param(None, {"b": -2}, ParamsDict({"b": -2}), id="task"), + pytest.param({"a": -1}, {"b": -2}, ParamsDict({"a": -1, "b": -2}), id="merge"), + ], +) +def test_mapped_expand_against_params(dag_maker, dag_params, task_params, expected_partial_params): + with dag_maker(params=dag_params) as dag: + MockOperator.partial(task_id="t", params=task_params).stream(params=[{"c": "x"}, {"d": 1}]) + + t = dag.get_task("t") + assert isinstance(t, MappedOperator) + assert t.params == expected_partial_params + assert t.expand_input.value == {"params": [{"c": "x"}, {"d": 1}]} + + +@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode +def test_mapped_render_template_fields_validating_operator(dag_maker, session, tmp_path): + file_template_dir = tmp_path / "path" / "to" + file_template_dir.mkdir(parents=True, exist_ok=True) + file_template = file_template_dir / "file.ext" + file_template.write_text("loaded data") + + with set_current_task_instance_session(session=session): + + class MyOperator(BaseOperator): + template_fields = ("partial_template", "map_template", "file_template") + template_ext = (".ext",) + + def __init__( + self, partial_template, partial_static, map_template, map_static, file_template, **kwargs + ): + for value in [partial_template, partial_static, map_template, map_static, file_template]: + assert isinstance(value, str), "value should have been resolved before unmapping" + super().__init__(**kwargs) + self.partial_template = partial_template + self.partial_static = partial_static + self.map_template = map_template + self.map_static = map_static + self.file_template = file_template + + def execute(self, context): + pass + + with dag_maker(session=session, template_searchpath=tmp_path.__fspath__()): + task1 = BaseOperator(task_id="op1") + output1 = task1.output + mapped = MyOperator.partial( + task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" + ).stream(map_template=output1, map_static=output1, file_template=["/path/to/file.ext"]) + + dr = dag_maker.create_dagrun() + ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) + + ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session) + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=1, + keys=None, + ) + ) + session.flush() + + mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) + mapped_ti.map_index = 0 + + assert isinstance(mapped_ti.task, MappedOperator) + mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) + assert isinstance(mapped_ti.task, MyOperator) + + assert mapped_ti.task.partial_template == "a", "Should be templated!" + assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" + assert mapped_ti.task.map_template == "{{ ds }}", "Should not be templated!" + assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" + assert mapped_ti.task.file_template == "loaded data", "Should be templated!" + + +@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode +def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, tmp_path): + file_template_dir = tmp_path / "path" / "to" + file_template_dir.mkdir(parents=True, exist_ok=True) + file_template = file_template_dir / "file.ext" + file_template.write_text("loaded data") + + with set_current_task_instance_session(session=session): + + class MyOperator(BaseOperator): + template_fields = ("partial_template", "map_template", "file_template") + template_ext = (".ext",) + + def __init__( + self, partial_template, partial_static, map_template, map_static, file_template, **kwargs + ): + for value in [partial_template, partial_static, map_template, map_static, file_template]: + assert isinstance(value, str), "value should have been resolved before unmapping" + super().__init__(**kwargs) + self.partial_template = partial_template + self.partial_static = partial_static + self.map_template = map_template + self.map_static = map_static + self.file_template = file_template + + def execute(self, context): + pass + + with dag_maker(session=session, template_searchpath=tmp_path.__fspath__()): + mapped = MyOperator.partial( + task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" + ).expand_kwargs( + [{"map_template": "{{ ds }}", "map_static": "{{ ds }}", "file_template": "/path/to/file.ext"}] + ) + + dr = dag_maker.create_dagrun() + + mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session, map_index=0) + + assert isinstance(mapped_ti.task, MappedOperator) + mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) + assert isinstance(mapped_ti.task, MyOperator) + + assert mapped_ti.task.partial_template == "a", "Should be templated!" + assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be templated!" + assert mapped_ti.task.map_template == "2016-01-01", "Should be templated!" + assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!" + assert mapped_ti.task.file_template == "loaded data", "Should be templated!" + + +@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode +def test_mapped_render_nested_template_fields(dag_maker, session): + with dag_maker(session=session): + MockOperatorWithNestedFields.partial( + task_id="t", arg2=NestedFields(field_1="{{ ti.task_id }}", field_2="value_2") + ).stream(arg1=["{{ ti.task_id }}", ["s", "{{ ti.task_id }}"]]) + + dr = dag_maker.create_dagrun() + decision = dr.task_instance_scheduling_decisions() + tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} + assert len(tis) == 2 + + ti = tis[("t", 0)] + ti.run(session=session) + assert ti.task.arg1 == "t" + assert ti.task.arg2.field_1 == "t" + assert ti.task.arg2.field_2 == "value_2" + + ti = tis[("t", 1)] + ti.run(session=session) + assert ti.task.arg1 == ["s", "t"] + assert ti.task.arg2.field_1 == "t" + assert ti.task.arg2.field_2 == "value_2" + + +@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode +@pytest.mark.parametrize( + ["num_existing_tis", "expected"], + ( + pytest.param(0, [(0, None), (1, None), (2, None)], id="only-unmapped-ti-exists"), + pytest.param( + 3, + [(0, "success"), (1, "success"), (2, "success")], + id="all-tis-exist", + ), + pytest.param( + 5, + [ + (0, "success"), + (1, "success"), + (2, "success"), + (3, TaskInstanceState.REMOVED), + (4, TaskInstanceState.REMOVED), + ], + id="tis-to-be-removed", + ), + ), +) +def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis, expected): + literal = [{"arg1": "a"}, {"arg1": "b"}, {"arg1": "c"}] + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id="task_2").expand_kwargs(task1.output) + + dr = dag_maker.create_dagrun() + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=len(literal), + keys=None, + ) + ) + + if num_existing_tis: + # Remove the map_index=-1 TI when we're creating other TIs + session.query(TaskInstance).filter( + TaskInstance.dag_id == mapped.dag_id, + TaskInstance.task_id == mapped.task_id, + TaskInstance.run_id == dr.run_id, + ).delete() + + for index in range(num_existing_tis): + # Give the existing TIs a state to make sure we don't change them + ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS) + session.add(ti) + session.flush() + + mapped.expand_mapped_task(dr.run_id, session=session) + + indices = ( + session.query(TaskInstance.map_index, TaskInstance.state) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TaskInstance.map_index) + .all() + ) + + assert indices == expected + + +def _create_mapped_with_name_template_classic(*, task_id, map_names, template): + class HasMapName(BaseOperator): + def __init__(self, *, map_name: str, **kwargs): + super().__init__(**kwargs) + self.map_name = map_name + + def execute(self, context): + context["map_name"] = self.map_name + + return HasMapName.partial(task_id=task_id, map_index_template=template).stream( + map_name=map_names, + ) + + +def _create_mapped_with_name_template_taskflow(*, task_id, map_names, template): + from airflow.operators.python import get_current_context + + @task(task_id=task_id, map_index_template=template) + def task1(map_name): + context = get_current_context() + context["map_name"] = map_name + + return task1.stream(map_name=map_names) + + +def _create_named_map_index_renders_on_failure_classic(*, task_id, map_names, template): + class HasMapName(BaseOperator): + def __init__(self, *, map_name: str, **kwargs): + super().__init__(**kwargs) + self.map_name = map_name + + def execute(self, context): + context["map_name"] = self.map_name + raise AirflowSkipException("Imagine this task failed!") + + return HasMapName.partial(task_id=task_id, map_index_template=template).stream( + map_name=map_names, + ) + + +def _create_named_map_index_renders_on_failure_taskflow(*, task_id, map_names, template): + from airflow.operators.python import get_current_context + + @task(task_id=task_id, map_index_template=template) + def task1(map_name): + context = get_current_context() + context["map_name"] = map_name + raise AirflowSkipException("Imagine this task failed!") + + return task1.stream(map_name=map_names) + + +@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode +@pytest.mark.parametrize( + "template, expected_rendered_names", + [ + pytest.param(None, [None, None], id="unset"), + pytest.param("", ["", ""], id="constant"), + pytest.param("{{ ti.task_id }}-{{ ti.map_index }}", ["task1-0", "task1-1"], id="builtin"), + pytest.param("{{ ti.task_id }}-{{ map_name }}", ["task1-a", "task1-b"], id="custom"), + ], +) +@pytest.mark.parametrize( + "create_mapped_task", + [ + pytest.param(_create_mapped_with_name_template_classic, id="classic"), + pytest.param(_create_mapped_with_name_template_taskflow, id="taskflow"), + pytest.param(_create_named_map_index_renders_on_failure_classic, id="classic-failure"), + pytest.param(_create_named_map_index_renders_on_failure_taskflow, id="taskflow-failure"), + ], +) +def test_expand_mapped_task_instance_with_named_index( + dag_maker, + session, + create_mapped_task, + template, + expected_rendered_names, +) -> None: + """Test that the correct number of downstream tasks are generated when mapping with an XComArg""" + with dag_maker("test-dag", session=session, start_date=DEFAULT_DATE): + create_mapped_task(task_id="task1", map_names=["a", "b"], template=template) + + dr = dag_maker.create_dagrun() + tis = dr.get_task_instances() + for ti in tis: + ti.run() + session.flush() + + indices = session.scalars( + select(TaskInstance.rendered_map_index) + .where( + TaskInstance.dag_id == "test-dag", + TaskInstance.task_id == "task1", + TaskInstance.run_id == dr.run_id, + ) + .order_by(TaskInstance.map_index) + ).all() + + assert indices == expected_rendered_names + + +@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode +@pytest.mark.parametrize( + "map_index, expected", + [ + pytest.param(0, "2016-01-01", id="0"), + pytest.param(1, 2, id="1"), + ], +) +def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, map_index, expected): + with set_current_task_instance_session(session=session): + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id="a", arg2="{{ ti.task_id }}").expand_kwargs(task1.output) + + dr = dag_maker.create_dagrun() + ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) + + ti.xcom_push(key=XCOM_RETURN_KEY, value=[{"arg1": "{{ ds }}"}, {"arg1": 2}], session=session) + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=2, + keys=None, + ) + ) + session.flush() + + ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) + ti.refresh_from_task(mapped) + ti.map_index = map_index + assert isinstance(ti.task, MappedOperator) + mapped.render_template_fields(context=ti.get_template_context(session=session)) + assert isinstance(ti.task, MockOperator) + assert ti.task.arg1 == expected + assert ti.task.arg2 == "a" + + +def test_xcomarg_property_of_mapped_operator(dag_maker): + with dag_maker("test_xcomarg_property_of_mapped_operator"): + op_a = MockOperator.partial(task_id="a").stream(arg1=["x", "y", "z"]) + dag_maker.create_dagrun() + + assert op_a.output == XComArg(op_a) + + +def test_set_xcomarg_dependencies_with_mapped_operator(dag_maker): + with dag_maker("test_set_xcomargs_dependencies_with_mapped_operator"): + op1 = MockOperator.partial(task_id="op1").stream(arg1=[1, 2, 3]) + op2 = MockOperator.partial(task_id="op2").stream(arg2=["a", "b", "c"]) + op3 = MockOperator(task_id="op3", arg1=op1.output) + op4 = MockOperator(task_id="op4", arg1=[op1.output, op2.output]) + op5 = MockOperator(task_id="op5", arg1={"op1": op1.output, "op2": op2.output}) + + assert op1 in op3.upstream_list + assert op1 in op4.upstream_list + assert op2 in op4.upstream_list + assert op1 in op5.upstream_list + assert op2 in op5.upstream_list + + +def test_all_xcomargs_from_mapped_tasks_are_consumable(dag_maker, session): + class PushXcomOperator(MockOperator): + def __init__(self, arg1, **kwargs): + super().__init__(arg1=arg1, **kwargs) + + def execute(self, context): + return self.arg1 + + class ConsumeXcomOperator(PushXcomOperator): + def execute(self, context): + assert set(self.arg1) == {1, 2, 3} + + with dag_maker("test_all_xcomargs_from_mapped_tasks_are_consumable"): + op1 = PushXcomOperator.partial(task_id="op1").stream(arg1=[1, 2, 3]) + ConsumeXcomOperator(task_id="op2", arg1=op1.output) + + dr = dag_maker.create_dagrun() + tis = dr.get_task_instances(session=session) + for ti in tis: + ti.run() + + +def test_task_mapping_with_task_group_context(): + with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE) as dag: + task1 = BaseOperator(task_id="op1") + finish = MockOperator(task_id="finish") + + with TaskGroup("test-group") as group: + literal = ["a", "b", "c"] + mapped = MockOperator.partial(task_id="task_2").stream(arg2=literal) + + task1 >> group >> finish + + assert task1.downstream_list == [mapped] + assert mapped.upstream_list == [task1] + + assert mapped in dag.tasks + assert mapped.task_group == group + + assert finish.upstream_list == [mapped] + assert mapped.downstream_list == [finish] + + +def test_task_mapping_with_explicit_task_group(): + with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE) as dag: + task1 = BaseOperator(task_id="op1") + finish = MockOperator(task_id="finish") + + group = TaskGroup("test-group") + literal = ["a", "b", "c"] + mapped = MockOperator.partial(task_id="task_2", task_group=group).stream(arg2=literal) + + task1 >> group >> finish + + assert task1.downstream_list == [mapped] + assert mapped.upstream_list == [task1] + + assert mapped in dag.tasks + assert mapped.task_group == group + + assert finish.upstream_list == [mapped] + assert mapped.downstream_list == [finish] + + +class TestMappedSetupTeardown: + @staticmethod + def get_states(dr): + ti_dict = defaultdict(dict) + for ti in dr.get_task_instances(): + if ti.map_index == -1: + ti_dict[ti.task_id] = ti.state + else: + ti_dict[ti.task_id][ti.map_index] = ti.state + return dict(ti_dict) + + def classic_operator(self, task_id, ret=None, partial=False, fail=False): + def success_callable(ret=None): + def inner(*args, **kwargs): + print(args) + print(kwargs) + if ret: + return ret + + return inner + + def failure_callable(): + def inner(*args, **kwargs): + print(args) + print(kwargs) + raise ValueError("fail") + + return inner + + kwargs = dict(task_id=task_id) + if not fail: + kwargs.update(python_callable=success_callable(ret=ret)) + else: + kwargs.update(python_callable=failure_callable()) + if partial: + return PythonOperator.partial(**kwargs) + else: + return PythonOperator(**kwargs) + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode + @pytest.mark.parametrize("type_", ["taskflow", "classic"]) + def test_one_to_many_work_failed(self, type_, dag_maker): + """ + Work task failed. Setup maps to teardown. Should have 3 teardowns all successful even + though the work task has failed. + """ + if type_ == "taskflow": + with dag_maker() as dag: + + @setup + def my_setup(): + print("setting up multiple things") + return [1, 2, 3] + + @task + def my_work(val): + print(f"doing work with multiple things: {val}") + raise ValueError("fail!") + + @teardown + def my_teardown(val): + print(f"teardown: {val}") + + s = my_setup() + t = my_teardown.stream(val=s) + with t: + my_work(s) + else: + + @task + def my_work(val): + print(f"work: {val}") + raise ValueError("i fail") + + with dag_maker() as dag: + my_setup = self.classic_operator("my_setup", [[1], [2], [3]]) + my_teardown = self.classic_operator("my_teardown", partial=True) + t = my_teardown.stream(op_args=my_setup.output) + with t.as_teardown(setups=my_setup): + my_work(my_setup.output) + + dr = dag.test() + states = self.get_states(dr) + expected = { + "my_setup": "success", + "my_work": "failed", + "my_teardown": {0: "success", 1: "success", 2: "success"}, + } + assert states == expected + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode + @pytest.mark.parametrize("type_", ["taskflow", "classic"]) + def test_many_one_explicit_odd_setup_mapped_setups_fail(self, type_, dag_maker): + """ + one unmapped setup goes to two different teardowns + one mapped setup goes to same teardown + mapped setups fail + teardowns should still run + """ + if type_ == "taskflow": + with dag_maker() as dag: + + @task + def other_setup(): + print("other setup") + return "other setup" + + @task + def other_work(): + print("other work") + return "other work" + + @task + def other_teardown(): + print("other teardown") + return "other teardown" + + @task + def my_setup(val): + print(f"setup: {val}") + raise ValueError("fail") + return val + + @task + def my_work(val): + print(f"work: {val}") + + @task + def my_teardown(val): + print(f"teardown: {val}") + + s = my_setup.stream(val=["data1.json", "data2.json", "data3.json"]) + o_setup = other_setup() + o_teardown = other_teardown() + with o_teardown.as_teardown(setups=o_setup): + other_work() + t = my_teardown(s).as_teardown(setups=s) + with t: + my_work(s) + o_setup >> t + else: + with dag_maker() as dag: + + @task + def other_work(): + print("other work") + return "other work" + + @task + def my_work(val): + print(f"work: {val}") + + my_teardown = self.classic_operator("my_teardown") + + my_setup = self.classic_operator("my_setup", partial=True, fail=True) + s = my_setup.stream(op_args=[["data1.json"], ["data2.json"], ["data3.json"]]) + o_setup = self.classic_operator("other_setup") + o_teardown = self.classic_operator("other_teardown") + with o_teardown.as_teardown(setups=o_setup): + other_work() + t = my_teardown.as_teardown(setups=s) + with t: + my_work(s.output) + o_setup >> t + + dr = dag.test() + states = self.get_states(dr) + expected = { + "my_setup": {0: "failed", 1: "failed", 2: "failed"}, + "other_setup": "success", + "other_teardown": "success", + "other_work": "success", + "my_teardown": "success", + "my_work": "upstream_failed", + } + assert states == expected + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode + @pytest.mark.parametrize("type_", ["taskflow", "classic"]) + def test_many_one_explicit_odd_setup_all_setups_fail(self, type_, dag_maker): + """ + one unmapped setup goes to two different teardowns + one mapped setup goes to same teardown + all setups fail + teardowns should not run + """ + if type_ == "taskflow": + with dag_maker() as dag: + + @task + def other_setup(): + print("other setup") + raise ValueError("fail") + return "other setup" + + @task + def other_work(): + print("other work") + return "other work" + + @task + def other_teardown(): + print("other teardown") + return "other teardown" + + @task + def my_setup(val): + print(f"setup: {val}") + raise ValueError("fail") + return val + + @task + def my_work(val): + print(f"work: {val}") + + @task + def my_teardown(val): + print(f"teardown: {val}") + + s = my_setup.stream(val=["data1.json", "data2.json", "data3.json"]) + o_setup = other_setup() + o_teardown = other_teardown() + with o_teardown.as_teardown(setups=o_setup): + other_work() + t = my_teardown(s).as_teardown(setups=s) + with t: + my_work(s) + o_setup >> t + else: + with dag_maker() as dag: + + @task + def other_setup(): + print("other setup") + raise ValueError("fail") + return "other setup" + + @task + def other_work(): + print("other work") + return "other work" + + @task + def other_teardown(): + print("other teardown") + return "other teardown" + + @task + def my_work(val): + print(f"work: {val}") + + my_setup = self.classic_operator("my_setup", partial=True, fail=True) + s = my_setup.stream(op_args=[["data1.json"], ["data2.json"], ["data3.json"]]) + o_setup = other_setup() + o_teardown = other_teardown() + with o_teardown.as_teardown(setups=o_setup): + other_work() + my_teardown = self.classic_operator("my_teardown") + t = my_teardown.as_teardown(setups=s) + with t: + my_work(s.output) + o_setup >> t + + dr = dag.test() + states = self.get_states(dr) + expected = { + "my_teardown": "upstream_failed", + "other_setup": "failed", + "other_work": "upstream_failed", + "other_teardown": "upstream_failed", + "my_setup": {0: "failed", 1: "failed", 2: "failed"}, + "my_work": "upstream_failed", + } + assert states == expected + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode + @pytest.mark.parametrize("type_", ["taskflow", "classic"]) + def test_many_one_explicit_odd_setup_one_mapped_fails(self, type_, dag_maker): + """ + one unmapped setup goes to two different teardowns + one mapped setup goes to same teardown + one of the mapped setup instances fails + teardowns should all run + """ + if type_ == "taskflow": + with dag_maker() as dag: + + @task + def other_setup(): + print("other setup") + return "other setup" + + @task + def other_work(): + print("other work") + return "other work" + + @task + def other_teardown(): + print("other teardown") + return "other teardown" + + @task + def my_setup(val): + if val == "data2.json": + raise ValueError("fail!") + elif val == "data3.json": + raise AirflowSkipException("skip!") + print(f"setup: {val}") + return val + + @task + def my_work(val): + print(f"work: {val}") + + @task + def my_teardown(val): + print(f"teardown: {val}") + + s = my_setup.stream(val=["data1.json", "data2.json", "data3.json"]) + o_setup = other_setup() + o_teardown = other_teardown() + with o_teardown.as_teardown(setups=o_setup): + other_work() + t = my_teardown(s).as_teardown(setups=s) + with t: + my_work(s) + o_setup >> t + else: + with dag_maker() as dag: + + @task + def other_setup(): + print("other setup") + return "other setup" + + @task + def other_work(): + print("other work") + return "other work" + + @task + def other_teardown(): + print("other teardown") + return "other teardown" + + def my_setup_callable(val): + if val == "data2.json": + raise ValueError("fail!") + elif val == "data3.json": + raise AirflowSkipException("skip!") + print(f"setup: {val}") + return val + + my_setup = PythonOperator.partial(task_id="my_setup", python_callable=my_setup_callable) + + @task + def my_work(val): + print(f"work: {val}") + + def my_teardown_callable(val): + print(f"teardown: {val}") + + s = my_setup.stream(op_args=[["data1.json"], ["data2.json"], ["data3.json"]]) + o_setup = other_setup() + o_teardown = other_teardown() + with o_teardown.as_teardown(setups=o_setup): + other_work() + my_teardown = PythonOperator( + task_id="my_teardown", op_args=[s.output], python_callable=my_teardown_callable + ) + t = my_teardown.as_teardown(setups=s) + with t: + my_work(s.output) + o_setup >> t + + dr = dag.test() + states = self.get_states(dr) + expected = { + "my_setup": {0: "success", 1: "failed", 2: "skipped"}, + "other_setup": "success", + "other_teardown": "success", + "other_work": "success", + "my_teardown": "success", + "my_work": "upstream_failed", + } + assert states == expected + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode + @pytest.mark.parametrize("type_", ["taskflow", "classic"]) + def test_one_to_many_as_teardown(self, type_, dag_maker): + """ + 1 setup mapping to 3 teardowns + 1 work task + work fails + teardowns succeed + dagrun should be failure + """ + if type_ == "taskflow": + with dag_maker() as dag: + + @task + def my_setup(): + print("setting up multiple things") + return [1, 2, 3] + + @task + def my_work(val): + print(f"doing work with multiple things: {val}") + raise ValueError("this fails") + return val + + @task + def my_teardown(val): + print(f"teardown: {val}") + + s = my_setup() + t = my_teardown.stream(val=s).as_teardown(setups=s) + with t: + my_work(s) + else: + with dag_maker() as dag: + + @task + def my_work(val): + print(f"doing work with multiple things: {val}") + raise ValueError("this fails") + return val + + my_teardown = self.classic_operator(task_id="my_teardown", partial=True) + + s = self.classic_operator(task_id="my_setup", ret=[[1], [2], [3]]) + t = my_teardown.stream(op_args=s.output).as_teardown(setups=s) + with t: + my_work(s) + dr = dag.test() + states = self.get_states(dr) + expected = { + "my_setup": "success", + "my_teardown": {0: "success", 1: "success", 2: "success"}, + "my_work": "failed", + } + assert states == expected + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode + @pytest.mark.parametrize("type_", ["taskflow", "classic"]) + def test_one_to_many_as_teardown_on_failure_fail_dagrun(self, type_, dag_maker): + """ + 1 setup mapping to 3 teardowns + 1 work task + work succeeds + all but one teardown succeed + on_failure_fail_dagrun=True + dagrun should be success + """ + if type_ == "taskflow": + with dag_maker() as dag: + + @task + def my_setup(): + print("setting up multiple things") + return [1, 2, 3] + + @task + def my_work(val): + print(f"doing work with multiple things: {val}") + return val + + @task + def my_teardown(val): + print(f"teardown: {val}") + if val == 2: + raise ValueError("failure") + + s = my_setup() + t = my_teardown.stream(val=s).as_teardown(setups=s, on_failure_fail_dagrun=True) + with t: + my_work(s) + # todo: if on_failure_fail_dagrun=True, should we still regard the WORK task as a leaf? + else: + with dag_maker() as dag: + + @task + def my_work(val): + print(f"doing work with multiple things: {val}") + return val + + def my_teardown_callable(val): + print(f"teardown: {val}") + if val == 2: + raise ValueError("failure") + + s = self.classic_operator(task_id="my_setup", ret=[[1], [2], [3]]) + my_teardown = PythonOperator.partial( + task_id="my_teardown", python_callable=my_teardown_callable + ).stream(op_args=s.output) + t = my_teardown.as_teardown(setups=s, on_failure_fail_dagrun=True) + with t: + my_work(s.output) + + dr = dag.test() + states = self.get_states(dr) + expected = { + "my_setup": "success", + "my_teardown": {0: "success", 1: "failed", 2: "success"}, + "my_work": "success", + } + assert states == expected + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode + @pytest.mark.parametrize("type_", ["taskflow", "classic"]) + def test_mapped_task_group_simple(self, type_, dag_maker, session): + """ + Mapped task group wherein there's a simple s >> w >> t pipeline. + When s is skipped, all should be skipped + When s is failed, all should be upstream failed + """ + if type_ == "taskflow": + with dag_maker() as dag: + + @setup + def my_setup(val): + if val == "data2.json": + raise ValueError("fail!") + elif val == "data3.json": + raise AirflowSkipException("skip!") + print(f"setup: {val}") + + @task + def my_work(val): + print(f"work: {val}") + + @teardown + def my_teardown(val): + print(f"teardown: {val}") + + @task_group + def file_transforms(filename): + s = my_setup(filename) + t = my_teardown(filename) + s >> t + with t: + my_work(filename) + + file_transforms.stream(filename=["data1.json", "data2.json", "data3.json"]) + else: + with dag_maker() as dag: + + def my_setup_callable(val): + if val == "data2.json": + raise ValueError("fail!") + elif val == "data3.json": + raise AirflowSkipException("skip!") + print(f"setup: {val}") + + @task + def my_work(val): + print(f"work: {val}") + + def my_teardown_callable(val): + print(f"teardown: {val}") + + @task_group + def file_transforms(filename): + s = PythonOperator( + task_id="my_setup", python_callable=my_setup_callable, op_args=filename + ) + t = PythonOperator( + task_id="my_teardown", python_callable=my_teardown_callable, op_args=filename + ) + with t.as_teardown(setups=s): + my_work(filename) + + file_transforms.stream(filename=[["data1.json"], ["data2.json"], ["data3.json"]]) + dr = dag.test() + states = self.get_states(dr) + expected = { + "file_transforms.my_setup": {0: "success", 1: "failed", 2: "skipped"}, + "file_transforms.my_work": {0: "success", 1: "upstream_failed", 2: "skipped"}, + "file_transforms.my_teardown": {0: "success", 1: "upstream_failed", 2: "skipped"}, + } + + assert states == expected + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode + @pytest.mark.parametrize("type_", ["taskflow", "classic"]) + def test_mapped_task_group_work_fail_or_skip(self, type_, dag_maker): + """ + Mapped task group wherein there's a simple s >> w >> t pipeline. + When w is skipped, teardown should still run + When w is failed, teardown should still run + """ + if type_ == "taskflow": + with dag_maker() as dag: + + @setup + def my_setup(val): + print(f"setup: {val}") + + @task + def my_work(val): + if val == "data2.json": + raise ValueError("fail!") + elif val == "data3.json": + raise AirflowSkipException("skip!") + print(f"work: {val}") + + @teardown + def my_teardown(val): + print(f"teardown: {val}") + + @task_group + def file_transforms(filename): + s = my_setup(filename) + t = my_teardown(filename).as_teardown(setups=s) + with t: + my_work(filename) + + file_transforms.stream(filename=["data1.json", "data2.json", "data3.json"]) + else: + with dag_maker() as dag: + + @task + def my_work(vals): + val = vals[0] + if val == "data2.json": + raise ValueError("fail!") + elif val == "data3.json": + raise AirflowSkipException("skip!") + print(f"work: {val}") + + @teardown + def my_teardown(val): + print(f"teardown: {val}") + + def null_callable(val): + pass + + @task_group + def file_transforms(filename): + s = PythonOperator(task_id="my_setup", python_callable=null_callable, op_args=filename) + t = PythonOperator(task_id="my_teardown", python_callable=null_callable, op_args=filename) + t = t.as_teardown(setups=s) + with t: + my_work(filename) + + file_transforms.stream(filename=[["data1.json"], ["data2.json"], ["data3.json"]]) + dr = dag.test() + states = self.get_states(dr) + expected = { + "file_transforms.my_setup": {0: "success", 1: "success", 2: "success"}, + "file_transforms.my_teardown": {0: "success", 1: "success", 2: "success"}, + "file_transforms.my_work": {0: "success", 1: "failed", 2: "skipped"}, + } + assert states == expected + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode + @pytest.mark.parametrize("type_", ["taskflow", "classic"]) + def test_teardown_many_one_explicit(self, type_, dag_maker): + """-- passing + one mapped setup going to one unmapped work + 3 diff states for setup: success / failed / skipped + teardown still runs, and receives the xcom from the single successful setup + """ + if type_ == "taskflow": + with dag_maker() as dag: + + @task + def my_setup(val): + if val == "data2.json": + raise ValueError("fail!") + elif val == "data3.json": + raise AirflowSkipException("skip!") + print(f"setup: {val}") + return val + + @task + def my_work(val): + print(f"work: {val}") + + @task + def my_teardown(val): + print(f"teardown: {val}") + + s = my_setup.stream(val=["data1.json", "data2.json", "data3.json"]) + with my_teardown(s).as_teardown(setups=s): + my_work(s) + else: + with dag_maker() as dag: + + def my_setup_callable(val): + if val == "data2.json": + raise ValueError("fail!") + elif val == "data3.json": + raise AirflowSkipException("skip!") + print(f"setup: {val}") + return val + + @task + def my_work(val): + print(f"work: {val}") + + s = PythonOperator.partial(task_id="my_setup", python_callable=my_setup_callable) + s = s.stream(op_args=[["data1.json"], ["data2.json"], ["data3.json"]]) + t = self.classic_operator("my_teardown") + with t.as_teardown(setups=s): + my_work(s.output) + + dr = dag.test() + states = self.get_states(dr) + expected = { + "my_setup": {0: "success", 1: "failed", 2: "skipped"}, + "my_teardown": "success", + "my_work": "upstream_failed", + } + assert states == expected + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode + def test_one_to_many_with_teardown_and_fail_stop(self, dag_maker): + """ + With fail_stop enabled, the teardown for an already-completed setup + should not be skipped. + """ + with dag_maker(fail_stop=True) as dag: + + @task + def my_setup(): + print("setting up multiple things") + return [1, 2, 3] + + @task + def my_work(val): + print(f"doing work with multiple things: {val}") + raise ValueError("this fails") + return val + + @task + def my_teardown(val): + print(f"teardown: {val}") + + s = my_setup() + t = my_teardown.stream(val=s).as_teardown(setups=s) + with t: + my_work(s) + + dr = dag.test() + states = self.get_states(dr) + expected = { + "my_setup": "success", + "my_teardown": {0: "success", 1: "success", 2: "success"}, + "my_work": "failed", + } + assert states == expected + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode + def test_one_to_many_with_teardown_and_fail_stop_more_tasks(self, dag_maker): + """ + when fail_stop enabled, teardowns should run according to their setups. + in this case, the second teardown skips because its setup skips. + """ + with dag_maker(fail_stop=True) as dag: + for num in (1, 2): + with TaskGroup(f"tg_{num}"): + + @task + def my_setup(): + print("setting up multiple things") + return [1, 2, 3] + + @task + def my_work(val): + print(f"doing work with multiple things: {val}") + raise ValueError("this fails") + return val + + @task + def my_teardown(val): + print(f"teardown: {val}") + + s = my_setup() + t = my_teardown.stream(val=s).as_teardown(setups=s) + with t: + my_work(s) + tg1, tg2 = dag.task_group.children.values() + tg1 >> tg2 + dr = dag.test() + states = self.get_states(dr) + expected = { + "tg_1.my_setup": "success", + "tg_1.my_teardown": {0: "success", 1: "success", 2: "success"}, + "tg_1.my_work": "failed", + "tg_2.my_setup": "skipped", + "tg_2.my_teardown": "skipped", + "tg_2.my_work": "skipped", + } + assert states == expected + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode + def test_one_to_many_with_teardown_and_fail_stop_more_tasks_mapped_setup(self, dag_maker): + """ + when fail_stop enabled, teardowns should run according to their setups. + in this case, the second teardown skips because its setup skips. + """ + with dag_maker(fail_stop=True) as dag: + for num in (1, 2): + with TaskGroup(f"tg_{num}"): + + @task + def my_pre_setup(): + print("input to the setup") + return [1, 2, 3] + + @task + def my_setup(val): + print("setting up multiple things") + return val + + @task + def my_work(val): + print(f"doing work with multiple things: {val}") + raise ValueError("this fails") + return val + + @task + def my_teardown(val): + print(f"teardown: {val}") + + s = my_setup.stream(val=my_pre_setup()) + t = my_teardown.stream(val=s).as_teardown(setups=s) + with t: + my_work(s) + tg1, tg2 = dag.task_group.children.values() + tg1 >> tg2 + dr = dag.test() + states = self.get_states(dr) + expected = { + "tg_1.my_pre_setup": "success", + "tg_1.my_setup": {0: "success", 1: "success", 2: "success"}, + "tg_1.my_teardown": {0: "success", 1: "success", 2: "success"}, + "tg_1.my_work": "failed", + "tg_2.my_pre_setup": "skipped", + "tg_2.my_setup": "skipped", + "tg_2.my_teardown": "skipped", + "tg_2.my_work": "skipped", + } + assert states == expected + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode + def test_skip_one_mapped_task_from_task_group_with_generator(self, dag_maker): + with dag_maker() as dag: + + @task + def make_list(): + return [1, 2, 3] + + @task + def double(n): + if n == 2: + raise AirflowSkipException() + return n * 2 + + @task + def last(n): ... + + @task_group + def group(n: int) -> None: + last(double(n)) + + group.stream(n=make_list()) + + dr = dag.test() + states = self.get_states(dr) + expected = { + "group.double": {0: "success", 1: "skipped", 2: "success"}, + "group.last": {0: "success", 1: "skipped", 2: "success"}, + "make_list": "success", + } + assert states == expected + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode + def test_skip_one_mapped_task_from_task_group(self, dag_maker): + with dag_maker() as dag: + + @task + def double(n): + if n == 2: + raise AirflowSkipException() + return n * 2 + + @task + def last(n): ... + + @task_group + def group(n: int) -> None: + last(double(n)) + + group.stream(n=[1, 2, 3]) + + dr = dag.test() + states = self.get_states(dr) + expected = { + "group.double": {0: "success", 1: "skipped", 2: "success"}, + "group.last": {0: "success", 1: "skipped", 2: "success"}, + } + assert states == expected From e18b0662facd3f382062d61f94e246b3a5d0365c Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 2 Oct 2024 17:11:09 +0200 Subject: [PATCH 29/97] refactor: Added stream method to _TaskDecorator class --- airflow/decorators/base.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index bb9602d50c1cd..9be40e2b56414 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -56,7 +56,7 @@ ListOfDictsExpandInput, is_mappable, ) -from airflow.models.mappedoperator import MappedOperator, ensure_xcomarg_return_value +from airflow.models.mappedoperator import MappedOperator, ensure_xcomarg_return_value, validate_mapping_kwargs from airflow.models.pool import Pool from airflow.models.xcom_arg import XComArg from airflow.typing_compat import ParamSpec, Protocol @@ -532,6 +532,39 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: ) return XComArg(operator=operator) + def stream(self, **mapped_kwargs: OperatorExpandArgument) -> XComArg: + from airflow.models.streamedoperator import StreamedOperator + + if not mapped_kwargs: + raise TypeError("no arguments to expand against") + validate_mapping_kwargs(self.operator_class, "stream", mapped_kwargs) + prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") + + expand_input = DictOfListsExpandInput(mapped_kwargs) + ensure_xcomarg_return_value(expand_input.value) + + partial_kwargs = self.kwargs.copy() + task_id = partial_kwargs.pop("task_id") + dag = partial_kwargs.pop("dag") + task_group = partial_kwargs.pop("task_group") + start_date = partial_kwargs.pop("start_date") + end_date = partial_kwargs.pop("end_date") + max_active_tis_per_dag = partial_kwargs.pop("max_active_tis_per_dag", None) + + operator = StreamedOperator( + task_id=task_id, + dag=dag, + task_group=task_group, + start_date=start_date, + end_date=end_date, + max_active_tis_per_dag=max_active_tis_per_dag, + operator_class=self.operator_class, + expand_input=expand_input, + retries=0, + partial_kwargs=self.kwargs.copy(), + ) + return XComArg(operator=operator) + def partial(self, **kwargs: Any) -> _TaskDecorator[FParams, FReturn, OperatorSubclass]: self._validate_arg_names("partial", kwargs) old_kwargs = self.kwargs.get("op_kwargs", {}) From a688e4cf270bb48897a1aa35f2091d03646f1a9e Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 2 Oct 2024 17:58:40 +0200 Subject: [PATCH 30/97] refactor: Suppress AttributeError when close on loop fails as some event loops don't support it --- airflow/models/streamedoperator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index b7808f9c24dff..87fbba789440d 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -21,7 +21,7 @@ import logging import os from asyncio import AbstractEventLoop, Semaphore, ensure_future, iscoroutinefunction -from contextlib import contextmanager +from contextlib import contextmanager, suppress from math import ceil from time import sleep from typing import TYPE_CHECKING, Any, Callable, Generator, Sequence, cast @@ -64,7 +64,8 @@ def event_loop() -> Generator[AbstractEventLoop, None, None]: yield loop finally: if new_event_loop and loop is not None: - loop.close() + with suppress(AttributeError): + loop.close() class OperatorExecutor(LoggingMixin): From a7c79c17b3518031db4d4e2ac0b3ea387486650c Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 2 Oct 2024 18:52:21 +0200 Subject: [PATCH 31/97] refactor: Removed invocation of validate_mapping_kwargs in stream method of _TaskDecorator --- airflow/decorators/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 9be40e2b56414..8c77c7f2d0200 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -56,7 +56,7 @@ ListOfDictsExpandInput, is_mappable, ) -from airflow.models.mappedoperator import MappedOperator, ensure_xcomarg_return_value, validate_mapping_kwargs +from airflow.models.mappedoperator import MappedOperator, ensure_xcomarg_return_value from airflow.models.pool import Pool from airflow.models.xcom_arg import XComArg from airflow.typing_compat import ParamSpec, Protocol @@ -537,7 +537,6 @@ def stream(self, **mapped_kwargs: OperatorExpandArgument) -> XComArg: if not mapped_kwargs: raise TypeError("no arguments to expand against") - validate_mapping_kwargs(self.operator_class, "stream", mapped_kwargs) prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") expand_input = DictOfListsExpandInput(mapped_kwargs) From 90cf03cf0c1d873e52c17e2334ae8fb402038606 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 2 Oct 2024 18:58:31 +0200 Subject: [PATCH 32/97] refactor: test_mapped_expand_against_params should check on instance of StreamedOperator --- tests/models/test_streamedoperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_streamedoperator.py b/tests/models/test_streamedoperator.py index 17aab897d2498..2f618afec1d7f 100644 --- a/tests/models/test_streamedoperator.py +++ b/tests/models/test_streamedoperator.py @@ -401,7 +401,7 @@ def test_mapped_expand_against_params(dag_maker, dag_params, task_params, expect MockOperator.partial(task_id="t", params=task_params).stream(params=[{"c": "x"}, {"d": 1}]) t = dag.get_task("t") - assert isinstance(t, MappedOperator) + assert isinstance(t, StreamedOperator) assert t.params == expected_partial_params assert t.expand_input.value == {"params": [{"c": "x"}, {"d": 1}]} From 6f8399fb04132cff0649293dfd7494acae7261f2 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 2 Oct 2024 18:59:03 +0200 Subject: [PATCH 33/97] refactor: expand_input and partial_kwargs should be public instead of protected in StreamedOperator --- airflow/models/streamedoperator.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 87fbba789440d..e86f88443dbd9 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -171,8 +171,8 @@ class StreamedOperator(BaseOperator): """Object representing a streamed operator in a DAG.""" _operator_class: type[BaseOperator] - _expand_input: ExpandInput - _partial_kwargs: dict[str, Any] + expand_input: ExpandInput + partial_kwargs: dict[str, Any] def __init__( self, @@ -184,11 +184,11 @@ def __init__( ): super().__init__(**kwargs) self._operator_class = operator_class - self._expand_input = expand_input - self._partial_kwargs = partial_kwargs or {} + self.expand_input = expand_input + self.partial_kwargs = partial_kwargs or {} self._mapped_kwargs: list[dict] = [] self._semaphore = Semaphore(self.max_active_tis_per_dag or os.cpu_count() or 1) - XComArg.apply_upstream_relationship(self, self._expand_input.value) + XComArg.apply_upstream_relationship(self, self.expand_input.value) @property def operator_name(self) -> str: @@ -197,8 +197,8 @@ def operator_name(self) -> str: def _unmap_operator(self, index): self.log.debug("index: %s", index) kwargs = { - **self._partial_kwargs, - **{"task_id": f"{self._partial_kwargs.get('task_id')}_{index}"}, + **self.partial_kwargs, + **{"task_id": f"{self.partial_kwargs.get('task_id')}_{index}"}, **self._mapped_kwargs[index], } self.log.debug("kwargs: %s", kwargs) @@ -206,8 +206,8 @@ def _unmap_operator(self, index): return self._operator_class(**kwargs) def _resolve_expand_input(self, context: Context, session: Session): - if isinstance(self._expand_input.value, dict): - for key, value in self._expand_input.value.items(): + if isinstance(self.expand_input.value, dict): + for key, value in self.expand_input.value.items(): if _needs_run_time_resolution(value): value = value.resolve(context=context, session=session) From 51e3d0c99ac0a97940eadb85cf918ddf42662255 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 2 Oct 2024 21:13:23 +0200 Subject: [PATCH 34/97] refactor: Skip test related to stream on TaskGroups as this isn't supported yet --- tests/models/test_streamedoperator.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/models/test_streamedoperator.py b/tests/models/test_streamedoperator.py index 2f618afec1d7f..28353c67f5ab7 100644 --- a/tests/models/test_streamedoperator.py +++ b/tests/models/test_streamedoperator.py @@ -20,6 +20,7 @@ from collections import defaultdict from datetime import timedelta from typing import TYPE_CHECKING +from unittest import skip from unittest.mock import patch import pendulum @@ -1353,6 +1354,7 @@ def my_teardown_callable(val): } assert states == expected + @skip("Stream is not yet implemented on TaskGroup") @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode @pytest.mark.parametrize("type_", ["taskflow", "classic"]) def test_mapped_task_group_simple(self, type_, dag_maker, session): @@ -1428,6 +1430,7 @@ def file_transforms(filename): assert states == expected + @skip("Stream is not yet implemented on TaskGroup") @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode @pytest.mark.parametrize("type_", ["taskflow", "classic"]) def test_mapped_task_group_work_fail_or_skip(self, type_, dag_maker): @@ -1691,6 +1694,7 @@ def my_teardown(val): } assert states == expected + @skip("Stream is not yet implemented on TaskGroup") @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode def test_skip_one_mapped_task_from_task_group_with_generator(self, dag_maker): with dag_maker() as dag: @@ -1723,6 +1727,7 @@ def group(n: int) -> None: } assert states == expected + @skip("Stream is not yet implemented on TaskGroup") @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode def test_skip_one_mapped_task_from_task_group(self, dag_maker): with dag_maker() as dag: From d607ab13b16aefeb28f02761ab0889923dedfaaf Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 3 Oct 2024 08:09:45 +0200 Subject: [PATCH 35/97] refactor: Use pytest.mark.skip instead of unittest.skip --- tests/models/test_streamedoperator.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/models/test_streamedoperator.py b/tests/models/test_streamedoperator.py index 28353c67f5ab7..92e9455d943fd 100644 --- a/tests/models/test_streamedoperator.py +++ b/tests/models/test_streamedoperator.py @@ -20,7 +20,6 @@ from collections import defaultdict from datetime import timedelta from typing import TYPE_CHECKING -from unittest import skip from unittest.mock import patch import pendulum @@ -1354,7 +1353,7 @@ def my_teardown_callable(val): } assert states == expected - @skip("Stream is not yet implemented on TaskGroup") + @pytest.mark.skip("Stream is not yet implemented on TaskGroup") @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode @pytest.mark.parametrize("type_", ["taskflow", "classic"]) def test_mapped_task_group_simple(self, type_, dag_maker, session): @@ -1430,7 +1429,7 @@ def file_transforms(filename): assert states == expected - @skip("Stream is not yet implemented on TaskGroup") + @pytest.mark.skip("Stream is not yet implemented on TaskGroup") @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode @pytest.mark.parametrize("type_", ["taskflow", "classic"]) def test_mapped_task_group_work_fail_or_skip(self, type_, dag_maker): @@ -1694,7 +1693,7 @@ def my_teardown(val): } assert states == expected - @skip("Stream is not yet implemented on TaskGroup") + @pytest.mark.skip("Stream is not yet implemented on TaskGroup") @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode def test_skip_one_mapped_task_from_task_group_with_generator(self, dag_maker): with dag_maker() as dag: @@ -1727,7 +1726,7 @@ def group(n: int) -> None: } assert states == expected - @skip("Stream is not yet implemented on TaskGroup") + @pytest.mark.skip("Stream is not yet implemented on TaskGroup") @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode def test_skip_one_mapped_task_from_task_group(self, dag_maker): with dag_maker() as dag: From ad3022e7d42204eeb34bd86bc828c2c1b4d36037 Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 3 Oct 2024 10:57:20 +0200 Subject: [PATCH 36/97] refactor: Added def _get_specified_expand_input method on StreamedOperator --- airflow/models/streamedoperator.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index e86f88443dbd9..6ea7a13197961 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -194,6 +194,9 @@ def __init__( def operator_name(self) -> str: return self._operator_class.__name__ + def _get_specified_expand_input(self) -> ExpandInput: + return self.expand_input + def _unmap_operator(self, index): self.log.debug("index: %s", index) kwargs = { From c0eb1c64d4e5a71216290410122792779d9249ea Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 3 Oct 2024 11:13:35 +0200 Subject: [PATCH 37/97] refactor: Fix logging of number of workers being used for exception in StreamedOperator and only print logging statement once when an attempt is being made to run a task which raises a DeferrableTask exception --- airflow/models/streamedoperator.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 6ea7a13197961..661b4a0e9467b 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -106,15 +106,6 @@ async def _run_callable(self, method: Callable, *args, **kwargs): self.log.debug("semaphore: %s (%s)", self._semaphore, self._semaphore.locked()) async with self._semaphore: while self.task_instance.try_number <= self.operator.retries: - if self.log.isEnabledFor(logging.INFO): - self.log.info( - "Attempting running task %s of %s for %s with map_index %s.", - self.task_instance.try_number, - self.operator.retries, - type(self.operator).__name__, - self.task_instance.map_index, - ) - try: outlet_events = context_get_outlet_events(self.context) callable_runner = ExecutionCallableRunner( @@ -133,7 +124,6 @@ async def _run_callable(self, method: Callable, *args, **kwargs): ) raise e - self.log.error("An error occurred: %s", e) self.task_instance.try_number += 1 self.task_instance.end_date = timezone.utcnow() @@ -156,6 +146,15 @@ async def _run_deferrable(self, context: Context, task_deferred: TaskDeferred): return result async def run(self, method: Callable, *args, **kwargs): + if self.log.isEnabledFor(logging.INFO): + self.log.info( + "Attempting running task %s of %s for %s with map_index %s.", + self.task_instance.try_number, + self.operator.retries, + type(self.operator).__name__, + self.task_instance.map_index, + ) + self.operator.pre_execute(context=self.context) self.task_instance._run_execute_callback(context=self.context, task=self.operator) @@ -187,7 +186,9 @@ def __init__( self.expand_input = expand_input self.partial_kwargs = partial_kwargs or {} self._mapped_kwargs: list[dict] = [] - self._semaphore = Semaphore(self.max_active_tis_per_dag or os.cpu_count() or 1) + if not self.max_active_tis_per_dag: + self.max_active_tis_per_dag = os.cpu_count() or 1 + self._semaphore = Semaphore(self.max_active_tis_per_dag) XComArg.apply_upstream_relationship(self, self.expand_input.value) @property From 26782ff37a9a821f657abff29b293485c4291fb6 Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 3 Oct 2024 11:21:22 +0200 Subject: [PATCH 38/97] refactor: Pop task_group with None when key doesn't exist --- airflow/decorators/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 8c77c7f2d0200..9f4b0f880aa26 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -545,7 +545,7 @@ def stream(self, **mapped_kwargs: OperatorExpandArgument) -> XComArg: partial_kwargs = self.kwargs.copy() task_id = partial_kwargs.pop("task_id") dag = partial_kwargs.pop("dag") - task_group = partial_kwargs.pop("task_group") + task_group = partial_kwargs.pop("task_group", None) start_date = partial_kwargs.pop("start_date") end_date = partial_kwargs.pop("end_date") max_active_tis_per_dag = partial_kwargs.pop("max_active_tis_per_dag", None) From 8cdbd46247ba24edbb6b5bfc95f2ec2218c072ba Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 3 Oct 2024 11:28:26 +0200 Subject: [PATCH 39/97] refactor: Pop dag or task_group with None when key doesn't exist and lookup from current context when stream is invoked from TaskDecorator class --- airflow/decorators/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 9f4b0f880aa26..9f3bd93fa8efc 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -544,8 +544,8 @@ def stream(self, **mapped_kwargs: OperatorExpandArgument) -> XComArg: partial_kwargs = self.kwargs.copy() task_id = partial_kwargs.pop("task_id") - dag = partial_kwargs.pop("dag") - task_group = partial_kwargs.pop("task_group", None) + dag = partial_kwargs.pop("dag", None) or DagContext.get_current_dag() + task_group = partial_kwargs.pop("task_group", None) or TaskGroupContext.get_current_task_group(dag) start_date = partial_kwargs.pop("start_date") end_date = partial_kwargs.pop("end_date") max_active_tis_per_dag = partial_kwargs.pop("max_active_tis_per_dag", None) From 3409993ed2849960bd1b7ca8a94ed6b345405fb5 Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 3 Oct 2024 12:58:27 +0200 Subject: [PATCH 40/97] refactor: Added info logging statement when a task completed successfully in StreamedOperator --- airflow/models/streamedoperator.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 661b4a0e9467b..ff3591d761ef5 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -165,6 +165,15 @@ async def run(self, method: Callable, *args, **kwargs): finally: self.operator.post_execute(context=self.context) + if self.log.isEnabledFor(logging.INFO): + self.log.info( + "Task %s of %s for %s with map_index %s finished successfully.", + self.task_instance.try_number, + self.operator.retries, + type(self.operator).__name__, + self.task_instance.map_index, + ) + class StreamedOperator(BaseOperator): """Object representing a streamed operator in a DAG.""" From 26a557a6dbfa60645f9914c36120dabb4360df5f Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 3 Oct 2024 13:31:26 +0200 Subject: [PATCH 41/97] refactor: Fixed info logging statement when a task completed successfully in StreamedOperator --- airflow/models/streamedoperator.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index ff3591d761ef5..fb25fb38a7be1 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -167,11 +167,10 @@ async def run(self, method: Callable, *args, **kwargs): if self.log.isEnabledFor(logging.INFO): self.log.info( - "Task %s of %s for %s with map_index %s finished successfully.", - self.task_instance.try_number, - self.operator.retries, - type(self.operator).__name__, + "Task instance %s for %s finished successfully in %s attempts.", self.task_instance.map_index, + type(self.operator).__name__, + self.task_instance.next_try_number, ) From 52fb43ce18c7f2009fec2ebad1d69a3d85bc4217 Mon Sep 17 00:00:00 2001 From: David Blain Date: Fri, 4 Oct 2024 09:58:42 +0200 Subject: [PATCH 42/97] refactor: expand_input attribute in StreamedOperator should be a shallow copy (and hopefully avoid TypeError: cannot pickle 'Context' object) --- airflow/models/streamedoperator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index fb25fb38a7be1..965b0dbd47bb7 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -180,6 +180,7 @@ class StreamedOperator(BaseOperator): _operator_class: type[BaseOperator] expand_input: ExpandInput partial_kwargs: dict[str, Any] + shallow_copy_attrs: Sequence[str] = ("expand_input",) def __init__( self, From cf54d74b3ef5637956097b24a7f840b5381e16b8 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Oct 2024 10:57:07 +0200 Subject: [PATCH 43/97] refactor: task_id can be directly retrieved from itself in unmap_operator method --- airflow/models/streamedoperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 965b0dbd47bb7..e8577d2298a90 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -211,7 +211,7 @@ def _unmap_operator(self, index): self.log.debug("index: %s", index) kwargs = { **self.partial_kwargs, - **{"task_id": f"{self.partial_kwargs.get('task_id')}_{index}"}, + **{"task_id": f"{self.task_id}_{index}"}, **self._mapped_kwargs[index], } self.log.debug("kwargs: %s", kwargs) From 00a598cfd1c26b4924ec4ff6626f8b1db67ad230 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Oct 2024 13:45:29 +0200 Subject: [PATCH 44/97] refactor: Updated shallow_copy_attrs in StreamedOperator --- airflow/models/streamedoperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index e8577d2298a90..b2f2905684bf0 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -180,7 +180,7 @@ class StreamedOperator(BaseOperator): _operator_class: type[BaseOperator] expand_input: ExpandInput partial_kwargs: dict[str, Any] - shallow_copy_attrs: Sequence[str] = ("expand_input",) + shallow_copy_attrs: Sequence[str] = ("expand_input", "partial_kwargs", "_log", "_semaphore") def __init__( self, From 7d1151226d6ba6115f78be7e6bb9a5f95de3244f Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 8 Oct 2024 13:58:34 +0200 Subject: [PATCH 45/97] refactor: Use existing is_mappable method instead of custom isinstance check --- airflow/models/streamedoperator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index b2f2905684bf0..e45c0294787ac 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -35,6 +35,7 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.expandinput import ( ExpandInput, + is_mappable, _needs_run_time_resolution, ) from airflow.models.taskinstance import TaskInstance @@ -224,7 +225,7 @@ def _resolve_expand_input(self, context: Context, session: Session): if _needs_run_time_resolution(value): value = value.resolve(context=context, session=session) - if isinstance(value, Sequence) and not isinstance(value, (str, bytes)): + if is_mappable(value): value = list(value) self.log.debug("resolved_value: %s", value) From b76ccf517226ed56568eb29272e267660bde9740 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 9 Oct 2024 13:25:23 +0200 Subject: [PATCH 46/97] refactor: Reorganized imports of StreamedOperator --- airflow/models/streamedoperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index e45c0294787ac..8c574e7d7988b 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -35,8 +35,8 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.expandinput import ( ExpandInput, - is_mappable, _needs_run_time_resolution, + is_mappable, ) from airflow.models.taskinstance import TaskInstance from airflow.triggers.base import run_trigger From 14be647c1e660f10afa2e2610400dd31e99d2ac3 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 9 Oct 2024 13:28:55 +0200 Subject: [PATCH 47/97] refactor: Ignore type when evaluating value as list --- airflow/models/streamedoperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 8c574e7d7988b..5cded01f865e6 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -226,7 +226,7 @@ def _resolve_expand_input(self, context: Context, session: Session): value = value.resolve(context=context, session=session) if is_mappable(value): - value = list(value) + value = list(value) # type: ignore self.log.debug("resolved_value: %s", value) From d9ba001671fff7465f70bf7db9a6b74a36edb05c Mon Sep 17 00:00:00 2001 From: David Blain Date: Fri, 22 Nov 2024 17:12:37 +0100 Subject: [PATCH 48/97] refactor: Use ThreadPoolExecutor to execute tasks concurrently instead of purely relying on asyncio --- airflow/models/mappedoperator.py | 4 +- airflow/models/streamedoperator.py | 172 +++++++++++++++++++---------- 2 files changed, 118 insertions(+), 58 deletions(-) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 7056f5344d19d..af9032ccc0043 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -40,7 +40,7 @@ DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, DEFAULT_WEIGHT_RULE, AbstractOperator, - NotMapped, + NotMapped, DEFAULT_TASK_EXECUTION_TIMEOUT, ) from airflow.models.expandinput import ( DictOfListsExpandInput, @@ -258,6 +258,7 @@ def stream(self, **mapped_kwargs: OperatorExpandArgument) -> StreamedOperator: start_date = partial_kwargs.pop("start_date") end_date = partial_kwargs.pop("end_date") max_active_tis_per_dag = partial_kwargs.pop("max_active_tis_per_dag", None) + execution_timeout = partial_kwargs.pop("execution_timeout", DEFAULT_TASK_EXECUTION_TIMEOUT) return StreamedOperator( task_id=task_id, @@ -265,6 +266,7 @@ def stream(self, **mapped_kwargs: OperatorExpandArgument) -> StreamedOperator: task_group=task_group, start_date=start_date, end_date=end_date, + execution_timeout=execution_timeout, max_active_tis_per_dag=max_active_tis_per_dag, operator_class=self.operator_class, expand_input=expand_input, diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 5cded01f865e6..203f0603f197d 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -20,7 +20,8 @@ import asyncio import logging import os -from asyncio import AbstractEventLoop, Semaphore, ensure_future, iscoroutinefunction +from asyncio import AbstractEventLoop, Semaphore, ensure_future, iscoroutinefunction, wait_for +from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager, suppress from math import ceil from time import sleep @@ -82,13 +83,11 @@ class OperatorExecutor(LoggingMixin): def __init__( self, - semaphore: Semaphore, operator: BaseOperator, context: Context, task_instance: TaskInstance, ): super().__init__() - self._semaphore = semaphore self.operator = operator self.__context = context self._task_instance = task_instance @@ -104,31 +103,31 @@ def context(self) -> Context: return {**self.__context, **{"ti": self.task_instance}} async def _run_callable(self, method: Callable, *args, **kwargs): - self.log.debug("semaphore: %s (%s)", self._semaphore, self._semaphore.locked()) - async with self._semaphore: - while self.task_instance.try_number <= self.operator.retries: - try: - outlet_events = context_get_outlet_events(self.context) - callable_runner = ExecutionCallableRunner( - func=method, outlet_events=outlet_events, logger=self.log + while self.task_instance.try_number <= self.operator.retries: + try: + outlet_events = context_get_outlet_events(self.context) + callable_runner = ExecutionCallableRunner( + func=method, outlet_events=outlet_events, logger=self.log + ) + if iscoroutinefunction(method): + return await callable_runner.run(*args, **kwargs) + return callable_runner.run(*args, **kwargs) + except AirflowException as e: + if self.task_instance.next_try_number > self.operator.retries: + self.log.error( + "Max number of attempts for %s with map_index %s failed due to: %s", + type(self.operator).__name__, + self.task_instance.map_index, + e, ) - if iscoroutinefunction(method): - return await callable_runner.run(*args, **kwargs) - return callable_runner.run(*args, **kwargs) - except AirflowException as e: - if self.task_instance.next_try_number > self.operator.retries: - self.log.error( - "Max number of attempts for %s with map_index %s failed due to: %s", - type(self.operator).__name__, - self.task_instance.map_index, - e, - ) - raise e + raise e - self.task_instance.try_number += 1 - self.task_instance.end_date = timezone.utcnow() + self.task_instance.try_number += 1 + self.task_instance.end_date = timezone.utcnow() - raise AirflowRescheduleTaskInstanceException(task_instance=self.task_instance) + raise AirflowRescheduleTaskInstanceException( + task_instance=self.task_instance + ) async def _run_deferrable(self, context: Context, task_deferred: TaskDeferred): event = await run_trigger(task_deferred.trigger) @@ -157,12 +156,18 @@ async def run(self, method: Callable, *args, **kwargs): ) self.operator.pre_execute(context=self.context) - self.task_instance._run_execute_callback(context=self.context, task=self.operator) + self.task_instance._run_execute_callback( + context=self.context, task=self.operator + ) try: - return await self._run_callable(method, *(list(args or ()) + [self.context]), **kwargs) + return await self._run_callable( + method, *(list(args or ()) + [self.context]), **kwargs + ) except TaskDeferred as task_deferred: - return await self._run_callable(self._run_deferrable, *[self.context, task_deferred]) + return await self._run_callable( + self._run_deferrable, *[self.context, task_deferred] + ) finally: self.operator.post_execute(context=self.context) @@ -181,7 +186,13 @@ class StreamedOperator(BaseOperator): _operator_class: type[BaseOperator] expand_input: ExpandInput partial_kwargs: dict[str, Any] - shallow_copy_attrs: Sequence[str] = ("expand_input", "partial_kwargs", "_log", "_semaphore") + # each operator should override this class attr for shallow copy attrs. + shallow_copy_attrs: Sequence[str] = ( + "expand_input", + "partial_kwargs", + "_log", + "_semaphore", + ) def __init__( self, @@ -198,7 +209,6 @@ def __init__( self._mapped_kwargs: list[dict] = [] if not self.max_active_tis_per_dag: self.max_active_tis_per_dag = os.cpu_count() or 1 - self._semaphore = Semaphore(self.max_active_tis_per_dag) XComArg.apply_upstream_relationship(self, self.expand_input.value) @property @@ -212,7 +222,7 @@ def _unmap_operator(self, index): self.log.debug("index: %s", index) kwargs = { **self.partial_kwargs, - **{"task_id": f"{self.task_id}_{index}"}, + **{"task_id": f"{self.partial_kwargs.get('task_id')}_{index}"}, **self._mapped_kwargs[index], } self.log.debug("kwargs: %s", kwargs) @@ -244,23 +254,67 @@ def render_template_fields( session = get_current_task_instance_session() self._resolve_expand_input(context=context, session=session) - def _run_futures(self, context: Context, futures, results: list[Any] | None = None) -> list[Any]: + @classmethod + def run_until_complete(cls, func, *args, **kwargs): + try: + loop = asyncio.get_event_loop() + return loop.run_until_complete(func(*args, **kwargs)) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop.run_until_complete(func(*args, **kwargs)) + + async def _run_tasks( + self, context: Context, tasks, results: list[Any] | None = None + ) -> list[Any]: reschedule_date = timezone.utcnow() results = results or [] - failed_futures = [] + failed_tasks = [] with event_loop() as loop: - for result in loop.run_until_complete(asyncio.gather(*futures, return_exceptions=True)): - if isinstance(result, Exception): - if not isinstance(result, AirflowRescheduleTaskInstanceException): - raise result - reschedule_date = result.reschedule_date - failed_futures.append(ensure_future(self._run_task(context, result.task_instance))) - else: - results.append(result) + with ThreadPoolExecutor( + max_workers=self.max_active_tis_per_dag + ) as executor: + futures = [ + ( + loop.run_in_executor( + executor, + self.run_until_complete, + self._run_operator, + context, + task, + ), + task, + ) + for task in tasks + ] + + for future, task in futures: + try: + # Apply timeout for each task + result = await wait_for( + future, timeout=self.execution_timeout.total_seconds() + ) + self.log.info("result: %s", result) + results.append(result) + except asyncio.TimeoutError: + self.log.warning( + "Task timed out after %s seconds: %s", + self.execution_timeout.total_seconds(), + task.task_id, + ) + if task.next_try_number > self.retries: + raise e + failed_tasks.append(task) + except AirflowRescheduleTaskInstanceException as e: + reschedule_date = e.reschedule_date + failed_tasks.append(e.task_instance) + except Exception as e: + self.log.exception("Unexpected error: %s", e) + raise e - if not failed_futures: - return list(filter(None, results)) + if not failed_tasks: + return list(filter(None, results)) # session = get_current_task_instance_session() # TaskInstance._set_state(context["ti"], TaskInstanceState.UP_FOR_RETRY, session) @@ -274,21 +328,20 @@ def _run_futures(self, context: Context, futures, results: list[Any] | None = No if delay_seconds > 0: self.log.info( "Attempting to run %s failed tasks within %s seconds...", - len(failed_futures), + len(failed_tasks), delay_seconds, ) - sleep(delay_seconds) + await sleep(delay_seconds) # TaskInstance._set_state(context["ti"], TaskInstanceState.RUNNING, session) - return self._run_futures(context, failed_futures, results) + return await self._run_tasks(context, failed_tasks, results) - async def _run_task(self, context: Context, task_instance: TaskInstance): + async def _run_operator(self, context: Context, task_instance: TaskInstance): operator: BaseOperator = cast(BaseOperator, task_instance.task) self.log.debug("operator: %s", operator) result = await OperatorExecutor( - semaphore=self._semaphore, operator=operator, context=context, task_instance=task_instance, @@ -298,7 +351,7 @@ async def _run_task(self, context: Context, task_instance: TaskInstance): if operator.do_xcom_push: return result - def _create_future(self, context: Context, index: int): + def _create_task(self, context: Context, index: int): operator = self._unmap_operator(index) operator.render_template_fields(context=context) task_instance = TaskInstance( @@ -307,19 +360,24 @@ def _create_future(self, context: Context, index: int): state=context["ti"].state, map_index=index, ) - return asyncio.ensure_future(self._run_task(context, task_instance)) + return task_instance def execute(self, context: Context): self.log.info( - "Executing %s mapped tasks on %s with %s workers", + "Executing %s mapped tasks on %s with %s threads and timeout of %s", len(self._mapped_kwargs), self._operator_class.__name__, self.max_active_tis_per_dag, + self.execution_timeout.total_seconds(), ) - return self._run_futures( - context=context, - futures=[ - self._create_future(context, index) for index, mapped_kwargs in enumerate(self._mapped_kwargs) - ], - ) + with event_loop() as loop: + loop.run_until_complete( + self._run_tasks( + context=context, + tasks=[ + self._create_task(context, index) + for index, _ in enumerate(self._mapped_kwargs) + ], + ) + ) From 1e10f27cc5b12d3c9966fe507fe66e2a95d27ef7 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 25 Nov 2024 08:44:12 +0100 Subject: [PATCH 49/97] refactor: Reorganized imports in StreamedOperator --- airflow/models/streamedoperator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 203f0603f197d..4d58600def4fc 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -20,7 +20,7 @@ import asyncio import logging import os -from asyncio import AbstractEventLoop, Semaphore, ensure_future, iscoroutinefunction, wait_for +from asyncio import AbstractEventLoop, TimeoutError, iscoroutinefunction, wait_for from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager, suppress from math import ceil @@ -295,9 +295,9 @@ async def _run_tasks( result = await wait_for( future, timeout=self.execution_timeout.total_seconds() ) - self.log.info("result: %s", result) + self.log.debug("result: %s", result) results.append(result) - except asyncio.TimeoutError: + except TimeoutError: self.log.warning( "Task timed out after %s seconds: %s", self.execution_timeout.total_seconds(), From 11200e40d8d91933d9ad826e8de07abe97d7d9ea Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 28 Nov 2024 16:56:44 +0100 Subject: [PATCH 50/97] refactor: Refactored StreamedOperator using ThreadPool so we can also support timeouts --- airflow/models/streamedoperator.py | 178 +++++++++++++---------------- 1 file changed, 80 insertions(+), 98 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 4d58600def4fc..77d15468a3355 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -20,19 +20,21 @@ import asyncio import logging import os -from asyncio import AbstractEventLoop, TimeoutError, iscoroutinefunction, wait_for -from concurrent.futures import ThreadPoolExecutor -from contextlib import contextmanager, suppress +from asyncio import TimeoutError, iscoroutinefunction, wait_for +from datetime import timedelta from math import ceil +from multiprocessing.pool import ThreadPool from time import sleep -from typing import TYPE_CHECKING, Any, Callable, Generator, Sequence, cast +from typing import TYPE_CHECKING, Any, Callable, Sequence, cast, Iterable from airflow import XComArg from airflow.exceptions import ( AirflowException, + AirflowTaskTimeout, AirflowRescheduleTaskInstanceException, TaskDeferred, ) +from airflow.models.abstractoperator import DEFAULT_TASK_EXECUTION_TIMEOUT from airflow.models.baseoperator import BaseOperator from airflow.models.expandinput import ( ExpandInput, @@ -52,24 +54,6 @@ from sqlalchemy.orm import Session -@contextmanager -def event_loop() -> Generator[AbstractEventLoop, None, None]: - new_event_loop = False - loop = None - try: - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - new_event_loop = True - yield loop - finally: - if new_event_loop and loop is not None: - with suppress(AttributeError): - loop.close() - - class OperatorExecutor(LoggingMixin): """ Run an operator with given task context and task instance. @@ -86,11 +70,13 @@ def __init__( operator: BaseOperator, context: Context, task_instance: TaskInstance, + timeout, ): super().__init__() self.operator = operator self.__context = context self._task_instance = task_instance + self.timeout = timeout @property def task_instance(self) -> TaskInstance: @@ -155,18 +141,26 @@ async def run(self, method: Callable, *args, **kwargs): self.task_instance.map_index, ) + if self.task_instance.try_number == 0: + self.operator.render_template_fields(context=self.context) self.operator.pre_execute(context=self.context) self.task_instance._run_execute_callback( context=self.context, task=self.operator ) try: - return await self._run_callable( - method, *(list(args or ()) + [self.context]), **kwargs + return await wait_for( + self._run_callable( + method, *(list(args or ()) + [self.context]), **kwargs + ), + timeout=self.timeout, ) except TaskDeferred as task_deferred: - return await self._run_callable( - self._run_deferrable, *[self.context, task_deferred] + return await wait_for( + self._run_callable( + self._run_deferrable, *[self.context, task_deferred] + ), + timeout=self.timeout, ) finally: self.operator.post_execute(context=self.context) @@ -200,12 +194,16 @@ def __init__( operator_class: type[BaseOperator], expand_input: ExpandInput, partial_kwargs: dict[str, Any] | None = None, + timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT, **kwargs: Any, ): super().__init__(**kwargs) self._operator_class = operator_class self.expand_input = expand_input self.partial_kwargs = partial_kwargs or {} + self.timeout = ( + timeout.total_seconds() if isinstance(timeout, timedelta) else timeout + ) self._mapped_kwargs: list[dict] = [] if not self.max_active_tis_per_dag: self.max_active_tis_per_dag = os.cpu_count() or 1 @@ -222,7 +220,7 @@ def _unmap_operator(self, index): self.log.debug("index: %s", index) kwargs = { **self.partial_kwargs, - **{"task_id": f"{self.partial_kwargs.get('task_id')}_{index}"}, + **{"task_id": f"{self.task_id}_{index}"}, **self._mapped_kwargs[index], } self.log.debug("kwargs: %s", kwargs) @@ -254,66 +252,51 @@ def render_template_fields( session = get_current_task_instance_session() self._resolve_expand_input(context=context, session=session) - @classmethod - def run_until_complete(cls, func, *args, **kwargs): - try: - loop = asyncio.get_event_loop() - return loop.run_until_complete(func(*args, **kwargs)) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop.run_until_complete(func(*args, **kwargs)) - - async def _run_tasks( - self, context: Context, tasks, results: list[Any] | None = None + def _run_task(self, context: Context, task: TaskInstance): + loop = ( + asyncio.new_event_loop() + ) # Always open new event loop as this is executed in multithreaded + asyncio.set_event_loop(loop) + return loop.run_until_complete(self._run_operator(context, task)) + + def _run_tasks( + self, + context: Context, + tasks: Iterable[TaskInstance], + results: list[Any] | None = None, ) -> list[Any]: - reschedule_date = timezone.utcnow() + exception: BaseException | None = None results = results or [] - failed_tasks = [] - - with event_loop() as loop: - with ThreadPoolExecutor( - max_workers=self.max_active_tis_per_dag - ) as executor: - futures = [ - ( - loop.run_in_executor( - executor, - self.run_until_complete, - self._run_operator, - context, - task, - ), - task, - ) - for task in tasks - ] - - for future, task in futures: - try: - # Apply timeout for each task - result = await wait_for( - future, timeout=self.execution_timeout.total_seconds() - ) - self.log.debug("result: %s", result) - results.append(result) - except TimeoutError: - self.log.warning( - "Task timed out after %s seconds: %s", - self.execution_timeout.total_seconds(), - task.task_id, - ) - if task.next_try_number > self.retries: - raise e + reschedule_date = timezone.utcnow() + failed_tasks: list[TaskInstance] = [] + + with ThreadPool(processes=self.max_active_tis_per_dag) as pool: + futures = [ + (task, pool.apply_async(self._run_task, (context, task))) + for task in tasks + ] + + for task, future in futures: + try: + result = future.get(timeout=self.timeout) + results.append(result) + except TimeoutError as e: + self.log.warning("A timeout occurred for task_id %s", task.task_id) + if task.next_try_number > self.retries: + exception = AirflowTaskTimeout(e) + else: + reschedule_date = task.next_retry_datetime() failed_tasks.append(task) - except AirflowRescheduleTaskInstanceException as e: - reschedule_date = e.reschedule_date - failed_tasks.append(e.task_instance) - except Exception as e: - self.log.exception("Unexpected error: %s", e) - raise e + except AirflowRescheduleTaskInstanceException as e: + reschedule_date = e.reschedule_date + failed_tasks.append(e.task) + except AirflowException as e: + self.log.error("An exception occurred for task_id %s", task.task_id) + exception = e if not failed_tasks: + if exception: + raise exception return list(filter(None, results)) # session = get_current_task_instance_session() @@ -332,11 +315,11 @@ async def _run_tasks( delay_seconds, ) - await sleep(delay_seconds) + sleep(delay_seconds) # TaskInstance._set_state(context["ti"], TaskInstanceState.RUNNING, session) - return await self._run_tasks(context, failed_tasks, results) + return self._run_tasks(context, failed_tasks, results) async def _run_operator(self, context: Context, task_instance: TaskInstance): operator: BaseOperator = cast(BaseOperator, task_instance.task) @@ -351,9 +334,8 @@ async def _run_operator(self, context: Context, task_instance: TaskInstance): if operator.do_xcom_push: return result - def _create_task(self, context: Context, index: int): + def _create_task(self, context: Context, index: int) -> TaskInstance: operator = self._unmap_operator(index) - operator.render_template_fields(context=context) task_instance = TaskInstance( task=operator, run_id=context["ti"].run_id, @@ -364,20 +346,20 @@ def _create_task(self, context: Context, index: int): def execute(self, context: Context): self.log.info( - "Executing %s mapped tasks on %s with %s threads and timeout of %s", + "Executing %s mapped tasks on %s with %s threads and timeout %s", len(self._mapped_kwargs), self._operator_class.__name__, self.max_active_tis_per_dag, - self.execution_timeout.total_seconds(), + self.timeout, ) - with event_loop() as loop: - loop.run_until_complete( - self._run_tasks( - context=context, - tasks=[ - self._create_task(context, index) - for index, _ in enumerate(self._mapped_kwargs) - ], - ) - ) + results = self._run_tasks( + context=context, + tasks=map( + lambda index: self._create_task(context, index[0]), + enumerate(self._mapped_kwargs), + ), + ) + + if self.do_xcom_push: + return results From 41381f9c31c02edb26ac8fd8a9abd92dc69134ef Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 28 Nov 2024 18:45:24 +0100 Subject: [PATCH 51/97] fix: Don't set the newly created loop event, just use it locally and close it after execution --- airflow/models/streamedoperator.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 77d15468a3355..983ed2e8983e0 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -253,11 +253,12 @@ def render_template_fields( self._resolve_expand_input(context=context, session=session) def _run_task(self, context: Context, task: TaskInstance): - loop = ( - asyncio.new_event_loop() - ) # Always open new event loop as this is executed in multithreaded - asyncio.set_event_loop(loop) - return loop.run_until_complete(self._run_operator(context, task)) + # Always open new event loop as this is executed in multithreaded + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(self._run_operator(context, task)) + finally: + loop.close() def _run_tasks( self, From c3c06f71b66a095ee5ad952758bb6fb9eb69bbb5 Mon Sep 17 00:00:00 2001 From: David Blain Date: Sat, 30 Nov 2024 17:41:53 +0100 Subject: [PATCH 52/97] refactor: Refactored the StreamedOperator even more --- airflow/models/streamedoperator.py | 117 +++++++++++++++++++++-------- 1 file changed, 86 insertions(+), 31 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 983ed2e8983e0..535882ce75fb0 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -20,7 +20,8 @@ import asyncio import logging import os -from asyncio import TimeoutError, iscoroutinefunction, wait_for +from multiprocessing import TimeoutError +from asyncio import iscoroutinefunction, wait_for, ensure_future from datetime import timedelta from math import ceil from multiprocessing.pool import ThreadPool @@ -31,20 +32,21 @@ from airflow.exceptions import ( AirflowException, AirflowTaskTimeout, - AirflowRescheduleTaskInstanceException, - TaskDeferred, + TaskDeferred, AirflowRescheduleException, ) from airflow.models.abstractoperator import DEFAULT_TASK_EXECUTION_TIMEOUT from airflow.models.baseoperator import BaseOperator from airflow.models.expandinput import ( ExpandInput, _needs_run_time_resolution, - is_mappable, + is_mappable, OperatorExpandArgument, DictOfListsExpandInput, ) +from airflow.models.mappedoperator import validate_mapping_kwargs, ensure_xcomarg_return_value from airflow.models.taskinstance import TaskInstance -from airflow.triggers.base import run_trigger +from airflow.triggers.base import BaseTrigger, TriggerEvent from airflow.utils import timezone from airflow.utils.context import Context, context_get_outlet_events +from airflow.utils.helpers import prevent_duplicates from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.operator_helpers import ExecutionCallableRunner from airflow.utils.task_instance_session import get_current_task_instance_session @@ -54,6 +56,24 @@ from sqlalchemy.orm import Session +# TODO: Check _run_inline_trigger method from DAG, could be refactored so it uses this method +async def run_trigger(trigger: BaseTrigger) -> TriggerEvent | None: + async for event in trigger.run(): + return event + + +class AirflowRescheduleTaskInstanceException(AirflowRescheduleException): + """ + Raise when the task should be re-scheduled for a specific TaskInstance at a later time. + + :param task_instance: The task instance that should be rescheduled + """ + + def __init__(self, task: TaskInstance): + super().__init__(reschedule_date=task.next_retry_datetime()) + self.task = task + + class OperatorExecutor(LoggingMixin): """ Run an operator with given task context and task instance. @@ -89,31 +109,30 @@ def context(self) -> Context: return {**self.__context, **{"ti": self.task_instance}} async def _run_callable(self, method: Callable, *args, **kwargs): - while self.task_instance.try_number <= self.operator.retries: - try: - outlet_events = context_get_outlet_events(self.context) - callable_runner = ExecutionCallableRunner( - func=method, outlet_events=outlet_events, logger=self.log - ) - if iscoroutinefunction(method): - return await callable_runner.run(*args, **kwargs) - return callable_runner.run(*args, **kwargs) - except AirflowException as e: - if self.task_instance.next_try_number > self.operator.retries: - self.log.error( - "Max number of attempts for %s with map_index %s failed due to: %s", - type(self.operator).__name__, - self.task_instance.map_index, - e, - ) - raise e - - self.task_instance.try_number += 1 - self.task_instance.end_date = timezone.utcnow() - - raise AirflowRescheduleTaskInstanceException( - task_instance=self.task_instance + try: + outlet_events = context_get_outlet_events(self.context) + callable_runner = ExecutionCallableRunner( + func=method, outlet_events=outlet_events, logger=self.log + ) + if iscoroutinefunction(method): + return await callable_runner.run(*args, **kwargs) + return callable_runner.run(*args, **kwargs) + except AirflowException as e: + if self.task_instance.next_try_number > self.operator.retries: + self.log.error( + "Max number of attempts for %s with map_index %s failed due to: %s", + type(self.operator).__name__, + self.task_instance.map_index, + e, ) + raise e + + self.task_instance.try_number += 1 + self.task_instance.end_date = timezone.utcnow() + + raise AirflowRescheduleTaskInstanceException( + task=self.task_instance + ) async def _run_deferrable(self, context: Context, task_deferred: TaskDeferred): event = await run_trigger(task_deferred.trigger) @@ -256,7 +275,7 @@ def _run_task(self, context: Context, task: TaskInstance): # Always open new event loop as this is executed in multithreaded loop = asyncio.new_event_loop() try: - return loop.run_until_complete(self._run_operator(context, task)) + return loop.run_until_complete(ensure_future(self._run_operator(context, task), loop=loop)) finally: loop.close() @@ -329,7 +348,12 @@ async def _run_operator(self, context: Context, task_instance: TaskInstance): operator=operator, context=context, task_instance=task_instance, - ).run(operator.execute) + timeout=self.timeout, + ).run( + lambda _context: operator.execute.__wrapped__( + self=operator, context=_context + ) + ) # TODO: change back to operator.execute once ExecutorSafeguard is fixed self.log.debug("result: %s", result) self.log.debug("do_xcom_push: %s", operator.do_xcom_push) if operator.do_xcom_push: @@ -364,3 +388,34 @@ def execute(self, context: Context): if self.do_xcom_push: return results + + +def stream(self, **mapped_kwargs: OperatorExpandArgument) -> StreamedOperator: + if not mapped_kwargs: + raise TypeError("no arguments to expand against") + validate_mapping_kwargs(self.operator_class, "stream", mapped_kwargs) + prevent_duplicates( + self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified" + ) + + expand_input = DictOfListsExpandInput(mapped_kwargs) + ensure_xcomarg_return_value(expand_input.value) + + kwargs = {} + + for parameter_name in BaseOperator._comps: + parameter_value = self.kwargs.get(parameter_name) + if parameter_value: + kwargs[parameter_name] = parameter_value + + # We don't retry the whole stream operator, we retry the individual tasks + kwargs["retries"] = 0 + # We don't want to time out the whole stream operator, we only time out the individual tasks + kwargs["timeout"] = kwargs.pop("execution_timeout", None) + + return StreamedOperator( + **kwargs, + operator_class=self.operator_class, + expand_input=expand_input, + partial_kwargs=self.kwargs, + ) From 493972ad162d9d63c77461350c9dce45f76f5c5e Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 2 Dec 2024 16:33:32 +0100 Subject: [PATCH 53/97] feature: Fix StreamedOperator so it can run triggers in async mode and normal --- airflow/exceptions.py | 6 +- airflow/models/mappedoperator.py | 38 ++-- airflow/models/streamedoperator.py | 292 ++++++++++++++--------------- 3 files changed, 159 insertions(+), 177 deletions(-) diff --git a/airflow/exceptions.py b/airflow/exceptions.py index e9adb297af258..12aa0edec5a74 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -91,9 +91,9 @@ class AirflowRescheduleTaskInstanceException(AirflowRescheduleException): :param task_instance: The task instance that should be rescheduled """ - def __init__(self, task_instance: TaskInstance): - super().__init__(reschedule_date=task_instance.next_retry_datetime()) - self.task_instance = task_instance + def __init__(self, task: TaskInstance): + super().__init__(reschedule_date=task.next_retry_datetime()) + self.task = task class InvalidStatsNameException(AirflowException): diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 3a59d0350dd2f..b03720650cdc1 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -26,7 +26,6 @@ import attr import methodtools - from airflow.exceptions import UnmappableOperator from airflow.models.abstractoperator import ( DEFAULT_EXECUTOR, @@ -41,8 +40,7 @@ DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, DEFAULT_WEIGHT_RULE, AbstractOperator, - NotMapped, DEFAULT_TASK_EXECUTION_TIMEOUT, -) + NotMapped, ) from airflow.models.expandinput import ( DictOfListsExpandInput, ListOfDictsExpandInput, @@ -246,32 +244,30 @@ def stream(self, **mapped_kwargs: OperatorExpandArgument) -> StreamedOperator: if not mapped_kwargs: raise TypeError("no arguments to expand against") validate_mapping_kwargs(self.operator_class, "stream", mapped_kwargs) - prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") + prevent_duplicates( + self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified" + ) expand_input = DictOfListsExpandInput(mapped_kwargs) ensure_xcomarg_return_value(expand_input.value) - partial_kwargs = self.kwargs.copy() - task_id = partial_kwargs.pop("task_id") - dag = partial_kwargs.pop("dag") - task_group = partial_kwargs.pop("task_group") - start_date = partial_kwargs.pop("start_date") - end_date = partial_kwargs.pop("end_date") - max_active_tis_per_dag = partial_kwargs.pop("max_active_tis_per_dag", None) - execution_timeout = partial_kwargs.pop("execution_timeout", DEFAULT_TASK_EXECUTION_TIMEOUT) + kwargs = {} + + for parameter_name in BaseOperator._comps: + parameter_value = self.kwargs.get(parameter_name) + if parameter_value: + kwargs[parameter_name] = parameter_value + + # We don't retry the whole stream operator, we retry the individual tasks + kwargs["retries"] = 0 + # We don't want to time out the whole stream operator, we only time out the individual tasks + kwargs["timeout"] = kwargs.pop("execution_timeout", None) return StreamedOperator( - task_id=task_id, - dag=dag, - task_group=task_group, - start_date=start_date, - end_date=end_date, - execution_timeout=execution_timeout, - max_active_tis_per_dag=max_active_tis_per_dag, + **kwargs, operator_class=self.operator_class, expand_input=expand_input, - retries=0, - partial_kwargs=self.kwargs.copy(), + partial_kwargs=self.kwargs, ) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 535882ce75fb0..c6f1cc7b824a3 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -20,33 +20,34 @@ import asyncio import logging import os -from multiprocessing import TimeoutError -from asyncio import iscoroutinefunction, wait_for, ensure_future +from abc import abstractmethod +from asyncio import AbstractEventLoop, Future, ensure_future +from contextlib import contextmanager, suppress from datetime import timedelta from math import ceil +from multiprocessing import TimeoutError from multiprocessing.pool import ThreadPool from time import sleep -from typing import TYPE_CHECKING, Any, Callable, Sequence, cast, Iterable +from typing import TYPE_CHECKING, Any, Generator, Iterable, Sequence from airflow import XComArg from airflow.exceptions import ( AirflowException, + AirflowRescheduleTaskInstanceException, AirflowTaskTimeout, - TaskDeferred, AirflowRescheduleException, + TaskDeferred, ) from airflow.models.abstractoperator import DEFAULT_TASK_EXECUTION_TIMEOUT from airflow.models.baseoperator import BaseOperator from airflow.models.expandinput import ( ExpandInput, + is_mappable, _needs_run_time_resolution, - is_mappable, OperatorExpandArgument, DictOfListsExpandInput, ) -from airflow.models.mappedoperator import validate_mapping_kwargs, ensure_xcomarg_return_value from airflow.models.taskinstance import TaskInstance -from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.triggers.base import run_trigger from airflow.utils import timezone from airflow.utils.context import Context, context_get_outlet_events -from airflow.utils.helpers import prevent_duplicates from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.operator_helpers import ExecutionCallableRunner from airflow.utils.task_instance_session import get_current_task_instance_session @@ -56,47 +57,33 @@ from sqlalchemy.orm import Session -# TODO: Check _run_inline_trigger method from DAG, could be refactored so it uses this method -async def run_trigger(trigger: BaseTrigger) -> TriggerEvent | None: - async for event in trigger.run(): - return event - - -class AirflowRescheduleTaskInstanceException(AirflowRescheduleException): - """ - Raise when the task should be re-scheduled for a specific TaskInstance at a later time. - - :param task_instance: The task instance that should be rescheduled - """ - - def __init__(self, task: TaskInstance): - super().__init__(reschedule_date=task.next_retry_datetime()) - self.task = task - - -class OperatorExecutor(LoggingMixin): - """ - Run an operator with given task context and task instance. - - If the execute function raises a TaskDeferred exception, then the trigger instance within the - TaskDeferred exception will be executed with the given context and task instance. The operator - or trigger will always be executed in an async way. - - :meta private: - """ - +@contextmanager +def event_loop() -> Generator[AbstractEventLoop, None, None]: + new_event_loop = False + loop = None + try: + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + new_event_loop = True + yield loop + finally: + if new_event_loop and loop is not None: + with suppress(AttributeError): + loop.close() + + +class TaskExecutor(LoggingMixin): def __init__( self, - operator: BaseOperator, context: Context, task_instance: TaskInstance, - timeout, ): super().__init__() - self.operator = operator self.__context = context self._task_instance = task_instance - self.timeout = timeout @property def task_instance(self) -> TaskInstance: @@ -108,15 +95,67 @@ def task_instance(self) -> TaskInstance: def context(self) -> Context: return {**self.__context, **{"ti": self.task_instance}} - async def _run_callable(self, method: Callable, *args, **kwargs): + @property + def operator(self) -> BaseOperator: + return self.task_instance.task + + @abstractmethod + def run(self, *args, **kwargs): + raise NotImplementedError() + + def _handle_result(self, result): + """ + Common logic to handle result and post-execution tasks. + """ + self.operator.post_execute(context=self.context) + if self.log.isEnabledFor(logging.INFO): + self.log.info( + "Task instance %s for %s finished successfully in %s attempts.", + self.task_instance.map_index, + type(self.operator).__name__, + self.task_instance.next_try_number, + ) + if self.operator.do_xcom_push: + return result + return None + + +class OperatorExecutor(TaskExecutor): + """ + Run an operator with given task context and task instance. + + If the execute function raises a TaskDeferred exception, then the trigger instance within the + TaskDeferred exception will be executed with the given context and task instance. The operator + or trigger will always be executed in an async way. + + :meta private: + """ + + def run(self): + if self.log.isEnabledFor(logging.INFO): + self.log.info( + "Attempting running task %s of %s for %s with map_index %s.", + self.task_instance.try_number, + self.operator.retries, + type(self.operator).__name__, + self.task_instance.map_index, + ) + + if self.task_instance.try_number == 0: + self.operator.render_template_fields(context=self.context) + self.operator.pre_execute(context=self.context) + self.task_instance._run_execute_callback( + context=self.context, task=self.operator + ) try: outlet_events = context_get_outlet_events(self.context) callable_runner = ExecutionCallableRunner( - func=method, outlet_events=outlet_events, logger=self.log + func=self.operator.execute, + outlet_events=outlet_events, + logger=self.log, ) - if iscoroutinefunction(method): - return await callable_runner.run(*args, **kwargs) - return callable_runner.run(*args, **kwargs) + result = callable_runner.run(self.operator, self.context) + return self._handle_result(result) except AirflowException as e: if self.task_instance.next_try_number > self.operator.retries: self.log.error( @@ -134,7 +173,9 @@ async def _run_callable(self, method: Callable, *args, **kwargs): task=self.task_instance ) - async def _run_deferrable(self, context: Context, task_deferred: TaskDeferred): + +class TriggerExecutor(TaskExecutor): + async def run(self, task_deferred: TaskDeferred): event = await run_trigger(task_deferred.trigger) self.log.debug("event: %s", event) @@ -143,54 +184,14 @@ async def _run_deferrable(self, context: Context, task_deferred: TaskDeferred): self.log.debug("next_method: %s", task_deferred.method_name) if task_deferred.method_name: - next_method = BaseOperator.next_callable( - self.operator, task_deferred.method_name, task_deferred.kwargs - ) - result = next_method(context, event.payload) - self.log.debug("result: %s", result) - return result - - async def run(self, method: Callable, *args, **kwargs): - if self.log.isEnabledFor(logging.INFO): - self.log.info( - "Attempting running task %s of %s for %s with map_index %s.", - self.task_instance.try_number, - self.operator.retries, - type(self.operator).__name__, - self.task_instance.map_index, - ) - - if self.task_instance.try_number == 0: - self.operator.render_template_fields(context=self.context) - self.operator.pre_execute(context=self.context) - self.task_instance._run_execute_callback( - context=self.context, task=self.operator - ) - - try: - return await wait_for( - self._run_callable( - method, *(list(args or ()) + [self.context]), **kwargs - ), - timeout=self.timeout, - ) - except TaskDeferred as task_deferred: - return await wait_for( - self._run_callable( - self._run_deferrable, *[self.context, task_deferred] - ), - timeout=self.timeout, - ) - finally: - self.operator.post_execute(context=self.context) - - if self.log.isEnabledFor(logging.INFO): - self.log.info( - "Task instance %s for %s finished successfully in %s attempts.", - self.task_instance.map_index, - type(self.operator).__name__, - self.task_instance.next_try_number, - ) + try: + next_method = BaseOperator.next_callable( + self.operator, task_deferred.method_name, task_deferred.kwargs + ) + result = next_method(self.context, event.payload) + return self._handle_result(result) + except TaskDeferred as task_deferred: + return await self.run(task_deferred=task_deferred) class StreamedOperator(BaseOperator): @@ -204,7 +205,6 @@ class StreamedOperator(BaseOperator): "expand_input", "partial_kwargs", "_log", - "_semaphore", ) def __init__( @@ -271,14 +271,6 @@ def render_template_fields( session = get_current_task_instance_session() self._resolve_expand_input(context=context, session=session) - def _run_task(self, context: Context, task: TaskInstance): - # Always open new event loop as this is executed in multithreaded - loop = asyncio.new_event_loop() - try: - return loop.run_until_complete(ensure_future(self._run_operator(context, task), loop=loop)) - finally: - loop.close() - def _run_tasks( self, context: Context, @@ -288,18 +280,26 @@ def _run_tasks( exception: BaseException | None = None results = results or [] reschedule_date = timezone.utcnow() + deferred_tasks: list[Future] = [] failed_tasks: list[TaskInstance] = [] with ThreadPool(processes=self.max_active_tis_per_dag) as pool: futures = [ - (task, pool.apply_async(self._run_task, (context, task))) + (task, pool.apply_async(self._run_operator, (context, task))) for task in tasks ] for task, future in futures: try: result = future.get(timeout=self.timeout) - results.append(result) + if isinstance(result, TaskDeferred): + deferred_tasks.append( + ensure_future( + self._run_deferrable(context=context, task=task, task_deferred=result) + ) + ) + else: + results.append(result) except TimeoutError as e: self.log.warning("A timeout occurred for task_id %s", task.task_id) if task.next_try_number > self.retries: @@ -314,6 +314,23 @@ def _run_tasks( self.log.error("An exception occurred for task_id %s", task.task_id) exception = e + if deferred_tasks: + self.log.info("Running %s deferred tasks", len(deferred_tasks)) + + with event_loop() as loop: + for result in loop.run_until_complete( + asyncio.gather(*deferred_tasks, return_exceptions=True) + ): + if isinstance(result, Exception): + if not isinstance(result, AirflowRescheduleTaskInstanceException): + exception = result + reschedule_date = result.reschedule_date + failed_tasks.append(result.task) + else: + results.append(result) + + deferred_tasks.clear() + if not failed_tasks: if exception: raise exception @@ -341,23 +358,22 @@ def _run_tasks( return self._run_tasks(context, failed_tasks, results) - async def _run_operator(self, context: Context, task_instance: TaskInstance): - operator: BaseOperator = cast(BaseOperator, task_instance.task) - self.log.debug("operator: %s", operator) - result = await OperatorExecutor( - operator=operator, + @classmethod + def _run_operator(cls, context: Context, task_instance: TaskInstance): + try: + return OperatorExecutor( + context=context, + task_instance=task_instance, + ).run() + except TaskDeferred as task_deferred: + return task_deferred + + @classmethod + async def _run_deferrable(cls, context: Context, task: TaskInstance, task_deferred: TaskDeferred): + return await TriggerExecutor( context=context, - task_instance=task_instance, - timeout=self.timeout, - ).run( - lambda _context: operator.execute.__wrapped__( - self=operator, context=_context - ) - ) # TODO: change back to operator.execute once ExecutorSafeguard is fixed - self.log.debug("result: %s", result) - self.log.debug("do_xcom_push: %s", operator.do_xcom_push) - if operator.do_xcom_push: - return result + task_instance=task, + ).run(task_deferred) def _create_task(self, context: Context, index: int) -> TaskInstance: operator = self._unmap_operator(index) @@ -388,34 +404,4 @@ def execute(self, context: Context): if self.do_xcom_push: return results - - -def stream(self, **mapped_kwargs: OperatorExpandArgument) -> StreamedOperator: - if not mapped_kwargs: - raise TypeError("no arguments to expand against") - validate_mapping_kwargs(self.operator_class, "stream", mapped_kwargs) - prevent_duplicates( - self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified" - ) - - expand_input = DictOfListsExpandInput(mapped_kwargs) - ensure_xcomarg_return_value(expand_input.value) - - kwargs = {} - - for parameter_name in BaseOperator._comps: - parameter_value = self.kwargs.get(parameter_name) - if parameter_value: - kwargs[parameter_name] = parameter_value - - # We don't retry the whole stream operator, we retry the individual tasks - kwargs["retries"] = 0 - # We don't want to time out the whole stream operator, we only time out the individual tasks - kwargs["timeout"] = kwargs.pop("execution_timeout", None) - - return StreamedOperator( - **kwargs, - operator_class=self.operator_class, - expand_input=expand_input, - partial_kwargs=self.kwargs, - ) + return None From 5ddb9500a4b1753c1db9df387e543930f29eb61a Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 2 Dec 2024 17:23:55 +0100 Subject: [PATCH 54/97] feature: Re-added semaphore for async execution --- airflow/models/mappedoperator.py | 8 ++--- airflow/models/streamedoperator.py | 54 +++++++++++++++--------------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index b03720650cdc1..4ba01b840e1f6 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -26,6 +26,7 @@ import attr import methodtools + from airflow.exceptions import UnmappableOperator from airflow.models.abstractoperator import ( DEFAULT_EXECUTOR, @@ -40,7 +41,8 @@ DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING, DEFAULT_WEIGHT_RULE, AbstractOperator, - NotMapped, ) + NotMapped, +) from airflow.models.expandinput import ( DictOfListsExpandInput, ListOfDictsExpandInput, @@ -244,9 +246,7 @@ def stream(self, **mapped_kwargs: OperatorExpandArgument) -> StreamedOperator: if not mapped_kwargs: raise TypeError("no arguments to expand against") validate_mapping_kwargs(self.operator_class, "stream", mapped_kwargs) - prevent_duplicates( - self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified" - ) + prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") expand_input = DictOfListsExpandInput(mapped_kwargs) ensure_xcomarg_return_value(expand_input.value) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index c6f1cc7b824a3..9cdd28f44b580 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -21,14 +21,15 @@ import logging import os from abc import abstractmethod -from asyncio import AbstractEventLoop, Future, ensure_future +from asyncio import AbstractEventLoop, Future, Semaphore, ensure_future +from collections.abc import Generator, Iterable, Sequence from contextlib import contextmanager, suppress from datetime import timedelta from math import ceil from multiprocessing import TimeoutError from multiprocessing.pool import ThreadPool from time import sleep -from typing import TYPE_CHECKING, Any, Generator, Iterable, Sequence +from typing import TYPE_CHECKING, Any from airflow import XComArg from airflow.exceptions import ( @@ -41,8 +42,8 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.expandinput import ( ExpandInput, - is_mappable, _needs_run_time_resolution, + is_mappable, ) from airflow.models.taskinstance import TaskInstance from airflow.triggers.base import run_trigger @@ -144,13 +145,12 @@ def run(self): if self.task_instance.try_number == 0: self.operator.render_template_fields(context=self.context) self.operator.pre_execute(context=self.context) - self.task_instance._run_execute_callback( - context=self.context, task=self.operator - ) + self.task_instance._run_execute_callback(context=self.context, task=self.operator) try: outlet_events = context_get_outlet_events(self.context) + # TODO: change back to operator.execute once ExecutorSafeguard is fixed callable_runner = ExecutionCallableRunner( - func=self.operator.execute, + func=self.operator.execute.__wrapped__, outlet_events=outlet_events, logger=self.log, ) @@ -169,9 +169,7 @@ def run(self): self.task_instance.try_number += 1 self.task_instance.end_date = timezone.utcnow() - raise AirflowRescheduleTaskInstanceException( - task=self.task_instance - ) + raise AirflowRescheduleTaskInstanceException(task=self.task_instance) class TriggerExecutor(TaskExecutor): @@ -205,6 +203,7 @@ class StreamedOperator(BaseOperator): "expand_input", "partial_kwargs", "_log", + "_semaphore", ) def __init__( @@ -220,12 +219,11 @@ def __init__( self._operator_class = operator_class self.expand_input = expand_input self.partial_kwargs = partial_kwargs or {} - self.timeout = ( - timeout.total_seconds() if isinstance(timeout, timedelta) else timeout - ) + self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout self._mapped_kwargs: list[dict] = [] if not self.max_active_tis_per_dag: self.max_active_tis_per_dag = os.cpu_count() or 1 + self._semaphore = Semaphore(self.max_active_tis_per_dag) XComArg.apply_upstream_relationship(self, self.expand_input.value) @property @@ -284,10 +282,7 @@ def _run_tasks( failed_tasks: list[TaskInstance] = [] with ThreadPool(processes=self.max_active_tis_per_dag) as pool: - futures = [ - (task, pool.apply_async(self._run_operator, (context, task))) - for task in tasks - ] + futures = [(task, pool.apply_async(self._run_operator, (context, task))) for task in tasks] for task, future in futures: try: @@ -295,7 +290,11 @@ def _run_tasks( if isinstance(result, TaskDeferred): deferred_tasks.append( ensure_future( - self._run_deferrable(context=context, task=task, task_deferred=result) + self._run_deferrable( + context=context, + task=task, + task_deferred=result, + ) ) ) else: @@ -322,10 +321,11 @@ def _run_tasks( asyncio.gather(*deferred_tasks, return_exceptions=True) ): if isinstance(result, Exception): - if not isinstance(result, AirflowRescheduleTaskInstanceException): + if isinstance(result, AirflowRescheduleTaskInstanceException): + reschedule_date = result.reschedule_date + failed_tasks.append(result.task) + else: exception = result - reschedule_date = result.reschedule_date - failed_tasks.append(result.task) else: results.append(result) @@ -368,12 +368,12 @@ def _run_operator(cls, context: Context, task_instance: TaskInstance): except TaskDeferred as task_deferred: return task_deferred - @classmethod - async def _run_deferrable(cls, context: Context, task: TaskInstance, task_deferred: TaskDeferred): - return await TriggerExecutor( - context=context, - task_instance=task, - ).run(task_deferred) + async def _run_deferrable(self, context: Context, task: TaskInstance, task_deferred: TaskDeferred): + async with self._semaphore: + return await TriggerExecutor( + context=context, + task_instance=task, + ).run(task_deferred) def _create_task(self, context: Context, index: int) -> TaskInstance: operator = self._unmap_operator(index) From 6626671354adbaa4b952cca8dc9b18932278871b Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 3 Dec 2024 09:59:05 +0100 Subject: [PATCH 55/97] feature: Import gather from asyncio StreamedOperator --- airflow/models/streamedoperator.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index 9cdd28f44b580..f7bd9c6df538c 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -21,7 +21,7 @@ import logging import os from abc import abstractmethod -from asyncio import AbstractEventLoop, Future, Semaphore, ensure_future +from asyncio import AbstractEventLoop, Future, Semaphore, ensure_future, gather from collections.abc import Generator, Iterable, Sequence from contextlib import contextmanager, suppress from datetime import timedelta @@ -125,9 +125,8 @@ class OperatorExecutor(TaskExecutor): """ Run an operator with given task context and task instance. - If the execute function raises a TaskDeferred exception, then the trigger instance within the - TaskDeferred exception will be executed with the given context and task instance. The operator - or trigger will always be executed in an async way. + If the execute function raises a TaskDeferred exception, then the trigger will be executed in an + async way using the TriggerExecutor. :meta private: """ @@ -173,6 +172,16 @@ def run(self): class TriggerExecutor(TaskExecutor): + """ + Run a trigger with given task deferred exception. + + If the next method raises a TaskDeferred exception, then the trigger instance will be re-executed with + the given TaskDeferred exception until no more TaskDeferred exceptions occur. The trigger will always + be executed in an async way. + + :meta private: + """ + async def run(self, task_deferred: TaskDeferred): event = await run_trigger(task_deferred.trigger) @@ -318,7 +327,7 @@ def _run_tasks( with event_loop() as loop: for result in loop.run_until_complete( - asyncio.gather(*deferred_tasks, return_exceptions=True) + gather(*deferred_tasks, return_exceptions=True) ): if isinstance(result, Exception): if isinstance(result, AirflowRescheduleTaskInstanceException): From bc754841c49c91c81e36732ec4040c307c17cb5e Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 3 Dec 2024 21:16:48 +0100 Subject: [PATCH 56/97] refactored: Refactored BaseExecutor as a context manager to handle common code between OperatorExecutor and TriggerExecutor --- airflow/models/streamedoperator.py | 148 +++++++++++++++-------------- 1 file changed, 78 insertions(+), 70 deletions(-) diff --git a/airflow/models/streamedoperator.py b/airflow/models/streamedoperator.py index f7bd9c6df538c..16ed401a3e701 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/streamedoperator.py @@ -22,14 +22,13 @@ import os from abc import abstractmethod from asyncio import AbstractEventLoop, Future, Semaphore, ensure_future, gather -from collections.abc import Generator, Iterable, Sequence from contextlib import contextmanager, suppress from datetime import timedelta from math import ceil from multiprocessing import TimeoutError from multiprocessing.pool import ThreadPool from time import sleep -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Generator, Iterable, Sequence from airflow import XComArg from airflow.exceptions import ( @@ -42,17 +41,18 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.expandinput import ( ExpandInput, - _needs_run_time_resolution, is_mappable, + _needs_run_time_resolution, ) from airflow.models.taskinstance import TaskInstance -from airflow.triggers.base import run_trigger from airflow.utils import timezone from airflow.utils.context import Context, context_get_outlet_events from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.operator_helpers import ExecutionCallableRunner from airflow.utils.task_instance_session import get_current_task_instance_session +from airflow.triggers.base import run_trigger + if TYPE_CHECKING: import jinja2 from sqlalchemy.orm import Session @@ -104,10 +104,43 @@ def operator(self) -> BaseOperator: def run(self, *args, **kwargs): raise NotImplementedError() - def _handle_result(self, result): - """ - Common logic to handle result and post-execution tasks. - """ + def __enter__(self): + if self.log.isEnabledFor(logging.INFO): + self.log.info( + "Attempting running task %s of %s for %s with map_index %s.", + self.task_instance.try_number, + self.operator.retries, + type(self.operator).__name__, + self.task_instance.map_index, + ) + + if self.task_instance.try_number == 0: + self.operator.render_template_fields(context=self.context) + self.operator.pre_execute(context=self.context) + self.task_instance._run_execute_callback( + context=self.context, task=self.operator + ) + + return self + + def __exit__(self, exc_type, exc_value, traceback): + if exc_value: + if isinstance(exc_value, AirflowException): + if self.task_instance.next_try_number > self.operator.retries: + self.log.error( + "Max number of attempts for %s with map_index %s failed due to: %s", + type(self.operator).__name__, + self.task_instance.map_index, + exc_value, + ) + raise exc_value + + self.task_instance.try_number += 1 + self.task_instance.end_date = timezone.utcnow() + + raise AirflowRescheduleTaskInstanceException(task=self.task_instance) + raise exc_value + self.operator.post_execute(context=self.context) if self.log.isEnabledFor(logging.INFO): self.log.info( @@ -116,9 +149,6 @@ def _handle_result(self, result): type(self.operator).__name__, self.task_instance.next_try_number, ) - if self.operator.do_xcom_push: - return result - return None class OperatorExecutor(TaskExecutor): @@ -131,44 +161,14 @@ class OperatorExecutor(TaskExecutor): :meta private: """ - def run(self): - if self.log.isEnabledFor(logging.INFO): - self.log.info( - "Attempting running task %s of %s for %s with map_index %s.", - self.task_instance.try_number, - self.operator.retries, - type(self.operator).__name__, - self.task_instance.map_index, - ) - - if self.task_instance.try_number == 0: - self.operator.render_template_fields(context=self.context) - self.operator.pre_execute(context=self.context) - self.task_instance._run_execute_callback(context=self.context, task=self.operator) - try: - outlet_events = context_get_outlet_events(self.context) - # TODO: change back to operator.execute once ExecutorSafeguard is fixed - callable_runner = ExecutionCallableRunner( - func=self.operator.execute.__wrapped__, - outlet_events=outlet_events, - logger=self.log, - ) - result = callable_runner.run(self.operator, self.context) - return self._handle_result(result) - except AirflowException as e: - if self.task_instance.next_try_number > self.operator.retries: - self.log.error( - "Max number of attempts for %s with map_index %s failed due to: %s", - type(self.operator).__name__, - self.task_instance.map_index, - e, - ) - raise e - - self.task_instance.try_number += 1 - self.task_instance.end_date = timezone.utcnow() - - raise AirflowRescheduleTaskInstanceException(task=self.task_instance) + def run(self, *args, **kwargs): + outlet_events = context_get_outlet_events(self.context) + # TODO: change back to operator.execute once ExecutorSafeguard is fixed + return ExecutionCallableRunner( + func=self.operator.execute.__wrapped__, + outlet_events=outlet_events, + logger=self.log, + ).run(self.operator, self.context) class TriggerExecutor(TaskExecutor): @@ -195,8 +195,12 @@ async def run(self, task_deferred: TaskDeferred): next_method = BaseOperator.next_callable( self.operator, task_deferred.method_name, task_deferred.kwargs ) - result = next_method(self.context, event.payload) - return self._handle_result(result) + outlet_events = context_get_outlet_events(self.context) + return ExecutionCallableRunner( + func=next_method, + outlet_events=outlet_events, + logger=self.log, + ).run(self.context, event.payload) except TaskDeferred as task_deferred: return await self.run(task_deferred=task_deferred) @@ -228,7 +232,9 @@ def __init__( self._operator_class = operator_class self.expand_input = expand_input self.partial_kwargs = partial_kwargs or {} - self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout + self.timeout = ( + timeout.total_seconds() if isinstance(timeout, timedelta) else timeout + ) self._mapped_kwargs: list[dict] = [] if not self.max_active_tis_per_dag: self.max_active_tis_per_dag = os.cpu_count() or 1 @@ -291,7 +297,10 @@ def _run_tasks( failed_tasks: list[TaskInstance] = [] with ThreadPool(processes=self.max_active_tis_per_dag) as pool: - futures = [(task, pool.apply_async(self._run_operator, (context, task))) for task in tasks] + futures = [ + (task, pool.apply_async(self._run_operator, (context, task))) + for task in tasks + ] for task, future in futures: try: @@ -306,7 +315,7 @@ def _run_tasks( ) ) ) - else: + elif result: results.append(result) except TimeoutError as e: self.log.warning("A timeout occurred for task_id %s", task.task_id) @@ -329,13 +338,15 @@ def _run_tasks( for result in loop.run_until_complete( gather(*deferred_tasks, return_exceptions=True) ): + self.log.debug("result: %s", result) + if isinstance(result, Exception): if isinstance(result, AirflowRescheduleTaskInstanceException): reschedule_date = result.reschedule_date failed_tasks.append(result.task) else: exception = result - else: + elif result: results.append(result) deferred_tasks.clear() @@ -343,7 +354,7 @@ def _run_tasks( if not failed_tasks: if exception: raise exception - return list(filter(None, results)) + return results # session = get_current_task_instance_session() # TaskInstance._set_state(context["ti"], TaskInstanceState.UP_FOR_RETRY, session) @@ -369,20 +380,17 @@ def _run_tasks( @classmethod def _run_operator(cls, context: Context, task_instance: TaskInstance): - try: - return OperatorExecutor( - context=context, - task_instance=task_instance, - ).run() - except TaskDeferred as task_deferred: - return task_deferred - - async def _run_deferrable(self, context: Context, task: TaskInstance, task_deferred: TaskDeferred): - async with self._semaphore: - return await TriggerExecutor( - context=context, - task_instance=task, - ).run(task_deferred) + with OperatorExecutor(context=context, task_instance=task_instance) as executor: + try: + return executor.run() + except TaskDeferred as task_deferred: + return task_deferred + + async def _run_deferrable( + self, context: Context, task: TaskInstance, task_deferred: TaskDeferred + ): + with TriggerExecutor(context=context, task_instance=task) as executor: + return await executor.run(task_deferred) def _create_task(self, context: Context, index: int) -> TaskInstance: operator = self._unmap_operator(index) From 065dc6f4e5c57c049a7eadaff182ccd4e18e6e7b Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 4 Dec 2024 07:51:56 +0100 Subject: [PATCH 57/97] refactored: Renamed StreamedOperator to IterableOperator --- ...treamedoperator.py => iterableoperator.py} | 2 +- airflow/models/mappedoperator.py | 6 +- ...edoperator.py => test_iterableoperator.py} | 128 +++++++++--------- 3 files changed, 68 insertions(+), 68 deletions(-) rename airflow/models/{streamedoperator.py => iterableoperator.py} (99%) rename tests/models/{test_streamedoperator.py => test_iterableoperator.py} (93%) diff --git a/airflow/models/streamedoperator.py b/airflow/models/iterableoperator.py similarity index 99% rename from airflow/models/streamedoperator.py rename to airflow/models/iterableoperator.py index 16ed401a3e701..92086fa101961 100644 --- a/airflow/models/streamedoperator.py +++ b/airflow/models/iterableoperator.py @@ -205,7 +205,7 @@ async def run(self, task_deferred: TaskDeferred): return await self.run(task_deferred=task_deferred) -class StreamedOperator(BaseOperator): +class IterableOperator(BaseOperator): """Object representing a streamed operator in a DAG.""" _operator_class: type[BaseOperator] diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 4ba01b840e1f6..1192d0da82f25 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -78,9 +78,9 @@ OperatorExpandArgument, OperatorExpandKwargsArgument, ) + from airflow.models.iterableoperator import IterableOperator from airflow.models.operator import Operator from airflow.models.param import ParamsDict - from airflow.models.streamedoperator import StreamedOperator from airflow.models.xcom_arg import XComArg from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.utils.context import Context @@ -240,8 +240,8 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: ) return op - def stream(self, **mapped_kwargs: OperatorExpandArgument) -> StreamedOperator: - from airflow.models.streamedoperator import StreamedOperator + def iterate(self, **mapped_kwargs: OperatorExpandArgument) -> IterableOperator: + from airflow.models.iterableoperator import IterableOperator if not mapped_kwargs: raise TypeError("no arguments to expand against") diff --git a/tests/models/test_streamedoperator.py b/tests/models/test_iterableoperator.py similarity index 93% rename from tests/models/test_streamedoperator.py rename to tests/models/test_iterableoperator.py index 92e9455d943fd..9d8667bfc07a3 100644 --- a/tests/models/test_streamedoperator.py +++ b/tests/models/test_iterableoperator.py @@ -32,7 +32,7 @@ from airflow.models.dag import DAG from airflow.models.mappedoperator import MappedOperator from airflow.models.param import ParamsDict -from airflow.models.streamedoperator import StreamedOperator +from airflow.models.IterableOperator import IterableOperator from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.models.xcom_arg import XComArg @@ -56,7 +56,7 @@ def test_task_mapping_with_dag(): with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE) as dag: task1 = BaseOperator(task_id="op1") literal = ["a", "b", "c"] - mapped = MockOperator.partial(task_id="task_2").stream(arg2=literal) + mapped = MockOperator.partial(task_id="task_2").iterate(arg2=literal) finish = MockOperator(task_id="finish") task1 >> mapped >> finish @@ -91,7 +91,7 @@ def execute(self, context: Context): with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE) as dag: task1 = CustomOperator(task_id="op1", arg=None) unrenderable_values = [UnrenderableClass(), UnrenderableClass()] - mapped = CustomOperator.partial(task_id="task_2").stream(arg=unrenderable_values) + mapped = CustomOperator.partial(task_id="task_2").iterate(arg=unrenderable_values) task1 >> mapped dag.test() assert ( @@ -105,13 +105,13 @@ def test_task_mapping_without_dag_context(): with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE) as dag: task1 = BaseOperator(task_id="op1") literal = ["a", "b", "c"] - streamed = MockOperator.partial(task_id="task_2").stream(arg2=literal) + iterable = MockOperator.partial(task_id="task_2").iterate(arg2=literal) - task1 >> streamed + task1 >> iterable - assert isinstance(streamed, StreamedOperator) - assert streamed in dag.tasks - assert task1.downstream_list == [streamed] + assert isinstance(iterable, IterableOperator) + assert iterable in dag.tasks + assert task1.downstream_list == [iterable] # At parse time there should only be two tasks! assert len(dag.tasks) == 2 @@ -121,38 +121,38 @@ def test_task_mapping_default_args(): with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE, default_args=default_args): task1 = BaseOperator(task_id="op1") literal = ["a", "b", "c"] - streamed = MockOperator.partial(task_id="task_2").stream(arg2=literal) + iterable = MockOperator.partial(task_id="task_2").iterate(arg2=literal) - task1 >> streamed + task1 >> iterable - assert streamed.owner == "test" - assert streamed.start_date == pendulum.instance(default_args["start_date"]) + assert iterable.owner == "test" + assert iterable.start_date == pendulum.instance(default_args["start_date"]) def test_task_mapping_override_default_args(): default_args = {"retries": 2, "start_date": DEFAULT_DATE.now()} with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE, default_args=default_args): literal = ["a", "b", "c"] - streamed = MockOperator.partial(task_id="task", retries=1).stream(arg2=literal) + iterable = MockOperator.partial(task_id="task", retries=1).iterate(arg2=literal) - # retries should be 0 because it will be applied on the streamed tasks - assert streamed.retries == 0 + # retries should be 0 because it will be applied on the iterable tasks + assert iterable.retries == 0 # start_date should be equal to default_args["start_date"] because it is not provided as partial arg - assert streamed.start_date == pendulum.instance(default_args["start_date"]) + assert iterable.start_date == pendulum.instance(default_args["start_date"]) # owner should be equal to Airflow default owner (airflow) because it is not provided at all - assert streamed.owner == "airflow" + assert iterable.owner == "airflow" def test_map_unknown_arg_raises(): with pytest.raises(TypeError, match=r"argument 'file'"): - BaseOperator.partial(task_id="a").stream(file=[1, 2, {"a": "b"}]) + BaseOperator.partial(task_id="a").iterate(file=[1, 2, {"a": "b"}]) def test_map_xcom_arg(): """Test that dependencies are correct when mapping with an XComArg""" with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE): task1 = BaseOperator(task_id="op1") - mapped = MockOperator.partial(task_id="task_2").stream(arg2=task1.output) + mapped = MockOperator.partial(task_id="task_2").iterate(arg2=task1.output) finish = MockOperator(task_id="finish") mapped >> finish @@ -178,8 +178,8 @@ def execute(self, context): with dag_maker("test-dag", session=session, start_date=DEFAULT_DATE) as dag: upstream_return = [1, 2, 3] task1 = PushExtraXComOperator(return_value=upstream_return, task_id="task_1") - task2 = PushExtraXComOperator.partial(task_id="task_2").stream(return_value=task1.output) - task3 = PushExtraXComOperator.partial(task_id="task_3").stream(return_value=task2.output) + task2 = PushExtraXComOperator.partial(task_id="task_2").iterate(return_value=task1.output) + task3 = PushExtraXComOperator.partial(task_id="task_3").iterate(return_value=task2.output) dr = dag_maker.create_dagrun() ti_1 = dr.get_task_instance("task_1", session) @@ -247,7 +247,7 @@ def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expec literal = [1, 2, {"a": "b"}] with dag_maker(session=session): task1 = BaseOperator(task_id="op1") - mapped = MockOperator.partial(task_id="task_2").stream(arg2=task1.output) + mapped = MockOperator.partial(task_id="task_2").iterate(arg2=task1.output) dr = dag_maker.create_dagrun() @@ -298,7 +298,7 @@ def test_expand_mapped_task_failed_state_in_db(dag_maker, session): literal = [1, 2] with dag_maker(session=session): task1 = BaseOperator(task_id="op1") - mapped = MockOperator.partial(task_id="task_2").stream(arg2=task1.output) + mapped = MockOperator.partial(task_id="task_2").iterate(arg2=task1.output) dr = dag_maker.create_dagrun() @@ -344,7 +344,7 @@ def test_expand_mapped_task_failed_state_in_db(dag_maker, session): def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): with dag_maker(session=session): task1 = BaseOperator(task_id="op1") - mapped = MockOperator.partial(task_id="task_2").stream(arg2=task1.output) + mapped = MockOperator.partial(task_id="task_2").iterate(arg2=task1.output) dr = dag_maker.create_dagrun() @@ -363,7 +363,7 @@ def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): def test_mapped_task_applies_default_args_classic(dag_maker): with dag_maker(default_args={"execution_timeout": timedelta(minutes=30)}) as dag: MockOperator(task_id="simple", arg1=None, arg2=0) - MockOperator.partial(task_id="mapped").stream(arg1=[1], arg2=[2, 3]) + MockOperator.partial(task_id="mapped").iterate(arg1=[1], arg2=[2, 3]) assert dag.get_task("simple").execution_timeout == timedelta(minutes=30) assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30) @@ -381,7 +381,7 @@ def mapped(arg): pass simple(arg=0) - mapped.stream(arg=[1, 2]) + mapped.iterate(arg=[1, 2]) assert dag.get_task("simple").execution_timeout == timedelta(minutes=30) assert dag.get_task("mapped").execution_timeout == timedelta(minutes=30) @@ -398,10 +398,10 @@ def mapped(arg): ) def test_mapped_expand_against_params(dag_maker, dag_params, task_params, expected_partial_params): with dag_maker(params=dag_params) as dag: - MockOperator.partial(task_id="t", params=task_params).stream(params=[{"c": "x"}, {"d": 1}]) + MockOperator.partial(task_id="t", params=task_params).iterate(params=[{"c": "x"}, {"d": 1}]) t = dag.get_task("t") - assert isinstance(t, StreamedOperator) + assert isinstance(t, IterableOperator) assert t.params == expected_partial_params assert t.expand_input.value == {"params": [{"c": "x"}, {"d": 1}]} @@ -439,7 +439,7 @@ def execute(self, context): output1 = task1.output mapped = MyOperator.partial( task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" - ).stream(map_template=output1, map_static=output1, file_template=["/path/to/file.ext"]) + ).iterate(map_template=output1, map_static=output1, file_template=["/path/to/file.ext"]) dr = dag_maker.create_dagrun() ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) @@ -527,7 +527,7 @@ def test_mapped_render_nested_template_fields(dag_maker, session): with dag_maker(session=session): MockOperatorWithNestedFields.partial( task_id="t", arg2=NestedFields(field_1="{{ ti.task_id }}", field_2="value_2") - ).stream(arg1=["{{ ti.task_id }}", ["s", "{{ ti.task_id }}"]]) + ).iterate(arg1=["{{ ti.task_id }}", ["s", "{{ ti.task_id }}"]]) dr = dag_maker.create_dagrun() decision = dr.task_instance_scheduling_decisions() @@ -624,7 +624,7 @@ def __init__(self, *, map_name: str, **kwargs): def execute(self, context): context["map_name"] = self.map_name - return HasMapName.partial(task_id=task_id, map_index_template=template).stream( + return HasMapName.partial(task_id=task_id, map_index_template=template).iterate( map_name=map_names, ) @@ -637,7 +637,7 @@ def task1(map_name): context = get_current_context() context["map_name"] = map_name - return task1.stream(map_name=map_names) + return task1.iterate(map_name=map_names) def _create_named_map_index_renders_on_failure_classic(*, task_id, map_names, template): @@ -650,7 +650,7 @@ def execute(self, context): context["map_name"] = self.map_name raise AirflowSkipException("Imagine this task failed!") - return HasMapName.partial(task_id=task_id, map_index_template=template).stream( + return HasMapName.partial(task_id=task_id, map_index_template=template).iterate( map_name=map_names, ) @@ -664,7 +664,7 @@ def task1(map_name): context["map_name"] = map_name raise AirflowSkipException("Imagine this task failed!") - return task1.stream(map_name=map_names) + return task1.iterate(map_name=map_names) @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode @@ -759,7 +759,7 @@ def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, ses def test_xcomarg_property_of_mapped_operator(dag_maker): with dag_maker("test_xcomarg_property_of_mapped_operator"): - op_a = MockOperator.partial(task_id="a").stream(arg1=["x", "y", "z"]) + op_a = MockOperator.partial(task_id="a").iterate(arg1=["x", "y", "z"]) dag_maker.create_dagrun() assert op_a.output == XComArg(op_a) @@ -767,8 +767,8 @@ def test_xcomarg_property_of_mapped_operator(dag_maker): def test_set_xcomarg_dependencies_with_mapped_operator(dag_maker): with dag_maker("test_set_xcomargs_dependencies_with_mapped_operator"): - op1 = MockOperator.partial(task_id="op1").stream(arg1=[1, 2, 3]) - op2 = MockOperator.partial(task_id="op2").stream(arg2=["a", "b", "c"]) + op1 = MockOperator.partial(task_id="op1").iterate(arg1=[1, 2, 3]) + op2 = MockOperator.partial(task_id="op2").iterate(arg2=["a", "b", "c"]) op3 = MockOperator(task_id="op3", arg1=op1.output) op4 = MockOperator(task_id="op4", arg1=[op1.output, op2.output]) op5 = MockOperator(task_id="op5", arg1={"op1": op1.output, "op2": op2.output}) @@ -793,7 +793,7 @@ def execute(self, context): assert set(self.arg1) == {1, 2, 3} with dag_maker("test_all_xcomargs_from_mapped_tasks_are_consumable"): - op1 = PushXcomOperator.partial(task_id="op1").stream(arg1=[1, 2, 3]) + op1 = PushXcomOperator.partial(task_id="op1").iterate(arg1=[1, 2, 3]) ConsumeXcomOperator(task_id="op2", arg1=op1.output) dr = dag_maker.create_dagrun() @@ -809,7 +809,7 @@ def test_task_mapping_with_task_group_context(): with TaskGroup("test-group") as group: literal = ["a", "b", "c"] - mapped = MockOperator.partial(task_id="task_2").stream(arg2=literal) + mapped = MockOperator.partial(task_id="task_2").iterate(arg2=literal) task1 >> group >> finish @@ -830,7 +830,7 @@ def test_task_mapping_with_explicit_task_group(): group = TaskGroup("test-group") literal = ["a", "b", "c"] - mapped = MockOperator.partial(task_id="task_2", task_group=group).stream(arg2=literal) + mapped = MockOperator.partial(task_id="task_2", task_group=group).iterate(arg2=literal) task1 >> group >> finish @@ -908,7 +908,7 @@ def my_teardown(val): print(f"teardown: {val}") s = my_setup() - t = my_teardown.stream(val=s) + t = my_teardown.iterate(val=s) with t: my_work(s) else: @@ -921,7 +921,7 @@ def my_work(val): with dag_maker() as dag: my_setup = self.classic_operator("my_setup", [[1], [2], [3]]) my_teardown = self.classic_operator("my_teardown", partial=True) - t = my_teardown.stream(op_args=my_setup.output) + t = my_teardown.iterate(op_args=my_setup.output) with t.as_teardown(setups=my_setup): my_work(my_setup.output) @@ -975,7 +975,7 @@ def my_work(val): def my_teardown(val): print(f"teardown: {val}") - s = my_setup.stream(val=["data1.json", "data2.json", "data3.json"]) + s = my_setup.iterate(val=["data1.json", "data2.json", "data3.json"]) o_setup = other_setup() o_teardown = other_teardown() with o_teardown.as_teardown(setups=o_setup): @@ -999,7 +999,7 @@ def my_work(val): my_teardown = self.classic_operator("my_teardown") my_setup = self.classic_operator("my_setup", partial=True, fail=True) - s = my_setup.stream(op_args=[["data1.json"], ["data2.json"], ["data3.json"]]) + s = my_setup.iterate(op_args=[["data1.json"], ["data2.json"], ["data3.json"]]) o_setup = self.classic_operator("other_setup") o_teardown = self.classic_operator("other_teardown") with o_teardown.as_teardown(setups=o_setup): @@ -1063,7 +1063,7 @@ def my_work(val): def my_teardown(val): print(f"teardown: {val}") - s = my_setup.stream(val=["data1.json", "data2.json", "data3.json"]) + s = my_setup.iterate(val=["data1.json", "data2.json", "data3.json"]) o_setup = other_setup() o_teardown = other_teardown() with o_teardown.as_teardown(setups=o_setup): @@ -1096,7 +1096,7 @@ def my_work(val): print(f"work: {val}") my_setup = self.classic_operator("my_setup", partial=True, fail=True) - s = my_setup.stream(op_args=[["data1.json"], ["data2.json"], ["data3.json"]]) + s = my_setup.iterate(op_args=[["data1.json"], ["data2.json"], ["data3.json"]]) o_setup = other_setup() o_teardown = other_teardown() with o_teardown.as_teardown(setups=o_setup): @@ -1163,7 +1163,7 @@ def my_work(val): def my_teardown(val): print(f"teardown: {val}") - s = my_setup.stream(val=["data1.json", "data2.json", "data3.json"]) + s = my_setup.iterate(val=["data1.json", "data2.json", "data3.json"]) o_setup = other_setup() o_teardown = other_teardown() with o_teardown.as_teardown(setups=o_setup): @@ -1207,7 +1207,7 @@ def my_work(val): def my_teardown_callable(val): print(f"teardown: {val}") - s = my_setup.stream(op_args=[["data1.json"], ["data2.json"], ["data3.json"]]) + s = my_setup.iterate(op_args=[["data1.json"], ["data2.json"], ["data3.json"]]) o_setup = other_setup() o_teardown = other_teardown() with o_teardown.as_teardown(setups=o_setup): @@ -1261,7 +1261,7 @@ def my_teardown(val): print(f"teardown: {val}") s = my_setup() - t = my_teardown.stream(val=s).as_teardown(setups=s) + t = my_teardown.iterate(val=s).as_teardown(setups=s) with t: my_work(s) else: @@ -1276,7 +1276,7 @@ def my_work(val): my_teardown = self.classic_operator(task_id="my_teardown", partial=True) s = self.classic_operator(task_id="my_setup", ret=[[1], [2], [3]]) - t = my_teardown.stream(op_args=s.output).as_teardown(setups=s) + t = my_teardown.iterate(op_args=s.output).as_teardown(setups=s) with t: my_work(s) dr = dag.test() @@ -1319,7 +1319,7 @@ def my_teardown(val): raise ValueError("failure") s = my_setup() - t = my_teardown.stream(val=s).as_teardown(setups=s, on_failure_fail_dagrun=True) + t = my_teardown.iterate(val=s).as_teardown(setups=s, on_failure_fail_dagrun=True) with t: my_work(s) # todo: if on_failure_fail_dagrun=True, should we still regard the WORK task as a leaf? @@ -1339,7 +1339,7 @@ def my_teardown_callable(val): s = self.classic_operator(task_id="my_setup", ret=[[1], [2], [3]]) my_teardown = PythonOperator.partial( task_id="my_teardown", python_callable=my_teardown_callable - ).stream(op_args=s.output) + ).iterate(op_args=s.output) t = my_teardown.as_teardown(setups=s, on_failure_fail_dagrun=True) with t: my_work(s.output) @@ -1389,7 +1389,7 @@ def file_transforms(filename): with t: my_work(filename) - file_transforms.stream(filename=["data1.json", "data2.json", "data3.json"]) + file_transforms.iterate(filename=["data1.json", "data2.json", "data3.json"]) else: with dag_maker() as dag: @@ -1418,7 +1418,7 @@ def file_transforms(filename): with t.as_teardown(setups=s): my_work(filename) - file_transforms.stream(filename=[["data1.json"], ["data2.json"], ["data3.json"]]) + file_transforms.iterate(filename=[["data1.json"], ["data2.json"], ["data3.json"]]) dr = dag.test() states = self.get_states(dr) expected = { @@ -1464,7 +1464,7 @@ def file_transforms(filename): with t: my_work(filename) - file_transforms.stream(filename=["data1.json", "data2.json", "data3.json"]) + file_transforms.iterate(filename=["data1.json", "data2.json", "data3.json"]) else: with dag_maker() as dag: @@ -1492,7 +1492,7 @@ def file_transforms(filename): with t: my_work(filename) - file_transforms.stream(filename=[["data1.json"], ["data2.json"], ["data3.json"]]) + file_transforms.iterate(filename=[["data1.json"], ["data2.json"], ["data3.json"]]) dr = dag.test() states = self.get_states(dr) expected = { @@ -1530,7 +1530,7 @@ def my_work(val): def my_teardown(val): print(f"teardown: {val}") - s = my_setup.stream(val=["data1.json", "data2.json", "data3.json"]) + s = my_setup.iterate(val=["data1.json", "data2.json", "data3.json"]) with my_teardown(s).as_teardown(setups=s): my_work(s) else: @@ -1549,7 +1549,7 @@ def my_work(val): print(f"work: {val}") s = PythonOperator.partial(task_id="my_setup", python_callable=my_setup_callable) - s = s.stream(op_args=[["data1.json"], ["data2.json"], ["data3.json"]]) + s = s.iterate(op_args=[["data1.json"], ["data2.json"], ["data3.json"]]) t = self.classic_operator("my_teardown") with t.as_teardown(setups=s): my_work(s.output) @@ -1587,7 +1587,7 @@ def my_teardown(val): print(f"teardown: {val}") s = my_setup() - t = my_teardown.stream(val=s).as_teardown(setups=s) + t = my_teardown.iterate(val=s).as_teardown(setups=s) with t: my_work(s) @@ -1626,7 +1626,7 @@ def my_teardown(val): print(f"teardown: {val}") s = my_setup() - t = my_teardown.stream(val=s).as_teardown(setups=s) + t = my_teardown.iterate(val=s).as_teardown(setups=s) with t: my_work(s) tg1, tg2 = dag.task_group.children.values() @@ -1673,8 +1673,8 @@ def my_work(val): def my_teardown(val): print(f"teardown: {val}") - s = my_setup.stream(val=my_pre_setup()) - t = my_teardown.stream(val=s).as_teardown(setups=s) + s = my_setup.iterate(val=my_pre_setup()) + t = my_teardown.iterate(val=s).as_teardown(setups=s) with t: my_work(s) tg1, tg2 = dag.task_group.children.values() @@ -1715,7 +1715,7 @@ def last(n): ... def group(n: int) -> None: last(double(n)) - group.stream(n=make_list()) + group.iterate(n=make_list()) dr = dag.test() states = self.get_states(dr) @@ -1744,7 +1744,7 @@ def last(n): ... def group(n: int) -> None: last(double(n)) - group.stream(n=[1, 2, 3]) + group.iterate(n=[1, 2, 3]) dr = dag.test() states = self.get_states(dr) From 1fd075f45504592324a2755445019cf5ce97ccc5 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 4 Dec 2024 08:54:15 +0100 Subject: [PATCH 58/97] refactor: Make sure the run_deferrable method uses the semaphore to prevent too much parallelism --- airflow/models/iterableoperator.py | 34 +++++++++++------------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index 92086fa101961..102f2ba2ba536 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -22,13 +22,14 @@ import os from abc import abstractmethod from asyncio import AbstractEventLoop, Future, Semaphore, ensure_future, gather +from collections.abc import Generator, Iterable, Sequence from contextlib import contextmanager, suppress from datetime import timedelta from math import ceil from multiprocessing import TimeoutError from multiprocessing.pool import ThreadPool from time import sleep -from typing import TYPE_CHECKING, Any, Generator, Iterable, Sequence +from typing import TYPE_CHECKING, Any from airflow import XComArg from airflow.exceptions import ( @@ -41,18 +42,17 @@ from airflow.models.baseoperator import BaseOperator from airflow.models.expandinput import ( ExpandInput, - is_mappable, _needs_run_time_resolution, + is_mappable, ) from airflow.models.taskinstance import TaskInstance +from airflow.triggers.base import run_trigger from airflow.utils import timezone from airflow.utils.context import Context, context_get_outlet_events from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.operator_helpers import ExecutionCallableRunner from airflow.utils.task_instance_session import get_current_task_instance_session -from airflow.triggers.base import run_trigger - if TYPE_CHECKING: import jinja2 from sqlalchemy.orm import Session @@ -117,9 +117,7 @@ def __enter__(self): if self.task_instance.try_number == 0: self.operator.render_template_fields(context=self.context) self.operator.pre_execute(context=self.context) - self.task_instance._run_execute_callback( - context=self.context, task=self.operator - ) + self.task_instance._run_execute_callback(context=self.context, task=self.operator) return self @@ -232,9 +230,7 @@ def __init__( self._operator_class = operator_class self.expand_input = expand_input self.partial_kwargs = partial_kwargs or {} - self.timeout = ( - timeout.total_seconds() if isinstance(timeout, timedelta) else timeout - ) + self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout self._mapped_kwargs: list[dict] = [] if not self.max_active_tis_per_dag: self.max_active_tis_per_dag = os.cpu_count() or 1 @@ -297,10 +293,7 @@ def _run_tasks( failed_tasks: list[TaskInstance] = [] with ThreadPool(processes=self.max_active_tis_per_dag) as pool: - futures = [ - (task, pool.apply_async(self._run_operator, (context, task))) - for task in tasks - ] + futures = [(task, pool.apply_async(self._run_operator, (context, task))) for task in tasks] for task, future in futures: try: @@ -335,9 +328,7 @@ def _run_tasks( self.log.info("Running %s deferred tasks", len(deferred_tasks)) with event_loop() as loop: - for result in loop.run_until_complete( - gather(*deferred_tasks, return_exceptions=True) - ): + for result in loop.run_until_complete(gather(*deferred_tasks, return_exceptions=True)): self.log.debug("result: %s", result) if isinstance(result, Exception): @@ -386,11 +377,10 @@ def _run_operator(cls, context: Context, task_instance: TaskInstance): except TaskDeferred as task_deferred: return task_deferred - async def _run_deferrable( - self, context: Context, task: TaskInstance, task_deferred: TaskDeferred - ): - with TriggerExecutor(context=context, task_instance=task) as executor: - return await executor.run(task_deferred) + async def _run_deferrable(self, context: Context, task: TaskInstance, task_deferred: TaskDeferred): + async with self._semaphore: + async with TriggerExecutor(context=context, task_instance=task) as executor: + return await executor.run(task_deferred) def _create_task(self, context: Context, index: int) -> TaskInstance: operator = self._unmap_operator(index) From 73a7779a205ee59622076af6b2590ace2a46d3cd Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 4 Dec 2024 10:11:17 +0100 Subject: [PATCH 59/97] refactor: Make sure the run_deferrable method uses the semaphore to prevent too much parallelism --- airflow/models/iterableoperator.py | 49 ++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index 102f2ba2ba536..33dded2f9bb0a 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -85,6 +85,7 @@ def __init__( super().__init__() self.__context = context self._task_instance = task_instance + self._is_async_mode: bool = False # Flag to track sync/async mode @property def task_instance(self) -> TaskInstance: @@ -100,6 +101,10 @@ def context(self) -> Context: def operator(self) -> BaseOperator: return self.task_instance.task + @property + def mode(self) -> str: + return "async" if self._is_async_mode else "sync" + @abstractmethod def run(self, *args, **kwargs): raise NotImplementedError() @@ -107,20 +112,26 @@ def run(self, *args, **kwargs): def __enter__(self): if self.log.isEnabledFor(logging.INFO): self.log.info( - "Attempting running task %s of %s for %s with map_index %s.", + "Attempting running task %s of %s for %s with map_index %s in %s mode.", self.task_instance.try_number, self.operator.retries, type(self.operator).__name__, self.task_instance.map_index, + self.mode, ) if self.task_instance.try_number == 0: self.operator.render_template_fields(context=self.context) self.operator.pre_execute(context=self.context) - self.task_instance._run_execute_callback(context=self.context, task=self.operator) - + self.task_instance._run_execute_callback( + context=self.context, task=self.operator + ) return self + async def __aenter__(self): + self._is_async_mode = True + return self.__enter__() + def __exit__(self, exc_type, exc_value, traceback): if exc_value: if isinstance(exc_value, AirflowException): @@ -138,16 +149,19 @@ def __exit__(self, exc_type, exc_value, traceback): raise AirflowRescheduleTaskInstanceException(task=self.task_instance) raise exc_value - self.operator.post_execute(context=self.context) if self.log.isEnabledFor(logging.INFO): self.log.info( - "Task instance %s for %s finished successfully in %s attempts.", + "Task instance %s for %s finished successfully in %s attempts in %s mode.", self.task_instance.map_index, type(self.operator).__name__, self.task_instance.next_try_number, + self.mode, ) + async def __aexit__(self, exc_type, exc_value, traceback): + self.__exit__(exc_type, exc_value, traceback) + class OperatorExecutor(TaskExecutor): """ @@ -230,7 +244,9 @@ def __init__( self._operator_class = operator_class self.expand_input = expand_input self.partial_kwargs = partial_kwargs or {} - self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout + self.timeout = ( + timeout.total_seconds() if isinstance(timeout, timedelta) else timeout + ) self._mapped_kwargs: list[dict] = [] if not self.max_active_tis_per_dag: self.max_active_tis_per_dag = os.cpu_count() or 1 @@ -293,7 +309,10 @@ def _run_tasks( failed_tasks: list[TaskInstance] = [] with ThreadPool(processes=self.max_active_tis_per_dag) as pool: - futures = [(task, pool.apply_async(self._run_operator, (context, task))) for task in tasks] + futures = [ + (task, pool.apply_async(self._run_operator, (context, task))) + for task in tasks + ] for task, future in futures: try: @@ -328,7 +347,9 @@ def _run_tasks( self.log.info("Running %s deferred tasks", len(deferred_tasks)) with event_loop() as loop: - for result in loop.run_until_complete(gather(*deferred_tasks, return_exceptions=True)): + for result in loop.run_until_complete( + gather(*deferred_tasks, return_exceptions=True) + ): self.log.debug("result: %s", result) if isinstance(result, Exception): @@ -371,13 +392,15 @@ def _run_tasks( @classmethod def _run_operator(cls, context: Context, task_instance: TaskInstance): - with OperatorExecutor(context=context, task_instance=task_instance) as executor: - try: + try: + with OperatorExecutor(context=context, task_instance=task_instance) as executor: return executor.run() - except TaskDeferred as task_deferred: - return task_deferred + except TaskDeferred as task_deferred: + return task_deferred - async def _run_deferrable(self, context: Context, task: TaskInstance, task_deferred: TaskDeferred): + async def _run_deferrable( + self, context: Context, task: TaskInstance, task_deferred: TaskDeferred + ): async with self._semaphore: async with TriggerExecutor(context=context, task_instance=task) as executor: return await executor.run(task_deferred) From a5b238a5c317e99070e1988edd7fe25835467b16 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 4 Dec 2024 19:24:43 +0100 Subject: [PATCH 60/97] refactor: Updated IterableOperator --- airflow/models/iterableoperator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index 33dded2f9bb0a..787c513ebc754 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -177,10 +177,10 @@ def run(self, *args, **kwargs): outlet_events = context_get_outlet_events(self.context) # TODO: change back to operator.execute once ExecutorSafeguard is fixed return ExecutionCallableRunner( - func=self.operator.execute.__wrapped__, + func=self.operator.execute, outlet_events=outlet_events, logger=self.log, - ).run(self.operator, self.context) + ).run(self.context) class TriggerExecutor(TaskExecutor): @@ -393,7 +393,9 @@ def _run_tasks( @classmethod def _run_operator(cls, context: Context, task_instance: TaskInstance): try: - with OperatorExecutor(context=context, task_instance=task_instance) as executor: + with OperatorExecutor( + context=context, task_instance=task_instance + ) as executor: return executor.run() except TaskDeferred as task_deferred: return task_deferred From ed0a9bbdc1112b0393cd8a4dcdc2b4f5dec7bb01 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 16 Dec 2024 12:09:07 +0100 Subject: [PATCH 61/97] refactor: Reorganized imports TestIterableOperator --- tests/models/test_iterableoperator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/test_iterableoperator.py b/tests/models/test_iterableoperator.py index 9d8667bfc07a3..def7436dde5a5 100644 --- a/tests/models/test_iterableoperator.py +++ b/tests/models/test_iterableoperator.py @@ -30,9 +30,9 @@ from airflow.exceptions import AirflowSkipException from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG +from airflow.models.iterableoperator import IterableOperator from airflow.models.mappedoperator import MappedOperator from airflow.models.param import ParamsDict -from airflow.models.IterableOperator import IterableOperator from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.models.xcom_arg import XComArg @@ -42,6 +42,7 @@ from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.trigger_rule import TriggerRule from airflow.utils.xcom import XCOM_RETURN_KEY + from tests.models import DEFAULT_DATE from tests.test_utils.mapping import expand_mapped_task from tests.test_utils.mock_operators import MockOperator, MockOperatorWithNestedFields, NestedFields From ad73a8beab56ddd2fdd4e6ddb9acde3ba35016ea Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 16 Dec 2024 12:09:54 +0100 Subject: [PATCH 62/97] refactor: Fixed import IterableOperator in MappedOperator --- airflow/models/mappedoperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 1192d0da82f25..2b5a90b4b3e9a 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -263,7 +263,7 @@ def iterate(self, **mapped_kwargs: OperatorExpandArgument) -> IterableOperator: # We don't want to time out the whole stream operator, we only time out the individual tasks kwargs["timeout"] = kwargs.pop("execution_timeout", None) - return StreamedOperator( + return IterableOperator( **kwargs, operator_class=self.operator_class, expand_input=expand_input, From e2f3065c4c8647fad5029e06970cff199a74f377 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 16 Dec 2024 12:10:35 +0100 Subject: [PATCH 63/97] refactor: Removed unused import of DAG in exceptions module --- airflow/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/exceptions.py b/airflow/exceptions.py index 5aab8d093b910..64e00abce0162 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -31,7 +31,7 @@ if TYPE_CHECKING: from collections.abc import Sized - from airflow.models import DAG, DagRun, TaskInstance + from airflow.models import DagRun, TaskInstance class AirflowException(Exception): From 00d1de1dd7910d8e9f04b4e6c3cdb3ccf86ca081 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 16 Dec 2024 12:14:15 +0100 Subject: [PATCH 64/97] refactor: Reformatted IterableOperator --- airflow/models/iterableoperator.py | 31 ++++++++++++------------------ 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index 787c513ebc754..e301f9add8c66 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -123,9 +123,7 @@ def __enter__(self): if self.task_instance.try_number == 0: self.operator.render_template_fields(context=self.context) self.operator.pre_execute(context=self.context) - self.task_instance._run_execute_callback( - context=self.context, task=self.operator - ) + self.task_instance._run_execute_callback(context=self.context, task=self.operator) return self async def __aenter__(self): @@ -176,6 +174,12 @@ class OperatorExecutor(TaskExecutor): def run(self, *args, **kwargs): outlet_events = context_get_outlet_events(self.context) # TODO: change back to operator.execute once ExecutorSafeguard is fixed + if hasattr(self.operator.execute, "__wrapped__"): + return ExecutionCallableRunner( + func=self.operator.execute.__wrapped__, + outlet_events=outlet_events, + logger=self.log, + ).run(self.operator, self.context) return ExecutionCallableRunner( func=self.operator.execute, outlet_events=outlet_events, @@ -244,9 +248,7 @@ def __init__( self._operator_class = operator_class self.expand_input = expand_input self.partial_kwargs = partial_kwargs or {} - self.timeout = ( - timeout.total_seconds() if isinstance(timeout, timedelta) else timeout - ) + self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout self._mapped_kwargs: list[dict] = [] if not self.max_active_tis_per_dag: self.max_active_tis_per_dag = os.cpu_count() or 1 @@ -309,10 +311,7 @@ def _run_tasks( failed_tasks: list[TaskInstance] = [] with ThreadPool(processes=self.max_active_tis_per_dag) as pool: - futures = [ - (task, pool.apply_async(self._run_operator, (context, task))) - for task in tasks - ] + futures = [(task, pool.apply_async(self._run_operator, (context, task))) for task in tasks] for task, future in futures: try: @@ -347,9 +346,7 @@ def _run_tasks( self.log.info("Running %s deferred tasks", len(deferred_tasks)) with event_loop() as loop: - for result in loop.run_until_complete( - gather(*deferred_tasks, return_exceptions=True) - ): + for result in loop.run_until_complete(gather(*deferred_tasks, return_exceptions=True)): self.log.debug("result: %s", result) if isinstance(result, Exception): @@ -393,16 +390,12 @@ def _run_tasks( @classmethod def _run_operator(cls, context: Context, task_instance: TaskInstance): try: - with OperatorExecutor( - context=context, task_instance=task_instance - ) as executor: + with OperatorExecutor(context=context, task_instance=task_instance) as executor: return executor.run() except TaskDeferred as task_deferred: return task_deferred - async def _run_deferrable( - self, context: Context, task: TaskInstance, task_deferred: TaskDeferred - ): + async def _run_deferrable(self, context: Context, task: TaskInstance, task_deferred: TaskDeferred): async with self._semaphore: async with TriggerExecutor(context=context, task_instance=task) as executor: return await executor.run(task_deferred) From 8a48456bdf021e8d2ad2c5c21066f5fc7b118020 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 16 Dec 2024 13:29:48 +0100 Subject: [PATCH 65/97] refactor: Current time should be taken before execution of tasks in IterableOperator --- airflow/models/iterableoperator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index e301f9add8c66..8c16a4a258b91 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -304,6 +304,7 @@ def _run_tasks( tasks: Iterable[TaskInstance], results: list[Any] | None = None, ) -> list[Any]: + now = timezone.utcnow() exception: BaseException | None = None results = results or [] reschedule_date = timezone.utcnow() @@ -369,7 +370,7 @@ def _run_tasks( # TaskInstance._set_state(context["ti"], TaskInstanceState.UP_FOR_RETRY, session) # Calculate delay before the next retry - delay = reschedule_date - timezone.utcnow() + delay = reschedule_date - now delay_seconds = ceil(delay.total_seconds()) self.log.debug("delay_seconds: %s", delay_seconds) From cd73affdecb2ecd49a436fa221672defc28fd013 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 28 Jan 2025 13:22:24 +0100 Subject: [PATCH 66/97] refactor: Fixed some static checks --- airflow/models/iterableoperator.py | 33 +++++-------------- .../airflow/sdk/definitions/mappedoperator.py | 11 ++----- 2 files changed, 11 insertions(+), 33 deletions(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index 177888cbcf7c1..5f71856660652 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -124,9 +124,7 @@ def __enter__(self): if self.task_instance.try_number == 0: self.operator.render_template_fields(context=self.context) self.operator.pre_execute(context=self.context) - self.task_instance._run_execute_callback( - context=self.context, task=self.operator - ) + self.task_instance._run_execute_callback(context=self.context, task=self.operator) return self async def __aenter__(self): @@ -251,9 +249,7 @@ def __init__( self._operator_class = operator_class self.expand_input = expand_input self.partial_kwargs = partial_kwargs or {} - self.timeout = ( - timeout.total_seconds() if isinstance(timeout, timedelta) else timeout - ) + self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout self._mapped_kwargs: list[dict] = [] if not self.max_active_tis_per_dag: self.max_active_tis_per_dag = os.cpu_count() or 1 @@ -296,9 +292,7 @@ def _resolve(self, value, context: Context, session: Session): def _resolve_expand_input(self, context: Context, session: Session): if isinstance(self.expand_input.value, XComArg): - resolved_input = self.expand_input.value.resolve( - context=context, session=session - ) + resolved_input = self.expand_input.value.resolve(context=context, session=session) else: resolved_input = self.expand_input.value @@ -309,9 +303,7 @@ def _resolve_expand_input(self, context: Context, session: Session): ) else: - value = self._resolve( - value=resolved_input, context=context, session=session - ) + value = self._resolve(value=resolved_input, context=context, session=session) if isinstance(value, dict): for key, item in self._resolve( @@ -346,10 +338,7 @@ def _run_tasks( failed_tasks: list[TaskInstance] = [] with ThreadPool(processes=self.max_active_tis_per_dag) as pool: - futures = [ - (task, pool.apply_async(self._run_operator, (context, task))) - for task in tasks - ] + futures = [(task, pool.apply_async(self._run_operator, (context, task))) for task in tasks] for task, future in futures: try: @@ -384,9 +373,7 @@ def _run_tasks( self.log.info("Running %s deferred tasks", len(deferred_tasks)) with event_loop() as loop: - for result in loop.run_until_complete( - gather(*deferred_tasks, return_exceptions=True) - ): + for result in loop.run_until_complete(gather(*deferred_tasks, return_exceptions=True)): self.log.debug("result: %s", result) if isinstance(result, Exception): @@ -430,16 +417,12 @@ def _run_tasks( @classmethod def _run_operator(cls, context: Context, task_instance: TaskInstance): try: - with OperatorExecutor( - context=context, task_instance=task_instance - ) as executor: + with OperatorExecutor(context=context, task_instance=task_instance) as executor: return executor.run() except TaskDeferred as task_deferred: return task_deferred - async def _run_deferrable( - self, context: Context, task: TaskInstance, task_deferred: TaskDeferred - ): + async def _run_deferrable(self, context: Context, task: TaskInstance, task_deferred: TaskDeferred): async with self._semaphore: async with TriggerExecutor(context=context, task_instance=task) as executor: return await executor.run(task_deferred) diff --git a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py index b32fd0c16dd5f..4ac35fa4f4384 100644 --- a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -86,7 +86,7 @@ TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, list[TaskStateChangeCallback]] -ValidationSource = Union[Literal["expand"], Literal["partial"]] +ValidationSource = Union[Literal["expand"], Literal["iterate"], Literal["partial"]] def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: @@ -242,24 +242,19 @@ def iterate(self, **mapped_kwargs: OperatorExpandArgument) -> IterableOperator: if not mapped_kwargs: raise TypeError("no arguments to expand against") validate_mapping_kwargs(self.operator_class, "iterate", mapped_kwargs) - prevent_duplicates( - self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified" - ) + prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") return self._iterate(DictOfListsExpandInput(mapped_kwargs)) def iterate_kwargs(self, kwargs: OperatorExpandKwargsArgument) -> IterableOperator: if isinstance(kwargs, Sequence): for item in kwargs: if not isinstance(item, (XComArg, Mapping)): - raise TypeError( - f"expected XComArg or list[dict], not {type(kwargs).__name__}" - ) + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") elif not isinstance(kwargs, XComArg): raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") return self._iterate(ListOfDictsExpandInput(kwargs)) def _iterate(self, expand_input) -> IterableOperator: - from airflow.models.iterableoperator import IterableOperator ensure_xcomarg_return_value(expand_input.value) From 9cb1442f2e830d7cf1558cabf316c60138c58d44 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 28 Jan 2025 14:08:44 +0100 Subject: [PATCH 67/97] refactor: Fixed static checks --- airflow/models/iterableoperator.py | 5 +---- tests/models/test_iterableoperator.py | 1 - 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index 5f71856660652..221fca1fb6467 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -298,10 +298,7 @@ def _resolve_expand_input(self, context: Context, session: Session): if isinstance(resolved_input, _MapResult): for value in resolved_input: - self._mapped_kwargs.append( - self._resolve(value=value, context=context, session=session) - ) - + self._mapped_kwargs.append(self._resolve(value=value, context=context, session=session)) else: value = self._resolve(value=resolved_input, context=context, session=session) diff --git a/tests/models/test_iterableoperator.py b/tests/models/test_iterableoperator.py index cfdf305fbb30f..0413676d1c866 100644 --- a/tests/models/test_iterableoperator.py +++ b/tests/models/test_iterableoperator.py @@ -33,7 +33,6 @@ from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap from airflow.providers.standard.operators.python import PythonOperator - from airflow.utils.state import TaskInstanceState from airflow.utils.task_group import TaskGroup from airflow.utils.task_instance_session import set_current_task_instance_session From 0f431353128cfb2cf48024a1dfeaa9f720d61401 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 28 Jan 2025 14:10:41 +0100 Subject: [PATCH 68/97] refactor: Fixed import of from _MapResult --- airflow/models/iterableoperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index 221fca1fb6467..51eebee2127c8 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -46,7 +46,7 @@ is_mappable, ) from airflow.models.taskinstance import TaskInstance -from airflow.models.xcom_arg import _MapResult +from airflow.sdk.definitions.xcom_arg import _MapResult from airflow.triggers.base import run_trigger from airflow.utils import timezone from airflow.utils.context import Context, context_get_outlet_events From 75d16805cc0cebe6e6a02a838a6a10e0868c73e8 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 28 Jan 2025 14:12:45 +0100 Subject: [PATCH 69/97] refactor: Fixed import of BaseOperator --- airflow/models/iterableoperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index 51eebee2127c8..d2541c8c9a6a5 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -39,13 +39,13 @@ TaskDeferred, ) from airflow.models.abstractoperator import DEFAULT_TASK_EXECUTION_TIMEOUT -from airflow.models.baseoperator import BaseOperator from airflow.models.expandinput import ( ExpandInput, _needs_run_time_resolution, is_mappable, ) from airflow.models.taskinstance import TaskInstance +from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.xcom_arg import _MapResult from airflow.triggers.base import run_trigger from airflow.utils import timezone From 4f36167b8c15a83f52924655d71eeab64fa1f01f Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 28 Jan 2025 14:29:32 +0100 Subject: [PATCH 70/97] refactor: Added next_callable method on BaseOperator --- .../airflow/sdk/definitions/baseoperator.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py index e7ecec69411ba..1af3822483a5d 100644 --- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py @@ -33,6 +33,7 @@ import attrs +from airflow.exceptions import TaskDeferralError from airflow.models.param import ParamsDict from airflow.sdk.definitions._internal.abstractoperator import ( DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, @@ -1434,3 +1435,20 @@ def render_template_fields( if not jinja_env: jinja_env = self.get_template_env() self._do_render_template_fields(self, self.template_fields, context, jinja_env, set()) + + @classmethod + def next_callable(cls, operator, next_method, next_kwargs) -> Callable[..., Any]: + """Get the next callable from given operator.""" + # __fail__ is a special signal value for next_method that indicates + # this task was scheduled specifically to fail. + if next_method == "__fail__": + next_kwargs = next_kwargs or {} + traceback = next_kwargs.get("traceback") + if traceback is not None: + cls.logger().error("Trigger failed:\n%s", "\n".join(traceback)) + raise TaskDeferralError(next_kwargs.get("error", "Unknown")) + # Grab the callable off the Operator/Task and add in any kwargs + execute_callable = getattr(operator, next_method) + if next_kwargs: + execute_callable = partial(execute_callable, **next_kwargs) + return execute_callable From 5a403633908668982820855f05441214420d47e0 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 28 Jan 2025 17:00:31 +0100 Subject: [PATCH 71/97] refactor: Fixed more static checks --- airflow/models/iterableoperator.py | 2 ++ task_sdk/src/airflow/sdk/definitions/mappedoperator.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index d2541c8c9a6a5..a2fd8d38fce48 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -78,6 +78,8 @@ def event_loop() -> Generator[AbstractEventLoop, None, None]: class TaskExecutor(LoggingMixin): + """Base class to run an operator or trigger with given task context and task instance.""" + def __init__( self, context: Context, diff --git a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py index 4ac35fa4f4384..a9db41ff9824a 100644 --- a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -32,6 +32,8 @@ ListOfDictsExpandInput, is_mappable, ) +from airflow.models.xcom_arg import XComArg +from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions._internal.abstractoperator import ( DEFAULT_EXECUTOR, DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, @@ -74,8 +76,6 @@ ) from airflow.models.iterableoperator import IterableOperator from airflow.models.param import ParamsDict - from airflow.models.xcom_arg import XComArg - from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG from airflow.sdk.types import Operator from airflow.ti_deps.deps.base_ti_dep import BaseTIDep From 87034e747c44626193a01f26fd02cc8e571707a0 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 29 Jan 2025 10:14:01 +0100 Subject: [PATCH 72/97] refactor: Fixed import BaseOperator --- airflow/models/iterableoperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index a2fd8d38fce48..0490b96f988ae 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -38,6 +38,7 @@ AirflowTaskTimeout, TaskDeferred, ) +from airflow.models import BaseOperator from airflow.models.abstractoperator import DEFAULT_TASK_EXECUTION_TIMEOUT from airflow.models.expandinput import ( ExpandInput, @@ -45,7 +46,6 @@ is_mappable, ) from airflow.models.taskinstance import TaskInstance -from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.xcom_arg import _MapResult from airflow.triggers.base import run_trigger from airflow.utils import timezone From 8020251b6cb326d6fefa14bdb2cdf74b32eae22e Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 29 Jan 2025 10:19:48 +0100 Subject: [PATCH 73/97] refactor: Fixed some imports --- airflow/models/iterableoperator.py | 10 ++++++++-- task_sdk/src/airflow/sdk/definitions/mappedoperator.py | 4 ++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index 0490b96f988ae..b5bfe5ba8fd81 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -31,7 +31,6 @@ from time import sleep from typing import TYPE_CHECKING, Any -from airflow import XComArg from airflow.exceptions import ( AirflowException, AirflowRescheduleTaskInstanceException, @@ -49,11 +48,18 @@ from airflow.sdk.definitions.xcom_arg import _MapResult from airflow.triggers.base import run_trigger from airflow.utils import timezone -from airflow.utils.context import Context, context_get_outlet_events from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.operator_helpers import ExecutionCallableRunner from airflow.utils.task_instance_session import get_current_task_instance_session +try: + from airflow.sdk.definitions.context import Context, context_get_outlet_events + from airflow.sdk.definitions.xcom_args import XComArg +except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.utils.context import Context, context_get_outlet_events + from airflow import XComArg + if TYPE_CHECKING: import jinja2 from sqlalchemy.orm import Session diff --git a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py index a9db41ff9824a..4ac35fa4f4384 100644 --- a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -32,8 +32,6 @@ ListOfDictsExpandInput, is_mappable, ) -from airflow.models.xcom_arg import XComArg -from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions._internal.abstractoperator import ( DEFAULT_EXECUTOR, DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST, @@ -76,6 +74,8 @@ ) from airflow.models.iterableoperator import IterableOperator from airflow.models.param import ParamsDict + from airflow.models.xcom_arg import XComArg + from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG from airflow.sdk.types import Operator from airflow.ti_deps.deps.base_ti_dep import BaseTIDep From 8bcffe6db166da202ac810afc05cc1eeb6211ea9 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 29 Jan 2025 11:20:00 +0100 Subject: [PATCH 74/97] refactor: Try to fix some mypy issues --- airflow/models/iterableoperator.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index b5bfe5ba8fd81..0fdb5122ae7ee 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -46,19 +46,21 @@ ) from airflow.models.taskinstance import TaskInstance from airflow.sdk.definitions.xcom_arg import _MapResult +from airflow.sdk.definitions._internal.abstractoperator import Operator from airflow.triggers.base import run_trigger from airflow.utils import timezone +from airflow.utils.context import context_get_outlet_events from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.operator_helpers import ExecutionCallableRunner from airflow.utils.task_instance_session import get_current_task_instance_session try: - from airflow.sdk.definitions.context import Context, context_get_outlet_events + from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.xcom_args import XComArg except ImportError: # TODO: Remove once provider drops support for Airflow 2 - from airflow.utils.context import Context, context_get_outlet_events from airflow import XComArg + from airflow.utils.context import Contexts if TYPE_CHECKING: import jinja2 @@ -107,7 +109,7 @@ def context(self) -> Context: return {**self.__context, **{"ti": self.task_instance}} @property - def operator(self) -> BaseOperator: + def operator(self) -> Operator: return self.task_instance.task @property From dae0a913fe0c8e92179d0fe911b0b1b88770ed66 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 29 Jan 2025 11:21:19 +0100 Subject: [PATCH 75/97] refactor: Removed test not applicable on IterableOperator --- tests/models/test_iterableoperator.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/tests/models/test_iterableoperator.py b/tests/models/test_iterableoperator.py index 0413676d1c866..beb9d7904f07e 100644 --- a/tests/models/test_iterableoperator.py +++ b/tests/models/test_iterableoperator.py @@ -567,30 +567,6 @@ def test_expand_mapped_task_instance_with_named_index( assert indices == expected_rendered_names -@pytest.mark.parametrize( - "create_mapped_task", - [ - pytest.param(_create_mapped_with_name_template_classic, id="classic"), - pytest.param(_create_mapped_with_name_template_taskflow, id="taskflow"), - ], -) -def test_expand_mapped_task_task_instance_mutation_hook(dag_maker, session, create_mapped_task) -> None: - """Test that the tast_instance_mutation_hook is called.""" - expected_map_index = [0, 1, 2] - - with dag_maker(session=session): - task1 = BaseOperator(task_id="op1") - mapped = MockOperator.partial(task_id="task_2").iterate(arg2=task1.output) - - dr = dag_maker.create_dagrun() - - with mock.patch("airflow.settings.task_instance_mutation_hook") as mock_hook: - expand_mapped_task(mapped, dr.run_id, task1.task_id, length=len(expected_map_index), session=session) - - for index, call in enumerate(mock_hook.call_args_list): - assert call.args[0].map_index == expected_map_index[index] - - @pytest.mark.parametrize( "map_index, expected", [ From e2727bbeb57461296f7b3ac967cd2759d18aa275 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 29 Jan 2025 11:26:54 +0100 Subject: [PATCH 76/97] refactor: Cannot assign state anymore from context to TaskInstance --- airflow/models/iterableoperator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index 0fdb5122ae7ee..d59a10f07ccb6 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -439,7 +439,6 @@ def _create_task(self, context: Context, index: int) -> TaskInstance: task_instance = TaskInstance( task=operator, run_id=context["ti"].run_id, - state=context["ti"].state, map_index=index, ) return task_instance From 0c80f9ade5eacabf611ebd4483d4dff4c28bc8bc Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 29 Jan 2025 11:31:06 +0100 Subject: [PATCH 77/97] refactor: Removed stream from decorators as this method is deprecated anyway and we should implement the iterate and iterate_kwargs method correctly there instead, at the moment those aren't implemented yet. --- airflow/decorators/base.py | 34 ++++------------------------------ 1 file changed, 4 insertions(+), 30 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 97e3c291ea582..4d694583fd5a7 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -533,37 +533,11 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: ) return XComArg(operator=operator) - def stream(self, **mapped_kwargs: OperatorExpandArgument) -> XComArg: - from airflow.models.streamedoperator import StreamedOperator + def iterate(self, **mapped_kwargs: OperatorExpandArgument) -> XComArg: + raise NotImplementedError - if not mapped_kwargs: - raise TypeError("no arguments to expand against") - prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") - - expand_input = DictOfListsExpandInput(mapped_kwargs) - ensure_xcomarg_return_value(expand_input.value) - - partial_kwargs = self.kwargs.copy() - task_id = partial_kwargs.pop("task_id") - dag = partial_kwargs.pop("dag", None) or DagContext.get_current_dag() - task_group = partial_kwargs.pop("task_group", None) or TaskGroupContext.get_current_task_group(dag) - start_date = partial_kwargs.pop("start_date") - end_date = partial_kwargs.pop("end_date") - max_active_tis_per_dag = partial_kwargs.pop("max_active_tis_per_dag", None) - - operator = StreamedOperator( - task_id=task_id, - dag=dag, - task_group=task_group, - start_date=start_date, - end_date=end_date, - max_active_tis_per_dag=max_active_tis_per_dag, - operator_class=self.operator_class, - expand_input=expand_input, - retries=0, - partial_kwargs=self.kwargs.copy(), - ) - return XComArg(operator=operator) + def iterate_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg: + raise NotImplementedError def partial(self, **kwargs: Any) -> _TaskDecorator[FParams, FReturn, OperatorSubclass]: self._validate_arg_names("partial", kwargs) From 77489e131d60b092c6c0325622977f836127fc5f Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 29 Jan 2025 11:49:05 +0100 Subject: [PATCH 78/97] refactor: Next callable method should be instance method instead of class method on BaseOperator --- airflow/models/baseoperator.py | 11 +++++------ airflow/models/iterableoperator.py | 4 ++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 3047a86f22aed..26dc1c4f3fc1f 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -154,7 +154,7 @@ def decorator(cls, func): @wraps(func) def wrapper(self, *args, **kwargs): from airflow.decorators.base import DecoratedOperator - from airflow.models.streamedoperator import StreamedOperator + from airflow.models.iterableoperator import IterableOperator sentinel_key = f"{self.__class__.__name__}__sentinel" sentinel = kwargs.pop(sentinel_key, None) @@ -170,7 +170,7 @@ def wrapper(self, *args, **kwargs): not cls.test_mode and not sentinel == _sentinel and not isinstance(self, DecoratedOperator) - and not isinstance(self, StreamedOperator) + and not isinstance(self, IterableOperator) ): message = f"{self.__class__.__name__}.{func.__name__} cannot be called outside TaskInstance!" if not self.allow_nested_operators: @@ -761,8 +761,7 @@ def defer( """ raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout) - @classmethod - def next_callable(cls, operator, next_method, next_kwargs) -> Callable[..., Any]: + def next_callable(self, next_method, next_kwargs) -> Callable[..., Any]: """Get the next callable from given operator.""" # __fail__ is a special signal value for next_method that indicates # this task was scheduled specifically to fail. @@ -770,13 +769,13 @@ def next_callable(cls, operator, next_method, next_kwargs) -> Callable[..., Any] next_kwargs = next_kwargs or {} traceback = next_kwargs.get("traceback") if traceback is not None: - cls.log.error("Trigger failed:\n%s", "\n".join(traceback)) + self.log.error("Trigger failed:\n%s", "\n".join(traceback)) if (error := next_kwargs.get("error", "Unknown")) == TriggerFailureReason.TRIGGER_TIMEOUT: raise TaskDeferralTimeout(error) else: raise TaskDeferralError(error) # Grab the callable off the Operator/Task and add in any kwargs - execute_callable = getattr(operator, next_method) + execute_callable = getattr(self, next_method) if next_kwargs: execute_callable = functools.partial(execute_callable, **next_kwargs) return execute_callable diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index d59a10f07ccb6..2a302999eb00a 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -219,8 +219,8 @@ async def run(self, task_deferred: TaskDeferred): if task_deferred.method_name: try: - next_method = BaseOperator.next_callable( - self.operator, task_deferred.method_name, task_deferred.kwargs + next_method = self.operator.next_callable( + task_deferred.method_name, task_deferred.kwargs ) outlet_events = context_get_outlet_events(self.context) return ExecutionCallableRunner( From 454b8f4e0e95ab612d0cc32c5d3780d8d6532b41 Mon Sep 17 00:00:00 2001 From: David Blain Date: Fri, 31 Jan 2025 14:50:38 +0100 Subject: [PATCH 79/97] refactor: Only keep sdk imports --- airflow/models/iterableoperator.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index 2a302999eb00a..6b5b56a62c2e2 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -45,8 +45,9 @@ is_mappable, ) from airflow.models.taskinstance import TaskInstance -from airflow.sdk.definitions.xcom_arg import _MapResult from airflow.sdk.definitions._internal.abstractoperator import Operator +from airflow.sdk.definitions.context import Context +from airflow.sdk.definitions.xcom_arg import _MapResult, XComArg from airflow.triggers.base import run_trigger from airflow.utils import timezone from airflow.utils.context import context_get_outlet_events @@ -54,14 +55,6 @@ from airflow.utils.operator_helpers import ExecutionCallableRunner from airflow.utils.task_instance_session import get_current_task_instance_session -try: - from airflow.sdk.definitions.context import Context - from airflow.sdk.definitions.xcom_args import XComArg -except ImportError: - # TODO: Remove once provider drops support for Airflow 2 - from airflow import XComArg - from airflow.utils.context import Contexts - if TYPE_CHECKING: import jinja2 from sqlalchemy.orm import Session From 44876f382d9c6c5a17a172a306598cd8276c49ac Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 13 Feb 2025 10:39:29 +0100 Subject: [PATCH 80/97] refactor: Updated IterableOperator which optimizes execution of deferrable tasks --- airflow/models/iterableoperator.py | 105 +++++++++++++++-------------- 1 file changed, 55 insertions(+), 50 deletions(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index 6b5b56a62c2e2..c01e0565b76f0 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -27,7 +27,7 @@ from datetime import timedelta from math import ceil from multiprocessing import TimeoutError -from multiprocessing.pool import ThreadPool +from multiprocessing.pool import ApplyResult, ThreadPool from time import sleep from typing import TYPE_CHECKING, Any @@ -226,7 +226,7 @@ async def run(self, task_deferred: TaskDeferred): class IterableOperator(BaseOperator): - """Object representing a streamed operator in a DAG.""" + """Object representing an iterable operator in a DAG.""" _operator_class: type[BaseOperator] expand_input: ExpandInput @@ -330,62 +330,67 @@ def _run_tasks( tasks: Iterable[TaskInstance], results: list[Any] | None = None, ) -> list[Any]: - now = timezone.utcnow() exception: BaseException | None = None results = results or [] reschedule_date = timezone.utcnow() - deferred_tasks: list[Future] = [] failed_tasks: list[TaskInstance] = [] with ThreadPool(processes=self.max_active_tis_per_dag) as pool: - futures = [(task, pool.apply_async(self._run_operator, (context, task))) for task in tasks] - - for task, future in futures: - try: - result = future.get(timeout=self.timeout) - if isinstance(result, TaskDeferred): - deferred_tasks.append( - ensure_future( - self._run_deferrable( - context=context, - task=task, - task_deferred=result, + futures = {pool.apply_async(self._run_operator, (context, task)): task for task in tasks} + + while futures: + self.log.info("Number of remaining futures: %s", len(futures)) + + deferred_tasks: list[Future] = [] + + for future in filter(ApplyResult.ready, list(futures.keys())): + task = futures.pop(future) + + try: + result = future.get(timeout=self.timeout) + if isinstance(result, TaskDeferred): + deferred_tasks.append( + ensure_future( + self._run_deferrable( + context=context, + task=task, + task_deferred=result, + ) ) ) - ) - elif result: - results.append(result) - except TimeoutError as e: - self.log.warning("A timeout occurred for task_id %s", task.task_id) - if task.next_try_number > self.retries: - exception = AirflowTaskTimeout(e) - else: - reschedule_date = task.next_retry_datetime() - failed_tasks.append(task) - except AirflowRescheduleTaskInstanceException as e: - reschedule_date = e.reschedule_date - failed_tasks.append(e.task) - except AirflowException as e: - self.log.error("An exception occurred for task_id %s", task.task_id) - exception = e - - if deferred_tasks: - self.log.info("Running %s deferred tasks", len(deferred_tasks)) - - with event_loop() as loop: - for result in loop.run_until_complete(gather(*deferred_tasks, return_exceptions=True)): - self.log.debug("result: %s", result) - - if isinstance(result, Exception): - if isinstance(result, AirflowRescheduleTaskInstanceException): - reschedule_date = result.reschedule_date - failed_tasks.append(result.task) + elif result: + results.append(result) + except TimeoutError as e: + self.log.warning("A timeout occurred for task_id %s", task.task_id) + if task.next_try_number > self.retries: + exception = AirflowTaskTimeout(e) else: - exception = result - elif result: - results.append(result) - - deferred_tasks.clear() + reschedule_date = task.next_retry_datetime() + failed_tasks.append(task) + except AirflowRescheduleTaskInstanceException as e: + reschedule_date = e.reschedule_date + failed_tasks.append(e.task) + except AirflowException as e: + self.log.error("An exception occurred for task_id %s", task.task_id) + exception = e + + if deferred_tasks: + self.log.info("Running %s deferred tasks", len(deferred_tasks)) + + with event_loop() as loop: + for result in loop.run_until_complete(gather(*deferred_tasks, return_exceptions=True)): + self.log.debug("result: %s", result) + + if isinstance(result, Exception): + if isinstance(result, AirflowRescheduleTaskInstanceException): + reschedule_date = result.reschedule_date + failed_tasks.append(result.task) + else: + exception = result + elif result: + results.append(result) + elif futures: + sleep(min(len(futures), os.cpu_count())) if not failed_tasks: if exception: @@ -396,7 +401,7 @@ def _run_tasks( # TaskInstance._set_state(context["ti"], TaskInstanceState.UP_FOR_RETRY, session) # Calculate delay before the next retry - delay = reschedule_date - now + delay = reschedule_date - timezone.utcnow() delay_seconds = ceil(delay.total_seconds()) self.log.debug("delay_seconds: %s", delay_seconds) From d30c568913a4e129c9924033ab6f31bb0f0fb537 Mon Sep 17 00:00:00 2001 From: David Blain Date: Fri, 14 Feb 2025 13:48:19 +0100 Subject: [PATCH 81/97] refactor: Set state on TaskInstance otherwise xcom push and pulls won't work with map_indexes --- airflow/models/iterableoperator.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index c01e0565b76f0..a2ac0f732fc29 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -53,6 +53,7 @@ from airflow.utils.context import context_get_outlet_events from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.operator_helpers import ExecutionCallableRunner +from airflow.utils.state import TaskInstanceState from airflow.utils.task_instance_session import get_current_task_instance_session if TYPE_CHECKING: @@ -127,6 +128,7 @@ def __enter__(self): if self.task_instance.try_number == 0: self.operator.render_template_fields(context=self.context) self.operator.pre_execute(context=self.context) + self.task_instance.set_state(TaskInstanceState.RUNNING) self.task_instance._run_execute_callback(context=self.context, task=self.operator) return self @@ -148,10 +150,12 @@ def __exit__(self, exc_type, exc_value, traceback): self.task_instance.try_number += 1 self.task_instance.end_date = timezone.utcnow() + self.task_instance.set_state(TaskInstanceState.FAILED) raise AirflowRescheduleTaskInstanceException(task=self.task_instance) raise exc_value self.operator.post_execute(context=self.context) + self.task_instance.set_state(TaskInstanceState.SUCCESS) if self.log.isEnabledFor(logging.INFO): self.log.info( "Task instance %s for %s finished successfully in %s attempts in %s mode.", From f18cd4d211bdd1d15dfdefeb1370813757db82ae Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 25 Feb 2025 21:58:21 +0100 Subject: [PATCH 82/97] refactor: Fixed task_id generation without duplicated nested task group ids and only returned XCom at end of execution instead of accumulation during execution which leads to more memory consumption --- airflow/models/iterableoperator.py | 101 +++++++++--------- .../airflow/sdk/definitions/mappedoperator.py | 6 +- 2 files changed, 55 insertions(+), 52 deletions(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index a2ac0f732fc29..94148df250cf1 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -47,7 +47,7 @@ from airflow.models.taskinstance import TaskInstance from airflow.sdk.definitions._internal.abstractoperator import Operator from airflow.sdk.definitions.context import Context -from airflow.sdk.definitions.xcom_arg import _MapResult, XComArg +from airflow.sdk.definitions.xcom_arg import XComArg, _MapResult from airflow.triggers.base import run_trigger from airflow.utils import timezone from airflow.utils.context import context_get_outlet_events @@ -55,6 +55,7 @@ from airflow.utils.operator_helpers import ExecutionCallableRunner from airflow.utils.state import TaskInstanceState from airflow.utils.task_instance_session import get_current_task_instance_session +from airflow.utils.xcom import XCOM_RETURN_KEY if TYPE_CHECKING: import jinja2 @@ -88,16 +89,18 @@ def __init__( task_instance: TaskInstance, ): super().__init__() - self.__context = context + self.__context = dict(context.items()) self._task_instance = task_instance self._is_async_mode: bool = False # Flag to track sync/async mode @property def task_instance(self) -> TaskInstance: - # TODO: If we want a specialized TaskInstance for the StreamedOperator, - # we could inherit from TaskInstanceDependencies return self._task_instance + @property + def task_index(self) -> int: + return int(self._task_instance.task_id.rsplit("_", 1)[-1]) + @property def context(self) -> Context: return {**self.__context, **{"ti": self.task_instance}} @@ -120,15 +123,15 @@ def __enter__(self): "Attempting running task %s of %s for %s with map_index %s in %s mode.", self.task_instance.try_number, self.operator.retries, - type(self.operator).__name__, - self.task_instance.map_index, + self.task_instance.task_id, + self.task_index, self.mode, ) if self.task_instance.try_number == 0: self.operator.render_template_fields(context=self.context) self.operator.pre_execute(context=self.context) - self.task_instance.set_state(TaskInstanceState.RUNNING) + self.task_instance.set_state(TaskInstanceState.SCHEDULED) self.task_instance._run_execute_callback(context=self.context, task=self.operator) return self @@ -142,25 +145,26 @@ def __exit__(self, exc_type, exc_value, traceback): if self.task_instance.next_try_number > self.operator.retries: self.log.error( "Max number of attempts for %s with map_index %s failed due to: %s", - type(self.operator).__name__, - self.task_instance.map_index, + self.task_instance.task_id, + self.task_index, exc_value, ) raise exc_value self.task_instance.try_number += 1 self.task_instance.end_date = timezone.utcnow() - self.task_instance.set_state(TaskInstanceState.FAILED) - + self.task_instance.set_state(TaskInstanceState.UP_FOR_RESCHEDULE) raise AirflowRescheduleTaskInstanceException(task=self.task_instance) + + self.task_instance.set_state(TaskInstanceState.FAILED) raise exc_value self.operator.post_execute(context=self.context) self.task_instance.set_state(TaskInstanceState.SUCCESS) if self.log.isEnabledFor(logging.INFO): self.log.info( "Task instance %s for %s finished successfully in %s attempts in %s mode.", - self.task_instance.map_index, - type(self.operator).__name__, + self.task_index, + self.task_instance.task_id, self.task_instance.next_try_number, self.mode, ) @@ -216,9 +220,7 @@ async def run(self, task_deferred: TaskDeferred): if task_deferred.method_name: try: - next_method = self.operator.next_callable( - task_deferred.method_name, task_deferred.kwargs - ) + next_method = self.operator.next_callable(task_deferred.method_name, task_deferred.kwargs) outlet_events = context_get_outlet_events(self.context) return ExecutionCallableRunner( func=next_method, @@ -271,15 +273,15 @@ def _get_specified_expand_input(self) -> ExpandInput: return self.expand_input def _unmap_operator(self, index): - self.log.debug("index: %s", index) kwargs = { **self.partial_kwargs, **{"task_id": f"{self.task_id}_{index}"}, **self._mapped_kwargs[index], } + self.log.debug("index: %s", index) self.log.debug("kwargs: %s", kwargs) self.log.debug("operator_class: %s", self._operator_class) - return self._operator_class(**kwargs) + return self._operator_class(**kwargs, _airflow_from_mapped=True) def _resolve(self, value, context: Context, session: Session): if isinstance(value, dict): @@ -332,10 +334,8 @@ def _run_tasks( self, context: Context, tasks: Iterable[TaskInstance], - results: list[Any] | None = None, - ) -> list[Any]: + ) -> Iterable[Any] | None: exception: BaseException | None = None - results = results or [] reschedule_date = timezone.utcnow() failed_tasks: list[TaskInstance] = [] @@ -357,13 +357,11 @@ def _run_tasks( ensure_future( self._run_deferrable( context=context, - task=task, + task_instance=task, task_deferred=result, ) ) ) - elif result: - results.append(result) except TimeoutError as e: self.log.warning("A timeout occurred for task_id %s", task.task_id) if task.next_try_number > self.retries: @@ -382,7 +380,9 @@ def _run_tasks( self.log.info("Running %s deferred tasks", len(deferred_tasks)) with event_loop() as loop: - for result in loop.run_until_complete(gather(*deferred_tasks, return_exceptions=True)): + for result in loop.run_until_complete( + gather(*deferred_tasks, return_exceptions=True) + ): self.log.debug("result: %s", result) if isinstance(result, Exception): @@ -391,18 +391,18 @@ def _run_tasks( failed_tasks.append(result.task) else: exception = result - elif result: - results.append(result) elif futures: sleep(min(len(futures), os.cpu_count())) if not failed_tasks: if exception: raise exception - return results - - # session = get_current_task_instance_session() - # TaskInstance._set_state(context["ti"], TaskInstanceState.UP_FOR_RETRY, session) + if self.do_xcom_push: + return [ + self.xcom_pull(context=context, task_ids=f"{self.task_id}_{index}", dag_id=self.dag_id) + for index in range(len(self._mapped_kwargs)) + ] + return None # Calculate delay before the next retry delay = reschedule_date - timezone.utcnow() @@ -419,31 +419,34 @@ def _run_tasks( sleep(delay_seconds) - # TaskInstance._set_state(context["ti"], TaskInstanceState.RUNNING, session) - - return self._run_tasks(context, failed_tasks, results) + return self._run_tasks(context, failed_tasks) - @classmethod - def _run_operator(cls, context: Context, task_instance: TaskInstance): + def _run_operator(self, context: Context, task_instance: TaskInstance): try: with OperatorExecutor(context=context, task_instance=task_instance) as executor: - return executor.run() + result = executor.run() + if self.do_xcom_push: + task_instance.xcom_push(key=XCOM_RETURN_KEY, value=result) + return result except TaskDeferred as task_deferred: return task_deferred - async def _run_deferrable(self, context: Context, task: TaskInstance, task_deferred: TaskDeferred): + async def _run_deferrable( + self, context: Context, task_instance: TaskInstance, task_deferred: TaskDeferred + ): async with self._semaphore: - async with TriggerExecutor(context=context, task_instance=task) as executor: - return await executor.run(task_deferred) + async with TriggerExecutor(context=context, task_instance=task_instance) as executor: + result = await executor.run(task_deferred) + if self.do_xcom_push: + task_instance.xcom_push(key=XCOM_RETURN_KEY, value=result) + return result - def _create_task(self, context: Context, index: int) -> TaskInstance: + def _create_task(self, run_id: str, index: int) -> TaskInstance: operator = self._unmap_operator(index) - task_instance = TaskInstance( + return TaskInstance( task=operator, - run_id=context["ti"].run_id, - map_index=index, + run_id=run_id, ) - return task_instance def execute(self, context: Context): self.log.info( @@ -454,14 +457,10 @@ def execute(self, context: Context): self.timeout, ) - results = self._run_tasks( + return self._run_tasks( context=context, tasks=map( - lambda index: self._create_task(context, index[0]), + lambda index: self._create_task(context["ti"].run_id, index[0]), enumerate(self._mapped_kwargs), ), ) - - if self.do_xcom_push: - return results - return None diff --git a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py index 3661924a462e8..2c70954b36140 100644 --- a/task_sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -69,6 +69,7 @@ OperatorExpandArgument, OperatorExpandKwargsArgument, ) + from airflow.models.iterableoperator import IterableOperator from airflow.models.xcom_arg import XComArg from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG @@ -282,7 +283,10 @@ def _iterate(self, expand_input) -> IterableOperator: # We don't want to time out the whole stream operator, we only time out the individual tasks kwargs["timeout"] = kwargs.pop("execution_timeout", None) kwargs["max_active_tis_per_dag"] = self.kwargs.get("max_active_tis_per_dag") - + self.kwargs.pop("task_group", None) + self.kwargs["task_id"] = kwargs["task_id"] = kwargs["task_id"].rsplit(".", 1)[-1] + self.kwargs["do_xcom_push"] = self.kwargs.get("do_xcom_push", True) + self._expand_called = True return IterableOperator( **kwargs, operator_class=self.operator_class, From 1375476ba9aba909c0fb134e2f2dbfa24ea9e877 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 26 Feb 2025 15:13:25 +0100 Subject: [PATCH 83/97] refactor: Refactored IterableOperator which now returns an XComIterable which lazily returns index results to reduce memory consumption --- airflow/models/iterableoperator.py | 68 +++++++++++++++++++++++++++--- 1 file changed, 63 insertions(+), 5 deletions(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index 94148df250cf1..26a1828a85488 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -22,7 +22,7 @@ import os from abc import abstractmethod from asyncio import AbstractEventLoop, Future, Semaphore, ensure_future, gather -from collections.abc import Generator, Iterable, Sequence +from collections.abc import Generator, Iterable, Iterator, Sequence from contextlib import contextmanager, suppress from datetime import timedelta from math import ceil @@ -45,9 +45,11 @@ is_mappable, ) from airflow.models.taskinstance import TaskInstance +from airflow.models.xcom import XCom from airflow.sdk.definitions._internal.abstractoperator import Operator from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.xcom_arg import XComArg, _MapResult +from airflow.serialization import serde from airflow.triggers.base import run_trigger from airflow.utils import timezone from airflow.utils.context import context_get_outlet_events @@ -61,6 +63,8 @@ import jinja2 from sqlalchemy.orm import Session +serde._extra_allowed = serde._extra_allowed.union({"infrabel.operators.iterableoperator.XComIterable"}) + @contextmanager def event_loop() -> Generator[AbstractEventLoop, None, None]: @@ -80,6 +84,58 @@ def event_loop() -> Generator[AbstractEventLoop, None, None]: loop.close() +class XComIterable(Iterator, Sequence): + """An iterable that lazily fetches XCom values one by one instead of loading all at once.""" + + def __init__(self, task_id: str, dag_id: str, run_id: str, length: int): + self.task_id = task_id + self.dag_id = dag_id + self.run_id = run_id + self.length = length + self.index = 0 + + def __iter__(self) -> Iterator: + self.index = 0 + return self + + def __next__(self): + if self.index >= self.length: + raise StopIteration + + value = self[self.index] + self.index += 1 + return value + + def __len__(self): + return self.length + + def __getitem__(self, index: int): + """Allows direct indexing, making this work like a sequence.""" + if not (0 <= index < self.length): + raise IndexError + + return XCom.get_one( + key=XCOM_RETURN_KEY, + dag_id=self.dag_id, + task_id=f"{self.task_id}_{index}", + run_id=self.run_id, + ) + + def serialize(self): + """Ensure the object is JSON serializable.""" + return { + "task_id": self.task_id, + "dag_id": self.dag_id, + "run_id": self.run_id, + "length": self.length, + } + + @classmethod + def deserialize(cls, data: dict, version: int): + """Ensure the object is JSON deserializable.""" + return XComIterable(**data) + + class TaskExecutor(LoggingMixin): """Base class to run an operator or trigger with given task context and task instance.""" @@ -398,10 +454,12 @@ def _run_tasks( if exception: raise exception if self.do_xcom_push: - return [ - self.xcom_pull(context=context, task_ids=f"{self.task_id}_{index}", dag_id=self.dag_id) - for index in range(len(self._mapped_kwargs)) - ] + return XComIterable( + task_id=self.task_id, + dag_id=self.dag_id, + run_id=context["run_id"], + length=len(self._mapped_kwargs), + ) return None # Calculate delay before the next retry From 0eefacd9d9304acad39826c08cbbaa26e8f0d983 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 19 Mar 2025 18:02:49 +0100 Subject: [PATCH 84/97] refactor: Refactored IterableOperator to support streaming of iterable inputs --- airflow/models/iterableoperator.py | 275 +++++++++++++----- .../src/airflow/sdk/definitions/xcom_arg.py | 30 +- 2 files changed, 232 insertions(+), 73 deletions(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index 26a1828a85488..68c0cfe63a883 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -136,6 +136,98 @@ def deserialize(cls, data: dict, version: int): return XComIterable(**data) +# TODO: should be moved to correct location as this will be used by streamable operators +class DeferredIterable(Iterator): + """An iterable that lazily fetches XCom values one by one instead of loading all at once.""" + + def __init__(self, results: list[Any] | Any, trigger: BaseTrigger, operator: BaseOperator, next_method: str, context: Context | None = None): + self.results = results.copy() if isinstance(results, list) else [results] + self.trigger = trigger + self.operator = operator + self.next_method = next_method + self.context = context + self.index = 0 + self._loop = None + + @provide_session + def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> DeferredIterable: + return DeferredIterable( + results=self.results, + trigger=self.trigger, + operator=self.operator, + next_method=self.next_method, + context={**context}, + ) + + @property + def loop(self): + if not self._loop: + self._loop = asyncio.new_event_loop() + return self._loop + + def __iter__(self) -> Iterator: + return self + + def __next__(self): + if self.index < len(self.results): + result = self.results[self.index] + self.index += 1 + return result + + # No more results; attempt to load the next page using the trigger + logging.info("No more results. Running trigger: %s", self.trigger) + + event = self.loop.run_until_complete(run_trigger(self.trigger)) + iterator = getattr(self.operator, self.next_method)(self.context, event.payload) + if not iterator: + raise StopIteration + + self.trigger = iterator.trigger + self.results.extend(iterator.results) + self.index += 1 + return self.results[-1] + + def __len__(self): + return len(self.results) + + def __getitem__(self, index: int): + if not (0 <= index < len(self)): + raise IndexError + + return self.results[index] + + def __del__(self): + if self._loop: + self._loop.close() + + def serialize(self): + """Ensure the object is JSON serializable.""" + return { + "results": self.results, + "trigger": self.trigger.serialize(), + "operator": SerializedBaseOperator.serialize_operator(self.operator), + "next_method": self.next_method, + } + + @classmethod + def deserialize(cls, data: dict, version: int): + """Ensure the object is JSON deserializable.""" + trigger_class = import_string(data["trigger"][0]) + trigger = trigger_class(**data["trigger"][1]) + operator_class = import_string(f"{data['operator']['_task_module']}.{data['operator']['_task_type']}") + operator_kwargs = { + key: value for key, value in data["operator"].items() + if key in inspect.signature(operator_class.__init__).parameters or key in operator_class._comps + } + operator = operator_class(**operator_kwargs) + return DeferredIterable( + results=data["results"], + trigger=trigger, + operator=operator, + next_method=data["next_method"] + ) + + class TaskExecutor(LoggingMixin): """Base class to run an operator or trigger with given task context and task instance.""" @@ -287,6 +379,20 @@ async def run(self, task_deferred: TaskDeferred): return await self.run(task_deferred=task_deferred) +class AtomicBoolean: + def __init__(self, initial: bool = False): + self._value = initial + self._lock = Lock() + + def set(self, value: bool): + with self._lock: + self._value = value + + def get(self) -> bool: + with self._lock: + return self._value + + class IterableOperator(BaseOperator): """Object representing an iterable operator in a DAG.""" @@ -299,6 +405,7 @@ class IterableOperator(BaseOperator): "partial_kwargs", "_log", "_semaphore", + "_completed", ) def __init__( @@ -315,10 +422,12 @@ def __init__( self.expand_input = expand_input self.partial_kwargs = partial_kwargs or {} self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout - self._mapped_kwargs: list[dict] = [] + self._mapped_kwargs: Iterable[dict] = [] if not self.max_active_tis_per_dag: self.max_active_tis_per_dag = os.cpu_count() or 1 self._semaphore = Semaphore(self.max_active_tis_per_dag) + self._completed = AtomicBoolean() + self._number_of_tasks: int = 0 XComArg.apply_upstream_relationship(self, self.expand_input.value) @property @@ -328,15 +437,17 @@ def operator_name(self) -> str: def _get_specified_expand_input(self) -> ExpandInput: return self.expand_input - def _unmap_operator(self, index): + def _unmap_operator(self, index: int, mapped_kwargs: dict): kwargs = { **self.partial_kwargs, **{"task_id": f"{self.task_id}_{index}"}, - **self._mapped_kwargs[index], + **mapped_kwargs, } + self._number_of_tasks += 1 self.log.debug("index: %s", index) self.log.debug("kwargs: %s", kwargs) self.log.debug("operator_class: %s", self._operator_class) + self.log.debug("number_of_tasks: %s", self._number_of_tasks) return self._operator_class(**kwargs, _airflow_from_mapped=True) def _resolve(self, value, context: Context, session: Session): @@ -347,7 +458,7 @@ def _resolve(self, value, context: Context, session: Session): item = item.resolve(context=context, session=session) if is_mappable(item): - item = list(item) # type: ignore + item = iter(item) # type: ignore self.log.debug("resolved_value: %s", item) @@ -355,28 +466,39 @@ def _resolve(self, value, context: Context, session: Session): return value + def _lazy_mapped_kwargs(self, input, context: Context, session: Session): + self.log.debug("_lazy_mapped_kwargs value: %s", input) + + value = self._resolve(value=input, context=context, session=session) + + self.log.debug("resolved value: %s", value) + + if isinstance(value, dict): + for key, item in value.items(): + if not isinstance(item, (Sequence, Iterable)) or isinstance(item, str): + yield {key: item} + else: + for sub_item in item: + yield {key: sub_item} + def _resolve_expand_input(self, context: Context, session: Session): + self.log.debug("resolve_expand_input: %s", self.expand_input) + if isinstance(self.expand_input.value, XComArg): - resolved_input = self.expand_input.value.resolve(context=context, session=session) + resolved_input = self.expand_input.value.resolve( + context=context, session=session + ) else: resolved_input = self.expand_input.value + self.log.debug("resolved_input: %s", resolved_input) + if isinstance(resolved_input, _MapResult): - for value in resolved_input: - self._mapped_kwargs.append(self._resolve(value=value, context=context, session=session)) + self._mapped_kwargs = map(lambda value: self._resolve(value=value, context=context, session=session), resolved_input) else: - value = self._resolve(value=resolved_input, context=context, session=session) + self._mapped_kwargs = iter(self._lazy_mapped_kwargs(input=resolved_input, context=context, session=session)) - if isinstance(value, dict): - for key, item in self._resolve( - value=resolved_input, context=context, session=session - ).items(): - if isinstance(item, list): - self._mapped_kwargs.extend([{key: item} for item in item]) - else: - self._mapped_kwargs.append({key: item}) - - self.log.debug("mapped_kwargs: %s", self._mapped_kwargs) + self.log.info("mapped_kwargs: %s", self._mapped_kwargs) def render_template_fields( self, @@ -394,43 +516,65 @@ def _run_tasks( exception: BaseException | None = None reschedule_date = timezone.utcnow() failed_tasks: list[TaskInstance] = [] - - with ThreadPool(processes=self.max_active_tis_per_dag) as pool: - futures = {pool.apply_async(self._run_operator, (context, task)): task for task in tasks} - - while futures: - self.log.info("Number of remaining futures: %s", len(futures)) - + task_queue = Queue() + + # Task Producer + def task_producer(): + self.log.info("Started producing tasks: %s", tasks) + for task in iter(tasks): + self.log.info("Created task: %s", task) + task_queue.put(task) + self._completed.set(True) + self.log.info("Finished producing tasks: %s", self._completed.get()) + + # Task Consumer + def task_consumer(): + self.log.info("Started consuming tasks") + nonlocal exception, reschedule_date + + self.log.info("task_queue : %s", task_queue.qsize()) + self.log.info("completed: %s", self._completed.get()) + + while not task_queue.empty() or not self._completed.get(): deferred_tasks: list[Future] = [] - for future in filter(ApplyResult.ready, list(futures.keys())): - task = futures.pop(future) - - try: - result = future.get(timeout=self.timeout) - if isinstance(result, TaskDeferred): - deferred_tasks.append( - ensure_future( - self._run_deferrable( - context=context, - task_instance=task, - task_deferred=result, + with ThreadPool(processes=self.max_active_tis_per_dag) as pool: + futures = {} + + while not task_queue.empty(): + task = task_queue.get(timeout=1) + self.log.info("received task : %s", task) + futures[pool.apply_async(self._run_operator, (context, task))] = task + + self.log.debug("futures: %s", futures) + + for future in filter(ApplyResult.ready, list(futures.keys())): + task = futures.pop(future) + try: + result = future.get(timeout=self.timeout) + if isinstance(result, TaskDeferred): + deferred_tasks.append( + ensure_future( + self._run_deferrable( + context=context, + task_instance=task, + task_deferred=result, + ) ) ) - ) - except TimeoutError as e: - self.log.warning("A timeout occurred for task_id %s", task.task_id) - if task.next_try_number > self.retries: - exception = AirflowTaskTimeout(e) - else: - reschedule_date = task.next_retry_datetime() - failed_tasks.append(task) - except AirflowRescheduleTaskInstanceException as e: - reschedule_date = e.reschedule_date - failed_tasks.append(e.task) - except AirflowException as e: - self.log.error("An exception occurred for task_id %s", task.task_id) - exception = e + except TimeoutError as e: + self.log.warning("A timeout occurred for task_id %s", task.task_id) + if task.next_try_number > self.retries: + exception = AirflowTaskTimeout(e) + else: + reschedule_date = task.next_retry_datetime() + failed_tasks.append(task) + except AirflowRescheduleTaskInstanceException as e: + reschedule_date = e.reschedule_date + failed_tasks.append(e.task) + except AirflowException as e: + self.log.error("An exception occurred for task_id %s", task.task_id) + exception = e if deferred_tasks: self.log.info("Running %s deferred tasks", len(deferred_tasks)) @@ -447,18 +591,27 @@ def _run_tasks( failed_tasks.append(result.task) else: exception = result - elif futures: - sleep(min(len(futures), os.cpu_count())) + + producer_thread = Thread(target=task_producer) + consumer_thread = Thread(target=task_consumer) + + producer_thread.start() + consumer_thread.start() + producer_thread.join() + consumer_thread.join() if not failed_tasks: if exception: raise exception + + self.log.info("Finished consuming tasks: %s", self._number_of_tasks) + if self.do_xcom_push: return XComIterable( task_id=self.task_id, dag_id=self.dag_id, run_id=context["run_id"], - length=len(self._mapped_kwargs), + length=self._number_of_tasks, ) return None @@ -499,26 +652,18 @@ async def _run_deferrable( task_instance.xcom_push(key=XCOM_RETURN_KEY, value=result) return result - def _create_task(self, run_id: str, index: int) -> TaskInstance: - operator = self._unmap_operator(index) + def _create_task(self, run_id: str, index: int, mapped_kwargs: dict) -> TaskInstance: + operator = self._unmap_operator(index, mapped_kwargs) return TaskInstance( task=operator, run_id=run_id, ) def execute(self, context: Context): - self.log.info( - "Executing %s mapped tasks on %s with %s threads and timeout %s", - len(self._mapped_kwargs), - self._operator_class.__name__, - self.max_active_tis_per_dag, - self.timeout, - ) - return self._run_tasks( context=context, tasks=map( - lambda index: self._create_task(context["ti"].run_id, index[0]), + lambda mapped_kwargs: self._create_task(context["ti"].run_id, mapped_kwargs[0], mapped_kwargs[1]), enumerate(self._mapped_kwargs), ), ) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index d080cc7ff3b13..88426a66f976a 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -382,21 +382,33 @@ def _get_callable_name(f: Callable | str) -> str: return "" -class _MapResult(Sequence): - def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: +class _MapResult(Sequence, Iterable): + def __init__(self, value: Sequence | Iterable, callables: list) -> None: self.value = value self.callables = callables - def __getitem__(self, index: Any) -> Any: - value = self.value[index] + def __getitem__(self, index: int) -> Any: + if not (0 <= index < len(self)): + raise IndexError - for f in self.callables: - value = f(value) - return value + value = self.value[index] + return self._apply_callables(value) def __len__(self) -> int: + if isinstance(self.value, Iterable): + raise TypeError + return len(self.value) + def __iter__(self) -> Iterator: + for item in iter(self.value): + yield self._apply_callables(item) + + def _apply_callables(self, value): + for func in self.callables: + value = func(value) + return value + class MapXComArg(XComArg): """ @@ -434,8 +446,10 @@ def map(self, f: Callable[[Any], Any]) -> MapXComArg: def resolve(self, context: Mapping[str, Any]) -> Any: value = self.arg.resolve(context) - if not isinstance(value, (Sequence, dict)): + if not isinstance(value, (Sequence, Iterable, dict)): raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") + if isinstance(value, Iterable) and hasattr(value, "resolve"): # TODO: should check if it's DeferredIterable + value = value.resolve(context) return _MapResult(value, self.callables) From d20b24e8ecb7149350b50b0d4350c388e64b01ec Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 19 Mar 2025 18:09:39 +0100 Subject: [PATCH 85/97] refactor: Don't need to copy context in resolve method of DeferrableIterable --- airflow/models/iterableoperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index 68c0cfe63a883..3af94c10e8fc1 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -156,7 +156,7 @@ def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_x trigger=self.trigger, operator=self.operator, next_method=self.next_method, - context={**context}, + context=context, ) @property From 41856fda2dd103239ec58abf1f9f90f6be3b23fe Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 20 Mar 2025 08:45:37 +0100 Subject: [PATCH 86/97] refactor: Inherit LoggingMixin for DeferrableIterable --- airflow/models/iterableoperator.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index 3af94c10e8fc1..ccaa96a0ce868 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -18,6 +18,7 @@ from __future__ import annotations import asyncio +import inspect import logging import os from abc import abstractmethod @@ -50,11 +51,14 @@ from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.xcom_arg import XComArg, _MapResult from airflow.serialization import serde -from airflow.triggers.base import run_trigger +from airflow.serialization.serialized_objects import SerializedBaseOperator +from airflow.triggers.base import run_trigger, BaseTrigger from airflow.utils import timezone from airflow.utils.context import context_get_outlet_events from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.module_loading import import_string from airflow.utils.operator_helpers import ExecutionCallableRunner +from airflow.utils.session import provide_session, NEW_SESSION from airflow.utils.state import TaskInstanceState from airflow.utils.task_instance_session import get_current_task_instance_session from airflow.utils.xcom import XCOM_RETURN_KEY @@ -137,10 +141,11 @@ def deserialize(cls, data: dict, version: int): # TODO: should be moved to correct location as this will be used by streamable operators -class DeferredIterable(Iterator): +class DeferredIterable(Iterator, LoggingMixin): """An iterable that lazily fetches XCom values one by one instead of loading all at once.""" def __init__(self, results: list[Any] | Any, trigger: BaseTrigger, operator: BaseOperator, next_method: str, context: Context | None = None): + super().__init__() self.results = results.copy() if isinstance(results, list) else [results] self.trigger = trigger self.operator = operator @@ -175,7 +180,7 @@ def __next__(self): return result # No more results; attempt to load the next page using the trigger - logging.info("No more results. Running trigger: %s", self.trigger) + self.log.info("No more results. Running trigger: %s", self.trigger) event = self.loop.run_until_complete(run_trigger(self.trigger)) iterator = getattr(self.operator, self.next_method)(self.context, event.payload) From ea769cba77b988bb8c35e98c3870b77cdc7d0714 Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 20 Mar 2025 08:52:26 +0100 Subject: [PATCH 87/97] refactor: Moved DeferrableIterable to dedicated iterable module under airflow models --- airflow/models/iterable.py | 127 +++++++++++++++++++++++++++++ airflow/models/iterableoperator.py | 101 +---------------------- 2 files changed, 130 insertions(+), 98 deletions(-) create mode 100644 airflow/models/iterable.py diff --git a/airflow/models/iterable.py b/airflow/models/iterable.py new file mode 100644 index 0000000000000..8c5eca58ff8f7 --- /dev/null +++ b/airflow/models/iterable.py @@ -0,0 +1,127 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import inspect +from collections.abc import Iterator +from typing import TYPE_CHECKING, Any + +from airflow.serialization.serialized_objects import SerializedBaseOperator +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.module_loading import import_string +from airflow.utils.session import provide_session, NEW_SESSION + +if TYPE_CHECKING: + from airflow.models import BaseOperator + from airflow.triggers.base import BaseTrigger, run_trigger + from airflow.utils.context import Context + + from sqlalchemy.orm import Session + + +class DeferredIterable(Iterator, LoggingMixin): + """An iterable that lazily fetches XCom values one by one instead of loading all at once.""" + + def __init__(self, results: list[Any] | Any, trigger: BaseTrigger, operator: BaseOperator, next_method: str, context: Context | None = None): + super().__init__() + self.results = results.copy() if isinstance(results, list) else [results] + self.trigger = trigger + self.operator = operator + self.next_method = next_method + self.context = context + self.index = 0 + self._loop = None + + @provide_session + def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> DeferredIterable: + return DeferredIterable( + results=self.results, + trigger=self.trigger, + operator=self.operator, + next_method=self.next_method, + context=context, + ) + + @property + def loop(self): + if not self._loop: + self._loop = asyncio.new_event_loop() + return self._loop + + def __iter__(self) -> Iterator: + return self + + def __next__(self): + if self.index < len(self.results): + result = self.results[self.index] + self.index += 1 + return result + + # No more results; attempt to load the next page using the trigger + self.log.info("No more results. Running trigger: %s", self.trigger) + + event = self.loop.run_until_complete(run_trigger(self.trigger)) + iterator = getattr(self.operator, self.next_method)(self.context, event.payload) + if not iterator: + raise StopIteration + + self.trigger = iterator.trigger + self.results.extend(iterator.results) + self.index += 1 + return self.results[-1] + + def __len__(self): + return len(self.results) + + def __getitem__(self, index: int): + if not (0 <= index < len(self)): + raise IndexError + + return self.results[index] + + def __del__(self): + if self._loop: + self._loop.close() + + def serialize(self): + """Ensure the object is JSON serializable.""" + return { + "results": self.results, + "trigger": self.trigger.serialize(), + "operator": SerializedBaseOperator.serialize_operator(self.operator), + "next_method": self.next_method, + } + + @classmethod + def deserialize(cls, data: dict, version: int): + """Ensure the object is JSON deserializable.""" + trigger_class = import_string(data["trigger"][0]) + trigger = trigger_class(**data["trigger"][1]) + operator_class = import_string(f"{data['operator']['_task_module']}.{data['operator']['_task_type']}") + operator_kwargs = { + key: value for key, value in data["operator"].items() + if key in inspect.signature(operator_class.__init__).parameters or key in operator_class._comps + } + operator = operator_class(**operator_kwargs) + return DeferredIterable( + results=data["results"], + trigger=trigger, + operator=operator, + next_method=data["next_method"] + ) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index ccaa96a0ce868..e44cf6f69b556 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -18,7 +18,6 @@ from __future__ import annotations import asyncio -import inspect import logging import os from abc import abstractmethod @@ -29,6 +28,8 @@ from math import ceil from multiprocessing import TimeoutError from multiprocessing.pool import ApplyResult, ThreadPool +from queue import Queue +from threading import Lock, Thread from time import sleep from typing import TYPE_CHECKING, Any @@ -51,14 +52,11 @@ from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.xcom_arg import XComArg, _MapResult from airflow.serialization import serde -from airflow.serialization.serialized_objects import SerializedBaseOperator -from airflow.triggers.base import run_trigger, BaseTrigger +from airflow.triggers.base import run_trigger from airflow.utils import timezone from airflow.utils.context import context_get_outlet_events from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.module_loading import import_string from airflow.utils.operator_helpers import ExecutionCallableRunner -from airflow.utils.session import provide_session, NEW_SESSION from airflow.utils.state import TaskInstanceState from airflow.utils.task_instance_session import get_current_task_instance_session from airflow.utils.xcom import XCOM_RETURN_KEY @@ -140,99 +138,6 @@ def deserialize(cls, data: dict, version: int): return XComIterable(**data) -# TODO: should be moved to correct location as this will be used by streamable operators -class DeferredIterable(Iterator, LoggingMixin): - """An iterable that lazily fetches XCom values one by one instead of loading all at once.""" - - def __init__(self, results: list[Any] | Any, trigger: BaseTrigger, operator: BaseOperator, next_method: str, context: Context | None = None): - super().__init__() - self.results = results.copy() if isinstance(results, list) else [results] - self.trigger = trigger - self.operator = operator - self.next_method = next_method - self.context = context - self.index = 0 - self._loop = None - - @provide_session - def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> DeferredIterable: - return DeferredIterable( - results=self.results, - trigger=self.trigger, - operator=self.operator, - next_method=self.next_method, - context=context, - ) - - @property - def loop(self): - if not self._loop: - self._loop = asyncio.new_event_loop() - return self._loop - - def __iter__(self) -> Iterator: - return self - - def __next__(self): - if self.index < len(self.results): - result = self.results[self.index] - self.index += 1 - return result - - # No more results; attempt to load the next page using the trigger - self.log.info("No more results. Running trigger: %s", self.trigger) - - event = self.loop.run_until_complete(run_trigger(self.trigger)) - iterator = getattr(self.operator, self.next_method)(self.context, event.payload) - if not iterator: - raise StopIteration - - self.trigger = iterator.trigger - self.results.extend(iterator.results) - self.index += 1 - return self.results[-1] - - def __len__(self): - return len(self.results) - - def __getitem__(self, index: int): - if not (0 <= index < len(self)): - raise IndexError - - return self.results[index] - - def __del__(self): - if self._loop: - self._loop.close() - - def serialize(self): - """Ensure the object is JSON serializable.""" - return { - "results": self.results, - "trigger": self.trigger.serialize(), - "operator": SerializedBaseOperator.serialize_operator(self.operator), - "next_method": self.next_method, - } - - @classmethod - def deserialize(cls, data: dict, version: int): - """Ensure the object is JSON deserializable.""" - trigger_class = import_string(data["trigger"][0]) - trigger = trigger_class(**data["trigger"][1]) - operator_class = import_string(f"{data['operator']['_task_module']}.{data['operator']['_task_type']}") - operator_kwargs = { - key: value for key, value in data["operator"].items() - if key in inspect.signature(operator_class.__init__).parameters or key in operator_class._comps - } - operator = operator_class(**operator_kwargs) - return DeferredIterable( - results=data["results"], - trigger=trigger, - operator=operator, - next_method=data["next_method"] - ) - - class TaskExecutor(LoggingMixin): """Base class to run an operator or trigger with given task context and task instance.""" From 9958becaad3c422f72b81e4266699f09fc021d4a Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 20 Mar 2025 12:03:21 +0100 Subject: [PATCH 88/97] refactor: Also use ThreadPool to execute task producer and consumer instead of Threads --- airflow/models/iterableoperator.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index e44cf6f69b556..27f9474d6572c 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -29,7 +29,7 @@ from multiprocessing import TimeoutError from multiprocessing.pool import ApplyResult, ThreadPool from queue import Queue -from threading import Lock, Thread +from threading import Lock from time import sleep from typing import TYPE_CHECKING, Any @@ -431,7 +431,7 @@ def _run_tasks( # Task Producer def task_producer(): self.log.info("Started producing tasks: %s", tasks) - for task in iter(tasks): + for task in tasks: self.log.info("Created task: %s", task) task_queue.put(task) self._completed.set(True) @@ -494,7 +494,6 @@ def task_consumer(): gather(*deferred_tasks, return_exceptions=True) ): self.log.debug("result: %s", result) - if isinstance(result, Exception): if isinstance(result, AirflowRescheduleTaskInstanceException): reschedule_date = result.reschedule_date @@ -502,13 +501,9 @@ def task_consumer(): else: exception = result - producer_thread = Thread(target=task_producer) - consumer_thread = Thread(target=task_consumer) - - producer_thread.start() - consumer_thread.start() - producer_thread.join() - consumer_thread.join() + with ThreadPool(processes=2) as executor: + executor.apply_async(task_producer) + executor.apply(task_consumer) if not failed_tasks: if exception: From 03d831209e836d2a14bb944251628d25e88ef840 Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 20 Mar 2025 15:49:52 +0100 Subject: [PATCH 89/97] refactor: Reformatted files --- airflow/models/iterable.py | 41 +++++++++++++------ airflow/models/iterableoperator.py | 16 +++++--- .../src/airflow/sdk/definitions/xcom_arg.py | 4 +- 3 files changed, 41 insertions(+), 20 deletions(-) diff --git a/airflow/models/iterable.py b/airflow/models/iterable.py index 8c5eca58ff8f7..969cc6bf83607 100644 --- a/airflow/models/iterable.py +++ b/airflow/models/iterable.py @@ -22,23 +22,31 @@ from collections.abc import Iterator from typing import TYPE_CHECKING, Any +from airflow.exceptions import AirflowException from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string -from airflow.utils.session import provide_session, NEW_SESSION +from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: + from sqlalchemy.orm import Session + from airflow.models import BaseOperator from airflow.triggers.base import BaseTrigger, run_trigger from airflow.utils.context import Context - from sqlalchemy.orm import Session - class DeferredIterable(Iterator, LoggingMixin): """An iterable that lazily fetches XCom values one by one instead of loading all at once.""" - def __init__(self, results: list[Any] | Any, trigger: BaseTrigger, operator: BaseOperator, next_method: str, context: Context | None = None): + def __init__( + self, + results: list[Any] | Any, + trigger: BaseTrigger, + operator: BaseOperator, + next_method: str, + context: Context | None = None, + ): super().__init__() self.results = results.copy() if isinstance(results, list) else [results] self.trigger = trigger @@ -49,7 +57,9 @@ def __init__(self, results: list[Any] | Any, trigger: BaseTrigger, operator: Bas self._loop = None @provide_session - def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> DeferredIterable: + def resolve( + self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True + ) -> DeferredIterable: return DeferredIterable( results=self.results, trigger=self.trigger, @@ -76,8 +86,12 @@ def __next__(self): # No more results; attempt to load the next page using the trigger self.log.info("No more results. Running trigger: %s", self.trigger) - event = self.loop.run_until_complete(run_trigger(self.trigger)) - iterator = getattr(self.operator, self.next_method)(self.context, event.payload) + try: + event = self.loop.run_until_complete(run_trigger(self.trigger)) + iterator = getattr(self.operator, self.next_method)(self.context, event.payload) + except Exception as e: + raise AirflowException from e + if not iterator: raise StopIteration @@ -114,14 +128,15 @@ def deserialize(cls, data: dict, version: int): trigger_class = import_string(data["trigger"][0]) trigger = trigger_class(**data["trigger"][1]) operator_class = import_string(f"{data['operator']['_task_module']}.{data['operator']['_task_type']}") + operator_parameters = ( + set(inspect.signature(operator_class.__init__).parameters) + .union(set(operator_class._comps)) + .union(operator_class.template_fields) + ) operator_kwargs = { - key: value for key, value in data["operator"].items() - if key in inspect.signature(operator_class.__init__).parameters or key in operator_class._comps + key: value for key, value in data["operator"].items() if key in operator_parameters } operator = operator_class(**operator_kwargs) return DeferredIterable( - results=data["results"], - trigger=trigger, - operator=operator, - next_method=data["next_method"] + results=data["results"], trigger=trigger, operator=operator, next_method=data["next_method"] ) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index 27f9474d6572c..717079663cbef 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -395,18 +395,20 @@ def _resolve_expand_input(self, context: Context, session: Session): self.log.debug("resolve_expand_input: %s", self.expand_input) if isinstance(self.expand_input.value, XComArg): - resolved_input = self.expand_input.value.resolve( - context=context, session=session - ) + resolved_input = self.expand_input.value.resolve(context=context, session=session) else: resolved_input = self.expand_input.value self.log.debug("resolved_input: %s", resolved_input) if isinstance(resolved_input, _MapResult): - self._mapped_kwargs = map(lambda value: self._resolve(value=value, context=context, session=session), resolved_input) + self._mapped_kwargs = map( + lambda value: self._resolve(value=value, context=context, session=session), resolved_input + ) else: - self._mapped_kwargs = iter(self._lazy_mapped_kwargs(input=resolved_input, context=context, session=session)) + self._mapped_kwargs = iter( + self._lazy_mapped_kwargs(input=resolved_input, context=context, session=session) + ) self.log.info("mapped_kwargs: %s", self._mapped_kwargs) @@ -568,7 +570,9 @@ def execute(self, context: Context): return self._run_tasks( context=context, tasks=map( - lambda mapped_kwargs: self._create_task(context["ti"].run_id, mapped_kwargs[0], mapped_kwargs[1]), + lambda mapped_kwargs: self._create_task( + context["ti"].run_id, mapped_kwargs[0], mapped_kwargs[1] + ), enumerate(self._mapped_kwargs), ), ) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 88426a66f976a..aba2f8836fbc6 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -448,7 +448,9 @@ def resolve(self, context: Mapping[str, Any]) -> Any: value = self.arg.resolve(context) if not isinstance(value, (Sequence, Iterable, dict)): raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") - if isinstance(value, Iterable) and hasattr(value, "resolve"): # TODO: should check if it's DeferredIterable + if isinstance(value, Iterable) and hasattr( + value, "resolve" + ): # TODO: should check if it's DeferredIterable value = value.resolve(context) return _MapResult(value, self.callables) From 7ee67cbf05049bfeb6eee1e532ffe6fab826836c Mon Sep 17 00:00:00 2001 From: David Blain Date: Sun, 23 Mar 2025 08:59:26 +0100 Subject: [PATCH 90/97] refactor: Refactored way to get operator after deserialization trigger in DeferredOperator --- airflow/models/iterable.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/airflow/models/iterable.py b/airflow/models/iterable.py index 969cc6bf83607..7630da1cdf3b6 100644 --- a/airflow/models/iterable.py +++ b/airflow/models/iterable.py @@ -18,12 +18,10 @@ from __future__ import annotations import asyncio -import inspect from collections.abc import Iterator from typing import TYPE_CHECKING, Any from airflow.exceptions import AirflowException -from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string from airflow.utils.session import NEW_SESSION, provide_session @@ -118,25 +116,29 @@ def serialize(self): return { "results": self.results, "trigger": self.trigger.serialize(), - "operator": SerializedBaseOperator.serialize_operator(self.operator), + "dag_fileloc": self.operator.dag.fileloc, + "dag_id": self.operator.dag_id, + "task_id": self.operator.task_id, "next_method": self.next_method, } + @classmethod + def get_operator_from_dag(cls, dag_fileloc: str, dag_id: str, task_id: str) -> BaseOperator: + """Loads a DAG using DagBag and gets the operator by task_id.""" + + from airflow.models import DagBag + + dag_bag = DagBag(dag_folder=None) # Avoid loading all DAGs + dag_bag.process_file(dag_fileloc) + return dag_bag.dags[dag_id].get_task(task_id) + @classmethod def deserialize(cls, data: dict, version: int): """Ensure the object is JSON deserializable.""" + trigger_class = import_string(data["trigger"][0]) trigger = trigger_class(**data["trigger"][1]) - operator_class = import_string(f"{data['operator']['_task_module']}.{data['operator']['_task_type']}") - operator_parameters = ( - set(inspect.signature(operator_class.__init__).parameters) - .union(set(operator_class._comps)) - .union(operator_class.template_fields) - ) - operator_kwargs = { - key: value for key, value in data["operator"].items() if key in operator_parameters - } - operator = operator_class(**operator_kwargs) + operator = cls.get_operator_from_dag(data["dag_fileloc"], data["dag_id"], data["task_id"]) return DeferredIterable( results=data["results"], trigger=trigger, operator=operator, next_method=data["next_method"] ) From 15f9b5c799c65ed4ab22024be0344e3854cae7f6 Mon Sep 17 00:00:00 2001 From: David Blain Date: Sun, 23 Mar 2025 08:59:53 +0100 Subject: [PATCH 91/97] refactor: Refactored producing of tasks in IterableOperator --- airflow/models/iterableoperator.py | 182 +++++++++++++---------------- 1 file changed, 83 insertions(+), 99 deletions(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index 717079663cbef..bc30f200445b8 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -29,7 +29,7 @@ from multiprocessing import TimeoutError from multiprocessing.pool import ApplyResult, ThreadPool from queue import Queue -from threading import Lock +from threading import Thread from time import sleep from typing import TYPE_CHECKING, Any @@ -289,20 +289,6 @@ async def run(self, task_deferred: TaskDeferred): return await self.run(task_deferred=task_deferred) -class AtomicBoolean: - def __init__(self, initial: bool = False): - self._value = initial - self._lock = Lock() - - def set(self, value: bool): - with self._lock: - self._value = value - - def get(self) -> bool: - with self._lock: - return self._value - - class IterableOperator(BaseOperator): """Object representing an iterable operator in a DAG.""" @@ -336,7 +322,6 @@ def __init__( if not self.max_active_tis_per_dag: self.max_active_tis_per_dag = os.cpu_count() or 1 self._semaphore = Semaphore(self.max_active_tis_per_dag) - self._completed = AtomicBoolean() self._number_of_tasks: int = 0 XComArg.apply_upstream_relationship(self, self.expand_input.value) @@ -402,15 +387,11 @@ def _resolve_expand_input(self, context: Context, session: Session): self.log.debug("resolved_input: %s", resolved_input) if isinstance(resolved_input, _MapResult): - self._mapped_kwargs = map( - lambda value: self._resolve(value=value, context=context, session=session), resolved_input - ) + self._mapped_kwargs = map(lambda value: self._resolve(value=value, context=context, session=session), resolved_input) else: - self._mapped_kwargs = iter( - self._lazy_mapped_kwargs(input=resolved_input, context=context, session=session) - ) + self._mapped_kwargs = iter(self._lazy_mapped_kwargs(input=resolved_input, context=context, session=session)) - self.log.info("mapped_kwargs: %s", self._mapped_kwargs) + self.log.debug("mapped_kwargs: %s", self._mapped_kwargs) def render_template_fields( self, @@ -425,87 +406,93 @@ def _run_tasks( context: Context, tasks: Iterable[TaskInstance], ) -> Iterable[Any] | None: - exception: BaseException | None = None + exception: Exception | None = None reschedule_date = timezone.utcnow() failed_tasks: list[TaskInstance] = [] task_queue = Queue() # Task Producer def task_producer(): - self.log.info("Started producing tasks: %s", tasks) - for task in tasks: - self.log.info("Created task: %s", task) - task_queue.put(task) - self._completed.set(True) - self.log.info("Finished producing tasks: %s", self._completed.get()) + try: + self.log.info("Started producing tasks") + for task in tasks: + self.log.info("Created task: %s", task) + task_queue.put(task) + self.log.info("Finished producing tasks") + except Exception as e: + self.log.error("Exception in task_producer: %s", e) + task_queue.put(e) + + producer_thread = Thread(target=task_producer) + producer_thread.start() + + self.log.debug("task_queue : %s", task_queue.qsize()) # Task Consumer - def task_consumer(): - self.log.info("Started consuming tasks") - nonlocal exception, reschedule_date - - self.log.info("task_queue : %s", task_queue.qsize()) - self.log.info("completed: %s", self._completed.get()) - - while not task_queue.empty() or not self._completed.get(): - deferred_tasks: list[Future] = [] - - with ThreadPool(processes=self.max_active_tis_per_dag) as pool: - futures = {} - - while not task_queue.empty(): - task = task_queue.get(timeout=1) - self.log.info("received task : %s", task) - futures[pool.apply_async(self._run_operator, (context, task))] = task - - self.log.debug("futures: %s", futures) - - for future in filter(ApplyResult.ready, list(futures.keys())): - task = futures.pop(future) - try: - result = future.get(timeout=self.timeout) - if isinstance(result, TaskDeferred): - deferred_tasks.append( - ensure_future( - self._run_deferrable( - context=context, - task_instance=task, - task_deferred=result, - ) - ) - ) - except TimeoutError as e: - self.log.warning("A timeout occurred for task_id %s", task.task_id) - if task.next_try_number > self.retries: - exception = AirflowTaskTimeout(e) + while not task_queue.empty() or producer_thread.is_alive(): + deferred_tasks: list[Future] = [] + + with ThreadPool(processes=self.max_active_tis_per_dag) as pool: + futures = {} + + while not task_queue.empty(): + task = task_queue.get(timeout=1) + self.log.info("received task : %s", task) + + # Check if the task is an exception and stop immediately + if isinstance(task, Exception): + producer_thread.join() + raise AirflowException("An exception occurred in the producer thread") from task + + futures[pool.apply_async(self._run_operator, (context, task))] = task + + self.log.debug("futures: %s", futures) + + for future in filter(ApplyResult.ready, list(futures.keys())): + task = futures.pop(future) + try: + result = future.get(timeout=self.timeout) + if isinstance(result, TaskDeferred): + deferred_tasks.append( + ensure_future( + self._run_deferrable( + context=context, + task_instance=task, + task_deferred=result, + ), + ) + ) + except TimeoutError as e: + self.log.warning("A timeout occurred for task_id %s", task.task_id) + if task.next_try_number > self.retries: + exception = AirflowTaskTimeout(e) + else: + reschedule_date = task.next_retry_datetime() + failed_tasks.append(task) + except AirflowRescheduleTaskInstanceException as e: + reschedule_date = e.reschedule_date + failed_tasks.append(e.task) + except AirflowException as e: + self.log.error("An exception occurred for task_id %s", task.task_id) + exception = e + + if deferred_tasks: + self.log.info("Running %s deferred tasks", len(deferred_tasks)) + + with event_loop() as loop: + for result in loop.run_until_complete( + gather(*deferred_tasks, return_exceptions=True) + ): + self.log.debug("result: %s", result) + + if isinstance(result, Exception): + if isinstance(result, AirflowRescheduleTaskInstanceException): + reschedule_date = result.reschedule_date + failed_tasks.append(result.task) else: - reschedule_date = task.next_retry_datetime() - failed_tasks.append(task) - except AirflowRescheduleTaskInstanceException as e: - reschedule_date = e.reschedule_date - failed_tasks.append(e.task) - except AirflowException as e: - self.log.error("An exception occurred for task_id %s", task.task_id) - exception = e - - if deferred_tasks: - self.log.info("Running %s deferred tasks", len(deferred_tasks)) - - with event_loop() as loop: - for result in loop.run_until_complete( - gather(*deferred_tasks, return_exceptions=True) - ): - self.log.debug("result: %s", result) - if isinstance(result, Exception): - if isinstance(result, AirflowRescheduleTaskInstanceException): - reschedule_date = result.reschedule_date - failed_tasks.append(result.task) - else: - exception = result - - with ThreadPool(processes=2) as executor: - executor.apply_async(task_producer) - executor.apply(task_consumer) + exception = result + + producer_thread.join() if not failed_tasks: if exception: @@ -534,7 +521,6 @@ def task_consumer(): len(failed_tasks), delay_seconds, ) - sleep(delay_seconds) return self._run_tasks(context, failed_tasks) @@ -570,9 +556,7 @@ def execute(self, context: Context): return self._run_tasks( context=context, tasks=map( - lambda mapped_kwargs: self._create_task( - context["ti"].run_id, mapped_kwargs[0], mapped_kwargs[1] - ), + lambda mapped_kwargs: self._create_task(context["ti"].run_id, mapped_kwargs[0], mapped_kwargs[1]), enumerate(self._mapped_kwargs), ), ) From ad4372843456174063bbc3540950d19228b313b2 Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 24 Mar 2025 20:21:46 +0100 Subject: [PATCH 92/97] refactor: Improved IterableOperator and DeferredIterable --- airflow/models/iterable.py | 28 +++++----- airflow/models/iterableoperator.py | 85 +++++++++++++++--------------- 2 files changed, 59 insertions(+), 54 deletions(-) diff --git a/airflow/models/iterable.py b/airflow/models/iterable.py index 7630da1cdf3b6..171bfaa262478 100644 --- a/airflow/models/iterable.py +++ b/airflow/models/iterable.py @@ -81,20 +81,26 @@ def __next__(self): self.index += 1 return result - # No more results; attempt to load the next page using the trigger + if not self.trigger: + raise StopIteration + self.log.info("No more results. Running trigger: %s", self.trigger) try: event = self.loop.run_until_complete(run_trigger(self.trigger)) - iterator = getattr(self.operator, self.next_method)(self.context, event.payload) + next_method = getattr(self.operator, self.next_method) + self.log.debug("Triggering next method: %s", self.next_method) + results = next_method(self.context, event.payload) except Exception as e: raise AirflowException from e - if not iterator: - raise StopIteration + if isinstance(results, DeferredIterable): + self.trigger = results.trigger + self.results.extend(results.results) + else: + self.trigger = None + self.results.extend(results) - self.trigger = iterator.trigger - self.results.extend(iterator.results) self.index += 1 return self.results[-1] @@ -115,21 +121,19 @@ def serialize(self): """Ensure the object is JSON serializable.""" return { "results": self.results, - "trigger": self.trigger.serialize(), - "dag_fileloc": self.operator.dag.fileloc, + "trigger": self.trigger.serialize() if self.trigger else None, "dag_id": self.operator.dag_id, "task_id": self.operator.task_id, "next_method": self.next_method, } @classmethod - def get_operator_from_dag(cls, dag_fileloc: str, dag_id: str, task_id: str) -> BaseOperator: + def get_operator_from_dag(cls, dag_id: str, task_id: str) -> BaseOperator: """Loads a DAG using DagBag and gets the operator by task_id.""" from airflow.models import DagBag - dag_bag = DagBag(dag_folder=None) # Avoid loading all DAGs - dag_bag.process_file(dag_fileloc) + dag_bag = DagBag(dag_folder=None) return dag_bag.dags[dag_id].get_task(task_id) @classmethod @@ -138,7 +142,7 @@ def deserialize(cls, data: dict, version: int): trigger_class = import_string(data["trigger"][0]) trigger = trigger_class(**data["trigger"][1]) - operator = cls.get_operator_from_dag(data["dag_fileloc"], data["dag_id"], data["task_id"]) + operator = cls.get_operator_from_dag(data["dag_id"], data["task_id"]) return DeferredIterable( results=data["results"], trigger=trigger, operator=operator, next_method=data["next_method"] ) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index bc30f200445b8..38b86708c064b 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -31,7 +31,7 @@ from queue import Queue from threading import Thread from time import sleep -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Coroutine from airflow.exceptions import ( AirflowException, @@ -300,8 +300,9 @@ class IterableOperator(BaseOperator): "expand_input", "partial_kwargs", "_log", + "_task_queue", + "_producer_thread", "_semaphore", - "_completed", ) def __init__( @@ -321,6 +322,8 @@ def __init__( self._mapped_kwargs: Iterable[dict] = [] if not self.max_active_tis_per_dag: self.max_active_tis_per_dag = os.cpu_count() or 1 + self._task_queue = Queue() + self._producer_thread: Thread | None = None self._semaphore = Semaphore(self.max_active_tis_per_dag) self._number_of_tasks: int = 0 XComArg.apply_upstream_relationship(self, self.expand_input.value) @@ -404,44 +407,26 @@ def render_template_fields( def _run_tasks( self, context: Context, - tasks: Iterable[TaskInstance], ) -> Iterable[Any] | None: exception: Exception | None = None reschedule_date = timezone.utcnow() failed_tasks: list[TaskInstance] = [] - task_queue = Queue() - # Task Producer - def task_producer(): - try: - self.log.info("Started producing tasks") - for task in tasks: - self.log.info("Created task: %s", task) - task_queue.put(task) - self.log.info("Finished producing tasks") - except Exception as e: - self.log.error("Exception in task_producer: %s", e) - task_queue.put(e) - - producer_thread = Thread(target=task_producer) - producer_thread.start() - - self.log.debug("task_queue : %s", task_queue.qsize()) + self.log.info("task_queue : %s", self._task_queue.qsize()) # Task Consumer - while not task_queue.empty() or producer_thread.is_alive(): - deferred_tasks: list[Future] = [] + while not self._task_queue.empty() or self._producer_thread.is_alive(): + futures = {} + deferred_tasks: list[Coroutine[Any, Any, Any]] = [] with ThreadPool(processes=self.max_active_tis_per_dag) as pool: - futures = {} - - while not task_queue.empty(): - task = task_queue.get(timeout=1) - self.log.info("received task : %s", task) + while not self._task_queue.empty(): + task = self._task_queue.get(timeout=1) + self.log.debug("received task : %s", task) # Check if the task is an exception and stop immediately if isinstance(task, Exception): - producer_thread.join() + self._producer_thread.join() raise AirflowException("An exception occurred in the producer thread") from task futures[pool.apply_async(self._run_operator, (context, task))] = task @@ -452,15 +437,16 @@ def task_producer(): task = futures.pop(future) try: result = future.get(timeout=self.timeout) + + self.log.debug("result: %s", result) + if isinstance(result, TaskDeferred): deferred_tasks.append( - ensure_future( - self._run_deferrable( - context=context, - task_instance=task, - task_deferred=result, - ), - ) + self._run_deferrable( + context=context, + task_instance=task, + task_deferred=result, + ), ) except TimeoutError as e: self.log.warning("A timeout occurred for task_id %s", task.task_id) @@ -492,7 +478,7 @@ def task_producer(): else: exception = result - producer_thread.join() + self._producer_thread.join() if not failed_tasks: if exception: @@ -509,6 +495,9 @@ def task_producer(): ) return None + for failed_task in failed_tasks: + self._task_queue.put(failed_task) + # Calculate delay before the next retry delay = reschedule_date - timezone.utcnow() delay_seconds = ceil(delay.total_seconds()) @@ -523,7 +512,7 @@ def task_producer(): ) sleep(delay_seconds) - return self._run_tasks(context, failed_tasks) + return self._run_tasks(context) def _run_operator(self, context: Context, task_instance: TaskInstance): try: @@ -553,10 +542,22 @@ def _create_task(self, run_id: str, index: int, mapped_kwargs: dict) -> TaskInst ) def execute(self, context: Context): - return self._run_tasks( - context=context, - tasks=map( + def task_producer(): + tasks = map( lambda mapped_kwargs: self._create_task(context["ti"].run_id, mapped_kwargs[0], mapped_kwargs[1]), enumerate(self._mapped_kwargs), - ), - ) + ) + + try: + self.log.info("Started producing tasks") + for task in tasks: + self._task_queue.put(task) + self.log.info("Finished producing tasks") + except Exception as e: + self.log.error("Exception in task_producer: %s", e) + self._task_queue.put(e) + + self._producer_thread = Thread(target=task_producer) + self._producer_thread.start() + + return self._run_tasks(context=context) From d9d3daca766c84c603d27f33b5a02361770eddf0 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 26 Mar 2025 18:03:25 +0100 Subject: [PATCH 93/97] refactor: Made next_kwargs optional for next_callable method in BaseOperator --- airflow/models/iterableoperator.py | 104 ++++++++++-------- .../airflow/sdk/definitions/baseoperator.py | 17 +-- 2 files changed, 67 insertions(+), 54 deletions(-) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index 38b86708c064b..32f5c61fc961d 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -31,7 +31,7 @@ from queue import Queue from threading import Thread from time import sleep -from typing import TYPE_CHECKING, Any, Coroutine +from typing import TYPE_CHECKING, Any from airflow.exceptions import ( AirflowException, @@ -75,6 +75,8 @@ def event_loop() -> Generator[AbstractEventLoop, None, None]: try: try: loop = asyncio.get_event_loop() + if loop.is_closed(): + raise RuntimeError except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -302,7 +304,7 @@ class IterableOperator(BaseOperator): "_log", "_task_queue", "_producer_thread", - "_semaphore", + # "_semaphore", ) def __init__( @@ -324,7 +326,6 @@ def __init__( self.max_active_tis_per_dag = os.cpu_count() or 1 self._task_queue = Queue() self._producer_thread: Thread | None = None - self._semaphore = Semaphore(self.max_active_tis_per_dag) self._number_of_tasks: int = 0 XComArg.apply_upstream_relationship(self, self.expand_input.value) @@ -390,9 +391,13 @@ def _resolve_expand_input(self, context: Context, session: Session): self.log.debug("resolved_input: %s", resolved_input) if isinstance(resolved_input, _MapResult): - self._mapped_kwargs = map(lambda value: self._resolve(value=value, context=context, session=session), resolved_input) + self._mapped_kwargs = map( + lambda value: self._resolve(value=value, context=context, session=session), resolved_input + ) else: - self._mapped_kwargs = iter(self._lazy_mapped_kwargs(input=resolved_input, context=context, session=session)) + self._mapped_kwargs = iter( + self._lazy_mapped_kwargs(input=resolved_input, context=context, session=session) + ) self.log.debug("mapped_kwargs: %s", self._mapped_kwargs) @@ -417,7 +422,7 @@ def _run_tasks( # Task Consumer while not self._task_queue.empty() or self._producer_thread.is_alive(): futures = {} - deferred_tasks: list[Coroutine[Any, Any, Any]] = [] + deferred_tasks: list[Future] = [] with ThreadPool(processes=self.max_active_tis_per_dag) as pool: while not self._task_queue.empty(): @@ -431,44 +436,48 @@ def _run_tasks( futures[pool.apply_async(self._run_operator, (context, task))] = task - self.log.debug("futures: %s", futures) + self.log.debug("futures: %s", futures) - for future in filter(ApplyResult.ready, list(futures.keys())): - task = futures.pop(future) - try: - result = future.get(timeout=self.timeout) - - self.log.debug("result: %s", result) - - if isinstance(result, TaskDeferred): - deferred_tasks.append( - self._run_deferrable( - context=context, - task_instance=task, - task_deferred=result, - ), - ) - except TimeoutError as e: - self.log.warning("A timeout occurred for task_id %s", task.task_id) - if task.next_try_number > self.retries: - exception = AirflowTaskTimeout(e) - else: - reschedule_date = task.next_retry_datetime() - failed_tasks.append(task) - except AirflowRescheduleTaskInstanceException as e: - reschedule_date = e.reschedule_date - failed_tasks.append(e.task) - except AirflowException as e: - self.log.error("An exception occurred for task_id %s", task.task_id) - exception = e - - if deferred_tasks: - self.log.info("Running %s deferred tasks", len(deferred_tasks)) - - with event_loop() as loop: - for result in loop.run_until_complete( - gather(*deferred_tasks, return_exceptions=True) - ): + with event_loop() as loop: + semaphore = Semaphore(self.max_active_tis_per_dag, loop=loop) + + for future in filter(ApplyResult.ready, list(futures.keys())): + task = futures.pop(future) + try: + result = future.get(timeout=self.timeout) + + self.log.debug("result: %s", result) + + if isinstance(result, TaskDeferred): + deferred_tasks.append( + ensure_future( + self._run_deferrable( + semaphore=semaphore, + context=context, + task_instance=task, + task_deferred=result, + ), + loop=loop, + ), + ) + except TimeoutError as e: + self.log.warning("A timeout occurred for task_id %s", task.task_id) + if task.next_try_number > self.retries: + exception = AirflowTaskTimeout(e) + else: + reschedule_date = task.next_retry_datetime() + failed_tasks.append(task) + except AirflowRescheduleTaskInstanceException as e: + reschedule_date = e.reschedule_date + failed_tasks.append(e.task) + except AirflowException as e: + self.log.error("An exception occurred for task_id %s", task.task_id) + exception = e + + if deferred_tasks: + self.log.info("Running %s deferred tasks", len(deferred_tasks)) + + for result in loop.run_until_complete(gather(*deferred_tasks, return_exceptions=True)): self.log.debug("result: %s", result) if isinstance(result, Exception): @@ -525,11 +534,12 @@ def _run_operator(self, context: Context, task_instance: TaskInstance): return task_deferred async def _run_deferrable( - self, context: Context, task_instance: TaskInstance, task_deferred: TaskDeferred + self, semaphore: Semaphore, context: Context, task_instance: TaskInstance, task_deferred: TaskDeferred ): - async with self._semaphore: + async with semaphore: async with TriggerExecutor(context=context, task_instance=task_instance) as executor: result = await executor.run(task_deferred) + self.log.info("_run_deferrable: %s", result) if self.do_xcom_push: task_instance.xcom_push(key=XCOM_RETURN_KEY, value=result) return result @@ -544,7 +554,9 @@ def _create_task(self, run_id: str, index: int, mapped_kwargs: dict) -> TaskInst def execute(self, context: Context): def task_producer(): tasks = map( - lambda mapped_kwargs: self._create_task(context["ti"].run_id, mapped_kwargs[0], mapped_kwargs[1]), + lambda mapped_kwargs: self._create_task( + context["ti"].run_id, mapped_kwargs[0], mapped_kwargs[1] + ), enumerate(self._mapped_kwargs), ) diff --git a/task-sdk/src/airflow/sdk/definitions/baseoperator.py b/task-sdk/src/airflow/sdk/definitions/baseoperator.py index 863af4f7e6a69..5834a95c3ef68 100644 --- a/task-sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/baseoperator.py @@ -1580,21 +1580,22 @@ def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, execute_callable = getattr(self, next_method) return execute_callable(context, **next_kwargs) - @classmethod - def next_callable(cls, operator, next_method, next_kwargs) -> Callable[..., Any]: + def next_callable( + self, next_method: str, next_kwargs: dict[str, Any] | None = None + ) -> Callable[..., Any]: + """Get the next callable from given operator.""" from airflow.exceptions import TaskDeferralError - """Get the next callable from given operator.""" # __fail__ is a special signal value for next_method that indicates # this task was scheduled specifically to fail. if next_method == "__fail__": next_kwargs = next_kwargs or {} traceback = next_kwargs.get("traceback") if traceback is not None: - cls.logger().error("Trigger failed:\n%s", "\n".join(traceback)) + self.log.error("Trigger failed:\n%s", "\n".join(traceback)) raise TaskDeferralError(next_kwargs.get("error", "Unknown")) # Grab the callable off the Operator/Task and add in any kwargs - execute_callable = getattr(operator, next_method) + execute_callable = getattr(self, next_method) if next_kwargs: execute_callable = partial(execute_callable, **next_kwargs) return execute_callable @@ -1835,9 +1836,9 @@ def chain_linear(*elements: DependencyMixin | Sequence[DependencyMixin]): E.g.: suppose you want precedence like so:: - â•­â"€op2â"€â•® â•­â"€op4â"€â•® - op1â"€â"¤ â"œâ"€â"œâ"€op5â"€â"¤â"€op7 - â•°-op3â"€â•¯ â•°-op6â"€â•¯ + ╭�"�op2�"�╮ ╭�"�op4�"�╮ + op1�"��"� �"��"��"��"�op5�"��"��"�op7 + â•°-op3�"�╯ â•°-op6�"�╯ Then you can accomplish like so:: From 07f98cc02ce639bcf8a03cc44145d3664d027cd5 Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 27 Mar 2025 15:52:32 +0100 Subject: [PATCH 94/97] refactor: Improved lazy evaluation of tasks passed to the _run_tasks method of the IterableOperator --- airflow/models/iterable.py | 144 +++++++-- airflow/models/iterableoperator.py | 281 +++++++----------- .../src/airflow/sdk/definitions/xcom_arg.py | 6 +- 3 files changed, 221 insertions(+), 210 deletions(-) diff --git a/airflow/models/iterable.py b/airflow/models/iterable.py index 171bfaa262478..f8db7634c33de 100644 --- a/airflow/models/iterable.py +++ b/airflow/models/iterable.py @@ -19,19 +19,105 @@ import asyncio from collections.abc import Iterator -from typing import TYPE_CHECKING, Any +from contextlib import contextmanager, suppress +from typing import TYPE_CHECKING, Any, Generator, Sequence from airflow.exceptions import AirflowException + from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.module_loading import import_string -from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.xcom import XCOM_RETURN_KEY +from airflow.serialization import serde + +try: + from airflow.sdk.definitions._internal.abstractoperator import Operator + from airflow.sdk.definitions.context import Context + from airflow.sdk.definitions.xcom_arg import XComArg, MapXComArg + from airflow.sdk.execution_time.xcom import XCom +except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.models.baseoperator import BaseOperator as Operator + from airflow.models.xcom_arg import XComArg, MapXComArg + from airflow.models import XCom + from airflow.utils.context import Context if TYPE_CHECKING: + from asyncio import AbstractEventLoop from sqlalchemy.orm import Session - from airflow.models import BaseOperator from airflow.triggers.base import BaseTrigger, run_trigger - from airflow.utils.context import Context + + +@contextmanager +def event_loop() -> Generator[AbstractEventLoop, None, None]: + new_event_loop = False + loop = None + try: + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + raise RuntimeError + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + new_event_loop = True + yield loop + finally: + if new_event_loop and loop is not None: + with suppress(AttributeError): + loop.close() + + +class XComIterable(Iterator, Sequence): + """An iterable that lazily fetches XCom values one by one instead of loading all at once.""" + + def __init__(self, task_id: str, dag_id: str, run_id: str, length: int): + self.task_id = task_id + self.dag_id = dag_id + self.run_id = run_id + self.length = length + self.index = 0 + + def __iter__(self) -> Iterator: + self.index = 0 + return self + + def __next__(self): + if self.index >= self.length: + raise StopIteration + + value = self[self.index] + self.index += 1 + return value + + def __len__(self): + return self.length + + def __getitem__(self, index: int): + """Allows direct indexing, making this work like a sequence.""" + if not (0 <= index < self.length): + raise IndexError + + return XCom.get_one( + key=XCOM_RETURN_KEY, + dag_id=self.dag_id, + task_id=f"{self.task_id}_{index}", + run_id=self.run_id, + ) + + def serialize(self): + """Ensure the object is JSON serializable.""" + return { + "task_id": self.task_id, + "dag_id": self.dag_id, + "run_id": self.run_id, + "length": self.length, + } + + @classmethod + def deserialize(cls, data: dict, version: int): + """Ensure the object is JSON deserializable.""" + return XComIterable(**data) class DeferredIterable(Iterator, LoggingMixin): @@ -41,7 +127,7 @@ def __init__( self, results: list[Any] | Any, trigger: BaseTrigger, - operator: BaseOperator, + operator: Operator, next_method: str, context: Context | None = None, ): @@ -52,26 +138,16 @@ def __init__( self.next_method = next_method self.context = context self.index = 0 - self._loop = None - @provide_session - def resolve( - self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True - ) -> DeferredIterable: + def resolve(self, context: Context) -> DeferredIterable: return DeferredIterable( results=self.results, trigger=self.trigger, operator=self.operator, next_method=self.next_method, - context=context, + context=context ) - @property - def loop(self): - if not self._loop: - self._loop = asyncio.new_event_loop() - return self._loop - def __iter__(self) -> Iterator: return self @@ -87,11 +163,13 @@ def __next__(self): self.log.info("No more results. Running trigger: %s", self.trigger) try: - event = self.loop.run_until_complete(run_trigger(self.trigger)) - next_method = getattr(self.operator, self.next_method) - self.log.debug("Triggering next method: %s", self.next_method) - results = next_method(self.context, event.payload) + with event_loop() as loop: + event = loop.run_until_complete(run_trigger(self.trigger)) + next_method = getattr(self.operator, self.next_method) + self.log.debug("Triggering next method: %s", self.next_method) + results = next_method(self.context, event.payload) except Exception as e: + self.log.exception(e) raise AirflowException from e if isinstance(results, DeferredIterable): @@ -113,27 +191,28 @@ def __getitem__(self, index: int): return self.results[index] - def __del__(self): - if self._loop: - self._loop.close() - def serialize(self): """Ensure the object is JSON serializable.""" return { "results": self.results, "trigger": self.trigger.serialize() if self.trigger else None, + "dag_fileloc": self.operator.dag.fileloc, "dag_id": self.operator.dag_id, "task_id": self.operator.task_id, "next_method": self.next_method, } @classmethod - def get_operator_from_dag(cls, dag_id: str, task_id: str) -> BaseOperator: + def get_operator_from_dag(cls, dag_fileloc: str, dag_id: str, task_id: str) -> Operator: """Loads a DAG using DagBag and gets the operator by task_id.""" from airflow.models import DagBag - dag_bag = DagBag(dag_folder=None) + dag_bag = DagBag(dag_folder=None) # Avoid loading all DAGs + processed_dags = dag_bag.process_file(dag_fileloc) + cls.logger().info("processed_dags: %s", processed_dags) + cls.logger().info("dag_bag: %s", dag_bag) + cls.logger().info("dags: %s", dag_bag.dags) return dag_bag.dags[dag_id].get_task(task_id) @classmethod @@ -142,7 +221,16 @@ def deserialize(cls, data: dict, version: int): trigger_class = import_string(data["trigger"][0]) trigger = trigger_class(**data["trigger"][1]) - operator = cls.get_operator_from_dag(data["dag_id"], data["task_id"]) + operator = cls.get_operator_from_dag(data["dag_fileloc"], data["dag_id"], data["task_id"]) return DeferredIterable( results=data["results"], trigger=trigger, operator=operator, next_method=data["next_method"] ) + + +# This is a workaround to allow the DeferredIterable and XComIterable classes to be serialized +serde._extra_allowed = serde._extra_allowed.union( + { + f"{XComIterable.__module__}.{XComIterable.__class__.__name__}", + f"{DeferredIterable.__module__}.{DeferredIterable.__class__.__name__}", + } +) diff --git a/airflow/models/iterableoperator.py b/airflow/models/iterableoperator.py index 32f5c61fc961d..7a511b014612d 100644 --- a/airflow/models/iterableoperator.py +++ b/airflow/models/iterableoperator.py @@ -17,21 +17,20 @@ # under the License. from __future__ import annotations -import asyncio import logging import os from abc import abstractmethod -from asyncio import AbstractEventLoop, Future, Semaphore, ensure_future, gather -from collections.abc import Generator, Iterable, Iterator, Sequence -from contextlib import contextmanager, suppress +from asyncio import Semaphore, gather +from collections.abc import Iterable, Iterator, Sequence +from contextlib import suppress from datetime import timedelta +from json import JSONDecodeError from math import ceil +from more_itertools import ichunked from multiprocessing import TimeoutError from multiprocessing.pool import ApplyResult, ThreadPool -from queue import Queue -from threading import Thread from time import sleep -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Coroutine from airflow.exceptions import ( AirflowException, @@ -46,12 +45,11 @@ _needs_run_time_resolution, is_mappable, ) +from airflow.models.iterable import event_loop, XComIterable from airflow.models.taskinstance import TaskInstance -from airflow.models.xcom import XCom from airflow.sdk.definitions._internal.abstractoperator import Operator from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.xcom_arg import XComArg, _MapResult -from airflow.serialization import serde from airflow.triggers.base import run_trigger from airflow.utils import timezone from airflow.utils.context import context_get_outlet_events @@ -65,80 +63,6 @@ import jinja2 from sqlalchemy.orm import Session -serde._extra_allowed = serde._extra_allowed.union({"infrabel.operators.iterableoperator.XComIterable"}) - - -@contextmanager -def event_loop() -> Generator[AbstractEventLoop, None, None]: - new_event_loop = False - loop = None - try: - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - raise RuntimeError - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - new_event_loop = True - yield loop - finally: - if new_event_loop and loop is not None: - with suppress(AttributeError): - loop.close() - - -class XComIterable(Iterator, Sequence): - """An iterable that lazily fetches XCom values one by one instead of loading all at once.""" - - def __init__(self, task_id: str, dag_id: str, run_id: str, length: int): - self.task_id = task_id - self.dag_id = dag_id - self.run_id = run_id - self.length = length - self.index = 0 - - def __iter__(self) -> Iterator: - self.index = 0 - return self - - def __next__(self): - if self.index >= self.length: - raise StopIteration - - value = self[self.index] - self.index += 1 - return value - - def __len__(self): - return self.length - - def __getitem__(self, index: int): - """Allows direct indexing, making this work like a sequence.""" - if not (0 <= index < self.length): - raise IndexError - - return XCom.get_one( - key=XCOM_RETURN_KEY, - dag_id=self.dag_id, - task_id=f"{self.task_id}_{index}", - run_id=self.run_id, - ) - - def serialize(self): - """Ensure the object is JSON serializable.""" - return { - "task_id": self.task_id, - "dag_id": self.dag_id, - "run_id": self.run_id, - "length": self.length, - } - - @classmethod - def deserialize(cls, data: dict, version: int): - """Ensure the object is JSON deserializable.""" - return XComIterable(**data) - class TaskExecutor(LoggingMixin): """Base class to run an operator or trigger with given task context and task instance.""" @@ -192,7 +116,9 @@ def __enter__(self): self.operator.render_template_fields(context=self.context) self.operator.pre_execute(context=self.context) self.task_instance.set_state(TaskInstanceState.SCHEDULED) - self.task_instance._run_execute_callback(context=self.context, task=self.operator) + self.task_instance._run_execute_callback( + context=self.context, task=self.operator + ) return self async def __aenter__(self): @@ -280,7 +206,9 @@ async def run(self, task_deferred: TaskDeferred): if task_deferred.method_name: try: - next_method = self.operator.next_callable(task_deferred.method_name, task_deferred.kwargs) + next_method = self.operator.next_callable( + task_deferred.method_name, task_deferred.kwargs + ) outlet_events = context_get_outlet_events(self.context) return ExecutionCallableRunner( func=next_method, @@ -302,9 +230,7 @@ class IterableOperator(BaseOperator): "expand_input", "partial_kwargs", "_log", - "_task_queue", - "_producer_thread", - # "_semaphore", + "_semaphore", ) def __init__( @@ -320,12 +246,13 @@ def __init__( self._operator_class = operator_class self.expand_input = expand_input self.partial_kwargs = partial_kwargs or {} - self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout + self.timeout = ( + timeout.total_seconds() if isinstance(timeout, timedelta) else timeout + ) self._mapped_kwargs: Iterable[dict] = [] if not self.max_active_tis_per_dag: self.max_active_tis_per_dag = os.cpu_count() or 1 - self._task_queue = Queue() - self._producer_thread: Thread | None = None + self._semaphore = Semaphore(self.max_active_tis_per_dag) self._number_of_tasks: int = 0 XComArg.apply_upstream_relationship(self, self.expand_input.value) @@ -384,20 +311,18 @@ def _resolve_expand_input(self, context: Context, session: Session): self.log.debug("resolve_expand_input: %s", self.expand_input) if isinstance(self.expand_input.value, XComArg): - resolved_input = self.expand_input.value.resolve(context=context, session=session) + resolved_input = self.expand_input.value.resolve( + context=context, session=session + ) else: resolved_input = self.expand_input.value self.log.debug("resolved_input: %s", resolved_input) if isinstance(resolved_input, _MapResult): - self._mapped_kwargs = map( - lambda value: self._resolve(value=value, context=context, session=session), resolved_input - ) + self._mapped_kwargs = map(lambda value: self._resolve(value=value, context=context, session=session), resolved_input) else: - self._mapped_kwargs = iter( - self._lazy_mapped_kwargs(input=resolved_input, context=context, session=session) - ) + self._mapped_kwargs = iter(self._lazy_mapped_kwargs(input=resolved_input, context=context, session=session)) self.log.debug("mapped_kwargs: %s", self._mapped_kwargs) @@ -412,37 +337,27 @@ def render_template_fields( def _run_tasks( self, context: Context, + tasks: Iterator[TaskInstance], ) -> Iterable[Any] | None: - exception: Exception | None = None + exception: BaseException | None = None reschedule_date = timezone.utcnow() + futures: dict[ApplyResult, TaskInstance] = {} failed_tasks: list[TaskInstance] = [] + chunked_tasks: Iterator[Iterable[TaskInstance]] = ichunked(tasks, (self.max_active_tis_per_dag * 2)) - self.log.info("task_queue : %s", self._task_queue.qsize()) - - # Task Consumer - while not self._task_queue.empty() or self._producer_thread.is_alive(): - futures = {} - deferred_tasks: list[Future] = [] + with ThreadPool(processes=self.max_active_tis_per_dag) as pool: + for task in next(chunked_tasks, []): + future = pool.apply_async(self._run_operator, (context, task)) + futures[future] = task - with ThreadPool(processes=self.max_active_tis_per_dag) as pool: - while not self._task_queue.empty(): - task = self._task_queue.get(timeout=1) - self.log.debug("received task : %s", task) + while futures: + self.log.info("Number of remaining futures: %s", len(futures)) - # Check if the task is an exception and stop immediately - if isinstance(task, Exception): - self._producer_thread.join() - raise AirflowException("An exception occurred in the producer thread") from task - - futures[pool.apply_async(self._run_operator, (context, task))] = task - - self.log.debug("futures: %s", futures) - - with event_loop() as loop: - semaphore = Semaphore(self.max_active_tis_per_dag, loop=loop) + deferred_tasks: list[Coroutine[Any, Any, Any]] = [] for future in filter(ApplyResult.ready, list(futures.keys())): task = futures.pop(future) + try: result = future.get(timeout=self.timeout) @@ -450,18 +365,16 @@ def _run_tasks( if isinstance(result, TaskDeferred): deferred_tasks.append( - ensure_future( - self._run_deferrable( - semaphore=semaphore, - context=context, - task_instance=task, - task_deferred=result, - ), - loop=loop, - ), + self._run_deferrable( + context=context, + task_instance=task, + task_deferred=result, + ) ) except TimeoutError as e: - self.log.warning("A timeout occurred for task_id %s", task.task_id) + self.log.warning( + "A timeout occurred for task_id %s", task.task_id + ) if task.next_try_number > self.retries: exception = AirflowTaskTimeout(e) else: @@ -471,30 +384,38 @@ def _run_tasks( reschedule_date = e.reschedule_date failed_tasks.append(e.task) except AirflowException as e: - self.log.error("An exception occurred for task_id %s", task.task_id) + self.log.error( + "An exception occurred for task_id %s", task.task_id + ) exception = e + for task in next(chunked_tasks, []): + future = pool.apply_async(self._run_operator, (context, task)) + futures[future] = task + if deferred_tasks: self.log.info("Running %s deferred tasks", len(deferred_tasks)) - for result in loop.run_until_complete(gather(*deferred_tasks, return_exceptions=True)): - self.log.debug("result: %s", result) - - if isinstance(result, Exception): - if isinstance(result, AirflowRescheduleTaskInstanceException): - reschedule_date = result.reschedule_date - failed_tasks.append(result.task) - else: - exception = result - - self._producer_thread.join() + with event_loop() as loop: + for result in loop.run_until_complete( + gather(*deferred_tasks, return_exceptions=True) + ): + self.log.debug("result: %s", result) + + if isinstance(result, Exception): + if isinstance( + result, AirflowRescheduleTaskInstanceException + ): + reschedule_date = result.reschedule_date + failed_tasks.append(result.task) + else: + exception = result + # elif futures: + # sleep(min(len(futures), os.cpu_count())) if not failed_tasks: if exception: raise exception - - self.log.info("Finished consuming tasks: %s", self._number_of_tasks) - if self.do_xcom_push: return XComIterable( task_id=self.task_id, @@ -504,9 +425,6 @@ def _run_tasks( ) return None - for failed_task in failed_tasks: - self._task_queue.put(failed_task) - # Calculate delay before the next retry delay = reschedule_date - timezone.utcnow() delay_seconds = ceil(delay.total_seconds()) @@ -519,27 +437,44 @@ def _run_tasks( len(failed_tasks), delay_seconds, ) + sleep(delay_seconds) - return self._run_tasks(context) + return self._run_tasks(context, iter(failed_tasks)) + + @classmethod + def _xcom_pull(cls, task_instance: TaskInstance): + with suppress(JSONDecodeError): + return task_instance.xcom_pull(task_ids=task_instance.task_id, dag_id=task_instance.dag_id) + return None def _run_operator(self, context: Context, task_instance: TaskInstance): try: - with OperatorExecutor(context=context, task_instance=task_instance) as executor: - result = executor.run() - if self.do_xcom_push: - task_instance.xcom_push(key=XCOM_RETURN_KEY, value=result) - return result + result = self._xcom_pull(task_instance) + + self.log.debug("result: %s", result) + + if not result: + with OperatorExecutor( + context=context, task_instance=task_instance + ) as executor: + result = executor.run() + if self.do_xcom_push: + task_instance.xcom_push(key=XCOM_RETURN_KEY, value=result) + else: + self.log.info("Task instance %s already completed.", task_instance.task_id) + return result except TaskDeferred as task_deferred: return task_deferred async def _run_deferrable( - self, semaphore: Semaphore, context: Context, task_instance: TaskInstance, task_deferred: TaskDeferred + self, context: Context, task_instance: TaskInstance, task_deferred: TaskDeferred ): - async with semaphore: - async with TriggerExecutor(context=context, task_instance=task_instance) as executor: + async with self._semaphore: + async with TriggerExecutor( + context=context, task_instance=task_instance + ) as executor: result = await executor.run(task_deferred) - self.log.info("_run_deferrable: %s", result) if self.do_xcom_push: task_instance.xcom_push(key=XCOM_RETURN_KEY, value=result) return result @@ -552,24 +487,12 @@ def _create_task(self, run_id: str, index: int, mapped_kwargs: dict) -> TaskInst ) def execute(self, context: Context): - def task_producer(): - tasks = map( - lambda mapped_kwargs: self._create_task( - context["ti"].run_id, mapped_kwargs[0], mapped_kwargs[1] - ), - enumerate(self._mapped_kwargs), - ) - - try: - self.log.info("Started producing tasks") - for task in tasks: - self._task_queue.put(task) - self.log.info("Finished producing tasks") - except Exception as e: - self.log.error("Exception in task_producer: %s", e) - self._task_queue.put(e) - - self._producer_thread = Thread(target=task_producer) - self._producer_thread.start() - - return self._run_tasks(context=context) + return self._run_tasks( + context=context, + tasks=iter( + map( + lambda mapped_kwargs: self._create_task(context["ti"].run_id, mapped_kwargs[0], mapped_kwargs[1]), + enumerate(self._mapped_kwargs), + ) + ), + ) diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index aba2f8836fbc6..372447571df8d 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -33,6 +33,8 @@ from airflow.utils.trigger_rule import TriggerRule from airflow.utils.xcom import XCOM_RETURN_KEY +from airflow.models.iterable import DeferredIterable + if TYPE_CHECKING: from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.edges import EdgeModifier @@ -448,9 +450,7 @@ def resolve(self, context: Mapping[str, Any]) -> Any: value = self.arg.resolve(context) if not isinstance(value, (Sequence, Iterable, dict)): raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") - if isinstance(value, Iterable) and hasattr( - value, "resolve" - ): # TODO: should check if it's DeferredIterable + if isinstance(value, DeferredIterable): value = value.resolve(context) return _MapResult(value, self.callables) From e70691af665eb0bbb56c71a6498ea407b46627eb Mon Sep 17 00:00:00 2001 From: David Blain Date: Fri, 28 Mar 2025 07:57:24 +0100 Subject: [PATCH 95/97] refactor: Improved run_tasks method of IterableOperator --- airflow-core/src/airflow/models/iterable.py | 5 +- .../src/airflow/models/iterableoperator.py | 79 +++++++++++-------- 2 files changed, 51 insertions(+), 33 deletions(-) diff --git a/airflow-core/src/airflow/models/iterable.py b/airflow-core/src/airflow/models/iterable.py index 5657aa35b5ec8..448925bb7580e 100644 --- a/airflow-core/src/airflow/models/iterable.py +++ b/airflow-core/src/airflow/models/iterable.py @@ -161,9 +161,10 @@ def __next__(self): try: with event_loop() as loop: + self.log.info("Running trigger: %s", self.trigger) event = loop.run_until_complete(run_trigger(self.trigger)) next_method = getattr(self.operator, self.next_method) - self.log.debug("Triggering next method: %s", self.next_method) + self.log.info("Triggering next method: %s", self.next_method) results = next_method(self.context, event.payload) except Exception as e: self.log.exception(e) @@ -202,6 +203,7 @@ def serialize(self): @classmethod def get_operator_from_dag(cls, dag_fileloc: str, dag_id: str, task_id: str) -> Operator: """Loads a DAG using DagBag and gets the operator by task_id.""" + from airflow.models import DagBag dag_bag = DagBag(dag_folder=None) # Avoid loading all DAGs @@ -214,6 +216,7 @@ def get_operator_from_dag(cls, dag_fileloc: str, dag_id: str, task_id: str) -> O @classmethod def deserialize(cls, data: dict, version: int): """Ensure the object is JSON deserializable.""" + trigger_class = import_string(data["trigger"][0]) trigger = trigger_class(**data["trigger"][1]) operator = cls.get_operator_from_dag(data["dag_fileloc"], data["dag_id"], data["task_id"]) diff --git a/airflow-core/src/airflow/models/iterableoperator.py b/airflow-core/src/airflow/models/iterableoperator.py index 1034c7a511cab..4d968fb5ed16f 100644 --- a/airflow-core/src/airflow/models/iterableoperator.py +++ b/airflow-core/src/airflow/models/iterableoperator.py @@ -117,7 +117,9 @@ def __enter__(self): self.operator.render_template_fields(context=self.context) self.operator.pre_execute(context=self.context) self.task_instance.set_state(TaskInstanceState.SCHEDULED) - self.task_instance._run_execute_callback(context=self.context, task=self.operator) + self.task_instance._run_execute_callback( + context=self.context, task=self.operator + ) return self async def __aenter__(self): @@ -205,7 +207,9 @@ async def run(self, task_deferred: TaskDeferred): if task_deferred.method_name: try: - next_method = self.operator.next_callable(task_deferred.method_name, task_deferred.kwargs) + next_method = self.operator.next_callable( + task_deferred.method_name, task_deferred.kwargs + ) outlet_events = context_get_outlet_events(self.context) return ExecutionCallableRunner( func=next_method, @@ -243,7 +247,9 @@ def __init__( self._operator_class = operator_class self.expand_input = expand_input self.partial_kwargs = partial_kwargs or {} - self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout + self.timeout = ( + timeout.total_seconds() if isinstance(timeout, timedelta) else timeout + ) self._mapped_kwargs: Iterable[dict] = [] if not self.max_active_tis_per_dag: self.max_active_tis_per_dag = os.cpu_count() or 1 @@ -306,20 +312,18 @@ def _resolve_expand_input(self, context: Context, session: Session): self.log.debug("resolve_expand_input: %s", self.expand_input) if isinstance(self.expand_input.value, XComArg): - resolved_input = self.expand_input.value.resolve(context=context, session=session) + resolved_input = self.expand_input.value.resolve( + context=context, session=session + ) else: resolved_input = self.expand_input.value self.log.debug("resolved_input: %s", resolved_input) if isinstance(resolved_input, _MapResult): - self._mapped_kwargs = map( - lambda value: self._resolve(value=value, context=context, session=session), resolved_input - ) + self._mapped_kwargs = map(lambda value: self._resolve(value=value, context=context, session=session), resolved_input) else: - self._mapped_kwargs = iter( - self._lazy_mapped_kwargs(input=resolved_input, context=context, session=session) - ) + self._mapped_kwargs = iter(self._lazy_mapped_kwargs(input=resolved_input, context=context, session=session)) self.log.debug("mapped_kwargs: %s", self._mapped_kwargs) @@ -338,6 +342,7 @@ def _run_tasks( ) -> Iterable[Any] | None: exception: BaseException | None = None reschedule_date = timezone.utcnow() + prev_futures_count = 0 futures: dict[ApplyResult, TaskInstance] = {} failed_tasks: list[TaskInstance] = [] chunked_tasks: Iterator[Iterable[TaskInstance]] = ichunked(tasks, (self.max_active_tis_per_dag * 2)) @@ -348,11 +353,16 @@ def _run_tasks( futures[future] = task while futures: - self.log.info("Number of remaining futures: %s", len(futures)) + futures_count = len(futures) + + if futures_count != prev_futures_count: + self.log.info("Number of remaining futures: %s", futures_count) + prev_futures_count = futures_count deferred_tasks: list[Coroutine[Any, Any, Any]] = [] + ready_futures = [future for future in futures.keys() if future.ready()] - for future in filter(ApplyResult.ready, list(futures.keys())): + for future in ready_futures: task = futures.pop(future) try: @@ -369,17 +379,21 @@ def _run_tasks( ) ) except TimeoutError as e: - self.log.warning("A timeout occurred for task_id %s", task.task_id) + self.log.warning( + "A timeout occurred for task_id %s", task.task_id + ) if task.next_try_number > self.retries: exception = AirflowTaskTimeout(e) else: - reschedule_date = task.next_retry_datetime() + reschedule_date = min(reschedule_date, task.next_retry_datetime()) failed_tasks.append(task) except AirflowRescheduleTaskInstanceException as e: - reschedule_date = e.reschedule_date + reschedule_date = min(reschedule_date, e.reschedule_date) failed_tasks.append(e.task) except AirflowException as e: - self.log.error("An exception occurred for task_id %s", task.task_id) + self.log.error( + "An exception occurred for task_id %s", task.task_id + ) exception = e for task in next(chunked_tasks, []): @@ -391,18 +405,20 @@ def _run_tasks( with event_loop() as loop: for result in loop.run_until_complete( - gather(*deferred_tasks, return_exceptions=True) + gather(*[loop.create_task(task) for task in deferred_tasks], return_exceptions=True) ): self.log.debug("result: %s", result) if isinstance(result, Exception): - if isinstance(result, AirflowRescheduleTaskInstanceException): - reschedule_date = result.reschedule_date + if isinstance( + result, AirflowRescheduleTaskInstanceException + ): + reschedule_date = min(reschedule_date, result.reschedule_date) failed_tasks.append(result.task) else: exception = result - # elif futures: - # sleep(min(len(futures), os.cpu_count())) + elif not ready_futures and futures: + sleep(len(futures) * 0.1) if not failed_tasks: if exception: @@ -417,12 +433,9 @@ def _run_tasks( return None # Calculate delay before the next retry - delay = reschedule_date - timezone.utcnow() - delay_seconds = ceil(delay.total_seconds()) - - self.log.debug("delay_seconds: %s", delay_seconds) + if reschedule_date > timezone.utcnow(): + delay_seconds = ceil((reschedule_date - timezone.utcnow()).total_seconds()) - if delay_seconds > 0: self.log.info( "Attempting to run %s failed tasks within %s seconds...", len(failed_tasks), @@ -445,8 +458,10 @@ def _run_operator(self, context: Context, task_instance: TaskInstance): self.log.debug("result: %s", result) - if not result: - with OperatorExecutor(context=context, task_instance=task_instance) as executor: + if result is None: + with OperatorExecutor( + context=context, task_instance=task_instance + ) as executor: result = executor.run() if self.do_xcom_push: task_instance.xcom_push(key=XCOM_RETURN_KEY, value=result) @@ -460,7 +475,9 @@ async def _run_deferrable( self, context: Context, task_instance: TaskInstance, task_deferred: TaskDeferred ): async with self._semaphore: - async with TriggerExecutor(context=context, task_instance=task_instance) as executor: + async with TriggerExecutor( + context=context, task_instance=task_instance + ) as executor: result = await executor.run(task_deferred) if self.do_xcom_push: task_instance.xcom_push(key=XCOM_RETURN_KEY, value=result) @@ -478,9 +495,7 @@ def execute(self, context: Context): context=context, tasks=iter( map( - lambda mapped_kwargs: self._create_task( - context["ti"].run_id, mapped_kwargs[0], mapped_kwargs[1] - ), + lambda mapped_kwargs: self._create_task(context["ti"].run_id, mapped_kwargs[0], mapped_kwargs[1]), enumerate(self._mapped_kwargs), ) ), From 97c64467e56f573c1823b2484a73743e83c2881d Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 7 Apr 2025 11:03:33 +0200 Subject: [PATCH 96/97] refactor: Updated IterableOperator with DeferredIterable --- airflow-core/src/airflow/models/iterable.py | 53 +++++++++---- .../src/airflow/models/iterableoperator.py | 76 +++++++++++++------ .../src/airflow/sdk/definitions/xcom_arg.py | 6 +- 3 files changed, 96 insertions(+), 39 deletions(-) diff --git a/airflow-core/src/airflow/models/iterable.py b/airflow-core/src/airflow/models/iterable.py index 448925bb7580e..93b8f304bf955 100644 --- a/airflow-core/src/airflow/models/iterable.py +++ b/airflow-core/src/airflow/models/iterable.py @@ -18,13 +18,14 @@ from __future__ import annotations import asyncio -from collections.abc import Generator, Iterator, Sequence +from collections.abc import Generator, Iterable, Iterator, Sequence from contextlib import contextmanager, suppress from typing import TYPE_CHECKING, Any from airflow.exceptions import AirflowException from airflow.serialization import serde from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.mixins import ResolveMixin from airflow.utils.module_loading import import_string from airflow.utils.xcom import XCOM_RETURN_KEY @@ -117,7 +118,7 @@ def deserialize(cls, data: dict, version: int): return XComIterable(**data) -class DeferredIterable(Iterator, LoggingMixin): +class DeferredIterable(Iterator, ResolveMixin, LoggingMixin): """An iterable that lazily fetches XCom values one by one instead of loading all at once.""" def __init__( @@ -136,7 +137,12 @@ def __init__( self.context = context self.index = 0 - def resolve(self, context: Context) -> DeferredIterable: + def iter_references(self) -> Iterable[tuple[Operator, str]]: + yield self.operator, XCOM_RETURN_KEY + + def resolve( + self, context: Context, *, include_xcom: bool = True + ) -> DeferredIterable: return DeferredIterable( results=self.results, trigger=self.trigger, @@ -159,10 +165,25 @@ def __next__(self): self.log.info("No more results. Running trigger: %s", self.trigger) + if not self.context: + raise AirflowException("Context is required to run the trigger.") + + results = self._execute_trigger() + + if isinstance(results, (list, set)): + self.results.extend(results) + else: + self.results.append(results) + + self.index += 1 + return self.results[-1] + + def _execute_trigger(self): try: with event_loop() as loop: self.log.info("Running trigger: %s", self.trigger) event = loop.run_until_complete(run_trigger(self.trigger)) + self.operator.render_template_fields(context=self.context) next_method = getattr(self.operator, self.next_method) self.log.info("Triggering next method: %s", self.next_method) results = next_method(self.context, event.payload) @@ -172,15 +193,13 @@ def __next__(self): if isinstance(results, DeferredIterable): self.trigger = results.trigger - self.results.extend(results.results) - else: - self.trigger = None - self.results.extend(results) + return results.results - self.index += 1 - return self.results[-1] + self.trigger = None + return results def __len__(self): + # TODO: maybe we should raise an exception here as you can't know the total length of an iterable in advance return len(self.results) def __getitem__(self, index: int): @@ -201,14 +220,15 @@ def serialize(self): } @classmethod - def get_operator_from_dag(cls, dag_fileloc: str, dag_id: str, task_id: str) -> Operator: + def get_operator_from_dag( + cls, dag_fileloc: str, dag_id: str, task_id: str + ) -> Operator: """Loads a DAG using DagBag and gets the operator by task_id.""" from airflow.models import DagBag dag_bag = DagBag(dag_folder=None) # Avoid loading all DAGs - processed_dags = dag_bag.process_file(dag_fileloc) - cls.logger().info("processed_dags: %s", processed_dags) + dag_bag.process_file(dag_fileloc) cls.logger().info("dag_bag: %s", dag_bag) cls.logger().info("dags: %s", dag_bag.dags) return dag_bag.dags[dag_id].get_task(task_id) @@ -219,9 +239,14 @@ def deserialize(cls, data: dict, version: int): trigger_class = import_string(data["trigger"][0]) trigger = trigger_class(**data["trigger"][1]) - operator = cls.get_operator_from_dag(data["dag_fileloc"], data["dag_id"], data["task_id"]) + operator = cls.get_operator_from_dag( + data["dag_fileloc"], data["dag_id"], data["task_id"] + ) return DeferredIterable( - results=data["results"], trigger=trigger, operator=operator, next_method=data["next_method"] + results=data["results"], + trigger=trigger, + operator=operator, + next_method=data["next_method"], ) diff --git a/airflow-core/src/airflow/models/iterableoperator.py b/airflow-core/src/airflow/models/iterableoperator.py index 4d968fb5ed16f..d1711033c9579 100644 --- a/airflow-core/src/airflow/models/iterableoperator.py +++ b/airflow-core/src/airflow/models/iterableoperator.py @@ -77,6 +77,7 @@ def __init__( self.__context = dict(context.items()) self._task_instance = task_instance self._is_async_mode: bool = False # Flag to track sync/async mode + self._result: Any | None = None @property def task_instance(self) -> TaskInstance: @@ -99,8 +100,16 @@ def mode(self) -> str: return "async" if self._is_async_mode else "sync" @abstractmethod + def execute(self, *args, **kwargs): + raise NotImplementedError + def run(self, *args, **kwargs): - raise NotImplementedError() + self._result = self.execute(*args, **kwargs) + return self._result + + async def run_deferred(self, *args, **kwargs): + self._result = await self.execute(*args, **kwargs) + return self._result def __enter__(self): if self.log.isEnabledFor(logging.INFO): @@ -145,7 +154,9 @@ def __exit__(self, exc_type, exc_value, traceback): self.task_instance.set_state(TaskInstanceState.FAILED) raise exc_value - self.operator.post_execute(context=self.context) + if self.operator.do_xcom_push: + self.task_instance.xcom_push(key=XCOM_RETURN_KEY, value=self._result) + self.operator.post_execute(context=self.context, result=self._result) self.task_instance.set_state(TaskInstanceState.SUCCESS) if self.log.isEnabledFor(logging.INFO): self.log.info( @@ -170,7 +181,7 @@ class OperatorExecutor(TaskExecutor): :meta private: """ - def run(self, *args, **kwargs): + def execute(self, *args, **kwargs): outlet_events = context_get_outlet_events(self.context) # TODO: change back to operator.execute once ExecutorSafeguard is fixed if hasattr(self.operator.execute, "__wrapped__"): @@ -197,7 +208,7 @@ class TriggerExecutor(TaskExecutor): :meta private: """ - async def run(self, task_deferred: TaskDeferred): + async def execute(self, task_deferred: TaskDeferred): event = await run_trigger(task_deferred.trigger) self.log.debug("event: %s", event) @@ -217,7 +228,7 @@ async def run(self, task_deferred: TaskDeferred): logger=self.log, ).run(self.context, event.payload) except TaskDeferred as task_deferred: - return await self.run(task_deferred=task_deferred) + return await self.execute(task_deferred=task_deferred) class IterableOperator(BaseOperator): @@ -321,9 +332,18 @@ def _resolve_expand_input(self, context: Context, session: Session): self.log.debug("resolved_input: %s", resolved_input) if isinstance(resolved_input, _MapResult): - self._mapped_kwargs = map(lambda value: self._resolve(value=value, context=context, session=session), resolved_input) + self._mapped_kwargs = map( + lambda value: self._resolve( + value=value, context=context, session=session + ), + resolved_input, + ) else: - self._mapped_kwargs = iter(self._lazy_mapped_kwargs(input=resolved_input, context=context, session=session)) + self._mapped_kwargs = iter( + self._lazy_mapped_kwargs( + input=resolved_input, context=context, session=session + ) + ) self.log.debug("mapped_kwargs: %s", self._mapped_kwargs) @@ -345,7 +365,9 @@ def _run_tasks( prev_futures_count = 0 futures: dict[ApplyResult, TaskInstance] = {} failed_tasks: list[TaskInstance] = [] - chunked_tasks: Iterator[Iterable[TaskInstance]] = ichunked(tasks, (self.max_active_tis_per_dag * 2)) + chunked_tasks: Iterator[Iterable[TaskInstance]] = ichunked( + tasks, (self.max_active_tis_per_dag * 2) + ) with ThreadPool(processes=self.max_active_tis_per_dag) as pool: for task in next(chunked_tasks, []): @@ -385,7 +407,9 @@ def _run_tasks( if task.next_try_number > self.retries: exception = AirflowTaskTimeout(e) else: - reschedule_date = min(reschedule_date, task.next_retry_datetime()) + reschedule_date = min( + reschedule_date, task.next_retry_datetime() + ) failed_tasks.append(task) except AirflowRescheduleTaskInstanceException as e: reschedule_date = min(reschedule_date, e.reschedule_date) @@ -405,7 +429,10 @@ def _run_tasks( with event_loop() as loop: for result in loop.run_until_complete( - gather(*[loop.create_task(task) for task in deferred_tasks], return_exceptions=True) + gather( + *[loop.create_task(task) for task in deferred_tasks], + return_exceptions=True, + ) ): self.log.debug("result: %s", result) @@ -413,7 +440,9 @@ def _run_tasks( if isinstance( result, AirflowRescheduleTaskInstanceException ): - reschedule_date = min(reschedule_date, result.reschedule_date) + reschedule_date = min( + reschedule_date, result.reschedule_date + ) failed_tasks.append(result.task) else: exception = result @@ -449,7 +478,9 @@ def _run_tasks( @classmethod def _xcom_pull(cls, task_instance: TaskInstance): with suppress(JSONDecodeError): - return task_instance.xcom_pull(task_ids=task_instance.task_id, dag_id=task_instance.dag_id) + return task_instance.xcom_pull( + task_ids=task_instance.task_id, dag_id=task_instance.dag_id + ) return None def _run_operator(self, context: Context, task_instance: TaskInstance): @@ -462,11 +493,11 @@ def _run_operator(self, context: Context, task_instance: TaskInstance): with OperatorExecutor( context=context, task_instance=task_instance ) as executor: - result = executor.run() - if self.do_xcom_push: - task_instance.xcom_push(key=XCOM_RETURN_KEY, value=result) + return executor.run() else: - self.log.info("Task instance %s already completed.", task_instance.task_id) + self.log.info( + "Task instance %s already completed.", task_instance.task_id + ) return result except TaskDeferred as task_deferred: return task_deferred @@ -478,12 +509,11 @@ async def _run_deferrable( async with TriggerExecutor( context=context, task_instance=task_instance ) as executor: - result = await executor.run(task_deferred) - if self.do_xcom_push: - task_instance.xcom_push(key=XCOM_RETURN_KEY, value=result) - return result + return await executor.run_deferred(task_deferred) - def _create_task(self, run_id: str, index: int, mapped_kwargs: dict) -> TaskInstance: + def _create_task( + self, run_id: str, index: int, mapped_kwargs: dict + ) -> TaskInstance: operator = self._unmap_operator(index, mapped_kwargs) return TaskInstance( task=operator, @@ -495,7 +525,9 @@ def execute(self, context: Context): context=context, tasks=iter( map( - lambda mapped_kwargs: self._create_task(context["ti"].run_id, mapped_kwargs[0], mapped_kwargs[1]), + lambda mapped_kwargs: self._create_task( + context["ti"].run_id, mapped_kwargs[0], mapped_kwargs[1] + ), enumerate(self._mapped_kwargs), ) ), diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index e82353a6b813c..4a23db9ad413c 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -33,8 +33,6 @@ from airflow.utils.trigger_rule import TriggerRule from airflow.utils.xcom import XCOM_RETURN_KEY -from airflow.models.iterable import DeferredIterable - if TYPE_CHECKING: from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.definitions.edges import EdgeModifier @@ -358,6 +356,8 @@ def resolve(self, context: Mapping[str, Any]) -> Any: default=NOTSET, ) if not isinstance(result, ArgNotSet): + if isinstance(result, ResolveMixin): + return result.resolve(context) return result if self.key == XCOM_RETURN_KEY: return None @@ -450,7 +450,7 @@ def resolve(self, context: Mapping[str, Any]) -> Any: value = self.arg.resolve(context) if not isinstance(value, (Sequence, Iterable, dict)): raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") - if isinstance(value, DeferredIterable): + if isinstance(value, ResolveMixin): value = value.resolve(context) return _MapResult(value, self.callables) From 8fd3b907e3604a890460206546d6620127014fcc Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 6 May 2025 21:10:44 +0200 Subject: [PATCH 97/97] refactor: Refactored IterableOperator to be Airflow 3 compliant --- airflow-core/src/airflow/models/iterable.py | 22 +- .../src/airflow/models/iterableoperator.py | 360 ++++++++++-------- 2 files changed, 207 insertions(+), 175 deletions(-) diff --git a/airflow-core/src/airflow/models/iterable.py b/airflow-core/src/airflow/models/iterable.py index 93b8f304bf955..5401eb0cb99b4 100644 --- a/airflow-core/src/airflow/models/iterable.py +++ b/airflow-core/src/airflow/models/iterable.py @@ -66,7 +66,7 @@ def event_loop() -> Generator[AbstractEventLoop, None, None]: loop.close() -class XComIterable(Iterator, Sequence): +class XComIterable(Sequence): """An iterable that lazily fetches XCom values one by one instead of loading all at once.""" def __init__(self, task_id: str, dag_id: str, run_id: str, length: int): @@ -97,9 +97,9 @@ def __getitem__(self, index: int): raise IndexError return XCom.get_one( - key=XCOM_RETURN_KEY, + key=f"{self.task_id}_{index}", dag_id=self.dag_id, - task_id=f"{self.task_id}_{index}", + task_id=self.task_id, run_id=self.run_id, ) @@ -140,9 +140,7 @@ def __init__( def iter_references(self) -> Iterable[tuple[Operator, str]]: yield self.operator, XCOM_RETURN_KEY - def resolve( - self, context: Context, *, include_xcom: bool = True - ) -> DeferredIterable: + def resolve(self, context: Context) -> DeferredIterable: return DeferredIterable( results=self.results, trigger=self.trigger, @@ -220,14 +218,11 @@ def serialize(self): } @classmethod - def get_operator_from_dag( - cls, dag_fileloc: str, dag_id: str, task_id: str - ) -> Operator: + def get_operator_from_dag(cls, dag_fileloc: str, dag_id: str, task_id: str) -> Operator: """Loads a DAG using DagBag and gets the operator by task_id.""" - from airflow.models import DagBag - dag_bag = DagBag(dag_folder=None) # Avoid loading all DAGs + dag_bag = DagBag(collect_dags=False) # Avoid loading all DAGs dag_bag.process_file(dag_fileloc) cls.logger().info("dag_bag: %s", dag_bag) cls.logger().info("dags: %s", dag_bag.dags) @@ -236,12 +231,9 @@ def get_operator_from_dag( @classmethod def deserialize(cls, data: dict, version: int): """Ensure the object is JSON deserializable.""" - trigger_class = import_string(data["trigger"][0]) trigger = trigger_class(**data["trigger"][1]) - operator = cls.get_operator_from_dag( - data["dag_fileloc"], data["dag_id"], data["task_id"] - ) + operator = cls.get_operator_from_dag(data["dag_fileloc"], data["dag_id"], data["task_id"]) return DeferredIterable( results=data["results"], trigger=trigger, diff --git a/airflow-core/src/airflow/models/iterableoperator.py b/airflow-core/src/airflow/models/iterableoperator.py index d1711033c9579..9c1de5ce31c38 100644 --- a/airflow-core/src/airflow/models/iterableoperator.py +++ b/airflow-core/src/airflow/models/iterableoperator.py @@ -21,13 +21,11 @@ import os from abc import abstractmethod from asyncio import Semaphore, gather -from collections.abc import Coroutine, Iterable, Iterator, Sequence -from contextlib import suppress +from collections.abc import Coroutine, Iterable, Sequence +from concurrent.futures import Future, ThreadPoolExecutor, as_completed from datetime import timedelta -from json import JSONDecodeError from math import ceil from multiprocessing import TimeoutError -from multiprocessing.pool import ApplyResult, ThreadPool from time import sleep from typing import TYPE_CHECKING, Any @@ -43,26 +41,63 @@ from airflow.models.abstractoperator import DEFAULT_TASK_EXECUTION_TIMEOUT from airflow.models.expandinput import ( ExpandInput, - _needs_run_time_resolution, - is_mappable, ) from airflow.models.iterable import XComIterable, event_loop from airflow.models.taskinstance import TaskInstance -from airflow.sdk.definitions._internal.abstractoperator import Operator +from airflow.sdk.bases.operator import BaseOperator +from airflow.sdk.definitions._internal.mixins import ResolveMixin +from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet from airflow.sdk.definitions.context import Context -from airflow.sdk.definitions.xcom_arg import XComArg, _MapResult +from airflow.sdk.definitions.xcom_arg import XComArg +from airflow.sdk.execution_time.callback_runner import ( + create_executable_runner, +) +from airflow.sdk.execution_time.context import context_get_outlet_events from airflow.triggers.base import run_trigger from airflow.utils import timezone from airflow.utils.context import context_get_outlet_events from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.operator_helpers import ExecutionCallableRunner from airflow.utils.state import TaskInstanceState -from airflow.utils.task_instance_session import get_current_task_instance_session from airflow.utils.xcom import XCOM_RETURN_KEY if TYPE_CHECKING: import jinja2 - from sqlalchemy.orm import Session + + +class VolatileTaskInstance(TaskInstance): + """Volatile task instance to run an operator which handles XCom's in memory.""" + + _xcoms: dict[str, Any] = {} + + def xcom_pull( + self, + task_ids: str | Iterable[str] | None = None, + dag_id: str | None = None, + key: str = XCOM_RETURN_KEY, + include_prior_dates: bool = False, # TODO: Add support for this + *, + map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET, + default: Any = None, + run_id: str | None = None, + ) -> Any: + key = f"{self.task_id}_{self.dag_id}_{key}" + if map_indexes is not None and (not isinstance(map_indexes, int) or map_indexes >= 0): + key += f"_{map_indexes}" + return self._xcoms.get(key, default) + + def xcom_push(self, key: str, value: Any): + key = f"{self.task_id}_{self.dag_id}_{key}" + if self.map_index is not None and self.map_index >= 0: + key += f"_{self.map_index}" + self._xcoms[key] = value + + @property + def next_try_number(self) -> int: + return self.try_number + 1 + + @property + def xcom_key(self) -> str: + return f"{self.task_id}_{self.map_index}" class TaskExecutor(LoggingMixin): @@ -83,16 +118,32 @@ def __init__( def task_instance(self) -> TaskInstance: return self._task_instance + @property + def dag_id(self) -> str: + return self._task_instance.dag_id + + @property + def task_id(self) -> str: + return self._task_instance.task_id + @property def task_index(self) -> int: - return int(self._task_instance.task_id.rsplit("_", 1)[-1]) + # return int(self._task_instance.task_id.rsplit("_", 1)[-1]) + return self._task_instance.map_index + + @property + def key(self): + return self.task_instance.xcom_key @property def context(self) -> Context: - return {**self.__context, **{"ti": self.task_instance}} + return { + **self.__context, + **{"ti": self.task_instance, "task_instance": self.task_instance}, + } @property - def operator(self) -> Operator: + def operator(self) -> BaseOperator: return self.task_instance.task @property @@ -125,10 +176,7 @@ def __enter__(self): if self.task_instance.try_number == 0: self.operator.render_template_fields(context=self.context) self.operator.pre_execute(context=self.context) - self.task_instance.set_state(TaskInstanceState.SCHEDULED) - self.task_instance._run_execute_callback( - context=self.context, task=self.operator - ) + self.task_instance._run_execute_callback(context=self.context, task=self.operator) return self async def __aenter__(self): @@ -145,19 +193,26 @@ def __exit__(self, exc_type, exc_value, traceback): self.task_index, exc_value, ) + if self.task_instance.task.on_failure_callback: + self.task_instance.task.on_failure_callback( + {**self.context, **{"exception": exc_value}} + ) + self.task_instance.state = TaskInstanceState.FAILED raise exc_value self.task_instance.try_number += 1 self.task_instance.end_date = timezone.utcnow() - self.task_instance.set_state(TaskInstanceState.UP_FOR_RESCHEDULE) + if self.task_instance.task.on_retry_callback: + self.task_instance.task.on_retry_callback({**self.context, **{"exception": exc_value}}) + self.task_instance.state = TaskInstanceState.UP_FOR_RESCHEDULE raise AirflowRescheduleTaskInstanceException(task=self.task_instance) - self.task_instance.set_state(TaskInstanceState.FAILED) raise exc_value - if self.operator.do_xcom_push: - self.task_instance.xcom_push(key=XCOM_RETURN_KEY, value=self._result) + + self.task_instance.state = TaskInstanceState.SUCCESS + if self.task_instance.task.on_success_callback: + self.task_instance.task.on_success_callback(self.context) self.operator.post_execute(context=self.context, result=self._result) - self.task_instance.set_state(TaskInstanceState.SUCCESS) if self.log.isEnabledFor(logging.INFO): self.log.info( "Task instance %s for %s finished successfully in %s attempts in %s mode.", @@ -185,12 +240,12 @@ def execute(self, *args, **kwargs): outlet_events = context_get_outlet_events(self.context) # TODO: change back to operator.execute once ExecutorSafeguard is fixed if hasattr(self.operator.execute, "__wrapped__"): - return ExecutionCallableRunner( + return create_executable_runner( func=self.operator.execute.__wrapped__, outlet_events=outlet_events, logger=self.log, ).run(self.operator, self.context) - return ExecutionCallableRunner( + return create_executable_runner( func=self.operator.execute, outlet_events=outlet_events, logger=self.log, @@ -218,11 +273,9 @@ async def execute(self, task_deferred: TaskDeferred): if task_deferred.method_name: try: - next_method = self.operator.next_callable( - task_deferred.method_name, task_deferred.kwargs - ) + next_method = self.operator.next_callable(task_deferred.method_name, task_deferred.kwargs) outlet_events = context_get_outlet_events(self.context) - return ExecutionCallableRunner( + return create_executable_runner( func=next_method, outlet_events=outlet_events, logger=self.log, @@ -258,9 +311,7 @@ def __init__( self._operator_class = operator_class self.expand_input = expand_input self.partial_kwargs = partial_kwargs or {} - self.timeout = ( - timeout.total_seconds() if isinstance(timeout, timedelta) else timeout - ) + self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout self._mapped_kwargs: Iterable[dict] = [] if not self.max_active_tis_per_dag: self.max_active_tis_per_dag = os.cpu_count() or 1 @@ -272,31 +323,35 @@ def __init__( def operator_name(self) -> str: return self._operator_class.__name__ + @property + def task_type(self) -> str: + return self._operator_class.__name__ + + @property + def chunk_size(self) -> int: + return self.max_active_tis_per_dag * 2 + def _get_specified_expand_input(self) -> ExpandInput: return self.expand_input - def _unmap_operator(self, index: int, mapped_kwargs: dict): + def _unmap_operator(self, mapped_kwargs: dict): kwargs = { **self.partial_kwargs, - **{"task_id": f"{self.task_id}_{index}"}, + **{"task_id": self.task_id}, **mapped_kwargs, } self._number_of_tasks += 1 - self.log.debug("index: %s", index) self.log.debug("kwargs: %s", kwargs) self.log.debug("operator_class: %s", self._operator_class) self.log.debug("number_of_tasks: %s", self._number_of_tasks) return self._operator_class(**kwargs, _airflow_from_mapped=True) - def _resolve(self, value, context: Context, session: Session): + def _resolve(self, value, context: Context): if isinstance(value, dict): for key in value: item = value[key] - if _needs_run_time_resolution(item): - item = item.resolve(context=context, session=session) - - if is_mappable(item): - item = iter(item) # type: ignore + if isinstance(item, ResolveMixin): + item = item.resolve(context=context) self.log.debug("resolved_value: %s", item) @@ -304,46 +359,44 @@ def _resolve(self, value, context: Context, session: Session): return value - def _lazy_mapped_kwargs(self, input, context: Context, session: Session): - self.log.debug("_lazy_mapped_kwargs value: %s", input) + def _lazy_mapped_kwargs(self, value, context: Context) -> Iterable[dict]: + self.log.debug("_lazy_mapped_kwargs resolved_value: %s", value) - value = self._resolve(value=input, context=context, session=session) + resolved_value = self._resolve(value=value, context=context) - self.log.debug("resolved value: %s", value) + self.log.debug("resolved resolved_value: %s", resolved_value) - if isinstance(value, dict): - for key, item in value.items(): + if isinstance(resolved_value, dict): + for key, item in resolved_value.items(): if not isinstance(item, (Sequence, Iterable)) or isinstance(item, str): yield {key: item} else: for sub_item in item: yield {key: sub_item} - def _resolve_expand_input(self, context: Context, session: Session): + def _resolve_expand_input(self, context: Context): self.log.debug("resolve_expand_input: %s", self.expand_input) - if isinstance(self.expand_input.value, XComArg): - resolved_input = self.expand_input.value.resolve( - context=context, session=session - ) + # TODO: check how to use the correct ResolvMixin type + if isinstance(self.expand_input.value, ResolveMixin): + resolved_input = self.expand_input.value.resolve(context=context) else: resolved_input = self.expand_input.value self.log.debug("resolved_input: %s", resolved_input) - if isinstance(resolved_input, _MapResult): + # Once _MapResult inherits from _MappableResult in Airflow, check only for _MappableResult + if type(resolved_input).__name__ in { + "_FilterResult", + "_MapResult", + "_LazyMapResult", + }: self._mapped_kwargs = map( - lambda value: self._resolve( - value=value, context=context, session=session - ), + lambda value: self._resolve(value=value, context=context), resolved_input, ) else: - self._mapped_kwargs = iter( - self._lazy_mapped_kwargs( - input=resolved_input, context=context, session=session - ) - ) + self._mapped_kwargs = iter(self._lazy_mapped_kwargs(value=resolved_input, context=context)) self.log.debug("mapped_kwargs: %s", self._mapped_kwargs) @@ -352,26 +405,29 @@ def render_template_fields( context: Context, jinja_env: jinja2.Environment | None = None, ) -> None: - session = get_current_task_instance_session() - self._resolve_expand_input(context=context, session=session) + self._resolve_expand_input(context=context) + + def _xcom_push(self, context: Context, task: TaskInstance, value: Any) -> None: + self.log.info("Pushing XCom %s", task.map_index) + + context["ti"].xcom_push(key=task.xcom_key, value=value) def _run_tasks( self, context: Context, - tasks: Iterator[TaskInstance], - ) -> Iterable[Any] | None: + tasks: Iterable[TaskInstance], + ) -> None: exception: BaseException | None = None reschedule_date = timezone.utcnow() prev_futures_count = 0 - futures: dict[ApplyResult, TaskInstance] = {} + futures: dict[Future, TaskInstance] = {} + failed_tasks: list[TaskInstance] = [] - chunked_tasks: Iterator[Iterable[TaskInstance]] = ichunked( - tasks, (self.max_active_tis_per_dag * 2) - ) + chunked_tasks = ichunked(tasks, self.chunk_size) - with ThreadPool(processes=self.max_active_tis_per_dag) as pool: + with ThreadPoolExecutor(max_workers=self.max_active_tis_per_dag) as pool: for task in next(chunked_tasks, []): - future = pool.apply_async(self._run_operator, (context, task)) + future = pool.submit(self._run_operator, context, task) futures[future] = task while futures: @@ -381,73 +437,78 @@ def _run_tasks( self.log.info("Number of remaining futures: %s", futures_count) prev_futures_count = futures_count - deferred_tasks: list[Coroutine[Any, Any, Any]] = [] - ready_futures = [future for future in futures.keys() if future.ready()] + deferred_tasks: dict[Coroutine[Any, Any, Any], TaskInstance] = {} + ready_futures = False - for future in ready_futures: - task = futures.pop(future) + with event_loop() as loop: + for future in as_completed(futures.keys()): + task = futures.pop(future) + ready_futures = True - try: - result = future.get(timeout=self.timeout) + try: + result = future.result(timeout=self.timeout) - self.log.debug("result: %s", result) + self.log.debug("result: %s", result) - if isinstance(result, TaskDeferred): - deferred_tasks.append( - self._run_deferrable( + if isinstance(result, TaskDeferred): + deferred_task = loop.create_task( + self._run_deferrable( + context=context, + task_instance=task, + task_deferred=result, + ) + ) + deferred_tasks[deferred_task] = task + elif result and task.task.do_xcom_push: + self._xcom_push( context=context, - task_instance=task, - task_deferred=result, + task=task, + value=result, ) - ) - except TimeoutError as e: - self.log.warning( - "A timeout occurred for task_id %s", task.task_id - ) - if task.next_try_number > self.retries: - exception = AirflowTaskTimeout(e) - else: - reschedule_date = min( - reschedule_date, task.next_retry_datetime() - ) - failed_tasks.append(task) - except AirflowRescheduleTaskInstanceException as e: - reschedule_date = min(reschedule_date, e.reschedule_date) - failed_tasks.append(e.task) - except AirflowException as e: - self.log.error( - "An exception occurred for task_id %s", task.task_id - ) - exception = e - - for task in next(chunked_tasks, []): - future = pool.apply_async(self._run_operator, (context, task)) - futures[future] = task - - if deferred_tasks: - self.log.info("Running %s deferred tasks", len(deferred_tasks)) - - with event_loop() as loop: - for result in loop.run_until_complete( - gather( - *[loop.create_task(task) for task in deferred_tasks], - return_exceptions=True, - ) - ): + except TimeoutError as e: + self.log.warning("A timeout occurred for task_id %s", task.task_id) + if task.next_try_number > self.retries: + exception = AirflowTaskTimeout(e) + else: + reschedule_date = min(reschedule_date, task.next_retry_datetime()) + failed_tasks.append(task) + except AirflowRescheduleTaskInstanceException as e: + reschedule_date = min(reschedule_date, e.reschedule_date) + failed_tasks.append(e.task) + except AirflowException as e: + self.log.error("An exception occurred for task_id %s", task.task_id) + exception = e + + if len(futures) < self.chunk_size: + for task in next(chunked_tasks, []): + future = pool.submit(self._run_operator, context, task) + futures[future] = task + + if deferred_tasks: + self.log.info("Running %s deferred tasks", len(deferred_tasks)) + + deferred_task_keys = list(deferred_tasks.keys()) + results = loop.run_until_complete(gather(*deferred_task_keys, return_exceptions=True)) + + for future, result in zip(deferred_task_keys, results): + task = deferred_tasks[future] + self.log.debug("result: %s", result) if isinstance(result, Exception): - if isinstance( - result, AirflowRescheduleTaskInstanceException - ): - reschedule_date = min( - reschedule_date, result.reschedule_date - ) - failed_tasks.append(result.task) + if isinstance(result, AirflowRescheduleTaskInstanceException): + reschedule_date = min(reschedule_date, result.reschedule_date) + failed_tasks.append(task) else: exception = result - elif not ready_futures and futures: - sleep(len(futures) * 0.1) + elif result and task.task.do_xcom_push: + self._xcom_push( + context=context, + task=task, + value=result, + ) + elif not ready_futures and futures: + sleep(len(futures) * 0.1) if not failed_tasks: if exception: @@ -459,7 +520,6 @@ def _run_tasks( run_id=context["run_id"], length=self._number_of_tasks, ) - return None # Calculate delay before the next retry if reschedule_date > timezone.utcnow(): @@ -473,32 +533,12 @@ def _run_tasks( sleep(delay_seconds) - return self._run_tasks(context, iter(failed_tasks)) - - @classmethod - def _xcom_pull(cls, task_instance: TaskInstance): - with suppress(JSONDecodeError): - return task_instance.xcom_pull( - task_ids=task_instance.task_id, dag_id=task_instance.dag_id - ) - return None + return self._run_tasks(context, failed_tasks) def _run_operator(self, context: Context, task_instance: TaskInstance): try: - result = self._xcom_pull(task_instance) - - self.log.debug("result: %s", result) - - if result is None: - with OperatorExecutor( - context=context, task_instance=task_instance - ) as executor: - return executor.run() - else: - self.log.info( - "Task instance %s already completed.", task_instance.task_id - ) - return result + with OperatorExecutor(context=context, task_instance=task_instance) as executor: + return executor.run() except TaskDeferred as task_deferred: return task_deferred @@ -506,18 +546,16 @@ async def _run_deferrable( self, context: Context, task_instance: TaskInstance, task_deferred: TaskDeferred ): async with self._semaphore: - async with TriggerExecutor( - context=context, task_instance=task_instance - ) as executor: + async with TriggerExecutor(context=context, task_instance=task_instance) as executor: return await executor.run_deferred(task_deferred) - def _create_task( - self, run_id: str, index: int, mapped_kwargs: dict - ) -> TaskInstance: - operator = self._unmap_operator(index, mapped_kwargs) - return TaskInstance( + def _create_task(self, run_id: str, index: int, mapped_kwargs: dict) -> TaskInstance: + operator = self._unmap_operator(mapped_kwargs) + return VolatileTaskInstance( task=operator, run_id=run_id, + state=TaskInstanceState.SCHEDULED.value, + map_index=index, ) def execute(self, context: Context): @@ -526,7 +564,9 @@ def execute(self, context: Context): tasks=iter( map( lambda mapped_kwargs: self._create_task( - context["ti"].run_id, mapped_kwargs[0], mapped_kwargs[1] + context["ti"].run_id, + mapped_kwargs[0], + mapped_kwargs[1], ), enumerate(self._mapped_kwargs), )