From 8f414323c5edd86afc042948f482d2e084cc6860 Mon Sep 17 00:00:00 2001 From: "Derrick J. Wippler" Date: Mon, 13 May 2024 16:41:34 -0500 Subject: [PATCH] Fixed race condition in tokenBucket() --- Makefile | 3 + algorithms.go | 594 +++++++++++++++++++----------------- benchmark_cache_test.go | 16 +- cache.go | 7 + cache_manager.go | 9 +- cluster/cluster.go | 17 ++ cmd/gubernator/main_test.go | 5 +- functional_test.go | 10 +- gubernator.go | 28 +- lrucache.go | 11 +- otter.go | 25 +- otter_test.go | 9 +- store.go | 79 ++--- store_test.go | 91 ++++++ 14 files changed, 542 insertions(+), 362 deletions(-) diff --git a/Makefile b/Makefile index d98f86d..6ecd257 100644 --- a/Makefile +++ b/Makefile @@ -45,6 +45,9 @@ clean-proto: ## Clean the generated source files from the protobuf sources @find . -name "*.pb.go" -type f -delete @find . -name "*.pb.*.go" -type f -delete +.PHONY: validate +validate: lint test + go mod tidy && git diff --exit-code .PHONY: proto proto: ## Build protos diff --git a/algorithms.go b/algorithms.go index c923161..a118294 100644 --- a/algorithms.go +++ b/algorithms.go @@ -18,14 +18,26 @@ package gubernator import ( "context" + "errors" "github.com/mailgun/holster/v4/clock" "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" - "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" ) +var errAlreadyExistsInCache = errors.New("already exists in cache") + +type rateContext struct { + context.Context + ReqState RateLimitReqState + Request *RateLimitReq + CacheItem *CacheItem + Store Store + Cache Cache + // TODO: Remove + InstanceID string +} + // ### NOTE ### // The both token and leaky follow the same semantic which allows for requests of more than the limit // to be rejected, but subsequent requests within the same window that are under the limit to succeed. @@ -34,223 +46,240 @@ import ( // with 100 emails and the request will succeed. You can override this default behavior with `DRAIN_OVER_LIMIT` // Implements token bucket algorithm for rate limiting. https://en.wikipedia.org/wiki/Token_bucket -func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqState RateLimitReqState) (resp *RateLimitResp, err error) { +func tokenBucket(ctx rateContext) (resp *RateLimitResp, err error) { tokenBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("tokenBucket")) defer tokenBucketTimer.ObserveDuration() - - // Get rate limit from cache. - hashKey := r.HashKey() - item, ok := c.GetItem(hashKey) - - if s != nil && !ok { - // Cache miss. - // Check our store for the item. - if item, ok = s.Get(ctx, r); ok { - c.Add(item) + var ok bool + // TODO: Remove + //fmt.Printf("[%s] tokenBucket()\n", ctx.InstanceID) + + // Get rate limit from cache + hashKey := ctx.Request.HashKey() + ctx.CacheItem, ok = ctx.Cache.GetItem(hashKey) + + // If not in the cache, check the store if provided + if ctx.Store != nil && !ok { + if ctx.CacheItem, ok = ctx.Store.Get(ctx, ctx.Request); ok { + if !ctx.Cache.Add(ctx.CacheItem) { + // Someone else added a new token bucket item to the cache for this + // rate limit before we did, so we retry by calling ourselves recursively. + return tokenBucket(ctx) + } } } - // Sanity checks. - if ok { - if item.Value == nil { - msgPart := "tokenBucket: Invalid cache item; Value is nil" - trace.SpanFromContext(ctx).AddEvent(msgPart, trace.WithAttributes( - attribute.String("hashKey", hashKey), - attribute.String("key", r.UniqueKey), - attribute.String("name", r.Name), - )) - logrus.Error(msgPart) - ok = false - } else if item.Key != hashKey { - msgPart := "tokenBucket: Invalid cache item; key mismatch" - trace.SpanFromContext(ctx).AddEvent(msgPart, trace.WithAttributes( - attribute.String("itemKey", item.Key), - attribute.String("hashKey", hashKey), - attribute.String("name", r.Name), - )) - logrus.Error(msgPart) - ok = false + // If no item was found, or the item is expired. + if ctx.CacheItem == nil || ctx.CacheItem.IsExpired() { + // Initialize the Token bucket item + rl, err := InitTokenBucketItem(ctx) + if err != nil && errors.Is(err, errAlreadyExistsInCache) { + // Someone else added a new token bucket item to the cache for this + // rate limit before we did, so we retry by calling ourselves recursively. + return tokenBucket(ctx) } + return rl, err } - if ok { - // Item found in cache or store. - if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) { - c.Remove(hashKey) + // Gain exclusive rights to this item while we calculate the rate limit + ctx.CacheItem.mutex.Lock() + defer ctx.CacheItem.mutex.Unlock() - if s != nil { - s.Remove(ctx, hashKey) - } - return &RateLimitResp{ - Status: Status_UNDER_LIMIT, - Limit: r.Limit, - Remaining: r.Limit, - ResetTime: 0, - }, nil + t, ok := ctx.CacheItem.Value.(*TokenBucketItem) + if !ok { + // Client switched algorithms; perhaps due to a migration? + ctx.Cache.Remove(hashKey) + if ctx.Store != nil { + ctx.Store.Remove(ctx, hashKey) } - t, ok := item.Value.(*TokenBucketItem) - if !ok { - // Client switched algorithms; perhaps due to a migration? - trace.SpanFromContext(ctx).AddEvent("Client switched algorithms; perhaps due to a migration?") + ctx.CacheItem = nil - c.Remove(hashKey) + rl, err := InitTokenBucketItem(ctx) + if err != nil && errors.Is(err, errAlreadyExistsInCache) { + return tokenBucket(ctx) + } + return rl, err + } - if s != nil { - s.Remove(ctx, hashKey) - } + if HasBehavior(ctx.Request.Behavior, Behavior_RESET_REMAINING) { + t.Remaining = ctx.Request.Limit + t.Limit = ctx.Request.Limit + t.Status = Status_UNDER_LIMIT - return tokenBucketNewItem(ctx, s, c, r, reqState) + if ctx.Store != nil { + ctx.Store.OnChange(ctx, ctx.Request, ctx.CacheItem) } - // Update the limit if it changed. - if t.Limit != r.Limit { - // Add difference to remaining. - t.Remaining += r.Limit - t.Limit - if t.Remaining < 0 { - t.Remaining = 0 - } - t.Limit = r.Limit - } + return &RateLimitResp{ + Status: Status_UNDER_LIMIT, + Limit: ctx.Request.Limit, + Remaining: ctx.Request.Limit, + ResetTime: 0, + }, nil + } - rl := &RateLimitResp{ - Status: t.Status, - Limit: r.Limit, - Remaining: t.Remaining, - ResetTime: item.ExpireAt, + // Update the limit if it changed. + if t.Limit != ctx.Request.Limit { + // Add difference to remaining. + t.Remaining += ctx.Request.Limit - t.Limit + if t.Remaining < 0 { + t.Remaining = 0 } + t.Limit = ctx.Request.Limit + } - // If the duration config changed, update the new ExpireAt. - if t.Duration != r.Duration { - span := trace.SpanFromContext(ctx) - span.AddEvent("Duration changed") - expire := t.CreatedAt + r.Duration - if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { - expire, err = GregorianExpiration(clock.Now(), r.Duration) - if err != nil { - return nil, err - } - } + rl := &RateLimitResp{ + Status: t.Status, + Limit: ctx.Request.Limit, + Remaining: t.Remaining, + ResetTime: ctx.CacheItem.ExpireAt, + } - // If our new duration means we are currently expired. - createdAt := *r.CreatedAt - if expire <= createdAt { - // Renew item. - span.AddEvent("Limit has expired") - expire = createdAt + r.Duration - t.CreatedAt = createdAt - t.Remaining = t.Limit + // If the duration config changed, update the new ExpireAt. + if t.Duration != ctx.Request.Duration { + span := trace.SpanFromContext(ctx) + span.AddEvent("Duration changed") + expire := t.CreatedAt + ctx.Request.Duration + if HasBehavior(ctx.Request.Behavior, Behavior_DURATION_IS_GREGORIAN) { + expire, err = GregorianExpiration(clock.Now(), ctx.Request.Duration) + if err != nil { + return nil, err } - - item.ExpireAt = expire - t.Duration = r.Duration - rl.ResetTime = expire } - if s != nil && reqState.IsOwner { - defer func() { - s.OnChange(ctx, r, item) - }() + // If our new duration means we are currently expired. + createdAt := *ctx.Request.CreatedAt + if expire <= createdAt { + // Renew item. + span.AddEvent("Limit has expired") + expire = createdAt + ctx.Request.Duration + t.CreatedAt = createdAt + t.Remaining = t.Limit } - // Client is only interested in retrieving the current status or - // updating the rate limit config. - if r.Hits == 0 { - return rl, nil - } + ctx.CacheItem.ExpireAt = expire + t.Duration = ctx.Request.Duration + rl.ResetTime = expire + } - // If we are already at the limit. - if rl.Remaining == 0 && r.Hits > 0 { - trace.SpanFromContext(ctx).AddEvent("Already over the limit") - if reqState.IsOwner { - metricOverLimitCounter.Add(1) - } - rl.Status = Status_OVER_LIMIT - t.Status = rl.Status - return rl, nil + if ctx.Store != nil && ctx.ReqState.IsOwner { + defer func() { + ctx.Store.OnChange(ctx, ctx.Request, ctx.CacheItem) + }() + } + + // Client is only interested in retrieving the current status or + // updating the rate limit config. + if ctx.Request.Hits == 0 { + return rl, nil + } + + // If we are already at the limit. + if rl.Remaining == 0 && ctx.Request.Hits > 0 { + trace.SpanFromContext(ctx).AddEvent("Already over the limit") + if ctx.ReqState.IsOwner { + metricOverLimitCounter.Add(1) } + rl.Status = Status_OVER_LIMIT + t.Status = rl.Status + return rl, nil + } + + // If requested hits takes the remainder. + if t.Remaining == ctx.Request.Hits { + trace.SpanFromContext(ctx).AddEvent("At the limit") + t.Remaining = 0 + rl.Remaining = 0 + return rl, nil + } - // If requested hits takes the remainder. - if t.Remaining == r.Hits { - trace.SpanFromContext(ctx).AddEvent("At the limit") + // If requested is more than available, then return over the limit + // without updating the cache. + if ctx.Request.Hits > t.Remaining { + trace.SpanFromContext(ctx).AddEvent("Over the limit") + if ctx.ReqState.IsOwner { + metricOverLimitCounter.Add(1) + } + rl.Status = Status_OVER_LIMIT + if HasBehavior(ctx.Request.Behavior, Behavior_DRAIN_OVER_LIMIT) { + // DRAIN_OVER_LIMIT behavior drains the remaining counter. t.Remaining = 0 rl.Remaining = 0 - return rl, nil } - - // If requested is more than available, then return over the limit - // without updating the cache. - if r.Hits > t.Remaining { - trace.SpanFromContext(ctx).AddEvent("Over the limit") - if reqState.IsOwner { - metricOverLimitCounter.Add(1) - } - rl.Status = Status_OVER_LIMIT - if HasBehavior(r.Behavior, Behavior_DRAIN_OVER_LIMIT) { - // DRAIN_OVER_LIMIT behavior drains the remaining counter. - t.Remaining = 0 - rl.Remaining = 0 - } - return rl, nil - } - - t.Remaining -= r.Hits - rl.Remaining = t.Remaining return rl, nil } - // Item is not found in cache or store, create new. - return tokenBucketNewItem(ctx, s, c, r, reqState) + t.Remaining -= ctx.Request.Hits + rl.Remaining = t.Remaining + return rl, nil } -// Called by tokenBucket() when adding a new item in the store. -func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqState RateLimitReqState) (resp *RateLimitResp, err error) { - createdAt := *r.CreatedAt - expire := createdAt + r.Duration +// InitTokenBucketItem will create a new item if the passed item is nil, else it will update the provided item. +func InitTokenBucketItem(ctx rateContext) (resp *RateLimitResp, err error) { + createdAt := *ctx.Request.CreatedAt + expire := createdAt + ctx.Request.Duration - t := &TokenBucketItem{ - Limit: r.Limit, - Duration: r.Duration, - Remaining: r.Limit - r.Hits, + t := TokenBucketItem{ + Limit: ctx.Request.Limit, + Duration: ctx.Request.Duration, + Remaining: ctx.Request.Limit - ctx.Request.Hits, CreatedAt: createdAt, } // Add a new rate limit to the cache. - if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { - expire, err = GregorianExpiration(clock.Now(), r.Duration) + if HasBehavior(ctx.Request.Behavior, Behavior_DURATION_IS_GREGORIAN) { + expire, err = GregorianExpiration(clock.Now(), ctx.Request.Duration) if err != nil { return nil, err } } - item := &CacheItem{ - Algorithm: Algorithm_TOKEN_BUCKET, - Key: r.HashKey(), - Value: t, - ExpireAt: expire, - } - rl := &RateLimitResp{ Status: Status_UNDER_LIMIT, - Limit: r.Limit, + Limit: ctx.Request.Limit, Remaining: t.Remaining, ResetTime: expire, } // Client could be requesting that we always return OVER_LIMIT. - if r.Hits > r.Limit { + if ctx.Request.Hits > ctx.Request.Limit { trace.SpanFromContext(ctx).AddEvent("Over the limit") - if reqState.IsOwner { + if ctx.ReqState.IsOwner { metricOverLimitCounter.Add(1) } rl.Status = Status_OVER_LIMIT - rl.Remaining = r.Limit - t.Remaining = r.Limit + rl.Remaining = ctx.Request.Limit + t.Remaining = ctx.Request.Limit } - c.Add(item) + // If the cache item already exists, update it + if ctx.CacheItem != nil { + ctx.CacheItem.mutex.Lock() + ctx.CacheItem.Algorithm = Algorithm_TOKEN_BUCKET + ctx.CacheItem.ExpireAt = expire + in, ok := ctx.CacheItem.Value.(*TokenBucketItem) + if !ok { + // Likely the store gave us the wrong cache type + ctx.CacheItem.mutex.Unlock() + ctx.CacheItem = nil + return InitTokenBucketItem(ctx) + } + *in = t + ctx.CacheItem.mutex.Unlock() + } else { + // else create a new cache item and add it to the cache + ctx.CacheItem = &CacheItem{ + Algorithm: Algorithm_TOKEN_BUCKET, + Key: ctx.Request.HashKey(), + Value: &t, + ExpireAt: expire, + } + if !ctx.Cache.Add(ctx.CacheItem) { + return rl, errAlreadyExistsInCache + } + } - if s != nil && reqState.IsOwner { - s.OnChange(ctx, r, item) + if ctx.Store != nil && ctx.ReqState.IsOwner { + ctx.Store.OnChange(ctx, ctx.Request, ctx.CacheItem) } return rl, nil @@ -261,6 +290,8 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqStat leakyBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.getRateLimit_leakyBucket")) defer leakyBucketTimer.ObserveDuration() + // TODO(thrawn01): Test for race conditions, and fix + if r.Burst == 0 { r.Burst = r.Limit } @@ -272,165 +303,158 @@ func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqStat item, ok := c.GetItem(hashKey) if s != nil && !ok { - // Cache miss. - // Check our store for the item. + // Cache missed, check our store for the item. if item, ok = s.Get(ctx, r); ok { - c.Add(item) + if !c.Add(item) { + // Someone else added a new leaky bucket item to the cache for this + // rate limit before we did, so we retry by calling ourselves recursively. + return leakyBucket(ctx, s, c, r, reqState) + } } } - // Sanity checks. - if ok { - if item.Value == nil { - msgPart := "leakyBucket: Invalid cache item; Value is nil" - trace.SpanFromContext(ctx).AddEvent(msgPart, trace.WithAttributes( - attribute.String("hashKey", hashKey), - attribute.String("key", r.UniqueKey), - attribute.String("name", r.Name), - )) - logrus.Error(msgPart) - ok = false - } else if item.Key != hashKey { - msgPart := "leakyBucket: Invalid cache item; key mismatch" - trace.SpanFromContext(ctx).AddEvent(msgPart, trace.WithAttributes( - attribute.String("itemKey", item.Key), - attribute.String("hashKey", hashKey), - attribute.String("name", r.Name), - )) - logrus.Error(msgPart) - ok = false + if !ok { + rl, err := leakyBucketNewItem(ctx, s, c, r, reqState) + if err != nil && errors.Is(err, errAlreadyExistsInCache) { + // Someone else added a new leaky bucket item to the cache for this + // rate limit before we did, so we retry by calling ourselves recursively. + return leakyBucket(ctx, s, c, r, reqState) } + return rl, err } - if ok { - // Item found in cache or store. + // Item found in cache or store. + b, ok := item.Value.(*LeakyBucketItem) + if !ok { + // Client switched algorithms; perhaps due to a migration? + c.Remove(hashKey) + if s != nil { + s.Remove(ctx, hashKey) + } - b, ok := item.Value.(*LeakyBucketItem) - if !ok { - // Client switched algorithms; perhaps due to a migration? - c.Remove(hashKey) + rl, err := leakyBucketNewItem(ctx, s, c, r, reqState) + if err != nil && errors.Is(err, errAlreadyExistsInCache) { + return leakyBucket(ctx, s, c, r, reqState) + } + return rl, err + } - if s != nil { - s.Remove(ctx, hashKey) - } + // Gain exclusive rights to this item while we calculate the rate limit + b.mutex.Lock() + defer b.mutex.Unlock() - return leakyBucketNewItem(ctx, s, c, r, reqState) - } + if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) { + b.Remaining = float64(r.Burst) + } - if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) { + // Update burst, limit and duration if they changed + if b.Burst != r.Burst { + if r.Burst > int64(b.Remaining) { b.Remaining = float64(r.Burst) } + b.Burst = r.Burst + } - // Update burst, limit and duration if they changed - if b.Burst != r.Burst { - if r.Burst > int64(b.Remaining) { - b.Remaining = float64(r.Burst) - } - b.Burst = r.Burst - } - - b.Limit = r.Limit - b.Duration = r.Duration - - duration := r.Duration - rate := float64(duration) / float64(r.Limit) + b.Limit = r.Limit + b.Duration = r.Duration - if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { - d, err := GregorianDuration(clock.Now(), r.Duration) - if err != nil { - return nil, err - } - n := clock.Now() - expire, err := GregorianExpiration(n, r.Duration) - if err != nil { - return nil, err - } + duration := r.Duration + rate := float64(duration) / float64(r.Limit) - // Calculate the rate using the entire duration of the gregorian interval - // IE: Minute = 60,000 milliseconds, etc.. etc.. - rate = float64(d) / float64(r.Limit) - // Update the duration to be the end of the gregorian interval - duration = expire - (n.UnixNano() / 1000000) + if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { + d, err := GregorianDuration(clock.Now(), r.Duration) + if err != nil { + return nil, err } - - if r.Hits != 0 { - c.UpdateExpiration(r.HashKey(), createdAt+duration) + n := clock.Now() + expire, err := GregorianExpiration(n, r.Duration) + if err != nil { + return nil, err } - // Calculate how much leaked out of the bucket since the last time we leaked a hit - elapsed := createdAt - b.UpdatedAt - leak := float64(elapsed) / rate + // Calculate the rate using the entire duration of the gregorian interval + // IE: Minute = 60,000 milliseconds, etc.. etc.. + rate = float64(d) / float64(r.Limit) + // Update the duration to be the end of the gregorian interval + duration = expire - (n.UnixNano() / 1000000) + } - if int64(leak) > 0 { - b.Remaining += leak - b.UpdatedAt = createdAt - } + if r.Hits != 0 { + c.UpdateExpiration(r.HashKey(), createdAt+duration) + } - if int64(b.Remaining) > b.Burst { - b.Remaining = float64(b.Burst) - } + // Calculate how much leaked out of the bucket since the last time we leaked a hit + elapsed := createdAt - b.UpdatedAt + leak := float64(elapsed) / rate - rl := &RateLimitResp{ - Limit: b.Limit, - Remaining: int64(b.Remaining), - Status: Status_UNDER_LIMIT, - ResetTime: createdAt + (b.Limit-int64(b.Remaining))*int64(rate), - } + if int64(leak) > 0 { + b.Remaining += leak + b.UpdatedAt = createdAt + } - // TODO: Feature missing: check for Duration change between item/request. + if int64(b.Remaining) > b.Burst { + b.Remaining = float64(b.Burst) + } - if s != nil && reqState.IsOwner { - defer func() { - s.OnChange(ctx, r, item) - }() - } + rl := &RateLimitResp{ + Limit: b.Limit, + Remaining: int64(b.Remaining), + Status: Status_UNDER_LIMIT, + ResetTime: createdAt + (b.Limit-int64(b.Remaining))*int64(rate), + } - // If we are already at the limit - if int64(b.Remaining) == 0 && r.Hits > 0 { - if reqState.IsOwner { - metricOverLimitCounter.Add(1) - } - rl.Status = Status_OVER_LIMIT - return rl, nil - } + // TODO: Feature missing: check for Duration change between item/request. - // If requested hits takes the remainder - if int64(b.Remaining) == r.Hits { - b.Remaining = 0 - rl.Remaining = int64(b.Remaining) - rl.ResetTime = createdAt + (rl.Limit-rl.Remaining)*int64(rate) - return rl, nil - } + if s != nil && reqState.IsOwner { + defer func() { + s.OnChange(ctx, r, item) + }() + } - // If requested is more than available, then return over the limit - // without updating the bucket, unless `DRAIN_OVER_LIMIT` is set. - if r.Hits > int64(b.Remaining) { - if reqState.IsOwner { - metricOverLimitCounter.Add(1) - } - rl.Status = Status_OVER_LIMIT + // If we are already at the limit + if int64(b.Remaining) == 0 && r.Hits > 0 { + if reqState.IsOwner { + metricOverLimitCounter.Add(1) + } + rl.Status = Status_OVER_LIMIT + return rl, nil + } - // DRAIN_OVER_LIMIT behavior drains the remaining counter. - if HasBehavior(r.Behavior, Behavior_DRAIN_OVER_LIMIT) { - b.Remaining = 0 - rl.Remaining = 0 - } + // If requested hits takes the remainder + if int64(b.Remaining) == r.Hits { + b.Remaining = 0 + rl.Remaining = int64(b.Remaining) + rl.ResetTime = createdAt + (rl.Limit-rl.Remaining)*int64(rate) + return rl, nil + } - return rl, nil + // If requested is more than available, then return over the limit + // without updating the bucket, unless `DRAIN_OVER_LIMIT` is set. + if r.Hits > int64(b.Remaining) { + if reqState.IsOwner { + metricOverLimitCounter.Add(1) } + rl.Status = Status_OVER_LIMIT - // Client is only interested in retrieving the current status - if r.Hits == 0 { - return rl, nil + // DRAIN_OVER_LIMIT behavior drains the remaining counter. + if HasBehavior(r.Behavior, Behavior_DRAIN_OVER_LIMIT) { + b.Remaining = 0 + rl.Remaining = 0 } - b.Remaining -= float64(r.Hits) - rl.Remaining = int64(b.Remaining) - rl.ResetTime = createdAt + (rl.Limit-rl.Remaining)*int64(rate) return rl, nil } - return leakyBucketNewItem(ctx, s, c, r, reqState) + // Client is only interested in retrieving the current status + if r.Hits == 0 { + return rl, nil + } + + b.Remaining -= float64(r.Hits) + rl.Remaining = int64(b.Remaining) + rl.ResetTime = createdAt + (rl.Limit-rl.Remaining)*int64(rate) + return rl, nil + } // Called by leakyBucket() when adding a new item in the store. @@ -483,7 +507,9 @@ func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, Value: &b, } - c.Add(item) + if !c.Add(item) { + return nil, errAlreadyExistsInCache + } if s != nil && reqState.IsOwner { s.OnChange(ctx, r, item) diff --git a/benchmark_cache_test.go b/benchmark_cache_test.go index 98c2b68..1f849d8 100644 --- a/benchmark_cache_test.go +++ b/benchmark_cache_test.go @@ -39,7 +39,7 @@ func BenchmarkCache(b *testing.B) { cache, err := testCase.NewTestCache() require.NoError(b, err) expire := clock.Now().Add(time.Hour).UnixMilli() - keys := GenerateRandomKeys() + keys := GenerateRandomKeys(defaultNumKeys) for _, key := range keys { item := &gubernator.CacheItem{ @@ -64,7 +64,7 @@ func BenchmarkCache(b *testing.B) { cache, err := testCase.NewTestCache() require.NoError(b, err) expire := clock.Now().Add(time.Hour).UnixMilli() - keys := GenerateRandomKeys() + keys := GenerateRandomKeys(defaultNumKeys) mask := len(keys) - 1 b.ReportAllocs() @@ -85,7 +85,7 @@ func BenchmarkCache(b *testing.B) { cache, err := testCase.NewTestCache() require.NoError(b, err) expire := clock.Now().Add(time.Hour).UnixMilli() - keys := GenerateRandomKeys() + keys := GenerateRandomKeys(defaultNumKeys) for _, key := range keys { item := &gubernator.CacheItem{ @@ -129,7 +129,7 @@ func BenchmarkCache(b *testing.B) { cache, err := testCase.NewTestCache() require.NoError(b, err) expire := clock.Now().Add(time.Hour).UnixMilli() - keys := GenerateRandomKeys() + keys := GenerateRandomKeys(defaultNumKeys) var mutex sync.Mutex var task func(key string) @@ -172,11 +172,11 @@ func BenchmarkCache(b *testing.B) { } } -const cacheSize = 32768 +const defaultNumKeys = 32768 -func GenerateRandomKeys() []string { - keys := make([]string, 0, cacheSize) - for i := 0; i < cacheSize; i++ { +func GenerateRandomKeys(size int) []string { + keys := make([]string, 0, size) + for i := 0; i < size; i++ { keys = append(keys, gubernator.RandomString(20)) } return keys diff --git a/cache.go b/cache.go index 0fd431a..dbeea21 100644 --- a/cache.go +++ b/cache.go @@ -16,6 +16,8 @@ limitations under the License. package gubernator +import "sync" + type Cache interface { Add(item *CacheItem) bool UpdateExpiration(key string, expireAt int64) bool @@ -27,6 +29,7 @@ type Cache interface { } type CacheItem struct { + mutex sync.Mutex Algorithm Algorithm Key string Value interface{} @@ -41,6 +44,10 @@ type CacheItem struct { } func (item *CacheItem) IsExpired() bool { + // TODO(thrawn01): Eliminate the need for this mutex lock + item.mutex.Lock() + defer item.mutex.Unlock() + now := MillisecondNow() // If the entry is invalidated diff --git a/cache_manager.go b/cache_manager.go index 6542bdf..3514ce1 100644 --- a/cache_manager.go +++ b/cache_manager.go @@ -58,7 +58,14 @@ func (m *cacheManager) GetRateLimit(ctx context.Context, req *RateLimitReq, stat switch req.Algorithm { case Algorithm_TOKEN_BUCKET: - rlResponse, err = tokenBucket(ctx, m.conf.Store, m.cache, req, state) + rlResponse, err = tokenBucket(rateContext{ + Store: m.conf.Store, + Cache: m.cache, + ReqState: state, + Request: req, + Context: ctx, + InstanceID: m.conf.InstanceID, + }) if err != nil { msg := "Error in tokenBucket" countError(err, msg) diff --git a/cluster/cluster.go b/cluster/cluster.go index 3fef87e..ad3714e 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -53,6 +53,23 @@ func GetRandomPeer(dc string) gubernator.PeerInfo { return local[rand.Intn(len(local))] } +// GetRandomDaemon returns a random daemon from the cluster +func GetRandomDaemon(dc string) *gubernator.Daemon { + var local []*gubernator.Daemon + + for _, d := range daemons { + if d.PeerInfo.DataCenter == dc { + local = append(local, d) + } + } + + if len(local) == 0 { + panic(fmt.Sprintf("failed to find random daemon for dc '%s'", dc)) + } + + return local[rand.Intn(len(local))] +} + // GetPeers returns a list of all peers in the cluster func GetPeers() []gubernator.PeerInfo { return peers diff --git a/cmd/gubernator/main_test.go b/cmd/gubernator/main_test.go index 4f1364e..f374d7f 100644 --- a/cmd/gubernator/main_test.go +++ b/cmd/gubernator/main_test.go @@ -15,9 +15,10 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + cli "github.com/gubernator-io/gubernator/v2/cmd/gubernator" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "golang.org/x/net/proxy" ) @@ -78,9 +79,9 @@ func TestCLI(t *testing.T) { time.Sleep(time.Second * 1) err = c.Process.Signal(syscall.SIGTERM) - require.NoError(t, err, out.String()) <-waitCh + require.NoError(t, err, out.String()) assert.Contains(t, out.String(), tt.contains) }) } diff --git a/functional_test.go b/functional_test.go index d93d352..ba8be28 100644 --- a/functional_test.go +++ b/functional_test.go @@ -2252,8 +2252,14 @@ func waitForIdle(timeout clock.Duration, daemons ...*guber.Daemon) error { if err != nil { return err } - ggql := metrics["gubernator_global_queue_length"] - gsql := metrics["gubernator_global_send_queue_length"] + ggql, ok := metrics["gubernator_global_queue_length"] + if !ok { + return errors.New("gubernator_global_queue_length not found") + } + gsql, ok := metrics["gubernator_global_send_queue_length"] + if !ok { + return errors.New("gubernator_global_send_queue_length not found") + } if ggql.Value == 0 && gsql.Value == 0 { return nil diff --git a/gubernator.go b/gubernator.go index 1ae40d4..1ef4fac 100644 --- a/gubernator.go +++ b/gubernator.go @@ -420,12 +420,26 @@ func (s *V1Instance) getGlobalRateLimit(ctx context.Context, req *RateLimitReq) func (s *V1Instance) UpdatePeerGlobals(ctx context.Context, r *UpdatePeerGlobalsReq) (*UpdatePeerGlobalsResp, error) { defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.UpdatePeerGlobals")).ObserveDuration() now := MillisecondNow() + for _, g := range r.Globals { - item := &CacheItem{ - ExpireAt: g.Status.ResetTime, - Algorithm: g.Algorithm, - Key: g.Key, + item, _, err := s.cache.GetCacheItem(ctx, g.Key) + if err != nil { + return nil, err } + + if item == nil { + item = &CacheItem{ + ExpireAt: g.Status.ResetTime, + Algorithm: g.Algorithm, + Key: g.Key, + } + err := s.cache.AddCacheItem(ctx, g.Key, item) + if err != nil { + return nil, fmt.Errorf("during CacheManager.AddCacheItem(): %w", err) + } + } + + item.mutex.Lock() switch g.Algorithm { case Algorithm_LEAKY_BUCKET: item.Value = &LeakyBucketItem{ @@ -444,12 +458,8 @@ func (s *V1Instance) UpdatePeerGlobals(ctx context.Context, r *UpdatePeerGlobals CreatedAt: now, } } - err := s.cache.AddCacheItem(ctx, g.Key, item) - if err != nil { - return nil, errors.Wrap(err, "Error in CacheManager.AddCacheItem") - } + item.mutex.Unlock() } - return &UpdatePeerGlobalsResp{}, nil } diff --git a/lrucache.go b/lrucache.go index 8a415c9..5bef041 100644 --- a/lrucache.go +++ b/lrucache.go @@ -119,11 +119,12 @@ func (c *LRUCache) GetItem(key string) (item *CacheItem, ok bool) { if ele, hit := c.cache[key]; hit { entry := ele.Value.(*CacheItem) - if entry.IsExpired() { - c.removeElement(ele) - metricCacheAccess.WithLabelValues("miss").Add(1) - return - } + // TODO(thrawn01): Remove + //if entry.IsExpired() { + // c.removeElement(ele) + // metricCacheAccess.WithLabelValues("miss").Add(1) + // return + //} metricCacheAccess.WithLabelValues("hit").Add(1) c.ll.MoveToFront(ele) diff --git a/otter.go b/otter.go index f04fc8e..9f60c55 100644 --- a/otter.go +++ b/otter.go @@ -37,9 +37,9 @@ func NewOtterCache(size int) (*OtterCache, error) { // Add adds a new CacheItem to the cache. The key must be provided via CacheItem.Key // returns true if the item was added to the cache; false if the item was too large -// for the cache. +// for the cache or already exists in the cache. func (o *OtterCache) Add(item *CacheItem) bool { - return o.cache.Set(item.Key, item) + return o.cache.SetIfAbsent(item.Key, item) } // GetItem returns an item in the cache that corresponds to the provided key @@ -50,16 +50,17 @@ func (o *OtterCache) GetItem(key string) (*CacheItem, bool) { return nil, false } - if item.IsExpired() { - metricCacheAccess.WithLabelValues("miss").Add(1) - // If the item is expired, just return `nil` - // - // We avoid the explicit deletion of the expired item to avoid acquiring a mutex lock in otter. - // Explicit deletions in otter require a mutex, which can cause performance bottlenecks - // under high concurrency scenarios. By allowing the item to be evicted naturally by - // otter's eviction mechanism, we avoid impacting performance under high concurrency. - return nil, false - } + // TODO(thrawn01): Remove + //if item.IsExpired() { + // metricCacheAccess.WithLabelValues("miss").Add(1) + // // If the item is expired, just return `nil` + // // + // // We avoid the explicit deletion of the expired item to avoid acquiring a mutex lock in otter. + // // Explicit deletions in otter require a mutex, which can cause performance bottlenecks + // // under high concurrency scenarios. By allowing the item to be evicted naturally by + // // otter's eviction mechanism, we avoid impacting performance under high concurrency. + // return nil, false + //} metricCacheAccess.WithLabelValues("hit").Add(1) return item, true } diff --git a/otter_test.go b/otter_test.go index 6eb629d..9c84df1 100644 --- a/otter_test.go +++ b/otter_test.go @@ -65,13 +65,18 @@ func TestOtterCache(t *testing.T) { } cache.Add(item1) - // Update same key. + // Update same key is refused item2 := &gubernator.CacheItem{ Key: key, Value: "new value", ExpireAt: expireAt, } - cache.Add(item2) + assert.False(t, cache.Add(item2)) + + // Fetch and update the CacheItem + update, ok := cache.GetItem(key) + assert.True(t, ok) + update.Value = "new value" // Verify. verifyItem, ok := cache.GetItem(key) diff --git a/store.go b/store.go index 1c23461..089ea50 100644 --- a/store.go +++ b/store.go @@ -16,7 +16,10 @@ limitations under the License. package gubernator -import "context" +import ( + "context" + "sync" +) // PERSISTENT STORE DETAILS @@ -27,6 +30,7 @@ import "context" // Both interfaces can be implemented simultaneously to ensure data is always saved to persistent storage. type LeakyBucketItem struct { + mutex sync.Mutex Limit int64 Duration int64 Remaining float64 @@ -47,18 +51,18 @@ type TokenBucketItem struct { // to maximize performance of gubernator. // Implementations MUST be threadsafe. type Store interface { - // Called by gubernator *after* a rate limit item is updated. It's up to the store to + // OnChange is called by gubernator *after* a rate limit item is updated. It's up to the store to // decide if this rate limit item should be persisted in the store. It's up to the // store to expire old rate limit items. The CacheItem represents the current state of // the rate limit item *after* the RateLimitReq has been applied. OnChange(ctx context.Context, r *RateLimitReq, item *CacheItem) - // Called by gubernator when a rate limit is missing from the cache. It's up to the store + // Get is called by gubernator when a rate limit is missing from the cache. It's up to the store // to decide if this request is fulfilled. Should return true if the request is fulfilled // and false if the request is not fulfilled or doesn't exist in the store. Get(ctx context.Context, r *RateLimitReq) (*CacheItem, bool) - // Called by gubernator when an existing rate limit should be removed from the store. + // Remove ic called by gubernator when an existing rate limit should be removed from the store. // NOTE: This is NOT called when an rate limit expires from the cache, store implementors // must expire rate limits in the store. Remove(ctx context.Context, key string) @@ -77,39 +81,40 @@ type Loader interface { Save(chan *CacheItem) error } -func NewMockStore() *MockStore { - ml := &MockStore{ - Called: make(map[string]int), - CacheItems: make(map[string]*CacheItem), - } - ml.Called["OnChange()"] = 0 - ml.Called["Remove()"] = 0 - ml.Called["Get()"] = 0 - return ml -} - -type MockStore struct { - Called map[string]int - CacheItems map[string]*CacheItem -} - -var _ Store = &MockStore{} - -func (ms *MockStore) OnChange(ctx context.Context, r *RateLimitReq, item *CacheItem) { - ms.Called["OnChange()"] += 1 - ms.CacheItems[item.Key] = item -} - -func (ms *MockStore) Get(ctx context.Context, r *RateLimitReq) (*CacheItem, bool) { - ms.Called["Get()"] += 1 - item, ok := ms.CacheItems[r.HashKey()] - return item, ok -} - -func (ms *MockStore) Remove(ctx context.Context, key string) { - ms.Called["Remove()"] += 1 - delete(ms.CacheItems, key) -} +// TODO Remove +//func NewMockStore() *MockStore { +// ml := &MockStore{ +// Called: make(map[string]int), +// CacheItems: make(map[string]*CacheItem), +// } +// ml.Called["OnChange()"] = 0 +// ml.Called["Remove()"] = 0 +// ml.Called["Get()"] = 0 +// return ml +//} +// +//type MockStore struct { +// Called map[string]int +// CacheItems map[string]*CacheItem +//} +// +//var _ Store = &MockStore{} +// +//func (ms *MockStore) OnChange(ctx context.Context, r *RateLimitReq, item *CacheItem) { +// ms.Called["OnChange()"] += 1 +// ms.CacheItems[item.Key] = item +//} +// +//func (ms *MockStore) Get(ctx context.Context, r *RateLimitReq) (*CacheItem, bool) { +// ms.Called["Get()"] += 1 +// item, ok := ms.CacheItems[r.HashKey()] +// return item, ok +//} +// +//func (ms *MockStore) Remove(ctx context.Context, key string) { +// ms.Called["Remove()"] += 1 +// delete(ms.CacheItems, key) +//} func NewMockLoader() *MockLoader { ml := &MockLoader{ diff --git a/store_test.go b/store_test.go index e7c58f6..ff29df0 100644 --- a/store_test.go +++ b/store_test.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "net" + "sync" "testing" "github.com/gubernator-io/gubernator/v2" @@ -124,6 +125,96 @@ func TestLoader(t *testing.T) { assert.Equal(t, gubernator.Status_UNDER_LIMIT, item.Status) } +type NoOpStore struct{} + +func (ms *NoOpStore) Remove(ctx context.Context, key string) {} +func (ms *NoOpStore) OnChange(ctx context.Context, r *gubernator.RateLimitReq, item *gubernator.CacheItem) { +} + +func (ms *NoOpStore) Get(ctx context.Context, r *gubernator.RateLimitReq) (*gubernator.CacheItem, bool) { + return &gubernator.CacheItem{ + Algorithm: gubernator.Algorithm_TOKEN_BUCKET, + Key: r.HashKey(), + Value: gubernator.TokenBucketItem{ + CreatedAt: gubernator.MillisecondNow(), + Duration: gubernator.Minute * 60, + Limit: 1_000, + Remaining: 1_000, + Status: 0, + }, + ExpireAt: 0, + }, true +} + +// The goal of this test is to generate some race conditions where multiple routines load from the store and or +// add items to the cache in parallel thus creating a race condition the code must then handle. +func TestHighContentionFromStore(t *testing.T) { + const ( + numGoroutines = 1_000 + numKeys = 400 + ) + store := &NoOpStore{} + srv := newV1Server(t, "localhost:0", gubernator.Config{ + Behaviors: gubernator.BehaviorConfig{ + GlobalSyncWait: clock.Millisecond * 50, // Suitable for testing but not production + GlobalTimeout: clock.Second, + }, + Store: store, + }) + client, err := gubernator.DialV1Server(srv.listener.Addr().String(), nil) + require.NoError(t, err) + + keys := GenerateRandomKeys(numKeys) + + var wg sync.WaitGroup + var ready sync.WaitGroup + wg.Add(numGoroutines) + ready.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + ready.Wait() + for idx := 0; idx < numKeys; idx++ { + _, err := client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{ + Requests: []*gubernator.RateLimitReq{ + { + Name: keys[idx], + UniqueKey: "high_contention_", + Algorithm: gubernator.Algorithm_TOKEN_BUCKET, + Duration: gubernator.Minute * 60, + Limit: numKeys, + Hits: 1, + }, + }, + }) + require.NoError(t, err) + } + wg.Done() + }() + ready.Done() + } + wg.Wait() + + for idx := 0; idx < numKeys; idx++ { + resp, err := client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{ + Requests: []*gubernator.RateLimitReq{ + { + Name: keys[idx], + UniqueKey: "high_contention_", + Algorithm: gubernator.Algorithm_TOKEN_BUCKET, + Duration: gubernator.Minute * 60, + Limit: numKeys, + Hits: 0, + }, + }, + }) + require.NoError(t, err) + assert.Equal(t, int64(0), resp.Responses[0].Remaining) + } + + assert.NoError(t, srv.Close()) +} + func TestStore(t *testing.T) { ctx := context.Background() setup := func() (*MockStore2, *v1Server, gubernator.V1Client) {