From c85cb17e20eaf0f2708df16a690a95daf8b1345c Mon Sep 17 00:00:00 2001 From: Jonathan Haas Date: Wed, 15 Apr 2026 15:10:52 -0700 Subject: [PATCH 1/3] Harden session JWT revocation --- db/migrations/0002_session_token_id.sql | 2 + internal/app/service.go | 15 +++ internal/app/service_test.go | 87 +++++++++++++ internal/bootstrap/service.go | 8 +- internal/core/types.go | 4 + internal/crypto/sessionjwt/manager.go | 76 +++++++++-- internal/crypto/sessionjwt/manager_test.go | 140 +++++++++++++++++++++ internal/store/memory/runtime.go | 47 +++++-- internal/store/memory/runtime_test.go | 11 ++ internal/store/postgres/repository.go | 11 +- internal/store/redis/runtime.go | 33 +++++ internal/store/redis/runtime_test.go | 11 ++ 12 files changed, 420 insertions(+), 25 deletions(-) create mode 100644 db/migrations/0002_session_token_id.sql create mode 100644 internal/crypto/sessionjwt/manager_test.go diff --git a/db/migrations/0002_session_token_id.sql b/db/migrations/0002_session_token_id.sql new file mode 100644 index 0000000..6256c67 --- /dev/null +++ b/db/migrations/0002_session_token_id.sql @@ -0,0 +1,2 @@ +ALTER TABLE sessions + ADD COLUMN IF NOT EXISTS token_id TEXT NOT NULL DEFAULT ''; diff --git a/internal/app/service.go b/internal/app/service.go index 8e5dc3d..22f153b 100644 --- a/internal/app/service.go +++ b/internal/app/service.go @@ -131,6 +131,7 @@ func (s *Service) CreateSession(ctx context.Context, req *core.CreateSessionRequ State: core.SessionStateActive, CreatedAt: now, } + session.TokenID = "tok_" + session.ID if err := s.repo.SaveSession(ctx, session); err != nil { return nil, fmt.Errorf("create session %q: save session: %w", session.ID, err) } @@ -446,6 +447,11 @@ func (s *Service) RevokeSession(ctx context.Context, req *core.RevokeSessionRequ return fmt.Errorf("revoke session %q: save session: %w", session.ID, err) } s.metrics.recordSessionTransition(previousState, session.State, session.TenantID) + if s.runtime != nil && session.TokenID != "" { + if err := s.runtime.RevokeSessionToken(ctx, session.TokenID, session.ExpiresAt); err != nil { + return fmt.Errorf("revoke session %q: revoke session token %q: %w", session.ID, session.TokenID, err) + } + } grants, err := s.repo.ListGrantsBySession(ctx, req.SessionID) if err != nil { @@ -836,6 +842,15 @@ func (s *Service) loadActiveSession(ctx context.Context, raw string) (*core.Sess if err != nil { return nil, nil, fmt.Errorf("load active session: verify session token: %w", err) } + if s.runtime != nil && claims.TokenID != "" { + revoked, err := s.runtime.IsSessionTokenRevoked(ctx, claims.TokenID) + if err != nil { + return nil, nil, fmt.Errorf("load active session: check session token revocation: %w", err) + } + if revoked { + return nil, nil, fmt.Errorf("%w: session token revoked", core.ErrForbidden) + } + } session, err := s.repo.GetSession(ctx, claims.SessionID) if err != nil { return nil, nil, fmt.Errorf("load active session %q: load session: %w", claims.SessionID, err) diff --git a/internal/app/service_test.go b/internal/app/service_test.go index 0735893..891f9a0 100644 --- a/internal/app/service_test.go +++ b/internal/app/service_test.go @@ -573,6 +573,93 @@ func TestService_RevokeSessionRevokesOutstandingGrants(t *testing.T) { } } +func TestService_RequestGrantRejectsRevokedSessionToken(t *testing.T) { + t.Parallel() + + ctx := context.Background() + now := testNow() + repo := memstore.NewRepository() + runtime := memstore.NewRuntimeStore() + tools := toolregistry.New() + engine := policy.NewEngine() + signer := mustNewSigner(t) + + mustPutTool(t, ctx, tools, core.Tool{ + TenantID: "t_acme", + Tool: "github", + ManifestHash: "sha256:test", + RuntimeClass: core.RuntimeClassHosted, + AllowedDeliveryModes: []core.DeliveryMode{core.DeliveryModeProxy}, + AllowedCapabilities: []string{"repo.read"}, + TrustTags: []string{"trusted", "github"}, + }) + mustPutPolicy(t, engine, core.Policy{ + TenantID: "t_acme", + Capability: "repo.read", + ResourceKind: core.ResourceKindGitHubRepo, + AllowedDeliveryModes: []core.DeliveryMode{core.DeliveryModeProxy}, + DefaultTTL: 10 * time.Minute, + MaxTTL: 10 * time.Minute, + ApprovalMode: core.ApprovalModeNone, + RequiredToolTags: []string{"trusted", "github"}, + Condition: `request.tool == "github"`, + }) + + svc, err := app.NewService(app.Config{ + Clock: fixedClock(now), + IDs: fixedIDs("sess_revoked"), + Repository: repo, + Runtime: runtime, + Verifier: fakeVerifier{identity: workloadIdentity()}, + SessionTokens: signer, + Policy: engine, + Tools: tools, + Connectors: fakeConnectorResolver{connector: &fakeConnector{kind: "github"}}, + Deliveries: map[core.DeliveryMode]core.DeliveryAdapter{ + core.DeliveryModeProxy: &fakeDeliveryAdapter{ + mode: core.DeliveryModeProxy, + delivery: &core.Delivery{ + Kind: core.DeliveryKindProxyHandle, + Handle: "ph_revoked", + }, + }, + }, + }) + if err != nil { + t.Fatalf("NewService() error = %v", err) + } + + sessionResp, err := svc.CreateSession(ctx, &core.CreateSessionRequest{ + TenantID: "t_acme", + AgentID: "agent_pr_reviewer", + RunID: "run_revoked", + ToolContext: []string{"github"}, + Attestation: &core.Attestation{Kind: core.AttestationKindK8SServiceAccountJWT, Token: "jwt"}, + }) + if err != nil { + t.Fatalf("CreateSession() error = %v", err) + } + + claims, err := signer.Verify(sessionResp.SessionToken) + if err != nil { + t.Fatalf("Verify() error = %v", err) + } + if err := runtime.RevokeSessionToken(ctx, claims.TokenID, claims.ExpiresAt); err != nil { + t.Fatalf("RevokeSessionToken() error = %v", err) + } + + _, err = svc.RequestGrant(ctx, &core.RequestGrantRequest{ + SessionToken: sessionResp.SessionToken, + Tool: "github", + Capability: "repo.read", + ResourceRef: "github:repo:acme/widgets", + DeliveryMode: core.DeliveryModeProxy, + }) + if err == nil || !strings.Contains(err.Error(), "session token revoked") { + t.Fatalf("RequestGrant() error = %v, want revoked token failure", err) + } +} + func mustNewSigner(t *testing.T) core.SessionTokenManager { t.Helper() diff --git a/internal/bootstrap/service.go b/internal/bootstrap/service.go index 5faa0a0..a63d591 100644 --- a/internal/bootstrap/service.go +++ b/internal/bootstrap/service.go @@ -403,18 +403,22 @@ func (v multiVerifier) Verify(ctx context.Context, in *core.Attestation) (*core. } func newSessionTokenManager() (core.SessionTokenManager, error) { + options := []sessionjwt.Option{} + if issuer := strings.TrimSpace(os.Getenv("ASB_SESSION_TOKEN_ISSUER")); issuer != "" { + options = append(options, sessionjwt.WithIssuer(issuer)) + } if path := os.Getenv("ASB_SESSION_SIGNING_PRIVATE_KEY_FILE"); path != "" { privateKey, err := loadEd25519PrivateKey(path) if err != nil { return nil, err } - return sessionjwt.NewManager(privateKey) + return sessionjwt.NewManager(privateKey, options...) } _, privateKey, err := ed25519.GenerateKey(rand.Reader) if err != nil { return nil, err } - return sessionjwt.NewManager(privateKey) + return sessionjwt.NewManager(privateKey, options...) } func newDelegationValidator() (core.DelegationValidator, error) { diff --git a/internal/core/types.go b/internal/core/types.go index 1216e24..c577609 100644 --- a/internal/core/types.go +++ b/internal/core/types.go @@ -133,6 +133,7 @@ type Session struct { TenantID string AgentID string RunID string + TokenID string WorkloadIdentity WorkloadIdentity Delegation *Delegation ToolContext []string @@ -395,6 +396,7 @@ type SessionClaims struct { TenantID string AgentID string RunID string + TokenID string ToolContext []string WorkloadHash string ExpiresAt time.Time @@ -504,6 +506,8 @@ type RuntimeStore interface { CompleteProxyRequest(ctx context.Context, handle string, responseBytes int64) error SaveRelaySession(ctx context.Context, relay *BrowserRelaySession) error GetRelaySession(ctx context.Context, sessionID string) (*BrowserRelaySession, error) + RevokeSessionToken(ctx context.Context, tokenID string, expiresAt time.Time) error + IsSessionTokenRevoked(ctx context.Context, tokenID string) (bool, error) } type GitHubProxyExecutor interface { diff --git a/internal/crypto/sessionjwt/manager.go b/internal/crypto/sessionjwt/manager.go index f0d2a75..d05499d 100644 --- a/internal/crypto/sessionjwt/manager.go +++ b/internal/crypto/sessionjwt/manager.go @@ -12,8 +12,18 @@ import ( type Manager struct { privateKey ed25519.PrivateKey publicKey ed25519.PublicKey + issuer string + clockSkew time.Duration + now func() time.Time } +type Option func(*Manager) + +const ( + defaultIssuer = "asb" + defaultClockSkew = 30 * time.Second +) + type claims struct { SessionID string `json:"sid"` TenantID string `json:"tenant_id"` @@ -24,7 +34,31 @@ type claims struct { jwt.RegisteredClaims } -func NewManager(privateKey ed25519.PrivateKey) (*Manager, error) { +func WithIssuer(issuer string) Option { + return func(manager *Manager) { + if issuer != "" { + manager.issuer = issuer + } + } +} + +func WithClockSkew(clockSkew time.Duration) Option { + return func(manager *Manager) { + if clockSkew >= 0 { + manager.clockSkew = clockSkew + } + } +} + +func WithNowFunc(now func() time.Time) Option { + return func(manager *Manager) { + if now != nil { + manager.now = now + } + } +} + +func NewManager(privateKey ed25519.PrivateKey, options ...Option) (*Manager, error) { if len(privateKey) == 0 { return nil, fmt.Errorf("%w: private key is required", core.ErrInvalidRequest) } @@ -32,13 +66,24 @@ func NewManager(privateKey ed25519.PrivateKey) (*Manager, error) { if !ok { return nil, fmt.Errorf("%w: private key public component is %T, want ed25519.PublicKey", core.ErrInvalidRequest, privateKey.Public()) } - return &Manager{ + manager := &Manager{ privateKey: privateKey, publicKey: publicKey, - }, nil + issuer: defaultIssuer, + clockSkew: defaultClockSkew, + now: time.Now, + } + for _, option := range options { + option(manager) + } + return manager, nil } func (m *Manager) Sign(session *core.Session) (string, error) { + tokenID := session.TokenID + if tokenID == "" { + tokenID = session.ID + } token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims{ SessionID: session.ID, TenantID: session.TenantID, @@ -49,7 +94,10 @@ func (m *Manager) Sign(session *core.Session) (string, error) { RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(session.ExpiresAt), IssuedAt: jwt.NewNumericDate(session.CreatedAt), + NotBefore: jwt.NewNumericDate(session.CreatedAt.Add(-m.clockSkew)), + Issuer: m.issuer, Subject: session.WorkloadIdentity.Subject, + ID: tokenID, }, }) return token.SignedString(m.privateKey) @@ -61,7 +109,7 @@ func (m *Manager) Verify(raw string) (*core.SessionClaims, error) { return nil, fmt.Errorf("%w: unexpected signing method %q", core.ErrUnauthorized, token.Method.Alg()) } return m.publicKey, nil - }) + }, jwt.WithIssuer(m.issuer), jwt.WithIssuedAt(), jwt.WithLeeway(m.clockSkew), jwt.WithTimeFunc(m.now)) if err != nil { return nil, fmt.Errorf("%w: %v", core.ErrUnauthorized, err) } @@ -70,21 +118,27 @@ func (m *Manager) Verify(raw string) (*core.SessionClaims, error) { if !ok || !parsed.Valid { return nil, fmt.Errorf("%w: invalid session token", core.ErrUnauthorized) } + if tokenClaims.ExpiresAt == nil { + return nil, fmt.Errorf("%w: missing exp", core.ErrUnauthorized) + } + if tokenClaims.NotBefore == nil { + return nil, fmt.Errorf("%w: missing nbf", core.ErrUnauthorized) + } + if tokenClaims.ID == "" { + return nil, fmt.Errorf("%w: missing jti", core.ErrUnauthorized) + } + if tokenClaims.SessionID == "" || tokenClaims.TenantID == "" { + return nil, fmt.Errorf("%w: missing required session claims", core.ErrUnauthorized) + } return &core.SessionClaims{ SessionID: tokenClaims.SessionID, TenantID: tokenClaims.TenantID, AgentID: tokenClaims.AgentID, RunID: tokenClaims.RunID, + TokenID: tokenClaims.ID, ToolContext: append([]string(nil), tokenClaims.ToolContext...), WorkloadHash: tokenClaims.WorkloadHash, ExpiresAt: tokenClaims.ExpiresAt.Time, }, nil } - -func (c *claims) Valid() error { - if c.ExpiresAt == nil { - return fmt.Errorf("%w: missing exp", core.ErrUnauthorized) - } - return jwt.NewValidator(jwt.WithTimeFunc(time.Now)).Validate(c.RegisteredClaims) -} diff --git a/internal/crypto/sessionjwt/manager_test.go b/internal/crypto/sessionjwt/manager_test.go new file mode 100644 index 0000000..42d802e --- /dev/null +++ b/internal/crypto/sessionjwt/manager_test.go @@ -0,0 +1,140 @@ +package sessionjwt + +import ( + "crypto/ed25519" + "strings" + "testing" + "time" + + "github.com/evalops/asb/internal/core" + "github.com/golang-jwt/jwt/v5" +) + +func TestManagerSignAddsStandardClaims(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 4, 15, 22, 0, 0, 0, time.UTC) + _, privateKey, err := ed25519.GenerateKey(nil) + if err != nil { + t.Fatalf("GenerateKey() error = %v", err) + } + + manager, err := NewManager(privateKey, WithIssuer("asb.example"), WithNowFunc(func() time.Time { return now })) + if err != nil { + t.Fatalf("NewManager() error = %v", err) + } + + raw, err := manager.Sign(&core.Session{ + ID: "sess_123", + TokenID: "tok_123", + TenantID: "t_acme", + AgentID: "agent_pr_reviewer", + RunID: "run_7f9", + CreatedAt: now, + ExpiresAt: now.Add(15 * time.Minute), + WorkloadIdentity: core.WorkloadIdentity{ + Subject: "system:serviceaccount:agents:runner", + }, + }) + if err != nil { + t.Fatalf("Sign() error = %v", err) + } + + parsed, err := jwt.ParseWithClaims(raw, &claims{}, func(token *jwt.Token) (any, error) { + return manager.publicKey, nil + }) + if err != nil { + t.Fatalf("ParseWithClaims() error = %v", err) + } + tokenClaims, ok := parsed.Claims.(*claims) + if !ok { + t.Fatalf("claims type = %T, want *claims", parsed.Claims) + } + if tokenClaims.ID != "tok_123" { + t.Fatalf("jti = %q, want tok_123", tokenClaims.ID) + } + if tokenClaims.Issuer != "asb.example" { + t.Fatalf("issuer = %q, want asb.example", tokenClaims.Issuer) + } + if tokenClaims.NotBefore == nil { + t.Fatal("nbf = nil, want non-nil") + } + if got := tokenClaims.NotBefore.Time; !got.Equal(now.Add(-defaultClockSkew)) { + t.Fatalf("nbf = %s, want %s", got, now.Add(-defaultClockSkew)) + } + + sessionClaims, err := manager.Verify(raw) + if err != nil { + t.Fatalf("Verify() error = %v", err) + } + if sessionClaims.TokenID != "tok_123" { + t.Fatalf("TokenID = %q, want tok_123", sessionClaims.TokenID) + } +} + +func TestManagerVerifyRejectsMissingJTI(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 4, 15, 22, 0, 0, 0, time.UTC) + _, privateKey, err := ed25519.GenerateKey(nil) + if err != nil { + t.Fatalf("GenerateKey() error = %v", err) + } + + manager, err := NewManager(privateKey, WithNowFunc(func() time.Time { return now })) + if err != nil { + t.Fatalf("NewManager() error = %v", err) + } + + raw, err := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims{ + SessionID: "sess_123", + TenantID: "t_acme", + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(15 * time.Minute)), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now.Add(-defaultClockSkew)), + Issuer: defaultIssuer, + }, + }).SignedString(privateKey) + if err != nil { + t.Fatalf("SignedString() error = %v", err) + } + + if _, err := manager.Verify(raw); err == nil || !strings.Contains(err.Error(), "missing jti") { + t.Fatalf("Verify() error = %v, want missing jti", err) + } +} + +func TestManagerVerifyRejectsTokenBeforeNotBefore(t *testing.T) { + t.Parallel() + + now := time.Date(2026, 4, 15, 22, 0, 0, 0, time.UTC) + _, privateKey, err := ed25519.GenerateKey(nil) + if err != nil { + t.Fatalf("GenerateKey() error = %v", err) + } + + manager, err := NewManager(privateKey, WithNowFunc(func() time.Time { return now })) + if err != nil { + t.Fatalf("NewManager() error = %v", err) + } + + raw, err := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims{ + SessionID: "sess_123", + TenantID: "t_acme", + RegisteredClaims: jwt.RegisteredClaims{ + ID: "tok_123", + ExpiresAt: jwt.NewNumericDate(now.Add(15 * time.Minute)), + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now.Add(time.Minute)), + Issuer: defaultIssuer, + }, + }).SignedString(privateKey) + if err != nil { + t.Fatalf("SignedString() error = %v", err) + } + + if _, err := manager.Verify(raw); err == nil || !strings.Contains(err.Error(), "token is not valid yet") { + t.Fatalf("Verify() error = %v, want not-before validation failure", err) + } +} diff --git a/internal/store/memory/runtime.go b/internal/store/memory/runtime.go index 4582f83..9bc73a1 100644 --- a/internal/store/memory/runtime.go +++ b/internal/store/memory/runtime.go @@ -11,17 +11,19 @@ import ( ) type RuntimeStore struct { - mu sync.RWMutex - budgets map[string]*proxybudget.BudgetTracker - relays map[string]*core.BrowserRelaySession - expires map[string]time.Time + mu sync.RWMutex + budgets map[string]*proxybudget.BudgetTracker + relays map[string]*core.BrowserRelaySession + expires map[string]time.Time + tokenRevocations map[string]time.Time } func NewRuntimeStore() *RuntimeStore { return &RuntimeStore{ - budgets: make(map[string]*proxybudget.BudgetTracker), - relays: make(map[string]*core.BrowserRelaySession), - expires: make(map[string]time.Time), + budgets: make(map[string]*proxybudget.BudgetTracker), + relays: make(map[string]*core.BrowserRelaySession), + expires: make(map[string]time.Time), + tokenRevocations: make(map[string]time.Time), } } @@ -95,3 +97,34 @@ func (s *RuntimeStore) GetRelaySession(_ context.Context, sessionID string) (*co } return &cp, nil } + +func (s *RuntimeStore) RevokeSessionToken(_ context.Context, tokenID string, expiresAt time.Time) error { + if tokenID == "" { + return nil + } + + s.mu.Lock() + defer s.mu.Unlock() + + s.tokenRevocations[tokenID] = expiresAt + return nil +} + +func (s *RuntimeStore) IsSessionTokenRevoked(_ context.Context, tokenID string) (bool, error) { + if tokenID == "" { + return false, nil + } + + s.mu.Lock() + defer s.mu.Unlock() + + expiresAt, ok := s.tokenRevocations[tokenID] + if !ok { + return false, nil + } + if !expiresAt.IsZero() && time.Now().After(expiresAt) { + delete(s.tokenRevocations, tokenID) + return false, nil + } + return true, nil +} diff --git a/internal/store/memory/runtime_test.go b/internal/store/memory/runtime_test.go index 812d87b..466d362 100644 --- a/internal/store/memory/runtime_test.go +++ b/internal/store/memory/runtime_test.go @@ -52,4 +52,15 @@ func TestRuntimeStore_ProxyBudgetsAndRelaySessions(t *testing.T) { if got.KeyID != "key_1" || got.Selectors["username"] != "#username" { t.Fatalf("relay = %#v, want saved relay session", got) } + + if err := store.RevokeSessionToken(context.Background(), "tok_123", time.Now().Add(time.Minute)); err != nil { + t.Fatalf("RevokeSessionToken() error = %v", err) + } + revoked, err := store.IsSessionTokenRevoked(context.Background(), "tok_123") + if err != nil { + t.Fatalf("IsSessionTokenRevoked() error = %v", err) + } + if !revoked { + t.Fatal("IsSessionTokenRevoked() = false, want true") + } } diff --git a/internal/store/postgres/repository.go b/internal/store/postgres/repository.go index 781cb99..30f838c 100644 --- a/internal/store/postgres/repository.go +++ b/internal/store/postgres/repository.go @@ -77,15 +77,16 @@ func (r *Repository) SaveSession(ctx context.Context, session *core.Session) err return err } _, err = r.db.Exec(ctx, ` - INSERT INTO sessions (id, tenant_id, workload_id, delegation_id, agent_id, run_id, tool_context_json, workload_hash, state, expires_at, created_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + INSERT INTO sessions (id, tenant_id, workload_id, delegation_id, agent_id, run_id, token_id, tool_context_json, workload_hash, state, expires_at, created_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) ON CONFLICT (id) DO UPDATE SET delegation_id = EXCLUDED.delegation_id, + token_id = EXCLUDED.token_id, tool_context_json = EXCLUDED.tool_context_json, workload_hash = EXCLUDED.workload_hash, state = EXCLUDED.state, expires_at = EXCLUDED.expires_at - `, session.ID, session.TenantID, workloadID, delegationID, session.AgentID, session.RunID, toolContextJSON, session.WorkloadHash, string(session.State), session.ExpiresAt, session.CreatedAt) + `, session.ID, session.TenantID, workloadID, delegationID, session.AgentID, session.RunID, session.TokenID, toolContextJSON, session.WorkloadHash, string(session.State), session.ExpiresAt, session.CreatedAt) return err } @@ -106,7 +107,7 @@ func (r *Repository) GetSession(ctx context.Context, sessionID string) (*core.Se err := r.db.QueryRow(ctx, ` SELECT - s.id, s.tenant_id, s.agent_id, s.run_id, s.tool_context_json, s.workload_hash, s.state, s.expires_at, s.created_at, + s.id, s.tenant_id, s.agent_id, s.run_id, s.token_id, s.tool_context_json, s.workload_hash, s.state, s.expires_at, s.created_at, w.identity_type, w.subject, w.issuer, w.metadata_json, d.id, d.issuer, d.subject, d.claims_json, d.expires_at FROM sessions s @@ -114,7 +115,7 @@ func (r *Repository) GetSession(ctx context.Context, sessionID string) (*core.Se LEFT JOIN delegations d ON d.id = s.delegation_id WHERE s.id = $1 `, sessionID).Scan( - &session.ID, &session.TenantID, &session.AgentID, &session.RunID, &toolContextJSON, &session.WorkloadHash, &session.State, &session.ExpiresAt, &session.CreatedAt, + &session.ID, &session.TenantID, &session.AgentID, &session.RunID, &session.TokenID, &toolContextJSON, &session.WorkloadHash, &session.State, &session.ExpiresAt, &session.CreatedAt, &workloadType, &workloadSubject, &workloadIssuer, &workloadJSON, &delegationID, &delegationIssuer, &delegationSubject, &delegationClaimsJSON, &delegationExpiresAt, ) diff --git a/internal/store/redis/runtime.go b/internal/store/redis/runtime.go index 333d3f7..da50b46 100644 --- a/internal/store/redis/runtime.go +++ b/internal/store/redis/runtime.go @@ -173,6 +173,35 @@ func (s *RuntimeStore) GetRelaySession(ctx context.Context, sessionID string) (* }, nil } +func (s *RuntimeStore) RevokeSessionToken(ctx context.Context, tokenID string, expiresAt time.Time) error { + if tokenID == "" { + return nil + } + + key := revokedSessionTokenKey(tokenID) + if err := s.client.Set(ctx, key, "1", 0).Err(); err != nil { + return err + } + if !expiresAt.IsZero() { + if err := s.client.ExpireAt(ctx, key, expiresAt).Err(); err != nil { + return err + } + } + return nil +} + +func (s *RuntimeStore) IsSessionTokenRevoked(ctx context.Context, tokenID string) (bool, error) { + if tokenID == "" { + return false, nil + } + + exists, err := s.client.Exists(ctx, revokedSessionTokenKey(tokenID)).Result() + if err != nil { + return false, err + } + return exists > 0, nil +} + func proxyKey(handle string) string { return "proxy:" + handle } @@ -181,6 +210,10 @@ func relayKey(sessionID string) string { return "relay:" + sessionID } +func revokedSessionTokenKey(tokenID string) string { + return "session_token_revoked:" + tokenID +} + func parseRedisInt(value any) int { switch v := value.(type) { case string: diff --git a/internal/store/redis/runtime_test.go b/internal/store/redis/runtime_test.go index 8ffaab1..5c04c17 100644 --- a/internal/store/redis/runtime_test.go +++ b/internal/store/redis/runtime_test.go @@ -60,4 +60,15 @@ func TestRuntimeStore_ProxyAndRelayState(t *testing.T) { if got.KeyID != "key_1" || got.Selectors["username"] != "#username" { t.Fatalf("relay = %#v, want saved relay session", got) } + + if err := store.RevokeSessionToken(context.Background(), "tok_123", time.Now().Add(time.Minute)); err != nil { + t.Fatalf("RevokeSessionToken() error = %v", err) + } + revoked, err := store.IsSessionTokenRevoked(context.Background(), "tok_123") + if err != nil { + t.Fatalf("IsSessionTokenRevoked() error = %v", err) + } + if !revoked { + t.Fatal("IsSessionTokenRevoked() = false, want true") + } } From 10a9957ecbfceb593cbc0c4636c80667347b0ccc Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Wed, 15 Apr 2026 22:20:38 +0000 Subject: [PATCH 2/3] Confirm sessionjwt skew report is false positive From 03b97ea4962d7cc6190721abf703432aad058d01 Mon Sep 17 00:00:00 2001 From: Jonathan Haas Date: Wed, 15 Apr 2026 15:49:58 -0700 Subject: [PATCH 3/3] test: fix session jwt claim parsing clock --- internal/crypto/sessionjwt/manager_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/crypto/sessionjwt/manager_test.go b/internal/crypto/sessionjwt/manager_test.go index 42d802e..37306a5 100644 --- a/internal/crypto/sessionjwt/manager_test.go +++ b/internal/crypto/sessionjwt/manager_test.go @@ -42,7 +42,7 @@ func TestManagerSignAddsStandardClaims(t *testing.T) { parsed, err := jwt.ParseWithClaims(raw, &claims{}, func(token *jwt.Token) (any, error) { return manager.publicKey, nil - }) + }, jwt.WithTimeFunc(func() time.Time { return now })) if err != nil { t.Fatalf("ParseWithClaims() error = %v", err) }