From a65fea782a5fc3803289b05827e4c8f12980c3c6 Mon Sep 17 00:00:00 2001 From: Rafi Shamim Date: Sat, 18 May 2024 03:51:19 -0400 Subject: [PATCH] pgwire: remove readTimeoutConn in favor of a channel Rather than using a connection that polls for the context being done every second, we now spin up an additional goroutine that blocks until the connection context is done, or the drain signal was received. Release note: None --- pkg/sql/pgwire/auth.go | 2 +- pkg/sql/pgwire/conn.go | 38 --------- pkg/sql/pgwire/conn_test.go | 161 +++++++++++++++++++----------------- pkg/sql/pgwire/server.go | 90 ++++++++++++++------ 4 files changed, 150 insertions(+), 141 deletions(-) diff --git a/pkg/sql/pgwire/auth.go b/pkg/sql/pgwire/auth.go index a6e7b454e32f..8a7bff74b4e3 100644 --- a/pkg/sql/pgwire/auth.go +++ b/pkg/sql/pgwire/auth.go @@ -312,7 +312,7 @@ func (c *conn) findAuthenticationMethod( // If the client is using SSL, retrieve the TLS state to provide as // input to the method. if authOpt.connType == hba.ConnHostSSL { - tlsConn, ok := c.conn.(*readTimeoutConn).Conn.(*tls.Conn) + tlsConn, ok := c.conn.(*tls.Conn) if !ok { err = errors.AssertionFailedf("server reports hostssl conn without TLS state") return diff --git a/pkg/sql/pgwire/conn.go b/pkg/sql/pgwire/conn.go index 6c3a90bfc735..0ea7f3d598cf 100644 --- a/pkg/sql/pgwire/conn.go +++ b/pkg/sql/pgwire/conn.go @@ -1473,41 +1473,3 @@ var testingStatusReportParams = map[string]string{ "client_encoding": "UTF8", "standard_conforming_strings": "on", } - -// readTimeoutConn overloads net.Conn.Read by periodically calling -// checkExitConds() and aborting the read if an error is returned. -type readTimeoutConn struct { - net.Conn - // checkExitConds is called periodically by Read(). If it returns an error, - // the Read() returns that error. Future calls to Read() are allowed, in which - // case checkExitConds() will be called again. - checkExitConds func() error -} - -func (c *readTimeoutConn) Read(b []byte) (int, error) { - // readTimeout is the amount of time ReadTimeoutConn should wait on a - // read before checking for exit conditions. The tradeoff is between the - // time it takes to react to session context cancellation and the overhead - // of waking up and checking for exit conditions. - const readTimeout = 1 * time.Second - - // Remove the read deadline when returning from this function to avoid - // unexpected behavior. - defer func() { _ = c.SetReadDeadline(time.Time{}) }() - for { - if err := c.checkExitConds(); err != nil { - return 0, err - } - if err := c.SetReadDeadline(timeutil.Now().Add(readTimeout)); err != nil { - return 0, err - } - n, err := c.Conn.Read(b) - if err != nil { - // Continue if the error is due to timing out. - if ne := (net.Error)(nil); errors.As(err, &ne) && ne.Timeout() { - continue - } - } - return n, err - } -} diff --git a/pkg/sql/pgwire/conn_test.go b/pkg/sql/pgwire/conn_test.go index 6896e57c8c72..ef18183791ef 100644 --- a/pkg/sql/pgwire/conn_test.go +++ b/pkg/sql/pgwire/conn_test.go @@ -11,7 +11,6 @@ package pgwire import ( - "bytes" "context" gosql "database/sql" "database/sql/driver" @@ -335,6 +334,92 @@ func TestPipelineMetric(t *testing.T) { require.NoError(t, err) } +// BenchmarkIdleConn monitors the cpu usage of a single connection when it is +// idle. +func BenchmarkIdleConn(b *testing.B) { + defer leaktest.AfterTest(b)() + defer log.Scope(b).Close(b) + + // Start a pgwire "server". We use a fake server here since we don't want + // to measure any background resource usage done by a normal CRDB server. + addr := util.TestAddr + ln, err := net.Listen(addr.Network(), addr.String()) + if err != nil { + b.Fatal(err) + } + serverAddr := ln.Addr() + log.Infof(context.Background(), "started listener on %s", serverAddr) + + ctx, cancelConn := context.WithCancel(context.Background()) + defer cancelConn() + connectGroup := ctxgroup.WithContext(ctx) + + var pgxConns []*pgx5.Conn + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + connectGroup.GoCtx(func(ctx context.Context) error { + host, ports, err := net.SplitHostPort(serverAddr.String()) + if err != nil { + return err + } + port, err := strconv.Atoi(ports) + if err != nil { + return err + } + pgxConn, err := pgx5.Connect( + ctx, + fmt.Sprintf("postgresql://%s@%s:%d/system?sslmode=disable", username.RootUser, host, port), + ) + if err != nil { + return err + } + pgxConns = append(pgxConns, pgxConn) + return nil + }) + + server := newTestServer() + // Wait for the client to connect. + netConn, err := waitForClientConn(ln) + require.NoError(b, err) + + // Run the conn's loop in the background - it will push commands to the + // buffer. serveImpl also performs the server handshake. + expectedTimeoutErr := errors.New("normal timeout") + serveCtx, stopServe := context.WithTimeoutCause(context.Background(), 15*time.Second, expectedTimeoutErr) + defer stopServe() + serveGroup := ctxgroup.WithContext(serveCtx) + serverSideConn := server.newConn( + serveCtx, + cancelConn, + netConn, + sql.SessionArgs{ConnResultsBufferSize: 16 << 10}, + timeutil.Now(), + ) + + b.StartTimer() + serveGroup.GoCtx(func(ctx context.Context) error { + server.serveImpl( + ctx, + serverSideConn, + &mon.BoundAccount{}, /* reserved */ + authOptions{testingSkipAuth: true, connType: hba.ConnHostAny}, + clusterunique.ID{}, + ) + return nil + }) + err = serveGroup.Wait() + b.StopTimer() + require.NoError(b, err) + require.ErrorIs(b, context.Cause(serveCtx), expectedTimeoutErr) + } + + _ = connectGroup.Wait() + for _, c := range pgxConns { + _ = c.Close(ctx) + } +} + func TestConnMessageTooBig(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) @@ -1295,80 +1380,6 @@ func TestMaliciousInputs(t *testing.T) { } } -// TestReadTimeoutConn asserts that a readTimeoutConn performs reads normally -// and exits with an appropriate error when exit conditions are satisfied. -func TestReadTimeoutConnExits(t *testing.T) { - defer leaktest.AfterTest(t)() - defer log.Scope(t).Close(t) - // Cannot use net.Pipe because deadlines are not supported. - ln, err := net.Listen(util.TestAddr.Network(), util.TestAddr.String()) - if err != nil { - t.Fatal(err) - } - log.Infof(context.Background(), "started listener on %s", ln.Addr()) - defer func() { - if err := ln.Close(); err != nil { - t.Fatal(err) - } - }() - - ctx, cancel := context.WithCancel(context.Background()) - expectedRead := []byte("expectedRead") - - // Start a goroutine that performs reads using a readTimeoutConn. - errChan := make(chan error) - go func() { - defer close(errChan) - errChan <- func() error { - c, err := ln.Accept() - if err != nil { - return err - } - defer c.Close() - - readTimeoutConn := &readTimeoutConn{ - Conn: c, - checkExitConds: func() error { - return ctx.Err() - }, - } - // Assert that reads are performed normally. - readBytes := make([]byte, len(expectedRead)) - if _, err := readTimeoutConn.Read(readBytes); err != nil { - return err - } - if !bytes.Equal(readBytes, expectedRead) { - return errors.Errorf("expected %v got %v", expectedRead, readBytes) - } - - // The main goroutine will cancel the context, which should abort - // this read with an appropriate error. - _, err = readTimeoutConn.Read(make([]byte, 1)) - return err - }() - }() - - c, err := net.Dial(ln.Addr().Network(), ln.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer c.Close() - - if _, err := c.Write(expectedRead); err != nil { - t.Fatal(err) - } - - select { - case err := <-errChan: - t.Fatalf("goroutine unexpectedly returned: %v", err) - default: - } - cancel() - if err := <-errChan; !errors.Is(err, context.Canceled) { - t.Fatalf("unexpected error: %v", err) - } -} - func TestConnResultsBufferSize(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) diff --git a/pkg/sql/pgwire/server.go b/pkg/sql/pgwire/server.go index 325acadc8333..19d65d4a5c11 100644 --- a/pkg/sql/pgwire/server.go +++ b/pkg/sql/pgwire/server.go @@ -13,6 +13,7 @@ package pgwire import ( "bufio" "context" + "crypto/tls" "io" "net" "sync" @@ -239,6 +240,9 @@ type Server struct { // After the timeout set by server.shutdown.transactions.timeout, // all connections will be closed regardless any txns in flight. draining bool + // drainCh is closed when draining is set to true. If Undrain is called, + // this channel is reset. + drainCh chan struct{} // rejectNewConnections is set true when the server does not accept new // SQL connections, e.g. when the draining process enters the phase whose // duration is specified by the server.shutdown.connections.timeout. @@ -355,6 +359,7 @@ func MakeServer( server.mu.Lock() server.mu.connCancelMap = make(cancelChanMap) + server.mu.drainCh = make(chan struct{}) server.mu.Unlock() connAuthConf.SetOnChange(&st.SV, func(ctx context.Context) { @@ -449,6 +454,14 @@ func (s *Server) Undrain() { s.setDrainingLocked(false) } +// DrainCh returns a channel that can be watched to detect when the server +// enters the draining state. +func (s *Server) DrainCh() <-chan struct{} { + s.mu.Lock() + defer s.mu.Unlock() + return s.mu.drainCh +} + // setDrainingLocked sets the server's draining state and returns whether the // state changed (i.e. drain != s.mu.draining). s.mu must be locked when // setDrainingLocked is called. @@ -457,6 +470,11 @@ func (s *Server) setDrainingLocked(drain bool) bool { return false } s.mu.draining = drain + if drain { + close(s.mu.drainCh) + } else { + s.mu.drainCh = make(chan struct{}) + } return true } @@ -840,18 +858,14 @@ func (s *Server) newConn( sArgs sql.SessionArgs, connStart time.Time, ) *conn { - // The net.Conn is switched to a conn that exits if the ctx is canceled. - rtc := &readTimeoutConn{ - Conn: netConn, - } sv := &s.execCfg.Settings.SV c := &conn{ - conn: rtc, + conn: netConn, cancelConn: cancelConn, sessionArgs: sArgs, metrics: s.tenantMetrics, startTime: connStart, - rd: *bufio.NewReader(rtc), + rd: *bufio.NewReader(netConn), sv: sv, readBuf: pgwirebase.MakeReadBuffer(pgwirebase.ReadBufferOptionWithClusterSettings(sv)), alwaysLogAuthActivity: s.testingAuthLogEnabled.Get(), @@ -864,23 +878,6 @@ func (s *Server) newConn( c.msgBuilder.init(s.tenantMetrics.BytesOutCount) c.errWriter.sv = sv c.errWriter.msgBuilder = &c.msgBuilder - - var sentDrainSignal bool - rtc.checkExitConds = func() error { - // If the context was canceled, it's time to stop reading. Either a - // higher-level server or the command processor have canceled us. - if err := ctx.Err(); err != nil { - return err - } - // If the server is draining, we'll let the processor know by pushing a - // DrainRequest. This will make the processor quit whenever it finds a good - // time. - if !sentDrainSignal && s.IsDraining() { - _ /* err */ = c.stmtBuf.Push(ctx, sql.DrainRequest{}) - sentDrainSignal = true - } - return nil - } return c } @@ -897,7 +894,8 @@ const maxRepeatedErrorCount = 1 << 15 // canceled (which also happens when draining (but not from the get-go), and // when the processor encounters a fatal error). // -// serveImpl always closes the network connection before returning. +// serveImpl closes the stmtBuf before returning, which is a signal to the +// processor goroutine to exit. // // sqlServer is used to create the command processor. As a special facility for // tests, sqlServer can be nil, in which case the command processor and the @@ -929,15 +927,18 @@ const maxRepeatedErrorCount = 1 << 15 // // Draining notes: // -// The reader notices that the server is draining by polling the IsDraining -// closure passed to serveImpl. At that point, the reader delegates the +// The reader notices that the server is draining by watching the channel +// returned by Server.DrainCh. At that point, the reader delegates the // responsibility of closing the connection to the statement processor: it will // push a DrainRequest to the stmtBuf which signals the processor to quit ASAP. // The processor will quit immediately upon seeing that command if it's not // currently in a transaction. If it is in a transaction, it will wait until the // first time a Sync command is processed outside of a transaction - the logic // being that we want to stop when we're both outside transactions and outside -// batches. +// batches. If the processor does not process the DrainRequest quickly enough +// (based on the server.shutdown.transactions.timeout setting), the server +// will cancel the context, which will close the connection and make the +// reader goroutine exit. func (s *Server) serveImpl( ctx context.Context, c *conn, @@ -955,6 +956,41 @@ func (s *Server) serveImpl( sqlServer := s.SQLServer inTestWithoutSQL := sqlServer == nil + go func() { + select { + case <-ctx.Done(): + // If the context was canceled, it's time to stop reading. Either a + // higher-level server or the command processor have canceled us. + case <-s.DrainCh(): + // If the server is draining, we'll let the processor know by pushing a + // DrainRequest. This will make the processor quit whenever it finds a + // good time (i.e., outside of a transaction). The context will be + // cancelled by the server after all processors have quit or after the + // server.shutdown.transactions.timeout duration. + _ /* err */ = c.stmtBuf.Push(ctx, sql.DrainRequest{}) + <-ctx.Done() + } + // If possible, we try to only close the read side of the connection. This will cause the + // ReadTypedMsg call in the reader goroutine to return an error, which will + // cause the reader to exit and signal the processor to quit also, and still + // be able to write an error message to the client. If we're unable to only + // close the read side, we fallback to setting a read deadline that will + // make all reads timeout. + var tcpConn *net.TCPConn + switch c := c.conn.(type) { + case *net.TCPConn: + tcpConn = c + case *tls.Conn: + underConn := c.NetConn() + tcpConn, _ = underConn.(*net.TCPConn) + } + if tcpConn == nil { + _ = c.conn.SetReadDeadline(timeutil.Now()) + } else { + _ = tcpConn.CloseRead() + } + }() + // NOTE: We're going to write a few messages to the connection in this method, // for the handshake. After that, all writes are done async, in the // startWriter() goroutine.