Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ eval = [
# go/keep-sorted start
"Jinja2>=3.1.4,<4.0.0", # For eval template rendering
"gepa>=0.1.0",
"google-cloud-aiplatform[evaluation]>=1.100.0",
"google-cloud-aiplatform[evaluation]>=1.143.0",
"pandas>=2.2.3",
"rouge-score>=0.1.2",
"tabulate>=0.9.0",
Expand All @@ -125,7 +125,7 @@ test = [
"kubernetes>=29.0.0", # For GkeCodeExecutor
"langchain-community>=0.3.17",
"langgraph>=0.2.60, <0.4.8", # For LangGraphAgent
"litellm>=1.75.5, <2.0.0", # For LiteLLM tests
"litellm>=1.75.5, <=1.82.6", # For LiteLLM tests. Upper bound pinned: versions 1.82.7+ compromised in supply chain attack.
"llama-index-readers-file>=0.4.0", # For retrieval tests
"openai>=1.100.2", # For LiteLLM
"opentelemetry-instrumentation-google-genai>=0.3b0, <1.0.0",
Expand Down Expand Up @@ -158,7 +158,7 @@ extensions = [
"kubernetes>=29.0.0", # For GkeCodeExecutor
"k8s-agent-sandbox>=0.1.1.post2", # For GkeCodeExecutor sandbox mode
"langgraph>=0.2.60, <0.4.8", # For LangGraphAgent
"litellm>=1.75.5, <2.0.0", # For LiteLlm class. Currently has OpenAI limitations. TODO: once LiteLlm fix it
"litellm>=1.75.5, <=1.82.6", # For LiteLlm class. Upper bound pinned: versions 1.82.7+ compromised in supply chain attack.
"llama-index-readers-file>=0.4.0", # For retrieval using LlamaIndex.
"llama-index-embeddings-google-genai>=0.3.0", # For files retrieval using LlamaIndex.
"lxml>=5.3.0", # For load_web_page tool.
Expand Down
180 changes: 178 additions & 2 deletions src/google/adk/cli/adk_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import json
import logging
import os
import re
import sys
import time
import traceback
Expand Down Expand Up @@ -140,6 +141,158 @@ def _parse_cors_origins(
return literal_origins, combined_regex


def _is_origin_allowed(
origin: str,
allowed_literal_origins: list[str],
allowed_origin_regex: Optional[re.Pattern[str]],
) -> bool:
"""Check whether the given origin matches the allowed origins."""
if "*" in allowed_literal_origins:
return True
if origin in allowed_literal_origins:
return True
if allowed_origin_regex is not None:
return allowed_origin_regex.fullmatch(origin) is not None
return False


def _normalize_origin_scheme(scheme: str) -> str:
"""Normalize request schemes to the browser Origin scheme space."""
if scheme == "ws":
return "http"
if scheme == "wss":
return "https"
return scheme


def _strip_optional_quotes(value: str) -> str:
"""Strip a single pair of wrapping quotes from a header value."""
if len(value) >= 2 and value[0] == '"' and value[-1] == '"':
return value[1:-1]
return value


def _get_scope_header(
scope: dict[str, Any], header_name: bytes
) -> Optional[str]:
"""Return the first matching header value from an ASGI scope."""
for candidate_name, candidate_value in scope.get("headers", []):
if candidate_name == header_name:
return candidate_value.decode("latin-1").split(",", 1)[0].strip()
return None


def _get_request_origin(scope: dict[str, Any]) -> Optional[str]:
"""Compute the effective origin for the current HTTP/WebSocket request."""
forwarded = _get_scope_header(scope, b"forwarded")
if forwarded is not None:
proto = None
host = None
for element in forwarded.split(",", 1)[0].split(";"):
if "=" not in element:
continue
name, value = element.split("=", 1)
if name.strip().lower() == "proto":
proto = _strip_optional_quotes(value.strip())
elif name.strip().lower() == "host":
host = _strip_optional_quotes(value.strip())
if proto is not None and host is not None:
return f"{_normalize_origin_scheme(proto)}://{host}"

host = _get_scope_header(scope, b"x-forwarded-host")
if host is None:
host = _get_scope_header(scope, b"host")
if host is None:
return None

proto = _get_scope_header(scope, b"x-forwarded-proto")
if proto is None:
proto = scope.get("scheme", "http")
return f"{_normalize_origin_scheme(proto)}://{host}"


def _is_request_origin_allowed(
origin: str,
scope: dict[str, Any],
allowed_literal_origins: list[str],
allowed_origin_regex: Optional[re.Pattern[str]],
has_configured_allowed_origins: bool,
) -> bool:
"""Validate an Origin header against explicit config or same-origin."""
if has_configured_allowed_origins and _is_origin_allowed(
origin, allowed_literal_origins, allowed_origin_regex
):
return True

request_origin = _get_request_origin(scope)
if request_origin is None:
return False
return origin == request_origin


_SAFE_HTTP_METHODS = frozenset({"GET", "HEAD", "OPTIONS"})


class _OriginCheckMiddleware:
"""ASGI middleware that blocks cross-origin state-changing requests."""

def __init__(
self,
app: Any,
has_configured_allowed_origins: bool,
allowed_origins: list[str],
allowed_origin_regex: Optional[re.Pattern[str]],
) -> None:
self._app = app
self._has_configured_allowed_origins = has_configured_allowed_origins
self._allowed_origins = allowed_origins
self._allowed_origin_regex = allowed_origin_regex

async def __call__(
self,
scope: dict[str, Any],
receive: Any,
send: Any,
) -> None:
if scope["type"] != "http":
await self._app(scope, receive, send)
return

method = scope.get("method", "GET")
if method in _SAFE_HTTP_METHODS:
await self._app(scope, receive, send)
return

origin = _get_scope_header(scope, b"origin")
if origin is None:
await self._app(scope, receive, send)
return

if _is_request_origin_allowed(
origin,
scope,
self._allowed_origins,
self._allowed_origin_regex,
self._has_configured_allowed_origins,
):
await self._app(scope, receive, send)
return

response_body = b"Forbidden: origin not allowed"
await send({
"type": "http.response.start",
"status": 403,
"headers": [
(b"content-type", b"text/plain"),
(b"content-length", str(len(response_body)).encode()),
],
})
await send({
"type": "http.response.body",
"body": response_body,
})


class ApiServerSpanExporter(export_lib.SpanExporter):

def __init__(self, trace_dict):
Expand Down Expand Up @@ -759,8 +912,12 @@ async def internal_lifespan(app: FastAPI):
# Run the FastAPI server.
app = FastAPI(lifespan=internal_lifespan)

has_configured_allowed_origins = bool(allow_origins)
if allow_origins:
literal_origins, combined_regex = _parse_cors_origins(allow_origins)
compiled_origin_regex = (
re.compile(combined_regex) if combined_regex is not None else None
)
app.add_middleware(
CORSMiddleware,
allow_origins=literal_origins,
Expand All @@ -769,6 +926,16 @@ async def internal_lifespan(app: FastAPI):
allow_methods=["*"],
allow_headers=["*"],
)
else:
literal_origins = []
compiled_origin_regex = None

app.add_middleware(
_OriginCheckMiddleware,
has_configured_allowed_origins=has_configured_allowed_origins,
allowed_origins=literal_origins,
allowed_origin_regex=compiled_origin_regex,
)

@app.get("/health")
async def health() -> dict[str, str]:
Expand Down Expand Up @@ -1755,14 +1922,23 @@ async def run_agent_live(
enable_affective_dialog: bool | None = Query(default=None),
enable_session_resumption: bool | None = Query(default=None),
) -> None:
ws_origin = websocket.headers.get("origin")
if ws_origin is not None and not _is_request_origin_allowed(
ws_origin,
websocket.scope,
literal_origins,
compiled_origin_regex,
has_configured_allowed_origins,
):
await websocket.close(code=1008, reason="Origin not allowed")
return

await websocket.accept()

session = await self.session_service.get_session(
app_name=app_name, user_id=user_id, session_id=session_id
)
if not session:
# Accept first so that the client is aware of connection establishment,
# then close with a specific code.
await websocket.close(code=1002, reason="Session not found")
return

Expand Down
15 changes: 14 additions & 1 deletion src/google/adk/cli/cli_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import sys
import traceback
from typing import Final
from typing import Literal
from typing import Optional
import warnings

Expand Down Expand Up @@ -1162,6 +1163,9 @@ def to_gke(
memory_service_uri: Optional[str] = None,
use_local_storage: bool = False,
a2a: bool = False,
service_type: Literal[
'ClusterIP', 'NodePort', 'LoadBalancer'
] = 'ClusterIP',
):
"""Deploys an agent to Google Kubernetes Engine(GKE).

Expand Down Expand Up @@ -1189,6 +1193,7 @@ def to_gke(
artifact_service_uri: The URI of the artifact service.
memory_service_uri: The URI of the memory service.
use_local_storage: Whether to use local .adk storage in the container.
service_type: The Kubernetes Service type (default: ClusterIP).
"""
click.secho(
'\n🚀 Starting ADK Agent Deployment to GKE...', fg='cyan', bold=True
Expand Down Expand Up @@ -1326,7 +1331,7 @@ def to_gke(
metadata:
name: {service_name}
spec:
type: LoadBalancer
type: {service_type}
selector:
app: {service_name}
ports:
Expand Down Expand Up @@ -1380,3 +1385,11 @@ def to_gke(
click.secho(
'\n🎉 Deployment to GKE finished successfully!', fg='cyan', bold=True
)
if service_type == 'ClusterIP':
click.echo(
'\nThe service is only reachable from within the cluster.'
' To access it locally, run:'
f'\n kubectl port-forward svc/{service_name} {port}:{port}'
'\n\nTo expose the service externally, add a Gateway or'
' re-deploy with --service_type=LoadBalancer.'
)
13 changes: 13 additions & 0 deletions src/google/adk/cli/cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -2209,6 +2209,17 @@ def cli_deploy_agent_engine(
default="INFO",
help="Optional. Set the logging level",
)
@click.option(
"--service_type",
type=click.Choice(["ClusterIP", "LoadBalancer"], case_sensitive=True),
default="ClusterIP",
show_default=True,
help=(
"Optional. The Kubernetes Service type for the deployed agent."
" ClusterIP (default) keeps the service cluster-internal;"
" use LoadBalancer to expose a public IP."
),
)
@click.option(
"--temp_folder",
type=str,
Expand Down Expand Up @@ -2252,6 +2263,7 @@ def cli_deploy_gke(
otel_to_cloud: bool,
with_ui: bool,
adk_version: str,
service_type: str,
log_level: Optional[str] = None,
session_service_uri: Optional[str] = None,
artifact_service_uri: Optional[str] = None,
Expand Down Expand Up @@ -2283,6 +2295,7 @@ def cli_deploy_gke(
with_ui=with_ui,
log_level=log_level,
adk_version=adk_version,
service_type=service_type,
session_service_uri=session_service_uri,
artifact_service_uri=artifact_service_uri,
memory_service_uri=memory_service_uri,
Expand Down
Loading