From 6becfe9bb3e08ec7ba0e1b3f4446d36189926a9e Mon Sep 17 00:00:00 2001 From: Peter Marsh Date: Wed, 5 May 2021 16:08:57 +0200 Subject: [PATCH] Add support for SRV records for memcache clients This allows MEMCAHE_SRV to be specified as an SRV record from which multiple memcache hosts can be resolved. For example: MEMCACHE_SRV=_memcache._tcp.mylovelydomain.com This can be used instead of MEMCACHE_HOST_PORT. This will then be resolved and whatever set of servers it represents will be used as the set of memcache servers to connect to. At this stage neither priority or weight is supported, though weight could be fairly straightforwardly in future. The SRV can be polled periodically for new servers by setting the following env var (with 0 meaning "never check"): MEMCACHE_SRV_REFRESH=600s # supports standard go durations --- README.md | 4 +- src/memcached/cache_impl.go | 66 ++++++++++++++++++++++++++++++- src/memcached/client.go | 7 ++++ src/settings/settings.go | 6 ++- src/srv/srv.go | 49 +++++++++++++++++++++++ test/memcached/cache_impl_test.go | 36 +++++++++++++++++ test/srv/srv_test.go | 56 ++++++++++++++++++++++++++ 7 files changed, 220 insertions(+), 4 deletions(-) create mode 100644 src/srv/srv.go create mode 100644 test/srv/srv_test.go diff --git a/README.md b/README.md index a584edc0d..c47f3be4e 100644 --- a/README.md +++ b/README.md @@ -582,7 +582,9 @@ Experimental Memcache support has been added as an alternative to Redis in v1.5. To configure a Memcache instance use the following environment variables instead of the Redis variables: -1. `MEMCACHE_HOST_PORT=`: a comma separated list of hostname:port pairs for memcache nodes. +1. `MEMCACHE_HOST_PORT=`: a comma separated list of hostname:port pairs for memcache nodes (mutually exclusive with `MEMCACHE_SRV`) +1. `MEMCACHE_SRV=`: an SRV record to lookup hosts from (mutually exclusive with `MEMCACHE_HOST_PORT`) +1. `MEMCACHE_SRV_REFRESH=0`: refresh the list of hosts every n seconds, if 0 no refreshing will happen, supports duration suffixes: "ns", "us" (or "µs"), "ms", "s", "m", "h". 1. `BACKEND_TYPE=memcache` 1. `CACHE_KEY_PREFIX`: a string to prepend to all cache keys diff --git a/src/memcached/cache_impl.go b/src/memcached/cache_impl.go index 1e7b0b69a..17a71261c 100644 --- a/src/memcached/cache_impl.go +++ b/src/memcached/cache_impl.go @@ -20,6 +20,7 @@ import ( "math/rand" "strconv" "sync" + "time" "github.com/coocood/freecache" stats "github.com/lyft/gostats" @@ -33,6 +34,7 @@ import ( "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/settings" + "github.com/envoyproxy/ratelimit/src/srv" "github.com/envoyproxy/ratelimit/src/utils" ) @@ -173,6 +175,68 @@ func (this *rateLimitMemcacheImpl) Flush() { this.waitGroup.Wait() } +func refreshServersPeriodically(serverList memcache.ServerList, srv string, d time.Duration, finish <-chan struct{}) { + t := time.NewTicker(d) + defer t.Stop() + for { + select { + case <-t.C: + err := refreshServers(serverList, srv) + if err != nil { + logger.Warn("failed to refresh memcahce hosts") + } else { + logger.Debug("refreshed memcache hosts") + } + case <-finish: + return + } + } +} + +func refreshServers(serverList memcache.ServerList, srv_ string) error { + servers, err := srv.ServerStringsFromSrv(srv_) + if err != nil { + return err + } + err = serverList.SetServers(servers...) + if err != nil { + return err + } + return nil +} + +func newMemcachedFromSrv(srv_ string, d time.Duration) Client { + serverList := new(memcache.ServerList) + err := refreshServers(*serverList, srv_) + if err != nil { + errorText := "Unable to fetch servers from SRV" + logger.Errorf(errorText) + panic(MemcacheError(errorText)) + } + + if d > 0 { + logger.Infof("refreshing memcache hosts every: %v milliseconds", d.Milliseconds()) + finish := make(chan struct{}) + go refreshServersPeriodically(*serverList, srv_, d, finish) + } else { + logger.Debugf("not periodically refreshing memcached hosts") + } + + return memcache.NewFromSelector(serverList) +} + +func newMemcacheFromSettings(s settings.Settings) Client { + if s.MemcacheSrv != "" && len(s.MemcacheHostPort) > 0 { + panic(MemcacheError("Both MEMCADHE_HOST_PORT and MEMCACHE_SRV are set")) + } + if s.MemcacheSrv != "" { + logger.Debugf("Using MEMCACHE_SRV: %v", s.MemcacheSrv) + return newMemcachedFromSrv(s.MemcacheSrv, s.MemcacheSrvSeconds) + } + logger.Debugf("Usng MEMCACHE_HOST_PORT:: %v", s.MemcacheHostPort) + return memcache.New(s.MemcacheHostPort...) +} + func NewRateLimitCacheImpl(client Client, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, localCache *freecache.Cache, scope stats.Scope, nearLimitRatio float32, cacheKeyPrefix string) limiter.RateLimitCache { return &rateLimitMemcacheImpl{ @@ -189,7 +253,7 @@ func NewRateLimitCacheImpl(client Client, timeSource utils.TimeSource, jitterRan func NewRateLimitCacheImplFromSettings(s settings.Settings, timeSource utils.TimeSource, jitterRand *rand.Rand, localCache *freecache.Cache, scope stats.Scope) limiter.RateLimitCache { return NewRateLimitCacheImpl( - CollectStats(memcache.New(s.MemcacheHostPort...), scope.Scope("memcache")), + CollectStats(newMemcacheFromSettings(s), scope.Scope("memcache")), timeSource, jitterRand, s.ExpirationJitterMaxSeconds, diff --git a/src/memcached/client.go b/src/memcached/client.go index 55c0ec318..e80902692 100644 --- a/src/memcached/client.go +++ b/src/memcached/client.go @@ -4,6 +4,13 @@ import ( "github.com/bradfitz/gomemcache/memcache" ) +// Errors that may be raised during config parsing. +type MemcacheError string + +func (e MemcacheError) Error() string { + return string(e) +} + var _ Client = (*memcache.Client)(nil) // Interface for memcached, used for mocking. diff --git a/src/settings/settings.go b/src/settings/settings.go index 82b8ebe59..b7edf7ce1 100644 --- a/src/settings/settings.go +++ b/src/settings/settings.go @@ -59,7 +59,7 @@ type Settings struct { RedisPerSecondPoolSize int `envconfig:"REDIS_PERSECOND_POOL_SIZE" default:"10"` RedisPerSecondAuth string `envconfig:"REDIS_PERSECOND_AUTH" default:""` RedisPerSecondTls bool `envconfig:"REDIS_PERSECOND_TLS" default:"false"` - // RedisPerSecondPipelineWindow sets the duration after which internal pipelines will be flushed for per second redis. + // RedisPerSec ondPipelineWindow sets the duration after which internal pipelines will be flushed for per second redis. // See comments of RedisPipelineWindow for details. RedisPerSecondPipelineWindow time.Duration `envconfig:"REDIS_PERSECOND_PIPELINE_WINDOW" default:"0"` // RedisPerSecondPipelineLimit sets maximum number of commands that can be pipelined before flushing for per second redis. @@ -67,7 +67,9 @@ type Settings struct { RedisPerSecondPipelineLimit int `envconfig:"REDIS_PERSECOND_PIPELINE_LIMIT" default:"0"` // Memcache settings - MemcacheHostPort []string `envconfig:"MEMCACHE_HOST_PORT" default:""` + MemcacheHostPort []string `envconfig:"MEMCACHE_HOST_PORT" default:""` + MemcacheSrv string `envconfig:"MEMCACHE_SRV" default:""` + MemcacheSrvRefresh time.Duration `envconfig:"MEMCACHE_SRV_REFRESH" default:"0"` } type Option func(*Settings) diff --git a/src/srv/srv.go b/src/srv/srv.go new file mode 100644 index 000000000..041ceb950 --- /dev/null +++ b/src/srv/srv.go @@ -0,0 +1,49 @@ +package srv + +import ( + "errors" + "fmt" + "net" + "regexp" + + logger "github.com/sirupsen/logrus" +) + +var srvRegex = regexp.MustCompile(`^_(.+?)\._(.+?)\.(.+)$`) + +func ParseSrv(srv string) (string, string, string, error) { + matches := srvRegex.FindStringSubmatch(srv) + if matches == nil { + errorText := fmt.Sprintf("could not parse %s to SRV parts", srv) + logger.Errorf(errorText) + return "", "", "", errors.New(errorText) + } + return matches[1], matches[2], matches[3], nil +} + +func ServerStringsFromSrv(srv string) ([]string, error) { + service, proto, name, err := ParseSrv(srv) + + if err != nil { + logger.Errorf("failed to parse SRV: %s", err) + return nil, err + } + + _, srvs, err := net.LookupSRV(service, proto, name) + + if err != nil { + logger.Errorf("failed to lookup SRV: %s", err) + return nil, err + } + + logger.Debugf("found %v servers(s) from SRV", len(srvs)) + + serversFromSrv := make([]string, len(srvs)) + for i, srv := range srvs { + server := fmt.Sprintf("%s:%v", srv.Target, srv.Port) + logger.Debugf("server from srv[%v]: %s", i, server) + serversFromSrv[i] = fmt.Sprintf("%s:%v", srv.Target, srv.Port) + } + + return serversFromSrv, nil +} diff --git a/test/memcached/cache_impl_test.go b/test/memcached/cache_impl_test.go index 1e2ba8d77..652ac7c7a 100644 --- a/test/memcached/cache_impl_test.go +++ b/test/memcached/cache_impl_test.go @@ -16,6 +16,7 @@ import ( "github.com/envoyproxy/ratelimit/src/config" "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/memcached" + "github.com/envoyproxy/ratelimit/src/settings" "github.com/envoyproxy/ratelimit/src/utils" stats "github.com/lyft/gostats" @@ -583,6 +584,41 @@ func TestMemcacheAdd(t *testing.T) { cache.Flush() } +func TestNewRateLimitCacheImplFromSettingsWhenSrvCannotBeResolved(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + + var s settings.Settings + s.NearLimitRatio = 0.8 + s.CacheKeyPrefix = "" + s.ExpirationJitterMaxSeconds = 300 + s.MemcacheSrv = "_something._tcp.example.invalid" + + assert.Panics(func() { memcached.NewRateLimitCacheImplFromSettings(s, timeSource, nil, nil, statsStore) }) +} + +func TestNewRateLimitCacheImplFromSettingsWhenHostAndPortAndSrvAreBothSet(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + + timeSource := mock_utils.NewMockTimeSource(controller) + statsStore := stats.NewStore(stats.NewNullSink(), false) + + var s settings.Settings + s.NearLimitRatio = 0.8 + s.CacheKeyPrefix = "" + s.ExpirationJitterMaxSeconds = 300 + s.MemcacheSrv = "_something._tcp.example.invalid" + s.MemcacheHostPort = []string{"example.org:11211"} + + assert.Panics(func() { memcached.NewRateLimitCacheImplFromSettings(s, timeSource, nil, nil, statsStore) }) +} + func getMultiResult(vals map[string]int) map[string]*memcache.Item { result := make(map[string]*memcache.Item, len(vals)) for k, v := range vals { diff --git a/test/srv/srv_test.go b/test/srv/srv_test.go new file mode 100644 index 000000000..5e3e8f79f --- /dev/null +++ b/test/srv/srv_test.go @@ -0,0 +1,56 @@ +package srv + +import ( + "errors" + "net" + "testing" + + "github.com/envoyproxy/ratelimit/src/srv" + "github.com/stretchr/testify/assert" +) + +func TestParseSrv(t *testing.T) { + service, proto, name, err := srv.ParseSrv("_something._tcp.example.org.") + assert.Equal(t, service, "something") + assert.Equal(t, proto, "tcp") + assert.Equal(t, name, "example.org.") + assert.Nil(t, err) + + service, proto, name, err = srv.ParseSrv("_something-else._udp.example.org") + assert.Equal(t, service, "something-else") + assert.Equal(t, proto, "udp") + assert.Equal(t, name, "example.org") + assert.Nil(t, err) + + _, _, _, err = srv.ParseSrv("example.org") + assert.Equal(t, err, errors.New("could not parse example.org to SRV parts")) +} + +func TestServerStringsFromSrvWhenSrvIsNotWellFormed(t *testing.T) { + _, err := srv.ServerStringsFromSrv("example.org") + assert.Equal(t, err, errors.New("could not parse example.org to SRV parts")) +} + +func TestServerStringsFromSevWhenSrvIsWellFormedButNotLookupable(t *testing.T) { + _, err := srv.ServerStringsFromSrv("_something._tcp.example.invalid") + var e *net.DNSError + if errors.As(err, &e) { + assert.Equal(t, e.Err, "no such host") + assert.Equal(t, e.Name, "_something._tcp.example.invalid") + assert.False(t, e.IsTimeout) + assert.False(t, e.IsTemporary) + assert.True(t, e.IsNotFound) + } else { + t.Fail() + } +} + +func TestServerStrings(t *testing.T) { + // it seems reasonable to think _xmpp-server._tcp.gmail.com will be available for a long time! + servers, err := srv.ServerStringsFromSrv("_xmpp-server._tcp.gmail.com.") + assert.True(t, len(servers) > 0) + for _, s := range servers { + assert.Regexp(t, `^.*xmpp-server.*google.com.:\d+$`, s) + } + assert.Nil(t, err) +}