diff --git a/internal/db/kvs/redis/option.go b/internal/db/kvs/redis/option.go index e917ef28cba..5b30a427b54 100644 --- a/internal/db/kvs/redis/option.go +++ b/internal/db/kvs/redis/option.go @@ -33,7 +33,6 @@ var ( defaultOpts = []Option{ WithInitialPingDuration("30ms"), WithInitialPingTimeLimit("5m"), - WithPing(true), } ) @@ -323,10 +322,3 @@ func WithInitialPingDuration(dur string) Option { return nil } } - -func WithPing(enabled bool) Option { - return func(r *redisClient) error { - r.pingEnabled = enabled - return nil - } -} diff --git a/internal/db/kvs/redis/option_test.go b/internal/db/kvs/redis/option_test.go index d8c009ce914..0be610e31f6 100644 --- a/internal/db/kvs/redis/option_test.go +++ b/internal/db/kvs/redis/option_test.go @@ -3077,117 +3077,3 @@ func TestWithInitialPingDuration(t *testing.T) { }) } } - -func TestWithPingFlag(t *testing.T) { - // Change interface type to the type of object you are testing - type T = interface{} - type args struct { - flag bool - } - type want struct { - obj *T - // Uncomment this line if the option returns an error, otherwise delete it - // err error - } - type test struct { - name string - args args - want want - // Use the first line if the option returns an error. otherwise use the second line - // checkFunc func(want, *T, error) error - // checkFunc func(want, *T) error - beforeFunc func(args) - afterFunc func(args) - } - - // Uncomment this block if the option returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T, err error) error { - if !errors.Is(err, w.err) { - return errors.Errorf("got error = %v, want %v", err, w.err) - } - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - // Uncomment this block if the option do not returns an error, otherwise delete it - /* - defaultCheckFunc := func(w want, obj *T) error { - if !reflect.DeepEqual(obj, w.obj) { - return errors.Errorf("got = %v, want %v", obj, w.obj) - } - return nil - } - */ - - tests := []test{ - // TODO test cases - /* - { - name: "test_case_1", - args: args { - flag: false, - }, - want: want { - obj: new(T), - }, - }, - */ - - // TODO test cases - /* - func() test { - return test { - name: "test_case_2", - args: args { - flag: false, - }, - want: want { - obj: new(T), - }, - } - }(), - */ - } - - for _, test := range tests { - t.Run(test.name, func(tt *testing.T) { - defer goleak.VerifyNone(tt) - if test.beforeFunc != nil { - test.beforeFunc(test.args) - } - if test.afterFunc != nil { - defer test.afterFunc(test.args) - } - - // Uncomment this block if the option returns an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - - got := WithPingFlag(test.args.flag) - obj := new(T) - if err := test.checkFunc(test.want, obj, got(obj)); err != nil { - tt.Errorf("error = %v", err) - } - */ - - // Uncomment this block if the option do not return an error, otherwise delete it - /* - if test.checkFunc == nil { - test.checkFunc = defaultCheckFunc - } - got := WithPingFlag(test.args.flag) - obj := new(T) - got(obj) - if err := test.checkFunc(test.want, obj); err != nil { - tt.Errorf("error = %v", err) - } - */ - }) - } -} diff --git a/internal/db/kvs/redis/redis.go b/internal/db/kvs/redis/redis.go index 44a5e254633..3828b924554 100644 --- a/internal/db/kvs/redis/redis.go +++ b/internal/db/kvs/redis/redis.go @@ -83,9 +83,7 @@ type redisClient struct { routeRandomly bool tlsConfig *tls.Config writeTimeout time.Duration - - client Redis - pingEnabled bool + client Redis } // New returns Redis implementation if no error occurs. @@ -97,82 +95,85 @@ func New(ctx context.Context, opts ...Option) (rc Redis, err error) { } } - switch len(r.addrs) { + r, err = r.newRedisClient(ctx) + if err != nil { + return nil, err + } + + return r.ping(ctx) +} + +func (rc *redisClient) newRedisClient(ctx context.Context) (*redisClient, error) { + switch len(rc.addrs) { case 0: return nil, errors.ErrRedisAddrsNotFound case 1: - if len(r.addrs[0]) == 0 { + if len(rc.addrs[0]) == 0 { return nil, errors.ErrRedisAddrsNotFound } - r.client = redis.NewClient(&redis.Options{ - Addr: r.addrs[0], - Password: r.password, - Dialer: r.dialer, - OnConnect: r.onConnect, - DB: r.db, - MaxRetries: r.maxRetries, - MinRetryBackoff: r.minRetryBackoff, - MaxRetryBackoff: r.maxRetryBackoff, - DialTimeout: r.dialTimeout, - ReadTimeout: r.readTimeout, - WriteTimeout: r.writeTimeout, - PoolSize: r.poolSize, - MinIdleConns: r.minIdleConns, - MaxConnAge: r.maxConnAge, - PoolTimeout: r.poolTimeout, - IdleTimeout: r.idleTimeout, - IdleCheckFrequency: r.idleCheckFrequency, - TLSConfig: r.tlsConfig, + rc.client = redis.NewClient(&redis.Options{ + Addr: rc.addrs[0], + Password: rc.password, + Dialer: rc.dialer, + OnConnect: rc.onConnect, + DB: rc.db, + MaxRetries: rc.maxRetries, + MinRetryBackoff: rc.minRetryBackoff, + MaxRetryBackoff: rc.maxRetryBackoff, + DialTimeout: rc.dialTimeout, + ReadTimeout: rc.readTimeout, + WriteTimeout: rc.writeTimeout, + PoolSize: rc.poolSize, + MinIdleConns: rc.minIdleConns, + MaxConnAge: rc.maxConnAge, + PoolTimeout: rc.poolTimeout, + IdleTimeout: rc.idleTimeout, + IdleCheckFrequency: rc.idleCheckFrequency, + TLSConfig: rc.tlsConfig, }) default: - r.client = redis.NewClusterClient(&redis.ClusterOptions{ - Addrs: r.addrs, - Dialer: r.dialer, - MaxRedirects: r.maxRedirects, - ReadOnly: r.readOnly, - RouteByLatency: r.routeByLatency, - RouteRandomly: r.routeRandomly, - ClusterSlots: r.clusterSlots, - OnNewNode: r.onNewNode, - OnConnect: r.onConnect, - Password: r.password, - MaxRetries: r.maxRetries, - MinRetryBackoff: r.minRetryBackoff, - MaxRetryBackoff: r.maxRetryBackoff, - DialTimeout: r.dialTimeout, - ReadTimeout: r.readTimeout, - WriteTimeout: r.writeTimeout, - PoolSize: r.poolSize, - MinIdleConns: r.minIdleConns, - MaxConnAge: r.maxConnAge, - PoolTimeout: r.poolTimeout, - IdleTimeout: r.idleTimeout, - IdleCheckFrequency: r.idleCheckFrequency, - TLSConfig: r.tlsConfig, + rc.client = redis.NewClusterClient(&redis.ClusterOptions{ + Addrs: rc.addrs, + Dialer: rc.dialer, + MaxRedirects: rc.maxRedirects, + ReadOnly: rc.readOnly, + RouteByLatency: rc.routeByLatency, + RouteRandomly: rc.routeRandomly, + ClusterSlots: rc.clusterSlots, + OnNewNode: rc.onNewNode, + OnConnect: rc.onConnect, + Password: rc.password, + MaxRetries: rc.maxRetries, + MinRetryBackoff: rc.minRetryBackoff, + MaxRetryBackoff: rc.maxRetryBackoff, + DialTimeout: rc.dialTimeout, + ReadTimeout: rc.readTimeout, + WriteTimeout: rc.writeTimeout, + PoolSize: rc.poolSize, + MinIdleConns: rc.minIdleConns, + MaxConnAge: rc.maxConnAge, + PoolTimeout: rc.poolTimeout, + IdleTimeout: rc.idleTimeout, + IdleCheckFrequency: rc.idleCheckFrequency, + TLSConfig: rc.tlsConfig, }).WithContext(ctx) } - if r.pingEnabled { - if err = r.ping(ctx); err != nil { - return nil, err - } - } - - return r.client, nil + return rc, nil } -func (rc *redisClient) ping(ctx context.Context) (err error) { +func (rc *redisClient) ping(ctx context.Context) (r Redis, err error) { pctx, cancel := context.WithTimeout(ctx, rc.initialPingTimeLimit) defer cancel() tick := time.NewTicker(rc.initialPingDuration) for { select { case <-pctx.Done(): - return errors.Wrap(errors.Wrap(err, errors.ErrRedisConnectionPingFailed.Error()), pctx.Err().Error()) + return nil, errors.Wrap(errors.Wrap(err, errors.ErrRedisConnectionPingFailed.Error()), pctx.Err().Error()) case <-tick.C: err = rc.client.Ping().Err() if err == nil { - return nil + return rc.client, nil } log.Error(err) } diff --git a/internal/db/kvs/redis/redis_test.go b/internal/db/kvs/redis/redis_test.go index 234df92ab03..1504f76aa01 100644 --- a/internal/db/kvs/redis/redis_test.go +++ b/internal/db/kvs/redis/redis_test.go @@ -25,7 +25,7 @@ import ( "testing" "time" - "github.com/go-redis/redis/v7" + redis "github.com/go-redis/redis/v7" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/vdaas/vald/internal/errors" @@ -77,12 +77,117 @@ func TestNew(t *testing.T) { return nil } + tests := []test{ + { + name: "returns address not found error when options is nil", + args: args{ + ctx: context.Background(), + }, + want: want{ + wantRc: nil, + err: errors.ErrRedisAddrsNotFound, + }, + }, + + { + name: "returns ping failed error when options is not nil", + args: args{ + ctx: context.Background(), + opts: []Option{ + WithAddrs("127.0.0.0.1"), + WithInitialPingTimeLimit("1µs"), + WithInitialPingDuration("10ms"), + }, + }, + want: want{ + wantRc: nil, + err: errors.Wrap(errors.Wrap(nil, errors.ErrRedisConnectionPingFailed.Error()), context.DeadlineExceeded.Error()), + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(tt *testing.T) { + defer goleak.VerifyNone(tt, goleakIgnoreOptions...) + if test.beforeFunc != nil { + test.beforeFunc(test.args) + } + if test.afterFunc != nil { + defer test.afterFunc(test.args) + } + if test.checkFunc == nil { + test.checkFunc = defaultCheckFunc + } + + gotRc, err := New(test.args.ctx, test.args.opts...) + if err := test.checkFunc(test.want, gotRc, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +} + +func Test_redisClient_newRedisClient(t *testing.T) { + type args struct { + ctx context.Context + } + type fields struct { + addrs []string + clusterSlots func() ([]redis.ClusterSlot, error) + db int + dialTimeout time.Duration + dialer func(ctx context.Context, network, addr string) (net.Conn, error) + idleCheckFrequency time.Duration + idleTimeout time.Duration + initialPingDuration time.Duration + initialPingTimeLimit time.Duration + keyPref string + maxConnAge time.Duration + maxRedirects int + maxRetries int + maxRetryBackoff time.Duration + minIdleConns int + minRetryBackoff time.Duration + onConnect func(*redis.Conn) error + onNewNode func(*redis.Client) + password string + poolSize int + poolTimeout time.Duration + readOnly bool + readTimeout time.Duration + routeByLatency bool + routeRandomly bool + tlsConfig *tls.Config + writeTimeout time.Duration + } + type want struct { + want *redisClient + err error + } + type test struct { + name string + args args + fields fields + want want + checkFunc func(want, *redisClient, error) error + beforeFunc func(args) + afterFunc func(args) + } + defaultCheckFunc := func(w want, got *redisClient, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got error = %v, want %v", err, w.err) + } + if !reflect.DeepEqual(got, w.want) { + return errors.Errorf("got = %v, want %v", got, w.want) + } + return nil + } tests := []test{ func() test { - dialer := func(ctx context.Context, addr, port string) (net.Conn, error) { + dialer := func(ctx context.Context, _, _ string) (net.Conn, error) { return nil, nil } - connFn := func(*redis.Conn) error { + connFn := func(c *redis.Conn) error { return nil } cfg := new(tls.Config) @@ -91,52 +196,53 @@ func TestNew(t *testing.T) { name: "returns Redis implementation when address length is 1", args: args{ ctx: context.Background(), - opts: []Option{ - WithAddrs("127.0.0.1"), - WithPassword("pass"), - WithDialer(dialer), - WithOnConnectFunction(connFn), - WithDB(1), - WithRetryLimit(2), - WithMinimumRetryBackoff("3s"), - WithMaximumRetryBackoff("4s"), - WithDialTimeout("5s"), - WithReadTimeout("6s"), - WithWriteTimeout("7s"), - WithPoolSize(8), - WithMinimumIdleConnection(9), - WithMaximumConnectionAge("10s"), - WithPoolTimeout("11s"), - WithIdleTimeout("12s"), - WithIdleCheckFrequency("13s"), - WithTLSConfig(cfg), - WithPing(false), - }, + }, + fields: fields{ + addrs: []string{"127.0.0.1"}, + password: "pass", + dialer: dialer, + onConnect: connFn, + db: 1, + maxRetries: 2, + minRetryBackoff: 3 * time.Second, + maxRetryBackoff: 4 * time.Second, + dialTimeout: 5 * time.Second, + readTimeout: 6 * time.Second, + writeTimeout: 7 * time.Second, + poolSize: 8, + minIdleConns: 9, + maxConnAge: 10 * time.Second, + poolTimeout: 11 * time.Second, + idleTimeout: 12 * time.Second, + idleCheckFrequency: 13 * time.Second, + tlsConfig: cfg, }, want: want{ - wantRc: redis.NewClient(&redis.Options{ - Addr: "127.0.0.1", - Password: "pass", - Dialer: dialer, - OnConnect: connFn, - DB: 1, - MaxRetries: 2, - MinRetryBackoff: 3 * time.Second, - MaxRetryBackoff: 4 * time.Second, - DialTimeout: 5 * time.Second, - ReadTimeout: 6 * time.Second, - WriteTimeout: 7 * time.Second, - PoolSize: 8, - MinIdleConns: 9, - MaxConnAge: 10 * time.Second, - PoolTimeout: 11 * time.Second, - IdleTimeout: 12 * time.Second, - IdleCheckFrequency: 13 * time.Second, - TLSConfig: cfg, - }), + want: &redisClient{ + client: redis.NewClient(&redis.Options{ + Addr: "127.0.0.1", + Password: "pass", + Dialer: dialer, + OnConnect: connFn, + DB: 1, + MaxRetries: 2, + MinRetryBackoff: 3 * time.Second, + MaxRetryBackoff: 4 * time.Second, + DialTimeout: 5 * time.Second, + ReadTimeout: 6 * time.Second, + WriteTimeout: 7 * time.Second, + PoolSize: 8, + MinIdleConns: 9, + MaxConnAge: 10 * time.Second, + PoolTimeout: 11 * time.Second, + IdleTimeout: 12 * time.Second, + IdleCheckFrequency: 13 * time.Second, + TLSConfig: cfg, + }), + }, err: nil, }, - checkFunc: func(w want, gotRc Redis, err error) error { + checkFunc: func(w want, gotRc *redisClient, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got error = %v, want %v", err, w.err) } @@ -145,8 +251,8 @@ func TestNew(t *testing.T) { } var ( - want = w.wantRc.(*redis.Client).Options() - got = gotRc.(*redis.Client).Options() + want = w.want.client.(*redis.Client).Options() + got = gotRc.client.(*redis.Client).Options() ) opts := []cmp.Option{ @@ -163,7 +269,7 @@ func TestNew(t *testing.T) { }), } if diff := cmp.Diff(want, got, opts...); diff != "" { - return errors.Errorf("got = %v, want = %v", got, want) + return errors.Errorf("client options diff: %s", diff) } return nil @@ -172,14 +278,14 @@ func TestNew(t *testing.T) { }(), func() test { - dialer := func(ctx context.Context, addr, port string) (net.Conn, error) { + dialer := func(ctx context.Context, _, _ string) (net.Conn, error) { return nil, nil } - closterSlots := func() ([]redis.ClusterSlot, error) { + cslots := func() ([]redis.ClusterSlot, error) { return nil, nil } onNewNode := func(*redis.Client) {} - onConnect := func(*redis.Conn) error { + onConnect := func(c *redis.Conn) error { return nil } cfg := new(tls.Config) @@ -188,64 +294,61 @@ func TestNew(t *testing.T) { name: "returns Redis implementation when address length is 2", args: args{ ctx: context.Background(), - opts: []Option{ - WithAddrs("127.0.0.1", "192.168.33.10"), - WithDialer(dialer), - WithRedirectLimit(1), - WithReadOnlyFlag(true), - WithRouteByLatencyFlag(true), - WithRouteRandomlyFlag(true), - WithClusterSlots(closterSlots), - WithOnNewNodeFunction(onNewNode), - WithOnConnectFunction(onConnect), - WithPassword("pass"), - WithRetryLimit(2), - WithMinimumRetryBackoff("3s"), - WithMaximumRetryBackoff("4s"), - WithDialTimeout("5s"), - WithReadTimeout("6s"), - WithWriteTimeout("7s"), - WithPoolSize(8), - WithMinimumIdleConnection(9), - WithMaximumConnectionAge("10s"), - WithPoolTimeout("11s"), - WithIdleTimeout("12s"), - WithIdleCheckFrequency("13s"), - WithTLSConfig(cfg), - WithPing(false), - }, + }, + fields: fields{ + addrs: []string{"127.0.0.1", "127.0.0.2"}, + dialer: dialer, + maxRedirects: 1, + readOnly: true, + routeByLatency: true, + routeRandomly: true, + clusterSlots: cslots, + onNewNode: onNewNode, + onConnect: onConnect, + password: "pass", + maxRetries: 2, + minRetryBackoff: 3 * time.Second, + maxRetryBackoff: 4 * time.Second, + dialTimeout: 5 * time.Second, + readTimeout: 6 * time.Second, + writeTimeout: 7 * time.Second, + poolSize: 8, + maxConnAge: 9 * time.Second, + idleTimeout: 10 * time.Second, + idleCheckFrequency: 11 * time.Second, + tlsConfig: cfg, }, want: want{ - wantRc: redis.NewClusterClient(&redis.ClusterOptions{ - Addrs: []string{ - "127.0.0.1", "192.168.33.10", - }, - Dialer: dialer, - MaxRedirects: 1, - ReadOnly: true, - RouteByLatency: true, - RouteRandomly: true, - ClusterSlots: closterSlots, - OnNewNode: onNewNode, - OnConnect: onConnect, - Password: "pass", - MaxRetries: 2, - MinRetryBackoff: 3 * time.Second, - MaxRetryBackoff: 4 * time.Second, - DialTimeout: 5 * time.Second, - ReadTimeout: 6 * time.Second, - WriteTimeout: 7 * time.Second, - PoolSize: 8, - MinIdleConns: 9, - MaxConnAge: 10 * time.Second, - PoolTimeout: 11 * time.Second, - IdleTimeout: 12 * time.Second, - IdleCheckFrequency: 13 * time.Second, - TLSConfig: cfg, - }), + want: &redisClient{ + client: redis.NewClusterClient(&redis.ClusterOptions{ + Addrs: []string{ + "127.0.0.1", "127.0.0.2", + }, + Dialer: dialer, + MaxRedirects: 1, + ReadOnly: true, + RouteByLatency: true, + RouteRandomly: true, + ClusterSlots: cslots, + OnNewNode: onNewNode, + OnConnect: onConnect, + Password: "pass", + MaxRetries: 2, + MinRetryBackoff: 3 * time.Second, + MaxRetryBackoff: 4 * time.Second, + DialTimeout: 5 * time.Second, + ReadTimeout: 6 * time.Second, + WriteTimeout: 7 * time.Second, + PoolSize: 8, + MaxConnAge: 9 * time.Second, + IdleTimeout: 10 * time.Second, + IdleCheckFrequency: 11 * time.Second, + TLSConfig: cfg, + }), + }, err: nil, }, - checkFunc: func(w want, gotRc Redis, err error) error { + checkFunc: func(w want, gotRc *redisClient, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got error = %v, want %v", err, w.err) } @@ -254,8 +357,8 @@ func TestNew(t *testing.T) { } var ( - want = w.wantRc.(*redis.ClusterClient).Options() - got = gotRc.(*redis.ClusterClient).Options() + want = w.want.client.(*redis.ClusterClient).Options() + got = gotRc.client.(*redis.ClusterClient).Options() ) opts := []cmp.Option{ @@ -292,54 +395,29 @@ func TestNew(t *testing.T) { func() test { return test{ - name: "returns redis address not found error when address length is 0", - args: args{ - ctx: context.Background(), - opts: nil, - }, - want: want{ - wantRc: nil, - err: errors.ErrRedisAddrsNotFound, - }, - } - }(), - - func() test { - return test{ - name: "returns redis address not found error when address length is 1 but address contains empty string", + name: "returns address not found error when address length is 0", args: args{ ctx: context.Background(), - opts: []Option{ - WithAddrs(""), - }, }, want: want{ - wantRc: nil, - err: errors.ErrRedisAddrsNotFound, + want: nil, + err: errors.ErrRedisAddrsNotFound, }, } }(), func() test { - err := errors.New("err") return test{ - name: "returns ping error when address length is 1 and ping fails", + name: "returns address not found error when address length is 1 but contains empty string", + fields: fields{ + addrs: []string{""}, + }, args: args{ ctx: context.Background(), - opts: []Option{ - WithAddrs("127.0.0.01"), - WithInitialPingDuration("1ms"), - WithInitialPingTimeLimit("2ms"), - WithDialer(func(ctx context.Context, addr string, port string) (net.Conn, error) { - return nil, err - }), - }, }, want: want{ - wantRc: nil, - err: errors.Wrap(errors.Wrap( - err, - errors.ErrRedisConnectionPingFailed.Error()), context.DeadlineExceeded.Error()), + want: nil, + err: errors.ErrRedisAddrsNotFound, }, } }(), @@ -357,11 +435,41 @@ func TestNew(t *testing.T) { if test.checkFunc == nil { test.checkFunc = defaultCheckFunc } + rc := &redisClient{ + addrs: test.fields.addrs, + clusterSlots: test.fields.clusterSlots, + db: test.fields.db, + dialTimeout: test.fields.dialTimeout, + dialer: test.fields.dialer, + idleCheckFrequency: test.fields.idleCheckFrequency, + idleTimeout: test.fields.idleTimeout, + initialPingDuration: test.fields.initialPingDuration, + initialPingTimeLimit: test.fields.initialPingTimeLimit, + keyPref: test.fields.keyPref, + maxConnAge: test.fields.maxConnAge, + maxRedirects: test.fields.maxRedirects, + maxRetries: test.fields.maxRetries, + maxRetryBackoff: test.fields.maxRetryBackoff, + minIdleConns: test.fields.minIdleConns, + minRetryBackoff: test.fields.minRetryBackoff, + onConnect: test.fields.onConnect, + onNewNode: test.fields.onNewNode, + password: test.fields.password, + poolSize: test.fields.poolSize, + poolTimeout: test.fields.poolTimeout, + readOnly: test.fields.readOnly, + readTimeout: test.fields.readTimeout, + routeByLatency: test.fields.routeByLatency, + routeRandomly: test.fields.routeRandomly, + tlsConfig: test.fields.tlsConfig, + writeTimeout: test.fields.writeTimeout, + } - gotRc, err := New(test.args.ctx, test.args.opts...) - if err := test.checkFunc(test.want, gotRc, err); err != nil { + got, err := rc.newRedisClient(test.args.ctx) + if err := test.checkFunc(test.want, got, err); err != nil { tt.Errorf("error = %v", err) } + }) } } @@ -376,49 +484,55 @@ func Test_redisClient_ping(t *testing.T) { client Redis } type want struct { - err error + wantR Redis + err error } type test struct { name string args args fields fields want want - checkFunc func(want, error) error + checkFunc func(want, Redis, error) error beforeFunc func(args) afterFunc func(args) } - defaultCheckFunc := func(w want, err error) error { + defaultCheckFunc := func(w want, gotR Redis, err error) error { if !errors.Is(err, w.err) { return errors.Errorf("got error = %v, want %v", err, w.err) } + if !reflect.DeepEqual(gotR, w.wantR) { + return errors.Errorf("got = %v, want %v", gotR, w.wantR) + } return nil } tests := []test{ func() test { + r := &MockRedis{ + PingFunc: func() *StatusCmd { + return new(StatusCmd) + }, + } + return test{ name: "returns nil when the ping success", args: args{ ctx: context.Background(), }, fields: fields{ - initialPingDuration: time.Microsecond, + initialPingDuration: time.Millisecond, initialPingTimeLimit: time.Second, - client: func() Redis { - return &MockRedis{ - PingFunc: func() *StatusCmd { - return new(StatusCmd) - }, - } - }(), + client: r, }, want: want{ - err: nil, + wantR: r, + err: nil, }, } }(), func() test { err := errors.New("err") + return test{ name: "returns ping failed error when the ping fails and reached the ping time limit", args: args{ @@ -426,7 +540,7 @@ func Test_redisClient_ping(t *testing.T) { }, fields: fields{ initialPingDuration: time.Millisecond, - initialPingTimeLimit: 5 * time.Millisecond, + initialPingTimeLimit: 3 * time.Millisecond, client: func() Redis { return &MockRedis{ PingFunc: func() (cmd *StatusCmd) { @@ -438,7 +552,8 @@ func Test_redisClient_ping(t *testing.T) { }(), }, want: want{ - err: errors.Wrap(errors.Wrap(err, errors.ErrRedisConnectionPingFailed.Error()), context.DeadlineExceeded.Error()), + wantR: nil, + err: errors.Wrap(errors.Wrap(err, errors.ErrRedisConnectionPingFailed.Error()), context.DeadlineExceeded.Error()), }, } }(), @@ -462,10 +577,11 @@ func Test_redisClient_ping(t *testing.T) { client: test.fields.client, } - err := rc.ping(test.args.ctx) - if err := test.checkFunc(test.want, err); err != nil { + gotR, err := rc.ping(test.args.ctx) + if err := test.checkFunc(test.want, gotR, err); err != nil { tt.Errorf("error = %v", err) } + }) } }