diff --git a/ra/ra.go b/ra/ra.go index ab19f842de55..6dd293b61035 100644 --- a/ra/ra.go +++ b/ra/ra.go @@ -20,7 +20,6 @@ import ( "github.com/jmhodges/clock" "github.com/prometheus/client_golang/prometheus" - "github.com/weppos/publicsuffix-go/publicsuffix" "golang.org/x/crypto/ocsp" "google.golang.org/grpc" "google.golang.org/protobuf/proto" @@ -1377,26 +1376,6 @@ func (ra *RegistrationAuthorityImpl) getSCTs(ctx context.Context, cert []byte, e return scts, nil } -// domainsForRateLimiting transforms a list of FQDNs into a list of eTLD+1's -// for the purpose of rate limiting. It also de-duplicates the output -// domains. Exact public suffix matches are included. -func domainsForRateLimiting(names []string) []string { - var domains []string - for _, name := range names { - domain, err := publicsuffix.Domain(name) - if err != nil { - // The only possible errors are: - // (1) publicsuffix.Domain is giving garbage values - // (2) the public suffix is the domain itself - // We assume 2 and include the original name in the result. - domains = append(domains, name) - } else { - domains = append(domains, domain) - } - } - return core.UniqueLowerNames(domains) -} - // enforceNameCounts uses the provided count RPC to find a count of certificates // for each of the names. If the count for any of the names exceeds the limit // for the given registration then the names out of policy are returned to be @@ -1451,7 +1430,7 @@ func (ra *RegistrationAuthorityImpl) checkCertificatesPerNameLimit(ctx context.C return nil } - tldNames := domainsForRateLimiting(names) + tldNames := ratelimits.DomainsForRateLimiting(names) namesOutOfLimit, earliest, err := ra.enforceNameCounts(ctx, tldNames, limit, regID) if err != nil { return fmt.Errorf("checking certificates per name limit for %q: %s", diff --git a/ra/ra_test.go b/ra/ra_test.go index d95dd3322507..02e17d6232e5 100644 --- a/ra/ra_test.go +++ b/ra/ra_test.go @@ -1113,26 +1113,6 @@ func TestAuthzFailedRateLimitingNewOrder(t *testing.T) { testcase() } -func TestDomainsForRateLimiting(t *testing.T) { - domains := domainsForRateLimiting([]string{}) - test.AssertEquals(t, len(domains), 0) - - domains = domainsForRateLimiting([]string{"www.example.com", "example.com"}) - test.AssertDeepEquals(t, domains, []string{"example.com"}) - - domains = domainsForRateLimiting([]string{"www.example.com", "example.com", "www.example.co.uk"}) - test.AssertDeepEquals(t, domains, []string{"example.co.uk", "example.com"}) - - domains = domainsForRateLimiting([]string{"www.example.com", "example.com", "www.example.co.uk", "co.uk"}) - test.AssertDeepEquals(t, domains, []string{"co.uk", "example.co.uk", "example.com"}) - - domains = domainsForRateLimiting([]string{"foo.bar.baz.www.example.com", "baz.example.com"}) - test.AssertDeepEquals(t, domains, []string{"example.com"}) - - domains = domainsForRateLimiting([]string{"github.io", "foo.github.io", "bar.github.io"}) - test.AssertDeepEquals(t, domains, []string{"bar.github.io", "foo.github.io", "github.io"}) -} - type mockSAWithNameCounts struct { mocks.StorageAuthority nameCounts *sapb.CountByNames diff --git a/ratelimits/bucket.go b/ratelimits/bucket.go index 501d1fd2c447..a60965c019f1 100644 --- a/ratelimits/bucket.go +++ b/ratelimits/bucket.go @@ -1,17 +1,21 @@ package ratelimits import ( + "errors" "fmt" "net" + "strconv" + + "github.com/letsencrypt/boulder/core" ) // BucketId should only be created using the New*BucketId functions. It is used // by the Limiter to look up the bucket and limit overrides for a specific // subscriber and limit. type BucketId struct { - // limit is the name of the associated rate limit. It is used for looking up - // default limits. - limit Name + // limitName is the name of the associated rate limit. It is used for + // looking up default limits. + limitName Name // bucketKey is the limit Name enum (e.g. "1") concatenated with the // subscriber identifier specific to the associate limit Name type. @@ -27,7 +31,7 @@ func NewRegistrationsPerIPAddressBucketId(ip net.IP) (BucketId, error) { return BucketId{}, err } return BucketId{ - limit: NewRegistrationsPerIPAddress, + limitName: NewRegistrationsPerIPAddress, bucketKey: joinWithColon(NewRegistrationsPerIPAddress.EnumString(), id), }, nil } @@ -46,7 +50,113 @@ func NewRegistrationsPerIPv6RangeBucketId(ip net.IP) (BucketId, error) { return BucketId{}, err } return BucketId{ - limit: NewRegistrationsPerIPv6Range, + limitName: NewRegistrationsPerIPv6Range, bucketKey: joinWithColon(NewRegistrationsPerIPv6Range.EnumString(), id), }, nil } + +// NewOrdersPerAccountBucketId returns a BucketId for the provided ACME +// registration Id. +func NewOrdersPerAccountBucketId(regId int64) (BucketId, error) { + id := strconv.FormatInt(regId, 10) + err := validateIdForName(NewOrdersPerAccount, id) + if err != nil { + return BucketId{}, err + } + return BucketId{ + limitName: NewOrdersPerAccount, + bucketKey: joinWithColon(NewOrdersPerAccount.EnumString(), id), + }, nil +} + +// NewFailedAuthorizationsPerAccountBucketId returns a BucketId for the provided +// ACME registration Id. +func NewFailedAuthorizationsPerAccountBucketId(regId int64) (BucketId, error) { + id := strconv.FormatInt(regId, 10) + err := validateIdForName(FailedAuthorizationsPerAccount, id) + if err != nil { + return BucketId{}, err + } + return BucketId{ + limitName: FailedAuthorizationsPerAccount, + bucketKey: joinWithColon(FailedAuthorizationsPerAccount.EnumString(), id), + }, nil +} + +// NewCertificatesPerDomainBucketId returns a BucketId for the provided order +// domain name. +func NewCertificatesPerDomainBucketId(orderName string) (BucketId, error) { + err := validateIdForName(CertificatesPerDomain, orderName) + if err != nil { + return BucketId{}, err + } + return BucketId{ + limitName: CertificatesPerDomain, + bucketKey: joinWithColon(CertificatesPerDomain.EnumString(), orderName), + }, nil +} + +// newCertificatesPerDomainPerAccountBucketId is only referenced internally. +// Buckets for CertificatesPerDomainPerAccount are created by calling +// NewCertificatesPerDomainBucketsWithCost(). +func newCertificatesPerDomainPerAccountBucketId(regId int64) (BucketId, error) { + id := strconv.FormatInt(regId, 10) + err := validateIdForName(CertificatesPerDomainPerAccount, id) + if err != nil { + return BucketId{}, err + } + return BucketId{ + limitName: CertificatesPerDomainPerAccount, + bucketKey: joinWithColon(CertificatesPerDomainPerAccount.EnumString(), id), + }, nil +} + +// NewCertificatesPerDomainTransactions returns a slice of Transactions for the +// provided order domain names. The cost specified will be applied per eTLD+1 +// name present in the orderDomains. +// +// Note: the CertificatesPerDomain limit is special in that it can be overridden +// by overrides to the CertificatesPerDomainPerAccount limit. This occurs when +// the CertificatesPerDomainPerAccount override allows for higher throughput +// than the CertificatesPerDomain limit. In these cases, the cost will be +// consumed from the CertificatesPerDomainPerAccount bucket AND from the +// CertificatesPerDomain bucket, if capacity exists. +func NewCertificatesPerDomainTransactions(limiter *Limiter, regId int64, orderDomains []string, cost int64) ([]Transaction, error) { + id, err := newCertificatesPerDomainPerAccountBucketId(regId) + if err != nil { + return nil, err + } + certsPerDomainPerAccountLimit, err := limiter.getLimit(CertificatesPerDomainPerAccount, id.bucketKey) + if err != nil { + if !errors.Is(err, errLimitDisabled) { + return nil, err + } + } + + var txns []Transaction + var regIdBucketCost int64 + for _, name := range DomainsForRateLimiting(orderDomains) { + bucketId, err := NewCertificatesPerDomainBucketId(name) + if err != nil { + return nil, err + } + regIdBucketCost++ + txns = append(txns, NewTransaction(bucketId, cost, certsPerDomainPerAccountLimit.isOverride)) + } + txns = append(txns, NewTransaction(id, regIdBucketCost*cost, certsPerDomainPerAccountLimit.isOverride)) + return txns, nil +} + +// NewCertificatesPerFQDNSetBucket returns a BucketId for the provided order +// domain names. +func NewCertificatesPerFQDNSetBucket(orderNames []string) (BucketId, error) { + id := string(core.HashNames(orderNames)) + err := validateIdForName(CertificatesPerFQDNSet, id) + if err != nil { + return BucketId{}, err + } + return BucketId{ + limitName: CertificatesPerFQDNSet, + bucketKey: joinWithColon(CertificatesPerFQDNSet.EnumString(), id), + }, nil +} diff --git a/ratelimits/limiter.go b/ratelimits/limiter.go index 46c727b1d399..2a4aece39b4d 100644 --- a/ratelimits/limiter.go +++ b/ratelimits/limiter.go @@ -4,6 +4,8 @@ import ( "context" "errors" "fmt" + "math" + "slices" "time" "github.com/jmhodges/clock" @@ -35,7 +37,7 @@ var errLimitDisabled = errors.New("limit disabled") // disabledLimitDecision is an "allowed" *Decision that should be returned when // a checked limit is found to be disabled. -var disabledLimitDecision = &Decision{true, 0, 0, 0, time.Time{}} +var disabledLimitDecision = &Decision{Allowed: true, Remaining: math.MaxInt64} // Limiter provides a high-level interface for rate limiting requests by // utilizing a leaky bucket-style approach. @@ -99,13 +101,21 @@ func NewLimiter(clk clock.Clock, source source, defaults, overrides string, stat type Transaction struct { BucketId cost int64 + + // optimistic indicates to the limiter that the transaction should be + // spent if possible, but should not be denied if the bucket lacks the + // capacity to satisfy the cost. + optimistic bool } -// NewTransaction creates a new Transaction for the provided BucketId and cost. -func NewTransaction(b BucketId, cost int64) Transaction { +// NewTransaction creates a new Transaction for the provided BucketId, cost, and +// a bool indicating whether the transaction is optimistic. Optimistic requests +// will not be denied if the bucket lacks the capacity to satisfy the cost. +func NewTransaction(b BucketId, cost int64, optimistic bool) Transaction { return Transaction{ - BucketId: b, - cost: cost, + BucketId: b, + cost: cost, + optimistic: optimistic, } } @@ -142,7 +152,7 @@ func (l *Limiter) Check(ctx context.Context, txn Transaction) (*Decision, error) return nil, ErrInvalidCostForCheck } - limit, err := l.getLimit(txn.limit, txn.bucketKey) + limit, err := l.getLimit(txn.limitName, txn.bucketKey) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -181,7 +191,7 @@ func (l *Limiter) Spend(ctx context.Context, txn Transaction) (*Decision, error) return nil, ErrInvalidCost } - limit, err := l.getLimit(txn.limit, txn.bucketKey) + limit, err := l.getLimit(txn.limitName, txn.bucketKey) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -196,7 +206,7 @@ func (l *Limiter) Spend(ctx context.Context, txn Transaction) (*Decision, error) start := l.clk.Now() status := Denied defer func() { - l.spendLatency.WithLabelValues(txn.limit.String(), status).Observe(l.clk.Since(start).Seconds()) + l.spendLatency.WithLabelValues(txn.limitName.String(), status).Observe(l.clk.Since(start).Seconds()) }() // Remove cancellation from the request context so that transactions are not @@ -223,7 +233,7 @@ func (l *Limiter) Spend(ctx context.Context, txn Transaction) (*Decision, error) if limit.isOverride { // Calculate the current utilization of the override limit. utilization := float64(limit.Burst-d.Remaining) / float64(limit.Burst) - l.overrideUsageGauge.WithLabelValues(txn.limit.String(), txn.bucketKey).Set(utilization) + l.overrideUsageGauge.WithLabelValues(txn.limitName.String(), txn.bucketKey).Set(utilization) } if !d.Allowed { @@ -238,6 +248,132 @@ func (l *Limiter) Spend(ctx context.Context, txn Transaction) (*Decision, error) return d, nil } +type batchTransaction struct { + Transaction + limit limit +} + +func (l *Limiter) prepareBatch(txns []Transaction) ([]batchTransaction, []string, error) { + var batchTxns []batchTransaction + var bucketKeys []string + for _, txn := range txns { + if txn.cost <= 0 { + return nil, nil, ErrInvalidCost + } + limit, err := l.getLimit(txn.limitName, txn.bucketKey) + if err != nil { + if errors.Is(err, errLimitDisabled) { + continue + } + return nil, nil, err + } + if txn.cost > limit.Burst { + return nil, nil, ErrInvalidCostOverLimit + } + if slices.Contains(bucketKeys, txn.bucketKey) { + return nil, nil, fmt.Errorf("found duplicate bucket %q in batch", txn.bucketKey) + } + bucketKeys = append(bucketKeys, txn.bucketKey) + batchTxns = append(batchTxns, batchTransaction{txn, limit}) + } + return batchTxns, bucketKeys, nil +} + +type batchDecision struct { + *Decision +} + +func newBatchDecision() *batchDecision { + return &batchDecision{ + Decision: &Decision{ + Allowed: true, + Remaining: math.MaxInt64, + }, + } +} + +func (d *batchDecision) consolidate(in *Decision) { + d.Allowed = d.Allowed && in.Allowed + d.Remaining = min(d.Remaining, in.Remaining) + d.RetryIn = max(d.RetryIn, in.RetryIn) + d.ResetIn = max(d.ResetIn, in.ResetIn) + if in.newTAT.After(d.newTAT) { + d.newTAT = in.newTAT + } +} + +// BatchSpend attempts to deduct the costs from the provided buckets' +// capacities. The following rules are applied to consolidate the Decisions for +// each transaction: +// - Allowed is true if all of the Decisions were allowed, +// - Remaining is the smallest value across all Decisions, and +// - RetryIn and ResetIn are the largest values across all Decisions. +// +// Non-existent buckets will be created WITH the cost factored into the initial +// state. New bucket states are persisted to the underlying datastore, if +// applicable, before returning. +func (l *Limiter) BatchSpend(ctx context.Context, txns []Transaction) (*Decision, error) { + batch, bucketKeys, err := l.prepareBatch(txns) + if err != nil { + return nil, err + } + if len(batch) <= 0 { + // All buckets were disabled. + return disabledLimitDecision, nil + } + + // Remove cancellation from the request context so that transactions are not + // interrupted by a client disconnect. + ctx = context.WithoutCancel(ctx) + tats, err := l.source.BatchGet(ctx, bucketKeys) + if err != nil { + return nil, err + } + + start := l.clk.Now() + batchDecision := newBatchDecision() + newTATs := make(map[string]time.Time) + + for _, txn := range batch { + tat, exists := tats[txn.bucketKey] + if !exists { + // First request from this client. Initialize the bucket with a TAT of + // "now", which is equivalent to a full bucket. + tat = l.clk.Now() + } + + // Spend the cost and update the consolidated decision. + d := maybeSpend(l.clk, txn.limit, tat, txn.cost) + if d.Allowed { + newTATs[txn.bucketKey] = d.newTAT + } + + if txn.limit.isOverride { + // Calculate the current utilization of the override limit. + utilization := float64(txn.limit.Burst-d.Remaining) / float64(txn.limit.Burst) + l.overrideUsageGauge.WithLabelValues(txn.limitName.String(), txn.bucketKey).Set(utilization) + } + + if !d.Allowed && txn.optimistic { + // Suppress denial for optimistic transaction. + d = disabledLimitDecision + } + batchDecision.consolidate(d) + } + + if batchDecision.Allowed { + // Persist the batch. + err = l.source.BatchSet(ctx, newTATs) + if err != nil { + return nil, err + } + l.spendLatency.WithLabelValues("batch", Allowed).Observe(l.clk.Since(start).Seconds()) + } else { + l.spendLatency.WithLabelValues("batch", Denied).Observe(l.clk.Since(start).Seconds()) + } + return batchDecision.Decision, nil +} + // Refund attempts to refund all of the cost to the capacity of the specified // bucket. The returned *Decision indicates whether the refund was successful // and represents the current state of the bucket. The new bucket state is @@ -254,7 +390,7 @@ func (l *Limiter) Refund(ctx context.Context, txn Transaction) (*Decision, error return nil, ErrInvalidCost } - limit, err := l.getLimit(txn.limit, txn.bucketKey) + limit, err := l.getLimit(txn.limitName, txn.bucketKey) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -277,6 +413,62 @@ func (l *Limiter) Refund(ctx context.Context, txn Transaction) (*Decision, error return d, l.source.Set(ctx, txn.bucketKey, d.newTAT) } +// BatchRefund attempts to refund all or some of the costs to the provided +// buckets' capacities. The following rules are applied to consolidate the +// Decisions for each transaction: +// - Allowed is true if all of the Decisions were allowed, +// - Remaining is the smallest value across all Decisions, and +// - RetryIn and ResetIn are the largest values across all Decisions. +// +// Non-existent buckets within the batch are disregarded without error, as this +// is equivalent to the bucket being full. The new bucket state is persisted to +// the underlying datastore, if applicable, before returning. +func (l *Limiter) BatchRefund(ctx context.Context, txns []Transaction) (*Decision, error) { + batch, bucketKeys, err := l.prepareBatch(txns) + if err != nil { + return nil, err + } + if len(batch) <= 0 { + // All buckets were disabled. + return disabledLimitDecision, nil + } + + // Remove cancellation from the request context so that transactions are not + // interrupted by a client disconnect. + ctx = context.WithoutCancel(ctx) + tats, err := l.source.BatchGet(ctx, bucketKeys) + if err != nil { + return nil, err + } + + batchDecision := newBatchDecision() + newTATs := make(map[string]time.Time) + + for _, txn := range batch { + tat, exists := tats[txn.bucketKey] + if !exists { + // Ignore non-existent bucket. + continue + } + + // Refund the cost and update the consolidated decision. + d := maybeRefund(l.clk, txn.limit, tat, txn.cost) + if d.Allowed { + newTATs[txn.bucketKey] = d.newTAT + } + batchDecision.consolidate(d) + } + + if len(newTATs) > 0 { + // Persist the batch. + err = l.source.BatchSet(ctx, newTATs) + if err != nil { + return nil, err + } + } + return batchDecision.Decision, nil +} + // Reset resets the specified bucket to its maximum capacity. The new bucket // state is persisted to the underlying datastore before returning. func (l *Limiter) Reset(ctx context.Context, bucketId BucketId) error { diff --git a/ratelimits/limiter_test.go b/ratelimits/limiter_test.go index 7f1088102793..40204130cac2 100644 --- a/ratelimits/limiter_test.go +++ b/ratelimits/limiter_test.go @@ -58,8 +58,8 @@ func Test_Limiter_CheckWithLimitNoExist(t *testing.T) { testCtx, limiters, _, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - bucketId := BucketId{limit: Name(9999), bucketKey: testIP} - _, err := l.Check(testCtx, NewTransaction(bucketId, 1)) + bucketId := BucketId{limitName: Name(9999), bucketKey: testIP} + _, err := l.Check(testCtx, NewTransaction(bucketId, 1, false)) test.AssertError(t, err, "should error") }) } @@ -67,7 +67,7 @@ func Test_Limiter_CheckWithLimitNoExist(t *testing.T) { func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { t.Parallel() - testCtx, limiters, clk, _ := setup(t) + testCtx, limiters, clk, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { // Verify our overrideUsageGauge is being set correctly. 0.0 == 0% of @@ -81,21 +81,21 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { // Attempt to check a spend of 41 requests (a cost > the limit burst // capacity), this should fail with a specific error. - _, err = l.Check(testCtx, NewTransaction(overriddenBucketId, 41)) + _, err = l.Check(testCtx, NewTransaction(overriddenBucketId, 41, false)) test.AssertErrorIs(t, err, ErrInvalidCostOverLimit) // Attempt to spend 41 requests (a cost > the limit burst capacity), // this should fail with a specific error. - _, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 41)) + _, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 41, false)) test.AssertErrorIs(t, err, ErrInvalidCostOverLimit) // Attempt to spend all 40 requests, this should succeed. - d, err := l.Spend(testCtx, NewTransaction(overriddenBucketId, 40)) + d, err := l.Spend(testCtx, NewTransaction(overriddenBucketId, 40, false)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 1)) + d, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 1, false)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -115,7 +115,7 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { clk.Add(d.RetryIn) // We should be allowed to spend 1 more request. - d, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 1)) + d, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 1, false)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -126,19 +126,53 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { // Quickly spend 40 requests in a row. for i := 0; i < 40; i++ { - d, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 1)) + d, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 1, false)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(39-i)) } // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 1)) + d, err = l.Spend(testCtx, NewTransaction(overriddenBucketId, 1, false)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) + // Wait 1 second for a full bucket reset. + clk.Add(d.ResetIn) + + testIP := net.ParseIP(testIP) + normalBucket, err := NewRegistrationsPerIPAddressBucketId(testIP) + test.AssertNotError(t, err, "should not error") + + // Spend the same bucket but in a batch with bucket subject to + // default limits. This should succeed, but the decision should + // reflect that of the default bucket. + d, err = l.BatchSpend(testCtx, []Transaction{NewTransaction(overriddenBucketId, 1, false), NewTransaction(normalBucket, 1, false)}) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(19)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + test.AssertEquals(t, d.ResetIn, time.Millisecond*50) + + // Refund quota to both buckets. This should succeed, but the + // decision should reflect that of the default bucket. + d, err = l.BatchRefund(testCtx, []Transaction{NewTransaction(overriddenBucketId, 1, false), NewTransaction(normalBucket, 1, false)}) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(20)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + test.AssertEquals(t, d.ResetIn, time.Duration(0)) + + // Once more. + d, err = l.BatchSpend(testCtx, []Transaction{NewTransaction(overriddenBucketId, 1, false), NewTransaction(normalBucket, 1, false)}) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(19)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + test.AssertEquals(t, d.ResetIn, time.Millisecond*50) + // Reset between tests. err = l.Reset(testCtx, overriddenBucketId) test.AssertNotError(t, err, "should not error") @@ -156,7 +190,7 @@ func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) { // Check on an empty bucket should return the theoretical next state // of that bucket if the cost were spent. - d, err := l.Check(testCtx, NewTransaction(bucketId, 1)) + d, err := l.Check(testCtx, NewTransaction(bucketId, 1, false)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(19)) @@ -167,7 +201,7 @@ func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) { // However, that cost should not be spent yet, a 0 cost check should // tell us that we actually have 20 remaining. - d, err = l.Check(testCtx, NewTransaction(bucketId, 0)) + d, err = l.Check(testCtx, NewTransaction(bucketId, 0, false)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(20)) @@ -180,7 +214,7 @@ func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) { // Similar to above, but we'll use Spend() to actually initialize // the bucket. Spend should return the same result as Check. - d, err = l.Spend(testCtx, NewTransaction(bucketId, 1)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 1, false)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(19)) @@ -191,7 +225,7 @@ func Test_Limiter_InitializationViaCheckAndSpend(t *testing.T) { // However, that cost should not be spent yet, a 0 cost check should // tell us that we actually have 19 remaining. - d, err = l.Check(testCtx, NewTransaction(bucketId, 0)) + d, err = l.Check(testCtx, NewTransaction(bucketId, 0, false)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(19)) @@ -212,19 +246,19 @@ func Test_Limiter_RefundAndSpendCostErr(t *testing.T) { test.AssertNotError(t, err, "should not error") // Spend a cost of 0, which should fail. - _, err = l.Spend(testCtx, NewTransaction(bucketId, 0)) + _, err = l.Spend(testCtx, NewTransaction(bucketId, 0, false)) test.AssertErrorIs(t, err, ErrInvalidCost) // Spend a negative cost, which should fail. - _, err = l.Spend(testCtx, NewTransaction(bucketId, -1)) + _, err = l.Spend(testCtx, NewTransaction(bucketId, -1, false)) test.AssertErrorIs(t, err, ErrInvalidCost) // Refund a cost of 0, which should fail. - _, err = l.Refund(testCtx, NewTransaction(bucketId, 0)) + _, err = l.Refund(testCtx, NewTransaction(bucketId, 0, false)) test.AssertErrorIs(t, err, ErrInvalidCost) // Refund a negative cost, which should fail. - _, err = l.Refund(testCtx, NewTransaction(bucketId, -1)) + _, err = l.Refund(testCtx, NewTransaction(bucketId, -1, false)) test.AssertErrorIs(t, err, ErrInvalidCost) }) } @@ -238,7 +272,7 @@ func Test_Limiter_CheckWithBadCost(t *testing.T) { bucketId, err := NewRegistrationsPerIPAddressBucketId(net.ParseIP(testIP)) test.AssertNotError(t, err, "should not error") - _, err = l.Check(testCtx, NewTransaction(bucketId, -1)) + _, err = l.Check(testCtx, NewTransaction(bucketId, -1, false)) test.AssertErrorIs(t, err, ErrInvalidCostForCheck) }) } @@ -254,18 +288,18 @@ func Test_Limiter_DefaultLimits(t *testing.T) { // Attempt to spend 21 requests (a cost > the limit burst capacity), // this should fail with a specific error. - _, err = l.Spend(testCtx, NewTransaction(bucketId, 21)) + _, err = l.Spend(testCtx, NewTransaction(bucketId, 21, false)) test.AssertErrorIs(t, err, ErrInvalidCostOverLimit) // Attempt to spend all 20 requests, this should succeed. - d, err := l.Spend(testCtx, NewTransaction(bucketId, 20)) + d, err := l.Spend(testCtx, NewTransaction(bucketId, 20, false)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, NewTransaction(bucketId, 1)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 1, false)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -279,7 +313,7 @@ func Test_Limiter_DefaultLimits(t *testing.T) { clk.Add(d.RetryIn) // We should be allowed to spend 1 more request. - d, err = l.Spend(testCtx, NewTransaction(bucketId, 1)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 1, false)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -290,14 +324,14 @@ func Test_Limiter_DefaultLimits(t *testing.T) { // Quickly spend 20 requests in a row. for i := 0; i < 20; i++ { - d, err = l.Spend(testCtx, NewTransaction(bucketId, 1)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 1, false)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(19-i)) } // Attempting to spend 1 more, this should fail. - d, err = l.Spend(testCtx, NewTransaction(bucketId, 1)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 1, false)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -315,19 +349,19 @@ func Test_Limiter_RefundAndReset(t *testing.T) { test.AssertNotError(t, err, "should not error") // Attempt to spend all 20 requests, this should succeed. - d, err := l.Spend(testCtx, NewTransaction(bucketId, 20)) + d, err := l.Spend(testCtx, NewTransaction(bucketId, 20, false)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) // Refund 10 requests. - d, err = l.Refund(testCtx, NewTransaction(bucketId, 10)) + d, err = l.Refund(testCtx, NewTransaction(bucketId, 10, false)) test.AssertNotError(t, err, "should not error") test.AssertEquals(t, d.Remaining, int64(10)) // Spend 10 requests, this should succeed. - d, err = l.Spend(testCtx, NewTransaction(bucketId, 10)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 10, false)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -337,7 +371,7 @@ func Test_Limiter_RefundAndReset(t *testing.T) { test.AssertNotError(t, err, "should not error") // Attempt to spend 20 more requests, this should succeed. - d, err = l.Spend(testCtx, NewTransaction(bucketId, 20)) + d, err = l.Spend(testCtx, NewTransaction(bucketId, 20, false)) test.AssertNotError(t, err, "should not error") test.Assert(t, d.Allowed, "should be allowed") test.AssertEquals(t, d.Remaining, int64(0)) @@ -347,7 +381,7 @@ func Test_Limiter_RefundAndReset(t *testing.T) { clk.Add(d.ResetIn) // Refund 1 requests above our limit, this should fail. - d, err = l.Refund(testCtx, NewTransaction(bucketId, 1)) + d, err = l.Refund(testCtx, NewTransaction(bucketId, 1, false)) test.AssertNotError(t, err, "should not error") test.Assert(t, !d.Allowed, "should not be allowed") test.AssertEquals(t, d.Remaining, int64(20)) diff --git a/ratelimits/source.go b/ratelimits/source.go index 84935daefd39..0f8151993fa4 100644 --- a/ratelimits/source.go +++ b/ratelimits/source.go @@ -19,6 +19,14 @@ type source interface { // the underlying storage client implementation). Set(ctx context.Context, bucketKey string, tat time.Time) error + // BatchSet stores the TATs at the specified bucketKeys (formatted as + // 'name:id'). Implementations MUST ensure non-blocking operations by + // either: + // a) applying a deadline or timeout to the context WITHIN the method, or + // b) guaranteeing the operation will not block indefinitely (e.g. via + // the underlying storage client implementation). + BatchSet(ctx context.Context, bucketKeys map[string]time.Time) error + // Get retrieves the TAT associated with the specified bucketKey (formatted // as 'name:id'). Implementations MUST ensure non-blocking operations by // either: @@ -27,6 +35,14 @@ type source interface { // the underlying storage client implementation). Get(ctx context.Context, bucketKey string) (time.Time, error) + // BatchGet retrieves the TATs associated with the specified bucketKeys + // (formatted as 'name:id'). Implementations MUST ensure non-blocking + // operations by either: + // a) applying a deadline or timeout to the context WITHIN the method, or + // b) guaranteeing the operation will not block indefinitely (e.g. via + // the underlying storage client implementation). + BatchGet(ctx context.Context, bucketKeys []string) (map[string]time.Time, error) + // Delete removes the TAT associated with the specified bucketKey (formatted // as 'name:id'). Implementations MUST ensure non-blocking operations by // either: @@ -54,6 +70,15 @@ func (in *inmem) Set(_ context.Context, bucketKey string, tat time.Time) error { return nil } +func (in *inmem) BatchSet(_ context.Context, bucketKeys map[string]time.Time) error { + in.Lock() + defer in.Unlock() + for k, v := range bucketKeys { + in.m[k] = v + } + return nil +} + func (in *inmem) Get(_ context.Context, bucketKey string) (time.Time, error) { in.RLock() defer in.RUnlock() @@ -64,6 +89,20 @@ func (in *inmem) Get(_ context.Context, bucketKey string) (time.Time, error) { return tat, nil } +func (in *inmem) BatchGet(_ context.Context, bucketKeys []string) (map[string]time.Time, error) { + in.RLock() + defer in.RUnlock() + tats := make(map[string]time.Time, len(bucketKeys)) + for _, k := range bucketKeys { + tat, ok := in.m[k] + if !ok { + tats[k] = time.Time{} + } + tats[k] = tat + } + return tats, nil +} + func (in *inmem) Delete(_ context.Context, bucketKey string) error { in.Lock() defer in.Unlock() diff --git a/ratelimits/source_redis.go b/ratelimits/source_redis.go index 5664058fdf0f..79b1d9867a4f 100644 --- a/ratelimits/source_redis.go +++ b/ratelimits/source_redis.go @@ -83,6 +83,26 @@ func (r *RedisSource) Set(ctx context.Context, bucketKey string, tat time.Time) return nil } +// BatchSet stores TATs at the specified bucketKeys using a pipelined Redis +// transaction in order to reduce the number of round-trips to each Redis shard. +// An error is returned if the operation failed and nil otherwise. +func (r *RedisSource) BatchSet(ctx context.Context, buckets map[string]time.Time) error { + start := r.clk.Now() + + pipeline := r.client.Pipeline() + for bucketKey, tat := range buckets { + pipeline.Set(ctx, bucketKey, tat.UTC().UnixNano(), 0) + } + _, err := pipeline.Exec(ctx) + if err != nil { + r.latency.With(prometheus.Labels{"call": "batchset", "result": resultForError(err)}).Observe(time.Since(start).Seconds()) + return err + } + + r.latency.With(prometheus.Labels{"call": "batchset", "result": "success"}).Observe(time.Since(start).Seconds()) + return nil +} + // Get retrieves the TAT at the specified bucketKey. An error is returned if the // operation failed and nil otherwise. If the bucketKey does not exist, // ErrBucketNotFound is returned. @@ -104,6 +124,43 @@ func (r *RedisSource) Get(ctx context.Context, bucketKey string) (time.Time, err return time.Unix(0, tatNano).UTC(), nil } +// BatchGet retrieves the TATs at the specified bucketKeys using a pipelined +// Redis transaction in order to reduce the number of round-trips to each Redis +// shard. An error is returned if the operation failed and nil otherwise. If a +// bucketKey does not exist, it WILL NOT be included in the returned map. +func (r *RedisSource) BatchGet(ctx context.Context, bucketKeys []string) (map[string]time.Time, error) { + start := r.clk.Now() + + pipeline := r.client.Pipeline() + for _, bucketKey := range bucketKeys { + pipeline.Get(ctx, bucketKey) + } + results, err := pipeline.Exec(ctx) + if err != nil { + r.latency.With(prometheus.Labels{"call": "batchget", "result": resultForError(err)}).Observe(time.Since(start).Seconds()) + if !errors.Is(err, redis.Nil) { + return nil, err + } + } + + tats := make(map[string]time.Time, len(bucketKeys)) + for i, result := range results { + tatNano, err := result.(*redis.StringCmd).Int64() + if err != nil { + if errors.Is(err, redis.Nil) { + // Bucket key does not exist. + continue + } + r.latency.With(prometheus.Labels{"call": "batchget", "result": resultForError(err)}).Observe(time.Since(start).Seconds()) + return nil, err + } + tats[bucketKeys[i]] = time.Unix(0, tatNano).UTC() + } + + r.latency.With(prometheus.Labels{"call": "batchget", "result": "success"}).Observe(time.Since(start).Seconds()) + return tats, nil +} + // Delete deletes the TAT at the specified bucketKey ('name:id'). It returns an // error if the operation failed and nil otherwise. A nil return value does not // indicate that the bucketKey existed. diff --git a/ratelimits/source_redis_test.go b/ratelimits/source_redis_test.go index 13790b63bb91..6347d1d71d9e 100644 --- a/ratelimits/source_redis_test.go +++ b/ratelimits/source_redis_test.go @@ -2,6 +2,7 @@ package ratelimits import ( "testing" + "time" "github.com/letsencrypt/boulder/cmd" "github.com/letsencrypt/boulder/metrics" @@ -68,3 +69,37 @@ func Test_RedisSource_Ping(t *testing.T) { err = missingSecondShardSource.Ping(context.Background()) test.AssertError(t, err, "Ping should not error") } + +func Test_RedisSource_BatchSetAndGet(t *testing.T) { + clk := clock.NewFake() + s := newTestRedisSource(clk, map[string]string{ + "shard1": "10.33.33.4:4218", + "shard2": "10.33.33.5:4218", + }) + + now := clk.Now() + val1 := now.Add(time.Second) + val2 := now.Add(time.Second * 2) + val3 := now.Add(time.Second * 3) + + set := map[string]time.Time{ + "test1": val1, + "test2": val2, + "test3": val3, + } + + err := s.BatchSet(context.Background(), set) + test.AssertNotError(t, err, "BatchSet() should not error") + + got, err := s.BatchGet(context.Background(), []string{"test1", "test2", "test3"}) + test.AssertNotError(t, err, "BatchGet() should not error") + + for k, v := range set { + test.Assert(t, got[k].Equal(v), "BatchGet() should return the values set by BatchSet()") + } + + // Test that BatchGet() returns a zero time for a key that does not exist. + got, err = s.BatchGet(context.Background(), []string{"test1", "test4", "test3"}) + test.AssertNotError(t, err, "BatchGet() should not error when a key isn't found") + test.Assert(t, got["test4"].IsZero(), "BatchGet() should return a zero time for a key that does not exist") +} diff --git a/ratelimits/utilities.go b/ratelimits/utilities.go index 8a7cbca70877..dd5a1167eca8 100644 --- a/ratelimits/utilities.go +++ b/ratelimits/utilities.go @@ -2,9 +2,32 @@ package ratelimits import ( "strings" + + "github.com/letsencrypt/boulder/core" + "github.com/weppos/publicsuffix-go/publicsuffix" ) // joinWithColon joins the provided args with a colon. func joinWithColon(args ...string) string { return strings.Join(args, ":") } + +// DomainsForRateLimiting transforms a list of FQDNs into a list of eTLD+1's +// for the purpose of rate limiting. It also de-duplicates the output +// domains. Exact public suffix matches are included. +func DomainsForRateLimiting(names []string) []string { + var domains []string + for _, name := range names { + domain, err := publicsuffix.Domain(name) + if err != nil { + // The only possible errors are: + // (1) publicsuffix.Domain is giving garbage values + // (2) the public suffix is the domain itself + // We assume 2 and include the original name in the result. + domains = append(domains, name) + } else { + domains = append(domains, domain) + } + } + return core.UniqueLowerNames(domains) +} diff --git a/ratelimits/utilities_test.go b/ratelimits/utilities_test.go new file mode 100644 index 000000000000..9c68d3a6e899 --- /dev/null +++ b/ratelimits/utilities_test.go @@ -0,0 +1,27 @@ +package ratelimits + +import ( + "testing" + + "github.com/letsencrypt/boulder/test" +) + +func TestDomainsForRateLimiting(t *testing.T) { + domains := DomainsForRateLimiting([]string{}) + test.AssertEquals(t, len(domains), 0) + + domains = DomainsForRateLimiting([]string{"www.example.com", "example.com"}) + test.AssertDeepEquals(t, domains, []string{"example.com"}) + + domains = DomainsForRateLimiting([]string{"www.example.com", "example.com", "www.example.co.uk"}) + test.AssertDeepEquals(t, domains, []string{"example.co.uk", "example.com"}) + + domains = DomainsForRateLimiting([]string{"www.example.com", "example.com", "www.example.co.uk", "co.uk"}) + test.AssertDeepEquals(t, domains, []string{"co.uk", "example.co.uk", "example.com"}) + + domains = DomainsForRateLimiting([]string{"foo.bar.baz.www.example.com", "baz.example.com"}) + test.AssertDeepEquals(t, domains, []string{"example.com"}) + + domains = DomainsForRateLimiting([]string{"github.io", "foo.github.io", "bar.github.io"}) + test.AssertDeepEquals(t, domains, []string{"bar.github.io", "foo.github.io", "github.io"}) +} diff --git a/wfe2/wfe.go b/wfe2/wfe.go index 5c7932730fa4..e3ad1ad7c299 100644 --- a/wfe2/wfe.go +++ b/wfe2/wfe.go @@ -644,7 +644,7 @@ func (wfe *WebFrontEndImpl) checkNewAccountLimits(ctx context.Context, ip net.IP return } - decision, err := wfe.limiter.Spend(ctx, ratelimits.NewTransaction(bucketId, 1)) + decision, err := wfe.limiter.Spend(ctx, ratelimits.NewTransaction(bucketId, 1, false)) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPAddress) return @@ -661,7 +661,7 @@ func (wfe *WebFrontEndImpl) checkNewAccountLimits(ctx context.Context, ip net.IP return } - _, err = wfe.limiter.Spend(ctx, ratelimits.NewTransaction(bucketId, 1)) + _, err = wfe.limiter.Spend(ctx, ratelimits.NewTransaction(bucketId, 1, false)) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPv6Range) } @@ -692,7 +692,7 @@ func (wfe *WebFrontEndImpl) refundNewAccountLimits(ctx context.Context, ip net.I return } - _, err = wfe.limiter.Refund(ctx, ratelimits.NewTransaction(bucketId, 1)) + _, err = wfe.limiter.Refund(ctx, ratelimits.NewTransaction(bucketId, 1, false)) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPAddress) return @@ -708,7 +708,7 @@ func (wfe *WebFrontEndImpl) refundNewAccountLimits(ctx context.Context, ip net.I return } - _, err = wfe.limiter.Refund(ctx, ratelimits.NewTransaction(bucketId, 1)) + _, err = wfe.limiter.Refund(ctx, ratelimits.NewTransaction(bucketId, 1, false)) if err != nil { warn(err, ratelimits.NewRegistrationsPerIPv6Range) }