Skip to content

Commit

Permalink
Addressing comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
beautifulentropy committed Nov 6, 2023
1 parent 81df1d7 commit 45d1849
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 159 deletions.
52 changes: 24 additions & 28 deletions ratelimits/bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,52 +5,48 @@ import (
"net"
)

// Bucket identifies a specific subscriber rate limit bucket to the Limiter.
type Bucket struct {
name Name
key string
}

// BucketWithCost is a bucket with an associated cost.
type BucketWithCost struct {
Bucket
cost int64
}
// BucketId should only be created using the New*BucketId functions. It is used
// by the Limiter to look up the bucket and limit overrides for a specific
// subscriber and limit.
type BucketId struct {
// limit is the name of the associated rate limit. It is used for looking up
// default limits.
limit Name

// WithCost returns a BucketWithCost for the provided cost.
func (b Bucket) WithCost(cost int64) BucketWithCost {
return BucketWithCost{b, cost}
// bucketKey is the limit Name enum (e.g. "1") concatenated with the
// subscriber identifier specific to the associate limit Name type.
bucketKey string
}

// NewRegistrationsPerIPAddressBucket returns a Bucket for the provided IP
// NewRegistrationsPerIPAddressBucketId returns a BucketId for the provided IP
// address.
func NewRegistrationsPerIPAddressBucket(ip net.IP) (Bucket, error) {
func NewRegistrationsPerIPAddressBucketId(ip net.IP) (BucketId, error) {
id := ip.String()
err := validateIdForName(NewRegistrationsPerIPAddress, id)
if err != nil {
return Bucket{}, err
return BucketId{}, err
}
return Bucket{
name: NewRegistrationsPerIPAddress,
key: joinWithColon(NewRegistrationsPerIPAddress.EnumString(), id),
return BucketId{
limit: NewRegistrationsPerIPAddress,
bucketKey: joinWithColon(NewRegistrationsPerIPAddress.EnumString(), id),
}, nil
}

// NewRegistrationsPerIPv6RangeBucket returns a Bucket for the /48 IPv6 range
// containing the provided IPv6 address.
func NewRegistrationsPerIPv6RangeBucket(ip net.IP) (Bucket, error) {
// NewRegistrationsPerIPv6RangeBucketId returns a BucketId for the /48 IPv6
// range containing the provided IPv6 address.
func NewRegistrationsPerIPv6RangeBucketId(ip net.IP) (BucketId, error) {
if ip.To4() != nil {
return Bucket{}, fmt.Errorf("invalid IPv6 address, %q must be an IPv6 address", ip.String())
return BucketId{}, fmt.Errorf("invalid IPv6 address, %q must be an IPv6 address", ip.String())
}
ipMask := net.CIDRMask(48, 128)
ipNet := &net.IPNet{IP: ip.Mask(ipMask), Mask: ipMask}
id := ipNet.String()
err := validateIdForName(NewRegistrationsPerIPv6Range, id)
if err != nil {
return Bucket{}, err
return BucketId{}, err
}
return Bucket{
name: NewRegistrationsPerIPv6Range,
key: joinWithColon(NewRegistrationsPerIPv6Range.EnumString(), id),
return BucketId{
limit: NewRegistrationsPerIPv6Range,
bucketKey: joinWithColon(NewRegistrationsPerIPv6Range.EnumString(), id),
}, nil
}
84 changes: 49 additions & 35 deletions ratelimits/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,20 @@ func NewLimiter(clk clock.Clock, source source, defaults, overrides string, stat
return limiter, nil
}

// Transaction is a cost to be spent or refunded from a specific BucketId.
type Transaction struct {
BucketId
cost int64
}

// NewTransaction creates a new Transaction for the provided BucketId and cost.
func NewTransaction(b BucketId, cost int64) Transaction {
return Transaction{
BucketId: b,
cost: cost,
}
}

type Decision struct {
// Allowed is true if the bucket possessed enough capacity to allow the
// request given the cost.
Expand Down Expand Up @@ -123,37 +137,37 @@ type Decision struct {
// satisfy the cost and represents the hypothetical state of the bucket IF the
// cost WERE to be deducted. If no bucket exists it will NOT be created. No
// state is persisted to the underlying datastore.
func (l *Limiter) Check(ctx context.Context, bucket BucketWithCost) (*Decision, error) {
if bucket.cost < 0 {
func (l *Limiter) Check(ctx context.Context, txn Transaction) (*Decision, error) {
if txn.cost < 0 {
return nil, ErrInvalidCostForCheck
}

limit, err := l.getLimit(bucket.name, bucket.key)
limit, err := l.getLimit(txn.limit, txn.bucketKey)
if err != nil {
if errors.Is(err, errLimitDisabled) {
return disabledLimitDecision, nil
}
return nil, err
}

if bucket.cost > limit.Burst {
if txn.cost > limit.Burst {
return nil, ErrInvalidCostOverLimit
}

// Remove cancellation from the request context so that transactions are not
// interrupted by a client disconnect.
ctx = context.WithoutCancel(ctx)
tat, err := l.source.Get(ctx, bucket.key)
tat, err := l.source.Get(ctx, txn.bucketKey)
if err != nil {
if !errors.Is(err, ErrBucketNotFound) {
return nil, err
}
// First request from this client. No need to initialize the bucket
// because this is a check, not a spend. A TAT of "now" is equivalent to
// a full bucket.
return maybeSpend(l.clk, limit, l.clk.Now(), bucket.cost), nil
return maybeSpend(l.clk, limit, l.clk.Now(), txn.cost), nil
}
return maybeSpend(l.clk, limit, tat, bucket.cost), nil
return maybeSpend(l.clk, limit, tat, txn.cost), nil
}

// Spend attempts to deduct the cost from the provided bucket's capacity. The
Expand All @@ -162,37 +176,37 @@ func (l *Limiter) Check(ctx context.Context, bucket BucketWithCost) (*Decision,
// be created WITH the cost factored into its initial state. The new bucket
// state is persisted to the underlying datastore, if applicable, before
// returning.
func (l *Limiter) Spend(ctx context.Context, bucket BucketWithCost) (*Decision, error) {
if bucket.cost <= 0 {
func (l *Limiter) Spend(ctx context.Context, txn Transaction) (*Decision, error) {
if txn.cost <= 0 {
return nil, ErrInvalidCost
}

limit, err := l.getLimit(bucket.name, bucket.key)
limit, err := l.getLimit(txn.limit, txn.bucketKey)
if err != nil {
if errors.Is(err, errLimitDisabled) {
return disabledLimitDecision, nil
}
return nil, err
}

if bucket.cost > limit.Burst {
if txn.cost > limit.Burst {
return nil, ErrInvalidCostOverLimit
}

start := l.clk.Now()
status := Denied
defer func() {
l.spendLatency.WithLabelValues(bucket.name.String(), status).Observe(l.clk.Since(start).Seconds())
l.spendLatency.WithLabelValues(txn.limit.String(), status).Observe(l.clk.Since(start).Seconds())
}()

// Remove cancellation from the request context so that transactions are not
// interrupted by a client disconnect.
ctx = context.WithoutCancel(ctx)
tat, err := l.source.Get(ctx, bucket.key)
tat, err := l.source.Get(ctx, txn.bucketKey)
if err != nil {
if errors.Is(err, ErrBucketNotFound) {
// First request from this client.
d, err := l.initialize(ctx, limit, bucket)
d, err := l.initialize(ctx, limit, txn)
if err != nil {
return nil, err
}
Expand All @@ -204,19 +218,19 @@ func (l *Limiter) Spend(ctx context.Context, bucket BucketWithCost) (*Decision,
return nil, err
}

d := maybeSpend(l.clk, limit, tat, bucket.cost)
d := maybeSpend(l.clk, limit, tat, txn.cost)

if limit.isOverride {
// Calculate the current utilization of the override limit.
utilization := float64(limit.Burst-d.Remaining) / float64(limit.Burst)
l.overrideUsageGauge.WithLabelValues(bucket.name.String(), bucket.key).Set(utilization)
l.overrideUsageGauge.WithLabelValues(txn.limit.String(), txn.bucketKey).Set(utilization)
}

if !d.Allowed {
return d, nil
}

err = l.source.Set(ctx, bucket.key, d.newTAT)
err = l.source.Set(ctx, txn.bucketKey, d.newTAT)
if err != nil {
return nil, err
}
Expand All @@ -235,12 +249,12 @@ func (l *Limiter) Spend(ctx context.Context, bucket BucketWithCost) (*Decision,
// instance, if a bucket has a maximum capacity of 10 and currently has 5
// requests remaining, a refund request of 7 will result in the bucket reaching
// its maximum capacity of 10, not 12.
func (l *Limiter) Refund(ctx context.Context, bucket BucketWithCost) (*Decision, error) {
if bucket.cost <= 0 {
func (l *Limiter) Refund(ctx context.Context, txn Transaction) (*Decision, error) {
if txn.cost <= 0 {
return nil, ErrInvalidCost
}

limit, err := l.getLimit(bucket.name, bucket.key)
limit, err := l.getLimit(txn.limit, txn.bucketKey)
if err != nil {
if errors.Is(err, errLimitDisabled) {
return disabledLimitDecision, nil
Expand All @@ -251,57 +265,57 @@ func (l *Limiter) Refund(ctx context.Context, bucket BucketWithCost) (*Decision,
// Remove cancellation from the request context so that transactions are not
// interrupted by a client disconnect.
ctx = context.WithoutCancel(ctx)
tat, err := l.source.Get(ctx, bucket.key)
tat, err := l.source.Get(ctx, txn.bucketKey)
if err != nil {
return nil, err
}
d := maybeRefund(l.clk, limit, tat, bucket.cost)
d := maybeRefund(l.clk, limit, tat, txn.cost)
if !d.Allowed {
// The bucket is already at maximum capacity.
return d, nil
}
return d, l.source.Set(ctx, bucket.key, d.newTAT)
return d, l.source.Set(ctx, txn.bucketKey, d.newTAT)
}

// Reset resets the specified bucket to its maximum capacity. The new bucket
// state is persisted to the underlying datastore before returning.
func (l *Limiter) Reset(ctx context.Context, bucket Bucket) error {
func (l *Limiter) Reset(ctx context.Context, bucketId BucketId) error {
// Remove cancellation from the request context so that transactions are not
// interrupted by a client disconnect.
ctx = context.WithoutCancel(ctx)
return l.source.Delete(ctx, bucket.key)
return l.source.Delete(ctx, bucketId.bucketKey)
}

// initialize creates a new bucket and sets its TAT to now, which is equivalent
// to a full bucket. The new bucket state is persisted to the underlying
// datastore before returning.
func (l *Limiter) initialize(ctx context.Context, rl limit, bucket BucketWithCost) (*Decision, error) {
d := maybeSpend(l.clk, rl, l.clk.Now(), bucket.cost)
func (l *Limiter) initialize(ctx context.Context, rl limit, txn Transaction) (*Decision, error) {
d := maybeSpend(l.clk, rl, l.clk.Now(), txn.cost)

// Remove cancellation from the request context so that transactions are not
// interrupted by a client disconnect.
ctx = context.WithoutCancel(ctx)
err := l.source.Set(ctx, bucket.key, d.newTAT)
err := l.source.Set(ctx, txn.bucketKey, d.newTAT)
if err != nil {
return nil, err
}
return d, nil

}

// getLimit returns the limit for the specified by name and bucketKey, name is
// required, bucketKey is optional. If bucketKey is left unspecified, the
// default limit for the limit specified by name is returned. If no default
// limit exists for the specified name, errLimitDisabled is returned.
func (l *Limiter) getLimit(name Name, bucketKey string) (limit, error) {
// getLimit returns the limit for the specified by name and id, name is
// required, id is optional. If id is left unspecified, the default limit for
// the limit specified by name is returned. If no default limit exists for the
// specified name, errLimitDisabled is returned.
func (l *Limiter) getLimit(name Name, id string) (limit, error) {
if !name.isValid() {
// This should never happen. Callers should only be specifying the limit
// Name enums defined in this package.
return limit{}, fmt.Errorf("specified name enum %q, is invalid", name)
}
if bucketKey != "" {
if id != "" {
// Check for override.
ol, ok := l.overrides[bucketKey]
ol, ok := l.overrides[id]
if ok {
return ol, nil
}
Expand Down
Loading

0 comments on commit 45d1849

Please sign in to comment.