diff --git a/airflow-core/src/airflow/exceptions.py b/airflow-core/src/airflow/exceptions.py index 045f9647ade76..fcfc06babbcad 100644 --- a/airflow-core/src/airflow/exceptions.py +++ b/airflow-core/src/airflow/exceptions.py @@ -32,7 +32,7 @@ if TYPE_CHECKING: from collections.abc import Sized - from airflow.models import DagRun + from airflow.models import DagRun, TaskInstance from airflow.sdk.definitions.asset import AssetNameRef, AssetUniqueKey, AssetUriRef from airflow.utils.state import DagRunState @@ -87,6 +87,18 @@ def serialize(self): return f"{cls.__module__}.{cls.__name__}", (), {"reschedule_date": self.reschedule_date} +class AirflowRescheduleTaskInstanceException(AirflowRescheduleException): + """ + Raise when the task should be re-scheduled for a specific TaskInstance at a later time. + + :param task_instance: The task instance that should be rescheduled + """ + + def __init__(self, task: TaskInstance): + super().__init__(reschedule_date=task.next_retry_datetime()) + self.task = task + + class InvalidStatsNameException(AirflowException): """Raise when name of the stats is invalid.""" diff --git a/airflow-core/src/airflow/models/iterable.py b/airflow-core/src/airflow/models/iterable.py new file mode 100644 index 0000000000000..5401eb0cb99b4 --- /dev/null +++ b/airflow-core/src/airflow/models/iterable.py @@ -0,0 +1,251 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from collections.abc import Generator, Iterable, Iterator, Sequence +from contextlib import contextmanager, suppress +from typing import TYPE_CHECKING, Any + +from airflow.exceptions import AirflowException +from airflow.serialization import serde +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.mixins import ResolveMixin +from airflow.utils.module_loading import import_string +from airflow.utils.xcom import XCOM_RETURN_KEY + +try: + from airflow.sdk.definitions._internal.abstractoperator import Operator + from airflow.sdk.definitions.context import Context + from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg + from airflow.sdk.execution_time.xcom import XCom +except ImportError: + # TODO: Remove once provider drops support for Airflow 2 + from airflow.models import XCom + from airflow.models.baseoperator import BaseOperator as Operator + from airflow.utils.context import Context + +if TYPE_CHECKING: + from asyncio import AbstractEventLoop + + from airflow.triggers.base import BaseTrigger, run_trigger + + +@contextmanager +def event_loop() -> Generator[AbstractEventLoop, None, None]: + new_event_loop = False + loop = None + try: + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + raise RuntimeError + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + new_event_loop = True + yield loop + finally: + if new_event_loop and loop is not None: + with suppress(AttributeError): + loop.close() + + +class XComIterable(Sequence): + """An iterable that lazily fetches XCom values one by one instead of loading all at once.""" + + def __init__(self, task_id: str, dag_id: str, run_id: str, length: int): + self.task_id = task_id + self.dag_id = dag_id + self.run_id = run_id + self.length = length + self.index = 0 + + def __iter__(self) -> Iterator: + self.index = 0 + return self + + def __next__(self): + if self.index >= self.length: + raise StopIteration + + value = self[self.index] + self.index += 1 + return value + + def __len__(self): + return self.length + + def __getitem__(self, index: int): + """Allows direct indexing, making this work like a sequence.""" + if not (0 <= index < self.length): + raise IndexError + + return XCom.get_one( + key=f"{self.task_id}_{index}", + dag_id=self.dag_id, + task_id=self.task_id, + run_id=self.run_id, + ) + + def serialize(self): + """Ensure the object is JSON serializable.""" + return { + "task_id": self.task_id, + "dag_id": self.dag_id, + "run_id": self.run_id, + "length": self.length, + } + + @classmethod + def deserialize(cls, data: dict, version: int): + """Ensure the object is JSON deserializable.""" + return XComIterable(**data) + + +class DeferredIterable(Iterator, ResolveMixin, LoggingMixin): + """An iterable that lazily fetches XCom values one by one instead of loading all at once.""" + + def __init__( + self, + results: list[Any] | Any, + trigger: BaseTrigger, + operator: Operator, + next_method: str, + context: Context | None = None, + ): + super().__init__() + self.results = results.copy() if isinstance(results, list) else [results] + self.trigger = trigger + self.operator = operator + self.next_method = next_method + self.context = context + self.index = 0 + + def iter_references(self) -> Iterable[tuple[Operator, str]]: + yield self.operator, XCOM_RETURN_KEY + + def resolve(self, context: Context) -> DeferredIterable: + return DeferredIterable( + results=self.results, + trigger=self.trigger, + operator=self.operator, + next_method=self.next_method, + context=context, + ) + + def __iter__(self) -> Iterator: + return self + + def __next__(self): + if self.index < len(self.results): + result = self.results[self.index] + self.index += 1 + return result + + if not self.trigger: + raise StopIteration + + self.log.info("No more results. Running trigger: %s", self.trigger) + + if not self.context: + raise AirflowException("Context is required to run the trigger.") + + results = self._execute_trigger() + + if isinstance(results, (list, set)): + self.results.extend(results) + else: + self.results.append(results) + + self.index += 1 + return self.results[-1] + + def _execute_trigger(self): + try: + with event_loop() as loop: + self.log.info("Running trigger: %s", self.trigger) + event = loop.run_until_complete(run_trigger(self.trigger)) + self.operator.render_template_fields(context=self.context) + next_method = getattr(self.operator, self.next_method) + self.log.info("Triggering next method: %s", self.next_method) + results = next_method(self.context, event.payload) + except Exception as e: + self.log.exception(e) + raise AirflowException from e + + if isinstance(results, DeferredIterable): + self.trigger = results.trigger + return results.results + + self.trigger = None + return results + + def __len__(self): + # TODO: maybe we should raise an exception here as you can't know the total length of an iterable in advance + return len(self.results) + + def __getitem__(self, index: int): + if not (0 <= index < len(self)): + raise IndexError + + return self.results[index] + + def serialize(self): + """Ensure the object is JSON serializable.""" + return { + "results": self.results, + "trigger": self.trigger.serialize() if self.trigger else None, + "dag_fileloc": self.operator.dag.fileloc, + "dag_id": self.operator.dag_id, + "task_id": self.operator.task_id, + "next_method": self.next_method, + } + + @classmethod + def get_operator_from_dag(cls, dag_fileloc: str, dag_id: str, task_id: str) -> Operator: + """Loads a DAG using DagBag and gets the operator by task_id.""" + from airflow.models import DagBag + + dag_bag = DagBag(collect_dags=False) # Avoid loading all DAGs + dag_bag.process_file(dag_fileloc) + cls.logger().info("dag_bag: %s", dag_bag) + cls.logger().info("dags: %s", dag_bag.dags) + return dag_bag.dags[dag_id].get_task(task_id) + + @classmethod + def deserialize(cls, data: dict, version: int): + """Ensure the object is JSON deserializable.""" + trigger_class = import_string(data["trigger"][0]) + trigger = trigger_class(**data["trigger"][1]) + operator = cls.get_operator_from_dag(data["dag_fileloc"], data["dag_id"], data["task_id"]) + return DeferredIterable( + results=data["results"], + trigger=trigger, + operator=operator, + next_method=data["next_method"], + ) + + +# This is a workaround to allow the DeferredIterable and XComIterable classes to be serialized +serde._extra_allowed = serde._extra_allowed.union( + { + f"{XComIterable.__module__}.{XComIterable.__class__.__name__}", + f"{DeferredIterable.__module__}.{DeferredIterable.__class__.__name__}", + } +) diff --git a/airflow-core/src/airflow/models/iterableoperator.py b/airflow-core/src/airflow/models/iterableoperator.py new file mode 100644 index 0000000000000..9c1de5ce31c38 --- /dev/null +++ b/airflow-core/src/airflow/models/iterableoperator.py @@ -0,0 +1,574 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +import os +from abc import abstractmethod +from asyncio import Semaphore, gather +from collections.abc import Coroutine, Iterable, Sequence +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from datetime import timedelta +from math import ceil +from multiprocessing import TimeoutError +from time import sleep +from typing import TYPE_CHECKING, Any + +from more_itertools import ichunked + +from airflow.exceptions import ( + AirflowException, + AirflowRescheduleTaskInstanceException, + AirflowTaskTimeout, + TaskDeferred, +) +from airflow.models import BaseOperator +from airflow.models.abstractoperator import DEFAULT_TASK_EXECUTION_TIMEOUT +from airflow.models.expandinput import ( + ExpandInput, +) +from airflow.models.iterable import XComIterable, event_loop +from airflow.models.taskinstance import TaskInstance +from airflow.sdk.bases.operator import BaseOperator +from airflow.sdk.definitions._internal.mixins import ResolveMixin +from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet +from airflow.sdk.definitions.context import Context +from airflow.sdk.definitions.xcom_arg import XComArg +from airflow.sdk.execution_time.callback_runner import ( + create_executable_runner, +) +from airflow.sdk.execution_time.context import context_get_outlet_events +from airflow.triggers.base import run_trigger +from airflow.utils import timezone +from airflow.utils.context import context_get_outlet_events +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.state import TaskInstanceState +from airflow.utils.xcom import XCOM_RETURN_KEY + +if TYPE_CHECKING: + import jinja2 + + +class VolatileTaskInstance(TaskInstance): + """Volatile task instance to run an operator which handles XCom's in memory.""" + + _xcoms: dict[str, Any] = {} + + def xcom_pull( + self, + task_ids: str | Iterable[str] | None = None, + dag_id: str | None = None, + key: str = XCOM_RETURN_KEY, + include_prior_dates: bool = False, # TODO: Add support for this + *, + map_indexes: int | Iterable[int] | None | ArgNotSet = NOTSET, + default: Any = None, + run_id: str | None = None, + ) -> Any: + key = f"{self.task_id}_{self.dag_id}_{key}" + if map_indexes is not None and (not isinstance(map_indexes, int) or map_indexes >= 0): + key += f"_{map_indexes}" + return self._xcoms.get(key, default) + + def xcom_push(self, key: str, value: Any): + key = f"{self.task_id}_{self.dag_id}_{key}" + if self.map_index is not None and self.map_index >= 0: + key += f"_{self.map_index}" + self._xcoms[key] = value + + @property + def next_try_number(self) -> int: + return self.try_number + 1 + + @property + def xcom_key(self) -> str: + return f"{self.task_id}_{self.map_index}" + + +class TaskExecutor(LoggingMixin): + """Base class to run an operator or trigger with given task context and task instance.""" + + def __init__( + self, + context: Context, + task_instance: TaskInstance, + ): + super().__init__() + self.__context = dict(context.items()) + self._task_instance = task_instance + self._is_async_mode: bool = False # Flag to track sync/async mode + self._result: Any | None = None + + @property + def task_instance(self) -> TaskInstance: + return self._task_instance + + @property + def dag_id(self) -> str: + return self._task_instance.dag_id + + @property + def task_id(self) -> str: + return self._task_instance.task_id + + @property + def task_index(self) -> int: + # return int(self._task_instance.task_id.rsplit("_", 1)[-1]) + return self._task_instance.map_index + + @property + def key(self): + return self.task_instance.xcom_key + + @property + def context(self) -> Context: + return { + **self.__context, + **{"ti": self.task_instance, "task_instance": self.task_instance}, + } + + @property + def operator(self) -> BaseOperator: + return self.task_instance.task + + @property + def mode(self) -> str: + return "async" if self._is_async_mode else "sync" + + @abstractmethod + def execute(self, *args, **kwargs): + raise NotImplementedError + + def run(self, *args, **kwargs): + self._result = self.execute(*args, **kwargs) + return self._result + + async def run_deferred(self, *args, **kwargs): + self._result = await self.execute(*args, **kwargs) + return self._result + + def __enter__(self): + if self.log.isEnabledFor(logging.INFO): + self.log.info( + "Attempting running task %s of %s for %s with map_index %s in %s mode.", + self.task_instance.try_number, + self.operator.retries, + self.task_instance.task_id, + self.task_index, + self.mode, + ) + + if self.task_instance.try_number == 0: + self.operator.render_template_fields(context=self.context) + self.operator.pre_execute(context=self.context) + self.task_instance._run_execute_callback(context=self.context, task=self.operator) + return self + + async def __aenter__(self): + self._is_async_mode = True + return self.__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + if exc_value: + if isinstance(exc_value, AirflowException): + if self.task_instance.next_try_number > self.operator.retries: + self.log.error( + "Max number of attempts for %s with map_index %s failed due to: %s", + self.task_instance.task_id, + self.task_index, + exc_value, + ) + if self.task_instance.task.on_failure_callback: + self.task_instance.task.on_failure_callback( + {**self.context, **{"exception": exc_value}} + ) + self.task_instance.state = TaskInstanceState.FAILED + raise exc_value + + self.task_instance.try_number += 1 + self.task_instance.end_date = timezone.utcnow() + if self.task_instance.task.on_retry_callback: + self.task_instance.task.on_retry_callback({**self.context, **{"exception": exc_value}}) + self.task_instance.state = TaskInstanceState.UP_FOR_RESCHEDULE + raise AirflowRescheduleTaskInstanceException(task=self.task_instance) + + raise exc_value + + self.task_instance.state = TaskInstanceState.SUCCESS + if self.task_instance.task.on_success_callback: + self.task_instance.task.on_success_callback(self.context) + self.operator.post_execute(context=self.context, result=self._result) + if self.log.isEnabledFor(logging.INFO): + self.log.info( + "Task instance %s for %s finished successfully in %s attempts in %s mode.", + self.task_index, + self.task_instance.task_id, + self.task_instance.next_try_number, + self.mode, + ) + + async def __aexit__(self, exc_type, exc_value, traceback): + self.__exit__(exc_type, exc_value, traceback) + + +class OperatorExecutor(TaskExecutor): + """ + Run an operator with given task context and task instance. + + If the execute function raises a TaskDeferred exception, then the trigger will be executed in an + async way using the TriggerExecutor. + + :meta private: + """ + + def execute(self, *args, **kwargs): + outlet_events = context_get_outlet_events(self.context) + # TODO: change back to operator.execute once ExecutorSafeguard is fixed + if hasattr(self.operator.execute, "__wrapped__"): + return create_executable_runner( + func=self.operator.execute.__wrapped__, + outlet_events=outlet_events, + logger=self.log, + ).run(self.operator, self.context) + return create_executable_runner( + func=self.operator.execute, + outlet_events=outlet_events, + logger=self.log, + ).run(self.context) + + +class TriggerExecutor(TaskExecutor): + """ + Run a trigger with given task deferred exception. + + If the next method raises a TaskDeferred exception, then the trigger instance will be re-executed with + the given TaskDeferred exception until no more TaskDeferred exceptions occur. The trigger will always + be executed in an async way. + + :meta private: + """ + + async def execute(self, task_deferred: TaskDeferred): + event = await run_trigger(task_deferred.trigger) + + self.log.debug("event: %s", event) + + if event: + self.log.debug("next_method: %s", task_deferred.method_name) + + if task_deferred.method_name: + try: + next_method = self.operator.next_callable(task_deferred.method_name, task_deferred.kwargs) + outlet_events = context_get_outlet_events(self.context) + return create_executable_runner( + func=next_method, + outlet_events=outlet_events, + logger=self.log, + ).run(self.context, event.payload) + except TaskDeferred as task_deferred: + return await self.execute(task_deferred=task_deferred) + + +class IterableOperator(BaseOperator): + """Object representing an iterable operator in a DAG.""" + + _operator_class: type[BaseOperator] + expand_input: ExpandInput + partial_kwargs: dict[str, Any] + # each operator should override this class attr for shallow copy attrs. + shallow_copy_attrs: Sequence[str] = ( + "expand_input", + "partial_kwargs", + "_log", + "_semaphore", + ) + + def __init__( + self, + *, + operator_class: type[BaseOperator], + expand_input: ExpandInput, + partial_kwargs: dict[str, Any] | None = None, + timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT, + **kwargs: Any, + ): + super().__init__(**kwargs) + self._operator_class = operator_class + self.expand_input = expand_input + self.partial_kwargs = partial_kwargs or {} + self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout + self._mapped_kwargs: Iterable[dict] = [] + if not self.max_active_tis_per_dag: + self.max_active_tis_per_dag = os.cpu_count() or 1 + self._semaphore = Semaphore(self.max_active_tis_per_dag) + self._number_of_tasks: int = 0 + XComArg.apply_upstream_relationship(self, self.expand_input.value) + + @property + def operator_name(self) -> str: + return self._operator_class.__name__ + + @property + def task_type(self) -> str: + return self._operator_class.__name__ + + @property + def chunk_size(self) -> int: + return self.max_active_tis_per_dag * 2 + + def _get_specified_expand_input(self) -> ExpandInput: + return self.expand_input + + def _unmap_operator(self, mapped_kwargs: dict): + kwargs = { + **self.partial_kwargs, + **{"task_id": self.task_id}, + **mapped_kwargs, + } + self._number_of_tasks += 1 + self.log.debug("kwargs: %s", kwargs) + self.log.debug("operator_class: %s", self._operator_class) + self.log.debug("number_of_tasks: %s", self._number_of_tasks) + return self._operator_class(**kwargs, _airflow_from_mapped=True) + + def _resolve(self, value, context: Context): + if isinstance(value, dict): + for key in value: + item = value[key] + if isinstance(item, ResolveMixin): + item = item.resolve(context=context) + + self.log.debug("resolved_value: %s", item) + + value[key] = item + + return value + + def _lazy_mapped_kwargs(self, value, context: Context) -> Iterable[dict]: + self.log.debug("_lazy_mapped_kwargs resolved_value: %s", value) + + resolved_value = self._resolve(value=value, context=context) + + self.log.debug("resolved resolved_value: %s", resolved_value) + + if isinstance(resolved_value, dict): + for key, item in resolved_value.items(): + if not isinstance(item, (Sequence, Iterable)) or isinstance(item, str): + yield {key: item} + else: + for sub_item in item: + yield {key: sub_item} + + def _resolve_expand_input(self, context: Context): + self.log.debug("resolve_expand_input: %s", self.expand_input) + + # TODO: check how to use the correct ResolvMixin type + if isinstance(self.expand_input.value, ResolveMixin): + resolved_input = self.expand_input.value.resolve(context=context) + else: + resolved_input = self.expand_input.value + + self.log.debug("resolved_input: %s", resolved_input) + + # Once _MapResult inherits from _MappableResult in Airflow, check only for _MappableResult + if type(resolved_input).__name__ in { + "_FilterResult", + "_MapResult", + "_LazyMapResult", + }: + self._mapped_kwargs = map( + lambda value: self._resolve(value=value, context=context), + resolved_input, + ) + else: + self._mapped_kwargs = iter(self._lazy_mapped_kwargs(value=resolved_input, context=context)) + + self.log.debug("mapped_kwargs: %s", self._mapped_kwargs) + + def render_template_fields( + self, + context: Context, + jinja_env: jinja2.Environment | None = None, + ) -> None: + self._resolve_expand_input(context=context) + + def _xcom_push(self, context: Context, task: TaskInstance, value: Any) -> None: + self.log.info("Pushing XCom %s", task.map_index) + + context["ti"].xcom_push(key=task.xcom_key, value=value) + + def _run_tasks( + self, + context: Context, + tasks: Iterable[TaskInstance], + ) -> None: + exception: BaseException | None = None + reschedule_date = timezone.utcnow() + prev_futures_count = 0 + futures: dict[Future, TaskInstance] = {} + + failed_tasks: list[TaskInstance] = [] + chunked_tasks = ichunked(tasks, self.chunk_size) + + with ThreadPoolExecutor(max_workers=self.max_active_tis_per_dag) as pool: + for task in next(chunked_tasks, []): + future = pool.submit(self._run_operator, context, task) + futures[future] = task + + while futures: + futures_count = len(futures) + + if futures_count != prev_futures_count: + self.log.info("Number of remaining futures: %s", futures_count) + prev_futures_count = futures_count + + deferred_tasks: dict[Coroutine[Any, Any, Any], TaskInstance] = {} + ready_futures = False + + with event_loop() as loop: + for future in as_completed(futures.keys()): + task = futures.pop(future) + ready_futures = True + + try: + result = future.result(timeout=self.timeout) + + self.log.debug("result: %s", result) + + if isinstance(result, TaskDeferred): + deferred_task = loop.create_task( + self._run_deferrable( + context=context, + task_instance=task, + task_deferred=result, + ) + ) + deferred_tasks[deferred_task] = task + elif result and task.task.do_xcom_push: + self._xcom_push( + context=context, + task=task, + value=result, + ) + except TimeoutError as e: + self.log.warning("A timeout occurred for task_id %s", task.task_id) + if task.next_try_number > self.retries: + exception = AirflowTaskTimeout(e) + else: + reschedule_date = min(reschedule_date, task.next_retry_datetime()) + failed_tasks.append(task) + except AirflowRescheduleTaskInstanceException as e: + reschedule_date = min(reschedule_date, e.reschedule_date) + failed_tasks.append(e.task) + except AirflowException as e: + self.log.error("An exception occurred for task_id %s", task.task_id) + exception = e + + if len(futures) < self.chunk_size: + for task in next(chunked_tasks, []): + future = pool.submit(self._run_operator, context, task) + futures[future] = task + + if deferred_tasks: + self.log.info("Running %s deferred tasks", len(deferred_tasks)) + + deferred_task_keys = list(deferred_tasks.keys()) + results = loop.run_until_complete(gather(*deferred_task_keys, return_exceptions=True)) + + for future, result in zip(deferred_task_keys, results): + task = deferred_tasks[future] + + self.log.debug("result: %s", result) + + if isinstance(result, Exception): + if isinstance(result, AirflowRescheduleTaskInstanceException): + reschedule_date = min(reschedule_date, result.reschedule_date) + failed_tasks.append(task) + else: + exception = result + elif result and task.task.do_xcom_push: + self._xcom_push( + context=context, + task=task, + value=result, + ) + elif not ready_futures and futures: + sleep(len(futures) * 0.1) + + if not failed_tasks: + if exception: + raise exception + if self.do_xcom_push: + return XComIterable( + task_id=self.task_id, + dag_id=self.dag_id, + run_id=context["run_id"], + length=self._number_of_tasks, + ) + + # Calculate delay before the next retry + if reschedule_date > timezone.utcnow(): + delay_seconds = ceil((reschedule_date - timezone.utcnow()).total_seconds()) + + self.log.info( + "Attempting to run %s failed tasks within %s seconds...", + len(failed_tasks), + delay_seconds, + ) + + sleep(delay_seconds) + + return self._run_tasks(context, failed_tasks) + + def _run_operator(self, context: Context, task_instance: TaskInstance): + try: + with OperatorExecutor(context=context, task_instance=task_instance) as executor: + return executor.run() + except TaskDeferred as task_deferred: + return task_deferred + + async def _run_deferrable( + self, context: Context, task_instance: TaskInstance, task_deferred: TaskDeferred + ): + async with self._semaphore: + async with TriggerExecutor(context=context, task_instance=task_instance) as executor: + return await executor.run_deferred(task_deferred) + + def _create_task(self, run_id: str, index: int, mapped_kwargs: dict) -> TaskInstance: + operator = self._unmap_operator(mapped_kwargs) + return VolatileTaskInstance( + task=operator, + run_id=run_id, + state=TaskInstanceState.SCHEDULED.value, + map_index=index, + ) + + def execute(self, context: Context): + return self._run_tasks( + context=context, + tasks=iter( + map( + lambda mapped_kwargs: self._create_task( + context["ti"].run_id, + mapped_kwargs[0], + mapped_kwargs[1], + ), + enumerate(self._mapped_kwargs), + ) + ), + ) diff --git a/airflow-core/src/airflow/triggers/base.py b/airflow-core/src/airflow/triggers/base.py index 2dfe6880786f6..bb14f066d3268 100644 --- a/airflow-core/src/airflow/triggers/base.py +++ b/airflow-core/src/airflow/triggers/base.py @@ -38,6 +38,12 @@ log = structlog.get_logger(logger_name=__name__) +async def run_trigger(trigger: BaseTrigger) -> TriggerEvent | None: + async for event in trigger.run(): + return event + return None + + @dataclass class StartTriggerArgs: """Arguments required for start task execution from triggerer.""" diff --git a/airflow-core/tests/unit/models/test_iterableoperator.py b/airflow-core/tests/unit/models/test_iterableoperator.py new file mode 100644 index 0000000000000..1456cd0fbceb4 --- /dev/null +++ b/airflow-core/tests/unit/models/test_iterableoperator.py @@ -0,0 +1,1521 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from collections import defaultdict +from typing import TYPE_CHECKING +from unittest.mock import patch + +import pytest +from sqlalchemy import select + +from airflow.decorators import setup, task, task_group, teardown +from airflow.exceptions import AirflowSkipException +from airflow.models.baseoperator import BaseOperator +from airflow.models.dag import DAG +from airflow.models.iterableoperator import IterableOperator +from airflow.models.taskinstance import TaskInstance +from airflow.models.taskmap import TaskMap +from airflow.providers.standard.operators.python import PythonOperator +from airflow.utils.state import TaskInstanceState +from airflow.utils.task_group import TaskGroup +from airflow.utils.task_instance_session import set_current_task_instance_session +from airflow.utils.xcom import XCOM_RETURN_KEY + +from tests.models import DEFAULT_DATE +from tests_common.test_utils.mapping import expand_mapped_task +from tests_common.test_utils.mock_operators import ( + MockOperator, + MockOperatorWithNestedFields, + NestedFields, +) + +pytestmark = pytest.mark.db_test + +if TYPE_CHECKING: + from airflow.sdk.definitions.context import Context + + +@patch("airflow.models.abstractoperator.AbstractOperator.render_template") +def test_task_mapping_with_dag_and_list_of_pandas_dataframe(mock_render_template, caplog): + class UnrenderableClass: + def __bool__(self): + raise ValueError("Similar to Pandas DataFrames, this class raises an exception.") + + class CustomOperator(BaseOperator): + template_fields = ("arg",) + + def __init__(self, arg, **kwargs): + super().__init__(**kwargs) + self.arg = arg + + def execute(self, context: Context): + pass + + with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE) as dag: + task1 = CustomOperator(task_id="op1", arg=None) + unrenderable_values = [UnrenderableClass(), UnrenderableClass()] + mapped = IterableOperator.partial(task_id="task_2").iterate(arg=unrenderable_values) + task1 >> mapped + dag.test() + assert ( + "Unable to check if the value of type 'UnrenderableClass' is False for task 'task_2', field 'arg'" + in caplog.text + ) + mock_render_template.assert_called() + + +def test_map_xcom_arg_multiple_upstream_xcoms(dag_maker, session): + """Test that the correct number of downstream tasks are generated when mapping with an XComArg""" + + class PushExtraXComOperator(BaseOperator): + """Push an extra XCom value along with the default return value.""" + + def __init__(self, return_value, **kwargs): + super().__init__(**kwargs) + self.return_value = return_value + + def execute(self, context): + context["task_instance"].xcom_push(key="extra_key", value="extra_value") + return self.return_value + + with dag_maker("test-dag", session=session, start_date=DEFAULT_DATE) as dag: + upstream_return = [1, 2, 3] + task1 = PushExtraXComOperator(return_value=upstream_return, task_id="task_1") + task2 = PushExtraXComOperator.partial(task_id="task_2").iterate(return_value=task1.output) + task3 = PushExtraXComOperator.partial(task_id="task_3").iterate(return_value=task2.output) + + dr = dag_maker.create_dagrun() + ti_1 = dr.get_task_instance("task_1", session) + ti_1.run() + + ti_2s, _ = TaskMap.iterate_mapped_task(task2, dr.run_id, session=session) + for ti in ti_2s: + ti.refresh_from_task(dag.get_task("task_2")) + ti.run() + + ti_3s, _ = TaskMap.iterate_mapped_task(task3, dr.run_id, session=session) + for ti in ti_3s: + ti.refresh_from_task(dag.get_task("task_3")) + ti.run() + + assert len(ti_3s) == len(ti_2s) == len(upstream_return) + + +@pytest.mark.parametrize( + ["num_existing_tis", "expected"], + ( + pytest.param(0, [(0, None), (1, None), (2, None)], id="only-unmapped-ti-exists"), + pytest.param( + 3, + [(0, "success"), (1, "success"), (2, "success")], + id="all-tis-exist", + ), + pytest.param( + 5, + [ + (0, "success"), + (1, "success"), + (2, "success"), + (3, TaskInstanceState.REMOVED), + (4, TaskInstanceState.REMOVED), + ], + id="tis-to-be-removed", + ), + ), +) +def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expected): + literal = [1, 2, {"a": "b"}] + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id="task_2").iterate(arg2=task1.output) + + dr = dag_maker.create_dagrun() + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=len(literal), + keys=None, + ) + ) + + if num_existing_tis: + # Remove the map_index=-1 TI when we're creating other TIs + session.query(TaskInstance).filter( + TaskInstance.dag_id == mapped.dag_id, + TaskInstance.task_id == mapped.task_id, + TaskInstance.run_id == dr.run_id, + ).delete() + + for index in range(num_existing_tis): + # Give the existing TIs a state to make sure we don't change them + ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS) + session.add(ti) + session.flush() + + TaskMap.iterate_mapped_task(mapped, dr.run_id, session=session) + + indices = ( + session.query(TaskInstance.map_index, TaskInstance.state) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TaskInstance.map_index) + .all() + ) + + assert indices == expected + + +def test_expand_mapped_task_failed_state_in_db(dag_maker, session): + """ + This test tries to recreate a faulty state in the database and checks if we can recover from it. + The state that happens is that there exists mapped task instances and the unmapped task instance. + So we have instances with map_index [-1, 0, 1]. The -1 task instances should be removed in this case. + """ + literal = [1, 2] + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id="task_2").iterate(arg2=task1.output) + + dr = dag_maker.create_dagrun() + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=len(literal), + keys=None, + ) + ) + + for index in range(2): + # Give the existing TIs a state to make sure we don't change them + ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS) + session.add(ti) + session.flush() + + indices = ( + session.query(TaskInstance.map_index, TaskInstance.state) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TaskInstance.map_index) + .all() + ) + # Make sure we have the faulty state in the database + assert indices == [(-1, None), (0, "success"), (1, "success")] + + TaskMap.iterate_mapped_task(mapped, dr.run_id, session=session) + + indices = ( + session.query(TaskInstance.map_index, TaskInstance.state) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TaskInstance.map_index) + .all() + ) + # The -1 index should be cleaned up + assert indices == [(0, "success"), (1, "success")] + + +def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session): + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id="task_2").iterate(arg2=task1.output) + + dr = dag_maker.create_dagrun() + + expand_mapped_task(mapped, dr.run_id, task1.task_id, length=0, session=session) + + indices = ( + session.query(TaskInstance.map_index, TaskInstance.state) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TaskInstance.map_index) + .all() + ) + + assert indices == [(-1, TaskInstanceState.SKIPPED)] + + +class _RenderTemplateFieldsValidationOperator(BaseOperator): + template_fields = ( + "partial_template", + "map_template_xcom", + "map_template_literal", + "map_template_file", + ) + template_ext = (".ext",) + + fields_to_test = [ + "partial_template", + "partial_static", + "map_template_xcom", + "map_template_literal", + "map_static", + "map_template_file", + ] + + def __init__( + self, + partial_template, + partial_static, + map_template_xcom, + map_template_literal, + map_static, + map_template_file, + **kwargs, + ): + for field in self.fields_to_test: + setattr(self, field, value := locals()[field]) + assert isinstance(value, str), "value should have been resolved before unmapping" + super().__init__(**kwargs) + + def execute(self, context): + pass + + +def test_mapped_render_template_fields_validating_operator(dag_maker, session, tmp_path): + file_template_dir = tmp_path / "path" / "to" + file_template_dir.mkdir(parents=True, exist_ok=True) + file_template = file_template_dir / "file.ext" + file_template.write_text("loaded data") + + with set_current_task_instance_session(session=session): + with dag_maker(session=session, template_searchpath=tmp_path.__fspath__()): + task1 = BaseOperator(task_id="op1") + output1 = task1.output + mapped = _RenderTemplateFieldsValidationOperator.partial( + task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" + ).iterate( + map_static=output1, + map_template_literal=["{{ ds }}"], + map_template_xcom=output1, + map_template_file=["/path/to/file.ext"], + ) + + dr = dag_maker.create_dagrun() + ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) + ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session) + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=1, + keys=None, + ) + ) + session.flush() + + mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) + mapped_ti.map_index = 0 + assert isinstance(mapped_ti.task, IterableOperator) + mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) + assert isinstance(mapped_ti.task, _RenderTemplateFieldsValidationOperator) + + assert mapped_ti.task.partial_template == "a", "Should be rendered!" + assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be rendered!" + assert mapped_ti.task.map_static == "{{ ds }}", "Should not be rendered!" + assert mapped_ti.task.map_template_literal == "2016-01-01", "Should be rendered!" + assert mapped_ti.task.map_template_xcom == "{{ ds }}", "XCom resolved but not double rendered!" + assert mapped_ti.task.map_template_file == "loaded data", "Should be rendered!" + + +def test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, tmp_path): + file_template_dir = tmp_path / "path" / "to" + file_template_dir.mkdir(parents=True, exist_ok=True) + file_template = file_template_dir / "file.ext" + file_template.write_text("loaded data") + + with set_current_task_instance_session(session=session): + with dag_maker(session=session, template_searchpath=tmp_path.__fspath__()): + mapped = _RenderTemplateFieldsValidationOperator.partial( + task_id="a", partial_template="{{ ti.task_id }}", partial_static="{{ ti.task_id }}" + ).iterate_kwargs( + [ + { + "map_template_literal": "{{ ds }}", + "map_static": "{{ ds }}", + "map_template_file": "/path/to/file.ext", + # This field is not tested since XCom inside a literal list + # is not rendered (matching BaseOperator rendering behavior). + "map_template_xcom": "", + } + ] + ) + + dr = dag_maker.create_dagrun() + mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session, map_index=0) + assert isinstance(mapped_ti.task, IterableOperator) + mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) + assert isinstance(mapped_ti.task, _RenderTemplateFieldsValidationOperator) + + assert mapped_ti.task.partial_template == "a", "Should be rendered!" + assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be rendered!" + assert mapped_ti.task.map_template_literal == "2016-01-01", "Should be rendered!" + assert mapped_ti.task.map_static == "{{ ds }}", "Should not be rendered!" + assert mapped_ti.task.map_template_file == "loaded data", "Should be rendered!" + + +def test_mapped_render_nested_template_fields(dag_maker, session): + with dag_maker(session=session): + MockOperatorWithNestedFields.partial( + task_id="t", arg2=NestedFields(field_1="{{ ti.task_id }}", field_2="value_2") + ).iterate(arg1=["{{ ti.task_id }}", ["s", "{{ ti.task_id }}"]]) + + dr = dag_maker.create_dagrun() + decision = dr.task_instance_scheduling_decisions() + tis = {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} + assert len(tis) == 2 + + ti = tis[("t", 0)] + ti.run(session=session) + assert ti.task.arg1 == "t" + assert ti.task.arg2.field_1 == "t" + assert ti.task.arg2.field_2 == "value_2" + + ti = tis[("t", 1)] + ti.run(session=session) + assert ti.task.arg1 == ["s", "t"] + assert ti.task.arg2.field_1 == "t" + assert ti.task.arg2.field_2 == "value_2" + + +@pytest.mark.parametrize( + ["num_existing_tis", "expected"], + ( + pytest.param(0, [(0, None), (1, None), (2, None)], id="only-unmapped-ti-exists"), + pytest.param( + 3, + [(0, "success"), (1, "success"), (2, "success")], + id="all-tis-exist", + ), + pytest.param( + 5, + [ + (0, "success"), + (1, "success"), + (2, "success"), + (3, TaskInstanceState.REMOVED), + (4, TaskInstanceState.REMOVED), + ], + id="tis-to-be-removed", + ), + ), +) +def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis, expected): + literal = [{"arg1": "a"}, {"arg1": "b"}, {"arg1": "c"}] + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id="task_2").iterate_kwargs(task1.output) + + dr = dag_maker.create_dagrun() + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=len(literal), + keys=None, + ) + ) + + if num_existing_tis: + # Remove the map_index=-1 TI when we're creating other TIs + session.query(TaskInstance).filter( + TaskInstance.dag_id == mapped.dag_id, + TaskInstance.task_id == mapped.task_id, + TaskInstance.run_id == dr.run_id, + ).delete() + + for index in range(num_existing_tis): + # Give the existing TIs a state to make sure we don't change them + ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS) + session.add(ti) + session.flush() + + TaskMap.iterate_mapped_task(mapped, dr.run_id, session=session) + + indices = ( + session.query(TaskInstance.map_index, TaskInstance.state) + .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id) + .order_by(TaskInstance.map_index) + .all() + ) + + assert indices == expected + + +def _create_mapped_with_name_template_classic(*, task_id, map_names, template): + class HasMapName(BaseOperator): + def __init__(self, *, map_name: str, **kwargs): + super().__init__(**kwargs) + self.map_name = map_name + + def execute(self, context): + context["map_name"] = self.map_name + + return HasMapName.partial(task_id=task_id, map_index_template=template).iterate( + map_name=map_names, + ) + + +def _create_mapped_with_name_template_taskflow(*, task_id, map_names, template): + from airflow.providers.standard.operators.python import get_current_context + + @task(task_id=task_id, map_index_template=template) + def task1(map_name): + context = get_current_context() + context["map_name"] = map_name + + return task1.iterate(map_name=map_names) + + +def _create_named_map_index_renders_on_failure_classic(*, task_id, map_names, template): + class HasMapName(BaseOperator): + def __init__(self, *, map_name: str, **kwargs): + super().__init__(**kwargs) + self.map_name = map_name + + def execute(self, context): + context["map_name"] = self.map_name + raise AirflowSkipException("Imagine this task failed!") + + return HasMapName.partial(task_id=task_id, map_index_template=template).iterate( + map_name=map_names, + ) + + +def _create_named_map_index_renders_on_failure_taskflow(*, task_id, map_names, template): + from airflow.providers.standard.operators.python import get_current_context + + @task(task_id=task_id, map_index_template=template) + def task1(map_name): + context = get_current_context() + context["map_name"] = map_name + raise AirflowSkipException("Imagine this task failed!") + + return task1.iterate(map_name=map_names) + + +@pytest.mark.parametrize( + "template, expected_rendered_names", + [ + pytest.param(None, [None, None], id="unset"), + pytest.param("", ["", ""], id="constant"), + pytest.param("{{ ti.task_id }}-{{ ti.map_index }}", ["task1-0", "task1-1"], id="builtin"), + pytest.param("{{ ti.task_id }}-{{ map_name }}", ["task1-a", "task1-b"], id="custom"), + ], +) +@pytest.mark.parametrize( + "create_mapped_task", + [ + pytest.param(_create_mapped_with_name_template_classic, id="classic"), + pytest.param(_create_mapped_with_name_template_taskflow, id="taskflow"), + pytest.param(_create_named_map_index_renders_on_failure_classic, id="classic-failure"), + pytest.param(_create_named_map_index_renders_on_failure_taskflow, id="taskflow-failure"), + ], +) +def test_expand_mapped_task_instance_with_named_index( + dag_maker, + session, + create_mapped_task, + template, + expected_rendered_names, +) -> None: + """Test that the correct number of downstream tasks are generated when mapping with an XComArg""" + with dag_maker("test-dag", session=session, start_date=DEFAULT_DATE): + create_mapped_task(task_id="task1", map_names=["a", "b"], template=template) + + dr = dag_maker.create_dagrun() + tis = dr.get_task_instances() + for ti in tis: + ti.run() + session.flush() + + indices = session.scalars( + select(TaskInstance.rendered_map_index) + .where( + TaskInstance.dag_id == "test-dag", + TaskInstance.task_id == "task1", + TaskInstance.run_id == dr.run_id, + ) + .order_by(TaskInstance.map_index) + ).all() + + assert indices == expected_rendered_names + + +@pytest.mark.parametrize( + "map_index, expected", + [ + pytest.param(0, "2016-01-01", id="0"), + pytest.param(1, 2, id="1"), + ], +) +def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, map_index, expected): + with set_current_task_instance_session(session=session): + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = MockOperator.partial(task_id="a", arg2="{{ ti.task_id }}").iterate_kwargs(task1.output) + + dr = dag_maker.create_dagrun() + ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) + + ti.xcom_push(key=XCOM_RETURN_KEY, value=[{"arg1": "{{ ds }}"}, {"arg1": 2}], session=session) + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=2, + keys=None, + ) + ) + session.flush() + + ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) + ti.refresh_from_task(mapped) + ti.map_index = map_index + assert isinstance(ti.task, IterableOperator) + mapped.render_template_fields(context=ti.get_template_context(session=session)) + assert isinstance(ti.task, MockOperator) + assert ti.task.arg1 == expected + assert ti.task.arg2 == "a" + + +def test_all_xcomargs_from_mapped_tasks_are_consumable(dag_maker, session): + class PushXcomOperator(MockOperator): + def __init__(self, arg1, **kwargs): + super().__init__(arg1=arg1, **kwargs) + + def execute(self, context): + return self.arg1 + + class ConsumeXcomOperator(PushXcomOperator): + def execute(self, context): + assert set(self.arg1) == {1, 2, 3} + + with dag_maker("test_all_xcomargs_from_mapped_tasks_are_consumable"): + op1 = PushXcomOperator.partial(task_id="op1").iterate(arg1=[1, 2, 3]) + ConsumeXcomOperator(task_id="op2", arg1=op1.output) + + dr = dag_maker.create_dagrun() + tis = dr.get_task_instances(session=session) + for ti in tis: + ti.run() + + +class TestMappedSetupTeardown: + @staticmethod + def get_states(dr): + ti_dict = defaultdict(dict) + for ti in dr.get_task_instances(): + if ti.map_index == -1: + ti_dict[ti.task_id] = ti.state + else: + ti_dict[ti.task_id][ti.map_index] = ti.state + return dict(ti_dict) + + def classic_operator(self, task_id, ret=None, partial=False, fail=False): + def success_callable(ret=None): + def inner(*args, **kwargs): + print(args) + print(kwargs) + if ret: + return ret + + return inner + + def failure_callable(): + def inner(*args, **kwargs): + print(args) + print(kwargs) + raise ValueError("fail") + + return inner + + kwargs = dict(task_id=task_id) + if not fail: + kwargs.update(python_callable=success_callable(ret=ret)) + else: + kwargs.update(python_callable=failure_callable()) + if partial: + return PythonOperator.partial(**kwargs) + else: + return PythonOperator(**kwargs) + + @pytest.mark.parametrize("type_", ["taskflow", "classic"]) + def test_one_to_many_work_failed(self, type_, dag_maker): + """ + Work task failed. Setup maps to teardown. Should have 3 teardowns all successful even + though the work task has failed. + """ + if type_ == "taskflow": + with dag_maker() as dag: + + @setup + def my_setup(): + print("setting up multiple things") + return [1, 2, 3] + + @task + def my_work(val): + print(f"doing work with multiple things: {val}") + raise ValueError("fail!") + + @teardown + def my_teardown(val): + print(f"teardown: {val}") + + s = my_setup() + t = my_teardown.iterate(val=s) + with t: + my_work(s) + else: + + @task + def my_work(val): + print(f"work: {val}") + raise ValueError("i fail") + + with dag_maker() as dag: + my_setup = self.classic_operator("my_setup", [[1], [2], [3]]) + my_teardown = self.classic_operator("my_teardown", partial=True) + t = my_teardown.iterate(op_args=my_setup.output) + with t.as_teardown(setups=my_setup): + my_work(my_setup.output) + + dr = dag.test() + states = self.get_states(dr) + expected = { + "my_setup": "success", + "my_work": "failed", + "my_teardown": {0: "success", 1: "success", 2: "success"}, + } + assert states == expected + + @pytest.mark.parametrize("type_", ["taskflow", "classic"]) + def test_many_one_explicit_odd_setup_mapped_setups_fail(self, type_, dag_maker): + """ + one unmapped setup goes to two different teardowns + one mapped setup goes to same teardown + mapped setups fail + teardowns should still run + """ + if type_ == "taskflow": + with dag_maker() as dag: + + @task + def other_setup(): + print("other setup") + return "other setup" + + @task + def other_work(): + print("other work") + return "other work" + + @task + def other_teardown(): + print("other teardown") + return "other teardown" + + @task + def my_setup(val): + print(f"setup: {val}") + raise ValueError("fail") + return val + + @task + def my_work(val): + print(f"work: {val}") + + @task + def my_teardown(val): + print(f"teardown: {val}") + + s = my_setup.iterate(val=["data1.json", "data2.json", "data3.json"]) + o_setup = other_setup() + o_teardown = other_teardown() + with o_teardown.as_teardown(setups=o_setup): + other_work() + t = my_teardown(s).as_teardown(setups=s) + with t: + my_work(s) + o_setup >> t + else: + with dag_maker() as dag: + + @task + def other_work(): + print("other work") + return "other work" + + @task + def my_work(val): + print(f"work: {val}") + + my_teardown = self.classic_operator("my_teardown") + + my_setup = self.classic_operator("my_setup", partial=True, fail=True) + s = my_setup.iterate(op_args=[["data1.json"], ["data2.json"], ["data3.json"]]) + o_setup = self.classic_operator("other_setup") + o_teardown = self.classic_operator("other_teardown") + with o_teardown.as_teardown(setups=o_setup): + other_work() + t = my_teardown.as_teardown(setups=s) + with t: + my_work(s.output) + o_setup >> t + + dr = dag.test() + states = self.get_states(dr) + expected = { + "my_setup": {0: "failed", 1: "failed", 2: "failed"}, + "other_setup": "success", + "other_teardown": "success", + "other_work": "success", + "my_teardown": "success", + "my_work": "upstream_failed", + } + assert states == expected + + @pytest.mark.parametrize("type_", ["taskflow", "classic"]) + def test_many_one_explicit_odd_setup_all_setups_fail(self, type_, dag_maker): + """ + one unmapped setup goes to two different teardowns + one mapped setup goes to same teardown + all setups fail + teardowns should not run + """ + if type_ == "taskflow": + with dag_maker() as dag: + + @task + def other_setup(): + print("other setup") + raise ValueError("fail") + return "other setup" + + @task + def other_work(): + print("other work") + return "other work" + + @task + def other_teardown(): + print("other teardown") + return "other teardown" + + @task + def my_setup(val): + print(f"setup: {val}") + raise ValueError("fail") + return val + + @task + def my_work(val): + print(f"work: {val}") + + @task + def my_teardown(val): + print(f"teardown: {val}") + + s = my_setup.iterate(val=["data1.json", "data2.json", "data3.json"]) + o_setup = other_setup() + o_teardown = other_teardown() + with o_teardown.as_teardown(setups=o_setup): + other_work() + t = my_teardown(s).as_teardown(setups=s) + with t: + my_work(s) + o_setup >> t + else: + with dag_maker() as dag: + + @task + def other_setup(): + print("other setup") + raise ValueError("fail") + return "other setup" + + @task + def other_work(): + print("other work") + return "other work" + + @task + def other_teardown(): + print("other teardown") + return "other teardown" + + @task + def my_work(val): + print(f"work: {val}") + + my_setup = self.classic_operator("my_setup", partial=True, fail=True) + s = my_setup.iterate(op_args=[["data1.json"], ["data2.json"], ["data3.json"]]) + o_setup = other_setup() + o_teardown = other_teardown() + with o_teardown.as_teardown(setups=o_setup): + other_work() + my_teardown = self.classic_operator("my_teardown") + t = my_teardown.as_teardown(setups=s) + with t: + my_work(s.output) + o_setup >> t + + dr = dag.test() + states = self.get_states(dr) + expected = { + "my_teardown": "upstream_failed", + "other_setup": "failed", + "other_work": "upstream_failed", + "other_teardown": "upstream_failed", + "my_setup": {0: "failed", 1: "failed", 2: "failed"}, + "my_work": "upstream_failed", + } + assert states == expected + + @pytest.mark.parametrize("type_", ["taskflow", "classic"]) + def test_many_one_explicit_odd_setup_one_mapped_fails(self, type_, dag_maker): + """ + one unmapped setup goes to two different teardowns + one mapped setup goes to same teardown + one of the mapped setup instances fails + teardowns should all run + """ + if type_ == "taskflow": + with dag_maker() as dag: + + @task + def other_setup(): + print("other setup") + return "other setup" + + @task + def other_work(): + print("other work") + return "other work" + + @task + def other_teardown(): + print("other teardown") + return "other teardown" + + @task + def my_setup(val): + if val == "data2.json": + raise ValueError("fail!") + elif val == "data3.json": + raise AirflowSkipException("skip!") + print(f"setup: {val}") + return val + + @task + def my_work(val): + print(f"work: {val}") + + @task + def my_teardown(val): + print(f"teardown: {val}") + + s = my_setup.iterate(val=["data1.json", "data2.json", "data3.json"]) + o_setup = other_setup() + o_teardown = other_teardown() + with o_teardown.as_teardown(setups=o_setup): + other_work() + t = my_teardown(s).as_teardown(setups=s) + with t: + my_work(s) + o_setup >> t + else: + with dag_maker() as dag: + + @task + def other_setup(): + print("other setup") + return "other setup" + + @task + def other_work(): + print("other work") + return "other work" + + @task + def other_teardown(): + print("other teardown") + return "other teardown" + + def my_setup_callable(val): + if val == "data2.json": + raise ValueError("fail!") + elif val == "data3.json": + raise AirflowSkipException("skip!") + print(f"setup: {val}") + return val + + my_setup = PythonOperator.partial(task_id="my_setup", python_callable=my_setup_callable) + + @task + def my_work(val): + print(f"work: {val}") + + def my_teardown_callable(val): + print(f"teardown: {val}") + + s = my_setup.iterate(op_args=[["data1.json"], ["data2.json"], ["data3.json"]]) + o_setup = other_setup() + o_teardown = other_teardown() + with o_teardown.as_teardown(setups=o_setup): + other_work() + my_teardown = PythonOperator( + task_id="my_teardown", op_args=[s.output], python_callable=my_teardown_callable + ) + t = my_teardown.as_teardown(setups=s) + with t: + my_work(s.output) + o_setup >> t + + dr = dag.test() + states = self.get_states(dr) + expected = { + "my_setup": {0: "success", 1: "failed", 2: "skipped"}, + "other_setup": "success", + "other_teardown": "success", + "other_work": "success", + "my_teardown": "success", + "my_work": "upstream_failed", + } + assert states == expected + + @pytest.mark.parametrize("type_", ["taskflow", "classic"]) + def test_one_to_many_as_teardown(self, type_, dag_maker): + """ + 1 setup mapping to 3 teardowns + 1 work task + work fails + teardowns succeed + dagrun should be failure + """ + if type_ == "taskflow": + with dag_maker() as dag: + + @task + def my_setup(): + print("setting up multiple things") + return [1, 2, 3] + + @task + def my_work(val): + print(f"doing work with multiple things: {val}") + raise ValueError("this fails") + return val + + @task + def my_teardown(val): + print(f"teardown: {val}") + + s = my_setup() + t = my_teardown.iterate(val=s).as_teardown(setups=s) + with t: + my_work(s) + else: + with dag_maker() as dag: + + @task + def my_work(val): + print(f"doing work with multiple things: {val}") + raise ValueError("this fails") + return val + + my_teardown = self.classic_operator(task_id="my_teardown", partial=True) + + s = self.classic_operator(task_id="my_setup", ret=[[1], [2], [3]]) + t = my_teardown.iterate(op_args=s.output).as_teardown(setups=s) + with t: + my_work(s) + dr = dag.test() + states = self.get_states(dr) + expected = { + "my_setup": "success", + "my_teardown": {0: "success", 1: "success", 2: "success"}, + "my_work": "failed", + } + assert states == expected + + @pytest.mark.parametrize("type_", ["taskflow", "classic"]) + def test_one_to_many_as_teardown_on_failure_fail_dagrun(self, type_, dag_maker): + """ + 1 setup mapping to 3 teardowns + 1 work task + work succeeds + all but one teardown succeed + on_failure_fail_dagrun=True + dagrun should be success + """ + if type_ == "taskflow": + with dag_maker() as dag: + + @task + def my_setup(): + print("setting up multiple things") + return [1, 2, 3] + + @task + def my_work(val): + print(f"doing work with multiple things: {val}") + return val + + @task + def my_teardown(val): + print(f"teardown: {val}") + if val == 2: + raise ValueError("failure") + + s = my_setup() + t = my_teardown.iterate(val=s).as_teardown(setups=s, on_failure_fail_dagrun=True) + with t: + my_work(s) + # todo: if on_failure_fail_dagrun=True, should we still regard the WORK task as a leaf? + else: + with dag_maker() as dag: + + @task + def my_work(val): + print(f"doing work with multiple things: {val}") + return val + + def my_teardown_callable(val): + print(f"teardown: {val}") + if val == 2: + raise ValueError("failure") + + s = self.classic_operator(task_id="my_setup", ret=[[1], [2], [3]]) + my_teardown = PythonOperator.partial( + task_id="my_teardown", python_callable=my_teardown_callable + ).iterate(op_args=s.output) + t = my_teardown.as_teardown(setups=s, on_failure_fail_dagrun=True) + with t: + my_work(s.output) + + dr = dag.test() + states = self.get_states(dr) + expected = { + "my_setup": "success", + "my_teardown": {0: "success", 1: "failed", 2: "success"}, + "my_work": "success", + } + assert states == expected + + @pytest.mark.parametrize("type_", ["taskflow", "classic"]) + def test_mapped_task_group_simple(self, type_, dag_maker, session): + """ + Mapped task group wherein there's a simple s >> w >> t pipeline. + When s is skipped, all should be skipped + When s is failed, all should be upstream failed + """ + if type_ == "taskflow": + with dag_maker() as dag: + + @setup + def my_setup(val): + if val == "data2.json": + raise ValueError("fail!") + elif val == "data3.json": + raise AirflowSkipException("skip!") + print(f"setup: {val}") + + @task + def my_work(val): + print(f"work: {val}") + + @teardown + def my_teardown(val): + print(f"teardown: {val}") + + @task_group + def file_transforms(filename): + s = my_setup(filename) + t = my_teardown(filename) + s >> t + with t: + my_work(filename) + + file_transforms.iterate(filename=["data1.json", "data2.json", "data3.json"]) + else: + with dag_maker() as dag: + + def my_setup_callable(val): + if val == "data2.json": + raise ValueError("fail!") + elif val == "data3.json": + raise AirflowSkipException("skip!") + print(f"setup: {val}") + + @task + def my_work(val): + print(f"work: {val}") + + def my_teardown_callable(val): + print(f"teardown: {val}") + + @task_group + def file_transforms(filename): + s = PythonOperator( + task_id="my_setup", python_callable=my_setup_callable, op_args=filename + ) + t = PythonOperator( + task_id="my_teardown", python_callable=my_teardown_callable, op_args=filename + ) + with t.as_teardown(setups=s): + my_work(filename) + + file_transforms.iterate(filename=[["data1.json"], ["data2.json"], ["data3.json"]]) + dr = dag.test() + states = self.get_states(dr) + expected = { + "file_transforms.my_setup": {0: "success", 1: "failed", 2: "skipped"}, + "file_transforms.my_work": {0: "success", 1: "upstream_failed", 2: "skipped"}, + "file_transforms.my_teardown": {0: "success", 1: "upstream_failed", 2: "skipped"}, + } + + assert states == expected + + @pytest.mark.parametrize("type_", ["taskflow", "classic"]) + def test_mapped_task_group_work_fail_or_skip(self, type_, dag_maker): + """ + Mapped task group wherein there's a simple s >> w >> t pipeline. + When w is skipped, teardown should still run + When w is failed, teardown should still run + """ + if type_ == "taskflow": + with dag_maker() as dag: + + @setup + def my_setup(val): + print(f"setup: {val}") + + @task + def my_work(val): + if val == "data2.json": + raise ValueError("fail!") + elif val == "data3.json": + raise AirflowSkipException("skip!") + print(f"work: {val}") + + @teardown + def my_teardown(val): + print(f"teardown: {val}") + + @task_group + def file_transforms(filename): + s = my_setup(filename) + t = my_teardown(filename).as_teardown(setups=s) + with t: + my_work(filename) + + file_transforms.iterate(filename=["data1.json", "data2.json", "data3.json"]) + else: + with dag_maker() as dag: + + @task + def my_work(vals): + val = vals[0] + if val == "data2.json": + raise ValueError("fail!") + elif val == "data3.json": + raise AirflowSkipException("skip!") + print(f"work: {val}") + + @teardown + def my_teardown(val): + print(f"teardown: {val}") + + def null_callable(val): + pass + + @task_group + def file_transforms(filename): + s = PythonOperator(task_id="my_setup", python_callable=null_callable, op_args=filename) + t = PythonOperator(task_id="my_teardown", python_callable=null_callable, op_args=filename) + t = t.as_teardown(setups=s) + with t: + my_work(filename) + + file_transforms.iterate(filename=[["data1.json"], ["data2.json"], ["data3.json"]]) + dr = dag.test() + states = self.get_states(dr) + expected = { + "file_transforms.my_setup": {0: "success", 1: "success", 2: "success"}, + "file_transforms.my_teardown": {0: "success", 1: "success", 2: "success"}, + "file_transforms.my_work": {0: "success", 1: "failed", 2: "skipped"}, + } + assert states == expected + + @pytest.mark.parametrize("type_", ["taskflow", "classic"]) + def test_teardown_many_one_explicit(self, type_, dag_maker): + """-- passing + one mapped setup going to one unmapped work + 3 diff states for setup: success / failed / skipped + teardown still runs, and receives the xcom from the single successful setup + """ + if type_ == "taskflow": + with dag_maker() as dag: + + @task + def my_setup(val): + if val == "data2.json": + raise ValueError("fail!") + elif val == "data3.json": + raise AirflowSkipException("skip!") + print(f"setup: {val}") + return val + + @task + def my_work(val): + print(f"work: {val}") + + @task + def my_teardown(val): + print(f"teardown: {val}") + + s = my_setup.iterate(val=["data1.json", "data2.json", "data3.json"]) + with my_teardown(s).as_teardown(setups=s): + my_work(s) + else: + with dag_maker() as dag: + + def my_setup_callable(val): + if val == "data2.json": + raise ValueError("fail!") + elif val == "data3.json": + raise AirflowSkipException("skip!") + print(f"setup: {val}") + return val + + @task + def my_work(val): + print(f"work: {val}") + + s = PythonOperator.partial(task_id="my_setup", python_callable=my_setup_callable) + s = s.iterate(op_args=[["data1.json"], ["data2.json"], ["data3.json"]]) + t = self.classic_operator("my_teardown") + with t.as_teardown(setups=s): + my_work(s.output) + + dr = dag.test() + states = self.get_states(dr) + expected = { + "my_setup": {0: "success", 1: "failed", 2: "skipped"}, + "my_teardown": "success", + "my_work": "upstream_failed", + } + assert states == expected + + def test_one_to_many_with_teardown_and_fail_fast(self, dag_maker): + """ + With fail_fast enabled, the teardown for an already-completed setup + should not be skipped. + """ + with dag_maker(fail_fast=True) as dag: + + @task + def my_setup(): + print("setting up multiple things") + return [1, 2, 3] + + @task + def my_work(val): + print(f"doing work with multiple things: {val}") + raise ValueError("this fails") + return val + + @task + def my_teardown(val): + print(f"teardown: {val}") + + s = my_setup() + t = my_teardown.iterate(val=s).as_teardown(setups=s) + with t: + my_work(s) + + dr = dag.test() + states = self.get_states(dr) + expected = { + "my_setup": "success", + "my_teardown": {0: "success", 1: "success", 2: "success"}, + "my_work": "failed", + } + assert states == expected + + def test_one_to_many_with_teardown_and_fail_fast_more_tasks(self, dag_maker): + """ + when fail_fast enabled, teardowns should run according to their setups. + in this case, the second teardown skips because its setup skips. + """ + with dag_maker(fail_fast=True) as dag: + for num in (1, 2): + with TaskGroup(f"tg_{num}"): + + @task + def my_setup(): + print("setting up multiple things") + return [1, 2, 3] + + @task + def my_work(val): + print(f"doing work with multiple things: {val}") + raise ValueError("this fails") + return val + + @task + def my_teardown(val): + print(f"teardown: {val}") + + s = my_setup() + t = my_teardown.iterate(val=s).as_teardown(setups=s) + with t: + my_work(s) + tg1, tg2 = dag.task_group.children.values() + tg1 >> tg2 + dr = dag.test() + states = self.get_states(dr) + expected = { + "tg_1.my_setup": "success", + "tg_1.my_teardown": {0: "success", 1: "success", 2: "success"}, + "tg_1.my_work": "failed", + "tg_2.my_setup": "skipped", + "tg_2.my_teardown": "skipped", + "tg_2.my_work": "skipped", + } + assert states == expected + + def test_one_to_many_with_teardown_and_fail_fast_more_tasks_mapped_setup(self, dag_maker): + """ + when fail_fast enabled, teardowns should run according to their setups. + in this case, the second teardown skips because its setup skips. + """ + with dag_maker(fail_fast=True) as dag: + for num in (1, 2): + with TaskGroup(f"tg_{num}"): + + @task + def my_pre_setup(): + print("input to the setup") + return [1, 2, 3] + + @task + def my_setup(val): + print("setting up multiple things") + return val + + @task + def my_work(val): + print(f"doing work with multiple things: {val}") + raise ValueError("this fails") + return val + + @task + def my_teardown(val): + print(f"teardown: {val}") + + s = my_setup.iterate(val=my_pre_setup()) + t = my_teardown.iterate(val=s).as_teardown(setups=s) + with t: + my_work(s) + tg1, tg2 = dag.task_group.children.values() + tg1 >> tg2 + dr = dag.test() + states = self.get_states(dr) + expected = { + "tg_1.my_pre_setup": "success", + "tg_1.my_setup": {0: "success", 1: "success", 2: "success"}, + "tg_1.my_teardown": {0: "success", 1: "success", 2: "success"}, + "tg_1.my_work": "failed", + "tg_2.my_pre_setup": "skipped", + "tg_2.my_setup": "skipped", + "tg_2.my_teardown": "skipped", + "tg_2.my_work": "skipped", + } + assert states == expected + + def test_skip_one_mapped_task_from_task_group_with_generator(self, dag_maker): + with dag_maker() as dag: + + @task + def make_list(): + return [1, 2, 3] + + @task + def double(n): + if n == 2: + raise AirflowSkipException() + return n * 2 + + @task + def last(n): ... + + @task_group + def group(n: int) -> None: + last(double(n)) + + group.iterate(n=make_list()) + + dr = dag.test() + states = self.get_states(dr) + expected = { + "group.double": {0: "success", 1: "skipped", 2: "success"}, + "group.last": {0: "success", 1: "skipped", 2: "success"}, + "make_list": "success", + } + assert states == expected + + def test_skip_one_mapped_task_from_task_group(self, dag_maker): + with dag_maker() as dag: + + @task + def double(n): + if n == 2: + raise AirflowSkipException() + return n * 2 + + @task + def last(n): ... + + @task_group + def group(n: int) -> None: + last(double(n)) + + group.iterate(n=[1, 2, 3]) + + dr = dag.test() + states = self.get_states(dr) + expected = { + "group.double": {0: "success", 1: "skipped", 2: "success"}, + "group.last": {0: "success", 1: "skipped", 2: "success"}, + } + assert states == expected diff --git a/task-sdk/src/airflow/sdk/bases/decorator.py b/task-sdk/src/airflow/sdk/bases/decorator.py index 6780fe29fe27d..c0bfea661b9fe 100644 --- a/task-sdk/src/airflow/sdk/bases/decorator.py +++ b/task-sdk/src/airflow/sdk/bases/decorator.py @@ -534,6 +534,12 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: ) return XComArg(operator=operator) + def iterate(self, **mapped_kwargs: OperatorExpandArgument) -> XComArg: + raise NotImplementedError + + def iterate_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg: + raise NotImplementedError + def partial(self, **kwargs: Any) -> _TaskDecorator[FParams, FReturn, OperatorSubclass]: self._validate_arg_names("partial", kwargs) old_kwargs = self.kwargs.get("op_kwargs", {}) diff --git a/task-sdk/src/airflow/sdk/bases/operator.py b/task-sdk/src/airflow/sdk/bases/operator.py index 3f74694aaa7c4..7e6d19af571d3 100644 --- a/task-sdk/src/airflow/sdk/bases/operator.py +++ b/task-sdk/src/airflow/sdk/bases/operator.py @@ -1607,6 +1607,26 @@ def resume_execution(self, next_method: str, next_kwargs: dict[str, Any] | None, execute_callable = getattr(self, next_method) return execute_callable(context, **next_kwargs) + def next_callable( + self, next_method: str, next_kwargs: dict[str, Any] | None = None + ) -> Callable[..., Any]: + """Get the next callable from given operator.""" + from airflow.exceptions import TaskDeferralError + + # __fail__ is a special signal value for next_method that indicates + # this task was scheduled specifically to fail. + if next_method == "__fail__": + next_kwargs = next_kwargs or {} + traceback = next_kwargs.get("traceback") + if traceback is not None: + self.log.error("Trigger failed:\n%s", "\n".join(traceback)) + raise TaskDeferralError(next_kwargs.get("error", "Unknown")) + # Grab the callable off the Operator/Task and add in any kwargs + execute_callable = getattr(self, next_method) + if next_kwargs: + execute_callable = partial(execute_callable, **next_kwargs) + return execute_callable + def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None: r""" @@ -1843,9 +1863,9 @@ def chain_linear(*elements: DependencyMixin | Sequence[DependencyMixin]): E.g.: suppose you want precedence like so:: - ╭─op2─╮ ╭─op4─╮ - op1─┤ ├─├─op5─┤─op7 - ╰-op3─╯ ╰-op6─╯ + ╭�"�op2�"�╮ ╭�"�op4�"�╮ + op1�"��"� �"��"��"��"�op5�"��"��"�op7 + ╰-op3�"�╯ ╰-op6�"�╯ Then you can accomplish like so:: diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py index abcb0366eb4d9..de2a832a7f6f8 100644 --- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -68,6 +68,7 @@ OperatorExpandArgument, OperatorExpandKwargsArgument, ) + from airflow.models.iterableoperator import IterableOperator from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.bases.operatorlink import BaseOperatorLink from airflow.sdk.definitions.dag import DAG @@ -84,7 +85,7 @@ TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, list[TaskStateChangeCallback]] -ValidationSource = Union[Literal["expand"], Literal["partial"]] +ValidationSource = Union[Literal["expand"], Literal["iterate"], Literal["partial"]] def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: @@ -254,6 +255,47 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: ) return op + def iterate(self, **mapped_kwargs: OperatorExpandArgument) -> IterableOperator: + if not mapped_kwargs: + raise TypeError("no arguments to expand against") + validate_mapping_kwargs(self.operator_class, "iterate", mapped_kwargs) + prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified") + return self._iterate(DictOfListsExpandInput(mapped_kwargs)) + + def iterate_kwargs(self, kwargs: OperatorExpandKwargsArgument) -> IterableOperator: + if isinstance(kwargs, Sequence): + for item in kwargs: + if not isinstance(item, (XComArg, Mapping)): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") + elif not isinstance(kwargs, XComArg): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") + return self._iterate(ListOfDictsExpandInput(kwargs)) + + def _iterate(self, expand_input) -> IterableOperator: + from airflow.models.iterableoperator import IterableOperator + + ensure_xcomarg_return_value(expand_input.value) + kwargs = {} + for parameter_name in BaseOperator._comps: + parameter_value = self.kwargs.get(parameter_name) + if parameter_value: + kwargs[parameter_name] = parameter_value + # We don't retry the whole stream operator, we retry the individual tasks + kwargs["retries"] = 0 + # We don't want to time out the whole stream operator, we only time out the individual tasks + kwargs["timeout"] = kwargs.pop("execution_timeout", None) + kwargs["max_active_tis_per_dag"] = self.kwargs.get("max_active_tis_per_dag") + self.kwargs.pop("task_group", None) + self.kwargs["task_id"] = kwargs["task_id"] = kwargs["task_id"].rsplit(".", 1)[-1] + self.kwargs["do_xcom_push"] = self.kwargs.get("do_xcom_push", True) + self._expand_called = True + return IterableOperator( + **kwargs, + operator_class=self.operator_class, + expand_input=expand_input, + partial_kwargs=self.kwargs, + ) + @attrs.define( kw_only=True, diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 2a93585304cb0..d7092559579a2 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -351,6 +351,8 @@ def resolve(self, context: Mapping[str, Any]) -> Any: map_indexes=map_indexes, ) if not isinstance(result, ArgNotSet): + if isinstance(result, ResolveMixin): + return result.resolve(context) return result if self.key == XCOM_RETURN_KEY: return None @@ -377,21 +379,33 @@ def _get_callable_name(f: Callable | str) -> str: return "" -class _MapResult(Sequence): - def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: +class _MapResult(Sequence, Iterable): + def __init__(self, value: Sequence | Iterable, callables: list) -> None: self.value = value self.callables = callables - def __getitem__(self, index: Any) -> Any: - value = self.value[index] + def __getitem__(self, index: int) -> Any: + if not (0 <= index < len(self)): + raise IndexError - for f in self.callables: - value = f(value) - return value + value = self.value[index] + return self._apply_callables(value) def __len__(self) -> int: + if isinstance(self.value, Iterable): + raise TypeError + return len(self.value) + def __iter__(self) -> Iterator: + for item in iter(self.value): + yield self._apply_callables(item) + + def _apply_callables(self, value): + for func in self.callables: + value = func(value) + return value + class MapXComArg(XComArg): """ @@ -429,8 +443,10 @@ def map(self, f: Callable[[Any], Any]) -> MapXComArg: def resolve(self, context: Mapping[str, Any]) -> Any: value = self.arg.resolve(context) - if not isinstance(value, (Sequence, dict)): + if not isinstance(value, (Sequence, Iterable, dict)): raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") + if isinstance(value, ResolveMixin): + value = value.resolve(context) return _MapResult(value, self.callables)