From 0a61e774794d8cd538d56cb9ecda301058b7c825 Mon Sep 17 00:00:00 2001 From: Darin Peshev Date: Wed, 9 Dec 2020 17:22:36 -0800 Subject: [PATCH] ccl/sqlproxyccl: idle connection timout support Previusly, the connection from the end user to the backend can be idle for any period of time without disconnecting. In certain cases, we want the ability for idle connections to disconnect when the idle time exceeds redefined timeout. To address this, this patch adds tracking of how long the connection to the backend has been idle. If there is a timeout specified, the connection will be disconnected and an error message will be sent back to the end user. This PR also refactors the proxy configuration. Previsouly - the proxy was cnfigured via BackendConfigFromParams func passed by the proxy user. This is now deprecated and instead proxy user has to pass a BackendDialer that does that. Release note: None --- pkg/ccl/cliccl/BUILD.bazel | 1 + pkg/ccl/cliccl/mtproxy.go | 20 +- pkg/ccl/sqlproxyccl/BUILD.bazel | 4 + pkg/ccl/sqlproxyccl/backend_dialer.go | 79 ++++++++ pkg/ccl/sqlproxyccl/error.go | 14 +- pkg/ccl/sqlproxyccl/errorcode_string.go | 5 +- .../sqlproxyccl/idle_disconnect_connection.go | 97 ++++++++++ .../idle_disconnect_connection_test.go | 137 ++++++++++++++ pkg/ccl/sqlproxyccl/metrics.go | 10 +- pkg/ccl/sqlproxyccl/proxy.go | 174 ++++++++++-------- pkg/ccl/sqlproxyccl/proxy_test.go | 134 +++++++++----- 11 files changed, 536 insertions(+), 139 deletions(-) create mode 100644 pkg/ccl/sqlproxyccl/backend_dialer.go create mode 100644 pkg/ccl/sqlproxyccl/idle_disconnect_connection.go create mode 100644 pkg/ccl/sqlproxyccl/idle_disconnect_connection_test.go diff --git a/pkg/ccl/cliccl/BUILD.bazel b/pkg/ccl/cliccl/BUILD.bazel index 217382fb7922..068d930478fa 100644 --- a/pkg/ccl/cliccl/BUILD.bazel +++ b/pkg/ccl/cliccl/BUILD.bazel @@ -40,6 +40,7 @@ go_library( "//vendor/github.com/cockroachdb/cmux", "//vendor/github.com/cockroachdb/errors", "//vendor/github.com/cockroachdb/errors/oserror", + "//vendor/github.com/jackc/pgproto3/v2:pgproto3", "//vendor/github.com/spf13/cobra", "//vendor/golang.org/x/sync/errgroup", ], diff --git a/pkg/ccl/cliccl/mtproxy.go b/pkg/ccl/cliccl/mtproxy.go index 6c13f2617195..db2ca4efbc9f 100644 --- a/pkg/ccl/cliccl/mtproxy.go +++ b/pkg/ccl/cliccl/mtproxy.go @@ -19,6 +19,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl" "github.com/cockroachdb/cockroach/pkg/cli" "github.com/cockroachdb/errors" + "github.com/jackc/pgproto3/v2" "github.com/spf13/cobra" "golang.org/x/sync/errgroup" ) @@ -148,22 +149,19 @@ Uuwb2FVdh76ZK0AVd3Jh3KJs4+hr2u9syHaa7UPKXTcZsFWlGwZuu6X5A+0SO0S2 IncomingTLSConfig: &tls.Config{ Certificates: []tls.Certificate{cer}, }, - BackendConfigFromParams: func( - params map[string]string, _ *sqlproxyccl.Conn, - ) (config *sqlproxyccl.BackendConfig, clientErr error) { + BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { + params := msg.Parameters const magic = "prancing-pony" - cfg := &sqlproxyccl.BackendConfig{ - OutgoingAddress: sqlProxyTargetAddr, - TLSConf: outgoingConf, - } if strings.HasPrefix(params["database"], magic+".") { params["database"] = params["database"][len(magic)+1:] - return cfg, nil + } else if params["options"] != "--cluster="+magic { + return nil, errors.Errorf("client failed to pass '%s' via database or options", magic) } - if params["options"] == "--cluster="+magic { - return cfg, nil + conn, err := sqlproxyccl.BackendDial(msg, sqlProxyTargetAddr, outgoingConf) + if err != nil { + return nil, err } - return nil, errors.Errorf("client failed to pass '%s' via database or options", magic) + return conn, nil }, }) diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index 54a6c51de926..53853c234cec 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -3,8 +3,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "sqlproxyccl", srcs = [ + "backend_dialer.go", "error.go", "errorcode_string.go", + "idle_disconnect_connection.go", "metrics.go", "proxy.go", "server.go", @@ -26,6 +28,7 @@ go_library( go_test( name = "sqlproxyccl_test", srcs = [ + "idle_disconnect_connection_test.go", "main_test.go", "proxy_test.go", "server_test.go", @@ -43,6 +46,7 @@ go_test( "//pkg/util/randutil", "//pkg/util/timeutil", "//vendor/github.com/cockroachdb/errors", + "//vendor/github.com/jackc/pgproto3/v2:pgproto3", "//vendor/github.com/jackc/pgx/v4:pgx", "//vendor/github.com/stretchr/testify/require", ], diff --git a/pkg/ccl/sqlproxyccl/backend_dialer.go b/pkg/ccl/sqlproxyccl/backend_dialer.go new file mode 100644 index 000000000000..0e3aff5e4cbb --- /dev/null +++ b/pkg/ccl/sqlproxyccl/backend_dialer.go @@ -0,0 +1,79 @@ +// Copyright 2020 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "crypto/tls" + "encoding/binary" + "io" + "net" + + "github.com/jackc/pgproto3/v2" +) + +// BackendDial is an example backend dialer that does a TCP/IP connection +// to a backend, SSL and forwards the start message. +func BackendDial( + msg *pgproto3.StartupMessage, outgoingAddress string, tlsConfig *tls.Config, +) (net.Conn, error) { + conn, err := net.Dial("tcp", outgoingAddress) + if err != nil { + return nil, NewErrorf( + CodeBackendDown, "unable to reach backend SQL server: %v", err, + ) + } + conn, err = SSLOverlay(conn, tlsConfig) + if err != nil { + return nil, err + } + err = RelayStartupMsg(conn, msg) + if err != nil { + return nil, NewErrorf( + CodeBackendDown, "relaying StartupMessage to target server %v: %v", + outgoingAddress, err) + } + return conn, nil +} + +// SSLOverlay attempts to upgrade the PG connection to use SSL +// if a tls.Config is specified.. +func SSLOverlay(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) { + if tlsConfig == nil { + return conn, nil + } + + var err error + // Send SSLRequest. + if err := binary.Write(conn, binary.BigEndian, pgSSLRequest); err != nil { + return nil, NewErrorf( + CodeBackendDown, "sending SSLRequest to target server: %v", err, + ) + } + + response := make([]byte, 1) + if _, err = io.ReadFull(conn, response); err != nil { + return nil, + NewErrorf(CodeBackendDown, "reading response to SSLRequest") + } + + if response[0] != pgAcceptSSLRequest { + return nil, NewErrorf( + CodeBackendRefusedTLS, "target server refused TLS connection", + ) + } + + outCfg := tlsConfig.Clone() + return tls.Client(conn, outCfg), nil +} + +// RelayStartupMsg forwards the start message on the backend connection. +func RelayStartupMsg(conn net.Conn, msg *pgproto3.StartupMessage) (err error) { + _, err = conn.Write(msg.Encode(nil)) + return +} diff --git a/pkg/ccl/sqlproxyccl/error.go b/pkg/ccl/sqlproxyccl/error.go index df4fc4d4237b..2c1c88e0f8c4 100644 --- a/pkg/ccl/sqlproxyccl/error.go +++ b/pkg/ccl/sqlproxyccl/error.go @@ -66,20 +66,26 @@ const ( // CodeExpiredClientConnection indicates that proxy connection to the client // has expired and should be closed. CodeExpiredClientConnection + + // CodeIdleDisconnect indicates that the connection was disconnected for + // being idle for longer than the specified timeout. + CodeIdleDisconnect ) -type codeError struct { +// CodeError is combines an error with one of the above codes to ease +// the processing of the errors. +type CodeError struct { code ErrorCode err error } -func (e *codeError) Error() string { +func (e *CodeError) Error() string { return fmt.Sprintf("%s: %s", e.code, e.err) } -// NewErrorf returns a new codeError out of the supplied args. +// NewErrorf returns a new CodeError out of the supplied args. func NewErrorf(code ErrorCode, format string, args ...interface{}) error { - return &codeError{ + return &CodeError{ code: code, err: errors.Errorf(format, args...), } diff --git a/pkg/ccl/sqlproxyccl/errorcode_string.go b/pkg/ccl/sqlproxyccl/errorcode_string.go index feb7c15b6d65..91eedaeef236 100644 --- a/pkg/ccl/sqlproxyccl/errorcode_string.go +++ b/pkg/ccl/sqlproxyccl/errorcode_string.go @@ -20,11 +20,12 @@ func _() { _ = x[CodeClientDisconnected-10] _ = x[CodeProxyRefusedConnection-11] _ = x[CodeExpiredClientConnection-12] + _ = x[CodeIdleDisconnect-13] } -const _ErrorCode_name = "CodeClientReadFailedCodeClientWriteFailedCodeUnexpectedInsecureStartupMessageCodeSNIRoutingFailedCodeUnexpectedStartupMessageCodeParamsRoutingFailedCodeBackendDownCodeBackendRefusedTLSCodeBackendDisconnectedCodeClientDisconnectedCodeProxyRefusedConnectionCodeExpiredClientConnection" +const _ErrorCode_name = "CodeClientReadFailedCodeClientWriteFailedCodeUnexpectedInsecureStartupMessageCodeSNIRoutingFailedCodeUnexpectedStartupMessageCodeParamsRoutingFailedCodeBackendDownCodeBackendRefusedTLSCodeBackendDisconnectedCodeClientDisconnectedCodeProxyRefusedConnectionCodeExpiredClientConnectionCodeIdleDisconnect" -var _ErrorCode_index = [...]uint16{0, 20, 41, 77, 97, 125, 148, 163, 184, 207, 229, 255, 282} +var _ErrorCode_index = [...]uint16{0, 20, 41, 77, 97, 125, 148, 163, 184, 207, 229, 255, 282, 300} func (i ErrorCode) String() string { i -= 1 diff --git a/pkg/ccl/sqlproxyccl/idle_disconnect_connection.go b/pkg/ccl/sqlproxyccl/idle_disconnect_connection.go new file mode 100644 index 000000000000..bc5fd1be9e41 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/idle_disconnect_connection.go @@ -0,0 +1,97 @@ +// Copyright 2020 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "net" + "time" + + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" + "github.com/cockroachdb/errors" +) + +// IdleDisconnectConnection is a wrapper around net.Conn that disconnects if +// connection is idle. The idle time is only counted while the client is +// waiting, blocked on Read. +type IdleDisconnectConnection struct { + net.Conn + timeout time.Duration + mu struct { + syncutil.Mutex + lastDeadlineSetAt time.Time + } +} + +var errNotSupported = errors.Errorf( + "Not supported for IdleDisconnectConnection", +) + +func (c *IdleDisconnectConnection) updateDeadline() error { + now := timeutil.Now() + // If it has been more than 1% of the timeout duration - advance the deadline. + c.mu.Lock() + defer c.mu.Unlock() + if now.Sub(c.mu.lastDeadlineSetAt) > c.timeout/100 { + c.mu.lastDeadlineSetAt = now + + if err := c.Conn.SetReadDeadline(now.Add(c.timeout)); err != nil { + return err + } + } + return nil +} + +// Read reads data from the connection with timeout. +func (c *IdleDisconnectConnection) Read(b []byte) (n int, err error) { + if err := c.updateDeadline(); err != nil { + return 0, err + } + return c.Conn.Read(b) +} + +// Write writes data to the connection and sets the read timeout. +func (c *IdleDisconnectConnection) Write(b []byte) (n int, err error) { + // The Write for the connection is not blocking (or can block only temporary + // in case of flow control). For idle connections, the Read will be the call + // that will block and stay blocked until the backend doesn't send something. + // However, it is theoretically possible, that the traffic is only going in + // one direction - from the proxy to the backend, in which case we will call + // repeatedly Write but stay blocked on the Read. For that specific case - the + // write pushes further out the read deadline so the read doesn't timeout. + if err := c.updateDeadline(); err != nil { + return 0, err + } + return c.Conn.Write(b) +} + +// SetDeadline is unsupported as it will interfere with the reads. +func (c *IdleDisconnectConnection) SetDeadline(t time.Time) error { + return errNotSupported +} + +// SetReadDeadline is unsupported as it will interfere with the reads. +func (c *IdleDisconnectConnection) SetReadDeadline(t time.Time) error { + return errNotSupported +} + +// SetWriteDeadline is unsupported as it will interfere with the reads. +func (c *IdleDisconnectConnection) SetWriteDeadline(t time.Time) error { + return errNotSupported +} + +// IdleDisconnectOverlay upgrades the connection to one that closes when +// idle for more than timeout duration. Timeout of zero will turn off +// the idle disconnect code. +func IdleDisconnectOverlay(conn net.Conn, timeout time.Duration) net.Conn { + if timeout != 0 { + return &IdleDisconnectConnection{Conn: conn, timeout: timeout} + } + return conn +} diff --git a/pkg/ccl/sqlproxyccl/idle_disconnect_connection_test.go b/pkg/ccl/sqlproxyccl/idle_disconnect_connection_test.go new file mode 100644 index 000000000000..002901450914 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/idle_disconnect_connection_test.go @@ -0,0 +1,137 @@ +// Copyright 2020 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "fmt" + "io" + "net" + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/require" +) + +func setupServerWithIdleDisconnect(t testing.TB, timeout time.Duration) net.Addr { + server, err := net.Listen("tcp", "") + require.NoError(t, err) + + // Server is echoing back bytes in an infinite loop. + go func() { + cServ, err := server.Accept() + if err != nil { + t.Errorf("Error during accept: %v", err) + } + defer cServ.Close() + cServ = IdleDisconnectOverlay(cServ, timeout) + _, _ = io.Copy(cServ, cServ) + }() + return server.Addr() +} + +func ping(conn net.Conn) error { + n, err := conn.Write([]byte{1}) + if err != nil { + return err + } + if n != 1 { + return errors.Newf("Expected 1 written byte but got %d", n) + } + n, err = conn.Read([]byte{1}) + if err != nil { + return err + } + if n != 1 { + return errors.Newf("Expected 1 read byte but got %d", n) + } + return nil +} + +func benchmarkSocketRead(timeout time.Duration, b *testing.B) { + addr := setupServerWithIdleDisconnect(b, timeout) + + cCli, err := net.Dial("tcp", addr.String()) + if err != nil { + b.Errorf("Error during dial: %v", err) + } + defer cCli.Close() + + bCli := []byte{1} + for i := 0; i < b.N; i++ { + _, err = cCli.Write(bCli) + if err != nil { + b.Errorf("Error during write: %v", err) + } + _, err = cCli.Read(bCli) + if err != nil { + b.Errorf("Error during read: %v", err) + } + } + _, err = cCli.Write([]byte{}) // This serves as EOF and shuts down the echo server + if err != nil { + b.Errorf("Error during read: %v", err) + } +} + +// No statistically significant difference in a single roundtrip time between +// using and not using deadline as implemented above. Both show the same value in my tests. +// SocketReadWithDeadline-32 11.1µs ± 1% +// SocketReadWithoutDeadline-32 11.0µs ± 3% +func BenchmarkSocketReadWithoutDeadline(b *testing.B) { + benchmarkSocketRead(0, b) +} +func BenchmarkSocketReadWithDeadline(b *testing.B) { + benchmarkSocketRead(1e8, b) +} + +func TestIdleDisconnectOverlay(t *testing.T) { + defer leaktest.AfterTest(t)() + tests := []struct { + timeout time.Duration + willCloseAtCheck int + }{ + // The disconnect checks are done at 0.2s, 0.7s and 1.5s marks. + {0, 0}, + {0.1e9, 1}, + {0.3e9, 2}, + {0.6e9, 3}, + {1e9, 0}, + } + + for _, tt := range tests { + name := fmt.Sprintf( + "timeout(%s)-willCloseAt(%d)", tt.timeout, tt.willCloseAtCheck, + ) + t.Run(name, func(t *testing.T) { + addr := setupServerWithIdleDisconnect(t, tt.timeout) + conn, err := net.Dial("tcp", addr.String()) + require.NoError(t, err, "Unable to dial server") + time.Sleep(.2e9) + if tt.willCloseAtCheck == 1 { + require.Error(t, ping(conn)) + return + } + require.NoError(t, ping(conn)) + time.Sleep(.5e9) + if tt.willCloseAtCheck == 2 { + require.Error(t, ping(conn)) + return + } + require.NoError(t, ping(conn)) + time.Sleep(.8e9) + if tt.willCloseAtCheck == 3 { + require.Error(t, ping(conn)) + return + } + require.NoError(t, ping(conn)) + }) + } +} diff --git a/pkg/ccl/sqlproxyccl/metrics.go b/pkg/ccl/sqlproxyccl/metrics.go index e8115bf66301..be6159f0b3e7 100644 --- a/pkg/ccl/sqlproxyccl/metrics.go +++ b/pkg/ccl/sqlproxyccl/metrics.go @@ -14,6 +14,7 @@ import "github.com/cockroachdb/cockroach/pkg/util/metric" // operations. type Metrics struct { BackendDisconnectCount *metric.Counter + IdleDisconnectCount *metric.Counter BackendDownCount *metric.Counter ClientDisconnectCount *metric.Counter CurConnCount *metric.Gauge @@ -53,10 +54,16 @@ var ( Measurement: "Disconnects", Unit: metric.Unit_COUNT, } + metaIdleDisconnectCount = metric.Metadata{ + Name: "proxy.err.idle_disconnect", + Help: "Number of disconnects due to idle timeout", + Measurement: "Idle Disconnects", + Unit: metric.Unit_COUNT, + } metaClientDisconnectCount = metric.Metadata{ Name: "proxy.err.client_disconnect", Help: "Number of disconnects initiated by clients", - Measurement: "Disconnects", + Measurement: "Client Disconnects", Unit: metric.Unit_COUNT, } metaRefusedConnCount = metric.Metadata{ @@ -83,6 +90,7 @@ var ( func MakeProxyMetrics() Metrics { return Metrics{ BackendDisconnectCount: metric.NewCounter(metaBackendDisconnectCount), + IdleDisconnectCount: metric.NewCounter(metaIdleDisconnectCount), BackendDownCount: metric.NewCounter(metaBackendDownCount), ClientDisconnectCount: metric.NewCounter(metaClientDisconnectCount), CurConnCount: metric.NewGauge(metaCurConnCount), diff --git a/pkg/ccl/sqlproxyccl/proxy.go b/pkg/ccl/sqlproxyccl/proxy.go index 26fbf60dd543..d7db698dada1 100644 --- a/pkg/ccl/sqlproxyccl/proxy.go +++ b/pkg/ccl/sqlproxyccl/proxy.go @@ -11,9 +11,9 @@ package sqlproxyccl import ( "context" "crypto/tls" - "encoding/binary" "io" "net" + "os" "github.com/cockroachdb/errors" "github.com/jackc/pgproto3/v2" @@ -26,6 +26,7 @@ var pgSSLRequest = []int32{8, 80877103} // BackendConfig contains the configuration of a backend connection that is // being proxied. +// To be removed once all clients are migrated to use backend dialer. type BackendConfig struct { // The address to which the connection is forwarded. OutgoingAddress string @@ -50,23 +51,30 @@ type BackendConfig struct { type Options struct { IncomingTLSConfig *tls.Config // config used for client -> proxy connection - // TODO(tbg): this is unimplemented and exists only to check which clients - // allow use of SNI. Should always return ("", nil). - BackendConfigFromSNI func(serverName string) (config *BackendConfig, clientErr error) // BackendFromParams returns the config to use for the proxy -> backend // connection. The TLS config is in it and it must have an appropriate // ServerName for the remote backend. + // Deprecated: processing of the params now happens in the BackendDialer. + // This is only here to support OnSuccess and KeepAlive. BackendConfigFromParams func( params map[string]string, incomingConn *Conn, ) (config *BackendConfig, clientErr error) // If set, consulted to modify the parameters set by the frontend before // forwarding them to the backend during startup. + // Deprecated: include the code that modifies the request params + // in the backend dialer. ModifyRequestParams func(map[string]string) // If set, consulted to decorate an error message to be sent to the client. // The error passed to this method will contain no internal information. OnSendErrToClient func(code ErrorCode, msg string) string + + // If set, will be used to establish and return connection to the backend. + // If not set, the old logic will be used. + // The argument is the startup message received from the frontend. It + // contains the protocol version and params sent by the client. + BackendDialer func(msg *pgproto3.StartupMessage) (net.Conn, error) } // Proxy takes an incoming client connection and relays it to a backend SQL @@ -76,9 +84,16 @@ func (s *Server) Proxy(proxyConn *Conn) error { if s.opts.OnSendErrToClient != nil { msg = s.opts.OnSendErrToClient(code, msg) } + + var pgCode string + if code == CodeIdleDisconnect { + pgCode = "57P01" // admin shutdown + } else { + pgCode = "08004" // rejected connection + } _, _ = conn.Write((&pgproto3.ErrorResponse{ Severity: "FATAL", - Code: "08004", // rejected connection + Code: pgCode, Message: msg, }).Encode(nil)) } @@ -88,6 +103,7 @@ func (s *Server) Proxy(proxyConn *Conn) error { // hence it's important to close `conn` rather than `proxyConn` since closing // the latter will not call `Close` method of `tls.Conn`. defer func() { _ = conn.Close() }() + var sniServerName string // If we have an incoming TLS Config, require that the client initiates // with a TLS connection. if s.opts.IncomingTLSConfig != nil { @@ -114,22 +130,11 @@ func (s *Server) Proxy(proxyConn *Conn) error { } cfg := s.opts.IncomingTLSConfig.Clone() - var sniServerName string + cfg.GetConfigForClient = func(h *tls.ClientHelloInfo) (*tls.Config, error) { sniServerName = h.ServerName return nil, nil } - if s.opts.BackendConfigFromSNI != nil { - cfg, clientErr := s.opts.BackendConfigFromSNI(sniServerName) - if clientErr != nil { - code := CodeSNIRoutingFailed - sendErrToClient(conn, code, clientErr.Error()) // won't actually be shown by most clients - return NewErrorf(code, "rejected by OutgoingAddrFromSNI") - } - if cfg.OutgoingAddress != "" { - return NewErrorf(CodeSNIRoutingFailed, "BackendConfigFromSNI is unimplemented") - } - } conn = tls.Server(conn, cfg) } @@ -142,73 +147,65 @@ func (s *Server) Proxy(proxyConn *Conn) error { return NewErrorf(CodeUnexpectedStartupMessage, "unsupported post-TLS startup message: %T", m) } - var backendConfig *BackendConfig - { - var clientErr error - backendConfig, clientErr = s.opts.BackendConfigFromParams(msg.Parameters, proxyConn) - if clientErr != nil { - var codeErr *codeError - if !errors.As(clientErr, &codeErr) { - codeErr = &codeError{ - code: CodeParamsRoutingFailed, - err: errors.Errorf("rejected by BackendConfigFromParams: %v", clientErr), + // Add the sniServerName (if used) as parameter + if sniServerName != "" { + msg.Parameters["sni-server"] = sniServerName + } + + backendDialer := s.opts.BackendDialer + if backendDialer == nil { + // This we need to keep until all the clients are switched to provide BackendDialer. + // It constructs a backend dialer from the information provided via + // BackendConfigFromParams function. + backendDialer = func(msg *pgproto3.StartupMessage) (net.Conn, error) { + backendConfig, clientErr := s.opts.BackendConfigFromParams(msg.Parameters, proxyConn) + if clientErr != nil { + var codeErr *CodeError + if !errors.As(clientErr, &codeErr) { + codeErr = &CodeError{ + code: CodeParamsRoutingFailed, + err: errors.Errorf("rejected by BackendConfigFromParams: %v", clientErr), + } } + return nil, codeErr } - if codeErr.code == CodeProxyRefusedConnection { - s.metrics.RefusedConnCount.Inc(1) - } else { - s.metrics.RoutingErrCount.Inc(1) + + // We should be able to remove this when the all clients switch to + // backend dialer. + if s.opts.ModifyRequestParams != nil { + s.opts.ModifyRequestParams(msg.Parameters) } - sendErrToClient(conn, codeErr.code, clientErr.Error()) - return codeErr + + crdbConn, err := BackendDial(msg, backendConfig.OutgoingAddress, backendConfig.TLSConf) + if err != nil { + return nil, err + } + + return crdbConn, nil } } - crdbConn, err := net.Dial("tcp", backendConfig.OutgoingAddress) + crdbConn, err := backendDialer(msg) if err != nil { s.metrics.BackendDownCount.Inc(1) - code := CodeBackendDown - sendErrToClient(conn, code, "unable to reach backend SQL server") - return NewErrorf(code, "dialing backend server: %v", err) - } - defer func() { _ = crdbConn.Close() }() - - if backendConfig.TLSConf != nil { - // Send SSLRequest. - if err := binary.Write(crdbConn, binary.BigEndian, pgSSLRequest); err != nil { - s.metrics.BackendDownCount.Inc(1) - return NewErrorf(CodeBackendDown, "sending SSLRequest to target server: %v", err) - } - - response := make([]byte, 1) - if _, err = io.ReadFull(crdbConn, response); err != nil { - s.metrics.BackendDownCount.Inc(1) - return NewErrorf(CodeBackendDown, "reading response to SSLRequest") + var codeErr *CodeError + if !errors.As(err, &codeErr) { + codeErr = &CodeError{ + code: CodeBackendDown, + err: errors.Errorf("unable to reach backend SQL server: %v", err), + } } - - if response[0] != pgAcceptSSLRequest { - s.metrics.BackendDownCount.Inc(1) - return NewErrorf(CodeBackendRefusedTLS, "target server refused TLS connection") + if codeErr.code == CodeProxyRefusedConnection { + s.metrics.RefusedConnCount.Inc(1) + } else if codeErr.code == CodeParamsRoutingFailed { + s.metrics.RoutingErrCount.Inc(1) } - - outCfg := backendConfig.TLSConf.Clone() - crdbConn = tls.Client(crdbConn, outCfg) - } - - if s.opts.ModifyRequestParams != nil { - s.opts.ModifyRequestParams(msg.Parameters) - } - - if _, err := crdbConn.Write(msg.Encode(nil)); err != nil { - s.metrics.BackendDownCount.Inc(1) - return NewErrorf(CodeBackendDown, "relaying StartupMessage to target server %v: %v", - backendConfig.OutgoingAddress, err) + sendErrToClient(conn, codeErr.code, codeErr.Error()) + return codeErr } + defer func() { _ = crdbConn.Close() }() s.metrics.SuccessfulConnCount.Inc(1) - if backendConfig.OnConnectionSuccess != nil { - backendConfig.OnConnectionSuccess() - } // These channels are buffered because we'll only consume one of them. errOutgoing := make(chan error, 1) @@ -217,6 +214,22 @@ func (s *Server) Proxy(proxyConn *Conn) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + if s.opts.BackendConfigFromParams != nil { + // Ignore the next error as we already did all checks in the BackendDialer + // so there shouldn't be any errors here. + // This is temporary until Spas moves processing of OnConnectionSuccess and + // KeepAliveLoop outside of the proxy. + backendConfig, _ := s.opts.BackendConfigFromParams(msg.Parameters, proxyConn) + if backendConfig.OnConnectionSuccess != nil { + backendConfig.OnConnectionSuccess() + } + if backendConfig.KeepAliveLoop != nil { + go func() { + errExpired <- backendConfig.KeepAliveLoop(ctx) + }() + } + } + go func() { _, err := io.Copy(crdbConn, conn) errOutgoing <- err @@ -225,11 +238,6 @@ func (s *Server) Proxy(proxyConn *Conn) error { _, err := io.Copy(conn, crdbConn) errIncoming <- err }() - if backendConfig.KeepAliveLoop != nil { - go func() { - errExpired <- backendConfig.KeepAliveLoop(ctx) - }() - } select { // NB: when using pgx, we see a nil errIncoming first on clean connection @@ -239,11 +247,21 @@ func (s *Server) Proxy(proxyConn *Conn) error { // client gets to close the connection once it's sent that message, // meaning either case is possible. case err := <-errIncoming: - if err != nil { + if err == nil { + return nil + } else if codeErr := (*CodeError)(nil); errors.As(err, &codeErr) && + codeErr.code == CodeExpiredClientConnection { + s.metrics.ExpiredClientConnCount.Inc(1) + sendErrToClient(conn, codeErr.code, codeErr.Error()) + return codeErr + } else if errors.Is(err, os.ErrDeadlineExceeded) { + s.metrics.IdleDisconnectCount.Inc(1) + sendErrToClient(conn, CodeIdleDisconnect, "terminating connection due to idle timeout") + return NewErrorf(CodeIdleDisconnect, "terminating connection due to idle timeout: %v", err) + } else { s.metrics.BackendDisconnectCount.Inc(1) return NewErrorf(CodeBackendDisconnected, "copying from target server to client: %s", err) } - return nil case err := <-errOutgoing: // The incoming connection got closed. if err != nil { diff --git a/pkg/ccl/sqlproxyccl/proxy_test.go b/pkg/ccl/sqlproxyccl/proxy_test.go index 63bafaf4c8dc..1c8c56b2a2ef 100644 --- a/pkg/ccl/sqlproxyccl/proxy_test.go +++ b/pkg/ccl/sqlproxyccl/proxy_test.go @@ -24,6 +24,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" + "github.com/jackc/pgproto3/v2" "github.com/jackc/pgx/v4" "github.com/stretchr/testify/require" ) @@ -69,31 +70,33 @@ openssl req -new -x509 -sha256 -key testserver.key -out testserver.crt \ func testingTenantIDFromDatabaseForAddr( addr string, validTenant string, -) func(map[string]string, *Conn) (config *BackendConfig, clientErr error) { - return func(p map[string]string, _ *Conn) (config *BackendConfig, clientErr error) { +) func(msg *pgproto3.StartupMessage) (net.Conn, error) { + return func(msg *pgproto3.StartupMessage) (net.Conn, error) { const dbKey = "database" + p := msg.Parameters db, ok := p[dbKey] if !ok { - return nil, errors.Newf("need to specify database") + return nil, NewErrorf( + CodeParamsRoutingFailed, "need to specify database", + ) } sl := strings.SplitN(db, "_", 2) if len(sl) != 2 { - return nil, errors.Newf("malformed database name") + return nil, NewErrorf( + CodeParamsRoutingFailed, "malformed database name", + ) } db, tenantID := sl[0], sl[1] if tenantID != validTenant { - return nil, errors.Newf("invalid tenantID") + return nil, NewErrorf(CodeParamsRoutingFailed, "invalid tenantID") } p[dbKey] = db - return &BackendConfig{ - OutgoingAddress: addr, - TLSConf: &tls.Config{ - // NB: this would be false in production. - InsecureSkipVerify: true, - }, - }, nil + return BackendDial(msg, addr, &tls.Config{ + // NB: this would be false in production. + InsecureSkipVerify: true, + }) } } @@ -135,12 +138,9 @@ func TestLongDBName(t *testing.T) { ac := makeAssertCtx() - var m map[string]string opts := Options{ - BackendConfigFromParams: func( - mm map[string]string, _ *Conn) (config *BackendConfig, clientErr error) { - m = mm - return nil, errors.New("boom") + BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { + return nil, NewErrorf(CodeParamsRoutingFailed, "boom") }, OnSendErrToClient: ac.onSendErrToClient, } @@ -150,7 +150,6 @@ func TestLongDBName(t *testing.T) { longDB := strings.Repeat("x", 70) // 63 is limit pgurl := fmt.Sprintf("postgres://unused:unused@%s/%s", addr, longDB) ac.assertConnectErr(t, pgurl, "" /* suffix */, CodeParamsRoutingFailed, "boom") - require.Equal(t, longDB, m["database"]) require.Equal(t, int64(1), s.metrics.RoutingErrCount.Count()) } @@ -162,8 +161,10 @@ func TestFailedConnection(t *testing.T) { ac := makeAssertCtx() opts := Options{ - BackendConfigFromParams: testingTenantIDFromDatabaseForAddr("undialable%$!@$", "29"), - OnSendErrToClient: ac.onSendErrToClient, + BackendDialer: testingTenantIDFromDatabaseForAddr( + "undialable%$!@$", "29", + ), + OnSendErrToClient: ac.onSendErrToClient, } s, addr, done := setupTestProxyWithCerts(t, &opts) defer done() @@ -230,11 +231,14 @@ func TestProxyAgainstSecureCRDB(t *testing.T) { opts := Options{ BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { return &BackendConfig{ - OutgoingAddress: tc.Server(0).ServingSQLAddr(), - TLSConf: outgoingTLSConfig, OnConnectionSuccess: func() { connSuccess = true }, }, nil }, + BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { + return BackendDial( + msg, tc.Server(0).ServingSQLAddr(), outgoingTLSConfig, + ) + }, } s, addr, done := setupTestProxyWithCerts(t, &opts) defer done() @@ -275,11 +279,14 @@ func TestProxyTLSClose(t *testing.T) { BackendConfigFromParams: func(params map[string]string, conn *Conn) (*BackendConfig, error) { proxyIncomingConn.Store(conn) return &BackendConfig{ - OutgoingAddress: tc.Server(0).ServingSQLAddr(), - TLSConf: outgoingTLSConfig, OnConnectionSuccess: func() { connSuccess = true }, }, nil }, + BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { + return BackendDial( + msg, tc.Server(0).ServingSQLAddr(), outgoingTLSConfig, + ) + }, } s, addr, done := setupTestProxyWithCerts(t, &opts) defer done() @@ -315,13 +322,8 @@ func TestProxyModifyRequestParams(t *testing.T) { outgoingTLSConfig.InsecureSkipVerify = true opts := Options{ - BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { - return &BackendConfig{ - OutgoingAddress: tc.Server(0).ServingSQLAddr(), - TLSConf: outgoingTLSConfig, - }, nil - }, - ModifyRequestParams: func(params map[string]string) { + BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { + params := msg.Parameters require.EqualValues(t, map[string]string{ "authToken": "abc123", "user": "bogususer", @@ -331,6 +333,8 @@ func TestProxyModifyRequestParams(t *testing.T) { // and the backend is changed to a user that actually exists. delete(params, "authToken") params["user"] = "root" + + return BackendDial(msg, tc.Server(0).ServingSQLAddr(), outgoingTLSConfig) }, } s, proxyAddr, done := setupTestProxyWithCerts(t, &opts) @@ -354,11 +358,8 @@ func newInsecureProxyServer( t *testing.T, outgoingAddr string, outgoingTLSConfig *tls.Config, ) (addr string, cleanup func()) { s := NewServer(Options{ - BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { - return &BackendConfig{ - OutgoingAddress: outgoingAddr, - TLSConf: outgoingTLSConfig, - }, nil + BackendDialer: func(message *pgproto3.StartupMessage) (net.Conn, error) { + return BackendDial(message, outgoingAddr, outgoingTLSConfig) }, }) const listenAddress = "127.0.0.1:0" @@ -444,11 +445,8 @@ func TestProxyRefuseConn(t *testing.T) { ac := makeAssertCtx() opts := Options{ - BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { - return &BackendConfig{ - OutgoingAddress: tc.Server(0).ServingSQLAddr(), - TLSConf: outgoingTLSConfig, - }, NewErrorf(CodeProxyRefusedConnection, "too many attempts") + BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { + return nil, NewErrorf(CodeProxyRefusedConnection, "too many attempts") }, OnSendErrToClient: ac.onSendErrToClient, } @@ -477,8 +475,6 @@ func TestProxyKeepAlive(t *testing.T) { opts := Options{ BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { return &BackendConfig{ - OutgoingAddress: tc.Server(0).ServingSQLAddr(), - TLSConf: outgoingTLSConfig, // Don't let connections last more than 100ms. KeepAliveLoop: func(ctx context.Context) error { t := timeutil.NewTimer() @@ -495,6 +491,9 @@ func TestProxyKeepAlive(t *testing.T) { }, }, nil }, + BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { + return BackendDial(msg, tc.Server(0).ServingSQLAddr(), outgoingTLSConfig) + }, } s, addr, done := setupTestProxyWithCerts(t, &opts) defer done() @@ -517,3 +516,52 @@ func TestProxyKeepAlive(t *testing.T) { "unexpected error received: %v", err, ) } + +func TestProxyAgainstSecureCRDBWithIdleTimeout(t *testing.T) { + defer leaktest.AfterTest(t)() + + ctx := context.Background() + tc := serverutils.StartNewTestCluster(t, 1, base.TestClusterArgs{}) + defer tc.Stopper().Stop(ctx) + + outgoingTLSConfig, err := tc.Server(0).RPCContext().GetClientTLSConfig() + require.NoError(t, err) + outgoingTLSConfig.InsecureSkipVerify = true + + idleTimeout, _ := time.ParseDuration("0.5s") + var connSuccess bool + opts := Options{ + BackendConfigFromParams: func(params map[string]string, _ *Conn) (*BackendConfig, error) { + return &BackendConfig{ + OnConnectionSuccess: func() { connSuccess = true }, + }, nil + }, + BackendDialer: func(msg *pgproto3.StartupMessage) (net.Conn, error) { + conn, err := BackendDial(msg, tc.Server(0).ServingSQLAddr(), outgoingTLSConfig) + if err != nil { + return nil, err + } + return IdleDisconnectOverlay(conn, idleTimeout), nil + }, + } + s, addr, done := setupTestProxyWithCerts(t, &opts) + defer done() + + url := fmt.Sprintf("postgres://root:admin@%s/defaultdb_29?sslmode=require", addr) + conn, err := pgx.Connect(context.Background(), url) + require.NoError(t, err) + require.Equal(t, int64(1), s.metrics.CurConnCount.Value()) + defer func() { + require.NoError(t, conn.Close(ctx)) + require.True(t, connSuccess) + require.Equal(t, int64(1), s.metrics.SuccessfulConnCount.Count()) + }() + + var n int + err = conn.QueryRow(context.Background(), "SELECT $1::int", 1).Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + time.Sleep(idleTimeout * 2) + err = conn.QueryRow(context.Background(), "SELECT $1::int", 1).Scan(&n) + require.EqualError(t, err, "FATAL: terminating connection due to idle timeout (SQLSTATE 57P01)") +}