From 7fd825bc7509e80ab6ea9577f474bb3ce3b8caee Mon Sep 17 00:00:00 2001 From: Samantha Date: Fri, 1 Dec 2023 17:10:34 -0500 Subject: [PATCH] Addressing commment, part 2 --- ratelimits/limiter.go | 80 ++++---------------------------------- ratelimits/source.go | 14 ------- ratelimits/source_redis.go | 15 ------- 3 files changed, 7 insertions(+), 102 deletions(-) diff --git a/ratelimits/limiter.go b/ratelimits/limiter.go index 6807bda722e..6eecd48eebc 100644 --- a/ratelimits/limiter.go +++ b/ratelimits/limiter.go @@ -113,55 +113,7 @@ func (l *Limiter) Check(ctx context.Context, txn Transaction) (*Decision, error) // state is persisted to the underlying datastore, if applicable, before // returning. func (l *Limiter) Spend(ctx context.Context, txn Transaction) (*Decision, error) { - if txn.allowOnly() { - return allowedDecision, nil - } - - start := l.clk.Now() - status := Denied - defer func() { - l.spendLatency.WithLabelValues(txn.limit.name.String(), status).Observe(l.clk.Since(start).Seconds()) - }() - - // Remove cancellation from the request context so that transactions are not - // interrupted by a client disconnect. - ctx = context.WithoutCancel(ctx) - tat, err := l.source.Get(ctx, txn.bucketKey) - if err != nil { - if !errors.Is(err, ErrBucketNotFound) { - return nil, err - - } - // First request from this client. - tat = l.clk.Now() - } - - d := maybeSpend(l.clk, txn.limit, tat, txn.cost) - - if txn.limit.isOverride { - // Calculate the current utilization of the override limit. - utilization := float64(txn.limit.Burst-d.Remaining) / float64(txn.limit.Burst) - l.overrideUsageGauge.WithLabelValues(txn.limit.name.String(), txn.bucketKey).Set(utilization) - } - - if !d.Allowed { - if txn.spendOnly() { - return allowedDecision, nil - } - return d, nil - } - - if tat == d.newTAT || txn.checkOnly() { - // Don't update storage - return d, nil - } - - err = l.source.Set(ctx, txn.bucketKey, d.newTAT) - if err != nil { - return nil, err - } - status = Allowed - return d, nil + return l.BatchSpend(ctx, []Transaction{txn}) } func prepareBatch(txns []Transaction) ([]Transaction, []string, error) { @@ -284,26 +236,7 @@ func (l *Limiter) BatchSpend(ctx context.Context, txns []Transaction) (*Decision // requests remaining, a refund request of 7 will result in the bucket reaching // its maximum capacity of 10, not 12. func (l *Limiter) Refund(ctx context.Context, txn Transaction) (*Decision, error) { - if txn.allowOnly() { - return allowedDecision, nil - } - // Remove cancellation from the request context so that transactions are not - // interrupted by a client disconnect. - ctx = context.WithoutCancel(ctx) - tat, err := l.source.Get(ctx, txn.bucketKey) - if err != nil { - return nil, err - } - d := maybeRefund(l.clk, txn.limit, tat, txn.cost) - if tat == d.newTAT || txn.checkOnly() { - return maybeRefund(l.clk, txn.limit, tat, 0), nil - } - if d.Allowed { - // Persist the new bucket state. - return d, l.source.Set(ctx, txn.bucketKey, d.newTAT) - } - // Bucket is already full. - return d, nil + return l.BatchRefund(ctx, []Transaction{txn}) } // BatchRefund attempts to refund all or some of the costs to the provided @@ -344,12 +277,13 @@ func (l *Limiter) BatchRefund(ctx context.Context, txns []Transaction) (*Decisio continue } - d := maybeRefund(l.clk, txn.limit, tat, txn.cost) - if tat == d.newTAT || txn.checkOnly() { - d = maybeRefund(l.clk, txn.limit, tat, 0) + var cost int64 + if !txn.checkOnly() { + cost = txn.cost } + d := maybeRefund(l.clk, txn.limit, tat, cost) batchDecision.merge(d) - if d.Allowed && (tat != d.newTAT) { + if d.Allowed && tat != d.newTAT { // New bucket state should be persisted. newTATs[txn.bucketKey] = d.newTAT } diff --git a/ratelimits/source.go b/ratelimits/source.go index 0f8151993fa..77f43b73961 100644 --- a/ratelimits/source.go +++ b/ratelimits/source.go @@ -12,13 +12,6 @@ var ErrBucketNotFound = fmt.Errorf("bucket not found") // source is an interface for creating and modifying TATs. type source interface { - // Set stores the TAT at the specified bucketKey (formatted as 'name:id'). - // Implementations MUST ensure non-blocking operations by either: - // a) applying a deadline or timeout to the context WITHIN the method, or - // b) guaranteeing the operation will not block indefinitely (e.g. via - // the underlying storage client implementation). - Set(ctx context.Context, bucketKey string, tat time.Time) error - // BatchSet stores the TATs at the specified bucketKeys (formatted as // 'name:id'). Implementations MUST ensure non-blocking operations by // either: @@ -63,13 +56,6 @@ func newInmem() *inmem { return &inmem{m: make(map[string]time.Time)} } -func (in *inmem) Set(_ context.Context, bucketKey string, tat time.Time) error { - in.Lock() - defer in.Unlock() - in.m[bucketKey] = tat - return nil -} - func (in *inmem) BatchSet(_ context.Context, bucketKeys map[string]time.Time) error { in.Lock() defer in.Unlock() diff --git a/ratelimits/source_redis.go b/ratelimits/source_redis.go index 1251f16c828..2c807c9d4e8 100644 --- a/ratelimits/source_redis.go +++ b/ratelimits/source_redis.go @@ -68,21 +68,6 @@ func resultForError(err error) string { return "failed" } -// Set stores the TAT at the specified bucketKey. It returns an error if the -// operation failed and nil otherwise. -func (r *RedisSource) Set(ctx context.Context, bucketKey string, tat time.Time) error { - start := r.clk.Now() - - err := r.client.Set(ctx, bucketKey, tat.UnixNano(), 0).Err() - if err != nil { - r.latency.With(prometheus.Labels{"call": "set", "result": resultForError(err)}).Observe(time.Since(start).Seconds()) - return err - } - - r.latency.With(prometheus.Labels{"call": "set", "result": "success"}).Observe(time.Since(start).Seconds()) - return nil -} - // BatchSet stores TATs at the specified bucketKeys using a pipelined Redis // Transaction in order to reduce the number of round-trips to each Redis shard. // An error is returned if the operation failed and nil otherwise.