diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 0b6f3fb6fe..67342f780a 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -83,6 +83,8 @@ def get_fast_api_app( allow_origins: Optional[list[str]] = None, web: bool, a2a: bool = False, + a2a_task_store: Optional[Any] = None, + a2a_push_config_store: Optional[Any] = None, host: str = "127.0.0.1", port: int = 8000, url_prefix: Optional[str] = None, @@ -125,6 +127,12 @@ def get_fast_api_app( allow_origins: List of allowed origins for CORS. web: Whether to enable the web UI and serve its assets. a2a: Whether to enable Agent-to-Agent (A2A) protocol support. + a2a_task_store: Optional A2A TaskStore instance. Defaults to + InMemoryTaskStore. Pass a custom store (e.g. DatabaseTaskStore) to + persist task state across restarts or share it across replicas. + a2a_push_config_store: Optional A2A PushNotificationConfigStore. + Defaults to InMemoryPushNotificationConfigStore. Pass a custom store + for persistence across restarts. host: Host address for the server (defaults to 127.0.0.1). port: Port number for the server (defaults to 8000). url_prefix: Optional prefix for all URL routes. @@ -561,7 +569,10 @@ async def get_agent_builder( base_path = Path.cwd() / agents_dir # the root agents directory should be an existing folder if base_path.exists() and base_path.is_dir(): - a2a_task_store = InMemoryTaskStore() + if a2a_task_store is None: + a2a_task_store = InMemoryTaskStore() + if a2a_push_config_store is None: + a2a_push_config_store = InMemoryPushNotificationConfigStore() def create_a2a_runner_loader(captured_app_name: str): """Factory function to create A2A runner with proper closure.""" @@ -589,12 +600,10 @@ async def _get_a2a_runner_async() -> Runner: runner=create_a2a_runner_loader(app_name), ) - push_config_store = InMemoryPushNotificationConfigStore() - request_handler = DefaultRequestHandler( agent_executor=agent_executor, task_store=a2a_task_store, - push_config_store=push_config_store, + push_config_store=a2a_push_config_store, ) with (p / "agent.json").open("r", encoding="utf-8") as f: diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 15bc908ddb..26e6649bf9 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -1547,6 +1547,138 @@ def test_a2a_disabled_by_default(test_app): logger.info("A2A disabled by default test passed") +def test_a2a_uses_in_memory_task_store_by_default( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + temp_agents_dir_with_a2a, + monkeypatch, +): + """Test that InMemoryTaskStore is used when no task store is provided.""" + with ( + patch("signal.signal", return_value=None), + patch( + "google.adk.cli.fast_api.create_session_service_from_options", + return_value=mock_session_service, + ), + patch( + "google.adk.cli.fast_api.create_artifact_service_from_options", + return_value=mock_artifact_service, + ), + patch( + "google.adk.cli.fast_api.create_memory_service_from_options", + return_value=mock_memory_service, + ), + patch( + "google.adk.cli.fast_api.AgentLoader", + return_value=mock_agent_loader, + ), + patch( + "google.adk.cli.fast_api.LocalEvalSetsManager", + return_value=mock_eval_sets_manager, + ), + patch( + "google.adk.cli.fast_api.LocalEvalSetResultsManager", + return_value=mock_eval_set_results_manager, + ), + patch("a2a.server.tasks.InMemoryTaskStore") as mock_task_store_class, + patch("a2a.server.tasks.InMemoryPushNotificationConfigStore"), + patch("google.adk.a2a.executor.a2a_agent_executor.A2aAgentExecutor"), + patch("a2a.server.request_handlers.DefaultRequestHandler"), + patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app, + ): + mock_a2a_app_instance = MagicMock() + mock_a2a_app_instance.routes.return_value = [] + mock_a2a_app.return_value = mock_a2a_app_instance + monkeypatch.chdir(temp_agents_dir_with_a2a) + + _ = get_fast_api_app( + agents_dir=".", + web=True, + session_service_uri="", + artifact_service_uri="", + memory_service_uri="", + allow_origins=["*"], + a2a=True, + host="127.0.0.1", + port=8000, + ) + + mock_task_store_class.assert_called_once() + + +def test_a2a_custom_task_store_is_used( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + temp_agents_dir_with_a2a, + monkeypatch, +): + """Test that a custom task store is forwarded to DefaultRequestHandler.""" + custom_task_store = MagicMock() + + with ( + patch("signal.signal", return_value=None), + patch( + "google.adk.cli.fast_api.create_session_service_from_options", + return_value=mock_session_service, + ), + patch( + "google.adk.cli.fast_api.create_artifact_service_from_options", + return_value=mock_artifact_service, + ), + patch( + "google.adk.cli.fast_api.create_memory_service_from_options", + return_value=mock_memory_service, + ), + patch( + "google.adk.cli.fast_api.AgentLoader", + return_value=mock_agent_loader, + ), + patch( + "google.adk.cli.fast_api.LocalEvalSetsManager", + return_value=mock_eval_sets_manager, + ), + patch( + "google.adk.cli.fast_api.LocalEvalSetResultsManager", + return_value=mock_eval_set_results_manager, + ), + patch("a2a.server.tasks.InMemoryTaskStore") as mock_task_store_class, + patch("a2a.server.tasks.InMemoryPushNotificationConfigStore"), + patch("google.adk.a2a.executor.a2a_agent_executor.A2aAgentExecutor"), + patch( + "a2a.server.request_handlers.DefaultRequestHandler" + ) as mock_handler, + patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app, + ): + mock_a2a_app_instance = MagicMock() + mock_a2a_app_instance.routes.return_value = [] + mock_a2a_app.return_value = mock_a2a_app_instance + monkeypatch.chdir(temp_agents_dir_with_a2a) + + _ = get_fast_api_app( + agents_dir=".", + web=True, + session_service_uri="", + artifact_service_uri="", + memory_service_uri="", + allow_origins=["*"], + a2a=True, + a2a_task_store=custom_task_store, + host="127.0.0.1", + port=8000, + ) + + mock_task_store_class.assert_not_called() + assert mock_handler.call_args.kwargs["task_store"] is custom_task_store + + def test_patch_memory(test_app, create_test_session, mock_memory_service): """Test adding a session to memory.""" info = create_test_session