Skip to content

Commit

Permalink
Implement httprate.WithErrorHandler()
Browse files Browse the repository at this point in the history
  • Loading branch information
VojtechVitek committed Jul 26, 2024
1 parent 6aa26b0 commit 6556a11
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 23 deletions.
26 changes: 22 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ to implement the `httprate.LimitCounter` interface to support an atomic incremen

## Backends

- [x] In-memory (built into this package)
- [x] Redis: https://github.com/go-chi/httprate-redis
- [x] Local in-memory backend (default)
- [x] Redis backend: https://github.com/go-chi/httprate-redis

## Example

Expand Down Expand Up @@ -85,12 +85,30 @@ r.Use(httprate.Limit(
10,
time.Minute,
httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "some specific response here", http.StatusTooManyRequests)
http.Error(w, `{"error": "Rate limited. Please slow down."}`, http.StatusTooManyRequests)
}),
))
```

### Customize response headers
### Send specific response for errors returned by the LimitCounter implementation

```go
r.Use(httprate.Limit(
10,
time.Minute,
httprate.WithErrorHandler(func(w http.ResponseWriter, r *http.Request, err error) {
// NOTE: The local in-memory counter is guaranteed not return any errors.
// Other backends may return errors, depending on whether they have
// in-memory fallback mechanism implemented in case of network errors.

http.Error(w, fmt.Sprintf(`{"error": %q}`, err), http.StatusPreconditionRequired)
}),
httprate.WithLimitCounter(customBackend),
))
```


### Send custom custom response headers

```go
r.Use(httprate.Limit(
Expand Down
8 changes: 7 additions & 1 deletion httprate.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,13 @@ func WithKeyByRealIP() Option {

func WithLimitHandler(h http.HandlerFunc) Option {
return func(rl *rateLimiter) {
rl.onRequestLimit = h
rl.onRateLimited = h
}
}

func WithErrorHandler(h func(http.ResponseWriter, *http.Request, error)) Option {
return func(rl *rateLimiter) {
rl.onError = h
}
}

Expand Down
41 changes: 26 additions & 15 deletions limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,26 @@ func NewRateLimiter(requestLimit int, windowLength time.Duration, options ...Opt
rl.limitCounter.Config(requestLimit, windowLength)
}

if rl.onRequestLimit == nil {
rl.onRequestLimit = func(w http.ResponseWriter, r *http.Request) {
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
}
if rl.onRateLimited == nil {
rl.onRateLimited = onRateLimited
}

if rl.onError == nil {
rl.onError = onError
}

return rl
}

type rateLimiter struct {
requestLimit int
windowLength time.Duration
keyFn KeyFunc
limitCounter LimitCounter
onRequestLimit http.HandlerFunc
headers ResponseHeaders
mu sync.Mutex
requestLimit int
windowLength time.Duration
keyFn KeyFunc
limitCounter LimitCounter
onRateLimited http.HandlerFunc
onError func(http.ResponseWriter, *http.Request, error)
headers ResponseHeaders
mu sync.Mutex
}

func (l *rateLimiter) Counter() LimitCounter {
Expand All @@ -75,7 +78,7 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key, err := l.keyFn(r)
if err != nil {
http.Error(w, err.Error(), http.StatusPreconditionRequired)
l.onError(w, r, err)
return
}

Expand All @@ -93,7 +96,7 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
_, rateFloat, err := l.calculateRate(key, limit)
if err != nil {
l.mu.Unlock()
http.Error(w, err.Error(), http.StatusPreconditionRequired)
l.onError(w, r, err)
return
}
rate := int(math.Round(rateFloat))
Expand All @@ -108,14 +111,14 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {

l.mu.Unlock()
setHeader(w, l.headers.RetryAfter, fmt.Sprintf("%d", int(l.windowLength.Seconds()))) // RFC 6585
l.onRequestLimit(w, r)
l.onRateLimited(w, r)
return
}

err = l.limitCounter.IncrementBy(key, currentWindow, increment)
if err != nil {
l.mu.Unlock()
http.Error(w, err.Error(), http.StatusInternalServerError)
l.onError(w, r, err)
return
}
l.mu.Unlock()
Expand Down Expand Up @@ -150,3 +153,11 @@ func setHeader(w http.ResponseWriter, key string, value string) {
w.Header().Set(key, value)
}
}

func onRateLimited(w http.ResponseWriter, r *http.Request) {
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
}

func onError(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, err.Error(), http.StatusPreconditionRequired)
}
9 changes: 6 additions & 3 deletions local_counter.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (

// NewLocalLimitCounter creates an instance of localCounter,
// which is an in-memory implementation of http.LimitCounter.
//
// All methods are guaranteed to always return nil error.
func NewLocalLimitCounter(windowLength time.Duration) *localCounter {
return &localCounter{
windowLength: windowLength,
Expand Down Expand Up @@ -60,10 +62,11 @@ func (c *localCounter) Get(key string, currentWindow, previousWindow time.Time)
return 0, 0, nil
}

// Config implements LimitCounter but is redundant.
func (c *localCounter) Config(requestLimit int, windowLength time.Duration) {}
func (c *localCounter) Config(requestLimit int, windowLength time.Duration) {
c.windowLength = windowLength
c.latestWindow = time.Now().UTC().Truncate(windowLength)
}

// Increment implements LimitCounter but is redundant.
func (c *localCounter) Increment(key string, currentWindow time.Time) error {
return c.IncrementBy(key, currentWindow, 1)
}
Expand Down

0 comments on commit 6556a11

Please sign in to comment.