Skip to content

Commit

Permalink
Add option to disable dns resolve (vdaas#2634)
Browse files Browse the repository at this point in the history
* fix: add option to disable dns resolve

Signed-off-by: hlts2 <[email protected]>

* fix: add mock function

Signed-off-by: hlts2 <[email protected]>

* fix: unimplemented error

Signed-off-by: hlts2 <[email protected]>

* fix: change ForwardedContext method to private

Signed-off-by: hlts2 <[email protected]>

---------

Signed-off-by: hlts2 <[email protected]>
  • Loading branch information
hlts2 authored and takuyaymd committed Dec 2, 2024
1 parent 1298adb commit 1052167
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 30 deletions.
1 change: 1 addition & 0 deletions apis/grpc/v1/payload/payload.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions apis/grpc/v1/rpc/errdetails/error_details.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

54 changes: 34 additions & 20 deletions internal/net/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,30 +87,32 @@ type Client interface {
GetDialOption() []DialOption
GetCallOption() []CallOption
GetBackoff() backoff.Backoff
SetDisableResolveDNSAddr(addr string, disabled bool)
ConnectedAddrs() []string
Close(ctx context.Context) error
}

type gRPCClient struct {
addrs map[string]struct{}
poolSize uint64
clientCount uint64
conns sync.Map[string, pool.Conn]
hcDur time.Duration
prDur time.Duration
dialer net.Dialer
enablePoolRebalance bool
resolveDNS bool
dopts []DialOption
copts []CallOption
roccd string // reconnection old connection closing duration
eg errgroup.Group
bo backoff.Backoff
cb circuitbreaker.CircuitBreaker
gbo gbackoff.Config // grpc's original backoff configuration
mcd time.Duration // minimum connection timeout duration
group singleflight.Group[pool.Conn]
crl sync.Map[string, bool] // connection request list
addrs map[string]struct{}
poolSize uint64
clientCount uint64
conns sync.Map[string, pool.Conn]
hcDur time.Duration
prDur time.Duration
dialer net.Dialer
enablePoolRebalance bool
disableResolveDNSAddrs sync.Map[string, bool]
resolveDNS bool
dopts []DialOption
copts []CallOption
roccd string // reconnection old connection closing duration
eg errgroup.Group
bo backoff.Backoff
cb circuitbreaker.CircuitBreaker
gbo gbackoff.Config // grpc's original backoff configuration
mcd time.Duration // minimum connection timeout duration
group singleflight.Group[pool.Conn]
crl sync.Map[string, bool] // connection request list

ech <-chan error
monitorRunning atomic.Bool
Expand Down Expand Up @@ -946,6 +948,12 @@ func (g *gRPCClient) GetBackoff() backoff.Backoff {
return g.bo
}

func (g *gRPCClient) SetDisableResolveDNSAddr(addr string, disabled bool) {
// NOTE: When connecting to multiple locations, it was necessary to switch dynamically, so implementation was added.
// There is no setting for disable on the helm chart side, so I used this implementation.
g.disableResolveDNSAddrs.Store(addr, disabled)
}

func (g *gRPCClient) Connect(
ctx context.Context, addr string, dopts ...DialOption,
) (conn pool.Conn, err error) {
Expand Down Expand Up @@ -975,7 +983,13 @@ func (g *gRPCClient) Connect(
pool.WithAddr(addr),
pool.WithSize(g.poolSize),
pool.WithDialOptions(append(g.dopts, dopts...)...),
pool.WithResolveDNS(g.resolveDNS),
pool.WithResolveDNS(func() bool {
disabled, ok := g.disableResolveDNSAddrs.Load(addr)
if ok && disabled {
return false
}
return g.resolveDNS
}()),
}
if g.bo != nil {
opts = append(opts, pool.WithBackoff(g.bo))
Expand Down
14 changes: 10 additions & 4 deletions internal/test/mock/grpc/grpc_client_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ type GRPCClientMock struct {
addr string,
conn *grpc.ClientConn,
copts ...grpc.CallOption) error) error
ConnectFunc func(ctx context.Context, addr string, dopts ...grpc.DialOption) (pool.Conn, error)
DisconnectFunc func(ctx context.Context, addr string) error
IsConnectedFunc func(ctx context.Context, addr string) bool
ConnectedAddrsFunc func() []string
ConnectFunc func(ctx context.Context, addr string, dopts ...grpc.DialOption) (pool.Conn, error)
DisconnectFunc func(ctx context.Context, addr string) error
IsConnectedFunc func(ctx context.Context, addr string) bool
ConnectedAddrsFunc func() []string
SetDisableResolveDNSAddrFunc func(addr string, disabled bool)
}

// OrderedRangeConcurrent calls the OrderedRangeConcurrentFunc object.
Expand Down Expand Up @@ -70,3 +71,8 @@ func (gc *GRPCClientMock) Disconnect(ctx context.Context, addr string) error {
func (gc *GRPCClientMock) IsConnected(ctx context.Context, addr string) bool {
return gc.IsConnectedFunc(ctx, addr)
}

// SetDisableResolveDNSAddr calls the SetDisableResolveDNSAddr object.
func (gc *GRPCClientMock) SetDisableResolveDNSAddr(addr string, disabled bool) {
gc.SetDisableResolveDNSAddrFunc(addr, disabled)
}
2 changes: 2 additions & 0 deletions internal/test/mock/grpc_testify_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,5 @@ func (c *ClientInternal) Close(ctx context.Context) error {
args := c.Called(ctx)
return args.Error(0)
}

func (c *ClientInternal) SetDisableResolveDNSAddr(addr string, distributed bool) {}
22 changes: 16 additions & 6 deletions pkg/gateway/mirror/service/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package service
import (
"context"
"reflect"
"strings"

"github.com/vdaas/vald/internal/client/v1/client/mirror"
"github.com/vdaas/vald/internal/errors"
Expand All @@ -32,7 +33,6 @@ const (

// Gateway represents an interface for interacting with gRPC clients.
type Gateway interface {
ForwardedContext(ctx context.Context, podName string) context.Context
FromForwardedContext(ctx context.Context) string
BroadCast(ctx context.Context,
f func(ctx context.Context, target string, vc MirrorClient, copts ...grpc.CallOption) error) error
Expand Down Expand Up @@ -73,9 +73,9 @@ func (g *gateway) GRPCClient() grpc.Client {
return g.client.GRPCClient()
}

// ForwardedContext takes a context and a podName, returning a new context
// forwardedContext takes a context and a podName, returning a new context
// with additional information related to forwarding.
func (*gateway) ForwardedContext(ctx context.Context, podName string) context.Context {
func (*gateway) forwardedContext(ctx context.Context, podName string) context.Context {
return grpc.NewOutgoingContext(ctx, grpc.MD{
forwardedContextKey: []string{
podName,
Expand Down Expand Up @@ -113,7 +113,7 @@ func (g *gateway) BroadCast(
span.End()
}
}()
return g.client.GRPCClient().RangeConcurrent(g.ForwardedContext(ctx, g.podName), -1, func(ictx context.Context,
return g.client.GRPCClient().RangeConcurrent(g.forwardedContext(ctx, g.podName), -1, func(ictx context.Context,
addr string, conn *grpc.ClientConn, copts ...grpc.CallOption,
) (err error) {
select {
Expand Down Expand Up @@ -143,11 +143,21 @@ func (g *gateway) Do(
if target == "" {
return nil, errors.ErrTargetNotFound
}
return g.client.GRPCClient().Do(g.ForwardedContext(ctx, g.podName), target,
fctx := g.forwardedContext(ctx, g.podName)
res, err = g.client.GRPCClient().Do(fctx, target,
func(ictx context.Context, conn *grpc.ClientConn, copts ...grpc.CallOption) (any, error) {
return f(ictx, target, NewMirrorClient(conn), copts...)
},
)
if err != nil {
return g.client.GRPCClient().RoundRobin(fctx, func(ictx context.Context, conn *grpc.ClientConn, copts ...grpc.CallOption) (any, error) {
if strings.EqualFold(conn.Target(), target) {
return nil, errors.ErrTargetNotFound
}
return f(ictx, conn.Target(), NewMirrorClient(conn), copts...)
})
}
return res, nil
}

// DoMulti performs a gRPC operation on multiple targets using the provided function.
Expand All @@ -168,7 +178,7 @@ func (g *gateway) DoMulti(
if len(targets) == 0 {
return errors.ErrTargetNotFound
}
return g.client.GRPCClient().OrderedRangeConcurrent(g.ForwardedContext(ctx, g.podName), targets, -1,
return g.client.GRPCClient().OrderedRangeConcurrent(g.forwardedContext(ctx, g.podName), targets, -1,
func(ictx context.Context, addr string, conn *grpc.ClientConn, copts ...grpc.CallOption) (err error) {
select {
case <-ictx.Done():
Expand Down
1 change: 1 addition & 0 deletions pkg/gateway/mirror/service/mirror.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ func (m *mirr) Connect(ctx context.Context, targets ...*payload.Mirror_Target) e
if !m.isSelfMirrorAddr(addr) && !m.isGatewayAddr(addr) {
_, ok := m.addrs.Load(addr)
if !ok || !m.IsConnected(ctx, addr) {
m.gateway.GRPCClient().SetDisableResolveDNSAddr(addr, true)
_, err := m.gateway.GRPCClient().Connect(ctx, addr)
if err != nil {
m.addrs.Delete(addr)
Expand Down
6 changes: 6 additions & 0 deletions pkg/gateway/mirror/service/mirror_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ func Test_mirr_Connect(t *testing.T) {
ConnectFunc: func(_ context.Context, _ string, _ ...grpc.DialOption) (conn pool.Conn, err error) {
return conn, err
},
SetDisableResolveDNSAddrFunc: func(addr string, disabled bool) {},
}
},
},
Expand Down Expand Up @@ -118,6 +119,7 @@ func Test_mirr_Connect(t *testing.T) {
ConnectFunc: func(_ context.Context, _ string, _ ...grpc.DialOption) (pool.Conn, error) {
return nil, errors.New("missing port in address")
},
SetDisableResolveDNSAddrFunc: func(addr string, disabled bool) {},
}
},
},
Expand Down Expand Up @@ -221,6 +223,7 @@ func Test_mirr_Disconnect(t *testing.T) {
DisconnectFunc: func(_ context.Context, _ string) error {
return nil
},
SetDisableResolveDNSAddrFunc: func(addr string, disabled bool) {},
}
},
},
Expand Down Expand Up @@ -252,6 +255,7 @@ func Test_mirr_Disconnect(t *testing.T) {
DisconnectFunc: func(_ context.Context, _ string) error {
return errors.New("missing port in address")
},
SetDisableResolveDNSAddrFunc: func(addr string, disabled bool) {},
}
},
},
Expand Down Expand Up @@ -373,6 +377,7 @@ func Test_mirr_MirrorTargets(t *testing.T) {
IsConnectedFunc: func(_ context.Context, addr string) bool {
return connected[addr]
},
SetDisableResolveDNSAddrFunc: func(addr string, disabled bool) {},
}
},
},
Expand Down Expand Up @@ -498,6 +503,7 @@ func Test_mirr_connectedOtherMirrorAddrs(t *testing.T) {
IsConnectedFunc: func(_ context.Context, addr string) bool {
return connected[addr]
},
SetDisableResolveDNSAddrFunc: func(addr string, disabled bool) {},
}
},
},
Expand Down

0 comments on commit 1052167

Please sign in to comment.