From 25933a83dcf5dbda85684e07bd842b3962c0f122 Mon Sep 17 00:00:00 2001 From: Shivan Date: Fri, 27 Mar 2026 23:07:06 -0700 Subject: [PATCH 1/2] fix(evaluation): Prevent path traversal in local eval managers This commit adds a strict validation regex (^[a-zA-Z0-9_\-\.]+$) and explicit `..` checks for app_name, eval_set_id, eval_case_id, and eval_set_result_id in LocalEvalSetsManager and LocalEvalSetResultsManager. By sanitizing path parameters, this prevents directory traversal attacks when the FastAPI endpoints attempt to read or modify evaluation JSON files on the local filesystem. --- .../evaluation/local_eval_set_results_manager.py | 10 ++++++++++ .../adk/evaluation/local_eval_sets_manager.py | 14 ++++++++------ .../test_local_eval_set_results_manager.py | 8 ++++++++ .../evaluation/test_local_eval_sets_manager.py | 11 ++++++++++- 4 files changed, 36 insertions(+), 7 deletions(-) diff --git a/src/google/adk/evaluation/local_eval_set_results_manager.py b/src/google/adk/evaluation/local_eval_set_results_manager.py index c6da638abe..656d9f411e 100644 --- a/src/google/adk/evaluation/local_eval_set_results_manager.py +++ b/src/google/adk/evaluation/local_eval_set_results_manager.py @@ -16,6 +16,7 @@ import logging import os +import re from typing_extensions import override @@ -67,6 +68,7 @@ def get_eval_set_result( self, app_name: str, eval_set_result_id: str ) -> EvalSetResult: """Returns an EvalSetResult identified by app_name and eval_set_result_id.""" + self._validate_id("Eval Set Result ID", eval_set_result_id) # Load the eval set result file data. maybe_eval_result_file_path = ( os.path.join( @@ -97,4 +99,12 @@ def list_eval_set_results(self, app_name: str) -> list[str]: return eval_result_files def _get_eval_history_dir(self, app_name: str) -> str: + self._validate_id("App Name", app_name) return os.path.join(self._agents_dir, app_name, _ADK_EVAL_HISTORY_DIR) + + def _validate_id(self, id_name: str, id_value: str): + pattern = r"^[a-zA-Z0-9_\-\.]+$" + if not bool(re.fullmatch(pattern, id_value)) or ".." in id_value: + raise ValueError( + f"Invalid {id_name}. {id_name} should have the `{pattern}` format and not contain `..`", + ) diff --git a/src/google/adk/evaluation/local_eval_sets_manager.py b/src/google/adk/evaluation/local_eval_sets_manager.py index 8d2290b911..3f2f0ca77f 100644 --- a/src/google/adk/evaluation/local_eval_sets_manager.py +++ b/src/google/adk/evaluation/local_eval_sets_manager.py @@ -201,7 +201,7 @@ def get_eval_set(self, app_name: str, eval_set_id: str) -> Optional[EvalSet]: try: eval_set_file_path = self._get_eval_set_file_path(app_name, eval_set_id) return load_eval_set_from_file(eval_set_file_path, eval_set_id) - except FileNotFoundError: + except (FileNotFoundError, ValueError): return None @override @@ -211,8 +211,6 @@ def create_eval_set(self, app_name: str, eval_set_id: str) -> EvalSet: Raises: ValueError: If Eval Set ID is not valid or an eval set already exists. """ - self._validate_id(id_name="Eval Set ID", id_value=eval_set_id) - # Define the file path new_eval_set_path = self._get_eval_set_file_path(app_name, eval_set_id) @@ -247,6 +245,7 @@ def list_eval_sets(self, app_name: str) -> list[str]: Raises: NotFoundError: If the eval directory for the app is not found. """ + self._validate_id("App Name", app_name) eval_set_file_path = os.path.join(self._agents_dir, app_name) eval_sets = [] try: @@ -266,6 +265,7 @@ def get_eval_case( self, app_name: str, eval_set_id: str, eval_case_id: str ) -> Optional[EvalCase]: """Returns an EvalCase if found; otherwise, None.""" + self._validate_id("Eval Case ID", eval_case_id) eval_set = self.get_eval_set(app_name, eval_set_id) if not eval_set: return None @@ -310,6 +310,8 @@ def delete_eval_case( self._save_eval_set(app_name, eval_set_id, updated_eval_set) def _get_eval_set_file_path(self, app_name: str, eval_set_id: str) -> str: + self._validate_id("App Name", app_name) + self._validate_id("Eval Set ID", eval_set_id) return os.path.join( self._agents_dir, app_name, @@ -317,10 +319,10 @@ def _get_eval_set_file_path(self, app_name: str, eval_set_id: str) -> str: ) def _validate_id(self, id_name: str, id_value: str): - pattern = r"^[a-zA-Z0-9_]+$" - if not bool(re.fullmatch(pattern, id_value)): + pattern = r"^[a-zA-Z0-9_\-\.]+$" + if not bool(re.fullmatch(pattern, id_value)) or ".." in id_value: raise ValueError( - f"Invalid {id_name}. {id_name} should have the `{pattern}` format", + f"Invalid {id_name}. {id_name} should have the `{pattern}` format and not contain `..`", ) def _write_eval_set_to_path(self, eval_set_path: str, eval_set: EvalSet): diff --git a/tests/unittests/evaluation/test_local_eval_set_results_manager.py b/tests/unittests/evaluation/test_local_eval_set_results_manager.py index 4647392628..5b2c873e29 100644 --- a/tests/unittests/evaluation/test_local_eval_set_results_manager.py +++ b/tests/unittests/evaluation/test_local_eval_set_results_manager.py @@ -174,3 +174,11 @@ def test_list_eval_set_results_empty(self): # No eval set results saved for the app results = self.manager.list_eval_set_results(self.app_name) assert results == [] + + def test_get_eval_history_dir_invalid_app_name(self): + with pytest.raises(ValueError, match="Invalid App Name"): + self.manager.list_eval_set_results("../invalid") + + def test_get_eval_set_result_invalid_id(self): + with pytest.raises(ValueError, match="Invalid Eval Set Result ID"): + self.manager.get_eval_set_result(self.app_name, "../invalid_id") diff --git a/tests/unittests/evaluation/test_local_eval_sets_manager.py b/tests/unittests/evaluation/test_local_eval_sets_manager.py index 3450fb9338..67e089a3db 100644 --- a/tests/unittests/evaluation/test_local_eval_sets_manager.py +++ b/tests/unittests/evaluation/test_local_eval_sets_manager.py @@ -390,11 +390,20 @@ def test_local_eval_sets_manager_create_eval_set_invalid_id( self, local_eval_sets_manager ): app_name = "test_app" - eval_set_id = "invalid-id" + eval_set_id = "invalid/id" with pytest.raises(ValueError, match="Invalid Eval Set ID"): local_eval_sets_manager.create_eval_set(app_name, eval_set_id) + def test_local_eval_sets_manager_create_eval_set_invalid_app_name( + self, local_eval_sets_manager + ): + app_name = "../test_app" + eval_set_id = "test_eval_set" + + with pytest.raises(ValueError, match="Invalid App Name"): + local_eval_sets_manager.create_eval_set(app_name, eval_set_id) + def test_local_eval_sets_manager_create_eval_set_already_exists( self, local_eval_sets_manager, mocker ): From 1658c82783988f8172fd129aa417569ccece02ce Mon Sep 17 00:00:00 2001 From: shivan4030 <9358527+shivan4030@users.noreply.github.com> Date: Sat, 28 Mar 2026 08:54:32 +0000 Subject: [PATCH 2/2] Fix call_tool to use keyword arguments instead of positional args Updated `session.call_tool` invocation in `mcp_tool.py` to use `name=` instead of a positional argument for the tool name. Updated `test_mcp_tool.py` mock assertions accordingly, and removed the TODO comment. --- src/google/adk/tools/mcp_tool/mcp_tool.py | 2 +- tests/unittests/tools/mcp_tool/test_mcp_tool.py | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 9a2fd5fcfd..54f784aa11 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -374,7 +374,7 @@ async def _run_async_impl( resolved_callback = self._resolve_progress_callback(tool_context) response = await session.call_tool( - self._mcp_tool.name, + name=self._mcp_tool.name, arguments=args, progress_callback=resolved_callback, meta=meta_trace_context, diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index d6e39b94f3..9ba56745ab 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -228,9 +228,8 @@ async def test_run_async_impl_no_auth(self): self.mock_session_manager.create_session.assert_called_once_with( headers=None ) - # Fix: call_tool uses 'arguments' parameter, not positional args self.mock_session.call_tool.assert_called_once_with( - "test_tool", arguments=args, progress_callback=None, meta=None + name="test_tool", arguments=args, progress_callback=None, meta=None ) @pytest.mark.asyncio @@ -340,7 +339,7 @@ def inject_context(carrier, context=None) -> None: headers=None ) self.mock_session.call_tool.assert_called_once_with( - "test_tool", + name="test_tool", arguments=args, progress_callback=None, meta={ @@ -870,7 +869,7 @@ async def test_run_async_impl_with_header_provider_no_auth(self): headers=expected_headers ) self.mock_session.call_tool.assert_called_once_with( - "test_tool", arguments=args, progress_callback=None, meta=None + name="test_tool", arguments=args, progress_callback=None, meta=None ) @pytest.mark.asyncio @@ -913,7 +912,7 @@ async def test_run_async_impl_with_header_provider_and_oauth2(self): "X-Tenant-ID": "test-tenant", } self.mock_session.call_tool.assert_called_once_with( - "test_tool", arguments=args, progress_callback=None, meta=None + name="test_tool", arguments=args, progress_callback=None, meta=None ) def test_init_with_progress_callback(self): @@ -967,7 +966,7 @@ async def my_progress_callback( ) # Verify progress_callback was passed to call_tool self.mock_session.call_tool.assert_called_once_with( - "test_tool", + name="test_tool", arguments=args, progress_callback=my_progress_callback, meta=None,