Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions internal/controller/oidc_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"

"github.com/steveiliop56/tinyauth/internal/service"
"github.com/steveiliop56/tinyauth/internal/utils"
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
Expand Down Expand Up @@ -376,22 +377,48 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
return
}

var token string

authorization := c.GetHeader("Authorization")
if authorization != "" {
tokenType, bearerToken, ok := strings.Cut(authorization, " ")
if !ok {
tlog.App.Warn().Msg("OIDC userinfo accessed with malformed authorization header")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}

tokenType, token, ok := strings.Cut(authorization, " ")
if strings.ToLower(tokenType) != "bearer" {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}

if !ok {
token = bearerToken
} else if c.Request.Method == http.MethodPost {
if c.ContentType() != "application/x-www-form-urlencoded" {
tlog.App.Warn().Msg("OIDC userinfo POST accessed with invalid content type")
c.JSON(400, gin.H{
"error": "invalid_request",
})
return
}
token = c.PostForm("access_token")
if token == "" {
tlog.App.Warn().Msg("OIDC userinfo POST accessed without access_token in body")
c.JSON(401, gin.H{
"error": "invalid_request",
})
return
}
} else {
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
c.JSON(401, gin.H{
"error": "invalid_grant",
})
return
}

if strings.ToLower(tokenType) != "bearer" {
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token type")
c.JSON(401, gin.H{
"error": "invalid_grant",
"error": "invalid_request",
})
return
}
Expand Down
122 changes: 122 additions & 0 deletions internal/controller/oidc_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,128 @@ func TestOIDCController(t *testing.T) {
assert.False(t, ok, "Did not expect email claim in userinfo response")
},
},
{
description: "Ensure userinfo forbids access with no authorization header",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)

var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
{
description: "Ensure userinfo forbids access with malformed authorization header",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
req.Header.Set("Authorization", "Bearer")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)

var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
{
description: "Ensure userinfo forbids access with invalid token type",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
req.Header.Set("Authorization", "Basic some-token")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)

var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
{
description: "Ensure userinfo forbids access with empty bearer token",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("GET", "/api/oidc/userinfo", nil)
req.Header.Set("Authorization", "Bearer ")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)

var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_grant", res["error"])
},
},
{
description: "Ensure userinfo POST rejects missing access token in body",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(""))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
router.ServeHTTP(recorder, req)
assert.Equal(t, 401, recorder.Code)

var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
{
description: "Ensure userinfo POST rejects wrong content type",
middlewares: []gin.HandlerFunc{},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(`{"access_token":"some-token"}`))
req.Header.Set("Content-Type", "application/json")
router.ServeHTTP(recorder, req)
assert.Equal(t, 400, recorder.Code)

var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"])
},
},
{
description: "Ensure userinfo accepts access token via POST body",
middlewares: []gin.HandlerFunc{
simpleCtx,
},
run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) {
tokenTest, found := getTestByDescription("Ensure we can get a token with a valid request")
assert.True(t, found, "Token test not found")
tokenRecorder := httptest.NewRecorder()
tokenTest(t, router, tokenRecorder)

var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
assert.NoError(t, err)

accessToken := tokenRes["access_token"].(string)
assert.NotEmpty(t, accessToken)

body := url.Values{}
body.Set("access_token", accessToken)
req := httptest.NewRequest("POST", "/api/oidc/userinfo", strings.NewReader(body.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code)

var userInfoRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
assert.NoError(t, err)

_, ok := userInfoRes["sub"]
assert.True(t, ok, "Expected sub claim in userinfo response")
},
},
{
description: "Ensure plain PKCE succeeds",
middlewares: []gin.HandlerFunc{
Expand Down