Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/google/adk/evaluation/local_eval_set_results_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
import os
import re

from typing_extensions import override

Expand Down Expand Up @@ -67,6 +68,7 @@ def get_eval_set_result(
self, app_name: str, eval_set_result_id: str
) -> EvalSetResult:
"""Returns an EvalSetResult identified by app_name and eval_set_result_id."""
self._validate_id("Eval Set Result ID", eval_set_result_id)
# Load the eval set result file data.
maybe_eval_result_file_path = (
os.path.join(
Expand Down Expand Up @@ -97,4 +99,12 @@ def list_eval_set_results(self, app_name: str) -> list[str]:
return eval_result_files

def _get_eval_history_dir(self, app_name: str) -> str:
self._validate_id("App Name", app_name)
return os.path.join(self._agents_dir, app_name, _ADK_EVAL_HISTORY_DIR)

def _validate_id(self, id_name: str, id_value: str):
pattern = r"^[a-zA-Z0-9_\-\.]+$"
if not bool(re.fullmatch(pattern, id_value)) or ".." in id_value:
raise ValueError(
f"Invalid {id_name}. {id_name} should have the `{pattern}` format and not contain `..`",
)
14 changes: 8 additions & 6 deletions src/google/adk/evaluation/local_eval_sets_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def get_eval_set(self, app_name: str, eval_set_id: str) -> Optional[EvalSet]:
try:
eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id)
return load_eval_set_from_file(eval_set_file_path, eval_set_id)
except FileNotFoundError:
except (FileNotFoundError, ValueError):
return None

@override
Expand All @@ -211,8 +211,6 @@ def create_eval_set(self, app_name: str, eval_set_id: str) -> EvalSet:
Raises:
ValueError: If Eval Set ID is not valid or an eval set already exists.
"""
self._validate_id(id_name="Eval Set ID", id_value=eval_set_id)

# Define the file path
new_eval_set_path = self._get_eval_set_file_path(app_name, eval_set_id)

Expand Down Expand Up @@ -247,6 +245,7 @@ def list_eval_sets(self, app_name: str) -> list[str]:
Raises:
NotFoundError: If the eval directory for the app is not found.
"""
self._validate_id("App Name", app_name)
eval_set_file_path = os.path.join(self._agents_dir, app_name)
eval_sets = []
try:
Expand All @@ -266,6 +265,7 @@ def get_eval_case(
self, app_name: str, eval_set_id: str, eval_case_id: str
) -> Optional[EvalCase]:
"""Returns an EvalCase if found; otherwise, None."""
self._validate_id("Eval Case ID", eval_case_id)
eval_set = self.get_eval_set(app_name, eval_set_id)
if not eval_set:
return None
Expand Down Expand Up @@ -310,17 +310,19 @@ def delete_eval_case(
self._save_eval_set(app_name, eval_set_id, updated_eval_set)

def _get_eval_set_file_path(self, app_name: str, eval_set_id: str) -> str:
self._validate_id("App Name", app_name)
self._validate_id("Eval Set ID", eval_set_id)
return os.path.join(
self._agents_dir,
app_name,
eval_set_id + _EVAL_SET_FILE_EXTENSION,
)

def _validate_id(self, id_name: str, id_value: str):
pattern = r"^[a-zA-Z0-9_]+$"
if not bool(re.fullmatch(pattern, id_value)):
pattern = r"^[a-zA-Z0-9_\-\.]+$"
if not bool(re.fullmatch(pattern, id_value)) or ".." in id_value:
raise ValueError(
f"Invalid {id_name}. {id_name} should have the `{pattern}` format",
f"Invalid {id_name}. {id_name} should have the `{pattern}` format and not contain `..`",
)

def _write_eval_set_to_path(self, eval_set_path: str, eval_set: EvalSet):
Expand Down
58 changes: 58 additions & 0 deletions src/google/adk/sessions/_session_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,71 @@

from __future__ import annotations

import builtins
import pickle
from typing import Any
from typing import Optional
from typing import Type
from typing import TypeVar

from .state import State


class RestrictedUnpickler(pickle.Unpickler):
"""A restricted unpickler that only allows specific safe classes."""

SAFE_BUILTINS = {
"int",
"float",
"complex",
"bool",
"str",
"bytes",
"bytearray",
"list",
"tuple",
"set",
"frozenset",
"dict",
"NoneType",
}

def find_class(self, module, name):
# Only allow safe builtins
if module == "builtins" and name in self.SAFE_BUILTINS:
return getattr(builtins, name)

# Allow datetime classes
if module == "datetime":
import datetime

return getattr(datetime, name)

# Allow uuid classes
if module == "_uuid" or module == "uuid":
import uuid

return getattr(uuid, name)

# Allow google.adk classes
if module.startswith("google.adk."):
return super().find_class(module, name)

# Allow google.genai classes
if module.startswith("google.genai."):
return super().find_class(module, name)

# Allow pydantic classes (needed for EventActions which is a BaseModel)
if module.startswith("pydantic") or module.startswith("pydantic_core"):
return super().find_class(module, name)

# Allow collections
if module == "collections":
return super().find_class(module, name)

# Forbid everything else
raise pickle.UnpicklingError(f"global '{module}.{name}' is forbidden")

M = TypeVar("M")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import argparse
from datetime import datetime
from datetime import timezone
import io
import json
import logging
import pickle
Expand Down Expand Up @@ -59,7 +60,7 @@ def _row_to_event(row: dict) -> Event:
if actions_val is not None:
try:
if isinstance(actions_val, bytes):
actions = pickle.loads(actions_val)
actions = _session_util.RestrictedUnpickler(io.BytesIO(actions_val)).load()
else: # for spanner - it might return object directly
actions = actions_val
except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion src/google/adk/sessions/schemas/v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from datetime import datetime
from datetime import timezone
import io
import json
import pickle
from typing import Any
Expand Down Expand Up @@ -87,7 +88,7 @@ def process_result_value(self, value, dialect):
"""Ensures the raw bytes from the database are unpickled back into a Python object."""
if value is not None:
if dialect.name in ("spanner+spanner", "mysql"):
return pickle.loads(value)
return _session_util.RestrictedUnpickler(io.BytesIO(value)).load()
return value


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,11 @@ def test_list_eval_set_results_empty(self):
# No eval set results saved for the app
results = self.manager.list_eval_set_results(self.app_name)
assert results == []

def test_get_eval_history_dir_invalid_app_name(self):
with pytest.raises(ValueError, match="Invalid App Name"):
self.manager.list_eval_set_results("../invalid")

def test_get_eval_set_result_invalid_id(self):
with pytest.raises(ValueError, match="Invalid Eval Set Result ID"):
self.manager.get_eval_set_result(self.app_name, "../invalid_id")
11 changes: 10 additions & 1 deletion tests/unittests/evaluation/test_local_eval_sets_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,11 +390,20 @@ def test_local_eval_sets_manager_create_eval_set_invalid_id(
self, local_eval_sets_manager
):
app_name = "test_app"
eval_set_id = "invalid-id"
eval_set_id = "invalid/id"

with pytest.raises(ValueError, match="Invalid Eval Set ID"):
local_eval_sets_manager.create_eval_set(app_name, eval_set_id)

def test_local_eval_sets_manager_create_eval_set_invalid_app_name(
self, local_eval_sets_manager
):
app_name = "../test_app"
eval_set_id = "test_eval_set"

with pytest.raises(ValueError, match="Invalid App Name"):
local_eval_sets_manager.create_eval_set(app_name, eval_set_id)

def test_local_eval_sets_manager_create_eval_set_already_exists(
self, local_eval_sets_manager, mocker
):
Expand Down
24 changes: 24 additions & 0 deletions tests/unittests/sessions/test_dynamic_pickle_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import pickle
import os
from unittest import mock

from google.adk.sessions.schemas.v0 import DynamicPickleType
Expand Down Expand Up @@ -179,3 +180,26 @@ def test_roundtrip_pickle_dialects(pickle_type, dialect_name):
# Simulate result (DB -> Python)
result_value = pickle_type.process_result_value(bound_value, mock_dialect)
assert result_value == original_data


class MaliciousObj:
def __reduce__(self):
return (os.system, ('echo "HACKED"',))


@pytest.mark.parametrize(
"dialect_name",
[
pytest.param("mysql", id="mysql"),
pytest.param("spanner+spanner", id="spanner"),
],
)
def test_process_result_value_restricts_unsafe_pickle(pickle_type, dialect_name):
"""Test that DynamicPickleType restricts unsafe unpickling."""
mock_dialect = mock.Mock()
mock_dialect.name = dialect_name

malicious_data = pickle.dumps(MaliciousObj())

with pytest.raises(pickle.UnpicklingError, match="global 'posix.system' is forbidden|global 'nt.system' is forbidden|global 'os.system' is forbidden"):
pickle_type.process_result_value(malicious_data, mock_dialect)