diff --git a/pkg/ccl/cliccl/BUILD.bazel b/pkg/ccl/cliccl/BUILD.bazel index 217382fb7922..068d930478fa 100644 --- a/pkg/ccl/cliccl/BUILD.bazel +++ b/pkg/ccl/cliccl/BUILD.bazel @@ -40,6 +40,7 @@ go_library( "//vendor/github.com/cockroachdb/cmux", "//vendor/github.com/cockroachdb/errors", "//vendor/github.com/cockroachdb/errors/oserror", + "//vendor/github.com/jackc/pgproto3/v2:pgproto3", "//vendor/github.com/spf13/cobra", "//vendor/golang.org/x/sync/errgroup", ], diff --git a/pkg/ccl/cliccl/mtproxy.go b/pkg/ccl/cliccl/mtproxy.go index 6c13f2617195..db2ca4efbc9f 100644 --- a/pkg/ccl/cliccl/mtproxy.go +++ b/pkg/ccl/cliccl/mtproxy.go @@ -19,6 +19,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl" "github.com/cockroachdb/cockroach/pkg/cli" "github.com/cockroachdb/errors" + "github.com/jackc/pgproto3/v2" "github.com/spf13/cobra" "golang.org/x/sync/errgroup" ) @@ -148,22 +149,19 @@ Uuwb2FVdh76ZK0AVd3Jh3KJs4+hr2u9syHaa7UPKXTcZsFWlGwZuu6X5A+0SO0S2 IncomingTLSConfig: &tls.Config{ Certificates: []tls.Certificate{cer}, }, - BackendConfigFromParams: func( - params map[string]string, _ *sqlproxyccl.Conn, - ) (config *sqlproxyccl.BackendConfig, clientErr error) { + BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { + params := msg.Parameters const magic = "prancing-pony" - cfg := &sqlproxyccl.BackendConfig{ - OutgoingAddress: sqlProxyTargetAddr, - TLSConf: outgoingConf, - } if strings.HasPrefix(params["database"], magic+".") { params["database"] = params["database"][len(magic)+1:] - return cfg, nil + } else if params["options"] != "--cluster="+magic { + return nil, errors.Errorf("client failed to pass '%s' via database or options", magic) } - if params["options"] == "--cluster="+magic { - return cfg, nil + conn, err := sqlproxyccl.BackendDial(msg, sqlProxyTargetAddr, outgoingConf) + if err != nil { + return nil, err } - return nil, errors.Errorf("client failed to pass '%s' via database or options", magic) + return conn, nil }, }) diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index 54a6c51de926..53853c234cec 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -3,8 +3,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "sqlproxyccl", srcs = [ + "backend_dialer.go", "error.go", "errorcode_string.go", + "idle_disconnect_connection.go", "metrics.go", "proxy.go", "server.go", @@ -26,6 +28,7 @@ go_library( go_test( name = "sqlproxyccl_test", srcs = [ + "idle_disconnect_connection_test.go", "main_test.go", "proxy_test.go", "server_test.go", @@ -43,6 +46,7 @@ go_test( "//pkg/util/randutil", "//pkg/util/timeutil", "//vendor/github.com/cockroachdb/errors", + "//vendor/github.com/jackc/pgproto3/v2:pgproto3", "//vendor/github.com/jackc/pgx/v4:pgx", "//vendor/github.com/stretchr/testify/require", ], diff --git a/pkg/ccl/sqlproxyccl/backend_dialer.go b/pkg/ccl/sqlproxyccl/backend_dialer.go new file mode 100644 index 000000000000..0e3aff5e4cbb --- /dev/null +++ b/pkg/ccl/sqlproxyccl/backend_dialer.go @@ -0,0 +1,79 @@ +// Copyright 2020 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "crypto/tls" + "encoding/binary" + "io" + "net" + + "github.com/jackc/pgproto3/v2" +) + +// BackendDial is an example backend dialer that does a TCP/IP connection +// to a backend, SSL and forwards the start message. +func BackendDial( + msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config, +) (net.Conn, error) { + conn, err := net.Dial("tcp", outgoingAddress) + if err != nil { + return nil, NewErrorf( + CodeBackendDown, "unable to reach backend SQL server: %v", err, + ) + } + conn, err = SSLOverlay(conn, tlsConfig) + if err != nil { + return nil, err + } + err = RelayStartupMsg(conn, msg) + if err != nil { + return nil, NewErrorf( + CodeBackendDown, "relaying StartupMessage to target server %v: %v", + outgoingAddress, err) + } + return conn, nil +} + +// SSLOverlay attempts to upgrade the PG connection to use SSL +// if a tls.Config is specified.. +func SSLOverlay(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { + if tlsConfig == nil { + return conn, nil + } + + var err error + // Send SSLRequest. + if err := binary.Write(conn, binary.BigEndian, pgSSLRequest); err != nil { + return nil, NewErrorf( + CodeBackendDown, "sending SSLRequest to target server: %v", err, + ) + } + + response := make([]byte, 1) + if _, err = io.ReadFull(conn, response); err != nil { + return nil, + NewErrorf(CodeBackendDown, "reading response to SSLRequest") + } + + if response[0] != pgAcceptSSLRequest { + return nil, NewErrorf( + CodeBackendRefusedTLS, "target server refused TLS connection", + ) + } + + outCfg := tlsConfig.Clone() + return tls.Client(conn, outCfg), nil +} + +// RelayStartupMsg forwards the start message on the backend connection. +func RelayStartupMsg(conn net.Conn, msg *pgproto3.StartupMessage) (err error) { + _, err = conn.Write(msg.Encode(nil)) + return +} diff --git a/pkg/ccl/sqlproxyccl/error.go b/pkg/ccl/sqlproxyccl/error.go index df4fc4d4237b..2c1c88e0f8c4 100644 --- a/pkg/ccl/sqlproxyccl/error.go +++ b/pkg/ccl/sqlproxyccl/error.go @@ -66,20 +66,26 @@ const ( // CodeExpiredClientConnection indicates that proxy connection to the client // has expired and should be closed. CodeExpiredClientConnection + + // CodeIdleDisconnect indicates that the connection was disconnected for + // being idle for longer than the specified timeout. + CodeIdleDisconnect ) -type codeError struct { +// CodeError is combines an error with one of the above codes to ease +// the processing of the errors. +type CodeError struct { code ErrorCode err error } -func (e *codeError) Error() string { +func (e *CodeError) Error() string { return fmt.Sprintf("%s: %s", e.code, e.err) } -// NewErrorf returns a new codeError out of the supplied args. +// NewErrorf returns a new CodeError out of the supplied args. func NewErrorf(code ErrorCode, format string, args ...interface{}) error { - return &codeError{ + return &CodeError{ code: code, err: errors.Errorf(format, args...), } diff --git a/pkg/ccl/sqlproxyccl/errorcode_string.go b/pkg/ccl/sqlproxyccl/errorcode_string.go index feb7c15b6d65..91eedaeef236 100644 --- a/pkg/ccl/sqlproxyccl/errorcode_string.go +++ b/pkg/ccl/sqlproxyccl/errorcode_string.go @@ -20,11 +20,12 @@ func _() { _ = x[CodeClientDisconnected-10] _ = x[CodeProxyRefusedConnection-11] _ = x[CodeExpiredClientConnection-12] + _ = x[CodeIdleDisconnect-13] } -const _ErrorCode_name = "CodeClientReadFailedCodeClientWriteFailedCodeUnexpectedInsecureStartupMessageCodeSNIRoutingFailedCodeUnexpectedStartupMessageCodeParamsRoutingFailedCodeBackendDownCodeBackendRefusedTLSCodeBackendDisconnectedCodeClientDisconnectedCodeProxyRefusedConnectionCodeExpiredClientConnection" +const _ErrorCode_name = "CodeClientReadFailedCodeClientWriteFailedCodeUnexpectedInsecureStartupMessageCodeSNIRoutingFailedCodeUnexpectedStartupMessageCodeParamsRoutingFailedCodeBackendDownCodeBackendRefusedTLSCodeBackendDisconnectedCodeClientDisconnectedCodeProxyRefusedConnectionCodeExpiredClientConnectionCodeIdleDisconnect" -var _ErrorCode_index = [...]uint16{0, 20, 41, 77, 97, 125, 148, 163, 184, 207, 229, 255, 282} +var _ErrorCode_index = [...]uint16{0, 20, 41, 77, 97, 125, 148, 163, 184, 207, 229, 255, 282, 300} func (i ErrorCode) String() string { i -= 1 diff --git a/pkg/ccl/sqlproxyccl/idle_disconnect_connection.go b/pkg/ccl/sqlproxyccl/idle_disconnect_connection.go new file mode 100644 index 000000000000..bc5fd1be9e41 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/idle_disconnect_connection.go @@ -0,0 +1,97 @@ +// Copyright 2020 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "net" + "time" + + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/errors" +) + +// IdleDisconnectConnection is a wrapper around net.Conn that disconnects if +// connection is idle. The idle time is only counted while the client is +// waiting, blocked on Read. +type IdleDisconnectConnection struct { + net.Conn + timeout time.Duration + mu struct { + syncutil.Mutex + lastDeadlineSetAt time.Time + } +} + +var errNotSupported = errors.Errorf( + "Not supported for IdleDisconnectConnection", +) + +func (c *IdleDisconnectConnection) updateDeadline() error { + now := timeutil.Now() + // If it has been more than 1% of the timeout duration - advance the deadline. + c.mu.Lock() + defer c.mu.Unlock() + if now.Sub(c.mu.lastDeadlineSetAt) > c.timeout/100 { + c.mu.lastDeadlineSetAt = now + + if err := c.Conn.SetReadDeadline(now.Add(c.timeout)); err != nil { + return err + } + } + return nil +} + +// Read reads data from the connection with timeout. +func (c *IdleDisconnectConnection) Read(b []byte) (n int, err error) { + if err := c.updateDeadline(); err != nil { + return 0, err + } + return c.Conn.Read(b) +} + +// Write writes data to the connection and sets the read timeout. +func (c *IdleDisconnectConnection) Write(b []byte) (n int, err error) { + // The Write for the connection is not blocking (or can block only temporary + // in case of flow control). For idle connections, the Read will be the call + // that will block and stay blocked until the backend doesn't send something. + // However, it is theoretically possible, that the traffic is only going in + // one direction - from the proxy to the backend, in which case we will call + // repeatedly Write but stay blocked on the Read. For that specific case - the + // write pushes further out the read deadline so the read doesn't timeout. + if err := c.updateDeadline(); err != nil { + return 0, err + } + return c.Conn.Write(b) +} + +// SetDeadline is unsupported as it will interfere with the reads. +func (c *IdleDisconnectConnection) SetDeadline(t time.Time) error { + return errNotSupported +} + +// SetReadDeadline is unsupported as it will interfere with the reads. +func (c *IdleDisconnectConnection) SetReadDeadline(t time.Time) error { + return errNotSupported +} + +// SetWriteDeadline is unsupported as it will interfere with the reads. +func (c *IdleDisconnectConnection) SetWriteDeadline(t time.Time) error { + return errNotSupported +} + +// IdleDisconnectOverlay upgrades the connection to one that closes when +// idle for more than timeout duration. Timeout of zero will turn off +// the idle disconnect code. +func IdleDisconnectOverlay(conn net.Conn, timeout time.Duration) net.Conn { + if timeout != 0 { + return &IdleDisconnectConnection{Conn: conn, timeout: timeout} + } + return conn +} diff --git a/pkg/ccl/sqlproxyccl/idle_disconnect_connection_test.go b/pkg/ccl/sqlproxyccl/idle_disconnect_connection_test.go new file mode 100644 index 000000000000..002901450914 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/idle_disconnect_connection_test.go @@ -0,0 +1,137 @@ +// Copyright 2020 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "fmt" + "io" + "net" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/require" +) + +func setupServerWithIdleDisconnect(t testing.TB, timeout time.Duration) net.Addr { + server, err := net.Listen("tcp", "") + require.NoError(t, err) + + // Server is echoing back bytes in an infinite loop. + go func() { + cServ, err := server.Accept() + if err != nil { + t.Errorf("Error during accept: %v", err) + } + defer cServ.Close() + cServ = IdleDisconnectOverlay(cServ, timeout) + _, _ = io.Copy(cServ, cServ) + }() + return server.Addr() +} + +func ping(conn net.Conn) error { + n, err := conn.Write([]byte{1}) + if err != nil { + return err + } + if n != 1 { + return errors.Newf("Expected 1 written byte but got %d", n) + } + n, err = conn.Read([]byte{1}) + if err != nil { + return err + } + if n != 1 { + return errors.Newf("Expected 1 read byte but got %d", n) + } + return nil +} + +func benchmarkSocketRead(timeout time.Duration, b *testing.B) { + addr := setupServerWithIdleDisconnect(b, timeout) + + cCli, err := net.Dial("tcp", addr.String()) + if err != nil { + b.Errorf("Error during dial: %v", err) + } + defer cCli.Close() + + bCli := []byte{1} + for i := 0; i < b.N; i++ { + _, err = cCli.Write(bCli) + if err != nil { + b.Errorf("Error during write: %v", err) + } + _, err = cCli.Read(bCli) + if err != nil { + b.Errorf("Error during read: %v", err) + } + } + _, err = cCli.Write([]byte{}) // This serves as EOF and shuts down the echo server + if err != nil { + b.Errorf("Error during read: %v", err) + } +} + +// No statistically significant difference in a single roundtrip time between +// using and not using deadline as implemented above. Both show the same value in my tests. +// SocketReadWithDeadline-32 11.1µs ± 1% +// SocketReadWithoutDeadline-32 11.0µs ± 3% +func BenchmarkSocketReadWithoutDeadline(b *testing.B) { + benchmarkSocketRead(0, b) +} +func BenchmarkSocketReadWithDeadline(b *testing.B) { + benchmarkSocketRead(1e8, b) +} + +func TestIdleDisconnectOverlay(t *testing.T) { + defer leaktest.AfterTest(t)() + tests := []struct { + timeout time.Duration + willCloseAtCheck int + }{ + // The disconnect checks are done at 0.2s, 0.7s and 1.5s marks. + {0, 0}, + {0.1e9, 1}, + {0.3e9, 2}, + {0.6e9, 3}, + {1e9, 0}, + } + + for _, tt := range tests { + name := fmt.Sprintf( + "timeout(%s)-willCloseAt(%d)", tt.timeout, tt.willCloseAtCheck, + ) + t.Run(name, func(t *testing.T) { + addr := setupServerWithIdleDisconnect(t, tt.timeout) + conn, err := net.Dial("tcp", addr.String()) + require.NoError(t, err, "Unable to dial server") + time.Sleep(.2e9) + if tt.willCloseAtCheck == 1 { + require.Error(t, ping(conn)) + return + } + require.NoError(t, ping(conn)) + time.Sleep(.5e9) + if tt.willCloseAtCheck == 2 { + require.Error(t, ping(conn)) + return + } + require.NoError(t, ping(conn)) + time.Sleep(.8e9) + if tt.willCloseAtCheck == 3 { + require.Error(t, ping(conn)) + return + } + require.NoError(t, ping(conn)) + }) + } +} diff --git a/pkg/ccl/sqlproxyccl/metrics.go b/pkg/ccl/sqlproxyccl/metrics.go index e8115bf66301..be6159f0b3e7 100644 --- a/pkg/ccl/sqlproxyccl/metrics.go +++ b/pkg/ccl/sqlproxyccl/metrics.go @@ -14,6 +14,7 @@ import "github.com/cockroachdb/cockroach/pkg/util/metric" // operations. type Metrics struct { BackendDisconnectCount *metric.Counter + IdleDisconnectCount *metric.Counter BackendDownCount *metric.Counter ClientDisconnectCount *metric.Counter CurConnCount *metric.Gauge @@ -53,10 +54,16 @@ var ( Measurement: "Disconnects", Unit: metric.Unit_COUNT, } + metaIdleDisconnectCount = metric.Metadata{ + Name: "proxy.err.idle_disconnect", + Help: "Number of disconnects due to idle timeout", + Measurement: "Idle Disconnects", + Unit: metric.Unit_COUNT, + } metaClientDisconnectCount = metric.Metadata{ Name: "proxy.err.client_disconnect", Help: "Number of disconnects initiated by clients", - Measurement: "Disconnects", + Measurement: "Client Disconnects", Unit: metric.Unit_COUNT, } metaRefusedConnCount = metric.Metadata{ @@ -83,6 +90,7 @@ var ( func MakeProxyMetrics() Metrics { return Metrics{ BackendDisconnectCount: metric.NewCounter(metaBackendDisconnectCount), + IdleDisconnectCount: metric.NewCounter(metaIdleDisconnectCount), BackendDownCount: metric.NewCounter(metaBackendDownCount), ClientDisconnectCount: metric.NewCounter(metaClientDisconnectCount), CurConnCount: metric.NewGauge(metaCurConnCount), diff --git a/pkg/ccl/sqlproxyccl/proxy.go b/pkg/ccl/sqlproxyccl/proxy.go index 26fbf60dd543..d7db698dada1 100644 --- a/pkg/ccl/sqlproxyccl/proxy.go +++ b/pkg/ccl/sqlproxyccl/proxy.go @@ -11,9 +11,9 @@ package sqlproxyccl import ( "context" "crypto/tls" - "encoding/binary" "io" "net" + "os" "github.com/cockroachdb/errors" "github.com/jackc/pgproto3/v2" @@ -26,6 +26,7 @@ var pgSSLRequest = []int32{8, 80877103} // BackendConfig contains the configuration of a backend connection that is // being proxied. +// To be removed once all clients are migrated to use backend dialer. type BackendConfig struct { // The address to which the connection is forwarded. OutgoingAddress string @@ -50,23 +51,30 @@ type BackendConfig struct { type Options struct { IncomingTLSConfig *tls.Config // config used for client -> proxy connection - // TODO(tbg): this is unimplemented and exists only to check which clients - // allow use of SNI. Should always return ("", nil). - BackendConfigFromSNI func(serverName string) (config *BackendConfig, clientErr error) // BackendFromParams returns the config to use for the proxy -> backend // connection. The TLS config is in it and it must have an appropriate // ServerName for the remote backend. + // Deprecated: processing of the params now happens in the BackendDialer. + // This is only here to support OnSuccess and KeepAlive. BackendConfigFromParams func( params map[string]string, incomingConn *Conn, ) (config *BackendConfig, clientErr error) // If set, consulted to modify the parameters set by the frontend before // forwarding them to the backend during startup. + // Deprecated: include the code that modifies the request params + // in the backend dialer. ModifyRequestParams func(map[string]string) // If set, consulted to decorate an error message to be sent to the client. // The error passed to this method will contain no internal information. OnSendErrToClient func(code ErrorCode, msg string) string + + // If set, will be used to establish and return connection to the backend. + // If not set, the old logic will be used. + // The argument is the startup message received from the frontend. It + // contains the protocol version and params sent by the client. + BackendDialer func(msg *pgproto3.StartupMessage) (net.Conn, error) } // Proxy takes an incoming client connection and relays it to a backend SQL @@ -76,9 +84,16 @@ func (s *Server) Proxy(proxyConn *Conn) error { if s.opts.OnSendErrToClient != nil { msg = s.opts.OnSendErrToClient(code, msg) } + + var pgCode string + if code == CodeIdleDisconnect { + pgCode = "57P01" // admin shutdown + } else { + pgCode = "08004" // rejected connection + } _, _ = conn.Write((&pgproto3.ErrorResponse{ Severity: "FATAL", - Code: "08004", // rejected connection + Code: pgCode, Message: msg, }).Encode(nil)) } @@ -88,6 +103,7 @@ func (s *Server) Proxy(proxyConn *Conn) error { // hence it's important to close `conn` rather than `proxyConn` since closing // the latter will not call `Close` method of `tls.Conn`. defer func() { _ = conn.Close() }() + var sniServerName string // If we have an incoming TLS Config, require that the client initiates // with a TLS connection. if s.opts.IncomingTLSConfig != nil { @@ -114,22 +130,11 @@ func (s *Server) Proxy(proxyConn *Conn) error { } cfg := s.opts.IncomingTLSConfig.Clone() - var sniServerName string + cfg.GetConfigForClient = func(h *tls.ClientHelloInfo) (*tls.Config, error) { sniServerName = h.ServerName return nil, nil } - if s.opts.BackendConfigFromSNI != nil { - cfg, clientErr := s.opts.BackendConfigFromSNI(sniServerName) - if clientErr != nil { - code := CodeSNIRoutingFailed - sendErrToClient(conn, code, clientErr.Error()) // won't actually be shown by most clients - return NewErrorf(code, "rejected by OutgoingAddrFromSNI") - } - if cfg.OutgoingAddress != "" { - return NewErrorf(CodeSNIRoutingFailed, "BackendConfigFromSNI is unimplemented") - } - } conn = tls.Server(conn, cfg) } @@ -142,73 +147,65 @@ func (s *Server) Proxy(proxyConn *Conn) error { return NewErrorf(CodeUnexpectedStartupMessage, "unsupported post-TLS startup message: %T", m) } - var backendConfig *BackendConfig - { - var clientErr error - backendConfig, clientErr = s.opts.BackendConfigFromParams(msg.Parameters, proxyConn) - if clientErr != nil { - var codeErr *codeError - if !errors.As(clientErr, &codeErr) { - codeErr = &codeError{ - code: CodeParamsRoutingFailed, - err: errors.Errorf("rejected by BackendConfigFromParams: %v", clientErr), + // Add the sniServerName (if used) as parameter + if sniServerName != "" { + msg.Parameters["sni-server"] = sniServerName + } + + backendDialer := s.opts.BackendDialer + if backendDialer == nil { + // This we need to keep until all the clients are switched to provide BackendDialer. + // It constructs a backend dialer from the information provided via + // BackendConfigFromParams function. + backendDialer = func(msg *pgproto3.StartupMessage) (net.Conn, error) { + backendConfig, clientErr := s.opts.BackendConfigFromParams(msg.Parameters, proxyConn) + if clientErr != nil { + var codeErr *CodeError + if !errors.As(clientErr, &codeErr) { + codeErr = &CodeError{ + code: CodeParamsRoutingFailed, + err: errors.Errorf("rejected by BackendConfigFromParams: %v", clientErr), + } } + return nil, codeErr } - if codeErr.code == CodeProxyRefusedConnection { - s.metrics.RefusedConnCount.Inc(1) - } else { - s.metrics.RoutingErrCount.Inc(1) + + // We should be able to remove this when the all clients switch to + // backend dialer. + if s.opts.ModifyRequestParams != nil { + s.opts.ModifyRequestParams(msg.Parameters) } - sendErrToClient(conn, codeErr.code, clientErr.Error()) - return codeErr + + crdbConn, err := BackendDial(msg, backendConfig.OutgoingAddress, backendConfig.TLSConf) + if err != nil { + return nil, err + } + + return crdbConn, nil } } - crdbConn, err := net.Dial("tcp", backendConfig.OutgoingAddress) + crdbConn, err := backendDialer(msg) if err != nil { s.metrics.BackendDownCount.Inc(1) - code := CodeBackendDown - sendErrToClient(conn, code, "unable to reach backend SQL server") - return NewErrorf(code, "dialing backend server: %v", err) - } - defer func() { _ = crdbConn.Close() }() - - if backendConfig.TLSConf != nil { - // Send SSLRequest. - if err := binary.Write(crdbConn, binary.BigEndian, pgSSLRequest); err != nil { - s.metrics.BackendDownCount.Inc(1) - return NewErrorf(CodeBackendDown, "sending SSLRequest to target server: %v", err) - } - - response := make([]byte, 1) - if _, err = io.ReadFull(crdbConn, response); err != nil { - s.metrics.BackendDownCount.Inc(1) - return NewErrorf(CodeBackendDown, "reading response to SSLRequest") + var codeErr *CodeError + if !errors.As(err, &codeErr) { + codeErr = &CodeError{ + code: CodeBackendDown, + err: errors.Errorf("unable to reach backend SQL server: %v", err), + } } - - if response[0] != pgAcceptSSLRequest { - s.metrics.BackendDownCount.Inc(1) - return NewErrorf(CodeBackendRefusedTLS, "target server refused TLS connection") + if codeErr.code == CodeProxyRefusedConnection { + s.metrics.RefusedConnCount.Inc(1) + } else if codeErr.code == CodeParamsRoutingFailed { + s.metrics.RoutingErrCount.Inc(1) } - - outCfg := backendConfig.TLSConf.Clone() - crdbConn = tls.Client(crdbConn, outCfg) - } - - if s.opts.ModifyRequestParams != nil { - s.opts.ModifyRequestParams(msg.Parameters) - } - - if _, err := crdbConn.Write(msg.Encode(nil)); err != nil { - s.metrics.BackendDownCount.Inc(1) - return NewErrorf(CodeBackendDown, "relaying StartupMessage to target server %v: %v", - backendConfig.OutgoingAddress, err) + sendErrToClient(conn, codeErr.code, codeErr.Error()) + return codeErr } + defer func() { _ = crdbConn.Close() }() s.metrics.SuccessfulConnCount.Inc(1) - if backendConfig.OnConnectionSuccess != nil { - backendConfig.OnConnectionSuccess() - } // These channels are buffered because we'll only consume one of them. errOutgoing := make(chan error, 1) @@ -217,6 +214,22 @@ func (s *Server) Proxy(proxyConn *Conn) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + if s.opts.BackendConfigFromParams != nil { + // Ignore the next error as we already did all checks in the BackendDialer + // so there shouldn't be any errors here. + // This is temporary until Spas moves processing of OnConnectionSuccess and + // KeepAliveLoop outside of the proxy. + backendConfig, _ := s.opts.BackendConfigFromParams(msg.Parameters, proxyConn) + if backendConfig.OnConnectionSuccess != nil { + backendConfig.OnConnectionSuccess() + } + if backendConfig.KeepAliveLoop != nil { + go func() { + errExpired <- backendConfig.KeepAliveLoop(ctx) + }() + } + } + go func() { _, err := io.Copy(crdbConn, conn) errOutgoing <- err @@ -225,11 +238,6 @@ func (s *Server) Proxy(proxyConn *Conn) error { _, err := io.Copy(conn, crdbConn) errIncoming <- err }() - if backendConfig.KeepAliveLoop != nil { - go func() { - errExpired <- backendConfig.KeepAliveLoop(ctx) - }() - } select { // NB: when using pgx, we see a nil errIncoming first on clean connection @@ -239,11 +247,21 @@ func (s *Server) Proxy(proxyConn *Conn) error { // client gets to close the connection once it's sent that message, // meaning either case is possible. case err := <-errIncoming: - if err != nil { + if err == nil { + return nil + } else if codeErr := (*CodeError)(nil); errors.As(err, &codeErr) && + codeErr.code == CodeExpiredClientConnection { + s.metrics.ExpiredClientConnCount.Inc(1) + sendErrToClient(conn, codeErr.code, codeErr.Error()) + return codeErr + } else if errors.Is(err, os.ErrDeadlineExceeded) { + s.metrics.IdleDisconnectCount.Inc(1) + sendErrToClient(conn, CodeIdleDisconnect, "terminating connection due to idle timeout") + return NewErrorf(CodeIdleDisconnect, "terminating connection due to idle timeout: %v", err) + } else { s.metrics.BackendDisconnectCount.Inc(1) return NewErrorf(CodeBackendDisconnected, "copying from target server to client: %s", err) } - return nil case err := <-errOutgoing: // The incoming connection got closed. if err != nil { diff --git a/pkg/ccl/sqlproxyccl/proxy_test.go b/pkg/ccl/sqlproxyccl/proxy_test.go index 63bafaf4c8dc..1c8c56b2a2ef 100644 --- a/pkg/ccl/sqlproxyccl/proxy_test.go +++ b/pkg/ccl/sqlproxyccl/proxy_test.go @@ -24,6 +24,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" + "github.com/jackc/pgproto3/v2" "github.com/jackc/pgx/v4" "github.com/stretchr/testify/require" ) @@ -69,31 +70,33 @@ openssl req -new -x509 -sha256 -key testserver.key -out testserver.crt \ func testingTenantIDFromDatabaseForAddr( addr string, validTenant string, -) func(map[string]string, *Conn) (config *BackendConfig, clientErr error) { - return func(p map[string]string, _ *Conn) (config *BackendConfig, clientErr error) { +) func(msg *pgproto3.StartupMessage) (net.Conn, error) { + return func(msg *pgproto3.StartupMessage) (net.Conn, error) { const dbKey = "database" + p := msg.Parameters db, ok := p[dbKey] if !ok { - return nil, errors.Newf("need to specify database") + return nil, NewErrorf( + CodeParamsRoutingFailed, "need to specify database", + ) } sl := strings.SplitN(db, "_", 2) if len(sl) != 2 { - return nil, errors.Newf("malformed database name") + return nil, NewErrorf( + CodeParamsRoutingFailed, "malformed database name", + ) } db, tenantID := sl[0], sl[1] if tenantID != validTenant { - return nil, errors.Newf("invalid tenantID") + return nil, NewErrorf(CodeParamsRoutingFailed, "invalid tenantID") } p[dbKey] = db - return &BackendConfig{ - OutgoingAddress: addr, - TLSConf: &tls.Config{ - // NB: this would be false in production. - InsecureSkipVerify: true, - }, - }, nil + return BackendDial(msg, addr, &tls.Config{ + // NB: this would be false in production. + InsecureSkipVerify: true, + }) } } @@ -135,12 +138,9 @@ func TestLongDBName(t *testing.T) { ac := makeAssertCtx() - var m map[string]string opts := Options{ - BackendConfigFromParams: func( - mm map[string]string, _ *Conn) (config *BackendConfig, clientErr error) { - m = mm - return nil, errors.New("boom") + BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { + return nil, NewErrorf(CodeParamsRoutingFailed, "boom") }, OnSendErrToClient: ac.onSendErrToClient, } @@ -150,7 +150,6 @@ func TestLongDBName(t *testing.T) { longDB := strings.Repeat("x", 70) // 63 is limit pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s", addr, longDB) ac.assertConnectErr(t, pgurl, "" /* suffix */, CodeParamsRoutingFailed, "boom") - require.Equal(t, longDB, m["database"]) require.Equal(t, int64(1), s.metrics.RoutingErrCount.Count()) } @@ -162,8 +161,10 @@ func TestFailedConnection(t *testing.T) { ac := makeAssertCtx() opts := Options{ - BackendConfigFromParams: testingTenantIDFromDatabaseForAddr("undialable%$!@$", "29"), - OnSendErrToClient: ac.onSendErrToClient, + BackendDialer: testingTenantIDFromDatabaseForAddr( + "undialable%$!@$", "29", + ), + OnSendErrToClient: ac.onSendErrToClient, } s, addr, done := setupTestProxyWithCerts(t, &opts) defer done() @@ -230,11 +231,14 @@ func TestProxyAgainstSecureCRDB(t *testing.T) { opts := Options{ BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { return &BackendConfig{ - OutgoingAddress: tc.Server(0).ServingSQLAddr(), - TLSConf: outgoingTLSConfig, OnConnectionSuccess: func() { connSuccess = true }, }, nil }, + BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { + return BackendDial( + msg, tc.Server(0).ServingSQLAddr(), outgoingTLSConfig, + ) + }, } s, addr, done := setupTestProxyWithCerts(t, &opts) defer done() @@ -275,11 +279,14 @@ func TestProxyTLSClose(t *testing.T) { BackendConfigFromParams: func(params map[string]string, conn *Conn) (*BackendConfig, error) { proxyIncomingConn.Store(conn) return &BackendConfig{ - OutgoingAddress: tc.Server(0).ServingSQLAddr(), - TLSConf: outgoingTLSConfig, OnConnectionSuccess: func() { connSuccess = true }, }, nil }, + BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { + return BackendDial( + msg, tc.Server(0).ServingSQLAddr(), outgoingTLSConfig, + ) + }, } s, addr, done := setupTestProxyWithCerts(t, &opts) defer done() @@ -315,13 +322,8 @@ func TestProxyModifyRequestParams(t *testing.T) { outgoingTLSConfig.InsecureSkipVerify = true opts := Options{ - BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { - return &BackendConfig{ - OutgoingAddress: tc.Server(0).ServingSQLAddr(), - TLSConf: outgoingTLSConfig, - }, nil - }, - ModifyRequestParams: func(params map[string]string) { + BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { + params := msg.Parameters require.EqualValues(t, map[string]string{ "authToken": "abc123", "user": "bogususer", @@ -331,6 +333,8 @@ func TestProxyModifyRequestParams(t *testing.T) { // and the backend is changed to a user that actually exists. delete(params, "authToken") params["user"] = "root" + + return BackendDial(msg, tc.Server(0).ServingSQLAddr(), outgoingTLSConfig) }, } s, proxyAddr, done := setupTestProxyWithCerts(t, &opts) @@ -354,11 +358,8 @@ func newInsecureProxyServer( t *testing.T, outgoingAddr string, outgoingTLSConfig *tls.Config, ) (addr string, cleanup func()) { s := NewServer(Options{ - BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { - return &BackendConfig{ - OutgoingAddress: outgoingAddr, - TLSConf: outgoingTLSConfig, - }, nil + BackendDialer: func(message *pgproto3.StartupMessage) (net.Conn, error) { + return BackendDial(message, outgoingAddr, outgoingTLSConfig) }, }) const listenAddress = "127.0.0.1:0" @@ -444,11 +445,8 @@ func TestProxyRefuseConn(t *testing.T) { ac := makeAssertCtx() opts := Options{ - BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { - return &BackendConfig{ - OutgoingAddress: tc.Server(0).ServingSQLAddr(), - TLSConf: outgoingTLSConfig, - }, NewErrorf(CodeProxyRefusedConnection, "too many attempts") + BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { + return nil, NewErrorf(CodeProxyRefusedConnection, "too many attempts") }, OnSendErrToClient: ac.onSendErrToClient, } @@ -477,8 +475,6 @@ func TestProxyKeepAlive(t *testing.T) { opts := Options{ BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { return &BackendConfig{ - OutgoingAddress: tc.Server(0).ServingSQLAddr(), - TLSConf: outgoingTLSConfig, // Don't let connections last more than 100ms. KeepAliveLoop: func(ctx context.Context) error { t := timeutil.NewTimer() @@ -495,6 +491,9 @@ func TestProxyKeepAlive(t *testing.T) { }, }, nil }, + BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { + return BackendDial(msg, tc.Server(0).ServingSQLAddr(), outgoingTLSConfig) + }, } s, addr, done := setupTestProxyWithCerts(t, &opts) defer done() @@ -517,3 +516,52 @@ func TestProxyKeepAlive(t *testing.T) { "unexpected error received: %v", err, ) } + +func TestProxyAgainstSecureCRDBWithIdleTimeout(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) + defer tc.Stopper().Stop(ctx) + + outgoingTLSConfig, err := tc.Server(0).RPCContext().GetClientTLSConfig() + require.NoError(t, err) + outgoingTLSConfig.InsecureSkipVerify = true + + idleTimeout, _ := time.ParseDuration("0.5s") + var connSuccess bool + opts := Options{ + BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { + return &BackendConfig{ + OnConnectionSuccess: func() { connSuccess = true }, + }, nil + }, + BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { + conn, err := BackendDial(msg, tc.Server(0).ServingSQLAddr(), outgoingTLSConfig) + if err != nil { + return nil, err + } + return IdleDisconnectOverlay(conn, idleTimeout), nil + }, + } + s, addr, done := setupTestProxyWithCerts(t, &opts) + defer done() + + url := fmt.Sprintf("postgres://root:admin@%s/defaultdb_29?sslmode=require", addr) + conn, err := pgx.Connect(context.Background(), url) + require.NoError(t, err) + require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) + defer func() { + require.NoError(t, conn.Close(ctx)) + require.True(t, connSuccess) + require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) + }() + + var n int + err = conn.QueryRow(context.Background(), "SELECT $1::int", 1).Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + time.Sleep(idleTimeout * 2) + err = conn.QueryRow(context.Background(), "SELECT $1::int", 1).Scan(&n) + require.EqualError(t, err, "FATAL: terminating connection due to idle timeout (SQLSTATE 57P01)") +}