From ad86b2c3267852629ab6aa2eca1b0447cfef469c Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sun, 12 Apr 2026 21:30:01 +0800 Subject: [PATCH 01/13] feat(oauth): add RFC 7523 private_key_jwt client authentication MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Accept signed JWT assertions at /oauth/token and /oauth/introspect as an alternative to client_secret_basic/post - Support RS256 and ES256 via inline jwks or remote jwks_uri registered per client (RFC 7591 §2.1) - Advertise the new method in OIDC discovery and Dynamic Client Registration accepts token_endpoint_auth_method=private_key_jwt - Enforce iss/sub/aud/exp/nbf/jti with configurable clock skew and lifetime caps; serialise jti replay detection under a mutex - Rate-limit JWKS refresh on kid miss to avoid unbounded refetches - Emit CLIENT_ASSERTION_VERIFIED / CLIENT_ASSERTION_FAILED audit events - Ship docs/PRIVATE_KEY_JWT.md with usage, curl + Python SDK examples Co-Authored-By: Claude Opus 4.6 (1M context) --- .env.example | 12 + docs/PRIVATE_KEY_JWT.md | 247 ++++++++++ internal/bootstrap/handlers.go | 51 +- internal/config/config.go | 14 + internal/handlers/client_auth.go | 148 ++++++ internal/handlers/oidc.go | 64 ++- internal/handlers/registration.go | 141 +++++- internal/handlers/registration_test.go | 82 +++- internal/handlers/token.go | 180 +++++--- .../handlers/token_private_key_jwt_test.go | 278 +++++++++++ internal/handlers/utils.go | 12 +- internal/models/audit_log.go | 4 + internal/models/oauth_application.go | 78 +++- internal/services/client.go | 99 +++- internal/services/client_assertion.go | 382 +++++++++++++++ internal/services/client_assertion_test.go | 435 ++++++++++++++++++ internal/services/jwks_fetcher.go | 152 ++++++ internal/services/jwks_fetcher_test.go | 183 ++++++++ internal/services/token_client_credentials.go | 30 ++ internal/util/jwk.go | 176 +++++++ internal/util/jwk_test.go | 193 ++++++++ 21 files changed, 2837 insertions(+), 124 deletions(-) create mode 100644 docs/PRIVATE_KEY_JWT.md create mode 100644 internal/handlers/client_auth.go create mode 100644 internal/handlers/token_private_key_jwt_test.go create mode 100644 internal/services/client_assertion.go create mode 100644 internal/services/client_assertion_test.go create mode 100644 internal/services/jwks_fetcher.go create mode 100644 internal/services/jwks_fetcher_test.go create mode 100644 internal/util/jwk.go create mode 100644 internal/util/jwk_test.go diff --git a/.env.example b/.env.example index b386fd60..ff6a91aa 100644 --- a/.env.example +++ b/.env.example @@ -285,3 +285,15 @@ EXPIRED_TOKEN_CLEANUP_INTERVAL=1h # How often to run the cleanup (default: # CORS_ALLOWED_METHODS=GET,POST,PUT,DELETE,OPTIONS # CORS_ALLOWED_HEADERS=Origin,Content-Type,Authorization # CORS_MAX_AGE=12h + +# ============================================================ +# Private Key JWT (RFC 7523) — Client Assertion Authentication +# ============================================================ +# When enabled, OAuth clients can authenticate to /oauth/token and +# /oauth/introspect by presenting a signed JWT assertion instead of +# client_secret. See docs/PRIVATE_KEY_JWT.md. +# PRIVATE_KEY_JWT_ENABLED=true +# JWKS_FETCH_TIMEOUT=10s # HTTP timeout when fetching a client's jwks_uri +# JWKS_CACHE_TTL=1h # Cached JWKS lifetime before refetch +# CLIENT_ASSERTION_MAX_LIFETIME=5m # Reject assertions whose exp-iat exceeds this +# CLIENT_ASSERTION_CLOCK_SKEW=30s # Tolerance for exp/nbf/iat skew diff --git a/docs/PRIVATE_KEY_JWT.md b/docs/PRIVATE_KEY_JWT.md new file mode 100644 index 00000000..a041c7a9 --- /dev/null +++ b/docs/PRIVATE_KEY_JWT.md @@ -0,0 +1,247 @@ +# Private Key JWT Client Authentication (RFC 7523) + +AuthGate supports **`private_key_jwt`** at the token endpoint — clients authenticate by signing a short-lived JWT assertion with their private key instead of presenting a long-lived `client_secret`. This is the authentication method [recommended by the MCP OAuth Client Credentials extension](https://modelcontextprotocol.io/extensions/auth/oauth-client-credentials) for machine-to-machine flows. + +- **RFC 7521** — Assertion Framework for OAuth 2.0 Client Authentication +- **RFC 7523** — JWT Profile for OAuth 2.0 Client Authentication +- **RFC 7591 §2.1** — Dynamic client registration of `jwks`/`jwks_uri` + +## Why use it + +| Aspect | `client_secret_basic`/`_post` | `private_key_jwt` | +| ------------------------- | --------------------------------------------------------------------- | -------------------------------------------------------- | +| Credential lifetime | Long-lived, stored on server | Short-lived (minutes), regenerated per call | +| Transmitted over the wire | Raw secret (even under TLS, exposed to intermediaries) | JWT signed with private key; private key stays on client | +| Server storage | bcrypt hash of secret (attacker who exfiltrates DB can offline-crack) | Public key only (no useful secret to steal) | +| Rotation | Requires re-issuing secret | Rotate JWKS; no downtime | +| Replay protection | None at protocol level | Built-in via `jti` + short `exp` | + +See [RFC 7523 §1](https://datatracker.ietf.org/doc/html/rfc7523#section-1) for more background. + +## Enabling the feature + +`private_key_jwt` is enabled by default. Set the following environment variables to configure it: + +```bash +PRIVATE_KEY_JWT_ENABLED=true # Feature flag (default: true) +JWKS_FETCH_TIMEOUT=10s # HTTP timeout when fetching jwks_uri (default: 10s) +JWKS_CACHE_TTL=1h # JWKS cache lifetime (default: 1h) +CLIENT_ASSERTION_MAX_LIFETIME=5m # Reject assertions whose exp-iat exceeds this (default: 5m) +CLIENT_ASSERTION_CLOCK_SKEW=30s # Tolerance for exp/nbf/iat skew (default: 30s) +``` + +When enabled, the OIDC discovery document lists the new method: + +```bash +curl https://authgate.example.com/.well-known/openid-configuration | jq .token_endpoint_auth_methods_supported +# ["client_secret_basic","client_secret_post","none","private_key_jwt"] + +curl https://authgate.example.com/.well-known/openid-configuration | jq .token_endpoint_auth_signing_alg_values_supported +# ["RS256","ES256"] +``` + +## Supported algorithms + +- **RS256** — 2048-bit (or larger) RSA with SHA-256 +- **ES256** — ECDSA P-256 with SHA-256 + +`HS256` and other symmetric algorithms are rejected by design (they provide no advantage over `client_secret_*`). `EdDSA` is not currently supported; file an issue if you need it. + +## Registering a client + +### Option 1 — Dynamic Client Registration (RFC 7591) + +Enable DCR (`ENABLE_DYNAMIC_CLIENT_REGISTRATION=true`), then POST a client metadata document: + +```bash +curl -X POST https://authgate.example.com/oauth/register \ + -H "Content-Type: application/json" \ + -d '{ + "client_name": "my-service", + "grant_types": ["client_credentials"], + "token_endpoint_auth_method": "private_key_jwt", + "token_endpoint_auth_signing_alg": "RS256", + "scope": "email profile", + "jwks": { + "keys": [ + { + "kty": "RSA", + "use": "sig", + "kid": "2026-04-12", + "alg": "RS256", + "n": "0vx7agoebGcQ...", + "e": "AQAB" + } + ] + } + }' +``` + +Alternative: provide `jwks_uri` instead of an inline `jwks`: + +```json +{ + "client_name": "my-service", + "grant_types": ["client_credentials"], + "token_endpoint_auth_method": "private_key_jwt", + "token_endpoint_auth_signing_alg": "ES256", + "jwks_uri": "https://my-service.example.com/.well-known/jwks.json" +} +``` + +`jwks_uri` and `jwks` are **mutually exclusive**; exactly one must be present. + +Registered clients start in `pending` status and require admin approval (standard DCR behaviour). The response does **not** include a `client_secret` — `private_key_jwt` clients have no shared secret. + +### Option 2 — Service-layer API + +Callers with direct access to the service layer can use `services.CreateClientRequest` with the new fields: + +```go +req := services.CreateClientRequest{ + ClientName: "my-service", + ClientType: core.ClientTypeConfidential, + EnableClientCredentialsFlow: true, + TokenEndpointAuthMethod: models.TokenEndpointAuthPrivateKeyJWT, + TokenEndpointAuthSigningAlg: "RS256", + JWKS: jwkSetJSON, // or JWKSURI: "https://..." + IsAdminCreated: true, // active immediately +} +resp, err := clientService.CreateClient(ctx, req) +``` + +### Admin UI + +The admin web form does not yet expose the new fields. Admins can register `private_key_jwt` clients via the DCR endpoint above or via a small helper script that calls the service layer directly. This will be added in a follow-up. + +## Requesting a token + +The client signs a JWT and presents it as `client_assertion` at `/oauth/token`: + +```http +POST /oauth/token HTTP/1.1 +Host: authgate.example.com +Content-Type: application/x-www-form-urlencoded + +grant_type=client_credentials +&scope=read write +&client_assertion_type=urn:ietf:params:oauth:client-assertion-type:jwt-bearer +&client_assertion= +``` + +### Required JWT claims + +| Claim | Value | +| ----- | ----------------------------------------------------------------------------- | +| `iss` | Your client_id | +| `sub` | Your client_id (must match `iss`) | +| `aud` | AuthGate's token endpoint URL, or the issuer URL | +| `iat` | Issued-at timestamp (now) | +| `exp` | Expiration (max `iat + CLIENT_ASSERTION_MAX_LIFETIME`; recommend 1–5 minutes) | +| `jti` | Unique per-assertion identifier (required — replay protection) | + +### JWT header + +`alg` must match `token_endpoint_auth_signing_alg` that was registered for the client. Include `kid` so AuthGate can pick the right key from your JWK Set: + +```json +{ + "alg": "RS256", + "kid": "2026-04-12", + "typ": "JWT" +} +``` + +## Client examples + +### Python (official MCP SDK) + +```python +from mcp.client.auth.extensions.client_credentials import ( + PrivateKeyJWTOAuthProvider, + SignedJWTParameters, +) +from mcp.client.streamable_http import streamablehttp_client +from mcp import ClientSession + +jwt_params = SignedJWTParameters( + issuer="my-service", # client_id + subject="my-service", # must match issuer + signing_key=open("private_key.pem").read(), + signing_algorithm="RS256", + lifetime_seconds=300, +) + +provider = PrivateKeyJWTOAuthProvider( + server_url="https://authgate.example.com/mcp", + client_id="my-service", + assertion_provider=jwt_params.create_assertion_provider(), + scopes="read write", +) +``` + +The SDK obtains the token endpoint URL from AuthGate's `/.well-known/openid-configuration` and handles assertion signing + token refresh automatically. + +### curl (debugging) + +Generate a key pair and sign a JWT manually (using `python -c` or `jose-util`), then: + +```bash +ASSERTION=$(python3 - <<'PY' +import jwt, time, uuid +priv = open('private_key.pem').read() +claims = { + 'iss': 'my-service', + 'sub': 'my-service', + 'aud': 'https://authgate.example.com/oauth/token', + 'iat': int(time.time()), + 'exp': int(time.time()) + 300, + 'jti': str(uuid.uuid4()), +} +print(jwt.encode(claims, priv, algorithm='RS256', headers={'kid': '2026-04-12'})) +PY +) + +curl -X POST https://authgate.example.com/oauth/token \ + -d grant_type=client_credentials \ + -d scope="read write" \ + -d client_assertion_type=urn:ietf:params:oauth:client-assertion-type:jwt-bearer \ + -d client_assertion="$ASSERTION" +``` + +## Key rotation + +- **`jwks_uri`**: publish both old and new keys during overlap; AuthGate re-fetches when it encounters a `kid` it doesn't know (bypassing `JWKS_CACHE_TTL`). This supports zero-downtime rotation. +- **Inline `jwks`**: update the client via DCR or admin API with the new JWK Set. Old clients already issued assertions against the old key continue to verify until reassertions start arriving with the new `kid`. + +Best practice: keep both keys published for at least twice the longest `exp` you expect (e.g. 10 minutes for 5-minute assertions) before retiring the old one. + +## Which endpoints accept `private_key_jwt` + +Currently supported: + +- `POST /oauth/token` for **`grant_type=client_credentials`** — primary use case (MCP M2M). +- `POST /oauth/introspect` — so Resource Servers can authenticate via the same JWT. + +The other grants (`authorization_code`, `refresh_token`, `device_code`) continue to use `client_secret_*` or public-client (no auth) modes. Extending them is tracked as a follow-up — the shared authenticator (`internal/handlers/client_auth.go`) is already in place, only the per-grant wiring is pending. + +## Troubleshooting + +| Symptom | Likely cause | +| ------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------- | +| `invalid_client` immediately | `iss` / `sub` / `aud` mismatch, or `exp` missing/past | +| `invalid_client` after a few seconds | Clock skew — check `CLIENT_ASSERTION_CLOCK_SKEW` | +| `invalid_client` on re-use | `jti` replay — each assertion must be unique | +| `invalid_client` with correct claims | `kid` missing or not in registered JWKS; check server logs for `no matching JWK for kid` | +| `invalid_client` after key rotation | `JWKS_CACHE_TTL` not yet expired and `kid` mismatch didn't trigger refresh — verify new key ships with a fresh `kid` | +| `unauthorized_client` instead of `invalid_client` | Client registered but `client_credentials` flow not enabled | + +Check `/admin/audit` with filter `event_type=CLIENT_ASSERTION_FAILED` for the exact reason logged server-side. + +## Security notes + +- **Private keys must never leave the client.** Use a secrets manager (AWS Secrets Manager, GCP Secret Manager, HashiCorp Vault) or a KMS-backed sign API (AWS KMS, GCP Cloud KMS). +- Keep assertion lifetime short (≤ 5 minutes) — `CLIENT_ASSERTION_MAX_LIFETIME` caps it server-side. +- Use a cryptographically-random `jti` per assertion (a UUIDv4 is fine). +- `PRIVATE_KEY_JWT_ENABLED=false` immediately rejects all assertion-based authentication and hides the method from discovery — a useful kill-switch if key material is suspected to be compromised and you need time to investigate. +- `jti` replay protection currently uses an in-memory cache per instance. For multi-instance deployments where a client might hit different replicas within the same assertion lifetime, promote the jti cache to Redis (tracked as a follow-up). diff --git a/internal/bootstrap/handlers.go b/internal/bootstrap/handlers.go index d91d17fb..5d93e200 100644 --- a/internal/bootstrap/handlers.go +++ b/internal/bootstrap/handlers.go @@ -4,12 +4,15 @@ import ( "crypto" "embed" "net/http" + "strings" "github.com/go-authgate/authgate/internal/auth" + "github.com/go-authgate/authgate/internal/cache" "github.com/go-authgate/authgate/internal/config" "github.com/go-authgate/authgate/internal/core" "github.com/go-authgate/authgate/internal/handlers" "github.com/go-authgate/authgate/internal/services" + "github.com/go-authgate/authgate/internal/util" ) // handlerSet holds all HTTP handlers and required services @@ -50,6 +53,48 @@ func initializeHandlers(deps handlerDeps) handlerSet { // Build JWKS handler from the token provider's public key info jwksHandler := buildJWKSHandler(deps.tokenProvider, deps.cfg) + // Build the optional RFC 7523 private_key_jwt client authenticator. + // Memory caches are used locally — a follow-up PR may wire redis variants + // when multi-instance deployment needs coordinated jti replay protection. + var clientAuth *handlers.ClientAuthenticator + if deps.cfg.PrivateKeyJWTEnabled { + jwksCache := cache.NewMemoryCache[util.JWKSet]() + jtiCache := cache.NewMemoryCache[bool]() + jwksFetcher := services.NewJWKSFetcher( + jwksCache, + deps.cfg.JWKSFetchTimeout, + deps.cfg.JWKSCacheTTL, + ) + tokenEndpoint := strings.TrimRight(deps.cfg.BaseURL, "/") + "/oauth/token" + issuer := strings.TrimRight(deps.cfg.BaseURL, "/") + verifier := services.NewClientAssertionVerifier( + deps.services.client, + jwksFetcher, + jtiCache, + deps.auditService, + services.ClientAssertionConfig{ + Enabled: true, + ExpectedAudiences: []string{tokenEndpoint, issuer}, + MaxLifetime: deps.cfg.ClientAssertionMaxLifetime, + ClockSkew: deps.cfg.ClientAssertionClockSkew, + }, + ) + clientAuth = handlers.NewClientAuthenticator( + deps.services.client, + verifier, + tokenEndpoint, + ) + } + + tokenHandler := handlers.NewTokenHandler( + deps.services.token, + deps.services.authorization, + deps.cfg, + ) + if clientAuth != nil { + tokenHandler = tokenHandler.WithClientAuthenticator(clientAuth) + } + return handlerSet{ auth: handlers.NewAuthHandler( deps.services.user, @@ -62,11 +107,7 @@ func initializeHandlers(deps handlerDeps) handlerSet { deps.services.authorization, deps.cfg, ), - token: handlers.NewTokenHandler( - deps.services.token, - deps.services.authorization, - deps.cfg, - ), + token: tokenHandler, client: handlers.NewClientHandler(deps.services.client, deps.services.authorization), userClient: handlers.NewUserClientHandler(deps.services.client), session: handlers.NewSessionHandler(deps.services.token), diff --git a/internal/config/config.go b/internal/config/config.go index d3a9a3ce..5eb20566 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -100,6 +100,13 @@ type Config struct { // Client Credentials Flow settings (RFC 6749 §4.4) ClientCredentialsTokenExpiration time.Duration // Access token lifetime for client_credentials grant (default: 1h, same as JWTExpiration) + // Private Key JWT settings (RFC 7523 client_assertion-based authentication) + PrivateKeyJWTEnabled bool // Enable private_key_jwt token endpoint authentication (default: true) + JWKSFetchTimeout time.Duration // HTTP timeout for fetching remote JWKS (default: 10s) + JWKSCacheTTL time.Duration // Cached JWKS lifetime before refetch (default: 1h) + ClientAssertionMaxLifetime time.Duration // Maximum allowed assertion exp-iat window (default: 5m) + ClientAssertionClockSkew time.Duration // Clock skew tolerance when validating exp/nbf/iat (default: 30s) + // OAuth settings // GitHub OAuth GitHubOAuthEnabled bool @@ -304,6 +311,13 @@ func Load() *Config { time.Hour, ), // 1 hour default; keep short — no refresh token means no rotation mechanism + // Private Key JWT (RFC 7523) settings + PrivateKeyJWTEnabled: getEnvBool("PRIVATE_KEY_JWT_ENABLED", true), + JWKSFetchTimeout: getEnvDuration("JWKS_FETCH_TIMEOUT", 10*time.Second), + JWKSCacheTTL: getEnvDuration("JWKS_CACHE_TTL", time.Hour), + ClientAssertionMaxLifetime: getEnvDuration("CLIENT_ASSERTION_MAX_LIFETIME", 5*time.Minute), + ClientAssertionClockSkew: getEnvDuration("CLIENT_ASSERTION_CLOCK_SKEW", 30*time.Second), + // OAuth settings // GitHub OAuth GitHubOAuthEnabled: getEnvBool("GITHUB_OAUTH_ENABLED", false), diff --git a/internal/handlers/client_auth.go b/internal/handlers/client_auth.go new file mode 100644 index 00000000..6c2f9f1f --- /dev/null +++ b/internal/handlers/client_auth.go @@ -0,0 +1,148 @@ +package handlers + +import ( + "context" + "errors" + "strings" + + "github.com/go-authgate/authgate/internal/models" + "github.com/go-authgate/authgate/internal/services" + + "github.com/gin-gonic/gin" +) + +// Form parameters defined in RFC 7521 §4.2 for assertion-based client authentication. +const ( + formClientAssertion = "client_assertion" + formClientAssertionType = "client_assertion_type" +) + +// Errors returned by the shared client authenticator. All map to OAuth +// invalid_client; the distinction is for logs/tests and for choosing the +// correct WWW-Authenticate header. +var ( + ErrClientAuthRequired = errors.New("client authentication required") + ErrClientAuthMismatch = errors.New("client_id mismatch between parameters and assertion") + ErrClientAuthMethodUnmet = errors.New( + "client authentication method does not match client registration", + ) + ErrClientAuthSecretBad = errors.New("invalid client secret") +) + +// AuthenticatedClient carries the outcome of a successful authentication at the +// token endpoint. +type AuthenticatedClient struct { + Client *models.OAuthApplication + Method string // client_secret_basic | client_secret_post | private_key_jwt | none +} + +// ClientAuthenticator performs RFC 6749 §2.3 + RFC 7521 §4.2 authentication at +// the token endpoint. It supports client_secret_basic, client_secret_post, +// private_key_jwt, and none (public clients). +type ClientAuthenticator struct { + clientService *services.ClientService + assertionVerifier *services.ClientAssertionVerifier + audience string +} + +// NewClientAuthenticator wires a new authenticator. assertionVerifier may be nil, +// in which case private_key_jwt is rejected. The audience is the token endpoint +// URL presented to clients (typically BaseURL + "/oauth/token"). +func NewClientAuthenticator( + cs *services.ClientService, + av *services.ClientAssertionVerifier, + audience string, +) *ClientAuthenticator { + return &ClientAuthenticator{ + clientService: cs, + assertionVerifier: av, + audience: audience, + } +} + +// Authenticate inspects the request and returns the authenticated client. +// requireConfidential=true forces the caller to present credentials (for grants +// like client_credentials that are restricted to confidential clients); +// requireConfidential=false still verifies credentials when they are supplied, +// but allows public clients to pass with only a client_id. +func (a *ClientAuthenticator) Authenticate( + c *gin.Context, + requireConfidential bool, +) (*AuthenticatedClient, error) { + assertion := c.PostForm(formClientAssertion) + assertionType := c.PostForm(formClientAssertionType) + if assertion != "" || assertionType != "" { + return a.authenticateViaAssertion(c.Request.Context(), c, assertion, assertionType) + } + + clientID, secret, cameFromHeader := parseClientCredentials(c) + if clientID == "" { + if requireConfidential { + return nil, ErrClientAuthRequired + } + return nil, ErrClientAuthRequired + } + + client, err := a.clientService.GetClientWithSecret(c.Request.Context(), clientID) + if err != nil { + return nil, ErrClientAuthRequired + } + if !client.IsActive() { + return nil, ErrClientAuthRequired + } + + // Enforce registration consistency: a private_key_jwt client must use an + // assertion (which takes the branch above), not a shared secret. + if client.UsesPrivateKeyJWT() { + return nil, ErrClientAuthMethodUnmet + } + + if client.UsesClientSecret() { + if secret == "" || !client.ValidateClientSecret([]byte(secret)) { + return nil, ErrClientAuthSecretBad + } + method := models.TokenEndpointAuthClientSecretBasic + if !cameFromHeader { + method = models.TokenEndpointAuthClientSecretPost + } + return &AuthenticatedClient{ + Client: client, + Method: method, + }, nil + } + + // Public client (method=none). Only allowed when requireConfidential=false. + if requireConfidential { + return nil, ErrClientAuthRequired + } + return &AuthenticatedClient{ + Client: client, + Method: models.TokenEndpointAuthNone, + }, nil +} + +func (a *ClientAuthenticator) authenticateViaAssertion( + ctx context.Context, + c *gin.Context, + assertion, assertionType string, +) (*AuthenticatedClient, error) { + if a.assertionVerifier == nil { + return nil, ErrClientAuthRequired + } + client, err := a.assertionVerifier.Verify(ctx, assertion, assertionType) + if err != nil { + return nil, err + } + // RFC 7521 §4.2: if client_id is also sent as a form param, it must match + // the authenticated client. + if formID := strings.TrimSpace( + c.PostForm("client_id"), + ); formID != "" && + formID != client.ClientID { + return nil, ErrClientAuthMismatch + } + return &AuthenticatedClient{ + Client: client, + Method: models.TokenEndpointAuthPrivateKeyJWT, + }, nil +} diff --git a/internal/handlers/oidc.go b/internal/handlers/oidc.go index 89f25112..2979b7f5 100644 --- a/internal/handlers/oidc.go +++ b/internal/handlers/oidc.go @@ -42,20 +42,21 @@ func NewOIDCHandler( // discoveryMetadata holds the OIDC Provider Metadata returned by the discovery endpoint. type discoveryMetadata struct { - Issuer string `json:"issuer"` - AuthorizationEndpoint string `json:"authorization_endpoint"` - TokenEndpoint string `json:"token_endpoint"` - UserinfoEndpoint string `json:"userinfo_endpoint"` - RevocationEndpoint string `json:"revocation_endpoint"` - JwksURI string `json:"jwks_uri,omitempty"` - ResponseTypesSupported []string `json:"response_types_supported"` - SubjectTypesSupported []string `json:"subject_types_supported"` - IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported,omitempty"` - ScopesSupported []string `json:"scopes_supported"` - TokenEndpointAuthMethods []string `json:"token_endpoint_auth_methods_supported"` - GrantTypesSupported []string `json:"grant_types_supported"` - ClaimsSupported []string `json:"claims_supported"` - CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"` + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserinfoEndpoint string `json:"userinfo_endpoint"` + RevocationEndpoint string `json:"revocation_endpoint"` + JwksURI string `json:"jwks_uri,omitempty"` + ResponseTypesSupported []string `json:"response_types_supported"` + SubjectTypesSupported []string `json:"subject_types_supported"` + IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported,omitempty"` + ScopesSupported []string `json:"scopes_supported"` + TokenEndpointAuthMethods []string `json:"token_endpoint_auth_methods_supported"` + TokenEndpointAuthSigningAlgValuesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported,omitempty"` + GrantTypesSupported []string `json:"grant_types_supported"` + ClaimsSupported []string `json:"claims_supported"` + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"` } // Discovery godoc @@ -89,11 +90,12 @@ func (h *OIDCHandler) Discovery(c *gin.Context) { SubjectTypesSupported: []string{"public"}, IDTokenSigningAlgValuesSupported: idTokenAlgs, ScopesSupported: scopes, - TokenEndpointAuthMethods: []string{ - "client_secret_basic", - "client_secret_post", - "none", - }, + TokenEndpointAuthMethods: tokenEndpointAuthMethods( + h.config.PrivateKeyJWTEnabled, + ), + TokenEndpointAuthSigningAlgValuesSupported: tokenEndpointAuthSigningAlgs( + h.config.PrivateKeyJWTEnabled, + ), GrantTypesSupported: []string{ GrantTypeAuthorizationCode, GrantTypeDeviceCode, @@ -177,6 +179,30 @@ func (h *OIDCHandler) UserInfo(c *gin.Context) { c.JSON(http.StatusOK, claims) } +// tokenEndpointAuthMethods returns the list of supported client authentication +// methods, conditionally including private_key_jwt (RFC 7523). +func tokenEndpointAuthMethods(privateKeyJWTEnabled bool) []string { + methods := []string{ + models.TokenEndpointAuthClientSecretBasic, + models.TokenEndpointAuthClientSecretPost, + models.TokenEndpointAuthNone, + } + if privateKeyJWTEnabled { + methods = append(methods, models.TokenEndpointAuthPrivateKeyJWT) + } + return methods +} + +// tokenEndpointAuthSigningAlgs returns the list of JWT signing algorithms accepted +// for client_assertion. The list is empty (omitted from discovery) when +// private_key_jwt is disabled. +func tokenEndpointAuthSigningAlgs(privateKeyJWTEnabled bool) []string { + if !privateKeyJWTEnabled { + return nil + } + return []string{models.AssertionAlgRS256, models.AssertionAlgES256} +} + // buildUserInfoClaims constructs UserInfo response claims based on the granted scopes. // sub and iss are always included. profile and email scopes gate their respective claims. func buildUserInfoClaims(userID, issuer, scopes string, user *models.User) map[string]any { diff --git a/internal/handlers/registration.go b/internal/handlers/registration.go index 56d23177..4186815e 100644 --- a/internal/handlers/registration.go +++ b/internal/handlers/registration.go @@ -2,6 +2,7 @@ package handlers import ( "crypto/subtle" + "encoding/json" "errors" "net/http" "strings" @@ -36,12 +37,15 @@ func NewRegistrationHandler( // clientRegistrationRequest represents the RFC 7591 §2 registration request body. type clientRegistrationRequest struct { - ClientName string `json:"client_name"` - RedirectURIs []string `json:"redirect_uris"` - GrantTypes []string `json:"grant_types"` - TokenEPAuth string `json:"token_endpoint_auth_method"` - Scope string `json:"scope"` - ClientURI string `json:"client_uri"` + ClientName string `json:"client_name"` + RedirectURIs []string `json:"redirect_uris"` + GrantTypes []string `json:"grant_types"` + TokenEPAuth string `json:"token_endpoint_auth_method"` + Scope string `json:"scope"` + ClientURI string `json:"client_uri"` + JWKSURI string `json:"jwks_uri,omitempty"` + JWKS json.RawMessage `json:"jwks,omitempty"` + TokenEPAuthSigningAlg string `json:"token_endpoint_auth_signing_alg,omitempty"` } // Register godoc @@ -104,10 +108,17 @@ func (h *RegistrationHandler) Register(c *gin.Context) { // 5. Determine grant types from request enableDeviceFlow := false enableAuthCodeFlow := false + enableClientCredentials := false if len(req.GrantTypes) == 0 { - // RFC 7591 §2: default grant_type is "authorization_code" - enableAuthCodeFlow = true + // RFC 7591 §2: default grant_type is "authorization_code", except + // private_key_jwt registrations typically target the client_credentials + // flow — skip the redirect_uri requirement in that case. + if req.TokenEPAuth == models.TokenEndpointAuthPrivateKeyJWT { + enableClientCredentials = true + } else { + enableAuthCodeFlow = true + } } else { for _, gt := range req.GrantTypes { switch gt { @@ -115,12 +126,14 @@ func (h *RegistrationHandler) Register(c *gin.Context) { enableAuthCodeFlow = true case GrantTypeDeviceCode, GrantTypeDeviceCodeShort: enableDeviceFlow = true + case GrantTypeClientCredentials: + enableClientCredentials = true default: respondOAuthError( c, http.StatusBadRequest, "invalid_client_metadata", - "Unsupported grant_type: "+gt+". Supported: authorization_code, device_code", + "Unsupported grant_type: "+gt+". Supported: authorization_code, device_code, client_credentials", ) return } @@ -130,25 +143,80 @@ func (h *RegistrationHandler) Register(c *gin.Context) { // 6. Determine auth method → client type (RFC 7591 §2: default is "client_secret_basic") authMethod := req.TokenEPAuth if authMethod == "" { - authMethod = "client_secret_basic" + authMethod = models.TokenEndpointAuthClientSecretBasic } var clientType core.ClientType switch authMethod { - case "none": + case models.TokenEndpointAuthNone: clientType = core.ClientTypePublic - case "client_secret_basic", "client_secret_post": + case models.TokenEndpointAuthClientSecretBasic, models.TokenEndpointAuthClientSecretPost: + clientType = core.ClientTypeConfidential + case models.TokenEndpointAuthPrivateKeyJWT: + if !h.config.PrivateKeyJWTEnabled { + respondOAuthError( + c, + http.StatusBadRequest, + "invalid_client_metadata", + "private_key_jwt is not enabled on this server", + ) + return + } clientType = core.ClientTypeConfidential default: respondOAuthError( c, http.StatusBadRequest, "invalid_client_metadata", - "Unsupported token_endpoint_auth_method: "+req.TokenEPAuth+". Supported: none, client_secret_basic, client_secret_post", + "Unsupported token_endpoint_auth_method: "+req.TokenEPAuth+". Supported: none, client_secret_basic, client_secret_post, private_key_jwt", ) return } + // 6b. For private_key_jwt, require key material (RFC 7591 §2.1). + var ( + jwksInline string + signingAlg string + ) + if authMethod == models.TokenEndpointAuthPrivateKeyJWT { + hasURI := strings.TrimSpace(req.JWKSURI) != "" + hasInline := len(req.JWKS) > 0 && !isJSONNull(req.JWKS) + if !hasURI && !hasInline { + respondOAuthError( + c, + http.StatusBadRequest, + "invalid_client_metadata", + "private_key_jwt requires either jwks_uri or jwks", + ) + return + } + if hasURI && hasInline { + respondOAuthError( + c, + http.StatusBadRequest, + "invalid_client_metadata", + "jwks_uri and jwks are mutually exclusive", + ) + return + } + if hasInline { + jwksInline = string(req.JWKS) + } + signingAlg = req.TokenEPAuthSigningAlg + if signingAlg == "" { + signingAlg = models.AssertionAlgRS256 + } + if signingAlg != models.AssertionAlgRS256 && signingAlg != models.AssertionAlgES256 { + respondOAuthError( + c, + http.StatusBadRequest, + "invalid_client_metadata", + "Unsupported token_endpoint_auth_signing_alg: "+req.TokenEPAuthSigningAlg+". Supported: RS256, ES256", + ) + return + } + } + // 7. Validate scopes (only user-safe scopes allowed) scope := strings.TrimSpace(req.Scope) if scope != "" { @@ -170,14 +238,19 @@ func (h *RegistrationHandler) Register(c *gin.Context) { // 8. Create the client via service (pending status, requires admin approval) createReq := services.CreateClientRequest{ - ClientName: req.ClientName, - Description: req.ClientURI, - Scopes: scope, - RedirectURIs: req.RedirectURIs, - ClientType: clientType, - EnableDeviceFlow: enableDeviceFlow, - EnableAuthCodeFlow: enableAuthCodeFlow, - IsAdminCreated: false, // Dynamic registration → pending approval + ClientName: req.ClientName, + Description: req.ClientURI, + Scopes: scope, + RedirectURIs: req.RedirectURIs, + ClientType: clientType, + EnableDeviceFlow: enableDeviceFlow, + EnableAuthCodeFlow: enableAuthCodeFlow, + IsAdminCreated: false, // Dynamic registration → pending approval + EnableClientCredentialsFlow: enableClientCredentials, + TokenEndpointAuthMethod: authMethod, + TokenEndpointAuthSigningAlg: signingAlg, + JWKSURI: req.JWKSURI, + JWKS: jwksInline, } resp, err := h.clientService.CreateClient(c.Request.Context(), createReq) @@ -219,17 +292,35 @@ func (h *RegistrationHandler) Register(c *gin.Context) { // 10. Build RFC 7591 §3.2.1 response grantTypes := buildResponseGrantTypes(app) - c.JSON(http.StatusCreated, gin.H{ + response := gin.H{ "client_id": app.ClientID, - "client_secret": resp.ClientSecretPlain, "client_name": app.ClientName, "redirect_uris": app.RedirectURIs, "grant_types": grantTypes, "token_endpoint_auth_method": authMethod, "scope": app.Scopes, "client_id_issued_at": app.CreatedAt.Unix(), - "client_secret_expires_at": 0, // 0 = does not expire (RFC 7591 §3.2.1) - }) + } + // Only client_secret_* clients receive a shared secret in the response. + if app.UsesClientSecret() { + response["client_secret"] = resp.ClientSecretPlain + response["client_secret_expires_at"] = 0 // RFC 7591 §3.2.1: 0 = does not expire + } + if app.UsesPrivateKeyJWT() { + if app.JWKSURI != "" { + response["jwks_uri"] = app.JWKSURI + } + response["token_endpoint_auth_signing_alg"] = app.TokenEndpointAuthSigningAlg + } + c.JSON(http.StatusCreated, response) +} + +// isJSONNull reports whether the given raw JSON represents the null literal +// (or only whitespace around it). Used to distinguish "jwks": null from a +// missing or present jwks field. +func isJSONNull(raw json.RawMessage) bool { + s := strings.TrimSpace(string(raw)) + return s == "" || s == "null" } // buildResponseGrantTypes converts the OAuthApplication's enabled flows into diff --git a/internal/handlers/registration_test.go b/internal/handlers/registration_test.go index 43f0e90b..5cf1a8ff 100644 --- a/internal/handlers/registration_test.go +++ b/internal/handlers/registration_test.go @@ -40,6 +40,7 @@ func setupRegistrationTestEnvWithOpts(t *testing.T, opts registrationTestOpts) * BaseURL: "http://localhost:8080", EnableDynamicClientRegistration: opts.enabled, DynamicClientRegistrationToken: opts.token, + PrivateKeyJWTEnabled: true, } s, err := store.New(context.Background(), "sqlite", ":memory:", &config.Config{}) @@ -284,6 +285,72 @@ func TestRegister_UnsupportedAuthMethod(t *testing.T) { w := postRegister(t, r, map[string]any{ "client_name": "My App", + "token_endpoint_auth_method": "mTLS", + }) + + assert.Equal(t, http.StatusBadRequest, w.Code) + var resp map[string]any + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + assert.Equal(t, "invalid_client_metadata", resp["error"]) + assert.Contains(t, resp["error_description"], "mTLS") +} + +// ─── Success: private_key_jwt registration with inline jwks ───────────────── + +func TestRegister_PrivateKeyJWT_InlineJWKS(t *testing.T) { + r := setupRegistrationTestEnv(t, true) + + jwks := map[string]any{ + "keys": []map[string]any{ + { + "kty": "RSA", + "use": "sig", + "kid": "test", + "alg": "RS256", + "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw", + "e": "AQAB", + }, + }, + } + w := postRegister(t, r, map[string]any{ + "client_name": "Machine Client", + "token_endpoint_auth_method": "private_key_jwt", + "token_endpoint_auth_signing_alg": "RS256", + "jwks": jwks, + }) + + assert.Equal(t, http.StatusCreated, w.Code) + var resp map[string]any + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + assert.Equal(t, "private_key_jwt", resp["token_endpoint_auth_method"]) + assert.Equal(t, "RS256", resp["token_endpoint_auth_signing_alg"]) + // private_key_jwt clients must not receive a shared secret. + _, hasSecret := resp["client_secret"] + assert.False(t, hasSecret, "private_key_jwt client should not receive a client_secret") +} + +func TestRegister_PrivateKeyJWT_JWKSURI(t *testing.T) { + r := setupRegistrationTestEnv(t, true) + + w := postRegister(t, r, map[string]any{ + "client_name": "Machine Client", + "token_endpoint_auth_method": "private_key_jwt", + "token_endpoint_auth_signing_alg": "ES256", + "jwks_uri": "https://example.com/.well-known/jwks.json", + }) + + assert.Equal(t, http.StatusCreated, w.Code) + var resp map[string]any + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + assert.Equal(t, "private_key_jwt", resp["token_endpoint_auth_method"]) + assert.Equal(t, "https://example.com/.well-known/jwks.json", resp["jwks_uri"]) +} + +func TestRegister_PrivateKeyJWT_MissingKeyMaterial(t *testing.T) { + r := setupRegistrationTestEnv(t, true) + + w := postRegister(t, r, map[string]any{ + "client_name": "Machine Client", "token_endpoint_auth_method": "private_key_jwt", }) @@ -291,7 +358,20 @@ func TestRegister_UnsupportedAuthMethod(t *testing.T) { var resp map[string]any require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) assert.Equal(t, "invalid_client_metadata", resp["error"]) - assert.Contains(t, resp["error_description"], "private_key_jwt") + assert.Contains(t, resp["error_description"], "jwks") +} + +func TestRegister_PrivateKeyJWT_BothKeysProvided(t *testing.T) { + r := setupRegistrationTestEnv(t, true) + + w := postRegister(t, r, map[string]any{ + "client_name": "Machine Client", + "token_endpoint_auth_method": "private_key_jwt", + "jwks_uri": "https://example.com/jwks", + "jwks": map[string]any{"keys": []any{}}, + }) + + assert.Equal(t, http.StatusBadRequest, w.Code) } // ─── Error: unsupported scope ──────────────────────────────────────────────── diff --git a/internal/handlers/token.go b/internal/handlers/token.go index 235ce37e..856f1fbd 100644 --- a/internal/handlers/token.go +++ b/internal/handlers/token.go @@ -43,6 +43,7 @@ type TokenHandler struct { tokenService *services.TokenService authorizationService *services.AuthorizationService config *config.Config + clientAuthenticator *ClientAuthenticator // optional; when nil, falls back to secret-only auth } func NewTokenHandler( @@ -57,6 +58,13 @@ func NewTokenHandler( } } +// WithClientAuthenticator attaches a shared ClientAuthenticator so the token +// endpoint can accept RFC 7523 private_key_jwt in addition to client_secret. +func (h *TokenHandler) WithClientAuthenticator(a *ClientAuthenticator) *TokenHandler { + h.clientAuthenticator = a + return h +} + // buildTokenResponse constructs a standard OAuth 2.0 token response (RFC 6749 §5.1). func buildTokenResponse(accessToken, refreshToken *models.AccessToken, idToken string) gin.H { expiresIn := max(int(time.Until(accessToken.ExpiresAt).Seconds()), 0) @@ -290,33 +298,50 @@ func (h *TokenHandler) TokenInfo(c *gin.Context) { // @Failure 401 {object} object{error=string,error_description=string} "Client authentication failed" // @Router /oauth/introspect [post] func (h *TokenHandler) Introspect(c *gin.Context) { - // 1. Authenticate the calling client (RFC 7662 §2.1) - clientID, clientSecret := parseClientCredentials(c) - if clientID == "" || clientSecret == "" { - c.Header("WWW-Authenticate", `Basic realm="authgate"`) - respondOAuthError( - c, - http.StatusUnauthorized, - errInvalidClient, - "Client authentication required", - ) - return - } - - // Verify client credentials - if err := h.tokenService.AuthenticateClient( - c.Request.Context(), - clientID, - clientSecret, - ); err != nil { - c.Header("WWW-Authenticate", `Basic realm="authgate"`) - respondOAuthError( - c, - http.StatusUnauthorized, - errInvalidClient, - "Client authentication failed", - ) - return + // 1. Authenticate the calling client (RFC 7662 §2.1). + // Prefer the shared authenticator so private_key_jwt is accepted alongside + // client_secret_* methods. + var clientID string + if h.clientAuthenticator != nil { + authed, err := h.clientAuthenticator.Authenticate(c, true) + if err != nil { + c.Header("WWW-Authenticate", `Basic realm="authgate"`) + respondOAuthError( + c, + http.StatusUnauthorized, + errInvalidClient, + "Client authentication failed", + ) + return + } + clientID = authed.Client.ClientID + } else { + id, secret, _ := parseClientCredentials(c) + if id == "" || secret == "" { + c.Header("WWW-Authenticate", `Basic realm="authgate"`) + respondOAuthError( + c, + http.StatusUnauthorized, + errInvalidClient, + "Client authentication required", + ) + return + } + if err := h.tokenService.AuthenticateClient( + c.Request.Context(), + id, + secret, + ); err != nil { + c.Header("WWW-Authenticate", `Basic realm="authgate"`) + respondOAuthError( + c, + http.StatusUnauthorized, + errInvalidClient, + "Client authentication failed", + ) + return + } + clientID = id } // 2. Get the token parameter (RFC 7662 §2.1: REQUIRED) @@ -408,12 +433,41 @@ func (h *TokenHandler) Revoke(c *gin.Context) { } // handleClientCredentialsGrant handles the client_credentials grant type (RFC 6749 §4.4). -// Client authentication is accepted via HTTP Basic Auth (preferred per RFC 6749 §2.3.1) -// or as client_id / client_secret form-body parameters. +// Client authentication is accepted via HTTP Basic Auth (preferred per RFC 6749 §2.3.1), +// client_id / client_secret form-body parameters, or RFC 7523 private_key_jwt assertions. // Only confidential clients with the client_credentials flow enabled may use this endpoint. // No refresh token is issued in the response. func (h *TokenHandler) handleClientCredentialsGrant(c *gin.Context) { - clientID, clientSecret := parseClientCredentials(c) + requestedScopes := c.PostForm("scope") // Optional + + // Prefer the shared authenticator when wired — it unifies Basic Auth, + // form-body credentials, and private_key_jwt assertions. + if h.clientAuthenticator != nil { + authed, err := h.clientAuthenticator.Authenticate(c, true) + if err != nil { + c.Header("WWW-Authenticate", `Basic realm="authgate"`) + respondOAuthError( + c, + http.StatusUnauthorized, + errInvalidClient, + "Client authentication failed", + ) + return + } + accessToken, err := h.tokenService.IssueClientCredentialsTokenForClient( + c.Request.Context(), + authed.Client, + requestedScopes, + ) + if err != nil { + h.writeClientCredentialsError(c, err) + return + } + c.JSON(http.StatusOK, buildTokenResponse(accessToken, nil, "")) + return + } + + clientID, clientSecret, _ := parseClientCredentials(c) if clientID == "" || clientSecret == "" { c.Header("WWW-Authenticate", `Basic realm="authgate"`) respondOAuthError( @@ -425,8 +479,6 @@ func (h *TokenHandler) handleClientCredentialsGrant(c *gin.Context) { return } - requestedScopes := c.PostForm("scope") // Optional - accessToken, err := h.tokenService.IssueClientCredentialsToken( c.Request.Context(), clientID, @@ -434,35 +486,7 @@ func (h *TokenHandler) handleClientCredentialsGrant(c *gin.Context) { requestedScopes, ) if err != nil { - switch { - case errors.Is(err, services.ErrInvalidClientCredentials), - errors.Is(err, services.ErrClientNotConfidential): - // RFC 6749 §5.2: use 401 + WWW-Authenticate for invalid_client - c.Header("WWW-Authenticate", `Basic realm="authgate"`) - respondOAuthError( - c, - http.StatusUnauthorized, - errInvalidClient, - "Client authentication failed", - ) - case errors.Is(err, services.ErrClientCredentialsFlowDisabled): - respondOAuthError(c, http.StatusBadRequest, errUnauthorizedClient, - "Client credentials flow is not enabled for this client") - case errors.Is(err, token.ErrInvalidScope): - respondOAuthError( - c, - http.StatusBadRequest, - errInvalidScope, - "Requested scope exceeds client permissions or contains restricted scopes (openid, offline_access are not permitted)", - ) - default: - respondOAuthError( - c, - http.StatusInternalServerError, - errServerError, - "Token issuance failed", - ) - } + h.writeClientCredentialsError(c, err) return } @@ -470,6 +494,40 @@ func (h *TokenHandler) handleClientCredentialsGrant(c *gin.Context) { c.JSON(http.StatusOK, buildTokenResponse(accessToken, nil, "")) } +// writeClientCredentialsError maps service-layer errors from the client_credentials +// flow to RFC 6749 error responses. Shared between the classic and shared-authenticator +// code paths. +func (h *TokenHandler) writeClientCredentialsError(c *gin.Context, err error) { + switch { + case errors.Is(err, services.ErrInvalidClientCredentials), + errors.Is(err, services.ErrClientNotConfidential): + c.Header("WWW-Authenticate", `Basic realm="authgate"`) + respondOAuthError( + c, + http.StatusUnauthorized, + errInvalidClient, + "Client authentication failed", + ) + case errors.Is(err, services.ErrClientCredentialsFlowDisabled): + respondOAuthError(c, http.StatusBadRequest, errUnauthorizedClient, + "Client credentials flow is not enabled for this client") + case errors.Is(err, token.ErrInvalidScope): + respondOAuthError( + c, + http.StatusBadRequest, + errInvalidScope, + "Requested scope exceeds client permissions or contains restricted scopes (openid, offline_access are not permitted)", + ) + default: + respondOAuthError( + c, + http.StatusInternalServerError, + errServerError, + "Token issuance failed", + ) + } +} + // handleAuthorizationCodeGrant handles the authorization_code grant type (RFC 6749 §4.1.3). func (h *TokenHandler) handleAuthorizationCodeGrant(c *gin.Context) { code := c.PostForm("code") diff --git a/internal/handlers/token_private_key_jwt_test.go b/internal/handlers/token_private_key_jwt_test.go new file mode 100644 index 00000000..ad188e71 --- /dev/null +++ b/internal/handlers/token_private_key_jwt_test.go @@ -0,0 +1,278 @@ +package handlers + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "math/big" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/go-authgate/authgate/internal/cache" + "github.com/go-authgate/authgate/internal/config" + "github.com/go-authgate/authgate/internal/core" + "github.com/go-authgate/authgate/internal/metrics" + "github.com/go-authgate/authgate/internal/models" + "github.com/go-authgate/authgate/internal/services" + "github.com/go-authgate/authgate/internal/store" + "github.com/go-authgate/authgate/internal/token" + "github.com/go-authgate/authgate/internal/util" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// setupPKJWTEnv builds a test router whose /oauth/token endpoint accepts RFC 7523 +// client_assertion in addition to the classic client_secret paths. +func setupPKJWTEnv(t *testing.T) (*gin.Engine, *store.Store, string) { + t.Helper() + gin.SetMode(gin.TestMode) + + baseURL := "https://authgate.test" + cfg := &config.Config{ + BaseURL: baseURL, + JWTExpiration: time.Hour, + ClientCredentialsTokenExpiration: time.Hour, + JWTSecret: "test-secret-32-chars-long!!!!!!!", + PrivateKeyJWTEnabled: true, + JWKSFetchTimeout: 2 * time.Second, + JWKSCacheTTL: time.Minute, + ClientAssertionMaxLifetime: 5 * time.Minute, + ClientAssertionClockSkew: 30 * time.Second, + } + + s, err := store.New(context.Background(), "sqlite", ":memory:", &config.Config{}) + require.NoError(t, err) + + localProvider, err := token.NewLocalTokenProvider(cfg) + require.NoError(t, err) + auditSvc := services.NewNoopAuditService() + clientSvc := services.NewClientService(s, auditSvc, nil, 0, nil, 0) + deviceSvc := services.NewDeviceService(s, cfg, auditSvc, metrics.NewNoopMetrics(), clientSvc) + tokenSvc := services.NewTokenService( + s, cfg, deviceSvc, localProvider, auditSvc, metrics.NewNoopMetrics(), + cache.NewNoopCache[models.AccessToken](), clientSvc, + ) + authzSvc := services.NewAuthorizationService(s, cfg, auditSvc, tokenSvc, clientSvc) + + jwksCache := cache.NewMemoryCache[util.JWKSet](0) + jtiCache := cache.NewMemoryCache[bool](0) + t.Cleanup(func() { + _ = jwksCache.Close() + _ = jtiCache.Close() + }) + fetcher := services.NewJWKSFetcher(jwksCache, cfg.JWKSFetchTimeout, cfg.JWKSCacheTTL) + tokenEndpoint := strings.TrimRight(baseURL, "/") + "/oauth/token" + verifier := services.NewClientAssertionVerifier( + clientSvc, fetcher, jtiCache, auditSvc, + services.ClientAssertionConfig{ + Enabled: true, + ExpectedAudiences: []string{tokenEndpoint, baseURL}, + MaxLifetime: cfg.ClientAssertionMaxLifetime, + ClockSkew: cfg.ClientAssertionClockSkew, + }, + ) + clientAuth := NewClientAuthenticator(clientSvc, verifier, tokenEndpoint) + handler := NewTokenHandler(tokenSvc, authzSvc, cfg).WithClientAuthenticator(clientAuth) + + r := gin.New() + r.POST("/oauth/token", handler.Token) + r.POST("/oauth/introspect", handler.Introspect) + + return r, s, tokenEndpoint +} + +// seedPKJWTClient inserts a confidential client whose token endpoint auth method +// is private_key_jwt with the given inline JWK Set. +func seedPKJWTClient( + t *testing.T, + s *store.Store, + jwk util.JWK, + alg string, +) *models.OAuthApplication { + t.Helper() + blob, err := json.Marshal(util.JWKSet{Keys: []util.JWK{jwk}}) + require.NoError(t, err) + client := &models.OAuthApplication{ + ClientID: uuid.New().String(), + ClientName: "MCP M2M client", + UserID: uuid.New().String(), + Scopes: "read write", + GrantTypes: "client_credentials", + ClientType: core.ClientTypeConfidential.String(), + EnableClientCredentialsFlow: true, + Status: models.ClientStatusActive, + TokenEndpointAuthMethod: models.TokenEndpointAuthPrivateKeyJWT, + TokenEndpointAuthSigningAlg: alg, + JWKS: string(blob), + } + require.NoError(t, s.CreateClient(client)) + return client +} + +func rsaPKJWTFixture(t *testing.T, kid string) (*rsa.PrivateKey, util.JWK) { + t.Helper() + priv, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + return priv, util.JWK{ + Kty: "RSA", + Use: "sig", + Kid: kid, + Alg: "RS256", + N: base64.RawURLEncoding.EncodeToString(priv.PublicKey.N.Bytes()), + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(priv.PublicKey.E)).Bytes()), + } +} + +func signAssertion( + t *testing.T, + priv *rsa.PrivateKey, + kid, clientID, audience string, +) string { + t.Helper() + now := time.Now() + claims := jwt.MapClaims{ + "iss": clientID, + "sub": clientID, + "aud": audience, + "iat": now.Unix(), + "exp": now.Add(time.Minute).Unix(), + "jti": uuid.NewString(), + } + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tok.Header["kid"] = kid + out, err := tok.SignedString(priv) + require.NoError(t, err) + return out +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +func TestPrivateKeyJWT_ClientCredentials_Success(t *testing.T) { + r, s, aud := setupPKJWTEnv(t) + priv, jwk := rsaPKJWTFixture(t, "k1") + client := seedPKJWTClient(t, s, jwk, "RS256") + + form := url.Values{ + "grant_type": {"client_credentials"}, + "client_assertion_type": {services.AssertionType}, + "client_assertion": {signAssertion(t, priv, "k1", client.ClientID, aud)}, + "scope": {"read"}, + } + w := postToken(t, r, form, nil) + + require.Equal(t, http.StatusOK, w.Code, "body=%s", w.Body.String()) + var resp map[string]any + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + assert.NotEmpty(t, resp["access_token"]) + assert.Equal(t, "Bearer", resp["token_type"]) + assert.Equal(t, "read", resp["scope"]) + // RFC 6749 §4.4.3 — no refresh_token + _, hasRefresh := resp["refresh_token"] + assert.False(t, hasRefresh) +} + +func TestPrivateKeyJWT_ClientCredentials_InvalidAud(t *testing.T) { + r, s, _ := setupPKJWTEnv(t) + priv, jwk := rsaPKJWTFixture(t, "k1") + client := seedPKJWTClient(t, s, jwk, "RS256") + + form := url.Values{ + "grant_type": {"client_credentials"}, + "client_assertion_type": {services.AssertionType}, + "client_assertion": { + signAssertion(t, priv, "k1", client.ClientID, "https://attacker.example"), + }, + } + w := postToken(t, r, form, nil) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var resp map[string]any + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + assert.Equal(t, "invalid_client", resp["error"]) +} + +func TestPrivateKeyJWT_ClientCredentials_ReplayRejected(t *testing.T) { + r, s, aud := setupPKJWTEnv(t) + priv, jwk := rsaPKJWTFixture(t, "k1") + client := seedPKJWTClient(t, s, jwk, "RS256") + + assertion := signAssertion(t, priv, "k1", client.ClientID, aud) + form := url.Values{ + "grant_type": {"client_credentials"}, + "client_assertion_type": {services.AssertionType}, + "client_assertion": {assertion}, + } + // First use succeeds + w := postToken(t, r, form, nil) + require.Equal(t, http.StatusOK, w.Code, "first call body=%s", w.Body.String()) + // Second use of the exact same assertion must be rejected as replay. + w2 := postToken(t, r, form, nil) + require.Equal(t, http.StatusUnauthorized, w2.Code) +} + +func TestPrivateKeyJWT_Introspect_Success(t *testing.T) { + r, s, aud := setupPKJWTEnv(t) + priv, jwk := rsaPKJWTFixture(t, "k1") + client := seedPKJWTClient(t, s, jwk, "RS256") + + // First obtain an access token via client_credentials. + ccForm := url.Values{ + "grant_type": {"client_credentials"}, + "client_assertion_type": {services.AssertionType}, + "client_assertion": {signAssertion(t, priv, "k1", client.ClientID, aud)}, + } + w := postToken(t, r, ccForm, nil) + require.Equal(t, http.StatusOK, w.Code) + var tokenResp map[string]any + require.NoError(t, json.NewDecoder(w.Body).Decode(&tokenResp)) + accessToken := tokenResp["access_token"].(string) + require.NotEmpty(t, accessToken) + + // Now introspect it, authenticating via a fresh assertion. + introForm := url.Values{ + "token": {accessToken}, + "client_assertion_type": {services.AssertionType}, + "client_assertion": {signAssertion(t, priv, "k1", client.ClientID, aud)}, + } + req, err := http.NewRequest( + http.MethodPost, + "/oauth/introspect", + strings.NewReader(introForm.Encode()), + ) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code, "body=%s", rec.Body.String()) + var introResp map[string]any + require.NoError(t, json.NewDecoder(rec.Body).Decode(&introResp)) + assert.Equal(t, true, introResp["active"]) + assert.Equal(t, client.ClientID, introResp["client_id"]) +} + +func TestPrivateKeyJWT_DisabledMethodRejectsSecretAttempt(t *testing.T) { + r, s, _ := setupPKJWTEnv(t) + _, jwk := rsaPKJWTFixture(t, "k1") + client := seedPKJWTClient(t, s, jwk, "RS256") + + // Attempt to authenticate a private_key_jwt client with a shared secret. + // The client has no secret hash stored, so this must be rejected. + form := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {client.ClientID}, + "client_secret": {"anything"}, + } + w := postToken(t, r, form, nil) + require.Equal(t, http.StatusUnauthorized, w.Code) +} diff --git a/internal/handlers/utils.go b/internal/handlers/utils.go index 20e028a4..6eddef84 100644 --- a/internal/handlers/utils.go +++ b/internal/handlers/utils.go @@ -24,13 +24,13 @@ func parsePaginationParams(c *gin.Context) store.PaginationParams { // parseClientCredentials extracts client_id and client_secret from the request // using HTTP Basic Auth (preferred per RFC 6749 §2.3.1) or form-body parameters. -func parseClientCredentials(c *gin.Context) (clientID, clientSecret string) { - clientID, clientSecret, ok := c.Request.BasicAuth() - if !ok { - clientID = c.PostForm("client_id") - clientSecret = c.PostForm("client_secret") +// fromHeader is true when credentials were taken from the Authorization header, +// letting callers distinguish client_secret_basic from client_secret_post. +func parseClientCredentials(c *gin.Context) (clientID, clientSecret string, fromHeader bool) { + if id, pw, ok := c.Request.BasicAuth(); ok { + return id, pw, true } - return clientID, clientSecret + return c.PostForm("client_id"), c.PostForm("client_secret"), false } // respondOAuthError writes an RFC-compliant OAuth error JSON response. diff --git a/internal/models/audit_log.go b/internal/models/audit_log.go index 551abd84..7ef7f443 100644 --- a/internal/models/audit_log.go +++ b/internal/models/audit_log.go @@ -73,6 +73,10 @@ const ( // Token Introspection events (RFC 7662) EventTokenIntrospected EventType = "TOKEN_INTROSPECTED" + // Client authentication events (RFC 7523 private_key_jwt) + EventClientAssertionVerified EventType = "CLIENT_ASSERTION_VERIFIED" + EventClientAssertionFailed EventType = "CLIENT_ASSERTION_FAILED" + // Audit events EventTypeAuditLogView EventType = "AUDIT_LOG_VIEWED" EventTypeAuditLogExported EventType = "AUDIT_LOG_EXPORTED" diff --git a/internal/models/oauth_application.go b/internal/models/oauth_application.go index 8fe215ae..8bdbafd5 100644 --- a/internal/models/oauth_application.go +++ b/internal/models/oauth_application.go @@ -23,6 +23,20 @@ const ( ClientStatusInactive ClientStatus = "inactive" // Admin rejected or disabled ) +// Token endpoint authentication methods (RFC 7591 §2). +const ( + TokenEndpointAuthNone = "none" // Public client, no authentication + TokenEndpointAuthClientSecretBasic = "client_secret_basic" // HTTP Basic (default) + TokenEndpointAuthClientSecretPost = "client_secret_post" // client_secret in form body + TokenEndpointAuthPrivateKeyJWT = "private_key_jwt" // RFC 7523 JWT Bearer Assertion +) + +// Client assertion signing algorithms supported for private_key_jwt. +const ( + AssertionAlgRS256 = "RS256" + AssertionAlgES256 = "ES256" +) + // Base32 characters, but lowercased. const lowerBase32Chars = "abcdefghijklmnopqrstuvwxyz234567" @@ -32,7 +46,7 @@ var base32Lower = base32.NewEncoding(lowerBase32Chars).WithPadding(base32.NoPadd type OAuthApplication struct { ID int64 `gorm:"primaryKey;autoIncrement"` ClientID string `gorm:"uniqueIndex;not null"` - ClientSecret string `gorm:"not null"` // bcrypt hashed secret + ClientSecret string `gorm:"not null;default:''"` // bcrypt hashed secret; empty for public / private_key_jwt clients ClientName string `gorm:"not null"` Description string `gorm:"type:text"` UserID string `gorm:"not null"` @@ -42,8 +56,12 @@ type OAuthApplication struct { ClientType string `gorm:"not null;default:'public'"` // "confidential" or "public" EnableDeviceFlow bool `gorm:"not null;default:true"` EnableAuthCodeFlow bool `gorm:"not null;default:false"` - EnableClientCredentialsFlow bool `gorm:"not null;default:false"` // Client Credentials Grant (RFC 6749 §4.4); confidential clients only - Status string `gorm:"not null;default:'active'"` // ClientStatusPending / ClientStatusActive / ClientStatusInactive + EnableClientCredentialsFlow bool `gorm:"not null;default:false"` // Client Credentials Grant (RFC 6749 §4.4); confidential clients only + Status string `gorm:"not null;default:'active'"` // ClientStatusPending / ClientStatusActive / ClientStatusInactive + TokenEndpointAuthMethod string `gorm:"not null;default:'client_secret_basic'"` // RFC 7591 §2 + TokenEndpointAuthSigningAlg string `gorm:"type:varchar(10);not null;default:''"` // RS256 | ES256 (required for private_key_jwt) + JWKSURI string `gorm:"type:varchar(500);not null;default:''"` // Remote JWKS endpoint URL (mutually exclusive with JWKS) + JWKS string `gorm:"type:text;not null;default:''"` // Inline JWK Set JSON (mutually exclusive with JWKSURI) CreatedBy string CreatedAt time.Time UpdatedAt time.Time @@ -110,3 +128,57 @@ func (OAuthApplication) TableName() string { func (app *OAuthApplication) IsActive() bool { return app.Status == ClientStatusActive } + +// UsesPrivateKeyJWT reports whether the client authenticates using JWT Bearer +// Assertions (RFC 7523) at the token endpoint. +func (app *OAuthApplication) UsesPrivateKeyJWT() bool { + return app.TokenEndpointAuthMethod == TokenEndpointAuthPrivateKeyJWT +} + +// UsesClientSecret reports whether the client authenticates using a shared +// secret (either HTTP Basic or form-body). Public clients and private_key_jwt +// clients return false. +func (app *OAuthApplication) UsesClientSecret() bool { + return app.TokenEndpointAuthMethod == TokenEndpointAuthClientSecretBasic || + app.TokenEndpointAuthMethod == TokenEndpointAuthClientSecretPost +} + +// ErrInvalidKeyMaterial indicates a client's private_key_jwt configuration is invalid. +var ErrInvalidKeyMaterial = errors.New("invalid key material for private_key_jwt") + +// ValidateKeyMaterial verifies that a private_key_jwt client has exactly one of +// JWKSURI or JWKS set, and that the signing algorithm is supported. For other +// auth methods, it verifies no key material is present. +func (app *OAuthApplication) ValidateKeyMaterial() error { + if !app.UsesPrivateKeyJWT() { + if app.JWKSURI != "" || app.JWKS != "" || app.TokenEndpointAuthSigningAlg != "" { + return errors.New( + "JWKS, JWKS URI, and signing algorithm must be empty when token_endpoint_auth_method is not private_key_jwt", + ) + } + return nil + } + + // private_key_jwt validation + hasURI := strings.TrimSpace(app.JWKSURI) != "" + hasInline := strings.TrimSpace(app.JWKS) != "" + if !hasURI && !hasInline { + return errors.New("private_key_jwt requires either jwks_uri or jwks to be provided") + } + if hasURI && hasInline { + return errors.New("private_key_jwt requires jwks_uri and jwks to be mutually exclusive") + } + + switch app.TokenEndpointAuthSigningAlg { + case AssertionAlgRS256, AssertionAlgES256: + return nil + case "": + return errors.New( + "private_key_jwt requires token_endpoint_auth_signing_alg to be set (RS256 or ES256)", + ) + default: + return errors.New( + "unsupported token_endpoint_auth_signing_alg: only RS256 and ES256 are supported", + ) + } +} diff --git a/internal/services/client.go b/internal/services/client.go index 38a16d58..2d500c76 100644 --- a/internal/services/client.go +++ b/internal/services/client.go @@ -63,8 +63,39 @@ var ( ErrInvalidClientStatus = errors.New( "status must be \"active\", \"inactive\", or \"pending\"", ) + ErrPrivateKeyJWTRequiresConfidential = errors.New( + "private_key_jwt requires a confidential client", + ) + ErrInvalidTokenEndpointAuthMethod = errors.New( + "invalid token_endpoint_auth_method", + ) ) +// validTokenEndpointAuthMethod reports whether m is one of the recognised +// RFC 7591 §2 values AuthGate supports. +func validTokenEndpointAuthMethod(m string) bool { + switch m { + case models.TokenEndpointAuthNone, + models.TokenEndpointAuthClientSecretBasic, + models.TokenEndpointAuthClientSecretPost, + models.TokenEndpointAuthPrivateKeyJWT: + return true + } + return false +} + +// resolveTokenEndpointAuthMethod picks the default auth method for a given +// client type when the caller did not specify one. +func resolveTokenEndpointAuthMethod(method string, clientType core.ClientType) string { + if method != "" { + return method + } + if clientType == core.ClientTypePublic { + return models.TokenEndpointAuthNone + } + return models.TokenEndpointAuthClientSecretBasic +} + // validateRedirectURIs checks that every URI in the slice is an absolute http/https // URI without a fragment, as required by RFC 6749. func validateRedirectURIs(uris []string) error { @@ -143,6 +174,14 @@ type CreateClientRequest struct { EnableAuthCodeFlow bool // Enable Authorization Code Flow (RFC 6749) EnableClientCredentialsFlow bool // Enable Client Credentials Grant (RFC 6749 §4.4); confidential clients only IsAdminCreated bool // When true: Status=active; when false: Status=pending + + // Token endpoint authentication (RFC 7591 §2). When empty, a default is + // selected based on ClientType. Setting this to "private_key_jwt" (RFC 7523) + // requires JWKSURI or JWKS plus TokenEndpointAuthSigningAlg. + TokenEndpointAuthMethod string + TokenEndpointAuthSigningAlg string // RS256 | ES256 + JWKSURI string // Mutually exclusive with JWKS + JWKS string // Inline JWK Set JSON } type UpdateClientRequest struct { @@ -155,6 +194,12 @@ type UpdateClientRequest struct { EnableDeviceFlow bool EnableAuthCodeFlow bool EnableClientCredentialsFlow bool // Enable Client Credentials Grant (RFC 6749 §4.4); confidential clients only + + // Token endpoint authentication (see CreateClientRequest). + TokenEndpointAuthMethod string + TokenEndpointAuthSigningAlg string + JWKSURI string + JWKS string } type ClientResponse struct { @@ -190,6 +235,16 @@ func (s *ClientService) CreateClient( return nil, err } + // Token endpoint authentication method (RFC 7591 §2). + authMethod := resolveTokenEndpointAuthMethod(req.TokenEndpointAuthMethod, clientType) + if !validTokenEndpointAuthMethod(authMethod) { + return nil, ErrInvalidTokenEndpointAuthMethod + } + if authMethod == models.TokenEndpointAuthPrivateKeyJWT && + clientType != core.ClientTypeConfidential { + return nil, ErrPrivateKeyJWTRequiresConfidential + } + // Generate client ID clientID := uuid.New().String() @@ -232,12 +287,25 @@ func (s *ClientService) CreateClient( EnableClientCredentialsFlow: enableClientCredentials, Status: clientStatus, CreatedBy: req.CreatedBy, + TokenEndpointAuthMethod: authMethod, + TokenEndpointAuthSigningAlg: req.TokenEndpointAuthSigningAlg, + JWKSURI: strings.TrimSpace(req.JWKSURI), + JWKS: strings.TrimSpace(req.JWKS), } - // Generate client secret - clientSecret, err := client.GenerateClientSecret(ctx) - if err != nil { - return nil, err + if err := client.ValidateKeyMaterial(); err != nil { + return nil, fmt.Errorf("%w: %s", ErrInvalidClientData, err.Error()) + } + + // Generate a shared secret only for the two client_secret_* auth methods. + // Public (none) and private_key_jwt clients do not have a secret. + var clientSecret string + if client.UsesClientSecret() { + var err error + clientSecret, err = client.GenerateClientSecret(ctx) + if err != nil { + return nil, err + } } if err := s.store.CreateClient(client); err != nil { @@ -333,6 +401,29 @@ func (s *ClientService) UpdateClient( enableClientCredentials, ) + // Token endpoint authentication fields (may be zero to preserve existing). + if req.TokenEndpointAuthMethod != "" { + if !validTokenEndpointAuthMethod(req.TokenEndpointAuthMethod) { + return ErrInvalidTokenEndpointAuthMethod + } + if req.TokenEndpointAuthMethod == models.TokenEndpointAuthPrivateKeyJWT && + clientType != core.ClientTypeConfidential { + return ErrPrivateKeyJWTRequiresConfidential + } + client.TokenEndpointAuthMethod = req.TokenEndpointAuthMethod + } + client.TokenEndpointAuthSigningAlg = req.TokenEndpointAuthSigningAlg + client.JWKSURI = strings.TrimSpace(req.JWKSURI) + client.JWKS = strings.TrimSpace(req.JWKS) + // Clear the shared secret when switching away from client_secret_* methods, + // so a stale hash cannot authenticate a reconfigured client. + if !client.UsesClientSecret() { + client.ClientSecret = "" + } + if err := client.ValidateKeyMaterial(); err != nil { + return fmt.Errorf("%w: %s", ErrInvalidClientData, err.Error()) + } + err = s.store.UpdateClient(client) if err != nil { return err diff --git a/internal/services/client_assertion.go b/internal/services/client_assertion.go new file mode 100644 index 00000000..a4355d85 --- /dev/null +++ b/internal/services/client_assertion.go @@ -0,0 +1,382 @@ +package services + +import ( + "context" + "errors" + "fmt" + "log" + "slices" + "strings" + "sync" + "time" + + "github.com/go-authgate/authgate/internal/core" + "github.com/go-authgate/authgate/internal/models" + "github.com/go-authgate/authgate/internal/util" + + "github.com/golang-jwt/jwt/v5" +) + +// AssertionType is the sole value allowed for the client_assertion_type parameter +// when using JWT Bearer Assertions (RFC 7523 §2.2). +const AssertionType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" + +// jtiCacheKeyPrefix namespaces the per-jti replay protection entries so they +// can coexist with other keys in a shared cache backend. +const jtiCacheKeyPrefix = "pkjwt:jti:" + +// Errors returned by ClientAssertionVerifier. All are presented to the caller as +// OAuth invalid_client — the distinction is useful for audit logs and tests. +var ( + ErrAssertionFeatureDisabled = errors.New("private_key_jwt is disabled") + ErrAssertionTypeInvalid = errors.New("invalid client_assertion_type") + ErrAssertionMalformed = errors.New("malformed client_assertion") + ErrAssertionIssuerMismatch = errors.New("client_assertion iss/sub mismatch") + ErrAssertionClientUnknown = errors.New( + "client_assertion issuer is not a registered client", + ) + ErrAssertionClientInactive = errors.New("client is not active") + ErrAssertionMethodNotAllowed = errors.New("client is not configured for private_key_jwt") + ErrAssertionKeyLookup = errors.New("unable to resolve client signing key") + ErrAssertionSignatureInvalid = errors.New("client_assertion signature is invalid") + ErrAssertionAlgorithmMismatch = errors.New( + "client_assertion algorithm does not match client registration", + ) + ErrAssertionAudienceInvalid = errors.New("client_assertion audience is invalid") + ErrAssertionExpired = errors.New("client_assertion is expired") + ErrAssertionNotYetValid = errors.New("client_assertion is not yet valid") + ErrAssertionLifetimeTooLong = errors.New("client_assertion lifetime exceeds server maximum") + ErrAssertionMissingJTI = errors.New("client_assertion is missing jti") + ErrAssertionJTIReplay = errors.New("client_assertion jti was already used") + ErrAssertionMissingRequiredTime = errors.New("client_assertion is missing required time claims") +) + +// ClientAssertionConfig controls the verifier's behaviour. All durations are +// positive; the caller is responsible for providing sensible defaults. +type ClientAssertionConfig struct { + Enabled bool + ExpectedAudiences []string // at least one must be present in the aud claim + MaxLifetime time.Duration + ClockSkew time.Duration +} + +// ClientAssertionVerifier validates JWT Bearer Assertions presented as +// client_assertion at the token endpoint (RFC 7523). +type ClientAssertionVerifier struct { + clientService *ClientService + jwksFetcher *JWKSFetcher + jtiCache core.Cache[bool] + auditService core.AuditLogger + cfg ClientAssertionConfig + + // jtiMu serialises the jti Get+Set so two concurrent requests with the + // same jti cannot both observe a cache miss and pass replay detection. + // Honest traffic has unique jtis and hits no contention; replay attempts + // and malformed duplicates block each other, which is the desired effect. + jtiMu sync.Mutex +} + +// NewClientAssertionVerifier wires the verifier. auditService may be nil (no-op). +// jtiCache must be supplied — it is required for RFC 7523 §3 replay prevention. +func NewClientAssertionVerifier( + clientService *ClientService, + jwksFetcher *JWKSFetcher, + jtiCache core.Cache[bool], + auditService core.AuditLogger, + cfg ClientAssertionConfig, +) *ClientAssertionVerifier { + if auditService == nil { + auditService = NewNoopAuditService() + } + if cfg.MaxLifetime <= 0 { + cfg.MaxLifetime = 5 * time.Minute + } + if cfg.ClockSkew < 0 { + cfg.ClockSkew = 30 * time.Second + } + return &ClientAssertionVerifier{ + clientService: clientService, + jwksFetcher: jwksFetcher, + jtiCache: jtiCache, + auditService: auditService, + cfg: cfg, + } +} + +// Verify validates the provided JWT assertion and returns the authenticated +// OAuth client. All error returns are safe to surface as OAuth invalid_client. +func (v *ClientAssertionVerifier) Verify( + ctx context.Context, + assertion, assertionType string, +) (*models.OAuthApplication, error) { + if !v.cfg.Enabled { + return nil, ErrAssertionFeatureDisabled + } + if assertionType != AssertionType { + return nil, ErrAssertionTypeInvalid + } + if strings.TrimSpace(assertion) == "" { + return nil, ErrAssertionMalformed + } + + parser := jwt.NewParser(jwt.WithoutClaimsValidation()) + tok, _, err := parser.ParseUnverified(assertion, jwt.MapClaims{}) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrAssertionMalformed, err) + } + claims, ok := tok.Claims.(jwt.MapClaims) + if !ok { + return nil, ErrAssertionMalformed + } + + iss, _ := claims["iss"].(string) + sub, _ := claims["sub"].(string) + if iss == "" || sub == "" || iss != sub { + v.logFailure(ctx, iss, ErrAssertionIssuerMismatch.Error()) + return nil, ErrAssertionIssuerMismatch + } + + client, err := v.clientService.GetClient(ctx, iss) + if err != nil { + v.logFailure(ctx, iss, "client lookup failed") + if errors.Is(err, ErrClientNotFound) { + return nil, ErrAssertionClientUnknown + } + return nil, ErrAssertionClientUnknown + } + if !client.IsActive() { + v.logFailure(ctx, iss, "client inactive") + return nil, ErrAssertionClientInactive + } + if !client.UsesPrivateKeyJWT() { + v.logFailure(ctx, iss, "client not configured for private_key_jwt") + return nil, ErrAssertionMethodNotAllowed + } + + // Algorithm must match the one registered with the client. + if tok.Method.Alg() != client.TokenEndpointAuthSigningAlg { + v.logFailure(ctx, iss, fmt.Sprintf( + "algorithm mismatch: header=%s registered=%s", + tok.Method.Alg(), client.TokenEndpointAuthSigningAlg, + )) + return nil, ErrAssertionAlgorithmMismatch + } + + kid, _ := tok.Header["kid"].(string) + jwkSet, err := v.resolveJWKS(ctx, client, kid) + if err != nil { + v.logFailure(ctx, iss, err.Error()) + return nil, fmt.Errorf("%w: %v", ErrAssertionKeyLookup, err) + } + jwk := jwkSet.FindByKid(kid) + if jwk == nil { + v.logFailure(ctx, iss, fmt.Sprintf("no matching JWK for kid=%q", kid)) + return nil, fmt.Errorf("%w: no matching kid", ErrAssertionKeyLookup) + } + pubKey, err := jwk.ToPublicKey() + if err != nil { + v.logFailure(ctx, iss, fmt.Sprintf("public key decode failed: %v", err)) + return nil, fmt.Errorf("%w: %v", ErrAssertionKeyLookup, err) + } + + // Verify signature with strict algorithm enforcement. Library-side claim + // validation is disabled so our custom skew and lifetime caps apply below. + verifyParser := jwt.NewParser( + jwt.WithValidMethods([]string{client.TokenEndpointAuthSigningAlg}), + jwt.WithoutClaimsValidation(), + ) + if _, err := verifyParser.Parse(assertion, func(_ *jwt.Token) (any, error) { + return pubKey, nil + }); err != nil { + v.logFailure(ctx, iss, fmt.Sprintf("signature verification failed: %v", err)) + return nil, ErrAssertionSignatureInvalid + } + + if err := v.validateTimeClaims(claims); err != nil { + v.logFailure(ctx, iss, err.Error()) + return nil, err + } + if err := v.validateAudience(claims); err != nil { + v.logFailure(ctx, iss, err.Error()) + return nil, err + } + if err := v.checkJTIReplay(ctx, iss, claims); err != nil { + v.logFailure(ctx, iss, err.Error()) + return nil, err + } + + v.logSuccess(ctx, client) + return client, nil +} + +func (v *ClientAssertionVerifier) resolveJWKS( + ctx context.Context, + client *models.OAuthApplication, + kid string, +) (*util.JWKSet, error) { + if client.JWKS != "" { + set, err := util.ParseJWKSet(client.JWKS) + if err != nil { + return nil, fmt.Errorf("parse inline JWKS: %w", err) + } + return set, nil + } + if client.JWKSURI == "" { + return nil, errors.New("client has no JWKS configured") + } + if v.jwksFetcher == nil { + return nil, errors.New("JWKS fetcher not configured") + } + return v.jwksFetcher.GetWithRefresh(ctx, client.JWKSURI, kid) +} + +func (v *ClientAssertionVerifier) validateTimeClaims(claims jwt.MapClaims) error { + now := time.Now() + skew := v.cfg.ClockSkew + + expF, ok := claims["exp"].(float64) + if !ok { + return ErrAssertionMissingRequiredTime + } + iatF, ok := claims["iat"].(float64) + if !ok { + return ErrAssertionMissingRequiredTime + } + exp := time.Unix(int64(expF), 0) + iat := time.Unix(int64(iatF), 0) + + if now.After(exp.Add(skew)) { + return ErrAssertionExpired + } + if iat.Sub(now) > skew { + return ErrAssertionNotYetValid + } + if nbfF, ok := claims["nbf"].(float64); ok { + nbf := time.Unix(int64(nbfF), 0) + if nbf.Sub(now) > skew { + return ErrAssertionNotYetValid + } + } + if exp.Sub(iat) > v.cfg.MaxLifetime { + return ErrAssertionLifetimeTooLong + } + return nil +} + +func (v *ClientAssertionVerifier) validateAudience(claims jwt.MapClaims) error { + raw, ok := claims["aud"] + if !ok { + return fmt.Errorf("%w: missing aud", ErrAssertionAudienceInvalid) + } + audValues := extractAudienceValues(raw) + if len(audValues) == 0 { + return fmt.Errorf("%w: empty aud", ErrAssertionAudienceInvalid) + } + for _, expected := range v.cfg.ExpectedAudiences { + if expected == "" { + continue + } + if slices.Contains(audValues, expected) { + return nil + } + } + return fmt.Errorf("%w: aud %v not accepted", ErrAssertionAudienceInvalid, audValues) +} + +func extractAudienceValues(raw any) []string { + switch v := raw.(type) { + case string: + return []string{v} + case []any: + out := make([]string, 0, len(v)) + for _, item := range v { + if s, ok := item.(string); ok && s != "" { + out = append(out, s) + } + } + return out + case []string: + return v + default: + return nil + } +} + +func (v *ClientAssertionVerifier) checkJTIReplay( + ctx context.Context, + clientID string, + claims jwt.MapClaims, +) error { + jti, _ := claims["jti"].(string) + if strings.TrimSpace(jti) == "" { + return ErrAssertionMissingJTI + } + if v.jtiCache == nil { + // Without a cache we cannot prevent replay. Treat this as a + // hard-configured failure rather than silently allowing replays. + return errors.New("jti replay cache not configured") + } + key := jtiCacheKeyPrefix + clientID + ":" + jti + + // Serialise the Get+Set pair: without the lock, two concurrent requests + // carrying the same jti can both observe a cache miss before either Set + // lands, accepting a replay. + v.jtiMu.Lock() + defer v.jtiMu.Unlock() + + if _, err := v.jtiCache.Get(ctx, key); err == nil { + return ErrAssertionJTIReplay + } + // TTL = remaining assertion lifetime + clock skew. If exp is absent, + // fall back to MaxLifetime (defensive). + ttl := v.cfg.MaxLifetime + v.cfg.ClockSkew + if expF, ok := claims["exp"].(float64); ok { + remaining := time.Until(time.Unix(int64(expF), 0)) + v.cfg.ClockSkew + if remaining > 0 { + ttl = remaining + } + } + // Log but do not block on cache write failure — availability over perfect + // replay protection (the cache is best-effort in a multi-instance setup). + if err := v.jtiCache.Set(ctx, key, true, ttl); err != nil { + log.Printf("[ClientAssertion] failed to record jti %s: %v", jti, err) + } + return nil +} + +func (v *ClientAssertionVerifier) logSuccess( + ctx context.Context, + client *models.OAuthApplication, +) { + v.auditService.Log(ctx, core.AuditLogEntry{ + EventType: models.EventClientAssertionVerified, + Severity: models.SeverityInfo, + ActorUserID: "client:" + client.ClientID, + ResourceType: models.ResourceClient, + ResourceID: client.ClientID, + ResourceName: client.ClientName, + Action: "client_assertion verified", + Details: models.AuditDetails{ + "signing_alg": client.TokenEndpointAuthSigningAlg, + }, + Success: true, + }) +} + +func (v *ClientAssertionVerifier) logFailure( + ctx context.Context, + issuer, reason string, +) { + v.auditService.Log(ctx, core.AuditLogEntry{ + EventType: models.EventClientAssertionFailed, + Severity: models.SeverityWarning, + ActorUserID: "client:" + issuer, + ResourceType: models.ResourceClient, + ResourceID: issuer, + Action: "client_assertion rejected", + Details: models.AuditDetails{ + "reason": reason, + }, + Success: false, + }) +} diff --git a/internal/services/client_assertion_test.go b/internal/services/client_assertion_test.go new file mode 100644 index 00000000..7ba2b09f --- /dev/null +++ b/internal/services/client_assertion_test.go @@ -0,0 +1,435 @@ +package services + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "math/big" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-authgate/authgate/internal/cache" + "github.com/go-authgate/authgate/internal/core" + "github.com/go-authgate/authgate/internal/models" + "github.com/go-authgate/authgate/internal/store" + "github.com/go-authgate/authgate/internal/util" + + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +// --- helpers ------------------------------------------------------------ + +const testAudience = "https://authgate.test/oauth/token" + +type rsaFixture struct { + priv *rsa.PrivateKey + jwk util.JWK +} + +func newRSAFixture(t *testing.T, kid string) rsaFixture { + t.Helper() + priv, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + return rsaFixture{ + priv: priv, + jwk: util.JWK{ + Kty: "RSA", + Use: "sig", + Kid: kid, + Alg: "RS256", + N: base64.RawURLEncoding.EncodeToString(priv.PublicKey.N.Bytes()), + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(priv.PublicKey.E)).Bytes()), + }, + } +} + +type ecFixture struct { + priv *ecdsa.PrivateKey + jwk util.JWK +} + +func newECFixture(t *testing.T, kid string) ecFixture { + t.Helper() + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + byteLen := 32 + xBytes := make([]byte, byteLen) + yBytes := make([]byte, byteLen) + copy(xBytes[byteLen-len(priv.PublicKey.X.Bytes()):], priv.PublicKey.X.Bytes()) + copy(yBytes[byteLen-len(priv.PublicKey.Y.Bytes()):], priv.PublicKey.Y.Bytes()) + return ecFixture{ + priv: priv, + jwk: util.JWK{ + Kty: "EC", + Use: "sig", + Kid: kid, + Alg: "ES256", + Crv: "P-256", + X: base64.RawURLEncoding.EncodeToString(xBytes), + Y: base64.RawURLEncoding.EncodeToString(yBytes), + }, + } +} + +func signRS256(t *testing.T, priv *rsa.PrivateKey, kid string, claims jwt.MapClaims) string { + t.Helper() + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tok.Header["kid"] = kid + out, err := tok.SignedString(priv) + require.NoError(t, err) + return out +} + +func signES256(t *testing.T, priv *ecdsa.PrivateKey, kid string, claims jwt.MapClaims) string { + t.Helper() + tok := jwt.NewWithClaims(jwt.SigningMethodES256, claims) + tok.Header["kid"] = kid + out, err := tok.SignedString(priv) + require.NoError(t, err) + return out +} + +func signHS256(t *testing.T, secret []byte, claims jwt.MapClaims) string { + t.Helper() + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + out, err := tok.SignedString(secret) + require.NoError(t, err) + return out +} + +type verifierFixture struct { + verifier *ClientAssertionVerifier + store *store.Store + cs *ClientService +} + +func newVerifierFixture(t *testing.T) *verifierFixture { + t.Helper() + s := setupTestStore(t) + cs := NewClientService(s, nil, nil, 0, nil, 0) + fetcher := NewJWKSFetcher(cache.NewMemoryCache[util.JWKSet](0), 2*time.Second, time.Minute) + jtiCache := cache.NewMemoryCache[bool](0) + t.Cleanup(func() { _ = jtiCache.Close() }) + v := NewClientAssertionVerifier(cs, fetcher, jtiCache, NewNoopAuditService(), + ClientAssertionConfig{ + Enabled: true, + ExpectedAudiences: []string{testAudience}, + MaxLifetime: 5 * time.Minute, + ClockSkew: 30 * time.Second, + }) + return &verifierFixture{verifier: v, store: s, cs: cs} +} + +func seedRSAClient( + t *testing.T, + s *store.Store, + jwk util.JWK, + alg string, +) *models.OAuthApplication { + t.Helper() + set := util.JWKSet{Keys: []util.JWK{jwk}} + blob, err := json.Marshal(set) + require.NoError(t, err) + client := &models.OAuthApplication{ + ClientID: uuid.New().String(), + ClientName: "pkjwt-client", + UserID: uuid.New().String(), + Scopes: "read write", + GrantTypes: "client_credentials", + ClientType: core.ClientTypeConfidential.String(), + EnableClientCredentialsFlow: true, + Status: models.ClientStatusActive, + TokenEndpointAuthMethod: models.TokenEndpointAuthPrivateKeyJWT, + TokenEndpointAuthSigningAlg: alg, + JWKS: string(blob), + } + require.NoError(t, s.CreateClient(client)) + return client +} + +func baseClaims(clientID string) jwt.MapClaims { + now := time.Now() + return jwt.MapClaims{ + "iss": clientID, + "sub": clientID, + "aud": testAudience, + "iat": now.Unix(), + "exp": now.Add(time.Minute).Unix(), + "jti": uuid.NewString(), + } +} + +// --- success paths ------------------------------------------------------ + +func TestClientAssertion_RS256_Success(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + client := seedRSAClient(t, f.store, rsa.jwk, "RS256") + + token := signRS256(t, rsa.priv, "k1", baseClaims(client.ClientID)) + got, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.NoError(t, err) + require.Equal(t, client.ClientID, got.ClientID) +} + +func TestClientAssertion_ES256_Success(t *testing.T) { + f := newVerifierFixture(t) + ec := newECFixture(t, "ec1") + set := util.JWKSet{Keys: []util.JWK{ec.jwk}} + blob, err := json.Marshal(set) + require.NoError(t, err) + + client := &models.OAuthApplication{ + ClientID: uuid.New().String(), + ClientName: "pkjwt-ec", + UserID: uuid.New().String(), + Scopes: "read", + GrantTypes: "client_credentials", + ClientType: core.ClientTypeConfidential.String(), + EnableClientCredentialsFlow: true, + Status: models.ClientStatusActive, + TokenEndpointAuthMethod: models.TokenEndpointAuthPrivateKeyJWT, + TokenEndpointAuthSigningAlg: "ES256", + JWKS: string(blob), + } + require.NoError(t, f.store.CreateClient(client)) + + token := signES256(t, ec.priv, "ec1", baseClaims(client.ClientID)) + got, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.NoError(t, err) + require.Equal(t, client.ClientID, got.ClientID) +} + +func TestClientAssertion_JWKSURI_Success(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + blob, _ := json.Marshal(util.JWKSet{Keys: []util.JWK{rsa.jwk}}) + _, _ = w.Write(blob) + })) + defer srv.Close() + + client := &models.OAuthApplication{ + ClientID: uuid.New().String(), + ClientName: "pkjwt-uri", + UserID: uuid.New().String(), + Scopes: "read", + GrantTypes: "client_credentials", + ClientType: core.ClientTypeConfidential.String(), + EnableClientCredentialsFlow: true, + Status: models.ClientStatusActive, + TokenEndpointAuthMethod: models.TokenEndpointAuthPrivateKeyJWT, + TokenEndpointAuthSigningAlg: "RS256", + JWKSURI: srv.URL, + } + require.NoError(t, f.store.CreateClient(client)) + + token := signRS256(t, rsa.priv, "k1", baseClaims(client.ClientID)) + got, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.NoError(t, err) + require.Equal(t, client.ClientID, got.ClientID) +} + +// --- failure paths ------------------------------------------------------ + +func TestClientAssertion_FeatureDisabled(t *testing.T) { + f := newVerifierFixture(t) + f.verifier.cfg.Enabled = false + _, err := f.verifier.Verify(context.Background(), "anything", AssertionType) + require.ErrorIs(t, err, ErrAssertionFeatureDisabled) +} + +func TestClientAssertion_WrongAssertionType(t *testing.T) { + f := newVerifierFixture(t) + _, err := f.verifier.Verify(context.Background(), "anything", "urn:other") + require.ErrorIs(t, err, ErrAssertionTypeInvalid) +} + +func TestClientAssertion_Malformed(t *testing.T) { + f := newVerifierFixture(t) + _, err := f.verifier.Verify(context.Background(), "not.a.jwt", AssertionType) + require.ErrorIs(t, err, ErrAssertionMalformed) +} + +func TestClientAssertion_IssSubMismatch(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + client := seedRSAClient(t, f.store, rsa.jwk, "RS256") + + claims := baseClaims(client.ClientID) + claims["sub"] = "different" + token := signRS256(t, rsa.priv, "k1", claims) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionIssuerMismatch) +} + +func TestClientAssertion_UnknownClient(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + claims := baseClaims("ghost-client") + token := signRS256(t, rsa.priv, "k1", claims) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionClientUnknown) +} + +func TestClientAssertion_MethodNotAllowed(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + // Seed a client that registered as client_secret_basic, not private_key_jwt. + client := &models.OAuthApplication{ + ClientID: uuid.New().String(), + ClientName: "secret-client", + UserID: uuid.New().String(), + Scopes: "read", + ClientType: core.ClientTypeConfidential.String(), + EnableClientCredentialsFlow: true, + Status: models.ClientStatusActive, + TokenEndpointAuthMethod: models.TokenEndpointAuthClientSecretBasic, + } + require.NoError(t, f.store.CreateClient(client)) + + token := signRS256(t, rsa.priv, "k1", baseClaims(client.ClientID)) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionMethodNotAllowed) +} + +func TestClientAssertion_AlgorithmMismatch(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + // Register the client as ES256, but sign as RS256 — mismatch. + client := seedRSAClient(t, f.store, rsa.jwk, "ES256") + + token := signRS256(t, rsa.priv, "k1", baseClaims(client.ClientID)) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionAlgorithmMismatch) +} + +func TestClientAssertion_HS256Rejected(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + client := seedRSAClient(t, f.store, rsa.jwk, "RS256") + + // Try to sign with HS256 — algorithm mismatch should reject it regardless. + token := signHS256(t, []byte("shared-secret"), baseClaims(client.ClientID)) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionAlgorithmMismatch) +} + +func TestClientAssertion_BadSignature(t *testing.T) { + f := newVerifierFixture(t) + rsa1 := newRSAFixture(t, "k1") + rsa2 := newRSAFixture(t, "k1") + // Register rsa1's public key, but sign with rsa2's private key. + client := seedRSAClient(t, f.store, rsa1.jwk, "RS256") + + token := signRS256(t, rsa2.priv, "k1", baseClaims(client.ClientID)) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionSignatureInvalid) +} + +func TestClientAssertion_WrongAudience(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + client := seedRSAClient(t, f.store, rsa.jwk, "RS256") + + claims := baseClaims(client.ClientID) + claims["aud"] = "https://other.example.com/token" + token := signRS256(t, rsa.priv, "k1", claims) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionAudienceInvalid) +} + +func TestClientAssertion_Expired(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + client := seedRSAClient(t, f.store, rsa.jwk, "RS256") + + claims := baseClaims(client.ClientID) + claims["iat"] = time.Now().Add(-2 * time.Minute).Unix() + claims["exp"] = time.Now().Add(-time.Minute).Unix() + token := signRS256(t, rsa.priv, "k1", claims) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionExpired) +} + +func TestClientAssertion_LifetimeTooLong(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + client := seedRSAClient(t, f.store, rsa.jwk, "RS256") + + now := time.Now() + claims := jwt.MapClaims{ + "iss": client.ClientID, + "sub": client.ClientID, + "aud": testAudience, + "iat": now.Unix(), + "exp": now.Add(1 * time.Hour).Unix(), // exceeds MaxLifetime=5m + "jti": uuid.NewString(), + } + token := signRS256(t, rsa.priv, "k1", claims) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionLifetimeTooLong) +} + +func TestClientAssertion_MissingJTI(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + client := seedRSAClient(t, f.store, rsa.jwk, "RS256") + + claims := baseClaims(client.ClientID) + delete(claims, "jti") + token := signRS256(t, rsa.priv, "k1", claims) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionMissingJTI) +} + +func TestClientAssertion_JTIReplay(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + client := seedRSAClient(t, f.store, rsa.jwk, "RS256") + + claims := baseClaims(client.ClientID) + token := signRS256(t, rsa.priv, "k1", claims) + + // First use: accepted. + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.NoError(t, err) + // Second use of the same token: rejected as replay. + _, err = f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionJTIReplay) +} + +func TestClientAssertion_UnknownKid(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "registered") + client := seedRSAClient(t, f.store, rsa.jwk, "RS256") + + // Sign with a different kid — no matching JWK. + token := signRS256(t, rsa.priv, "mystery", baseClaims(client.ClientID)) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionKeyLookup) +} + +func TestClientAssertion_InactiveClient(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + client := seedRSAClient(t, f.store, rsa.jwk, "RS256") + // Disable the client after registration. + client.Status = models.ClientStatusInactive + require.NoError(t, f.store.UpdateClient(client)) + // Clear any cached version. + _ = f.cs.clientCache.Delete(context.Background(), client.ClientID) + + token := signRS256(t, rsa.priv, "k1", baseClaims(client.ClientID)) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionClientInactive) +} diff --git a/internal/services/jwks_fetcher.go b/internal/services/jwks_fetcher.go new file mode 100644 index 00000000..06cb4823 --- /dev/null +++ b/internal/services/jwks_fetcher.go @@ -0,0 +1,152 @@ +package services + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "net/http" + "sync" + "time" + + "github.com/go-authgate/authgate/internal/core" + "github.com/go-authgate/authgate/internal/util" +) + +// Errors returned by JWKSFetcher. +var ( + ErrJWKSFetchFailed = errors.New("failed to fetch JWKS") + ErrJWKSTooLarge = errors.New("JWKS response exceeds size limit") +) + +// Maximum JWKS body size accepted over the network (defense against slowloris / huge responses). +const jwksMaxBodyBytes = 1 << 20 // 1 MiB + +// Minimum interval between forced refreshes of the same JWKS URI. +// Prevents attackers from triggering unbounded refetches by sending assertions +// with unknown kids — legitimate rotations still succeed after this cooldown. +const jwksRefreshCooldown = 30 * time.Second + +// JWKSFetcher retrieves and caches JWK Sets from remote jwks_uri endpoints. +// It is safe for concurrent use. +type JWKSFetcher struct { + httpClient *http.Client + cache core.Cache[util.JWKSet] + ttl time.Duration + + // lastRefresh tracks when each uri was last force-refreshed so kid-miss + // driven refetches respect jwksRefreshCooldown. + lastRefresh sync.Map // map[string]time.Time +} + +// NewJWKSFetcher constructs a JWKSFetcher. +// timeout controls the HTTP request timeout; ttl controls the cache lifetime. +// cache may be nil, in which case a bounded in-process memory cache is used. +func NewJWKSFetcher(cache core.Cache[util.JWKSet], timeout, ttl time.Duration) *JWKSFetcher { + if timeout <= 0 { + timeout = 10 * time.Second + } + if ttl <= 0 { + ttl = time.Hour + } + return &JWKSFetcher{ + httpClient: &http.Client{Timeout: timeout}, + cache: cache, + ttl: ttl, + } +} + +// Get returns the cached JWKS for uri, fetching it on cache miss. Use +// GetWithRefresh when verifying a signature against a specific kid so that a +// cache stale against a rotated signing key triggers a fresh fetch. +func (f *JWKSFetcher) Get(ctx context.Context, uri string) (*util.JWKSet, error) { + return f.getCached(ctx, uri) +} + +// GetWithRefresh returns the JWKS for uri. If the cached version does not +// contain a key with the requested kid, the cache is bypassed and the +// document is refetched — this supports runtime key rotation without waiting +// for the TTL to expire. Refreshes for the same uri are rate-limited to +// prevent malformed or attacker-crafted kids from triggering unbounded +// refetches of the remote endpoint. +func (f *JWKSFetcher) GetWithRefresh(ctx context.Context, uri, kid string) (*util.JWKSet, error) { + set, err := f.getCached(ctx, uri) + if err != nil { + return nil, err + } + if kid == "" || set.FindByKid(kid) != nil { + return set, nil + } + if !f.canRefreshNow(uri) { + return set, nil + } + if err := f.cache.Delete(ctx, uri); err != nil { + log.Printf("[JWKSFetcher] cache delete failed for %s: %v", uri, err) + } + return f.getCached(ctx, uri) +} + +// canRefreshNow returns true if uri has not been force-refreshed within +// jwksRefreshCooldown, and marks it as refreshed. Concurrent callers for the +// same uri collapse into a single refresh per cooldown window. +func (f *JWKSFetcher) canRefreshNow(uri string) bool { + now := time.Now() + prev, loaded := f.lastRefresh.LoadOrStore(uri, now) + if !loaded { + return true + } + if now.Sub(prev.(time.Time)) < jwksRefreshCooldown { + return false + } + f.lastRefresh.Store(uri, now) + return true +} + +func (f *JWKSFetcher) getCached(ctx context.Context, uri string) (*util.JWKSet, error) { + if f.cache == nil { + set, err := f.fetch(ctx, uri) + if err != nil { + return nil, err + } + return &set, nil + } + set, err := f.cache.GetWithFetch(ctx, uri, f.ttl, + func(ctx context.Context, _ string) (util.JWKSet, error) { + return f.fetch(ctx, uri) + }) + if err != nil { + return nil, err + } + return &set, nil +} + +func (f *JWKSFetcher) fetch(ctx context.Context, uri string) (util.JWKSet, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil) + if err != nil { + return util.JWKSet{}, fmt.Errorf("%w: %v", ErrJWKSFetchFailed, err) + } + req.Header.Set("Accept", "application/json") + resp, err := f.httpClient.Do(req) + if err != nil { + return util.JWKSet{}, fmt.Errorf("%w: %v", ErrJWKSFetchFailed, err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return util.JWKSet{}, fmt.Errorf( + "%w: unexpected status %d", ErrJWKSFetchFailed, resp.StatusCode, + ) + } + body, err := io.ReadAll(io.LimitReader(resp.Body, jwksMaxBodyBytes+1)) + if err != nil { + return util.JWKSet{}, fmt.Errorf("%w: %v", ErrJWKSFetchFailed, err) + } + if len(body) > jwksMaxBodyBytes { + return util.JWKSet{}, ErrJWKSTooLarge + } + set, err := util.ParseJWKSet(string(body)) + if err != nil { + return util.JWKSet{}, fmt.Errorf("%w: %v", ErrJWKSFetchFailed, err) + } + return *set, nil +} diff --git a/internal/services/jwks_fetcher_test.go b/internal/services/jwks_fetcher_test.go new file mode 100644 index 00000000..c2b67b78 --- /dev/null +++ b/internal/services/jwks_fetcher_test.go @@ -0,0 +1,183 @@ +package services + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "errors" + "math/big" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/go-authgate/authgate/internal/cache" + "github.com/go-authgate/authgate/internal/util" +) + +func jwkSetJSON(t *testing.T, keys []util.JWK) string { + t.Helper() + b, err := json.Marshal(util.JWKSet{Keys: keys}) + if err != nil { + t.Fatalf("marshal JWKS: %v", err) + } + return string(b) +} + +func rsaJWKFixture(t *testing.T, kid string) (util.JWK, *rsa.PrivateKey) { + t.Helper() + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + return util.JWK{ + Kty: "RSA", + Use: "sig", + Kid: kid, + Alg: "RS256", + N: base64.RawURLEncoding.EncodeToString(priv.PublicKey.N.Bytes()), + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(priv.PublicKey.E)).Bytes()), + }, priv +} + +func TestJWKSFetcher_CacheHit(t *testing.T) { + jwk, _ := rsaJWKFixture(t, "k1") + var hits int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(&hits, 1) + _, _ = w.Write([]byte(jwkSetJSON(t, []util.JWK{jwk}))) + })) + defer srv.Close() + + mc := cache.NewMemoryCache[util.JWKSet](0) + defer mc.Close() + f := NewJWKSFetcher(mc, 2*time.Second, time.Minute) + + ctx := context.Background() + if _, err := f.Get(ctx, srv.URL); err != nil { + t.Fatalf("first Get: %v", err) + } + if _, err := f.Get(ctx, srv.URL); err != nil { + t.Fatalf("second Get: %v", err) + } + if n := atomic.LoadInt32(&hits); n != 1 { + t.Fatalf("expected 1 HTTP hit (cache should serve the second), got %d", n) + } +} + +func TestJWKSFetcher_RefreshOnKidMiss(t *testing.T) { + oldJWK, _ := rsaJWKFixture(t, "old") + newJWK, _ := rsaJWKFixture(t, "new") + var hits int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + n := atomic.AddInt32(&hits, 1) + if n == 1 { + _, _ = w.Write([]byte(jwkSetJSON(t, []util.JWK{oldJWK}))) + return + } + _, _ = w.Write([]byte(jwkSetJSON(t, []util.JWK{newJWK}))) + })) + defer srv.Close() + + mc := cache.NewMemoryCache[util.JWKSet](0) + defer mc.Close() + f := NewJWKSFetcher(mc, 2*time.Second, time.Minute) + + ctx := context.Background() + // Warm cache with the old JWKS. + if _, err := f.Get(ctx, srv.URL); err != nil { + t.Fatalf("warm: %v", err) + } + // Ask for a kid that exists only in the refreshed set — fetcher must refetch. + set, err := f.GetWithRefresh(ctx, srv.URL, "new") + if err != nil { + t.Fatalf("GetWithRefresh: %v", err) + } + if got := set.FindByKid("new"); got == nil { + t.Fatal("expected refresh to find new kid") + } + if n := atomic.LoadInt32(&hits); n != 2 { + t.Fatalf("expected 2 HTTP hits after kid miss, got %d", n) + } +} + +func TestJWKSFetcher_RefreshCooldown(t *testing.T) { + jwk, _ := rsaJWKFixture(t, "registered") + var hits int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(&hits, 1) + _, _ = w.Write([]byte(jwkSetJSON(t, []util.JWK{jwk}))) + })) + defer srv.Close() + + mc := cache.NewMemoryCache[util.JWKSet](0) + defer mc.Close() + f := NewJWKSFetcher(mc, 2*time.Second, time.Minute) + + ctx := context.Background() + // First unknown-kid request triggers a refetch (cache miss forces one more). + _, err := f.GetWithRefresh(ctx, srv.URL, "unknown") + if err != nil { + t.Fatalf("first call: %v", err) + } + firstHits := atomic.LoadInt32(&hits) + // Second unknown-kid request within the cooldown must NOT refetch. + _, err = f.GetWithRefresh(ctx, srv.URL, "still-unknown") + if err != nil { + t.Fatalf("second call: %v", err) + } + if got := atomic.LoadInt32(&hits); got != firstHits { + t.Fatalf("expected no additional refetch within cooldown; hits went %d → %d", + firstHits, got) + } +} + +func TestJWKSFetcher_Non200Fails(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + mc := cache.NewMemoryCache[util.JWKSet](0) + defer mc.Close() + f := NewJWKSFetcher(mc, 2*time.Second, time.Minute) + _, err := f.Get(context.Background(), srv.URL) + if !errors.Is(err, ErrJWKSFetchFailed) { + t.Fatalf("expected ErrJWKSFetchFailed, got %v", err) + } +} + +func TestJWKSFetcher_InvalidBody(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("not json")) + })) + defer srv.Close() + + mc := cache.NewMemoryCache[util.JWKSet](0) + defer mc.Close() + f := NewJWKSFetcher(mc, 2*time.Second, time.Minute) + _, err := f.Get(context.Background(), srv.URL) + if !errors.Is(err, ErrJWKSFetchFailed) { + t.Fatalf("expected ErrJWKSFetchFailed, got %v", err) + } +} + +func TestJWKSFetcher_BodyTooLarge(t *testing.T) { + large := strings.Repeat("a", jwksMaxBodyBytes+10) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte(large)) + })) + defer srv.Close() + + mc := cache.NewMemoryCache[util.JWKSet](0) + defer mc.Close() + f := NewJWKSFetcher(mc, 2*time.Second, time.Minute) + _, err := f.Get(context.Background(), srv.URL) + if !errors.Is(err, ErrJWKSTooLarge) { + t.Fatalf("expected ErrJWKSTooLarge, got %v", err) + } +} diff --git a/internal/services/token_client_credentials.go b/internal/services/token_client_credentials.go index 9a33c9a4..b68ec459 100644 --- a/internal/services/token_client_credentials.go +++ b/internal/services/token_client_credentials.go @@ -46,6 +46,36 @@ func (s *TokenService) IssueClientCredentialsToken( return nil, ErrInvalidClientCredentials } + return s.issueClientCredentialsTokenForClient(ctx, client, requestedScopes) +} + +// IssueClientCredentialsTokenForClient issues a client_credentials access token +// for a client that has already been authenticated (e.g. via RFC 7523 +// private_key_jwt). The caller is responsible for authentication; this method +// only checks client type, flow enablement, and scope bounds. +func (s *TokenService) IssueClientCredentialsTokenForClient( + ctx context.Context, + client *models.OAuthApplication, + requestedScopes string, +) (*models.AccessToken, error) { + if client == nil || !client.IsActive() { + return nil, ErrInvalidClientCredentials + } + if core.ClientType(client.ClientType) != core.ClientTypeConfidential { + return nil, ErrClientNotConfidential + } + if !client.EnableClientCredentialsFlow { + return nil, ErrClientCredentialsFlowDisabled + } + return s.issueClientCredentialsTokenForClient(ctx, client, requestedScopes) +} + +func (s *TokenService) issueClientCredentialsTokenForClient( + ctx context.Context, + client *models.OAuthApplication, + requestedScopes string, +) (*models.AccessToken, error) { + clientID := client.ClientID // 5. Resolve effective scopes effectiveScopes := requestedScopes if effectiveScopes == "" { diff --git a/internal/util/jwk.go b/internal/util/jwk.go new file mode 100644 index 00000000..bd4f1cf7 --- /dev/null +++ b/internal/util/jwk.go @@ -0,0 +1,176 @@ +package util + +import ( + "crypto" + "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "math/big" +) + +// Known JWK key types and curves (RFC 7517 §4.1, RFC 7518 §6). +const ( + JWKTypeRSA = "RSA" + JWKTypeEC = "EC" + JWKCurveP256 = "P-256" +) + +// JWK represents a single JSON Web Key (RFC 7517). +// Only the fields required for RSA and EC P-256 signing keys are modelled; additional +// JWK fields (e.g. x5c, x5t) are ignored on parse and omitted on marshal. +type JWK struct { + Kty string `json:"kty"` + Use string `json:"use,omitempty"` + Kid string `json:"kid,omitempty"` + Alg string `json:"alg,omitempty"` + // RSA + N string `json:"n,omitempty"` + E string `json:"e,omitempty"` + // EC + Crv string `json:"crv,omitempty"` + X string `json:"x,omitempty"` + Y string `json:"y,omitempty"` +} + +// JWKSet is a set of JSON Web Keys (RFC 7517 §5). +type JWKSet struct { + Keys []JWK `json:"keys"` +} + +// Errors returned from this package. +var ( + ErrInvalidJWKS = errors.New("invalid JWK Set") + ErrUnsupportedKeyType = errors.New("unsupported JWK key type") + ErrJWKFieldMissing = errors.New("JWK missing required field") + ErrJWKInvalidEncoding = errors.New("JWK contains invalid base64url encoding") + ErrJWKNoMatchingKey = errors.New("no JWK matches requested kid") + ErrJWKAlgorithmMismatch = errors.New("JWK algorithm does not match requested algorithm") +) + +// ParseJWKSet unmarshals a JWKS JSON document and validates that it contains +// at least one key. It does not verify individual keys — use ToPublicKey for that. +func ParseJWKSet(jsonBlob string) (*JWKSet, error) { + var set JWKSet + if err := json.Unmarshal([]byte(jsonBlob), &set); err != nil { + return nil, fmt.Errorf("%w: %v", ErrInvalidJWKS, err) + } + if len(set.Keys) == 0 { + return nil, fmt.Errorf("%w: keys array is empty", ErrInvalidJWKS) + } + return &set, nil +} + +// FindByKid returns the key whose Kid matches, or nil if not found. +// If kid is empty and there is exactly one key, that key is returned; otherwise nil. +func (s *JWKSet) FindByKid(kid string) *JWK { + if kid == "" { + if len(s.Keys) == 1 { + return &s.Keys[0] + } + return nil + } + for i := range s.Keys { + if s.Keys[i].Kid == kid { + return &s.Keys[i] + } + } + return nil +} + +// ToPublicKey converts a JWK to a crypto.PublicKey. Only RSA and EC P-256 keys +// are supported. The returned key can be used with golang-jwt/jwt for signature +// verification. +func (k *JWK) ToPublicKey() (crypto.PublicKey, error) { + switch k.Kty { + case JWKTypeRSA: + return rsaKeyFromJWK(k) + case JWKTypeEC: + return ecKeyFromJWK(k) + default: + return nil, fmt.Errorf("%w: kty=%q", ErrUnsupportedKeyType, k.Kty) + } +} + +func rsaKeyFromJWK(k *JWK) (*rsa.PublicKey, error) { + if k.N == "" || k.E == "" { + return nil, fmt.Errorf("%w: RSA key requires n and e", ErrJWKFieldMissing) + } + nBytes, err := base64.RawURLEncoding.DecodeString(k.N) + if err != nil { + return nil, fmt.Errorf("%w: n: %v", ErrJWKInvalidEncoding, err) + } + eBytes, err := base64.RawURLEncoding.DecodeString(k.E) + if err != nil { + return nil, fmt.Errorf("%w: e: %v", ErrJWKInvalidEncoding, err) + } + n := new(big.Int).SetBytes(nBytes) + e := new(big.Int).SetBytes(eBytes) + if !e.IsInt64() { + return nil, fmt.Errorf("%w: RSA exponent overflow", ErrInvalidJWKS) + } + if n.BitLen() < 2048 { + return nil, fmt.Errorf( + "%w: RSA modulus must be >= 2048 bits (got %d)", + ErrInvalidJWKS, + n.BitLen(), + ) + } + return &rsa.PublicKey{N: n, E: int(e.Int64())}, nil +} + +func ecKeyFromJWK(k *JWK) (*ecdsa.PublicKey, error) { + if k.Crv == "" || k.X == "" || k.Y == "" { + return nil, fmt.Errorf("%w: EC key requires crv, x, and y", ErrJWKFieldMissing) + } + var ( + curve elliptic.Curve + ecdhCrv ecdh.Curve + byteLen int + ) + switch k.Crv { + case JWKCurveP256: + curve = elliptic.P256() + ecdhCrv = ecdh.P256() + byteLen = 32 + default: + return nil, fmt.Errorf( + "%w: curve %q (only P-256 is supported)", + ErrUnsupportedKeyType, + k.Crv, + ) + } + xBytes, err := base64.RawURLEncoding.DecodeString(k.X) + if err != nil { + return nil, fmt.Errorf("%w: x: %v", ErrJWKInvalidEncoding, err) + } + yBytes, err := base64.RawURLEncoding.DecodeString(k.Y) + if err != nil { + return nil, fmt.Errorf("%w: y: %v", ErrJWKInvalidEncoding, err) + } + if len(xBytes) > byteLen || len(yBytes) > byteLen { + return nil, fmt.Errorf("%w: EC coordinate exceeds curve byte length", ErrInvalidJWKS) + } + // Left-pad x/y to fixed width, then build SEC1 uncompressed point (0x04 || X || Y) + // and delegate on-curve validation to crypto/ecdh. + xPad := make([]byte, byteLen) + yPad := make([]byte, byteLen) + copy(xPad[byteLen-len(xBytes):], xBytes) + copy(yPad[byteLen-len(yBytes):], yBytes) + uncompressed := make([]byte, 0, 1+2*byteLen) + uncompressed = append(uncompressed, 0x04) + uncompressed = append(uncompressed, xPad...) + uncompressed = append(uncompressed, yPad...) + if _, err := ecdhCrv.NewPublicKey(uncompressed); err != nil { + return nil, fmt.Errorf("%w: EC point is not on curve %s: %v", ErrInvalidJWKS, k.Crv, err) + } + return &ecdsa.PublicKey{ + Curve: curve, + X: new(big.Int).SetBytes(xPad), + Y: new(big.Int).SetBytes(yPad), + }, nil +} diff --git a/internal/util/jwk_test.go b/internal/util/jwk_test.go new file mode 100644 index 00000000..0a1b3c0f --- /dev/null +++ b/internal/util/jwk_test.go @@ -0,0 +1,193 @@ +package util + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "errors" + "math/big" + "testing" +) + +// helpers --------------------------------------------------------------- + +func jwkFromRSA(t *testing.T, pub *rsa.PublicKey, kid string) JWK { + t.Helper() + return JWK{ + Kty: "RSA", + Use: "sig", + Kid: kid, + Alg: "RS256", + N: base64.RawURLEncoding.EncodeToString(pub.N.Bytes()), + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pub.E)).Bytes()), + } +} + +func jwkFromEC(t *testing.T, pub *ecdsa.PublicKey, kid string) JWK { + t.Helper() + byteLen := (pub.Curve.Params().BitSize + 7) / 8 + xBytes := make([]byte, byteLen) + yBytes := make([]byte, byteLen) + copy(xBytes[byteLen-len(pub.X.Bytes()):], pub.X.Bytes()) + copy(yBytes[byteLen-len(pub.Y.Bytes()):], pub.Y.Bytes()) + return JWK{ + Kty: "EC", + Use: "sig", + Kid: kid, + Alg: "ES256", + Crv: pub.Curve.Params().Name, + X: base64.RawURLEncoding.EncodeToString(xBytes), + Y: base64.RawURLEncoding.EncodeToString(yBytes), + } +} + +// tests ----------------------------------------------------------------- + +func TestParseJWKSet_RSA_Roundtrip(t *testing.T) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + set := JWKSet{Keys: []JWK{jwkFromRSA(t, &priv.PublicKey, "test-key")}} + blob, err := json.Marshal(set) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + parsed, err := ParseJWKSet(string(blob)) + if err != nil { + t.Fatalf("ParseJWKSet: %v", err) + } + if len(parsed.Keys) != 1 { + t.Fatalf("expected 1 key, got %d", len(parsed.Keys)) + } + k := parsed.FindByKid("test-key") + if k == nil { + t.Fatal("FindByKid: nil") + } + pub, err := k.ToPublicKey() + if err != nil { + t.Fatalf("ToPublicKey: %v", err) + } + rsaPub, ok := pub.(*rsa.PublicKey) + if !ok { + t.Fatalf("expected *rsa.PublicKey, got %T", pub) + } + if rsaPub.N.Cmp(priv.PublicKey.N) != 0 || rsaPub.E != priv.PublicKey.E { + t.Fatal("RSA key mismatch after roundtrip") + } +} + +func TestParseJWKSet_EC_Roundtrip(t *testing.T) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + set := JWKSet{Keys: []JWK{jwkFromEC(t, &priv.PublicKey, "ec-key")}} + blob, err := json.Marshal(set) + if err != nil { + t.Fatalf("marshal: %v", err) + } + parsed, err := ParseJWKSet(string(blob)) + if err != nil { + t.Fatalf("ParseJWKSet: %v", err) + } + k := parsed.FindByKid("ec-key") + pub, err := k.ToPublicKey() + if err != nil { + t.Fatalf("ToPublicKey: %v", err) + } + ecPub, ok := pub.(*ecdsa.PublicKey) + if !ok { + t.Fatalf("expected *ecdsa.PublicKey, got %T", pub) + } + if ecPub.X.Cmp(priv.PublicKey.X) != 0 || ecPub.Y.Cmp(priv.PublicKey.Y) != 0 { + t.Fatal("EC key mismatch after roundtrip") + } +} + +func TestParseJWKSet_EmptyKeys(t *testing.T) { + _, err := ParseJWKSet(`{"keys":[]}`) + if !errors.Is(err, ErrInvalidJWKS) { + t.Fatalf("expected ErrInvalidJWKS, got %v", err) + } +} + +func TestParseJWKSet_InvalidJSON(t *testing.T) { + _, err := ParseJWKSet(`not json`) + if !errors.Is(err, ErrInvalidJWKS) { + t.Fatalf("expected ErrInvalidJWKS, got %v", err) + } +} + +func TestFindByKid_SingleKeyNoKid(t *testing.T) { + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + set := &JWKSet{Keys: []JWK{jwkFromRSA(t, &priv.PublicKey, "")}} + if got := set.FindByKid(""); got == nil { + t.Fatal("expected single key match when kid empty") + } +} + +func TestFindByKid_MultipleKeysEmptyKid(t *testing.T) { + priv, _ := rsa.GenerateKey(rand.Reader, 2048) + set := &JWKSet{Keys: []JWK{ + jwkFromRSA(t, &priv.PublicKey, "a"), + jwkFromRSA(t, &priv.PublicKey, "b"), + }} + if got := set.FindByKid(""); got != nil { + t.Fatal("expected nil when multiple keys and empty kid") + } +} + +func TestToPublicKey_UnsupportedKty(t *testing.T) { + k := &JWK{Kty: "oct"} + _, err := k.ToPublicKey() + if !errors.Is(err, ErrUnsupportedKeyType) { + t.Fatalf("expected ErrUnsupportedKeyType, got %v", err) + } +} + +func TestToPublicKey_RSA_MissingField(t *testing.T) { + k := &JWK{Kty: "RSA", N: "AQAB"} // missing e + _, err := k.ToPublicKey() + if !errors.Is(err, ErrJWKFieldMissing) { + t.Fatalf("expected ErrJWKFieldMissing, got %v", err) + } +} + +func TestToPublicKey_RSA_WeakKey(t *testing.T) { + priv, err := rsa.GenerateKey(rand.Reader, 1024) // below min 2048 + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + k := jwkFromRSA(t, &priv.PublicKey, "weak") + _, err = k.ToPublicKey() + if !errors.Is(err, ErrInvalidJWKS) { + t.Fatalf("expected ErrInvalidJWKS for weak key, got %v", err) + } +} + +func TestToPublicKey_EC_UnsupportedCurve(t *testing.T) { + k := &JWK{Kty: "EC", Crv: "P-384", X: "AA", Y: "AA"} + _, err := k.ToPublicKey() + if !errors.Is(err, ErrUnsupportedKeyType) { + t.Fatalf("expected ErrUnsupportedKeyType, got %v", err) + } +} + +func TestToPublicKey_EC_PointNotOnCurve(t *testing.T) { + k := &JWK{ + Kty: "EC", + Crv: "P-256", + // Valid base64 but nonsense coordinates → not on curve + X: base64.RawURLEncoding.EncodeToString([]byte{1, 2, 3, 4}), + Y: base64.RawURLEncoding.EncodeToString([]byte{5, 6, 7, 8}), + } + _, err := k.ToPublicKey() + if !errors.Is(err, ErrInvalidJWKS) { + t.Fatalf("expected ErrInvalidJWKS for off-curve point, got %v", err) + } +} From 062de054ebfeb9f59bc5e4327e8b2908b8a4ad49 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sun, 12 Apr 2026 21:44:34 +0800 Subject: [PATCH 02/13] fix(oauth): address Copilot review on private_key_jwt PR - Only overwrite token endpoint auth fields in UpdateClient when the method is explicitly provided, preventing legacy forms from clearing private_key_jwt clients' JWKS material - Guard against nil cache in JWKSFetcher GetWithRefresh and fix the constructor docstring to reflect the no-cache-means-no-cache behaviour - Reject RSA public exponents that are even or less than 3 at JWK parse time - Distinguish cache miss from backend errors in jti replay check and fail closed on backend errors rather than silently accepting replays - Shard the jti replay mutex per client_id so honest traffic is no longer serialised across clients Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/services/client.go | 12 ++++++--- internal/services/client_assertion.go | 38 ++++++++++++++++++--------- internal/services/jwks_fetcher.go | 12 ++++++--- internal/util/jwk.go | 11 +++++++- 4 files changed, 51 insertions(+), 22 deletions(-) diff --git a/internal/services/client.go b/internal/services/client.go index 2d500c76..3d7d1cfa 100644 --- a/internal/services/client.go +++ b/internal/services/client.go @@ -401,7 +401,11 @@ func (s *ClientService) UpdateClient( enableClientCredentials, ) - // Token endpoint authentication fields (may be zero to preserve existing). + // Token endpoint authentication fields are atomic: when the caller + // specifies a method they must also provide its key material, and when + // they omit the method the existing configuration is preserved. This + // prevents forms that don't surface the new fields (e.g. the legacy + // admin UI) from silently wiping a private_key_jwt client's JWKS. if req.TokenEndpointAuthMethod != "" { if !validTokenEndpointAuthMethod(req.TokenEndpointAuthMethod) { return ErrInvalidTokenEndpointAuthMethod @@ -411,10 +415,10 @@ func (s *ClientService) UpdateClient( return ErrPrivateKeyJWTRequiresConfidential } client.TokenEndpointAuthMethod = req.TokenEndpointAuthMethod + client.TokenEndpointAuthSigningAlg = req.TokenEndpointAuthSigningAlg + client.JWKSURI = strings.TrimSpace(req.JWKSURI) + client.JWKS = strings.TrimSpace(req.JWKS) } - client.TokenEndpointAuthSigningAlg = req.TokenEndpointAuthSigningAlg - client.JWKSURI = strings.TrimSpace(req.JWKSURI) - client.JWKS = strings.TrimSpace(req.JWKS) // Clear the shared secret when switching away from client_secret_* methods, // so a stale hash cannot authenticate a reconfigured client. if !client.UsesClientSecret() { diff --git a/internal/services/client_assertion.go b/internal/services/client_assertion.go index a4355d85..d7c149e9 100644 --- a/internal/services/client_assertion.go +++ b/internal/services/client_assertion.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/go-authgate/authgate/internal/cache" "github.com/go-authgate/authgate/internal/core" "github.com/go-authgate/authgate/internal/models" "github.com/go-authgate/authgate/internal/util" @@ -69,11 +70,10 @@ type ClientAssertionVerifier struct { auditService core.AuditLogger cfg ClientAssertionConfig - // jtiMu serialises the jti Get+Set so two concurrent requests with the - // same jti cannot both observe a cache miss and pass replay detection. - // Honest traffic has unique jtis and hits no contention; replay attempts - // and malformed duplicates block each other, which is the desired effect. - jtiMu sync.Mutex + // jtiLocks shards the jti replay Get+Set critical section per client to + // keep honest traffic free of cross-client contention while still closing + // the TOCTOU window for concurrent requests carrying the same jti. + jtiLocks sync.Map // map[string]*sync.Mutex, keyed by client_id } // NewClientAssertionVerifier wires the verifier. auditService may be nil (no-op). @@ -318,13 +318,24 @@ func (v *ClientAssertionVerifier) checkJTIReplay( } key := jtiCacheKeyPrefix + clientID + ":" + jti - // Serialise the Get+Set pair: without the lock, two concurrent requests - // carrying the same jti can both observe a cache miss before either Set - // lands, accepting a replay. - v.jtiMu.Lock() - defer v.jtiMu.Unlock() + // Serialise the Get+Set pair per client so two concurrent requests + // carrying the same jti cannot both observe a cache miss before either + // Set lands. Sharding by client keeps honest high-throughput traffic + // free of cross-client contention. + lockIface, _ := v.jtiLocks.LoadOrStore(clientID, &sync.Mutex{}) + lock := lockIface.(*sync.Mutex) + lock.Lock() + defer lock.Unlock() - if _, err := v.jtiCache.Get(ctx, key); err == nil { + switch _, err := v.jtiCache.Get(ctx, key); { + case err == nil: + return ErrAssertionJTIReplay + case errors.Is(err, cache.ErrCacheMiss): + // expected — the jti has not been seen; fall through to record it. + default: + // Backend error (e.g. Redis unavailable). Fail closed so we do not + // silently accept replays while the cache is degraded. + log.Printf("[ClientAssertion] jti cache lookup failed: %v", err) return ErrAssertionJTIReplay } // TTL = remaining assertion lifetime + clock skew. If exp is absent, @@ -336,10 +347,11 @@ func (v *ClientAssertionVerifier) checkJTIReplay( ttl = remaining } } - // Log but do not block on cache write failure — availability over perfect - // replay protection (the cache is best-effort in a multi-instance setup). if err := v.jtiCache.Set(ctx, key, true, ttl); err != nil { + // Set failure after a miss: reject the assertion rather than + // silently skipping replay tracking. log.Printf("[ClientAssertion] failed to record jti %s: %v", jti, err) + return ErrAssertionJTIReplay } return nil } diff --git a/internal/services/jwks_fetcher.go b/internal/services/jwks_fetcher.go index 06cb4823..33390ae4 100644 --- a/internal/services/jwks_fetcher.go +++ b/internal/services/jwks_fetcher.go @@ -41,8 +41,10 @@ type JWKSFetcher struct { } // NewJWKSFetcher constructs a JWKSFetcher. -// timeout controls the HTTP request timeout; ttl controls the cache lifetime. -// cache may be nil, in which case a bounded in-process memory cache is used. +// timeout controls the HTTP request timeout; ttl controls the cache lifetime +// when a cache is provided. cache may be nil, in which case every call goes +// straight to the remote — callers that need kid-miss refresh semantics +// should supply a non-nil cache. func NewJWKSFetcher(cache core.Cache[util.JWKSet], timeout, ttl time.Duration) *JWKSFetcher { if timeout <= 0 { timeout = 10 * time.Second @@ -81,8 +83,10 @@ func (f *JWKSFetcher) GetWithRefresh(ctx context.Context, uri, kid string) (*uti if !f.canRefreshNow(uri) { return set, nil } - if err := f.cache.Delete(ctx, uri); err != nil { - log.Printf("[JWKSFetcher] cache delete failed for %s: %v", uri, err) + if f.cache != nil { + if err := f.cache.Delete(ctx, uri); err != nil { + log.Printf("[JWKSFetcher] cache delete failed for %s: %v", uri, err) + } } return f.getCached(ctx, uri) } diff --git a/internal/util/jwk.go b/internal/util/jwk.go index bd4f1cf7..3644e906 100644 --- a/internal/util/jwk.go +++ b/internal/util/jwk.go @@ -113,6 +113,15 @@ func rsaKeyFromJWK(k *JWK) (*rsa.PublicKey, error) { if !e.IsInt64() { return nil, fmt.Errorf("%w: RSA exponent overflow", ErrInvalidJWKS) } + // E must be an odd integer > 1 per PKCS#1; reject weak/degenerate exponents + // (e.g. 0, 1, even values) that make signature verification trivially unsafe. + eInt := e.Int64() + if eInt < 3 || eInt&1 == 0 { + return nil, fmt.Errorf( + "%w: RSA exponent must be odd and >= 3 (got %d)", + ErrInvalidJWKS, eInt, + ) + } if n.BitLen() < 2048 { return nil, fmt.Errorf( "%w: RSA modulus must be >= 2048 bits (got %d)", @@ -120,7 +129,7 @@ func rsaKeyFromJWK(k *JWK) (*rsa.PublicKey, error) { n.BitLen(), ) } - return &rsa.PublicKey{N: n, E: int(e.Int64())}, nil + return &rsa.PublicKey{N: n, E: int(eInt)}, nil } func ecKeyFromJWK(k *JWK) (*ecdsa.PublicKey, error) { From 397ef7ba8e67c10478ae6000d27344f41df554f0 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sun, 12 Apr 2026 21:56:52 +0800 Subject: [PATCH 03/13] fix(oauth): address second-round Copilot review - Reject shared-secret authentication for clients whose registered method is not client_secret_*, both in IssueClientCredentialsToken and AuthenticateClient, closing an auth-method downgrade path - Enforce full ClientType to TokenEndpointAuthMethod consistency on CreateClient (none implies public; client_secret_* and private_key_jwt imply confidential) - Prevent UpdateClient from switching into client_secret_basic/post when no secret hash exists, so a method flip cannot leave a client unauthenticatable without an explicit RegenerateSecret call - Disable the MemoryCache reaper goroutines wired in bootstrap for the JWKS and jti caches since they are not plumbed into the shutdown manager; lazy expiration is sufficient for their short TTLs - Update the DCR Swagger annotation so client_secret and related fields are documented as conditional on the auth method Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/bootstrap/handlers.go | 7 ++- internal/handlers/registration.go | 2 +- internal/services/client.go | 49 ++++++++++++++++--- internal/services/token_client_credentials.go | 15 +++++- 4 files changed, 62 insertions(+), 11 deletions(-) diff --git a/internal/bootstrap/handlers.go b/internal/bootstrap/handlers.go index 5d93e200..cb970b99 100644 --- a/internal/bootstrap/handlers.go +++ b/internal/bootstrap/handlers.go @@ -58,8 +58,11 @@ func initializeHandlers(deps handlerDeps) handlerSet { // when multi-instance deployment needs coordinated jti replay protection. var clientAuth *handlers.ClientAuthenticator if deps.cfg.PrivateKeyJWTEnabled { - jwksCache := cache.NewMemoryCache[util.JWKSet]() - jtiCache := cache.NewMemoryCache[bool]() + // Pass 0 to disable the background reaper: these caches aren't wired + // into the shutdown manager, and lazy expiration on Get is sufficient + // given the short TTLs used here (JWKS hour, jti ≤ assertion lifetime). + jwksCache := cache.NewMemoryCache[util.JWKSet](0) + jtiCache := cache.NewMemoryCache[bool](0) jwksFetcher := services.NewJWKSFetcher( jwksCache, deps.cfg.JWKSFetchTimeout, diff --git a/internal/handlers/registration.go b/internal/handlers/registration.go index 4186815e..13b0b280 100644 --- a/internal/handlers/registration.go +++ b/internal/handlers/registration.go @@ -56,7 +56,7 @@ type clientRegistrationRequest struct { // @Accept json // @Produce json // @Param request body clientRegistrationRequest true "Client registration request" -// @Success 201 {object} object{client_id=string,client_secret=string,client_name=string,redirect_uris=[]string,grant_types=[]string,token_endpoint_auth_method=string,scope=string,client_id_issued_at=int,client_secret_expires_at=int} "Client registered successfully" +// @Success 201 {object} object{client_id=string,client_secret=string,client_secret_expires_at=int,jwks_uri=string,token_endpoint_auth_signing_alg=string,client_name=string,redirect_uris=[]string,grant_types=[]string,token_endpoint_auth_method=string,scope=string,client_id_issued_at=int} "Client registered successfully. client_secret and client_secret_expires_at are only present for client_secret_basic/post auth methods; jwks_uri and token_endpoint_auth_signing_alg are present for private_key_jwt." // @Failure 400 {object} object{error=string,error_description=string} "Invalid client metadata" // @Failure 401 {object} object{error=string,error_description=string} "Invalid or missing initial access token" // @Failure 403 {object} object{error=string,error_description=string} "Dynamic registration is disabled" diff --git a/internal/services/client.go b/internal/services/client.go index 3d7d1cfa..c7f53099 100644 --- a/internal/services/client.go +++ b/internal/services/client.go @@ -235,14 +235,27 @@ func (s *ClientService) CreateClient( return nil, err } - // Token endpoint authentication method (RFC 7591 §2). + // Token endpoint authentication method (RFC 7591 §2). Enforce full + // method ↔ client type consistency so downstream code that keys off + // either field cannot disagree on a client's auth contract. authMethod := resolveTokenEndpointAuthMethod(req.TokenEndpointAuthMethod, clientType) if !validTokenEndpointAuthMethod(authMethod) { return nil, ErrInvalidTokenEndpointAuthMethod } - if authMethod == models.TokenEndpointAuthPrivateKeyJWT && - clientType != core.ClientTypeConfidential { - return nil, ErrPrivateKeyJWTRequiresConfidential + switch authMethod { + case models.TokenEndpointAuthNone: + if clientType != core.ClientTypePublic { + return nil, ErrInvalidTokenEndpointAuthMethod + } + case models.TokenEndpointAuthClientSecretBasic, + models.TokenEndpointAuthClientSecretPost: + if clientType != core.ClientTypeConfidential { + return nil, ErrInvalidTokenEndpointAuthMethod + } + case models.TokenEndpointAuthPrivateKeyJWT: + if clientType != core.ClientTypeConfidential { + return nil, ErrPrivateKeyJWTRequiresConfidential + } } // Generate client ID @@ -410,9 +423,31 @@ func (s *ClientService) UpdateClient( if !validTokenEndpointAuthMethod(req.TokenEndpointAuthMethod) { return ErrInvalidTokenEndpointAuthMethod } - if req.TokenEndpointAuthMethod == models.TokenEndpointAuthPrivateKeyJWT && - clientType != core.ClientTypeConfidential { - return ErrPrivateKeyJWTRequiresConfidential + switch req.TokenEndpointAuthMethod { + case models.TokenEndpointAuthNone: + if clientType != core.ClientTypePublic { + return ErrInvalidTokenEndpointAuthMethod + } + case models.TokenEndpointAuthClientSecretBasic, + models.TokenEndpointAuthClientSecretPost: + if clientType != core.ClientTypeConfidential { + return ErrInvalidTokenEndpointAuthMethod + } + // Switching to client_secret_* from a method that never stored + // a secret (private_key_jwt or none) would leave the client + // unauthenticatable. Require an explicit RegenerateSecret call + // to mint one so the operator receives the new plaintext. + if client.ClientSecret == "" { + return fmt.Errorf( + "%w: switching to %s requires generating a new secret first (use RegenerateSecret)", + ErrInvalidClientData, + req.TokenEndpointAuthMethod, + ) + } + case models.TokenEndpointAuthPrivateKeyJWT: + if clientType != core.ClientTypeConfidential { + return ErrPrivateKeyJWTRequiresConfidential + } } client.TokenEndpointAuthMethod = req.TokenEndpointAuthMethod client.TokenEndpointAuthSigningAlg = req.TokenEndpointAuthSigningAlg diff --git a/internal/services/token_client_credentials.go b/internal/services/token_client_credentials.go index b68ec459..18e47003 100644 --- a/internal/services/token_client_credentials.go +++ b/internal/services/token_client_credentials.go @@ -41,7 +41,15 @@ func (s *TokenService) IssueClientCredentialsToken( return nil, ErrClientCredentialsFlowDisabled } - // 4. Authenticate the client via its secret + // 4. Refuse to accept a shared secret for clients whose registered auth + // method is not secret-based (e.g. private_key_jwt). This prevents a + // downgrade attack where a stale or unexpected secret hash on such a + // client could be presented to bypass the assertion requirement. + if !client.UsesClientSecret() { + return nil, ErrInvalidClientCredentials + } + + // 5. Authenticate the client via its secret if !client.ValidateClientSecret([]byte(clientSecret)) { return nil, ErrInvalidClientCredentials } @@ -167,6 +175,11 @@ func (s *TokenService) AuthenticateClient( if !client.IsActive() { return ErrInvalidClientCredentials } + // Guard against auth-method downgrade: a client registered for + // private_key_jwt or none must not be authenticable via a shared secret. + if !client.UsesClientSecret() { + return ErrInvalidClientCredentials + } if !client.ValidateClientSecret([]byte(clientSecret)) { return ErrInvalidClientCredentials } From 254eab29877fe4f2db22824ca008791f73d0c28c Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sun, 12 Apr 2026 22:05:14 +0800 Subject: [PATCH 04/13] fix(oauth): address third-round Copilot review - Apply the 30s clock-skew default when callers pass the zero value for ClientAssertionConfig.ClockSkew, not only when it is negative - Parse inline JWKS JSON at create/update time so malformed registrations are rejected with invalid_client_data immediately rather than failing opaquely at assertion verification time Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/services/client.go | 27 +++++++++++++++++++++++++++ internal/services/client_assertion.go | 2 +- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/internal/services/client.go b/internal/services/client.go index c7f53099..119d7dfa 100644 --- a/internal/services/client.go +++ b/internal/services/client.go @@ -96,6 +96,27 @@ func resolveTokenEndpointAuthMethod(method string, clientType core.ClientType) s return models.TokenEndpointAuthClientSecretBasic } +// validateInlineJWKS parses the inline JWKS JSON (if present) and verifies +// every key can be converted to a usable public key. Called on create/update +// so malformed registrations are rejected immediately, rather than failing +// opaquely at assertion-verification time. +func validateInlineJWKS(jwks string) error { + jwks = strings.TrimSpace(jwks) + if jwks == "" { + return nil + } + set, err := util.ParseJWKSet(jwks) + if err != nil { + return err + } + for i := range set.Keys { + if _, err := set.Keys[i].ToPublicKey(); err != nil { + return fmt.Errorf("jwks[%d]: %w", i, err) + } + } + return nil +} + // validateRedirectURIs checks that every URI in the slice is an absolute http/https // URI without a fragment, as required by RFC 6749. func validateRedirectURIs(uris []string) error { @@ -309,6 +330,9 @@ func (s *ClientService) CreateClient( if err := client.ValidateKeyMaterial(); err != nil { return nil, fmt.Errorf("%w: %s", ErrInvalidClientData, err.Error()) } + if err := validateInlineJWKS(client.JWKS); err != nil { + return nil, fmt.Errorf("%w: invalid jwks: %s", ErrInvalidClientData, err.Error()) + } // Generate a shared secret only for the two client_secret_* auth methods. // Public (none) and private_key_jwt clients do not have a secret. @@ -462,6 +486,9 @@ func (s *ClientService) UpdateClient( if err := client.ValidateKeyMaterial(); err != nil { return fmt.Errorf("%w: %s", ErrInvalidClientData, err.Error()) } + if err := validateInlineJWKS(client.JWKS); err != nil { + return fmt.Errorf("%w: invalid jwks: %s", ErrInvalidClientData, err.Error()) + } err = s.store.UpdateClient(client) if err != nil { diff --git a/internal/services/client_assertion.go b/internal/services/client_assertion.go index d7c149e9..38e046e1 100644 --- a/internal/services/client_assertion.go +++ b/internal/services/client_assertion.go @@ -91,7 +91,7 @@ func NewClientAssertionVerifier( if cfg.MaxLifetime <= 0 { cfg.MaxLifetime = 5 * time.Minute } - if cfg.ClockSkew < 0 { + if cfg.ClockSkew <= 0 { cfg.ClockSkew = 30 * time.Second } return &ClientAssertionVerifier{ From 799740d18b552e33ef01f6a2aff7b1845b32389d Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sun, 12 Apr 2026 22:13:47 +0800 Subject: [PATCH 05/13] refactor(oauth): address fourth-round Copilot review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Short-circuit JWKSFetcher GetWithRefresh when the cache is nil so the no-cache path does not perform a second network fetch after a kid miss - Drop the unused audience field from ClientAuthenticator since assertion audience validation lives entirely on ClientAssertionVerifier - Collapse the redundant requireConfidential branches in Authenticate when client_id is empty — both paths returned the same error - Remove the unused ErrInvalidKeyMaterial sentinel: ValidateKeyMaterial never returns it, and keeping it would suggest an errors.Is contract that is not actually implemented Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/bootstrap/handlers.go | 1 - internal/handlers/client_auth.go | 14 ++++---------- internal/handlers/token_private_key_jwt_test.go | 2 +- internal/models/oauth_application.go | 3 --- internal/services/jwks_fetcher.go | 11 +++++++---- 5 files changed, 12 insertions(+), 19 deletions(-) diff --git a/internal/bootstrap/handlers.go b/internal/bootstrap/handlers.go index cb970b99..53e29e51 100644 --- a/internal/bootstrap/handlers.go +++ b/internal/bootstrap/handlers.go @@ -85,7 +85,6 @@ func initializeHandlers(deps handlerDeps) handlerSet { clientAuth = handlers.NewClientAuthenticator( deps.services.client, verifier, - tokenEndpoint, ) } diff --git a/internal/handlers/client_auth.go b/internal/handlers/client_auth.go index 6c2f9f1f..d8136604 100644 --- a/internal/handlers/client_auth.go +++ b/internal/handlers/client_auth.go @@ -38,25 +38,22 @@ type AuthenticatedClient struct { // ClientAuthenticator performs RFC 6749 §2.3 + RFC 7521 §4.2 authentication at // the token endpoint. It supports client_secret_basic, client_secret_post, -// private_key_jwt, and none (public clients). +// private_key_jwt, and none (public clients). Assertion-audience validation +// lives on ClientAssertionVerifier; this type only dispatches. type ClientAuthenticator struct { clientService *services.ClientService assertionVerifier *services.ClientAssertionVerifier - audience string } -// NewClientAuthenticator wires a new authenticator. assertionVerifier may be nil, -// in which case private_key_jwt is rejected. The audience is the token endpoint -// URL presented to clients (typically BaseURL + "/oauth/token"). +// NewClientAuthenticator wires a new authenticator. assertionVerifier may be +// nil, in which case private_key_jwt assertions are rejected. func NewClientAuthenticator( cs *services.ClientService, av *services.ClientAssertionVerifier, - audience string, ) *ClientAuthenticator { return &ClientAuthenticator{ clientService: cs, assertionVerifier: av, - audience: audience, } } @@ -77,9 +74,6 @@ func (a *ClientAuthenticator) Authenticate( clientID, secret, cameFromHeader := parseClientCredentials(c) if clientID == "" { - if requireConfidential { - return nil, ErrClientAuthRequired - } return nil, ErrClientAuthRequired } diff --git a/internal/handlers/token_private_key_jwt_test.go b/internal/handlers/token_private_key_jwt_test.go index ad188e71..48528c1a 100644 --- a/internal/handlers/token_private_key_jwt_test.go +++ b/internal/handlers/token_private_key_jwt_test.go @@ -81,7 +81,7 @@ func setupPKJWTEnv(t *testing.T) (*gin.Engine, *store.Store, string) { ClockSkew: cfg.ClientAssertionClockSkew, }, ) - clientAuth := NewClientAuthenticator(clientSvc, verifier, tokenEndpoint) + clientAuth := NewClientAuthenticator(clientSvc, verifier) handler := NewTokenHandler(tokenSvc, authzSvc, cfg).WithClientAuthenticator(clientAuth) r := gin.New() diff --git a/internal/models/oauth_application.go b/internal/models/oauth_application.go index 8bdbafd5..93edd24d 100644 --- a/internal/models/oauth_application.go +++ b/internal/models/oauth_application.go @@ -143,9 +143,6 @@ func (app *OAuthApplication) UsesClientSecret() bool { app.TokenEndpointAuthMethod == TokenEndpointAuthClientSecretPost } -// ErrInvalidKeyMaterial indicates a client's private_key_jwt configuration is invalid. -var ErrInvalidKeyMaterial = errors.New("invalid key material for private_key_jwt") - // ValidateKeyMaterial verifies that a private_key_jwt client has exactly one of // JWKSURI or JWKS set, and that the signing algorithm is supported. For other // auth methods, it verifies no key material is present. diff --git a/internal/services/jwks_fetcher.go b/internal/services/jwks_fetcher.go index 33390ae4..609521a6 100644 --- a/internal/services/jwks_fetcher.go +++ b/internal/services/jwks_fetcher.go @@ -80,13 +80,16 @@ func (f *JWKSFetcher) GetWithRefresh(ctx context.Context, uri, kid string) (*uti if kid == "" || set.FindByKid(kid) != nil { return set, nil } + // With no cache there is nothing to refresh — getCached already dialed + // the remote for this call, so an additional fetch is pure overhead. + if f.cache == nil { + return set, nil + } if !f.canRefreshNow(uri) { return set, nil } - if f.cache != nil { - if err := f.cache.Delete(ctx, uri); err != nil { - log.Printf("[JWKSFetcher] cache delete failed for %s: %v", uri, err) - } + if err := f.cache.Delete(ctx, uri); err != nil { + log.Printf("[JWKSFetcher] cache delete failed for %s: %v", uri, err) } return f.getCached(ctx, uri) } From 309b998ab7344395e6054b3f88a14a24f4403fcf Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sun, 12 Apr 2026 22:23:08 +0800 Subject: [PATCH 06/13] fix(oauth): address fifth-round Copilot review - Close the race in JWKSFetcher canRefreshNow: concurrent callers past the cooldown window could all observe the stale timestamp and trigger redundant refreshes. Use a CompareAndSwap loop so only one caller wins the refresh decision per window - Reject client_assertion claims where exp is not strictly after iat, so a negative or zero lifetime cannot slip past the MaxLifetime cap - Surface jti replay cache backend failures via a distinct ErrAssertionJTICacheUnavailable error instead of reporting them as replay attempts in the audit log, so operators can distinguish replay from cache outage during incident response Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/services/client_assertion.go | 17 +++++++++++++---- internal/services/jwks_fetcher.go | 27 +++++++++++++++++---------- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/internal/services/client_assertion.go b/internal/services/client_assertion.go index 38e046e1..02fb4ad3 100644 --- a/internal/services/client_assertion.go +++ b/internal/services/client_assertion.go @@ -49,6 +49,7 @@ var ( ErrAssertionLifetimeTooLong = errors.New("client_assertion lifetime exceeds server maximum") ErrAssertionMissingJTI = errors.New("client_assertion is missing jti") ErrAssertionJTIReplay = errors.New("client_assertion jti was already used") + ErrAssertionJTICacheUnavailable = errors.New("client_assertion jti replay cache unavailable") ErrAssertionMissingRequiredTime = errors.New("client_assertion is missing required time claims") ) @@ -257,6 +258,11 @@ func (v *ClientAssertionVerifier) validateTimeClaims(claims jwt.MapClaims) error return ErrAssertionNotYetValid } } + // exp must be strictly after iat: a zero or negative lifetime is + // nonsensical and would otherwise pass the MaxLifetime bound below. + if !exp.After(iat) { + return ErrAssertionLifetimeTooLong + } if exp.Sub(iat) > v.cfg.MaxLifetime { return ErrAssertionLifetimeTooLong } @@ -334,9 +340,11 @@ func (v *ClientAssertionVerifier) checkJTIReplay( // expected — the jti has not been seen; fall through to record it. default: // Backend error (e.g. Redis unavailable). Fail closed so we do not - // silently accept replays while the cache is degraded. + // silently accept replays while the cache is degraded. Use a + // distinct error so audit logs don't misreport a cache outage as + // a replay attempt. log.Printf("[ClientAssertion] jti cache lookup failed: %v", err) - return ErrAssertionJTIReplay + return ErrAssertionJTICacheUnavailable } // TTL = remaining assertion lifetime + clock skew. If exp is absent, // fall back to MaxLifetime (defensive). @@ -349,9 +357,10 @@ func (v *ClientAssertionVerifier) checkJTIReplay( } if err := v.jtiCache.Set(ctx, key, true, ttl); err != nil { // Set failure after a miss: reject the assertion rather than - // silently skipping replay tracking. + // silently skipping replay tracking. Use the dedicated cache- + // unavailable error so audit logs are accurate. log.Printf("[ClientAssertion] failed to record jti %s: %v", jti, err) - return ErrAssertionJTIReplay + return ErrAssertionJTICacheUnavailable } return nil } diff --git a/internal/services/jwks_fetcher.go b/internal/services/jwks_fetcher.go index 609521a6..93e6b8f2 100644 --- a/internal/services/jwks_fetcher.go +++ b/internal/services/jwks_fetcher.go @@ -96,18 +96,25 @@ func (f *JWKSFetcher) GetWithRefresh(ctx context.Context, uri, kid string) (*uti // canRefreshNow returns true if uri has not been force-refreshed within // jwksRefreshCooldown, and marks it as refreshed. Concurrent callers for the -// same uri collapse into a single refresh per cooldown window. +// same uri collapse into a single refresh per cooldown window via CAS. func (f *JWKSFetcher) canRefreshNow(uri string) bool { - now := time.Now() - prev, loaded := f.lastRefresh.LoadOrStore(uri, now) - if !loaded { - return true - } - if now.Sub(prev.(time.Time)) < jwksRefreshCooldown { - return false + for { + now := time.Now() + prev, loaded := f.lastRefresh.LoadOrStore(uri, now) + if !loaded { + return true + } + prevTime := prev.(time.Time) + if now.Sub(prevTime) < jwksRefreshCooldown { + return false + } + // CompareAndSwap guarantees only one concurrent caller wins the + // post-cooldown refresh decision; losers retry the loop and fall + // back into the cooldown branch. + if f.lastRefresh.CompareAndSwap(uri, prevTime, now) { + return true + } } - f.lastRefresh.Store(uri, now) - return true } func (f *JWKSFetcher) getCached(ctx context.Context, uri string) (*util.JWKSet, error) { From d053d53094fb3ef885f0aab3bf3540341eee9e88 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sun, 12 Apr 2026 22:28:27 +0800 Subject: [PATCH 07/13] fix(oauth): use dedicated sentinel for missing jti cache Return ErrAssertionJTICacheUnavailable when jtiCache is nil instead of allocating an ad-hoc errors.New, so callers and tests can detect the misconfiguration via errors.Is consistently with the backend-error path. Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/services/client_assertion.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/internal/services/client_assertion.go b/internal/services/client_assertion.go index 02fb4ad3..9cf8720f 100644 --- a/internal/services/client_assertion.go +++ b/internal/services/client_assertion.go @@ -318,9 +318,10 @@ func (v *ClientAssertionVerifier) checkJTIReplay( return ErrAssertionMissingJTI } if v.jtiCache == nil { - // Without a cache we cannot prevent replay. Treat this as a - // hard-configured failure rather than silently allowing replays. - return errors.New("jti replay cache not configured") + // Without a cache we cannot prevent replay. Fail closed using the + // dedicated sentinel so callers/tests can distinguish this from an + // actual replay via errors.Is. + return ErrAssertionJTICacheUnavailable } key := jtiCacheKeyPrefix + clientID + ":" + jti From 11445e9577fd3985ff0c79f959a336036660ef5f Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sun, 12 Apr 2026 22:38:11 +0800 Subject: [PATCH 08/13] fix(oauth): address sixth-round Copilot review - Return the dedicated ErrAssertionInvalidTimeWindow (not ErrAssertionLifetimeTooLong) when exp is not strictly after iat so audit logs and tests distinguish a nonsensical time window from an assertion that simply sets too long a lifetime - Rework JWKSFetcher kid-miss refresh so a failing cache.Delete cannot strand key rotation: fetch fresh directly via HTTP, then best-effort update the cache. Callers now always see the newly rotated keys even if the cache backend is degraded - Expand the /oauth/introspect Swagger annotation to document that RFC 7523 client_assertion and client_assertion_type are accepted alongside Basic/form client_secret authentication Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/handlers/token.go | 12 +++++++----- internal/services/client_assertion.go | 3 ++- internal/services/jwks_fetcher.go | 14 +++++++++++--- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/internal/handlers/token.go b/internal/handlers/token.go index 856f1fbd..ee8f3267 100644 --- a/internal/handlers/token.go +++ b/internal/handlers/token.go @@ -286,14 +286,16 @@ func (h *TokenHandler) TokenInfo(c *gin.Context) { // Introspect godoc // // @Summary Introspect token (RFC 7662) -// @Description Determine the active state and metadata of an OAuth 2.0 token. Requires client authentication via HTTP Basic Auth or form-body client credentials. +// @Description Determine the active state and metadata of an OAuth 2.0 token. Requires client authentication via HTTP Basic Auth, form-body client_id/client_secret, or RFC 7523 private_key_jwt (client_assertion + client_assertion_type). // @Tags OAuth // @Accept x-www-form-urlencoded // @Produce json -// @Param token formData string true "The token to introspect" -// @Param token_type_hint formData string false "Hint about the type of token: 'access_token' or 'refresh_token'" -// @Param client_id formData string false "Client ID (alternative to HTTP Basic Auth)" -// @Param client_secret formData string false "Client secret (alternative to HTTP Basic Auth)" +// @Param token formData string true "The token to introspect" +// @Param token_type_hint formData string false "Hint about the type of token: 'access_token' or 'refresh_token'" +// @Param client_id formData string false "Client ID (alternative to HTTP Basic Auth)" +// @Param client_secret formData string false "Client secret (alternative to HTTP Basic Auth)" +// @Param client_assertion formData string false "Signed JWT assertion (RFC 7523 private_key_jwt); use with client_assertion_type" +// @Param client_assertion_type formData string false "Must be urn:ietf:params:oauth:client-assertion-type:jwt-bearer when client_assertion is present" // @Success 200 {object} object{active=bool,scope=string,client_id=string,username=string,token_type=string,exp=int,iat=int,sub=string,iss=string,jti=string} "Token introspection response" // @Failure 401 {object} object{error=string,error_description=string} "Client authentication failed" // @Router /oauth/introspect [post] diff --git a/internal/services/client_assertion.go b/internal/services/client_assertion.go index 9cf8720f..bf61656f 100644 --- a/internal/services/client_assertion.go +++ b/internal/services/client_assertion.go @@ -47,6 +47,7 @@ var ( ErrAssertionExpired = errors.New("client_assertion is expired") ErrAssertionNotYetValid = errors.New("client_assertion is not yet valid") ErrAssertionLifetimeTooLong = errors.New("client_assertion lifetime exceeds server maximum") + ErrAssertionInvalidTimeWindow = errors.New("client_assertion exp must be strictly after iat") ErrAssertionMissingJTI = errors.New("client_assertion is missing jti") ErrAssertionJTIReplay = errors.New("client_assertion jti was already used") ErrAssertionJTICacheUnavailable = errors.New("client_assertion jti replay cache unavailable") @@ -261,7 +262,7 @@ func (v *ClientAssertionVerifier) validateTimeClaims(claims jwt.MapClaims) error // exp must be strictly after iat: a zero or negative lifetime is // nonsensical and would otherwise pass the MaxLifetime bound below. if !exp.After(iat) { - return ErrAssertionLifetimeTooLong + return ErrAssertionInvalidTimeWindow } if exp.Sub(iat) > v.cfg.MaxLifetime { return ErrAssertionLifetimeTooLong diff --git a/internal/services/jwks_fetcher.go b/internal/services/jwks_fetcher.go index 93e6b8f2..985a976a 100644 --- a/internal/services/jwks_fetcher.go +++ b/internal/services/jwks_fetcher.go @@ -88,10 +88,18 @@ func (f *JWKSFetcher) GetWithRefresh(ctx context.Context, uri, kid string) (*uti if !f.canRefreshNow(uri) { return set, nil } - if err := f.cache.Delete(ctx, uri); err != nil { - log.Printf("[JWKSFetcher] cache delete failed for %s: %v", uri, err) + // Bypass the cache: fetch fresh directly and best-effort update the + // cache afterward. This keeps key rotation working even if the cache + // backend is degraded and Delete/Set fails — callers still see the + // freshly rotated keys. + fresh, err := f.fetch(ctx, uri) + if err != nil { + return nil, err } - return f.getCached(ctx, uri) + if setErr := f.cache.Set(ctx, uri, fresh, f.ttl); setErr != nil { + log.Printf("[JWKSFetcher] cache set after refresh failed for %s: %v", uri, setErr) + } + return &fresh, nil } // canRefreshNow returns true if uri has not been force-refreshed within From a05f4d5df8c59c793c519cc5c50580ce6c45c288 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sun, 12 Apr 2026 22:48:21 +0800 Subject: [PATCH 09/13] fix(oauth): enforce secret auth method + drop misleading Basic challenge - Reject requests that use Basic Auth against a client registered as client_secret_post, and form-body credentials against one registered as client_secret_basic. Previously the server accepted either mode for any client_secret_* client, which made the registered method merely informational rather than a contract - Suppress WWW-Authenticate Basic in 401 responses when the caller authenticated (or attempted to) via private_key_jwt, since Basic is not an applicable method for those clients. A new shouldChallengeBasic helper gates the header across the token endpoint, introspect, and the client_credentials error path Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/handlers/client_auth.go | 14 +++++++++++--- internal/handlers/token.go | 24 +++++++++++++++++++++--- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/internal/handlers/client_auth.go b/internal/handlers/client_auth.go index d8136604..b66a5bc9 100644 --- a/internal/handlers/client_auth.go +++ b/internal/handlers/client_auth.go @@ -95,13 +95,21 @@ func (a *ClientAuthenticator) Authenticate( if secret == "" || !client.ValidateClientSecret([]byte(secret)) { return nil, ErrClientAuthSecretBad } - method := models.TokenEndpointAuthClientSecretBasic + observed := models.TokenEndpointAuthClientSecretBasic if !cameFromHeader { - method = models.TokenEndpointAuthClientSecretPost + observed = models.TokenEndpointAuthClientSecretPost + } + // Enforce the registered method contract: a client registered as + // client_secret_basic must use Basic Auth, and client_secret_post + // must use form-body credentials. Accepting either mode would let a + // stolen secret be presented through whichever channel happens to + // be easier for the attacker. + if client.TokenEndpointAuthMethod != observed { + return nil, ErrClientAuthMethodUnmet } return &AuthenticatedClient{ Client: client, - Method: method, + Method: observed, }, nil } diff --git a/internal/handlers/token.go b/internal/handlers/token.go index ee8f3267..472a9867 100644 --- a/internal/handlers/token.go +++ b/internal/handlers/token.go @@ -307,7 +307,12 @@ func (h *TokenHandler) Introspect(c *gin.Context) { if h.clientAuthenticator != nil { authed, err := h.clientAuthenticator.Authenticate(c, true) if err != nil { - c.Header("WWW-Authenticate", `Basic realm="authgate"`) + // Only advertise a Basic challenge when Basic is actually an + // applicable method for this caller — otherwise a private_key_jwt + // client would see a misleading WWW-Authenticate header. + if shouldChallengeBasic(c) { + c.Header("WWW-Authenticate", `Basic realm="authgate"`) + } respondOAuthError( c, http.StatusUnauthorized, @@ -447,7 +452,9 @@ func (h *TokenHandler) handleClientCredentialsGrant(c *gin.Context) { if h.clientAuthenticator != nil { authed, err := h.clientAuthenticator.Authenticate(c, true) if err != nil { - c.Header("WWW-Authenticate", `Basic realm="authgate"`) + if shouldChallengeBasic(c) { + c.Header("WWW-Authenticate", `Basic realm="authgate"`) + } respondOAuthError( c, http.StatusUnauthorized, @@ -496,6 +503,15 @@ func (h *TokenHandler) handleClientCredentialsGrant(c *gin.Context) { c.JSON(http.StatusOK, buildTokenResponse(accessToken, nil, "")) } +// shouldChallengeBasic reports whether an HTTP Basic Auth challenge is +// meaningful in a 401 response to the given request. We skip the challenge +// when the caller explicitly used private_key_jwt (client_assertion is +// present) because Basic is not an applicable method for that client. +func shouldChallengeBasic(c *gin.Context) bool { + return c.PostForm(formClientAssertion) == "" && + c.PostForm(formClientAssertionType) == "" +} + // writeClientCredentialsError maps service-layer errors from the client_credentials // flow to RFC 6749 error responses. Shared between the classic and shared-authenticator // code paths. @@ -503,7 +519,9 @@ func (h *TokenHandler) writeClientCredentialsError(c *gin.Context, err error) { switch { case errors.Is(err, services.ErrInvalidClientCredentials), errors.Is(err, services.ErrClientNotConfidential): - c.Header("WWW-Authenticate", `Basic realm="authgate"`) + if shouldChallengeBasic(c) { + c.Header("WWW-Authenticate", `Basic realm="authgate"`) + } respondOAuthError( c, http.StatusUnauthorized, From 552ac15641cb2ed91989b34fe89a91ecaf1023a9 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sun, 12 Apr 2026 22:56:44 +0800 Subject: [PATCH 10/13] fix(oauth): roll back JWKS refresh reservation on failure + reject overflowing RSA exponent - Split JWKSFetcher canRefreshNow into tryReserveRefresh + rollback: the CAS reservation is now undone when the subsequent network fetch fails, so a transient error no longer strands key rotation until the cooldown elapses - Reject RSA public exponents that exceed the platform int range when constructing the rsa.PublicKey, so a large-but-int64-sized exponent cannot silently overflow into a corrupted key on 32-bit architectures Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/services/jwks_fetcher.go | 39 +++++++++++++++++++++++-------- internal/util/jwk.go | 10 ++++++++ 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/internal/services/jwks_fetcher.go b/internal/services/jwks_fetcher.go index 985a976a..7ab03f46 100644 --- a/internal/services/jwks_fetcher.go +++ b/internal/services/jwks_fetcher.go @@ -85,7 +85,8 @@ func (f *JWKSFetcher) GetWithRefresh(ctx context.Context, uri, kid string) (*uti if f.cache == nil { return set, nil } - if !f.canRefreshNow(uri) { + prev, stored, ok := f.tryReserveRefresh(uri) + if !ok { return set, nil } // Bypass the cache: fetch fresh directly and best-effort update the @@ -94,6 +95,9 @@ func (f *JWKSFetcher) GetWithRefresh(ctx context.Context, uri, kid string) (*uti // freshly rotated keys. fresh, err := f.fetch(ctx, uri) if err != nil { + // Roll back the reservation so the next caller can retry, rather + // than being suppressed by the cooldown after a transient error. + f.rollbackRefreshReservation(uri, prev, stored) return nil, err } if setErr := f.cache.Set(ctx, uri, fresh, f.ttl); setErr != nil { @@ -103,28 +107,43 @@ func (f *JWKSFetcher) GetWithRefresh(ctx context.Context, uri, kid string) (*uti } // canRefreshNow returns true if uri has not been force-refreshed within -// jwksRefreshCooldown, and marks it as refreshed. Concurrent callers for the -// same uri collapse into a single refresh per cooldown window via CAS. -func (f *JWKSFetcher) canRefreshNow(uri string) bool { +// tryReserveRefresh attempts to claim the right to refresh uri via CAS. +// On success it returns the timestamp it stored (so the caller can roll back +// on fetch failure) and whether a previous reservation existed — both needed +// to undo the reservation cleanly if the subsequent network fetch fails. +func (f *JWKSFetcher) tryReserveRefresh(uri string) (prev, stored time.Time, ok bool) { for { now := time.Now() - prev, loaded := f.lastRefresh.LoadOrStore(uri, now) + existing, loaded := f.lastRefresh.LoadOrStore(uri, now) if !loaded { - return true + // First-ever reservation for this uri. + return time.Time{}, now, true } - prevTime := prev.(time.Time) + prevTime := existing.(time.Time) if now.Sub(prevTime) < jwksRefreshCooldown { - return false + return time.Time{}, time.Time{}, false } - // CompareAndSwap guarantees only one concurrent caller wins the + // CompareAndSwap ensures only one concurrent caller wins the // post-cooldown refresh decision; losers retry the loop and fall // back into the cooldown branch. if f.lastRefresh.CompareAndSwap(uri, prevTime, now) { - return true + return prevTime, now, true } } } +// rollbackRefreshReservation reverts a reservation made by tryReserveRefresh +// after the refresh fetch failed. Without this, a transient fetch error would +// strand key rotation until the cooldown elapses. +func (f *JWKSFetcher) rollbackRefreshReservation(uri string, prev, stored time.Time) { + if prev.IsZero() { + // Best-effort: if someone else updated since, leave their value alone. + f.lastRefresh.CompareAndDelete(uri, stored) + return + } + f.lastRefresh.CompareAndSwap(uri, stored, prev) +} + func (f *JWKSFetcher) getCached(ctx context.Context, uri string) (*util.JWKSet, error) { if f.cache == nil { set, err := f.fetch(ctx, uri) diff --git a/internal/util/jwk.go b/internal/util/jwk.go index 3644e906..aad588d9 100644 --- a/internal/util/jwk.go +++ b/internal/util/jwk.go @@ -10,6 +10,7 @@ import ( "encoding/json" "errors" "fmt" + "math" "math/big" ) @@ -122,6 +123,15 @@ func rsaKeyFromJWK(k *JWK) (*rsa.PublicKey, error) { ErrInvalidJWKS, eInt, ) } + // rsa.PublicKey.E is an int, which is 32 bits on some platforms. Reject + // exponents that would truncate when cast rather than silently producing + // a corrupted key. + if eInt > math.MaxInt { + return nil, fmt.Errorf( + "%w: RSA exponent %d exceeds platform int range", + ErrInvalidJWKS, eInt, + ) + } if n.BitLen() < 2048 { return nil, fmt.Errorf( "%w: RSA modulus must be >= 2048 bits (got %d)", From 78c8c27667932fb1e58266593d53bb792d7017a7 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sun, 12 Apr 2026 23:04:51 +0800 Subject: [PATCH 11/13] refactor(oauth): fix stale docstring and surface key-gen errors in tests - Rewrite the docstring above JWKSFetcher tryReserveRefresh which still referenced the old canRefreshNow name and behaviour after the CAS refactor - Use require.NoError when generating RSA keys in JWK tests instead of dropping the error: a genuine generation failure now surfaces as a clear Fatalf rather than a later nil-pointer panic Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/services/jwks_fetcher.go | 11 ++++++----- internal/util/jwk_test.go | 10 ++++++++-- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/internal/services/jwks_fetcher.go b/internal/services/jwks_fetcher.go index 7ab03f46..094a43b8 100644 --- a/internal/services/jwks_fetcher.go +++ b/internal/services/jwks_fetcher.go @@ -106,11 +106,12 @@ func (f *JWKSFetcher) GetWithRefresh(ctx context.Context, uri, kid string) (*uti return &fresh, nil } -// canRefreshNow returns true if uri has not been force-refreshed within -// tryReserveRefresh attempts to claim the right to refresh uri via CAS. -// On success it returns the timestamp it stored (so the caller can roll back -// on fetch failure) and whether a previous reservation existed — both needed -// to undo the reservation cleanly if the subsequent network fetch fails. +// tryReserveRefresh attempts to reserve a refresh for uri, subject to the +// per-URI cooldown. LoadOrStore + CompareAndSwap together ensure that at most +// one concurrent caller wins the right to refresh once the cooldown has +// elapsed. On success it returns the previous timestamp (zero when the uri +// was never refreshed) and the timestamp it stored, so rollbackRefreshReservation +// can undo the reservation if the subsequent network fetch fails. func (f *JWKSFetcher) tryReserveRefresh(uri string) (prev, stored time.Time, ok bool) { for { now := time.Now() diff --git a/internal/util/jwk_test.go b/internal/util/jwk_test.go index 0a1b3c0f..eae6a109 100644 --- a/internal/util/jwk_test.go +++ b/internal/util/jwk_test.go @@ -124,7 +124,10 @@ func TestParseJWKSet_InvalidJSON(t *testing.T) { } func TestFindByKid_SingleKeyNoKid(t *testing.T) { - priv, _ := rsa.GenerateKey(rand.Reader, 2048) + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } set := &JWKSet{Keys: []JWK{jwkFromRSA(t, &priv.PublicKey, "")}} if got := set.FindByKid(""); got == nil { t.Fatal("expected single key match when kid empty") @@ -132,7 +135,10 @@ func TestFindByKid_SingleKeyNoKid(t *testing.T) { } func TestFindByKid_MultipleKeysEmptyKid(t *testing.T) { - priv, _ := rsa.GenerateKey(rand.Reader, 2048) + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } set := &JWKSet{Keys: []JWK{ jwkFromRSA(t, &priv.PublicKey, "a"), jwkFromRSA(t, &priv.PublicKey, "b"), From f549a477d9f965cacbcba04e8563e5c0bcb14cc9 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sun, 12 Apr 2026 23:13:56 +0800 Subject: [PATCH 12/13] fix(oauth): reject private_key_jwt with incompatible grants Prevent the creation or update of a private_key_jwt client that also enables authorization_code or device_code. Those grant handlers still authenticate confidential clients via client_secret, so the resulting client would register but fail at token exchange. CreateClient and UpdateClient now return an invalid_client_data error in this case, pointing the operator at the currently supported scope (client_credentials + introspection). Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/services/client.go | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/internal/services/client.go b/internal/services/client.go index 119d7dfa..91a7bb80 100644 --- a/internal/services/client.go +++ b/internal/services/client.go @@ -277,6 +277,17 @@ func (s *ClientService) CreateClient( if clientType != core.ClientTypeConfidential { return nil, ErrPrivateKeyJWTRequiresConfidential } + // Only client_credentials and introspection currently authenticate + // via the shared ClientAuthenticator. Enabling other grants on a + // private_key_jwt client would produce a client that can register + // but cannot actually exchange codes or refresh tokens, because + // those paths still expect a shared secret. + if req.EnableAuthCodeFlow || req.EnableDeviceFlow { + return nil, fmt.Errorf( + "%w: private_key_jwt is currently supported only for the client_credentials grant; disable authorization_code and device_code flows", + ErrInvalidClientData, + ) + } } // Generate client ID @@ -472,6 +483,15 @@ func (s *ClientService) UpdateClient( if clientType != core.ClientTypeConfidential { return ErrPrivateKeyJWTRequiresConfidential } + // See the matching guard in CreateClient — authorization_code and + // device_code still expect a shared secret, so enabling them on a + // private_key_jwt client produces an unusable configuration. + if req.EnableAuthCodeFlow || req.EnableDeviceFlow { + return fmt.Errorf( + "%w: private_key_jwt is currently supported only for the client_credentials grant; disable authorization_code and device_code flows", + ErrInvalidClientData, + ) + } } client.TokenEndpointAuthMethod = req.TokenEndpointAuthMethod client.TokenEndpointAuthSigningAlg = req.TokenEndpointAuthSigningAlg From 57d86d46df3b64c54feb3312539b1319d7ff72a9 Mon Sep 17 00:00:00 2001 From: Bo-Yi Wu Date: Sun, 12 Apr 2026 23:22:33 +0800 Subject: [PATCH 13/13] fix(oauth): evict stale JWKS on Set failure + advertise jwk-set Accept MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Best-effort delete the cached JWKS when cache.Set fails after a kid-miss refresh. Without the eviction, the freshly-updated lastRefresh timestamp would pin subsequent callers to the stale entry until the cooldown expires - Include application/jwk-set+json in the fetcher's Accept header per RFC 7517 §8.5, so remote JWKS endpoints that honour Accept strictly do not return 406 Not Acceptable for jwks_uri fetches Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/services/jwks_fetcher.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/internal/services/jwks_fetcher.go b/internal/services/jwks_fetcher.go index 094a43b8..091dd9a2 100644 --- a/internal/services/jwks_fetcher.go +++ b/internal/services/jwks_fetcher.go @@ -102,6 +102,15 @@ func (f *JWKSFetcher) GetWithRefresh(ctx context.Context, uri, kid string) (*uti } if setErr := f.cache.Set(ctx, uri, fresh, f.ttl); setErr != nil { log.Printf("[JWKSFetcher] cache set after refresh failed for %s: %v", uri, setErr) + // Best-effort delete: without it the stale cached JWKS would + // keep being returned until the cooldown expires, since + // subsequent callers skip refresh while lastRefresh is fresh. + if delErr := f.cache.Delete(ctx, uri); delErr != nil { + log.Printf( + "[JWKSFetcher] cache delete after failed refresh set failed for %s: %v", + uri, delErr, + ) + } } return &fresh, nil } @@ -168,7 +177,9 @@ func (f *JWKSFetcher) fetch(ctx context.Context, uri string) (util.JWKSet, error if err != nil { return util.JWKSet{}, fmt.Errorf("%w: %v", ErrJWKSFetchFailed, err) } - req.Header.Set("Accept", "application/json") + // RFC 7517 §8.5 defines application/jwk-set+json; many JWKS endpoints + // honour Accept strictly and would 406 on plain application/json alone. + req.Header.Set("Accept", "application/jwk-set+json, application/json") resp, err := f.httpClient.Do(req) if err != nil { return util.JWKSet{}, fmt.Errorf("%w: %v", ErrJWKSFetchFailed, err)