Skip to content

Commit

Permalink
Fix errors with reload lua scripts (#124)
Browse files Browse the repository at this point in the history
* refactor(store/redis): Fix errors with 'SCRIPT FLUSH' command

* refactor(store/redis): Fix benchmark names
  • Loading branch information
novln authored Nov 9, 2020
1 parent aa44c77 commit b298ea8
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 100 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ _Dead simple rate limit middleware for Go._
Using [Go Modules](https://github.com/golang/go/wiki/Modules)

```bash
$ go get github.com/ulule/limiter/v3@v3.5.0
$ go get github.com/ulule/limiter/v3@v3.7.1
```

## Usage
Expand Down Expand Up @@ -79,7 +79,6 @@ import "github.com/ulule/limiter/v3/drivers/store/redis"

store, err := redis.NewStoreWithOptions(pool, limiter.StoreOptions{
Prefix: "your_own_prefix",
MaxRetry: 4,
})
if err != nil {
panic(err)
Expand Down
178 changes: 114 additions & 64 deletions drivers/store/redis/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"strings"
"sync"
"sync/atomic"
"time"

libredis "github.com/go-redis/redis/v8"
Expand Down Expand Up @@ -55,18 +56,19 @@ type Client interface {
type Store struct {
// Prefix used for the key.
Prefix string
// deprecated, this option make no sense when all operations were atomic
// MaxRetry is the maximum number of retry under race conditions.
// Deprecated: this option is no longer required since all operations are atomic now.
MaxRetry int
// client used to communicate with redis server.
client Client
// luaIncrSHA is the SHA of increase and expire key script
// luaMutex is a mutex used to avoid concurrent access on luaIncrSHA and luaPeekSHA.
luaMutex sync.RWMutex
// luaLoaded is used for CAS and reduce pressure on luaMutex.
luaLoaded uint32
// luaIncrSHA is the SHA of increase and expire key script.
luaIncrSHA string
// luaPeekSHA is the SHA of peek and expire key script
// luaPeekSHA is the SHA of peek and expire key script.
luaPeekSHA string
// hasLuaScriptLoaded was used to check whether the lua script was loaded or not
hasLuaScriptLoaded bool
mu sync.Mutex
}

// NewStore returns an instance of redis store with defaults.
Expand All @@ -81,125 +83,173 @@ func NewStore(client Client) (limiter.Store, error) {
// NewStoreWithOptions returns an instance of redis store with options.
func NewStoreWithOptions(client Client, options limiter.StoreOptions) (limiter.Store, error) {
store := &Store{
client: client,
Prefix: options.Prefix,
MaxRetry: options.MaxRetry,
hasLuaScriptLoaded: false,
client: client,
Prefix: options.Prefix,
MaxRetry: options.MaxRetry,
}

if store.MaxRetry <= 0 {
store.MaxRetry = 1
}
if err := store.preloadLuaScripts(context.Background()); err != nil {
err := store.preloadLuaScripts(context.Background())
if err != nil {
return nil, err
}
return store, nil
}

// preloadLuaScripts would preload the incr and peek lua script
func (store *Store) preloadLuaScripts(ctx context.Context) error {
store.mu.Lock()
defer store.mu.Unlock()
if store.hasLuaScriptLoaded {
return nil
}
incrLuaSHA, err := store.client.ScriptLoad(ctx, luaIncrScript).Result()
if err != nil {
return errors.Wrap(err, "failed to load incr lua script")
}
peekLuaSHA, err := store.client.ScriptLoad(ctx, luaPeekScript).Result()
if err != nil {
return errors.Wrap(err, "failed to load peek lua script")
}
store.luaIncrSHA = incrLuaSHA
store.luaPeekSHA = peekLuaSHA
store.hasLuaScriptLoaded = true
return nil
return store, nil
}

// Get returns the limit for given identifier.
func (store *Store) Get(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
key = fmt.Sprintf("%s:%s", store.Prefix, key)
cmd := store.evalSHA(ctx, store.luaIncrSHA, []string{key}, 1, rate.Period.Milliseconds())
cmd := store.evalSHA(ctx, store.getLuaIncrSHA, []string{key}, 1, rate.Period.Milliseconds())
count, ttl, err := parseCountAndTTL(cmd)
if err != nil {
return limiter.Context{}, err
}

now := time.Now()
expiration := now.Add(rate.Period)
if ttl > 0 {
expiration = now.Add(time.Duration(ttl) * time.Millisecond)
}

return common.GetContextFromState(now, rate, expiration, count), nil
}

// Peek returns the limit for given identifier, without modification on current values.
func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
key = fmt.Sprintf("%s:%s", store.Prefix, key)
cmd := store.evalSHA(ctx, store.luaPeekSHA, []string{key})
cmd := store.evalSHA(ctx, store.getLuaPeekSHA, []string{key})
count, ttl, err := parseCountAndTTL(cmd)
if err != nil {
return limiter.Context{}, err
}

now := time.Now()
expiration := now.Add(rate.Period)
if ttl > 0 {
expiration = now.Add(time.Duration(ttl) * time.Millisecond)
}

return common.GetContextFromState(now, rate, expiration, count), nil
}

// Reset returns the limit for given identifier which is set to zero.
func (store *Store) Reset(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
key = fmt.Sprintf("%s:%s", store.Prefix, key)
if _, err := store.client.Del(ctx, key).Result(); err != nil {
_, err := store.client.Del(ctx, key).Result()
if err != nil {
return limiter.Context{}, err
}

count := int64(0)
now := time.Now()
expiration := now.Add(rate.Period)

return common.GetContextFromState(now, rate, expiration, count), nil
}

// evalSHA eval the redis lua sha and load the script if missing
func (store *Store) evalSHA(ctx context.Context, sha string, keys []string, args ...interface{}) *libredis.Cmd {
cmd := store.client.EvalSha(ctx, sha, keys, args...)
if err := cmd.Err(); err != nil {
if !isLuaScriptGone(err) {
return cmd
}
store.mu.Lock()
store.hasLuaScriptLoaded = false
store.mu.Unlock()
if err := store.preloadLuaScripts(ctx); err != nil {
cmd = libredis.NewCmd(ctx)
cmd.SetErr(err)
return cmd
}
cmd = store.client.EvalSha(ctx, sha, keys)
}
return cmd
// preloadLuaScripts preloads the "incr" and "peek" lua scripts.
func (store *Store) preloadLuaScripts(ctx context.Context) error {
// Verify if we need to load lua scripts.
// Inspired by sync.Once.
if atomic.LoadUint32(&store.luaLoaded) == 0 {
return store.loadLuaScripts(ctx)
}
return nil
}

// reloadLuaScripts forces a reload of "incr" and "peek" lua scripts.
func (store *Store) reloadLuaScripts(ctx context.Context) error {
// Reset lua scripts loaded state.
// Inspired by sync.Once.
atomic.StoreUint32(&store.luaLoaded, 0)
return store.loadLuaScripts(ctx)
}

// isLuaScriptGone check whether the error was no script or no
// loadLuaScripts load "incr" and "peek" lua scripts.
// WARNING: Please use preloadLuaScripts or reloadLuaScripts, instead of this one.
func (store *Store) loadLuaScripts(ctx context.Context) error {
store.luaMutex.Lock()
defer store.luaMutex.Unlock()

// Check if scripts are already loaded.
if atomic.LoadUint32(&store.luaLoaded) != 0 {
return nil
}

luaIncrSHA, err := store.client.ScriptLoad(ctx, luaIncrScript).Result()
if err != nil {
return errors.Wrap(err, `failed to load "incr" lua script`)
}

luaPeekSHA, err := store.client.ScriptLoad(ctx, luaPeekScript).Result()
if err != nil {
return errors.Wrap(err, `failed to load "peek" lua script`)
}

store.luaIncrSHA = luaIncrSHA
store.luaPeekSHA = luaPeekSHA

atomic.StoreUint32(&store.luaLoaded, 1)

return nil
}

// getLuaIncrSHA returns a "thread-safe" value for luaIncrSHA.
func (store *Store) getLuaIncrSHA() string {
store.luaMutex.RLock()
defer store.luaMutex.RUnlock()
return store.luaIncrSHA
}

// getLuaPeekSHA returns a "thread-safe" value for luaPeekSHA.
func (store *Store) getLuaPeekSHA() string {
store.luaMutex.RLock()
defer store.luaMutex.RUnlock()
return store.luaPeekSHA
}

// evalSHA eval the redis lua sha and load the scripts if missing.
func (store *Store) evalSHA(ctx context.Context, getSha func() string,
keys []string, args ...interface{}) *libredis.Cmd {

cmd := store.client.EvalSha(ctx, getSha(), keys, args...)
err := cmd.Err()
if err == nil || !isLuaScriptGone(err) {
return cmd
}

err = store.reloadLuaScripts(ctx)
if err != nil {
cmd = libredis.NewCmd(ctx)
cmd.SetErr(err)
return cmd
}

return store.client.EvalSha(ctx, getSha(), keys, args...)
}

// isLuaScriptGone returns if the error is a missing lua script from redis server.
func isLuaScriptGone(err error) bool {
return strings.HasPrefix(err.Error(), "NOSCRIPT")
}

// parseCountAndTTL parse count and ttl from lua script output
// parseCountAndTTL parse count and ttl from lua script output.
func parseCountAndTTL(cmd *libredis.Cmd) (int64, int64, error) {
ret, err := cmd.Result()
result, err := cmd.Result()
if err != nil {
return 0, 0, err
return 0, 0, errors.Wrap(err, "an error has occurred with redis command")
}
if fields, ok := ret.([]interface{}); !ok || len(fields) != 2 {
return 0, 0, errors.New("two elements in array was expected")

fields, ok := result.([]interface{})
if !ok || len(fields) != 2 {
return 0, 0, errors.New("two elements in result were expected")
}
fields := ret.([]interface{})

count, ok1 := fields[0].(int64)
ttl, ok2 := fields[1].(int64)
if !ok1 || !ok2 {
return 0, 0, errors.New("type of the count and ttl should be number")
return 0, 0, errors.New("type of the count and/or ttl should be number")
}

return count, ttl, nil
}
59 changes: 33 additions & 26 deletions drivers/store/redis/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ func TestRedisStoreSequentialAccess(t *testing.T) {
is.NotNil(client)

store, err := redis.NewStoreWithOptions(client, limiter.StoreOptions{
Prefix: "limiter:redis:sequential",
MaxRetry: 3,
Prefix: "limiter:redis:sequential-test",
})
is.NoError(err)
is.NotNil(store)
Expand All @@ -39,8 +38,7 @@ func TestRedisStoreConcurrentAccess(t *testing.T) {
is.NotNil(client)

store, err := redis.NewStoreWithOptions(client, limiter.StoreOptions{
Prefix: "limiter:redis:concurrent",
MaxRetry: 7,
Prefix: "limiter:redis:concurrent-test",
})
is.NoError(err)
is.NotNil(store)
Expand Down Expand Up @@ -94,40 +92,49 @@ func TestRedisClientExpiration(t *testing.T) {
is.Greater(actual, expected)
}

func newRedisClient() (*libredis.Client, error) {
uri := "redis://localhost:6379/0"
if os.Getenv("REDIS_URI") != "" {
uri = os.Getenv("REDIS_URI")
}
func BenchmarkRedisStoreSequentialAccess(b *testing.B) {
is := require.New(b)

opt, err := libredis.ParseURL(uri)
if err != nil {
return nil, err
}
client, err := newRedisClient()
is.NoError(err)
is.NotNil(client)

client := libredis.NewClient(opt)
return client, nil
store, err := redis.NewStoreWithOptions(client, limiter.StoreOptions{
Prefix: "limiter:redis:sequential-benchmark",
})
is.NoError(err)
is.NotNil(store)

tests.BenchmarkStoreSequentialAccess(b, store)
}

func BenchmarkGet(b *testing.B) {
func BenchmarkRedisStoreConcurrentAccess(b *testing.B) {
is := require.New(b)

client, err := newRedisClient()
is.NoError(err)
is.NotNil(client)

store, err := redis.NewStoreWithOptions(client, limiter.StoreOptions{
Prefix: "limiter:redis:benchmark",
MaxRetry: 3,
Prefix: "limiter:redis:concurrent-benchmark",
})
is.NoError(err)
is.NotNil(store)
limiter := limiter.New(store, limiter.Rate{
Limit: 100000,
Period: 10 * time.Second,
})

for i := 0; i < b.N; i++ {
lctx, err := limiter.Get(context.TODO(), "foo")
is.NoError(err)
is.NotZero(lctx)
tests.BenchmarkStoreConcurrentAccess(b, store)
}

func newRedisClient() (*libredis.Client, error) {
uri := "redis://localhost:6379/0"
if os.Getenv("REDIS_URI") != "" {
uri = os.Getenv("REDIS_URI")
}

opt, err := libredis.ParseURL(uri)
if err != nil {
return nil, err
}

client := libredis.NewClient(opt)
return client, nil
}
Loading

0 comments on commit b298ea8

Please sign in to comment.