-
Notifications
You must be signed in to change notification settings - Fork 490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature: add redis token refresh support #3238
Changes from 22 commits
20e6f07
5fb1c79
a2f7680
0ff785c
9cd23cb
a5c961e
c88df66
b9fb8e2
287851d
607cacf
2ebeaf0
c043100
93ff135
2c800c1
0eb9e36
82c1b54
faeab92
09674e2
e9f7064
24645ce
b81216b
5bb9248
fbf5a15
7f43880
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,10 +20,13 @@ import ( | |
"strings" | ||
"time" | ||
|
||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" | ||
"golang.org/x/mod/semver" | ||
|
||
"github.com/dapr/components-contrib/common/authentication/azure" | ||
"github.com/dapr/components-contrib/configuration" | ||
"github.com/dapr/components-contrib/metadata" | ||
"github.com/dapr/kit/logger" | ||
) | ||
|
||
const ( | ||
|
@@ -160,9 +163,9 @@ func ParseClientFromProperties(properties map[string]string, componentType metad | |
|
||
var c RedisClient | ||
if settings.Failover { | ||
c = newV8FailoverClient(settings) | ||
c = newV8FailoverClient(settings, properties) | ||
} else { | ||
c = newV8Client(settings) | ||
c = newV8Client(settings, properties) | ||
} | ||
version, versionErr := GetServerVersion(c) | ||
c.Close() // close the client to avoid leaking connections | ||
|
@@ -177,14 +180,14 @@ func ParseClientFromProperties(properties map[string]string, componentType metad | |
} | ||
if useNewClient { | ||
if settings.Failover { | ||
return newV9FailoverClient(settings), settings, nil | ||
return newV9FailoverClient(settings, properties), settings, nil | ||
} | ||
return newV9Client(settings), settings, nil | ||
return newV9Client(settings, properties), settings, nil | ||
} else { | ||
if settings.Failover { | ||
return newV8FailoverClient(settings), settings, nil | ||
return newV8FailoverClient(settings, properties), settings, nil | ||
} | ||
return newV8Client(settings), settings, nil | ||
return newV8Client(settings, properties), settings, nil | ||
} | ||
} | ||
|
||
|
@@ -252,3 +255,65 @@ type RedisError string | |
func (e RedisError) Error() string { return string(e) } | ||
|
||
func (RedisError) RedisError() {} | ||
|
||
func (s *Settings) refreshTokenRoutineForRedis(ctx context.Context, redisClient RedisClient, version string, meta map[string]string, logger logger.Logger) { | ||
ticker := time.NewTicker(tokenRefreshInterval) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The ticker value must be reset every time you get a new token. The value must be token expiration time minus 10 minutes (so that we refresh 10 minutes before expiration). |
||
defer ticker.Stop() | ||
|
||
if !s.useAzureAD { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would like to see this exit condition instead be used in the v8client and v9client only. if s.useAzureAd {
authenticateWithAzureRedis()
} or something like that |
||
return | ||
} | ||
|
||
maxRetries := 3 | ||
|
||
for { | ||
select { | ||
case <-ctx.Done(): | ||
return | ||
case <-ticker.C: | ||
retry := 0 | ||
for retry < maxRetries { | ||
env, err := azure.NewEnvironmentSettings(meta) | ||
if err != nil { | ||
logger.Error("Failed to get Azure environment settings:", err) | ||
retry++ | ||
continue | ||
} | ||
tokenCred, err := env.GetTokenCredential() | ||
if err != nil { | ||
logger.Error("Failed to get Azure AD token credential:", err) | ||
retry++ | ||
time.Sleep(5 * time.Second) | ||
continue | ||
} | ||
at, err := tokenCred.GetToken(ctx, policy.TokenRequestOptions{ | ||
Scopes: []string{ | ||
env.Cloud.Services[azure.ServiceOSSRDBMS].Audience + "/.default", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ItalyPaleAle does this |
||
}, | ||
}) | ||
if err != nil { | ||
logger.Error("Failed to get Azure AD token:", err) | ||
retry++ | ||
time.Sleep(5 * time.Second) | ||
continue | ||
} | ||
|
||
// Authenticate with Redis using the refreshed token | ||
var authErr error | ||
if version == "v8" { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, this is bad. Instead the That way you can then simply call References: |
||
authErr = redisClient.(v8Client).client.Pipeline().Auth(ctx, at.Token).Err() | ||
} else if version == "v9" { | ||
authErr = redisClient.(v9Client).client.Pipeline().Auth(ctx, at.Token).Err() | ||
} | ||
if authErr != nil { | ||
logger.Error("Failed to authenticate with Redis using refreshed Azure AD token:", authErr) | ||
berndverst marked this conversation as resolved.
Show resolved
Hide resolved
|
||
retry++ | ||
time.Sleep(5 * time.Second) | ||
continue | ||
} | ||
logger.Info("Successfully refreshed Azure AD token and re-authenticated Redis.") | ||
break | ||
} | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,13 +20,21 @@ import ( | |
"time" | ||
|
||
v9 "github.com/redis/go-redis/v9" | ||
|
||
"github.com/dapr/kit/logger" | ||
) | ||
|
||
type v9Pipeliner struct { | ||
pipeliner v9.Pipeliner | ||
writeTimeout Duration | ||
} | ||
|
||
var v9logger = logger.NewLogger("dapr.components.redisv9") | ||
|
||
const ( | ||
tokenRefreshInterval = 50 * time.Minute | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Apparently we do not have to do this every hour -- instead we should examine the most recent token expiration / duration, and then refresh the token say 10 minutes before it expires. The token validity is variable and can be longer than one hour. There should be a function you can use to examine the validity of the token. |
||
) | ||
|
||
func (p v9Pipeliner) Exec(ctx context.Context) error { | ||
_, err := p.pipeliner.Exec(ctx) | ||
return err | ||
|
@@ -317,7 +325,7 @@ func (c v9Client) TTLResult(ctx context.Context, key string) (time.Duration, err | |
return c.client.TTL(writeCtx, key).Result() | ||
} | ||
|
||
func newV9FailoverClient(s *Settings) RedisClient { | ||
func newV9FailoverClient(s *Settings, properties map[string]string) RedisClient { | ||
if s == nil { | ||
return nil | ||
} | ||
|
@@ -350,24 +358,26 @@ func newV9FailoverClient(s *Settings) RedisClient { | |
|
||
if s.RedisType == ClusterType { | ||
opts.SentinelAddrs = strings.Split(s.Host, ",") | ||
|
||
client := v9.NewFailoverClusterClient(opts) | ||
go s.refreshTokenRoutineForRedis(context.Background(), ClientFromV9Client(client), "v9", properties, v9logger) | ||
return v9Client{ | ||
client: v9.NewFailoverClusterClient(opts), | ||
client: client, | ||
readTimeout: s.ReadTimeout, | ||
writeTimeout: s.WriteTimeout, | ||
dialTimeout: s.DialTimeout, | ||
} | ||
} | ||
|
||
client := v9.NewFailoverClient(opts) | ||
go s.refreshTokenRoutineForRedis(context.Background(), ClientFromV9Client(client), "v9", properties, v9logger) | ||
return v9Client{ | ||
client: v9.NewFailoverClient(opts), | ||
client: client, | ||
readTimeout: s.ReadTimeout, | ||
writeTimeout: s.WriteTimeout, | ||
dialTimeout: s.DialTimeout, | ||
} | ||
} | ||
|
||
func newV9Client(s *Settings) RedisClient { | ||
func newV9Client(s *Settings, properties map[string]string) RedisClient { | ||
if s == nil { | ||
return nil | ||
} | ||
|
@@ -395,9 +405,10 @@ func newV9Client(s *Settings) RedisClient { | |
InsecureSkipVerify: s.EnableTLS, | ||
} | ||
} | ||
|
||
client := v9.NewClusterClient(options) | ||
go s.refreshTokenRoutineForRedis(context.Background(), ClientFromV9Client(client), "v9", properties, v9logger) | ||
return v9Client{ | ||
client: v9.NewClusterClient(options), | ||
client: client, | ||
readTimeout: s.ReadTimeout, | ||
writeTimeout: s.WriteTimeout, | ||
dialTimeout: s.DialTimeout, | ||
|
@@ -429,11 +440,16 @@ func newV9Client(s *Settings) RedisClient { | |
InsecureSkipVerify: s.EnableTLS, | ||
} | ||
} | ||
|
||
client := v9.NewClient(options) | ||
go s.refreshTokenRoutineForRedis(context.Background(), ClientFromV9Client(client), "v9", properties, v9logger) | ||
return v9Client{ | ||
client: v9.NewClient(options), | ||
client: client, | ||
readTimeout: s.ReadTimeout, | ||
writeTimeout: s.WriteTimeout, | ||
dialTimeout: s.DialTimeout, | ||
} | ||
} | ||
|
||
func ClientFromV9Client(client v9.UniversalClient) RedisClient { | ||
return v9Client{client: client} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rename this to