diff --git a/cmd/altinity-mcp/main.go b/cmd/altinity-mcp/main.go index 2a21c27..577ef32 100644 --- a/cmd/altinity-mcp/main.go +++ b/cmd/altinity-mcp/main.go @@ -454,9 +454,15 @@ func (a *application) startHTTPServer(cfg config.Config, mcpServer *mcp.Server) tokenInjector := a.createTokenInjector() dtInjector := a.dynamicToolsInjector + // Stateless: true makes the streamable HTTP transport carry no per-pod + // session state — each request stands alone. Required for replicas>=2 + // behind a non-sticky LB, where consecutive tool calls from one client + // may land on different pods. Trade-off: server-initiated requests + // (sampling, roots/list, etc.) are not supported; altinity-mcp only + // uses client-initiated tool calls so this is safe. httpServer := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { return mcpServer - }, nil) + }, &mcp.StreamableHTTPOptions{Stateless: true}) mux := http.NewServeMux() transportHandler := serverInjector(tokenInjector(dtInjector(httpServer))) @@ -484,9 +490,15 @@ func (a *application) startHTTPServer(cfg config.Config, mcpServer *mcp.Server) httpHandler = stripTrailingSlash(corsHandler(mux)) } else { // Use standard HTTP server without dynamic paths + // Stateless: true makes the streamable HTTP transport carry no per-pod + // session state — each request stands alone. Required for replicas>=2 + // behind a non-sticky LB, where consecutive tool calls from one client + // may land on different pods. Trade-off: server-initiated requests + // (sampling, roots/list, etc.) are not supported; altinity-mcp only + // uses client-initiated tool calls so this is safe. httpServer := mcp.NewStreamableHTTPHandler(func(r *http.Request) *mcp.Server { return mcpServer - }, nil) + }, &mcp.StreamableHTTPOptions{Stateless: true}) dtInjector := a.dynamicToolsInjector mux := http.NewServeMux() transportHandler := serverInjector(dtInjector(httpServer)) @@ -964,8 +976,6 @@ type application struct { mcpServer *altinitymcp.ClickHouseJWEServer httpSrv *http.Server httpSrvMutex sync.RWMutex - oauthState *oauthStateStore - oauthStateMu sync.Mutex configFile string configMutex sync.RWMutex stopConfigReload chan struct{} @@ -985,15 +995,6 @@ func (a *application) getHTTPServer() *http.Server { return a.httpSrv } -func (a *application) getOAuthStateStore() *oauthStateStore { - a.oauthStateMu.Lock() - defer a.oauthStateMu.Unlock() - if a.oauthState == nil { - a.oauthState = newOAuthStateStore() - } - return a.oauthState -} - func newApplication(ctx context.Context, cfg config.Config, cmd CommandInterface) (*application, error) { if err := validateOAuthRuntimeConfig(cfg); err != nil { return nil, err @@ -1052,7 +1053,6 @@ func newApplication(ctx context.Context, cfg config.Config, cmd CommandInterface app := &application{ config: cfg, mcpServer: mcpServer, - oauthState: newOAuthStateStore(), configFile: cmd.String("config"), stopConfigReload: make(chan struct{}), } diff --git a/cmd/altinity-mcp/main_test.go b/cmd/altinity-mcp/main_test.go index 709c327..d3d37ff 100644 --- a/cmd/altinity-mcp/main_test.go +++ b/cmd/altinity-mcp/main_test.go @@ -3908,39 +3908,6 @@ func TestHealthHandler_CHUnavailable(t *testing.T) { require.Equal(t, "unhealthy", body["status"]) } -func TestOAuthStateStoreEviction(t *testing.T) { - t.Parallel() - store := newOAuthStateStore() - - // Fill pending auth to capacity - for i := 0; i < maxOAuthStateEntries; i++ { - store.putPendingAuth(fmt.Sprintf("pending-%d", i), oauthPendingAuth{ - ExpiresAt: time.Now().Add(time.Duration(i) * time.Second), - }) - } - - // Adding one more should evict the oldest - store.putPendingAuth("new-pending", oauthPendingAuth{ExpiresAt: time.Now().Add(time.Hour)}) - _, ok := store.consumePendingAuth("pending-0") // oldest should be evicted - require.False(t, ok) - _, ok = store.consumePendingAuth("new-pending") - require.True(t, ok) - - // Fill auth codes to capacity - for i := 0; i < maxOAuthStateEntries; i++ { - store.putAuthCode(fmt.Sprintf("code-%d", i), oauthIssuedCode{ - ExpiresAt: time.Now().Add(time.Duration(i) * time.Second), - }) - } - - // Adding one more should evict the oldest - store.putAuthCode("new-code", oauthIssuedCode{ExpiresAt: time.Now().Add(time.Hour)}) - _, ok = store.consumeAuthCode("code-0") - require.False(t, ok) - _, ok = store.consumeAuthCode("new-code") - require.True(t, ok) -} - func TestToolInputSettingsCLIFlag(t *testing.T) { cases := []struct { name string diff --git a/cmd/altinity-mcp/oauth_server.go b/cmd/altinity-mcp/oauth_server.go index c6edd24..c1c663b 100644 --- a/cmd/altinity-mcp/oauth_server.go +++ b/cmd/altinity-mcp/oauth_server.go @@ -15,7 +15,6 @@ import ( "net/url" "slices" "strings" - "sync" "time" "github.com/altinity/altinity-mcp/pkg/config" @@ -100,107 +99,15 @@ type oauthIssuedCode struct { AccessTokenExpiry time.Time } -// maxOAuthStateEntries caps each map in the state store to prevent memory -// exhaustion from floods of unauthenticated /oauth/authorize requests. -const maxOAuthStateEntries = 10000 - -type oauthStateStore struct { - mu sync.Mutex - pendingAuth map[string]oauthPendingAuth - authCodes map[string]oauthIssuedCode -} - -func newOAuthStateStore() *oauthStateStore { - return &oauthStateStore{ - pendingAuth: make(map[string]oauthPendingAuth), - authCodes: make(map[string]oauthIssuedCode), - } -} - -func (s *oauthStateStore) cleanupExpiredLocked(now time.Time) { - for key, pending := range s.pendingAuth { - if !pending.ExpiresAt.IsZero() && now.After(pending.ExpiresAt) { - delete(s.pendingAuth, key) - } - } - for key, issued := range s.authCodes { - if !issued.ExpiresAt.IsZero() && now.After(issued.ExpiresAt) { - delete(s.authCodes, key) - } - } -} - -// evictOldestPendingLocked removes the entry with the earliest expiry. -func (s *oauthStateStore) evictOldestPendingLocked() { - var oldestKey string - var oldestTime time.Time - for key, pending := range s.pendingAuth { - if oldestKey == "" || pending.ExpiresAt.Before(oldestTime) { - oldestKey = key - oldestTime = pending.ExpiresAt - } - } - if oldestKey != "" { - delete(s.pendingAuth, oldestKey) - } -} - -// evictOldestCodeLocked removes the entry with the earliest expiry. -func (s *oauthStateStore) evictOldestCodeLocked() { - var oldestKey string - var oldestTime time.Time - for key, issued := range s.authCodes { - if oldestKey == "" || issued.ExpiresAt.Before(oldestTime) { - oldestKey = key - oldestTime = issued.ExpiresAt - } - } - if oldestKey != "" { - delete(s.authCodes, oldestKey) - } -} - -func (s *oauthStateStore) putPendingAuth(id string, pending oauthPendingAuth) { - s.mu.Lock() - defer s.mu.Unlock() - s.cleanupExpiredLocked(time.Now()) - if len(s.pendingAuth) >= maxOAuthStateEntries { - s.evictOldestPendingLocked() - } - s.pendingAuth[id] = pending -} - -func (s *oauthStateStore) consumePendingAuth(id string) (oauthPendingAuth, bool) { - s.mu.Lock() - defer s.mu.Unlock() - s.cleanupExpiredLocked(time.Now()) - pending, ok := s.pendingAuth[id] - if ok { - delete(s.pendingAuth, id) - } - return pending, ok -} - -func (s *oauthStateStore) putAuthCode(id string, issued oauthIssuedCode) { - s.mu.Lock() - defer s.mu.Unlock() - s.cleanupExpiredLocked(time.Now()) - if len(s.authCodes) >= maxOAuthStateEntries { - s.evictOldestCodeLocked() - } - s.authCodes[id] = issued -} - -func (s *oauthStateStore) consumeAuthCode(id string) (oauthIssuedCode, bool) { - s.mu.Lock() - defer s.mu.Unlock() - s.cleanupExpiredLocked(time.Now()) - issued, ok := s.authCodes[id] - if ok { - delete(s.authCodes, id) - } - return issued, ok -} +// OAuth pending-auth and issued-code state are encoded as stateless JWE tokens +// (see encodePendingAuth / encodeAuthCode below) so any replica can decode +// state minted by any other replica. There is no in-memory store, no eviction, +// and no per-pod size cap — expiry is enforced by the `exp` claim inside each +// JWE. Single-use on auth codes is intentionally NOT enforced server-side: +// codes are bound to the client's PKCE verifier (RFC 7636) and live for at +// most defaultAuthCodeTTLSeconds, so replay within the TTL is limited to +// whoever holds the verifier — i.e. the legitimate client. Trading strict +// RFC 6749 §4.1.2 single-use for zero shared state across replicas. func writeOAuthTokenError(w http.ResponseWriter, status int, code, description string) { w.Header().Set("Content-Type", "application/json") @@ -264,8 +171,10 @@ const oauthKidV1 = "v1" // Bumping the /vN suffix in any single label rotates that one key without // disturbing the others. const ( - hkdfInfoOAuthClientID = "altinity-mcp/oauth/client-id/v1" - hkdfInfoOAuthRefresh = "altinity-mcp/oauth/refresh-token/v1" + hkdfInfoOAuthClientID = "altinity-mcp/oauth/client-id/v1" + hkdfInfoOAuthRefresh = "altinity-mcp/oauth/refresh-token/v1" + hkdfInfoOAuthPendingAuth = "altinity-mcp/oauth/pending-auth/v1" + hkdfInfoOAuthAuthCode = "altinity-mcp/oauth/auth-code/v1" ) // encodeOAuthJWE emits a JWE-wrapped JSON document of `claims`, encrypted @@ -336,6 +245,139 @@ func decodeOAuthJWE(secret []byte, info string, token string) (map[string]interf return jwe_auth.ParseAndDecryptJWE(token, secret, secret) } +// encodePendingAuth wraps an oauthPendingAuth into a stateless JWE used as the +// `state` parameter sent to the upstream IdP at /authorize. Any replica with +// the shared signing_secret can decode it at /callback. +func (a *application) encodePendingAuth(p oauthPendingAuth) (string, error) { + secret, err := a.mustJWESecret() + if err != nil { + return "", err + } + claims := map[string]interface{}{ + "client_id": p.ClientID, + "redirect_uri": p.RedirectURI, + "scope": p.Scope, + "client_state": p.ClientState, + "code_challenge": p.CodeChallenge, + "code_challenge_method": p.CodeChallengeMethod, + "resource": p.Resource, + "upstream_pkce_verifier": p.UpstreamPKCEVerifier, + "exp": p.ExpiresAt.Unix(), + } + return encodeOAuthJWE(secret, hkdfInfoOAuthPendingAuth, claims) +} + +// decodePendingAuth is the inverse of encodePendingAuth. Returns (pending, +// false) when the token is unparseable, tampered, expired, or carries claims +// outside the JWE whitelist. +func (a *application) decodePendingAuth(token string) (oauthPendingAuth, bool) { + secret := a.oauthJWESecret() + if len(secret) == 0 { + return oauthPendingAuth{}, false + } + claims, err := decodeOAuthJWE(secret, hkdfInfoOAuthPendingAuth, token) + if err != nil { + return oauthPendingAuth{}, false + } + p := oauthPendingAuth{ + ClientID: stringFromClaims(claims, "client_id"), + RedirectURI: stringFromClaims(claims, "redirect_uri"), + Scope: stringFromClaims(claims, "scope"), + ClientState: stringFromClaims(claims, "client_state"), + CodeChallenge: stringFromClaims(claims, "code_challenge"), + CodeChallengeMethod: stringFromClaims(claims, "code_challenge_method"), + Resource: stringFromClaims(claims, "resource"), + UpstreamPKCEVerifier: stringFromClaims(claims, "upstream_pkce_verifier"), + ExpiresAt: unixFromClaims(claims, "exp"), + } + return p, true +} + +// encodeAuthCode wraps an oauthIssuedCode into a stateless JWE used as the +// `code` parameter returned to the MCP client at /callback. Redeemed at +// /token by decodeAuthCode on any replica. +func (a *application) encodeAuthCode(c oauthIssuedCode) (string, error) { + secret, err := a.mustJWESecret() + if err != nil { + return "", err + } + claims := map[string]interface{}{ + "client_id": c.ClientID, + "redirect_uri": c.RedirectURI, + "scope": c.Scope, + "code_challenge": c.CodeChallenge, + "code_challenge_method": c.CodeChallengeMethod, + "resource": c.Resource, + "upstream_bearer_token": c.UpstreamBearerToken, + "upstream_refresh_token": c.UpstreamRefreshToken, + "upstream_token_type": c.UpstreamTokenType, + "sub": c.Subject, + "email": c.Email, + "name": c.Name, + "hd": c.HostedDomain, + "email_verified": c.EmailVerified, + "exp": c.ExpiresAt.Unix(), + "access_token_exp": c.AccessTokenExpiry.Unix(), + } + return encodeOAuthJWE(secret, hkdfInfoOAuthAuthCode, claims) +} + +// decodeAuthCode is the inverse of encodeAuthCode. +func (a *application) decodeAuthCode(token string) (oauthIssuedCode, bool) { + secret := a.oauthJWESecret() + if len(secret) == 0 { + return oauthIssuedCode{}, false + } + claims, err := decodeOAuthJWE(secret, hkdfInfoOAuthAuthCode, token) + if err != nil { + return oauthIssuedCode{}, false + } + c := oauthIssuedCode{ + ClientID: stringFromClaims(claims, "client_id"), + RedirectURI: stringFromClaims(claims, "redirect_uri"), + Scope: stringFromClaims(claims, "scope"), + CodeChallenge: stringFromClaims(claims, "code_challenge"), + CodeChallengeMethod: stringFromClaims(claims, "code_challenge_method"), + Resource: stringFromClaims(claims, "resource"), + UpstreamBearerToken: stringFromClaims(claims, "upstream_bearer_token"), + UpstreamRefreshToken: stringFromClaims(claims, "upstream_refresh_token"), + UpstreamTokenType: stringFromClaims(claims, "upstream_token_type"), + Subject: stringFromClaims(claims, "sub"), + Email: stringFromClaims(claims, "email"), + Name: stringFromClaims(claims, "name"), + HostedDomain: stringFromClaims(claims, "hd"), + ExpiresAt: unixFromClaims(claims, "exp"), + AccessTokenExpiry: unixFromClaims(claims, "access_token_exp"), + } + if v, ok := claims["email_verified"].(bool); ok { + c.EmailVerified = v + } + return c, true +} + +func stringFromClaims(claims map[string]interface{}, key string) string { + if v, ok := claims[key].(string); ok { + return v + } + return "" +} + +func unixFromClaims(claims map[string]interface{}, key string) time.Time { + v, ok := claims[key] + if !ok { + return time.Time{} + } + switch t := v.(type) { + case float64: + return time.Unix(int64(t), 0) + case int64: + return time.Unix(t, 0) + case int: + return time.Unix(int64(t), 0) + } + return time.Time{} +} + func normalizeURL(raw string) string { return strings.TrimRight(strings.TrimSpace(raw), "/") } @@ -666,14 +708,6 @@ func (a *application) createMCPAuthInjector(cfg config.Config) func(http.Handler } } -func randomToken(prefix string) string { - buf := make([]byte, 24) - if _, err := rand.Read(buf); err != nil { - panic(err) - } - return prefix + base64.RawURLEncoding.EncodeToString(buf) -} - func pkceChallenge(verifier string) string { sum := sha256.Sum256([]byte(verifier)) return base64.RawURLEncoding.EncodeToString(sum[:]) @@ -1252,8 +1286,7 @@ func (a *application) handleOAuthAuthorize(w http.ResponseWriter, r *http.Reques return } - callbackState := randomToken("oas_") - a.getOAuthStateStore().putPendingAuth(callbackState, oauthPendingAuth{ + callbackState, err := a.encodePendingAuth(oauthPendingAuth{ ClientID: clientID, RedirectURI: redirectURI, Scope: sanitizeScope(q.Get("scope")), @@ -1264,6 +1297,11 @@ func (a *application) handleOAuthAuthorize(w http.ResponseWriter, r *http.Reques UpstreamPKCEVerifier: upstreamVerifier, ExpiresAt: time.Now().Add(time.Duration(defaultPendingAuthTTLSeconds) * time.Second), }) + if err != nil { + log.Error().Err(err).Msg("Failed to encode pending-auth JWE") + http.Error(w, "Failed to initialize OAuth state", http.StatusInternalServerError) + return + } cfg := a.GetCurrentConfig() authURL, err := a.resolveUpstreamAuthURL() @@ -1302,7 +1340,7 @@ func (a *application) handleOAuthCallback(w http.ResponseWriter, r *http.Request return } - pending, ok := a.getOAuthStateStore().consumePendingAuth(requestID) + pending, ok := a.decodePendingAuth(requestID) if !ok { http.Error(w, "Unknown OAuth request", http.StatusBadRequest) return @@ -1429,7 +1467,6 @@ func (a *application) handleOAuthCallback(w http.ResponseWriter, r *http.Request // is not registered and clients redirect directly to the upstream IdP. // Wrap the upstream tokens in our short-lived issued code; /token unwraps // them in handleOAuthTokenAuthCode. - authCode := randomToken("oac_") issuedCode := oauthIssuedCode{ ClientID: pending.ClientID, RedirectURI: pending.RedirectURI, @@ -1450,7 +1487,12 @@ func (a *application) handleOAuthCallback(w http.ResponseWriter, r *http.Request Msg("upstream_offline_access=true but upstream did not return a refresh_token; check IdP application config (offline_access scope, refresh_token grant, audience)") } } - a.getOAuthStateStore().putAuthCode(authCode, issuedCode) + authCode, err := a.encodeAuthCode(issuedCode) + if err != nil { + log.Error().Err(err).Msg("Failed to encode auth-code JWE") + http.Error(w, "Failed to issue authorization code", http.StatusInternalServerError) + return + } redirect, err := url.Parse(pending.RedirectURI) if err != nil { @@ -1602,7 +1644,7 @@ func (a *application) handleOAuthTokenAuthCode(w http.ResponseWriter, r *http.Re writeOAuthTokenError(w, http.StatusUnauthorized, "invalid_client", "client authentication failed") return } - issued, ok := a.getOAuthStateStore().consumeAuthCode(r.Form.Get("code")) + issued, ok := a.decodeAuthCode(r.Form.Get("code")) if !ok { log.Debug().Msg("OAuth token request rejected: unknown or expired authorization code") writeOAuthTokenError(w, http.StatusBadRequest, "invalid_grant", "invalid authorization code") diff --git a/cmd/altinity-mcp/oauth_server_test.go b/cmd/altinity-mcp/oauth_server_test.go index 687fc6c..cdebb43 100644 --- a/cmd/altinity-mcp/oauth_server_test.go +++ b/cmd/altinity-mcp/oauth_server_test.go @@ -1099,86 +1099,158 @@ func TestCanonicalResourceURL(t *testing.T) { } } -func TestOAuthStateStoreSizeCap(t *testing.T) { +// newJWEStateTestApp builds a minimal application wired with a SigningSecret +// for exercising the stateless JWE encode/decode helpers in isolation. +func newJWEStateTestApp(secret string) *application { + cfg := config.Config{ + Server: config.ServerConfig{ + OAuth: config.OAuthConfig{ + Enabled: true, + SigningSecret: secret, + }, + }, + } + return &application{config: cfg} +} + +func TestOAuthStateJWERoundTrip(t *testing.T) { t.Parallel() - t.Run("pending_auth_evicts_oldest_at_cap", func(t *testing.T) { - t.Parallel() - store := newOAuthStateStore() - // Fill to capacity with entries that expire far in the future - for i := 0; i < maxOAuthStateEntries; i++ { - store.putPendingAuth(fmt.Sprintf("p_%d", i), oauthPendingAuth{ - ClientID: "client", - ExpiresAt: time.Now().Add(time.Hour), - }) + const secret = "test-jwe-state-secret-32-byte-key" + + t.Run("pending_auth_round_trip", func(t *testing.T) { + t.Parallel() + app := newJWEStateTestApp(secret) + want := oauthPendingAuth{ + ClientID: "cid", + RedirectURI: "https://client.example/cb", + Scope: "openid email", + ClientState: "abc", + CodeChallenge: "ch", + CodeChallengeMethod: "S256", + Resource: "https://mcp.example/", + UpstreamPKCEVerifier: "verifier", + ExpiresAt: time.Now().Add(10 * time.Minute).Truncate(time.Second), } - require.Equal(t, maxOAuthStateEntries, len(store.pendingAuth)) + tok, err := app.encodePendingAuth(want) + require.NoError(t, err) + require.NotEmpty(t, tok) - // Insert one with the earliest expiry to make it the eviction target - store.pendingAuth["earliest"] = oauthPendingAuth{ - ClientID: "early", - ExpiresAt: time.Now().Add(-time.Minute), + got, ok := app.decodePendingAuth(tok) + require.True(t, ok) + require.Equal(t, want.ClientID, got.ClientID) + require.Equal(t, want.RedirectURI, got.RedirectURI) + require.Equal(t, want.Scope, got.Scope) + require.Equal(t, want.ClientState, got.ClientState) + require.Equal(t, want.CodeChallenge, got.CodeChallenge) + require.Equal(t, want.CodeChallengeMethod, got.CodeChallengeMethod) + require.Equal(t, want.Resource, got.Resource) + require.Equal(t, want.UpstreamPKCEVerifier, got.UpstreamPKCEVerifier) + require.Equal(t, want.ExpiresAt.Unix(), got.ExpiresAt.Unix()) + }) + + t.Run("auth_code_round_trip", func(t *testing.T) { + t.Parallel() + app := newJWEStateTestApp(secret) + want := oauthIssuedCode{ + ClientID: "cid", + RedirectURI: "https://client.example/cb", + Scope: "openid email", + CodeChallenge: "ch", + CodeChallengeMethod: "S256", + Resource: "https://mcp.example/", + UpstreamBearerToken: "upstream-bearer", + UpstreamRefreshToken: "upstream-refresh", + UpstreamTokenType: "Bearer", + Subject: "user-1", + Email: "u@example.com", + Name: "User", + HostedDomain: "example.com", + EmailVerified: true, + ExpiresAt: time.Now().Add(60 * time.Second).Truncate(time.Second), + AccessTokenExpiry: time.Now().Add(time.Hour).Truncate(time.Second), } + tok, err := app.encodeAuthCode(want) + require.NoError(t, err) - // Next put should evict "earliest" and stay at cap - store.putPendingAuth("overflow", oauthPendingAuth{ - ClientID: "new", - ExpiresAt: time.Now().Add(time.Hour), - }) - // expired entries cleaned + oldest evicted, should not exceed cap - require.LessOrEqual(t, len(store.pendingAuth), maxOAuthStateEntries) - _, ok := store.pendingAuth["earliest"] - require.False(t, ok, "earliest entry should have been evicted") - _, ok = store.pendingAuth["overflow"] - require.True(t, ok, "new entry should be present") + got, ok := app.decodeAuthCode(tok) + require.True(t, ok) + require.Equal(t, want, got) }) - t.Run("auth_codes_evicts_oldest_at_cap", func(t *testing.T) { + t.Run("cross_pod_portable_with_shared_secret", func(t *testing.T) { t.Parallel() - store := newOAuthStateStore() - for i := 0; i < maxOAuthStateEntries; i++ { - store.putAuthCode(fmt.Sprintf("c_%d", i), oauthIssuedCode{ - ClientID: "client", - ExpiresAt: time.Now().Add(time.Hour), - }) - } - require.Equal(t, maxOAuthStateEntries, len(store.authCodes)) + // Simulate two replicas: separate application instances, identical secret. + podA := newJWEStateTestApp(secret) + podB := newJWEStateTestApp(secret) + mintedOnA, err := podA.encodePendingAuth(oauthPendingAuth{ + ClientID: "cid", + ExpiresAt: time.Now().Add(10 * time.Minute), + }) + require.NoError(t, err) + got, ok := podB.decodePendingAuth(mintedOnA) + require.True(t, ok) + require.Equal(t, "cid", got.ClientID) + }) - store.authCodes["earliest"] = oauthIssuedCode{ - ClientID: "early", - ExpiresAt: time.Now().Add(-time.Minute), - } + t.Run("cross_pod_rejected_with_different_secret", func(t *testing.T) { + t.Parallel() + podA := newJWEStateTestApp(secret) + podB := newJWEStateTestApp("a-different-secret-32-bytes-long!") + mintedOnA, err := podA.encodeAuthCode(oauthIssuedCode{ + ClientID: "cid", + ExpiresAt: time.Now().Add(60 * time.Second), + }) + require.NoError(t, err) + _, ok := podB.decodeAuthCode(mintedOnA) + require.False(t, ok, "JWE minted with a different secret must not decode") + }) - store.putAuthCode("overflow", oauthIssuedCode{ - ClientID: "new", - ExpiresAt: time.Now().Add(time.Hour), + t.Run("expired_auth_code_rejected", func(t *testing.T) { + t.Parallel() + app := newJWEStateTestApp(secret) + tok, err := app.encodeAuthCode(oauthIssuedCode{ + ClientID: "cid", + ExpiresAt: time.Now().Add(-1 * time.Second), }) - require.LessOrEqual(t, len(store.authCodes), maxOAuthStateEntries) - _, ok := store.authCodes["earliest"] - require.False(t, ok, "earliest entry should have been evicted") - _, ok = store.authCodes["overflow"] - require.True(t, ok, "new entry should be present") + require.NoError(t, err) + _, ok := app.decodeAuthCode(tok) + require.False(t, ok, "expired auth code must be rejected by JWE exp validation") }) - t.Run("expired_entries_cleaned_before_cap_check", func(t *testing.T) { + t.Run("expired_pending_auth_rejected", func(t *testing.T) { t.Parallel() - store := newOAuthStateStore() - // Fill with already-expired entries - for i := 0; i < maxOAuthStateEntries; i++ { - store.pendingAuth[fmt.Sprintf("exp_%d", i)] = oauthPendingAuth{ - ClientID: "client", - ExpiresAt: time.Now().Add(-time.Second), - } - } - require.Equal(t, maxOAuthStateEntries, len(store.pendingAuth)) + app := newJWEStateTestApp(secret) + tok, err := app.encodePendingAuth(oauthPendingAuth{ + ClientID: "cid", + ExpiresAt: time.Now().Add(-1 * time.Second), + }) + require.NoError(t, err) + _, ok := app.decodePendingAuth(tok) + require.False(t, ok) + }) - // putPendingAuth cleans expired first, so this should succeed without eviction - store.putPendingAuth("fresh", oauthPendingAuth{ - ClientID: "new", - ExpiresAt: time.Now().Add(time.Hour), + t.Run("tampered_token_rejected", func(t *testing.T) { + t.Parallel() + app := newJWEStateTestApp(secret) + tok, err := app.encodeAuthCode(oauthIssuedCode{ + ClientID: "cid", + ExpiresAt: time.Now().Add(60 * time.Second), }) - require.Equal(t, 1, len(store.pendingAuth)) - _, ok := store.pendingAuth["fresh"] - require.True(t, ok) + require.NoError(t, err) + // Flip a byte in the JWE ciphertext. + bs := []byte(tok) + bs[len(bs)/2] ^= 0x01 + _, ok := app.decodeAuthCode(string(bs)) + require.False(t, ok) + }) + + t.Run("decode_missing_secret_fails_cleanly", func(t *testing.T) { + t.Parallel() + app := newJWEStateTestApp("") + _, ok := app.decodePendingAuth("anything") + require.False(t, ok) + _, ok = app.decodeAuthCode("anything") + require.False(t, ok) }) } @@ -2215,54 +2287,6 @@ func TestNormalizeURL(t *testing.T) { } } -func TestOAuthStateStore(t *testing.T) { - t.Parallel() - - t.Run("put_and_consume_pending_auth", func(t *testing.T) { - t.Parallel() - store := newOAuthStateStore() - pending := oauthPendingAuth{ExpiresAt: time.Now().Add(time.Hour)} - store.putPendingAuth("key1", pending) - - got, ok := store.consumePendingAuth("key1") - require.True(t, ok) - require.Equal(t, pending.ExpiresAt.Unix(), got.ExpiresAt.Unix()) - - _, ok = store.consumePendingAuth("key1") - require.False(t, ok) - }) - - t.Run("put_and_consume_auth_code", func(t *testing.T) { - t.Parallel() - store := newOAuthStateStore() - issued := oauthIssuedCode{ExpiresAt: time.Now().Add(time.Hour)} - store.putAuthCode("code1", issued) - - got, ok := store.consumeAuthCode("code1") - require.True(t, ok) - require.Equal(t, issued.ExpiresAt.Unix(), got.ExpiresAt.Unix()) - - _, ok = store.consumeAuthCode("code1") - require.False(t, ok) - }) - - t.Run("expired_entries_cleaned_up", func(t *testing.T) { - t.Parallel() - store := newOAuthStateStore() - store.putPendingAuth("expired", oauthPendingAuth{ExpiresAt: time.Now().Add(-time.Hour)}) - store.putAuthCode("expired", oauthIssuedCode{ExpiresAt: time.Now().Add(-time.Hour)}) - - // Next put triggers cleanup - store.putPendingAuth("fresh", oauthPendingAuth{ExpiresAt: time.Now().Add(time.Hour)}) - store.putAuthCode("fresh", oauthIssuedCode{ExpiresAt: time.Now().Add(time.Hour)}) - - _, ok := store.consumePendingAuth("expired") - require.False(t, ok) - _, ok = store.consumeAuthCode("expired") - require.False(t, ok) - }) -} - func TestSanitizeScope(t *testing.T) { t.Parallel() require.Equal(t, "read write", sanitizeScope(" read write ")) diff --git a/pkg/jwe_auth/jwe_auth.go b/pkg/jwe_auth/jwe_auth.go index 3ac6c90..e9eb7f7 100644 --- a/pkg/jwe_auth/jwe_auth.go +++ b/pkg/jwe_auth/jwe_auth.go @@ -265,6 +265,8 @@ func validateClaimsWhitelist(claims map[string]interface{}) error { "upstream_bearer_token": true, "upstream_refresh_token": true, "upstream_token_type": true, + "upstream_pkce_verifier": true, + "resource": true, "access_token_exp": true, "email": true, "name": true,