diff --git a/clientv3/balancer.go b/clientv3/balancer.go index 7e8d957f9cf..284a22297f1 100644 --- a/clientv3/balancer.go +++ b/clientv3/balancer.go @@ -42,6 +42,11 @@ type simpleBalancer struct { // upc closes when upEps transitions from empty to non-zero or the balancer closes. upc chan struct{} + // grpc issues TLS cert checks using the string passed into dial so + // that string must be the host. To recover the full scheme://host URL, + // have a map from hosts to the original endpoint. + host2ep map[string]string + // pinAddr is the currently pinned address; set to the empty string on // intialization and shutdown. pinAddr string @@ -62,6 +67,7 @@ func newSimpleBalancer(eps []string) *simpleBalancer { readyc: make(chan struct{}), upEps: make(map[string]struct{}), upc: make(chan struct{}), + host2ep: getHost2ep(eps), } return sb } @@ -74,6 +80,35 @@ func (b *simpleBalancer) ConnectNotify() <-chan struct{} { return b.upc } +func (b *simpleBalancer) getEndpoint(host string) string { + b.mu.Lock() + defer b.mu.Unlock() + return b.host2ep[host] +} + +func getHost2ep(eps []string) map[string]string { + hm := make(map[string]string, len(eps)) + for i := range eps { + _, host, _ := parseEndpoint(eps[i]) + hm[host] = eps[i] + } + return hm +} + +func (b *simpleBalancer) updateAddrs(eps []string) { + b.mu.Lock() + defer b.mu.Unlock() + + b.host2ep = getHost2ep(eps) + + addrs := make([]grpc.Address, 0, len(eps)) + for i := range eps { + addrs = append(addrs, grpc.Address{Addr: getHost(eps[i])}) + } + b.addrs = addrs + b.notifyCh <- addrs +} + func (b *simpleBalancer) Up(addr grpc.Address) func(error) { b.mu.Lock() defer b.mu.Unlock() diff --git a/clientv3/client.go b/clientv3/client.go index 1ca94eb6c9e..d36105ba867 100644 --- a/clientv3/client.go +++ b/clientv3/client.go @@ -99,6 +99,12 @@ func (c *Client) Ctx() context.Context { return c.ctx } // Endpoints lists the registered endpoints for the client. func (c *Client) Endpoints() []string { return c.cfg.Endpoints } +// SetEndpoints updates client's endpoints. +func (c *Client) SetEndpoints(eps ...string) { + c.cfg.Endpoints = eps + c.balancer.updateAddrs(eps) +} + type authTokenCredential struct { token string } @@ -113,19 +119,31 @@ func (cred authTokenCredential) GetRequestMetadata(ctx context.Context, s ...str }, nil } -func (c *Client) dialTarget(endpoint string) (proto string, host string, creds *credentials.TransportCredentials) { +func parseEndpoint(endpoint string) (proto string, host string, scheme bool) { proto = "tcp" host = endpoint - creds = c.creds url, uerr := url.Parse(endpoint) if uerr != nil || !strings.Contains(endpoint, "://") { return } + scheme = true + // strip scheme:// prefix since grpc dials by host host = url.Host switch url.Scheme { + case "http", "https": case "unix": proto = "unix" + default: + proto, host = "", "" + } + return +} + +func (c *Client) processCreds(protocol string) (creds *credentials.TransportCredentials) { + creds = c.creds + switch protocol { + case "unix": case "http": creds = nil case "https": @@ -136,7 +154,7 @@ func (c *Client) dialTarget(endpoint string) (proto string, host string, creds * emptyCreds := credentials.NewTLS(tlsconfig) creds = &emptyCreds default: - return "", "", nil + creds = nil } return } @@ -148,17 +166,8 @@ func (c *Client) dialSetupOpts(endpoint string, dopts ...grpc.DialOption) (opts } opts = append(opts, dopts...) - // grpc issues TLS cert checks using the string passed into dial so - // that string must be the host. To recover the full scheme://host URL, - // have a map from hosts to the original endpoint. - host2ep := make(map[string]string) - for i := range c.cfg.Endpoints { - _, host, _ := c.dialTarget(c.cfg.Endpoints[i]) - host2ep[host] = c.cfg.Endpoints[i] - } - f := func(host string, t time.Duration) (net.Conn, error) { - proto, host, _ := c.dialTarget(host2ep[host]) + proto, host, _ := parseEndpoint(c.balancer.getEndpoint(host)) if proto == "" { return nil, fmt.Errorf("unknown scheme for %q", host) } @@ -171,7 +180,10 @@ func (c *Client) dialSetupOpts(endpoint string, dopts ...grpc.DialOption) (opts } opts = append(opts, grpc.WithDialer(f)) - _, _, creds := c.dialTarget(endpoint) + creds := c.creds + if proto, _, scheme := parseEndpoint(endpoint); scheme { + creds = c.processCreds(proto) + } if creds != nil { opts = append(opts, grpc.WithTransportCredentials(*creds)) } else { diff --git a/clientv3/integration/dial_test.go b/clientv3/integration/dial_test.go new file mode 100644 index 00000000000..9d9e6b47f27 --- /dev/null +++ b/clientv3/integration/dial_test.go @@ -0,0 +1,60 @@ +// Copyright 2016 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration + +import ( + "math/rand" + "testing" + "time" + + "github.com/coreos/etcd/clientv3" + "github.com/coreos/etcd/integration" + "github.com/coreos/etcd/pkg/testutil" + "golang.org/x/net/context" +) + +// TestDialSetEndpoints ensures SetEndpoints can replace unavailable endpoints with available ones. +func TestDialSetEndpoints(t *testing.T) { + defer testutil.AfterTest(t) + clus := integration.NewClusterV3(t, &integration.ClusterConfig{Size: 3}) + defer clus.Terminate(t) + + // get endpoint list + eps := make([]string, 3) + for i := range eps { + eps[i] = clus.Members[i].GRPCAddr() + } + toKill := rand.Intn(len(eps)) + + cfg := clientv3.Config{Endpoints: []string{eps[toKill]}, DialTimeout: 1 * time.Second} + cli, err := clientv3.New(cfg) + if err != nil { + t.Fatal(err) + } + defer cli.Close() + + // make a dead node + clus.Members[toKill].Stop(t) + clus.WaitLeader(t) + + // update client with available endpoints + cli.SetEndpoints(eps[(toKill+1)%3]) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + if _, err = cli.Get(ctx, "foo", clientv3.WithSerializable()); err != nil { + t.Fatal(err) + } + cancel() +}