diff --git a/charts/vald-readreplica/templates/deployment.yaml b/charts/vald-readreplica/templates/deployment.yaml index fea4a53f29..5cb2a9a926 100644 --- a/charts/vald-readreplica/templates/deployment.yaml +++ b/charts/vald-readreplica/templates/deployment.yaml @@ -74,10 +74,17 @@ spec: - podAffinityTerm: labelSelector: matchExpressions: + # to avoid being deployed to the same node as the agent pods - key: app operator: In values: - - {{ $readreplica.name }}-{{ $id }} + - {{ $agent.name }} + # to avoid being deployed to the same node as the other readreplica pods + # and Deployment replicas of myself + - key: app.kubernetes.io/component + operator: In + values: + - {{ $readreplica.component_name }} topologyKey: kubernetes.io/hostname weight: 100 {{- if $agent.topologySpreadConstraints }} diff --git a/charts/vald/templates/gateway/lb/configmap.yaml b/charts/vald/templates/gateway/lb/configmap.yaml index 0facd73355..7c840cb0c2 100644 --- a/charts/vald/templates/gateway/lb/configmap.yaml +++ b/charts/vald/templates/gateway/lb/configmap.yaml @@ -16,6 +16,8 @@ {{- $gateway := .Values.gateway.lb -}} {{- $agent := .Values.agent -}} {{- $discoverer := .Values.discoverer -}} +{{- $readreplica := .Values.agent.readreplica -}} +{{- $release := .Release -}} {{- if $gateway.enabled }} apiVersion: v1 kind: ConfigMap @@ -49,6 +51,7 @@ data: agent_namespace: {{ $gateway.gateway_config.agent_namespace | quote }} node_name: {{ $gateway.gateway_config.node_name | quote }} index_replica: {{ $gateway.gateway_config.index_replica }} + read_replica_replicas: {{ $readreplica.replica }} discoverer: duration: {{ $gateway.gateway_config.discoverer.duration }} client: @@ -64,4 +67,20 @@ data: agent_client_options: {{- include "vald.grpc.client.addrs" (dict "Valued" $gateway.gateway_config.discoverer.agent_client_options.addrs) | nindent 10 }} {{- include "vald.grpc.client" (dict "Values" $gateway.gateway_config.discoverer.agent_client_options "default" .Values.defaults.grpc.client) | nindent 10 }} + {{- if $readreplica.enabled }} + read_replica_client: + client: + {{- $discovererClient := $gateway.gateway_config.discoverer.client }} + {{- $readReplicaPort := $agent.server_config.servers.grpc.port }} + {{- $defaultReadReplicaPort := default .Values.defaults.server_config.servers.grpc.port $readReplicaPort }} + {{- $readReplicaAddrs := list }} + {{- range $i := until (int $agent.minReplicas) }} + {{- $addr := printf "%s-%d.%s.svc.cluster.local:%d" $readreplica.name $i $release.Namespace (int64 $defaultReadReplicaPort) }} + {{- $readReplicaAddrs = append $readReplicaAddrs $addr }} + {{- end }} + {{- $readReplicaAddrs := dict "Values" $discovererClient.addrs "default" $readReplicaAddrs }} + {{- include "vald.grpc.client.addrs" $readReplicaAddrs | nindent 10 }} + {{- $readReplicaGRPCclient := dict "Values" $discovererClient "default" .Values.defaults.grpc.client }} + {{- include "vald.grpc.client" $readReplicaGRPCclient | nindent 10 }} + {{- end }} {{- end }} diff --git a/internal/client/v1/client/discoverer/discover.go b/internal/client/v1/client/discoverer/discover.go index 2c2c5c398a..32b0d86ebb 100644 --- a/internal/client/v1/client/discoverer/discover.go +++ b/internal/client/v1/client/discoverer/discover.go @@ -37,7 +37,15 @@ import ( type Client interface { Start(ctx context.Context) (<-chan error, error) GetAddrs(ctx context.Context) []string + + // GetClient returns the grpc.Client for both read and write. GetClient() grpc.Client + + // GetReadClient returns the grpc.Client only for read. If there's no readreplica, this returns the grpc.Client for the primary agent. + // Use this API only for getting client for agent. For other use cases, use GetClient() instead. + // Internally, this API round robin between c.client and c.readClient with the ratio of + // agent replicas and read replica agent replicas. + GetReadClient() grpc.Client } type client struct { @@ -56,6 +64,10 @@ type client struct { name string namespace string nodeName string + // read replica related members below + readClient grpc.Client + readReplicaReplicas uint64 + roundRobin atomic.Uint64 } func New(opts ...Option) (d Client, err error) { @@ -68,12 +80,22 @@ func New(opts ...Option) (d Client, err error) { return c, nil } +// Start starts the discoverer client. +// skipcq: GO-R1005 func (c *client) Start(ctx context.Context) (<-chan error, error) { dech, err := c.dscClient.StartConnectionMonitor(ctx) if err != nil { return nil, err } + var rrech <-chan error + if c.readClient != nil { + rrech, err = c.readClient.StartConnectionMonitor(ctx) + if err != nil { + return nil, err + } + } + ech := make(chan error, 100) addrs, err := c.dnsDiscovery(ctx, ech) if err != nil { @@ -134,6 +156,7 @@ func (c *client) Start(ctx context.Context) (<-chan error, error) { return finalize() case err = <-dech: case err = <-aech: + case err = <-rrech: case <-dt.C: err = c.discover(ctx, ech) } @@ -172,6 +195,26 @@ func (c *client) GetClient() grpc.Client { return c.client } +func (c *client) GetReadClient() grpc.Client { + // just return write client when there is no read replica + if c.readClient == nil { + return c.client + } + + var next uint64 + for { + cur := c.roundRobin.Load() + next = (cur + 1) % (c.readReplicaReplicas + 1) + if c.roundRobin.CompareAndSwap(cur, next) { + break + } + } + if next == 0 { + return c.client + } + return c.readClient +} + func (c *client) connect(ctx context.Context, addr string) (err error) { if c.autoconn && c.client != nil { _, err = c.client.Connect(ctx, addr) diff --git a/internal/client/v1/client/discoverer/discover_test.go b/internal/client/v1/client/discoverer/discover_test.go index c1e55bd732..fb4b403227 100644 --- a/internal/client/v1/client/discoverer/discover_test.go +++ b/internal/client/v1/client/discoverer/discover_test.go @@ -17,6 +17,133 @@ // Package discoverer package discoverer +import ( + "context" + "reflect" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" + "github.com/vdaas/vald/internal/net/grpc" + "github.com/vdaas/vald/internal/sync/errgroup" + "github.com/vdaas/vald/internal/test/mock" +) + +func Test_client_GetReadClient(t *testing.T) { + type fields struct { + client grpc.Client + readClient grpc.Client + readReplicaReplicas uint64 + roundRobin atomic.Uint64 + } + type test struct { + name string + fields fields + want grpc.Client + } + + mockClient := mock.ClientInternal{} + mockClient.On("GetAddrs").Return([]string{"read write client"}) + mockReadClient := mock.ClientInternal{} + mockReadClient.On("GetAddrs").Return([]string{"read replica client"}) + + tests := []test{ + { + name: "returns primary client when there is no read replica", + fields: fields{ + client: &mockClient, + readClient: nil, + readReplicaReplicas: 1, + }, + want: &mockClient, + }, + func() test { + var counter atomic.Uint64 + counter.Store(0) + return test{ + name: "returns read client when there is read replica and the counter increments to anything other than 0", + fields: fields{ + client: &mockClient, + readClient: &mockReadClient, + readReplicaReplicas: 1, + roundRobin: counter, + }, + want: &mockReadClient, + } + }(), + func() test { + var counter atomic.Uint64 + counter.Store(1) + return test{ + name: "returns primary client when there is read replica and the counter increments to 0", + fields: fields{ + client: &mockClient, + readClient: &mockReadClient, + readReplicaReplicas: 1, + roundRobin: counter, + }, + want: &mockClient, + } + }(), + func() test { + var counter atomic.Uint64 + counter.Store(3) + return test{ + name: "returns primary client when there is read replica and the counter increments to 0(replicas: 3)", + fields: fields{ + client: &mockClient, + readClient: &mockReadClient, + readReplicaReplicas: 3, + roundRobin: counter, + }, + want: &mockClient, + } + }(), + } + for _, tc := range tests { + test := tc + t.Run(test.name, func(t *testing.T) { + c := &client{ + client: test.fields.client, + readClient: test.fields.readClient, + readReplicaReplicas: test.fields.readReplicaReplicas, + roundRobin: test.fields.roundRobin, + } + got := c.GetReadClient() + if !reflect.DeepEqual(got, test.want) { + t.Errorf("GetReadClient() = %v, want %v", got, test.want) + } + }) + } +} + +func Test_client_GetReadClient_concurrent(t *testing.T) { + mockClient := mock.ClientInternal{} + mockClient.On("GetAddrs").Return([]string{"read write client"}) + mockReadClient := mock.ClientInternal{} + mockReadClient.On("GetAddrs").Return([]string{"read replica client"}) + + c := &client{ + client: &mockClient, + readClient: &mockReadClient, + readReplicaReplicas: 100, + roundRobin: atomic.Uint64{}, + } + + eg, _ := errgroup.New(context.Background()) + for i := 0; i < 150; i++ { + eg.Go(func() error { + c.GetReadClient() + return nil + }) + } + + err := eg.Wait() + require.NoError(t, err) + + require.EqualValues(t, uint64(49), c.roundRobin.Load(), "atomic operation did not happen in the concurrent calls") +} + // NOT IMPLEMENTED BELOW // // func TestNew(t *testing.T) { diff --git a/internal/client/v1/client/discoverer/option.go b/internal/client/v1/client/discoverer/option.go index 6159d8b72e..632bfc2d29 100644 --- a/internal/client/v1/client/discoverer/option.go +++ b/internal/client/v1/client/discoverer/option.go @@ -68,6 +68,13 @@ func WithDiscovererClient(gc grpc.Client) Option { } } +func WithReadReplicaClient(gc grpc.Client) Option { + return func(c *client) error { + c.readClient = gc + return nil + } +} + func WithDiscoverDuration(dur string) Option { return func(c *client) error { d, err := timeutil.Parse(dur) @@ -142,3 +149,10 @@ func WithErrGroup(eg errgroup.Group) Option { return nil } } + +func WithReadReplicaReplicas(num uint64) Option { + return func(c *client) error { + c.readReplicaReplicas = num + return nil + } +} diff --git a/internal/config/lb.go b/internal/config/lb.go index b479400b2d..cfc2a6fb2a 100644 --- a/internal/config/lb.go +++ b/internal/config/lb.go @@ -37,6 +37,12 @@ type LB struct { // IndexReplica represents index replication count IndexReplica int `json:"index_replica" yaml:"index_replica"` + // ReadReplicaReplicas represents replica count of read replica Deployment + ReadReplicaReplicas uint64 `json:"read_replica_replicas" yaml:"read_replica_replicas"` + + // ReadReplicaClient represents read replica client configuration + ReadReplicaClient ReadReplicaClient `json:"read_replica_client" yaml:"read_replica_client"` + // Discoverer represent agent discoverer service configuration Discoverer *DiscovererClient `json:"discoverer" yaml:"discoverer"` @@ -56,3 +62,26 @@ func (g *LB) Bind() *LB { } return g } + +// ReadReplicaClient represents a configuration of grpc client for read replica +type ReadReplicaClient struct { + Duration string `json:"duration" yaml:"duration"` + Client *GRPCClient `json:"client" yaml:"client"` + AgentClientOptions *GRPCClient `json:"agent_client_options" yaml:"agent_client_options"` +} + +// Bind binds the actual data from the ReadReplicaClient receiver field. +func (d *ReadReplicaClient) Bind() *ReadReplicaClient { + d.Duration = GetActualValue(d.Duration) + if d.Client != nil { + d.Client.Bind() + } else { + d.Client = newGRPCClientConfig() + } + if d.AgentClientOptions != nil { + d.AgentClientOptions.Bind() + } else { + d.AgentClientOptions = newGRPCClientConfig() + } + return d +} diff --git a/pkg/gateway/lb/handler/grpc/aggregation.go b/pkg/gateway/lb/handler/grpc/aggregation.go index 4c7403767e..99f9d45fcc 100644 --- a/pkg/gateway/lb/handler/grpc/aggregation.go +++ b/pkg/gateway/lb/handler/grpc/aggregation.go @@ -33,6 +33,7 @@ import ( "github.com/vdaas/vald/internal/net/grpc/status" "github.com/vdaas/vald/internal/observability/trace" "github.com/vdaas/vald/internal/sync" + "github.com/vdaas/vald/pkg/gateway/lb/service" ) type Aggregator interface { @@ -68,7 +69,7 @@ func (s *server) aggregationSearch(ctx context.Context, aggr Aggregator, cfg *pa ctx, cancel := context.WithTimeout(ctx, timeout) aggr.Start(ctx) - err = s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.Client, copts ...grpc.CallOption) error { + err = s.gateway.BroadCast(ctx, service.READ, func(ctx context.Context, target string, vc vald.Client, copts ...grpc.CallOption) error { sctx, sspan := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/aggregationSearch/"+target) defer func() { if sspan != nil { diff --git a/pkg/gateway/lb/handler/grpc/handler.go b/pkg/gateway/lb/handler/grpc/handler.go index 27b48274f4..7f0678c3ba 100644 --- a/pkg/gateway/lb/handler/grpc/handler.go +++ b/pkg/gateway/lb/handler/grpc/handler.go @@ -95,7 +95,7 @@ func (s *server) exists(ctx context.Context, uuid string) (id *payload.Object_ID defer close(ich) defer close(ech) var once sync.Once - ech <- s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.Client, copts ...grpc.CallOption) error { + ech <- s.gateway.BroadCast(ctx, service.READ, func(ctx context.Context, target string, vc vald.Client, copts ...grpc.CallOption) error { sctx, sspan := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/exists/BroadCast/"+target) defer func() { if sspan != nil { @@ -1653,7 +1653,7 @@ func (s *server) Update(ctx context.Context, req *payload.Update_Request) (res * Ips: make([]string, 0, s.replica), } ) - err = s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.Client, copts ...grpc.CallOption) (err error) { + err = s.gateway.BroadCast(ctx, service.WRITE, func(ctx context.Context, target string, vc vald.Client, copts ...grpc.CallOption) (err error) { ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/"+vald.UpdateRPCName+"/"+target) defer func() { if span != nil { @@ -2562,7 +2562,7 @@ func (s *server) Remove(ctx context.Context, req *payload.Remove_Request) (locs Ips: make([]string, 0, s.replica), } ls := make([]string, 0, s.replica) - err = s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.Client, copts ...grpc.CallOption) (err error) { + err = s.gateway.BroadCast(ctx, service.WRITE, func(ctx context.Context, target string, vc vald.Client, copts ...grpc.CallOption) (err error) { ctx, span := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/"+vald.RemoveRPCName+"/"+target) defer func() { if span != nil { @@ -2778,7 +2778,7 @@ func (s *server) RemoveByTimestamp(ctx context.Context, req *payload.Remove_Time visited := make(map[string]int) // map[uuid: position of locs] locs = new(payload.Object_Locations) - err := s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.Client, copts ...grpc.CallOption) (err error) { + err := s.gateway.BroadCast(ctx, service.WRITE, func(ctx context.Context, target string, vc vald.Client, copts ...grpc.CallOption) (err error) { sctx, sspan := trace.StartSpan(grpc.WithGRPCMethod(ctx, "BroadCast/"+target), apiName+"/removeByTimestamp/BroadCast/"+target) defer func() { if sspan != nil { @@ -2905,7 +2905,7 @@ func (s *server) getObject(ctx context.Context, uuid string) (vec *payload.Objec defer close(vch) defer close(ech) var once sync.Once - ech <- s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.Client, copts ...grpc.CallOption) error { + ech <- s.gateway.BroadCast(ctx, service.READ, func(ctx context.Context, target string, vc vald.Client, copts ...grpc.CallOption) error { sctx, sspan := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/getObject/BroadCast/"+target) defer func() { if sspan != nil { @@ -3143,7 +3143,7 @@ func (s *server) StreamListObject(req *payload.Object_List_Request, stream vald. defer cancel() var rmu, smu sync.Mutex - err := s.gateway.BroadCast(ctx, func(ctx context.Context, target string, vc vald.Client, copts ...grpc.CallOption) error { + err := s.gateway.BroadCast(ctx, service.READ, func(ctx context.Context, target string, vc vald.Client, copts ...grpc.CallOption) error { ctx, sspan := trace.StartSpan(grpc.WrapGRPCMethod(ctx, "BroadCast/"+target), apiName+"/"+vald.StreamListObjectRPCName+"/"+target) defer func() { if sspan != nil { diff --git a/pkg/gateway/lb/service/gateway.go b/pkg/gateway/lb/service/gateway.go index 64f9f0dc9b..8ae522c9ff 100644 --- a/pkg/gateway/lb/service/gateway.go +++ b/pkg/gateway/lb/service/gateway.go @@ -37,10 +37,17 @@ type Gateway interface { Addrs(ctx context.Context) []string DoMulti(ctx context.Context, num int, f func(ctx context.Context, target string, ac vald.Client, copts ...grpc.CallOption) error) error - BroadCast(ctx context.Context, + BroadCast(ctx context.Context, kind BroadCastKind, f func(ctx context.Context, target string, ac vald.Client, copts ...grpc.CallOption) error) error } +type BroadCastKind int + +const ( + READ BroadCastKind = iota + WRITE +) + type gateway struct { client discoverer.Client eg errgroup.Group @@ -60,7 +67,7 @@ func (g *gateway) Start(ctx context.Context) (<-chan error, error) { return g.client.Start(ctx) } -func (g *gateway) BroadCast(ctx context.Context, +func (g *gateway) BroadCast(ctx context.Context, kind BroadCastKind, f func(ctx context.Context, target string, ac vald.Client, copts ...grpc.CallOption) error, ) (err error) { fctx, span := trace.StartSpan(ctx, "vald/gateway-lb/service/Gateway.BroadCast") @@ -69,7 +76,16 @@ func (g *gateway) BroadCast(ctx context.Context, span.End() } }() - return g.client.GetClient().RangeConcurrent(fctx, -1, func(ictx context.Context, + + var client grpc.Client + switch kind { + case READ: + client = g.client.GetReadClient() + case WRITE: + client = g.client.GetClient() + } + + return client.RangeConcurrent(fctx, -1, func(ictx context.Context, addr string, conn *grpc.ClientConn, copts ...grpc.CallOption, ) (err error) { select { diff --git a/pkg/gateway/lb/usecase/vald.go b/pkg/gateway/lb/usecase/vald.go index b66d9068cf..b8dc5e3948 100644 --- a/pkg/gateway/lb/usecase/vald.go +++ b/pkg/gateway/lb/usecase/vald.go @@ -46,6 +46,37 @@ type run struct { gateway service.Gateway } +func discovererClient(cfg *config.Data, dopts []grpc.Option, aopts []grpc.Option, eg errgroup.Group) (discoverer.Client, error) { + var discovererOpts []discoverer.Option + discovererOpts = append(discovererOpts, + discoverer.WithAutoConnect(true), + discoverer.WithName(cfg.Gateway.AgentName), + discoverer.WithNamespace(cfg.Gateway.AgentNamespace), + discoverer.WithPort(cfg.Gateway.AgentPort), + discoverer.WithServiceDNSARecord(cfg.Gateway.AgentDNS), + discoverer.WithDiscovererClient(grpc.New(dopts...)), + discoverer.WithDiscoverDuration(cfg.Gateway.Discoverer.Duration), + discoverer.WithOptions(aopts...), + discoverer.WithNodeName(cfg.Gateway.NodeName), + discoverer.WithReadReplicaReplicas(cfg.Gateway.ReadReplicaReplicas), + ) + + rrOpts, err := cfg.Gateway.ReadReplicaClient.Client.Opts() + if err != nil { + return nil, err + } + // only append when read replica is enabled + if rrOpts != nil { + rrOpts = append(rrOpts, + grpc.WithErrGroup(eg), + grpc.WithConnectionPoolSize(int(cfg.Gateway.ReadReplicaReplicas)), + ) + discovererOpts = append(discovererOpts, discoverer.WithReadReplicaClient(grpc.New(rrOpts...))) + } + + return discoverer.New(discovererOpts...) +} + func New(cfg *config.Data) (r runner.Runner, err error) { eg := errgroup.Get() @@ -55,6 +86,7 @@ func New(cfg *config.Data) (r runner.Runner, err error) { if err != nil { return nil, err } + // skipcq: CRT-D0001 dopts := append( cOpts, @@ -68,20 +100,11 @@ func New(cfg *config.Data) (r runner.Runner, err error) { acOpts, grpc.WithErrGroup(eg)) - client, err := discoverer.New( - discoverer.WithAutoConnect(true), - discoverer.WithName(cfg.Gateway.AgentName), - discoverer.WithNamespace(cfg.Gateway.AgentNamespace), - discoverer.WithPort(cfg.Gateway.AgentPort), - discoverer.WithServiceDNSARecord(cfg.Gateway.AgentDNS), - discoverer.WithDiscovererClient(grpc.New(dopts...)), - discoverer.WithDiscoverDuration(cfg.Gateway.Discoverer.Duration), - discoverer.WithOptions(aopts...), - discoverer.WithNodeName(cfg.Gateway.NodeName), - ) + client, err := discovererClient(cfg, dopts, aopts, eg) if err != nil { return nil, err } + gateway, err = service.NewGateway( service.WithErrGroup(eg), service.WithDiscoverer(client), diff --git a/pkg/gateway/lb/usecase/vald_test.go b/pkg/gateway/lb/usecase/vald_test.go index 9444c474a6..c6e32773f0 100644 --- a/pkg/gateway/lb/usecase/vald_test.go +++ b/pkg/gateway/lb/usecase/vald_test.go @@ -17,6 +17,97 @@ // Package usecase represents gateways usecase layer package usecase +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/vdaas/vald/internal/client/v1/client/discoverer" + iconfig "github.com/vdaas/vald/internal/config" + "github.com/vdaas/vald/internal/sync/errgroup" + "github.com/vdaas/vald/pkg/gateway/lb/config" + + "github.com/vdaas/vald/internal/net/grpc" +) + +func Test_discovererClient(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *config.Data + dopts []grpc.Option + aopts []grpc.Option + assert func(*testing.T, discoverer.Client, error) + }{ + { + name: "Not create read replica client when read replica client option is not set", + cfg: &config.Data{ + Gateway: &iconfig.LB{ + AgentName: "agent", + AgentNamespace: "agent-ns", + AgentPort: 8081, + AgentDNS: "agent-dns", + Discoverer: &iconfig.DiscovererClient{ + Duration: "1m", + }, + NodeName: "node", + }, + }, + dopts: []grpc.Option{}, + aopts: []grpc.Option{}, + assert: func(t *testing.T, client discoverer.Client, err error) { + require.NoError(t, err) + + // check multiple times to ensure that the client is not a read replica client + require.Equal(t, client.GetClient(), client.GetReadClient()) + require.Equal(t, client.GetClient(), client.GetReadClient()) + require.Equal(t, client.GetClient(), client.GetReadClient()) + }, + }, + { + name: "create read replica client when read replica client option is set", + cfg: &config.Data{ + Gateway: &iconfig.LB{ + AgentName: "agent", + AgentNamespace: "agent-ns", + AgentPort: 8081, + AgentDNS: "agent-dns", + Discoverer: &iconfig.DiscovererClient{ + Duration: "1m", + }, + NodeName: "node", + ReadReplicaClient: iconfig.ReadReplicaClient{ + Client: &iconfig.GRPCClient{}, + }, + // set this to big enough value to ensure that the round robin counter won't reset to 0 + ReadReplicaReplicas: 100, + }, + }, + dopts: []grpc.Option{}, + aopts: []grpc.Option{}, + assert: func(t *testing.T, client discoverer.Client, err error) { + require.NoError(t, err) + + // ensure that GetReadClient() returns a read replica client by calling it multiple times beforehand + // and increments the round robin counter + client.GetReadClient() + client.GetReadClient() + client.GetReadClient() + + require.NotEqual(t, client.GetClient(), client.GetReadClient()) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + client, err := discovererClient(tt.cfg, tt.dopts, tt.aopts, errgroup.Get()) + tt.assert(t, client, err) + }) + } +} + // NOT IMPLEMENTED BELOW // // func TestNew(t *testing.T) { diff --git a/pkg/index/job/correction/service/corrector_test.go b/pkg/index/job/correction/service/corrector_test.go index 43ad7517f1..2329cbc2e8 100644 --- a/pkg/index/job/correction/service/corrector_test.go +++ b/pkg/index/job/correction/service/corrector_test.go @@ -41,6 +41,10 @@ func (m *mockDiscovererClient) GetClient() grpc.Client { return &m.client } +func (m *mockDiscovererClient) GetReadClient() grpc.Client { + return &m.client +} + func Test_correct_correctTimestamp(t *testing.T) { t.Parallel() diff --git a/pkg/index/job/readreplica/rotate/service/rotator.go b/pkg/index/job/readreplica/rotate/service/rotator.go index 9a845f4c28..b5bd90b67e 100644 --- a/pkg/index/job/readreplica/rotate/service/rotator.go +++ b/pkg/index/job/readreplica/rotate/service/rotator.go @@ -241,6 +241,7 @@ func (r *rotator) createPVC(ctx context.Context, newSnapShot string, deployment Kind: cur.Spec.DataSource.Kind, APIGroup: cur.Spec.DataSource.APIGroup, }, + StorageClassName: cur.Spec.StorageClassName, }, }