diff --git a/gzip.go b/gzip.go index fbc9939..23efacc 100644 --- a/gzip.go +++ b/gzip.go @@ -117,6 +117,7 @@ func (w *GzipResponseWriter) Close() error { err := w.gw.Close() gzipWriterPools[w.index].Put(w.gw) + w.gw = nil return err } diff --git a/gzip_test.go b/gzip_test.go index f2d44e0..80f54c7 100644 --- a/gzip_test.go +++ b/gzip_test.go @@ -215,6 +215,31 @@ func TestGzipHandlerContentLength(t *testing.T) { assert.NotEqual(t, b, body) } +func TestGzipDoubleClose(t *testing.T) { + // reset the pool for the default compression so we can make sure duplicates + // aren't added back by double close + addLevelPool(gzip.DefaultCompression) + + handler := GzipHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // call close here and it'll get called again interally by + // NewGzipLevelHandler's handler defer + w.Write([]byte("test")) + w.(io.Closer).Close() + })) + + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Accept-Encoding", "gzip") + w := httptest.NewRecorder() + handler.ServeHTTP(w, r) + + // the second close shouldn't have added the same writer + // so we pull out 2 writers from the pool and make sure they're different + w1 := gzipWriterPools[poolIndex(gzip.DefaultCompression)].Get() + w2 := gzipWriterPools[poolIndex(gzip.DefaultCompression)].Get() + // assert.NotEqual looks at the value and not the address, so we use regular == + assert.False(t, w1 == w2) +} + // -------------------------------------------------------------------- func BenchmarkGzipHandler_S2k(b *testing.B) { benchmark(b, false, 2048) }