Skip to content

Commit b2a1bfb

Browse files
committed
fix: validate client id on oidc token endpoint
1 parent f1e869a commit b2a1bfb

5 files changed

Lines changed: 27 additions & 54 deletions

File tree

internal/controller/oidc_controller.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ func (controller *OIDCController) Token(c *gin.Context) {
270270

271271
switch req.GrantType {
272272
case "authorization_code":
273-
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code))
273+
entry, err := controller.oidc.GetCodeEntry(c, controller.oidc.Hash(req.Code), client.ClientID)
274274
if err != nil {
275275
if errors.Is(err, service.ErrCodeNotFound) {
276276
tlog.App.Warn().Msg("Code not found")
@@ -286,6 +286,13 @@ func (controller *OIDCController) Token(c *gin.Context) {
286286
})
287287
return
288288
}
289+
if errors.Is(err, service.ErrInvalidClient) {
290+
tlog.App.Warn().Msg("Invalid client ID")
291+
c.JSON(400, gin.H{
292+
"error": "invalid_client",
293+
})
294+
return
295+
}
289296
tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry")
290297
c.JSON(400, gin.H{
291298
"error": "server_error",

internal/controller/proxy_controller.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,6 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
185185

186186
tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
187187

188-
if userContext.IsBasicAuth && userContext.TotpEnabled {
189-
tlog.App.Debug().Msg("User has TOTP enabled, denying basic auth access")
190-
userContext.IsLoggedIn = false
191-
}
192-
193188
if userContext.IsLoggedIn {
194189
userAllowed := controller.auth.IsUserAllowed(c, userContext, acls)
195190

internal/controller/proxy_controller_test.go

Lines changed: 9 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ func setupProxyController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.En
5959
Username: "testuser",
6060
Password: "$2a$10$ne6z693sTgzT3ePoQ05PgOecUHnBjM7sSNj6M.l5CLUP.f6NyCnt.", // test
6161
},
62+
{
63+
Username: "totpuser",
64+
Password: "$2a$10$ne6z693sTgzT3ePoQ05PgOecUHnBjM7sSNj6M.l5CLUP.f6NyCnt.",
65+
TotpSecret: "foo",
66+
},
6267
},
6368
OauthWhitelist: []string{},
6469
SessionExpiry: 3600,
@@ -79,9 +84,11 @@ func setupProxyController(t *testing.T, middlewares *[]gin.HandlerFunc) (*gin.En
7984
return router, recorder, authService
8085
}
8186

87+
// TODO: Needs tests for context middleware
88+
8289
func TestProxyHandler(t *testing.T) {
8390
// Setup
84-
router, recorder, authService := setupProxyController(t, nil)
91+
router, recorder, _ := setupProxyController(t, nil)
8592

8693
// Test invalid proxy
8794
req := httptest.NewRequest("GET", "/api/auth/invalidproxy", nil)
@@ -144,21 +151,6 @@ func TestProxyHandler(t *testing.T) {
144151
assert.Equal(t, 401, recorder.Code)
145152

146153
// Test logged in user
147-
c := gin.CreateTestContextOnly(recorder, router)
148-
149-
err := authService.CreateSessionCookie(c, &repository.Session{
150-
Username: "testuser",
151-
Name: "testuser",
152-
Email: "testuser@example.com",
153-
Provider: "local",
154-
TotpPending: false,
155-
OAuthGroups: "",
156-
})
157-
158-
assert.NilError(t, err)
159-
160-
cookie := c.Writer.Header().Get("Set-Cookie")
161-
162154
router, recorder, _ = setupProxyController(t, &[]gin.HandlerFunc{
163155
func(c *gin.Context) {
164156
c.Set("context", &config.UserContext{
@@ -177,44 +169,15 @@ func TestProxyHandler(t *testing.T) {
177169
})
178170

179171
req = httptest.NewRequest("GET", "/api/auth/traefik", nil)
180-
req.Header.Set("Cookie", cookie)
181172
req.Header.Set("X-Forwarded-Proto", "https")
182173
req.Header.Set("X-Forwarded-Host", "example.com")
183174
req.Header.Set("X-Forwarded-Uri", "/somepath")
184175
req.Header.Set("Accept", "text/html")
185-
router.ServeHTTP(recorder, req)
186176

177+
router.ServeHTTP(recorder, req)
187178
assert.Equal(t, 200, recorder.Code)
188179

189180
assert.Equal(t, "testuser", recorder.Header().Get("Remote-User"))
190181
assert.Equal(t, "testuser", recorder.Header().Get("Remote-Name"))
191182
assert.Equal(t, "testuser@example.com", recorder.Header().Get("Remote-Email"))
192-
193-
// Ensure basic auth is disabled for TOTP enabled users
194-
router, recorder, _ = setupProxyController(t, &[]gin.HandlerFunc{
195-
func(c *gin.Context) {
196-
c.Set("context", &config.UserContext{
197-
Username: "testuser",
198-
Name: "testuser",
199-
Email: "testuser@example.com",
200-
IsLoggedIn: true,
201-
IsBasicAuth: true,
202-
OAuth: false,
203-
Provider: "local",
204-
TotpPending: false,
205-
OAuthGroups: "",
206-
TotpEnabled: true,
207-
})
208-
c.Next()
209-
},
210-
})
211-
212-
req = httptest.NewRequest("GET", "/api/auth/traefik", nil)
213-
req.Header.Set("X-Forwarded-Proto", "https")
214-
req.Header.Set("X-Forwarded-Host", "example.com")
215-
req.Header.Set("X-Forwarded-Uri", "/somepath")
216-
req.SetBasicAuth("testuser", "test")
217-
router.ServeHTTP(recorder, req)
218-
219-
assert.Equal(t, 401, recorder.Code)
220183
}

internal/middleware/context_middleware.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,17 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
182182

183183
user := m.auth.GetLocalUser(basic.Username)
184184

185+
if user.TotpSecret != "" {
186+
tlog.App.Debug().Msg("User with TOTP not allowed to login via basic auth")
187+
return
188+
}
189+
185190
c.Set("context", &config.UserContext{
186191
Username: user.Username,
187192
Name: utils.Capitalize(user.Username),
188193
Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain),
189194
Provider: "local",
190195
IsLoggedIn: true,
191-
TotpEnabled: user.TotpSecret != "",
192196
IsBasicAuth: true,
193197
})
194198
c.Next()

internal/service/oidc_service.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ func (service *OIDCService) ValidateGrantType(grantType string) error {
352352
return nil
353353
}
354354

355-
func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string) (repository.OidcCode, error) {
355+
func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, clientId string) (repository.OidcCode, error) {
356356
oidcCode, err := service.queries.GetOidcCode(c, codeHash)
357357

358358
if err != nil {
@@ -374,6 +374,10 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string) (repos
374374
return repository.OidcCode{}, ErrCodeExpired
375375
}
376376

377+
if oidcCode.ClientID != clientId {
378+
return repository.OidcCode{}, ErrInvalidClient
379+
}
380+
377381
return oidcCode, nil
378382
}
379383

0 commit comments

Comments
 (0)