diff --git a/config/confighttp/confighttp.go b/config/confighttp/confighttp.go index ebf4f97c351..9294b6ee78c 100644 --- a/config/confighttp/confighttp.go +++ b/config/confighttp/confighttp.go @@ -224,6 +224,9 @@ type HTTPServerSettings struct { // Auth for this receiver Auth *configauth.Authentication `mapstructure:"auth"` + // RateLimit for this receiver + RateLimit *RateLimit `mapstructure:"rate_limit"` + // MaxRequestBodySize sets the maximum request body size in bytes MaxRequestBodySize int64 `mapstructure:"max_request_body_size"` @@ -309,6 +312,17 @@ func (hss *HTTPServerSettings) ToServer(host component.Host, settings component. handler = authInterceptor(handler, server) } + // The RateLimit interceptor should always be right after auth to ensure + // the request rate is within an acceptable threshold. + if hss.RateLimit != nil { + limiter, err := hss.RateLimit.rateLimiter(host.GetExtensions()) + if err != nil { + return nil, err + } + + handler = rateLimitInterceptor(handler, limiter) + } + // TODO: emit a warning when non-empty CorsHeaders and empty CorsOrigins. if hss.CORS != nil && len(hss.CORS.AllowedOrigins) > 0 { co := cors.Options{ diff --git a/config/confighttp/ratelimit.go b/config/confighttp/ratelimit.go new file mode 100644 index 00000000000..cfbe7bbc5cb --- /dev/null +++ b/config/confighttp/ratelimit.go @@ -0,0 +1,51 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package confighttp // import "go.opentelemetry.io/collector/config/confighttp" + +import ( + "context" + "fmt" + "net/http" + + "go.opentelemetry.io/collector/component" + "go.opentelemetry.io/collector/extension" +) + +// RateLimit defines rate limiter settings for the receiver. +type RateLimit struct { + // RateLimiterID specifies the name of the extension to use in order to rate limit the incoming data point. + RateLimiterID component.ID `mapstructure:"rate_limiter"` +} + +type rateLimiter interface { + extension.Extension + + Take(context.Context, string, http.Header) error +} + +// rateLimiter attempts to select the appropriate rateLimiter from the list of extensions, +// based on the component id of the extension. If a rateLimiter is not found, an error is returned. +func (rl RateLimit) rateLimiter(extensions map[component.ID]component.Component) (rateLimiter, error) { + if ext, found := extensions[rl.RateLimiterID]; found { + if limiter, ok := ext.(rateLimiter); ok { + return limiter, nil + } + return nil, fmt.Errorf("extension %q is not a rate limit", rl.RateLimiterID) + } + return nil, fmt.Errorf("rate limit %q not found", rl.RateLimiterID) +} + +// rateLimitInterceptor adds interceptor for rate limit check. +// It returns TooManyRequests(429) status code if rate limiter rejects the request. +func rateLimitInterceptor(next http.Handler, limiter rateLimiter) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + err := limiter.Take(r.Context(), r.URL.Path, r.Header) + if err != nil { + http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/config/confighttp/ratelimit_test.go b/config/confighttp/ratelimit_test.go new file mode 100644 index 00000000000..b06f47010cd --- /dev/null +++ b/config/confighttp/ratelimit_test.go @@ -0,0 +1,111 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package confighttp + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/collector/component" + "go.opentelemetry.io/collector/component/componenttest" +) + +func TestServerRateLimit(t *testing.T) { + // prepare + hss := HTTPServerSettings{ + Endpoint: "localhost:0", + RateLimit: &RateLimit{ + RateLimiterID: component.NewID("mock"), + }, + } + + limiter := &mockRateLimiter{} + + host := &mockHost{ + ext: map[component.ID]component.Component{ + component.NewID("mock"): limiter, + }, + } + + var handlerCalled bool + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handlerCalled = true + }) + + srv, err := hss.ToServer(host, componenttest.NewNopTelemetrySettings(), handler) + require.NoError(t, err) + + // test + srv.Handler.ServeHTTP(&httptest.ResponseRecorder{}, httptest.NewRequest("GET", "/", nil)) + + // verify + assert.True(t, handlerCalled) + assert.Equal(t, 1, limiter.calls) +} + +func TestInvalidServerRateLimit(t *testing.T) { + hss := HTTPServerSettings{ + RateLimit: &RateLimit{ + RateLimiterID: component.NewID("non-existing"), + }, + } + + srv, err := hss.ToServer(componenttest.NewNopHost(), componenttest.NewNopTelemetrySettings(), http.NewServeMux()) + require.Error(t, err) + require.Nil(t, srv) +} + +func TestRejectedServerRateLimit(t *testing.T) { + // prepare + hss := HTTPServerSettings{ + Endpoint: "localhost:0", + RateLimit: &RateLimit{ + RateLimiterID: component.NewID("mock"), + }, + } + host := &mockHost{ + ext: map[component.ID]component.Component{ + component.NewID("mock"): &mockRateLimiter{ + err: errors.New("rate limited"), + }, + }, + } + + srv, err := hss.ToServer(host, componenttest.NewNopTelemetrySettings(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + require.NoError(t, err) + + // test + response := &httptest.ResponseRecorder{} + srv.Handler.ServeHTTP(response, httptest.NewRequest("GET", "/", nil)) + + // verify + assert.Equal(t, response.Result().StatusCode, http.StatusTooManyRequests) + assert.Equal(t, response.Result().Status, fmt.Sprintf("%v %s", http.StatusTooManyRequests, http.StatusText(http.StatusTooManyRequests))) +} + +// Mocks + +type mockRateLimiter struct { + calls int + err error +} + +func (m *mockRateLimiter) Take(context.Context, string, http.Header) error { + m.calls++ + return m.err +} + +func (m *mockRateLimiter) Start(_ context.Context, _ component.Host) error { + return nil +} + +func (m *mockRateLimiter) Shutdown(_ context.Context) error { + return nil +}