Skip to content

Commit 6c0ddf9

Browse files
committed
feat(security): add subdomain wildcard support to allowed_hosts and allowed_origins
1 parent 3d7b311 commit 6c0ddf9

2 files changed

Lines changed: 264 additions & 20 deletions

File tree

src/mcp/server/transport_security.py

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""DNS rebinding protection for MCP server transports."""
22

33
import logging
4+
from urllib.parse import urlparse
45

56
from pydantic import BaseModel, Field
67
from starlette.requests import Request
@@ -22,12 +23,24 @@ class TransportSecuritySettings(BaseModel):
2223
allowed_hosts: list[str] = Field(default_factory=list)
2324
"""List of allowed Host header values.
2425
26+
Supports exact matches, port wildcards, and subdomain wildcards:
27+
28+
- ``"example.com"`` — exact match
29+
- ``"example.com:*"`` — any port on that host
30+
- ``"*.example.com"`` — any subdomain (or the base domain itself)
31+
2532
Only applies when `enable_dns_rebinding_protection` is `True`.
2633
"""
2734

2835
allowed_origins: list[str] = Field(default_factory=list)
2936
"""List of allowed Origin header values.
3037
38+
Supports exact matches, port wildcards, and subdomain wildcards:
39+
40+
- ``"https://example.com"`` — exact match
41+
- ``"https://example.com:*"`` — any port on that origin
42+
- ``"https://*.example.com"`` — any subdomain (or the base domain itself) with HTTPS
43+
3144
Only applies when `enable_dns_rebinding_protection` is `True`.
3245
"""
3346

@@ -40,46 +53,61 @@ def __init__(self, settings: TransportSecuritySettings | None = None):
4053
# If not specified, disable DNS rebinding protection by default for backwards compatibility
4154
self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False)
4255

43-
def _validate_host(self, host: str | None) -> bool: # pragma: no cover
56+
def _validate_host(self, host: str | None) -> bool:
4457
"""Validate the Host header against allowed values."""
4558
if not host:
4659
logger.warning("Missing Host header in request")
4760
return False
4861

49-
# Check exact match first
5062
if host in self.settings.allowed_hosts:
5163
return True
5264

53-
# Check wildcard port patterns
65+
# Strip port for subdomain wildcard matching
66+
host_without_port = host.split(":")[0]
67+
5468
for allowed in self.settings.allowed_hosts:
5569
if allowed.endswith(":*"):
56-
# Extract base host from pattern
70+
# Port wildcard: e.g., "example.com:*" matches "example.com:8080"
5771
base_host = allowed[:-2]
58-
# Check if the actual host starts with base host and has a port
5972
if host.startswith(base_host + ":"):
6073
return True
74+
elif allowed.startswith("*."):
75+
# Subdomain wildcard: e.g., "*.example.com" matches "example.com"
76+
# and "sub.example.com" (port is ignored)
77+
suffix = allowed[2:]
78+
if host_without_port == suffix or host_without_port.endswith("." + suffix):
79+
return True
6180

6281
logger.warning(f"Invalid Host header: {host}")
6382
return False
6483

65-
def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover
84+
def _validate_origin(self, origin: str | None) -> bool:
6685
"""Validate the Origin header against allowed values."""
6786
# Origin can be absent for same-origin requests
6887
if not origin:
6988
return True
7089

71-
# Check exact match first
7290
if origin in self.settings.allowed_origins:
7391
return True
7492

75-
# Check wildcard port patterns
7693
for allowed in self.settings.allowed_origins:
7794
if allowed.endswith(":*"):
78-
# Extract base origin from pattern
95+
# Port wildcard: e.g., "https://example.com:*" matches "https://example.com:8080"
7996
base_origin = allowed[:-2]
80-
# Check if the actual origin starts with base origin and has a port
8197
if origin.startswith(base_origin + ":"):
8298
return True
99+
elif "://*." in allowed:
100+
# Subdomain wildcard: e.g., "https://*.example.com" matches
101+
# "https://example.com" and "https://sub.example.com"
102+
parsed_allowed = urlparse(allowed)
103+
parsed_origin = urlparse(origin)
104+
if parsed_allowed.scheme != parsed_origin.scheme:
105+
continue
106+
# hostname is "*.suffix" because "://*." is in the pattern
107+
suffix = (parsed_allowed.hostname or "")[2:]
108+
origin_hostname = parsed_origin.hostname or ""
109+
if origin_hostname == suffix or origin_hostname.endswith("." + suffix):
110+
return True
83111

84112
logger.warning(f"Invalid Origin header: {origin}")
85113
return False
@@ -94,7 +122,7 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res
94122
Returns None if validation passes, or an error Response if validation fails.
95123
"""
96124
# Always validate Content-Type for POST requests
97-
if is_post: # pragma: no branch
125+
if is_post:
98126
content_type = request.headers.get("content-type")
99127
if not self._validate_content_type(content_type):
100128
return Response("Invalid Content-Type header", status_code=400)
@@ -103,14 +131,12 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res
103131
if not self.settings.enable_dns_rebinding_protection:
104132
return None
105133

