Skip to content

Commit

Permalink
ccl/sqlproxyccl: handle implicit auth in OpenTenantConnWithToken
Browse files Browse the repository at this point in the history
Informs cockroachdb#76000. Extracted from cockroachdb#76805.

Previously, we assumed that with the token-based authentication, the server is
ready to accept queries the moment we connect to it. This assumption has been
proved wrong during the integration tests with the forwarder, and there's
an implicit AuthenticateOK step that happens after connecting to the server.
During that time, initial connection data such as ParameterStatus and
BackendKeyData messages will be sent to the client as well. For now, we will
ignore those messages. Once we start implementing query cancellation within
the proxy, that has to be updated to cache the new BackendKeyData.

This commit also fixes a buglet to handle pgwire messages with no body.
pgproto3's Receive methods will still call Next if the body size is 0, and
previously, we were returning an io.EOF error. This commit changes that
behavior to return an empty slice.

Release note: None

Release justification: This fixes two buglets: one that was introduced when we
added token-based authentication support to the proxy in cockroachdb#76417, and another
when we added the interceptors. This is low risk as part of the code is
guarded behind the connection migration feature, which is currently not being
used in production. To add on, CockroachCloud is the only user of sqlproxy.
  • Loading branch information
jaylim-crl authored and RajivTS committed Mar 6, 2022
1 parent 846eea5 commit b61d605
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 3 deletions.
54 changes: 54 additions & 0 deletions pkg/ccl/sqlproxyccl/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package sqlproxyccl
import (
"net"

"github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/interceptor"
"github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/throttler"
pgproto3 "github.com/jackc/pgproto3/v2"
)
Expand Down Expand Up @@ -133,3 +134,56 @@ var authenticate = func(clientConn, crdbConn net.Conn, throttleHook func(throttl
}
return newErrorf(codeBackendDisconnected, "authentication took more than %d iterations", i)
}

// readTokenAuthResult reads the result for the token-based authentication, and
// assumes that the connection credentials have already been transmitted to the
// server (as part of the startup message). If authentication fails, this will
// return an error.
//
// NOTE: For now, this also reads the initial connection data
// (i.e. ParameterStatus and BackendKeyData) until we see a ReadyForQuery
// message. All messages will be discarded, and this is fine because we only
// call this during connection migration, and the proxy is the client. Once we
// address the TODO below, we could generalize it such that the client here is
// a no-op client.
//
// TODO(jaylim-crl): We should extract out the initial connection data stuff
// into a readInitialConnData function that `authenticate` can also use. It was
// a mistake to split reader (interceptor) and writer (net.Conn), and I think
// we should merge them back in the future. Instead of having the writer as the
// other end, the writer should be the same connection. That way, a
// sqlproxyccl.Conn can be used to read-from, or write-to the same component.
var readTokenAuthResult = func(serverConn net.Conn) error {
// This interceptor is discarded once this function returns. Just like
// pgproto3.NewFrontend, this interceptor has an internal buffer.
// Discarding the buffer is fine since there won't be any other messages
// from the server once we receive the ReadyForQuery message because the
// caller (i.e. proxy) does not forward client messages until then.
serverInterceptor := interceptor.NewFrontendInterceptor(serverConn)

// The auth step should require only a few back and forths so 20 iterations
// should be enough.
var i int
for ; i < 20; i++ {
backendMsg, err := serverInterceptor.ReadMsg()
if err != nil {
return newErrorf(codeBackendReadFailed, "unable to receive message from backend: %v", err)
}

switch tp := backendMsg.(type) {
case *pgproto3.AuthenticationOk, *pgproto3.ParameterStatus, *pgproto3.BackendKeyData:
// Do nothing.

case *pgproto3.ErrorResponse:
return newErrorf(codeAuthFailed, "authentication failed: %s", tp.Message)

case *pgproto3.ReadyForQuery:
return nil

default:
return newErrorf(codeBackendDisconnected, "received unexpected backend message type: %v", tp)
}
}

return newErrorf(codeBackendDisconnected, "authentication took more than %d iterations", i)
}
54 changes: 54 additions & 0 deletions pkg/ccl/sqlproxyccl/authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,57 @@ func TestAuthenticateUnexpectedMessage(t *testing.T) {
require.True(t, errors.As(err, &codeErr))
require.Equal(t, codeBackendDisconnected, codeErr.code)
}

