diff --git a/credentials/credentials.go b/credentials/credentials.go index 8ea3d4a1dc28..dd01237000cb 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -30,9 +30,9 @@ import ( "fmt" "io/ioutil" "net" - "strings" "github.com/golang/protobuf/proto" + "google.golang.org/grpc/credentials/internal" ) @@ -166,11 +166,12 @@ func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawCon // use local cfg to avoid clobbering ServerName if using multiple endpoints cfg := cloneTLSConfig(c.config) if cfg.ServerName == "" { - colonPos := strings.LastIndex(authority, ":") - if colonPos == -1 { - colonPos = len(authority) + serverName, _, err := net.SplitHostPort(authority) + if err != nil { + // If the authority had no host port or if the authority cannot be parsed, use it as-is. + serverName = authority } - cfg.ServerName = authority[:colonPos] + cfg.ServerName = serverName } conn := tls.Client(rawConn, cfg) errChannel := make(chan error, 1) diff --git a/credentials/credentials_test.go b/credentials/credentials_test.go index b15458636f1a..cc3aba0e3c96 100644 --- a/credentials/credentials_test.go +++ b/credentials/credentials_test.go @@ -23,6 +23,7 @@ import ( "crypto/tls" "net" "reflect" + "strings" "testing" "google.golang.org/grpc/testdata" @@ -55,18 +56,40 @@ func TestTLSClone(t *testing.T) { type serverHandshake func(net.Conn) (AuthInfo, error) func TestClientHandshakeReturnsAuthInfo(t *testing.T) { - done := make(chan AuthInfo, 1) - lis := launchServer(t, tlsServerHandshake, done) - defer lis.Close() - lisAddr := lis.Addr().String() - clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr) - // wait until server sends serverAuthInfo or fails. - serverAuthInfo, ok := <-done - if !ok { - t.Fatalf("Error at server-side") + tcs := []struct { + name string + address string + }{ + { + name: "localhost", + address: "localhost:0", + }, + { + name: "ipv4", + address: "127.0.0.1:0", + }, + { + name: "ipv6", + address: "[::1]:0", + }, } - if !compare(clientAuthInfo, serverAuthInfo) { - t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo) + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + done := make(chan AuthInfo, 1) + lis := launchServerOnListenAddress(t, tlsServerHandshake, done, tc.address) + defer lis.Close() + lisAddr := lis.Addr().String() + clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr) + // wait until server sends serverAuthInfo or fails. + serverAuthInfo, ok := <-done + if !ok { + t.Fatalf("Error at server-side") + } + if !compare(clientAuthInfo, serverAuthInfo) { + t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo) + } + }) } } @@ -121,8 +144,15 @@ func compare(a1, a2 AuthInfo) bool { } func launchServer(t *testing.T, hs serverHandshake, done chan AuthInfo) net.Listener { - lis, err := net.Listen("tcp", "localhost:0") + return launchServerOnListenAddress(t, hs, done, "localhost:0") +} + +func launchServerOnListenAddress(t *testing.T, hs serverHandshake, done chan AuthInfo, address string) net.Listener { + lis, err := net.Listen("tcp", address) if err != nil { + if strings.Contains(err.Error(), "bind: cannot assign requested address") { + t.Skip("missing IPv6 support") + } t.Fatalf("Failed to listen: %v", err) } go serverHandle(t, hs, done, lis)