From 373a5f8dc3a5981d3b84bd55469c3c0360cabfe9 Mon Sep 17 00:00:00 2001 From: linxiuqiang <15060002560@163.com> Date: Sat, 13 Jun 2026 23:06:56 +0800 Subject: [PATCH] Close servlet streamable HTTP transports on async lifecycle events --- ...vletStreamableServerTransportProvider.java | 66 +++-- ...treamableServerTransportProviderTests.java | 277 ++++++++++++++++++ 2 files changed, 312 insertions(+), 31 deletions(-) create mode 100644 mcp-core/src/test/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProviderTests.java diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index e6af4fd0f..e6d1ea10a 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -32,6 +32,8 @@ import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.util.KeepAliveScheduler; import jakarta.servlet.AsyncContext; +import jakarta.servlet.AsyncEvent; +import jakarta.servlet.AsyncListener; import jakarta.servlet.ServletException; import jakarta.servlet.annotation.WebServlet; import jakarta.servlet.http.HttpServlet; @@ -317,6 +319,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) // Check if this is a replay request if (request.getHeader(HttpHeaders.LAST_EVENT_ID) != null) { String lastId = request.getHeader(HttpHeaders.LAST_EVENT_ID); + registerAsyncLifecycle(asyncContext, sessionId, sessionTransport::close); try { session.replay(lastId) @@ -330,13 +333,13 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) } catch (Exception e) { logger.error("Failed to replay message: {}", e.getMessage()); - asyncContext.complete(); + sessionTransport.close(); } }); } catch (Exception e) { logger.error("Failed to replay messages: {}", e.getMessage()); - asyncContext.complete(); + sessionTransport.close(); } } else { @@ -344,30 +347,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session .listeningStream(sessionTransport); - asyncContext.addListener(new jakarta.servlet.AsyncListener() { - @Override - public void onComplete(jakarta.servlet.AsyncEvent event) throws IOException { - logger.debug("SSE connection completed for session: {}", sessionId); - listeningStream.close(); - } - - @Override - public void onTimeout(jakarta.servlet.AsyncEvent event) throws IOException { - logger.debug("SSE connection timed out for session: {}", sessionId); - listeningStream.close(); - } - - @Override - public void onError(jakarta.servlet.AsyncEvent event) throws IOException { - logger.debug("SSE connection error for session: {}", sessionId); - listeningStream.close(); - } - - @Override - public void onStartAsync(jakarta.servlet.AsyncEvent event) throws IOException { - // No action needed - } - }); + registerAsyncLifecycle(asyncContext, sessionId, listeningStream::close); } } catch (Exception e) { @@ -519,6 +499,7 @@ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { HttpServletStreamableMcpSessionTransport sessionTransport = new HttpServletStreamableMcpSessionTransport( sessionId, asyncContext, response.getWriter()); + registerAsyncLifecycle(asyncContext, sessionId, sessionTransport::close); try { session.responseStream(jsonrpcRequest, sessionTransport) @@ -527,7 +508,7 @@ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { } catch (Exception e) { logger.error("Failed to handle request stream: {}", e.getMessage()); - asyncContext.complete(); + sessionTransport.close(); } } else { @@ -557,6 +538,32 @@ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { } } + private void registerAsyncLifecycle(AsyncContext asyncContext, String sessionId, Runnable onClose) { + asyncContext.addListener(new AsyncListener() { + @Override + public void onComplete(AsyncEvent event) throws IOException { + logger.debug("SSE async context completed for session: {}", sessionId); + onClose.run(); + } + + @Override + public void onTimeout(AsyncEvent event) throws IOException { + logger.debug("SSE async context timed out for session: {}", sessionId); + onClose.run(); + } + + @Override + public void onError(AsyncEvent event) throws IOException { + logger.debug("SSE async context errored for session: {}", sessionId); + onClose.run(); + } + + @Override + public void onStartAsync(AsyncEvent event) throws IOException { + } + }); + } + /** * Handles DELETE requests for session deletion. * @param request The HTTP servlet request @@ -747,8 +754,7 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId } catch (Exception e) { logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage()); - HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId); - this.asyncContext.complete(); + this.close(); } finally { lock.unlock(); @@ -792,8 +798,6 @@ public void close() { } this.closed = true; - - // HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId); this.asyncContext.complete(); logger.debug("Successfully completed async context for session {}", sessionId); } diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProviderTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProviderTests.java new file mode 100644 index 000000000..a7125681f --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProviderTests.java @@ -0,0 +1,277 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringReader; +import java.io.StringWriter; +import java.time.Duration; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + +import jakarta.servlet.AsyncContext; +import jakarta.servlet.AsyncEvent; +import jakarta.servlet.AsyncListener; +import jakarta.servlet.ServletOutputStream; +import jakarta.servlet.WriteListener; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import reactor.core.publisher.Mono; + +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.server.McpNotificationHandler; +import io.modelcontextprotocol.server.McpRequestHandler; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStreamableServerSession; +import io.modelcontextprotocol.spec.json.gson.GsonMcpJsonMapper; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class HttpServletStreamableServerTransportProviderTests { + + private final McpJsonMapper jsonMapper = new GsonMcpJsonMapper(); + + @Test + void getListenerDisconnectClosesListeningStream() throws Exception { + HttpServletStreamableServerTransportProvider provider = HttpServletStreamableServerTransportProvider.builder() + .jsonMapper(this.jsonMapper) + .mcpEndpoint("/mcp") + .build(); + String sessionId = "session-get"; + McpStreamableServerSession session = createSession(sessionId, Map.of(), Map.of()); + provider.setSessionFactory(request -> new McpStreamableServerSession.McpStreamableServerSessionInit(session, + Mono.just(testInitializeResult()))); + initializeSession(provider, sessionId); + + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + AsyncContext asyncContext = mock(AsyncContext.class); + StringWriter output = new StringWriter(); + PrintWriter writer = new PrintWriter(output, true); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getHeader("Accept")).thenReturn("text/event-stream"); + when(request.getHeader(HttpHeaders.MCP_SESSION_ID)).thenReturn(sessionId); + when(request.getHeader(HttpHeaders.LAST_EVENT_ID)).thenReturn(null); + when(request.getHeaderNames()).thenReturn(java.util.Collections.emptyEnumeration()); + when(request.startAsync()).thenReturn(asyncContext); + when(response.getWriter()).thenReturn(writer); + + provider.doGet(request, response); + + provider.notifyClient(sessionId, "server/notification", Map.of("connected", true)).block(); + assertThat(output.toString()).contains("server/notification"); + + ArgumentCaptor listenerCaptor = ArgumentCaptor.forClass(AsyncListener.class); + verify(asyncContext).addListener(listenerCaptor.capture()); + + listenerCaptor.getValue().onError(new AsyncEvent(asyncContext)); + + verify(asyncContext).complete(); + output.getBuffer().setLength(0); + + assertThatThrownBy(() -> provider.notifyClient(sessionId, "server/notification", Map.of("after", true)).block()) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining(sessionId); + assertThat(output.toString()).isEmpty(); + } + + @Test + void getReplayRequestRegistersAsyncListenerAndClosesTransportOnDisconnect() throws Exception { + HttpServletStreamableServerTransportProvider provider = HttpServletStreamableServerTransportProvider.builder() + .jsonMapper(this.jsonMapper) + .mcpEndpoint("/mcp") + .build(); + String sessionId = "session-replay"; + McpStreamableServerSession session = createSession(sessionId, Map.of(), Map.of()); + provider.setSessionFactory(request -> new McpStreamableServerSession.McpStreamableServerSessionInit(session, + Mono.just(testInitializeResult()))); + initializeSession(provider, sessionId); + + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + AsyncContext asyncContext = mock(AsyncContext.class); + StringWriter output = new StringWriter(); + PrintWriter writer = new PrintWriter(output, true); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getHeader("Accept")).thenReturn("text/event-stream"); + when(request.getHeader(HttpHeaders.MCP_SESSION_ID)).thenReturn(sessionId); + when(request.getHeader(HttpHeaders.LAST_EVENT_ID)).thenReturn("last-1"); + when(request.getHeaderNames()).thenReturn(java.util.Collections.emptyEnumeration()); + when(request.startAsync()).thenReturn(asyncContext); + when(response.getWriter()).thenReturn(writer); + + provider.doGet(request, response); + + ArgumentCaptor listenerCaptor = ArgumentCaptor.forClass(AsyncListener.class); + verify(asyncContext).addListener(listenerCaptor.capture()); + + listenerCaptor.getValue().onError(new AsyncEvent(asyncContext)); + + verify(asyncContext).complete(); + assertThat(output.toString()).isEmpty(); + } + + @Test + void postStreamingRequestRegistersAsyncListenerAndClosesTransportOnDisconnect() throws Exception { + HttpServletStreamableServerTransportProvider provider = HttpServletStreamableServerTransportProvider.builder() + .jsonMapper(this.jsonMapper) + .mcpEndpoint("/mcp") + .build(); + + String sessionId = "session-post"; + McpRequestHandler echoHandler = (exchange, params) -> Mono.just(params); + McpStreamableServerSession session = createSession(sessionId, Map.of("echo", echoHandler), Map.of()); + provider.setSessionFactory(request -> new McpStreamableServerSession.McpStreamableServerSessionInit(session, + Mono.just(testInitializeResult()))); + + initializeSession(provider, sessionId); + + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + AsyncContext asyncContext = mock(AsyncContext.class); + StringWriter output = new StringWriter(); + PrintWriter writer = new PrintWriter(output, true); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getHeader("Accept")).thenReturn("text/event-stream, application/json"); + when(request.getHeader(HttpHeaders.MCP_SESSION_ID)).thenReturn(sessionId); + when(request.getHeaderNames()).thenReturn(java.util.Collections.emptyEnumeration()); + when(request.getReader()).thenReturn(new BufferedReader(new StringReader(this.jsonMapper + .writeValueAsString(new McpSchema.JSONRPCRequest("echo", "request-1", Map.of("value", "ok")))))); + when(request.startAsync()).thenReturn(asyncContext); + when(response.getWriter()).thenReturn(writer); + + provider.doPost(request, response); + + assertThat(output.toString()).contains("\"result\":{\"value\":\"ok\"}"); + + ArgumentCaptor listenerCaptor = ArgumentCaptor.forClass(AsyncListener.class); + verify(asyncContext).addListener(listenerCaptor.capture()); + + listenerCaptor.getValue().onError(new AsyncEvent(asyncContext)); + + verify(asyncContext).complete(); + } + + @Test + void sendFailureClosesOnlyCurrentTransportWithoutRemovingSession() throws Exception { + HttpServletStreamableServerTransportProvider provider = HttpServletStreamableServerTransportProvider.builder() + .jsonMapper(this.jsonMapper) + .mcpEndpoint("/mcp") + .build(); + String sessionId = "session-send-failure"; + McpStreamableServerSession session = createSession(sessionId, Map.of(), Map.of()); + provider.setSessionFactory(request -> new McpStreamableServerSession.McpStreamableServerSessionInit(session, + Mono.just(testInitializeResult()))); + initializeSession(provider, sessionId); + + HttpServletRequest getRequest = mock(HttpServletRequest.class); + HttpServletResponse getResponse = mock(HttpServletResponse.class); + AsyncContext getAsyncContext = mock(AsyncContext.class); + PrintWriter failingWriter = new PrintWriter(new FailingServletOutputStream(), true); + + when(getRequest.getRequestURI()).thenReturn("/mcp"); + when(getRequest.getHeader("Accept")).thenReturn("text/event-stream"); + when(getRequest.getHeader(HttpHeaders.MCP_SESSION_ID)).thenReturn(sessionId); + when(getRequest.getHeader(HttpHeaders.LAST_EVENT_ID)).thenReturn(null); + when(getRequest.getHeaderNames()).thenReturn(java.util.Collections.emptyEnumeration()); + when(getRequest.startAsync()).thenReturn(getAsyncContext); + when(getResponse.getWriter()).thenReturn(failingWriter); + + provider.doGet(getRequest, getResponse); + provider.notifyClient(sessionId, "server/notification", Map.of("boom", true)).block(); + + verify(getAsyncContext).complete(); + + HttpServletRequest deleteRequest = mock(HttpServletRequest.class); + HttpServletResponse deleteResponse = mock(HttpServletResponse.class); + when(deleteRequest.getRequestURI()).thenReturn("/mcp"); + when(deleteRequest.getHeader(HttpHeaders.MCP_SESSION_ID)).thenReturn(sessionId); + when(deleteRequest.getHeaderNames()).thenReturn(java.util.Collections.emptyEnumeration()); + + provider.doDelete(deleteRequest, deleteResponse); + + verify(deleteResponse).setStatus(HttpServletResponse.SC_OK); + } + + private void initializeSession(HttpServletStreamableServerTransportProvider provider, String expectedSessionId) + throws Exception { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + StringWriter output = new StringWriter(); + PrintWriter writer = new PrintWriter(output, true); + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getHeader("Accept")).thenReturn("text/event-stream, application/json"); + when(request.getHeaderNames()).thenReturn(java.util.Collections.emptyEnumeration()); + when(request.getReader()).thenReturn(new BufferedReader(new StringReader(this.jsonMapper.writeValueAsString( + new McpSchema.JSONRPCRequest(McpSchema.METHOD_INITIALIZE, "init-1", testInitializeRequest()))))); + when(response.getWriter()).thenReturn(writer); + + doAnswer(invocation -> { + assertThat(invocation.getArgument(1, String.class)).isEqualTo(expectedSessionId); + return null; + }).when(response).setHeader(eq(HttpHeaders.MCP_SESSION_ID), any(String.class)); + + provider.doPost(request, response); + verify(response).setStatus(HttpServletResponse.SC_OK); + } + + private McpStreamableServerSession createSession(String sessionId, + Map> requestHandlers, + Map notificationHandlers) { + return new McpStreamableServerSession(sessionId, testInitializeRequest().capabilities(), + testInitializeRequest().clientInfo(), Duration.ofSeconds(2), requestHandlers, notificationHandlers); + } + + private McpSchema.InitializeRequest testInitializeRequest() { + return McpSchema.InitializeRequest + .builder("2025-11-25", new McpSchema.ClientCapabilities(null, null, null, null), + new McpSchema.Implementation("test-client", "1.0.0")) + .build(); + } + + private McpSchema.InitializeResult testInitializeResult() { + return McpSchema.InitializeResult + .builder("2025-11-25", new McpSchema.ServerCapabilities(null, null, null, null, null, null), + new McpSchema.Implementation("test-server", "1.0.0")) + .build(); + } + + private static final class FailingServletOutputStream extends ServletOutputStream { + + private final AtomicReference failure = new AtomicReference<>( + new IOException("Client disconnected")); + + @Override + public void write(int b) throws IOException { + throw this.failure.get(); + } + + @Override + public boolean isReady() { + return false; + } + + @Override + public void setWriteListener(WriteListener writeListener) { + } + + } + +}