diff --git a/airflow-core/.pre-commit-config.yaml b/airflow-core/.pre-commit-config.yaml index 7573eec4e6533..121b51d4e8b55 100644 --- a/airflow-core/.pre-commit-config.yaml +++ b/airflow-core/.pre-commit-config.yaml @@ -376,6 +376,7 @@ repos: ^src/airflow/timetables/assets\.py$| ^src/airflow/timetables/base\.py$| ^src/airflow/timetables/simple\.py$| + ^src/airflow/triggers/base\.py$| ^src/airflow/utils/cli\.py$| ^src/airflow/utils/context\.py$| ^src/airflow/utils/dag_cycle_tester\.py$| diff --git a/airflow-core/src/airflow/api_fastapi/common/dagbag.py b/airflow-core/src/airflow/api_fastapi/common/dagbag.py index 3ca4483ce876a..c7630cde9f7d3 100644 --- a/airflow-core/src/airflow/api_fastapi/common/dagbag.py +++ b/airflow-core/src/airflow/api_fastapi/common/dagbag.py @@ -84,7 +84,7 @@ def get_dag_for_run_or_latest_version( dag: SerializedDAG | None = None if dag_run: if dag_run.created_dag_version_id: - dag = dag_bag._get_dag(dag_run.created_dag_version_id, session=session) + dag = dag_bag.get_dag(dag_run.created_dag_version_id, session=session) if not dag: dag = dag_bag.get_dag_for_run(dag_run, session=session) elif dag_id: diff --git a/airflow-core/src/airflow/executors/workloads/trigger.py b/airflow-core/src/airflow/executors/workloads/trigger.py index 25bca9ce44b13..2959cde6ee380 100644 --- a/airflow-core/src/airflow/executors/workloads/trigger.py +++ b/airflow-core/src/airflow/executors/workloads/trigger.py @@ -35,8 +35,11 @@ class RunTrigger(BaseModel): """ id: int - ti: TaskInstanceDTO | None # Could be none for asset-based triggers. classpath: str # Dot-separated name of the module+fn to import and run this workload. encrypted_kwargs: str + ti: TaskInstanceDTO | None = None # Could be none for asset-based triggers. timeout_after: datetime | None = None type: Literal["RunTrigger"] = Field(init=False, default="RunTrigger") + dag_data: dict | None = ( + None # Serialized Dag model in dict format so it can be deserialized in trigger subprocess. + ) diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 1406283c05cb3..7fb5fdcd0187c 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -48,6 +48,7 @@ from airflow.executors.workloads.task import TaskInstanceDTO from airflow.jobs.base_job_runner import BaseJobRunner from airflow.jobs.job import perform_heartbeat +from airflow.models.dagbag import DBDagBag from airflow.models.trigger import Trigger from airflow.observability.metrics import stats_utils from airflow.sdk.api.datamodels._generated import HITLDetailResponse @@ -81,10 +82,12 @@ _RequestFrame, ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess, make_buffered_socket_reader +from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance +from airflow.serialization.serialized_objects import DagSerialization from airflow.triggers.base import BaseEventTrigger, BaseTrigger, DiscrimatedTriggerEvent, TriggerEvent from airflow.utils.helpers import log_filename_template_renderer from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, create_session, provide_session if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -93,6 +96,7 @@ from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI from airflow.jobs.job import Job from airflow.sdk.api.client import Client + from airflow.sdk.definitions.context import Context from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI logger = logging.getLogger(__name__) @@ -626,6 +630,66 @@ def emit_metrics(self): extra_tags={"hostname": self.job.hostname}, ) + @provide_session + def create_workload( + self, + trigger: Trigger, + dag_bag: DBDagBag, + render_log_fname=log_filename_template_renderer(), + session: Session = NEW_SESSION, + ) -> workloads.RunTrigger | None: + if trigger.task_instance is None: + return workloads.RunTrigger( + id=trigger.id, + classpath=trigger.classpath, + encrypted_kwargs=trigger.encrypted_kwargs, + ) + + if not trigger.task_instance.dag_version_id: + # This is to handle 2 to 3 upgrade where TI.dag_version_id can be none + log.warning( + "TaskInstance associated with Trigger has no associated Dag Version, skipping the trigger", + ti_id=trigger.task_instance.id, + ) + return None + + log_path = render_log_fname(ti=trigger.task_instance) + ser_ti = TaskInstanceDTO.model_validate(trigger.task_instance, from_attributes=True) + + # When producing logs from TIs, include the job id producing the logs to disambiguate it. + self.logger_cache[trigger.id] = TriggerLoggingFactory( + log_path=f"{log_path}.trigger.{self.job.id}.log", + ti=ser_ti, # type: ignore + ) + + serialized_dag_model = dag_bag.get_dag_model( + version_id=trigger.task_instance.dag_version_id, + session=session, + ) + + if serialized_dag_model: + task = serialized_dag_model.dag.get_task(trigger.task_instance.task_id) + + # When a TaskInstance of a Trigger contains a task with start_from_trigger enabled, + # it means we need to load the SerializedDagModel so we can build a RuntimeTaskInstance later on which + # will allow us to build a context on which we will render the templated fields. + if task.start_from_trigger: + return workloads.RunTrigger( + id=trigger.id, + classpath=trigger.classpath, + encrypted_kwargs=trigger.encrypted_kwargs, + ti=ser_ti, + timeout_after=trigger.task_instance.trigger_timeout, + dag_data=serialized_dag_model.data, + ) + return workloads.RunTrigger( + id=trigger.id, + classpath=trigger.classpath, + encrypted_kwargs=trigger.encrypted_kwargs, + ti=ser_ti, + timeout_after=trigger.task_instance.trigger_timeout, + ) + def update_triggers(self, requested_trigger_ids: set[int]): """ Request that we update what triggers we're running. @@ -634,8 +698,7 @@ def update_triggers(self, requested_trigger_ids: set[int]): adds them to the dequeues so the subprocess can actually mutate the running trigger set. """ - render_log_fname = log_filename_template_renderer() - + dag_bag = DBDagBag() known_trigger_ids = ( self.running_triggers.union(x[0] for x in self.events) .union(self.cancelling_triggers) @@ -646,60 +709,45 @@ def update_triggers(self, requested_trigger_ids: set[int]): new_trigger_ids = requested_trigger_ids - known_trigger_ids cancel_trigger_ids = self.running_triggers - requested_trigger_ids # Bulk-fetch new trigger records - new_triggers = Trigger.bulk_fetch(new_trigger_ids) - trigger_ids_with_non_task_associations = Trigger.fetch_trigger_ids_with_non_task_associations() - to_create: list[workloads.RunTrigger] = [] - # Add in new triggers - for new_id in new_trigger_ids: - # Check it didn't vanish in the meantime - if new_id not in new_triggers: - log.warning("Trigger disappeared before we could start it", id=new_id) - continue - - new_trigger_orm = new_triggers[new_id] - - # If the trigger is not associated to a task, an asset, or a callback, this means the TaskInstance - # row was updated by either Trigger.submit_event or Trigger.submit_failure - # and can happen when a single trigger Job is being run on multiple TriggerRunners - # in a High-Availability setup. - if new_trigger_orm.task_instance is None and new_id not in trigger_ids_with_non_task_associations: - log.info( - ( - "TaskInstance Trigger is None. It was likely updated by another trigger job. " - "Skipping trigger instantiation." - ), - id=new_id, - ) - continue - - workload = workloads.RunTrigger( - classpath=new_trigger_orm.classpath, - id=new_id, - encrypted_kwargs=new_trigger_orm.encrypted_kwargs, - ti=None, + with create_session() as session: + # Bulk-fetch new trigger records + new_triggers = Trigger.bulk_fetch(new_trigger_ids, session=session) + trigger_ids_with_non_task_associations = Trigger.fetch_trigger_ids_with_non_task_associations( + session=session ) - if new_trigger_orm.task_instance: - log_path = render_log_fname(ti=new_trigger_orm.task_instance) - if not new_trigger_orm.task_instance.dag_version_id: - # This is to handle 2 to 3 upgrade where TI.dag_version_id can be none - log.warning( - "TaskInstance associated with Trigger has no associated Dag Version, skipping the trigger", - ti_id=new_trigger_orm.task_instance.id, - ) + to_create: list[workloads.RunTrigger] = [] + # Add in new triggers + for new_trigger_id in new_trigger_ids: + # Check it didn't vanish in the meantime + if new_trigger_id not in new_triggers: + log.warning("Trigger disappeared before we could start it", id=new_trigger_id) continue - ser_ti = TaskInstanceDTO.model_validate(new_trigger_orm.task_instance, from_attributes=True) - # When producing logs from TIs, include the job id producing the logs to disambiguate it. - self.logger_cache[new_id] = TriggerLoggingFactory( - log_path=f"{log_path}.trigger.{self.job.id}.log", - ti=ser_ti, # type: ignore - ) - workload.ti = ser_ti - workload.timeout_after = new_trigger_orm.task_instance.trigger_timeout + new_trigger_orm = new_triggers[new_trigger_id] + + # If the trigger is not associated to a task, an asset, or a callback, this means the TaskInstance + # row was updated by either Trigger.submit_event or Trigger.submit_failure + # and can happen when a single trigger Job is being run on multiple TriggerRunners + # in a High-Availability setup. + if ( + new_trigger_orm.task_instance is None + and new_trigger_id not in trigger_ids_with_non_task_associations + ): + log.info( + ( + "TaskInstance of Trigger is None. It was likely updated by another trigger job. " + "Skipping trigger instantiation." + ), + id=new_trigger_id, + ) + continue - to_create.append(workload) + if workload := self.create_workload( + trigger=new_trigger_orm, dag_bag=dag_bag, session=session + ): + to_create.append(workload) - self.creating_triggers.extend(to_create) + self.creating_triggers.extend(to_create) if cancel_trigger_ids: # Enqueue orphaned triggers for cancellation @@ -954,9 +1002,19 @@ async def init_comms(self): raise RuntimeError(f"Required first message to be a messages.StartTriggerer, it was {msg}") async def create_triggers(self): + def create_runtime_ti(encoded_dag: dict) -> RuntimeTaskInstance: + task = DagSerialization.from_dict(encoded_dag).get_task(workload.ti.task_id) + + # I need to recreate a TaskInstance from task_runner before invoking get_template_context (airflow.executors.workloads.TaskInstance) + return RuntimeTaskInstance.model_construct( + **workload.ti.model_dump(exclude_unset=True), + task=task, + ) + """Drain the to_create queue and create all new triggers that have been requested in the DB.""" while self.to_create: await asyncio.sleep(0) + context: Context | None = None workload = self.to_create.popleft() trigger_id = workload.id if trigger_id in self.triggers: @@ -984,24 +1042,32 @@ async def create_triggers(self): # that could cause None values in collections. kw = Trigger._decrypt_kwargs(workload.encrypted_kwargs) deserialised_kwargs = {k: smart_decode_trigger_kwargs(v) for k, v in kw.items()} - trigger_instance = trigger_class(**deserialised_kwargs) + + if ti := workload.ti: + trigger_name = f"{ti.dag_id}/{ti.run_id}/{ti.task_id}/{ti.map_index}/{ti.try_number} (ID {trigger_id})" + trigger_instance = trigger_class(**deserialised_kwargs) + + if workload.dag_data: + runtime_ti = create_runtime_ti(workload.dag_data) + context = runtime_ti.get_template_context() + trigger_instance.task_instance = runtime_ti + else: + trigger_instance.task_instance = ti + else: + trigger_name = f"ID {trigger_id}" + trigger_instance = trigger_class(**deserialised_kwargs) except TypeError as err: self.log.error("Trigger failed to inflate", error=err) self.failed_triggers.append((trigger_id, err)) continue trigger_instance.trigger_id = trigger_id trigger_instance.triggerer_job_id = self.job_id - trigger_instance.task_instance = ti = workload.ti trigger_instance.timeout_after = workload.timeout_after - trigger_name = ( - f"{ti.dag_id}/{ti.run_id}/{ti.task_id}/{ti.map_index}/{ti.try_number} (ID {trigger_id})" - if ti - else f"ID {trigger_id}" - ) self.triggers[trigger_id] = { "task": asyncio.create_task( - self.run_trigger(trigger_id, trigger_instance, workload.timeout_after), name=trigger_name + self.run_trigger(trigger_id, trigger_instance, workload.timeout_after, context), + name=trigger_name, ), "is_watcher": isinstance(trigger_instance, BaseEventTrigger), "name": trigger_name, @@ -1168,7 +1234,13 @@ async def block_watchdog(self): ) Stats.incr("triggers.blocked_main_thread") - async def run_trigger(self, trigger_id: int, trigger: BaseTrigger, timeout_after: datetime | None = None): + async def run_trigger( + self, + trigger_id: int, + trigger: BaseTrigger, + timeout_after: datetime | None = None, + context: Context | None = None, + ): """Run a trigger (they are async generators) and push their events into our outbound event deque.""" if not os.environ.get("AIRFLOW_DISABLE_GREENBACK_PORTAL", "").lower() == "true": import greenback @@ -1180,6 +1252,9 @@ async def run_trigger(self, trigger_id: int, trigger: BaseTrigger, timeout_after name = self.triggers[trigger_id]["name"] self.log.info("trigger %s starting", name) try: + if context: + trigger.render_template_fields(context=context) + async for event in trigger.run(): await self.log.ainfo( "Trigger fired event", name=self.triggers[trigger_id]["name"], result=event diff --git a/airflow-core/src/airflow/models/dagbag.py b/airflow-core/src/airflow/models/dagbag.py index e04f77d06df34..dd399f0d8f3a9 100644 --- a/airflow-core/src/airflow/models/dagbag.py +++ b/airflow-core/src/airflow/models/dagbag.py @@ -45,24 +45,27 @@ class DBDagBag: """ def __init__(self, load_op_links: bool = True) -> None: - self._dags: dict[UUID, SerializedDAG] = {} # dag_version_id to dag + self._dags: dict[UUID, SerializedDagModel] = {} # dag_version_id to dag self.load_op_links = load_op_links - def _read_dag(self, serdag: SerializedDagModel) -> SerializedDAG | None: - serdag.load_op_links = self.load_op_links - if dag := serdag.dag: - self._dags[serdag.dag_version_id] = dag + def _read_dag(self, serialized_dag_model: SerializedDagModel) -> SerializedDAG | None: + serialized_dag_model.load_op_links = self.load_op_links + if dag := serialized_dag_model.dag: + self._dags[serialized_dag_model.dag_version_id] = serialized_dag_model return dag - def _get_dag(self, version_id: UUID, session: Session) -> SerializedDAG | None: - if dag := self._dags.get(version_id): - return dag - dag_version = session.get(DagVersion, version_id, options=[joinedload(DagVersion.serialized_dag)]) - if not dag_version: - return None - if not (serdag := dag_version.serialized_dag): - return None - return self._read_dag(serdag) + def get_dag_model(self, version_id: UUID, session: Session) -> SerializedDagModel | None: + if not (serialized_dag_model := self._dags.get(version_id)): + dag_version = session.get(DagVersion, version_id, options=[joinedload(DagVersion.serialized_dag)]) + if not dag_version or not (serialized_dag_model := dag_version.serialized_dag): + return None + self._read_dag(serialized_dag_model) + return serialized_dag_model + + def get_dag(self, version_id: UUID, session: Session) -> SerializedDAG | None: + if serialized_dag_model := self.get_dag_model(version_id=version_id, session=session): + return serialized_dag_model.dag + return None @staticmethod def _version_from_dag_run(dag_run: DagRun, *, session: Session) -> UUID | None: @@ -74,24 +77,24 @@ def _version_from_dag_run(dag_run: DagRun, *, session: Session) -> UUID | None: def get_dag_for_run(self, dag_run: DagRun, session: Session) -> SerializedDAG | None: if version_id := self._version_from_dag_run(dag_run=dag_run, session=session): - return self._get_dag(version_id=version_id, session=session) + return self.get_dag(version_id=version_id, session=session) return None def iter_all_latest_version_dags(self, *, session: Session) -> Generator[SerializedDAG, None, None]: """Walk through all latest version dags available in the database.""" from airflow.models.serialized_dag import SerializedDagModel - for sdm in session.scalars(select(SerializedDagModel)): - if dag := self._read_dag(sdm): + for serialized_dag_model in session.scalars(select(SerializedDagModel)): + if dag := self._read_dag(serialized_dag_model): yield dag def get_latest_version_of_dag(self, dag_id: str, *, session: Session) -> SerializedDAG | None: """Get the latest version of a dag by its id.""" from airflow.models.serialized_dag import SerializedDagModel - if not (serdag := SerializedDagModel.get(dag_id, session=session)): + if not (serialized_dag_model := SerializedDagModel.get(dag_id, session=session)): return None - return self._read_dag(serdag) + return self._read_dag(serialized_dag_model) def generate_md5_hash(context): diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index c37713da4d843..ef916c11a0293 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -1961,7 +1961,14 @@ def schedule_tis( debug_try_number_check = self.log.isEnabledFor(logging.DEBUG) expected_try_number_by_ti_id: dict[UUID, tuple[int, int, str | None]] = {} for ti in schedulable_tis: - if ti.is_schedulable: + if not ti.is_schedulable: + empty_ti_ids.append(ti.id) + # The defer_task method will check "start_trigger_args" to see whether the operator + # start execution from triggerer. If so, we'll also check "start_from_trigger" + # to see whether this feature is turned on and defer this task. + # If not, we'll add this "ti" into "schedulable_ti_ids" and later + # execute it to run in the worker. + elif not ti.defer_task(session=session): schedulable_ti_ids.append(ti.id) if ti.state == TaskInstanceState.UP_FOR_RESCHEDULE: reschedule_ti_ids.add(ti.id) @@ -1973,25 +1980,6 @@ def schedule_tis( ti.try_number, ti.state, ) - # Check "start_trigger_args" to see whether the operator supports - # start execution from triggerer. If so, we'll check "start_from_trigger" - # to see whether this feature is turned on and defer this task. - # If not, we'll add this "ti" into "schedulable_ti_ids" and later - # execute it to run in the worker. - # TODO TaskSDK: This is disabled since we haven't figured out how - # to render start_from_trigger in the scheduler. If we need to - # render the value in a worker, it kind of defeats the purpose of - # this feature (which is to save a worker process if possible). - # elif task.start_trigger_args is not None: - # if task.expand_start_from_trigger(context=ti.get_template_context()): - # ti.start_date = timezone.utcnow() - # if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE: - # ti.try_number += 1 - # ti.defer_task(exception=None, session=session) - # else: - # schedulable_ti_ids.append(ti.id) - else: - empty_ti_ids.append(ti.id) count = 0 # Don't only check if the TI.id is in id_chunk diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 443468161f8ab..33a1559964898 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -119,7 +119,7 @@ from airflow.serialization.definitions.dag import SerializedDAG from airflow.serialization.definitions.mappedoperator import Operator from airflow.serialization.definitions.taskgroup import SerializedTaskGroup - + from airflow.triggers.base import StartTriggerArgs PAST_DEPENDS_MET = "past_depends_met" @@ -1583,6 +1583,73 @@ def update_heartbeat(self): .values(last_heartbeat_at=timezone.utcnow()) ) + @property + def start_trigger_args(self) -> StartTriggerArgs | None: + if self.task and self.task.start_from_trigger is True: + return self.task.start_trigger_args + return None + + # TODO: We have some code duplication here and in the _create_ti_state_update_query_and_update_state + # method of the task_instances module in the execution api when a TIDeferredStatePayload is being + # processed. This is because of a TaskInstance being updated differently using SQLAlchemy. + # If we use the approach from the execution api as common code in the DagRun schedule_tis method, + # the side effect is the changes done to the task instance aren't picked up by the scheduler and + # thus the task instance isn't processed until the scheduler is restarted. + @provide_session + def defer_task(self, session: Session = NEW_SESSION) -> bool: + """ + Mark the task as deferred and sets up the trigger that is needed to resume it when TaskDeferred is raised. + + :meta: private + """ + from airflow.models.trigger import Trigger + + if TYPE_CHECKING: + assert self.start_date + assert isinstance(self.task, Operator) + + if start_trigger_args := self.start_trigger_args: + trigger_kwargs = start_trigger_args.trigger_kwargs or {} + timeout = start_trigger_args.timeout + + # Calculate timeout too if it was passed + if timeout is not None: + self.trigger_timeout = timezone.utcnow() + timeout + else: + self.trigger_timeout = None + + trigger_row = Trigger( + classpath=start_trigger_args.trigger_cls, + kwargs=trigger_kwargs, + ) + + # First, make the trigger entry + session.add(trigger_row) + session.flush() + + # Then, update ourselves so it matches the deferral request + # Keep an eye on the logic in `check_and_change_state_before_execution()` + # depending on self.next_method semantics + self.state = TaskInstanceState.DEFERRED + self.trigger_id = trigger_row.id + self.next_method = start_trigger_args.next_method + self.next_kwargs = start_trigger_args.next_kwargs or {} + + # If an execution_timeout is set, set the timeout to the minimum of + # it and the trigger timeout + if execution_timeout := self.task.execution_timeout: + if self.trigger_timeout: + self.trigger_timeout = min(self.start_date + execution_timeout, self.trigger_timeout) + else: + self.trigger_timeout = self.start_date + execution_timeout + self.start_date = timezone.utcnow() + if self.state != TaskInstanceState.UP_FOR_RESCHEDULE: + self.try_number += 1 + if self.test_mode: + _add_log(event=self.state, task_instance=self, session=session) + return True + return False + @classmethod def fetch_handle_failure_context( cls, diff --git a/airflow-core/src/airflow/serialization/definitions/mappedoperator.py b/airflow-core/src/airflow/serialization/definitions/mappedoperator.py index 1cf6d357e651a..6a8e94f02b7d0 100644 --- a/airflow-core/src/airflow/serialization/definitions/mappedoperator.py +++ b/airflow-core/src/airflow/serialization/definitions/mappedoperator.py @@ -481,9 +481,9 @@ def expand_start_from_trigger(self, *, context: Context) -> bool: return False # TODO (GH-52141): Implement this. log.warning( - "Starting a mapped task from triggerer is currently unsupported", - task_id=self.task_id, - dag_id=self.dag_id, + "Starting a mapped task %r from dag %r on triggerer is currently unsupported", + self.task_id, + self.dag_id, ) return False diff --git a/airflow-core/src/airflow/triggers/base.py b/airflow-core/src/airflow/triggers/base.py index 416558242b8a0..7ca7ed20a7463 100644 --- a/airflow-core/src/airflow/triggers/base.py +++ b/airflow-core/src/airflow/triggers/base.py @@ -21,7 +21,7 @@ from collections.abc import AsyncIterator from dataclasses import dataclass from datetime import timedelta -from typing import Annotated, Any +from typing import TYPE_CHECKING, Annotated, Any import structlog from pydantic import ( @@ -32,11 +32,24 @@ model_serializer, ) +from airflow.sdk.definitions._internal.templater import Templater from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import TaskInstanceState log = structlog.get_logger(logger_name=__name__) +if TYPE_CHECKING: + from typing import TypeAlias + + import jinja2 + + from airflow.models.mappedoperator import MappedOperator + from airflow.models.taskinstance import TaskInstance + from airflow.sdk.definitions.context import Context + from airflow.serialization.serialized_objects import SerializedBaseOperator + + Operator: TypeAlias = MappedOperator | SerializedBaseOperator + @dataclass class StartTriggerArgs: @@ -49,7 +62,7 @@ class StartTriggerArgs: timeout: timedelta | None = None -class BaseTrigger(abc.ABC, LoggingMixin): +class BaseTrigger(abc.ABC, Templater, LoggingMixin): """ Base class for all triggers. @@ -66,14 +79,56 @@ class BaseTrigger(abc.ABC, LoggingMixin): supports_triggerer_queue: bool = True def __init__(self, **kwargs): + super().__init__() # these values are set by triggerer when preparing to run the instance # when run, they are injected into logger record. - self.task_instance = None + self._task_instance = None self.trigger_id = None + self.template_fields = () + self.template_ext = () + self.task_id = None def _set_context(self, context): """Part of LoggingMixin and used mainly for configuration of task logging; not used for triggers.""" - raise NotImplementedError + pass + + @property + def task(self) -> Operator | None: + # We must check if the TaskInstance is the generated Pydantic one or the RuntimeTaskInstance + if self.task_instance and hasattr(self.task_instance, "task"): + return self.task_instance.task + return None + + @property + def task_instance(self) -> TaskInstance: + return self._task_instance + + @task_instance.setter + def task_instance(self, value: TaskInstance | None) -> None: + self._task_instance = value + if self.task_instance: + self.task_id = self.task_instance.task_id + if self.task: + self.template_fields = self.task.template_fields + self.template_ext = self.task.template_ext + + def render_template_fields( + self, + context: Context, + jinja_env: jinja2.Environment | None = None, + ) -> None: + """ + Template all attributes listed in *self.template_fields*. + + This mutates the attributes in-place and is irreversible. + + :param context: Context dict with values to apply on content. + :param jinja_env: Jinja's environment to use for rendering. + """ + if not jinja_env: + jinja_env = self.get_template_env() + # We only need to render templated fields if templated fields are part of the start_trigger_args + self._do_render_template_fields(self, self.template_fields, context, jinja_env, set()) @abc.abstractmethod def serialize(self) -> tuple[str, dict[str, Any]]: diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 802a34192e352..cb18cea93ae12 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -111,9 +111,9 @@ def create_trigger_in_db(session, trigger, operator=None): session.merge(testing_bundle) session.flush() - dag_model = DagModel(dag_id="test_dag", bundle_name=bundle_name) - dag = DAG(dag_id=dag_model.dag_id, schedule="@daily", start_date=pendulum.datetime(2023, 1, 1)) date = pendulum.datetime(2023, 1, 1) + dag_model = DagModel(dag_id="test_dag", bundle_name=bundle_name) + dag = DAG(dag_id=dag_model.dag_id, schedule="@daily", start_date=date) run = DagRun( dag_id=dag_model.dag_id, run_id="test_run", @@ -256,6 +256,7 @@ def send_msg_spy(self, msg, *args, **kwargs): classpath=trigger.serialize()[0], encrypted_kwargs=trigger_orm.encrypted_kwargs, kind="RunTrigger", + dag_data=ANY, ) ) # OK, now remove it from the DB diff --git a/airflow-core/tests/unit/models/test_dagbag.py b/airflow-core/tests/unit/models/test_dagbag.py index 3b5b98877262e..8f56e0200fc2a 100644 --- a/airflow-core/tests/unit/models/test_dagbag.py +++ b/airflow-core/tests/unit/models/test_dagbag.py @@ -16,8 +16,14 @@ # under the License. from __future__ import annotations +from unittest.mock import MagicMock, patch + import pytest +from airflow.models.dagbag import DBDagBag +from airflow.models.serialized_dag import SerializedDagModel +from airflow.serialization.serialized_objects import SerializedDAG + pytestmark = pytest.mark.db_test # This file previously contained tests for DagBag functionality, but those tests @@ -26,3 +32,76 @@ # # Tests for models-specific functionality (DBDagBag, DagPriorityParsingRequest, etc.) # would remain in this file, but currently no such tests exist. + + +class TestDBDagBag: + def setup_method(self): + self.db_dag_bag = DBDagBag() + self.session = MagicMock() + + def test__read_dag_stores_and_returns_dag(self): + """It should store the SerializedDagModel in _dags and return the dag.""" + mock_dag = MagicMock(spec=SerializedDAG) + mock_serdag = MagicMock(spec=SerializedDagModel) + mock_serdag.dag = mock_dag + mock_serdag.dag_version_id = "v1" + + result = self.db_dag_bag._read_dag(mock_serdag) + + assert result == mock_dag + assert self.db_dag_bag._dags["v1"] == mock_serdag + assert mock_serdag.load_op_links is True + + def test__read_dag_returns_none_when_no_dag(self): + """It should return None and not modify _dags when no DAG is present.""" + mock_serdag = MagicMock(spec=SerializedDagModel) + mock_serdag.dag = None + mock_serdag.dag_version_id = "v1" + + result = self.db_dag_bag._read_dag(mock_serdag) + + assert result is None + assert "v1" not in self.db_dag_bag._dags + + def test_get_dag_model(self): + """It should return the cached SerializedDagModel if already loaded.""" + mock_serdag = MagicMock(spec=SerializedDagModel) + mock_serdag.dag_version_id = "v1" + mock_dag_version = MagicMock() + mock_dag_version.serialized_dag = mock_serdag + self.session.get.return_value = mock_dag_version + + self.db_dag_bag.get_dag_model("v1", session=self.session) + result = self.db_dag_bag.get_dag_model("v1", session=self.session) + + assert result == mock_serdag + self.session.get.assert_called_once() + + def test_get_dag_model_returns_none_when_not_found(self): + """It should return None if version_id not found in DB.""" + self.session.get.return_value = None + + result = self.db_dag_bag.get_dag_model("v1", session=self.session) + + assert result is None + + def test_get_dag_calls_get_dag_model_and__read_dag(self): + """It should call get_dag_model and then _read_dag.""" + mock_serdag = MagicMock(spec=SerializedDagModel) + mock_serdag.dag_version_id = "v1" + mock_dag = MagicMock(spec=SerializedDAG) + mock_dag_version = MagicMock() + mock_dag_version.serialized_dag = mock_serdag + mock_serdag.dag = mock_dag + self.session.get.return_value = mock_dag_version + + result = self.db_dag_bag.get_dag("v1", session=self.session) + + self.session.get.assert_called_once() + assert result == mock_dag + + def test_get_dag_returns_none_when_model_missing(self): + """It should return None if no SerializedDagModel found.""" + with patch.object(self.db_dag_bag, "get_dag_model", return_value=None): + result = self.db_dag_bag.get_dag("v1", session=self.session) + assert result is None diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 82b9adc3162ae..e0de7930dbfe6 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -2649,6 +2649,103 @@ def mock_policy(task_instance: TaskInstance): assert ti.max_tries == expected_max_tries +def test_defer_task_returns_false_when_no_start_from_trigger(create_task_instance): + session = mock.MagicMock() + ti = create_task_instance( + dag_id="test_defer_task", + task_id="test_defer_task_op", + ) + assert not ti.defer_task(session=session) + + +def test_defer_task_returns_false_when_no_start_trigger_args(create_task_instance): + session = mock.MagicMock() + ti = create_task_instance( + dag_id="test_defer_task", + task_id="test_defer_task", + start_from_trigger=True, + ) + assert not ti.defer_task(session=session) + + +def test_defer_task(create_task_instance): + from airflow.models.trigger import Trigger + from airflow.triggers.base import StartTriggerArgs + + session = mock.MagicMock() + ti = create_task_instance( + dag_id="test_defer_task", + task_id="test_defer_task_op", + start_from_trigger=True, + start_trigger_args=StartTriggerArgs( + trigger_cls="trigger_cls", + next_method="next_method", + trigger_kwargs={"key": "value"}, + ), + ) + assert ti.defer_task(session=session) + + # Check that session.add was called with a Trigger + assert session.add.call_count == 1 + trigger_row = session.add.call_args[0][0] + assert isinstance(trigger_row, Trigger) + assert trigger_row.classpath == "trigger_cls" + assert trigger_row.kwargs == {"key": "value"} + + # Check that session.flush was called + session.flush.assert_called_once() + + # Check that TaskInstance state was updated + assert ti.state == TaskInstanceState.DEFERRED + assert ti.trigger_id == trigger_row.id + assert ti.next_method == "next_method" + assert ti.next_kwargs == {} + + # Check trigger_timeout is set (should be None since no timeout provided) + assert ti.trigger_timeout is None + + +def test_defer_task_with_trigger_timeout(create_task_instance): + from airflow.models.trigger import Trigger + from airflow.triggers.base import StartTriggerArgs + + session = mock.MagicMock() + timeout = datetime.timedelta(hours=1) + ti = create_task_instance( + dag_id="test_defer_task_with_trigger_timeout", + task_id="test_defer_task_with_trigger_timeout_op", + start_from_trigger=True, + start_trigger_args=StartTriggerArgs( + trigger_cls="trigger_cls", + next_method="next_method", + trigger_kwargs={"key": "value"}, + timeout=timeout, + ), + ) + + # Save start_date to calculate expected trigger_timeout + now = timezone.utcnow() + ti.start_date = now + + ti.defer_task(session=session) + + # Check session interactions + assert session.add.call_count == 1 + trigger_row = session.add.call_args[0][0] + assert isinstance(trigger_row, Trigger) + session.flush.assert_called_once() + + # TaskInstance fields + assert ti.state == TaskInstanceState.DEFERRED + assert ti.trigger_id == trigger_row.id + assert ti.next_method == "next_method" + assert ti.next_kwargs == {} + + # Check trigger_timeout is set correctly (within a small tolerance) + expected_timeout = now + timeout + assert abs((ti.trigger_timeout - expected_timeout).total_seconds()) < 5 + + class TestTaskInstanceRecordTaskMapXComPush: """Test TI.xcom_push() correctly records return values for task-mapping.""" diff --git a/airflow-core/tests/unit/triggers/test_base_trigger.py b/airflow-core/tests/unit/triggers/test_base_trigger.py new file mode 100644 index 0000000000000..53066c46f6a14 --- /dev/null +++ b/airflow-core/tests/unit/triggers/test_base_trigger.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.sdk.bases.operator import BaseOperator +from airflow.triggers.base import BaseTrigger, StartTriggerArgs + + +class DummyOperator(BaseOperator): + template_fields = ("name",) + + +class DummyTrigger(BaseTrigger): + def __init__(self, name: str, **kwargs): + super().__init__(**kwargs) + self.name = name + + def run(self): + return None + + def serialize(self): + return {"name": self.name} + + +@pytest.mark.db_test +def test_render_template_fields(create_task_instance): + op = DummyOperator(task_id="dummy_task") + ti = create_task_instance( + task=op, + start_from_trigger=True, + start_trigger_args=StartTriggerArgs( + trigger_cls=f"{DummyTrigger.__module__}.{DummyTrigger.__qualname__}", + next_method="resume_method", + trigger_kwargs={"name": "Hello {{ name }}"}, + ), + ) + + trigger = DummyTrigger(name="Hello {{ name }}") + + assert not trigger.task_instance + assert not trigger.template_fields + assert not trigger.template_ext + + trigger.task_instance = ti + + assert trigger.task_instance == ti + assert "name" in trigger.template_fields + assert not trigger.template_ext + + trigger.render_template_fields(context={"name": "world"}) + + assert trigger.name == "Hello world" diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index a4b590c545277..4fc35fe351483 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -53,6 +53,7 @@ from airflow.sdk.types import DagRunProtocol, Operator from airflow.serialization.definitions.dag import SerializedDAG from airflow.timetables.base import DagRunInfo, DataInterval + from airflow.triggers.base import StartTriggerArgs from airflow.typing_compat import Self from airflow.utils.state import DagRunState, TaskInstanceState @@ -1564,6 +1565,9 @@ def maker( hostname=None, pid=None, last_heartbeat_at=None, + task: Operator | None = None, + start_from_trigger: bool = False, + start_trigger_args: StartTriggerArgs | None = None, **kwargs, ) -> TaskInstance: timezone = _import_timezone() @@ -1572,26 +1576,33 @@ def maker( if logical_date is NOTSET: # For now: default to having a logical date if None is not explicitly passed. logical_date = timezone.utcnow() - with dag_maker(dag_id, **kwargs): + with dag_maker(dag_id, **kwargs) as dag: op_kwargs = {} op_kwargs["task_display_name"] = task_display_name - task = EmptyOperator( - task_id=task_id, - max_active_tis_per_dag=max_active_tis_per_dag, - max_active_tis_per_dagrun=max_active_tis_per_dagrun, - executor_config=executor_config or {}, - on_success_callback=on_success_callback, - on_execute_callback=on_execute_callback, - on_failure_callback=on_failure_callback, - on_retry_callback=on_retry_callback, - on_skipped_callback=on_skipped_callback, - inlets=inlets, - outlets=outlets, - email=email, - pool=pool, - trigger_rule=trigger_rule, - **op_kwargs, - ) + if not task: + task = EmptyOperator( + task_id=task_id, + max_active_tis_per_dag=max_active_tis_per_dag, + max_active_tis_per_dagrun=max_active_tis_per_dagrun, + executor_config=executor_config or {}, + on_success_callback=on_success_callback, + on_execute_callback=on_execute_callback, + on_failure_callback=on_failure_callback, + on_retry_callback=on_retry_callback, + on_skipped_callback=on_skipped_callback, + inlets=inlets, + outlets=outlets, + email=email, + pool=pool, + trigger_rule=trigger_rule, + **op_kwargs, + ) + else: + task_id = task.task_id + task.dag = dag + task.start_from_trigger = start_from_trigger + task.start_trigger_args = start_trigger_args + if AIRFLOW_V_3_0_PLUS: dagrun_kwargs = { "logical_date": logical_date, diff --git a/task-sdk/src/airflow/sdk/bases/operator.py b/task-sdk/src/airflow/sdk/bases/operator.py index 6e88f0a94adac..4d5905ab73bd8 100644 --- a/task-sdk/src/airflow/sdk/bases/operator.py +++ b/task-sdk/src/airflow/sdk/bases/operator.py @@ -550,6 +550,11 @@ def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any: # Store the args passed to init -- we need them to support task.map serialization! self._BaseOperator__init_kwargs.update(kwargs) # type: ignore + # Validate trigger kwargs. + # Make sure method exists as class can depend on metaclass without extending the BaseOperator. + if hasattr(self, "_validate_start_from_trigger_kwargs"): + self._validate_start_from_trigger_kwargs() + # Set upstream task defined by XComArgs passed to template fields of the operator. # BUT: only do this _ONCE_, not once for each class in the hierarchy if not instantiated_from_mapped and func == self.__init__.__wrapped__: # type: ignore[misc] @@ -846,6 +851,14 @@ def say_hello_world(**context): to render templates as native Python types. If False, a Jinja ``Environment`` is used to render templates as string values. If None (default), inherits from the DAG setting. + :param start_from_trigger: If True, the operator starts execution directly in the triggerer, + skipping the initial worker execution phase. In this mode, templated fields are rendered + inside the triggerer instead of the worker. This avoids an extra round trip to a worker, + but may increase load on the triggerer, since the DAG must be serialized in order to + render templated fields. Use with care for DAGs with many tasks or heavy templating. + :param start_trigger_args: Used together with ``start_from_trigger`` to explicitly specify + which operator fields should be passed to the trigger. This helps limit the amount of + data serialized and sent to the triggerer. """ task_id: str @@ -1440,6 +1453,15 @@ def _set_xcomargs_dependency(self, field: str, newvalue: Any) -> None: return XComArg.apply_upstream_relationship(self, newvalue) + def _validate_start_from_trigger_kwargs(self): + if self.start_from_trigger and self.start_trigger_args and self.start_trigger_args.trigger_kwargs: + for name, val in self.start_trigger_args.trigger_kwargs.items(): + if callable(val): + raise ValueError( + f"{self.__class__.__name__} with task_id '{self.task_id}' has a callable in trigger kwargs named " + f"'{name}', which is not allowed when start_from_trigger is enabled." + ) + def on_kill(self) -> None: """ Override this method to clean up subprocesses when a task instance gets killed. diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py index e32bd377f01d0..00b811146a688 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/abstractoperator.py @@ -285,59 +285,6 @@ def _render(self, template, context, dag: DAG | None = None): dag = self.get_dag() return super()._render(template, context, dag=dag) - def _do_render_template_fields( - self, - parent: Any, - template_fields: Iterable[str], - context: Context, - jinja_env: jinja2.Environment, - seen_oids: set[int], - ) -> None: - """Override the base to use custom error logging.""" - for attr_name in template_fields: - try: - value = getattr(parent, attr_name) - except AttributeError: - raise AttributeError( - f"{attr_name!r} is configured as a template field " - f"but {parent.task_type} does not have this attribute." - ) - try: - if not value: - continue - except Exception: - # This may happen if the templated field points to a class which does not support `__bool__`, - # such as Pandas DataFrames: - # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465 - log.info( - "Unable to check if the value of type '%s' is False for task '%s', field '%s'.", - type(value).__name__, - self.task_id, - attr_name, - ) - # We may still want to render custom classes which do not support __bool__ - pass - - try: - if callable(value): - rendered_content = value(context=context, jinja_env=jinja_env) - else: - rendered_content = self.render_template(value, context, jinja_env, seen_oids) - except Exception: - # Mask sensitive values in the template before logging - from airflow.sdk._shared.secrets_masker import redact - - masked_value = redact(value) - log.exception( - "Exception rendering Jinja template for task '%s', field '%s'. Template: %r", - self.task_id, - attr_name, - masked_value, - ) - raise - else: - setattr(parent, attr_name, rendered_content) - def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: """ Return mapped nodes that are direct dependencies of the current task. diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/templater.py b/task-sdk/src/airflow/sdk/definitions/_internal/templater.py index f094ccd6b2880..cfe4a6100e482 100644 --- a/task-sdk/src/airflow/sdk/definitions/_internal/templater.py +++ b/task-sdk/src/airflow/sdk/definitions/_internal/templater.py @@ -20,7 +20,7 @@ import datetime import logging import os -from collections.abc import Collection, Iterable, Sequence +from collections.abc import Collection, Iterable, Iterator, Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Any @@ -117,6 +117,48 @@ def _should_render_native(self, dag: DAG | None = None) -> bool: return dag.render_template_as_native_obj if dag else False + def _iter_templated_fields( + self, + parent: Any, + template_fields: Iterable[str], + ) -> Iterator[tuple[str, Any]]: + """ + Iterate over template fields yielding ``(attr_name, value)`` pairs for non-empty fields. + + Fields whose value is falsy are skipped. Objects that do not support + ``__bool__`` (e.g. Pandas DataFrames) are still yielded. + """ + for attr_name in template_fields: + try: + value = getattr(parent, attr_name) + except AttributeError: + raise AttributeError( + f"{attr_name!r} is configured as a template field " + f"but {type(parent).__name__} does not have this attribute." + ) + try: + if not value: + continue + except Exception: + # This may happen if the templated field points to a class which does not support + # ``__bool__``, such as Pandas DataFrames: + # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465 + if hasattr(self, "task_id"): + log.info( + "Unable to check if the value of type '%s' is False for task '%s', field '%s'.", + type(value).__name__, + self.task_id, + attr_name, + ) + else: + log.info( + "Unable to check if the value of type '%s' is False for field '%s'.", + type(value).__name__, + attr_name, + ) + # We may still want to render custom classes which do not support __bool__ + yield attr_name, value + def _do_render_template_fields( self, parent: Any, @@ -125,15 +167,47 @@ def _do_render_template_fields( jinja_env: jinja2.Environment, seen_oids: set[int], ) -> None: - for attr_name in template_fields: - value = getattr(parent, attr_name) - rendered_content = self.render_template( - value, - context, - jinja_env, - seen_oids, - ) - if rendered_content: + """ + Render template fields on *parent* in-place. + + For each non-empty field yielded by :meth:`_iter_templated_fields`, the value is + rendered (or called, when it is callable) and the result is written back via + ``setattr``. Rendering errors are logged with masked values before being re-raised. + + :param parent: The object whose attributes will be templated. + :param template_fields: Names of the attributes to render. + :param context: Context dict with values to apply on content. + :param jinja_env: Jinja2 environment to use for rendering. + :param seen_oids: Set of already-rendered object ids used to prevent infinite + recursion on circular references. + """ + for attr_name, value in self._iter_templated_fields(parent, template_fields): + try: + if callable(value): + rendered_content = value(context=context, jinja_env=jinja_env) + else: + rendered_content = self.render_template(value, context, jinja_env, seen_oids) + except Exception: + # Mask sensitive values in the template before logging + from airflow.sdk._shared.secrets_masker import redact + + masked_value = redact(value) + if hasattr(self, "task_id"): + log.exception( + "Exception rendering Jinja template for task '%s', field '%s'. Template: %r", + self.task_id, + attr_name, + masked_value, + ) + else: + log.exception( + "Exception rendering Jinja template for %s, field '%s'. Template: %r", + type(parent).__name__, + attr_name, + masked_value, + ) + raise + else: setattr(parent, attr_name, rendered_content) def _render(self, template, context, dag=None) -> Any: diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py index abc4c86ed8544..7c0540421d438 100644 --- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py +++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py @@ -226,6 +226,16 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: task_group = partial_kwargs.pop("task_group") start_date = partial_kwargs.pop("start_date", None) end_date = partial_kwargs.pop("end_date", None) + start_from_trigger = ( + partial_kwargs["start_from_trigger"] + if "start_from_trigger" in partial_kwargs + else getattr(self.operator_class, "start_from_trigger", False) + ) + start_trigger_args = ( + partial_kwargs["start_trigger_args"] + if "start_trigger_args" in partial_kwargs + else getattr(self.operator_class, "start_trigger_args", None) + ) try: operator_name = self.operator_class.custom_operator_name # type: ignore @@ -259,8 +269,8 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: # to BaseOperator.expand() contribute to operator arguments. expand_input_attr="expand_input", # TODO: Move these to task SDK's BaseOperator and remove getattr - start_trigger_args=getattr(self.operator_class, "start_trigger_args", None), - start_from_trigger=bool(getattr(self.operator_class, "start_from_trigger", False)), + start_trigger_args=start_trigger_args, + start_from_trigger=start_from_trigger, ) return op diff --git a/task-sdk/tests/task_sdk/bases/test_operator.py b/task-sdk/tests/task_sdk/bases/test_operator.py index 9e6db88d5cf14..dcb5240a83dc8 100644 --- a/task-sdk/tests/task_sdk/bases/test_operator.py +++ b/task-sdk/tests/task_sdk/bases/test_operator.py @@ -41,6 +41,7 @@ ) from airflow.sdk.definitions.param import ParamsDict from airflow.sdk.definitions.template import literal +from airflow.triggers.base import StartTriggerArgs DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc) @@ -108,9 +109,18 @@ def __init__(self, arg1: str = "", arg2: str = "", **kwargs): super().__init__(**kwargs) self.arg1 = arg1 self.arg2 = arg2 + if self.start_from_trigger: + self.start_trigger_args = StartTriggerArgs( + trigger_cls="trigger_cls", + next_method="next_method", + trigger_kwargs={"arg1": arg1, "arg2": arg2}, + ) class TestBaseOperator: + def setup_method(self, method): + MockOperator.start_from_trigger = False + # Since we have a custom metaclass, lets double check the behaviour of # passing args in the wrong way (args etc) def test_kwargs_only(self): @@ -800,6 +810,16 @@ def test_jinja_env_creation(self, mock_jinja_env): task.render_template_fields(context={"foo": "whatever", "bar": "whatever"}) assert mock_jinja_env.call_count == 1 + def test_validate_start_from_trigger_kwargs(self): + MockOperator.start_from_trigger = True + + with pytest.raises( + ValueError, + match="MockOperator with task_id 'one' has a callable in trigger kwargs named " + "'arg2', which is not allowed when start_from_trigger is enabled.", + ): + MockOperator(task_id="one", arg1="{{ foo }}", arg2=lambda context, jinja_env: "bar") + def test_params_source(self): # Test bug when copying an operator attached to a Dag with DAG( diff --git a/task-sdk/tests/task_sdk/definitions/_internal/test_templater.py b/task-sdk/tests/task_sdk/definitions/_internal/test_templater.py index bcce3c895470c..fccdfe8664cc1 100644 --- a/task-sdk/tests/task_sdk/definitions/_internal/test_templater.py +++ b/task-sdk/tests/task_sdk/definitions/_internal/test_templater.py @@ -18,6 +18,7 @@ from __future__ import annotations from datetime import datetime, timezone +from unittest.mock import MagicMock, NonCallableMagicMock import jinja2 import pytest @@ -111,6 +112,193 @@ def test_not_render_file_literal_value(self): assert rendered_content == "template_file.txt" + def test_do_render_template_fields_basic(self): + """Test that _do_render_template_fields renders a simple string template field in-place.""" + templater = Templater() + templater.template_ext = [] + + parent = MagicMock(spec=["greeting"]) + parent.greeting = "Hello {{ name }}" + + context = {"name": "world"} + jinja_env = templater.get_template_env() + + templater._do_render_template_fields(parent, ["greeting"], context, jinja_env, set()) + + assert parent.greeting == "Hello world" + + def test_do_render_template_fields_multiple_fields(self): + """Test rendering multiple template fields at once.""" + templater = Templater() + templater.template_ext = [] + + parent = MagicMock(spec=["first", "second"]) + parent.first = "Hello {{ name }}" + parent.second = "Date: {{ ds }}" + + context = {"name": "world", "ds": "2024-01-01"} + jinja_env = templater.get_template_env() + + templater._do_render_template_fields(parent, ["first", "second"], context, jinja_env, set()) + + assert parent.first == "Hello world" + assert parent.second == "Date: 2024-01-01" + + def test_do_render_template_fields_callable_value(self): + """Test that callable field values are called with context and jinja_env.""" + templater = Templater() + templater.template_ext = [] + + callback = MagicMock(spec=lambda context, jinja_env: None, return_value="resolved") + parent = MagicMock(spec=["my_field"]) + parent.my_field = callback + + context = {"key": "value"} + jinja_env = templater.get_template_env() + + templater._do_render_template_fields(parent, ["my_field"], context, jinja_env, set()) + + callback.assert_called_once_with(context=context, jinja_env=jinja_env) + assert parent.my_field == "resolved" + + def test_do_render_template_fields_skips_falsy_values(self): + """Test that falsy field values (empty string, None, 0) are skipped.""" + templater = Templater() + templater.template_ext = [] + + parent = MagicMock(spec=["empty_str", "none_val"]) + parent.empty_str = "" + parent.none_val = None + + context = {"name": "world"} + jinja_env = templater.get_template_env() + + templater._do_render_template_fields(parent, ["empty_str", "none_val"], context, jinja_env, set()) + + # Falsy values should not be touched + assert parent.empty_str == "" + assert parent.none_val is None + + def test_do_render_template_fields_missing_attribute(self): + """Test that a missing attribute on parent raises AttributeError.""" + templater = Templater() + templater.template_ext = [] + + parent = MagicMock(spec=["existing"]) + parent.existing = "value" + + context = {} + jinja_env = templater.get_template_env() + + with pytest.raises( + AttributeError, + match="'nonexistent' is configured as a template field", + ): + templater._do_render_template_fields(parent, ["nonexistent"], context, jinja_env, set()) + + def test_do_render_template_fields_exception_logged_with_task_id(self, caplog): + """Test that rendering errors are logged with task_id when available and re-raised.""" + templater = Templater() + templater.template_ext = [] + templater.task_id = "my_task" + + parent = MagicMock(spec=["bad_field"]) + parent.bad_field = "{{ undefined_var }}" + + context = {} + jinja_env = SandboxedEnvironment(undefined=jinja2.StrictUndefined, cache_size=0) + + with pytest.raises(jinja2.UndefinedError): + templater._do_render_template_fields(parent, ["bad_field"], context, jinja_env, set()) + + assert "Exception rendering Jinja template for task 'my_task', field 'bad_field'" in caplog.text + + def test_do_render_template_fields_exception_logged_without_task_id(self, caplog): + """Test that rendering errors are logged with parent type name when no task_id.""" + templater = Templater() + templater.template_ext = [] + + parent = MagicMock(spec=["bad_field"]) + parent.bad_field = "{{ undefined_var }}" + + context = {} + jinja_env = SandboxedEnvironment(undefined=jinja2.StrictUndefined, cache_size=0) + + with pytest.raises(jinja2.UndefinedError): + templater._do_render_template_fields(parent, ["bad_field"], context, jinja_env, set()) + + assert "Exception rendering Jinja template for MagicMock, field 'bad_field'" in caplog.text + + def test_do_render_template_fields_nested_template_fields(self): + """Test rendering nested objects that have their own template_fields.""" + templater = Templater() + templater.template_ext = [] + + inner = NonCallableMagicMock(spec=["template_fields", "message"]) + inner.template_fields = ["message"] + inner.message = "Hello {{ name }}" + + parent = MagicMock(spec=["nested"]) + parent.nested = inner + + context = {"name": "world"} + jinja_env = templater.get_template_env() + + templater._do_render_template_fields(parent, ["nested"], context, jinja_env, set()) + + assert inner.message == "Hello world" + + def test_do_render_template_fields_seen_oids_prevents_reprocessing(self): + """Test that already-seen objects (by id) are not re-rendered.""" + templater = Templater() + templater.template_ext = [] + + parent = MagicMock(spec=["greeting"]) + parent.greeting = "Hello {{ name }}" + + context = {"name": "world"} + jinja_env = templater.get_template_env() + + # Pre-populate seen_oids with the parent's greeting value id + seen_oids = {id(parent.greeting)} + + templater._do_render_template_fields(parent, ["greeting"], context, jinja_env, seen_oids) + + # The value should NOT be rendered because render_template checks + # `id(value) in seen_oids` and short-circuits, returning the original + # unrendered string. + assert parent.greeting == "Hello {{ name }}" + + def test_do_render_template_fields_renders_dict_values(self): + """Test that dict field values have their inner templates rendered.""" + templater = Templater() + templater.template_ext = [] + + parent = MagicMock(spec=["params"]) + parent.params = {"key": "{{ value }}"} + + context = {"value": "rendered"} + jinja_env = templater.get_template_env() + + templater._do_render_template_fields(parent, ["params"], context, jinja_env, set()) + + assert parent.params == {"key": "rendered"} + + def test_do_render_template_fields_renders_list_values(self): + """Test that list field values have their inner templates rendered.""" + templater = Templater() + templater.template_ext = [] + + parent = MagicMock(spec=["items"]) + parent.items = ["{{ a }}", "{{ b }}"] + + context = {"a": "first", "b": "second"} + jinja_env = templater.get_template_env() + + templater._do_render_template_fields(parent, ["items"], context, jinja_env, set()) + + assert parent.items == ["first", "second"] + @pytest.fixture def env():