diff --git a/drivers/store/redis/store.go b/drivers/store/redis/store.go index cbf981c..af97931 100644 --- a/drivers/store/redis/store.go +++ b/drivers/store/redis/store.go @@ -259,11 +259,10 @@ func updateValue(rtx *libredis.Tx, key string, expiration time.Duration) (int64, return 0, 0, err } - // If ttl is -1ms, we have to define key expiration. - // PTTL return values changed as of Redis 2.8 - // Now the command returns -2ms if the key does not exist, and -1ms if the key exists, but there is no expiry set - // We shouldn't try to set an expiry on a key that doesn't exist - if ttl == (-1 * time.Millisecond) { + // If ttl is less than zero, we have to define key expiration. + // The PTTL command returns -2 if the key does not exist, and -1 if the key exists, but there is no expiry set. + // We shouldn't try to set an expiry on a key that doesn't exist. + if isExpirationRequired(ttl) { expire := rtx.Expire(key, expiration) ok, err := expire.Result() @@ -313,3 +312,16 @@ func (store *Store) ping() (bool, error) { return (pong == "PONG"), nil } + +// isExpirationRequired returns if we should set an expiration on a key, using (error) result from PTTL command. +// The error code is -2 if the key does not exist, and -1 if the key exists. +// Usually, it should be returned in nanosecond, but some users have reported that it could be in millisecond as well. +// Better safe than sorry: we handle both. +func isExpirationRequired(ttl time.Duration) bool { + switch ttl { + case -1 * time.Nanosecond, -1 * time.Millisecond: + return true + default: + return false + } +} diff --git a/drivers/store/redis/store_test.go b/drivers/store/redis/store_test.go index be58c31..7b0c103 100644 --- a/drivers/store/redis/store_test.go +++ b/drivers/store/redis/store_test.go @@ -3,6 +3,7 @@ package redis_test import ( "os" "testing" + "time" libredis "github.com/go-redis/redis/v7" "github.com/stretchr/testify/require" @@ -46,6 +47,51 @@ func TestRedisStoreConcurrentAccess(t *testing.T) { tests.TestStoreConcurrentAccess(t, store) } +func TestRedisClientExpiration(t *testing.T) { + is := require.New(t) + + client, err := newRedisClient() + is.NoError(err) + is.NotNil(client) + + key := "foobar" + value := 642 + keyNoExpiration := -1 * time.Nanosecond + keyNotExist := -2 * time.Nanosecond + + delCmd := client.Del(key) + _, err = delCmd.Result() + is.NoError(err) + + expCmd := client.PTTL(key) + ttl, err := expCmd.Result() + is.NoError(err) + is.Equal(keyNotExist, ttl) + + setCmd := client.Set(key, value, 0) + _, err = setCmd.Result() + is.NoError(err) + + expCmd = client.PTTL(key) + ttl, err = expCmd.Result() + is.NoError(err) + is.Equal(keyNoExpiration, ttl) + + setCmd = client.Set(key, value, time.Second) + _, err = setCmd.Result() + is.NoError(err) + + time.Sleep(100 * time.Millisecond) + + expCmd = client.PTTL(key) + ttl, err = expCmd.Result() + is.NoError(err) + + expected := int64(0) + actual := int64(ttl) + is.Greater(actual, expected) +} + func newRedisClient() (*libredis.Client, error) { uri := "redis://localhost:6379/0" if os.Getenv("REDIS_URI") != "" {