func TestReadTokenAuthResult(t *testing.T) {
defer leaktest.AfterTest(t)()

t.Run("unexpected message", func(t *testing.T) {
cli, srv := net.Pipe()

go func() {
_, err := srv.Write((&pgproto3.BindComplete{}).Encode(nil))
require.NoError(t, err)
}()

err := readTokenAuthResult(cli)
require.Error(t, err)
codeErr := (*codeError)(nil)
require.True(t, errors.As(err, &codeErr))
require.Equal(t, codeBackendDisconnected, codeErr.code)
})

t.Run("error_response", func(t *testing.T) {
cli, srv := net.Pipe()

go func() {
_, err := srv.Write((&pgproto3.ErrorResponse{Severity: "FATAL", Code: "foo"}).Encode(nil))
require.NoError(t, err)
}()

err := readTokenAuthResult(cli)
require.Error(t, err)
codeErr := (*codeError)(nil)
require.True(t, errors.As(err, &codeErr))
require.Equal(t, codeAuthFailed, codeErr.code)
})

t.Run("successful", func(t *testing.T) {
cli, srv := net.Pipe()

go func() {
_, err := srv.Write((&pgproto3.AuthenticationOk{}).Encode(nil))
require.NoError(t, err)

_, err = srv.Write((&pgproto3.ParameterStatus{Name: "Server Version", Value: "1.3"}).Encode(nil))
require.NoError(t, err)

_, err = srv.Write((&pgproto3.BackendKeyData{ProcessID: uint32(42)}).Encode(nil))
require.NoError(t, err)

_, err = srv.Write((&pgproto3.ReadyForQuery{}).Encode(nil))
require.NoError(t, err)
}()

require.NoError(t, readTokenAuthResult(cli))
})
}
25 changes: 23 additions & 2 deletions pkg/ccl/sqlproxyccl/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,10 @@ type connector struct {
}

