Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions components/dags/src/pinta_dags/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ class AirflowVariable(enum.StrEnum):
)
CALCULATE_DEM_DIFF_STAGING_TABLES = "pinta_calculate_dem_diff_staging_tables"

# Maximum number of update area dissolve pipelines running in parallel.
DISSOLVE_UPDATE_AREAS_MAX_PARALLEL_PIPELINES = (
"pinta_dissolve_update_areas_max_parallel_pipelines"
)


def connection_uri_template(conn_id: str) -> str:
"""Jinja template for a connection's SQLAlchemy URI with the psycopg3 driver."""
Expand Down
142 changes: 142 additions & 0 deletions components/dags/src/pinta_dags/dags/dissolve_update_areas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright (c) 2026 National Land Survey of Finland
# (https://www.maanmittauslaitos.fi/en).
# This file is part of the Pinta.
# Licensed under the MIT License; see the repository LICENSE file.

import datetime
from typing import cast

from airflow.sdk import DAG, Param, Variable, dag, task
from pinta_common import constants

from pinta_dags import config
from pinta_dags.config import AirflowVariable
from pinta_dags.tasks import (
build_job_connection_uri_task,
find_update_area_geometries,
get_database_name,
set_processing_status_completed,
set_processing_status_failed,
set_processing_status_started,
)


def _get_max_parallel_pipelines() -> int:
var = AirflowVariable.DISSOLVE_UPDATE_AREAS_MAX_PARALLEL_PIPELINES
max_parallel = int(Variable.get(var, 4))
if max_parallel < 1:
msg = f"{var} must be at least 1"
raise ValueError(msg)
return max_parallel


def create_dissolve_update_areas_dag(
*,
dag_id: str,
) -> DAG:
@dag(
dag_id=dag_id,
tags=[dag_id],
dag_display_name="Dissolve update areas",
schedule=None,
params={
"id": Param(
"",
type="string",
format="uuid",
description=("Production area id as UUID"),
)
},
is_paused_upon_creation=False,
)
def dissolve_update_areas_dag() -> None:
# Precondition: the production area must already have its job database
# provisioned and database_name set for production area by orchestrator DAG.

@task.docker(
**config.PINTA_CONTAINER_TASK_ARGS,
max_active_tis_per_dag=_get_max_parallel_pipelines(),
# Parallel tasks merging into the same base/overview tiles can
# deadlock on the concurrent row updates; retry to ride out the loser.
retries=3,
retry_delay=datetime.timedelta(seconds=10),
)
def dissolve_update_area(
primary_connection_uri: str,
job_connection_uri: str,
geom_wkt: str,
) -> None:
import sqlalchemy
import sqlmodel
from pinta_processing import pipelines

with (
sqlmodel.Session(
sqlalchemy.create_engine(primary_connection_uri)
) as primary_session,
sqlmodel.Session(
sqlalchemy.create_engine(job_connection_uri)
) as job_session,
):
pipeline = pipelines.dissolve_update_area(
primary_session=primary_session,
job_session=job_session,
geom_wkt=geom_wkt,
)
pipeline.execute()

primary_connection_uri = config.connection_uri_template("pinta_processing_db")
job_connection_uri = config.connection_uri_template("pinta_job_db")

prod_area_id = "{{ params.id }}"

status_started = set_processing_status_started(
primary_connection_uri, prod_area_id
)
database_name = cast(
"str", get_database_name(primary_connection_uri, prod_area_id)
)
job_db_uri = cast(
"str",
build_job_connection_uri_task(
base_uri=job_connection_uri,
database_name=database_name,
),
)
geom_wkt_list = find_update_area_geometries(job_db_uri)

dissolved_areas = dissolve_update_area.partial(
primary_connection_uri=primary_connection_uri,
job_connection_uri=job_db_uri,
).expand(geom_wkt=geom_wkt_list)

status_completed = set_processing_status_completed(
primary_connection_uri, prod_area_id
)
status_failed = set_processing_status_failed(
primary_connection_uri, prod_area_id
)

