Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 62 additions & 14 deletions data_rentgen/db/repositories/job_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Literal

from sqlalchemy import ARRAY, Integer, any_, bindparam, cast, func, or_, select, tuple_
from sqlalchemy import ARRAY, Integer, any_, bindparam, cast, func, literal, select, tuple_
from sqlalchemy.orm import aliased

from data_rentgen.db.models.job_dependency import JobDependency
from data_rentgen.db.repositories.base import Repository
Expand All @@ -26,6 +27,59 @@
JobDependency.to_job_id == bindparam("to_job_id"),
)

upstream_jobs_query_base_part = (
select(
JobDependency,
literal(1).label("depth"),
)
.select_from(JobDependency)
.where(JobDependency.to_job_id == any_(bindparam("job_ids")))
)
upstream_jobs_query_cte = upstream_jobs_query_base_part.cte(name="upstream_jobs_query", recursive=True)

upstream_jobs_query_recursive_part = (
select(
JobDependency,
(upstream_jobs_query_cte.c.depth + 1).label("depth"),
)
.select_from(JobDependency)
.where(
upstream_jobs_query_cte.c.depth < bindparam("depth"),
JobDependency.to_job_id == upstream_jobs_query_cte.c.from_job_id,
)
)


upstream_jobs_query_cte = upstream_jobs_query_cte.union(upstream_jobs_query_recursive_part)
upstream_entities_query = select(aliased(JobDependency, upstream_jobs_query_cte))

downstream_jobs_query_base_part = (
select(
JobDependency,
literal(1).label("depth"),
)
.select_from(JobDependency)
.where(JobDependency.from_job_id == any_(bindparam("job_ids")))
)
downstream_jobs_query_cte = downstream_jobs_query_base_part.cte(name="downstream_jobs_query", recursive=True)

downstream_jobs_query_recursive_part = (
select(
JobDependency,
(downstream_jobs_query_cte.c.depth + 1).label("depth"),
)
.select_from(JobDependency)
.where(
downstream_jobs_query_cte.c.depth < bindparam("depth"),
JobDependency.from_job_id == downstream_jobs_query_cte.c.to_job_id,
)
)

downstream_jobs_query_cte = downstream_jobs_query_cte.union(downstream_jobs_query_recursive_part)
downstream_entities_query = select(aliased(JobDependency, downstream_jobs_query_cte))

both_entities_query = select(aliased(JobDependency, (upstream_entities_query.union(downstream_entities_query)).cte()))