// OpenTenantConnWithToken opens a connection to the tenant cluster using the
// token-based authentication.
func (c *connector) OpenTenantConnWithToken(ctx context.Context, token string) (net.Conn, error) {
// token-based authentication during connection migration.
func (c *connector) OpenTenantConnWithToken(
ctx context.Context, token string,
) (retServerConn net.Conn, retErr error) {
c.StartupMsg.Parameters[sessionRevivalTokenStartupParam] = token
defer func() {
// Delete token after return.
Expand All @@ -145,9 +147,27 @@ func (c *connector) OpenTenantConnWithToken(ctx context.Context, token string) (
if err != nil {
return nil, err
}
defer func() {
if retErr != nil {
serverConn.Close()
}
}()

if c.IdleMonitorWrapperFn != nil {
serverConn = c.IdleMonitorWrapperFn(serverConn)
}

// When we use token-based authentication, we will still get the initial
// connection data messages (e.g. ParameterStatus and BackendKeyData).
// Since this method is only used during connection migration (i.e. proxy
// is connecting to the SQL pod), we'll discard all of the messages, and
// only return once we've seen a ReadyForQuery message.
//
// NOTE: This will need to be updated when we implement query cancellation.
if err := readTokenAuthResult(serverConn); err != nil {
return nil, err
}
log.Infof(ctx, "connected to %s through token-based auth", serverConn.RemoteAddr())
return serverConn, nil
}

Expand Down Expand Up @@ -188,6 +208,7 @@ func (c *connector) OpenTenantConnWithAuth(
if err := authenticate(clientConn, serverConn, throttleHook); err != nil {
return nil, true, err
}
log.Infof(ctx, "connected to %s through normal auth", serverConn.RemoteAddr())
return serverConn, false, nil
}

Expand Down
67 changes: 66 additions & 1 deletion pkg/ccl/sqlproxyccl/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func TestConnector_OpenTenantConnWithToken(t *testing.T) {
const token = "foobarbaz"
ctx := context.Background()

t.Run("error", func(t *testing.T) {
t.Run("error during open", func(t *testing.T) {
c := &connector{
StartupMsg: &pgproto3.StartupMessage{
Parameters: make(map[string]string),
Expand All @@ -53,6 +53,48 @@ func TestConnector_OpenTenantConnWithToken(t *testing.T) {
require.Equal(t, "", str)
})

t.Run("error during auth", func(t *testing.T) {
c := &connector{
StartupMsg: &pgproto3.StartupMessage{
Parameters: make(map[string]string),
},
}
conn, _ := net.Pipe()
defer conn.Close()

var openCalled bool
c.testingKnobs.dialTenantCluster = func(ctx context.Context) (net.Conn, error) {
openCalled = true

// Validate that token is set.
str, ok := c.StartupMsg.Parameters[sessionRevivalTokenStartupParam]
require.True(t, ok)
require.Equal(t, token, str)

return conn, nil
}

defer testutils.TestingHook(
&readTokenAuthResult,
func(serverConn net.Conn) error {
return errors.New("bar")
},
)()

crdbConn, err := c.OpenTenantConnWithToken(ctx, token)
require.True(t, openCalled)
require.EqualError(t, err, "bar")
require.Nil(t, crdbConn)

// Ensure that token is deleted.
_, ok := c.StartupMsg.Parameters[sessionRevivalTokenStartupParam]
require.False(t, ok)

// Connection should be closed.
_, err = conn.Write([]byte("foo"))
require.Regexp(t, "closed pipe", err)
})

t.Run("successful", func(t *testing.T) {
c := &connector{
StartupMsg: &pgproto3.StartupMessage{
Expand All @@ -74,8 +116,19 @@ func TestConnector_OpenTenantConnWithToken(t *testing.T) {
return conn, nil
}

var authCalled bool
defer testutils.TestingHook(
&readTokenAuthResult,
func(serverConn net.Conn) error {
authCalled = true
require.Equal(t, conn, serverConn)
return nil
},
)()

crdbConn, err := c.OpenTenantConnWithToken(ctx, token)
require.True(t, openCalled)
require.True(t, authCalled)
require.NoError(t, err)
require.Equal(t, conn, crdbConn)

Expand Down Expand Up @@ -111,9 +164,20 @@ func TestConnector_OpenTenantConnWithToken(t *testing.T) {
return conn, nil
}

var authCalled bool
defer testutils.TestingHook(
&readTokenAuthResult,
func(serverConn net.Conn) error {
authCalled = true
require.Equal(t, conn, serverConn)
return nil
},
)()

crdbConn, err := c.OpenTenantConnWithToken(ctx, token)
require.True(t, wrapperCalled)
require.True(t, openCalled)
require.True(t, authCalled)
require.NoError(t, err)
require.Equal(t, conn, crdbConn)

Expand Down Expand Up @@ -641,6 +705,7 @@ func TestRetriableConnectorError(t *testing.T) {
require.False(t, isRetriableConnectorError(err))
err = markAsRetriableConnectorError(err)
require.True(t, isRetriableConnectorError(err))
require.True(t, errors.Is(err, errRetryConnectorSentinel))
}

var _ TenantResolver = &testTenantResolver{}
Expand Down
5 changes: 5 additions & 0 deletions pkg/ccl/sqlproxyccl/interceptor/chunkreader.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ func newChunkReader(msg []byte) pgproto3.ChunkReader {
// returned once the entire message has been read. If the caller tries to read
// more bytes than it could, an errInvalidRead will be returned.
func (cr *chunkReader) Next(n int) (buf []byte, err error) {
// pgproto3's Receive methods will still invoke Next even if the body size
// is 0. We shouldn't return an EOF in that case.
if n == 0 {
return []byte{}, nil
}
if cr.pos == len(cr.msg) {
return nil, io.EOF
}
Expand Down
10 changes: 10 additions & 0 deletions pkg/ccl/sqlproxyccl/interceptor/chunkreader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,21 @@ func TestChunkReader(t *testing.T) {
require.EqualError(t, err, errInvalidRead.Error())
require.Nil(t, buf)

// Attempt n = 0 before EOF.
buf, err = cr.Next(0)
require.NoError(t, err)
require.Len(t, buf, 0)

buf, err = cr.Next(11)
require.NoError(t, err)
require.Equal(t, "hello world", string(buf))

buf, err = cr.Next(1)
require.EqualError(t, err, io.EOF.Error())
require.Nil(t, buf)

// Attempting n = 0 after EOF returns nothing instead of an error.
buf, err = cr.Next(0)
require.NoError(t, err)
require.Len(t, buf, 0)
}

0 comments on commit b61d605

Please sign in to comment.