diff --git a/gzip.go b/gzip.go index 957fc92..fa9e636 100644 --- a/gzip.go +++ b/gzip.go @@ -3,8 +3,10 @@ package gziphandler // import "github.com/NYTimes/gziphandler" import ( "bufio" "compress/gzip" + "crypto/rand" "fmt" "io" + "math/big" "mime" "net" "net/http" @@ -41,6 +43,9 @@ const ( // gzipWriterPools. var gzipWriterPools [gzip.BestCompression - gzip.BestSpeed + 2]*sync.Pool +// ascii is used to create pseudo random file names for HTB. +var ascii = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + func init() { for i := gzip.BestSpeed; i <= gzip.BestCompression; i++ { addLevelPool(i) @@ -81,9 +86,10 @@ type GzipResponseWriter struct { code int // Saves the WriteHeader value. - minSize int // Specifies the minimum response size to gzip. If the response length is bigger than this value, it is compressed. - buf []byte // Holds the first part of the write before reaching the minSize or the end of the write. - ignore bool // If true, then we immediately passthru writes to the underlying ResponseWriter. + minSize int // Specifies the minimum response size to gzip. If the response length is bigger than this value, it is compressed. + htbFileName func() (string, error) // Holds a function that returns a random string for HTB. Setting a function enables dependency injection. + buf []byte // Holds the first part of the write before reaching the minSize or the end of the write. + ignore bool // If true, then we immediately passthru writes to the underlying ResponseWriter. contentTypes []parsedContentType // Only compress if the response is one of these content-types. All are accepted if empty. } @@ -169,6 +175,16 @@ func (w *GzipResponseWriter) startGzip() error { if len(w.buf) > 0 { // Initialize the GZIP response. w.init() + + // Handle HTB. Modifying the header needs to happen before the first call to write. + if w.htbFileName != nil { + htbName, err := w.htbFileName() + if err != nil { + return fmt.Errorf("gziphandler: generating HTB file name: %w", err) + } + w.gw.Header.Name = htbName + } + n, err := w.gw.Write(w.buf) // This should never happen (per io.Writer docs), but if the write didn't @@ -182,6 +198,25 @@ func (w *GzipResponseWriter) startGzip() error { return nil } +func htbFileName(maxSize int) (string, error) { + // Get a random size + size, err := rand.Int(rand.Reader, big.NewInt(int64(maxSize))) + if err != nil { + return "", err + } + + // Create the string + b := make([]rune, size.Int64()) + for i := range b { + n, err := rand.Int(rand.Reader, big.NewInt(int64(len(ascii)))) + if err != nil { + return "", err + } + b[i] = ascii[n.Int64()] + } + return string(b), nil +} + // startPlain writes to sent bytes and buffer the underlying ResponseWriter without gzip. func (w *GzipResponseWriter) startPlain() error { if w.code != 0 { @@ -329,6 +364,11 @@ func GzipHandlerWithOpts(opts ...option) (func(http.Handler) http.Handler, error minSize: c.minSize, contentTypes: c.contentTypes, } + if c.htbSize > 0 { + gw.htbFileName = func() (string, error) { + return htbFileName(c.htbSize) + } + } defer gw.Close() if _, ok := w.(http.CloseNotifier); ok { @@ -378,6 +418,7 @@ func (pct parsedContentType) equals(mediaType string, params map[string]string) type config struct { minSize int level int + htbSize int contentTypes []parsedContentType } @@ -407,6 +448,14 @@ func CompressionLevel(level int) option { } } +// HTBSize lets you specify the maximum size for Heal The Breach. +// See https://ieeexplore.ieee.org/document/9754554 +func HTBSize(size int) option { + return func(c *config) { + c.htbSize = size + } +} + // ContentTypes specifies a list of content types to compare // the Content-Type header to before compressing. If none // match, the response will be returned as-is. diff --git a/gzip_test.go b/gzip_test.go index bed7f52..c338e87 100644 --- a/gzip_test.go +++ b/gzip_test.go @@ -79,6 +79,28 @@ func TestGzipHandler(t *testing.T) { handler.ServeHTTP(res3, req3) assert.Equal(t, http.DetectContentType([]byte(testBody)), res3.Header().Get("Content-Type")) + + // Test with an HTB function + + req4, _ := http.NewRequest("GET", "/whatever", nil) + req4.Header.Set("Accept-Encoding", "gzip") + resp4 := httptest.NewRecorder() + + htbFileName := "12345" + wrapper, _ := handlerWithCustomHTB(func() (string, error) { + return htbFileName, nil + }) + htbHandler := wrapper(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, testBody) + + })) + htbHandler.ServeHTTP(resp4, req4) + res4 := resp4.Result() + + assert.Equal(t, 200, res4.StatusCode) + assert.Equal(t, "gzip", res4.Header.Get("Content-Encoding")) + assert.Equal(t, "Accept-Encoding", res4.Header.Get("Vary")) + assert.Equal(t, gzipStrLevelWithFileName(testBody, htbFileName, gzip.DefaultCompression), resp4.Body.Bytes()) } func TestGzipHandlerSmallBodyNoCompression(t *testing.T) { @@ -625,6 +647,15 @@ func gzipStrLevel(s string, lvl int) []byte { return b.Bytes() } +func gzipStrLevelWithFileName(s, n string, lvl int) []byte { + var b bytes.Buffer + w, _ := gzip.NewWriterLevel(&b, lvl) + w.Header.Name = n + io.WriteString(w, s) + w.Close() + return b.Bytes() +} + func benchmark(b *testing.B, parallel bool, size int) { bin, err := ioutil.ReadFile("testdata/benchmark.json") if err != nil { @@ -671,3 +702,37 @@ func newTestHandler(body string) http.Handler { } })) } + +func handlerWithCustomHTB(htb func() (string, error)) (func(http.Handler) http.Handler, error) { + c := &config{ + level: gzip.DefaultCompression, + minSize: DefaultMinSize, + } + return func(h http.Handler) http.Handler { + index := poolIndex(c.level) + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add(vary, acceptEncoding) + if acceptsGzip(r) { + gw := &GzipResponseWriter{ + ResponseWriter: w, + index: index, + minSize: c.minSize, + contentTypes: c.contentTypes, + htbFileName: htb, + } + defer gw.Close() + + if _, ok := w.(http.CloseNotifier); ok { + gwcn := GzipResponseWriterWithCloseNotify{gw} + h.ServeHTTP(gwcn, r) + } else { + h.ServeHTTP(gw, r) + } + + } else { + h.ServeHTTP(w, r) + } + }) + }, nil +}