diff --git a/algorithms.go b/algorithms.go index a84c84c..bfeceee 100644 --- a/algorithms.go +++ b/algorithms.go @@ -56,7 +56,7 @@ func tokenBucket(ctx rateContext) (resp *RateLimitResp, err error) { // If not in the cache, check the store if provided if ctx.Store != nil && !ok { if ctx.CacheItem, ok = ctx.Store.Get(ctx, ctx.Request); ok { - if !ctx.Cache.Add(ctx.CacheItem) { + if !ctx.Cache.AddIfNotPresent(ctx.CacheItem) { // Someone else added a new token bucket item to the cache for this // rate limit before we did, so we retry by calling ourselves recursively. return tokenBucket(ctx) @@ -270,7 +270,7 @@ func initTokenBucketItem(ctx rateContext) (resp *RateLimitResp, err error) { Value: &t, ExpireAt: expire, } - if !ctx.Cache.Add(ctx.CacheItem) { + if !ctx.Cache.AddIfNotPresent(ctx.CacheItem) { return rl, errAlreadyExistsInCache } } @@ -299,7 +299,7 @@ func leakyBucket(ctx rateContext) (resp *RateLimitResp, err error) { if ctx.Store != nil && !ok { // Cache missed, check our store for the item. if ctx.CacheItem, ok = ctx.Store.Get(ctx, ctx.Request); ok { - if !ctx.Cache.Add(ctx.CacheItem) { + if !ctx.Cache.AddIfNotPresent(ctx.CacheItem) { // Someone else added a new leaky bucket item to the cache for this // rate limit before we did, so we retry by calling ourselves recursively. return leakyBucket(ctx) @@ -519,7 +519,7 @@ func initLeakyBucketItem(ctx rateContext) (resp *RateLimitResp, err error) { Key: ctx.Request.HashKey(), Value: &b, } - if !ctx.Cache.Add(ctx.CacheItem) { + if !ctx.Cache.AddIfNotPresent(ctx.CacheItem) { return nil, errAlreadyExistsInCache } } diff --git a/benchmark_cache_test.go b/benchmark_cache_test.go index 5e24bb0..e2869d3 100644 --- a/benchmark_cache_test.go +++ b/benchmark_cache_test.go @@ -25,6 +25,13 @@ func BenchmarkCache(b *testing.B) { }, LockRequired: true, }, + { + Name: "LRUMutexCache", + NewTestCache: func() (gubernator.Cache, error) { + return gubernator.NewLRUMutexCache(0), nil + }, + LockRequired: true, + }, { Name: "OtterCache", NewTestCache: func() (gubernator.Cache, error) { @@ -48,7 +55,7 @@ func BenchmarkCache(b *testing.B) { Value: "value:" + key, ExpireAt: expire, } - cache.Add(item) + cache.AddIfNotPresent(item) } mask := len(keys) - 1 @@ -78,7 +85,7 @@ func BenchmarkCache(b *testing.B) { Value: "value:" + keys[index&mask], ExpireAt: expire, } - cache.Add(item) + cache.AddIfNotPresent(item) } }) @@ -94,7 +101,7 @@ func BenchmarkCache(b *testing.B) { Value: "value:" + key, ExpireAt: expire, } - cache.Add(item) + cache.AddIfNotPresent(item) } var mutex sync.Mutex @@ -144,7 +151,7 @@ func BenchmarkCache(b *testing.B) { Value: "value:" + key, ExpireAt: expire, } - cache.Add(item) + cache.AddIfNotPresent(item) } } else { task = func(key string) { @@ -153,7 +160,7 @@ func BenchmarkCache(b *testing.B) { Value: "value:" + key, ExpireAt: expire, } - cache.Add(item) + cache.AddIfNotPresent(item) } } diff --git a/cache.go b/cache.go index ed1444d..c9d095f 100644 --- a/cache.go +++ b/cache.go @@ -16,10 +16,25 @@ limitations under the License. package gubernator -import "sync" +import ( + "sync" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/mailgun/holster/v4/clock" +) type Cache interface { + // Add adds an item, or replaces an item in the cache + // + // Deprecated: Gubernator algorithms now use AddIfNotExists. + // TODO: Remove this method in v3 Add(item *CacheItem) bool + + // AddIfNotPresent adds the item to the cache if it doesn't already exist. + // Returns true if the item was added, false if the item already exists. + AddIfNotPresent(item *CacheItem) bool + GetItem(key string) (value *CacheItem, ok bool) Each() chan *CacheItem Remove(key string) @@ -66,3 +81,91 @@ func (item *CacheItem) IsExpired() bool { return false } + +func (item *CacheItem) Copy(from *CacheItem) { + item.mutex.Lock() + defer item.mutex.Unlock() + + item.InvalidAt = from.InvalidAt + item.Algorithm = from.Algorithm + item.ExpireAt = from.ExpireAt + item.Value = from.Value + item.Key = from.Key +} + +// MillisecondNow returns unix epoch in milliseconds +func MillisecondNow() int64 { + return clock.Now().UnixNano() / 1000000 +} + +type CacheStats struct { + Size int64 + Hit int64 + Miss int64 + UnexpiredEvictions int64 +} + +// CacheCollector provides prometheus metrics collector for Cache implementations +// Register only one collector, add one or more caches to this collector. +type CacheCollector struct { + caches []Cache + metricSize prometheus.Gauge + metricAccess *prometheus.CounterVec + metricUnexpiredEvictions prometheus.Counter +} + +func NewCacheCollector() *CacheCollector { + return &CacheCollector{ + caches: []Cache{}, + metricSize: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "gubernator_cache_size", + Help: "The number of items in LRU Cache which holds the rate limits.", + }), + metricAccess: prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "gubernator_cache_access_count", + Help: "Cache access counts. Label \"type\" = hit|miss.", + }, []string{"type"}), + metricUnexpiredEvictions: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "gubernator_unexpired_evictions_count", + Help: "Count the number of cache items which were evicted while unexpired.", + }), + } +} + +var _ prometheus.Collector = &CacheCollector{} + +// AddCache adds a Cache object to be tracked by the collector. +func (c *CacheCollector) AddCache(cache Cache) { + c.caches = append(c.caches, cache) +} + +// Describe fetches prometheus metrics to be registered +func (c *CacheCollector) Describe(ch chan<- *prometheus.Desc) { + c.metricSize.Describe(ch) + c.metricAccess.Describe(ch) + c.metricUnexpiredEvictions.Describe(ch) +} + +// Collect fetches metric counts and gauges from the cache +func (c *CacheCollector) Collect(ch chan<- prometheus.Metric) { + stats := c.getStats() + c.metricSize.Set(float64(stats.Size)) + c.metricSize.Collect(ch) + c.metricAccess.WithLabelValues("miss").Add(float64(stats.Miss)) + c.metricAccess.WithLabelValues("hit").Add(float64(stats.Hit)) + c.metricAccess.Collect(ch) + c.metricUnexpiredEvictions.Add(float64(stats.UnexpiredEvictions)) + c.metricUnexpiredEvictions.Collect(ch) +} + +func (c *CacheCollector) getStats() CacheStats { + var total CacheStats + for _, cache := range c.caches { + stats := cache.Stats() + total.Hit += stats.Hit + total.Miss += stats.Miss + total.Size += stats.Size + total.UnexpiredEvictions += stats.UnexpiredEvictions + } + return total +} diff --git a/cache_manager.go b/cache_manager.go index 90555d4..5723aa1 100644 --- a/cache_manager.go +++ b/cache_manager.go @@ -147,7 +147,14 @@ func (m *cacheManager) Load(ctx context.Context) error { case <-ctx.Done(): return ctx.Err() } - _ = m.cache.Add(item) + retry: + if !m.cache.AddIfNotPresent(item) { + cItem, ok := m.cache.GetItem(item.Key) + if !ok { + goto retry + } + cItem.Copy(item) + } } } @@ -160,6 +167,6 @@ func (m *cacheManager) GetCacheItem(_ context.Context, key string) (*CacheItem, // AddCacheItem adds an item to the cache. The CacheItem.Key should be set correctly, else the item // will not be added to the cache correctly. func (m *cacheManager) AddCacheItem(_ context.Context, _ string, item *CacheItem) error { - _ = m.cache.Add(item) + _ = m.cache.AddIfNotPresent(item) return nil } diff --git a/cache_manager_test.go b/cache_manager_test.go index 81e3d3c..a206d5f 100644 --- a/cache_manager_test.go +++ b/cache_manager_test.go @@ -76,7 +76,7 @@ func TestCacheManager(t *testing.T) { // Mock Cache. for _, item := range cacheItems { - mockCache.On("Add", item).Once().Return(false) + mockCache.On("AddIfNotPresent", item).Once().Return(true) } // Call code. diff --git a/cluster/cluster.go b/cluster/cluster.go index ad3714e..8eaf100 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -174,6 +174,7 @@ func StartWith(localPeers []gubernator.PeerInfo) error { GRPCListenAddress: peer.GRPCAddress, HTTPListenAddress: peer.HTTPAddress, DataCenter: peer.DataCenter, + CacheProvider: "otter", Behaviors: gubernator.BehaviorConfig{ // Suitable for testing but not production GlobalSyncWait: clock.Millisecond * 50, diff --git a/config.go b/config.go index c0918e4..8760198 100644 --- a/config.go +++ b/config.go @@ -249,6 +249,9 @@ type DaemonConfig struct { // (Optional) TraceLevel sets the tracing level, this controls the number of spans included in a single trace. // Valid options are (tracing.InfoLevel, tracing.DebugLevel) Defaults to tracing.InfoLevel TraceLevel tracing.Level + + // (Optional) CacheProvider specifies which cache implementation to store rate limits in + CacheProvider string } func (d *DaemonConfig) ClientTLS() *tls.Config { @@ -420,7 +423,10 @@ func SetupDaemonConfig(logger *logrus.Logger, configFile io.Reader) (DaemonConfi setter.SetDefault(&conf.DNSPoolConf.ResolvConf, os.Getenv("GUBER_RESOLV_CONF"), "/etc/resolv.conf") setter.SetDefault(&conf.DNSPoolConf.OwnAddress, conf.AdvertiseAddress) + setter.SetDefault(&conf.CacheProvider, os.Getenv("GUBER_CACHE_PROVIDER"), "default-lru") + // PeerPicker Config + // TODO: Deprecated: Remove in GUBER_PEER_PICKER in v3 if pp := os.Getenv("GUBER_PEER_PICKER"); pp != "" { var replicas int var hash string diff --git a/daemon.go b/daemon.go index a2701f7..a67444f 100644 --- a/daemon.go +++ b/daemon.go @@ -97,14 +97,23 @@ func (s *Daemon) Start(ctx context.Context) error { } cacheFactory := func(maxSize int) (Cache, error) { - //cache := NewLRUCache(maxSize) - // TODO: Enable Otter as default or provide a config option - cache, err := NewOtterCache(maxSize) - if err != nil { - return nil, err + // TODO(thrawn01): Make Otter the default in gubernator V3 + switch s.conf.CacheProvider { + case "otter": + cache, err := NewOtterCache(maxSize) + if err != nil { + return nil, err + } + cacheCollector.AddCache(cache) + return cache, nil + case "default-lru", "": + cache := NewLRUMutexCache(maxSize) + cacheCollector.AddCache(cache) + return cache, nil + default: + return nil, errors.Errorf("'GUBER_CACHE_PROVIDER=%s' is invalid; "+ + "choices are ['otter', 'default-lru']", s.conf.CacheProvider) } - cacheCollector.AddCache(cache) - return cache, nil } // Handler to collect duration and API access metrics for GRPC diff --git a/example.conf b/example.conf index 3396a40..254db8d 100644 --- a/example.conf +++ b/example.conf @@ -264,4 +264,13 @@ GUBER_INSTANCE_ID= ############################ # OTEL_EXPORTER_OTLP_PROTOCOL=otlp # OTEL_EXPORTER_OTLP_ENDPOINT=https://api.honeycomb.io -# OTEL_EXPORTER_OTLP_HEADERS=x-honeycomb-team= \ No newline at end of file +# OTEL_EXPORTER_OTLP_HEADERS=x-honeycomb-team= + +############################ +# Cache Providers +############################ +# +# Select the cache provider, available options are 'default-lru', 'otter' +# default-lru - A built in LRU implementation which uses a mutex +# otter - Is a lock-less cache implementation based on S3-FIFO algorithm (https://maypok86.github.io/otter/) +# GUBER_CACHE_PROVIDER=default-lru diff --git a/lrucache.go b/lrucache.go index d3690d4..9001905 100644 --- a/lrucache.go +++ b/lrucache.go @@ -20,35 +20,27 @@ package gubernator import ( "container/list" - "sync" "sync/atomic" - "github.com/mailgun/holster/v4/clock" "github.com/mailgun/holster/v4/setter" - "github.com/prometheus/client_golang/prometheus" ) // LRUCache is an LRU cache that supports expiration and is not thread-safe // Be sure to use a mutex to prevent concurrent method calls. +// +// Deprecated: Use LRUMutexCache instead. This will be removed in v3 type LRUCache struct { cache map[string]*list.Element ll *list.List - mu sync.Mutex stats CacheStats cacheSize int cacheLen int64 } -type CacheStats struct { - Size int64 - Hit int64 - Miss int64 - UnexpiredEvictions int64 -} - var _ Cache = &LRUCache{} // NewLRUCache creates a new Cache with a maximum size. +// Deprecated: Use NewLRUMutexCache instead. This will be removed in v3 func NewLRUCache(maxSize int) *LRUCache { setter.SetDefault(&maxSize, 50_000) @@ -59,10 +51,10 @@ func NewLRUCache(maxSize int) *LRUCache { } } -// Each is not thread-safe. Each() maintains a goroutine that iterates. -// Other go routines cannot safely access the Cache while iterating. -// It would be safer if this were done using an iterator or delegate pattern -// that doesn't require a goroutine. May need to reassess functional requirements. +// Each maintains a goroutine that iterates. Other go routines cannot safely +// access the Cache while iterating. It would be safer if this were done +// using an iterator or delegate pattern that doesn't require a goroutine. +// May need to reassess functional requirements. func (c *LRUCache) Each() chan *CacheItem { out := make(chan *CacheItem) go func() { @@ -74,11 +66,9 @@ func (c *LRUCache) Each() chan *CacheItem { return out } -// Add adds a value to the cache. -func (c *LRUCache) Add(item *CacheItem) bool { - c.mu.Lock() - defer c.mu.Unlock() - +// AddIfNotPresent adds a value to the cache. returns true if value was added to the cache, +// false if it already exists in the cache +func (c *LRUCache) AddIfNotPresent(item *CacheItem) bool { // If the key already exist, set the new value if ee, ok := c.cache[item.Key]; ok { c.ll.MoveToFront(ee) @@ -95,24 +85,43 @@ func (c *LRUCache) Add(item *CacheItem) bool { return false } -// MillisecondNow returns unix epoch in milliseconds -func MillisecondNow() int64 { - return clock.Now().UnixNano() / 1000000 +// Add adds a value to the cache. +// Deprecated: Gubernator algorithms now use AddIfNotExists. +// This method will be removed in the next major version +func (c *LRUCache) Add(item *CacheItem) bool { + // If the key already exist, set the new value + if ee, ok := c.cache[item.Key]; ok { + c.ll.MoveToFront(ee) + ee.Value = item + return true + } + + ele := c.ll.PushFront(item) + c.cache[item.Key] = ele + if c.cacheSize != 0 && c.ll.Len() > c.cacheSize { + c.removeOldest() + } + atomic.StoreInt64(&c.cacheLen, int64(c.ll.Len())) + return false } // GetItem returns the item stored in the cache func (c *LRUCache) GetItem(key string) (item *CacheItem, ok bool) { - c.mu.Lock() - defer c.mu.Unlock() if ele, hit := c.cache[key]; hit { entry := ele.Value.(*CacheItem) - c.stats.Hit++ + if entry.IsExpired() { + c.removeElement(ele) + atomic.AddInt64(&c.stats.Miss, 1) + return + } + + atomic.AddInt64(&c.stats.Hit, 1) c.ll.MoveToFront(ele) return entry, true } - c.stats.Miss++ + atomic.AddInt64(&c.stats.Miss, 1) return } @@ -130,7 +139,7 @@ func (c *LRUCache) removeOldest() { entry := ele.Value.(*CacheItem) if MillisecondNow() < entry.ExpireAt { - c.stats.UnexpiredEvictions++ + atomic.AddInt64(&c.stats.UnexpiredEvictions, 1) } c.removeElement(ele) @@ -149,6 +158,16 @@ func (c *LRUCache) Size() int64 { return atomic.LoadInt64(&c.cacheLen) } +// UpdateExpiration updates the expiration time for the key +func (c *LRUCache) UpdateExpiration(key string, expireAt int64) bool { + if ele, hit := c.cache[key]; hit { + entry := ele.Value.(*CacheItem) + entry.ExpireAt = expireAt + return true + } + return false +} + func (c *LRUCache) Close() error { c.cache = nil c.ll = nil @@ -158,77 +177,10 @@ func (c *LRUCache) Close() error { // Stats returns the current status for the cache func (c *LRUCache) Stats() CacheStats { - c.mu.Lock() - defer func() { - c.stats = CacheStats{} - c.mu.Unlock() - }() - - c.stats.Size = atomic.LoadInt64(&c.cacheLen) - return c.stats -} - -// CacheCollector provides prometheus metrics collector for LRUCache. -// Register only one collector, add one or more caches to this collector. -type CacheCollector struct { - caches []Cache - metricSize prometheus.Gauge - metricAccess *prometheus.CounterVec - metricUnexpiredEvictions prometheus.Counter -} - -func NewCacheCollector() *CacheCollector { - return &CacheCollector{ - caches: []Cache{}, - metricSize: prometheus.NewGauge(prometheus.GaugeOpts{ - Name: "gubernator_cache_size", - Help: "The number of items in LRU Cache which holds the rate limits.", - }), - metricAccess: prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "gubernator_cache_access_count", - Help: "Cache access counts. Label \"type\" = hit|miss.", - }, []string{"type"}), - metricUnexpiredEvictions: prometheus.NewCounter(prometheus.CounterOpts{ - Name: "gubernator_unexpired_evictions_count", - Help: "Count the number of cache items which were evicted while unexpired.", - }), - } -} - -var _ prometheus.Collector = &CacheCollector{} - -// AddCache adds a Cache object to be tracked by the collector. -func (c *CacheCollector) AddCache(cache Cache) { - c.caches = append(c.caches, cache) -} - -// Describe fetches prometheus metrics to be registered -func (c *CacheCollector) Describe(ch chan<- *prometheus.Desc) { - c.metricSize.Describe(ch) - c.metricAccess.Describe(ch) - c.metricUnexpiredEvictions.Describe(ch) -} - -// Collect fetches metric counts and gauges from the cache -func (c *CacheCollector) Collect(ch chan<- prometheus.Metric) { - stats := c.getStats() - c.metricSize.Set(float64(stats.Size)) - c.metricSize.Collect(ch) - c.metricAccess.WithLabelValues("miss").Add(float64(stats.Miss)) - c.metricAccess.WithLabelValues("hit").Add(float64(stats.Hit)) - c.metricAccess.Collect(ch) - c.metricUnexpiredEvictions.Add(float64(stats.UnexpiredEvictions)) - c.metricUnexpiredEvictions.Collect(ch) -} - -func (c *CacheCollector) getStats() CacheStats { - var total CacheStats - for _, cache := range c.caches { - stats := cache.Stats() - total.Hit += stats.Hit - total.Miss += stats.Miss - total.Size += stats.Size - total.UnexpiredEvictions += stats.UnexpiredEvictions - } - return total + var result CacheStats + result.UnexpiredEvictions = atomic.SwapInt64(&c.stats.UnexpiredEvictions, 0) + result.Miss = atomic.SwapInt64(&c.stats.Miss, 0) + result.Hit = atomic.SwapInt64(&c.stats.Hit, 0) + result.Size = atomic.LoadInt64(&c.cacheLen) + return result } diff --git a/lrumutex.go b/lrumutex.go new file mode 100644 index 0000000..e54f15e --- /dev/null +++ b/lrumutex.go @@ -0,0 +1,163 @@ +/* +Modifications Copyright 2024 Derrick Wippler + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +This work is derived from github.com/golang/groupcache/lru +*/ + +package gubernator + +import ( + "container/list" + "sync" + "sync/atomic" + + "github.com/mailgun/holster/v4/setter" +) + +// LRUMutexCache is a mutex protected LRU cache that supports expiration and is thread-safe +type LRUMutexCache struct { + cache map[string]*list.Element + ll *list.List + mu sync.Mutex + stats CacheStats + cacheSize int + cacheLen int64 +} + +var _ Cache = &LRUMutexCache{} + +// NewLRUMutexCache creates a new Cache with a maximum size. +func NewLRUMutexCache(maxSize int) *LRUMutexCache { + setter.SetDefault(&maxSize, 50_000) + + return &LRUMutexCache{ + cache: make(map[string]*list.Element), + ll: list.New(), + cacheSize: maxSize, + } +} + +// Each maintains a goroutine that iterates over every item in the cache. +// Other go routines operating on this cache will block until all items +// are read from the returned channel. +func (c *LRUMutexCache) Each() chan *CacheItem { + out := make(chan *CacheItem) + go func() { + c.mu.Lock() + defer c.mu.Unlock() + + for _, ele := range c.cache { + out <- ele.Value.(*CacheItem) + } + close(out) + }() + return out +} + +// Add is a noop as it is deprecated +func (c *LRUMutexCache) Add(item *CacheItem) bool { + return false +} + +// AddIfNotPresent adds the item to the cache if it doesn't already exist. +// Returns true if the item was added, false if the item already exists. +func (c *LRUMutexCache) AddIfNotPresent(item *CacheItem) bool { + c.mu.Lock() + defer c.mu.Unlock() + + // If the key already exist, do nothing + if _, ok := c.cache[item.Key]; ok { + return false + } + + ele := c.ll.PushFront(item) + c.cache[item.Key] = ele + if c.cacheSize != 0 && c.ll.Len() > c.cacheSize { + c.removeOldest() + } + atomic.StoreInt64(&c.cacheLen, int64(c.ll.Len())) + return true +} + +// GetItem returns the item stored in the cache +func (c *LRUMutexCache) GetItem(key string) (item *CacheItem, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + if ele, hit := c.cache[key]; hit { + entry := ele.Value.(*CacheItem) + + c.stats.Hit++ + c.ll.MoveToFront(ele) + return entry, true + } + + c.stats.Miss++ + return +} + +// Remove removes the provided key from the cache. +func (c *LRUMutexCache) Remove(key string) { + c.mu.Lock() + defer c.mu.Unlock() + + if ele, hit := c.cache[key]; hit { + c.removeElement(ele) + } +} + +// RemoveOldest removes the oldest item from the cache. +func (c *LRUMutexCache) removeOldest() { + ele := c.ll.Back() + if ele != nil { + entry := ele.Value.(*CacheItem) + + if MillisecondNow() < entry.ExpireAt { + c.stats.UnexpiredEvictions++ + } + + c.removeElement(ele) + } +} + +func (c *LRUMutexCache) removeElement(e *list.Element) { + c.ll.Remove(e) + kv := e.Value.(*CacheItem) + delete(c.cache, kv.Key) + atomic.StoreInt64(&c.cacheLen, int64(c.ll.Len())) +} + +// Size returns the number of items in the cache. +func (c *LRUMutexCache) Size() int64 { + return atomic.LoadInt64(&c.cacheLen) +} + +func (c *LRUMutexCache) Close() error { + c.cache = nil + c.ll = nil + c.cacheLen = 0 + return nil +} + +// Stats returns the current status for the cache +func (c *LRUMutexCache) Stats() CacheStats { + c.mu.Lock() + defer func() { + c.stats = CacheStats{} + c.mu.Unlock() + }() + + c.stats.Size = atomic.LoadInt64(&c.cacheLen) + return c.stats +} diff --git a/lrumutext_test.go b/lrumutext_test.go new file mode 100644 index 0000000..8e4b637 --- /dev/null +++ b/lrumutext_test.go @@ -0,0 +1,594 @@ +/* +Copyright 2018-2022 Mailgun Technologies Inc + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gubernator_test + +import ( + "fmt" + "math/rand" + "strconv" + "sync" + "testing" + "time" + + "github.com/gubernator-io/gubernator/v2" + "github.com/mailgun/holster/v4/clock" + "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLRUMutexCache(t *testing.T) { + const iterations = 1000 + const concurrency = 100 + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + + t.Run("Happy path", func(t *testing.T) { + cache := gubernator.NewLRUMutexCache(0) + + // Populate cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + assert.True(t, cache.AddIfNotPresent(item)) + } + + // Validate cache. + assert.Equal(t, int64(iterations), cache.Size()) + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item, ok := cache.GetItem(key) + require.True(t, ok) + require.NotNil(t, item) + assert.Equal(t, item.Value, i) + } + + // Clear cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + cache.Remove(key) + } + + assert.Zero(t, cache.Size()) + }) + + t.Run("Update an existing key", func(t *testing.T) { + cache := gubernator.NewLRUMutexCache(0) + const key = "foobar" + + // Add key. + item1 := &gubernator.CacheItem{ + Key: key, + Value: "initial value", + ExpireAt: expireAt, + } + require.True(t, cache.AddIfNotPresent(item1)) + + // Update same key. + item2 := &gubernator.CacheItem{ + Key: key, + Value: "new value", + ExpireAt: expireAt, + } + require.False(t, cache.AddIfNotPresent(item2)) + + updateItem, ok := cache.GetItem(item1.Key) + require.True(t, ok) + updateItem.Value = "new value" + + // Verify. + verifyItem, ok := cache.GetItem(key) + require.True(t, ok) + assert.Equal(t, item2, verifyItem) + }) + + t.Run("Concurrent reads", func(t *testing.T) { + cache := gubernator.NewLRUMutexCache(0) + + // Populate cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + assert.True(t, cache.AddIfNotPresent(item)) + } + + assert.Equal(t, int64(iterations), cache.Size()) + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for thread := 0; thread < concurrency; thread++ { + doneWg.Add(1) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item, ok := cache.GetItem(key) + assert.True(t, ok) + require.NotNil(t, item) + assert.Equal(t, item.Value, i) + } + }() + } + + // Wait for goroutines to finish. + launchWg.Done() + doneWg.Wait() + }) + + t.Run("Concurrent writes", func(t *testing.T) { + cache := gubernator.NewLRUMutexCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for thread := 0; thread < concurrency; thread++ { + doneWg.Add(1) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + }() + } + + // Wait for goroutines to finish. + launchWg.Done() + doneWg.Wait() + }) + + t.Run("Concurrent reads and writes", func(t *testing.T) { + cache := gubernator.NewLRUMutexCache(0) + + // Populate cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + assert.True(t, cache.AddIfNotPresent(item)) + } + + assert.Equal(t, int64(iterations), cache.Size()) + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for thread := 0; thread < concurrency; thread++ { + doneWg.Add(2) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item, ok := cache.GetItem(key) + assert.True(t, ok) + require.NotNil(t, item) + assert.Equal(t, item.Value, i) + } + }() + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + }() + } + + // Wait for goroutines to finish. + launchWg.Done() + doneWg.Wait() + }) + + t.Run("Collect metrics during concurrent reads/writes", func(t *testing.T) { + cache := gubernator.NewLRUMutexCache(0) + + // Populate cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + + assert.Equal(t, int64(iterations), cache.Size()) + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for thread := 0; thread < concurrency; thread++ { + doneWg.Add(3) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + // Get, cache hit. + key := strconv.Itoa(i) + _, _ = cache.GetItem(key) + + // Get, cache miss. + key2 := strconv.Itoa(rand.Intn(1000) + 10000) + _, _ = cache.GetItem(key2) + } + }() + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + // Add existing. + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + + // Add new. + key2 := strconv.Itoa(rand.Intn(1000) + 20000) + item2 := &gubernator.CacheItem{ + Key: key2, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item2) + } + }() + + collector := gubernator.NewCacheCollector() + collector.AddCache(cache) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + // Get metrics. + ch := make(chan prometheus.Metric, 10) + collector.Collect(ch) + } + }() + } + + // Wait for goroutines to finish. + launchWg.Done() + doneWg.Wait() + }) + + t.Run("Check gubernator_unexpired_evictions_count metric is not incremented when expired item is evicted", func(t *testing.T) { + defer clock.Freeze(clock.Now()).Unfreeze() + + promRegister := prometheus.NewRegistry() + + // The LRU cache for storing rate limits. + cacheCollector := gubernator.NewCacheCollector() + err := promRegister.Register(cacheCollector) + require.NoError(t, err) + + cache := gubernator.NewLRUMutexCache(10) + cacheCollector.AddCache(cache) + + // fill cache with short duration cache items + for i := 0; i < 10; i++ { + cache.AddIfNotPresent(&gubernator.CacheItem{ + Algorithm: gubernator.Algorithm_LEAKY_BUCKET, + Key: fmt.Sprintf("short-expiry-%d", i), + Value: "bar", + ExpireAt: clock.Now().Add(5 * time.Minute).UnixMilli(), + }) + } + + // jump forward in time to expire all short duration keys + clock.Advance(6 * time.Minute) + + // add a new cache item to force eviction + cache.AddIfNotPresent(&gubernator.CacheItem{ + Algorithm: gubernator.Algorithm_LEAKY_BUCKET, + Key: "evict1", + Value: "bar", + ExpireAt: clock.Now().Add(1 * time.Hour).UnixMilli(), + }) + + collChan := make(chan prometheus.Metric, 64) + cacheCollector.Collect(collChan) + // Check metrics to verify evicted cache key is expired + <-collChan + <-collChan + <-collChan + m := <-collChan // gubernator_unexpired_evictions_count + met := new(dto.Metric) + _ = m.Write(met) + assert.Contains(t, m.Desc().String(), "gubernator_unexpired_evictions_count") + assert.Equal(t, 0, int(*met.Counter.Value)) + }) + + t.Run("Check gubernator_unexpired_evictions_count metric is incremented when unexpired item is evicted", func(t *testing.T) { + defer clock.Freeze(clock.Now()).Unfreeze() + + promRegister := prometheus.NewRegistry() + + // The LRU cache for storing rate limits. + cacheCollector := gubernator.NewCacheCollector() + err := promRegister.Register(cacheCollector) + require.NoError(t, err) + + cache := gubernator.NewLRUMutexCache(10) + cacheCollector.AddCache(cache) + + // fill cache with long duration cache items + for i := 0; i < 10; i++ { + cache.AddIfNotPresent(&gubernator.CacheItem{ + Algorithm: gubernator.Algorithm_LEAKY_BUCKET, + Key: fmt.Sprintf("long-expiry-%d", i), + Value: "bar", + ExpireAt: clock.Now().Add(1 * time.Hour).UnixMilli(), + }) + } + + // add a new cache item to force eviction + cache.AddIfNotPresent(&gubernator.CacheItem{ + Algorithm: gubernator.Algorithm_LEAKY_BUCKET, + Key: "evict2", + Value: "bar", + ExpireAt: clock.Now().Add(1 * time.Hour).UnixMilli(), + }) + + // Check metrics to verify evicted cache key is *NOT* expired + collChan := make(chan prometheus.Metric, 64) + cacheCollector.Collect(collChan) + <-collChan + <-collChan + <-collChan + m := <-collChan // gubernator_unexpired_evictions_count + met := new(dto.Metric) + _ = m.Write(met) + assert.Contains(t, m.Desc().String(), "gubernator_unexpired_evictions_count") + assert.Equal(t, 1, int(*met.Counter.Value)) + }) +} + +func BenchmarkLRUMutexCache(b *testing.B) { + + b.Run("Sequential reads", func(b *testing.B) { + cache := gubernator.NewLRUMutexCache(b.N) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + + // Populate cache. + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + assert.True(b, cache.AddIfNotPresent(item)) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + _, _ = cache.GetItem(key) + } + }) + + b.Run("Sequential writes", func(b *testing.B) { + cache := gubernator.NewLRUMutexCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + }) + + b.Run("Concurrent reads", func(b *testing.B) { + cache := gubernator.NewLRUMutexCache(b.N) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + + // Populate cache. + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + assert.True(b, cache.AddIfNotPresent(item)) + } + + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + doneWg.Add(1) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + _, _ = cache.GetItem(key) + }() + } + + b.ReportAllocs() + b.ResetTimer() + launchWg.Done() + doneWg.Wait() + }) + + b.Run("Concurrent writes", func(b *testing.B) { + cache := gubernator.NewLRUMutexCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + doneWg.Add(1) + + go func(i int) { + defer doneWg.Done() + launchWg.Wait() + + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + }(i) + } + + b.ReportAllocs() + b.ResetTimer() + launchWg.Done() + doneWg.Wait() + }) + + b.Run("Concurrent reads and writes of existing keys", func(b *testing.B) { + cache := gubernator.NewLRUMutexCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + // Populate cache. + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + assert.True(b, cache.AddIfNotPresent(item)) + } + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + doneWg.Add(2) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + _, _ = cache.GetItem(key) + }() + + go func(i int) { + defer doneWg.Done() + launchWg.Wait() + + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + }(i) + } + + b.ReportAllocs() + b.ResetTimer() + launchWg.Done() + doneWg.Wait() + }) + + b.Run("Concurrent reads and writes of non-existent keys", func(b *testing.B) { + cache := gubernator.NewLRUMutexCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for i := 0; i < b.N; i++ { + doneWg.Add(2) + + go func(i int) { + defer doneWg.Done() + launchWg.Wait() + + key := strconv.Itoa(i) + _, _ = cache.GetItem(key) + }(i) + + go func(i int) { + defer doneWg.Done() + launchWg.Wait() + + key := "z" + strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + _ = cache.AddIfNotPresent(item) + }(i) + } + + b.ReportAllocs() + b.ResetTimer() + launchWg.Done() + doneWg.Wait() + }) +} diff --git a/mock_cache_test.go b/mock_cache_test.go index 536a93b..6b72b39 100644 --- a/mock_cache_test.go +++ b/mock_cache_test.go @@ -34,6 +34,11 @@ func (m *MockCache) Add(item *guber.CacheItem) bool { return args.Bool(0) } +func (m *MockCache) AddIfNotPresent(item *guber.CacheItem) bool { + args := m.Called(item) + return args.Bool(0) +} + func (m *MockCache) GetItem(key string) (value *guber.CacheItem, ok bool) { args := m.Called(key) retval, _ := args.Get(0).(*guber.CacheItem) diff --git a/otter.go b/otter.go index 3ba29f9..865a454 100644 --- a/otter.go +++ b/otter.go @@ -44,10 +44,15 @@ func NewOtterCache(size int) (*OtterCache, error) { return o, nil } -// Add adds a new CacheItem to the cache. The key must be provided via CacheItem.Key -// returns true if the item was added to the cache; false if the item was too large -// for the cache or already exists in the cache. +// Add is a noop as it is deprecated func (o *OtterCache) Add(item *CacheItem) bool { + return false +} + +// AddIfNotPresent adds a new CacheItem to the cache. The key must be provided via CacheItem.Key +// returns true if the item was added to the cache; false if the item was too large +// for the cache or if the key already exists in the cache. +func (o *OtterCache) AddIfNotPresent(item *CacheItem) bool { return o.cache.SetIfAbsent(item.Key, item) } diff --git a/otter_test.go b/otter_test.go index 9c84df1..650ea2f 100644 --- a/otter_test.go +++ b/otter_test.go @@ -29,7 +29,7 @@ func TestOtterCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) } // Validate cache. @@ -63,7 +63,7 @@ func TestOtterCache(t *testing.T) { Value: "initial value", ExpireAt: expireAt, } - cache.Add(item1) + cache.AddIfNotPresent(item1) // Update same key is refused item2 := &gubernator.CacheItem{ @@ -71,7 +71,7 @@ func TestOtterCache(t *testing.T) { Value: "new value", ExpireAt: expireAt, } - assert.False(t, cache.Add(item2)) + assert.False(t, cache.AddIfNotPresent(item2)) // Fetch and update the CacheItem update, ok := cache.GetItem(key) @@ -96,7 +96,7 @@ func TestOtterCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) } assert.Equal(t, int64(iterations), cache.Size()) @@ -146,7 +146,7 @@ func TestOtterCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) } }() } @@ -168,7 +168,7 @@ func TestOtterCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) } assert.Equal(t, int64(iterations), cache.Size()) @@ -203,7 +203,7 @@ func TestOtterCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) } }() } @@ -229,7 +229,7 @@ func BenchmarkOtterCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) } b.ReportAllocs() @@ -255,7 +255,7 @@ func BenchmarkOtterCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) } }) @@ -271,7 +271,7 @@ func BenchmarkOtterCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) } var launchWg, doneWg sync.WaitGroup @@ -314,7 +314,7 @@ func BenchmarkOtterCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) }(i) } @@ -338,7 +338,7 @@ func BenchmarkOtterCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) } for i := 0; i < b.N; i++ { @@ -361,7 +361,7 @@ func BenchmarkOtterCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) }(i) } @@ -398,7 +398,7 @@ func BenchmarkOtterCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - cache.Add(item) + cache.AddIfNotPresent(item) }(i) }