From 1a048517915f2e389cab06af5cf839a296d22a2d Mon Sep 17 00:00:00 2001 From: Jay Date: Mon, 7 Feb 2022 11:12:16 -0500 Subject: [PATCH] ccl/sqlproxyccl: add basic forwarder component Informs #76000. This commit refactors the ConnectionCopy call in proxy_handler.go into a new forwarder component, which was described in the connection migration RFC. At the moment, this forwarder component does basic forwarding through ConnectionCopy, just like before, so there should be no behavioral changes to the proxy. This will serve as a building block for subsequent PRs. Release note: None --- pkg/ccl/sqlproxyccl/BUILD.bazel | 2 + pkg/ccl/sqlproxyccl/forwarder.go | 158 ++++++++++++++++++ pkg/ccl/sqlproxyccl/forwarder_test.go | 228 ++++++++++++++++++++++++++ pkg/ccl/sqlproxyccl/proxy_handler.go | 48 +++--- 4 files changed, 411 insertions(+), 25 deletions(-) create mode 100644 pkg/ccl/sqlproxyccl/forwarder.go create mode 100644 pkg/ccl/sqlproxyccl/forwarder_test.go diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index 4bbbfda183d2..b73644954b23 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -7,6 +7,7 @@ go_library( "authentication.go", "backend_dialer.go", "error.go", + "forwarder.go", "frontend_admitter.go", "metrics.go", "proxy.go", @@ -48,6 +49,7 @@ go_test( size = "small", srcs = [ "authentication_test.go", + "forwarder_test.go", "frontend_admitter_test.go", "main_test.go", "proxy_handler_test.go", diff --git a/pkg/ccl/sqlproxyccl/forwarder.go b/pkg/ccl/sqlproxyccl/forwarder.go new file mode 100644 index 000000000000..3975f51a1cbd --- /dev/null +++ b/pkg/ccl/sqlproxyccl/forwarder.go @@ -0,0 +1,158 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "context" + "net" + + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/errors" +) + +// ErrForwarderClosed indicates that the forwarder has been closed. +var ErrForwarderClosed = errors.New("forwarder has been closed") + +// ErrForwarderStarted indicates that the forwarder has already started. +var ErrForwarderStarted = errors.New("forwarder has already started") + +// forwarder is used to forward pgwire messages from the client to the server, +// and vice-versa. At the moment, this does a direct proxying, and there is +// no intercepting. Once https://github.com/cockroachdb/cockroach/issues/76000 +// has been addressed, we will start intercepting pgwire messages at their +// boundaries here. +type forwarder struct { + // ctx is a single context used to control all goroutines spawned by the + // forwarder. + ctx context.Context + ctxCancel context.CancelFunc + + // crdbConn is only set after the authentication phase for the initial + // connection. In the context of a connection migration, crdbConn is only + // replaced once the session has successfully been deserialized, and the + // old connection will be closed. + conn net.Conn // client <-> proxy + crdbConn net.Conn // proxy <-> client + + // errChan is a buffered channel that contains the first forwarder error. + // This channel may receive nil errors. + errChan chan error + + mu struct { + syncutil.Mutex + + // closed is set to true whenever the forwarder is closed explicitly + // through Close, or when any of its main goroutines has terminated, + // whichever that happens first. + // + // A new forwarder instance will have to be recreated if one wants to + // reuse the same pair of connections. + closed bool + + // started is set to true once Run has been invoked on the forwarder. + // This is a safeguard to prevent callers from starting the forwarding + // process twice. This will never be set back to false. + started bool + } +} + +// newForwarder returns a new instance of forwarder. When this is called, it +// is expected that the caller passes ownerships of conn and crdbConn to the +// forwarder, which implies that these connections may be closed by the +// forwarder if necessary, but this is not guaranteed. The caller should still +// attempt to cleanup the connections. +func newForwarder(ctx context.Context, conn net.Conn, crdbConn net.Conn) *forwarder { + ctx, cancelFn := context.WithCancel(ctx) + + return &forwarder{ + ctx: ctx, + ctxCancel: cancelFn, + conn: conn, + crdbConn: crdbConn, + errChan: make(chan error, 1), + } +} + +// Run starts the forwarding process for the associated connections, and can +// only be called once throughout the lifetime of the forwarder instance. +// +// All goroutines spun up must check on f.ctx to prevent leaks, if possible. If +// there was an error within the goroutines, the forwarder will be closed, and +// the first error can be found in f.errChan. +// +// This returns ErrForwarderClosed if the forwarder has been closed, and +// ErrForwarderStarted if the forwarder has already been started. +func (f *forwarder) Run() error { + f.mu.Lock() + defer f.mu.Unlock() + + if f.mu.closed { + return ErrForwarderClosed + } + if f.mu.started { + return ErrForwarderStarted + } + + go func() { + defer f.Close() + + // Block until context is done. + <-f.ctx.Done() + + // Closing the connection will terminate the goroutine below. + // This will be here temporarily because the only way to unblock + // io.Copy is to close one of the ends. Once we replace io.Copy + // with the interceptors, we could use ctx directly, and no longer + // need this goroutine. + f.crdbConn.Close() + }() + + // Copy all pgwire messages from frontend to backend connection until we + // encounter an error or shutdown signal. + go func() { + defer f.Close() + + err := ConnectionCopy(f.crdbConn, f.conn) + select { + case f.errChan <- err: /* error reported */ + default: /* the channel already contains an error */ + } + }() + + f.mu.started = true + return nil +} + +// Close closes the forwarder, and stops the forwarding process. This is +// idempotent. +func (f *forwarder) Close() { + f.mu.Lock() + defer f.mu.Unlock() + + f.mu.closed = true + f.ctxCancel() +} + +// IsClosed returns a boolean indicating whether the forwarder is closed. +func (f *forwarder) IsClosed() bool { + f.mu.Lock() + defer f.mu.Unlock() + + return f.mu.closed +} + +// IsStarted returns a boolean indicating whether the forwarder has been +// started during its lifetime. This does not indicate that the forwarder has +// been closed. +func (f *forwarder) IsStarted() bool { + f.mu.Lock() + defer f.mu.Unlock() + + return f.mu.started +} diff --git a/pkg/ccl/sqlproxyccl/forwarder_test.go b/pkg/ccl/sqlproxyccl/forwarder_test.go new file mode 100644 index 000000000000..c6616e1442b3 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/forwarder_test.go @@ -0,0 +1,228 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "context" + "net" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/errors" + "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/require" +) + +func TestForwarder_Run(t *testing.T) { + defer leaktest.AfterTest(t)() + + bgCtx := context.Background() + + t.Run("early_errors", func(t *testing.T) { + f := newForwarder(bgCtx, nil /* conn */, nil /* crdbConn */) + + // Forwarder started. + f.mu.started = true + require.EqualError(t, f.Run(), ErrForwarderStarted.Error()) + + // Forwarder closed. + f.mu.closed = true + require.EqualError(t, f.Run(), ErrForwarderClosed.Error()) + }) + + t.Run("closed_when_processors_error", func(t *testing.T) { + p1, p2 := net.Pipe() + f := newForwarder(bgCtx, p1, p2) + + // Close the connections right away. + p1.Close() + p2.Close() + require.NoError(t, f.Run()) + + // We have to wait for the goroutine to run. + testutils.SucceedsSoon(t, func() error { + if f.IsClosed() { + return nil + } + return errors.New("forwarder is not closed yet") + }) + + select { + case err := <-f.errChan: + require.Error(t, err) + default: + t.Fatalf("require error, but not none") + } + }) + + t.Run("client_to_server", func(t *testing.T) { + ctx, cancel := context.WithTimeout(bgCtx, 5*time.Second) + defer cancel() + + clientW, clientR := net.Pipe() + serverW, serverR := net.Pipe() + // We don't close clientW and serverR here since we have no control + // over those. + defer clientR.Close() + defer serverW.Close() + + f := newForwarder(ctx, clientR, serverW) + defer f.Close() + require.NoError(t, f.Run()) + require.True(t, f.mu.started) + require.False(t, f.mu.closed) + + // Client writes some pgwire messages. + errChan := make(chan error, 1) + go func() { + _, err := clientW.Write((&pgproto3.Query{ + String: "SELECT 1", + }).Encode(nil)) + if err != nil { + errChan <- err + return + } + + if _, err := clientW.Write((&pgproto3.Execute{ + Portal: "foobar", + MaxRows: 42, + }).Encode(nil)); err != nil { + errChan <- err + return + } + + if _, err := clientW.Write((&pgproto3.Close{ + ObjectType: 'P', + }).Encode(nil)); err != nil { + errChan <- err + return + } + }() + + // Server should receive messages in order. + backend := pgproto3.NewBackend(pgproto3.NewChunkReader(serverR), serverR) + + msg, err := backend.Receive() + require.NoError(t, err) + m1, ok := msg.(*pgproto3.Query) + require.True(t, ok) + require.Equal(t, "SELECT 1", m1.String) + + msg, err = backend.Receive() + require.NoError(t, err) + m2, ok := msg.(*pgproto3.Execute) + require.True(t, ok) + require.Equal(t, "foobar", m2.Portal) + require.Equal(t, uint32(42), m2.MaxRows) + + msg, err = backend.Receive() + require.NoError(t, err) + m3, ok := msg.(*pgproto3.Close) + require.True(t, ok) + require.Equal(t, byte('P'), m3.ObjectType) + + select { + case err = <-errChan: + t.Fatalf("require no error, but got %v", err) + default: + } + }) + + t.Run("server_to_client", func(t *testing.T) { + ctx, cancel := context.WithTimeout(bgCtx, 5*time.Second) + defer cancel() + + clientW, clientR := net.Pipe() + serverW, serverR := net.Pipe() + // We don't close clientW and serverR here since we have no control + // over those. + defer clientR.Close() + defer serverW.Close() + + f := newForwarder(ctx, clientR, serverW) + defer f.Close() + require.NoError(t, f.Run()) + require.True(t, f.mu.started) + require.False(t, f.mu.closed) + + // Server writes some pgwire messages. + errChan := make(chan error, 1) + go func() { + if _, err := serverR.Write((&pgproto3.ErrorResponse{ + Code: "100", + Message: "foobarbaz", + }).Encode(nil)); err != nil { + errChan <- err + return + } + + if _, err := serverR.Write((&pgproto3.ReadyForQuery{ + TxStatus: 'I', + }).Encode(nil)); err != nil { + errChan <- err + return + } + }() + + // Client should receive messages in order. + frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(clientW), clientW) + + msg, err := frontend.Receive() + require.NoError(t, err) + m1, ok := msg.(*pgproto3.ErrorResponse) + require.True(t, ok) + require.Equal(t, "100", m1.Code) + require.Equal(t, "foobarbaz", m1.Message) + + msg, err = frontend.Receive() + require.NoError(t, err) + m2, ok := msg.(*pgproto3.ReadyForQuery) + require.True(t, ok) + require.Equal(t, byte('I'), m2.TxStatus) + + select { + case err = <-errChan: + t.Fatalf("require no error, but got %v", err) + default: + } + }) +} + +func TestForwarder_Close(t *testing.T) { + defer leaktest.AfterTest(t)() + + f := newForwarder(context.Background(), nil /* conn */, nil /* crdbConn */) + require.False(t, f.mu.closed) + + f.Close() + require.True(t, f.mu.closed) + require.EqualError(t, f.ctx.Err(), context.Canceled.Error()) +} + +func TestForwarder_IsClosed(t *testing.T) { + defer leaktest.AfterTest(t)() + + f := newForwarder(context.Background(), nil /* conn */, nil /* crdbConn */) + require.False(t, f.mu.closed) + require.False(t, f.IsClosed()) + f.mu.closed = true + require.True(t, f.IsClosed()) +} + +func TestForwarder_IsStarted(t *testing.T) { + defer leaktest.AfterTest(t)() + + f := newForwarder(context.Background(), nil /* conn */, nil /* crdbConn */) + require.False(t, f.mu.started) + require.False(t, f.IsStarted()) + f.mu.started = true + require.True(t, f.IsStarted()) +} diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index 7cc8d82c8709..00489ccae8fb 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -392,6 +392,11 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn } // Monitor for idle connection, if requested. + // + // TODO(jaylim-crl): Wrap conn instead of crdbConn for idle connections. + // Once we have connection migration, this mechanism can be removed + // entirely since we will be migrating connections instead of closing them + // whenever a pod is draining. Also note that this isn't used in CC today. if handler.idleMonitor != nil { crdbConn = handler.idleMonitor.DetectIdle(crdbConn, func() { err := newErrorf(codeIdleDisconnect, "idle connection closed") @@ -420,42 +425,35 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn handler.metrics.SuccessfulConnCount.Inc(1) - ctx, cancel := context.WithCancel(ctx) - defer cancel() - log.Infof(ctx, "new connection") connBegin := timeutil.Now() defer func() { log.Infof(ctx, "closing after %.2fs", timeutil.Since(connBegin).Seconds()) }() - // Copy all pgwire messages from frontend to backend connection until we - // encounter an error or shutdown signal. - go func() { - err := ConnectionCopy(crdbConn, conn) - select { - case errConnection <- err: /* error reported */ - default: /* the channel already contains an error */ - } - }() + // Pass ownerships of conn and crdbConn to the forwarder. + f := newForwarder(ctx, conn, crdbConn) + defer f.Close() + + // Start forwarding messages. + if err := f.Run(); err != nil { + // This should not happen. + panic(errors.Wrap(err, "forwarder.Run failed")) + } + // Block until an error is received, or when the stopper starts quiescing, + // whichever that happens first. select { - case err := <-errConnection: + case err := <-f.errChan: // From forwarder. + handler.metrics.updateForError(err) + return err + case err := <-errConnection: // From denyListWatcher or idleMonitor. handler.metrics.updateForError(err) return err - case <-ctx.Done(): - err := ctx.Err() - if err != nil { - // The client connection expired. - codeErr := newErrorf( - codeExpiredClientConnection, "expired client conn: %v", err, - ) - handler.metrics.updateForError(codeErr) - return codeErr - } - return nil case <-handler.stopper.ShouldQuiesce(): - return nil + err := context.Canceled + handler.metrics.updateForError(err) + return err } }