Skip to content

Commit cee7a22

Browse files
committed
refactor: introduce HeaderAccessor interface for validateHeaders API
Replace Function<String, List<String>> with a dedicated HeaderAccessor interface (getHeader + getHeaderNames) as suggested in review. Transports now pass an HttpServletHeaderAccessor wrapping the request directly, leveraging the servlet's native case-insensitive header lookup instead of extracting all headers into a Map upfront.
1 parent 83c796c commit cee7a22

9 files changed

Lines changed: 198 additions & 100 deletions

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/DefaultServerTransportSecurityValidator.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import java.util.ArrayList;
88
import java.util.List;
9-
import java.util.function.Function;
109

1110
import io.modelcontextprotocol.util.Assert;
1211

@@ -47,14 +46,14 @@ private DefaultServerTransportSecurityValidator(List<String> allowedOrigins, Lis
4746
}
4847

4948
@Override
50-
public void validateHeaders(Function<String, List<String>> headerAccessor) throws ServerTransportSecurityException {
51-
List<String> originValues = headerAccessor.apply(ORIGIN_HEADER);
49+
public void validateHeaders(HeaderAccessor accessor) throws ServerTransportSecurityException {
50+
List<String> originValues = accessor.getHeader(ORIGIN_HEADER);
5251
if (originValues != null && !originValues.isEmpty()) {
5352
validateOrigin(originValues.get(0));
5453
}
5554

5655
if (!allowedHosts.isEmpty()) {
57-
List<String> hostValues = headerAccessor.apply(HOST_HEADER);
56+
List<String> hostValues = accessor.getHeader(HOST_HEADER);
5857
if (hostValues == null || hostValues.isEmpty()) {
5958
throw new ServerTransportSecurityException(421, "Invalid Host header");
6059
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright 2026-2026 the original author or authors.
3+
*/
4+
5+
package io.modelcontextprotocol.server.transport;
6+
7+
import java.util.List;
8+
9+
/**
10+
* Abstraction for accessing HTTP headers from an incoming request. Implementations should
11+
* provide case-insensitive header name lookups (e.g., when backed by
12+
* {@code HttpServletRequest}).
13+
*
14+
* @author Neeraj Bhatt
15+
* @since 0.16.0
16+
* @see ServerTransportSecurityValidator
17+
*/
18+
public interface HeaderAccessor {
19+
20+
/**
21+
* Returns the values of the specified header, or an empty list if the header is not
22+
* present.
23+
* @param name the header name (case-insensitive)
24+
* @return the list of header values, never {@code null}
25+
*/
26+
List<String> getHeader(String name);
27+
28+
/**
29+
* Returns all header names present in the request.
30+
* @return the list of header names, never {@code null}
31+
*/
32+
List<String> getHeaderNames();
33+
34+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Copyright 2026-2026 the original author or authors.
3+
*/
4+
5+
package io.modelcontextprotocol.server.transport;
6+
7+
import java.util.Collections;
8+
import java.util.List;
9+
10+
import jakarta.servlet.http.HttpServletRequest;
11+
12+
/**
13+
* {@link HeaderAccessor} implementation backed by an {@link HttpServletRequest}. Header
14+
* name lookups are case-insensitive as per the Servlet specification.
15+
*
16+
* <p>
17+
* For internal use only.
18+
*
19+
* @author Neeraj Bhatt
20+
* @since 0.16.0
21+
* @see HeaderAccessor
22+
*/
23+
final class HttpServletHeaderAccessor implements HeaderAccessor {
24+
25+
private final HttpServletRequest request;
26+
27+
HttpServletHeaderAccessor(HttpServletRequest request) {
28+
this.request = request;
29+
}
30+
31+
@Override
32+
public List<String> getHeader(String name) {
33+
return Collections.list(this.request.getHeaders(name));
34+
}
35+
36+
@Override
37+
public List<String> getHeaderNames() {
38+
return Collections.list(this.request.getHeaderNames());
39+
}
40+
41+
}

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletRequestUtils.java

Lines changed: 0 additions & 40 deletions
This file was deleted.

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProvider.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
280280
}
281281

282282
try {
283-
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
284-
this.securityValidator.validateHeaders(headers);
283+
this.securityValidator.validateHeaders(new HttpServletHeaderAccessor(request));
285284
}
286285
catch (ServerTransportSecurityException e) {
287286
response.sendError(e.getStatusCode(), e.getMessage());
@@ -353,8 +352,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
353352
}
354353

355354
try {
356-
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
357-
this.securityValidator.validateHeaders(headers);
355+
this.securityValidator.validateHeaders(new HttpServletHeaderAccessor(request));
358356
}
359357
catch (ServerTransportSecurityException e) {
360358
response.sendError(e.getStatusCode(), e.getMessage());

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStatelessServerTransport.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import java.io.IOException;
99
import java.io.PrintWriter;
1010
import java.util.List;
11-
import java.util.Map;
1211

1312
import org.slf4j.Logger;
1413
import org.slf4j.LoggerFactory;
@@ -134,8 +133,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
134133
}
135134

136135
try {
137-
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
138-
this.securityValidator.validateHeaders(headers);
136+
this.securityValidator.validateHeaders(new HttpServletHeaderAccessor(request));
139137
}
140138
catch (ServerTransportSecurityException e) {
141139
response.sendError(e.getStatusCode(), e.getMessage());

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import java.time.Duration;
1111
import java.util.ArrayList;
1212
import java.util.List;
13-
import java.util.Map;
1413
import java.util.concurrent.ConcurrentHashMap;
1514
import java.util.concurrent.locks.ReentrantLock;
1615

@@ -271,8 +270,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
271270
}
272271

273272
try {
274-
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
275-
this.securityValidator.validateHeaders(headers);
273+
this.securityValidator.validateHeaders(new HttpServletHeaderAccessor(request));
276274
}
277275
catch (ServerTransportSecurityException e) {
278276
response.sendError(e.getStatusCode(), e.getMessage());
@@ -407,8 +405,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
407405
}
408406

409407
try {
410-
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
411-
this.securityValidator.validateHeaders(headers);
408+
this.securityValidator.validateHeaders(new HttpServletHeaderAccessor(request));
412409
}
413410
catch (ServerTransportSecurityException e) {
414411
response.sendError(e.getStatusCode(), e.getMessage());
@@ -588,8 +585,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response
588585
}
589586

590587
try {
591-
Map<String, List<String>> headers = HttpServletRequestUtils.extractHeaders(request);
592-
this.securityValidator.validateHeaders(headers);
588+
this.securityValidator.validateHeaders(new HttpServletHeaderAccessor(request));
593589
}
594590
catch (ServerTransportSecurityException e) {
595591
response.sendError(e.getStatusCode(), e.getMessage());

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/ServerTransportSecurityValidator.java

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,21 @@
44

55
package io.modelcontextprotocol.server.transport;
66

7+
import java.util.ArrayList;
8+
import java.util.Collections;
79
import java.util.List;
810
import java.util.Map;
9-
import java.util.function.Function;
11+
import java.util.stream.Collectors;
1012

1113
/**
1214
* Interface for validating HTTP requests in server transports. Implementations can
1315
* validate Origin headers, Host headers, or any other security-related headers according
1416
* to the MCP specification.
1517
*
1618
* <p>
17-
* New implementations should override {@link #validateHeaders(Function)
18-
* validateHeaders(Function)} for more efficient, case-insensitive header access. The
19-
* older {@link #validateHeaders(Map) validateHeaders(Map)} is deprecated and will be
19+
* New implementations should override {@link #validateHeaders(HeaderAccessor)
20+
* validateHeaders(HeaderAccessor)} for more efficient, case-insensitive header access.
21+
* The older {@link #validateHeaders(Map) validateHeaders(Map)} is deprecated and will be
2022
* removed in a future major version.
2123
*
2224
* @author Daniel Garnier-Moiroux
@@ -29,45 +31,74 @@ public interface ServerTransportSecurityValidator {
2931
* A no-op validator that accepts all requests without validation.
3032
*/
3133
ServerTransportSecurityValidator NOOP = new ServerTransportSecurityValidator() {
34+
@Override
35+
public void validateHeaders(Map<String, List<String>> headers) throws ServerTransportSecurityException {
36+
}
37+
38+
@Override
39+
public void validateHeaders(HeaderAccessor accessor) throws ServerTransportSecurityException {
40+
}
3241
};
3342

3443
/**
3544
* Validates the HTTP headers from an incoming request.
3645
*
3746
* <p>
38-
* The default implementation converts the map into a case-insensitive header accessor
39-
* and delegates to {@link #validateHeaders(Function)}.
47+
* The default implementation converts the map into a {@link HeaderAccessor} and
48+
* delegates to {@link #validateHeaders(HeaderAccessor)}.
4049
* @param headers A map of header names to their values (multi-valued headers
4150
* supported)
4251
* @throws ServerTransportSecurityException if validation fails
43-
* @deprecated Use {@link #validateHeaders(Function)} instead for more efficient,
44-
* case-insensitive header access. This method will be removed in a future major
45-
* version.
52+
* @deprecated Use {@link #validateHeaders(HeaderAccessor)} instead for more
53+
* efficient, case-insensitive header access. This method will be removed in a future
54+
* major version.
4655
*/
4756
@Deprecated
4857
default void validateHeaders(Map<String, List<String>> headers) throws ServerTransportSecurityException {
49-
validateHeaders(name -> headers.entrySet()
50-
.stream()
51-
.filter(e -> e.getKey().equalsIgnoreCase(name))
52-
.map(Map.Entry::getValue)
53-
.findFirst()
54-
.orElse(List.of()));
58+
validateHeaders(new HeaderAccessor() {
59+
@Override
60+
public List<String> getHeader(String name) {
61+
return headers.entrySet()
62+
.stream()
63+
.filter(e -> e.getKey().equalsIgnoreCase(name))
64+
.map(Map.Entry::getValue)
65+
.findFirst()
66+
.orElse(List.of());
67+
}
68+
69+
@Override
70+
public List<String> getHeaderNames() {
71+
return List.copyOf(headers.keySet());
72+
}
73+
});
5574
}
5675

5776
/**
58-
* Validates the HTTP headers from an incoming request using a header accessor
59-
* function.
77+
* Validates the HTTP headers from an incoming request using a {@link HeaderAccessor}.
6078
*
6179
* <p>
6280
* New implementations should override this method. Header name lookup through the
6381
* accessor should be case-insensitive (e.g., when backed by
64-
* {@code HttpServletRequest.getHeaders}).
65-
* @param headerAccessor A function that returns the list of values for a given header
66-
* name, or an empty list if the header is not present.
82+
* {@code HttpServletRequest}).
83+
*
84+
* <p>
85+
* The default implementation collects all headers from the accessor into a
86+
* {@link Map} and delegates to the deprecated {@link #validateHeaders(Map)} method,
87+
* so that existing implementations that only override {@link #validateHeaders(Map)}
88+
* continue to work.
89+
* @param accessor provides access to request headers
6790
* @throws ServerTransportSecurityException if validation fails
6891
*/
69-
default void validateHeaders(Function<String, List<String>> headerAccessor)
70-
throws ServerTransportSecurityException {
92+
default void validateHeaders(HeaderAccessor accessor) throws ServerTransportSecurityException {
93+
var collectedHeaders = accessor.getHeaderNames()
94+
.stream()
95+
.collect(Collectors.<String, String, List<String>>toUnmodifiableMap(String::toLowerCase,
96+
accessor::getHeader, (l1, l2) -> {
97+
var merged = new ArrayList<>(l1);
98+
merged.addAll(l2);
99+
return Collections.unmodifiableList(merged);
100+
}));
101+
validateHeaders(collectedHeaders);
71102
}
72103

73104
}

0 commit comments

Comments
 (0)