From df04286f795274db3c489a4b3bb5543d264fc69d Mon Sep 17 00:00:00 2001 From: karo Date: Thu, 19 Mar 2026 10:32:15 +0100 Subject: [PATCH 01/10] feat: add `deviceFlowEnabled` property to `Application` Introduced `deviceFlowEnabled` to `Application` model and repository. Updated migration, mapping functions, and SQL queries to support the new property. Signed-off-by: karo --- .../postgres/migrations/18_deviceFlow.sql | 6 ++++++ internal/repositories/applications.go | 18 ++++++++++++++++++ internal/repositories/postgres/applications.go | 10 ++++++++++ 3 files changed, 34 insertions(+) create mode 100644 internal/database/postgres/migrations/18_deviceFlow.sql diff --git a/internal/database/postgres/migrations/18_deviceFlow.sql b/internal/database/postgres/migrations/18_deviceFlow.sql new file mode 100644 index 00000000..b9f7baab --- /dev/null +++ b/internal/database/postgres/migrations/18_deviceFlow.sql @@ -0,0 +1,6 @@ +-- +migrate Up + +alter table applications add column device_flow_enabled boolean not null default false; + +-- +migrate Down + diff --git a/internal/repositories/applications.go b/internal/repositories/applications.go index c9be147f..8151c6a3 100644 --- a/internal/repositories/applications.go +++ b/internal/repositories/applications.go @@ -27,6 +27,7 @@ const ( ApplicationChangeRedirectUris ApplicationChangePostLogoutRedirectUris ApplicationChangeSystemApplication + ApplicationChangeDeviceFlowEnabled ) type Application struct { @@ -48,6 +49,8 @@ type Application struct { claimsMappingScript *string accessTokenHeaderType string + + deviceFlowEnabled bool } func NewApplication(virtualServerId uuid.UUID, projectId uuid.UUID, name string, displayName string, type_ ApplicationType, redirectUris []string) *Application { @@ -78,6 +81,7 @@ func NewApplicationFromDB( systemApplication bool, claimsMappingScript *string, accessTokenHeaderType string, + deviceFlowEnabled bool, ) *Application { return &Application{ BaseModel: base, @@ -93,6 +97,7 @@ func NewApplicationFromDB( systemApplication: systemApplication, claimsMappingScript: claimsMappingScript, accessTokenHeaderType: accessTokenHeaderType, + deviceFlowEnabled: deviceFlowEnabled, } } @@ -211,6 +216,19 @@ func (a *Application) SetSystemApplication(systemApplication bool) { a.TrackChange(ApplicationChangeSystemApplication) } +func (a *Application) DeviceFlowEnabled() bool { + return a.deviceFlowEnabled +} + +func (a *Application) SetDeviceFlowEnabled(deviceFlowEnabled bool) { + if a.deviceFlowEnabled == deviceFlowEnabled { + return + } + + a.deviceFlowEnabled = deviceFlowEnabled + a.TrackChange(ApplicationChangeDeviceFlowEnabled) +} + type ApplicationFilter struct { PagingInfo OrderInfo diff --git a/internal/repositories/postgres/applications.go b/internal/repositories/postgres/applications.go index 9f23fb67..5d45c240 100644 --- a/internal/repositories/postgres/applications.go +++ b/internal/repositories/postgres/applications.go @@ -29,6 +29,7 @@ type postgresApplication struct { systemApplication bool claimsMappingScript sql.NullString accessTokenHeaderType string + deviceFlowEnabled bool } func mapApplication(a *repositories.Application) *postgresApplication { @@ -45,6 +46,7 @@ func mapApplication(a *repositories.Application) *postgresApplication { systemApplication: a.SystemApplication(), claimsMappingScript: pghelpers.WrapStringPointer(a.ClaimsMappingScript()), accessTokenHeaderType: a.AccessTokenHeaderType(), + deviceFlowEnabled: a.DeviceFlowEnabled(), } } @@ -62,6 +64,7 @@ func (a *postgresApplication) Map() *repositories.Application { a.systemApplication, pghelpers.UnwrapNullString(a.claimsMappingScript), a.accessTokenHeaderType, + a.deviceFlowEnabled, ) } @@ -83,6 +86,7 @@ func (a *postgresApplication) scan(row pghelpers.Row, additionalPtrs ...any) err &a.systemApplication, &a.claimsMappingScript, &a.accessTokenHeaderType, + &a.deviceFlowEnabled, } ptrs = append(ptrs, additionalPtrs...) @@ -121,6 +125,7 @@ func (r *ApplicationRepository) selectQuery(filter *repositories.ApplicationFilt "system_application", "claims_mapping_script", "access_token_header_type", + "device_flow_enabled", ).From("applications") if filter.HasName() { @@ -244,6 +249,7 @@ func (r *ApplicationRepository) ExecuteInsert(ctx context.Context, tx *sql.Tx, a "system_application", "claims_mapping_script", "access_token_header_type", + "device_flow_enabled", ). Values( mapped.id, @@ -260,6 +266,7 @@ func (r *ApplicationRepository) ExecuteInsert(ctx context.Context, tx *sql.Tx, a mapped.systemApplication, mapped.claimsMappingScript, mapped.accessTokenHeaderType, + mapped.deviceFlowEnabled, ). Returning("xmin") @@ -316,6 +323,9 @@ func (r *ApplicationRepository) ExecuteUpdate(ctx context.Context, tx *sql.Tx, a case repositories.ApplicationChangeSystemApplication: s.SetMore(s.Assign("system_application", mapped.systemApplication)) + case repositories.ApplicationChangeDeviceFlowEnabled: + s.SetMore(s.Assign("device_flow_enabled", mapped.deviceFlowEnabled)) + default: return fmt.Errorf("updating field %v is not supported", field) } From 71778fc6aa5521e41ed94df0dfa24b002b249eb6 Mon Sep 17 00:00:00 2001 From: karo Date: Thu, 19 Mar 2026 10:32:47 +0100 Subject: [PATCH 02/10] feat: add support for device code flow tokens Added `OidcDeviceCodeTokenType` and `OidcUserCodeTokenType` to token types. Introduced `StoreToken` method for token service and `DeviceCodeInfo` structure for managing device code flow data. Updated `LoginInfo` to include `deviceCode`. Signed-off-by: karo --- internal/jsonTypes/DeviceCodeInfo.go | 17 +++++++++++++++++ internal/jsonTypes/LoginInfo.go | 1 + internal/services/tokens.go | 15 +++++++++++++++ 3 files changed, 33 insertions(+) create mode 100644 internal/jsonTypes/DeviceCodeInfo.go diff --git a/internal/jsonTypes/DeviceCodeInfo.go b/internal/jsonTypes/DeviceCodeInfo.go new file mode 100644 index 00000000..ee23882f --- /dev/null +++ b/internal/jsonTypes/DeviceCodeInfo.go @@ -0,0 +1,17 @@ +package jsonTypes + +type DeviceCodeStatus string + +const ( + DeviceCodeStatusPending DeviceCodeStatus = "pending" + DeviceCodeStatusAuthorized DeviceCodeStatus = "authorized" + DeviceCodeStatusDenied DeviceCodeStatus = "denied" +) + +type DeviceCodeInfo struct { + VirtualServerName string + ClientId string + GrantedScopes []string + Status string + UserId *string +} diff --git a/internal/jsonTypes/LoginInfo.go b/internal/jsonTypes/LoginInfo.go index f17032ec..d62737c5 100644 --- a/internal/jsonTypes/LoginInfo.go +++ b/internal/jsonTypes/LoginInfo.go @@ -28,6 +28,7 @@ type LoginInfo struct { UserId uuid.UUID `json:"userId"` OriginalUrl string `json:"originalUrl"` TotpSecret string `json:"totpSecret"` + DeviceCode string `json:"deviceCode"` } func NewLoginInfo(virtualServer *repositories.VirtualServer, application *repositories.Application, originalUrl string) LoginInfo { diff --git a/internal/services/tokens.go b/internal/services/tokens.go index ac9421f5..27ad00bd 100644 --- a/internal/services/tokens.go +++ b/internal/services/tokens.go @@ -20,6 +20,8 @@ const ( LoginSessionTokenType TokenType = "login_session" OidcCodeTokenType TokenType = "oidc_code" OidcRefreshTokenTokenType TokenType = "oidc_refresh_token" + OidcDeviceCodeTokenType TokenType = "oidc_device_code" + OidcUserCodeTokenType TokenType = "oidc_user_code" ) func (t TokenType) Key(token string) string { @@ -33,6 +35,7 @@ type TokenService interface { UpdateToken(ctx context.Context, tokenType TokenType, token string, value string, expiration time.Duration) error GetToken(ctx context.Context, tokenType TokenType, token string) (string, error) DeleteToken(ctx context.Context, tokenType TokenType, token string) error + StoreToken(ctx context.Context, tokenType TokenType, token string, value string, expiration time.Duration) error } type tokenService struct { @@ -100,3 +103,15 @@ func (t *tokenService) DeleteToken(ctx context.Context, tokenType TokenType, tok return nil } + +func (t *tokenService) StoreToken(ctx context.Context, tokenType TokenType, token string, value string, expiration time.Duration) error { + scope := middlewares.GetScope(ctx) + kvStore := ioc.GetDependency[keyValue.Store](scope) + + err := kvStore.Set(ctx, tokenType.Key(token), value, keyValue.WithExpiration(expiration)) + if err != nil { + return fmt.Errorf("storing token in kv: %w", err) + } + + return nil +} From f61fc425e0739be4c8df46fe14f8a30c9edbdd5b Mon Sep 17 00:00:00 2001 From: karo Date: Thu, 19 Mar 2026 11:00:36 +0100 Subject: [PATCH 03/10] feat: handle `deviceFlowEnabled` in `PatchApplication` logic Added support for updating the `deviceFlowEnabled` property in the `PatchApplication` command and HTTP handler. Updated DTO and application update logic accordingly. Signed-off-by: karo --- internal/commands/PatchApplication.go | 5 +++++ internal/handlers/applications.go | 2 ++ 2 files changed, 7 insertions(+) diff --git a/internal/commands/PatchApplication.go b/internal/commands/PatchApplication.go index 102a870c..aa84246e 100644 --- a/internal/commands/PatchApplication.go +++ b/internal/commands/PatchApplication.go @@ -22,6 +22,7 @@ type PatchApplication struct { DisplayName *string ClaimsMappingScript *string AccessTokenHeaderType *string + DeviceFlowEnabled *bool } func (a PatchApplication) LogRequest() bool { @@ -87,6 +88,10 @@ func HandlePatchApplication(ctx context.Context, command PatchApplication) (*Pat application.SetAccessTokenHeaderType(*command.AccessTokenHeaderType) } + if command.DeviceFlowEnabled != nil { + application.SetDeviceFlowEnabled(*command.DeviceFlowEnabled) + } + dbContext.Applications().Update(application) return &PatchApplicationResponse{}, nil } diff --git a/internal/handlers/applications.go b/internal/handlers/applications.go index 614822f7..cd9fe48e 100644 --- a/internal/handlers/applications.go +++ b/internal/handlers/applications.go @@ -195,6 +195,7 @@ func GetApplication(w http.ResponseWriter, r *http.Request) { type PatchApplicationRequestDto struct { DisplayName *string `json:"displayName"` ClaimsMappingScript *string `json:"customClaimsMappingScript"` + DeviceFlowEnabled *bool `json:"deviceFlowEnabled"` } // PatchApplication updates fields of a specific application by ID @@ -246,6 +247,7 @@ func PatchApplication(w http.ResponseWriter, r *http.Request) { ApplicationId: appId, DisplayName: utils.TrimSpace(dto.DisplayName), ClaimsMappingScript: dto.ClaimsMappingScript, + DeviceFlowEnabled: dto.DeviceFlowEnabled, }) if err != nil { utils.HandleHttpError(w, err) From 1abfa2774066c9222b8b837ac67ad0686698b21c Mon Sep 17 00:00:00 2001 From: karo Date: Thu, 19 Mar 2026 11:09:22 +0100 Subject: [PATCH 04/10] feat: add routes for device activation and flow handling Registered new OIDC routes for initiating device flow, handling activation pages, and success handling. Signed-off-by: karo --- internal/server/server.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/internal/server/server.go b/internal/server/server.go index eccf6854..afd9b47a 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -61,6 +61,10 @@ func Serve(dp *ioc.DependencyProvider, serverConfig config.ServerConfig) { oidcRouter.HandleFunc("/token", handlers.OidcToken).Methods(http.MethodPost, http.MethodOptions) oidcRouter.HandleFunc("/userinfo", handlers.OidcUserinfo).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) oidcRouter.HandleFunc("/end_session", handlers.OidcEndSession).Methods(http.MethodGet, http.MethodOptions) + oidcRouter.HandleFunc("/device", handlers.BeginDeviceFlow).Methods(http.MethodPost, http.MethodOptions) + oidcRouter.HandleFunc("/activate", handlers.GetActivatePage).Methods(http.MethodGet) + oidcRouter.HandleFunc("/activate", handlers.PostActivatePage).Methods(http.MethodPost) + oidcRouter.HandleFunc("/activate/success", handlers.ActivateSuccess).Methods(http.MethodGet) loginRouter := r.PathPrefix("/logins").Subrouter() From 97b0e8aa150e713eb09ddca79e3d2d1b68b638b5 Mon Sep 17 00:00:00 2001 From: karo Date: Thu, 19 Mar 2026 11:13:52 +0100 Subject: [PATCH 05/10] feat: implement complete OIDC device code flow Added `DeviceAuthorizationEndpoint` to OpenID configuration. Introduced support for the device code grant type and accompanying OIDC endpoints, including device flow initiation, activation, and token issuance. Updated `DeviceCodeInfo` structure and validation logic. Signed-off-by: karo --- internal/handlers/login.go | 36 +++ internal/handlers/oidc.go | 423 ++++++++++++++++++++++++++- internal/jsonTypes/DeviceCodeInfo.go | 3 +- 3 files changed, 455 insertions(+), 7 deletions(-) diff --git a/internal/handlers/login.go b/internal/handlers/login.go index 657442d4..6a81a3fd 100644 --- a/internal/handlers/login.go +++ b/internal/handlers/login.go @@ -544,6 +544,42 @@ func FinishLogin(w http.ResponseWriter, r *http.Request) { return } + if loginInfo.DeviceCode != "" { + deviceCodeInfoString, err := tokenService.GetToken(ctx, services.OidcDeviceCodeTokenType, loginInfo.DeviceCode) + if err != nil { + utils.HandleHttpError(w, fmt.Errorf("getting device code info: %w", err)) + return + } + + var deviceCodeInfo jsonTypes.DeviceCodeInfo + if err := json.Unmarshal([]byte(deviceCodeInfoString), &deviceCodeInfo); err != nil { + utils.HandleHttpError(w, fmt.Errorf("unmarshaling device code info: %w", err)) + return + } + + userIdStr := loginInfo.UserId + deviceCodeInfo.Status = string(jsonTypes.DeviceCodeStatusAuthorized) + deviceCodeInfo.UserId = userIdStr.String() + + updatedInfoJson, err := json.Marshal(deviceCodeInfo) + if err != nil { + utils.HandleHttpError(w, fmt.Errorf("marshaling device code info: %w", err)) + return + } + + if err := tokenService.UpdateToken(ctx, services.OidcDeviceCodeTokenType, loginInfo.DeviceCode, string(updatedInfoJson), 10*time.Minute); err != nil { + utils.HandleHttpError(w, fmt.Errorf("updating device code info: %w", err)) + return + } + + if deviceCodeInfo.UserCode != "" { + _ = tokenService.DeleteToken(ctx, services.OidcUserCodeTokenType, deviceCodeInfo.UserCode) + } + + http.Redirect(w, r, fmt.Sprintf("%s/oidc/%s/activate/success", config.C.Server.ExternalUrl, loginInfo.VirtualServerName), http.StatusFound) + return + } + http.Redirect(w, r, loginInfo.OriginalUrl, http.StatusFound) } diff --git a/internal/handlers/oidc.go b/internal/handlers/oidc.go index c11f8ded..1423d36d 100644 --- a/internal/handlers/oidc.go +++ b/internal/handlers/oidc.go @@ -172,6 +172,7 @@ type OpenIdConfigurationResponseDto struct { TokenEndpoint string `json:"token_endpoint"` UserinfoEndpoint string `json:"userinfo_endpoint"` EndSessionEndpoint string `json:"end_session_endpoint"` + DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint"` JwksUri string `json:"jwks_uri"` ResponseTypesSupported []string `json:"response_types_supported"` SubjectTypesSupported []string `json:"subject_types_supported"` @@ -217,18 +218,19 @@ func WellKnownOpenIdConfiguration(w http.ResponseWriter, r *http.Request) { responseDto := OpenIdConfigurationResponseDto{ Issuer: fmt.Sprintf("%s/oidc/%s", config.C.Server.ExternalUrl, vsName), - AuthorizationEndpoint: fmt.Sprintf("%s/oidc/%s/authorize", config.C.Server.ExternalUrl, vsName), - TokenEndpoint: fmt.Sprintf("%s/oidc/%s/token", config.C.Server.ExternalUrl, vsName), - UserinfoEndpoint: fmt.Sprintf("%s/oidc/%s/userinfo", config.C.Server.ExternalUrl, vsName), - EndSessionEndpoint: fmt.Sprintf("%s/oidc/%s/end_session", config.C.Server.ExternalUrl, vsName), - JwksUri: fmt.Sprintf("%s/oidc/%s/.well-known/jwks.json", config.C.Server.ExternalUrl, vsName), + AuthorizationEndpoint: fmt.Sprintf("%s/oidc/%s/authorize", config.C.Server.ExternalUrl, vsName), + TokenEndpoint: fmt.Sprintf("%s/oidc/%s/token", config.C.Server.ExternalUrl, vsName), + UserinfoEndpoint: fmt.Sprintf("%s/oidc/%s/userinfo", config.C.Server.ExternalUrl, vsName), + EndSessionEndpoint: fmt.Sprintf("%s/oidc/%s/end_session", config.C.Server.ExternalUrl, vsName), + DeviceAuthorizationEndpoint: fmt.Sprintf("%s/oidc/%s/device", config.C.Server.ExternalUrl, vsName), + JwksUri: fmt.Sprintf("%s/oidc/%s/.well-known/jwks.json", config.C.Server.ExternalUrl, vsName), ResponseTypesSupported: []string{"code"}, // TODO: maybe support more RequestParameterSupported: true, SubjectTypesSupported: []string{"public"}, IdTokenSigningAlgValuesSupported: []string{string(virtualServer.SigningAlgorithm())}, TokenEndpointAuthMethodsSupported: []string{"client_secret_basic", "client_secret_post"}, - GrantTypesSupported: []string{"authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:token-exchange"}, + GrantTypesSupported: []string{"authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:token-exchange", "urn:ietf:params:oauth:grant-type:device_code"}, ScopesSupported: []string{"openid", "email", "profile"}, // TODO: get from db ClaimsSupported: []string{"sub", "name", "email"}, // TODO: get from db @@ -837,6 +839,9 @@ func OidcToken(w http.ResponseWriter, r *http.Request) { case "urn:ietf:params:oauth:grant-type:token-exchange": handleTokenExchange(w, r) + case "urn:ietf:params:oauth:grant-type:device_code": + handleDeviceCodeGrant(w, r) + default: utils.HandleHttpError(w, fmt.Errorf("unsupported grant type: %s", grantType)) return @@ -1682,3 +1687,409 @@ func handleTokenExchange(w http.ResponseWriter, r *http.Request) { return } } + +type DeviceAuthorizationResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationUri string `json:"verification_uri"` + VerificationUriComplete string `json:"verification_uri_complete"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +func generateUserCode() string { + const chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + b := utils.GetSecureRandomBytes(8) + result := make([]byte, 9) // XXXX-XXXX + for i := 0; i < 4; i++ { + result[i] = chars[int(b[i])%len(chars)] + } + result[4] = '-' + for i := 0; i < 4; i++ { + result[5+i] = chars[int(b[4+i])%len(chars)] + } + return string(result) +} + +func BeginDeviceFlow(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + scope := middlewares.GetScope(ctx) + + err := r.ParseForm() + if err != nil { + utils.HandleHttpError(w, err) + return + } + + vsName, err := middlewares.GetVirtualServerName(ctx) + if err != nil { + utils.HandleHttpError(w, err) + return + } + + clientId := r.Form.Get("client_id") + if clientId == "" { + writeOAuthError(w, "invalid_client", "client_id is required") + return + } + + scopeParam := r.Form.Get("scope") + scopes := strings.Split(scopeParam, " ") + + if !slices.Contains(scopes, "openid") { + writeOAuthError(w, "invalid_scope", "required openid scope missing") + return + } + + dbContext := ioc.GetDependency[database.Context](scope) + + virtualServerFilter := repositories.NewVirtualServerFilter().Name(vsName) + virtualServer, err := dbContext.VirtualServers().FirstOrNil(ctx, virtualServerFilter) + if err != nil { + utils.HandleHttpError(w, fmt.Errorf("getting virtual server: %w", err)) + return + } + if virtualServer == nil { + utils.HandleHttpError(w, fmt.Errorf("virtual server not found")) + return + } + + applicationFilter := repositories.NewApplicationFilter(). + VirtualServerId(virtualServer.Id()). + Name(clientId) + application, err := dbContext.Applications().FirstOrNil(ctx, applicationFilter) + if err != nil { + utils.HandleHttpError(w, fmt.Errorf("getting application: %w", err)) + return + } + if application == nil { + writeOAuthError(w, "invalid_client", "application not found") + return + } + + if !application.DeviceFlowEnabled() { + writeOAuthError(w, "unauthorized_client", "device flow is not enabled for this application") + return + } + + userCode := generateUserCode() + + deviceCodeInfo := jsonTypes.DeviceCodeInfo{ + VirtualServerName: vsName, + ClientId: clientId, + GrantedScopes: scopes, + Status: string(jsonTypes.DeviceCodeStatusPending), + UserCode: userCode, + } + + deviceCodeInfoJson, err := json.Marshal(deviceCodeInfo) + if err != nil { + utils.HandleHttpError(w, fmt.Errorf("marshaling device code info: %w", err)) + return + } + + tokenService := ioc.GetDependency[services.TokenService](scope) + + deviceCode, err := tokenService.GenerateAndStoreToken(ctx, services.OidcDeviceCodeTokenType, string(deviceCodeInfoJson), 10*time.Minute) + if err != nil { + utils.HandleHttpError(w, fmt.Errorf("generating device code: %w", err)) + return + } + + err = tokenService.StoreToken(ctx, services.OidcUserCodeTokenType, userCode, deviceCode, 10*time.Minute) + if err != nil { + utils.HandleHttpError(w, fmt.Errorf("storing user code: %w", err)) + return + } + + verificationUri := fmt.Sprintf("%s/%s/activate", config.C.Frontend.ExternalUrl, vsName) + verificationUriComplete := fmt.Sprintf("%s?user_code=%s", verificationUri, userCode) + + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + response := DeviceAuthorizationResponse{ + DeviceCode: deviceCode, + UserCode: userCode, + VerificationUri: verificationUri, + VerificationUriComplete: verificationUriComplete, + ExpiresIn: 600, + Interval: 5, + } + err = json.NewEncoder(w).Encode(response) + if err != nil { + utils.HandleHttpError(w, fmt.Errorf("encoding response: %w", err)) + return + } +} + +func handleDeviceCodeGrant(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + scope := middlewares.GetScope(ctx) + + clientId, clientSecret, hasBasicAuth := r.BasicAuth() + if !hasBasicAuth { + clientId = r.Form.Get("client_id") + clientSecret = "" + } + + deviceCode := r.Form.Get("device_code") + if deviceCode == "" { + writeOAuthError(w, "invalid_request", "device_code is required") + return + } + + tokenService := ioc.GetDependency[services.TokenService](scope) + valueString, err := tokenService.GetToken(ctx, services.OidcDeviceCodeTokenType, deviceCode) + if err != nil { + writeOAuthError(w, "expired_token", "device code has expired or is invalid") + return + } + + var deviceCodeInfo jsonTypes.DeviceCodeInfo + err = json.Unmarshal([]byte(valueString), &deviceCodeInfo) + if err != nil { + utils.HandleHttpError(w, fmt.Errorf("unmarshaling device code info: %w", err)) + return + } + + switch deviceCodeInfo.Status { + case string(jsonTypes.DeviceCodeStatusPending): + writeOAuthError(w, "authorization_pending", "the authorization request is still pending") + return + case string(jsonTypes.DeviceCodeStatusDenied): + writeOAuthError(w, "access_denied", "the user denied the authorization request") + return + case string(jsonTypes.DeviceCodeStatusAuthorized): + // continue + default: + writeOAuthError(w, "server_error", "unexpected device code status") + return + } + + if deviceCodeInfo.ClientId != clientId { + writeOAuthError(w, "invalid_client", "client_id mismatch") + return + } + + userIdStr := deviceCodeInfo.UserId + userId, err := uuid.Parse(userIdStr) + if err != nil { + utils.HandleHttpError(w, fmt.Errorf("parsing user id: %w", err)) + return + } + + err = tokenService.DeleteToken(ctx, services.OidcDeviceCodeTokenType, deviceCode) + if err != nil { + utils.HandleHttpError(w, fmt.Errorf("deleting device code: %w", err)) + return + } + + dbContext := ioc.GetDependency[database.Context](scope) + + virtualServerFilter := repositories.NewVirtualServerFilter().Name(deviceCodeInfo.VirtualServerName) + virtualServer, err := dbContext.VirtualServers().FirstOrNil(ctx, virtualServerFilter) + if err != nil { + utils.HandleHttpError(w, fmt.Errorf("getting virtual server: %w", err)) + return + } + if virtualServer == nil { + utils.HandleHttpError(w, fmt.Errorf("virtual server not found")) + return + } + + applicationFilter := repositories.NewApplicationFilter(). + VirtualServerId(virtualServer.Id()). + Name(clientId) + application, err := dbContext.Applications().FirstOrNil(ctx, applicationFilter) + if err != nil { + utils.HandleHttpError(w, fmt.Errorf("getting application: %w", err)) + return + } + if application == nil { + utils.HandleHttpError(w, fmt.Errorf("application not found")) + return + } + + _ = clientSecret // public clients don't need a secret + + userFilter := repositories.NewUserFilter().Id(userId) + user, err := dbContext.Users().FirstOrNil(ctx, userFilter) + if err != nil { + utils.HandleHttpError(w, fmt.Errorf("getting user: %w", err)) + return + } + if user == nil { + utils.HandleHttpError(w, fmt.Errorf("user not found")) + return + } + + keyService := ioc.GetDependency[services.KeyService](scope) + keyPair := keyService.GetKey(deviceCodeInfo.VirtualServerName, virtualServer.SigningAlgorithm()) + + clockService := ioc.GetDependency[clock.Service](scope) + now := clockService.Now() + + tokenDuration := time.Hour + + params := TokenGenerationParams{ + UserId: userId, + VirtualServerName: deviceCodeInfo.VirtualServerName, + ClientId: clientId, + ApplicationId: application.Id(), + GrantedScopes: deviceCodeInfo.GrantedScopes, + UserDisplayName: user.DisplayName(), + UserPrimaryEmail: user.PrimaryEmail(), + ExternalUrl: config.C.Server.ExternalUrl, + KeyPair: keyPair, + IssuedAt: now, + AccessTokenExpiry: tokenDuration, + IdTokenExpiry: tokenDuration, + RefreshTokenExpiry: tokenDuration, + AccessTokenHeaderType: application.AccessTokenHeaderType(), + } + + tokens, err := generateTokens(ctx, params, tokenService) + if err != nil { + utils.HandleHttpError(w, err) + return + } + + w.Header().Set("Cache-Control", "no-store") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + scopeString := strings.Join(deviceCodeInfo.GrantedScopes, " ") + response := CodeFlowResponse{ + TokenType: "Bearer", + IdToken: tokens.IdToken, + AccessToken: tokens.AccessToken, + RefreshToken: tokens.RefreshToken, + Scope: scopeString, + ExpiresIn: tokens.ExpiresIn, + } + err = json.NewEncoder(w).Encode(response) + if err != nil { + utils.HandleHttpError(w, fmt.Errorf("encoding response: %w", err)) + return + } +} + +func GetActivatePage(w http.ResponseWriter, r *http.Request) { + vsName, err := middlewares.GetVirtualServerName(r.Context()) + if err != nil { + utils.HandleHttpError(w, err) + return + } + + redirectUrl := fmt.Sprintf("%s/%s/activate", config.C.Frontend.ExternalUrl, vsName) + if userCode := r.URL.Query().Get("user_code"); userCode != "" { + redirectUrl += "?user_code=" + url.QueryEscape(userCode) + } + http.Redirect(w, r, redirectUrl, http.StatusFound) +} + +func PostActivatePage(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + scope := middlewares.GetScope(ctx) + + err := r.ParseForm() + if err != nil { + utils.HandleHttpError(w, err) + return + } + + vsName, err := middlewares.GetVirtualServerName(ctx) + if err != nil { + utils.HandleHttpError(w, err) + return + } + + userCode := r.Form.Get("user_code") + if userCode == "" { + http.Error(w, "user_code is required", http.StatusBadRequest) + return + } + + tokenService := ioc.GetDependency[services.TokenService](scope) + + deviceCode, err := tokenService.GetToken(ctx, services.OidcUserCodeTokenType, userCode) + if err != nil { + http.Error(w, "Invalid or expired device code. Please request a new one.", http.StatusNotFound) + return + } + + deviceCodeInfoString, err := tokenService.GetToken(ctx, services.OidcDeviceCodeTokenType, deviceCode) + if err != nil { + http.Error(w, "Invalid or expired device code. Please request a new one.", http.StatusNotFound) + return + } + + var deviceCodeInfo jsonTypes.DeviceCodeInfo + err = json.Unmarshal([]byte(deviceCodeInfoString), &deviceCodeInfo) + if err != nil { + utils.HandleHttpError(w, fmt.Errorf("unmarshaling device code info: %w", err)) + return + } + + if deviceCodeInfo.Status != string(jsonTypes.DeviceCodeStatusPending) { + http.Error(w, "This device code has already been used.", http.StatusBadRequest) + return + } + + dbContext := ioc.GetDependency[database.Context](scope) + + virtualServerFilter := repositories.NewVirtualServerFilter().Name(vsName) + virtualServer, err := dbContext.VirtualServers().FirstOrNil(ctx, virtualServerFilter) + if err != nil { + utils.HandleHttpError(w, fmt.Errorf("getting virtual server: %w", err)) + return + } + if virtualServer == nil { + utils.HandleHttpError(w, fmt.Errorf("virtual server not found")) + return + } + + applicationFilter := repositories.NewApplicationFilter(). + VirtualServerId(virtualServer.Id()). + Name(deviceCodeInfo.ClientId) + application, err := dbContext.Applications().FirstOrNil(ctx, applicationFilter) + if err != nil { + utils.HandleHttpError(w, fmt.Errorf("getting application: %w", err)) + return + } + if application == nil { + utils.HandleHttpError(w, fmt.Errorf("application not found")) + return + } + + loginInfo := jsonTypes.NewLoginInfo(virtualServer, application, fmt.Sprintf("%s/oidc/%s/activate", config.C.Server.ExternalUrl, vsName)) + loginInfo.DeviceCode = deviceCode + + loginInfoString, err := json.Marshal(loginInfo) + if err != nil { + utils.HandleHttpError(w, fmt.Errorf("marshaling login info: %w", err)) + return + } + + loginSessionToken, err := tokenService.GenerateAndStoreToken(ctx, services.LoginSessionTokenType, string(loginInfoString), 15*time.Minute) + if err != nil { + utils.HandleHttpError(w, fmt.Errorf("generating login session token: %w", err)) + return + } + + redirectUrl := fmt.Sprintf("%s/login?token=%s", config.C.Frontend.ExternalUrl, loginSessionToken) + http.Redirect(w, r, redirectUrl, http.StatusFound) +} + +func ActivateSuccess(w http.ResponseWriter, r *http.Request) { + vsName, err := middlewares.GetVirtualServerName(r.Context()) + if err != nil { + utils.HandleHttpError(w, err) + return + } + + redirectUrl := fmt.Sprintf("%s/%s/activate/success", config.C.Frontend.ExternalUrl, vsName) + http.Redirect(w, r, redirectUrl, http.StatusFound) +} diff --git a/internal/jsonTypes/DeviceCodeInfo.go b/internal/jsonTypes/DeviceCodeInfo.go index ee23882f..d18f8749 100644 --- a/internal/jsonTypes/DeviceCodeInfo.go +++ b/internal/jsonTypes/DeviceCodeInfo.go @@ -13,5 +13,6 @@ type DeviceCodeInfo struct { ClientId string GrantedScopes []string Status string - UserId *string + UserId string + UserCode string } From 4715bfde8f21aef40c97ab217d2070ac59599fc2 Mon Sep 17 00:00:00 2001 From: karo Date: Thu, 19 Mar 2026 11:14:20 +0100 Subject: [PATCH 06/10] feat: add OIDC client with device flow and token polling support Implemented `OidcClient` to handle the device code flow, including initiating the device authorization and polling for tokens. Extended `Transport` with `NewOidcRequest` and `DoRaw` methods to support OIDC-related HTTP requests. Updated the main client interface to expose the `Oidc` client. Signed-off-by: karo --- client/client.go | 5 ++ client/oidc.go | 126 ++++++++++++++++++++++++++++++++++++++++++++ client/transport.go | 13 +++++ 3 files changed, 144 insertions(+) create mode 100644 client/oidc.go diff --git a/client/client.go b/client/client.go index 0516452d..5910f7a0 100644 --- a/client/client.go +++ b/client/client.go @@ -4,6 +4,7 @@ type Client interface { Application() ApplicationClient VirtualServer() VirtualServerClient User() UserClient + Oidc() OidcClient } type client struct { @@ -27,3 +28,7 @@ func (c *client) VirtualServer() VirtualServerClient { func (c *client) User() UserClient { return NewUserClient(c.transport) } + +func (c *client) Oidc() OidcClient { + return NewOidcClient(c.transport) +} diff --git a/client/oidc.go b/client/oidc.go new file mode 100644 index 00000000..01207279 --- /dev/null +++ b/client/oidc.go @@ -0,0 +1,126 @@ +package client + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" +) + +var ( + ErrAuthorizationPending = errors.New("authorization_pending") + ErrAccessDenied = errors.New("access_denied") + ErrExpiredToken = errors.New("expired_token") + ErrSlowDown = errors.New("slow_down") +) + +type DeviceAuthorizationResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationUri string `json:"verification_uri"` + VerificationUriComplete string `json:"verification_uri_complete"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +type DeviceTokenResponse struct { + TokenType string `json:"token_type"` + IdToken string `json:"id_token"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + Scope string `json:"scope"` + ExpiresIn int `json:"expires_in"` +} + +type OidcClient interface { + BeginDeviceFlow(ctx context.Context, clientId string, scope string) (DeviceAuthorizationResponse, error) + PollDeviceToken(ctx context.Context, clientId string, deviceCode string) (DeviceTokenResponse, error) +} + +type oidcClient struct { + transport *Transport +} + +func NewOidcClient(transport *Transport) OidcClient { + return &oidcClient{transport: transport} +} + +func (o *oidcClient) BeginDeviceFlow(ctx context.Context, clientId string, scope string) (DeviceAuthorizationResponse, error) { + formValues := url.Values{ + "client_id": {clientId}, + "scope": {scope}, + } + + req, err := o.transport.NewOidcRequest(ctx, http.MethodPost, "/device", strings.NewReader(formValues.Encode())) + if err != nil { + return DeviceAuthorizationResponse{}, fmt.Errorf("creating request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := o.transport.Do(req) + if err != nil { + return DeviceAuthorizationResponse{}, fmt.Errorf("doing request: %w", err) + } + defer resp.Body.Close() + + var result DeviceAuthorizationResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return DeviceAuthorizationResponse{}, fmt.Errorf("decoding response: %w", err) + } + + return result, nil +} + +func (o *oidcClient) PollDeviceToken(ctx context.Context, clientId string, deviceCode string) (DeviceTokenResponse, error) { + formValues := url.Values{ + "grant_type": {"urn:ietf:params:oauth:grant-type:device_code"}, + "client_id": {clientId}, + "device_code": {deviceCode}, + } + + req, err := o.transport.NewOidcRequest(ctx, http.MethodPost, "/token", strings.NewReader(formValues.Encode())) + if err != nil { + return DeviceTokenResponse{}, fmt.Errorf("creating request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := o.transport.DoRaw(req) + if err != nil { + return DeviceTokenResponse{}, fmt.Errorf("doing request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusBadRequest { + var oauthErr struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + _ = json.NewDecoder(resp.Body).Decode(&oauthErr) + switch oauthErr.Error { + case "authorization_pending": + return DeviceTokenResponse{}, ErrAuthorizationPending + case "access_denied": + return DeviceTokenResponse{}, ErrAccessDenied + case "expired_token": + return DeviceTokenResponse{}, ErrExpiredToken + case "slow_down": + return DeviceTokenResponse{}, ErrSlowDown + default: + return DeviceTokenResponse{}, fmt.Errorf("oauth error %s: %s", oauthErr.Error, oauthErr.ErrorDescription) + } + } + + if resp.StatusCode != http.StatusOK { + return DeviceTokenResponse{}, ApiError{Message: resp.Status, Code: resp.StatusCode} + } + + var result DeviceTokenResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return DeviceTokenResponse{}, fmt.Errorf("decoding response: %w", err) + } + + return result, nil +} diff --git a/client/transport.go b/client/transport.go index 3ffa8d45..56c25d8c 100644 --- a/client/transport.go +++ b/client/transport.go @@ -69,6 +69,10 @@ func (t *Transport) NewTenantRequest(ctx context.Context, method string, endpoin return t.NewRootRequest(ctx, method, fmt.Sprintf("/api/virtual-servers/%s%s", t.virtualServer, endpoint), body) } +func (t *Transport) NewOidcRequest(ctx context.Context, method string, endpoint string, body io.Reader) (*http.Request, error) { + return t.NewRootRequest(ctx, method, fmt.Sprintf("/oidc/%s%s", t.virtualServer, endpoint), body) +} + func (t *Transport) NewRootRequest(ctx context.Context, method string, endpoint string, body io.Reader) (*http.Request, error) { base, err := url.Parse(t.baseURL) if err != nil { @@ -109,3 +113,12 @@ func (t *Transport) Do(req *http.Request) (*http.Response, error) { return response, nil } + +// DoRaw executes the request and returns the raw response without status code checking. +func (t *Transport) DoRaw(req *http.Request) (*http.Response, error) { + response, err := t.client.Do(req) + if err != nil { + return nil, fmt.Errorf("doing request: %w", err) + } + return response, nil +} From 01136192266842046dfaa704ec909f34751673b5 Mon Sep 17 00:00:00 2001 From: karo Date: Thu, 19 Mar 2026 11:14:27 +0100 Subject: [PATCH 07/10] feat: extend OIDC client with activation, login, and password verification Added `PostActivate`, `VerifyPassword`, and `FinishLogin` methods to `OidcClient`. Updated `Transport` with `DoNoRedirect` to support requests without redirections. Refactored e2e tests to utilize new client methods for device flow and activation scenarios. Signed-off-by: karo --- client/oidc.go | 87 ++++++++++++++ client/oidc_test.go | 147 +++++++++++++++++++++++ client/transport.go | 15 +++ tests/e2e/deviceflow_test.go | 219 +++++++++++++++++++++++++++++++++++ tests/e2e/harness.go | 4 + 5 files changed, 472 insertions(+) create mode 100644 client/oidc_test.go create mode 100644 tests/e2e/deviceflow_test.go diff --git a/client/oidc.go b/client/oidc.go index 01207279..b73bc3ad 100644 --- a/client/oidc.go +++ b/client/oidc.go @@ -1,6 +1,7 @@ package client import ( + "bytes" "context" "encoding/json" "errors" @@ -15,6 +16,7 @@ var ( ErrAccessDenied = errors.New("access_denied") ErrExpiredToken = errors.New("expired_token") ErrSlowDown = errors.New("slow_down") + ErrInvalidUserCode = errors.New("invalid_user_code") ) type DeviceAuthorizationResponse struct { @@ -38,6 +40,9 @@ type DeviceTokenResponse struct { type OidcClient interface { BeginDeviceFlow(ctx context.Context, clientId string, scope string) (DeviceAuthorizationResponse, error) PollDeviceToken(ctx context.Context, clientId string, deviceCode string) (DeviceTokenResponse, error) + PostActivate(ctx context.Context, userCode string) (loginToken string, err error) + VerifyPassword(ctx context.Context, loginToken string, username string, password string) error + FinishLogin(ctx context.Context, loginToken string) error } type oidcClient struct { @@ -124,3 +129,85 @@ func (o *oidcClient) PollDeviceToken(ctx context.Context, clientId string, devic return result, nil } + +func (o *oidcClient) PostActivate(ctx context.Context, userCode string) (string, error) { + formValues := url.Values{"user_code": {userCode}} + + req, err := o.transport.NewOidcRequest(ctx, http.MethodPost, "/activate", strings.NewReader(formValues.Encode())) + if err != nil { + return "", fmt.Errorf("creating request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := o.transport.DoNoRedirect(req) + if err != nil { + return "", fmt.Errorf("doing request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + return "", ErrInvalidUserCode + } + if resp.StatusCode != http.StatusFound { + return "", ApiError{Message: resp.Status, Code: resp.StatusCode} + } + + location := resp.Header.Get("Location") + parsed, err := url.Parse(location) + if err != nil { + return "", fmt.Errorf("parsing redirect location: %w", err) + } + + token := parsed.Query().Get("token") + if token == "" { + return "", fmt.Errorf("no login token in redirect location") + } + + return token, nil +} + +func (o *oidcClient) VerifyPassword(ctx context.Context, loginToken string, username string, password string) error { + body, err := json.Marshal(map[string]string{"username": username, "password": password}) + if err != nil { + return fmt.Errorf("marshaling request: %w", err) + } + + req, err := o.transport.NewRootRequest(ctx, http.MethodPost, fmt.Sprintf("/logins/%s/verify-password", loginToken), bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("creating request: %w", err) + } + + resp, err := o.transport.DoRaw(req) + if err != nil { + return fmt.Errorf("doing request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusUnauthorized { + return fmt.Errorf("invalid credentials") + } + if resp.StatusCode >= 400 { + return ApiError{Message: resp.Status, Code: resp.StatusCode} + } + + return nil +} + +func (o *oidcClient) FinishLogin(ctx context.Context, loginToken string) error { + req, err := o.transport.NewRootRequest(ctx, http.MethodPost, fmt.Sprintf("/logins/%s/finish-login", loginToken), nil) + if err != nil { + return fmt.Errorf("creating request: %w", err) + } + + resp, err := o.transport.DoNoRedirect(req) + if err != nil { + return fmt.Errorf("doing request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + return ApiError{Message: resp.Status, Code: resp.StatusCode} + } + + return nil +} diff --git a/client/oidc_test.go b/client/oidc_test.go new file mode 100644 index 00000000..755b6a40 --- /dev/null +++ b/client/oidc_test.go @@ -0,0 +1,147 @@ +package client + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/suite" +) + +type OidcClientSuite struct { + suite.Suite +} + +func TestOidcClientSuite(t *testing.T) { + t.Parallel() + suite.Run(t, new(OidcClientSuite)) +} + +func (s *OidcClientSuite) TestBeginDeviceFlow_HappyPath() { + expected := DeviceAuthorizationResponse{ + DeviceCode: "device-code-abc", + UserCode: "ABCD-EFGH", + VerificationUri: "http://localhost/oidc/test/activate", + VerificationUriComplete: "http://localhost/oidc/test/activate?user_code=ABCD-EFGH", + ExpiresIn: 600, + Interval: 5, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.Equal(http.MethodPost, r.Method) + s.Equal("/oidc/test/device", r.URL.Path) + s.Equal("application/x-www-form-urlencoded", r.Header.Get("Content-Type")) + + s.NoError(r.ParseForm()) + s.Equal("my-app", r.Form.Get("client_id")) + s.Equal("openid profile", r.Form.Get("scope")) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(expected) + })) + defer server.Close() + + testee := NewClient(server.URL, "test").Oidc() + + result, err := testee.BeginDeviceFlow(s.T().Context(), "my-app", "openid profile") + + s.Require().NoError(err) + s.Equal(expected, result) +} + +func (s *OidcClientSuite) TestBeginDeviceFlow_RejectsUnknownApp() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{ + "error": "invalid_client", + "error_description": "application not found", + }) + })) + defer server.Close() + + testee := NewClient(server.URL, "test").Oidc() + + _, err := testee.BeginDeviceFlow(s.T().Context(), "nonexistent", "openid") + + s.Require().Error(err) +} + +func (s *OidcClientSuite) TestPollDeviceToken_HappyPath() { + expected := DeviceTokenResponse{ + TokenType: "Bearer", + IdToken: "id-token-value", + AccessToken: "access-token-value", + RefreshToken: "refresh-token-value", + Scope: "openid", + ExpiresIn: 3600, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s.Equal(http.MethodPost, r.Method) + s.Equal("/oidc/test/token", r.URL.Path) + s.Equal("application/x-www-form-urlencoded", r.Header.Get("Content-Type")) + + s.NoError(r.ParseForm()) + s.Equal("urn:ietf:params:oauth:grant-type:device_code", r.Form.Get("grant_type")) + s.Equal("my-app", r.Form.Get("client_id")) + s.Equal("device-code-abc", r.Form.Get("device_code")) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(expected) + })) + defer server.Close() + + testee := NewClient(server.URL, "test").Oidc() + + result, err := testee.PollDeviceToken(s.T().Context(), "my-app", "device-code-abc") + + s.Require().NoError(err) + s.Equal(expected, result) +} + +func (s *OidcClientSuite) TestPollDeviceToken_AuthorizationPending() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "authorization_pending"}) + })) + defer server.Close() + + testee := NewClient(server.URL, "test").Oidc() + + _, err := testee.PollDeviceToken(s.T().Context(), "my-app", "device-code-abc") + + s.ErrorIs(err, ErrAuthorizationPending) +} + +func (s *OidcClientSuite) TestPollDeviceToken_AccessDenied() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "access_denied"}) + })) + defer server.Close() + + testee := NewClient(server.URL, "test").Oidc() + + _, err := testee.PollDeviceToken(s.T().Context(), "my-app", "device-code-abc") + + s.ErrorIs(err, ErrAccessDenied) +} + +func (s *OidcClientSuite) TestPollDeviceToken_ExpiredToken() { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]string{"error": "expired_token"}) + })) + defer server.Close() + + testee := NewClient(server.URL, "test").Oidc() + + _, err := testee.PollDeviceToken(s.T().Context(), "my-app", "device-code-abc") + + s.ErrorIs(err, ErrExpiredToken) +} diff --git a/client/transport.go b/client/transport.go index 56c25d8c..2b79a037 100644 --- a/client/transport.go +++ b/client/transport.go @@ -122,3 +122,18 @@ func (t *Transport) DoRaw(req *http.Request) (*http.Response, error) { } return response, nil } + +// DoNoRedirect executes the request without following redirects. +func (t *Transport) DoNoRedirect(req *http.Request) (*http.Response, error) { + noRedirectClient := &http.Client{ + Transport: t.client.Transport, + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return http.ErrUseLastResponse + }, + } + response, err := noRedirectClient.Do(req) + if err != nil { + return nil, fmt.Errorf("doing request: %w", err) + } + return response, nil +} diff --git a/tests/e2e/deviceflow_test.go b/tests/e2e/deviceflow_test.go new file mode 100644 index 00000000..ed4cb66d --- /dev/null +++ b/tests/e2e/deviceflow_test.go @@ -0,0 +1,219 @@ +//go:build e2e + +package e2e + +import ( + "Keyline/client" + "Keyline/internal/authentication" + "Keyline/internal/commands" + "Keyline/internal/database" + "Keyline/internal/middlewares" + "Keyline/internal/repositories" + "Keyline/utils" + "context" + "fmt" + + "github.com/The127/ioc" + "github.com/The127/mediatr" + "github.com/google/uuid" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +const ( + deviceAppName = "test-device-app" + deviceUserUsername = "test-device-user" + deviceUserPassword = "test-device-password-1" +) + +var _ = Describe("Device Authorization Grant", Ordered, func() { + var h *harness + var deviceAppId uuid.UUID + + BeforeAll(func() { + h = newE2eTestHarness(nil) + var err error + deviceAppId, err = setupDeviceFlowFixtures(h.Scope()) + Expect(err).ToNot(HaveOccurred()) + _ = deviceAppId + }) + + AfterAll(func() { + h.Close() + }) + + Describe("POST /device", func() { + It("rejects missing client_id", func() { + _, err := h.Client().Oidc().BeginDeviceFlow(h.Ctx(), "", "openid") + Expect(err).To(HaveOccurred()) + }) + + It("rejects application without device flow enabled", func() { + _, err := h.Client().Oidc().BeginDeviceFlow(h.Ctx(), commands.AdminApplicationName, "openid") + Expect(err).To(HaveOccurred()) + }) + + It("rejects missing openid scope", func() { + _, err := h.Client().Oidc().BeginDeviceFlow(h.Ctx(), deviceAppName, "profile") + Expect(err).To(HaveOccurred()) + }) + + It("returns device authorization response", func() { + resp, err := h.Client().Oidc().BeginDeviceFlow(h.Ctx(), deviceAppName, "openid") + Expect(err).ToNot(HaveOccurred()) + Expect(resp.DeviceCode).ToNot(BeEmpty()) + Expect(resp.UserCode).ToNot(BeEmpty()) + Expect(resp.VerificationUri).ToNot(BeEmpty()) + Expect(resp.VerificationUriComplete).ToNot(BeEmpty()) + Expect(resp.ExpiresIn).To(BeEquivalentTo(600)) + Expect(resp.Interval).To(BeEquivalentTo(5)) + }) + }) + + Describe("POST /activate", func() { + It("rejects unknown user_code", func() { + _, err := h.Client().Oidc().PostActivate(h.Ctx(), "XXXX-XXXX") + Expect(err).To(MatchError(client.ErrInvalidUserCode)) + }) + }) + + Describe("Full device authorization flow", func() { + It("issues tokens after user approves", func() { + // Step 1: CLI requests device authorization + deviceResp, err := h.Client().Oidc().BeginDeviceFlow(h.Ctx(), deviceAppName, "openid") + Expect(err).ToNot(HaveOccurred()) + Expect(deviceResp.DeviceCode).ToNot(BeEmpty()) + Expect(deviceResp.UserCode).ToNot(BeEmpty()) + + // Step 2: Poll while pending + _, pollErr := h.Client().Oidc().PollDeviceToken(h.Ctx(), deviceAppName, deviceResp.DeviceCode) + Expect(pollErr).To(MatchError(client.ErrAuthorizationPending)) + + // Step 3: User submits user_code on activation page + loginToken, err := h.Client().Oidc().PostActivate(h.Ctx(), deviceResp.UserCode) + Expect(err).ToNot(HaveOccurred()) + Expect(loginToken).ToNot(BeEmpty()) + + // Step 4: User completes login (verify password) + err = h.Client().Oidc().VerifyPassword(h.Ctx(), loginToken, deviceUserUsername, deviceUserPassword) + Expect(err).ToNot(HaveOccurred()) + + // Step 5: Finish login — marks device code as authorized + err = h.Client().Oidc().FinishLogin(h.Ctx(), loginToken) + Expect(err).ToNot(HaveOccurred()) + + // Step 6: CLI polls and receives tokens + tokenResp, err := h.Client().Oidc().PollDeviceToken(h.Ctx(), deviceAppName, deviceResp.DeviceCode) + Expect(err).ToNot(HaveOccurred()) + Expect(tokenResp.AccessToken).ToNot(BeEmpty()) + Expect(tokenResp.IdToken).ToNot(BeEmpty()) + Expect(tokenResp.RefreshToken).ToNot(BeEmpty()) + Expect(tokenResp.TokenType).To(Equal("Bearer")) + }) + + It("rejects double use of device_code", func() { + deviceResp, err := h.Client().Oidc().BeginDeviceFlow(h.Ctx(), deviceAppName, "openid") + Expect(err).ToNot(HaveOccurred()) + + loginToken, err := h.Client().Oidc().PostActivate(h.Ctx(), deviceResp.UserCode) + Expect(err).ToNot(HaveOccurred()) + + err = h.Client().Oidc().VerifyPassword(h.Ctx(), loginToken, deviceUserUsername, deviceUserPassword) + Expect(err).ToNot(HaveOccurred()) + + err = h.Client().Oidc().FinishLogin(h.Ctx(), loginToken) + Expect(err).ToNot(HaveOccurred()) + + // First poll succeeds + _, err = h.Client().Oidc().PollDeviceToken(h.Ctx(), deviceAppName, deviceResp.DeviceCode) + Expect(err).ToNot(HaveOccurred()) + + // Second poll fails — one-time use + _, err = h.Client().Oidc().PollDeviceToken(h.Ctx(), deviceAppName, deviceResp.DeviceCode) + Expect(err).To(MatchError(client.ErrExpiredToken)) + }) + + It("returns expired_token for unknown device_code", func() { + _, err := h.Client().Oidc().PollDeviceToken(h.Ctx(), deviceAppName, "nonexistent-device-code") + Expect(err).To(MatchError(client.ErrExpiredToken)) + }) + }) +}) + +func setupDeviceFlowFixtures(scope *ioc.DependencyProvider) (uuid.UUID, error) { + subscope := scope.NewScope() + defer subscope.Close() + + ctx := context.Background() + ctx = middlewares.ContextWithScope(ctx, subscope) + ctx = authentication.ContextWithCurrentUser(ctx, authentication.SystemUser()) + + m := ioc.GetDependency[mediatr.Mediator](subscope) + dbContext := ioc.GetDependency[database.Context](subscope) + + _, err := mediatr.Send[*commands.CreateProjectResponse](ctx, m, commands.CreateProject{ + VirtualServerName: "test-vs", + Slug: "device-flow-project", + Name: "Device Flow Project", + }) + if err != nil { + return uuid.Nil, fmt.Errorf("creating project: %w", err) + } + if err := dbContext.SaveChanges(ctx); err != nil { + return uuid.Nil, fmt.Errorf("saving project: %w", err) + } + + appResp, err := mediatr.Send[*commands.CreateApplicationResponse](ctx, m, commands.CreateApplication{ + VirtualServerName: "test-vs", + ProjectSlug: "device-flow-project", + Name: deviceAppName, + DisplayName: "Test Device App", + Type: repositories.ApplicationTypePublic, + RedirectUris: []string{"http://localhost:9999/callback"}, + PostLogoutRedirectUris: []string{}, + }) + if err != nil { + return uuid.Nil, fmt.Errorf("creating application: %w", err) + } + if err := dbContext.SaveChanges(ctx); err != nil { + return uuid.Nil, fmt.Errorf("saving application: %w", err) + } + + _, err = mediatr.Send[*commands.PatchApplicationResponse](ctx, m, commands.PatchApplication{ + VirtualServerName: "test-vs", + ProjectSlug: "device-flow-project", + ApplicationId: appResp.Id, + DeviceFlowEnabled: utils.Ptr(true), + }) + if err != nil { + return uuid.Nil, fmt.Errorf("enabling device flow: %w", err) + } + if err := dbContext.SaveChanges(ctx); err != nil { + return uuid.Nil, fmt.Errorf("saving device flow flag: %w", err) + } + + userResp, err := mediatr.Send[*commands.CreateUserResponse](ctx, m, commands.CreateUser{ + VirtualServerName: "test-vs", + DisplayName: "Test Device User", + Username: deviceUserUsername, + Email: deviceUserUsername + "@test.local", + EmailVerified: true, + }) + if err != nil { + return uuid.Nil, fmt.Errorf("creating user: %w", err) + } + if err := dbContext.SaveChanges(ctx); err != nil { + return uuid.Nil, fmt.Errorf("saving user: %w", err) + } + + passwordCred := repositories.NewCredential(userResp.Id, &repositories.CredentialPasswordDetails{ + HashedPassword: utils.HashPassword(deviceUserPassword), + Temporary: false, + }) + dbContext.Credentials().Insert(passwordCred) + if err := dbContext.SaveChanges(ctx); err != nil { + return uuid.Nil, fmt.Errorf("saving password credential: %w", err) + } + + return appResp.Id, nil +} diff --git a/tests/e2e/harness.go b/tests/e2e/harness.go index 631e7fd4..86b5a9aa 100644 --- a/tests/e2e/harness.go +++ b/tests/e2e/harness.go @@ -93,6 +93,10 @@ func (h *harness) ApiUrl() string { return h.serverUrl } +func (h *harness) Scope() *ioc.DependencyProvider { + return h.scope +} + func newE2eTestHarness(tokenSourceGenerator func(ctx context.Context, url string) oauth2.TokenSource) *harness { ctx := context.Background() dc := ioc.NewDependencyCollection() From 8bec5174c76bb3d0447c794110351871997ae692 Mon Sep 17 00:00:00 2001 From: karo Date: Thu, 19 Mar 2026 11:20:21 +0100 Subject: [PATCH 08/10] feat: support `deviceFlowEnabled` in application creation and virtual server commands Extended DTOs, handlers, commands, and configuration to include the `deviceFlowEnabled` property for applications. Signed-off-by: karo --- cmd/api/main.go | 13 +++++++------ internal/commands/CreateApplication.go | 2 ++ internal/commands/CreateVirtualServer.go | 14 ++++++++------ internal/config/config.go | 1 + internal/handlers/applications.go | 2 ++ 5 files changed, 20 insertions(+), 12 deletions(-) diff --git a/cmd/api/main.go b/cmd/api/main.go index c7e35095..c039a92a 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -180,12 +180,13 @@ func initApplication(dp *ioc.DependencyProvider) { var apps []commands.CreateVirtualServerProjectApplication = nil for _, app := range project.Applications { apps = append(apps, commands.CreateVirtualServerProjectApplication{ - Name: app.Name, - DisplayName: app.DisplayName, - Type: app.Type, - HashedSecret: app.HashedSecret, - RedirectUris: app.RedirectUris, - PostLogoutUris: app.PostLogoutRedirectUris, + Name: app.Name, + DisplayName: app.DisplayName, + Type: app.Type, + HashedSecret: app.HashedSecret, + RedirectUris: app.RedirectUris, + PostLogoutUris: app.PostLogoutRedirectUris, + DeviceFlowEnabled: app.DeviceFlowEnabled, }) } diff --git a/internal/commands/CreateApplication.go b/internal/commands/CreateApplication.go index ec25f598..ee981c8f 100644 --- a/internal/commands/CreateApplication.go +++ b/internal/commands/CreateApplication.go @@ -27,6 +27,7 @@ type CreateApplication struct { HashedSecret *string AccessTokenHeaderType string + DeviceFlowEnabled bool } func (c CreateApplication) LogRequest() bool { @@ -88,6 +89,7 @@ func HandleCreateApplication(ctx context.Context, command CreateApplication) (*C application.SetPostLogoutRedirectUris(command.PostLogoutRedirectUris) application.SetAccessTokenHeaderType(command.AccessTokenHeaderType) + application.SetDeviceFlowEnabled(command.DeviceFlowEnabled) dbContext.Applications().Insert(application) diff --git a/internal/commands/CreateVirtualServer.go b/internal/commands/CreateVirtualServer.go index 1227ff2d..21bc8c8f 100644 --- a/internal/commands/CreateVirtualServer.go +++ b/internal/commands/CreateVirtualServer.go @@ -56,12 +56,13 @@ type CreateVirtualServerProjectRole struct { } type CreateVirtualServerProjectApplication struct { - Name string - DisplayName string - Type string - HashedSecret *string - RedirectUris []string - PostLogoutUris []string + Name string + DisplayName string + Type string + HashedSecret *string + RedirectUris []string + PostLogoutUris []string + DeviceFlowEnabled bool } type CreateVirtualServerProject struct { @@ -162,6 +163,7 @@ func HandleCreateVirtualServer(ctx context.Context, command CreateVirtualServer) newApp.SetHashedSecret(*app.HashedSecret) } newApp.SetPostLogoutRedirectUris(app.PostLogoutUris) + newApp.SetDeviceFlowEnabled(app.DeviceFlowEnabled) dbContext.Applications().Insert(newApp) } diff --git a/internal/config/config.go b/internal/config/config.go index 28d54914..7abad46d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -147,6 +147,7 @@ type InitialProjectConfig struct { HashedSecret *string RedirectUris []string PostLogoutRedirectUris []string + DeviceFlowEnabled bool } ResourceServers []struct { Slug string diff --git a/internal/handlers/applications.go b/internal/handlers/applications.go index cd9fe48e..82260b54 100644 --- a/internal/handlers/applications.go +++ b/internal/handlers/applications.go @@ -24,6 +24,7 @@ type CreateApplicationRequestDto struct { PostLogoutUris []string `json:"postLogoutUris" validate:"dive,url"` Type string `json:"type" validate:"required,oneof=public confidential"` AccessTokenHeaderType *string `json:"accessTokenHeaderType" validate:"omitempty,oneof=at+jwt JWT"` + DeviceFlowEnabled bool `json:"deviceFlowEnabled"` } type CreateApplicationResponseDto struct { @@ -86,6 +87,7 @@ func CreateApplication(w http.ResponseWriter, r *http.Request) { RedirectUris: dto.RedirectUris, PostLogoutRedirectUris: utils.EmptyIfNil(dto.PostLogoutUris), AccessTokenHeaderType: accessTokenHeaderType, + DeviceFlowEnabled: dto.DeviceFlowEnabled, }) if err != nil { utils.HandleHttpError(w, err) From 00db6e92306640b14ce0750db41a7af3baeadefe Mon Sep 17 00:00:00 2001 From: karo Date: Thu, 19 Mar 2026 11:49:22 +0100 Subject: [PATCH 09/10] refactor: replace `pq.ByteaArray` with `sql.NullString` in audit log mapping Updated audit log mapping to use `sql.NullString` for nullable string fields. Removed reliance on `pq.ByteaArray` and introduced helper functions for wrapping and unwrapping nullable values. Signed-off-by: karo --- internal/repositories/postgres/auditlogs.go | 22 +++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/internal/repositories/postgres/auditlogs.go b/internal/repositories/postgres/auditlogs.go index ab4d1e92..922df3d1 100644 --- a/internal/repositories/postgres/auditlogs.go +++ b/internal/repositories/postgres/auditlogs.go @@ -11,7 +11,6 @@ import ( "fmt" "github.com/google/uuid" - "github.com/lib/pq" "github.com/huandu/go-sqlbuilder" ) @@ -22,15 +21,23 @@ type postgresAuditLog struct { userId *uuid.UUID requestType string request string - response *pq.ByteaArray + response sql.NullString allowed bool - allowReasonType string + allowReasonType sql.NullString allowReason *string } func mapAuditLog(auditLog *repositories.AuditLog) *postgresAuditLog { return &postgresAuditLog{ postgresBaseModel: mapBase(auditLog.BaseModel), + virtualServerId: auditLog.VirtualServerId(), + userId: auditLog.UserId(), + requestType: auditLog.RequestType(), + request: auditLog.Request(), + response: pghelpers.WrapStringPointer(auditLog.Response()), + allowed: auditLog.Allowed(), + allowReasonType: pghelpers.WrapStringPointer(auditLog.AllowReasonType()), + allowReason: auditLog.AllowReason(), } } @@ -41,9 +48,9 @@ func (a *postgresAuditLog) Map() *repositories.AuditLog { a.userId, a.requestType, a.request, - nil, // TODO + pghelpers.UnwrapNullString(a.response), a.allowed, - &a.allowReasonType, + pghelpers.UnwrapNullString(a.allowReasonType), a.allowReason, ) } @@ -153,8 +160,11 @@ func (r *AuditLogRepository) ExecuteInsert(ctx context.Context, tx *sql.Tx, audi mapped := mapAuditLog(auditLog) s := sqlbuilder.InsertInto("audit_logs"). - Cols("virtual_server_id", "user_id", "request_type", "request", "response", "allowed", "allow_reason_type", "allow_reason"). + Cols("id", "audit_created_at", "audit_updated_at", "virtual_server_id", "user_id", "request_type", "request", "response", "allowed", "allow_reason_type", "allow_reason"). Values( + mapped.id, + mapped.auditCreatedAt, + mapped.auditUpdatedAt, mapped.virtualServerId, mapped.userId, mapped.requestType, From fc688ec4fb9150e5b17cb4d858420287c239b2d1 Mon Sep 17 00:00:00 2001 From: karo Date: Thu, 19 Mar 2026 12:01:11 +0100 Subject: [PATCH 10/10] chore: suppress errcheck linter warnings for `defer resp.Body.Close()` calls in OIDC client Signed-off-by: karo --- client/oidc.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/client/oidc.go b/client/oidc.go index b73bc3ad..240903f0 100644 --- a/client/oidc.go +++ b/client/oidc.go @@ -69,7 +69,7 @@ func (o *oidcClient) BeginDeviceFlow(ctx context.Context, clientId string, scope if err != nil { return DeviceAuthorizationResponse{}, fmt.Errorf("doing request: %w", err) } - defer resp.Body.Close() + defer resp.Body.Close() //nolint:errcheck var result DeviceAuthorizationResponse if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { @@ -96,7 +96,7 @@ func (o *oidcClient) PollDeviceToken(ctx context.Context, clientId string, devic if err != nil { return DeviceTokenResponse{}, fmt.Errorf("doing request: %w", err) } - defer resp.Body.Close() + defer resp.Body.Close() //nolint:errcheck if resp.StatusCode == http.StatusBadRequest { var oauthErr struct { @@ -143,7 +143,7 @@ func (o *oidcClient) PostActivate(ctx context.Context, userCode string) (string, if err != nil { return "", fmt.Errorf("doing request: %w", err) } - defer resp.Body.Close() + defer resp.Body.Close() //nolint:errcheck if resp.StatusCode == http.StatusNotFound { return "", ErrInvalidUserCode @@ -181,7 +181,7 @@ func (o *oidcClient) VerifyPassword(ctx context.Context, loginToken string, user if err != nil { return fmt.Errorf("doing request: %w", err) } - defer resp.Body.Close() + defer resp.Body.Close() //nolint:errcheck if resp.StatusCode == http.StatusUnauthorized { return fmt.Errorf("invalid credentials") @@ -203,7 +203,7 @@ func (o *oidcClient) FinishLogin(ctx context.Context, loginToken string) error { if err != nil { return fmt.Errorf("doing request: %w", err) } - defer resp.Body.Close() + defer resp.Body.Close() //nolint:errcheck if resp.StatusCode >= 400 { return ApiError{Message: resp.Status, Code: resp.StatusCode}