-
Notifications
You must be signed in to change notification settings - Fork 3.5k
fix(migration): restrict unpickling of v0 actions blobs #5866
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2e188c0
d1b6f27
fc4d23d
2243d7e
ba776c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| import argparse | ||
| from datetime import datetime | ||
| from datetime import timezone | ||
| import io | ||
| import json | ||
| import logging | ||
| import pickle | ||
|
|
@@ -37,6 +38,93 @@ | |
|
|
||
| logger = logging.getLogger("google_adk." + __name__) | ||
|
|
||
| _ALLOWED_PICKLE_GLOBALS: set[tuple[str, str]] = { | ||
| # Builtin containers/primitives. | ||
| ("builtins", "dict"), | ||
| ("builtins", "list"), | ||
| ("builtins", "set"), | ||
| ("builtins", "tuple"), | ||
| ("builtins", "str"), | ||
| ("builtins", "bytes"), | ||
| ("builtins", "bytearray"), | ||
| ("builtins", "int"), | ||
| ("builtins", "float"), | ||
| ("builtins", "bool"), | ||
| ("datetime", "datetime"), | ||
| ("datetime", "timedelta"), | ||
| ("datetime", "timezone"), | ||
| # Expected pickled payload for v0 session schema events. | ||
| ("fastapi.openapi.models", "APIKey"), | ||
| ("fastapi.openapi.models", "APIKeyIn"), | ||
| ("fastapi.openapi.models", "HTTPBase"), | ||
| ("fastapi.openapi.models", "HTTPBearer"), | ||
| ("fastapi.openapi.models", "OAuth2"), | ||
| ("fastapi.openapi.models", "OAuthFlow"), | ||
| ("fastapi.openapi.models", "OAuthFlowAuthorizationCode"), | ||
| ("fastapi.openapi.models", "OAuthFlowClientCredentials"), | ||
| ("fastapi.openapi.models", "OAuthFlowImplicit"), | ||
| ("fastapi.openapi.models", "OAuthFlowPassword"), | ||
| ("fastapi.openapi.models", "OAuthFlows"), | ||
| ("fastapi.openapi.models", "OpenIdConnect"), | ||
| ("fastapi.openapi.models", "SecurityBase"), | ||
| ("fastapi.openapi.models", "SecurityScheme"), | ||
| ("fastapi.openapi.models", "SecuritySchemeType"), | ||
| ("google.adk.auth.auth_credential", "AuthCredential"), | ||
| ("google.adk.auth.auth_credential", "AuthCredentialTypes"), | ||
| ("google.adk.auth.auth_credential", "HttpAuth"), | ||
| ("google.adk.auth.auth_credential", "HttpCredentials"), | ||
| ("google.adk.auth.auth_credential", "OAuth2Auth"), | ||
| ("google.adk.auth.auth_credential", "ServiceAccountCredential"), | ||
| ("google.adk.auth.auth_schemes", "CustomAuthScheme"), | ||
| ("google.adk.auth.auth_schemes", "ExtendedOAuth2"), | ||
| ("google.adk.auth.auth_schemes", "OAuthGrantType"), | ||
| ("google.adk.auth.auth_schemes", "OpenIdConnectWithConfig"), | ||
| ("google.adk.auth.auth_tool", "AuthConfig"), | ||
| ("google.adk.events.event_actions", "EventActions"), | ||
| ("google.adk.events.event_actions", "EventCompaction"), | ||
| ("google.adk.events.ui_widget", "UiWidget"), | ||
| ("google.adk.tools.tool_confirmation", "ToolConfirmation"), | ||
| ("google.genai.types", "Blob"), | ||
| ("google.genai.types", "CodeExecutionResult"), | ||
| ("google.genai.types", "Content"), | ||
| ("google.genai.types", "ExecutableCode"), | ||
| ("google.genai.types", "FileData"), | ||
| ("google.genai.types", "FunctionCall"), | ||
| ("google.genai.types", "FunctionResponse"), | ||
| ("google.genai.types", "FunctionResponseBlob"), | ||
| ("google.genai.types", "FunctionResponseFileData"), | ||
| ("google.genai.types", "FunctionResponsePart"), | ||
| ("google.genai.types", "Part"), | ||
| ("google.genai.types", "PartMediaResolution"), | ||
| ("google.genai.types", "VideoMetadata"), | ||
| } | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing Allowlist Type: The class EventActions has a field render_ui_widgets which contains a list of UiWidget instances. However, google.adk.events.ui_widget.UiWidget is not registered in the _ALLOWED_PICKLE_GLOBALS allowlist. This means any legacy database events containing UI widgets will fail to unpickle by default and fallback to empty actions, unless the unsafe unpickling flag is set.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for catching this. I pushed I also re-ran the local checks after the update:
The PR branch is updated; the new fork workflow runs may still need maintainer approval before GitHub runs them on the latest head. Thanks! |
||
|
|
||
|
|
||
| class _RestrictedUnpickler(pickle.Unpickler): | ||
| """Restricted unpickler for migrating legacy v0 schema actions. | ||
|
|
||
| The v0 session schema stored `EventActions` as a pickled blob. During | ||
| migration we treat the raw bytes read from the source DB as untrusted input | ||
| and only allow the minimum set of safe globals needed to reconstruct | ||
| `EventActions`. | ||
| """ | ||
|
|
||
| def find_class(self, module: str, name: str) -> Any: # noqa: ANN001 | ||
| if (module, name) in _ALLOWED_PICKLE_GLOBALS: | ||
| return super().find_class(module, name) | ||
| raise pickle.UnpicklingError( | ||
| f"Blocked global during migration unpickle: {module}.{name}" | ||
| ) | ||
|
|
||
|
|
||
| def _restricted_pickle_loads( | ||
| data: bytes, *, allow_unsafe_unpickling: bool = False | ||
| ) -> Any: | ||
| """Load a pickle payload using the restricted unpickler by default.""" | ||
| if allow_unsafe_unpickling: | ||
| return pickle.loads(data) | ||
| return _RestrictedUnpickler(io.BytesIO(data)).load() | ||
|
|
||
|
|
||
| def _to_datetime_obj(val: Any) -> datetime | Any: | ||
| """Converts string to datetime if needed.""" | ||
|
|
@@ -51,15 +139,19 @@ def _to_datetime_obj(val: Any) -> datetime | Any: | |
| return val | ||
|
|
||
|
|
||
| def _row_to_event(row: dict) -> Event: | ||
| def _row_to_event( | ||
| row: dict[str, Any], *, allow_unsafe_unpickling: bool = False | ||
| ) -> Event: | ||
| """Converts event row (dict) to event object, handling missing columns and deserializing.""" | ||
|
|
||
| actions_val = row.get("actions") | ||
| actions = None | ||
| if actions_val is not None: | ||
| try: | ||
| if isinstance(actions_val, bytes): | ||
| actions = pickle.loads(actions_val) | ||
| actions = _restricted_pickle_loads( | ||
| actions_val, allow_unsafe_unpickling=allow_unsafe_unpickling | ||
| ) | ||
| else: # for spanner - it might return object directly | ||
| actions = actions_val | ||
| except Exception as e: | ||
|
|
@@ -75,17 +167,25 @@ def _row_to_event(row: dict) -> Event: | |
| else: | ||
| actions = EventActions() | ||
|
|
||
| def _safe_json_load(val): | ||
| data = None | ||
| def _safe_json_load(val: Any) -> dict[str, Any] | None: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The type hint suggests it returns a |
||
| if isinstance(val, str): | ||
| try: | ||
| data = json.loads(val) | ||
| except json.JSONDecodeError: | ||
| logger.warning(f"Failed to decode JSON for event {row.get('id')}") | ||
| return None | ||
| elif isinstance(val, dict): | ||
| data = val # for postgres JSONB | ||
| return data | ||
| return val # for postgres JSONB | ||
| else: | ||
| return None | ||
|
|
||
| if isinstance(data, dict): | ||
| return data | ||
| logger.warning( | ||
| f"Expected JSON object for event {row.get('id')}, got" | ||
| f" {type(data).__name__}." | ||
| ) | ||
| return None | ||
|
|
||
| content_dict = _safe_json_load(row.get("content")) | ||
| grounding_metadata_dict = _safe_json_load(row.get("grounding_metadata")) | ||
|
|
@@ -147,23 +247,31 @@ def _safe_json_load(val): | |
| ) | ||
|
|
||
|
|
||
| def _get_state_dict(state_val: Any) -> dict: | ||
| def _get_state_dict(state_val: Any) -> dict[str, Any]: | ||
| """Safely load dict from JSON string or return dict if already dict.""" | ||
| if isinstance(state_val, dict): | ||
| return state_val | ||
| if isinstance(state_val, str): | ||
| try: | ||
| return json.loads(state_val) | ||
| data = json.loads(state_val) | ||
| except json.JSONDecodeError: | ||
| logger.warning( | ||
| "Failed to parse state JSON string, defaulting to empty dict." | ||
| ) | ||
| return {} | ||
| if isinstance(data, dict): | ||
| return data | ||
| logger.warning("State JSON was not an object, defaulting to empty dict.") | ||
| return {} | ||
| return {} | ||
|
|
||
|
|
||
| # --- Migration Logic --- | ||
| def migrate(source_db_url: str, dest_db_url: str): | ||
| def migrate( | ||
| source_db_url: str, | ||
| dest_db_url: str, | ||
| allow_unsafe_unpickling: bool = False, | ||
| ) -> None: | ||
| """Migrates data from old pickle schema to new JSON schema.""" | ||
| # Convert async driver URLs to sync URLs for SQLAlchemy's synchronous engine. | ||
| # This allows users to provide URLs like 'postgresql+asyncpg://...' and have | ||
|
|
@@ -172,6 +280,11 @@ def migrate(source_db_url: str, dest_db_url: str): | |
| dest_sync_url = _schema_check_utils.to_sync_url(dest_db_url) | ||
|
|
||
| logger.info(f"Connecting to source database: {source_db_url}") | ||
| if allow_unsafe_unpickling: | ||
| logger.warning( | ||
| "Unsafe pickle migration mode is enabled. Only use this with a trusted" | ||
| " source database." | ||
| ) | ||
| try: | ||
| source_engine = create_engine(source_sync_url) | ||
| SourceSession = sessionmaker(bind=source_engine) | ||
|
|
@@ -265,7 +378,10 @@ def migrate(source_db_url: str, dest_db_url: str): | |
| text("SELECT * FROM events") | ||
| ).mappings(): | ||
| try: | ||
| event_obj = _row_to_event(dict(row)) | ||
| event_obj = _row_to_event( | ||
| dict(row), | ||
| allow_unsafe_unpickling=allow_unsafe_unpickling, | ||
| ) | ||
| new_event = v1.StorageEvent( | ||
| id=event_obj.id, | ||
| app_name=row["app_name"], | ||
|
|
@@ -309,9 +425,22 @@ def migrate(source_db_url: str, dest_db_url: str): | |
| required=True, | ||
| help="SQLAlchemy URL of destination database", | ||
| ) | ||
| parser.add_argument( | ||
| "--allow_unsafe_unpickling", | ||
| "--allow-unsafe-unpickling", | ||
| action="store_true", | ||
| help=( | ||
| "Allow legacy pickle payloads to use Python's unsafe pickle loader." | ||
| " Only use this with a trusted source database." | ||
| ), | ||
| ) | ||
| args = parser.parse_args() | ||
| try: | ||
| migrate(args.source_db_url, args.dest_db_url) | ||
| migrate( | ||
| args.source_db_url, | ||
| args.dest_db_url, | ||
| allow_unsafe_unpickling=args.allow_unsafe_unpickling, | ||
| ) | ||
| except Exception as e: | ||
| logger.error(f"Migration failed: {e}") | ||
| sys.exit(1) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also allow
datetime.datetimeanddatetime.timezone? It's quite common for legacystate_deltaor otherAnyfields inEventActionsto contain timestamp objects.