Skip to content

Commit

Permalink
Add additional hlog logging handlers (rs#607)
Browse files Browse the repository at this point in the history
* Add HTTPVersionHandler.

* Add RemoteIPHandler.

* Add trimPort to HostHandler.

* Add EtagHandler.

* Add ResponseHeaderHandler.

* Add TestGetHost.

* Call AccessHandler's f also on panic.
  • Loading branch information
mitar authored Nov 8, 2023
1 parent e7034c2 commit bb14b8b
Show file tree
Hide file tree
Showing 2 changed files with 229 additions and 8 deletions.
108 changes: 101 additions & 7 deletions hlog/hlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package hlog

import (
"context"
"net"
"net/http"
"strings"
"time"

"github.com/rs/xid"
Expand Down Expand Up @@ -89,6 +91,35 @@ func RemoteAddrHandler(fieldKey string) func(next http.Handler) http.Handler {
}
}

func getHost(hostPort string) string {
if hostPort == "" {
return ""
}

host, _, err := net.SplitHostPort(hostPort)
if err != nil {
return hostPort
}
return host
}

// RemoteIPHandler is similar to RemoteAddrHandler, but logs only
// an IP, not a port.
func RemoteIPHandler(fieldKey string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := getHost(r.RemoteAddr)
if ip != "" {
log := zerolog.Ctx(r.Context())
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str(fieldKey, ip)
})
}
next.ServeHTTP(w, r)
})
}
}

// UserAgentHandler adds the request's user-agent as a field to the context's logger
// using fieldKey as field key.
func UserAgentHandler(fieldKey string) func(next http.Handler) http.Handler {
Expand Down Expand Up @@ -135,6 +166,21 @@ func ProtoHandler(fieldKey string) func(next http.Handler) http.Handler {
}
}

// HTTPVersionHandler is similar to ProtoHandler, but it does not store the "HTTP/"
// prefix in the protocol name.
func HTTPVersionHandler(fieldKey string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proto := strings.TrimPrefix(r.Proto, "HTTP/")
log := zerolog.Ctx(r.Context())
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str(fieldKey, proto)
})
next.ServeHTTP(w, r)
})
}
}

type idKey struct{}

// IDFromRequest returns the unique id associated to the request if any.
Expand Down Expand Up @@ -205,27 +251,75 @@ func CustomHeaderHandler(fieldKey, header string) func(next http.Handler) http.H
}
}

// EtagHandler adds Etag header from response's header as a field to
// the context's logger using fieldKey as field key.
func EtagHandler(fieldKey string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
etag := w.Header().Get("Etag")
if etag != "" {
etag = strings.ReplaceAll(etag, `"`, "")
log := zerolog.Ctx(r.Context())
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str(fieldKey, etag)
})
}
}()
next.ServeHTTP(w, r)
})
}
}

func ResponseHeaderHandler(fieldKey, headerName string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
value := w.Header().Get(headerName)
if value != "" {
log := zerolog.Ctx(r.Context())
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str(fieldKey, value)
})
}
}()
next.ServeHTTP(w, r)
})
}
}

// AccessHandler returns a handler that call f after each request.
func AccessHandler(f func(r *http.Request, status, size int, duration time.Duration)) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
lw := mutil.WrapWriter(w)
defer func() {
f(r, lw.Status(), lw.BytesWritten(), time.Since(start))
}()
next.ServeHTTP(lw, r)
f(r, lw.Status(), lw.BytesWritten(), time.Since(start))
})
}
}

// HostHandler adds the request's host as a field to the context's logger
// using fieldKey as field key.
func HostHandler(fieldKey string) func(next http.Handler) http.Handler {
// using fieldKey as field key. If trimPort is set to true, then port is
// removed from the host.
func HostHandler(fieldKey string, trimPort ...bool) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log := zerolog.Ctx(r.Context())
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str(fieldKey, r.Host)
})
var host string
if len(trimPort) > 0 && trimPort[0] {
host = getHost(r.Host)
} else {
host = r.Host
}
if host != "" {
log := zerolog.Ctx(r.Context())
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
return c.Str(fieldKey, host)
})
}
next.ServeHTTP(w, r)
})
}
Expand Down
129 changes: 128 additions & 1 deletion hlog/hlog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,38 @@ func TestRemoteAddrHandlerIPv6(t *testing.T) {
}
}

