From b71aaf63e024c77f55b91ae7eee78bd0dd73f857 Mon Sep 17 00:00:00 2001 From: Vee Zhang Date: Fri, 13 May 2022 18:12:59 +0800 Subject: [PATCH] middleware: add ReserveRequest and ReserveResponseWriter --- middleware/keep_request_responce.go | 150 ++++++++++++++++++ middleware/keep_request_responce_test.go | 185 +++++++++++++++++++++++ middleware/middleware.go | 9 ++ middleware/middleware_test.go | 11 ++ 4 files changed, 355 insertions(+) create mode 100644 middleware/keep_request_responce.go create mode 100644 middleware/keep_request_responce_test.go create mode 100644 middleware/middleware.go create mode 100644 middleware/middleware_test.go diff --git a/middleware/keep_request_responce.go b/middleware/keep_request_responce.go new file mode 100644 index 0000000..c345cd0 --- /dev/null +++ b/middleware/keep_request_responce.go @@ -0,0 +1,150 @@ +package middleware + +import ( + "bytes" + "context" + "io" + "net/http" + "sync" +) + +type ( + ReserveRequestConfig struct { + Skipper Skipper + } + + ReserveResponseWriterConfig struct { + Skipper Skipper + } + + reservedRequestCtxKey struct{} + reservedResponseWriterCtxKey struct{} + + reservedRequest struct { + r *http.Request + rawBody io.ReadCloser + reservedBody *reservedRequestBody + } + + reservedResponseWriter struct { + w http.ResponseWriter + } + + reservedRequestBody struct { + body io.ReadCloser + teeReader io.Reader + buff bytes.Buffer + once sync.Once + isRead bool + } +) + +func ReserveRequest(config ReserveRequestConfig) func(next http.Handler) http.Handler { + if config.Skipper == nil { + config.Skipper = DefaultSkipper + } + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if config.Skipper(r) { + next.ServeHTTP(w, r) + return + } + + reservedReq := newReservedRequest(r) + reservedReq.next(next, w) + }) + } +} + +func GetRequest(ctx context.Context) (*http.Request, bool) { + v := ctx.Value(reservedRequestCtxKey{}) + if v == nil { + return nil, false + } + return v.(*reservedRequest).getRequest(), true +} + +func ReserveResponseWriter(config ReserveResponseWriterConfig) func(next http.Handler) http.Handler { + if config.Skipper == nil { + config.Skipper = DefaultSkipper + } + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if config.Skipper(r) { + next.ServeHTTP(w, r) + return + } + + reservedResp := newReservedResponseWriter(w) + reservedResp.next(next, r) + }) + } +} + +func GetResponseWriter(ctx context.Context) (http.ResponseWriter, bool) { + v := ctx.Value(reservedResponseWriterCtxKey{}) + if v == nil { + return nil, false + } + return v.(*reservedResponseWriter).getResponseWriter(), true +} + +func newReservedRequest(r *http.Request) *reservedRequest { + return &reservedRequest{ + r: r, + rawBody: r.Body, + reservedBody: newReservedRequestBody(r.Body), + } +} + +func newReservedResponseWriter(w http.ResponseWriter) *reservedResponseWriter { + return &reservedResponseWriter{ + w: w, + } +} + +func newReservedRequestBody(body io.ReadCloser) *reservedRequestBody { + rc := &reservedRequestBody{ + body: body, + } + rc.teeReader = io.TeeReader(body, &rc.buff) + return rc +} + +func (rr *reservedRequest) next(next http.Handler, w http.ResponseWriter) { + r := rr.r + r.Body = rr.reservedBody + r = r.WithContext(context.WithValue(r.Context(), reservedRequestCtxKey{}, rr)) + next.ServeHTTP(w, r) +} + +func (rr *reservedRequest) getRequest() *http.Request { + r := rr.r + if rr.reservedBody.isRead { + r = r.Clone(r.Context()) + r.Body = io.NopCloser(&rr.reservedBody.buff) + } else { + r.Body = rr.rawBody + } + return r +} + +func (rr *reservedResponseWriter) next(next http.Handler, r *http.Request) { + r = r.WithContext(context.WithValue(r.Context(), reservedResponseWriterCtxKey{}, rr)) + next.ServeHTTP(rr.w, r) +} + +func (rr *reservedResponseWriter) getResponseWriter() http.ResponseWriter { + return rr.w +} + +func (rrb *reservedRequestBody) Read(p []byte) (n int, err error) { + rrb.once.Do(func() { + rrb.isRead = true + }) + return rrb.teeReader.Read(p) +} + +func (rrb *reservedRequestBody) Close() error { + return rrb.body.Close() +} diff --git a/middleware/keep_request_responce_test.go b/middleware/keep_request_responce_test.go new file mode 100644 index 0000000..9288c50 --- /dev/null +++ b/middleware/keep_request_responce_test.go @@ -0,0 +1,185 @@ +package middleware + +import ( + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestReserveRequest(t *testing.T) { + tests := []struct { + name string + method string + withMiddleware bool + config ReserveRequestConfig + readTwice bool + }{{ + name: "get:middleware:false", + method: http.MethodGet, + withMiddleware: false, + }, { + name: "get:middleware:true", + method: http.MethodGet, + withMiddleware: true, + }, { + name: "get:middleware:true:skipper", + method: http.MethodGet, + withMiddleware: true, + config: ReserveRequestConfig{ + Skipper: func(*http.Request) bool { + return true + }, + }, + }, { + name: "get:middleware:true:readTwice", + method: http.MethodGet, + withMiddleware: true, + readTwice: true, + }, { + name: "post:middleware:false", + method: http.MethodPost, + withMiddleware: false, + }, { + name: "post:middleware:true", + method: http.MethodPost, + withMiddleware: true, + }, { + name: "post:middleware:true", + method: http.MethodPost, + withMiddleware: true, + }, { + name: "post:middleware:true:skipper", + method: http.MethodPost, + withMiddleware: true, + config: ReserveRequestConfig{ + Skipper: func(*http.Request) bool { + return true + }, + }, + }, { + name: "post:middleware:true:readTwice", + method: http.MethodPost, + withMiddleware: true, + readTwice: true, + }} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + bodyString := "test body" + + var req *http.Request + if test.method == http.MethodPost { + req = httptest.NewRequest(http.MethodPost, "http://localhost", strings.NewReader(bodyString)) + } else { + req = httptest.NewRequest(http.MethodGet, "http://localhost", nil) + } + + rec := httptest.NewRecorder() + var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + assert.NoError(t, r.Body.Close()) + }() + + checkBody := func(readCloser io.ReadCloser) { + bodyBytes, err := ioutil.ReadAll(readCloser) + assert.NoError(t, err) + if test.method == http.MethodPost { + assert.Equal(t, bodyString, string(bodyBytes)) + } else { + assert.Equal(t, "", string(bodyBytes)) + } + } + + if test.readTwice { + checkBody(r.Body) + } + + httpReq, ok := GetRequest(r.Context()) + if test.withMiddleware && (test.config.Skipper == nil || !test.config.Skipper(r)) { + assert.True(t, ok) + assert.NotNil(t, httpReq) + checkBody(httpReq.Body) + } else { + assert.False(t, ok) + assert.Nil(t, httpReq) + } + }) + if test.withMiddleware { + m := ReserveRequest(test.config) + h = m(h) + } + h.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) + }) + } +} + +func TestReserveResponseWriter(t *testing.T) { + tests := []struct { + name string + method string + withMiddleware bool + config ReserveResponseWriterConfig + }{{ + name: "middleware:false", + method: http.MethodGet, + withMiddleware: false, + }, { + name: "middleware:true", + method: http.MethodGet, + withMiddleware: true, + }, { + name: "middleware:true:skipper", + method: http.MethodGet, + withMiddleware: true, + config: ReserveResponseWriterConfig{ + Skipper: func(*http.Request) bool { + return true + }, + }, + }, { + name: "middleware:true:readTwice", + method: http.MethodGet, + withMiddleware: true, + }} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + bodyString := "test body" + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + + rec := httptest.NewRecorder() + var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + assert.NoError(t, r.Body.Close()) + }() + + httpResp, ok := GetResponseWriter(r.Context()) + if test.withMiddleware && (test.config.Skipper == nil || !test.config.Skipper(r)) { + assert.True(t, ok) + assert.NotNil(t, httpResp) + httpResp.Write([]byte(bodyString)) + } else { + assert.False(t, ok) + assert.Nil(t, httpResp) + } + }) + if test.withMiddleware { + m := ReserveResponseWriter(test.config) + h = m(h) + } + h.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) + if test.withMiddleware && (test.config.Skipper == nil || !test.config.Skipper(req)) { + assert.Equal(t, bodyString, rec.Body.String()) + } else { + assert.Equal(t, "", rec.Body.String()) + } + }) + } +} diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 0000000..25da84b --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,9 @@ +package middleware + +import "net/http" + +type Skipper func(*http.Request) bool + +func DefaultSkipper(*http.Request) bool { + return false +} diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go new file mode 100644 index 0000000..6234397 --- /dev/null +++ b/middleware/middleware_test.go @@ -0,0 +1,11 @@ +package middleware + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDefaultSkipper(t *testing.T) { + assert.False(t, DefaultSkipper(nil)) +}