Skip to content

Commit

Permalink
Add rate limit to confighttp (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
foadnh authored Nov 6, 2023
1 parent b8eb083 commit a397030
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 0 deletions.
14 changes: 14 additions & 0 deletions config/confighttp/confighttp.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ type HTTPServerSettings struct {
// Auth for this receiver
Auth *configauth.Authentication `mapstructure:"auth"`

// RateLimit for this receiver
RateLimit *RateLimit `mapstructure:"rate_limit"`

// MaxRequestBodySize sets the maximum request body size in bytes
MaxRequestBodySize int64 `mapstructure:"max_request_body_size"`

Expand Down Expand Up @@ -309,6 +312,17 @@ func (hss *HTTPServerSettings) ToServer(host component.Host, settings component.
handler = authInterceptor(handler, server)
}

// The RateLimit interceptor should always be right after auth to ensure
// the request rate is within an acceptable threshold.
if hss.RateLimit != nil {
limiter, err := hss.RateLimit.rateLimiter(host.GetExtensions())
if err != nil {
return nil, err
}

handler = rateLimitInterceptor(handler, limiter)
}

// TODO: emit a warning when non-empty CorsHeaders and empty CorsOrigins.
if hss.CORS != nil && len(hss.CORS.AllowedOrigins) > 0 {
co := cors.Options{
Expand Down
51 changes: 51 additions & 0 deletions config/confighttp/ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0

package confighttp // import "go.opentelemetry.io/collector/config/confighttp"

import (
"context"
"fmt"
"net/http"

"go.opentelemetry.io/collector/component"
"go.opentelemetry.io/collector/extension"
)

// RateLimit defines rate limiter settings for the receiver.
type RateLimit struct {
// RateLimiterID specifies the name of the extension to use in order to rate limit the incoming data point.
RateLimiterID component.ID `mapstructure:"rate_limiter"`
}

type rateLimiter interface {
extension.Extension

Take(context.Context, string, http.Header) error
}

// rateLimiter attempts to select the appropriate rateLimiter from the list of extensions,
// based on the component id of the extension. If a rateLimiter is not found, an error is returned.
func (rl RateLimit) rateLimiter(extensions map[component.ID]component.Component) (rateLimiter, error) {
if ext, found := extensions[rl.RateLimiterID]; found {
if limiter, ok := ext.(rateLimiter); ok {
return limiter, nil
}
return nil, fmt.Errorf("extension %q is not a rate limit", rl.RateLimiterID)
}
return nil, fmt.Errorf("rate limit %q not found", rl.RateLimiterID)
}

// rateLimitInterceptor adds interceptor for rate limit check.
// It returns TooManyRequests(429) status code if rate limiter rejects the request.
func rateLimitInterceptor(next http.Handler, limiter rateLimiter) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := limiter.Take(r.Context(), r.URL.Path, r.Header)
if err != nil {
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
return
}

next.ServeHTTP(w, r)
})
}
111 changes: 111 additions & 0 deletions config/confighttp/ratelimit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0

package confighttp

import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/collector/component"
"go.opentelemetry.io/collector/component/componenttest"
)

func TestServerRateLimit(t *testing.T) {
// prepare
hss := HTTPServerSettings{
Endpoint: "localhost:0",
RateLimit: &RateLimit{
RateLimiterID: component.NewID("mock"),
},
}

limiter := &mockRateLimiter{}

host := &mockHost{
ext: map[component.ID]component.Component{
component.NewID("mock"): limiter,
},
}

var handlerCalled bool
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlerCalled = true
})

srv, err := hss.ToServer(host, componenttest.NewNopTelemetrySettings(), handler)
require.NoError(t, err)

// test
srv.Handler.ServeHTTP(&httptest.ResponseRecorder{}, httptest.NewRequest("GET", "/", nil))

// verify
assert.True(t, handlerCalled)
assert.Equal(t, 1, limiter.calls)
}

func TestInvalidServerRateLimit(t *testing.T) {
hss := HTTPServerSettings{
RateLimit: &RateLimit{
RateLimiterID: component.NewID("non-existing"),
},
}

srv, err := hss.ToServer(componenttest.NewNopHost(), componenttest.NewNopTelemetrySettings(), http.NewServeMux())
require.Error(t, err)
require.Nil(t, srv)
}

func TestRejectedServerRateLimit(t *testing.T) {
// prepare
hss := HTTPServerSettings{
Endpoint: "localhost:0",
RateLimit: &RateLimit{
RateLimiterID: component.NewID("mock"),
},
}
host := &mockHost{
ext: map[component.ID]component.Component{
component.NewID("mock"): &mockRateLimiter{
err: errors.New("rate limited"),
},
},
}

srv, err := hss.ToServer(host, componenttest.NewNopTelemetrySettings(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
require.NoError(t, err)

// test
response := &httptest.ResponseRecorder{}
srv.Handler.ServeHTTP(response, httptest.NewRequest("GET", "/", nil))

// verify
assert.Equal(t, response.Result().StatusCode, http.StatusTooManyRequests)
assert.Equal(t, response.Result().Status, fmt.Sprintf("%v %s", http.StatusTooManyRequests, http.StatusText(http.StatusTooManyRequests)))
}

// Mocks

type mockRateLimiter struct {
calls int
err error
}

func (m *mockRateLimiter) Take(context.Context, string, http.Header) error {
m.calls++
return m.err
}

func (m *mockRateLimiter) Start(_ context.Context, _ component.Host) error {
return nil
}

func (m *mockRateLimiter) Shutdown(_ context.Context) error {
return nil
}

0 comments on commit a397030

Please sign in to comment.