From 25933a83dcf5dbda85684e07bd842b3962c0f122 Mon Sep 17 00:00:00 2001 From: Shivan Date: Fri, 27 Mar 2026 23:07:06 -0700 Subject: [PATCH 1/2] fix(evaluation): Prevent path traversal in local eval managers This commit adds a strict validation regex (^[a-zA-Z0-9_\-\.]+$) and explicit `..` checks for app_name, eval_set_id, eval_case_id, and eval_set_result_id in LocalEvalSetsManager and LocalEvalSetResultsManager. By sanitizing path parameters, this prevents directory traversal attacks when the FastAPI endpoints attempt to read or modify evaluation JSON files on the local filesystem. --- .../evaluation/local_eval_set_results_manager.py | 10 ++++++++++ .../adk/evaluation/local_eval_sets_manager.py | 14 ++++++++------ .../test_local_eval_set_results_manager.py | 8 ++++++++ .../evaluation/test_local_eval_sets_manager.py | 11 ++++++++++- 4 files changed, 36 insertions(+), 7 deletions(-) diff --git a/src/google/adk/evaluation/local_eval_set_results_manager.py b/src/google/adk/evaluation/local_eval_set_results_manager.py index c6da638abe..656d9f411e 100644 --- a/src/google/adk/evaluation/local_eval_set_results_manager.py +++ b/src/google/adk/evaluation/local_eval_set_results_manager.py @@ -16,6 +16,7 @@ import logging import os +import re from typing_extensions import override @@ -67,6 +68,7 @@ def get_eval_set_result( self, app_name: str, eval_set_result_id: str ) -> EvalSetResult: """Returns an EvalSetResult identified by app_name and eval_set_result_id.""" + self._validate_id("Eval Set Result ID", eval_set_result_id) # Load the eval set result file data. maybe_eval_result_file_path = ( os.path.join( @@ -97,4 +99,12 @@ def list_eval_set_results(self, app_name: str) -> list[str]: return eval_result_files def _get_eval_history_dir(self, app_name: str) -> str: + self._validate_id("App Name", app_name) return os.path.join(self._agents_dir, app_name, _ADK_EVAL_HISTORY_DIR) + + def _validate_id(self, id_name: str, id_value: str): + pattern = r"^[a-zA-Z0-9_\-\.]+$" + if not bool(re.fullmatch(pattern, id_value)) or ".." in id_value: + raise ValueError( + f"Invalid {id_name}. {id_name} should have the `{pattern}` format and not contain `..`", + ) diff --git a/src/google/adk/evaluation/local_eval_sets_manager.py b/src/google/adk/evaluation/local_eval_sets_manager.py index 8d2290b911..3f2f0ca77f 100644 --- a/src/google/adk/evaluation/local_eval_sets_manager.py +++ b/src/google/adk/evaluation/local_eval_sets_manager.py @@ -201,7 +201,7 @@ def get_eval_set(self, app_name: str, eval_set_id: str) -> Optional[EvalSet]: try: eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id) return load_eval_set_from_file(eval_set_file_path, eval_set_id) - except FileNotFoundError: + except (FileNotFoundError, ValueError): return None @override @@ -211,8 +211,6 @@ def create_eval_set(self, app_name: str, eval_set_id: str) -> EvalSet: Raises: ValueError: If Eval Set ID is not valid or an eval set already exists. """ - self._validate_id(id_name="Eval Set ID", id_value=eval_set_id) - # Define the file path new_eval_set_path = self._get_eval_set_file_path(app_name, eval_set_id) @@ -247,6 +245,7 @@ def list_eval_sets(self, app_name: str) -> list[str]: Raises: NotFoundError: If the eval directory for the app is not found. """ + self._validate_id("App Name", app_name) eval_set_file_path = os.path.join(self._agents_dir, app_name) eval_sets = [] try: @@ -266,6 +265,7 @@ def get_eval_case( self, app_name: str, eval_set_id: str, eval_case_id: str ) -> Optional[EvalCase]: """Returns an EvalCase if found; otherwise, None.""" + self._validate_id("Eval Case ID", eval_case_id) eval_set = self.get_eval_set(app_name, eval_set_id) if not eval_set: return None @@ -310,6 +310,8 @@ def delete_eval_case( self._save_eval_set(app_name, eval_set_id, updated_eval_set) def _get_eval_set_file_path(self, app_name: str, eval_set_id: str) -> str: + self._validate_id("App Name", app_name) + self._validate_id("Eval Set ID", eval_set_id) return os.path.join( self._agents_dir, app_name, @@ -317,10 +319,10 @@ def _get_eval_set_file_path(self, app_name: str, eval_set_id: str) -> str: ) def _validate_id(self, id_name: str, id_value: str): - pattern = r"^[a-zA-Z0-9_]+$" - if not bool(re.fullmatch(pattern, id_value)): + pattern = r"^[a-zA-Z0-9_\-\.]+$" + if not bool(re.fullmatch(pattern, id_value)) or ".." in id_value: raise ValueError( - f"Invalid {id_name}. {id_name} should have the `{pattern}` format", + f"Invalid {id_name}. {id_name} should have the `{pattern}` format and not contain `..`", ) def _write_eval_set_to_path(self, eval_set_path: str, eval_set: EvalSet): diff --git a/tests/unittests/evaluation/test_local_eval_set_results_manager.py b/tests/unittests/evaluation/test_local_eval_set_results_manager.py index 4647392628..5b2c873e29 100644 --- a/tests/unittests/evaluation/test_local_eval_set_results_manager.py +++ b/tests/unittests/evaluation/test_local_eval_set_results_manager.py @@ -174,3 +174,11 @@ def test_list_eval_set_results_empty(self): # No eval set results saved for the app results = self.manager.list_eval_set_results(self.app_name) assert results == [] + + def test_get_eval_history_dir_invalid_app_name(self): + with pytest.raises(ValueError, match="Invalid App Name"): + self.manager.list_eval_set_results("../invalid") + + def test_get_eval_set_result_invalid_id(self): + with pytest.raises(ValueError, match="Invalid Eval Set Result ID"): + self.manager.get_eval_set_result(self.app_name, "../invalid_id") diff --git a/tests/unittests/evaluation/test_local_eval_sets_manager.py b/tests/unittests/evaluation/test_local_eval_sets_manager.py index 3450fb9338..67e089a3db 100644 --- a/tests/unittests/evaluation/test_local_eval_sets_manager.py +++ b/tests/unittests/evaluation/test_local_eval_sets_manager.py @@ -390,11 +390,20 @@ def test_local_eval_sets_manager_create_eval_set_invalid_id( self, local_eval_sets_manager ): app_name = "test_app" - eval_set_id = "invalid-id" + eval_set_id = "invalid/id" with pytest.raises(ValueError, match="Invalid Eval Set ID"): local_eval_sets_manager.create_eval_set(app_name, eval_set_id) + def test_local_eval_sets_manager_create_eval_set_invalid_app_name( + self, local_eval_sets_manager + ): + app_name = "../test_app" + eval_set_id = "test_eval_set" + + with pytest.raises(ValueError, match="Invalid App Name"): + local_eval_sets_manager.create_eval_set(app_name, eval_set_id) + def test_local_eval_sets_manager_create_eval_set_already_exists( self, local_eval_sets_manager, mocker ): From 1403b924e08c1eb5110e5a1a5ee3a1348187a03d Mon Sep 17 00:00:00 2001 From: shivan4030 <9358527+shivan4030@users.noreply.github.com> Date: Sat, 28 Mar 2026 09:04:08 +0000 Subject: [PATCH 2/2] fix: use RestrictedUnpickler to prevent insecure deserialization in v0 schema Replaces `pickle.loads` with a custom `RestrictedUnpickler` in `schemas/v0.py` and the SQLAlchemy migration script to prevent arbitrary code execution vulnerabilities when deserializing EventActions objects from the database. Added corresponding tests to verify security controls block malicious unpickling. --- src/google/adk/sessions/_session_util.py | 58 +++++++++++++++++++ .../migrate_from_sqlalchemy_pickle.py | 3 +- src/google/adk/sessions/schemas/v0.py | 3 +- .../sessions/test_dynamic_pickle_type.py | 24 ++++++++ 4 files changed, 86 insertions(+), 2 deletions(-) diff --git a/src/google/adk/sessions/_session_util.py b/src/google/adk/sessions/_session_util.py index 3a92021929..3927223af0 100644 --- a/src/google/adk/sessions/_session_util.py +++ b/src/google/adk/sessions/_session_util.py @@ -15,6 +15,8 @@ from __future__ import annotations +import builtins +import pickle from typing import Any from typing import Optional from typing import Type @@ -22,6 +24,62 @@ from .state import State + +class RestrictedUnpickler(pickle.Unpickler): + """A restricted unpickler that only allows specific safe classes.""" + + SAFE_BUILTINS = { + "int", + "float", + "complex", + "bool", + "str", + "bytes", + "bytearray", + "list", + "tuple", + "set", + "frozenset", + "dict", + "NoneType", + } + + def find_class(self, module, name): + # Only allow safe builtins + if module == "builtins" and name in self.SAFE_BUILTINS: + return getattr(builtins, name) + + # Allow datetime classes + if module == "datetime": + import datetime + + return getattr(datetime, name) + + # Allow uuid classes + if module == "_uuid" or module == "uuid": + import uuid + + return getattr(uuid, name) + + # Allow google.adk classes + if module.startswith("google.adk."): + return super().find_class(module, name) + + # Allow google.genai classes + if module.startswith("google.genai."): + return super().find_class(module, name) + + # Allow pydantic classes (needed for EventActions which is a BaseModel) + if module.startswith("pydantic") or module.startswith("pydantic_core"): + return super().find_class(module, name) + + # Allow collections + if module == "collections": + return super().find_class(module, name) + + # Forbid everything else + raise pickle.UnpicklingError(f"global '{module}.{name}' is forbidden") + M = TypeVar("M") diff --git a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py index a6d1ad2a78..cefac12bb0 100644 --- a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py +++ b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py @@ -18,6 +18,7 @@ import argparse from datetime import datetime from datetime import timezone +import io import json import logging import pickle @@ -59,7 +60,7 @@ def _row_to_event(row: dict) -> Event: if actions_val is not None: try: if isinstance(actions_val, bytes): - actions = pickle.loads(actions_val) + actions = _session_util.RestrictedUnpickler(io.BytesIO(actions_val)).load() else: # for spanner - it might return object directly actions = actions_val except Exception as e: diff --git a/src/google/adk/sessions/schemas/v0.py b/src/google/adk/sessions/schemas/v0.py index 7679a56e5b..e920185e09 100644 --- a/src/google/adk/sessions/schemas/v0.py +++ b/src/google/adk/sessions/schemas/v0.py @@ -28,6 +28,7 @@ from datetime import datetime from datetime import timezone +import io import json import pickle from typing import Any @@ -87,7 +88,7 @@ def process_result_value(self, value, dialect): """Ensures the raw bytes from the database are unpickled back into a Python object.""" if value is not None: if dialect.name in ("spanner+spanner", "mysql"): - return pickle.loads(value) + return _session_util.RestrictedUnpickler(io.BytesIO(value)).load() return value diff --git a/tests/unittests/sessions/test_dynamic_pickle_type.py b/tests/unittests/sessions/test_dynamic_pickle_type.py index e1ac56294b..1a0fb44e1a 100644 --- a/tests/unittests/sessions/test_dynamic_pickle_type.py +++ b/tests/unittests/sessions/test_dynamic_pickle_type.py @@ -15,6 +15,7 @@ from __future__ import annotations import pickle +import os from unittest import mock from google.adk.sessions.schemas.v0 import DynamicPickleType @@ -179,3 +180,26 @@ def test_roundtrip_pickle_dialects(pickle_type, dialect_name): # Simulate result (DB -> Python) result_value = pickle_type.process_result_value(bound_value, mock_dialect) assert result_value == original_data + + +class MaliciousObj: + def __reduce__(self): + return (os.system, ('echo "HACKED"',)) + + +@pytest.mark.parametrize( + "dialect_name", + [ + pytest.param("mysql", id="mysql"), + pytest.param("spanner+spanner", id="spanner"), + ], +) +def test_process_result_value_restricts_unsafe_pickle(pickle_type, dialect_name): + """Test that DynamicPickleType restricts unsafe unpickling.""" + mock_dialect = mock.Mock() + mock_dialect.name = dialect_name + + malicious_data = pickle.dumps(MaliciousObj()) + + with pytest.raises(pickle.UnpicklingError, match="global 'posix.system' is forbidden|global 'nt.system' is forbidden|global 'os.system' is forbidden"): + pickle_type.process_result_value(malicious_data, mock_dialect)