diff --git a/src/google/adk/cli/api_server.py b/src/google/adk/cli/api_server.py index f9b34c164c..d242960b77 100644 --- a/src/google/adk/cli/api_server.py +++ b/src/google/adk/cli/api_server.py @@ -643,6 +643,9 @@ def __init__( self.trigger_sources = trigger_sources self.default_llm_model = default_llm_model self.default_app_name = os.getenv("ADK_DEFAULT_APP_NAME") + # Registry of active agent-run tasks keyed by session_id, + # enabling cancellation via the /cancel API endpoint. + self.active_tasks: dict[str, asyncio.Task[Any]] = {} async def get_runner_async(self, app_name: str) -> Runner: """Returns the cached runner for the given app.""" @@ -1218,6 +1221,7 @@ async def update_session( return session + @app.get( "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}", response_model_exclude_none=True, @@ -1472,6 +1476,7 @@ async def worker(): raise HTTPException(status_code=404, detail=str(e)) from e worker_task = asyncio.create_task(worker()) + self.active_tasks[req.session_id] = worker_task async def monitor(): try: @@ -1502,6 +1507,7 @@ async def monitor(): raise finally: monitor_task.cancel() + self.active_tasks.pop(req.session_id, None) @app.post("/run_sse") async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse: @@ -1518,11 +1524,6 @@ async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse: _set_telemetry_context_if_needed(runner) # Validate session existence before starting the stream. - # We check directly here instead of eagerly advancing the - # runner's async generator with anext(), because splitting - # generator consumption across two asyncio Tasks (request - # handler vs StreamingResponse) breaks OpenTelemetry context - # detachment. if not runner.auto_create_session: session = await self.session_service.get_session( app_name=req.app_name, @@ -1535,59 +1536,81 @@ async def run_agent_sse(req: RunAgentRequest) -> StreamingResponse: detail=f"Session not found: {req.session_id}", ) - # Convert the events to properly formatted SSE - async def event_generator(): - async with Aclosing( - runner.run_async( - user_id=req.user_id, - session_id=req.session_id, - new_message=req.new_message, - state_delta=req.state_delta, - run_config=RunConfig( - streaming_mode=stream_mode, - custom_metadata=req.custom_metadata, - ), - invocation_id=req.invocation_id, - ) - ) as agen: - try: + # Use a queue to bridge the producer task (runs the agent) and + # the StreamingResponse consumer (formats SSE). This lets the + # /cancel endpoint cancel the producer task via the active_tasks + # registry. + event_queue: asyncio.Queue[Event | Exception | None] = asyncio.Queue() + + async def produce_events() -> None: + try: + async with Aclosing( + runner.run_async( + user_id=req.user_id, + session_id=req.session_id, + new_message=req.new_message, + state_delta=req.state_delta, + run_config=RunConfig( + streaming_mode=stream_mode, + custom_metadata=req.custom_metadata, + ), + invocation_id=req.invocation_id, + ) + ) as agen: async for event in agen: - # ADK Web renders artifacts from `actions.artifactDelta` - # during part processing *and* during action processing - # 1) the original event with `artifactDelta` cleared (content) - # 2) a content-less "action-only" event carrying `artifactDelta` - events_to_stream = [event] - if ( - not req.function_call_event_id - and event.actions.artifact_delta - and event.content - and event.content.parts - ): - content_event = event.model_copy(deep=True) - content_event.actions.artifact_delta = {} - artifact_event = event.model_copy(deep=True) - artifact_event.content = None - events_to_stream = [content_event, artifact_event] - - for event_to_stream in events_to_stream: - sse_event = event_to_stream.model_dump_json( - exclude_none=True, - by_alias=True, - ) - logger.debug( - "Generated event in agent run streaming: %s", sse_event - ) - yield f"data: {sse_event}\n\n" - except Exception as e: - logger.exception("Error in event_generator: %s", e) - yield f"data: {json.dumps({'error': str(e)})}\n\n" + await event_queue.put(event) + except asyncio.CancelledError: + pass + except Exception as e: # pylint: disable=broad-exception-caught + await event_queue.put(e) + finally: + await event_queue.put(None) # sentinel + + producer_task = asyncio.create_task(produce_events()) + self.active_tasks[req.session_id] = producer_task + + async def event_generator(): + try: + while True: + item = await event_queue.get() + if item is None: + break + if isinstance(item, Exception): + logger.exception("Error in event_generator: %s", item) + yield f"data: {json.dumps({'error': str(item)})}\n\n" + break + + events_to_stream = [item] + if ( + not req.function_call_event_id + and item.actions.artifact_delta + and item.content + and item.content.parts + ): + content_event = item.model_copy(deep=True) + content_event.actions.artifact_delta = {} + artifact_event = item.model_copy(deep=True) + artifact_event.content = None + events_to_stream = [content_event, artifact_event] + + for event_to_stream in events_to_stream: + sse_event = event_to_stream.model_dump_json( + exclude_none=True, + by_alias=True, + ) + logger.debug( + "Generated event in agent run streaming: %s", sse_event + ) + yield f"data: {sse_event}\n\n" + finally: + if not producer_task.done(): + producer_task.cancel() + self.active_tasks.pop(req.session_id, None) - # Returns a streaming response with the proper media type for SSE return StreamingResponse( event_generator(), media_type="text/event-stream", ) - @app.websocket("/run_live") async def run_agent_live( websocket: WebSocket, @@ -1684,6 +1707,8 @@ async def process_messages(): asyncio.create_task(forward_events()), asyncio.create_task(process_messages()), ] + # Register under session_id so the /cancel endpoint can cancel them. + self.active_tasks[session_id] = tasks[0] done, pending = await asyncio.wait( tasks, return_when=asyncio.FIRST_EXCEPTION ) @@ -1706,3 +1731,34 @@ async def process_messages(): finally: for task in pending: task.cancel() + self.active_tasks.pop(session_id, None) + + @app.post( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}:cancel", + ) + async def cancel_session( + app_name: str, user_id: str, session_id: str + ) -> dict[str, Any]: + """Cancel an in-progress agent run for the given session. + + Looks up the active asyncio.Task for *session_id* in the + server's task registry and cancels it. The running agent will + receive a CancelledError on its next await point (e.g. an LLM + API call or tool invocation), allowing it to stop gracefully. + + Returns 404 if no active run is found for the session. + """ + task = self.active_tasks.get(session_id) + if task is None or task.done(): + raise HTTPException( + status_code=404, + detail=f"No active run found for session '{session_id}'", + ) + task.cancel() + logger.info( + "Cancelled agent run for session %s (app=%s, user=%s)", + session_id, + app_name, + user_id, + ) + return {"status": "cancelled", "session_id": session_id} diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index b6b61fffe2..78aafe305e 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -1251,7 +1251,6 @@ async def _call_llm_async( llm_request: LlmRequest, model_response_event: Event, ) -> AsyncGenerator[LlmResponse, None]: - async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: with tracer.start_as_current_span('call_llm') as span: # Runs before_model_callback inside the call_llm span so diff --git a/tests/unittests/cli/test_cancel_session.py b/tests/unittests/cli/test_cancel_session.py new file mode 100644 index 0000000000..2bf844c8ca --- /dev/null +++ b/tests/unittests/cli/test_cancel_session.py @@ -0,0 +1,269 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import logging +from typing import Optional +from unittest.mock import patch + +from fastapi.testclient import TestClient +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.run_config import RunConfig +from google.adk.cli import fast_api as fast_api_module +from google.adk.events.event import Event +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types +import pytest + +logger = logging.getLogger("google_adk." + __name__) + + +# --------------------------------------------------------------------------- +# Test helpers +# --------------------------------------------------------------------------- + +# Shared mutable flag so the mocked runner can signal that cancellation +# was actually detected (CancelledError caught). +_cancellation_signal: list[bool] = [] + + +def _make_text_event(text: str) -> Event: + return Event( + author="test_agent", + invocation_id="invocation_id", + content=types.Content( + role="model", parts=[types.Part(text=text)] + ), + ) + + +async def _cancellable_run_async( + self, + user_id, + session_id, + new_message, + state_delta=None, + run_config: Optional[RunConfig] = None, + invocation_id: Optional[str] = None, +): + """Yields one event, then blocks until cancelled via task.cancel(). + + Sets ``_cancellation_signal[0] = True`` when CancelledError is caught, + so the test can verify the cancellation propagated to the runner. + """ + _cancellation_signal.clear() + yield _make_text_event("starting run...") + try: + await asyncio.sleep(3600) # cancelled by the /cancel endpoint + except asyncio.CancelledError: + _cancellation_signal.append(True) + yield _make_text_event("run was cancelled") + raise + + +@pytest.fixture(autouse=True) +def _clear_cancellation_signal(): + """Reset the shared cancellation signal before each test.""" + _cancellation_signal.clear() + + +@pytest.fixture +def test_session_info(): + return { + "app_name": "test_app", + "user_id": "test_user", + } + + +@pytest.fixture +def mock_agent_loader(): + """Minimal agent loader that returns a single LlmAgent.""" + + class Loader: + def load_agent(self, app_name): + agent = LlmAgent(name=app_name, model="gemini-2.5-flash") + return agent + + def list_apps(self): + return ["test_app"] + + def list_app_info(self): + return [{"name": "test_app", "description": "Test app"}] + + return Loader() + + +@pytest.fixture +def client(monkeypatch, mock_agent_loader): + """Create a TestClient for the FastAPI app with a cancellable runner.""" + monkeypatch.setattr(Runner, "run_async", _cancellable_run_async) + session_service = InMemorySessionService() + + app = fast_api_module.get_fast_api_app( + agent_loader=mock_agent_loader, + session_service=session_service, + ) + return TestClient(app) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestCancelSessionEndpoint: + """Integration tests for POST /apps/.../sessions/...:cancel.""" + + def test_cancel_active_run_interrupts_runner( + self, client, test_session_info + ): + """Start a blocking run, cancel it, and verify the runner was interrupted.""" + app_name = test_session_info["app_name"] + user_id = test_session_info["user_id"] + + # 1. Create a session + create_resp = client.post( + f"/apps/{app_name}/users/{user_id}/sessions", + json={"app_name": app_name, "user_id": user_id}, + ) + assert create_resp.status_code == 200 + session_id = create_resp.json()["session_id"] + + # 2. Start a blocking run in a background thread. + # Use the TestClient (not raw requests) so the call reaches + # the in-memory FastAPI app. TestClient.post() is synchronous + # and will block until the server responds — which only happens + # after we cancel the run in step 4. + import threading + + run_result = {"status": None, "error": None} + run_started = threading.Event() + + def do_run(test_client): + try: + resp = test_client.post( + f"/apps/{app_name}/users/{user_id}" + f"/sessions/{session_id}/run", + json={ + "app_name": app_name, + "user_id": user_id, + "session_id": session_id, + "new_message": { + "role": "user", + "parts": [{"text": "hello"}], + }, + }, + ) + run_result["status"] = resp.status_code + run_result["body"] = resp.json() if resp.text else None + except Exception as e: + run_result["error"] = str(e) + + run_thread = threading.Thread( + target=do_run, args=(client,), daemon=True + ) + run_thread.start() + + # 3. Wait for the runner to start processing (signal from the + # mocked runner that it entered the cancellation-sensitive block). + # The runner yields one event before blocking, so the thread + # will have sent the request and be waiting on the response. + import time + time.sleep(0.5) + + # 4. Cancel the run via the new endpoint + cancel_resp = client.post( + f"/apps/{app_name}/users/{user_id}/sessions/{session_id}:cancel", + ) + assert cancel_resp.status_code == 200 + data = cancel_resp.json() + assert data["status"] == "cancelled" + assert data["session_id"] == session_id + + # 5. Wait for the background run to finish (should happen quickly + # after cancellation) + run_thread.join(timeout=5.0) + assert not run_thread.is_alive(), ( + "Background run thread should have completed after cancellation" + ) + + # 6. Verify the runner actually detected cancellation. + # The _cancellable_run_async sets this flag when CancelledError + # is caught inside the runner coroutine. + assert len(_cancellation_signal) > 0, ( + "CancelledError was NOT raised inside the runner — " + "the task.cancel() did not propagate to the agent coroutine" + ) + logger.info("Run result after cancellation: %s", run_result) + + def test_cancel_nonexistent_session_returns_404(self, client): + """Cancelling a session with no active run returns 404.""" + resp = client.post( + "/apps/test_app/users/test_user/sessions/nonexistent:cancel", + ) + assert resp.status_code == 404 + assert "no active run" in resp.json()["detail"].lower() + + def test_cancel_idempotent_returns_404_on_second_call(self, client): + """Double-cancelling the same session returns 404 on the second call.""" + url = "/apps/test_app/users/test_user/sessions/nonexistent:cancel" + assert client.post(url).status_code == 404 + assert client.post(url).status_code == 404 + + +class TestTaskRegistry: + """Tests for the active_tasks registry lifecycle.""" + + def test_registry_cleanup_after_run_completion( + self, client, test_session_info, monkeypatch + ): + """After a run completes normally, /cancel returns 404 (task cleaned up).""" + async def fast_run(self, **kwargs): + yield _make_text_event("done") + + monkeypatch.setattr(Runner, "run_async", fast_run) + + app_name = test_session_info["app_name"] + user_id = test_session_info["user_id"] + + create_resp = client.post( + f"/apps/{app_name}/users/{user_id}/sessions", + json={"app_name": app_name, "user_id": user_id}, + ) + assert create_resp.status_code == 200 + session_id = create_resp.json()["session_id"] + + # Run to completion + run_resp = client.post( + f"/apps/{app_name}/users/{user_id}/sessions/{session_id}/run", + json={ + "app_name": app_name, + "user_id": user_id, + "session_id": session_id, + "new_message": { + "role": "user", + "parts": [{"text": "hello"}], + }, + }, + ) + assert run_resp.status_code == 200 + + # Task should already be popped from registry + cancel_resp = client.post( + f"/apps/{app_name}/users/{user_id}/sessions/{session_id}:cancel", + ) + assert cancel_resp.status_code == 404