Skip to content

Commit 373fddb

Browse files
committed
fix(auth): Support fallback OAuth token and prefixless credential lookups in session state
Session state might store authentication responses as raw string tokens instead of AuthCredential objects, or under custom credential keys without the standard "temp:" prefix. Add robust fallback handling to resolve raw token strings, check for prefixless keys, and scan state values for any Google OAuth access tokens starting with "ya29."
1 parent 9670ce2 commit 373fddb

2 files changed

Lines changed: 100 additions & 7 deletions

File tree

src/google/adk/auth/auth_handler.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,13 @@ async def exchange_auth_token(
5555
return exchange_result.credential
5656

5757
async def parse_and_store_auth_response(self, state: State) -> None:
58+
credential_key = self.auth_config.credential_key
59+
if not credential_key:
60+
raise ValueError("credential_key is empty.")
5861

59-
credential_key = "temp:" + self.auth_config.credential_key
62+
temp_credential_key = "temp:" + credential_key
6063

61-
state[credential_key] = self.auth_config.exchanged_auth_credential
64+
state[temp_credential_key] = self.auth_config.exchanged_auth_credential
6265
if not isinstance(
6366
self.auth_config.auth_scheme, SecurityBase
6467
) or self.auth_config.auth_scheme.type_ not in (
@@ -67,15 +70,54 @@ async def parse_and_store_auth_response(self, state: State) -> None:
6770
):
6871
return
6972

70-
state[credential_key] = await self.exchange_auth_token()
73+
state[temp_credential_key] = await self.exchange_auth_token()
7174

7275
def _validate(self) -> None:
73-
if not self.auth_scheme:
76+
if not self.auth_config.auth_scheme:
7477
raise ValueError("auth_scheme is empty.")
7578

76-
def get_auth_response(self, state: State) -> AuthCredential:
77-
credential_key = "temp:" + self.auth_config.credential_key
78-
return state.get(credential_key, None)
79+
def get_auth_response(self, state: State) -> AuthCredential | None:
80+
# 1. Try reading the temp credential key (standard ADK flow)
81+
credential_key = self.auth_config.credential_key
82+
if not credential_key:
83+
return None
84+
85+
temp_credential_key = "temp:" + credential_key
86+
val = state.get(temp_credential_key, None)
87+
if val is not None:
88+
if isinstance(val, AuthCredential):
89+
return val
90+
if isinstance(val, str) and val:
91+
return self._build_oauth2_credential(val)
92+
93+
# 2. Try reading the credential key without the 'temp:' prefix
94+
val = state.get(credential_key, None)
95+
if val is not None:
96+
if isinstance(val, AuthCredential):
97+
return val
98+
if isinstance(val, str) and val:
99+
return self._build_oauth2_credential(val)
100+
101+
# 3. Fallback: scan the state for any active Google OAuth access token (ya29.*)
102+
try:
103+
state_dict = state.to_dict() if hasattr(state, "to_dict") else state
104+
if isinstance(state_dict, dict):
105+
for k, v in state_dict.items():
106+
if isinstance(v, str) and v.startswith("ya29."):
107+
return self._build_oauth2_credential(v)
108+
except Exception: # pylint: disable=broad-except
109+
pass
110+
111+
return None
112+
113+
def _build_oauth2_credential(self, token: str) -> AuthCredential:
114+
from .auth_credential import AuthCredentialTypes
115+
from .auth_credential import OAuth2Auth
116+
117+
return AuthCredential(
118+
auth_type=AuthCredentialTypes.OAUTH2,
119+
oauth2=OAuth2Auth(access_token=token),
120+
)
79121

80122
def generate_auth_request(self) -> AuthConfig:
81123
if not isinstance(

tests/unittests/auth/test_auth_handler.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,57 @@ def test_get_auth_response_not_exists(self, auth_config):
503503
result = handler.get_auth_response(state)
504504
assert result is None
505505

506+
def test_get_auth_response_temp_prefix_str_token(self, auth_config):
507+
"""Test retrieving a string token stored under temp prefix in state."""
508+
handler = AuthHandler(auth_config)
509+
state = MockState()
510+
credential_key = auth_config.credential_key
511+
state["temp:" + credential_key] = "ya29.mock_token"
512+
513+
result = handler.get_auth_response(state)
514+
515+
assert result is not None
516+
assert result.auth_type == AuthCredentialTypes.OAUTH2
517+
assert result.oauth2.access_token == "ya29.mock_token"
518+
519+
def test_get_auth_response_no_prefix_credential(
520+
self, auth_config, oauth2_credentials_with_auth_uri
521+
):
522+
"""Test retrieving a credential stored under the key without prefix."""
523+
handler = AuthHandler(auth_config)
524+
state = MockState()
525+
credential_key = auth_config.credential_key
526+
state[credential_key] = oauth2_credentials_with_auth_uri
527+
528+
result = handler.get_auth_response(state)
529+
530+
assert result == oauth2_credentials_with_auth_uri
531+
532+
def test_get_auth_response_no_prefix_str_token(self, auth_config):
533+
"""Test retrieving a string token stored under the key without prefix."""
534+
handler = AuthHandler(auth_config)
535+
state = MockState()
536+
credential_key = auth_config.credential_key
537+
state[credential_key] = "ya29.mock_token_no_prefix"
538+
539+
result = handler.get_auth_response(state)
540+
541+
assert result is not None
542+
assert result.auth_type == AuthCredentialTypes.OAUTH2
543+
assert result.oauth2.access_token == "ya29.mock_token_no_prefix"
544+
545+
def test_get_auth_response_fallback_google_token(self, auth_config):
546+
"""Test retrieving fallback Google token from state via scanning."""
547+
handler = AuthHandler(auth_config)
548+
state = MockState()
549+
state["some_other_key"] = "ya29.fallback_google_token"
550+
551+
result = handler.get_auth_response(state)
552+
553+
assert result is not None
554+
assert result.auth_type == AuthCredentialTypes.OAUTH2
555+
assert result.oauth2.access_token == "ya29.fallback_google_token"
556+
506557

507558
class TestParseAndStoreAuthResponse:
508559
"""Tests for the parse_and_store_auth_response method."""

0 commit comments

Comments
 (0)