Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ type Client interface {
Application() ApplicationClient
VirtualServer() VirtualServerClient
User() UserClient
Oidc() OidcClient
}

type client struct {
Expand All @@ -27,3 +28,7 @@ func (c *client) VirtualServer() VirtualServerClient {
func (c *client) User() UserClient {
return NewUserClient(c.transport)
}

func (c *client) Oidc() OidcClient {
return NewOidcClient(c.transport)
}
213 changes: 213 additions & 0 deletions client/oidc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
package client

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
)

var (
ErrAuthorizationPending = errors.New("authorization_pending")
ErrAccessDenied = errors.New("access_denied")
ErrExpiredToken = errors.New("expired_token")
ErrSlowDown = errors.New("slow_down")
ErrInvalidUserCode = errors.New("invalid_user_code")
)

type DeviceAuthorizationResponse struct {
DeviceCode string `json:"device_code"`
UserCode string `json:"user_code"`
VerificationUri string `json:"verification_uri"`
VerificationUriComplete string `json:"verification_uri_complete"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval"`
}

type DeviceTokenResponse struct {
TokenType string `json:"token_type"`
IdToken string `json:"id_token"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
ExpiresIn int `json:"expires_in"`
}

type OidcClient interface {
BeginDeviceFlow(ctx context.Context, clientId string, scope string) (DeviceAuthorizationResponse, error)
PollDeviceToken(ctx context.Context, clientId string, deviceCode string) (DeviceTokenResponse, error)
PostActivate(ctx context.Context, userCode string) (loginToken string, err error)
VerifyPassword(ctx context.Context, loginToken string, username string, password string) error
FinishLogin(ctx context.Context, loginToken string) error
}

type oidcClient struct {
transport *Transport
}

func NewOidcClient(transport *Transport) OidcClient {
return &oidcClient{transport: transport}
}

func (o *oidcClient) BeginDeviceFlow(ctx context.Context, clientId string, scope string) (DeviceAuthorizationResponse, error) {
formValues := url.Values{
"client_id": {clientId},
"scope": {scope},
}

req, err := o.transport.NewOidcRequest(ctx, http.MethodPost, "/device", strings.NewReader(formValues.Encode()))
if err != nil {
return DeviceAuthorizationResponse{}, fmt.Errorf("creating request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

resp, err := o.transport.Do(req)
if err != nil {
return DeviceAuthorizationResponse{}, fmt.Errorf("doing request: %w", err)
}
defer resp.Body.Close() //nolint:errcheck

var result DeviceAuthorizationResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return DeviceAuthorizationResponse{}, fmt.Errorf("decoding response: %w", err)
}

return result, nil
}

func (o *oidcClient) PollDeviceToken(ctx context.Context, clientId string, deviceCode string) (DeviceTokenResponse, error) {
formValues := url.Values{
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
"client_id": {clientId},
"device_code": {deviceCode},
}

req, err := o.transport.NewOidcRequest(ctx, http.MethodPost, "/token", strings.NewReader(formValues.Encode()))
if err != nil {
return DeviceTokenResponse{}, fmt.Errorf("creating request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

resp, err := o.transport.DoRaw(req)
if err != nil {
return DeviceTokenResponse{}, fmt.Errorf("doing request: %w", err)
}
defer resp.Body.Close() //nolint:errcheck

if resp.StatusCode == http.StatusBadRequest {
var oauthErr struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}
_ = json.NewDecoder(resp.Body).Decode(&oauthErr)
switch oauthErr.Error {
case "authorization_pending":
return DeviceTokenResponse{}, ErrAuthorizationPending
case "access_denied":
return DeviceTokenResponse{}, ErrAccessDenied
case "expired_token":
return DeviceTokenResponse{}, ErrExpiredToken
case "slow_down":
return DeviceTokenResponse{}, ErrSlowDown
default:
return DeviceTokenResponse{}, fmt.Errorf("oauth error %s: %s", oauthErr.Error, oauthErr.ErrorDescription)
}
}

if resp.StatusCode != http.StatusOK {
return DeviceTokenResponse{}, ApiError{Message: resp.Status, Code: resp.StatusCode}
}

var result DeviceTokenResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return DeviceTokenResponse{}, fmt.Errorf("decoding response: %w", err)
}

return result, nil
}

func (o *oidcClient) PostActivate(ctx context.Context, userCode string) (string, error) {
formValues := url.Values{"user_code": {userCode}}

req, err := o.transport.NewOidcRequest(ctx, http.MethodPost, "/activate", strings.NewReader(formValues.Encode()))
if err != nil {
return "", fmt.Errorf("creating request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")

resp, err := o.transport.DoNoRedirect(req)
if err != nil {
return "", fmt.Errorf("doing request: %w", err)
}
defer resp.Body.Close() //nolint:errcheck

if resp.StatusCode == http.StatusNotFound {
return "", ErrInvalidUserCode
}
if resp.StatusCode != http.StatusFound {
return "", ApiError{Message: resp.Status, Code: resp.StatusCode}
}

location := resp.Header.Get("Location")
parsed, err := url.Parse(location)
if err != nil {
return "", fmt.Errorf("parsing redirect location: %w", err)
}

token := parsed.Query().Get("token")
if token == "" {
return "", fmt.Errorf("no login token in redirect location")
}

return token, nil
}

func (o *oidcClient) VerifyPassword(ctx context.Context, loginToken string, username string, password string) error {
body, err := json.Marshal(map[string]string{"username": username, "password": password})
if err != nil {
return fmt.Errorf("marshaling request: %w", err)
}

req, err := o.transport.NewRootRequest(ctx, http.MethodPost, fmt.Sprintf("/logins/%s/verify-password", loginToken), bytes.NewReader(body))
if err != nil {
return fmt.Errorf("creating request: %w", err)
}

resp, err := o.transport.DoRaw(req)
if err != nil {
return fmt.Errorf("doing request: %w", err)
}
defer resp.Body.Close() //nolint:errcheck

if resp.StatusCode == http.StatusUnauthorized {
return fmt.Errorf("invalid credentials")
}
if resp.StatusCode >= 400 {
return ApiError{Message: resp.Status, Code: resp.StatusCode}
}

return nil
}

func (o *oidcClient) FinishLogin(ctx context.Context, loginToken string) error {
req, err := o.transport.NewRootRequest(ctx, http.MethodPost, fmt.Sprintf("/logins/%s/finish-login", loginToken), nil)
if err != nil {
return fmt.Errorf("creating request: %w", err)
}

resp, err := o.transport.DoNoRedirect(req)
if err != nil {
return fmt.Errorf("doing request: %w", err)
}
defer resp.Body.Close() //nolint:errcheck

if resp.StatusCode >= 400 {
return ApiError{Message: resp.Status, Code: resp.StatusCode}
}

return nil
}
147 changes: 147 additions & 0 deletions client/oidc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package client

import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/suite"
)

type OidcClientSuite struct {
suite.Suite
}

func TestOidcClientSuite(t *testing.T) {
t.Parallel()
suite.Run(t, new(OidcClientSuite))
}

func (s *OidcClientSuite) TestBeginDeviceFlow_HappyPath() {
expected := DeviceAuthorizationResponse{
DeviceCode: "device-code-abc",
UserCode: "ABCD-EFGH",
VerificationUri: "http://localhost/oidc/test/activate",
VerificationUriComplete: "http://localhost/oidc/test/activate?user_code=ABCD-EFGH",
ExpiresIn: 600,
Interval: 5,
}

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.Equal(http.MethodPost, r.Method)
s.Equal("/oidc/test/device", r.URL.Path)
s.Equal("application/x-www-form-urlencoded", r.Header.Get("Content-Type"))

s.NoError(r.ParseForm())
s.Equal("my-app", r.Form.Get("client_id"))
s.Equal("openid profile", r.Form.Get("scope"))

w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(expected)
}))
defer server.Close()

testee := NewClient(server.URL, "test").Oidc()

result, err := testee.BeginDeviceFlow(s.T().Context(), "my-app", "openid profile")

s.Require().NoError(err)
s.Equal(expected, result)
}

func (s *OidcClientSuite) TestBeginDeviceFlow_RejectsUnknownApp() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
_ = json.NewEncoder(w).Encode(map[string]string{
"error": "invalid_client",
"error_description": "application not found",
})
}))
defer server.Close()

testee := NewClient(server.URL, "test").Oidc()

_, err := testee.BeginDeviceFlow(s.T().Context(), "nonexistent", "openid")

s.Require().Error(err)
}

func (s *OidcClientSuite) TestPollDeviceToken_HappyPath() {
expected := DeviceTokenResponse{
TokenType: "Bearer",
IdToken: "id-token-value",
AccessToken: "access-token-value",
RefreshToken: "refresh-token-value",
Scope: "openid",
ExpiresIn: 3600,
}

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.Equal(http.MethodPost, r.Method)
s.Equal("/oidc/test/token", r.URL.Path)
s.Equal("application/x-www-form-urlencoded", r.Header.Get("Content-Type"))

s.NoError(r.ParseForm())
s.Equal("urn:ietf:params:oauth:grant-type:device_code", r.Form.Get("grant_type"))
s.Equal("my-app", r.Form.Get("client_id"))
s.Equal("device-code-abc", r.Form.Get("device_code"))

w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(expected)
}))
defer server.Close()

testee := NewClient(server.URL, "test").Oidc()

result, err := testee.PollDeviceToken(s.T().Context(), "my-app", "device-code-abc")

s.Require().NoError(err)
s.Equal(expected, result)
}

func (s *OidcClientSuite) TestPollDeviceToken_AuthorizationPending() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
_ = json.NewEncoder(w).Encode(map[string]string{"error": "authorization_pending"})
}))
defer server.Close()

testee := NewClient(server.URL, "test").Oidc()

_, err := testee.PollDeviceToken(s.T().Context(), "my-app", "device-code-abc")

s.ErrorIs(err, ErrAuthorizationPending)
}

func (s *OidcClientSuite) TestPollDeviceToken_AccessDenied() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
_ = json.NewEncoder(w).Encode(map[string]string{"error": "access_denied"})
}))
defer server.Close()

testee := NewClient(server.URL, "test").Oidc()

_, err := testee.PollDeviceToken(s.T().Context(), "my-app", "device-code-abc")

s.ErrorIs(err, ErrAccessDenied)
}

func (s *OidcClientSuite) TestPollDeviceToken_ExpiredToken() {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
_ = json.NewEncoder(w).Encode(map[string]string{"error": "expired_token"})
}))
defer server.Close()

testee := NewClient(server.URL, "test").Oidc()

_, err := testee.PollDeviceToken(s.T().Context(), "my-app", "device-code-abc")

s.ErrorIs(err, ErrExpiredToken)
}
Loading