func TestRemoteIPHandler(t *testing.T) {
out := &bytes.Buffer{}
r := &http.Request{
RemoteAddr: "1.2.3.4:1234",
}
h := RemoteIPHandler("ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.Log().Msg("")
}))
h = NewHandler(zerolog.New(out))(h)
h.ServeHTTP(nil, r)
if want, got := `{"ip":"1.2.3.4"}`+"\n", decodeIfBinary(out); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}

func TestRemoteIPHandlerIPv6(t *testing.T) {
out := &bytes.Buffer{}
r := &http.Request{
RemoteAddr: "[2001:db8:a0b:12f0::1]:1234",
}
h := RemoteIPHandler("ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.Log().Msg("")
}))
h = NewHandler(zerolog.New(out))(h)
h.ServeHTTP(nil, r)
if want, got := `{"ip":"2001:db8:a0b:12f0::1"}`+"\n", decodeIfBinary(out); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}

func TestUserAgentHandler(t *testing.T) {
out := &bytes.Buffer{}
r := &http.Request{
Expand Down Expand Up @@ -201,6 +233,46 @@ func TestCustomHeaderHandler(t *testing.T) {
}
}

func TestEtagHandler(t *testing.T) {
out := &bytes.Buffer{}
w := httptest.NewRecorder()
r := &http.Request{}
h := EtagHandler("etag")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Etag", `"abcdef"`)
w.WriteHeader(http.StatusOK)
}))
h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h.ServeHTTP(w, r)
l := FromRequest(r)
l.Log().Msg("")
})
h3 := NewHandler(zerolog.New(out))(h2)
h3.ServeHTTP(w, r)
if want, got := `{"etag":"abcdef"}`+"\n", decodeIfBinary(out); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}

func TestResponseHeaderHandler(t *testing.T) {
out := &bytes.Buffer{}
w := httptest.NewRecorder()
r := &http.Request{}
h := ResponseHeaderHandler("encoding", "Content-Encoding")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Encoding", `gzip`)
w.WriteHeader(http.StatusOK)
}))
h2 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h.ServeHTTP(w, r)
l := FromRequest(r)
l.Log().Msg("")
})
h3 := NewHandler(zerolog.New(out))(h2)
h3.ServeHTTP(w, r)
if want, got := `{"encoding":"gzip"}`+"\n", decodeIfBinary(out); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}

func TestProtoHandler(t *testing.T) {
out := &bytes.Buffer{}
r := &http.Request{
Expand All @@ -217,6 +289,22 @@ func TestProtoHandler(t *testing.T) {
}
}

func TestHTTPVersionHandler(t *testing.T) {
out := &bytes.Buffer{}
r := &http.Request{
Proto: "HTTP/1.1",
}
h := HTTPVersionHandler("proto")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.Log().Msg("")
}))
h = NewHandler(zerolog.New(out))(h)
h.ServeHTTP(nil, r)
if want, got := `{"proto":"1.1"}`+"\n", decodeIfBinary(out); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}

func TestCombinedHandlers(t *testing.T) {
out := &bytes.Buffer{}
r := &http.Request{
Expand Down Expand Up @@ -295,14 +383,53 @@ func TestCtxWithID(t *testing.T) {

func TestHostHandler(t *testing.T) {
out := &bytes.Buffer{}
r := &http.Request{Host: "example.com"}
r := &http.Request{Host: "example.com:8080"}
h := HostHandler("host")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.Log().Msg("")
}))
h = NewHandler(zerolog.New(out))(h)
h.ServeHTTP(nil, r)
if want, got := `{"host":"example.com:8080"}`+"\n", decodeIfBinary(out); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}

func TestHostHandlerWithoutPort(t *testing.T) {
out := &bytes.Buffer{}
r := &http.Request{Host: "example.com:8080"}
h := HostHandler("host", true)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r)
l.Log().Msg("")
}))
h = NewHandler(zerolog.New(out))(h)
h.ServeHTTP(nil, r)
if want, got := `{"host":"example.com"}`+"\n", decodeIfBinary(out); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
}
}

func TestGetHost(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"", ""},
{"example.com:8080", "example.com"},
{"example.com", "example.com"},
{"invalid", "invalid"},
{"192.168.0.1:8080", "192.168.0.1"},
{"[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8080", "2001:0db8:85a3:0000:0000:8a2e:0370:7334"},
{"こんにちは.com:8080", "こんにちは.com"},
}

for _, tt := range tests {
tt := tt
t.Run(tt.input, func(t *testing.T) {
result := getHost(tt.input)
if tt.expected != result {
t.Errorf("Invalid log output, got: %s, want: %s", result, tt.expected)
}
})
}
}

0 comments on commit bb14b8b

Please sign in to comment.