Skip to content

Commit

Permalink
Merge pull request #3456 from Zagan202/3405-extend-LoggingMiddleware
Browse files Browse the repository at this point in the history
go/support/http 3405 extend logging middleware
  • Loading branch information
Zagan202 authored Mar 10, 2021
2 parents 37db31b + 5f72b89 commit c908d78
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 22 deletions.
65 changes: 43 additions & 22 deletions support/http/logging_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package http

import (
stdhttp "net/http"
"strings"
"time"

"github.com/go-chi/chi"
Expand All @@ -10,6 +11,11 @@ import (
"github.com/stellar/go/support/log"
)

// Options allow the middleware logger to accept additional information.
type Options struct {
ExtraHeaders []string
}

// SetLogger is a middleware that sets a logger on the context.
func SetLoggerMiddleware(l *log.Entry) func(stdhttp.Handler) stdhttp.Handler {
return func(next stdhttp.Handler) stdhttp.Handler {
Expand All @@ -24,39 +30,54 @@ func SetLoggerMiddleware(l *log.Entry) func(stdhttp.Handler) stdhttp.Handler {

// LoggingMiddleware is a middleware that logs requests to the logger.
func LoggingMiddleware(next stdhttp.Handler) stdhttp.Handler {
return stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) {
mw := mutil.WrapWriter(w)
ctx := log.PushContext(r.Context(), func(l *log.Entry) *log.Entry {
return l.WithFields(log.F{
"req": middleware.GetReqID(r.Context()),
})
})
return LoggingMiddlewareWithOptions(Options{})(next)
}

r = r.WithContext(ctx)
// LoggingMiddlewareWithOptions is a middleware that logs requests to the logger.
// Requires an Options struct to accept additional information.
func LoggingMiddlewareWithOptions(options Options) func(stdhttp.Handler) stdhttp.Handler {
return func(next stdhttp.Handler) stdhttp.Handler {
return stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) {
mw := mutil.WrapWriter(w)
ctx := log.PushContext(r.Context(), func(l *log.Entry) *log.Entry {
return l.WithFields(log.F{
"req": middleware.GetReqID(r.Context()),
})
})
r = r.WithContext(ctx)

logStartOfRequest(r)
logStartOfRequest(r, options.ExtraHeaders)

then := time.Now()
next.ServeHTTP(mw, r)
duration := time.Since(then)
then := time.Now()
next.ServeHTTP(mw, r)
duration := time.Since(then)

logEndOfRequest(r, duration, mw)
})
logEndOfRequest(r, duration, mw)
})
}
}

// logStartOfRequest emits the logline that reports that an http request is
// beginning processing.
func logStartOfRequest(
r *stdhttp.Request,
extraHeaders []string,
) {
l := log.Ctx(r.Context()).WithFields(log.F{
"subsys": "http",
"path": r.URL.String(),
"method": r.Method,
"ip": r.RemoteAddr,
"host": r.Host,
"useragent": r.Header.Get("User-Agent"),
})
fields := log.F{}
for _, header := range extraHeaders {
// Strips "-" characters and lowercases new logrus.Fields keys to be uniform with the other keys in the logger.
// Simplifies querying extended fields.
var headerkey = strings.ToLower(strings.ReplaceAll(header, "-", ""))
fields[headerkey] = r.Header.Get(header)
}
fields["subsys"] = "http"
fields["path"] = r.URL.String()
fields["method"] = r.Method
fields["ip"] = r.RemoteAddr
fields["host"] = r.Host
fields["useragent"] = r.Header.Get("User-Agent")
l := log.Ctx(r.Context()).WithFields(fields)

l.Info("starting request")
}

Expand Down
103 changes: 103 additions & 0 deletions support/http/logging_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,22 @@ import (
"github.com/stretchr/testify/assert"
)

// setXFFMiddleware sets "X-Forwarded-For" header to test LoggingMiddlewareWithOptions.
func setXFFMiddleware(next stdhttp.Handler) stdhttp.Handler {
return stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) {
r.Header.Set("X-Forwarded-For", "203.0.113.195")
next.ServeHTTP(w, r)
})
}

// setContentMD5MiddleWare sets header to test LoggingMiddlewareWithOptions.
func setContentMD5Middleware(next stdhttp.Handler) stdhttp.Handler {
return stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) {
r.Header.Set("Content-MD5", "U3RlbGxhciBpcyBBd2Vzb21lIQ==")
next.ServeHTTP(w, r)
})
}

func TestHTTPMiddleware(t *testing.T) {
done := log.DefaultLogger.StartTest(log.InfoLevel)
mux := chi.NewMux()
Expand Down Expand Up @@ -82,6 +98,93 @@ func TestHTTPMiddleware(t *testing.T) {
}
}

