Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: optimize body buffering. #505

Merged
merged 4 commits into from
Nov 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 120 additions & 52 deletions http/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package http

import (
"fmt"
"io"
"net/http"

Expand All @@ -17,41 +18,108 @@ import (
// rwInterceptor intercepts the ResponseWriter, so it can track response size
// and returned status code.
type rwInterceptor struct {
w http.ResponseWriter
tx types.Transaction
statusCode int
w http.ResponseWriter
tx types.Transaction
statusCode int
proto string
hasStatusCode bool
}

func (i *rwInterceptor) WriteHeader(statusCode int) {
if i.hasStatusCode {
return
}

for k, vv := range i.w.Header() {
for _, v := range vv {
i.tx.AddResponseHeader(k, v)
}
}

i.hasStatusCode = true
i.statusCode = statusCode
if it := i.tx.ProcessResponseHeaders(statusCode, i.proto); it != nil {
i.statusCode = obtainStatusCodeFromInterruptionOrDefault(it, i.statusCode)
}
}

func (i *rwInterceptor) Write(b []byte) (int, error) {
return i.tx.ResponseBodyWriter().Write(b)
}
if !i.hasStatusCode {
i.WriteHeader(http.StatusOK)
}

func (i *rwInterceptor) Header() http.Header {
return i.w.Header()
}
if i.tx.Interrupted() {
// if there is an interruption it must be from phase 4 and hence
// we won't write anything to either the body or the buffer.
return 0, nil
}

func (i *rwInterceptor) StatusCode() int {
return i.statusCode
if i.tx.ResponseBodyAccessible() {
// we only buffer the response body if we are going to access
// to it, otherwise we just send it to the response writer.
return i.tx.ResponseBodyWriter().Write(b)
}

return i.w.Write(b)
}

// ResponseWriter adds Proto to http.ResponseWriter.
type ResponseWriterStatusCodeGetter interface {
http.ResponseWriter
StatusCode() int
func (i *rwInterceptor) Header() http.Header {
return i.w.Header()
}

var _ ResponseWriterStatusCodeGetter = (*rwInterceptor)(nil)
var _ http.ResponseWriter = (*rwInterceptor)(nil)

// wrap wraps the interceptor into a response writer that also preserves
// the http interfaces implemented by the original response writer to avoid
// the observer effect.
// the observer effect. It also returns the response processor which takes care
// of the response body copyback from the transaction buffer.
//
// Heavily inspired in https://github.com/openzipkin/zipkin-go/blob/master/middleware/http/server.go#L218
func wrap(w http.ResponseWriter, tx types.Transaction) ResponseWriterStatusCodeGetter { // nolint:gocyclo
i := &rwInterceptor{w: w, tx: tx}
func wrap(w http.ResponseWriter, r *http.Request, tx types.Transaction) (
http.ResponseWriter,
func(types.Transaction, *http.Request) error,
) { // nolint:gocyclo
i := &rwInterceptor{w: w, tx: tx, proto: r.Proto}

responseProcessor := func(tx types.Transaction, r *http.Request) error {
// We look for interruptions determined at phase 4 (response headers)
// as body hasn't being analized yet.
if tx.Interrupted() {
// phase 4 interruption stops execution
w.WriteHeader(i.statusCode)
return nil
}

if tx.ResponseBodyAccessible() {
if it, err := tx.ProcessResponseBody(); err != nil {
w.WriteHeader(http.StatusInternalServerError)
return err
} else if it != nil {
w.WriteHeader(obtainStatusCodeFromInterruptionOrDefault(it, i.statusCode))
return nil
}

// we release the buffer
reader, err := tx.ResponseBodyReader()
if err != nil {
i.w.WriteHeader(http.StatusInternalServerError)
return fmt.Errorf("failed to release the response body reader: %v", err)
}

// this is the last opportunity we have to report the resolved status code
// as next step is write into the response writer (triggering a 200 in the
// response status code.)
i.w.WriteHeader(i.statusCode)
if _, err := io.Copy(w, reader); err != nil {
i.w.WriteHeader(http.StatusInternalServerError)
return fmt.Errorf("failed to copy the response body: %v", err)
}
} else {
i.w.WriteHeader(i.statusCode)
}

return nil
}

var (
hijacker, isHijacker = i.w.(http.Hijacker)
Expand All @@ -63,103 +131,103 @@ func wrap(w http.ResponseWriter, tx types.Transaction) ResponseWriterStatusCodeG
switch {
case !isHijacker && !isPusher && !isFlusher && !isReader:
return struct {
ResponseWriterStatusCodeGetter
}{i}
http.ResponseWriter
}{i}, responseProcessor
case !isHijacker && !isPusher && !isFlusher && isReader:
return struct {
ResponseWriterStatusCodeGetter
http.ResponseWriter
io.ReaderFrom
}{i, reader}
}{i, reader}, responseProcessor
case !isHijacker && !isPusher && isFlusher && !isReader:
return struct {
ResponseWriterStatusCodeGetter
http.ResponseWriter
http.Flusher
}{i, flusher}
}{i, flusher}, responseProcessor
case !isHijacker && !isPusher && isFlusher && isReader:
return struct {
ResponseWriterStatusCodeGetter
http.ResponseWriter
http.Flusher
io.ReaderFrom
}{i, flusher, reader}
}{i, flusher, reader}, responseProcessor
case !isHijacker && isPusher && !isFlusher && !isReader:
return struct {
ResponseWriterStatusCodeGetter
http.ResponseWriter
http.Pusher
}{i, pusher}
}{i, pusher}, responseProcessor
case !isHijacker && isPusher && !isFlusher && isReader:
return struct {
ResponseWriterStatusCodeGetter
http.ResponseWriter
http.Pusher
io.ReaderFrom
}{i, pusher, reader}
}{i, pusher, reader}, responseProcessor
case !isHijacker && isPusher && isFlusher && !isReader:
return struct {
ResponseWriterStatusCodeGetter
http.ResponseWriter
http.Pusher
http.Flusher
}{i, pusher, flusher}
}{i, pusher, flusher}, responseProcessor
case !isHijacker && isPusher && isFlusher && isReader:
return struct {
ResponseWriterStatusCodeGetter
http.ResponseWriter
http.Pusher
http.Flusher
io.ReaderFrom
}{i, pusher, flusher, reader}
}{i, pusher, flusher, reader}, responseProcessor
case isHijacker && !isPusher && !isFlusher && !isReader:
return struct {
ResponseWriterStatusCodeGetter
http.ResponseWriter
http.Hijacker
}{i, hijacker}
}{i, hijacker}, responseProcessor
case isHijacker && !isPusher && !isFlusher && isReader:
return struct {
ResponseWriterStatusCodeGetter
http.ResponseWriter
http.Hijacker
io.ReaderFrom
}{i, hijacker, reader}
}{i, hijacker, reader}, responseProcessor
case isHijacker && !isPusher && isFlusher && !isReader:
return struct {
ResponseWriterStatusCodeGetter
http.ResponseWriter
http.Hijacker
http.Flusher
}{i, hijacker, flusher}
}{i, hijacker, flusher}, responseProcessor
case isHijacker && !isPusher && isFlusher && isReader:
return struct {
ResponseWriterStatusCodeGetter
http.ResponseWriter
http.Hijacker
http.Flusher
io.ReaderFrom
}{i, hijacker, flusher, reader}
}{i, hijacker, flusher, reader}, responseProcessor
case isHijacker && isPusher && !isFlusher && !isReader:
return struct {
ResponseWriterStatusCodeGetter
http.ResponseWriter
http.Hijacker
http.Pusher
}{i, hijacker, pusher}
}{i, hijacker, pusher}, responseProcessor
case isHijacker && isPusher && !isFlusher && isReader:
return struct {
ResponseWriterStatusCodeGetter
http.ResponseWriter
http.Hijacker
http.Pusher
io.ReaderFrom
}{i, hijacker, pusher, reader}
}{i, hijacker, pusher, reader}, responseProcessor
case isHijacker && isPusher && isFlusher && !isReader:
return struct {
ResponseWriterStatusCodeGetter
http.ResponseWriter
http.Hijacker
http.Pusher
http.Flusher
}{i, hijacker, pusher, flusher}
}{i, hijacker, pusher, flusher}, responseProcessor
case isHijacker && isPusher && isFlusher && isReader:
return struct {
ResponseWriterStatusCodeGetter
http.ResponseWriter
http.Hijacker
http.Pusher
http.Flusher
io.ReaderFrom
}{i, hijacker, pusher, flusher, reader}
}{i, hijacker, pusher, flusher, reader}, responseProcessor
default:
return struct {
ResponseWriterStatusCodeGetter
}{i}
http.ResponseWriter
}{i}, responseProcessor
}
}
46 changes: 46 additions & 0 deletions http/interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright 2022 Juan Pablo Tosso and the OWASP Coraza contributors
// SPDX-License-Identifier: Apache-2.0

// tinygo does not support net.http so this package is not needed for it
//go:build !tinygo
// +build !tinygo

package http

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/corazawaf/coraza/v3"
)

func TestWriteHeader(t *testing.T) {
waf, err := coraza.NewWAF(coraza.NewWAFConfig())
if err != nil {
t.Fatal(err)
}

tx := waf.NewTransaction()
req, _ := http.NewRequest("GET", "", nil)
res := httptest.NewRecorder()
rw, responseProcessor := wrap(res, req, tx)
rw.WriteHeader(204)
rw.WriteHeader(205)
// although we called WriteHeader, status code should be applied until
// responseProcessor is called.
if unwanted, have := 204, res.Code; unwanted == have {
t.Errorf("unexpected status code %d", have)
}

err = responseProcessor(tx, req)
if err != nil {
t.Errorf("unexpected error: %v", err)
}

// although we called a second time with 205, status code should remain the first
// value.
if want, have := 204, res.Code; want != have {
t.Errorf("unexpected status code, want %d, have %d", want, have)
}
}
Loading