106-
# Validate Host header # pragma: no cover
107-
host = request.headers.get("host") # pragma: no cover
108-
if not self._validate_host(host): # pragma: no cover
109-
return Response("Invalid Host header", status_code=421) # pragma: no cover
134+
host = request.headers.get("host")
135+
if not self._validate_host(host):
136+
return Response("Invalid Host header", status_code=421)
110137

111-
# Validate Origin header # pragma: no cover
112-
origin = request.headers.get("origin") # pragma: no cover
113-
if not self._validate_origin(origin): # pragma: no cover
114-
return Response("Invalid Origin header", status_code=403) # pragma: no cover
138+
origin = request.headers.get("origin")
139+
if not self._validate_origin(origin):
140+
return Response("Invalid Origin header", status_code=403)
115141

116-
return None # pragma: no cover
142+
return None
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
"""Unit tests for TransportSecurityMiddleware."""
2+
3+
import pytest
4+
from starlette.requests import Request
5+
6+
from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings
7+
8+
9+
def make_request(headers: dict[str, str], method: str = "GET") -> Request:
10+
scope = {
11+
"type": "http",
12+
"method": method,
13+
"path": "/",
14+
"query_string": b"",
15+
"headers": [(k.lower().encode(), v.encode()) for k, v in headers.items()],
16+
}
17+
return Request(scope)
18+
19+
20+
def make_middleware(
21+
*,
22+
allowed_hosts: list[str] | None = None,
23+
allowed_origins: list[str] | None = None,
24+
) -> TransportSecurityMiddleware:
25+
return TransportSecurityMiddleware(
26+
TransportSecuritySettings(
27+
enable_dns_rebinding_protection=True,
28+
allowed_hosts=allowed_hosts or [],
29+
allowed_origins=allowed_origins or [],
30+
)
31+
)
32+
33+
34+
# ---------------------------------------------------------------------------
35+
# _validate_host
36+
# ---------------------------------------------------------------------------
37+
38+
39+
def test_validate_host_missing_header():
40+
mw = make_middleware(allowed_hosts=["example.com"])
41+
assert mw._validate_host(None) is False
42+
43+
44+
def test_validate_host_exact_match():
45+
mw = make_middleware(allowed_hosts=["example.com"])
46+
assert mw._validate_host("example.com") is True
47+
48+
49+
def test_validate_host_no_match():
50+
mw = make_middleware(allowed_hosts=["example.com"])
51+
assert mw._validate_host("evil.com") is False
52+
53+
54+
def test_validate_host_port_wildcard_matches():
55+
mw = make_middleware(allowed_hosts=["example.com:*"])
56+
assert mw._validate_host("example.com:8080") is True
57+
58+
59+
def test_validate_host_port_wildcard_different_host():
60+
mw = make_middleware(allowed_hosts=["example.com:*"])
61+
assert mw._validate_host("evil.com:8080") is False
62+
63+
64+
def test_validate_host_subdomain_wildcard_base_domain():
65+
# "*.example.com" should match the base domain itself
66+
mw = make_middleware(allowed_hosts=["*.example.com"])
67+
assert mw._validate_host("example.com") is True
68+
69+
70+
def test_validate_host_subdomain_wildcard_with_subdomain():
71+
mw = make_middleware(allowed_hosts=["*.example.com"])
72+
assert mw._validate_host("app.example.com") is True
73+
74+
75+
def test_validate_host_subdomain_wildcard_with_nested_subdomain():
76+
mw = make_middleware(allowed_hosts=["*.example.com"])
77+
assert mw._validate_host("api.staging.example.com") is True
78+
79+
80+
def test_validate_host_subdomain_wildcard_with_port():
81+
# Port should be stripped before subdomain matching
82+
mw = make_middleware(allowed_hosts=["*.example.com"])
83+
assert mw._validate_host("app.example.com:443") is True
84+
85+
86+
def test_validate_host_subdomain_wildcard_no_match():
87+
mw = make_middleware(allowed_hosts=["*.example.com"])
88+
assert mw._validate_host("notexample.com") is False
89+
90+
91+
def test_validate_host_subdomain_wildcard_suffix_collision():
92+
# "fakeexample.com" must not match "*.example.com"
93+
mw = make_middleware(allowed_hosts=["*.example.com"])
94+
assert mw._validate_host("fakeexample.com") is False
95+
96+
97+
# ---------------------------------------------------------------------------
98+
# _validate_origin
99+
# ---------------------------------------------------------------------------
100+
101+
102+
def test_validate_origin_absent():
103+
mw = make_middleware(allowed_origins=["https://example.com"])
104+
assert mw._validate_origin(None) is True
105+
106+
107+
def test_validate_origin_exact_match():
108+
mw = make_middleware(allowed_origins=["https://example.com"])
109+
assert mw._validate_origin("https://example.com") is True
110+
111+
112+
def test_validate_origin_no_match():
113+
mw = make_middleware(allowed_origins=["https://example.com"])
114+
assert mw._validate_origin("https://evil.com") is False
115+
116+
117+
def test_validate_origin_port_wildcard_matches():
118+
mw = make_middleware(allowed_origins=["https://example.com:*"])
119+
assert mw._validate_origin("https://example.com:8443") is True
120+
121+
122+
def test_validate_origin_port_wildcard_different_host():
123+
mw = make_middleware(allowed_origins=["https://example.com:*"])
124+
assert mw._validate_origin("https://evil.com:8443") is False
125+
126+
127+
def test_validate_origin_subdomain_wildcard_base_domain():
128+
# "https://*.example.com" should match the base domain itself
129+
mw = make_middleware(allowed_origins=["https://*.example.com"])
130+
assert mw._validate_origin("https://example.com") is True
131+
132+
133+
def test_validate_origin_subdomain_wildcard_with_subdomain():
134+
mw = make_middleware(allowed_origins=["https://*.example.com"])
135+
assert mw._validate_origin("https://app.example.com") is True
136+
137+
138+
def test_validate_origin_subdomain_wildcard_scheme_mismatch():
139+
mw = make_middleware(allowed_origins=["https://*.example.com"])
140+
assert mw._validate_origin("http://app.example.com") is False
141+
142+
143+
def test_validate_origin_subdomain_wildcard_no_match():
144+
mw = make_middleware(allowed_origins=["https://*.example.com"])
145+
assert mw._validate_origin("https://evil.com") is False
146+
147+
148+
# ---------------------------------------------------------------------------
149+
# validate_request (integration over the public method)
150+
# ---------------------------------------------------------------------------
151+
152+
153+
@pytest.mark.anyio
154+
async def test_validate_request_post_invalid_content_type():
155+
mw = make_middleware(allowed_hosts=["example.com"])
156+
req = make_request({"host": "example.com", "content-type": "text/plain"}, method="POST")
157+
resp = await mw.validate_request(req, is_post=True)
158+
assert resp is not None
159+
assert resp.status_code == 400
160+
161+
162+
@pytest.mark.anyio
163+
async def test_validate_request_post_valid_content_type_protection_disabled():
164+
mw = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False))
165+
req = make_request({"host": "example.com", "content-type": "application/json"}, method="POST")
166+
resp = await mw.validate_request(req, is_post=True)
167+
assert resp is None
168+
169+
170+
@pytest.mark.anyio
171+
async def test_validate_request_get_protection_disabled():
172+
mw = TransportSecurityMiddleware(TransportSecuritySettings(enable_dns_rebinding_protection=False))
173+
req = make_request({"host": "evil.com"}, method="GET")
174+
resp = await mw.validate_request(req, is_post=False)
175+
assert resp is None
176+
177+
178+
@pytest.mark.anyio
179+
async def test_validate_request_get_invalid_host():
180+
mw = make_middleware(allowed_hosts=["example.com"])
181+
req = make_request({"host": "evil.com"}, method="GET")
182+
resp = await mw.validate_request(req, is_post=False)
183+
assert resp is not None
184+
assert resp.status_code == 421
185+
186+
187+
@pytest.mark.anyio
188+
async def test_validate_request_post_invalid_host():
189+
mw = make_middleware(allowed_hosts=["example.com"])
190+
req = make_request({"host": "evil.com", "content-type": "application/json"}, method="POST")
191+
resp = await mw.validate_request(req, is_post=True)
192+
assert resp is not None
193+
assert resp.status_code == 421
194+
195+
196+
@pytest.mark.anyio
197+
async def test_validate_request_invalid_origin():
198+
mw = make_middleware(allowed_hosts=["example.com"], allowed_origins=["https://example.com"])
199+
req = make_request({"host": "example.com", "origin": "https://evil.com"}, method="GET")
200+
resp = await mw.validate_request(req, is_post=False)
201+
assert resp is not None
202+
assert resp.status_code == 403
203+
204+
205+
@pytest.mark.anyio
206+
async def test_validate_request_all_valid():
207+
mw = make_middleware(allowed_hosts=["example.com"], allowed_origins=["https://example.com"])
208+
req = make_request({"host": "example.com", "origin": "https://example.com"}, method="GET")
209+
resp = await mw.validate_request(req, is_post=False)
210+
assert resp is None
211+
212+
213+
@pytest.mark.anyio
214+
async def test_validate_request_wildcard_host_end_to_end():
215+
mw = make_middleware(allowed_hosts=["*.example.com"], allowed_origins=["https://*.example.com"])
216+
req = make_request({"host": "api.example.com", "origin": "https://app.example.com"}, method="GET")
217+
resp = await mw.validate_request(req, is_post=False)
218+
assert resp is None

0 commit comments

Comments
 (0)