From 017f0aa09d7fd802bd1760836e329734ea642180 Mon Sep 17 00:00:00 2001 From: Sergio VS Date: Tue, 23 Nov 2021 11:12:06 +0100 Subject: [PATCH] fix: reset request after reset user values on keep-alive connections (#1162) * fix: reset request after reset user values on keep-alive connections * test: add test for reset request after reset user values --- server.go | 6 ++++-- server_test.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/server.go b/server.go index d7b7f0f789..5ffc891745 100644 --- a/server.go +++ b/server.go @@ -2118,6 +2118,7 @@ func (s *Server) serveConn(c net.Conn) (err error) { br, err = acquireByteReader(&ctx) } + reqReset = false ctx.Request.isTLS = isTLS ctx.Response.Header.noDefaultContentType = s.NoDefaultContentType ctx.Response.Header.noDefaultDate = s.NoDefaultDate @@ -2302,8 +2303,6 @@ func (s *Server) serveConn(c net.Conn) (err error) { if !ctx.IsGet() && ctx.IsHead() { ctx.Response.SkipBody = true } - reqReset = true - ctx.Request.Reset() hijackHandler = ctx.hijackHandler ctx.hijackHandler = nil @@ -2404,6 +2403,9 @@ func (s *Server) serveConn(c net.Conn) (err error) { s.setState(c, StateIdle) ctx.userValues.Reset() + reqReset = true + ctx.Request.Reset() + if atomic.LoadInt32(&s.stop) == 1 { err = nil break diff --git a/server_test.go b/server_test.go index 5678022d85..b42197c985 100644 --- a/server_test.go +++ b/server_test.go @@ -25,6 +25,15 @@ import ( // Make sure RequestCtx implements context.Context var _ context.Context = &RequestCtx{} +type closerWithRequestCtx struct { + ctx *RequestCtx + closeFunc func(ctx *RequestCtx) error +} + +func (c *closerWithRequestCtx) Close() error { + return c.closeFunc(c.ctx) +} + func TestServerCRNLAfterPost_Pipeline(t *testing.T) { t.Parallel() @@ -2032,6 +2041,43 @@ func TestRequestCtxWriteString(t *testing.T) { } } +func TestServeConnKeepRequestValuesUntilResetUserValues(t *testing.T) { + t.Parallel() + + reqStr := "POST /foo HTTP/1.0\r\nHost: google.com\r\nContent-Type: application/octet-stream\r\nContent-Length: 0\r\nConnection: keep-alive\r\n\r\n" + + rw := &readWriter{} + rw.r.WriteString(reqStr) + + var resultReqStr string + + ch := make(chan struct{}) + go func() { + err := ServeConn(rw, func(ctx *RequestCtx) { + ctx.SetUserValue("myKey", &closerWithRequestCtx{ + ctx: ctx, + closeFunc: func(closerCtx *RequestCtx) error { + resultReqStr = closerCtx.Request.String() + return nil + }}) + }) + if err != nil { + t.Errorf("unexpected error in ServeConn: %s", err) + } + close(ch) + }() + + select { + case <-ch: + case <-time.After(time.Second): + t.Fatal("timeout") + } + + if resultReqStr != reqStr { + t.Errorf("Request == %s, want %s", resultReqStr, reqStr) + } +} + func TestServeConnNonHTTP11KeepAlive(t *testing.T) { t.Parallel()