From 13835ed311d8712f247b2cd23637102927460a4a Mon Sep 17 00:00:00 2001 From: "Derrick J. Wippler" Date: Thu, 9 May 2024 17:01:50 -0500 Subject: [PATCH 1/6] Otter is now available as a cache option --- benchmark_cache_test.go | 117 +++-- cache_manager.go | 153 ++++++ workers_test.go => cache_manager_test.go | 30 +- cmd/gubernator/main.go | 2 +- cmd/gubernator/main_test.go | 2 +- config.go | 6 +- daemon.go | 15 +- functional_test.go | 2 + global.go | 2 +- go.mod | 4 +- go.sum | 8 +- gubernator.go | 53 +- lrucache.go | 28 +- lrucache_test.go | 6 +- otter.go | 111 ++++ otter_test.go | 405 +++++++++++++++ peer_client.go | 7 +- tls.go | 7 +- workers.go | 626 ----------------------- workers_internal_test.go | 84 --- 20 files changed, 834 insertions(+), 834 deletions(-) create mode 100644 cache_manager.go rename workers_test.go => cache_manager_test.go (83%) create mode 100644 otter.go create mode 100644 otter_test.go delete mode 100644 workers.go delete mode 100644 workers_internal_test.go diff --git a/benchmark_cache_test.go b/benchmark_cache_test.go index e19ee7e..98c2b68 100644 --- a/benchmark_cache_test.go +++ b/benchmark_cache_test.go @@ -1,66 +1,80 @@ package gubernator_test import ( - "strconv" + "math/rand" "sync" "testing" "time" "github.com/gubernator-io/gubernator/v2" "github.com/mailgun/holster/v4/clock" + "github.com/stretchr/testify/require" ) func BenchmarkCache(b *testing.B) { testCases := []struct { Name string - NewTestCache func() gubernator.Cache + NewTestCache func() (gubernator.Cache, error) LockRequired bool }{ { Name: "LRUCache", - NewTestCache: func() gubernator.Cache { - return gubernator.NewLRUCache(0) + NewTestCache: func() (gubernator.Cache, error) { + return gubernator.NewLRUCache(0), nil }, LockRequired: true, }, + { + Name: "OtterCache", + NewTestCache: func() (gubernator.Cache, error) { + return gubernator.NewOtterCache(0) + }, + LockRequired: false, + }, } for _, testCase := range testCases { b.Run(testCase.Name, func(b *testing.B) { b.Run("Sequential reads", func(b *testing.B) { - cache := testCase.NewTestCache() + cache, err := testCase.NewTestCache() + require.NoError(b, err) expire := clock.Now().Add(time.Hour).UnixMilli() + keys := GenerateRandomKeys() - for i := 0; i < b.N; i++ { - key := strconv.Itoa(i) + for _, key := range keys { item := &gubernator.CacheItem{ Key: key, - Value: i, + Value: "value:" + key, ExpireAt: expire, } cache.Add(item) } + mask := len(keys) - 1 b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - key := strconv.Itoa(i) - _, _ = cache.GetItem(key) + index := int(rand.Uint32() & uint32(mask)) + _, _ = cache.GetItem(keys[index&mask]) } }) b.Run("Sequential writes", func(b *testing.B) { - cache := testCase.NewTestCache() + cache, err := testCase.NewTestCache() + require.NoError(b, err) expire := clock.Now().Add(time.Hour).UnixMilli() + keys := GenerateRandomKeys() + mask := len(keys) - 1 b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { + index := int(rand.Uint32() & uint32(mask)) item := &gubernator.CacheItem{ - Key: strconv.Itoa(i), - Value: i, + Key: keys[index&mask], + Value: "value:" + keys[index&mask], ExpireAt: expire, } cache.Add(item) @@ -68,93 +82,102 @@ func BenchmarkCache(b *testing.B) { }) b.Run("Concurrent reads", func(b *testing.B) { - cache := testCase.NewTestCache() + cache, err := testCase.NewTestCache() + require.NoError(b, err) expire := clock.Now().Add(time.Hour).UnixMilli() + keys := GenerateRandomKeys() - for i := 0; i < b.N; i++ { - key := strconv.Itoa(i) + for _, key := range keys { item := &gubernator.CacheItem{ Key: key, - Value: i, + Value: "value:" + key, ExpireAt: expire, } cache.Add(item) } - var wg sync.WaitGroup var mutex sync.Mutex - var task func(i int) + var task func(key string) if testCase.LockRequired { - task = func(i int) { + task = func(key string) { mutex.Lock() defer mutex.Unlock() - key := strconv.Itoa(i) _, _ = cache.GetItem(key) - wg.Done() } } else { - task = func(i int) { - key := strconv.Itoa(i) + task = func(key string) { _, _ = cache.GetItem(key) - wg.Done() } } b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { - wg.Add(1) - go task(i) - } + mask := len(keys) - 1 + + b.RunParallel(func(pb *testing.PB) { + index := int(rand.Uint32() & uint32(mask)) + for pb.Next() { + task(keys[index&mask]) + } + }) - wg.Wait() }) b.Run("Concurrent writes", func(b *testing.B) { - cache := testCase.NewTestCache() + cache, err := testCase.NewTestCache() + require.NoError(b, err) expire := clock.Now().Add(time.Hour).UnixMilli() + keys := GenerateRandomKeys() - var wg sync.WaitGroup var mutex sync.Mutex - var task func(i int) + var task func(key string) if testCase.LockRequired { - task = func(i int) { + task = func(key string) { mutex.Lock() defer mutex.Unlock() item := &gubernator.CacheItem{ - Key: strconv.Itoa(i), - Value: i, + Key: key, + Value: "value:" + key, ExpireAt: expire, } cache.Add(item) - wg.Done() } } else { - task = func(i int) { + task = func(key string) { item := &gubernator.CacheItem{ - Key: strconv.Itoa(i), - Value: i, + Key: key, + Value: "value:" + key, ExpireAt: expire, } cache.Add(item) - wg.Done() } } + mask := len(keys) - 1 b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { - wg.Add(1) - go task(i) - } - - wg.Wait() + b.RunParallel(func(pb *testing.PB) { + index := int(rand.Uint32() & uint32(mask)) + for pb.Next() { + task(keys[index&mask]) + } + }) }) }) } } + +const cacheSize = 32768 + +func GenerateRandomKeys() []string { + keys := make([]string, 0, cacheSize) + for i := 0; i < cacheSize; i++ { + keys = append(keys, gubernator.RandomString(20)) + } + return keys +} diff --git a/cache_manager.go b/cache_manager.go new file mode 100644 index 0000000..6542bdf --- /dev/null +++ b/cache_manager.go @@ -0,0 +1,153 @@ +/* +Copyright 2024 Derrick J. Wippler + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gubernator + +import ( + "context" + "sync" + + "github.com/pkg/errors" +) + +type CacheManager interface { + GetRateLimit(context.Context, *RateLimitReq, RateLimitReqState) (*RateLimitResp, error) + GetCacheItem(context.Context, string) (*CacheItem, bool, error) + AddCacheItem(context.Context, string, *CacheItem) error + Store(ctx context.Context) error + Load(context.Context) error + Close() error +} + +type cacheManager struct { + conf Config + cache Cache +} + +// NewCacheManager creates a new instance of the CacheManager interface using +// the cache returned by Config.CacheFactory +func NewCacheManager(conf Config) (CacheManager, error) { + + cache, err := conf.CacheFactory(conf.CacheSize) + if err != nil { + return nil, err + } + return &cacheManager{ + cache: cache, + conf: conf, + }, nil +} + +// GetRateLimit fetches the item from the cache if it exists, and preforms the appropriate rate limit calculation +func (m *cacheManager) GetRateLimit(ctx context.Context, req *RateLimitReq, state RateLimitReqState) (*RateLimitResp, error) { + var rlResponse *RateLimitResp + var err error + + switch req.Algorithm { + case Algorithm_TOKEN_BUCKET: + rlResponse, err = tokenBucket(ctx, m.conf.Store, m.cache, req, state) + if err != nil { + msg := "Error in tokenBucket" + countError(err, msg) + } + + case Algorithm_LEAKY_BUCKET: + rlResponse, err = leakyBucket(ctx, m.conf.Store, m.cache, req, state) + if err != nil { + msg := "Error in leakyBucket" + countError(err, msg) + } + + default: + err = errors.Errorf("Invalid rate limit algorithm '%d'", req.Algorithm) + } + + return rlResponse, err +} + +// Store saves every cache item into persistent storage provided via Config.Loader +func (m *cacheManager) Store(ctx context.Context) error { + out := make(chan *CacheItem, 500) + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + for item := range m.cache.Each() { + select { + case out <- item: + + case <-ctx.Done(): + return + } + } + }() + + go func() { + wg.Wait() + close(out) + }() + + if ctx.Err() != nil { + return ctx.Err() + } + + if err := m.conf.Loader.Save(out); err != nil { + return errors.Wrap(err, "while calling p.conf.Loader.Save()") + } + return nil +} + +// Close closes the cache manager +func (m *cacheManager) Close() error { + return m.cache.Close() +} + +// Load cache items from persistent storage provided via Config.Loader +func (m *cacheManager) Load(ctx context.Context) error { + ch, err := m.conf.Loader.Load() + if err != nil { + return errors.Wrap(err, "Error in loader.Load") + } + + for { + var item *CacheItem + var ok bool + + select { + case item, ok = <-ch: + if !ok { + return nil + } + case <-ctx.Done(): + return ctx.Err() + } + _ = m.cache.Add(item) + } +} + +// GetCacheItem returns an item from the cache +func (m *cacheManager) GetCacheItem(_ context.Context, key string) (*CacheItem, bool, error) { + item, ok := m.cache.GetItem(key) + return item, ok, nil +} + +// AddCacheItem adds an item to the cache. The CacheItem.Key should be set correctly, else the item +// will not be added to the cache correctly. +func (m *cacheManager) AddCacheItem(_ context.Context, _ string, item *CacheItem) error { + _ = m.cache.Add(item) + return nil +} diff --git a/workers_test.go b/cache_manager_test.go similarity index 83% rename from workers_test.go rename to cache_manager_test.go index 4e77960..81e3d3c 100644 --- a/workers_test.go +++ b/cache_manager_test.go @@ -28,7 +28,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestGubernatorPool(t *testing.T) { +func TestCacheManager(t *testing.T) { ctx := context.Background() testCases := []struct { @@ -43,7 +43,7 @@ func TestGubernatorPool(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { // Setup mock data. const NumCacheItems = 100 - cacheItems := []*guber.CacheItem{} + var cacheItems []*guber.CacheItem for i := 0; i < NumCacheItems; i++ { cacheItems = append(cacheItems, &guber.CacheItem{ Key: fmt.Sprintf("Foobar%04d", i), @@ -55,15 +55,16 @@ func TestGubernatorPool(t *testing.T) { t.Run("Load()", func(t *testing.T) { mockLoader := &MockLoader2{} mockCache := &MockCache{} - conf := &guber.Config{ - CacheFactory: func(maxSize int) guber.Cache { - return mockCache + conf := guber.Config{ + CacheFactory: func(maxSize int) (guber.Cache, error) { + return mockCache, nil }, Loader: mockLoader, Workers: testCase.workers, } assert.NoError(t, conf.SetDefaults()) - chp := guber.NewWorkerPool(conf) + manager, err := guber.NewCacheManager(conf) + require.NoError(t, err) // Mock Loader. fakeLoadCh := make(chan *guber.CacheItem, NumCacheItems) @@ -79,31 +80,32 @@ func TestGubernatorPool(t *testing.T) { } // Call code. - err := chp.Load(ctx) + err = manager.Load(ctx) // Verify. - require.NoError(t, err, "Error in chp.Load") + require.NoError(t, err, "Error in manager.Load") }) t.Run("Store()", func(t *testing.T) { mockLoader := &MockLoader2{} mockCache := &MockCache{} - conf := &guber.Config{ - CacheFactory: func(maxSize int) guber.Cache { - return mockCache + conf := guber.Config{ + CacheFactory: func(maxSize int) (guber.Cache, error) { + return mockCache, nil }, Loader: mockLoader, Workers: testCase.workers, } require.NoError(t, conf.SetDefaults()) - chp := guber.NewWorkerPool(conf) + chp, err := guber.NewCacheManager(conf) + require.NoError(t, err) // Mock Loader. mockLoader.On("Save", mock.Anything).Once().Return(nil). Run(func(args mock.Arguments) { // Verify items sent over the channel passed to Save(). saveCh := args.Get(0).(chan *guber.CacheItem) - savedItems := []*guber.CacheItem{} + var savedItems []*guber.CacheItem for item := range saveCh { savedItems = append(savedItems, item) } @@ -124,7 +126,7 @@ func TestGubernatorPool(t *testing.T) { mockCache.On("Each").Times(testCase.workers).Return(eachCh) // Call code. - err := chp.Store(ctx) + err = chp.Store(ctx) // Verify. require.NoError(t, err, "Error in chp.Store") diff --git a/cmd/gubernator/main.go b/cmd/gubernator/main.go index 8b54023..3b20856 100644 --- a/cmd/gubernator/main.go +++ b/cmd/gubernator/main.go @@ -31,7 +31,7 @@ import ( "github.com/mailgun/holster/v4/tracing" "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/sdk/resource" - semconv "go.opentelemetry.io/otel/semconv/v1.21.0" + semconv "go.opentelemetry.io/otel/semconv/v1.24.0" "k8s.io/klog/v2" ) diff --git a/cmd/gubernator/main_test.go b/cmd/gubernator/main_test.go index 8c4e10e..4f1364e 100644 --- a/cmd/gubernator/main_test.go +++ b/cmd/gubernator/main_test.go @@ -78,7 +78,7 @@ func TestCLI(t *testing.T) { time.Sleep(time.Second * 1) err = c.Process.Signal(syscall.SIGTERM) - require.NoError(t, err) + require.NoError(t, err, out.String()) <-waitCh assert.Contains(t, out.String(), tt.contains) diff --git a/config.go b/config.go index 7a80724..c0918e4 100644 --- a/config.go +++ b/config.go @@ -80,7 +80,7 @@ type Config struct { Behaviors BehaviorConfig // (Optional) The cache implementation - CacheFactory func(maxSize int) Cache + CacheFactory func(maxSize int) (Cache, error) // (Optional) A persistent store implementation. Allows the implementor the ability to store the rate limits this // instance of gubernator owns. It's up to the implementor to decide what rate limits to persist. @@ -141,8 +141,8 @@ func (c *Config) SetDefaults() error { setter.SetDefault(&c.Logger, logrus.New().WithField("category", "gubernator")) if c.CacheFactory == nil { - c.CacheFactory = func(maxSize int) Cache { - return NewLRUCache(maxSize) + c.CacheFactory = func(maxSize int) (Cache, error) { + return NewOtterCache(maxSize) } } diff --git a/daemon.go b/daemon.go index c45b5d6..a2701f7 100644 --- a/daemon.go +++ b/daemon.go @@ -90,16 +90,21 @@ func (s *Daemon) Start(ctx context.Context) error { s.promRegister = prometheus.NewRegistry() - // The LRU cache for storing rate limits. - cacheCollector := NewLRUCacheCollector() + // The cache for storing rate limits. + cacheCollector := NewCacheCollector() if err := s.promRegister.Register(cacheCollector); err != nil { return errors.Wrap(err, "during call to promRegister.Register()") } - cacheFactory := func(maxSize int) Cache { - cache := NewLRUCache(maxSize) + cacheFactory := func(maxSize int) (Cache, error) { + //cache := NewLRUCache(maxSize) + // TODO: Enable Otter as default or provide a config option + cache, err := NewOtterCache(maxSize) + if err != nil { + return nil, err + } cacheCollector.AddCache(cache) - return cache + return cache, nil } // Handler to collect duration and API access metrics for GRPC diff --git a/functional_test.go b/functional_test.go index 8ea08fe..d93d352 100644 --- a/functional_test.go +++ b/functional_test.go @@ -2294,6 +2294,8 @@ func getPeerCounters(t *testing.T, peers []*guber.Daemon, name string) map[strin } func sendHit(t *testing.T, d *guber.Daemon, req *guber.RateLimitReq, expectStatus guber.Status, expectRemaining int64) { + t.Helper() + if req.Hits != 0 { t.Logf("Sending %d hits to peer %s", req.Hits, d.InstanceID) } diff --git a/global.go b/global.go index c5fe167..effbf65 100644 --- a/global.go +++ b/global.go @@ -242,7 +242,7 @@ func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string] // Get current rate limit state. grlReq := proto.Clone(update).(*RateLimitReq) grlReq.Hits = 0 - status, err := gm.instance.workerPool.GetRateLimit(ctx, grlReq, reqState) + status, err := gm.instance.cache.GetRateLimit(ctx, grlReq, reqState) if err != nil { gm.log.WithError(err).Error("while retrieving rate limit status") continue diff --git a/go.mod b/go.mod index 482dfc0..8f2518e 100644 --- a/go.mod +++ b/go.mod @@ -5,12 +5,12 @@ go 1.21 toolchain go1.21.9 require ( - github.com/OneOfOne/xxhash v1.2.8 github.com/davecgh/go-spew v1.1.1 github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.0 github.com/hashicorp/memberlist v0.5.0 github.com/mailgun/errors v0.1.5 github.com/mailgun/holster/v4 v4.19.0 + github.com/maypok86/otter v1.2.1 github.com/miekg/dns v1.1.50 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.13.0 @@ -45,7 +45,9 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/coreos/go-semver v0.3.0 // indirect github.com/coreos/go-systemd/v22 v22.3.2 // indirect + github.com/dolthub/maphash v0.1.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/gammazero/deque v0.2.1 // indirect github.com/go-logr/logr v1.4.1 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect diff --git a/go.sum b/go.sum index 89711f6..461b937 100644 --- a/go.sum +++ b/go.sum @@ -52,8 +52,6 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= github.com/NYTimes/gziphandler v0.0.0-20170623195520-56545f4a5d46/go.mod h1:3wb06e3pkSAbeQ52E9H9iFoQsEEwGN64994WTCIhntQ= -github.com/OneOfOne/xxhash v1.2.8 h1:31czK/TI9sNkxIKfaUfGlU47BAxQ0ztGgd9vPyqimf8= -github.com/OneOfOne/xxhash v1.2.8/go.mod h1:eZbhyaAYD41SGSSsnmcpxVoRiQ/MPUTjUdIIOT9Um7Q= github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/Shopify/toxiproxy v2.1.4+incompatible h1:TKdv8HiTLgE5wdJuEML90aBgNWsokNbMijUGhmcoBJc= @@ -106,6 +104,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= +github.com/dolthub/maphash v0.1.0 h1:bsQ7JsF4FkkWyrP3oCnFJgrCUAFbFf3kOl4L/QxPDyQ= +github.com/dolthub/maphash v0.1.0/go.mod h1:gkg4Ch4CdCDu5h6PMriVLawB7koZ+5ijb9puGMV50a4= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/elazarl/goproxy v0.0.0-20180725130230-947c36da3153/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= @@ -128,6 +128,8 @@ github.com/form3tech-oss/jwt-go v3.2.3+incompatible/go.mod h1:pbq4aXjuKjdthFRnoD github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/gammazero/deque v0.2.1 h1:qSdsbG6pgp6nL7A0+K/B7s12mcCY/5l5SIUpMOl+dC0= +github.com/gammazero/deque v0.2.1/go.mod h1:LFroj8x4cMYCukHJDbxFCkT+r9AndaJnFMuZDV34tuU= github.com/getkin/kin-openapi v0.76.0/go.mod h1:660oXbgy5JFMKreazJaQTw7o+X00qeSyhcnluiMv+Xg= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= @@ -306,6 +308,8 @@ github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaO github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/maypok86/otter v1.2.1 h1:xyvMW+t0vE1sKt/++GTkznLitEl7D/msqXkAbLwiC1M= +github.com/maypok86/otter v1.2.1/go.mod h1:mKLfoI7v1HOmQMwFgX4QkRk23mX6ge3RDvjdHOWG4R4= github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso= github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA= github.com/miekg/dns v1.1.50/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME= diff --git a/gubernator.go b/gubernator.go index ff6812a..1ae40d4 100644 --- a/gubernator.go +++ b/gubernator.go @@ -45,12 +45,12 @@ const ( type V1Instance struct { UnimplementedV1Server UnimplementedPeersV1Server - global *globalManager - peerMutex sync.RWMutex - log FieldLogger - conf Config - isClosed bool - workerPool *WorkerPool + global *globalManager + cache CacheManager + peerMutex sync.RWMutex + log FieldLogger + conf Config + isClosed bool } type RateLimitReqState struct { @@ -83,14 +83,6 @@ var ( Name: "gubernator_check_error_counter", Help: "The number of errors while checking rate limits.", }, []string{"error"}) - metricCommandCounter = prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "gubernator_command_counter", - Help: "The count of commands processed by each worker in WorkerPool.", - }, []string{"worker", "method"}) - metricWorkerQueue = prometheus.NewGaugeVec(prometheus.GaugeOpts{ - Name: "gubernator_worker_queue_length", - Help: "The count of requests queued up in WorkerPool.", - }, []string{"method", "worker"}) // Batch behavior. metricBatchSendRetries = prometheus.NewCounterVec(prometheus.CounterOpts{ @@ -126,7 +118,10 @@ func NewV1Instance(conf Config) (s *V1Instance, err error) { conf: conf, } - s.workerPool = NewWorkerPool(&conf) + s.cache, err = NewCacheManager(conf) + if err != nil { + return nil, fmt.Errorf("during NewCacheManager(): %w", err) + } s.global = newGlobalManager(conf.Behaviors, s) // Register our instance with all GRPC servers @@ -140,9 +135,9 @@ func NewV1Instance(conf Config) (s *V1Instance, err error) { } // Load the cache. - err = s.workerPool.Load(ctx) + err = s.cache.Load(ctx) if err != nil { - return nil, errors.Wrap(err, "Error in workerPool.Load") + return nil, errors.Wrap(err, "Error in CacheManager.Load") } return s, nil @@ -158,19 +153,19 @@ func (s *V1Instance) Close() (err error) { s.global.Close() if s.conf.Loader != nil { - err = s.workerPool.Store(ctx) + err = s.cache.Store(ctx) if err != nil { s.log.WithError(err). - Error("Error in workerPool.Store") - return errors.Wrap(err, "Error in workerPool.Store") + Error("Error in CacheManager.Store") + return errors.Wrap(err, "Error in CacheManager.Store") } } - err = s.workerPool.Close() + err = s.cache.Close() if err != nil { s.log.WithError(err). - Error("Error in workerPool.Close") - return errors.Wrap(err, "Error in workerPool.Close") + Error("Error in CacheManager.Close") + return errors.Wrap(err, "Error in CacheManager.Close") } s.isClosed = true @@ -449,9 +444,9 @@ func (s *V1Instance) UpdatePeerGlobals(ctx context.Context, r *UpdatePeerGlobals CreatedAt: now, } } - err := s.workerPool.AddCacheItem(ctx, g.Key, item) + err := s.cache.AddCacheItem(ctx, g.Key, item) if err != nil { - return nil, errors.Wrap(err, "Error in workerPool.AddCacheItem") + return nil, errors.Wrap(err, "Error in CacheManager.AddCacheItem") } } @@ -595,9 +590,9 @@ func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq, req defer func() { tracing.EndScope(ctx, err) }() defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.getLocalRateLimit")).ObserveDuration() - resp, err := s.workerPool.GetRateLimit(ctx, r, reqState) + resp, err := s.cache.GetRateLimit(ctx, r, reqState) if err != nil { - return nil, errors.Wrap(err, "during workerPool.GetRateLimit") + return nil, errors.Wrap(err, "during CacheManager.GetRateLimit") } // If global behavior, then broadcast update to all peers. @@ -742,12 +737,10 @@ func (s *V1Instance) Describe(ch chan<- *prometheus.Desc) { metricBatchSendDuration.Describe(ch) metricBatchSendRetries.Describe(ch) metricCheckErrorCounter.Describe(ch) - metricCommandCounter.Describe(ch) metricConcurrentChecks.Describe(ch) metricFuncTimeDuration.Describe(ch) metricGetRateLimitCounter.Describe(ch) metricOverLimitCounter.Describe(ch) - metricWorkerQueue.Describe(ch) s.global.metricBroadcastDuration.Describe(ch) s.global.metricGlobalQueueLength.Describe(ch) s.global.metricGlobalSendDuration.Describe(ch) @@ -760,12 +753,10 @@ func (s *V1Instance) Collect(ch chan<- prometheus.Metric) { metricBatchSendDuration.Collect(ch) metricBatchSendRetries.Collect(ch) metricCheckErrorCounter.Collect(ch) - metricCommandCounter.Collect(ch) metricConcurrentChecks.Collect(ch) metricFuncTimeDuration.Collect(ch) metricGetRateLimitCounter.Collect(ch) metricOverLimitCounter.Collect(ch) - metricWorkerQueue.Collect(ch) s.global.metricBroadcastDuration.Collect(ch) s.global.metricGlobalQueueLength.Collect(ch) s.global.metricGlobalSendDuration.Collect(ch) diff --git a/lrucache.go b/lrucache.go index 0386720..8a415c9 100644 --- a/lrucache.go +++ b/lrucache.go @@ -20,6 +20,7 @@ package gubernator import ( "container/list" + "sync" "sync/atomic" "github.com/mailgun/holster/v4/clock" @@ -32,18 +33,19 @@ import ( type LRUCache struct { cache map[string]*list.Element ll *list.List + mu sync.Mutex cacheSize int cacheLen int64 } -// LRUCacheCollector provides prometheus metrics collector for LRUCache. +// CacheCollector provides prometheus metrics collector for LRUCache. // Register only one collector, add one or more caches to this collector. -type LRUCacheCollector struct { +type CacheCollector struct { caches []Cache } var _ Cache = &LRUCache{} -var _ prometheus.Collector = &LRUCacheCollector{} +var _ prometheus.Collector = &CacheCollector{} var metricCacheSize = prometheus.NewGauge(prometheus.GaugeOpts{ Name: "gubernator_cache_size", @@ -86,6 +88,9 @@ func (c *LRUCache) Each() chan *CacheItem { // Add adds a value to the cache. func (c *LRUCache) Add(item *CacheItem) bool { + c.mu.Lock() + defer c.mu.Unlock() + // If the key already exist, set the new value if ee, ok := c.cache[item.Key]; ok { c.ll.MoveToFront(ee) @@ -109,6 +114,8 @@ func MillisecondNow() int64 { // GetItem returns the item stored in the cache func (c *LRUCache) GetItem(key string) (item *CacheItem, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() if ele, hit := c.cache[key]; hit { entry := ele.Value.(*CacheItem) @@ -162,6 +169,9 @@ func (c *LRUCache) Size() int64 { // UpdateExpiration updates the expiration time for the key func (c *LRUCache) UpdateExpiration(key string, expireAt int64) bool { + c.mu.Lock() + defer c.mu.Unlock() + if ele, hit := c.cache[key]; hit { entry := ele.Value.(*CacheItem) entry.ExpireAt = expireAt @@ -177,33 +187,33 @@ func (c *LRUCache) Close() error { return nil } -func NewLRUCacheCollector() *LRUCacheCollector { - return &LRUCacheCollector{ +func NewCacheCollector() *CacheCollector { + return &CacheCollector{ caches: []Cache{}, } } // AddCache adds a Cache object to be tracked by the collector. -func (collector *LRUCacheCollector) AddCache(cache Cache) { +func (collector *CacheCollector) AddCache(cache Cache) { collector.caches = append(collector.caches, cache) } // Describe fetches prometheus metrics to be registered -func (collector *LRUCacheCollector) Describe(ch chan<- *prometheus.Desc) { +func (collector *CacheCollector) Describe(ch chan<- *prometheus.Desc) { metricCacheSize.Describe(ch) metricCacheAccess.Describe(ch) metricCacheUnexpiredEvictions.Describe(ch) } // Collect fetches metric counts and gauges from the cache -func (collector *LRUCacheCollector) Collect(ch chan<- prometheus.Metric) { +func (collector *CacheCollector) Collect(ch chan<- prometheus.Metric) { metricCacheSize.Set(collector.getSize()) metricCacheSize.Collect(ch) metricCacheAccess.Collect(ch) metricCacheUnexpiredEvictions.Collect(ch) } -func (collector *LRUCacheCollector) getSize() float64 { +func (collector *CacheCollector) getSize() float64 { var size float64 for _, cache := range collector.caches { diff --git a/lrucache_test.go b/lrucache_test.go index 51f33bc..402dd83 100644 --- a/lrucache_test.go +++ b/lrucache_test.go @@ -316,7 +316,7 @@ func TestLRUCache(t *testing.T) { } }() - collector := gubernator.NewLRUCacheCollector() + collector := gubernator.NewCacheCollector() collector.AddCache(cache) go func() { @@ -342,7 +342,7 @@ func TestLRUCache(t *testing.T) { promRegister := prometheus.NewRegistry() // The LRU cache for storing rate limits. - cacheCollector := gubernator.NewLRUCacheCollector() + cacheCollector := gubernator.NewCacheCollector() err := promRegister.Register(cacheCollector) require.NoError(t, err) @@ -389,7 +389,7 @@ func TestLRUCache(t *testing.T) { promRegister := prometheus.NewRegistry() // The LRU cache for storing rate limits. - cacheCollector := gubernator.NewLRUCacheCollector() + cacheCollector := gubernator.NewCacheCollector() err := promRegister.Register(cacheCollector) require.NoError(t, err) diff --git a/otter.go b/otter.go new file mode 100644 index 0000000..f04fc8e --- /dev/null +++ b/otter.go @@ -0,0 +1,111 @@ +package gubernator + +import ( + "fmt" + + "github.com/mailgun/holster/v4/setter" + "github.com/maypok86/otter" +) + +type OtterCache struct { + cache otter.Cache[string, *CacheItem] +} + +// NewOtterCache returns a new cache backed by otter. If size is 0, then +// the cache is created with a default cache size. +func NewOtterCache(size int) (*OtterCache, error) { + setter.SetDefault(&size, 150_000) + b, err := otter.NewBuilder[string, *CacheItem](size) + if err != nil { + return nil, fmt.Errorf("during otter.NewBuilder(): %w", err) + } + + b.DeletionListener(func(key string, value *CacheItem, cause otter.DeletionCause) { + if cause == otter.Size { + metricCacheUnexpiredEvictions.Add(1) + } + }) + + cache, err := b.Build() + if err != nil { + return nil, fmt.Errorf("during otter.Builder.Build(): %w", err) + } + return &OtterCache{ + cache: cache, + }, nil +} + +// Add adds a new CacheItem to the cache. The key must be provided via CacheItem.Key +// returns true if the item was added to the cache; false if the item was too large +// for the cache. +func (o *OtterCache) Add(item *CacheItem) bool { + return o.cache.Set(item.Key, item) +} + +// GetItem returns an item in the cache that corresponds to the provided key +func (o *OtterCache) GetItem(key string) (*CacheItem, bool) { + item, ok := o.cache.Get(key) + if !ok { + metricCacheAccess.WithLabelValues("miss").Add(1) + return nil, false + } + + if item.IsExpired() { + metricCacheAccess.WithLabelValues("miss").Add(1) + // If the item is expired, just return `nil` + // + // We avoid the explicit deletion of the expired item to avoid acquiring a mutex lock in otter. + // Explicit deletions in otter require a mutex, which can cause performance bottlenecks + // under high concurrency scenarios. By allowing the item to be evicted naturally by + // otter's eviction mechanism, we avoid impacting performance under high concurrency. + return nil, false + } + metricCacheAccess.WithLabelValues("hit").Add(1) + return item, true +} + +// UpdateExpiration will update an item in the cache with a new expiration date. +// returns true if the item exists in the cache and was updated. +func (o *OtterCache) UpdateExpiration(key string, expireAt int64) bool { + item, ok := o.cache.Get(key) + if !ok { + return false + } + + item.ExpireAt = expireAt + return true +} + +// Each returns a channel which the call can use to iterate through +// all the items in the cache. +func (o *OtterCache) Each() chan *CacheItem { + ch := make(chan *CacheItem) + + go func() { + o.cache.Range(func(_ string, v *CacheItem) bool { + ch <- v + return true + }) + close(ch) + }() + return ch +} + +// Remove explicitly removes and item from the cache. +// NOTE: A deletion call to otter requires a mutex to preform, +// if possible, avoid preforming explicit removal from the cache. +// Instead, prefer the item to be evicted naturally. +func (o *OtterCache) Remove(key string) { + o.cache.Delete(key) +} + +// Size return the current number of items in the cache +func (o *OtterCache) Size() int64 { + return int64(o.cache.Size()) +} + +// Close closes the cache and all associated background processes +func (o *OtterCache) Close() error { + o.cache.Close() + return nil +} diff --git a/otter_test.go b/otter_test.go new file mode 100644 index 0000000..6eb629d --- /dev/null +++ b/otter_test.go @@ -0,0 +1,405 @@ +package gubernator_test + +import ( + "strconv" + "sync" + "testing" + "time" + + "github.com/gubernator-io/gubernator/v2" + "github.com/mailgun/holster/v4/clock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOtterCache(t *testing.T) { + const iterations = 1000 + const concurrency = 100 + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + + t.Run("Happy path", func(t *testing.T) { + cache, err := gubernator.NewOtterCache(0) + require.NoError(t, err) + + // Populate cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.Add(item) + } + + // Validate cache. + assert.Equal(t, int64(iterations), cache.Size()) + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item, ok := cache.GetItem(key) + require.True(t, ok) + require.NotNil(t, item) + assert.Equal(t, item.Value, i) + } + + // Clear cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + cache.Remove(key) + } + + assert.Zero(t, cache.Size()) + }) + + t.Run("Update an existing key", func(t *testing.T) { + cache, err := gubernator.NewOtterCache(0) + require.NoError(t, err) + const key = "foobar" + + // Add key. + item1 := &gubernator.CacheItem{ + Key: key, + Value: "initial value", + ExpireAt: expireAt, + } + cache.Add(item1) + + // Update same key. + item2 := &gubernator.CacheItem{ + Key: key, + Value: "new value", + ExpireAt: expireAt, + } + cache.Add(item2) + + // Verify. + verifyItem, ok := cache.GetItem(key) + require.True(t, ok) + assert.Equal(t, item2, verifyItem) + }) + + t.Run("Concurrent reads", func(t *testing.T) { + cache, err := gubernator.NewOtterCache(0) + require.NoError(t, err) + + // Populate cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.Add(item) + } + + assert.Equal(t, int64(iterations), cache.Size()) + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for thread := 0; thread < concurrency; thread++ { + doneWg.Add(1) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item, ok := cache.GetItem(key) + assert.True(t, ok) + require.NotNil(t, item) + assert.Equal(t, item.Value, i) + } + }() + } + + // Wait for goroutines to finish. + launchWg.Done() + doneWg.Wait() + }) + + t.Run("Concurrent writes", func(t *testing.T) { + cache, err := gubernator.NewOtterCache(0) + require.NoError(t, err) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for thread := 0; thread < concurrency; thread++ { + doneWg.Add(1) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.Add(item) + } + }() + } + + // Wait for goroutines to finish. + launchWg.Done() + doneWg.Wait() + }) + + t.Run("Concurrent reads and writes", func(t *testing.T) { + cache, err := gubernator.NewOtterCache(0) + require.NoError(t, err) + + // Populate cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.Add(item) + } + + assert.Equal(t, int64(iterations), cache.Size()) + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for thread := 0; thread < concurrency; thread++ { + doneWg.Add(2) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item, ok := cache.GetItem(key) + assert.True(t, ok) + require.NotNil(t, item) + assert.Equal(t, item.Value, i) + } + }() + + go func() { + defer doneWg.Done() + launchWg.Wait() + + // Write different keys than the keys we are reading to avoid race on Add() / GetItem() + for i := iterations; i < iterations*2; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.Add(item) + } + }() + } + + // Wait for goroutines to finish. + launchWg.Done() + doneWg.Wait() + }) +} + +func BenchmarkOtterCache(b *testing.B) { + + b.Run("Sequential reads", func(b *testing.B) { + cache, _ := gubernator.NewOtterCache(0) + + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + + // Populate cache. + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.Add(item) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + _, _ = cache.GetItem(key) + } + }) + + b.Run("Sequential writes", func(b *testing.B) { + cache, _ := gubernator.NewOtterCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.Add(item) + } + }) + + b.Run("Concurrent reads", func(b *testing.B) { + cache, _ := gubernator.NewOtterCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + + // Populate cache. + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.Add(item) + } + + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + doneWg.Add(1) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + _, _ = cache.GetItem(key) + }() + } + + b.ReportAllocs() + b.ResetTimer() + launchWg.Done() + doneWg.Wait() + }) + + b.Run("Concurrent writes", func(b *testing.B) { + cache, _ := gubernator.NewOtterCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + doneWg.Add(1) + + go func(i int) { + defer doneWg.Done() + launchWg.Wait() + + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.Add(item) + }(i) + } + + b.ReportAllocs() + b.ResetTimer() + launchWg.Done() + doneWg.Wait() + }) + + b.Run("Concurrent reads and writes of existing keys", func(b *testing.B) { + cache, _ := gubernator.NewOtterCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + // Populate cache. + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.Add(item) + } + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + doneWg.Add(2) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + _, _ = cache.GetItem(key) + }() + + go func(i int) { + defer doneWg.Done() + launchWg.Wait() + + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.Add(item) + }(i) + } + + b.ReportAllocs() + b.ResetTimer() + launchWg.Done() + doneWg.Wait() + }) + + b.Run("Concurrent reads and writes of non-existent keys", func(b *testing.B) { + cache, _ := gubernator.NewOtterCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for i := 0; i < b.N; i++ { + doneWg.Add(2) + + go func(i int) { + defer doneWg.Done() + launchWg.Wait() + + key := strconv.Itoa(i) + _, _ = cache.GetItem(key) + }(i) + + go func(i int) { + defer doneWg.Done() + launchWg.Wait() + + key := "z" + strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.Add(item) + }(i) + } + + b.ReportAllocs() + b.ResetTimer() + launchWg.Done() + doneWg.Wait() + }) +} diff --git a/peer_client.go b/peer_client.go index 03b29ff..413a5f2 100644 --- a/peer_client.go +++ b/peer_client.go @@ -66,10 +66,9 @@ type response struct { } type request struct { - request *RateLimitReq - reqState RateLimitReqState - resp chan *response - ctx context.Context + request *RateLimitReq + resp chan *response + ctx context.Context } type PeerConfig struct { diff --git a/tls.go b/tls.go index 46cd2d7..25f951e 100644 --- a/tls.go +++ b/tls.go @@ -259,17 +259,20 @@ func SetupTLS(conf *TLSConfig) error { // If user asked for client auth if conf.ClientAuth != tls.NoClientCert { clientPool := x509.NewCertPool() + var certProvided bool if conf.ClientAuthCaPEM != nil { // If client auth CA was provided clientPool.AppendCertsFromPEM(conf.ClientAuthCaPEM.Bytes()) + certProvided = true } else if conf.CaPEM != nil { // else use the servers CA clientPool.AppendCertsFromPEM(conf.CaPEM.Bytes()) + certProvided = true } - // error if neither was provided - if len(clientPool.Subjects()) == 0 { //nolint:all + // error if neither cert was provided + if !certProvided { return errors.New("client auth enabled, but no CA's provided") } diff --git a/workers.go b/workers.go deleted file mode 100644 index 34d99d1..0000000 --- a/workers.go +++ /dev/null @@ -1,626 +0,0 @@ -/* -Copyright 2018-2022 Mailgun Technologies Inc - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package gubernator - -// Thread-safe worker pool for handling concurrent Gubernator requests. -// Ensures requests are synchronized to avoid caching conflicts. -// Handle concurrent requests by sharding cache key space across multiple -// workers. -// Uses hash ring design pattern to distribute requests to an assigned worker. -// No mutex locking necessary because each worker has its own data space and -// processes requests sequentially. -// -// Request workflow: -// - A 63-bit hash is generated from an incoming request by its Key/Name -// values. (Actually 64 bit, but we toss out one bit to properly calculate -// the next step.) -// - Workers are assigned equal size hash ranges. The worker is selected by -// choosing the worker index associated with that linear hash value range. -// - The worker has command channels for each method call. The request is -// enqueued to the appropriate channel. -// - The worker pulls the request from the appropriate channel and executes the -// business logic for that method. Then, it sends a response back using the -// requester's provided response channel. - -import ( - "context" - "io" - "strconv" - "sync" - "sync/atomic" - - "github.com/OneOfOne/xxhash" - "github.com/mailgun/holster/v4/errors" - "github.com/mailgun/holster/v4/setter" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" - "go.opentelemetry.io/otel/trace" -) - -type WorkerPool struct { - hasher workerHasher - workers []*Worker - workerCacheSize int - hashRingStep uint64 - conf *Config - done chan struct{} -} - -type Worker struct { - name string - conf *Config - cache Cache - getRateLimitRequest chan request - storeRequest chan workerStoreRequest - loadRequest chan workerLoadRequest - addCacheItemRequest chan workerAddCacheItemRequest - getCacheItemRequest chan workerGetCacheItemRequest -} - -type workerHasher interface { - // ComputeHash63 returns a 63-bit hash derived from input. - ComputeHash63(input string) uint64 -} - -// hasher is the default implementation of workerHasher. -type hasher struct{} - -// Method request/response structs. -type workerStoreRequest struct { - ctx context.Context - response chan workerStoreResponse - out chan<- *CacheItem -} - -type workerStoreResponse struct{} - -type workerLoadRequest struct { - ctx context.Context - response chan workerLoadResponse - in <-chan *CacheItem -} - -type workerLoadResponse struct{} - -type workerAddCacheItemRequest struct { - ctx context.Context - response chan workerAddCacheItemResponse - item *CacheItem -} - -type workerAddCacheItemResponse struct { - exists bool -} - -type workerGetCacheItemRequest struct { - ctx context.Context - response chan workerGetCacheItemResponse - key string -} - -type workerGetCacheItemResponse struct { - item *CacheItem - ok bool -} - -var _ io.Closer = &WorkerPool{} -var _ workerHasher = &hasher{} - -var workerCounter int64 - -func NewWorkerPool(conf *Config) *WorkerPool { - setter.SetDefault(&conf.CacheSize, 50_000) - - // Compute hashRingStep as interval between workers' 63-bit hash ranges. - // 64th bit is used here as a max value that is just out of range of 63-bit space to calculate the step. - chp := &WorkerPool{ - workers: make([]*Worker, conf.Workers), - workerCacheSize: conf.CacheSize / conf.Workers, - hasher: newHasher(), - hashRingStep: uint64(1<<63) / uint64(conf.Workers), - conf: conf, - done: make(chan struct{}), - } - - // Create workers. - conf.Logger.Infof("Starting %d Gubernator workers...", conf.Workers) - for i := 0; i < conf.Workers; i++ { - chp.workers[i] = chp.newWorker() - go chp.dispatch(chp.workers[i]) - } - - return chp -} - -func newHasher() *hasher { - return &hasher{} -} - -func (ph *hasher) ComputeHash63(input string) uint64 { - return xxhash.ChecksumString64S(input, 0) >> 1 -} - -func (p *WorkerPool) Close() error { - close(p.done) - return nil -} - -// Create a new pool worker instance. -func (p *WorkerPool) newWorker() *Worker { - worker := &Worker{ - conf: p.conf, - cache: p.conf.CacheFactory(p.workerCacheSize), - getRateLimitRequest: make(chan request), - storeRequest: make(chan workerStoreRequest), - loadRequest: make(chan workerLoadRequest), - addCacheItemRequest: make(chan workerAddCacheItemRequest), - getCacheItemRequest: make(chan workerGetCacheItemRequest), - } - workerNumber := atomic.AddInt64(&workerCounter, 1) - 1 - worker.name = strconv.FormatInt(workerNumber, 10) - return worker -} - -// getWorker Returns the request channel associated with the key. -// Hash the key, then lookup hash ring to find the worker. -func (p *WorkerPool) getWorker(key string) *Worker { - hash := p.hasher.ComputeHash63(key) - idx := hash / p.hashRingStep - return p.workers[idx] -} - -// Pool worker for processing Gubernator requests. -// Each worker maintains its own state. -// A hash ring will distribute requests to an assigned worker by key. -// See: getWorker() -func (p *WorkerPool) dispatch(worker *Worker) { - for { - // Dispatch requests from each channel. - select { - case req, ok := <-worker.getRateLimitRequest: - if !ok { - // Channel closed. Unexpected, but should be handled. - logrus.Error("workerPool worker stopped because channel closed") - return - } - - resp := new(response) - resp.rl, resp.err = worker.handleGetRateLimit(req.ctx, req.request, req.reqState, worker.cache) - select { - case req.resp <- resp: - // Success. - - case <-req.ctx.Done(): - // Context canceled. - trace.SpanFromContext(req.ctx).RecordError(resp.err) - } - metricCommandCounter.WithLabelValues(worker.name, "GetRateLimit").Inc() - - case req, ok := <-worker.storeRequest: - if !ok { - // Channel closed. Unexpected, but should be handled. - logrus.Error("workerPool worker stopped because channel closed") - return - } - - worker.handleStore(req, worker.cache) - metricCommandCounter.WithLabelValues(worker.name, "Store").Inc() - - case req, ok := <-worker.loadRequest: - if !ok { - // Channel closed. Unexpected, but should be handled. - logrus.Error("workerPool worker stopped because channel closed") - return - } - - worker.handleLoad(req, worker.cache) - metricCommandCounter.WithLabelValues(worker.name, "Load").Inc() - - case req, ok := <-worker.addCacheItemRequest: - if !ok { - // Channel closed. Unexpected, but should be handled. - logrus.Error("workerPool worker stopped because channel closed") - return - } - - worker.handleAddCacheItem(req, worker.cache) - metricCommandCounter.WithLabelValues(worker.name, "AddCacheItem").Inc() - - case req, ok := <-worker.getCacheItemRequest: - if !ok { - // Channel closed. Unexpected, but should be handled. - logrus.Error("workerPool worker stopped because channel closed") - return - } - - worker.handleGetCacheItem(req, worker.cache) - metricCommandCounter.WithLabelValues(worker.name, "GetCacheItem").Inc() - - case <-p.done: - // Clean up. - return - } - } -} - -// GetRateLimit sends a GetRateLimit request to worker pool. -func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitReq, reqState RateLimitReqState) (*RateLimitResp, error) { - // Delegate request to assigned channel based on request key. - worker := p.getWorker(rlRequest.HashKey()) - queueGauge := metricWorkerQueue.WithLabelValues("GetRateLimit", worker.name) - queueGauge.Inc() - defer queueGauge.Dec() - handlerRequest := request{ - ctx: ctx, - resp: make(chan *response, 1), - request: rlRequest, - reqState: reqState, - } - - // Send request. - select { - case worker.getRateLimitRequest <- handlerRequest: - // Successfully sent request. - case <-ctx.Done(): - return nil, ctx.Err() - } - - // Wait for response. - select { - case handlerResponse := <-handlerRequest.resp: - // Successfully read response. - return handlerResponse.rl, handlerResponse.err - case <-ctx.Done(): - return nil, ctx.Err() - } -} - -// Handle request received by worker. -func (worker *Worker) handleGetRateLimit(ctx context.Context, req *RateLimitReq, reqState RateLimitReqState, cache Cache) (*RateLimitResp, error) { - defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("Worker.handleGetRateLimit")).ObserveDuration() - var rlResponse *RateLimitResp - var err error - - switch req.Algorithm { - case Algorithm_TOKEN_BUCKET: - rlResponse, err = tokenBucket(ctx, worker.conf.Store, cache, req, reqState) - if err != nil { - msg := "Error in tokenBucket" - countError(err, msg) - err = errors.Wrap(err, msg) - trace.SpanFromContext(ctx).RecordError(err) - } - - case Algorithm_LEAKY_BUCKET: - rlResponse, err = leakyBucket(ctx, worker.conf.Store, cache, req, reqState) - if err != nil { - msg := "Error in leakyBucket" - countError(err, msg) - err = errors.Wrap(err, msg) - trace.SpanFromContext(ctx).RecordError(err) - } - - default: - err = errors.Errorf("Invalid rate limit algorithm '%d'", req.Algorithm) - trace.SpanFromContext(ctx).RecordError(err) - metricCheckErrorCounter.WithLabelValues("Invalid algorithm").Add(1) - } - - return rlResponse, err -} - -// Load atomically loads cache from persistent storage. -// Read from persistent storage. Load into each appropriate worker's cache. -// Workers are locked during this load operation to prevent race conditions. -func (p *WorkerPool) Load(ctx context.Context) (err error) { - queueGauge := metricWorkerQueue.WithLabelValues("Load", "") - queueGauge.Inc() - defer queueGauge.Dec() - ch, err := p.conf.Loader.Load() - if err != nil { - return errors.Wrap(err, "Error in loader.Load") - } - - type loadChannel struct { - ch chan *CacheItem - worker *Worker - respChan chan workerLoadResponse - } - - // Map request channel hash to load channel. - loadChMap := map[*Worker]loadChannel{} - - // Send each item to the assigned channel's cache. -MAIN: - for { - var item *CacheItem - var ok bool - - select { - case item, ok = <-ch: - if !ok { - break MAIN - } - // Successfully received item. - - case <-ctx.Done(): - // Context canceled. - return ctx.Err() - } - - worker := p.getWorker(item.Key) - - // Initiate a load channel with each worker. - loadCh, exist := loadChMap[worker] - if !exist { - loadCh = loadChannel{ - ch: make(chan *CacheItem), - worker: worker, - respChan: make(chan workerLoadResponse), - } - loadChMap[worker] = loadCh - - // Tie up the worker while loading. - worker.loadRequest <- workerLoadRequest{ - ctx: ctx, - response: loadCh.respChan, - in: loadCh.ch, - } - } - - // Send item to worker's load channel. - select { - case loadCh.ch <- item: - // Successfully sent item. - - case <-ctx.Done(): - // Context canceled. - return ctx.Err() - } - } - - // Clean up. - for _, loadCh := range loadChMap { - close(loadCh.ch) - - // Load response confirms all items have been loaded and the worker - // resumes normal operation. - select { - case <-loadCh.respChan: - // Successfully received response. - - case <-ctx.Done(): - // Context canceled. - return ctx.Err() - } - } - - return nil -} - -func (worker *Worker) handleLoad(request workerLoadRequest, cache Cache) { -MAIN: - for { - var item *CacheItem - var ok bool - - select { - case item, ok = <-request.in: - if !ok { - break MAIN - } - // Successfully received item. - - case <-request.ctx.Done(): - // Context canceled. - return - } - - cache.Add(item) - } - - response := workerLoadResponse{} - - select { - case request.response <- response: - // Successfully sent response. - - case <-request.ctx.Done(): - // Context canceled. - trace.SpanFromContext(request.ctx).RecordError(request.ctx.Err()) - } -} - -// Store atomically stores cache to persistent storage. -// Save all workers' caches to persistent storage. -// Workers are locked during this store operation to prevent race conditions. -func (p *WorkerPool) Store(ctx context.Context) (err error) { - queueGauge := metricWorkerQueue.WithLabelValues("Store", "") - queueGauge.Inc() - defer queueGauge.Dec() - var wg sync.WaitGroup - out := make(chan *CacheItem, 500) - - // Iterate each worker's cache to `out` channel. - for _, worker := range p.workers { - wg.Add(1) - - go func(ctx context.Context, worker *Worker) { - defer wg.Done() - - respChan := make(chan workerStoreResponse) - req := workerStoreRequest{ - ctx: ctx, - response: respChan, - out: out, - } - - select { - case worker.storeRequest <- req: - // Successfully sent request. - select { - case <-respChan: - // Successfully received response. - return - - case <-ctx.Done(): - // Context canceled. - trace.SpanFromContext(ctx).RecordError(ctx.Err()) - return - } - - case <-ctx.Done(): - // Context canceled. - trace.SpanFromContext(ctx).RecordError(ctx.Err()) - return - } - }(ctx, worker) - } - - // When all iterators are done, close `out` channel. - go func() { - wg.Wait() - close(out) - }() - - if ctx.Err() != nil { - return ctx.Err() - } - - if err = p.conf.Loader.Save(out); err != nil { - return errors.Wrap(err, "while calling p.conf.Loader.Save()") - } - - return nil -} - -func (worker *Worker) handleStore(request workerStoreRequest, cache Cache) { - for item := range cache.Each() { - select { - case request.out <- item: - // Successfully sent item. - - case <-request.ctx.Done(): - // Context canceled. - trace.SpanFromContext(request.ctx).RecordError(request.ctx.Err()) - return - } - } - - response := workerStoreResponse{} - - select { - case request.response <- response: - // Successfully sent response. - - case <-request.ctx.Done(): - // Context canceled. - trace.SpanFromContext(request.ctx).RecordError(request.ctx.Err()) - } -} - -// AddCacheItem adds an item to the worker's cache. -func (p *WorkerPool) AddCacheItem(ctx context.Context, key string, item *CacheItem) (err error) { - worker := p.getWorker(key) - queueGauge := metricWorkerQueue.WithLabelValues("AddCacheItem", worker.name) - queueGauge.Inc() - defer queueGauge.Dec() - respChan := make(chan workerAddCacheItemResponse) - req := workerAddCacheItemRequest{ - ctx: ctx, - response: respChan, - item: item, - } - - select { - case worker.addCacheItemRequest <- req: - // Successfully sent request. - select { - case <-respChan: - // Successfully received response. - return nil - - case <-ctx.Done(): - // Context canceled. - return ctx.Err() - } - - case <-ctx.Done(): - // Context canceled. - return ctx.Err() - } -} - -func (worker *Worker) handleAddCacheItem(request workerAddCacheItemRequest, cache Cache) { - exists := cache.Add(request.item) - response := workerAddCacheItemResponse{exists} - - select { - case request.response <- response: - // Successfully sent response. - - case <-request.ctx.Done(): - // Context canceled. - trace.SpanFromContext(request.ctx).RecordError(request.ctx.Err()) - } -} - -// GetCacheItem gets item from worker's cache. -func (p *WorkerPool) GetCacheItem(ctx context.Context, key string) (item *CacheItem, found bool, err error) { - worker := p.getWorker(key) - queueGauge := metricWorkerQueue.WithLabelValues("GetCacheItem", worker.name) - queueGauge.Inc() - defer queueGauge.Dec() - respChan := make(chan workerGetCacheItemResponse) - req := workerGetCacheItemRequest{ - ctx: ctx, - response: respChan, - key: key, - } - - select { - case worker.getCacheItemRequest <- req: - // Successfully sent request. - select { - case resp := <-respChan: - // Successfully received response. - return resp.item, resp.ok, nil - - case <-ctx.Done(): - // Context canceled. - return nil, false, ctx.Err() - } - - case <-ctx.Done(): - // Context canceled. - return nil, false, ctx.Err() - } -} - -func (worker *Worker) handleGetCacheItem(request workerGetCacheItemRequest, cache Cache) { - item, ok := cache.GetItem(request.key) - response := workerGetCacheItemResponse{item, ok} - - select { - case request.response <- response: - // Successfully sent response. - - case <-request.ctx.Done(): - // Context canceled. - trace.SpanFromContext(request.ctx).RecordError(request.ctx.Err()) - } -} diff --git a/workers_internal_test.go b/workers_internal_test.go deleted file mode 100644 index 291971a..0000000 --- a/workers_internal_test.go +++ /dev/null @@ -1,84 +0,0 @@ -/* -Copyright 2024 Mailgun Technologies Inc - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package gubernator - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -type MockHasher struct { - mock.Mock -} - -func (m *MockHasher) ComputeHash63(input string) uint64 { - args := m.Called(input) - retval, _ := args.Get(0).(uint64) - return retval -} - -func TestWorkersInternal(t *testing.T) { - t.Run("getWorker()", func(t *testing.T) { - const concurrency = 32 - conf := &Config{ - Workers: concurrency, - } - require.NoError(t, conf.SetDefaults()) - - // Test that getWorker() interpolates the hash to find the expected worker. - testCases := []struct { - Name string - Hash uint64 - ExpectedIdx int - }{ - {"Hash 0%", 0, 0}, - {"Hash 50%", 0x3fff_ffff_ffff_ffff, (concurrency / 2) - 1}, - {"Hash 50% + 1", 0x4000_0000_0000_0000, concurrency / 2}, - {"Hash 100%", 0x7fff_ffff_ffff_ffff, concurrency - 1}, - } - - for _, testCase := range testCases { - t.Run(testCase.Name, func(t *testing.T) { - pool := NewWorkerPool(conf) - defer pool.Close() - mockHasher := &MockHasher{} - pool.hasher = mockHasher - - // Setup mocks. - mockHasher.On("ComputeHash63", mock.Anything).Once().Return(testCase.Hash) - - // Call code. - worker := pool.getWorker("Foobar") - - // Verify - require.NotNil(t, worker) - - var actualIdx int - for ; actualIdx < len(pool.workers); actualIdx++ { - if pool.workers[actualIdx] == worker { - break - } - } - assert.Equal(t, testCase.ExpectedIdx, actualIdx) - mockHasher.AssertExpectations(t) - }) - } - }) -} From 8f414323c5edd86afc042948f482d2e084cc6860 Mon Sep 17 00:00:00 2001 From: "Derrick J. Wippler" Date: Mon, 13 May 2024 16:41:34 -0500 Subject: [PATCH 2/6] Fixed race condition in tokenBucket() --- Makefile | 3 + algorithms.go | 594 +++++++++++++++++++----------------- benchmark_cache_test.go | 16 +- cache.go | 7 + cache_manager.go | 9 +- cluster/cluster.go | 17 ++ cmd/gubernator/main_test.go | 5 +- functional_test.go | 10 +- gubernator.go | 28 +- lrucache.go | 11 +- otter.go | 25 +- otter_test.go | 9 +- store.go | 79 ++--- store_test.go | 91 ++++++ 14 files changed, 542 insertions(+), 362 deletions(-) diff --git a/Makefile b/Makefile index d98f86d..6ecd257 100644 --- a/Makefile +++ b/Makefile @@ -45,6 +45,9 @@ clean-proto: ## Clean the generated source files from the protobuf sources @find . -name "*.pb.go" -type f -delete @find . -name "*.pb.*.go" -type f -delete +.PHONY: validate +validate: lint test + go mod tidy && git diff --exit-code .PHONY: proto proto: ## Build protos diff --git a/algorithms.go b/algorithms.go index c923161..a118294 100644 --- a/algorithms.go +++ b/algorithms.go @@ -18,14 +18,26 @@ package gubernator import ( "context" + "errors" "github.com/mailgun/holster/v4/clock" "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" - "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" ) +var errAlreadyExistsInCache = errors.New("already exists in cache") + +type rateContext struct { + context.Context + ReqState RateLimitReqState + Request *RateLimitReq + CacheItem *CacheItem + Store Store + Cache Cache + // TODO: Remove + InstanceID string +} + // ### NOTE ### // The both token and leaky follow the same semantic which allows for requests of more than the limit // to be rejected, but subsequent requests within the same window that are under the limit to succeed. @@ -34,223 +46,240 @@ import ( // with 100 emails and the request will succeed. You can override this default behavior with `DRAIN_OVER_LIMIT` // Implements token bucket algorithm for rate limiting. https://en.wikipedia.org/wiki/Token_bucket -func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqState RateLimitReqState) (resp *RateLimitResp, err error) { +func tokenBucket(ctx rateContext) (resp *RateLimitResp, err error) { tokenBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("tokenBucket")) defer tokenBucketTimer.ObserveDuration() - - // Get rate limit from cache. - hashKey := r.HashKey() - item, ok := c.GetItem(hashKey) - - if s != nil && !ok { - // Cache miss. - // Check our store for the item. - if item, ok = s.Get(ctx, r); ok { - c.Add(item) + var ok bool + // TODO: Remove + //fmt.Printf("[%s] tokenBucket()\n", ctx.InstanceID) + + // Get rate limit from cache + hashKey := ctx.Request.HashKey() + ctx.CacheItem, ok = ctx.Cache.GetItem(hashKey) + + // If not in the cache, check the store if provided + if ctx.Store != nil && !ok { + if ctx.CacheItem, ok = ctx.Store.Get(ctx, ctx.Request); ok { + if !ctx.Cache.Add(ctx.CacheItem) { + // Someone else added a new token bucket item to the cache for this + // rate limit before we did, so we retry by calling ourselves recursively. + return tokenBucket(ctx) + } } } - // Sanity checks. - if ok { - if item.Value == nil { - msgPart := "tokenBucket: Invalid cache item; Value is nil" - trace.SpanFromContext(ctx).AddEvent(msgPart, trace.WithAttributes( - attribute.String("hashKey", hashKey), - attribute.String("key", r.UniqueKey), - attribute.String("name", r.Name), - )) - logrus.Error(msgPart) - ok = false - } else if item.Key != hashKey { - msgPart := "tokenBucket: Invalid cache item; key mismatch" - trace.SpanFromContext(ctx).AddEvent(msgPart, trace.WithAttributes( - attribute.String("itemKey", item.Key), - attribute.String("hashKey", hashKey), - attribute.String("name", r.Name), - )) - logrus.Error(msgPart) - ok = false + // If no item was found, or the item is expired. + if ctx.CacheItem == nil || ctx.CacheItem.IsExpired() { + // Initialize the Token bucket item + rl, err := InitTokenBucketItem(ctx) + if err != nil && errors.Is(err, errAlreadyExistsInCache) { + // Someone else added a new token bucket item to the cache for this + // rate limit before we did, so we retry by calling ourselves recursively. + return tokenBucket(ctx) } + return rl, err } - if ok { - // Item found in cache or store. - if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) { - c.Remove(hashKey) + // Gain exclusive rights to this item while we calculate the rate limit + ctx.CacheItem.mutex.Lock() + defer ctx.CacheItem.mutex.Unlock() - if s != nil { - s.Remove(ctx, hashKey) - } - return &RateLimitResp{ - Status: Status_UNDER_LIMIT, - Limit: r.Limit, - Remaining: r.Limit, - ResetTime: 0, - }, nil + t, ok := ctx.CacheItem.Value.(*TokenBucketItem) + if !ok { + // Client switched algorithms; perhaps due to a migration? + ctx.Cache.Remove(hashKey) + if ctx.Store != nil { + ctx.Store.Remove(ctx, hashKey) } - t, ok := item.Value.(*TokenBucketItem) - if !ok { - // Client switched algorithms; perhaps due to a migration? - trace.SpanFromContext(ctx).AddEvent("Client switched algorithms; perhaps due to a migration?") + ctx.CacheItem = nil - c.Remove(hashKey) + rl, err := InitTokenBucketItem(ctx) + if err != nil && errors.Is(err, errAlreadyExistsInCache) { + return tokenBucket(ctx) + } + return rl, err + } - if s != nil { - s.Remove(ctx, hashKey) - } + if HasBehavior(ctx.Request.Behavior, Behavior_RESET_REMAINING) { + t.Remaining = ctx.Request.Limit + t.Limit = ctx.Request.Limit + t.Status = Status_UNDER_LIMIT - return tokenBucketNewItem(ctx, s, c, r, reqState) + if ctx.Store != nil { + ctx.Store.OnChange(ctx, ctx.Request, ctx.CacheItem) } - // Update the limit if it changed. - if t.Limit != r.Limit { - // Add difference to remaining. - t.Remaining += r.Limit - t.Limit - if t.Remaining < 0 { - t.Remaining = 0 - } - t.Limit = r.Limit - } + return &RateLimitResp{ + Status: Status_UNDER_LIMIT, + Limit: ctx.Request.Limit, + Remaining: ctx.Request.Limit, + ResetTime: 0, + }, nil + } - rl := &RateLimitResp{ - Status: t.Status, - Limit: r.Limit, - Remaining: t.Remaining, - ResetTime: item.ExpireAt, + // Update the limit if it changed. + if t.Limit != ctx.Request.Limit { + // Add difference to remaining. + t.Remaining += ctx.Request.Limit - t.Limit + if t.Remaining < 0 { + t.Remaining = 0 } + t.Limit = ctx.Request.Limit + } - // If the duration config changed, update the new ExpireAt. - if t.Duration != r.Duration { - span := trace.SpanFromContext(ctx) - span.AddEvent("Duration changed") - expire := t.CreatedAt + r.Duration - if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { - expire, err = GregorianExpiration(clock.Now(), r.Duration) - if err != nil { - return nil, err - } - } + rl := &RateLimitResp{ + Status: t.Status, + Limit: ctx.Request.Limit, + Remaining: t.Remaining, + ResetTime: ctx.CacheItem.ExpireAt, + } - // If our new duration means we are currently expired. - createdAt := *r.CreatedAt - if expire <= createdAt { - // Renew item. - span.AddEvent("Limit has expired") - expire = createdAt + r.Duration - t.CreatedAt = createdAt - t.Remaining = t.Limit + // If the duration config changed, update the new ExpireAt. + if t.Duration != ctx.Request.Duration { + span := trace.SpanFromContext(ctx) + span.AddEvent("Duration changed") + expire := t.CreatedAt + ctx.Request.Duration + if HasBehavior(ctx.Request.Behavior, Behavior_DURATION_IS_GREGORIAN) { + expire, err = GregorianExpiration(clock.Now(), ctx.Request.Duration) + if err != nil { + return nil, err } - - item.ExpireAt = expire - t.Duration = r.Duration - rl.ResetTime = expire } - if s != nil && reqState.IsOwner { - defer func() { - s.OnChange(ctx, r, item) - }() + // If our new duration means we are currently expired. + createdAt := *ctx.Request.CreatedAt + if expire <= createdAt { + // Renew item. + span.AddEvent("Limit has expired") + expire = createdAt + ctx.Request.Duration + t.CreatedAt = createdAt + t.Remaining = t.Limit } - // Client is only interested in retrieving the current status or - // updating the rate limit config. - if r.Hits == 0 { - return rl, nil - } + ctx.CacheItem.ExpireAt = expire + t.Duration = ctx.Request.Duration + rl.ResetTime = expire + } - // If we are already at the limit. - if rl.Remaining == 0 && r.Hits > 0 { - trace.SpanFromContext(ctx).AddEvent("Already over the limit") - if reqState.IsOwner { - metricOverLimitCounter.Add(1) - } - rl.Status = Status_OVER_LIMIT - t.Status = rl.Status - return rl, nil + if ctx.Store != nil && ctx.ReqState.IsOwner { + defer func() { + ctx.Store.OnChange(ctx, ctx.Request, ctx.CacheItem) + }() + } + + // Client is only interested in retrieving the current status or + // updating the rate limit config. + if ctx.Request.Hits == 0 { + return rl, nil + } + + // If we are already at the limit. + if rl.Remaining == 0 && ctx.Request.Hits > 0 { + trace.SpanFromContext(ctx).AddEvent("Already over the limit") + if ctx.ReqState.IsOwner { + metricOverLimitCounter.Add(1) } + rl.Status = Status_OVER_LIMIT + t.Status = rl.Status + return rl, nil + } + + // If requested hits takes the remainder. + if t.Remaining == ctx.Request.Hits { + trace.SpanFromContext(ctx).AddEvent("At the limit") + t.Remaining = 0 + rl.Remaining = 0 + return rl, nil + } - // If requested hits takes the remainder. - if t.Remaining == r.Hits { - trace.SpanFromContext(ctx).AddEvent("At the limit") + // If requested is more than available, then return over the limit + // without updating the cache. + if ctx.Request.Hits > t.Remaining { + trace.SpanFromContext(ctx).AddEvent("Over the limit") + if ctx.ReqState.IsOwner { + metricOverLimitCounter.Add(1) + } + rl.Status = Status_OVER_LIMIT + if HasBehavior(ctx.Request.Behavior, Behavior_DRAIN_OVER_LIMIT) { + // DRAIN_OVER_LIMIT behavior drains the remaining counter. t.Remaining = 0 rl.Remaining = 0 - return rl, nil } - - // If requested is more than available, then return over the limit - // without updating the cache. - if r.Hits > t.Remaining { - trace.SpanFromContext(ctx).AddEvent("Over the limit") - if reqState.IsOwner { - metricOverLimitCounter.Add(1) - } - rl.Status = Status_OVER_LIMIT - if HasBehavior(r.Behavior, Behavior_DRAIN_OVER_LIMIT) { - // DRAIN_OVER_LIMIT behavior drains the remaining counter. - t.Remaining = 0 - rl.Remaining = 0 - } - return rl, nil - } - - t.Remaining -= r.Hits - rl.Remaining = t.Remaining return rl, nil } - // Item is not found in cache or store, create new. - return tokenBucketNewItem(ctx, s, c, r, reqState) + t.Remaining -= ctx.Request.Hits + rl.Remaining = t.Remaining + return rl, nil } -// Called by tokenBucket() when adding a new item in the store. -func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqState RateLimitReqState) (resp *RateLimitResp, err error) { - createdAt := *r.CreatedAt - expire := createdAt + r.Duration +// InitTokenBucketItem will create a new item if the passed item is nil, else it will update the provided item. +func InitTokenBucketItem(ctx rateContext) (resp *RateLimitResp, err error) { + createdAt := *ctx.Request.CreatedAt + expire := createdAt + ctx.Request.Duration - t := &TokenBucketItem{ - Limit: r.Limit, - Duration: r.Duration, - Remaining: r.Limit - r.Hits, + t := TokenBucketItem{ + Limit: ctx.Request.Limit, + Duration: ctx.Request.Duration, + Remaining: ctx.Request.Limit - ctx.Request.Hits, CreatedAt: createdAt, } // Add a new rate limit to the cache. - if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { - expire, err = GregorianExpiration(clock.Now(), r.Duration) + if HasBehavior(ctx.Request.Behavior, Behavior_DURATION_IS_GREGORIAN) { + expire, err = GregorianExpiration(clock.Now(), ctx.Request.Duration) if err != nil { return nil, err } } - item := &CacheItem{ - Algorithm: Algorithm_TOKEN_BUCKET, - Key: r.HashKey(), - Value: t, - ExpireAt: expire, - } - rl := &RateLimitResp{ Status: Status_UNDER_LIMIT, - Limit: r.Limit, + Limit: ctx.Request.Limit, Remaining: t.Remaining, ResetTime: expire, } // Client could be requesting that we always return OVER_LIMIT. - if r.Hits > r.Limit { + if ctx.Request.Hits > ctx.Request.Limit { trace.SpanFromContext(ctx).AddEvent("Over the limit") - if reqState.IsOwner { + if ctx.ReqState.IsOwner { metricOverLimitCounter.Add(1) } rl.Status = Status_OVER_LIMIT - rl.Remaining = r.Limit - t.Remaining = r.Limit + rl.Remaining = ctx.Request.Limit + t.Remaining = ctx.Request.Limit } - c.Add(item) + // If the cache item already exists, update it + if ctx.CacheItem != nil { + ctx.CacheItem.mutex.Lock() + ctx.CacheItem.Algorithm = Algorithm_TOKEN_BUCKET + ctx.CacheItem.ExpireAt = expire + in, ok := ctx.CacheItem.Value.(*TokenBucketItem) + if !ok { + // Likely the store gave us the wrong cache type + ctx.CacheItem.mutex.Unlock() + ctx.CacheItem = nil + return InitTokenBucketItem(ctx) + } + *in = t + ctx.CacheItem.mutex.Unlock() + } else { + // else create a new cache item and add it to the cache + ctx.CacheItem = &CacheItem{ + Algorithm: Algorithm_TOKEN_BUCKET, + Key: ctx.Request.HashKey(), + Value: &t, + ExpireAt: expire, + } + if !ctx.Cache.Add(ctx.CacheItem) { + return rl, errAlreadyExistsInCache + } + } - if s != nil && reqState.IsOwner { - s.OnChange(ctx, r, item) + if ctx.Store != nil && ctx.ReqState.IsOwner { + ctx.Store.OnChange(ctx, ctx.Request, ctx.CacheItem) } return rl, nil @@ -261,6 +290,8 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqStat leakyBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.getRateLimit_leakyBucket")) defer leakyBucketTimer.ObserveDuration() + // TODO(thrawn01): Test for race conditions, and fix + if r.Burst == 0 { r.Burst = r.Limit } @@ -272,165 +303,158 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqStat item, ok := c.GetItem(hashKey) if s != nil && !ok { - // Cache miss. - // Check our store for the item. + // Cache missed, check our store for the item. if item, ok = s.Get(ctx, r); ok { - c.Add(item) + if !c.Add(item) { + // Someone else added a new leaky bucket item to the cache for this + // rate limit before we did, so we retry by calling ourselves recursively. + return leakyBucket(ctx, s, c, r, reqState) + } } } - // Sanity checks. - if ok { - if item.Value == nil { - msgPart := "leakyBucket: Invalid cache item; Value is nil" - trace.SpanFromContext(ctx).AddEvent(msgPart, trace.WithAttributes( - attribute.String("hashKey", hashKey), - attribute.String("key", r.UniqueKey), - attribute.String("name", r.Name), - )) - logrus.Error(msgPart) - ok = false - } else if item.Key != hashKey { - msgPart := "leakyBucket: Invalid cache item; key mismatch" - trace.SpanFromContext(ctx).AddEvent(msgPart, trace.WithAttributes( - attribute.String("itemKey", item.Key), - attribute.String("hashKey", hashKey), - attribute.String("name", r.Name), - )) - logrus.Error(msgPart) - ok = false + if !ok { + rl, err := leakyBucketNewItem(ctx, s, c, r, reqState) + if err != nil && errors.Is(err, errAlreadyExistsInCache) { + // Someone else added a new leaky bucket item to the cache for this + // rate limit before we did, so we retry by calling ourselves recursively. + return leakyBucket(ctx, s, c, r, reqState) } + return rl, err } - if ok { - // Item found in cache or store. + // Item found in cache or store. + b, ok := item.Value.(*LeakyBucketItem) + if !ok { + // Client switched algorithms; perhaps due to a migration? + c.Remove(hashKey) + if s != nil { + s.Remove(ctx, hashKey) + } - b, ok := item.Value.(*LeakyBucketItem) - if !ok { - // Client switched algorithms; perhaps due to a migration? - c.Remove(hashKey) + rl, err := leakyBucketNewItem(ctx, s, c, r, reqState) + if err != nil && errors.Is(err, errAlreadyExistsInCache) { + return leakyBucket(ctx, s, c, r, reqState) + } + return rl, err + } - if s != nil { - s.Remove(ctx, hashKey) - } + // Gain exclusive rights to this item while we calculate the rate limit + b.mutex.Lock() + defer b.mutex.Unlock() - return leakyBucketNewItem(ctx, s, c, r, reqState) - } + if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) { + b.Remaining = float64(r.Burst) + } - if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) { + // Update burst, limit and duration if they changed + if b.Burst != r.Burst { + if r.Burst > int64(b.Remaining) { b.Remaining = float64(r.Burst) } + b.Burst = r.Burst + } - // Update burst, limit and duration if they changed - if b.Burst != r.Burst { - if r.Burst > int64(b.Remaining) { - b.Remaining = float64(r.Burst) - } - b.Burst = r.Burst - } - - b.Limit = r.Limit - b.Duration = r.Duration - - duration := r.Duration - rate := float64(duration) / float64(r.Limit) + b.Limit = r.Limit + b.Duration = r.Duration - if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { - d, err := GregorianDuration(clock.Now(), r.Duration) - if err != nil { - return nil, err - } - n := clock.Now() - expire, err := GregorianExpiration(n, r.Duration) - if err != nil { - return nil, err - } + duration := r.Duration + rate := float64(duration) / float64(r.Limit) - // Calculate the rate using the entire duration of the gregorian interval - // IE: Minute = 60,000 milliseconds, etc.. etc.. - rate = float64(d) / float64(r.Limit) - // Update the duration to be the end of the gregorian interval - duration = expire - (n.UnixNano() / 1000000) + if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { + d, err := GregorianDuration(clock.Now(), r.Duration) + if err != nil { + return nil, err } - - if r.Hits != 0 { - c.UpdateExpiration(r.HashKey(), createdAt+duration) + n := clock.Now() + expire, err := GregorianExpiration(n, r.Duration) + if err != nil { + return nil, err } - // Calculate how much leaked out of the bucket since the last time we leaked a hit - elapsed := createdAt - b.UpdatedAt - leak := float64(elapsed) / rate + // Calculate the rate using the entire duration of the gregorian interval + // IE: Minute = 60,000 milliseconds, etc.. etc.. + rate = float64(d) / float64(r.Limit) + // Update the duration to be the end of the gregorian interval + duration = expire - (n.UnixNano() / 1000000) + } - if int64(leak) > 0 { - b.Remaining += leak - b.UpdatedAt = createdAt - } + if r.Hits != 0 { + c.UpdateExpiration(r.HashKey(), createdAt+duration) + } - if int64(b.Remaining) > b.Burst { - b.Remaining = float64(b.Burst) - } + // Calculate how much leaked out of the bucket since the last time we leaked a hit + elapsed := createdAt - b.UpdatedAt + leak := float64(elapsed) / rate - rl := &RateLimitResp{ - Limit: b.Limit, - Remaining: int64(b.Remaining), - Status: Status_UNDER_LIMIT, - ResetTime: createdAt + (b.Limit-int64(b.Remaining))*int64(rate), - } + if int64(leak) > 0 { + b.Remaining += leak + b.UpdatedAt = createdAt + } - // TODO: Feature missing: check for Duration change between item/request. + if int64(b.Remaining) > b.Burst { + b.Remaining = float64(b.Burst) + } - if s != nil && reqState.IsOwner { - defer func() { - s.OnChange(ctx, r, item) - }() - } + rl := &RateLimitResp{ + Limit: b.Limit, + Remaining: int64(b.Remaining), + Status: Status_UNDER_LIMIT, + ResetTime: createdAt + (b.Limit-int64(b.Remaining))*int64(rate), + } - // If we are already at the limit - if int64(b.Remaining) == 0 && r.Hits > 0 { - if reqState.IsOwner { - metricOverLimitCounter.Add(1) - } - rl.Status = Status_OVER_LIMIT - return rl, nil - } + // TODO: Feature missing: check for Duration change between item/request. - // If requested hits takes the remainder - if int64(b.Remaining) == r.Hits { - b.Remaining = 0 - rl.Remaining = int64(b.Remaining) - rl.ResetTime = createdAt + (rl.Limit-rl.Remaining)*int64(rate) - return rl, nil - } + if s != nil && reqState.IsOwner { + defer func() { + s.OnChange(ctx, r, item) + }() + } - // If requested is more than available, then return over the limit - // without updating the bucket, unless `DRAIN_OVER_LIMIT` is set. - if r.Hits > int64(b.Remaining) { - if reqState.IsOwner { - metricOverLimitCounter.Add(1) - } - rl.Status = Status_OVER_LIMIT + // If we are already at the limit + if int64(b.Remaining) == 0 && r.Hits > 0 { + if reqState.IsOwner { + metricOverLimitCounter.Add(1) + } + rl.Status = Status_OVER_LIMIT + return rl, nil + } - // DRAIN_OVER_LIMIT behavior drains the remaining counter. - if HasBehavior(r.Behavior, Behavior_DRAIN_OVER_LIMIT) { - b.Remaining = 0 - rl.Remaining = 0 - } + // If requested hits takes the remainder + if int64(b.Remaining) == r.Hits { + b.Remaining = 0 + rl.Remaining = int64(b.Remaining) + rl.ResetTime = createdAt + (rl.Limit-rl.Remaining)*int64(rate) + return rl, nil + } - return rl, nil + // If requested is more than available, then return over the limit + // without updating the bucket, unless `DRAIN_OVER_LIMIT` is set. + if r.Hits > int64(b.Remaining) { + if reqState.IsOwner { + metricOverLimitCounter.Add(1) } + rl.Status = Status_OVER_LIMIT - // Client is only interested in retrieving the current status - if r.Hits == 0 { - return rl, nil + // DRAIN_OVER_LIMIT behavior drains the remaining counter. + if HasBehavior(r.Behavior, Behavior_DRAIN_OVER_LIMIT) { + b.Remaining = 0 + rl.Remaining = 0 } - b.Remaining -= float64(r.Hits) - rl.Remaining = int64(b.Remaining) - rl.ResetTime = createdAt + (rl.Limit-rl.Remaining)*int64(rate) return rl, nil } - return leakyBucketNewItem(ctx, s, c, r, reqState) + // Client is only interested in retrieving the current status + if r.Hits == 0 { + return rl, nil + } + + b.Remaining -= float64(r.Hits) + rl.Remaining = int64(b.Remaining) + rl.ResetTime = createdAt + (rl.Limit-rl.Remaining)*int64(rate) + return rl, nil + } // Called by leakyBucket() when adding a new item in the store. @@ -483,7 +507,9 @@ func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, Value: &b, } - c.Add(item) + if !c.Add(item) { + return nil, errAlreadyExistsInCache + } if s != nil && reqState.IsOwner { s.OnChange(ctx, r, item) diff --git a/benchmark_cache_test.go b/benchmark_cache_test.go index 98c2b68..1f849d8 100644 --- a/benchmark_cache_test.go +++ b/benchmark_cache_test.go @@ -39,7 +39,7 @@ func BenchmarkCache(b *testing.B) { cache, err := testCase.NewTestCache() require.NoError(b, err) expire := clock.Now().Add(time.Hour).UnixMilli() - keys := GenerateRandomKeys() + keys := GenerateRandomKeys(defaultNumKeys) for _, key := range keys { item := &gubernator.CacheItem{ @@ -64,7 +64,7 @@ func BenchmarkCache(b *testing.B) { cache, err := testCase.NewTestCache() require.NoError(b, err) expire := clock.Now().Add(time.Hour).UnixMilli() - keys := GenerateRandomKeys() + keys := GenerateRandomKeys(defaultNumKeys) mask := len(keys) - 1 b.ReportAllocs() @@ -85,7 +85,7 @@ func BenchmarkCache(b *testing.B) { cache, err := testCase.NewTestCache() require.NoError(b, err) expire := clock.Now().Add(time.Hour).UnixMilli() - keys := GenerateRandomKeys() + keys := GenerateRandomKeys(defaultNumKeys) for _, key := range keys { item := &gubernator.CacheItem{ @@ -129,7 +129,7 @@ func BenchmarkCache(b *testing.B) { cache, err := testCase.NewTestCache() require.NoError(b, err) expire := clock.Now().Add(time.Hour).UnixMilli() - keys := GenerateRandomKeys() + keys := GenerateRandomKeys(defaultNumKeys) var mutex sync.Mutex var task func(key string) @@ -172,11 +172,11 @@ func BenchmarkCache(b *testing.B) { } } -const cacheSize = 32768 +const defaultNumKeys = 32768 -func GenerateRandomKeys() []string { - keys := make([]string, 0, cacheSize) - for i := 0; i < cacheSize; i++ { +func GenerateRandomKeys(size int) []string { + keys := make([]string, 0, size) + for i := 0; i < size; i++ { keys = append(keys, gubernator.RandomString(20)) } return keys diff --git a/cache.go b/cache.go index 0fd431a..dbeea21 100644 --- a/cache.go +++ b/cache.go @@ -16,6 +16,8 @@ limitations under the License. package gubernator +import "sync" + type Cache interface { Add(item *CacheItem) bool UpdateExpiration(key string, expireAt int64) bool @@ -27,6 +29,7 @@ type Cache interface { } type CacheItem struct { + mutex sync.Mutex Algorithm Algorithm Key string Value interface{} @@ -41,6 +44,10 @@ type CacheItem struct { } func (item *CacheItem) IsExpired() bool { + // TODO(thrawn01): Eliminate the need for this mutex lock + item.mutex.Lock() + defer item.mutex.Unlock() + now := MillisecondNow() // If the entry is invalidated diff --git a/cache_manager.go b/cache_manager.go index 6542bdf..3514ce1 100644 --- a/cache_manager.go +++ b/cache_manager.go @@ -58,7 +58,14 @@ func (m *cacheManager) GetRateLimit(ctx context.Context, req *RateLimitReq, stat switch req.Algorithm { case Algorithm_TOKEN_BUCKET: - rlResponse, err = tokenBucket(ctx, m.conf.Store, m.cache, req, state) + rlResponse, err = tokenBucket(rateContext{ + Store: m.conf.Store, + Cache: m.cache, + ReqState: state, + Request: req, + Context: ctx, + InstanceID: m.conf.InstanceID, + }) if err != nil { msg := "Error in tokenBucket" countError(err, msg) diff --git a/cluster/cluster.go b/cluster/cluster.go index 3fef87e..ad3714e 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -53,6 +53,23 @@ func GetRandomPeer(dc string) gubernator.PeerInfo { return local[rand.Intn(len(local))] } +// GetRandomDaemon returns a random daemon from the cluster +func GetRandomDaemon(dc string) *gubernator.Daemon { + var local []*gubernator.Daemon + + for _, d := range daemons { + if d.PeerInfo.DataCenter == dc { + local = append(local, d) + } + } + + if len(local) == 0 { + panic(fmt.Sprintf("failed to find random daemon for dc '%s'", dc)) + } + + return local[rand.Intn(len(local))] +} + // GetPeers returns a list of all peers in the cluster func GetPeers() []gubernator.PeerInfo { return peers diff --git a/cmd/gubernator/main_test.go b/cmd/gubernator/main_test.go index 4f1364e..f374d7f 100644 --- a/cmd/gubernator/main_test.go +++ b/cmd/gubernator/main_test.go @@ -15,9 +15,10 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + cli "github.com/gubernator-io/gubernator/v2/cmd/gubernator" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "golang.org/x/net/proxy" ) @@ -78,9 +79,9 @@ func TestCLI(t *testing.T) { time.Sleep(time.Second * 1) err = c.Process.Signal(syscall.SIGTERM) - require.NoError(t, err, out.String()) <-waitCh + require.NoError(t, err, out.String()) assert.Contains(t, out.String(), tt.contains) }) } diff --git a/functional_test.go b/functional_test.go index d93d352..ba8be28 100644 --- a/functional_test.go +++ b/functional_test.go @@ -2252,8 +2252,14 @@ func waitForIdle(timeout clock.Duration, daemons ...*guber.Daemon) error { if err != nil { return err } - ggql := metrics["gubernator_global_queue_length"] - gsql := metrics["gubernator_global_send_queue_length"] + ggql, ok := metrics["gubernator_global_queue_length"] + if !ok { + return errors.New("gubernator_global_queue_length not found") + } + gsql, ok := metrics["gubernator_global_send_queue_length"] + if !ok { + return errors.New("gubernator_global_send_queue_length not found") + } if ggql.Value == 0 && gsql.Value == 0 { return nil diff --git a/gubernator.go b/gubernator.go index 1ae40d4..1ef4fac 100644 --- a/gubernator.go +++ b/gubernator.go @@ -420,12 +420,26 @@ func (s *V1Instance) getGlobalRateLimit(ctx context.Context, req *RateLimitReq) func (s *V1Instance) UpdatePeerGlobals(ctx context.Context, r *UpdatePeerGlobalsReq) (*UpdatePeerGlobalsResp, error) { defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.UpdatePeerGlobals")).ObserveDuration() now := MillisecondNow() + for _, g := range r.Globals { - item := &CacheItem{ - ExpireAt: g.Status.ResetTime, - Algorithm: g.Algorithm, - Key: g.Key, + item, _, err := s.cache.GetCacheItem(ctx, g.Key) + if err != nil { + return nil, err } + + if item == nil { + item = &CacheItem{ + ExpireAt: g.Status.ResetTime, + Algorithm: g.Algorithm, + Key: g.Key, + } + err := s.cache.AddCacheItem(ctx, g.Key, item) + if err != nil { + return nil, fmt.Errorf("during CacheManager.AddCacheItem(): %w", err) + } + } + + item.mutex.Lock() switch g.Algorithm { case Algorithm_LEAKY_BUCKET: item.Value = &LeakyBucketItem{ @@ -444,12 +458,8 @@ func (s *V1Instance) UpdatePeerGlobals(ctx context.Context, r *UpdatePeerGlobals CreatedAt: now, } } - err := s.cache.AddCacheItem(ctx, g.Key, item) - if err != nil { - return nil, errors.Wrap(err, "Error in CacheManager.AddCacheItem") - } + item.mutex.Unlock() } - return &UpdatePeerGlobalsResp{}, nil } diff --git a/lrucache.go b/lrucache.go index 8a415c9..5bef041 100644 --- a/lrucache.go +++ b/lrucache.go @@ -119,11 +119,12 @@ func (c *LRUCache) GetItem(key string) (item *CacheItem, ok bool) { if ele, hit := c.cache[key]; hit { entry := ele.Value.(*CacheItem) - if entry.IsExpired() { - c.removeElement(ele) - metricCacheAccess.WithLabelValues("miss").Add(1) - return - } + // TODO(thrawn01): Remove + //if entry.IsExpired() { + // c.removeElement(ele) + // metricCacheAccess.WithLabelValues("miss").Add(1) + // return + //} metricCacheAccess.WithLabelValues("hit").Add(1) c.ll.MoveToFront(ele) diff --git a/otter.go b/otter.go index f04fc8e..9f60c55 100644 --- a/otter.go +++ b/otter.go @@ -37,9 +37,9 @@ func NewOtterCache(size int) (*OtterCache, error) { // Add adds a new CacheItem to the cache. The key must be provided via CacheItem.Key // returns true if the item was added to the cache; false if the item was too large -// for the cache. +// for the cache or already exists in the cache. func (o *OtterCache) Add(item *CacheItem) bool { - return o.cache.Set(item.Key, item) + return o.cache.SetIfAbsent(item.Key, item) } // GetItem returns an item in the cache that corresponds to the provided key @@ -50,16 +50,17 @@ func (o *OtterCache) GetItem(key string) (*CacheItem, bool) { return nil, false } - if item.IsExpired() { - metricCacheAccess.WithLabelValues("miss").Add(1) - // If the item is expired, just return `nil` - // - // We avoid the explicit deletion of the expired item to avoid acquiring a mutex lock in otter. - // Explicit deletions in otter require a mutex, which can cause performance bottlenecks - // under high concurrency scenarios. By allowing the item to be evicted naturally by - // otter's eviction mechanism, we avoid impacting performance under high concurrency. - return nil, false - } + // TODO(thrawn01): Remove + //if item.IsExpired() { + // metricCacheAccess.WithLabelValues("miss").Add(1) + // // If the item is expired, just return `nil` + // // + // // We avoid the explicit deletion of the expired item to avoid acquiring a mutex lock in otter. + // // Explicit deletions in otter require a mutex, which can cause performance bottlenecks + // // under high concurrency scenarios. By allowing the item to be evicted naturally by + // // otter's eviction mechanism, we avoid impacting performance under high concurrency. + // return nil, false + //} metricCacheAccess.WithLabelValues("hit").Add(1) return item, true } diff --git a/otter_test.go b/otter_test.go index 6eb629d..9c84df1 100644 --- a/otter_test.go +++ b/otter_test.go @@ -65,13 +65,18 @@ func TestOtterCache(t *testing.T) { } cache.Add(item1) - // Update same key. + // Update same key is refused item2 := &gubernator.CacheItem{ Key: key, Value: "new value", ExpireAt: expireAt, } - cache.Add(item2) + assert.False(t, cache.Add(item2)) + + // Fetch and update the CacheItem + update, ok := cache.GetItem(key) + assert.True(t, ok) + update.Value = "new value" // Verify. verifyItem, ok := cache.GetItem(key) diff --git a/store.go b/store.go index 1c23461..089ea50 100644 --- a/store.go +++ b/store.go @@ -16,7 +16,10 @@ limitations under the License. package gubernator -import "context" +import ( + "context" + "sync" +) // PERSISTENT STORE DETAILS @@ -27,6 +30,7 @@ import "context" // Both interfaces can be implemented simultaneously to ensure data is always saved to persistent storage. type LeakyBucketItem struct { + mutex sync.Mutex Limit int64 Duration int64 Remaining float64 @@ -47,18 +51,18 @@ type TokenBucketItem struct { // to maximize performance of gubernator. // Implementations MUST be threadsafe. type Store interface { - // Called by gubernator *after* a rate limit item is updated. It's up to the store to + // OnChange is called by gubernator *after* a rate limit item is updated. It's up to the store to // decide if this rate limit item should be persisted in the store. It's up to the // store to expire old rate limit items. The CacheItem represents the current state of // the rate limit item *after* the RateLimitReq has been applied. OnChange(ctx context.Context, r *RateLimitReq, item *CacheItem) - // Called by gubernator when a rate limit is missing from the cache. It's up to the store + // Get is called by gubernator when a rate limit is missing from the cache. It's up to the store // to decide if this request is fulfilled. Should return true if the request is fulfilled // and false if the request is not fulfilled or doesn't exist in the store. Get(ctx context.Context, r *RateLimitReq) (*CacheItem, bool) - // Called by gubernator when an existing rate limit should be removed from the store. + // Remove ic called by gubernator when an existing rate limit should be removed from the store. // NOTE: This is NOT called when an rate limit expires from the cache, store implementors // must expire rate limits in the store. Remove(ctx context.Context, key string) @@ -77,39 +81,40 @@ type Loader interface { Save(chan *CacheItem) error } -func NewMockStore() *MockStore { - ml := &MockStore{ - Called: make(map[string]int), - CacheItems: make(map[string]*CacheItem), - } - ml.Called["OnChange()"] = 0 - ml.Called["Remove()"] = 0 - ml.Called["Get()"] = 0 - return ml -} - -type MockStore struct { - Called map[string]int - CacheItems map[string]*CacheItem -} - -var _ Store = &MockStore{} - -func (ms *MockStore) OnChange(ctx context.Context, r *RateLimitReq, item *CacheItem) { - ms.Called["OnChange()"] += 1 - ms.CacheItems[item.Key] = item -} - -func (ms *MockStore) Get(ctx context.Context, r *RateLimitReq) (*CacheItem, bool) { - ms.Called["Get()"] += 1 - item, ok := ms.CacheItems[r.HashKey()] - return item, ok -} - -func (ms *MockStore) Remove(ctx context.Context, key string) { - ms.Called["Remove()"] += 1 - delete(ms.CacheItems, key) -} +// TODO Remove +//func NewMockStore() *MockStore { +// ml := &MockStore{ +// Called: make(map[string]int), +// CacheItems: make(map[string]*CacheItem), +// } +// ml.Called["OnChange()"] = 0 +// ml.Called["Remove()"] = 0 +// ml.Called["Get()"] = 0 +// return ml +//} +// +//type MockStore struct { +// Called map[string]int +// CacheItems map[string]*CacheItem +//} +// +//var _ Store = &MockStore{} +// +//func (ms *MockStore) OnChange(ctx context.Context, r *RateLimitReq, item *CacheItem) { +// ms.Called["OnChange()"] += 1 +// ms.CacheItems[item.Key] = item +//} +// +//func (ms *MockStore) Get(ctx context.Context, r *RateLimitReq) (*CacheItem, bool) { +// ms.Called["Get()"] += 1 +// item, ok := ms.CacheItems[r.HashKey()] +// return item, ok +//} +// +//func (ms *MockStore) Remove(ctx context.Context, key string) { +// ms.Called["Remove()"] += 1 +// delete(ms.CacheItems, key) +//} func NewMockLoader() *MockLoader { ml := &MockLoader{ diff --git a/store_test.go b/store_test.go index e7c58f6..ff29df0 100644 --- a/store_test.go +++ b/store_test.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "net" + "sync" "testing" "github.com/gubernator-io/gubernator/v2" @@ -124,6 +125,96 @@ func TestLoader(t *testing.T) { assert.Equal(t, gubernator.Status_UNDER_LIMIT, item.Status) } +type NoOpStore struct{} + +func (ms *NoOpStore) Remove(ctx context.Context, key string) {} +func (ms *NoOpStore) OnChange(ctx context.Context, r *gubernator.RateLimitReq, item *gubernator.CacheItem) { +} + +func (ms *NoOpStore) Get(ctx context.Context, r *gubernator.RateLimitReq) (*gubernator.CacheItem, bool) { + return &gubernator.CacheItem{ + Algorithm: gubernator.Algorithm_TOKEN_BUCKET, + Key: r.HashKey(), + Value: gubernator.TokenBucketItem{ + CreatedAt: gubernator.MillisecondNow(), + Duration: gubernator.Minute * 60, + Limit: 1_000, + Remaining: 1_000, + Status: 0, + }, + ExpireAt: 0, + }, true +} + +// The goal of this test is to generate some race conditions where multiple routines load from the store and or +// add items to the cache in parallel thus creating a race condition the code must then handle. +func TestHighContentionFromStore(t *testing.T) { + const ( + numGoroutines = 1_000 + numKeys = 400 + ) + store := &NoOpStore{} + srv := newV1Server(t, "localhost:0", gubernator.Config{ + Behaviors: gubernator.BehaviorConfig{ + GlobalSyncWait: clock.Millisecond * 50, // Suitable for testing but not production + GlobalTimeout: clock.Second, + }, + Store: store, + }) + client, err := gubernator.DialV1Server(srv.listener.Addr().String(), nil) + require.NoError(t, err) + + keys := GenerateRandomKeys(numKeys) + + var wg sync.WaitGroup + var ready sync.WaitGroup + wg.Add(numGoroutines) + ready.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + ready.Wait() + for idx := 0; idx < numKeys; idx++ { + _, err := client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{ + Requests: []*gubernator.RateLimitReq{ + { + Name: keys[idx], + UniqueKey: "high_contention_", + Algorithm: gubernator.Algorithm_TOKEN_BUCKET, + Duration: gubernator.Minute * 60, + Limit: numKeys, + Hits: 1, + }, + }, + }) + require.NoError(t, err) + } + wg.Done() + }() + ready.Done() + } + wg.Wait() + + for idx := 0; idx < numKeys; idx++ { + resp, err := client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{ + Requests: []*gubernator.RateLimitReq{ + { + Name: keys[idx], + UniqueKey: "high_contention_", + Algorithm: gubernator.Algorithm_TOKEN_BUCKET, + Duration: gubernator.Minute * 60, + Limit: numKeys, + Hits: 0, + }, + }, + }) + require.NoError(t, err) + assert.Equal(t, int64(0), resp.Responses[0].Remaining) + } + + assert.NoError(t, srv.Close()) +} + func TestStore(t *testing.T) { ctx := context.Background() setup := func() (*MockStore2, *v1Server, gubernator.V1Client) { From 1954bede0614b10b7c23a2d6bf267b0994903b9c Mon Sep 17 00:00:00 2001 From: "Derrick J. Wippler" Date: Tue, 14 May 2024 16:24:11 -0500 Subject: [PATCH 3/6] Fixed race conditions in leakybucket --- algorithms.go | 219 ++++++++++++++++++++++++--------------------- cache.go | 1 - cache_manager.go | 19 ++-- lrucache.go | 20 ----- mock_cache_test.go | 5 -- otter.go | 23 ----- store.go | 37 -------- store_test.go | 5 +- 8 files changed, 131 insertions(+), 198 deletions(-) diff --git a/algorithms.go b/algorithms.go index a118294..a84c84c 100644 --- a/algorithms.go +++ b/algorithms.go @@ -34,8 +34,6 @@ type rateContext struct { CacheItem *CacheItem Store Store Cache Cache - // TODO: Remove - InstanceID string } // ### NOTE ### @@ -50,8 +48,6 @@ func tokenBucket(ctx rateContext) (resp *RateLimitResp, err error) { tokenBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("tokenBucket")) defer tokenBucketTimer.ObserveDuration() var ok bool - // TODO: Remove - //fmt.Printf("[%s] tokenBucket()\n", ctx.InstanceID) // Get rate limit from cache hashKey := ctx.Request.HashKey() @@ -69,9 +65,8 @@ func tokenBucket(ctx rateContext) (resp *RateLimitResp, err error) { } // If no item was found, or the item is expired. - if ctx.CacheItem == nil || ctx.CacheItem.IsExpired() { - // Initialize the Token bucket item - rl, err := InitTokenBucketItem(ctx) + if !ok || ctx.CacheItem.IsExpired() { + rl, err := initTokenBucketItem(ctx) if err != nil && errors.Is(err, errAlreadyExistsInCache) { // Someone else added a new token bucket item to the cache for this // rate limit before we did, so we retry by calling ourselves recursively. @@ -82,7 +77,6 @@ func tokenBucket(ctx rateContext) (resp *RateLimitResp, err error) { // Gain exclusive rights to this item while we calculate the rate limit ctx.CacheItem.mutex.Lock() - defer ctx.CacheItem.mutex.Unlock() t, ok := ctx.CacheItem.Value.(*TokenBucketItem) if !ok { @@ -91,15 +85,18 @@ func tokenBucket(ctx rateContext) (resp *RateLimitResp, err error) { if ctx.Store != nil { ctx.Store.Remove(ctx, hashKey) } + // Tell init to create a new cache item + ctx.CacheItem.mutex.Unlock() ctx.CacheItem = nil - - rl, err := InitTokenBucketItem(ctx) + rl, err := initTokenBucketItem(ctx) if err != nil && errors.Is(err, errAlreadyExistsInCache) { return tokenBucket(ctx) } return rl, err } + defer ctx.CacheItem.mutex.Unlock() + if HasBehavior(ctx.Request.Behavior, Behavior_RESET_REMAINING) { t.Remaining = ctx.Request.Limit t.Limit = ctx.Request.Limit @@ -213,8 +210,8 @@ func tokenBucket(ctx rateContext) (resp *RateLimitResp, err error) { return rl, nil } -// InitTokenBucketItem will create a new item if the passed item is nil, else it will update the provided item. -func InitTokenBucketItem(ctx rateContext) (resp *RateLimitResp, err error) { +// initTokenBucketItem will create a new item if the passed item is nil, else it will update the provided item. +func initTokenBucketItem(ctx rateContext) (resp *RateLimitResp, err error) { createdAt := *ctx.Request.CreatedAt expire := createdAt + ctx.Request.Duration @@ -254,14 +251,14 @@ func InitTokenBucketItem(ctx rateContext) (resp *RateLimitResp, err error) { // If the cache item already exists, update it if ctx.CacheItem != nil { ctx.CacheItem.mutex.Lock() - ctx.CacheItem.Algorithm = Algorithm_TOKEN_BUCKET + ctx.CacheItem.Algorithm = ctx.Request.Algorithm ctx.CacheItem.ExpireAt = expire in, ok := ctx.CacheItem.Value.(*TokenBucketItem) if !ok { - // Likely the store gave us the wrong cache type + // Likely the store gave us the wrong cache item ctx.CacheItem.mutex.Unlock() ctx.CacheItem = nil - return InitTokenBucketItem(ctx) + return initTokenBucketItem(ctx) } *in = t ctx.CacheItem.mutex.Unlock() @@ -286,134 +283,136 @@ func InitTokenBucketItem(ctx rateContext) (resp *RateLimitResp, err error) { } // Implements leaky bucket algorithm for rate limiting https://en.wikipedia.org/wiki/Leaky_bucket -func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqState RateLimitReqState) (resp *RateLimitResp, err error) { +func leakyBucket(ctx rateContext) (resp *RateLimitResp, err error) { leakyBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.getRateLimit_leakyBucket")) defer leakyBucketTimer.ObserveDuration() + var ok bool - // TODO(thrawn01): Test for race conditions, and fix - - if r.Burst == 0 { - r.Burst = r.Limit + if ctx.Request.Burst == 0 { + ctx.Request.Burst = ctx.Request.Limit } - createdAt := *r.CreatedAt - // Get rate limit from cache. - hashKey := r.HashKey() - item, ok := c.GetItem(hashKey) + hashKey := ctx.Request.HashKey() + ctx.CacheItem, ok = ctx.Cache.GetItem(hashKey) - if s != nil && !ok { + if ctx.Store != nil && !ok { // Cache missed, check our store for the item. - if item, ok = s.Get(ctx, r); ok { - if !c.Add(item) { + if ctx.CacheItem, ok = ctx.Store.Get(ctx, ctx.Request); ok { + if !ctx.Cache.Add(ctx.CacheItem) { // Someone else added a new leaky bucket item to the cache for this // rate limit before we did, so we retry by calling ourselves recursively. - return leakyBucket(ctx, s, c, r, reqState) + return leakyBucket(ctx) } } } - if !ok { - rl, err := leakyBucketNewItem(ctx, s, c, r, reqState) + // If no item was found, or the item is expired. + if !ok || ctx.CacheItem.IsExpired() { + rl, err := initLeakyBucketItem(ctx) if err != nil && errors.Is(err, errAlreadyExistsInCache) { // Someone else added a new leaky bucket item to the cache for this // rate limit before we did, so we retry by calling ourselves recursively. - return leakyBucket(ctx, s, c, r, reqState) + return leakyBucket(ctx) } return rl, err } + // Gain exclusive rights to this item while we calculate the rate limit + ctx.CacheItem.mutex.Lock() + // Item found in cache or store. - b, ok := item.Value.(*LeakyBucketItem) + t, ok := ctx.CacheItem.Value.(*LeakyBucketItem) if !ok { // Client switched algorithms; perhaps due to a migration? - c.Remove(hashKey) - if s != nil { - s.Remove(ctx, hashKey) + ctx.Cache.Remove(hashKey) + if ctx.Store != nil { + ctx.Store.Remove(ctx, hashKey) } - - rl, err := leakyBucketNewItem(ctx, s, c, r, reqState) + // Tell init to create a new cache item + ctx.CacheItem.mutex.Unlock() + ctx.CacheItem = nil + rl, err := initLeakyBucketItem(ctx) if err != nil && errors.Is(err, errAlreadyExistsInCache) { - return leakyBucket(ctx, s, c, r, reqState) + return leakyBucket(ctx) } return rl, err } - // Gain exclusive rights to this item while we calculate the rate limit - b.mutex.Lock() - defer b.mutex.Unlock() + defer ctx.CacheItem.mutex.Unlock() - if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) { - b.Remaining = float64(r.Burst) + if HasBehavior(ctx.Request.Behavior, Behavior_RESET_REMAINING) { + t.Remaining = float64(ctx.Request.Burst) } // Update burst, limit and duration if they changed - if b.Burst != r.Burst { - if r.Burst > int64(b.Remaining) { - b.Remaining = float64(r.Burst) + if t.Burst != ctx.Request.Burst { + if ctx.Request.Burst > int64(t.Remaining) { + t.Remaining = float64(ctx.Request.Burst) } - b.Burst = r.Burst + t.Burst = ctx.Request.Burst } - b.Limit = r.Limit - b.Duration = r.Duration + t.Limit = ctx.Request.Limit + t.Duration = ctx.Request.Duration - duration := r.Duration - rate := float64(duration) / float64(r.Limit) + duration := ctx.Request.Duration + rate := float64(duration) / float64(ctx.Request.Limit) - if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { - d, err := GregorianDuration(clock.Now(), r.Duration) + if HasBehavior(ctx.Request.Behavior, Behavior_DURATION_IS_GREGORIAN) { + d, err := GregorianDuration(clock.Now(), ctx.Request.Duration) if err != nil { return nil, err } n := clock.Now() - expire, err := GregorianExpiration(n, r.Duration) + expire, err := GregorianExpiration(n, ctx.Request.Duration) if err != nil { return nil, err } // Calculate the rate using the entire duration of the gregorian interval // IE: Minute = 60,000 milliseconds, etc.. etc.. - rate = float64(d) / float64(r.Limit) + rate = float64(d) / float64(ctx.Request.Limit) // Update the duration to be the end of the gregorian interval duration = expire - (n.UnixNano() / 1000000) } - if r.Hits != 0 { - c.UpdateExpiration(r.HashKey(), createdAt+duration) + createdAt := *ctx.Request.CreatedAt + if ctx.Request.Hits != 0 { + ctx.CacheItem.ExpireAt = createdAt + duration } // Calculate how much leaked out of the bucket since the last time we leaked a hit - elapsed := createdAt - b.UpdatedAt + elapsed := createdAt - t.UpdatedAt leak := float64(elapsed) / rate if int64(leak) > 0 { - b.Remaining += leak - b.UpdatedAt = createdAt + t.Remaining += leak + t.UpdatedAt = createdAt } - if int64(b.Remaining) > b.Burst { - b.Remaining = float64(b.Burst) + if int64(t.Remaining) > t.Burst { + t.Remaining = float64(t.Burst) } rl := &RateLimitResp{ - Limit: b.Limit, - Remaining: int64(b.Remaining), + Limit: t.Limit, + Remaining: int64(t.Remaining), Status: Status_UNDER_LIMIT, - ResetTime: createdAt + (b.Limit-int64(b.Remaining))*int64(rate), + ResetTime: createdAt + (t.Limit-int64(t.Remaining))*int64(rate), } // TODO: Feature missing: check for Duration change between item/request. - if s != nil && reqState.IsOwner { + if ctx.Store != nil && ctx.ReqState.IsOwner { defer func() { - s.OnChange(ctx, r, item) + ctx.Store.OnChange(ctx, ctx.Request, ctx.CacheItem) }() } // If we are already at the limit - if int64(b.Remaining) == 0 && r.Hits > 0 { - if reqState.IsOwner { + if int64(t.Remaining) == 0 && ctx.Request.Hits > 0 { + if ctx.ReqState.IsOwner { metricOverLimitCounter.Add(1) } rl.Status = Status_OVER_LIMIT @@ -421,24 +420,24 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqStat } // If requested hits takes the remainder - if int64(b.Remaining) == r.Hits { - b.Remaining = 0 - rl.Remaining = int64(b.Remaining) + if int64(t.Remaining) == ctx.Request.Hits { + t.Remaining = 0 + rl.Remaining = int64(t.Remaining) rl.ResetTime = createdAt + (rl.Limit-rl.Remaining)*int64(rate) return rl, nil } // If requested is more than available, then return over the limit // without updating the bucket, unless `DRAIN_OVER_LIMIT` is set. - if r.Hits > int64(b.Remaining) { - if reqState.IsOwner { + if ctx.Request.Hits > int64(t.Remaining) { + if ctx.ReqState.IsOwner { metricOverLimitCounter.Add(1) } rl.Status = Status_OVER_LIMIT // DRAIN_OVER_LIMIT behavior drains the remaining counter. - if HasBehavior(r.Behavior, Behavior_DRAIN_OVER_LIMIT) { - b.Remaining = 0 + if HasBehavior(ctx.Request.Behavior, Behavior_DRAIN_OVER_LIMIT) { + t.Remaining = 0 rl.Remaining = 0 } @@ -446,25 +445,25 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqStat } // Client is only interested in retrieving the current status - if r.Hits == 0 { + if ctx.Request.Hits == 0 { return rl, nil } - b.Remaining -= float64(r.Hits) - rl.Remaining = int64(b.Remaining) + t.Remaining -= float64(ctx.Request.Hits) + rl.Remaining = int64(t.Remaining) rl.ResetTime = createdAt + (rl.Limit-rl.Remaining)*int64(rate) return rl, nil } // Called by leakyBucket() when adding a new item in the store. -func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqState RateLimitReqState) (resp *RateLimitResp, err error) { - createdAt := *r.CreatedAt - duration := r.Duration - rate := float64(duration) / float64(r.Limit) - if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { +func initLeakyBucketItem(ctx rateContext) (resp *RateLimitResp, err error) { + createdAt := *ctx.Request.CreatedAt + duration := ctx.Request.Duration + rate := float64(duration) / float64(ctx.Request.Limit) + if HasBehavior(ctx.Request.Behavior, Behavior_DURATION_IS_GREGORIAN) { n := clock.Now() - expire, err := GregorianExpiration(n, r.Duration) + expire, err := GregorianExpiration(n, ctx.Request.Duration) if err != nil { return nil, err } @@ -475,23 +474,23 @@ func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, // Create a new leaky bucket b := LeakyBucketItem{ - Remaining: float64(r.Burst - r.Hits), - Limit: r.Limit, + Remaining: float64(ctx.Request.Burst - ctx.Request.Hits), + Limit: ctx.Request.Limit, Duration: duration, UpdatedAt: createdAt, - Burst: r.Burst, + Burst: ctx.Request.Burst, } rl := RateLimitResp{ Status: Status_UNDER_LIMIT, Limit: b.Limit, - Remaining: r.Burst - r.Hits, - ResetTime: createdAt + (b.Limit-(r.Burst-r.Hits))*int64(rate), + Remaining: ctx.Request.Burst - ctx.Request.Hits, + ResetTime: createdAt + (b.Limit-(ctx.Request.Burst-ctx.Request.Hits))*int64(rate), } // Client could be requesting that we start with the bucket OVER_LIMIT - if r.Hits > r.Burst { - if reqState.IsOwner { + if ctx.Request.Hits > ctx.Request.Burst { + if ctx.ReqState.IsOwner { metricOverLimitCounter.Add(1) } rl.Status = Status_OVER_LIMIT @@ -500,19 +499,33 @@ func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, b.Remaining = 0 } - item := &CacheItem{ - ExpireAt: createdAt + duration, - Algorithm: r.Algorithm, - Key: r.HashKey(), - Value: &b, - } - - if !c.Add(item) { - return nil, errAlreadyExistsInCache + if ctx.CacheItem != nil { + ctx.CacheItem.mutex.Lock() + ctx.CacheItem.Algorithm = ctx.Request.Algorithm + ctx.CacheItem.ExpireAt = createdAt + duration + in, ok := ctx.CacheItem.Value.(*LeakyBucketItem) + if !ok { + // Likely the store gave us the wrong cache item + ctx.CacheItem.mutex.Unlock() + ctx.CacheItem = nil + return initLeakyBucketItem(ctx) + } + *in = b + ctx.CacheItem.mutex.Unlock() + } else { + ctx.CacheItem = &CacheItem{ + ExpireAt: createdAt + duration, + Algorithm: ctx.Request.Algorithm, + Key: ctx.Request.HashKey(), + Value: &b, + } + if !ctx.Cache.Add(ctx.CacheItem) { + return nil, errAlreadyExistsInCache + } } - if s != nil && reqState.IsOwner { - s.OnChange(ctx, r, item) + if ctx.Store != nil && ctx.ReqState.IsOwner { + ctx.Store.OnChange(ctx, ctx.Request, ctx.CacheItem) } return &rl, nil diff --git a/cache.go b/cache.go index dbeea21..70631e4 100644 --- a/cache.go +++ b/cache.go @@ -20,7 +20,6 @@ import "sync" type Cache interface { Add(item *CacheItem) bool - UpdateExpiration(key string, expireAt int64) bool GetItem(key string) (value *CacheItem, ok bool) Each() chan *CacheItem Remove(key string) diff --git a/cache_manager.go b/cache_manager.go index 3514ce1..90555d4 100644 --- a/cache_manager.go +++ b/cache_manager.go @@ -59,12 +59,11 @@ func (m *cacheManager) GetRateLimit(ctx context.Context, req *RateLimitReq, stat switch req.Algorithm { case Algorithm_TOKEN_BUCKET: rlResponse, err = tokenBucket(rateContext{ - Store: m.conf.Store, - Cache: m.cache, - ReqState: state, - Request: req, - Context: ctx, - InstanceID: m.conf.InstanceID, + Store: m.conf.Store, + Cache: m.cache, + ReqState: state, + Request: req, + Context: ctx, }) if err != nil { msg := "Error in tokenBucket" @@ -72,7 +71,13 @@ func (m *cacheManager) GetRateLimit(ctx context.Context, req *RateLimitReq, stat } case Algorithm_LEAKY_BUCKET: - rlResponse, err = leakyBucket(ctx, m.conf.Store, m.cache, req, state) + rlResponse, err = leakyBucket(rateContext{ + Store: m.conf.Store, + Cache: m.cache, + ReqState: state, + Request: req, + Context: ctx, + }) if err != nil { msg := "Error in leakyBucket" countError(err, msg) diff --git a/lrucache.go b/lrucache.go index 5bef041..83fad9a 100644 --- a/lrucache.go +++ b/lrucache.go @@ -119,13 +119,6 @@ func (c *LRUCache) GetItem(key string) (item *CacheItem, ok bool) { if ele, hit := c.cache[key]; hit { entry := ele.Value.(*CacheItem) - // TODO(thrawn01): Remove - //if entry.IsExpired() { - // c.removeElement(ele) - // metricCacheAccess.WithLabelValues("miss").Add(1) - // return - //} - metricCacheAccess.WithLabelValues("hit").Add(1) c.ll.MoveToFront(ele) return entry, true @@ -168,19 +161,6 @@ func (c *LRUCache) Size() int64 { return atomic.LoadInt64(&c.cacheLen) } -// UpdateExpiration updates the expiration time for the key -func (c *LRUCache) UpdateExpiration(key string, expireAt int64) bool { - c.mu.Lock() - defer c.mu.Unlock() - - if ele, hit := c.cache[key]; hit { - entry := ele.Value.(*CacheItem) - entry.ExpireAt = expireAt - return true - } - return false -} - func (c *LRUCache) Close() error { c.cache = nil c.ll = nil diff --git a/mock_cache_test.go b/mock_cache_test.go index 3eea640..15f12cb 100644 --- a/mock_cache_test.go +++ b/mock_cache_test.go @@ -34,11 +34,6 @@ func (m *MockCache) Add(item *guber.CacheItem) bool { return args.Bool(0) } -func (m *MockCache) UpdateExpiration(key string, expireAt int64) bool { - args := m.Called(key, expireAt) - return args.Bool(0) -} - func (m *MockCache) GetItem(key string) (value *guber.CacheItem, ok bool) { args := m.Called(key) retval, _ := args.Get(0).(*guber.CacheItem) diff --git a/otter.go b/otter.go index 9f60c55..992bbfc 100644 --- a/otter.go +++ b/otter.go @@ -50,33 +50,10 @@ func (o *OtterCache) GetItem(key string) (*CacheItem, bool) { return nil, false } - // TODO(thrawn01): Remove - //if item.IsExpired() { - // metricCacheAccess.WithLabelValues("miss").Add(1) - // // If the item is expired, just return `nil` - // // - // // We avoid the explicit deletion of the expired item to avoid acquiring a mutex lock in otter. - // // Explicit deletions in otter require a mutex, which can cause performance bottlenecks - // // under high concurrency scenarios. By allowing the item to be evicted naturally by - // // otter's eviction mechanism, we avoid impacting performance under high concurrency. - // return nil, false - //} metricCacheAccess.WithLabelValues("hit").Add(1) return item, true } -// UpdateExpiration will update an item in the cache with a new expiration date. -// returns true if the item exists in the cache and was updated. -func (o *OtterCache) UpdateExpiration(key string, expireAt int64) bool { - item, ok := o.cache.Get(key) - if !ok { - return false - } - - item.ExpireAt = expireAt - return true -} - // Each returns a channel which the call can use to iterate through // all the items in the cache. func (o *OtterCache) Each() chan *CacheItem { diff --git a/store.go b/store.go index 089ea50..b96868f 100644 --- a/store.go +++ b/store.go @@ -18,7 +18,6 @@ package gubernator import ( "context" - "sync" ) // PERSISTENT STORE DETAILS @@ -30,7 +29,6 @@ import ( // Both interfaces can be implemented simultaneously to ensure data is always saved to persistent storage. type LeakyBucketItem struct { - mutex sync.Mutex Limit int64 Duration int64 Remaining float64 @@ -81,41 +79,6 @@ type Loader interface { Save(chan *CacheItem) error } -// TODO Remove -//func NewMockStore() *MockStore { -// ml := &MockStore{ -// Called: make(map[string]int), -// CacheItems: make(map[string]*CacheItem), -// } -// ml.Called["OnChange()"] = 0 -// ml.Called["Remove()"] = 0 -// ml.Called["Get()"] = 0 -// return ml -//} -// -//type MockStore struct { -// Called map[string]int -// CacheItems map[string]*CacheItem -//} -// -//var _ Store = &MockStore{} -// -//func (ms *MockStore) OnChange(ctx context.Context, r *RateLimitReq, item *CacheItem) { -// ms.Called["OnChange()"] += 1 -// ms.CacheItems[item.Key] = item -//} -// -//func (ms *MockStore) Get(ctx context.Context, r *RateLimitReq) (*CacheItem, bool) { -// ms.Called["Get()"] += 1 -// item, ok := ms.CacheItems[r.HashKey()] -// return item, ok -//} -// -//func (ms *MockStore) Remove(ctx context.Context, key string) { -// ms.Called["Remove()"] += 1 -// delete(ms.CacheItems, key) -//} - func NewMockLoader() *MockLoader { ml := &MockLoader{ Called: make(map[string]int), diff --git a/store_test.go b/store_test.go index ff29df0..f3e1122 100644 --- a/store_test.go +++ b/store_test.go @@ -150,8 +150,9 @@ func (ms *NoOpStore) Get(ctx context.Context, r *gubernator.RateLimitReq) (*gube // add items to the cache in parallel thus creating a race condition the code must then handle. func TestHighContentionFromStore(t *testing.T) { const ( - numGoroutines = 1_000 - numKeys = 400 + // Increase these number to improve the chance of contention, but at the cost of test speed. + numGoroutines = 500 + numKeys = 100 ) store := &NoOpStore{} srv := newV1Server(t, "localhost:0", gubernator.Config{ From 117ddbb7f901a31794e72f9b0f87586b78f82b6f Mon Sep 17 00:00:00 2001 From: "Derrick J. Wippler" Date: Tue, 14 May 2024 17:42:04 -0500 Subject: [PATCH 4/6] Added Otter cost func, and reduced memory and time it takes to run the benchmarks --- Makefile | 2 +- benchmark_cache_test.go | 3 +-- cache.go | 16 ++++++++++------ otter.go | 9 ++++++++- store.go | 23 +++++++++++++---------- 5 files changed, 33 insertions(+), 20 deletions(-) diff --git a/Makefile b/Makefile index 6ecd257..df4b0bd 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,7 @@ test: ## Run unit tests and measure code coverage .PHONY: bench bench: ## Run Go benchmarks - go test ./... -bench . -benchtime 5s -timeout 0 -run='^$$' -benchmem + go test ./... -bench . -timeout 6m -run='^$$' -benchmem .PHONY: docker docker: ## Build Docker image diff --git a/benchmark_cache_test.go b/benchmark_cache_test.go index 1f849d8..5e24bb0 100644 --- a/benchmark_cache_test.go +++ b/benchmark_cache_test.go @@ -12,6 +12,7 @@ import ( ) func BenchmarkCache(b *testing.B) { + const defaultNumKeys = 8192 testCases := []struct { Name string NewTestCache func() (gubernator.Cache, error) @@ -172,8 +173,6 @@ func BenchmarkCache(b *testing.B) { } } -const defaultNumKeys = 32768 - func GenerateRandomKeys(size int) []string { keys := make([]string, 0, size) for i := 0; i < size; i++ { diff --git a/cache.go b/cache.go index 70631e4..9a318ad 100644 --- a/cache.go +++ b/cache.go @@ -27,19 +27,23 @@ type Cache interface { Close() error } +// CacheItem is 64 bytes aligned in size +// Since both TokenBucketItem and LeakyBucketItem both 40 bytes in size then a CacheItem with +// the Value attached takes up 64 + 40 = 104 bytes of space. Not counting the size of the key. type CacheItem struct { - mutex sync.Mutex - Algorithm Algorithm - Key string - Value interface{} + mutex sync.Mutex // 8 bytes + Key string // 16 bytes + Value interface{} // 16 bytes // Timestamp when rate limit expires in epoch milliseconds. - ExpireAt int64 + ExpireAt int64 // 8 Bytes // Timestamp when the cache should invalidate this rate limit. This is useful when used in conjunction with // a persistent store to ensure our node has the most up to date info from the store. Ignored if set to `0` // It is set by the persistent store implementation to indicate when the node should query the persistent store // for the latest rate limit data. - InvalidAt int64 + InvalidAt int64 // 8 bytes + Algorithm Algorithm // 4 bytes + // 4 Bytes of Padding } func (item *CacheItem) IsExpired() bool { diff --git a/otter.go b/otter.go index 992bbfc..31d7d1c 100644 --- a/otter.go +++ b/otter.go @@ -14,7 +14,8 @@ type OtterCache struct { // NewOtterCache returns a new cache backed by otter. If size is 0, then // the cache is created with a default cache size. func NewOtterCache(size int) (*OtterCache, error) { - setter.SetDefault(&size, 150_000) + // Default is 500k bytes in size + setter.SetDefault(&size, 500_000) b, err := otter.NewBuilder[string, *CacheItem](size) if err != nil { return nil, fmt.Errorf("during otter.NewBuilder(): %w", err) @@ -26,6 +27,12 @@ func NewOtterCache(size int) (*OtterCache, error) { } }) + b.Cost(func(key string, value *CacheItem) uint32 { + // The total size of the CacheItem and Bucket item is 104 bytes. + // See cache.go:CacheItem definition for details. + return uint32(104 + len(value.Key)) + }) + cache, err := b.Build() if err != nil { return nil, fmt.Errorf("during otter.Builder.Build(): %w", err) diff --git a/store.go b/store.go index b96868f..c542d87 100644 --- a/store.go +++ b/store.go @@ -28,20 +28,23 @@ import ( // and `Get()` to keep the in memory cache and persistent store up to date with the latest ratelimit data. // Both interfaces can be implemented simultaneously to ensure data is always saved to persistent storage. +// LeakyBucketItem is 40 bytes aligned in size type LeakyBucketItem struct { - Limit int64 - Duration int64 - Remaining float64 - UpdatedAt int64 - Burst int64 + Limit int64 // 8 bytes + Duration int64 // 8 bytes + Remaining float64 // 8 bytes + UpdatedAt int64 // 8 bytes + Burst int64 // 8 bytes } +// TokenBucketItem is 40 bytes aligned in size type TokenBucketItem struct { - Status Status - Limit int64 - Duration int64 - Remaining int64 - CreatedAt int64 + Limit int64 // 8 bytes + Duration int64 // 8 bytes + Remaining int64 // 8 bytes + CreatedAt int64 // 8 bytes + Status Status // 4 bytes + // 4 bytes of padding } // Store interface allows implementors to off load storage of all or a subset of ratelimits to From 6dff201bba9eec4b16197cc30eb273776c566478 Mon Sep 17 00:00:00 2001 From: "Derrick J. Wippler" Date: Tue, 14 May 2024 20:39:42 -0500 Subject: [PATCH 5/6] Fixed flapping cache eviction test --- cache.go | 1 + lrucache.go | 111 ++++++++++++++++++++++++++++----------------- lrucache_test.go | 3 +- mock_cache_test.go | 4 ++ otter.go | 26 ++++++++--- 5 files changed, 95 insertions(+), 50 deletions(-) diff --git a/cache.go b/cache.go index 9a318ad..ed1444d 100644 --- a/cache.go +++ b/cache.go @@ -24,6 +24,7 @@ type Cache interface { Each() chan *CacheItem Remove(key string) Size() int64 + Stats() CacheStats Close() error } diff --git a/lrucache.go b/lrucache.go index 83fad9a..d3690d4 100644 --- a/lrucache.go +++ b/lrucache.go @@ -34,31 +34,19 @@ type LRUCache struct { cache map[string]*list.Element ll *list.List mu sync.Mutex + stats CacheStats cacheSize int cacheLen int64 } -// CacheCollector provides prometheus metrics collector for LRUCache. -// Register only one collector, add one or more caches to this collector. -type CacheCollector struct { - caches []Cache +type CacheStats struct { + Size int64 + Hit int64 + Miss int64 + UnexpiredEvictions int64 } var _ Cache = &LRUCache{} -var _ prometheus.Collector = &CacheCollector{} - -var metricCacheSize = prometheus.NewGauge(prometheus.GaugeOpts{ - Name: "gubernator_cache_size", - Help: "The number of items in LRU Cache which holds the rate limits.", -}) -var metricCacheAccess = prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "gubernator_cache_access_count", - Help: "Cache access counts. Label \"type\" = hit|miss.", -}, []string{"type"}) -var metricCacheUnexpiredEvictions = prometheus.NewCounter(prometheus.CounterOpts{ - Name: "gubernator_unexpired_evictions_count", - Help: "Count the number of cache items which were evicted while unexpired.", -}) // NewLRUCache creates a new Cache with a maximum size. func NewLRUCache(maxSize int) *LRUCache { @@ -119,12 +107,12 @@ func (c *LRUCache) GetItem(key string) (item *CacheItem, ok bool) { if ele, hit := c.cache[key]; hit { entry := ele.Value.(*CacheItem) - metricCacheAccess.WithLabelValues("hit").Add(1) + c.stats.Hit++ c.ll.MoveToFront(ele) return entry, true } - metricCacheAccess.WithLabelValues("miss").Add(1) + c.stats.Miss++ return } @@ -142,7 +130,7 @@ func (c *LRUCache) removeOldest() { entry := ele.Value.(*CacheItem) if MillisecondNow() < entry.ExpireAt { - metricCacheUnexpiredEvictions.Add(1) + c.stats.UnexpiredEvictions++ } c.removeElement(ele) @@ -168,38 +156,79 @@ func (c *LRUCache) Close() error { return nil } +// Stats returns the current status for the cache +func (c *LRUCache) Stats() CacheStats { + c.mu.Lock() + defer func() { + c.stats = CacheStats{} + c.mu.Unlock() + }() + + c.stats.Size = atomic.LoadInt64(&c.cacheLen) + return c.stats +} + +// CacheCollector provides prometheus metrics collector for LRUCache. +// Register only one collector, add one or more caches to this collector. +type CacheCollector struct { + caches []Cache + metricSize prometheus.Gauge + metricAccess *prometheus.CounterVec + metricUnexpiredEvictions prometheus.Counter +} + func NewCacheCollector() *CacheCollector { return &CacheCollector{ caches: []Cache{}, + metricSize: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "gubernator_cache_size", + Help: "The number of items in LRU Cache which holds the rate limits.", + }), + metricAccess: prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "gubernator_cache_access_count", + Help: "Cache access counts. Label \"type\" = hit|miss.", + }, []string{"type"}), + metricUnexpiredEvictions: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "gubernator_unexpired_evictions_count", + Help: "Count the number of cache items which were evicted while unexpired.", + }), } } +var _ prometheus.Collector = &CacheCollector{} + // AddCache adds a Cache object to be tracked by the collector. -func (collector *CacheCollector) AddCache(cache Cache) { - collector.caches = append(collector.caches, cache) +func (c *CacheCollector) AddCache(cache Cache) { + c.caches = append(c.caches, cache) } // Describe fetches prometheus metrics to be registered -func (collector *CacheCollector) Describe(ch chan<- *prometheus.Desc) { - metricCacheSize.Describe(ch) - metricCacheAccess.Describe(ch) - metricCacheUnexpiredEvictions.Describe(ch) +func (c *CacheCollector) Describe(ch chan<- *prometheus.Desc) { + c.metricSize.Describe(ch) + c.metricAccess.Describe(ch) + c.metricUnexpiredEvictions.Describe(ch) } // Collect fetches metric counts and gauges from the cache -func (collector *CacheCollector) Collect(ch chan<- prometheus.Metric) { - metricCacheSize.Set(collector.getSize()) - metricCacheSize.Collect(ch) - metricCacheAccess.Collect(ch) - metricCacheUnexpiredEvictions.Collect(ch) -} - -func (collector *CacheCollector) getSize() float64 { - var size float64 - - for _, cache := range collector.caches { - size += float64(cache.Size()) +func (c *CacheCollector) Collect(ch chan<- prometheus.Metric) { + stats := c.getStats() + c.metricSize.Set(float64(stats.Size)) + c.metricSize.Collect(ch) + c.metricAccess.WithLabelValues("miss").Add(float64(stats.Miss)) + c.metricAccess.WithLabelValues("hit").Add(float64(stats.Hit)) + c.metricAccess.Collect(ch) + c.metricUnexpiredEvictions.Add(float64(stats.UnexpiredEvictions)) + c.metricUnexpiredEvictions.Collect(ch) +} + +func (c *CacheCollector) getStats() CacheStats { + var total CacheStats + for _, cache := range c.caches { + stats := cache.Stats() + total.Hit += stats.Hit + total.Miss += stats.Miss + total.Size += stats.Size + total.UnexpiredEvictions += stats.UnexpiredEvictions } - - return size + return total } diff --git a/lrucache_test.go b/lrucache_test.go index 402dd83..91262ec 100644 --- a/lrucache_test.go +++ b/lrucache_test.go @@ -27,10 +27,9 @@ import ( "github.com/gubernator-io/gubernator/v2" "github.com/mailgun/holster/v4/clock" "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - dto "github.com/prometheus/client_model/go" ) func TestLRUCache(t *testing.T) { diff --git a/mock_cache_test.go b/mock_cache_test.go index 15f12cb..536a93b 100644 --- a/mock_cache_test.go +++ b/mock_cache_test.go @@ -55,6 +55,10 @@ func (m *MockCache) Size() int64 { return int64(args.Int(0)) } +func (m *MockCache) Stats() guber.CacheStats { + return guber.CacheStats{} +} + func (m *MockCache) Close() error { args := m.Called() return args.Error(0) diff --git a/otter.go b/otter.go index 31d7d1c..3ba29f9 100644 --- a/otter.go +++ b/otter.go @@ -2,6 +2,7 @@ package gubernator import ( "fmt" + "sync/atomic" "github.com/mailgun/holster/v4/setter" "github.com/maypok86/otter" @@ -9,6 +10,7 @@ import ( type OtterCache struct { cache otter.Cache[string, *CacheItem] + stats CacheStats } // NewOtterCache returns a new cache backed by otter. If size is 0, then @@ -21,9 +23,11 @@ func NewOtterCache(size int) (*OtterCache, error) { return nil, fmt.Errorf("during otter.NewBuilder(): %w", err) } + o := &OtterCache{} + b.DeletionListener(func(key string, value *CacheItem, cause otter.DeletionCause) { if cause == otter.Size { - metricCacheUnexpiredEvictions.Add(1) + atomic.AddInt64(&o.stats.UnexpiredEvictions, 1) } }) @@ -33,13 +37,11 @@ func NewOtterCache(size int) (*OtterCache, error) { return uint32(104 + len(value.Key)) }) - cache, err := b.Build() + o.cache, err = b.Build() if err != nil { return nil, fmt.Errorf("during otter.Builder.Build(): %w", err) } - return &OtterCache{ - cache: cache, - }, nil + return o, nil } // Add adds a new CacheItem to the cache. The key must be provided via CacheItem.Key @@ -53,11 +55,11 @@ func (o *OtterCache) Add(item *CacheItem) bool { func (o *OtterCache) GetItem(key string) (*CacheItem, bool) { item, ok := o.cache.Get(key) if !ok { - metricCacheAccess.WithLabelValues("miss").Add(1) + atomic.AddInt64(&o.stats.Miss, 1) return nil, false } - metricCacheAccess.WithLabelValues("hit").Add(1) + atomic.AddInt64(&o.stats.Hit, 1) return item, true } @@ -89,6 +91,16 @@ func (o *OtterCache) Size() int64 { return int64(o.cache.Size()) } +// Stats returns the current cache stats and resets the values to zero +func (o *OtterCache) Stats() CacheStats { + var result CacheStats + result.UnexpiredEvictions = atomic.SwapInt64(&o.stats.UnexpiredEvictions, 0) + result.Miss = atomic.SwapInt64(&o.stats.Miss, 0) + result.Hit = atomic.SwapInt64(&o.stats.Hit, 0) + result.Size = int64(o.cache.Size()) + return result +} + // Close closes the cache and all associated background processes func (o *OtterCache) Close() error { o.cache.Close() From 4e09e4246621ed21298ba17662a96670e7807373 Mon Sep 17 00:00:00 2001 From: "Derrick J. Wippler" Date: Tue, 14 May 2024 23:42:33 -0500 Subject: [PATCH 6/6] Added LRUMutexCache restored LRUCache --- algorithms.go | 8 +- benchmark_cache_test.go | 17 +- cache.go | 105 ++++++- cache_manager.go | 11 +- cache_manager_test.go | 2 +- cluster/cluster.go | 1 + config.go | 6 + daemon.go | 23 +- example.conf | 11 +- lrucache.go | 154 ++++------- lrumutex.go | 163 +++++++++++ lrumutext_test.go | 594 ++++++++++++++++++++++++++++++++++++++++ mock_cache_test.go | 5 + otter.go | 11 +- otter_test.go | 28 +- 15 files changed, 1000 insertions(+), 139 deletions(-) create mode 100644 lrumutex.go create mode 100644 lrumutext_test.go diff --git a/algorithms.go b/algorithms.go index a84c84c..bfeceee 100644 --- a/algorithms.go +++ b/algorithms.go @@ -56,7 +56,7 @@ func tokenBucket(ctx rateContext) (resp *RateLimitResp, err error) { // If not in the cache, check the store if provided if ctx.Store != nil && !ok { if ctx.CacheItem, ok = ctx.Store.Get(ctx, ctx.Request); ok { - if !ctx.Cache.Add(ctx.CacheItem) { + if !ctx.Cache.AddIfNotPresent(ctx.CacheItem) { // Someone else added a new token bucket item to the cache for this // rate limit before we did, so we retry by calling ourselves recursively. return tokenBucket(ctx) @@ -270,7 +270,7 @@ func initTokenBucketItem(ctx rateContext) (resp *RateLimitResp, err error) { Value: &t, ExpireAt: expire, } - if !ctx.Cache.Add(ctx.CacheItem) { + if !ctx.Cache.AddIfNotPresent(ctx.CacheItem) { return rl, errAlreadyExistsInCache } } @@ -299,7 +299,7 @@ func leakyBucket(ctx rateContext) (resp *RateLimitResp, err error) { if ctx.Store != nil && !ok { // Cache missed, check our store for the item. if ctx.CacheItem, ok = ctx.Store.Get(ctx, ctx.Request); ok { - if !ctx.Cache.Add(ctx.CacheItem) { + if !ctx.Cache.AddIfNotPresent(ctx.CacheItem) { // Someone else added a new leaky bucket item to the cache for this // rate limit before we did, so we retry by calling ourselves recursively. return leakyBucket(ctx) @@ -519,7 +519,7 @@ func initLeakyBucketItem(ctx rateContext) (resp *RateLimitResp, err error) { Key: ctx.Request.HashKey(), Value: &b, } - if !ctx.Cache.Add(ctx.CacheItem) { + if !ctx.Cache.AddIfNotPresent(ctx.CacheItem) { return nil, errAlreadyExistsInCache } } diff --git a/benchmark_cache_test.go b/benchmark_cache_test.go index 5e24bb0..e2869d3 100644 --- a/benchmark_cache_test.go +++ b/benchmark_cache_test.go @@ -25,6 +25,13 @@ func BenchmarkCache(b *testing.B) { }, LockRequired: true, }, + { + Name: "LRUMutexCache", + NewTestCache: func() (gubernator.Cache, error) { + return gubernator.NewLRUMutexCache(0), nil + }, + LockRequired: true, + }, { Name: "OtterCache", NewTestCache: func() (gubernator.Cache, error) { @@ -48,7 +55,7 @@ func BenchmarkCache(b *testing.B) { Value: "value:" + key, ExpireAt: expire, } - cache.Add(item) + cache.AddIfNotPresent(item) } mask := len(keys) - 1 @@ -78,7 +85,7 @@ func BenchmarkCache(b *testing.B) { Value: "value:" + keys[index&mask], ExpireAt: expire, } - cache.Add(item) + cache.AddIfNotPresent(item) } }) @@ -94,7 +101,7 @@ func BenchmarkCache(b *testing.B) { Value: "value:" + key, ExpireAt: expire, } - cache.Add(item) + cache.AddIfNotPresent(item) } var mutex sync.Mutex @@ -144,7 +151,7 @@ func BenchmarkCache(b *testing.B) { Value: "value:" + key, ExpireAt: expire, } - cache.Add(item) + cache.AddIfNotPresent(item) } } else { task = func(key string) { @@ -153,7 +160,7 @@ func BenchmarkCache(b *testing.B) { Value: "value:" + key, ExpireAt: expire, } - cache.Add(item) + cache.AddIfNotPresent(item) } } diff --git a/cache.go b/cache.go index ed1444d..c9d095f 100644 --- a/cache.go +++ b/cache.go @@ -16,10 +16,25 @@ limitations under the License. package gubernator -import "sync" +import ( + "sync" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/mailgun/holster/v4/clock" +) type Cache interface { + // Add adds an item, or replaces an item in the cache + // + // Deprecated: Gubernator algorithms now use AddIfNotExists. + // TODO: Remove this method in v3 Add(item *CacheItem) bool + + // AddIfNotPresent adds the item to the cache if it doesn't already exist. + // Returns true if the item was added, false if the item already exists. + AddIfNotPresent(item *CacheItem) bool + GetItem(key string) (value *CacheItem, ok bool) Each() chan *CacheItem Remove(key string) @@ -66,3 +81,91 @@ func (item *CacheItem) IsExpired() bool { return false } + +func (item *CacheItem) Copy(from *CacheItem) { + item.mutex.Lock() + defer item.mutex.Unlock() + + item.InvalidAt = from.InvalidAt + item.Algorithm = from.Algorithm + item.ExpireAt = from.ExpireAt + item.Value = from.Value + item.Key = from.Key +} + +// MillisecondNow returns unix epoch in milliseconds +func MillisecondNow() int64 { + return clock.Now().UnixNano() / 1000000 +} + +type CacheStats struct { + Size int64 + Hit int64 + Miss int64 + UnexpiredEvictions int64 +} + +// CacheCollector provides prometheus metrics collector for Cache implementations +// Register only one collector, add one or more caches to this collector. +type CacheCollector struct { + caches []Cache + metricSize prometheus.Gauge + metricAccess *prometheus.CounterVec + metricUnexpiredEvictions prometheus.Counter +} + +func NewCacheCollector() *CacheCollector { + return &CacheCollector{ + caches: []Cache{}, + metricSize: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "gubernator_cache_size", + Help: "The number of items in LRU Cache which holds the rate limits.", + }), + metricAccess: prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "gubernator_cache_access_count", + Help: "Cache access counts. Label \"type\" = hit|miss.", + }, []string{"type"}), + metricUnexpiredEvictions: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "gubernator_unexpired_evictions_count", + Help: "Count the number of cache items which were evicted while unexpired.", + }), + } +} + +var _ prometheus.Collector = &CacheCollector{} + +// AddCache adds a Cache object to be tracked by the collector. +func (c *CacheCollector) AddCache(cache Cache) { + c.caches = append(c.caches, cache) +} + +// Describe fetches prometheus metrics to be registered +func (c *CacheCollector) Describe(ch chan<- *prometheus.Desc) { + c.metricSize.Describe(ch) + c.metricAccess.Describe(ch) + c.metricUnexpiredEvictions.Describe(ch) +} + +// Collect fetches metric counts and gauges from the cache +func (c *CacheCollector) Collect(ch chan<- prometheus.Metric) { + stats := c.getStats() + c.metricSize.Set(float64(stats.Size)) + c.metricSize.Collect(ch) + c.metricAccess.WithLabelValues("miss").Add(float64(stats.Miss)) + c.metricAccess.WithLabelValues("hit").Add(float64(stats.Hit)) + c.metricAccess.Collect(ch) + c.metricUnexpiredEvictions.Add(float64(stats.UnexpiredEvictions)) + c.metricUnexpiredEvictions.Collect(ch) +} + +func (c *CacheCollector) getStats() CacheStats { + var total CacheStats + for _, cache := range c.caches { + stats := cache.Stats() + total.Hit += stats.Hit + total.Miss += stats.Miss + total.Size += stats.Size + total.UnexpiredEvictions += stats.UnexpiredEvictions + } + return total +} diff --git a/cache_manager.go b/cache_manager.go index 90555d4..5723aa1 100644 --- a/cache_manager.go +++ b/cache_manager.go @@ -147,7 +147,14 @@ func (m *cacheManager) Load(ctx context.Context) error { case <-ctx.Done(): return ctx.Err() } - _ = m.cache.Add(item) + retry: + if !m.cache.AddIfNotPresent(item) { + cItem, ok := m.cache.GetItem(item.Key) + if !ok { + goto retry + } + cItem.Copy(item) + } } } @@ -160,6 +167,6 @@ func (m *cacheManager) GetCacheItem(_ context.Context, key string) (*CacheItem, // AddCacheItem adds an item to the cache. The CacheItem.Key should be set correctly, else the item // will not be added to the cache correctly. func (m *cacheManager) AddCacheItem(_ context.Context, _ string, item *CacheItem) error { - _ = m.cache.Add(item) + _ = m.cache.AddIfNotPresent(item) return nil } diff --git a/cache_manager_test.go b/cache_manager_test.go index 81e3d3c..a206d5f 100644 --- a/cache_manager_test.go +++ b/cache_manager_test.go @@ -76,7 +76,7 @@ func TestCacheManager(t *testing.T) { // Mock Cache. for _, item := range cacheItems { - mockCache.On("Add", item).Once().Return(false) + mockCache.On("AddIfNotPresent", item).Once().Return(true) } // Call code. diff --git a/cluster/cluster.go b/cluster/cluster.go index ad3714e..8eaf100 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -174,6 +174,7 @@ func StartWith(localPeers []gubernator.PeerInfo) error { GRPCListenAddress: peer.GRPCAddress, HTTPListenAddress: peer.HTTPAddress, DataCenter: peer.DataCenter, + CacheProvider: "otter", Behaviors: gubernator.BehaviorConfig{ // Suitable for testing but not production GlobalSyncWait: clock.Millisecond * 50, diff --git a/config.go b/config.go index c0918e4..8760198 100644 --- a/config.go +++ b/config.go @@ -249,6 +249,9 @@ type DaemonConfig struct { // (Optional) TraceLevel sets the tracing level, this controls the number of spans included in a single trace. // Valid options are (tracing.InfoLevel, tracing.DebugLevel) Defaults to tracing.InfoLevel TraceLevel tracing.Level + + // (Optional) CacheProvider specifies which cache implementation to store rate limits in + CacheProvider string } func (d *DaemonConfig) ClientTLS() *tls.Config { @@ -420,7 +423,10 @@ func SetupDaemonConfig(logger *logrus.Logger, configFile io.Reader) (DaemonConfi setter.SetDefault(&conf.DNSPoolConf.ResolvConf, os.Getenv("GUBER_RESOLV_CONF"), "/etc/resolv.conf") setter.SetDefault(&conf.DNSPoolConf.OwnAddress, conf.AdvertiseAddress) + setter.SetDefault(&conf.CacheProvider, os.Getenv("GUBER_CACHE_PROVIDER"), "default-lru") + // PeerPicker Config + // TODO: Deprecated: Remove in GUBER_PEER_PICKER in v3 if pp := os.Getenv("GUBER_PEER_PICKER"); pp != "" { var replicas int var hash string diff --git a/daemon.go b/daemon.go index a2701f7..a67444f 100644 --- a/daemon.go +++ b/daemon.go @@ -97,14 +97,23 @@ func (s *Daemon) Start(ctx context.Context) error { } cacheFactory := func(maxSize int) (Cache, error) { - //cache := NewLRUCache(maxSize) - // TODO: Enable Otter as default or provide a config option - cache, err := NewOtterCache(maxSize) - if err != nil { - return nil, err + // TODO(thrawn01): Make Otter the default in gubernator V3 + switch s.conf.CacheProvider { + case "otter": + cache, err := NewOtterCache(maxSize) + if err != nil { + return nil, err + } + cacheCollector.AddCache(cache) + return cache, nil + case "default-lru", "": + cache := NewLRUMutexCache(maxSize) + cacheCollector.AddCache(cache) + return cache, nil + default: + return nil, errors.Errorf("'GUBER_CACHE_PROVIDER=%s' is invalid; "+ + "choices are ['otter', 'default-lru']", s.conf.CacheProvider) } - cacheCollector.AddCache(cache) - return cache, nil } // Handler to collect duration and API access metrics for GRPC diff --git a/example.conf b/example.conf index 3396a40..254db8d 100644 --- a/example.conf +++ b/example.conf @@ -264,4 +264,13 @@ GUBER_INSTANCE_ID= ############################ # OTEL_EXPORTER_OTLP_PROTOCOL=otlp # OTEL_EXPORTER_OTLP_ENDPOINT=https://api.honeycomb.io -# OTEL_EXPORTER_OTLP_HEADERS=x-honeycomb-team= \ No newline at end of file +# OTEL_EXPORTER_OTLP_HEADERS=x-honeycomb-team= + +############################ +# Cache Providers +############################ +# +# Select the cache provider, available options are 'default-lru', 'otter' +# default-lru - A built in LRU implementation which uses a mutex +# otter - Is a lock-less cache implementation based on S3-FIFO algorithm (https://maypok86.github.io/otter/) +# GUBER_CACHE_PROVIDER=default-lru diff --git a/lrucache.go b/lrucache.go index d3690d4..9001905 100644 --- a/lrucache.go +++ b/lrucache.go @@ -20,35 +20,27 @@ package gubernator import ( "container/list" - "sync" "sync/atomic" - "github.com/mailgun/holster/v4/clock" "github.com/mailgun/holster/v4/setter" - "github.com/prometheus/client_golang/prometheus" ) // LRUCache is an LRU cache that supports expiration and is not thread-safe // Be sure to use a mutex to prevent concurrent method calls. +// +// Deprecated: Use LRUMutexCache instead. This will be removed in v3 type LRUCache struct { cache map[string]*list.Element ll *list.List - mu sync.Mutex stats CacheStats cacheSize int cacheLen int64 } -type CacheStats struct { - Size int64 - Hit int64 - Miss int64 - UnexpiredEvictions int64 -} - var _ Cache = &LRUCache{} // NewLRUCache creates a new Cache with a maximum size. +// Deprecated: Use NewLRUMutexCache instead. This will be removed in v3 func NewLRUCache(maxSize int) *LRUCache { setter.SetDefault(&maxSize, 50_000) @@ -59,10 +51,10 @@ func NewLRUCache(maxSize int) *LRUCache { } } -// Each is not thread-safe. Each() maintains a goroutine that iterates. -// Other go routines cannot safely access the Cache while iterating. -// It would be safer if this were done using an iterator or delegate pattern -// that doesn't require a goroutine. May need to reassess functional requirements. +// Each maintains a goroutine that iterates. Other go routines cannot safely +// access the Cache while iterating. It would be safer if this were done +// using an iterator or delegate pattern that doesn't require a goroutine. +// May need to reassess functional requirements. func (c *LRUCache) Each() chan *CacheItem { out := make(chan *CacheItem) go func() { @@ -74,11 +66,9 @@ func (c *LRUCache) Each() chan *CacheItem { return out } -// Add adds a value to the cache. -func (c *LRUCache) Add(item *CacheItem) bool { - c.mu.Lock() - defer c.mu.Unlock() - +// AddIfNotPresent adds a value to the cache. returns true if value was added to the cache, +// false if it already exists in the cache +func (c *LRUCache) AddIfNotPresent(item *CacheItem) bool { // If the key already exist, set the new value if ee, ok := c.cache[item.Key]; ok { c.ll.MoveToFront(ee) @@ -95,24 +85,43 @@ func (c *LRUCache) Add(item *CacheItem) bool { return false } -// MillisecondNow returns unix epoch in milliseconds -func MillisecondNow() int64 { - return clock.Now().UnixNano() / 1000000 +// Add adds a value to the cache. +// Deprecated: Gubernator algorithms now use AddIfNotExists. +// This method will be removed in the next major version +func (c *LRUCache) Add(item *CacheItem) bool { + // If the key already exist, set the new value + if ee, ok := c.cache[item.Key]; ok { + c.ll.MoveToFront(ee) + ee.Value = item + return true + } + + ele := c.ll.PushFront(item) + c.cache[item.Key] = ele + if c.cacheSize != 0 && c.ll.Len() > c.cacheSize { + c.removeOldest() + } + atomic.StoreInt64(&c.cacheLen, int64(c.ll.Len())) + return false } // GetItem returns the item stored in the cache func (c *LRUCache) GetItem(key string) (item *CacheItem, ok bool) { - c.mu.Lock() - defer c.mu.Unlock() if ele, hit := c.cache[key]; hit { entry := ele.Value.(*CacheItem) - c.stats.Hit++ + if entry.IsExpired() { + c.removeElement(ele) + atomic.AddInt64(&c.stats.Miss, 1) + return + } + + atomic.AddInt64(&c.stats.Hit, 1) c.ll.MoveToFront(ele) return entry, true } - c.stats.Miss++ + atomic.AddInt64(&c.stats.Miss, 1) return } @@ -130,7 +139,7 @@ func (c *LRUCache) removeOldest() { entry := ele.Value.(*CacheItem) if MillisecondNow() < entry.ExpireAt { - c.stats.UnexpiredEvictions++ + atomic.AddInt64(&c.stats.UnexpiredEvictions, 1) } c.removeElement(ele) @@ -149,6 +158,16 @@ func (c *LRUCache) Size() int64 { return atomic.LoadInt64(&c.cacheLen) } +// UpdateExpiration updates the expiration time for the key +func (c *LRUCache) UpdateExpiration(key string, expireAt int64) bool { + if ele, hit := c.cache[key]; hit { + entry := ele.Value.(*CacheItem) + entry.ExpireAt = expireAt + return true + } + return false +} + func (c *LRUCache) Close() error { c.cache = nil c.ll = nil @@ -158,77 +177,10 @@ func (c *LRUCache) Close() error { // Stats returns the current status for the cache func (c *LRUCache) Stats() CacheStats { - c.mu.Lock() - defer func() { - c.stats = CacheStats{} - c.mu.Unlock() - }() - - c.stats.Size = atomic.LoadInt64(&c.cacheLen) - return c.stats -} - -// CacheCollector provides prometheus metrics collector for LRUCache. -// Register only one collector, add one or more caches to this collector. -type CacheCollector struct { - caches []Cache - metricSize prometheus.Gauge - metricAccess *prometheus.CounterVec - metricUnexpiredEvictions prometheus.Counter -} - -func NewCacheCollector() *CacheCollector { - return &CacheCollector{ - caches: []Cache{}, - metricSize: prometheus.NewGauge(prometheus.GaugeOpts{ - Name: "gubernator_cache_size", - Help: "The number of items in LRU Cache which holds the rate limits.", - }), - metricAccess: prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "gubernator_cache_access_count", - Help: "Cache access counts. Label \"type\" = hit|miss.", - }, []string{"type"}), - metricUnexpiredEvictions: prometheus.NewCounter(prometheus.CounterOpts{ - Name: "gubernator_unexpired_evictions_count", - Help: "Count the number of cache items which were evicted while unexpired.", - }), - } -} - -var _ prometheus.Collector = &CacheCollector{} - -// AddCache adds a Cache object to be tracked by the collector. -func (c *CacheCollector) AddCache(cache Cache) { - c.caches = append(c.caches, cache) -} - -// Describe fetches prometheus metrics to be registered -func (c *CacheCollector) Describe(ch chan<- *prometheus.Desc) { - c.metricSize.Describe(ch) - c.metricAccess.Describe(ch) - c.metricUnexpiredEvictions.Describe(ch) -} - -// Collect fetches metric counts and gauges from the cache -func (c *CacheCollector) Collect(ch chan<- prometheus.Metric) { - stats := c.getStats() - c.metricSize.Set(float64(stats.Size)) - c.metricSize.Collect(ch) - c.metricAccess.WithLabelValues("miss").Add(float64(stats.Miss)) - c.metricAccess.WithLabelValues("hit").Add(float64(stats.Hit)) - c.metricAccess.Collect(ch) - c.metricUnexpiredEvictions.Add(float64(stats.UnexpiredEvictions)) - c.metricUnexpiredEvictions.Collect(ch) -} - -func (c *CacheCollector) getStats() CacheStats { - var total CacheStats - for _, cache := range c.caches { - stats := cache.Stats() - total.Hit += stats.Hit - total.Miss += stats.Miss - total.Size += stats.Size - total.UnexpiredEvictions += stats.UnexpiredEvictions - } - return total + var result CacheStats + result.UnexpiredEvictions = atomic.SwapInt64(&c.stats.UnexpiredEvictions, 0) + result.Miss = atomic.SwapInt64(&c.stats.Miss, 0) + result.Hit = atomic.SwapInt64(&c.stats.Hit, 0) + result.Size = atomic.LoadInt64(&c.cacheLen) + return result } diff --git a/lrumutex.go b/lrumutex.go new file mode 100644 index 0000000..e54f15e --- /dev/null +++ b/lrumutex.go @@ -0,0 +1,163 @@ +/* +Modifications Copyright 2024 Derrick Wippler + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +This work is derived from github.com/golang/groupcache/lru +*/ + +package gubernator + +import ( + "container/list" + "sync" + "sync/atomic" + + "github.com/mailgun/holster/v4/setter" +) + +// LRUMutexCache is a mutex protected LRU cache that supports expiration and is thread-safe +type LRUMutexCache struct { + cache map[string]*list.Element + ll *list.List + mu sync.Mutex + stats CacheStats + cacheSize int + cacheLen int64 +} + +var _ Cache = &LRUMutexCache{} + +// NewLRUMutexCache creates a new Cache with a maximum size. +func NewLRUMutexCache(maxSize int) *LRUMutexCache { + setter.SetDefault(&maxSize, 50_000) + + return &LRUMutexCache{ + cache: make(map[string]*list.Element), + ll: list.New(), + cacheSize: maxSize, + } +} + +// Each maintains a goroutine that iterates over every item in the cache. +// Other go routines operating on this cache will block until all items +// are read from the returned channel. +func (c *LRUMutexCache) Each() chan *CacheItem { + out := make(chan *CacheItem) + go func() { + c.mu.Lock() + defer c.mu.Unlock() + + for _, ele := range c.cache { + out <- ele.Value.(*CacheItem) + } + close(out) + }() + return out +} + +// Add is a noop as it is deprecated +func (c *LRUMutexCache) Add(item *CacheItem) bool { + return false +} + +// AddIfNotPresent adds the item to the cache if it doesn't already exist. +// Returns true if the item was added, false if the item already exists. +func (c *LRUMutexCache) AddIfNotPresent(item *CacheItem) bool { + c.mu.Lock() + defer c.mu.Unlock() + + // If the key already exist, do nothing + if _, ok := c.cache[item.Key]; ok { + return false + } + + ele := c.ll.PushFront(item) + c.cache[item.Key] = ele + if c.cacheSize != 0 && c.ll.Len() > c.cacheSize { + c.removeOldest() + } + atomic.StoreInt64(&c.cacheLen, int64(c.ll.Len())) + return true +} + +// GetItem returns the item stored in the cache +func (c *LRUMutexCache) GetItem(key string) (item *CacheItem, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + if ele, hit := c.cache[key]; hit { + entry := ele.Value.(*CacheItem) + + c.stats.Hit++ + c.ll.MoveToFront(ele) + return entry, true + } + + c.stats.Miss++ + return +} + +// Remove removes the provided key from the cache. +func (c *LRUMutexCache) Remove(key string) { + c.mu.Lock() + defer c.mu.Unlock() + + if ele, hit := c.cache[key]; hit { + c.removeElement(ele) + } +} + +// RemoveOldest removes the oldest item from the cache. +func (c *LRUMutexCache) removeOldest() { + ele := c.ll.Back() + if ele != nil { + entry := ele.Value.(*CacheItem) + + if MillisecondNow() < entry.ExpireAt { + c.stats.UnexpiredEvictions++ + } + + c.removeElement(ele) + } +} + +func (c *LRUMutexCache) removeElement(e *list.Element) { + c.ll.Remove(e) + kv := e.Value.(*CacheItem) + delete(c.cache, kv.Key) + atomic.StoreInt64(&c.cacheLen, int64(c.ll.Len())) +} + +// Size returns the number of items in the cache. +func (c *LRUMutexCache) Size() int64 { + return atomic.LoadInt64(&c.cacheLen) +} + +func (c *LRUMutexCache) Close() error { + c.cache = nil + c.ll = nil + c.cacheLen = 0 + return nil +} + +// Stats returns the current status for the cache +func (c *LRUMutexCache) Stats() CacheStats { + c.mu.Lock() + defer func() { + c.stats = CacheStats{} + c.mu.Unlock() + }() + + c.stats.Size = atomic.LoadInt64(&c.cacheLen) + return c.stats +} diff --git a/lrumutext_test.go b/lrumutext_test.go new file mode 100644 index 0000000..8e4b637 --- /dev/null +++ b/lrumutext_test.go @@ -0,0 +1,594 @@ +/* +Copyright 2018-2022 Mailgun Technologies Inc + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gubernator_test + +import ( + "fmt" + "math/rand" + "strconv" + "sync" + "testing" + "time" + + "github.com/gubernator-io/gubernator/v2" + "github.com/mailgun/holster/v4/clock" + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLRUMutexCache(t *testing.T) { + const iterations = 1000 + const concurrency = 100 + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + + t.Run("Happy path", func(t *testing.T) { + cache := gubernator.NewLRUMutexCache(0) + + // Populate cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + assert.True(t, cache.AddIfNotPresent(item)) + } + + // Validate cache. + assert.Equal(t, int64(iterations), cache.Size()) + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item, ok := cache.GetItem(key) + require.True(t, ok) + require.NotNil(t, item) + assert.Equal(t, item.Value, i) + } + + // Clear cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + cache.Remove(key) + } + + assert.Zero(t, cache.Size()) + }) + + t.Run("Update an existing key", func(t *testing.T) { + cache := gubernator.NewLRUMutexCache(0) + const key = "foobar" + + // Add key. + item1 := &gubernator.CacheItem{ + Key: key, + Value: "initial value", + ExpireAt: expireAt, + } + require.True(t, cache.AddIfNotPresent(item1)) + + // Update same key. + item2 := &gubernator.CacheItem{ + Key: key, + Value: "new value", + ExpireAt: expireAt, + } + require.False(t, cache.AddIfNotPresent(item2)) + + updateItem, ok := cache.GetItem(item1.Key) + require.True(t, ok) + updateItem.Value = "new value" + + // Verify. + verifyItem, ok := cache.GetItem(key) + require.True(t, ok) + assert.Equal(t, item2, verifyItem) + }) + + t.Run("Concurrent reads", func(t *testing.T) { + cache := gubernator.NewLRUMutexCache(0) + + // Populate cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + assert.True(t, cache.AddIfNotPresent(item)) + } + + assert.Equal(t, int64(iterations), cache.Size()) + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for thread := 0; thread < concurrency; thread++ { + doneWg.Add(1) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item, ok := cache.GetItem(key) + assert.True(t, ok) + require.NotNil(t, item) + assert.Equal(t, item.Value, i) + } + }() + } + + // Wait for goroutines to finish. + launchWg.Done() + doneWg.Wait() + }) + + t.Run("Concurrent writes", func(t *testing.T) { + cache := gubernator.NewLRUMutexCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for thread := 0; thread < concurrency; thread++ { + doneWg.Add(1) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + }() + } + + // Wait for goroutines to finish. + launchWg.Done() + doneWg.Wait() + }) + + t.Run("Concurrent reads and writes", func(t *testing.T) { + cache := gubernator.NewLRUMutexCache(0) + + // Populate cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + assert.True(t, cache.AddIfNotPresent(item)) + } + + assert.Equal(t, int64(iterations), cache.Size()) + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for thread := 0; thread < concurrency; thread++ { + doneWg.Add(2) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item, ok := cache.GetItem(key) + assert.True(t, ok) + require.NotNil(t, item) + assert.Equal(t, item.Value, i) + } + }() + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + }() + } + + // Wait for goroutines to finish. + launchWg.Done() + doneWg.Wait() + }) + + t.Run("Collect metrics during concurrent reads/writes", func(t *testing.T) { + cache := gubernator.NewLRUMutexCache(0) + + // Populate cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + + assert.Equal(t, int64(iterations), cache.Size()) + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for thread := 0; thread < concurrency; thread++ { + doneWg.Add(3) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + // Get, cache hit. + key := strconv.Itoa(i) + _, _ = cache.GetItem(key) + + // Get, cache miss. + key2 := strconv.Itoa(rand.Intn(1000) + 10000) + _, _ = cache.GetItem(key2) + } + }() + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + // Add existing. + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + + // Add new. + key2 := strconv.Itoa(rand.Intn(1000) + 20000) + item2 := &gubernator.CacheItem{ + Key: key2, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item2) + } + }() + + collector := gubernator.NewCacheCollector() + collector.AddCache(cache) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + // Get metrics. + ch := make(chan prometheus.Metric, 10) + collector.Collect(ch) + } + }() + } + + // Wait for goroutines to finish. + launchWg.Done() + doneWg.Wait() + }) + + t.Run("Check gubernator_unexpired_evictions_count metric is not incremented when expired item is evicted", func(t *testing.T) { + defer clock.Freeze(clock.Now()).Unfreeze() + + promRegister := prometheus.NewRegistry() + + // The LRU cache for storing rate limits. + cacheCollector := gubernator.NewCacheCollector() + err := promRegister.Register(cacheCollector) + require.NoError(t, err) + + cache := gubernator.NewLRUMutexCache(10) + cacheCollector.AddCache(cache) + + // fill cache with short duration cache items + for i := 0; i < 10; i++ { + cache.AddIfNotPresent(&gubernator.CacheItem{ + Algorithm: gubernator.Algorithm_LEAKY_BUCKET, + Key: fmt.Sprintf("short-expiry-%d", i), + Value: "bar", + ExpireAt: clock.Now().Add(5 * time.Minute).UnixMilli(), + }) + } + + // jump forward in time to expire all short duration keys + clock.Advance(6 * time.Minute) + + // add a new cache item to force eviction + cache.AddIfNotPresent(&gubernator.CacheItem{ + Algorithm: gubernator.Algorithm_LEAKY_BUCKET, + Key: "evict1", + Value: "bar", + ExpireAt: clock.Now().Add(1 * time.Hour).UnixMilli(), + }) + + collChan := make(chan prometheus.Metric, 64) + cacheCollector.Collect(collChan) + // Check metrics to verify evicted cache key is expired + <-collChan + <-collChan + <-collChan + m := <-collChan // gubernator_unexpired_evictions_count + met := new(dto.Metric) + _ = m.Write(met) + assert.Contains(t, m.Desc().String(), "gubernator_unexpired_evictions_count") + assert.Equal(t, 0, int(*met.Counter.Value)) + }) + + t.Run("Check gubernator_unexpired_evictions_count metric is incremented when unexpired item is evicted", func(t *testing.T) { + defer clock.Freeze(clock.Now()).Unfreeze() + + promRegister := prometheus.NewRegistry() + + // The LRU cache for storing rate limits. + cacheCollector := gubernator.NewCacheCollector() + err := promRegister.Register(cacheCollector) + require.NoError(t, err) + + cache := gubernator.NewLRUMutexCache(10) + cacheCollector.AddCache(cache) + + // fill cache with long duration cache items + for i := 0; i < 10; i++ { + cache.AddIfNotPresent(&gubernator.CacheItem{ + Algorithm: gubernator.Algorithm_LEAKY_BUCKET, + Key: fmt.Sprintf("long-expiry-%d", i), + Value: "bar", + ExpireAt: clock.Now().Add(1 * time.Hour).UnixMilli(), + }) + } + + // add a new cache item to force eviction + cache.AddIfNotPresent(&gubernator.CacheItem{ + Algorithm: gubernator.Algorithm_LEAKY_BUCKET, + Key: "evict2", + Value: "bar", + ExpireAt: clock.Now().Add(1 * time.Hour).UnixMilli(), + }) + + // Check metrics to verify evicted cache key is *NOT* expired + collChan := make(chan prometheus.Metric, 64) + cacheCollector.Collect(collChan) + <-collChan + <-collChan + <-collChan + m := <-collChan // gubernator_unexpired_evictions_count + met := new(dto.Metric) + _ = m.Write(met) + assert.Contains(t, m.Desc().String(), "gubernator_unexpired_evictions_count") + assert.Equal(t, 1, int(*met.Counter.Value)) + }) +} + +func BenchmarkLRUMutexCache(b *testing.B) { + + b.Run("Sequential reads", func(b *testing.B) { + cache := gubernator.NewLRUMutexCache(b.N) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + + // Populate cache. + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + assert.True(b, cache.AddIfNotPresent(item)) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + _, _ = cache.GetItem(key) + } + }) + + b.Run("Sequential writes", func(b *testing.B) { + cache := gubernator.NewLRUMutexCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + }) + + b.Run("Concurrent reads", func(b *testing.B) { + cache := gubernator.NewLRUMutexCache(b.N) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + + // Populate cache. + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + assert.True(b, cache.AddIfNotPresent(item)) + } + + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + doneWg.Add(1) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + _, _ = cache.GetItem(key) + }() + } + + b.ReportAllocs() + b.ResetTimer() + launchWg.Done() + doneWg.Wait() + }) + + b.Run("Concurrent writes", func(b *testing.B) { + cache := gubernator.NewLRUMutexCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + doneWg.Add(1) + + go func(i int) { + defer doneWg.Done() + launchWg.Wait() + + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + }(i) + } + + b.ReportAllocs() + b.ResetTimer() + launchWg.Done() + doneWg.Wait() + }) + + b.Run("Concurrent reads and writes of existing keys", func(b *testing.B) { + cache := gubernator.NewLRUMutexCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + // Populate cache. + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + assert.True(b, cache.AddIfNotPresent(item)) + } + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + doneWg.Add(2) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + _, _ = cache.GetItem(key) + }() + + go func(i int) { + defer doneWg.Done() + launchWg.Wait() + + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + }(i) + } + + b.ReportAllocs() + b.ResetTimer() + launchWg.Done() + doneWg.Wait() + }) + + b.Run("Concurrent reads and writes of non-existent keys", func(b *testing.B) { + cache := gubernator.NewLRUMutexCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for i := 0; i < b.N; i++ { + doneWg.Add(2) + + go func(i int) { + defer doneWg.Done() + launchWg.Wait() + + key := strconv.Itoa(i) + _, _ = cache.GetItem(key) + }(i) + + go func(i int) { + defer doneWg.Done() + launchWg.Wait() + + key := "z" + strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + _ = cache.AddIfNotPresent(item) + }(i) + } + + b.ReportAllocs() + b.ResetTimer() + launchWg.Done() + doneWg.Wait() + }) +} diff --git a/mock_cache_test.go b/mock_cache_test.go index 536a93b..6b72b39 100644 --- a/mock_cache_test.go +++ b/mock_cache_test.go @@ -34,6 +34,11 @@ func (m *MockCache) Add(item *guber.CacheItem) bool { return args.Bool(0) } +func (m *MockCache) AddIfNotPresent(item *guber.CacheItem) bool { + args := m.Called(item) + return args.Bool(0) +} + func (m *MockCache) GetItem(key string) (value *guber.CacheItem, ok bool) { args := m.Called(key) retval, _ := args.Get(0).(*guber.CacheItem) diff --git a/otter.go b/otter.go index 3ba29f9..865a454 100644 --- a/otter.go +++ b/otter.go @@ -44,10 +44,15 @@ func NewOtterCache(size int) (*OtterCache, error) { return o, nil } -// Add adds a new CacheItem to the cache. The key must be provided via CacheItem.Key -// returns true if the item was added to the cache; false if the item was too large -// for the cache or already exists in the cache. +// Add is a noop as it is deprecated func (o *OtterCache) Add(item *CacheItem) bool { + return false +} + +// AddIfNotPresent adds a new CacheItem to the cache. The key must be provided via CacheItem.Key +// returns true if the item was added to the cache; false if the item was too large +// for the cache or if the key already exists in the cache. +func (o *OtterCache) AddIfNotPresent(item *CacheItem) bool { return o.cache.SetIfAbsent(item.Key, item) } diff --git a/otter_test.go b/otter_test.go index 9c84df1..650ea2f 100644 --- a/otter_test.go +++ b/otter_test.go @@ -29,7 +29,7 @@ func TestOtterCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) } // Validate cache. @@ -63,7 +63,7 @@ func TestOtterCache(t *testing.T) { Value: "initial value", ExpireAt: expireAt, } - cache.Add(item1) + cache.AddIfNotPresent(item1) // Update same key is refused item2 := &gubernator.CacheItem{ @@ -71,7 +71,7 @@ func TestOtterCache(t *testing.T) { Value: "new value", ExpireAt: expireAt, } - assert.False(t, cache.Add(item2)) + assert.False(t, cache.AddIfNotPresent(item2)) // Fetch and update the CacheItem update, ok := cache.GetItem(key) @@ -96,7 +96,7 @@ func TestOtterCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) } assert.Equal(t, int64(iterations), cache.Size()) @@ -146,7 +146,7 @@ func TestOtterCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) } }() } @@ -168,7 +168,7 @@ func TestOtterCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) } assert.Equal(t, int64(iterations), cache.Size()) @@ -203,7 +203,7 @@ func TestOtterCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) } }() } @@ -229,7 +229,7 @@ func BenchmarkOtterCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) } b.ReportAllocs() @@ -255,7 +255,7 @@ func BenchmarkOtterCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) } }) @@ -271,7 +271,7 @@ func BenchmarkOtterCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) } var launchWg, doneWg sync.WaitGroup @@ -314,7 +314,7 @@ func BenchmarkOtterCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) }(i) } @@ -338,7 +338,7 @@ func BenchmarkOtterCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) } for i := 0; i < b.N; i++ { @@ -361,7 +361,7 @@ func BenchmarkOtterCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) }(i) } @@ -398,7 +398,7 @@ func BenchmarkOtterCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) }(i) }