diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index d546b57c2c83..a5c6f0a5aa4f 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -6,6 +6,7 @@ go_library( srcs = [ "authentication.go", "backend_dialer.go", + "connector.go", "error.go", "frontend_admitter.go", "metrics.go", @@ -48,6 +49,7 @@ go_test( size = "small", srcs = [ "authentication_test.go", + "connector_test.go", "frontend_admitter_test.go", "main_test.go", "proxy_handler_test.go", diff --git a/pkg/ccl/sqlproxyccl/backend_dialer.go b/pkg/ccl/sqlproxyccl/backend_dialer.go index b0ce1c594c62..5d1f12d76c56 100644 --- a/pkg/ccl/sqlproxyccl/backend_dialer.go +++ b/pkg/ccl/sqlproxyccl/backend_dialer.go @@ -24,6 +24,8 @@ import ( // // BackendDial uses a dial timeout of 5 seconds to mitigate network black // holes. +// +// TODO(jaylim-crl): Move dialer into connector in the future. var BackendDial = func( msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config, ) (net.Conn, error) { diff --git a/pkg/ccl/sqlproxyccl/connector.go b/pkg/ccl/sqlproxyccl/connector.go new file mode 100644 index 000000000000..935c4a971c1c --- /dev/null +++ b/pkg/ccl/sqlproxyccl/connector.go @@ -0,0 +1,267 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "context" + "crypto/tls" + "net" + "time" + + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/throttler" + "github.com/cockroachdb/cockroach/pkg/util/netutil/addr" + "github.com/cockroachdb/cockroach/pkg/util/retry" + "github.com/cockroachdb/errors" + pgproto3 "github.com/jackc/pgproto3/v2" +) + +// sessionRevivalTokenStartupParam indicates the name of the parameter that +// will activate token-based authentication if present in the startup message. +const sessionRevivalTokenStartupParam = "crdb:session_revival_token_base64" + +// errRetryConnectorSentinel exists to allow more robust retection of retry +// errors even if they are wrapped. +var errRetryConnectorSentinel = errors.New("retry connector error") + +// MarkAsRetriableConnectorError marks the given error with +// errRetryConnectorSentinel, which will trigger the connector to retry if such +// error returns. +func MarkAsRetriableConnectorError(err error) error { + return errors.Mark(err, errRetryConnectorSentinel) +} + +// IsRetriableConnectorError checks whether a given error is retriable. This +// should be called on errors which are transient so that the connector can +// retry on such errors. +func IsRetriableConnectorError(err error) bool { + return errors.Is(err, errRetryConnectorSentinel) +} + +// connector is a per-session tenant-associated component that can be used to +// obtain a connection to the tenant cluster. This will also handle the +// authentication phase. All connections returned by the connector should +// already be ready to accept regular pgwire messages (e.g. SQL queries). +type connector struct { + // StartupMsg represents the startup message associated with the client. + // This will be used when establishing a pgwire connection with the SQL pod. + // + // NOTE: This field is required. + StartupMsg *pgproto3.StartupMessage + + // TLSConfig represents the client TLS config used by the connector when + // connecting with the SQL pod. If the ServerName field is set, this will + // be overridden during connection establishment. Set to nil if we are + // connecting to an insecure cluster. + // + // NOTE: This field is optional. + TLSConfig *tls.Config + + // AddrLookupFn is used by the connector to return an address (that must + // include both host and port) pointing to one of the SQL pods for the + // tenant associated with this connector. + // + // This will be called within an infinite backoff loop. If an error is + // transient, this should return an error that has been marked with + // errRetryConnectorSentinel (i.e. MarkAsRetriableConnectorError). + // + // NOTE: This field is required. + AddrLookupFn func(ctx context.Context) (string, error) + + // AuthenticateFn is used by the connector to authenticate the client + // against the server. This will only be used in non-token-based auth + // methods. This should block until the server has authenticated the client. + // + // NOTE: This field is required. + AuthenticateFn func( + client net.Conn, + server net.Conn, + throttleHook func(throttler.AttemptStatus) error, + ) error + + // IdleMonitorWrapperFn is used to wrap the connection to the SQL pod with + // an idle monitor. If not specified, the raw connection to the SQL pod + // will be returned. + // + // In the case of connecting with an authentication phase, the connection + // will be wrapped before starting the authentication. + // + // NOTE: This field is optional. + IdleMonitorWrapperFn func(crdbConn net.Conn) net.Conn + + // Event callback functions. OnLookupEvent and OnDialEvent will be called + // after the lookup and dial operations respectively, regardless of error. + // + // NOTE: These fields are optional. + // + // TODO(jaylim-crl): Look into removing event callback functions. This + // requires us to pass in some sort of directory into the connector + // component. Perhaps addrLookupFn could be replaced with that. We can't + // do that today because addrLookupFn also relies on a fallback mechanism + // for --routing-rule. + OnLookupEvent func(ctx context.Context, err error) + OnDialEvent func(ctx context.Context, outgoingAddr string, err error) + + // Testing knobs for internal connector calls. If specified, these will + // be called instead of the actual logic. + testingKnobs struct { + openClusterConnInternal func(ctx context.Context) (net.Conn, error) + dialOutgoingAddr func(outgoingAddr string) (net.Conn, error) + } +} + +// OpenClusterConnWithToken opens a connection to the tenant cluster using the +// token-based authentication. +func (c *connector) OpenClusterConnWithToken(ctx context.Context, token string) (net.Conn, error) { + c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] = token + defer func() { + // Delete token after return. + delete(c.StartupMsg.Parameters, sessionRevivalTokenStartupParam) + }() + + crdbConn, err := c.openClusterConnInternal(ctx) + if err != nil { + return nil, err + } + if c.IdleMonitorWrapperFn != nil { + crdbConn = c.IdleMonitorWrapperFn(crdbConn) + } + return crdbConn, nil +} + +// OpenClusterConnWithAuth opens a connection to the tenant cluster using +// normal authentication methods (e.g. password, etc.). Once a connection to +// one of the tenant's SQL pod has been established, we will transfer +// request/response flow between clientConn and the new connection to the +// authenticator, which implies that this will be blocked until authentication +// succeeds, or when an error is returned. +// +// sentToClient will be set to true if an error occurred during the +// authenticator phase since errors would have already been sent to the client. +func (c *connector) OpenClusterConnWithAuth( + ctx context.Context, clientConn net.Conn, throttleHook func(throttler.AttemptStatus) error, +) (serverConn net.Conn, sentToClient bool, retErr error) { + // Just a safety check, but this shouldn't happen since we will block the + // startup param in the frontend admitter. The only case where we actually + // need to delete this param is if OpenClusterConnWithToken was called + // previously, but that wouldn't happen based on the current proxy logic. + delete(c.StartupMsg.Parameters, sessionRevivalTokenStartupParam) + + crdbConn, err := c.openClusterConnInternal(ctx) + if err != nil { + return nil, false, err + } + defer func() { + if retErr != nil { + crdbConn.Close() + } + }() + + if c.IdleMonitorWrapperFn != nil { + crdbConn = c.IdleMonitorWrapperFn(crdbConn) + } + + // Perform user authentication. + if err := c.AuthenticateFn(clientConn, crdbConn, throttleHook); err != nil { + return nil, true, err + } + return crdbConn, false, nil +} + +// openClusterConnInternal returns a connection to the tenant cluster associated +// with the connector. Once a connection has been established, the pgwire +// startup message will be relayed to the server. Returned errors may be marked +// as a lookup or dial error. +func (c *connector) openClusterConnInternal(ctx context.Context) (net.Conn, error) { + if c.testingKnobs.openClusterConnInternal != nil { + return c.testingKnobs.openClusterConnInternal(ctx) + } + + // Repeatedly try to make a connection until context is canceled, or until + // we get a non-retriable error. This is preferable to terminating client + // connections, because in most cases those connections will simply be + // retried, further increasing load on the system. + retryOpts := retry.Options{ + InitialBackoff: 10 * time.Millisecond, + MaxBackoff: 5 * time.Second, + } + + var crdbConn net.Conn + var outgoingAddr string + var err error + for r := retry.StartWithCtx(ctx, retryOpts); r.Next(); { + // Retrieve a SQL pod address to connect to. + outgoingAddr, err = c.AddrLookupFn(ctx) + if c.OnLookupEvent != nil { + c.OnLookupEvent(ctx, err) + } + if err != nil { + if IsRetriableConnectorError(err) { + continue + } + return nil, err + } + // Make a connection to the SQL pod. + crdbConn, err = c.dialOutgoingAddr(outgoingAddr) + if c.OnDialEvent != nil { + c.OnDialEvent(ctx, outgoingAddr, err) + } + if err != nil { + if IsRetriableConnectorError(err) { + continue + } + return nil, err + } + return crdbConn, nil + } + + // Since the retry loop above retries infinitely, the only possibility + // where we will exit the loop is when context is cancelled. + if errors.Is(err, context.Canceled) { + return nil, err + } + // Loop exited at boundary, so mark previous error with cancellation. + if ctxErr := ctx.Err(); err != nil && ctxErr != nil { + return nil, errors.Mark(err, ctxErr) + } + panic("unreachable") +} + +// dialOutgoingAddr dials the given outgoing address for the SQL pod, and +// forwards the startup message to it. If the connector specifies a TLS +// connection, it will also attempt to upgrade the PG connection to use TLS. +func (c *connector) dialOutgoingAddr(outgoingAddr string) (net.Conn, error) { + if c.testingKnobs.dialOutgoingAddr != nil { + return c.testingKnobs.dialOutgoingAddr(outgoingAddr) + } + + // Use a TLS config if one was provided. If TLSConfig is nil, Clone will + // return nil. + tlsConf := c.TLSConfig.Clone() + if tlsConf != nil { + // outgoingAddr will always have a port. We use an empty string as the + // default port as we only care about extracting the host. + outgoingHost, _, err := addr.SplitHostPort(outgoingAddr, "" /* defaultPort */) + if err != nil { + return nil, err + } + // Always set ServerName. If InsecureSkipVerify is true, this will + // be ignored. + tlsConf.ServerName = outgoingHost + } + + conn, err := BackendDial(c.StartupMsg, outgoingAddr, tlsConf) + if err != nil { + var codeErr *codeError + if errors.As(err, &codeErr) && codeErr.code == codeBackendDown { + return nil, MarkAsRetriableConnectorError(err) + } + return nil, err + } + return conn, nil +} diff --git a/pkg/ccl/sqlproxyccl/connector_test.go b/pkg/ccl/sqlproxyccl/connector_test.go new file mode 100644 index 000000000000..1432db391577 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/connector_test.go @@ -0,0 +1,510 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "context" + "crypto/tls" + "net" + "reflect" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/throttler" + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/errors" + "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/require" +) + +func TestRetriableConnectorError(t *testing.T) { + defer leaktest.AfterTest(t)() + + err := errors.New("foobar") + require.False(t, IsRetriableConnectorError(err)) + err = MarkAsRetriableConnectorError(err) + require.True(t, IsRetriableConnectorError(err)) +} + +func TestConnector_OpenClusterConnWithToken(t *testing.T) { + defer leaktest.AfterTest(t)() + + const token = "foobarbaz" + ctx := context.Background() + + t.Run("error", func(t *testing.T) { + c := &connector{ + StartupMsg: &pgproto3.StartupMessage{ + Parameters: make(map[string]string), + }, + } + c.testingKnobs.openClusterConnInternal = func(ctx context.Context) (net.Conn, error) { + return nil, errors.New("foo") + } + + crdbConn, err := c.OpenClusterConnWithToken(ctx, token) + require.EqualError(t, err, "foo") + require.Nil(t, crdbConn) + + // Ensure that token is deleted. + str, ok := c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] + require.False(t, ok) + require.Equal(t, "", str) + }) + + t.Run("successful", func(t *testing.T) { + c := &connector{ + StartupMsg: &pgproto3.StartupMessage{ + Parameters: make(map[string]string), + }, + } + conn, _ := net.Pipe() + defer conn.Close() + + var openCalled bool + c.testingKnobs.openClusterConnInternal = func(ctx context.Context) (net.Conn, error) { + openCalled = true + + // Validate that token is set. + str, ok := c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] + require.True(t, ok) + require.Equal(t, token, str) + + return conn, nil + } + + crdbConn, err := c.OpenClusterConnWithToken(ctx, token) + require.True(t, openCalled) + require.NoError(t, err) + require.Equal(t, conn, crdbConn) + + // Ensure that token is deleted. + _, ok := c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] + require.False(t, ok) + }) + + t.Run("idle monitor wrapper is called", func(t *testing.T) { + var wrapperCalled bool + c := &connector{ + StartupMsg: &pgproto3.StartupMessage{ + Parameters: make(map[string]string), + }, + IdleMonitorWrapperFn: func(crdbConn net.Conn) net.Conn { + wrapperCalled = true + return crdbConn + }, + } + + conn, _ := net.Pipe() + defer conn.Close() + + var openCalled bool + c.testingKnobs.openClusterConnInternal = func(ctx context.Context) (net.Conn, error) { + openCalled = true + + // Validate that token is set. + str, ok := c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] + require.True(t, ok) + require.Equal(t, token, str) + + return conn, nil + } + + crdbConn, err := c.OpenClusterConnWithToken(ctx, token) + require.True(t, wrapperCalled) + require.True(t, openCalled) + require.NoError(t, err) + require.Equal(t, conn, crdbConn) + + // Ensure that token is deleted. + _, ok := c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] + require.False(t, ok) + }) +} + +func TestConnector_OpenClusterConnWithAuth(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + dummyHook := func(throttler.AttemptStatus) error { + return nil + } + + t.Run("error during open", func(t *testing.T) { + c := &connector{ + StartupMsg: &pgproto3.StartupMessage{ + Parameters: make(map[string]string), + }, + } + c.testingKnobs.openClusterConnInternal = func(ctx context.Context) (net.Conn, error) { + return nil, errors.New("foo") + } + + crdbConn, sentToClient, err := c.OpenClusterConnWithAuth(ctx, + nil /* clientConn */, nil /* throttleHook */) + require.EqualError(t, err, "foo") + require.False(t, sentToClient) + require.Nil(t, crdbConn) + }) + + t.Run("error during auth", func(t *testing.T) { + c := &connector{ + StartupMsg: &pgproto3.StartupMessage{ + Parameters: make(map[string]string), + }, + AuthenticateFn: func( + client net.Conn, + server net.Conn, + throttleHook func(throttler.AttemptStatus) error, + ) error { + return errors.New("bar") + }, + } + + conn, _ := net.Pipe() + defer conn.Close() + + var openCalled bool + c.testingKnobs.openClusterConnInternal = func(ctx context.Context) (net.Conn, error) { + openCalled = true + return conn, nil + } + + crdbConn, sentToClient, err := c.OpenClusterConnWithAuth(ctx, + nil /* clientConn */, nil /* throttleHook */) + require.True(t, openCalled) + require.EqualError(t, err, "bar") + require.True(t, sentToClient) + require.Nil(t, crdbConn) + + // Connection should be closed. + _, err = conn.Write([]byte("foo")) + require.Regexp(t, "closed pipe", err) + }) + + t.Run("successful", func(t *testing.T) { + clientConn, _ := net.Pipe() + defer clientConn.Close() + + var authCalled bool + c := &connector{ + StartupMsg: &pgproto3.StartupMessage{ + Parameters: map[string]string{ + // Passing in a token should have no effect. + sessionRevivalTokenStartupParam: "foo", + }, + }, + AuthenticateFn: func( + client net.Conn, + server net.Conn, + throttleHook func(throttler.AttemptStatus) error, + ) error { + authCalled = true + require.Equal(t, clientConn, client) + require.NotNil(t, server) + require.Equal(t, reflect.ValueOf(dummyHook).Pointer(), + reflect.ValueOf(throttleHook).Pointer()) + return nil + }, + } + + conn, _ := net.Pipe() + defer conn.Close() + + var openCalled bool + c.testingKnobs.openClusterConnInternal = func(ctx context.Context) (net.Conn, error) { + openCalled = true + + // Validate that token is not set. + _, ok := c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] + require.False(t, ok) + + return conn, nil + } + + crdbConn, sentToClient, err := c.OpenClusterConnWithAuth(ctx, clientConn, dummyHook) + require.True(t, openCalled) + require.True(t, authCalled) + require.NoError(t, err) + require.False(t, sentToClient) + require.Equal(t, conn, crdbConn) + }) + + t.Run("idle monitor wrapper is called", func(t *testing.T) { + clientConn, _ := net.Pipe() + defer clientConn.Close() + + var authCalled, wrapperCalled bool + c := &connector{ + StartupMsg: &pgproto3.StartupMessage{ + Parameters: map[string]string{ + // Passing in a token should have no effect. + sessionRevivalTokenStartupParam: "foo", + }, + }, + AuthenticateFn: func( + client net.Conn, + server net.Conn, + throttleHook func(throttler.AttemptStatus) error, + ) error { + authCalled = true + require.Equal(t, clientConn, client) + require.NotNil(t, server) + require.Equal(t, reflect.ValueOf(dummyHook).Pointer(), + reflect.ValueOf(throttleHook).Pointer()) + return nil + }, + IdleMonitorWrapperFn: func(crdbConn net.Conn) net.Conn { + wrapperCalled = true + return crdbConn + }, + } + + conn, _ := net.Pipe() + defer conn.Close() + + var openCalled bool + c.testingKnobs.openClusterConnInternal = func(ctx context.Context) (net.Conn, error) { + openCalled = true + + // Validate that token is not set. + _, ok := c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] + require.False(t, ok) + + return conn, nil + } + + crdbConn, sentToClient, err := c.OpenClusterConnWithAuth(ctx, clientConn, dummyHook) + require.True(t, openCalled) + require.True(t, wrapperCalled) + require.True(t, authCalled) + require.NoError(t, err) + require.False(t, sentToClient) + require.Equal(t, conn, crdbConn) + }) +} + +func TestConnector_openClusterConnInternal(t *testing.T) { + defer leaktest.AfterTest(t)() + + bgCtx := context.Background() + + t.Run("context canceled at boundary", func(t *testing.T) { + // This is a short test, and is expected to finish within ms. + ctx, cancel := context.WithTimeout(bgCtx, 2*time.Second) + defer cancel() + + var onLookupEventCalled bool + c := &connector{ + AddrLookupFn: func(ctx context.Context) (string, error) { + return "", MarkAsRetriableConnectorError(errors.New("baz")) + }, + OnLookupEvent: func(ctx context.Context, err error) { + onLookupEventCalled = true + require.EqualError(t, err, "baz") + + // Cancel context to trigger loop exit. + cancel() + }, + } + conn, err := c.openClusterConnInternal(ctx) + require.EqualError(t, err, "baz") + require.True(t, errors.Is(err, context.Canceled)) + require.Nil(t, conn) + + require.True(t, onLookupEventCalled) + }) + + t.Run("context canceled within loop", func(t *testing.T) { + // This is a short test, and is expected to finish within ms. + ctx, cancel := context.WithTimeout(bgCtx, 2*time.Second) + defer cancel() + + var onLookupEventCalled bool + c := &connector{ + AddrLookupFn: func(ctx context.Context) (string, error) { + return "", errors.Wrap(context.Canceled, "foobar") + }, + OnLookupEvent: func(ctx context.Context, err error) { + onLookupEventCalled = true + require.EqualError(t, err, "foobar: context canceled") + }, + } + conn, err := c.openClusterConnInternal(ctx) + require.EqualError(t, err, "foobar: context canceled") + require.True(t, errors.Is(err, context.Canceled)) + require.Nil(t, conn) + + require.True(t, onLookupEventCalled) + }) + + t.Run("non-transient error", func(t *testing.T) { + // This is a short test, and is expected to finish within ms. + ctx, cancel := context.WithTimeout(bgCtx, 2*time.Second) + defer cancel() + + c := &connector{ + AddrLookupFn: func(ctx context.Context) (string, error) { + return "", errors.New("baz") + }, + } + conn, err := c.openClusterConnInternal(ctx) + require.EqualError(t, err, "baz") + require.Nil(t, conn) + }) + + t.Run("successful", func(t *testing.T) { + // This should not take more than 5 seconds. + ctx, cancel := context.WithTimeout(bgCtx, 5*time.Second) + defer cancel() + + crdbConn, _ := net.Pipe() + defer crdbConn.Close() + + // We will exercise the following events: + // 1. retriable error on Lookup. + // 2. retriable error on Dial. + var addrLookupFnCount, dialOutgoingAddrCount int + var onLookupEventCalled, onDialEventCalled bool + c := &connector{ + AddrLookupFn: func(ctx context.Context) (string, error) { + addrLookupFnCount++ + if addrLookupFnCount == 1 { + return "", MarkAsRetriableConnectorError(errors.New("foo")) + } + return "127.0.0.10:42", nil + }, + OnLookupEvent: func(ctx context.Context, err error) { + onLookupEventCalled = true + if addrLookupFnCount == 1 { + require.EqualError(t, err, "foo") + } else { + require.NoError(t, err) + } + }, + OnDialEvent: func(ctx context.Context, outgoingAddr string, err error) { + onDialEventCalled = true + require.Equal(t, "127.0.0.10:42", outgoingAddr) + if dialOutgoingAddrCount == 1 { + require.EqualError(t, err, "bar") + } else { + require.NoError(t, err) + } + }, + } + c.testingKnobs.dialOutgoingAddr = func(outgoingAddr string) (net.Conn, error) { + dialOutgoingAddrCount++ + if dialOutgoingAddrCount == 1 { + return nil, MarkAsRetriableConnectorError(errors.New("bar")) + } + return crdbConn, nil + } + conn, err := c.openClusterConnInternal(ctx) + require.NoError(t, err) + require.Equal(t, crdbConn, conn) + + // Assert existing calls. + require.Equal(t, 3, addrLookupFnCount) + require.Equal(t, 2, dialOutgoingAddrCount) + require.True(t, onLookupEventCalled) + require.True(t, onDialEventCalled) + }) +} + +func TestConnector_dialOutgoingAddr(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("with tlsConfig", func(t *testing.T) { + c := &connector{ + StartupMsg: &pgproto3.StartupMessage{}, + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + } + crdbConn, _ := net.Pipe() + defer crdbConn.Close() + + defer testutils.TestingHook(&BackendDial, + func(msg *pgproto3.StartupMessage, outgoingAddress string, + tlsConfig *tls.Config) (net.Conn, error) { + require.Equal(t, c.StartupMsg, msg) + require.Equal(t, "10.11.12.13:80", outgoingAddress) + require.Equal(t, "10.11.12.13", tlsConfig.ServerName) + return crdbConn, nil + }, + )() + conn, err := c.dialOutgoingAddr("10.11.12.13:80") + require.NoError(t, err) + require.Equal(t, crdbConn, conn) + }) + + t.Run("invalid outgoingAddr with tlsConfig", func(t *testing.T) { + c := &connector{ + StartupMsg: &pgproto3.StartupMessage{}, + TLSConfig: &tls.Config{InsecureSkipVerify: true}, + } + conn, err := c.dialOutgoingAddr("!@#$::") + require.Error(t, err) + require.Regexp(t, "invalid address format", err) + require.False(t, IsRetriableConnectorError(err)) + require.Nil(t, conn) + }) + + t.Run("without tlsConfig", func(t *testing.T) { + c := &connector{StartupMsg: &pgproto3.StartupMessage{}} + crdbConn, _ := net.Pipe() + defer crdbConn.Close() + + defer testutils.TestingHook(&BackendDial, + func(msg *pgproto3.StartupMessage, outgoingAddress string, + tlsConfig *tls.Config) (net.Conn, error) { + require.Equal(t, c.StartupMsg, msg) + require.Equal(t, "10.11.12.13:1234", outgoingAddress) + require.Nil(t, tlsConfig) + return crdbConn, nil + }, + )() + conn, err := c.dialOutgoingAddr("10.11.12.13:1234") + require.NoError(t, err) + require.Equal(t, crdbConn, conn) + }) + + t.Run("failed to dial with non-transient error", func(t *testing.T) { + c := &connector{StartupMsg: &pgproto3.StartupMessage{}} + defer testutils.TestingHook(&BackendDial, + func(msg *pgproto3.StartupMessage, outgoingAddress string, + tlsConfig *tls.Config) (net.Conn, error) { + require.Equal(t, c.StartupMsg, msg) + require.Equal(t, "127.0.0.1:1234", outgoingAddress) + require.Nil(t, tlsConfig) + return nil, errors.New("foo") + }, + )() + conn, err := c.dialOutgoingAddr("127.0.0.1:1234") + require.EqualError(t, err, "foo") + require.False(t, IsRetriableConnectorError(err)) + require.Nil(t, conn) + }) + + t.Run("failed to dial with transient error", func(t *testing.T) { + c := &connector{StartupMsg: &pgproto3.StartupMessage{}} + defer testutils.TestingHook(&BackendDial, + func(msg *pgproto3.StartupMessage, outgoingAddress string, + tlsConfig *tls.Config) (net.Conn, error) { + require.Equal(t, c.StartupMsg, msg) + require.Equal(t, "127.0.0.2:4567", outgoingAddress) + require.Nil(t, tlsConfig) + return nil, newErrorf(codeBackendDown, "bar") + }, + )() + conn, err := c.dialOutgoingAddr("127.0.0.2:4567") + require.EqualError(t, err, "codeBackendDown: bar") + require.True(t, IsRetriableConnectorError(err)) + require.Nil(t, conn) + }) +} diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index 7cc8d82c8709..ee2243a6edd2 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -26,7 +26,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/security/certmgr" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/netutil/addr" - "github.com/cockroachdb/cockroach/pkg/util/retry" "github.com/cockroachdb/cockroach/pkg/util/stop" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" @@ -275,86 +274,62 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn return err } - var crdbConn net.Conn - var outgoingAddress string - - // Repeatedly try to make a connection. Any failures are assumed to be - // transient unless the tenant cannot be found (e.g. because it was - // deleted). We will simply loop forever, or until the context is canceled - // (e.g. by client disconnect). This is preferable to terminating client - // connections, because in most cases those connections will simply be - // retried, further increasing load on the system. - retryOpts := retry.Options{ - InitialBackoff: 10 * time.Millisecond, - MaxBackoff: 5 * time.Second, - } - outgoingAddressErr := log.Every(time.Minute) backendDialErr := log.Every(time.Minute) reportFailureErr := log.Every(time.Minute) var outgoingAddressErrs, codeBackendDownErrs, reportFailureErrs int - for r := retry.StartWithCtx(ctx, retryOpts); r.Next(); { - // Get the DNS/IP address of the backend server to dial. - outgoingAddress, err = handler.outgoingAddress(ctx, clusterName, tenID) - if err != nil { - // Failure is assumed to be transient (and should be retried) except - // in case where the server was not found. + // TLS options for the proxy are split into Insecure and SkipVerify. + // In insecure mode, tlsConf is expected to be nil. This will cause the + // connector's dialer to skip TLS entirely. If SkipVerify is true, + // tlsConf will be set to a non-nil config with InsecureSkipVerify set + // to true. InsecureSkipVerify will provide an encrypted connection but + // not verify that the connection recipient is a trusted party. + var tlsConf *tls.Config + if !handler.Insecure { + tlsConf = &tls.Config{InsecureSkipVerify: handler.SkipVerify} + } + connector := &connector{ + StartupMsg: backendStartupMsg, + TLSConfig: tlsConf, + AddrLookupFn: func(ctx context.Context) (string, error) { + addr, err := handler.outgoingAddress(ctx, clusterName, tenID) + if err == nil { + return addr, nil + } + // Transient error when retrieving outgoing address. if status.Code(err) != codes.NotFound { - outgoingAddressErrs++ - if outgoingAddressErr.ShouldLog() { - log.Ops.Errorf(ctx, - "outgoing address (%d errors skipped): %v", - outgoingAddressErrs, - err, - ) - outgoingAddressErrs = 0 - } - continue + return "", MarkAsRetriableConnectorError(err) } - - // Remap error for external consumption. - log.Errorf(ctx, "could not retrieve outgoing address: %v", err.Error()) - err = newErrorf( - codeParamsRoutingFailed, "cluster %s-%d not found", clusterName, tenID.ToUint64()) - break - } - - // NB: TLS options for the proxy are split into Insecure and - // SkipVerify. In insecure mode, tlsConf is expected to be nil. This - // will cause BackendDial to skip TLS entirely. If SkipVerify is true, - // tlsConf will be set to a non-nil config with InsecureSkipVerify set - // to true. InsecureSkipVerify will provide an encrypted connection but - // not verify that the connection recipient is a trusted party. - var tlsConf *tls.Config - if !handler.Insecure { - // Use an empty string as the default port as we only care about the - // correctly parsing the outgoingHost/IP here. - outgoingHost, _, err := addr.SplitHostPort(outgoingAddress, "") - if err != nil { - log.Errorf(ctx, "could not split outgoing address '%s' into host and port: %v", outgoingAddress, err.Error()) - // Remap error for external consumption. - clientErr := newErrorf( - codeParamsRoutingFailed, "cluster %s-%d not found", clusterName, tenID.ToUint64()) - updateMetricsAndSendErrToClient(clientErr, conn, handler.metrics) - return clientErr + // Don't retry if we get a NotFound error. + return "", newErrorf(codeParamsRoutingFailed, + "cluster %s-%d not found", clusterName, tenID.ToUint64()) + }, + AuthenticateFn: authenticate, + OnLookupEvent: func(ctx context.Context, err error) { + // We only care about retriable errors since we want to log them. + if !IsRetriableConnectorError(err) { + outgoingAddressErrs = 0 + return } - - tlsConf = &tls.Config{ - // Always set ServerName, if SkipVerify is true, it will be - // ignored. When SkipVerify is false, it is required to - // establish a TLS connection. - ServerName: outgoingHost, - InsecureSkipVerify: handler.SkipVerify, + outgoingAddressErrs++ + if outgoingAddressErr.ShouldLog() { + log.Ops.Errorf(ctx, + "outgoing address (%d errors skipped): %v", + outgoingAddressErrs, + err, + ) + outgoingAddressErrs = 0 + } + }, + OnDialEvent: func(ctx context.Context, outgoingAddr string, err error) { + // We only care about retriable errors since we want to log them, + // and report errors to the tenant directory. + if !IsRetriableConnectorError(err) { + codeBackendDownErrs = 0 + reportFailureErrs = 0 + return } - } - - // Now actually dial the backend server. - crdbConn, err = BackendDial(backendStartupMsg, outgoingAddress, tlsConf) - - // If we get a backend down error, retry the connection. - var codeErr *codeError - if err != nil && errors.As(err, &codeErr) && codeErr.code == codeBackendDown { codeBackendDownErrs++ if backendDialErr.ShouldLog() { log.Ops.Errorf(ctx, @@ -364,12 +339,12 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn ) codeBackendDownErrs = 0 } - if handler.directory != nil { // Report the failure to the directory so that it can refresh any // stale information that may have caused the problem. - err = reportFailureToDirectory(ctx, tenID, outgoingAddress, handler.directory) - if err != nil { + if err = reportFailureToDirectory( + ctx, tenID, outgoingAddr, handler.directory, + ); err != nil { reportFailureErrs++ if reportFailureErr.ShouldLog() { log.Ops.Errorf(ctx, @@ -381,42 +356,43 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn } } } - continue - } - break - } - - if err != nil { - updateMetricsAndSendErrToClient(err, conn, handler.metrics) - return err + }, } // Monitor for idle connection, if requested. if handler.idleMonitor != nil { - crdbConn = handler.idleMonitor.DetectIdle(crdbConn, func() { - err := newErrorf(codeIdleDisconnect, "idle connection closed") - select { - case errConnection <- err: /* error reported */ - default: /* the channel already contains an error */ - } - }) + connector.IdleMonitorWrapperFn = func(crdbConn net.Conn) net.Conn { + return handler.idleMonitor.DetectIdle(crdbConn, func() { + err := newErrorf(codeIdleDisconnect, "idle connection closed") + select { + case errConnection <- err: /* error reported */ + default: /* the channel already contains an error */ + } + }) + } } - defer func() { _ = crdbConn.Close() }() - - // Perform user authentication. - if err := authenticate(conn, crdbConn, func(status throttler.AttemptStatus) error { - err := handler.throttleService.ReportAttempt(ctx, throttleTags, throttleTime, status) - if err != nil { - log.Errorf(ctx, "throttler refused connection after authentication: %v", err.Error()) - return throttledError + crdbConn, sentToClient, err := connector.OpenClusterConnWithAuth(ctx, conn, + func(status throttler.AttemptStatus) error { + if err := handler.throttleService.ReportAttempt( + ctx, throttleTags, throttleTime, status, + ); err != nil { + log.Errorf(ctx, "throttler refused connection after authentication: %v", err.Error()) + return throttledError + } + return nil + }, + ) + if err != nil { + log.Errorf(ctx, "could not connect to cluster: %v", err.Error()) + if sentToClient { + handler.metrics.updateForError(err) + } else { + updateMetricsAndSendErrToClient(err, conn, handler.metrics) } - return nil - }); err != nil { - handler.metrics.updateForError(err) - log.Ops.Errorf(ctx, "authenticate: %s", err) return err } + defer crdbConn.Close() handler.metrics.SuccessfulConnCount.Inc(1)