Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/google/adk/cli/cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -2205,15 +2205,33 @@ def migrate():
default="INFO",
help="Optional. Set the logging level",
)
@click.option( # type: ignore[untyped-decorator]
"--allow-unsafe-unpickling",
"--allow_unsafe_unpickling",
is_flag=True,
default=False,
help=(
"Optional. Allow unsafe pickle loading for trusted legacy session"
" databases."
),
)
def cli_migrate_session(
*, source_db_url: str, dest_db_url: str, log_level: str
*,
source_db_url: str,
dest_db_url: str,
log_level: str,
allow_unsafe_unpickling: bool,
):
"""Migrates a session database to the latest schema version."""
logs.setup_adk_logger(getattr(logging, log_level.upper()))
try:
from ..sessions.migration import migration_runner

migration_runner.upgrade(source_db_url, dest_db_url)
migration_runner.upgrade(
source_db_url,
dest_db_url,
allow_unsafe_unpickling=allow_unsafe_unpickling,
)
click.secho("Migration check and upgrade process finished.", fg="green")
except Exception as e:
click.secho(f"Migration failed: {e}", fg="red", err=True)
Expand Down
6 changes: 3 additions & 3 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
from ...telemetry.tracing import tracer
from ...tools.base_toolset import BaseToolset
from ...tools.tool_context import ToolContext
from ...utils.context_utils import Aclosing
from ...utils import model_name_utils
from ...utils.context_utils import Aclosing
from .audio_cache_manager import AudioCacheManager
from .functions import build_auth_request_event

