Skip to content

Commit

Permalink
sqlproxyccl: rework sqlproxy connection throttler
Browse files Browse the repository at this point in the history
This change switches the sqlproxy connection throttling logic back to
exponential backoff. The tokenbucket approach was introduced by
PR #69041. There are a few behavior differences between this and the
original exponential backoff implementation.

1. The throttling logic is maintained per (ip, tenant) instead of per
   (ip). Some platform as a service provides share a single outbound ip
   address between multiple clients. These users would occasionaly see
   throttling caused by a second user sharing their IP.
2. The throttling logic was triggered before there was an authentication
   failure. It takes ~100ms-1000ms to authenticate with the tenant
   process.  Any requests that arrived after the first request, but
   before it was processed, would trigger the throttle. Now, we only
   trigger the throttle in response to an explict authorization error.

Release note: None
  • Loading branch information
jeffswenson committed Oct 13, 2021
1 parent 41ee2e9 commit 2559b4b
Show file tree
Hide file tree
Showing 17 changed files with 526 additions and 459 deletions.
1 change: 1 addition & 0 deletions pkg/ccl/sqlproxyccl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ go_test(
"//pkg/ccl/sqlproxyccl/denylist",
"//pkg/ccl/sqlproxyccl/tenant",
"//pkg/ccl/sqlproxyccl/tenantdirsvr",
"//pkg/ccl/sqlproxyccl/throttler",
"//pkg/ccl/utilccl",
"//pkg/roachpb:with-mocks",
"//pkg/security",
Expand Down
93 changes: 64 additions & 29 deletions pkg/ccl/sqlproxyccl/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,25 @@ package sqlproxyccl
import (
"net"

"github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/throttler"
"github.com/jackc/pgproto3/v2"
)

// authenticate handles the startup of the pgwire protocol to the point where
// the connections is considered authenticated. If that doesn't happen, it
// returns an error.
var authenticate = func(clientConn, crdbConn net.Conn) error {
var authenticate = func(clientConn, crdbConn net.Conn, throttleHook func(throttler.AttemptStatus) *pgproto3.ErrorResponse) error {
fe := pgproto3.NewBackend(pgproto3.NewChunkReader(clientConn), clientConn)
be := pgproto3.NewFrontend(pgproto3.NewChunkReader(crdbConn), crdbConn)

feSend := func(msg pgproto3.BackendMessage) error {
err := fe.Send(msg)
if err != nil {
return newErrorf(codeClientWriteFailed, "unable to send message %v to client: %v", msg, err)
}
return nil
}

// The auth step should require only a few back and forths so 20 iterations
// should be enough.
var i int
Expand All @@ -32,39 +41,18 @@ var authenticate = func(clientConn, crdbConn net.Conn) error {
return newErrorf(codeBackendReadFailed, "unable to receive message from backend: %v", err)
}

err = fe.Send(backendMsg)
if err != nil {
return newErrorf(
codeClientWriteFailed, "unable to send message %v to client: %v", backendMsg, err,
)
}

// Decide what to do based on the type of the server response.
// The cases in this switch are roughly sorted in the order the server will send them.
switch tp := backendMsg.(type) {
case *pgproto3.ReadyForQuery:
// Server has authenticated the connection successfully and is ready to
// serve queries.
return nil
case *pgproto3.AuthenticationOk:
// Server has authenticated the connection; keep reading messages until
// `pgproto3.ReadyForQuery` is encountered which signifies that server
// is ready to serve queries.
case *pgproto3.ParameterStatus:
// Server sent status message; keep reading messages until
// `pgproto3.ReadyForQuery` is encountered.
case *pgproto3.BackendKeyData:
// Server sent backend key data; keep reading messages until
// `pgproto3.ReadyForQuery` is encountered.
case *pgproto3.ErrorResponse:
// Server has rejected the authentication response from the client and
// has closed the connection.
return newErrorf(codeAuthFailed, "authentication failed: %s", tp.Message)

// The backend is requesting the user to authenticate.
// Read the client response and forward it to server.
case
*pgproto3.AuthenticationCleartextPassword,
*pgproto3.AuthenticationMD5Password,
*pgproto3.AuthenticationSASL:
// The backend is requesting the user to authenticate.
// Read the client response and forward it to server.
if err = feSend(backendMsg); err != nil {
return err
}
fntMsg, err := fe.Receive()
if err != nil {
return newErrorf(codeClientReadFailed, "unable to receive message from client: %v", err)
Expand All @@ -75,6 +63,53 @@ var authenticate = func(clientConn, crdbConn net.Conn) error {
codeBackendWriteFailed, "unable to send message %v to backend: %v", fntMsg, err,
)
}

// Server has authenticated the connection; keep reading messages until
// `pgproto3.ReadyForQuery` is encountered which signifies that server
// is ready to serve queries.
case *pgproto3.AuthenticationOk:
throttleError := throttleHook(throttler.AttemptOK)
if throttleError != nil {
if err = feSend(throttleError); err != nil {
return err
}
return newErrorf(codeProxyRefusedConnection, "connection attempt throttled")
}
if err = feSend(backendMsg); err != nil {
return err
}

// Server has rejected the authentication response from the client and
// has closed the connection.
case *pgproto3.ErrorResponse:
throttleError := throttleHook(throttler.AttemptInvalidCredentials)
if throttleError != nil {
if err = feSend(throttleError); err != nil {
return err
}
return newErrorf(codeProxyRefusedConnection, "connection attempt throttled")
}
if err = feSend(backendMsg); err != nil {
return err
}
return newErrorf(codeAuthFailed, "authentication failed: %s", tp.Message)

// Information provided by the server to the client before the connection is ready
// to accept queries. These are typically returned after AuthenticationOk and before
// ReadyForQuery.
case *pgproto3.ParameterStatus, *pgproto3.BackendKeyData:
if err = feSend(backendMsg); err != nil {
return err
}

// Server has authenticated the connection successfully and is ready to
// serve queries.
case *pgproto3.ReadyForQuery:
if err = feSend(backendMsg); err != nil {
return err
}
return nil

default:
return newErrorf(codeBackendDisconnected, "received unexpected backend message type: %v", tp)
}
Expand Down
90 changes: 83 additions & 7 deletions pkg/ccl/sqlproxyccl/authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,17 @@ import (
"net"
"testing"

"github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/throttler"
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/errors"
"github.com/jackc/pgproto3/v2"
"github.com/stretchr/testify/require"
)

var nilThrottleHook = func(state throttler.AttemptStatus) *pgproto3.ErrorResponse {
return nil
}

func TestAuthenticateOK(t *testing.T) {
defer leaktest.AfterTest(t)()

Expand All @@ -33,7 +38,7 @@ func TestAuthenticateOK(t *testing.T) {
require.Equal(t, beMsg, &pgproto3.ReadyForQuery{})
}()

require.NoError(t, authenticate(srv, cli))
require.NoError(t, authenticate(srv, cli, nilThrottleHook))
}

func TestAuthenticateClearText(t *testing.T) {
Expand Down Expand Up @@ -75,7 +80,76 @@ func TestAuthenticateClearText(t *testing.T) {
require.Equal(t, beMsg, &pgproto3.ReadyForQuery{})
}()

require.NoError(t, authenticate(srv, cli))
require.NoError(t, authenticate(srv, cli, nilThrottleHook))
}

func TestAuthenticateThrottled(t *testing.T) {
defer leaktest.AfterTest(t)()

server := func(t *testing.T, be *pgproto3.Backend, authResponse pgproto3.BackendMessage) {
require.NoError(t, be.Send(&pgproto3.AuthenticationCleartextPassword{}))

msg, err := be.Receive()
require.NoError(t, err)
require.Equal(t, msg, &pgproto3.PasswordMessage{Password: "password"})

require.NoError(t, be.Send(authResponse))
}

client := func(t *testing.T, fe *pgproto3.Frontend) {
msg, err := fe.Receive()
require.NoError(t, err)
require.Equal(t, msg, &pgproto3.AuthenticationCleartextPassword{})

require.NoError(t, fe.Send(&pgproto3.PasswordMessage{Password: "password"}))

msg, err = fe.Receive()
require.NoError(t, err)
require.Equal(t, msg, &pgproto3.ErrorResponse{Message: "throttled"})

// Try reading from the connection. This check ensures authorize
// swallowed the OK/Error response from the sql server.
_, err = fe.Receive()
require.Error(t, err)
}

type testCase struct {
name string
result pgproto3.BackendMessage
expectedStatus throttler.AttemptStatus
}
for _, tc := range []testCase{
{
name: "AuthenticationOkay",
result: &pgproto3.AuthenticationOk{},
expectedStatus: throttler.AttemptOK,
},
{
name: "AuthenticationError",
result: &pgproto3.ErrorResponse{Message: "wrong password"},
expectedStatus: throttler.AttemptInvalidCredentials,
},
} {
t.Run(tc.name, func(t *testing.T) {
proxyToServer, serverToProxy := net.Pipe()
proxyToClient, clientToProxy := net.Pipe()
sqlServer := pgproto3.NewBackend(pgproto3.NewChunkReader(serverToProxy), serverToProxy)
sqlClient := pgproto3.NewFrontend(pgproto3.NewChunkReader(clientToProxy), clientToProxy)

go server(t, sqlServer, &pgproto3.AuthenticationOk{})
go client(t, sqlClient)

err := authenticate(proxyToClient, proxyToServer, func(status throttler.AttemptStatus) *pgproto3.ErrorResponse {
require.Equal(t, throttler.AttemptOK, status)
return &pgproto3.ErrorResponse{Message: "throttled"}
})
require.Error(t, err)
require.Contains(t, err.Error(), "connection attempt throttled")

proxyToServer.Close()
proxyToClient.Close()
})
}
}

func TestAuthenticateError(t *testing.T) {
Expand All @@ -93,7 +167,7 @@ func TestAuthenticateError(t *testing.T) {
require.Equal(t, beMsg, &pgproto3.ErrorResponse{Severity: "FATAL", Code: "foo"})
}()

err := authenticate(srv, cli)
err := authenticate(srv, cli, nilThrottleHook)
require.Error(t, err)
codeErr := (*codeError)(nil)
require.True(t, errors.As(err, &codeErr))
Expand All @@ -110,12 +184,14 @@ func TestAuthenticateUnexpectedMessage(t *testing.T) {
go func() {
err := be.Send(&pgproto3.BindComplete{})
require.NoError(t, err)
beMsg, err := fe.Receive()
require.NoError(t, err)
require.Equal(t, beMsg, &pgproto3.BindComplete{})
_, err = fe.Receive()
require.Error(t, err)
}()

err := authenticate(srv, cli)
err := authenticate(srv, cli, nilThrottleHook)

srv.Close()

require.Error(t, err)
codeErr := (*codeError)(nil)
require.True(t, errors.As(err, &codeErr))
Expand Down
35 changes: 19 additions & 16 deletions pkg/ccl/sqlproxyccl/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,7 @@ func updateMetricsAndSendErrToClient(err error, conn net.Conn, metrics *metrics)
SendErrToClient(conn, err)
}

// SendErrToClient will encode and pass back to the SQL client an error message.
// It can be called by the implementors of proxyHandler to give more
// information to the end user in case of a problem.
var SendErrToClient = func(conn net.Conn, err error) {
if err == nil || conn == nil {
return
}
func toPgError(err error) *pgproto3.ErrorResponse {
codeErr := (*codeError)(nil)
if errors.As(err, &codeErr) {
var msg string
Expand All @@ -60,19 +54,28 @@ var SendErrToClient = func(conn net.Conn, err error) {
} else {
pgCode = "08004" // rejected connection
}
_, _ = conn.Write((&pgproto3.ErrorResponse{
return &pgproto3.ErrorResponse{
Severity: "FATAL",
Code: pgCode,
Message: msg,
}).Encode(nil))
} else {
// Return a generic "internal server error" message.
_, _ = conn.Write((&pgproto3.ErrorResponse{
Severity: "FATAL",
Code: "08004", // rejected connection
Message: "internal server error",
}).Encode(nil))
}
}
// Return a generic "internal server error" message.
return &pgproto3.ErrorResponse{
Severity: "FATAL",
Code: "08004", // rejected connection
Message: "internal server error",
}
}

// SendErrToClient will encode and pass back to the SQL client an error message.
// It can be called by the implementors of proxyHandler to give more
// information to the end user in case of a problem.
var SendErrToClient = func(conn net.Conn, err error) {
if err == nil || conn == nil {
return
}
_, _ = conn.Write(toPgError(err).Encode(nil))
}

// ConnectionCopy does a bi-directional copy between the backend and frontend
Expand Down
26 changes: 17 additions & 9 deletions pkg/ccl/sqlproxyccl/proxy_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,9 @@ type ProxyOptions struct {
// DrainTimeout if set, will close DRAINING connections that have been idle
// for this duration.
DrainTimeout time.Duration

// Token bucket policy used to throttle (IP, TenantID) connection pairs that
// have no history of successful authentication.
ThrottlePolicy throttler.BucketPolicy
// ThrottleBaseDelay is the initial exponential backoff triggered in
// response to the first connection failure.
ThrottleBaseDelay time.Duration
}

// proxyHandler is the default implementation of a proxy handler.
Expand Down Expand Up @@ -135,6 +134,8 @@ type proxyHandler struct {
certManager *certmgr.CertManager
}

var throttledError = newErrorf(codeProxyRefusedConnection, "connection attempt throttled")

// newProxyHandler will create a new proxy handler with configuration based on
// the provided options.
func newProxyHandler(
Expand Down Expand Up @@ -163,7 +164,7 @@ func newProxyHandler(
}

handler.throttleService = throttler.NewLocalService(
throttler.WithPolicy(handler.ThrottlePolicy),
throttler.WithBaseDelay(handler.ThrottleBaseDelay),
)

if handler.DirectoryAddr != "" {
Expand Down Expand Up @@ -260,9 +261,10 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn
defer removeListener()

throttleTags := throttler.ConnectionTags{IP: ipAddr, TenantID: tenID.String()}
if err := handler.throttleService.LoginCheck(throttleTags); err != nil {
throttleTime, err := handler.throttleService.LoginCheck(throttleTags)
if err != nil {
log.Errorf(ctx, "throttler refused connection: %v", err.Error())
err = newErrorf(codeProxyRefusedConnection, "connection attempt throttled")
err = throttledError
updateMetricsAndSendErrToClient(err, conn, handler.metrics)
return err
}
Expand Down Expand Up @@ -397,14 +399,20 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn
defer func() { _ = crdbConn.Close() }()

// Perform user authentication.
if err := authenticate(conn, crdbConn); err != nil {
if err := authenticate(conn, crdbConn, func(status throttler.AttemptStatus) *pgproto3.ErrorResponse {
err := handler.throttleService.ReportAttempt(ctx, throttleTags, throttleTime, status)
if err != nil {
log.Errorf(ctx, "throttler refused connection after authentication: %v", err.Error())
return toPgError(throttledError)
}
return nil
}); err != nil {
handler.metrics.updateForError(err)
log.Ops.Errorf(ctx, "authenticate: %s", err)
return err
}

handler.metrics.SuccessfulConnCount.Inc(1)
handler.throttleService.ReportSuccess(throttleTags)

ctx, cancel := context.WithCancel(ctx)
defer cancel()
Expand Down
Loading

0 comments on commit 2559b4b

Please sign in to comment.