Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: add redis token refresh support #3238

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions common/component/redis/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ func ParseClientFromProperties(properties map[string]string, componentType metad

var c RedisClient
if settings.Failover {
c = newV8FailoverClient(settings)
c = newV8FailoverClient(settings, properties)
} else {
c = newV8Client(settings)
c = newV8Client(settings, properties)
}
version, versionErr := GetServerVersion(c)
c.Close() // close the client to avoid leaking connections
Expand All @@ -177,14 +177,14 @@ func ParseClientFromProperties(properties map[string]string, componentType metad
}
if useNewClient {
if settings.Failover {
return newV9FailoverClient(settings), settings, nil
return newV9FailoverClient(settings, properties), settings, nil
}
return newV9Client(settings), settings, nil
return newV9Client(settings, properties), settings, nil
} else {
if settings.Failover {
return newV8FailoverClient(settings), settings, nil
return newV8FailoverClient(settings, properties), settings, nil
}
return newV8Client(settings), settings, nil
return newV8Client(settings, properties), settings, nil
}
}

Expand Down
25 changes: 18 additions & 7 deletions common/component/redis/v8client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"strings"
"time"

"github.com/dapr/kit/logger"

Check failure on line 22 in common/component/redis/v8client.go

View workflow job for this annotation

GitHub Actions / Build linux_amd64 binaries

File is not `goimports`-ed with -local github.com/dapr/ (goimports)
v8 "github.com/go-redis/redis/v8"
)

Expand All @@ -27,6 +28,8 @@
writeTimeout Duration
}

var v8logger = logger.NewLogger("dapr.components.redisv8")

func (p v8Pipeliner) Exec(ctx context.Context) error {
_, err := p.pipeliner.Exec(ctx)
return err
Expand All @@ -48,6 +51,7 @@
readTimeout Duration
writeTimeout Duration
dialTimeout Duration
closeCh chan struct{}
}

func (c v8Client) GetDel(ctx context.Context, key string) (string, error) {
Expand Down Expand Up @@ -316,10 +320,11 @@
return c.client.TTL(writeCtx, key).Result()
}

func newV8FailoverClient(s *Settings) RedisClient {
func newV8FailoverClient(s *Settings, properties map[string]string) RedisClient {
if s == nil {
return nil
}
closeCh := make(chan struct{})
opts := &v8.FailoverOptions{
DB: s.DB,
MasterName: s.SentinelMasterName,
Expand Down Expand Up @@ -349,27 +354,28 @@

if s.RedisType == ClusterType {
opts.SentinelAddrs = strings.Split(s.Host, ",")

return v8Client{
client: v8.NewFailoverClusterClient(opts),
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
closeCh: closeCh,
}
}

return v8Client{
client: v8.NewFailoverClient(opts),
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
closeCh: closeCh,
}
}

func newV8Client(s *Settings) RedisClient {
func newV8Client(s *Settings, properties map[string]string) RedisClient {
if s == nil {
return nil
}
closeCh := make(chan struct{})
if s.RedisType == ClusterType {
options := &v8.ClusterOptions{
Addrs: strings.Split(s.Host, ","),
Expand All @@ -394,12 +400,15 @@
InsecureSkipVerify: s.EnableTLS,
}
}
client := v8.NewClusterClient(options)
go refreshTokenRoutineForRedis(context.Background(), ClientFromV8Client(client), "v8", properties, v8logger, closeCh)
sadath-12 marked this conversation as resolved.
Show resolved Hide resolved

return v8Client{
client: v8.NewClusterClient(options),
client: client,
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
closeCh: closeCh,
}
}

Expand Down Expand Up @@ -428,12 +437,14 @@
InsecureSkipVerify: s.EnableTLS,
}
}

client := v8.NewClient(options)
go refreshTokenRoutineForRedis(context.Background(), ClientFromV8Client(client), "v8", properties, v8logger, closeCh)
return v8Client{
client: v8.NewClient(options),
client: client,
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
closeCh: closeCh,
}
}

Expand Down
76 changes: 68 additions & 8 deletions common/component/redis/v9client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
"strings"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"

Check failure on line 22 in common/component/redis/v9client.go

View workflow job for this annotation

GitHub Actions / Build linux_amd64 binaries

File is not `goimports`-ed with -local github.com/dapr/ (goimports)
"github.com/dapr/components-contrib/common/authentication/azure"
"github.com/dapr/kit/logger"
v9 "github.com/redis/go-redis/v9"
)

Expand All @@ -27,6 +30,12 @@
writeTimeout Duration
}

var v9logger = logger.NewLogger("dapr.components.redisv9")

