diff --git a/clientconn.go b/clientconn.go index 9209eda8884f..c41e7825c680 100644 --- a/clientconn.go +++ b/clientconn.go @@ -1210,10 +1210,16 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne onCloseCalled := make(chan struct{}) reconnect := grpcsync.NewEvent() + authority := ac.cc.authority + // addr.ServerName takes precedent over ClientConn authority, if present. + if addr.ServerName != "" { + authority = addr.ServerName + } + target := transport.TargetInfo{ Addr: addr.Addr, Metadata: addr.Metadata, - Authority: ac.cc.authority, + Authority: authority, } once := sync.Once{} diff --git a/resolver/resolver.go b/resolver/resolver.go index afc8e6a55c0f..147e2314c334 100644 --- a/resolver/resolver.go +++ b/resolver/resolver.go @@ -86,9 +86,16 @@ type Address struct { // Type is the type of this address. Type AddressType // ServerName is the name of this address. + // If non-empty, the ServerName is used as the transport certification authority for + // the address, instead of the hostname from the Dial target string. In most cases, + // this should not be set. // - // e.g. if Type is GRPCLB, ServerName should be the name of the remote load + // If Type is GRPCLB, ServerName should be the name of the remote load // balancer, not the name of the backend. + // + // WARNING: ServerName must only be populated with trusted values. It + // is insecure to populate it with data from untrusted inputs since untrusted + // values could be used to bypass the authority checks performed by TLS. ServerName string // Metadata is the information associated with Addr, which may be used // to make load balancing decision. diff --git a/test/end2end_test.go b/test/end2end_test.go index 22fa389215bb..03a4f5c4f4b8 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -4969,6 +4969,49 @@ func (s) TestCredsHandshakeAuthority(t *testing.T) { } } +// This test makes sure that the authority client handshake gets is the endpoint +// of the ServerName of the address when it is set. +func (s) TestCredsHandshakeServerNameAuthority(t *testing.T) { + const testAuthority = "test.auth.ori.ty" + const testServerName = "test.server.name" + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + cred := &authorityCheckCreds{} + s := grpc.NewServer() + go s.Serve(lis) + defer s.Stop() + + r, rcleanup := manual.GenerateAndRegisterManualResolver() + defer rcleanup() + + cc, err := grpc.Dial(r.Scheme()+":///"+testAuthority, grpc.WithTransportCredentials(cred)) + if err != nil { + t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err) + } + defer cc.Close() + r.UpdateState(resolver.State{Addresses: []resolver.Address{{Addr: lis.Addr().String(), ServerName: testServerName}}}) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + for { + s := cc.GetState() + if s == connectivity.Ready { + break + } + if !cc.WaitForStateChange(ctx, s) { + // ctx got timeout or canceled. + t.Fatalf("ClientConn is not ready after 100 ms") + } + } + + if cred.got != testServerName { + t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority) + } +} + type clientFailCreds struct{} func (c *clientFailCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {