From c588016243be6e56e761da0903b2b19cdf6b7656 Mon Sep 17 00:00:00 2001 From: Ayu Date: Sun, 29 Mar 2026 23:07:16 +0800 Subject: [PATCH 1/3] Implement worker task execution flow --- src/apps/api/routers/runs.py | 15 +- src/apps/worker/builtin_agents.py | 26 +++ src/apps/worker/registry.py | 16 ++ src/apps/worker/service.py | 180 ++++++++++++++++++++ src/tests/test_worker_execution.py | 256 +++++++++++++++++++++++++++++ 5 files changed, 492 insertions(+), 1 deletion(-) create mode 100644 src/apps/worker/builtin_agents.py create mode 100644 src/apps/worker/registry.py create mode 100644 src/apps/worker/service.py create mode 100644 src/tests/test_worker_execution.py diff --git a/src/apps/api/routers/runs.py b/src/apps/api/routers/runs.py index f11dfa1..3305040 100644 --- a/src/apps/api/routers/runs.py +++ b/src/apps/api/routers/runs.py @@ -1,5 +1,18 @@ from __future__ import annotations -from fastapi import APIRouter +from fastapi import APIRouter, Depends, HTTPException, status +from sqlalchemy.orm import Session + +from src.apps.api.deps import get_db +from src.packages.core.db.models import ExecutionRunORM +from src.packages.core.schemas import ExecutionRunRead router = APIRouter(prefix="/runs", tags=["runs"]) + + +@router.get("/{run_id}", response_model=ExecutionRunRead) +def get_run(run_id: str, db: Session = Depends(get_db)) -> ExecutionRunRead: + run = db.get(ExecutionRunORM, run_id) + if run is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Execution run not found") + return ExecutionRunRead.model_validate(run) diff --git a/src/apps/worker/builtin_agents.py b/src/apps/worker/builtin_agents.py new file mode 100644 index 0000000..0e1f27d --- /dev/null +++ b/src/apps/worker/builtin_agents.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import Any, Protocol + +from src.packages.core.db.models import TaskORM + + +class WorkerAgent(Protocol): + def run(self, task: TaskORM, context: dict[str, Any]) -> dict[str, Any]: + ... + + +class EchoWorkerAgent: + def run(self, task: TaskORM, context: dict[str, Any]) -> dict[str, Any]: + return { + "status": "ok", + "task_id": task.id, + "task_type": task.task_type, + "echo": task.input_payload, + "context": context, + } + + +class FailingWorkerAgent: + def run(self, task: TaskORM, context: dict[str, Any]) -> dict[str, Any]: + raise RuntimeError(f"Agent execution failed for task {task.id}") diff --git a/src/apps/worker/registry.py b/src/apps/worker/registry.py new file mode 100644 index 0000000..831cb30 --- /dev/null +++ b/src/apps/worker/registry.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from src.apps.worker.builtin_agents import EchoWorkerAgent, FailingWorkerAgent, WorkerAgent + + +def get_worker_agent(role_name: str) -> WorkerAgent: + agents: dict[str, WorkerAgent] = { + "default_worker": EchoWorkerAgent(), + "echo_worker": EchoWorkerAgent(), + "failing_worker": FailingWorkerAgent(), + } + + if role_name not in agents: + raise KeyError(f"No worker agent registered for role {role_name}") + + return agents[role_name] diff --git a/src/apps/worker/service.py b/src/apps/worker/service.py new file mode 100644 index 0000000..4e26438 --- /dev/null +++ b/src/apps/worker/service.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +from sqlalchemy import Select, select +from sqlalchemy.orm import Session + +from src.apps.worker.registry import get_worker_agent +from src.packages.core.db.models import ( + AgentRoleORM, + AssignmentORM, + EventLogORM, + ExecutionRunORM, + TaskORM, +) +from src.packages.core.task_state_machine import transition_task_status + + +class WorkerService: + def __init__(self, db: Session): + self.db = db + + def _emit_execution_event( + self, + *, + event_type: str, + task: TaskORM, + run: ExecutionRunORM, + agent_role: AgentRoleORM, + message: str, + ) -> None: + self.db.add( + EventLogORM( + batch_id=task.batch_id, + task_id=task.id, + run_id=run.id, + event_type=event_type, + event_status=run.run_status, + message=message, + payload={ + "task_id": task.id, + "run_id": run.id, + "agent_role_id": agent_role.id, + "role_name": agent_role.role_name, + }, + ) + ) + + def _active_assignment_query(self, task_id: str) -> Select[tuple[AssignmentORM]]: + return ( + select(AssignmentORM) + .where( + AssignmentORM.task_id == task_id, + AssignmentORM.assignment_status == "active", + ) + .order_by(AssignmentORM.assigned_at.desc()) + .limit(1) + ) + + def claim_next_task(self) -> TaskORM | None: + task = self.db.scalar( + select(TaskORM) + .where(TaskORM.status == "queued") + .order_by(TaskORM.created_at.asc()) + .with_for_update(skip_locked=True) + .limit(1) + ) + if task is None: + return None + + transition_task_status( + self.db, + task, + to_status="running", + reason="Worker claimed queued task", + source="worker", + ) + self.db.flush() + return task + + def execute_task(self, task: TaskORM) -> ExecutionRunORM: + assignment = self.db.scalar(self._active_assignment_query(task.id)) + if assignment is None: + raise RuntimeError(f"No active assignment found for task {task.id}") + + agent_role = self.db.get(AgentRoleORM, assignment.agent_role_id) + if agent_role is None: + raise RuntimeError(f"Assigned agent role {assignment.agent_role_id} not found") + + started_at = datetime.now(timezone.utc) + + run = ExecutionRunORM( + task_id=task.id, + agent_role_id=agent_role.id, + run_status="running", + started_at=started_at, + logs=[f"Execution started for role {agent_role.role_name}"], + input_snapshot=task.input_payload, + output_snapshot={}, + error_message=None, + token_usage={}, + latency_ms=None, + ) + self.db.add(run) + self.db.flush() + + self._emit_execution_event( + event_type="task_execution_started", + task=task, + run=run, + agent_role=agent_role, + message="Worker execution started", + ) + + try: + agent = get_worker_agent(agent_role.role_name) + result = agent.run( + task, + { + "agent_role_id": agent_role.id, + "role_name": agent_role.role_name, + "assignment_id": assignment.id, + }, + ) + finished_at = datetime.now(timezone.utc) + run.run_status = "succeeded" + run.finished_at = finished_at + run.output_snapshot = result + run.logs = [*run.logs, "Execution completed successfully"] + run.latency_ms = max(int((finished_at - started_at).total_seconds() * 1000), 0) + assignment.assignment_status = "fulfilled" + transition_task_status( + self.db, + task, + to_status="success", + reason="Worker execution completed successfully", + source="worker", + run_id=run.id, + ) + self._emit_execution_event( + event_type="task_execution_finished", + task=task, + run=run, + agent_role=agent_role, + message="Worker execution finished", + ) + self.db.flush() + return run + except Exception as exc: + finished_at = datetime.now(timezone.utc) + run.run_status = "failed" + run.finished_at = finished_at + run.error_message = str(exc) + run.logs = [*run.logs, f"Execution failed: {exc}"] + run.latency_ms = max(int((finished_at - started_at).total_seconds() * 1000), 0) + transition_task_status( + self.db, + task, + to_status="failed", + reason=f"Worker execution failed: {exc}", + source="worker", + run_id=run.id, + ) + self._emit_execution_event( + event_type="task_execution_failed", + task=task, + run=run, + agent_role=agent_role, + message=str(exc), + ) + self.db.flush() + return run + + def run_once(self) -> ExecutionRunORM | None: + with self.db.begin(): + task = self.claim_next_task() + if task is None: + return None + return self.execute_task(task) diff --git a/src/tests/test_worker_execution.py b/src/tests/test_worker_execution.py new file mode 100644 index 0000000..708d90a --- /dev/null +++ b/src/tests/test_worker_execution.py @@ -0,0 +1,256 @@ +from __future__ import annotations + +import os +import sys +import uuid +from pathlib import Path + +from fastapi.testclient import TestClient +from sqlalchemy import create_engine, text +from sqlalchemy.orm import Session + + +ROOT = Path(__file__).resolve().parents[2] +WORKER_PREFIX = "worker-test-" + +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + + +def _database_url() -> str: + database_url = os.getenv("DATABASE_URL") + if not database_url: + env_file = ROOT / ".env" + if env_file.exists(): + for line in env_file.read_text(encoding="utf-8").splitlines(): + if line.startswith("DATABASE_URL="): + database_url = line.split("=", 1)[1].strip() + break + if not database_url: + raise RuntimeError("DATABASE_URL is not set") + return database_url + + +def _cleanup_database() -> None: + engine = create_engine(_database_url()) + with engine.begin() as conn: + conn.execute( + text("DELETE FROM task_batches WHERE title LIKE :prefix"), + {"prefix": f"{WORKER_PREFIX}%"}, + ) + conn.execute( + text( + "DELETE FROM agent_roles " + "WHERE role_name IN ('default_worker', 'failing_worker', 'echo_worker')" + ) + ) + + +_cleanup_database() + +from src.apps.api.app import app # noqa: E402 +from src.apps.worker.service import WorkerService # noqa: E402 + + +client = TestClient(app) + + +def setup_function() -> None: + _cleanup_database() + + +def teardown_function() -> None: + _cleanup_database() + + +def _register_agent( + *, + role_name: str, + capabilities: list[str], + supported_task_types: list[str], +) -> dict: + payload = { + "role_name": role_name, + "description": "worker test role", + "capabilities": capabilities, + "capability_declaration": { + "supported_task_types": supported_task_types, + "input_requirements": {"properties": {"text": {"type": "string"}}}, + "output_contract": {"type": "object"}, + "supports_concurrency": False, + "allows_auto_retry": False, + }, + "input_schema": {}, + "output_schema": {}, + "timeout_seconds": 300, + "max_retries": 1, + "enabled": True, + "version": "1.0.0", + } + response = client.post("/agents/register", json=payload) + assert response.status_code == 201 + return response.json() + + +def _submit_batch(*, task_type: str, title_suffix: str) -> dict: + payload = { + "title": f"{WORKER_PREFIX}batch-{title_suffix}", + "description": "worker execution batch", + "created_by": "pytest", + "metadata": {"suite": "worker"}, + "tasks": [ + { + "client_task_id": "task_1", + "title": f"{WORKER_PREFIX}task-{title_suffix}-1", + "task_type": task_type, + "priority": "medium", + "input_payload": {"text": "hello"}, + "expected_output_schema": {"type": "object"}, + "dependency_client_task_ids": [], + }, + { + "client_task_id": "task_2", + "title": f"{WORKER_PREFIX}task-{title_suffix}-2", + "task_type": task_type, + "priority": "medium", + "input_payload": {"text": "world"}, + "expected_output_schema": {"type": "object"}, + "dependency_client_task_ids": [], + }, + { + "client_task_id": "task_3", + "title": f"{WORKER_PREFIX}task-{title_suffix}-3", + "task_type": task_type, + "priority": "medium", + "input_payload": {"text": "!"}, + "expected_output_schema": {"type": "object"}, + "dependency_client_task_ids": [], + }, + ], + } + response = client.post("/task-batches", json=payload) + assert response.status_code == 201 + return response.json() + + +def test_worker_executes_queued_task_to_success() -> None: + suffix = uuid.uuid4().hex[:8] + _register_agent( + role_name="echo_worker", + capabilities=["task:echo_task"], + supported_task_types=["echo_task"], + ) + created = _submit_batch(task_type="echo_task", title_suffix=suffix) + + engine = create_engine(_database_url()) + with Session(engine) as session: + worker = WorkerService(session) + run = worker.run_once() + assert run is not None + run_id = run.id + task_id = run.task_id + + task_response = client.get(f"/tasks/{task_id}") + assert task_response.status_code == 200 + assert task_response.json()["status"] == "success" + + run_response = client.get(f"/runs/{run_id}") + assert run_response.status_code == 200 + run_body = run_response.json() + assert run_body["run_status"] == "succeeded" + assert run_body["output_snapshot"]["status"] == "ok" + assert run_body["output_snapshot"]["task_id"] == task_id + assert run_body["output_snapshot"]["echo"]["text"] in {"hello", "world", "!"} + assert run_body["latency_ms"] is not None + + events_response = client.get(f"/tasks/{task_id}/events") + assert events_response.status_code == 200 + event_statuses = [event["event_status"] for event in events_response.json() if event["event_type"] == "task_status_changed"] + assert event_statuses == ["queued", "running", "success"] + + remaining_task_ids = [task["task_id"] for task in created["tasks"] if task["task_id"] != task_id] + assert len(remaining_task_ids) == 2 + + +def test_worker_marks_task_failed_and_persists_error() -> None: + suffix = uuid.uuid4().hex[:8] + _register_agent( + role_name="failing_worker", + capabilities=["task:fail_task"], + supported_task_types=["fail_task"], + ) + _submit_batch(task_type="fail_task", title_suffix=suffix) + + engine = create_engine(_database_url()) + with Session(engine) as session: + worker = WorkerService(session) + run = worker.run_once() + assert run is not None + run_id = run.id + task_id = run.task_id + + task_response = client.get(f"/tasks/{task_id}") + assert task_response.status_code == 200 + assert task_response.json()["status"] == "failed" + + run_response = client.get(f"/runs/{run_id}") + assert run_response.status_code == 200 + run_body = run_response.json() + assert run_body["run_status"] == "failed" + assert "Agent execution failed" in run_body["error_message"] + assert run_body["output_snapshot"] == {} + + events_response = client.get(f"/tasks/{task_id}/events") + assert events_response.status_code == 200 + event_statuses = [event["event_status"] for event in events_response.json() if event["event_type"] == "task_status_changed"] + assert event_statuses == ["queued", "running", "failed"] + + +def test_run_once_returns_none_when_queue_is_empty() -> None: + engine = create_engine(_database_url()) + with Session(engine) as session: + worker = WorkerService(session) + run = worker.run_once() + assert run is None + + with engine.connect() as conn: + run_count = conn.execute( + text( + "SELECT count(*) " + "FROM execution_runs r " + "JOIN tasks t ON r.task_id = t.id " + "WHERE t.title LIKE :prefix" + ), + {"prefix": f"{WORKER_PREFIX}%"}, + ).scalar_one() + assert run_count == 0 + + +def test_worker_does_not_consume_same_task_twice() -> None: + suffix = uuid.uuid4().hex[:8] + _register_agent( + role_name="echo_worker", + capabilities=["task:echo_task"], + supported_task_types=["echo_task"], + ) + _submit_batch(task_type="echo_task", title_suffix=suffix) + + engine = create_engine(_database_url()) + with Session(engine) as first_session: + worker = WorkerService(first_session) + first_run = worker.run_once() + assert first_run is not None + first_task_id = first_run.task_id + + with Session(engine) as second_session: + worker = WorkerService(second_session) + second_run = worker.run_once() + assert second_run is not None + assert second_run.task_id != first_task_id + + with engine.connect() as conn: + remaining_queued = conn.execute( + text("SELECT count(*) FROM tasks WHERE status = 'queued' AND title LIKE :prefix"), + {"prefix": f"{WORKER_PREFIX}%"}, + ).scalar_one() + assert remaining_queued == 1 From 90f7184d0cb441938f089a025e4be3e707276369 Mon Sep 17 00:00:00 2001 From: Ayu Date: Mon, 30 Mar 2026 23:34:47 +0800 Subject: [PATCH 2/3] Add built-in planner worker reviewer demo chain --- src/apps/api/app.py | 6 + src/apps/api/bootstrap.py | 76 +++++++++++++ src/apps/worker/builtin_agents.py | 48 ++++++++ src/apps/worker/registry.py | 12 +- src/tests/test_builtin_demo_chain.py | 161 +++++++++++++++++++++++++++ 5 files changed, 302 insertions(+), 1 deletion(-) create mode 100644 src/apps/api/bootstrap.py create mode 100644 src/tests/test_builtin_demo_chain.py diff --git a/src/apps/api/app.py b/src/apps/api/app.py index 38bdd01..abae1bb 100644 --- a/src/apps/api/app.py +++ b/src/apps/api/app.py @@ -2,6 +2,7 @@ from fastapi import FastAPI +from src.apps.api.bootstrap import ensure_builtin_agent_roles from src.apps.api.routers import ( agents_router, health_router, @@ -24,3 +25,8 @@ app.include_router(agents_router) app.include_router(runs_router) app.include_router(reviews_router) + + +@app.on_event("startup") +def bootstrap_defaults() -> None: + ensure_builtin_agent_roles() diff --git a/src/apps/api/bootstrap.py b/src/apps/api/bootstrap.py new file mode 100644 index 0000000..e419b6e --- /dev/null +++ b/src/apps/api/bootstrap.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from src.apps.api.deps import engine +from src.packages.core.db.models import AgentRoleORM + + +BUILTIN_ROLES: tuple[dict, ...] = ( + { + "role_name": "planner_agent", + "description": "Built-in planner for demo preprocessing", + "capabilities": ["task:planner_preprocess"], + "input_schema": { + "supported_task_types": ["planner_preprocess"], + "supports_concurrency": True, + "allows_auto_retry": False, + }, + "output_schema": {}, + }, + { + "role_name": "worker_agent", + "description": "Built-in worker for demo execution", + "capabilities": ["task:worker_execute"], + "input_schema": { + "supported_task_types": ["worker_execute"], + "supports_concurrency": True, + "allows_auto_retry": False, + }, + "output_schema": {}, + }, + { + "role_name": "reviewer_agent", + "description": "Built-in reviewer for demo validation", + "capabilities": ["task:reviewer_validate"], + "input_schema": { + "supported_task_types": ["reviewer_validate"], + "supports_concurrency": True, + "allows_auto_retry": False, + }, + "output_schema": {}, + }, +) + + +def ensure_builtin_agent_roles() -> None: + with Session(engine) as session: + with session.begin(): + for config in BUILTIN_ROLES: + existing = session.scalar( + select(AgentRoleORM).where(AgentRoleORM.role_name == config["role_name"]) + ) + if existing is not None: + existing.description = config["description"] + existing.capabilities = config["capabilities"] + existing.input_schema = config["input_schema"] + existing.output_schema = config["output_schema"] + existing.timeout_seconds = 300 + existing.max_retries = 0 + existing.enabled = True + existing.version = "1.0.0" + continue + session.add( + AgentRoleORM( + role_name=config["role_name"], + description=config["description"], + capabilities=config["capabilities"], + input_schema=config["input_schema"], + output_schema=config["output_schema"], + timeout_seconds=300, + max_retries=0, + enabled=True, + version="1.0.0", + ) + ) diff --git a/src/apps/worker/builtin_agents.py b/src/apps/worker/builtin_agents.py index 0e1f27d..d04cb57 100644 --- a/src/apps/worker/builtin_agents.py +++ b/src/apps/worker/builtin_agents.py @@ -24,3 +24,51 @@ def run(self, task: TaskORM, context: dict[str, Any]) -> dict[str, Any]: class FailingWorkerAgent: def run(self, task: TaskORM, context: dict[str, Any]) -> dict[str, Any]: raise RuntimeError(f"Agent execution failed for task {task.id}") + + +class PlannerWorkerAgent: + def run(self, task: TaskORM, context: dict[str, Any]) -> dict[str, Any]: + text = str(task.input_payload.get("text", "")).strip() + tags = [task.task_type, "demo", "planned"] + steps = [step for step in text.split(" ") if step] + return { + "status": "ok", + "stage": "planner", + "task_id": task.id, + "normalized_text": text, + "tags": tags, + "steps": steps, + "context": context, + } + + +class DefaultWorkerAgent: + def run(self, task: TaskORM, context: dict[str, Any]) -> dict[str, Any]: + return { + "status": "ok", + "stage": "worker", + "task_id": task.id, + "result": { + "summary": f"processed task_type={task.task_type}", + "input": task.input_payload, + }, + "context": context, + } + + +class ReviewerWorkerAgent: + def run(self, task: TaskORM, context: dict[str, Any]) -> dict[str, Any]: + raw = task.input_payload.get("raw_output") + validation_passed = bool(raw) or bool(task.input_payload.get("allow_empty")) + needs_manual_review = bool(task.input_payload.get("force_manual_review")) or not validation_passed + return { + "status": "ok", + "stage": "reviewer", + "task_id": task.id, + "validation_passed": validation_passed, + "needs_manual_review": needs_manual_review, + "notes": "manual review required due to failed validation" + if needs_manual_review + else "auto review passed", + "context": context, + } diff --git a/src/apps/worker/registry.py b/src/apps/worker/registry.py index 831cb30..7b2dbc7 100644 --- a/src/apps/worker/registry.py +++ b/src/apps/worker/registry.py @@ -1,6 +1,13 @@ from __future__ import annotations -from src.apps.worker.builtin_agents import EchoWorkerAgent, FailingWorkerAgent, WorkerAgent +from src.apps.worker.builtin_agents import ( + DefaultWorkerAgent, + EchoWorkerAgent, + FailingWorkerAgent, + PlannerWorkerAgent, + ReviewerWorkerAgent, + WorkerAgent, +) def get_worker_agent(role_name: str) -> WorkerAgent: @@ -8,6 +15,9 @@ def get_worker_agent(role_name: str) -> WorkerAgent: "default_worker": EchoWorkerAgent(), "echo_worker": EchoWorkerAgent(), "failing_worker": FailingWorkerAgent(), + "planner_agent": PlannerWorkerAgent(), + "worker_agent": DefaultWorkerAgent(), + "reviewer_agent": ReviewerWorkerAgent(), } if role_name not in agents: diff --git a/src/tests/test_builtin_demo_chain.py b/src/tests/test_builtin_demo_chain.py new file mode 100644 index 0000000..defc804 --- /dev/null +++ b/src/tests/test_builtin_demo_chain.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import os +import sys +import uuid +from pathlib import Path + +from fastapi.testclient import TestClient +from sqlalchemy import create_engine, text +from sqlalchemy.orm import Session + + +ROOT = Path(__file__).resolve().parents[2] +DEMO_PREFIX = "builtin-demo-" + +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + + +def _database_url() -> str: + database_url = os.getenv("DATABASE_URL") + if not database_url: + env_file = ROOT / ".env" + if env_file.exists(): + for line in env_file.read_text(encoding="utf-8").splitlines(): + if line.startswith("DATABASE_URL="): + database_url = line.split("=", 1)[1].strip() + break + if not database_url: + raise RuntimeError("DATABASE_URL is not set") + return database_url + + +def _cleanup_database() -> None: + engine = create_engine(_database_url()) + with engine.begin() as conn: + conn.execute( + text("DELETE FROM task_batches WHERE title LIKE :prefix"), + {"prefix": f"{DEMO_PREFIX}%"}, + ) + + +def _demo_payload(suffix: str) -> dict: + return { + "title": f"{DEMO_PREFIX}batch-{suffix}", + "description": "built-in three role demo", + "created_by": "pytest", + "metadata": {"suite": "builtin-chain"}, + "tasks": [ + { + "client_task_id": "task_1", + "title": f"{DEMO_PREFIX}planner-{suffix}", + "task_type": "planner_preprocess", + "priority": "medium", + "input_payload": {"text": "draft implementation plan"}, + "expected_output_schema": {"type": "object"}, + "dependency_client_task_ids": [], + }, + { + "client_task_id": "task_2", + "title": f"{DEMO_PREFIX}worker-{suffix}", + "task_type": "worker_execute", + "priority": "medium", + "input_payload": {"text": "execute planned task"}, + "expected_output_schema": {"type": "object"}, + "dependency_client_task_ids": ["task_1"], + }, + { + "client_task_id": "task_3", + "title": f"{DEMO_PREFIX}reviewer-{suffix}", + "task_type": "reviewer_validate", + "priority": "medium", + "input_payload": {"raw_output": {"status": "ok"}}, + "expected_output_schema": {"type": "object"}, + "dependency_client_task_ids": ["task_2"], + }, + ], + } + + +_cleanup_database() + +from src.apps.api.app import app # noqa: E402 +from src.apps.worker.service import WorkerService # noqa: E402 + +def setup_function() -> None: + _cleanup_database() + + +def teardown_function() -> None: + _cleanup_database() + + +def test_builtin_roles_seeded_once_and_demo_chain_runs() -> None: + with TestClient(app) as client: + engine = create_engine(_database_url()) + with engine.connect() as conn: + planner_count = conn.execute( + text("SELECT count(*) FROM agent_roles WHERE role_name = 'planner_agent'") + ).scalar_one() + worker_count = conn.execute( + text("SELECT count(*) FROM agent_roles WHERE role_name = 'worker_agent'") + ).scalar_one() + reviewer_count = conn.execute( + text("SELECT count(*) FROM agent_roles WHERE role_name = 'reviewer_agent'") + ).scalar_one() + + assert planner_count == 1 + assert worker_count == 1 + assert reviewer_count == 1 + + suffix = uuid.uuid4().hex[:8] + create_response = client.post("/task-batches", json=_demo_payload(suffix)) + assert create_response.status_code == 201 + created = create_response.json() + created_ids = [task["task_id"] for task in created["tasks"]] + + with Session(engine) as session: + worker = WorkerService(session) + first_run = worker.run_once() + second_run = worker.run_once() + third_run = worker.run_once() + assert first_run is not None + assert second_run is not None + assert third_run is not None + run_summaries = [ + {"run_id": first_run.id, "task_id": first_run.task_id}, + {"run_id": second_run.id, "task_id": second_run.task_id}, + {"run_id": third_run.id, "task_id": third_run.task_id}, + ] + + assert {item["task_id"] for item in run_summaries} == set(created_ids) + + reviewer_task_id = created["tasks"][2]["task_id"] + run_body = {} + for summary in run_summaries: + reviewer_run_response = client.get(f"/runs/{summary['run_id']}") + assert reviewer_run_response.status_code == 200 + candidate = reviewer_run_response.json() + if candidate["task_id"] == reviewer_task_id: + run_body = candidate + break + + assert run_body["task_id"] == reviewer_task_id + assert run_body["output_snapshot"]["stage"] == "reviewer" + assert run_body["output_snapshot"]["validation_passed"] is True + assert run_body["output_snapshot"]["needs_manual_review"] is False + + for task_id in created_ids: + task_response = client.get(f"/tasks/{task_id}") + assert task_response.status_code == 200 + assert task_response.json()["status"] == "success" + + events_response = client.get(f"/tasks/{task_id}/events") + assert events_response.status_code == 200 + statuses = [ + event["event_status"] + for event in events_response.json() + if event["event_type"] == "task_status_changed" + ] + assert statuses == ["queued", "running", "success"] From 0cd696b073a66d7d6fe40052f41d03eae791cd0d Mon Sep 17 00:00:00 2001 From: Ayu Date: Tue, 31 Mar 2026 10:06:23 +0800 Subject: [PATCH 3/3] Keep agent registry in worker registry --- src/apps/worker/registry.py | 44 +++++++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/src/apps/worker/registry.py b/src/apps/worker/registry.py index 7b2dbc7..8aefc2e 100644 --- a/src/apps/worker/registry.py +++ b/src/apps/worker/registry.py @@ -1,7 +1,11 @@ from __future__ import annotations +from src.apps.worker.types import AgentRunner, WorkerContext +from src.packages.sdk.base_agent import BaseAgent +from src.packages.core.db.models import TaskORM + from src.apps.worker.builtin_agents import ( - DefaultWorkerAgent, + DefaultWorkerAgent as BuiltinDefaultWorkerAgent, EchoWorkerAgent, FailingWorkerAgent, PlannerWorkerAgent, @@ -10,13 +14,49 @@ ) +class DefaultWorkerAgent(BaseAgent): + role_name = "default_worker" + capabilities = ["default_worker"] + + def run(self, task: TaskORM, context: WorkerContext) -> dict: + return { + "status": "ok", + "task_id": task.id, + "run_id": context.run_id, + "agent_role": context.agent_role_name, + "echo": task.input_payload, + } + + +class AgentRegistry: + def __init__(self) -> None: + self._agents: dict[str, AgentRunner] = {} + + def register(self, role_name: str, agent: AgentRunner) -> None: + declared_role_name = getattr(agent, "role_name", None) + if declared_role_name and declared_role_name != role_name: + raise ValueError( + f"Agent role_name mismatch: declared={declared_role_name} registered={role_name}" + ) + self._agents[role_name] = agent + + def get(self, role_name: str) -> AgentRunner | None: + return self._agents.get(role_name) + + +def build_default_registry() -> AgentRegistry: + registry = AgentRegistry() + registry.register("default_worker", DefaultWorkerAgent()) + return registry + + def get_worker_agent(role_name: str) -> WorkerAgent: agents: dict[str, WorkerAgent] = { "default_worker": EchoWorkerAgent(), "echo_worker": EchoWorkerAgent(), "failing_worker": FailingWorkerAgent(), "planner_agent": PlannerWorkerAgent(), - "worker_agent": DefaultWorkerAgent(), + "worker_agent": BuiltinDefaultWorkerAgent(), "reviewer_agent": ReviewerWorkerAgent(), }