Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: add redis token refresh support #3238

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion bindings/redis/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ func (r *Redis) Init(ctx context.Context, meta bindings.Metadata) (err error) {
if err != nil {
return err
}

_, err = r.client.PingResult(ctx)
if err != nil {
return fmt.Errorf("redis binding: error connecting to redis at %s: %s", r.clientSettings.Host, err)
Expand Down
12 changes: 6 additions & 6 deletions common/component/redis/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ func ParseClientFromProperties(properties map[string]string, componentType metad

var c RedisClient
if settings.Failover {
c = newV8FailoverClient(settings)
c = newV8FailoverClient(settings, properties)
} else {
c = newV8Client(settings)
c = newV8Client(settings, properties)
}
version, versionErr := GetServerVersion(c)
c.Close() // close the client to avoid leaking connections
Expand All @@ -177,14 +177,14 @@ func ParseClientFromProperties(properties map[string]string, componentType metad
}
if useNewClient {
if settings.Failover {
return newV9FailoverClient(settings), settings, nil
return newV9FailoverClient(settings, properties), settings, nil
}
return newV9Client(settings), settings, nil
return newV9Client(settings, properties), settings, nil
} else {
if settings.Failover {
return newV8FailoverClient(settings), settings, nil
return newV8FailoverClient(settings, properties), settings, nil
}
return newV8Client(settings), settings, nil
return newV8Client(settings, properties), settings, nil
}
}

Expand Down
23 changes: 14 additions & 9 deletions common/component/redis/v8client.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ func (c v8Client) TTLResult(ctx context.Context, key string) (time.Duration, err
return c.client.TTL(writeCtx, key).Result()
}

func newV8FailoverClient(s *Settings) RedisClient {
func newV8FailoverClient(s *Settings, properties map[string]string) RedisClient {
if s == nil {
return nil
}
Expand Down Expand Up @@ -349,24 +349,26 @@ func newV8FailoverClient(s *Settings) RedisClient {

if s.RedisType == ClusterType {
opts.SentinelAddrs = strings.Split(s.Host, ",")

client := v8.NewFailoverClusterClient(opts)
go refreshTokenRoutineForRedis(context.Background(), ClientFromV8Client(client), properties)
return v8Client{
client: v8.NewFailoverClusterClient(opts),
client: client,
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
}
}

client := v8.NewFailoverClient(opts)
go refreshTokenRoutineForRedis(context.Background(), ClientFromV8Client(client), properties)
return v8Client{
client: v8.NewFailoverClient(opts),
client: client,
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
}
}

func newV8Client(s *Settings) RedisClient {
func newV8Client(s *Settings, properties map[string]string) RedisClient {
if s == nil {
return nil
}
Expand Down Expand Up @@ -394,9 +396,11 @@ func newV8Client(s *Settings) RedisClient {
InsecureSkipVerify: s.EnableTLS,
}
}
client := v8.NewClusterClient(options)
go refreshTokenRoutineForRedis(context.Background(), ClientFromV8Client(client), properties)

return v8Client{
client: v8.NewClusterClient(options),
client: client,
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
Expand Down Expand Up @@ -428,9 +432,10 @@ func newV8Client(s *Settings) RedisClient {
InsecureSkipVerify: s.EnableTLS,
}
}

client := v8.NewClient(options)
go refreshTokenRoutineForRedis(context.Background(), ClientFromV8Client(client), properties)
return v8Client{
client: v8.NewClient(options),
client: client,
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
Expand Down
61 changes: 51 additions & 10 deletions common/component/redis/v9client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@ package redis
import (
"context"
"crypto/tls"
"fmt"
"strings"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/dapr/components-contrib/common/authentication/azure"
v9 "github.com/redis/go-redis/v9"
)

Expand Down Expand Up @@ -317,7 +320,7 @@ func (c v9Client) TTLResult(ctx context.Context, key string) (time.Duration, err
return c.client.TTL(writeCtx, key).Result()
}

func newV9FailoverClient(s *Settings) RedisClient {
func newV9FailoverClient(s *Settings, properties map[string]string) RedisClient {
if s == nil {
return nil
}
Expand Down Expand Up @@ -350,24 +353,26 @@ func newV9FailoverClient(s *Settings) RedisClient {

if s.RedisType == ClusterType {
opts.SentinelAddrs = strings.Split(s.Host, ",")

client := v9.NewFailoverClusterClient(opts)
go refreshTokenRoutineForRedis(context.Background(), ClientFromV9Client(client), properties)
return v9Client{
client: v9.NewFailoverClusterClient(opts),
client: client,
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
}
}

client := v9.NewFailoverClient(opts)
go refreshTokenRoutineForRedis(context.Background(), ClientFromV9Client(client), properties)
return v9Client{
client: v9.NewFailoverClient(opts),
client: client,
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
}
}

func newV9Client(s *Settings) RedisClient {
func newV9Client(s *Settings, properties map[string]string) RedisClient {
if s == nil {
return nil
}
Expand Down Expand Up @@ -395,9 +400,10 @@ func newV9Client(s *Settings) RedisClient {
InsecureSkipVerify: s.EnableTLS,
}
}

client := v9.NewClusterClient(options)
go refreshTokenRoutineForRedis(context.Background(), ClientFromV9Client(client), properties)
sadath-12 marked this conversation as resolved.
Show resolved Hide resolved
return v9Client{
client: v9.NewClusterClient(options),
client: client,
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
Expand Down Expand Up @@ -429,11 +435,46 @@ func newV9Client(s *Settings) RedisClient {
InsecureSkipVerify: s.EnableTLS,
}
}

client := v9.NewClient(options)
go refreshTokenRoutineForRedis(context.Background(), ClientFromV9Client(client), properties)
sadath-12 marked this conversation as resolved.
Show resolved Hide resolved
return v9Client{
client: v9.NewClient(options),
client: client,
readTimeout: s.ReadTimeout,
writeTimeout: s.WriteTimeout,
dialTimeout: s.DialTimeout,
}
}

func ClientFromV9Client(client v9.UniversalClient) RedisClient {
return v9Client{client: client}
}

func refreshTokenRoutineForRedis(ctx context.Context, redisClient RedisClient, meta map[string]string) {
ticker := time.NewTicker(time.Hour)
sadath-12 marked this conversation as resolved.
Show resolved Hide resolved
defer ticker.Stop()

for {
select {
case <-ticker.C:
sadath-12 marked this conversation as resolved.
Show resolved Hide resolved
env, err := azure.NewEnvironmentSettings(meta)
sadath-12 marked this conversation as resolved.
Show resolved Hide resolved
tokenCred, err := env.GetTokenCredential()
if err != nil {
fmt.Println("Failed to get Azure AD token credential:", err)
sadath-12 marked this conversation as resolved.
Show resolved Hide resolved
continue
}
at, err := tokenCred.GetToken(ctx, policy.TokenRequestOptions{
sadath-12 marked this conversation as resolved.
Show resolved Hide resolved
Scopes: []string{
env.Cloud.Services[azure.ServiceOSSRDBMS].Audience + "/.default",
},
})

// Authenticate with Redis using the refreshed token
err = redisClient.(v9Client).client.Pipeline().Auth(ctx, at.Token).Err()
sadath-12 marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
fmt.Println("Failed to authenticate with Redis using refreshed Azure AD token:", err)
sadath-12 marked this conversation as resolved.
Show resolved Hide resolved
continue
}
fmt.Println("Successfully refreshed Azure AD token and re-authenticated Redis.")
sadath-12 marked this conversation as resolved.
Show resolved Hide resolved
}
}
}