From c96c473f89c3866a8ad51f01d61f9800369bc80b Mon Sep 17 00:00:00 2001 From: Samantha Date: Thu, 9 Nov 2023 16:19:19 -0500 Subject: [PATCH] WIP --- ra/ra.go | 23 +--- ra/ra_test.go | 20 --- ratelimits/bucket.go | 129 +++++++++++++++++- ratelimits/limiter.go | 228 ++++++++++++++++++++++++++++++-- ratelimits/limiter_test.go | 38 +++++- ratelimits/names.go | 16 ++- ratelimits/source.go | 39 ++++++ ratelimits/source_redis.go | 57 ++++++++ ratelimits/source_redis_test.go | 35 +++++ ratelimits/utilities.go | 23 ++++ ratelimits/utilities_test.go | 27 ++++ 11 files changed, 569 insertions(+), 66 deletions(-) create mode 100644 ratelimits/utilities_test.go diff --git a/ra/ra.go b/ra/ra.go index ab19f842de5..6dd293b6103 100644 --- a/ra/ra.go +++ b/ra/ra.go @@ -20,7 +20,6 @@ import ( "github.com/jmhodges/clock" "github.com/prometheus/client_golang/prometheus" - "github.com/weppos/publicsuffix-go/publicsuffix" "golang.org/x/crypto/ocsp" "google.golang.org/grpc" "google.golang.org/protobuf/proto" @@ -1377,26 +1376,6 @@ func (ra *RegistrationAuthorityImpl) getSCTs(ctx context.Context, cert []byte, e return scts, nil } -// domainsForRateLimiting transforms a list of FQDNs into a list of eTLD+1's -// for the purpose of rate limiting. It also de-duplicates the output -// domains. Exact public suffix matches are included. -func domainsForRateLimiting(names []string) []string { - var domains []string - for _, name := range names { - domain, err := publicsuffix.Domain(name) - if err != nil { - // The only possible errors are: - // (1) publicsuffix.Domain is giving garbage values - // (2) the public suffix is the domain itself - // We assume 2 and include the original name in the result. - domains = append(domains, name) - } else { - domains = append(domains, domain) - } - } - return core.UniqueLowerNames(domains) -} - // enforceNameCounts uses the provided count RPC to find a count of certificates // for each of the names. If the count for any of the names exceeds the limit // for the given registration then the names out of policy are returned to be @@ -1451,7 +1430,7 @@ func (ra *RegistrationAuthorityImpl) checkCertificatesPerNameLimit(ctx context.C return nil } - tldNames := domainsForRateLimiting(names) + tldNames := ratelimits.DomainsForRateLimiting(names) namesOutOfLimit, earliest, err := ra.enforceNameCounts(ctx, tldNames, limit, regID) if err != nil { return fmt.Errorf("checking certificates per name limit for %q: %s", diff --git a/ra/ra_test.go b/ra/ra_test.go index d95dd332250..02e17d6232e 100644 --- a/ra/ra_test.go +++ b/ra/ra_test.go @@ -1113,26 +1113,6 @@ func TestAuthzFailedRateLimitingNewOrder(t *testing.T) { testcase() } -func TestDomainsForRateLimiting(t *testing.T) { - domains := domainsForRateLimiting([]string{}) - test.AssertEquals(t, len(domains), 0) - - domains = domainsForRateLimiting([]string{"www.example.com", "example.com"}) - test.AssertDeepEquals(t, domains, []string{"example.com"}) - - domains = domainsForRateLimiting([]string{"www.example.com", "example.com", "www.example.co.uk"}) - test.AssertDeepEquals(t, domains, []string{"example.co.uk", "example.com"}) - - domains = domainsForRateLimiting([]string{"www.example.com", "example.com", "www.example.co.uk", "co.uk"}) - test.AssertDeepEquals(t, domains, []string{"co.uk", "example.co.uk", "example.com"}) - - domains = domainsForRateLimiting([]string{"foo.bar.baz.www.example.com", "baz.example.com"}) - test.AssertDeepEquals(t, domains, []string{"example.com"}) - - domains = domainsForRateLimiting([]string{"github.io", "foo.github.io", "bar.github.io"}) - test.AssertDeepEquals(t, domains, []string{"bar.github.io", "foo.github.io", "github.io"}) -} - type mockSAWithNameCounts struct { mocks.StorageAuthority nameCounts *sapb.CountByNames diff --git a/ratelimits/bucket.go b/ratelimits/bucket.go index 501d1fd2c44..22c7c84100a 100644 --- a/ratelimits/bucket.go +++ b/ratelimits/bucket.go @@ -1,17 +1,21 @@ package ratelimits import ( + "errors" "fmt" "net" + "strconv" + + "github.com/letsencrypt/boulder/core" ) // BucketId should only be created using the New*BucketId functions. It is used // by the Limiter to look up the bucket and limit overrides for a specific // subscriber and limit. type BucketId struct { - // limit is the name of the associated rate limit. It is used for looking up - // default limits. - limit Name + // limitName is the name of the associated rate limit. It is used for + // looking up default limits. + limitName Name // bucketKey is the limit Name enum (e.g. "1") concatenated with the // subscriber identifier specific to the associate limit Name type. @@ -27,7 +31,7 @@ func NewRegistrationsPerIPAddressBucketId(ip net.IP) (BucketId, error) { return BucketId{}, err } return BucketId{ - limit: NewRegistrationsPerIPAddress, + limitName: NewRegistrationsPerIPAddress, bucketKey: joinWithColon(NewRegistrationsPerIPAddress.EnumString(), id), }, nil } @@ -46,7 +50,122 @@ func NewRegistrationsPerIPv6RangeBucketId(ip net.IP) (BucketId, error) { return BucketId{}, err } return BucketId{ - limit: NewRegistrationsPerIPv6Range, + limitName: NewRegistrationsPerIPv6Range, bucketKey: joinWithColon(NewRegistrationsPerIPv6Range.EnumString(), id), }, nil } + +// NewOrdersPerAccountBucketId returns a BucketId for the provided ACME +// registration Id. +func NewOrdersPerAccountBucketId(regId int64) (BucketId, error) { + id := strconv.FormatInt(regId, 10) + err := validateIdForName(NewOrdersPerAccount, id) + if err != nil { + return BucketId{}, err + } + return BucketId{ + limitName: NewOrdersPerAccount, + bucketKey: joinWithColon(NewOrdersPerAccount.EnumString(), id), + }, nil +} + +// NewFailedAuthorizationsPerAccountBucketId returns a BucketId for the provided +// ACME registration Id. +func NewFailedAuthorizationsPerAccountBucketId(regId int64) (BucketId, error) { + id := strconv.FormatInt(regId, 10) + err := validateIdForName(FailedAuthorizationsPerAccount, id) + if err != nil { + return BucketId{}, err + } + return BucketId{ + limitName: FailedAuthorizationsPerAccount, + bucketKey: joinWithColon(FailedAuthorizationsPerAccount.EnumString(), id), + }, nil +} + +// NewCertificatesPerDomainBucketId returns a BucketId for the provided order +// domain name. +func NewCertificatesPerDomainBucketId(orderName string) (BucketId, error) { + err := validateIdForName(CertificatesPerDomain, orderName) + if err != nil { + return BucketId{}, err + } + return BucketId{ + limitName: CertificatesPerDomain, + bucketKey: joinWithColon(CertificatesPerDomain.EnumString(), orderName), + }, nil +} + +// newCertificatesPerDomainPerAccountBucketId is only referenced internally. +// Buckets for CertificatesPerDomainPerAccount are created by calling +// NewCertificatesPerDomainBucketsWithCost(). +func newCertificatesPerDomainPerAccountBucketId(regId int64) (BucketId, error) { + id := strconv.FormatInt(regId, 10) + err := validateIdForName(CertificatesPerDomainPerAccount, id) + if err != nil { + return BucketId{}, err + } + return BucketId{ + limitName: CertificatesPerDomainPerAccount, + bucketKey: joinWithColon(CertificatesPerDomainPerAccount.EnumString(), id), + }, nil +} + +// NewCertificatesPerDomainTransactions returns a slice of Transactions for the +// provided order domain names. The cost specified will be applied per eTLD+1 +// name present in the orderDomains. +// +// Note: when overrides to the CertificatesPerDomainPerAccount are configured +// for the subscriber, the cost: +// - MUST be consumed from the CertificatesPerDomainPerAccount bucket and +// - SHOULD be consumed from each CertificatesPerDomain bucket, if possible. +// +// When a CertificatesPerDomainPerAccount override is configured, all of the +// CertificatesPerDomain transactions returned by this function will be marked +// as optimistic and the combined cost of all of these transactions will be +// specified in a CertificatesPerDomainPerAccount transaction as well. +func NewCertificatesPerDomainTransactions(limiter *Limiter, regId int64, orderDomains []string, cost int64) ([]Transaction, error) { + id, err := newCertificatesPerDomainPerAccountBucketId(regId) + if err != nil { + return nil, err + } + certsPerDomainPerAccountLimit, err := limiter.getLimit(CertificatesPerDomainPerAccount, id.bucketKey) + if err != nil { + if !errors.Is(err, errLimitDisabled) { + return nil, err + } + } + + var txns []Transaction + var certsPerDomainPerAccountCost int64 + for _, name := range DomainsForRateLimiting(orderDomains) { + bucketId, err := NewCertificatesPerDomainBucketId(name) + if err != nil { + return nil, err + } + certsPerDomainPerAccountCost += cost + if certsPerDomainPerAccountLimit.isOverride { + txns = append(txns, newOptimisticTransaction(bucketId, cost)) + } else { + txns = append(txns, NewTransaction(bucketId, cost)) + } + } + if certsPerDomainPerAccountLimit.isOverride { + txns = append(txns, NewTransaction(id, certsPerDomainPerAccountCost)) + } + return txns, nil +} + +// NewCertificatesPerFQDNSetBucket returns a BucketId for the provided order +// domain names. +func NewCertificatesPerFQDNSetBucket(orderNames []string) (BucketId, error) { + id := string(core.HashNames(orderNames)) + err := validateIdForName(CertificatesPerFQDNSet, id) + if err != nil { + return BucketId{}, err + } + return BucketId{ + limitName: CertificatesPerFQDNSet, + bucketKey: joinWithColon(CertificatesPerFQDNSet.EnumString(), id), + }, nil +} diff --git a/ratelimits/limiter.go b/ratelimits/limiter.go index 46c727b1d39..ffe622d8218 100644 --- a/ratelimits/limiter.go +++ b/ratelimits/limiter.go @@ -4,6 +4,8 @@ import ( "context" "errors" "fmt" + "math" + "slices" "time" "github.com/jmhodges/clock" @@ -35,7 +37,7 @@ var errLimitDisabled = errors.New("limit disabled") // disabledLimitDecision is an "allowed" *Decision that should be returned when // a checked limit is found to be disabled. -var disabledLimitDecision = &Decision{true, 0, 0, 0, time.Time{}} +var disabledLimitDecision = &Decision{Allowed: true, Remaining: math.MaxInt64} // Limiter provides a high-level interface for rate limiting requests by // utilizing a leaky bucket-style approach. @@ -99,6 +101,12 @@ func NewLimiter(clk clock.Clock, source source, defaults, overrides string, stat type Transaction struct { BucketId cost int64 + + // optimistic indicates to the limiter that the cost should be spent if + // possible, but should not be denied if the bucket lacks the capacity to + // satisfy the cost. Note: optimistic transactions are only supported by + // limiter.BatchSpend(). + optimistic bool } // NewTransaction creates a new Transaction for the provided BucketId and cost. @@ -109,6 +117,18 @@ func NewTransaction(b BucketId, cost int64) Transaction { } } +// newOptimisticTransaction creates a new optimistic Transaction for the +// provided BucketId and cost. Optimistic transactions will not be denied if the +// bucket lacks the capacity to satisfy the cost. Note: optimistic transactions +// are only supported by limiter.BatchSpend(). +func newOptimisticTransaction(b BucketId, cost int64) Transaction { + return Transaction{ + BucketId: b, + cost: cost, + optimistic: true, + } +} + type Decision struct { // Allowed is true if the bucket possessed enough capacity to allow the // request given the cost. @@ -142,7 +162,7 @@ func (l *Limiter) Check(ctx context.Context, txn Transaction) (*Decision, error) return nil, ErrInvalidCostForCheck } - limit, err := l.getLimit(txn.limit, txn.bucketKey) + limit, err := l.getLimit(txn.limitName, txn.bucketKey) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -181,7 +201,7 @@ func (l *Limiter) Spend(ctx context.Context, txn Transaction) (*Decision, error) return nil, ErrInvalidCost } - limit, err := l.getLimit(txn.limit, txn.bucketKey) + limit, err := l.getLimit(txn.limitName, txn.bucketKey) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -196,7 +216,7 @@ func (l *Limiter) Spend(ctx context.Context, txn Transaction) (*Decision, error) start := l.clk.Now() status := Denied defer func() { - l.spendLatency.WithLabelValues(txn.limit.String(), status).Observe(l.clk.Since(start).Seconds()) + l.spendLatency.WithLabelValues(txn.limitName.String(), status).Observe(l.clk.Since(start).Seconds()) }() // Remove cancellation from the request context so that transactions are not @@ -223,7 +243,7 @@ func (l *Limiter) Spend(ctx context.Context, txn Transaction) (*Decision, error) if limit.isOverride { // Calculate the current utilization of the override limit. utilization := float64(limit.Burst-d.Remaining) / float64(limit.Burst) - l.overrideUsageGauge.WithLabelValues(txn.limit.String(), txn.bucketKey).Set(utilization) + l.overrideUsageGauge.WithLabelValues(txn.limitName.String(), txn.bucketKey).Set(utilization) } if !d.Allowed { @@ -238,6 +258,132 @@ func (l *Limiter) Spend(ctx context.Context, txn Transaction) (*Decision, error) return d, nil } +type batchTransaction struct { + Transaction + limit limit +} + +func (l *Limiter) prepareBatch(txns []Transaction) ([]batchTransaction, []string, error) { + var batchTxns []batchTransaction + var bucketKeys []string + for _, txn := range txns { + if txn.cost <= 0 { + return nil, nil, ErrInvalidCost + } + limit, err := l.getLimit(txn.limitName, txn.bucketKey) + if err != nil { + if errors.Is(err, errLimitDisabled) { + continue + } + return nil, nil, err + } + if txn.cost > limit.Burst { + return nil, nil, ErrInvalidCostOverLimit + } + if slices.Contains(bucketKeys, txn.bucketKey) { + return nil, nil, fmt.Errorf("found duplicate bucket %q in batch", txn.bucketKey) + } + bucketKeys = append(bucketKeys, txn.bucketKey) + batchTxns = append(batchTxns, batchTransaction{txn, limit}) + } + return batchTxns, bucketKeys, nil +} + +type batchDecision struct { + *Decision +} + +func newBatchDecision() *batchDecision { + return &batchDecision{ + Decision: &Decision{ + Allowed: true, + Remaining: math.MaxInt64, + }, + } +} + +func (d *batchDecision) consolidate(in *Decision) { + d.Allowed = d.Allowed && in.Allowed + d.Remaining = min(d.Remaining, in.Remaining) + d.RetryIn = max(d.RetryIn, in.RetryIn) + d.ResetIn = max(d.ResetIn, in.ResetIn) + if in.newTAT.After(d.newTAT) { + d.newTAT = in.newTAT + } +} + +// BatchSpend attempts to deduct the costs from the provided buckets' +// capacities. The following rules are applied to consolidate the Decisions for +// each transaction: +// - Allowed is true if all of the Decisions were allowed, +// - Remaining is the smallest value across all Decisions, and +// - RetryIn and ResetIn are the largest values across all Decisions. +// +// Non-existent buckets will be created WITH the cost factored into the initial +// state. New bucket states are persisted to the underlying datastore, if +// applicable, before returning. +func (l *Limiter) BatchSpend(ctx context.Context, txns []Transaction) (*Decision, error) { + batch, bucketKeys, err := l.prepareBatch(txns) + if err != nil { + return nil, err + } + if len(batch) <= 0 { + // All buckets were disabled. + return disabledLimitDecision, nil + } + + // Remove cancellation from the request context so that transactions are not + // interrupted by a client disconnect. + ctx = context.WithoutCancel(ctx) + tats, err := l.source.BatchGet(ctx, bucketKeys) + if err != nil { + return nil, err + } + + start := l.clk.Now() + batchDecision := newBatchDecision() + newTATs := make(map[string]time.Time) + + for _, txn := range batch { + tat, exists := tats[txn.bucketKey] + if !exists { + // First request from this client. Initialize the bucket with a TAT of + // "now", which is equivalent to a full bucket. + tat = l.clk.Now() + } + + // Spend the cost and update the consolidated decision. + d := maybeSpend(l.clk, txn.limit, tat, txn.cost) + if d.Allowed { + newTATs[txn.bucketKey] = d.newTAT + } + + if txn.limit.isOverride { + // Calculate the current utilization of the override limit. + utilization := float64(txn.limit.Burst-d.Remaining) / float64(txn.limit.Burst) + l.overrideUsageGauge.WithLabelValues(txn.limitName.String(), txn.bucketKey).Set(utilization) + } + + if !d.Allowed && txn.optimistic { + // Suppress denial for optimistic transaction. + d = disabledLimitDecision + } + batchDecision.consolidate(d) + } + + if batchDecision.Allowed { + // Persist the batch. + err = l.source.BatchSet(ctx, newTATs) + if err != nil { + return nil, err + } + l.spendLatency.WithLabelValues("batch", Allowed).Observe(l.clk.Since(start).Seconds()) + } else { + l.spendLatency.WithLabelValues("batch", Denied).Observe(l.clk.Since(start).Seconds()) + } + return batchDecision.Decision, nil +} + // Refund attempts to refund all of the cost to the capacity of the specified // bucket. The returned *Decision indicates whether the refund was successful // and represents the current state of the bucket. The new bucket state is @@ -254,7 +400,7 @@ func (l *Limiter) Refund(ctx context.Context, txn Transaction) (*Decision, error return nil, ErrInvalidCost } - limit, err := l.getLimit(txn.limit, txn.bucketKey) + limit, err := l.getLimit(txn.limitName, txn.bucketKey) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -277,6 +423,62 @@ func (l *Limiter) Refund(ctx context.Context, txn Transaction) (*Decision, error return d, l.source.Set(ctx, txn.bucketKey, d.newTAT) } +// BatchRefund attempts to refund all or some of the costs to the provided +// buckets' capacities. The following rules are applied to consolidate the +// Decisions for each transaction: +// - Allowed is true if all of the Decisions were allowed, +// - Remaining is the smallest value across all Decisions, and +// - RetryIn and ResetIn are the largest values across all Decisions. +// +// Non-existent buckets within the batch are disregarded without error, as this +// is equivalent to the bucket being full. The new bucket state is persisted to +// the underlying datastore, if applicable, before returning. +func (l *Limiter) BatchRefund(ctx context.Context, txns []Transaction) (*Decision, error) { + batch, bucketKeys, err := l.prepareBatch(txns) + if err != nil { + return nil, err + } + if len(batch) <= 0 { + // All buckets were disabled. + return disabledLimitDecision, nil + } + + // Remove cancellation from the request context so that transactions are not + // interrupted by a client disconnect. + ctx = context.WithoutCancel(ctx) + tats, err := l.source.BatchGet(ctx, bucketKeys) + if err != nil { + return nil, err + } + + batchDecision := newBatchDecision() + newTATs := make(map[string]time.Time) + + for _, txn := range batch { + tat, exists := tats[txn.bucketKey] + if !exists { + // Ignore non-existent bucket. + continue + } + + // Refund the cost and update the consolidated decision. + d := maybeRefund(l.clk, txn.limit, tat, txn.cost) + if d.Allowed { + newTATs[txn.bucketKey] = d.newTAT + } + batchDecision.consolidate(d) + } + + if len(newTATs) > 0 { + // Persist the batch. + err = l.source.BatchSet(ctx, newTATs) + if err != nil { + return nil, err + } + } + return batchDecision.Decision, nil +} + // Reset resets the specified bucket to its maximum capacity. The new bucket // state is persisted to the underlying datastore before returning. func (l *Limiter) Reset(ctx context.Context, bucketId BucketId) error { @@ -302,19 +504,19 @@ func (l *Limiter) initialize(ctx context.Context, rl limit, txn Transaction) (*D return d, nil } -// getLimit returns the limit for the specified by name and id, name is -// required, id is optional. If id is left unspecified, the default limit for -// the limit specified by name is returned. If no default limit exists for the -// specified name, errLimitDisabled is returned. -func (l *Limiter) getLimit(name Name, id string) (limit, error) { +// getLimit returns the limit for the specified by name and bucketKey, name is +// required, bucketKey is optional. If bucketkey is left unspecified, the +// default limit for the limit specified by name is returned. If no default +// limit exists for the specified name, errLimitDisabled is returned. +func (l *Limiter) getLimit(name Name, bucketKey string) (limit, error) { if !name.isValid() { // This should never happen. Callers should only be specifying the limit // Name enums defined in this package. return limit{}, fmt.Errorf("specified name enum %q, is invalid", name) } - if id != "" { + if bucketKey != "" { // Check for override. - ol, ok := l.overrides[id] + ol, ok := l.overrides[bucketKey] if ok { return ol, nil } diff --git a/ratelimits/limiter_test.go b/ratelimits/limiter_test.go index 7f108810279..44dcdaa1f03 100644 --- a/ratelimits/limiter_test.go +++ b/ratelimits/limiter_test.go @@ -58,7 +58,7 @@ func Test_Limiter_CheckWithLimitNoExist(t *testing.T) { testCtx, limiters, _, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - bucketId := BucketId{limit: Name(9999), bucketKey: testIP} + bucketId := BucketId{limitName: Name(9999), bucketKey: testIP} _, err := l.Check(testCtx, NewTransaction(bucketId, 1)) test.AssertError(t, err, "should error") }) @@ -67,7 +67,7 @@ func Test_Limiter_CheckWithLimitNoExist(t *testing.T) { func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { t.Parallel() - testCtx, limiters, clk, _ := setup(t) + testCtx, limiters, clk, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { // Verify our overrideUsageGauge is being set correctly. 0.0 == 0% of @@ -139,6 +139,40 @@ func Test_Limiter_CheckWithLimitOverrides(t *testing.T) { test.AssertEquals(t, d.Remaining, int64(0)) test.AssertEquals(t, d.ResetIn, time.Second) + // Wait 1 second for a full bucket reset. + clk.Add(d.ResetIn) + + testIP := net.ParseIP(testIP) + normalBucket, err := NewRegistrationsPerIPAddressBucketId(testIP) + test.AssertNotError(t, err, "should not error") + + // Spend the same bucket but in a batch with bucket subject to + // default limits. This should succeed, but the decision should + // reflect that of the default bucket. + d, err = l.BatchSpend(testCtx, []Transaction{NewTransaction(overriddenBucketId, 1), NewTransaction(normalBucket, 1)}) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(19)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + test.AssertEquals(t, d.ResetIn, time.Millisecond*50) + + // Refund quota to both buckets. This should succeed, but the + // decision should reflect that of the default bucket. + d, err = l.BatchRefund(testCtx, []Transaction{NewTransaction(overriddenBucketId, 1), NewTransaction(normalBucket, 1)}) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(20)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + test.AssertEquals(t, d.ResetIn, time.Duration(0)) + + // Once more. + d, err = l.BatchSpend(testCtx, []Transaction{NewTransaction(overriddenBucketId, 1), NewTransaction(normalBucket, 1)}) + test.AssertNotError(t, err, "should not error") + test.Assert(t, d.Allowed, "should be allowed") + test.AssertEquals(t, d.Remaining, int64(19)) + test.AssertEquals(t, d.RetryIn, time.Duration(0)) + test.AssertEquals(t, d.ResetIn, time.Millisecond*50) + // Reset between tests. err = l.Reset(testCtx, overriddenBucketId) test.AssertNotError(t, err, "should not error") diff --git a/ratelimits/names.go b/ratelimits/names.go index b0d581e7657..f06246bb4f0 100644 --- a/ratelimits/names.go +++ b/ratelimits/names.go @@ -41,17 +41,25 @@ const ( NewOrdersPerAccount // FailedAuthorizationsPerAccount uses bucket key 'enum:regId', where regId - // is the ACME registration Id of the account. + // is the ACME registration Id of the account. Cost MUST be consumed from + // this bucket only when the authorization is considered "failed". It SHOULD + // be checked before new authorizations are created. FailedAuthorizationsPerAccount // CertificatesPerDomain uses bucket key 'enum:domain', where domain is a - // domain name in the issued certificate. + // domain name in the issued certificate. When overrides to the + // CertificatesPerDomainPerAccount are configured for a subscriber, the + // cost: + // - MUST be consumed from the CertificatesPerDomainPerAccount bucket and + // - SHOULD be consumed from each CertificatesPerDomain bucket, if possible. CertificatesPerDomain // CertificatesPerDomainPerAccount uses bucket key 'enum:regId', where regId // is the ACME registration Id of the account. This limit is never checked - // or enforced by the Limiter. Instead, it is used to override the - // CertificatesPerDomain limit for the specified account. + // or enforced by the Limiter. It is only used to provide an override for + // CertificatesPerDomain limit. When overrides are configured the cost: + // - MUST be consumed from the CertificatesPerDomainPerAccount bucket and + // - SHOULD be consumed from each CertificatesPerDomain bucket, if possible. CertificatesPerDomainPerAccount // CertificatesPerFQDNSet uses bucket key 'enum:fqdnSet', where fqdnSet is a diff --git a/ratelimits/source.go b/ratelimits/source.go index 84935daefd3..0f8151993fa 100644 --- a/ratelimits/source.go +++ b/ratelimits/source.go @@ -19,6 +19,14 @@ type source interface { // the underlying storage client implementation). Set(ctx context.Context, bucketKey string, tat time.Time) error + // BatchSet stores the TATs at the specified bucketKeys (formatted as + // 'name:id'). Implementations MUST ensure non-blocking operations by + // either: + // a) applying a deadline or timeout to the context WITHIN the method, or + // b) guaranteeing the operation will not block indefinitely (e.g. via + // the underlying storage client implementation). + BatchSet(ctx context.Context, bucketKeys map[string]time.Time) error + // Get retrieves the TAT associated with the specified bucketKey (formatted // as 'name:id'). Implementations MUST ensure non-blocking operations by // either: @@ -27,6 +35,14 @@ type source interface { // the underlying storage client implementation). Get(ctx context.Context, bucketKey string) (time.Time, error) + // BatchGet retrieves the TATs associated with the specified bucketKeys + // (formatted as 'name:id'). Implementations MUST ensure non-blocking + // operations by either: + // a) applying a deadline or timeout to the context WITHIN the method, or + // b) guaranteeing the operation will not block indefinitely (e.g. via + // the underlying storage client implementation). + BatchGet(ctx context.Context, bucketKeys []string) (map[string]time.Time, error) + // Delete removes the TAT associated with the specified bucketKey (formatted // as 'name:id'). Implementations MUST ensure non-blocking operations by // either: @@ -54,6 +70,15 @@ func (in *inmem) Set(_ context.Context, bucketKey string, tat time.Time) error { return nil } +func (in *inmem) BatchSet(_ context.Context, bucketKeys map[string]time.Time) error { + in.Lock() + defer in.Unlock() + for k, v := range bucketKeys { + in.m[k] = v + } + return nil +} + func (in *inmem) Get(_ context.Context, bucketKey string) (time.Time, error) { in.RLock() defer in.RUnlock() @@ -64,6 +89,20 @@ func (in *inmem) Get(_ context.Context, bucketKey string) (time.Time, error) { return tat, nil } +func (in *inmem) BatchGet(_ context.Context, bucketKeys []string) (map[string]time.Time, error) { + in.RLock() + defer in.RUnlock() + tats := make(map[string]time.Time, len(bucketKeys)) + for _, k := range bucketKeys { + tat, ok := in.m[k] + if !ok { + tats[k] = time.Time{} + } + tats[k] = tat + } + return tats, nil +} + func (in *inmem) Delete(_ context.Context, bucketKey string) error { in.Lock() defer in.Unlock() diff --git a/ratelimits/source_redis.go b/ratelimits/source_redis.go index 5664058fdf0..79b1d9867a4 100644 --- a/ratelimits/source_redis.go +++ b/ratelimits/source_redis.go @@ -83,6 +83,26 @@ func (r *RedisSource) Set(ctx context.Context, bucketKey string, tat time.Time) return nil } +// BatchSet stores TATs at the specified bucketKeys using a pipelined Redis +// transaction in order to reduce the number of round-trips to each Redis shard. +// An error is returned if the operation failed and nil otherwise. +func (r *RedisSource) BatchSet(ctx context.Context, buckets map[string]time.Time) error { + start := r.clk.Now() + + pipeline := r.client.Pipeline() + for bucketKey, tat := range buckets { + pipeline.Set(ctx, bucketKey, tat.UTC().UnixNano(), 0) + } + _, err := pipeline.Exec(ctx) + if err != nil { + r.latency.With(prometheus.Labels{"call": "batchset", "result": resultForError(err)}).Observe(time.Since(start).Seconds()) + return err + } + + r.latency.With(prometheus.Labels{"call": "batchset", "result": "success"}).Observe(time.Since(start).Seconds()) + return nil +} + // Get retrieves the TAT at the specified bucketKey. An error is returned if the // operation failed and nil otherwise. If the bucketKey does not exist, // ErrBucketNotFound is returned. @@ -104,6 +124,43 @@ func (r *RedisSource) Get(ctx context.Context, bucketKey string) (time.Time, err return time.Unix(0, tatNano).UTC(), nil } +// BatchGet retrieves the TATs at the specified bucketKeys using a pipelined +// Redis transaction in order to reduce the number of round-trips to each Redis +// shard. An error is returned if the operation failed and nil otherwise. If a +// bucketKey does not exist, it WILL NOT be included in the returned map. +func (r *RedisSource) BatchGet(ctx context.Context, bucketKeys []string) (map[string]time.Time, error) { + start := r.clk.Now() + + pipeline := r.client.Pipeline() + for _, bucketKey := range bucketKeys { + pipeline.Get(ctx, bucketKey) + } + results, err := pipeline.Exec(ctx) + if err != nil { + r.latency.With(prometheus.Labels{"call": "batchget", "result": resultForError(err)}).Observe(time.Since(start).Seconds()) + if !errors.Is(err, redis.Nil) { + return nil, err + } + } + + tats := make(map[string]time.Time, len(bucketKeys)) + for i, result := range results { + tatNano, err := result.(*redis.StringCmd).Int64() + if err != nil { + if errors.Is(err, redis.Nil) { + // Bucket key does not exist. + continue + } + r.latency.With(prometheus.Labels{"call": "batchget", "result": resultForError(err)}).Observe(time.Since(start).Seconds()) + return nil, err + } + tats[bucketKeys[i]] = time.Unix(0, tatNano).UTC() + } + + r.latency.With(prometheus.Labels{"call": "batchget", "result": "success"}).Observe(time.Since(start).Seconds()) + return tats, nil +} + // Delete deletes the TAT at the specified bucketKey ('name:id'). It returns an // error if the operation failed and nil otherwise. A nil return value does not // indicate that the bucketKey existed. diff --git a/ratelimits/source_redis_test.go b/ratelimits/source_redis_test.go index 13790b63bb9..6347d1d71d9 100644 --- a/ratelimits/source_redis_test.go +++ b/ratelimits/source_redis_test.go @@ -2,6 +2,7 @@ package ratelimits import ( "testing" + "time" "github.com/letsencrypt/boulder/cmd" "github.com/letsencrypt/boulder/metrics" @@ -68,3 +69,37 @@ func Test_RedisSource_Ping(t *testing.T) { err = missingSecondShardSource.Ping(context.Background()) test.AssertError(t, err, "Ping should not error") } + +func Test_RedisSource_BatchSetAndGet(t *testing.T) { + clk := clock.NewFake() + s := newTestRedisSource(clk, map[string]string{ + "shard1": "10.33.33.4:4218", + "shard2": "10.33.33.5:4218", + }) + + now := clk.Now() + val1 := now.Add(time.Second) + val2 := now.Add(time.Second * 2) + val3 := now.Add(time.Second * 3) + + set := map[string]time.Time{ + "test1": val1, + "test2": val2, + "test3": val3, + } + + err := s.BatchSet(context.Background(), set) + test.AssertNotError(t, err, "BatchSet() should not error") + + got, err := s.BatchGet(context.Background(), []string{"test1", "test2", "test3"}) + test.AssertNotError(t, err, "BatchGet() should not error") + + for k, v := range set { + test.Assert(t, got[k].Equal(v), "BatchGet() should return the values set by BatchSet()") + } + + // Test that BatchGet() returns a zero time for a key that does not exist. + got, err = s.BatchGet(context.Background(), []string{"test1", "test4", "test3"}) + test.AssertNotError(t, err, "BatchGet() should not error when a key isn't found") + test.Assert(t, got["test4"].IsZero(), "BatchGet() should return a zero time for a key that does not exist") +} diff --git a/ratelimits/utilities.go b/ratelimits/utilities.go index 8a7cbca7087..dd5a1167eca 100644 --- a/ratelimits/utilities.go +++ b/ratelimits/utilities.go @@ -2,9 +2,32 @@ package ratelimits import ( "strings" + + "github.com/letsencrypt/boulder/core" + "github.com/weppos/publicsuffix-go/publicsuffix" ) // joinWithColon joins the provided args with a colon. func joinWithColon(args ...string) string { return strings.Join(args, ":") } + +// DomainsForRateLimiting transforms a list of FQDNs into a list of eTLD+1's +// for the purpose of rate limiting. It also de-duplicates the output +// domains. Exact public suffix matches are included. +func DomainsForRateLimiting(names []string) []string { + var domains []string + for _, name := range names { + domain, err := publicsuffix.Domain(name) + if err != nil { + // The only possible errors are: + // (1) publicsuffix.Domain is giving garbage values + // (2) the public suffix is the domain itself + // We assume 2 and include the original name in the result. + domains = append(domains, name) + } else { + domains = append(domains, domain) + } + } + return core.UniqueLowerNames(domains) +} diff --git a/ratelimits/utilities_test.go b/ratelimits/utilities_test.go new file mode 100644 index 00000000000..9c68d3a6e89 --- /dev/null +++ b/ratelimits/utilities_test.go @@ -0,0 +1,27 @@ +package ratelimits + +import ( + "testing" + + "github.com/letsencrypt/boulder/test" +) + +func TestDomainsForRateLimiting(t *testing.T) { + domains := DomainsForRateLimiting([]string{}) + test.AssertEquals(t, len(domains), 0) + + domains = DomainsForRateLimiting([]string{"www.example.com", "example.com"}) + test.AssertDeepEquals(t, domains, []string{"example.com"}) + + domains = DomainsForRateLimiting([]string{"www.example.com", "example.com", "www.example.co.uk"}) + test.AssertDeepEquals(t, domains, []string{"example.co.uk", "example.com"}) + + domains = DomainsForRateLimiting([]string{"www.example.com", "example.com", "www.example.co.uk", "co.uk"}) + test.AssertDeepEquals(t, domains, []string{"co.uk", "example.co.uk", "example.com"}) + + domains = DomainsForRateLimiting([]string{"foo.bar.baz.www.example.com", "baz.example.com"}) + test.AssertDeepEquals(t, domains, []string{"example.com"}) + + domains = DomainsForRateLimiting([]string{"github.io", "foo.github.io", "bar.github.io"}) + test.AssertDeepEquals(t, domains, []string{"bar.github.io", "foo.github.io", "github.io"}) +}