From 70ba8523cdf99ae937a2883085990508b8cef52c Mon Sep 17 00:00:00 2001 From: Sthitaprajna Sahoo Date: Tue, 24 Mar 2026 02:01:04 +0000 Subject: [PATCH 1/2] feat(a2a): expose a2a_task_store and a2a_push_config_store in get_fast_api_app MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Problem `get_fast_api_app()` unconditionally creates an `InMemoryTaskStore` and `InMemoryPushNotificationConfigStore`, making it impossible for callers to inject persistent or shared stores without patching ADK internals. This is especially painful in production deployments where: - Multiple replicas need a shared task store to route A2A callbacks correctly - Operators want task state to survive server restarts (e.g. SQLite/Postgres) ## Solution Adds two new optional keyword arguments to `get_fast_api_app()`: - `a2a_task_store: Optional[Any] = None` - `a2a_push_config_store: Optional[Any] = None` When `None` (the default), the existing `InMemory*` defaults are used — fully backward-compatible. When provided, the caller-supplied instances are forwarded directly to `DefaultRequestHandler`. This mirrors the pattern introduced for the lower-level `to_a2a()` helper in PR #3839. ## Usage from a2a.server.tasks import DatabaseTaskStore from sqlalchemy.ext.asyncio import create_async_engine engine = create_async_engine("postgresql+asyncpg://user:pw@host/db") app = get_fast_api_app( agents_dir="agents/", web=True, a2a=True, a2a_task_store=DatabaseTaskStore(engine), ) ## Tests Two new test cases added to tests/unittests/cli/test_fast_api.py: - test_a2a_uses_in_memory_task_store_by_default - test_a2a_custom_task_store_bypasses_in_memory_default --- src/google/adk/cli/fast_api.py | 28 +++++- tests/unittests/cli/test_fast_api.py | 132 +++++++++++++++++++++++++++ 2 files changed, 158 insertions(+), 2 deletions(-) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 0b6f3fb6fe..2b0bda6498 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -83,6 +83,8 @@ def get_fast_api_app( allow_origins: Optional[list[str]] = None, web: bool, a2a: bool = False, + a2a_task_store: Optional[Any] = None, + a2a_push_config_store: Optional[Any] = None, host: str = "127.0.0.1", port: int = 8000, url_prefix: Optional[str] = None, @@ -125,6 +127,23 @@ def get_fast_api_app( allow_origins: List of allowed origins for CORS. web: Whether to enable the web UI and serve its assets. a2a: Whether to enable Agent-to-Agent (A2A) protocol support. + a2a_task_store: Optional A2A TaskStore instance. Defaults to + InMemoryTaskStore when a2a=True. Pass a DatabaseTaskStore (from the + a2a-sdk) for persistence across server restarts and horizontal replicas. + Example:: + + from a2a.server.tasks import DatabaseTaskStore + from sqlalchemy.ext.asyncio import create_async_engine + + engine = create_async_engine("postgresql+asyncpg://user:pw@host/db") + app = get_fast_api_app( + agents_dir="agents/", web=True, a2a=True, + a2a_task_store=DatabaseTaskStore(engine), + ) + + a2a_push_config_store: Optional A2A PushNotificationConfigStore instance. + Defaults to InMemoryPushNotificationConfigStore when a2a=True. Pass a + DatabasePushNotificationConfigStore for persistence across restarts. host: Host address for the server (defaults to 127.0.0.1). port: Port number for the server (defaults to 8000). url_prefix: Optional prefix for all URL routes. @@ -561,7 +580,8 @@ async def get_agent_builder( base_path = Path.cwd() / agents_dir # the root agents directory should be an existing folder if base_path.exists() and base_path.is_dir(): - a2a_task_store = InMemoryTaskStore() + if a2a_task_store is None: + a2a_task_store = InMemoryTaskStore() def create_a2a_runner_loader(captured_app_name: str): """Factory function to create A2A runner with proper closure.""" @@ -589,7 +609,11 @@ async def _get_a2a_runner_async() -> Runner: runner=create_a2a_runner_loader(app_name), ) - push_config_store = InMemoryPushNotificationConfigStore() + push_config_store = ( + a2a_push_config_store + if a2a_push_config_store is not None + else InMemoryPushNotificationConfigStore() + ) request_handler = DefaultRequestHandler( agent_executor=agent_executor, diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 15bc908ddb..330871bebc 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -1547,6 +1547,138 @@ def test_a2a_disabled_by_default(test_app): logger.info("A2A disabled by default test passed") +def test_a2a_uses_in_memory_task_store_by_default( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + temp_agents_dir_with_a2a, + monkeypatch, +): + """Test that InMemoryTaskStore is created when no task_store is provided.""" + with ( + patch("signal.signal", return_value=None), + patch( + "google.adk.cli.fast_api.create_session_service_from_options", + return_value=mock_session_service, + ), + patch( + "google.adk.cli.fast_api.create_artifact_service_from_options", + return_value=mock_artifact_service, + ), + patch( + "google.adk.cli.fast_api.create_memory_service_from_options", + return_value=mock_memory_service, + ), + patch( + "google.adk.cli.fast_api.AgentLoader", + return_value=mock_agent_loader, + ), + patch( + "google.adk.cli.fast_api.LocalEvalSetsManager", + return_value=mock_eval_sets_manager, + ), + patch( + "google.adk.cli.fast_api.LocalEvalSetResultsManager", + return_value=mock_eval_set_results_manager, + ), + patch("a2a.server.tasks.InMemoryTaskStore") as mock_task_store_class, + patch("a2a.server.tasks.InMemoryPushNotificationConfigStore"), + patch("google.adk.a2a.executor.a2a_agent_executor.A2aAgentExecutor"), + patch("a2a.server.request_handlers.DefaultRequestHandler"), + patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app, + ): + mock_a2a_app.return_value.routes.return_value = [] + monkeypatch.chdir(temp_agents_dir_with_a2a) + + _ = get_fast_api_app( + agents_dir=".", + web=True, + session_service_uri="", + artifact_service_uri="", + memory_service_uri="", + allow_origins=["*"], + a2a=True, + host="127.0.0.1", + port=8000, + ) + + mock_task_store_class.assert_called_once() + + +def test_a2a_custom_task_store_bypasses_in_memory_default( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + temp_agents_dir_with_a2a, + monkeypatch, +): + """Test that a custom task_store is forwarded and InMemoryTaskStore is not created.""" + custom_task_store = MagicMock() + + with ( + patch("signal.signal", return_value=None), + patch( + "google.adk.cli.fast_api.create_session_service_from_options", + return_value=mock_session_service, + ), + patch( + "google.adk.cli.fast_api.create_artifact_service_from_options", + return_value=mock_artifact_service, + ), + patch( + "google.adk.cli.fast_api.create_memory_service_from_options", + return_value=mock_memory_service, + ), + patch( + "google.adk.cli.fast_api.AgentLoader", + return_value=mock_agent_loader, + ), + patch( + "google.adk.cli.fast_api.LocalEvalSetsManager", + return_value=mock_eval_sets_manager, + ), + patch( + "google.adk.cli.fast_api.LocalEvalSetResultsManager", + return_value=mock_eval_set_results_manager, + ), + patch("a2a.server.tasks.InMemoryTaskStore") as mock_task_store_class, + patch("a2a.server.tasks.InMemoryPushNotificationConfigStore"), + patch("google.adk.a2a.executor.a2a_agent_executor.A2aAgentExecutor"), + patch( + "a2a.server.request_handlers.DefaultRequestHandler" + ) as mock_handler, + patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app, + ): + mock_a2a_app.return_value.routes.return_value = [] + monkeypatch.chdir(temp_agents_dir_with_a2a) + + _ = get_fast_api_app( + agents_dir=".", + web=True, + session_service_uri="", + artifact_service_uri="", + memory_service_uri="", + allow_origins=["*"], + a2a=True, + a2a_task_store=custom_task_store, + host="127.0.0.1", + port=8000, + ) + + # InMemoryTaskStore must NOT be instantiated when a custom store is supplied + mock_task_store_class.assert_not_called() + + # The custom store must be passed through to DefaultRequestHandler + call_kwargs = mock_handler.call_args.kwargs + assert call_kwargs["task_store"] is custom_task_store + + def test_patch_memory(test_app, create_test_session, mock_memory_service): """Test adding a session to memory.""" info = create_test_session From 00919924ea4bd550e87f2d3463495459d01e02e5 Mon Sep 17 00:00:00 2001 From: Sthitaprajna Sahoo Date: Tue, 24 Mar 2026 02:06:30 +0000 Subject: [PATCH 2/2] fixup: code style and convention alignment - hoist a2a_push_config_store init outside per-agent loop, matching the task_store pattern (both now use if-None guard at the top) - remove redundant local push_config_store variable; use a2a_push_config_store directly in DefaultRequestHandler call - trim docstring: drop inline Example block, shorten to match the one-liner style used by the other params in the same function - rename test to test_a2a_custom_task_store_is_used (shorter, matches existing naming style) - fix test mock setup to use an explicit mock_a2a_app_instance like the test_app_with_a2a fixture does - simplify test assertion: drop intermediate call_kwargs variable --- src/google/adk/cli/fast_api.py | 31 +++++++--------------------- tests/unittests/cli/test_fast_api.py | 20 +++++++++--------- 2 files changed, 18 insertions(+), 33 deletions(-) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 2b0bda6498..67342f780a 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -128,22 +128,11 @@ def get_fast_api_app( web: Whether to enable the web UI and serve its assets. a2a: Whether to enable Agent-to-Agent (A2A) protocol support. a2a_task_store: Optional A2A TaskStore instance. Defaults to - InMemoryTaskStore when a2a=True. Pass a DatabaseTaskStore (from the - a2a-sdk) for persistence across server restarts and horizontal replicas. - Example:: - - from a2a.server.tasks import DatabaseTaskStore - from sqlalchemy.ext.asyncio import create_async_engine - - engine = create_async_engine("postgresql+asyncpg://user:pw@host/db") - app = get_fast_api_app( - agents_dir="agents/", web=True, a2a=True, - a2a_task_store=DatabaseTaskStore(engine), - ) - - a2a_push_config_store: Optional A2A PushNotificationConfigStore instance. - Defaults to InMemoryPushNotificationConfigStore when a2a=True. Pass a - DatabasePushNotificationConfigStore for persistence across restarts. + InMemoryTaskStore. Pass a custom store (e.g. DatabaseTaskStore) to + persist task state across restarts or share it across replicas. + a2a_push_config_store: Optional A2A PushNotificationConfigStore. + Defaults to InMemoryPushNotificationConfigStore. Pass a custom store + for persistence across restarts. host: Host address for the server (defaults to 127.0.0.1). port: Port number for the server (defaults to 8000). url_prefix: Optional prefix for all URL routes. @@ -582,6 +571,8 @@ async def get_agent_builder( if base_path.exists() and base_path.is_dir(): if a2a_task_store is None: a2a_task_store = InMemoryTaskStore() + if a2a_push_config_store is None: + a2a_push_config_store = InMemoryPushNotificationConfigStore() def create_a2a_runner_loader(captured_app_name: str): """Factory function to create A2A runner with proper closure.""" @@ -609,16 +600,10 @@ async def _get_a2a_runner_async() -> Runner: runner=create_a2a_runner_loader(app_name), ) - push_config_store = ( - a2a_push_config_store - if a2a_push_config_store is not None - else InMemoryPushNotificationConfigStore() - ) - request_handler = DefaultRequestHandler( agent_executor=agent_executor, task_store=a2a_task_store, - push_config_store=push_config_store, + push_config_store=a2a_push_config_store, ) with (p / "agent.json").open("r", encoding="utf-8") as f: diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 330871bebc..26e6649bf9 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -1557,7 +1557,7 @@ def test_a2a_uses_in_memory_task_store_by_default( temp_agents_dir_with_a2a, monkeypatch, ): - """Test that InMemoryTaskStore is created when no task_store is provided.""" + """Test that InMemoryTaskStore is used when no task store is provided.""" with ( patch("signal.signal", return_value=None), patch( @@ -1590,7 +1590,9 @@ def test_a2a_uses_in_memory_task_store_by_default( patch("a2a.server.request_handlers.DefaultRequestHandler"), patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app, ): - mock_a2a_app.return_value.routes.return_value = [] + mock_a2a_app_instance = MagicMock() + mock_a2a_app_instance.routes.return_value = [] + mock_a2a_app.return_value = mock_a2a_app_instance monkeypatch.chdir(temp_agents_dir_with_a2a) _ = get_fast_api_app( @@ -1608,7 +1610,7 @@ def test_a2a_uses_in_memory_task_store_by_default( mock_task_store_class.assert_called_once() -def test_a2a_custom_task_store_bypasses_in_memory_default( +def test_a2a_custom_task_store_is_used( mock_session_service, mock_artifact_service, mock_memory_service, @@ -1618,7 +1620,7 @@ def test_a2a_custom_task_store_bypasses_in_memory_default( temp_agents_dir_with_a2a, monkeypatch, ): - """Test that a custom task_store is forwarded and InMemoryTaskStore is not created.""" + """Test that a custom task store is forwarded to DefaultRequestHandler.""" custom_task_store = MagicMock() with ( @@ -1655,7 +1657,9 @@ def test_a2a_custom_task_store_bypasses_in_memory_default( ) as mock_handler, patch("a2a.server.apps.A2AStarletteApplication") as mock_a2a_app, ): - mock_a2a_app.return_value.routes.return_value = [] + mock_a2a_app_instance = MagicMock() + mock_a2a_app_instance.routes.return_value = [] + mock_a2a_app.return_value = mock_a2a_app_instance monkeypatch.chdir(temp_agents_dir_with_a2a) _ = get_fast_api_app( @@ -1671,12 +1675,8 @@ def test_a2a_custom_task_store_bypasses_in_memory_default( port=8000, ) - # InMemoryTaskStore must NOT be instantiated when a custom store is supplied mock_task_store_class.assert_not_called() - - # The custom store must be passed through to DefaultRequestHandler - call_kwargs = mock_handler.call_args.kwargs - assert call_kwargs["task_store"] is custom_task_store + assert mock_handler.call_args.kwargs["task_store"] is custom_task_store def test_patch_memory(test_app, create_test_session, mock_memory_service):