diff --git a/agent/consul/rate/handler.go b/agent/consul/rate/handler.go index 12ed974c6a87..a67a104cc46f 100644 --- a/agent/consul/rate/handler.go +++ b/agent/consul/rate/handler.go @@ -215,6 +215,7 @@ func NewHandlerWithLimiter( logger: logger, } h.globalCfg.Store(&cfg) + h.ipCfg.Store(&IPLimitConfig{}) return h } @@ -248,45 +249,39 @@ func (h *Handler) Allow(op Operation) error { return nil } - for _, l := range h.limits(op) { - if l.mode == ModeDisabled { - continue - } - - if h.limiter.Allow(l.ent) { - continue - } - - // TODO(NET-1382): is this the correct log-level? - - enforced := l.mode == ModeEnforcing - h.logger.Debug("RPC exceeded allowed rate limit", - "rpc", op.Name, - "source_addr", op.SourceAddr, - "limit_type", l.desc, - "limit_enforced", enforced, - ) - - metrics.IncrCounterWithLabels([]string{"rpc", "rate_limit", "exceeded"}, 1, []metrics.Label{ - { - Name: "limit_type", - Value: l.desc, - }, - { - Name: "op", - Value: op.Name, - }, - { - Name: "mode", - Value: l.mode.String(), - }, - }) - - if enforced { - if h.leaderStatusProvider.IsLeader() && op.Type == OperationTypeWrite { - return ErrRetryLater + allow, throttledLimits := h.allowAllLimits(h.limits(op)) + + if !allow { + for _, l := range throttledLimits { + enforced := l.mode == ModeEnforcing + h.logger.Debug("RPC exceeded allowed rate limit", + "rpc", op.Name, + "source_addr", op.SourceAddr, + "limit_type", l.desc, + "limit_enforced", enforced, + ) + + metrics.IncrCounterWithLabels([]string{"rpc", "rate_limit", "exceeded"}, 1, []metrics.Label{ + { + Name: "limit_type", + Value: l.desc, + }, + { + Name: "op", + Value: op.Name, + }, + { + Name: "mode", + Value: l.mode.String(), + }, + }) + + if enforced { + if h.leaderStatusProvider.IsLeader() && op.Type == OperationTypeWrite { + return ErrRetryLater + } + return ErrRetryElsewhere } - return ErrRetryElsewhere } } return nil @@ -320,6 +315,23 @@ type limit struct { desc string } +func (h *Handler) allowAllLimits(limits []limit) (bool, []limit) { + allow := true + throttledLimits := make([]limit, 0) + + for _, l := range limits { + if l.mode == ModeDisabled { + continue + } + + if !h.limiter.Allow(l.ent) { + throttledLimits = append(throttledLimits, l) + allow = false + } + } + return allow, throttledLimits +} + // limits returns the limits to check for the given operation (e.g. global + // ip-based + tenant-based). func (h *Handler) limits(op Operation) []limit { @@ -329,6 +341,14 @@ func (h *Handler) limits(op Operation) []limit { limits = append(limits, *global) } + if ipGlobal := h.ipGlobalLimit(op); ipGlobal != nil { + limits = append(limits, *ipGlobal) + } + + if ipCategory := h.ipCategoryLimit(op); ipCategory != nil { + limits = append(limits, *ipCategory) + } + return limits } @@ -354,23 +374,23 @@ func (h *Handler) globalLimit(op Operation) *limit { var ( // globalWrite identifies the global rate limit applied to write operations. - globalWrite = globalLimit("global.write") + globalWrite = limitedEntity("global.write") // globalRead identifies the global rate limit applied to read operations. - globalRead = globalLimit("global.read") + globalRead = limitedEntity("global.read") // globalIPRead identifies the global rate limit applied to read operations. - globalIPRead = globalLimit("global.ip.read") + globalIPRead = limitedEntity("global.ip.read") // globalIPWrite identifies the global rate limit applied to read operations. - globalIPWrite = globalLimit("global.ip.write") + globalIPWrite = limitedEntity("global.ip.write") ) -// globalLimit represents a limit that applies to all writes or reads. -type globalLimit []byte +// limitedEntity convert the string type to Multilimiter.LimitedEntity +type limitedEntity []byte // Key satisfies the multilimiter.LimitedEntity interface. -func (prefix globalLimit) Key() multilimiter.KeyType { +func (prefix limitedEntity) Key() multilimiter.KeyType { return multilimiter.Key(prefix, nil) } diff --git a/agent/consul/rate/handler_oss.go b/agent/consul/rate/handler_oss.go index 1316ae6cc96c..fc33a69487f8 100644 --- a/agent/consul/rate/handler_oss.go +++ b/agent/consul/rate/handler_oss.go @@ -6,9 +6,16 @@ package rate -type IPLimitConfig struct { -} +type IPLimitConfig struct{} func (h *Handler) UpdateIPConfig(cfg IPLimitConfig) { // noop } + +func (h *Handler) ipGlobalLimit(op Operation) *limit { + return nil +} + +func (h *Handler) ipCategoryLimit(op Operation) *limit { + return nil +} diff --git a/agent/consul/rate/handler_test.go b/agent/consul/rate/handler_test.go index 8f1b465f473f..0311cb60089d 100644 --- a/agent/consul/rate/handler_test.go +++ b/agent/consul/rate/handler_test.go @@ -5,20 +5,18 @@ package rate import ( "bytes" - "context" + "github.com/hashicorp/consul/agent/metrics" + "github.com/stretchr/testify/require" "net" "net/netip" "testing" "golang.org/x/time/rate" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" - "github.com/hashicorp/go-hclog" + "github.com/stretchr/testify/mock" "github.com/hashicorp/consul/agent/consul/multilimiter" - "github.com/hashicorp/consul/agent/metrics" ) // @@ -226,10 +224,10 @@ func TestHandler(t *testing.T) { for desc, tc := range testCases { t.Run(desc, func(t *testing.T) { sink := metrics.TestSetupMetrics(t, "") - limiter := newMockLimiter(t) + limiter := multilimiter.NewMockRateLimiter(t) limiter.On("UpdateConfig", mock.Anything, mock.Anything).Return() for _, c := range tc.checks { - limiter.On("Allow", c.limit).Return(c.allow) + limiter.On("Allow", mock.Anything).Return(c.allow) } leaderStatusProvider := NewMockLeaderStatusProvider(t) @@ -376,7 +374,7 @@ func TestAllow(t *testing.T) { type testCase struct { description string cfg *HandlerConfig - expectedAllowCalls int + expectedAllowCalls bool } testCases := []testCase{ { @@ -390,7 +388,7 @@ func TestAllow(t *testing.T) { }, }, }, - expectedAllowCalls: 0, + expectedAllowCalls: false, }, { description: "RateLimiter gets called when mode is permissive.", @@ -403,7 +401,7 @@ func TestAllow(t *testing.T) { }, }, }, - expectedAllowCalls: 1, + expectedAllowCalls: true, }, { description: "RateLimiter gets called when mode is enforcing.", @@ -416,14 +414,14 @@ func TestAllow(t *testing.T) { }, }, }, - expectedAllowCalls: 1, + expectedAllowCalls: true, }, } for _, tc := range testCases { t.Run(tc.description, func(t *testing.T) { mockRateLimiter := multilimiter.NewMockRateLimiter(t) - if tc.expectedAllowCalls > 0 { + if tc.expectedAllowCalls { mockRateLimiter.On("Allow", mock.Anything).Return(func(entity multilimiter.LimitedEntity) bool { return true }) } mockRateLimiter.On("UpdateConfig", mock.Anything, mock.Anything).Return() @@ -435,31 +433,7 @@ func TestAllow(t *testing.T) { addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:1234")) mockRateLimiter.Calls = nil handler.Allow(Operation{Name: "test", SourceAddr: addr}) - mockRateLimiter.AssertNumberOfCalls(t, "Allow", tc.expectedAllowCalls) + mockRateLimiter.AssertExpectations(t) }) } } - -var _ multilimiter.RateLimiter = (*mockLimiter)(nil) - -func newMockLimiter(t *testing.T) *mockLimiter { - l := &mockLimiter{} - l.Mock.Test(t) - - t.Cleanup(func() { l.AssertExpectations(t) }) - - return l -} - -type mockLimiter struct { - mock.Mock -} - -func (m *mockLimiter) Allow(v multilimiter.LimitedEntity) bool { return m.Called(v).Bool(0) } -func (m *mockLimiter) Run(ctx context.Context) { m.Called(ctx) } -func (m *mockLimiter) UpdateConfig(cfg multilimiter.LimiterConfig, prefix []byte) { - m.Called(cfg, prefix) -} -func (m *mockLimiter) DeleteConfig(prefix []byte) { - m.Called(prefix) -}