Skip to content

Commit

Permalink
[CC-3025] sqlproxy: add hook to rate limit connections
Browse files Browse the repository at this point in the history
We need a mechanism for sqlproxy users to provide admission control to
clients that is enforced in the proxy. To that end we expand
`BackendConfigFromParams` to provide a signal whether a connection
should be refused and a callback to record the success of a connection.
The callback can implement caching of known connections for example so
that they are not rate limited on further connection requests.

Release note: none.
  • Loading branch information
Spas Bojanov committed Nov 11, 2020
1 parent 6a1d8c7 commit d505572
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 49 deletions.
14 changes: 10 additions & 4 deletions pkg/ccl/cliccl/mtproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,22 @@ Uuwb2FVdh76ZK0AVd3Jh3KJs4+hr2u9syHaa7UPKXTcZsFWlGwZuu6X5A+0SO0S2
IncomingTLSConfig: &tls.Config{
Certificates: []tls.Certificate{cer},
},
BackendFromParams: func(params map[string]string) (addr string, conf *tls.Config, clientErr error) {
BackendConfigFromParams: func(
params map[string]string, ipAddress string,
) (config *sqlproxyccl.BackendConfig, clientErr error) {
const magic = "prancing-pony"
cfg := &sqlproxyccl.BackendConfig{
OutgoingAddress: sqlProxyTargetAddr,
TLSConf: outgoingConf,
}
if strings.HasPrefix(params["database"], magic+".") {
params["database"] = params["database"][len(magic)+1:]
return sqlProxyTargetAddr, outgoingConf, nil
return cfg, nil
}
if params["options"] == "--cluster="+magic {
return sqlProxyTargetAddr, outgoingConf, nil
return cfg, nil
}
return "", nil, errors.Errorf("client failed to pass '%s' via database or options", magic)
return nil, errors.Errorf("client failed to pass '%s' via database or options", magic)
},
})

Expand Down
1 change: 0 additions & 1 deletion pkg/ccl/sqlproxyccl/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ go_test(
"//pkg/security/securitytest",
"//pkg/server",
"//pkg/testutils/serverutils",
"//pkg/testutils/skip",
"//pkg/testutils/testcluster",
"//pkg/util/leaktest",
"//pkg/util/randutil",
Expand Down
4 changes: 4 additions & 0 deletions pkg/ccl/sqlproxyccl/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ const (
// CodeClientDisconnected indicates that the client disconnected unexpectedly
// (with a connection error) while in a session with backend SQL server.
CodeClientDisconnected

// CodeProxyRefusedConnection indicates that the proxy refused the connection
// request due to high load or too many connection attempts.
CodeProxyRefusedConnection
)

type codeError struct {
Expand Down
5 changes: 3 additions & 2 deletions pkg/ccl/sqlproxyccl/errorcode_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions pkg/ccl/sqlproxyccl/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type Metrics struct {
ClientDisconnectCount *metric.Counter
CurConnCount *metric.Gauge
RoutingErrCount *metric.Counter
RefusedConnCount *metric.Counter
}

// MetricStruct implements the metrics.Struct interface.
Expand Down Expand Up @@ -56,6 +57,12 @@ var (
Measurement: "Disconnects",
Unit: metric.Unit_COUNT,
}
metaRefusedConnCount = metric.Metadata{
Name: "proxy.err.refused_conn",
Help: "Number of refused connections initiated by a given IP",
Measurement: "Refused",
Unit: metric.Unit_COUNT,
}
)

// MakeProxyMetrics instantiates the metrics holder for proxy monitoring.
Expand All @@ -66,5 +73,6 @@ func MakeProxyMetrics() Metrics {
ClientDisconnectCount: metric.NewCounter(metaClientDisconnectCount),
CurConnCount: metric.NewGauge(metaCurConnCount),
RoutingErrCount: metric.NewCounter(metaRoutingErrCount),
RefusedConnCount: metric.NewCounter(metaBackendDisconnectCount),
}
}
74 changes: 56 additions & 18 deletions pkg/ccl/sqlproxyccl/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,28 @@ const pgAcceptSSLRequest = 'S'
// See https://www.postgresql.org/docs/9.1/protocol-message-formats.html.
var pgSSLRequest = []int32{8, 80877103}

