Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 4 additions & 21 deletions components/dags/src/pinta_dags/dags/calculate_dem_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pinta_dags.config import AirflowVariable
from pinta_dags.tasks import (
build_job_connection_uri_task,
find_production_area_tile_geometries,
get_database_name,
initialize_dem_tables,
merge_dem_staging_tables,
Expand Down Expand Up @@ -75,26 +76,6 @@ def calculate_dem_diff_dag() -> None:
# Precondition: the production area must already have its job database
# provisioned and database_name set for production area by orchestrator DAG.

@task.docker(**config.PINTA_CONTAINER_TASK_ARGS)
def find_production_area(
connection_uri: str,
production_area_id: str,
) -> list[str]:
import sqlalchemy
import sqlmodel
from geoalchemy2.shape import to_shape
from pinta_db.primary_db.models.management import ProductionArea

engine = sqlalchemy.create_engine(connection_uri)
with sqlmodel.Session(engine) as session:
statement = sqlmodel.select(ProductionArea).where(
ProductionArea.id == production_area_id
)
area_in_db = session.exec(statement).first()
if not area_in_db:
return []
return [to_shape(tile.geom).wkt for tile in area_in_db.tiles]

@task.docker(
**config.PINTA_CONTAINER_TASK_ARGS,
max_active_tis_per_dag=_get_max_parallel_pipelines(),
Expand Down Expand Up @@ -155,7 +136,9 @@ def cluster_diff_polygons(job_connection_uri: str) -> None:
database_name=database_name,
),
)
tile_wkt_list = find_production_area(primary_connection_uri, prod_area_id)
tile_wkt_list = find_production_area_tile_geometries.override(
task_id="find_production_area"
)(primary_connection_uri, prod_area_id)

init_diff_task = initialize_dem_tables.override(
task_id="initialize_diff_tables"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

from pinta_dags import config

# How often each triggered child DAG is polled for completion.
_TRIGGER_POKE_INTERVAL_SECONDS = 5


def create_calculate_rasters_for_production_area_dag( # noqa: PLR0915
*,
Expand Down Expand Up @@ -47,6 +50,11 @@ def create_calculate_rasters_for_production_area_dag( # noqa: PLR0915
type="boolean",
description="Cluster difference polygons in the DEM diff DAG",
),
"initialize_dem_preview": Param(
default=True,
type="boolean",
description="Initialize DEM preview table",
),
},
is_paused_upon_creation=False,
)
Expand Down Expand Up @@ -105,6 +113,7 @@ def should_run(*, requested: bool) -> bool:
trigger_dag_id=constants.DAG_ID_CALCULATE_REFERENCE_DEM,
conf={"id": "{{ params.id }}"},
wait_for_completion=True,
poke_interval=_TRIGGER_POKE_INTERVAL_SECONDS,
)

trigger_calculate_dem_diff = TriggerDagRunOperator(
Expand All @@ -115,6 +124,15 @@ def should_run(*, requested: bool) -> bool:
"cluster": "{{ params.cluster_diff_polygons }}",
},
wait_for_completion=True,
poke_interval=_TRIGGER_POKE_INTERVAL_SECONDS,
)

trigger_initialize_dem_preview = TriggerDagRunOperator(
task_id="trigger_initialize_dem_preview",
trigger_dag_id=constants.DAG_ID_INITIALIZE_DEM_PREVIEW,
conf={"id": "{{ params.id }}"},
wait_for_completion=True,
poke_interval=_TRIGGER_POKE_INTERVAL_SECONDS,
)

