diff --git a/pkg/ccl/cliccl/mtproxy.go b/pkg/ccl/cliccl/mtproxy.go index 6b09f706cfe4..81b2f6be2944 100644 --- a/pkg/ccl/cliccl/mtproxy.go +++ b/pkg/ccl/cliccl/mtproxy.go @@ -148,16 +148,22 @@ Uuwb2FVdh76ZK0AVd3Jh3KJs4+hr2u9syHaa7UPKXTcZsFWlGwZuu6X5A+0SO0S2 IncomingTLSConfig: &tls.Config{ Certificates: []tls.Certificate{cer}, }, - BackendFromParams: func(params map[string]string) (addr string, conf *tls.Config, clientErr error) { + BackendConfigFromParams: func( + params map[string]string, ipAddress string, + ) (config *sqlproxyccl.BackendConfig, clientErr error) { const magic = "prancing-pony" + cfg := &sqlproxyccl.BackendConfig{ + OutgoingAddress: sqlProxyTargetAddr, + TLSConf: outgoingConf, + } if strings.HasPrefix(params["database"], magic+".") { params["database"] = params["database"][len(magic)+1:] - return sqlProxyTargetAddr, outgoingConf, nil + return cfg, nil } if params["options"] == "--cluster="+magic { - return sqlProxyTargetAddr, outgoingConf, nil + return cfg, nil } - return "", nil, errors.Errorf("client failed to pass '%s' via database or options", magic) + return nil, errors.Errorf("client failed to pass '%s' via database or options", magic) }, }) diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index 5d9454efef59..63677fba25d2 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -38,7 +38,6 @@ go_test( "//pkg/security/securitytest", "//pkg/server", "//pkg/testutils/serverutils", - "//pkg/testutils/skip", "//pkg/testutils/testcluster", "//pkg/util/leaktest", "//pkg/util/randutil", diff --git a/pkg/ccl/sqlproxyccl/error.go b/pkg/ccl/sqlproxyccl/error.go index e1bf9c6d9288..e0b04562ae21 100644 --- a/pkg/ccl/sqlproxyccl/error.go +++ b/pkg/ccl/sqlproxyccl/error.go @@ -58,6 +58,10 @@ const ( // CodeClientDisconnected indicates that the client disconnected unexpectedly // (with a connection error) while in a session with backend SQL server. CodeClientDisconnected + + // CodeProxyRefusedConnection indicates that the proxy refused the connection + // request due to high load or too many connection attempts. + CodeProxyRefusedConnection ) type codeError struct { diff --git a/pkg/ccl/sqlproxyccl/errorcode_string.go b/pkg/ccl/sqlproxyccl/errorcode_string.go index 45292b2181f3..81f980eb5fa4 100644 --- a/pkg/ccl/sqlproxyccl/errorcode_string.go +++ b/pkg/ccl/sqlproxyccl/errorcode_string.go @@ -18,11 +18,12 @@ func _() { _ = x[CodeBackendRefusedTLS-8] _ = x[CodeBackendDisconnected-9] _ = x[CodeClientDisconnected-10] + _ = x[CodeProxyRefusedConnection-11] } -const _ErrorCode_name = "CodeClientReadFailedCodeClientWriteFailedCodeUnexpectedInsecureStartupMessageCodeSNIRoutingFailedCodeUnexpectedStartupMessageCodeParamsRoutingFailedCodeBackendDownCodeBackendRefusedTLSCodeBackendDisconnectedCodeClientDisconnected" +const _ErrorCode_name = "CodeClientReadFailedCodeClientWriteFailedCodeUnexpectedInsecureStartupMessageCodeSNIRoutingFailedCodeUnexpectedStartupMessageCodeParamsRoutingFailedCodeBackendDownCodeBackendRefusedTLSCodeBackendDisconnectedCodeClientDisconnectedCodeProxyRefusedConnection" -var _ErrorCode_index = [...]uint8{0, 20, 41, 77, 97, 125, 148, 163, 184, 207, 229} +var _ErrorCode_index = [...]uint8{0, 20, 41, 77, 97, 125, 148, 163, 184, 207, 229, 255} func (i ErrorCode) String() string { i -= 1 diff --git a/pkg/ccl/sqlproxyccl/metrics.go b/pkg/ccl/sqlproxyccl/metrics.go index 7f7c87c72ea0..e68b6d70c7ec 100644 --- a/pkg/ccl/sqlproxyccl/metrics.go +++ b/pkg/ccl/sqlproxyccl/metrics.go @@ -18,6 +18,7 @@ type Metrics struct { ClientDisconnectCount *metric.Counter CurConnCount *metric.Gauge RoutingErrCount *metric.Counter + RefusedConnCount *metric.Counter } // MetricStruct implements the metrics.Struct interface. @@ -56,6 +57,12 @@ var ( Measurement: "Disconnects", Unit: metric.Unit_COUNT, } + metaRefusedConnCount = metric.Metadata{ + Name: "proxy.err.refused_conn", + Help: "Number of refused connections initiated by a given IP", + Measurement: "Refused", + Unit: metric.Unit_COUNT, + } ) // MakeProxyMetrics instantiates the metrics holder for proxy monitoring. @@ -66,5 +73,6 @@ func MakeProxyMetrics() Metrics { ClientDisconnectCount: metric.NewCounter(metaClientDisconnectCount), CurConnCount: metric.NewGauge(metaCurConnCount), RoutingErrCount: metric.NewCounter(metaRoutingErrCount), + RefusedConnCount: metric.NewCounter(metaBackendDisconnectCount), } } diff --git a/pkg/ccl/sqlproxyccl/proxy.go b/pkg/ccl/sqlproxyccl/proxy.go index 01efd134a541..ed7eb84b5ace 100644 --- a/pkg/ccl/sqlproxyccl/proxy.go +++ b/pkg/ccl/sqlproxyccl/proxy.go @@ -22,17 +22,28 @@ const pgAcceptSSLRequest = 'S' // See https://www.postgresql.org/docs/9.1/protocol-message-formats.html. var pgSSLRequest = []int32{8, 80877103} +// BackendConfig contains the configuration of a backend connection that is +// being proxied. +type BackendConfig struct { + OutgoingAddress string + TLSConf *tls.Config + RefuseConn bool + OnConnectionSuccess func() error +} + // Options are the options to the Proxy method. type Options struct { IncomingTLSConfig *tls.Config // config used for client -> proxy connection // TODO(tbg): this is unimplemented and exists only to check which clients // allow use of SNI. Should always return ("", nil). - BackendFromSNI func(serverName string) (addr string, conf *tls.Config, clientErr error) - // BackendFromParams returns the address and TLS config to use for - // the proxy -> backend connection. The returned config must have - // an appropriate ServerName for the remote backend. - BackendFromParams func(map[string]string) (addr string, conf *tls.Config, clientErr error) + BackendConfigFromSNI func(serverName string) (config *BackendConfig, clientErr error) + // BackendFromParams returns the config to use for the proxy -> backend + // connection. The TLS config is in it and it must have an appropriate + // ServerName for the remote backend. + BackendConfigFromParams func( + params map[string]string, ipAddress string, + ) (config *BackendConfig, clientErr error) // If set, consulted to modify the parameters set by the frontend before // forwarding them to the backend during startup. @@ -86,15 +97,15 @@ func (s *Server) Proxy(conn net.Conn) error { sniServerName = h.ServerName return nil, nil } - if s.opts.BackendFromSNI != nil { - addr, _, clientErr := s.opts.BackendFromSNI(sniServerName) + if s.opts.BackendConfigFromSNI != nil { + cfg, clientErr := s.opts.BackendConfigFromSNI(sniServerName) if clientErr != nil { code := CodeSNIRoutingFailed sendErrToClient(conn, code, clientErr.Error()) // won't actually be shown by most clients return newErrorf(code, "rejected by OutgoingAddrFromSNI") } - if addr != "" { - return newErrorf(CodeSNIRoutingFailed, "BackendFromSNI is unimplemented") + if cfg.OutgoingAddress != "" { + return newErrorf(CodeSNIRoutingFailed, "BackendConfigFromSNI is unimplemented") } } conn = tls.Server(conn, cfg) @@ -109,15 +120,32 @@ func (s *Server) Proxy(conn net.Conn) error { return newErrorf(CodeUnexpectedStartupMessage, "unsupported post-TLS startup message: %T", m) } - outgoingAddr, outgoingTLS, clientErr := s.opts.BackendFromParams(msg.Parameters) - if clientErr != nil { - s.metrics.RoutingErrCount.Inc(1) - code := CodeParamsRoutingFailed - sendErrToClient(conn, code, clientErr.Error()) - return newErrorf(code, "rejected by OutgoingAddrFromParams: %v", clientErr) + var backendConfig *BackendConfig + { + ip, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil { + return newErrorf( + CodeParamsRoutingFailed, "could not parse address %s: %v", + conn.RemoteAddr().String(), err) + } + var clientErr error + backendConfig, clientErr = s.opts.BackendConfigFromParams(msg.Parameters, ip) + if clientErr != nil { + s.metrics.RoutingErrCount.Inc(1) + code := CodeParamsRoutingFailed + sendErrToClient(conn, code, clientErr.Error()) + return newErrorf(code, "rejected by BackendConfigFromParams: %v", clientErr) + } } - crdbConn, err := net.Dial("tcp", outgoingAddr) + if backendConfig.RefuseConn { + s.metrics.RefusedConnCount.Inc(1) + code := CodeProxyRefusedConnection + sendErrToClient(conn, code, "backend refused to admit") + return newErrorf(code, "backend refused to admit") + } + + crdbConn, err := net.Dial("tcp", backendConfig.OutgoingAddress) if err != nil { s.metrics.BackendDownCount.Inc(1) code := CodeBackendDown @@ -142,7 +170,7 @@ func (s *Server) Proxy(conn net.Conn) error { return newErrorf(CodeBackendRefusedTLS, "target server refused TLS connection") } - outCfg := outgoingTLS.Clone() + outCfg := backendConfig.TLSConf.Clone() crdbConn = tls.Client(crdbConn, outCfg) if s.opts.ModifyRequestParams != nil { @@ -151,7 +179,17 @@ func (s *Server) Proxy(conn net.Conn) error { if _, err := crdbConn.Write(msg.Encode(nil)); err != nil { s.metrics.BackendDownCount.Inc(1) - return newErrorf(CodeBackendDown, "relaying StartupMessage to target server %v: %v", outgoingAddr, err) + return newErrorf(CodeBackendDown, "relaying StartupMessage to target server %v: %v", + backendConfig.OutgoingAddress, err) + } + + if backendConfig.OnConnectionSuccess != nil { + if err := backendConfig.OnConnectionSuccess(); err != nil { + code := CodeBackendDown + sendErrToClient(conn, code, err.Error()) + s.metrics.BackendDownCount.Inc(1) + return newErrorf(code, "recording connection success: %v", err) + } } // These channels are buffered because we'll only consume one of them. diff --git a/pkg/ccl/sqlproxyccl/proxy_test.go b/pkg/ccl/sqlproxyccl/proxy_test.go index b042969a7323..aaa55de691f3 100644 --- a/pkg/ccl/sqlproxyccl/proxy_test.go +++ b/pkg/ccl/sqlproxyccl/proxy_test.go @@ -19,7 +19,6 @@ import ( "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" - "github.com/cockroachdb/cockroach/pkg/testutils/skip" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/errors" "github.com/jackc/pgx/v4" @@ -65,29 +64,31 @@ openssl req -new -x509 -sha256 -key testserver.key -out testserver.crt -days 365 func testingTenantIDFromDatabaseForAddr( addr string, validTenant string, -) func(map[string]string) (string, *tls.Config, error) { - return func(p map[string]string) (_ string, config *tls.Config, clientErr error) { +) func(map[string]string, string) (config *BackendConfig, clientErr error) { + return func(p map[string]string, _ string) (config *BackendConfig, clientErr error) { const dbKey = "database" db, ok := p[dbKey] if !ok { - return "", nil, errors.Newf("need to specify database") + return nil, errors.Newf("need to specify database") } sl := strings.SplitN(db, "_", 2) if len(sl) != 2 { - return "", nil, errors.Newf("malformed database name") + return nil, errors.Newf("malformed database name") } db, tenantID := sl[0], sl[1] if tenantID != validTenant { - return "", nil, errors.Newf("invalid tenantID") + return nil, errors.Newf("invalid tenantID") } p[dbKey] = db - config = &tls.Config{ - // NB: this would be false in production. - InsecureSkipVerify: true, - } - return addr, config, nil + return &BackendConfig{ + OutgoingAddress: addr, + TLSConf: &tls.Config{ + // NB: this would be false in production. + InsecureSkipVerify: true, + }, + }, nil } } @@ -131,9 +132,10 @@ func TestLongDBName(t *testing.T) { var m map[string]string opts := Options{ - BackendFromParams: func(mm map[string]string) (string, *tls.Config, error) { + BackendConfigFromParams: func( + mm map[string]string, _ string) (config *BackendConfig, clientErr error) { m = mm - return "", nil, errors.New("boom") + return nil, errors.New("boom") }, OnSendErrToClient: ac.onSendErrToClient, } @@ -153,8 +155,8 @@ func TestFailedConnection(t *testing.T) { ac := makeAssertCtx() opts := Options{ - BackendFromParams: testingTenantIDFromDatabaseForAddr("undialable%$!@$", "29"), - OnSendErrToClient: ac.onSendErrToClient, + BackendConfigFromParams: testingTenantIDFromDatabaseForAddr("undialable%$!@$", "29"), + OnSendErrToClient: ac.onSendErrToClient, } addr, done := setupTestProxyWithCerts(t, &opts) defer done() @@ -203,19 +205,30 @@ func TestProxyAgainstSecureCRDB(t *testing.T) { defer leaktest.AfterTest(t)() ctx := context.Background() + tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) + defer tc.Stopper().Stop(ctx) - skip.IgnoreLint(t, "this test needs a running (secure) CockroachDB instance at the given address") - const crdbSQL = "127.0.0.1:52966" - // TODO(asubiotto): use an in-mem test server once this code lives in the CRDB - // repo. - // // TODO(tbg): if I use the https (!) port of ./cockroach demo, the // connection hangs instead of failing. Why? Probably both ends end up waiting // for the other side due to protocol mismatch. Should set deadlines on all // the read/write ops to avoid this failure mode. + outgoingTLSConfig, err := tc.Server(0).RPCContext().GetClientTLSConfig() + require.NoError(t, err) + outgoingTLSConfig.InsecureSkipVerify = true + + var connSuccess bool opts := Options{ - BackendFromParams: testingTenantIDFromDatabaseForAddr(crdbSQL, "29"), + BackendConfigFromParams: func(params map[string]string, _ string) (*BackendConfig, error) { + return &BackendConfig{ + OutgoingAddress: tc.Server(0).ServingSQLAddr(), + TLSConf: outgoingTLSConfig, + OnConnectionSuccess: func() error { + connSuccess = true + return nil + }, + }, nil + }, } addr, done := setupTestProxyWithCerts(t, &opts) defer done() @@ -225,6 +238,7 @@ func TestProxyAgainstSecureCRDB(t *testing.T) { require.NoError(t, err) defer func() { require.NoError(t, conn.Close(ctx)) + require.True(t, connSuccess) }() var n int @@ -241,12 +255,15 @@ func TestProxyModifyRequestParams(t *testing.T) { defer tc.Stopper().Stop(ctx) outgoingTLSConfig, err := tc.Server(0).RPCContext().GetClientTLSConfig() - outgoingTLSConfig.InsecureSkipVerify = true require.NoError(t, err) + outgoingTLSConfig.InsecureSkipVerify = true opts := Options{ - BackendFromParams: func(params map[string]string) (string, *tls.Config, error) { - return tc.Server(0).ServingSQLAddr(), outgoingTLSConfig, nil + BackendConfigFromParams: func(params map[string]string, _ string) (*BackendConfig, error) { + return &BackendConfig{ + OutgoingAddress: tc.Server(0).ServingSQLAddr(), + TLSConf: outgoingTLSConfig, + }, nil }, ModifyRequestParams: func(params map[string]string) { require.EqualValues(t, map[string]string{ @@ -275,3 +292,34 @@ func TestProxyModifyRequestParams(t *testing.T) { require.NoError(t, err) require.EqualValues(t, 1, n) } + +func TestProxyRefuseConn(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) + defer tc.Stopper().Stop(ctx) + + outgoingTLSConfig, err := tc.Server(0).RPCContext().GetClientTLSConfig() + require.NoError(t, err) + outgoingTLSConfig.InsecureSkipVerify = true + + ac := makeAssertCtx() + opts := Options{ + BackendConfigFromParams: func(params map[string]string, _ string) (*BackendConfig, error) { + return &BackendConfig{ + OutgoingAddress: tc.Server(0).ServingSQLAddr(), + TLSConf: outgoingTLSConfig, + RefuseConn: true, + }, nil + }, + OnSendErrToClient: ac.onSendErrToClient, + } + addr, done := setupTestProxyWithCerts(t, &opts) + defer done() + + ac.assertConnectErr( + t, fmt.Sprintf("postgres://root:admin@%s/", addr), "defaultdb_29?sslmode=require", + CodeProxyRefusedConnection, "backend refused to admit", + ) +}