From c6069a73c8a8126bb377eed54a793ba991307eaf Mon Sep 17 00:00:00 2001 From: luska Date: Mon, 23 Mar 2026 12:43:32 -0300 Subject: [PATCH] fix(security): validate Origin header on /builder/save endpoint Reject cross-origin POST requests to /builder/save and /builder/app/{app_name}/cancel by checking the Origin header against a set of allowed origins derived from the server's host and port. This prevents CSRF attacks where a malicious website could upload arbitrary agent code to a victim's local ADK instance. --- src/google/adk/cli/fast_api.py | 34 ++++++++++++++++-- tests/unittests/cli/test_fast_api.py | 53 ++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 2 deletions(-) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 8f78c15f9b..e08b6b33ce 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -27,6 +27,8 @@ import click from fastapi import FastAPI +from fastapi import HTTPException +from fastapi import Request from fastapi import UploadFile from fastapi.responses import FileResponse from fastapi.responses import PlainTextResponse @@ -266,6 +268,30 @@ def tear_down_observer(observer: Observer, _: AdkWebServer): **extra_fast_api_args, ) + _builder_allowed_origins: set[str] = {f"http://{host}:{port}"} + if host in ("0.0.0.0", "127.0.0.1", "localhost", "::1", "::"): + _builder_allowed_origins.update({ + f"http://localhost:{port}", + f"http://127.0.0.1:{port}", + }) + _builder_origin_check_enabled = True + if allow_origins: + for origin in allow_origins: + if origin == "*": + _builder_origin_check_enabled = False + break + if not origin.startswith("regex:"): + _builder_allowed_origins.add(origin) + + def _check_origin(request: Request) -> None: + if not _builder_origin_check_enabled: + return + origin = request.headers.get("origin") + if origin and origin not in _builder_allowed_origins: + raise HTTPException( + status_code=403, detail=f"Origin not allowed: {origin}" + ) + agents_base_path = (Path.cwd() / agents_dir).resolve() def _get_app_root(app_name: str) -> Path: @@ -406,8 +432,11 @@ def ensure_tmp_exists(app_name: str) -> bool: @app.post("/builder/save", response_model_exclude_none=True) async def builder_build( - files: list[UploadFile], tmp: Optional[bool] = False + request: Request, + files: list[UploadFile], + tmp: Optional[bool] = False, ) -> bool: + _check_origin(request) try: if tmp: app_names = set() @@ -472,7 +501,8 @@ async def builder_build( return False @app.post("/builder/app/{app_name}/cancel", response_model_exclude_none=True) - async def builder_cancel(app_name: str) -> bool: + async def builder_cancel(request: Request, app_name: str) -> bool: + _check_origin(request) return cleanup_tmp(app_name) @app.get( diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 0ea28e6683..b76dceb494 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -1658,6 +1658,59 @@ def test_builder_save_rejects_traversal(builder_test_client, tmp_path): assert not (tmp_path / "app" / "tmp" / "escape.yaml").exists() +@pytest.fixture +def csrf_test_client( + tmp_path, + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, +): + return _create_test_client( + mock_session_service, + mock_artifact_service, + mock_memory_service, + mock_agent_loader, + mock_eval_sets_manager, + mock_eval_set_results_manager, + allow_origins=None, + ) + + +def test_builder_save_rejects_cross_origin(csrf_test_client): + response = csrf_test_client.post( + "/builder/save", + files=[ + ("files", ("app/agent.yaml", b"name: test\n", "application/x-yaml")) + ], + headers={"Origin": "http://evil.com"}, + ) + assert response.status_code == 403 + + +def test_builder_save_allows_same_origin(csrf_test_client): + response = csrf_test_client.post( + "/builder/save", + files=[ + ("files", ("app/agent.yaml", b"name: test\n", "application/x-yaml")) + ], + headers={"Origin": "http://127.0.0.1:8000"}, + ) + assert response.status_code != 403 + + +def test_builder_save_allows_no_origin(csrf_test_client): + response = csrf_test_client.post( + "/builder/save", + files=[ + ("files", ("app/agent.yaml", b"name: test\n", "application/x-yaml")) + ], + ) + assert response.status_code != 403 + + def test_agent_run_resume_without_message_success( test_app, create_test_session ):