Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion src/google/adk/plugins/bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,24 @@ def _get_tool_origin(tool: "BaseTool") -> str:
})


def _is_sensitive_json_string(value: str) -> bool:
"""Checks if a string is a JSON blob containing sensitive keys.

This catches opaque JSON-encoded strings (e.g. serialized credential
caches) whose inner keys would not be caught by the dict-level
_SENSITIVE_KEYS check.
"""
if not value or value[0] not in ("{", "["):
return False
try:
parsed = json.loads(value)
except (json.JSONDecodeError, ValueError):
return False
if isinstance(parsed, dict):
return bool(_SENSITIVE_KEYS & {k.lower() for k in parsed})
return False


def _recursive_smart_truncate(
obj: Any, max_len: int, seen: Optional[set[int]] = None
) -> tuple[Any, bool]:
Expand Down Expand Up @@ -266,10 +284,19 @@ def _recursive_smart_truncate(
for k, v in obj.items():
if isinstance(k, str):
k_lower = k.lower()
if k_lower in _SENSITIVE_KEYS or k_lower.startswith("temp:"):
if (
k_lower in _SENSITIVE_KEYS
or k_lower.startswith("temp:")
or k_lower.startswith("secret:")
):
new_dict[k] = "[REDACTED]"
continue

# Detect JSON-encoded strings that contain sensitive keys.
if isinstance(v, str) and _is_sensitive_json_string(v):
new_dict[k] = "[REDACTED]"
continue

val, trunc = _recursive_smart_truncate(v, max_len, seen)
if trunc:
truncated_any = True
Expand Down
4 changes: 3 additions & 1 deletion src/google/adk/sessions/_session_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def extract_state_delta(
deltas["app"][key.removeprefix(State.APP_PREFIX)] = state[key]
elif key.startswith(State.USER_PREFIX):
deltas["user"][key.removeprefix(State.USER_PREFIX)] = state[key]
elif not key.startswith(State.TEMP_PREFIX):
elif not key.startswith(State.TEMP_PREFIX) and not key.startswith(
State.SECRET_PREFIX
):
deltas["session"][key] = state[key]
return deltas
87 changes: 87 additions & 0 deletions src/google/adk/sessions/base_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,22 @@ class BaseSessionService(abc.ABC):
The service provides a set of methods for managing sessions and events.
"""

@property
def _secret_state_cache(
self,
) -> dict[tuple[str, str, str], dict[str, Any]]:
"""Process-local cache for secret-scoped state.

Keyed by (app_name, user_id, session_id).
Lazily initialized to avoid requiring subclasses to call
super().__init__().
"""
try:
return self.__secret_state_cache
except AttributeError:
self.__secret_state_cache: dict[tuple[str, str, str], dict[str, Any]] = {}
return self.__secret_state_cache

@abc.abstractmethod
async def create_session(
self,
Expand Down Expand Up @@ -120,6 +136,10 @@ async def append_event(self, session: Session, event: Event) -> Event:
# read temp values (e.g. output_key='temp:my_key' in SequentialAgent).
self._apply_temp_state(session, event)
event = self._trim_temp_delta_state(event)
# Apply secret-scoped state to in-memory session and process cache
# BEFORE trimming, so the session retains secret values across turns.
self._apply_secret_state(session, event)
event = self._trim_secret_delta_state(event)
self._update_session_state(session, event)
session.events.append(event)
return event
Expand Down Expand Up @@ -154,6 +174,73 @@ def _trim_temp_delta_state(self, event: Event) -> Event:
}
return event

def _apply_secret_state(self, session: Session, event: Event) -> None:
"""Applies secret-scoped state to in-memory session and process cache.

Secret state survives across turns (via the process-local cache) but
is never persisted to storage. The event delta is trimmed separately
by _trim_secret_delta_state.
"""
if not event.actions or not event.actions.state_delta:
return
cache_key = (session.app_name, session.user_id, session.id)
for key, value in event.actions.state_delta.items():
if key.startswith(State.SECRET_PREFIX):
session.state[key] = value
self._secret_state_cache.setdefault(cache_key, {})[key] = value

def _trim_secret_delta_state(self, event: Event) -> Event:
"""Removes secret-scoped keys from event delta before persistence."""
if not event.actions or not event.actions.state_delta:
return event
event.actions.state_delta = {
key: value
for key, value in event.actions.state_delta.items()
if not key.startswith(State.SECRET_PREFIX)
}
return event

def _seed_secret_state_on_create(
self,
*,
app_name: str,
user_id: str,
session_id: str,
state: Optional[dict[str, Any]],
) -> Optional[dict[str, Any]]:
"""Extracts secret-scoped keys from initial state into the cache.

Returns the state dict with secret keys removed (for persistence)
but seeds them in the process-local cache so get_session() can
restore them.
"""
if not state:
return state
secret_keys = {
k: v for k, v in state.items() if k.startswith(State.SECRET_PREFIX)
}
if not secret_keys:
return state
cache_key = (app_name, user_id, session_id)
self._secret_state_cache.setdefault(cache_key, {}).update(secret_keys)
return {
k: v for k, v in state.items() if not k.startswith(State.SECRET_PREFIX)
}

def _restore_secret_state(self, session: Session) -> None:
"""Merges cached secret state into an in-memory session."""
cache_key = (session.app_name, session.user_id, session.id)
secret_state = self._secret_state_cache.get(cache_key, {})
for key, value in secret_state.items():
session.state[key] = value

def _evict_secret_state(
self, app_name: str, user_id: str, session_id: str
) -> None:
"""Removes cached secret state for a deleted session."""
cache_key = (app_name, user_id, session_id)
self._secret_state_cache.pop(cache_key, None)

def _update_session_state(self, session: Session, event: Event) -> None:
"""Updates the session state based on the event."""
if not event.actions or not event.actions.state_delta:
Expand Down
27 changes: 27 additions & 0 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,21 @@ async def create_session(
defaults={"app_name": app_name, "user_id": user_id, "state": {}},
)

# Extract secret keys before they are dropped by
# extract_state_delta; we'll seed the cache after the session
# is committed and has a final ID.
secret_keys = {}
if state:
secret_keys = {
k: v for k, v in state.items() if k.startswith(State.SECRET_PREFIX)
}
if secret_keys:
state = {
k: v
for k, v in state.items()
if not k.startswith(State.SECRET_PREFIX)
}

# Extract state deltas
state_deltas = _session_util.extract_state_delta(state)
app_state_delta = state_deltas["app"]
Expand Down Expand Up @@ -482,6 +497,13 @@ async def create_session(
session = storage_session.to_session(
state=merged_state, is_sqlite=is_sqlite
)

# Seed secret state into the cache now that session.id is resolved.
if secret_keys:
cache_key = (app_name, user_id, session.id)
self._secret_state_cache.setdefault(cache_key, {}).update(secret_keys)
for key, value in secret_keys.items():
session.state[key] = value
return session

@override
Expand Down Expand Up @@ -547,6 +569,7 @@ async def get_session(
session = storage_session.to_session(
state=merged_state, events=events, is_sqlite=is_sqlite
)
self._restore_secret_state(session)
return session

@override
Expand Down Expand Up @@ -615,6 +638,7 @@ async def delete_session(
)
await sql_session.execute(stmt)
await sql_session.commit()
self._evict_secret_state(app_name, user_id, session_id)

@override
async def append_event(self, session: Session, event: Event) -> Event:
Expand All @@ -627,6 +651,9 @@ async def append_event(self, session: Session, event: Event) -> Event:
self._apply_temp_state(session, event)
# Trim temp state before persisting
event = self._trim_temp_delta_state(event)
# Apply secret state to in-memory session and process cache.
self._apply_secret_state(session, event)
event = self._trim_secret_delta_state(event)

# 1. Validate the session has not gone stale.
# 2. Update session attributes based on event config.
Expand Down
30 changes: 23 additions & 7 deletions src/google/adk/sessions/in_memory_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,22 @@ def _create_session_impl(
app_name=app_name, user_id=user_id, session_id=session_id
):
raise AlreadyExistsError(f'Session with id {session_id} already exists.')

session_id = (
session_id.strip()
if session_id and session_id.strip()
else platform_uuid.new_uuid()
)

# Seed secret state into the process-local cache before
# extract_state_delta, which would otherwise drop secret keys.
state = self._seed_secret_state_on_create(
app_name=app_name,
user_id=user_id,
session_id=session_id,
state=state,
)

state_deltas = _session_util.extract_state_delta(state)
app_state_delta = state_deltas['app']
user_state_delta = state_deltas['user']
Expand All @@ -129,11 +145,6 @@ def _create_session_impl(
user_state_delta
)

session_id = (
session_id.strip()
if session_id and session_id.strip()
else platform_uuid.new_uuid()
)
session = Session(
app_name=app_name,
user_id=user_id,
Expand All @@ -149,7 +160,9 @@ def _create_session_impl(
self.sessions[app_name][user_id][session_id] = session

copied_session = _copy_session(session)
return self._merge_state(app_name, user_id, copied_session)
merged = self._merge_state(app_name, user_id, copied_session)
self._restore_secret_state(merged)
return merged

@override
async def get_session(
Expand Down Expand Up @@ -219,7 +232,9 @@ def _get_session_impl(
copied_session.events = copied_session.events[i + 1 :]

# Return a copy of the session object with merged state.
return self._merge_state(app_name, user_id, copied_session)
merged = self._merge_state(app_name, user_id, copied_session)
self._restore_secret_state(merged)
return merged

def _merge_state(
self, app_name: str, user_id: str, copied_session: Session
Expand Down Expand Up @@ -311,6 +326,7 @@ def _delete_session_impl(
return

self.sessions[app_name][user_id].pop(session_id)
self._evict_secret_state(app_name, user_id, session_id)

@override
async def append_event(self, session: Session, event: Event) -> Event:
Expand Down
21 changes: 19 additions & 2 deletions src/google/adk/sessions/sqlite_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ async def create_session(
f"Session with id {session_id} already exists."
)

# Seed secret state into the process-local cache before
# extract_state_delta, which would otherwise drop secret keys.
state = self._seed_secret_state_on_create(
app_name=app_name,
user_id=user_id,
session_id=session_id,
state=state,
)

# Extract state deltas
state_deltas = _session_util.extract_state_delta(state)
app_state_delta = state_deltas["app"]
Expand Down Expand Up @@ -218,14 +227,16 @@ async def create_session(
merged_state = _merge_state(
storage_app_state, storage_user_state, session_state
)
return Session(
session = Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=merged_state,
events=[],
last_update_time=now,
)
self._restore_secret_state(session)
return session

@override
async def get_session(
Expand Down Expand Up @@ -284,14 +295,16 @@ async def get_session(
for event_data in reversed(storage_events_data)
]

return Session(
session = Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=merged_state,
events=events,
last_update_time=last_update_time,
)
self._restore_secret_state(session)
return session

@override
async def list_sessions(
Expand Down Expand Up @@ -358,6 +371,7 @@ async def delete_session(
(app_name, user_id, session_id),
)
await db.commit()
self._evict_secret_state(app_name, user_id, session_id)

@override
async def append_event(self, session: Session, event: Event) -> Event:
Expand All @@ -369,6 +383,9 @@ async def append_event(self, session: Session, event: Event) -> Event:
self._apply_temp_state(session, event)
# Trim temp state before persisting
event = self._trim_temp_delta_state(event)
# Apply secret state to in-memory session and process cache.
self._apply_secret_state(session, event)
event = self._trim_secret_delta_state(event)
event_timestamp = event.timestamp

async with self._get_db_connection() as db:
Expand Down
1 change: 1 addition & 0 deletions src/google/adk/sessions/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class State:
APP_PREFIX = "app:"
USER_PREFIX = "user:"
TEMP_PREFIX = "temp:"
SECRET_PREFIX = "secret:"

def __init__(self, value: dict[str, Any], delta: dict[str, Any]):
"""
Expand Down
Loading