From d04c56c58547a4485708be125a440a0371d29355 Mon Sep 17 00:00:00 2001 From: Dave Shanley Date: Mon, 27 Apr 2026 07:37:31 -0400 Subject: [PATCH] Address issue #258 Make sure the body bytes are available and reset for reading. --- requests/validate_body.go | 10 +- requests/validate_body_test.go | 273 +++++++++++++++++++++++++++++++++ requests/validate_request.go | 101 ++++++++++-- 3 files changed, 368 insertions(+), 16 deletions(-) diff --git a/requests/validate_body.go b/requests/validate_body.go index 92ae344..1586e5c 100644 --- a/requests/validate_body.go +++ b/requests/validate_body.go @@ -4,10 +4,8 @@ package requests import ( - "bytes" "encoding/json" "fmt" - "io" "net/http" "strings" @@ -92,10 +90,8 @@ func (v *requestBodyValidator) ValidateRequestBodyWithPathItem(request *http.Req return true, nil } - if request != nil && request.Body != nil { - requestBody, _ := io.ReadAll(request.Body) - _ = request.Body.Close() - + if request != nil && (request.Body != nil || request.GetBody != nil) { + requestBody := readAndResetRequestBody(request) stringedBody := string(requestBody) var jsonBody any var prevalidationErrors []*errors.ValidationError @@ -121,7 +117,7 @@ func (v *requestBodyValidator) ValidateRequestBodyWithPathItem(request *http.Req } } - request.Body = io.NopCloser(bytes.NewBuffer(transformedBytes)) + setRequestBody(request, transformedBytes) } } diff --git a/requests/validate_body_test.go b/requests/validate_body_test.go index ed41101..36daa63 100644 --- a/requests/validate_body_test.go +++ b/requests/validate_body_test.go @@ -7,6 +7,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" "net/http" "sync" "testing" @@ -642,6 +643,278 @@ paths: assert.Len(t, errors, 0) } +func TestValidateBody_UsesGetBodyWhenBodyAlreadyConsumed(t *testing.T) { + spec := `openapi: 3.1.0 +paths: + /burgers/createBurger: + post: + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [name, patties, vegetarian] + properties: + name: + type: string + patties: + type: integer + vegetarian: + type: boolean` + + doc, _ := libopenapi.NewDocument([]byte(spec)) + + m, _ := doc.BuildV3Model() + v := NewRequestBodyValidator(&m.Model) + + body := map[string]interface{}{ + "name": "Big Mac", + "patties": 2, + "vegetarian": true, + } + bodyBytes, _ := json.Marshal(body) + + request, _ := http.NewRequest(http.MethodPost, "https://things.com/burgers/createBurger", + bytes.NewReader(bodyBytes)) + request.Header.Set("Content-Type", "application/json") + _, _ = io.ReadAll(request.Body) + + valid, validationErrors := v.ValidateRequestBody(request) + require.True(t, valid) + require.Empty(t, validationErrors) + + restoredBody, err := io.ReadAll(request.Body) + require.NoError(t, err) + require.JSONEq(t, string(bodyBytes), string(restoredBody)) + + replayedBody, err := request.GetBody() + require.NoError(t, err) + replayedBytes, err := io.ReadAll(replayedBody) + require.NoError(t, err) + require.NoError(t, replayedBody.Close()) + require.JSONEq(t, string(bodyBytes), string(replayedBytes)) +} + +func TestValidateBody_PrefersAssignedBodyOverStaleGetBody(t *testing.T) { + spec := `openapi: 3.1.0 +paths: + /burgers/createBurger: + post: + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [name, patties, vegetarian] + properties: + name: + type: string + patties: + type: integer + vegetarian: + type: boolean` + + doc, _ := libopenapi.NewDocument([]byte(spec)) + + m, _ := doc.BuildV3Model() + v := NewRequestBodyValidator(&m.Model) + + staleBodyBytes, _ := json.Marshal(map[string]interface{}{ + "name": "Big Mac", + "patties": false, + "vegetarian": true, + }) + currentBodyBytes, _ := json.Marshal(map[string]interface{}{ + "name": "Big Mac", + "patties": 2, + "vegetarian": true, + }) + + request, _ := http.NewRequest(http.MethodPost, "https://things.com/burgers/createBurger", + bytes.NewReader(staleBodyBytes)) + request.Header.Set("Content-Type", "application/json") + request.Body = io.NopCloser(bytes.NewReader(currentBodyBytes)) + + valid, validationErrors := v.ValidateRequestBody(request) + require.True(t, valid) + require.Empty(t, validationErrors) +} + +func TestValidateBody_DoesNotUseStaleGetBodyForConsumedDifferentBodySameLength(t *testing.T) { + spec := `openapi: 3.1.0 +paths: + /burgers/createBurger: + post: + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [patties] + properties: + patties: + type: integer` + + doc, _ := libopenapi.NewDocument([]byte(spec)) + + m, _ := doc.BuildV3Model() + v := NewRequestBodyValidator(&m.Model) + + staleBodyBytes := []byte(`{"patties":12345}`) + currentBodyBytes := []byte(`{"patties":false}`) + require.Len(t, currentBodyBytes, len(staleBodyBytes)) + + request, _ := http.NewRequest(http.MethodPost, "https://things.com/burgers/createBurger", + bytes.NewReader(staleBodyBytes)) + request.Header.Set("Content-Type", "application/json") + request.Body = io.NopCloser(bytes.NewReader(currentBodyBytes)) + _, _ = io.ReadAll(request.Body) + + valid, validationErrors := v.ValidateRequestBody(request) + require.False(t, valid) + require.Len(t, validationErrors, 1) + require.Equal(t, "POST request body is empty for '/burgers/createBurger'", validationErrors[0].Message) + + replayedBody, err := request.GetBody() + require.NoError(t, err) + replayedBytes, err := io.ReadAll(replayedBody) + require.NoError(t, err) + require.NoError(t, replayedBody.Close()) + require.Empty(t, replayedBytes) +} + +func TestValidateBody_DoesNotUseStaleGetBodyForExplicitEmptyBody(t *testing.T) { + spec := `openapi: 3.1.0 +paths: + /burgers/createBurger: + post: + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [name, patties, vegetarian] + properties: + name: + type: string + patties: + type: integer + vegetarian: + type: boolean` + + doc, _ := libopenapi.NewDocument([]byte(spec)) + + m, _ := doc.BuildV3Model() + v := NewRequestBodyValidator(&m.Model) + + staleBodyBytes, _ := json.Marshal(map[string]interface{}{ + "name": "Big Mac", + "patties": 2, + "vegetarian": true, + }) + + tests := []struct { + name string + body io.ReadCloser + }{ + { + name: "http no body", + body: http.NoBody, + }, + { + name: "empty reader", + body: io.NopCloser(bytes.NewReader(nil)), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + request, _ := http.NewRequest(http.MethodPost, "https://things.com/burgers/createBurger", + bytes.NewReader(staleBodyBytes)) + request.Header.Set("Content-Type", "application/json") + request.Body = tc.body + + valid, validationErrors := v.ValidateRequestBody(request) + require.False(t, valid) + require.Len(t, validationErrors, 1) + require.Equal(t, "POST request body is empty for '/burgers/createBurger'", validationErrors[0].Message) + + replayedBody, err := request.GetBody() + require.NoError(t, err) + replayedBytes, err := io.ReadAll(replayedBody) + require.NoError(t, err) + require.NoError(t, replayedBody.Close()) + require.Empty(t, replayedBytes) + }) + } +} + +func TestRequestBodyHelpers_NilRequest(t *testing.T) { + setRequestBody(nil, []byte(`{"ok":true}`)) + require.Nil(t, readAndResetRequestBody(nil)) +} + +type requestBodyReaderTestBody struct{} + +func (r *requestBodyReaderTestBody) Read(_ []byte) (int, error) { + return 0, io.EOF +} + +func (r *requestBodyReaderTestBody) Close() error { + return nil +} + +type failingReplayableBody struct{} + +func (r *failingReplayableBody) Read(_ []byte) (int, error) { + return 0, io.EOF +} + +func (r *failingReplayableBody) Close() error { + return nil +} + +func (r *failingReplayableBody) ReadAt(_ []byte, _ int64) (int, error) { + return 0, io.ErrUnexpectedEOF +} + +func (r *failingReplayableBody) Size() int64 { + return 1 +} + +func TestRequestBodyReader_DefensiveBranches(t *testing.T) { + require.Nil(t, requestBodyReader(nil)) + require.Nil(t, requestBodyReader(http.NoBody)) + + var nilBody *requestBodyReaderTestBody + require.Nil(t, requestBodyReader(nilBody)) + + body := &requestBodyReaderTestBody{} + require.Same(t, body, requestBodyReader(body)) +} + +func TestRequestBodySnapshot_DefensiveBranches(t *testing.T) { + snapshot, ok := requestBodySnapshot(nil) + require.False(t, ok) + require.Nil(t, snapshot) + + snapshot, ok = requestBodySnapshot(&http.Request{Body: &requestBodyReaderTestBody{}}) + require.False(t, ok) + require.Nil(t, snapshot) + + snapshot, ok = requestBodySnapshot(&http.Request{Body: io.NopCloser(bytes.NewReader(nil))}) + require.False(t, ok) + require.Nil(t, snapshot) + + snapshot, ok = requestBodySnapshot(&http.Request{Body: &failingReplayableBody{}}) + require.False(t, ok) + require.Nil(t, snapshot) +} + func TestValidateBody_ValidBasicSchema_WithFullContentTypeHeader(t *testing.T) { spec := `openapi: 3.1.0 paths: diff --git a/requests/validate_request.go b/requests/validate_request.go index 251b8d1..b933788 100644 --- a/requests/validate_request.go +++ b/requests/validate_request.go @@ -39,6 +39,97 @@ type ValidateRequestSchemaInput struct { BodyRequired bool // Optional: Whether the request body is required (default false) } +type replayableBody interface { + io.ReaderAt + Size() int64 +} + +func setRequestBody(request *http.Request, body []byte) { + if request == nil { + return + } + bodyCopy := append([]byte(nil), body...) + request.Body = io.NopCloser(bytes.NewReader(bodyCopy)) + request.ContentLength = int64(len(bodyCopy)) + request.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(bodyCopy)), nil + } +} + +func requestBodySnapshot(request *http.Request) ([]byte, bool) { + if request == nil || request.Body == nil || request.Body == http.NoBody { + return nil, false + } + reader := requestBodyReader(request.Body) + body, ok := reader.(replayableBody) + if !ok { + return nil, false + } + size := body.Size() + if size <= 0 { + return nil, false + } + snapshot, err := io.ReadAll(io.NewSectionReader(body, 0, size)) + if err != nil { + return nil, false + } + return snapshot, true +} + +func requestBodyReader(body io.ReadCloser) io.Reader { + if body == nil || body == http.NoBody { + return nil + } + + value := reflect.ValueOf(body) + if value.Kind() == reflect.Ptr { + if value.IsNil() { + return nil + } + value = value.Elem() + } + if value.Kind() == reflect.Struct { + field := value.FieldByName("Reader") + if field.IsValid() && field.CanInterface() { + if reader, ok := field.Interface().(io.Reader); ok { + return reader + } + } + } + return body +} + +func readAndResetRequestBody(request *http.Request) []byte { + if request == nil { + return nil + } + + var requestBody []byte + bodyRead := false + bodySnapshot, hasBodySnapshot := requestBodySnapshot(request) + if request.Body != nil { + requestBody, _ = io.ReadAll(request.Body) + _ = request.Body.Close() + bodyRead = true + } + + if len(requestBody) == 0 && hasBodySnapshot && request.GetBody != nil { + if body, err := request.GetBody(); err == nil && body != nil { + replayedBody, _ := io.ReadAll(body) + _ = body.Close() + if bytes.Equal(replayedBody, bodySnapshot) { + requestBody = replayedBody + bodyRead = true + } + } + } + + if bodyRead { + setRequestBody(request, requestBody) + } + return requestBody +} + // ValidateRequestSchema will validate a http.Request pointer against a schema. // If validation fails, it will return a list of validation errors as the second return value. // The schema will be stored and reused from cache if available, otherwise it will be compiled on each call. @@ -146,15 +237,7 @@ func ValidateRequestSchema(input *ValidateRequestSchemaInput) (bool, []*errors.V request := input.Request schema := input.Schema - var requestBody []byte - if request != nil && request.Body != nil { - requestBody, _ = io.ReadAll(request.Body) - - // close the request body, so it can be re-read later by another player in the chain - _ = request.Body.Close() - request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - - } + requestBody := readAndResetRequestBody(request) var decodedObj interface{}