From cd6e8dd554132e0eb46cb27fe50344dc16cfbb2f Mon Sep 17 00:00:00 2001 From: Jay Date: Mon, 28 Feb 2022 01:13:03 -0500 Subject: [PATCH] ccl/sqlproxyccl: complete connection migration support in the forwarder Informs cockroachdb#76000. Builds on top of #77109 and #77111. This commit completes the connection migration feature in the the forwarder within sqlproxy. The idea is as described in the RFC. A couple of new sqlproxy metrics have been added as well: - proxy.conn_migration.requested - proxy.conn_migration.accepted - proxy.conn_migration.rejected - proxy.conn_migration.success - proxy.conn_migration.error - proxy.conn_migration.fail - proxy.conn_migration.timeout_closed - proxy.conn_migration.timeout_recoverable For more details, see metrics.go in the sqlproxyccl package. Release note: None --- pkg/ccl/sqlproxyccl/BUILD.bazel | 3 + pkg/ccl/sqlproxyccl/forwarder.go | 642 ++++++++++++++++++++-- pkg/ccl/sqlproxyccl/forwarder_test.go | 556 ++++++++++++++++++- pkg/ccl/sqlproxyccl/metrics.go | 96 +++- pkg/ccl/sqlproxyccl/proxy_handler.go | 17 +- pkg/ccl/sqlproxyccl/proxy_handler_test.go | 365 ++++++++++++ 6 files changed, 1622 insertions(+), 57 deletions(-) diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index 935dd831bdb1..af8d1643c11d 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -41,6 +41,7 @@ go_library( "//pkg/util/stop", "//pkg/util/syncutil", "//pkg/util/timeutil", + "//pkg/util/uuid", "@com_github_cockroachdb_errors//:errors", "@com_github_cockroachdb_logtags//:logtags", "@com_github_jackc_pgproto3_v2//:pgproto3", @@ -81,6 +82,7 @@ go_test( "//pkg/sql/pgwire", "//pkg/sql/pgwire/pgerror", "//pkg/sql/pgwire/pgwirebase", + "//pkg/sql/tests", "//pkg/testutils", "//pkg/testutils/serverutils", "//pkg/testutils/skip", @@ -93,6 +95,7 @@ go_test( "//pkg/util/stop", "//pkg/util/syncutil", "//pkg/util/timeutil", + "@com_github_cockroachdb_cockroach_go_v2//crdb", "@com_github_cockroachdb_errors//:errors", "@com_github_jackc_pgconn//:pgconn", "@com_github_jackc_pgproto3_v2//:pgproto3", diff --git a/pkg/ccl/sqlproxyccl/forwarder.go b/pkg/ccl/sqlproxyccl/forwarder.go index 26af5b589fb4..6120286edf01 100644 --- a/pkg/ccl/sqlproxyccl/forwarder.go +++ b/pkg/ccl/sqlproxyccl/forwarder.go @@ -11,30 +11,86 @@ package sqlproxyccl import ( "context" "net" + "time" "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/interceptor" "github.com/cockroachdb/cockroach/pkg/sql/pgwire" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase" + "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/cockroach/pkg/util/uuid" "github.com/cockroachdb/errors" ) +const ( + // stateReady represents the state where the forwarder is ready to forward + // packets from the client to the server (and vice-versa). + stateReady int = iota + // stateTransferRequested represents the state where a session transfer was + // requested. + stateTransferRequested + // stateTransferInProgress represents the state where the session transfer + // is in-progress, and all incoming pgwire messages are buffered in the + // kernel's socket buffer. + stateTransferInProgress +) + +// transferTimeout corresponds to the timeout while waiting for the transfer +// state response. If this gets triggered, the transfer is aborted, and the +// connection will be terminated. +const defaultTransferTimeout = 15 * time.Second + +// clientMsgAny is used to denote a wildcard client message type. +var clientMsgAny = pgwirebase.ClientMessageType(0) + +var ( + // errReadAbortedDueToTransfer is returned whenever a Read call exits due to + // a session transfer. + errReadAbortedDueToTransfer = errors.New("read aborted due to transfer") + + // errTransferTimeout denotes that a transfer process has timed out where + // the forwarder wasn't able to locate the right transfer state response in + // time + errTransferTimeout = errors.New("transfer timeout") + + // errTransferProtocol indicates that an invariant has failed. + errTransferProtocol = errors.New("transfer protocol error") +) + // forwarder is used to forward pgwire messages from the client to the server, -// and vice-versa. At the moment, this does a direct proxying, and there is -// no intercepting. Once https://github.com/cockroachdb/cockroach/issues/76000 -// has been addressed, we will start intercepting pgwire messages at their -// boundaries here. +// and vice-versa. The forwarder instance should always be constructed through +// the forward function, which also starts the forwarder. // -// The forwarder instance should always be constructed through the forward -// function, which also starts the forwarder. +// The forwarder always starts with the ready state, which means that all +// messages from the client are forwarded to the server, and vice-versa. When +// a connection migration is requested through RequestTransfer, the forwarder +// transitions to the transferRequested state. If we are safe to transfer, +// the forwarder will transition to the transferInProgress state. If we are not, +// the forwarder aborts the transfer and transitions back to the ready state. +// Once the transfer process completes, the forwarder goes back to the ready +// state. At any point during the transfer process, we may also transition back +// to the ready state if the connection is deemed recoverable. type forwarder struct { // ctx is a single context used to control all goroutines spawned by the // forwarder. ctx context.Context ctxCancel context.CancelFunc + // connector is an instance of the connector, which will be used to open a + // new connection to a SQL pod. This connector instance must be associated + // to the same tenant as the forwarder. + connector *connector + + // metrics contains various counters reflecting proxy operations. This is + // the same as the metrics field in the proxyHandler instance. + metrics *metrics + // serverConn is only set after the authentication phase for the initial // connection. In the context of a connection migration, serverConn is only // replaced once the session has successfully been deserialized, and the - // old connection will be closed. + // old connection will be closed. Whenever serverConn gets updated, both + // clientMessageTypeSent and isServerMsgReadyReceived fields have to reset + // to their initial values. // // All reads from these connections must go through the interceptors. It is // not safe to read from these directly as the interceptors may have @@ -48,18 +104,85 @@ type forwarder struct { // // These interceptors have to match clientConn and serverConn. See comment // above on when those fields will be updated. - // - // TODO(jaylim-crl): Add updater functions that sets both conn and - // interceptor fields at the same time. At the moment, there's no use case - // besides the forward function. When connection migration happens, we - // will need to create a new serverInterceptor. We should remember to close - // old serverConn as well. - clientInterceptor *interceptor.BackendInterceptor // clientConn -> serverConn - serverInterceptor *interceptor.FrontendInterceptor // serverConn -> clientConn - - // errChan is a buffered channel that contains the first forwarder error. + clientInterceptor *interceptor.BackendInterceptor // clientConn's reader + serverInterceptor *interceptor.FrontendInterceptor // serverConn's reader + + // errCh is a buffered channel that contains the first forwarder error. // This channel may receive nil errors. - errChan chan error + errCh chan error + + // mu contains state protected by the forwarder's mutex. This is necessary + // since fields will be read and write from different goroutines. + mu struct { + syncutil.Mutex + + // state represents the forwarder's state. Most of the time, this will + // be stateReady. + state int + + // isServerMsgReadyReceived denotes whether a ReadyForQuery message has + // been received by the server-to-client processor *after* a message has + // been sent to the server through a Write on serverConn, either directly + // or through ForwardMsg. + // + // This will be initialized to true to implicitly denote that the server + // is ready to accept queries. + isServerMsgReadyReceived bool + + // transferKey is a unique string used to identify the transfer request, + // and will be passed into the SHOW TRANSFER STATE statement. This will + // be set to a randomly generated UUID whenever the transfer is + // requested through the RequestTransfer API, and back to an empty + // string whenever the transfer completes successfully or with a + // recoverable error. + transferKey string + + // transferCloserCh is a channel that must be set **before** + // transitioning to the transferRequested state, and this must be closed + // whenever the forwarder transitions back to the ready state, which + // signifies that the transfer process has completed successfully. + // Closing this will unblock the client-to-server processor and stop the + // timeout handler. + transferCloserCh chan struct{} + + // transferCtx has to be derived from ctx, and is created by the + // timeout handler. All transfer related operations will use this so + // that they can react to the timeout handler when that gets triggered. + transferCtx context.Context + + // transferConnRecoverable denotes whether the connection is recoverable + // during the transfer phase. This will be used by the timeout handler + // to determine whether it should close the forwarder. + transferConnRecoverable bool + } + + // ------------------------------------------------------------------------ + // The following fields are used for connection migration. + // + // For details on how connection migration works, read the following RFC: + // https://github.com/cockroachdb/cockroach/pull/75707. + // ------------------------------------------------------------------------ + + // disableClientInterrupts denotes that clientConn should not be interrupted + // by the custom readTimeoutConn that is wrapping the original clientConn. + // This is false by default. + disableClientInterrupts bool + + // clientMessageTypeSent indicates the message type for the last pgwire + // message sent to serverConn. This is used to determine a safe transfer + // point. + // + // If no message has been sent to serverConn by this forwarder, this will be + // clientMsgAny. + clientMessageTypeSent pgwirebase.ClientMessageType + + // Knobs used for testing. + testingKnobs struct { + isSafeTransferPoint func() bool + transferTimeoutDuration func() time.Duration + onTransferTimeoutHandlerStart func() + onTransferTimeoutHandlerFinish func() + } } // forward returns a new instance of forwarder, and starts forwarding messages @@ -69,41 +192,64 @@ type forwarder struct { // be nil in all cases except for testing. // // Note that callers MUST call Close in all cases, even if ctx was cancelled. -func forward(ctx context.Context, clientConn, serverConn net.Conn) *forwarder { +// +// TODO(jaylim-crl): Convert this to return a Forwarder interface. +func forward( + ctx context.Context, + connector *connector, + metrics *metrics, + clientConn net.Conn, + serverConn net.Conn, +) *forwarder { ctx, cancelFn := context.WithCancel(ctx) + // The forwarder starts with a state where connections migration can occur. f := &forwarder{ ctx: ctx, ctxCancel: cancelFn, - errChan: make(chan error, 1), + errCh: make(chan error, 1), + connector: connector, + metrics: metrics, } // The net.Conn object for the client is switched to a net.Conn that - // unblocks Read every second on idle to check for exit conditions. - // This is mainly used to unblock the request processor whenever the + // unblocks Read every second on idle to check for exit conditions. This is + // mainly used to unblock the client-to-server processor whenever the // forwarder has stopped, or a transfer has been requested. clientConn = pgwire.NewReadTimeoutConn(clientConn, func() error { // Context was cancelled. if f.ctx.Err() != nil { return f.ctx.Err() } - // TODO(jaylim-crl): Check for transfer state here. + + // Client interrupts are disabled. + if f.disableClientInterrupts { + return nil + } + + // We want to unblock idle clients whenever a transfer has been + // requested. This allows the client-to-server processor to be freed up + // to start the transfer. + f.mu.Lock() + defer f.mu.Unlock() + if f.mu.state != stateReady { + return errReadAbortedDueToTransfer + } return nil }) f.setClientConn(clientConn) f.setServerConn(serverConn) - // Start request (client to server) and response (server to client) - // processors. We will copy all pgwire messages/ from client to server - // (and vice-versa) until we encounter an error or a shutdown signal - // (i.e. context cancellation). + // Start client-to-server and server-to-client processors. We will copy all + // pgwire messages from client to server (and vice-versa) until we encounter + // an error, or a shutdown signal (i.e. context cancellation). go func() { defer f.Close() err := wrapClientToServerError(f.handleClientToServer()) select { - case f.errChan <- err: /* error reported */ + case f.errCh <- err: /* error reported */ default: /* the channel already contains an error */ } }() @@ -112,7 +258,7 @@ func forward(ctx context.Context, clientConn, serverConn net.Conn) *forwarder { err := wrapServerToClientError(f.handleServerToClient()) select { - case f.errChan <- err: /* error reported */ + case f.errCh <- err: /* error reported */ default: /* the channel already contains an error */ } }() @@ -130,37 +276,428 @@ func (f *forwarder) Close() { f.serverConn.Close() } +// RequestTransfer requests that the forwarder performs a best-effort connection +// migration whenever it can. It is best-effort because this will be a no-op if +// the forwarder is not in a state that is eligible for a connection migration. +// If a transfer is already in progress, or has been requested, this is a no-op. +func (f *forwarder) RequestTransfer() { + // We'll get an error if the forwarder is already in one of the transfer + // states. In that case, just ignore it since we want RequestTransfer to + // be idempotent. + if err := f.prepareTransfer(); err != nil { + // Transfer in progress. + return + } + f.metrics.ConnMigrationRequestedCount.Inc(1) +} + // handleClientToServer handles the communication from the client to the server. // This returns a context cancellation error whenever the forwarder's context -// is cancelled, or whenever forwarding fails. When ForwardMsg gets blocked on -// Read, we will unblock that through our custom readTimeoutConn wrapper, which -// gets triggered when context is cancelled. +// is cancelled, or whenever forwarding fails. func (f *forwarder) handleClientToServer() error { for f.ctx.Err() == nil { - if _, err := f.clientInterceptor.ForwardMsg(f.serverConn); err != nil { - return err + // Always peek the message to ensure that we're blocked on reading the + // header, rather than when forwarding. + typ, _, err := f.clientInterceptor.PeekMsg() + if err != nil && !errors.Is(err, errReadAbortedDueToTransfer) { + return errors.Wrap(err, "peeking message in client-to-server") + } + + // Note that if state changes the moment we unlock mu, that's fine. + // The fact that we got here signifies that there was already a message + // in the interceptor's buffer, which is valid for the state that was + // stale. Since this can only happen for the ready->transferRequested + // case, it follows that when a transfer gets requested the moment the + // message was read, we'll finish forwarding that last message before + // starting the transfer in the next iteration. + f.mu.Lock() + localState := f.mu.state + f.mu.Unlock() + + switch localState { + case stateReady: + // If we exit PeekMsg due to a transfer, the state must be in + // stateTransferRequested unless there's a bug. Be defensive here + // and peek again so that we don't end up blocking on the peek + // call within ForwardMsg because client interrupts will be + // disabled. + if errors.Is(err, errReadAbortedDueToTransfer) { + log.Error(f.ctx, "read aborted in client-to-server, but state is ready") + continue + } + + if forwardErr := func() error { + // We may be blocked waiting for more packets when reading the + // message's body. If a transfer was requested, there's no point + // interrupting Reads since we're not at a message boundary, and + // we cannot start a transfer, so don't interrupt at all. + f.disableClientInterrupts = true + defer func() { f.disableClientInterrupts = false }() + + f.clientMessageTypeSent = typ + + f.mu.Lock() + f.mu.isServerMsgReadyReceived = false + f.mu.Unlock() + + // When ForwardMsg gets blocked on Read, we will unblock that + // through our custom readTimeoutConn wrapper. + _, err := f.clientInterceptor.ForwardMsg(f.serverConn) + return err + }(); forwardErr != nil { + return errors.Wrap(forwardErr, "forwarding message in server-to-client") + } + + case stateTransferRequested: + // Can we perform the transfer? + if !f.isSafeTransferPoint() { + f.metrics.ConnMigrationRejectedCount.Inc(1) + + // Abort the transfer safely. + if err := f.finishTransfer(); err != nil { + return errors.Wrap(errTransferProtocol, + "aborting transfer due to unsafe transfer point") + } + continue + } + + f.metrics.ConnMigrationAcceptedCount.Inc(1) + + // Update the state first so that the server-to-client processor + // could start processing. If we update the state after sending the + // request, we may miss response messages. + f.mu.Lock() + f.mu.state = stateTransferInProgress + key, closer := f.mu.transferKey, f.mu.transferCloserCh + f.mu.Unlock() + + // Timeout handler begins when we send a transfer state request + // message to the server. + timeout := defaultTransferTimeout + if f.testingKnobs.transferTimeoutDuration != nil { + timeout = f.testingKnobs.transferTimeoutDuration() + } + f.runTransferTimeoutHandler(timeout) + + // Once we send the request, the forwarder should not send any + // further messages to the server. Since requests and responses + // are in a FIFO order, we can guarantee that the server will no + // longer return messages intended for the client once we receive + // responses for the SHOW TRANSFER STATE query. + if err := runShowTransferState(f.serverConn, key); err != nil { + return errors.Wrap(err, "writing transfer state request") + } + + // Wait until transfer is completed. Client-to-server processor is + // blocked to ensure that we don't send more client messagess to + // the server. + select { + case <-f.ctx.Done(): + return f.ctx.Err() + case <-closer: + // Channel is closed whenever transfer completes, so we are done. + } + + case stateTransferInProgress: + // This cannot happen unless there is a bug. While the transfer is + // in progress, the client-to-server processor has to be blocked, + // and the only way to transition into this state is to go through + // the stateTransferRequested state. + // + // Return an error to close the connection, rather than letting it + // continue silently. + return errors.Wrap(errTransferProtocol, + "transferInProgress state in client-to-server processor") } } return f.ctx.Err() } // handleServerToClient handles the communication from the server to the client. -// This returns a context cancellation error whenever the forwarder's context -// is cancelled, or whenever forwarding fails. When ForwardMsg gets blocked on -// Read, we will unblock that by closing serverConn through f.Close(). +// This returns an error whenever the forwarder's context is cancelled, or the +// connection can no longer be used due to the state of the server (e.g. failed +// forwarding, or non-recoverable transfers). func (f *forwarder) handleServerToClient() error { for f.ctx.Err() == nil { - if _, err := f.serverInterceptor.ForwardMsg(f.clientConn); err != nil { - return err + // Always peek the message to ensure that we're blocked on reading the + // header, rather than when forwarding or reading the entire message. + typ, _, err := f.serverInterceptor.PeekMsg() + if err != nil { + return errors.Wrap(err, "peeking message in server-to-client") + } + + // When we unlock mu, localState may be stale when transitioning from + // ready->transferRequested, or transferRequested->transferInProgress. + // This is fine because the moment we got here, we know that there must + // be a message in the interceptor's buffer, and that is valid for the + // previous state, so finish up the current message first. + localState := func() int { + f.mu.Lock() + defer f.mu.Unlock() + + // Have we seen a ReadyForQuery message? + // + // It doesn't matter which state we're in. Even if the transfer + // message has already been sent, the first message that we're going + // to be looking for isn't ReadyForQuery. This is only used to + // determine a safe transfer point. + if typ == pgwirebase.ServerMsgReady { + f.mu.isServerMsgReadyReceived = true + } + + return f.mu.state + }() + + switch localState { + case stateReady, stateTransferRequested: + // When ForwardMsg gets blocked on Read, we will unblock that by + // closing serverConn through f.Close(). + if _, err := f.serverInterceptor.ForwardMsg(f.clientConn); err != nil { + return errors.Wrap(err, "forwarding message in server-to-client") + } + + case stateTransferInProgress: + if err := f.processTransfer(); err != nil { + // Connection is not recoverable; terminate it right away. + if !isConnRecoverableError(err) { + f.metrics.ConnMigrationErrorCount.Inc(1) + return errors.Wrap(err, + "terminating due to non-recoverable connection during transfer") + } + log.Infof(f.ctx, "transfer failed, but connection is recoverable: %s", err) + f.metrics.ConnMigrationFailCount.Inc(1) + } else { + log.Infof(f.ctx, "transfer successful") + f.metrics.ConnMigrationSuccessCount.Inc(1) + } + if err := f.finishTransfer(); err != nil { + return errors.Wrap(errTransferProtocol, "wrapping up transfer process") + } } } return f.ctx.Err() } +// isSafeTransferPoint returns true if we're at a point where we're safe to +// transfer, and false otherwise. This should only be called during the +// transferRequested state. +func (f *forwarder) isSafeTransferPoint() bool { + if f.testingKnobs.isSafeTransferPoint != nil { + return f.testingKnobs.isSafeTransferPoint() + } + // Three conditions when evaluating a safe transfer point: + // 1. The last message sent to the SQL pod was a Sync(S) or + // SimpleQuery(Q), and a ReadyForQuery(Z) has already been + // received at the time of evaluation. + // 2. The last message sent to the SQL pod was a CopyDone(c), and + // a ReadyForQuery(Z) has already been received at the time of + // evaluation. + // 3. The last message sent to the SQL pod was a CopyFail(f), and + // a ReadyForQuery(Z) has already been received at the time of + // evaluation. + // + // NOTE: clientMessageTypeSent does not require a mutex because it is only + // set in the transferInProgress state, and this method should only be + // called in the transferRequested state. + switch f.clientMessageTypeSent { + case clientMsgAny, + pgwirebase.ClientMsgSync, + pgwirebase.ClientMsgSimpleQuery, + pgwirebase.ClientMsgCopyDone, + pgwirebase.ClientMsgCopyFail: + f.mu.Lock() + defer f.mu.Unlock() + + return f.mu.isServerMsgReadyReceived + default: + return false + } +} + +// prepareTransfer sets up the transfer metadata. This moves the forwarder into +// the transferRequested state, and generates a unique transfer key for the +// forwarder. If the forwarder's state is not ready, this will return an error. +func (f *forwarder) prepareTransfer() error { + f.mu.Lock() + defer f.mu.Unlock() + + if f.mu.state != stateReady { + return errors.New("transfer is already in-progress") + } + f.mu.transferKey = uuid.MakeV4().String() + f.mu.transferCloserCh = make(chan struct{}) + f.mu.state = stateTransferRequested + f.mu.transferCtx = nil + f.mu.transferConnRecoverable = false + return nil +} + +// finishTransfer moves the forwarder back to the ready state, and closes the +// transferCloser channel (which unblocks the client-to-server processor). This +// returns an error if it is called during the steady state. +// +// NOTE: This should only be called if the connection is safe to continue +// because this unblocks the client-to-server processor, which may result in +// more packets being sent to the server. If the connection is unsafe to +// proceed, we should just call Close(). +func (f *forwarder) finishTransfer() error { + f.mu.Lock() + defer f.mu.Unlock() + + if f.mu.state == stateReady { + return errors.New("no transfer in-progress") + } + f.mu.transferKey = "" + // This nil case should not happen, but we'll check to avoid closing nil + // channels, which causes panics. + if f.mu.transferCloserCh != nil { + close(f.mu.transferCloserCh) + } + f.mu.transferCloserCh = nil + f.mu.state = stateReady + f.mu.transferCtx = nil + f.mu.transferConnRecoverable = false + return nil +} + +// processTransfer attempts to perform the connection migration, and blocks until +// the connection has been migrated, or an error has occurred. If the connection +// has been migrated successfully, retErr == nil. +// +// If retErr != nil, the forwarder has to be closed by the caller to prevent any +// data corruption, with one exception: the caller may choose to abort the +// transfer process and continue with the fowarding if and only if the error +// has been marked with errConnRecoverableSentinel, which can be verified with +// isConnRecoverableError(). +// +// NOTE: f.mu.transferCtx has to be set before calling this. We use transferCtx +// instead of ctx here to ensure that we can recover if we fail to connect, or +// deserialize the session. Binding to ctx means the only way to abort is to +// close the forwarder. +func (f *forwarder) processTransfer() (retErr error) { + f.mu.Lock() + transferKey, transferCtx := f.mu.transferKey, f.mu.transferCtx + f.mu.Unlock() + + if transferCtx == nil { + return errors.Wrap(errTransferProtocol, "transferCtx is nil") + } + + state, revivalToken, err := waitForShowTransferState(transferCtx, + f.serverInterceptor, f.clientConn, transferKey) + if err != nil { + // Some errors may be recoverable, but those are handled in + // awaitTransferStateResponse, and marked accordingly. + return err + } + + f.mu.Lock() + f.mu.transferConnRecoverable = true + f.mu.Unlock() + + // Connect to a new SQL pod. + // + // TODO(jaylim-crl): There is a possibility where the same pod will get + // selected. Some ideas to solve this: pass in the remote address of + // serverConn to avoid choosing that pod, or maybe a filter callback? + // Will handle this later. + newServerConn, err := f.connector.OpenTenantConnWithToken(transferCtx, revivalToken) + if err != nil { + return markAsConnRecoverableError(err) + } + defer func() { + if retErr != nil { + newServerConn.Close() + } + }() + newServerInterceptor := interceptor.NewFrontendInterceptor(newServerConn) + + // Deserialize session state within the new SQL pod. + err = runAndWaitForDeserializeSession( + transferCtx, newServerConn, newServerInterceptor, state) + if err != nil { + return markAsConnRecoverableError(err) + } + + // Transfer was successful - use the new server connections. + f.serverConn.Close() + f.setServerConnAndInterceptor(newServerConn, newServerInterceptor) + return nil +} + +// runTransferTimeoutHandler starts a timeout handler in the background for a +// duration of waitTimeout until the transfer completes; this happens whenever +// the transferCloserCh channel is closed. If the transfer doesn't complete by +// the given duration, the forwarder will be closed if we're in a non-recoverable +// state. +// +// NOTE: This should only be called during a transfer process. We assume that +// transferCloserCh has already been initialized. +func (f *forwarder) runTransferTimeoutHandler(waitTimeout time.Duration) { + f.mu.Lock() + closer := f.mu.transferCloserCh + // This lint rule is intended; transferCtx isn't used here. + transferCtx, cancel := context.WithTimeout(f.ctx, waitTimeout) // nolint:context + f.mu.transferCtx = transferCtx + f.mu.Unlock() + + // We use a goroutine instead of the return value in the processors to + // allow us to unblock writeTransferStateRequest if the write to the server + // took a long time. + go func() { + defer cancel() + + if f.testingKnobs.onTransferTimeoutHandlerStart != nil { + f.testingKnobs.onTransferTimeoutHandlerStart() + } + select { + case <-f.ctx.Done(): + // Forwarder's context was cancelled. Do nothing. + case <-closer: + // Transfer has completed. + case <-transferCtx.Done(): + f.mu.Lock() + // If transferCtx is nil, that means the connection is recoverable + // because we called finishTransfer. + recoverable := f.mu.transferConnRecoverable || f.mu.transferCtx == nil + f.mu.Unlock() + + // Connection is recoverable, don't close the connection. Context + // cancellation will be propagated up accordingly. + if recoverable { + f.metrics.ConnMigrationTimeoutRecoverableCount.Inc(1) + break + } + + f.metrics.ConnMigrationTimeoutClosedCount.Inc(1) + + // If we're waiting for a message through the server's interceptor, + // this will unblock that call with a closed pipe. If we're busy + // processing other messages, the cancelled context will eventually + // be read. + // + // We send a message to f.errCh first before closing the forwarder + // to ensure that we don't get a context cancellation in errCh + // when we unblock the server interceptor. + select { + case f.errCh <- errTransferTimeout: /* error reported */ + default: /* the channel already contains an error */ + } + f.Close() + } + if f.testingKnobs.onTransferTimeoutHandlerFinish != nil { + f.testingKnobs.onTransferTimeoutHandlerFinish() + } + }() +} + // setClientConn is a convenient helper to update clientConn, and will also // create a matching interceptor for the given connection. It is the caller's // responsibility to close the old connection before calling this, or there // may be a leak. +// +// It is the responsibility of the caller to know when this is safe to call +// since this updates clientConn and clientInterceptor, and is not thread-safe. func (f *forwarder) setClientConn(clientConn net.Conn) { f.clientConn = clientConn f.clientInterceptor = interceptor.NewBackendInterceptor(f.clientConn) @@ -170,9 +707,34 @@ func (f *forwarder) setClientConn(clientConn net.Conn) { // create a matching interceptor for the given connection. It is the caller's // responsibility to close the old connection before calling this, or there // may be a leak. +// +// It is the responsibility of the caller to know when this is safe to call +// since this updates serverConn and serverInterceptor, and is not thread-safe. func (f *forwarder) setServerConn(serverConn net.Conn) { + f.setServerConnAndInterceptor(serverConn, nil /* serverInterceptor */) +} + +// setServerConnAndInterceptor, is similar to setServerConn, but takes in a +// serverInterceptor as well. That way, an existing interceptor can be used. +// If serverInterceptor is nil, an interceptor will be created for the given +// serverConn. +// +// See setServerConn for more information. +func (f *forwarder) setServerConnAndInterceptor( + serverConn net.Conn, serverInterceptor *interceptor.FrontendInterceptor, +) { f.serverConn = serverConn - f.serverInterceptor = interceptor.NewFrontendInterceptor(f.serverConn) + if serverInterceptor == nil { + f.serverInterceptor = interceptor.NewFrontendInterceptor(f.serverConn) + } else { + f.serverInterceptor = serverInterceptor + } + f.clientMessageTypeSent = clientMsgAny + + // This method will only be called during initialization, or whenever the + // transfer is being processed, which in this case, there are no reads on + // this variable, so there won't be a race. + f.mu.isServerMsgReadyReceived = true } // wrapClientToServerError overrides client to server errors for external diff --git a/pkg/ccl/sqlproxyccl/forwarder_test.go b/pkg/ccl/sqlproxyccl/forwarder_test.go index 2fdbd7ff33a4..f6861fd79a9f 100644 --- a/pkg/ccl/sqlproxyccl/forwarder_test.go +++ b/pkg/ccl/sqlproxyccl/forwarder_test.go @@ -11,17 +11,23 @@ package sqlproxyccl import ( "bytes" "context" + "io" "net" + "sync/atomic" "testing" "time" + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/interceptor" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/errors" - "github.com/jackc/pgproto3/v2" + pgproto3 "github.com/jackc/pgproto3/v2" "github.com/stretchr/testify/require" ) +// TestForward is a simple test for message forwarding without connection +// migration. For in-depth tests, see proxy_handler_test.go. func TestForward(t *testing.T) { defer leaktest.AfterTest(t)() @@ -32,7 +38,7 @@ func TestForward(t *testing.T) { // Close the connection right away. p2 is owned by the forwarder. p1.Close() - f := forward(bgCtx, p1, p2) + f := forward(bgCtx, nil /* connector */, nil /* metrics */, p1, p2) defer f.Close() // We have to wait for the goroutine to run. Once the forwarder stops, @@ -56,7 +62,7 @@ func TestForward(t *testing.T) { // for that. defer clientR.Close() - f := forward(ctx, clientR, serverW) + f := forward(ctx, nil /* connector */, nil /* metrics */, clientR, serverW) defer f.Close() require.Nil(t, f.ctx.Err()) @@ -127,7 +133,7 @@ func TestForward(t *testing.T) { // for that. defer clientR.Close() - f := forward(ctx, clientR, serverW) + f := forward(ctx, nil /* connector */, nil /* metrics */, clientR, serverW) defer f.Close() require.Nil(t, f.ctx.Err()) @@ -180,7 +186,7 @@ func TestForwarder_Close(t *testing.T) { p1, p2 := net.Pipe() defer p1.Close() // p2 is owned by the forwarder. - f := forward(context.Background(), p1, p2) + f := forward(context.Background(), nil /* connector */, nil /* metrics */, p1, p2) defer f.Close() require.Nil(t, f.ctx.Err()) @@ -188,6 +194,506 @@ func TestForwarder_Close(t *testing.T) { require.EqualError(t, f.ctx.Err(), context.Canceled.Error()) } +func TestForwarder_RequestTransfer(t *testing.T) { + defer leaktest.AfterTest(t)() + + metrics := makeProxyMetrics() + f := &forwarder{metrics: &metrics} + require.Equal(t, "", f.mu.transferKey) + require.Nil(t, f.mu.transferCloserCh) + require.Equal(t, stateReady, f.mu.state) + + f.RequestTransfer() + require.Equal(t, int64(1), f.metrics.ConnMigrationRequestedCount.Count()) + + key, closer, state := f.mu.transferKey, f.mu.transferCloserCh, f.mu.state + + // Call again to test idempotency. + f.RequestTransfer() + require.Equal(t, key, f.mu.transferKey) + require.Equal(t, closer, f.mu.transferCloserCh) + require.Equal(t, state, f.mu.state) + require.Equal(t, int64(1), f.metrics.ConnMigrationRequestedCount.Count()) +} + +func TestForwarder_isSafeTransferPoint(t *testing.T) { + defer leaktest.AfterTest(t)() + + for _, tc := range []struct { + name string + sent pgwirebase.ClientMessageType + ready bool + expected bool + }{ + // Case 1. + {"sync_with_ready", pgwirebase.ClientMsgSync, true, true}, + {"sync_without_ready", pgwirebase.ClientMsgSync, false, false}, + // Case 1. + {"query_with_ready", pgwirebase.ClientMsgSimpleQuery, true, true}, + {"query_without_ready", pgwirebase.ClientMsgSimpleQuery, false, false}, + // Case 2. + {"copy_done_with_ready", pgwirebase.ClientMsgCopyDone, true, true}, + {"copy_done_without_ready", pgwirebase.ClientMsgCopyDone, false, false}, + // Case 3. + {"copy_fail_with_ready", pgwirebase.ClientMsgCopyFail, true, true}, + {"copy_fail_without_ready", pgwirebase.ClientMsgCopyFail, false, false}, + // Other. + {"random_sent_with_ready", pgwirebase.ClientMsgExecute, true, false}, + {"random_sent_without_ready", pgwirebase.ClientMsgExecute, false, false}, + {"initial_state", clientMsgAny, true, true}, + } { + t.Run(tc.name, func(t *testing.T) { + f := &forwarder{} + f.clientMessageTypeSent = tc.sent + f.mu.isServerMsgReadyReceived = tc.ready + if tc.expected { + require.True(t, f.isSafeTransferPoint()) + } else { + require.False(t, f.isSafeTransferPoint()) + } + }) + } +} + +func TestForwarder_prepareAndFinishTransfer(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("successful", func(t *testing.T) { + f := &forwarder{} + require.Equal(t, "", f.mu.transferKey) + require.Nil(t, f.mu.transferCloserCh) + require.Equal(t, stateReady, f.mu.state) + require.Nil(t, f.mu.transferCtx) + require.False(t, f.mu.transferConnRecoverable) + + require.NoError(t, f.prepareTransfer()) + + require.NotEqual(t, "", f.mu.transferKey) + require.NotNil(t, f.mu.transferCloserCh) + require.Equal(t, stateTransferRequested, f.mu.state) + require.Nil(t, f.mu.transferCtx) + require.False(t, f.mu.transferConnRecoverable) + + ch := f.mu.transferCloserCh + select { + case <-ch: + t.Fatalf("transferCh is closed, which should not happen") + default: + } + + require.NoError(t, f.finishTransfer()) + + require.Equal(t, "", f.mu.transferKey) + require.Nil(t, f.mu.transferCloserCh) + require.Equal(t, stateReady, f.mu.state) + require.Nil(t, f.mu.transferCtx) + require.False(t, f.mu.transferConnRecoverable) + + select { + case <-ch: + default: + t.Fatalf("transferCh is still open") + } + }) + + t.Run("error", func(t *testing.T) { + f := &forwarder{} + + f.mu.state = stateTransferRequested + require.EqualError(t, f.prepareTransfer(), "transfer is already in-progress") + + f.mu.state = stateReady + require.EqualError(t, f.finishTransfer(), "no transfer in-progress") + }) +} + +func TestForwarder_processTransfer(t *testing.T) { + defer leaktest.AfterTest(t)() + ctx := context.Background() + + t.Run("transfer_response_error", func(t *testing.T) { + f := &forwarder{} + f.mu.transferKey = "foo-bar-baz" + f.mu.transferCtx = ctx + + defer testutils.TestingHook(&waitForShowTransferState, func( + fnCtx context.Context, + serverInterceptor *interceptor.FrontendInterceptor, + clientConn io.Writer, + transferKey string, + ) (string, string, error) { + require.Equal(t, ctx, fnCtx) + require.Nil(t, serverInterceptor) + require.Nil(t, clientConn) + require.Equal(t, "foo-bar-baz", transferKey) + return "", "", errors.New("bar") + })() + + err := f.processTransfer() + require.EqualError(t, err, "bar") + require.False(t, isConnRecoverableError(err)) + require.False(t, f.mu.transferConnRecoverable) + }) + + t.Run("connection_error", func(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + f := &forwarder{ + connector: &connector{ + StartupMsg: &pgproto3.StartupMessage{ + Parameters: make(map[string]string), + }, + }, + } + f.setClientConn(client) + f.setServerConn(server) + + f.mu.transferKey = "foo-bar-baz" + f.mu.transferCtx = ctx + + defer testutils.TestingHook(&waitForShowTransferState, func( + fnCtx context.Context, + serverInterceptor *interceptor.FrontendInterceptor, + clientConn io.Writer, + transferKey string, + ) (string, string, error) { + require.Equal(t, ctx, fnCtx) + require.NotNil(t, serverInterceptor) + require.Equal(t, client, clientConn) + require.Equal(t, "foo-bar-baz", transferKey) + return "state-string", "token-string", nil + })() + f.connector.testingKnobs.dialTenantCluster = func(ctx context.Context) (net.Conn, error) { + return nil, errors.New("foo") + } + + err := f.processTransfer() + require.EqualError(t, err, "foo") + require.True(t, isConnRecoverableError(err)) + require.True(t, f.mu.transferConnRecoverable) + }) + + t.Run("deserialization_error", func(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + f := &forwarder{ + connector: &connector{ + StartupMsg: &pgproto3.StartupMessage{ + Parameters: make(map[string]string), + }, + }, + } + f.setClientConn(client) + f.setServerConn(server) + + f.mu.transferKey = "foo-bar-baz" + f.mu.transferCtx = ctx + + defer testutils.TestingHook(&waitForShowTransferState, func( + fnCtx context.Context, + serverInterceptor *interceptor.FrontendInterceptor, + clientConn io.Writer, + transferKey string, + ) (string, string, error) { + require.Equal(t, ctx, fnCtx) + require.NotNil(t, serverInterceptor) + require.Equal(t, client, clientConn) + require.Equal(t, "foo-bar-baz", transferKey) + return "state-string", "token-string", nil + })() + f.connector.testingKnobs.dialTenantCluster = func(ctx context.Context) (net.Conn, error) { + str, ok := f.connector.StartupMsg.Parameters[sessionRevivalTokenStartupParam] + require.True(t, ok) + require.Equal(t, "token-string", str) + return server, nil + } + defer testutils.TestingHook( + &implicitAuthenticate, + func(serverConn net.Conn) error { + return nil + }, + )() + defer testutils.TestingHook(&runAndWaitForDeserializeSession, func( + fnCtx context.Context, + serverConn io.Writer, + serverInterceptor *interceptor.FrontendInterceptor, + state string, + ) error { + require.Equal(t, ctx, fnCtx) + require.Equal(t, server, serverConn) + require.NotNil(t, serverInterceptor) + require.Equal(t, "state-string", state) + return errors.New("bar") + })() + + err := f.processTransfer() + require.EqualError(t, err, "bar") + require.True(t, isConnRecoverableError(err)) + require.True(t, f.mu.transferConnRecoverable) + + // Ensure that conn gets closed when we fail on deserialization so that + // there won't be leaks. + _, err = server.Write([]byte("foo")) + require.Regexp(t, "closed pipe", err) + }) + + t.Run("successful", func(t *testing.T) { + initClient, initServer := net.Pipe() + defer initClient.Close() + defer initServer.Close() + + newServer, _ := net.Pipe() + defer newServer.Close() + + f := &forwarder{ + connector: &connector{ + StartupMsg: &pgproto3.StartupMessage{ + Parameters: make(map[string]string), + }, + }, + } + f.setClientConn(initClient) + f.setServerConn(initServer) + + f.mu.transferKey = "foo-bar-baz" + f.mu.transferCtx = ctx + + defer testutils.TestingHook(&waitForShowTransferState, func( + fnCtx context.Context, + serverInterceptor *interceptor.FrontendInterceptor, + clientConn io.Writer, + transferKey string, + ) (string, string, error) { + require.Equal(t, ctx, fnCtx) + require.NotNil(t, serverInterceptor) + require.Equal(t, initClient, clientConn) + require.Equal(t, "foo-bar-baz", transferKey) + return "state-string", "token-string", nil + })() + f.connector.testingKnobs.dialTenantCluster = func(ctx context.Context) (net.Conn, error) { + str, ok := f.connector.StartupMsg.Parameters[sessionRevivalTokenStartupParam] + require.True(t, ok) + require.Equal(t, "token-string", str) + return newServer, nil + } + defer testutils.TestingHook( + &implicitAuthenticate, + func(serverConn net.Conn) error { + return nil + }, + )() + defer testutils.TestingHook(&runAndWaitForDeserializeSession, func( + fnCtx context.Context, + serverConn io.Writer, + serverInterceptor *interceptor.FrontendInterceptor, + state string, + ) error { + require.Equal(t, ctx, fnCtx) + require.Equal(t, newServer, serverConn) + require.NotNil(t, serverInterceptor) + require.Equal(t, "state-string", state) + return nil + })() + + err := f.processTransfer() + require.NoError(t, err) + require.True(t, f.mu.transferConnRecoverable) + + // Ensure that old serverConn is closed. + _, err = initServer.Write([]byte("foo")) + require.Regexp(t, "closed pipe", err) + require.Equal(t, initClient, f.clientConn) + require.Equal(t, newServer, f.serverConn) + require.NotNil(t, f.serverInterceptor) + }) +} + +func TestForwarder_runTransferTimeoutHandler(t *testing.T) { + defer leaktest.AfterTest(t)() + bgCtx := context.Background() + + t.Run("cancelled_externally", func(t *testing.T) { + var started, finished int32 + ctx, cancel := context.WithCancel(bgCtx) + defer cancel() + errCh := make(chan error, 1) + transferCh := make(chan struct{}) + + f := &forwarder{ + ctx: ctx, + ctxCancel: cancel, + errCh: errCh, + } + f.mu.transferCloserCh = transferCh + f.testingKnobs.onTransferTimeoutHandlerStart = func() { + atomic.StoreInt32(&started, 1) + } + f.testingKnobs.onTransferTimeoutHandlerFinish = func() { + atomic.StoreInt32(&finished, 1) + } + + f.runTransferTimeoutHandler(2 * time.Second) + f.mu.Lock() + require.NotNil(t, f.mu.transferCtx) + f.mu.Unlock() + + // Wait until the handler starts before cancelling. + require.Eventually(t, func() bool { + return atomic.LoadInt32(&started) == 1 + }, 2*time.Second, 50*time.Millisecond, "timed out waiting for timeout handler to run") + + cancel() + + require.Eventually(t, func() bool { + return atomic.LoadInt32(&finished) == 1 + }, 2*time.Second, 50*time.Millisecond, "timed out waiting for timeout handler to return") + + select { + case err := <-errCh: + t.Fatalf("require no error, but got %v", err) + default: + } + }) + + t.Run("transfer_completed", func(t *testing.T) { + var started, finished int32 + ctx, cancel := context.WithCancel(bgCtx) + defer cancel() + errCh := make(chan error, 1) + transferCh := make(chan struct{}) + + f := &forwarder{ + ctx: ctx, + ctxCancel: cancel, + errCh: errCh, + } + f.mu.transferCloserCh = transferCh + f.testingKnobs.onTransferTimeoutHandlerStart = func() { + atomic.StoreInt32(&started, 1) + } + f.testingKnobs.onTransferTimeoutHandlerFinish = func() { + atomic.StoreInt32(&finished, 1) + } + + f.runTransferTimeoutHandler(2 * time.Second) + f.mu.Lock() + require.NotNil(t, f.mu.transferCtx) + f.mu.Unlock() + + // Wait until the handler starts before closing transferCh. + require.Eventually(t, func() bool { + return atomic.LoadInt32(&started) == 1 + }, 2*time.Second, 50*time.Millisecond, "timed out waiting for timeout handler to run") + + close(transferCh) + + require.Eventually(t, func() bool { + return atomic.LoadInt32(&finished) == 1 + }, 2*time.Second, 50*time.Millisecond, "timed out waiting for timeout handler to return") + + select { + case err := <-errCh: + t.Fatalf("require no error, but got %v", err) + default: + } + require.Nil(t, f.ctx.Err()) + }) + + t.Run("timeout_with_non_recoverable_conn", func(t *testing.T) { + w, _ := net.Pipe() + defer w.Close() + + var finished int32 + ctx, cancel := context.WithCancel(bgCtx) + defer cancel() + errCh := make(chan error, 1) + transferCh := make(chan struct{}) + + metrics := makeProxyMetrics() + f := &forwarder{ + ctx: ctx, + ctxCancel: cancel, + errCh: errCh, + serverConn: w, + metrics: &metrics, + } + f.mu.transferCloserCh = transferCh + f.testingKnobs.onTransferTimeoutHandlerFinish = func() { + atomic.StoreInt32(&finished, 1) + } + + f.runTransferTimeoutHandler(500 * time.Millisecond) + f.mu.Lock() + require.NotNil(t, f.mu.transferCtx) + f.mu.Unlock() + + // Wait until handler finishes, which will be triggered after 500ms. + require.Eventually(t, func() bool { + return atomic.LoadInt32(&finished) == 1 + }, 2*time.Second, 50*time.Millisecond, "timed out waiting for timeout handler to return") + + select { + case err := <-errCh: + require.EqualError(t, err, errTransferTimeout.Error()) + default: + t.Fatalf("require error, but got none") + } + // Parent context is cancelled as well because we closed the forwarder. + require.NotNil(t, f.ctx.Err()) + require.NotNil(t, f.mu.transferCtx.Err()) + require.Equal(t, int64(1), f.metrics.ConnMigrationTimeoutClosedCount.Count()) + }) + + t.Run("timeout_with_recoverable_conn", func(t *testing.T) { + w, _ := net.Pipe() + defer w.Close() + + var finished int32 + ctx, cancel := context.WithCancel(bgCtx) + defer cancel() + errCh := make(chan error, 1) + transferCh := make(chan struct{}) + + metrics := makeProxyMetrics() + f := &forwarder{ + ctx: ctx, + ctxCancel: cancel, + errCh: errCh, + serverConn: w, + metrics: &metrics, + } + f.mu.transferCloserCh = transferCh + f.mu.transferConnRecoverable = true + f.testingKnobs.onTransferTimeoutHandlerFinish = func() { + atomic.StoreInt32(&finished, 1) + } + + f.runTransferTimeoutHandler(500 * time.Millisecond) + f.mu.Lock() + require.NotNil(t, f.mu.transferCtx) + f.mu.Unlock() + + // Wait until handler finishes, which will be triggered after 500ms. + require.Eventually(t, func() bool { + return atomic.LoadInt32(&finished) == 1 + }, 2*time.Second, 50*time.Millisecond, "timed out waiting for timeout handler to return") + + select { + case err := <-errCh: + t.Fatalf("require no error, but got %v", err) + default: + } + // Parent context should not be cancelled in a recoverable conn case. + require.Nil(t, f.ctx.Err()) + require.NotNil(t, f.mu.transferCtx.Err()) + require.Equal(t, int64(1), f.metrics.ConnMigrationTimeoutRecoverableCount.Count()) + }) +} + func TestForwarder_setClientConn(t *testing.T) { defer leaktest.AfterTest(t)() f := &forwarder{serverConn: nil, serverInterceptor: nil} @@ -226,6 +732,9 @@ func TestForwarder_setServerConn(t *testing.T) { f.setServerConn(r) require.Equal(t, r, f.serverConn) + require.NotNil(t, f.serverInterceptor) + require.Equal(t, clientMsgAny, f.clientMessageTypeSent) + require.True(t, f.mu.isServerMsgReadyReceived) dst := new(bytes.Buffer) errChan := make(chan error, 1) @@ -244,6 +753,43 @@ func TestForwarder_setServerConn(t *testing.T) { require.Equal(t, 6, dst.Len()) } +func TestForwarder_setServerConnAndInterceptor(t *testing.T) { + defer leaktest.AfterTest(t)() + f := &forwarder{serverConn: nil, serverInterceptor: nil} + + _, r1 := net.Pipe() + defer r1.Close() + + w2, r2 := net.Pipe() + defer w2.Close() + defer r2.Close() + + // Use a different interceptor from the one for r1. + fi := interceptor.NewFrontendInterceptor(r2) + + f.setServerConnAndInterceptor(r1, fi) + require.Equal(t, r1, f.serverConn) + require.Equal(t, fi, f.serverInterceptor) + require.Equal(t, clientMsgAny, f.clientMessageTypeSent) + require.True(t, f.mu.isServerMsgReadyReceived) + + dst := new(bytes.Buffer) + errChan := make(chan error, 1) + go func() { + _, err := f.serverInterceptor.ForwardMsg(dst) + errChan <- err + }() + + _, err := w2.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)) + require.NoError(t, err) + + // Block until message has been forwarded. This checks that we are using + // the right interceptor. + err = <-errChan + require.NoError(t, err) + require.Equal(t, 6, dst.Len()) +} + func TestWrapClientToServerError(t *testing.T) { defer leaktest.AfterTest(t)() diff --git a/pkg/ccl/sqlproxyccl/metrics.go b/pkg/ccl/sqlproxyccl/metrics.go index 7b5afdf2c012..f46f3d2ce842 100644 --- a/pkg/ccl/sqlproxyccl/metrics.go +++ b/pkg/ccl/sqlproxyccl/metrics.go @@ -15,16 +15,24 @@ import ( // metrics contains pointers to the metrics for monitoring proxy operations. type metrics struct { - BackendDisconnectCount *metric.Counter - IdleDisconnectCount *metric.Counter - BackendDownCount *metric.Counter - ClientDisconnectCount *metric.Counter - CurConnCount *metric.Gauge - RoutingErrCount *metric.Counter - RefusedConnCount *metric.Counter - SuccessfulConnCount *metric.Counter - AuthFailedCount *metric.Counter - ExpiredClientConnCount *metric.Counter + BackendDisconnectCount *metric.Counter + IdleDisconnectCount *metric.Counter + BackendDownCount *metric.Counter + ClientDisconnectCount *metric.Counter + CurConnCount *metric.Gauge + RoutingErrCount *metric.Counter + RefusedConnCount *metric.Counter + SuccessfulConnCount *metric.Counter + AuthFailedCount *metric.Counter + ExpiredClientConnCount *metric.Counter + ConnMigrationRequestedCount *metric.Counter + ConnMigrationAcceptedCount *metric.Counter + ConnMigrationRejectedCount *metric.Counter + ConnMigrationSuccessCount *metric.Counter + ConnMigrationErrorCount *metric.Counter + ConnMigrationFailCount *metric.Counter + ConnMigrationTimeoutClosedCount *metric.Counter + ConnMigrationTimeoutRecoverableCount *metric.Counter } // MetricStruct implements the metrics.Struct interface. @@ -93,6 +101,65 @@ var ( Measurement: "Expired Client Connections", Unit: metric.Unit_COUNT, } + // Connection migration metrics. + // + // requested = accepted + rejected + // - Connections are rejected if we're not in a safe transfer point. + // + // accepted >= success + error + fail + // - Note that it's >= because some timeout errors are excluded. For + // example, if we timed out in writing the SHOW TRANSFER STATE query to + // the server, we'll see a timeout_closed metric, but not the error metric. + metaConnMigrationRequestedCount = metric.Metadata{ + Name: "proxy.conn_migration.requested", + Help: "Number of requested connection migrations", + Measurement: "Connection Migrations", + Unit: metric.Unit_COUNT, + } + metaConnMigrationAcceptedCount = metric.Metadata{ + Name: "proxy.conn_migration.accepted", + Help: "Number of accepted connection migrations", + Measurement: "Connection Migrations", + Unit: metric.Unit_COUNT, + } + metaConnMigrationRejectedCount = metric.Metadata{ + Name: "proxy.conn_migration.rejected", + Help: "Number of rejected connection migrations", + Measurement: "Connection Migrations", + Unit: metric.Unit_COUNT, + } + metaConnMigrationSuccessCount = metric.Metadata{ + Name: "proxy.conn_migration.success", + Help: "Number of connection migrations which are successful", + Measurement: "Connection Migrations", + Unit: metric.Unit_COUNT, + } + metaConnMigrationErrorCount = metric.Metadata{ + // When connection migrations errored out, connections will be closed. + Name: "proxy.conn_migration.error", + Help: "Number of connection migrations which errored", + Measurement: "Connection Migrations", + Unit: metric.Unit_COUNT, + } + metaConnMigrationFailCount = metric.Metadata{ + // Connections are recoverable, so they won't be closed. + Name: "proxy.conn_migration.fail", + Help: "Number of connection migrations which failed", + Measurement: "Connection Migrations", + Unit: metric.Unit_COUNT, + } + metaConnMigrationTimeoutClosedCount = metric.Metadata{ + Name: "proxy.conn_migration.timeout_closed", + Help: "Number of expired connection migrations that closes connections", + Measurement: "Connection Migrations", + Unit: metric.Unit_COUNT, + } + metaConnMigrationTimeoutRecoverableCount = metric.Metadata{ + Name: "proxy.conn_migration.timeout_recoverable", + Help: "Number of expired connection migrations that recovers connections", + Measurement: "Connection Migrations", + Unit: metric.Unit_COUNT, + } ) // makeProxyMetrics instantiates the metrics holder for proxy monitoring. @@ -108,6 +175,15 @@ func makeProxyMetrics() metrics { SuccessfulConnCount: metric.NewCounter(metaSuccessfulConnCount), AuthFailedCount: metric.NewCounter(metaAuthFailedCount), ExpiredClientConnCount: metric.NewCounter(metaExpiredClientConnCount), + // Connection migration metrics. + ConnMigrationRequestedCount: metric.NewCounter(metaConnMigrationRequestedCount), + ConnMigrationAcceptedCount: metric.NewCounter(metaConnMigrationAcceptedCount), + ConnMigrationRejectedCount: metric.NewCounter(metaConnMigrationRejectedCount), + ConnMigrationSuccessCount: metric.NewCounter(metaConnMigrationSuccessCount), + ConnMigrationErrorCount: metric.NewCounter(metaConnMigrationErrorCount), + ConnMigrationFailCount: metric.NewCounter(metaConnMigrationFailCount), + ConnMigrationTimeoutClosedCount: metric.NewCounter(metaConnMigrationTimeoutClosedCount), + ConnMigrationTimeoutRecoverableCount: metric.NewCounter(metaConnMigrationTimeoutRecoverableCount), } } diff --git a/pkg/ccl/sqlproxyccl/proxy_handler.go b/pkg/ccl/sqlproxyccl/proxy_handler.go index 2610553079e9..a6d6b74cedf5 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler.go @@ -96,6 +96,11 @@ type ProxyOptions struct { // ThrottleBaseDelay is the initial exponential backoff triggered in // response to the first connection failure. ThrottleBaseDelay time.Duration + + // Used for testing. + testingKnobs struct { + afterForward func(*forwarder) error + } } // proxyHandler is the default implementation of a proxy handler. @@ -342,8 +347,16 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn }() // Pass ownership of crdbConn to the forwarder. - f = forward(ctx, conn, crdbConn) + f = forward(ctx, connector, handler.metrics, conn, crdbConn) defer f.Close() + if handler.testingKnobs.afterForward != nil { + if err := handler.testingKnobs.afterForward(f); err != nil { + select { + case errConnection <- err: /* error reported */ + default: /* the channel already contains an error */ + } + } + } // Block until an error is received, or when the stopper starts quiescing, // whichever that happens first. @@ -357,7 +370,7 @@ func (handler *proxyHandler) handle(ctx context.Context, incomingConn *proxyConn // TODO(jaylim-crl): It would be nice to have more consistency in how we // manage background goroutines, communicate errors, etc. select { - case err := <-f.errChan: // From forwarder. + case err := <-f.errCh: // From forwarder. handler.metrics.updateForError(err) return err case err := <-errConnection: // From denyListWatcher or idleMonitor. diff --git a/pkg/ccl/sqlproxyccl/proxy_handler_test.go b/pkg/ccl/sqlproxyccl/proxy_handler_test.go index 3de0c858c81c..7a47f9ac1510 100644 --- a/pkg/ccl/sqlproxyccl/proxy_handler_test.go +++ b/pkg/ccl/sqlproxyccl/proxy_handler_test.go @@ -11,6 +11,7 @@ package sqlproxyccl import ( "context" "crypto/tls" + gosql "database/sql" "fmt" "io/ioutil" "net" @@ -20,6 +21,7 @@ import ( "testing" "time" + "github.com/cockroachdb/cockroach-go/v2/crdb" "github.com/cockroachdb/cockroach/pkg/base" "github.com/cockroachdb/cockroach/pkg/ccl/kvccl/kvtenantccl" "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/denylist" @@ -30,6 +32,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql" "github.com/cockroachdb/cockroach/pkg/sql/pgwire" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" + "github.com/cockroachdb/cockroach/pkg/sql/tests" "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/skip" @@ -752,6 +755,368 @@ func TestDirectoryConnect(t *testing.T) { }) } +func TestConnectionMigration(t *testing.T) { + defer leaktest.AfterTest(t)() + ctx := context.Background() + defer log.Scope(t).Close(t) + + params, _ := tests.CreateTestServerParams() + s, _, _ := serverutils.StartServer(t, params) + defer s.Stopper().Stop(ctx) + tenantID := serverutils.TestTenantID() + + // Start first SQL pod. + tenant1, mainDB1 := serverutils.StartTenant(t, s, tests.CreateTestTenantParams(tenantID)) + tenant1.PGServer().(*pgwire.Server).TestingSetTrustClientProvidedRemoteAddr(true) + defer tenant1.Stopper().Stop(ctx) + defer mainDB1.Close() + + // Start second SQL pod. + params2 := tests.CreateTestTenantParams(tenantID) + params2.Existing = true + tenant2, mainDB2 := serverutils.StartTenant(t, s, params2) + tenant2.PGServer().(*pgwire.Server).TestingSetTrustClientProvidedRemoteAddr(true) + defer tenant2.Stopper().Stop(ctx) + defer mainDB2.Close() + + _, err := mainDB1.Exec("CREATE USER testuser WITH PASSWORD 'hunter2'") + require.NoError(t, err) + _, err = mainDB1.Exec("GRANT admin TO testuser") + require.NoError(t, err) + _, err = mainDB1.Exec("SET CLUSTER SETTING server.user_login.session_revival_token.enabled = true") + require.NoError(t, err) + + // Create a proxy server without using a directory. The directory is very + // difficult to work with, and there isn't a way to easily stub out fake + // loads. For this test, we will stub out lookupAddr in the connector. We + // will alternate between tenant1 and tenant2, starting with tenant1. + forwarderCh := make(chan *forwarder) + opts := &ProxyOptions{SkipVerify: true, RoutingRule: tenant1.SQLAddr()} + opts.testingKnobs.afterForward = func(f *forwarder) error { + select { + case forwarderCh <- f: + case <-time.After(10 * time.Second): + return errors.New("no receivers for forwarder") + } + return nil + } + _, addr := newSecureProxyServer(ctx, t, s.Stopper(), opts) + + // The tenant ID does not matter here since we stubbed RoutingRule. + connectionString := fmt.Sprintf("postgres://testuser:hunter2@%s/?sslmode=require&options=--cluster=tenant-cluster-28", addr) + + type queryer interface { + QueryRowContext(context.Context, string, ...interface{}) *gosql.Row + } + // queryAddr queries the SQL node that `db` is connected to for its address. + queryAddr := func(t *testing.T, ctx context.Context, db queryer) string { + t.Helper() + var host, port string + require.NoError(t, db.QueryRowContext(ctx, ` + SELECT + a.value AS "host", b.value AS "port" + FROM crdb_internal.node_runtime_info a, crdb_internal.node_runtime_info b + WHERE a.component = 'DB' AND a.field = 'Host' + AND b.component = 'DB' AND b.field = 'Port' + `).Scan(&host, &port)) + return fmt.Sprintf("%s:%s", host, port) + } + + // Test that connection transfers are successful. Note that if one sub-test + // fails, the remaining will fail as well since they all use the same + // forwarder instance. + t.Run("successful", func(t *testing.T) { + tCtx, cancel := context.WithCancel(ctx) + defer cancel() + + db, err := gosql.Open("postgres", connectionString) + db.SetMaxOpenConns(1) + defer db.Close() + require.NoError(t, err) + + // Spin up a goroutine to trigger the initial connection. + go func() { + _ = db.PingContext(tCtx) + }() + + var f *forwarder + select { + case f = <-forwarderCh: + case <-time.After(10 * time.Second): + t.Fatal("no connection") + } + + // Set up forwarder hooks. + prevTenant1 := true + var lookupAddrDelayDuration time.Duration + f.connector.testingKnobs.lookupAddr = func(ctx context.Context) (string, error) { + if lookupAddrDelayDuration != 0 { + select { + case <-ctx.Done(): + return "", errors.Wrap(ctx.Err(), "injected delays") + case <-time.After(lookupAddrDelayDuration): + } + } + if prevTenant1 { + prevTenant1 = false + return tenant2.SQLAddr(), nil + } + prevTenant1 = true + return tenant1.SQLAddr(), nil + } + + t.Run("normal_transfer", func(t *testing.T) { + require.Equal(t, tenant1.SQLAddr(), queryAddr(t, tCtx, db)) + + _, err = db.Exec("SET application_name = 'foo'") + require.NoError(t, err) + + // Show that we get alternating SQL pods when we transfer. + f.RequestTransfer() + require.Eventually(t, func() bool { + return f.metrics.ConnMigrationSuccessCount.Count() == 1 + }, 10*time.Second, 25*time.Millisecond) + require.Equal(t, tenant2.SQLAddr(), queryAddr(t, tCtx, db)) + + var name string + require.NoError(t, db.QueryRow("SHOW application_name").Scan(&name)) + require.Equal(t, "foo", name) + + _, err = db.Exec("SET application_name = 'bar'") + require.NoError(t, err) + + f.RequestTransfer() + require.Eventually(t, func() bool { + return f.metrics.ConnMigrationSuccessCount.Count() == 1 + }, 10*time.Second, 25*time.Millisecond) + require.Equal(t, tenant1.SQLAddr(), queryAddr(t, tCtx, db)) + + require.NoError(t, db.QueryRow("SHOW application_name").Scan(&name)) + require.Equal(t, "bar", name) + + // Now attempt a transfer concurrently with requests. + closerCh := make(chan struct{}) + go func() { + for i := 0; i < 10 && tCtx.Err() == nil; i++ { + f.RequestTransfer() + time.Sleep(500 * time.Millisecond) + } + closerCh <- struct{}{} + }() + + // This test runs for 5 seconds. + var tenant1Addr, tenant2Addr int + for i := 0; i < 100; i++ { + addr := queryAddr(t, tCtx, db) + if addr == tenant1.SQLAddr() { + tenant1Addr++ + } else { + require.Equal(t, tenant2.SQLAddr(), addr) + tenant2Addr++ + } + time.Sleep(50 * time.Millisecond) + } + + // In 5s, we should have at least 10 transfers. Just ensure that we have + // at least half in total. + require.Eventually(t, func() bool { + return f.metrics.ConnMigrationSuccessCount.Count() >= 5 + }, 10*time.Second, 25*time.Millisecond) + require.True(t, tenant1Addr > 2) + require.True(t, tenant2Addr > 2) + require.Equal(t, int64(0), f.metrics.ConnMigrationFailCount.Count()) + require.Equal(t, int64(0), f.metrics.ConnMigrationErrorCount.Count()) + + // Ensure that the goroutine terminates so other subtests are not + // affected. + <-closerCh + + // There's a chance that we still have an in-progress transfer, so + // attempt to wait. + require.Eventually(t, func() bool { + f.mu.Lock() + defer f.mu.Unlock() + return stateReady == f.mu.state + }, 5*time.Second, 25*time.Millisecond) + }) + + // Transfers should fail if there is an open transaction. These failed + // transfers should not close the connection. + t.Run("failed_transfers_with_tx", func(t *testing.T) { + initSuccessCount := f.metrics.ConnMigrationSuccessCount.Count() + initAddr := queryAddr(t, tCtx, db) + + err = crdb.ExecuteTx(tCtx, db, nil /* txopts */, func(tx *gosql.Tx) error { + for i := 0; i < 10; i++ { + f.RequestTransfer() + addr := queryAddr(t, tCtx, tx) + if initAddr != addr { + return errors.Newf( + "address does not match, expected %s, found %s", + initAddr, + addr, + ) + } + time.Sleep(50 * time.Millisecond) + } + return nil + }) + require.NoError(t, err) + + // Make sure there are no pending transfers. + func() { + f.mu.Lock() + defer f.mu.Unlock() + require.Equal(t, stateReady, f.mu.state) + }() + + // Just check that we have half of what we requested since we cannot + // guarantee that the transfer will run within 50ms. + require.True(t, f.metrics.ConnMigrationFailCount.Count() >= 5) + require.Equal(t, int64(0), f.metrics.ConnMigrationErrorCount.Count()) + require.Equal(t, initSuccessCount, f.metrics.ConnMigrationSuccessCount.Count()) + prevFailCount := f.metrics.ConnMigrationFailCount.Count() + + // Once the transaction is closed, transfers should work. + f.RequestTransfer() + require.Eventually(t, func() bool { + return f.metrics.ConnMigrationSuccessCount.Count() == initSuccessCount+1 + }, 10*time.Second, 25*time.Millisecond) + require.NotEqual(t, initAddr, queryAddr(t, tCtx, db)) + require.Equal(t, prevFailCount, f.metrics.ConnMigrationFailCount.Count()) + require.Equal(t, int64(0), f.metrics.ConnMigrationErrorCount.Count()) + + // We have already asserted metrics above, so transfer must have + // been completed. + f.mu.Lock() + defer f.mu.Unlock() + require.Equal(t, stateReady, f.mu.state) + }) + + // Transfer timeout caused by dial issues should not close the session. + // We will test this by introducing delays when connecting to the SQL + // pod. + t.Run("failed_transfers_with_dial_issues", func(t *testing.T) { + initSuccessCount := f.metrics.ConnMigrationSuccessCount.Count() + initFailCount := f.metrics.ConnMigrationFailCount.Count() + initAddr := queryAddr(t, tCtx, db) + + // Set the delay longer than the timeout. + lookupAddrDelayDuration = 10 * time.Second + f.testingKnobs.transferTimeoutDuration = func() time.Duration { + return 3 * time.Second + } + + f.RequestTransfer() + require.Eventually(t, func() bool { + return f.metrics.ConnMigrationFailCount.Count() == initFailCount+1 + }, 10*time.Second, 25*time.Millisecond) + require.Equal(t, initAddr, queryAddr(t, tCtx, db)) + require.Equal(t, initSuccessCount, f.metrics.ConnMigrationSuccessCount.Count()) + require.Equal(t, int64(0), f.metrics.ConnMigrationErrorCount.Count()) + require.Equal(t, int64(1), f.metrics.ConnMigrationTimeoutRecoverableCount.Count()) + }) + }) + + // Test transfer timeouts caused by waiting for a transfer state response. + // In reality, this can only be caused by pipelined queries. Consider the + // folllowing: + // 1. short-running simple query + // 2. long-running simple query + // 3. SHOW TRANSFER STATE + // When (1) returns a response, the forwarder will see that we're in a + // safe transfer point, and initiate (3). But (2) may block until we hit + // a timeout. + // + // There's no easy way to simulate pipelined queries. pgtest (that allows + // us to send individual pgwire messages) does not support authentication, + // which is what the proxy needs, so we will stub isSafeTransferPoint + // instead. + t.Run("transfer_timeout_in_response", func(t *testing.T) { + tCtx, cancel := context.WithCancel(ctx) + defer cancel() + + db, err := gosql.Open("postgres", connectionString) + db.SetMaxOpenConns(1) + defer db.Close() + require.NoError(t, err) + + // Use a single connection so that we don't reopen when the connection + // is closed. + conn, err := db.Conn(tCtx) + require.NoError(t, err) + + // Spin up a goroutine to trigger the initial connection. + go func() { + _ = conn.PingContext(tCtx) + }() + + var f *forwarder + select { + case f = <-forwarderCh: + case <-time.After(10 * time.Second): + t.Fatal("no connection") + } + require.Equal(t, int64(0), f.metrics.ConnMigrationTimeoutClosedCount.Count()) + + // Set up forwarder hooks. + prevTenant1 := true + f.connector.testingKnobs.lookupAddr = func(ctx context.Context) (string, error) { + if prevTenant1 { + prevTenant1 = false + return tenant2.SQLAddr(), nil + } + prevTenant1 = true + return tenant1.SQLAddr(), nil + } + f.testingKnobs.isSafeTransferPoint = func() bool { + return true + } + f.testingKnobs.transferTimeoutDuration = func() time.Duration { + // Transfer timeout is 3s, and we'll run pg_sleep for 10s. + return 3 * time.Second + } + + goCh := make(chan struct{}, 1) + errCh := make(chan error, 1) + go func() { + goCh <- struct{}{} + _, err = conn.ExecContext(tCtx, "SELECT pg_sleep(10)") + errCh <- err + }() + + // Block until goroutine is started. We want to make sure we run the + // transfer request *after* sending the query. This doesn't guarantee, + // but is the best that we can do. We also added a sleep call here. + // + // Alternatively, we could open another connection, and query the server + // to make sure pg_sleep is running, but that seems unnecessary for just + // one test. + <-goCh + time.Sleep(250 * time.Millisecond) + f.RequestTransfer() + + // Connection should be closed because this is a non-recoverable error, + // i.e. timeout after sending the request, but before fully receiving + // its response. + require.Eventually(t, func() bool { + err := conn.PingContext(tCtx) + return err != nil && strings.Contains(err.Error(), "bad connection") + }, 10*time.Second, 25*time.Millisecond) + + select { + case <-time.After(10 * time.Second): + t.Fatalf("require that pg_sleep query terminates") + case err = <-errCh: + require.NotNil(t, err) + require.Regexp(t, "bad connection", err.Error()) + } + require.Eventually(t, func() bool { + return f.metrics.ConnMigrationTimeoutClosedCount.Count() == 1 + }, 20*time.Second, 25*time.Millisecond) + }) +} + func TestClusterNameAndTenantFromParams(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t)