Skip to content

Commit

Permalink
pgwire: remove readTimeoutConn in favor of a channel
Browse files Browse the repository at this point in the history
Rather than using a connection that polls for the context being done
every second, we now spin up an additional goroutine that blocks until
the connection context is done, or the drain signal was received.

Release note: None
  • Loading branch information
rafiss committed May 20, 2024
1 parent e60b492 commit a65fea7
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 141 deletions.
2 changes: 1 addition & 1 deletion pkg/sql/pgwire/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ func (c *conn) findAuthenticationMethod(
// If the client is using SSL, retrieve the TLS state to provide as
// input to the method.
if authOpt.connType == hba.ConnHostSSL {
tlsConn, ok := c.conn.(*readTimeoutConn).Conn.(*tls.Conn)
tlsConn, ok := c.conn.(*tls.Conn)
if !ok {
err = errors.AssertionFailedf("server reports hostssl conn without TLS state")
return
Expand Down
38 changes: 0 additions & 38 deletions pkg/sql/pgwire/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1473,41 +1473,3 @@ var testingStatusReportParams = map[string]string{
"client_encoding": "UTF8",
"standard_conforming_strings": "on",
}

// readTimeoutConn overloads net.Conn.Read by periodically calling
// checkExitConds() and aborting the read if an error is returned.
type readTimeoutConn struct {
net.Conn
// checkExitConds is called periodically by Read(). If it returns an error,
// the Read() returns that error. Future calls to Read() are allowed, in which
// case checkExitConds() will be called again.
checkExitConds func() error
}

func (c *readTimeoutConn) Read(b []byte) (int, error) {
// readTimeout is the amount of time ReadTimeoutConn should wait on a
// read before checking for exit conditions. The tradeoff is between the
// time it takes to react to session context cancellation and the overhead
// of waking up and checking for exit conditions.
const readTimeout = 1 * time.Second

// Remove the read deadline when returning from this function to avoid
// unexpected behavior.
defer func() { _ = c.SetReadDeadline(time.Time{}) }()
for {
if err := c.checkExitConds(); err != nil {
return 0, err
}
if err := c.SetReadDeadline(timeutil.Now().Add(readTimeout)); err != nil {
return 0, err
}
n, err := c.Conn.Read(b)
if err != nil {
// Continue if the error is due to timing out.
if ne := (net.Error)(nil); errors.As(err, &ne) && ne.Timeout() {
continue
}
}
return n, err
}
}
161 changes: 86 additions & 75 deletions pkg/sql/pgwire/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
package pgwire

import (
"bytes"
"context"
gosql "database/sql"
"database/sql/driver"
Expand Down Expand Up @@ -335,6 +334,92 @@ func TestPipelineMetric(t *testing.T) {
require.NoError(t, err)
}

// BenchmarkIdleConn monitors the cpu usage of a single connection when it is
// idle.
func BenchmarkIdleConn(b *testing.B) {
defer leaktest.AfterTest(b)()
defer log.Scope(b).Close(b)

// Start a pgwire "server". We use a fake server here since we don't want
// to measure any background resource usage done by a normal CRDB server.
addr := util.TestAddr
ln, err := net.Listen(addr.Network(), addr.String())
if err != nil {
b.Fatal(err)
}
serverAddr := ln.Addr()
log.Infof(context.Background(), "started listener on %s", serverAddr)

ctx, cancelConn := context.WithCancel(context.Background())
defer cancelConn()
connectGroup := ctxgroup.WithContext(ctx)

var pgxConns []*pgx5.Conn
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
connectGroup.GoCtx(func(ctx context.Context) error {
host, ports, err := net.SplitHostPort(serverAddr.String())
if err != nil {
return err
}
port, err := strconv.Atoi(ports)
if err != nil {
return err
}
pgxConn, err := pgx5.Connect(
ctx,
fmt.Sprintf("postgresql://%s@%s:%d/system?sslmode=disable", username.RootUser, host, port),
)
if err != nil {
return err
}
pgxConns = append(pgxConns, pgxConn)
return nil
})

server := newTestServer()
// Wait for the client to connect.
netConn, err := waitForClientConn(ln)
require.NoError(b, err)

// Run the conn's loop in the background - it will push commands to the
// buffer. serveImpl also performs the server handshake.
expectedTimeoutErr := errors.New("normal timeout")
serveCtx, stopServe := context.WithTimeoutCause(context.Background(), 15*time.Second, expectedTimeoutErr)
defer stopServe()
serveGroup := ctxgroup.WithContext(serveCtx)
serverSideConn := server.newConn(
serveCtx,
cancelConn,
netConn,
sql.SessionArgs{ConnResultsBufferSize: 16 << 10},
timeutil.Now(),
)

b.StartTimer()
serveGroup.GoCtx(func(ctx context.Context) error {
server.serveImpl(
ctx,
serverSideConn,
&mon.BoundAccount{}, /* reserved */
authOptions{testingSkipAuth: true, connType: hba.ConnHostAny},
clusterunique.ID{},
)
return nil
})
err = serveGroup.Wait()
b.StopTimer()
require.NoError(b, err)
require.ErrorIs(b, context.Cause(serveCtx), expectedTimeoutErr)
}

_ = connectGroup.Wait()
for _, c := range pgxConns {
_ = c.Close(ctx)
}
}

func TestConnMessageTooBig(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)
Expand Down Expand Up @@ -1295,80 +1380,6 @@ func TestMaliciousInputs(t *testing.T) {
}
}

// TestReadTimeoutConn asserts that a readTimeoutConn performs reads normally
// and exits with an appropriate error when exit conditions are satisfied.
func TestReadTimeoutConnExits(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)
// Cannot use net.Pipe because deadlines are not supported.
ln, err := net.Listen(util.TestAddr.Network(), util.TestAddr.String())
if err != nil {
t.Fatal(err)
}
log.Infof(context.Background(), "started listener on %s", ln.Addr())
defer func() {
if err := ln.Close(); err != nil {
t.Fatal(err)
}
}()

ctx, cancel := context.WithCancel(context.Background())
expectedRead := []byte("expectedRead")

// Start a goroutine that performs reads using a readTimeoutConn.
errChan := make(chan error)
go func() {
defer close(errChan)
errChan <- func() error {
c, err := ln.Accept()
if err != nil {
return err
}
defer c.Close()

readTimeoutConn := &readTimeoutConn{
Conn: c,
checkExitConds: func() error {
return ctx.Err()
},
}
// Assert that reads are performed normally.
readBytes := make([]byte, len(expectedRead))
if _, err := readTimeoutConn.Read(readBytes); err != nil {
return err
}
if !bytes.Equal(readBytes, expectedRead) {
return errors.Errorf("expected %v got %v", expectedRead, readBytes)
}

// The main goroutine will cancel the context, which should abort
// this read with an appropriate error.
_, err = readTimeoutConn.Read(make([]byte, 1))
return err
}()
}()

c, err := net.Dial(ln.Addr().Network(), ln.Addr().String())
if err != nil {
t.Fatal(err)
}
defer c.Close()

if _, err := c.Write(expectedRead); err != nil {
t.Fatal(err)
}

select {
case err := <-errChan:
t.Fatalf("goroutine unexpectedly returned: %v", err)
default:
}
cancel()
if err := <-errChan; !errors.Is(err, context.Canceled) {
t.Fatalf("unexpected error: %v", err)
}
}

func TestConnResultsBufferSize(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)
Expand Down
Loading

0 comments on commit a65fea7

Please sign in to comment.