const (
tokenRefreshInterval = 50 * time.Minute
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently we do not have to do this every hour -- instead we should examine the most recent token expiration / duration, and then refresh the token say 10 minutes before it expires.

The token validity is variable and can be longer than one hour.

There should be a function you can use to examine the validity of the token.

)

func (p v9Pipeliner) Exec(ctx context.Context) error {
_, err := p.pipeliner.Exec(ctx)
return err
Expand All @@ -48,6 +57,7 @@
readTimeout Duration
writeTimeout Duration
dialTimeout Duration
closeCh chan struct{}
}

func (c v9Client) GetDel(ctx context.Context, key string) (string, error) {
Expand Down Expand Up @@ -117,6 +127,7 @@
}

func (c v9Client) Close() error {
close(c.closeCh)
return c.client.Close()
}

Expand Down Expand Up @@ -317,10 +328,11 @@
return c.client.TTL(writeCtx, key).Result()
}

func newV9FailoverClient(s *Settings) RedisClient {
func newV9FailoverClient(s *Settings, properties map[string]string) RedisClient {
if s == nil {
return nil
}
closeCh := make(chan struct{})
opts := &v9.FailoverOptions{
DB: s.DB,
MasterName: s.SentinelMasterName,
Expand Down Expand Up @@ -350,15 +362,14 @@

if s.RedisType == ClusterType {
opts.SentinelAddrs = strings.Split(s.Host, ",")

return v9Client{
client: v9.NewFailoverClusterClient(opts),
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
closeCh: closeCh,
}
}

return v9Client{
client: v9.NewFailoverClient(opts),
readTimeout: s.ReadTimeout,
Expand All @@ -367,10 +378,11 @@
}
}

func newV9Client(s *Settings) RedisClient {
func newV9Client(s *Settings, properties map[string]string) RedisClient {
if s == nil {
return nil
}
closeCh := make(chan struct{})
if s.RedisType == ClusterType {
options := &v9.ClusterOptions{
Addrs: strings.Split(s.Host, ","),
Expand All @@ -395,12 +407,14 @@
InsecureSkipVerify: s.EnableTLS,
}
}

client := v9.NewClusterClient(options)
go refreshTokenRoutineForRedis(context.Background(), ClientFromV9Client(client), "v9", properties, v9logger, closeCh)
return v9Client{
client: v9.NewClusterClient(options),
client: client,
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
closeCh: closeCh,
}
}

Expand Down Expand Up @@ -429,11 +443,57 @@
InsecureSkipVerify: s.EnableTLS,
}
}

client := v9.NewClient(options)
go refreshTokenRoutineForRedis(context.Background(), ClientFromV9Client(client), "v9", properties, v9logger, closeCh)
return v9Client{
client: v9.NewClient(options),
client: client,
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
closeCh: closeCh,
}
}

func ClientFromV9Client(client v9.UniversalClient) RedisClient {
return v9Client{client: client}
}

func refreshTokenRoutineForRedis(ctx context.Context, redisClient RedisClient, version string, meta map[string]string, logger logger.Logger, closeCh chan struct{}) {
sadath-12 marked this conversation as resolved.
Show resolved Hide resolved
ticker := time.NewTicker(tokenRefreshInterval)
defer ticker.Stop()

for {
select {
case <-closeCh:
return
case <-ticker.C:
env, err := azure.NewEnvironmentSettings(meta)

Check failure on line 470 in common/component/redis/v9client.go

View workflow job for this annotation

GitHub Actions / Build linux_amd64 binaries

ineffectual assignment to err (ineffassign)
sadath-12 marked this conversation as resolved.
Show resolved Hide resolved
tokenCred, err := env.GetTokenCredential()
if err != nil {
logger.Error("Failed to get Azure AD token credential:", err)
continue
}
at, err := tokenCred.GetToken(ctx, policy.TokenRequestOptions{
sadath-12 marked this conversation as resolved.
Show resolved Hide resolved
Scopes: []string{
env.Cloud.Services[azure.ServiceOSSRDBMS].Audience + "/.default",
},
})
if err != nil {
logger.Debug("Failed to get Azure AD token:", err)
sadath-12 marked this conversation as resolved.
Show resolved Hide resolved
continue
}

// Authenticate with Redis using the refreshed token
if version == "v8" {
err = redisClient.(v8Client).client.Pipeline().Auth(ctx, at.Token).Err()
} else if version == "v9" {
err = redisClient.(v9Client).client.Pipeline().Auth(ctx, at.Token).Err()
}
if err != nil {
logger.Error("Failed to authenticate with Redis using refreshed Azure AD token:", err)
continue
}
logger.Info("Successfully refreshed Azure AD token and re-authenticated Redis.")
}
}
}
Loading