diff --git a/cmd/gateway/main.go b/cmd/gateway/main.go index 74936da..46dcf5a 100644 --- a/cmd/gateway/main.go +++ b/cmd/gateway/main.go @@ -5,9 +5,11 @@ package main import ( "context" + "github.com/redis/go-redis/v9" "go.uber.org/zap" "github.com/alpnuhoglu/gamemesh/internal/gateway" + "github.com/alpnuhoglu/gamemesh/internal/player" "github.com/alpnuhoglu/gamemesh/pkg/auth" "github.com/alpnuhoglu/gamemesh/pkg/config" "github.com/alpnuhoglu/gamemesh/pkg/logger" @@ -45,9 +47,28 @@ func main() { m := metrics.New(cfg.ServiceName) tokens := auth.NewTokenManager(cfg.JWTSecret, cfg.JWTExpiry, cfg.JWTIssuer) + // The gateway is the boundary node that enforces JWT revocation (logout) so + // a revoked token is rejected here, before any upstream spends CPU on it. It + // reads the same cluster Redis the player service writes sessions to; a + // short-TTL in-process cache keeps that check off the Redis hot path. + rdb := redis.NewClient(&redis.Options{ + Addr: cfg.RedisAddr, + Password: cfg.RedisPassword, + DB: cfg.RedisDB, + }) + if err := tracing.InstrumentRedis(rdb); err != nil { + log.Fatal("failed to instrument redis", zap.Error(err)) + } + var sessions middleware.SessionChecker + store := player.NewSessionStore(rdb) + if cfg.SessionCacheEnabled { + store = player.NewCachedSessionStore(store, cfg.SessionCacheTTL) + } + sessions = store + engine := server.NewEngine(cfg, log, m) engine.Use(middleware.RateLimit(cfg.RateLimitRPS, cfg.RateLimitBurst)) - gateway.RegisterRoutes(engine, cfg, tokens, log) + gateway.RegisterRoutes(engine, cfg, tokens, sessions, log) if err := server.Run(engine, cfg.HTTPPort, log); err != nil { log.Fatal("server exited", zap.Error(err)) diff --git a/internal/gateway/router.go b/internal/gateway/router.go index 2175afc..060e8c2 100644 --- a/internal/gateway/router.go +++ b/internal/gateway/router.go @@ -12,13 +12,17 @@ import ( // RegisterRoutes wires every public route to its upstream service. // Routes are declared explicitly (no blanket wildcards) so the gateway is // also the authoritative, reviewable map of the public API surface. -func RegisterRoutes(r *gin.Engine, cfg *config.Config, tokens *auth.TokenManager, log *zap.Logger) { +// +// sessions enforces server-side revocation on protected routes; pass a +// cache-backed checker so the per-request check stays off the Redis hot path. A +// nil sessions disables revocation (JWT-only) — handy for tests. +func RegisterRoutes(r *gin.Engine, cfg *config.Config, tokens *auth.TokenManager, sessions middleware.SessionChecker, log *zap.Logger) { playerProxy := newProxy(cfg.PlayerServiceURL, log) matchProxy := newProxy(cfg.MatchmakingServiceURL, log) lbProxy := newProxy(cfg.LeaderboardServiceURL, log) wsProxy := newProxy(cfg.WebsocketServiceURL, log) - authRequired := middleware.Auth(tokens) + authRequired := middleware.Auth(tokens, sessions) v1 := r.Group("/api/v1") diff --git a/internal/gateway/router_test.go b/internal/gateway/router_test.go index f2d2560..52aca74 100644 --- a/internal/gateway/router_test.go +++ b/internal/gateway/router_test.go @@ -43,7 +43,7 @@ func newGatewayUnderTest(t *testing.T, upstreamURL string) (*httptest.Server, *a gin.SetMode(gin.TestMode) r := gin.New() r.Use(middleware.RequestID()) - RegisterRoutes(r, cfg, tokens, zap.NewNop()) + RegisterRoutes(r, cfg, tokens, nil, zap.NewNop()) // nil sessions → JWT-only srv := httptest.NewServer(r) t.Cleanup(srv.Close) diff --git a/internal/player/session.go b/internal/player/session.go index b21ac69..5bf5ef3 100644 --- a/internal/player/session.go +++ b/internal/player/session.go @@ -2,6 +2,7 @@ package player import ( "context" + "sync" "time" "github.com/redis/go-redis/v9" @@ -42,3 +43,122 @@ func (s *redisSessionStore) Exists(ctx context.Context, jti string) (bool, error n, err := s.rdb.Exists(ctx, sessionKey(jti)).Result() return n > 0, err } + +// cachedSessionStore wraps a SessionStore with a short-lived, in-process +// positive cache so the gateway can enforce JWT revocation on every +// authenticated request without hitting Redis each time. +// +// Positive cache: a "valid" verdict for a JTI is trusted for ttl (e.g. 5s) +// before Redis is consulted again. That ttl is the worst-case window a revoked +// token stays usable on this replica — an accepted eventual-consistency +// trade-off that avoids the multi-replica complexity of a synchronised +// revocation (negative) list. Delete drops the local entry immediately, so a +// logout served by the same replica revokes at once. +// +// Concurrency: the access pattern is read-mostly (one Redis lookup per JTI per +// ttl window, then many lock-free-ish RLock reads), so a single RWMutex is +// ample — at the gateway's RPS and CPU budget a sharded map would be +// over-engineering. Because this is a decorator behind the SessionStore +// interface, the internal map can be swapped for shards later WITHOUT touching +// any caller if profiling ever shows contention. The RWMutex + map + janitor +// shape deliberately mirrors middleware.ipLimiter for codebase cohesion. +type cachedSessionStore struct { + inner SessionStore + ttl time.Duration + + mu sync.RWMutex + entries map[string]sessionEntry +} + +type sessionEntry struct { + valid bool + expiresAt time.Time +} + +// NewCachedSessionStore wraps store with an in-process positive cache of the +// given ttl and starts a background janitor that evicts expired entries. +func NewCachedSessionStore(store SessionStore, ttl time.Duration) SessionStore { + c := &cachedSessionStore{ + inner: store, + ttl: ttl, + entries: make(map[string]sessionEntry), + } + go c.janitor() + return c +} + +// Save delegates to Redis and seeds the local cache so the just-created session +// is immediately served from memory (a login is normally followed by requests). +func (c *cachedSessionStore) Save(ctx context.Context, jti, playerID string, ttl time.Duration) error { + if err := c.inner.Save(ctx, jti, playerID, ttl); err != nil { + return err + } + c.set(jti, true) + return nil +} + +// Delete removes the session from Redis AND evicts the local entry so a logout +// handled by this replica takes effect at once (no ttl wait on this node). +func (c *cachedSessionStore) Delete(ctx context.Context, jti string) error { + c.evict(jti) + return c.inner.Delete(ctx, jti) +} + +// Exists serves a cached verdict when fresh; otherwise it consults Redis once +// and caches the result for ttl. +func (c *cachedSessionStore) Exists(ctx context.Context, jti string) (bool, error) { + if v, ok := c.get(jti); ok { + return v, nil + } + exists, err := c.inner.Exists(ctx, jti) + if err != nil { + // Fail open is unsafe for revocation; surface the error and let the + // caller decide (the middleware rejects on error). + return false, err + } + c.set(jti, exists) + return exists, nil +} + +func (c *cachedSessionStore) get(jti string) (bool, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + e, ok := c.entries[jti] + if !ok || time.Now().After(e.expiresAt) { + return false, false + } + return e.valid, true +} + +func (c *cachedSessionStore) set(jti string, valid bool) { + c.mu.Lock() + c.entries[jti] = sessionEntry{valid: valid, expiresAt: time.Now().Add(c.ttl)} + c.mu.Unlock() +} + +func (c *cachedSessionStore) evict(jti string) { + c.mu.Lock() + delete(c.entries, jti) + c.mu.Unlock() +} + +// janitor evicts expired entries every ttl so the map cannot grow unbounded +// from one-shot JTIs. Mirrors middleware.ipLimiter.janitor. +func (c *cachedSessionStore) janitor() { + for range time.Tick(c.ttl) { + now := time.Now() + c.mu.Lock() + // Pre-size the kill list to the current map length: at most every entry + // is expired, so the slice never grows during collection. + expired := make([]string, 0, len(c.entries)) + for jti, e := range c.entries { + if now.After(e.expiresAt) { + expired = append(expired, jti) + } + } + for _, jti := range expired { + delete(c.entries, jti) + } + c.mu.Unlock() + } +} diff --git a/internal/player/session_test.go b/internal/player/session_test.go index 40c95d1..90a069a 100644 --- a/internal/player/session_test.go +++ b/internal/player/session_test.go @@ -2,6 +2,9 @@ package player import ( "context" + "errors" + "sync" + "sync/atomic" "testing" "time" @@ -11,6 +14,45 @@ import ( "github.com/stretchr/testify/require" ) +// countingStore is an in-memory SessionStore that records how many times each +// method hits it, so cache tests can prove Redis was (not) consulted. +type countingStore struct { + mu sync.Mutex + present map[string]bool + existsCalls int32 + failExists bool +} + +func newCountingStore() *countingStore { + return &countingStore{present: map[string]bool{}} +} + +func (s *countingStore) Save(_ context.Context, jti, _ string, _ time.Duration) error { + s.mu.Lock() + s.present[jti] = true + s.mu.Unlock() + return nil +} + +func (s *countingStore) Delete(_ context.Context, jti string) error { + s.mu.Lock() + delete(s.present, jti) + s.mu.Unlock() + return nil +} + +func (s *countingStore) Exists(_ context.Context, jti string) (bool, error) { + atomic.AddInt32(&s.existsCalls, 1) + if s.failExists { + return false, errors.New("redis down") + } + s.mu.Lock() + defer s.mu.Unlock() + return s.present[jti], nil +} + +func (s *countingStore) calls() int32 { return atomic.LoadInt32(&s.existsCalls) } + func TestSessionStoreLifecycle(t *testing.T) { mr := miniredis.RunT(t) rdb := redis.NewClient(&redis.Options{Addr: mr.Addr()}) @@ -46,3 +88,89 @@ func TestSessionStoreTTLExpiry(t *testing.T) { require.NoError(t, err) assert.False(t, exists, "session must expire with the JWT") } + +func TestCachedSessionStore_CacheHitSkipsRedis(t *testing.T) { + inner := newCountingStore() + require.NoError(t, inner.Save(context.Background(), "jti", "p1", time.Hour)) + cache := NewCachedSessionStore(inner, time.Minute) + ctx := context.Background() + + // First lookup misses the cache → hits Redis once. + ok, err := cache.Exists(ctx, "jti") + require.NoError(t, err) + assert.True(t, ok) + assert.Equal(t, int32(1), inner.calls()) + + // Subsequent lookups within the TTL are served from cache → no more Redis. + for i := 0; i < 50; i++ { + ok, err = cache.Exists(ctx, "jti") + require.NoError(t, err) + assert.True(t, ok) + } + assert.Equal(t, int32(1), inner.calls(), "cache hits must not hit Redis") +} + +func TestCachedSessionStore_TTLReChecksRedis(t *testing.T) { + inner := newCountingStore() + require.NoError(t, inner.Save(context.Background(), "jti", "p1", time.Hour)) + cache := NewCachedSessionStore(inner, 20*time.Millisecond) + ctx := context.Background() + + _, _ = cache.Exists(ctx, "jti") + assert.Equal(t, int32(1), inner.calls()) + + // After the cache entry expires, the next lookup re-checks Redis. + time.Sleep(40 * time.Millisecond) + _, _ = cache.Exists(ctx, "jti") + assert.Equal(t, int32(2), inner.calls(), "expired cache entry must re-check Redis") +} + +func TestCachedSessionStore_DeleteInvalidatesImmediately(t *testing.T) { + inner := newCountingStore() + require.NoError(t, inner.Save(context.Background(), "jti", "p1", time.Hour)) + cache := NewCachedSessionStore(inner, time.Minute) + ctx := context.Background() + + // Warm the cache with a "valid" verdict. + ok, _ := cache.Exists(ctx, "jti") + require.True(t, ok) + + // Logout on this replica: Delete must evict the local entry at once, so the + // next check re-reads Redis (now absent) rather than serving the stale cache. + require.NoError(t, cache.Delete(ctx, "jti")) + ok, err := cache.Exists(ctx, "jti") + require.NoError(t, err) + assert.False(t, ok, "deleted session must not be served from cache") +} + +func TestCachedSessionStore_PropagatesError(t *testing.T) { + inner := newCountingStore() + inner.failExists = true + cache := NewCachedSessionStore(inner, time.Minute) + + _, err := cache.Exists(context.Background(), "jti") + require.Error(t, err, "store errors must surface (caller fails closed)") +} + +func TestCachedSessionStore_ConcurrentAccess(t *testing.T) { + inner := newCountingStore() + require.NoError(t, inner.Save(context.Background(), "jti", "p1", time.Hour)) + cache := NewCachedSessionStore(inner, 5*time.Millisecond) + ctx := context.Background() + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 20; j++ { + _, _ = cache.Exists(ctx, "jti") + if j%5 == 0 { + _ = cache.Delete(ctx, "jti") + _ = cache.Save(ctx, "jti", "p1", time.Hour) + } + } + }() + } + wg.Wait() +} diff --git a/pkg/config/config.go b/pkg/config/config.go index dc5068e..af8ec3a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -91,6 +91,14 @@ type Config struct { LogSlowRequestThreshold time.Duration // requests slower than this always log TraceHighVolumeEvents []string // event types whose consumer spans are sampled down TraceHighVolumeSampleRatio float64 // keep-ratio for those high-volume spans + + // Session revocation cache. The gateway checks the session store on every + // authenticated request to honour server-side logout; an in-process positive + // cache keeps that off the Redis hot path. SessionCacheTTL is how long a + // "valid" verdict is trusted before re-checking Redis — i.e. the worst-case + // window a revoked token stays usable (eventual consistency, accepted). + SessionCacheEnabled bool + SessionCacheTTL time.Duration } // Load reads configuration for the named service from the environment. @@ -163,6 +171,9 @@ func Load(serviceName string) *Config { LogSlowRequestThreshold: getEnvDuration("LOG_SLOW_REQUEST_THRESHOLD", time.Second), TraceHighVolumeEvents: splitCSV(getEnv("TRACE_HIGHVOLUME_EVENTS", "LeaderboardUpdated")), TraceHighVolumeSampleRatio: getEnvFloat("TRACE_HIGHVOLUME_SAMPLE_RATIO", 0.01), + + SessionCacheEnabled: getEnvBool("SESSION_CACHE_ENABLED", true), + SessionCacheTTL: getEnvDuration("SESSION_CACHE_TTL", 5*time.Second), } } diff --git a/pkg/middleware/middleware.go b/pkg/middleware/middleware.go index fedf759..19cd6af 100644 --- a/pkg/middleware/middleware.go +++ b/pkg/middleware/middleware.go @@ -4,6 +4,7 @@ package middleware import ( + "context" "net/http" "strconv" "strings" @@ -209,10 +210,25 @@ func RateLimit(rps float64, burst int) gin.HandlerFunc { } } +// SessionChecker reports whether a session (by JWT ID / JTI) is still active. +// It is a minimal, locally-defined interface so this package depends on a +// behaviour, not on internal/player — Go structural typing lets the player +// SessionStore satisfy it for free. A *cached* checker should be passed so this +// per-request call does not hit Redis every time. +type SessionChecker interface { + Exists(ctx context.Context, jti string) (bool, error) +} + // Auth validates the Bearer token and stores the caller identity in the // context. Used by the gateway (and the WebSocket service directly, since // browsers cannot set headers on WS upgrade requests). -func Auth(tm *auth.TokenManager) gin.HandlerFunc { +// +// When sessions is non-nil, Auth also enforces server-side revocation: after +// the JWT is cryptographically valid, the session (by JTI) must still exist, so +// a logged-out token is rejected at the gateway before reaching any upstream. +// A nil sessions disables the check (JWT-only), preserving the previous +// behaviour for callers/tests that do not wire a store. +func Auth(tm *auth.TokenManager, sessions SessionChecker) gin.HandlerFunc { return func(c *gin.Context) { header := c.GetHeader("Authorization") const prefix = "Bearer " @@ -225,6 +241,24 @@ func Auth(tm *auth.TokenManager) gin.HandlerFunc { httpx.Error(c, http.StatusUnauthorized, "invalid or expired token") return } + + // Enforce server-side revocation (logout) when a session store is wired. + // The store is expected to be cache-backed, so this is normally an + // in-process lookup, not a Redis round trip. On a store error we reject + // rather than fail open: a revocation check that cannot run must not let + // a possibly-revoked token through. + if sessions != nil { + active, serr := sessions.Exists(c.Request.Context(), claims.ID) + if serr != nil { + httpx.Error(c, http.StatusServiceUnavailable, "session check unavailable") + return + } + if !active { + httpx.Error(c, http.StatusUnauthorized, "token revoked") + return + } + } + c.Set(CtxUserID, claims.PlayerID) c.Set(CtxUsername, claims.Username) c.Set(CtxTokenJTI, claims.ID) diff --git a/pkg/middleware/middleware_test.go b/pkg/middleware/middleware_test.go index bcead3e..23ee014 100644 --- a/pkg/middleware/middleware_test.go +++ b/pkg/middleware/middleware_test.go @@ -1,6 +1,8 @@ package middleware import ( + "context" + "errors" "net/http" "net/http/httptest" "testing" @@ -169,7 +171,7 @@ func TestAuthMiddleware(t *testing.T) { gin.SetMode(gin.TestMode) r := gin.New() - r.Use(Auth(tm)) + r.Use(Auth(tm, nil)) // nil sessions → JWT-only (revocation check disabled) r.GET("/protected", func(c *gin.Context) { c.JSON(200, gin.H{ "user_id": c.GetString(CtxUserID), @@ -205,3 +207,39 @@ func TestAuthMiddleware(t *testing.T) { assert.Contains(t, w.Body.String(), playerID.String()) assert.Contains(t, w.Body.String(), jti) } + +// stubSessions is a fixed-verdict SessionChecker for the revocation tests. +type stubSessions struct { + active bool + err error +} + +func (s stubSessions) Exists(context.Context, string) (bool, error) { return s.active, s.err } + +func TestAuthMiddlewareRevocation(t *testing.T) { + tm := auth.NewTokenManager("test-secret", time.Hour, "gamemesh") + token, _, err := tm.Generate(uuid.New(), "alice") + require.NoError(t, err) + + newReq := func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + return req + } + run := func(sessions SessionChecker) int { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(Auth(tm, sessions)) + r.GET("/protected", func(c *gin.Context) { c.JSON(200, gin.H{}) }) + w := httptest.NewRecorder() + r.ServeHTTP(w, newReq()) + return w.Code + } + + // Active session → allowed. + assert.Equal(t, http.StatusOK, run(stubSessions{active: true})) + // Revoked (logged-out) session → 401, even though the JWT is cryptographically valid. + assert.Equal(t, http.StatusUnauthorized, run(stubSessions{active: false})) + // Store unavailable → fail closed (503), never let a possibly-revoked token through. + assert.Equal(t, http.StatusServiceUnavailable, run(stubSessions{err: errors.New("redis down")})) +}