diff --git a/handler.go b/handler.go index 7aa6278..512b317 100644 --- a/handler.go +++ b/handler.go @@ -30,6 +30,31 @@ func ChainHandlers(handlers ...ResponseHandler) ResponseHandler { } } +// KeepRespBodyHandlers Combine multiple ResponseHandler and ensure that each processor has access to the original response body +func KeepRespBodyHandlers(handlers ...ResponseHandler) ResponseHandler { + return func(r *http.Response) error { + for _, h := range handlers { + if h == nil { + continue + } + + var dup io.ReadCloser + r.Body, dup = dupReadCloser(r.Body) + if err := h(r); err != nil { + return err + } + r.Body = dup + } + return nil + } +} + +func dupReadCloser(reader io.ReadCloser) (io.ReadCloser, io.ReadCloser) { + var buf bytes.Buffer + tee := io.TeeReader(reader, &buf) + return io.NopCloser(tee), io.NopCloser(&buf) +} + func consumeBody(res *http.Response) (err error) { const maxDiscardSize = 640 * 1 << 10 if _, err = io.CopyN(io.Discard, res.Body, maxDiscardSize); err == io.EOF { diff --git a/handler_test.go b/handler_test.go index 7e74c16..adcf952 100644 --- a/handler_test.go +++ b/handler_test.go @@ -1,8 +1,10 @@ package requests_test import ( + "bytes" "context" "fmt" + "io" "net/http" "os" "path/filepath" @@ -32,3 +34,37 @@ func BenchmarkBuilder_ToFile(b *testing.B) { be.NilErr(b, err) } } + +// TestKeepRespBodyHandlers tests the KeepRespBodyHandlers function. +func TestKeepRespBodyHandlers(t *testing.T) { + type Common struct { + ID int `json:"id"` + } + + type Book struct { + Common + Name string `json:"name"` + } + + var ( + book Book + common Common + str string + ) + + handler := requests.KeepRespBodyHandlers( + requests.ToJSON(&common), + requests.ToJSON(&book), + requests.ToString(&str), + ) + + err := handler(&http.Response{ + Body: io.NopCloser(bytes.NewReader([]byte(`{"id":1, "name":"孙子兵法"}`))), + }) + + be.NilErr(t, err) + be.Equal(t, 1, common.ID) + be.Equal(t, 1, book.ID) + be.Equal(t, "孙子兵法", book.Name) + be.Equal(t, `{"id":1, "name":"孙子兵法"}`, str) +}