Skip to content

Commit

Permalink
g
Browse files Browse the repository at this point in the history
  • Loading branch information
komuw committed Aug 16, 2024
1 parent 2cf73be commit 0460c9c
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 28 deletions.
19 changes: 8 additions & 11 deletions middleware/log_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"context"
"fmt"
"io"
"log/slog"
mathRand "math/rand/v2"
"net/http"
"net/http/httptest"
Expand All @@ -24,7 +23,7 @@ const (
someLatencyMS = 3
)

func toLog(t *testing.T, l *slog.Logger) func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) {
func toLog(t *testing.T, buf *bytes.Buffer) func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) {
t.Helper()

const (
Expand All @@ -33,6 +32,8 @@ func toLog(t *testing.T, l *slog.Logger) func(w http.ResponseWriter, r http.Requ
rateShedSamplePercent = 10
)

l := log.New(context.Background(), buf, 500)

return func(w http.ResponseWriter, r http.Request, statusCode int, fields []any) {
// Each request should get its own context. That's why we call `log.WithID` for every request.
reqL := log.WithID(r.Context(), l)
Expand Down Expand Up @@ -84,16 +85,12 @@ func someLogHandler(successMsg string) http.HandlerFunc {
func TestLogMiddleware(t *testing.T) {
t.Parallel()

getLogger := func(w io.Writer) *slog.Logger {
return log.New(context.Background(), w, 500)
}

t.Run("success", func(t *testing.T) {
t.Parallel()

logOutput := &bytes.Buffer{}
successMsg := "hello"
wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput))
wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput))

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodHead, "/someUri", nil)
Expand All @@ -116,7 +113,7 @@ func TestLogMiddleware(t *testing.T) {
logOutput := &bytes.Buffer{}
errorMsg := "someLogHandler failed"
successMsg := "hello"
wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput))
wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput))

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodHead, "/someUri", nil)
Expand Down Expand Up @@ -152,7 +149,7 @@ func TestLogMiddleware(t *testing.T) {
logOutput := &bytes.Buffer{}
successMsg := "hello"
errorMsg := "someLogHandler failed"
wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput))
wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput))

{
// first request that succeds
Expand Down Expand Up @@ -231,7 +228,7 @@ func TestLogMiddleware(t *testing.T) {

logOutput := &bytes.Buffer{}
successMsg := "hello"
wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput))
wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput))

someLogID := "hey-some-log-id:" + id.New()

Expand Down Expand Up @@ -261,7 +258,7 @@ func TestLogMiddleware(t *testing.T) {
successMsg := "hello"
// for this concurrency test, we have to re-use the same wrappedHandler
// so that state is shared and thus we can see if there is any state which is not handled correctly.
wrappedHandler := logger(someLogHandler(successMsg), getLogger(logOutput))
wrappedHandler := logger(someLogHandler(successMsg), toLog(t, logOutput))

runhandler := func() {
rec := httptest.NewRecorder()
Expand Down
7 changes: 7 additions & 0 deletions middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ const (
allowHeader = "Allow"
)

// TODO: remove
func todoRemove(w http.ResponseWriter, r http.Request, statusCode int, fields []any) {
return
}

// allDefaultMiddlewares is a middleware that bundles all the default/core middlewares into one.
//
// example usage:
Expand Down Expand Up @@ -168,7 +173,9 @@ func allDefaultMiddlewares(
rateLimit,
),
),
todoRemove,
),
todoRemove,
),
),
strategy,
Expand Down
25 changes: 8 additions & 17 deletions middleware/recoverer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,14 @@ package middleware

import (
"bytes"
"context"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync"
"testing"

"github.com/komuw/ong/errors"
"github.com/komuw/ong/log"

"go.akshayshah.org/attest"
)
Expand Down Expand Up @@ -48,16 +43,12 @@ func anotherHandlerThatPanics() http.HandlerFunc {
func TestPanic(t *testing.T) {
t.Parallel()

getLogger := func(w io.Writer) *slog.Logger {
return log.New(context.Background(), w, 500)
}

t.Run("ok if no panic", func(t *testing.T) {
t.Parallel()

logOutput := &bytes.Buffer{}
msg := "hello"
wrappedHandler := recoverer(handlerThatPanics(msg, false, nil), getLogger(logOutput))
wrappedHandler := recoverer(handlerThatPanics(msg, false, nil), toLog(t, logOutput))

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
Expand All @@ -74,7 +65,7 @@ func TestPanic(t *testing.T) {

logOutput := &bytes.Buffer{}
msg := "hello"
wrappedHandler := recoverer(handlerThatPanics(msg, true, nil), getLogger(logOutput))
wrappedHandler := recoverer(handlerThatPanics(msg, true, nil), toLog(t, logOutput))

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
Expand Down Expand Up @@ -104,7 +95,7 @@ func TestPanic(t *testing.T) {
msg := "hello"
errMsg := "99 problems"
err := errors.New(errMsg)
wrappedHandler := recoverer(handlerThatPanics(msg, false, err), getLogger(logOutput))
wrappedHandler := recoverer(handlerThatPanics(msg, false, err), toLog(t, logOutput))

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
Expand Down Expand Up @@ -132,7 +123,7 @@ func TestPanic(t *testing.T) {
t.Parallel()

logOutput := &bytes.Buffer{}
wrappedHandler := recoverer(anotherHandlerThatPanics(), getLogger(logOutput))
wrappedHandler := recoverer(anotherHandlerThatPanics(), toLog(t, logOutput))

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/someUri", nil)
Expand All @@ -141,19 +132,19 @@ func TestPanic(t *testing.T) {
res := rec.Result()
defer res.Body.Close()
attest.Equal(t, res.StatusCode, http.StatusInternalServerError)
attest.Subsequence(t, logOutput.String(), "middleware/recoverer_test.go:42") // line where panic happened.
attest.Subsequence(t, logOutput.String(), "middleware/recoverer_test.go:37") // line where panic happened.
})

t.Run("concurrency safe", func(t *testing.T) {
t.Parallel()

// &bytes.Buffer{} is not concurrency safe, so we use os.Stderr instead.
logOutput := os.Stderr
// If &bytes.Buffer{} is not concurrency safe, we can use os.Stderr instead.
logOutput := &bytes.Buffer{}
msg := "hey"
err := errors.New(msg)
// for this concurrency test, we have to re-use the same wrappedHandler
// so that state is shared and thus we can see if there is any state which is not handled correctly.
wrappedHandler := recoverer(handlerThatPanics(msg, false, err), getLogger(logOutput))
wrappedHandler := recoverer(handlerThatPanics(msg, false, err), toLog(t, logOutput))

runhandler := func() {
rec := httptest.NewRecorder()
Expand Down

0 comments on commit 0460c9c

Please sign in to comment.