Skip to content

Commit

Permalink
client: synchronously verify server preface in newClientTransport (#5731
Browse files Browse the repository at this point in the history
)
  • Loading branch information
dfawley authored Oct 20, 2022
1 parent f51d212 commit 9127159
Show file tree
Hide file tree
Showing 7 changed files with 266 additions and 223 deletions.
119 changes: 41 additions & 78 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1228,38 +1228,33 @@ func (ac *addrConn) tryAllAddrs(addrs []resolver.Address, connectDeadline time.T
// address was not successfully connected, or updates ac appropriately with the
// new transport.
func (ac *addrConn) createTransport(addr resolver.Address, copts transport.ConnectOptions, connectDeadline time.Time) error {
// TODO: Delete prefaceReceived and move the logic to wait for it into the
// transport.
prefaceReceived := grpcsync.NewEvent()
connClosed := grpcsync.NewEvent()

addr.ServerName = ac.cc.getServerName(addr)
hctx, hcancel := context.WithCancel(ac.ctx)
hcStarted := false // protected by ac.mu

onClose := func() {
onClose := grpcsync.OnceFunc(func() {
ac.mu.Lock()
defer ac.mu.Unlock()
defer connClosed.Fire()
defer hcancel()
if !hcStarted || hctx.Err() != nil {
// We didn't start the health check or set the state to READY, so
// no need to do anything else here.
//
// OR, we have already cancelled the health check context, meaning
// we have already called onClose once for this transport. In this
// case it would be dangerous to clear the transport and update the
// state, since there may be a new transport in this addrConn.
if ac.state == connectivity.Shutdown {
// Already shut down. tearDown() already cleared the transport and
// canceled hctx via ac.ctx, and we expected this connection to be
// closed, so do nothing here.
return
}
hcancel()
if ac.transport == nil {
// We're still connecting to this address, which could error. Do
// not update the connectivity state or resolve; these will happen
// at the end of the tryAllAddrs connection loop in the event of an
// error.
return
}
ac.transport = nil
// Refresh the name resolver
// Refresh the name resolver on any connection loss.
ac.cc.resolveNow(resolver.ResolveNowOptions{})
if ac.state != connectivity.Shutdown {
ac.updateConnectivityState(connectivity.Idle, nil)
}
}

// Always go idle and wait for the LB policy to initiate a new
// connection attempt.
ac.updateConnectivityState(connectivity.Idle, nil)
})
onGoAway := func(r transport.GoAwayReason) {
ac.mu.Lock()
ac.adjustParams(r)
Expand All @@ -1271,74 +1266,42 @@ func (ac *addrConn) createTransport(addr resolver.Address, copts transport.Conne
defer cancel()
copts.ChannelzParentID = ac.channelzID

newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addr, copts, func() { prefaceReceived.Fire() }, onGoAway, onClose)
newTr, err := transport.NewClientTransport(connectCtx, ac.cc.ctx, addr, copts, onGoAway, onClose)
if err != nil {
// newTr is either nil, or closed.
hcancel()
channelz.Warningf(logger, ac.channelzID, "grpc: addrConn.createTransport failed to connect to %s. Err: %v", addr, err)
return err
}

select {
case <-connectCtx.Done():
// We didn't get the preface in time.
ac.mu.Lock()
defer ac.mu.Unlock()
if ac.state == connectivity.Shutdown {
// This can happen if the subConn was removed while in `Connecting`
// state. tearDown() would have set the state to `Shutdown`, but
// would not have closed the transport since ac.transport would not
// have been set at that point.
//
// We run this in a goroutine because newTr.Close() calls onClose()
// inline, which requires locking ac.mu.
//
// The error we pass to Close() is immaterial since there are no open
// streams at this point, so no trailers with error details will be sent
// out. We just need to pass a non-nil error.
newTr.Close(transport.ErrConnClosing)
if connectCtx.Err() == context.DeadlineExceeded {
err := errors.New("failed to receive server preface within timeout")
channelz.Warningf(logger, ac.channelzID, "grpc: addrConn.createTransport failed to connect to %s: %v", addr, err)
return err
}
go newTr.Close(transport.ErrConnClosing)
return nil
case <-prefaceReceived.Done():
// We got the preface - huzzah! things are good.
ac.mu.Lock()
defer ac.mu.Unlock()
if connClosed.HasFired() {
// onClose called first; go idle but do nothing else.
if ac.state != connectivity.Shutdown {
ac.updateConnectivityState(connectivity.Idle, nil)
}
return nil
}
if ac.state == connectivity.Shutdown {
// This can happen if the subConn was removed while in `Connecting`
// state. tearDown() would have set the state to `Shutdown`, but
// would not have closed the transport since ac.transport would not
// been set at that point.
//
// We run this in a goroutine because newTr.Close() calls onClose()
// inline, which requires locking ac.mu.
//
// The error we pass to Close() is immaterial since there are no open
// streams at this point, so no trailers with error details will be sent
// out. We just need to pass a non-nil error.
go newTr.Close(transport.ErrConnClosing)
return nil
}
ac.curAddr = addr
ac.transport = newTr
hcStarted = true
ac.startHealthCheck(hctx) // Will set state to READY if appropriate.
}
if hctx.Err() != nil {
// onClose was already called for this connection, but the connection
// was successfully established first. Consider it a success and set
// the new state to Idle.
ac.updateConnectivityState(connectivity.Idle, nil)
return nil
case <-connClosed.Done():
// The transport has already closed. If we received the preface, too,
// this is not an error and go idle.
select {
case <-prefaceReceived.Done():
ac.mu.Lock()
defer ac.mu.Unlock()

if ac.state != connectivity.Shutdown {
ac.updateConnectivityState(connectivity.Idle, nil)
}
return nil
default:
return errors.New("connection closed before server preface received")
}
}
ac.curAddr = addr
ac.transport = newTr
ac.startHealthCheck(hctx) // Will set state to READY if appropriate.
return nil
}

// startHealthCheck starts the health checking stream (RPC) to watch the health
Expand Down
32 changes: 32 additions & 0 deletions internal/grpcsync/oncefunc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
*
* Copyright 2022 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package grpcsync

import (
"sync"
)

// OnceFunc returns a function wrapping f which ensures f is only executed
// once even if the returned function is executed multiple times.
func OnceFunc(f func()) func() {
var once sync.Once
return func() {
once.Do(f)
}
}
53 changes: 53 additions & 0 deletions internal/grpcsync/oncefunc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
*
* Copyright 2022 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package grpcsync

import (
"sync"
"sync/atomic"
"testing"
"time"
)

// TestOnceFunc tests that a OnceFunc is executed only once even with multiple
// simultaneous callers of it.
func (s) TestOnceFunc(t *testing.T) {
var v int32
inc := OnceFunc(func() { atomic.AddInt32(&v, 1) })

const numWorkers = 100
var wg sync.WaitGroup // Blocks until all workers have called inc.
wg.Add(numWorkers)

block := NewEvent() // Signal to worker goroutines to call inc

for i := 0; i < numWorkers; i++ {
go func() {
<-block.Done() // Wait for a signal.
inc() // Call the OnceFunc.
wg.Done()
}()
}
time.Sleep(time.Millisecond) // Allow goroutines to get to the block.
block.Fire() // Unblock them.
wg.Wait() // Wait for them to complete.
if v != 1 {
t.Fatalf("OnceFunc() called %v times; want 1", v)
}
}
Loading

0 comments on commit 9127159

Please sign in to comment.