Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 59 additions & 10 deletions src/google/adk/auth/auth_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# Prefix used by toolset auth credential IDs.
# Auth requests with this prefix are for toolset authentication (before tool
# listing) and don't require resuming a function call.
TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = '_adk_toolset_auth_'
TOOLSET_AUTH_CREDENTIAL_ID_PREFIX = "_adk_toolset_auth_"


async def _store_auth_and_collect_resume_targets(
Expand Down Expand Up @@ -80,18 +80,57 @@ async def _store_auth_and_collect_resume_targets(
except TypeError:
continue

# Step 2: Store credentials. Merge credential_key from the original
# request into the client's auth response before storing.
authorized_keys: set[str] = set()
for fc_id in auth_fc_ids:
if fc_id not in auth_responses:
continue
auth_config = AuthConfig.model_validate(auth_responses[fc_id])
requested_auth_config = requested_auth_config_by_id.get(fc_id)
if (
requested_auth_config
and requested_auth_config.credential_key is not None
):
auth_config.credential_key = requested_auth_config.credential_key
if requested_auth_config:
credential_key = getattr(requested_auth_config, "credential_key", None)
if credential_key is not None:
auth_config.credential_key = credential_key
raw_auth_credential = getattr(
requested_auth_config, "raw_auth_credential", None
)
if raw_auth_credential:
if auth_config.raw_auth_credential is None:
auth_config.raw_auth_credential = raw_auth_credential
elif auth_config.raw_auth_credential.oauth2 and getattr(
raw_auth_credential, "oauth2", None
):
target = auth_config.raw_auth_credential.oauth2
source = raw_auth_credential.oauth2
for field in [
"client_id",
"client_secret",
"redirect_uri",
"token_endpoint_auth_method",
]:
if getattr(target, field) is None:
setattr(target, field, getattr(source, field))
exchanged_auth_credential = getattr(
requested_auth_config, "exchanged_auth_credential", None
)
if exchanged_auth_credential:
if auth_config.exchanged_auth_credential is None:
auth_config.exchanged_auth_credential = exchanged_auth_credential
elif auth_config.exchanged_auth_credential.oauth2 and getattr(
exchanged_auth_credential, "oauth2", None
):
target = auth_config.exchanged_auth_credential.oauth2
source = exchanged_auth_credential.oauth2
for field in [
"client_id",
"client_secret",
"redirect_uri",
"token_endpoint_auth_method",
]:
if getattr(target, field) is None:
setattr(target, field, getattr(source, field))
if auth_config.credential_key:
authorized_keys.add(auth_config.credential_key)

await AuthHandler(auth_config=auth_config).parse_and_store_auth_response(
state=state
)
Expand Down Expand Up @@ -121,6 +160,16 @@ async def _store_auth_and_collect_resume_targets(
continue
tools_to_resume.add(args.function_call_id)

for event in events:
actions = getattr(event, "actions", None)
if actions and actions.requested_auth_configs:
for (
original_fc_id,
config,
) in actions.requested_auth_configs.items():
if config.credential_key in authorized_keys:
tools_to_resume.add(original_fc_id)

return tools_to_resume


Expand All @@ -132,7 +181,7 @@ async def run_async(
self, invocation_context: InvocationContext, llm_request: LlmRequest
) -> AsyncGenerator[Event, None]:
agent = invocation_context.agent
if not hasattr(agent, 'canonical_tools'):
if not hasattr(agent, "canonical_tools"):
return
events = invocation_context.session.events
if not events:
Expand All @@ -147,7 +196,7 @@ async def run_async(
last_event_with_content = event
break

if not last_event_with_content or last_event_with_content.author != 'user':
if not last_event_with_content or last_event_with_content.author != "user":
return

responses = last_event_with_content.get_function_responses()
Expand Down
10 changes: 10 additions & 0 deletions src/google/adk/flows/llm_flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,17 @@ def build_auth_request_event(
parts = []
long_running_tool_ids = set()

deduplicated_requests: Dict[str, AuthConfig] = {}
seen_keys = set()
for function_call_id, auth_config in auth_requests.items():
key = auth_config.credential_key
if key is None:
deduplicated_requests[function_call_id] = auth_config
elif key not in seen_keys:
seen_keys.add(key)
deduplicated_requests[function_call_id] = auth_config

for function_call_id, auth_config in deduplicated_requests.items():
request_euc_function_call = types.FunctionCall(
name=REQUEST_EUC_FUNCTION_CALL_NAME,
id=generate_client_function_call_id(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _get_legacy_credential_key(
if auth_credential
else ""
)
return f"{scheme_name}_{credential_name}_existing_exchanged_credential"
return f"user:{scheme_name}_{credential_name}_existing_exchanged_credential"

def get_credential_key(
self,
Expand Down Expand Up @@ -114,7 +114,7 @@ def get_credential_key(
# persisted. temp: namespace will be cleared after current run. but tool
# want access token to be there stored across runs

return f"{scheme_name}_{credential_name}_existing_exchanged_credential"
return f"user:{scheme_name}_{credential_name}_existing_exchanged_credential"

def get_credential(
self,
Expand Down
28 changes: 26 additions & 2 deletions tests/unittests/auth/test_toolset_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,9 +406,12 @@ def test_multiple_auth_requests_create_multiple_parts(
self, mock_invocation_context
):
"""Test that multiple auth requests create multiple function call parts."""
config1 = create_oauth2_auth_config()
config2 = create_oauth2_auth_config()
config2.credential_key = "different_key"
auth_requests = {
"call_1": create_oauth2_auth_config(),
"call_2": create_oauth2_auth_config(),
"call_1": config1,
"call_2": config2,
}

event = build_auth_request_event(mock_invocation_context, auth_requests)
Expand All @@ -419,6 +422,27 @@ def test_multiple_auth_requests_create_multiple_parts(
}
assert function_call_ids == {"call_1", "call_2"}

def test_duplicate_auth_requests_are_deduplicated(
self, mock_invocation_context
):
"""Test that auth requests with the same credential key are deduplicated."""
config1 = create_oauth2_auth_config()
config2 = create_oauth2_auth_config()
# Ensure they have the same credential key
assert config1.credential_key == config2.credential_key

auth_requests = {
"call_1": config1,
"call_2": config2,
}

event = build_auth_request_event(mock_invocation_context, auth_requests)

assert len(event.content.parts) == 1
fc = event.content.parts[0].function_call
assert fc.name == REQUEST_EUC_FUNCTION_CALL_NAME
assert fc.args["functionCallId"] == "call_1"

def test_always_adds_long_running_tool_ids(self, mock_invocation_context):
"""Test that long_running_tool_ids is always set."""
auth_requests = {"call_123": create_oauth2_auth_config()}
Expand Down
Loading