Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: add redis token refresh support #3238

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 71 additions & 6 deletions common/component/redis/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ import (
"strings"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"golang.org/x/mod/semver"

"github.com/dapr/components-contrib/common/authentication/azure"
"github.com/dapr/components-contrib/configuration"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/kit/logger"
)

const (
Expand Down Expand Up @@ -160,9 +163,9 @@ func ParseClientFromProperties(properties map[string]string, componentType metad

var c RedisClient
if settings.Failover {
c = newV8FailoverClient(settings)
c = newV8FailoverClient(settings, properties)
} else {
c = newV8Client(settings)
c = newV8Client(settings, properties)
}
version, versionErr := GetServerVersion(c)
c.Close() // close the client to avoid leaking connections
Expand All @@ -177,14 +180,14 @@ func ParseClientFromProperties(properties map[string]string, componentType metad
}
if useNewClient {
if settings.Failover {
return newV9FailoverClient(settings), settings, nil
return newV9FailoverClient(settings, properties), settings, nil
}
return newV9Client(settings), settings, nil
return newV9Client(settings, properties), settings, nil
} else {
if settings.Failover {
return newV8FailoverClient(settings), settings, nil
return newV8FailoverClient(settings, properties), settings, nil
}
return newV8Client(settings), settings, nil
return newV8Client(settings, properties), settings, nil
}
}

Expand Down Expand Up @@ -252,3 +255,65 @@ type RedisError string
func (e RedisError) Error() string { return string(e) }

func (RedisError) RedisError() {}

func (s *Settings) refreshTokenRoutineForRedis(ctx context.Context, redisClient RedisClient, version string, meta map[string]string, logger logger.Logger) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rename this to

Suggested change
func (s *Settings) refreshTokenRoutineForRedis(ctx context.Context, redisClient RedisClient, version string, meta map[string]string, logger logger.Logger) {
func (s *Settings) authenticateWithAzureRedis(ctx context.Context, redisClient RedisClient, version string, meta map[string]string, logger logger.Logger) {

ticker := time.NewTicker(tokenRefreshInterval)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ticker value must be reset every time you get a new token. The value must be token expiration time minus 10 minutes (so that we refresh 10 minutes before expiration).

defer ticker.Stop()

if !s.useAzureAD {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to see this exit condition instead be used in the v8client and v9client only.

if s.useAzureAd {
  authenticateWithAzureRedis()
}

or something like that

return
}

maxRetries := 3

for {
select {
case <-ctx.Done():
return
case <-ticker.C:
retry := 0
for retry < maxRetries {
env, err := azure.NewEnvironmentSettings(meta)
if err != nil {
logger.Error("Failed to get Azure environment settings:", err)
retry++
continue
}
tokenCred, err := env.GetTokenCredential()
if err != nil {
logger.Error("Failed to get Azure AD token credential:", err)
retry++
time.Sleep(5 * time.Second)
continue
}
at, err := tokenCred.GetToken(ctx, policy.TokenRequestOptions{
Scopes: []string{
env.Cloud.Services[azure.ServiceOSSRDBMS].Audience + "/.default",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ItalyPaleAle does this dapr.io/oss-rdbms environment setting always exist? I don't think so?

},
})
if err != nil {
logger.Error("Failed to get Azure AD token:", err)
retry++
time.Sleep(5 * time.Second)
continue
}

// Authenticate with Redis using the refreshed token
var authErr error
if version == "v8" {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is bad.

Instead the Auth method should be added to the interface in redis.go, then an implementation should be added for the v8client.go and v9client.go

That way you can then simply call client.Auth without using the Pipeline client and without using version specific clients in this code. The version should not be required here.

References:
https://pkg.go.dev/github.com/go-redis/redis/v8#Conn.Auth
https://pkg.go.dev/github.com/go-redis/redis/v9#Conn.Auth

authErr = redisClient.(v8Client).client.Pipeline().Auth(ctx, at.Token).Err()
} else if version == "v9" {
authErr = redisClient.(v9Client).client.Pipeline().Auth(ctx, at.Token).Err()
}
if authErr != nil {
logger.Error("Failed to authenticate with Redis using refreshed Azure AD token:", authErr)
berndverst marked this conversation as resolved.
Show resolved Hide resolved
retry++
time.Sleep(5 * time.Second)
continue
}
logger.Info("Successfully refreshed Azure AD token and re-authenticated Redis.")
break
}
}
}
}
2 changes: 2 additions & 0 deletions common/component/redis/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ type Settings struct {

// The max len of stream
MaxLenApprox int64 `mapstructure:"maxLenApprox" mdonly:"pubsub"`
// azureAD for authentication
useAzureAD bool `mapstructure:"useAzureAD"`
}

func (s *Settings) Decode(in interface{}) error {
Expand Down
27 changes: 18 additions & 9 deletions common/component/redis/v8client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@ import (
"time"

v8 "github.com/go-redis/redis/v8"

"github.com/dapr/kit/logger"
)

type v8Pipeliner struct {
pipeliner v8.Pipeliner
writeTimeout Duration
}

var v8logger = logger.NewLogger("dapr.components.redisv8")

func (p v8Pipeliner) Exec(ctx context.Context) error {
_, err := p.pipeliner.Exec(ctx)
return err
Expand Down Expand Up @@ -316,7 +320,7 @@ func (c v8Client) TTLResult(ctx context.Context, key string) (time.Duration, err
return c.client.TTL(writeCtx, key).Result()
}

func newV8FailoverClient(s *Settings) RedisClient {
func newV8FailoverClient(s *Settings, properties map[string]string) RedisClient {
if s == nil {
return nil
}
Expand Down Expand Up @@ -349,24 +353,26 @@ func newV8FailoverClient(s *Settings) RedisClient {

if s.RedisType == ClusterType {
opts.SentinelAddrs = strings.Split(s.Host, ",")

client := v8.NewFailoverClusterClient(opts)
go s.refreshTokenRoutineForRedis(context.Background(), ClientFromV8Client(client), "v8", properties, v8logger)
return v8Client{
client: v8.NewFailoverClusterClient(opts),
client: client,
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
}
}

client := v8.NewFailoverClient(opts)
go s.refreshTokenRoutineForRedis(context.Background(), ClientFromV8Client(client), "v8", properties, v8logger)
return v8Client{
client: v8.NewFailoverClient(opts),
client: client,
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
}
}

func newV8Client(s *Settings) RedisClient {
func newV8Client(s *Settings, properties map[string]string) RedisClient {
if s == nil {
return nil
}
Expand Down Expand Up @@ -394,9 +400,11 @@ func newV8Client(s *Settings) RedisClient {
InsecureSkipVerify: s.EnableTLS,
}
}
client := v8.NewClusterClient(options)
go s.refreshTokenRoutineForRedis(context.Background(), ClientFromV8Client(client), "v8", properties, v8logger)

return v8Client{
client: v8.NewClusterClient(options),
client: client,
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
Expand Down Expand Up @@ -428,9 +436,10 @@ func newV8Client(s *Settings) RedisClient {
InsecureSkipVerify: s.EnableTLS,
}
}

client := v8.NewClient(options)
go s.refreshTokenRoutineForRedis(context.Background(), ClientFromV8Client(client), "v8", properties, v8logger)
return v8Client{
client: v8.NewClient(options),
client: client,
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
Expand Down
36 changes: 26 additions & 10 deletions common/component/redis/v9client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,21 @@ import (
"time"

v9 "github.com/redis/go-redis/v9"

"github.com/dapr/kit/logger"
)

type v9Pipeliner struct {
pipeliner v9.Pipeliner
writeTimeout Duration
}

var v9logger = logger.NewLogger("dapr.components.redisv9")

const (
tokenRefreshInterval = 50 * time.Minute
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently we do not have to do this every hour -- instead we should examine the most recent token expiration / duration, and then refresh the token say 10 minutes before it expires.

The token validity is variable and can be longer than one hour.

There should be a function you can use to examine the validity of the token.

)

func (p v9Pipeliner) Exec(ctx context.Context) error {
_, err := p.pipeliner.Exec(ctx)
return err
Expand Down Expand Up @@ -317,7 +325,7 @@ func (c v9Client) TTLResult(ctx context.Context, key string) (time.Duration, err
return c.client.TTL(writeCtx, key).Result()
}

func newV9FailoverClient(s *Settings) RedisClient {
func newV9FailoverClient(s *Settings, properties map[string]string) RedisClient {
if s == nil {
return nil
}
Expand Down Expand Up @@ -350,24 +358,26 @@ func newV9FailoverClient(s *Settings) RedisClient {

if s.RedisType == ClusterType {
opts.SentinelAddrs = strings.Split(s.Host, ",")

client := v9.NewFailoverClusterClient(opts)
go s.refreshTokenRoutineForRedis(context.Background(), ClientFromV9Client(client), "v9", properties, v9logger)
return v9Client{
client: v9.NewFailoverClusterClient(opts),
client: client,
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
}
}

client := v9.NewFailoverClient(opts)
go s.refreshTokenRoutineForRedis(context.Background(), ClientFromV9Client(client), "v9", properties, v9logger)
return v9Client{
client: v9.NewFailoverClient(opts),
client: client,
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
}
}

func newV9Client(s *Settings) RedisClient {
func newV9Client(s *Settings, properties map[string]string) RedisClient {
if s == nil {
return nil
}
Expand Down Expand Up @@ -395,9 +405,10 @@ func newV9Client(s *Settings) RedisClient {
InsecureSkipVerify: s.EnableTLS,
}
}

client := v9.NewClusterClient(options)
go s.refreshTokenRoutineForRedis(context.Background(), ClientFromV9Client(client), "v9", properties, v9logger)
return v9Client{
client: v9.NewClusterClient(options),
client: client,
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
Expand Down Expand Up @@ -429,11 +440,16 @@ func newV9Client(s *Settings) RedisClient {
InsecureSkipVerify: s.EnableTLS,
}
}

client := v9.NewClient(options)
go s.refreshTokenRoutineForRedis(context.Background(), ClientFromV9Client(client), "v9", properties, v9logger)
return v9Client{
client: v9.NewClient(options),
client: client,
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
}
}

func ClientFromV9Client(client v9.UniversalClient) RedisClient {
return v9Client{client: client}
}
Loading