diff --git a/README.md b/README.md index a584edc0..6e4cdee5 100644 --- a/README.md +++ b/README.md @@ -501,7 +501,7 @@ $ curl 0:6070/ /stats: print out stats ``` -You can specify the debug port with the `DEBUG_PORT` environment variable. It defaults to `6070`. +You can specify the debug server address with the `DEBUG_HOST` and `DEBUG_PORT` environment variables. They currently default to `0.0.0.0` and `6070` respectively. # Local Cache @@ -582,9 +582,12 @@ Experimental Memcache support has been added as an alternative to Redis in v1.5. To configure a Memcache instance use the following environment variables instead of the Redis variables: -1. `MEMCACHE_HOST_PORT=`: a comma separated list of hostname:port pairs for memcache nodes. +1. `MEMCACHE_HOST_PORT=`: a comma separated list of hostname:port pairs for memcache nodes (mutually exclusive with `MEMCACHE_SRV`) +1. `MEMCACHE_SRV=`: an SRV record to lookup hosts from (mutually exclusive with `MEMCACHE_HOST_PORT`) +1. `MEMCACHE_SRV_REFRESH=0`: refresh the list of hosts every n seconds, if 0 no refreshing will happen, supports duration suffixes: "ns", "us" (or "µs"), "ms", "s", "m", "h". 1. `BACKEND_TYPE=memcache` 1. `CACHE_KEY_PREFIX`: a string to prepend to all cache keys +1. `MEMCACHE_MAX_IDLE_CONNS=2`: the maximum number of idle TCP connections per memcache node, `2` is the default of the underlying library With memcache mode increments will happen asynchronously, so it's technically possible for a client to exceed quota briefly if multiple requests happen at exactly the same time. diff --git a/src/memcached/cache_impl.go b/src/memcached/cache_impl.go index 536c2fc2..4b21af33 100644 --- a/src/memcached/cache_impl.go +++ b/src/memcached/cache_impl.go @@ -21,6 +21,7 @@ import ( "math/rand" "strconv" "sync" + "time" "github.com/coocood/freecache" gostats "github.com/lyft/gostats" @@ -34,6 +35,7 @@ import ( "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/settings" + "github.com/envoyproxy/ratelimit/src/srv" "github.com/envoyproxy/ratelimit/src/utils" ) @@ -122,7 +124,7 @@ func (this *rateLimitMemcacheImpl) DoLimit( } this.waitGroup.Add(1) - go this.increaseAsync(cacheKeys, isOverLimitWithLocalCache, limits, uint64(hitsAddend)) + runAsync(func() { this.increaseAsync(cacheKeys, isOverLimitWithLocalCache, limits, uint64(hitsAddend)) }) if AutoFlushForIntegrationTests { this.Flush() } @@ -174,6 +176,104 @@ func (this *rateLimitMemcacheImpl) Flush() { this.waitGroup.Wait() } +func refreshServersPeriodically(serverList memcache.ServerList, srv string, d time.Duration, finish <-chan struct{}) { + t := time.NewTicker(d) + defer t.Stop() + for { + select { + case <-t.C: + err := refreshServers(serverList, srv) + if err != nil { + logger.Warn("failed to refresh memcahce hosts") + } else { + logger.Debug("refreshed memcache hosts") + } + case <-finish: + return + } + } +} + +func refreshServers(serverList memcache.ServerList, srv_ string) error { + servers, err := srv.ServerStringsFromSrv(srv_) + if err != nil { + return err + } + err = serverList.SetServers(servers...) + if err != nil { + return err + } + return nil +} + +func newMemcachedFromSrv(srv_ string, d time.Duration) Client { + serverList := new(memcache.ServerList) + err := refreshServers(*serverList, srv_) + if err != nil { + errorText := "Unable to fetch servers from SRV" + logger.Errorf(errorText) + panic(MemcacheError(errorText)) + } + + if d > 0 { + logger.Infof("refreshing memcache hosts every: %v milliseconds", d.Milliseconds()) + finish := make(chan struct{}) + go refreshServersPeriodically(*serverList, srv_, d, finish) + } else { + logger.Debugf("not periodically refreshing memcached hosts") + } + + return memcache.NewFromSelector(serverList) +} + +func newMemcacheFromSettings(s settings.Settings) Client { + if s.MemcacheSrv != "" && len(s.MemcacheHostPort) > 0 { + panic(MemcacheError("Both MEMCADHE_HOST_PORT and MEMCACHE_SRV are set")) + } + if s.MemcacheSrv != "" { + logger.Debugf("Using MEMCACHE_SRV: %v", s.MemcacheSrv) + return newMemcachedFromSrv(s.MemcacheSrv, s.MemcacheSrvRefresh) + } + logger.Debugf("Usng MEMCACHE_HOST_PORT:: %v", s.MemcacheHostPort) + client := memcache.New(s.MemcacheHostPort...) + client.MaxIdleConns = s.MemcacheMaxIdleConns + return client +} + +var taskQueue = make(chan func()) + +func runAsync(task func()) { + select { + case taskQueue <- task: + // submitted, everything is ok + + default: + go func() { + // do the given task + task() + + tasksProcessedWithinOnePeriod := 0 + const tickDuration = 10 * time.Second + tick := time.NewTicker(tickDuration) + defer tick.Stop() + + for { + select { + case t := <-taskQueue: + t() + tasksProcessedWithinOnePeriod++ + case <-tick.C: + if tasksProcessedWithinOnePeriod > 0 { + tasksProcessedWithinOnePeriod = 0 + continue + } + return + } + } + }() + } +} + func NewRateLimitCacheImpl(client Client, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, localCache *freecache.Cache, statsManager stats.Manager, nearLimitRatio float32, cacheKeyPrefix string) limiter.RateLimitCache { return &rateLimitMemcacheImpl{ @@ -190,7 +290,7 @@ func NewRateLimitCacheImpl(client Client, timeSource utils.TimeSource, jitterRan func NewRateLimitCacheImplFromSettings(s settings.Settings, timeSource utils.TimeSource, jitterRand *rand.Rand, localCache *freecache.Cache, scope gostats.Scope, statsManager stats.Manager) limiter.RateLimitCache { return NewRateLimitCacheImpl( - CollectStats(memcache.New(s.MemcacheHostPort...), scope.Scope("memcache")), + CollectStats(newMemcacheFromSettings(s), scope.Scope("memcache")), timeSource, jitterRand, s.ExpirationJitterMaxSeconds, diff --git a/src/memcached/client.go b/src/memcached/client.go index 55c0ec31..e8090269 100644 --- a/src/memcached/client.go +++ b/src/memcached/client.go @@ -4,6 +4,13 @@ import ( "github.com/bradfitz/gomemcache/memcache" ) +// Errors that may be raised during config parsing. +type MemcacheError string + +func (e MemcacheError) Error() string { + return string(e) +} + var _ Client = (*memcache.Client)(nil) // Interface for memcached, used for mocking. diff --git a/src/redis/driver_impl.go b/src/redis/driver_impl.go index 18e213f1..f6449ea5 100644 --- a/src/redis/driver_impl.go +++ b/src/redis/driver_impl.go @@ -59,14 +59,10 @@ func NewClientImpl(scope stats.Scope, useTls bool, auth string, redisType string df := func(network, addr string) (radix.Conn, error) { var dialOpts []radix.DialOpt - var err error if useTls { dialOpts = append(dialOpts, radix.DialUseTLS(&tls.Config{})) } - if err != nil { - return nil, err - } if auth != "" { logger.Warnf("enabling authentication to redis on %s", url) diff --git a/src/server/server_impl.go b/src/server/server_impl.go index e96d4e72..15305a32 100644 --- a/src/server/server_impl.go +++ b/src/server/server_impl.go @@ -10,6 +10,7 @@ import ( "net/http/pprof" "path/filepath" "sort" + "strconv" "sync" "os" @@ -40,9 +41,9 @@ type serverDebugListener struct { } type server struct { - port int - grpcPort int - debugPort int + httpAddress string + grpcAddress string + debugAddress string router *mux.Router grpcServer *grpc.Server store gostats.Store @@ -115,11 +116,10 @@ func (server *server) GrpcServer() *grpc.Server { func (server *server) Start() { go func() { - addr := fmt.Sprintf(":%d", server.debugPort) - logger.Warnf("Listening for debug on '%s'", addr) + logger.Warnf("Listening for debug on '%s'", server.debugAddress) var err error server.listenerMu.Lock() - server.debugListener.listener, err = reuseport.Listen("tcp", addr) + server.debugListener.listener, err = reuseport.Listen("tcp", server.debugAddress) server.listenerMu.Unlock() if err != nil { @@ -134,9 +134,8 @@ func (server *server) Start() { server.handleGracefulShutdown() - addr := fmt.Sprintf(":%d", server.port) - logger.Warnf("Listening for HTTP on '%s'", addr) - list, err := reuseport.Listen("tcp", addr) + logger.Warnf("Listening for HTTP on '%s'", server.httpAddress) + list, err := reuseport.Listen("tcp", server.httpAddress) if err != nil { logger.Fatalf("Failed to open HTTP listener: '%+v'", err) } @@ -152,9 +151,8 @@ func (server *server) Start() { } func (server *server) startGrpc() { - addr := fmt.Sprintf(":%d", server.grpcPort) - logger.Warnf("Listening for gRPC on '%s'", addr) - lis, err := reuseport.Listen("tcp", addr) + logger.Warnf("Listening for gRPC on '%s'", server.grpcAddress) + lis, err := reuseport.Listen("tcp", server.grpcAddress) if err != nil { logger.Fatalf("Failed to listen for gRPC: %v", err) } @@ -181,10 +179,10 @@ func newServer(s settings.Settings, name string, statsManager stats.Manager, loc ret := new(server) ret.grpcServer = grpc.NewServer(s.GrpcUnaryInterceptor) - // setup ports - ret.port = s.Port - ret.grpcPort = s.GrpcPort - ret.debugPort = s.DebugPort + // setup listen addresses + ret.httpAddress = net.JoinHostPort(s.Host, strconv.Itoa(s.Port)) + ret.grpcAddress = net.JoinHostPort(s.GrpcHost, strconv.Itoa(s.GrpcPort)) + ret.debugAddress = net.JoinHostPort(s.DebugHost, strconv.Itoa(s.DebugPort)) // setup stats ret.store = statsManager.GetStatsStore() @@ -255,6 +253,14 @@ func newServer(s settings.Settings, name string, statsManager stats.Manager, loc }) }) + // setup trace endpoint + ret.AddDebugHttpEndpoint( + "/debug/pprof/trace", + "trace endpoint", + func(writer http.ResponseWriter, request *http.Request) { + pprof.Trace(writer, request) + }) + // setup debug root ret.debugListener.debugMux.HandleFunc( "/", diff --git a/src/settings/settings.go b/src/settings/settings.go index 82b8ebe5..2646b5b2 100644 --- a/src/settings/settings.go +++ b/src/settings/settings.go @@ -10,10 +10,13 @@ import ( type Settings struct { // runtime options GrpcUnaryInterceptor grpc.ServerOption - // env config - Port int `envconfig:"PORT" default:"8080"` - GrpcPort int `envconfig:"GRPC_PORT" default:"8081"` - DebugPort int `envconfig:"DEBUG_PORT" default:"6070"` + // Server listen address config + Host string `envconfig:"HOST" default:"0.0.0.0"` + Port int `envconfig:"PORT" default:"8080"` + GrpcHost string `envconfig:"GRPC_HOST" default:"0.0.0.0"` + GrpcPort int `envconfig:"GRPC_PORT" default:"8081"` + DebugHost string `envconfig:"DEBUG_HOST" default:"0.0.0.0"` + DebugPort int `envconfig:"DEBUG_PORT" default:"6070"` // Logging settings LogLevel string `envconfig:"LOG_LEVEL" default:"WARN"` @@ -68,6 +71,14 @@ type Settings struct { // Memcache settings MemcacheHostPort []string `envconfig:"MEMCACHE_HOST_PORT" default:""` + // MemcacheMaxIdleConns sets the maximum number of idle TCP connections per memcached node. + // The default is 2 as that is the default of the underlying library. This is the maximum + // number of connections to memcache kept idle in pool, if a connection is needed but none + // are idle a new connection is opened, used and closed and can be left in a time-wait state + // which can result in high CPU usage. + MemcacheMaxIdleConns int `envconfig:"MEMCACHE_MAX_IDLE_CONNS" default:"2"` + MemcacheSrv string `envconfig:"MEMCACHE_SRV" default:""` + MemcacheSrvRefresh time.Duration `envconfig:"MEMCACHE_SRV_REFRESH" default:"0"` } type Option func(*Settings) diff --git a/src/srv/srv.go b/src/srv/srv.go new file mode 100644 index 00000000..041ceb95 --- /dev/null +++ b/src/srv/srv.go @@ -0,0 +1,49 @@ +package srv + +import ( + "errors" + "fmt" + "net" + "regexp" + + logger "github.com/sirupsen/logrus" +) + +var srvRegex = regexp.MustCompile(`^_(.+?)\._(.+?)\.(.+)$`) + +func ParseSrv(srv string) (string, string, string, error) { + matches := srvRegex.FindStringSubmatch(srv) + if matches == nil { + errorText := fmt.Sprintf("could not parse %s to SRV parts", srv) + logger.Errorf(errorText) + return "", "", "", errors.New(errorText) + } + return matches[1], matches[2], matches[3], nil +} + +func ServerStringsFromSrv(srv string) ([]string, error) { + service, proto, name, err := ParseSrv(srv) + + if err != nil { + logger.Errorf("failed to parse SRV: %s", err) + return nil, err + } + + _, srvs, err := net.LookupSRV(service, proto, name) + + if err != nil { + logger.Errorf("failed to lookup SRV: %s", err) + return nil, err + } + + logger.Debugf("found %v servers(s) from SRV", len(srvs)) + + serversFromSrv := make([]string, len(srvs)) + for i, srv := range srvs { + server := fmt.Sprintf("%s:%v", srv.Target, srv.Port) + logger.Debugf("server from srv[%v]: %s", i, server) + serversFromSrv[i] = fmt.Sprintf("%s:%v", srv.Target, srv.Port) + } + + return serversFromSrv, nil +} diff --git a/test/integration/integration_test.go b/test/integration/integration_test.go index da7cfa6d..249d9b2c 100644 --- a/test/integration/integration_test.go +++ b/test/integration/integration_test.go @@ -209,6 +209,21 @@ func TestBasicConfigMemcache(t *testing.T) { }) } +func TestConfigMemcacheWithMaxIdleConns(t *testing.T) { + singleNodePort := []int{6394} + assert := assert.New(t) + common.WithMultiMemcache(t, []common.MemcacheConfig{ + {Port: 6394}, + }, func() { + withDefaultMaxIdleConns := makeSimpleMemcacheSettings(singleNodePort, 0) + assert.Equal(2, withDefaultMaxIdleConns.MemcacheMaxIdleConns) + t.Run("MemcacheWithDefaultMaxIdleConns", testBasicConfig(withDefaultMaxIdleConns)) + withSpecifiedMaxIdleConns := makeSimpleMemcacheSettings(singleNodePort, 0) + withSpecifiedMaxIdleConns.MemcacheMaxIdleConns = 100 + t.Run("MemcacheWithSpecifiedMaxIdleConns", testBasicConfig(withSpecifiedMaxIdleConns)) + }) +} + func TestMultiNodeMemcache(t *testing.T) { multiNodePorts := []int{6494, 6495} common.WithMultiMemcache(t, []common.MemcacheConfig{ diff --git a/test/memcached/cache_impl_test.go b/test/memcached/cache_impl_test.go index 1358af44..0244d3bf 100644 --- a/test/memcached/cache_impl_test.go +++ b/test/memcached/cache_impl_test.go @@ -17,6 +17,7 @@ import ( "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/memcached" + "github.com/envoyproxy/ratelimit/src/settings" "github.com/envoyproxy/ratelimit/src/utils" stats "github.com/lyft/gostats" @@ -590,6 +591,41 @@ func TestMemcacheAdd(t *testing.T) { cache.Flush() } +func TestNewRateLimitCacheImplFromSettingsWhenSrvCannotBeResolved(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + + var s settings.Settings + s.NearLimitRatio = 0.8 + s.CacheKeyPrefix = "" + s.ExpirationJitterMaxSeconds = 300 + s.MemcacheSrv = "_something._tcp.example.invalid" + + assert.Panics(func() { memcached.NewRateLimitCacheImplFromSettings(s, timeSource, nil, nil, statsStore) }) +} + +func TestNewRateLimitCacheImplFromSettingsWhenHostAndPortAndSrvAreBothSet(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + + var s settings.Settings + s.NearLimitRatio = 0.8 + s.CacheKeyPrefix = "" + s.ExpirationJitterMaxSeconds = 300 + s.MemcacheSrv = "_something._tcp.example.invalid" + s.MemcacheHostPort = []string{"example.org:11211"} + + assert.Panics(func() { memcached.NewRateLimitCacheImplFromSettings(s, timeSource, nil, nil, statsStore) }) +} + func getMultiResult(vals map[string]int) map[string]*memcache.Item { result := make(map[string]*memcache.Item, len(vals)) for k, v := range vals { diff --git a/test/service/ratelimit_legacy_test.go b/test/service/ratelimit_legacy_test.go index 2b13ac5e..0c0a17aa 100644 --- a/test/service/ratelimit_legacy_test.go +++ b/test/service/ratelimit_legacy_test.go @@ -121,7 +121,7 @@ func TestServiceLegacy(test *testing.T) { t.configLoader.EXPECT().Load( []config.RateLimitConfigToLoad{{"config.basic_config", "fake_yaml"}}, gomock.Any()).Do( func([]config.RateLimitConfigToLoad, stats.Manager) { - barrier.signal() + defer barrier.signal() panic(config.RateLimitConfigError("load error")) }) t.runtimeUpdateCallback <- 1 diff --git a/test/srv/srv_test.go b/test/srv/srv_test.go new file mode 100644 index 00000000..5e3e8f79 --- /dev/null +++ b/test/srv/srv_test.go @@ -0,0 +1,56 @@ +package srv + +import ( + "errors" + "net" + "testing" + + "github.com/envoyproxy/ratelimit/src/srv" + "github.com/stretchr/testify/assert" +) + +func TestParseSrv(t *testing.T) { + service, proto, name, err := srv.ParseSrv("_something._tcp.example.org.") + assert.Equal(t, service, "something") + assert.Equal(t, proto, "tcp") + assert.Equal(t, name, "example.org.") + assert.Nil(t, err) + + service, proto, name, err = srv.ParseSrv("_something-else._udp.example.org") + assert.Equal(t, service, "something-else") + assert.Equal(t, proto, "udp") + assert.Equal(t, name, "example.org") + assert.Nil(t, err) + + _, _, _, err = srv.ParseSrv("example.org") + assert.Equal(t, err, errors.New("could not parse example.org to SRV parts")) +} + +func TestServerStringsFromSrvWhenSrvIsNotWellFormed(t *testing.T) { + _, err := srv.ServerStringsFromSrv("example.org") + assert.Equal(t, err, errors.New("could not parse example.org to SRV parts")) +} + +func TestServerStringsFromSevWhenSrvIsWellFormedButNotLookupable(t *testing.T) { + _, err := srv.ServerStringsFromSrv("_something._tcp.example.invalid") + var e *net.DNSError + if errors.As(err, &e) { + assert.Equal(t, e.Err, "no such host") + assert.Equal(t, e.Name, "_something._tcp.example.invalid") + assert.False(t, e.IsTimeout) + assert.False(t, e.IsTemporary) + assert.True(t, e.IsNotFound) + } else { + t.Fail() + } +} + +func TestServerStrings(t *testing.T) { + // it seems reasonable to think _xmpp-server._tcp.gmail.com will be available for a long time! + servers, err := srv.ServerStringsFromSrv("_xmpp-server._tcp.gmail.com.") + assert.True(t, len(servers) > 0) + for _, s := range servers { + assert.Regexp(t, `^.*xmpp-server.*google.com.:\d+$`, s) + } + assert.Nil(t, err) +}