// BackendConfig contains the configuration of a backend connection that is
// being proxied.
type BackendConfig struct {
OutgoingAddress string
TLSConf *tls.Config
RefuseConn bool
OnConnectionSuccess func() error
}

// Options are the options to the Proxy method.
type Options struct {
IncomingTLSConfig *tls.Config // config used for client -> proxy connection

// TODO(tbg): this is unimplemented and exists only to check which clients
// allow use of SNI. Should always return ("", nil).
BackendFromSNI func(serverName string) (addr string, conf *tls.Config, clientErr error)
// BackendFromParams returns the address and TLS config to use for
// the proxy -> backend connection. The returned config must have
// an appropriate ServerName for the remote backend.
BackendFromParams func(map[string]string) (addr string, conf *tls.Config, clientErr error)
BackendConfigFromSNI func(serverName string) (config *BackendConfig, clientErr error)
// BackendFromParams returns the config to use for the proxy -> backend
// connection. The TLS config is in it and it must have an appropriate
// ServerName for the remote backend.
BackendConfigFromParams func(
params map[string]string, ipAddress string,
) (config *BackendConfig, clientErr error)

// If set, consulted to modify the parameters set by the frontend before
// forwarding them to the backend during startup.
Expand Down Expand Up @@ -86,15 +97,15 @@ func (s *Server) Proxy(conn net.Conn) error {
sniServerName = h.ServerName
return nil, nil
}
if s.opts.BackendFromSNI != nil {
addr, _, clientErr := s.opts.BackendFromSNI(sniServerName)
if s.opts.BackendConfigFromSNI != nil {
cfg, clientErr := s.opts.BackendConfigFromSNI(sniServerName)
if clientErr != nil {
code := CodeSNIRoutingFailed
sendErrToClient(conn, code, clientErr.Error()) // won't actually be shown by most clients
return newErrorf(code, "rejected by OutgoingAddrFromSNI")
}
if addr != "" {
return newErrorf(CodeSNIRoutingFailed, "BackendFromSNI is unimplemented")
if cfg.OutgoingAddress != "" {
return newErrorf(CodeSNIRoutingFailed, "BackendConfigFromSNI is unimplemented")
}
}
conn = tls.Server(conn, cfg)
Expand All @@ -109,15 +120,32 @@ func (s *Server) Proxy(conn net.Conn) error {
return newErrorf(CodeUnexpectedStartupMessage, "unsupported post-TLS startup message: %T", m)
}

outgoingAddr, outgoingTLS, clientErr := s.opts.BackendFromParams(msg.Parameters)
if clientErr != nil {
s.metrics.RoutingErrCount.Inc(1)
code := CodeParamsRoutingFailed
sendErrToClient(conn, code, clientErr.Error())
return newErrorf(code, "rejected by OutgoingAddrFromParams: %v", clientErr)
var backendConfig *BackendConfig
{
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
return newErrorf(
CodeParamsRoutingFailed, "could not parse address %s: %v",
conn.RemoteAddr().String(), err)
}
var clientErr error
backendConfig, clientErr = s.opts.BackendConfigFromParams(msg.Parameters, ip)
if clientErr != nil {
s.metrics.RoutingErrCount.Inc(1)
code := CodeParamsRoutingFailed
sendErrToClient(conn, code, clientErr.Error())
return newErrorf(code, "rejected by BackendConfigFromParams: %v", clientErr)
}
}

crdbConn, err := net.Dial("tcp", outgoingAddr)
if backendConfig.RefuseConn {
s.metrics.RefusedConnCount.Inc(1)
code := CodeProxyRefusedConnection
sendErrToClient(conn, code, "backend refused to admit")
return newErrorf(code, "backend refused to admit")
}

crdbConn, err := net.Dial("tcp", backendConfig.OutgoingAddress)
if err != nil {
s.metrics.BackendDownCount.Inc(1)
code := CodeBackendDown
Expand All @@ -142,7 +170,7 @@ func (s *Server) Proxy(conn net.Conn) error {
return newErrorf(CodeBackendRefusedTLS, "target server refused TLS connection")
}

outCfg := outgoingTLS.Clone()
outCfg := backendConfig.TLSConf.Clone()
crdbConn = tls.Client(crdbConn, outCfg)

if s.opts.ModifyRequestParams != nil {
Expand All @@ -151,7 +179,17 @@ func (s *Server) Proxy(conn net.Conn) error {

if _, err := crdbConn.Write(msg.Encode(nil)); err != nil {
s.metrics.BackendDownCount.Inc(1)
return newErrorf(CodeBackendDown, "relaying StartupMessage to target server %v: %v", outgoingAddr, err)
return newErrorf(CodeBackendDown, "relaying StartupMessage to target server %v: %v",
backendConfig.OutgoingAddress, err)
}

if backendConfig.OnConnectionSuccess != nil {
if err := backendConfig.OnConnectionSuccess(); err != nil {
code := CodeBackendDown
sendErrToClient(conn, code, err.Error())
s.metrics.BackendDownCount.Inc(1)
return newErrorf(code, "recording connection success: %v", err)
}
}

// These channels are buffered because we'll only consume one of them.
Expand Down
96 changes: 72 additions & 24 deletions pkg/ccl/sqlproxyccl/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (

"github.com/cockroachdb/cockroach/pkg/base"
"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
"github.com/cockroachdb/cockroach/pkg/testutils/skip"
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/errors"
"github.com/jackc/pgx/v4"
Expand Down Expand Up @@ -65,29 +64,31 @@ openssl req -new -x509 -sha256 -key testserver.key -out testserver.crt -days 365

func testingTenantIDFromDatabaseForAddr(
addr string, validTenant string,
) func(map[string]string) (string, *tls.Config, error) {
return func(p map[string]string) (_ string, config *tls.Config, clientErr error) {
) func(map[string]string, string) (config *BackendConfig, clientErr error) {
return func(p map[string]string, _ string) (config *BackendConfig, clientErr error) {
const dbKey = "database"
db, ok := p[dbKey]
if !ok {
return "", nil, errors.Newf("need to specify database")
return nil, errors.Newf("need to specify database")
}
sl := strings.SplitN(db, "_", 2)
if len(sl) != 2 {
return "", nil, errors.Newf("malformed database name")
return nil, errors.Newf("malformed database name")
}
db, tenantID := sl[0], sl[1]

if tenantID != validTenant {
return "", nil, errors.Newf("invalid tenantID")
return nil, errors.Newf("invalid tenantID")
}

p[dbKey] = db
config = &tls.Config{
// NB: this would be false in production.
InsecureSkipVerify: true,
}
return addr, config, nil
return &BackendConfig{
OutgoingAddress: addr,
TLSConf: &tls.Config{
// NB: this would be false in production.
InsecureSkipVerify: true,
},
}, nil
}
}

Expand Down Expand Up @@ -131,9 +132,10 @@ func TestLongDBName(t *testing.T) {

var m map[string]string
opts := Options{
BackendFromParams: func(mm map[string]string) (string, *tls.Config, error) {
BackendConfigFromParams: func(
mm map[string]string, _ string) (config *BackendConfig, clientErr error) {
m = mm
return "", nil, errors.New("boom")
return nil, errors.New("boom")
},
OnSendErrToClient: ac.onSendErrToClient,
}
Expand All @@ -153,8 +155,8 @@ func TestFailedConnection(t *testing.T) {

ac := makeAssertCtx()
opts := Options{
BackendFromParams: testingTenantIDFromDatabaseForAddr("undialable%$!@$", "29"),
OnSendErrToClient: ac.onSendErrToClient,
BackendConfigFromParams: testingTenantIDFromDatabaseForAddr("undialable%$!@$", "29"),
OnSendErrToClient: ac.onSendErrToClient,
}
addr, done := setupTestProxyWithCerts(t, &opts)
defer done()
Expand Down Expand Up @@ -203,19 +205,30 @@ func TestProxyAgainstSecureCRDB(t *testing.T) {
defer leaktest.AfterTest(t)()

ctx := context.Background()
tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{})
defer tc.Stopper().Stop(ctx)

skip.IgnoreLint(t, "this test needs a running (secure) CockroachDB instance at the given address")
const crdbSQL = "127.0.0.1:52966"
// TODO(asubiotto): use an in-mem test server once this code lives in the CRDB
// repo.
//
// TODO(tbg): if I use the https (!) port of ./cockroach demo, the
// connection hangs instead of failing. Why? Probably both ends end up waiting
// for the other side due to protocol mismatch. Should set deadlines on all
// the read/write ops to avoid this failure mode.

outgoingTLSConfig, err := tc.Server(0).RPCContext().GetClientTLSConfig()
require.NoError(t, err)
outgoingTLSConfig.InsecureSkipVerify = true

var connSuccess bool
opts := Options{
BackendFromParams: testingTenantIDFromDatabaseForAddr(crdbSQL, "29"),
BackendConfigFromParams: func(params map[string]string, _ string) (*BackendConfig, error) {
return &BackendConfig{
OutgoingAddress: tc.Server(0).ServingSQLAddr(),
TLSConf: outgoingTLSConfig,
OnConnectionSuccess: func() error {
connSuccess = true
return nil
},
}, nil
},
}
addr, done := setupTestProxyWithCerts(t, &opts)
defer done()
Expand All @@ -225,6 +238,7 @@ func TestProxyAgainstSecureCRDB(t *testing.T) {
require.NoError(t, err)
defer func() {
require.NoError(t, conn.Close(ctx))
require.True(t, connSuccess)
}()

var n int
Expand All @@ -241,12 +255,15 @@ func TestProxyModifyRequestParams(t *testing.T) {
defer tc.Stopper().Stop(ctx)

outgoingTLSConfig, err := tc.Server(0).RPCContext().GetClientTLSConfig()
outgoingTLSConfig.InsecureSkipVerify = true
require.NoError(t, err)
outgoingTLSConfig.InsecureSkipVerify = true

opts := Options{
BackendFromParams: func(params map[string]string) (string, *tls.Config, error) {
return tc.Server(0).ServingSQLAddr(), outgoingTLSConfig, nil
BackendConfigFromParams: func(params map[string]string, _ string) (*BackendConfig, error) {
return &BackendConfig{
OutgoingAddress: tc.Server(0).ServingSQLAddr(),
TLSConf: outgoingTLSConfig,
}, nil
},
ModifyRequestParams: func(params map[string]string) {
require.EqualValues(t, map[string]string{
Expand Down Expand Up @@ -275,3 +292,34 @@ func TestProxyModifyRequestParams(t *testing.T) {
require.NoError(t, err)
require.EqualValues(t, 1, n)
}

func TestProxyRefuseConn(t *testing.T) {
defer leaktest.AfterTest(t)()

ctx := context.Background()
tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{})
defer tc.Stopper().Stop(ctx)

outgoingTLSConfig, err := tc.Server(0).RPCContext().GetClientTLSConfig()
require.NoError(t, err)
outgoingTLSConfig.InsecureSkipVerify = true

ac := makeAssertCtx()
opts := Options{
BackendConfigFromParams: func(params map[string]string, _ string) (*BackendConfig, error) {
return &BackendConfig{
OutgoingAddress: tc.Server(0).ServingSQLAddr(),
TLSConf: outgoingTLSConfig,
RefuseConn: true,
}, nil
},
OnSendErrToClient: ac.onSendErrToClient,
}
addr, done := setupTestProxyWithCerts(t, &opts)
defer done()

ac.assertConnectErr(
t, fmt.Sprintf("postgres://root:admin@%s/", addr), "defaultdb_29?sslmode=require",
CodeProxyRefusedConnection, "backend refused to admit",
)
}

0 comments on commit d505572

Please sign in to comment.