diff --git a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py index a6d1ad2a78..52f0e73d0e 100644 --- a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py +++ b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py @@ -19,6 +19,7 @@ from datetime import datetime from datetime import timezone import json +import io import logging import pickle import sys @@ -38,6 +39,42 @@ logger = logging.getLogger("google_adk." + __name__) +class RestrictedUnpickler(pickle.Unpickler): + """A restricted unpickler that only allows safe classes for events.""" + + def find_class(self, module: str, name: str) -> Any: + # Allow explicit ADK modules needed for unpickling events + safe_modules = { + "google.adk.events.event_actions", + "google.adk.events.ui_widget", + "google.adk.auth.auth_tool", + "google.adk.tools.tool_confirmation", + "google.genai.types", + } + if module in safe_modules: + return super().find_class(module, name) + # Allow safe builtins + if module == "builtins": + safe_builtins = { + "set", + "frozenset", + "dict", + "list", + "tuple", + "bool", + "int", + "float", + "str", + "bytes", + "bytearray", + } + if name in safe_builtins: + import builtins + + return getattr(builtins, name) + raise pickle.UnpicklingError(f"Global '{module}.{name}' is forbidden") + + def _to_datetime_obj(val: Any) -> datetime | Any: """Converts string to datetime if needed.""" if isinstance(val, str): @@ -59,7 +96,7 @@ def _row_to_event(row: dict) -> Event: if actions_val is not None: try: if isinstance(actions_val, bytes): - actions = pickle.loads(actions_val) + actions = RestrictedUnpickler(io.BytesIO(actions_val)).load() else: # for spanner - it might return object directly actions = actions_val except Exception as e: diff --git a/src/google/adk/sessions/schemas/v0.py b/src/google/adk/sessions/schemas/v0.py index 7679a56e5b..5cc0cfaa43 100644 --- a/src/google/adk/sessions/schemas/v0.py +++ b/src/google/adk/sessions/schemas/v0.py @@ -28,6 +28,7 @@ from datetime import datetime from datetime import timezone +import io import json import pickle from typing import Any @@ -62,6 +63,42 @@ from .shared import PreciseTimestamp +class RestrictedUnpickler(pickle.Unpickler): + """A restricted unpickler that only allows safe classes for events.""" + + def find_class(self, module: str, name: str) -> Any: + # Allow explicit ADK modules needed for unpickling events + safe_modules = { + "google.adk.events.event_actions", + "google.adk.events.ui_widget", + "google.adk.auth.auth_tool", + "google.adk.tools.tool_confirmation", + "google.genai.types", + } + if module in safe_modules: + return super().find_class(module, name) + # Allow safe builtins + if module == "builtins": + safe_builtins = { + "set", + "frozenset", + "dict", + "list", + "tuple", + "bool", + "int", + "float", + "str", + "bytes", + "bytearray", + } + if name in safe_builtins: + import builtins + + return getattr(builtins, name) + raise pickle.UnpicklingError(f"Global '{module}.{name}' is forbidden") + + class DynamicPickleType(TypeDecorator): """Represents a type that can be pickled.""" @@ -87,7 +124,7 @@ def process_result_value(self, value, dialect): """Ensures the raw bytes from the database are unpickled back into a Python object.""" if value is not None: if dialect.name in ("spanner+spanner", "mysql"): - return pickle.loads(value) + return RestrictedUnpickler(io.BytesIO(value)).load() return value