diff --git a/src/google/adk/auth/credential_service/session_state_credential_service.py b/src/google/adk/auth/credential_service/session_state_credential_service.py index 5559ec6005..45504f9565 100644 --- a/src/google/adk/auth/credential_service/session_state_credential_service.py +++ b/src/google/adk/auth/credential_service/session_state_credential_service.py @@ -19,6 +19,7 @@ from typing_extensions import override from ...agents.callback_context import CallbackContext +from ...sessions.state import State from ...utils.feature_decorator import experimental from ..auth_credential import AuthCredential from ..auth_tool import AuthConfig @@ -54,7 +55,20 @@ async def load_credential( Optional[AuthCredential]: the credential saved in the store. """ - return callback_context.state.get(auth_config.credential_key) + secret_key = State.SECRET_PREFIX + auth_config.credential_key + # Use `in` (not truthiness) so an explicit None is respected. + if secret_key in callback_context.state: + return callback_context.state[secret_key] + # Fall back to legacy unprefixed key, then migrate: copy into + # secret: scope and clear the legacy key so it is removed from + # persistent storage on the next state delta flush. + legacy_key = auth_config.credential_key + if legacy_key in callback_context.state: + value = callback_context.state[legacy_key] + callback_context.state[secret_key] = value + callback_context.state[legacy_key] = None + return value + return None @override async def save_credential( @@ -78,6 +92,6 @@ async def save_credential( None """ - callback_context.state[auth_config.credential_key] = ( + callback_context.state[State.SECRET_PREFIX + auth_config.credential_key] = ( auth_config.exchanged_auth_credential ) diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index fddae0ad33..ab2591a70b 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -218,6 +218,24 @@ def _get_tool_origin(tool: "BaseTool") -> str: }) +def _is_sensitive_json_string(value: str) -> bool: + """Checks if a string is a JSON blob containing sensitive keys. + + This catches opaque JSON-encoded strings (e.g. serialized credential + caches) whose inner keys would not be caught by the dict-level + _SENSITIVE_KEYS check. + """ + if not value or value[0] not in ("{", "["): + return False + try: + parsed = json.loads(value) + except (json.JSONDecodeError, ValueError): + return False + if isinstance(parsed, dict): + return bool(_SENSITIVE_KEYS & {k.lower() for k in parsed}) + return False + + def _recursive_smart_truncate( obj: Any, max_len: int, seen: Optional[set[int]] = None ) -> tuple[Any, bool]: @@ -266,10 +284,19 @@ def _recursive_smart_truncate( for k, v in obj.items(): if isinstance(k, str): k_lower = k.lower() - if k_lower in _SENSITIVE_KEYS or k_lower.startswith("temp:"): + if ( + k_lower in _SENSITIVE_KEYS + or k_lower.startswith("temp:") + or k_lower.startswith("secret:") + ): new_dict[k] = "[REDACTED]" continue + # Detect JSON-encoded strings that contain sensitive keys. + if isinstance(v, str) and _is_sensitive_json_string(v): + new_dict[k] = "[REDACTED]" + continue + val, trunc = _recursive_smart_truncate(v, max_len, seen) if trunc: truncated_any = True diff --git a/src/google/adk/sessions/_session_util.py b/src/google/adk/sessions/_session_util.py index 3a92021929..8a03267d0d 100644 --- a/src/google/adk/sessions/_session_util.py +++ b/src/google/adk/sessions/_session_util.py @@ -45,6 +45,8 @@ def extract_state_delta( deltas["app"][key.removeprefix(State.APP_PREFIX)] = state[key] elif key.startswith(State.USER_PREFIX): deltas["user"][key.removeprefix(State.USER_PREFIX)] = state[key] - elif not key.startswith(State.TEMP_PREFIX): + elif not key.startswith(State.TEMP_PREFIX) and not key.startswith( + State.SECRET_PREFIX + ): deltas["session"][key] = state[key] return deltas diff --git a/src/google/adk/sessions/base_session_service.py b/src/google/adk/sessions/base_session_service.py index af94bb9eeb..5d9febe01f 100644 --- a/src/google/adk/sessions/base_session_service.py +++ b/src/google/adk/sessions/base_session_service.py @@ -57,6 +57,22 @@ class BaseSessionService(abc.ABC): The service provides a set of methods for managing sessions and events. """ + @property + def _secret_state_cache( + self, + ) -> dict[tuple[str, str, str], dict[str, Any]]: + """Process-local cache for secret-scoped state. + + Keyed by (app_name, user_id, session_id). + Lazily initialized to avoid requiring subclasses to call + super().__init__(). + """ + try: + return self.__secret_state_cache + except AttributeError: + self.__secret_state_cache: dict[tuple[str, str, str], dict[str, Any]] = {} + return self.__secret_state_cache + @abc.abstractmethod async def create_session( self, @@ -120,6 +136,10 @@ async def append_event(self, session: Session, event: Event) -> Event: # read temp values (e.g. output_key='temp:my_key' in SequentialAgent). self._apply_temp_state(session, event) event = self._trim_temp_delta_state(event) + # Apply secret-scoped state to in-memory session and process cache + # BEFORE trimming, so the session retains secret values across turns. + self._apply_secret_state(session, event) + event = self._trim_secret_delta_state(event) self._update_session_state(session, event) session.events.append(event) return event @@ -154,6 +174,73 @@ def _trim_temp_delta_state(self, event: Event) -> Event: } return event + def _apply_secret_state(self, session: Session, event: Event) -> None: + """Applies secret-scoped state to in-memory session and process cache. + + Secret state survives across turns (via the process-local cache) but + is never persisted to storage. The event delta is trimmed separately + by _trim_secret_delta_state. + """ + if not event.actions or not event.actions.state_delta: + return + cache_key = (session.app_name, session.user_id, session.id) + for key, value in event.actions.state_delta.items(): + if key.startswith(State.SECRET_PREFIX): + session.state[key] = value + self._secret_state_cache.setdefault(cache_key, {})[key] = value + + def _trim_secret_delta_state(self, event: Event) -> Event: + """Removes secret-scoped keys from event delta before persistence.""" + if not event.actions or not event.actions.state_delta: + return event + event.actions.state_delta = { + key: value + for key, value in event.actions.state_delta.items() + if not key.startswith(State.SECRET_PREFIX) + } + return event + + def _seed_secret_state_on_create( + self, + *, + app_name: str, + user_id: str, + session_id: str, + state: Optional[dict[str, Any]], + ) -> Optional[dict[str, Any]]: + """Extracts secret-scoped keys from initial state into the cache. + + Returns the state dict with secret keys removed (for persistence) + but seeds them in the process-local cache so get_session() can + restore them. + """ + if not state: + return state + secret_keys = { + k: v for k, v in state.items() if k.startswith(State.SECRET_PREFIX) + } + if not secret_keys: + return state + cache_key = (app_name, user_id, session_id) + self._secret_state_cache.setdefault(cache_key, {}).update(secret_keys) + return { + k: v for k, v in state.items() if not k.startswith(State.SECRET_PREFIX) + } + + def _restore_secret_state(self, session: Session) -> None: + """Merges cached secret state into an in-memory session.""" + cache_key = (session.app_name, session.user_id, session.id) + secret_state = self._secret_state_cache.get(cache_key, {}) + for key, value in secret_state.items(): + session.state[key] = value + + def _evict_secret_state( + self, app_name: str, user_id: str, session_id: str + ) -> None: + """Removes cached secret state for a deleted session.""" + cache_key = (app_name, user_id, session_id) + self._secret_state_cache.pop(cache_key, None) + def _update_session_state(self, session: Session, event: Event) -> None: """Updates the session state based on the event.""" if not event.actions or not event.actions.state_delta: diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index d033f1f234..74a5238c03 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -445,6 +445,21 @@ async def create_session( defaults={"app_name": app_name, "user_id": user_id, "state": {}}, ) + # Extract secret keys before they are dropped by + # extract_state_delta; we'll seed the cache after the session + # is committed and has a final ID. + secret_keys = {} + if state: + secret_keys = { + k: v for k, v in state.items() if k.startswith(State.SECRET_PREFIX) + } + if secret_keys: + state = { + k: v + for k, v in state.items() + if not k.startswith(State.SECRET_PREFIX) + } + # Extract state deltas state_deltas = _session_util.extract_state_delta(state) app_state_delta = state_deltas["app"] @@ -482,6 +497,13 @@ async def create_session( session = storage_session.to_session( state=merged_state, is_sqlite=is_sqlite ) + + # Seed secret state into the cache now that session.id is resolved. + if secret_keys: + cache_key = (app_name, user_id, session.id) + self._secret_state_cache.setdefault(cache_key, {}).update(secret_keys) + for key, value in secret_keys.items(): + session.state[key] = value return session @override @@ -547,6 +569,7 @@ async def get_session( session = storage_session.to_session( state=merged_state, events=events, is_sqlite=is_sqlite ) + self._restore_secret_state(session) return session @override @@ -615,6 +638,7 @@ async def delete_session( ) await sql_session.execute(stmt) await sql_session.commit() + self._evict_secret_state(app_name, user_id, session_id) @override async def append_event(self, session: Session, event: Event) -> Event: @@ -627,6 +651,9 @@ async def append_event(self, session: Session, event: Event) -> Event: self._apply_temp_state(session, event) # Trim temp state before persisting event = self._trim_temp_delta_state(event) + # Apply secret state to in-memory session and process cache. + self._apply_secret_state(session, event) + event = self._trim_secret_delta_state(event) # 1. Validate the session has not gone stale. # 2. Update session attributes based on event config. diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index 1bef516086..eb39fe21b0 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -118,6 +118,22 @@ def _create_session_impl( app_name=app_name, user_id=user_id, session_id=session_id ): raise AlreadyExistsError(f'Session with id {session_id} already exists.') + + session_id = ( + session_id.strip() + if session_id and session_id.strip() + else platform_uuid.new_uuid() + ) + + # Seed secret state into the process-local cache before + # extract_state_delta, which would otherwise drop secret keys. + state = self._seed_secret_state_on_create( + app_name=app_name, + user_id=user_id, + session_id=session_id, + state=state, + ) + state_deltas = _session_util.extract_state_delta(state) app_state_delta = state_deltas['app'] user_state_delta = state_deltas['user'] @@ -129,11 +145,6 @@ def _create_session_impl( user_state_delta ) - session_id = ( - session_id.strip() - if session_id and session_id.strip() - else platform_uuid.new_uuid() - ) session = Session( app_name=app_name, user_id=user_id, @@ -149,7 +160,9 @@ def _create_session_impl( self.sessions[app_name][user_id][session_id] = session copied_session = _copy_session(session) - return self._merge_state(app_name, user_id, copied_session) + merged = self._merge_state(app_name, user_id, copied_session) + self._restore_secret_state(merged) + return merged @override async def get_session( @@ -219,7 +232,9 @@ def _get_session_impl( copied_session.events = copied_session.events[i + 1 :] # Return a copy of the session object with merged state. - return self._merge_state(app_name, user_id, copied_session) + merged = self._merge_state(app_name, user_id, copied_session) + self._restore_secret_state(merged) + return merged def _merge_state( self, app_name: str, user_id: str, copied_session: Session @@ -311,6 +326,7 @@ def _delete_session_impl( return self.sessions[app_name][user_id].pop(session_id) + self._evict_secret_state(app_name, user_id, session_id) @override async def append_event(self, session: Session, event: Event) -> Event: diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index 427bc3e73e..ef1bdbb32b 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -179,6 +179,15 @@ async def create_session( f"Session with id {session_id} already exists." ) + # Seed secret state into the process-local cache before + # extract_state_delta, which would otherwise drop secret keys. + state = self._seed_secret_state_on_create( + app_name=app_name, + user_id=user_id, + session_id=session_id, + state=state, + ) + # Extract state deltas state_deltas = _session_util.extract_state_delta(state) app_state_delta = state_deltas["app"] @@ -218,7 +227,7 @@ async def create_session( merged_state = _merge_state( storage_app_state, storage_user_state, session_state ) - return Session( + session = Session( app_name=app_name, user_id=user_id, id=session_id, @@ -226,6 +235,8 @@ async def create_session( events=[], last_update_time=now, ) + self._restore_secret_state(session) + return session @override async def get_session( @@ -284,7 +295,7 @@ async def get_session( for event_data in reversed(storage_events_data) ] - return Session( + session = Session( app_name=app_name, user_id=user_id, id=session_id, @@ -292,6 +303,8 @@ async def get_session( events=events, last_update_time=last_update_time, ) + self._restore_secret_state(session) + return session @override async def list_sessions( @@ -358,6 +371,7 @@ async def delete_session( (app_name, user_id, session_id), ) await db.commit() + self._evict_secret_state(app_name, user_id, session_id) @override async def append_event(self, session: Session, event: Event) -> Event: @@ -369,6 +383,9 @@ async def append_event(self, session: Session, event: Event) -> Event: self._apply_temp_state(session, event) # Trim temp state before persisting event = self._trim_temp_delta_state(event) + # Apply secret state to in-memory session and process cache. + self._apply_secret_state(session, event) + event = self._trim_secret_delta_state(event) event_timestamp = event.timestamp async with self._get_db_connection() as db: diff --git a/src/google/adk/sessions/state.py b/src/google/adk/sessions/state.py index a6a3bdbbe9..870ea7b076 100644 --- a/src/google/adk/sessions/state.py +++ b/src/google/adk/sessions/state.py @@ -23,6 +23,7 @@ class State: APP_PREFIX = "app:" USER_PREFIX = "user:" TEMP_PREFIX = "temp:" + SECRET_PREFIX = "secret:" def __init__(self, value: dict[str, Any], delta: dict[str, Any]): """ diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index 8025821975..4a3d00602d 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -39,6 +39,7 @@ from .base_session_service import GetSessionConfig from .base_session_service import ListSessionsResponse from .session import Session +from .state import State logger = logging.getLogger('google_adk.' + __name__) @@ -117,6 +118,19 @@ async def create_session( """ reasoning_engine_id = self._get_reasoning_engine_id(app_name) + # Strip secret keys before sending state to the API. + secret_keys = {} + if state: + secret_keys = { + k: v for k, v in state.items() if k.startswith(State.SECRET_PREFIX) + } + if secret_keys: + state = { + k: v + for k, v in state.items() + if not k.startswith(State.SECRET_PREFIX) + } + config = {'session_state': state} if state else {} if session_id: config['session_id'] = session_id @@ -138,6 +152,13 @@ async def create_session( state=getattr(get_session_response, 'session_state', None) or {}, last_update_time=get_session_response.update_time.timestamp(), ) + + # Seed secret state into the cache now that session.id is resolved. + if secret_keys: + cache_key = (app_name, user_id, session.id) + self._secret_state_cache.setdefault(cache_key, {}).update(secret_keys) + for key, value in secret_keys.items(): + session.state[key] = value return session @override @@ -214,6 +235,7 @@ async def get_session( if config.num_recent_events: session.events = session.events[-config.num_recent_events :] + self._restore_secret_state(session) return session @override @@ -260,6 +282,7 @@ async def delete_session( except Exception as e: logger.error('Error deleting session %s: %s', session_id, e) raise + self._evict_secret_state(app_name, user_id, session_id) @override async def append_event(self, session: Session, event: Event) -> Event: diff --git a/src/google/adk/tools/_google_credentials.py b/src/google/adk/tools/_google_credentials.py index 8eb92d9c53..e54b130313 100644 --- a/src/google/adk/tools/_google_credentials.py +++ b/src/google/adk/tools/_google_credentials.py @@ -171,11 +171,17 @@ async def get_valid_credentials( f" {self.credentials_config.external_access_token_key}." ) # First, try to get credentials from the tool context - creds_json = ( - tool_context.state.get(self.credentials_config._token_cache_key, None) - if self.credentials_config._token_cache_key - else None - ) + cache_key = self.credentials_config._token_cache_key + creds_json = tool_context.state.get(cache_key, None) if cache_key else None + # Fall back to legacy unprefixed key, then migrate: copy into + # secret: scope and clear the legacy key so it is removed from + # persistent storage on the next state delta flush. + if creds_json is None and cache_key and cache_key.startswith("secret:"): + legacy_key = cache_key[len("secret:") :] + creds_json = tool_context.state.get(legacy_key, None) + if creds_json is not None: + tool_context.state[cache_key] = creds_json + tool_context.state[legacy_key] = None creds = ( google.oauth2.credentials.Credentials.from_authorized_user_info( json.loads(creds_json), self.credentials_config.scopes diff --git a/src/google/adk/tools/bigquery/bigquery_credentials.py b/src/google/adk/tools/bigquery/bigquery_credentials.py index c491c52ee6..51faeadd46 100644 --- a/src/google/adk/tools/bigquery/bigquery_credentials.py +++ b/src/google/adk/tools/bigquery/bigquery_credentials.py @@ -18,7 +18,7 @@ from ...features import FeatureName from .._google_credentials import BaseGoogleCredentialsConfig -BIGQUERY_TOKEN_CACHE_KEY = "bigquery_token_cache" +BIGQUERY_TOKEN_CACHE_KEY = "secret:bigquery_token_cache" BIGQUERY_SCOPES = [ "https://www.googleapis.com/auth/bigquery", "https://www.googleapis.com/auth/dataplex.read-write", diff --git a/src/google/adk/utils/instructions_utils.py b/src/google/adk/utils/instructions_utils.py index 505b5cf128..730a89b54e 100644 --- a/src/google/adk/utils/instructions_utils.py +++ b/src/google/adk/utils/instructions_utils.py @@ -143,7 +143,12 @@ def _is_valid_state_name(var_name): return var_name.isidentifier() if len(parts) == 2: - prefixes = [State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX] + prefixes = [ + State.APP_PREFIX, + State.USER_PREFIX, + State.TEMP_PREFIX, + State.SECRET_PREFIX, + ] if (parts[0] + ':') in prefixes: return parts[1].isidentifier() return False diff --git a/tests/unittests/auth/credential_service/test_session_state_credential_service.py b/tests/unittests/auth/credential_service/test_session_state_credential_service.py index 1f997336f1..e7cd383e63 100644 --- a/tests/unittests/auth/credential_service/test_session_state_credential_service.py +++ b/tests/unittests/auth/credential_service/test_session_state_credential_service.py @@ -23,6 +23,7 @@ from google.adk.auth.auth_credential import OAuth2Auth from google.adk.auth.auth_tool import AuthConfig from google.adk.auth.credential_service.session_state_credential_service import SessionStateCredentialService +from google.adk.sessions.state import State import pytest @@ -265,10 +266,11 @@ async def test_state_persistence_across_operations( # Save credential await credential_service.save_credential(auth_config, callback_context) - # Verify state contains the credential - assert auth_config.credential_key in callback_context.state + # Verify state contains the credential under secret: prefix + secret_key = State.SECRET_PREFIX + auth_config.credential_key + assert secret_key in callback_context.state assert ( - callback_context.state[auth_config.credential_key] + callback_context.state[secret_key] == auth_config.exchanged_auth_credential ) @@ -279,9 +281,9 @@ async def test_state_persistence_across_operations( assert result is not None # Verify state still contains the credential - assert auth_config.credential_key in callback_context.state + assert secret_key in callback_context.state assert ( - callback_context.state[auth_config.credential_key] + callback_context.state[secret_key] == auth_config.exchanged_auth_credential ) @@ -300,7 +302,7 @@ async def test_state_persistence_across_operations( await credential_service.save_credential(auth_config, callback_context) # Verify state was updated - assert callback_context.state[auth_config.credential_key] == new_credential + assert callback_context.state[secret_key] == new_credential @pytest.mark.asyncio async def test_credential_key_uniqueness( @@ -344,13 +346,12 @@ async def test_credential_key_uniqueness( await credential_service.save_credential(auth_config1, callback_context) await credential_service.save_credential(auth_config2, callback_context) - # Verify both exist in state with different keys - assert "unique_key_1" in callback_context.state - assert "unique_key_2" in callback_context.state - assert ( - callback_context.state["unique_key_1"] - != callback_context.state["unique_key_2"] - ) + # Verify both exist in state with secret-prefixed keys + sk1 = State.SECRET_PREFIX + "unique_key_1" + sk2 = State.SECRET_PREFIX + "unique_key_2" + assert sk1 in callback_context.state + assert sk2 in callback_context.state + assert callback_context.state[sk1] != callback_context.state[sk2] # Load and verify both credentials result1 = await credential_service.load_credential( @@ -379,10 +380,89 @@ async def test_direct_state_access( redirect_uri="https://direct.com/callback", ), ) - callback_context.state[auth_config.credential_key] = test_credential + callback_context.state[State.SECRET_PREFIX + auth_config.credential_key] = ( + test_credential + ) # Load using the service result = await credential_service.load_credential( auth_config, callback_context ) assert result == test_credential + + @pytest.mark.asyncio + async def test_load_falls_back_to_legacy_unprefixed_key( + self, credential_service, auth_config, callback_context + ): + """Credentials stored under the old unprefixed key are still found.""" + legacy_cred = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="legacy_client", + client_secret="legacy_secret", + ), + ) + # Simulate a session persisted before the secret: migration + callback_context.state[auth_config.credential_key] = legacy_cred + + result = await credential_service.load_credential( + auth_config, callback_context + ) + assert result is not None + assert result.oauth2.client_id == "legacy_client" + # Legacy key should be migrated: secret key populated, legacy cleared + secret_key = State.SECRET_PREFIX + auth_config.credential_key + assert secret_key in callback_context.state + assert callback_context.state[secret_key] == legacy_cred + assert callback_context.state[auth_config.credential_key] is None + + @pytest.mark.asyncio + async def test_secret_key_takes_precedence_over_legacy( + self, credential_service, auth_config, callback_context + ): + """When both keys exist, the secret-prefixed key wins.""" + old_cred = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="old_client", + client_secret="old_secret", + ), + ) + new_cred = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="new_client", + client_secret="new_secret", + ), + ) + callback_context.state[auth_config.credential_key] = old_cred + callback_context.state[State.SECRET_PREFIX + auth_config.credential_key] = ( + new_cred + ) + + result = await credential_service.load_credential( + auth_config, callback_context + ) + assert result.oauth2.client_id == "new_client" + + @pytest.mark.asyncio + async def test_explicit_none_secret_key_not_revived_by_legacy( + self, credential_service, auth_config, callback_context + ): + """Explicit None in secret: key must not fall back to legacy key.""" + old_cred = AuthCredential( + auth_type=AuthCredentialTypes.OAUTH2, + oauth2=OAuth2Auth( + client_id="stale_client", + client_secret="stale_secret", + ), + ) + callback_context.state[auth_config.credential_key] = old_cred + callback_context.state[State.SECRET_PREFIX + auth_config.credential_key] = ( + None + ) + + result = await credential_service.load_credential( + auth_config, callback_context + ) + assert result is None diff --git a/tests/unittests/sessions/test_secret_state.py b/tests/unittests/sessions/test_secret_state.py new file mode 100644 index 0000000000..3dffbf7e5b --- /dev/null +++ b/tests/unittests/sessions/test_secret_state.py @@ -0,0 +1,404 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the secret: session state scope.""" + +import json + +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.sessions._session_util import extract_state_delta +from google.adk.sessions.base_session_service import BaseSessionService +from google.adk.sessions.database_session_service import DatabaseSessionService +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.sessions.session import Session +from google.adk.sessions.sqlite_session_service import SqliteSessionService +from google.adk.sessions.state import State +from google.adk.utils.instructions_utils import _is_valid_state_name +import pytest + +from .test_session_service import get_session_service +from .test_session_service import SessionServiceType + +# --------------------------------------------------------------------------- +# Unit: extract_state_delta +# --------------------------------------------------------------------------- + + +class TestExtractStateDelta: + + def test_secret_keys_excluded_from_all_buckets(self): + """secret: keys must not appear in app, user, or session buckets.""" + state = { + 'app:a': 1, + 'user:u': 2, + 'temp:t': 3, + 'secret:tok': 'abc', + 'normal': 4, + } + deltas = extract_state_delta(state) + assert deltas['app'] == {'a': 1} + assert deltas['user'] == {'u': 2} + assert deltas['session'] == {'normal': 4} + # secret and temp keys must not be in any bucket + all_values = {} + for bucket in deltas.values(): + all_values.update(bucket) + assert 'secret:tok' not in all_values + assert 'temp:t' not in all_values + + def test_only_secret_keys(self): + """If all keys are secret:, all buckets should be empty.""" + state = {'secret:a': 1, 'secret:b': 2} + deltas = extract_state_delta(state) + assert deltas == {'app': {}, 'user': {}, 'session': {}} + + +# --------------------------------------------------------------------------- +# Unit: State.SECRET_PREFIX +# --------------------------------------------------------------------------- + + +class TestStatePrefix: + + def test_secret_prefix_exists(self): + assert State.SECRET_PREFIX == 'secret:' + + +# --------------------------------------------------------------------------- +# Unit: _is_valid_state_name +# --------------------------------------------------------------------------- + + +class TestValidation: + + def test_secret_prefix_is_valid(self): + assert _is_valid_state_name('secret:token') is True + + def test_secret_prefix_invalid_name(self): + # Non-identifier after prefix + assert _is_valid_state_name('secret:123abc') is False + + +# --------------------------------------------------------------------------- +# Unit: BaseSessionService cache helpers +# --------------------------------------------------------------------------- + + +class TestBaseSessionServiceHelpers: + """Tests the lifecycle helpers on BaseSessionService.""" + + def _make_service(self): + """Create an InMemorySessionService (simplest concrete subclass).""" + return InMemorySessionService() + + def _make_session(self, app='app', user='user', sid='s1'): + return Session( + app_name=app, + user_id=user, + id=sid, + state={}, + ) + + def _make_event(self, state_delta): + return Event( + invocation_id='inv', + author='agent', + actions=EventActions(state_delta=state_delta), + ) + + def test_lazy_cache_initialization(self): + """Cache is lazily created on first access.""" + svc = self._make_service() + cache = svc._secret_state_cache + assert isinstance(cache, dict) + assert len(cache) == 0 + # Same object on second access + assert svc._secret_state_cache is cache + + def test_apply_and_trim_secret_state(self): + svc = self._make_service() + session = self._make_session() + event = self._make_event({ + 'secret:token': 'abc', + 'normal_key': 'val', + }) + + svc._apply_secret_state(session, event) + svc._trim_secret_delta_state(event) + + # Secret key applied to session state + assert session.state.get('secret:token') == 'abc' + # Secret key removed from event delta + assert 'secret:token' not in event.actions.state_delta + # Normal key still in delta + assert event.actions.state_delta.get('normal_key') == 'val' + # Secret key in process cache + cache_key = ('app', 'user', 's1') + assert svc._secret_state_cache[cache_key]['secret:token'] == 'abc' + + def test_seed_and_restore_secret_state(self): + svc = self._make_service() + state = { + 'secret:cred': 'xyz', + 'app:foo': 'bar', + 'normal': 123, + } + cleaned = svc._seed_secret_state_on_create( + app_name='app', + user_id='user', + session_id='s1', + state=state, + ) + # Returned state has secret keys stripped + assert 'secret:cred' not in cleaned + assert cleaned['app:foo'] == 'bar' + assert cleaned['normal'] == 123 + + # Restore into a session + session = self._make_session() + svc._restore_secret_state(session) + assert session.state.get('secret:cred') == 'xyz' + + def test_seed_with_no_secret_keys(self): + svc = self._make_service() + state = {'normal': 1} + result = svc._seed_secret_state_on_create( + app_name='a', user_id='u', session_id='s', state=state + ) + assert result == {'normal': 1} + + def test_seed_with_none_state(self): + svc = self._make_service() + result = svc._seed_secret_state_on_create( + app_name='a', user_id='u', session_id='s', state=None + ) + assert result is None + + def test_evict_secret_state(self): + svc = self._make_service() + svc._secret_state_cache[('a', 'u', 's')] = {'secret:x': 1} + svc._evict_secret_state('a', 'u', 's') + assert ('a', 'u', 's') not in svc._secret_state_cache + + def test_evict_nonexistent_key(self): + """Evicting a non-existent key should not raise.""" + svc = self._make_service() + svc._evict_secret_state('a', 'u', 's') # no-op + + +# --------------------------------------------------------------------------- +# Integration: session service lifecycle (parametrized) +# --------------------------------------------------------------------------- + +_SERVICE_TYPES = [ + SessionServiceType.IN_MEMORY, + SessionServiceType.DATABASE, + SessionServiceType.SQLITE, +] + + +@pytest.fixture(params=_SERVICE_TYPES) +async def session_service(request, tmp_path): + service = get_session_service(request.param, tmp_path) + yield service + if isinstance(service, DatabaseSessionService): + await service.close() + + +class TestSecretStateLifecycle: + """Integration tests for secret: state across session lifecycle.""" + + async def test_append_event_secret_survives_across_turns( + self, session_service + ): + """Secret state set via append_event survives get_session.""" + session = await session_service.create_session( + app_name='app', user_id='user' + ) + event = Event( + invocation_id='inv1', + author='agent', + actions=EventActions( + state_delta={'secret:token': 'abc123', 'visible': 'yes'} + ), + ) + await session_service.append_event(session=session, event=event) + + # Secret is in in-memory session + assert session.state.get('secret:token') == 'abc123' + assert session.state.get('visible') == 'yes' + + # Secret key trimmed from event delta + assert 'secret:token' not in event.actions.state_delta + + # Secret survives get_session (restored from cache) + restored = await session_service.get_session( + app_name='app', + user_id='user', + session_id=session.id, + ) + assert restored.state.get('secret:token') == 'abc123' + assert restored.state.get('visible') == 'yes' + + async def test_create_session_with_secret_state(self, session_service): + """Secret keys in initial state are cached, not persisted.""" + session = await session_service.create_session( + app_name='app', + user_id='user', + state={'secret:init_cred': 'init_val', 'normal': 'nval'}, + ) + # Secret available in returned session + assert session.state.get('secret:init_cred') == 'init_val' + assert session.state.get('normal') == 'nval' + + # Secret survives get_session + restored = await session_service.get_session( + app_name='app', + user_id='user', + session_id=session.id, + ) + assert restored.state.get('secret:init_cred') == 'init_val' + assert restored.state.get('normal') == 'nval' + + async def test_delete_session_evicts_secret_cache(self, session_service): + """Deleting a session evicts its secret cache entry.""" + session = await session_service.create_session( + app_name='app', + user_id='user', + state={'secret:key': 'val'}, + ) + sid = session.id + + await session_service.delete_session( + app_name='app', user_id='user', session_id=sid + ) + + # Cache should be empty for this session + cache_key = ('app', 'user', sid) + assert cache_key not in session_service._secret_state_cache + + async def test_list_sessions_does_not_include_secret_state( + self, session_service + ): + """list_sessions must NOT merge secret state.""" + session = await session_service.create_session( + app_name='app', + user_id='user', + state={'secret:hidden': 'shhh', 'visible': 'ok'}, + ) + + response = await session_service.list_sessions( + app_name='app', user_id='user' + ) + assert len(response.sessions) >= 1 + listed = next(s for s in response.sessions if s.id == session.id) + # Secret should NOT be in listed session state + assert listed.state.get('secret:hidden') is None + # Normal state should be present + assert listed.state.get('visible') == 'ok' + + async def test_secret_and_temp_independent(self, session_service): + """secret: and temp: scopes work independently.""" + session = await session_service.create_session( + app_name='app', user_id='user' + ) + event = Event( + invocation_id='inv1', + author='agent', + actions=EventActions( + state_delta={ + 'secret:s': 'secret_val', + 'temp:t': 'temp_val', + 'normal': 'n', + } + ), + ) + await session_service.append_event(session=session, event=event) + + # Both available in-memory + assert session.state.get('secret:s') == 'secret_val' + assert session.state.get('temp:t') == 'temp_val' + + # Neither in event delta + assert 'secret:s' not in event.actions.state_delta + assert 'temp:t' not in event.actions.state_delta + # Normal key persisted + assert event.actions.state_delta.get('normal') == 'n' + + # After get_session: secret survives, temp does not + restored = await session_service.get_session( + app_name='app', + user_id='user', + session_id=session.id, + ) + assert restored.state.get('secret:s') == 'secret_val' + # temp is lost after get_session (invocation-scoped) + assert restored.state.get('temp:t') is None + + +# --------------------------------------------------------------------------- +# BQAA redaction +# --------------------------------------------------------------------------- + + +class TestBQAARedaction: + + def test_secret_key_redacted(self): + from google.adk.plugins.bigquery_agent_analytics_plugin import _recursive_smart_truncate + + obj = {'secret:token': 'my_secret', 'normal': 'visible'} + result, _ = _recursive_smart_truncate(obj, max_len=-1) + assert result['secret:token'] == '[REDACTED]' + assert result['normal'] == 'visible' + + def test_json_blob_with_sensitive_keys_redacted(self): + from google.adk.plugins.bigquery_agent_analytics_plugin import _recursive_smart_truncate + + cred_json = json.dumps({ + 'access_token': 'ya29.xxx', + 'refresh_token': 'rt_xxx', + 'client_id': 'my_client', + }) + obj = {'bigquery_token_cache': cred_json, 'normal': 'ok'} + result, _ = _recursive_smart_truncate(obj, max_len=-1) + assert result['bigquery_token_cache'] == '[REDACTED]' + assert result['normal'] == 'ok' + + def test_json_blob_without_sensitive_keys_not_redacted(self): + from google.adk.plugins.bigquery_agent_analytics_plugin import _recursive_smart_truncate + + safe_json = json.dumps({'query': 'SELECT 1', 'count': 42}) + obj = {'data': safe_json} + result, _ = _recursive_smart_truncate(obj, max_len=-1) + # Should not be redacted since no sensitive keys + assert result['data'] == safe_json + + def test_non_json_string_not_redacted(self): + from google.adk.plugins.bigquery_agent_analytics_plugin import _recursive_smart_truncate + + obj = {'note': 'this is a normal string'} + result, _ = _recursive_smart_truncate(obj, max_len=-1) + assert result['note'] == 'this is a normal string' + + def test_is_sensitive_json_string_helper(self): + from google.adk.plugins.bigquery_agent_analytics_plugin import _is_sensitive_json_string + + assert _is_sensitive_json_string('not json') is False + assert _is_sensitive_json_string('') is False + assert _is_sensitive_json_string('{"safe_key": 1}') is False + assert _is_sensitive_json_string('{"access_token": "ya29"}') is True + assert _is_sensitive_json_string('{"Client_Secret": "xxx"}') is True + # Array does not trigger + assert _is_sensitive_json_string('[{"access_token": "ya29"}]') is False