diff --git a/services/horizon/internal/httpx/router.go b/services/horizon/internal/httpx/router.go index 922f5ac7af..d78042a7cc 100644 --- a/services/horizon/internal/httpx/router.go +++ b/services/horizon/internal/httpx/router.go @@ -19,6 +19,7 @@ import ( "github.com/stellar/go/services/horizon/internal/db2/history" "github.com/stellar/go/services/horizon/internal/ledger" "github.com/stellar/go/services/horizon/internal/paths" + "github.com/stellar/go/services/horizon/internal/render" "github.com/stellar/go/services/horizon/internal/render/sse" "github.com/stellar/go/services/horizon/internal/txsub" "github.com/stellar/go/support/db" @@ -97,7 +98,18 @@ func (r *Router) addMiddleware(config *RouterConfig, r.Use(c.Handler) if rateLimitter != nil { - r.Use(rateLimitter.RateLimit) + r.Use(func(handler http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Exempt streaming requests from rate limits via the HTTP middleware + // because rate limiting for streaming requests are already implemented in + // StreamHandler.ServeStream(). + if render.Negotiate(r) == render.MimeEventStream { + handler.ServeHTTP(w, r) + return + } + rateLimitter.RateLimit(handler).ServeHTTP(w, r) + }) + }) } if config.PrimaryDBSession != nil { diff --git a/services/horizon/internal/middleware_test.go b/services/horizon/internal/middleware_test.go index c6b617c2b3..b44d2cade6 100644 --- a/services/horizon/internal/middleware_test.go +++ b/services/horizon/internal/middleware_test.go @@ -82,16 +82,25 @@ func (suite *RateLimitMiddlewareTestSuite) TestRateLimit_LimitHeaders() { // Sets X-RateLimit-Remaining headers correctly. func (suite *RateLimitMiddlewareTestSuite) TestRateLimit_RemainingHeaders() { + // test that SSE requests are ignored + for i := 0; i < 10; i++ { + w := suite.rh.Get("/", test.RequestHelperStreaming) + assert.Equal(suite.T(), "", w.Header().Get("X-RateLimit-Remaining")) + assert.NotEqual(suite.T(), http.StatusTooManyRequests, w.Code) + } + for i := 0; i < 10; i++ { w := suite.rh.Get("/") expected := 10 - (i + 1) assert.Equal(suite.T(), strconv.Itoa(expected), w.Header().Get("X-RateLimit-Remaining")) + assert.NotEqual(suite.T(), http.StatusTooManyRequests, w.Code) } // confirm remaining stays at 0 for i := 0; i < 10; i++ { w := suite.rh.Get("/") assert.Equal(suite.T(), "0", w.Header().Get("X-RateLimit-Remaining")) + assert.Equal(suite.T(), http.StatusTooManyRequests, w.Code) } }