diff --git a/src/apps/api/app.py b/src/apps/api/app.py index 38bdd01..abae1bb 100644 --- a/src/apps/api/app.py +++ b/src/apps/api/app.py @@ -2,6 +2,7 @@ from fastapi import FastAPI +from src.apps.api.bootstrap import ensure_builtin_agent_roles from src.apps.api.routers import ( agents_router, health_router, @@ -24,3 +25,8 @@ app.include_router(agents_router) app.include_router(runs_router) app.include_router(reviews_router) + + +@app.on_event("startup") +def bootstrap_defaults() -> None: + ensure_builtin_agent_roles() diff --git a/src/apps/api/bootstrap.py b/src/apps/api/bootstrap.py new file mode 100644 index 0000000..e419b6e --- /dev/null +++ b/src/apps/api/bootstrap.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from src.apps.api.deps import engine +from src.packages.core.db.models import AgentRoleORM + + +BUILTIN_ROLES: tuple[dict, ...] = ( + { + "role_name": "planner_agent", + "description": "Built-in planner for demo preprocessing", + "capabilities": ["task:planner_preprocess"], + "input_schema": { + "supported_task_types": ["planner_preprocess"], + "supports_concurrency": True, + "allows_auto_retry": False, + }, + "output_schema": {}, + }, + { + "role_name": "worker_agent", + "description": "Built-in worker for demo execution", + "capabilities": ["task:worker_execute"], + "input_schema": { + "supported_task_types": ["worker_execute"], + "supports_concurrency": True, + "allows_auto_retry": False, + }, + "output_schema": {}, + }, + { + "role_name": "reviewer_agent", + "description": "Built-in reviewer for demo validation", + "capabilities": ["task:reviewer_validate"], + "input_schema": { + "supported_task_types": ["reviewer_validate"], + "supports_concurrency": True, + "allows_auto_retry": False, + }, + "output_schema": {}, + }, +) + + +def ensure_builtin_agent_roles() -> None: + with Session(engine) as session: + with session.begin(): + for config in BUILTIN_ROLES: + existing = session.scalar( + select(AgentRoleORM).where(AgentRoleORM.role_name == config["role_name"]) + ) + if existing is not None: + existing.description = config["description"] + existing.capabilities = config["capabilities"] + existing.input_schema = config["input_schema"] + existing.output_schema = config["output_schema"] + existing.timeout_seconds = 300 + existing.max_retries = 0 + existing.enabled = True + existing.version = "1.0.0" + continue + session.add( + AgentRoleORM( + role_name=config["role_name"], + description=config["description"], + capabilities=config["capabilities"], + input_schema=config["input_schema"], + output_schema=config["output_schema"], + timeout_seconds=300, + max_retries=0, + enabled=True, + version="1.0.0", + ) + ) diff --git a/src/apps/worker/builtin_agents.py b/src/apps/worker/builtin_agents.py new file mode 100644 index 0000000..d04cb57 --- /dev/null +++ b/src/apps/worker/builtin_agents.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from typing import Any, Protocol + +from src.packages.core.db.models import TaskORM + + +class WorkerAgent(Protocol): + def run(self, task: TaskORM, context: dict[str, Any]) -> dict[str, Any]: + ... + + +class EchoWorkerAgent: + def run(self, task: TaskORM, context: dict[str, Any]) -> dict[str, Any]: + return { + "status": "ok", + "task_id": task.id, + "task_type": task.task_type, + "echo": task.input_payload, + "context": context, + } + + +class FailingWorkerAgent: + def run(self, task: TaskORM, context: dict[str, Any]) -> dict[str, Any]: + raise RuntimeError(f"Agent execution failed for task {task.id}") + + +class PlannerWorkerAgent: + def run(self, task: TaskORM, context: dict[str, Any]) -> dict[str, Any]: + text = str(task.input_payload.get("text", "")).strip() + tags = [task.task_type, "demo", "planned"] + steps = [step for step in text.split(" ") if step] + return { + "status": "ok", + "stage": "planner", + "task_id": task.id, + "normalized_text": text, + "tags": tags, + "steps": steps, + "context": context, + } + + +class DefaultWorkerAgent: + def run(self, task: TaskORM, context: dict[str, Any]) -> dict[str, Any]: + return { + "status": "ok", + "stage": "worker", + "task_id": task.id, + "result": { + "summary": f"processed task_type={task.task_type}", + "input": task.input_payload, + }, + "context": context, + } + + +class ReviewerWorkerAgent: + def run(self, task: TaskORM, context: dict[str, Any]) -> dict[str, Any]: + raw = task.input_payload.get("raw_output") + validation_passed = bool(raw) or bool(task.input_payload.get("allow_empty")) + needs_manual_review = bool(task.input_payload.get("force_manual_review")) or not validation_passed + return { + "status": "ok", + "stage": "reviewer", + "task_id": task.id, + "validation_passed": validation_passed, + "needs_manual_review": needs_manual_review, + "notes": "manual review required due to failed validation" + if needs_manual_review + else "auto review passed", + "context": context, + } diff --git a/src/apps/worker/registry.py b/src/apps/worker/registry.py index 58a522b..8aefc2e 100644 --- a/src/apps/worker/registry.py +++ b/src/apps/worker/registry.py @@ -4,6 +4,15 @@ from src.packages.sdk.base_agent import BaseAgent from src.packages.core.db.models import TaskORM +from src.apps.worker.builtin_agents import ( + DefaultWorkerAgent as BuiltinDefaultWorkerAgent, + EchoWorkerAgent, + FailingWorkerAgent, + PlannerWorkerAgent, + ReviewerWorkerAgent, + WorkerAgent, +) + class DefaultWorkerAgent(BaseAgent): role_name = "default_worker" @@ -39,3 +48,19 @@ def build_default_registry() -> AgentRegistry: registry = AgentRegistry() registry.register("default_worker", DefaultWorkerAgent()) return registry + + +def get_worker_agent(role_name: str) -> WorkerAgent: + agents: dict[str, WorkerAgent] = { + "default_worker": EchoWorkerAgent(), + "echo_worker": EchoWorkerAgent(), + "failing_worker": FailingWorkerAgent(), + "planner_agent": PlannerWorkerAgent(), + "worker_agent": BuiltinDefaultWorkerAgent(), + "reviewer_agent": ReviewerWorkerAgent(), + } + + if role_name not in agents: + raise KeyError(f"No worker agent registered for role {role_name}") + + return agents[role_name] diff --git a/src/apps/worker/service.py b/src/apps/worker/service.py new file mode 100644 index 0000000..4e26438 --- /dev/null +++ b/src/apps/worker/service.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +from sqlalchemy import Select, select +from sqlalchemy.orm import Session + +from src.apps.worker.registry import get_worker_agent +from src.packages.core.db.models import ( + AgentRoleORM, + AssignmentORM, + EventLogORM, + ExecutionRunORM, + TaskORM, +) +from src.packages.core.task_state_machine import transition_task_status + + +class WorkerService: + def __init__(self, db: Session): + self.db = db + + def _emit_execution_event( + self, + *, + event_type: str, + task: TaskORM, + run: ExecutionRunORM, + agent_role: AgentRoleORM, + message: str, + ) -> None: + self.db.add( + EventLogORM( + batch_id=task.batch_id, + task_id=task.id, + run_id=run.id, + event_type=event_type, + event_status=run.run_status, + message=message, + payload={ + "task_id": task.id, + "run_id": run.id, + "agent_role_id": agent_role.id, + "role_name": agent_role.role_name, + }, + ) + ) + + def _active_assignment_query(self, task_id: str) -> Select[tuple[AssignmentORM]]: + return ( + select(AssignmentORM) + .where( + AssignmentORM.task_id == task_id, + AssignmentORM.assignment_status == "active", + ) + .order_by(AssignmentORM.assigned_at.desc()) + .limit(1) + ) + + def claim_next_task(self) -> TaskORM | None: + task = self.db.scalar( + select(TaskORM) + .where(TaskORM.status == "queued") + .order_by(TaskORM.created_at.asc()) + .with_for_update(skip_locked=True) + .limit(1) + ) + if task is None: + return None + + transition_task_status( + self.db, + task, + to_status="running", + reason="Worker claimed queued task", + source="worker", + ) + self.db.flush() + return task + + def execute_task(self, task: TaskORM) -> ExecutionRunORM: + assignment = self.db.scalar(self._active_assignment_query(task.id)) + if assignment is None: + raise RuntimeError(f"No active assignment found for task {task.id}") + + agent_role = self.db.get(AgentRoleORM, assignment.agent_role_id) + if agent_role is None: + raise RuntimeError(f"Assigned agent role {assignment.agent_role_id} not found") + + started_at = datetime.now(timezone.utc) + + run = ExecutionRunORM( + task_id=task.id, + agent_role_id=agent_role.id, + run_status="running", + started_at=started_at, + logs=[f"Execution started for role {agent_role.role_name}"], + input_snapshot=task.input_payload, + output_snapshot={}, + error_message=None, + token_usage={}, + latency_ms=None, + ) + self.db.add(run) + self.db.flush() + + self._emit_execution_event( + event_type="task_execution_started", + task=task, + run=run, + agent_role=agent_role, + message="Worker execution started", + ) + + try: + agent = get_worker_agent(agent_role.role_name) + result = agent.run( + task, + { + "agent_role_id": agent_role.id, + "role_name": agent_role.role_name, + "assignment_id": assignment.id, + }, + ) + finished_at = datetime.now(timezone.utc) + run.run_status = "succeeded" + run.finished_at = finished_at + run.output_snapshot = result + run.logs = [*run.logs, "Execution completed successfully"] + run.latency_ms = max(int((finished_at - started_at).total_seconds() * 1000), 0) + assignment.assignment_status = "fulfilled" + transition_task_status( + self.db, + task, + to_status="success", + reason="Worker execution completed successfully", + source="worker", + run_id=run.id, + ) + self._emit_execution_event( + event_type="task_execution_finished", + task=task, + run=run, + agent_role=agent_role, + message="Worker execution finished", + ) + self.db.flush() + return run + except Exception as exc: + finished_at = datetime.now(timezone.utc) + run.run_status = "failed" + run.finished_at = finished_at + run.error_message = str(exc) + run.logs = [*run.logs, f"Execution failed: {exc}"] + run.latency_ms = max(int((finished_at - started_at).total_seconds() * 1000), 0) + transition_task_status( + self.db, + task, + to_status="failed", + reason=f"Worker execution failed: {exc}", + source="worker", + run_id=run.id, + ) + self._emit_execution_event( + event_type="task_execution_failed", + task=task, + run=run, + agent_role=agent_role, + message=str(exc), + ) + self.db.flush() + return run + + def run_once(self) -> ExecutionRunORM | None: + with self.db.begin(): + task = self.claim_next_task() + if task is None: + return None + return self.execute_task(task) diff --git a/src/tests/test_builtin_demo_chain.py b/src/tests/test_builtin_demo_chain.py new file mode 100644 index 0000000..defc804 --- /dev/null +++ b/src/tests/test_builtin_demo_chain.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import os +import sys +import uuid +from pathlib import Path + +from fastapi.testclient import TestClient +from sqlalchemy import create_engine, text +from sqlalchemy.orm import Session + + +ROOT = Path(__file__).resolve().parents[2] +DEMO_PREFIX = "builtin-demo-" + +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + + +def _database_url() -> str: + database_url = os.getenv("DATABASE_URL") + if not database_url: + env_file = ROOT / ".env" + if env_file.exists(): + for line in env_file.read_text(encoding="utf-8").splitlines(): + if line.startswith("DATABASE_URL="): + database_url = line.split("=", 1)[1].strip() + break + if not database_url: + raise RuntimeError("DATABASE_URL is not set") + return database_url + + +def _cleanup_database() -> None: + engine = create_engine(_database_url()) + with engine.begin() as conn: + conn.execute( + text("DELETE FROM task_batches WHERE title LIKE :prefix"), + {"prefix": f"{DEMO_PREFIX}%"}, + ) + + +def _demo_payload(suffix: str) -> dict: + return { + "title": f"{DEMO_PREFIX}batch-{suffix}", + "description": "built-in three role demo", + "created_by": "pytest", + "metadata": {"suite": "builtin-chain"}, + "tasks": [ + { + "client_task_id": "task_1", + "title": f"{DEMO_PREFIX}planner-{suffix}", + "task_type": "planner_preprocess", + "priority": "medium", + "input_payload": {"text": "draft implementation plan"}, + "expected_output_schema": {"type": "object"}, + "dependency_client_task_ids": [], + }, + { + "client_task_id": "task_2", + "title": f"{DEMO_PREFIX}worker-{suffix}", + "task_type": "worker_execute", + "priority": "medium", + "input_payload": {"text": "execute planned task"}, + "expected_output_schema": {"type": "object"}, + "dependency_client_task_ids": ["task_1"], + }, + { + "client_task_id": "task_3", + "title": f"{DEMO_PREFIX}reviewer-{suffix}", + "task_type": "reviewer_validate", + "priority": "medium", + "input_payload": {"raw_output": {"status": "ok"}}, + "expected_output_schema": {"type": "object"}, + "dependency_client_task_ids": ["task_2"], + }, + ], + } + + +_cleanup_database() + +from src.apps.api.app import app # noqa: E402 +from src.apps.worker.service import WorkerService # noqa: E402 + +def setup_function() -> None: + _cleanup_database() + + +def teardown_function() -> None: + _cleanup_database() + + +def test_builtin_roles_seeded_once_and_demo_chain_runs() -> None: + with TestClient(app) as client: + engine = create_engine(_database_url()) + with engine.connect() as conn: + planner_count = conn.execute( + text("SELECT count(*) FROM agent_roles WHERE role_name = 'planner_agent'") + ).scalar_one() + worker_count = conn.execute( + text("SELECT count(*) FROM agent_roles WHERE role_name = 'worker_agent'") + ).scalar_one() + reviewer_count = conn.execute( + text("SELECT count(*) FROM agent_roles WHERE role_name = 'reviewer_agent'") + ).scalar_one() + + assert planner_count == 1 + assert worker_count == 1 + assert reviewer_count == 1 + + suffix = uuid.uuid4().hex[:8] + create_response = client.post("/task-batches", json=_demo_payload(suffix)) + assert create_response.status_code == 201 + created = create_response.json() + created_ids = [task["task_id"] for task in created["tasks"]] + + with Session(engine) as session: + worker = WorkerService(session) + first_run = worker.run_once() + second_run = worker.run_once() + third_run = worker.run_once() + assert first_run is not None + assert second_run is not None + assert third_run is not None + run_summaries = [ + {"run_id": first_run.id, "task_id": first_run.task_id}, + {"run_id": second_run.id, "task_id": second_run.task_id}, + {"run_id": third_run.id, "task_id": third_run.task_id}, + ] + + assert {item["task_id"] for item in run_summaries} == set(created_ids) + + reviewer_task_id = created["tasks"][2]["task_id"] + run_body = {} + for summary in run_summaries: + reviewer_run_response = client.get(f"/runs/{summary['run_id']}") + assert reviewer_run_response.status_code == 200 + candidate = reviewer_run_response.json() + if candidate["task_id"] == reviewer_task_id: + run_body = candidate + break + + assert run_body["task_id"] == reviewer_task_id + assert run_body["output_snapshot"]["stage"] == "reviewer" + assert run_body["output_snapshot"]["validation_passed"] is True + assert run_body["output_snapshot"]["needs_manual_review"] is False + + for task_id in created_ids: + task_response = client.get(f"/tasks/{task_id}") + assert task_response.status_code == 200 + assert task_response.json()["status"] == "success" + + events_response = client.get(f"/tasks/{task_id}/events") + assert events_response.status_code == 200 + statuses = [ + event["event_status"] + for event in events_response.json() + if event["event_type"] == "task_status_changed" + ] + assert statuses == ["queued", "running", "success"]