From 7ae36e525f3a99d0ab91c5a75d930f0aef66f100 Mon Sep 17 00:00:00 2001 From: Dhia Ayachi Date: Tue, 23 May 2023 15:37:01 -0400 Subject: [PATCH] add necessary plumbing to implement per server ip based rate limiting (#17436) --- agent/consul/rate/handler.go | 45 +++++++++------- agent/consul/rate/handler_test.go | 27 +++------- .../rate/mock_LeaderStatusProvider_test.go | 39 -------------- .../consul/rate/mock_RequestLimitsHandler.go | 6 +-- .../rate/mock_ServersStatusProvider_test.go | 53 +++++++++++++++++++ agent/consul/server.go | 14 +++++ agent/metadata/server.go | 10 ++++ 7 files changed, 113 insertions(+), 81 deletions(-) delete mode 100644 agent/consul/rate/mock_LeaderStatusProvider_test.go create mode 100644 agent/consul/rate/mock_ServersStatusProvider_test.go diff --git a/agent/consul/rate/handler.go b/agent/consul/rate/handler.go index a67a104cc46f5..84faefb6c3e81 100644 --- a/agent/consul/rate/handler.go +++ b/agent/consul/rate/handler.go @@ -8,6 +8,7 @@ import ( "context" "errors" "fmt" + "github.com/hashicorp/consul/agent/metadata" "net" "reflect" "sync/atomic" @@ -153,14 +154,14 @@ type RequestLimitsHandler interface { Allow(op Operation) error UpdateConfig(cfg HandlerConfig) UpdateIPConfig(cfg IPLimitConfig) - Register(leaderStatusProvider LeaderStatusProvider) + Register(serversStatusProvider ServersStatusProvider) } // Handler enforces rate limits for incoming RPCs. type Handler struct { - globalCfg *atomic.Pointer[HandlerConfig] - ipCfg *atomic.Pointer[IPLimitConfig] - leaderStatusProvider LeaderStatusProvider + globalCfg *atomic.Pointer[HandlerConfig] + ipCfg *atomic.Pointer[IPLimitConfig] + serversStatusProvider ServersStatusProvider limiter multilimiter.RateLimiter @@ -186,13 +187,14 @@ type HandlerConfig struct { GlobalLimitConfig GlobalLimitConfig } -//go:generate mockery --name LeaderStatusProvider --inpackage --filename mock_LeaderStatusProvider_test.go -type LeaderStatusProvider interface { +//go:generate mockery --name ServersStatusProvider --inpackage --filename mock_ServersStatusProvider_test.go +type ServersStatusProvider interface { // IsLeader is used to determine whether the operation is being performed // against the cluster leader, such that if it can _only_ be performed by // the leader (e.g. write operations) we don't tell clients to retry against // a different server. IsLeader() bool + IsServer(addr string) bool } func isInfRate(cfg multilimiter.LimiterConfig) bool { @@ -237,11 +239,11 @@ func (h *Handler) Run(ctx context.Context) { // because of an exhausted rate-limit. func (h *Handler) Allow(op Operation) error { - if h.leaderStatusProvider == nil { - h.logger.Error("leaderStatusProvider required to be set via Register(). bailing on rate limiter") + if h.serversStatusProvider == nil { + h.logger.Error("serversStatusProvider required to be set via Register(). bailing on rate limiter") return nil // TODO: panic and make sure to use the server's recovery handler - // panic("leaderStatusProvider required to be set via Register(..)") + // panic("serversStatusProvider required to be set via Register(..)") } cfg := h.globalCfg.Load() @@ -249,7 +251,7 @@ func (h *Handler) Allow(op Operation) error { return nil } - allow, throttledLimits := h.allowAllLimits(h.limits(op)) + allow, throttledLimits := h.allowAllLimits(h.limits(op), h.serversStatusProvider.IsServer(string(metadata.GetIP(op.SourceAddr)))) if !allow { for _, l := range throttledLimits { @@ -277,7 +279,7 @@ func (h *Handler) Allow(op Operation) error { }) if enforced { - if h.leaderStatusProvider.IsLeader() && op.Type == OperationTypeWrite { + if h.serversStatusProvider.IsLeader() && op.Type == OperationTypeWrite { return ErrRetryLater } return ErrRetryElsewhere @@ -305,17 +307,18 @@ func (h *Handler) UpdateConfig(cfg HandlerConfig) { } -func (h *Handler) Register(leaderStatusProvider LeaderStatusProvider) { - h.leaderStatusProvider = leaderStatusProvider +func (h *Handler) Register(serversStatusProvider ServersStatusProvider) { + h.serversStatusProvider = serversStatusProvider } type limit struct { - mode Mode - ent multilimiter.LimitedEntity - desc string + mode Mode + ent multilimiter.LimitedEntity + desc string + applyOnServer bool } -func (h *Handler) allowAllLimits(limits []limit) (bool, []limit) { +func (h *Handler) allowAllLimits(limits []limit, isServer bool) (bool, []limit) { allow := true throttledLimits := make([]limit, 0) @@ -324,6 +327,10 @@ func (h *Handler) allowAllLimits(limits []limit) (bool, []limit) { continue } + if isServer && !l.applyOnServer { + continue + } + if !h.limiter.Allow(l.ent) { throttledLimits = append(throttledLimits, l) allow = false @@ -358,7 +365,7 @@ func (h *Handler) globalLimit(op Operation) *limit { } cfg := h.globalCfg.Load() - lim := &limit{mode: cfg.GlobalLimitConfig.Mode} + lim := &limit{mode: cfg.GlobalLimitConfig.Mode, applyOnServer: true} switch op.Type { case OperationTypeRead: lim.desc = "global/read" @@ -409,4 +416,4 @@ func (nullRequestLimitsHandler) Run(_ context.Context) {} func (nullRequestLimitsHandler) UpdateConfig(_ HandlerConfig) {} -func (nullRequestLimitsHandler) Register(_ LeaderStatusProvider) {} +func (nullRequestLimitsHandler) Register(_ ServersStatusProvider) {} diff --git a/agent/consul/rate/handler_test.go b/agent/consul/rate/handler_test.go index 0311cb60089dc..54a8b86a4b989 100644 --- a/agent/consul/rate/handler_test.go +++ b/agent/consul/rate/handler_test.go @@ -19,22 +19,6 @@ import ( "github.com/hashicorp/consul/agent/consul/multilimiter" ) -// -// Revisit test when handler.go:189 TODO implemented -// -// func TestHandler_Allow_PanicsWhenLeaderStatusProviderNotRegistered(t *testing.T) { -// defer func() { -// err := recover() -// if err == nil { -// t.Fatal("Run should panic") -// } -// }() - -// handler := NewHandler(HandlerConfig{}, hclog.NewNullLogger()) -// handler.Allow(Operation{}) -// // intentionally skip handler.Register(...) -// } - func TestHandler(t *testing.T) { var ( rpcName = "Foo.Bar" @@ -50,6 +34,7 @@ func TestHandler(t *testing.T) { globalMode Mode checks []limitCheck isLeader bool + isServer bool expectErr error expectLog bool expectMetric bool @@ -230,8 +215,9 @@ func TestHandler(t *testing.T) { limiter.On("Allow", mock.Anything).Return(c.allow) } - leaderStatusProvider := NewMockLeaderStatusProvider(t) - leaderStatusProvider.On("IsLeader").Return(tc.isLeader).Maybe() + serversStatusProvider := NewMockServersStatusProvider(t) + serversStatusProvider.On("IsLeader").Return(tc.isLeader).Maybe() + serversStatusProvider.On("IsServer", mock.Anything).Return(tc.isServer).Maybe() var output bytes.Buffer logger := hclog.NewInterceptLogger(&hclog.LoggerOptions{ @@ -252,7 +238,7 @@ func TestHandler(t *testing.T) { limiter, logger, ) - handler.Register(leaderStatusProvider) + handler.Register(serversStatusProvider) require.Equal(t, tc.expectErr, handler.Allow(tc.op)) @@ -426,8 +412,9 @@ func TestAllow(t *testing.T) { } mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return() logger := hclog.NewNullLogger() - delegate := NewMockLeaderStatusProvider(t) + delegate := NewMockServersStatusProvider(t) delegate.On("IsLeader").Return(true).Maybe() + delegate.On("IsServer", mock.Anything).Return(false).Maybe() handler := NewHandlerWithLimiter(*tc.cfg, mockRateLimiter, logger) handler.Register(delegate) addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234")) diff --git a/agent/consul/rate/mock_LeaderStatusProvider_test.go b/agent/consul/rate/mock_LeaderStatusProvider_test.go deleted file mode 100644 index 92af311b16023..0000000000000 --- a/agent/consul/rate/mock_LeaderStatusProvider_test.go +++ /dev/null @@ -1,39 +0,0 @@ -// Code generated by mockery v2.20.0. DO NOT EDIT. - -package rate - -import mock "github.com/stretchr/testify/mock" - -// MockLeaderStatusProvider is an autogenerated mock type for the LeaderStatusProvider type -type MockLeaderStatusProvider struct { - mock.Mock -} - -// IsLeader provides a mock function with given fields: -func (_m *MockLeaderStatusProvider) IsLeader() bool { - ret := _m.Called() - - var r0 bool - if rf, ok := ret.Get(0).(func() bool); ok { - r0 = rf() - } else { - r0 = ret.Get(0).(bool) - } - - return r0 -} - -type mockConstructorTestingTNewMockLeaderStatusProvider interface { - mock.TestingT - Cleanup(func()) -} - -// NewMockLeaderStatusProvider creates a new instance of MockLeaderStatusProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewMockLeaderStatusProvider(t mockConstructorTestingTNewMockLeaderStatusProvider) *MockLeaderStatusProvider { - mock := &MockLeaderStatusProvider{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/agent/consul/rate/mock_RequestLimitsHandler.go b/agent/consul/rate/mock_RequestLimitsHandler.go index dcd131af42812..3fa07089b7bd9 100644 --- a/agent/consul/rate/mock_RequestLimitsHandler.go +++ b/agent/consul/rate/mock_RequestLimitsHandler.go @@ -27,9 +27,9 @@ func (_m *MockRequestLimitsHandler) Allow(op Operation) error { return r0 } -// Register provides a mock function with given fields: leaderStatusProvider -func (_m *MockRequestLimitsHandler) Register(leaderStatusProvider LeaderStatusProvider) { - _m.Called(leaderStatusProvider) +// Register provides a mock function with given fields: serversStatusProvider +func (_m *MockRequestLimitsHandler) Register(serversStatusProvider ServersStatusProvider) { + _m.Called(serversStatusProvider) } // Run provides a mock function with given fields: ctx diff --git a/agent/consul/rate/mock_ServersStatusProvider_test.go b/agent/consul/rate/mock_ServersStatusProvider_test.go new file mode 100644 index 0000000000000..42cd710743652 --- /dev/null +++ b/agent/consul/rate/mock_ServersStatusProvider_test.go @@ -0,0 +1,53 @@ +// Code generated by mockery v2.20.0. DO NOT EDIT. + +package rate + +import mock "github.com/stretchr/testify/mock" + +// MockServersStatusProvider is an autogenerated mock type for the ServersStatusProvider type +type MockServersStatusProvider struct { + mock.Mock +} + +// IsLeader provides a mock function with given fields: +func (_m *MockServersStatusProvider) IsLeader() bool { + ret := _m.Called() + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// IsServer provides a mock function with given fields: addr +func (_m *MockServersStatusProvider) IsServer(addr string) bool { + ret := _m.Called(addr) + + var r0 bool + if rf, ok := ret.Get(0).(func(string) bool); ok { + r0 = rf(addr) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +type mockConstructorTestingTNewMockServersStatusProvider interface { + mock.TestingT + Cleanup(func()) +} + +// NewMockServersStatusProvider creates a new instance of MockServersStatusProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockServersStatusProvider(t mockConstructorTestingTNewMockServersStatusProvider) *MockServersStatusProvider { + mock := &MockServersStatusProvider{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/agent/consul/server.go b/agent/consul/server.go index 88c05b711d604..294417c5e85f2 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -1660,6 +1660,20 @@ func (s *Server) IsLeader() bool { return s.raft.State() == raft.Leader } +// IsServer checks if this addr is of a server +func (s *Server) IsServer(addr string) bool { + for _, s := range s.raft.GetConfiguration().Configuration().Servers { + a, err := net.ResolveTCPAddr("tcp", string(s.Address)) + if err != nil { + continue + } + if string(metadata.GetIP(a)) == addr { + return true + } + } + return false +} + // LeaderLastContact returns the time of last contact by a leader. // This only makes sense if we are currently a follower. func (s *Server) LeaderLastContact() time.Time { diff --git a/agent/metadata/server.go b/agent/metadata/server.go index b2a7238cb9cac..64c9936909892 100644 --- a/agent/metadata/server.go +++ b/agent/metadata/server.go @@ -221,3 +221,13 @@ func AddFeatureFlags(tags map[string]string, flags ...string) { tags[featureFlagPrefix+flag] = "1" } } + +func GetIP(addr net.Addr) []byte { + switch a := addr.(type) { + case *net.UDPAddr: + return []byte(a.IP.String()) + case *net.TCPAddr: + return []byte(a.IP.String()) + } + return []byte{} +}