diff --git a/config/config.go b/config/config.go index f67ec0f..31e89de 100644 --- a/config/config.go +++ b/config/config.go @@ -1,8 +1,11 @@ package config import ( + "context" "log/slog" + "net/http" + v3 "github.com/pb33f/libopenapi/datamodel/high/v3" "github.com/santhosh-tekuri/jsonschema/v6" "github.com/pb33f/libopenapi-validator/cache" @@ -18,6 +21,18 @@ type RegexCache interface { Store(key, value any) // Set a compiled regex to the cache } +// AuthenticationFunc validates a security scheme for an HTTP request. +// Return nil when the scheme is satisfied; return an error to fail the current security requirement. +type AuthenticationFunc func(context.Context, *AuthenticationInput) error + +// AuthenticationInput contains the request and OpenAPI security scheme details passed to an AuthenticationFunc. +type AuthenticationInput struct { + Request *http.Request + SecuritySchemeName string + SecurityScheme *v3.SecurityScheme + Scopes []string +} + // ValidationOptions A container for validation configuration. // // Generally fluent With... style functions are used to establish the desired behavior. @@ -27,6 +42,7 @@ type ValidationOptions struct { FormatAssertions bool ContentAssertions bool SecurityValidation bool + AuthenticationFunc AuthenticationFunc OpenAPIMode bool // Enable OpenAPI-specific vocabulary validation AllowScalarCoercion bool // Enable string->boolean/number coercion Formats map[string]func(v any) error @@ -77,6 +93,7 @@ func WithExistingOpts(options *ValidationOptions) Option { o.FormatAssertions = options.FormatAssertions o.ContentAssertions = options.ContentAssertions o.SecurityValidation = options.SecurityValidation + o.AuthenticationFunc = options.AuthenticationFunc o.OpenAPIMode = options.OpenAPIMode o.AllowScalarCoercion = options.AllowScalarCoercion o.Formats = options.Formats @@ -140,6 +157,14 @@ func WithoutSecurityValidation() Option { } } +// WithAuthenticationFunc sets a custom function for validating security requirements. +// When set, the function is authoritative for all security scheme types, including oauth2 and openIdConnect. +func WithAuthenticationFunc(fn AuthenticationFunc) Option { + return func(o *ValidationOptions) { + o.AuthenticationFunc = fn + } +} + // WithCustomFormat adds custom formats and their validators that checks for custom 'format' assertions // When you add different validators with the same name, they will be overridden, // and only the last registration will take effect. diff --git a/config/config_test.go b/config/config_test.go index d2f77dd..c1cea05 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -4,6 +4,7 @@ package config import ( + "context" "log/slog" "sync" "testing" @@ -79,6 +80,25 @@ func TestWithoutSecurityValidation(t *testing.T) { assert.Nil(t, opts.RegexCache) } +func TestWithAuthenticationFunc(t *testing.T) { + called := false + authFn := func(ctx context.Context, input *AuthenticationInput) error { + called = true + assert.NotNil(t, ctx) + assert.Equal(t, "ApiKeyAuth", input.SecuritySchemeName) + return nil + } + + opts := NewValidationOptions(WithAuthenticationFunc(authFn)) + + assert.True(t, opts.SecurityValidation) + assert.NotNil(t, opts.AuthenticationFunc) + assert.NoError(t, opts.AuthenticationFunc(context.Background(), &AuthenticationInput{ + SecuritySchemeName: "ApiKeyAuth", + })) + assert.True(t, called) +} + func TestWithRegexEngine(t *testing.T) { // Test with nil regex engine (valid) var mockEngine jsonschema.RegexpEngine = nil @@ -260,6 +280,24 @@ func TestWithExistingOpts_SecurityValidationCopied(t *testing.T) { assert.True(t, opts2.SecurityValidation) } +func TestWithExistingOpts_AuthenticationFuncCopied(t *testing.T) { + called := false + authFn := func(context.Context, *AuthenticationInput) error { + called = true + return nil + } + + original := &ValidationOptions{ + AuthenticationFunc: authFn, + } + + opts := NewValidationOptions(WithExistingOpts(original)) + + assert.NotNil(t, opts.AuthenticationFunc) + assert.NoError(t, opts.AuthenticationFunc(context.Background(), &AuthenticationInput{})) + assert.True(t, called) +} + // Tests for new OpenAPI and scalar coercion configuration options func TestWithOpenAPIMode(t *testing.T) { diff --git a/parameters/validate_security.go b/parameters/validate_security.go index de4f5b6..6aab8fe 100644 --- a/parameters/validate_security.go +++ b/parameters/validate_security.go @@ -12,6 +12,7 @@ import ( v3 "github.com/pb33f/libopenapi/datamodel/high/v3" "github.com/pb33f/libopenapi/orderedmap" + "github.com/pb33f/libopenapi-validator/config" "github.com/pb33f/libopenapi-validator/errors" "github.com/pb33f/libopenapi-validator/helpers" "github.com/pb33f/libopenapi-validator/paths" @@ -84,7 +85,7 @@ func (v *paramValidator) ValidateSecurityWithPathItem(request *http.Request, pat } secScheme := v.document.Components.SecuritySchemes.GetOrZero(secName) - schemeValid, schemeErrors := v.validateSecurityScheme(secScheme, sec, request, pathValue) + schemeValid, schemeErrors := v.validateSecurityScheme(secName, secScheme, pair.Value(), sec, request, pathValue) if !schemeValid { requirementSatisfied = false requirementErrors = append(requirementErrors, schemeErrors...) @@ -103,11 +104,17 @@ func (v *paramValidator) ValidateSecurityWithPathItem(request *http.Request, pat // validateSecurityScheme checks if a single security scheme is satisfied by the request. func (v *paramValidator) validateSecurityScheme( + secName string, secScheme *v3.SecurityScheme, + scopes []string, sec *base.SecurityRequirement, request *http.Request, pathValue string, ) (bool, []*errors.ValidationError) { + if v.options.AuthenticationFunc != nil { + return v.validateAuthenticationFunc(secName, secScheme, scopes, sec, request, pathValue) + } + switch strings.ToLower(secScheme.Type) { case "http": return v.validateHTTPSecurityScheme(secScheme, sec, request, pathValue) @@ -118,6 +125,39 @@ func (v *paramValidator) validateSecurityScheme( return true, nil } +func (v *paramValidator) validateAuthenticationFunc( + secName string, + secScheme *v3.SecurityScheme, + scopes []string, + sec *base.SecurityRequirement, + request *http.Request, + pathValue string, +) (bool, []*errors.ValidationError) { + authErr := v.options.AuthenticationFunc(request.Context(), &config.AuthenticationInput{ + Request: request, + SecuritySchemeName: secName, + SecurityScheme: secScheme, + Scopes: scopes, + }) + if authErr == nil { + return true, nil + } + + validationErrors := []*errors.ValidationError{ + { + Message: fmt.Sprintf("Authentication failed for security scheme '%s'", secName), + Reason: authErr.Error(), + ValidationType: helpers.SecurityValidation, + ValidationSubType: secScheme.Type, + SpecLine: sec.GoLow().Requirements.ValueNode.Line, + SpecCol: sec.GoLow().Requirements.ValueNode.Column, + HowToFix: fmt.Sprintf("Provide valid credentials for security scheme '%s'", secName), + }, + } + errors.PopulateValidationErrors(validationErrors, request, pathValue) + return false, validationErrors +} + func (v *paramValidator) validateHTTPSecurityScheme( secScheme *v3.SecurityScheme, sec *base.SecurityRequirement, diff --git a/parameters/validate_security_test.go b/parameters/validate_security_test.go index e53d8bc..4ec236a 100644 --- a/parameters/validate_security_test.go +++ b/parameters/validate_security_test.go @@ -4,6 +4,8 @@ package parameters import ( + "context" + stderrors "errors" "net/http" "sync" "testing" @@ -1259,3 +1261,260 @@ components: assert.True(t, valid2) assert.Empty(t, errors2) } + +func TestParamValidator_ValidateSecurity_AuthenticationFunc_OAuth2Scopes(t *testing.T) { + spec := `openapi: 3.1.0 +paths: + /products: + get: + security: + - OAuth2: + - read:products + - write:products +components: + securitySchemes: + OAuth2: + type: oauth2 + flows: + clientCredentials: + tokenUrl: https://example.com/oauth/token + scopes: + read:products: Read products + write:products: Write products +` + + doc, _ := libopenapi.NewDocument([]byte(spec)) + m, _ := doc.BuildV3Model() + + type authContextKey struct{} + var request *http.Request + var gotScopes []string + called := 0 + authFn := func(ctx context.Context, input *config.AuthenticationInput) error { + called++ + assert.Equal(t, "ctx-value", ctx.Value(authContextKey{})) + assert.Equal(t, request, input.Request) + assert.Equal(t, "OAuth2", input.SecuritySchemeName) + assert.Equal(t, "oauth2", input.SecurityScheme.Type) + gotScopes = append([]string(nil), input.Scopes...) + return nil + } + + v := NewParameterValidator(&m.Model, config.WithAuthenticationFunc(authFn)) + request, _ = http.NewRequest(http.MethodGet, "https://things.com/products", nil) + request = request.WithContext(context.WithValue(request.Context(), authContextKey{}, "ctx-value")) + + valid, validationErrors := v.ValidateSecurity(request) + assert.True(t, valid) + assert.Empty(t, validationErrors) + assert.Equal(t, 1, called) + assert.ElementsMatch(t, []string{"read:products", "write:products"}, gotScopes) +} + +func TestParamValidator_ValidateSecurity_AuthenticationFunc_OpenIDConnect(t *testing.T) { + spec := `openapi: 3.1.0 +paths: + /products: + get: + security: + - OpenID: [] +components: + securitySchemes: + OpenID: + type: openIdConnect + openIdConnectUrl: https://example.com/.well-known/openid-configuration +` + + doc, _ := libopenapi.NewDocument([]byte(spec)) + m, _ := doc.BuildV3Model() + + called := false + authFn := func(ctx context.Context, input *config.AuthenticationInput) error { + called = true + assert.NotNil(t, ctx) + assert.Equal(t, "OpenID", input.SecuritySchemeName) + assert.Equal(t, "openIdConnect", input.SecurityScheme.Type) + assert.Empty(t, input.Scopes) + return nil + } + + v := NewParameterValidator(&m.Model, config.WithAuthenticationFunc(authFn)) + request, _ := http.NewRequest(http.MethodGet, "https://things.com/products", nil) + + valid, validationErrors := v.ValidateSecurity(request) + assert.True(t, valid) + assert.Empty(t, validationErrors) + assert.True(t, called) +} + +func TestParamValidator_ValidateSecurity_AuthenticationFunc_ORSuccessAfterFirstFailure(t *testing.T) { + spec := `openapi: 3.1.0 +paths: + /products: + get: + security: + - OAuth2: + - read:products + - ApiKeyAuth: [] +components: + securitySchemes: + OAuth2: + type: oauth2 + flows: + clientCredentials: + tokenUrl: https://example.com/oauth/token + scopes: + read:products: Read products + ApiKeyAuth: + type: apiKey + in: header + name: X-API-Key +` + + doc, _ := libopenapi.NewDocument([]byte(spec)) + m, _ := doc.BuildV3Model() + + calls := make(map[string]int) + authFn := func(_ context.Context, input *config.AuthenticationInput) error { + calls[input.SecuritySchemeName]++ + if input.SecuritySchemeName == "OAuth2" { + return stderrors.New("token missing") + } + return nil + } + + v := NewParameterValidator(&m.Model, config.WithAuthenticationFunc(authFn)) + request, _ := http.NewRequest(http.MethodGet, "https://things.com/products", nil) + + valid, validationErrors := v.ValidateSecurity(request) + assert.True(t, valid) + assert.Empty(t, validationErrors) + assert.Equal(t, 1, calls["OAuth2"]) + assert.Equal(t, 1, calls["ApiKeyAuth"]) +} + +func TestParamValidator_ValidateSecurity_AuthenticationFunc_ANDPartialFailure(t *testing.T) { + spec := `openapi: 3.1.0 +paths: + /products: + get: + security: + - OAuth2: + - read:products + ApiKeyAuth: [] +components: + securitySchemes: + OAuth2: + type: oauth2 + flows: + clientCredentials: + tokenUrl: https://example.com/oauth/token + scopes: + read:products: Read products + ApiKeyAuth: + type: apiKey + in: header + name: X-API-Key +` + + doc, _ := libopenapi.NewDocument([]byte(spec)) + m, _ := doc.BuildV3Model() + + calls := make(map[string]int) + authFn := func(_ context.Context, input *config.AuthenticationInput) error { + calls[input.SecuritySchemeName]++ + if input.SecuritySchemeName == "ApiKeyAuth" { + return stderrors.New("api key denied") + } + return nil + } + + v := NewParameterValidator(&m.Model, config.WithAuthenticationFunc(authFn)) + request, _ := http.NewRequest(http.MethodGet, "https://things.com/products", nil) + + valid, validationErrors := v.ValidateSecurity(request) + assert.False(t, valid) + assert.Len(t, validationErrors, 1) + assert.Equal(t, "Authentication failed for security scheme 'ApiKeyAuth'", validationErrors[0].Message) + assert.Equal(t, "api key denied", validationErrors[0].Reason) + assert.Equal(t, 1, calls["OAuth2"]) + assert.Equal(t, 1, calls["ApiKeyAuth"]) +} + +func TestParamValidator_ValidateSecurity_AuthenticationFunc_ErrorReturned(t *testing.T) { + spec := `openapi: 3.1.0 +paths: + /products: + get: + security: + - OAuth2: + - read:products +components: + securitySchemes: + OAuth2: + type: oauth2 + flows: + clientCredentials: + tokenUrl: https://example.com/oauth/token + scopes: + read:products: Read products +` + + doc, _ := libopenapi.NewDocument([]byte(spec)) + m, _ := doc.BuildV3Model() + + authFn := func(context.Context, *config.AuthenticationInput) error { + return stderrors.New("expired token") + } + + v := NewParameterValidator(&m.Model, config.WithAuthenticationFunc(authFn)) + request, _ := http.NewRequest(http.MethodGet, "https://things.com/products", nil) + + valid, validationErrors := v.ValidateSecurity(request) + assert.False(t, valid) + assert.Len(t, validationErrors, 1) + assert.Equal(t, "Authentication failed for security scheme 'OAuth2'", validationErrors[0].Message) + assert.Equal(t, "expired token", validationErrors[0].Reason) + assert.Equal(t, "security", validationErrors[0].ValidationType) + assert.Equal(t, "oauth2", validationErrors[0].ValidationSubType) + assert.Equal(t, request.Method, validationErrors[0].RequestMethod) + assert.Equal(t, request.URL.Path, validationErrors[0].RequestPath) + assert.Equal(t, "/products", validationErrors[0].SpecPath) +} + +func TestParamValidator_ValidateSecurity_AuthenticationFunc_SkippedWhenSecurityValidationDisabled(t *testing.T) { + spec := `openapi: 3.1.0 +paths: + /products: + get: + security: + - OAuth2: + - read:products +components: + securitySchemes: + OAuth2: + type: oauth2 + flows: + clientCredentials: + tokenUrl: https://example.com/oauth/token + scopes: + read:products: Read products +` + + doc, _ := libopenapi.NewDocument([]byte(spec)) + m, _ := doc.BuildV3Model() + + called := false + authFn := func(context.Context, *config.AuthenticationInput) error { + called = true + return stderrors.New("should not be called") + } + + v := NewParameterValidator(&m.Model, config.WithAuthenticationFunc(authFn), config.WithoutSecurityValidation()) + request, _ := http.NewRequest(http.MethodGet, "https://things.com/products", nil) + + valid, validationErrors := v.ValidateSecurity(request) + assert.True(t, valid) + assert.Empty(t, validationErrors) + assert.False(t, called) +} diff --git a/schema_validation/validate_document.go b/schema_validation/validate_document.go index 1a914d2..d539db4 100644 --- a/schema_validation/validate_document.go +++ b/schema_validation/validate_document.go @@ -8,6 +8,8 @@ import ( "encoding/json" "errors" "fmt" + "strconv" + "strings" "github.com/pb33f/libopenapi" "github.com/santhosh-tekuri/jsonschema/v6" @@ -20,11 +22,200 @@ import ( "github.com/pb33f/libopenapi-validator/helpers" ) -func normalizeJSON(data any) any { - d, _ := json.Marshal(data) +type nonStringMappingKey struct { + Value string + Tag string + Path []string + Line int + Column int + Sequence bool +} + +func normalizeJSON(data any) (any, error) { + d, err := json.Marshal(data) + if err != nil { + return nil, err + } + var normalized any _ = json.Unmarshal(d, &normalized) - return normalized + return normalized, nil +} + +func findNonStringMappingKey(rootNode *yaml.Node) *nonStringMappingKey { + if rootNode == nil { + return nil + } + return findNonStringMappingKeyInNode(rootNode, nil) +} + +func findNonStringMappingKeyInNode(node *yaml.Node, path []string) *nonStringMappingKey { + if node == nil { + return nil + } + + switch node.Kind { + case yaml.DocumentNode: + for _, child := range node.Content { + if found := findNonStringMappingKeyInNode(child, path); found != nil { + return found + } + } + case yaml.MappingNode: + for i := 0; i+1 < len(node.Content); i += 2 { + keyNode := node.Content[i] + valueNode := node.Content[i+1] + if isMergeMappingKey(keyNode) { + if found := findNonStringMappingKeyInMergeValue(valueNode, path); found != nil { + return found + } + continue + } + nextPath := appendPathSegment(path, keyNode.Value) + if !isStringMappingKey(keyNode) { + return &nonStringMappingKey{ + Value: keyNode.Value, + Tag: keyNode.ShortTag(), + Path: nextPath, + Line: keyNode.Line, + Column: keyNode.Column, + Sequence: keyNode.Kind == yaml.SequenceNode, + } + } + if found := findNonStringMappingKeyInNode(valueNode, nextPath); found != nil { + return found + } + } + case yaml.SequenceNode: + for i, child := range node.Content { + if found := findNonStringMappingKeyInNode(child, appendPathSegment(path, strconv.Itoa(i))); found != nil { + return found + } + } + } + + return nil +} + +func findNonStringMappingKeyInMergeValue(node *yaml.Node, path []string) *nonStringMappingKey { + if node == nil { + return nil + } + + switch node.Kind { + case yaml.AliasNode: + return findNonStringMappingKeyInMergeValue(node.Alias, path) + case yaml.SequenceNode: + for _, child := range node.Content { + if found := findNonStringMappingKeyInMergeValue(child, path); found != nil { + return found + } + } + return nil + default: + return findNonStringMappingKeyInNode(node, path) + } +} + +func isStringMappingKey(keyNode *yaml.Node) bool { + if keyNode == nil || keyNode.Kind != yaml.ScalarNode { + return false + } + return keyNode.ShortTag() == "!!str" +} + +func isMergeMappingKey(keyNode *yaml.Node) bool { + if keyNode == nil || keyNode.Kind != yaml.ScalarNode { + return false + } + return keyNode.ShortTag() == "!!merge" && keyNode.Value == "<<" +} + +func appendPathSegment(path []string, segment string) []string { + next := make([]string, 0, len(path)+1) + next = append(next, path...) + return append(next, segment) +} + +func buildJSONPointer(path []string) string { + if len(path) == 0 { + return "" + } + var builder strings.Builder + for _, segment := range path { + builder.WriteByte('/') + builder.WriteString(helpers.EscapeJSONPointerSegment(segment)) + } + return builder.String() +} + +func buildNonStringMappingKeyError(key *nonStringMappingKey) *liberrors.ValidationError { + pointer := buildJSONPointer(key.Path) + reason := fmt.Sprintf("OpenAPI documents require string mapping keys, but found %s key %q at %s", + yamlKeyType(key), key.Value, pointer) + howToFix := "Quote YAML mapping keys that should be strings, because OpenAPI documents must be representable as JSON objects" + + if isOperationResponseStatusCodeKey(key.Path) { + reason = fmt.Sprintf("Response status code keys must be strings, quote %s as %q at %s", + key.Value, key.Value, pointer) + howToFix = fmt.Sprintf("Quote the response status code key, for example use %q instead of %s", + key.Value, key.Value) + } + + return &liberrors.ValidationError{ + ValidationType: helpers.Schema, + ValidationSubType: "document", + Message: "OpenAPI document validation failed", + Reason: reason, + SpecLine: key.Line, + SpecCol: key.Column, + HowToFix: howToFix, + Context: pointer, + } +} + +func yamlKeyType(key *nonStringMappingKey) string { + if key == nil { + return "non-string" + } + if key.Sequence { + return "sequence" + } + return strings.TrimPrefix(key.Tag, "!!") +} + +func isOperationResponseStatusCodeKey(path []string) bool { + if len(path) < 5 || path[0] != "paths" || path[len(path)-2] != "responses" { + return false + } + for _, segment := range path[2 : len(path)-2] { + if isHTTPMethod(segment) { + return true + } + } + return false +} + +func isHTTPMethod(segment string) bool { + switch strings.ToLower(segment) { + case "get", "put", "post", "delete", "options", "head", "patch", "trace": + return true + default: + return false + } +} + +func buildDocumentDecodeError(reason, context string) *liberrors.ValidationError { + return &liberrors.ValidationError{ + ValidationType: helpers.Schema, + ValidationSubType: "document", + Message: "OpenAPI document validation failed", + Reason: reason, + SpecLine: 1, + SpecCol: 0, + HowToFix: "ensure the OpenAPI document is valid YAML/JSON and can be represented as JSON", + Context: context, + } } // ValidateOpenAPIDocument will validate an OpenAPI document against the OpenAPI 2, 3.0 and 3.1 schemas (depending on version) @@ -59,6 +250,12 @@ func ValidateOpenAPIDocumentWithPrecompiled(doc libopenapi.Document, compiledSch return false, validationErrors } + if info.RootNode != nil { + if invalidKey := findNonStringMappingKey(info.RootNode); invalidKey != nil { + return false, []*liberrors.ValidationError{buildNonStringMappingKeyError(invalidKey)} + } + } + // Use the precompiled schema if provided, otherwise compile it jsch := compiledSchema if jsch == nil { @@ -88,11 +285,29 @@ func ValidateOpenAPIDocumentWithPrecompiled(doc libopenapi.Document, compiledSch if err != nil { // Fall back to normalizeJSON if UnmarshalJSON fails if info.SpecJSON != nil { - normalized = normalizeJSON(*info.SpecJSON) + normalized, err = normalizeJSON(*info.SpecJSON) + if err != nil { + return false, []*liberrors.ValidationError{buildDocumentDecodeError( + fmt.Sprintf("The OpenAPI document cannot be converted to JSON: %s", err.Error()), + "SpecJSON", + )} + } + } else { + return false, []*liberrors.ValidationError{buildDocumentDecodeError( + fmt.Sprintf("The document's SpecJSONBytes cannot be decoded as JSON: %s", err.Error()), + "SpecJSONBytes", + )} } } } else if info.SpecJSON != nil { - normalized = normalizeJSON(*info.SpecJSON) + var err error + normalized, err = normalizeJSON(*info.SpecJSON) + if err != nil { + return false, []*liberrors.ValidationError{buildDocumentDecodeError( + fmt.Sprintf("The OpenAPI document cannot be converted to JSON: %s", err.Error()), + "SpecJSON", + )} + } } // Validate the document diff --git a/schema_validation/validate_document_test.go b/schema_validation/validate_document_test.go index 94cbd96..6f8e2f1 100644 --- a/schema_validation/validate_document_test.go +++ b/schema_validation/validate_document_test.go @@ -10,6 +10,7 @@ import ( "github.com/pb33f/libopenapi" "github.com/stretchr/testify/assert" + "go.yaml.in/yaml/v4" "github.com/pb33f/libopenapi-validator/config" liberrors "github.com/pb33f/libopenapi-validator/errors" @@ -53,6 +54,257 @@ func TestValidateDocument_Invalid31(t *testing.T) { assert.Len(t, errors[0].SchemaValidationErrors, 6) } +func TestValidateDocument_UnquotedIntegerResponseCodeHelpfulError(t *testing.T) { + spec := `openapi: 3.1.0 +info: + title: Test + version: 1.0.0 +paths: + /test: + get: + responses: + 200: + description: OK` + + doc, _ := libopenapi.NewDocument([]byte(spec)) + + valid, errors := ValidateOpenAPIDocument(doc) + + assert.False(t, valid) + assert.Len(t, errors, 1) + assert.Equal(t, "OpenAPI document validation failed", errors[0].Message) + assert.Contains(t, errors[0].Reason, "Response status code keys must be strings") + assert.Contains(t, errors[0].Reason, `quote 200 as "200"`) + assert.Contains(t, errors[0].Reason, "/paths/~1test/get/responses/200") + assert.NotContains(t, errors[0].Reason, "got null, want object") + assert.Contains(t, errors[0].HowToFix, `"200"`) + assert.Equal(t, 9, errors[0].SpecLine) + assert.Equal(t, 9, errors[0].SpecCol) + assert.Empty(t, errors[0].SchemaValidationErrors) +} + +func TestValidateDocument_QuotedResponseCodeValid(t *testing.T) { + spec := `openapi: 3.1.0 +info: + title: Test + version: 1.0.0 +paths: + /test: + get: + responses: + "200": + description: OK` + + doc, _ := libopenapi.NewDocument([]byte(spec)) + + valid, errors := ValidateOpenAPIDocument(doc) + + assert.True(t, valid) + assert.Empty(t, errors) +} + +func TestValidateDocument_YAMLMergeKeyValid(t *testing.T) { + spec := `openapi: 3.1.0 +info: + title: Test + version: 1.0.0 +x-base-response: &baseResponse + description: OK +paths: + /test: + get: + responses: + "200": + <<: *baseResponse` + + doc, _ := libopenapi.NewDocument([]byte(spec)) + + valid, errors := ValidateOpenAPIDocument(doc) + + assert.True(t, valid) + assert.Empty(t, errors) +} + +func TestValidateDocument_YAMLMergeKeyDoesNotHideInvalidResponseCode(t *testing.T) { + spec := `openapi: 3.1.0 +info: + title: Test + version: 1.0.0 +x-base-responses: &baseResponses + default: + description: Default +paths: + /test: + get: + responses: + <<: *baseResponses + 200: + description: OK` + + doc, _ := libopenapi.NewDocument([]byte(spec)) + + valid, errors := ValidateOpenAPIDocument(doc) + + assert.False(t, valid) + assert.Len(t, errors, 1) + assert.Contains(t, errors[0].Reason, "Response status code keys must be strings") + assert.Contains(t, errors[0].Reason, `quote 200 as "200"`) + assert.NotContains(t, errors[0].Reason, `merge key "<<"`) + assert.Equal(t, 13, errors[0].SpecLine) + assert.Empty(t, errors[0].SchemaValidationErrors) +} + +func TestValidateDocument_GenericNonStringMappingKeyHelpfulError(t *testing.T) { + spec := `openapi: 3.1.0 +info: + title: Test + version: 1.0.0 +paths: {} +x-values: + 1: one` + + doc, _ := libopenapi.NewDocument([]byte(spec)) + + valid, errors := ValidateOpenAPIDocument(doc) + + assert.False(t, valid) + assert.Len(t, errors, 1) + assert.Equal(t, "OpenAPI document validation failed", errors[0].Message) + assert.Contains(t, errors[0].Reason, "OpenAPI documents require string mapping keys") + assert.Contains(t, errors[0].Reason, `int key "1"`) + assert.Contains(t, errors[0].Reason, "/x-values/1") + assert.NotContains(t, errors[0].Reason, "got null, want object") + assert.Contains(t, errors[0].HowToFix, "Quote YAML mapping keys") + assert.Equal(t, 7, errors[0].SpecLine) + assert.Equal(t, 3, errors[0].SpecCol) + assert.Empty(t, errors[0].SchemaValidationErrors) +} + +func TestNormalizeJSON_ReturnsMarshalError(t *testing.T) { + payload := map[string]interface{}{ + "openapi": "3.1.0", + "invalid": map[interface{}]interface{}{ + 1: "one", + }, + } + + normalized, err := normalizeJSON(payload) + + assert.Nil(t, normalized) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported type: map[interface {}]interface {}") +} + +func TestValidateDocument_NormalizationErrorDoesNotValidateNil(t *testing.T) { + spec := `openapi: 3.1.0 +info: + title: Test + version: 1.0.0 +paths: {}` + + doc, _ := libopenapi.NewDocument([]byte(spec)) + badSpecJSON := map[string]interface{}{ + "openapi": "3.1.0", + "invalid": map[interface{}]interface{}{ + 1: "one", + }, + } + info := doc.GetSpecInfo() + info.SpecJSON = &badSpecJSON + info.SpecJSONBytes = nil + + valid, errors := ValidateOpenAPIDocument(doc) + + assert.False(t, valid) + assert.Len(t, errors, 1) + assert.Equal(t, "OpenAPI document validation failed", errors[0].Message) + assert.Contains(t, errors[0].Reason, "cannot be converted to JSON") + assert.Contains(t, errors[0].Reason, "unsupported type: map[interface {}]interface {}") + assert.NotContains(t, errors[0].Reason, "got null, want object") + assert.Empty(t, errors[0].SchemaValidationErrors) +} + +func TestValidateDocument_CorruptSpecJSONBytesFallbackNormalizationError(t *testing.T) { + spec := `openapi: 3.1.0 +info: + title: Test + version: 1.0.0 +paths: {}` + + doc, _ := libopenapi.NewDocument([]byte(spec)) + badSpecJSON := map[string]interface{}{ + "openapi": "3.1.0", + "invalid": map[interface{}]interface{}{ + 1: "one", + }, + } + corrupt := []byte(`{not valid json!!!}`) + info := doc.GetSpecInfo() + info.SpecJSON = &badSpecJSON + info.SpecJSONBytes = &corrupt + + valid, errors := ValidateOpenAPIDocument(doc) + + assert.False(t, valid) + assert.Len(t, errors, 1) + assert.Contains(t, errors[0].Reason, "cannot be converted to JSON") + assert.NotContains(t, errors[0].Reason, "got null, want object") + assert.Empty(t, errors[0].SchemaValidationErrors) +} + +func TestValidateDocumentHelpers_DefensiveBranches(t *testing.T) { + assert.Nil(t, findNonStringMappingKey(nil)) + assert.Nil(t, findNonStringMappingKeyInNode(nil, nil)) + assert.Nil(t, findNonStringMappingKeyInMergeValue(nil, nil)) + assert.False(t, isStringMappingKey(nil)) + assert.False(t, isMergeMappingKey(nil)) + assert.Equal(t, "", buildJSONPointer(nil)) + assert.Equal(t, "non-string", yamlKeyType(nil)) + assert.False(t, isOperationResponseStatusCodeKey([]string{"paths", "/test", "parameters", "responses", "200"})) + + sequenceKey := &nonStringMappingKey{Sequence: true} + assert.Equal(t, "sequence", yamlKeyType(sequenceKey)) + + mergeKey := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!merge", Value: "<<"} + assert.True(t, isMergeMappingKey(mergeKey)) + + intKey := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!int", Value: "1", Line: 2, Column: 5} + value := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: "one"} + mapping := &yaml.Node{ + Kind: yaml.MappingNode, + Content: []*yaml.Node{intKey, value}, + } + sequence := &yaml.Node{ + Kind: yaml.SequenceNode, + Content: []*yaml.Node{mapping}, + } + + found := findNonStringMappingKeyInNode(sequence, []string{"items"}) + assert.NotNil(t, found) + assert.Equal(t, []string{"items", "0", "1"}, found.Path) + assert.Equal(t, "1", found.Value) + + found = findNonStringMappingKeyInMergeValue(&yaml.Node{Kind: yaml.AliasNode, Alias: mapping}, []string{"merged"}) + assert.NotNil(t, found) + assert.Equal(t, []string{"merged", "1"}, found.Path) + + mergeMapping := &yaml.Node{ + Kind: yaml.MappingNode, + Content: []*yaml.Node{mergeKey, mapping}, + } + found = findNonStringMappingKeyInNode(mergeMapping, []string{"mergeTarget"}) + assert.NotNil(t, found) + assert.Equal(t, []string{"mergeTarget", "1"}, found.Path) + + mergeSequence := &yaml.Node{Kind: yaml.SequenceNode, Content: []*yaml.Node{mapping}} + found = findNonStringMappingKeyInMergeValue(mergeSequence, []string{"mergedSequence"}) + assert.NotNil(t, found) + assert.Equal(t, []string{"mergedSequence", "1"}, found.Path) + + assert.Nil(t, findNonStringMappingKeyInMergeValue(&yaml.Node{Kind: yaml.SequenceNode}, []string{"empty"})) + assert.Nil(t, findNonStringMappingKeyInMergeValue(&yaml.Node{Kind: yaml.AliasNode}, []string{"merged"})) +} + // Helper function to test the validation logic directly func validateOpenAPIDocumentWithMalformedSchema(loadedSchema string, decodedDocument map[string]interface{}) (bool, []*liberrors.ValidationError) { options := config.NewValidationOptions() @@ -277,10 +529,12 @@ func TestValidateDocument_SpecJSONBytesCorrupt_NilSpecJSON(t *testing.T) { info.SpecJSONBytes = &corrupt info.SpecJSON = nil - // Validation should still run (against nil normalized value) and report errors + // Validation should fail before JSON Schema validation instead of validating nil. valid, errs := ValidateOpenAPIDocument(doc) assert.False(t, valid) - assert.NotEmpty(t, errs) + assert.Len(t, errs, 1) + assert.Contains(t, errs[0].Reason, "SpecJSONBytes cannot be decoded as JSON") + assert.Empty(t, errs[0].SchemaValidationErrors) } func TestValidateDocument_SpecJSONBytesCorrupt_FallbackToSpecJSON(t *testing.T) {