@task.docker(
Expand Down Expand Up @@ -184,8 +202,17 @@ def set_processing_status_failed(
task_id="should_calculate_dem_diff",
trigger_rule=TriggerRule.NONE_FAILED,
)(requested=cast("bool", "{{ params.calculate_dem_diff }}"))

triggers = [trigger_calculate_reference_dem, trigger_calculate_dem_diff]
# The DEM preview copy is independent of the reference DEM -> DEM diff
# chain, so it runs in parallel straight off ensure_database.
preview_gate = should_run.override(
task_id="should_initialize_dem_preview",
)(requested=cast("bool", "{{ params.initialize_dem_preview }}"))

triggers = [
trigger_calculate_reference_dem,
trigger_calculate_dem_diff,
trigger_initialize_dem_preview,
]
status_completed = set_processing_status_completed(
primary_connection_uri, prod_area_id
)
Expand All @@ -197,6 +224,8 @@ def set_processing_status_failed(
# DEM so it can compare against freshly computed reference rasters.
ensure_database >> reference_gate >> trigger_calculate_reference_dem
trigger_calculate_reference_dem >> diff_gate >> trigger_calculate_dem_diff
# Runs in parallel with the reference DEM -> DEM diff chain.
ensure_database >> preview_gate >> trigger_initialize_dem_preview
# Also depend on ensure_database so a failure there (before any trigger
# runs) still resolves the processing status instead of leaving the
# production area stuck in STARTED.
Expand Down
24 changes: 4 additions & 20 deletions components/dags/src/pinta_dags/dags/calculate_reference_dem.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pinta_dags.config import AirflowVariable
from pinta_dags.tasks import (
build_job_connection_uri_task,
find_production_area_tile_paths,
get_database_name,
initialize_dem_tables,
merge_dem_staging_tables,
Expand Down Expand Up @@ -65,25 +66,6 @@ def calculate_reference_dem_dag() -> None:
# Precondition: the production area must already have its job database
# provisioned and database_name set for production area by orchestrator DAG.

@task.docker(**config.PINTA_CONTAINER_TASK_ARGS)
def find_production_area(
connection_uri: str,
production_area_id: str,
) -> list[str]:
import sqlalchemy
import sqlmodel
from pinta_db.primary_db.models.management import ProductionArea

engine = sqlalchemy.create_engine(connection_uri)
with sqlmodel.Session(engine) as session:
statement = sqlmodel.select(ProductionArea).where(
ProductionArea.id == production_area_id
)
area_in_db = session.exec(statement).first()
if not area_in_db:
return []
return [tile.file_path for tile in area_in_db.tiles]

@task.docker(
**config.PINTA_CONTAINER_TASK_ARGS,
max_active_tis_per_dag=_get_max_parallel_pipelines(),
Expand Down Expand Up @@ -133,7 +115,9 @@ def blast2dem( # noqa: PLR0913
database_name = cast(
"str", get_database_name(primary_connection_uri, prod_area_id)
)
file_paths = find_production_area(primary_connection_uri, prod_area_id)
file_paths = find_production_area_tile_paths.override(
task_id="find_production_area"
)(primary_connection_uri, prod_area_id)

job_db_uri = cast(
"str",
Expand Down
157 changes: 157 additions & 0 deletions components/dags/src/pinta_dags/dags/initialize_dem_preview.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright (c) 2026 National Land Survey of Finland
# (https://www.maanmittauslaitos.fi/en).
# This file is part of the Pinta.
# Licensed under the MIT License; see the repository LICENSE file.

from typing import cast

from airflow.sdk import DAG, Param, Variable, dag, task
from pinta_common import constants
from pinta_db.job_db.models.user import DemPreview
from pinta_db.job_db.schema import Schema
from pinta_db.primary_db.models.dem import Dem as PrimaryDem
from pinta_db.primary_db.schema import Schema as PrimarySchema

from pinta_dags import config
from pinta_dags.config import AirflowVariable
from pinta_dags.tasks import (
build_job_connection_uri_task,
find_production_area_tile_geometries,
get_database_name,
initialize_dem_tables,
merge_dem_staging_tables,
)

FROM_DB_SCHEMA = PrimarySchema.DEM.value
FROM_DB_TABLE = PrimaryDem.__tablename__
TO_DB_SCHEMA = Schema.USER.value
TO_DB_TABLE = DemPreview.__tablename__


def _get_max_parallel_pipelines() -> int:
# Reuses the reference DEM parallelism variable
var = AirflowVariable.CALCULATE_REFERENCE_DEM_MAX_PARALLEL_PIPELINES
max_parallel = int(Variable.get(var, 4))
if max_parallel < 1:
msg = f"{var} must be at least 1"
raise ValueError(msg)
return max_parallel


def _get_staging_tables() -> int:
# Reuses the reference DEM staging tables variable
var = AirflowVariable.CALCULATE_REFERENCE_DEM_STAGING_TABLES
staging_tables = int(Variable.get(var, 1))
if staging_tables < 0:
msg = f"{var} must be at least 0"
raise ValueError(msg)
return staging_tables


def create_initialize_dem_preview_dag(
*,
dag_id: str,
) -> DAG:
@dag(
dag_id=dag_id,
tags=[dag_id],
dag_display_name="Initialize DEM preview",
schedule=None,
params={
"id": Param(
"",
type="string",
format="uuid",
description=("Production area id as UUID"),
)
},
is_paused_upon_creation=False,
)
def initialize_dem_preview_dag() -> None:
# Precondition: the production area must already have its job database
# provisioned and database_name set for production area by orchestrator DAG.

@task.docker(
**config.PINTA_CONTAINER_TASK_ARGS,
max_active_tis_per_dag=_get_max_parallel_pipelines(),
)
def copy_dem_preview( # noqa: PLR0913
primary_connection_uri: str,
job_connection_uri: str,
tile_wkt: str,
staging_tables: int,
from_schema: str,
from_table: str,
to_schema: str,
to_table: str,
) -> None:
import sqlalchemy
import sqlmodel
from pinta_processing import pipelines

with (
sqlmodel.Session(
sqlalchemy.create_engine(primary_connection_uri)
) as primary_session,
sqlmodel.Session(
sqlalchemy.create_engine(job_connection_uri)
) as job_session,
):
pipeline = pipelines.postgis_to_postgis(
from_session=primary_session,
from_schema=from_schema,
from_table=from_table,
to_session=job_session,
to_schema=to_schema,
to_table=to_table,
tile_wkt=tile_wkt,
staging_tables=staging_tables,
)
pipeline.execute()

primary_connection_uri = config.connection_uri_template("pinta_processing_db")
job_connection_uri = config.connection_uri_template("pinta_job_db")
staging_tables = _get_staging_tables()

prod_area_id = "{{ params.id }}"
database_name = cast(
"str", get_database_name(primary_connection_uri, prod_area_id)
)
job_db_uri = cast(
"str",
build_job_connection_uri_task(
base_uri=job_connection_uri,
database_name=database_name,
),
)
tile_wkt_list = find_production_area_tile_geometries.override(
task_id="find_production_area"
)(primary_connection_uri, prod_area_id)

initialize_task = initialize_dem_tables(
job_db_uri, TO_DB_SCHEMA, TO_DB_TABLE, staging_tables
)
copied_tiles = copy_dem_preview.partial(
primary_connection_uri=primary_connection_uri,
job_connection_uri=job_db_uri,
staging_tables=staging_tables,
from_schema=FROM_DB_SCHEMA,
from_table=FROM_DB_TABLE,
to_schema=TO_DB_SCHEMA,
to_table=TO_DB_TABLE,
).expand(tile_wkt=tile_wkt_list)
(
tile_wkt_list
>> initialize_task
>> copied_tiles
>> merge_dem_staging_tables(
job_db_uri, TO_DB_SCHEMA, TO_DB_TABLE, staging_tables
)
)

return initialize_dem_preview_dag()


DAG_ID = constants.DAG_ID_INITIALIZE_DEM_PREVIEW

globals()[DAG_ID] = create_initialize_dem_preview_dag(dag_id=DAG_ID)
45 changes: 45 additions & 0 deletions components/dags/src/pinta_dags/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,51 @@ def get_database_name(
return area_in_db.database_name


@task.docker(**config.PINTA_CONTAINER_TASK_ARGS)
def find_production_area_tile_paths(
connection_uri: str,
production_area_id: str,
) -> list[str]:
"""Return the source file paths of the production area's point cloud tiles."""
import sqlalchemy
import sqlmodel
from pinta_db.primary_db.models.management import ProductionArea

engine = sqlalchemy.create_engine(connection_uri)
with sqlmodel.Session(engine) as session:
area_in_db = session.exec(
sqlmodel.select(ProductionArea).where(
ProductionArea.id == production_area_id
)
).first()
if not area_in_db:
return []
return [tile.file_path for tile in area_in_db.tiles]


@task.docker(**config.PINTA_CONTAINER_TASK_ARGS)
def find_production_area_tile_geometries(
connection_uri: str,
production_area_id: str,
) -> list[str]:
"""Return the geometries (as WKT) of the production area's point cloud tiles."""
import sqlalchemy
import sqlmodel
from geoalchemy2.shape import to_shape
from pinta_db.primary_db.models.management import ProductionArea

engine = sqlalchemy.create_engine(connection_uri)
with sqlmodel.Session(engine) as session:
area_in_db = session.exec(
sqlmodel.select(ProductionArea).where(
ProductionArea.id == production_area_id
)
).first()
if not area_in_db:
return []
return [to_shape(tile.geom).wkt for tile in area_in_db.tiles]


@task
def build_job_connection_uri_task(
base_uri: str,
Expand Down
Loading
Loading