diff --git a/ra/ra.go b/ra/ra.go index ab19f842de55..6dd293b61035 100644 --- a/ra/ra.go +++ b/ra/ra.go @@ -20,7 +20,6 @@ import ( "github.com/jmhodges/clock" "github.com/prometheus/client_golang/prometheus" - "github.com/weppos/publicsuffix-go/publicsuffix" "golang.org/x/crypto/ocsp" "google.golang.org/grpc" "google.golang.org/protobuf/proto" @@ -1377,26 +1376,6 @@ func (ra *RegistrationAuthorityImpl) getSCTs(ctx context.Context, cert []byte, e return scts, nil } -// domainsForRateLimiting transforms a list of FQDNs into a list of eTLD+1's -// for the purpose of rate limiting. It also de-duplicates the output -// domains. Exact public suffix matches are included. -func domainsForRateLimiting(names []string) []string { - var domains []string - for _, name := range names { - domain, err := publicsuffix.Domain(name) - if err != nil { - // The only possible errors are: - // (1) publicsuffix.Domain is giving garbage values - // (2) the public suffix is the domain itself - // We assume 2 and include the original name in the result. - domains = append(domains, name) - } else { - domains = append(domains, domain) - } - } - return core.UniqueLowerNames(domains) -} - // enforceNameCounts uses the provided count RPC to find a count of certificates // for each of the names. If the count for any of the names exceeds the limit // for the given registration then the names out of policy are returned to be @@ -1451,7 +1430,7 @@ func (ra *RegistrationAuthorityImpl) checkCertificatesPerNameLimit(ctx context.C return nil } - tldNames := domainsForRateLimiting(names) + tldNames := ratelimits.DomainsForRateLimiting(names) namesOutOfLimit, earliest, err := ra.enforceNameCounts(ctx, tldNames, limit, regID) if err != nil { return fmt.Errorf("checking certificates per name limit for %q: %s", diff --git a/ra/ra_test.go b/ra/ra_test.go index d95dd3322507..02e17d6232e5 100644 --- a/ra/ra_test.go +++ b/ra/ra_test.go @@ -1113,26 +1113,6 @@ func TestAuthzFailedRateLimitingNewOrder(t *testing.T) { testcase() } -func TestDomainsForRateLimiting(t *testing.T) { - domains := domainsForRateLimiting([]string{}) - test.AssertEquals(t, len(domains), 0) - - domains = domainsForRateLimiting([]string{"www.example.com", "example.com"}) - test.AssertDeepEquals(t, domains, []string{"example.com"}) - - domains = domainsForRateLimiting([]string{"www.example.com", "example.com", "www.example.co.uk"}) - test.AssertDeepEquals(t, domains, []string{"example.co.uk", "example.com"}) - - domains = domainsForRateLimiting([]string{"www.example.com", "example.com", "www.example.co.uk", "co.uk"}) - test.AssertDeepEquals(t, domains, []string{"co.uk", "example.co.uk", "example.com"}) - - domains = domainsForRateLimiting([]string{"foo.bar.baz.www.example.com", "baz.example.com"}) - test.AssertDeepEquals(t, domains, []string{"example.com"}) - - domains = domainsForRateLimiting([]string{"github.io", "foo.github.io", "bar.github.io"}) - test.AssertDeepEquals(t, domains, []string{"bar.github.io", "foo.github.io", "github.io"}) -} - type mockSAWithNameCounts struct { mocks.StorageAuthority nameCounts *sapb.CountByNames diff --git a/ratelimits/bucket.go b/ratelimits/bucket.go index 501d1fd2c447..8aa369a06f7f 100644 --- a/ratelimits/bucket.go +++ b/ratelimits/bucket.go @@ -3,15 +3,18 @@ package ratelimits import ( "fmt" "net" + "strconv" + + "github.com/letsencrypt/boulder/core" ) // BucketId should only be created using the New*BucketId functions. It is used // by the Limiter to look up the bucket and limit overrides for a specific // subscriber and limit. type BucketId struct { - // limit is the name of the associated rate limit. It is used for looking up - // default limits. - limit Name + // limitName is the name of the associated rate limit. It is used for + // looking up default limits. + limitName Name // bucketKey is the limit Name enum (e.g. "1") concatenated with the // subscriber identifier specific to the associate limit Name type. @@ -27,7 +30,7 @@ func NewRegistrationsPerIPAddressBucketId(ip net.IP) (BucketId, error) { return BucketId{}, err } return BucketId{ - limit: NewRegistrationsPerIPAddress, + limitName: NewRegistrationsPerIPAddress, bucketKey: joinWithColon(NewRegistrationsPerIPAddress.EnumString(), id), }, nil } @@ -46,7 +49,109 @@ func NewRegistrationsPerIPv6RangeBucketId(ip net.IP) (BucketId, error) { return BucketId{}, err } return BucketId{ - limit: NewRegistrationsPerIPv6Range, + limitName: NewRegistrationsPerIPv6Range, bucketKey: joinWithColon(NewRegistrationsPerIPv6Range.EnumString(), id), }, nil } + +// NewOrdersPerAccountBucketId returns a BucketId for the provided ACME +// registration Id. +func NewOrdersPerAccountBucketId(regId int64) (BucketId, error) { + id := strconv.FormatInt(regId, 10) + err := validateIdForName(NewOrdersPerAccount, id) + if err != nil { + return BucketId{}, err + } + return BucketId{ + limitName: NewOrdersPerAccount, + bucketKey: joinWithColon(NewOrdersPerAccount.EnumString(), id), + }, nil +} + +// NewFailedAuthorizationsPerAccountBucketId returns a BucketId for the provided +// ACME registration Id. +func NewFailedAuthorizationsPerAccountBucketId(regId int64) (BucketId, error) { + id := strconv.FormatInt(regId, 10) + err := validateIdForName(FailedAuthorizationsPerAccount, id) + if err != nil { + return BucketId{}, err + } + return BucketId{ + limitName: FailedAuthorizationsPerAccount, + bucketKey: joinWithColon(FailedAuthorizationsPerAccount.EnumString(), id), + }, nil +} + +// NewCertificatesPerDomainBucketId returns a BucketId for the provided order +// domain name. +func NewCertificatesPerDomainBucketId(orderName string, suppressDenials bool) (BucketId, error) { + err := validateIdForName(CertificatesPerDomain, orderName) + if err != nil { + return BucketId{}, err + } + return BucketId{ + limitName: CertificatesPerDomain, + bucketKey: joinWithColon(CertificatesPerDomain.EnumString(), orderName), + }, nil +} + +// newCertificatesPerDomainPerAccountBucketId is only referenced internally. +// Buckets for CertificatesPerDomainPerAccount are created by calling +// NewCertificatesPerDomainBucketsWithCost(). +func newCertificatesPerDomainPerAccountBucketId(regId int64) (BucketId, error) { + id := strconv.FormatInt(regId, 10) + err := validateIdForName(CertificatesPerDomainPerAccount, id) + if err != nil { + return BucketId{}, err + } + return BucketId{ + limitName: CertificatesPerDomainPerAccount, + bucketKey: joinWithColon(CertificatesPerDomainPerAccount.EnumString(), id), + }, nil +} + +// NewCertificatesPerDomainTransactions returns a slice of Transactions for the +// provided order domain names. The cost specified will be applied per eTLD+1 +// name present in the orderDomains. The CertificatesPerDomain limit is special +// in that it can be overridden by the CertificatesPerDomainPerAccount limit. +// This occurs when the CertificatesPerDomainPerAccount limit allows for higher +// throughput than the CertificatesPerDomain limit. In these cases, the cost +// will be consumed from the CertificatesPerDomainPerAccount bucket and ALSO +// from the CertificatesPerDomain bucket, if possible. +func NewCertificatesPerDomainTransactions(limiter *Limiter, regId int64, orderDomains []string, cost int64) ([]Transaction, error) { + id, err := newCertificatesPerDomainPerAccountBucketId(regId) + if err != nil { + return nil, err + } + regIdLimit, err := limiter.getLimit(CertificatesPerDomainPerAccount, id.bucketKey) + if err != nil { + return nil, err + } + + var txns []Transaction + var regIdBucketCost int64 + for _, name := range DomainsForRateLimiting(orderDomains) { + bucketId, err := NewCertificatesPerDomainBucketId(name, regIdLimit.isOverride) + if err != nil { + return nil, err + } + regIdBucketCost++ + txns = append(txns, NewTransaction(bucketId, cost)) + } + txns = append(txns, NewTransaction(id, regIdBucketCost*cost)) + return txns, nil +} + +// NewCertificatesPerFQDNSetBucket returns a BucketId for the provided order +// domain names. +func NewCertificatesPerFQDNSetBucket(orderNames []string) (BucketId, error) { + id := string(core.HashNames(orderNames)) + err := validateIdForName(CertificatesPerFQDNSet, id) + if err != nil { + return BucketId{}, err + } + return BucketId{ + limitName: CertificatesPerFQDNSet, + bucketKey: joinWithColon(CertificatesPerFQDNSet.EnumString(), id), + }, nil +} diff --git a/ratelimits/limiter.go b/ratelimits/limiter.go index 46c727b1d399..a14e3e0472b3 100644 --- a/ratelimits/limiter.go +++ b/ratelimits/limiter.go @@ -4,6 +4,8 @@ import ( "context" "errors" "fmt" + "math" + "slices" "time" "github.com/jmhodges/clock" @@ -99,6 +101,10 @@ func NewLimiter(clk clock.Clock, source source, defaults, overrides string, stat type Transaction struct { BucketId cost int64 + + // optimistic is true if the limiter should spend from the bucket if + // capacity exists, but should not return a denied Decision if it does not. + optimistic bool } // NewTransaction creates a new Transaction for the provided BucketId and cost. @@ -142,7 +148,7 @@ func (l *Limiter) Check(ctx context.Context, txn Transaction) (*Decision, error) return nil, ErrInvalidCostForCheck } - limit, err := l.getLimit(txn.limit, txn.bucketKey) + limit, err := l.getLimit(txn.limitName, txn.bucketKey) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -181,7 +187,7 @@ func (l *Limiter) Spend(ctx context.Context, txn Transaction) (*Decision, error) return nil, ErrInvalidCost } - limit, err := l.getLimit(txn.limit, txn.bucketKey) + limit, err := l.getLimit(txn.limitName, txn.bucketKey) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -196,7 +202,7 @@ func (l *Limiter) Spend(ctx context.Context, txn Transaction) (*Decision, error) start := l.clk.Now() status := Denied defer func() { - l.spendLatency.WithLabelValues(txn.limit.String(), status).Observe(l.clk.Since(start).Seconds()) + l.spendLatency.WithLabelValues(txn.limitName.String(), status).Observe(l.clk.Since(start).Seconds()) }() // Remove cancellation from the request context so that transactions are not @@ -223,7 +229,7 @@ func (l *Limiter) Spend(ctx context.Context, txn Transaction) (*Decision, error) if limit.isOverride { // Calculate the current utilization of the override limit. utilization := float64(limit.Burst-d.Remaining) / float64(limit.Burst) - l.overrideUsageGauge.WithLabelValues(txn.limit.String(), txn.bucketKey).Set(utilization) + l.overrideUsageGauge.WithLabelValues(txn.limitName.String(), txn.bucketKey).Set(utilization) } if !d.Allowed { @@ -238,6 +244,140 @@ func (l *Limiter) Spend(ctx context.Context, txn Transaction) (*Decision, error) return d, nil } +type batchTransaction struct { + Transaction + limit limit +} + +func (l *Limiter) prepareBatch(txns []Transaction) ([]batchTransaction, []string, error) { + var batchTxns []batchTransaction + var bucketKeys []string + var optimisticTxnCount int + for _, txn := range txns { + if txn.cost <= 0 { + return nil, nil, ErrInvalidCost + } + limit, err := l.getLimit(txn.limitName, txn.bucketKey) + if err != nil { + if errors.Is(err, errLimitDisabled) { + continue + } + return nil, nil, err + } + if txn.cost > limit.Burst { + return nil, nil, ErrInvalidCostOverLimit + } + if slices.Contains(bucketKeys, txn.bucketKey) { + return nil, nil, fmt.Errorf("found duplicate bucket %q in batch", txn.bucketKey) + } + if txn.optimistic { + optimisticTxnCount++ + } + bucketKeys = append(bucketKeys, txn.bucketKey) + batchTxns = append(batchTxns, batchTransaction{txn, limit}) + } + if len(batchTxns) == optimisticTxnCount { + // All transactions are optimistic, no need to check. + return nil, nil, nil + } + return batchTxns, bucketKeys, nil +} + +type batchDecision struct { + *Decision +} + +func newBatchDecision() *batchDecision { + return &batchDecision{ + Decision: &Decision{ + Allowed: true, + Remaining: math.MaxInt64, + }, + } +} + +func (d *batchDecision) consolidate(in *Decision) { + d.Allowed = d.Allowed && in.Allowed + d.Remaining = min(d.Remaining, in.Remaining) + d.RetryIn = max(d.RetryIn, in.RetryIn) + d.ResetIn = max(d.ResetIn, in.ResetIn) + if in.newTAT.After(d.newTAT) { + d.newTAT = in.newTAT + } +} + +// BatchSpend attempts to deduct the costs from the provided buckets' +// capacities. The following rules are applied to consolidate the Decisions for +// each transaction: +// - Allowed is true if all of the Decisions were allowed, +// - Remaining is the smallest value across all Decisions, and +// - RetryIn and ResetIn are the largest values across all Decisions. +// +// Non-existent buckets will be created WITH the cost factored into the initial +// state. New bucket states are persisted to the underlying datastore, if +// applicable, before returning. +func (l *Limiter) BatchSpend(ctx context.Context, txns []Transaction) (*Decision, error) { + batch, bucketKeys, err := l.prepareBatch(txns) + if err != nil { + return nil, err + } + if len(batch) <= 0 { + // All buckets were disabled. + return disabledLimitDecision, nil + } + + // Remove cancellation from the request context so that transactions are not + // interrupted by a client disconnect. + ctx = context.WithoutCancel(ctx) + tats, err := l.source.BatchGet(ctx, bucketKeys) + if err != nil { + return nil, err + } + + start := l.clk.Now() + batchDecision := newBatchDecision() + newTATs := make(map[string]time.Time) + + for _, txn := range batch { + tat, exists := tats[txn.bucketKey] + if !exists { + // First request from this client. Initialize the bucket with a TAT of + // "now", which is equivalent to a full bucket. + tat = l.clk.Now() + } + + // Spend the cost and update the consolidated decision. + d := maybeSpend(l.clk, txn.limit, tat, txn.cost) + if d.Allowed { + newTATs[txn.bucketKey] = d.newTAT + } + + if txn.limit.isOverride { + // Calculate the current utilization of the override limit. + utilization := float64(txn.limit.Burst-d.Remaining) / float64(txn.limit.Burst) + l.overrideUsageGauge.WithLabelValues(txn.limitName.String(), txn.bucketKey).Set(utilization) + } + + if !d.Allowed && txn.optimistic { + // Suppress denial for optimistic transaction. + d = disabledLimitDecision + } + batchDecision.consolidate(d) + } + + // Conditionally, spend the batch. + if batchDecision.Allowed { + err = l.source.BatchSet(ctx, newTATs) + if err != nil { + return nil, err + } + l.spendLatency.WithLabelValues("batch", Allowed).Observe(l.clk.Since(start).Seconds()) + } else { + l.spendLatency.WithLabelValues("batch", Denied).Observe(l.clk.Since(start).Seconds()) + } + return batchDecision.Decision, nil +} + // Refund attempts to refund all of the cost to the capacity of the specified // bucket. The returned *Decision indicates whether the refund was successful // and represents the current state of the bucket. The new bucket state is @@ -254,7 +394,7 @@ func (l *Limiter) Refund(ctx context.Context, txn Transaction) (*Decision, error return nil, ErrInvalidCost } - limit, err := l.getLimit(txn.limit, txn.bucketKey) + limit, err := l.getLimit(txn.limitName, txn.bucketKey) if err != nil { if errors.Is(err, errLimitDisabled) { return disabledLimitDecision, nil @@ -277,6 +417,62 @@ func (l *Limiter) Refund(ctx context.Context, txn Transaction) (*Decision, error return d, l.source.Set(ctx, txn.bucketKey, d.newTAT) } +// BatchRefund attempts to refund all or some of the costs to the provided +// buckets' capacities. The following rules are applied to consolidate the +// Decisions for each transaction: +// - Allowed is true if all of the Decisions were allowed, +// - Remaining is the smallest value across all Decisions, and +// - RetryIn and ResetIn are the largest values across all Decisions. +// +// Non-existent buckets within the batch are disregarded without error, as this +// is equivalent to the bucket being full. The new bucket state is persisted to +// the underlying datastore, if applicable, before returning. +func (l *Limiter) BatchRefund(ctx context.Context, txns []Transaction) (*Decision, error) { + batch, bucketKeys, err := l.prepareBatch(txns) + if err != nil { + return nil, err + } + if len(batch) <= 0 { + // All buckets were disabled. + return disabledLimitDecision, nil + } + + // Remove cancellation from the request context so that transactions are not + // interrupted by a client disconnect. + ctx = context.WithoutCancel(ctx) + tats, err := l.source.BatchGet(ctx, bucketKeys) + if err != nil { + return nil, err + } + + batchDecision := newBatchDecision() + newTATs := make(map[string]time.Time) + + for _, txn := range batch { + tat, exists := tats[txn.bucketKey] + if !exists { + // Ignore non-existent buckets. + continue + } + + // Refund the cost and update the consolidated decision. + d := maybeRefund(l.clk, txn.limit, tat, txn.cost) + if d.Allowed { + newTATs[txn.bucketKey] = d.newTAT + } + batchDecision.consolidate(d) + } + + // Conditionally, refund the batch. + if len(newTATs) > 0 { + err = l.source.BatchSet(ctx, newTATs) + if err != nil { + return nil, err + } + } + return batchDecision.Decision, nil +} + // Reset resets the specified bucket to its maximum capacity. The new bucket // state is persisted to the underlying datastore before returning. func (l *Limiter) Reset(ctx context.Context, bucketId BucketId) error { diff --git a/ratelimits/limiter_test.go b/ratelimits/limiter_test.go index 7f1088102793..a53155af0074 100644 --- a/ratelimits/limiter_test.go +++ b/ratelimits/limiter_test.go @@ -58,7 +58,7 @@ func Test_Limiter_CheckWithLimitNoExist(t *testing.T) { testCtx, limiters, _, testIP := setup(t) for name, l := range limiters { t.Run(name, func(t *testing.T) { - bucketId := BucketId{limit: Name(9999), bucketKey: testIP} + bucketId := BucketId{limitName: Name(9999), bucketKey: testIP} _, err := l.Check(testCtx, NewTransaction(bucketId, 1)) test.AssertError(t, err, "should error") }) diff --git a/ratelimits/source.go b/ratelimits/source.go index 84935daefd39..0f8151993fa4 100644 --- a/ratelimits/source.go +++ b/ratelimits/source.go @@ -19,6 +19,14 @@ type source interface { // the underlying storage client implementation). Set(ctx context.Context, bucketKey string, tat time.Time) error + // BatchSet stores the TATs at the specified bucketKeys (formatted as + // 'name:id'). Implementations MUST ensure non-blocking operations by + // either: + // a) applying a deadline or timeout to the context WITHIN the method, or + // b) guaranteeing the operation will not block indefinitely (e.g. via + // the underlying storage client implementation). + BatchSet(ctx context.Context, bucketKeys map[string]time.Time) error + // Get retrieves the TAT associated with the specified bucketKey (formatted // as 'name:id'). Implementations MUST ensure non-blocking operations by // either: @@ -27,6 +35,14 @@ type source interface { // the underlying storage client implementation). Get(ctx context.Context, bucketKey string) (time.Time, error) + // BatchGet retrieves the TATs associated with the specified bucketKeys + // (formatted as 'name:id'). Implementations MUST ensure non-blocking + // operations by either: + // a) applying a deadline or timeout to the context WITHIN the method, or + // b) guaranteeing the operation will not block indefinitely (e.g. via + // the underlying storage client implementation). + BatchGet(ctx context.Context, bucketKeys []string) (map[string]time.Time, error) + // Delete removes the TAT associated with the specified bucketKey (formatted // as 'name:id'). Implementations MUST ensure non-blocking operations by // either: @@ -54,6 +70,15 @@ func (in *inmem) Set(_ context.Context, bucketKey string, tat time.Time) error { return nil } +func (in *inmem) BatchSet(_ context.Context, bucketKeys map[string]time.Time) error { + in.Lock() + defer in.Unlock() + for k, v := range bucketKeys { + in.m[k] = v + } + return nil +} + func (in *inmem) Get(_ context.Context, bucketKey string) (time.Time, error) { in.RLock() defer in.RUnlock() @@ -64,6 +89,20 @@ func (in *inmem) Get(_ context.Context, bucketKey string) (time.Time, error) { return tat, nil } +func (in *inmem) BatchGet(_ context.Context, bucketKeys []string) (map[string]time.Time, error) { + in.RLock() + defer in.RUnlock() + tats := make(map[string]time.Time, len(bucketKeys)) + for _, k := range bucketKeys { + tat, ok := in.m[k] + if !ok { + tats[k] = time.Time{} + } + tats[k] = tat + } + return tats, nil +} + func (in *inmem) Delete(_ context.Context, bucketKey string) error { in.Lock() defer in.Unlock() diff --git a/ratelimits/source_redis.go b/ratelimits/source_redis.go index 5664058fdf0f..79b1d9867a4f 100644 --- a/ratelimits/source_redis.go +++ b/ratelimits/source_redis.go @@ -83,6 +83,26 @@ func (r *RedisSource) Set(ctx context.Context, bucketKey string, tat time.Time) return nil } +// BatchSet stores TATs at the specified bucketKeys using a pipelined Redis +// transaction in order to reduce the number of round-trips to each Redis shard. +// An error is returned if the operation failed and nil otherwise. +func (r *RedisSource) BatchSet(ctx context.Context, buckets map[string]time.Time) error { + start := r.clk.Now() + + pipeline := r.client.Pipeline() + for bucketKey, tat := range buckets { + pipeline.Set(ctx, bucketKey, tat.UTC().UnixNano(), 0) + } + _, err := pipeline.Exec(ctx) + if err != nil { + r.latency.With(prometheus.Labels{"call": "batchset", "result": resultForError(err)}).Observe(time.Since(start).Seconds()) + return err + } + + r.latency.With(prometheus.Labels{"call": "batchset", "result": "success"}).Observe(time.Since(start).Seconds()) + return nil +} + // Get retrieves the TAT at the specified bucketKey. An error is returned if the // operation failed and nil otherwise. If the bucketKey does not exist, // ErrBucketNotFound is returned. @@ -104,6 +124,43 @@ func (r *RedisSource) Get(ctx context.Context, bucketKey string) (time.Time, err return time.Unix(0, tatNano).UTC(), nil } +// BatchGet retrieves the TATs at the specified bucketKeys using a pipelined +// Redis transaction in order to reduce the number of round-trips to each Redis +// shard. An error is returned if the operation failed and nil otherwise. If a +// bucketKey does not exist, it WILL NOT be included in the returned map. +func (r *RedisSource) BatchGet(ctx context.Context, bucketKeys []string) (map[string]time.Time, error) { + start := r.clk.Now() + + pipeline := r.client.Pipeline() + for _, bucketKey := range bucketKeys { + pipeline.Get(ctx, bucketKey) + } + results, err := pipeline.Exec(ctx) + if err != nil { + r.latency.With(prometheus.Labels{"call": "batchget", "result": resultForError(err)}).Observe(time.Since(start).Seconds()) + if !errors.Is(err, redis.Nil) { + return nil, err + } + } + + tats := make(map[string]time.Time, len(bucketKeys)) + for i, result := range results { + tatNano, err := result.(*redis.StringCmd).Int64() + if err != nil { + if errors.Is(err, redis.Nil) { + // Bucket key does not exist. + continue + } + r.latency.With(prometheus.Labels{"call": "batchget", "result": resultForError(err)}).Observe(time.Since(start).Seconds()) + return nil, err + } + tats[bucketKeys[i]] = time.Unix(0, tatNano).UTC() + } + + r.latency.With(prometheus.Labels{"call": "batchget", "result": "success"}).Observe(time.Since(start).Seconds()) + return tats, nil +} + // Delete deletes the TAT at the specified bucketKey ('name:id'). It returns an // error if the operation failed and nil otherwise. A nil return value does not // indicate that the bucketKey existed. diff --git a/ratelimits/utilities.go b/ratelimits/utilities.go index 8a7cbca70877..dd5a1167eca8 100644 --- a/ratelimits/utilities.go +++ b/ratelimits/utilities.go @@ -2,9 +2,32 @@ package ratelimits import ( "strings" + + "github.com/letsencrypt/boulder/core" + "github.com/weppos/publicsuffix-go/publicsuffix" ) // joinWithColon joins the provided args with a colon. func joinWithColon(args ...string) string { return strings.Join(args, ":") } + +// DomainsForRateLimiting transforms a list of FQDNs into a list of eTLD+1's +// for the purpose of rate limiting. It also de-duplicates the output +// domains. Exact public suffix matches are included. +func DomainsForRateLimiting(names []string) []string { + var domains []string + for _, name := range names { + domain, err := publicsuffix.Domain(name) + if err != nil { + // The only possible errors are: + // (1) publicsuffix.Domain is giving garbage values + // (2) the public suffix is the domain itself + // We assume 2 and include the original name in the result. + domains = append(domains, name) + } else { + domains = append(domains, domain) + } + } + return core.UniqueLowerNames(domains) +} diff --git a/ratelimits/utilities_test.go b/ratelimits/utilities_test.go new file mode 100644 index 000000000000..9c68d3a6e899 --- /dev/null +++ b/ratelimits/utilities_test.go @@ -0,0 +1,27 @@ +package ratelimits + +import ( + "testing" + + "github.com/letsencrypt/boulder/test" +) + +func TestDomainsForRateLimiting(t *testing.T) { + domains := DomainsForRateLimiting([]string{}) + test.AssertEquals(t, len(domains), 0) + + domains = DomainsForRateLimiting([]string{"www.example.com", "example.com"}) + test.AssertDeepEquals(t, domains, []string{"example.com"}) + + domains = DomainsForRateLimiting([]string{"www.example.com", "example.com", "www.example.co.uk"}) + test.AssertDeepEquals(t, domains, []string{"example.co.uk", "example.com"}) + + domains = DomainsForRateLimiting([]string{"www.example.com", "example.com", "www.example.co.uk", "co.uk"}) + test.AssertDeepEquals(t, domains, []string{"co.uk", "example.co.uk", "example.com"}) + + domains = DomainsForRateLimiting([]string{"foo.bar.baz.www.example.com", "baz.example.com"}) + test.AssertDeepEquals(t, domains, []string{"example.com"}) + + domains = DomainsForRateLimiting([]string{"github.io", "foo.github.io", "bar.github.io"}) + test.AssertDeepEquals(t, domains, []string{"bar.github.io", "foo.github.io", "github.io"}) +}