Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,005 changes: 1,005 additions & 0 deletions mcp/distributed_test.go

Large diffs are not rendered by default.

711 changes: 711 additions & 0 deletions mcp/distributed_testutil_test.go

Large diffs are not rendered by default.

30 changes: 13 additions & 17 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
package mcp

import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/gob"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -1039,7 +1037,7 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar
params = new(InitializedParams)
}
var wasInit, wasInitd bool
ss.updateState(func(state *ServerSessionState) {
ss.updateState(ctx, func(state *ServerSessionState) {
wasInit = state.InitializeParams != nil
wasInitd = state.InitializedParams != nil
if wasInit && !wasInitd {
Expand Down Expand Up @@ -1113,13 +1111,14 @@ type ServerSession struct {
state ServerSessionState
}

func (ss *ServerSession) updateState(mut func(*ServerSessionState)) {
// FORK: distributed-sessions - added ctx parameter
func (ss *ServerSession) updateState(ctx context.Context, mut func(*ServerSessionState)) {
ss.mu.Lock()
mut(&ss.state)
copy := ss.state
ss.mu.Unlock()
if c, ok := ss.mcpConn.(serverConnection); ok {
c.sessionUpdated(copy)
c.sessionUpdated(ctx, copy)
}
}

Expand Down Expand Up @@ -1457,7 +1456,7 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam
if params == nil {
return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams)
}
ss.updateState(func(state *ServerSessionState) {
ss.updateState(ctx, func(state *ServerSessionState) {
state.InitializeParams = params
})

Expand Down Expand Up @@ -1485,8 +1484,8 @@ func (ss *ServerSession) cancel(context.Context, *CancelledParams) (Result, erro
return nil, nil
}

func (ss *ServerSession) setLevel(_ context.Context, params *SetLoggingLevelParams) (*emptyResult, error) {
ss.updateState(func(state *ServerSessionState) {
func (ss *ServerSession) setLevel(ctx context.Context, params *SetLoggingLevelParams) (*emptyResult, error) {
ss.updateState(ctx, func(state *ServerSessionState) {
state.LogLevel = params.Level
})
ss.server.opts.Logger.Info("client log level set", "level", params.Level)
Expand Down Expand Up @@ -1527,21 +1526,20 @@ func (ss *ServerSession) startKeepalive(interval time.Duration) {
}

// pageToken is the internal structure for the opaque pagination cursor.
// It will be Gob-encoded and then Base64-encoded for use as a string token.
// It will be JSON-encoded and then Base64-encoded for use as a string token.
type pageToken struct {
LastUID string // The unique ID of the last resource seen.
LastUID string `json:"last_uid"` // The unique ID of the last resource seen.
}

// encodeCursor encodes a unique identifier (UID) into a opaque pagination cursor
// by serializing a pageToken struct.
func encodeCursor(uid string) (string, error) {
var buf bytes.Buffer
token := pageToken{LastUID: uid}
encoder := gob.NewEncoder(&buf)
if err := encoder.Encode(token); err != nil {
encodedBytes, err := json.Marshal(token)
if err != nil {
return "", fmt.Errorf("failed to encode page token: %w", err)
}
return base64.URLEncoding.EncodeToString(buf.Bytes()), nil
return base64.URLEncoding.EncodeToString(encodedBytes), nil
}

// decodeCursor decodes an opaque pagination cursor into the original pageToken struct.
Expand All @@ -1552,9 +1550,7 @@ func decodeCursor(cursor string) (*pageToken, error) {
}

var token pageToken
buf := bytes.NewBuffer(decodedBytes)
decoder := gob.NewDecoder(buf)
if err := decoder.Decode(&token); err != nil {
if err := json.Unmarshal(decodedBytes, &token); err != nil {
return nil, fmt.Errorf("failed to decode page token: %w, cursor: %v", err, cursor)
}
return &token, nil
Expand Down
119 changes: 119 additions & 0 deletions mcp/session_backend.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Copyright 2025 The Go MCP SDK Authors. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.

// FORK: distributed-sessions
// This file implements the SessionBackend interface for distributed
// session management across multiple server replicas.

package mcp

import (
"context"
"errors"
)

// ErrSubscriptionSuperseded is returned by Subscribe when another subscriber
// took over the session's message stream. This typically happens during
// failover when a new pod claims ownership of the session.
var ErrSubscriptionSuperseded = errors.New("mcp: subscription superseded")

// ErrSessionNotFound is returned when a session lookup fails because the
// session does not exist in the backend.
var ErrSessionNotFound = errors.New("mcp: session not found")

// SessionData holds the persistent state for a distributed session.
// This data is stored in the SessionBackend and can be retrieved by any
// server replica.
type SessionData struct {
// SessionID is the unique identifier for this session, generated by
// the SessionBackend.Create method.
SessionID string `json:"sessionId"`

// State contains the MCP protocol state for the session.
State *ServerSessionState `json:"state,omitempty"`

// UserID is the authenticated user ID, used to prevent session hijacking.
// If non-empty, subsequent requests must have the same user ID.
UserID string `json:"userId,omitempty"`
}

// MessageHandler is called by Subscribe for each message that needs to be
// delivered to the session's SSE stream.
//
// Return nil to acknowledge the message (it will be removed from the queue).
// Return an error to signal delivery failure (message may be redelivered).
type MessageHandler func(ctx context.Context, msg []byte) error

// SessionBackend provides session persistence and cross-pod message routing
// for multi-replica MCP server deployments.
//
// Implementations must be safe for concurrent use by multiple goroutines.
//
// The interface combines two responsibilities:
// 1. Session CRUD: Persisting session metadata so any replica can handle requests
// 2. Message routing: Delivering messages from non-owner pods to the SSE-owner pod
//
// Example flow for a distributed deployment:
// 1. Client connects to Pod A, which calls Create() and stores session
// 2. Client's SSE stream connects to Pod A, which calls Subscribe()
// 3. Client's POST request hits Pod B (via load balancer)
// 4. Pod B calls Get() to find the session, then Publish() to route the response
// 5. Pod A's Subscribe handler receives the message and writes to SSE
type SessionBackend interface {
// Create persists a new session and returns its unique ID.
// The implementation generates the session ID (e.g., UUID, ULID).
//
// The returned ID must be globally unique across all sessions.
Create(ctx context.Context, data *SessionData) (string, error)

// Get retrieves session data by ID.
// Returns ErrSessionNotFound if the session does not exist.
Get(ctx context.Context, id string) (*SessionData, error)

// Update persists changes to an existing session.
// The SessionID field in data must match id.
// Returns ErrSessionNotFound if the session does not exist.
Update(ctx context.Context, id string, data *SessionData) error

// Delete removes a session from the backend.
// This should also clean up any associated message queues.
// Returns nil if the session does not exist (idempotent).
Delete(ctx context.Context, id string) error

// Touch updates the session's last activity timestamp without
// modifying other data. Called on each POST request to signal activity.
//
// Backend implementations SHOULD use this to extend the session's TTL.
// For example, a Redis implementation might call EXPIRE to reset the TTL.
// This is the primary mechanism for distributed session timeout management,
// since local timers are lost on pod restarts.
//
// Returns ErrSessionNotFound if the session does not exist.
Touch(ctx context.Context, id string) error

// Publish sends a message to the session's message queue.
// The message will be delivered to the Subscribe handler on the
// pod that owns the session's SSE stream.
//
// Messages are delivered in FIFO order per session.
Publish(ctx context.Context, sessionID string, msg []byte) error

// Subscribe starts receiving messages for a session.
//
// The handler is called for each message in order. Subscribe blocks
// until one of:
// - ctx is cancelled: returns ctx.Err()
// - handler returns an error: returns that error (message not acked)
// - another subscriber takes over: returns ErrSubscriptionSuperseded
// - session is deleted: returns ErrSessionNotFound
//
// When handler returns nil, the message is acknowledged and removed
// from the queue. When handler returns an error, the message may be
// redelivered to another subscriber.
//
// Only one subscriber should be active per session at a time.
// Implementations may enforce this by returning ErrSubscriptionSuperseded
// to the previous subscriber when a new one connects.
Subscribe(ctx context.Context, sessionID string, handler MessageHandler) error
}
Loading
Loading