From 9f16fd3d713323086eb0dfd33c45c53bcec0fb74 Mon Sep 17 00:00:00 2001 From: Boris Tyshkevich Date: Thu, 14 May 2026 18:24:49 +0200 Subject: [PATCH 1/2] oauth: stateless JWE for pending-auth + auth-code (HA-safe) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the per-pod oauthStateStore (in-memory maps for pending-auth and issued auth codes) with stateless JWE tokens, using the existing encodeOAuthJWE/decodeOAuthJWE + HKDF infrastructure already proven for DCR client_id and refresh tokens. Why: in forward mode altinity-mcp is the OAuth AS — /.well-known points clients at MCP's own /authorize and /token, so with replicas>=2 and no sticky sessions the legs of the OAuth dance land on different pods and the in-memory state lookup fails (~75% of the time). Encoding the state into the Google `state` parameter and the MCP auth `code` makes any replica with the shared signing_secret able to decrypt either side. Single-use enforcement on auth codes is intentionally not done server- side: codes are bound to the client's PKCE verifier (RFC 7636) and live 60s, so replay within the TTL is limited to whoever holds the verifier. Trading strict RFC 6749 §4.1.2 single-use for zero shared state. New HKDF labels: altinity-mcp/oauth/pending-auth/v1 altinity-mcp/oauth/auth-code/v1 Whitelist additions in jwe_auth: resource, upstream_pkce_verifier. Removed: oauthStateStore, its mutex, eviction logic, maxOAuthStateEntries, randomToken, application.oauthState/oauthStateMu fields, getOAuthStateStore. Replaced TestOAuthStateStore*/TestOAuthStateStoreEviction with TestOAuthStateJWERoundTrip covering round-trip, cross-pod portability, mismatched-secret rejection, expiry, tamper, and missing secret. Affects forward-mode deployments only (antalya, billing, otel-google). Gating-mode (otel, github via Auth0 CIMD) was already HA-safe — Auth0 owns the OAuth surface there. --- cmd/altinity-mcp/main.go | 12 -- cmd/altinity-mcp/main_test.go | 33 --- cmd/altinity-mcp/oauth_server.go | 278 +++++++++++++++----------- cmd/altinity-mcp/oauth_server_test.go | 248 ++++++++++++----------- pkg/jwe_auth/jwe_auth.go | 2 + 5 files changed, 298 insertions(+), 275 deletions(-) diff --git a/cmd/altinity-mcp/main.go b/cmd/altinity-mcp/main.go index 2a21c27..2fdfd9d 100644 --- a/cmd/altinity-mcp/main.go +++ b/cmd/altinity-mcp/main.go @@ -964,8 +964,6 @@ type application struct { mcpServer *altinitymcp.ClickHouseJWEServer httpSrv *http.Server httpSrvMutex sync.RWMutex - oauthState *oauthStateStore - oauthStateMu sync.Mutex configFile string configMutex sync.RWMutex stopConfigReload chan struct{} @@ -985,15 +983,6 @@ func (a *application) getHTTPServer() *http.Server { return a.httpSrv } -func (a *application) getOAuthStateStore() *oauthStateStore { - a.oauthStateMu.Lock() - defer a.oauthStateMu.Unlock() - if a.oauthState == nil { - a.oauthState = newOAuthStateStore() - } - return a.oauthState -} - func newApplication(ctx context.Context, cfg config.Config, cmd CommandInterface) (*application, error) { if err := validateOAuthRuntimeConfig(cfg); err != nil { return nil, err @@ -1052,7 +1041,6 @@ func newApplication(ctx context.Context, cfg config.Config, cmd CommandInterface app := &application{ config: cfg, mcpServer: mcpServer, - oauthState: newOAuthStateStore(), configFile: cmd.String("config"), stopConfigReload: make(chan struct{}), } diff --git a/cmd/altinity-mcp/main_test.go b/cmd/altinity-mcp/main_test.go index 709c327..d3d37ff 100644 --- a/cmd/altinity-mcp/main_test.go +++ b/cmd/altinity-mcp/main_test.go @@ -3908,39 +3908,6 @@ func TestHealthHandler_CHUnavailable(t *testing.T) { require.Equal(t, "unhealthy", body["status"]) } -func TestOAuthStateStoreEviction(t *testing.T) { - t.Parallel() - store := newOAuthStateStore() - - // Fill pending auth to capacity - for i := 0; i < maxOAuthStateEntries; i++ { - store.putPendingAuth(fmt.Sprintf("pending-%d", i), oauthPendingAuth{ - ExpiresAt: time.Now().Add(time.Duration(i) * time.Second), - }) - } - - // Adding one more should evict the oldest - store.putPendingAuth("new-pending", oauthPendingAuth{ExpiresAt: time.Now().Add(time.Hour)}) - _, ok := store.consumePendingAuth("pending-0") // oldest should be evicted - require.False(t, ok) - _, ok = store.consumePendingAuth("new-pending") - require.True(t, ok) - - // Fill auth codes to capacity - for i := 0; i < maxOAuthStateEntries; i++ { - store.putAuthCode(fmt.Sprintf("code-%d", i), oauthIssuedCode{ - ExpiresAt: time.Now().Add(time.Duration(i) * time.Second), - }) - } - - // Adding one more should evict the oldest - store.putAuthCode("new-code", oauthIssuedCode{ExpiresAt: time.Now().Add(time.Hour)}) - _, ok = store.consumeAuthCode("code-0") - require.False(t, ok) - _, ok = store.consumeAuthCode("new-code") - require.True(t, ok) -} - func TestToolInputSettingsCLIFlag(t *testing.T) { cases := []struct { name string diff --git a/cmd/altinity-mcp/oauth_server.go b/cmd/altinity-mcp/oauth_server.go index c6edd24..c1c663b 100644 --- a/cmd/altinity-mcp/oauth_server.go +++ b/cmd/altinity-mcp/oauth_server.go @@ -15,7 +15,6 @@ import ( "net/url" "slices" "strings" - "sync" "time" "github.com/altinity/altinity-mcp/pkg/config" @@ -100,107 +99,15 @@ type oauthIssuedCode struct { AccessTokenExpiry time.Time } -// maxOAuthStateEntries caps each map in the state store to prevent memory -// exhaustion from floods of unauthenticated /oauth/authorize requests. -const maxOAuthStateEntries = 10000 - -type oauthStateStore struct { - mu sync.Mutex - pendingAuth map[string]oauthPendingAuth - authCodes map[string]oauthIssuedCode -} - -func newOAuthStateStore() *oauthStateStore { - return &oauthStateStore{ - pendingAuth: make(map[string]oauthPendingAuth), - authCodes: make(map[string]oauthIssuedCode), - } -} - -func (s *oauthStateStore) cleanupExpiredLocked(now time.Time) { - for key, pending := range s.pendingAuth { - if !pending.ExpiresAt.IsZero() && now.After(pending.ExpiresAt) { - delete(s.pendingAuth, key) - } - } - for key, issued := range s.authCodes { - if !issued.ExpiresAt.IsZero() && now.After(issued.ExpiresAt) { - delete(s.authCodes, key) - } - } -} - -// evictOldestPendingLocked removes the entry with the earliest expiry. -func (s *oauthStateStore) evictOldestPendingLocked() { - var oldestKey string - var oldestTime time.Time - for key, pending := range s.pendingAuth { - if oldestKey == "" || pending.ExpiresAt.Before(oldestTime) { - oldestKey = key - oldestTime = pending.ExpiresAt - } - } - if oldestKey != "" { - delete(s.pendingAuth, oldestKey) - } -} - -// evictOldestCodeLocked removes the entry with the earliest expiry. -func (s *oauthStateStore) evictOldestCodeLocked() { - var oldestKey string - var oldestTime time.Time - for key, issued := range s.authCodes { - if oldestKey == "" || issued.ExpiresAt.Before(oldestTime) { - oldestKey = key - oldestTime = issued.ExpiresAt - } - } - if oldestKey != "" { - delete(s.authCodes, oldestKey) - } -} - -func (s *oauthStateStore) putPendingAuth(id string, pending oauthPendingAuth) { - s.mu.Lock() - defer s.mu.Unlock() - s.cleanupExpiredLocked(time.Now()) - if len(s.pendingAuth) >= maxOAuthStateEntries { - s.evictOldestPendingLocked() - } - s.pendingAuth[id] = pending -} - -func (s *oauthStateStore) consumePendingAuth(id string) (oauthPendingAuth, bool) { - s.mu.Lock() - defer s.mu.Unlock() - s.cleanupExpiredLocked(time.Now()) - pending, ok := s.pendingAuth[id] - if ok { - delete(s.pendingAuth, id) - } - return pending, ok -} - -func (s *oauthStateStore) putAuthCode(id string, issued oauthIssuedCode) { - s.mu.Lock() - defer s.mu.Unlock() - s.cleanupExpiredLocked(time.Now()) - if len(s.authCodes) >= maxOAuthStateEntries { - s.evictOldestCodeLocked() - } - s.authCodes[id] = issued -} - -func (s *oauthStateStore) consumeAuthCode(id string) (oauthIssuedCode, bool) { - s.mu.Lock() - defer s.mu.Unlock() - s.cleanupExpiredLocked(time.Now()) - issued, ok := s.authCodes[id] - if ok { - delete(s.authCodes, id) - } - return issued, ok -} +// OAuth pending-auth and issued-code state are encoded as stateless JWE tokens +// (see encodePendingAuth / encodeAuthCode below) so any replica can decode +// state minted by any other replica. There is no in-memory store, no eviction, +// and no per-pod size cap — expiry is enforced by the `exp` claim inside each +// JWE. Single-use on auth codes is intentionally NOT enforced server-side: +// codes are bound to the client's PKCE verifier (RFC 7636) and live for at +// most defaultAuthCodeTTLSeconds, so replay within the TTL is limited to +// whoever holds the verifier — i.e. the legitimate client. Trading strict +// RFC 6749 §4.1.2 single-use for zero shared state across replicas. func writeOAuthTokenError(w http.ResponseWriter, status int, code, description string) { w.Header().Set("Content-Type", "application/json") @@ -264,8 +171,10 @@ const oauthKidV1 = "v1" // Bumping the /vN suffix in any single label rotates that one key without // disturbing the others. const ( - hkdfInfoOAuthClientID = "altinity-mcp/oauth/client-id/v1" - hkdfInfoOAuthRefresh = "altinity-mcp/oauth/refresh-token/v1" + hkdfInfoOAuthClientID = "altinity-mcp/oauth/client-id/v1" + hkdfInfoOAuthRefresh = "altinity-mcp/oauth/refresh-token/v1" + hkdfInfoOAuthPendingAuth = "altinity-mcp/oauth/pending-auth/v1" + hkdfInfoOAuthAuthCode = "altinity-mcp/oauth/auth-code/v1" ) // encodeOAuthJWE emits a JWE-wrapped JSON document of `claims`, encrypted @@ -336,6 +245,139 @@ func decodeOAuthJWE(secret []byte, info string, token string) (map[string]interf return jwe_auth.ParseAndDecryptJWE(token, secret, secret) } +// encodePendingAuth wraps an oauthPendingAuth into a stateless JWE used as the +// `state` parameter sent to the upstream IdP at /authorize. Any replica with +// the shared signing_secret can decode it at /callback. +func (a *application) encodePendingAuth(p oauthPendingAuth) (string, error) { + secret, err := a.mustJWESecret() + if err != nil { + return "", err + } + claims := map[string]interface{}{ + "client_id": p.ClientID, + "redirect_uri": p.RedirectURI, + "scope": p.Scope, + "client_state": p.ClientState, + "code_challenge": p.CodeChallenge, + "code_challenge_method": p.CodeChallengeMethod, + "resource": p.Resource, + "upstream_pkce_verifier": p.UpstreamPKCEVerifier, + "exp": p.ExpiresAt.Unix(), + } + return encodeOAuthJWE(secret, hkdfInfoOAuthPendingAuth, claims) +} + +// decodePendingAuth is the inverse of encodePendingAuth. Returns (pending, +// false) when the token is unparseable, tampered, expired, or carries claims +// outside the JWE whitelist. +func (a *application) decodePendingAuth(token string) (oauthPendingAuth, bool) { + secret := a.oauthJWESecret() + if len(secret) == 0 { + return oauthPendingAuth{}, false + } + claims, err := decodeOAuthJWE(secret, hkdfInfoOAuthPendingAuth, token) + if err != nil { + return oauthPendingAuth{}, false + } + p := oauthPendingAuth{ + ClientID: stringFromClaims(claims, "client_id"), + RedirectURI: stringFromClaims(claims, "redirect_uri"), + Scope: stringFromClaims(claims, "scope"), + ClientState: stringFromClaims(claims, "client_state"), + CodeChallenge: stringFromClaims(claims, "code_challenge"), + CodeChallengeMethod: stringFromClaims(claims, "code_challenge_method"), + Resource: stringFromClaims(claims, "resource"), + UpstreamPKCEVerifier: stringFromClaims(claims, "upstream_pkce_verifier"), + ExpiresAt: unixFromClaims(claims, "exp"), + } + return p, true +} + +// encodeAuthCode wraps an oauthIssuedCode into a stateless JWE used as the +// `code` parameter returned to the MCP client at /callback. Redeemed at +// /token by decodeAuthCode on any replica. +func (a *application) encodeAuthCode(c oauthIssuedCode) (string, error) { + secret, err := a.mustJWESecret() + if err != nil { + return "", err + } + claims := map[string]interface{}{ + "client_id": c.ClientID, + "redirect_uri": c.RedirectURI, + "scope": c.Scope, + "code_challenge": c.CodeChallenge, + "code_challenge_method": c.CodeChallengeMethod, + "resource": c.Resource, + "upstream_bearer_token": c.UpstreamBearerToken, + "upstream_refresh_token": c.UpstreamRefreshToken, + "upstream_token_type": c.UpstreamTokenType, + "sub": c.Subject, + "email": c.Email, + "name": c.Name, + "hd": c.HostedDomain, + "email_verified": c.EmailVerified, + "exp": c.ExpiresAt.Unix(), + "access_token_exp": c.AccessTokenExpiry.Unix(), + } + return encodeOAuthJWE(secret, hkdfInfoOAuthAuthCode, claims) +} + +// decodeAuthCode is the inverse of encodeAuthCode. +func (a *application) decodeAuthCode(token string) (oauthIssuedCode, bool) { + secret := a.oauthJWESecret() + if len(secret) == 0 { + return oauthIssuedCode{}, false + } + claims, err := decodeOAuthJWE(secret, hkdfInfoOAuthAuthCode, token) + if err != nil { + return oauthIssuedCode{}, false + } + c := oauthIssuedCode{ + ClientID: stringFromClaims(claims, "client_id"), + RedirectURI: stringFromClaims(claims, "redirect_uri"), + Scope: stringFromClaims(claims, "scope"), + CodeChallenge: stringFromClaims(claims, "code_challenge"), + CodeChallengeMethod: stringFromClaims(claims, "code_challenge_method"), + Resource: stringFromClaims(claims, "resource"), + UpstreamBearerToken: stringFromClaims(claims, "upstream_bearer_token"), + UpstreamRefreshToken: stringFromClaims(claims, "upstream_refresh_token"), + UpstreamTokenType: stringFromClaims(claims, "upstream_token_type"), + Subject: stringFromClaims(claims, "sub"), + Email: stringFromClaims(claims, "email"), + Name: stringFromClaims(claims, "name"), + HostedDomain: stringFromClaims(claims, "hd"), + ExpiresAt: unixFromClaims(claims, "exp"), + AccessTokenExpiry: unixFromClaims(claims, "access_token_exp"), + } + if v, ok := claims["email_verified"].(bool); ok { + c.EmailVerified = v + } + return c, true +} + +func stringFromClaims(claims map[string]interface{}, key string) string { + if v, ok := claims[key].(string); ok { + return v + } + return "" +} + +func unixFromClaims(claims map[string]interface{}, key string) time.Time { + v, ok := claims[key] + if !ok { + return time.Time{} + } + switch t := v.(type) { + case float64: + return time.Unix(int64(t), 0) + case int64: + return time.Unix(t, 0) + case int: + return time.Unix(int64(t), 0) + } + return time.Time{} +} + func normalizeURL(raw string) string { return strings.TrimRight(strings.TrimSpace(raw), "/") } @@ -666,14 +708,6 @@ func (a *application) createMCPAuthInjector(cfg config.Config) func(http.Handler } } -func randomToken(prefix string) string { - buf := make([]byte, 24) - if _, err := rand.Read(buf); err != nil { - panic(err) - } - return prefix + base64.RawURLEncoding.EncodeToString(buf) -} - func pkceChallenge(verifier string) string { sum := sha256.Sum256([]byte(verifier)) return base64.RawURLEncoding.EncodeToString(sum[:]) @@ -1252,8 +1286,7 @@ func (a *application) handleOAuthAuthorize(w http.ResponseWriter, r *http.Reques return } - callbackState := randomToken("oas_") - a.getOAuthStateStore().putPendingAuth(callbackState, oauthPendingAuth{ + callbackState, err := a.encodePendingAuth(oauthPendingAuth{ ClientID: clientID, RedirectURI: redirectURI, Scope: sanitizeScope(q.Get("scope")), @@ -1264,6 +1297,11 @@ func (a *application) handleOAuthAuthorize(w http.ResponseWriter, r *http.Reques UpstreamPKCEVerifier: upstreamVerifier, ExpiresAt: time.Now().Add(time.Duration(defaultPendingAuthTTLSeconds) * time.Second), }) + if err != nil { + log.Error().Err(err).Msg("Failed to encode pending-auth JWE") + http.Error(w, "Failed to initialize OAuth state", http.StatusInternalServerError) + return + } cfg := a.GetCurrentConfig() authURL, err := a.resolveUpstreamAuthURL() @@ -1302,7 +1340,7 @@ func (a *application) handleOAuthCallback(w http.ResponseWriter, r *http.Request return } - pending, ok := a.getOAuthStateStore().consumePendingAuth(requestID) + pending, ok := a.decodePendingAuth(requestID) if !ok { http.Error(w, "Unknown OAuth request", http.StatusBadRequest) return @@ -1429,7 +1467,6 @@ func (a *application) handleOAuthCallback(w http.ResponseWriter, r *http.Request // is not registered and clients redirect directly to the upstream IdP. // Wrap the upstream tokens in our short-lived issued code; /token unwraps // them in handleOAuthTokenAuthCode. - authCode := randomToken("oac_") issuedCode := oauthIssuedCode{ ClientID: pending.ClientID, RedirectURI: pending.RedirectURI, @@ -1450,7 +1487,12 @@ func (a *application) handleOAuthCallback(w http.ResponseWriter, r *http.Request Msg("upstream_offline_access=true but upstream did not return a refresh_token; check IdP application config (offline_access scope, refresh_token grant, audience)") } } - a.getOAuthStateStore().putAuthCode(authCode, issuedCode) + authCode, err := a.encodeAuthCode(issuedCode) + if err != nil { + log.Error().Err(err).Msg("Failed to encode auth-code JWE") + http.Error(w, "Failed to issue authorization code", http.StatusInternalServerError) + return + } redirect, err := url.Parse(pending.RedirectURI) if err != nil { @@ -1602,7 +1644,7 @@ func (a *application) handleOAuthTokenAuthCode(w http.ResponseWriter, r *http.Re writeOAuthTokenError(w, http.StatusUnauthorized, "invalid_client", "client authentication failed") return } - issued, ok := a.getOAuthStateStore().consumeAuthCode(r.Form.Get("code")) + issued, ok := a.decodeAuthCode(r.Form.Get("code")) if !ok { log.Debug().Msg("OAuth token request rejected: unknown or expired authorization code") writeOAuthTokenError(w, http.StatusBadRequest, "invalid_grant", "invalid authorization code") diff --git a/cmd/altinity-mcp/oauth_server_test.go b/cmd/altinity-mcp/oauth_server_test.go index 687fc6c..cdebb43 100644 --- a/cmd/altinity-mcp/oauth_server_test.go +++ b/cmd/altinity-mcp/oauth_server_test.go @@ -1099,86 +1099,158 @@ func TestCanonicalResourceURL(t *testing.T) { } } -func TestOAuthStateStoreSizeCap(t *testing.T) { +// newJWEStateTestApp builds a minimal application wired with a SigningSecret +// for exercising the stateless JWE encode/decode helpers in isolation. +func newJWEStateTestApp(secret string) *application { + cfg := config.Config{ + Server: config.ServerConfig{ + OAuth: config.OAuthConfig{ + Enabled: true, + SigningSecret: secret, + }, + }, + } + return &application{config: cfg} +} + +func TestOAuthStateJWERoundTrip(t *testing.T) { t.Parallel() - t.Run("pending_auth_evicts_oldest_at_cap", func(t *testing.T) { - t.Parallel() - store := newOAuthStateStore() - // Fill to capacity with entries that expire far in the future - for i := 0; i < maxOAuthStateEntries; i++ { - store.putPendingAuth(fmt.Sprintf("p_%d", i), oauthPendingAuth{ - ClientID: "client", - ExpiresAt: time.Now().Add(time.Hour), - }) + const secret = "test-jwe-state-secret-32-byte-key" + + t.Run("pending_auth_round_trip", func(t *testing.T) { + t.Parallel() + app := newJWEStateTestApp(secret) + want := oauthPendingAuth{ + ClientID: "cid", + RedirectURI: "https://client.example/cb", + Scope: "openid email", + ClientState: "abc", + CodeChallenge: "ch", + CodeChallengeMethod: "S256", + Resource: "https://mcp.example/", + UpstreamPKCEVerifier: "verifier", + ExpiresAt: time.Now().Add(10 * time.Minute).Truncate(time.Second), } - require.Equal(t, maxOAuthStateEntries, len(store.pendingAuth)) + tok, err := app.encodePendingAuth(want) + require.NoError(t, err) + require.NotEmpty(t, tok) - // Insert one with the earliest expiry to make it the eviction target - store.pendingAuth["earliest"] = oauthPendingAuth{ - ClientID: "early", - ExpiresAt: time.Now().Add(-time.Minute), + got, ok := app.decodePendingAuth(tok) + require.True(t, ok) + require.Equal(t, want.ClientID, got.ClientID) + require.Equal(t, want.RedirectURI, got.RedirectURI) + require.Equal(t, want.Scope, got.Scope) + require.Equal(t, want.ClientState, got.ClientState) + require.Equal(t, want.CodeChallenge, got.CodeChallenge) + require.Equal(t, want.CodeChallengeMethod, got.CodeChallengeMethod) + require.Equal(t, want.Resource, got.Resource) + require.Equal(t, want.UpstreamPKCEVerifier, got.UpstreamPKCEVerifier) + require.Equal(t, want.ExpiresAt.Unix(), got.ExpiresAt.Unix()) + }) + + t.Run("auth_code_round_trip", func(t *testing.T) { + t.Parallel() + app := newJWEStateTestApp(secret) + want := oauthIssuedCode{ + ClientID: "cid", + RedirectURI: "https://client.example/cb", + Scope: "openid email", + CodeChallenge: "ch", + CodeChallengeMethod: "S256", + Resource: "https://mcp.example/", + UpstreamBearerToken: "upstream-bearer", + UpstreamRefreshToken: "upstream-refresh", + UpstreamTokenType: "Bearer", + Subject: "user-1", + Email: "u@example.com", + Name: "User", + HostedDomain: "example.com", + EmailVerified: true, + ExpiresAt: time.Now().Add(60 * time.Second).Truncate(time.Second), + AccessTokenExpiry: time.Now().Add(time.Hour).Truncate(time.Second), } + tok, err := app.encodeAuthCode(want) + require.NoError(t, err) - // Next put should evict "earliest" and stay at cap - store.putPendingAuth("overflow", oauthPendingAuth{ - ClientID: "new", - ExpiresAt: time.Now().Add(time.Hour), - }) - // expired entries cleaned + oldest evicted, should not exceed cap - require.LessOrEqual(t, len(store.pendingAuth), maxOAuthStateEntries) - _, ok := store.pendingAuth["earliest"] - require.False(t, ok, "earliest entry should have been evicted") - _, ok = store.pendingAuth["overflow"] - require.True(t, ok, "new entry should be present") + got, ok := app.decodeAuthCode(tok) + require.True(t, ok) + require.Equal(t, want, got) }) - t.Run("auth_codes_evicts_oldest_at_cap", func(t *testing.T) { + t.Run("cross_pod_portable_with_shared_secret", func(t *testing.T) { t.Parallel() - store := newOAuthStateStore() - for i := 0; i < maxOAuthStateEntries; i++ { - store.putAuthCode(fmt.Sprintf("c_%d", i), oauthIssuedCode{ - ClientID: "client", - ExpiresAt: time.Now().Add(time.Hour), - }) - } - require.Equal(t, maxOAuthStateEntries, len(store.authCodes)) + // Simulate two replicas: separate application instances, identical secret. + podA := newJWEStateTestApp(secret) + podB := newJWEStateTestApp(secret) + mintedOnA, err := podA.encodePendingAuth(oauthPendingAuth{ + ClientID: "cid", + ExpiresAt: time.Now().Add(10 * time.Minute), + }) + require.NoError(t, err) + got, ok := podB.decodePendingAuth(mintedOnA) + require.True(t, ok) + require.Equal(t, "cid", got.ClientID) + }) - store.authCodes["earliest"] = oauthIssuedCode{ - ClientID: "early", - ExpiresAt: time.Now().Add(-time.Minute), - } + t.Run("cross_pod_rejected_with_different_secret", func(t *testing.T) { + t.Parallel() + podA := newJWEStateTestApp(secret) + podB := newJWEStateTestApp("a-different-secret-32-bytes-long!") + mintedOnA, err := podA.encodeAuthCode(oauthIssuedCode{ + ClientID: "cid", + ExpiresAt: time.Now().Add(60 * time.Second), + }) + require.NoError(t, err) + _, ok := podB.decodeAuthCode(mintedOnA) + require.False(t, ok, "JWE minted with a different secret must not decode") + }) - store.putAuthCode("overflow", oauthIssuedCode{ - ClientID: "new", - ExpiresAt: time.Now().Add(time.Hour), + t.Run("expired_auth_code_rejected", func(t *testing.T) { + t.Parallel() + app := newJWEStateTestApp(secret) + tok, err := app.encodeAuthCode(oauthIssuedCode{ + ClientID: "cid", + ExpiresAt: time.Now().Add(-1 * time.Second), }) - require.LessOrEqual(t, len(store.authCodes), maxOAuthStateEntries) - _, ok := store.authCodes["earliest"] - require.False(t, ok, "earliest entry should have been evicted") - _, ok = store.authCodes["overflow"] - require.True(t, ok, "new entry should be present") + require.NoError(t, err) + _, ok := app.decodeAuthCode(tok) + require.False(t, ok, "expired auth code must be rejected by JWE exp validation") }) - t.Run("expired_entries_cleaned_before_cap_check", func(t *testing.T) { + t.Run("expired_pending_auth_rejected", func(t *testing.T) { t.Parallel() - store := newOAuthStateStore() - // Fill with already-expired entries - for i := 0; i < maxOAuthStateEntries; i++ { - store.pendingAuth[fmt.Sprintf("exp_%d", i)] = oauthPendingAuth{ - ClientID: "client", - ExpiresAt: time.Now().Add(-time.Second), - } - } - require.Equal(t, maxOAuthStateEntries, len(store.pendingAuth)) + app := newJWEStateTestApp(secret) + tok, err := app.encodePendingAuth(oauthPendingAuth{ + ClientID: "cid", + ExpiresAt: time.Now().Add(-1 * time.Second), + }) + require.NoError(t, err) + _, ok := app.decodePendingAuth(tok) + require.False(t, ok) + }) - // putPendingAuth cleans expired first, so this should succeed without eviction - store.putPendingAuth("fresh", oauthPendingAuth{ - ClientID: "new", - ExpiresAt: time.Now().Add(time.Hour), + t.Run("tampered_token_rejected", func(t *testing.T) { + t.Parallel() + app := newJWEStateTestApp(secret) + tok, err := app.encodeAuthCode(oauthIssuedCode{ + ClientID: "cid", + ExpiresAt: time.Now().Add(60 * time.Second), }) - require.Equal(t, 1, len(store.pendingAuth)) - _, ok := store.pendingAuth["fresh"] - require.True(t, ok) + require.NoError(t, err) + // Flip a byte in the JWE ciphertext. + bs := []byte(tok) + bs[len(bs)/2] ^= 0x01 + _, ok := app.decodeAuthCode(string(bs)) + require.False(t, ok) + }) + + t.Run("decode_missing_secret_fails_cleanly", func(t *testing.T) { + t.Parallel() + app := newJWEStateTestApp("") + _, ok := app.decodePendingAuth("anything") + require.False(t, ok) + _, ok = app.decodeAuthCode("anything") + require.False(t, ok) }) } @@ -2215,54 +2287,6 @@ func TestNormalizeURL(t *testing.T) { } } -func TestOAuthStateStore(t *testing.T) { - t.Parallel() - - t.Run("put_and_consume_pending_auth", func(t *testing.T) { - t.Parallel() - store := newOAuthStateStore() - pending := oauthPendingAuth{ExpiresAt: time.Now().Add(time.Hour)} - store.putPendingAuth("key1", pending) - - got, ok := store.consumePendingAuth("key1") - require.True(t, ok) - require.Equal(t, pending.ExpiresAt.Unix(), got.ExpiresAt.Unix()) - - _, ok = store.consumePendingAuth("key1") - require.False(t, ok) - }) - - t.Run("put_and_consume_auth_code", func(t *testing.T) { - t.Parallel() - store := newOAuthStateStore() - issued := oauthIssuedCode{ExpiresAt: time.Now().Add(time.Hour)} - store.putAuthCode("code1", issued) - - got, ok := store.consumeAuthCode("code1") - require.True(t, ok) - require.Equal(t, issued.ExpiresAt.Unix(), got.ExpiresAt.Unix()) - - _, ok = store.consumeAuthCode("code1") - require.False(t, ok) - }) - - t.Run("expired_entries_cleaned_up", func(t *testing.T) { - t.Parallel() - store := newOAuthStateStore() - store.putPendingAuth("expired", oauthPendingAuth{ExpiresAt: time.Now().Add(-time.Hour)}) - store.putAuthCode("expired", oauthIssuedCode{ExpiresAt: time.Now().Add(-time.Hour)}) - - // Next put triggers cleanup - store.putPendingAuth("fresh", oauthPendingAuth{ExpiresAt: time.Now().Add(time.Hour)}) - store.putAuthCode("fresh", oauthIssuedCode{ExpiresAt: time.Now().Add(time.Hour)}) - - _, ok := store.consumePendingAuth("expired") - require.False(t, ok) - _, ok = store.consumeAuthCode("expired") - require.False(t, ok) - }) -} - func TestSanitizeScope(t *testing.T) { t.Parallel() require.Equal(t, "read write", sanitizeScope(" read write ")) diff --git a/pkg/jwe_auth/jwe_auth.go b/pkg/jwe_auth/jwe_auth.go index 3ac6c90..e9eb7f7 100644 --- a/pkg/jwe_auth/jwe_auth.go +++ b/pkg/jwe_auth/jwe_auth.go @@ -265,6 +265,8 @@ func validateClaimsWhitelist(claims map[string]interface{}) error { "upstream_bearer_token": true, "upstream_refresh_token": true, "upstream_token_type": true, + "upstream_pkce_verifier": true, + "resource": true, "access_token_exp": true, "email": true, "name": true, From 058f43d120be282521c345800858b231a7c01529 Mon Sep 17 00:00:00 2001 From: Boris Tyshkevich Date: Thu, 14 May 2026 18:37:33 +0200 Subject: [PATCH 2/2] mcp: enable Stateless StreamableHTTP transport for HA MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit NewStreamableHTTPHandler defaults to session-tracked mode where each pod issues and validates its own Mcp-Session-Id. Under replicas>=2 with non-sticky load balancing, the MCP `initialize` call lands on whichever pod the LB picks, the client picks ONE returned session-id, and any subsequent tool call that lands on the OTHER pod is rejected with code 32600 "Session terminated". Switch both NewStreamableHTTPHandler call sites to Stateless: true. Each request becomes self-contained, no per-pod session table required. Trade-off: server-initiated requests (sampling, roots/list, log notifications outside an active request) are not supported. altinity-mcp only handles client-initiated tool calls, so this is safe today. Pairs with the JWE OAuth-state refactor in 9f16fd3 — together they make forward-mode and gating+broker_upstream deployments HA-safe. --- cmd/altinity-mcp/main.go | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/cmd/altinity-mcp/main.go b/cmd/altinity-mcp/main.go index 2fdfd9d..577ef32 100644 --- a/cmd/altinity-mcp/main.go +++ b/cmd/altinity-mcp/main.go @@ -454,9 +454,15 @@ func (a *application) startHTTPServer(cfg config.Config, mcpServer *mcp.Server) tokenInjector := a.createTokenInjector() dtInjector := a.dynamicToolsInjector + // Stateless: true makes the streamable HTTP transport carry no per-pod + // session state — each request stands alone. Required for replicas>=2 + // behind a non-sticky LB, where consecutive tool calls from one client + // may land on different pods. Trade-off: server-initiated requests + // (sampling, roots/list, etc.) are not supported; altinity-mcp only + // uses client-initiated tool calls so this is safe. httpServer := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { return mcpServer - }, nil) + }, &mcp.StreamableHTTPOptions{Stateless: true}) mux := http.NewServeMux() transportHandler := serverInjector(tokenInjector(dtInjector(httpServer))) @@ -484,9 +490,15 @@ func (a *application) startHTTPServer(cfg config.Config, mcpServer *mcp.Server) httpHandler = stripTrailingSlash(corsHandler(mux)) } else { // Use standard HTTP server without dynamic paths + // Stateless: true makes the streamable HTTP transport carry no per-pod + // session state — each request stands alone. Required for replicas>=2 + // behind a non-sticky LB, where consecutive tool calls from one client + // may land on different pods. Trade-off: server-initiated requests + // (sampling, roots/list, etc.) are not supported; altinity-mcp only + // uses client-initiated tool calls so this is safe. httpServer := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { return mcpServer - }, nil) + }, &mcp.StreamableHTTPOptions{Stateless: true}) dtInjector := a.dynamicToolsInjector mux := http.NewServeMux() transportHandler := serverInjector(dtInjector(httpServer))