Skip to content

Commit 4ab5256

Browse files
author
Neelagiri65
committed
fix: add SSRF protection to RestApiTool._request()
RestApiTool._request() passes URLs straight to httpx with no validation. load_web_page got SSRF protections after #4368 (hostname blocking, DNS pre-resolution, IP range filtering, scheme restriction) but RestApiTool was not updated. This extracts the SSRF validation logic into a shared _ssrf_protection module and applies it in _request() before making HTTP calls. Also sets a finite timeout (30s) instead of None. Blocked: localhost, *.localhost, loopback, link-local (169.254.x.x), private ranges (10.x, 172.16-31.x, 192.168.x), non-http(s) schemes. Related: #4368
1 parent 9670ce2 commit 4ab5256

5 files changed

Lines changed: 611 additions & 82 deletions

File tree

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Shared SSRF protection helpers for tools that make HTTP requests.
16+
17+
Two layers:
18+
19+
1. ``validate_url`` rejects bad schemes, missing/blocked hostnames, and any
20+
DNS result that includes a non-globally-routable IP. It returns a
21+
``ValidatedTarget`` so callers can use the pre-resolved address list.
22+
23+
2. ``send_pinned_async`` issues an ``httpx`` request against the validated IP
24+
literal directly, preserves the ``Host`` header, and sets the TLS server
25+
name via ``request.extensions["sni_hostname"]``. Together with (1) this
26+
closes the DNS rebinding window between URL validation and connect: even
27+
if the attacker flips the DNS record after validation, the socket goes to
28+
the IP we validated and the cert check uses the original hostname.
29+
30+
A matching ``PinnedAddressAdapter`` for the ``requests`` library is also
31+
provided so ``load_web_page`` and any other sync caller can share the same
32+
resolution and blocking rules.
33+
"""
34+
35+
from __future__ import annotations
36+
37+
from dataclasses import dataclass
38+
import ipaddress
39+
import socket
40+
from typing import Any
41+
from typing import Optional
42+
from urllib.parse import ParseResult
43+
from urllib.parse import urlparse
44+
from urllib.parse import urlunparse
45+
46+
import httpx
47+
48+
_ALLOWED_URL_SCHEMES = frozenset({"http", "https"})
49+
_DEFAULT_PORT_BY_SCHEME = {"http": 80, "https": 443}
50+
_ResolvedAddress = ipaddress.IPv4Address | ipaddress.IPv6Address
51+
52+
53+
@dataclass(frozen=True)
54+
class ValidatedTarget:
55+
"""A URL that passed validation, with its resolved addresses cached."""
56+
57+
url: str
58+
parsed: ParseResult
59+
scheme: str
60+
hostname: str
61+
host_header: str
62+
addresses: tuple[_ResolvedAddress, ...]
63+
64+
65+
def _format_host(hostname: str) -> str:
66+
if ":" in hostname:
67+
return f"[{hostname}]"
68+
return hostname
69+
70+
71+
def _build_host_header(
72+
*,
73+
hostname: str,
74+
scheme: str,
75+
explicit_port: Optional[int],
76+
) -> str:
77+
formatted = _format_host(hostname)
78+
default_port = _DEFAULT_PORT_BY_SCHEME[scheme]
79+
if explicit_port is None or explicit_port == default_port:
80+
return formatted
81+
return f"{formatted}:{explicit_port}"
82+
83+
84+
def is_blocked_hostname(hostname: str) -> bool:
85+
"""Return True for hostnames that always point at the local host."""
86+
normalized = hostname.rstrip(".").lower()
87+
return normalized == "localhost" or normalized.endswith(".localhost")
88+
89+
90+
def is_blocked_address(address: _ResolvedAddress) -> bool:
91+
"""Return True for any IP that isn't globally routable.
92+
93+
``ipaddress.is_global`` already covers private (RFC 1918), loopback,
94+
link-local (including 169.254.169.254), multicast, reserved, and unspecified
95+
ranges across IPv4 and IPv6. Using it directly avoids drift between hand
96+
maintained allow lists in different tools.
97+
"""
98+
return not address.is_global
99+
100+
101+
def _parse_ip_literal(hostname: str) -> Optional[_ResolvedAddress]:
102+
try:
103+
return ipaddress.ip_address(hostname)
104+
except ValueError:
105+
return None
106+
107+
108+
def resolve_host_addresses(hostname: str) -> tuple[_ResolvedAddress, ...]:
109+
"""Resolve a hostname to all of its A / AAAA records.
110+
111+
IP literals short-circuit and return themselves. ``getaddrinfo`` errors are
112+
surfaced as ``ValueError`` so callers can handle resolution failure and a
113+
bad scheme through the same code path.
114+
"""
115+
literal = _parse_ip_literal(hostname)
116+
if literal is not None:
117+
return (literal,)
118+
119+
try:
120+
info = socket.getaddrinfo(
121+
hostname,
122+
None,
123+
type=socket.SOCK_STREAM,
124+
proto=socket.IPPROTO_TCP,
125+
)
126+
except (socket.gaierror, UnicodeError) as exc:
127+
raise ValueError(f"Unable to resolve host: {hostname}") from exc
128+
129+
addresses: list[_ResolvedAddress] = []
130+
for family, _, _, _, sockaddr in info:
131+
if family not in (socket.AF_INET, socket.AF_INET6):
132+
continue
133+
addresses.append(ipaddress.ip_address(sockaddr[0]))
134+
135+
if not addresses:
136+
raise ValueError(f"Unable to resolve host: {hostname}")
137+
138+
# Deduplicate while preserving order so the first record is still tried
139+
# first by callers that iterate the tuple.
140+
return tuple(dict.fromkeys(addresses))
141+
142+
143+
def validate_url(url: str) -> ValidatedTarget:
144+
"""Validate ``url`` and return its resolved addresses.
145+
146+
Raises ``ValueError`` for unsupported schemes, missing or blocked
147+
hostnames, invalid ports, and DNS results where any IP is not globally
148+
routable. The check rejects the whole hostname if even one record points
149+
at private space so an attacker can't sneak past the gate with a
150+
multi-record set such as ``[8.8.8.8, 127.0.0.1]``.
151+
152+
Returning the addresses lets the caller pin the connection to a vetted IP
153+
instead of re-resolving at connect time. That closes the DNS rebinding
154+
window between this validation and the eventual HTTP request.
155+
"""
156+
parsed = urlparse(url)
157+
scheme = parsed.scheme.lower()
158+
if scheme not in _ALLOWED_URL_SCHEMES:
159+
raise ValueError(f"Unsupported url scheme: {url}")
160+
161+
hostname = parsed.hostname
162+
if not hostname:
163+
raise ValueError(f"URL is missing a hostname: {url}")
164+
165+
try:
166+
explicit_port = parsed.port
167+
except ValueError as exc:
168+
raise ValueError(f"Invalid url port: {url}") from exc
169+
170+
if is_blocked_hostname(hostname):
171+
raise ValueError(f"Blocked host: {hostname}")
172+
173+
addresses = resolve_host_addresses(hostname)
174+
if any(is_blocked_address(addr) for addr in addresses):
175+
raise ValueError(f"Blocked host: {hostname}")
176+
177+
return ValidatedTarget(
178+
url=url,
179+
parsed=parsed,
180+
scheme=scheme,
181+
hostname=hostname,
182+
host_header=_build_host_header(
183+
hostname=hostname,
184+
scheme=scheme,
185+
explicit_port=explicit_port,
186+
),
187+
addresses=addresses,
188+
)
189+
190+
191+
def rewrite_url_host(parsed: ParseResult, ip: str) -> str:
192+
"""Rewrite ``parsed`` to use ``ip`` (literal) in place of the hostname."""
193+
formatted = _format_host(ip)
194+
port = parsed.port
195+
netloc = formatted if port is None else f"{formatted}:{port}"
196+
return urlunparse(parsed._replace(netloc=netloc))
197+
198+
199+
async def send_pinned_async(
200+
client: httpx.AsyncClient,
201+
target: ValidatedTarget,
202+
**request_params: Any,
203+
) -> httpx.Response:
204+
"""Send a request to ``target`` via ``client`` with the IP pinned.
205+
206+
The URL is rewritten to use the first validated IP literally so the
207+
connection bypasses DNS at send time. The original hostname is preserved in
208+
the ``Host`` header (for HTTP routing) and in the ``sni_hostname`` request
209+
extension (for TLS verification, consumed by ``httpcore``).
210+
211+
If the chosen address fails to connect, the next address in
212+
``target.addresses`` is tried. All addresses in the tuple have already
213+
passed ``is_blocked_address``, so this loop never reaches a private IP.
214+
"""
215+
request_params.pop("url", None)
216+
headers = dict(request_params.pop("headers", None) or {})
217+
headers["Host"] = target.host_header
218+
base_extensions = request_params.pop("extensions", None) or {}
219+
extensions = {**base_extensions, "sni_hostname": target.hostname}
220+
221+
last_error: Optional[Exception] = None
222+
for address in target.addresses:
223+
rewritten_url = rewrite_url_host(target.parsed, str(address))
224+
try:
225+
return await client.request(
226+
url=rewritten_url,
227+
headers=headers,
228+
extensions=extensions,
229+
**request_params,
230+
)
231+
except httpx.HTTPError as exc:
232+
last_error = exc
233+
234+
assert last_error is not None # loop ran at least once: addresses is non-empty
235+
raise last_error

src/google/adk/tools/load_web_page.py

Lines changed: 8 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
"""Tool for web browse."""
1818

1919
from dataclasses import dataclass
20-
import ipaddress
21-
import socket
2220
from typing import Any
2321
from urllib.parse import ParseResult
2422
from urllib.parse import urlparse
@@ -28,9 +26,14 @@
2826
from requests.utils import get_environ_proxies
2927
from requests.utils import select_proxy
3028

31-
_ALLOWED_URL_SCHEMES = frozenset({'http', 'https'})
32-
_DEFAULT_PORT_BY_SCHEME = {'http': 80, 'https': 443}
33-
_ResolvedAddress = ipaddress.IPv4Address | ipaddress.IPv6Address
29+
from ._ssrf_protection import _ALLOWED_URL_SCHEMES
30+
from ._ssrf_protection import _build_host_header
31+
from ._ssrf_protection import _parse_ip_literal
32+
from ._ssrf_protection import _ResolvedAddress
33+
from ._ssrf_protection import is_blocked_address as _is_blocked_address
34+
from ._ssrf_protection import is_blocked_hostname as _is_blocked_hostname
35+
from ._ssrf_protection import resolve_host_addresses as _resolve_host_addresses
36+
from ._ssrf_protection import rewrite_url_host as _rewrite_url_host
3437

3538

3639
@dataclass(frozen=True)
@@ -96,25 +99,6 @@ def _failed_to_fetch_message(url: str) -> str:
9699
return f'Failed to fetch url: {url}'
97100

98101

99-
def _format_host(hostname: str) -> str:
100-
if ':' in hostname:
101-
return f'[{hostname}]'
102-
return hostname
103-
104-
105-
def _default_port_for_scheme(scheme: str) -> int:
106-
return _DEFAULT_PORT_BY_SCHEME[scheme]
107-
108-
109-
def _build_host_header(
110-
*, hostname: str, scheme: str, explicit_port: int | None
111-
) -> str:
112-
formatted_hostname = _format_host(hostname)
113-
if explicit_port is None or explicit_port == _default_port_for_scheme(scheme):
114-
return formatted_hostname
115-
return f'{formatted_hostname}:{explicit_port}'
116-
117-
118102
def _parse_request_target(url: str) -> _RequestTarget:
119103
parsed_url = urlparse(url)
120104
scheme = parsed_url.scheme.lower()
@@ -142,52 +126,6 @@ def _parse_request_target(url: str) -> _RequestTarget:
142126
)
143127

144128

145-
def _parse_ip_literal(hostname: str) -> _ResolvedAddress | None:
146-
try:
147-
return ipaddress.ip_address(hostname)
148-
except ValueError:
149-
return None
150-
151-
152-
def _is_blocked_hostname(hostname: str) -> bool:
153-
normalized_hostname = hostname.rstrip('.').lower()
154-
return normalized_hostname == 'localhost' or normalized_hostname.endswith(
155-
'.localhost'
156-
)
157-
158-
159-
def _is_blocked_address(address: _ResolvedAddress) -> bool:
160-
return not address.is_global
161-
162-
163-
def _resolve_host_addresses(hostname: str) -> tuple[_ResolvedAddress, ...]:
164-
resolved_address = _parse_ip_literal(hostname)
165-
166-
if resolved_address is not None:
167-
return (resolved_address,)
168-
169-
try:
170-
address_info = socket.getaddrinfo(
171-
hostname,
172-
None,
173-
type=socket.SOCK_STREAM,
174-
proto=socket.IPPROTO_TCP,
175-
)
176-
except (socket.gaierror, UnicodeError) as exc:
177-
raise ValueError(f'Unable to resolve host: {hostname}') from exc
178-
179-
resolved_addresses: list[_ResolvedAddress] = []
180-
for family, _, _, _, sockaddr in address_info:
181-
if family not in (socket.AF_INET, socket.AF_INET6):
182-
continue
183-
resolved_addresses.append(ipaddress.ip_address(sockaddr[0]))
184-
185-
if not resolved_addresses:
186-
raise ValueError(f'Unable to resolve host: {hostname}')
187-
188-
return tuple(resolved_addresses)
189-
190-
191129
def _get_proxy_url(url: str) -> str | None:
192130
proxies = get_environ_proxies(url)
193131
return select_proxy(url, proxies)
@@ -200,16 +138,6 @@ def _resolve_direct_addresses(hostname: str) -> tuple[_ResolvedAddress, ...]:
200138
return resolved_addresses
201139

202140

203-
def _rewrite_url_host(parsed_url: ParseResult, hostname: str) -> str:
204-
explicit_port = parsed_url.port
205-
formatted_hostname = _format_host(hostname)
206-
if explicit_port is None:
207-
rewritten_netloc = formatted_hostname
208-
else:
209-
rewritten_netloc = f'{formatted_hostname}:{explicit_port}'
210-
return parsed_url._replace(netloc=rewritten_netloc).geturl()
211-
212-
213141
def _fetch_direct_response(
214142
*,
215143
url: str,

src/google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -611,9 +611,22 @@ async def _request(
611611
httpx_client_factory: Optional[HttpxClientFactory] = None,
612612
**request_params,
613613
) -> httpx.Response:
614+
# SSRF defence:
615+
# 1. validate_url rejects bad schemes, localhost-style names, and any
616+
# hostname whose DNS records include a non-globally-routable IP.
617+
# 2. send_pinned_async issues the request against the validated IP so the
618+
# socket can't be flipped by a DNS rebinding between this validation
619+
# and the connect that follows. The Host header and TLS SNI keep the
620+
# original hostname so cert verification still works.
621+
from ..._ssrf_protection import send_pinned_async
622+
from ..._ssrf_protection import validate_url
623+
624+
target = validate_url(request_params.get("url", ""))
614625
verify = request_params.pop("verify", True)
626+
615627
if httpx_client_factory is not None:
616628
async with httpx_client_factory() as client:
617-
return await client.request(**request_params)
629+
return await send_pinned_async(client, target, **request_params)
630+
618631
async with httpx.AsyncClient(verify=verify, timeout=None) as client:
619-
return await client.request(**request_params)
632+
return await send_pinned_async(client, target, **request_params)

0 commit comments

Comments
 (0)