Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package config

import (
"context"
"log/slog"
"net/http"

v3 "github.com/pb33f/libopenapi/datamodel/high/v3"
"github.com/santhosh-tekuri/jsonschema/v6"

"github.com/pb33f/libopenapi-validator/cache"
Expand All @@ -18,6 +21,18 @@ type RegexCache interface {
Store(key, value any) // Set a compiled regex to the cache
}

// AuthenticationFunc validates a security scheme for an HTTP request.
// Return nil when the scheme is satisfied; return an error to fail the current security requirement.
type AuthenticationFunc func(context.Context, *AuthenticationInput) error

// AuthenticationInput contains the request and OpenAPI security scheme details passed to an AuthenticationFunc.
type AuthenticationInput struct {
Request *http.Request
SecuritySchemeName string
SecurityScheme *v3.SecurityScheme
Scopes []string
}

// ValidationOptions A container for validation configuration.
//
// Generally fluent With... style functions are used to establish the desired behavior.
Expand All @@ -27,6 +42,7 @@ type ValidationOptions struct {
FormatAssertions bool
ContentAssertions bool
SecurityValidation bool
AuthenticationFunc AuthenticationFunc
OpenAPIMode bool // Enable OpenAPI-specific vocabulary validation
AllowScalarCoercion bool // Enable string->boolean/number coercion
Formats map[string]func(v any) error
Expand Down Expand Up @@ -77,6 +93,7 @@ func WithExistingOpts(options *ValidationOptions) Option {
o.FormatAssertions = options.FormatAssertions
o.ContentAssertions = options.ContentAssertions
o.SecurityValidation = options.SecurityValidation
o.AuthenticationFunc = options.AuthenticationFunc
o.OpenAPIMode = options.OpenAPIMode
o.AllowScalarCoercion = options.AllowScalarCoercion
o.Formats = options.Formats
Expand Down Expand Up @@ -140,6 +157,14 @@ func WithoutSecurityValidation() Option {
}
}

// WithAuthenticationFunc sets a custom function for validating security requirements.
// When set, the function is authoritative for all security scheme types, including oauth2 and openIdConnect.
func WithAuthenticationFunc(fn AuthenticationFunc) Option {
return func(o *ValidationOptions) {
o.AuthenticationFunc = fn
}
}

// WithCustomFormat adds custom formats and their validators that checks for custom 'format' assertions
// When you add different validators with the same name, they will be overridden,
// and only the last registration will take effect.
Expand Down
38 changes: 38 additions & 0 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package config

import (
"context"
"log/slog"
"sync"
"testing"
Expand Down Expand Up @@ -79,6 +80,25 @@ func TestWithoutSecurityValidation(t *testing.T) {
assert.Nil(t, opts.RegexCache)
}

func TestWithAuthenticationFunc(t *testing.T) {
called := false
authFn := func(ctx context.Context, input *AuthenticationInput) error {
called = true
assert.NotNil(t, ctx)
assert.Equal(t, "ApiKeyAuth", input.SecuritySchemeName)
return nil
}

opts := NewValidationOptions(WithAuthenticationFunc(authFn))

assert.True(t, opts.SecurityValidation)
assert.NotNil(t, opts.AuthenticationFunc)
assert.NoError(t, opts.AuthenticationFunc(context.Background(), &AuthenticationInput{
SecuritySchemeName: "ApiKeyAuth",
}))
assert.True(t, called)
}

func TestWithRegexEngine(t *testing.T) {
// Test with nil regex engine (valid)
var mockEngine jsonschema.RegexpEngine = nil
Expand Down Expand Up @@ -260,6 +280,24 @@ func TestWithExistingOpts_SecurityValidationCopied(t *testing.T) {
assert.True(t, opts2.SecurityValidation)
}

func TestWithExistingOpts_AuthenticationFuncCopied(t *testing.T) {
called := false
authFn := func(context.Context, *AuthenticationInput) error {
called = true
return nil
}

original := &ValidationOptions{
AuthenticationFunc: authFn,
}

opts := NewValidationOptions(WithExistingOpts(original))

assert.NotNil(t, opts.AuthenticationFunc)
assert.NoError(t, opts.AuthenticationFunc(context.Background(), &AuthenticationInput{}))
assert.True(t, called)
}

// Tests for new OpenAPI and scalar coercion configuration options

func TestWithOpenAPIMode(t *testing.T) {
Expand Down
42 changes: 41 additions & 1 deletion parameters/validate_security.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
v3 "github.com/pb33f/libopenapi/datamodel/high/v3"
"github.com/pb33f/libopenapi/orderedmap"

"github.com/pb33f/libopenapi-validator/config"
"github.com/pb33f/libopenapi-validator/errors"
"github.com/pb33f/libopenapi-validator/helpers"
"github.com/pb33f/libopenapi-validator/paths"
Expand Down Expand Up @@ -84,7 +85,7 @@ func (v *paramValidator) ValidateSecurityWithPathItem(request *http.Request, pat
}

secScheme := v.document.Components.SecuritySchemes.GetOrZero(secName)
schemeValid, schemeErrors := v.validateSecurityScheme(secScheme, sec, request, pathValue)
schemeValid, schemeErrors := v.validateSecurityScheme(secName, secScheme, pair.Value(), sec, request, pathValue)
if !schemeValid {
requirementSatisfied = false
requirementErrors = append(requirementErrors, schemeErrors...)
Expand All @@ -103,11 +104,17 @@ func (v *paramValidator) ValidateSecurityWithPathItem(request *http.Request, pat

// validateSecurityScheme checks if a single security scheme is satisfied by the request.
func (v *paramValidator) validateSecurityScheme(
secName string,
secScheme *v3.SecurityScheme,
scopes []string,
sec *base.SecurityRequirement,
request *http.Request,
pathValue string,
) (bool, []*errors.ValidationError) {
if v.options.AuthenticationFunc != nil {
return v.validateAuthenticationFunc(secName, secScheme, scopes, sec, request, pathValue)
}

switch strings.ToLower(secScheme.Type) {
case "http":
return v.validateHTTPSecurityScheme(secScheme, sec, request, pathValue)
Expand All @@ -118,6 +125,39 @@ func (v *paramValidator) validateSecurityScheme(
return true, nil
}

func (v *paramValidator) validateAuthenticationFunc(
secName string,
secScheme *v3.SecurityScheme,
scopes []string,
sec *base.SecurityRequirement,
request *http.Request,
pathValue string,
) (bool, []*errors.ValidationError) {
authErr := v.options.AuthenticationFunc(request.Context(), &config.AuthenticationInput{
Request: request,
SecuritySchemeName: secName,
SecurityScheme: secScheme,
Scopes: scopes,
})
if authErr == nil {
return true, nil
}

validationErrors := []*errors.ValidationError{
{
Message: fmt.Sprintf("Authentication failed for security scheme '%s'", secName),
Reason: authErr.Error(),
ValidationType: helpers.SecurityValidation,
ValidationSubType: secScheme.Type,
SpecLine: sec.GoLow().Requirements.ValueNode.Line,
SpecCol: sec.GoLow().Requirements.ValueNode.Column,
HowToFix: fmt.Sprintf("Provide valid credentials for security scheme '%s'", secName),
},
}
errors.PopulateValidationErrors(validationErrors, request, pathValue)
return false, validationErrors
}

func (v *paramValidator) validateHTTPSecurityScheme(
secScheme *v3.SecurityScheme,
sec *base.SecurityRequirement,
Expand Down
Loading
Loading