From 5c28f3aca2aa32703e72871a584ffe79b8de64e4 Mon Sep 17 00:00:00 2001 From: Tony Coconate Date: Fri, 5 Jun 2026 11:39:52 -0500 Subject: [PATCH] fix(auth): Prevent duplicate OAuth prompts and fix tool resumption mapping --- src/google/adk/auth/auth_preprocessor.py | 69 ++++++++++++++++--- src/google/adk/flows/llm_flows/functions.py | 10 +++ .../openapi_spec_parser/tool_auth_handler.py | 4 +- tests/unittests/auth/test_toolset_auth.py | 28 +++++++- 4 files changed, 97 insertions(+), 14 deletions(-) diff --git a/src/google/adk/auth/auth_preprocessor.py b/src/google/adk/auth/auth_preprocessor.py index b0fa1e0ba8..6e34ce9be9 100644 --- a/src/google/adk/auth/auth_preprocessor.py +++ b/src/google/adk/auth/auth_preprocessor.py @@ -34,7 +34,7 @@ # Prefix used by toolset auth credential IDs. # Auth requests with this prefix are for toolset authentication (before tool # listing) and don't require resuming a function call. -TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_' +TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = "_adk_toolset_auth_" async def _store_auth_and_collect_resume_targets( @@ -80,18 +80,57 @@ async def _store_auth_and_collect_resume_targets( except TypeError: continue - # Step 2: Store credentials. Merge credential_key from the original - # request into the client's auth response before storing. + authorized_keys: set[str] = set() for fc_id in auth_fc_ids: if fc_id not in auth_responses: continue auth_config = AuthConfig.model_validate(auth_responses[fc_id]) requested_auth_config = requested_auth_config_by_id.get(fc_id) - if ( - requested_auth_config - and requested_auth_config.credential_key is not None - ): - auth_config.credential_key = requested_auth_config.credential_key + if requested_auth_config: + credential_key = getattr(requested_auth_config, "credential_key", None) + if credential_key is not None: + auth_config.credential_key = credential_key + raw_auth_credential = getattr( + requested_auth_config, "raw_auth_credential", None + ) + if raw_auth_credential: + if auth_config.raw_auth_credential is None: + auth_config.raw_auth_credential = raw_auth_credential + elif auth_config.raw_auth_credential.oauth2 and getattr( + raw_auth_credential, "oauth2", None + ): + target = auth_config.raw_auth_credential.oauth2 + source = raw_auth_credential.oauth2 + for field in [ + "client_id", + "client_secret", + "redirect_uri", + "token_endpoint_auth_method", + ]: + if getattr(target, field) is None: + setattr(target, field, getattr(source, field)) + exchanged_auth_credential = getattr( + requested_auth_config, "exchanged_auth_credential", None + ) + if exchanged_auth_credential: + if auth_config.exchanged_auth_credential is None: + auth_config.exchanged_auth_credential = exchanged_auth_credential + elif auth_config.exchanged_auth_credential.oauth2 and getattr( + exchanged_auth_credential, "oauth2", None + ): + target = auth_config.exchanged_auth_credential.oauth2 + source = exchanged_auth_credential.oauth2 + for field in [ + "client_id", + "client_secret", + "redirect_uri", + "token_endpoint_auth_method", + ]: + if getattr(target, field) is None: + setattr(target, field, getattr(source, field)) + if auth_config.credential_key: + authorized_keys.add(auth_config.credential_key) + await AuthHandler(auth_config=auth_config).parse_and_store_auth_response( state=state ) @@ -121,6 +160,16 @@ async def _store_auth_and_collect_resume_targets( continue tools_to_resume.add(args.function_call_id) + for event in events: + actions = getattr(event, "actions", None) + if actions and actions.requested_auth_configs: + for ( + original_fc_id, + config, + ) in actions.requested_auth_configs.items(): + if config.credential_key in authorized_keys: + tools_to_resume.add(original_fc_id) + return tools_to_resume @@ -132,7 +181,7 @@ async def run_async( self, invocation_context: InvocationContext, llm_request: LlmRequest ) -> AsyncGenerator[Event, None]: agent = invocation_context.agent - if not hasattr(agent, 'canonical_tools'): + if not hasattr(agent, "canonical_tools"): return events = invocation_context.session.events if not events: @@ -147,7 +196,7 @@ async def run_async( last_event_with_content = event break - if not last_event_with_content or last_event_with_content.author != 'user': + if not last_event_with_content or last_event_with_content.author != "user": return responses = last_event_with_content.get_function_responses() diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 259d40b6b6..2515d179d8 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -295,7 +295,17 @@ def build_auth_request_event( parts = [] long_running_tool_ids = set() + deduplicated_requests: Dict[str, AuthConfig] = {} + seen_keys = set() for function_call_id, auth_config in auth_requests.items(): + key = auth_config.credential_key + if key is None: + deduplicated_requests[function_call_id] = auth_config + elif key not in seen_keys: + seen_keys.add(key) + deduplicated_requests[function_call_id] = auth_config + + for function_call_id, auth_config in deduplicated_requests.items(): request_euc_function_call = types.FunctionCall( name=REQUEST_EUC_FUNCTION_CALL_NAME, id=generate_client_function_call_id(), diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py index 0d78a5759b..88a3632242 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py @@ -80,7 +80,7 @@ def _get_legacy_credential_key( if auth_credential else "" ) - return f"{scheme_name}_{credential_name}_existing_exchanged_credential" + return f"user:{scheme_name}_{credential_name}_existing_exchanged_credential" def get_credential_key( self, @@ -114,7 +114,7 @@ def get_credential_key( # persisted. temp: namespace will be cleared after current run. but tool # want access token to be there stored across runs - return f"{scheme_name}_{credential_name}_existing_exchanged_credential" + return f"user:{scheme_name}_{credential_name}_existing_exchanged_credential" def get_credential( self, diff --git a/tests/unittests/auth/test_toolset_auth.py b/tests/unittests/auth/test_toolset_auth.py index 7c231aba43..72fcd619de 100644 --- a/tests/unittests/auth/test_toolset_auth.py +++ b/tests/unittests/auth/test_toolset_auth.py @@ -406,9 +406,12 @@ def test_multiple_auth_requests_create_multiple_parts( self, mock_invocation_context ): """Test that multiple auth requests create multiple function call parts.""" + config1 = create_oauth2_auth_config() + config2 = create_oauth2_auth_config() + config2.credential_key = "different_key" auth_requests = { - "call_1": create_oauth2_auth_config(), - "call_2": create_oauth2_auth_config(), + "call_1": config1, + "call_2": config2, } event = build_auth_request_event(mock_invocation_context, auth_requests) @@ -419,6 +422,27 @@ def test_multiple_auth_requests_create_multiple_parts( } assert function_call_ids == {"call_1", "call_2"} + def test_duplicate_auth_requests_are_deduplicated( + self, mock_invocation_context + ): + """Test that auth requests with the same credential key are deduplicated.""" + config1 = create_oauth2_auth_config() + config2 = create_oauth2_auth_config() + # Ensure they have the same credential key + assert config1.credential_key == config2.credential_key + + auth_requests = { + "call_1": config1, + "call_2": config2, + } + + event = build_auth_request_event(mock_invocation_context, auth_requests) + + assert len(event.content.parts) == 1 + fc = event.content.parts[0].function_call + assert fc.name == REQUEST_EUC_FUNCTION_CALL_NAME + assert fc.args["functionCallId"] == "call_1" + def test_always_adds_long_running_tool_ids(self, mock_invocation_context): """Test that long_running_tool_ids is always set.""" auth_requests = {"call_123": create_oauth2_auth_config()}