Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions src/google/adk/cli/fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

import click
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi import Request
from fastapi import UploadFile
from fastapi.responses import FileResponse
from fastapi.responses import PlainTextResponse
Expand Down Expand Up @@ -266,6 +268,30 @@ def tear_down_observer(observer: Observer, _: AdkWebServer):
**extra_fast_api_args,
)

_builder_allowed_origins: set[str] = {f"http://{host}:{port}"}
if host in ("0.0.0.0", "127.0.0.1", "localhost", "::1", "::"):
_builder_allowed_origins.update({
f"http://localhost:{port}",
f"http://127.0.0.1:{port}",
})
_builder_origin_check_enabled = True
if allow_origins:
for origin in allow_origins:
if origin == "*":
_builder_origin_check_enabled = False
break
if not origin.startswith("regex:"):
_builder_allowed_origins.add(origin)

def _check_origin(request: Request) -> None:
if not _builder_origin_check_enabled:
return
origin = request.headers.get("origin")
if origin and origin not in _builder_allowed_origins:
raise HTTPException(
status_code=403, detail=f"Origin not allowed: {origin}"
)

agents_base_path = (Path.cwd() / agents_dir).resolve()

def _get_app_root(app_name: str) -> Path:
Expand Down Expand Up @@ -406,8 +432,11 @@ def ensure_tmp_exists(app_name: str) -> bool:

@app.post("/builder/save", response_model_exclude_none=True)
async def builder_build(
files: list[UploadFile], tmp: Optional[bool] = False
request: Request,
files: list[UploadFile],
tmp: Optional[bool] = False,
) -> bool:
_check_origin(request)
try:
if tmp:
app_names = set()
Expand Down Expand Up @@ -472,7 +501,8 @@ async def builder_build(
return False

@app.post("/builder/app/{app_name}/cancel", response_model_exclude_none=True)
async def builder_cancel(app_name: str) -> bool:
async def builder_cancel(request: Request, app_name: str) -> bool:
_check_origin(request)
return cleanup_tmp(app_name)

@app.get(
Expand Down
53 changes: 53 additions & 0 deletions tests/unittests/cli/test_fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,6 +1658,59 @@ def test_builder_save_rejects_traversal(builder_test_client, tmp_path):
assert not (tmp_path / "app" / "tmp" / "escape.yaml").exists()


@pytest.fixture
def csrf_test_client(
tmp_path,
mock_session_service,
mock_artifact_service,
mock_memory_service,
mock_agent_loader,
mock_eval_sets_manager,
mock_eval_set_results_manager,
):
return _create_test_client(
mock_session_service,
mock_artifact_service,
mock_memory_service,
mock_agent_loader,
mock_eval_sets_manager,
mock_eval_set_results_manager,
allow_origins=None,
)


def test_builder_save_rejects_cross_origin(csrf_test_client):
response = csrf_test_client.post(
"/builder/save",
files=[
("files", ("app/agent.yaml", b"name: test\n", "application/x-yaml"))
],
headers={"Origin": "http://evil.com"},
)
assert response.status_code == 403


def test_builder_save_allows_same_origin(csrf_test_client):
response = csrf_test_client.post(
"/builder/save",
files=[
("files", ("app/agent.yaml", b"name: test\n", "application/x-yaml"))
],
headers={"Origin": "http://127.0.0.1:8000"},
)
assert response.status_code != 403


def test_builder_save_allows_no_origin(csrf_test_client):
response = csrf_test_client.post(
"/builder/save",
files=[
("files", ("app/agent.yaml", b"name: test\n", "application/x-yaml"))
],
)
assert response.status_code != 403


def test_agent_run_resume_without_message_success(
test_app, create_test_session
):
Expand Down