Expand Down Expand Up @@ -563,8 +563,8 @@ async def run_live(
if llm_request.live_connect_config is None:
llm_request.live_connect_config = types.LiveConnectConfig()
if llm_request.live_connect_config.history_config is None:
llm_request.live_connect_config.history_config = types.HistoryConfig(
initial_history_in_client_content=True
llm_request.live_connect_config.history_config = (
types.HistoryConfig(initial_history_in_client_content=True)
)

logger.info(
Expand Down
9 changes: 7 additions & 2 deletions src/google/adk/flows/llm_flows/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,15 @@ def _build_basic_request(
llm_request.live_connect_config.realtime_input_config = (
invocation_context.run_config.realtime_input_config
)
active_model_name = getattr(getattr(agent, 'canonical_live_model', None), 'model', None) or llm_request.model
active_model_name = (
getattr(getattr(agent, 'canonical_live_model', None), 'model', None)
or llm_request.model
)
is_gemini_31 = model_name_utils.is_gemini_3_1_flash_live(active_model_name)
llm_request.live_connect_config.enable_affective_dialog = (
None if is_gemini_31 else invocation_context.run_config.enable_affective_dialog
None
if is_gemini_31
else invocation_context.run_config.enable_affective_dialog
)
llm_request.live_connect_config.proactivity = (
None if is_gemini_31 else invocation_context.run_config.proactivity
Expand Down
16 changes: 12 additions & 4 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,15 @@ async def send_history(self, history: list[types.Content]):
# protocol error (invalid role mid-session), we consolidate previous multi-turn
# interactions into a unified contextual preamble on a single user role turn.
if is_gemini_31 and self._api_backend != GoogleLLMVariant.GEMINI_API:
collapsed_text = "Previous conversation history:\n"
collapsed_text = 'Previous conversation history:\n'
for c in contents:
text_parts = "".join(p.text for p in c.parts if p.text)
text_parts = ''.join(p.text for p in c.parts if p.text)
collapsed_text += f'[{c.role}]: {text_parts}\n'
contents = [types.Content(role='user', parts=[types.Part.from_text(text=collapsed_text)])]
contents = [
types.Content(
role='user', parts=[types.Part.from_text(text=collapsed_text)]
)
]

logger.debug('Sending history to live connection: %s', contents)
await self._gemini_session.send_client_content(
Expand Down Expand Up @@ -281,7 +285,11 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
is_thought = current_is_thought
llm_response.partial = True
# don't yield the merged text event when receiving audio data
if text and not any(p.text for p in content.parts) and not has_inline_data:
if (
text
and not any(p.text for p in content.parts)
and not has_inline_data
):
yield self.__build_full_text_response(text, is_thought)
text = ''
is_thought = False
Expand Down
151 changes: 140 additions & 11 deletions src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import argparse
from datetime import datetime
from datetime import timezone
import io
import json
import logging
import pickle
Expand All @@ -37,6 +38,93 @@

logger = logging.getLogger("google_adk." + __name__)

_ALLOWED_PICKLE_GLOBALS: set[tuple[str, str]] = {
# Builtin containers/primitives.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also allow datetime.datetime and datetime.timezone? It's quite common for legacy state_delta or other Any fields in EventActions to contain timestamp objects.

("builtins", "dict"),
("builtins", "list"),
("builtins", "set"),
("builtins", "tuple"),
("builtins", "str"),
("builtins", "bytes"),
("builtins", "bytearray"),
("builtins", "int"),
("builtins", "float"),
("builtins", "bool"),
("datetime", "datetime"),
("datetime", "timedelta"),
("datetime", "timezone"),
# Expected pickled payload for v0 session schema events.
("fastapi.openapi.models", "APIKey"),
("fastapi.openapi.models", "APIKeyIn"),
("fastapi.openapi.models", "HTTPBase"),
("fastapi.openapi.models", "HTTPBearer"),
("fastapi.openapi.models", "OAuth2"),
("fastapi.openapi.models", "OAuthFlow"),
("fastapi.openapi.models", "OAuthFlowAuthorizationCode"),
("fastapi.openapi.models", "OAuthFlowClientCredentials"),
("fastapi.openapi.models", "OAuthFlowImplicit"),
("fastapi.openapi.models", "OAuthFlowPassword"),
("fastapi.openapi.models", "OAuthFlows"),
("fastapi.openapi.models", "OpenIdConnect"),
("fastapi.openapi.models", "SecurityBase"),
("fastapi.openapi.models", "SecurityScheme"),
("fastapi.openapi.models", "SecuritySchemeType"),
("google.adk.auth.auth_credential", "AuthCredential"),
("google.adk.auth.auth_credential", "AuthCredentialTypes"),
("google.adk.auth.auth_credential", "HttpAuth"),
("google.adk.auth.auth_credential", "HttpCredentials"),
("google.adk.auth.auth_credential", "OAuth2Auth"),
("google.adk.auth.auth_credential", "ServiceAccountCredential"),
("google.adk.auth.auth_schemes", "CustomAuthScheme"),
("google.adk.auth.auth_schemes", "ExtendedOAuth2"),
("google.adk.auth.auth_schemes", "OAuthGrantType"),
("google.adk.auth.auth_schemes", "OpenIdConnectWithConfig"),
("google.adk.auth.auth_tool", "AuthConfig"),
("google.adk.events.event_actions", "EventActions"),
("google.adk.events.event_actions", "EventCompaction"),
("google.adk.events.ui_widget", "UiWidget"),
("google.adk.tools.tool_confirmation", "ToolConfirmation"),
("google.genai.types", "Blob"),
("google.genai.types", "CodeExecutionResult"),
("google.genai.types", "Content"),
("google.genai.types", "ExecutableCode"),
("google.genai.types", "FileData"),
("google.genai.types", "FunctionCall"),
("google.genai.types", "FunctionResponse"),
("google.genai.types", "FunctionResponseBlob"),
("google.genai.types", "FunctionResponseFileData"),
("google.genai.types", "FunctionResponsePart"),
("google.genai.types", "Part"),
("google.genai.types", "PartMediaResolution"),
("google.genai.types", "VideoMetadata"),
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing Allowlist Type:

The class EventActions has a field render_ui_widgets which contains a list of UiWidget instances. However, google.adk.events.ui_widget.UiWidget is not registered in the _ALLOWED_PICKLE_GLOBALS allowlist. This means any legacy database events containing UI widgets will fail to unpickle by default and fallback to empty actions, unless the unsafe unpickling flag is set.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this. I pushed 8f0e973 to add google.adk.events.ui_widget.UiWidget to the restricted migration allowlist, with a regression test covering EventActions.render_ui_widgets under the restricted unpickler.

I also re-ran the local checks after the update:

  • uv run --python 3.13 pytest -q tests/unittests/sessions/migration/test_migration.py tests/unittests/cli/utils/test_cli_tools_click.py -> 60 passed
  • uv run --python 3.13 pytest -q tests/unittests/flows/llm_flows/test_base_llm_flow.py tests/unittests/models/test_gemini_llm_connection.py -> 78 passed
  • uv run --python 3.13 mypy src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py src/google/adk/sessions/migration/migration_runner.py -> passed
  • uv run --python 3.13 pre-commit run --all-files -> passed
  • git diff --check -> passed

The PR branch is updated; the new fork workflow runs may still need maintainer approval before GitHub runs them on the latest head. Thanks!



class _RestrictedUnpickler(pickle.Unpickler):
"""Restricted unpickler for migrating legacy v0 schema actions.

The v0 session schema stored `EventActions` as a pickled blob. During
migration we treat the raw bytes read from the source DB as untrusted input
and only allow the minimum set of safe globals needed to reconstruct
`EventActions`.
"""

def find_class(self, module: str, name: str) -> Any: # noqa: ANN001
if (module, name) in _ALLOWED_PICKLE_GLOBALS:
return super().find_class(module, name)
raise pickle.UnpicklingError(
f"Blocked global during migration unpickle: {module}.{name}"
)


def _restricted_pickle_loads(
data: bytes, *, allow_unsafe_unpickling: bool = False
) -> Any:
"""Load a pickle payload using the restricted unpickler by default."""
if allow_unsafe_unpickling:
return pickle.loads(data)
return _RestrictedUnpickler(io.BytesIO(data)).load()


def _to_datetime_obj(val: Any) -> datetime | Any:
"""Converts string to datetime if needed."""
Expand All @@ -51,15 +139,19 @@ def _to_datetime_obj(val: Any) -> datetime | Any:
return val


def _row_to_event(row: dict) -> Event:
def _row_to_event(
row: dict[str, Any], *, allow_unsafe_unpickling: bool = False
) -> Event:
"""Converts event row (dict) to event object, handling missing columns and deserializing."""

actions_val = row.get("actions")
actions = None
if actions_val is not None:
try:
if isinstance(actions_val, bytes):
actions = pickle.loads(actions_val)
actions = _restricted_pickle_loads(
actions_val, allow_unsafe_unpickling=allow_unsafe_unpickling
)
else: # for spanner - it might return object directly
actions = actions_val
except Exception as e:
Expand All @@ -75,17 +167,25 @@ def _row_to_event(row: dict) -> Event:
else:
actions = EventActions()

def _safe_json_load(val):
data = None
def _safe_json_load(val: Any) -> dict[str, Any] | None:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type hint suggests it returns a dict[str, Any], but json.loads could return a list. While it's unlikely for these specific columns, it might be safer to use dict[str, Any] | list[Any] | None or verify it's a dict. Also, the cast below might hide issues if it's actually a list.

if isinstance(val, str):
try:
data = json.loads(val)
except json.JSONDecodeError:
logger.warning(f"Failed to decode JSON for event {row.get('id')}")
return None
elif isinstance(val, dict):
data = val # for postgres JSONB
return data
return val # for postgres JSONB
else:
return None

if isinstance(data, dict):
return data
logger.warning(
f"Expected JSON object for event {row.get('id')}, got"
f" {type(data).__name__}."
)
return None

content_dict = _safe_json_load(row.get("content"))
grounding_metadata_dict = _safe_json_load(row.get("grounding_metadata"))
Expand Down Expand Up @@ -147,23 +247,31 @@ def _safe_json_load(val):
)


def _get_state_dict(state_val: Any) -> dict:
def _get_state_dict(state_val: Any) -> dict[str, Any]:
"""Safely load dict from JSON string or return dict if already dict."""
if isinstance(state_val, dict):
return state_val
if isinstance(state_val, str):
try:
return json.loads(state_val)
data = json.loads(state_val)
except json.JSONDecodeError:
logger.warning(
"Failed to parse state JSON string, defaulting to empty dict."
)
return {}
if isinstance(data, dict):
return data
logger.warning("State JSON was not an object, defaulting to empty dict.")
return {}
return {}


# --- Migration Logic ---
def migrate(source_db_url: str, dest_db_url: str):
def migrate(
source_db_url: str,
dest_db_url: str,
allow_unsafe_unpickling: bool = False,
) -> None:
"""Migrates data from old pickle schema to new JSON schema."""
# Convert async driver URLs to sync URLs for SQLAlchemy's synchronous engine.
# This allows users to provide URLs like 'postgresql+asyncpg://...' and have
Expand All @@ -172,6 +280,11 @@ def migrate(source_db_url: str, dest_db_url: str):
dest_sync_url = _schema_check_utils.to_sync_url(dest_db_url)

logger.info(f"Connecting to source database: {source_db_url}")
if allow_unsafe_unpickling:
logger.warning(
"Unsafe pickle migration mode is enabled. Only use this with a trusted"
" source database."
)
try:
source_engine = create_engine(source_sync_url)
SourceSession = sessionmaker(bind=source_engine)
Expand Down Expand Up @@ -265,7 +378,10 @@ def migrate(source_db_url: str, dest_db_url: str):
text("SELECT * FROM events")
).mappings():
try:
event_obj = _row_to_event(dict(row))
event_obj = _row_to_event(
dict(row),
allow_unsafe_unpickling=allow_unsafe_unpickling,
)
new_event = v1.StorageEvent(
id=event_obj.id,
app_name=row["app_name"],
Expand Down Expand Up @@ -309,9 +425,22 @@ def migrate(source_db_url: str, dest_db_url: str):
required=True,
help="SQLAlchemy URL of destination database",
)
parser.add_argument(
"--allow_unsafe_unpickling",
"--allow-unsafe-unpickling",
action="store_true",
help=(
"Allow legacy pickle payloads to use Python's unsafe pickle loader."
" Only use this with a trusted source database."
),
)
args = parser.parse_args()
try:
migrate(args.source_db_url, args.dest_db_url)
migrate(
args.source_db_url,
args.dest_db_url,
allow_unsafe_unpickling=args.allow_unsafe_unpickling,
)
except Exception as e:
logger.error(f"Migration failed: {e}")
sys.exit(1)
18 changes: 16 additions & 2 deletions src/google/adk/sessions/migration/migration_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@
LATEST_VERSION = _schema_check_utils.LATEST_SCHEMA_VERSION


def upgrade(source_db_url: str, dest_db_url: str):
def upgrade(
source_db_url: str,
dest_db_url: str,
allow_unsafe_unpickling: bool = False,
) -> None:
"""Migrates a database from its current version to the latest version.

If the source database schema is older than the latest version, this
Expand All @@ -61,6 +65,9 @@ def upgrade(source_db_url: str, dest_db_url: str):
source_db_url: The SQLAlchemy URL of the database to migrate from.
dest_db_url: The SQLAlchemy URL of the database to migrate to. This must be
different from source_db_url.
allow_unsafe_unpickling: If true, use Python's unsafe pickle loader for the
legacy pickle migration step. Only use this with a trusted source
database.

Raises:
RuntimeError: If source_db_url and dest_db_url are the same, or if no
Expand Down Expand Up @@ -113,7 +120,14 @@ def upgrade(source_db_url: str, dest_db_url: str):
logger.info(
f"Migrating from {in_url} to {out_url} (schema v{end_version})..."
)
migrate_func(in_url, out_url)
if migrate_func is migrate_from_sqlalchemy_pickle.migrate:
migrate_func(
in_url,
out_url,
allow_unsafe_unpickling=allow_unsafe_unpickling,
)
else:
migrate_func(in_url, out_url)
logger.info("Finished migration step to schema %s.", end_version)
# The output of this step becomes the input for the next step.
in_url = out_url
Expand Down
Loading