func TestHTTPMiddlewareWithOptions(t *testing.T) {
done := log.DefaultLogger.StartTest(log.InfoLevel)
mux := chi.NewMux()

mux.Use(setXFFMiddleware)
mux.Use(setContentMD5Middleware)
mux.Use(middleware.RequestID)
options := Options{ExtraHeaders: []string{"X-Forwarded-For", "Content-MD5"}}
mux.Use(LoggingMiddlewareWithOptions(options))

mux.Get("/path/{value}", stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) {
log.Ctx(r.Context()).Info("handler log line")
}))
mux.Handle("/not_found", stdhttp.NotFoundHandler())

src := httptest.NewServer(t, mux)
src.GET("/path/1234").Expect().Status(stdhttp.StatusOK)
src.GET("/not_found").Expect().Status(stdhttp.StatusNotFound)
src.GET("/really_not_found").Expect().Status(stdhttp.StatusNotFound)

// get the log buffer and ensure it has both the start and end log lines for
// each request
logged := done()
if assert.Len(t, logged, 7, "unexpected log line count") {
assert.Equal(t, "starting request", logged[0].Message)
assert.Equal(t, "http", logged[0].Data["subsys"])
assert.Equal(t, "GET", logged[0].Data["method"])
assert.NotEmpty(t, logged[0].Data["req"])
assert.Equal(t, "/path/1234", logged[0].Data["path"])
assert.Equal(t, "Go-http-client/1.1", logged[0].Data["useragent"])
assert.Equal(t, "203.0.113.195", logged[0].Data["xforwardedfor"])
assert.Equal(t, "U3RlbGxhciBpcyBBd2Vzb21lIQ==", logged[0].Data["contentmd5"])
assert.Equal(t, 10, len(logged[0].Data))
req1 := logged[0].Data["req"]

assert.Equal(t, "handler log line", logged[1].Message)
assert.Equal(t, req1, logged[1].Data["req"])
assert.Equal(t, 2, len(logged[1].Data))

assert.Equal(t, "finished request", logged[2].Message)
assert.Equal(t, "http", logged[2].Data["subsys"])
assert.Equal(t, "GET", logged[2].Data["method"])
assert.Equal(t, req1, logged[2].Data["req"])
assert.Equal(t, "/path/1234", logged[2].Data["path"])
assert.Equal(t, "/path/{value}", logged[2].Data["route"])
assert.Equal(t, 9, len(logged[2].Data))

assert.Equal(t, "starting request", logged[3].Message)
assert.Equal(t, "http", logged[3].Data["subsys"])
assert.Equal(t, "GET", logged[3].Data["method"])
assert.NotEmpty(t, logged[3].Data["req"])
assert.NotEmpty(t, logged[3].Data["path"])
assert.Equal(t, "Go-http-client/1.1", logged[3].Data["useragent"])
assert.Equal(t, "203.0.113.195", logged[3].Data["xforwardedfor"])
assert.Equal(t, "U3RlbGxhciBpcyBBd2Vzb21lIQ==", logged[3].Data["contentmd5"])
assert.Equal(t, 10, len(logged[3].Data))
req2 := logged[3].Data["req"]

assert.Equal(t, "finished request", logged[4].Message)
assert.Equal(t, "http", logged[4].Data["subsys"])
assert.Equal(t, "GET", logged[4].Data["method"])
assert.Equal(t, req2, logged[4].Data["req"])
assert.Equal(t, "/not_found", logged[4].Data["path"])
assert.Equal(t, "/not_found", logged[4].Data["route"])
assert.Equal(t, 9, len(logged[4].Data))

assert.Equal(t, "starting request", logged[5].Message)
assert.Equal(t, "http", logged[5].Data["subsys"])
assert.Equal(t, "GET", logged[5].Data["method"])
assert.NotEmpty(t, logged[5].Data["req"])
assert.NotEmpty(t, logged[5].Data["path"])
assert.Equal(t, "Go-http-client/1.1", logged[5].Data["useragent"])
assert.Equal(t, "203.0.113.195", logged[5].Data["xforwardedfor"])
assert.Equal(t, "U3RlbGxhciBpcyBBd2Vzb21lIQ==", logged[5].Data["contentmd5"])
assert.Equal(t, 10, len(logged[5].Data))
req3 := logged[5].Data["req"]

assert.Equal(t, "finished request", logged[6].Message)
assert.Equal(t, "http", logged[6].Data["subsys"])
assert.Equal(t, "GET", logged[6].Data["method"])
assert.Equal(t, req3, logged[6].Data["req"])
assert.Equal(t, "/really_not_found", logged[6].Data["path"])
assert.Equal(t, "", logged[6].Data["route"])
assert.Equal(t, 9, len(logged[6].Data))
}
}

func TestHTTPMiddleware_stdlibServeMux(t *testing.T) {
done := log.DefaultLogger.StartTest(log.InfoLevel)

Expand Down

0 comments on commit c908d78

Please sign in to comment.