class JobDependencyRepository(Repository[JobDependency]):
async def fetch_bulk(
Expand Down Expand Up @@ -60,25 +114,19 @@ async def get_dependencies(
self,
job_ids: list[int],
direction: Literal["UPSTREAM", "DOWNSTREAM", "BOTH"],
depth: int,
) -> list[JobDependency]:

job_dependency_query = select(JobDependency)
match direction:
case "UPSTREAM":
job_dependency_query = job_dependency_query.where(JobDependency.to_job_id == any_(bindparam("job_ids")))
query = upstream_entities_query
case "DOWNSTREAM":
job_dependency_query = job_dependency_query.where(
JobDependency.from_job_id == any_(bindparam("job_ids"))
)
query = downstream_entities_query
case "BOTH":
job_dependency_query = job_dependency_query.where(
or_(
JobDependency.from_job_id == any_(bindparam("job_ids")),
JobDependency.to_job_id == any_(bindparam("job_ids")),
)
)
scalars = await self._session.scalars(job_dependency_query, {"job_ids": job_ids})
return list(scalars.all())
query = both_entities_query

result = await self._session.scalars(query, {"job_ids": job_ids, "depth": depth})
return list(result.all())

async def _get(self, job_dependency: JobDependencyDTO) -> JobDependency | None:
return await self._session.scalar(
Expand Down
6 changes: 5 additions & 1 deletion data_rentgen/server/api/v1/router/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@ async def get_job_dependencies(
job_service: Annotated[JobService, Depends()],
current_user: Annotated[User, Depends(get_user())],
) -> JobDependenciesResponseV1:
job_dependencies = await job_service.get_job_dependencies(query_args.start_node_id, query_args.direction)
job_dependencies = await job_service.get_job_dependencies(
start_node_id=query_args.start_node_id,
direction=query_args.direction,
depth=query_args.depth,
)
return JobDependenciesResponseV1(
relations=JobDependenciesRelationsV1(
parents=[
Expand Down
1 change: 1 addition & 0 deletions data_rentgen/server/schemas/v1/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,5 @@ class JobDependenciesQueryV1(BaseModel):
description="Direction of the lineage",
examples=["DOWNSTREAM", "UPSTREAM", "BOTH"],
)
depth: int = Field(description="Depth of dependencies between jobs", default=1)
model_config = ConfigDict(extra="ignore")
16 changes: 13 additions & 3 deletions data_rentgen/server/services/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,15 @@ async def get_job_dependencies(
self,
start_node_id: int,
direction: Literal["UPSTREAM", "DOWNSTREAM", "BOTH"],
depth: int,
) -> JobDependenciesResult:
logger.info("Get Job dependencies with start at job with id %s and direction: %s", start_node_id, direction)
logger.info(
"Get Job dependencies with start at job with id %s and next params: direction: %s, depth: %s",
start_node_id,
direction,
depth,
)
job_ids = {start_node_id}

ancestor_relations = await self._uow.job.list_ancestor_relations([start_node_id])
descendant_relations = await self._uow.job.list_descendant_relations([start_node_id])
Expand All @@ -122,10 +129,13 @@ async def get_job_dependencies(
| {p.child_job_id for p in descendant_relations}
)

dependencies = await self._uow.job_dependency.get_dependencies(job_ids=list(job_ids), direction=direction)
dependencies = await self._uow.job_dependency.get_dependencies(
job_ids=list(job_ids),
direction=direction,
depth=depth,
)
dependency_job_ids = {d.from_job_id for d in dependencies} | {d.to_job_id for d in dependencies}
job_ids |= dependency_job_ids

# return ancestors for all found jobs in the graph
ancestor_relations += await self._uow.job.list_ancestor_relations(list(dependency_job_ids))
job_ids |= {p.parent_job_id for p in ancestor_relations}
Expand Down
1 change: 1 addition & 0 deletions docs/changelog/next_release/412.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add ``depth`` query parameter to ``GET /v1/jobs/dependencies`` endpoint, allowing control over how many layers of dependency are traversed. Defaults to ``1``.
45 changes: 45 additions & 0 deletions tests/test_server/fixtures/factories/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,51 @@ async def jobs_with_same_parent_job(
await clean_db(async_session)


@pytest_asyncio.fixture
async def job_dependency_depth_chain(
async_session_maker: Callable[[], AbstractAsyncContextManager[AsyncSession]],
) -> AsyncGenerator[list[Job], None]:
"""
Linear dependency chain of 5 jobs:

job_1 → job_2 → job_3 → job_4 → job_5

Each arrow is a JobDependency edge with type "DIRECT_DEPENDENCY".
Used for testing depth-limited dependency queries.
"""
async with async_session_maker() as async_session:
location = await create_location(async_session)
job_type = await create_job_type(async_session)

jobs = []
for i in range(1, 6):
job = await create_job(
async_session,
location_id=location.id,
job_type_id=job_type.id,
job_kwargs={"name": f"depth-chain-job-{i}"},
)
jobs.append(job)

async_session.add_all(
[
JobDependency(
from_job_id=jobs[i].id,
to_job_id=jobs[i + 1].id,
type="DIRECT_DEPENDENCY",
)
for i in range(len(jobs) - 1)
],
)
await async_session.commit()
async_session.expunge_all()

yield jobs

async with async_session_maker() as async_session:
await clean_db(async_session)


@pytest_asyncio.fixture
async def job_dependency_chain(
async_session_maker: Callable[[], AbstractAsyncContextManager[AsyncSession]],
Expand Down
102 changes: 102 additions & 0 deletions tests/test_server/test_jobs/test_job_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,105 @@ async def test_get_job_dependencies_with_direction_downstream(
},
"nodes": {"jobs": jobs_to_json(expected_nodes)},
}


@pytest.mark.parametrize(
["depth", "direction", "expected_dep_indices", "expected_job_indices"],
[
(1, "DOWNSTREAM", [(2, 3)], [2, 3]),
(2, "DOWNSTREAM", [(2, 3), (3, 4)], [2, 3, 4]),
(1, "UPSTREAM", [(1, 2)], [1, 2]),
(2, "UPSTREAM", [(0, 1), (1, 2)], [0, 1, 2]),
(1, "BOTH", [(1, 2), (2, 3)], [1, 2, 3]),
(2, "BOTH", [(0, 1), (1, 2), (2, 3), (3, 4)], [0, 1, 2, 3, 4]),
(5, "BOTH", [(0, 1), (1, 2), (2, 3), (3, 4)], [0, 1, 2, 3, 4]),
],
ids=[
"depth_1-downstream",
"depth_2-downstream",
"depth_1-upstream",
"depth_2-upstream",
"depth_1-both",
"depth_2-both",
"depth_5-both",
],
)
async def test_get_job_dependencies_with_depth(
test_client: AsyncClient,
job_dependency_depth_chain: tuple[Job, ...],
async_session: AsyncSession,
mocked_user: MockedUser,
depth: int,
direction: str,
expected_dep_indices: list[tuple[int, int]],
expected_job_indices: list[int],
):
"""
Fixture chain: job_0 → job_1 → job_2 → job_3 → job_4
Start node is always job_2 (middle of the chain).
"""
jobs = job_dependency_depth_chain
start_job = jobs[2]

expected_jobs = await enrich_jobs([jobs[i] for i in expected_job_indices], async_session)

response = await test_client.get(
"v1/jobs/dependencies",
headers={"Authorization": f"Bearer {mocked_user.access_token}"},
params={"start_node_id": start_job.id, "depth": depth, "direction": direction},
)
assert response.status_code == HTTPStatus.OK, response.json()
assert response.json() == {
"relations": {
"parents": [],
"dependencies": [
{
"from": {"kind": "JOB", "id": str(jobs[i].id)},
"to": {"kind": "JOB", "id": str(jobs[j].id)},
"type": "DIRECT_DEPENDENCY",
}
for i, j in sorted(expected_dep_indices)
],
},
"nodes": {"jobs": jobs_to_json(expected_jobs)},
}


@pytest.mark.parametrize(
["direction", "start_node_index"],
[
("UPSTREAM", 0),
("DOWNSTREAM", 4),
],
ids=["upstream_boundary", "downstream_boundary"],
)
async def test_get_job_dependencies_with_depth_on_boundary(
test_client: AsyncClient,
job_dependency_depth_chain: tuple[Job, ...],
async_session: AsyncSession,
mocked_user: MockedUser,
direction: str,
start_node_index: int,
):
"""
Fixture chain: job_0 → job_1 → job_2 → job_3 → job_4
Start node is job_0 or job_4.
"""
jobs = job_dependency_depth_chain
start_job = jobs[start_node_index]

[expected_job] = await enrich_jobs([start_job], async_session)

response = await test_client.get(
"v1/jobs/dependencies",
headers={"Authorization": f"Bearer {mocked_user.access_token}"},
params={"start_node_id": start_job.id, "depth": 2, "direction": direction},
)
assert response.status_code == HTTPStatus.OK, response.json()
assert response.json() == {
"relations": {
"parents": [],
"dependencies": [],
},
"nodes": {"jobs": jobs_to_json([expected_job])},
}
Loading