From 8b2c82d2c32714663e0b8317eae3a8ab2851a7e7 Mon Sep 17 00:00:00 2001 From: Jay Date: Sun, 27 Feb 2022 17:18:08 -0500 Subject: [PATCH] ccl/sqlproxyccl: handle implicit auth in OpenTenantConnWithToken Informs #76000. Extracted from #76805. Previously, we assumed that with the token-based authentication, the server is ready to accept queries the moment we connect to it. This assumption has been proved wrong during the integration tests with the forwarder, and there's an implicit AuthenticateOK step that happens after connecting to the server. During that time, initial connection data such as ParameterStatus and BackendKeyData messages will be sent to the client as well. For now, we will ignore those messages. Once we start implementing query cancellation within the proxy, that has to be updated to cache the new BackendKeyData. Release note: None Release justification: This fixes a bug that was introduced when we added token-based authentication support to the proxy in #76417. This is low risk, as the code is guarded behind the connection migration feature, which is currently not being used in production. To add on, CockroachCloud is the only user of sqlproxy. --- pkg/ccl/sqlproxyccl/authentication.go | 40 +++++++++++++ pkg/ccl/sqlproxyccl/authentication_test.go | 57 ++++++++++++++++++ pkg/ccl/sqlproxyccl/connector.go | 25 +++++++- pkg/ccl/sqlproxyccl/connector_test.go | 67 +++++++++++++++++++++- 4 files changed, 186 insertions(+), 3 deletions(-) diff --git a/pkg/ccl/sqlproxyccl/authentication.go b/pkg/ccl/sqlproxyccl/authentication.go index 61aebb6012cd..ac5a495d2f92 100644 --- a/pkg/ccl/sqlproxyccl/authentication.go +++ b/pkg/ccl/sqlproxyccl/authentication.go @@ -133,3 +133,43 @@ var authenticate = func(clientConn, crdbConn net.Conn, throttleHook func(throttl } return newErrorf(codeBackendDisconnected, "authentication took more than %d iterations", i) } + +// implicitAuthenticate assumes that the connection credentials have already +// been transmitted to the server, and discards all messages until we get a +// ReadyForQuery message. If authentication fails, this will return an error. +// +// Discarding messages is fine because this will only be used during connection +// migration with the token-based authentication, and the proxy is the client. +var implicitAuthenticate = func(crdbConn net.Conn) error { + // Use pgproto3 directly for now even though there is an internal buffer + // within the chunkreader. This is fine since there won't be any other + // messages from the server once we receive the ReadyForQuery message. This + // is the same approach as the one used in the authenticate function above. + be := pgproto3.NewFrontend(pgproto3.NewChunkReader(crdbConn), crdbConn) + + // The auth step should require only a few back and forths so 20 iterations + // should be enough. + var i int + for ; i < 20; i++ { + backendMsg, err := be.Receive() + if err != nil { + return newErrorf(codeBackendReadFailed, "unable to receive message from backend: %v", err) + } + + switch tp := backendMsg.(type) { + case *pgproto3.AuthenticationOk, *pgproto3.ParameterStatus, *pgproto3.BackendKeyData: + // Do nothing. + + case *pgproto3.ErrorResponse: + return newErrorf(codeAuthFailed, "authentication failed: %s", tp.Message) + + case *pgproto3.ReadyForQuery: + return nil + + default: + return newErrorf(codeBackendDisconnected, "received unexpected backend message type: %v", tp) + } + } + + return newErrorf(codeBackendDisconnected, "authentication took more than %d iterations", i) +} diff --git a/pkg/ccl/sqlproxyccl/authentication_test.go b/pkg/ccl/sqlproxyccl/authentication_test.go index ab791402375f..7c09dd19a6eb 100644 --- a/pkg/ccl/sqlproxyccl/authentication_test.go +++ b/pkg/ccl/sqlproxyccl/authentication_test.go @@ -202,3 +202,60 @@ func TestAuthenticateUnexpectedMessage(t *testing.T) { require.True(t, errors.As(err, &codeErr)) require.Equal(t, codeBackendDisconnected, codeErr.code) } + +func TestImplicitAuthenticate(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("unexpected message", func(t *testing.T) { + cli, srv := net.Pipe() + be := pgproto3.NewBackend(pgproto3.NewChunkReader(srv), srv) + + go func() { + err := be.Send(&pgproto3.BindComplete{}) + require.NoError(t, err) + }() + + err := implicitAuthenticate(cli) + require.Error(t, err) + codeErr := (*codeError)(nil) + require.True(t, errors.As(err, &codeErr)) + require.Equal(t, codeBackendDisconnected, codeErr.code) + }) + + t.Run("error_response", func(t *testing.T) { + cli, srv := net.Pipe() + be := pgproto3.NewBackend(pgproto3.NewChunkReader(srv), srv) + + go func() { + err := be.Send(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "foo"}) + require.NoError(t, err) + }() + + err := implicitAuthenticate(cli) + require.Error(t, err) + codeErr := (*codeError)(nil) + require.True(t, errors.As(err, &codeErr)) + require.Equal(t, codeAuthFailed, codeErr.code) + }) + + t.Run("successful", func(t *testing.T) { + cli, srv := net.Pipe() + be := pgproto3.NewBackend(pgproto3.NewChunkReader(srv), srv) + + go func() { + err := be.Send(&pgproto3.AuthenticationOk{}) + require.NoError(t, err) + + err = be.Send(&pgproto3.ParameterStatus{Name: "Server Version", Value: "1.3"}) + require.NoError(t, err) + + err = be.Send(&pgproto3.BackendKeyData{ProcessID: uint32(42)}) + require.NoError(t, err) + + err = be.Send(&pgproto3.ReadyForQuery{}) + require.NoError(t, err) + }() + + require.NoError(t, implicitAuthenticate(cli)) + }) +} diff --git a/pkg/ccl/sqlproxyccl/connector.go b/pkg/ccl/sqlproxyccl/connector.go index 32a9f5c0c8de..377383477a9c 100644 --- a/pkg/ccl/sqlproxyccl/connector.go +++ b/pkg/ccl/sqlproxyccl/connector.go @@ -133,8 +133,10 @@ type connector struct { } // OpenTenantConnWithToken opens a connection to the tenant cluster using the -// token-based authentication. -func (c *connector) OpenTenantConnWithToken(ctx context.Context, token string) (net.Conn, error) { +// token-based authentication during connection migration. +func (c *connector) OpenTenantConnWithToken( + ctx context.Context, token string, +) (retServerConn net.Conn, retErr error) { c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] = token defer func() { // Delete token after return. @@ -145,9 +147,27 @@ func (c *connector) OpenTenantConnWithToken(ctx context.Context, token string) ( if err != nil { return nil, err } + defer func() { + if retErr != nil { + serverConn.Close() + } + }() + if c.IdleMonitorWrapperFn != nil { serverConn = c.IdleMonitorWrapperFn(serverConn) } + + // When we use token-based authentication, we will still get the initial + // connection data messages (e.g. ParameterStatus and BackendKeyData). + // Since this method is only used during connection migration (i.e. proxy + // is connecting to the SQL pod), we'll discard all of the messages, and + // only return once we've seen a ReadyForQuery message. + // + // NOTE: This will need to be updated when we implement query cancellation. + if err := implicitAuthenticate(serverConn); err != nil { + return nil, err + } + log.Infof(ctx, "connected to %s through token-based auth", serverConn.RemoteAddr()) return serverConn, nil } @@ -188,6 +208,7 @@ func (c *connector) OpenTenantConnWithAuth( if err := authenticate(clientConn, serverConn, throttleHook); err != nil { return nil, true, err } + log.Infof(ctx, "connected to %s through normal auth", serverConn.RemoteAddr()) return serverConn, false, nil } diff --git a/pkg/ccl/sqlproxyccl/connector_test.go b/pkg/ccl/sqlproxyccl/connector_test.go index 5fb2aab568f3..d7e2719fefb3 100644 --- a/pkg/ccl/sqlproxyccl/connector_test.go +++ b/pkg/ccl/sqlproxyccl/connector_test.go @@ -33,7 +33,7 @@ func TestConnector_OpenTenantConnWithToken(t *testing.T) { const token = "foobarbaz" ctx := context.Background() - t.Run("error", func(t *testing.T) { + t.Run("error during open", func(t *testing.T) { c := &connector{ StartupMsg: &pgproto3.StartupMessage{ Parameters: make(map[string]string), @@ -53,6 +53,48 @@ func TestConnector_OpenTenantConnWithToken(t *testing.T) { require.Equal(t, "", str) }) + t.Run("error during auth", func(t *testing.T) { + c := &connector{ + StartupMsg: &pgproto3.StartupMessage{ + Parameters: make(map[string]string), + }, + } + conn, _ := net.Pipe() + defer conn.Close() + + var openCalled bool + c.testingKnobs.dialTenantCluster = func(ctx context.Context) (net.Conn, error) { + openCalled = true + + // Validate that token is set. + str, ok := c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] + require.True(t, ok) + require.Equal(t, token, str) + + return conn, nil + } + + defer testutils.TestingHook( + &implicitAuthenticate, + func(serverConn net.Conn) error { + return errors.New("bar") + }, + )() + + crdbConn, err := c.OpenTenantConnWithToken(ctx, token) + require.True(t, openCalled) + require.EqualError(t, err, "bar") + require.Nil(t, crdbConn) + + // Ensure that token is deleted. + _, ok := c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] + require.False(t, ok) + + // Connection should be closed. + _, err = conn.Write([]byte("foo")) + require.Regexp(t, "closed pipe", err) + }) + t.Run("successful", func(t *testing.T) { c := &connector{ StartupMsg: &pgproto3.StartupMessage{ @@ -74,8 +116,19 @@ func TestConnector_OpenTenantConnWithToken(t *testing.T) { return conn, nil } + var authCalled bool + defer testutils.TestingHook( + &implicitAuthenticate, + func(serverConn net.Conn) error { + authCalled = true + require.Equal(t, conn, serverConn) + return nil + }, + )() + crdbConn, err := c.OpenTenantConnWithToken(ctx, token) require.True(t, openCalled) + require.True(t, authCalled) require.NoError(t, err) require.Equal(t, conn, crdbConn) @@ -111,9 +164,20 @@ func TestConnector_OpenTenantConnWithToken(t *testing.T) { return conn, nil } + var authCalled bool + defer testutils.TestingHook( + &implicitAuthenticate, + func(serverConn net.Conn) error { + authCalled = true + require.Equal(t, conn, serverConn) + return nil + }, + )() + crdbConn, err := c.OpenTenantConnWithToken(ctx, token) require.True(t, wrapperCalled) require.True(t, openCalled) + require.True(t, authCalled) require.NoError(t, err) require.Equal(t, conn, crdbConn) @@ -641,6 +705,7 @@ func TestRetriableConnectorError(t *testing.T) { require.False(t, isRetriableConnectorError(err)) err = markAsRetriableConnectorError(err) require.True(t, isRetriableConnectorError(err)) + require.True(t, errors.Is(err, errRetryConnectorSentinel)) } var _ TenantResolver = &testTenantResolver{}