From 6e8267931ed312aa8c60517ea2c917d45562a5be Mon Sep 17 00:00:00 2001 From: Jay Date: Thu, 10 Feb 2022 20:47:10 -0500 Subject: [PATCH] ccl/sqlproxyccl: add connector component and support for session revival token Informs #76000. Previously, all the connection establishment logic is coupled with the handler function within proxy_handler.go. This makes connecting to a new SQL pod during connection migration difficult. This commit refactors all of those connection logic out of the proxy handler into a connector component, as described in the connection migration RFC. At the same time, we also add support for the session revival token within this connector component. Note that the overall behavior of the SQL proxy should be unchanged with this commit. Release note: None --- pkg/ccl/sqlproxyccl/BUILD.bazel | 2 + pkg/ccl/sqlproxyccl/backend_dialer.go | 2 + pkg/ccl/sqlproxyccl/connector.go | 267 ++++++++++++++ pkg/ccl/sqlproxyccl/connector_test.go | 510 ++++++++++++++++++++++++++ pkg/ccl/sqlproxyccl/proxy_handler.go | 180 ++++----- 5 files changed, 859 insertions(+), 102 deletions(-) create mode 100644 pkg/ccl/sqlproxyccl/connector.go create mode 100644 pkg/ccl/sqlproxyccl/connector_test.go diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index d546b57c2c83..a5c6f0a5aa4f 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -6,6 +6,7 @@ go_library( srcs = [ "authentication.go", "backend_dialer.go", + "connector.go", "error.go", "frontend_admitter.go", "metrics.go", @@ -48,6 +49,7 @@ go_test( size = "small", srcs = [ "authentication_test.go", + "connector_test.go", "frontend_admitter_test.go", "main_test.go", "proxy_handler_test.go", diff --git a/pkg/ccl/sqlproxyccl/backend_dialer.go b/pkg/ccl/sqlproxyccl/backend_dialer.go index b0ce1c594c62..5d1f12d76c56 100644 --- a/pkg/ccl/sqlproxyccl/backend_dialer.go +++ b/pkg/ccl/sqlproxyccl/backend_dialer.go @@ -24,6 +24,8 @@ import ( // // BackendDial uses a dial timeout of 5 seconds to mitigate network black // holes. +// +// TODO(jaylim-crl): Move dialer into connector in the future. var BackendDial = func( msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config, ) (net.Conn, error) { diff --git a/pkg/ccl/sqlproxyccl/connector.go b/pkg/ccl/sqlproxyccl/connector.go new file mode 100644 index 000000000000..935c4a971c1c --- /dev/null +++ b/pkg/ccl/sqlproxyccl/connector.go @@ -0,0 +1,267 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "context" + "crypto/tls" + "net" + "time" + + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/throttler" + "github.com/cockroachdb/cockroach/pkg/util/netutil/addr" + "github.com/cockroachdb/cockroach/pkg/util/retry" + "github.com/cockroachdb/errors" + pgproto3 "github.com/jackc/pgproto3/v2" +) + +// sessionRevivalTokenStartupParam indicates the name of the parameter that +// will activate token-based authentication if present in the startup message. +const sessionRevivalTokenStartupParam = "crdb:session_revival_token_base64" + +// errRetryConnectorSentinel exists to allow more robust retection of retry +// errors even if they are wrapped. +var errRetryConnectorSentinel = errors.New("retry connector error") + +// MarkAsRetriableConnectorError marks the given error with +// errRetryConnectorSentinel, which will trigger the connector to retry if such +// error returns. +func MarkAsRetriableConnectorError(err error) error { + return errors.Mark(err, errRetryConnectorSentinel) +} + +// IsRetriableConnectorError checks whether a given error is retriable. This +// should be called on errors which are transient so that the connector can +// retry on such errors. +func IsRetriableConnectorError(err error) bool { + return errors.Is(err, errRetryConnectorSentinel) +} + +// connector is a per-session tenant-associated component that can be used to +// obtain a connection to the tenant cluster. This will also handle the +// authentication phase. All connections returned by the connector should +// already be ready to accept regular pgwire messages (e.g. SQL queries). +type connector struct { + // StartupMsg represents the startup message associated with the client. + // This will be used when establishing a pgwire connection with the SQL pod. + // + // NOTE: This field is required. + StartupMsg *pgproto3.StartupMessage + + // TLSConfig represents the client TLS config used by the connector when + // connecting with the SQL pod. If the ServerName field is set, this will + // be overridden during connection establishment. Set to nil if we are + // connecting to an insecure cluster. + // + // NOTE: This field is optional. + TLSConfig *tls.Config + + // AddrLookupFn is used by the connector to return an address (that must + // include both host and port) pointing to one of the SQL pods for the + // tenant associated with this connector. + // + // This will be called within an infinite backoff loop. If an error is + // transient, this should return an error that has been marked with + // errRetryConnectorSentinel (i.e. MarkAsRetriableConnectorError). + // + // NOTE: This field is required. + AddrLookupFn func(ctx context.Context) (string, error) + + // AuthenticateFn is used by the connector to authenticate the client + // against the server. This will only be used in non-token-based auth + // methods. This should block until the server has authenticated the client. + // + // NOTE: This field is required. + AuthenticateFn func( + client net.Conn, + server net.Conn, + throttleHook func(throttler.AttemptStatus) error, + ) error + + // IdleMonitorWrapperFn is used to wrap the connection to the SQL pod with + // an idle monitor. If not specified, the raw connection to the SQL pod + // will be returned. + // + // In the case of connecting with an authentication phase, the connection + // will be wrapped before starting the authentication. + // + // NOTE: This field is optional. + IdleMonitorWrapperFn func(crdbConn net.Conn) net.Conn + + // Event callback functions. OnLookupEvent and OnDialEvent will be called + // after the lookup and dial operations respectively, regardless of error. + // + // NOTE: These fields are optional. + // + // TODO(jaylim-crl): Look into removing event callback functions. This + // requires us to pass in some sort of directory into the connector + // component. Perhaps addrLookupFn could be replaced with that. We can't + // do that today because addrLookupFn also relies on a fallback mechanism + // for --routing-rule. + OnLookupEvent func(ctx context.Context, err error) + OnDialEvent func(ctx context.Context, outgoingAddr string, err error) + + // Testing knobs for internal connector calls. If specified, these will + // be called instead of the actual logic. + testingKnobs struct { + openClusterConnInternal func(ctx context.Context) (net.Conn, error) + dialOutgoingAddr func(outgoingAddr string) (net.Conn, error) + } +} + +// OpenClusterConnWithToken opens a connection to the tenant cluster using the +// token-based authentication. +func (c *connector) OpenClusterConnWithToken(ctx context.Context, token string) (net.Conn, error) { + c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] = token + defer func() { + // Delete token after return. + delete(c.StartupMsg.Parameters, sessionRevivalTokenStartupParam) + }() + + crdbConn, err := c.openClusterConnInternal(ctx) + if err != nil { + return nil, err + } + if c.IdleMonitorWrapperFn != nil { + crdbConn = c.IdleMonitorWrapperFn(crdbConn) + } + return crdbConn, nil +} + +// OpenClusterConnWithAuth opens a connection to the tenant cluster using +// normal authentication methods (e.g. password, etc.). Once a connection to +// one of the tenant's SQL pod has been established, we will transfer +// request/response flow between clientConn and the new connection to the +// authenticator, which implies that this will be blocked until authentication +// succeeds, or when an error is returned. +// +// sentToClient will be set to true if an error occurred during the +// authenticator phase since errors would have already been sent to the client. +func (c *connector) OpenClusterConnWithAuth( + ctx context.Context, clientConn net.Conn, throttleHook func(throttler.AttemptStatus) error, +) (serverConn net.Conn, sentToClient bool, retErr error) { + // Just a safety check, but this shouldn't happen since we will block the + // startup param in the frontend admitter. The only case where we actually + // need to delete this param is if OpenClusterConnWithToken was called + // previously, but that wouldn't happen based on the current proxy logic. + delete(c.StartupMsg.Parameters, sessionRevivalTokenStartupParam) + + crdbConn, err := c.openClusterConnInternal(ctx) + if err != nil { + return nil, false, err + } + defer func() { + if retErr != nil { + crdbConn.Close() + } + }() + + if c.IdleMonitorWrapperFn != nil { + crdbConn = c.IdleMonitorWrapperFn(crdbConn) + } + + // Perform user authentication. + if err := c.AuthenticateFn(clientConn, crdbConn, throttleHook); err != nil { + return nil, true, err + } + return crdbConn, false, nil +} + +// openClusterConnInternal returns a connection to the tenant cluster associated +// with the connector. Once a connection has been established, the pgwire +// startup message will be relayed to the server. Returned errors may be marked +// as a lookup or dial error. +func (c *connector) openClusterConnInternal(ctx context.Context) (net.Conn, error) { + if c.testingKnobs.openClusterConnInternal != nil { + return c.testingKnobs.openClusterConnInternal(ctx) + } + + // Repeatedly try to make a connection until context is canceled, or until + // we get a non-retriable error. This is preferable to terminating client + // connections, because in most cases those connections will simply be + // retried, further increasing load on the system. + retryOpts := retry.Options{ + InitialBackoff: 10 * time.Millisecond, + MaxBackoff: 5 * time.Second, + } + + var crdbConn net.Conn + var outgoingAddr string + var err error + for r := retry.StartWithCtx(ctx, retryOpts); r.Next(); { + // Retrieve a SQL pod address to connect to. + outgoingAddr, err = c.AddrLookupFn(ctx) + if c.OnLookupEvent != nil { + c.OnLookupEvent(ctx, err) + } + if err != nil { + if IsRetriableConnectorError(err) { + continue + } + return nil, err + } + // Make a connection to the SQL pod. + crdbConn, err = c.dialOutgoingAddr(outgoingAddr) + if c.OnDialEvent != nil { + c.OnDialEvent(ctx, outgoingAddr, err) + } + if err != nil { + if IsRetriableConnectorError(err) { + continue + } + return nil, err + } + return crdbConn, nil + } + + // Since the retry loop above retries infinitely, the only possibility + // where we will exit the loop is when context is cancelled. + if errors.Is(err, context.Canceled) { + return nil, err + } + // Loop exited at boundary, so mark previous error with cancellation. + if ctxErr := ctx.Err(); err != nil && ctxErr != nil { + return nil, errors.Mark(err, ctxErr) + } + panic("unreachable") +} + +// dialOutgoingAddr dials the given outgoing address for the SQL pod, and +// forwards the startup message to it. If the connector specifies a TLS +// connection, it will also attempt to upgrade the PG connection to use TLS. +func (c *connector) dialOutgoingAddr(outgoingAddr string) (net.Conn, error) { + if c.testingKnobs.dialOutgoingAddr != nil { + return c.testingKnobs.dialOutgoingAddr(outgoingAddr) + } + + // Use a TLS config if one was provided. If TLSConfig is nil, Clone will + // return nil. + tlsConf := c.TLSConfig.Clone() + if tlsConf != nil { + // outgoingAddr will always have a port. We use an empty string as the + // default port as we only care about extracting the host. + outgoingHost, _, err := addr.SplitHostPort(outgoingAddr, "" /* defaultPort */) + if err != nil { + return nil, err + } + // Always set ServerName. If InsecureSkipVerify is true, this will + // be ignored. + tlsConf.ServerName = outgoingHost + } + + conn, err := BackendDial(c.StartupMsg, outgoingAddr, tlsConf) + if err != nil { + var codeErr *codeError + if errors.As(err, &codeErr) && codeErr.code == codeBackendDown { + return nil, MarkAsRetriableConnectorError(err) + } + return nil, err + } + return conn, nil +} diff --git a/pkg/ccl/sqlproxyccl/connector_test.go b/pkg/ccl/sqlproxyccl/connector_test.go new file mode 100644 index 000000000000..1432db391577 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/connector_test.go @@ -0,0 +1,510 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "context" + "crypto/tls" + "net" + "reflect" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/throttler" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/errors" + "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/require" +) + +func TestRetriableConnectorError(t *testing.T) { + defer leaktest.AfterTest(t)() + + err := errors.New("foobar") + require.False(t, IsRetriableConnectorError(err)) + err = MarkAsRetriableConnectorError(err) + require.True(t, IsRetriableConnectorError(err)) +} + +func TestConnector_OpenClusterConnWithToken(t *testing.T) { + defer leaktest.AfterTest(t)() + + const token = "foobarbaz" + ctx := context.Background() + + t.Run("error", func(t *testing.T) { + c := &connector{ + StartupMsg: &pgproto3.StartupMessage{ + Parameters: make(map[string]string), + }, + } + c.testingKnobs.openClusterConnInternal = func(ctx context.Context) (net.Conn, error) { + return nil, errors.New("foo") + } + + crdbConn, err := c.OpenClusterConnWithToken(ctx, token) + require.EqualError(t, err, "foo") + require.Nil(t, crdbConn) + + // Ensure that token is deleted. + str, ok := c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] + require.False(t, ok) + require.Equal(t, "", str) + }) + + t.Run("successful", func(t *testing.T) { + c := &connector{ + StartupMsg: &pgproto3.StartupMessage{ + Parameters: make(map[string]string), + }, + } + conn, _ := net.Pipe() + defer conn.Close() + + var openCalled bool + c.testingKnobs.openClusterConnInternal = func(ctx context.Context) (net.Conn, error) { + openCalled = true + + // Validate that token is set. + str, ok := c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] + require.True(t, ok) + require.Equal(t, token, str) + + return conn, nil + } + + crdbConn, err := c.OpenClusterConnWithToken(ctx, token) + require.True(t, openCalled) + require.NoError(t, err) + require.Equal(t, conn, crdbConn) + + // Ensure that token is deleted. + _, ok := c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] + require.False(t, ok) + }) + + t.Run("idle monitor wrapper is called", func(t *testing.T) { + var wrapperCalled bool + c := &connector{ + StartupMsg: &pgproto3.StartupMessage{ + Parameters: make(map[string]string), + }, + IdleMonitorWrapperFn: func(crdbConn net.Conn) net.Conn { + wrapperCalled = true + return crdbConn + }, + } + + conn, _ := net.Pipe() + defer conn.Close() + + var openCalled bool + c.testingKnobs.openClusterConnInternal = func(ctx context.Context) (net.Conn, error) { + openCalled = true + + // Validate that token is set. + str, ok := c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] + require.True(t, ok) + require.Equal(t, token, str) + + return conn, nil + } + + crdbConn, err := c.OpenClusterConnWithToken(ctx, token) + require.True(t, wrapperCalled) + require.True(t, openCalled) + require.NoError(t, err) + require.Equal(t, conn, crdbConn) + + // Ensure that token is deleted. + _, ok := c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] + require.False(t, ok) + }) +} + +func TestConnector_OpenClusterConnWithAuth(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + dummyHook := func(throttler.AttemptStatus) error { + return nil + } + + t.Run("error during open", func(t *testing.T) { + c := &connector{ + StartupMsg: &pgproto3.StartupMessage{ + Parameters: make(map[string]string), + }, + } + c.testingKnobs.openClusterConnInternal = func(ctx context.Context) (net.Conn, error) { + return nil, errors.New("foo") + } + + crdbConn, sentToClient, err := c.OpenClusterConnWithAuth(ctx, + nil /* clientConn */, nil /* throttleHook */) + require.EqualError(t, err, "foo") + require.False(t, sentToClient) + require.Nil(t, crdbConn) + }) + + t.Run("error during auth", func(t *testing.T) { + c := &connector{ + StartupMsg: &pgproto3.StartupMessage{ + Parameters: make(map[string]string), + }, + AuthenticateFn: func( + client net.Conn, + server net.Conn, + throttleHook func(throttler.AttemptStatus) error, + ) error { + return errors.New("bar") + }, + } + + conn, _ := net.Pipe() + defer conn.Close() + + var openCalled bool + c.testingKnobs.openClusterConnInternal = func(ctx context.Context) (net.Conn, error) { + openCalled = true + return conn, nil + } + + crdbConn, sentToClient, err := c.OpenClusterConnWithAuth(ctx, + nil /* clientConn */, nil /* throttleHook */) + require.True(t, openCalled) + require.EqualError(t, err, "bar") + require.True(t, sentToClient) + require.Nil(t, crdbConn) + + // Connection should be closed. + _, err = conn.Write([]byte("foo")) + require.Regexp(t, "closed pipe", err) + }) + + t.Run("successful", func(t *testing.T) { + clientConn, _ := net.Pipe() + defer clientConn.Close() + + var authCalled bool + c := &connector{ + StartupMsg: &pgproto3.StartupMessage{ + Parameters: map[string]string{ + // Passing in a token should have no effect. + sessionRevivalTokenStartupParam: "foo", + }, + }, + AuthenticateFn: func( + client net.Conn, + server net.Conn, + throttleHook func(throttler.AttemptStatus) error, + ) error { + authCalled = true + require.Equal(t, clientConn, client) + require.NotNil(t, server) + require.Equal(t, reflect.ValueOf(dummyHook).Pointer(), + reflect.ValueOf(throttleHook).Pointer()) + return nil + }, + } + + conn, _ := net.Pipe() + defer conn.Close() + + var openCalled bool + c.testingKnobs.openClusterConnInternal = func(ctx context.Context) (net.Conn, error) { + openCalled = true + + // Validate that token is not set. + _, ok := c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] + require.False(t, ok) + + return conn, nil + } + + crdbConn, sentToClient, err := c.OpenClusterConnWithAuth(ctx, clientConn, dummyHook) + require.True(t, openCalled) + require.True(t, authCalled) + require.NoError(t, err) + require.False(t, sentToClient) + require.Equal(t, conn, crdbConn) + }) + + t.Run("idle monitor wrapper is called", func(t *testing.T) { + clientConn, _ := net.Pipe() + defer clientConn.Close() + + var authCalled, wrapperCalled bool + c := &connector{ + StartupMsg: &pgproto3.StartupMessage{ + Parameters: map[string]string{ + // Passing in a token should have no effect. + sessionRevivalTokenStartupParam: "foo", + }, + }, + AuthenticateFn: func( + client net.Conn, + server net.Conn, + throttleHook func(throttler.AttemptStatus) error, + ) error { + authCalled = true + require.Equal(t, clientConn, client) + require.NotNil(t, server) + require.Equal(t, reflect.ValueOf(dummyHook).Pointer(), + reflect.ValueOf(throttleHook).Pointer()) + return nil + }, + IdleMonitorWrapperFn: func(crdbConn net.Conn) net.Conn { + wrapperCalled = true + return crdbConn + }, + } + + conn, _ := net.Pipe() + defer conn.Close() + + var openCalled bool + c.testingKnobs.openClusterConnInternal = func(ctx context.Context) (net.Conn, error) { + openCalled = true + + // Validate that token is not set. + _, ok := c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] + require.False(t, ok) + + return conn, nil + } + + crdbConn, sentToClient, err := c.OpenClusterConnWithAuth(ctx, clientConn, dummyHook) + require.True(t, openCalled) + require.True(t, wrapperCalled) + require.True(t, authCalled) + require.NoError(t, err) + require.False(t, sentToClient) + require.Equal(t, conn, crdbConn) + }) +} + +func TestConnector_openClusterConnInternal(t *testing.T) { + defer leaktest.AfterTest(t)() + + bgCtx := context.Background() + + t.Run("context canceled at boundary", func(t *testing.T) { + // This is a short test, and is expected to finish within ms. + ctx, cancel := context.WithTimeout(bgCtx, 2*time.Second) + defer cancel() + + var onLookupEventCalled bool + c := &connector{ + AddrLookupFn: func(ctx context.Context) (string, error) { + return "", MarkAsRetriableConnectorError(errors.New("baz")) + }, + OnLookupEvent: func(ctx context.Context, err error) { + onLookupEventCalled = true + require.EqualError(t, err, "baz") + + // Cancel context to trigger loop exit. + cancel() + }, + } + conn, err := c.openClusterConnInternal(ctx) + require.EqualError(t, err, "baz") + require.True(t, errors.Is(err, context.Canceled)) + require.Nil(t, conn) + + require.True(t, onLookupEventCalled) + }) + + t.Run("context canceled within loop", func(t *testing.T) { + // This is a short test, and is expected to finish within ms. + ctx, cancel := context.WithTimeout(bgCtx, 2*time.Second) + defer cancel() + + var onLookupEventCalled bool + c := &connector{ + AddrLookupFn: func(ctx context.Context) (string, error) { + return "", errors.Wrap(context.Canceled, "foobar") + }, + OnLookupEvent: func(ctx context.Context, err error) { + onLookupEventCalled = true + require.EqualError(t, err, "foobar: context canceled") + }, + } + conn, err := c.openClusterConnInternal(ctx) + require.EqualError(t, err, "foobar: context canceled") + require.True(t, errors.Is(err, context.Canceled)) + require.Nil(t, conn) + + require.True(t, onLookupEventCalled) + }) + + t.Run("non-transient error", func(t *testing.T) { + // This is a short test, and is expected to finish within ms. + ctx, cancel := context.WithTimeout(bgCtx, 2*time.Second) + defer cancel() + + c := &connector{ + AddrLookupFn: func(ctx context.Context) (string, error) { + return "", errors.New("baz") + }, + } + conn, err := c.openClusterConnInternal(ctx) + require.EqualError(t, err, "baz") + require.Nil(t, conn) + }) + + t.Run("successful", func(t *testing.T) { + // This should not take more than 5 seconds. + ctx, cancel := context.WithTimeout(bgCtx, 5*time.Second) + defer cancel() + + crdbConn, _ := net.Pipe() + defer crdbConn.Close() + + // We will exercise the following events: + // 1. retriable error on Lookup. + // 2. retriable error on Dial. + var addrLookupFnCount, dialOutgoingAddrCount int + var onLookupEventCalled, onDialEventCalled bool + c := &connector{ + AddrLookupFn: func(ctx context.Context) (string, error) { + addrLookupFnCount++ + if addrLookupFnCount == 1 { + return "", MarkAsRetriableConnectorError(errors.New("foo")) + } + return "127.0.0.10:42", nil + }, + OnLookupEvent: func(ctx context.Context, err error) { + onLookupEventCalled = true + if addrLookupFnCount == 1 { + require.EqualError(t, err, "foo") + } else { + require.NoError(t, err) + } + }, + OnDialEvent: func(ctx context.Context, outgoingAddr string, err error) { + onDialEventCalled = true + require.Equal(t, "127.0.0.10:42", outgoingAddr) + if dialOutgoingAddrCount == 1 { + require.EqualError(t, err, "bar") + } else { + require.NoError(t, err) + } + }, + } + c.testingKnobs.dialOutgoingAddr = func(outgoingAddr string) (net.Conn, error) { + dialOutgoingAddrCount++ + if dialOutgoingAddrCount == 1 { + return nil, MarkAsRetriableConnectorError(errors.New("bar")) + } + return crdbConn, nil + } + conn, err := c.openClusterConnInternal(ctx) + require.NoError(t, err) + require.Equal(t, crdbConn, conn) + + // Assert existing calls. + require.Equal(t, 3, addrLookupFnCount) + require.Equal(t, 2, dialOutgoingAddrCount) + require.True(t, onLookupEventCalled) + require.True(t, onDialEventCalled) + }) +} + +func TestConnector_dialOutgoingAddr(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("with tlsConfig", func(t *testing.T) { + c := &connector{ + StartupMsg: &pgproto3.StartupMessage{}, + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + } + crdbConn, _ := net.Pipe() + defer crdbConn.Close() + + defer testutils.TestingHook(&BackendDial, + func(msg *pgproto3.StartupMessage, outgoingAddress string, + tlsConfig *tls.Config) (net.Conn, error) { + require.Equal(t, c.StartupMsg, msg) + require.Equal(t, "10.11.12.13:80", outgoingAddress) + require.Equal(t, "10.11.12.13", tlsConfig.ServerName) + return crdbConn, nil + }, + )() + conn, err := c.dialOutgoingAddr("10.11.12.13:80") + require.NoError(t, err) + require.Equal(t, crdbConn, conn) + }) + + t.Run("invalid outgoingAddr with tlsConfig", func(t *testing.T) { + c := &connector{ + StartupMsg: &pgproto3.StartupMessage{}, + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + } + conn, err := c.dialOutgoingAddr("!@#$::") + require.Error(t, err) + require.Regexp(t, "invalid address format", err) + require.False(t, IsRetriableConnectorError(err)) + require.Nil(t, conn) + }) + + t.Run("without tlsConfig", func(t *testing.T) { + c := &connector{StartupMsg: &pgproto3.StartupMessage{}} + crdbConn, _ := net.Pipe() + defer crdbConn.Close() + + defer testutils.TestingHook(&BackendDial, + func(msg *pgproto3.StartupMessage, outgoingAddress string, + tlsConfig *tls.Config) (net.Conn, error) { + require.Equal(t, c.StartupMsg, msg) + require.Equal(t, "10.11.12.13:1234", outgoingAddress) + require.Nil(t, tlsConfig) + return crdbConn, nil + }, + )() + conn, err := c.dialOutgoingAddr("10.11.12.13:1234") + require.NoError(t, err) + require.Equal(t, crdbConn, conn) + }) + + t.Run("failed to dial with non-transient error", func(t *testing.T) { + c := &connector{StartupMsg: &pgproto3.StartupMessage{}} + defer testutils.TestingHook(&BackendDial, + func(msg *pgproto3.StartupMessage, outgoingAddress string, + tlsConfig *tls.Config) (net.Conn, error) { + require.Equal(t, c.StartupMsg, msg) + require.Equal(t, "127.0.0.1:1234", outgoingAddress) + require.Nil(t, tlsConfig) + return nil, errors.New("foo") + }, + )() + conn, err := c.dialOutgoingAddr("127.0.0.1:1234") + require.EqualError(t, err, "foo") + require.False(t, IsRetriableConnectorError(err)) + require.Nil(t, conn) + }) + + t.Run("failed to dial with transient error", func(t *testing.T) { + c := &connector{StartupMsg: &pgproto3.StartupMessage{}} + defer testutils.TestingHook(&BackendDial, + func(msg *pgproto3.StartupMessage, outgoingAddress string, + tlsConfig *tls.Config) (net.Conn, error) { + require.Equal(t, c.StartupMsg, msg) + require.Equal(t, "127.0.0.2:4567", outgoingAddress) + require.Nil(t, tlsConfig) + return nil, newErrorf(codeBackendDown, "bar") + }, + )() + conn, err := c.dialOutgoingAddr("127.0.0.2:4567") + require.EqualError(t, err, "codeBackendDown: bar") + require.True(t, IsRetriableConnectorError(err)) + require.Nil(t, conn) + }) +} diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index 7cc8d82c8709..ee2243a6edd2 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -26,7 +26,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/security/certmgr" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/netutil/addr" - "github.com/cockroachdb/cockroach/pkg/util/retry" "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" @@ -275,86 +274,62 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn return err } - var crdbConn net.Conn - var outgoingAddress string - - // Repeatedly try to make a connection. Any failures are assumed to be - // transient unless the tenant cannot be found (e.g. because it was - // deleted). We will simply loop forever, or until the context is canceled - // (e.g. by client disconnect). This is preferable to terminating client - // connections, because in most cases those connections will simply be - // retried, further increasing load on the system. - retryOpts := retry.Options{ - InitialBackoff: 10 * time.Millisecond, - MaxBackoff: 5 * time.Second, - } - outgoingAddressErr := log.Every(time.Minute) backendDialErr := log.Every(time.Minute) reportFailureErr := log.Every(time.Minute) var outgoingAddressErrs, codeBackendDownErrs, reportFailureErrs int - for r := retry.StartWithCtx(ctx, retryOpts); r.Next(); { - // Get the DNS/IP address of the backend server to dial. - outgoingAddress, err = handler.outgoingAddress(ctx, clusterName, tenID) - if err != nil { - // Failure is assumed to be transient (and should be retried) except - // in case where the server was not found. + // TLS options for the proxy are split into Insecure and SkipVerify. + // In insecure mode, tlsConf is expected to be nil. This will cause the + // connector's dialer to skip TLS entirely. If SkipVerify is true, + // tlsConf will be set to a non-nil config with InsecureSkipVerify set + // to true. InsecureSkipVerify will provide an encrypted connection but + // not verify that the connection recipient is a trusted party. + var tlsConf *tls.Config + if !handler.Insecure { + tlsConf = &tls.Config{InsecureSkipVerify: handler.SkipVerify} + } + connector := &connector{ + StartupMsg: backendStartupMsg, + TLSConfig: tlsConf, + AddrLookupFn: func(ctx context.Context) (string, error) { + addr, err := handler.outgoingAddress(ctx, clusterName, tenID) + if err == nil { + return addr, nil + } + // Transient error when retrieving outgoing address. if status.Code(err) != codes.NotFound { - outgoingAddressErrs++ - if outgoingAddressErr.ShouldLog() { - log.Ops.Errorf(ctx, - "outgoing address (%d errors skipped): %v", - outgoingAddressErrs, - err, - ) - outgoingAddressErrs = 0 - } - continue + return "", MarkAsRetriableConnectorError(err) } - - // Remap error for external consumption. - log.Errorf(ctx, "could not retrieve outgoing address: %v", err.Error()) - err = newErrorf( - codeParamsRoutingFailed, "cluster %s-%d not found", clusterName, tenID.ToUint64()) - break - } - - // NB: TLS options for the proxy are split into Insecure and - // SkipVerify. In insecure mode, tlsConf is expected to be nil. This - // will cause BackendDial to skip TLS entirely. If SkipVerify is true, - // tlsConf will be set to a non-nil config with InsecureSkipVerify set - // to true. InsecureSkipVerify will provide an encrypted connection but - // not verify that the connection recipient is a trusted party. - var tlsConf *tls.Config - if !handler.Insecure { - // Use an empty string as the default port as we only care about the - // correctly parsing the outgoingHost/IP here. - outgoingHost, _, err := addr.SplitHostPort(outgoingAddress, "") - if err != nil { - log.Errorf(ctx, "could not split outgoing address '%s' into host and port: %v", outgoingAddress, err.Error()) - // Remap error for external consumption. - clientErr := newErrorf( - codeParamsRoutingFailed, "cluster %s-%d not found", clusterName, tenID.ToUint64()) - updateMetricsAndSendErrToClient(clientErr, conn, handler.metrics) - return clientErr + // Don't retry if we get a NotFound error. + return "", newErrorf(codeParamsRoutingFailed, + "cluster %s-%d not found", clusterName, tenID.ToUint64()) + }, + AuthenticateFn: authenticate, + OnLookupEvent: func(ctx context.Context, err error) { + // We only care about retriable errors since we want to log them. + if !IsRetriableConnectorError(err) { + outgoingAddressErrs = 0 + return } - - tlsConf = &tls.Config{ - // Always set ServerName, if SkipVerify is true, it will be - // ignored. When SkipVerify is false, it is required to - // establish a TLS connection. - ServerName: outgoingHost, - InsecureSkipVerify: handler.SkipVerify, + outgoingAddressErrs++ + if outgoingAddressErr.ShouldLog() { + log.Ops.Errorf(ctx, + "outgoing address (%d errors skipped): %v", + outgoingAddressErrs, + err, + ) + outgoingAddressErrs = 0 + } + }, + OnDialEvent: func(ctx context.Context, outgoingAddr string, err error) { + // We only care about retriable errors since we want to log them, + // and report errors to the tenant directory. + if !IsRetriableConnectorError(err) { + codeBackendDownErrs = 0 + reportFailureErrs = 0 + return } - } - - // Now actually dial the backend server. - crdbConn, err = BackendDial(backendStartupMsg, outgoingAddress, tlsConf) - - // If we get a backend down error, retry the connection. - var codeErr *codeError - if err != nil && errors.As(err, &codeErr) && codeErr.code == codeBackendDown { codeBackendDownErrs++ if backendDialErr.ShouldLog() { log.Ops.Errorf(ctx, @@ -364,12 +339,12 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn ) codeBackendDownErrs = 0 } - if handler.directory != nil { // Report the failure to the directory so that it can refresh any // stale information that may have caused the problem. - err = reportFailureToDirectory(ctx, tenID, outgoingAddress, handler.directory) - if err != nil { + if err = reportFailureToDirectory( + ctx, tenID, outgoingAddr, handler.directory, + ); err != nil { reportFailureErrs++ if reportFailureErr.ShouldLog() { log.Ops.Errorf(ctx, @@ -381,42 +356,43 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn } } } - continue - } - break - } - - if err != nil { - updateMetricsAndSendErrToClient(err, conn, handler.metrics) - return err + }, } // Monitor for idle connection, if requested. if handler.idleMonitor != nil { - crdbConn = handler.idleMonitor.DetectIdle(crdbConn, func() { - err := newErrorf(codeIdleDisconnect, "idle connection closed") - select { - case errConnection <- err: /* error reported */ - default: /* the channel already contains an error */ - } - }) + connector.IdleMonitorWrapperFn = func(crdbConn net.Conn) net.Conn { + return handler.idleMonitor.DetectIdle(crdbConn, func() { + err := newErrorf(codeIdleDisconnect, "idle connection closed") + select { + case errConnection <- err: /* error reported */ + default: /* the channel already contains an error */ + } + }) + } } - defer func() { _ = crdbConn.Close() }() - - // Perform user authentication. - if err := authenticate(conn, crdbConn, func(status throttler.AttemptStatus) error { - err := handler.throttleService.ReportAttempt(ctx, throttleTags, throttleTime, status) - if err != nil { - log.Errorf(ctx, "throttler refused connection after authentication: %v", err.Error()) - return throttledError + crdbConn, sentToClient, err := connector.OpenClusterConnWithAuth(ctx, conn, + func(status throttler.AttemptStatus) error { + if err := handler.throttleService.ReportAttempt( + ctx, throttleTags, throttleTime, status, + ); err != nil { + log.Errorf(ctx, "throttler refused connection after authentication: %v", err.Error()) + return throttledError + } + return nil + }, + ) + if err != nil { + log.Errorf(ctx, "could not connect to cluster: %v", err.Error()) + if sentToClient { + handler.metrics.updateForError(err) + } else { + updateMetricsAndSendErrToClient(err, conn, handler.metrics) } - return nil - }); err != nil { - handler.metrics.updateForError(err) - log.Ops.Errorf(ctx, "authenticate: %s", err) return err } + defer crdbConn.Close() handler.metrics.SuccessfulConnCount.Inc(1)