diff --git a/ratelimits/bucket.go b/ratelimits/bucket.go index a8901e10ed3..501d1fd2c44 100644 --- a/ratelimits/bucket.go +++ b/ratelimits/bucket.go @@ -5,52 +5,48 @@ import ( "net" ) -// Bucket identifies a specific subscriber rate limit bucket to the Limiter. -type Bucket struct { - name Name - key string -} - -// BucketWithCost is a bucket with an associated cost. -type BucketWithCost struct { - Bucket - cost int64 -} +// BucketId should only be created using the New*BucketId functions. It is used +// by the Limiter to look up the bucket and limit overrides for a specific +// subscriber and limit. +type BucketId struct { + // limit is the name of the associated rate limit. It is used for looking up + // default limits. + limit Name -// WithCost returns a BucketWithCost for the provided cost. -func (b Bucket) WithCost(cost int64) BucketWithCost { - return BucketWithCost{b, cost} + // bucketKey is the limit Name enum (e.g. "1") concatenated with the + // subscriber identifier specific to the associate limit Name type. + bucketKey string } -// NewRegistrationsPerIPAddressBucket returns a Bucket for the provided IP +// NewRegistrationsPerIPAddressBucketId returns a BucketId for the provided IP // address. -func NewRegistrationsPerIPAddressBucket(ip net.IP) (Bucket, error) { +func NewRegistrationsPerIPAddressBucketId(ip net.IP) (BucketId, error) { id := ip.String() err := validateIdForName(NewRegistrationsPerIPAddress, id) if err != nil { - return Bucket{}, err + return BucketId{}, err } - return Bucket{ - name: NewRegistrationsPerIPAddress, - key: joinWithColon(NewRegistrationsPerIPAddress.EnumString(), id), + return BucketId{ + limit: NewRegistrationsPerIPAddress, + bucketKey: joinWithColon(NewRegistrationsPerIPAddress.EnumString(), id), }, nil } -// NewRegistrationsPerIPv6RangeBucket returns a Bucket for the /48 IPv6 range -// containing the provided IPv6 address. -func NewRegistrationsPerIPv6RangeBucket(ip net.IP) (Bucket, error) { +// NewRegistrationsPerIPv6RangeBucketId returns a BucketId for the /48 IPv6 +// range containing the provided IPv6 address. +func NewRegistrationsPerIPv6RangeBucketId(ip net.IP) (BucketId, error) { if ip.To4() != nil { - return Bucket{}, fmt.Errorf("invalid IPv6 address, %q must be an IPv6 address", ip.String()) + return BucketId{}, fmt.Errorf("invalid IPv6 address, %q must be an IPv6 address", ip.String()) } ipMask := net.CIDRMask(48, 128) ipNet := &net.IPNet{IP: ip.Mask(ipMask), Mask: ipMask} id := ipNet.String() err := validateIdForName(NewRegistrationsPerIPv6Range, id) if err != nil { - return Bucket{}, err + return BucketId{}, err } - return Bucket{ - name: NewRegistrationsPerIPv6Range, - key: joinWithColon(NewRegistrationsPerIPv6Range.EnumString(), id), + return BucketId{ + limit: NewRegistrationsPerIPv6Range, + bucketKey: joinWithColon(NewRegistrationsPerIPv6Range.EnumString(), id), }, nil } diff --git a/ratelimits/limiter.go b/ratelimits/limiter.go index 4471a9ee5c7..eab087bf5df 100644 --- a/ratelimits/limiter.go +++ b/ratelimits/limiter.go @@ -95,6 +95,20 @@ func NewLimiter(clk clock.Clock, source source, defaults, overrides string, stat return limiter, nil } +// Transaction is a cost to be spent or refunded from a specific BucketId. +type Transaction struct { + BucketId + cost int64 +} + +// NewTransaction creates a new Transaction for the provided BucketId and cost. +func NewTransaction(b BucketId, cost int64) Transaction { + return Transaction{ + BucketId: b, + cost: cost, + } +} + type Decision struct { // Allowed is true if the bucket possessed enough capacity to allow the // request given the cost. @@ -123,12 +137,12 @@ type Decision struct { // satisfy the cost and represents the hypothetical state of the bucket IF the // cost WERE to be deducted. If no bucket exists it will NOT be created. No // state is persisted to the underlying datastore. -func (l *Limiter) Check(ctx context.Context, bucket BucketWithCost) (*Decision, error) { - if bucket.cost < 0 { +func (l *Limiter) Check(ctx context.Context, txn Transaction) (*Decision, error) { + if txn.cost < 0 { return nil, ErrInvalidCostForCheck } - limit, err := l.getLimit(bucket.name, bucket.key) + limit, err := l.getLimit(txn.limit, txn.bucketKey) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -136,14 +150,14 @@ func (l *Limiter) Check(ctx context.Context, bucket BucketWithCost) (*Decision, return nil, err } - if bucket.cost > limit.Burst { + if txn.cost > limit.Burst { return nil, ErrInvalidCostOverLimit } // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) - tat, err := l.source.Get(ctx, bucket.key) + tat, err := l.source.Get(ctx, txn.bucketKey) if err != nil { if !errors.Is(err, ErrBucketNotFound) { return nil, err @@ -151,9 +165,9 @@ func (l *Limiter) Check(ctx context.Context, bucket BucketWithCost) (*Decision, // First request from this client. No need to initialize the bucket // because this is a check, not a spend. A TAT of "now" is equivalent to // a full bucket. - return maybeSpend(l.clk, limit, l.clk.Now(), bucket.cost), nil + return maybeSpend(l.clk, limit, l.clk.Now(), txn.cost), nil } - return maybeSpend(l.clk, limit, tat, bucket.cost), nil + return maybeSpend(l.clk, limit, tat, txn.cost), nil } // Spend attempts to deduct the cost from the provided bucket's capacity. The @@ -162,12 +176,12 @@ func (l *Limiter) Check(ctx context.Context, bucket BucketWithCost) (*Decision, // be created WITH the cost factored into its initial state. The new bucket // state is persisted to the underlying datastore, if applicable, before // returning. -func (l *Limiter) Spend(ctx context.Context, bucket BucketWithCost) (*Decision, error) { - if bucket.cost <= 0 { +func (l *Limiter) Spend(ctx context.Context, txn Transaction) (*Decision, error) { + if txn.cost <= 0 { return nil, ErrInvalidCost } - limit, err := l.getLimit(bucket.name, bucket.key) + limit, err := l.getLimit(txn.limit, txn.bucketKey) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -175,24 +189,24 @@ func (l *Limiter) Spend(ctx context.Context, bucket BucketWithCost) (*Decision, return nil, err } - if bucket.cost > limit.Burst { + if txn.cost > limit.Burst { return nil, ErrInvalidCostOverLimit } start := l.clk.Now() status := Denied defer func() { - l.spendLatency.WithLabelValues(bucket.name.String(), status).Observe(l.clk.Since(start).Seconds()) + l.spendLatency.WithLabelValues(txn.limit.String(), status).Observe(l.clk.Since(start).Seconds()) }() // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) - tat, err := l.source.Get(ctx, bucket.key) + tat, err := l.source.Get(ctx, txn.bucketKey) if err != nil { if errors.Is(err, ErrBucketNotFound) { // First request from this client. - d, err := l.initialize(ctx, limit, bucket) + d, err := l.initialize(ctx, limit, txn) if err != nil { return nil, err } @@ -204,19 +218,19 @@ func (l *Limiter) Spend(ctx context.Context, bucket BucketWithCost) (*Decision, return nil, err } - d := maybeSpend(l.clk, limit, tat, bucket.cost) + d := maybeSpend(l.clk, limit, tat, txn.cost) if limit.isOverride { // Calculate the current utilization of the override limit. utilization := float64(limit.Burst-d.Remaining) / float64(limit.Burst) - l.overrideUsageGauge.WithLabelValues(bucket.name.String(), bucket.key).Set(utilization) + l.overrideUsageGauge.WithLabelValues(txn.limit.String(), txn.bucketKey).Set(utilization) } if !d.Allowed { return d, nil } - err = l.source.Set(ctx, bucket.key, d.newTAT) + err = l.source.Set(ctx, txn.bucketKey, d.newTAT) if err != nil { return nil, err } @@ -235,12 +249,12 @@ func (l *Limiter) Spend(ctx context.Context, bucket BucketWithCost) (*Decision, // instance, if a bucket has a maximum capacity of 10 and currently has 5 // requests remaining, a refund request of 7 will result in the bucket reaching // its maximum capacity of 10, not 12. -func (l *Limiter) Refund(ctx context.Context, bucket BucketWithCost) (*Decision, error) { - if bucket.cost <= 0 { +func (l *Limiter) Refund(ctx context.Context, txn Transaction) (*Decision, error) { + if txn.cost <= 0 { return nil, ErrInvalidCost } - limit, err := l.getLimit(bucket.name, bucket.key) + limit, err := l.getLimit(txn.limit, txn.bucketKey) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -251,37 +265,37 @@ func (l *Limiter) Refund(ctx context.Context, bucket BucketWithCost) (*Decision, // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) - tat, err := l.source.Get(ctx, bucket.key) + tat, err := l.source.Get(ctx, txn.bucketKey) if err != nil { return nil, err } - d := maybeRefund(l.clk, limit, tat, bucket.cost) + d := maybeRefund(l.clk, limit, tat, txn.cost) if !d.Allowed { // The bucket is already at maximum capacity. return d, nil } - return d, l.source.Set(ctx, bucket.key, d.newTAT) + return d, l.source.Set(ctx, txn.bucketKey, d.newTAT) } // Reset resets the specified bucket to its maximum capacity. The new bucket // state is persisted to the underlying datastore before returning. -func (l *Limiter) Reset(ctx context.Context, bucket Bucket) error { +func (l *Limiter) Reset(ctx context.Context, bucketId BucketId) error { // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) - return l.source.Delete(ctx, bucket.key) + return l.source.Delete(ctx, bucketId.bucketKey) } // initialize creates a new bucket and sets its TAT to now, which is equivalent // to a full bucket. The new bucket state is persisted to the underlying // datastore before returning. -func (l *Limiter) initialize(ctx context.Context, rl limit, bucket BucketWithCost) (*Decision, error) { - d := maybeSpend(l.clk, rl, l.clk.Now(), bucket.cost) +func (l *Limiter) initialize(ctx context.Context, rl limit, txn Transaction) (*Decision, error) { + d := maybeSpend(l.clk, rl, l.clk.Now(), txn.cost) // Remove cancellation from the request context so that transactions are not // interrupted by a client disconnect. ctx = context.WithoutCancel(ctx) - err := l.source.Set(ctx, bucket.key, d.newTAT) + err := l.source.Set(ctx, txn.bucketKey, d.newTAT) if err != nil { return nil, err } @@ -289,19 +303,19 @@ func (l *Limiter) initialize(ctx context.Context, rl limit, bucket BucketWithCos } -// getLimit returns the limit for the specified by name and bucketKey, name is -// required, bucketKey is optional. If bucketKey is left unspecified, the -// default limit for the limit specified by name is returned. If no default -// limit exists for the specified name, errLimitDisabled is returned. -func (l *Limiter) getLimit(name Name, bucketKey string) (limit, error) { +// getLimit returns the limit for the specified by name and id, name is +// required, id is optional. If id is left unspecified, the default limit for +// the limit specified by name is returned. If no default limit exists for the +// specified name, errLimitDisabled is returned. +func (l *Limiter) getLimit(name Name, id string) (limit, error) { if !name.isValid() { // This should never happen. Callers should only be specifying the limit // Name enums defined in this package. return limit{}, fmt.Errorf("specified name enum %q, is invalid", name) } - if bucketKey != "" { + if id != "" { // Check for override. - ol, ok := l.overrides[bucketKey] + ol, ok := l.overrides[id] if ok { return ol, nil } diff --git a/ratelimits/limiter_test.go b/ratelimits/limiter_test.go index 3b33215db0d..7f108810279 100644 --- a/ratelimits/limiter_test.go +++ b/ratelimits/limiter_test.go @@ -58,8 +58,8 @@ func Test_Limiter_CheckWithLimitNoExist(t *testing.T) { testCtx, limiters, _, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - bucket := Bucket{name: Name(9999), key: testIP} - _, err := l.Check(testCtx, bucket.WithCost(1)) + bucketId := BucketId{limit: Name(9999), bucketKey: testIP} + _, err := l.Check(testCtx, NewTransaction(bucketId, 1)) test.AssertError(t, err, "should error") }) } @@ -76,26 +76,26 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { "limit": NewRegistrationsPerIPAddress.String(), "bucket_key": joinWithColon(NewRegistrationsPerIPAddress.EnumString(), tenZeroZeroTwo)}, 0) - overriddenBucket, err := NewRegistrationsPerIPAddressBucket(net.ParseIP(tenZeroZeroTwo)) + overriddenBucketId, err := NewRegistrationsPerIPAddressBucketId(net.ParseIP(tenZeroZeroTwo)) test.AssertNotError(t, err, "should not error") // Attempt to check a spend of 41 requests (a cost > the limit burst // capacity), this should fail with a specific error. - _, err = l.Check(testCtx, overriddenBucket.WithCost(41)) + _, err = l.Check(testCtx, NewTransaction(overriddenBucketId, 41)) test.AssertErrorIs(t, err, ErrInvalidCostOverLimit) // Attempt to spend 41 requests (a cost > the limit burst capacity), // this should fail with a specific error. - _, err = l.Spend(testCtx, overriddenBucket.WithCost(41)) + _, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 41)) test.AssertErrorIs(t, err, ErrInvalidCostOverLimit) // Attempt to spend all 40 requests, this should succeed. - d, err := l.Spend(testCtx, overriddenBucket.WithCost(40)) + d, err := l.Spend(testCtx, NewTransaction(overriddenBucketId, 40)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, overriddenBucket.WithCost(1)) + d, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -115,7 +115,7 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { clk.Add(d.RetryIn) // We should be allowed to spend 1 more request. - d, err = l.Spend(testCtx, overriddenBucket.WithCost(1)) + d, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -126,21 +126,21 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { // Quickly spend 40 requests in a row. for i := 0; i < 40; i++ { - d, err = l.Spend(testCtx, overriddenBucket.WithCost(1)) + d, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(39-i)) } // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, overriddenBucket.WithCost(1)) + d, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) // Reset between tests. - err = l.Reset(testCtx, overriddenBucket) + err = l.Reset(testCtx, overriddenBucketId) test.AssertNotError(t, err, "should not error") }) } @@ -151,12 +151,12 @@ func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) { testCtx, limiters, _, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - bucket, err := NewRegistrationsPerIPAddressBucket(net.ParseIP(testIP)) + bucketId, err := NewRegistrationsPerIPAddressBucketId(net.ParseIP(testIP)) test.AssertNotError(t, err, "should not error") // Check on an empty bucket should return the theoretical next state // of that bucket if the cost were spent. - d, err := l.Check(testCtx, bucket.WithCost(1)) + d, err := l.Check(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(19)) @@ -167,7 +167,7 @@ func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) { // However, that cost should not be spent yet, a 0 cost check should // tell us that we actually have 20 remaining. - d, err = l.Check(testCtx, bucket.WithCost(0)) + d, err = l.Check(testCtx, NewTransaction(bucketId, 0)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(20)) @@ -175,12 +175,12 @@ func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) { test.AssertEquals(t, d.RetryIn, time.Duration(0)) // Reset our bucket. - err = l.Reset(testCtx, bucket) + err = l.Reset(testCtx, bucketId) test.AssertNotError(t, err, "should not error") // Similar to above, but we'll use Spend() to actually initialize // the bucket. Spend should return the same result as Check. - d, err = l.Spend(testCtx, bucket.WithCost(1)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(19)) @@ -191,7 +191,7 @@ func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) { // However, that cost should not be spent yet, a 0 cost check should // tell us that we actually have 19 remaining. - d, err = l.Check(testCtx, bucket.WithCost(0)) + d, err = l.Check(testCtx, NewTransaction(bucketId, 0)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(19)) @@ -208,23 +208,23 @@ func Test_Limiter_RefundAndSpendCostErr(t *testing.T) { testCtx, limiters, _, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - bucket, err := NewRegistrationsPerIPAddressBucket(net.ParseIP(testIP)) + bucketId, err := NewRegistrationsPerIPAddressBucketId(net.ParseIP(testIP)) test.AssertNotError(t, err, "should not error") // Spend a cost of 0, which should fail. - _, err = l.Spend(testCtx, bucket.WithCost(0)) + _, err = l.Spend(testCtx, NewTransaction(bucketId, 0)) test.AssertErrorIs(t, err, ErrInvalidCost) // Spend a negative cost, which should fail. - _, err = l.Spend(testCtx, bucket.WithCost(-1)) + _, err = l.Spend(testCtx, NewTransaction(bucketId, -1)) test.AssertErrorIs(t, err, ErrInvalidCost) // Refund a cost of 0, which should fail. - _, err = l.Refund(testCtx, bucket.WithCost(0)) + _, err = l.Refund(testCtx, NewTransaction(bucketId, 0)) test.AssertErrorIs(t, err, ErrInvalidCost) // Refund a negative cost, which should fail. - _, err = l.Refund(testCtx, bucket.WithCost(-1)) + _, err = l.Refund(testCtx, NewTransaction(bucketId, -1)) test.AssertErrorIs(t, err, ErrInvalidCost) }) } @@ -235,10 +235,10 @@ func Test_Limiter_CheckWithBadCost(t *testing.T) { testCtx, limiters, _, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - bucket, err := NewRegistrationsPerIPAddressBucket(net.ParseIP(testIP)) + bucketId, err := NewRegistrationsPerIPAddressBucketId(net.ParseIP(testIP)) test.AssertNotError(t, err, "should not error") - _, err = l.Check(testCtx, bucket.WithCost(-1)) + _, err = l.Check(testCtx, NewTransaction(bucketId, -1)) test.AssertErrorIs(t, err, ErrInvalidCostForCheck) }) } @@ -249,23 +249,23 @@ func Test_Limiter_DefaultLimits(t *testing.T) { testCtx, limiters, clk, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - bucket, err := NewRegistrationsPerIPAddressBucket(net.ParseIP(testIP)) + bucketId, err := NewRegistrationsPerIPAddressBucketId(net.ParseIP(testIP)) test.AssertNotError(t, err, "should not error") // Attempt to spend 21 requests (a cost > the limit burst capacity), // this should fail with a specific error. - _, err = l.Spend(testCtx, bucket.WithCost(21)) + _, err = l.Spend(testCtx, NewTransaction(bucketId, 21)) test.AssertErrorIs(t, err, ErrInvalidCostOverLimit) // Attempt to spend all 20 requests, this should succeed. - d, err := l.Spend(testCtx, bucket.WithCost(20)) + d, err := l.Spend(testCtx, NewTransaction(bucketId, 20)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, bucket.WithCost(1)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -279,7 +279,7 @@ func Test_Limiter_DefaultLimits(t *testing.T) { clk.Add(d.RetryIn) // We should be allowed to spend 1 more request. - d, err = l.Spend(testCtx, bucket.WithCost(1)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -290,14 +290,14 @@ func Test_Limiter_DefaultLimits(t *testing.T) { // Quickly spend 20 requests in a row. for i := 0; i < 20; i++ { - d, err = l.Spend(testCtx, bucket.WithCost(1)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(19-i)) } // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, bucket.WithCost(1)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -311,33 +311,33 @@ func Test_Limiter_RefundAndReset(t *testing.T) { testCtx, limiters, clk, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - bucket, err := NewRegistrationsPerIPAddressBucket(net.ParseIP(testIP)) + bucketId, err := NewRegistrationsPerIPAddressBucketId(net.ParseIP(testIP)) test.AssertNotError(t, err, "should not error") // Attempt to spend all 20 requests, this should succeed. - d, err := l.Spend(testCtx, bucket.WithCost(20)) + d, err := l.Spend(testCtx, NewTransaction(bucketId, 20)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) // Refund 10 requests. - d, err = l.Refund(testCtx, bucket.WithCost(10)) + d, err = l.Refund(testCtx, NewTransaction(bucketId, 10)) test.AssertNotError(t, err, "should not error") test.AssertEquals(t, d.Remaining, int64(10)) // Spend 10 requests, this should succeed. - d, err = l.Spend(testCtx, bucket.WithCost(10)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 10)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) - err = l.Reset(testCtx, bucket) + err = l.Reset(testCtx, bucketId) test.AssertNotError(t, err, "should not error") // Attempt to spend 20 more requests, this should succeed. - d, err = l.Spend(testCtx, bucket.WithCost(20)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 20)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -347,7 +347,7 @@ func Test_Limiter_RefundAndReset(t *testing.T) { clk.Add(d.ResetIn) // Refund 1 requests above our limit, this should fail. - d, err = l.Refund(testCtx, bucket.WithCost(1)) + d, err = l.Refund(testCtx, NewTransaction(bucketId, 1)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(20)) diff --git a/ratelimits/utilities.go b/ratelimits/utilities.go index dd5a1167eca..8a7cbca7087 100644 --- a/ratelimits/utilities.go +++ b/ratelimits/utilities.go @@ -2,32 +2,9 @@ package ratelimits import ( "strings" - - "github.com/letsencrypt/boulder/core" - "github.com/weppos/publicsuffix-go/publicsuffix" ) // joinWithColon joins the provided args with a colon. func joinWithColon(args ...string) string { return strings.Join(args, ":") } - -// DomainsForRateLimiting transforms a list of FQDNs into a list of eTLD+1's -// for the purpose of rate limiting. It also de-duplicates the output -// domains. Exact public suffix matches are included. -func DomainsForRateLimiting(names []string) []string { - var domains []string - for _, name := range names { - domain, err := publicsuffix.Domain(name) - if err != nil { - // The only possible errors are: - // (1) publicsuffix.Domain is giving garbage values - // (2) the public suffix is the domain itself - // We assume 2 and include the original name in the result. - domains = append(domains, name) - } else { - domains = append(domains, domain) - } - } - return core.UniqueLowerNames(domains) -} diff --git a/ratelimits/utilities_test.go b/ratelimits/utilities_test.go deleted file mode 100644 index 9c68d3a6e89..00000000000 --- a/ratelimits/utilities_test.go +++ /dev/null @@ -1,27 +0,0 @@ -package ratelimits - -import ( - "testing" - - "github.com/letsencrypt/boulder/test" -) - -func TestDomainsForRateLimiting(t *testing.T) { - domains := DomainsForRateLimiting([]string{}) - test.AssertEquals(t, len(domains), 0) - - domains = DomainsForRateLimiting([]string{"www.example.com", "example.com"}) - test.AssertDeepEquals(t, domains, []string{"example.com"}) - - domains = DomainsForRateLimiting([]string{"www.example.com", "example.com", "www.example.co.uk"}) - test.AssertDeepEquals(t, domains, []string{"example.co.uk", "example.com"}) - - domains = DomainsForRateLimiting([]string{"www.example.com", "example.com", "www.example.co.uk", "co.uk"}) - test.AssertDeepEquals(t, domains, []string{"co.uk", "example.co.uk", "example.com"}) - - domains = DomainsForRateLimiting([]string{"foo.bar.baz.www.example.com", "baz.example.com"}) - test.AssertDeepEquals(t, domains, []string{"example.com"}) - - domains = DomainsForRateLimiting([]string{"github.io", "foo.github.io", "bar.github.io"}) - test.AssertDeepEquals(t, domains, []string{"bar.github.io", "foo.github.io", "github.io"}) -} diff --git a/wfe2/wfe.go b/wfe2/wfe.go index d7ae9a29306..4ff3e68cedf 100644 --- a/wfe2/wfe.go +++ b/wfe2/wfe.go @@ -638,13 +638,13 @@ func (wfe *WebFrontEndImpl) checkNewAccountLimits(ctx context.Context, ip net.IP wfe.log.Warningf("checking %s rate limit: %s", limit, err) } - bucket, err := ratelimits.NewRegistrationsPerIPAddressBucket(ip) + bucketId, err := ratelimits.NewRegistrationsPerIPAddressBucketId(ip) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPAddress) return } - decision, err := wfe.limiter.Spend(ctx, bucket.WithCost(1)) + decision, err := wfe.limiter.Spend(ctx, ratelimits.NewTransaction(bucketId, 1)) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPAddress) return @@ -655,13 +655,13 @@ func (wfe *WebFrontEndImpl) checkNewAccountLimits(ctx context.Context, ip net.IP return } - bucket, err = ratelimits.NewRegistrationsPerIPv6RangeBucket(ip) + bucketId, err = ratelimits.NewRegistrationsPerIPv6RangeBucketId(ip) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPv6Range) return } - _, err = wfe.limiter.Spend(ctx, bucket.WithCost(1)) + _, err = wfe.limiter.Spend(ctx, ratelimits.NewTransaction(bucketId, 1)) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPv6Range) } @@ -686,13 +686,13 @@ func (wfe *WebFrontEndImpl) refundNewAccountLimits(ctx context.Context, ip net.I wfe.log.Warningf("refunding %s rate limit: %s", limit, err) } - bucket, err := ratelimits.NewRegistrationsPerIPAddressBucket(ip) + bucketId, err := ratelimits.NewRegistrationsPerIPAddressBucketId(ip) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPAddress) return } - _, err = wfe.limiter.Refund(ctx, bucket.WithCost(1)) + _, err = wfe.limiter.Refund(ctx, ratelimits.NewTransaction(bucketId, 1)) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPAddress) return @@ -702,13 +702,13 @@ func (wfe *WebFrontEndImpl) refundNewAccountLimits(ctx context.Context, ip net.I return } - bucket, err = ratelimits.NewRegistrationsPerIPv6RangeBucket(ip) + bucketId, err = ratelimits.NewRegistrationsPerIPv6RangeBucketId(ip) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPv6Range) return } - _, err = wfe.limiter.Refund(ctx, bucket.WithCost(1)) + _, err = wfe.limiter.Refund(ctx, ratelimits.NewTransaction(bucketId, 1)) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPv6Range) }