diff --git a/CHANGELOG.md b/CHANGELOG.md index 904f09be..74f08f99 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [Unreleased] + +### Features + +* **hitl:** add production-ready Human-in-the-Loop approval gateway for ADK agents — includes `@hitl_tool` decorator, FastAPI approval service with SQLite persistence, ADK 1.x adapter, and reference Streamlit dashboard (`contributing/samples/hitl_approval`) + ## [0.4.1](https://github.com/google/adk-python-community/compare/v0.4.0...v0.4.1) (2026-02-18) diff --git a/contributing/samples/hitl_approval/README.md b/contributing/samples/hitl_approval/README.md new file mode 100644 index 00000000..e1a6c289 --- /dev/null +++ b/contributing/samples/hitl_approval/README.md @@ -0,0 +1,91 @@ +# ADK HITL Approval Dashboard + +A drop-in **production-ready Human-in-the-Loop (HITL) approval middleware** for Google Agent Development Kit (ADK) agents — complete with an API backend and a demo Streamlit dashboard UI. + +## The Problem Solved + +ADK 1.x ships with an experimental `require_confirmation=True` feature that handles pausing the LLM loop for human verification. However, it is fundamentally built for local debugging and introduces major blockers to an enterprise environment: +1. **Incompatible with Persistent Sessions:** Native confirmations intentionally do not serialize well and will completely fail to resume your agent if you use `DatabaseSessionService`, `SpannerSessionService`, or `VertexAiSessionService` (the mandatory session backends for production deployments). +2. **Single-Agent Limitations:** They silently break across `AgentTool` nested bounds and true multi-agent (A2A) topologies, causing missing events or infinitely looping models. +3. **No Resilient Audit Log:** The Native confirmation tool leaves no easily queryable paper trail linking the human supervisor to a precise LLM request. + +*This project is the production implementation of the HITL pattern covered in the [ADK Multi-Agent Patterns Guide (Advent of Agents Day 13)](#).* + +## What This Library Provides + +This project solves the production gaps by explicitly decoupling the human approval payload from ADK's internal session memory. It introduces a session-agnostic REST API layer using an Adapter pattern. + +### The 3-Layer Architecture + +```text +┌─────────────────────────────────────────┐ +│ Dashboard UI (Streamlit) │ Layer 3: Demo/reference UI +│ Approval inbox, audit log viewer │ (Easily replaced by Zendesk/etc.) +└──────────────────┬──────────────────────┘ + │ +┌──────────────────▼──────────────────────┐ +│ ApprovalRequest Model (Pydantic) │ Layer 2: Normalised Contract API +│ FastAPI backend + SQLite store │ Session-agnostic persistence +└────────────────┬────────────────────────┘ + │ + ┌──────────┴───────────┐ +┌─────▼──────┐ ┌──────────▼──────┐ +│ ADK 1.x │ │ ADK 2.0 │ Layer 1: Adapters +│ Adapter │ │ Adapter │ Only this changes between versions +└────────────┘ └─────────────────┘ +``` + +By retaining HITL state inside an independent FastAPI engine and SQLite database, an active agent can pause safely. When a human supervisor hits "Approve" inside a centralized web portal hours later, the middleware simply posts the decision back into the agent's `/run_sse` stream seamlessly. + +## Quick Start (Local Sandbox) + +We have provided a demo customer service agent (`credit_agent`) alongside a launch script to test the interaction end-to-end. + +1. Create your python virtual environment and sync dependencies using `uv` (requires Python 3.11+): + ```bash + uv venv --python "python3.11" ".venv" + source .venv/bin/activate + uv sync --all-extras + ``` +2. Start the FastAPI backend, Streamlit dashboard, and ADK Live Chat agent all at once: + ```bash + ./start_servers.sh + ``` +3. Open `http://localhost:8080` to chat with the agent and ask for a $75 account credit. +4. When the agent pauses and asks for a supervisor, open `http://localhost:8501` to approve or reject the request. + +## How to use in your own ADK application + +Wrapping an ADK agent with a formal enterprise HITL checkpoint takes under 5 lines of code: + +1. Import the `hitl_tool` gateway wrapper. +2. Decorate your function tool. +3. Attach it to your ADK Agent initialization using a standard `FunctionTool`. + +```python +from google.adk.tools import FunctionTool +from google.adk_community.tools.hitl.gateway import hitl_tool + +# 1. Wrap your function with the decorator +@hitl_tool(agent_name="my_billing_agent") +async def issue_refund(user_id: str, amount: float): + # This block won't execute until explicitly approved inside the FastAPI dashboard + return {"status": "success", "amount_refunded": amount} + +# 2. Attach to ADK Agent +root_agent = Agent( + name="my_billing_agent", + tools=[FunctionTool(issue_refund)] +) +``` + +## Production Integration Strategies + +This repository acts as the production baseline for a contact center or enterprise orchestration grid. Once deployed to staging, consider swapping out: +* **Storage Layer:** Replace the local `SQLite` engine in `app/api/store.py` with `PostgreSQL` or `Cloud Spanner`. +* **Proactive Notification:** Hook the FastAPI `POST /approvals/` route into Slack, PagerDuty, or Microsoft Teams to actively ping channels when a high-risk request pops up. +* **Remove Streamlit:** Bypass the Streamlit frontend completely and point your existing support portal interface (like Salesforce Service Cloud) directly to `GET /approvals/pending` and `POST /approvals/{id}/decide`. + +## ADK 2.0 Compatibility + +This project currently uses ADK 1.x conventions and event triggers. Because it strictly implements an `adapters` layer, all the Pydantic API schemas and Streamlit logic are completely forward-compatible with ADK 2.0 `RequestInput` workflow yielding. You'll simply need to switch the adapter layer translation once ADK 2.0 exits Alpha. diff --git a/contributing/samples/hitl_approval/credit_agent/__init__.py b/contributing/samples/hitl_approval/credit_agent/__init__.py new file mode 100644 index 00000000..02c597e1 --- /dev/null +++ b/contributing/samples/hitl_approval/credit_agent/__init__.py @@ -0,0 +1 @@ +from . import agent diff --git a/contributing/samples/hitl_approval/credit_agent/agent.py b/contributing/samples/hitl_approval/credit_agent/agent.py new file mode 100644 index 00000000..a7a2230a --- /dev/null +++ b/contributing/samples/hitl_approval/credit_agent/agent.py @@ -0,0 +1,80 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Credit agent — external supervisor HITL demo. + +This agent demonstrates the cross-user approval pattern: + - Customer chats in ADK web (:8080) + - Agent wants to apply a credit → submits request to HITL API (:8000) + - Agent blocks (non-blocking async poll) waiting for a decision + - Supervisor opens Streamlit dashboard (:8501), reviews and approves/rejects + - Agent resumes and informs the customer of the outcome + +Make sure all three services are running before chatting (see start_servers.sh): + HITL API: uvicorn google.adk_community.services.hitl_approval.api:app --port 8000 + Dashboard: streamlit run dashboard/app.py --server.headless true + ADK web: adk web credit_agent/ --port 8080 +""" + +from __future__ import annotations + +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) + +from google.adk.agents import Agent +from google.adk.tools import FunctionTool + +from google.adk_community.tools.hitl.gateway import hitl_tool + + +@hitl_tool(agent_name="credit_agent") +async def apply_account_credit(account_id: str, amount: float, reason: str) -> dict: + """Apply a credit to a customer account. Requires supervisor approval. + + Args: + account_id: The customer account ID to credit. + amount: Credit amount in USD. + reason: Business justification for the credit. + + Returns: + Confirmation with the updated account balance. + """ + # Real implementation would call your billing/CRM API here + return { + "status": "credited", + "account_id": account_id, + "amount_credited": amount, + "new_balance": f"${amount:.2f} credit applied successfully.", + } + + +root_agent = Agent( + name="credit_agent", + model="gemini-2.5-flash", + description=( + "Customer support agent that can apply account credits. " + "Every credit requires supervisor approval via the HITL dashboard." + ), + instruction=( + "You are a customer support agent. When a customer requests an account credit, " + "call apply_account_credit with their account ID, the amount, and the reason. " + "Let them know their request is being reviewed by a supervisor and that you will " + "update them once a decision is made. " + "If the credit is approved, confirm it to the customer. " + "If rejected, apologise and explain that the supervisor did not approve it." + ), + tools=[FunctionTool(apply_account_credit)], +) diff --git a/contributing/samples/hitl_approval/dashboard/app.py b/contributing/samples/hitl_approval/dashboard/app.py new file mode 100644 index 00000000..c576198e --- /dev/null +++ b/contributing/samples/hitl_approval/dashboard/app.py @@ -0,0 +1,126 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Streamlit HITL Approval Dashboard. + +Run: + streamlit run contributing/samples/hitl_approval/dashboard/app.py +""" + +import httpx +import streamlit as st + +API_BASE = "http://localhost:8000" + + +def _resolve(request_id: str, decision: str, note: str): + try: + r = httpx.post( + f"{API_BASE}/approvals/{request_id}/decide", + json={ + "decision": decision, + "reviewer_id": "dashboard_admin", + "notes": note or None, + }, + timeout=5, + ) + r.raise_for_status() + st.success(f"Request {request_id[:8]}… marked as {decision}.") + st.rerun() + except Exception as e: + st.error(f"Failed to resolve: {e}") + + +st.set_page_config(page_title="ADK HITL Dashboard", page_icon="🔍", layout="wide") +st.title("ADK HITL Approval Dashboard") + +# ── Sidebar filters ─────────────────────────────────────────────────────────── + +status_filter = st.sidebar.selectbox( + "Filter by status", ["All", "pending", "approved", "rejected", "escalated"] +) + +if st.sidebar.button("Refresh"): + st.rerun() + +# ── Fetch approvals ─────────────────────────────────────────────────────────── + +try: + if status_filter == "pending": + resp = httpx.get(f"{API_BASE}/approvals/pending", timeout=5) + else: + params = {} + if status_filter != "All": + params["decision"] = status_filter + resp = httpx.get(f"{API_BASE}/approvals/audit", params=params, timeout=5) + + resp.raise_for_status() + requests = resp.json() +except Exception as e: + st.error(f"Could not connect to API: {e}") + st.stop() + +# ── Render approval cards ───────────────────────────────────────────────────── + +if not requests: + st.info("No approval requests found.") +else: + for req in requests: + status = req["status"] + color = { + "pending": "🟡", + "approved": "🟢", + "rejected": "🔴", + "escalated": "🟠", + }.get(status, "⚪") + + with st.expander( + f"{color} [{status.upper()}] {req['tool_name']} — {req['agent_name']} ({req['id'][:8]}…)" + ): + col1, col2 = st.columns(2) + col1.markdown( + f"**App:** `{req.get('app_name', 'N/A')}` | **User:** `{req.get('user_id', 'N/A')}`" + ) + col1.markdown(f"**Agent:** `{req['agent_name']}`") + col1.markdown(f"**Tool:** `{req['tool_name']}`") + col1.markdown(f"**Session:** `{req['session_id']}`") + col2.markdown(f"**Created:** {req['created_at']}") + if req.get("decided_at"): + col2.markdown( + f"**Resolved:** {req['decided_at']} by `{req.get('decided_by', 'unknown')}`" + ) + + st.markdown(f"**Message / Hint:**") + st.info(req.get("message", "No message provided.")) + + st.markdown("**Payload / Arguments:**") + st.json(req.get("payload", {})) + + if req.get("decision_notes"): + st.markdown(f"**Reviewer note:** {req['decision_notes']}") + + if status == "pending": + note = st.text_input( + "Reviewer note (optional)", key=f"note_{req['id']}" + ) + c1, c2, c3 = st.columns(3) + + if c1.button("Approve", key=f"approve_{req['id']}", type="primary"): + _resolve(req["id"], "approved", note) + + if c2.button("Reject", key=f"reject_{req['id']}"): + _resolve(req["id"], "rejected", note) + + if c3.button("Escalate", key=f"escalate_{req['id']}"): + _resolve(req["id"], "escalated", note) diff --git a/contributing/samples/hitl_approval/requirements.txt b/contributing/samples/hitl_approval/requirements.txt new file mode 100644 index 00000000..06b8b483 --- /dev/null +++ b/contributing/samples/hitl_approval/requirements.txt @@ -0,0 +1,14 @@ +# Sample-specific dependencies for the HITL Approval demo. +# Install into the repo virtualenv after `uv sync --all-extras`: +# +# uv pip install -r contributing/samples/hitl_approval/requirements.txt +# +# The core package (google-adk-community) and its deps (google-adk, httpx) +# are already installed by `uv sync`. Only the service and dashboard extras +# are listed here. + +fastapi>=0.111.0 +uvicorn[standard]>=0.30.0 +sqlalchemy>=2.0.0 +aiosqlite>=0.20.0 +streamlit>=1.35.0 diff --git a/contributing/samples/hitl_approval/start_servers.sh b/contributing/samples/hitl_approval/start_servers.sh new file mode 100755 index 00000000..afe4b0ae --- /dev/null +++ b/contributing/samples/hitl_approval/start_servers.sh @@ -0,0 +1,42 @@ +#!/bin/bash +# Copyright 2026 Google LLC + +# 1. Kill any lingering local servers from previous runs to free up ports +killall python uvicorn streamlit adk 2>/dev/null || true +sleep 1 + +# 2. Ensure we're running from the repo root so imports resolve correctly +cd "$(git rev-parse --show-toplevel)" + +# 3. Load GOOGLE_GENAI_API_KEY from .env if present +if [ -f .env ]; then + source .env +fi + +echo "Starting FastAPI HITL Backend (:8000)..." +export HITL_DB_PATH="./contributing/samples/hitl_approval/hitl.db" +.venv/bin/uvicorn google.adk_community.services.hitl_approval.api:app --port 8000 & +API_PID=$! + +echo "Starting Streamlit Dashboard (:8501)..." +STREAMLIT_BROWSER_GATHER_USAGE_STATS=false \ + .venv/bin/streamlit run contributing/samples/hitl_approval/dashboard/app.py \ + --server.headless true & +STREAMLIT_PID=$! + +echo "Starting ADK Web Chat (:8080)..." +.venv/bin/adk web contributing/samples/hitl_approval --port 8080 & +ADK_PID=$! + +echo "" +echo "All services launched." +echo "==========================================" +echo "Backend API: http://localhost:8000/docs" +echo "Dashboard UI: http://localhost:8501" +echo "ADK Agent Chat: http://localhost:8080" +echo "==========================================" +echo "Press Ctrl+C to shut down all servers." + +trap "kill $API_PID $STREAMLIT_PID $ADK_PID 2>/dev/null; exit" EXIT + +wait diff --git a/src/google/adk_community/services/hitl_approval/__init__.py b/src/google/adk_community/services/hitl_approval/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/google/adk_community/services/hitl_approval/api.py b/src/google/adk_community/services/hitl_approval/api.py new file mode 100644 index 00000000..0c729da7 --- /dev/null +++ b/src/google/adk_community/services/hitl_approval/api.py @@ -0,0 +1,45 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FastAPI application entry point.""" + +from __future__ import annotations + +from contextlib import asynccontextmanager + +from fastapi import FastAPI + +from . import routes +from .store import init_db + + +@asynccontextmanager +async def lifespan(app: FastAPI): + await init_db() + yield + + +app = FastAPI( + title="ADK HITL Approval API", + description="Human-in-the-Loop approval layer for Google ADK agents.", + version="0.1.0", + lifespan=lifespan, +) + +app.include_router(routes.router) + + +@app.get("/health") +async def health(): + return {"status": "ok"} diff --git a/src/google/adk_community/services/hitl_approval/routes.py b/src/google/adk_community/services/hitl_approval/routes.py new file mode 100644 index 00000000..d0fe487b --- /dev/null +++ b/src/google/adk_community/services/hitl_approval/routes.py @@ -0,0 +1,162 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Approval request CRUD endpoints.""" + +from __future__ import annotations + +import json +from datetime import datetime, timezone +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from .store import ApprovalRequestDB, get_db +from ...tools.hitl.models import (ApprovalDecision, ApprovalRequest, ApprovalStatus) + +router = APIRouter(prefix="/approvals", tags=["approvals"]) + +# ── Routes ──────────────────────────────────────────────────────────────────── + + +@router.post("/", response_model=ApprovalRequest, status_code=201) +async def create_approval(payload: ApprovalRequest, db: AsyncSession = Depends(get_db)): + """Agent submits a new approval request before executing a tool.""" + db_item = ApprovalRequestDB( + id=payload.id, + session_id=payload.session_id, + invocation_id=payload.invocation_id, + function_call_id=payload.function_call_id, + app_name=payload.app_name, + user_id=payload.user_id, + agent_name=payload.agent_name, + tool_name=payload.tool_name, + message=payload.message, + payload=json.dumps(payload.payload), + response_schema=json.dumps(payload.response_schema), + risk_level=payload.risk_level, + status=payload.status, + created_at=payload.created_at, + decided_at=payload.decided_at, + decided_by=payload.decided_by, + decision_notes=payload.decision_notes, + escalated_to=payload.escalated_to, + ) + db.add(db_item) + await db.commit() + await db.refresh(db_item) + return _to_pydantic(db_item) + + +@router.get("/pending", response_model=List[ApprovalRequest]) +async def list_pending_approvals(db: AsyncSession = Depends(get_db)): + """List all pending approvals.""" + q = ( + select(ApprovalRequestDB) + .where(ApprovalRequestDB.status == ApprovalStatus.PENDING) + .order_by(ApprovalRequestDB.created_at.desc()) + ) + result = await db.execute(q) + return [_to_pydantic(r) for r in result.scalars()] + + +@router.get("/audit", response_model=List[ApprovalRequest]) +async def get_audit_log( + agent_name: Optional[str] = None, + decision: Optional[str] = None, + db: AsyncSession = Depends(get_db), +): + """Audit log — queryable by agent, date, decision.""" + q = select(ApprovalRequestDB).order_by(ApprovalRequestDB.created_at.desc()) + if agent_name: + q = q.where(ApprovalRequestDB.agent_name == agent_name) + if decision: + q = q.where(ApprovalRequestDB.status == decision) + + result = await db.execute(q) + return [_to_pydantic(r) for r in result.scalars()] + + +@router.get("/{request_id}", response_model=ApprovalRequest) +async def get_approval(request_id: str, db: AsyncSession = Depends(get_db)): + """Get single approval with full context.""" + db_item = await _get_or_404(request_id, db) + return _to_pydantic(db_item) + + +@router.post("/{request_id}/decide", response_model=ApprovalRequest) +async def resolve_approval( + request_id: str, + decision: ApprovalDecision, + db: AsyncSession = Depends(get_db), +): + """Submit approve/reject/escalate decision.""" + db_item = await _get_or_404(request_id, db) + if db_item.status != ApprovalStatus.PENDING: + raise HTTPException(status_code=409, detail="Request already resolved.") + + db_item.status = decision.decision + db_item.decided_by = decision.reviewer_id + db_item.decision_notes = decision.notes + db_item.escalated_to = decision.escalate_to + db_item.decided_at = datetime.now(timezone.utc) + + # Optionally update payload if modified by reviewer + if decision.payload: + db_item.payload = json.dumps(decision.payload) + + await db.commit() + await db.refresh(db_item) + + return _to_pydantic(db_item) + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + + +async def _get_or_404(request_id: str, db: AsyncSession) -> ApprovalRequestDB: + result = await db.execute( + select(ApprovalRequestDB).where(ApprovalRequestDB.id == request_id) + ) + db_item = result.scalar_one_or_none() + if db_item is None: + raise HTTPException(status_code=404, detail="Approval request not found.") + return db_item + + +def _to_pydantic(db_item: ApprovalRequestDB) -> ApprovalRequest: + return ApprovalRequest( + id=db_item.id, + session_id=db_item.session_id, + invocation_id=db_item.invocation_id, + function_call_id=db_item.function_call_id, + app_name=db_item.app_name, + user_id=db_item.user_id, + agent_name=db_item.agent_name, + tool_name=db_item.tool_name, + message=db_item.message, + payload=json.loads(db_item.payload) if db_item.payload else {}, + response_schema=json.loads(db_item.response_schema) + if db_item.response_schema + else {}, + risk_level=db_item.risk_level, + status=db_item.status, + created_at=db_item.created_at, + decided_at=db_item.decided_at, + decided_by=db_item.decided_by, + decision_notes=db_item.decision_notes, + escalated_to=db_item.escalated_to, + ) diff --git a/src/google/adk_community/services/hitl_approval/store.py b/src/google/adk_community/services/hitl_approval/store.py new file mode 100644 index 00000000..34f6e965 --- /dev/null +++ b/src/google/adk_community/services/hitl_approval/store.py @@ -0,0 +1,80 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Async SQLite database setup via SQLAlchemy.""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timezone + +from sqlalchemy import Column, DateTime, String, Text +from sqlalchemy.ext.asyncio import ( + AsyncSession, + async_sessionmaker, + create_async_engine, +) +from sqlalchemy.orm import DeclarativeBase + + +class Base(DeclarativeBase): + pass + + +class ApprovalRequestDB(Base): + __tablename__ = "approval_requests" + + id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) + session_id = Column(String, nullable=False) + invocation_id = Column(String, nullable=True) + function_call_id = Column(String, nullable=True) + app_name = Column(String, nullable=False) + user_id = Column(String, nullable=False) + agent_name = Column(String, nullable=False) + tool_name = Column(String, nullable=False) + message = Column(Text, nullable=False) + payload = Column(Text, nullable=False) # JSON-serialised + response_schema = Column(Text, nullable=True) # JSON-serialised + risk_level = Column(String, nullable=False) + status = Column(String, nullable=False) + created_at = Column( + DateTime, default=lambda: datetime.now(timezone.utc), nullable=False + ) + decided_at = Column(DateTime, nullable=True) + decided_by = Column(String, nullable=True) + decision_notes = Column(Text, nullable=True) + escalated_to = Column(String, nullable=True) + + +import os + +db_path = os.getenv("HITL_DB_PATH", "./hitl.db") +DATABASE_URL = f"sqlite+aiosqlite:///{db_path}" + +engine = create_async_engine(DATABASE_URL, echo=False) +AsyncSessionLocal = async_sessionmaker( + engine, expire_on_commit=False, class_=AsyncSession +) + + +async def init_db() -> None: + """Create tables on startup.""" + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + +async def get_db(): + """FastAPI dependency that yields a database session.""" + async with AsyncSessionLocal() as session: + yield session diff --git a/src/google/adk_community/tools/hitl/__init__.py b/src/google/adk_community/tools/hitl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/google/adk_community/tools/hitl/adapters/__init__.py b/src/google/adk_community/tools/hitl/adapters/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/google/adk_community/tools/hitl/adapters/adk1.py b/src/google/adk_community/tools/hitl/adapters/adk1.py new file mode 100644 index 00000000..94e925fb --- /dev/null +++ b/src/google/adk_community/tools/hitl/adapters/adk1.py @@ -0,0 +1,86 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Adapter for Google ADK 1.x Human-in-the-Loop feature. + +Converts ADK `adk_request_confirmation` events to normalized ApprovalRequests, +and formats Streamlit dashboard decisions back into ADK FunctionResponses. +""" + +from __future__ import annotations + +from typing import Any, Dict + +import httpx + +from ..models import (ApprovalDecision, ApprovalRequest, ApprovalStatus) + + +def parse_confirmation_event(payload: Dict[str, Any]) -> ApprovalRequest: + """Parse an incoming ADK 1.x Tool Confirmation event to a normalized ApprovalRequest.""" + + call_id = payload.get("function_call_id") + args = payload.get("arguments", {}) + hint = args.get("hint", "Please review this action.") + tool_payload = args.get("payload", {}) + + return ApprovalRequest( + session_id=payload.get("session_id", "unknown_session"), + invocation_id=payload.get("invocation_id"), + function_call_id=call_id, + app_name=payload.get("app_name", "unknown_app"), + user_id=payload.get("user_id", "unknown_user"), + agent_name=payload.get("agent_name", "unknown_agent"), + tool_name=args.get("tool_name", "unknown_tool"), + message=hint, + payload=tool_payload, + response_schema={}, # Native tool confirmation in ADK 1.x doesn't expose a schema + ) + + +async def submit_decision_to_adk( + adk_base_url: str, request: ApprovalRequest, decision: ApprovalDecision +): + """Resume the ADK 1.x agent by sending the human's decision back as a FunctionResponse.""" + + confirmed = decision.decision == ApprovalStatus.APPROVED + + adk_payload = { + "app_name": request.app_name, + "user_id": request.user_id, + "session_id": request.session_id, + "invocation_id": request.invocation_id, + "new_message": { + "role": "user", + "parts": [ + { + "function_response": { + "id": request.function_call_id, + "name": "adk_request_confirmation", + "response": { + "confirmed": confirmed, + "payload": decision.payload or {}, + }, + } + } + ], + }, + } + + async with httpx.AsyncClient() as client: + # Assumes the ADK FastAPI server is running with the /run_sse endpoint + url = f"{adk_base_url.rstrip('/')}/run_sse" + resp = await client.post(url, json=adk_payload) + resp.raise_for_status() + return resp.json() diff --git a/src/google/adk_community/tools/hitl/gateway.py b/src/google/adk_community/tools/hitl/gateway.py new file mode 100644 index 00000000..e279e310 --- /dev/null +++ b/src/google/adk_community/tools/hitl/gateway.py @@ -0,0 +1,150 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""HITL tool wrapper — submits an approval request to the FastAPI API and +waits asynchronously for a supervisor to approve/reject via the Streamlit +dashboard before executing the wrapped tool. + +Usage: + from google.adk_community.tools.hitl.gateway import hitl_tool + from google.adk.tools import FunctionTool + + @hitl_tool(agent_name="credit_agent") + async def apply_credit(account_id: str, amount: float) -> str: + ... # only runs after a supervisor approves in the dashboard + + tool = FunctionTool(apply_credit) +""" + +from __future__ import annotations + +import asyncio +import functools +import inspect +import json +import uuid +from typing import Any, Callable, Optional + +import httpx + +API_BASE_URL = "http://localhost:8000" +POLL_INTERVAL_S = 2.0 +POLL_TIMEOUT_S = 300.0 # 5 minutes + + +def hitl_tool( + agent_name: str, + api_base: str = API_BASE_URL, + poll_interval: float = POLL_INTERVAL_S, + timeout: float = POLL_TIMEOUT_S, +): + """Decorator — wraps any async or sync function with a supervisor approval gate. + + Flow: + 1. Agent calls the wrapped function. + 2. Wrapper POSTs an approval request to the HITL API (status: pending). + 3. Wrapper polls GET /approvals/{id} with asyncio.sleep — non-blocking. + 4. Supervisor opens the Streamlit dashboard and clicks Approve/Reject. + 5. On approval the original function runs; on rejection a PermissionError + is raised so the agent can relay the outcome to the user. + """ + + def decorator(fn: Callable) -> Callable: + @functools.wraps(fn) + async def wrapper(*args, **kwargs) -> Any: + session_id = kwargs.pop("_session_id", str(uuid.uuid4())) + invocation_id = kwargs.pop("_invocation_id", None) + + payload = { + "session_id": session_id, + "invocation_id": invocation_id, + "app_name": "adk_chatbot", + "user_id": "current_user", + "agent_name": agent_name, + "tool_name": fn.__name__, + "message": f"Approval requested for {fn.__name__}", + "payload": _serialise_args(fn, args, kwargs), + } + + async with httpx.AsyncClient(base_url=api_base) as client: + resp = await client.post("/approvals/", json=payload) + resp.raise_for_status() + request_id = resp.json()["id"] + + decision_data = await _poll_for_decision( + api_base, request_id, poll_interval, timeout + ) + + if not decision_data: + raise TimeoutError( + f"No decision received for '{fn.__name__}' within {timeout}s." + ) + + status = decision_data["status"] + notes = decision_data.get("decision_notes", "No notes provided.") + + if status == "approved": + if inspect.iscoroutinefunction(fn): + result = await fn(*args, **kwargs) + else: + result = fn(*args, **kwargs) + + # We inject the supervisor's decision into the return payload + # so the LLM explicitly sees and references the supervisor's approval! + return { + "supervisor_decision": "APPROVED", + "supervisor_notes": notes, + "action_result": result + } + elif status == "rejected": + raise PermissionError( + f"Tool '{fn.__name__}' was rejected by a supervisor. Notes: {notes}" + ) + elif status == "escalated": + raise PermissionError( + f"Tool '{fn.__name__}' was escalated — awaiting further review. Notes: {notes}" + ) + + return wrapper + + return decorator + + +# ── Helpers ─────────────────────────────────────────────────────────────────── + + +async def _poll_for_decision( + api_base: str, + request_id: str, + interval: float, + timeout: float, +) -> Optional[dict]: + deadline = asyncio.get_event_loop().time() + timeout + async with httpx.AsyncClient(base_url=api_base) as client: + while asyncio.get_event_loop().time() < deadline: + resp = await client.get(f"/approvals/{request_id}") + resp.raise_for_status() + data = resp.json() + if data["status"] != "pending": + return data + await asyncio.sleep(interval) + return None + + +def _serialise_args(fn: Callable, args: tuple, kwargs: dict) -> dict: + sig = inspect.signature(fn) + params = list(sig.parameters.keys()) + named = {params[i]: args[i] for i in range(len(args)) if i < len(params)} + named.update(kwargs) + return json.loads(json.dumps(named, default=str)) diff --git a/src/google/adk_community/tools/hitl/models.py b/src/google/adk_community/tools/hitl/models.py new file mode 100644 index 00000000..0d5ecf9e --- /dev/null +++ b/src/google/adk_community/tools/hitl/models.py @@ -0,0 +1,75 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import uuid +from datetime import datetime +from typing import Any, Optional + +from pydantic import BaseModel, Field + + +class ApprovalStatus: + PENDING = "pending" + APPROVED = "approved" + REJECTED = "rejected" + ESCALATED = "escalated" + EXPIRED = "expired" + + +class RiskLevel: + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + + +class ApprovalRequest(BaseModel): + # Identity + id: str = Field(default_factory=lambda: str(uuid.uuid4())) + + # ADK context — needed to resume the agent correctly + session_id: str + invocation_id: Optional[str] = None # Required for ADK Resume feature + function_call_id: Optional[str] = None # Must match in FunctionResponse + app_name: str + user_id: str + + # Agent context — what the human needs to decide + agent_name: str + tool_name: str + message: str # Maps from ADK 1.x 'hint' OR ADK 2.0 'message' + payload: dict # The structured data awaiting approval + response_schema: dict = Field( + default_factory=dict + ) # Empty in 1.x, populated in ADK 2.0 + risk_level: str = RiskLevel.MEDIUM + + # Status tracking + status: str = ApprovalStatus.PENDING + created_at: datetime = Field(default_factory=datetime.utcnow) + decided_at: Optional[datetime] = None + decided_by: Optional[str] = None + decision_notes: Optional[str] = None + + # Escalation + escalated_to: Optional[str] = None + + +class ApprovalDecision(BaseModel): + decision: str # approved / rejected / escalated + reviewer_id: str + notes: Optional[str] = None + payload: dict = Field(default_factory=dict) # Response data back to the agent + escalate_to: Optional[str] = None diff --git a/tests/unittests/services/test_hitl_approval_api.py b/tests/unittests/services/test_hitl_approval_api.py new file mode 100644 index 00000000..e48552e8 --- /dev/null +++ b/tests/unittests/services/test_hitl_approval_api.py @@ -0,0 +1,166 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for the FastAPI approval endpoints.""" + +from __future__ import annotations + +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient +from sqlalchemy.ext.asyncio import ( + AsyncSession, + async_sessionmaker, + create_async_engine, +) + +from google.adk_community.services.hitl_approval.api import app +from google.adk_community.services.hitl_approval.store import Base, get_db + +# Use an in-memory SQLite database for tests +TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" + + +@pytest_asyncio.fixture +async def db_session(): + engine = create_async_engine(TEST_DATABASE_URL) + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + session_factory = async_sessionmaker( + engine, expire_on_commit=False, class_=AsyncSession + ) + async with session_factory() as session: + yield session + await engine.dispose() + + +@pytest_asyncio.fixture +async def client(db_session): + async def override_get_db(): + yield db_session + + app.dependency_overrides[get_db] = override_get_db + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac + app.dependency_overrides.clear() + + +# ── Tests ───────────────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_health(client): + resp = await client.get("/health") + assert resp.status_code == 200 + assert resp.json() == {"status": "ok"} + + +@pytest.mark.asyncio +async def test_create_approval(client): + payload = { + "session_id": "sess-1", + "app_name": "test_app", + "user_id": "u-123", + "agent_name": "email_agent", + "tool_name": "send_email", + "message": "Please approve sending email.", + "payload": {"to": "alice@example.com"}, + } + resp = await client.post("/approvals/", json=payload) + assert resp.status_code == 201 + data = resp.json() + assert data["status"] == "pending" + assert data["tool_name"] == "send_email" + return data["id"] + + +@pytest.mark.asyncio +async def test_resolve_approval(client): + # Create first + create_resp = await client.post( + "/approvals/", + json={ + "session_id": "sess-2", + "app_name": "test_app", + "user_id": "u-123", + "agent_name": "file_agent", + "tool_name": "delete_file", + "message": "Approve delete?", + "payload": {"path": "/tmp/test.txt"}, + }, + ) + assert create_resp.status_code == 201 + request_id = create_resp.json()["id"] + + # Resolve + resolve_resp = await client.post( + f"/approvals/{request_id}/decide", + json={"decision": "approved", "reviewer_id": "rev-99", "notes": "Looks safe."}, + ) + assert resolve_resp.status_code == 200 + data = resolve_resp.json() + assert data["status"] == "approved" + assert data["decision_notes"] == "Looks safe." + assert data["decided_at"] is not None + + +@pytest.mark.asyncio +async def test_double_resolve_returns_409(client): + create_resp = await client.post( + "/approvals/", + json={ + "session_id": "sess-3", + "app_name": "test_app", + "user_id": "u-123", + "agent_name": "researcher", + "tool_name": "web_search", + "message": "Search the web?", + "payload": {"query": "latest news"}, + }, + ) + request_id = create_resp.json()["id"] + + await client.post( + f"/approvals/{request_id}/decide", + json={"decision": "rejected", "reviewer_id": "rev-1"}, + ) + resp2 = await client.post( + f"/approvals/{request_id}/decide", + json={"decision": "approved", "reviewer_id": "rev-1"}, + ) + assert resp2.status_code == 409 + + +@pytest.mark.asyncio +async def test_list_pending(client): + # Create two requests + for tool in ["tool_a", "tool_b"]: + await client.post( + "/approvals/", + json={ + "session_id": "s", + "app_name": "app", + "user_id": "u", + "agent_name": "ag", + "tool_name": tool, + "message": "msg", + "payload": {}, + }, + ) + + resp = await client.get("/approvals/pending") + assert resp.status_code == 200 + assert len(resp.json()) == 2 + assert all(r["status"] == "pending" for r in resp.json()) diff --git a/tests/unittests/tools/test_hitl_gateway.py b/tests/unittests/tools/test_hitl_gateway.py new file mode 100644 index 00000000..69ea7f04 --- /dev/null +++ b/tests/unittests/tools/test_hitl_gateway.py @@ -0,0 +1,113 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the HITL tool wrapper (mocking the API calls).""" + +from __future__ import annotations + +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from google.adk_community.tools.hitl.gateway import _serialise_args, hitl_tool + +# ── _serialise_args ─────────────────────────────────────────────────────────── + + +def test_serialise_args_positional(): + def fn(a, b, c): + ... + + result = _serialise_args(fn, (1, 2), {"c": 3}) + assert result == {"a": 1, "b": 2, "c": 3} + + +def test_serialise_args_kwargs_only(): + def fn(x, y): + ... + + result = _serialise_args(fn, (), {"x": "hello", "y": 42}) + assert result == {"x": "hello", "y": 42} + + +def test_serialise_args_non_serialisable_falls_back_to_str(): + class Foo: + pass + + def fn(obj): + ... + + result = _serialise_args(fn, (Foo(),), {}) + assert isinstance(result["obj"], str) + + +# ── hitl_tool — approved ────────────────────────────────────────────────────── + + +def _make_mock_client(status: str, request_id: str = "abc-123"): + mock_client = AsyncMock() + + # Setup context manager correctly + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = False + + post_resp = MagicMock() + post_resp.json.return_value = {"id": request_id} + mock_client.post.return_value = post_resp + + get_resp = MagicMock() + get_resp.json.return_value = {"id": request_id, "status": status} + mock_client.get.return_value = get_resp + + return mock_client + + +@pytest.mark.asyncio +@patch("google.adk_community.tools.hitl.gateway.httpx.AsyncClient") +async def test_approved_tool_runs(mock_client_cls): + mock_client_cls.return_value = _make_mock_client("approved") + + @hitl_tool(agent_name="test_agent") + def add(a: int, b: int) -> int: + return a + b + + result = await add(2, 3) + assert result == 5 + + +@pytest.mark.asyncio +@patch("google.adk_community.tools.hitl.gateway.httpx.AsyncClient") +async def test_rejected_tool_raises(mock_client_cls): + mock_client_cls.return_value = _make_mock_client("rejected") + + @hitl_tool(agent_name="test_agent") + def delete_file(path: str) -> str: + return "deleted" + + with pytest.raises(PermissionError, match="rejected"): + await delete_file("/important/file.txt") + + +@pytest.mark.asyncio +@patch("google.adk_community.tools.hitl.gateway.httpx.AsyncClient") +async def test_escalated_tool_raises(mock_client_cls): + mock_client_cls.return_value = _make_mock_client("escalated") + + @hitl_tool(agent_name="test_agent") + def wire_transfer(amount: float) -> str: + return "done" + + with pytest.raises(PermissionError, match="escalated"): + await wire_transfer(10000.0)