Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import io.modelcontextprotocol.json.McpJsonMapper;
import io.modelcontextprotocol.util.KeepAliveScheduler;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.AsyncEvent;
import jakarta.servlet.AsyncListener;
import jakarta.servlet.ServletException;
import jakarta.servlet.annotation.WebServlet;
import jakarta.servlet.http.HttpServlet;
Expand Down Expand Up @@ -317,6 +319,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
// Check if this is a replay request
if (request.getHeader(HttpHeaders.LAST_EVENT_ID) != null) {
String lastId = request.getHeader(HttpHeaders.LAST_EVENT_ID);
registerAsyncLifecycle(asyncContext, sessionId, sessionTransport::close);

try {
session.replay(lastId)
Expand All @@ -330,44 +333,21 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
}
catch (Exception e) {
logger.error("Failed to replay message: {}", e.getMessage());
asyncContext.complete();
sessionTransport.close();
}
});
}
catch (Exception e) {
logger.error("Failed to replay messages: {}", e.getMessage());
asyncContext.complete();
sessionTransport.close();
}
}
else {
// Establish new listening stream
McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session
.listeningStream(sessionTransport);

asyncContext.addListener(new jakarta.servlet.AsyncListener() {
@Override
public void onComplete(jakarta.servlet.AsyncEvent event) throws IOException {
logger.debug("SSE connection completed for session: {}", sessionId);
listeningStream.close();
}

@Override
public void onTimeout(jakarta.servlet.AsyncEvent event) throws IOException {
logger.debug("SSE connection timed out for session: {}", sessionId);
listeningStream.close();
}

@Override
public void onError(jakarta.servlet.AsyncEvent event) throws IOException {
logger.debug("SSE connection error for session: {}", sessionId);
listeningStream.close();
}

@Override
public void onStartAsync(jakarta.servlet.AsyncEvent event) throws IOException {
// No action needed
}
});
registerAsyncLifecycle(asyncContext, sessionId, listeningStream::close);
}
}
catch (Exception e) {
Expand Down Expand Up @@ -519,6 +499,7 @@ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) {

HttpServletStreamableMcpSessionTransport sessionTransport = new HttpServletStreamableMcpSessionTransport(
sessionId, asyncContext, response.getWriter());
registerAsyncLifecycle(asyncContext, sessionId, sessionTransport::close);

try {
session.responseStream(jsonrpcRequest, sessionTransport)
Expand All @@ -527,7 +508,7 @@ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) {
}
catch (Exception e) {
logger.error("Failed to handle request stream: {}", e.getMessage());
asyncContext.complete();
sessionTransport.close();
}
}
else {
Expand Down Expand Up @@ -557,6 +538,32 @@ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) {
}
}

private void registerAsyncLifecycle(AsyncContext asyncContext, String sessionId, Runnable onClose) {
asyncContext.addListener(new AsyncListener() {
@Override
public void onComplete(AsyncEvent event) throws IOException {
logger.debug("SSE async context completed for session: {}", sessionId);
onClose.run();
}

@Override
public void onTimeout(AsyncEvent event) throws IOException {
logger.debug("SSE async context timed out for session: {}", sessionId);
onClose.run();
}

@Override
public void onError(AsyncEvent event) throws IOException {
logger.debug("SSE async context errored for session: {}", sessionId);
onClose.run();
}

@Override
public void onStartAsync(AsyncEvent event) throws IOException {
}
});
}

/**
* Handles DELETE requests for session deletion.
* @param request The HTTP servlet request
Expand Down Expand Up @@ -747,8 +754,7 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message, String messageId
}
catch (Exception e) {
logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage());
HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId);
this.asyncContext.complete();
this.close();
}
finally {
lock.unlock();
Expand Down Expand Up @@ -792,8 +798,6 @@ public void close() {
}

this.closed = true;

// HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId);
this.asyncContext.complete();
logger.debug("Successfully completed async context for session {}", sessionId);
}
Expand Down
Loading