# Stamp STARTED before any work, then run the dissolve chain.
status_started >> database_name
geom_wkt_list >> dissolved_areas

# Resolve the final status off every task that can fail (each is a direct
# upstream, so ONE_FAILED still fires when an early step fails and the
# mapped task never runs). NONE_FAILED marks COMPLETED otherwise.
processing_steps = [
status_started,
database_name,
job_db_uri,
geom_wkt_list,
dissolved_areas,
]
processing_steps >> status_completed
processing_steps >> status_failed

return dissolve_update_areas_dag()


DAG_ID = constants.DAG_ID_DISSOLVE_UPDATE_AREAS

globals()[DAG_ID] = create_dissolve_update_areas_dag(dag_id=DAG_ID)
83 changes: 82 additions & 1 deletion components/dags/src/pinta_dags/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

"""Airflow tasks shared across Pinta DAGs."""

from airflow.sdk import task
from airflow.sdk import TriggerRule, task

from pinta_dags import config

Expand Down Expand Up @@ -78,6 +78,22 @@ def find_production_area_tile_geometries(
return [to_shape(tile.geom).wkt for tile in area_in_db.tiles]


@task.docker(**config.PINTA_CONTAINER_TASK_ARGS)
def find_update_area_geometries(
connection_uri: str,
) -> list[str]:
"""Return the geometries (as WKT) of all update areas in the job database."""
import sqlalchemy
import sqlmodel
from geoalchemy2.shape import to_shape
from pinta_db.job_db.models.user import UpdateArea

engine = sqlalchemy.create_engine(connection_uri)
with sqlmodel.Session(engine) as session:
update_areas = session.exec(sqlmodel.select(UpdateArea)).all()
return [to_shape(area.geom).wkt for area in update_areas]


@task
def build_job_connection_uri_task(
base_uri: str,
Expand Down Expand Up @@ -142,3 +158,68 @@ def merge_dem_staging_tables(
staging_tables=staging_tables,
session=session,
)


@task.docker(**config.PINTA_CONTAINER_TASK_ARGS)
def set_processing_status_started(connection_uri: str, production_area_id: str) -> None:
"""Mark the production area as processing started."""
import sqlalchemy
import sqlmodel
from pinta_db.primary_db.models.management import ProcessingStatus, ProductionArea

engine = sqlalchemy.create_engine(connection_uri)
with sqlmodel.Session(engine) as session:
area_in_db = session.exec(
sqlmodel.select(ProductionArea).where(
ProductionArea.id == production_area_id
)
).first()
if area_in_db:
area_in_db.processing_status = ProcessingStatus.STARTED
session.commit()


@task.docker(
**config.PINTA_CONTAINER_TASK_ARGS,
trigger_rule=TriggerRule.NONE_FAILED,
)
def set_processing_status_completed(
connection_uri: str, production_area_id: str
) -> None:
"""Mark the production area as processing completed when nothing failed."""
import sqlalchemy
import sqlmodel
from pinta_db.primary_db.models.management import ProcessingStatus, ProductionArea

engine = sqlalchemy.create_engine(connection_uri)
with sqlmodel.Session(engine) as session:
area_in_db = session.exec(
sqlmodel.select(ProductionArea).where(
ProductionArea.id == production_area_id
)
).first()
if area_in_db:
area_in_db.processing_status = ProcessingStatus.COMPLETED
session.commit()


@task.docker(
**config.PINTA_CONTAINER_TASK_ARGS,
trigger_rule=TriggerRule.ONE_FAILED,
)
def set_processing_status_failed(connection_uri: str, production_area_id: str) -> None:
"""Mark the production area as processing failed when any upstream failed."""
import sqlalchemy
import sqlmodel
from pinta_db.primary_db.models.management import ProcessingStatus, ProductionArea

engine = sqlalchemy.create_engine(connection_uri)
with sqlmodel.Session(engine) as session:
area_in_db = session.exec(
sqlmodel.select(ProductionArea).where(
ProductionArea.id == production_area_id
)
).first()
if area_in_db:
area_in_db.processing_status = ProcessingStatus.FAILURE
session.commit()
Loading
Loading