diff --git a/.env.example b/.env.example index b386fd6..ff6a91a 100644 --- a/.env.example +++ b/.env.example @@ -285,3 +285,15 @@ EXPIRED_TOKEN_CLEANUP_INTERVAL=1h # How often to run the cleanup (default: # CORS_ALLOWED_METHODS=GET,POST,PUT,DELETE,OPTIONS # CORS_ALLOWED_HEADERS=Origin,Content-Type,Authorization # CORS_MAX_AGE=12h + +# ============================================================ +# Private Key JWT (RFC 7523) — Client Assertion Authentication +# ============================================================ +# When enabled, OAuth clients can authenticate to /oauth/token and +# /oauth/introspect by presenting a signed JWT assertion instead of +# client_secret. See docs/PRIVATE_KEY_JWT.md. +# PRIVATE_KEY_JWT_ENABLED=true +# JWKS_FETCH_TIMEOUT=10s # HTTP timeout when fetching a client's jwks_uri +# JWKS_CACHE_TTL=1h # Cached JWKS lifetime before refetch +# CLIENT_ASSERTION_MAX_LIFETIME=5m # Reject assertions whose exp-iat exceeds this +# CLIENT_ASSERTION_CLOCK_SKEW=30s # Tolerance for exp/nbf/iat skew diff --git a/docs/PRIVATE_KEY_JWT.md b/docs/PRIVATE_KEY_JWT.md new file mode 100644 index 0000000..a041c7a --- /dev/null +++ b/docs/PRIVATE_KEY_JWT.md @@ -0,0 +1,247 @@ +# Private Key JWT Client Authentication (RFC 7523) + +AuthGate supports **`private_key_jwt`** at the token endpoint — clients authenticate by signing a short-lived JWT assertion with their private key instead of presenting a long-lived `client_secret`. This is the authentication method [recommended by the MCP OAuth Client Credentials extension](https://modelcontextprotocol.io/extensions/auth/oauth-client-credentials) for machine-to-machine flows. + +- **RFC 7521** — Assertion Framework for OAuth 2.0 Client Authentication +- **RFC 7523** — JWT Profile for OAuth 2.0 Client Authentication +- **RFC 7591 §2.1** — Dynamic client registration of `jwks`/`jwks_uri` + +## Why use it + +| Aspect | `client_secret_basic`/`_post` | `private_key_jwt` | +| ------------------------- | --------------------------------------------------------------------- | -------------------------------------------------------- | +| Credential lifetime | Long-lived, stored on server | Short-lived (minutes), regenerated per call | +| Transmitted over the wire | Raw secret (even under TLS, exposed to intermediaries) | JWT signed with private key; private key stays on client | +| Server storage | bcrypt hash of secret (attacker who exfiltrates DB can offline-crack) | Public key only (no useful secret to steal) | +| Rotation | Requires re-issuing secret | Rotate JWKS; no downtime | +| Replay protection | None at protocol level | Built-in via `jti` + short `exp` | + +See [RFC 7523 §1](https://datatracker.ietf.org/doc/html/rfc7523#section-1) for more background. + +## Enabling the feature + +`private_key_jwt` is enabled by default. Set the following environment variables to configure it: + +```bash +PRIVATE_KEY_JWT_ENABLED=true # Feature flag (default: true) +JWKS_FETCH_TIMEOUT=10s # HTTP timeout when fetching jwks_uri (default: 10s) +JWKS_CACHE_TTL=1h # JWKS cache lifetime (default: 1h) +CLIENT_ASSERTION_MAX_LIFETIME=5m # Reject assertions whose exp-iat exceeds this (default: 5m) +CLIENT_ASSERTION_CLOCK_SKEW=30s # Tolerance for exp/nbf/iat skew (default: 30s) +``` + +When enabled, the OIDC discovery document lists the new method: + +```bash +curl https://authgate.example.com/.well-known/openid-configuration | jq .token_endpoint_auth_methods_supported +# ["client_secret_basic","client_secret_post","none","private_key_jwt"] + +curl https://authgate.example.com/.well-known/openid-configuration | jq .token_endpoint_auth_signing_alg_values_supported +# ["RS256","ES256"] +``` + +## Supported algorithms + +- **RS256** — 2048-bit (or larger) RSA with SHA-256 +- **ES256** — ECDSA P-256 with SHA-256 + +`HS256` and other symmetric algorithms are rejected by design (they provide no advantage over `client_secret_*`). `EdDSA` is not currently supported; file an issue if you need it. + +## Registering a client + +### Option 1 — Dynamic Client Registration (RFC 7591) + +Enable DCR (`ENABLE_DYNAMIC_CLIENT_REGISTRATION=true`), then POST a client metadata document: + +```bash +curl -X POST https://authgate.example.com/oauth/register \ + -H "Content-Type: application/json" \ + -d '{ + "client_name": "my-service", + "grant_types": ["client_credentials"], + "token_endpoint_auth_method": "private_key_jwt", + "token_endpoint_auth_signing_alg": "RS256", + "scope": "email profile", + "jwks": { + "keys": [ + { + "kty": "RSA", + "use": "sig", + "kid": "2026-04-12", + "alg": "RS256", + "n": "0vx7agoebGcQ...", + "e": "AQAB" + } + ] + } + }' +``` + +Alternative: provide `jwks_uri` instead of an inline `jwks`: + +```json +{ + "client_name": "my-service", + "grant_types": ["client_credentials"], + "token_endpoint_auth_method": "private_key_jwt", + "token_endpoint_auth_signing_alg": "ES256", + "jwks_uri": "https://my-service.example.com/.well-known/jwks.json" +} +``` + +`jwks_uri` and `jwks` are **mutually exclusive**; exactly one must be present. + +Registered clients start in `pending` status and require admin approval (standard DCR behaviour). The response does **not** include a `client_secret` — `private_key_jwt` clients have no shared secret. + +### Option 2 — Service-layer API + +Callers with direct access to the service layer can use `services.CreateClientRequest` with the new fields: + +```go +req := services.CreateClientRequest{ + ClientName: "my-service", + ClientType: core.ClientTypeConfidential, + EnableClientCredentialsFlow: true, + TokenEndpointAuthMethod: models.TokenEndpointAuthPrivateKeyJWT, + TokenEndpointAuthSigningAlg: "RS256", + JWKS: jwkSetJSON, // or JWKSURI: "https://..." + IsAdminCreated: true, // active immediately +} +resp, err := clientService.CreateClient(ctx, req) +``` + +### Admin UI + +The admin web form does not yet expose the new fields. Admins can register `private_key_jwt` clients via the DCR endpoint above or via a small helper script that calls the service layer directly. This will be added in a follow-up. + +## Requesting a token + +The client signs a JWT and presents it as `client_assertion` at `/oauth/token`: + +```http +POST /oauth/token HTTP/1.1 +Host: authgate.example.com +Content-Type: application/x-www-form-urlencoded + +grant_type=client_credentials +&scope=read write +&client_assertion_type=urn:ietf:params:oauth:client-assertion-type:jwt-bearer +&client_assertion= +``` + +### Required JWT claims + +| Claim | Value | +| ----- | ----------------------------------------------------------------------------- | +| `iss` | Your client_id | +| `sub` | Your client_id (must match `iss`) | +| `aud` | AuthGate's token endpoint URL, or the issuer URL | +| `iat` | Issued-at timestamp (now) | +| `exp` | Expiration (max `iat + CLIENT_ASSERTION_MAX_LIFETIME`; recommend 1–5 minutes) | +| `jti` | Unique per-assertion identifier (required — replay protection) | + +### JWT header + +`alg` must match `token_endpoint_auth_signing_alg` that was registered for the client. Include `kid` so AuthGate can pick the right key from your JWK Set: + +```json +{ + "alg": "RS256", + "kid": "2026-04-12", + "typ": "JWT" +} +``` + +## Client examples + +### Python (official MCP SDK) + +```python +from mcp.client.auth.extensions.client_credentials import ( + PrivateKeyJWTOAuthProvider, + SignedJWTParameters, +) +from mcp.client.streamable_http import streamablehttp_client +from mcp import ClientSession + +jwt_params = SignedJWTParameters( + issuer="my-service", # client_id + subject="my-service", # must match issuer + signing_key=open("private_key.pem").read(), + signing_algorithm="RS256", + lifetime_seconds=300, +) + +provider = PrivateKeyJWTOAuthProvider( + server_url="https://authgate.example.com/mcp", + client_id="my-service", + assertion_provider=jwt_params.create_assertion_provider(), + scopes="read write", +) +``` + +The SDK obtains the token endpoint URL from AuthGate's `/.well-known/openid-configuration` and handles assertion signing + token refresh automatically. + +### curl (debugging) + +Generate a key pair and sign a JWT manually (using `python -c` or `jose-util`), then: + +```bash +ASSERTION=$(python3 - <<'PY' +import jwt, time, uuid +priv = open('private_key.pem').read() +claims = { + 'iss': 'my-service', + 'sub': 'my-service', + 'aud': 'https://authgate.example.com/oauth/token', + 'iat': int(time.time()), + 'exp': int(time.time()) + 300, + 'jti': str(uuid.uuid4()), +} +print(jwt.encode(claims, priv, algorithm='RS256', headers={'kid': '2026-04-12'})) +PY +) + +curl -X POST https://authgate.example.com/oauth/token \ + -d grant_type=client_credentials \ + -d scope="read write" \ + -d client_assertion_type=urn:ietf:params:oauth:client-assertion-type:jwt-bearer \ + -d client_assertion="$ASSERTION" +``` + +## Key rotation + +- **`jwks_uri`**: publish both old and new keys during overlap; AuthGate re-fetches when it encounters a `kid` it doesn't know (bypassing `JWKS_CACHE_TTL`). This supports zero-downtime rotation. +- **Inline `jwks`**: update the client via DCR or admin API with the new JWK Set. Old clients already issued assertions against the old key continue to verify until reassertions start arriving with the new `kid`. + +Best practice: keep both keys published for at least twice the longest `exp` you expect (e.g. 10 minutes for 5-minute assertions) before retiring the old one. + +## Which endpoints accept `private_key_jwt` + +Currently supported: + +- `POST /oauth/token` for **`grant_type=client_credentials`** — primary use case (MCP M2M). +- `POST /oauth/introspect` — so Resource Servers can authenticate via the same JWT. + +The other grants (`authorization_code`, `refresh_token`, `device_code`) continue to use `client_secret_*` or public-client (no auth) modes. Extending them is tracked as a follow-up — the shared authenticator (`internal/handlers/client_auth.go`) is already in place, only the per-grant wiring is pending. + +## Troubleshooting + +| Symptom | Likely cause | +| ------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------- | +| `invalid_client` immediately | `iss` / `sub` / `aud` mismatch, or `exp` missing/past | +| `invalid_client` after a few seconds | Clock skew — check `CLIENT_ASSERTION_CLOCK_SKEW` | +| `invalid_client` on re-use | `jti` replay — each assertion must be unique | +| `invalid_client` with correct claims | `kid` missing or not in registered JWKS; check server logs for `no matching JWK for kid` | +| `invalid_client` after key rotation | `JWKS_CACHE_TTL` not yet expired and `kid` mismatch didn't trigger refresh — verify new key ships with a fresh `kid` | +| `unauthorized_client` instead of `invalid_client` | Client registered but `client_credentials` flow not enabled | + +Check `/admin/audit` with filter `event_type=CLIENT_ASSERTION_FAILED` for the exact reason logged server-side. + +## Security notes + +- **Private keys must never leave the client.** Use a secrets manager (AWS Secrets Manager, GCP Secret Manager, HashiCorp Vault) or a KMS-backed sign API (AWS KMS, GCP Cloud KMS). +- Keep assertion lifetime short (≤ 5 minutes) — `CLIENT_ASSERTION_MAX_LIFETIME` caps it server-side. +- Use a cryptographically-random `jti` per assertion (a UUIDv4 is fine). +- `PRIVATE_KEY_JWT_ENABLED=false` immediately rejects all assertion-based authentication and hides the method from discovery — a useful kill-switch if key material is suspected to be compromised and you need time to investigate. +- `jti` replay protection currently uses an in-memory cache per instance. For multi-instance deployments where a client might hit different replicas within the same assertion lifetime, promote the jti cache to Redis (tracked as a follow-up). diff --git a/internal/bootstrap/handlers.go b/internal/bootstrap/handlers.go index d91d17f..53e29e5 100644 --- a/internal/bootstrap/handlers.go +++ b/internal/bootstrap/handlers.go @@ -4,12 +4,15 @@ import ( "crypto" "embed" "net/http" + "strings" "github.com/go-authgate/authgate/internal/auth" + "github.com/go-authgate/authgate/internal/cache" "github.com/go-authgate/authgate/internal/config" "github.com/go-authgate/authgate/internal/core" "github.com/go-authgate/authgate/internal/handlers" "github.com/go-authgate/authgate/internal/services" + "github.com/go-authgate/authgate/internal/util" ) // handlerSet holds all HTTP handlers and required services @@ -50,6 +53,50 @@ func initializeHandlers(deps handlerDeps) handlerSet { // Build JWKS handler from the token provider's public key info jwksHandler := buildJWKSHandler(deps.tokenProvider, deps.cfg) + // Build the optional RFC 7523 private_key_jwt client authenticator. + // Memory caches are used locally — a follow-up PR may wire redis variants + // when multi-instance deployment needs coordinated jti replay protection. + var clientAuth *handlers.ClientAuthenticator + if deps.cfg.PrivateKeyJWTEnabled { + // Pass 0 to disable the background reaper: these caches aren't wired + // into the shutdown manager, and lazy expiration on Get is sufficient + // given the short TTLs used here (JWKS hour, jti ≤ assertion lifetime). + jwksCache := cache.NewMemoryCache[util.JWKSet](0) + jtiCache := cache.NewMemoryCache[bool](0) + jwksFetcher := services.NewJWKSFetcher( + jwksCache, + deps.cfg.JWKSFetchTimeout, + deps.cfg.JWKSCacheTTL, + ) + tokenEndpoint := strings.TrimRight(deps.cfg.BaseURL, "/") + "/oauth/token" + issuer := strings.TrimRight(deps.cfg.BaseURL, "/") + verifier := services.NewClientAssertionVerifier( + deps.services.client, + jwksFetcher, + jtiCache, + deps.auditService, + services.ClientAssertionConfig{ + Enabled: true, + ExpectedAudiences: []string{tokenEndpoint, issuer}, + MaxLifetime: deps.cfg.ClientAssertionMaxLifetime, + ClockSkew: deps.cfg.ClientAssertionClockSkew, + }, + ) + clientAuth = handlers.NewClientAuthenticator( + deps.services.client, + verifier, + ) + } + + tokenHandler := handlers.NewTokenHandler( + deps.services.token, + deps.services.authorization, + deps.cfg, + ) + if clientAuth != nil { + tokenHandler = tokenHandler.WithClientAuthenticator(clientAuth) + } + return handlerSet{ auth: handlers.NewAuthHandler( deps.services.user, @@ -62,11 +109,7 @@ func initializeHandlers(deps handlerDeps) handlerSet { deps.services.authorization, deps.cfg, ), - token: handlers.NewTokenHandler( - deps.services.token, - deps.services.authorization, - deps.cfg, - ), + token: tokenHandler, client: handlers.NewClientHandler(deps.services.client, deps.services.authorization), userClient: handlers.NewUserClientHandler(deps.services.client), session: handlers.NewSessionHandler(deps.services.token), diff --git a/internal/config/config.go b/internal/config/config.go index d3a9a3c..5eb2056 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -100,6 +100,13 @@ type Config struct { // Client Credentials Flow settings (RFC 6749 §4.4) ClientCredentialsTokenExpiration time.Duration // Access token lifetime for client_credentials grant (default: 1h, same as JWTExpiration) + // Private Key JWT settings (RFC 7523 client_assertion-based authentication) + PrivateKeyJWTEnabled bool // Enable private_key_jwt token endpoint authentication (default: true) + JWKSFetchTimeout time.Duration // HTTP timeout for fetching remote JWKS (default: 10s) + JWKSCacheTTL time.Duration // Cached JWKS lifetime before refetch (default: 1h) + ClientAssertionMaxLifetime time.Duration // Maximum allowed assertion exp-iat window (default: 5m) + ClientAssertionClockSkew time.Duration // Clock skew tolerance when validating exp/nbf/iat (default: 30s) + // OAuth settings // GitHub OAuth GitHubOAuthEnabled bool @@ -304,6 +311,13 @@ func Load() *Config { time.Hour, ), // 1 hour default; keep short — no refresh token means no rotation mechanism + // Private Key JWT (RFC 7523) settings + PrivateKeyJWTEnabled: getEnvBool("PRIVATE_KEY_JWT_ENABLED", true), + JWKSFetchTimeout: getEnvDuration("JWKS_FETCH_TIMEOUT", 10*time.Second), + JWKSCacheTTL: getEnvDuration("JWKS_CACHE_TTL", time.Hour), + ClientAssertionMaxLifetime: getEnvDuration("CLIENT_ASSERTION_MAX_LIFETIME", 5*time.Minute), + ClientAssertionClockSkew: getEnvDuration("CLIENT_ASSERTION_CLOCK_SKEW", 30*time.Second), + // OAuth settings // GitHub OAuth GitHubOAuthEnabled: getEnvBool("GITHUB_OAUTH_ENABLED", false), diff --git a/internal/handlers/client_auth.go b/internal/handlers/client_auth.go new file mode 100644 index 0000000..b66a5bc --- /dev/null +++ b/internal/handlers/client_auth.go @@ -0,0 +1,150 @@ +package handlers + +import ( + "context" + "errors" + "strings" + + "github.com/go-authgate/authgate/internal/models" + "github.com/go-authgate/authgate/internal/services" + + "github.com/gin-gonic/gin" +) + +// Form parameters defined in RFC 7521 §4.2 for assertion-based client authentication. +const ( + formClientAssertion = "client_assertion" + formClientAssertionType = "client_assertion_type" +) + +// Errors returned by the shared client authenticator. All map to OAuth +// invalid_client; the distinction is for logs/tests and for choosing the +// correct WWW-Authenticate header. +var ( + ErrClientAuthRequired = errors.New("client authentication required") + ErrClientAuthMismatch = errors.New("client_id mismatch between parameters and assertion") + ErrClientAuthMethodUnmet = errors.New( + "client authentication method does not match client registration", + ) + ErrClientAuthSecretBad = errors.New("invalid client secret") +) + +// AuthenticatedClient carries the outcome of a successful authentication at the +// token endpoint. +type AuthenticatedClient struct { + Client *models.OAuthApplication + Method string // client_secret_basic | client_secret_post | private_key_jwt | none +} + +// ClientAuthenticator performs RFC 6749 §2.3 + RFC 7521 §4.2 authentication at +// the token endpoint. It supports client_secret_basic, client_secret_post, +// private_key_jwt, and none (public clients). Assertion-audience validation +// lives on ClientAssertionVerifier; this type only dispatches. +type ClientAuthenticator struct { + clientService *services.ClientService + assertionVerifier *services.ClientAssertionVerifier +} + +// NewClientAuthenticator wires a new authenticator. assertionVerifier may be +// nil, in which case private_key_jwt assertions are rejected. +func NewClientAuthenticator( + cs *services.ClientService, + av *services.ClientAssertionVerifier, +) *ClientAuthenticator { + return &ClientAuthenticator{ + clientService: cs, + assertionVerifier: av, + } +} + +// Authenticate inspects the request and returns the authenticated client. +// requireConfidential=true forces the caller to present credentials (for grants +// like client_credentials that are restricted to confidential clients); +// requireConfidential=false still verifies credentials when they are supplied, +// but allows public clients to pass with only a client_id. +func (a *ClientAuthenticator) Authenticate( + c *gin.Context, + requireConfidential bool, +) (*AuthenticatedClient, error) { + assertion := c.PostForm(formClientAssertion) + assertionType := c.PostForm(formClientAssertionType) + if assertion != "" || assertionType != "" { + return a.authenticateViaAssertion(c.Request.Context(), c, assertion, assertionType) + } + + clientID, secret, cameFromHeader := parseClientCredentials(c) + if clientID == "" { + return nil, ErrClientAuthRequired + } + + client, err := a.clientService.GetClientWithSecret(c.Request.Context(), clientID) + if err != nil { + return nil, ErrClientAuthRequired + } + if !client.IsActive() { + return nil, ErrClientAuthRequired + } + + // Enforce registration consistency: a private_key_jwt client must use an + // assertion (which takes the branch above), not a shared secret. + if client.UsesPrivateKeyJWT() { + return nil, ErrClientAuthMethodUnmet + } + + if client.UsesClientSecret() { + if secret == "" || !client.ValidateClientSecret([]byte(secret)) { + return nil, ErrClientAuthSecretBad + } + observed := models.TokenEndpointAuthClientSecretBasic + if !cameFromHeader { + observed = models.TokenEndpointAuthClientSecretPost + } + // Enforce the registered method contract: a client registered as + // client_secret_basic must use Basic Auth, and client_secret_post + // must use form-body credentials. Accepting either mode would let a + // stolen secret be presented through whichever channel happens to + // be easier for the attacker. + if client.TokenEndpointAuthMethod != observed { + return nil, ErrClientAuthMethodUnmet + } + return &AuthenticatedClient{ + Client: client, + Method: observed, + }, nil + } + + // Public client (method=none). Only allowed when requireConfidential=false. + if requireConfidential { + return nil, ErrClientAuthRequired + } + return &AuthenticatedClient{ + Client: client, + Method: models.TokenEndpointAuthNone, + }, nil +} + +func (a *ClientAuthenticator) authenticateViaAssertion( + ctx context.Context, + c *gin.Context, + assertion, assertionType string, +) (*AuthenticatedClient, error) { + if a.assertionVerifier == nil { + return nil, ErrClientAuthRequired + } + client, err := a.assertionVerifier.Verify(ctx, assertion, assertionType) + if err != nil { + return nil, err + } + // RFC 7521 §4.2: if client_id is also sent as a form param, it must match + // the authenticated client. + if formID := strings.TrimSpace( + c.PostForm("client_id"), + ); formID != "" && + formID != client.ClientID { + return nil, ErrClientAuthMismatch + } + return &AuthenticatedClient{ + Client: client, + Method: models.TokenEndpointAuthPrivateKeyJWT, + }, nil +} diff --git a/internal/handlers/oidc.go b/internal/handlers/oidc.go index 89f2511..2979b7f 100644 --- a/internal/handlers/oidc.go +++ b/internal/handlers/oidc.go @@ -42,20 +42,21 @@ func NewOIDCHandler( // discoveryMetadata holds the OIDC Provider Metadata returned by the discovery endpoint. type discoveryMetadata struct { - Issuer string `json:"issuer"` - AuthorizationEndpoint string `json:"authorization_endpoint"` - TokenEndpoint string `json:"token_endpoint"` - UserinfoEndpoint string `json:"userinfo_endpoint"` - RevocationEndpoint string `json:"revocation_endpoint"` - JwksURI string `json:"jwks_uri,omitempty"` - ResponseTypesSupported []string `json:"response_types_supported"` - SubjectTypesSupported []string `json:"subject_types_supported"` - IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported,omitempty"` - ScopesSupported []string `json:"scopes_supported"` - TokenEndpointAuthMethods []string `json:"token_endpoint_auth_methods_supported"` - GrantTypesSupported []string `json:"grant_types_supported"` - ClaimsSupported []string `json:"claims_supported"` - CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"` + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + UserinfoEndpoint string `json:"userinfo_endpoint"` + RevocationEndpoint string `json:"revocation_endpoint"` + JwksURI string `json:"jwks_uri,omitempty"` + ResponseTypesSupported []string `json:"response_types_supported"` + SubjectTypesSupported []string `json:"subject_types_supported"` + IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported,omitempty"` + ScopesSupported []string `json:"scopes_supported"` + TokenEndpointAuthMethods []string `json:"token_endpoint_auth_methods_supported"` + TokenEndpointAuthSigningAlgValuesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported,omitempty"` + GrantTypesSupported []string `json:"grant_types_supported"` + ClaimsSupported []string `json:"claims_supported"` + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"` } // Discovery godoc @@ -89,11 +90,12 @@ func (h *OIDCHandler) Discovery(c *gin.Context) { SubjectTypesSupported: []string{"public"}, IDTokenSigningAlgValuesSupported: idTokenAlgs, ScopesSupported: scopes, - TokenEndpointAuthMethods: []string{ - "client_secret_basic", - "client_secret_post", - "none", - }, + TokenEndpointAuthMethods: tokenEndpointAuthMethods( + h.config.PrivateKeyJWTEnabled, + ), + TokenEndpointAuthSigningAlgValuesSupported: tokenEndpointAuthSigningAlgs( + h.config.PrivateKeyJWTEnabled, + ), GrantTypesSupported: []string{ GrantTypeAuthorizationCode, GrantTypeDeviceCode, @@ -177,6 +179,30 @@ func (h *OIDCHandler) UserInfo(c *gin.Context) { c.JSON(http.StatusOK, claims) } +// tokenEndpointAuthMethods returns the list of supported client authentication +// methods, conditionally including private_key_jwt (RFC 7523). +func tokenEndpointAuthMethods(privateKeyJWTEnabled bool) []string { + methods := []string{ + models.TokenEndpointAuthClientSecretBasic, + models.TokenEndpointAuthClientSecretPost, + models.TokenEndpointAuthNone, + } + if privateKeyJWTEnabled { + methods = append(methods, models.TokenEndpointAuthPrivateKeyJWT) + } + return methods +} + +// tokenEndpointAuthSigningAlgs returns the list of JWT signing algorithms accepted +// for client_assertion. The list is empty (omitted from discovery) when +// private_key_jwt is disabled. +func tokenEndpointAuthSigningAlgs(privateKeyJWTEnabled bool) []string { + if !privateKeyJWTEnabled { + return nil + } + return []string{models.AssertionAlgRS256, models.AssertionAlgES256} +} + // buildUserInfoClaims constructs UserInfo response claims based on the granted scopes. // sub and iss are always included. profile and email scopes gate their respective claims. func buildUserInfoClaims(userID, issuer, scopes string, user *models.User) map[string]any { diff --git a/internal/handlers/registration.go b/internal/handlers/registration.go index 56d2317..13b0b28 100644 --- a/internal/handlers/registration.go +++ b/internal/handlers/registration.go @@ -2,6 +2,7 @@ package handlers import ( "crypto/subtle" + "encoding/json" "errors" "net/http" "strings" @@ -36,12 +37,15 @@ func NewRegistrationHandler( // clientRegistrationRequest represents the RFC 7591 §2 registration request body. type clientRegistrationRequest struct { - ClientName string `json:"client_name"` - RedirectURIs []string `json:"redirect_uris"` - GrantTypes []string `json:"grant_types"` - TokenEPAuth string `json:"token_endpoint_auth_method"` - Scope string `json:"scope"` - ClientURI string `json:"client_uri"` + ClientName string `json:"client_name"` + RedirectURIs []string `json:"redirect_uris"` + GrantTypes []string `json:"grant_types"` + TokenEPAuth string `json:"token_endpoint_auth_method"` + Scope string `json:"scope"` + ClientURI string `json:"client_uri"` + JWKSURI string `json:"jwks_uri,omitempty"` + JWKS json.RawMessage `json:"jwks,omitempty"` + TokenEPAuthSigningAlg string `json:"token_endpoint_auth_signing_alg,omitempty"` } // Register godoc @@ -52,7 +56,7 @@ type clientRegistrationRequest struct { // @Accept json // @Produce json // @Param request body clientRegistrationRequest true "Client registration request" -// @Success 201 {object} object{client_id=string,client_secret=string,client_name=string,redirect_uris=[]string,grant_types=[]string,token_endpoint_auth_method=string,scope=string,client_id_issued_at=int,client_secret_expires_at=int} "Client registered successfully" +// @Success 201 {object} object{client_id=string,client_secret=string,client_secret_expires_at=int,jwks_uri=string,token_endpoint_auth_signing_alg=string,client_name=string,redirect_uris=[]string,grant_types=[]string,token_endpoint_auth_method=string,scope=string,client_id_issued_at=int} "Client registered successfully. client_secret and client_secret_expires_at are only present for client_secret_basic/post auth methods; jwks_uri and token_endpoint_auth_signing_alg are present for private_key_jwt." // @Failure 400 {object} object{error=string,error_description=string} "Invalid client metadata" // @Failure 401 {object} object{error=string,error_description=string} "Invalid or missing initial access token" // @Failure 403 {object} object{error=string,error_description=string} "Dynamic registration is disabled" @@ -104,10 +108,17 @@ func (h *RegistrationHandler) Register(c *gin.Context) { // 5. Determine grant types from request enableDeviceFlow := false enableAuthCodeFlow := false + enableClientCredentials := false if len(req.GrantTypes) == 0 { - // RFC 7591 §2: default grant_type is "authorization_code" - enableAuthCodeFlow = true + // RFC 7591 §2: default grant_type is "authorization_code", except + // private_key_jwt registrations typically target the client_credentials + // flow — skip the redirect_uri requirement in that case. + if req.TokenEPAuth == models.TokenEndpointAuthPrivateKeyJWT { + enableClientCredentials = true + } else { + enableAuthCodeFlow = true + } } else { for _, gt := range req.GrantTypes { switch gt { @@ -115,12 +126,14 @@ func (h *RegistrationHandler) Register(c *gin.Context) { enableAuthCodeFlow = true case GrantTypeDeviceCode, GrantTypeDeviceCodeShort: enableDeviceFlow = true + case GrantTypeClientCredentials: + enableClientCredentials = true default: respondOAuthError( c, http.StatusBadRequest, "invalid_client_metadata", - "Unsupported grant_type: "+gt+". Supported: authorization_code, device_code", + "Unsupported grant_type: "+gt+". Supported: authorization_code, device_code, client_credentials", ) return } @@ -130,25 +143,80 @@ func (h *RegistrationHandler) Register(c *gin.Context) { // 6. Determine auth method → client type (RFC 7591 §2: default is "client_secret_basic") authMethod := req.TokenEPAuth if authMethod == "" { - authMethod = "client_secret_basic" + authMethod = models.TokenEndpointAuthClientSecretBasic } var clientType core.ClientType switch authMethod { - case "none": + case models.TokenEndpointAuthNone: clientType = core.ClientTypePublic - case "client_secret_basic", "client_secret_post": + case models.TokenEndpointAuthClientSecretBasic, models.TokenEndpointAuthClientSecretPost: + clientType = core.ClientTypeConfidential + case models.TokenEndpointAuthPrivateKeyJWT: + if !h.config.PrivateKeyJWTEnabled { + respondOAuthError( + c, + http.StatusBadRequest, + "invalid_client_metadata", + "private_key_jwt is not enabled on this server", + ) + return + } clientType = core.ClientTypeConfidential default: respondOAuthError( c, http.StatusBadRequest, "invalid_client_metadata", - "Unsupported token_endpoint_auth_method: "+req.TokenEPAuth+". Supported: none, client_secret_basic, client_secret_post", + "Unsupported token_endpoint_auth_method: "+req.TokenEPAuth+". Supported: none, client_secret_basic, client_secret_post, private_key_jwt", ) return } + // 6b. For private_key_jwt, require key material (RFC 7591 §2.1). + var ( + jwksInline string + signingAlg string + ) + if authMethod == models.TokenEndpointAuthPrivateKeyJWT { + hasURI := strings.TrimSpace(req.JWKSURI) != "" + hasInline := len(req.JWKS) > 0 && !isJSONNull(req.JWKS) + if !hasURI && !hasInline { + respondOAuthError( + c, + http.StatusBadRequest, + "invalid_client_metadata", + "private_key_jwt requires either jwks_uri or jwks", + ) + return + } + if hasURI && hasInline { + respondOAuthError( + c, + http.StatusBadRequest, + "invalid_client_metadata", + "jwks_uri and jwks are mutually exclusive", + ) + return + } + if hasInline { + jwksInline = string(req.JWKS) + } + signingAlg = req.TokenEPAuthSigningAlg + if signingAlg == "" { + signingAlg = models.AssertionAlgRS256 + } + if signingAlg != models.AssertionAlgRS256 && signingAlg != models.AssertionAlgES256 { + respondOAuthError( + c, + http.StatusBadRequest, + "invalid_client_metadata", + "Unsupported token_endpoint_auth_signing_alg: "+req.TokenEPAuthSigningAlg+". Supported: RS256, ES256", + ) + return + } + } + // 7. Validate scopes (only user-safe scopes allowed) scope := strings.TrimSpace(req.Scope) if scope != "" { @@ -170,14 +238,19 @@ func (h *RegistrationHandler) Register(c *gin.Context) { // 8. Create the client via service (pending status, requires admin approval) createReq := services.CreateClientRequest{ - ClientName: req.ClientName, - Description: req.ClientURI, - Scopes: scope, - RedirectURIs: req.RedirectURIs, - ClientType: clientType, - EnableDeviceFlow: enableDeviceFlow, - EnableAuthCodeFlow: enableAuthCodeFlow, - IsAdminCreated: false, // Dynamic registration → pending approval + ClientName: req.ClientName, + Description: req.ClientURI, + Scopes: scope, + RedirectURIs: req.RedirectURIs, + ClientType: clientType, + EnableDeviceFlow: enableDeviceFlow, + EnableAuthCodeFlow: enableAuthCodeFlow, + IsAdminCreated: false, // Dynamic registration → pending approval + EnableClientCredentialsFlow: enableClientCredentials, + TokenEndpointAuthMethod: authMethod, + TokenEndpointAuthSigningAlg: signingAlg, + JWKSURI: req.JWKSURI, + JWKS: jwksInline, } resp, err := h.clientService.CreateClient(c.Request.Context(), createReq) @@ -219,17 +292,35 @@ func (h *RegistrationHandler) Register(c *gin.Context) { // 10. Build RFC 7591 §3.2.1 response grantTypes := buildResponseGrantTypes(app) - c.JSON(http.StatusCreated, gin.H{ + response := gin.H{ "client_id": app.ClientID, - "client_secret": resp.ClientSecretPlain, "client_name": app.ClientName, "redirect_uris": app.RedirectURIs, "grant_types": grantTypes, "token_endpoint_auth_method": authMethod, "scope": app.Scopes, "client_id_issued_at": app.CreatedAt.Unix(), - "client_secret_expires_at": 0, // 0 = does not expire (RFC 7591 §3.2.1) - }) + } + // Only client_secret_* clients receive a shared secret in the response. + if app.UsesClientSecret() { + response["client_secret"] = resp.ClientSecretPlain + response["client_secret_expires_at"] = 0 // RFC 7591 §3.2.1: 0 = does not expire + } + if app.UsesPrivateKeyJWT() { + if app.JWKSURI != "" { + response["jwks_uri"] = app.JWKSURI + } + response["token_endpoint_auth_signing_alg"] = app.TokenEndpointAuthSigningAlg + } + c.JSON(http.StatusCreated, response) +} + +// isJSONNull reports whether the given raw JSON represents the null literal +// (or only whitespace around it). Used to distinguish "jwks": null from a +// missing or present jwks field. +func isJSONNull(raw json.RawMessage) bool { + s := strings.TrimSpace(string(raw)) + return s == "" || s == "null" } // buildResponseGrantTypes converts the OAuthApplication's enabled flows into diff --git a/internal/handlers/registration_test.go b/internal/handlers/registration_test.go index 43f0e90..5cf1a8f 100644 --- a/internal/handlers/registration_test.go +++ b/internal/handlers/registration_test.go @@ -40,6 +40,7 @@ func setupRegistrationTestEnvWithOpts(t *testing.T, opts registrationTestOpts) * BaseURL: "http://localhost:8080", EnableDynamicClientRegistration: opts.enabled, DynamicClientRegistrationToken: opts.token, + PrivateKeyJWTEnabled: true, } s, err := store.New(context.Background(), "sqlite", ":memory:", &config.Config{}) @@ -284,6 +285,72 @@ func TestRegister_UnsupportedAuthMethod(t *testing.T) { w := postRegister(t, r, map[string]any{ "client_name": "My App", + "token_endpoint_auth_method": "mTLS", + }) + + assert.Equal(t, http.StatusBadRequest, w.Code) + var resp map[string]any + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + assert.Equal(t, "invalid_client_metadata", resp["error"]) + assert.Contains(t, resp["error_description"], "mTLS") +} + +// ─── Success: private_key_jwt registration with inline jwks ───────────────── + +func TestRegister_PrivateKeyJWT_InlineJWKS(t *testing.T) { + r := setupRegistrationTestEnv(t, true) + + jwks := map[string]any{ + "keys": []map[string]any{ + { + "kty": "RSA", + "use": "sig", + "kid": "test", + "alg": "RS256", + "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw", + "e": "AQAB", + }, + }, + } + w := postRegister(t, r, map[string]any{ + "client_name": "Machine Client", + "token_endpoint_auth_method": "private_key_jwt", + "token_endpoint_auth_signing_alg": "RS256", + "jwks": jwks, + }) + + assert.Equal(t, http.StatusCreated, w.Code) + var resp map[string]any + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + assert.Equal(t, "private_key_jwt", resp["token_endpoint_auth_method"]) + assert.Equal(t, "RS256", resp["token_endpoint_auth_signing_alg"]) + // private_key_jwt clients must not receive a shared secret. + _, hasSecret := resp["client_secret"] + assert.False(t, hasSecret, "private_key_jwt client should not receive a client_secret") +} + +func TestRegister_PrivateKeyJWT_JWKSURI(t *testing.T) { + r := setupRegistrationTestEnv(t, true) + + w := postRegister(t, r, map[string]any{ + "client_name": "Machine Client", + "token_endpoint_auth_method": "private_key_jwt", + "token_endpoint_auth_signing_alg": "ES256", + "jwks_uri": "https://example.com/.well-known/jwks.json", + }) + + assert.Equal(t, http.StatusCreated, w.Code) + var resp map[string]any + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + assert.Equal(t, "private_key_jwt", resp["token_endpoint_auth_method"]) + assert.Equal(t, "https://example.com/.well-known/jwks.json", resp["jwks_uri"]) +} + +func TestRegister_PrivateKeyJWT_MissingKeyMaterial(t *testing.T) { + r := setupRegistrationTestEnv(t, true) + + w := postRegister(t, r, map[string]any{ + "client_name": "Machine Client", "token_endpoint_auth_method": "private_key_jwt", }) @@ -291,7 +358,20 @@ func TestRegister_UnsupportedAuthMethod(t *testing.T) { var resp map[string]any require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) assert.Equal(t, "invalid_client_metadata", resp["error"]) - assert.Contains(t, resp["error_description"], "private_key_jwt") + assert.Contains(t, resp["error_description"], "jwks") +} + +func TestRegister_PrivateKeyJWT_BothKeysProvided(t *testing.T) { + r := setupRegistrationTestEnv(t, true) + + w := postRegister(t, r, map[string]any{ + "client_name": "Machine Client", + "token_endpoint_auth_method": "private_key_jwt", + "jwks_uri": "https://example.com/jwks", + "jwks": map[string]any{"keys": []any{}}, + }) + + assert.Equal(t, http.StatusBadRequest, w.Code) } // ─── Error: unsupported scope ──────────────────────────────────────────────── diff --git a/internal/handlers/token.go b/internal/handlers/token.go index 235ce37..472a986 100644 --- a/internal/handlers/token.go +++ b/internal/handlers/token.go @@ -43,6 +43,7 @@ type TokenHandler struct { tokenService *services.TokenService authorizationService *services.AuthorizationService config *config.Config + clientAuthenticator *ClientAuthenticator // optional; when nil, falls back to secret-only auth } func NewTokenHandler( @@ -57,6 +58,13 @@ func NewTokenHandler( } } +// WithClientAuthenticator attaches a shared ClientAuthenticator so the token +// endpoint can accept RFC 7523 private_key_jwt in addition to client_secret. +func (h *TokenHandler) WithClientAuthenticator(a *ClientAuthenticator) *TokenHandler { + h.clientAuthenticator = a + return h +} + // buildTokenResponse constructs a standard OAuth 2.0 token response (RFC 6749 §5.1). func buildTokenResponse(accessToken, refreshToken *models.AccessToken, idToken string) gin.H { expiresIn := max(int(time.Until(accessToken.ExpiresAt).Seconds()), 0) @@ -278,45 +286,69 @@ func (h *TokenHandler) TokenInfo(c *gin.Context) { // Introspect godoc // // @Summary Introspect token (RFC 7662) -// @Description Determine the active state and metadata of an OAuth 2.0 token. Requires client authentication via HTTP Basic Auth or form-body client credentials. +// @Description Determine the active state and metadata of an OAuth 2.0 token. Requires client authentication via HTTP Basic Auth, form-body client_id/client_secret, or RFC 7523 private_key_jwt (client_assertion + client_assertion_type). // @Tags OAuth // @Accept x-www-form-urlencoded // @Produce json -// @Param token formData string true "The token to introspect" -// @Param token_type_hint formData string false "Hint about the type of token: 'access_token' or 'refresh_token'" -// @Param client_id formData string false "Client ID (alternative to HTTP Basic Auth)" -// @Param client_secret formData string false "Client secret (alternative to HTTP Basic Auth)" +// @Param token formData string true "The token to introspect" +// @Param token_type_hint formData string false "Hint about the type of token: 'access_token' or 'refresh_token'" +// @Param client_id formData string false "Client ID (alternative to HTTP Basic Auth)" +// @Param client_secret formData string false "Client secret (alternative to HTTP Basic Auth)" +// @Param client_assertion formData string false "Signed JWT assertion (RFC 7523 private_key_jwt); use with client_assertion_type" +// @Param client_assertion_type formData string false "Must be urn:ietf:params:oauth:client-assertion-type:jwt-bearer when client_assertion is present" // @Success 200 {object} object{active=bool,scope=string,client_id=string,username=string,token_type=string,exp=int,iat=int,sub=string,iss=string,jti=string} "Token introspection response" // @Failure 401 {object} object{error=string,error_description=string} "Client authentication failed" // @Router /oauth/introspect [post] func (h *TokenHandler) Introspect(c *gin.Context) { - // 1. Authenticate the calling client (RFC 7662 §2.1) - clientID, clientSecret := parseClientCredentials(c) - if clientID == "" || clientSecret == "" { - c.Header("WWW-Authenticate", `Basic realm="authgate"`) - respondOAuthError( - c, - http.StatusUnauthorized, - errInvalidClient, - "Client authentication required", - ) - return - } - - // Verify client credentials - if err := h.tokenService.AuthenticateClient( - c.Request.Context(), - clientID, - clientSecret, - ); err != nil { - c.Header("WWW-Authenticate", `Basic realm="authgate"`) - respondOAuthError( - c, - http.StatusUnauthorized, - errInvalidClient, - "Client authentication failed", - ) - return + // 1. Authenticate the calling client (RFC 7662 §2.1). + // Prefer the shared authenticator so private_key_jwt is accepted alongside + // client_secret_* methods. + var clientID string + if h.clientAuthenticator != nil { + authed, err := h.clientAuthenticator.Authenticate(c, true) + if err != nil { + // Only advertise a Basic challenge when Basic is actually an + // applicable method for this caller — otherwise a private_key_jwt + // client would see a misleading WWW-Authenticate header. + if shouldChallengeBasic(c) { + c.Header("WWW-Authenticate", `Basic realm="authgate"`) + } + respondOAuthError( + c, + http.StatusUnauthorized, + errInvalidClient, + "Client authentication failed", + ) + return + } + clientID = authed.Client.ClientID + } else { + id, secret, _ := parseClientCredentials(c) + if id == "" || secret == "" { + c.Header("WWW-Authenticate", `Basic realm="authgate"`) + respondOAuthError( + c, + http.StatusUnauthorized, + errInvalidClient, + "Client authentication required", + ) + return + } + if err := h.tokenService.AuthenticateClient( + c.Request.Context(), + id, + secret, + ); err != nil { + c.Header("WWW-Authenticate", `Basic realm="authgate"`) + respondOAuthError( + c, + http.StatusUnauthorized, + errInvalidClient, + "Client authentication failed", + ) + return + } + clientID = id } // 2. Get the token parameter (RFC 7662 §2.1: REQUIRED) @@ -408,12 +440,43 @@ func (h *TokenHandler) Revoke(c *gin.Context) { } // handleClientCredentialsGrant handles the client_credentials grant type (RFC 6749 §4.4). -// Client authentication is accepted via HTTP Basic Auth (preferred per RFC 6749 §2.3.1) -// or as client_id / client_secret form-body parameters. +// Client authentication is accepted via HTTP Basic Auth (preferred per RFC 6749 §2.3.1), +// client_id / client_secret form-body parameters, or RFC 7523 private_key_jwt assertions. // Only confidential clients with the client_credentials flow enabled may use this endpoint. // No refresh token is issued in the response. func (h *TokenHandler) handleClientCredentialsGrant(c *gin.Context) { - clientID, clientSecret := parseClientCredentials(c) + requestedScopes := c.PostForm("scope") // Optional + + // Prefer the shared authenticator when wired — it unifies Basic Auth, + // form-body credentials, and private_key_jwt assertions. + if h.clientAuthenticator != nil { + authed, err := h.clientAuthenticator.Authenticate(c, true) + if err != nil { + if shouldChallengeBasic(c) { + c.Header("WWW-Authenticate", `Basic realm="authgate"`) + } + respondOAuthError( + c, + http.StatusUnauthorized, + errInvalidClient, + "Client authentication failed", + ) + return + } + accessToken, err := h.tokenService.IssueClientCredentialsTokenForClient( + c.Request.Context(), + authed.Client, + requestedScopes, + ) + if err != nil { + h.writeClientCredentialsError(c, err) + return + } + c.JSON(http.StatusOK, buildTokenResponse(accessToken, nil, "")) + return + } + + clientID, clientSecret, _ := parseClientCredentials(c) if clientID == "" || clientSecret == "" { c.Header("WWW-Authenticate", `Basic realm="authgate"`) respondOAuthError( @@ -425,8 +488,6 @@ func (h *TokenHandler) handleClientCredentialsGrant(c *gin.Context) { return } - requestedScopes := c.PostForm("scope") // Optional - accessToken, err := h.tokenService.IssueClientCredentialsToken( c.Request.Context(), clientID, @@ -434,35 +495,7 @@ func (h *TokenHandler) handleClientCredentialsGrant(c *gin.Context) { requestedScopes, ) if err != nil { - switch { - case errors.Is(err, services.ErrInvalidClientCredentials), - errors.Is(err, services.ErrClientNotConfidential): - // RFC 6749 §5.2: use 401 + WWW-Authenticate for invalid_client - c.Header("WWW-Authenticate", `Basic realm="authgate"`) - respondOAuthError( - c, - http.StatusUnauthorized, - errInvalidClient, - "Client authentication failed", - ) - case errors.Is(err, services.ErrClientCredentialsFlowDisabled): - respondOAuthError(c, http.StatusBadRequest, errUnauthorizedClient, - "Client credentials flow is not enabled for this client") - case errors.Is(err, token.ErrInvalidScope): - respondOAuthError( - c, - http.StatusBadRequest, - errInvalidScope, - "Requested scope exceeds client permissions or contains restricted scopes (openid, offline_access are not permitted)", - ) - default: - respondOAuthError( - c, - http.StatusInternalServerError, - errServerError, - "Token issuance failed", - ) - } + h.writeClientCredentialsError(c, err) return } @@ -470,6 +503,51 @@ func (h *TokenHandler) handleClientCredentialsGrant(c *gin.Context) { c.JSON(http.StatusOK, buildTokenResponse(accessToken, nil, "")) } +// shouldChallengeBasic reports whether an HTTP Basic Auth challenge is +// meaningful in a 401 response to the given request. We skip the challenge +// when the caller explicitly used private_key_jwt (client_assertion is +// present) because Basic is not an applicable method for that client. +func shouldChallengeBasic(c *gin.Context) bool { + return c.PostForm(formClientAssertion) == "" && + c.PostForm(formClientAssertionType) == "" +} + +// writeClientCredentialsError maps service-layer errors from the client_credentials +// flow to RFC 6749 error responses. Shared between the classic and shared-authenticator +// code paths. +func (h *TokenHandler) writeClientCredentialsError(c *gin.Context, err error) { + switch { + case errors.Is(err, services.ErrInvalidClientCredentials), + errors.Is(err, services.ErrClientNotConfidential): + if shouldChallengeBasic(c) { + c.Header("WWW-Authenticate", `Basic realm="authgate"`) + } + respondOAuthError( + c, + http.StatusUnauthorized, + errInvalidClient, + "Client authentication failed", + ) + case errors.Is(err, services.ErrClientCredentialsFlowDisabled): + respondOAuthError(c, http.StatusBadRequest, errUnauthorizedClient, + "Client credentials flow is not enabled for this client") + case errors.Is(err, token.ErrInvalidScope): + respondOAuthError( + c, + http.StatusBadRequest, + errInvalidScope, + "Requested scope exceeds client permissions or contains restricted scopes (openid, offline_access are not permitted)", + ) + default: + respondOAuthError( + c, + http.StatusInternalServerError, + errServerError, + "Token issuance failed", + ) + } +} + // handleAuthorizationCodeGrant handles the authorization_code grant type (RFC 6749 §4.1.3). func (h *TokenHandler) handleAuthorizationCodeGrant(c *gin.Context) { code := c.PostForm("code") diff --git a/internal/handlers/token_private_key_jwt_test.go b/internal/handlers/token_private_key_jwt_test.go new file mode 100644 index 0000000..48528c1 --- /dev/null +++ b/internal/handlers/token_private_key_jwt_test.go @@ -0,0 +1,278 @@ +package handlers + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "math/big" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/go-authgate/authgate/internal/cache" + "github.com/go-authgate/authgate/internal/config" + "github.com/go-authgate/authgate/internal/core" + "github.com/go-authgate/authgate/internal/metrics" + "github.com/go-authgate/authgate/internal/models" + "github.com/go-authgate/authgate/internal/services" + "github.com/go-authgate/authgate/internal/store" + "github.com/go-authgate/authgate/internal/token" + "github.com/go-authgate/authgate/internal/util" + + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// setupPKJWTEnv builds a test router whose /oauth/token endpoint accepts RFC 7523 +// client_assertion in addition to the classic client_secret paths. +func setupPKJWTEnv(t *testing.T) (*gin.Engine, *store.Store, string) { + t.Helper() + gin.SetMode(gin.TestMode) + + baseURL := "https://authgate.test" + cfg := &config.Config{ + BaseURL: baseURL, + JWTExpiration: time.Hour, + ClientCredentialsTokenExpiration: time.Hour, + JWTSecret: "test-secret-32-chars-long!!!!!!!", + PrivateKeyJWTEnabled: true, + JWKSFetchTimeout: 2 * time.Second, + JWKSCacheTTL: time.Minute, + ClientAssertionMaxLifetime: 5 * time.Minute, + ClientAssertionClockSkew: 30 * time.Second, + } + + s, err := store.New(context.Background(), "sqlite", ":memory:", &config.Config{}) + require.NoError(t, err) + + localProvider, err := token.NewLocalTokenProvider(cfg) + require.NoError(t, err) + auditSvc := services.NewNoopAuditService() + clientSvc := services.NewClientService(s, auditSvc, nil, 0, nil, 0) + deviceSvc := services.NewDeviceService(s, cfg, auditSvc, metrics.NewNoopMetrics(), clientSvc) + tokenSvc := services.NewTokenService( + s, cfg, deviceSvc, localProvider, auditSvc, metrics.NewNoopMetrics(), + cache.NewNoopCache[models.AccessToken](), clientSvc, + ) + authzSvc := services.NewAuthorizationService(s, cfg, auditSvc, tokenSvc, clientSvc) + + jwksCache := cache.NewMemoryCache[util.JWKSet](0) + jtiCache := cache.NewMemoryCache[bool](0) + t.Cleanup(func() { + _ = jwksCache.Close() + _ = jtiCache.Close() + }) + fetcher := services.NewJWKSFetcher(jwksCache, cfg.JWKSFetchTimeout, cfg.JWKSCacheTTL) + tokenEndpoint := strings.TrimRight(baseURL, "/") + "/oauth/token" + verifier := services.NewClientAssertionVerifier( + clientSvc, fetcher, jtiCache, auditSvc, + services.ClientAssertionConfig{ + Enabled: true, + ExpectedAudiences: []string{tokenEndpoint, baseURL}, + MaxLifetime: cfg.ClientAssertionMaxLifetime, + ClockSkew: cfg.ClientAssertionClockSkew, + }, + ) + clientAuth := NewClientAuthenticator(clientSvc, verifier) + handler := NewTokenHandler(tokenSvc, authzSvc, cfg).WithClientAuthenticator(clientAuth) + + r := gin.New() + r.POST("/oauth/token", handler.Token) + r.POST("/oauth/introspect", handler.Introspect) + + return r, s, tokenEndpoint +} + +// seedPKJWTClient inserts a confidential client whose token endpoint auth method +// is private_key_jwt with the given inline JWK Set. +func seedPKJWTClient( + t *testing.T, + s *store.Store, + jwk util.JWK, + alg string, +) *models.OAuthApplication { + t.Helper() + blob, err := json.Marshal(util.JWKSet{Keys: []util.JWK{jwk}}) + require.NoError(t, err) + client := &models.OAuthApplication{ + ClientID: uuid.New().String(), + ClientName: "MCP M2M client", + UserID: uuid.New().String(), + Scopes: "read write", + GrantTypes: "client_credentials", + ClientType: core.ClientTypeConfidential.String(), + EnableClientCredentialsFlow: true, + Status: models.ClientStatusActive, + TokenEndpointAuthMethod: models.TokenEndpointAuthPrivateKeyJWT, + TokenEndpointAuthSigningAlg: alg, + JWKS: string(blob), + } + require.NoError(t, s.CreateClient(client)) + return client +} + +func rsaPKJWTFixture(t *testing.T, kid string) (*rsa.PrivateKey, util.JWK) { + t.Helper() + priv, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + return priv, util.JWK{ + Kty: "RSA", + Use: "sig", + Kid: kid, + Alg: "RS256", + N: base64.RawURLEncoding.EncodeToString(priv.PublicKey.N.Bytes()), + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(priv.PublicKey.E)).Bytes()), + } +} + +func signAssertion( + t *testing.T, + priv *rsa.PrivateKey, + kid, clientID, audience string, +) string { + t.Helper() + now := time.Now() + claims := jwt.MapClaims{ + "iss": clientID, + "sub": clientID, + "aud": audience, + "iat": now.Unix(), + "exp": now.Add(time.Minute).Unix(), + "jti": uuid.NewString(), + } + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tok.Header["kid"] = kid + out, err := tok.SignedString(priv) + require.NoError(t, err) + return out +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +func TestPrivateKeyJWT_ClientCredentials_Success(t *testing.T) { + r, s, aud := setupPKJWTEnv(t) + priv, jwk := rsaPKJWTFixture(t, "k1") + client := seedPKJWTClient(t, s, jwk, "RS256") + + form := url.Values{ + "grant_type": {"client_credentials"}, + "client_assertion_type": {services.AssertionType}, + "client_assertion": {signAssertion(t, priv, "k1", client.ClientID, aud)}, + "scope": {"read"}, + } + w := postToken(t, r, form, nil) + + require.Equal(t, http.StatusOK, w.Code, "body=%s", w.Body.String()) + var resp map[string]any + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + assert.NotEmpty(t, resp["access_token"]) + assert.Equal(t, "Bearer", resp["token_type"]) + assert.Equal(t, "read", resp["scope"]) + // RFC 6749 §4.4.3 — no refresh_token + _, hasRefresh := resp["refresh_token"] + assert.False(t, hasRefresh) +} + +func TestPrivateKeyJWT_ClientCredentials_InvalidAud(t *testing.T) { + r, s, _ := setupPKJWTEnv(t) + priv, jwk := rsaPKJWTFixture(t, "k1") + client := seedPKJWTClient(t, s, jwk, "RS256") + + form := url.Values{ + "grant_type": {"client_credentials"}, + "client_assertion_type": {services.AssertionType}, + "client_assertion": { + signAssertion(t, priv, "k1", client.ClientID, "https://attacker.example"), + }, + } + w := postToken(t, r, form, nil) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var resp map[string]any + require.NoError(t, json.NewDecoder(w.Body).Decode(&resp)) + assert.Equal(t, "invalid_client", resp["error"]) +} + +func TestPrivateKeyJWT_ClientCredentials_ReplayRejected(t *testing.T) { + r, s, aud := setupPKJWTEnv(t) + priv, jwk := rsaPKJWTFixture(t, "k1") + client := seedPKJWTClient(t, s, jwk, "RS256") + + assertion := signAssertion(t, priv, "k1", client.ClientID, aud) + form := url.Values{ + "grant_type": {"client_credentials"}, + "client_assertion_type": {services.AssertionType}, + "client_assertion": {assertion}, + } + // First use succeeds + w := postToken(t, r, form, nil) + require.Equal(t, http.StatusOK, w.Code, "first call body=%s", w.Body.String()) + // Second use of the exact same assertion must be rejected as replay. + w2 := postToken(t, r, form, nil) + require.Equal(t, http.StatusUnauthorized, w2.Code) +} + +func TestPrivateKeyJWT_Introspect_Success(t *testing.T) { + r, s, aud := setupPKJWTEnv(t) + priv, jwk := rsaPKJWTFixture(t, "k1") + client := seedPKJWTClient(t, s, jwk, "RS256") + + // First obtain an access token via client_credentials. + ccForm := url.Values{ + "grant_type": {"client_credentials"}, + "client_assertion_type": {services.AssertionType}, + "client_assertion": {signAssertion(t, priv, "k1", client.ClientID, aud)}, + } + w := postToken(t, r, ccForm, nil) + require.Equal(t, http.StatusOK, w.Code) + var tokenResp map[string]any + require.NoError(t, json.NewDecoder(w.Body).Decode(&tokenResp)) + accessToken := tokenResp["access_token"].(string) + require.NotEmpty(t, accessToken) + + // Now introspect it, authenticating via a fresh assertion. + introForm := url.Values{ + "token": {accessToken}, + "client_assertion_type": {services.AssertionType}, + "client_assertion": {signAssertion(t, priv, "k1", client.ClientID, aud)}, + } + req, err := http.NewRequest( + http.MethodPost, + "/oauth/introspect", + strings.NewReader(introForm.Encode()), + ) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code, "body=%s", rec.Body.String()) + var introResp map[string]any + require.NoError(t, json.NewDecoder(rec.Body).Decode(&introResp)) + assert.Equal(t, true, introResp["active"]) + assert.Equal(t, client.ClientID, introResp["client_id"]) +} + +func TestPrivateKeyJWT_DisabledMethodRejectsSecretAttempt(t *testing.T) { + r, s, _ := setupPKJWTEnv(t) + _, jwk := rsaPKJWTFixture(t, "k1") + client := seedPKJWTClient(t, s, jwk, "RS256") + + // Attempt to authenticate a private_key_jwt client with a shared secret. + // The client has no secret hash stored, so this must be rejected. + form := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {client.ClientID}, + "client_secret": {"anything"}, + } + w := postToken(t, r, form, nil) + require.Equal(t, http.StatusUnauthorized, w.Code) +} diff --git a/internal/handlers/utils.go b/internal/handlers/utils.go index 20e028a..6eddef8 100644 --- a/internal/handlers/utils.go +++ b/internal/handlers/utils.go @@ -24,13 +24,13 @@ func parsePaginationParams(c *gin.Context) store.PaginationParams { // parseClientCredentials extracts client_id and client_secret from the request // using HTTP Basic Auth (preferred per RFC 6749 §2.3.1) or form-body parameters. -func parseClientCredentials(c *gin.Context) (clientID, clientSecret string) { - clientID, clientSecret, ok := c.Request.BasicAuth() - if !ok { - clientID = c.PostForm("client_id") - clientSecret = c.PostForm("client_secret") +// fromHeader is true when credentials were taken from the Authorization header, +// letting callers distinguish client_secret_basic from client_secret_post. +func parseClientCredentials(c *gin.Context) (clientID, clientSecret string, fromHeader bool) { + if id, pw, ok := c.Request.BasicAuth(); ok { + return id, pw, true } - return clientID, clientSecret + return c.PostForm("client_id"), c.PostForm("client_secret"), false } // respondOAuthError writes an RFC-compliant OAuth error JSON response. diff --git a/internal/models/audit_log.go b/internal/models/audit_log.go index 551abd8..7ef7f44 100644 --- a/internal/models/audit_log.go +++ b/internal/models/audit_log.go @@ -73,6 +73,10 @@ const ( // Token Introspection events (RFC 7662) EventTokenIntrospected EventType = "TOKEN_INTROSPECTED" + // Client authentication events (RFC 7523 private_key_jwt) + EventClientAssertionVerified EventType = "CLIENT_ASSERTION_VERIFIED" + EventClientAssertionFailed EventType = "CLIENT_ASSERTION_FAILED" + // Audit events EventTypeAuditLogView EventType = "AUDIT_LOG_VIEWED" EventTypeAuditLogExported EventType = "AUDIT_LOG_EXPORTED" diff --git a/internal/models/oauth_application.go b/internal/models/oauth_application.go index 8fe215a..93edd24 100644 --- a/internal/models/oauth_application.go +++ b/internal/models/oauth_application.go @@ -23,6 +23,20 @@ const ( ClientStatusInactive ClientStatus = "inactive" // Admin rejected or disabled ) +// Token endpoint authentication methods (RFC 7591 §2). +const ( + TokenEndpointAuthNone = "none" // Public client, no authentication + TokenEndpointAuthClientSecretBasic = "client_secret_basic" // HTTP Basic (default) + TokenEndpointAuthClientSecretPost = "client_secret_post" // client_secret in form body + TokenEndpointAuthPrivateKeyJWT = "private_key_jwt" // RFC 7523 JWT Bearer Assertion +) + +// Client assertion signing algorithms supported for private_key_jwt. +const ( + AssertionAlgRS256 = "RS256" + AssertionAlgES256 = "ES256" +) + // Base32 characters, but lowercased. const lowerBase32Chars = "abcdefghijklmnopqrstuvwxyz234567" @@ -32,7 +46,7 @@ var base32Lower = base32.NewEncoding(lowerBase32Chars).WithPadding(base32.NoPadd type OAuthApplication struct { ID int64 `gorm:"primaryKey;autoIncrement"` ClientID string `gorm:"uniqueIndex;not null"` - ClientSecret string `gorm:"not null"` // bcrypt hashed secret + ClientSecret string `gorm:"not null;default:''"` // bcrypt hashed secret; empty for public / private_key_jwt clients ClientName string `gorm:"not null"` Description string `gorm:"type:text"` UserID string `gorm:"not null"` @@ -42,8 +56,12 @@ type OAuthApplication struct { ClientType string `gorm:"not null;default:'public'"` // "confidential" or "public" EnableDeviceFlow bool `gorm:"not null;default:true"` EnableAuthCodeFlow bool `gorm:"not null;default:false"` - EnableClientCredentialsFlow bool `gorm:"not null;default:false"` // Client Credentials Grant (RFC 6749 §4.4); confidential clients only - Status string `gorm:"not null;default:'active'"` // ClientStatusPending / ClientStatusActive / ClientStatusInactive + EnableClientCredentialsFlow bool `gorm:"not null;default:false"` // Client Credentials Grant (RFC 6749 §4.4); confidential clients only + Status string `gorm:"not null;default:'active'"` // ClientStatusPending / ClientStatusActive / ClientStatusInactive + TokenEndpointAuthMethod string `gorm:"not null;default:'client_secret_basic'"` // RFC 7591 §2 + TokenEndpointAuthSigningAlg string `gorm:"type:varchar(10);not null;default:''"` // RS256 | ES256 (required for private_key_jwt) + JWKSURI string `gorm:"type:varchar(500);not null;default:''"` // Remote JWKS endpoint URL (mutually exclusive with JWKS) + JWKS string `gorm:"type:text;not null;default:''"` // Inline JWK Set JSON (mutually exclusive with JWKSURI) CreatedBy string CreatedAt time.Time UpdatedAt time.Time @@ -110,3 +128,54 @@ func (OAuthApplication) TableName() string { func (app *OAuthApplication) IsActive() bool { return app.Status == ClientStatusActive } + +// UsesPrivateKeyJWT reports whether the client authenticates using JWT Bearer +// Assertions (RFC 7523) at the token endpoint. +func (app *OAuthApplication) UsesPrivateKeyJWT() bool { + return app.TokenEndpointAuthMethod == TokenEndpointAuthPrivateKeyJWT +} + +// UsesClientSecret reports whether the client authenticates using a shared +// secret (either HTTP Basic or form-body). Public clients and private_key_jwt +// clients return false. +func (app *OAuthApplication) UsesClientSecret() bool { + return app.TokenEndpointAuthMethod == TokenEndpointAuthClientSecretBasic || + app.TokenEndpointAuthMethod == TokenEndpointAuthClientSecretPost +} + +// ValidateKeyMaterial verifies that a private_key_jwt client has exactly one of +// JWKSURI or JWKS set, and that the signing algorithm is supported. For other +// auth methods, it verifies no key material is present. +func (app *OAuthApplication) ValidateKeyMaterial() error { + if !app.UsesPrivateKeyJWT() { + if app.JWKSURI != "" || app.JWKS != "" || app.TokenEndpointAuthSigningAlg != "" { + return errors.New( + "JWKS, JWKS URI, and signing algorithm must be empty when token_endpoint_auth_method is not private_key_jwt", + ) + } + return nil + } + + // private_key_jwt validation + hasURI := strings.TrimSpace(app.JWKSURI) != "" + hasInline := strings.TrimSpace(app.JWKS) != "" + if !hasURI && !hasInline { + return errors.New("private_key_jwt requires either jwks_uri or jwks to be provided") + } + if hasURI && hasInline { + return errors.New("private_key_jwt requires jwks_uri and jwks to be mutually exclusive") + } + + switch app.TokenEndpointAuthSigningAlg { + case AssertionAlgRS256, AssertionAlgES256: + return nil + case "": + return errors.New( + "private_key_jwt requires token_endpoint_auth_signing_alg to be set (RS256 or ES256)", + ) + default: + return errors.New( + "unsupported token_endpoint_auth_signing_alg: only RS256 and ES256 are supported", + ) + } +} diff --git a/internal/services/client.go b/internal/services/client.go index 38a16d5..91a7bb8 100644 --- a/internal/services/client.go +++ b/internal/services/client.go @@ -63,8 +63,60 @@ var ( ErrInvalidClientStatus = errors.New( "status must be \"active\", \"inactive\", or \"pending\"", ) + ErrPrivateKeyJWTRequiresConfidential = errors.New( + "private_key_jwt requires a confidential client", + ) + ErrInvalidTokenEndpointAuthMethod = errors.New( + "invalid token_endpoint_auth_method", + ) ) +// validTokenEndpointAuthMethod reports whether m is one of the recognised +// RFC 7591 §2 values AuthGate supports. +func validTokenEndpointAuthMethod(m string) bool { + switch m { + case models.TokenEndpointAuthNone, + models.TokenEndpointAuthClientSecretBasic, + models.TokenEndpointAuthClientSecretPost, + models.TokenEndpointAuthPrivateKeyJWT: + return true + } + return false +} + +// resolveTokenEndpointAuthMethod picks the default auth method for a given +// client type when the caller did not specify one. +func resolveTokenEndpointAuthMethod(method string, clientType core.ClientType) string { + if method != "" { + return method + } + if clientType == core.ClientTypePublic { + return models.TokenEndpointAuthNone + } + return models.TokenEndpointAuthClientSecretBasic +} + +// validateInlineJWKS parses the inline JWKS JSON (if present) and verifies +// every key can be converted to a usable public key. Called on create/update +// so malformed registrations are rejected immediately, rather than failing +// opaquely at assertion-verification time. +func validateInlineJWKS(jwks string) error { + jwks = strings.TrimSpace(jwks) + if jwks == "" { + return nil + } + set, err := util.ParseJWKSet(jwks) + if err != nil { + return err + } + for i := range set.Keys { + if _, err := set.Keys[i].ToPublicKey(); err != nil { + return fmt.Errorf("jwks[%d]: %w", i, err) + } + } + return nil +} + // validateRedirectURIs checks that every URI in the slice is an absolute http/https // URI without a fragment, as required by RFC 6749. func validateRedirectURIs(uris []string) error { @@ -143,6 +195,14 @@ type CreateClientRequest struct { EnableAuthCodeFlow bool // Enable Authorization Code Flow (RFC 6749) EnableClientCredentialsFlow bool // Enable Client Credentials Grant (RFC 6749 §4.4); confidential clients only IsAdminCreated bool // When true: Status=active; when false: Status=pending + + // Token endpoint authentication (RFC 7591 §2). When empty, a default is + // selected based on ClientType. Setting this to "private_key_jwt" (RFC 7523) + // requires JWKSURI or JWKS plus TokenEndpointAuthSigningAlg. + TokenEndpointAuthMethod string + TokenEndpointAuthSigningAlg string // RS256 | ES256 + JWKSURI string // Mutually exclusive with JWKS + JWKS string // Inline JWK Set JSON } type UpdateClientRequest struct { @@ -155,6 +215,12 @@ type UpdateClientRequest struct { EnableDeviceFlow bool EnableAuthCodeFlow bool EnableClientCredentialsFlow bool // Enable Client Credentials Grant (RFC 6749 §4.4); confidential clients only + + // Token endpoint authentication (see CreateClientRequest). + TokenEndpointAuthMethod string + TokenEndpointAuthSigningAlg string + JWKSURI string + JWKS string } type ClientResponse struct { @@ -190,6 +256,40 @@ func (s *ClientService) CreateClient( return nil, err } + // Token endpoint authentication method (RFC 7591 §2). Enforce full + // method ↔ client type consistency so downstream code that keys off + // either field cannot disagree on a client's auth contract. + authMethod := resolveTokenEndpointAuthMethod(req.TokenEndpointAuthMethod, clientType) + if !validTokenEndpointAuthMethod(authMethod) { + return nil, ErrInvalidTokenEndpointAuthMethod + } + switch authMethod { + case models.TokenEndpointAuthNone: + if clientType != core.ClientTypePublic { + return nil, ErrInvalidTokenEndpointAuthMethod + } + case models.TokenEndpointAuthClientSecretBasic, + models.TokenEndpointAuthClientSecretPost: + if clientType != core.ClientTypeConfidential { + return nil, ErrInvalidTokenEndpointAuthMethod + } + case models.TokenEndpointAuthPrivateKeyJWT: + if clientType != core.ClientTypeConfidential { + return nil, ErrPrivateKeyJWTRequiresConfidential + } + // Only client_credentials and introspection currently authenticate + // via the shared ClientAuthenticator. Enabling other grants on a + // private_key_jwt client would produce a client that can register + // but cannot actually exchange codes or refresh tokens, because + // those paths still expect a shared secret. + if req.EnableAuthCodeFlow || req.EnableDeviceFlow { + return nil, fmt.Errorf( + "%w: private_key_jwt is currently supported only for the client_credentials grant; disable authorization_code and device_code flows", + ErrInvalidClientData, + ) + } + } + // Generate client ID clientID := uuid.New().String() @@ -232,12 +332,28 @@ func (s *ClientService) CreateClient( EnableClientCredentialsFlow: enableClientCredentials, Status: clientStatus, CreatedBy: req.CreatedBy, + TokenEndpointAuthMethod: authMethod, + TokenEndpointAuthSigningAlg: req.TokenEndpointAuthSigningAlg, + JWKSURI: strings.TrimSpace(req.JWKSURI), + JWKS: strings.TrimSpace(req.JWKS), } - // Generate client secret - clientSecret, err := client.GenerateClientSecret(ctx) - if err != nil { - return nil, err + if err := client.ValidateKeyMaterial(); err != nil { + return nil, fmt.Errorf("%w: %s", ErrInvalidClientData, err.Error()) + } + if err := validateInlineJWKS(client.JWKS); err != nil { + return nil, fmt.Errorf("%w: invalid jwks: %s", ErrInvalidClientData, err.Error()) + } + + // Generate a shared secret only for the two client_secret_* auth methods. + // Public (none) and private_key_jwt clients do not have a secret. + var clientSecret string + if client.UsesClientSecret() { + var err error + clientSecret, err = client.GenerateClientSecret(ctx) + if err != nil { + return nil, err + } } if err := s.store.CreateClient(client); err != nil { @@ -333,6 +449,67 @@ func (s *ClientService) UpdateClient( enableClientCredentials, ) + // Token endpoint authentication fields are atomic: when the caller + // specifies a method they must also provide its key material, and when + // they omit the method the existing configuration is preserved. This + // prevents forms that don't surface the new fields (e.g. the legacy + // admin UI) from silently wiping a private_key_jwt client's JWKS. + if req.TokenEndpointAuthMethod != "" { + if !validTokenEndpointAuthMethod(req.TokenEndpointAuthMethod) { + return ErrInvalidTokenEndpointAuthMethod + } + switch req.TokenEndpointAuthMethod { + case models.TokenEndpointAuthNone: + if clientType != core.ClientTypePublic { + return ErrInvalidTokenEndpointAuthMethod + } + case models.TokenEndpointAuthClientSecretBasic, + models.TokenEndpointAuthClientSecretPost: + if clientType != core.ClientTypeConfidential { + return ErrInvalidTokenEndpointAuthMethod + } + // Switching to client_secret_* from a method that never stored + // a secret (private_key_jwt or none) would leave the client + // unauthenticatable. Require an explicit RegenerateSecret call + // to mint one so the operator receives the new plaintext. + if client.ClientSecret == "" { + return fmt.Errorf( + "%w: switching to %s requires generating a new secret first (use RegenerateSecret)", + ErrInvalidClientData, + req.TokenEndpointAuthMethod, + ) + } + case models.TokenEndpointAuthPrivateKeyJWT: + if clientType != core.ClientTypeConfidential { + return ErrPrivateKeyJWTRequiresConfidential + } + // See the matching guard in CreateClient — authorization_code and + // device_code still expect a shared secret, so enabling them on a + // private_key_jwt client produces an unusable configuration. + if req.EnableAuthCodeFlow || req.EnableDeviceFlow { + return fmt.Errorf( + "%w: private_key_jwt is currently supported only for the client_credentials grant; disable authorization_code and device_code flows", + ErrInvalidClientData, + ) + } + } + client.TokenEndpointAuthMethod = req.TokenEndpointAuthMethod + client.TokenEndpointAuthSigningAlg = req.TokenEndpointAuthSigningAlg + client.JWKSURI = strings.TrimSpace(req.JWKSURI) + client.JWKS = strings.TrimSpace(req.JWKS) + } + // Clear the shared secret when switching away from client_secret_* methods, + // so a stale hash cannot authenticate a reconfigured client. + if !client.UsesClientSecret() { + client.ClientSecret = "" + } + if err := client.ValidateKeyMaterial(); err != nil { + return fmt.Errorf("%w: %s", ErrInvalidClientData, err.Error()) + } + if err := validateInlineJWKS(client.JWKS); err != nil { + return fmt.Errorf("%w: invalid jwks: %s", ErrInvalidClientData, err.Error()) + } + err = s.store.UpdateClient(client) if err != nil { return err diff --git a/internal/services/client_assertion.go b/internal/services/client_assertion.go new file mode 100644 index 0000000..bf61656 --- /dev/null +++ b/internal/services/client_assertion.go @@ -0,0 +1,405 @@ +package services + +import ( + "context" + "errors" + "fmt" + "log" + "slices" + "strings" + "sync" + "time" + + "github.com/go-authgate/authgate/internal/cache" + "github.com/go-authgate/authgate/internal/core" + "github.com/go-authgate/authgate/internal/models" + "github.com/go-authgate/authgate/internal/util" + + "github.com/golang-jwt/jwt/v5" +) + +// AssertionType is the sole value allowed for the client_assertion_type parameter +// when using JWT Bearer Assertions (RFC 7523 §2.2). +const AssertionType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" + +// jtiCacheKeyPrefix namespaces the per-jti replay protection entries so they +// can coexist with other keys in a shared cache backend. +const jtiCacheKeyPrefix = "pkjwt:jti:" + +// Errors returned by ClientAssertionVerifier. All are presented to the caller as +// OAuth invalid_client — the distinction is useful for audit logs and tests. +var ( + ErrAssertionFeatureDisabled = errors.New("private_key_jwt is disabled") + ErrAssertionTypeInvalid = errors.New("invalid client_assertion_type") + ErrAssertionMalformed = errors.New("malformed client_assertion") + ErrAssertionIssuerMismatch = errors.New("client_assertion iss/sub mismatch") + ErrAssertionClientUnknown = errors.New( + "client_assertion issuer is not a registered client", + ) + ErrAssertionClientInactive = errors.New("client is not active") + ErrAssertionMethodNotAllowed = errors.New("client is not configured for private_key_jwt") + ErrAssertionKeyLookup = errors.New("unable to resolve client signing key") + ErrAssertionSignatureInvalid = errors.New("client_assertion signature is invalid") + ErrAssertionAlgorithmMismatch = errors.New( + "client_assertion algorithm does not match client registration", + ) + ErrAssertionAudienceInvalid = errors.New("client_assertion audience is invalid") + ErrAssertionExpired = errors.New("client_assertion is expired") + ErrAssertionNotYetValid = errors.New("client_assertion is not yet valid") + ErrAssertionLifetimeTooLong = errors.New("client_assertion lifetime exceeds server maximum") + ErrAssertionInvalidTimeWindow = errors.New("client_assertion exp must be strictly after iat") + ErrAssertionMissingJTI = errors.New("client_assertion is missing jti") + ErrAssertionJTIReplay = errors.New("client_assertion jti was already used") + ErrAssertionJTICacheUnavailable = errors.New("client_assertion jti replay cache unavailable") + ErrAssertionMissingRequiredTime = errors.New("client_assertion is missing required time claims") +) + +// ClientAssertionConfig controls the verifier's behaviour. All durations are +// positive; the caller is responsible for providing sensible defaults. +type ClientAssertionConfig struct { + Enabled bool + ExpectedAudiences []string // at least one must be present in the aud claim + MaxLifetime time.Duration + ClockSkew time.Duration +} + +// ClientAssertionVerifier validates JWT Bearer Assertions presented as +// client_assertion at the token endpoint (RFC 7523). +type ClientAssertionVerifier struct { + clientService *ClientService + jwksFetcher *JWKSFetcher + jtiCache core.Cache[bool] + auditService core.AuditLogger + cfg ClientAssertionConfig + + // jtiLocks shards the jti replay Get+Set critical section per client to + // keep honest traffic free of cross-client contention while still closing + // the TOCTOU window for concurrent requests carrying the same jti. + jtiLocks sync.Map // map[string]*sync.Mutex, keyed by client_id +} + +// NewClientAssertionVerifier wires the verifier. auditService may be nil (no-op). +// jtiCache must be supplied — it is required for RFC 7523 §3 replay prevention. +func NewClientAssertionVerifier( + clientService *ClientService, + jwksFetcher *JWKSFetcher, + jtiCache core.Cache[bool], + auditService core.AuditLogger, + cfg ClientAssertionConfig, +) *ClientAssertionVerifier { + if auditService == nil { + auditService = NewNoopAuditService() + } + if cfg.MaxLifetime <= 0 { + cfg.MaxLifetime = 5 * time.Minute + } + if cfg.ClockSkew <= 0 { + cfg.ClockSkew = 30 * time.Second + } + return &ClientAssertionVerifier{ + clientService: clientService, + jwksFetcher: jwksFetcher, + jtiCache: jtiCache, + auditService: auditService, + cfg: cfg, + } +} + +// Verify validates the provided JWT assertion and returns the authenticated +// OAuth client. All error returns are safe to surface as OAuth invalid_client. +func (v *ClientAssertionVerifier) Verify( + ctx context.Context, + assertion, assertionType string, +) (*models.OAuthApplication, error) { + if !v.cfg.Enabled { + return nil, ErrAssertionFeatureDisabled + } + if assertionType != AssertionType { + return nil, ErrAssertionTypeInvalid + } + if strings.TrimSpace(assertion) == "" { + return nil, ErrAssertionMalformed + } + + parser := jwt.NewParser(jwt.WithoutClaimsValidation()) + tok, _, err := parser.ParseUnverified(assertion, jwt.MapClaims{}) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrAssertionMalformed, err) + } + claims, ok := tok.Claims.(jwt.MapClaims) + if !ok { + return nil, ErrAssertionMalformed + } + + iss, _ := claims["iss"].(string) + sub, _ := claims["sub"].(string) + if iss == "" || sub == "" || iss != sub { + v.logFailure(ctx, iss, ErrAssertionIssuerMismatch.Error()) + return nil, ErrAssertionIssuerMismatch + } + + client, err := v.clientService.GetClient(ctx, iss) + if err != nil { + v.logFailure(ctx, iss, "client lookup failed") + if errors.Is(err, ErrClientNotFound) { + return nil, ErrAssertionClientUnknown + } + return nil, ErrAssertionClientUnknown + } + if !client.IsActive() { + v.logFailure(ctx, iss, "client inactive") + return nil, ErrAssertionClientInactive + } + if !client.UsesPrivateKeyJWT() { + v.logFailure(ctx, iss, "client not configured for private_key_jwt") + return nil, ErrAssertionMethodNotAllowed + } + + // Algorithm must match the one registered with the client. + if tok.Method.Alg() != client.TokenEndpointAuthSigningAlg { + v.logFailure(ctx, iss, fmt.Sprintf( + "algorithm mismatch: header=%s registered=%s", + tok.Method.Alg(), client.TokenEndpointAuthSigningAlg, + )) + return nil, ErrAssertionAlgorithmMismatch + } + + kid, _ := tok.Header["kid"].(string) + jwkSet, err := v.resolveJWKS(ctx, client, kid) + if err != nil { + v.logFailure(ctx, iss, err.Error()) + return nil, fmt.Errorf("%w: %v", ErrAssertionKeyLookup, err) + } + jwk := jwkSet.FindByKid(kid) + if jwk == nil { + v.logFailure(ctx, iss, fmt.Sprintf("no matching JWK for kid=%q", kid)) + return nil, fmt.Errorf("%w: no matching kid", ErrAssertionKeyLookup) + } + pubKey, err := jwk.ToPublicKey() + if err != nil { + v.logFailure(ctx, iss, fmt.Sprintf("public key decode failed: %v", err)) + return nil, fmt.Errorf("%w: %v", ErrAssertionKeyLookup, err) + } + + // Verify signature with strict algorithm enforcement. Library-side claim + // validation is disabled so our custom skew and lifetime caps apply below. + verifyParser := jwt.NewParser( + jwt.WithValidMethods([]string{client.TokenEndpointAuthSigningAlg}), + jwt.WithoutClaimsValidation(), + ) + if _, err := verifyParser.Parse(assertion, func(_ *jwt.Token) (any, error) { + return pubKey, nil + }); err != nil { + v.logFailure(ctx, iss, fmt.Sprintf("signature verification failed: %v", err)) + return nil, ErrAssertionSignatureInvalid + } + + if err := v.validateTimeClaims(claims); err != nil { + v.logFailure(ctx, iss, err.Error()) + return nil, err + } + if err := v.validateAudience(claims); err != nil { + v.logFailure(ctx, iss, err.Error()) + return nil, err + } + if err := v.checkJTIReplay(ctx, iss, claims); err != nil { + v.logFailure(ctx, iss, err.Error()) + return nil, err + } + + v.logSuccess(ctx, client) + return client, nil +} + +func (v *ClientAssertionVerifier) resolveJWKS( + ctx context.Context, + client *models.OAuthApplication, + kid string, +) (*util.JWKSet, error) { + if client.JWKS != "" { + set, err := util.ParseJWKSet(client.JWKS) + if err != nil { + return nil, fmt.Errorf("parse inline JWKS: %w", err) + } + return set, nil + } + if client.JWKSURI == "" { + return nil, errors.New("client has no JWKS configured") + } + if v.jwksFetcher == nil { + return nil, errors.New("JWKS fetcher not configured") + } + return v.jwksFetcher.GetWithRefresh(ctx, client.JWKSURI, kid) +} + +func (v *ClientAssertionVerifier) validateTimeClaims(claims jwt.MapClaims) error { + now := time.Now() + skew := v.cfg.ClockSkew + + expF, ok := claims["exp"].(float64) + if !ok { + return ErrAssertionMissingRequiredTime + } + iatF, ok := claims["iat"].(float64) + if !ok { + return ErrAssertionMissingRequiredTime + } + exp := time.Unix(int64(expF), 0) + iat := time.Unix(int64(iatF), 0) + + if now.After(exp.Add(skew)) { + return ErrAssertionExpired + } + if iat.Sub(now) > skew { + return ErrAssertionNotYetValid + } + if nbfF, ok := claims["nbf"].(float64); ok { + nbf := time.Unix(int64(nbfF), 0) + if nbf.Sub(now) > skew { + return ErrAssertionNotYetValid + } + } + // exp must be strictly after iat: a zero or negative lifetime is + // nonsensical and would otherwise pass the MaxLifetime bound below. + if !exp.After(iat) { + return ErrAssertionInvalidTimeWindow + } + if exp.Sub(iat) > v.cfg.MaxLifetime { + return ErrAssertionLifetimeTooLong + } + return nil +} + +func (v *ClientAssertionVerifier) validateAudience(claims jwt.MapClaims) error { + raw, ok := claims["aud"] + if !ok { + return fmt.Errorf("%w: missing aud", ErrAssertionAudienceInvalid) + } + audValues := extractAudienceValues(raw) + if len(audValues) == 0 { + return fmt.Errorf("%w: empty aud", ErrAssertionAudienceInvalid) + } + for _, expected := range v.cfg.ExpectedAudiences { + if expected == "" { + continue + } + if slices.Contains(audValues, expected) { + return nil + } + } + return fmt.Errorf("%w: aud %v not accepted", ErrAssertionAudienceInvalid, audValues) +} + +func extractAudienceValues(raw any) []string { + switch v := raw.(type) { + case string: + return []string{v} + case []any: + out := make([]string, 0, len(v)) + for _, item := range v { + if s, ok := item.(string); ok && s != "" { + out = append(out, s) + } + } + return out + case []string: + return v + default: + return nil + } +} + +func (v *ClientAssertionVerifier) checkJTIReplay( + ctx context.Context, + clientID string, + claims jwt.MapClaims, +) error { + jti, _ := claims["jti"].(string) + if strings.TrimSpace(jti) == "" { + return ErrAssertionMissingJTI + } + if v.jtiCache == nil { + // Without a cache we cannot prevent replay. Fail closed using the + // dedicated sentinel so callers/tests can distinguish this from an + // actual replay via errors.Is. + return ErrAssertionJTICacheUnavailable + } + key := jtiCacheKeyPrefix + clientID + ":" + jti + + // Serialise the Get+Set pair per client so two concurrent requests + // carrying the same jti cannot both observe a cache miss before either + // Set lands. Sharding by client keeps honest high-throughput traffic + // free of cross-client contention. + lockIface, _ := v.jtiLocks.LoadOrStore(clientID, &sync.Mutex{}) + lock := lockIface.(*sync.Mutex) + lock.Lock() + defer lock.Unlock() + + switch _, err := v.jtiCache.Get(ctx, key); { + case err == nil: + return ErrAssertionJTIReplay + case errors.Is(err, cache.ErrCacheMiss): + // expected — the jti has not been seen; fall through to record it. + default: + // Backend error (e.g. Redis unavailable). Fail closed so we do not + // silently accept replays while the cache is degraded. Use a + // distinct error so audit logs don't misreport a cache outage as + // a replay attempt. + log.Printf("[ClientAssertion] jti cache lookup failed: %v", err) + return ErrAssertionJTICacheUnavailable + } + // TTL = remaining assertion lifetime + clock skew. If exp is absent, + // fall back to MaxLifetime (defensive). + ttl := v.cfg.MaxLifetime + v.cfg.ClockSkew + if expF, ok := claims["exp"].(float64); ok { + remaining := time.Until(time.Unix(int64(expF), 0)) + v.cfg.ClockSkew + if remaining > 0 { + ttl = remaining + } + } + if err := v.jtiCache.Set(ctx, key, true, ttl); err != nil { + // Set failure after a miss: reject the assertion rather than + // silently skipping replay tracking. Use the dedicated cache- + // unavailable error so audit logs are accurate. + log.Printf("[ClientAssertion] failed to record jti %s: %v", jti, err) + return ErrAssertionJTICacheUnavailable + } + return nil +} + +func (v *ClientAssertionVerifier) logSuccess( + ctx context.Context, + client *models.OAuthApplication, +) { + v.auditService.Log(ctx, core.AuditLogEntry{ + EventType: models.EventClientAssertionVerified, + Severity: models.SeverityInfo, + ActorUserID: "client:" + client.ClientID, + ResourceType: models.ResourceClient, + ResourceID: client.ClientID, + ResourceName: client.ClientName, + Action: "client_assertion verified", + Details: models.AuditDetails{ + "signing_alg": client.TokenEndpointAuthSigningAlg, + }, + Success: true, + }) +} + +func (v *ClientAssertionVerifier) logFailure( + ctx context.Context, + issuer, reason string, +) { + v.auditService.Log(ctx, core.AuditLogEntry{ + EventType: models.EventClientAssertionFailed, + Severity: models.SeverityWarning, + ActorUserID: "client:" + issuer, + ResourceType: models.ResourceClient, + ResourceID: issuer, + Action: "client_assertion rejected", + Details: models.AuditDetails{ + "reason": reason, + }, + Success: false, + }) +} diff --git a/internal/services/client_assertion_test.go b/internal/services/client_assertion_test.go new file mode 100644 index 0000000..7ba2b09 --- /dev/null +++ b/internal/services/client_assertion_test.go @@ -0,0 +1,435 @@ +package services + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "math/big" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-authgate/authgate/internal/cache" + "github.com/go-authgate/authgate/internal/core" + "github.com/go-authgate/authgate/internal/models" + "github.com/go-authgate/authgate/internal/store" + "github.com/go-authgate/authgate/internal/util" + + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +// --- helpers ------------------------------------------------------------ + +const testAudience = "https://authgate.test/oauth/token" + +type rsaFixture struct { + priv *rsa.PrivateKey + jwk util.JWK +} + +func newRSAFixture(t *testing.T, kid string) rsaFixture { + t.Helper() + priv, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + return rsaFixture{ + priv: priv, + jwk: util.JWK{ + Kty: "RSA", + Use: "sig", + Kid: kid, + Alg: "RS256", + N: base64.RawURLEncoding.EncodeToString(priv.PublicKey.N.Bytes()), + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(priv.PublicKey.E)).Bytes()), + }, + } +} + +type ecFixture struct { + priv *ecdsa.PrivateKey + jwk util.JWK +} + +func newECFixture(t *testing.T, kid string) ecFixture { + t.Helper() + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + byteLen := 32 + xBytes := make([]byte, byteLen) + yBytes := make([]byte, byteLen) + copy(xBytes[byteLen-len(priv.PublicKey.X.Bytes()):], priv.PublicKey.X.Bytes()) + copy(yBytes[byteLen-len(priv.PublicKey.Y.Bytes()):], priv.PublicKey.Y.Bytes()) + return ecFixture{ + priv: priv, + jwk: util.JWK{ + Kty: "EC", + Use: "sig", + Kid: kid, + Alg: "ES256", + Crv: "P-256", + X: base64.RawURLEncoding.EncodeToString(xBytes), + Y: base64.RawURLEncoding.EncodeToString(yBytes), + }, + } +} + +func signRS256(t *testing.T, priv *rsa.PrivateKey, kid string, claims jwt.MapClaims) string { + t.Helper() + tok := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tok.Header["kid"] = kid + out, err := tok.SignedString(priv) + require.NoError(t, err) + return out +} + +func signES256(t *testing.T, priv *ecdsa.PrivateKey, kid string, claims jwt.MapClaims) string { + t.Helper() + tok := jwt.NewWithClaims(jwt.SigningMethodES256, claims) + tok.Header["kid"] = kid + out, err := tok.SignedString(priv) + require.NoError(t, err) + return out +} + +func signHS256(t *testing.T, secret []byte, claims jwt.MapClaims) string { + t.Helper() + tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + out, err := tok.SignedString(secret) + require.NoError(t, err) + return out +} + +type verifierFixture struct { + verifier *ClientAssertionVerifier + store *store.Store + cs *ClientService +} + +func newVerifierFixture(t *testing.T) *verifierFixture { + t.Helper() + s := setupTestStore(t) + cs := NewClientService(s, nil, nil, 0, nil, 0) + fetcher := NewJWKSFetcher(cache.NewMemoryCache[util.JWKSet](0), 2*time.Second, time.Minute) + jtiCache := cache.NewMemoryCache[bool](0) + t.Cleanup(func() { _ = jtiCache.Close() }) + v := NewClientAssertionVerifier(cs, fetcher, jtiCache, NewNoopAuditService(), + ClientAssertionConfig{ + Enabled: true, + ExpectedAudiences: []string{testAudience}, + MaxLifetime: 5 * time.Minute, + ClockSkew: 30 * time.Second, + }) + return &verifierFixture{verifier: v, store: s, cs: cs} +} + +func seedRSAClient( + t *testing.T, + s *store.Store, + jwk util.JWK, + alg string, +) *models.OAuthApplication { + t.Helper() + set := util.JWKSet{Keys: []util.JWK{jwk}} + blob, err := json.Marshal(set) + require.NoError(t, err) + client := &models.OAuthApplication{ + ClientID: uuid.New().String(), + ClientName: "pkjwt-client", + UserID: uuid.New().String(), + Scopes: "read write", + GrantTypes: "client_credentials", + ClientType: core.ClientTypeConfidential.String(), + EnableClientCredentialsFlow: true, + Status: models.ClientStatusActive, + TokenEndpointAuthMethod: models.TokenEndpointAuthPrivateKeyJWT, + TokenEndpointAuthSigningAlg: alg, + JWKS: string(blob), + } + require.NoError(t, s.CreateClient(client)) + return client +} + +func baseClaims(clientID string) jwt.MapClaims { + now := time.Now() + return jwt.MapClaims{ + "iss": clientID, + "sub": clientID, + "aud": testAudience, + "iat": now.Unix(), + "exp": now.Add(time.Minute).Unix(), + "jti": uuid.NewString(), + } +} + +// --- success paths ------------------------------------------------------ + +func TestClientAssertion_RS256_Success(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + client := seedRSAClient(t, f.store, rsa.jwk, "RS256") + + token := signRS256(t, rsa.priv, "k1", baseClaims(client.ClientID)) + got, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.NoError(t, err) + require.Equal(t, client.ClientID, got.ClientID) +} + +func TestClientAssertion_ES256_Success(t *testing.T) { + f := newVerifierFixture(t) + ec := newECFixture(t, "ec1") + set := util.JWKSet{Keys: []util.JWK{ec.jwk}} + blob, err := json.Marshal(set) + require.NoError(t, err) + + client := &models.OAuthApplication{ + ClientID: uuid.New().String(), + ClientName: "pkjwt-ec", + UserID: uuid.New().String(), + Scopes: "read", + GrantTypes: "client_credentials", + ClientType: core.ClientTypeConfidential.String(), + EnableClientCredentialsFlow: true, + Status: models.ClientStatusActive, + TokenEndpointAuthMethod: models.TokenEndpointAuthPrivateKeyJWT, + TokenEndpointAuthSigningAlg: "ES256", + JWKS: string(blob), + } + require.NoError(t, f.store.CreateClient(client)) + + token := signES256(t, ec.priv, "ec1", baseClaims(client.ClientID)) + got, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.NoError(t, err) + require.Equal(t, client.ClientID, got.ClientID) +} + +func TestClientAssertion_JWKSURI_Success(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + blob, _ := json.Marshal(util.JWKSet{Keys: []util.JWK{rsa.jwk}}) + _, _ = w.Write(blob) + })) + defer srv.Close() + + client := &models.OAuthApplication{ + ClientID: uuid.New().String(), + ClientName: "pkjwt-uri", + UserID: uuid.New().String(), + Scopes: "read", + GrantTypes: "client_credentials", + ClientType: core.ClientTypeConfidential.String(), + EnableClientCredentialsFlow: true, + Status: models.ClientStatusActive, + TokenEndpointAuthMethod: models.TokenEndpointAuthPrivateKeyJWT, + TokenEndpointAuthSigningAlg: "RS256", + JWKSURI: srv.URL, + } + require.NoError(t, f.store.CreateClient(client)) + + token := signRS256(t, rsa.priv, "k1", baseClaims(client.ClientID)) + got, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.NoError(t, err) + require.Equal(t, client.ClientID, got.ClientID) +} + +// --- failure paths ------------------------------------------------------ + +func TestClientAssertion_FeatureDisabled(t *testing.T) { + f := newVerifierFixture(t) + f.verifier.cfg.Enabled = false + _, err := f.verifier.Verify(context.Background(), "anything", AssertionType) + require.ErrorIs(t, err, ErrAssertionFeatureDisabled) +} + +func TestClientAssertion_WrongAssertionType(t *testing.T) { + f := newVerifierFixture(t) + _, err := f.verifier.Verify(context.Background(), "anything", "urn:other") + require.ErrorIs(t, err, ErrAssertionTypeInvalid) +} + +func TestClientAssertion_Malformed(t *testing.T) { + f := newVerifierFixture(t) + _, err := f.verifier.Verify(context.Background(), "not.a.jwt", AssertionType) + require.ErrorIs(t, err, ErrAssertionMalformed) +} + +func TestClientAssertion_IssSubMismatch(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + client := seedRSAClient(t, f.store, rsa.jwk, "RS256") + + claims := baseClaims(client.ClientID) + claims["sub"] = "different" + token := signRS256(t, rsa.priv, "k1", claims) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionIssuerMismatch) +} + +func TestClientAssertion_UnknownClient(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + claims := baseClaims("ghost-client") + token := signRS256(t, rsa.priv, "k1", claims) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionClientUnknown) +} + +func TestClientAssertion_MethodNotAllowed(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + // Seed a client that registered as client_secret_basic, not private_key_jwt. + client := &models.OAuthApplication{ + ClientID: uuid.New().String(), + ClientName: "secret-client", + UserID: uuid.New().String(), + Scopes: "read", + ClientType: core.ClientTypeConfidential.String(), + EnableClientCredentialsFlow: true, + Status: models.ClientStatusActive, + TokenEndpointAuthMethod: models.TokenEndpointAuthClientSecretBasic, + } + require.NoError(t, f.store.CreateClient(client)) + + token := signRS256(t, rsa.priv, "k1", baseClaims(client.ClientID)) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionMethodNotAllowed) +} + +func TestClientAssertion_AlgorithmMismatch(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + // Register the client as ES256, but sign as RS256 — mismatch. + client := seedRSAClient(t, f.store, rsa.jwk, "ES256") + + token := signRS256(t, rsa.priv, "k1", baseClaims(client.ClientID)) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionAlgorithmMismatch) +} + +func TestClientAssertion_HS256Rejected(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + client := seedRSAClient(t, f.store, rsa.jwk, "RS256") + + // Try to sign with HS256 — algorithm mismatch should reject it regardless. + token := signHS256(t, []byte("shared-secret"), baseClaims(client.ClientID)) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionAlgorithmMismatch) +} + +func TestClientAssertion_BadSignature(t *testing.T) { + f := newVerifierFixture(t) + rsa1 := newRSAFixture(t, "k1") + rsa2 := newRSAFixture(t, "k1") + // Register rsa1's public key, but sign with rsa2's private key. + client := seedRSAClient(t, f.store, rsa1.jwk, "RS256") + + token := signRS256(t, rsa2.priv, "k1", baseClaims(client.ClientID)) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionSignatureInvalid) +} + +func TestClientAssertion_WrongAudience(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + client := seedRSAClient(t, f.store, rsa.jwk, "RS256") + + claims := baseClaims(client.ClientID) + claims["aud"] = "https://other.example.com/token" + token := signRS256(t, rsa.priv, "k1", claims) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionAudienceInvalid) +} + +func TestClientAssertion_Expired(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + client := seedRSAClient(t, f.store, rsa.jwk, "RS256") + + claims := baseClaims(client.ClientID) + claims["iat"] = time.Now().Add(-2 * time.Minute).Unix() + claims["exp"] = time.Now().Add(-time.Minute).Unix() + token := signRS256(t, rsa.priv, "k1", claims) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionExpired) +} + +func TestClientAssertion_LifetimeTooLong(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + client := seedRSAClient(t, f.store, rsa.jwk, "RS256") + + now := time.Now() + claims := jwt.MapClaims{ + "iss": client.ClientID, + "sub": client.ClientID, + "aud": testAudience, + "iat": now.Unix(), + "exp": now.Add(1 * time.Hour).Unix(), // exceeds MaxLifetime=5m + "jti": uuid.NewString(), + } + token := signRS256(t, rsa.priv, "k1", claims) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionLifetimeTooLong) +} + +func TestClientAssertion_MissingJTI(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + client := seedRSAClient(t, f.store, rsa.jwk, "RS256") + + claims := baseClaims(client.ClientID) + delete(claims, "jti") + token := signRS256(t, rsa.priv, "k1", claims) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionMissingJTI) +} + +func TestClientAssertion_JTIReplay(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + client := seedRSAClient(t, f.store, rsa.jwk, "RS256") + + claims := baseClaims(client.ClientID) + token := signRS256(t, rsa.priv, "k1", claims) + + // First use: accepted. + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.NoError(t, err) + // Second use of the same token: rejected as replay. + _, err = f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionJTIReplay) +} + +func TestClientAssertion_UnknownKid(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "registered") + client := seedRSAClient(t, f.store, rsa.jwk, "RS256") + + // Sign with a different kid — no matching JWK. + token := signRS256(t, rsa.priv, "mystery", baseClaims(client.ClientID)) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionKeyLookup) +} + +func TestClientAssertion_InactiveClient(t *testing.T) { + f := newVerifierFixture(t) + rsa := newRSAFixture(t, "k1") + client := seedRSAClient(t, f.store, rsa.jwk, "RS256") + // Disable the client after registration. + client.Status = models.ClientStatusInactive + require.NoError(t, f.store.UpdateClient(client)) + // Clear any cached version. + _ = f.cs.clientCache.Delete(context.Background(), client.ClientID) + + token := signRS256(t, rsa.priv, "k1", baseClaims(client.ClientID)) + _, err := f.verifier.Verify(context.Background(), token, AssertionType) + require.ErrorIs(t, err, ErrAssertionClientInactive) +} diff --git a/internal/services/jwks_fetcher.go b/internal/services/jwks_fetcher.go new file mode 100644 index 0000000..091dd9a --- /dev/null +++ b/internal/services/jwks_fetcher.go @@ -0,0 +1,205 @@ +package services + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "net/http" + "sync" + "time" + + "github.com/go-authgate/authgate/internal/core" + "github.com/go-authgate/authgate/internal/util" +) + +// Errors returned by JWKSFetcher. +var ( + ErrJWKSFetchFailed = errors.New("failed to fetch JWKS") + ErrJWKSTooLarge = errors.New("JWKS response exceeds size limit") +) + +// Maximum JWKS body size accepted over the network (defense against slowloris / huge responses). +const jwksMaxBodyBytes = 1 << 20 // 1 MiB + +// Minimum interval between forced refreshes of the same JWKS URI. +// Prevents attackers from triggering unbounded refetches by sending assertions +// with unknown kids — legitimate rotations still succeed after this cooldown. +const jwksRefreshCooldown = 30 * time.Second + +// JWKSFetcher retrieves and caches JWK Sets from remote jwks_uri endpoints. +// It is safe for concurrent use. +type JWKSFetcher struct { + httpClient *http.Client + cache core.Cache[util.JWKSet] + ttl time.Duration + + // lastRefresh tracks when each uri was last force-refreshed so kid-miss + // driven refetches respect jwksRefreshCooldown. + lastRefresh sync.Map // map[string]time.Time +} + +// NewJWKSFetcher constructs a JWKSFetcher. +// timeout controls the HTTP request timeout; ttl controls the cache lifetime +// when a cache is provided. cache may be nil, in which case every call goes +// straight to the remote — callers that need kid-miss refresh semantics +// should supply a non-nil cache. +func NewJWKSFetcher(cache core.Cache[util.JWKSet], timeout, ttl time.Duration) *JWKSFetcher { + if timeout <= 0 { + timeout = 10 * time.Second + } + if ttl <= 0 { + ttl = time.Hour + } + return &JWKSFetcher{ + httpClient: &http.Client{Timeout: timeout}, + cache: cache, + ttl: ttl, + } +} + +// Get returns the cached JWKS for uri, fetching it on cache miss. Use +// GetWithRefresh when verifying a signature against a specific kid so that a +// cache stale against a rotated signing key triggers a fresh fetch. +func (f *JWKSFetcher) Get(ctx context.Context, uri string) (*util.JWKSet, error) { + return f.getCached(ctx, uri) +} + +// GetWithRefresh returns the JWKS for uri. If the cached version does not +// contain a key with the requested kid, the cache is bypassed and the +// document is refetched — this supports runtime key rotation without waiting +// for the TTL to expire. Refreshes for the same uri are rate-limited to +// prevent malformed or attacker-crafted kids from triggering unbounded +// refetches of the remote endpoint. +func (f *JWKSFetcher) GetWithRefresh(ctx context.Context, uri, kid string) (*util.JWKSet, error) { + set, err := f.getCached(ctx, uri) + if err != nil { + return nil, err + } + if kid == "" || set.FindByKid(kid) != nil { + return set, nil + } + // With no cache there is nothing to refresh — getCached already dialed + // the remote for this call, so an additional fetch is pure overhead. + if f.cache == nil { + return set, nil + } + prev, stored, ok := f.tryReserveRefresh(uri) + if !ok { + return set, nil + } + // Bypass the cache: fetch fresh directly and best-effort update the + // cache afterward. This keeps key rotation working even if the cache + // backend is degraded and Delete/Set fails — callers still see the + // freshly rotated keys. + fresh, err := f.fetch(ctx, uri) + if err != nil { + // Roll back the reservation so the next caller can retry, rather + // than being suppressed by the cooldown after a transient error. + f.rollbackRefreshReservation(uri, prev, stored) + return nil, err + } + if setErr := f.cache.Set(ctx, uri, fresh, f.ttl); setErr != nil { + log.Printf("[JWKSFetcher] cache set after refresh failed for %s: %v", uri, setErr) + // Best-effort delete: without it the stale cached JWKS would + // keep being returned until the cooldown expires, since + // subsequent callers skip refresh while lastRefresh is fresh. + if delErr := f.cache.Delete(ctx, uri); delErr != nil { + log.Printf( + "[JWKSFetcher] cache delete after failed refresh set failed for %s: %v", + uri, delErr, + ) + } + } + return &fresh, nil +} + +// tryReserveRefresh attempts to reserve a refresh for uri, subject to the +// per-URI cooldown. LoadOrStore + CompareAndSwap together ensure that at most +// one concurrent caller wins the right to refresh once the cooldown has +// elapsed. On success it returns the previous timestamp (zero when the uri +// was never refreshed) and the timestamp it stored, so rollbackRefreshReservation +// can undo the reservation if the subsequent network fetch fails. +func (f *JWKSFetcher) tryReserveRefresh(uri string) (prev, stored time.Time, ok bool) { + for { + now := time.Now() + existing, loaded := f.lastRefresh.LoadOrStore(uri, now) + if !loaded { + // First-ever reservation for this uri. + return time.Time{}, now, true + } + prevTime := existing.(time.Time) + if now.Sub(prevTime) < jwksRefreshCooldown { + return time.Time{}, time.Time{}, false + } + // CompareAndSwap ensures only one concurrent caller wins the + // post-cooldown refresh decision; losers retry the loop and fall + // back into the cooldown branch. + if f.lastRefresh.CompareAndSwap(uri, prevTime, now) { + return prevTime, now, true + } + } +} + +// rollbackRefreshReservation reverts a reservation made by tryReserveRefresh +// after the refresh fetch failed. Without this, a transient fetch error would +// strand key rotation until the cooldown elapses. +func (f *JWKSFetcher) rollbackRefreshReservation(uri string, prev, stored time.Time) { + if prev.IsZero() { + // Best-effort: if someone else updated since, leave their value alone. + f.lastRefresh.CompareAndDelete(uri, stored) + return + } + f.lastRefresh.CompareAndSwap(uri, stored, prev) +} + +func (f *JWKSFetcher) getCached(ctx context.Context, uri string) (*util.JWKSet, error) { + if f.cache == nil { + set, err := f.fetch(ctx, uri) + if err != nil { + return nil, err + } + return &set, nil + } + set, err := f.cache.GetWithFetch(ctx, uri, f.ttl, + func(ctx context.Context, _ string) (util.JWKSet, error) { + return f.fetch(ctx, uri) + }) + if err != nil { + return nil, err + } + return &set, nil +} + +func (f *JWKSFetcher) fetch(ctx context.Context, uri string) (util.JWKSet, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil) + if err != nil { + return util.JWKSet{}, fmt.Errorf("%w: %v", ErrJWKSFetchFailed, err) + } + // RFC 7517 §8.5 defines application/jwk-set+json; many JWKS endpoints + // honour Accept strictly and would 406 on plain application/json alone. + req.Header.Set("Accept", "application/jwk-set+json, application/json") + resp, err := f.httpClient.Do(req) + if err != nil { + return util.JWKSet{}, fmt.Errorf("%w: %v", ErrJWKSFetchFailed, err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return util.JWKSet{}, fmt.Errorf( + "%w: unexpected status %d", ErrJWKSFetchFailed, resp.StatusCode, + ) + } + body, err := io.ReadAll(io.LimitReader(resp.Body, jwksMaxBodyBytes+1)) + if err != nil { + return util.JWKSet{}, fmt.Errorf("%w: %v", ErrJWKSFetchFailed, err) + } + if len(body) > jwksMaxBodyBytes { + return util.JWKSet{}, ErrJWKSTooLarge + } + set, err := util.ParseJWKSet(string(body)) + if err != nil { + return util.JWKSet{}, fmt.Errorf("%w: %v", ErrJWKSFetchFailed, err) + } + return *set, nil +} diff --git a/internal/services/jwks_fetcher_test.go b/internal/services/jwks_fetcher_test.go new file mode 100644 index 0000000..c2b67b7 --- /dev/null +++ b/internal/services/jwks_fetcher_test.go @@ -0,0 +1,183 @@ +package services + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "errors" + "math/big" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/go-authgate/authgate/internal/cache" + "github.com/go-authgate/authgate/internal/util" +) + +func jwkSetJSON(t *testing.T, keys []util.JWK) string { + t.Helper() + b, err := json.Marshal(util.JWKSet{Keys: keys}) + if err != nil { + t.Fatalf("marshal JWKS: %v", err) + } + return string(b) +} + +func rsaJWKFixture(t *testing.T, kid string) (util.JWK, *rsa.PrivateKey) { + t.Helper() + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + return util.JWK{ + Kty: "RSA", + Use: "sig", + Kid: kid, + Alg: "RS256", + N: base64.RawURLEncoding.EncodeToString(priv.PublicKey.N.Bytes()), + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(priv.PublicKey.E)).Bytes()), + }, priv +} + +func TestJWKSFetcher_CacheHit(t *testing.T) { + jwk, _ := rsaJWKFixture(t, "k1") + var hits int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(&hits, 1) + _, _ = w.Write([]byte(jwkSetJSON(t, []util.JWK{jwk}))) + })) + defer srv.Close() + + mc := cache.NewMemoryCache[util.JWKSet](0) + defer mc.Close() + f := NewJWKSFetcher(mc, 2*time.Second, time.Minute) + + ctx := context.Background() + if _, err := f.Get(ctx, srv.URL); err != nil { + t.Fatalf("first Get: %v", err) + } + if _, err := f.Get(ctx, srv.URL); err != nil { + t.Fatalf("second Get: %v", err) + } + if n := atomic.LoadInt32(&hits); n != 1 { + t.Fatalf("expected 1 HTTP hit (cache should serve the second), got %d", n) + } +} + +func TestJWKSFetcher_RefreshOnKidMiss(t *testing.T) { + oldJWK, _ := rsaJWKFixture(t, "old") + newJWK, _ := rsaJWKFixture(t, "new") + var hits int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + n := atomic.AddInt32(&hits, 1) + if n == 1 { + _, _ = w.Write([]byte(jwkSetJSON(t, []util.JWK{oldJWK}))) + return + } + _, _ = w.Write([]byte(jwkSetJSON(t, []util.JWK{newJWK}))) + })) + defer srv.Close() + + mc := cache.NewMemoryCache[util.JWKSet](0) + defer mc.Close() + f := NewJWKSFetcher(mc, 2*time.Second, time.Minute) + + ctx := context.Background() + // Warm cache with the old JWKS. + if _, err := f.Get(ctx, srv.URL); err != nil { + t.Fatalf("warm: %v", err) + } + // Ask for a kid that exists only in the refreshed set — fetcher must refetch. + set, err := f.GetWithRefresh(ctx, srv.URL, "new") + if err != nil { + t.Fatalf("GetWithRefresh: %v", err) + } + if got := set.FindByKid("new"); got == nil { + t.Fatal("expected refresh to find new kid") + } + if n := atomic.LoadInt32(&hits); n != 2 { + t.Fatalf("expected 2 HTTP hits after kid miss, got %d", n) + } +} + +func TestJWKSFetcher_RefreshCooldown(t *testing.T) { + jwk, _ := rsaJWKFixture(t, "registered") + var hits int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + atomic.AddInt32(&hits, 1) + _, _ = w.Write([]byte(jwkSetJSON(t, []util.JWK{jwk}))) + })) + defer srv.Close() + + mc := cache.NewMemoryCache[util.JWKSet](0) + defer mc.Close() + f := NewJWKSFetcher(mc, 2*time.Second, time.Minute) + + ctx := context.Background() + // First unknown-kid request triggers a refetch (cache miss forces one more). + _, err := f.GetWithRefresh(ctx, srv.URL, "unknown") + if err != nil { + t.Fatalf("first call: %v", err) + } + firstHits := atomic.LoadInt32(&hits) + // Second unknown-kid request within the cooldown must NOT refetch. + _, err = f.GetWithRefresh(ctx, srv.URL, "still-unknown") + if err != nil { + t.Fatalf("second call: %v", err) + } + if got := atomic.LoadInt32(&hits); got != firstHits { + t.Fatalf("expected no additional refetch within cooldown; hits went %d → %d", + firstHits, got) + } +} + +func TestJWKSFetcher_Non200Fails(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + mc := cache.NewMemoryCache[util.JWKSet](0) + defer mc.Close() + f := NewJWKSFetcher(mc, 2*time.Second, time.Minute) + _, err := f.Get(context.Background(), srv.URL) + if !errors.Is(err, ErrJWKSFetchFailed) { + t.Fatalf("expected ErrJWKSFetchFailed, got %v", err) + } +} + +func TestJWKSFetcher_InvalidBody(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("not json")) + })) + defer srv.Close() + + mc := cache.NewMemoryCache[util.JWKSet](0) + defer mc.Close() + f := NewJWKSFetcher(mc, 2*time.Second, time.Minute) + _, err := f.Get(context.Background(), srv.URL) + if !errors.Is(err, ErrJWKSFetchFailed) { + t.Fatalf("expected ErrJWKSFetchFailed, got %v", err) + } +} + +func TestJWKSFetcher_BodyTooLarge(t *testing.T) { + large := strings.Repeat("a", jwksMaxBodyBytes+10) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte(large)) + })) + defer srv.Close() + + mc := cache.NewMemoryCache[util.JWKSet](0) + defer mc.Close() + f := NewJWKSFetcher(mc, 2*time.Second, time.Minute) + _, err := f.Get(context.Background(), srv.URL) + if !errors.Is(err, ErrJWKSTooLarge) { + t.Fatalf("expected ErrJWKSTooLarge, got %v", err) + } +} diff --git a/internal/services/token_client_credentials.go b/internal/services/token_client_credentials.go index 9a33c9a..18e4700 100644 --- a/internal/services/token_client_credentials.go +++ b/internal/services/token_client_credentials.go @@ -41,11 +41,49 @@ func (s *TokenService) IssueClientCredentialsToken( return nil, ErrClientCredentialsFlowDisabled } - // 4. Authenticate the client via its secret + // 4. Refuse to accept a shared secret for clients whose registered auth + // method is not secret-based (e.g. private_key_jwt). This prevents a + // downgrade attack where a stale or unexpected secret hash on such a + // client could be presented to bypass the assertion requirement. + if !client.UsesClientSecret() { + return nil, ErrInvalidClientCredentials + } + + // 5. Authenticate the client via its secret if !client.ValidateClientSecret([]byte(clientSecret)) { return nil, ErrInvalidClientCredentials } + return s.issueClientCredentialsTokenForClient(ctx, client, requestedScopes) +} + +// IssueClientCredentialsTokenForClient issues a client_credentials access token +// for a client that has already been authenticated (e.g. via RFC 7523 +// private_key_jwt). The caller is responsible for authentication; this method +// only checks client type, flow enablement, and scope bounds. +func (s *TokenService) IssueClientCredentialsTokenForClient( + ctx context.Context, + client *models.OAuthApplication, + requestedScopes string, +) (*models.AccessToken, error) { + if client == nil || !client.IsActive() { + return nil, ErrInvalidClientCredentials + } + if core.ClientType(client.ClientType) != core.ClientTypeConfidential { + return nil, ErrClientNotConfidential + } + if !client.EnableClientCredentialsFlow { + return nil, ErrClientCredentialsFlowDisabled + } + return s.issueClientCredentialsTokenForClient(ctx, client, requestedScopes) +} + +func (s *TokenService) issueClientCredentialsTokenForClient( + ctx context.Context, + client *models.OAuthApplication, + requestedScopes string, +) (*models.AccessToken, error) { + clientID := client.ClientID // 5. Resolve effective scopes effectiveScopes := requestedScopes if effectiveScopes == "" { @@ -137,6 +175,11 @@ func (s *TokenService) AuthenticateClient( if !client.IsActive() { return ErrInvalidClientCredentials } + // Guard against auth-method downgrade: a client registered for + // private_key_jwt or none must not be authenticable via a shared secret. + if !client.UsesClientSecret() { + return ErrInvalidClientCredentials + } if !client.ValidateClientSecret([]byte(clientSecret)) { return ErrInvalidClientCredentials } diff --git a/internal/util/jwk.go b/internal/util/jwk.go new file mode 100644 index 0000000..aad588d --- /dev/null +++ b/internal/util/jwk.go @@ -0,0 +1,195 @@ +package util + +import ( + "crypto" + "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "math" + "math/big" +) + +// Known JWK key types and curves (RFC 7517 §4.1, RFC 7518 §6). +const ( + JWKTypeRSA = "RSA" + JWKTypeEC = "EC" + JWKCurveP256 = "P-256" +) + +// JWK represents a single JSON Web Key (RFC 7517). +// Only the fields required for RSA and EC P-256 signing keys are modelled; additional +// JWK fields (e.g. x5c, x5t) are ignored on parse and omitted on marshal. +type JWK struct { + Kty string `json:"kty"` + Use string `json:"use,omitempty"` + Kid string `json:"kid,omitempty"` + Alg string `json:"alg,omitempty"` + // RSA + N string `json:"n,omitempty"` + E string `json:"e,omitempty"` + // EC + Crv string `json:"crv,omitempty"` + X string `json:"x,omitempty"` + Y string `json:"y,omitempty"` +} + +// JWKSet is a set of JSON Web Keys (RFC 7517 §5). +type JWKSet struct { + Keys []JWK `json:"keys"` +} + +// Errors returned from this package. +var ( + ErrInvalidJWKS = errors.New("invalid JWK Set") + ErrUnsupportedKeyType = errors.New("unsupported JWK key type") + ErrJWKFieldMissing = errors.New("JWK missing required field") + ErrJWKInvalidEncoding = errors.New("JWK contains invalid base64url encoding") + ErrJWKNoMatchingKey = errors.New("no JWK matches requested kid") + ErrJWKAlgorithmMismatch = errors.New("JWK algorithm does not match requested algorithm") +) + +// ParseJWKSet unmarshals a JWKS JSON document and validates that it contains +// at least one key. It does not verify individual keys — use ToPublicKey for that. +func ParseJWKSet(jsonBlob string) (*JWKSet, error) { + var set JWKSet + if err := json.Unmarshal([]byte(jsonBlob), &set); err != nil { + return nil, fmt.Errorf("%w: %v", ErrInvalidJWKS, err) + } + if len(set.Keys) == 0 { + return nil, fmt.Errorf("%w: keys array is empty", ErrInvalidJWKS) + } + return &set, nil +} + +// FindByKid returns the key whose Kid matches, or nil if not found. +// If kid is empty and there is exactly one key, that key is returned; otherwise nil. +func (s *JWKSet) FindByKid(kid string) *JWK { + if kid == "" { + if len(s.Keys) == 1 { + return &s.Keys[0] + } + return nil + } + for i := range s.Keys { + if s.Keys[i].Kid == kid { + return &s.Keys[i] + } + } + return nil +} + +// ToPublicKey converts a JWK to a crypto.PublicKey. Only RSA and EC P-256 keys +// are supported. The returned key can be used with golang-jwt/jwt for signature +// verification. +func (k *JWK) ToPublicKey() (crypto.PublicKey, error) { + switch k.Kty { + case JWKTypeRSA: + return rsaKeyFromJWK(k) + case JWKTypeEC: + return ecKeyFromJWK(k) + default: + return nil, fmt.Errorf("%w: kty=%q", ErrUnsupportedKeyType, k.Kty) + } +} + +func rsaKeyFromJWK(k *JWK) (*rsa.PublicKey, error) { + if k.N == "" || k.E == "" { + return nil, fmt.Errorf("%w: RSA key requires n and e", ErrJWKFieldMissing) + } + nBytes, err := base64.RawURLEncoding.DecodeString(k.N) + if err != nil { + return nil, fmt.Errorf("%w: n: %v", ErrJWKInvalidEncoding, err) + } + eBytes, err := base64.RawURLEncoding.DecodeString(k.E) + if err != nil { + return nil, fmt.Errorf("%w: e: %v", ErrJWKInvalidEncoding, err) + } + n := new(big.Int).SetBytes(nBytes) + e := new(big.Int).SetBytes(eBytes) + if !e.IsInt64() { + return nil, fmt.Errorf("%w: RSA exponent overflow", ErrInvalidJWKS) + } + // E must be an odd integer > 1 per PKCS#1; reject weak/degenerate exponents + // (e.g. 0, 1, even values) that make signature verification trivially unsafe. + eInt := e.Int64() + if eInt < 3 || eInt&1 == 0 { + return nil, fmt.Errorf( + "%w: RSA exponent must be odd and >= 3 (got %d)", + ErrInvalidJWKS, eInt, + ) + } + // rsa.PublicKey.E is an int, which is 32 bits on some platforms. Reject + // exponents that would truncate when cast rather than silently producing + // a corrupted key. + if eInt > math.MaxInt { + return nil, fmt.Errorf( + "%w: RSA exponent %d exceeds platform int range", + ErrInvalidJWKS, eInt, + ) + } + if n.BitLen() < 2048 { + return nil, fmt.Errorf( + "%w: RSA modulus must be >= 2048 bits (got %d)", + ErrInvalidJWKS, + n.BitLen(), + ) + } + return &rsa.PublicKey{N: n, E: int(eInt)}, nil +} + +func ecKeyFromJWK(k *JWK) (*ecdsa.PublicKey, error) { + if k.Crv == "" || k.X == "" || k.Y == "" { + return nil, fmt.Errorf("%w: EC key requires crv, x, and y", ErrJWKFieldMissing) + } + var ( + curve elliptic.Curve + ecdhCrv ecdh.Curve + byteLen int + ) + switch k.Crv { + case JWKCurveP256: + curve = elliptic.P256() + ecdhCrv = ecdh.P256() + byteLen = 32 + default: + return nil, fmt.Errorf( + "%w: curve %q (only P-256 is supported)", + ErrUnsupportedKeyType, + k.Crv, + ) + } + xBytes, err := base64.RawURLEncoding.DecodeString(k.X) + if err != nil { + return nil, fmt.Errorf("%w: x: %v", ErrJWKInvalidEncoding, err) + } + yBytes, err := base64.RawURLEncoding.DecodeString(k.Y) + if err != nil { + return nil, fmt.Errorf("%w: y: %v", ErrJWKInvalidEncoding, err) + } + if len(xBytes) > byteLen || len(yBytes) > byteLen { + return nil, fmt.Errorf("%w: EC coordinate exceeds curve byte length", ErrInvalidJWKS) + } + // Left-pad x/y to fixed width, then build SEC1 uncompressed point (0x04 || X || Y) + // and delegate on-curve validation to crypto/ecdh. + xPad := make([]byte, byteLen) + yPad := make([]byte, byteLen) + copy(xPad[byteLen-len(xBytes):], xBytes) + copy(yPad[byteLen-len(yBytes):], yBytes) + uncompressed := make([]byte, 0, 1+2*byteLen) + uncompressed = append(uncompressed, 0x04) + uncompressed = append(uncompressed, xPad...) + uncompressed = append(uncompressed, yPad...) + if _, err := ecdhCrv.NewPublicKey(uncompressed); err != nil { + return nil, fmt.Errorf("%w: EC point is not on curve %s: %v", ErrInvalidJWKS, k.Crv, err) + } + return &ecdsa.PublicKey{ + Curve: curve, + X: new(big.Int).SetBytes(xPad), + Y: new(big.Int).SetBytes(yPad), + }, nil +} diff --git a/internal/util/jwk_test.go b/internal/util/jwk_test.go new file mode 100644 index 0000000..eae6a10 --- /dev/null +++ b/internal/util/jwk_test.go @@ -0,0 +1,199 @@ +package util + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/json" + "errors" + "math/big" + "testing" +) + +// helpers --------------------------------------------------------------- + +func jwkFromRSA(t *testing.T, pub *rsa.PublicKey, kid string) JWK { + t.Helper() + return JWK{ + Kty: "RSA", + Use: "sig", + Kid: kid, + Alg: "RS256", + N: base64.RawURLEncoding.EncodeToString(pub.N.Bytes()), + E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(pub.E)).Bytes()), + } +} + +func jwkFromEC(t *testing.T, pub *ecdsa.PublicKey, kid string) JWK { + t.Helper() + byteLen := (pub.Curve.Params().BitSize + 7) / 8 + xBytes := make([]byte, byteLen) + yBytes := make([]byte, byteLen) + copy(xBytes[byteLen-len(pub.X.Bytes()):], pub.X.Bytes()) + copy(yBytes[byteLen-len(pub.Y.Bytes()):], pub.Y.Bytes()) + return JWK{ + Kty: "EC", + Use: "sig", + Kid: kid, + Alg: "ES256", + Crv: pub.Curve.Params().Name, + X: base64.RawURLEncoding.EncodeToString(xBytes), + Y: base64.RawURLEncoding.EncodeToString(yBytes), + } +} + +// tests ----------------------------------------------------------------- + +func TestParseJWKSet_RSA_Roundtrip(t *testing.T) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + set := JWKSet{Keys: []JWK{jwkFromRSA(t, &priv.PublicKey, "test-key")}} + blob, err := json.Marshal(set) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + parsed, err := ParseJWKSet(string(blob)) + if err != nil { + t.Fatalf("ParseJWKSet: %v", err) + } + if len(parsed.Keys) != 1 { + t.Fatalf("expected 1 key, got %d", len(parsed.Keys)) + } + k := parsed.FindByKid("test-key") + if k == nil { + t.Fatal("FindByKid: nil") + } + pub, err := k.ToPublicKey() + if err != nil { + t.Fatalf("ToPublicKey: %v", err) + } + rsaPub, ok := pub.(*rsa.PublicKey) + if !ok { + t.Fatalf("expected *rsa.PublicKey, got %T", pub) + } + if rsaPub.N.Cmp(priv.PublicKey.N) != 0 || rsaPub.E != priv.PublicKey.E { + t.Fatal("RSA key mismatch after roundtrip") + } +} + +func TestParseJWKSet_EC_Roundtrip(t *testing.T) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + set := JWKSet{Keys: []JWK{jwkFromEC(t, &priv.PublicKey, "ec-key")}} + blob, err := json.Marshal(set) + if err != nil { + t.Fatalf("marshal: %v", err) + } + parsed, err := ParseJWKSet(string(blob)) + if err != nil { + t.Fatalf("ParseJWKSet: %v", err) + } + k := parsed.FindByKid("ec-key") + pub, err := k.ToPublicKey() + if err != nil { + t.Fatalf("ToPublicKey: %v", err) + } + ecPub, ok := pub.(*ecdsa.PublicKey) + if !ok { + t.Fatalf("expected *ecdsa.PublicKey, got %T", pub) + } + if ecPub.X.Cmp(priv.PublicKey.X) != 0 || ecPub.Y.Cmp(priv.PublicKey.Y) != 0 { + t.Fatal("EC key mismatch after roundtrip") + } +} + +func TestParseJWKSet_EmptyKeys(t *testing.T) { + _, err := ParseJWKSet(`{"keys":[]}`) + if !errors.Is(err, ErrInvalidJWKS) { + t.Fatalf("expected ErrInvalidJWKS, got %v", err) + } +} + +func TestParseJWKSet_InvalidJSON(t *testing.T) { + _, err := ParseJWKSet(`not json`) + if !errors.Is(err, ErrInvalidJWKS) { + t.Fatalf("expected ErrInvalidJWKS, got %v", err) + } +} + +func TestFindByKid_SingleKeyNoKid(t *testing.T) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + set := &JWKSet{Keys: []JWK{jwkFromRSA(t, &priv.PublicKey, "")}} + if got := set.FindByKid(""); got == nil { + t.Fatal("expected single key match when kid empty") + } +} + +func TestFindByKid_MultipleKeysEmptyKid(t *testing.T) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + set := &JWKSet{Keys: []JWK{ + jwkFromRSA(t, &priv.PublicKey, "a"), + jwkFromRSA(t, &priv.PublicKey, "b"), + }} + if got := set.FindByKid(""); got != nil { + t.Fatal("expected nil when multiple keys and empty kid") + } +} + +func TestToPublicKey_UnsupportedKty(t *testing.T) { + k := &JWK{Kty: "oct"} + _, err := k.ToPublicKey() + if !errors.Is(err, ErrUnsupportedKeyType) { + t.Fatalf("expected ErrUnsupportedKeyType, got %v", err) + } +} + +func TestToPublicKey_RSA_MissingField(t *testing.T) { + k := &JWK{Kty: "RSA", N: "AQAB"} // missing e + _, err := k.ToPublicKey() + if !errors.Is(err, ErrJWKFieldMissing) { + t.Fatalf("expected ErrJWKFieldMissing, got %v", err) + } +} + +func TestToPublicKey_RSA_WeakKey(t *testing.T) { + priv, err := rsa.GenerateKey(rand.Reader, 1024) // below min 2048 + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + k := jwkFromRSA(t, &priv.PublicKey, "weak") + _, err = k.ToPublicKey() + if !errors.Is(err, ErrInvalidJWKS) { + t.Fatalf("expected ErrInvalidJWKS for weak key, got %v", err) + } +} + +func TestToPublicKey_EC_UnsupportedCurve(t *testing.T) { + k := &JWK{Kty: "EC", Crv: "P-384", X: "AA", Y: "AA"} + _, err := k.ToPublicKey() + if !errors.Is(err, ErrUnsupportedKeyType) { + t.Fatalf("expected ErrUnsupportedKeyType, got %v", err) + } +} + +func TestToPublicKey_EC_PointNotOnCurve(t *testing.T) { + k := &JWK{ + Kty: "EC", + Crv: "P-256", + // Valid base64 but nonsense coordinates → not on curve + X: base64.RawURLEncoding.EncodeToString([]byte{1, 2, 3, 4}), + Y: base64.RawURLEncoding.EncodeToString([]byte{5, 6, 7, 8}), + } + _, err := k.ToPublicKey() + if !errors.Is(err, ErrInvalidJWKS) { + t.Fatalf("expected ErrInvalidJWKS for off-curve point, got %v", err) + } +}