diff --git a/README.md b/README.md index a584edc0d..c47f3be4e 100644 --- a/README.md +++ b/README.md @@ -582,7 +582,9 @@ Experimental Memcache support has been added as an alternative to Redis in v1.5. To configure a Memcache instance use the following environment variables instead of the Redis variables: -1. `MEMCACHE_HOST_PORT=`: a comma separated list of hostname:port pairs for memcache nodes. +1. `MEMCACHE_HOST_PORT=`: a comma separated list of hostname:port pairs for memcache nodes (mutually exclusive with `MEMCACHE_SRV`) +1. `MEMCACHE_SRV=`: an SRV record to lookup hosts from (mutually exclusive with `MEMCACHE_HOST_PORT`) +1. `MEMCACHE_SRV_REFRESH=0`: refresh the list of hosts every n seconds, if 0 no refreshing will happen, supports duration suffixes: "ns", "us" (or "µs"), "ms", "s", "m", "h". 1. `BACKEND_TYPE=memcache` 1. `CACHE_KEY_PREFIX`: a string to prepend to all cache keys diff --git a/src/memcached/cache_impl.go b/src/memcached/cache_impl.go index 1e7b0b69a..17a71261c 100644 --- a/src/memcached/cache_impl.go +++ b/src/memcached/cache_impl.go @@ -20,6 +20,7 @@ import ( "math/rand" "strconv" "sync" + "time" "github.com/coocood/freecache" stats "github.com/lyft/gostats" @@ -33,6 +34,7 @@ import ( "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/settings" + "github.com/envoyproxy/ratelimit/src/srv" "github.com/envoyproxy/ratelimit/src/utils" ) @@ -173,6 +175,68 @@ func (this *rateLimitMemcacheImpl) Flush() { this.waitGroup.Wait() } +func refreshServersPeriodically(serverList memcache.ServerList, srv string, d time.Duration, finish <-chan struct{}) { + t := time.NewTicker(d) + defer t.Stop() + for { + select { + case <-t.C: + err := refreshServers(serverList, srv) + if err != nil { + logger.Warn("failed to refresh memcahce hosts") + } else { + logger.Debug("refreshed memcache hosts") + } + case <-finish: + return + } + } +} + +func refreshServers(serverList memcache.ServerList, srv_ string) error { + servers, err := srv.ServerStringsFromSrv(srv_) + if err != nil { + return err + } + err = serverList.SetServers(servers...) + if err != nil { + return err + } + return nil +} + +func newMemcachedFromSrv(srv_ string, d time.Duration) Client { + serverList := new(memcache.ServerList) + err := refreshServers(*serverList, srv_) + if err != nil { + errorText := "Unable to fetch servers from SRV" + logger.Errorf(errorText) + panic(MemcacheError(errorText)) + } + + if d > 0 { + logger.Infof("refreshing memcache hosts every: %v milliseconds", d.Milliseconds()) + finish := make(chan struct{}) + go refreshServersPeriodically(*serverList, srv_, d, finish) + } else { + logger.Debugf("not periodically refreshing memcached hosts") + } + + return memcache.NewFromSelector(serverList) +} + +func newMemcacheFromSettings(s settings.Settings) Client { + if s.MemcacheSrv != "" && len(s.MemcacheHostPort) > 0 { + panic(MemcacheError("Both MEMCADHE_HOST_PORT and MEMCACHE_SRV are set")) + } + if s.MemcacheSrv != "" { + logger.Debugf("Using MEMCACHE_SRV: %v", s.MemcacheSrv) + return newMemcachedFromSrv(s.MemcacheSrv, s.MemcacheSrvSeconds) + } + logger.Debugf("Usng MEMCACHE_HOST_PORT:: %v", s.MemcacheHostPort) + return memcache.New(s.MemcacheHostPort...) +} + func NewRateLimitCacheImpl(client Client, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, localCache *freecache.Cache, scope stats.Scope, nearLimitRatio float32, cacheKeyPrefix string) limiter.RateLimitCache { return &rateLimitMemcacheImpl{ @@ -189,7 +253,7 @@ func NewRateLimitCacheImpl(client Client, timeSource utils.TimeSource, jitterRan func NewRateLimitCacheImplFromSettings(s settings.Settings, timeSource utils.TimeSource, jitterRand *rand.Rand, localCache *freecache.Cache, scope stats.Scope) limiter.RateLimitCache { return NewRateLimitCacheImpl( - CollectStats(memcache.New(s.MemcacheHostPort...), scope.Scope("memcache")), + CollectStats(newMemcacheFromSettings(s), scope.Scope("memcache")), timeSource, jitterRand, s.ExpirationJitterMaxSeconds, diff --git a/src/memcached/client.go b/src/memcached/client.go index 55c0ec318..e80902692 100644 --- a/src/memcached/client.go +++ b/src/memcached/client.go @@ -4,6 +4,13 @@ import ( "github.com/bradfitz/gomemcache/memcache" ) +// Errors that may be raised during config parsing. +type MemcacheError string + +func (e MemcacheError) Error() string { + return string(e) +} + var _ Client = (*memcache.Client)(nil) // Interface for memcached, used for mocking. diff --git a/src/settings/settings.go b/src/settings/settings.go index 82b8ebe59..b7edf7ce1 100644 --- a/src/settings/settings.go +++ b/src/settings/settings.go @@ -59,7 +59,7 @@ type Settings struct { RedisPerSecondPoolSize int `envconfig:"REDIS_PERSECOND_POOL_SIZE" default:"10"` RedisPerSecondAuth string `envconfig:"REDIS_PERSECOND_AUTH" default:""` RedisPerSecondTls bool `envconfig:"REDIS_PERSECOND_TLS" default:"false"` - // RedisPerSecondPipelineWindow sets the duration after which internal pipelines will be flushed for per second redis. + // RedisPerSec ondPipelineWindow sets the duration after which internal pipelines will be flushed for per second redis. // See comments of RedisPipelineWindow for details. RedisPerSecondPipelineWindow time.Duration `envconfig:"REDIS_PERSECOND_PIPELINE_WINDOW" default:"0"` // RedisPerSecondPipelineLimit sets maximum number of commands that can be pipelined before flushing for per second redis. @@ -67,7 +67,9 @@ type Settings struct { RedisPerSecondPipelineLimit int `envconfig:"REDIS_PERSECOND_PIPELINE_LIMIT" default:"0"` // Memcache settings - MemcacheHostPort []string `envconfig:"MEMCACHE_HOST_PORT" default:""` + MemcacheHostPort []string `envconfig:"MEMCACHE_HOST_PORT" default:""` + MemcacheSrv string `envconfig:"MEMCACHE_SRV" default:""` + MemcacheSrvRefresh time.Duration `envconfig:"MEMCACHE_SRV_REFRESH" default:"0"` } type Option func(*Settings) diff --git a/src/srv/srv.go b/src/srv/srv.go new file mode 100644 index 000000000..041ceb950 --- /dev/null +++ b/src/srv/srv.go @@ -0,0 +1,49 @@ +package srv + +import ( + "errors" + "fmt" + "net" + "regexp" + + logger "github.com/sirupsen/logrus" +) + +var srvRegex = regexp.MustCompile(`^_(.+?)\._(.+?)\.(.+)$`) + +func ParseSrv(srv string) (string, string, string, error) { + matches := srvRegex.FindStringSubmatch(srv) + if matches == nil { + errorText := fmt.Sprintf("could not parse %s to SRV parts", srv) + logger.Errorf(errorText) + return "", "", "", errors.New(errorText) + } + return matches[1], matches[2], matches[3], nil +} + +func ServerStringsFromSrv(srv string) ([]string, error) { + service, proto, name, err := ParseSrv(srv) + + if err != nil { + logger.Errorf("failed to parse SRV: %s", err) + return nil, err + } + + _, srvs, err := net.LookupSRV(service, proto, name) + + if err != nil { + logger.Errorf("failed to lookup SRV: %s", err) + return nil, err + } + + logger.Debugf("found %v servers(s) from SRV", len(srvs)) + + serversFromSrv := make([]string, len(srvs)) + for i, srv := range srvs { + server := fmt.Sprintf("%s:%v", srv.Target, srv.Port) + logger.Debugf("server from srv[%v]: %s", i, server) + serversFromSrv[i] = fmt.Sprintf("%s:%v", srv.Target, srv.Port) + } + + return serversFromSrv, nil +} diff --git a/test/memcached/cache_impl_test.go b/test/memcached/cache_impl_test.go index 1e2ba8d77..652ac7c7a 100644 --- a/test/memcached/cache_impl_test.go +++ b/test/memcached/cache_impl_test.go @@ -16,6 +16,7 @@ import ( "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/memcached" + "github.com/envoyproxy/ratelimit/src/settings" "github.com/envoyproxy/ratelimit/src/utils" stats "github.com/lyft/gostats" @@ -583,6 +584,41 @@ func TestMemcacheAdd(t *testing.T) { cache.Flush() } +func TestNewRateLimitCacheImplFromSettingsWhenSrvCannotBeResolved(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + + var s settings.Settings + s.NearLimitRatio = 0.8 + s.CacheKeyPrefix = "" + s.ExpirationJitterMaxSeconds = 300 + s.MemcacheSrv = "_something._tcp.example.invalid" + + assert.Panics(func() { memcached.NewRateLimitCacheImplFromSettings(s, timeSource, nil, nil, statsStore) }) +} + +func TestNewRateLimitCacheImplFromSettingsWhenHostAndPortAndSrvAreBothSet(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + + var s settings.Settings + s.NearLimitRatio = 0.8 + s.CacheKeyPrefix = "" + s.ExpirationJitterMaxSeconds = 300 + s.MemcacheSrv = "_something._tcp.example.invalid" + s.MemcacheHostPort = []string{"example.org:11211"} + + assert.Panics(func() { memcached.NewRateLimitCacheImplFromSettings(s, timeSource, nil, nil, statsStore) }) +} + func getMultiResult(vals map[string]int) map[string]*memcache.Item { result := make(map[string]*memcache.Item, len(vals)) for k, v := range vals { diff --git a/test/srv/srv_test.go b/test/srv/srv_test.go new file mode 100644 index 000000000..5e3e8f79f --- /dev/null +++ b/test/srv/srv_test.go @@ -0,0 +1,56 @@ +package srv + +import ( + "errors" + "net" + "testing" + + "github.com/envoyproxy/ratelimit/src/srv" + "github.com/stretchr/testify/assert" +) + +func TestParseSrv(t *testing.T) { + service, proto, name, err := srv.ParseSrv("_something._tcp.example.org.") + assert.Equal(t, service, "something") + assert.Equal(t, proto, "tcp") + assert.Equal(t, name, "example.org.") + assert.Nil(t, err) + + service, proto, name, err = srv.ParseSrv("_something-else._udp.example.org") + assert.Equal(t, service, "something-else") + assert.Equal(t, proto, "udp") + assert.Equal(t, name, "example.org") + assert.Nil(t, err) + + _, _, _, err = srv.ParseSrv("example.org") + assert.Equal(t, err, errors.New("could not parse example.org to SRV parts")) +} + +func TestServerStringsFromSrvWhenSrvIsNotWellFormed(t *testing.T) { + _, err := srv.ServerStringsFromSrv("example.org") + assert.Equal(t, err, errors.New("could not parse example.org to SRV parts")) +} + +func TestServerStringsFromSevWhenSrvIsWellFormedButNotLookupable(t *testing.T) { + _, err := srv.ServerStringsFromSrv("_something._tcp.example.invalid") + var e *net.DNSError + if errors.As(err, &e) { + assert.Equal(t, e.Err, "no such host") + assert.Equal(t, e.Name, "_something._tcp.example.invalid") + assert.False(t, e.IsTimeout) + assert.False(t, e.IsTemporary) + assert.True(t, e.IsNotFound) + } else { + t.Fail() + } +} + +func TestServerStrings(t *testing.T) { + // it seems reasonable to think _xmpp-server._tcp.gmail.com will be available for a long time! + servers, err := srv.ServerStringsFromSrv("_xmpp-server._tcp.gmail.com.") + assert.True(t, len(servers) > 0) + for _, s := range servers { + assert.Regexp(t, `^.*xmpp-server.*google.com.:\d+$`, s) + } + assert.Nil(t, err) +}