diff --git a/.github/workflows/master.yaml b/.github/workflows/master.yaml index 34450a8..83dd798 100644 --- a/.github/workflows/master.yaml +++ b/.github/workflows/master.yaml @@ -12,7 +12,7 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 15 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - uses: actions/setup-go@v5 with: @@ -40,7 +40,7 @@ jobs: comment-on-alert: true - name: Save benchmark JSON to cache - uses: actions/cache/save@v4 + uses: actions/cache/save@v5 with: path: ./cache/benchmark-data.json # Save with commit hash to avoid "cache already exists" diff --git a/Makefile b/Makefile index 56cc731..50b4ffd 100644 --- a/Makefile +++ b/Makefile @@ -15,16 +15,25 @@ $(GOLANGCI_LINT): ## Download Go linter lint: $(GOLANGCI_LINT) ## Run Go linter $(GOLANGCI_LINT) run -v -c .golangci.yml ./... +.PHONY: tidy +tidy: + go mod tidy && git diff --exit-code + +.PHONY: ci +ci: tidy lint test bench + @echo + @echo "\033[32mEVERYTHING PASSED!\033[0m" + .PHONY: test test: ## Run unit tests and measure code coverage - (go test -v -race -p=1 -count=1 -tags holster_test_mode -coverprofile coverage.out ./...; ret=$$?; \ + (go test -v -race -p=1 -count=1 -tags clock_mutex -coverprofile coverage.out ./...; ret=$$?; \ go tool cover -func coverage.out; \ go tool cover -html coverage.out -o coverage.html; \ exit $$ret) .PHONY: bench bench: ## Run Go benchmarks - go test ./... -bench . -benchtime 5s -timeout 0 -run='^$$' -benchmem + go test ./... -bench . -timeout 6m -run='^$$' -benchmem .PHONY: docker docker: ## Build Docker image @@ -45,7 +54,6 @@ clean-proto: ## Clean the generated source files from the protobuf sources @find . -name "*.pb.go" -type f -delete @find . -name "*.pb.*.go" -type f -delete - .PHONY: proto proto: ## Build protos ./buf.gen.yaml diff --git a/README.md b/README.md index 3bc2ab9..3d140f6 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,6 @@ Gubernator is a distributed, high performance, cloud native and stateless rate-l kubernetes or nomad trivial. * Gubernator holds no state on disk, It’s configuration is passed to it by the client on a per-request basis. -* Gubernator provides both GRPC and HTTP access to the API. * It Can be run as a sidecar to services that need rate limiting or as a separate service. * It Can be used as a library to implement a domain-specific rate limiting service. * Supports optional eventually consistent rate limit distribution for extremely @@ -38,8 +37,10 @@ $ docker-compose up -d ``` Now you can make rate limit requests via CURL ``` -# Hit the HTTP API at localhost:1050 (GRPC is at 1051) -$ curl http://localhost:1050/v1/HealthCheck +# Hit the HTTP API at localhost:9080 +$ curl http://localhost:9080/v1/health.check + +# TODO: Update this example # Make a rate limit request $ curl http://localhost:1050/v1/GetRateLimits \ @@ -59,7 +60,7 @@ $ curl http://localhost:1050/v1/GetRateLimits \ ### ProtoBuf Structure -An example rate limit request sent via GRPC might look like the following +An example rate limit request sent with protobuf might look like the following ```yaml rate_limits: # Scopes the request to a specific rate limit @@ -214,7 +215,7 @@ limiting service. When you use the library, your service becomes a full member of the cluster participating in the same consistent hashing and caching as a stand alone -Gubernator server would. All you need to do is provide the GRPC server instance +Gubernator server would. All you need to do is provide the server instance and tell Gubernator where the peers in your cluster are located. The `cmd/gubernator/main.go` is a great example of how to use Gubernator as a library. @@ -238,21 +239,13 @@ to support rate limit durations longer than a minute, day or month, calls to those rate limits that have durations over a self determined limit. ### API -All methods are accessed via GRPC but are also exposed via HTTP using the -[GRPC Gateway](https://github.com/grpc-ecosystem/grpc-gateway) #### Health Check Health check returns `unhealthy` in the event a peer is reported by etcd or kubernetes as `up` but the server instance is unable to contact that peer via it's advertised address. -###### GRPC -```grpc -rpc HealthCheck (HealthCheckReq) returns (HealthCheckResp) -``` - -###### HTTP ``` -GET /v1/HealthCheck +GET /v1/health.check ``` Example response: @@ -269,14 +262,8 @@ Rate limits can be applied or retrieved using this interface. If the client makes a request to the server with `hits: 0` then current state of the rate limit is retrieved but not incremented. -###### GRPC -```grpc -rpc GetRateLimits (GetRateLimitsReq) returns (GetRateLimitsResp) -``` - -###### HTTP ``` -POST /v1/GetRateLimits +POST /v1/rate-limit.check ``` Example Payload @@ -285,7 +272,7 @@ Example Payload "requests": [ { "name": "requests_per_sec", - "uniqueKey": "account:12345", + "unique_key": "account:12345", "hits": "1", "limit": "10", "duration": "1000" @@ -314,20 +301,10 @@ Example response: ``` ### Deployment -NOTE: Gubernator uses `etcd`, Kubernetes or round-robin DNS to discover peers and +NOTE: Gubernator uses `memberlist` Kubernetes or round-robin DNS to discover peers and establish a cluster. If you don't have either, the docker-compose method is the simplest way to try gubernator out. - -##### Docker with existing etcd cluster -```bash -$ docker run -p 1051:1051 -p 1050:1050 -e GUBER_ETCD_ENDPOINTS=etcd1:2379,etcd2:2379 \ - ghcr.io/gubernator-io/gubernator:latest - -# Hit the HTTP API at localhost:1050 -$ curl http://localhost:1050/v1/HealthCheck -``` - ##### Kubernetes ```bash # Download the kubernetes deployment spec @@ -346,14 +323,15 @@ you can use same fully-qualified domain name to both let your business logic con instances to find `gubernator` and for `gubernator` containers/instances to find each other. ##### TLS -Gubernator supports TLS for both HTTP and GRPC connections. You can see an example with -self signed certs by running `docker-compose-tls.yaml` +Gubernator supports TLS. You can see an example with self-signed certs by running +`docker-compose-tls.yaml` ```bash # Run docker compose $ docker-compose -f docker-compose-tls.yaml up -d -# Hit the HTTP API at localhost:1050 (GRPC is at 1051) -$ curl --cacert certs/ca.cert --cert certs/gubernator.pem --key certs/gubernator.key https://localhost:1050/v1/HealthCheck +# Hit the HTTP API at localhost:9080 ++$ curl -X POST --cacert certs/ca.cert --cert certs/gubernator.pem \ + --key certs/gubernator.key https://localhost:9080/v1/health.check ``` ### Configuration diff --git a/algorithms.go b/algorithms.go index c923161..40fdaa3 100644 --- a/algorithms.go +++ b/algorithms.go @@ -18,14 +18,26 @@ package gubernator import ( "context" + "errors" - "github.com/mailgun/holster/v4/clock" + "github.com/kapetan-io/tackle/clock" "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" - "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" ) +var errAlreadyExistsInCache = errors.New("already exists in cache") + +type rateContext struct { + context.Context + // TODO(thrawn01): Roll this into `rateContext` + ReqState RateLimitContext + + Request *RateLimitRequest + CacheItem *CacheItem + Store Store + Cache Cache +} + // ### NOTE ### // The both token and leaky follow the same semantic which allows for requests of more than the limit // to be rejected, but subsequent requests within the same window that are under the limit to succeed. @@ -34,413 +46,425 @@ import ( // with 100 emails and the request will succeed. You can override this default behavior with `DRAIN_OVER_LIMIT` // Implements token bucket algorithm for rate limiting. https://en.wikipedia.org/wiki/Token_bucket -func tokenBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqState RateLimitReqState) (resp *RateLimitResp, err error) { +func tokenBucket(ctx rateContext) (resp *RateLimitResponse, err error) { tokenBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("tokenBucket")) defer tokenBucketTimer.ObserveDuration() - - // Get rate limit from cache. - hashKey := r.HashKey() - item, ok := c.GetItem(hashKey) - - if s != nil && !ok { - // Cache miss. - // Check our store for the item. - if item, ok = s.Get(ctx, r); ok { - c.Add(item) + var ok bool + + // Get rate limit from cache + hashKey := ctx.Request.HashKey() + ctx.CacheItem, ok = ctx.Cache.GetItem(hashKey) + + // If not in the cache, check the store if provided + if ctx.Store != nil && !ok { + if ctx.CacheItem, ok = ctx.Store.Get(ctx, ctx.Request); ok { + if !ctx.Cache.AddIfNotPresent(ctx.CacheItem) { + // Someone else added a new token bucket item to the cache for this + // rate limit before we did, so we retry by calling ourselves recursively. + return tokenBucket(ctx) + } } } - // Sanity checks. - if ok { - if item.Value == nil { - msgPart := "tokenBucket: Invalid cache item; Value is nil" - trace.SpanFromContext(ctx).AddEvent(msgPart, trace.WithAttributes( - attribute.String("hashKey", hashKey), - attribute.String("key", r.UniqueKey), - attribute.String("name", r.Name), - )) - logrus.Error(msgPart) - ok = false - } else if item.Key != hashKey { - msgPart := "tokenBucket: Invalid cache item; key mismatch" - trace.SpanFromContext(ctx).AddEvent(msgPart, trace.WithAttributes( - attribute.String("itemKey", item.Key), - attribute.String("hashKey", hashKey), - attribute.String("name", r.Name), - )) - logrus.Error(msgPart) - ok = false + // If no item was found, or the item is expired. + if !ok || ctx.CacheItem.IsExpired() { + rl, err := initTokenBucketItem(ctx) + if err != nil && errors.Is(err, errAlreadyExistsInCache) { + // Someone else added a new token bucket item to the cache for this + // rate limit before we did, so we retry by calling ourselves recursively. + return tokenBucket(ctx) } + return rl, err } - if ok { - // Item found in cache or store. - if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) { - c.Remove(hashKey) + // Gain exclusive rights to this item while we calculate the rate limit + ctx.CacheItem.mutex.Lock() - if s != nil { - s.Remove(ctx, hashKey) - } - return &RateLimitResp{ - Status: Status_UNDER_LIMIT, - Limit: r.Limit, - Remaining: r.Limit, - ResetTime: 0, - }, nil + t, ok := ctx.CacheItem.Value.(*TokenBucketItem) + if !ok { + // Client switched algorithms; perhaps due to a migration? + ctx.Cache.Remove(hashKey) + if ctx.Store != nil { + ctx.Store.Remove(ctx, hashKey) } - t, ok := item.Value.(*TokenBucketItem) - if !ok { - // Client switched algorithms; perhaps due to a migration? - trace.SpanFromContext(ctx).AddEvent("Client switched algorithms; perhaps due to a migration?") + // Tell init to create a new cache item + ctx.CacheItem.mutex.Unlock() + ctx.CacheItem = nil + rl, err := initTokenBucketItem(ctx) + if err != nil && errors.Is(err, errAlreadyExistsInCache) { + return tokenBucket(ctx) + } + return rl, err + } - c.Remove(hashKey) + defer ctx.CacheItem.mutex.Unlock() - if s != nil { - s.Remove(ctx, hashKey) - } + if HasBehavior(ctx.Request.Behavior, Behavior_RESET_REMAINING) { + t.Remaining = ctx.Request.Limit + t.Limit = ctx.Request.Limit + t.Status = Status_UNDER_LIMIT + + if ctx.Store != nil { + ctx.Store.OnChange(ctx, ctx.Request, ctx.CacheItem) + } + return &RateLimitResponse{ + Status: Status_UNDER_LIMIT, + Limit: ctx.Request.Limit, + Remaining: ctx.Request.Limit, + ResetTime: 0, + }, nil + } - return tokenBucketNewItem(ctx, s, c, r, reqState) + // Update the limit if it changed. + if t.Limit != ctx.Request.Limit { + // Add difference to remaining. + t.Remaining += ctx.Request.Limit - t.Limit + if t.Remaining < 0 { + t.Remaining = 0 } + t.Limit = ctx.Request.Limit + } + + rl := &RateLimitResponse{ + Status: t.Status, + Limit: ctx.Request.Limit, + Remaining: t.Remaining, + ResetTime: ctx.CacheItem.ExpireAt, + } - // Update the limit if it changed. - if t.Limit != r.Limit { - // Add difference to remaining. - t.Remaining += r.Limit - t.Limit - if t.Remaining < 0 { - t.Remaining = 0 + // If the duration config changed, update the new ExpireAt. + if t.Duration != ctx.Request.Duration { + span := trace.SpanFromContext(ctx) + span.AddEvent("Duration changed") + expire := t.CreatedAt + ctx.Request.Duration + if HasBehavior(ctx.Request.Behavior, Behavior_DURATION_IS_GREGORIAN) { + expire, err = GregorianExpiration(clock.Now(), ctx.Request.Duration) + if err != nil { + return nil, err } - t.Limit = r.Limit } - rl := &RateLimitResp{ - Status: t.Status, - Limit: r.Limit, - Remaining: t.Remaining, - ResetTime: item.ExpireAt, + // If our new duration means we are currently expired. + createdAt := *ctx.Request.CreatedAt + if expire <= createdAt { + // Renew item. + span.AddEvent("Limit has expired") + expire = createdAt + ctx.Request.Duration + t.CreatedAt = createdAt + t.Remaining = t.Limit } - // If the duration config changed, update the new ExpireAt. - if t.Duration != r.Duration { - span := trace.SpanFromContext(ctx) - span.AddEvent("Duration changed") - expire := t.CreatedAt + r.Duration - if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { - expire, err = GregorianExpiration(clock.Now(), r.Duration) - if err != nil { - return nil, err - } - } + ctx.CacheItem.ExpireAt = expire + t.Duration = ctx.Request.Duration + rl.ResetTime = expire + } - // If our new duration means we are currently expired. - createdAt := *r.CreatedAt - if expire <= createdAt { - // Renew item. - span.AddEvent("Limit has expired") - expire = createdAt + r.Duration - t.CreatedAt = createdAt - t.Remaining = t.Limit - } + if ctx.Store != nil && ctx.ReqState.IsOwner { + defer func() { + ctx.Store.OnChange(ctx, ctx.Request, ctx.CacheItem) + }() + } - item.ExpireAt = expire - t.Duration = r.Duration - rl.ResetTime = expire - } + // Client is only interested in retrieving the current status or + // updating the rate limit config. + if ctx.Request.Hits == 0 { + return rl, nil + } - if s != nil && reqState.IsOwner { - defer func() { - s.OnChange(ctx, r, item) - }() + // If we are already at the limit. + if rl.Remaining == 0 && ctx.Request.Hits > 0 { + trace.SpanFromContext(ctx).AddEvent("Already over the limit") + if ctx.ReqState.IsOwner { + metricOverLimitCounter.Add(1) } + rl.Status = Status_OVER_LIMIT + t.Status = rl.Status + return rl, nil + } - // Client is only interested in retrieving the current status or - // updating the rate limit config. - if r.Hits == 0 { - return rl, nil - } + // If requested hits takes the remainder. + if t.Remaining == ctx.Request.Hits { + trace.SpanFromContext(ctx).AddEvent("At the limit") + t.Remaining = 0 + rl.Remaining = 0 + return rl, nil + } - // If we are already at the limit. - if rl.Remaining == 0 && r.Hits > 0 { - trace.SpanFromContext(ctx).AddEvent("Already over the limit") - if reqState.IsOwner { - metricOverLimitCounter.Add(1) - } - rl.Status = Status_OVER_LIMIT - t.Status = rl.Status - return rl, nil + // If requested is more than available, then return over the limit + // without updating the cache. + if ctx.Request.Hits > t.Remaining { + trace.SpanFromContext(ctx).AddEvent("Over the limit") + if ctx.ReqState.IsOwner { + metricOverLimitCounter.Add(1) } - - // If requested hits takes the remainder. - if t.Remaining == r.Hits { - trace.SpanFromContext(ctx).AddEvent("At the limit") + rl.Status = Status_OVER_LIMIT + if HasBehavior(ctx.Request.Behavior, Behavior_DRAIN_OVER_LIMIT) { + // DRAIN_OVER_LIMIT behavior drains the remaining counter. t.Remaining = 0 rl.Remaining = 0 - return rl, nil - } - - // If requested is more than available, then return over the limit - // without updating the cache. - if r.Hits > t.Remaining { - trace.SpanFromContext(ctx).AddEvent("Over the limit") - if reqState.IsOwner { - metricOverLimitCounter.Add(1) - } - rl.Status = Status_OVER_LIMIT - if HasBehavior(r.Behavior, Behavior_DRAIN_OVER_LIMIT) { - // DRAIN_OVER_LIMIT behavior drains the remaining counter. - t.Remaining = 0 - rl.Remaining = 0 - } - return rl, nil } - - t.Remaining -= r.Hits - rl.Remaining = t.Remaining return rl, nil } - // Item is not found in cache or store, create new. - return tokenBucketNewItem(ctx, s, c, r, reqState) + t.Remaining -= ctx.Request.Hits + rl.Remaining = t.Remaining + return rl, nil } -// Called by tokenBucket() when adding a new item in the store. -func tokenBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqState RateLimitReqState) (resp *RateLimitResp, err error) { - createdAt := *r.CreatedAt - expire := createdAt + r.Duration +// initTokenBucketItem will create a new item if the passed item is nil, else it will update the provided item. +func initTokenBucketItem(ctx rateContext) (resp *RateLimitResponse, err error) { + createdAt := *ctx.Request.CreatedAt + expire := createdAt + ctx.Request.Duration - t := &TokenBucketItem{ - Limit: r.Limit, - Duration: r.Duration, - Remaining: r.Limit - r.Hits, + t := TokenBucketItem{ + Limit: ctx.Request.Limit, + Duration: ctx.Request.Duration, + Remaining: ctx.Request.Limit - ctx.Request.Hits, CreatedAt: createdAt, } // Add a new rate limit to the cache. - if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { - expire, err = GregorianExpiration(clock.Now(), r.Duration) + if HasBehavior(ctx.Request.Behavior, Behavior_DURATION_IS_GREGORIAN) { + expire, err = GregorianExpiration(clock.Now(), ctx.Request.Duration) if err != nil { return nil, err } } - item := &CacheItem{ - Algorithm: Algorithm_TOKEN_BUCKET, - Key: r.HashKey(), - Value: t, - ExpireAt: expire, - } - - rl := &RateLimitResp{ + rl := &RateLimitResponse{ Status: Status_UNDER_LIMIT, - Limit: r.Limit, + Limit: ctx.Request.Limit, Remaining: t.Remaining, ResetTime: expire, } // Client could be requesting that we always return OVER_LIMIT. - if r.Hits > r.Limit { + if ctx.Request.Hits > ctx.Request.Limit { trace.SpanFromContext(ctx).AddEvent("Over the limit") - if reqState.IsOwner { + if ctx.ReqState.IsOwner { metricOverLimitCounter.Add(1) } rl.Status = Status_OVER_LIMIT - rl.Remaining = r.Limit - t.Remaining = r.Limit + rl.Remaining = ctx.Request.Limit + t.Remaining = ctx.Request.Limit } - c.Add(item) + // If the cache item already exists, update it + if ctx.CacheItem != nil { + ctx.CacheItem.mutex.Lock() + ctx.CacheItem.Algorithm = ctx.Request.Algorithm + ctx.CacheItem.ExpireAt = expire + in, ok := ctx.CacheItem.Value.(*TokenBucketItem) + if !ok { + // Likely the store gave us the wrong cache item + ctx.CacheItem.mutex.Unlock() + ctx.CacheItem = nil + return initTokenBucketItem(ctx) + } + *in = t + ctx.CacheItem.mutex.Unlock() + } else { + // else create a new cache item and add it to the cache + ctx.CacheItem = &CacheItem{ + Algorithm: Algorithm_TOKEN_BUCKET, + Key: ctx.Request.HashKey(), + Value: &t, + ExpireAt: expire, + } + if !ctx.Cache.AddIfNotPresent(ctx.CacheItem) { + return rl, errAlreadyExistsInCache + } + } - if s != nil && reqState.IsOwner { - s.OnChange(ctx, r, item) + if ctx.Store != nil && ctx.ReqState.IsOwner { + ctx.Store.OnChange(ctx, ctx.Request, ctx.CacheItem) } return rl, nil } // Implements leaky bucket algorithm for rate limiting https://en.wikipedia.org/wiki/Leaky_bucket -func leakyBucket(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqState RateLimitReqState) (resp *RateLimitResp, err error) { +func leakyBucket(ctx rateContext) (resp *RateLimitResponse, err error) { leakyBucketTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.getRateLimit_leakyBucket")) defer leakyBucketTimer.ObserveDuration() + var ok bool - if r.Burst == 0 { - r.Burst = r.Limit + if ctx.Request.Burst == 0 { + ctx.Request.Burst = ctx.Request.Limit } - createdAt := *r.CreatedAt - // Get rate limit from cache. - hashKey := r.HashKey() - item, ok := c.GetItem(hashKey) - - if s != nil && !ok { - // Cache miss. - // Check our store for the item. - if item, ok = s.Get(ctx, r); ok { - c.Add(item) + hashKey := ctx.Request.HashKey() + ctx.CacheItem, ok = ctx.Cache.GetItem(hashKey) + + if ctx.Store != nil && !ok { + // Cache missed, check our store for the item. + if ctx.CacheItem, ok = ctx.Store.Get(ctx, ctx.Request); ok { + if !ctx.Cache.AddIfNotPresent(ctx.CacheItem) { + // Someone else added a new leaky bucket item to the cache for this + // rate limit before we did, so we retry by calling ourselves recursively. + return leakyBucket(ctx) + } } } - // Sanity checks. - if ok { - if item.Value == nil { - msgPart := "leakyBucket: Invalid cache item; Value is nil" - trace.SpanFromContext(ctx).AddEvent(msgPart, trace.WithAttributes( - attribute.String("hashKey", hashKey), - attribute.String("key", r.UniqueKey), - attribute.String("name", r.Name), - )) - logrus.Error(msgPart) - ok = false - } else if item.Key != hashKey { - msgPart := "leakyBucket: Invalid cache item; key mismatch" - trace.SpanFromContext(ctx).AddEvent(msgPart, trace.WithAttributes( - attribute.String("itemKey", item.Key), - attribute.String("hashKey", hashKey), - attribute.String("name", r.Name), - )) - logrus.Error(msgPart) - ok = false + // If no item was found, or the item is expired. + if !ok || ctx.CacheItem.IsExpired() { + rl, err := initLeakyBucketItem(ctx) + if err != nil && errors.Is(err, errAlreadyExistsInCache) { + // Someone else added a new leaky bucket item to the cache for this + // rate limit before we did, so we retry by calling ourselves recursively. + return leakyBucket(ctx) } + return rl, err } - if ok { - // Item found in cache or store. - - b, ok := item.Value.(*LeakyBucketItem) - if !ok { - // Client switched algorithms; perhaps due to a migration? - c.Remove(hashKey) - - if s != nil { - s.Remove(ctx, hashKey) - } + // Gain exclusive rights to this item while we calculate the rate limit + ctx.CacheItem.mutex.Lock() - return leakyBucketNewItem(ctx, s, c, r, reqState) + // Item found in cache or store. + t, ok := ctx.CacheItem.Value.(*LeakyBucketItem) + if !ok { + // Client switched algorithms; perhaps due to a migration? + ctx.Cache.Remove(hashKey) + if ctx.Store != nil { + ctx.Store.Remove(ctx, hashKey) } - - if HasBehavior(r.Behavior, Behavior_RESET_REMAINING) { - b.Remaining = float64(r.Burst) + // Tell init to create a new cache item + ctx.CacheItem.mutex.Unlock() + ctx.CacheItem = nil + rl, err := initLeakyBucketItem(ctx) + if err != nil && errors.Is(err, errAlreadyExistsInCache) { + return leakyBucket(ctx) } + return rl, err + } - // Update burst, limit and duration if they changed - if b.Burst != r.Burst { - if r.Burst > int64(b.Remaining) { - b.Remaining = float64(r.Burst) - } - b.Burst = r.Burst - } + defer ctx.CacheItem.mutex.Unlock() - b.Limit = r.Limit - b.Duration = r.Duration + if HasBehavior(ctx.Request.Behavior, Behavior_RESET_REMAINING) { + t.Remaining = float64(ctx.Request.Burst) + } - duration := r.Duration - rate := float64(duration) / float64(r.Limit) + // Update burst, limit and duration if they changed + if t.Burst != ctx.Request.Burst { + if ctx.Request.Burst > int64(t.Remaining) { + t.Remaining = float64(ctx.Request.Burst) + } + t.Burst = ctx.Request.Burst + } - if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { - d, err := GregorianDuration(clock.Now(), r.Duration) - if err != nil { - return nil, err - } - n := clock.Now() - expire, err := GregorianExpiration(n, r.Duration) - if err != nil { - return nil, err - } + t.Limit = ctx.Request.Limit + t.Duration = ctx.Request.Duration - // Calculate the rate using the entire duration of the gregorian interval - // IE: Minute = 60,000 milliseconds, etc.. etc.. - rate = float64(d) / float64(r.Limit) - // Update the duration to be the end of the gregorian interval - duration = expire - (n.UnixNano() / 1000000) - } + duration := ctx.Request.Duration + rate := float64(duration) / float64(ctx.Request.Limit) - if r.Hits != 0 { - c.UpdateExpiration(r.HashKey(), createdAt+duration) + if HasBehavior(ctx.Request.Behavior, Behavior_DURATION_IS_GREGORIAN) { + d, err := GregorianDuration(clock.Now(), ctx.Request.Duration) + if err != nil { + return nil, err + } + n := clock.Now() + expire, err := GregorianExpiration(n, ctx.Request.Duration) + if err != nil { + return nil, err } - // Calculate how much leaked out of the bucket since the last time we leaked a hit - elapsed := createdAt - b.UpdatedAt - leak := float64(elapsed) / rate + // Calculate the rate using the entire duration of the gregorian interval + // IE: Minute = 60,000 milliseconds, etc.. etc.. + rate = float64(d) / float64(ctx.Request.Limit) + // Update the duration to be the end of the gregorian interval + duration = expire - (n.UnixNano() / 1000000) + } - if int64(leak) > 0 { - b.Remaining += leak - b.UpdatedAt = createdAt - } + createdAt := *ctx.Request.CreatedAt + if ctx.Request.Hits != 0 { + ctx.CacheItem.ExpireAt = createdAt + duration + } - if int64(b.Remaining) > b.Burst { - b.Remaining = float64(b.Burst) - } + // Calculate how much leaked out of the bucket since the last time we leaked a hit + elapsed := createdAt - t.UpdatedAt + leak := float64(elapsed) / rate - rl := &RateLimitResp{ - Limit: b.Limit, - Remaining: int64(b.Remaining), - Status: Status_UNDER_LIMIT, - ResetTime: createdAt + (b.Limit-int64(b.Remaining))*int64(rate), - } + if int64(leak) > 0 { + t.Remaining += leak + t.UpdatedAt = createdAt + } - // TODO: Feature missing: check for Duration change between item/request. + if int64(t.Remaining) > t.Burst { + t.Remaining = float64(t.Burst) + } - if s != nil && reqState.IsOwner { - defer func() { - s.OnChange(ctx, r, item) - }() - } + rl := &RateLimitResponse{ + Limit: t.Limit, + Remaining: int64(t.Remaining), + Status: Status_UNDER_LIMIT, + ResetTime: createdAt + (t.Limit-int64(t.Remaining))*int64(rate), + } - // If we are already at the limit - if int64(b.Remaining) == 0 && r.Hits > 0 { - if reqState.IsOwner { - metricOverLimitCounter.Add(1) - } - rl.Status = Status_OVER_LIMIT - return rl, nil - } + // TODO: Feature missing: check for Duration change between item/request. - // If requested hits takes the remainder - if int64(b.Remaining) == r.Hits { - b.Remaining = 0 - rl.Remaining = int64(b.Remaining) - rl.ResetTime = createdAt + (rl.Limit-rl.Remaining)*int64(rate) - return rl, nil - } + if ctx.Store != nil && ctx.ReqState.IsOwner { + defer func() { + ctx.Store.OnChange(ctx, ctx.Request, ctx.CacheItem) + }() + } - // If requested is more than available, then return over the limit - // without updating the bucket, unless `DRAIN_OVER_LIMIT` is set. - if r.Hits > int64(b.Remaining) { - if reqState.IsOwner { - metricOverLimitCounter.Add(1) - } - rl.Status = Status_OVER_LIMIT + // If we are already at the limit + if int64(t.Remaining) == 0 && ctx.Request.Hits > 0 { + if ctx.ReqState.IsOwner { + metricOverLimitCounter.Add(1) + } + rl.Status = Status_OVER_LIMIT + return rl, nil + } - // DRAIN_OVER_LIMIT behavior drains the remaining counter. - if HasBehavior(r.Behavior, Behavior_DRAIN_OVER_LIMIT) { - b.Remaining = 0 - rl.Remaining = 0 - } + // If requested hits takes the remainder + if int64(t.Remaining) == ctx.Request.Hits { + t.Remaining = 0 + rl.Remaining = int64(t.Remaining) + rl.ResetTime = createdAt + (rl.Limit-rl.Remaining)*int64(rate) + return rl, nil + } - return rl, nil + // If requested is more than available, then return over the limit + // without updating the bucket, unless `DRAIN_OVER_LIMIT` is set. + if ctx.Request.Hits > int64(t.Remaining) { + if ctx.ReqState.IsOwner { + metricOverLimitCounter.Add(1) } + rl.Status = Status_OVER_LIMIT - // Client is only interested in retrieving the current status - if r.Hits == 0 { - return rl, nil + // DRAIN_OVER_LIMIT behavior drains the remaining counter. + if HasBehavior(ctx.Request.Behavior, Behavior_DRAIN_OVER_LIMIT) { + t.Remaining = 0 + rl.Remaining = 0 } - b.Remaining -= float64(r.Hits) - rl.Remaining = int64(b.Remaining) - rl.ResetTime = createdAt + (rl.Limit-rl.Remaining)*int64(rate) return rl, nil } - return leakyBucketNewItem(ctx, s, c, r, reqState) + // Client is only interested in retrieving the current status + if ctx.Request.Hits == 0 { + return rl, nil + } + + t.Remaining -= float64(ctx.Request.Hits) + rl.Remaining = int64(t.Remaining) + rl.ResetTime = createdAt + (rl.Limit-rl.Remaining)*int64(rate) + return rl, nil + } // Called by leakyBucket() when adding a new item in the store. -func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, reqState RateLimitReqState) (resp *RateLimitResp, err error) { - createdAt := *r.CreatedAt - duration := r.Duration - rate := float64(duration) / float64(r.Limit) - if HasBehavior(r.Behavior, Behavior_DURATION_IS_GREGORIAN) { +func initLeakyBucketItem(ctx rateContext) (resp *RateLimitResponse, err error) { + createdAt := *ctx.Request.CreatedAt + duration := ctx.Request.Duration + rate := float64(duration) / float64(ctx.Request.Limit) + if HasBehavior(ctx.Request.Behavior, Behavior_DURATION_IS_GREGORIAN) { n := clock.Now() - expire, err := GregorianExpiration(n, r.Duration) + expire, err := GregorianExpiration(n, ctx.Request.Duration) if err != nil { return nil, err } @@ -451,23 +475,23 @@ func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, // Create a new leaky bucket b := LeakyBucketItem{ - Remaining: float64(r.Burst - r.Hits), - Limit: r.Limit, + Remaining: float64(ctx.Request.Burst - ctx.Request.Hits), + Limit: ctx.Request.Limit, Duration: duration, UpdatedAt: createdAt, - Burst: r.Burst, + Burst: ctx.Request.Burst, } - rl := RateLimitResp{ + rl := RateLimitResponse{ Status: Status_UNDER_LIMIT, Limit: b.Limit, - Remaining: r.Burst - r.Hits, - ResetTime: createdAt + (b.Limit-(r.Burst-r.Hits))*int64(rate), + Remaining: ctx.Request.Burst - ctx.Request.Hits, + ResetTime: createdAt + (b.Limit-(ctx.Request.Burst-ctx.Request.Hits))*int64(rate), } // Client could be requesting that we start with the bucket OVER_LIMIT - if r.Hits > r.Burst { - if reqState.IsOwner { + if ctx.Request.Hits > ctx.Request.Burst { + if ctx.ReqState.IsOwner { metricOverLimitCounter.Add(1) } rl.Status = Status_OVER_LIMIT @@ -476,17 +500,33 @@ func leakyBucketNewItem(ctx context.Context, s Store, c Cache, r *RateLimitReq, b.Remaining = 0 } - item := &CacheItem{ - ExpireAt: createdAt + duration, - Algorithm: r.Algorithm, - Key: r.HashKey(), - Value: &b, + if ctx.CacheItem != nil { + ctx.CacheItem.mutex.Lock() + ctx.CacheItem.Algorithm = ctx.Request.Algorithm + ctx.CacheItem.ExpireAt = createdAt + duration + in, ok := ctx.CacheItem.Value.(*LeakyBucketItem) + if !ok { + // Likely the store gave us the wrong cache item + ctx.CacheItem.mutex.Unlock() + ctx.CacheItem = nil + return initLeakyBucketItem(ctx) + } + *in = b + ctx.CacheItem.mutex.Unlock() + } else { + ctx.CacheItem = &CacheItem{ + ExpireAt: createdAt + duration, + Algorithm: ctx.Request.Algorithm, + Key: ctx.Request.HashKey(), + Value: &b, + } + if !ctx.Cache.AddIfNotPresent(ctx.CacheItem) { + return nil, errAlreadyExistsInCache + } } - c.Add(item) - - if s != nil && reqState.IsOwner { - s.OnChange(ctx, r, item) + if ctx.Store != nil && ctx.ReqState.IsOwner { + ctx.Store.OnChange(ctx, ctx.Request, ctx.CacheItem) } return &rl, nil diff --git a/benchmark_cache_test.go b/benchmark_cache_test.go index e19ee7e..45b9d0f 100644 --- a/benchmark_cache_test.go +++ b/benchmark_cache_test.go @@ -1,160 +1,182 @@ package gubernator_test import ( - "strconv" + "math/rand" "sync" "testing" "time" - "github.com/gubernator-io/gubernator/v2" - "github.com/mailgun/holster/v4/clock" + "github.com/gubernator-io/gubernator/v3" + "github.com/kapetan-io/tackle/clock" + "github.com/stretchr/testify/require" ) func BenchmarkCache(b *testing.B) { + const defaultNumKeys = 8192 testCases := []struct { Name string - NewTestCache func() gubernator.Cache + NewTestCache func() (gubernator.Cache, error) LockRequired bool }{ { Name: "LRUCache", - NewTestCache: func() gubernator.Cache { - return gubernator.NewLRUCache(0) + NewTestCache: func() (gubernator.Cache, error) { + return gubernator.NewLRUCache(0), nil }, LockRequired: true, }, + { + Name: "OtterCache", + NewTestCache: func() (gubernator.Cache, error) { + return gubernator.NewOtterCache(0) + }, + LockRequired: false, + }, } for _, testCase := range testCases { b.Run(testCase.Name, func(b *testing.B) { b.Run("Sequential reads", func(b *testing.B) { - cache := testCase.NewTestCache() + cache, err := testCase.NewTestCache() + require.NoError(b, err) expire := clock.Now().Add(time.Hour).UnixMilli() + keys := GenerateRandomKeys(defaultNumKeys) - for i := 0; i < b.N; i++ { - key := strconv.Itoa(i) + for _, key := range keys { item := &gubernator.CacheItem{ Key: key, - Value: i, + Value: "value:" + key, ExpireAt: expire, } - cache.Add(item) + cache.AddIfNotPresent(item) } + mask := len(keys) - 1 b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - key := strconv.Itoa(i) - _, _ = cache.GetItem(key) + index := int(rand.Uint32() & uint32(mask)) + _, _ = cache.GetItem(keys[index&mask]) } }) b.Run("Sequential writes", func(b *testing.B) { - cache := testCase.NewTestCache() + cache, err := testCase.NewTestCache() + require.NoError(b, err) expire := clock.Now().Add(time.Hour).UnixMilli() + keys := GenerateRandomKeys(defaultNumKeys) + mask := len(keys) - 1 b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { + index := int(rand.Uint32() & uint32(mask)) item := &gubernator.CacheItem{ - Key: strconv.Itoa(i), - Value: i, + Key: keys[index&mask], + Value: "value:" + keys[index&mask], ExpireAt: expire, } - cache.Add(item) + cache.AddIfNotPresent(item) } }) b.Run("Concurrent reads", func(b *testing.B) { - cache := testCase.NewTestCache() + cache, err := testCase.NewTestCache() + require.NoError(b, err) expire := clock.Now().Add(time.Hour).UnixMilli() + keys := GenerateRandomKeys(defaultNumKeys) - for i := 0; i < b.N; i++ { - key := strconv.Itoa(i) + for _, key := range keys { item := &gubernator.CacheItem{ Key: key, - Value: i, + Value: "value:" + key, ExpireAt: expire, } - cache.Add(item) + cache.AddIfNotPresent(item) } - var wg sync.WaitGroup var mutex sync.Mutex - var task func(i int) + var task func(key string) if testCase.LockRequired { - task = func(i int) { + task = func(key string) { mutex.Lock() defer mutex.Unlock() - key := strconv.Itoa(i) _, _ = cache.GetItem(key) - wg.Done() } } else { - task = func(i int) { - key := strconv.Itoa(i) + task = func(key string) { _, _ = cache.GetItem(key) - wg.Done() } } b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { - wg.Add(1) - go task(i) - } + mask := len(keys) - 1 + + b.RunParallel(func(pb *testing.PB) { + index := int(rand.Uint32() & uint32(mask)) + for pb.Next() { + task(keys[index&mask]) + } + }) - wg.Wait() }) b.Run("Concurrent writes", func(b *testing.B) { - cache := testCase.NewTestCache() + cache, err := testCase.NewTestCache() + require.NoError(b, err) expire := clock.Now().Add(time.Hour).UnixMilli() + keys := GenerateRandomKeys(defaultNumKeys) - var wg sync.WaitGroup var mutex sync.Mutex - var task func(i int) + var task func(key string) if testCase.LockRequired { - task = func(i int) { + task = func(key string) { mutex.Lock() defer mutex.Unlock() item := &gubernator.CacheItem{ - Key: strconv.Itoa(i), - Value: i, + Key: key, + Value: "value:" + key, ExpireAt: expire, } - cache.Add(item) - wg.Done() + cache.AddIfNotPresent(item) } } else { - task = func(i int) { + task = func(key string) { item := &gubernator.CacheItem{ - Key: strconv.Itoa(i), - Value: i, + Key: key, + Value: "value:" + key, ExpireAt: expire, } - cache.Add(item) - wg.Done() + cache.AddIfNotPresent(item) } } + mask := len(keys) - 1 b.ReportAllocs() b.ResetTimer() - for i := 0; i < b.N; i++ { - wg.Add(1) - go task(i) - } - - wg.Wait() + b.RunParallel(func(pb *testing.PB) { + index := int(rand.Uint32() & uint32(mask)) + for pb.Next() { + task(keys[index&mask]) + } + }) }) }) } } + +func GenerateRandomKeys(size int) []string { + keys := make([]string, 0, size) + for i := 0; i < size; i++ { + keys = append(keys, gubernator.RandomString(20)) + } + return keys +} diff --git a/benchmark_test.go b/benchmark_test.go index 6adc92f..9a8b2bd 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -18,56 +18,82 @@ package gubernator_test import ( "context" + "fmt" + "math/rand" + "runtime" "testing" - guber "github.com/gubernator-io/gubernator/v2" - "github.com/gubernator-io/gubernator/v2/cluster" - "github.com/mailgun/holster/v4/clock" - "github.com/mailgun/holster/v4/syncutil" + guber "github.com/gubernator-io/gubernator/v3" + "github.com/gubernator-io/gubernator/v3/cluster" + "github.com/kapetan-io/tackle/clock" "github.com/stretchr/testify/require" ) +// go test benchmark_test.go -bench=BenchmarkTrace -benchtime=20s -trace=trace.out +// go tool trace trace.out +//func BenchmarkTrace(b *testing.B) { +// if err := cluster.StartWith([]guber.PeerInfo{ +// {HTTPAddress: "127.0.0.1:7980", DataCenter: cluster.DataCenterNone}, +// {HTTPAddress: "127.0.0.1:7981", DataCenter: cluster.DataCenterNone}, +// {HTTPAddress: "127.0.0.1:7982", DataCenter: cluster.DataCenterNone}, +// {HTTPAddress: "127.0.0.1:7983", DataCenter: cluster.DataCenterNone}, +// {HTTPAddress: "127.0.0.1:7984", DataCenter: cluster.DataCenterNone}, +// {HTTPAddress: "127.0.0.1:7985", DataCenter: cluster.DataCenterNone}, +// +// // DataCenterOne +// {HTTPAddress: "127.0.0.1:9880", DataCenter: cluster.DataCenterOne}, +// {HTTPAddress: "127.0.0.1:9881", DataCenter: cluster.DataCenterOne}, +// {HTTPAddress: "127.0.0.1:9882", DataCenter: cluster.DataCenterOne}, +// {HTTPAddress: "127.0.0.1:9883", DataCenter: cluster.DataCenterOne}, +// }); err != nil { +// fmt.Println(err) +// os.Exit(1) +// } +// defer cluster.Stop(context.Background()) +//} + func BenchmarkServer(b *testing.B) { ctx := context.Background() conf := guber.Config{} err := conf.SetDefaults() require.NoError(b, err, "Error in conf.SetDefaults") createdAt := epochMillis(clock.Now()) + d := cluster.GetRandomDaemon(cluster.DataCenterNone) - b.Run("GetPeerRateLimit", func(b *testing.B) { - client, err := guber.NewPeerClient(guber.PeerConfig{ - Info: cluster.GetRandomPeer(cluster.DataCenterNone), - Behavior: conf.Behaviors, - }) - if err != nil { - b.Errorf("Error building client: %s", err) - } + b.Run("Forward", func(b *testing.B) { + client := d.MustClient().(guber.PeerClient) b.ResetTimer() for n := 0; n < b.N; n++ { - _, err := client.GetPeerRateLimit(ctx, &guber.RateLimitReq{ - Name: b.Name(), - UniqueKey: guber.RandomString(10), - // Behavior: guber.Behavior_NO_BATCHING, - Limit: 10, - Duration: 5, - Hits: 1, - CreatedAt: &createdAt, - }) + var resp guber.ForwardResponse + err := client.Forward(ctx, &guber.ForwardRequest{ + Requests: []*guber.RateLimitRequest{ + { + Name: b.Name(), + UniqueKey: guber.RandomString(10), + // Behavior: guber.Behavior_NO_BATCHING, + Limit: 10, + Duration: 5, + Hits: 1, + CreatedAt: &createdAt, + }, + }, + }, &resp) if err != nil { b.Errorf("Error in client.GetPeerRateLimit: %s", err) } } }) - b.Run("GetRateLimits batching", func(b *testing.B) { - client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) - require.NoError(b, err, "Error in guber.DialV1Server") + b.Run("CheckRateLimits batching", func(b *testing.B) { + client := cluster.GetRandomDaemon(cluster.DataCenterNone).MustClient() + require.NoError(b, err) b.ResetTimer() for n := 0; n < b.N; n++ { - _, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(ctx, &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: b.Name(), UniqueKey: guber.RandomString(10), @@ -76,21 +102,22 @@ func BenchmarkServer(b *testing.B) { Hits: 1, }, }, - }) + }, &resp) if err != nil { b.Errorf("Error in client.GetRateLimits(): %s", err) } } }) - b.Run("GetRateLimits global", func(b *testing.B) { - client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) - require.NoError(b, err, "Error in guber.DialV1Server") + b.Run("CheckRateLimits global", func(b *testing.B) { + client := cluster.GetRandomDaemon(cluster.DataCenterNone).MustClient() + require.NoError(b, err) b.ResetTimer() for n := 0; n < b.N; n++ { - _, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(ctx, &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: b.Name(), UniqueKey: guber.RandomString(10), @@ -100,49 +127,97 @@ func BenchmarkServer(b *testing.B) { Hits: 1, }, }, - }) + }, &resp) if err != nil { - b.Errorf("Error in client.GetRateLimits: %s", err) + b.Errorf("Error in client.CheckRateLimits: %s", err) } } }) b.Run("HealthCheck", func(b *testing.B) { - client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) - require.NoError(b, err, "Error in guber.DialV1Server") + client := cluster.GetRandomDaemon(cluster.DataCenterNone).MustClient() + require.NoError(b, err) b.ResetTimer() for n := 0; n < b.N; n++ { - if _, err := client.HealthCheck(ctx, &guber.HealthCheckReq{}); err != nil { + var resp guber.HealthCheckResponse + if err := client.HealthCheck(ctx, &resp); err != nil { b.Errorf("Error in client.HealthCheck: %s", err) } } }) - b.Run("Thundering herd", func(b *testing.B) { - client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) - require.NoError(b, err, "Error in guber.DialV1Server") - b.ResetTimer() - fan := syncutil.NewFanOut(100) + b.Run("Concurrency CheckRateLimits", func(b *testing.B) { + var clients []guber.Client - for n := 0; n < b.N; n++ { - fan.Run(func(o interface{}) error { - _, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + // Create a client for each CPU on the system. This should allow us to simulate the + // maximum contention possible for this system. + for i := 0; i < runtime.NumCPU(); i++ { + client, err := guber.NewClient(guber.WithNoTLS(d.Listener.Addr().String())) + require.NoError(b, err) + clients = append(clients, client) + } + + keys := GenerateRandomKeys(8_000) + keyMask := len(keys) - 1 + clientMask := len(clients) - 1 + var idx int + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + client := clients[idx&clientMask] + idx++ + + for pb.Next() { + keyIdx := int(rand.Uint32() & uint32(clientMask)) + var resp guber.CheckRateLimitsResponse + err = client.CheckRateLimits(ctx, &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: b.Name(), - UniqueKey: guber.RandomString(10), - Limit: 10, - Duration: guber.Second * 5, + UniqueKey: keys[keyIdx&keyMask], + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Minute, + Limit: 100, Hits: 1, }, }, - }) + }, &resp) if err != nil { - b.Errorf("Error in client.GetRateLimits: %s", err) + fmt.Printf("%s\n", err.Error()) } - return nil - }, nil) + } + }) + + for _, client := range clients { + _ = client.Close(context.Background()) + } + }) + + b.Run("Concurrency HealthCheck", func(b *testing.B) { + var clients []guber.Client + + // Create a client for each CPU on the system. This should allow us to simulate the + // maximum contention possible for this system. + for i := 0; i < runtime.NumCPU(); i++ { + client, err := guber.NewClient(guber.WithNoTLS(d.Listener.Addr().String())) + require.NoError(b, err) + clients = append(clients, client) } + mask := len(clients) - 1 + var idx int + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + client := clients[idx&mask] + idx++ + + for pb.Next() { + var resp guber.HealthCheckResponse + if err := client.HealthCheck(ctx, &resp); err != nil { + b.Errorf("Error in client.HealthCheck: %s", err) + } + } + }) }) } diff --git a/buf.gen.yaml b/buf.gen.yaml index 5c62f51..b928594 100755 --- a/buf.gen.yaml +++ b/buf.gen.yaml @@ -5,18 +5,5 @@ plugins: - plugin: buf.build/protocolbuffers/go:v1.32.0 out: ./ opt: paths=source_relative - - plugin: buf.build/grpc/go:v1.3.0 - out: ./ - opt: - - paths=source_relative - - require_unimplemented_servers=false - - plugin: buf.build/grpc-ecosystem/gateway:v2.18.0 # same version in go.mod - out: ./ - opt: - - paths=source_relative - - logtostderr=true - - generate_unbound_methods=true - - plugin: buf.build/grpc/python:v1.57.0 - out: ./python/gubernator - plugin: buf.build/protocolbuffers/python out: ./python/gubernator diff --git a/buf.yaml b/buf.yaml index b6d1351..38bfc22 100644 --- a/buf.yaml +++ b/buf.yaml @@ -8,7 +8,4 @@ breaking: - FILE lint: use: - - DEFAULT - rpc_allow_same_request_response: false - rpc_allow_google_protobuf_empty_requests: true - rpc_allow_google_protobuf_empty_responses: true \ No newline at end of file + - DEFAULT \ No newline at end of file diff --git a/cache.go b/cache.go index 0fd431a..d40050f 100644 --- a/cache.go +++ b/cache.go @@ -16,31 +16,50 @@ limitations under the License. package gubernator +import ( + "sync" + + "github.com/prometheus/client_golang/prometheus" + + "github.com/kapetan-io/tackle/clock" +) + type Cache interface { - Add(item *CacheItem) bool - UpdateExpiration(key string, expireAt int64) bool + // AddIfNotPresent adds the item to the cache if it doesn't already exist. + // Returns true if the item was added, false if the item already exists. + AddIfNotPresent(item *CacheItem) bool GetItem(key string) (value *CacheItem, ok bool) Each() chan *CacheItem Remove(key string) Size() int64 + Stats() CacheStats Close() error } +// CacheItem is 64 bytes aligned in size +// Since both TokenBucketItem and LeakyBucketItem both 40 bytes in size then a CacheItem with +// the Value attached takes up 64 + 40 = 104 bytes of space. Not counting the size of the key. type CacheItem struct { - Algorithm Algorithm - Key string - Value interface{} + mutex sync.Mutex // 8 bytes + Key string // 16 bytes + Value interface{} // 16 bytes // Timestamp when rate limit expires in epoch milliseconds. - ExpireAt int64 + ExpireAt int64 // 8 Bytes // Timestamp when the cache should invalidate this rate limit. This is useful when used in conjunction with // a persistent store to ensure our node has the most up to date info from the store. Ignored if set to `0` // It is set by the persistent store implementation to indicate when the node should query the persistent store // for the latest rate limit data. - InvalidAt int64 + InvalidAt int64 // 8 bytes + Algorithm Algorithm // 4 bytes + // 4 Bytes of Padding } func (item *CacheItem) IsExpired() bool { + // TODO(thrawn01): Eliminate the need for this mutex lock + item.mutex.Lock() + defer item.mutex.Unlock() + now := MillisecondNow() // If the entry is invalidated @@ -55,3 +74,91 @@ func (item *CacheItem) IsExpired() bool { return false } + +func (item *CacheItem) Copy(from *CacheItem) { + item.mutex.Lock() + defer item.mutex.Unlock() + + item.InvalidAt = from.InvalidAt + item.Algorithm = from.Algorithm + item.ExpireAt = from.ExpireAt + item.Value = from.Value + item.Key = from.Key +} + +// MillisecondNow returns unix epoch in milliseconds +func MillisecondNow() int64 { + return clock.Now().UnixNano() / 1000000 +} + +type CacheStats struct { + Size int64 + Hit int64 + Miss int64 + UnexpiredEvictions int64 +} + +// CacheCollector provides prometheus metrics collector for Cache implementations +// Register only one collector, add one or more caches to this collector. +type CacheCollector struct { + caches []Cache + metricSize prometheus.Gauge + metricAccess *prometheus.CounterVec + metricUnexpiredEvictions prometheus.Counter +} + +func NewCacheCollector() *CacheCollector { + return &CacheCollector{ + caches: []Cache{}, + metricSize: prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "gubernator_cache_size", + Help: "The number of items in LRU Cache which holds the rate limits.", + }), + metricAccess: prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "gubernator_cache_access_count", + Help: "Cache access counts. Label \"type\" = hit|miss.", + }, []string{"type"}), + metricUnexpiredEvictions: prometheus.NewCounter(prometheus.CounterOpts{ + Name: "gubernator_unexpired_evictions_count", + Help: "Count the number of cache items which were evicted while unexpired.", + }), + } +} + +var _ prometheus.Collector = &CacheCollector{} + +// AddCache adds a Cache object to be tracked by the collector. +func (c *CacheCollector) AddCache(cache Cache) { + c.caches = append(c.caches, cache) +} + +// Describe fetches prometheus metrics to be registered +func (c *CacheCollector) Describe(ch chan<- *prometheus.Desc) { + c.metricSize.Describe(ch) + c.metricAccess.Describe(ch) + c.metricUnexpiredEvictions.Describe(ch) +} + +// Collect fetches metric counts and gauges from the cache +func (c *CacheCollector) Collect(ch chan<- prometheus.Metric) { + stats := c.getStats() + c.metricSize.Set(float64(stats.Size)) + c.metricSize.Collect(ch) + c.metricAccess.WithLabelValues("miss").Add(float64(stats.Miss)) + c.metricAccess.WithLabelValues("hit").Add(float64(stats.Hit)) + c.metricAccess.Collect(ch) + c.metricUnexpiredEvictions.Add(float64(stats.UnexpiredEvictions)) + c.metricUnexpiredEvictions.Collect(ch) +} + +func (c *CacheCollector) getStats() CacheStats { + var total CacheStats + for _, cache := range c.caches { + stats := cache.Stats() + total.Hit += stats.Hit + total.Miss += stats.Miss + total.Size += stats.Size + total.UnexpiredEvictions += stats.UnexpiredEvictions + } + return total +} diff --git a/cache_manager.go b/cache_manager.go new file mode 100644 index 0000000..d24e252 --- /dev/null +++ b/cache_manager.go @@ -0,0 +1,172 @@ +/* +Copyright 2024 Derrick J. Wippler + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gubernator + +import ( + "context" + "sync" + + "github.com/pkg/errors" +) + +type CacheManager interface { + CheckRateLimit(context.Context, *RateLimitRequest, RateLimitContext) (*RateLimitResponse, error) + GetCacheItem(context.Context, string) (*CacheItem, bool, error) + AddCacheItem(context.Context, string, *CacheItem) error + Store(ctx context.Context) error + Load(context.Context) error + Close() error +} + +type cacheManager struct { + conf Config + cache Cache +} + +// NewCacheManager creates a new instance of the CacheManager interface using +// the cache returned by Config.CacheFactory +func NewCacheManager(conf Config) (CacheManager, error) { + + cache, err := conf.CacheFactory(conf.CacheSize) + if err != nil { + return nil, err + } + return &cacheManager{ + cache: cache, + conf: conf, + }, nil +} + +// GetRateLimit fetches the item from the cache if it exists, and preforms the appropriate rate limit calculation +func (m *cacheManager) CheckRateLimit(ctx context.Context, req *RateLimitRequest, state RateLimitContext) (*RateLimitResponse, error) { + var rlResponse *RateLimitResponse + var err error + + switch req.Algorithm { + case Algorithm_TOKEN_BUCKET: + rlResponse, err = tokenBucket(rateContext{ + Store: m.conf.Store, + Cache: m.cache, + ReqState: state, + Request: req, + Context: ctx, + }) + if err != nil { + msg := "Error in tokenBucket" + countError(err, msg) + } + + case Algorithm_LEAKY_BUCKET: + rlResponse, err = leakyBucket(rateContext{ + Store: m.conf.Store, + Cache: m.cache, + ReqState: state, + Request: req, + Context: ctx, + }) + if err != nil { + msg := "Error in leakyBucket" + countError(err, msg) + } + + default: + err = errors.Errorf("Invalid rate limit algorithm '%d'", req.Algorithm) + } + + return rlResponse, err +} + +// Store saves every cache item into persistent storage provided via Config.Loader +func (m *cacheManager) Store(ctx context.Context) error { + out := make(chan *CacheItem, 500) + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + for item := range m.cache.Each() { + select { + case out <- item: + + case <-ctx.Done(): + return + } + } + }() + + go func() { + wg.Wait() + close(out) + }() + + if ctx.Err() != nil { + return ctx.Err() + } + + if err := m.conf.Loader.Save(out); err != nil { + return errors.Wrap(err, "while calling p.conf.Loader.Save()") + } + return nil +} + +// Close closes the cache manager +func (m *cacheManager) Close() error { + return m.cache.Close() +} + +// Load cache items from persistent storage provided via Config.Loader +func (m *cacheManager) Load(ctx context.Context) error { + ch, err := m.conf.Loader.Load() + if err != nil { + return errors.Wrap(err, "Error in loader.Load") + } + + for { + var item *CacheItem + var ok bool + + select { + case item, ok = <-ch: + if !ok { + return nil + } + case <-ctx.Done(): + return ctx.Err() + } + retry: + if !m.cache.AddIfNotPresent(item) { + cItem, ok := m.cache.GetItem(item.Key) + if !ok { + goto retry + } + cItem.Copy(item) + } + } +} + +// GetCacheItem returns an item from the cache +func (m *cacheManager) GetCacheItem(_ context.Context, key string) (*CacheItem, bool, error) { + item, ok := m.cache.GetItem(key) + return item, ok, nil +} + +// AddCacheItem adds an item to the cache. The CacheItem.Key should be set correctly, else the item +// will not be added to the cache correctly. +func (m *cacheManager) AddCacheItem(_ context.Context, _ string, item *CacheItem) error { + _ = m.cache.AddIfNotPresent(item) + return nil +} diff --git a/workers_test.go b/cache_manager_test.go similarity index 80% rename from workers_test.go rename to cache_manager_test.go index 4e77960..7994b08 100644 --- a/workers_test.go +++ b/cache_manager_test.go @@ -22,13 +22,13 @@ import ( "sort" "testing" - guber "github.com/gubernator-io/gubernator/v2" + guber "github.com/gubernator-io/gubernator/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) -func TestGubernatorPool(t *testing.T) { +func TestCacheManager(t *testing.T) { ctx := context.Background() testCases := []struct { @@ -43,7 +43,7 @@ func TestGubernatorPool(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { // Setup mock data. const NumCacheItems = 100 - cacheItems := []*guber.CacheItem{} + var cacheItems []*guber.CacheItem for i := 0; i < NumCacheItems; i++ { cacheItems = append(cacheItems, &guber.CacheItem{ Key: fmt.Sprintf("Foobar%04d", i), @@ -55,15 +55,16 @@ func TestGubernatorPool(t *testing.T) { t.Run("Load()", func(t *testing.T) { mockLoader := &MockLoader2{} mockCache := &MockCache{} - conf := &guber.Config{ - CacheFactory: func(maxSize int) guber.Cache { - return mockCache + conf := guber.Config{ + CacheFactory: func(maxSize int) (guber.Cache, error) { + return mockCache, nil }, Loader: mockLoader, Workers: testCase.workers, } assert.NoError(t, conf.SetDefaults()) - chp := guber.NewWorkerPool(conf) + manager, err := guber.NewCacheManager(conf) + require.NoError(t, err) // Mock Loader. fakeLoadCh := make(chan *guber.CacheItem, NumCacheItems) @@ -75,35 +76,36 @@ func TestGubernatorPool(t *testing.T) { // Mock Cache. for _, item := range cacheItems { - mockCache.On("Add", item).Once().Return(false) + mockCache.On("AddIfNotPresent", item).Once().Return(true) } // Call code. - err := chp.Load(ctx) + err = manager.Load(ctx) // Verify. - require.NoError(t, err, "Error in chp.Load") + require.NoError(t, err, "Error in manager.Load") }) t.Run("Store()", func(t *testing.T) { mockLoader := &MockLoader2{} mockCache := &MockCache{} - conf := &guber.Config{ - CacheFactory: func(maxSize int) guber.Cache { - return mockCache + conf := guber.Config{ + CacheFactory: func(maxSize int) (guber.Cache, error) { + return mockCache, nil }, Loader: mockLoader, Workers: testCase.workers, } require.NoError(t, conf.SetDefaults()) - chp := guber.NewWorkerPool(conf) + chp, err := guber.NewCacheManager(conf) + require.NoError(t, err) // Mock Loader. mockLoader.On("Save", mock.Anything).Once().Return(nil). Run(func(args mock.Arguments) { // Verify items sent over the channel passed to Save(). saveCh := args.Get(0).(chan *guber.CacheItem) - savedItems := []*guber.CacheItem{} + var savedItems []*guber.CacheItem for item := range saveCh { savedItems = append(savedItems, item) } @@ -124,7 +126,7 @@ func TestGubernatorPool(t *testing.T) { mockCache.On("Each").Times(testCase.workers).Return(eachCh) // Call code. - err := chp.Store(ctx) + err = chp.Store(ctx) // Verify. require.NoError(t, err, "Error in chp.Store") diff --git a/client.go b/client.go index 430bd22..913a2f0 100644 --- a/client.go +++ b/client.go @@ -17,17 +17,22 @@ limitations under the License. package gubernator import ( + "bytes" + "context" crand "crypto/rand" "crypto/tls" + "fmt" "math/rand" + "net/http" "time" - "github.com/mailgun/holster/v4/clock" + "github.com/duh-rpc/duh-go" + v1 "github.com/duh-rpc/duh-go/proto/v1" + "github.com/kapetan-io/tackle/clock" + "github.com/kapetan-io/tackle/set" "github.com/pkg/errors" - "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" + "go.opentelemetry.io/otel/propagation" + "google.golang.org/protobuf/proto" ) const ( @@ -36,32 +41,175 @@ const ( Minute = 60 * Second ) -func (m *RateLimitReq) HashKey() string { +type Client interface { + CheckRateLimits(context.Context, *CheckRateLimitsRequest, *CheckRateLimitsResponse) error + HealthCheck(context.Context, *HealthCheckResponse) error + Close(ctx context.Context) error +} + +func (m *RateLimitRequest) HashKey() string { return m.Name + "_" + m.UniqueKey } -// DialV1Server is a convenience function for dialing gubernator instances -func DialV1Server(server string, tls *tls.Config) (V1Client, error) { - if len(server) == 0 { - return nil, errors.New("server is empty; must provide a server") +type ClientOptions struct { + // Users can provide their own http client with TLS config if needed + Client *http.Client + // The address of endpoint in the format `://:` + Endpoint string +} + +type client struct { + *duh.Client + prop propagation.TraceContext + opts ClientOptions +} + +// NewClient creates a new instance of the Gubernator user client +func NewClient(opts ClientOptions) (Client, error) { + set.Default(&opts.Client, &http.Client{ + Transport: &http.Transport{ + MaxConnsPerHost: 2_000, + MaxIdleConns: 2_000, + MaxIdleConnsPerHost: 2_000, + IdleConnTimeout: 60 * time.Second, + }, + }) + + if len(opts.Endpoint) == 0 { + return nil, errors.New("opts.Endpoint is empty; must provide an address") + } + + return &client{ + Client: &duh.Client{ + Client: opts.Client, + }, + opts: opts, + }, nil +} + +func NewPeerClient(opts ClientOptions) PeerClient { + return &client{ + Client: &duh.Client{ + Client: opts.Client, + }, + opts: opts, + } +} + +func (c *client) CheckRateLimits(ctx context.Context, req *CheckRateLimitsRequest, resp *CheckRateLimitsResponse) error { + payload, err := proto.Marshal(req) + if err != nil { + return duh.NewClientError("while marshaling request payload: %w", err, nil) + } + + r, err := http.NewRequestWithContext(ctx, http.MethodPost, + fmt.Sprintf("%s%s", c.opts.Endpoint, RPCRateLimitCheck), bytes.NewReader(payload)) + if err != nil { + return duh.NewClientError("", err, nil) + } + + r.Header.Set("Content-Type", duh.ContentTypeProtoBuf) + return c.Do(r, resp) +} + +func (c *client) HealthCheck(ctx context.Context, resp *HealthCheckResponse) error { + payload, err := proto.Marshal(&HealthCheckRequest{}) + if err != nil { + return duh.NewClientError("while marshaling request payload: %w", err, nil) + } + + r, err := http.NewRequestWithContext(ctx, http.MethodPost, + fmt.Sprintf("%s%s", c.opts.Endpoint, RPCHealthCheck), bytes.NewReader(payload)) + if err != nil { + return duh.NewClientError("", err, nil) } - // Setup OpenTelemetry interceptor to propagate spans. - opts := []grpc.DialOption{ - grpc.WithStatsHandler(otelgrpc.NewClientHandler()), + r.Header.Set("Content-Type", duh.ContentTypeProtoBuf) + return c.Do(r, resp) +} + +func (c *client) Forward(ctx context.Context, req *ForwardRequest, resp *ForwardResponse) error { + payload, err := proto.Marshal(req) + if err != nil { + return duh.NewClientError("while marshaling request payload: %w", err, nil) } - if tls != nil { - opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tls))) - } else { - opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) + + r, err := http.NewRequestWithContext(ctx, http.MethodPost, + fmt.Sprintf("%s%s", c.opts.Endpoint, RPCPeerForward), bytes.NewReader(payload)) + if err != nil { + return duh.NewClientError("", err, nil) } - conn, err := grpc.NewClient(server, opts...) + c.prop.Inject(ctx, propagation.HeaderCarrier(r.Header)) + r.Header.Set("Content-Type", duh.ContentTypeProtoBuf) + return c.Do(r, resp) +} + +func (c *client) Update(ctx context.Context, req *UpdateRequest) error { + payload, err := proto.Marshal(req) + if err != nil { + return duh.NewClientError("while marshaling request payload: %w", err, nil) + } + r, err := http.NewRequestWithContext(ctx, http.MethodPost, + fmt.Sprintf("%s%s", c.opts.Endpoint, RPCPeerUpdate), bytes.NewReader(payload)) if err != nil { - return nil, errors.Wrapf(err, "failed to dial server %s", server) + return duh.NewClientError("", err, nil) } - return NewV1Client(conn), nil + r.Header.Set("Content-Type", duh.ContentTypeProtoBuf) + return c.Do(r, &v1.Reply{}) +} + +func (c *client) Close(_ context.Context) error { + c.Client.Client.CloseIdleConnections() + return nil +} + +// WithNoTLS returns ClientOptions suitable for use with NON-TLS clients +func WithNoTLS(address string) ClientOptions { + return ClientOptions{ + Endpoint: fmt.Sprintf("http://%s", address), + Client: &http.Client{ + Transport: &http.Transport{ + MaxConnsPerHost: 2_000, + MaxIdleConns: 2_000, + MaxIdleConnsPerHost: 2_000, + IdleConnTimeout: 60 * time.Second, + }, + }, + } +} + +// WithTLS returns ClientOptions suitable for use with TLS clients +func WithTLS(tls *tls.Config, address string) ClientOptions { + return ClientOptions{ + Endpoint: fmt.Sprintf("https://%s", address), + Client: &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tls, + MaxConnsPerHost: 2_000, + MaxIdleConns: 2_000, + MaxIdleConnsPerHost: 2_000, + IdleConnTimeout: 60 * time.Second, + }, + }, + } +} + +// WithPeerInfo returns ClientOptions using the provided PeerInfo +func WithPeerInfo(info PeerInfo) ClientOptions { + if info.GetTLS() != nil { + return WithTLS(info.GetTLS(), info.HTTPAddress) + } + return WithNoTLS(info.HTTPAddress) +} + +// WithDaemonConfig returns ClientOptions suitable for use by the Daemon +func WithDaemonConfig(conf DaemonConfig, address string) ClientOptions { + if conf.ClientTLS() == nil { + return WithNoTLS(address) + } + return WithTLS(conf.ClientTLS(), address) } // ToTimeStamp is a convenience function to convert a time.Duration @@ -96,10 +244,10 @@ func RandomPeer(peers []PeerInfo) PeerInfo { // RandomString returns a random alpha string of 'n' length func RandomString(n int) string { const alphanumeric = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" - var bytes = make([]byte, n) - _, _ = crand.Read(bytes) - for i, b := range bytes { - bytes[i] = alphanumeric[b%byte(len(alphanumeric))] + var buf = make([]byte, n) + _, _ = crand.Read(buf) + for i, b := range buf { + buf[i] = alphanumeric[b%byte(len(alphanumeric))] } - return string(bytes) + return string(buf) } diff --git a/cluster/cluster.go b/cluster/cluster.go index d3999ba..fa09540 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -19,12 +19,12 @@ package cluster import ( "context" "fmt" + "log/slog" "math/rand" - "github.com/gubernator-io/gubernator/v2" - "github.com/mailgun/holster/v4/clock" - "github.com/mailgun/holster/v4/errors" - "github.com/sirupsen/logrus" + "github.com/gubernator-io/gubernator/v3" + "github.com/kapetan-io/errors" + "github.com/kapetan-io/tackle/clock" ) const ( @@ -36,8 +36,14 @@ const ( var daemons []*gubernator.Daemon var peers []gubernator.PeerInfo -// GetRandomPeer returns a random peer from the cluster -func GetRandomPeer(dc string) gubernator.PeerInfo { +// GetRandomPeerOptions returns gubernator.ClientOptions for a random peer in the cluster +func GetRandomPeerOptions(dc string) gubernator.ClientOptions { + info := GetRandomPeerInfo(dc) + return gubernator.WithNoTLS(info.HTTPAddress) +} + +// GetRandomPeerInfo returns a random peer from the cluster +func GetRandomPeerInfo(dc string) gubernator.PeerInfo { var local []gubernator.PeerInfo for _, p := range peers { @@ -53,6 +59,23 @@ func GetRandomPeer(dc string) gubernator.PeerInfo { return local[rand.Intn(len(local))] } +// GetRandomDaemon returns a random daemon from the cluster +func GetRandomDaemon(dc string) *gubernator.Daemon { + var local []*gubernator.Daemon + + for _, d := range daemons { + if d.PeerInfo.DataCenter == dc { + local = append(local, d) + } + } + + if len(local) == 0 { + panic(fmt.Sprintf("failed to find random daemon for dc '%s'", dc)) + } + + return local[rand.Intn(len(local))] +} + // GetPeers returns a list of all peers in the cluster func GetPeers() []gubernator.PeerInfo { return peers @@ -70,7 +93,7 @@ func PeerAt(idx int) gubernator.PeerInfo { // FindOwningPeer finds the peer which owns the rate limit with the provided name and unique key func FindOwningPeer(name, key string) (gubernator.PeerInfo, error) { - p, err := daemons[0].V1Server.GetPeer(context.Background(), name+"_"+key) + p, err := daemons[0].Service.GetPeer(context.Background(), name+"_"+key) if err != nil { return gubernator.PeerInfo{}, err } @@ -79,13 +102,13 @@ func FindOwningPeer(name, key string) (gubernator.PeerInfo, error) { // FindOwningDaemon finds the daemon which owns the rate limit with the provided name and unique key func FindOwningDaemon(name, key string) (*gubernator.Daemon, error) { - p, err := daemons[0].V1Server.GetPeer(context.Background(), name+"_"+key) + p, err := daemons[0].Service.GetPeer(context.Background(), name+"_"+key) if err != nil { return &gubernator.Daemon{}, err } for i, d := range daemons { - if d.PeerInfo.GRPCAddress == p.Info().GRPCAddress { + if d.Config().HTTPListenAddress == p.Info().HTTPAddress { return daemons[i], nil } } @@ -102,7 +125,7 @@ func ListNonOwningDaemons(name, key string) ([]*gubernator.Daemon, error) { var daemons []*gubernator.Daemon for _, d := range GetDaemons() { - if d.PeerInfo.GRPCAddress != owner.PeerInfo.GRPCAddress { + if d.Config().HTTPListenAddress != owner.Config().HTTPListenAddress { daemons = append(daemons, d) } } @@ -121,16 +144,15 @@ func NumOfDaemons() int { // Start a local cluster of gubernator servers func Start(numInstances int) error { - // Ideally we should let the socket choose the port, but then + // Ideally, we should let the socket choose the port, but then // some things like the logger will not be set correctly. var peers []gubernator.PeerInfo port := 1111 for i := 0; i < numInstances; i++ { peers = append(peers, gubernator.PeerInfo{ HTTPAddress: fmt.Sprintf("localhost:%d", port), - GRPCAddress: fmt.Sprintf("localhost:%d", port+1), }) - port += 2 + port += 1 } return StartWith(peers) } @@ -138,7 +160,9 @@ func Start(numInstances int) error { // Restart the cluster func Restart(ctx context.Context) error { for i := 0; i < len(daemons); i++ { - daemons[i].Close() + if err := daemons[i].Close(ctx); err != nil { + return err + } if err := daemons[i].Start(ctx); err != nil { return err } @@ -148,15 +172,16 @@ func Restart(ctx context.Context) error { } // StartWith a local cluster with specific addresses -func StartWith(localPeers []gubernator.PeerInfo, opts ...option) error { +func StartWith(localPeers []gubernator.PeerInfo, opts ...Option) error { for _, peer := range localPeers { ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) cfg := gubernator.DaemonConfig{ - Logger: logrus.WithField("instance", peer.GRPCAddress), - InstanceID: peer.GRPCAddress, - GRPCListenAddress: peer.GRPCAddress, + Logger: slog.Default().With("instance", peer.HTTPAddress), + InstanceID: peer.HTTPAddress, HTTPListenAddress: peer.HTTPAddress, + AdvertiseAddress: peer.HTTPAddress, DataCenter: peer.DataCenter, + CacheProvider: "otter", Behaviors: gubernator.BehaviorConfig{ // Suitable for testing but not production GlobalSyncWait: clock.Millisecond * 50, @@ -164,24 +189,18 @@ func StartWith(localPeers []gubernator.PeerInfo, opts ...option) error { BatchTimeout: clock.Second * 5, }, } + for _, opt := range opts { opt.Apply(&cfg) } d, err := gubernator.SpawnDaemon(ctx, cfg) cancel() if err != nil { - return errors.Wrapf(err, "while starting server for addr '%s'", peer.GRPCAddress) - } - - p := gubernator.PeerInfo{ - GRPCAddress: d.GRPCListeners[0].Addr().String(), - HTTPAddress: d.HTTPListener.Addr().String(), - DataCenter: peer.DataCenter, + return fmt.Errorf("while starting server for addr '%s': %w", peer.HTTPAddress, err) } - d.PeerInfo = p // Add the peers and daemons to the package level variables - peers = append(peers, p) + peers = append(peers, d.PeerInfo) daemons = append(daemons, d) } @@ -193,15 +212,15 @@ func StartWith(localPeers []gubernator.PeerInfo, opts ...option) error { } // Stop all daemons in the cluster -func Stop() { +func Stop(ctx context.Context) { for _, d := range daemons { - d.Close() + _ = d.Close(ctx) } peers = nil daemons = nil } -type option interface { +type Option interface { Apply(cfg *gubernator.DaemonConfig) } @@ -214,6 +233,6 @@ func (o *eventChannelOption) Apply(cfg *gubernator.DaemonConfig) { } // WithEventChannel sets EventChannel to Gubernator config. -func WithEventChannel(eventChannel chan<- gubernator.HitEvent) option { +func WithEventChannel(eventChannel chan<- gubernator.HitEvent) Option { return &eventChannelOption{eventChannel: eventChannel} } diff --git a/cluster/cluster_test.go b/cluster/cluster_test.go index 16f0c7f..4e7f9d5 100644 --- a/cluster/cluster_test.go +++ b/cluster/cluster_test.go @@ -17,10 +17,11 @@ limitations under the License. package cluster_test import ( + "context" "testing" - gubernator "github.com/gubernator-io/gubernator/v2" - "github.com/gubernator-io/gubernator/v2/cluster" + "github.com/gubernator-io/gubernator/v3" + "github.com/gubernator-io/gubernator/v3/cluster" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" @@ -32,7 +33,9 @@ func TestStartMultipleInstances(t *testing.T) { }) err := cluster.Start(2) require.NoError(t, err) - t.Cleanup(cluster.Stop) + t.Cleanup(func() { + cluster.Stop(context.Background()) + }) assert.Equal(t, 2, len(cluster.GetPeers())) assert.Equal(t, 2, len(cluster.GetDaemons())) @@ -41,7 +44,7 @@ func TestStartMultipleInstances(t *testing.T) { func TestStartOneInstance(t *testing.T) { err := cluster.Start(1) require.NoError(t, err) - defer cluster.Stop() + defer cluster.Stop(context.Background()) assert.Equal(t, 1, len(cluster.GetPeers())) assert.Equal(t, 1, len(cluster.GetDaemons())) @@ -49,28 +52,28 @@ func TestStartOneInstance(t *testing.T) { func TestStartMultipleDaemons(t *testing.T) { peers := []gubernator.PeerInfo{ - {GRPCAddress: "localhost:1111", HTTPAddress: "localhost:1112"}, - {GRPCAddress: "localhost:2222", HTTPAddress: "localhost:2221"}} + {HTTPAddress: "localhost:1111"}, + {HTTPAddress: "localhost:2222"}} err := cluster.StartWith(peers) require.NoError(t, err) - defer cluster.Stop() + defer cluster.Stop(context.Background()) wantPeers := []gubernator.PeerInfo{ - {GRPCAddress: "127.0.0.1:1111", HTTPAddress: "127.0.0.1:1112"}, - {GRPCAddress: "127.0.0.1:2222", HTTPAddress: "127.0.0.1:2221"}, + {HTTPAddress: "127.0.0.1:1111"}, + {HTTPAddress: "127.0.0.1:2222"}, } daemons := cluster.GetDaemons() assert.Equal(t, wantPeers, cluster.GetPeers()) assert.Equal(t, 2, len(daemons)) - assert.Equal(t, "127.0.0.1:1111", daemons[0].GRPCListeners[0].Addr().String()) - assert.Equal(t, "127.0.0.1:2222", daemons[1].GRPCListeners[0].Addr().String()) - assert.Equal(t, "127.0.0.1:2222", cluster.DaemonAt(1).GRPCListeners[0].Addr().String()) - assert.Equal(t, "127.0.0.1:2222", cluster.PeerAt(1).GRPCAddress) + assert.Equal(t, "127.0.0.1:1111", daemons[0].Listener.Addr().String()) + assert.Equal(t, "127.0.0.1:2222", daemons[1].Listener.Addr().String()) + assert.Equal(t, "127.0.0.1:2222", cluster.DaemonAt(1).Listener.Addr().String()) + assert.Equal(t, "127.0.0.1:2222", cluster.PeerAt(1).HTTPAddress) } func TestStartWithInvalidPeer(t *testing.T) { - err := cluster.StartWith([]gubernator.PeerInfo{{GRPCAddress: "1111"}}) + err := cluster.StartWith([]gubernator.PeerInfo{{HTTPAddress: "1111"}}) assert.NotNil(t, err) assert.Nil(t, cluster.GetPeers()) assert.Nil(t, cluster.GetDaemons()) diff --git a/cmd/gubernator-cli/main.go b/cmd/gubernator-cli/main.go index 3c05912..f0faa03 100644 --- a/cmd/gubernator-cli/main.go +++ b/cmd/gubernator-cli/main.go @@ -20,27 +20,27 @@ import ( "context" "flag" "fmt" + "log/slog" "math/rand" "os" "strings" "time" "github.com/davecgh/go-spew/spew" - guber "github.com/gubernator-io/gubernator/v2" - "github.com/mailgun/holster/v4/clock" - "github.com/mailgun/holster/v4/errors" - "github.com/mailgun/holster/v4/setter" - "github.com/mailgun/holster/v4/syncutil" - "github.com/mailgun/holster/v4/tracing" - "github.com/sirupsen/logrus" + guber "github.com/gubernator-io/gubernator/v3" + "github.com/gubernator-io/gubernator/v3/tracing" + "github.com/kapetan-io/tackle/clock" + "github.com/kapetan-io/tackle/set" + "github.com/kapetan-io/tackle/wait" "go.opentelemetry.io/otel/attribute" + sdktrace "go.opentelemetry.io/otel/sdk/trace" "go.opentelemetry.io/otel/trace" "golang.org/x/time/rate" ) var ( - log *logrus.Logger - configFile, grpcAddress string + log *slog.Logger + configFile, httpAddress string concurrency uint64 timeout time.Duration checksPerRequest uint64 @@ -49,9 +49,8 @@ var ( ) func main() { - log = logrus.StandardLogger() flag.StringVar(&configFile, "config", "", "Environment config file") - flag.StringVar(&grpcAddress, "e", "", "Gubernator GRPC endpoint address") + flag.StringVar(&httpAddress, "e", "", "Gubernator HTTP endpoint address") flag.Uint64Var(&concurrency, "concurrency", 1, "Concurrent threads (default 1)") flag.DurationVar(&timeout, "timeout", 100*time.Millisecond, "Request timeout (default 100ms)") flag.Uint64Var(&checksPerRequest, "checks", 1, "Rate checks per request (default 1)") @@ -59,68 +58,74 @@ func main() { flag.BoolVar(&quiet, "q", false, "Quiet logging") flag.Parse() - if quiet { - log.SetLevel(logrus.ErrorLevel) - } + log = slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: func() slog.Level { + if quiet { + return slog.LevelError + } + return slog.LevelInfo + }(), + })) // Initialize tracing. res, err := tracing.NewResource("gubernator-cli", "") if err != nil { - log.WithError(err).Fatal("Error in tracing.NewResource") + log.LogAttrs(context.TODO(), slog.LevelError, "Error in tracing.NewResource", + guber.ErrAttr(err), + ) + return } ctx := context.Background() - err = tracing.InitTracing(ctx, - "github.com/gubernator-io/gubernator/v2/cmd/gubernator-cli", - tracing.WithResource(res), - ) + shutdown, err := tracing.InitTracing(ctx, log, "github.com/gubernator-io/gubernator/v3/cmd/gubernator-cli", + sdktrace.WithResource(res)) if err != nil { - log.WithError(err).Warn("Error in tracing.InitTracing") + log.LogAttrs(context.TODO(), slog.LevelWarn, "Error in tracing.InitTracing", + guber.ErrAttr(err), + ) } + defer func() { _ = shutdown(ctx) }() // Print startup message. - startCtx := tracing.StartScope(ctx) + startCtx := tracing.StartScope(ctx, "main") argsMsg := fmt.Sprintf("Command line: %s", strings.Join(os.Args[1:], " ")) log.Info(argsMsg) tracing.EndScope(startCtx, nil) - var client guber.V1Client - err = tracing.CallScope(ctx, func(ctx context.Context) error { - // Print startup message. - cmdLine := strings.Join(os.Args[1:], " ") - logrus.WithContext(ctx).Info("Command line: " + cmdLine) + var client guber.Client + // Print startup message. + cmdLine := strings.Join(os.Args[1:], " ") + log.LogAttrs(ctx, slog.LevelInfo, "Command Line", + slog.String("cmdLine", cmdLine), + ) - configFileReader, err := os.Open(configFile) - if err != nil { - return fmt.Errorf("while opening config file: %s", err) - } - conf, err := guber.SetupDaemonConfig(log, configFileReader) - if err != nil { - return err - } - setter.SetOverride(&conf.GRPCListenAddress, grpcAddress) + configFileReader, err := os.Open(configFile) + exitOnError(err, "Error opening config file") - if configFile == "" && grpcAddress == "" && os.Getenv("GUBER_GRPC_ADDRESS") == "" { - return errors.New("please provide a GRPC endpoint via -e or from a config " + - "file via -config or set the env GUBER_GRPC_ADDRESS") - } + conf, err := guber.SetupDaemonConfig(log, configFileReader) + exitOnError(err, "Error parsing config file") - err = guber.SetupTLS(conf.TLS) - if err != nil { - return err - } + set.Override(&conf.HTTPListenAddress, httpAddress) + + if configFile == "" && httpAddress == "" && os.Getenv("GUBER_GRPC_ADDRESS") == "" { + log.Error("please provide a GRPC endpoint via -e or from a config " + + "file via -config or set the env GUBER_GRPC_ADDRESS") + os.Exit(1) + } - log.WithContext(ctx).Infof("Connecting to '%s'...", conf.GRPCListenAddress) - client, err = guber.DialV1Server(conf.GRPCListenAddress, conf.ClientTLS()) - return err - }) + err = guber.SetupTLS(conf.TLS) + exitOnError(err, "Error setting up TLS") - checkErr(err) + log.LogAttrs(context.TODO(), slog.LevelInfo, "Connecting to", + slog.String("address", conf.HTTPListenAddress), + ) + client, err = guber.NewClient(guber.WithDaemonConfig(conf, conf.HTTPListenAddress)) + exitOnError(err, "Error creating client") // Generate a selection of rate limits with random limits. - var rateLimits []*guber.RateLimitReq + var rateLimits []*guber.RateLimitRequest for i := 0; i < 2000; i++ { - rateLimits = append(rateLimits, &guber.RateLimitReq{ + rateLimits = append(rateLimits, &guber.RateLimitRequest{ Name: fmt.Sprintf("gubernator-cli-%d", i), UniqueKey: guber.RandomString(10), Hits: 1, @@ -131,97 +136,90 @@ func main() { }) } - fan := syncutil.NewFanOut(int(concurrency)) + fan := wait.NewFanOut(int(concurrency)) var limiter *rate.Limiter if reqRate > 0 { l := rate.Limit(reqRate) - log.WithField("reqRate", reqRate).Info("") + log.LogAttrs(context.TODO(), slog.LevelInfo, "rate", + slog.Float64("rate", reqRate), + ) limiter = rate.NewLimiter(l, 1) } // Replay requests in endless loop. for { for i := int(0); i < len(rateLimits); i += int(checksPerRequest) { - req := &guber.GetRateLimitsReq{ + req := &guber.CheckRateLimitsRequest{ Requests: rateLimits[i:min(i+int(checksPerRequest), len(rateLimits))], } - fan.Run(func(obj interface{}) error { - req := obj.(*guber.GetRateLimitsReq) - + fan.Run(func() error { if reqRate > 0 { _ = limiter.Wait(ctx) } - sendRequest(ctx, client, req) - return nil - }, req) + }) } } } -func min(a, b int) int { - if a <= b { - return a - } - return b -} - -func checkErr(err error) { - if err != nil { - log.Fatal(err.Error()) - } -} - func randInt(min, max int) int { return rand.Intn(max-min) + min } -func sendRequest(ctx context.Context, client guber.V1Client, req *guber.GetRateLimitsReq) { - ctx = tracing.StartScope(ctx) +func sendRequest(ctx context.Context, client guber.Client, req *guber.CheckRateLimitsRequest) { + ctx = tracing.StartScope(ctx, "sendRequest") defer tracing.EndScope(ctx, nil) ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() // Now hit our cluster with the rate limits - resp, err := client.GetRateLimits(ctx, req) - cancel() - if err != nil { - log.WithContext(ctx).WithError(err).Error("Error in client.GetRateLimits") + var resp guber.CheckRateLimitsResponse + if err := client.CheckRateLimits(ctx, req, &resp); err != nil { + log.LogAttrs(ctx, slog.LevelError, "Error in client.GetRateLimits", + guber.ErrAttr(err), + ) return } // Sanity checks. - if resp == nil { - log.WithContext(ctx).Error("Response object is unexpectedly nil") - return - } if resp.Responses == nil { - log.WithContext(ctx).Error("Responses array is unexpectedly nil") + log.LogAttrs(ctx, slog.LevelError, "Responses array is unexpectedly nil") return } - // Check for overlimit response. - overlimit := false + // Check for over limit response. + overLimit := false for itemNum, resp := range resp.Responses { if resp.Status == guber.Status_OVER_LIMIT { - overlimit = true - log.WithContext(ctx).WithField("name", req.Requests[itemNum].Name). - Info("Overlimit!") + overLimit = true + log.LogAttrs(ctx, slog.LevelInfo, "Overlimit!", + slog.String("name", req.Requests[itemNum].Name), + ) } } - if overlimit { + if overLimit { span := trace.SpanFromContext(ctx) span.SetAttributes( attribute.Bool("overlimit", true), ) if !quiet { - dumpResp := spew.Sdump(resp) - log.WithContext(ctx).Info(dumpResp) + dumpResp := spew.Sdump(&resp) + log.LogAttrs(ctx, slog.LevelInfo, "Dump", + slog.String("value", dumpResp), + ) } } } + +func exitOnError(err error, msg string) { + if err != nil { + fmt.Printf("%s: %s\n", msg, err) + os.Exit(1) + } +} diff --git a/cmd/gubernator-cluster/main.go b/cmd/gubernator-cluster/main.go index 49f23ca..4169044 100644 --- a/cmd/gubernator-cluster/main.go +++ b/cmd/gubernator-cluster/main.go @@ -17,26 +17,25 @@ limitations under the License. package main import ( + "context" "fmt" "os" "os/signal" - "github.com/gubernator-io/gubernator/v2" - "github.com/gubernator-io/gubernator/v2/cluster" - "github.com/sirupsen/logrus" + "github.com/gubernator-io/gubernator/v3" + "github.com/gubernator-io/gubernator/v3/cluster" ) // Start a cluster of gubernator instances for use in testing clients func main() { - logrus.SetLevel(logrus.InfoLevel) // Start a local cluster err := cluster.StartWith([]gubernator.PeerInfo{ - {GRPCAddress: "127.0.0.1:9990", HTTPAddress: "127.0.0.1:9980"}, - {GRPCAddress: "127.0.0.1:9991", HTTPAddress: "127.0.0.1:9981"}, - {GRPCAddress: "127.0.0.1:9992", HTTPAddress: "127.0.0.1:9982"}, - {GRPCAddress: "127.0.0.1:9993", HTTPAddress: "127.0.0.1:9983"}, - {GRPCAddress: "127.0.0.1:9994", HTTPAddress: "127.0.0.1:9984"}, - {GRPCAddress: "127.0.0.1:9995", HTTPAddress: "127.0.0.1:9985"}, + {HTTPAddress: "127.0.0.1:9980"}, + {HTTPAddress: "127.0.0.1:9981"}, + {HTTPAddress: "127.0.0.1:9982"}, + {HTTPAddress: "127.0.0.1:9983"}, + {HTTPAddress: "127.0.0.1:9984"}, + {HTTPAddress: "127.0.0.1:9985"}, }) if err != nil { panic(err) @@ -49,7 +48,7 @@ func main() { signal.Notify(c, os.Interrupt) for sig := range c { if sig == os.Interrupt { - cluster.Stop() + cluster.Stop(context.Background()) os.Exit(0) } } diff --git a/cmd/gubernator/main.go b/cmd/gubernator/main.go index 3b20856..cbe5d75 100644 --- a/cmd/gubernator/main.go +++ b/cmd/gubernator/main.go @@ -21,21 +21,23 @@ import ( "flag" "fmt" "io" + "log/slog" "os" "os/signal" "runtime" "strings" "syscall" - "github.com/gubernator-io/gubernator/v2" - "github.com/mailgun/holster/v4/tracing" - "github.com/sirupsen/logrus" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + + "github.com/gubernator-io/gubernator/v3" + "github.com/gubernator-io/gubernator/v3/tracing" "go.opentelemetry.io/otel/sdk/resource" semconv "go.opentelemetry.io/otel/semconv/v1.24.0" "k8s.io/klog/v2" ) -var log = logrus.WithField("category", "gubernator") +var log = slog.Default().With("category", "gubernator") var Version = "dev-build" var tracerCloser io.Closer @@ -50,7 +52,11 @@ func main() { func Main(ctx context.Context) error { var configFile string - logrus.Infof("Gubernator %s (%s/%s)", Version, runtime.GOARCH, runtime.GOOS) + log.LogAttrs(ctx, slog.LevelInfo, "Gubernator", + slog.String("version", Version), + slog.String("arch", runtime.GOARCH), + slog.String("os", runtime.GOOS), + ) flags := flag.NewFlagSet("gubernator", flag.ContinueOnError) flags.SetOutput(io.Discard) flags.StringVar(&configFile, "config", "", "environment config file") @@ -73,7 +79,10 @@ func Main(ctx context.Context) error { semconv.ServiceInstanceID(gubernator.GetInstanceID()), )) if err != nil { - log.WithError(err).Fatal("during tracing.NewResource()") + log.LogAttrs(ctx, slog.LevelError, "during tracing.NewResource()", + gubernator.ErrAttr(err), + ) + return err } defer func() { if tracerCloser != nil { @@ -82,13 +91,14 @@ func Main(ctx context.Context) error { }() // Initialize tracing. - err = tracing.InitTracing(ctx, - "github.com/gubernator-io/gubernator/v2", - tracing.WithLevel(gubernator.GetTracingLevel()), - tracing.WithResource(res), - ) + shutdown, err := tracing.InitTracing(ctx, log, + "github.com/gubernator-io/gubernator/v3", + sdktrace.WithResource(res)) if err != nil { - log.WithError(err).Fatal("during tracing.InitTracing()") + log.LogAttrs(ctx, slog.LevelError, "during tracing.InitTracing()", + gubernator.ErrAttr(err), + ) + return err } var configFileReader io.Reader @@ -96,11 +106,14 @@ func Main(ctx context.Context) error { if configFile != "" { configFileReader, err = os.Open(configFile) if err != nil { - log.WithError(err).Fatal("while opening config file") + log.LogAttrs(ctx, slog.LevelError, "while opening config file", + gubernator.ErrAttr(err), + ) + return err } } - conf, err := gubernator.SetupDaemonConfig(logrus.StandardLogger(), configFileReader) + conf, err := gubernator.SetupDaemonConfig(slog.Default(), configFileReader) if err != nil { return fmt.Errorf("while collecting daemon config: %w", err) } @@ -117,8 +130,8 @@ func Main(ctx context.Context) error { select { case <-c: log.Info("caught signal; shutting down") - daemon.Close() - _ = tracing.CloseTracing(context.Background()) + _ = daemon.Close(context.Background()) + _ = shutdown(ctx) return nil case <-ctx.Done(): return ctx.Err() diff --git a/cmd/gubernator/main_test.go b/cmd/gubernator/main_test.go index 1ca08e3..1b51fa1 100644 --- a/cmd/gubernator/main_test.go +++ b/cmd/gubernator/main_test.go @@ -19,7 +19,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/net/proxy" - cli "github.com/gubernator-io/gubernator/v2/cmd/gubernator" + cli "github.com/gubernator-io/gubernator/v3/cmd/gubernator" ) var cliRunning = flag.Bool("test_cli_running", false, "True if running as a child process; used by TestCLI") @@ -46,11 +46,10 @@ func TestCLI(t *testing.T) { { name: "Should start with no config provided", env: []string{ - "GUBER_GRPC_ADDRESS=localhost:1050", - "GUBER_HTTP_ADDRESS=localhost:1051", + "GUBER_HTTP_ADDRESS=localhost:1050", }, args: []string{}, - contains: "HTTP Gateway Listening on", + contains: "HTTP Listening", }, } for _, tt := range tests { @@ -79,9 +78,9 @@ func TestCLI(t *testing.T) { time.Sleep(time.Second * 1) err = c.Process.Signal(syscall.SIGTERM) - require.NoError(t, err) <-waitCh + require.NoError(t, err, out.String()) assert.Contains(t, out.String(), tt.contains) }) } diff --git a/cmd/healthcheck/main.go b/cmd/healthcheck/main.go index 3420e69..e27062f 100644 --- a/cmd/healthcheck/main.go +++ b/cmd/healthcheck/main.go @@ -23,7 +23,7 @@ import ( "net/http" "os" - guber "github.com/gubernator-io/gubernator/v2" + guber "github.com/gubernator-io/gubernator/v3" ) func main() { @@ -31,7 +31,8 @@ func main() { if url == "" { url = "localhost:1050" } - resp, err := http.DefaultClient.Get(fmt.Sprintf("http://%s/v1/HealthCheck", url)) + + resp, err := http.DefaultClient.Get(fmt.Sprintf("http://%s/healthz", url)) if err != nil { panic(err) } @@ -42,7 +43,7 @@ func main() { panic(err) } - var hc guber.HealthCheckResp + var hc guber.HealthCheckResponse if err := json.Unmarshal(body, &hc); err != nil { panic(err) } diff --git a/config.go b/config.go index d9a6417..ee2e55a 100644 --- a/config.go +++ b/config.go @@ -18,30 +18,29 @@ package gubernator import ( "bufio" + "context" "crypto/rand" "crypto/tls" "crypto/x509" "encoding/hex" "fmt" "io" + "log/slog" "net" "os" "runtime" + "slices" "strconv" "strings" "time" "github.com/davecgh/go-spew/spew" - "github.com/mailgun/holster/v4/clock" - "github.com/mailgun/holster/v4/setter" - "github.com/mailgun/holster/v4/slice" - "github.com/mailgun/holster/v4/tracing" + "github.com/kapetan-io/tackle/clock" + "github.com/kapetan-io/tackle/set" "github.com/pkg/errors" "github.com/segmentio/fasthash/fnv1" "github.com/segmentio/fasthash/fnv1a" - "github.com/sirupsen/logrus" etcd "go.etcd.io/etcd/client/v3" - "google.golang.org/grpc" ) // BehaviorConfig controls the handling of rate limits in the cluster @@ -72,14 +71,14 @@ type BehaviorConfig struct { type Config struct { InstanceID string - // (Required) A list of GRPC servers to register our instance with - GRPCServers []*grpc.Server + // (Optional) The PeerClient gubernator should use when making requests to other peers in the cluster. + PeerClientFactory func(PeerInfo) PeerClient // (Optional) Adjust how gubernator behaviors are configured Behaviors BehaviorConfig // (Optional) The cache implementation - CacheFactory func(maxSize int) Cache + CacheFactory func(maxSize int) (Cache, error) // (Optional) A persistent store implementation. Allows the implementor the ability to store the rate limits this // instance of gubernator owns. It's up to the implementor to decide what rate limits to persist. @@ -107,12 +106,6 @@ type Config struct { // (Optional) A Logger which implements the declared logger interface (typically *logrus.Entry) Logger FieldLogger - // (Optional) The TLS config used when connecting to gubernator peers - PeerTLS *tls.Config - - // (Optional) If true, will emit traces for GRPC client requests to other peers - PeerTraceGRPC bool - // (Optional) The number of go routine workers used to process concurrent rate limit requests // Default is set to number of CPUs. Workers int @@ -125,31 +118,35 @@ type Config struct { } type HitEvent struct { - Request *RateLimitReq - Response *RateLimitResp + Request *RateLimitRequest + Response *RateLimitResponse } func (c *Config) SetDefaults() error { - setter.SetDefault(&c.Behaviors.BatchTimeout, time.Millisecond*500) - setter.SetDefault(&c.Behaviors.BatchLimit, maxBatchSize) - setter.SetDefault(&c.Behaviors.BatchWait, time.Microsecond*500) + set.Default(&c.Behaviors.BatchTimeout, time.Millisecond*500) + set.Default(&c.Behaviors.BatchLimit, maxBatchSize) + set.Default(&c.Behaviors.BatchWait, time.Microsecond*500) - setter.SetDefault(&c.Behaviors.GlobalTimeout, time.Millisecond*500) - setter.SetDefault(&c.Behaviors.GlobalBatchLimit, maxBatchSize) - setter.SetDefault(&c.Behaviors.GlobalSyncWait, time.Millisecond*100) + set.Default(&c.Behaviors.GlobalTimeout, time.Millisecond*500) + set.Default(&c.Behaviors.GlobalBatchLimit, maxBatchSize) + set.Default(&c.Behaviors.GlobalSyncWait, time.Millisecond*100) - setter.SetDefault(&c.Behaviors.GlobalPeerRequestsConcurrency, 100) + set.Default(&c.Behaviors.GlobalPeerRequestsConcurrency, 100) - setter.SetDefault(&c.LocalPicker, NewReplicatedConsistentHash(nil, defaultReplicas)) - setter.SetDefault(&c.RegionPicker, NewRegionPicker(nil)) + set.Default(&c.LocalPicker, NewReplicatedConsistentHash(nil, DefaultReplicas)) + set.Default(&c.RegionPicker, NewRegionPicker(nil)) - setter.SetDefault(&c.CacheSize, 50_000) - setter.SetDefault(&c.Workers, runtime.NumCPU()) - setter.SetDefault(&c.Logger, logrus.New().WithField("category", "gubernator")) + set.Default(&c.CacheSize, 50_000) + set.Default(&c.Workers, runtime.NumCPU()) + set.Default(&c.InstanceID, GetInstanceID()) + set.Default(&c.Logger, slog.Default().With( + "instance", c.InstanceID, + "category", "gubernator", + )) if c.CacheFactory == nil { - c.CacheFactory = func(maxSize int) Cache { - return NewLRUCache(maxSize) + c.CacheFactory = func(maxSize int) (Cache, error) { + return NewOtterCache(maxSize) } } @@ -158,9 +155,10 @@ func (c *Config) SetDefaults() error { } // Make a copy of the TLS config in case our caller decides to make changes - if c.PeerTLS != nil { - c.PeerTLS = c.PeerTLS.Clone() - } + // TODO(thrawn01): Fix Peer TLS + //if c.PeerTLS != nil { + // c.PeerTLS = c.PeerTLS.Clone() + //} return nil } @@ -170,15 +168,29 @@ type PeerInfo struct { DataCenter string `json:"data-center"` // (Optional) The http address:port of the peer HTTPAddress string `json:"http-address"` - // (Required) The grpc address:port of the peer - GRPCAddress string `json:"grpc-address"` // (Optional) Is true if PeerInfo is for this instance of gubernator IsOwner bool `json:"is-owner,omitempty"` + // tls is private so that will not be serialized if marshalled to json, yaml, etc... + tls *tls.Config } // HashKey returns the hash key used to identify this peer in the Picker. -func (p PeerInfo) HashKey() string { - return p.GRPCAddress +func (p *PeerInfo) HashKey() string { + return p.HTTPAddress +} + +// GetTLS returns the TLS config for this peer if it exists +func (p *PeerInfo) GetTLS() *tls.Config { + return p.tls +} + +// SetTLS sets the TLS config +func (p *PeerInfo) SetTLS(t *tls.Config) { + // SetTLS() is called by Daemon when SetPeers() is called. If a user has already provided a tls config when + // this PeerInfo was created, then we should not overwrite the user provided config. + if p.tls == nil { + p.tls = t + } } type UpdateFunc func([]PeerInfo) @@ -186,9 +198,6 @@ type UpdateFunc func([]PeerInfo) var DebugEnabled = false type DaemonConfig struct { - // (Required) The `address:port` that will accept GRPC requests - GRPCListenAddress string - // (Required) The `address:port` that will accept HTTP requests HTTPListenAddress string @@ -199,12 +208,8 @@ type DaemonConfig struct { // provide client certificate but you want to enforce mTLS in other RPCs (like in K8s) HTTPStatusListenAddress string - // (Optional) Defines the max age connection from client in seconds. - // Default is infinity - GRPCMaxConnectionAgeSeconds int - // (Optional) The `address:port` that is advertised to other Gubernator peers. - // Defaults to `GRPCListenAddress` + // Defaults to `HTTPListenAddress` AdvertiseAddress string // (Optional) The number of items in the cache. Defaults to 50,000 @@ -243,6 +248,16 @@ type DaemonConfig struct { // (Optional) A Logger which implements the declared logger interface (typically *logrus.Entry) Logger FieldLogger + // (Optional) A loader from a persistent store. Allows the implementor the ability to load and save + // the contents of the cache when the gubernator instance is started and stopped + Loader Loader + + // (Optional) A persistent store implementation. Allows the implementor the ability to store the rate limits this + // instance of gubernator owns. It's up to the implementor to decide what rate limits to persist. + // For instance, an implementor might only persist rate limits that have an expiration of + // longer than 1 hour. + Store Store + // (Optional) TLS Configuration; SpawnDaemon() will modify the passed TLS config in an // attempt to build a complete TLS config if one is not provided. TLS *TLSConfig @@ -253,12 +268,11 @@ type DaemonConfig struct { // (Optional) Instance ID which is a unique id that identifies this instance of gubernator InstanceID string - // (Optional) TraceLevel sets the tracing level, this controls the number of spans included in a single trace. - // Valid options are (tracing.InfoLevel, tracing.DebugLevel) Defaults to tracing.InfoLevel - TraceLevel tracing.Level - // (Optional) EventChannel receives hit events EventChannel chan<- HitEvent + + // (Optional) CacheProvider specifies which cache implementation to store rate limits in + CacheProvider string } func (d *DaemonConfig) ClientTLS() *tls.Config { @@ -277,8 +291,8 @@ func (d *DaemonConfig) ServerTLS() *tls.Config { // SetupDaemonConfig returns a DaemonConfig object that is the result of merging the lines // in the provided configFile and the environment variables. See `example.conf` for all available config options and their descriptions. -func SetupDaemonConfig(logger *logrus.Logger, configFile io.Reader) (DaemonConfig, error) { - log := logrus.NewEntry(logger) +func SetupDaemonConfig(logger *slog.Logger, configFile io.Reader) (DaemonConfig, error) { + log := logger.With() var conf DaemonConfig var logLevel string var logFormat string @@ -286,56 +300,55 @@ func SetupDaemonConfig(logger *logrus.Logger, configFile io.Reader) (DaemonConfi var err error if configFile != nil { - log.Infof("Loading env config: %s", configFile) if err := fromEnvFile(log, configFile); err != nil { return conf, err } } // Log config - setter.SetDefault(&logFormat, os.Getenv("GUBER_LOG_FORMAT")) + set.Default(&logFormat, os.Getenv("GUBER_LOG_FORMAT")) + slogLevel := &slog.LevelVar{} if logFormat != "" { switch logFormat { case "json": - logger.SetFormatter(&logrus.JSONFormatter{}) + log = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{ + Level: slogLevel, + })) case "text": - logger.SetFormatter(&logrus.TextFormatter{}) + log = slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: slogLevel, + })) default: return conf, errors.New("GUBER_LOG_FORMAT is invalid; expected value is either json or text") } } - setter.SetDefault(&DebugEnabled, getEnvBool(log, "GUBER_DEBUG")) - setter.SetDefault(&logLevel, os.Getenv("GUBER_LOG_LEVEL")) + set.Default(&DebugEnabled, getEnvBool(log, "GUBER_DEBUG")) + set.Default(&logLevel, os.Getenv("GUBER_LOG_LEVEL")) if DebugEnabled { - logger.SetLevel(logrus.DebugLevel) + slogLevel.Set(slog.LevelDebug) log.Debug("Debug enabled") } else if logLevel != "" { - logrusLogLevel, err := logrus.ParseLevel(logLevel) + err = slogLevel.UnmarshalText([]byte(logLevel)) if err != nil { return conf, errors.Wrap(err, "invalid log level") } - - logger.SetLevel(logrusLogLevel) } // Main config - setter.SetDefault(&conf.GRPCListenAddress, os.Getenv("GUBER_GRPC_ADDRESS"), - fmt.Sprintf("%s:1051", LocalHost())) - setter.SetDefault(&conf.HTTPListenAddress, os.Getenv("GUBER_HTTP_ADDRESS"), + set.Default(&conf.HTTPListenAddress, os.Getenv("GUBER_HTTP_ADDRESS"), fmt.Sprintf("%s:1050", LocalHost())) - setter.SetDefault(&conf.InstanceID, GetInstanceID()) - setter.SetDefault(&conf.HTTPStatusListenAddress, os.Getenv("GUBER_STATUS_HTTP_ADDRESS"), "") - setter.SetDefault(&conf.GRPCMaxConnectionAgeSeconds, getEnvInteger(log, "GUBER_GRPC_MAX_CONN_AGE_SEC"), 0) - setter.SetDefault(&conf.CacheSize, getEnvInteger(log, "GUBER_CACHE_SIZE"), 50_000) - setter.SetDefault(&conf.Workers, getEnvInteger(log, "GUBER_WORKER_COUNT"), 0) - setter.SetDefault(&conf.AdvertiseAddress, os.Getenv("GUBER_ADVERTISE_ADDRESS"), conf.GRPCListenAddress) - setter.SetDefault(&conf.DataCenter, os.Getenv("GUBER_DATA_CENTER"), "") - setter.SetDefault(&conf.MetricFlags, getEnvMetricFlags(log, "GUBER_METRIC_FLAGS")) + set.Default(&conf.InstanceID, GetInstanceID()) + set.Default(&conf.HTTPStatusListenAddress, os.Getenv("GUBER_STATUS_HTTP_ADDRESS"), "") + set.Default(&conf.CacheSize, getEnvInteger(log, "GUBER_CACHE_SIZE"), 50_000) + set.Default(&conf.Workers, getEnvInteger(log, "GUBER_WORKER_COUNT"), 0) + set.Default(&conf.AdvertiseAddress, os.Getenv("GUBER_ADVERTISE_ADDRESS"), conf.HTTPListenAddress) + set.Default(&conf.DataCenter, os.Getenv("GUBER_DATA_CENTER"), "") + set.Default(&conf.MetricFlags, getEnvMetricFlags(log, "GUBER_METRIC_FLAGS")) choices := []string{"member-list", "k8s", "etcd", "dns"} - setter.SetDefault(&conf.PeerDiscoveryType, os.Getenv("GUBER_PEER_DISCOVERY_TYPE"), "member-list") - if !slice.ContainsString(conf.PeerDiscoveryType, choices, nil) { + set.Default(&conf.PeerDiscoveryType, os.Getenv("GUBER_PEER_DISCOVERY_TYPE"), "member-list") + if !slices.Contains(choices, conf.PeerDiscoveryType) { return conf, fmt.Errorf("GUBER_PEER_DISCOVERY_TYPE is invalid; choices are [%s]`", strings.Join(choices, ",")) } @@ -353,25 +366,25 @@ func SetupDaemonConfig(logger *logrus.Logger, configFile io.Reader) (DaemonConfi } // Behaviors - setter.SetDefault(&conf.Behaviors.BatchTimeout, getEnvDuration(log, "GUBER_BATCH_TIMEOUT")) - setter.SetDefault(&conf.Behaviors.BatchLimit, getEnvInteger(log, "GUBER_BATCH_LIMIT")) - setter.SetDefault(&conf.Behaviors.BatchWait, getEnvDuration(log, "GUBER_BATCH_WAIT")) - setter.SetDefault(&conf.Behaviors.DisableBatching, getEnvBool(log, "GUBER_DISABLE_BATCHING")) + set.Default(&conf.Behaviors.BatchTimeout, getEnvDuration(log, "GUBER_BATCH_TIMEOUT")) + set.Default(&conf.Behaviors.BatchLimit, getEnvInteger(log, "GUBER_BATCH_LIMIT")) + set.Default(&conf.Behaviors.BatchWait, getEnvDuration(log, "GUBER_BATCH_WAIT")) + set.Default(&conf.Behaviors.DisableBatching, getEnvBool(log, "GUBER_DISABLE_BATCHING")) - setter.SetDefault(&conf.Behaviors.GlobalTimeout, getEnvDuration(log, "GUBER_GLOBAL_TIMEOUT")) - setter.SetDefault(&conf.Behaviors.GlobalBatchLimit, getEnvInteger(log, "GUBER_GLOBAL_BATCH_LIMIT")) - setter.SetDefault(&conf.Behaviors.GlobalSyncWait, getEnvDuration(log, "GUBER_GLOBAL_SYNC_WAIT")) - setter.SetDefault(&conf.Behaviors.ForceGlobal, getEnvBool(log, "GUBER_FORCE_GLOBAL")) + set.Default(&conf.Behaviors.GlobalTimeout, getEnvDuration(log, "GUBER_GLOBAL_TIMEOUT")) + set.Default(&conf.Behaviors.GlobalBatchLimit, getEnvInteger(log, "GUBER_GLOBAL_BATCH_LIMIT")) + set.Default(&conf.Behaviors.GlobalSyncWait, getEnvDuration(log, "GUBER_GLOBAL_SYNC_WAIT")) + set.Default(&conf.Behaviors.ForceGlobal, getEnvBool(log, "GUBER_FORCE_GLOBAL")) // TLS Config if anyHasPrefix("GUBER_TLS_", os.Environ()) { conf.TLS = &TLSConfig{} - setter.SetDefault(&conf.TLS.CaFile, os.Getenv("GUBER_TLS_CA")) - setter.SetDefault(&conf.TLS.CaKeyFile, os.Getenv("GUBER_TLS_CA_KEY")) - setter.SetDefault(&conf.TLS.KeyFile, os.Getenv("GUBER_TLS_KEY")) - setter.SetDefault(&conf.TLS.CertFile, os.Getenv("GUBER_TLS_CERT")) - setter.SetDefault(&conf.TLS.AutoTLS, getEnvBool(log, "GUBER_TLS_AUTO")) - setter.SetDefault(&conf.TLS.MinVersion, getEnvMinVersion(log, "GUBER_TLS_MIN_VERSION")) + set.Default(&conf.TLS.CaFile, os.Getenv("GUBER_TLS_CA")) + set.Default(&conf.TLS.CaKeyFile, os.Getenv("GUBER_TLS_CA_KEY")) + set.Default(&conf.TLS.KeyFile, os.Getenv("GUBER_TLS_KEY")) + set.Default(&conf.TLS.CertFile, os.Getenv("GUBER_TLS_CERT")) + set.Default(&conf.TLS.AutoTLS, getEnvBool(log, "GUBER_TLS_AUTO")) + set.Default(&conf.TLS.MinVersion, getEnvMinVersion(log, "GUBER_TLS_MIN_VERSION")) clientAuth := os.Getenv("GUBER_TLS_CLIENT_AUTH") if clientAuth != "" { @@ -388,33 +401,33 @@ func SetupDaemonConfig(logger *logrus.Logger, configFile io.Reader) (DaemonConfi } conf.TLS.ClientAuth = t } - setter.SetDefault(&conf.TLS.ClientAuthKeyFile, os.Getenv("GUBER_TLS_CLIENT_AUTH_KEY")) - setter.SetDefault(&conf.TLS.ClientAuthCertFile, os.Getenv("GUBER_TLS_CLIENT_AUTH_CERT")) - setter.SetDefault(&conf.TLS.ClientAuthCaFile, os.Getenv("GUBER_TLS_CLIENT_AUTH_CA_CERT")) - setter.SetDefault(&conf.TLS.InsecureSkipVerify, getEnvBool(log, "GUBER_TLS_INSECURE_SKIP_VERIFY")) - setter.SetDefault(&conf.TLS.ClientAuthServerName, os.Getenv("GUBER_TLS_CLIENT_AUTH_SERVER_NAME")) + set.Default(&conf.TLS.ClientAuthKeyFile, os.Getenv("GUBER_TLS_CLIENT_AUTH_KEY")) + set.Default(&conf.TLS.ClientAuthCertFile, os.Getenv("GUBER_TLS_CLIENT_AUTH_CERT")) + set.Default(&conf.TLS.ClientAuthCaFile, os.Getenv("GUBER_TLS_CLIENT_AUTH_CA_CERT")) + set.Default(&conf.TLS.InsecureSkipVerify, getEnvBool(log, "GUBER_TLS_INSECURE_SKIP_VERIFY")) + set.Default(&conf.TLS.ClientAuthServerName, os.Getenv("GUBER_TLS_CLIENT_AUTH_SERVER_NAME")) } // ETCD Config - setter.SetDefault(&conf.EtcdPoolConf.KeyPrefix, os.Getenv("GUBER_ETCD_KEY_PREFIX"), "/gubernator-peers") - setter.SetDefault(&conf.EtcdPoolConf.EtcdConfig, &etcd.Config{}) - setter.SetDefault(&conf.EtcdPoolConf.EtcdConfig.Endpoints, getEnvSlice("GUBER_ETCD_ENDPOINTS"), []string{"localhost:2379"}) - setter.SetDefault(&conf.EtcdPoolConf.EtcdConfig.DialTimeout, getEnvDuration(log, "GUBER_ETCD_DIAL_TIMEOUT"), clock.Second*5) - setter.SetDefault(&conf.EtcdPoolConf.EtcdConfig.Username, os.Getenv("GUBER_ETCD_USER")) - setter.SetDefault(&conf.EtcdPoolConf.EtcdConfig.Password, os.Getenv("GUBER_ETCD_PASSWORD")) - setter.SetDefault(&conf.EtcdPoolConf.Advertise.GRPCAddress, os.Getenv("GUBER_ETCD_ADVERTISE_ADDRESS"), conf.AdvertiseAddress) - setter.SetDefault(&conf.EtcdPoolConf.Advertise.DataCenter, os.Getenv("GUBER_ETCD_DATA_CENTER"), conf.DataCenter) - - setter.SetDefault(&conf.MemberListPoolConf.Advertise.GRPCAddress, os.Getenv("GUBER_MEMBERLIST_ADVERTISE_ADDRESS"), conf.AdvertiseAddress) - setter.SetDefault(&conf.MemberListPoolConf.MemberListAddress, os.Getenv("GUBER_MEMBERLIST_ADDRESS"), fmt.Sprintf("%s:7946", advAddr)) - setter.SetDefault(&conf.MemberListPoolConf.KnownNodes, getEnvSlice("GUBER_MEMBERLIST_KNOWN_NODES"), []string{}) - setter.SetDefault(&conf.MemberListPoolConf.Advertise.DataCenter, conf.DataCenter) - setter.SetDefault(&conf.MemberListPoolConf.EncryptionConfig.SecretKeys, getEnvSlice("GUBER_MEMBERLIST_SECRET_KEYS"), []string{}) - setter.SetDefault(&conf.MemberListPoolConf.EncryptionConfig.GossipVerifyIncoming, getEnvBool(log, "GUBER_MEMBERLIST_GOSSIP_VERIFY_INCOMING"), true) - setter.SetDefault(&conf.MemberListPoolConf.EncryptionConfig.GossipVerifyOutgoing, getEnvBool(log, "GUBER_MEMBERLIST_GOSSIP_VERIFY_OUTGOING"), true) + set.Default(&conf.EtcdPoolConf.KeyPrefix, os.Getenv("GUBER_ETCD_KEY_PREFIX"), "/gubernator-peers") + set.Default(&conf.EtcdPoolConf.EtcdConfig, &etcd.Config{}) + set.Default(&conf.EtcdPoolConf.EtcdConfig.Endpoints, getEnvSlice("GUBER_ETCD_ENDPOINTS"), []string{"localhost:2379"}) + set.Default(&conf.EtcdPoolConf.EtcdConfig.DialTimeout, getEnvDuration(log, "GUBER_ETCD_DIAL_TIMEOUT"), clock.Second*5) + set.Default(&conf.EtcdPoolConf.EtcdConfig.Username, os.Getenv("GUBER_ETCD_USER")) + set.Default(&conf.EtcdPoolConf.EtcdConfig.Password, os.Getenv("GUBER_ETCD_PASSWORD")) + set.Default(&conf.EtcdPoolConf.Advertise.HTTPAddress, os.Getenv("GUBER_ETCD_ADVERTISE_ADDRESS"), conf.AdvertiseAddress) + set.Default(&conf.EtcdPoolConf.Advertise.DataCenter, os.Getenv("GUBER_ETCD_DATA_CENTER"), conf.DataCenter) + + set.Default(&conf.MemberListPoolConf.Advertise.HTTPAddress, os.Getenv("GUBER_MEMBERLIST_ADVERTISE_ADDRESS"), conf.AdvertiseAddress) + set.Default(&conf.MemberListPoolConf.MemberListAddress, os.Getenv("GUBER_MEMBERLIST_ADDRESS"), fmt.Sprintf("%s:7946", advAddr)) + set.Default(&conf.MemberListPoolConf.KnownNodes, getEnvSlice("GUBER_MEMBERLIST_KNOWN_NODES"), []string{}) + set.Default(&conf.MemberListPoolConf.Advertise.DataCenter, conf.DataCenter) + set.Default(&conf.MemberListPoolConf.EncryptionConfig.SecretKeys, getEnvSlice("GUBER_MEMBERLIST_SECRET_KEYS"), []string{}) + set.Default(&conf.MemberListPoolConf.EncryptionConfig.GossipVerifyIncoming, getEnvBool(log, "GUBER_MEMBERLIST_GOSSIP_VERIFY_INCOMING"), true) + set.Default(&conf.MemberListPoolConf.EncryptionConfig.GossipVerifyOutgoing, getEnvBool(log, "GUBER_MEMBERLIST_GOSSIP_VERIFY_OUTGOING"), true) // Kubernetes Config - setter.SetDefault(&conf.K8PoolConf.Namespace, os.Getenv("GUBER_K8S_NAMESPACE"), "default") + set.Default(&conf.K8PoolConf.Namespace, os.Getenv("GUBER_K8S_NAMESPACE"), "default") conf.K8PoolConf.PodIP = os.Getenv("GUBER_K8S_POD_IP") conf.K8PoolConf.PodPort = os.Getenv("GUBER_K8S_POD_PORT") conf.K8PoolConf.Selector = os.Getenv("GUBER_K8S_ENDPOINTS_SELECTOR") @@ -426,20 +439,23 @@ func SetupDaemonConfig(logger *logrus.Logger, configFile io.Reader) (DaemonConfi } // DNS Config - setter.SetDefault(&conf.DNSPoolConf.FQDN, os.Getenv("GUBER_DNS_FQDN")) - setter.SetDefault(&conf.DNSPoolConf.ResolvConf, os.Getenv("GUBER_RESOLV_CONF"), "/etc/resolv.conf") - setter.SetDefault(&conf.DNSPoolConf.OwnAddress, conf.AdvertiseAddress) + set.Default(&conf.DNSPoolConf.FQDN, os.Getenv("GUBER_DNS_FQDN")) + set.Default(&conf.DNSPoolConf.ResolvConf, os.Getenv("GUBER_RESOLV_CONF"), "/etc/resolv.conf") + set.Default(&conf.DNSPoolConf.OwnAddress, conf.AdvertiseAddress) + + set.Default(&conf.CacheProvider, os.Getenv("GUBER_CACHE_PROVIDER"), "default-lru") // PeerPicker Config + // TODO: Deprecated: Remove in GUBER_PEER_PICKER in v3 if pp := os.Getenv("GUBER_PEER_PICKER"); pp != "" { var replicas int var hash string switch pp { case "replicated-hash": - setter.SetDefault(&replicas, getEnvInteger(log, "GUBER_REPLICATED_HASH_REPLICAS"), defaultReplicas) + set.Default(&replicas, getEnvInteger(log, "GUBER_REPLICATED_HASH_REPLICAS"), DefaultReplicas) conf.Picker = NewReplicatedConsistentHash(nil, replicas) - setter.SetDefault(&hash, os.Getenv("GUBER_PEER_PICKER_HASH"), "fnv1a") + set.Default(&hash, os.Getenv("GUBER_PEER_PICKER_HASH"), "fnv1a") hashFuncs := map[string]HashString64{ "fnv1a": fnv1a.HashString64, "fnv1": fnv1.HashString64, @@ -486,8 +502,6 @@ func SetupDaemonConfig(logger *logrus.Logger, configFile io.Reader) (DaemonConfi log.Debug(spew.Sdump(conf)) } - setter.SetDefault(&conf.TraceLevel, GetTracingLevel()) - return conf, nil } @@ -528,15 +542,15 @@ func setupEtcdTLS(conf *etcd.Config) error { // set `GUBER_ETCD_TLS_ENABLE` and this line will // create a TLS config with no config. - setter.SetDefault(&conf.TLS, &tls.Config{}) + set.Default(&conf.TLS, &tls.Config{}) - setter.SetDefault(&tlsCertFile, os.Getenv("GUBER_ETCD_TLS_CERT")) - setter.SetDefault(&tlsKeyFile, os.Getenv("GUBER_ETCD_TLS_KEY")) - setter.SetDefault(&tlsCAFile, os.Getenv("GUBER_ETCD_TLS_CA")) + set.Default(&tlsCertFile, os.Getenv("GUBER_ETCD_TLS_CERT")) + set.Default(&tlsKeyFile, os.Getenv("GUBER_ETCD_TLS_KEY")) + set.Default(&tlsCAFile, os.Getenv("GUBER_ETCD_TLS_CA")) // If the CA file was provided if tlsCAFile != "" { - setter.SetDefault(&conf.TLS, &tls.Config{}) + set.Default(&conf.TLS, &tls.Config{}) var certPool *x509.CertPool = nil if pemBytes, err := os.ReadFile(tlsCAFile); err == nil { @@ -545,7 +559,7 @@ func setupEtcdTLS(conf *etcd.Config) error { } else { return errors.Wrapf(err, "while loading cert CA file '%s'", tlsCAFile) } - setter.SetDefault(&conf.TLS.RootCAs, certPool) + set.Default(&conf.TLS.RootCAs, certPool) conf.TLS.InsecureSkipVerify = false } @@ -556,13 +570,13 @@ func setupEtcdTLS(conf *etcd.Config) error { return errors.Wrapf(err, "while loading cert '%s' and key file '%s'", tlsCertFile, tlsKeyFile) } - setter.SetDefault(&conf.TLS.Certificates, []tls.Certificate{tlsCert}) + set.Default(&conf.TLS.Certificates, []tls.Certificate{tlsCert}) } // If no other TLS config is provided this will force connecting with TLS, // without cert verification if os.Getenv("GUBER_ETCD_TLS_SKIP_VERIFY") != "" { - setter.SetDefault(&conf.TLS, &tls.Config{}) + set.Default(&conf.TLS, &tls.Config{}) conf.TLS.InsecureSkipVerify = true } return nil @@ -577,20 +591,20 @@ func anyHasPrefix(prefix string, items []string) bool { return false } -func getEnvBool(log logrus.FieldLogger, name string) bool { +func getEnvBool(log FieldLogger, name string) bool { v := os.Getenv(name) if v == "" { return false } b, err := strconv.ParseBool(v) if err != nil { - log.WithError(err).Errorf("while parsing '%s' as an boolean", name) + log.LogAttrs(context.TODO(), slog.LevelError, "while parsing boolean", ErrAttr(err), slog.String("name", name)) return false } return b } -func getEnvMinVersion(log logrus.FieldLogger, name string) uint16 { +func getEnvMinVersion(log FieldLogger, name string) uint16 { v := os.Getenv(name) if v == "" { return tls.VersionTLS13 @@ -603,33 +617,42 @@ func getEnvMinVersion(log logrus.FieldLogger, name string) uint16 { } version, ok := minVersion[v] if !ok { - log.WithError(fmt.Errorf("unknown tls version: %s", v)).Errorf("while parsing '%s' as an min tls version, defaulting to 1.3", name) + log.LogAttrs(context.TODO(), slog.LevelError, "while parsing min tls version, defaulting to 1.3", + ErrAttr(fmt.Errorf("unknown tls version: %s", v)), + slog.String("name", name), + ) return tls.VersionTLS13 } return version } -func getEnvInteger(log logrus.FieldLogger, name string) int { +func getEnvInteger(log FieldLogger, name string) int { v := os.Getenv(name) if v == "" { return 0 } i, err := strconv.ParseInt(v, 10, 64) if err != nil { - log.WithError(err).Errorf("while parsing '%s' as an integer", name) + log.LogAttrs(context.TODO(), slog.LevelError, "while parsing as an integer", + ErrAttr(err), + slog.String("name", name), + ) return 0 } return int(i) } -func getEnvDuration(log logrus.FieldLogger, name string) time.Duration { +func getEnvDuration(log FieldLogger, name string) time.Duration { v := os.Getenv(name) if v == "" { return 0 } d, err := time.ParseDuration(v) if err != nil { - log.WithError(err).Errorf("while parsing '%s' as a duration", name) + log.LogAttrs(context.TODO(), slog.LevelError, "while parsing as a duration", + ErrAttr(err), + slog.String("name", name), + ) return 0 } return d @@ -645,7 +668,7 @@ func getEnvSlice(name string) []string { // Take values from a file in the format `GUBER_CONF_ITEM=my-value` and sets them as environment variables. // Lines that begin with `#` are ignored -func fromEnvFile(log logrus.FieldLogger, configFile io.Reader) error { +func fromEnvFile(log FieldLogger, configFile io.Reader) error { contents, err := io.ReadAll(configFile) if err != nil { return fmt.Errorf("while reading config file '%s': %s", configFile, err) @@ -657,7 +680,7 @@ func fromEnvFile(log logrus.FieldLogger, configFile io.Reader) error { continue } - log.Debugf("config: [%d] '%s'", i, line) + log.Debug("config", "i", i, "line", line) parts := strings.SplitN(line, "=", 2) if len(parts) != 2 { return errors.Errorf("malformed key=value on line '%d'", i) @@ -695,7 +718,7 @@ func GetInstanceID() string { // 1. The environment variable `GUBER_INSTANCE_ID` // 2. The id of the docker container we are running in // 3. Generate a random id - setter.SetDefault(&id, os.Getenv("GUBER_INSTANCE_ID"), getDockerCID(), generateID()) + set.Default(&id, os.Getenv("GUBER_INSTANCE_ID"), getDockerCID(), generateID()) return id } @@ -726,40 +749,3 @@ func getDockerCID() string { } return "" } - -func GetTracingLevel() tracing.Level { - s := os.Getenv("GUBER_TRACING_LEVEL") - lvl, ok := map[string]tracing.Level{ - "ERROR": tracing.ErrorLevel, - "INFO": tracing.InfoLevel, - "DEBUG": tracing.DebugLevel, - }[s] - if ok { - return lvl - } - return tracing.InfoLevel -} - -// TraceLevelInfoFilter is used with otelgrpc.WithInterceptorFilter() to -// reduce noise by filtering trace propagation on some gRPC methods. -// otelgrpc deprecated use of interceptors in v0.45.0 in favor of stats -// handlers to propagate trace context. -// However, stats handlers do not have a filter feature. -// See: https://github.com/open-telemetry/opentelemetry-go-contrib/issues/4575 -// var TraceLevelInfoFilter = otelgrpc.Filter(func(info *otelgrpc.InterceptorInfo) bool { -// if info.UnaryServerInfo != nil { -// if info.UnaryServerInfo.FullMethod == "/pb.gubernator.PeersV1/GetPeerRateLimits" { -// return false -// } -// if info.UnaryServerInfo.FullMethod == "/pb.gubernator.V1/HealthCheck" { -// return false -// } -// } -// if info.Method == "/pb.gubernator.PeersV1/GetPeerRateLimits" { -// return false -// } -// if info.Method == "/pb.gubernator.V1/HealthCheck" { -// return false -// } -// return true -// }) diff --git a/config_test.go b/config_test.go index 74672bd..d6808a4 100644 --- a/config_test.go +++ b/config_test.go @@ -2,22 +2,22 @@ package gubernator import ( "fmt" + "log/slog" "os" "strings" "testing" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" ) -func TestParsesGrpcAddress(t *testing.T) { +func TestParsesAddress(t *testing.T) { os.Clearenv() s := ` # a comment -GUBER_GRPC_ADDRESS=10.10.10.10:9000` - daemonConfig, err := SetupDaemonConfig(logrus.StandardLogger(), strings.NewReader(s)) +GUBER_HTTP_ADDRESS=10.10.10.10:9000` + daemonConfig, err := SetupDaemonConfig(slog.New(slog.NewTextHandler(os.Stderr, nil)), strings.NewReader(s)) require.NoError(t, err) - require.Equal(t, "10.10.10.10:9000", daemonConfig.GRPCListenAddress) + require.Equal(t, "10.10.10.10:9000", daemonConfig.HTTPListenAddress) require.NotEmpty(t, daemonConfig.InstanceID) } @@ -25,9 +25,8 @@ func TestDefaultListenAddress(t *testing.T) { os.Clearenv() s := ` # a comment` - daemonConfig, err := SetupDaemonConfig(logrus.StandardLogger(), strings.NewReader(s)) + daemonConfig, err := SetupDaemonConfig(slog.New(slog.NewTextHandler(os.Stderr, nil)), strings.NewReader(s)) require.NoError(t, err) - require.Equal(t, fmt.Sprintf("%s:1051", LocalHost()), daemonConfig.GRPCListenAddress) require.Equal(t, fmt.Sprintf("%s:1050", LocalHost()), daemonConfig.HTTPListenAddress) require.NotEmpty(t, daemonConfig.InstanceID) } @@ -35,7 +34,12 @@ func TestDefaultListenAddress(t *testing.T) { func TestDefaultInstanceId(t *testing.T) { os.Clearenv() s := `` - daemonConfig, err := SetupDaemonConfig(logrus.StandardLogger(), strings.NewReader(s)) + daemonConfig, err := SetupDaemonConfig(slog.New(slog.NewTextHandler(os.Stderr, nil)), strings.NewReader(s)) require.NoError(t, err) require.NotEmpty(t, daemonConfig.InstanceID) + + instanceConfig := Config{} + err = instanceConfig.SetDefaults() + require.NoError(t, err) + require.NotEmpty(t, instanceConfig.InstanceID) } diff --git a/daemon.go b/daemon.go index 24e11f3..bd0ca48 100644 --- a/daemon.go +++ b/daemon.go @@ -22,57 +22,45 @@ import ( "fmt" "io" "log" + "log/slog" "net" "net/http" "strings" "time" - "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" - "github.com/mailgun/holster/v4/errors" - "github.com/mailgun/holster/v4/etcdutil" - "github.com/mailgun/holster/v4/setter" - "github.com/mailgun/holster/v4/syncutil" - "github.com/mailgun/holster/v4/tracing" + "github.com/kapetan-io/errors" + "github.com/kapetan-io/tackle/set" + "github.com/kapetan-io/tackle/wait" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/collectors" "github.com/prometheus/client_golang/prometheus/promhttp" - "github.com/sirupsen/logrus" - "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/keepalive" - "google.golang.org/protobuf/encoding/protojson" + "golang.org/x/net/proxy" ) type Daemon struct { - GRPCListeners []net.Listener - HTTPListener net.Listener - V1Server *V1Instance - InstanceID string - PeerInfo PeerInfo - - log FieldLogger - logWriter *io.PipeWriter - pool PoolInterface - conf DaemonConfig - httpSrv *http.Server - httpSrvNoMTLS *http.Server - grpcSrvs []*grpc.Server - wg syncutil.WaitGroup - statsHandler *GRPCStatsHandler - promRegister *prometheus.Registry - gwCancel context.CancelFunc - instanceConf Config - client V1Client + wg wait.Group + httpServers []*http.Server + pool PoolInterface + conf DaemonConfig + Listener net.Listener + HealthListener net.Listener + log FieldLogger + PeerInfo PeerInfo + Service *Service + InstanceID string + logAdaptor io.WriteCloser } // SpawnDaemon starts a new gubernator daemon according to the provided DaemonConfig. -// This function will block until the daemon responds to connections as specified -// by GRPCListenAddress and HTTPListenAddress +// This function will block until the daemon responds to connections to HTTPListenAddress func SpawnDaemon(ctx context.Context, conf DaemonConfig) (*Daemon, error) { + set.Default(&conf.Logger, slog.Default().With( + "service-id", conf.InstanceID, + "category", "gubernator", + )) s := &Daemon{ + logAdaptor: newLogAdaptor(conf.Logger), InstanceID: conf.InstanceID, log: conf.Logger, conf: conf, @@ -80,410 +68,349 @@ func SpawnDaemon(ctx context.Context, conf DaemonConfig) (*Daemon, error) { return s, s.Start(ctx) } -func (s *Daemon) Start(ctx context.Context) error { +func (d *Daemon) Start(ctx context.Context) error { var err error - setter.SetDefault(&s.log, logrus.WithFields(logrus.Fields{ - "instance": s.conf.InstanceID, - "category": "gubernator", - })) - - s.promRegister = prometheus.NewRegistry() - - // The LRU cache for storing rate limits. - cacheCollector := NewLRUCacheCollector() - if err := s.promRegister.Register(cacheCollector); err != nil { - return errors.Wrap(err, "during call to promRegister.Register()") - } - - cacheFactory := func(maxSize int) Cache { - cache := NewLRUCache(maxSize) - cacheCollector.AddCache(cache) - return cache - } - - // Handler to collect duration and API access metrics for GRPC - s.statsHandler = NewGRPCStatsHandler() - _ = s.promRegister.Register(s.statsHandler) - - var filters []otelgrpc.Option - // otelgrpc deprecated use of interceptors in v0.45.0 in favor of stats - // handlers to propagate trace context. - // However, stats handlers do not have a filter feature. - // See: https://github.com/open-telemetry/opentelemetry-go-contrib/issues/4575 - // if s.conf.TraceLevel != tracing.DebugLevel { - // filters = []otelgrpc.Option{ - // otelgrpc.WithInterceptorFilter(TraceLevelInfoFilter), - // } - // } - - opts := []grpc.ServerOption{ - grpc.StatsHandler(s.statsHandler), - grpc.MaxRecvMsgSize(1024 * 1024), - - // OpenTelemetry instrumentation on gRPC endpoints. - grpc.StatsHandler(otelgrpc.NewServerHandler(filters...)), + // The cache for storing rate limits. + registry := prometheus.NewRegistry() + cacheCollector := NewCacheCollector() + if err := registry.Register(cacheCollector); err != nil { + return fmt.Errorf("during call to promRegister.Register(): %w", err) } - if s.conf.GRPCMaxConnectionAgeSeconds > 0 { - opts = append(opts, grpc.KeepaliveParams(keepalive.ServerParameters{ - MaxConnectionAge: time.Second * time.Duration(s.conf.GRPCMaxConnectionAgeSeconds), - MaxConnectionAgeGrace: time.Second * time.Duration(s.conf.GRPCMaxConnectionAgeSeconds), - })) + cacheFactory := func(maxSize int) (Cache, error) { + switch d.conf.CacheProvider { + case "default-lru": + cache := NewLRUCache(maxSize) + cacheCollector.AddCache(cache) + return cache, nil + case "otter", "": + cache, err := NewOtterCache(maxSize) + if err != nil { + return nil, err + } + cacheCollector.AddCache(cache) + return cache, nil + default: + return nil, errors.Errorf("'GUBER_CACHE_PROVIDER=%s' is invalid; "+ + "choices are ['otter', 'default-lru']", d.conf.CacheProvider) + } } - if err := SetupTLS(s.conf.TLS); err != nil { + if err := SetupTLS(d.conf.TLS); err != nil { return err } - if s.conf.ServerTLS() != nil { - // Create two GRPC server instances, one for TLS and the other for the API Gateway - opts2 := append(opts, grpc.Creds(credentials.NewTLS(s.conf.ServerTLS()))) - s.grpcSrvs = append(s.grpcSrvs, grpc.NewServer(opts2...)) - } - s.grpcSrvs = append(s.grpcSrvs, grpc.NewServer(opts...)) - - // Registers a new gubernator instance with the GRPC server - s.instanceConf = Config{ - PeerTraceGRPC: s.conf.TraceLevel >= tracing.DebugLevel, - PeerTLS: s.conf.ClientTLS(), - DataCenter: s.conf.DataCenter, - LocalPicker: s.conf.Picker, - GRPCServers: s.grpcSrvs, - Logger: s.log, - CacheFactory: cacheFactory, - Behaviors: s.conf.Behaviors, - CacheSize: s.conf.CacheSize, - Workers: s.conf.Workers, - InstanceID: s.conf.InstanceID, - EventChannel: s.conf.EventChannel, - } - - s.V1Server, err = NewV1Instance(s.instanceConf) - if err != nil { - return errors.Wrap(err, "while creating new gubernator instance") - } - - // V1Server instance also implements prometheus.Collector interface - _ = s.promRegister.Register(s.V1Server) - - l, err := net.Listen("tcp", s.conf.GRPCListenAddress) + d.Service, err = NewService(Config{ + PeerClientFactory: func(info PeerInfo) PeerClient { + return NewPeerClient(WithPeerInfo(info)) + }, + CacheFactory: cacheFactory, + EventChannel: d.conf.EventChannel, + InstanceID: d.conf.InstanceID, + DataCenter: d.conf.DataCenter, + CacheSize: d.conf.CacheSize, + Behaviors: d.conf.Behaviors, + Workers: d.conf.Workers, + LocalPicker: d.conf.Picker, + Loader: d.conf.Loader, + Store: d.conf.Store, + Logger: d.log, + }) if err != nil { - return errors.Wrap(err, "while starting GRPC listener") + return errors.Errorf("while creating new gubernator service: %w", err) } - s.GRPCListeners = append(s.GRPCListeners, l) - // Start serving GRPC Requests - s.wg.Go(func() { - s.log.Infof("GRPC Listening on %s ...", l.Addr().String()) - if err := s.grpcSrvs[0].Serve(l); err != nil { - s.log.WithError(err).Error("while starting GRPC server") - } - }) - - var gatewayAddr string - if s.conf.ServerTLS() != nil { - // We start a new local GRPC instance because we can't guarantee the TLS cert provided by the - // user has localhost or the local interface included in the certs' valid hostnames. If they are not - // included, it means the local gateway connections will not be able to connect. - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - return errors.Wrap(err, "while starting GRPC Gateway listener") - } - s.GRPCListeners = append(s.GRPCListeners, l) - - s.wg.Go(func() { - s.log.Infof("GRPC Gateway Listening on %s ...", l.Addr()) - if err := s.grpcSrvs[1].Serve(l); err != nil { - s.log.WithError(err).Error("while starting GRPC Gateway server") - } - }) - gatewayAddr = l.Addr().String() - } else { - gatewayAddr, err = ResolveHostIP(s.conf.GRPCListenAddress) - if err != nil { - return errors.Wrap(err, "while resolving GRPC gateway client address") - } - } + // Service implements prometheus.Collector interface + registry.MustRegister(d.Service) - switch s.conf.PeerDiscoveryType { + switch d.conf.PeerDiscoveryType { case "k8s": // Source our list of peers from kubernetes endpoint API - s.conf.K8PoolConf.OnUpdate = s.V1Server.SetPeers - s.pool, err = NewK8sPool(s.conf.K8PoolConf) + d.conf.K8PoolConf.OnUpdate = d.Service.SetPeers + d.pool, err = NewK8sPool(d.conf.K8PoolConf) if err != nil { - return errors.Wrap(err, "while querying kubernetes API") + return errors.Errorf("while querying kubernetes API: %w", err) } case "etcd": - s.conf.EtcdPoolConf.OnUpdate = s.V1Server.SetPeers + d.conf.EtcdPoolConf.OnUpdate = d.Service.SetPeers // Register ourselves with other peers via ETCD - s.conf.EtcdPoolConf.Client, err = etcdutil.NewClient(s.conf.EtcdPoolConf.EtcdConfig) + d.conf.EtcdPoolConf.Client, err = newEtcdClient(d.conf.EtcdPoolConf.EtcdConfig) if err != nil { - return errors.Wrap(err, "while connecting to etcd") + return errors.Errorf("while connecting to etcd: %w", err) } - s.pool, err = NewEtcdPool(s.conf.EtcdPoolConf) + d.pool, err = NewEtcdPool(d.conf.EtcdPoolConf) if err != nil { - return errors.Wrap(err, "while creating etcd pool") + return errors.Errorf("while creating etcd pool: %w", err) } case "dns": - s.conf.DNSPoolConf.OnUpdate = s.V1Server.SetPeers - s.pool, err = NewDNSPool(s.conf.DNSPoolConf) + d.conf.DNSPoolConf.OnUpdate = d.Service.SetPeers + d.pool, err = NewDNSPool(d.conf.DNSPoolConf) if err != nil { - return errors.Wrap(err, "while creating the DNS pool") + return errors.Errorf("while creating the DNS pool: %w", err) } case "member-list": - s.conf.MemberListPoolConf.OnUpdate = s.V1Server.SetPeers - s.conf.MemberListPoolConf.Logger = s.log + d.conf.MemberListPoolConf.OnUpdate = d.Service.SetPeers + d.conf.MemberListPoolConf.Logger = d.log // Register peer on the member list - s.pool, err = NewMemberListPool(ctx, s.conf.MemberListPoolConf) + d.pool, err = NewMemberListPool(ctx, d.conf.MemberListPoolConf) if err != nil { - return errors.Wrap(err, "while creating member list pool") + return errors.Errorf("while creating member list pool: %w", err) } } - // We override the default Marshaller to enable the `UseProtoNames` option. - // We do this is because the default JSONPb in 2.5.0 marshals proto structs using - // `camelCase`, while all the JSON annotations are `under_score`. - // Our protobuf files follow the convention described here - // https://developers.google.com/protocol-buffers/docs/style#message-and-field-names - // Camel case breaks unmarshalling our GRPC gateway responses with protobuf structs. - gateway := runtime.NewServeMux( - runtime.WithMarshalerOption(runtime.MIMEWildcard, &runtime.JSONPb{ - MarshalOptions: protojson.MarshalOptions{ - UseProtoNames: true, - EmitUnpopulated: true, - }, - UnmarshalOptions: protojson.UnmarshalOptions{ - DiscardUnknown: true, - }, - }), - ) - - // Set up an JSON Gateway API for our GRPC methods - var gwCtx context.Context - gwCtx, s.gwCancel = context.WithCancel(context.Background()) - err = RegisterV1HandlerFromEndpoint(gwCtx, gateway, gatewayAddr, - []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}) - if err != nil { - return errors.Wrap(err, "while registering GRPC gateway handler") - } - - // Serve the JSON Gateway and metrics handlers via standard HTTP/1 - mux := http.NewServeMux() - // Optionally collect process metrics - if s.conf.MetricFlags.Has(FlagOSMetrics) { - s.log.Debug("Collecting OS Metrics") - s.promRegister.MustRegister(collectors.NewProcessCollector( + if d.conf.MetricFlags.Has(FlagOSMetrics) { + d.log.Debug("Collecting OS Metrics") + registry.MustRegister(collectors.NewProcessCollector( collectors.ProcessCollectorOpts{Namespace: "gubernator"}, )) } // Optionally collect golang internal metrics - if s.conf.MetricFlags.Has(FlagGolangMetrics) { - s.log.Debug("Collecting Golang Metrics") - s.promRegister.MustRegister(collectors.NewGoCollector()) + if d.conf.MetricFlags.Has(FlagGolangMetrics) { + d.log.Debug("Collecting Golang Metrics") + registry.MustRegister(collectors.NewGoCollector()) + } + + handler := NewHandler(d.Service, promhttp.InstrumentMetricHandler( + registry, promhttp.HandlerFor(registry, promhttp.HandlerOpts{}), + )) + registry.MustRegister(handler) + + if d.conf.ServerTLS() != nil { + if err := d.spawnHTTPS(ctx, handler); err != nil { + return err + } + if d.conf.HTTPStatusListenAddress != "" { + if err := d.spawnHTTPHealthCheck(ctx, handler, registry); err != nil { + return err + } + } + } else { + if err := d.spawnHTTP(ctx, handler); err != nil { + return err + } } + d.PeerInfo = PeerInfo{ + HTTPAddress: d.Listener.Addr().String(), + DataCenter: d.conf.DataCenter, + } + + return nil +} + +// spawnHTTPHealthCheck spawns a plan HTTP listener for use by orchestration systems to preform health checks and +// collect metrics when TLS and client certs are in use. +func (d *Daemon) spawnHTTPHealthCheck(ctx context.Context, h *Handler, r *prometheus.Registry) error { + mux := http.NewServeMux() + mux.HandleFunc("/healthz", h.HealthZ) mux.Handle("/metrics", promhttp.InstrumentMetricHandler( - s.promRegister, promhttp.HandlerFor(s.promRegister, promhttp.HandlerOpts{}), + r, promhttp.HandlerFor(r, promhttp.HandlerOpts{}), )) - mux.Handle("/", gateway) - s.logWriter = newLogWriter(s.log) - log := log.New(s.logWriter, "", 0) - s.httpSrv = &http.Server{Addr: s.conf.HTTPListenAddress, Handler: mux, ErrorLog: log} + srv := &http.Server{ + ErrorLog: log.New(d.logAdaptor, "", 0), + Addr: d.conf.HTTPStatusListenAddress, + TLSConfig: d.conf.ServerTLS().Clone(), + Handler: mux, + } - s.HTTPListener, err = net.Listen("tcp", s.conf.HTTPListenAddress) + srv.TLSConfig.ClientAuth = tls.NoClientCert + var err error + d.HealthListener, err = net.Listen("tcp", d.conf.HTTPStatusListenAddress) if err != nil { - return errors.Wrap(err, "while starting HTTP listener") + return errors.Errorf("while starting HTTP listener for health metric: %w", err) } - httpListenerAddr := s.HTTPListener.Addr().String() - addrs := []string{httpListenerAddr} - - if s.conf.ServerTLS() != nil { - - // If configured, start another listener at configured address and server only - // /v1/HealthCheck while not requesting or verifying client certificate. - if s.conf.HTTPStatusListenAddress != "" { - muxNoMTLS := http.NewServeMux() - muxNoMTLS.Handle("/v1/HealthCheck", gateway) - s.httpSrvNoMTLS = &http.Server{ - Addr: s.conf.HTTPStatusListenAddress, - Handler: muxNoMTLS, - ErrorLog: log, - TLSConfig: s.conf.ServerTLS().Clone(), + d.wg.Go(func() { + d.log.LogAttrs(ctx, slog.LevelInfo, "HTTPS Health Check Listening", + slog.String("address", d.conf.HTTPStatusListenAddress), + ) + if err := srv.ServeTLS(d.HealthListener, "", ""); err != nil { + if !errors.Is(err, http.ErrServerClosed) { + d.log.LogAttrs(context.TODO(), slog.LevelError, "while starting TLS Status HTTP server", + ErrAttr(err), + ) } - s.httpSrvNoMTLS.TLSConfig.ClientAuth = tls.NoClientCert - httpListener, err := net.Listen("tcp", s.conf.HTTPStatusListenAddress) - if err != nil { - return errors.Wrap(err, "while starting HTTP listener for health metric") - } - httpAddr := httpListener.Addr().String() - addrs = append(addrs, httpAddr) - s.wg.Go(func() { - s.log.Infof("HTTPS Status Handler Listening on %s ...", httpAddr) - if err := s.httpSrvNoMTLS.ServeTLS(httpListener, "", ""); err != nil { - if !errors.Is(err, http.ErrServerClosed) { - s.log.WithError(err).Error("while starting TLS Status HTTP server") - } - } - }) } + }) - // This is to avoid any race conditions that might occur - // since the tls config is a shared pointer. - s.httpSrv.TLSConfig = s.conf.ServerTLS().Clone() - s.wg.Go(func() { - s.log.Infof("HTTPS Gateway Listening on %s ...", httpListenerAddr) - if err := s.httpSrv.ServeTLS(s.HTTPListener, "", ""); err != nil { - if !errors.Is(err, http.ErrServerClosed) { - s.log.WithError(err).Error("while starting TLS HTTP server") - } - } - }) - } else { - s.wg.Go(func() { - s.log.Infof("HTTP Gateway Listening on %s ...", httpListenerAddr) - if err := s.httpSrv.Serve(s.HTTPListener); err != nil { - if !errors.Is(err, http.ErrServerClosed) { - s.log.WithError(err).Error("while starting HTTP server") - } + if err := WaitForConnect(ctx, d.HealthListener.Addr().String(), nil); err != nil { + return err + } + + d.httpServers = append(d.httpServers, srv) + return nil +} + +func (d *Daemon) spawnHTTPS(ctx context.Context, mux http.Handler) error { + srv := &http.Server{ + ErrorLog: log.New(d.logAdaptor, "", 0), + TLSConfig: d.conf.ServerTLS().Clone(), + Addr: d.conf.HTTPListenAddress, + Handler: mux, + } + + var err error + d.Listener, err = net.Listen("tcp", d.conf.HTTPListenAddress) + if err != nil { + return errors.Errorf("while starting HTTPS listener: %w", err) + } + + d.wg.Go(func() { + d.log.LogAttrs(context.TODO(), slog.LevelInfo, "HTTPS Listening", + slog.String("address", d.conf.HTTPListenAddress), + ) + if err := srv.ServeTLS(d.Listener, "", ""); err != nil { + if !errors.Is(err, http.ErrServerClosed) { + d.log.LogAttrs(context.TODO(), slog.LevelError, "while starting TLS HTTP server", + ErrAttr(err), + ) } - }) + + } + }) + if err := WaitForConnect(ctx, d.Listener.Addr().String(), d.conf.ClientTLS()); err != nil { + return err } - // Validate we can reach the GRPC and HTTP endpoints before returning - for _, l := range s.GRPCListeners { - addrs = append(addrs, l.Addr().String()) + d.httpServers = append(d.httpServers, srv) + + return nil +} + +func (d *Daemon) spawnHTTP(ctx context.Context, h http.Handler) error { + srv := &http.Server{ + ErrorLog: log.New(d.logAdaptor, "", 0), + Addr: d.conf.HTTPListenAddress, + Handler: h, } - if err := WaitForConnect(ctx, addrs); err != nil { + var err error + d.Listener, err = net.Listen("tcp", d.conf.HTTPListenAddress) + if err != nil { + return errors.Errorf("while starting HTTP listener: %w", err) + } + + d.wg.Go(func() { + d.log.LogAttrs(context.TODO(), slog.LevelInfo, "HTTP Listening", + slog.String("address", d.conf.HTTPListenAddress), + ) + if err := srv.Serve(d.Listener); err != nil { + if !errors.Is(err, http.ErrServerClosed) { + d.log.LogAttrs(context.TODO(), slog.LevelError, "while starting HTTP server", + ErrAttr(err), + ) + } + } + }) + + if err := WaitForConnect(ctx, d.Listener.Addr().String(), nil); err != nil { return err } + d.httpServers = append(d.httpServers, srv) return nil } // Close gracefully closes all server connections and listening sockets -func (s *Daemon) Close() { - if s.httpSrv == nil && s.httpSrvNoMTLS == nil { - return +func (d *Daemon) Close(ctx context.Context) error { + if len(d.httpServers) == 0 { + return nil } - if s.pool != nil { - s.pool.Close() + for _, srv := range d.httpServers { + d.log.LogAttrs(context.TODO(), slog.LevelInfo, "Shutting down server", + slog.String("address", srv.Addr), + ) + _ = srv.Shutdown(ctx) } + d.httpServers = nil - s.log.Infof("HTTP Gateway close for %s ...", s.conf.HTTPListenAddress) - _ = s.httpSrv.Shutdown(context.Background()) - if s.httpSrvNoMTLS != nil { - s.log.Infof("HTTP Status Gateway close for %s ...", s.conf.HTTPStatusListenAddress) - _ = s.httpSrvNoMTLS.Shutdown(context.Background()) - } - for i, srv := range s.grpcSrvs { - s.log.Infof("GRPC close for %s ...", s.GRPCListeners[i].Addr()) - srv.GracefulStop() + if err := d.Service.Close(ctx); err != nil { + return err } - s.logWriter.Close() - _ = s.V1Server.Close() - s.wg.Stop() - s.statsHandler.Close() - s.gwCancel() - s.httpSrv = nil - s.httpSrvNoMTLS = nil - s.grpcSrvs = nil + d.Service = nil + + _ = d.logAdaptor.Close() + d.HealthListener = nil + d.Listener = nil + + _ = d.wg.Wait() + return nil } // SetPeers sets the peers for this daemon -func (s *Daemon) SetPeers(in []PeerInfo) { +func (d *Daemon) SetPeers(in []PeerInfo) { peers := make([]PeerInfo, len(in)) copy(peers, in) for i, p := range peers { - if s.conf.GRPCListenAddress == p.GRPCAddress { + peers[i].SetTLS(d.conf.ClientTLS()) + if d.conf.AdvertiseAddress == p.HTTPAddress { peers[i].IsOwner = true } } - s.V1Server.SetPeers(peers) + d.Service.SetPeers(peers) } // Config returns the current config for this Daemon -func (s *Daemon) Config() DaemonConfig { - return s.conf +func (d *Daemon) Config() DaemonConfig { + return d.conf } // Peers returns the peers this daemon knows about -func (s *Daemon) Peers() []PeerInfo { +func (d *Daemon) Peers() []PeerInfo { var peers []PeerInfo - for _, client := range s.V1Server.GetPeerList() { + for _, client := range d.Service.GetPeerList() { peers = append(peers, client.Info()) } return peers } -func (s *Daemon) MustClient() V1Client { - c, err := s.Client() +func (d *Daemon) MustClient() Client { + c, err := d.Client() if err != nil { - panic(fmt.Sprintf("[%s] failed to init daemon client - '%s'", s.InstanceID, err)) + panic(fmt.Sprintf("[%s] failed to init daemon client - '%d'", d.InstanceID, err)) } return c } -func (s *Daemon) Client() (V1Client, error) { - if s.client != nil { - return s.client, nil +func (d *Daemon) Client() (Client, error) { + if d.conf.TLS != nil { + return NewClient(WithTLS(d.conf.ClientTLS(), d.Listener.Addr().String())) } + return NewClient(WithNoTLS(d.Listener.Addr().String())) +} - conn, err := grpc.NewClient( - fmt.Sprintf("static:///%s", s.PeerInfo.GRPCAddress), - grpc.WithResolvers(NewStaticBuilder()), - grpc.WithTransportCredentials(insecure.NewCredentials())) - if err != nil { - return nil, err +// WaitForConnect waits until the passed address is accepting connections. +// It will continue to attempt a connection until context is canceled. +func WaitForConnect(ctx context.Context, address string, cfg *tls.Config) error { + if address == "" { + return fmt.Errorf("WaitForConnect() requires a valid address") } - s.client = NewV1Client(conn) - return s.client, nil -} -// WaitForConnect returns nil if the list of addresses is listening -// for connections; will block until context is cancelled. -func WaitForConnect(ctx context.Context, addresses []string) error { - var d net.Dialer - var errs []error + var errs []string for { - errs = nil - for _, addr := range addresses { - if addr == "" { - continue - } - - // TODO: golang 1.15.3 introduces tls.DialContext(). When we are ready to drop - // support for older versions we can detect tls and use the tls.DialContext to - // avoid the `http: TLS handshake error` we get when using TLS. - conn, err := d.DialContext(ctx, "tcp", addr) - if err != nil { - errs = append(errs, err) - continue - } - _ = conn.Close() + var d proxy.ContextDialer + if cfg != nil { + d = &tls.Dialer{Config: cfg} + } else { + d = &net.Dialer{} } - - if len(errs) == 0 { - break + conn, err := d.DialContext(ctx, "tcp", address) + if err == nil { + _ = conn.Close() + return nil } - - <-ctx.Done() - return ctx.Err() - } - - if len(errs) != 0 { - var errStrings []string - for _, err := range errs { - errStrings = append(errStrings, err.Error()) + errs = append(errs, err.Error()) + if ctx.Err() != nil { + errs = append(errs, ctx.Err().Error()) + return errors.New(strings.Join(errs, "\n")) } - return errors.New(strings.Join(errStrings, "\n")) + time.Sleep(time.Millisecond * 100) + continue } - return nil } diff --git a/dns.go b/dns.go index bc2ff1c..0be8442 100644 --- a/dns.go +++ b/dns.go @@ -18,19 +18,20 @@ package gubernator import ( "context" + "log/slog" "math/rand" "net" "os" "time" - "github.com/mailgun/holster/v4/setter" + "github.com/kapetan-io/tackle/set" + "github.com/miekg/dns" "github.com/pkg/errors" - "github.com/sirupsen/logrus" ) -// Adapted from TimothyYe/godns // DNSResolver represents a dns resolver +// Adapted from TimothyYe/godns type DNSResolver struct { Servers []string random *rand.Rand @@ -118,7 +119,7 @@ type DNSPoolConfig struct { // (Required) Filesystem path to "/etc/resolv.conf", override for testing ResolvConf string - // (Required) Own GRPC address + // (Required) Own advertise address OwnAddress string // (Required) Called when the list of gubernators in the pool updates @@ -135,10 +136,10 @@ type DNSPool struct { } func NewDNSPool(conf DNSPoolConfig) (*DNSPool, error) { - setter.SetDefault(&conf.Logger, logrus.WithField("category", "gubernator")) + set.Default(&conf.Logger, slog.New(slog.NewTextHandler(os.Stderr, nil)).With("category", "gubernator")) if conf.OwnAddress == "" { - return nil, errors.New("Advertise.GRPCAddress is required") + return nil, errors.New("AdvertiseAddress is required") } ctx, cancel := context.WithCancel(context.Background()) @@ -157,12 +158,11 @@ func peer(ip string, self string, ipv6 bool) PeerInfo { if ipv6 { ip = "[" + ip + "]" } - grpc := ip + ":1051" + addr := ip + ":1050" return PeerInfo{ DataCenter: "", - HTTPAddress: ip + ":1050", - GRPCAddress: grpc, - IsOwner: grpc == self, + HTTPAddress: addr, + IsOwner: addr == self, } } diff --git a/etcd.go b/etcd.go index 336dfb1..7e3b67c 100644 --- a/etcd.go +++ b/etcd.go @@ -18,21 +18,28 @@ package gubernator import ( "context" + "crypto/tls" + "crypto/x509" "encoding/json" - - "github.com/mailgun/holster/v4/clock" - "github.com/mailgun/holster/v4/errors" - "github.com/mailgun/holster/v4/setter" - "github.com/mailgun/holster/v4/syncutil" - "github.com/sirupsen/logrus" + "fmt" + "log/slog" + "os" + "strings" + + "github.com/kapetan-io/errors" + "github.com/kapetan-io/tackle/clock" + "github.com/kapetan-io/tackle/set" + "github.com/kapetan-io/tackle/wait" etcd "go.etcd.io/etcd/client/v3" + "google.golang.org/grpc/grpclog" ) const ( - etcdTimeout = clock.Second * 10 - backOffTimeout = clock.Second * 5 - leaseTTL = 30 - defaultBaseKey = "/gubernator/peers/" + etcdTimeout = clock.Second * 10 + backOffTimeout = clock.Second * 5 + leaseTTL = 30 + defaultBaseKey = "/gubernator/peers/" + localEtcdEndpoint = "127.0.0.1:2379" ) type PoolInterface interface { @@ -41,7 +48,7 @@ type PoolInterface interface { type EtcdPool struct { peers map[string]PeerInfo - wg syncutil.WaitGroup + wg wait.Group ctx context.Context cancelCtx context.CancelFunc watchChan etcd.WatchChan @@ -71,11 +78,11 @@ type EtcdPoolConfig struct { } func NewEtcdPool(conf EtcdPoolConfig) (*EtcdPool, error) { - setter.SetDefault(&conf.KeyPrefix, defaultBaseKey) - setter.SetDefault(&conf.Logger, logrus.WithField("category", "gubernator")) + set.Default(&conf.KeyPrefix, defaultBaseKey) + set.Default(&conf.Logger, slog.Default().With("category", "gubernator")) - if conf.Advertise.GRPCAddress == "" { - return nil, errors.New("Advertise.GRPCAddress is required") + if conf.Advertise.HTTPAddress == "" { + return nil, errors.New("Advertise.HTTPAddress is required") } if conf.Client == nil { @@ -130,7 +137,10 @@ func (e *EtcdPool) watchPeers() error { select { case <-ready: - e.log.Infof("watching for peer changes '%s' at revision %d", e.conf.KeyPrefix, revision) + e.log.LogAttrs(context.TODO(), slog.LevelInfo, "watching for peer changes", + slog.String("key_prefix", e.conf.KeyPrefix), + slog.Int64("revision", revision), + ) case <-clock.After(etcdTimeout): return errors.New("timed out while waiting for watcher.Watch() to start") } @@ -143,14 +153,14 @@ func (e *EtcdPool) collectPeers(revision *int64) error { resp, err := e.conf.Client.Get(ctx, e.conf.KeyPrefix, etcd.WithPrefix()) if err != nil { - return errors.Wrapf(err, "while fetching peer listing from '%s'", e.conf.KeyPrefix) + return errors.Errorf("while fetching peer listing from '%s': %w", e.conf.KeyPrefix, err) } peers := make(map[string]PeerInfo) // Collect all the peers for _, v := range resp.Kvs { p := e.unMarshallValue(v.Value) - peers[p.GRPCAddress] = p + peers[p.HTTPAddress] = p } e.peers = peers @@ -164,8 +174,10 @@ func (e *EtcdPool) unMarshallValue(v []byte) PeerInfo { // for backward compatible with older gubernator versions if err := json.Unmarshal(v, &p); err != nil { - e.log.WithError(err).Errorf("while unmarshalling peer info from key value") - return PeerInfo{GRPCAddress: string(v)} + e.log.LogAttrs(context.TODO(), slog.LevelError, "while unmarshalling peer info from key value", + ErrAttr(err), + ) + return PeerInfo{HTTPAddress: string(v)} } return p } @@ -175,18 +187,20 @@ func (e *EtcdPool) watch() error { // Initialize watcher if err := e.watchPeers(); err != nil { - return errors.Wrap(err, "while attempting to start watch") + return errors.Errorf("while attempting to start watch: %w", err) } e.wg.Until(func(done chan struct{}) bool { for response := range e.watchChan { if response.Canceled { - e.log.Infof("graceful watch shutdown") + e.log.Info("graceful watch shutdown") return false } if err := response.Err(); err != nil { - e.log.Errorf("watch error: %v", err) + e.log.LogAttrs(context.TODO(), slog.LevelError, "watch error", + ErrAttr(err), + ) goto restart } _ = e.collectPeers(&rev) @@ -203,8 +217,9 @@ func (e *EtcdPool) watch() error { } if err := e.watchPeers(); err != nil { - e.log.WithError(err). - Error("while attempting to restart watch") + e.log.LogAttrs(context.TODO(), slog.LevelError, "while attempting to restart watch", + ErrAttr(err), + ) select { case <-clock.After(backOffTimeout): return true @@ -219,12 +234,14 @@ func (e *EtcdPool) watch() error { } func (e *EtcdPool) register(peer PeerInfo) error { - instanceKey := e.conf.KeyPrefix + peer.GRPCAddress - e.log.Infof("Registering peer '%#v' with etcd", peer) + instanceKey := e.conf.KeyPrefix + peer.HTTPAddress + e.log.LogAttrs(context.TODO(), slog.LevelInfo, "Registering peer with etcd", + slog.Any("peer", peer), + ) b, err := json.Marshal(peer) if err != nil { - return errors.Wrap(err, "while marshalling PeerInfo") + return errors.Errorf("while marshalling PeerInfo: %w", err) } var keepAlive <-chan *etcd.LeaseKeepAliveResponse @@ -237,12 +254,12 @@ func (e *EtcdPool) register(peer PeerInfo) error { lease, err = e.conf.Client.Grant(ctx, leaseTTL) if err != nil { - return errors.Wrapf(err, "during grant lease") + return errors.Errorf("during grant lease: %w", err) } _, err = e.conf.Client.Put(ctx, instanceKey, string(b), etcd.WithLease(lease.ID)) if err != nil { - return errors.Wrap(err, "during put") + return errors.Errorf("during put: %w", err) } if keepAlive, err = e.conf.Client.KeepAlive(e.ctx, lease.ID); err != nil { @@ -255,15 +272,16 @@ func (e *EtcdPool) register(peer PeerInfo) error { // Attempt to register our instance with etcd if err = register(); err != nil { - return errors.Wrap(err, "during initial peer registration") + return errors.Errorf("during initial peer registration: %w", err) } e.wg.Until(func(done chan struct{}) bool { // If we have lost our keep alive, register again if keepAlive == nil { if err = register(); err != nil { - e.log.WithError(err). - Error("while attempting to re-register peer") + e.log.LogAttrs(context.TODO(), slog.LevelError, "while attempting to re-register peer", + ErrAttr(err), + ) select { case <-clock.After(backOffTimeout): return true @@ -297,13 +315,15 @@ func (e *EtcdPool) register(peer PeerInfo) error { case <-done: ctx, cancel := context.WithTimeout(context.Background(), etcdTimeout) if _, err := e.conf.Client.Delete(ctx, instanceKey); err != nil { - e.log.WithError(err). - Warn("during etcd delete") + e.log.LogAttrs(context.TODO(), slog.LevelError, "during etcd delete", + ErrAttr(err), + ) } if _, err := e.conf.Client.Revoke(ctx, lease.ID); err != nil { - e.log.WithError(err). - Warn("during lease revoke") + e.log.LogAttrs(context.TODO(), slog.LevelError, "during lease revoke", + ErrAttr(err), + ) } cancel() return false @@ -323,7 +343,7 @@ func (e *EtcdPool) callOnUpdate() { var peers []PeerInfo for _, p := range e.peers { - if p.GRPCAddress == e.conf.Advertise.GRPCAddress { + if p.HTTPAddress == e.conf.Advertise.HTTPAddress { p.IsOwner = true } peers = append(peers, p) @@ -332,13 +352,13 @@ func (e *EtcdPool) callOnUpdate() { e.conf.OnUpdate(peers) } -// Get peers list from etcd. +// GetPeers returns a list of peers from etcd. func (e *EtcdPool) GetPeers(ctx context.Context) ([]PeerInfo, error) { keyPrefix := e.conf.KeyPrefix resp, err := e.conf.Client.Get(ctx, keyPrefix, etcd.WithPrefix()) if err != nil { - return nil, errors.Wrapf(err, "while fetching peer listing from '%s'", keyPrefix) + return nil, errors.Errorf("while fetching peer listing from '%s': %w", keyPrefix, err) } var peers []PeerInfo @@ -350,3 +370,104 @@ func (e *EtcdPool) GetPeers(ctx context.Context) ([]PeerInfo, error) { return peers, nil } + +func init() { + // We check this here to avoid data race with GRPC go routines writing to the logger + if os.Getenv("ETCD3_DEBUG") != "" { + grpclog.SetLoggerV2(grpclog.NewLoggerV2WithVerbosity(os.Stderr, os.Stderr, os.Stderr, 4)) + } +} + +// newEtcdClient creates a new etcd.Client with the specified config where blanks +// are filled from environment variables by NewConfig. +// +// If the provided config is nil and no environment variables are set, it will +// return a client connecting without TLS via localhost:2379. +func newEtcdClient(cfg *etcd.Config) (*etcd.Client, error) { + var err error + if cfg, err = newConfig(cfg); err != nil { + return nil, fmt.Errorf("failed to build etcd config: %w", err) + } + + etcdClt, err := etcd.New(*cfg) + if err != nil { + return nil, fmt.Errorf("failed to create etcd client: %w", err) + } + return etcdClt, nil +} + +// NewConfig creates a new etcd.Config using environment variables. If an +// existing config is passed, it will fill in missing configuration using +// environment variables or defaults if they exists on the local system. +// +// If no environment variables are set, it will return a config set to +// connect without TLS via localhost:2379. +func newConfig(cfg *etcd.Config) (*etcd.Config, error) { + var envEndpoint, tlsCertFile, tlsKeyFile, tlsCAFile string + + set.Default(&cfg, &etcd.Config{}) + set.Default(&cfg.Username, os.Getenv("ETCD3_USER")) + set.Default(&cfg.Password, os.Getenv("ETCD3_PASSWORD")) + set.Default(&tlsCertFile, os.Getenv("ETCD3_TLS_CERT")) + set.Default(&tlsKeyFile, os.Getenv("ETCD3_TLS_KEY")) + set.Default(&tlsCAFile, os.Getenv("ETCD3_CA")) + + // Default to 5 second timeout, else connections hang indefinitely + set.Default(&cfg.DialTimeout, clock.Second*5) + // Or if the user provided a timeout + if timeout := os.Getenv("ETCD3_DIAL_TIMEOUT"); timeout != "" { + duration, err := clock.ParseDuration(timeout) + if err != nil { + return nil, errors.Errorf( + "ETCD3_DIAL_TIMEOUT='%s' is not a duration (1m|15s|24h): %s", timeout, err) + } + cfg.DialTimeout = duration + } + + defaultCfg := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + + // If the CA file was provided + if tlsCAFile != "" { + set.Default(&cfg.TLS, defaultCfg) + + var certPool *x509.CertPool = nil + if pemBytes, err := os.ReadFile(tlsCAFile); err == nil { + certPool = x509.NewCertPool() + certPool.AppendCertsFromPEM(pemBytes) + } else { + return nil, errors.Errorf("while loading cert CA file '%s': %s", tlsCAFile, err) + } + set.Default(&cfg.TLS.RootCAs, certPool) + cfg.TLS.InsecureSkipVerify = false + } + + // If the cert and key files are provided attempt to load them + if tlsCertFile != "" && tlsKeyFile != "" { + set.Default(&cfg.TLS, defaultCfg) + tlsCert, err := tls.LoadX509KeyPair(tlsCertFile, tlsKeyFile) + if err != nil { + return nil, errors.Errorf("while loading cert '%s' and key file '%s': %s", + tlsCertFile, tlsKeyFile, err) + } + set.Default(&cfg.TLS.Certificates, []tls.Certificate{tlsCert}) + } + + set.Default(&envEndpoint, os.Getenv("ETCD3_ENDPOINT"), localEtcdEndpoint) + set.Default(&cfg.Endpoints, strings.Split(envEndpoint, ",")) + + // If no other TLS config is provided this will force connecting with TLS, + // without cert verification + if os.Getenv("ETCD3_SKIP_VERIFY") != "" { + set.Default(&cfg.TLS, defaultCfg) + cfg.TLS.InsecureSkipVerify = true + } + + // Enable TLS with no additional configuration + if os.Getenv("ETCD3_ENABLE_TLS") != "" { + set.Default(&cfg.TLS, defaultCfg) + } + + return cfg, nil +} diff --git a/example.conf b/example.conf index f14c9c8..dfe109c 100644 --- a/example.conf +++ b/example.conf @@ -234,10 +234,6 @@ GUBER_INSTANCE_ID= # Set the name of the service which will be reported in traces # OTEL_SERVICE_NAME=gubernator -# Set the tracing level, this controls the number of spans included in a single trace. -# Valid options are (ERROR, INFO, DEBUG) Defaults to "ERROR" -# GUBER_TRACING_LEVEL=ERROR - # Set which sampler to use (always_on, always_off, traceidratio, parentbased_always_on, # parentbased_always_off, parentbased_traceidratio) # OTEL_TRACES_SAMPLER=always_on @@ -264,4 +260,13 @@ GUBER_INSTANCE_ID= ############################ # OTEL_EXPORTER_OTLP_PROTOCOL=otlp # OTEL_EXPORTER_OTLP_ENDPOINT=https://api.honeycomb.io -# OTEL_EXPORTER_OTLP_HEADERS=x-honeycomb-team= \ No newline at end of file +# OTEL_EXPORTER_OTLP_HEADERS=x-honeycomb-team= + +############################ +# Cache Providers +############################ +# +# Select the cache provider, available options are 'default-lru', 'otter' +# default-lru - A built in LRU implementation which uses a mutex +# otter - Is a lock-less cache implementation based on S3-FIFO algorithm (https://maypok86.github.io/otter/) +# GUBER_CACHE_PROVIDER=default-lru diff --git a/flags.go b/flags.go index f805d71..1c78a54 100644 --- a/flags.go +++ b/flags.go @@ -16,6 +16,11 @@ limitations under the License. package gubernator +import ( + "context" + "log/slog" +) + const ( FlagOSMetrics MetricFlags = 1 << iota FlagGolangMetrics @@ -50,7 +55,10 @@ func getEnvMetricFlags(log FieldLogger, name string) MetricFlags { case "golang": result.Set(FlagGolangMetrics, true) default: - log.Errorf("invalid flag '%s' for '%s' valid options are ['os', 'golang']", f, name) + log.LogAttrs(context.TODO(), slog.LevelError, "invalid flag, valid options are ['os', 'golang']", + slog.String("flag", f), + slog.String("name", name), + ) } } return result diff --git a/functional_test.go b/functional_test.go index a7bc2af..85a8d12 100644 --- a/functional_test.go +++ b/functional_test.go @@ -17,7 +17,6 @@ limitations under the License. package gubernator_test import ( - "bytes" "context" "fmt" "io" @@ -31,19 +30,16 @@ import ( "testing" "time" - guber "github.com/gubernator-io/gubernator/v2" - "github.com/gubernator-io/gubernator/v2/cluster" + guber "github.com/gubernator-io/gubernator/v3" + "github.com/gubernator-io/gubernator/v3/cluster" + "github.com/kapetan-io/tackle/clock" + "github.com/kapetan-io/tackle/wait" "github.com/mailgun/errors" - "github.com/mailgun/holster/v4/clock" - "github.com/mailgun/holster/v4/syncutil" - "github.com/mailgun/holster/v4/testutil" "github.com/prometheus/common/expfmt" "github.com/prometheus/common/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/exp/maps" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" json "google.golang.org/protobuf/encoding/protojson" ) @@ -56,15 +52,15 @@ func TestMain(m *testing.M) { } code := m.Run() - cluster.Stop() + cluster.Stop(context.Background()) // os.Exit doesn't run deferred functions os.Exit(code) } func TestOverTheLimit(t *testing.T) { - client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) - require.NoError(t, err) + client, errs := guber.NewClient(cluster.GetRandomPeerOptions(cluster.DataCenterNone)) + require.Nil(t, errs) tests := []struct { Remaining int64 @@ -85,8 +81,9 @@ func TestOverTheLimit(t *testing.T) { } for _, test := range tests { - resp, err := client.GetRateLimits(context.Background(), &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(context.Background(), &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: "test_over_limit", UniqueKey: "account:1234", @@ -97,11 +94,12 @@ func TestOverTheLimit(t *testing.T) { Behavior: 0, }, }, - }) - require.NoError(t, err) + }, &resp) + require.Nil(t, err) rl := resp.Responses[0] + assert.Equal(t, "", rl.Error) assert.Equal(t, test.Status, rl.Status) assert.Equal(t, test.Remaining, rl.Remaining) assert.Equal(t, int64(2), rl.Limit) @@ -116,12 +114,13 @@ func TestMultipleAsync(t *testing.T) { // need to be changed. We want the test to forward both rate limits to other // nodes in the cluster. - t.Logf("Asking Peer: %s", cluster.GetPeers()[0].GRPCAddress) - client, errs := guber.DialV1Server(cluster.GetPeers()[0].GRPCAddress, nil) - require.NoError(t, errs) + t.Logf("Asking Peer: %s", cluster.GetPeers()[0].HTTPAddress) + client, errs := guber.NewClient(guber.WithNoTLS(cluster.GetPeers()[0].HTTPAddress)) + require.Nil(t, errs) - resp, err := client.GetRateLimits(context.Background(), &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(context.Background(), &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: "test_multiple_async", UniqueKey: "account:9234", @@ -141,8 +140,8 @@ func TestMultipleAsync(t *testing.T) { Behavior: 0, }, }, - }) - require.NoError(t, err) + }, &resp) + require.Nil(t, err) require.Len(t, resp.Responses, 2) @@ -158,11 +157,11 @@ func TestMultipleAsync(t *testing.T) { } func TestTokenBucket(t *testing.T) { - defer clock.Freeze(clock.Now()).Unfreeze() + defer clock.Freeze(clock.Now()).UnFreeze() - addr := cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress - client, err := guber.DialV1Server(addr, nil) - require.NoError(t, err) + addr := cluster.GetRandomPeerInfo(cluster.DataCenterNone).HTTPAddress + client, errs := guber.NewClient(guber.WithNoTLS(addr)) + require.Nil(t, errs) tests := []struct { name string @@ -192,8 +191,9 @@ func TestTokenBucket(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - resp, err := client.GetRateLimits(context.Background(), &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(context.Background(), &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: "test_token_bucket", UniqueKey: "account:1234", @@ -203,8 +203,8 @@ func TestTokenBucket(t *testing.T) { Hits: 1, }, }, - }) - require.NoError(t, err) + }, &resp) + require.Nil(t, err) rl := resp.Responses[0] @@ -219,10 +219,10 @@ func TestTokenBucket(t *testing.T) { } func TestTokenBucketGregorian(t *testing.T) { - defer clock.Freeze(clock.Now()).Unfreeze() + defer clock.Freeze(clock.Now()).UnFreeze() - client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) - require.NoError(t, err) + client, errs := guber.NewClient(guber.WithNoTLS(cluster.GetRandomPeerInfo(cluster.DataCenterNone).HTTPAddress)) + require.Nil(t, errs) tests := []struct { Name string @@ -266,8 +266,9 @@ func TestTokenBucketGregorian(t *testing.T) { for _, test := range tests { t.Run(test.Name, func(t *testing.T) { - resp, err := client.GetRateLimits(context.Background(), &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(context.Background(), &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: "test_token_bucket_greg", UniqueKey: "account:12345", @@ -278,8 +279,8 @@ func TestTokenBucketGregorian(t *testing.T) { Limit: 60, }, }, - }) - require.NoError(t, err) + }, &resp) + require.Nil(t, err) rl := resp.Responses[0] @@ -294,11 +295,11 @@ func TestTokenBucketGregorian(t *testing.T) { } func TestTokenBucketNegativeHits(t *testing.T) { - defer clock.Freeze(clock.Now()).Unfreeze() + defer clock.Freeze(clock.Now()).UnFreeze() - addr := cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress - client, err := guber.DialV1Server(addr, nil) - require.NoError(t, err) + addr := cluster.GetRandomPeerInfo(cluster.DataCenterNone).HTTPAddress + client, errs := guber.NewClient(guber.WithNoTLS(addr)) + require.Nil(t, errs) tests := []struct { name string @@ -339,8 +340,9 @@ func TestTokenBucketNegativeHits(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - resp, err := client.GetRateLimits(context.Background(), &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(context.Background(), &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: "test_token_bucket_negative", UniqueKey: "account:12345", @@ -350,8 +352,8 @@ func TestTokenBucketNegativeHits(t *testing.T) { Hits: tt.Hits, }, }, - }) - require.NoError(t, err) + }, &resp) + require.Nil(t, err) rl := resp.Responses[0] @@ -366,9 +368,9 @@ func TestTokenBucketNegativeHits(t *testing.T) { } func TestDrainOverLimit(t *testing.T) { - defer clock.Freeze(clock.Now()).Unfreeze() - client, err := guber.DialV1Server(cluster.PeerAt(0).GRPCAddress, nil) - require.NoError(t, err) + defer clock.Freeze(clock.Now()).UnFreeze() + client, errs := guber.NewClient(cluster.GetRandomPeerOptions(cluster.DataCenterNone)) + require.Nil(t, errs) tests := []struct { Name string @@ -404,8 +406,9 @@ func TestDrainOverLimit(t *testing.T) { for _, test := range tests { ctx := context.Background() t.Run(test.Name, func(t *testing.T) { - resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(ctx, &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: "test_drain_over_limit", UniqueKey: fmt.Sprintf("account:1234:%d", idx), @@ -416,7 +419,7 @@ func TestDrainOverLimit(t *testing.T) { Limit: 10, }, }, - }) + }, &resp) require.NoError(t, err) require.Len(t, resp.Responses, 1) @@ -432,16 +435,17 @@ func TestDrainOverLimit(t *testing.T) { } func TestTokenBucketRequestMoreThanAvailable(t *testing.T) { - defer clock.Freeze(clock.Now()).Unfreeze() + defer clock.Freeze(clock.Now()).UnFreeze() - client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) - require.NoError(t, err) + client, errs := guber.NewClient(cluster.GetRandomPeerOptions(cluster.DataCenterNone)) + require.Nil(t, errs) - sendHit := func(status guber.Status, remain int64, hit int64) *guber.RateLimitResp { + sendHit := func(status guber.Status, remain int64, hit int64) *guber.RateLimitResponse { ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) defer cancel() - resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(ctx, &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: "test_token_more_than_available", UniqueKey: "account:123456", @@ -451,7 +455,7 @@ func TestTokenBucketRequestMoreThanAvailable(t *testing.T) { Limit: 2000, }, }, - }) + }, &resp) require.NoError(t, err, hit) assert.Equal(t, "", resp.Responses[0].Error) assert.Equal(t, status, resp.Responses[0].Status) @@ -475,10 +479,10 @@ func TestTokenBucketRequestMoreThanAvailable(t *testing.T) { } func TestLeakyBucket(t *testing.T) { - defer clock.Freeze(clock.Now()).Unfreeze() + defer clock.Freeze(clock.Now()).UnFreeze() - client, err := guber.DialV1Server(cluster.PeerAt(0).GRPCAddress, nil) - require.NoError(t, err) + client, errs := guber.NewClient(guber.WithNoTLS(cluster.PeerAt(0).HTTPAddress)) + require.Nil(t, errs) tests := []struct { Name string @@ -575,8 +579,9 @@ func TestLeakyBucket(t *testing.T) { for _, test := range tests { t.Run(test.Name, func(t *testing.T) { - resp, err := client.GetRateLimits(context.Background(), &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(context.Background(), &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: "test_leaky_bucket", UniqueKey: "account:1234", @@ -586,7 +591,7 @@ func TestLeakyBucket(t *testing.T) { Limit: 10, }, }, - }) + }, &resp) require.NoError(t, err) require.Len(t, resp.Responses, 1) @@ -602,10 +607,10 @@ func TestLeakyBucket(t *testing.T) { } func TestLeakyBucketWithBurst(t *testing.T) { - defer clock.Freeze(clock.Now()).Unfreeze() + defer clock.Freeze(clock.Now()).UnFreeze() - client, err := guber.DialV1Server(cluster.PeerAt(0).GRPCAddress, nil) - require.NoError(t, err) + client, errs := guber.NewClient(guber.WithNoTLS(cluster.PeerAt(0).HTTPAddress)) + require.Nil(t, errs) tests := []struct { Name string @@ -681,8 +686,9 @@ func TestLeakyBucketWithBurst(t *testing.T) { for _, test := range tests { t.Run(test.Name, func(t *testing.T) { - resp, err := client.GetRateLimits(context.Background(), &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(context.Background(), &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: "test_leaky_bucket_with_burst", UniqueKey: "account:1234", @@ -693,7 +699,7 @@ func TestLeakyBucketWithBurst(t *testing.T) { Burst: 20, }, }, - }) + }, &resp) require.NoError(t, err) require.Len(t, resp.Responses, 1) @@ -709,10 +715,10 @@ func TestLeakyBucketWithBurst(t *testing.T) { } func TestLeakyBucketGregorian(t *testing.T) { - defer clock.Freeze(clock.Now()).Unfreeze() + defer clock.Freeze(clock.Now().Round(clock.Minute)).UnFreeze() - client, err := guber.DialV1Server(cluster.PeerAt(0).GRPCAddress, nil) - require.NoError(t, err) + client, errs := guber.NewClient(guber.WithNoTLS(cluster.PeerAt(0).HTTPAddress)) + require.Nil(t, errs) tests := []struct { Name string @@ -733,7 +739,7 @@ func TestLeakyBucketGregorian(t *testing.T) { Hits: 1, Remaining: 58, Status: guber.Status_UNDER_LIMIT, - Sleep: clock.Millisecond * 1200, + Sleep: clock.Second, }, { Name: "third hit; leak one hit", @@ -745,7 +751,7 @@ func TestLeakyBucketGregorian(t *testing.T) { // Truncate to the nearest minute. now := clock.Now() - trunc := now.Truncate(time.Hour) + trunc := now.Truncate(clock.Hour) trunc = now.Add(now.Sub(trunc)) clock.Advance(now.Sub(trunc)) @@ -754,8 +760,9 @@ func TestLeakyBucketGregorian(t *testing.T) { for _, test := range tests { t.Run(test.Name, func(t *testing.T) { - resp, err := client.GetRateLimits(context.Background(), &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(context.Background(), &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: name, UniqueKey: key, @@ -766,24 +773,26 @@ func TestLeakyBucketGregorian(t *testing.T) { Limit: 60, }, }, - }) + }, &resp) + clock.Freeze(clock.Now()) require.NoError(t, err) rl := resp.Responses[0] + assert.Equal(t, test.Status, rl.Status) assert.Equal(t, test.Remaining, rl.Remaining) assert.Equal(t, int64(60), rl.Limit) - assert.Greater(t, rl.ResetTime, now.Unix()) + assert.True(t, rl.ResetTime > now.Unix()) clock.Advance(test.Sleep) }) } } func TestLeakyBucketNegativeHits(t *testing.T) { - defer clock.Freeze(clock.Now()).Unfreeze() + defer clock.Freeze(clock.Now()).UnFreeze() - client, err := guber.DialV1Server(cluster.PeerAt(0).GRPCAddress, nil) - require.NoError(t, err) + client, errs := guber.NewClient(guber.WithNoTLS(cluster.PeerAt(0).HTTPAddress)) + require.Nil(t, errs) tests := []struct { Name string @@ -824,8 +833,9 @@ func TestLeakyBucketNegativeHits(t *testing.T) { for _, test := range tests { t.Run(test.Name, func(t *testing.T) { - resp, err := client.GetRateLimits(context.Background(), &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(context.Background(), &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: "test_leaky_bucket_negative", UniqueKey: "account:12345", @@ -835,7 +845,7 @@ func TestLeakyBucketNegativeHits(t *testing.T) { Limit: 10, }, }, - }) + }, &resp) require.NoError(t, err) require.Len(t, resp.Responses, 1) @@ -852,16 +862,17 @@ func TestLeakyBucketNegativeHits(t *testing.T) { func TestLeakyBucketRequestMoreThanAvailable(t *testing.T) { // Freeze time so we don't leak during the test - defer clock.Freeze(clock.Now()).Unfreeze() + defer clock.Freeze(clock.Now()).UnFreeze() - client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) - require.NoError(t, err) + client, errs := guber.NewClient(cluster.GetRandomPeerOptions(cluster.DataCenterNone)) + require.Nil(t, errs) - sendHit := func(status guber.Status, remain int64, hits int64) *guber.RateLimitResp { + sendHit := func(status guber.Status, remain int64, hits int64) *guber.RateLimitResponse { ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) defer cancel() - resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(ctx, &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: "test_leaky_more_than_available", UniqueKey: "account:123456", @@ -871,7 +882,7 @@ func TestLeakyBucketRequestMoreThanAvailable(t *testing.T) { Limit: 2000, }, }, - }) + }, &resp) require.NoError(t, err) assert.Equal(t, "", resp.Responses[0].Error) assert.Equal(t, status, resp.Responses[0].Status) @@ -895,16 +906,16 @@ func TestLeakyBucketRequestMoreThanAvailable(t *testing.T) { } func TestMissingFields(t *testing.T) { - client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) - require.NoError(t, err) + client, errs := guber.NewClient(guber.WithNoTLS(cluster.GetRandomPeerInfo(cluster.DataCenterNone).HTTPAddress)) + require.Nil(t, errs) tests := []struct { - Req *guber.RateLimitReq + Req *guber.RateLimitRequest Status guber.Status Error string }{ { - Req: &guber.RateLimitReq{ + Req: &guber.RateLimitRequest{ Name: "test_missing_fields", UniqueKey: "account:1234", Hits: 1, @@ -915,7 +926,7 @@ func TestMissingFields(t *testing.T) { Status: guber.Status_UNDER_LIMIT, }, { - Req: &guber.RateLimitReq{ + Req: &guber.RateLimitRequest{ Name: "test_missing_fields", UniqueKey: "account:12345", Hits: 1, @@ -926,7 +937,7 @@ func TestMissingFields(t *testing.T) { Status: guber.Status_OVER_LIMIT, }, { - Req: &guber.RateLimitReq{ + Req: &guber.RateLimitRequest{ UniqueKey: "account:1234", Hits: 1, Duration: 10000, @@ -936,7 +947,7 @@ func TestMissingFields(t *testing.T) { Status: guber.Status_UNDER_LIMIT, }, { - Req: &guber.RateLimitReq{ + Req: &guber.RateLimitRequest{ Name: "test_missing_fields", Hits: 1, Duration: 10000, @@ -948,10 +959,11 @@ func TestMissingFields(t *testing.T) { } for i, test := range tests { - resp, err := client.GetRateLimits(context.Background(), &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{test.Req}, - }) - require.NoError(t, err) + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(context.Background(), &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{test.Req}, + }, &resp) + require.Nil(t, err) assert.Equal(t, test.Error, resp.Responses[0].Error, i) assert.Equal(t, test.Status, resp.Responses[0].Status, i) } @@ -966,11 +978,12 @@ func TestGlobalRateLimits(t *testing.T) { require.NoError(t, err) var firstResetTime int64 - sendHit := func(client guber.V1Client, status guber.Status, hits, remain int64) { + sendHit := func(client guber.Client, status guber.Status, hits, remain int64) { ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) defer cancel() - resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(ctx, &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: name, UniqueKey: key, @@ -981,7 +994,7 @@ func TestGlobalRateLimits(t *testing.T) { Limit: 5, }, }, - }) + }, &resp) require.NoError(t, err) item := resp.Responses[0] assert.Equal(t, "", item.Error) @@ -1007,13 +1020,13 @@ func TestGlobalRateLimits(t *testing.T) { // Our second should be processed as if we own it since the async forward hasn't occurred yet sendHit(peers[0].MustClient(), guber.Status_UNDER_LIMIT, 2, 2) - testutil.UntilPass(t, 20, clock.Millisecond*200, func(t testutil.TestingT) { + assert.EventuallyWithT(t, func(collect *assert.CollectT) { // Inspect peers metrics, ensure the peer sent the global rate limit to the owner metricsURL := fmt.Sprintf("http://%s/metrics", peers[0].Config().HTTPListenAddress) m, err := getMetricRequest(metricsURL, "gubernator_global_send_duration_count") assert.NoError(t, err) assert.Equal(t, 1, int(m.Value)) - }) + }, 2*clock.Second, clock.Millisecond*200) require.NoError(t, waitForBroadcast(clock.Second*3, owner, 1)) @@ -1030,9 +1043,9 @@ func TestGlobalRateLimits(t *testing.T) { sendHit(peers[4].MustClient(), guber.Status_OVER_LIMIT, 1, 0) } -// Ensure global broadcast updates all peers when GetRateLimits is called on +// Ensure global broadcast updates all peers when CheckRateLimits is called on // either owner or non-owner peer. -func TestGlobalRateLimitsWithLoadBalancing(t *testing.T) { +func TestGlobalRateLimitsBroadcastUpdate(t *testing.T) { name := t.Name() key := guber.RandomString(10) @@ -1043,22 +1056,19 @@ func TestGlobalRateLimitsWithLoadBalancing(t *testing.T) { require.NoError(t, err) nonOwner := peers[0] - // Connect to owner and non-owner peers in round robin. - dialOpts := []grpc.DialOption{ - grpc.WithResolvers(guber.NewStaticBuilder()), - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"round_robin":{}}]}`), - } - address := fmt.Sprintf("static:///%s,%s", owner.PeerInfo.GRPCAddress, nonOwner.PeerInfo.GRPCAddress) - conn, err := grpc.NewClient(address, dialOpts...) + // Create a client for an owner and non-owner + client := owner.MustClient() require.NoError(t, err) - client := guber.NewV1Client(conn) - sendHit := func(client guber.V1Client, status guber.Status, i int) { + peerClient := nonOwner.MustClient() + require.NoError(t, err) + + sendHit := func(client guber.Client, status guber.Status, i int) { ctx, cancel := context.WithTimeout(context.Background(), 10*clock.Second) defer cancel() - resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(ctx, &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: name, UniqueKey: key, @@ -1069,7 +1079,7 @@ func TestGlobalRateLimitsWithLoadBalancing(t *testing.T) { Limit: 2, }, }, - }) + }, &resp) require.NoError(t, err, i) item := resp.Responses[0] assert.Equal(t, "", item.Error, fmt.Sprintf("unexpected error, iteration %d", i)) @@ -1081,7 +1091,7 @@ func TestGlobalRateLimitsWithLoadBalancing(t *testing.T) { // Send two hits that should be processed by the owner and non-owner and // deplete the limit consistently. sendHit(client, guber.Status_UNDER_LIMIT, 1) - sendHit(client, guber.Status_UNDER_LIMIT, 2) + sendHit(peerClient, guber.Status_UNDER_LIMIT, 2) require.NoError(t, waitForBroadcast(3*clock.Second, owner, 1)) // All successive hits should return OVER_LIMIT. @@ -1101,8 +1111,9 @@ func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { sendHit := func(expectedStatus guber.Status, hits, expectedRemaining int64) { ctx, cancel := context.WithTimeout(context.Background(), 10*clock.Second) defer cancel() - resp, err := peers[0].MustClient().GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := peers[0].MustClient().CheckRateLimits(ctx, &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: name, UniqueKey: key, @@ -1113,7 +1124,7 @@ func TestGlobalRateLimitsPeerOverLimit(t *testing.T) { Limit: 2, }, }, - }) + }, &resp) assert.NoError(t, err) item := resp.Responses[0] assert.Equal(t, "", item.Error, "unexpected error") @@ -1149,11 +1160,12 @@ func TestGlobalRequestMoreThanAvailable(t *testing.T) { peers, err := cluster.ListNonOwningDaemons(name, key) require.NoError(t, err) - sendHit := func(client guber.V1Client, expectedStatus guber.Status, hits int64, remaining int64) { + sendHit := func(client guber.Client, expectedStatus guber.Status, hits int64, remaining int64) { ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) defer cancel() - resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(ctx, &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: name, UniqueKey: key, @@ -1164,13 +1176,13 @@ func TestGlobalRequestMoreThanAvailable(t *testing.T) { Limit: 100, }, }, - }) + }, &resp) assert.NoError(t, err) assert.Equal(t, "", resp.Responses[0].GetError()) assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) } - require.NoError(t, waitForIdle(1*time.Minute, cluster.GetDaemons()...)) + require.NoError(t, waitForIdle(1*clock.Minute, cluster.GetDaemons()...)) prev := getMetricValue(t, owner, "gubernator_broadcast_duration_count") // Ensure GRPC has connections to each peer before we start, as we want @@ -1209,11 +1221,12 @@ func TestGlobalNegativeHits(t *testing.T) { peers, err := cluster.ListNonOwningDaemons(name, key) require.NoError(t, err) - sendHit := func(client guber.V1Client, status guber.Status, hits int64, remaining int64) { + sendHit := func(client guber.Client, status guber.Status, hits int64, remaining int64) { ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) defer cancel() - resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(ctx, &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: name, UniqueKey: key, @@ -1224,14 +1237,14 @@ func TestGlobalNegativeHits(t *testing.T) { Limit: 2, }, }, - }) + }, &resp) assert.NoError(t, err) assert.Equal(t, "", resp.Responses[0].GetError()) assert.Equal(t, status, resp.Responses[0].GetStatus()) assert.Equal(t, remaining, resp.Responses[0].Remaining) } - require.NoError(t, waitForIdle(1*time.Minute, cluster.GetDaemons()...)) + require.NoError(t, waitForIdle(1*clock.Minute, cluster.GetDaemons()...)) prev := getMetricValue(t, owner, "gubernator_broadcast_duration_count") require.NoError(t, err) @@ -1263,11 +1276,12 @@ func TestGlobalResetRemaining(t *testing.T) { peers, err := cluster.ListNonOwningDaemons(name, key) require.NoError(t, err) - sendHit := func(client guber.V1Client, expectedStatus guber.Status, hits int64, remaining int64) { + sendHit := func(client guber.Client, expectedStatus guber.Status, hits int64, remaining int64) { ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) defer cancel() - resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(ctx, &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: name, UniqueKey: key, @@ -1278,14 +1292,14 @@ func TestGlobalResetRemaining(t *testing.T) { Limit: 100, }, }, - }) + }, &resp) assert.NoError(t, err) assert.Equal(t, "", resp.Responses[0].GetError()) assert.Equal(t, expectedStatus, resp.Responses[0].GetStatus()) assert.Equal(t, remaining, resp.Responses[0].Remaining) } - require.NoError(t, waitForIdle(1*time.Minute, cluster.GetDaemons()...)) + require.NoError(t, waitForIdle(1*clock.Minute, cluster.GetDaemons()...)) prev := getMetricValue(t, owner, "gubernator_broadcast_duration_count") require.NoError(t, err) @@ -1303,8 +1317,9 @@ func TestGlobalResetRemaining(t *testing.T) { // Now reset the remaining ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) defer cancel() - resp, err := peers[0].MustClient().GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err = peers[0].MustClient().CheckRateLimits(ctx, &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: name, UniqueKey: key, @@ -1315,7 +1330,7 @@ func TestGlobalResetRemaining(t *testing.T) { Limit: 100, }, }, - }) + }, &resp) require.NoError(t, err) assert.NotEqual(t, 100, resp.Responses[0].Remaining) @@ -1323,8 +1338,8 @@ func TestGlobalResetRemaining(t *testing.T) { require.NoError(t, waitForBroadcast(clock.Second*10, owner, prev+2)) // Check a different peer to ensure remaining has been reset - resp, err = peers[1].MustClient().GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + err = peers[1].MustClient().CheckRateLimits(ctx, &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: name, UniqueKey: key, @@ -1335,14 +1350,14 @@ func TestGlobalResetRemaining(t *testing.T) { Limit: 100, }, }, - }) + }, &resp) require.NoError(t, err) assert.NotEqual(t, 100, resp.Responses[0].Remaining) } func TestChangeLimit(t *testing.T) { - client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) - require.NoError(t, err) + client, errs := guber.NewClient(guber.WithNoTLS(cluster.GetRandomPeerInfo(cluster.DataCenterNone).HTTPAddress)) + require.Nil(t, errs) tests := []struct { Remaining int64 @@ -1411,8 +1426,9 @@ func TestChangeLimit(t *testing.T) { for _, tt := range tests { t.Run(tt.Name, func(t *testing.T) { - resp, err := client.GetRateLimits(context.Background(), &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(context.Background(), &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: "test_change_limit", UniqueKey: "account:1234", @@ -1422,8 +1438,8 @@ func TestChangeLimit(t *testing.T) { Hits: 1, }, }, - }) - require.NoError(t, err) + }, &resp) + require.Nil(t, err) rl := resp.Responses[0] @@ -1436,8 +1452,8 @@ func TestChangeLimit(t *testing.T) { } func TestResetRemaining(t *testing.T) { - client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) - require.NoError(t, err) + client, errs := guber.NewClient(guber.WithNoTLS(cluster.GetRandomPeerInfo(cluster.DataCenterNone).HTTPAddress)) + require.Nil(t, errs) tests := []struct { Remaining int64 @@ -1483,8 +1499,9 @@ func TestResetRemaining(t *testing.T) { for _, tt := range tests { t.Run(tt.Name, func(t *testing.T) { - resp, err := client.GetRateLimits(context.Background(), &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(context.Background(), &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: "test_reset_remaining", UniqueKey: "account:1234", @@ -1495,8 +1512,8 @@ func TestResetRemaining(t *testing.T) { Hits: 1, }, }, - }) - require.NoError(t, err) + }, &resp) + require.Nil(t, err) rl := resp.Responses[0] @@ -1508,39 +1525,91 @@ func TestResetRemaining(t *testing.T) { } func TestHealthCheck(t *testing.T) { - // Check that the cluster is healthy to start with. - for _, peer := range cluster.GetDaemons() { - healthResp, err := peer.MustClient().HealthCheck(context.Background(), &guber.HealthCheckReq{}) + d := cluster.DaemonAt(0) + client, err := guber.NewClient(guber.WithNoTLS(d.Listener.Addr().String())) + require.NoError(t, err) + + // Check that the cluster is healthy to start with + var resp guber.HealthCheckResponse + err = client.HealthCheck(context.Background(), &resp) + require.NoError(t, err) + + require.Equal(t, "healthy", resp.GetStatus()) + + // Create a global rate limit that will need to be sent to all peers in the cluster + { + var resp guber.CheckRateLimitsResponse + err = client.CheckRateLimits(context.Background(), &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ + { + Name: "test_health_check", + UniqueKey: "account:12345", + Algorithm: guber.Algorithm_TOKEN_BUCKET, + Behavior: guber.Behavior_BATCHING, + Duration: guber.Second * 3, + Hits: 1, + Limit: 5, + }, + }, + }, &resp) require.NoError(t, err) - assert.Equal(t, "healthy", healthResp.Status) } - // Stop the cluster to ensure errors occur on our instance. - cluster.Stop() + // Stop the rest of the cluster to ensure errors occur on our instance + for i := 1; i < cluster.NumOfDaemons(); i++ { + d := cluster.DaemonAt(i) + require.NotNil(t, d) + _ = d.Close(context.Background()) + } - // Check the health again to get back the connection error. - testutil.UntilPass(t, 20, 300*clock.Millisecond, func(t testutil.TestingT) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - for _, peer := range cluster.GetDaemons() { - _, err := peer.MustClient().HealthCheck(ctx, &guber.HealthCheckReq{}) - assert.Error(t, err, "connect: connection refused") + // Hit the global rate limit again this time causing a connection error + { + var resp guber.CheckRateLimitsResponse + err = client.CheckRateLimits(context.Background(), &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ + { + Name: "test_health_check", + UniqueKey: "account:12345", + Algorithm: guber.Algorithm_TOKEN_BUCKET, + Behavior: guber.Behavior_GLOBAL, + Duration: guber.Second * 3, + Hits: 1, + Limit: 5, + }, + }, + }, &resp) + require.NoError(t, err) + } + + assert.EventuallyWithT(t, func(collect *assert.CollectT) { + // Check the health again to get back the connection error + var resp guber.HealthCheckResponse + err = client.HealthCheck(context.Background(), &resp) + if assert.Nil(t, err) { + return } - }) - // Restart so cluster is ready for next test. - require.NoError(t, startGubernator()) + assert.Equal(t, "unhealthy", resp.GetStatus()) + assert.Contains(t, resp.GetMessage(), "connect: connection refused") + }, 2*clock.Second, clock.Millisecond*300) + + // Restart stopped instances + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*15) + defer cancel() + require.NoError(t, cluster.Restart(ctx)) } func TestLeakyBucketDivBug(t *testing.T) { - defer clock.Freeze(clock.Now()).Unfreeze() + defer clock.Freeze(clock.Now()).UnFreeze() name := t.Name() key := guber.RandomString(10) - client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) + d := cluster.DaemonAt(0) + client, err := guber.NewClient(guber.WithNoTLS(d.Listener.Addr().String())) require.NoError(t, err) - resp, err := client.GetRateLimits(context.Background(), &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + var resp guber.CheckRateLimitsResponse + err = client.CheckRateLimits(context.Background(), &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: name, UniqueKey: key, @@ -1550,7 +1619,7 @@ func TestLeakyBucketDivBug(t *testing.T) { Limit: 2000, }, }, - }) + }, &resp) require.NoError(t, err) assert.Equal(t, "", resp.Responses[0].Error) assert.Equal(t, guber.Status_UNDER_LIMIT, resp.Responses[0].Status) @@ -1558,8 +1627,8 @@ func TestLeakyBucketDivBug(t *testing.T) { assert.Equal(t, int64(2000), resp.Responses[0].Limit) // Should result in a rate of 0.5 - resp, err = client.GetRateLimits(context.Background(), &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + err = client.CheckRateLimits(context.Background(), &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { Name: name, UniqueKey: key, @@ -1569,111 +1638,69 @@ func TestLeakyBucketDivBug(t *testing.T) { Limit: 2000, }, }, - }) + }, &resp) require.NoError(t, err) assert.Equal(t, int64(1899), resp.Responses[0].Remaining) assert.Equal(t, int64(2000), resp.Responses[0].Limit) } -func TestMultiRegion(t *testing.T) { - - // TODO: Queue a rate limit with multi region behavior on the DataCenterNone cluster - // TODO: Check the immediate response is correct - // TODO: Wait until the rate limit count shows up on the DataCenterOne and DataCenterTwo cluster - - // TODO: Increment the counts on the DataCenterTwo and DataCenterOne clusters - // TODO: Wait until both rate limit count show up on all datacenters -} - -func TestGRPCGateway(t *testing.T) { - name := t.Name() - key := guber.RandomString(10) - address := cluster.GetRandomPeer(cluster.DataCenterNone).HTTPAddress - resp, err := http.DefaultClient.Get("http://" + address + "/v1/HealthCheck") +func TestDefaultHealthZ(t *testing.T) { + address := cluster.GetRandomPeerInfo(cluster.DataCenterNone).HTTPAddress + resp, err := http.DefaultClient.Get("http://" + address + "/healthz") require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) b, err := io.ReadAll(resp.Body) - // This test ensures future upgrades don't accidentally change `under_score` to `camelCase` again. assert.Contains(t, string(b), "peer_count") - var hc guber.HealthCheckResp + var hc guber.HealthCheckResponse require.NoError(t, json.Unmarshal(b, &hc)) assert.Equal(t, int32(10), hc.PeerCount) require.NoError(t, err) - - payload, err := json.Marshal(&guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ - { - Name: name, - UniqueKey: key, - Duration: guber.Millisecond * 1000, - Hits: 1, - Limit: 10, - }, - }, - }) - require.NoError(t, err) - - resp, err = http.DefaultClient.Post("http://"+address+"/v1/GetRateLimits", - "application/json", bytes.NewReader(payload)) - require.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, http.StatusOK, resp.StatusCode) - b, err = io.ReadAll(resp.Body) - require.NoError(t, err) - var r guber.GetRateLimitsResp - - // NOTE: It is important to use 'protojson' instead of the standard 'json' package - // else the enums will not be converted properly and json.Unmarshal() will return an - // error. - require.NoError(t, json.Unmarshal(b, &r)) - require.Equal(t, 1, len(r.Responses)) - assert.Equal(t, guber.Status_UNDER_LIMIT, r.Responses[0].Status) } func TestGetPeerRateLimits(t *testing.T) { name := t.Name() ctx := context.Background() - peerClient, err := guber.NewPeerClient(guber.PeerConfig{ - Info: cluster.GetRandomPeer(cluster.DataCenterNone), + info := cluster.GetRandomPeerInfo(cluster.DataCenterNone) + + peerClient, err := guber.NewPeer(guber.PeerConfig{ + PeerClient: guber.NewPeerClient(guber.WithNoTLS(info.HTTPAddress)), + Info: info, }) require.NoError(t, err) t.Run("Stable rate check request order", func(t *testing.T) { // Ensure response order matches rate check request order. // Try various batch sizes. - createdAt := epochMillis(clock.Now()) testCases := []int{1, 2, 5, 10, 100, 1000} for _, n := range testCases { t.Run(fmt.Sprintf("Batch size %d", n), func(t *testing.T) { // Build request. - req := &guber.GetPeerRateLimitsReq{ - Requests: make([]*guber.RateLimitReq, n), + req := &guber.ForwardRequest{ + Requests: make([]*guber.RateLimitRequest, n), } for i := 0; i < n; i++ { - req.Requests[i] = &guber.RateLimitReq{ + req.Requests[i] = &guber.RateLimitRequest{ Name: name, - UniqueKey: guber.RandomString(10), + UniqueKey: fmt.Sprintf("%08x", i), Hits: 0, Limit: 1000 + int64(i), Duration: 1000, Algorithm: guber.Algorithm_TOKEN_BUCKET, Behavior: guber.Behavior_BATCHING, - CreatedAt: &createdAt, } } // Send request. - resp, err := peerClient.GetPeerRateLimits(ctx, req) + var resp guber.ForwardResponse + err := peerClient.ForwardBatch(ctx, req, &resp) // Verify. require.NoError(t, err) - require.NotNil(t, resp) assert.Len(t, resp.RateLimits, n) for i, item := range resp.RateLimits { @@ -1685,15 +1712,22 @@ func TestGetPeerRateLimits(t *testing.T) { }) } -// TODO: Add a test for sending no rate limits RateLimitReqList.RateLimits = nil +func TestNoRateLimits(t *testing.T) { + client, errs := guber.NewClient(cluster.GetRandomPeerOptions(cluster.DataCenterNone)) + require.Nil(t, errs) + + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(context.Background(), &guber.CheckRateLimitsRequest{}, &resp) + require.Error(t, err) +} func TestGlobalBehavior(t *testing.T) { const limit = 1000 - broadcastTimeout := 400 * time.Millisecond + broadcastTimeout := 400 * clock.Millisecond createdAt := epochMillis(clock.Now()) - makeReq := func(name, key string, hits int64) *guber.RateLimitReq { - return &guber.RateLimitReq{ + makeReq := func(name, key string, hits int64) *guber.RateLimitRequest { + return &guber.RateLimitRequest{ Name: name, UniqueKey: key, Algorithm: guber.Algorithm_TOKEN_BUCKET, @@ -1724,12 +1758,12 @@ func TestGlobalBehavior(t *testing.T) { require.NoError(t, err) t.Logf("Owner peer: %s", owner.InstanceID) - require.NoError(t, waitForIdle(1*time.Minute, cluster.GetDaemons()...)) + require.NoError(t, waitForIdle(1*clock.Minute, cluster.GetDaemons()...)) broadcastCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_broadcast_duration_count") updateCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_global_send_duration_count") - upgCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_grpc_request_duration_count{method=\"/pb.gubernator.PeersV1/UpdatePeerGlobals\"}") - gprlCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_grpc_request_duration_count{method=\"/pb.gubernator.PeersV1/GetPeerRateLimits\"}") + peerUpdateCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_http_handler_duration_count{path=\"/v1/peer.update\"}") + peerForwardCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_http_handler_duration_count{path=\"/v1/peer.forward\"}") // When for i := int64(0); i < testCase.Hits; i++ { @@ -1795,20 +1829,20 @@ func TestGlobalBehavior(t *testing.T) { // Assert UpdatePeerGlobals endpoint called once on each peer except owner. // Used by global broadcast. - upgCounters2 := getPeerCounters(t, cluster.GetDaemons(), "gubernator_grpc_request_duration_count{method=\"/pb.gubernator.PeersV1/UpdatePeerGlobals\"}") + upgCounters2 := getPeerCounters(t, cluster.GetDaemons(), "gubernator_http_handler_duration_count{path=\"/v1/peer.update\"}") for _, peer := range cluster.GetDaemons() { - expected := upgCounters[peer.InstanceID] + expected := peerUpdateCounters[peer.InstanceID] if peer.PeerInfo.DataCenter == cluster.DataCenterNone && peer.InstanceID != owner.InstanceID { expected++ } assert.Equal(t, expected, upgCounters2[peer.InstanceID]) } - // Assert PeerGetRateLimits endpoint not called. + // Assert PeerCheckRateLimits endpoint not called. // Used by global hits update. - gprlCounters2 := getPeerCounters(t, cluster.GetDaemons(), "gubernator_grpc_request_duration_count{method=\"/pb.gubernator.PeersV1/GetPeerRateLimits\"}") + gprlCounters2 := getPeerCounters(t, cluster.GetDaemons(), "gubernator_http_handler_duration_count{path=\"/v1/peer.forward\"}") for _, peer := range cluster.GetDaemons() { - expected := gprlCounters[peer.InstanceID] + expected := peerForwardCounters[peer.InstanceID] assert.Equal(t, expected, gprlCounters2[peer.InstanceID]) } @@ -1846,8 +1880,8 @@ func TestGlobalBehavior(t *testing.T) { broadcastCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_broadcast_duration_count") updateCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_global_send_duration_count") - upgCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_grpc_request_duration_count{method=\"/pb.gubernator.PeersV1/UpdatePeerGlobals\"}") - gprlCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_grpc_request_duration_count{method=\"/pb.gubernator.PeersV1/GetPeerRateLimits\"}") + peerUpdateCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_http_handler_duration_count{path=\"/v1/peer.update\"}") + peerForwardCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_http_handler_duration_count{path=\"/v1/peer.forward\"}") // When for i := int64(0); i < testCase.Hits; i++ { @@ -1914,20 +1948,20 @@ func TestGlobalBehavior(t *testing.T) { // Assert UpdatePeerGlobals endpoint called once on each peer except owner. // Used by global broadcast. - upgCounters2 := getPeerCounters(t, cluster.GetDaemons(), "gubernator_grpc_request_duration_count{method=\"/pb.gubernator.PeersV1/UpdatePeerGlobals\"}") + upgCounters2 := getPeerCounters(t, cluster.GetDaemons(), "gubernator_http_handler_duration_count{path=\"/v1/peer.update\"}") for _, peer := range cluster.GetDaemons() { - expected := upgCounters[peer.InstanceID] + expected := peerUpdateCounters[peer.InstanceID] if peer.PeerInfo.DataCenter == cluster.DataCenterNone && peer.InstanceID != owner.InstanceID { expected++ } assert.Equal(t, expected, upgCounters2[peer.InstanceID], "upgCounter %s", peer.InstanceID) } - // Assert PeerGetRateLimits endpoint called once on owner. + // Assert PeerCheckRateLimits endpoint called once on owner. // Used by global hits update. - gprlCounters2 := getPeerCounters(t, cluster.GetDaemons(), "gubernator_grpc_request_duration_count{method=\"/pb.gubernator.PeersV1/GetPeerRateLimits\"}") + gprlCounters2 := getPeerCounters(t, cluster.GetDaemons(), "gubernator_http_handler_duration_count{path=\"/v1/peer.forward\"}") for _, peer := range cluster.GetDaemons() { - expected := gprlCounters[peer.InstanceID] + expected := peerForwardCounters[peer.InstanceID] if peer.InstanceID == owner.InstanceID { expected++ } @@ -1975,8 +2009,8 @@ func TestGlobalBehavior(t *testing.T) { broadcastCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_broadcast_duration_count") updateCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_global_send_duration_count") - upgCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_grpc_request_duration_count{method=\"/pb.gubernator.PeersV1/UpdatePeerGlobals\"}") - gprlCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_grpc_request_duration_count{method=\"/pb.gubernator.PeersV1/GetPeerRateLimits\"}") + upgCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_http_handler_duration_count{path=\"/v1/peer.update\"}") + //gprlCounters := getPeerCounters(t, cluster.GetDaemons(), "gubernator_http_handler_duration_count{path=\"/v1/peer.forward\"}") expectUpdate := make(map[string]struct{}) var wg sync.WaitGroup var mutex sync.Mutex @@ -2063,7 +2097,7 @@ func TestGlobalBehavior(t *testing.T) { // Assert UpdatePeerGlobals endpoint called at least // once on each peer except owner. // Used by global broadcast. - upgCounters2 := getPeerCounters(t, cluster.GetDaemons(), "gubernator_grpc_request_duration_count{method=\"/pb.gubernator.PeersV1/UpdatePeerGlobals\"}") + upgCounters2 := getPeerCounters(t, cluster.GetDaemons(), "gubernator_http_handler_duration_count{path=\"/v1/peer.update\"}") for _, peer := range cluster.GetDaemons() { expected := upgCounters[peer.InstanceID] if peer.PeerInfo.DataCenter == cluster.DataCenterNone && peer.InstanceID != owner.InstanceID { @@ -2072,17 +2106,19 @@ func TestGlobalBehavior(t *testing.T) { assert.GreaterOrEqual(t, upgCounters2[peer.InstanceID], expected, "upgCounter %s", peer.InstanceID) } - // Assert PeerGetRateLimits endpoint called on owner + // Assert PeerCheckRateLimits endpoint called on owner // for each non-owner that received hits. // Used by global hits update. - gprlCounters2 := getPeerCounters(t, cluster.GetDaemons(), "gubernator_grpc_request_duration_count{method=\"/pb.gubernator.PeersV1/GetPeerRateLimits\"}") - for _, peer := range cluster.GetDaemons() { - expected := gprlCounters[peer.InstanceID] - if peer.InstanceID == owner.InstanceID { - expected += float64(len(expectUpdate)) - } - assert.Equal(t, expected, gprlCounters2[peer.InstanceID], "gprlCounter %s", peer.InstanceID) - } + // TODO(thrawn01): It is more important to verify the counts exist on each peer instead of how they got there. As the method of + // how they got there is an implementation detail and can/will change. Also, this test flaps occasionally. + //gprlCounters2 := getPeerCounters(t, cluster.GetDaemons(), "gubernator_http_handler_duration_count{path=\"/v1/peer.forward\"}") + //for _, peer := range cluster.GetDaemons() { + // expected := gprlCounters[peer.InstanceID] + // if peer.InstanceID == owner.InstanceID { + // expected += float64(len(expectUpdate)) + // } + // assert.Equal(t, expected, gprlCounters2[peer.InstanceID], "gprlCounter %s", peer.InstanceID) + //} // Verify all peers report consistent remaining value value. for _, peer := range cluster.GetDaemons() { @@ -2113,42 +2149,46 @@ func TestEventChannel(t *testing.T) { }() // Spawn specialized Gubernator cluster with EventChannel enabled. - cluster.Stop() + cluster.Stop(context.Background()) defer func() { err := startGubernator() require.NoError(t, err) }() peers := []guber.PeerInfo{ - {GRPCAddress: "127.0.0.1:10000", HTTPAddress: "127.0.0.1:10001", DataCenter: cluster.DataCenterNone}, - {GRPCAddress: "127.0.0.1:10002", HTTPAddress: "127.0.0.1:10003", DataCenter: cluster.DataCenterNone}, - {GRPCAddress: "127.0.0.1:10004", HTTPAddress: "127.0.0.1:10005", DataCenter: cluster.DataCenterNone}, + {HTTPAddress: "127.0.0.1:10001", DataCenter: cluster.DataCenterNone}, + {HTTPAddress: "127.0.0.1:10002", DataCenter: cluster.DataCenterNone}, + {HTTPAddress: "127.0.0.1:10003", DataCenter: cluster.DataCenterNone}, } err := cluster.StartWith(peers, cluster.WithEventChannel(eventChannel)) require.NoError(t, err) - defer cluster.Stop() + defer cluster.Stop(context.Background()) + + addr := cluster.GetRandomPeerInfo(cluster.DataCenterNone).HTTPAddress + client, err := guber.NewClient(guber.WithNoTLS(addr)) + require.Nil(t, err) - client, err := guber.DialV1Server(cluster.GetRandomPeer(cluster.DataCenterNone).GRPCAddress, nil) - require.NoError(t, err) sendHit := func(key string, behavior guber.Behavior) { ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) defer cancel() - _, err = client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{ + + var resp guber.CheckRateLimitsResponse + err = client.CheckRateLimits(ctx, &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{ { - Name: "test", - UniqueKey: key, Algorithm: guber.Algorithm_TOKEN_BUCKET, - Behavior: behavior, Duration: guber.Minute * 3, - Hits: 2, + Behavior: behavior, + Name: "test", Limit: 1000, + UniqueKey: key, + Hits: 2, }, }, - }) + }, &resp) require.NoError(t, err) select { case <-sem: - case <-time.After(3 * time.Second): + case <-time.After(3 * clock.Second): t.Fatal("Timeout waiting for EventChannel handler") } } @@ -2315,20 +2355,25 @@ func waitForUpdate(timeout clock.Duration, d *guber.Daemon, expect float64) erro // waitForIdle waits until both global broadcast and global hits queues are // empty. func waitForIdle(timeout clock.Duration, daemons ...*guber.Daemon) error { - var wg syncutil.WaitGroup + var wg wait.Group ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() for _, d := range daemons { - wg.Run(func(raw any) error { - d := raw.(*guber.Daemon) + wg.Run(func() error { for { metrics, err := getMetrics(d.Config().HTTPListenAddress, "gubernator_global_queue_length", "gubernator_global_send_queue_length") if err != nil { return err } - ggql := metrics["gubernator_global_queue_length"] - gsql := metrics["gubernator_global_send_queue_length"] + ggql, ok := metrics["gubernator_global_queue_length"] + if !ok { + return errors.New("gubernator_global_queue_length not found") + } + gsql, ok := metrics["gubernator_global_send_queue_length"] + if !ok { + return errors.New("gubernator_global_send_queue_length not found") + } if ggql.Value == 0 && gsql.Value == 0 { return nil @@ -2340,13 +2385,9 @@ func waitForIdle(timeout clock.Duration, daemons ...*guber.Daemon) error { return ctx.Err() } } - }, d) - } - errs := wg.Wait() - if len(errs) > 0 { - return errs[0] + }) } - return nil + return wg.Wait() } func getMetricValue(t *testing.T, d *guber.Daemon, name string) float64 { @@ -2368,16 +2409,20 @@ func getPeerCounters(t *testing.T, peers []*guber.Daemon, name string) map[strin return counters } -func sendHit(t *testing.T, d *guber.Daemon, req *guber.RateLimitReq, expectStatus guber.Status, expectRemaining int64) { +func sendHit(t *testing.T, d *guber.Daemon, req *guber.RateLimitRequest, expectStatus guber.Status, expectRemaining int64) { + t.Helper() + if req.Hits != 0 { t.Logf("Sending %d hits to peer %s", req.Hits, d.InstanceID) } client := d.MustClient() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) defer cancel() - resp, err := client.GetRateLimits(ctx, &guber.GetRateLimitsReq{ - Requests: []*guber.RateLimitReq{req}, - }) + + var resp guber.CheckRateLimitsResponse + err := client.CheckRateLimits(ctx, &guber.CheckRateLimitsRequest{ + Requests: []*guber.RateLimitRequest{req}, + }, &resp) require.NoError(t, err) item := resp.Responses[0] assert.Equal(t, "", item.Error) @@ -2394,18 +2439,18 @@ func epochMillis(t time.Time) int64 { func startGubernator() error { err := cluster.StartWith([]guber.PeerInfo{ - {GRPCAddress: "127.0.0.1:9990", HTTPAddress: "127.0.0.1:9980", DataCenter: cluster.DataCenterNone}, - {GRPCAddress: "127.0.0.1:9991", HTTPAddress: "127.0.0.1:9981", DataCenter: cluster.DataCenterNone}, - {GRPCAddress: "127.0.0.1:9992", HTTPAddress: "127.0.0.1:9982", DataCenter: cluster.DataCenterNone}, - {GRPCAddress: "127.0.0.1:9993", HTTPAddress: "127.0.0.1:9983", DataCenter: cluster.DataCenterNone}, - {GRPCAddress: "127.0.0.1:9994", HTTPAddress: "127.0.0.1:9984", DataCenter: cluster.DataCenterNone}, - {GRPCAddress: "127.0.0.1:9995", HTTPAddress: "127.0.0.1:9985", DataCenter: cluster.DataCenterNone}, + {HTTPAddress: "127.0.0.1:9980", DataCenter: cluster.DataCenterNone}, + {HTTPAddress: "127.0.0.1:9981", DataCenter: cluster.DataCenterNone}, + {HTTPAddress: "127.0.0.1:9982", DataCenter: cluster.DataCenterNone}, + {HTTPAddress: "127.0.0.1:9983", DataCenter: cluster.DataCenterNone}, + {HTTPAddress: "127.0.0.1:9984", DataCenter: cluster.DataCenterNone}, + {HTTPAddress: "127.0.0.1:9985", DataCenter: cluster.DataCenterNone}, // DataCenterOne - {GRPCAddress: "127.0.0.1:9890", HTTPAddress: "127.0.0.1:9880", DataCenter: cluster.DataCenterOne}, - {GRPCAddress: "127.0.0.1:9891", HTTPAddress: "127.0.0.1:9881", DataCenter: cluster.DataCenterOne}, - {GRPCAddress: "127.0.0.1:9892", HTTPAddress: "127.0.0.1:9882", DataCenter: cluster.DataCenterOne}, - {GRPCAddress: "127.0.0.1:9893", HTTPAddress: "127.0.0.1:9883", DataCenter: cluster.DataCenterOne}, + {HTTPAddress: "127.0.0.1:9880", DataCenter: cluster.DataCenterOne}, + {HTTPAddress: "127.0.0.1:9881", DataCenter: cluster.DataCenterOne}, + {HTTPAddress: "127.0.0.1:9882", DataCenter: cluster.DataCenterOne}, + {HTTPAddress: "127.0.0.1:9883", DataCenter: cluster.DataCenterOne}, }) if err != nil { return errors.Wrap(err, "while starting cluster") diff --git a/global.go b/global.go index c5fe167..788b1cf 100644 --- a/global.go +++ b/global.go @@ -18,8 +18,9 @@ package gubernator import ( "context" + "log/slog" - "github.com/mailgun/holster/v4/syncutil" + "github.com/kapetan-io/tackle/wait" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "google.golang.org/protobuf/proto" @@ -28,23 +29,23 @@ import ( // globalManager manages async hit queue and updates peers in // the cluster periodically when a global rate limit we own updates. type globalManager struct { - hitsQueue chan *RateLimitReq - broadcastQueue chan *RateLimitReq - wg syncutil.WaitGroup + hitsQueue chan *RateLimitRequest + broadcastQueue chan *RateLimitRequest + wg wait.Group conf BehaviorConfig log FieldLogger - instance *V1Instance // TODO circular import? V1Instance also holds a reference to globalManager + instance *Service metricGlobalSendDuration prometheus.Summary metricGlobalSendQueueLength prometheus.Gauge metricBroadcastDuration prometheus.Summary metricGlobalQueueLength prometheus.Gauge } -func newGlobalManager(conf BehaviorConfig, instance *V1Instance) *globalManager { +func newGlobalManager(conf BehaviorConfig, instance *Service) *globalManager { gm := globalManager{ log: instance.log, - hitsQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), - broadcastQueue: make(chan *RateLimitReq, conf.GlobalBatchLimit), + hitsQueue: make(chan *RateLimitRequest, conf.GlobalBatchLimit), + broadcastQueue: make(chan *RateLimitRequest, conf.GlobalBatchLimit), instance: instance, conf: conf, metricGlobalSendDuration: prometheus.NewSummary(prometheus.SummaryOpts{ @@ -54,7 +55,7 @@ func newGlobalManager(conf BehaviorConfig, instance *V1Instance) *globalManager }), metricGlobalSendQueueLength: prometheus.NewGauge(prometheus.GaugeOpts{ Name: "gubernator_global_send_queue_length", - Help: "The count of requests queued up for global broadcast. This is only used for GetRateLimit requests using global behavior.", + Help: "The count of requests queued up for global broadcast. This is only used for CheckRateLimit requests using global behavior.", }), metricBroadcastDuration: prometheus.NewSummary(prometheus.SummaryOpts{ Name: "gubernator_broadcast_duration", @@ -63,7 +64,7 @@ func newGlobalManager(conf BehaviorConfig, instance *V1Instance) *globalManager }), metricGlobalQueueLength: prometheus.NewGauge(prometheus.GaugeOpts{ Name: "gubernator_global_queue_length", - Help: "The count of requests queued up for global broadcast. This is only used for GetRateLimit requests using global behavior.", + Help: "The count of requests queued up for global broadcast. This is only used for CheckRateLimit requests using global behavior.", }), } gm.runAsyncHits() @@ -71,13 +72,13 @@ func newGlobalManager(conf BehaviorConfig, instance *V1Instance) *globalManager return &gm } -func (gm *globalManager) QueueHit(r *RateLimitReq) { +func (gm *globalManager) QueueHit(r *RateLimitRequest) { if r.Hits != 0 { gm.hitsQueue <- r } } -func (gm *globalManager) QueueUpdate(req *RateLimitReq) { +func (gm *globalManager) QueueUpdate(req *RateLimitRequest) { if req.Hits != 0 { gm.broadcastQueue <- req } @@ -90,7 +91,7 @@ func (gm *globalManager) QueueUpdate(req *RateLimitReq) { // and in a periodic frequency determined by GlobalSyncWait. func (gm *globalManager) runAsyncHits() { var interval = NewInterval(gm.conf.GlobalSyncWait) - hits := make(map[string]*RateLimitReq) + hits := make(map[string]*RateLimitRequest) gm.wg.Until(func(done chan struct{}) bool { @@ -114,7 +115,7 @@ func (gm *globalManager) runAsyncHits() { // Send the hits if we reached our batch limit if len(hits) == gm.conf.GlobalBatchLimit { gm.sendHits(hits) - hits = make(map[string]*RateLimitReq) + hits = make(map[string]*RateLimitRequest) gm.metricGlobalSendQueueLength.Set(0) return true } @@ -128,7 +129,7 @@ func (gm *globalManager) runAsyncHits() { case <-interval.C: if len(hits) != 0 { gm.sendHits(hits) - hits = make(map[string]*RateLimitReq) + hits = make(map[string]*RateLimitRequest) gm.metricGlobalSendQueueLength.Set(0) } case <-done: @@ -141,10 +142,10 @@ func (gm *globalManager) runAsyncHits() { // sendHits takes the hits collected by runAsyncHits and sends them to their // owning peers -func (gm *globalManager) sendHits(hits map[string]*RateLimitReq) { +func (gm *globalManager) sendHits(hits map[string]*RateLimitRequest) { type pair struct { - client *PeerClient - req GetPeerRateLimitsReq + client *Peer + req ForwardRequest } defer prometheus.NewTimer(gm.metricGlobalSendDuration).ObserveDuration() peerRequests := make(map[string]*pair) @@ -153,37 +154,42 @@ func (gm *globalManager) sendHits(hits map[string]*RateLimitReq) { for _, r := range hits { peer, err := gm.instance.GetPeer(context.Background(), r.HashKey()) if err != nil { - gm.log.WithError(err).Errorf("while getting peer for hash key '%s'", r.HashKey()) + gm.log.LogAttrs(context.TODO(), slog.LevelError, "while getting peer for hash key", + ErrAttr(err), + slog.String("hash_key", r.HashKey()), + ) continue } - p, ok := peerRequests[peer.Info().GRPCAddress] + p, ok := peerRequests[peer.Info().HTTPAddress] if ok { p.req.Requests = append(p.req.Requests, r) } else { - peerRequests[peer.Info().GRPCAddress] = &pair{ + peerRequests[peer.Info().HTTPAddress] = &pair{ client: peer, - req: GetPeerRateLimitsReq{Requests: []*RateLimitReq{r}}, + req: ForwardRequest{Requests: []*RateLimitRequest{r}}, } } } - fan := syncutil.NewFanOut(gm.conf.GlobalPeerRequestsConcurrency) + fan := wait.NewFanOut(gm.conf.GlobalPeerRequestsConcurrency) // Send the rate limit requests to their respective owning peers. for _, p := range peerRequests { - fan.Run(func(in interface{}) error { - p := in.(*pair) + fan.Run(func() error { ctx, cancel := context.WithTimeout(context.Background(), gm.conf.GlobalTimeout) - _, err := p.client.GetPeerRateLimits(ctx, &p.req) + var resp ForwardResponse + err := p.client.ForwardBatch(ctx, &p.req, &resp) cancel() if err != nil { - gm.log.WithError(err). - Errorf("while sending global hits to '%s'", p.client.Info().GRPCAddress) + gm.log.LogAttrs(context.TODO(), slog.LevelError, "while sending global hits", + ErrAttr(err), + slog.String("address", p.client.Info().HTTPAddress), + ) } return nil - }, p) + }) } - fan.Wait() + _ = fan.Wait() } // runBroadcasts collects status changes for global rate limits in a forever loop, @@ -192,7 +198,7 @@ func (gm *globalManager) sendHits(hits map[string]*RateLimitReq) { // and in a periodic frequency determined by GlobalSyncWait. func (gm *globalManager) runBroadcasts() { var interval = NewInterval(gm.conf.GlobalSyncWait) - updates := make(map[string]*RateLimitReq) + updates := make(map[string]*RateLimitRequest) gm.wg.Until(func(done chan struct{}) bool { select { @@ -203,7 +209,7 @@ func (gm *globalManager) runBroadcasts() { // Send the hits if we reached our batch limit if len(updates) >= gm.conf.GlobalBatchLimit { gm.broadcastPeers(context.Background(), updates) - updates = make(map[string]*RateLimitReq) + updates = make(map[string]*RateLimitRequest) gm.metricGlobalQueueLength.Set(0) return true } @@ -219,7 +225,7 @@ func (gm *globalManager) runBroadcasts() { break } gm.broadcastPeers(context.Background(), updates) - updates = make(map[string]*RateLimitReq) + updates = make(map[string]*RateLimitRequest) gm.metricGlobalQueueLength.Set(0) case <-done: @@ -231,61 +237,64 @@ func (gm *globalManager) runBroadcasts() { } // broadcastPeers broadcasts global rate limit statuses to all other peers -func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string]*RateLimitReq) { +func (gm *globalManager) broadcastPeers(ctx context.Context, updates map[string]*RateLimitRequest) { defer prometheus.NewTimer(gm.metricBroadcastDuration).ObserveDuration() - var req UpdatePeerGlobalsReq - reqState := RateLimitReqState{IsOwner: false} + var req UpdateRequest + reqState := RateLimitContext{IsOwner: false} gm.metricGlobalQueueLength.Set(float64(len(updates))) for _, update := range updates { - // Get current rate limit state. - grlReq := proto.Clone(update).(*RateLimitReq) + grlReq := proto.Clone(update).(*RateLimitRequest) grlReq.Hits = 0 - status, err := gm.instance.workerPool.GetRateLimit(ctx, grlReq, reqState) + state, err := gm.instance.cache.CheckRateLimit(ctx, grlReq, reqState) if err != nil { - gm.log.WithError(err).Error("while retrieving rate limit status") + gm.log.LogAttrs(context.TODO(), slog.LevelError, "while retrieving rate limit status", + ErrAttr(err), + ) continue } - updateReq := &UpdatePeerGlobal{ + updateReq := &UpdateRateLimit{ Key: update.HashKey(), Algorithm: update.Algorithm, Duration: update.Duration, - Status: status, + State: state, CreatedAt: *update.CreatedAt, } req.Globals = append(req.Globals, updateReq) } - fan := syncutil.NewFanOut(gm.conf.GlobalPeerRequestsConcurrency) + fan := wait.NewFanOut(gm.conf.GlobalPeerRequestsConcurrency) for _, peer := range gm.instance.GetPeerList() { // Exclude ourselves from the update if peer.Info().IsOwner { continue } - fan.Run(func(in interface{}) error { - peer := in.(*PeerClient) + fan.Run(func() error { ctx, cancel := context.WithTimeout(ctx, gm.conf.GlobalTimeout) - _, err := peer.UpdatePeerGlobals(ctx, &req) + err := peer.Update(ctx, &req) cancel() if err != nil { // Only log if it's an unknown error if !errors.Is(err, context.Canceled) && errors.Is(err, context.DeadlineExceeded) { - gm.log.WithError(err).Errorf("while broadcasting global updates to '%s'", peer.Info().GRPCAddress) + gm.log.LogAttrs(context.TODO(), slog.LevelError, "while broadcasting global updates", + ErrAttr(err), + slog.String("address", peer.Info().HTTPAddress), + ) } } return nil - }, peer) + }) } - fan.Wait() + _ = fan.Wait() } // Close stops all goroutines and shuts down all the peers. func (gm *globalManager) Close() { gm.wg.Stop() for _, peer := range gm.instance.GetPeerList() { - _ = peer.Shutdown(context.Background()) + _ = peer.Close(context.Background()) } } diff --git a/go.mod b/go.mod index 482dfc0..c7d46cc 100644 --- a/go.mod +++ b/go.mod @@ -1,35 +1,36 @@ -module github.com/gubernator-io/gubernator/v2 +module github.com/gubernator-io/gubernator/v3 -go 1.21 +go 1.22.9 -toolchain go1.21.9 +toolchain go1.23.1 require ( - github.com/OneOfOne/xxhash v1.2.8 github.com/davecgh/go-spew v1.1.1 - github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.0 + github.com/duh-rpc/duh-go v0.1.0 github.com/hashicorp/memberlist v0.5.0 + github.com/kapetan-io/errors v0.5.0 + github.com/kapetan-io/tackle v0.11.0 github.com/mailgun/errors v0.1.5 github.com/mailgun/holster/v4 v4.19.0 + github.com/maypok86/otter v1.2.1 github.com/miekg/dns v1.1.50 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.13.0 github.com/prometheus/client_model v0.2.0 github.com/prometheus/common v0.37.0 github.com/segmentio/fasthash v1.0.2 - github.com/sirupsen/logrus v1.9.2 github.com/stretchr/testify v1.9.0 go.etcd.io/etcd/client/v3 v3.5.5 - go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.46.1 - go.opentelemetry.io/otel v1.25.0 - go.opentelemetry.io/otel/sdk v1.25.0 - go.opentelemetry.io/otel/trace v1.25.0 + go.opentelemetry.io/otel v1.26.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.25.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.25.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.25.0 + go.opentelemetry.io/otel/sdk v1.26.0 + go.opentelemetry.io/otel/trace v1.26.0 go.uber.org/goleak v1.3.0 golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 golang.org/x/net v0.23.0 - golang.org/x/sync v0.6.0 - golang.org/x/time v0.3.0 - google.golang.org/genproto/googleapis/api v0.0.0-20240227224415-6ceb2ff114de + golang.org/x/time v0.8.0 google.golang.org/grpc v1.63.0 google.golang.org/protobuf v1.33.0 k8s.io/api v0.23.3 @@ -45,7 +46,8 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/coreos/go-semver v0.3.0 // indirect github.com/coreos/go-systemd/v22 v22.3.2 // indirect - github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/dolthub/maphash v0.1.0 // indirect + github.com/gammazero/deque v0.2.1 // indirect github.com/go-logr/logr v1.4.1 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect @@ -54,6 +56,7 @@ require ( github.com/google/go-cmp v0.6.0 // indirect github.com/google/gofuzz v1.1.0 // indirect github.com/googleapis/gnostic v0.5.5 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-immutable-radix v1.3.1 // indirect github.com/hashicorp/go-msgpack v1.1.5 // indirect @@ -68,30 +71,25 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/procfs v0.8.0 // indirect github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529 // indirect + github.com/sirupsen/logrus v1.9.2 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/stretchr/objx v0.5.2 // indirect - github.com/uptrace/opentelemetry-go-extra/otellogrus v0.2.1 // indirect - github.com/uptrace/opentelemetry-go-extra/otelutil v0.2.1 // indirect go.etcd.io/etcd/api/v3 v3.5.5 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.5 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 // indirect - go.opentelemetry.io/otel/exporters/jaeger v1.17.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.25.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.25.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.25.0 // indirect - go.opentelemetry.io/otel/metric v1.25.0 // indirect + go.opentelemetry.io/otel/metric v1.26.0 // indirect go.opentelemetry.io/proto/otlp v1.1.0 // indirect go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.8.0 // indirect go.uber.org/zap v1.21.0 // indirect golang.org/x/mod v0.15.0 // indirect golang.org/x/oauth2 v0.17.0 // indirect - golang.org/x/sys v0.18.0 // indirect + golang.org/x/sys v0.19.0 // indirect golang.org/x/term v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.18.0 // indirect google.golang.org/appengine v1.6.8 // indirect google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240227224415-6ceb2ff114de // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240401170217-c3f982113cda // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect diff --git a/go.sum b/go.sum index 89711f6..c941368 100644 --- a/go.sum +++ b/go.sum @@ -18,17 +18,12 @@ cloud.google.com/go v0.74.0/go.mod h1:VV1xSbzvo+9QJOxLDaJfTjx5e+MePCpCWwvftOeQmW cloud.google.com/go v0.78.0/go.mod h1:QjdrLG0uq+YwhjoVOLsS1t7TW8fs36kLs4XO5R5ECHg= cloud.google.com/go v0.79.0/go.mod h1:3bzgcEeQlzbuEAYu4mrWhKqWjmpprinYgKJLgKHnbb8= cloud.google.com/go v0.81.0/go.mod h1:mk/AM35KwGk/Nm2YSeZbxXdrNK3KZOYHmLkOqC2V6E0= -cloud.google.com/go v0.112.0 h1:tpFCD7hpHFlQ8yPwT3x+QeXqc2T6+n6T+hmABHfDUSM= cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE= cloud.google.com/go/bigquery v1.4.0/go.mod h1:S8dzgnTigyfTmLBfrtrhyYhwRxG72rYxvftPBK2Dvzc= cloud.google.com/go/bigquery v1.5.0/go.mod h1:snEHRnqQbz117VIFhE8bmtwIDY80NLUZUMb4Nv6dBIg= cloud.google.com/go/bigquery v1.7.0/go.mod h1://okPTzCYNXSlb24MZs83e2Do+h+VXtc4gLoIoXIAPc= cloud.google.com/go/bigquery v1.8.0/go.mod h1:J5hqkt3O0uAFnINi6JXValWIb1v0goeZM77hZzJN/fQ= -cloud.google.com/go/compute v1.24.0 h1:phWcR2eWzRJaL/kOiJwfFsPs4BaKq1j6vnpZrc1YlVg= -cloud.google.com/go/compute v1.24.0/go.mod h1:kw1/T+h/+tK2LJK0wiPPx1intgdAM3j/g3hFDlscY40= -cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY= -cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= cloud.google.com/go/datastore v1.1.0/go.mod h1:umbIZjpQpHh4hmRpGhH4tLFup+FVzqBi1b3c64qFpCk= cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= @@ -52,12 +47,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= github.com/NYTimes/gziphandler v0.0.0-20170623195520-56545f4a5d46/go.mod h1:3wb06e3pkSAbeQ52E9H9iFoQsEEwGN64994WTCIhntQ= -github.com/OneOfOne/xxhash v1.2.8 h1:31czK/TI9sNkxIKfaUfGlU47BAxQ0ztGgd9vPyqimf8= -github.com/OneOfOne/xxhash v1.2.8/go.mod h1:eZbhyaAYD41SGSSsnmcpxVoRiQ/MPUTjUdIIOT9Um7Q= github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= -github.com/Shopify/toxiproxy v2.1.4+incompatible h1:TKdv8HiTLgE5wdJuEML90aBgNWsokNbMijUGhmcoBJc= -github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= github.com/ahmetb/go-linq v3.0.0+incompatible h1:qQkjjOXKrKOTy83X8OpRmnKflXKQIL/mC/gMVVDMhOA= github.com/ahmetb/go-linq v3.0.0+incompatible/go.mod h1:PFffvbdbtw+QTB0WKRP0cNht7vnCfnGlEpak/DVg5cY= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= @@ -95,8 +86,6 @@ github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGX github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/xds/go v0.0.0-20210805033703-aa0b78936158/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= -github.com/cncf/xds/go v0.0.0-20231128003011-0fa0005c9caa h1:jQCWAUqqlij9Pgj2i/PB79y4KOPYVyFYdROxgaCwdTQ= -github.com/cncf/xds/go v0.0.0-20231128003011-0fa0005c9caa/go.mod h1:x/1Gn8zydmfq8dk6e9PdstVsDgu9RuyIIJqAaF//0IM= github.com/coreos/go-semver v0.3.0 h1:wkHLiw0WNATZnSG7epLsujiMCgPAc9xhjJ4tgnAxmfM= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.2 h1:D9/bQk5vlXQFZ6Kwuu6zaiXJ9oTPe68++AzAJc1DzSI= @@ -106,6 +95,10 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= +github.com/dolthub/maphash v0.1.0 h1:bsQ7JsF4FkkWyrP3oCnFJgrCUAFbFf3kOl4L/QxPDyQ= +github.com/dolthub/maphash v0.1.0/go.mod h1:gkg4Ch4CdCDu5h6PMriVLawB7koZ+5ijb9puGMV50a4= +github.com/duh-rpc/duh-go v0.1.0 h1:Ym7XvNhl1CD6dgy+YWiPfhkOQGNzFsBsIc5uvYdF08c= +github.com/duh-rpc/duh-go v0.1.0/go.mod h1:OoCoGsZkeED84v8TAE86m2NM5ZfNLNlqUUm7tYO+h+k= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/elazarl/goproxy v0.0.0-20180725130230-947c36da3153/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= @@ -117,17 +110,15 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.m github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/go-control-plane v0.9.10-0.20210907150352-cf90f659a021/go.mod h1:AFq3mo9L8Lqqiid3OhADV3RfLJnjiw63cSpi+fDTRC0= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/envoyproxy/protoc-gen-validate v1.0.4 h1:gVPz/FMfvh57HdSJQyvBtF00j8JU4zdyUgIUNhlgg0A= -github.com/envoyproxy/protoc-gen-validate v1.0.4/go.mod h1:qys6tmnRsYrQqIhm2bvKZH4Blx/1gTIZ2UKVY1M+Yew= github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= -github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= -github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= github.com/form3tech-oss/jwt-go v3.2.3+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/gammazero/deque v0.2.1 h1:qSdsbG6pgp6nL7A0+K/B7s12mcCY/5l5SIUpMOl+dC0= +github.com/gammazero/deque v0.2.1/go.mod h1:LFroj8x4cMYCukHJDbxFCkT+r9AndaJnFMuZDV34tuU= github.com/getkin/kin-openapi v0.76.0/go.mod h1:660oXbgy5JFMKreazJaQTw7o+X00qeSyhcnluiMv+Xg= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= @@ -234,7 +225,6 @@ github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5m github.com/googleapis/gnostic v0.5.1/go.mod h1:6U4PtQXGIEt/Z3h5MAT7FNofLnw9vXk2cUuW7uA/OeU= github.com/googleapis/gnostic v0.5.5 h1:9fHAtK0uDfpveeqqo1hkEZJcFvYXAiCN3UutL8F9xHw= github.com/googleapis/gnostic v0.5.5/go.mod h1:7+EbHbldMins07ALC74bsA81Ovc97DwqyJO1AENw9kA= -github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= @@ -283,6 +273,10 @@ github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1 github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= +github.com/kapetan-io/errors v0.5.0 h1:VVzhQ8WIefPLMNkv94ifiKe6V0qimr53PC+Q5T3FciU= +github.com/kapetan-io/errors v0.5.0/go.mod h1:Rc59bpJaA+YHiAyTY702+SmiL+iRQgpZXjR8S6mzyOg= +github.com/kapetan-io/tackle v0.11.0 h1:xcQ2WgES8rjsd0ZMBfFTMuCs8YG4+1r2OAPY0+mHXjM= +github.com/kapetan-io/tackle v0.11.0/go.mod h1:94m0H3j8pm9JMsAuqBsC/Y08WpAUh01ugkFxABjjHd8= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -306,6 +300,8 @@ github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaO github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/maypok86/otter v1.2.1 h1:xyvMW+t0vE1sKt/++GTkznLitEl7D/msqXkAbLwiC1M= +github.com/maypok86/otter v1.2.1/go.mod h1:mKLfoI7v1HOmQMwFgX4QkRk23mX6ge3RDvjdHOWG4R4= github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso= github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA= github.com/miekg/dns v1.1.50/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME= @@ -407,10 +403,6 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= -github.com/uptrace/opentelemetry-go-extra/otellogrus v0.2.1 h1:klz16Hi1ydcI/AkFCbgxXvqwfKNPb+/EVHMlu1PdEKo= -github.com/uptrace/opentelemetry-go-extra/otellogrus v0.2.1/go.mod h1:CufwvpLoGqj/uJFKsxBy09MKEM/o9QqaWjkB4RnFVdI= -github.com/uptrace/opentelemetry-go-extra/otelutil v0.2.1 h1:qjljyY//UH064+gQDHh5U7M1Jh6b+iQpJUWVAuRJ04A= -github.com/uptrace/opentelemetry-go-extra/otelutil v0.2.1/go.mod h1:7YSrHCmYPHIXjTWnKSU7EGT0TFEcm3WwSeQquwCGg38= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -430,26 +422,20 @@ go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.46.1 h1:SpGay3w+nEwMpfVnbqOLH5gY52/foP8RE8UzTZ1pdSE= -go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.46.1/go.mod h1:4UoMYEZOC0yN/sPGH76KPkkU7zgiEWYWL9vwmbnTJPE= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 h1:aFJWCqJMNjENlcleuuOkGAPH82y0yULBScfXcIEdS24= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1/go.mod h1:sEGXWArGqc3tVa+ekntsN65DmVbVeW+7lTKTjZF3/Fo= -go.opentelemetry.io/otel v1.25.0 h1:gldB5FfhRl7OJQbUHt/8s0a7cE8fbsPAtdpRaApKy4k= -go.opentelemetry.io/otel v1.25.0/go.mod h1:Wa2ds5NOXEMkCmUou1WA7ZBfLTHWIsp034OVD7AO+Vg= -go.opentelemetry.io/otel/exporters/jaeger v1.17.0 h1:D7UpUy2Xc2wsi1Ras6V40q806WM07rqoCWzXu7Sqy+4= -go.opentelemetry.io/otel/exporters/jaeger v1.17.0/go.mod h1:nPCqOnEH9rNLKqH/+rrUjiMzHJdV1BlpKcTwRTyKkKI= +go.opentelemetry.io/otel v1.26.0 h1:LQwgL5s/1W7YiiRwxf03QGnWLb2HW4pLiAhaA5cZXBs= +go.opentelemetry.io/otel v1.26.0/go.mod h1:UmLkJHUAidDval2EICqBMbnAd0/m2vmpf/dAM+fvFs4= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.25.0 h1:dT33yIHtmsqpixFsSQPwNeY5drM9wTcoL8h0FWF4oGM= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.25.0/go.mod h1:h95q0LBGh7hlAC08X2DhSeyIG02YQ0UyioTCVAqRPmc= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.25.0 h1:vOL89uRfOCCNIjkisd0r7SEdJF3ZJFyCNY34fdZs8eU= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.25.0/go.mod h1:8GlBGcDk8KKi7n+2S4BT/CPZQYH3erLu0/k64r1MYgo= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.25.0 h1:Mbi5PKN7u322woPa85d7ebZ+SOvEoPvoiBu+ryHWgfA= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.25.0/go.mod h1:e7ciERRhZaOZXVjx5MiL8TK5+Xv7G5Gv5PA2ZDEJdL8= -go.opentelemetry.io/otel/metric v1.25.0 h1:LUKbS7ArpFL/I2jJHdJcqMGxkRdxpPHE0VU/D4NuEwA= -go.opentelemetry.io/otel/metric v1.25.0/go.mod h1:rkDLUSd2lC5lq2dFNrX9LGAbINP5B7WBkC78RXCpH5s= -go.opentelemetry.io/otel/sdk v1.25.0 h1:PDryEJPC8YJZQSyLY5eqLeafHtG+X7FWnf3aXMtxbqo= -go.opentelemetry.io/otel/sdk v1.25.0/go.mod h1:oFgzCM2zdsxKzz6zwpTZYLLQsFwc+K0daArPdIhuxkw= -go.opentelemetry.io/otel/trace v1.25.0 h1:tqukZGLwQYRIFtSQM2u2+yfMVTgGVeqRLPUYx1Dq6RM= -go.opentelemetry.io/otel/trace v1.25.0/go.mod h1:hCCs70XM/ljO+BeQkyFnbK28SBIJ/Emuha+ccrCRT7I= +go.opentelemetry.io/otel/metric v1.26.0 h1:7S39CLuY5Jgg9CrnA9HHiEjGMF/X2VHvoXGgSllRz30= +go.opentelemetry.io/otel/metric v1.26.0/go.mod h1:SY+rHOI4cEawI9a7N1A4nIg/nTQXe1ccCNWYOJUrpX4= +go.opentelemetry.io/otel/sdk v1.26.0 h1:Y7bumHf5tAiDlRYFmGqetNcLaVUZmh4iYfmGxtmz7F8= +go.opentelemetry.io/otel/sdk v1.26.0/go.mod h1:0p8MXpqLeJ0pzcszQQN4F0S5FVjBLgypeGSngLsmirs= +go.opentelemetry.io/otel/trace v1.26.0 h1:1ieeAUb4y0TE26jUFrCIXKpTuVK7uJGN9/Z/2LP5sQA= +go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZuWCBllGV2U2y0= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= go.opentelemetry.io/proto/otlp v1.1.0 h1:2Di21piLrCqJ3U3eXGCTPHE9R8Nh+0uglSnOyxikMeI= go.opentelemetry.io/proto/otlp v1.1.0/go.mod h1:GpBHCBWiqvVLDqmHZsoMM3C5ySeKTC7ej/RNTae6MdY= @@ -657,8 +643,8 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= +golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -680,8 +666,8 @@ golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxb golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= -golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg= +golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= diff --git a/grpc_stats.go b/grpc_stats.go deleted file mode 100644 index 39cc662..0000000 --- a/grpc_stats.go +++ /dev/null @@ -1,145 +0,0 @@ -/* -Copyright 2018-2022 Mailgun Technologies Inc - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package gubernator - -import ( - "context" - - "github.com/mailgun/holster/v4/clock" - "github.com/mailgun/holster/v4/syncutil" - "github.com/prometheus/client_golang/prometheus" - "google.golang.org/grpc/stats" -) - -type GRPCStats struct { - Duration clock.Duration - Method string - Failed float64 - Success float64 -} - -type contextKey struct{} - -var statsContextKey = contextKey{} - -// Implements the Prometheus collector interface. Such that when the /metrics handler is -// called this collector pulls all the stats from -type GRPCStatsHandler struct { - reqCh chan *GRPCStats - wg syncutil.WaitGroup - - grpcRequestCount *prometheus.CounterVec - grpcRequestDuration *prometheus.SummaryVec -} - -func NewGRPCStatsHandler() *GRPCStatsHandler { - c := &GRPCStatsHandler{ - grpcRequestCount: prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "gubernator_grpc_request_counts", - Help: "The count of gRPC requests.", - }, []string{"status", "method"}), - grpcRequestDuration: prometheus.NewSummaryVec(prometheus.SummaryOpts{ - Name: "gubernator_grpc_request_duration", - Help: "The timings of gRPC requests in seconds", - Objectives: map[float64]float64{ - 0.5: 0.05, - 0.99: 0.001, - }, - }, []string{"method"}), - } - c.run() - return c -} - -func (c *GRPCStatsHandler) run() { - c.reqCh = make(chan *GRPCStats, 10000) - - c.wg.Until(func(done chan struct{}) bool { - select { - case stat := <-c.reqCh: - c.grpcRequestCount.With(prometheus.Labels{"status": "failed", "method": stat.Method}).Add(stat.Failed) - c.grpcRequestCount.With(prometheus.Labels{"status": "success", "method": stat.Method}).Add(stat.Success) - c.grpcRequestDuration.With(prometheus.Labels{"method": stat.Method}).Observe(stat.Duration.Seconds()) - case <-done: - return false - } - return true - }) -} - -func (c *GRPCStatsHandler) Describe(ch chan<- *prometheus.Desc) { - c.grpcRequestCount.Describe(ch) - c.grpcRequestDuration.Describe(ch) -} - -func (c *GRPCStatsHandler) Collect(ch chan<- prometheus.Metric) { - c.grpcRequestCount.Collect(ch) - c.grpcRequestDuration.Collect(ch) -} - -func (c *GRPCStatsHandler) Close() { - c.wg.Stop() -} - -func (c *GRPCStatsHandler) HandleRPC(ctx context.Context, s stats.RPCStats) { - rs := StatsFromContext(ctx) - if rs == nil { - return - } - - switch t := s.(type) { - // case *stats.Begin: - // case *stats.InPayload: - // case *stats.InHeader: - // case *stats.InTrailer: - // case *stats.OutPayload: - // case *stats.OutHeader: - // case *stats.OutTrailer: - case *stats.End: - rs.Duration = t.EndTime.Sub(t.BeginTime) - if t.Error != nil { - rs.Failed = 1 - } else { - rs.Success = 1 - } - c.reqCh <- rs - } -} - -func (c *GRPCStatsHandler) HandleConn(ctx context.Context, s stats.ConnStats) {} - -func (c *GRPCStatsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { - return ctx -} - -func (c *GRPCStatsHandler) TagRPC(ctx context.Context, tagInfo *stats.RPCTagInfo) context.Context { - return ContextWithStats(ctx, &GRPCStats{Method: tagInfo.FullMethodName}) -} - -// Returns a new `context.Context` that holds a reference to `GRPCStats`. -func ContextWithStats(ctx context.Context, stats *GRPCStats) context.Context { - return context.WithValue(ctx, statsContextKey, stats) -} - -// Returns the `GRPCStats` previously associated with `ctx`. -func StatsFromContext(ctx context.Context) *GRPCStats { - val := ctx.Value(statsContextKey) - if rs, ok := val.(*GRPCStats); ok { - return rs - } - return nil -} diff --git a/gubernator.go b/gubernator.go index b103b26..44384fe 100644 --- a/gubernator.go +++ b/gubernator.go @@ -18,21 +18,23 @@ package gubernator import ( "context" + "errors" "fmt" + "log/slog" "strings" "sync" "time" - "github.com/mailgun/errors" - "github.com/mailgun/holster/v4/clock" - "github.com/mailgun/holster/v4/syncutil" - "github.com/mailgun/holster/v4/tracing" + "github.com/kapetan-io/tackle/wait" + + "github.com/duh-rpc/duh-go" + v1 "github.com/duh-rpc/duh-go/proto/v1" + "github.com/gubernator-io/gubernator/v3/tracing" + "github.com/kapetan-io/tackle/clock" "github.com/prometheus/client_golang/prometheus" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/trace" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" ) @@ -42,25 +44,25 @@ const ( UnHealthy = "unhealthy" ) -type V1Instance struct { - UnimplementedV1Server - UnimplementedPeersV1Server +type Service struct { + propagator propagation.TraceContext global *globalManager peerMutex sync.RWMutex + cache CacheManager log FieldLogger conf Config isClosed bool - workerPool *WorkerPool } -type RateLimitReqState struct { +// RateLimitContext is context that is not included in the RateLimitRequest but is needed by algorithms.go +type RateLimitContext struct { IsOwner bool } var ( metricGetRateLimitCounter = prometheus.NewCounterVec(prometheus.CounterOpts{ Name: "gubernator_getratelimit_counter", - Help: "The count of getLocalRateLimit() calls. Label \"calltype\" may be \"local\" for calls handled by the same peer, or \"global\" for global rate limits.", + Help: "The count of checkLocalRateLimit() calls. Label \"calltype\" may be \"local\" for calls handled by the same peer, or \"global\" for global rate limits.", }, []string{"calltype"}) metricFuncTimeDuration = prometheus.NewSummaryVec(prometheus.SummaryOpts{ Name: "gubernator_func_duration", @@ -83,14 +85,6 @@ var ( Name: "gubernator_check_error_counter", Help: "The number of errors while checking rate limits.", }, []string{"error"}) - metricCommandCounter = prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "gubernator_command_counter", - Help: "The count of commands processed by each worker in WorkerPool.", - }, []string{"worker", "method"}) - metricWorkerQueue = prometheus.NewGaugeVec(prometheus.GaugeOpts{ - Name: "gubernator_worker_queue_length", - Help: "The count of requests queued up in WorkerPool.", - }, []string{"method", "worker"}) // Batch behavior. metricBatchSendRetries = prometheus.NewCounterVec(prometheus.CounterOpts{ @@ -110,47 +104,39 @@ var ( }, []string{"peerAddr"}) ) -// NewV1Instance instantiate a single instance of a gubernator peer and register this -// instance with the provided GRPCServer. -func NewV1Instance(conf Config) (s *V1Instance, err error) { +// NewService instantiate a single instance of a gubernator service +func NewService(conf Config) (s *Service, err error) { ctx := context.Background() - if conf.GRPCServers == nil { - return nil, errors.New("at least one GRPCServer instance is required") - } + if err := conf.SetDefaults(); err != nil { return nil, err } - s = &V1Instance{ + s = &Service{ log: conf.Logger, conf: conf, } - s.workerPool = NewWorkerPool(&conf) - s.global = newGlobalManager(conf.Behaviors, s) - - // Register our instance with all GRPC servers - for _, srv := range conf.GRPCServers { - RegisterV1Server(srv, s) - RegisterPeersV1Server(srv, s) + s.cache, err = NewCacheManager(conf) + if err != nil { + return nil, fmt.Errorf("during NewCacheManager(): %w", err) } + s.global = newGlobalManager(conf.Behaviors, s) if s.conf.Loader == nil { return s, nil } // Load the cache. - err = s.workerPool.Load(ctx) + err = s.cache.Load(ctx) if err != nil { - return nil, errors.Wrap(err, "Error in workerPool.Load") + return nil, fmt.Errorf("error in CacheManager.Load: %w", err) } return s, nil } -func (s *V1Instance) Close() (err error) { - ctx := context.Background() - +func (s *Service) Close(ctx context.Context) (err error) { if s.isClosed { return nil } @@ -158,114 +144,124 @@ func (s *V1Instance) Close() (err error) { s.global.Close() if s.conf.Loader != nil { - err = s.workerPool.Store(ctx) + err = s.cache.Store(ctx) if err != nil { - s.log.WithError(err). - Error("Error in workerPool.Store") - return errors.Wrap(err, "Error in workerPool.Store") + s.log.LogAttrs(context.TODO(), slog.LevelError, "Error in workerPool.Store", + ErrAttr(err), + ) + return fmt.Errorf("error in CacheManager.Store: %w", err) } } - err = s.workerPool.Close() + err = s.cache.Close() if err != nil { - s.log.WithError(err). - Error("Error in workerPool.Close") - return errors.Wrap(err, "Error in workerPool.Close") + s.log.LogAttrs(context.TODO(), slog.LevelError, "Error in workerPool.Close", + ErrAttr(err), + ) + return fmt.Errorf("error in CacheManager.Close: %w", err) } + // Close all the peer clients + s.SetPeers([]PeerInfo{}) + s.isClosed = true return nil } -// GetRateLimits is the public interface used by clients to request rate limits from the system. If the +// CheckRateLimits is the public interface used by clients to request rate limits from the system. If the // rate limit `Name` and `UniqueKey` is not owned by this instance, then we forward the request to the // peer that does. -func (s *V1Instance) GetRateLimits(ctx context.Context, r *GetRateLimitsReq) (*GetRateLimitsResp, error) { - funcTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.GetRateLimits")) +func (s *Service) CheckRateLimits(ctx context.Context, req *CheckRateLimitsRequest, resp *CheckRateLimitsResponse) (err error) { + funcTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("Service.CheckRateLimits")) defer funcTimer.ObserveDuration() metricConcurrentChecks.Inc() defer metricConcurrentChecks.Dec() - if len(r.Requests) > maxBatchSize { - metricCheckErrorCounter.WithLabelValues("Request too large").Inc() - return nil, status.Errorf(codes.OutOfRange, - "Requests.RateLimits list too large; max size is '%d'", maxBatchSize) + if len(req.Requests) > maxBatchSize { + metricCheckErrorCounter.WithLabelValues("Request too large").Add(1) + return duh.NewServiceError(duh.CodeBadRequest, + fmt.Sprintf("CheckRateLimitsRequest.RateLimits list too large; max size is '%d'", maxBatchSize), nil, nil) } - createdAt := epochMillis(clock.Now()) - resp := GetRateLimitsResp{ - Responses: make([]*RateLimitResp, len(r.Requests)), + if len(req.Requests) == 0 { + return duh.NewServiceError(duh.CodeBadRequest, + "CheckRateLimitsRequest.RateLimits list is empty; provide at least one rate limit", nil, nil) } + + resp.Responses = make([]*RateLimitResponse, len(req.Requests)) + asyncCh := make(chan AsyncResp, len(req.Requests)) + createdAt := epochMillis(clock.Now()) var wg sync.WaitGroup - asyncCh := make(chan AsyncResp, len(r.Requests)) // For each item in the request body - for i, req := range r.Requests { - key := req.Name + "_" + req.UniqueKey - var peer *PeerClient + for i, r := range req.Requests { + key := r.Name + "_" + r.UniqueKey + var peer *Peer var err error - if req.UniqueKey == "" { + if r.UniqueKey == "" { metricCheckErrorCounter.WithLabelValues("Invalid request").Inc() - resp.Responses[i] = &RateLimitResp{Error: "field 'unique_key' cannot be empty"} + resp.Responses[i] = &RateLimitResponse{Error: "field 'unique_key' cannot be empty"} continue } - if req.Name == "" { + + if r.Name == "" { metricCheckErrorCounter.WithLabelValues("Invalid request").Inc() - resp.Responses[i] = &RateLimitResp{Error: "field 'namespace' cannot be empty"} + resp.Responses[i] = &RateLimitResponse{Error: "field 'namespace' cannot be empty"} continue } - if req.CreatedAt == nil || *req.CreatedAt == 0 { - req.CreatedAt = &createdAt + + if r.CreatedAt == nil || *r.CreatedAt == 0 { + r.CreatedAt = &createdAt } if ctx.Err() != nil { - err = errors.Wrap(ctx.Err(), "Error while iterating request items") + err = fmt.Errorf("error while iterating request items: %w", ctx.Err()) span := trace.SpanFromContext(ctx) span.RecordError(err) - resp.Responses[i] = &RateLimitResp{ + resp.Responses[i] = &RateLimitResponse{ Error: err.Error(), } continue } if s.conf.Behaviors.ForceGlobal { - SetBehavior(&req.Behavior, Behavior_GLOBAL, true) + SetBehavior(&r.Behavior, Behavior_GLOBAL, true) } peer, err = s.GetPeer(ctx, key) if err != nil { countError(err, "Error in GetPeer") - err = errors.Wrapf(err, "Error in GetPeer, looking up peer that owns rate limit '%s'", key) - resp.Responses[i] = &RateLimitResp{ + err = fmt.Errorf("error in GetPeer, looking up peer that owns rate limit '%s': %w", key, err) + resp.Responses[i] = &RateLimitResponse{ Error: err.Error(), } continue } // If our server instance is the owner of this rate limit - reqState := RateLimitReqState{IsOwner: peer.Info().IsOwner} + reqState := RateLimitContext{IsOwner: peer.Info().IsOwner} if reqState.IsOwner { // Apply our rate limit algorithm to the request - resp.Responses[i], err = s.getLocalRateLimit(ctx, req, reqState) + resp.Responses[i], err = s.checkLocalRateLimit(ctx, r, reqState) if err != nil { - err = errors.Wrapf(err, "Error while apply rate limit for '%s'", key) + err = fmt.Errorf("error while apply rate limit for '%s': %w", key, err) span := trace.SpanFromContext(ctx) span.RecordError(err) - resp.Responses[i] = &RateLimitResp{Error: err.Error()} + resp.Responses[i] = &RateLimitResponse{Error: err.Error()} } } else { - if HasBehavior(req.Behavior, Behavior_GLOBAL) { - resp.Responses[i], err = s.getGlobalRateLimit(ctx, req) + if HasBehavior(r.Behavior, Behavior_GLOBAL) { + resp.Responses[i], err = s.checkGlobalRateLimit(ctx, r) if err != nil { - err = errors.Wrap(err, "Error in getGlobalRateLimit") + err = fmt.Errorf("error in checkGlobalRateLimit: %w", err) span := trace.SpanFromContext(ctx) span.RecordError(err) - resp.Responses[i] = &RateLimitResp{Error: err.Error()} + resp.Responses[i] = &RateLimitResponse{Error: err.Error()} } // Inform the client of the owner key of the key - resp.Responses[i].Metadata = map[string]string{"owner": peer.Info().GRPCAddress} + resp.Responses[i].Metadata = map[string]string{"owner": peer.Info().HTTPAddress} continue } @@ -275,7 +271,7 @@ func (s *V1Instance) GetRateLimits(ctx context.Context, r *GetRateLimitsReq) (*G go s.asyncRequest(ctx, &AsyncReq{ AsyncCh: asyncCh, Peer: peer, - Req: req, + Req: r, WG: &wg, Key: key, Idx: i, @@ -291,95 +287,98 @@ func (s *V1Instance) GetRateLimits(ctx context.Context, r *GetRateLimitsReq) (*G resp.Responses[a.Idx] = a.Resp } - return &resp, nil + return nil } type AsyncResp struct { + Resp *RateLimitResponse Idx int - Resp *RateLimitResp } type AsyncReq struct { + Req *RateLimitRequest WG *sync.WaitGroup AsyncCh chan AsyncResp - Req *RateLimitReq - Peer *PeerClient + Peer *Peer Key string Idx int } -func (s *V1Instance) asyncRequest(ctx context.Context, req *AsyncReq) { +func (s *Service) asyncRequest(ctx context.Context, req *AsyncReq) { + ctx = tracing.StartScope(ctx, "Service.asyncRequest") + defer tracing.EndScope(ctx, nil) var attempts int var err error - ctx = tracing.StartNamedScope(ctx, "V1Instance.asyncRequest") - defer tracing.EndScope(ctx, nil) - - funcTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.asyncRequest")) + funcTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("Service.asyncRequest")) defer funcTimer.ObserveDuration() - reqState := RateLimitReqState{IsOwner: req.Peer.Info().IsOwner} + reqState := RateLimitContext{IsOwner: req.Peer.Info().IsOwner} resp := AsyncResp{ Idx: req.Idx, } for { if attempts > 5 { - s.log.WithContext(ctx). - WithError(err). - WithField("key", req.Key). - Error("GetPeer() returned peer that is not connected") - countError(err, "Peer not connected") - err = fmt.Errorf("GetPeer() keeps returning peers that are not connected for '%s': %w", req.Key, err) - resp.Resp = &RateLimitResp{Error: err.Error()} + err = fmt.Errorf("attempts exhausted while communicating with '%s' for '%s': %w", + req.Peer.Info().HTTPAddress, req.Key, err) + s.log.LogAttrs(ctx, slog.LevelError, "attempts exhausted while communicating with peer", + ErrAttr(err), + slog.String("key", req.Key), + ) + countError(err, "peer communication failed") + resp.Resp = &RateLimitResponse{Error: err.Error()} break } // If we are attempting again, the owner of this rate limit might have changed to us! if attempts != 0 { if reqState.IsOwner { - resp.Resp, err = s.getLocalRateLimit(ctx, req.Req, reqState) + resp.Resp, err = s.checkLocalRateLimit(ctx, req.Req, reqState) if err != nil { - s.log.WithContext(ctx). - WithError(err). - WithField("key", req.Key). - Error("Error applying rate limit") - err = fmt.Errorf("during getLocalRateLimit() for '%s': %w", req.Key, err) - resp.Resp = &RateLimitResp{Error: err.Error()} + err = fmt.Errorf("during checkLocalRateLimit() for '%s': %w", req.Key, err) + s.log.LogAttrs(ctx, slog.LevelError, "while applying rate limit", + ErrAttr(err), + slog.String("key", req.Key), + ) + resp.Resp = &RateLimitResponse{Error: err.Error()} } break } } // Make an RPC call to the peer that owns this rate limit - var r *RateLimitResp - r, err = req.Peer.GetPeerRateLimit(ctx, req.Req) + var r *RateLimitResponse + r, err = req.Peer.Forward(ctx, req.Req) if err != nil { - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || + errors.Is(err, ErrPeerShutdown) { + attempts++ metricBatchSendRetries.WithLabelValues(req.Req.Name).Inc() req.Peer, err = s.GetPeer(ctx, req.Key) if err != nil { - errPart := fmt.Sprintf("while finding peer that owns rate limit '%s'", req.Key) - s.log.WithContext(ctx).WithError(err).WithField("key", req.Key).Error(errPart) + err = fmt.Errorf("while finding peer that owns rate limit '%s': %w", req.Key, err) + s.log.LogAttrs(ctx, slog.LevelError, err.Error(), + ErrAttr(err), + slog.String("key", req.Key), + ) countError(err, "during GetPeer()") - err = fmt.Errorf("%s: %w", errPart, err) - resp.Resp = &RateLimitResp{Error: err.Error()} + resp.Resp = &RateLimitResponse{Error: err.Error()} break } continue } - // Not calling `countError()` because we expect the remote end to - // report this error. + // Not calling `countError()` because we expect the remote end to report this error. err = fmt.Errorf("while fetching rate limit '%s' from peer: %w", req.Key, err) - resp.Resp = &RateLimitResp{Error: err.Error()} + resp.Resp = &RateLimitResponse{Error: err.Error()} break } // Inform the client of the owner key of the key resp.Resp = r - resp.Resp.Metadata = map[string]string{"owner": req.Peer.Info().GRPCAddress} + resp.Resp.Metadata = map[string]string{"owner": req.Peer.Info().HTTPAddress} break } @@ -391,14 +390,14 @@ func (s *V1Instance) asyncRequest(ctx context.Context, req *AsyncReq) { } } -// getGlobalRateLimit handles rate limits that are marked as `Behavior = GLOBAL`. Rate limit responses +// checkGlobalRateLimit handles rate limits that are marked as `Behavior = GLOBAL`. Rate limit responses // are returned from the local cache and the hits are queued to be sent to the owning peer. -func (s *V1Instance) getGlobalRateLimit(ctx context.Context, req *RateLimitReq) (resp *RateLimitResp, err error) { - ctx = tracing.StartNamedScope(ctx, "V1Instance.getGlobalRateLimit", trace.WithAttributes( +func (s *Service) checkGlobalRateLimit(ctx context.Context, req *RateLimitRequest) (resp *RateLimitResponse, err error) { + ctx = tracing.StartScope(ctx, "Service.checkGlobalRateLimit", trace.WithAttributes( attribute.String("ratelimit.key", req.UniqueKey), attribute.String("ratelimit.name", req.Name), )) - defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.getGlobalRateLimit")).ObserveDuration() + defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("Service.checkGlobalRateLimit")).ObserveDuration() defer func() { if err == nil { s.global.QueueHit(req) @@ -406,85 +405,90 @@ func (s *V1Instance) getGlobalRateLimit(ctx context.Context, req *RateLimitReq) tracing.EndScope(ctx, err) }() - req2 := proto.Clone(req).(*RateLimitReq) + req2 := proto.Clone(req).(*RateLimitRequest) SetBehavior(&req2.Behavior, Behavior_NO_BATCHING, true) SetBehavior(&req2.Behavior, Behavior_GLOBAL, false) - reqState := RateLimitReqState{IsOwner: false} + reqState := RateLimitContext{IsOwner: false} // Process the rate limit like we own it - resp, err = s.getLocalRateLimit(ctx, req2, reqState) + resp, err = s.checkLocalRateLimit(ctx, req2, reqState) if err != nil { - return nil, errors.Wrap(err, "during in getLocalRateLimit") + return nil, fmt.Errorf("during in checkLocalRateLimit: %w", err) } metricGetRateLimitCounter.WithLabelValues("global").Inc() return resp, nil } -// UpdatePeerGlobals updates the local cache with a list of global rate limits. This method should only -// be called by a peer who is the owner of a global rate limit. -func (s *V1Instance) UpdatePeerGlobals(ctx context.Context, r *UpdatePeerGlobalsReq) (*UpdatePeerGlobalsResp, error) { - defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.UpdatePeerGlobals")).ObserveDuration() +// Update updates the local cache with a list of rate limit state from a peer +// This method should only be called by a peer. +func (s *Service) Update(ctx context.Context, r *UpdateRequest, _ *v1.Reply) (err error) { + ctx = tracing.StartScope(ctx, "Service.Update") + defer func() { tracing.EndScope(ctx, err) }() + now := MillisecondNow() + for _, g := range r.Globals { - item := &CacheItem{ - ExpireAt: g.Status.ResetTime, - Algorithm: g.Algorithm, - Key: g.Key, + item, _, err := s.cache.GetCacheItem(ctx, g.Key) + if err != nil { + return err } + + if item == nil { + item = &CacheItem{ + ExpireAt: g.State.ResetTime, + Algorithm: g.Algorithm, + Key: g.Key, + } + err := s.cache.AddCacheItem(ctx, g.Key, item) + if err != nil { + return fmt.Errorf("during CacheManager.AddCacheItem(): %w", err) + } + } + + item.mutex.Lock() switch g.Algorithm { case Algorithm_LEAKY_BUCKET: item.Value = &LeakyBucketItem{ - Remaining: float64(g.Status.Remaining), - Limit: g.Status.Limit, + Remaining: float64(g.State.Remaining), + Limit: g.State.Limit, Duration: g.Duration, - Burst: g.Status.Limit, + Burst: g.State.Limit, UpdatedAt: now, } case Algorithm_TOKEN_BUCKET: item.Value = &TokenBucketItem{ - Status: g.Status.Status, - Limit: g.Status.Limit, + Status: g.State.Status, + Limit: g.State.Limit, Duration: g.Duration, - Remaining: g.Status.Remaining, + Remaining: g.State.Remaining, CreatedAt: now, } } - err := s.workerPool.AddCacheItem(ctx, g.Key, item) - if err != nil { - return nil, errors.Wrap(err, "Error in workerPool.AddCacheItem") - } + item.mutex.Unlock() } - - return &UpdatePeerGlobalsResp{}, nil + return nil } -// GetPeerRateLimits is called by other peers to get the rate limits owned by this peer. -func (s *V1Instance) GetPeerRateLimits(ctx context.Context, r *GetPeerRateLimitsReq) (resp *GetPeerRateLimitsResp, err error) { - defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.GetPeerRateLimits")).ObserveDuration() - if len(r.Requests) > maxBatchSize { - err := fmt.Errorf("'PeerRequest.rate_limits' list too large; max size is '%d'", maxBatchSize) - metricCheckErrorCounter.WithLabelValues("Request too large").Inc() - return nil, status.Error(codes.OutOfRange, err.Error()) +// Forward is called by other peers when forwarding rate limits to this peer +func (s *Service) Forward(ctx context.Context, req *ForwardRequest, resp *ForwardResponse) (err error) { + if len(req.Requests) > maxBatchSize { + metricCheckErrorCounter.WithLabelValues("Request too large").Add(1) + return duh.NewServiceError(duh.CodeBadRequest, + fmt.Sprintf("'Forward.requests' list too large; max size is '%d'", maxBatchSize), nil, nil) } // Invoke each rate limit request. - type reqIn struct { - idx int - req *RateLimitReq - } type respOut struct { idx int - rl *RateLimitResp + rl *RateLimitResponse } - resp = &GetPeerRateLimitsResp{ - RateLimits: make([]*RateLimitResp, len(r.Requests)), - } + resp.RateLimits = make([]*RateLimitResponse, len(req.Requests)) + reqState := RateLimitContext{IsOwner: true} respChan := make(chan respOut) var respWg sync.WaitGroup respWg.Add(1) - reqState := RateLimitReqState{IsOwner: true} go func() { // Capture each response and return in the same order @@ -496,39 +500,35 @@ func (s *V1Instance) GetPeerRateLimits(ctx context.Context, r *GetPeerRateLimits }() // Fan out requests. - fan := syncutil.NewFanOut(s.conf.Workers) - for idx, req := range r.Requests { - fan.Run(func(in interface{}) error { - rin := in.(reqIn) + fan := wait.NewFanOut(s.conf.Workers) + for idx, req := range req.Requests { + fan.Run(func() error { // Extract the propagated context from the metadata in the request - prop := propagation.TraceContext{} - ctx := prop.Extract(ctx, &MetadataCarrier{Map: rin.req.Metadata}) + ctx := s.propagator.Extract(ctx, &MetadataCarrier{Map: req.Metadata}) // Forwarded global requests must have DRAIN_OVER_LIMIT set so token and leaky algorithms // drain the remaining in the event a peer asks for more than is remaining. // This is needed because with GLOBAL behavior peers will accumulate hits, which could // result in requesting more hits than is remaining. - if HasBehavior(rin.req.Behavior, Behavior_GLOBAL) { - SetBehavior(&rin.req.Behavior, Behavior_DRAIN_OVER_LIMIT, true) + if HasBehavior(req.Behavior, Behavior_GLOBAL) { + SetBehavior(&req.Behavior, Behavior_DRAIN_OVER_LIMIT, true) } // Assign default to CreatedAt for backwards compatibility. - if rin.req.CreatedAt == nil || *rin.req.CreatedAt == 0 { + if req.CreatedAt == nil || *req.CreatedAt == 0 { createdAt := epochMillis(clock.Now()) - rin.req.CreatedAt = &createdAt + req.CreatedAt = &createdAt } - rl, err := s.getLocalRateLimit(ctx, rin.req, reqState) + rl, err := s.checkLocalRateLimit(ctx, req, reqState) if err != nil { - // Return the error for this request - err = errors.Wrap(err, "Error in getLocalRateLimit") - rl = &RateLimitResp{Error: err.Error()} - // metricCheckErrorCounter is updated within getLocalRateLimit(), not in GetPeerRateLimits. + rl = &RateLimitResponse{Error: fmt.Errorf("error in checkLocalRateLimit: %w", err).Error()} + // metricCheckErrorCounter is updated within checkLocalRateLimit(), not in Forward(). } - respChan <- respOut{rin.idx, rl} + respChan <- respOut{idx, rl} return nil - }, reqIn{idx, req}) + }) } // Wait for all requests to be handled, then clean up. @@ -536,12 +536,11 @@ func (s *V1Instance) GetPeerRateLimits(ctx context.Context, r *GetPeerRateLimits close(respChan) respWg.Wait() - return resp, nil + return nil } // HealthCheck Returns the health of our instance. -func (s *V1Instance) HealthCheck(ctx context.Context, r *HealthCheckReq) (health *HealthCheckResp, err error) { - span := trace.SpanFromContext(ctx) +func (s *Service) HealthCheck(ctx context.Context, _ *HealthCheckRequest, resp *HealthCheckResponse) (err error) { var errs []string @@ -553,7 +552,6 @@ func (s *V1Instance) HealthCheck(ctx context.Context, r *HealthCheckReq) (health for _, peer := range localPeers { for _, errMsg := range peer.GetLastErr() { err := fmt.Errorf("error returned from local peer.GetLastErr: %s", errMsg) - span.RecordError(err) errs = append(errs, err.Error()) } } @@ -563,42 +561,34 @@ func (s *V1Instance) HealthCheck(ctx context.Context, r *HealthCheckReq) (health for _, peer := range regionPeers { for _, errMsg := range peer.GetLastErr() { err := fmt.Errorf("error returned from region peer.GetLastErr: %s", errMsg) - span.RecordError(err) errs = append(errs, err.Error()) } } - health = &HealthCheckResp{ - PeerCount: int32(len(localPeers) + len(regionPeers)), - Status: Healthy, - } + resp.PeerCount = int32(len(localPeers) + len(regionPeers)) + resp.Status = Healthy if len(errs) != 0 { - health.Status = UnHealthy - health.Message = strings.Join(errs, "|") + resp.Status = UnHealthy + resp.Message = strings.Join(errs, "|") } - span.SetAttributes( - attribute.Int64("health.peerCount", int64(health.PeerCount)), - attribute.String("health.status", health.Status), - ) - - return health, nil + return nil } -func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq, reqState RateLimitReqState) (_ *RateLimitResp, err error) { - ctx = tracing.StartNamedScope(ctx, "V1Instance.getLocalRateLimit", trace.WithAttributes( +func (s *Service) checkLocalRateLimit(ctx context.Context, r *RateLimitRequest, reqState RateLimitContext) (_ *RateLimitResponse, err error) { + ctx = tracing.StartScope(ctx, "Service.checkLocalRateLimit", trace.WithAttributes( attribute.String("ratelimit.key", r.UniqueKey), attribute.String("ratelimit.name", r.Name), attribute.Int64("ratelimit.limit", r.Limit), attribute.Int64("ratelimit.hits", r.Hits), )) defer func() { tracing.EndScope(ctx, err) }() - defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.getLocalRateLimit")).ObserveDuration() + defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("Service.checkLocalRateLimit")).ObserveDuration() - resp, err := s.workerPool.GetRateLimit(ctx, r, reqState) + resp, err := s.cache.CheckRateLimit(ctx, r, reqState) if err != nil { - return nil, errors.Wrap(err, "during workerPool.GetRateLimit") + return nil, fmt.Errorf("during CacheManager.CheckRateLimit: %w", err) } // If global behavior, then broadcast update to all peers. @@ -625,8 +615,7 @@ func (s *V1Instance) getLocalRateLimit(ctx context.Context, r *RateLimitReq, req } // SetPeers replaces the peers and shuts down all the previous peers. -// TODO this should return an error if we failed to connect to any of the new peers -func (s *V1Instance) SetPeers(peerInfo []PeerInfo) { +func (s *Service) SetPeers(peerInfo []PeerInfo) { localPicker := s.conf.LocalPicker.New() regionPicker := s.conf.RegionPicker.New() @@ -634,37 +623,41 @@ func (s *V1Instance) SetPeers(peerInfo []PeerInfo) { // Add peers that are not in our local DC to the RegionPicker if info.DataCenter != s.conf.DataCenter { peer := s.conf.RegionPicker.GetByPeerInfo(info) - // If we don't have an existing PeerClient create a new one + // If we don't have an existing Peer create a new one if peer == nil { var err error - peer, err = NewPeerClient(PeerConfig{ - TraceGRPC: s.conf.PeerTraceGRPC, - Behavior: s.conf.Behaviors, - TLS: s.conf.PeerTLS, - Log: s.log, - Info: info, + peer, err = NewPeer(PeerConfig{ + PeerClient: s.conf.PeerClientFactory(info), + Behavior: s.conf.Behaviors, + Log: s.log, + Info: info, }) if err != nil { - s.log.Errorf("error connecting to peer %s: %s", info.GRPCAddress, err) + s.log.LogAttrs(context.TODO(), slog.LevelError, "during NewPeer() call", + ErrAttr(err), + slog.String("http_address", info.HTTPAddress), + ) return } } regionPicker.Add(peer) continue } - // If we don't have an existing PeerClient create a new one + // If we don't have an existing Peer create a new one peer := s.conf.LocalPicker.GetByPeerInfo(info) if peer == nil { var err error - peer, err = NewPeerClient(PeerConfig{ - TraceGRPC: s.conf.PeerTraceGRPC, - Behavior: s.conf.Behaviors, - TLS: s.conf.PeerTLS, - Log: s.log, - Info: info, + peer, err = NewPeer(PeerConfig{ + PeerClient: s.conf.PeerClientFactory(info), + Behavior: s.conf.Behaviors, + Log: s.log, + Info: info, }) if err != nil { - s.log.Errorf("error connecting to peer %s: %s", info.GRPCAddress, err) + s.log.LogAttrs(context.TODO(), slog.LevelError, "during NewPeer() call", + ErrAttr(err), + slog.String("http_address", info.HTTPAddress), + ) return } } @@ -680,13 +673,15 @@ func (s *V1Instance) SetPeers(peerInfo []PeerInfo) { s.conf.RegionPicker = regionPicker s.peerMutex.Unlock() - s.log.WithField("peers", peerInfo).Debug("peers updated") + s.log.LogAttrs(context.TODO(), slog.LevelDebug, "peers updated", + slog.Any("peers", peerInfo), + ) // Shutdown any old peers we no longer need ctx, cancel := context.WithTimeout(context.Background(), s.conf.Behaviors.BatchTimeout) defer cancel() - var shutdownPeers []*PeerClient + var shutdownPeers []*Peer for _, peer := range oldLocalPicker.Peers() { if peerInfo := s.conf.LocalPicker.GetByPeerInfo(peer.Info()); peerInfo == nil { shutdownPeers = append(shutdownPeers, peer) @@ -701,66 +696,67 @@ func (s *V1Instance) SetPeers(peerInfo []PeerInfo) { } } - var wg syncutil.WaitGroup + var wg wait.Group for _, p := range shutdownPeers { - wg.Run(func(obj interface{}) error { - pc := obj.(*PeerClient) - err := pc.Shutdown(ctx) + wg.Run(func() error { + err := p.Close(ctx) if err != nil { - s.log.WithError(err).WithField("peer", pc).Error("while shutting down peer") + s.log.LogAttrs(context.TODO(), slog.LevelError, "while shutting down peer", + ErrAttr(err), + slog.Any("peer", p), + ) } return nil - }, p) + }) } - wg.Wait() + _ = wg.Wait() if len(shutdownPeers) > 0 { var peers []string for _, p := range shutdownPeers { - peers = append(peers, p.Info().GRPCAddress) + peers = append(peers, p.Info().HTTPAddress) } - s.log.WithField("peers", peers).Debug("peers shutdown") + s.log.LogAttrs(context.TODO(), slog.LevelDebug, "peers shutdown", + slog.Any("peers", peers), + ) } } // GetPeer returns a peer client for the hash key provided -func (s *V1Instance) GetPeer(ctx context.Context, key string) (p *PeerClient, err error) { - defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("V1Instance.GetPeer")).ObserveDuration() - +func (s *Service) GetPeer(_ context.Context, key string) (p *Peer, err error) { s.peerMutex.RLock() defer s.peerMutex.RUnlock() + p, err = s.conf.LocalPicker.Get(key) if err != nil { - return nil, errors.Wrap(err, "Error in conf.LocalPicker.Get") + return nil, fmt.Errorf("error in conf.LocalPicker.Get: %w", err) } return p, nil } -func (s *V1Instance) GetPeerList() []*PeerClient { +func (s *Service) GetPeerList() []*Peer { s.peerMutex.RLock() defer s.peerMutex.RUnlock() return s.conf.LocalPicker.Peers() } -func (s *V1Instance) GetRegionPickers() map[string]PeerPicker { +func (s *Service) GetRegionPickers() map[string]PeerPicker { s.peerMutex.RLock() defer s.peerMutex.RUnlock() return s.conf.RegionPicker.Pickers() } // Describe fetches prometheus metrics to be registered -func (s *V1Instance) Describe(ch chan<- *prometheus.Desc) { +func (s *Service) Describe(ch chan<- *prometheus.Desc) { metricBatchQueueLength.Describe(ch) metricBatchSendDuration.Describe(ch) metricBatchSendRetries.Describe(ch) metricCheckErrorCounter.Describe(ch) - metricCommandCounter.Describe(ch) metricConcurrentChecks.Describe(ch) metricFuncTimeDuration.Describe(ch) metricGetRateLimitCounter.Describe(ch) metricOverLimitCounter.Describe(ch) - metricWorkerQueue.Describe(ch) s.global.metricBroadcastDuration.Describe(ch) s.global.metricGlobalQueueLength.Describe(ch) s.global.metricGlobalSendDuration.Describe(ch) @@ -768,17 +764,15 @@ func (s *V1Instance) Describe(ch chan<- *prometheus.Desc) { } // Collect fetches metrics from the server for use by prometheus -func (s *V1Instance) Collect(ch chan<- prometheus.Metric) { +func (s *Service) Collect(ch chan<- prometheus.Metric) { metricBatchQueueLength.Collect(ch) metricBatchSendDuration.Collect(ch) metricBatchSendRetries.Collect(ch) metricCheckErrorCounter.Collect(ch) - metricCommandCounter.Collect(ch) metricConcurrentChecks.Collect(ch) metricFuncTimeDuration.Collect(ch) metricGetRateLimitCounter.Collect(ch) metricOverLimitCounter.Collect(ch) - metricWorkerQueue.Collect(ch) s.global.metricBroadcastDuration.Collect(ch) s.global.metricGlobalQueueLength.Collect(ch) s.global.metricGlobalSendDuration.Collect(ch) diff --git a/gubernator.pb.go b/gubernator.pb.go index 305cc2a..6aa9642 100644 --- a/gubernator.pb.go +++ b/gubernator.pb.go @@ -22,11 +22,11 @@ package gubernator import ( - _ "google.golang.org/genproto/googleapis/api/annotations" - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" + + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" ) const ( @@ -248,17 +248,17 @@ func (Status) EnumDescriptor() ([]byte, []int) { return file_gubernator_proto_rawDescGZIP(), []int{2} } -// Must specify at least one Request -type GetRateLimitsReq struct { +// Must specify at least one RateLimitRequest +type CheckRateLimitsRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Requests []*RateLimitReq `protobuf:"bytes,1,rep,name=requests,proto3" json:"requests,omitempty"` + Requests []*RateLimitRequest `protobuf:"bytes,1,rep,name=requests,proto3" json:"requests,omitempty"` } -func (x *GetRateLimitsReq) Reset() { - *x = GetRateLimitsReq{} +func (x *CheckRateLimitsRequest) Reset() { + *x = CheckRateLimitsRequest{} if protoimpl.UnsafeEnabled { mi := &file_gubernator_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -266,13 +266,13 @@ func (x *GetRateLimitsReq) Reset() { } } -func (x *GetRateLimitsReq) String() string { +func (x *CheckRateLimitsRequest) String() string { return protoimpl.X.MessageStringOf(x) } -func (*GetRateLimitsReq) ProtoMessage() {} +func (*CheckRateLimitsRequest) ProtoMessage() {} -func (x *GetRateLimitsReq) ProtoReflect() protoreflect.Message { +func (x *CheckRateLimitsRequest) ProtoReflect() protoreflect.Message { mi := &file_gubernator_proto_msgTypes[0] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -284,29 +284,29 @@ func (x *GetRateLimitsReq) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use GetRateLimitsReq.ProtoReflect.Descriptor instead. -func (*GetRateLimitsReq) Descriptor() ([]byte, []int) { +// Deprecated: Use CheckRateLimitsRequest.ProtoReflect.Descriptor instead. +func (*CheckRateLimitsRequest) Descriptor() ([]byte, []int) { return file_gubernator_proto_rawDescGZIP(), []int{0} } -func (x *GetRateLimitsReq) GetRequests() []*RateLimitReq { +func (x *CheckRateLimitsRequest) GetRequests() []*RateLimitRequest { if x != nil { return x.Requests } return nil } -// RateLimits returned are in the same order as the Requests -type GetRateLimitsResp struct { +// RateLimits returned are in the same order provided in CheckRateLimitsRequest +type CheckRateLimitsResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Responses []*RateLimitResp `protobuf:"bytes,1,rep,name=responses,proto3" json:"responses,omitempty"` + Responses []*RateLimitResponse `protobuf:"bytes,1,rep,name=responses,proto3" json:"responses,omitempty"` } -func (x *GetRateLimitsResp) Reset() { - *x = GetRateLimitsResp{} +func (x *CheckRateLimitsResponse) Reset() { + *x = CheckRateLimitsResponse{} if protoimpl.UnsafeEnabled { mi := &file_gubernator_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -314,13 +314,13 @@ func (x *GetRateLimitsResp) Reset() { } } -func (x *GetRateLimitsResp) String() string { +func (x *CheckRateLimitsResponse) String() string { return protoimpl.X.MessageStringOf(x) } -func (*GetRateLimitsResp) ProtoMessage() {} +func (*CheckRateLimitsResponse) ProtoMessage() {} -func (x *GetRateLimitsResp) ProtoReflect() protoreflect.Message { +func (x *CheckRateLimitsResponse) ProtoReflect() protoreflect.Message { mi := &file_gubernator_proto_msgTypes[1] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -332,19 +332,19 @@ func (x *GetRateLimitsResp) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use GetRateLimitsResp.ProtoReflect.Descriptor instead. -func (*GetRateLimitsResp) Descriptor() ([]byte, []int) { +// Deprecated: Use CheckRateLimitsResponse.ProtoReflect.Descriptor instead. +func (*CheckRateLimitsResponse) Descriptor() ([]byte, []int) { return file_gubernator_proto_rawDescGZIP(), []int{1} } -func (x *GetRateLimitsResp) GetResponses() []*RateLimitResp { +func (x *CheckRateLimitsResponse) GetResponses() []*RateLimitResponse { if x != nil { return x.Responses } return nil } -type RateLimitReq struct { +type RateLimitRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields @@ -352,7 +352,7 @@ type RateLimitReq struct { // The name of the rate limit IE: 'requests_per_second', 'gets_per_minute` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` // Uniquely identifies this rate limit IE: 'ip:10.2.10.7' or 'account:123445' - UniqueKey string `protobuf:"bytes,2,opt,name=unique_key,json=uniqueKey,proto3" json:"unique_key,omitempty"` + UniqueKey string `protobuf:"bytes,2,opt,name=unique_key,proto3" json:"unique_key,omitempty"` // Rate limit requests optionally specify the number of hits a request adds to the matched limit. If Hit // is zero, the request returns the current limit, but does not increment the hit count. Hits int64 `protobuf:"varint,3,opt,name=hits,proto3" json:"hits,omitempty"` @@ -365,9 +365,9 @@ type RateLimitReq struct { Duration int64 `protobuf:"varint,5,opt,name=duration,proto3" json:"duration,omitempty"` // The algorithm used to calculate the rate limit. The algorithm may change on // subsequent requests, when this occurs any previous rate limit hit counts are reset. - Algorithm Algorithm `protobuf:"varint,6,opt,name=algorithm,proto3,enum=pb.gubernator.Algorithm" json:"algorithm,omitempty"` + Algorithm Algorithm `protobuf:"varint,6,opt,name=algorithm,proto3,enum=gubernator.v3.Algorithm" json:"algorithm,omitempty"` // Behavior is a set of int32 flags that control the behavior of the rate limit in gubernator - Behavior Behavior `protobuf:"varint,7,opt,name=behavior,proto3,enum=pb.gubernator.Behavior" json:"behavior,omitempty"` + Behavior Behavior `protobuf:"varint,7,opt,name=behavior,proto3,enum=gubernator.v3.Behavior" json:"behavior,omitempty"` // Maximum burst size that the limit can accept. Burst int64 `protobuf:"varint,8,opt,name=burst,proto3" json:"burst,omitempty"` // This is metadata that is associated with this rate limit. Peer to Peer communication will use @@ -387,8 +387,8 @@ type RateLimitReq struct { CreatedAt *int64 `protobuf:"varint,10,opt,name=created_at,json=createdAt,proto3,oneof" json:"created_at,omitempty"` } -func (x *RateLimitReq) Reset() { - *x = RateLimitReq{} +func (x *RateLimitRequest) Reset() { + *x = RateLimitRequest{} if protoimpl.UnsafeEnabled { mi := &file_gubernator_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -396,13 +396,13 @@ func (x *RateLimitReq) Reset() { } } -func (x *RateLimitReq) String() string { +func (x *RateLimitRequest) String() string { return protoimpl.X.MessageStringOf(x) } -func (*RateLimitReq) ProtoMessage() {} +func (*RateLimitRequest) ProtoMessage() {} -func (x *RateLimitReq) ProtoReflect() protoreflect.Message { +func (x *RateLimitRequest) ProtoReflect() protoreflect.Message { mi := &file_gubernator_proto_msgTypes[2] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -414,102 +414,102 @@ func (x *RateLimitReq) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use RateLimitReq.ProtoReflect.Descriptor instead. -func (*RateLimitReq) Descriptor() ([]byte, []int) { +// Deprecated: Use RateLimitRequest.ProtoReflect.Descriptor instead. +func (*RateLimitRequest) Descriptor() ([]byte, []int) { return file_gubernator_proto_rawDescGZIP(), []int{2} } -func (x *RateLimitReq) GetName() string { +func (x *RateLimitRequest) GetName() string { if x != nil { return x.Name } return "" } -func (x *RateLimitReq) GetUniqueKey() string { +func (x *RateLimitRequest) GetUniqueKey() string { if x != nil { return x.UniqueKey } return "" } -func (x *RateLimitReq) GetHits() int64 { +func (x *RateLimitRequest) GetHits() int64 { if x != nil { return x.Hits } return 0 } -func (x *RateLimitReq) GetLimit() int64 { +func (x *RateLimitRequest) GetLimit() int64 { if x != nil { return x.Limit } return 0 } -func (x *RateLimitReq) GetDuration() int64 { +func (x *RateLimitRequest) GetDuration() int64 { if x != nil { return x.Duration } return 0 } -func (x *RateLimitReq) GetAlgorithm() Algorithm { +func (x *RateLimitRequest) GetAlgorithm() Algorithm { if x != nil { return x.Algorithm } return Algorithm_TOKEN_BUCKET } -func (x *RateLimitReq) GetBehavior() Behavior { +func (x *RateLimitRequest) GetBehavior() Behavior { if x != nil { return x.Behavior } return Behavior_BATCHING } -func (x *RateLimitReq) GetBurst() int64 { +func (x *RateLimitRequest) GetBurst() int64 { if x != nil { return x.Burst } return 0 } -func (x *RateLimitReq) GetMetadata() map[string]string { +func (x *RateLimitRequest) GetMetadata() map[string]string { if x != nil { return x.Metadata } return nil } -func (x *RateLimitReq) GetCreatedAt() int64 { +func (x *RateLimitRequest) GetCreatedAt() int64 { if x != nil && x.CreatedAt != nil { return *x.CreatedAt } return 0 } -type RateLimitResp struct { +type RateLimitResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields // The status of the rate limit. - Status Status `protobuf:"varint,1,opt,name=status,proto3,enum=pb.gubernator.Status" json:"status,omitempty"` - // The currently configured request limit (Identical to [[RateLimitReq.limit]]). + Status Status `protobuf:"varint,1,opt,name=status,proto3,enum=gubernator.v3.Status" json:"status,omitempty"` + // The currently configured request limit (Identical to [[RateLimitRequest.limit]]). Limit int64 `protobuf:"varint,2,opt,name=limit,proto3" json:"limit,omitempty"` // This is the number of requests remaining before the rate limit is hit but after subtracting the hits from the current request Remaining int64 `protobuf:"varint,3,opt,name=remaining,proto3" json:"remaining,omitempty"` // This is the time when the rate limit span will be reset, provided as a unix timestamp in milliseconds. - ResetTime int64 `protobuf:"varint,4,opt,name=reset_time,json=resetTime,proto3" json:"reset_time,omitempty"` + ResetTime int64 `protobuf:"varint,4,opt,name=reset_time,proto3" json:"reset_time,omitempty"` // Contains the error; If set all other values should be ignored Error string `protobuf:"bytes,5,opt,name=error,proto3" json:"error,omitempty"` // This is additional metadata that a client might find useful. (IE: Additional headers, coordinator ownership, etc..) Metadata map[string]string `protobuf:"bytes,6,rep,name=metadata,proto3" json:"metadata,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` } -func (x *RateLimitResp) Reset() { - *x = RateLimitResp{} +func (x *RateLimitResponse) Reset() { + *x = RateLimitResponse{} if protoimpl.UnsafeEnabled { mi := &file_gubernator_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -517,13 +517,13 @@ func (x *RateLimitResp) Reset() { } } -func (x *RateLimitResp) String() string { +func (x *RateLimitResponse) String() string { return protoimpl.X.MessageStringOf(x) } -func (*RateLimitResp) ProtoMessage() {} +func (*RateLimitResponse) ProtoMessage() {} -func (x *RateLimitResp) ProtoReflect() protoreflect.Message { +func (x *RateLimitResponse) ProtoReflect() protoreflect.Message { mi := &file_gubernator_proto_msgTypes[3] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -535,61 +535,61 @@ func (x *RateLimitResp) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use RateLimitResp.ProtoReflect.Descriptor instead. -func (*RateLimitResp) Descriptor() ([]byte, []int) { +// Deprecated: Use RateLimitResponse.ProtoReflect.Descriptor instead. +func (*RateLimitResponse) Descriptor() ([]byte, []int) { return file_gubernator_proto_rawDescGZIP(), []int{3} } -func (x *RateLimitResp) GetStatus() Status { +func (x *RateLimitResponse) GetStatus() Status { if x != nil { return x.Status } return Status_UNDER_LIMIT } -func (x *RateLimitResp) GetLimit() int64 { +func (x *RateLimitResponse) GetLimit() int64 { if x != nil { return x.Limit } return 0 } -func (x *RateLimitResp) GetRemaining() int64 { +func (x *RateLimitResponse) GetRemaining() int64 { if x != nil { return x.Remaining } return 0 } -func (x *RateLimitResp) GetResetTime() int64 { +func (x *RateLimitResponse) GetResetTime() int64 { if x != nil { return x.ResetTime } return 0 } -func (x *RateLimitResp) GetError() string { +func (x *RateLimitResponse) GetError() string { if x != nil { return x.Error } return "" } -func (x *RateLimitResp) GetMetadata() map[string]string { +func (x *RateLimitResponse) GetMetadata() map[string]string { if x != nil { return x.Metadata } return nil } -type HealthCheckReq struct { +type HealthCheckRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields } -func (x *HealthCheckReq) Reset() { - *x = HealthCheckReq{} +func (x *HealthCheckRequest) Reset() { + *x = HealthCheckRequest{} if protoimpl.UnsafeEnabled { mi := &file_gubernator_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -597,13 +597,13 @@ func (x *HealthCheckReq) Reset() { } } -func (x *HealthCheckReq) String() string { +func (x *HealthCheckRequest) String() string { return protoimpl.X.MessageStringOf(x) } -func (*HealthCheckReq) ProtoMessage() {} +func (*HealthCheckRequest) ProtoMessage() {} -func (x *HealthCheckReq) ProtoReflect() protoreflect.Message { +func (x *HealthCheckRequest) ProtoReflect() protoreflect.Message { mi := &file_gubernator_proto_msgTypes[4] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -615,12 +615,12 @@ func (x *HealthCheckReq) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use HealthCheckReq.ProtoReflect.Descriptor instead. -func (*HealthCheckReq) Descriptor() ([]byte, []int) { +// Deprecated: Use HealthCheckRequest.ProtoReflect.Descriptor instead. +func (*HealthCheckRequest) Descriptor() ([]byte, []int) { return file_gubernator_proto_rawDescGZIP(), []int{4} } -type HealthCheckResp struct { +type HealthCheckResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields @@ -630,11 +630,11 @@ type HealthCheckResp struct { // If 'unhealthy', message indicates the problem Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` // The number of peers we know about - PeerCount int32 `protobuf:"varint,3,opt,name=peer_count,json=peerCount,proto3" json:"peer_count,omitempty"` + PeerCount int32 `protobuf:"varint,3,opt,name=peer_count,proto3" json:"peer_count,omitempty"` } -func (x *HealthCheckResp) Reset() { - *x = HealthCheckResp{} +func (x *HealthCheckResponse) Reset() { + *x = HealthCheckResponse{} if protoimpl.UnsafeEnabled { mi := &file_gubernator_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -642,13 +642,13 @@ func (x *HealthCheckResp) Reset() { } } -func (x *HealthCheckResp) String() string { +func (x *HealthCheckResponse) String() string { return protoimpl.X.MessageStringOf(x) } -func (*HealthCheckResp) ProtoMessage() {} +func (*HealthCheckResponse) ProtoMessage() {} -func (x *HealthCheckResp) ProtoReflect() protoreflect.Message { +func (x *HealthCheckResponse) ProtoReflect() protoreflect.Message { mi := &file_gubernator_proto_msgTypes[5] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -660,26 +660,26 @@ func (x *HealthCheckResp) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use HealthCheckResp.ProtoReflect.Descriptor instead. -func (*HealthCheckResp) Descriptor() ([]byte, []int) { +// Deprecated: Use HealthCheckResponse.ProtoReflect.Descriptor instead. +func (*HealthCheckResponse) Descriptor() ([]byte, []int) { return file_gubernator_proto_rawDescGZIP(), []int{5} } -func (x *HealthCheckResp) GetStatus() string { +func (x *HealthCheckResponse) GetStatus() string { if x != nil { return x.Status } return "" } -func (x *HealthCheckResp) GetMessage() string { +func (x *HealthCheckResponse) GetMessage() string { if x != nil { return x.Message } return "" } -func (x *HealthCheckResp) GetPeerCount() int32 { +func (x *HealthCheckResponse) GetPeerCount() int32 { if x != nil { return x.PeerCount } @@ -690,106 +690,93 @@ var File_gubernator_proto protoreflect.FileDescriptor var file_gubernator_proto_rawDesc = []byte{ 0x0a, 0x10, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x12, 0x0d, 0x70, 0x62, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, - 0x72, 0x1a, 0x1c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x61, 0x6e, - 0x6e, 0x6f, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, - 0x4b, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x73, - 0x52, 0x65, 0x71, 0x12, 0x37, 0x0a, 0x08, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x73, 0x18, - 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x70, 0x62, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, - 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x52, - 0x65, 0x71, 0x52, 0x08, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x73, 0x22, 0x4f, 0x0a, 0x11, - 0x47, 0x65, 0x74, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x73, 0x52, 0x65, 0x73, - 0x70, 0x12, 0x3a, 0x0a, 0x09, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x73, 0x18, 0x01, - 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x70, 0x62, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, - 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x52, 0x65, - 0x73, 0x70, 0x52, 0x09, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x73, 0x22, 0xc1, 0x03, - 0x0a, 0x0c, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x52, 0x65, 0x71, 0x12, 0x12, - 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, - 0x6d, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x75, 0x6e, 0x69, 0x71, 0x75, 0x65, 0x5f, 0x6b, 0x65, 0x79, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x75, 0x6e, 0x69, 0x71, 0x75, 0x65, 0x4b, 0x65, - 0x79, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x69, 0x74, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x04, 0x68, 0x69, 0x74, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x64, - 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x64, - 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x36, 0x0a, 0x09, 0x61, 0x6c, 0x67, 0x6f, 0x72, - 0x69, 0x74, 0x68, 0x6d, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x70, 0x62, 0x2e, - 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x41, 0x6c, 0x67, 0x6f, 0x72, - 0x69, 0x74, 0x68, 0x6d, 0x52, 0x09, 0x61, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x12, - 0x33, 0x0a, 0x08, 0x62, 0x65, 0x68, 0x61, 0x76, 0x69, 0x6f, 0x72, 0x18, 0x07, 0x20, 0x01, 0x28, - 0x0e, 0x32, 0x17, 0x2e, 0x70, 0x62, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, - 0x72, 0x2e, 0x42, 0x65, 0x68, 0x61, 0x76, 0x69, 0x6f, 0x72, 0x52, 0x08, 0x62, 0x65, 0x68, 0x61, - 0x76, 0x69, 0x6f, 0x72, 0x12, 0x14, 0x0a, 0x05, 0x62, 0x75, 0x72, 0x73, 0x74, 0x18, 0x08, 0x20, - 0x01, 0x28, 0x03, 0x52, 0x05, 0x62, 0x75, 0x72, 0x73, 0x74, 0x12, 0x45, 0x0a, 0x08, 0x6d, 0x65, - 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x09, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x70, - 0x62, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x52, 0x61, 0x74, - 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x52, 0x65, 0x71, 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, - 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, - 0x61, 0x12, 0x22, 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, - 0x0a, 0x20, 0x01, 0x28, 0x03, 0x48, 0x00, 0x52, 0x09, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, - 0x41, 0x74, 0x88, 0x01, 0x01, 0x1a, 0x3b, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, - 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, - 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, - 0x38, 0x01, 0x42, 0x0d, 0x0a, 0x0b, 0x5f, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x5f, 0x61, - 0x74, 0x22, 0xac, 0x02, 0x0a, 0x0d, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x52, - 0x65, 0x73, 0x70, 0x12, 0x2d, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0e, 0x32, 0x15, 0x2e, 0x70, 0x62, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, - 0x74, 0x6f, 0x72, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, - 0x75, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x03, 0x52, 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x72, 0x65, 0x6d, 0x61, - 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x72, 0x65, 0x6d, - 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x12, 0x1d, 0x0a, 0x0a, 0x72, 0x65, 0x73, 0x65, 0x74, 0x5f, - 0x74, 0x69, 0x6d, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x72, 0x65, 0x73, 0x65, - 0x74, 0x54, 0x69, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x05, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x46, 0x0a, 0x08, 0x6d, - 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2a, 0x2e, - 0x70, 0x62, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x52, 0x61, - 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x52, 0x65, 0x73, 0x70, 0x2e, 0x4d, 0x65, 0x74, 0x61, - 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, - 0x61, 0x74, 0x61, 0x1a, 0x3b, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, - 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, - 0x22, 0x10, 0x0a, 0x0e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x52, - 0x65, 0x71, 0x22, 0x62, 0x0a, 0x0f, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x43, 0x68, 0x65, 0x63, - 0x6b, 0x52, 0x65, 0x73, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, + 0x74, 0x6f, 0x12, 0x0d, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x76, + 0x33, 0x22, 0x55, 0x0a, 0x16, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, + 0x6d, 0x69, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x3b, 0x0a, 0x08, 0x72, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1f, 0x2e, + 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x76, 0x33, 0x2e, 0x52, 0x61, + 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x52, 0x08, + 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x73, 0x22, 0x59, 0x0a, 0x17, 0x43, 0x68, 0x65, 0x63, + 0x6b, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x12, 0x3e, 0x0a, 0x09, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x73, + 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, + 0x74, 0x6f, 0x72, 0x2e, 0x76, 0x33, 0x2e, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x52, 0x09, 0x72, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x73, 0x22, 0xca, 0x03, 0x0a, 0x10, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, + 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1e, 0x0a, 0x0a, + 0x75, 0x6e, 0x69, 0x71, 0x75, 0x65, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0a, 0x75, 0x6e, 0x69, 0x71, 0x75, 0x65, 0x5f, 0x6b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, + 0x68, 0x69, 0x74, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x68, 0x69, 0x74, 0x73, + 0x12, 0x14, 0x0a, 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x12, 0x36, 0x0a, 0x09, 0x61, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, + 0x06, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, + 0x6f, 0x72, 0x2e, 0x76, 0x33, 0x2e, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x52, + 0x09, 0x61, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x12, 0x33, 0x0a, 0x08, 0x62, 0x65, + 0x68, 0x61, 0x76, 0x69, 0x6f, 0x72, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x17, 0x2e, 0x67, + 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x76, 0x33, 0x2e, 0x42, 0x65, 0x68, + 0x61, 0x76, 0x69, 0x6f, 0x72, 0x52, 0x08, 0x62, 0x65, 0x68, 0x61, 0x76, 0x69, 0x6f, 0x72, 0x12, + 0x14, 0x0a, 0x05, 0x62, 0x75, 0x72, 0x73, 0x74, 0x18, 0x08, 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, + 0x62, 0x75, 0x72, 0x73, 0x74, 0x12, 0x49, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, + 0x61, 0x18, 0x09, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2d, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, + 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x76, 0x33, 0x2e, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, + 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, + 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, + 0x12, 0x22, 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x0a, + 0x20, 0x01, 0x28, 0x03, 0x48, 0x00, 0x52, 0x09, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x41, + 0x74, 0x88, 0x01, 0x01, 0x1a, 0x3b, 0x0a, 0x0d, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, + 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, + 0x01, 0x42, 0x0d, 0x0a, 0x0b, 0x5f, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, + 0x22, 0xb5, 0x02, 0x0a, 0x11, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2d, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x15, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, + 0x74, 0x6f, 0x72, 0x2e, 0x76, 0x33, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x06, 0x73, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x72, + 0x65, 0x6d, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, + 0x72, 0x65, 0x6d, 0x61, 0x69, 0x6e, 0x69, 0x6e, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x72, 0x65, 0x73, + 0x65, 0x74, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0a, 0x72, + 0x65, 0x73, 0x65, 0x74, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, + 0x6f, 0x72, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x12, + 0x4a, 0x0a, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x18, 0x06, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x2e, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x76, + 0x33, 0x2e, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, + 0x79, 0x52, 0x08, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x1a, 0x3b, 0x0a, 0x0d, 0x4d, + 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, + 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, + 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, + 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x14, 0x0a, 0x12, 0x48, 0x65, 0x61, 0x6c, + 0x74, 0x68, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x67, + 0x0a, 0x13, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, - 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x5f, - 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x09, 0x70, 0x65, 0x65, - 0x72, 0x43, 0x6f, 0x75, 0x6e, 0x74, 0x2a, 0x2f, 0x0a, 0x09, 0x41, 0x6c, 0x67, 0x6f, 0x72, 0x69, - 0x74, 0x68, 0x6d, 0x12, 0x10, 0x0a, 0x0c, 0x54, 0x4f, 0x4b, 0x45, 0x4e, 0x5f, 0x42, 0x55, 0x43, - 0x4b, 0x45, 0x54, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x4c, 0x45, 0x41, 0x4b, 0x59, 0x5f, 0x42, - 0x55, 0x43, 0x4b, 0x45, 0x54, 0x10, 0x01, 0x2a, 0x8d, 0x01, 0x0a, 0x08, 0x42, 0x65, 0x68, 0x61, - 0x76, 0x69, 0x6f, 0x72, 0x12, 0x0c, 0x0a, 0x08, 0x42, 0x41, 0x54, 0x43, 0x48, 0x49, 0x4e, 0x47, - 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x4e, 0x4f, 0x5f, 0x42, 0x41, 0x54, 0x43, 0x48, 0x49, 0x4e, - 0x47, 0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x47, 0x4c, 0x4f, 0x42, 0x41, 0x4c, 0x10, 0x02, 0x12, - 0x19, 0x0a, 0x15, 0x44, 0x55, 0x52, 0x41, 0x54, 0x49, 0x4f, 0x4e, 0x5f, 0x49, 0x53, 0x5f, 0x47, - 0x52, 0x45, 0x47, 0x4f, 0x52, 0x49, 0x41, 0x4e, 0x10, 0x04, 0x12, 0x13, 0x0a, 0x0f, 0x52, 0x45, - 0x53, 0x45, 0x54, 0x5f, 0x52, 0x45, 0x4d, 0x41, 0x49, 0x4e, 0x49, 0x4e, 0x47, 0x10, 0x08, 0x12, - 0x10, 0x0a, 0x0c, 0x4d, 0x55, 0x4c, 0x54, 0x49, 0x5f, 0x52, 0x45, 0x47, 0x49, 0x4f, 0x4e, 0x10, - 0x10, 0x12, 0x14, 0x0a, 0x10, 0x44, 0x52, 0x41, 0x49, 0x4e, 0x5f, 0x4f, 0x56, 0x45, 0x52, 0x5f, - 0x4c, 0x49, 0x4d, 0x49, 0x54, 0x10, 0x20, 0x2a, 0x29, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, - 0x73, 0x12, 0x0f, 0x0a, 0x0b, 0x55, 0x4e, 0x44, 0x45, 0x52, 0x5f, 0x4c, 0x49, 0x4d, 0x49, 0x54, - 0x10, 0x00, 0x12, 0x0e, 0x0a, 0x0a, 0x4f, 0x56, 0x45, 0x52, 0x5f, 0x4c, 0x49, 0x4d, 0x49, 0x54, - 0x10, 0x01, 0x32, 0xdd, 0x01, 0x0a, 0x02, 0x56, 0x31, 0x12, 0x70, 0x0a, 0x0d, 0x47, 0x65, 0x74, - 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x73, 0x12, 0x1f, 0x2e, 0x70, 0x62, 0x2e, - 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x61, - 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x73, 0x52, 0x65, 0x71, 0x1a, 0x20, 0x2e, 0x70, 0x62, - 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x47, 0x65, 0x74, 0x52, - 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x22, 0x1c, 0x82, - 0xd3, 0xe4, 0x93, 0x02, 0x16, 0x3a, 0x01, 0x2a, 0x22, 0x11, 0x2f, 0x76, 0x31, 0x2f, 0x47, 0x65, - 0x74, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x73, 0x12, 0x65, 0x0a, 0x0b, 0x48, - 0x65, 0x61, 0x6c, 0x74, 0x68, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x12, 0x1d, 0x2e, 0x70, 0x62, 0x2e, - 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, - 0x68, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x52, 0x65, 0x71, 0x1a, 0x1e, 0x2e, 0x70, 0x62, 0x2e, 0x67, - 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, - 0x43, 0x68, 0x65, 0x63, 0x6b, 0x52, 0x65, 0x73, 0x70, 0x22, 0x17, 0x82, 0xd3, 0xe4, 0x93, 0x02, - 0x11, 0x12, 0x0f, 0x2f, 0x76, 0x31, 0x2f, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x43, 0x68, 0x65, - 0x63, 0x6b, 0x42, 0x28, 0x5a, 0x23, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, - 0x2f, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2d, 0x69, 0x6f, 0x2f, 0x67, - 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x80, 0x01, 0x01, 0x62, 0x06, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x33, + 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x5f, + 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x70, 0x65, 0x65, + 0x72, 0x5f, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x2a, 0x2f, 0x0a, 0x09, 0x41, 0x6c, 0x67, 0x6f, 0x72, + 0x69, 0x74, 0x68, 0x6d, 0x12, 0x10, 0x0a, 0x0c, 0x54, 0x4f, 0x4b, 0x45, 0x4e, 0x5f, 0x42, 0x55, + 0x43, 0x4b, 0x45, 0x54, 0x10, 0x00, 0x12, 0x10, 0x0a, 0x0c, 0x4c, 0x45, 0x41, 0x4b, 0x59, 0x5f, + 0x42, 0x55, 0x43, 0x4b, 0x45, 0x54, 0x10, 0x01, 0x2a, 0x8d, 0x01, 0x0a, 0x08, 0x42, 0x65, 0x68, + 0x61, 0x76, 0x69, 0x6f, 0x72, 0x12, 0x0c, 0x0a, 0x08, 0x42, 0x41, 0x54, 0x43, 0x48, 0x49, 0x4e, + 0x47, 0x10, 0x00, 0x12, 0x0f, 0x0a, 0x0b, 0x4e, 0x4f, 0x5f, 0x42, 0x41, 0x54, 0x43, 0x48, 0x49, + 0x4e, 0x47, 0x10, 0x01, 0x12, 0x0a, 0x0a, 0x06, 0x47, 0x4c, 0x4f, 0x42, 0x41, 0x4c, 0x10, 0x02, + 0x12, 0x19, 0x0a, 0x15, 0x44, 0x55, 0x52, 0x41, 0x54, 0x49, 0x4f, 0x4e, 0x5f, 0x49, 0x53, 0x5f, + 0x47, 0x52, 0x45, 0x47, 0x4f, 0x52, 0x49, 0x41, 0x4e, 0x10, 0x04, 0x12, 0x13, 0x0a, 0x0f, 0x52, + 0x45, 0x53, 0x45, 0x54, 0x5f, 0x52, 0x45, 0x4d, 0x41, 0x49, 0x4e, 0x49, 0x4e, 0x47, 0x10, 0x08, + 0x12, 0x10, 0x0a, 0x0c, 0x4d, 0x55, 0x4c, 0x54, 0x49, 0x5f, 0x52, 0x45, 0x47, 0x49, 0x4f, 0x4e, + 0x10, 0x10, 0x12, 0x14, 0x0a, 0x10, 0x44, 0x52, 0x41, 0x49, 0x4e, 0x5f, 0x4f, 0x56, 0x45, 0x52, + 0x5f, 0x4c, 0x49, 0x4d, 0x49, 0x54, 0x10, 0x20, 0x2a, 0x29, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x12, 0x0f, 0x0a, 0x0b, 0x55, 0x4e, 0x44, 0x45, 0x52, 0x5f, 0x4c, 0x49, 0x4d, 0x49, + 0x54, 0x10, 0x00, 0x12, 0x0e, 0x0a, 0x0a, 0x4f, 0x56, 0x45, 0x52, 0x5f, 0x4c, 0x49, 0x4d, 0x49, + 0x54, 0x10, 0x01, 0x42, 0x25, 0x5a, 0x23, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, + 0x6d, 0x2f, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2d, 0x69, 0x6f, 0x2f, + 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, } var ( @@ -807,32 +794,28 @@ func file_gubernator_proto_rawDescGZIP() []byte { var file_gubernator_proto_enumTypes = make([]protoimpl.EnumInfo, 3) var file_gubernator_proto_msgTypes = make([]protoimpl.MessageInfo, 8) var file_gubernator_proto_goTypes = []interface{}{ - (Algorithm)(0), // 0: pb.gubernator.Algorithm - (Behavior)(0), // 1: pb.gubernator.Behavior - (Status)(0), // 2: pb.gubernator.Status - (*GetRateLimitsReq)(nil), // 3: pb.gubernator.GetRateLimitsReq - (*GetRateLimitsResp)(nil), // 4: pb.gubernator.GetRateLimitsResp - (*RateLimitReq)(nil), // 5: pb.gubernator.RateLimitReq - (*RateLimitResp)(nil), // 6: pb.gubernator.RateLimitResp - (*HealthCheckReq)(nil), // 7: pb.gubernator.HealthCheckReq - (*HealthCheckResp)(nil), // 8: pb.gubernator.HealthCheckResp - nil, // 9: pb.gubernator.RateLimitReq.MetadataEntry - nil, // 10: pb.gubernator.RateLimitResp.MetadataEntry + (Algorithm)(0), // 0: gubernator.v3.Algorithm + (Behavior)(0), // 1: gubernator.v3.Behavior + (Status)(0), // 2: gubernator.v3.Status + (*CheckRateLimitsRequest)(nil), // 3: gubernator.v3.CheckRateLimitsRequest + (*CheckRateLimitsResponse)(nil), // 4: gubernator.v3.CheckRateLimitsResponse + (*RateLimitRequest)(nil), // 5: gubernator.v3.RateLimitRequest + (*RateLimitResponse)(nil), // 6: gubernator.v3.RateLimitResponse + (*HealthCheckRequest)(nil), // 7: gubernator.v3.HealthCheckRequest + (*HealthCheckResponse)(nil), // 8: gubernator.v3.HealthCheckResponse + nil, // 9: gubernator.v3.RateLimitRequest.MetadataEntry + nil, // 10: gubernator.v3.RateLimitResponse.MetadataEntry } var file_gubernator_proto_depIdxs = []int32{ - 5, // 0: pb.gubernator.GetRateLimitsReq.requests:type_name -> pb.gubernator.RateLimitReq - 6, // 1: pb.gubernator.GetRateLimitsResp.responses:type_name -> pb.gubernator.RateLimitResp - 0, // 2: pb.gubernator.RateLimitReq.algorithm:type_name -> pb.gubernator.Algorithm - 1, // 3: pb.gubernator.RateLimitReq.behavior:type_name -> pb.gubernator.Behavior - 9, // 4: pb.gubernator.RateLimitReq.metadata:type_name -> pb.gubernator.RateLimitReq.MetadataEntry - 2, // 5: pb.gubernator.RateLimitResp.status:type_name -> pb.gubernator.Status - 10, // 6: pb.gubernator.RateLimitResp.metadata:type_name -> pb.gubernator.RateLimitResp.MetadataEntry - 3, // 7: pb.gubernator.V1.GetRateLimits:input_type -> pb.gubernator.GetRateLimitsReq - 7, // 8: pb.gubernator.V1.HealthCheck:input_type -> pb.gubernator.HealthCheckReq - 4, // 9: pb.gubernator.V1.GetRateLimits:output_type -> pb.gubernator.GetRateLimitsResp - 8, // 10: pb.gubernator.V1.HealthCheck:output_type -> pb.gubernator.HealthCheckResp - 9, // [9:11] is the sub-list for method output_type - 7, // [7:9] is the sub-list for method input_type + 5, // 0: gubernator.v3.CheckRateLimitsRequest.requests:type_name -> gubernator.v3.RateLimitRequest + 6, // 1: gubernator.v3.CheckRateLimitsResponse.responses:type_name -> gubernator.v3.RateLimitResponse + 0, // 2: gubernator.v3.RateLimitRequest.algorithm:type_name -> gubernator.v3.Algorithm + 1, // 3: gubernator.v3.RateLimitRequest.behavior:type_name -> gubernator.v3.Behavior + 9, // 4: gubernator.v3.RateLimitRequest.metadata:type_name -> gubernator.v3.RateLimitRequest.MetadataEntry + 2, // 5: gubernator.v3.RateLimitResponse.status:type_name -> gubernator.v3.Status + 10, // 6: gubernator.v3.RateLimitResponse.metadata:type_name -> gubernator.v3.RateLimitResponse.MetadataEntry + 7, // [7:7] is the sub-list for method output_type + 7, // [7:7] is the sub-list for method input_type 7, // [7:7] is the sub-list for extension type_name 7, // [7:7] is the sub-list for extension extendee 0, // [0:7] is the sub-list for field type_name @@ -845,7 +828,7 @@ func file_gubernator_proto_init() { } if !protoimpl.UnsafeEnabled { file_gubernator_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetRateLimitsReq); i { + switch v := v.(*CheckRateLimitsRequest); i { case 0: return &v.state case 1: @@ -857,7 +840,7 @@ func file_gubernator_proto_init() { } } file_gubernator_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetRateLimitsResp); i { + switch v := v.(*CheckRateLimitsResponse); i { case 0: return &v.state case 1: @@ -869,7 +852,7 @@ func file_gubernator_proto_init() { } } file_gubernator_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RateLimitReq); i { + switch v := v.(*RateLimitRequest); i { case 0: return &v.state case 1: @@ -881,7 +864,7 @@ func file_gubernator_proto_init() { } } file_gubernator_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RateLimitResp); i { + switch v := v.(*RateLimitResponse); i { case 0: return &v.state case 1: @@ -893,7 +876,7 @@ func file_gubernator_proto_init() { } } file_gubernator_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*HealthCheckReq); i { + switch v := v.(*HealthCheckRequest); i { case 0: return &v.state case 1: @@ -905,7 +888,7 @@ func file_gubernator_proto_init() { } } file_gubernator_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*HealthCheckResp); i { + switch v := v.(*HealthCheckResponse); i { case 0: return &v.state case 1: @@ -926,7 +909,7 @@ func file_gubernator_proto_init() { NumEnums: 3, NumMessages: 8, NumExtensions: 0, - NumServices: 1, + NumServices: 0, }, GoTypes: file_gubernator_proto_goTypes, DependencyIndexes: file_gubernator_proto_depIdxs, diff --git a/gubernator.pb.gw.go b/gubernator.pb.gw.go deleted file mode 100644 index bb46059..0000000 --- a/gubernator.pb.gw.go +++ /dev/null @@ -1,240 +0,0 @@ -// Code generated by protoc-gen-grpc-gateway. DO NOT EDIT. -// source: gubernator.proto - -/* -Package gubernator is a reverse proxy. - -It translates gRPC into RESTful JSON APIs. -*/ -package gubernator - -import ( - "context" - "io" - "net/http" - - "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" - "github.com/grpc-ecosystem/grpc-gateway/v2/utilities" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/metadata" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/proto" -) - -// Suppress "imported and not used" errors -var _ codes.Code -var _ io.Reader -var _ status.Status -var _ = runtime.String -var _ = utilities.NewDoubleArray -var _ = metadata.Join - -func request_V1_GetRateLimits_0(ctx context.Context, marshaler runtime.Marshaler, client V1Client, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq GetRateLimitsReq - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - - msg, err := client.GetRateLimits(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) - return msg, metadata, err - -} - -func local_request_V1_GetRateLimits_0(ctx context.Context, marshaler runtime.Marshaler, server V1Server, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq GetRateLimitsReq - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - - msg, err := server.GetRateLimits(ctx, &protoReq) - return msg, metadata, err - -} - -func request_V1_HealthCheck_0(ctx context.Context, marshaler runtime.Marshaler, client V1Client, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq HealthCheckReq - var metadata runtime.ServerMetadata - - msg, err := client.HealthCheck(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) - return msg, metadata, err - -} - -func local_request_V1_HealthCheck_0(ctx context.Context, marshaler runtime.Marshaler, server V1Server, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq HealthCheckReq - var metadata runtime.ServerMetadata - - msg, err := server.HealthCheck(ctx, &protoReq) - return msg, metadata, err - -} - -// RegisterV1HandlerServer registers the http handlers for service V1 to "mux". -// UnaryRPC :call V1Server directly. -// StreamingRPC :currently unsupported pending https://github.com/grpc/grpc-go/issues/906. -// Note that using this registration option will cause many gRPC library features to stop working. Consider using RegisterV1HandlerFromEndpoint instead. -func RegisterV1HandlerServer(ctx context.Context, mux *runtime.ServeMux, server V1Server) error { - - mux.Handle("POST", pattern_V1_GetRateLimits_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - var stream runtime.ServerTransportStream - ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/pb.gubernator.V1/GetRateLimits", runtime.WithHTTPPathPattern("/v1/GetRateLimits")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := local_request_V1_GetRateLimits_0(annotatedContext, inboundMarshaler, server, req, pathParams) - md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - - forward_V1_GetRateLimits_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - mux.Handle("GET", pattern_V1_HealthCheck_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - var stream runtime.ServerTransportStream - ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/pb.gubernator.V1/HealthCheck", runtime.WithHTTPPathPattern("/v1/HealthCheck")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := local_request_V1_HealthCheck_0(annotatedContext, inboundMarshaler, server, req, pathParams) - md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - - forward_V1_HealthCheck_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - return nil -} - -// RegisterV1HandlerFromEndpoint is same as RegisterV1Handler but -// automatically dials to "endpoint" and closes the connection when "ctx" gets done. -func RegisterV1HandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { - conn, err := grpc.DialContext(ctx, endpoint, opts...) - if err != nil { - return err - } - defer func() { - if err != nil { - if cerr := conn.Close(); cerr != nil { - grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) - } - return - } - go func() { - <-ctx.Done() - if cerr := conn.Close(); cerr != nil { - grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) - } - }() - }() - - return RegisterV1Handler(ctx, mux, conn) -} - -// RegisterV1Handler registers the http handlers for service V1 to "mux". -// The handlers forward requests to the grpc endpoint over "conn". -func RegisterV1Handler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { - return RegisterV1HandlerClient(ctx, mux, NewV1Client(conn)) -} - -// RegisterV1HandlerClient registers the http handlers for service V1 -// to "mux". The handlers forward requests to the grpc endpoint over the given implementation of "V1Client". -// Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "V1Client" -// doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in -// "V1Client" to call the correct interceptors. -func RegisterV1HandlerClient(ctx context.Context, mux *runtime.ServeMux, client V1Client) error { - - mux.Handle("POST", pattern_V1_GetRateLimits_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/pb.gubernator.V1/GetRateLimits", runtime.WithHTTPPathPattern("/v1/GetRateLimits")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := request_V1_GetRateLimits_0(annotatedContext, inboundMarshaler, client, req, pathParams) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - - forward_V1_GetRateLimits_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - mux.Handle("GET", pattern_V1_HealthCheck_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/pb.gubernator.V1/HealthCheck", runtime.WithHTTPPathPattern("/v1/HealthCheck")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := request_V1_HealthCheck_0(annotatedContext, inboundMarshaler, client, req, pathParams) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - - forward_V1_HealthCheck_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - return nil -} - -var ( - pattern_V1_GetRateLimits_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1}, []string{"v1", "GetRateLimits"}, "")) - - pattern_V1_HealthCheck_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1}, []string{"v1", "HealthCheck"}, "")) -) - -var ( - forward_V1_GetRateLimits_0 = runtime.ForwardResponseMessage - - forward_V1_HealthCheck_0 = runtime.ForwardResponseMessage -) diff --git a/gubernator.proto b/gubernator.proto index 626eabc..975a3fc 100644 --- a/gubernator.proto +++ b/gubernator.proto @@ -17,40 +17,17 @@ limitations under the License. syntax = "proto3"; option go_package = "github.com/gubernator-io/gubernator"; +package gubernator.v3; -option cc_generic_services = true; -package pb.gubernator; - -import "google/api/annotations.proto"; - -service V1 { - - // Given a list of rate limit requests, return the rate limits of each. - rpc GetRateLimits (GetRateLimitsReq) returns (GetRateLimitsResp) { - option (google.api.http) = { - post: "/v1/GetRateLimits" - body: "*" - }; - } - - // This method is for round trip benchmarking and can be used by - // the client to determine connectivity to the server - rpc HealthCheck (HealthCheckReq) returns (HealthCheckResp) { - option (google.api.http) = { - get: "/v1/HealthCheck" - }; - } +// Must specify at least one RateLimitRequest +message CheckRateLimitsRequest { + repeated RateLimitRequest requests = 1; } -// Must specify at least one Request -message GetRateLimitsReq { - repeated RateLimitReq requests = 1; -} - -// RateLimits returned are in the same order as the Requests -message GetRateLimitsResp { - repeated RateLimitResp responses = 1; +// RateLimits returned are in the same order provided in CheckRateLimitsRequest +message CheckRateLimitsResponse { + repeated RateLimitResponse responses = 1; } enum Algorithm { @@ -134,12 +111,12 @@ enum Behavior { // TODO: Add support for LOCAL. Which would force the rate limit to be handled by the local instance } -message RateLimitReq { +message RateLimitRequest { // The name of the rate limit IE: 'requests_per_second', 'gets_per_minute` string name = 1; // Uniquely identifies this rate limit IE: 'ip:10.2.10.7' or 'account:123445' - string unique_key = 2; + string unique_key = 2 [json_name="unique_key"]; // Rate limit requests optionally specify the number of hits a request adds to the matched limit. If Hit // is zero, the request returns the current limit, but does not increment the hit count. @@ -187,27 +164,27 @@ enum Status { OVER_LIMIT = 1; } -message RateLimitResp { +message RateLimitResponse { // The status of the rate limit. Status status = 1; - // The currently configured request limit (Identical to [[RateLimitReq.limit]]). + // The currently configured request limit (Identical to [[RateLimitRequest.limit]]). int64 limit = 2; // This is the number of requests remaining before the rate limit is hit but after subtracting the hits from the current request int64 remaining = 3; // This is the time when the rate limit span will be reset, provided as a unix timestamp in milliseconds. - int64 reset_time = 4; + int64 reset_time = 4 [json_name="reset_time"]; // Contains the error; If set all other values should be ignored string error = 5; // This is additional metadata that a client might find useful. (IE: Additional headers, coordinator ownership, etc..) map metadata = 6; } -message HealthCheckReq {} -message HealthCheckResp { +message HealthCheckRequest {} +message HealthCheckResponse { // Valid entries are 'healthy' or 'unhealthy' string status = 1; // If 'unhealthy', message indicates the problem string message = 2; // The number of peers we know about - int32 peer_count = 3; -} + int32 peer_count = 3 [json_name="peer_count"]; +} \ No newline at end of file diff --git a/gubernator_grpc.pb.go b/gubernator_grpc.pb.go deleted file mode 100644 index 209b75a..0000000 --- a/gubernator_grpc.pb.go +++ /dev/null @@ -1,165 +0,0 @@ -// -//Copyright 2018-2022 Mailgun Technologies Inc -// -//Licensed under the Apache License, Version 2.0 (the "License"); -//you may not use this file except in compliance with the License. -//You may obtain a copy of the License at -// -//http://www.apache.org/licenses/LICENSE-2.0 -// -//Unless required by applicable law or agreed to in writing, software -//distributed under the License is distributed on an "AS IS" BASIS, -//WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -//See the License for the specific language governing permissions and -//limitations under the License. - -// Code generated by protoc-gen-go-grpc. DO NOT EDIT. -// versions: -// - protoc-gen-go-grpc v1.3.0 -// - protoc (unknown) -// source: gubernator.proto - -package gubernator - -import ( - context "context" - grpc "google.golang.org/grpc" - codes "google.golang.org/grpc/codes" - status "google.golang.org/grpc/status" -) - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.32.0 or later. -const _ = grpc.SupportPackageIsVersion7 - -const ( - V1_GetRateLimits_FullMethodName = "/pb.gubernator.V1/GetRateLimits" - V1_HealthCheck_FullMethodName = "/pb.gubernator.V1/HealthCheck" -) - -// V1Client is the client API for V1 service. -// -// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. -type V1Client interface { - // Given a list of rate limit requests, return the rate limits of each. - GetRateLimits(ctx context.Context, in *GetRateLimitsReq, opts ...grpc.CallOption) (*GetRateLimitsResp, error) - // This method is for round trip benchmarking and can be used by - // the client to determine connectivity to the server - HealthCheck(ctx context.Context, in *HealthCheckReq, opts ...grpc.CallOption) (*HealthCheckResp, error) -} - -type v1Client struct { - cc grpc.ClientConnInterface -} - -func NewV1Client(cc grpc.ClientConnInterface) V1Client { - return &v1Client{cc} -} - -func (c *v1Client) GetRateLimits(ctx context.Context, in *GetRateLimitsReq, opts ...grpc.CallOption) (*GetRateLimitsResp, error) { - out := new(GetRateLimitsResp) - err := c.cc.Invoke(ctx, V1_GetRateLimits_FullMethodName, in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *v1Client) HealthCheck(ctx context.Context, in *HealthCheckReq, opts ...grpc.CallOption) (*HealthCheckResp, error) { - out := new(HealthCheckResp) - err := c.cc.Invoke(ctx, V1_HealthCheck_FullMethodName, in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -// V1Server is the server API for V1 service. -// All implementations should embed UnimplementedV1Server -// for forward compatibility -type V1Server interface { - // Given a list of rate limit requests, return the rate limits of each. - GetRateLimits(context.Context, *GetRateLimitsReq) (*GetRateLimitsResp, error) - // This method is for round trip benchmarking and can be used by - // the client to determine connectivity to the server - HealthCheck(context.Context, *HealthCheckReq) (*HealthCheckResp, error) -} - -// UnimplementedV1Server should be embedded to have forward compatible implementations. -type UnimplementedV1Server struct { -} - -func (UnimplementedV1Server) GetRateLimits(context.Context, *GetRateLimitsReq) (*GetRateLimitsResp, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetRateLimits not implemented") -} -func (UnimplementedV1Server) HealthCheck(context.Context, *HealthCheckReq) (*HealthCheckResp, error) { - return nil, status.Errorf(codes.Unimplemented, "method HealthCheck not implemented") -} - -// UnsafeV1Server may be embedded to opt out of forward compatibility for this service. -// Use of this interface is not recommended, as added methods to V1Server will -// result in compilation errors. -type UnsafeV1Server interface { - mustEmbedUnimplementedV1Server() -} - -func RegisterV1Server(s grpc.ServiceRegistrar, srv V1Server) { - s.RegisterService(&V1_ServiceDesc, srv) -} - -func _V1_GetRateLimits_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(GetRateLimitsReq) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(V1Server).GetRateLimits(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: V1_GetRateLimits_FullMethodName, - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(V1Server).GetRateLimits(ctx, req.(*GetRateLimitsReq)) - } - return interceptor(ctx, in, info, handler) -} - -func _V1_HealthCheck_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(HealthCheckReq) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(V1Server).HealthCheck(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: V1_HealthCheck_FullMethodName, - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(V1Server).HealthCheck(ctx, req.(*HealthCheckReq)) - } - return interceptor(ctx, in, info, handler) -} - -// V1_ServiceDesc is the grpc.ServiceDesc for V1 service. -// It's only intended for direct use with grpc.RegisterService, -// and not to be introspected or modified (even as a copy) -var V1_ServiceDesc = grpc.ServiceDesc{ - ServiceName: "pb.gubernator.V1", - HandlerType: (*V1Server)(nil), - Methods: []grpc.MethodDesc{ - { - MethodName: "GetRateLimits", - Handler: _V1_GetRateLimits_Handler, - }, - { - MethodName: "HealthCheck", - Handler: _V1_HealthCheck_Handler, - }, - }, - Streams: []grpc.StreamDesc{}, - Metadata: "gubernator.proto", -} diff --git a/http.go b/http.go new file mode 100644 index 0000000..3f40bbf --- /dev/null +++ b/http.go @@ -0,0 +1,188 @@ +package gubernator + +import ( + "context" + "fmt" + "net/http" + + "github.com/duh-rpc/duh-go" + v1 "github.com/duh-rpc/duh-go/proto/v1" + "github.com/prometheus/client_golang/prometheus" + "go.opentelemetry.io/otel/propagation" +) + +const ( + RPCPeerForward = "/v1/peer.forward" + RPCPeerUpdate = "/v1/peer.update" + RPCRateLimitCheck = "/v1/rate-limit.check" + RPCHealthCheck = "/v1/health.check" +) + +type Handler struct { + prop propagation.TraceContext + duration *prometheus.SummaryVec + metrics http.Handler + service *Service +} + +func NewHandler(s *Service, metrics http.Handler) *Handler { + return &Handler{ + duration: prometheus.NewSummaryVec(prometheus.SummaryOpts{ + Name: "gubernator_http_handler_duration", + Help: "The timings of http requests handled by the service", + Objectives: map[float64]float64{ + 0.5: 0.05, + 0.99: 0.001, + }, + }, []string{"path"}), + metrics: metrics, + service: s, + } +} + +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer prometheus.NewTimer(h.duration.WithLabelValues(r.URL.Path)).ObserveDuration() + ctx := h.prop.Extract(r.Context(), propagation.HeaderCarrier(r.Header)) + + switch r.URL.Path { + case RPCPeerForward: + h.PeerForward(ctx, w, r) + return + case RPCPeerUpdate: + h.PeerUpdate(ctx, w, r) + return + case RPCRateLimitCheck: + h.CheckRateLimit(ctx, w, r) + return + case RPCHealthCheck: + h.HealthCheck(w, r) + return + case "/metrics": + h.metrics.ServeHTTP(w, r) + return + case "/healthz": + h.HealthZ(w, r) + return + } + duh.ReplyWithCode(w, r, duh.CodeNotImplemented, nil, "no such method; "+r.URL.Path) +} + +func (h *Handler) PeerForward(ctx context.Context, w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + duh.ReplyWithCode(w, r, duh.CodeBadRequest, nil, + fmt.Sprintf("http method '%s' not allowed; only POST", r.Method)) + return + } + + var req ForwardRequest + if err := duh.ReadRequest(r, &req); err != nil { + duh.ReplyError(w, r, err) + return + } + var resp ForwardResponse + if err := h.service.Forward(ctx, &req, &resp); err != nil { + duh.ReplyError(w, r, err) + return + } + duh.Reply(w, r, duh.CodeOK, &resp) +} + +func (h *Handler) PeerUpdate(ctx context.Context, w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + duh.ReplyWithCode(w, r, duh.CodeBadRequest, nil, + fmt.Sprintf("http method '%s' not allowed; only POST", r.Method)) + return + } + + var req UpdateRequest + if err := duh.ReadRequest(r, &req); err != nil { + duh.ReplyError(w, r, err) + return + } + var resp v1.Reply + if err := h.service.Update(ctx, &req, &resp); err != nil { + duh.ReplyError(w, r, err) + return + } + duh.Reply(w, r, duh.CodeOK, &resp) +} + +func (h *Handler) CheckRateLimit(ctx context.Context, w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + duh.ReplyWithCode(w, r, duh.CodeBadRequest, nil, + fmt.Sprintf("http method '%s' not allowed; only POST", r.Method)) + return + } + + var req CheckRateLimitsRequest + if err := duh.ReadRequest(r, &req); err != nil { + duh.ReplyError(w, r, err) + return + } + + var resp CheckRateLimitsResponse + if err := h.service.CheckRateLimits(ctx, &req, &resp); err != nil { + duh.ReplyError(w, r, err) + return + } + duh.Reply(w, r, duh.CodeOK, &resp) +} + +func (h *Handler) HealthCheck(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + duh.ReplyWithCode(w, r, duh.CodeBadRequest, nil, + fmt.Sprintf("http method '%s' not allowed; only POST", r.Method)) + return + } + + var req HealthCheckRequest + if err := duh.ReadRequest(r, &req); err != nil { + duh.ReplyError(w, r, err) + return + } + var resp HealthCheckResponse + if err := h.service.HealthCheck(r.Context(), &req, &resp); err != nil { + duh.ReplyError(w, r, err) + return + } + duh.Reply(w, r, duh.CodeOK, &resp) +} + +func (h *Handler) ResetMetrics(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + duh.ReplyWithCode(w, r, duh.CodeBadRequest, nil, + fmt.Sprintf("http method '%s' not allowed; only POST", r.Method)) + return + } + + var req HealthCheckRequest + if err := duh.ReadRequest(r, &req); err != nil { + duh.ReplyError(w, r, err) + return + } + var resp HealthCheckResponse + if err := h.service.HealthCheck(r.Context(), &req, &resp); err != nil { + duh.ReplyError(w, r, err) + return + } + duh.Reply(w, r, duh.CodeOK, &resp) +} + +func (h *Handler) HealthZ(w http.ResponseWriter, r *http.Request) { + var resp HealthCheckResponse + if err := h.service.HealthCheck(r.Context(), nil, &resp); err != nil { + duh.ReplyError(w, r, err) + return + } + duh.Reply(w, r, duh.CodeOK, &resp) +} + +// Describe fetches prometheus metrics to be registered +func (h *Handler) Describe(ch chan<- *prometheus.Desc) { + h.duration.Describe(ch) +} + +// Collect fetches metrics from the server for use by prometheus +func (h *Handler) Collect(ch chan<- prometheus.Metric) { + h.duration.Collect(ch) +} diff --git a/interval.go b/interval.go index 639e613..54384f5 100644 --- a/interval.go +++ b/interval.go @@ -20,8 +20,8 @@ import ( "errors" "time" - "github.com/mailgun/holster/v4/clock" - "github.com/mailgun/holster/v4/syncutil" + "github.com/kapetan-io/tackle/clock" + "github.com/kapetan-io/tackle/wait" ) // Interval is a one-shot ticker. Call `Next()` to trigger the start of an @@ -29,7 +29,7 @@ import ( type Interval struct { C chan struct{} in chan struct{} - wg syncutil.WaitGroup + wg wait.Group } // NewInterval creates a new ticker like object, however diff --git a/interval_test.go b/interval_test.go index e1d439c..c41c2d8 100644 --- a/interval_test.go +++ b/interval_test.go @@ -19,8 +19,8 @@ package gubernator_test import ( "testing" - "github.com/gubernator-io/gubernator/v2" - "github.com/mailgun/holster/v4/clock" + "github.com/gubernator-io/gubernator/v3" + "github.com/kapetan-io/tackle/clock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/kubernetes.go b/kubernetes.go index e25ebf7..01b3ece 100644 --- a/kubernetes.go +++ b/kubernetes.go @@ -19,11 +19,11 @@ package gubernator import ( "context" "fmt" + "log/slog" "reflect" - "github.com/mailgun/holster/v4/setter" + "github.com/kapetan-io/tackle/set" "github.com/pkg/errors" - "github.com/sirupsen/logrus" api_v1 "k8s.io/api/core/v1" meta_v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -93,7 +93,7 @@ func NewK8sPool(conf K8sPoolConfig) (*K8sPool, error) { watchCtx: ctx, watchCancel: cancel, } - setter.SetDefault(&pool.log, logrus.WithField("category", "gubernator")) + set.Default(&pool.log, slog.Default().With("category", "gubernator")) return pool, pool.start() } @@ -120,27 +120,39 @@ func (e *K8sPool) startGenericWatch(objType runtime.Object, listWatch *cache.Lis e.informer.AddEventHandler(cache.ResourceEventHandlerFuncs{ AddFunc: func(obj interface{}) { key, err := cache.MetaNamespaceKeyFunc(obj) - e.log.Debugf("Queue (Add) '%s' - %v", key, err) + e.log.LogAttrs(context.TODO(), slog.LevelDebug, "Queue (Add)", + slog.String("key", key), + ) if err != nil { - e.log.Errorf("while calling MetaNamespaceKeyFunc(): %s", err) + e.log.LogAttrs(context.TODO(), slog.LevelError, "while calling MetaNamespaceKeyFunc()", + ErrAttr(err), + ) return } updateFunc() }, UpdateFunc: func(obj, new interface{}) { key, err := cache.MetaNamespaceKeyFunc(obj) - e.log.Debugf("Queue (Update) '%s' - %v", key, err) + e.log.LogAttrs(context.TODO(), slog.LevelDebug, "Queue (Update)", + slog.String("key", key), + ) if err != nil { - e.log.Errorf("while calling MetaNamespaceKeyFunc(): %s", err) + e.log.LogAttrs(context.TODO(), slog.LevelError, "while calling MetaNamespaceKeyFunc()", + ErrAttr(err), + ) return } updateFunc() }, DeleteFunc: func(obj interface{}) { key, err := cache.MetaNamespaceKeyFunc(obj) - e.log.Debugf("Queue (Delete) '%s' - %v", key, err) + e.log.LogAttrs(context.TODO(), slog.LevelDebug, "Queue (Delete)", + slog.String("key", key), + ) if err != nil { - e.log.Errorf("while calling MetaNamespaceKeyFunc(): %s", err) + e.log.LogAttrs(context.TODO(), slog.LevelError, "while calling MetaNamespaceKeyFunc()", + ErrAttr(err), + ) return } updateFunc() @@ -192,15 +204,19 @@ main: for _, obj := range e.informer.GetStore().List() { pod, ok := obj.(*api_v1.Pod) if !ok { - e.log.Errorf("expected type v1.Endpoints got '%s' instead", reflect.TypeOf(obj).String()) + e.log.LogAttrs(context.TODO(), slog.LevelError, "expected type v1.Endpoints", + slog.String("got_type", reflect.TypeOf(obj).String()), + ) } - peer := PeerInfo{GRPCAddress: fmt.Sprintf("%s:%s", pod.Status.PodIP, e.conf.PodPort)} + peer := PeerInfo{HTTPAddress: fmt.Sprintf("%s:%s", pod.Status.PodIP, e.conf.PodPort)} // if containers are not ready or not running then skip this peer for _, status := range pod.Status.ContainerStatuses { if !status.Ready || status.State.Running == nil { - e.log.Debugf("Skipping peer because it's not ready or not running: %+v\n", peer) + e.log.LogAttrs(context.TODO(), slog.LevelDebug, "Skipping peer because it's not ready or not running", + slog.Any("peer", peer), + ) continue main } } @@ -208,7 +224,9 @@ main: if pod.Status.PodIP == e.conf.PodIP { peer.IsOwner = true } - e.log.Debugf("Peer: %+v\n", peer) + e.log.LogAttrs(context.TODO(), slog.LevelDebug, "Peer", + slog.Any("peer", peer), + ) peers = append(peers, peer) } e.conf.OnUpdate(peers) @@ -220,7 +238,9 @@ func (e *K8sPool) updatePeersFromEndpoints() { for _, obj := range e.informer.GetStore().List() { endpoint, ok := obj.(*api_v1.Endpoints) if !ok { - e.log.Errorf("expected type v1.Endpoints got '%s' instead", reflect.TypeOf(obj).String()) + e.log.LogAttrs(context.TODO(), slog.LevelError, "expected type v1.Endpoints", + slog.String("got_type", reflect.TypeOf(obj).String()), + ) } for _, s := range endpoint.Subsets { @@ -228,13 +248,15 @@ func (e *K8sPool) updatePeersFromEndpoints() { // TODO(thrawn01): Might consider using the `namespace` as the `DataCenter`. We should // do what ever k8s convention is for identifying a k8s cluster within a federated multi-data // center setup. - peer := PeerInfo{GRPCAddress: fmt.Sprintf("%s:%s", addr.IP, e.conf.PodPort)} + peer := PeerInfo{HTTPAddress: fmt.Sprintf("%s:%s", addr.IP, e.conf.PodPort)} if addr.IP == e.conf.PodIP { peer.IsOwner = true } peers = append(peers, peer) - e.log.Debugf("Peer: %+v\n", peer) + e.log.LogAttrs(context.TODO(), slog.LevelDebug, "Peer", + slog.Any("peer", peer), + ) } } } diff --git a/log.go b/log.go index be044a5..5e5bfd5 100644 --- a/log.go +++ b/log.go @@ -1,34 +1,72 @@ package gubernator import ( + "bufio" "context" - "time" - - "github.com/sirupsen/logrus" + "io" + "log/slog" + "sync" ) -// The FieldLogger interface generalizes the Entry and Logger types type FieldLogger interface { - WithField(key string, value interface{}) *logrus.Entry - WithFields(fields logrus.Fields) *logrus.Entry - WithError(err error) *logrus.Entry - WithContext(ctx context.Context) *logrus.Entry - WithTime(t time.Time) *logrus.Entry - - Tracef(format string, args ...interface{}) - Debugf(format string, args ...interface{}) - Infof(format string, args ...interface{}) - Printf(format string, args ...interface{}) - Warnf(format string, args ...interface{}) - Warningf(format string, args ...interface{}) - Errorf(format string, args ...interface{}) - Fatalf(format string, args ...interface{}) - - Log(level logrus.Level, args ...interface{}) - Debug(args ...interface{}) - Info(args ...interface{}) - Print(args ...interface{}) - Warn(args ...interface{}) - Warning(args ...interface{}) - Error(args ...interface{}) + Handler() slog.Handler + With(args ...any) *slog.Logger + WithGroup(name string) *slog.Logger + Enabled(ctx context.Context, level slog.Level) bool + Log(ctx context.Context, level slog.Level, msg string, args ...any) + LogAttrs(ctx context.Context, level slog.Level, msg string, attrs ...slog.Attr) + Debug(msg string, args ...any) + DebugContext(ctx context.Context, msg string, args ...any) + Info(msg string, args ...any) + InfoContext(ctx context.Context, msg string, args ...any) + Warn(msg string, args ...any) + WarnContext(ctx context.Context, msg string, args ...any) + Error(msg string, args ...any) + ErrorContext(ctx context.Context, msg string, args ...any) +} + +func ErrAttr(err error) slog.Attr { + return slog.Any("error", err) +} + +type logAdaptor struct { + writer *io.PipeWriter + closer func() +} + +func (l *logAdaptor) Write(p []byte) (n int, err error) { + return l.writer.Write(p) +} + +func (l *logAdaptor) Close() error { + l.closer() + return nil +} + +func newLogAdaptor(log FieldLogger) *logAdaptor { + reader, writer := io.Pipe() + var wg sync.WaitGroup + wg.Add(1) + + go func() { + scanner := bufio.NewScanner(reader) + for scanner.Scan() { + log.Info(scanner.Text()) + } + if err := scanner.Err(); err != nil { + log.LogAttrs(context.TODO(), slog.LevelError, "Error while reading from Writer", + ErrAttr(err), + ) + } + _ = reader.Close() + wg.Done() + }() + + return &logAdaptor{ + writer: writer, + closer: func() { + _ = writer.Close() + wg.Wait() + }, + } } diff --git a/logging.go b/logging.go deleted file mode 100644 index 968b9d2..0000000 --- a/logging.go +++ /dev/null @@ -1,55 +0,0 @@ -/* -Copyright 2018-2023 Mailgun Technologies Inc - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package gubernator - -import ( - "encoding/json" - - "github.com/pkg/errors" - "github.com/sirupsen/logrus" -) - -type LogLevelJSON struct { - Level logrus.Level -} - -func (ll LogLevelJSON) MarshalJSON() ([]byte, error) { - return json.Marshal(ll.String()) -} - -func (ll *LogLevelJSON) UnmarshalJSON(b []byte) error { - var v interface{} - var err error - - if err = json.Unmarshal(b, &v); err != nil { - return err - } - - switch value := v.(type) { - case float64: - ll.Level = logrus.Level(int32(value)) - case string: - ll.Level, err = logrus.ParseLevel(value) - default: - return errors.New("invalid log level") - } - return err -} - -func (ll LogLevelJSON) String() string { - return ll.Level.String() -} diff --git a/lru_cache.go b/lru_cache.go new file mode 100644 index 0000000..c6a14cf --- /dev/null +++ b/lru_cache.go @@ -0,0 +1,158 @@ +/* +Modifications Copyright 2024 Derrick Wippler + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +This work is derived from github.com/golang/groupcache/lru +*/ + +package gubernator + +import ( + "container/list" + "sync" + "sync/atomic" + + "github.com/kapetan-io/tackle/set" +) + +// LRUCache is a mutex protected LRU cache that supports expiration and is thread-safe +type LRUCache struct { + cache map[string]*list.Element + ll *list.List + mu sync.Mutex + stats CacheStats + cacheSize int + cacheLen int64 +} + +var _ Cache = &LRUCache{} + +// NewLRUCache creates a new Cache with a maximum size. +func NewLRUCache(maxSize int) *LRUCache { + set.Default(&maxSize, 50_000) + + return &LRUCache{ + cache: make(map[string]*list.Element), + ll: list.New(), + cacheSize: maxSize, + } +} + +// Each maintains a goroutine that iterates over every item in the cache. +// Other go routines operating on this cache will block until all items +// are read from the returned channel. +func (c *LRUCache) Each() chan *CacheItem { + out := make(chan *CacheItem) + go func() { + c.mu.Lock() + defer c.mu.Unlock() + + for _, ele := range c.cache { + out <- ele.Value.(*CacheItem) + } + close(out) + }() + return out +} + +// AddIfNotPresent adds the item to the cache if it doesn't already exist. +// Returns true if the item was added, false if the item already exists. +func (c *LRUCache) AddIfNotPresent(item *CacheItem) bool { + c.mu.Lock() + defer c.mu.Unlock() + + // If the key already exist, do nothing + if _, ok := c.cache[item.Key]; ok { + return false + } + + ele := c.ll.PushFront(item) + c.cache[item.Key] = ele + if c.cacheSize != 0 && c.ll.Len() > c.cacheSize { + c.removeOldest() + } + atomic.StoreInt64(&c.cacheLen, int64(c.ll.Len())) + return true +} + +// GetItem returns the item stored in the cache +func (c *LRUCache) GetItem(key string) (item *CacheItem, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + if ele, hit := c.cache[key]; hit { + entry := ele.Value.(*CacheItem) + + c.stats.Hit++ + c.ll.MoveToFront(ele) + return entry, true + } + + c.stats.Miss++ + return +} + +// Remove removes the provided key from the cache. +func (c *LRUCache) Remove(key string) { + c.mu.Lock() + defer c.mu.Unlock() + + if ele, hit := c.cache[key]; hit { + c.removeElement(ele) + } +} + +// RemoveOldest removes the oldest item from the cache. +func (c *LRUCache) removeOldest() { + ele := c.ll.Back() + if ele != nil { + entry := ele.Value.(*CacheItem) + + if MillisecondNow() < entry.ExpireAt { + c.stats.UnexpiredEvictions++ + } + + c.removeElement(ele) + } +} + +func (c *LRUCache) removeElement(e *list.Element) { + c.ll.Remove(e) + kv := e.Value.(*CacheItem) + delete(c.cache, kv.Key) + atomic.StoreInt64(&c.cacheLen, int64(c.ll.Len())) +} + +// Size returns the number of items in the cache. +func (c *LRUCache) Size() int64 { + return atomic.LoadInt64(&c.cacheLen) +} + +func (c *LRUCache) Close() error { + c.cache = nil + c.ll = nil + c.cacheLen = 0 + return nil +} + +// Stats returns the current status for the cache +func (c *LRUCache) Stats() CacheStats { + c.mu.Lock() + defer func() { + c.stats = CacheStats{} + c.mu.Unlock() + }() + + c.stats.Size = atomic.LoadInt64(&c.cacheLen) + return c.stats +} diff --git a/lrucache_test.go b/lru_cache_test.go similarity index 86% rename from lrucache_test.go rename to lru_cache_test.go index 51f33bc..11ad30e 100644 --- a/lrucache_test.go +++ b/lru_cache_test.go @@ -24,20 +24,18 @@ import ( "testing" "time" - "github.com/gubernator-io/gubernator/v2" - "github.com/mailgun/holster/v4/clock" + "github.com/gubernator-io/gubernator/v3" + "github.com/kapetan-io/tackle/clock" "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - dto "github.com/prometheus/client_model/go" ) -func TestLRUCache(t *testing.T) { +func TestLRUMutexCache(t *testing.T) { const iterations = 1000 const concurrency = 100 expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() - var mutex sync.Mutex t.Run("Happy path", func(t *testing.T) { cache := gubernator.NewLRUCache(0) @@ -50,10 +48,7 @@ func TestLRUCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - exists := cache.Add(item) - mutex.Unlock() - assert.False(t, exists) + assert.True(t, cache.AddIfNotPresent(item)) } // Validate cache. @@ -61,9 +56,7 @@ func TestLRUCache(t *testing.T) { for i := 0; i < iterations; i++ { key := strconv.Itoa(i) - mutex.Lock() item, ok := cache.GetItem(key) - mutex.Unlock() require.True(t, ok) require.NotNil(t, item) assert.Equal(t, item.Value, i) @@ -72,9 +65,7 @@ func TestLRUCache(t *testing.T) { // Clear cache. for i := 0; i < iterations; i++ { key := strconv.Itoa(i) - mutex.Lock() cache.Remove(key) - mutex.Unlock() } assert.Zero(t, cache.Size()) @@ -90,8 +81,7 @@ func TestLRUCache(t *testing.T) { Value: "initial value", ExpireAt: expireAt, } - exists1 := cache.Add(item1) - require.False(t, exists1) + require.True(t, cache.AddIfNotPresent(item1)) // Update same key. item2 := &gubernator.CacheItem{ @@ -99,8 +89,11 @@ func TestLRUCache(t *testing.T) { Value: "new value", ExpireAt: expireAt, } - exists2 := cache.Add(item2) - require.True(t, exists2) + require.False(t, cache.AddIfNotPresent(item2)) + + updateItem, ok := cache.GetItem(item1.Key) + require.True(t, ok) + updateItem.Value = "new value" // Verify. verifyItem, ok := cache.GetItem(key) @@ -119,8 +112,7 @@ func TestLRUCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - exists := cache.Add(item) - assert.False(t, exists) + assert.True(t, cache.AddIfNotPresent(item)) } assert.Equal(t, int64(iterations), cache.Size()) @@ -136,9 +128,7 @@ func TestLRUCache(t *testing.T) { for i := 0; i < iterations; i++ { key := strconv.Itoa(i) - mutex.Lock() item, ok := cache.GetItem(key) - mutex.Unlock() assert.True(t, ok) require.NotNil(t, item) assert.Equal(t, item.Value, i) @@ -171,9 +161,7 @@ func TestLRUCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - cache.Add(item) - mutex.Unlock() + cache.AddIfNotPresent(item) } }() } @@ -194,10 +182,7 @@ func TestLRUCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - exists := cache.Add(item) - mutex.Unlock() - assert.False(t, exists) + assert.True(t, cache.AddIfNotPresent(item)) } assert.Equal(t, int64(iterations), cache.Size()) @@ -213,9 +198,7 @@ func TestLRUCache(t *testing.T) { for i := 0; i < iterations; i++ { key := strconv.Itoa(i) - mutex.Lock() item, ok := cache.GetItem(key) - mutex.Unlock() assert.True(t, ok) require.NotNil(t, item) assert.Equal(t, item.Value, i) @@ -233,9 +216,7 @@ func TestLRUCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - cache.Add(item) - mutex.Unlock() + cache.AddIfNotPresent(item) } }() } @@ -256,9 +237,7 @@ func TestLRUCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - cache.Add(item) - mutex.Unlock() + cache.AddIfNotPresent(item) } assert.Equal(t, int64(iterations), cache.Size()) @@ -275,15 +254,11 @@ func TestLRUCache(t *testing.T) { for i := 0; i < iterations; i++ { // Get, cache hit. key := strconv.Itoa(i) - mutex.Lock() _, _ = cache.GetItem(key) - mutex.Unlock() // Get, cache miss. key2 := strconv.Itoa(rand.Intn(1000) + 10000) - mutex.Lock() _, _ = cache.GetItem(key2) - mutex.Unlock() } }() @@ -299,9 +274,7 @@ func TestLRUCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - cache.Add(item) - mutex.Unlock() + cache.AddIfNotPresent(item) // Add new. key2 := strconv.Itoa(rand.Intn(1000) + 20000) @@ -310,13 +283,11 @@ func TestLRUCache(t *testing.T) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - cache.Add(item2) - mutex.Unlock() + cache.AddIfNotPresent(item2) } }() - collector := gubernator.NewLRUCacheCollector() + collector := gubernator.NewCacheCollector() collector.AddCache(cache) go func() { @@ -337,12 +308,12 @@ func TestLRUCache(t *testing.T) { }) t.Run("Check gubernator_unexpired_evictions_count metric is not incremented when expired item is evicted", func(t *testing.T) { - defer clock.Freeze(clock.Now()).Unfreeze() + defer clock.Freeze(clock.Now()).UnFreeze() promRegister := prometheus.NewRegistry() // The LRU cache for storing rate limits. - cacheCollector := gubernator.NewLRUCacheCollector() + cacheCollector := gubernator.NewCacheCollector() err := promRegister.Register(cacheCollector) require.NoError(t, err) @@ -351,7 +322,7 @@ func TestLRUCache(t *testing.T) { // fill cache with short duration cache items for i := 0; i < 10; i++ { - cache.Add(&gubernator.CacheItem{ + cache.AddIfNotPresent(&gubernator.CacheItem{ Algorithm: gubernator.Algorithm_LEAKY_BUCKET, Key: fmt.Sprintf("short-expiry-%d", i), Value: "bar", @@ -363,7 +334,7 @@ func TestLRUCache(t *testing.T) { clock.Advance(6 * time.Minute) // add a new cache item to force eviction - cache.Add(&gubernator.CacheItem{ + cache.AddIfNotPresent(&gubernator.CacheItem{ Algorithm: gubernator.Algorithm_LEAKY_BUCKET, Key: "evict1", Value: "bar", @@ -384,12 +355,12 @@ func TestLRUCache(t *testing.T) { }) t.Run("Check gubernator_unexpired_evictions_count metric is incremented when unexpired item is evicted", func(t *testing.T) { - defer clock.Freeze(clock.Now()).Unfreeze() + defer clock.Freeze(clock.Now()).UnFreeze() promRegister := prometheus.NewRegistry() // The LRU cache for storing rate limits. - cacheCollector := gubernator.NewLRUCacheCollector() + cacheCollector := gubernator.NewCacheCollector() err := promRegister.Register(cacheCollector) require.NoError(t, err) @@ -398,7 +369,7 @@ func TestLRUCache(t *testing.T) { // fill cache with long duration cache items for i := 0; i < 10; i++ { - cache.Add(&gubernator.CacheItem{ + cache.AddIfNotPresent(&gubernator.CacheItem{ Algorithm: gubernator.Algorithm_LEAKY_BUCKET, Key: fmt.Sprintf("long-expiry-%d", i), Value: "bar", @@ -407,7 +378,7 @@ func TestLRUCache(t *testing.T) { } // add a new cache item to force eviction - cache.Add(&gubernator.CacheItem{ + cache.AddIfNotPresent(&gubernator.CacheItem{ Algorithm: gubernator.Algorithm_LEAKY_BUCKET, Key: "evict2", Value: "bar", @@ -428,8 +399,7 @@ func TestLRUCache(t *testing.T) { }) } -func BenchmarkLRUCache(b *testing.B) { - var mutex sync.Mutex +func BenchmarkLRUMutexCache(b *testing.B) { b.Run("Sequential reads", func(b *testing.B) { cache := gubernator.NewLRUCache(b.N) @@ -443,8 +413,7 @@ func BenchmarkLRUCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - exists := cache.Add(item) - assert.False(b, exists) + assert.True(b, cache.AddIfNotPresent(item)) } b.ReportAllocs() @@ -452,9 +421,7 @@ func BenchmarkLRUCache(b *testing.B) { for i := 0; i < b.N; i++ { key := strconv.Itoa(i) - mutex.Lock() _, _ = cache.GetItem(key) - mutex.Unlock() } }) @@ -472,9 +439,7 @@ func BenchmarkLRUCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - cache.Add(item) - mutex.Unlock() + cache.AddIfNotPresent(item) } }) @@ -490,8 +455,7 @@ func BenchmarkLRUCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - exists := cache.Add(item) - assert.False(b, exists) + assert.True(b, cache.AddIfNotPresent(item)) } var launchWg, doneWg sync.WaitGroup @@ -505,9 +469,7 @@ func BenchmarkLRUCache(b *testing.B) { defer doneWg.Done() launchWg.Wait() - mutex.Lock() _, _ = cache.GetItem(key) - mutex.Unlock() }() } @@ -536,9 +498,7 @@ func BenchmarkLRUCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - cache.Add(item) - mutex.Unlock() + cache.AddIfNotPresent(item) }(i) } @@ -562,8 +522,7 @@ func BenchmarkLRUCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - exists := cache.Add(item) - assert.False(b, exists) + assert.True(b, cache.AddIfNotPresent(item)) } for i := 0; i < b.N; i++ { @@ -574,9 +533,7 @@ func BenchmarkLRUCache(b *testing.B) { defer doneWg.Done() launchWg.Wait() - mutex.Lock() _, _ = cache.GetItem(key) - mutex.Unlock() }() go func(i int) { @@ -588,9 +545,7 @@ func BenchmarkLRUCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - cache.Add(item) - mutex.Unlock() + cache.AddIfNotPresent(item) }(i) } @@ -614,9 +569,7 @@ func BenchmarkLRUCache(b *testing.B) { launchWg.Wait() key := strconv.Itoa(i) - mutex.Lock() _, _ = cache.GetItem(key) - mutex.Unlock() }(i) go func(i int) { @@ -629,9 +582,7 @@ func BenchmarkLRUCache(b *testing.B) { Value: i, ExpireAt: expireAt, } - mutex.Lock() - cache.Add(item) - mutex.Unlock() + _ = cache.AddIfNotPresent(item) }(i) } diff --git a/lrucache.go b/lrucache.go deleted file mode 100644 index 0386720..0000000 --- a/lrucache.go +++ /dev/null @@ -1,214 +0,0 @@ -/* -Modifications Copyright 2018-2022 Mailgun Technologies Inc - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. - -This work is derived from github.com/golang/groupcache/lru -*/ - -package gubernator - -import ( - "container/list" - "sync/atomic" - - "github.com/mailgun/holster/v4/clock" - "github.com/mailgun/holster/v4/setter" - "github.com/prometheus/client_golang/prometheus" -) - -// LRUCache is an LRU cache that supports expiration and is not thread-safe -// Be sure to use a mutex to prevent concurrent method calls. -type LRUCache struct { - cache map[string]*list.Element - ll *list.List - cacheSize int - cacheLen int64 -} - -// LRUCacheCollector provides prometheus metrics collector for LRUCache. -// Register only one collector, add one or more caches to this collector. -type LRUCacheCollector struct { - caches []Cache -} - -var _ Cache = &LRUCache{} -var _ prometheus.Collector = &LRUCacheCollector{} - -var metricCacheSize = prometheus.NewGauge(prometheus.GaugeOpts{ - Name: "gubernator_cache_size", - Help: "The number of items in LRU Cache which holds the rate limits.", -}) -var metricCacheAccess = prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "gubernator_cache_access_count", - Help: "Cache access counts. Label \"type\" = hit|miss.", -}, []string{"type"}) -var metricCacheUnexpiredEvictions = prometheus.NewCounter(prometheus.CounterOpts{ - Name: "gubernator_unexpired_evictions_count", - Help: "Count the number of cache items which were evicted while unexpired.", -}) - -// NewLRUCache creates a new Cache with a maximum size. -func NewLRUCache(maxSize int) *LRUCache { - setter.SetDefault(&maxSize, 50_000) - - return &LRUCache{ - cache: make(map[string]*list.Element), - ll: list.New(), - cacheSize: maxSize, - } -} - -// Each is not thread-safe. Each() maintains a goroutine that iterates. -// Other go routines cannot safely access the Cache while iterating. -// It would be safer if this were done using an iterator or delegate pattern -// that doesn't require a goroutine. May need to reassess functional requirements. -func (c *LRUCache) Each() chan *CacheItem { - out := make(chan *CacheItem) - go func() { - for _, ele := range c.cache { - out <- ele.Value.(*CacheItem) - } - close(out) - }() - return out -} - -// Add adds a value to the cache. -func (c *LRUCache) Add(item *CacheItem) bool { - // If the key already exist, set the new value - if ee, ok := c.cache[item.Key]; ok { - c.ll.MoveToFront(ee) - ee.Value = item - return true - } - - ele := c.ll.PushFront(item) - c.cache[item.Key] = ele - if c.cacheSize != 0 && c.ll.Len() > c.cacheSize { - c.removeOldest() - } - atomic.StoreInt64(&c.cacheLen, int64(c.ll.Len())) - return false -} - -// MillisecondNow returns unix epoch in milliseconds -func MillisecondNow() int64 { - return clock.Now().UnixNano() / 1000000 -} - -// GetItem returns the item stored in the cache -func (c *LRUCache) GetItem(key string) (item *CacheItem, ok bool) { - if ele, hit := c.cache[key]; hit { - entry := ele.Value.(*CacheItem) - - if entry.IsExpired() { - c.removeElement(ele) - metricCacheAccess.WithLabelValues("miss").Add(1) - return - } - - metricCacheAccess.WithLabelValues("hit").Add(1) - c.ll.MoveToFront(ele) - return entry, true - } - - metricCacheAccess.WithLabelValues("miss").Add(1) - return -} - -// Remove removes the provided key from the cache. -func (c *LRUCache) Remove(key string) { - if ele, hit := c.cache[key]; hit { - c.removeElement(ele) - } -} - -// RemoveOldest removes the oldest item from the cache. -func (c *LRUCache) removeOldest() { - ele := c.ll.Back() - if ele != nil { - entry := ele.Value.(*CacheItem) - - if MillisecondNow() < entry.ExpireAt { - metricCacheUnexpiredEvictions.Add(1) - } - - c.removeElement(ele) - } -} - -func (c *LRUCache) removeElement(e *list.Element) { - c.ll.Remove(e) - kv := e.Value.(*CacheItem) - delete(c.cache, kv.Key) - atomic.StoreInt64(&c.cacheLen, int64(c.ll.Len())) -} - -// Size returns the number of items in the cache. -func (c *LRUCache) Size() int64 { - return atomic.LoadInt64(&c.cacheLen) -} - -// UpdateExpiration updates the expiration time for the key -func (c *LRUCache) UpdateExpiration(key string, expireAt int64) bool { - if ele, hit := c.cache[key]; hit { - entry := ele.Value.(*CacheItem) - entry.ExpireAt = expireAt - return true - } - return false -} - -func (c *LRUCache) Close() error { - c.cache = nil - c.ll = nil - c.cacheLen = 0 - return nil -} - -func NewLRUCacheCollector() *LRUCacheCollector { - return &LRUCacheCollector{ - caches: []Cache{}, - } -} - -// AddCache adds a Cache object to be tracked by the collector. -func (collector *LRUCacheCollector) AddCache(cache Cache) { - collector.caches = append(collector.caches, cache) -} - -// Describe fetches prometheus metrics to be registered -func (collector *LRUCacheCollector) Describe(ch chan<- *prometheus.Desc) { - metricCacheSize.Describe(ch) - metricCacheAccess.Describe(ch) - metricCacheUnexpiredEvictions.Describe(ch) -} - -// Collect fetches metric counts and gauges from the cache -func (collector *LRUCacheCollector) Collect(ch chan<- prometheus.Metric) { - metricCacheSize.Set(collector.getSize()) - metricCacheSize.Collect(ch) - metricCacheAccess.Collect(ch) - metricCacheUnexpiredEvictions.Collect(ch) -} - -func (collector *LRUCacheCollector) getSize() float64 { - var size float64 - - for _, cache := range collector.caches { - size += float64(cache.Size()) - } - - return size -} diff --git a/memberlist.go b/memberlist.go index 0d67dfd..a5a7ccb 100644 --- a/memberlist.go +++ b/memberlist.go @@ -17,27 +17,25 @@ limitations under the License. package gubernator import ( - "bufio" "bytes" "context" "encoding/base64" "encoding/gob" "encoding/json" - "io" + "log/slog" "net" - "runtime" "strconv" "strings" ml "github.com/hashicorp/memberlist" - "github.com/mailgun/holster/v4/clock" + "github.com/kapetan-io/tackle/clock" + "github.com/kapetan-io/tackle/set" "github.com/mailgun/holster/v4/retry" - "github.com/mailgun/holster/v4/setter" "github.com/pkg/errors" - "github.com/sirupsen/logrus" ) type MemberListPool struct { + logAdaptor *logAdaptor log FieldLogger memberList *ml.Memberlist conf MemberListPoolConfig @@ -87,7 +85,7 @@ type MemberListEncryptionConfig struct { } func NewMemberListPool(ctx context.Context, conf MemberListPoolConfig) (*MemberListPool, error) { - setter.SetDefault(conf.Logger, logrus.WithField("category", "gubernator")) + set.Default(conf.Logger, slog.Default().With("category", "gubernator")) m := &MemberListPool{ log: conf.Logger, conf: conf, @@ -143,8 +141,8 @@ func NewMemberListPool(ctx context.Context, conf MemberListPoolConfig) (*MemberL if conf.NodeName != "" { config.Name = conf.NodeName } - - config.LogOutput = newLogWriter(m.log) + m.logAdaptor = newLogAdaptor(m.log) + config.LogOutput = m.logAdaptor // Create and set member list memberList, err := ml.Create(config) @@ -195,8 +193,11 @@ func (m *MemberListPool) joinPool(ctx context.Context, conf MemberListPoolConfig func (m *MemberListPool) Close() { err := m.memberList.Leave(clock.Second) if err != nil { - m.log.Warn(errors.Wrap(err, "while leaving member-list")) + m.log.LogAttrs(context.TODO(), slog.LevelWarn, "while leaving member-list", + ErrAttr(err), + ) } + _ = m.logAdaptor.Close() } type memberListEventHandler struct { @@ -219,7 +220,9 @@ func (e *memberListEventHandler) addPeer(node *ml.Node) { peer, err := unmarshallPeer(node.Meta, ip) if err != nil { - e.log.WithError(err).Warnf("while adding to peers") + e.log.LogAttrs(context.TODO(), slog.LevelWarn, "while adding to peers", + ErrAttr(err), + ) } else { e.peers[ip] = peer e.callOnUpdate() @@ -233,7 +236,9 @@ func (e *memberListEventHandler) NotifyJoin(node *ml.Node) { if err != nil { // This is called during member list initialization due to the fact that the local node // has no metadata yet - e.log.WithError(err).Warn("while deserialize member-list peer") + e.log.LogAttrs(context.TODO(), slog.LevelWarn, "while deserialize member-list peer", + ErrAttr(err), + ) return } peer.IsOwner = false @@ -255,7 +260,9 @@ func (e *memberListEventHandler) NotifyUpdate(node *ml.Node) { peer, err := unmarshallPeer(node.Meta, ip) if err != nil { - e.log.WithError(err).Warn("while unmarshalling peer info") + e.log.LogAttrs(context.TODO(), slog.LevelError, "while unmarshalling peer info", + ErrAttr(err), + ) } peer.IsOwner = false e.peers[ip] = peer @@ -266,7 +273,7 @@ func (e *memberListEventHandler) callOnUpdate() { var peers []PeerInfo for _, p := range e.peers { - if p.GRPCAddress == e.conf.Advertise.GRPCAddress { + if p.HTTPAddress == e.conf.Advertise.HTTPAddress { p.IsOwner = true } peers = append(peers, p) @@ -302,31 +309,11 @@ func unmarshallPeer(b []byte, ip string) (PeerInfo, error) { if metadata.AdvertiseAddress == "" { metadata.AdvertiseAddress = makeAddress(ip, metadata.GubernatorPort) } - return PeerInfo{GRPCAddress: metadata.AdvertiseAddress, DataCenter: metadata.DataCenter}, nil + return PeerInfo{HTTPAddress: metadata.AdvertiseAddress, DataCenter: metadata.DataCenter}, nil } return peer, nil } -func newLogWriter(log FieldLogger) *io.PipeWriter { - reader, writer := io.Pipe() - - go func() { - scanner := bufio.NewScanner(reader) - for scanner.Scan() { - log.Info(scanner.Text()) - } - if err := scanner.Err(); err != nil { - log.Errorf("Error while reading from Writer: %s", err) - } - reader.Close() - }() - runtime.SetFinalizer(writer, func(w *io.PipeWriter) { - writer.Close() - }) - - return writer -} - func splitAddress(addr string) (string, int, error) { host, port, err := net.SplitHostPort(addr) if err != nil { diff --git a/metadata_carrier_test.go b/metadata_carrier_test.go index 618a288..1b27191 100644 --- a/metadata_carrier_test.go +++ b/metadata_carrier_test.go @@ -19,7 +19,7 @@ package gubernator_test import ( "testing" - "github.com/gubernator-io/gubernator/v2" + "github.com/gubernator-io/gubernator/v3" "github.com/stretchr/testify/assert" ) diff --git a/mock_cache_test.go b/mock_cache_test.go index 3eea640..a2d8b5e 100644 --- a/mock_cache_test.go +++ b/mock_cache_test.go @@ -19,7 +19,7 @@ package gubernator_test // Mock implementation of Cache. import ( - guber "github.com/gubernator-io/gubernator/v2" + guber "github.com/gubernator-io/gubernator/v3" "github.com/stretchr/testify/mock" ) @@ -29,16 +29,11 @@ type MockCache struct { var _ guber.Cache = &MockCache{} -func (m *MockCache) Add(item *guber.CacheItem) bool { +func (m *MockCache) AddIfNotPresent(item *guber.CacheItem) bool { args := m.Called(item) return args.Bool(0) } -func (m *MockCache) UpdateExpiration(key string, expireAt int64) bool { - args := m.Called(key, expireAt) - return args.Bool(0) -} - func (m *MockCache) GetItem(key string) (value *guber.CacheItem, ok bool) { args := m.Called(key) retval, _ := args.Get(0).(*guber.CacheItem) @@ -60,6 +55,10 @@ func (m *MockCache) Size() int64 { return int64(args.Int(0)) } +func (m *MockCache) Stats() guber.CacheStats { + return guber.CacheStats{} +} + func (m *MockCache) Close() error { args := m.Called() return args.Error(0) diff --git a/mock_loader_test.go b/mock_loader_test.go index 4c58e84..8b22aa0 100644 --- a/mock_loader_test.go +++ b/mock_loader_test.go @@ -19,7 +19,7 @@ package gubernator_test // Mock implementation of Loader. import ( - guber "github.com/gubernator-io/gubernator/v2" + guber "github.com/gubernator-io/gubernator/v3" "github.com/stretchr/testify/mock" ) diff --git a/mock_store_test.go b/mock_store_test.go index 8a2f356..72c8d3f 100644 --- a/mock_store_test.go +++ b/mock_store_test.go @@ -21,7 +21,7 @@ package gubernator_test import ( "context" - guber "github.com/gubernator-io/gubernator/v2" + guber "github.com/gubernator-io/gubernator/v3" "github.com/stretchr/testify/mock" ) @@ -31,11 +31,11 @@ type MockStore2 struct { var _ guber.Store = &MockStore2{} -func (m *MockStore2) OnChange(ctx context.Context, r *guber.RateLimitReq, item *guber.CacheItem) { +func (m *MockStore2) OnChange(ctx context.Context, r *guber.RateLimitRequest, item *guber.CacheItem) { m.Called(ctx, r, item) } -func (m *MockStore2) Get(ctx context.Context, r *guber.RateLimitReq) (*guber.CacheItem, bool) { +func (m *MockStore2) Get(ctx context.Context, r *guber.RateLimitRequest) (*guber.CacheItem, bool) { args := m.Called(ctx, r) retval, _ := args.Get(0).(*guber.CacheItem) return retval, args.Bool(1) diff --git a/net.go b/net.go index 39495a8..1bb9765 100644 --- a/net.go +++ b/net.go @@ -19,14 +19,14 @@ package gubernator import ( "net" "os" + "slices" - "github.com/mailgun/holster/v4/slice" "github.com/pkg/errors" ) // ResolveHostIP attempts to discover the actual ip address of the host if the passed address is "0.0.0.0" or "::" func ResolveHostIP(addr string) (string, error) { - if slice.ContainsString(addr, []string{"0.0.0.0", "::", "0:0:0:0:0:0:0:0", ""}, nil) { + if slices.Contains([]string{"0.0.0.0", "::", "0:0:0:0:0:0:0:0", ""}, addr) { // Use the hostname as the advertise address as it's most likely to be the external interface domainName, err := os.Hostname() if err != nil { diff --git a/otter.go b/otter.go new file mode 100644 index 0000000..dc61c01 --- /dev/null +++ b/otter.go @@ -0,0 +1,108 @@ +package gubernator + +import ( + "fmt" + "sync/atomic" + + "github.com/kapetan-io/tackle/set" + "github.com/maypok86/otter" +) + +type OtterCache struct { + cache otter.Cache[string, *CacheItem] + stats CacheStats +} + +// NewOtterCache returns a new cache backed by otter. If size is 0, then +// the cache is created with a default cache size. +func NewOtterCache(size int) (*OtterCache, error) { + // Default is 500k bytes in size + set.Default(&size, 500_000) + b, err := otter.NewBuilder[string, *CacheItem](size) + if err != nil { + return nil, fmt.Errorf("during otter.NewBuilder(): %w", err) + } + + o := &OtterCache{} + + b.DeletionListener(func(key string, value *CacheItem, cause otter.DeletionCause) { + if cause == otter.Size { + atomic.AddInt64(&o.stats.UnexpiredEvictions, 1) + } + }) + + b.Cost(func(key string, value *CacheItem) uint32 { + // The total size of the CacheItem and Bucket item is 104 bytes. + // See cache.go:CacheItem definition for details. + return uint32(104 + len(value.Key)) + }) + + o.cache, err = b.Build() + if err != nil { + return nil, fmt.Errorf("during otter.Builder.Build(): %w", err) + } + return o, nil +} + +// AddIfNotPresent adds a new CacheItem to the cache. The key must be provided via CacheItem.Key +// returns true if the item was added to the cache; false if the item was too large +// for the cache or if the key already exists in the cache. +func (o *OtterCache) AddIfNotPresent(item *CacheItem) bool { + return o.cache.SetIfAbsent(item.Key, item) +} + +// GetItem returns an item in the cache that corresponds to the provided key +func (o *OtterCache) GetItem(key string) (*CacheItem, bool) { + item, ok := o.cache.Get(key) + if !ok { + atomic.AddInt64(&o.stats.Miss, 1) + return nil, false + } + + atomic.AddInt64(&o.stats.Hit, 1) + return item, true +} + +// Each returns a channel which the call can use to iterate through +// all the items in the cache. +func (o *OtterCache) Each() chan *CacheItem { + ch := make(chan *CacheItem) + + go func() { + o.cache.Range(func(_ string, v *CacheItem) bool { + ch <- v + return true + }) + close(ch) + }() + return ch +} + +// Remove explicitly removes and item from the cache. +// NOTE: A deletion call to otter requires a mutex to preform, +// if possible, avoid preforming explicit removal from the cache. +// Instead, prefer the item to be evicted naturally. +func (o *OtterCache) Remove(key string) { + o.cache.Delete(key) +} + +// Size return the current number of items in the cache +func (o *OtterCache) Size() int64 { + return int64(o.cache.Size()) +} + +// Stats returns the current cache stats and resets the values to zero +func (o *OtterCache) Stats() CacheStats { + var result CacheStats + result.UnexpiredEvictions = atomic.SwapInt64(&o.stats.UnexpiredEvictions, 0) + result.Miss = atomic.SwapInt64(&o.stats.Miss, 0) + result.Hit = atomic.SwapInt64(&o.stats.Hit, 0) + result.Size = int64(o.cache.Size()) + return result +} + +// Close closes the cache and all associated background processes +func (o *OtterCache) Close() error { + o.cache.Close() + return nil +} diff --git a/otter_test.go b/otter_test.go new file mode 100644 index 0000000..7cd9aff --- /dev/null +++ b/otter_test.go @@ -0,0 +1,410 @@ +package gubernator_test + +import ( + "strconv" + "sync" + "testing" + "time" + + "github.com/gubernator-io/gubernator/v3" + "github.com/kapetan-io/tackle/clock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOtterCache(t *testing.T) { + const iterations = 1000 + const concurrency = 100 + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + + t.Run("Happy path", func(t *testing.T) { + cache, err := gubernator.NewOtterCache(0) + require.NoError(t, err) + + // Populate cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + + // Validate cache. + assert.Equal(t, int64(iterations), cache.Size()) + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item, ok := cache.GetItem(key) + require.True(t, ok) + require.NotNil(t, item) + assert.Equal(t, item.Value, i) + } + + // Clear cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + cache.Remove(key) + } + + assert.Zero(t, cache.Size()) + }) + + t.Run("Update an existing key", func(t *testing.T) { + cache, err := gubernator.NewOtterCache(0) + require.NoError(t, err) + const key = "foobar" + + // Add key. + item1 := &gubernator.CacheItem{ + Key: key, + Value: "initial value", + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item1) + + // Update same key is refused + item2 := &gubernator.CacheItem{ + Key: key, + Value: "new value", + ExpireAt: expireAt, + } + assert.False(t, cache.AddIfNotPresent(item2)) + + // Fetch and update the CacheItem + update, ok := cache.GetItem(key) + assert.True(t, ok) + update.Value = "new value" + + // Verify. + verifyItem, ok := cache.GetItem(key) + require.True(t, ok) + assert.Equal(t, item2, verifyItem) + }) + + t.Run("Concurrent reads", func(t *testing.T) { + cache, err := gubernator.NewOtterCache(0) + require.NoError(t, err) + + // Populate cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + + assert.Equal(t, int64(iterations), cache.Size()) + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for thread := 0; thread < concurrency; thread++ { + doneWg.Add(1) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item, ok := cache.GetItem(key) + assert.True(t, ok) + require.NotNil(t, item) + assert.Equal(t, item.Value, i) + } + }() + } + + // Wait for goroutines to finish. + launchWg.Done() + doneWg.Wait() + }) + + t.Run("Concurrent writes", func(t *testing.T) { + cache, err := gubernator.NewOtterCache(0) + require.NoError(t, err) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for thread := 0; thread < concurrency; thread++ { + doneWg.Add(1) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + }() + } + + // Wait for goroutines to finish. + launchWg.Done() + doneWg.Wait() + }) + + t.Run("Concurrent reads and writes", func(t *testing.T) { + cache, err := gubernator.NewOtterCache(0) + require.NoError(t, err) + + // Populate cache. + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + + assert.Equal(t, int64(iterations), cache.Size()) + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for thread := 0; thread < concurrency; thread++ { + doneWg.Add(2) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + for i := 0; i < iterations; i++ { + key := strconv.Itoa(i) + item, ok := cache.GetItem(key) + assert.True(t, ok) + require.NotNil(t, item) + assert.Equal(t, item.Value, i) + } + }() + + go func() { + defer doneWg.Done() + launchWg.Wait() + + // Write different keys than the keys we are reading to avoid race on Add() / GetItem() + for i := iterations; i < iterations*2; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + }() + } + + // Wait for goroutines to finish. + launchWg.Done() + doneWg.Wait() + }) +} + +func BenchmarkOtterCache(b *testing.B) { + + b.Run("Sequential reads", func(b *testing.B) { + cache, _ := gubernator.NewOtterCache(0) + + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + + // Populate cache. + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + _, _ = cache.GetItem(key) + } + }) + + b.Run("Sequential writes", func(b *testing.B) { + cache, _ := gubernator.NewOtterCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + }) + + b.Run("Concurrent reads", func(b *testing.B) { + cache, _ := gubernator.NewOtterCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + + // Populate cache. + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + doneWg.Add(1) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + _, _ = cache.GetItem(key) + }() + } + + b.ReportAllocs() + b.ResetTimer() + launchWg.Done() + doneWg.Wait() + }) + + b.Run("Concurrent writes", func(b *testing.B) { + cache, _ := gubernator.NewOtterCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + doneWg.Add(1) + + go func(i int) { + defer doneWg.Done() + launchWg.Wait() + + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + }(i) + } + + b.ReportAllocs() + b.ResetTimer() + launchWg.Done() + doneWg.Wait() + }) + + b.Run("Concurrent reads and writes of existing keys", func(b *testing.B) { + cache, _ := gubernator.NewOtterCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + // Populate cache. + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + } + + for i := 0; i < b.N; i++ { + key := strconv.Itoa(i) + doneWg.Add(2) + + go func() { + defer doneWg.Done() + launchWg.Wait() + + _, _ = cache.GetItem(key) + }() + + go func(i int) { + defer doneWg.Done() + launchWg.Wait() + + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + }(i) + } + + b.ReportAllocs() + b.ResetTimer() + launchWg.Done() + doneWg.Wait() + }) + + b.Run("Concurrent reads and writes of non-existent keys", func(b *testing.B) { + cache, _ := gubernator.NewOtterCache(0) + expireAt := clock.Now().Add(1 * time.Hour).UnixMilli() + var launchWg, doneWg sync.WaitGroup + launchWg.Add(1) + + for i := 0; i < b.N; i++ { + doneWg.Add(2) + + go func(i int) { + defer doneWg.Done() + launchWg.Wait() + + key := strconv.Itoa(i) + _, _ = cache.GetItem(key) + }(i) + + go func(i int) { + defer doneWg.Done() + launchWg.Wait() + + key := "z" + strconv.Itoa(i) + item := &gubernator.CacheItem{ + Key: key, + Value: i, + ExpireAt: expireAt, + } + cache.AddIfNotPresent(item) + }(i) + } + + b.ReportAllocs() + b.ResetTimer() + launchWg.Done() + doneWg.Wait() + }) +} diff --git a/peer.go b/peer.go new file mode 100644 index 0000000..1b0bab6 --- /dev/null +++ b/peer.go @@ -0,0 +1,411 @@ +/* +Copyright 2018-2023 Mailgun Technologies Inc + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gubernator + +import ( + "context" + "fmt" + "log/slog" + "sync" + "sync/atomic" + + "github.com/gubernator-io/gubernator/v3/tracing" + "github.com/kapetan-io/tackle/clock" + "github.com/kapetan-io/tackle/set" + "github.com/mailgun/errors" + "github.com/mailgun/holster/v4/collections" + "github.com/prometheus/client_golang/prometheus" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/trace" +) + +type Peer struct { + lastErrs *collections.LRUCache + wg sync.WaitGroup + queue chan *request + mutex sync.RWMutex + client PeerClient + Conf PeerConfig + inShutdown int64 +} + +type response struct { + rl *RateLimitResponse + err error +} + +type request struct { + request *RateLimitRequest + ctx context.Context + resp chan *response +} + +type PeerConfig struct { + PeerClient PeerClient + Behavior BehaviorConfig + Info PeerInfo + Log FieldLogger +} + +type PeerClient interface { + Forward(context.Context, *ForwardRequest, *ForwardResponse) error + Update(context.Context, *UpdateRequest) error + Close(ctx context.Context) error +} + +func NewPeer(conf PeerConfig) (*Peer, error) { + if len(conf.Info.HTTPAddress) == 0 { + return nil, errors.New("Peer.Info.HTTPAddress is empty; must provide an address") + } + + set.Default(&conf.PeerClient, NewPeerClient(WithNoTLS(conf.Info.HTTPAddress))) + set.Default(&conf.Log, slog.Default().With("category", "Peer")) + + p := &Peer{ + lastErrs: collections.NewLRUCache(100), + queue: make(chan *request, 1000), + client: conf.PeerClient, + Conf: conf, + } + go p.run() + return p, nil +} + +// Info returns PeerInfo struct that describes this Peer +func (p *Peer) Info() PeerInfo { + return p.Conf.Info +} + +var ( + // TODO: Should retry in this case + ErrPeerShutdown = errors.New("peer is in shutdown; try a different peer") +) + +// Forward forwards a rate limit request to a peer. +// If the rate limit has `behavior == BATCHING` configured, this method will attempt to batch the rate limits +func (p *Peer) Forward(ctx context.Context, r *RateLimitRequest) (resp *RateLimitResponse, err error) { + ctx = tracing.StartScope(ctx, "Peer.Forward") + defer func() { tracing.EndScope(ctx, err) }() + span := trace.SpanFromContext(ctx) + span.SetAttributes( + attribute.String("peer.HTTPAddress", p.Conf.Info.HTTPAddress), + attribute.String("peer.Datacenter", p.Conf.Info.DataCenter), + attribute.String("request.key", r.UniqueKey), + attribute.String("request.name", r.Name), + attribute.Int64("request.algorithm", int64(r.Algorithm)), + attribute.Int64("request.behavior", int64(r.Behavior)), + attribute.Int64("request.duration", r.Duration), + attribute.Int64("request.limit", r.Limit), + attribute.Int64("request.hits", r.Hits), + attribute.Int64("request.burst", r.Burst), + ) + + if atomic.LoadInt64(&p.inShutdown) == 1 { + return nil, ErrPeerShutdown + } + + // NOTE: Add() must be done within the RLock since we must ensure all in-flight Forward() + // requests are done before calls to Close() can complete. We can't just wg.Wait() for + // since there may be Forward() call that is executing at this very code spot when Close() + // is called. In that scenario wg.Add() and wg.Wait() are in a race. + p.mutex.RLock() + p.wg.Add(1) + defer func() { + p.mutex.RUnlock() + defer p.wg.Done() + }() + + // If config asked for no batching + if HasBehavior(r.Behavior, Behavior_NO_BATCHING) { + // If no metadata is provided + if r.Metadata == nil { + r.Metadata = make(map[string]string) + } + // Propagate the trace context along with the rate limit so + // peers can continue to report traces for this rate limit. + prop := propagation.TraceContext{} + prop.Inject(ctx, &MetadataCarrier{Map: r.Metadata}) + + // Forward a single rate limit + var fr ForwardResponse + err = p.ForwardBatch(ctx, &ForwardRequest{ + Requests: []*RateLimitRequest{r}, + }, &fr) + if err != nil { + err = errors.Wrap(err, "Error in forward") + return nil, p.setLastErr(err) + } + return fr.RateLimits[0], nil + } + + resp, err = p.forwardBatch(ctx, r) + if err != nil { + err = errors.Wrap(err, "Error in forwardBatch") + return nil, p.setLastErr(err) + } + + return resp, nil +} + +// ForwardBatch requests a list of rate limit statuses from a peer +func (p *Peer) ForwardBatch(ctx context.Context, req *ForwardRequest, resp *ForwardResponse) (err error) { + ctx = tracing.StartScope(ctx, "Peer.forward") + defer func() { tracing.EndScope(ctx, err) }() + + if err = p.client.Forward(ctx, req, resp); err != nil { + return p.setLastErr(errors.Wrap(err, "Error in client.Forward()")) + } + + // Unlikely, but this avoids a panic if something wonky happens + if len(resp.RateLimits) != len(req.Requests) { + return p.setLastErr( + errors.New("number of rate limits in peer response does not match request")) + } + return nil +} + +// Update sends rate limit status updates to a peer +func (p *Peer) Update(ctx context.Context, req *UpdateRequest) (err error) { + ctx = tracing.StartScope(ctx, "Peer.Update") + defer func() { tracing.EndScope(ctx, err) }() + + err = p.client.Update(ctx, req) + if err != nil { + _ = p.setLastErr(err) + } + return err +} + +func (p *Peer) GetLastErr() []string { + var errs []string + keys := p.lastErrs.Keys() + + // Get errors from each key in the cache + for _, key := range keys { + err, ok := p.lastErrs.Get(key) + if ok { + errs = append(errs, err.(error).Error()) + } + } + + return errs +} + +// Close will gracefully close all client connections, until the context is canceled +func (p *Peer) Close(ctx context.Context) error { + if atomic.LoadInt64(&p.inShutdown) == 1 { + return nil + } + + atomic.AddInt64(&p.inShutdown, 1) + + // This allows us to wait on the wait group, or until the context + // has been canceled. + waitChan := make(chan struct{}) + go func() { + p.mutex.Lock() + p.wg.Wait() + close(p.queue) + p.mutex.Unlock() + close(waitChan) + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-waitChan: + } + return p.client.Close(ctx) +} + +func (p *Peer) forwardBatch(ctx context.Context, r *RateLimitRequest) (resp *RateLimitResponse, err error) { + ctx = tracing.StartScope(ctx, "Peer.forwardBatch") + defer func() { tracing.EndScope(ctx, err) }() + + funcTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("Peer.forwardBatch")) + defer funcTimer.ObserveDuration() + + if atomic.LoadInt64(&p.inShutdown) == 1 { + return nil, ErrPeerShutdown + } + + // Wait for a response or context cancel + ctx2 := tracing.StartScope(ctx, "Wait for response") + defer tracing.EndScope(ctx2, nil) + + req := request{ + resp: make(chan *response, 1), + ctx: ctx2, + request: r, + } + + // Enqueue the request to be sent + peerAddr := p.Info().HTTPAddress + metricBatchQueueLength.WithLabelValues(peerAddr).Set(float64(len(p.queue))) + + select { + case p.queue <- &req: + // Successfully enqueued request. + case <-ctx2.Done(): + return nil, errors.Wrap(ctx2.Err(), "Context error while enqueuing request") + } + + p.wg.Add(1) + defer func() { + p.wg.Done() + }() + + select { + case re := <-req.resp: + if re.err != nil { + return nil, fmt.Errorf("request error: %w", re.err) + } + return re.rl, nil + case <-ctx2.Done(): + return nil, errors.Wrap(ctx2.Err(), "Context error while waiting for response") + } +} + +// run waits for requests to be queued, when either c.batchWait time +// has elapsed or the queue reaches c.batchLimit. Send what is in the queue. +func (p *Peer) run() { + var interval = NewInterval(p.Conf.Behavior.BatchWait) + defer interval.Stop() + + var queue []*request + + for { + select { + case r, ok := <-p.queue: + // If the queue has closed, we need to send the rest of the queue + if !ok { + if len(queue) > 0 { + p.sendBatch(queue) + } + return + } + + queue = append(queue, r) + // Send the queue if we reached our batch limit + if len(queue) >= p.Conf.Behavior.BatchLimit { + p.Conf.Log.LogAttrs(context.TODO(), slog.LevelDebug, "run() reached batch limit", + slog.Int("queueLen", len(queue)), + slog.Int("batchLimit", p.Conf.Behavior.BatchLimit), + ) + ref := queue + queue = nil + go p.sendBatch(ref) + continue + } + + // If this is our first enqueued item since last + // sendBatch, reset interval timer. + if len(queue) == 1 { + interval.Next() + } + continue + + case <-interval.C: + queue2 := queue + + if len(queue2) > 0 { + queue = nil + go p.sendBatch(queue2) + } + } + } +} + +// sendBatch sends the queue provided and returns the responses to +// waiting go routines +func (p *Peer) sendBatch(queue []*request) { + ctx := tracing.StartScope(context.Background(), "Peer.sendBatch") + defer tracing.EndScope(ctx, nil) + + batchSendTimer := prometheus.NewTimer(metricBatchSendDuration.WithLabelValues(p.Conf.Info.HTTPAddress)) + defer batchSendTimer.ObserveDuration() + funcTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("Peer.sendBatch")) + defer funcTimer.ObserveDuration() + + var req ForwardRequest + for _, r := range queue { + // NOTE: This trace has the same name because it's in a separate trace than the one above. + // We link the two traces, so we can relate our rate limit trace back to the above trace. + r.ctx = tracing.StartScope(r.ctx, "Peer.sendBatch", + trace.WithLinks(trace.LinkFromContext(ctx))) + // If no metadata is provided + if r.request.Metadata == nil { + r.request.Metadata = make(map[string]string) + } + // Propagate the trace context along with the batched rate limit so + // peers can continue to report traces for this rate limit. + prop := propagation.TraceContext{} + prop.Inject(r.ctx, &MetadataCarrier{Map: r.request.Metadata}) + req.Requests = append(req.Requests, r.request) + tracing.EndScope(r.ctx, nil) + + } + + ctx, cancel := context.WithTimeout(ctx, p.Conf.Behavior.BatchTimeout) + var resp ForwardResponse + err := p.client.Forward(ctx, &req, &resp) + cancel() + + // An error here indicates the entire request failed + if err != nil { + err = errors.Wrap(err, "Error in client.forward") + p.Conf.Log.LogAttrs(context.TODO(), slog.LevelError, "Error in client.forward", + ErrAttr(err), + slog.Int("queueLen", len(queue)), + slog.Duration("batchTimeout", p.Conf.Behavior.BatchTimeout), + ) + _ = p.setLastErr(err) + + for _, r := range queue { + r.resp <- &response{err: err} + } + return + } + + // Unlikely, but this avoids a panic if something wonky happens + if len(resp.RateLimits) != len(queue) { + for _, r := range queue { + r.resp <- &response{err: errors.New("server responded with incorrect rate limit list size")} + } + return + } + + // Provide responses to channels waiting in the queue + for i, r := range queue { + r.resp <- &response{rl: resp.RateLimits[i]} + } +} + +func (p *Peer) setLastErr(err error) error { + // If we get a nil error return without caching it + if err == nil { + return err + } + + // Add error to the cache with a TTL of 5 minutes + p.lastErrs.AddWithTTL(err.Error(), + errors.Wrap(err, fmt.Sprintf("from host %s", p.Conf.Info.HTTPAddress)), + clock.Minute*5) + + return err +} diff --git a/peer.pb.go b/peer.pb.go new file mode 100644 index 0000000..22fd4d3 --- /dev/null +++ b/peer.pb.go @@ -0,0 +1,425 @@ +// +//Copyright 2018-2022 Mailgun Technologies Inc +// +//Licensed under the Apache License, Version 2.0 (the "License"); +//you may not use this file except in compliance with the License. +//You may obtain a copy of the License at +// +//http://www.apache.org/licenses/LICENSE-2.0 +// +//Unless required by applicable law or agreed to in writing, software +//distributed under the License is distributed on an "AS IS" BASIS, +//WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +//See the License for the specific language governing permissions and +//limitations under the License. + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.32.0 +// protoc (unknown) +// source: peer.proto + +package gubernator + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type ForwardRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Must specify at least one RateLimit. The peer that receives this request MUST be authoritative for + // each rate_limit[x].unique_key provided, as the peer will not forward the request to any other peers + Requests []*RateLimitRequest `protobuf:"bytes,1,rep,name=requests,proto3" json:"requests,omitempty"` +} + +func (x *ForwardRequest) Reset() { + *x = ForwardRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_peer_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ForwardRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ForwardRequest) ProtoMessage() {} + +func (x *ForwardRequest) ProtoReflect() protoreflect.Message { + mi := &file_peer_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ForwardRequest.ProtoReflect.Descriptor instead. +func (*ForwardRequest) Descriptor() ([]byte, []int) { + return file_peer_proto_rawDescGZIP(), []int{0} +} + +func (x *ForwardRequest) GetRequests() []*RateLimitRequest { + if x != nil { + return x.Requests + } + return nil +} + +type ForwardResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Responses are in the same order as they appeared in the PeerRateLimitRequestuests + RateLimits []*RateLimitResponse `protobuf:"bytes,1,rep,name=rate_limits,json=rateLimits,proto3" json:"rate_limits,omitempty"` +} + +func (x *ForwardResponse) Reset() { + *x = ForwardResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_peer_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ForwardResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ForwardResponse) ProtoMessage() {} + +func (x *ForwardResponse) ProtoReflect() protoreflect.Message { + mi := &file_peer_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ForwardResponse.ProtoReflect.Descriptor instead. +func (*ForwardResponse) Descriptor() ([]byte, []int) { + return file_peer_proto_rawDescGZIP(), []int{1} +} + +func (x *ForwardResponse) GetRateLimits() []*RateLimitResponse { + if x != nil { + return x.RateLimits + } + return nil +} + +type UpdateRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Must specify at least one RateLimit + Globals []*UpdateRateLimit `protobuf:"bytes,1,rep,name=globals,proto3" json:"globals,omitempty"` +} + +func (x *UpdateRequest) Reset() { + *x = UpdateRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_peer_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *UpdateRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*UpdateRequest) ProtoMessage() {} + +func (x *UpdateRequest) ProtoReflect() protoreflect.Message { + mi := &file_peer_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use UpdateRequest.ProtoReflect.Descriptor instead. +func (*UpdateRequest) Descriptor() ([]byte, []int) { + return file_peer_proto_rawDescGZIP(), []int{2} +} + +func (x *UpdateRequest) GetGlobals() []*UpdateRateLimit { + if x != nil { + return x.Globals + } + return nil +} + +type UpdateRateLimit struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Uniquely identifies this rate limit IE: 'ip:10.2.10.7' or 'account:123445' + Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` + // The rate limit state to update + State *RateLimitResponse `protobuf:"bytes,2,opt,name=state,proto3" json:"state,omitempty"` + // The algorithm used to calculate the rate limit. The algorithm may change on + // subsequent requests, when this occurs any previous rate limit hit counts are reset. + Algorithm Algorithm `protobuf:"varint,3,opt,name=algorithm,proto3,enum=gubernator.v3.Algorithm" json:"algorithm,omitempty"` + // The duration of the rate limit in milliseconds + Duration int64 `protobuf:"varint,4,opt,name=duration,proto3" json:"duration,omitempty"` + // The exact time the original request was created in Epoch milliseconds. + // Due to time drift between systems, it may be advantageous for a client to + // set the exact time the request was created. It possible the system clock + // for the client has drifted from the system clock where gubernator daemon + // is running. + // + // The created time is used by gubernator to calculate the reset time for + // both token and leaky algorithms. If it is not set by the client, + // gubernator will set the created time when it receives the rate limit + // request. + CreatedAt int64 `protobuf:"varint,5,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"` +} + +func (x *UpdateRateLimit) Reset() { + *x = UpdateRateLimit{} + if protoimpl.UnsafeEnabled { + mi := &file_peer_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *UpdateRateLimit) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*UpdateRateLimit) ProtoMessage() {} + +func (x *UpdateRateLimit) ProtoReflect() protoreflect.Message { + mi := &file_peer_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use UpdateRateLimit.ProtoReflect.Descriptor instead. +func (*UpdateRateLimit) Descriptor() ([]byte, []int) { + return file_peer_proto_rawDescGZIP(), []int{3} +} + +func (x *UpdateRateLimit) GetKey() string { + if x != nil { + return x.Key + } + return "" +} + +func (x *UpdateRateLimit) GetState() *RateLimitResponse { + if x != nil { + return x.State + } + return nil +} + +func (x *UpdateRateLimit) GetAlgorithm() Algorithm { + if x != nil { + return x.Algorithm + } + return Algorithm_TOKEN_BUCKET +} + +func (x *UpdateRateLimit) GetDuration() int64 { + if x != nil { + return x.Duration + } + return 0 +} + +func (x *UpdateRateLimit) GetCreatedAt() int64 { + if x != nil { + return x.CreatedAt + } + return 0 +} + +var File_peer_proto protoreflect.FileDescriptor + +var file_peer_proto_rawDesc = []byte{ + 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0d, 0x67, 0x75, + 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x76, 0x33, 0x1a, 0x10, 0x67, 0x75, 0x62, + 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x4d, 0x0a, + 0x0e, 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, + 0x3b, 0x0a, 0x08, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x1f, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x76, + 0x33, 0x2e, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x52, 0x08, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x73, 0x22, 0x54, 0x0a, 0x0f, + 0x46, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x41, 0x0a, 0x0b, 0x72, 0x61, 0x74, 0x65, 0x5f, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x73, 0x18, 0x01, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, + 0x72, 0x2e, 0x76, 0x33, 0x2e, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x52, 0x0a, 0x72, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, + 0x74, 0x73, 0x22, 0x49, 0x0a, 0x0d, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x12, 0x38, 0x0a, 0x07, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x18, 0x01, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, + 0x72, 0x2e, 0x76, 0x33, 0x2e, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x61, 0x74, 0x65, 0x4c, + 0x69, 0x6d, 0x69, 0x74, 0x52, 0x07, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x22, 0xce, 0x01, + 0x0a, 0x0f, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, + 0x74, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, + 0x6b, 0x65, 0x79, 0x12, 0x36, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, + 0x76, 0x33, 0x2e, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x52, 0x05, 0x73, 0x74, 0x61, 0x74, 0x65, 0x12, 0x36, 0x0a, 0x09, 0x61, + 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, + 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x76, 0x33, 0x2e, 0x41, + 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x52, 0x09, 0x61, 0x6c, 0x67, 0x6f, 0x72, 0x69, + 0x74, 0x68, 0x6d, 0x12, 0x1a, 0x0a, 0x08, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, + 0x1d, 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x05, 0x20, + 0x01, 0x28, 0x03, 0x52, 0x09, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x41, 0x74, 0x42, 0x25, + 0x5a, 0x23, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x75, 0x62, + 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2d, 0x69, 0x6f, 0x2f, 0x67, 0x75, 0x62, 0x65, 0x72, + 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_peer_proto_rawDescOnce sync.Once + file_peer_proto_rawDescData = file_peer_proto_rawDesc +) + +func file_peer_proto_rawDescGZIP() []byte { + file_peer_proto_rawDescOnce.Do(func() { + file_peer_proto_rawDescData = protoimpl.X.CompressGZIP(file_peer_proto_rawDescData) + }) + return file_peer_proto_rawDescData +} + +var file_peer_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_peer_proto_goTypes = []interface{}{ + (*ForwardRequest)(nil), // 0: gubernator.v3.ForwardRequest + (*ForwardResponse)(nil), // 1: gubernator.v3.ForwardResponse + (*UpdateRequest)(nil), // 2: gubernator.v3.UpdateRequest + (*UpdateRateLimit)(nil), // 3: gubernator.v3.UpdateRateLimit + (*RateLimitRequest)(nil), // 4: gubernator.v3.RateLimitRequest + (*RateLimitResponse)(nil), // 5: gubernator.v3.RateLimitResponse + (Algorithm)(0), // 6: gubernator.v3.Algorithm +} +var file_peer_proto_depIdxs = []int32{ + 4, // 0: gubernator.v3.ForwardRequest.requests:type_name -> gubernator.v3.RateLimitRequest + 5, // 1: gubernator.v3.ForwardResponse.rate_limits:type_name -> gubernator.v3.RateLimitResponse + 3, // 2: gubernator.v3.UpdateRequest.globals:type_name -> gubernator.v3.UpdateRateLimit + 5, // 3: gubernator.v3.UpdateRateLimit.state:type_name -> gubernator.v3.RateLimitResponse + 6, // 4: gubernator.v3.UpdateRateLimit.algorithm:type_name -> gubernator.v3.Algorithm + 5, // [5:5] is the sub-list for method output_type + 5, // [5:5] is the sub-list for method input_type + 5, // [5:5] is the sub-list for extension type_name + 5, // [5:5] is the sub-list for extension extendee + 0, // [0:5] is the sub-list for field type_name +} + +func init() { file_peer_proto_init() } +func file_peer_proto_init() { + if File_peer_proto != nil { + return + } + file_gubernator_proto_init() + if !protoimpl.UnsafeEnabled { + file_peer_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ForwardRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_peer_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ForwardResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_peer_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*UpdateRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_peer_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*UpdateRateLimit); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_peer_proto_rawDesc, + NumEnums: 0, + NumMessages: 4, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_peer_proto_goTypes, + DependencyIndexes: file_peer_proto_depIdxs, + MessageInfos: file_peer_proto_msgTypes, + }.Build() + File_peer_proto = out.File + file_peer_proto_rawDesc = nil + file_peer_proto_goTypes = nil + file_peer_proto_depIdxs = nil +} diff --git a/peers.proto b/peer.proto similarity index 68% rename from peers.proto rename to peer.proto index 4976bc5..87858a6 100644 --- a/peers.proto +++ b/peer.proto @@ -18,41 +18,31 @@ syntax = "proto3"; option go_package = "github.com/gubernator-io/gubernator"; -option cc_generic_services = true; - -package pb.gubernator; +package gubernator.v3; import "gubernator.proto"; -// NOTE: For use by gubernator peers only -service PeersV1 { - // Used by peers to relay batches of requests to an owner peer - rpc GetPeerRateLimits (GetPeerRateLimitsReq) returns (GetPeerRateLimitsResp) {} - - // Used by owner peers to send global rate limit updates to non-owner peers - rpc UpdatePeerGlobals (UpdatePeerGlobalsReq) returns (UpdatePeerGlobalsResp) {} -} - -message GetPeerRateLimitsReq { - // Must specify at least one RateLimit. The peer that recives this request MUST be authoritative for +message ForwardRequest { + // Must specify at least one RateLimit. The peer that receives this request MUST be authoritative for // each rate_limit[x].unique_key provided, as the peer will not forward the request to any other peers - repeated RateLimitReq requests = 1; + repeated RateLimitRequest requests = 1; } -message GetPeerRateLimitsResp { - // Responses are in the same order as they appeared in the PeerRateLimitRequests - repeated RateLimitResp rate_limits = 1; +message ForwardResponse { + // Responses are in the same order as they appeared in the PeerRateLimitRequestuests + repeated RateLimitResponse rate_limits = 1; } -message UpdatePeerGlobalsReq { +message UpdateRequest { // Must specify at least one RateLimit - repeated UpdatePeerGlobal globals = 1; + repeated UpdateRateLimit globals = 1; } -message UpdatePeerGlobal { +message UpdateRateLimit { // Uniquely identifies this rate limit IE: 'ip:10.2.10.7' or 'account:123445' string key = 1; - RateLimitResp status = 2; + // The rate limit state to update + RateLimitResponse state = 2; // The algorithm used to calculate the rate limit. The algorithm may change on // subsequent requests, when this occurs any previous rate limit hit counts are reset. Algorithm algorithm = 3; @@ -69,5 +59,4 @@ message UpdatePeerGlobal { // gubernator will set the created time when it receives the rate limit // request. int64 created_at = 5; -} -message UpdatePeerGlobalsResp {} +} \ No newline at end of file diff --git a/peer_client.go b/peer_client.go deleted file mode 100644 index 03b29ff..0000000 --- a/peer_client.go +++ /dev/null @@ -1,435 +0,0 @@ -/* -Copyright 2018-2022 Mailgun Technologies Inc - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package gubernator - -import ( - "context" - "crypto/tls" - "fmt" - "sync" - "sync/atomic" - - "github.com/mailgun/holster/v4/clock" - "github.com/mailgun/holster/v4/collections" - "github.com/mailgun/holster/v4/errors" - "github.com/mailgun/holster/v4/tracing" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" - "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/propagation" - "go.opentelemetry.io/otel/trace" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/status" -) - -type PeerPicker interface { - GetByPeerInfo(PeerInfo) *PeerClient - Peers() []*PeerClient - Get(string) (*PeerClient, error) - New() PeerPicker - Add(*PeerClient) -} - -type PeerClient struct { - client PeersV1Client - conn *grpc.ClientConn - conf PeerConfig - queue chan *request - queueClosed atomic.Bool - lastErrs *collections.LRUCache - - wgMutex sync.RWMutex - wg sync.WaitGroup // Monitor the number of in-flight requests. GUARDED_BY(wgMutex) -} - -type response struct { - rl *RateLimitResp - err error -} - -type request struct { - request *RateLimitReq - reqState RateLimitReqState - resp chan *response - ctx context.Context -} - -type PeerConfig struct { - TLS *tls.Config - Behavior BehaviorConfig - Info PeerInfo - Log FieldLogger - TraceGRPC bool -} - -// NewPeerClient tries to establish a connection to a peer in a non-blocking fashion. -// If batching is enabled, it also starts a goroutine where batches will be processed. -func NewPeerClient(conf PeerConfig) (*PeerClient, error) { - peerClient := &PeerClient{ - queue: make(chan *request, 1000), - conf: conf, - lastErrs: collections.NewLRUCache(100), - } - var opts []grpc.DialOption - - if conf.TraceGRPC { - opts = []grpc.DialOption{ - grpc.WithStatsHandler(otelgrpc.NewClientHandler()), - } - } - - if conf.TLS != nil { - opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(conf.TLS))) - } else { - opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) - } - - var err error - peerClient.conn, err = grpc.NewClient(conf.Info.GRPCAddress, opts...) - if err != nil { - return nil, err - } - peerClient.client = NewPeersV1Client(peerClient.conn) - - if !conf.Behavior.DisableBatching { - go peerClient.runBatch() - } - return peerClient, nil -} - -// Info returns PeerInfo struct that describes this PeerClient -func (c *PeerClient) Info() PeerInfo { - return c.conf.Info -} - -// GetPeerRateLimit forwards a rate limit request to a peer. If the rate limit has `behavior == BATCHING` configured, -// this method will attempt to batch the rate limits -func (c *PeerClient) GetPeerRateLimit(ctx context.Context, r *RateLimitReq) (resp *RateLimitResp, err error) { - span := trace.SpanFromContext(ctx) - span.SetAttributes( - attribute.String("ratelimit.key", r.UniqueKey), - attribute.String("ratelimit.name", r.Name), - ) - - // If config asked for no batching - if c.conf.Behavior.DisableBatching || HasBehavior(r.Behavior, Behavior_NO_BATCHING) { - // If no metadata is provided - if r.Metadata == nil { - r.Metadata = make(map[string]string) - } - // Propagate the trace context along with the rate limit so - // peers can continue to report traces for this rate limit. - prop := propagation.TraceContext{} - prop.Inject(ctx, &MetadataCarrier{Map: r.Metadata}) - - // Send a single low latency rate limit request - resp, err := c.GetPeerRateLimits(ctx, &GetPeerRateLimitsReq{ - Requests: []*RateLimitReq{r}, - }) - if err != nil { - err = errors.Wrap(err, "Error in GetPeerRateLimits") - return nil, c.setLastErr(err) - } - return resp.RateLimits[0], nil - } - - resp, err = c.getPeerRateLimitsBatch(ctx, r) - if err != nil { - err = errors.Wrap(err, "Error in getPeerRateLimitsBatch") - return nil, c.setLastErr(err) - } - - return resp, nil -} - -// GetPeerRateLimits requests a list of rate limit statuses from a peer -func (c *PeerClient) GetPeerRateLimits(ctx context.Context, r *GetPeerRateLimitsReq) (resp *GetPeerRateLimitsResp, err error) { - // NOTE: This must be done within the Lock since calling Wait() in Shutdown() causes - // a race condition if called within a separate go routine if the internal wg is `0` - // when Wait() is called then Add(1) is called concurrently. - c.wgMutex.Lock() - c.wg.Add(1) - c.wgMutex.Unlock() - defer c.wg.Done() - - resp, err = c.client.GetPeerRateLimits(ctx, r) - if err != nil { - err = errors.Wrap(err, "Error in client.GetPeerRateLimits") - // metricCheckErrorCounter is updated within client.GetPeerRateLimits(). - return nil, c.setLastErr(err) - } - - // Unlikely, but this avoids a panic if something wonky happens - if len(resp.RateLimits) != len(r.Requests) { - err = errors.New("number of rate limits in peer response does not match request") - metricCheckErrorCounter.WithLabelValues("Item mismatch").Add(1) - return nil, c.setLastErr(err) - } - return resp, nil -} - -// UpdatePeerGlobals sends global rate limit status updates to a peer -func (c *PeerClient) UpdatePeerGlobals(ctx context.Context, r *UpdatePeerGlobalsReq) (resp *UpdatePeerGlobalsResp, err error) { - - // See NOTE above about RLock and wg.Add(1) - c.wgMutex.Lock() - c.wg.Add(1) - c.wgMutex.Unlock() - defer c.wg.Done() - - resp, err = c.client.UpdatePeerGlobals(ctx, r) - if err != nil { - _ = c.setLastErr(err) - } - - return resp, err -} - -func (c *PeerClient) setLastErr(err error) error { - // If we get a nil error return without caching it - if err == nil { - return err - } - - // Prepend client address to error - errWithHostname := errors.Wrap(err, fmt.Sprintf("from host %s", c.conf.Info.GRPCAddress)) - key := err.Error() - - // Add error to the cache with a TTL of 5 minutes - c.lastErrs.AddWithTTL(key, errWithHostname, clock.Minute*5) - - return err -} - -func (c *PeerClient) GetLastErr() []string { - var errs []string - keys := c.lastErrs.Keys() - - // Get errors from each key in the cache - for _, key := range keys { - err, ok := c.lastErrs.Get(key) - if ok { - errs = append(errs, err.(error).Error()) - } - } - - return errs -} - -func (c *PeerClient) getPeerRateLimitsBatch(ctx context.Context, r *RateLimitReq) (resp *RateLimitResp, err error) { - funcTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("PeerClient.getPeerRateLimitsBatch")) - defer funcTimer.ObserveDuration() - - req := request{ - resp: make(chan *response, 1), - ctx: ctx, - request: r, - } - - c.wgMutex.Lock() - c.wg.Add(1) - c.wgMutex.Unlock() - defer c.wg.Done() - - // Enqueue the request to be sent - peerAddr := c.Info().GRPCAddress - metricBatchQueueLength.WithLabelValues(peerAddr).Set(float64(len(c.queue))) - - if c.queueClosed.Load() { - // this check prevents "panic: send on close channel" - return nil, status.Error(codes.Canceled, "grpc: the client connection is closing") - } - - select { - case c.queue <- &req: - // Successfully enqueued request. - case <-ctx.Done(): - return nil, errors.Wrap(ctx.Err(), "Context error while enqueuing request") - } - - // Wait for a response or context cancel - select { - case re := <-req.resp: - if re.err != nil { - err := errors.Wrap(c.setLastErr(re.err), "Request error") - return nil, c.setLastErr(err) - } - return re.rl, nil - case <-ctx.Done(): - return nil, errors.Wrap(ctx.Err(), "Context error while waiting for response") - } -} - -// runBatch processes batching requests by waiting for requests to be queued. Send -// the queue as a batch when either c.batchWait time has elapsed or the queue -// reaches c.batchLimit. -func (c *PeerClient) runBatch() { - var interval = NewInterval(c.conf.Behavior.BatchWait) - defer interval.Stop() - - var queue []*request - - for { - ctx := context.Background() - - select { - case r, ok := <-c.queue: - if !ok { - // If the queue has shutdown, we need to send the rest of the queue - if len(queue) > 0 { - c.sendBatch(ctx, queue) - } - return - } - - queue = append(queue, r) - // Send the queue if we reached our batch limit - if len(queue) >= c.conf.Behavior.BatchLimit { - c.conf.Log.WithContext(ctx). - WithFields(logrus.Fields{ - "queueLen": len(queue), - "batchLimit": c.conf.Behavior.BatchLimit, - }). - Debug("runBatch() reached batch limit") - ref := queue - queue = nil - go c.sendBatch(ctx, ref) - continue - } - - // If this is our first enqueued item since last - // sendBatch, reset interval timer. - if len(queue) == 1 { - interval.Next() - } - continue - - case <-interval.C: - queue2 := queue - - if len(queue2) > 0 { - queue = nil - - go func() { - c.sendBatch(ctx, queue2) - }() - } - } - } -} - -// sendBatch sends the queue provided and returns the responses to -// waiting go routines -func (c *PeerClient) sendBatch(ctx context.Context, queue []*request) { - batchSendTimer := prometheus.NewTimer(metricBatchSendDuration.WithLabelValues(c.conf.Info.GRPCAddress)) - defer batchSendTimer.ObserveDuration() - funcTimer := prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("PeerClient.sendBatch")) - defer funcTimer.ObserveDuration() - - var req GetPeerRateLimitsReq - for _, r := range queue { - // NOTE: This trace has the same name because it's in a separate trace than the one above. - // We link the two traces, so we can relate our rate limit trace back to the above trace. - r.ctx = tracing.StartNamedScope(r.ctx, "PeerClient.sendBatch", - trace.WithLinks(trace.LinkFromContext(ctx))) - // If no metadata is provided - if r.request.Metadata == nil { - r.request.Metadata = make(map[string]string) - } - // Propagate the trace context along with the batched rate limit so - // peers can continue to report traces for this rate limit. - prop := propagation.TraceContext{} - prop.Inject(r.ctx, &MetadataCarrier{Map: r.request.Metadata}) - req.Requests = append(req.Requests, r.request) - tracing.EndScope(r.ctx, nil) - } - - timeoutCtx, timeoutCancel := context.WithTimeout(ctx, c.conf.Behavior.BatchTimeout) - resp, err := c.client.GetPeerRateLimits(timeoutCtx, &req) - timeoutCancel() - - // An error here indicates the entire request failed - if err != nil { - logPart := "Error in client.GetPeerRateLimits" - c.conf.Log.WithContext(ctx). - WithError(err). - WithFields(logrus.Fields{ - "queueLen": len(queue), - "batchTimeout": c.conf.Behavior.BatchTimeout.String(), - }). - Error(logPart) - err = errors.Wrap(err, logPart) - _ = c.setLastErr(err) - // metricCheckErrorCounter is updated within client.GetPeerRateLimits(). - - for _, r := range queue { - r.resp <- &response{err: err} - } - return - } - - // Unlikely, but this avoids a panic if something wonky happens - if len(resp.RateLimits) != len(queue) { - err = errors.New("server responded with incorrect rate limit list size") - - for _, r := range queue { - metricCheckErrorCounter.WithLabelValues("Item mismatch").Add(1) - r.resp <- &response{err: err} - } - return - } - - // Provide responses to channels waiting in the queue - for i, r := range queue { - r.resp <- &response{rl: resp.RateLimits[i]} - } -} - -// Shutdown waits until all outstanding requests have finished or the context is cancelled. -// Then it closes the grpc connection. -func (c *PeerClient) Shutdown(ctx context.Context) error { - // ensure we don't leak goroutines, even if the Shutdown times out - defer c.conn.Close() - - waitChan := make(chan struct{}) - go func() { - // drain in-flight requests - c.wgMutex.Lock() - defer c.wgMutex.Unlock() - c.wg.Wait() - - // clear errors - c.lastErrs = collections.NewLRUCache(100) - - // signal that no more items will be sent - c.queueClosed.Store(true) - close(c.queue) - - close(waitChan) - }() - - select { - case <-ctx.Done(): - return ctx.Err() - case <-waitChan: - return nil - } -} diff --git a/peer_client_test.go b/peer_test.go similarity index 57% rename from peer_client_test.go rename to peer_test.go index f445677..6eb5e09 100644 --- a/peer_client_test.go +++ b/peer_test.go @@ -19,27 +19,23 @@ package gubernator_test import ( "context" "runtime" - "strings" + "sync" "testing" - "github.com/gubernator-io/gubernator/v2" - "github.com/gubernator-io/gubernator/v2/cluster" - "github.com/mailgun/holster/v4/clock" - "github.com/pkg/errors" + "github.com/gubernator-io/gubernator/v3" + "github.com/gubernator-io/gubernator/v3/cluster" + "github.com/kapetan-io/tackle/clock" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/sync/errgroup" ) func TestPeerClientShutdown(t *testing.T) { - type test struct { - Name string - Behavior gubernator.Behavior - } - const threads = 10 - createdAt := epochMillis(clock.Now()) - cases := []test{ + cases := []struct { + Name string + Behavior gubernator.Behavior + }{ {"No batching", gubernator.Behavior_NO_BATCHING}, {"Batching", gubernator.Behavior_BATCHING}, {"Global", gubernator.Behavior_GLOBAL}, @@ -59,45 +55,45 @@ func TestPeerClientShutdown(t *testing.T) { c := cases[i] t.Run(c.Name, func(t *testing.T) { - client, err := gubernator.NewPeerClient(gubernator.PeerConfig{ - Info: cluster.GetRandomPeer(cluster.DataCenterNone), + client, err := gubernator.NewPeer(gubernator.PeerConfig{ + Info: cluster.GetRandomPeerInfo(cluster.DataCenterNone), Behavior: config, }) require.NoError(t, err) - wg := errgroup.Group{} - wg.SetLimit(threads) - // Spawn a whole bunch of concurrent requests to test shutdown in various states - for j := 0; j < threads; j++ { - wg.Go(func() error { + wg := sync.WaitGroup{} + wg.Add(threads) + // Spawn a bunch of concurrent requests to test shutdown in various states + for i := 0; i < threads; i++ { + go func(client *gubernator.Peer, behavior gubernator.Behavior) { + defer wg.Done() ctx := context.Background() - _, err := client.GetPeerRateLimit(ctx, &gubernator.RateLimitReq{ - Hits: 1, - Limit: 100, - Behavior: c.Behavior, - CreatedAt: &createdAt, + _, err := client.Forward(ctx, &gubernator.RateLimitRequest{ + Hits: 1, + Limit: 100, + Behavior: behavior, }) - if err != nil { - if !strings.Contains(err.Error(), "client connection is closing") { - return errors.Wrap(err, "unexpected error in test") - } + isExpectedErr := false + + switch err.(type) { + case nil: + isExpectedErr = true } - return nil - }) + + assert.True(t, true, isExpectedErr) + + }(client, c.Behavior) } // yield the processor that way we allow other goroutines to start their request runtime.Gosched() - shutDownErr := client.Shutdown(context.Background()) + err = client.Close(context.Background()) + assert.NoError(t, err) - err = wg.Wait() - if err != nil { - t.Error(err) - t.Fail() - } - require.NoError(t, shutDownErr) + wg.Wait() }) + } } diff --git a/peers.pb.go b/peers.pb.go deleted file mode 100644 index 85a5749..0000000 --- a/peers.pb.go +++ /dev/null @@ -1,495 +0,0 @@ -// -//Copyright 2018-2022 Mailgun Technologies Inc -// -//Licensed under the Apache License, Version 2.0 (the "License"); -//you may not use this file except in compliance with the License. -//You may obtain a copy of the License at -// -//http://www.apache.org/licenses/LICENSE-2.0 -// -//Unless required by applicable law or agreed to in writing, software -//distributed under the License is distributed on an "AS IS" BASIS, -//WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -//See the License for the specific language governing permissions and -//limitations under the License. - -// Code generated by protoc-gen-go. DO NOT EDIT. -// versions: -// protoc-gen-go v1.32.0 -// protoc (unknown) -// source: peers.proto - -package gubernator - -import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" - reflect "reflect" - sync "sync" -) - -const ( - // Verify that this generated code is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) - // Verify that runtime/protoimpl is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) -) - -type GetPeerRateLimitsReq struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // Must specify at least one RateLimit. The peer that recives this request MUST be authoritative for - // each rate_limit[x].unique_key provided, as the peer will not forward the request to any other peers - Requests []*RateLimitReq `protobuf:"bytes,1,rep,name=requests,proto3" json:"requests,omitempty"` -} - -func (x *GetPeerRateLimitsReq) Reset() { - *x = GetPeerRateLimitsReq{} - if protoimpl.UnsafeEnabled { - mi := &file_peers_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *GetPeerRateLimitsReq) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*GetPeerRateLimitsReq) ProtoMessage() {} - -func (x *GetPeerRateLimitsReq) ProtoReflect() protoreflect.Message { - mi := &file_peers_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use GetPeerRateLimitsReq.ProtoReflect.Descriptor instead. -func (*GetPeerRateLimitsReq) Descriptor() ([]byte, []int) { - return file_peers_proto_rawDescGZIP(), []int{0} -} - -func (x *GetPeerRateLimitsReq) GetRequests() []*RateLimitReq { - if x != nil { - return x.Requests - } - return nil -} - -type GetPeerRateLimitsResp struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // Responses are in the same order as they appeared in the PeerRateLimitRequests - RateLimits []*RateLimitResp `protobuf:"bytes,1,rep,name=rate_limits,json=rateLimits,proto3" json:"rate_limits,omitempty"` -} - -func (x *GetPeerRateLimitsResp) Reset() { - *x = GetPeerRateLimitsResp{} - if protoimpl.UnsafeEnabled { - mi := &file_peers_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *GetPeerRateLimitsResp) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*GetPeerRateLimitsResp) ProtoMessage() {} - -func (x *GetPeerRateLimitsResp) ProtoReflect() protoreflect.Message { - mi := &file_peers_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use GetPeerRateLimitsResp.ProtoReflect.Descriptor instead. -func (*GetPeerRateLimitsResp) Descriptor() ([]byte, []int) { - return file_peers_proto_rawDescGZIP(), []int{1} -} - -func (x *GetPeerRateLimitsResp) GetRateLimits() []*RateLimitResp { - if x != nil { - return x.RateLimits - } - return nil -} - -type UpdatePeerGlobalsReq struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // Must specify at least one RateLimit - Globals []*UpdatePeerGlobal `protobuf:"bytes,1,rep,name=globals,proto3" json:"globals,omitempty"` -} - -func (x *UpdatePeerGlobalsReq) Reset() { - *x = UpdatePeerGlobalsReq{} - if protoimpl.UnsafeEnabled { - mi := &file_peers_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *UpdatePeerGlobalsReq) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*UpdatePeerGlobalsReq) ProtoMessage() {} - -func (x *UpdatePeerGlobalsReq) ProtoReflect() protoreflect.Message { - mi := &file_peers_proto_msgTypes[2] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use UpdatePeerGlobalsReq.ProtoReflect.Descriptor instead. -func (*UpdatePeerGlobalsReq) Descriptor() ([]byte, []int) { - return file_peers_proto_rawDescGZIP(), []int{2} -} - -func (x *UpdatePeerGlobalsReq) GetGlobals() []*UpdatePeerGlobal { - if x != nil { - return x.Globals - } - return nil -} - -type UpdatePeerGlobal struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - // Uniquely identifies this rate limit IE: 'ip:10.2.10.7' or 'account:123445' - Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` - Status *RateLimitResp `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` - // The algorithm used to calculate the rate limit. The algorithm may change on - // subsequent requests, when this occurs any previous rate limit hit counts are reset. - Algorithm Algorithm `protobuf:"varint,3,opt,name=algorithm,proto3,enum=pb.gubernator.Algorithm" json:"algorithm,omitempty"` - // The duration of the rate limit in milliseconds - Duration int64 `protobuf:"varint,4,opt,name=duration,proto3" json:"duration,omitempty"` - // The exact time the original request was created in Epoch milliseconds. - // Due to time drift between systems, it may be advantageous for a client to - // set the exact time the request was created. It possible the system clock - // for the client has drifted from the system clock where gubernator daemon - // is running. - // - // The created time is used by gubernator to calculate the reset time for - // both token and leaky algorithms. If it is not set by the client, - // gubernator will set the created time when it receives the rate limit - // request. - CreatedAt int64 `protobuf:"varint,5,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"` -} - -func (x *UpdatePeerGlobal) Reset() { - *x = UpdatePeerGlobal{} - if protoimpl.UnsafeEnabled { - mi := &file_peers_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *UpdatePeerGlobal) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*UpdatePeerGlobal) ProtoMessage() {} - -func (x *UpdatePeerGlobal) ProtoReflect() protoreflect.Message { - mi := &file_peers_proto_msgTypes[3] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use UpdatePeerGlobal.ProtoReflect.Descriptor instead. -func (*UpdatePeerGlobal) Descriptor() ([]byte, []int) { - return file_peers_proto_rawDescGZIP(), []int{3} -} - -func (x *UpdatePeerGlobal) GetKey() string { - if x != nil { - return x.Key - } - return "" -} - -func (x *UpdatePeerGlobal) GetStatus() *RateLimitResp { - if x != nil { - return x.Status - } - return nil -} - -func (x *UpdatePeerGlobal) GetAlgorithm() Algorithm { - if x != nil { - return x.Algorithm - } - return Algorithm_TOKEN_BUCKET -} - -func (x *UpdatePeerGlobal) GetDuration() int64 { - if x != nil { - return x.Duration - } - return 0 -} - -func (x *UpdatePeerGlobal) GetCreatedAt() int64 { - if x != nil { - return x.CreatedAt - } - return 0 -} - -type UpdatePeerGlobalsResp struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields -} - -func (x *UpdatePeerGlobalsResp) Reset() { - *x = UpdatePeerGlobalsResp{} - if protoimpl.UnsafeEnabled { - mi := &file_peers_proto_msgTypes[4] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *UpdatePeerGlobalsResp) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*UpdatePeerGlobalsResp) ProtoMessage() {} - -func (x *UpdatePeerGlobalsResp) ProtoReflect() protoreflect.Message { - mi := &file_peers_proto_msgTypes[4] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use UpdatePeerGlobalsResp.ProtoReflect.Descriptor instead. -func (*UpdatePeerGlobalsResp) Descriptor() ([]byte, []int) { - return file_peers_proto_rawDescGZIP(), []int{4} -} - -var File_peers_proto protoreflect.FileDescriptor - -var file_peers_proto_rawDesc = []byte{ - 0x0a, 0x0b, 0x70, 0x65, 0x65, 0x72, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0d, 0x70, - 0x62, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x1a, 0x10, 0x67, 0x75, - 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x4f, - 0x0a, 0x14, 0x47, 0x65, 0x74, 0x50, 0x65, 0x65, 0x72, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, - 0x69, 0x74, 0x73, 0x52, 0x65, 0x71, 0x12, 0x37, 0x0a, 0x08, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x70, 0x62, 0x2e, 0x67, 0x75, - 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, - 0x69, 0x74, 0x52, 0x65, 0x71, 0x52, 0x08, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x73, 0x22, - 0x56, 0x0a, 0x15, 0x47, 0x65, 0x74, 0x50, 0x65, 0x65, 0x72, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, - 0x6d, 0x69, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x12, 0x3d, 0x0a, 0x0b, 0x72, 0x61, 0x74, 0x65, - 0x5f, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, - 0x70, 0x62, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x52, 0x61, - 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x52, 0x65, 0x73, 0x70, 0x52, 0x0a, 0x72, 0x61, 0x74, - 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x73, 0x22, 0x51, 0x0a, 0x14, 0x55, 0x70, 0x64, 0x61, 0x74, - 0x65, 0x50, 0x65, 0x65, 0x72, 0x47, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x52, 0x65, 0x71, 0x12, - 0x39, 0x0a, 0x07, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, - 0x32, 0x1f, 0x2e, 0x70, 0x62, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, - 0x2e, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x47, 0x6c, 0x6f, 0x62, 0x61, - 0x6c, 0x52, 0x07, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x22, 0xcd, 0x01, 0x0a, 0x10, 0x55, - 0x70, 0x64, 0x61, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x47, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x12, - 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, - 0x79, 0x12, 0x34, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x1c, 0x2e, 0x70, 0x62, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, - 0x72, 0x2e, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x52, 0x65, 0x73, 0x70, 0x52, - 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x36, 0x0a, 0x09, 0x61, 0x6c, 0x67, 0x6f, 0x72, - 0x69, 0x74, 0x68, 0x6d, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x70, 0x62, 0x2e, - 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x41, 0x6c, 0x67, 0x6f, 0x72, - 0x69, 0x74, 0x68, 0x6d, 0x52, 0x09, 0x61, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x12, - 0x1a, 0x0a, 0x08, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x03, 0x52, 0x08, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1d, 0x0a, 0x0a, 0x63, - 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x09, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x41, 0x74, 0x22, 0x17, 0x0a, 0x15, 0x55, 0x70, - 0x64, 0x61, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x47, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x52, - 0x65, 0x73, 0x70, 0x32, 0xcd, 0x01, 0x0a, 0x07, 0x50, 0x65, 0x65, 0x72, 0x73, 0x56, 0x31, 0x12, - 0x60, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x50, 0x65, 0x65, 0x72, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, - 0x6d, 0x69, 0x74, 0x73, 0x12, 0x23, 0x2e, 0x70, 0x62, 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, - 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x47, 0x65, 0x74, 0x50, 0x65, 0x65, 0x72, 0x52, 0x61, 0x74, 0x65, - 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x73, 0x52, 0x65, 0x71, 0x1a, 0x24, 0x2e, 0x70, 0x62, 0x2e, 0x67, - 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x47, 0x65, 0x74, 0x50, 0x65, 0x65, - 0x72, 0x52, 0x61, 0x74, 0x65, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x22, - 0x00, 0x12, 0x60, 0x0a, 0x11, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x47, - 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x12, 0x23, 0x2e, 0x70, 0x62, 0x2e, 0x67, 0x75, 0x62, 0x65, - 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x50, 0x65, 0x65, - 0x72, 0x47, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x52, 0x65, 0x71, 0x1a, 0x24, 0x2e, 0x70, 0x62, - 0x2e, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2e, 0x55, 0x70, 0x64, 0x61, - 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x47, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x52, 0x65, 0x73, - 0x70, 0x22, 0x00, 0x42, 0x28, 0x5a, 0x23, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, - 0x6d, 0x2f, 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x2d, 0x69, 0x6f, 0x2f, - 0x67, 0x75, 0x62, 0x65, 0x72, 0x6e, 0x61, 0x74, 0x6f, 0x72, 0x80, 0x01, 0x01, 0x62, 0x06, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x33, -} - -var ( - file_peers_proto_rawDescOnce sync.Once - file_peers_proto_rawDescData = file_peers_proto_rawDesc -) - -func file_peers_proto_rawDescGZIP() []byte { - file_peers_proto_rawDescOnce.Do(func() { - file_peers_proto_rawDescData = protoimpl.X.CompressGZIP(file_peers_proto_rawDescData) - }) - return file_peers_proto_rawDescData -} - -var file_peers_proto_msgTypes = make([]protoimpl.MessageInfo, 5) -var file_peers_proto_goTypes = []interface{}{ - (*GetPeerRateLimitsReq)(nil), // 0: pb.gubernator.GetPeerRateLimitsReq - (*GetPeerRateLimitsResp)(nil), // 1: pb.gubernator.GetPeerRateLimitsResp - (*UpdatePeerGlobalsReq)(nil), // 2: pb.gubernator.UpdatePeerGlobalsReq - (*UpdatePeerGlobal)(nil), // 3: pb.gubernator.UpdatePeerGlobal - (*UpdatePeerGlobalsResp)(nil), // 4: pb.gubernator.UpdatePeerGlobalsResp - (*RateLimitReq)(nil), // 5: pb.gubernator.RateLimitReq - (*RateLimitResp)(nil), // 6: pb.gubernator.RateLimitResp - (Algorithm)(0), // 7: pb.gubernator.Algorithm -} -var file_peers_proto_depIdxs = []int32{ - 5, // 0: pb.gubernator.GetPeerRateLimitsReq.requests:type_name -> pb.gubernator.RateLimitReq - 6, // 1: pb.gubernator.GetPeerRateLimitsResp.rate_limits:type_name -> pb.gubernator.RateLimitResp - 3, // 2: pb.gubernator.UpdatePeerGlobalsReq.globals:type_name -> pb.gubernator.UpdatePeerGlobal - 6, // 3: pb.gubernator.UpdatePeerGlobal.status:type_name -> pb.gubernator.RateLimitResp - 7, // 4: pb.gubernator.UpdatePeerGlobal.algorithm:type_name -> pb.gubernator.Algorithm - 0, // 5: pb.gubernator.PeersV1.GetPeerRateLimits:input_type -> pb.gubernator.GetPeerRateLimitsReq - 2, // 6: pb.gubernator.PeersV1.UpdatePeerGlobals:input_type -> pb.gubernator.UpdatePeerGlobalsReq - 1, // 7: pb.gubernator.PeersV1.GetPeerRateLimits:output_type -> pb.gubernator.GetPeerRateLimitsResp - 4, // 8: pb.gubernator.PeersV1.UpdatePeerGlobals:output_type -> pb.gubernator.UpdatePeerGlobalsResp - 7, // [7:9] is the sub-list for method output_type - 5, // [5:7] is the sub-list for method input_type - 5, // [5:5] is the sub-list for extension type_name - 5, // [5:5] is the sub-list for extension extendee - 0, // [0:5] is the sub-list for field type_name -} - -func init() { file_peers_proto_init() } -func file_peers_proto_init() { - if File_peers_proto != nil { - return - } - file_gubernator_proto_init() - if !protoimpl.UnsafeEnabled { - file_peers_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetPeerRateLimitsReq); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_peers_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetPeerRateLimitsResp); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_peers_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*UpdatePeerGlobalsReq); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_peers_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*UpdatePeerGlobal); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_peers_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*UpdatePeerGlobalsResp); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } - type x struct{} - out := protoimpl.TypeBuilder{ - File: protoimpl.DescBuilder{ - GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_peers_proto_rawDesc, - NumEnums: 0, - NumMessages: 5, - NumExtensions: 0, - NumServices: 1, - }, - GoTypes: file_peers_proto_goTypes, - DependencyIndexes: file_peers_proto_depIdxs, - MessageInfos: file_peers_proto_msgTypes, - }.Build() - File_peers_proto = out.File - file_peers_proto_rawDesc = nil - file_peers_proto_goTypes = nil - file_peers_proto_depIdxs = nil -} diff --git a/peers.pb.gw.go b/peers.pb.gw.go deleted file mode 100644 index f092976..0000000 --- a/peers.pb.gw.go +++ /dev/null @@ -1,256 +0,0 @@ -// Code generated by protoc-gen-grpc-gateway. DO NOT EDIT. -// source: peers.proto - -/* -Package gubernator is a reverse proxy. - -It translates gRPC into RESTful JSON APIs. -*/ -package gubernator - -import ( - "context" - "io" - "net/http" - - "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" - "github.com/grpc-ecosystem/grpc-gateway/v2/utilities" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/metadata" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/proto" -) - -// Suppress "imported and not used" errors -var _ codes.Code -var _ io.Reader -var _ status.Status -var _ = runtime.String -var _ = utilities.NewDoubleArray -var _ = metadata.Join - -func request_PeersV1_GetPeerRateLimits_0(ctx context.Context, marshaler runtime.Marshaler, client PeersV1Client, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq GetPeerRateLimitsReq - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - - msg, err := client.GetPeerRateLimits(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) - return msg, metadata, err - -} - -func local_request_PeersV1_GetPeerRateLimits_0(ctx context.Context, marshaler runtime.Marshaler, server PeersV1Server, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq GetPeerRateLimitsReq - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - - msg, err := server.GetPeerRateLimits(ctx, &protoReq) - return msg, metadata, err - -} - -func request_PeersV1_UpdatePeerGlobals_0(ctx context.Context, marshaler runtime.Marshaler, client PeersV1Client, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq UpdatePeerGlobalsReq - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - - msg, err := client.UpdatePeerGlobals(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) - return msg, metadata, err - -} - -func local_request_PeersV1_UpdatePeerGlobals_0(ctx context.Context, marshaler runtime.Marshaler, server PeersV1Server, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq UpdatePeerGlobalsReq - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - - msg, err := server.UpdatePeerGlobals(ctx, &protoReq) - return msg, metadata, err - -} - -// RegisterPeersV1HandlerServer registers the http handlers for service PeersV1 to "mux". -// UnaryRPC :call PeersV1Server directly. -// StreamingRPC :currently unsupported pending https://github.com/grpc/grpc-go/issues/906. -// Note that using this registration option will cause many gRPC library features to stop working. Consider using RegisterPeersV1HandlerFromEndpoint instead. -func RegisterPeersV1HandlerServer(ctx context.Context, mux *runtime.ServeMux, server PeersV1Server) error { - - mux.Handle("POST", pattern_PeersV1_GetPeerRateLimits_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - var stream runtime.ServerTransportStream - ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/pb.gubernator.PeersV1/GetPeerRateLimits", runtime.WithHTTPPathPattern("/pb.gubernator.PeersV1/GetPeerRateLimits")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := local_request_PeersV1_GetPeerRateLimits_0(annotatedContext, inboundMarshaler, server, req, pathParams) - md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - - forward_PeersV1_GetPeerRateLimits_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - mux.Handle("POST", pattern_PeersV1_UpdatePeerGlobals_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - var stream runtime.ServerTransportStream - ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/pb.gubernator.PeersV1/UpdatePeerGlobals", runtime.WithHTTPPathPattern("/pb.gubernator.PeersV1/UpdatePeerGlobals")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := local_request_PeersV1_UpdatePeerGlobals_0(annotatedContext, inboundMarshaler, server, req, pathParams) - md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - - forward_PeersV1_UpdatePeerGlobals_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - return nil -} - -// RegisterPeersV1HandlerFromEndpoint is same as RegisterPeersV1Handler but -// automatically dials to "endpoint" and closes the connection when "ctx" gets done. -func RegisterPeersV1HandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { - conn, err := grpc.DialContext(ctx, endpoint, opts...) - if err != nil { - return err - } - defer func() { - if err != nil { - if cerr := conn.Close(); cerr != nil { - grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) - } - return - } - go func() { - <-ctx.Done() - if cerr := conn.Close(); cerr != nil { - grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) - } - }() - }() - - return RegisterPeersV1Handler(ctx, mux, conn) -} - -// RegisterPeersV1Handler registers the http handlers for service PeersV1 to "mux". -// The handlers forward requests to the grpc endpoint over "conn". -func RegisterPeersV1Handler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error { - return RegisterPeersV1HandlerClient(ctx, mux, NewPeersV1Client(conn)) -} - -// RegisterPeersV1HandlerClient registers the http handlers for service PeersV1 -// to "mux". The handlers forward requests to the grpc endpoint over the given implementation of "PeersV1Client". -// Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "PeersV1Client" -// doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in -// "PeersV1Client" to call the correct interceptors. -func RegisterPeersV1HandlerClient(ctx context.Context, mux *runtime.ServeMux, client PeersV1Client) error { - - mux.Handle("POST", pattern_PeersV1_GetPeerRateLimits_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/pb.gubernator.PeersV1/GetPeerRateLimits", runtime.WithHTTPPathPattern("/pb.gubernator.PeersV1/GetPeerRateLimits")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := request_PeersV1_GetPeerRateLimits_0(annotatedContext, inboundMarshaler, client, req, pathParams) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - - forward_PeersV1_GetPeerRateLimits_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - mux.Handle("POST", pattern_PeersV1_UpdatePeerGlobals_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/pb.gubernator.PeersV1/UpdatePeerGlobals", runtime.WithHTTPPathPattern("/pb.gubernator.PeersV1/UpdatePeerGlobals")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := request_PeersV1_UpdatePeerGlobals_0(annotatedContext, inboundMarshaler, client, req, pathParams) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - - forward_PeersV1_UpdatePeerGlobals_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - return nil -} - -var ( - pattern_PeersV1_GetPeerRateLimits_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1}, []string{"pb.gubernator.PeersV1", "GetPeerRateLimits"}, "")) - - pattern_PeersV1_UpdatePeerGlobals_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1}, []string{"pb.gubernator.PeersV1", "UpdatePeerGlobals"}, "")) -) - -var ( - forward_PeersV1_GetPeerRateLimits_0 = runtime.ForwardResponseMessage - - forward_PeersV1_UpdatePeerGlobals_0 = runtime.ForwardResponseMessage -) diff --git a/peers_grpc.pb.go b/peers_grpc.pb.go deleted file mode 100644 index e74a7d1..0000000 --- a/peers_grpc.pb.go +++ /dev/null @@ -1,163 +0,0 @@ -// -//Copyright 2018-2022 Mailgun Technologies Inc -// -//Licensed under the Apache License, Version 2.0 (the "License"); -//you may not use this file except in compliance with the License. -//You may obtain a copy of the License at -// -//http://www.apache.org/licenses/LICENSE-2.0 -// -//Unless required by applicable law or agreed to in writing, software -//distributed under the License is distributed on an "AS IS" BASIS, -//WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -//See the License for the specific language governing permissions and -//limitations under the License. - -// Code generated by protoc-gen-go-grpc. DO NOT EDIT. -// versions: -// - protoc-gen-go-grpc v1.3.0 -// - protoc (unknown) -// source: peers.proto - -package gubernator - -import ( - context "context" - grpc "google.golang.org/grpc" - codes "google.golang.org/grpc/codes" - status "google.golang.org/grpc/status" -) - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.32.0 or later. -const _ = grpc.SupportPackageIsVersion7 - -const ( - PeersV1_GetPeerRateLimits_FullMethodName = "/pb.gubernator.PeersV1/GetPeerRateLimits" - PeersV1_UpdatePeerGlobals_FullMethodName = "/pb.gubernator.PeersV1/UpdatePeerGlobals" -) - -// PeersV1Client is the client API for PeersV1 service. -// -// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. -type PeersV1Client interface { - // Used by peers to relay batches of requests to an owner peer - GetPeerRateLimits(ctx context.Context, in *GetPeerRateLimitsReq, opts ...grpc.CallOption) (*GetPeerRateLimitsResp, error) - // Used by owner peers to send global rate limit updates to non-owner peers - UpdatePeerGlobals(ctx context.Context, in *UpdatePeerGlobalsReq, opts ...grpc.CallOption) (*UpdatePeerGlobalsResp, error) -} - -type peersV1Client struct { - cc grpc.ClientConnInterface -} - -func NewPeersV1Client(cc grpc.ClientConnInterface) PeersV1Client { - return &peersV1Client{cc} -} - -func (c *peersV1Client) GetPeerRateLimits(ctx context.Context, in *GetPeerRateLimitsReq, opts ...grpc.CallOption) (*GetPeerRateLimitsResp, error) { - out := new(GetPeerRateLimitsResp) - err := c.cc.Invoke(ctx, PeersV1_GetPeerRateLimits_FullMethodName, in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *peersV1Client) UpdatePeerGlobals(ctx context.Context, in *UpdatePeerGlobalsReq, opts ...grpc.CallOption) (*UpdatePeerGlobalsResp, error) { - out := new(UpdatePeerGlobalsResp) - err := c.cc.Invoke(ctx, PeersV1_UpdatePeerGlobals_FullMethodName, in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -// PeersV1Server is the server API for PeersV1 service. -// All implementations should embed UnimplementedPeersV1Server -// for forward compatibility -type PeersV1Server interface { - // Used by peers to relay batches of requests to an owner peer - GetPeerRateLimits(context.Context, *GetPeerRateLimitsReq) (*GetPeerRateLimitsResp, error) - // Used by owner peers to send global rate limit updates to non-owner peers - UpdatePeerGlobals(context.Context, *UpdatePeerGlobalsReq) (*UpdatePeerGlobalsResp, error) -} - -// UnimplementedPeersV1Server should be embedded to have forward compatible implementations. -type UnimplementedPeersV1Server struct { -} - -func (UnimplementedPeersV1Server) GetPeerRateLimits(context.Context, *GetPeerRateLimitsReq) (*GetPeerRateLimitsResp, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetPeerRateLimits not implemented") -} -func (UnimplementedPeersV1Server) UpdatePeerGlobals(context.Context, *UpdatePeerGlobalsReq) (*UpdatePeerGlobalsResp, error) { - return nil, status.Errorf(codes.Unimplemented, "method UpdatePeerGlobals not implemented") -} - -// UnsafePeersV1Server may be embedded to opt out of forward compatibility for this service. -// Use of this interface is not recommended, as added methods to PeersV1Server will -// result in compilation errors. -type UnsafePeersV1Server interface { - mustEmbedUnimplementedPeersV1Server() -} - -func RegisterPeersV1Server(s grpc.ServiceRegistrar, srv PeersV1Server) { - s.RegisterService(&PeersV1_ServiceDesc, srv) -} - -func _PeersV1_GetPeerRateLimits_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(GetPeerRateLimitsReq) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(PeersV1Server).GetPeerRateLimits(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: PeersV1_GetPeerRateLimits_FullMethodName, - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(PeersV1Server).GetPeerRateLimits(ctx, req.(*GetPeerRateLimitsReq)) - } - return interceptor(ctx, in, info, handler) -} - -func _PeersV1_UpdatePeerGlobals_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(UpdatePeerGlobalsReq) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(PeersV1Server).UpdatePeerGlobals(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: PeersV1_UpdatePeerGlobals_FullMethodName, - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(PeersV1Server).UpdatePeerGlobals(ctx, req.(*UpdatePeerGlobalsReq)) - } - return interceptor(ctx, in, info, handler) -} - -// PeersV1_ServiceDesc is the grpc.ServiceDesc for PeersV1 service. -// It's only intended for direct use with grpc.RegisterService, -// and not to be introspected or modified (even as a copy) -var PeersV1_ServiceDesc = grpc.ServiceDesc{ - ServiceName: "pb.gubernator.PeersV1", - HandlerType: (*PeersV1Server)(nil), - Methods: []grpc.MethodDesc{ - { - MethodName: "GetPeerRateLimits", - Handler: _PeersV1_GetPeerRateLimits_Handler, - }, - { - MethodName: "UpdatePeerGlobals", - Handler: _PeersV1_UpdatePeerGlobals_Handler, - }, - }, - Streams: []grpc.StreamDesc{}, - Metadata: "peers.proto", -} diff --git a/python/gubernator/__init__.py b/python/gubernator/__init__.py deleted file mode 100644 index b90c1a3..0000000 --- a/python/gubernator/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# This code is py3.7 and py2.7 compatible - -import gubernator.ratelimit_pb2_grpc as pb_grpc -from datetime import datetime - -import time -import grpc - -MILLISECOND = 1 -SECOND = MILLISECOND * 1000 -MINUTE = SECOND * 60 - - -def sleep_until_reset(reset_time): - now = datetime.now() - time.sleep((reset_time - now).seconds) - - -def V1Client(endpoint='127.0.0.1:9090'): - channel = grpc.insecure_channel(endpoint) - return pb_grpc.RateLimitServiceV1Stub(channel) diff --git a/python/gubernator/gubernator_pb2.py b/python/gubernator/gubernator_pb2.py index fe20ddc..6fe6165 100644 --- a/python/gubernator/gubernator_pb2.py +++ b/python/gubernator/gubernator_pb2.py @@ -12,47 +12,40 @@ _sym_db = _symbol_database.Default() -from google.api import annotations_pb2 as google_dot_api_dot_annotations__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10gubernator.proto\x12\rpb.gubernator\x1a\x1cgoogle/api/annotations.proto\"K\n\x10GetRateLimitsReq\x12\x37\n\x08requests\x18\x01 \x03(\x0b\x32\x1b.pb.gubernator.RateLimitReqR\x08requests\"O\n\x11GetRateLimitsResp\x12:\n\tresponses\x18\x01 \x03(\x0b\x32\x1c.pb.gubernator.RateLimitRespR\tresponses\"\xc1\x03\n\x0cRateLimitReq\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1d\n\nunique_key\x18\x02 \x01(\tR\tuniqueKey\x12\x12\n\x04hits\x18\x03 \x01(\x03R\x04hits\x12\x14\n\x05limit\x18\x04 \x01(\x03R\x05limit\x12\x1a\n\x08\x64uration\x18\x05 \x01(\x03R\x08\x64uration\x12\x36\n\talgorithm\x18\x06 \x01(\x0e\x32\x18.pb.gubernator.AlgorithmR\talgorithm\x12\x33\n\x08\x62\x65havior\x18\x07 \x01(\x0e\x32\x17.pb.gubernator.BehaviorR\x08\x62\x65havior\x12\x14\n\x05\x62urst\x18\x08 \x01(\x03R\x05\x62urst\x12\x45\n\x08metadata\x18\t \x03(\x0b\x32).pb.gubernator.RateLimitReq.MetadataEntryR\x08metadata\x12\"\n\ncreated_at\x18\n \x01(\x03H\x00R\tcreatedAt\x88\x01\x01\x1a;\n\rMetadataEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\r\n\x0b_created_at\"\xac\x02\n\rRateLimitResp\x12-\n\x06status\x18\x01 \x01(\x0e\x32\x15.pb.gubernator.StatusR\x06status\x12\x14\n\x05limit\x18\x02 \x01(\x03R\x05limit\x12\x1c\n\tremaining\x18\x03 \x01(\x03R\tremaining\x12\x1d\n\nreset_time\x18\x04 \x01(\x03R\tresetTime\x12\x14\n\x05\x65rror\x18\x05 \x01(\tR\x05\x65rror\x12\x46\n\x08metadata\x18\x06 \x03(\x0b\x32*.pb.gubernator.RateLimitResp.MetadataEntryR\x08metadata\x1a;\n\rMetadataEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\x10\n\x0eHealthCheckReq\"b\n\x0fHealthCheckResp\x12\x16\n\x06status\x18\x01 \x01(\tR\x06status\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12\x1d\n\npeer_count\x18\x03 \x01(\x05R\tpeerCount*/\n\tAlgorithm\x12\x10\n\x0cTOKEN_BUCKET\x10\x00\x12\x10\n\x0cLEAKY_BUCKET\x10\x01*\x8d\x01\n\x08\x42\x65havior\x12\x0c\n\x08\x42\x41TCHING\x10\x00\x12\x0f\n\x0bNO_BATCHING\x10\x01\x12\n\n\x06GLOBAL\x10\x02\x12\x19\n\x15\x44URATION_IS_GREGORIAN\x10\x04\x12\x13\n\x0fRESET_REMAINING\x10\x08\x12\x10\n\x0cMULTI_REGION\x10\x10\x12\x14\n\x10\x44RAIN_OVER_LIMIT\x10 *)\n\x06Status\x12\x0f\n\x0bUNDER_LIMIT\x10\x00\x12\x0e\n\nOVER_LIMIT\x10\x01\x32\xdd\x01\n\x02V1\x12p\n\rGetRateLimits\x12\x1f.pb.gubernator.GetRateLimitsReq\x1a .pb.gubernator.GetRateLimitsResp\"\x1c\x82\xd3\xe4\x93\x02\x16\"\x11/v1/GetRateLimits:\x01*\x12\x65\n\x0bHealthCheck\x12\x1d.pb.gubernator.HealthCheckReq\x1a\x1e.pb.gubernator.HealthCheckResp\"\x17\x82\xd3\xe4\x93\x02\x11\x12\x0f/v1/HealthCheckB(Z#github.com/gubernator-io/gubernator\x80\x01\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10gubernator.proto\x12\rgubernator.v3\"U\n\x16\x43heckRateLimitsRequest\x12;\n\x08requests\x18\x01 \x03(\x0b\x32\x1f.gubernator.v3.RateLimitRequestR\x08requests\"Y\n\x17\x43heckRateLimitsResponse\x12>\n\tresponses\x18\x01 \x03(\x0b\x32 .gubernator.v3.RateLimitResponseR\tresponses\"\xca\x03\n\x10RateLimitRequest\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1e\n\nunique_key\x18\x02 \x01(\tR\nunique_key\x12\x12\n\x04hits\x18\x03 \x01(\x03R\x04hits\x12\x14\n\x05limit\x18\x04 \x01(\x03R\x05limit\x12\x1a\n\x08\x64uration\x18\x05 \x01(\x03R\x08\x64uration\x12\x36\n\talgorithm\x18\x06 \x01(\x0e\x32\x18.gubernator.v3.AlgorithmR\talgorithm\x12\x33\n\x08\x62\x65havior\x18\x07 \x01(\x0e\x32\x17.gubernator.v3.BehaviorR\x08\x62\x65havior\x12\x14\n\x05\x62urst\x18\x08 \x01(\x03R\x05\x62urst\x12I\n\x08metadata\x18\t \x03(\x0b\x32-.gubernator.v3.RateLimitRequest.MetadataEntryR\x08metadata\x12\"\n\ncreated_at\x18\n \x01(\x03H\x00R\tcreatedAt\x88\x01\x01\x1a;\n\rMetadataEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\r\n\x0b_created_at\"\xb5\x02\n\x11RateLimitResponse\x12-\n\x06status\x18\x01 \x01(\x0e\x32\x15.gubernator.v3.StatusR\x06status\x12\x14\n\x05limit\x18\x02 \x01(\x03R\x05limit\x12\x1c\n\tremaining\x18\x03 \x01(\x03R\tremaining\x12\x1e\n\nreset_time\x18\x04 \x01(\x03R\nreset_time\x12\x14\n\x05\x65rror\x18\x05 \x01(\tR\x05\x65rror\x12J\n\x08metadata\x18\x06 \x03(\x0b\x32..gubernator.v3.RateLimitResponse.MetadataEntryR\x08metadata\x1a;\n\rMetadataEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\"\x14\n\x12HealthCheckRequest\"g\n\x13HealthCheckResponse\x12\x16\n\x06status\x18\x01 \x01(\tR\x06status\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12\x1e\n\npeer_count\x18\x03 \x01(\x05R\npeer_count*/\n\tAlgorithm\x12\x10\n\x0cTOKEN_BUCKET\x10\x00\x12\x10\n\x0cLEAKY_BUCKET\x10\x01*\x8d\x01\n\x08\x42\x65havior\x12\x0c\n\x08\x42\x41TCHING\x10\x00\x12\x0f\n\x0bNO_BATCHING\x10\x01\x12\n\n\x06GLOBAL\x10\x02\x12\x19\n\x15\x44URATION_IS_GREGORIAN\x10\x04\x12\x13\n\x0fRESET_REMAINING\x10\x08\x12\x10\n\x0cMULTI_REGION\x10\x10\x12\x14\n\x10\x44RAIN_OVER_LIMIT\x10 *)\n\x06Status\x12\x0f\n\x0bUNDER_LIMIT\x10\x00\x12\x0e\n\nOVER_LIMIT\x10\x01\x42%Z#github.com/gubernator-io/gubernatorb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'gubernator_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: _globals['DESCRIPTOR']._loaded_options = None - _globals['DESCRIPTOR']._serialized_options = b'Z#github.com/gubernator-io/gubernator\200\001\001' - _globals['_RATELIMITREQ_METADATAENTRY']._loaded_options = None - _globals['_RATELIMITREQ_METADATAENTRY']._serialized_options = b'8\001' - _globals['_RATELIMITRESP_METADATAENTRY']._loaded_options = None - _globals['_RATELIMITRESP_METADATAENTRY']._serialized_options = b'8\001' - _globals['_V1'].methods_by_name['GetRateLimits']._loaded_options = None - _globals['_V1'].methods_by_name['GetRateLimits']._serialized_options = b'\202\323\344\223\002\026\"\021/v1/GetRateLimits:\001*' - _globals['_V1'].methods_by_name['HealthCheck']._loaded_options = None - _globals['_V1'].methods_by_name['HealthCheck']._serialized_options = b'\202\323\344\223\002\021\022\017/v1/HealthCheck' - _globals['_ALGORITHM']._serialized_start=1096 - _globals['_ALGORITHM']._serialized_end=1143 - _globals['_BEHAVIOR']._serialized_start=1146 - _globals['_BEHAVIOR']._serialized_end=1287 - _globals['_STATUS']._serialized_start=1289 - _globals['_STATUS']._serialized_end=1330 - _globals['_GETRATELIMITSREQ']._serialized_start=65 - _globals['_GETRATELIMITSREQ']._serialized_end=140 - _globals['_GETRATELIMITSRESP']._serialized_start=142 - _globals['_GETRATELIMITSRESP']._serialized_end=221 - _globals['_RATELIMITREQ']._serialized_start=224 - _globals['_RATELIMITREQ']._serialized_end=673 - _globals['_RATELIMITREQ_METADATAENTRY']._serialized_start=599 - _globals['_RATELIMITREQ_METADATAENTRY']._serialized_end=658 - _globals['_RATELIMITRESP']._serialized_start=676 - _globals['_RATELIMITRESP']._serialized_end=976 - _globals['_RATELIMITRESP_METADATAENTRY']._serialized_start=599 - _globals['_RATELIMITRESP_METADATAENTRY']._serialized_end=658 - _globals['_HEALTHCHECKREQ']._serialized_start=978 - _globals['_HEALTHCHECKREQ']._serialized_end=994 - _globals['_HEALTHCHECKRESP']._serialized_start=996 - _globals['_HEALTHCHECKRESP']._serialized_end=1094 - _globals['_V1']._serialized_start=1333 - _globals['_V1']._serialized_end=1554 + _globals['DESCRIPTOR']._serialized_options = b'Z#github.com/gubernator-io/gubernator' + _globals['_RATELIMITREQUEST_METADATAENTRY']._loaded_options = None + _globals['_RATELIMITREQUEST_METADATAENTRY']._serialized_options = b'8\001' + _globals['_RATELIMITRESPONSE_METADATAENTRY']._loaded_options = None + _globals['_RATELIMITRESPONSE_METADATAENTRY']._serialized_options = b'8\001' + _globals['_ALGORITHM']._serialized_start=1113 + _globals['_ALGORITHM']._serialized_end=1160 + _globals['_BEHAVIOR']._serialized_start=1163 + _globals['_BEHAVIOR']._serialized_end=1304 + _globals['_STATUS']._serialized_start=1306 + _globals['_STATUS']._serialized_end=1347 + _globals['_CHECKRATELIMITSREQUEST']._serialized_start=35 + _globals['_CHECKRATELIMITSREQUEST']._serialized_end=120 + _globals['_CHECKRATELIMITSRESPONSE']._serialized_start=122 + _globals['_CHECKRATELIMITSRESPONSE']._serialized_end=211 + _globals['_RATELIMITREQUEST']._serialized_start=214 + _globals['_RATELIMITREQUEST']._serialized_end=672 + _globals['_RATELIMITREQUEST_METADATAENTRY']._serialized_start=598 + _globals['_RATELIMITREQUEST_METADATAENTRY']._serialized_end=657 + _globals['_RATELIMITRESPONSE']._serialized_start=675 + _globals['_RATELIMITRESPONSE']._serialized_end=984 + _globals['_RATELIMITRESPONSE_METADATAENTRY']._serialized_start=598 + _globals['_RATELIMITRESPONSE_METADATAENTRY']._serialized_end=657 + _globals['_HEALTHCHECKREQUEST']._serialized_start=986 + _globals['_HEALTHCHECKREQUEST']._serialized_end=1006 + _globals['_HEALTHCHECKRESPONSE']._serialized_start=1008 + _globals['_HEALTHCHECKRESPONSE']._serialized_end=1111 # @@protoc_insertion_point(module_scope) diff --git a/python/gubernator/gubernator_pb2_grpc.py b/python/gubernator/gubernator_pb2_grpc.py deleted file mode 100644 index 02dd779..0000000 --- a/python/gubernator/gubernator_pb2_grpc.py +++ /dev/null @@ -1,102 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc - -import gubernator_pb2 as gubernator__pb2 - - -class V1Stub(object): - """Missing associated documentation comment in .proto file.""" - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.GetRateLimits = channel.unary_unary( - '/pb.gubernator.V1/GetRateLimits', - request_serializer=gubernator__pb2.GetRateLimitsReq.SerializeToString, - response_deserializer=gubernator__pb2.GetRateLimitsResp.FromString, - ) - self.HealthCheck = channel.unary_unary( - '/pb.gubernator.V1/HealthCheck', - request_serializer=gubernator__pb2.HealthCheckReq.SerializeToString, - response_deserializer=gubernator__pb2.HealthCheckResp.FromString, - ) - - -class V1Servicer(object): - """Missing associated documentation comment in .proto file.""" - - def GetRateLimits(self, request, context): - """Given a list of rate limit requests, return the rate limits of each. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def HealthCheck(self, request, context): - """This method is for round trip benchmarking and can be used by - the client to determine connectivity to the server - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_V1Servicer_to_server(servicer, server): - rpc_method_handlers = { - 'GetRateLimits': grpc.unary_unary_rpc_method_handler( - servicer.GetRateLimits, - request_deserializer=gubernator__pb2.GetRateLimitsReq.FromString, - response_serializer=gubernator__pb2.GetRateLimitsResp.SerializeToString, - ), - 'HealthCheck': grpc.unary_unary_rpc_method_handler( - servicer.HealthCheck, - request_deserializer=gubernator__pb2.HealthCheckReq.FromString, - response_serializer=gubernator__pb2.HealthCheckResp.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'pb.gubernator.V1', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - - - # This class is part of an EXPERIMENTAL API. -class V1(object): - """Missing associated documentation comment in .proto file.""" - - @staticmethod - def GetRateLimits(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/pb.gubernator.V1/GetRateLimits', - gubernator__pb2.GetRateLimitsReq.SerializeToString, - gubernator__pb2.GetRateLimitsResp.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def HealthCheck(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/pb.gubernator.V1/HealthCheck', - gubernator__pb2.HealthCheckReq.SerializeToString, - gubernator__pb2.HealthCheckResp.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/python/gubernator/peer_pb2.py b/python/gubernator/peer_pb2.py new file mode 100644 index 0000000..6d5ef2b --- /dev/null +++ b/python/gubernator/peer_pb2.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: peer.proto +# Protobuf Python Version: 5.26.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +import gubernator_pb2 as gubernator__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\npeer.proto\x12\rgubernator.v3\x1a\x10gubernator.proto\"M\n\x0e\x46orwardRequest\x12;\n\x08requests\x18\x01 \x03(\x0b\x32\x1f.gubernator.v3.RateLimitRequestR\x08requests\"T\n\x0f\x46orwardResponse\x12\x41\n\x0brate_limits\x18\x01 \x03(\x0b\x32 .gubernator.v3.RateLimitResponseR\nrateLimits\"I\n\rUpdateRequest\x12\x38\n\x07globals\x18\x01 \x03(\x0b\x32\x1e.gubernator.v3.UpdateRateLimitR\x07globals\"\xce\x01\n\x0fUpdateRateLimit\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x36\n\x05state\x18\x02 \x01(\x0b\x32 .gubernator.v3.RateLimitResponseR\x05state\x12\x36\n\talgorithm\x18\x03 \x01(\x0e\x32\x18.gubernator.v3.AlgorithmR\talgorithm\x12\x1a\n\x08\x64uration\x18\x04 \x01(\x03R\x08\x64uration\x12\x1d\n\ncreated_at\x18\x05 \x01(\x03R\tcreatedAtB%Z#github.com/gubernator-io/gubernatorb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'peer_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'Z#github.com/gubernator-io/gubernator' + _globals['_FORWARDREQUEST']._serialized_start=47 + _globals['_FORWARDREQUEST']._serialized_end=124 + _globals['_FORWARDRESPONSE']._serialized_start=126 + _globals['_FORWARDRESPONSE']._serialized_end=210 + _globals['_UPDATEREQUEST']._serialized_start=212 + _globals['_UPDATEREQUEST']._serialized_end=285 + _globals['_UPDATERATELIMIT']._serialized_start=288 + _globals['_UPDATERATELIMIT']._serialized_end=494 +# @@protoc_insertion_point(module_scope) diff --git a/python/gubernator/peers_pb2.py b/python/gubernator/peers_pb2.py deleted file mode 100644 index e7121e0..0000000 --- a/python/gubernator/peers_pb2.py +++ /dev/null @@ -1,38 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: peers.proto -# Protobuf Python Version: 5.26.1 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -import gubernator_pb2 as gubernator__pb2 - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0bpeers.proto\x12\rpb.gubernator\x1a\x10gubernator.proto\"O\n\x14GetPeerRateLimitsReq\x12\x37\n\x08requests\x18\x01 \x03(\x0b\x32\x1b.pb.gubernator.RateLimitReqR\x08requests\"V\n\x15GetPeerRateLimitsResp\x12=\n\x0brate_limits\x18\x01 \x03(\x0b\x32\x1c.pb.gubernator.RateLimitRespR\nrateLimits\"Q\n\x14UpdatePeerGlobalsReq\x12\x39\n\x07globals\x18\x01 \x03(\x0b\x32\x1f.pb.gubernator.UpdatePeerGlobalR\x07globals\"\xcd\x01\n\x10UpdatePeerGlobal\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x34\n\x06status\x18\x02 \x01(\x0b\x32\x1c.pb.gubernator.RateLimitRespR\x06status\x12\x36\n\talgorithm\x18\x03 \x01(\x0e\x32\x18.pb.gubernator.AlgorithmR\talgorithm\x12\x1a\n\x08\x64uration\x18\x04 \x01(\x03R\x08\x64uration\x12\x1d\n\ncreated_at\x18\x05 \x01(\x03R\tcreatedAt\"\x17\n\x15UpdatePeerGlobalsResp2\xcd\x01\n\x07PeersV1\x12`\n\x11GetPeerRateLimits\x12#.pb.gubernator.GetPeerRateLimitsReq\x1a$.pb.gubernator.GetPeerRateLimitsResp\"\x00\x12`\n\x11UpdatePeerGlobals\x12#.pb.gubernator.UpdatePeerGlobalsReq\x1a$.pb.gubernator.UpdatePeerGlobalsResp\"\x00\x42(Z#github.com/gubernator-io/gubernator\x80\x01\x01\x62\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'peers_pb2', _globals) -if not _descriptor._USE_C_DESCRIPTORS: - _globals['DESCRIPTOR']._loaded_options = None - _globals['DESCRIPTOR']._serialized_options = b'Z#github.com/gubernator-io/gubernator\200\001\001' - _globals['_GETPEERRATELIMITSREQ']._serialized_start=48 - _globals['_GETPEERRATELIMITSREQ']._serialized_end=127 - _globals['_GETPEERRATELIMITSRESP']._serialized_start=129 - _globals['_GETPEERRATELIMITSRESP']._serialized_end=215 - _globals['_UPDATEPEERGLOBALSREQ']._serialized_start=217 - _globals['_UPDATEPEERGLOBALSREQ']._serialized_end=298 - _globals['_UPDATEPEERGLOBAL']._serialized_start=301 - _globals['_UPDATEPEERGLOBAL']._serialized_end=506 - _globals['_UPDATEPEERGLOBALSRESP']._serialized_start=508 - _globals['_UPDATEPEERGLOBALSRESP']._serialized_end=531 - _globals['_PEERSV1']._serialized_start=534 - _globals['_PEERSV1']._serialized_end=739 -# @@protoc_insertion_point(module_scope) diff --git a/python/gubernator/peers_pb2_grpc.py b/python/gubernator/peers_pb2_grpc.py deleted file mode 100644 index 9ebb860..0000000 --- a/python/gubernator/peers_pb2_grpc.py +++ /dev/null @@ -1,104 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc - -import peers_pb2 as peers__pb2 - - -class PeersV1Stub(object): - """NOTE: For use by gubernator peers only - """ - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.GetPeerRateLimits = channel.unary_unary( - '/pb.gubernator.PeersV1/GetPeerRateLimits', - request_serializer=peers__pb2.GetPeerRateLimitsReq.SerializeToString, - response_deserializer=peers__pb2.GetPeerRateLimitsResp.FromString, - ) - self.UpdatePeerGlobals = channel.unary_unary( - '/pb.gubernator.PeersV1/UpdatePeerGlobals', - request_serializer=peers__pb2.UpdatePeerGlobalsReq.SerializeToString, - response_deserializer=peers__pb2.UpdatePeerGlobalsResp.FromString, - ) - - -class PeersV1Servicer(object): - """NOTE: For use by gubernator peers only - """ - - def GetPeerRateLimits(self, request, context): - """Used by peers to relay batches of requests to an owner peer - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def UpdatePeerGlobals(self, request, context): - """Used by owner peers to send global rate limit updates to non-owner peers - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_PeersV1Servicer_to_server(servicer, server): - rpc_method_handlers = { - 'GetPeerRateLimits': grpc.unary_unary_rpc_method_handler( - servicer.GetPeerRateLimits, - request_deserializer=peers__pb2.GetPeerRateLimitsReq.FromString, - response_serializer=peers__pb2.GetPeerRateLimitsResp.SerializeToString, - ), - 'UpdatePeerGlobals': grpc.unary_unary_rpc_method_handler( - servicer.UpdatePeerGlobals, - request_deserializer=peers__pb2.UpdatePeerGlobalsReq.FromString, - response_serializer=peers__pb2.UpdatePeerGlobalsResp.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'pb.gubernator.PeersV1', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - - - # This class is part of an EXPERIMENTAL API. -class PeersV1(object): - """NOTE: For use by gubernator peers only - """ - - @staticmethod - def GetPeerRateLimits(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/pb.gubernator.PeersV1/GetPeerRateLimits', - peers__pb2.GetPeerRateLimitsReq.SerializeToString, - peers__pb2.GetPeerRateLimitsResp.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - - @staticmethod - def UpdatePeerGlobals(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/pb.gubernator.PeersV1/UpdatePeerGlobals', - peers__pb2.UpdatePeerGlobalsReq.SerializeToString, - peers__pb2.UpdatePeerGlobalsResp.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/python/requirements-py2.txt b/python/requirements-py2.txt deleted file mode 100644 index bd761a7..0000000 --- a/python/requirements-py2.txt +++ /dev/null @@ -1,15 +0,0 @@ -atomicwrites==1.3.0 -attrs==18.2.0 -enum34==1.1.6 -funcsigs==1.0.2 -futures==3.2.0 -googleapis-common-protos==1.5.8 -grpcio==1.53.2 -more-itertools==5.0.0 -pathlib2==2.3.3 -pluggy==0.8.1 -protobuf==3.18.3 -py==1.10.0 -pytest==7.2.0 -scandir==1.9.0 -six==1.12.0 diff --git a/python/requirements-py3.txt b/python/requirements-py3.txt deleted file mode 100644 index 00d54ad..0000000 --- a/python/requirements-py3.txt +++ /dev/null @@ -1,11 +0,0 @@ -atomicwrites==1.3.0 -attrs==18.2.0 -googleapis-common-protos==1.5.8 -grpcio==1.53.2 -grpcio-tools==1.19.0 -more-itertools==6.0.0 -pluggy==0.8.1 -protobuf==3.18.3 -py==1.10.0 -pytest==7.2.0 -six==1.12.0 diff --git a/python/setup.py b/python/setup.py deleted file mode 100755 index 6657f3f..0000000 --- a/python/setup.py +++ /dev/null @@ -1,57 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# Copyright 2018-2022 Mailgun Technologies Inc -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -try: # for pip >= 10 - from pip._internal.req import parse_requirements -except ImportError: # for pip <= 9.0.3 - from pip.req import parse_requirements -from setuptools import setup, find_packages -import platform - -with open('version', 'r') as version_file: - version = version_file.readline().strip() - -if platform.python_version_tuple()[0] == '2': - reqs = parse_requirements('requirements-py2.txt', session='') -else: - reqs = parse_requirements('requirements-py3.txt', session='') - -requirements = [str(r.req) for r in reqs] - -setup( - name='gubernator', - version='0.1.0', - description="Python client for gubernator", - author="Derrick J. Wippler", - author_email='thrawn01@gmail.com', - url='https://github.com/gubernator-io/gubernator', - package_dir={'': '.'}, - packages=find_packages('.', exclude=['tests']), - install_requires=requirements, - license="Apache Software License 2.0", - python_requires='>=2.7', - classifiers=[ - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: Apache Software License', - 'Natural Language :: English', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - ], -) diff --git a/python/tests/__init__.py b/python/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/python/tests/test_client.py b/python/tests/test_client.py deleted file mode 100644 index 28efabf..0000000 --- a/python/tests/test_client.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2018-2022 Mailgun Technologies Inc -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from gubernator import ratelimit_pb2 as pb - -import pytest -import subprocess -import os -import gubernator - - -@pytest.fixture(scope='module') -def cluster(): - args = ["/bin/sh", "-c", - "go run ./cmd/gubernator-cluster/main.go"] - - os.chdir("golang") - proc = subprocess.Popen(args, stdout=subprocess.PIPE) - os.chdir("..") - - while True: - line = proc.stdout.readline() - if b'Ready' in line: - break - yield proc - proc.kill() - - -def test_health_check(cluster): - client = gubernator.V1Client() - resp = client.health_check() - print("Health:", resp) - - -def test_get_rate_limit(cluster): - req = pb.Requests() - rate_limit = req.requests.add() - - rate_limit.algorithm = pb.TOKEN_BUCKET - rate_limit.duration = gubernator.SECOND * 2 - rate_limit.limit = 10 - rate_limit.namespace = 'test-ns' - rate_limit.unique_key = 'domain-id-0001' - rate_limit.hits = 1 - - client = gubernator.V1Client() - resp = client.GetRateLimits(req, timeout=0.5) - print("RateLimit: {}".format(resp)) diff --git a/region_picker.go b/region_picker.go index 4bef59d..6bd8d0b 100644 --- a/region_picker.go +++ b/region_picker.go @@ -17,11 +17,11 @@ limitations under the License. package gubernator type RegionPeerPicker interface { - GetClients(string) ([]*PeerClient, error) - GetByPeerInfo(PeerInfo) *PeerClient + GetClients(string) ([]*Peer, error) + GetByPeerInfo(PeerInfo) *Peer Pickers() map[string]PeerPicker - Peers() []*PeerClient - Add(*PeerClient) + Peers() []*Peer + Add(*Peer) New() RegionPeerPicker } @@ -32,14 +32,14 @@ type RegionPicker struct { // A map of all the pickers by region regions map[string]PeerPicker // The implementation of picker we will use for each region - reqQueue chan *RateLimitReq + reqQueue chan *RateLimitRequest } func NewRegionPicker(fn HashString64) *RegionPicker { rp := &RegionPicker{ regions: make(map[string]PeerPicker), - reqQueue: make(chan *RateLimitReq), - ReplicatedConsistentHash: NewReplicatedConsistentHash(fn, defaultReplicas), + reqQueue: make(chan *RateLimitRequest), + ReplicatedConsistentHash: NewReplicatedConsistentHash(fn, DefaultReplicas), } return rp } @@ -48,14 +48,14 @@ func (rp *RegionPicker) New() RegionPeerPicker { hash := rp.ReplicatedConsistentHash.New().(*ReplicatedConsistentHash) return &RegionPicker{ regions: make(map[string]PeerPicker), - reqQueue: make(chan *RateLimitReq), + reqQueue: make(chan *RateLimitRequest), ReplicatedConsistentHash: hash, } } // GetClients returns all the PeerClients that match this key in all regions -func (rp *RegionPicker) GetClients(key string) ([]*PeerClient, error) { - result := make([]*PeerClient, len(rp.regions)) +func (rp *RegionPicker) GetClients(key string) ([]*Peer, error) { + result := make([]*Peer, len(rp.regions)) var i int for _, picker := range rp.regions { peer, err := picker.Get(key) @@ -69,7 +69,7 @@ func (rp *RegionPicker) GetClients(key string) ([]*PeerClient, error) { } // GetByPeerInfo returns the first PeerClient the PeerInfo.HasKey() matches -func (rp *RegionPicker) GetByPeerInfo(info PeerInfo) *PeerClient { +func (rp *RegionPicker) GetByPeerInfo(info PeerInfo) *Peer { for _, picker := range rp.regions { if client := picker.GetByPeerInfo(info); client != nil { return client @@ -83,8 +83,8 @@ func (rp *RegionPicker) Pickers() map[string]PeerPicker { return rp.regions } -func (rp *RegionPicker) Peers() []*PeerClient { - var peers []*PeerClient +func (rp *RegionPicker) Peers() []*Peer { + var peers []*Peer for _, picker := range rp.regions { peers = append(peers, picker.Peers()...) @@ -93,7 +93,7 @@ func (rp *RegionPicker) Peers() []*PeerClient { return peers } -func (rp *RegionPicker) Add(peer *PeerClient) { +func (rp *RegionPicker) Add(peer *Peer) { picker, ok := rp.regions[peer.Info().DataCenter] if !ok { picker = rp.ReplicatedConsistentHash.New() diff --git a/replicated_hash.go b/replicated_hash.go index c53504e..c5c4586 100644 --- a/replicated_hash.go +++ b/replicated_hash.go @@ -26,29 +26,38 @@ import ( "github.com/segmentio/fasthash/fnv1" ) -const defaultReplicas = 512 +type PeerPicker interface { + GetByPeerInfo(PeerInfo) *Peer + Peers() []*Peer + Get(string) (*Peer, error) + New() PeerPicker + Add(*Peer) +} + +// DefaultReplicas is the number of replicas the hashmap will create by default +const DefaultReplicas = 512 type HashString64 func(data string) uint64 var defaultHashString64 HashString64 = fnv1.HashString64 -// Implements PeerPicker +// ReplicatedConsistentHash implements PeerPicker type ReplicatedConsistentHash struct { hashFunc HashString64 peerKeys []peerInfo - peers map[string]*PeerClient + peers map[string]*Peer replicas int } type peerInfo struct { hash uint64 - peer *PeerClient + peer *Peer } func NewReplicatedConsistentHash(fn HashString64, replicas int) *ReplicatedConsistentHash { ch := &ReplicatedConsistentHash{ hashFunc: fn, - peers: make(map[string]*PeerClient), + peers: make(map[string]*Peer), replicas: replicas, } @@ -61,24 +70,24 @@ func NewReplicatedConsistentHash(fn HashString64, replicas int) *ReplicatedConsi func (ch *ReplicatedConsistentHash) New() PeerPicker { return &ReplicatedConsistentHash{ hashFunc: ch.hashFunc, - peers: make(map[string]*PeerClient), + peers: make(map[string]*Peer), replicas: ch.replicas, } } -func (ch *ReplicatedConsistentHash) Peers() []*PeerClient { - var results []*PeerClient +func (ch *ReplicatedConsistentHash) Peers() []*Peer { + var results []*Peer for _, v := range ch.peers { results = append(results, v) } return results } -// Adds a peer to the hash -func (ch *ReplicatedConsistentHash) Add(peer *PeerClient) { - ch.peers[peer.Info().GRPCAddress] = peer +// Add a peer to the hash +func (ch *ReplicatedConsistentHash) Add(peer *Peer) { + ch.peers[peer.Info().HTTPAddress] = peer - key := fmt.Sprintf("%x", md5.Sum([]byte(peer.Info().GRPCAddress))) + key := fmt.Sprintf("%x", md5.Sum([]byte(peer.Info().HTTPAddress))) for i := 0; i < ch.replicas; i++ { hash := ch.hashFunc(strconv.Itoa(i) + key) ch.peerKeys = append(ch.peerKeys, peerInfo{ @@ -90,18 +99,18 @@ func (ch *ReplicatedConsistentHash) Add(peer *PeerClient) { sort.Slice(ch.peerKeys, func(i, j int) bool { return ch.peerKeys[i].hash < ch.peerKeys[j].hash }) } -// Returns number of peers in the picker +// Size returns number of peers in the picker func (ch *ReplicatedConsistentHash) Size() int { return len(ch.peers) } -// Returns the peer by hostname -func (ch *ReplicatedConsistentHash) GetByPeerInfo(peer PeerInfo) *PeerClient { - return ch.peers[peer.GRPCAddress] +// GetByPeerInfo returns the peer by hostname +func (ch *ReplicatedConsistentHash) GetByPeerInfo(peer PeerInfo) *Peer { + return ch.peers[peer.HTTPAddress] } -// Given a key, return the peer that key is assigned too -func (ch *ReplicatedConsistentHash) Get(key string) (*PeerClient, error) { +// Get returns the peer that key is assigned too +func (ch *ReplicatedConsistentHash) Get(key string) (*Peer, error) { if ch.Size() == 0 { return nil, errors.New("unable to pick a peer; pool is empty") } diff --git a/replicated_hash_test.go b/replicated_hash_test.go index 699808b..7cd9c84 100644 --- a/replicated_hash_test.go +++ b/replicated_hash_test.go @@ -14,12 +14,13 @@ See the License for the specific language governing permissions and limitations under the License. */ -package gubernator +package gubernator_test import ( "net" "testing" + guber "github.com/gubernator-io/gubernator/v3" "github.com/segmentio/fasthash/fnv1" "github.com/segmentio/fasthash/fnv1a" "github.com/stretchr/testify/assert" @@ -29,27 +30,27 @@ func TestReplicatedConsistentHash(t *testing.T) { hosts := []string{"a.svc.local", "b.svc.local", "c.svc.local"} t.Run("Size", func(t *testing.T) { - hash := NewReplicatedConsistentHash(nil, defaultReplicas) + hash := guber.NewReplicatedConsistentHash(nil, guber.DefaultReplicas) for _, h := range hosts { - hash.Add(&PeerClient{conf: PeerConfig{Info: PeerInfo{GRPCAddress: h}}}) + hash.Add(&guber.Peer{Conf: guber.PeerConfig{Info: guber.PeerInfo{HTTPAddress: h}}}) } assert.Equal(t, len(hosts), hash.Size()) }) t.Run("Host", func(t *testing.T) { - hash := NewReplicatedConsistentHash(nil, defaultReplicas) - hostMap := map[string]*PeerClient{} + hash := guber.NewReplicatedConsistentHash(nil, guber.DefaultReplicas) + hostMap := make(map[string]*guber.Peer) for _, h := range hosts { - peer := &PeerClient{conf: PeerConfig{Info: PeerInfo{GRPCAddress: h}}} + peer := &guber.Peer{Conf: guber.PeerConfig{Info: guber.PeerInfo{HTTPAddress: h}}} hash.Add(peer) hostMap[h] = peer } for host, peer := range hostMap { - assert.Equal(t, peer, hash.GetByPeerInfo(PeerInfo{GRPCAddress: host})) + assert.Equal(t, peer, hash.GetByPeerInfo(guber.PeerInfo{HTTPAddress: host})) } }) @@ -62,7 +63,7 @@ func TestReplicatedConsistentHash(t *testing.T) { for _, tc := range []struct { name string - inHashFunc HashString64 + inHashFunc guber.HashString64 outDistribution map[string]int }{{ name: "default", @@ -83,17 +84,17 @@ func TestReplicatedConsistentHash(t *testing.T) { }, }} { t.Run(tc.name, func(t *testing.T) { - hash := NewReplicatedConsistentHash(tc.inHashFunc, defaultReplicas) + hash := guber.NewReplicatedConsistentHash(tc.inHashFunc, guber.DefaultReplicas) distribution := make(map[string]int) for _, h := range hosts { - hash.Add(&PeerClient{conf: PeerConfig{Info: PeerInfo{GRPCAddress: h}}}) + hash.Add(&guber.Peer{Conf: guber.PeerConfig{Info: guber.PeerInfo{HTTPAddress: h}}}) distribution[h] = 0 } for i := range strings { peer, _ := hash.Get(strings[i]) - distribution[peer.Info().GRPCAddress]++ + distribution[peer.Info().HTTPAddress]++ } assert.Equal(t, tc.outDistribution, distribution) }) @@ -103,7 +104,7 @@ func TestReplicatedConsistentHash(t *testing.T) { } func BenchmarkReplicatedConsistantHash(b *testing.B) { - hashFuncs := map[string]HashString64{ + hashFuncs := map[string]guber.HashString64{ "fasthash/fnv1a": fnv1a.HashString64, "fasthash/fnv1": fnv1.HashString64, } @@ -115,10 +116,10 @@ func BenchmarkReplicatedConsistantHash(b *testing.B) { ips[i] = net.IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)).String() } - hash := NewReplicatedConsistentHash(hashFunc, defaultReplicas) + hash := guber.NewReplicatedConsistentHash(hashFunc, guber.DefaultReplicas) hosts := []string{"a.svc.local", "b.svc.local", "c.svc.local"} for _, h := range hosts { - hash.Add(&PeerClient{conf: PeerConfig{Info: PeerInfo{GRPCAddress: h}}}) + hash.Add(&guber.Peer{Conf: guber.PeerConfig{Info: guber.PeerInfo{HTTPAddress: h}}}) } b.ResetTimer() diff --git a/staticbuilder.go b/staticbuilder.go deleted file mode 100644 index 9bbd832..0000000 --- a/staticbuilder.go +++ /dev/null @@ -1,45 +0,0 @@ -package gubernator - -import ( - "strings" - - "google.golang.org/grpc/resolver" -) - -type staticBuilder struct{} - -var _ resolver.Builder = (*staticBuilder)(nil) - -func (sb *staticBuilder) Scheme() string { - return "static" -} - -func (sb *staticBuilder) Build(target resolver.Target, cc resolver.ClientConn, _ resolver.BuildOptions) (resolver.Resolver, error) { - var resolverAddrs []resolver.Address - for _, address := range strings.Split(target.Endpoint(), ",") { - resolverAddrs = append(resolverAddrs, resolver.Address{ - Addr: address, - ServerName: address, - }) - } - if err := cc.UpdateState(resolver.State{Addresses: resolverAddrs}); err != nil { - return nil, err - } - return &staticResolver{cc: cc}, nil -} - -// NewStaticBuilder returns a builder which returns a staticResolver that tells GRPC -// to connect a specific peer in the cluster. -func NewStaticBuilder() resolver.Builder { - return &staticBuilder{} -} - -type staticResolver struct { - cc resolver.ClientConn -} - -func (sr *staticResolver) ResolveNow(_ resolver.ResolveNowOptions) {} - -func (sr *staticResolver) Close() {} - -var _ resolver.Resolver = (*staticResolver)(nil) diff --git a/store.go b/store.go index 1c23461..e8c18a9 100644 --- a/store.go +++ b/store.go @@ -16,7 +16,9 @@ limitations under the License. package gubernator -import "context" +import ( + "context" +) // PERSISTENT STORE DETAILS @@ -26,20 +28,23 @@ import "context" // and `Get()` to keep the in memory cache and persistent store up to date with the latest ratelimit data. // Both interfaces can be implemented simultaneously to ensure data is always saved to persistent storage. +// LeakyBucketItem is 40 bytes aligned in size type LeakyBucketItem struct { - Limit int64 - Duration int64 - Remaining float64 - UpdatedAt int64 - Burst int64 + Limit int64 // 8 bytes + Duration int64 // 8 bytes + Remaining float64 // 8 bytes + UpdatedAt int64 // 8 bytes + Burst int64 // 8 bytes } +// TokenBucketItem is 40 bytes aligned in size type TokenBucketItem struct { - Status Status - Limit int64 - Duration int64 - Remaining int64 - CreatedAt int64 + Limit int64 // 8 bytes + Duration int64 // 8 bytes + Remaining int64 // 8 bytes + CreatedAt int64 // 8 bytes + Status Status // 4 bytes + // 4 bytes of padding } // Store interface allows implementors to off load storage of all or a subset of ratelimits to @@ -47,18 +52,18 @@ type TokenBucketItem struct { // to maximize performance of gubernator. // Implementations MUST be threadsafe. type Store interface { - // Called by gubernator *after* a rate limit item is updated. It's up to the store to + // OnChange is called by gubernator *after* a rate limit item is updated. It's up to the store to // decide if this rate limit item should be persisted in the store. It's up to the // store to expire old rate limit items. The CacheItem represents the current state of - // the rate limit item *after* the RateLimitReq has been applied. - OnChange(ctx context.Context, r *RateLimitReq, item *CacheItem) + // the rate limit item *after* the RateLimitRequest has been applied. + OnChange(ctx context.Context, r *RateLimitRequest, item *CacheItem) - // Called by gubernator when a rate limit is missing from the cache. It's up to the store + // Get is called by gubernator when a rate limit is missing from the cache. It's up to the store // to decide if this request is fulfilled. Should return true if the request is fulfilled // and false if the request is not fulfilled or doesn't exist in the store. - Get(ctx context.Context, r *RateLimitReq) (*CacheItem, bool) + Get(ctx context.Context, r *RateLimitRequest) (*CacheItem, bool) - // Called by gubernator when an existing rate limit should be removed from the store. + // Remove ic called by gubernator when an existing rate limit should be removed from the store. // NOTE: This is NOT called when an rate limit expires from the cache, store implementors // must expire rate limits in the store. Remove(ctx context.Context, key string) @@ -95,12 +100,12 @@ type MockStore struct { var _ Store = &MockStore{} -func (ms *MockStore) OnChange(ctx context.Context, r *RateLimitReq, item *CacheItem) { +func (ms *MockStore) OnChange(ctx context.Context, r *RateLimitRequest, item *CacheItem) { ms.Called["OnChange()"] += 1 ms.CacheItems[item.Key] = item } -func (ms *MockStore) Get(ctx context.Context, r *RateLimitReq) (*CacheItem, bool) { +func (ms *MockStore) Get(ctx context.Context, r *RateLimitRequest) (*CacheItem, bool) { ms.Called["Get()"] += 1 item, ok := ms.CacheItems[r.HashKey()] return item, ok diff --git a/store_test.go b/store_test.go index e7c58f6..25efbae 100644 --- a/store_test.go +++ b/store_test.go @@ -19,80 +19,43 @@ package gubernator_test import ( "context" "fmt" - "net" + "sync" "testing" - "github.com/gubernator-io/gubernator/v2" - "github.com/mailgun/holster/v4/clock" + "github.com/gubernator-io/gubernator/v3" + "github.com/kapetan-io/tackle/clock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "google.golang.org/grpc" ) -type v1Server struct { - conf gubernator.Config - listener net.Listener - srv *gubernator.V1Instance -} - -func (s *v1Server) Close() error { - s.conf.GRPCServers[0].GracefulStop() - return s.srv.Close() -} - -// Start a single instance of V1Server with the provided config and listening address. -func newV1Server(t *testing.T, address string, conf gubernator.Config) *v1Server { - t.Helper() - conf.GRPCServers = append(conf.GRPCServers, grpc.NewServer()) - - srv, err := gubernator.NewV1Instance(conf) - require.NoError(t, err) - - listener, err := net.Listen("tcp", address) - require.NoError(t, err) - - go func() { - if err := conf.GRPCServers[0].Serve(listener); err != nil { - fmt.Printf("while serving: %s\n", err) - } - }() - - srv.SetPeers([]gubernator.PeerInfo{{GRPCAddress: listener.Addr().String(), IsOwner: true}}) - - ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) - - err = gubernator.WaitForConnect(ctx, []string{listener.Addr().String()}) - require.NoError(t, err) - cancel() - - return &v1Server{ - conf: conf, - listener: listener, - srv: srv, - } -} - func TestLoader(t *testing.T) { loader := gubernator.NewMockLoader() - srv := newV1Server(t, "localhost:0", gubernator.Config{ + d, err := gubernator.SpawnDaemon(context.Background(), gubernator.DaemonConfig{ + HTTPListenAddress: "localhost:0", Behaviors: gubernator.BehaviorConfig{ + // Suitable for testing but not production GlobalSyncWait: clock.Millisecond * 50, // Suitable for testing but not production GlobalTimeout: clock.Second, }, Loader: loader, }) + assert.NoError(t, err) + conf := d.Config() + d.SetPeers([]gubernator.PeerInfo{{HTTPAddress: conf.HTTPListenAddress, IsOwner: true}}) + // loader.Load() should have been called for gubernator startup assert.Equal(t, 1, loader.Called["Load()"]) assert.Equal(t, 0, loader.Called["Save()"]) - client, err := gubernator.DialV1Server(srv.listener.Addr().String(), nil) - assert.Nil(t, err) + client, err := gubernator.NewClient(gubernator.WithNoTLS(d.Listener.Addr().String())) + assert.NoError(t, err) - resp, err := client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{ - Requests: []*gubernator.RateLimitReq{ + var resp gubernator.CheckRateLimitsResponse + err = client.CheckRateLimits(context.Background(), &gubernator.CheckRateLimitsRequest{ + Requests: []*gubernator.RateLimitRequest{ { Name: "test_over_limit", UniqueKey: "account:1234", @@ -102,14 +65,12 @@ func TestLoader(t *testing.T) { Hits: 1, }, }, - }) - require.Nil(t, err) - require.NotNil(t, resp) + }, &resp) + require.NoError(t, err) require.Equal(t, 1, len(resp.Responses)) require.Equal(t, "", resp.Responses[0].Error) - err = srv.Close() - require.NoError(t, err, "Error in srv.Close") + d.Close(context.Background()) // Loader.Save() should been called during gubernator shutdown assert.Equal(t, 1, loader.Called["Load()"]) @@ -124,33 +85,138 @@ func TestLoader(t *testing.T) { assert.Equal(t, gubernator.Status_UNDER_LIMIT, item.Status) } +type NoOpStore struct{} + +func (ms *NoOpStore) Remove(ctx context.Context, key string) {} +func (ms *NoOpStore) OnChange(ctx context.Context, r *gubernator.RateLimitRequest, item *gubernator.CacheItem) { +} + +func (ms *NoOpStore) Get(ctx context.Context, r *gubernator.RateLimitRequest) (*gubernator.CacheItem, bool) { + return &gubernator.CacheItem{ + Algorithm: gubernator.Algorithm_TOKEN_BUCKET, + Key: r.HashKey(), + Value: gubernator.TokenBucketItem{ + CreatedAt: gubernator.MillisecondNow(), + Duration: gubernator.Minute * 60, + Limit: 1_000, + Remaining: 1_000, + Status: 0, + }, + ExpireAt: 0, + }, true +} + +// The goal of this test is to generate some race conditions where multiple routines load from the store and or +// add items to the cache in parallel thus creating a race condition the code must then handle. +func TestHighContentionFromStore(t *testing.T) { + const ( + // Increase these number to improve the chance of contention, but at the cost of test speed. + numGoroutines = 150 + numKeys = 100 + ) + store := &NoOpStore{} + d, err := gubernator.SpawnDaemon(context.Background(), gubernator.DaemonConfig{ + HTTPListenAddress: "localhost:0", + Behaviors: gubernator.BehaviorConfig{ + // Suitable for testing but not production + GlobalSyncWait: clock.Millisecond * 50, // Suitable for testing but not production + GlobalTimeout: clock.Second, + }, + Store: store, + }) + require.NoError(t, err) + d.SetPeers([]gubernator.PeerInfo{{HTTPAddress: d.Config().HTTPListenAddress, IsOwner: true}}) + + keys := GenerateRandomKeys(numKeys) + + var wg sync.WaitGroup + var ready sync.WaitGroup + wg.Add(numGoroutines) + ready.Add(numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + // Create a client for each concurrent request to avoid contention in the client + client, err := gubernator.NewClient(gubernator.WithNoTLS(d.Listener.Addr().String())) + require.NoError(t, err) + ready.Wait() + for idx := 0; idx < numKeys; idx++ { + var resp gubernator.CheckRateLimitsResponse + err := client.CheckRateLimits(context.Background(), &gubernator.CheckRateLimitsRequest{ + Requests: []*gubernator.RateLimitRequest{ + { + Name: keys[idx], + UniqueKey: "high_contention_", + Algorithm: gubernator.Algorithm_TOKEN_BUCKET, + Duration: gubernator.Minute * 60, + Limit: numKeys, + Hits: 1, + }, + }, + }, &resp) + if err != nil { + // NOTE: you may see `connection reset by peer` if the server is overloaded + // and needs to forcibly drop some connections due to out of open file handlers etc... + fmt.Printf("%s\n", err) + } + } + wg.Done() + }() + ready.Done() + } + wg.Wait() + + for idx := 0; idx < numKeys; idx++ { + var resp gubernator.CheckRateLimitsResponse + err := d.MustClient().CheckRateLimits(context.Background(), &gubernator.CheckRateLimitsRequest{ + Requests: []*gubernator.RateLimitRequest{ + { + Name: keys[idx], + UniqueKey: "high_contention_", + Algorithm: gubernator.Algorithm_TOKEN_BUCKET, + Duration: gubernator.Minute * 60, + Limit: numKeys, + Hits: 0, + }, + }, + }, &resp) + require.NoError(t, err) + assert.Equal(t, int64(0), resp.Responses[0].Remaining) + } + + assert.NoError(t, d.Close(context.Background())) +} + func TestStore(t *testing.T) { ctx := context.Background() - setup := func() (*MockStore2, *v1Server, gubernator.V1Client) { + setup := func() (*MockStore2, *gubernator.Daemon, gubernator.Client) { store := &MockStore2{} - srv := newV1Server(t, "localhost:0", gubernator.Config{ + d, err := gubernator.SpawnDaemon(context.Background(), gubernator.DaemonConfig{ + HTTPListenAddress: "localhost:0", Behaviors: gubernator.BehaviorConfig{ - GlobalSyncWait: clock.Millisecond * 50, // Suitable for testing but not production + GlobalSyncWait: clock.Millisecond * 50, GlobalTimeout: clock.Second, }, Store: store, }) + assert.NoError(t, err) + conf := d.Config() + d.SetPeers([]gubernator.PeerInfo{{HTTPAddress: conf.HTTPListenAddress, IsOwner: true}}) - client, err := gubernator.DialV1Server(srv.listener.Addr().String(), nil) + client, err := gubernator.NewClient(gubernator.WithNoTLS(d.Listener.Addr().String())) require.NoError(t, err) - return store, srv, client + return store, d, client } - tearDown := func(srv *v1Server) { - err := srv.Close() - require.NoError(t, err) + tearDown := func(d *gubernator.Daemon) { + d.Close(context.Background()) } // Create a mock argument matcher for a request by name/key. - matchReq := func(req *gubernator.RateLimitReq) interface{} { - return mock.MatchedBy(func(req2 *gubernator.RateLimitReq) bool { + matchReq := func(req *gubernator.RateLimitRequest) interface{} { + return mock.MatchedBy(func(req2 *gubernator.RateLimitRequest) bool { return req2.Name == req.Name && req2.UniqueKey == req.UniqueKey }) @@ -158,7 +224,7 @@ func TestStore(t *testing.T) { // Create a mock argument matcher for CacheItem input. // Verify item matches expected algorithm, limit, and duration. - matchItem := func(req *gubernator.RateLimitReq) interface{} { + matchItem := func(req *gubernator.RateLimitRequest) interface{} { switch req.Algorithm { case gubernator.Algorithm_TOKEN_BUCKET: return mock.MatchedBy(func(item *gubernator.CacheItem) bool { @@ -193,7 +259,7 @@ func TestStore(t *testing.T) { } // Create a bucket item matching the request. - createBucketItem := func(req *gubernator.RateLimitReq) interface{} { + createBucketItem := func(req *gubernator.RateLimitRequest) interface{} { switch req.Algorithm { case gubernator.Algorithm_TOKEN_BUCKET: return &gubernator.TokenBucketItem{ @@ -230,7 +296,7 @@ func TestStore(t *testing.T) { store, srv, client := setup() defer tearDown(srv) - req := &gubernator.RateLimitReq{ + req := &gubernator.RateLimitRequest{ Name: "test_over_limit", UniqueKey: "account:1234", Algorithm: testCase.Algorithm, @@ -244,12 +310,13 @@ func TestStore(t *testing.T) { store.On("OnChange", mock.Anything, matchReq(req), matchItem(req)).Once() // Call code. - resp, err := client.GetRateLimits(ctx, &gubernator.GetRateLimitsReq{ - Requests: []*gubernator.RateLimitReq{req}, - }) + var resp gubernator.CheckRateLimitsResponse + err := client.CheckRateLimits(ctx, &gubernator.CheckRateLimitsRequest{ + Requests: []*gubernator.RateLimitRequest{req}, + }, &resp) require.NoError(t, err) - require.NotNil(t, resp) assert.Len(t, resp.Responses, 1) + assert.Equal(t, "", resp.Responses[0].Error) assert.Equal(t, req.Limit, resp.Responses[0].Limit) assert.Equal(t, gubernator.Status_UNDER_LIMIT, resp.Responses[0].Status) store.AssertExpectations(t) @@ -259,12 +326,13 @@ func TestStore(t *testing.T) { store.On("OnChange", mock.Anything, matchReq(req), matchItem(req)).Once() // Call code. - resp, err := client.GetRateLimits(ctx, &gubernator.GetRateLimitsReq{ - Requests: []*gubernator.RateLimitReq{req}, - }) + var resp gubernator.CheckRateLimitsResponse + err := client.CheckRateLimits(ctx, &gubernator.CheckRateLimitsRequest{ + Requests: []*gubernator.RateLimitRequest{req}, + }, &resp) require.NoError(t, err) - require.NotNil(t, resp) assert.Len(t, resp.Responses, 1) + assert.Equal(t, "", resp.Responses[0].Error) assert.Equal(t, req.Limit, resp.Responses[0].Limit) assert.Equal(t, gubernator.Status_UNDER_LIMIT, resp.Responses[0].Status) store.AssertExpectations(t) @@ -275,7 +343,7 @@ func TestStore(t *testing.T) { store, srv, client := setup() defer tearDown(srv) - req := &gubernator.RateLimitReq{ + req := &gubernator.RateLimitRequest{ Name: "test_over_limit", UniqueKey: "account:1234", Algorithm: testCase.Algorithm, @@ -298,12 +366,13 @@ func TestStore(t *testing.T) { store.On("OnChange", mock.Anything, matchReq(req), matchItem(req)).Once() // Call code. - resp, err := client.GetRateLimits(ctx, &gubernator.GetRateLimitsReq{ - Requests: []*gubernator.RateLimitReq{req}, - }) + var resp gubernator.CheckRateLimitsResponse + err := client.CheckRateLimits(ctx, &gubernator.CheckRateLimitsRequest{ + Requests: []*gubernator.RateLimitRequest{req}, + }, &resp) require.NoError(t, err) - require.NotNil(t, resp) assert.Len(t, resp.Responses, 1) + assert.Equal(t, "", resp.Responses[0].Error) assert.Equal(t, req.Limit, resp.Responses[0].Limit) assert.Equal(t, gubernator.Status_UNDER_LIMIT, resp.Responses[0].Status) store.AssertExpectations(t) @@ -314,7 +383,7 @@ func TestStore(t *testing.T) { store, srv, client := setup() defer tearDown(srv) - req := &gubernator.RateLimitReq{ + req := &gubernator.RateLimitRequest{ Name: "test_over_limit", UniqueKey: "account:1234", Algorithm: testCase.Algorithm, @@ -338,12 +407,13 @@ func TestStore(t *testing.T) { store.On("OnChange", mock.Anything, matchReq(req), matchItem(req)).Once() // Call code. - resp, err := client.GetRateLimits(ctx, &gubernator.GetRateLimitsReq{ - Requests: []*gubernator.RateLimitReq{req}, - }) + var resp gubernator.CheckRateLimitsResponse + err := client.CheckRateLimits(ctx, &gubernator.CheckRateLimitsRequest{ + Requests: []*gubernator.RateLimitRequest{req}, + }, &resp) require.NoError(t, err) - require.NotNil(t, resp) assert.Len(t, resp.Responses, 1) + assert.Equal(t, "", resp.Responses[0].Error) assert.Equal(t, req.Limit, resp.Responses[0].Limit) assert.Equal(t, gubernator.Status_UNDER_LIMIT, resp.Responses[0].Status) store.AssertExpectations(t) @@ -360,7 +430,7 @@ func TestStore(t *testing.T) { oldDuration := int64(5000) newDuration := int64(8000) - req := &gubernator.RateLimitReq{ + req := &gubernator.RateLimitRequest{ Name: "test_over_limit", UniqueKey: "account:1234", Algorithm: testCase.Algorithm, @@ -427,12 +497,13 @@ func TestStore(t *testing.T) { Once() // Call code. - resp, err := client.GetRateLimits(ctx, &gubernator.GetRateLimitsReq{ - Requests: []*gubernator.RateLimitReq{req}, - }) + var resp gubernator.CheckRateLimitsResponse + err := client.CheckRateLimits(ctx, &gubernator.CheckRateLimitsRequest{ + Requests: []*gubernator.RateLimitRequest{req}, + }, &resp) require.NoError(t, err) - require.NotNil(t, resp) assert.Len(t, resp.Responses, 1) + assert.Equal(t, "", resp.Responses[0].Error) assert.Equal(t, req.Limit, resp.Responses[0].Limit) assert.Equal(t, gubernator.Status_UNDER_LIMIT, resp.Responses[0].Status) store.AssertExpectations(t) @@ -447,7 +518,7 @@ func TestStore(t *testing.T) { oldDuration := int64(500000) newDuration := int64(8000) - req := &gubernator.RateLimitReq{ + req := &gubernator.RateLimitRequest{ Name: "test_over_limit", UniqueKey: "account:1234", Algorithm: testCase.Algorithm, @@ -517,12 +588,13 @@ func TestStore(t *testing.T) { Once() // Call code. - resp, err := client.GetRateLimits(ctx, &gubernator.GetRateLimitsReq{ - Requests: []*gubernator.RateLimitReq{req}, - }) + var resp gubernator.CheckRateLimitsResponse + err := client.CheckRateLimits(ctx, &gubernator.CheckRateLimitsRequest{ + Requests: []*gubernator.RateLimitRequest{req}, + }, &resp) require.NoError(t, err) - require.NotNil(t, resp) assert.Len(t, resp.Responses, 1) + assert.Equal(t, "", resp.Responses[0].Error) assert.Equal(t, req.Limit, resp.Responses[0].Limit) assert.Equal(t, gubernator.Status_UNDER_LIMIT, resp.Responses[0].Status) store.AssertExpectations(t) diff --git a/tls.go b/tls.go index 46cd2d7..6655cd6 100644 --- a/tls.go +++ b/tls.go @@ -18,6 +18,7 @@ package gubernator import ( "bytes" + "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -25,15 +26,15 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "log/slog" "math/big" "net" "os" "strings" "time" - "github.com/mailgun/holster/v4/setter" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" + "github.com/kapetan-io/errors" + "github.com/kapetan-io/tackle/set" ) const ( @@ -132,7 +133,7 @@ func fromFile(name string) (*bytes.Buffer, error) { b, err := os.ReadFile(name) if err != nil { - return nil, errors.Wrapf(err, "while reading file '%s'", name) + return nil, errors.Errorf("while reading file '%s': %w", name, err) } return bytes.NewBuffer(b), nil } @@ -154,11 +155,11 @@ func SetupTLS(conf *TLSConfig) error { minServerTLSVersion = tls.VersionTLS13 } - setter.SetDefault(&conf.Logger, logrus.WithField("category", "gubernator")) + set.Default(&conf.Logger, slog.Default().With("category", "gubernator")) conf.Logger.Info("Detected TLS Configuration") // Basic config with reasonably secure defaults - setter.SetDefault(&conf.ServerTLS, &tls.Config{ + set.Default(&conf.ServerTLS, &tls.Config{ CipherSuites: []uint16{ tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, @@ -184,7 +185,7 @@ func SetupTLS(conf *TLSConfig) error { "h2", "http/1.1", // enable HTTP/2 }, }) - setter.SetDefault(&conf.ClientTLS, &tls.Config{}) + set.Default(&conf.ClientTLS, &tls.Config{}) // Attempt to load any files provided conf.CaPEM, err = fromFile(conf.CaFile) @@ -227,19 +228,21 @@ func SetupTLS(conf *TLSConfig) error { conf.Logger.Info("AutoTLS Enabled") // Generate CA Cert and Private Key if err := selfCA(conf); err != nil { - return errors.Wrap(err, "while generating self signed CA certs") + return errors.Errorf("while generating self signed CA certs: %w", err) } // Generate Server Cert and Private Key if err := selfCert(conf); err != nil { - return errors.Wrap(err, "while generating self signed server certs") + return errors.Errorf("while generating self signed server certs: %w", err) } } if conf.CaPEM != nil { rootPool, err := x509.SystemCertPool() if err != nil { - conf.Logger.Warnf("while loading system CA Certs '%s'; using provided pool instead", err) + conf.Logger.LogAttrs(context.TODO(), slog.LevelWarn, "while loading system CA Certs; using provided pool instead", + ErrAttr(err), + ) rootPool = x509.NewCertPool() } rootPool.AppendCertsFromPEM(conf.CaPEM.Bytes()) @@ -250,7 +253,7 @@ func SetupTLS(conf *TLSConfig) error { if conf.KeyPEM != nil && conf.CertPEM != nil { serverCert, err := tls.X509KeyPair(conf.CertPEM.Bytes(), conf.KeyPEM.Bytes()) if err != nil { - return errors.Wrap(err, "while parsing server certificate and private key") + return errors.Errorf("while parsing server certificate and private key: %w", err) } conf.ServerTLS.Certificates = []tls.Certificate{serverCert} conf.ClientTLS.Certificates = []tls.Certificate{serverCert} @@ -259,17 +262,20 @@ func SetupTLS(conf *TLSConfig) error { // If user asked for client auth if conf.ClientAuth != tls.NoClientCert { clientPool := x509.NewCertPool() + var certProvided bool if conf.ClientAuthCaPEM != nil { // If client auth CA was provided clientPool.AppendCertsFromPEM(conf.ClientAuthCaPEM.Bytes()) + certProvided = true } else if conf.CaPEM != nil { // else use the servers CA clientPool.AppendCertsFromPEM(conf.CaPEM.Bytes()) + certProvided = true } - // error if neither was provided - if len(clientPool.Subjects()) == 0 { //nolint:all + // error if neither cert was provided + if !certProvided { return errors.New("client auth enabled, but no CA's provided") } @@ -279,7 +285,7 @@ func SetupTLS(conf *TLSConfig) error { if conf.ClientAuthKeyPEM != nil && conf.ClientAuthCertPEM != nil { clientCert, err := tls.X509KeyPair(conf.ClientAuthCertPEM.Bytes(), conf.ClientAuthKeyPEM.Bytes()) if err != nil { - return errors.Wrap(err, "while parsing client certificate and private key") + return errors.Errorf("while parsing client certificate and private key: %w", err) } conf.ClientTLS.Certificates = []tls.Certificate{clientCert} } @@ -297,7 +303,7 @@ func selfCert(conf *TLSConfig) error { network, err := discoverNetwork() if err != nil { - return errors.Wrap(err, "while detecting ip and host names") + return errors.Errorf("while detecting ip and host names: %w", err) } cert := x509.Certificate{ @@ -324,20 +330,21 @@ func selfCert(conf *TLSConfig) error { } } - conf.Logger.Info("Generating Server Private Key and Certificate....") - conf.Logger.Infof("Cert DNS names: (%s)", strings.Join(cert.DNSNames, ",")) - conf.Logger.Infof("Cert IPs: (%s)", func() string { - var r []string - for i := range cert.IPAddresses { - r = append(r, cert.IPAddresses[i].String()) - } - return strings.Join(r, ",") - }()) + conf.Logger.LogAttrs(context.TODO(), slog.LevelInfo, "Generating Server Private Key and Certificate....", + slog.String("dns_names", strings.Join(cert.DNSNames, ",")), + slog.String("cert_ips", func() string { + var r []string + for i := range cert.IPAddresses { + r = append(r, cert.IPAddresses[i].String()) + } + return strings.Join(r, ",") + }()), + ) // Generate a public / private key privKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) if err != nil { - return errors.Wrap(err, "while generating pubic/private key pair") + return errors.Errorf("while generating pubic/private key pair: %w", err) } // Attempt to sign the generated certs with the provided CaFile @@ -347,7 +354,7 @@ func selfCert(conf *TLSConfig) error { keyPair, err := tls.X509KeyPair(conf.CaPEM.Bytes(), conf.CaKeyPEM.Bytes()) if err != nil { - return errors.Wrap(err, "while reading generated PEMs") + return errors.Errorf("while reading generated PEMs: %w", err) } if len(keyPair.Certificate) == 0 { @@ -356,12 +363,12 @@ func selfCert(conf *TLSConfig) error { caCert, err := x509.ParseCertificate(keyPair.Certificate[0]) if err != nil { - return errors.Wrap(err, "while parsing CA Cert") + return errors.Errorf("while parsing CA Cert: %w", err) } signedBytes, err := x509.CreateCertificate(rand.Reader, &cert, caCert, &privKey.PublicKey, keyPair.PrivateKey) if err != nil { - return errors.Wrap(err, "while self signing server cert") + return errors.Errorf("while self signing server cert: %w", err) } conf.CertPEM = new(bytes.Buffer) @@ -369,12 +376,12 @@ func selfCert(conf *TLSConfig) error { Type: "CERTIFICATE", Bytes: signedBytes, }); err != nil { - return errors.Wrap(err, "while encoding CERTIFICATE PEM") + return errors.Errorf("while encoding CERTIFICATE PEM: %w", err) } b, err := x509.MarshalECPrivateKey(privKey) if err != nil { - return errors.Wrap(err, "while encoding EC Marshalling") + return errors.Errorf("while encoding EC Marshalling: %w", err) } conf.KeyPEM = new(bytes.Buffer) @@ -382,7 +389,7 @@ func selfCert(conf *TLSConfig) error { Type: blockTypeEC, Bytes: b, }); err != nil { - return errors.Wrap(err, "while encoding EC KEY PEM") + return errors.Errorf("while encoding EC KEY PEM: %w", err) } return nil } @@ -410,12 +417,12 @@ func selfCA(conf *TLSConfig) error { conf.Logger.Info("Generating CA Certificates....") privKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader) if err != nil { - return errors.Wrap(err, "while generating pubic/private key pair") + return errors.Errorf("while generating pubic/private key pair: %w", err) } b, err = x509.CreateCertificate(rand.Reader, &ca, &ca, &privKey.PublicKey, privKey) if err != nil { - return errors.Wrap(err, "while self signing CA certificate") + return errors.Errorf("while self signing CA certificate: %w", err) } conf.CaPEM = new(bytes.Buffer) @@ -423,12 +430,12 @@ func selfCA(conf *TLSConfig) error { Type: blockTypeCert, Bytes: b, }); err != nil { - return errors.Wrap(err, "while encoding CERTIFICATE PEM") + return errors.Errorf("while encoding CERTIFICATE PEM: %w", err) } b, err = x509.MarshalECPrivateKey(privKey) if err != nil { - return errors.Wrap(err, "while marshalling EC private key") + return errors.Errorf("while marshalling EC private key: %w", err) } conf.CaKeyPEM = new(bytes.Buffer) @@ -436,7 +443,7 @@ func selfCA(conf *TLSConfig) error { Type: blockTypeEC, Bytes: b, }); err != nil { - return errors.Wrap(err, "while encoding EC private key into PEM") + return errors.Errorf("while encoding EC private key into PEM: %w", err) } return nil } diff --git a/tls_test.go b/tls_test.go index f6b56ba..49f781e 100644 --- a/tls_test.go +++ b/tls_test.go @@ -21,15 +21,17 @@ import ( "crypto/tls" "fmt" "io" + "log/slog" + "math/rand" "net/http" "strings" "testing" - "github.com/gubernator-io/gubernator/v2" - "github.com/mailgun/holster/v4/clock" + "github.com/kapetan-io/tackle/clock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/net/http2" + + "github.com/gubernator-io/gubernator/v3" ) func spawnDaemon(t *testing.T, conf gubernator.DaemonConfig) *gubernator.Daemon { @@ -39,18 +41,19 @@ func spawnDaemon(t *testing.T, conf gubernator.DaemonConfig) *gubernator.Daemon d, err := gubernator.SpawnDaemon(ctx, conf) cancel() require.NoError(t, err) - d.SetPeers([]gubernator.PeerInfo{{GRPCAddress: conf.GRPCListenAddress, IsOwner: true}}) + d.SetPeers([]gubernator.PeerInfo{{HTTPAddress: conf.HTTPListenAddress, IsOwner: true}}) return d } func makeRequest(t *testing.T, conf gubernator.DaemonConfig) error { t.Helper() - client, err := gubernator.DialV1Server(conf.GRPCListenAddress, conf.TLS.ClientTLS) + client, err := gubernator.NewClient(gubernator.WithTLS(conf.ClientTLS(), conf.HTTPListenAddress)) require.NoError(t, err) - resp, err := client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{ - Requests: []*gubernator.RateLimitReq{ + var resp gubernator.CheckRateLimitsResponse + err = client.CheckRateLimits(context.Background(), &gubernator.CheckRateLimitsRequest{ + Requests: []*gubernator.RateLimitRequest{ { Name: "test_tls", UniqueKey: "account:995", @@ -60,7 +63,7 @@ func makeRequest(t *testing.T, conf gubernator.DaemonConfig) error { Hits: 1, }, }, - }) + }, &resp) if err != nil { return err @@ -120,18 +123,18 @@ func TestSetupTLS(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { conf := gubernator.DaemonConfig{ - GRPCListenAddress: "127.0.0.1:9695", HTTPListenAddress: "127.0.0.1:9685", TLS: tt.tls, } d := spawnDaemon(t, conf) - client, err := gubernator.DialV1Server(conf.GRPCListenAddress, tt.tls.ClientTLS) + client, err := gubernator.NewClient(gubernator.WithTLS(conf.ClientTLS(), conf.HTTPListenAddress)) require.NoError(t, err) - resp, err := client.GetRateLimits(context.Background(), &gubernator.GetRateLimitsReq{ - Requests: []*gubernator.RateLimitReq{ + var resp gubernator.CheckRateLimitsResponse + err = client.CheckRateLimits(context.Background(), &gubernator.CheckRateLimitsRequest{ + Requests: []*gubernator.RateLimitRequest{ { Name: "test_tls", UniqueKey: "account:995", @@ -141,21 +144,20 @@ func TestSetupTLS(t *testing.T) { Hits: 1, }, }, - }) + }, &resp) require.NoError(t, err) rl := resp.Responses[0] assert.Equal(t, "", rl.Error) assert.Equal(t, gubernator.Status_UNDER_LIMIT, rl.Status) assert.Equal(t, int64(99), rl.Remaining) - d.Close() + d.Close(context.Background()) }) } } func TestSetupTLSSkipVerify(t *testing.T) { conf := gubernator.DaemonConfig{ - GRPCListenAddress: "127.0.0.1:9695", HTTPListenAddress: "127.0.0.1:9685", TLS: &gubernator.TLSConfig{ CaFile: "contrib/certs/ca.cert", @@ -165,7 +167,7 @@ func TestSetupTLSSkipVerify(t *testing.T) { } d := spawnDaemon(t, conf) - defer d.Close() + defer d.Close(context.Background()) tls := &gubernator.TLSConfig{ AutoTLS: true, @@ -190,13 +192,12 @@ func TestSetupTLSClientAuth(t *testing.T) { } conf := gubernator.DaemonConfig{ - GRPCListenAddress: "127.0.0.1:9695", HTTPListenAddress: "127.0.0.1:9685", TLS: &serverTLS, } d := spawnDaemon(t, conf) - defer d.Close() + defer d.Close(context.Background()) // Given generated client certs tls := &gubernator.TLSConfig{ @@ -211,7 +212,8 @@ func TestSetupTLSClientAuth(t *testing.T) { // Should not be allowed without a cert signed by the client CA err = makeRequest(t, conf) require.Error(t, err) - assert.Contains(t, err.Error(), "code = Unavailable desc") + // Error is different depending on golang version + //assert.Contains(t, err.Error(), "tls: certificate required") // Given the client auth certs tls = &gubernator.TLSConfig{ @@ -238,27 +240,23 @@ func TestTLSClusterWithClientAuthentication(t *testing.T) { } d1 := spawnDaemon(t, gubernator.DaemonConfig{ - GRPCListenAddress: "127.0.0.1:9695", HTTPListenAddress: "127.0.0.1:9685", TLS: &serverTLS, }) - defer d1.Close() + defer d1.Close(context.Background()) d2 := spawnDaemon(t, gubernator.DaemonConfig{ - GRPCListenAddress: "127.0.0.1:9696", HTTPListenAddress: "127.0.0.1:9686", TLS: &serverTLS, }) - defer d2.Close() + defer d2.Close(context.Background()) peers := []gubernator.PeerInfo{ { - GRPCAddress: d1.GRPCListeners[0].Addr().String(), - HTTPAddress: d1.HTTPListener.Addr().String(), + HTTPAddress: d1.Listener.Addr().String(), }, { - GRPCAddress: d2.GRPCListeners[0].Addr().String(), - HTTPAddress: d2.HTTPListener.Addr().String(), + HTTPAddress: d2.Listener.Addr().String(), }, } d1.SetPeers(peers) @@ -270,7 +268,7 @@ func TestTLSClusterWithClientAuthentication(t *testing.T) { config := d2.Config() client := &http.Client{ - Transport: &http2.Transport{ + Transport: &http.Transport{ TLSClientConfig: config.ClientTLS(), }, } @@ -281,13 +279,12 @@ func TestTLSClusterWithClientAuthentication(t *testing.T) { b, err := io.ReadAll(resp.Body) require.NoError(t, err) - // Should have called GetPeerRateLimits on d2 - assert.Contains(t, string(b), `{method="/pb.gubernator.PeersV1/GetPeerRateLimits"} 1`) + // Should have called /v1/peer.forward on d2 + assert.Contains(t, string(b), `{path="`+gubernator.RPCPeerForward+`"} 1`) } func TestHTTPSClientAuth(t *testing.T) { conf := gubernator.DaemonConfig{ - GRPCListenAddress: "127.0.0.1:9695", HTTPListenAddress: "127.0.0.1:9685", HTTPStatusListenAddress: "127.0.0.1:9686", TLS: &gubernator.TLSConfig{ @@ -299,7 +296,7 @@ func TestHTTPSClientAuth(t *testing.T) { } d := spawnDaemon(t, conf) - defer d.Close() + defer d.Close(context.Background()) clientWithCert := &http.Client{ Transport: &http.Transport{ @@ -315,9 +312,9 @@ func TestHTTPSClientAuth(t *testing.T) { }, } - reqCertRequired, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://%s/v1/HealthCheck", conf.HTTPListenAddress), nil) + reqCertRequired, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://%s/healthz", conf.HTTPListenAddress), nil) require.NoError(t, err) - reqNoClientCertRequired, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://%s/v1/HealthCheck", conf.HTTPStatusListenAddress), nil) + reqNoClientCertRequired, err := http.NewRequest(http.MethodGet, fmt.Sprintf("https://%s/healthz", conf.HTTPStatusListenAddress), nil) require.NoError(t, err) // Test that a client without a cert can access /v1/HealthCheck at status address @@ -326,18 +323,90 @@ func TestHTTPSClientAuth(t *testing.T) { defer resp.Body.Close() b, err := io.ReadAll(resp.Body) require.NoError(t, err) - assert.Equal(t, `{"status":"healthy","message":"","peer_count":1}`, strings.ReplaceAll(string(b), " ", "")) + assert.Equal(t, `{"status":"healthy","peer_count":1}`, strings.ReplaceAll(string(b), " ", "")) // Verify we get an error when we try to access existing HTTPListenAddress without cert - //nolint:bodyclose // Expect error, no body to close. - _, err = clientWithoutCert.Do(reqCertRequired) + _, err = clientWithoutCert.Do(reqCertRequired) //nolint:all require.Error(t, err) + // The error message is different depending on what version of golang is being used + //assert.Contains(t, err.Error(), "remote error: tls: certificate required") - // Check that with a valid client cert we can access /v1/HealthCheck at existing HTTPListenAddress - resp2, err := clientWithCert.Do(reqCertRequired) + // Check that with a valid client cert we can access /v1/healthz at existing HTTPListenAddress + resp3, err := clientWithCert.Do(reqCertRequired) require.NoError(t, err) - defer resp2.Body.Close() - b, err = io.ReadAll(resp2.Body) + defer resp3.Body.Close() + b, err = io.ReadAll(resp3.Body) + require.NoError(t, err) + assert.Equal(t, `{"status":"healthy","peer_count":1}`, strings.ReplaceAll(string(b), " ", "")) +} + +// Ensure SpawnDaemon() setup peer TLS auth properly +func TestHTTPSDaemonPeerAuth(t *testing.T) { + var daemons []*gubernator.Daemon + var peers []gubernator.PeerInfo + name := t.Name() + + for _, peer := range []gubernator.PeerInfo{ + {HTTPAddress: "127.0.0.1:9780"}, + {HTTPAddress: "127.0.0.1:9781"}, + {HTTPAddress: "127.0.0.1:9782"}, + } { + ctx, cancel := context.WithTimeout(context.Background(), clock.Second*10) + d, err := gubernator.SpawnDaemon(ctx, gubernator.DaemonConfig{ + Logger: slog.Default().With("instance", peer.HTTPAddress), + InstanceID: peer.HTTPAddress, + HTTPListenAddress: peer.HTTPAddress, + AdvertiseAddress: peer.HTTPAddress, + DataCenter: peer.DataCenter, + Behaviors: gubernator.BehaviorConfig{ + // Suitable for testing but not production + GlobalSyncWait: clock.Millisecond * 50, + GlobalTimeout: clock.Second * 5, + BatchTimeout: clock.Second * 5, + }, + TLS: &gubernator.TLSConfig{ + CaFile: "contrib/certs/ca.cert", + CertFile: "contrib/certs/gubernator.pem", + KeyFile: "contrib/certs/gubernator.key", + ClientAuth: tls.RequireAndVerifyClientCert, + }, + }) + cancel() + require.NoError(t, err) + peers = append(peers, d.PeerInfo) + daemons = append(daemons, d) + } + + for _, d := range daemons { + d.SetPeers(peers) + } + + defer func() { + for _, d := range daemons { + _ = d.Close(context.Background()) + } + }() + + client, err := daemons[0].Client() require.NoError(t, err) - assert.Equal(t, `{"status":"healthy","message":"","peer_count":1}`, strings.ReplaceAll(string(b), " ", "")) + + for i := 0; i < 1_000; i++ { + key := fmt.Sprintf("account:%08x", rand.Int()) + var resp gubernator.CheckRateLimitsResponse + err := client.CheckRateLimits(context.Background(), &gubernator.CheckRateLimitsRequest{ + Requests: []*gubernator.RateLimitRequest{ + { + Duration: gubernator.Millisecond * 9000, + Name: name, + UniqueKey: key, + Limit: 20, + Hits: 1, + }, + }, + }, &resp) + + require.Nil(t, err) + require.Equal(t, "", resp.Responses[0].Error) + } + } diff --git a/tracing/tracing.go b/tracing/tracing.go new file mode 100644 index 0000000..1502932 --- /dev/null +++ b/tracing/tracing.go @@ -0,0 +1,176 @@ +package tracing + +import ( + "context" + "fmt" + "log/slog" + "os" + "runtime" + "strconv" + "strings" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" + "go.opentelemetry.io/otel/propagation" + "go.opentelemetry.io/otel/sdk/resource" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + semconv "go.opentelemetry.io/otel/semconv/v1.24.0" + "go.opentelemetry.io/otel/trace" +) + +// TODO: How we use OTEL needs an overhaul, as such some of these functions will likely go away, however that +// day is not today. We should probably create a single span for each incoming request +// and avoid child spans if possible. See https://jeremymorrell.dev/blog/a-practitioners-guide-to-wide-events/ + +// NewResource creates a resource with sensible defaults. Replaces common use case of verbose usage. +func NewResource(serviceName, version string, resources ...*resource.Resource) (*resource.Resource, error) { + res, err := resource.Merge( + resource.Default(), + resource.NewWithAttributes( + semconv.SchemaURL, + semconv.ServiceNameKey.String(serviceName), + semconv.ServiceVersionKey.String(version), + ), + ) + if err != nil { + return nil, fmt.Errorf("error in resource.Merge: %w", err) + } + + for i, res2 := range resources { + res, err = resource.Merge(res, res2) + if err != nil { + return nil, fmt.Errorf("error in resource.Merge on resources index %d: %w", i, err) + } + } + + return res, nil +} + +func StartScope(ctx context.Context, spanName string, opts ...trace.SpanStartOption) context.Context { + fileTag := getFileTag(1) + opts = append(opts, trace.WithAttributes( + attribute.String("file", fileTag), + )) + + ctx, _ = Tracer().Start(ctx, spanName, opts...) + return ctx +} + +// EndScope end scope created by `StartScope()`/`StartScope()`. +// Logs error return value and ends span. +func EndScope(ctx context.Context, err error) { + span := trace.SpanFromContext(ctx) + + // If scope returns an error, mark span with error. + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + } + + span.End() +} + +// Tracer returns a tracer object. +func Tracer(opts ...trace.TracerOption) trace.Tracer { + return otel.Tracer(globalLibraryName, opts...) +} + +var globalLibraryName string + +type ShutdownFunc func(ctx context.Context) error + +// InitTracing initializes a global OpenTelemetry tracer provider singleton. +func InitTracing(ctx context.Context, log *slog.Logger, libraryName string, opts ...sdktrace.TracerProviderOption) (ShutdownFunc, error) { + exporter, err := makeOtlpExporter(ctx, log) + if err != nil { + return nil, fmt.Errorf("error in makeOtlpExporter: %w", err) + } + + exportProcessor := sdktrace.NewBatchSpanProcessor(exporter) + opts = append(opts, sdktrace.WithSpanProcessor(exportProcessor)) + + tp := sdktrace.NewTracerProvider(opts...) + otel.SetTracerProvider(tp) + + if libraryName == "" { + libraryName = "github.com/gubernator-io/gubernator/v3" + + } + globalLibraryName = libraryName + + // Required for trace propagation between services. + otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{})) + + return func(ctx context.Context) error { + return tp.Shutdown(ctx) + }, err +} + +// or returns the first string which is not "", returns "" if all strings provided are "" +func or(names ...string) string { + for _, name := range names { + if name != "" { + return name + } + } + + return "" +} + +func makeOtlpExporter(ctx context.Context, log *slog.Logger) (*otlptrace.Exporter, error) { + protocol := or( + os.Getenv("OTEL_EXPORTER_OTLP_PROTOCOL"), + os.Getenv("OTEL_EXPORTER_OTLP_TRACES_PROTOCOL"), + "grpc") + var client otlptrace.Client + + switch protocol { + case "grpc": + client = otlptracegrpc.NewClient() + case "http/protobuf": + client = otlptracehttp.NewClient() + default: + log.Error("unknown OTLP exporter protocol", "OTEL_EXPORTER_OTLP_PROTOCOL", protocol) + protocol = "grpc" + client = otlptracegrpc.NewClient() + } + + attrs := []slog.Attr{ + slog.String("exporter", "otlp"), + slog.String("protocol", protocol), + slog.String("endpoint", or(os.Getenv("OTEL_EXPORTER_OTLP_ENDPOINT"), + os.Getenv("OTEL_EXPORTER_OTLP_TRACES_ENDPOINT"))), + } + + sampler := os.Getenv("OTEL_TRACES_SAMPLER") + attrs = append(attrs, slog.String("sampler", sampler)) + if strings.HasSuffix(sampler, "traceidratio") { + ratio, _ := strconv.ParseFloat(os.Getenv("OTEL_TRACES_SAMPLER_ARG"), 64) + attrs = append(attrs, slog.Float64("sampler.ratio", ratio)) + } + + log.LogAttrs(ctx, slog.LevelInfo, "Initializing OpenTelemetry", attrs...) + + return otlptrace.New(ctx, client) +} + +// getFileTag returns file name:line of the caller. +// +// Use skip=0 to get the caller of getFileTag. +// +// Use skip=1 to get the caller of the caller(a getFileTag() wrapper). +func getFileTag(skip int) string { + _, file, line, ok := runtime.Caller(skip + 1) + + // Determine source file and line number. + if !ok { + // Rare condition. Probably a bug in caller. + return "unknown" + } + + return file + ":" + strconv.Itoa(line) +} diff --git a/workers.go b/workers.go deleted file mode 100644 index 34d99d1..0000000 --- a/workers.go +++ /dev/null @@ -1,626 +0,0 @@ -/* -Copyright 2018-2022 Mailgun Technologies Inc - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package gubernator - -// Thread-safe worker pool for handling concurrent Gubernator requests. -// Ensures requests are synchronized to avoid caching conflicts. -// Handle concurrent requests by sharding cache key space across multiple -// workers. -// Uses hash ring design pattern to distribute requests to an assigned worker. -// No mutex locking necessary because each worker has its own data space and -// processes requests sequentially. -// -// Request workflow: -// - A 63-bit hash is generated from an incoming request by its Key/Name -// values. (Actually 64 bit, but we toss out one bit to properly calculate -// the next step.) -// - Workers are assigned equal size hash ranges. The worker is selected by -// choosing the worker index associated with that linear hash value range. -// - The worker has command channels for each method call. The request is -// enqueued to the appropriate channel. -// - The worker pulls the request from the appropriate channel and executes the -// business logic for that method. Then, it sends a response back using the -// requester's provided response channel. - -import ( - "context" - "io" - "strconv" - "sync" - "sync/atomic" - - "github.com/OneOfOne/xxhash" - "github.com/mailgun/holster/v4/errors" - "github.com/mailgun/holster/v4/setter" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" - "go.opentelemetry.io/otel/trace" -) - -type WorkerPool struct { - hasher workerHasher - workers []*Worker - workerCacheSize int - hashRingStep uint64 - conf *Config - done chan struct{} -} - -type Worker struct { - name string - conf *Config - cache Cache - getRateLimitRequest chan request - storeRequest chan workerStoreRequest - loadRequest chan workerLoadRequest - addCacheItemRequest chan workerAddCacheItemRequest - getCacheItemRequest chan workerGetCacheItemRequest -} - -type workerHasher interface { - // ComputeHash63 returns a 63-bit hash derived from input. - ComputeHash63(input string) uint64 -} - -// hasher is the default implementation of workerHasher. -type hasher struct{} - -// Method request/response structs. -type workerStoreRequest struct { - ctx context.Context - response chan workerStoreResponse - out chan<- *CacheItem -} - -type workerStoreResponse struct{} - -type workerLoadRequest struct { - ctx context.Context - response chan workerLoadResponse - in <-chan *CacheItem -} - -type workerLoadResponse struct{} - -type workerAddCacheItemRequest struct { - ctx context.Context - response chan workerAddCacheItemResponse - item *CacheItem -} - -type workerAddCacheItemResponse struct { - exists bool -} - -type workerGetCacheItemRequest struct { - ctx context.Context - response chan workerGetCacheItemResponse - key string -} - -type workerGetCacheItemResponse struct { - item *CacheItem - ok bool -} - -var _ io.Closer = &WorkerPool{} -var _ workerHasher = &hasher{} - -var workerCounter int64 - -func NewWorkerPool(conf *Config) *WorkerPool { - setter.SetDefault(&conf.CacheSize, 50_000) - - // Compute hashRingStep as interval between workers' 63-bit hash ranges. - // 64th bit is used here as a max value that is just out of range of 63-bit space to calculate the step. - chp := &WorkerPool{ - workers: make([]*Worker, conf.Workers), - workerCacheSize: conf.CacheSize / conf.Workers, - hasher: newHasher(), - hashRingStep: uint64(1<<63) / uint64(conf.Workers), - conf: conf, - done: make(chan struct{}), - } - - // Create workers. - conf.Logger.Infof("Starting %d Gubernator workers...", conf.Workers) - for i := 0; i < conf.Workers; i++ { - chp.workers[i] = chp.newWorker() - go chp.dispatch(chp.workers[i]) - } - - return chp -} - -func newHasher() *hasher { - return &hasher{} -} - -func (ph *hasher) ComputeHash63(input string) uint64 { - return xxhash.ChecksumString64S(input, 0) >> 1 -} - -func (p *WorkerPool) Close() error { - close(p.done) - return nil -} - -// Create a new pool worker instance. -func (p *WorkerPool) newWorker() *Worker { - worker := &Worker{ - conf: p.conf, - cache: p.conf.CacheFactory(p.workerCacheSize), - getRateLimitRequest: make(chan request), - storeRequest: make(chan workerStoreRequest), - loadRequest: make(chan workerLoadRequest), - addCacheItemRequest: make(chan workerAddCacheItemRequest), - getCacheItemRequest: make(chan workerGetCacheItemRequest), - } - workerNumber := atomic.AddInt64(&workerCounter, 1) - 1 - worker.name = strconv.FormatInt(workerNumber, 10) - return worker -} - -// getWorker Returns the request channel associated with the key. -// Hash the key, then lookup hash ring to find the worker. -func (p *WorkerPool) getWorker(key string) *Worker { - hash := p.hasher.ComputeHash63(key) - idx := hash / p.hashRingStep - return p.workers[idx] -} - -// Pool worker for processing Gubernator requests. -// Each worker maintains its own state. -// A hash ring will distribute requests to an assigned worker by key. -// See: getWorker() -func (p *WorkerPool) dispatch(worker *Worker) { - for { - // Dispatch requests from each channel. - select { - case req, ok := <-worker.getRateLimitRequest: - if !ok { - // Channel closed. Unexpected, but should be handled. - logrus.Error("workerPool worker stopped because channel closed") - return - } - - resp := new(response) - resp.rl, resp.err = worker.handleGetRateLimit(req.ctx, req.request, req.reqState, worker.cache) - select { - case req.resp <- resp: - // Success. - - case <-req.ctx.Done(): - // Context canceled. - trace.SpanFromContext(req.ctx).RecordError(resp.err) - } - metricCommandCounter.WithLabelValues(worker.name, "GetRateLimit").Inc() - - case req, ok := <-worker.storeRequest: - if !ok { - // Channel closed. Unexpected, but should be handled. - logrus.Error("workerPool worker stopped because channel closed") - return - } - - worker.handleStore(req, worker.cache) - metricCommandCounter.WithLabelValues(worker.name, "Store").Inc() - - case req, ok := <-worker.loadRequest: - if !ok { - // Channel closed. Unexpected, but should be handled. - logrus.Error("workerPool worker stopped because channel closed") - return - } - - worker.handleLoad(req, worker.cache) - metricCommandCounter.WithLabelValues(worker.name, "Load").Inc() - - case req, ok := <-worker.addCacheItemRequest: - if !ok { - // Channel closed. Unexpected, but should be handled. - logrus.Error("workerPool worker stopped because channel closed") - return - } - - worker.handleAddCacheItem(req, worker.cache) - metricCommandCounter.WithLabelValues(worker.name, "AddCacheItem").Inc() - - case req, ok := <-worker.getCacheItemRequest: - if !ok { - // Channel closed. Unexpected, but should be handled. - logrus.Error("workerPool worker stopped because channel closed") - return - } - - worker.handleGetCacheItem(req, worker.cache) - metricCommandCounter.WithLabelValues(worker.name, "GetCacheItem").Inc() - - case <-p.done: - // Clean up. - return - } - } -} - -// GetRateLimit sends a GetRateLimit request to worker pool. -func (p *WorkerPool) GetRateLimit(ctx context.Context, rlRequest *RateLimitReq, reqState RateLimitReqState) (*RateLimitResp, error) { - // Delegate request to assigned channel based on request key. - worker := p.getWorker(rlRequest.HashKey()) - queueGauge := metricWorkerQueue.WithLabelValues("GetRateLimit", worker.name) - queueGauge.Inc() - defer queueGauge.Dec() - handlerRequest := request{ - ctx: ctx, - resp: make(chan *response, 1), - request: rlRequest, - reqState: reqState, - } - - // Send request. - select { - case worker.getRateLimitRequest <- handlerRequest: - // Successfully sent request. - case <-ctx.Done(): - return nil, ctx.Err() - } - - // Wait for response. - select { - case handlerResponse := <-handlerRequest.resp: - // Successfully read response. - return handlerResponse.rl, handlerResponse.err - case <-ctx.Done(): - return nil, ctx.Err() - } -} - -// Handle request received by worker. -func (worker *Worker) handleGetRateLimit(ctx context.Context, req *RateLimitReq, reqState RateLimitReqState, cache Cache) (*RateLimitResp, error) { - defer prometheus.NewTimer(metricFuncTimeDuration.WithLabelValues("Worker.handleGetRateLimit")).ObserveDuration() - var rlResponse *RateLimitResp - var err error - - switch req.Algorithm { - case Algorithm_TOKEN_BUCKET: - rlResponse, err = tokenBucket(ctx, worker.conf.Store, cache, req, reqState) - if err != nil { - msg := "Error in tokenBucket" - countError(err, msg) - err = errors.Wrap(err, msg) - trace.SpanFromContext(ctx).RecordError(err) - } - - case Algorithm_LEAKY_BUCKET: - rlResponse, err = leakyBucket(ctx, worker.conf.Store, cache, req, reqState) - if err != nil { - msg := "Error in leakyBucket" - countError(err, msg) - err = errors.Wrap(err, msg) - trace.SpanFromContext(ctx).RecordError(err) - } - - default: - err = errors.Errorf("Invalid rate limit algorithm '%d'", req.Algorithm) - trace.SpanFromContext(ctx).RecordError(err) - metricCheckErrorCounter.WithLabelValues("Invalid algorithm").Add(1) - } - - return rlResponse, err -} - -// Load atomically loads cache from persistent storage. -// Read from persistent storage. Load into each appropriate worker's cache. -// Workers are locked during this load operation to prevent race conditions. -func (p *WorkerPool) Load(ctx context.Context) (err error) { - queueGauge := metricWorkerQueue.WithLabelValues("Load", "") - queueGauge.Inc() - defer queueGauge.Dec() - ch, err := p.conf.Loader.Load() - if err != nil { - return errors.Wrap(err, "Error in loader.Load") - } - - type loadChannel struct { - ch chan *CacheItem - worker *Worker - respChan chan workerLoadResponse - } - - // Map request channel hash to load channel. - loadChMap := map[*Worker]loadChannel{} - - // Send each item to the assigned channel's cache. -MAIN: - for { - var item *CacheItem - var ok bool - - select { - case item, ok = <-ch: - if !ok { - break MAIN - } - // Successfully received item. - - case <-ctx.Done(): - // Context canceled. - return ctx.Err() - } - - worker := p.getWorker(item.Key) - - // Initiate a load channel with each worker. - loadCh, exist := loadChMap[worker] - if !exist { - loadCh = loadChannel{ - ch: make(chan *CacheItem), - worker: worker, - respChan: make(chan workerLoadResponse), - } - loadChMap[worker] = loadCh - - // Tie up the worker while loading. - worker.loadRequest <- workerLoadRequest{ - ctx: ctx, - response: loadCh.respChan, - in: loadCh.ch, - } - } - - // Send item to worker's load channel. - select { - case loadCh.ch <- item: - // Successfully sent item. - - case <-ctx.Done(): - // Context canceled. - return ctx.Err() - } - } - - // Clean up. - for _, loadCh := range loadChMap { - close(loadCh.ch) - - // Load response confirms all items have been loaded and the worker - // resumes normal operation. - select { - case <-loadCh.respChan: - // Successfully received response. - - case <-ctx.Done(): - // Context canceled. - return ctx.Err() - } - } - - return nil -} - -func (worker *Worker) handleLoad(request workerLoadRequest, cache Cache) { -MAIN: - for { - var item *CacheItem - var ok bool - - select { - case item, ok = <-request.in: - if !ok { - break MAIN - } - // Successfully received item. - - case <-request.ctx.Done(): - // Context canceled. - return - } - - cache.Add(item) - } - - response := workerLoadResponse{} - - select { - case request.response <- response: - // Successfully sent response. - - case <-request.ctx.Done(): - // Context canceled. - trace.SpanFromContext(request.ctx).RecordError(request.ctx.Err()) - } -} - -// Store atomically stores cache to persistent storage. -// Save all workers' caches to persistent storage. -// Workers are locked during this store operation to prevent race conditions. -func (p *WorkerPool) Store(ctx context.Context) (err error) { - queueGauge := metricWorkerQueue.WithLabelValues("Store", "") - queueGauge.Inc() - defer queueGauge.Dec() - var wg sync.WaitGroup - out := make(chan *CacheItem, 500) - - // Iterate each worker's cache to `out` channel. - for _, worker := range p.workers { - wg.Add(1) - - go func(ctx context.Context, worker *Worker) { - defer wg.Done() - - respChan := make(chan workerStoreResponse) - req := workerStoreRequest{ - ctx: ctx, - response: respChan, - out: out, - } - - select { - case worker.storeRequest <- req: - // Successfully sent request. - select { - case <-respChan: - // Successfully received response. - return - - case <-ctx.Done(): - // Context canceled. - trace.SpanFromContext(ctx).RecordError(ctx.Err()) - return - } - - case <-ctx.Done(): - // Context canceled. - trace.SpanFromContext(ctx).RecordError(ctx.Err()) - return - } - }(ctx, worker) - } - - // When all iterators are done, close `out` channel. - go func() { - wg.Wait() - close(out) - }() - - if ctx.Err() != nil { - return ctx.Err() - } - - if err = p.conf.Loader.Save(out); err != nil { - return errors.Wrap(err, "while calling p.conf.Loader.Save()") - } - - return nil -} - -func (worker *Worker) handleStore(request workerStoreRequest, cache Cache) { - for item := range cache.Each() { - select { - case request.out <- item: - // Successfully sent item. - - case <-request.ctx.Done(): - // Context canceled. - trace.SpanFromContext(request.ctx).RecordError(request.ctx.Err()) - return - } - } - - response := workerStoreResponse{} - - select { - case request.response <- response: - // Successfully sent response. - - case <-request.ctx.Done(): - // Context canceled. - trace.SpanFromContext(request.ctx).RecordError(request.ctx.Err()) - } -} - -// AddCacheItem adds an item to the worker's cache. -func (p *WorkerPool) AddCacheItem(ctx context.Context, key string, item *CacheItem) (err error) { - worker := p.getWorker(key) - queueGauge := metricWorkerQueue.WithLabelValues("AddCacheItem", worker.name) - queueGauge.Inc() - defer queueGauge.Dec() - respChan := make(chan workerAddCacheItemResponse) - req := workerAddCacheItemRequest{ - ctx: ctx, - response: respChan, - item: item, - } - - select { - case worker.addCacheItemRequest <- req: - // Successfully sent request. - select { - case <-respChan: - // Successfully received response. - return nil - - case <-ctx.Done(): - // Context canceled. - return ctx.Err() - } - - case <-ctx.Done(): - // Context canceled. - return ctx.Err() - } -} - -func (worker *Worker) handleAddCacheItem(request workerAddCacheItemRequest, cache Cache) { - exists := cache.Add(request.item) - response := workerAddCacheItemResponse{exists} - - select { - case request.response <- response: - // Successfully sent response. - - case <-request.ctx.Done(): - // Context canceled. - trace.SpanFromContext(request.ctx).RecordError(request.ctx.Err()) - } -} - -// GetCacheItem gets item from worker's cache. -func (p *WorkerPool) GetCacheItem(ctx context.Context, key string) (item *CacheItem, found bool, err error) { - worker := p.getWorker(key) - queueGauge := metricWorkerQueue.WithLabelValues("GetCacheItem", worker.name) - queueGauge.Inc() - defer queueGauge.Dec() - respChan := make(chan workerGetCacheItemResponse) - req := workerGetCacheItemRequest{ - ctx: ctx, - response: respChan, - key: key, - } - - select { - case worker.getCacheItemRequest <- req: - // Successfully sent request. - select { - case resp := <-respChan: - // Successfully received response. - return resp.item, resp.ok, nil - - case <-ctx.Done(): - // Context canceled. - return nil, false, ctx.Err() - } - - case <-ctx.Done(): - // Context canceled. - return nil, false, ctx.Err() - } -} - -func (worker *Worker) handleGetCacheItem(request workerGetCacheItemRequest, cache Cache) { - item, ok := cache.GetItem(request.key) - response := workerGetCacheItemResponse{item, ok} - - select { - case request.response <- response: - // Successfully sent response. - - case <-request.ctx.Done(): - // Context canceled. - trace.SpanFromContext(request.ctx).RecordError(request.ctx.Err()) - } -} diff --git a/workers_internal_test.go b/workers_internal_test.go deleted file mode 100644 index 291971a..0000000 --- a/workers_internal_test.go +++ /dev/null @@ -1,84 +0,0 @@ -/* -Copyright 2024 Mailgun Technologies Inc - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package gubernator - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -type MockHasher struct { - mock.Mock -} - -func (m *MockHasher) ComputeHash63(input string) uint64 { - args := m.Called(input) - retval, _ := args.Get(0).(uint64) - return retval -} - -func TestWorkersInternal(t *testing.T) { - t.Run("getWorker()", func(t *testing.T) { - const concurrency = 32 - conf := &Config{ - Workers: concurrency, - } - require.NoError(t, conf.SetDefaults()) - - // Test that getWorker() interpolates the hash to find the expected worker. - testCases := []struct { - Name string - Hash uint64 - ExpectedIdx int - }{ - {"Hash 0%", 0, 0}, - {"Hash 50%", 0x3fff_ffff_ffff_ffff, (concurrency / 2) - 1}, - {"Hash 50% + 1", 0x4000_0000_0000_0000, concurrency / 2}, - {"Hash 100%", 0x7fff_ffff_ffff_ffff, concurrency - 1}, - } - - for _, testCase := range testCases { - t.Run(testCase.Name, func(t *testing.T) { - pool := NewWorkerPool(conf) - defer pool.Close() - mockHasher := &MockHasher{} - pool.hasher = mockHasher - - // Setup mocks. - mockHasher.On("ComputeHash63", mock.Anything).Once().Return(testCase.Hash) - - // Call code. - worker := pool.getWorker("Foobar") - - // Verify - require.NotNil(t, worker) - - var actualIdx int - for ; actualIdx < len(pool.workers); actualIdx++ { - if pool.workers[actualIdx] == worker { - break - } - } - assert.Equal(t, testCase.ExpectedIdx, actualIdx) - mockHasher.AssertExpectations(t) - }) - } - }) -}