Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔥 Update: add timeout context middleware #2090

Merged
merged 9 commits into from
Sep 16, 2022
69 changes: 58 additions & 11 deletions middleware/timeout/README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
# Timeout
Timeout middleware for [Fiber](https://github.com/gofiber/fiber) wraps a `fiber.Handler` with a timeout. If the handler takes longer than the given duration to return, the timeout error is set and forwarded to the centralized [ErrorHandler](https://docs.gofiber.io/error-handling).
Timeout middleware for Fiber. As a `fiber.Handler` wrapper, it creates a context with `context.WithTimeout` and pass it in `UserContext`.

If the context passed executions (eg. DB ops, Http calls) takes longer than the given duration to return, the timeout error is set and forwarded to the centralized `ErrorHandler`.

It has no race conditions, ready to use on production.

### Table of Contents
- [Signatures](#signatures)
- [Examples](#examples)
- [Timeout](#timeout)
- [Table of Contents](#table-of-contents)
- [Signatures](#signatures)
- [Examples](#examples)


### Signatures
```go
func New(h fiber.Handler, t time.Duration) fiber.Handler
func New(handler fiber.Handler, timeout time.Duration, timeoutErrors ...error) fiber.Handler
```

### Examples
Expand All @@ -20,15 +26,56 @@ import (
)
```

After you initiate your Fiber app, you can use the following possibilities:
After you initiate your Fiber app, you can use:
```go
handler := func(ctx *fiber.Ctx) error {
err := ctx.SendString("Hello, World 👋!")
if err != nil {
return err
h := func(c *fiber.Ctx) error {
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
if err := sleepWithContextWithCustomError(c.UserContext(), sleepTime); err != nil {
return fmt.Errorf("%w: execution error", err)
}
return nil
}

app.Get("/foo", timeoutcontext.New(h, 5 * time.Second))

func sleepWithContext(ctx context.Context, d time.Duration) error {
timer := time.NewTimer(d)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return context.DeadlineExceeded
case <-timer.C:
}
return nil
}

app.Get("/foo", timeout.New(handler, 5 * time.Second))
```

Use with custom error:
```go
var ErrFooTimeOut = errors.New("foo context canceled")

h := func(c *fiber.Ctx) error {
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
if err := sleepWithContextWithCustomError(c.UserContext(), sleepTime); err != nil {
return fmt.Errorf("%w: execution error", err)
}
return nil
}

app.Get("/foo", timeoutcontext.New(h, 5 * time.Second), ErrFooTimeOut)

func sleepWithContext(ctx context.Context, d time.Duration) error {
timer := time.NewTimer(d)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return ErrFooTimeOut
case <-timer.C:
}
return nil
}
```
47 changes: 17 additions & 30 deletions middleware/timeout/timeout.go
Original file line number Diff line number Diff line change
@@ -1,43 +1,30 @@
package timeout

import (
"fmt"
"sync"
"context"
"errors"
"time"

"github.com/gofiber/fiber/v2"
)

var once sync.Once

// New wraps a handler and aborts the process of the handler if the timeout is reached
func New(handler fiber.Handler, timeout time.Duration) fiber.Handler {
once.Do(func() {
fmt.Println("[Warning] timeout contains data race issues, not ready for production!")
})

if timeout <= 0 {
return handler
}

// logic is from fasthttp.TimeoutWithCodeHandler https://github.com/valyala/fasthttp/blob/master/server.go#L418
// New implementation of timeout middleware. Set custom errors(context.DeadlineExceeded vs) for get fiber.ErrRequestTimeout response.
func New(h fiber.Handler, t time.Duration, tErrs ...error) fiber.Handler {
return func(ctx *fiber.Ctx) error {
ch := make(chan struct{}, 1)

go func() {
defer func() {
_ = recover()
}()
_ = handler(ctx)
ch <- struct{}{}
}()

select {
case <-ch:
case <-time.After(timeout):
return fiber.ErrRequestTimeout
timeoutContext, cancel := context.WithTimeout(ctx.UserContext(), t)
defer cancel()
ctx.SetUserContext(timeoutContext)
if err := h(ctx); err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return fiber.ErrRequestTimeout
}
for i := range tErrs {
if errors.Is(err, tErrs[i]) {
return fiber.ErrRequestTimeout
}
}
return err
}

return nil
}
}
125 changes: 77 additions & 48 deletions middleware/timeout/timeout_test.go
Original file line number Diff line number Diff line change
@@ -1,55 +1,84 @@
package timeout

// // go test -run Test_Middleware_Timeout
// func Test_Middleware_Timeout(t *testing.T) {
// app := fiber.New(fiber.Config{DisableStartupMessage: true})
import (
"context"
"errors"
"fmt"
"net/http/httptest"
"testing"
"time"

// h := New(func(c *fiber.Ctx) error {
// sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
// time.Sleep(sleepTime)
// return c.SendString("After " + c.Params("sleepTime") + "ms sleeping")
// }, 5*time.Millisecond)
// app.Get("/test/:sleepTime", h)
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
)

// testTimeout := func(timeoutStr string) {
// resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
// utils.AssertEqual(t, nil, err, "app.Test(req)")
// utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
// go test -run Test_Timeout
func Test_Timeout(t *testing.T) {
// fiber instance
app := fiber.New()
h := New(func(c *fiber.Ctx) error {
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
if err := sleepWithContext(c.UserContext(), sleepTime, context.DeadlineExceeded); err != nil {
return fmt.Errorf("%w: l2 wrap", fmt.Errorf("%w: l1 wrap ", err))
}
return nil
}, 100*time.Millisecond)
app.Get("/test/:sleepTime", h)
testTimeout := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
}
testSucces := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
}
testTimeout("300")
testTimeout("500")
testSucces("50")
testSucces("30")
}

// body, err := ioutil.ReadAll(resp.Body)
// utils.AssertEqual(t, nil, err)
// utils.AssertEqual(t, "Request Timeout", string(body))
// }
// testSucces := func(timeoutStr string) {
// resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
// utils.AssertEqual(t, nil, err, "app.Test(req)")
// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
var ErrFooTimeOut = errors.New("foo context canceled")

// body, err := ioutil.ReadAll(resp.Body)
// utils.AssertEqual(t, nil, err)
// utils.AssertEqual(t, "After "+timeoutStr+"ms sleeping", string(body))
// }
// go test -run Test_TimeoutWithCustomError
func Test_TimeoutWithCustomError(t *testing.T) {
// fiber instance
app := fiber.New()
h := New(func(c *fiber.Ctx) error {
sleepTime, _ := time.ParseDuration(c.Params("sleepTime") + "ms")
if err := sleepWithContext(c.UserContext(), sleepTime, ErrFooTimeOut); err != nil {
return fmt.Errorf("%w: execution error", err)
}
return nil
}, 100*time.Millisecond, ErrFooTimeOut)
app.Get("/test/:sleepTime", h)
testTimeout := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")
}
testSucces := func(timeoutStr string) {
resp, err := app.Test(httptest.NewRequest("GET", "/test/"+timeoutStr, nil))
utils.AssertEqual(t, nil, err, "app.Test(req)")
utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode, "Status code")
}
testTimeout("300")
testTimeout("500")
testSucces("50")
testSucces("30")
}

// testTimeout("15")
// testSucces("2")
// testTimeout("30")
// testSucces("3")
// }

// // go test -run -v Test_Timeout_Panic
// func Test_Timeout_Panic(t *testing.T) {
// app := fiber.New(fiber.Config{DisableStartupMessage: true})

// app.Get("/panic", recover.New(), New(func(c *fiber.Ctx) error {
// c.Set("dummy", "this should not be here")
// panic("panic in timeout handler")
// }, 5*time.Millisecond))

// resp, err := app.Test(httptest.NewRequest("GET", "/panic", nil))
// utils.AssertEqual(t, nil, err, "app.Test(req)")
// utils.AssertEqual(t, fiber.StatusRequestTimeout, resp.StatusCode, "Status code")

// body, err := ioutil.ReadAll(resp.Body)
// utils.AssertEqual(t, nil, err)
// utils.AssertEqual(t, "Request Timeout", string(body))
// }
func sleepWithContext(ctx context.Context, d time.Duration, te error) error {
timer := time.NewTimer(d)
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
return te
case <-timer.C:
}
return nil
}