diff --git a/pkg/ccl/sqlproxyccl/BUILD.bazel b/pkg/ccl/sqlproxyccl/BUILD.bazel index 8eec3dd6ba84..935dd831bdb1 100644 --- a/pkg/ccl/sqlproxyccl/BUILD.bazel +++ b/pkg/ccl/sqlproxyccl/BUILD.bazel @@ -6,6 +6,7 @@ go_library( srcs = [ "authentication.go", "backend_dialer.go", + "conn_migration.go", "connector.go", "error.go", "forwarder.go", @@ -28,7 +29,9 @@ go_library( "//pkg/security/certmgr", "//pkg/sql/pgwire", "//pkg/sql/pgwire/pgcode", + "//pkg/sql/pgwire/pgwirebase", "//pkg/util/contextutil", + "//pkg/util/encoding", "//pkg/util/grpcutil", "//pkg/util/httputil", "//pkg/util/log", @@ -52,6 +55,7 @@ go_test( size = "small", srcs = [ "authentication_test.go", + "conn_migration_test.go", "connector_test.go", "forwarder_test.go", "frontend_admitter_test.go", @@ -65,6 +69,7 @@ go_test( "//pkg/base", "//pkg/ccl/kvccl/kvtenantccl", "//pkg/ccl/sqlproxyccl/denylist", + "//pkg/ccl/sqlproxyccl/interceptor", "//pkg/ccl/sqlproxyccl/tenantdirsvr", "//pkg/ccl/sqlproxyccl/throttler", "//pkg/ccl/utilccl", @@ -75,6 +80,7 @@ go_test( "//pkg/sql", "//pkg/sql/pgwire", "//pkg/sql/pgwire/pgerror", + "//pkg/sql/pgwire/pgwirebase", "//pkg/testutils", "//pkg/testutils/serverutils", "//pkg/testutils/skip", diff --git a/pkg/ccl/sqlproxyccl/conn_migration.go b/pkg/ccl/sqlproxyccl/conn_migration.go new file mode 100644 index 000000000000..986c0036e253 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/conn_migration.go @@ -0,0 +1,425 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "context" + "encoding/json" + "fmt" + "io" + "strings" + + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/interceptor" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase" + "github.com/cockroachdb/cockroach/pkg/util/encoding" + "github.com/cockroachdb/errors" + pgproto3 "github.com/jackc/pgproto3/v2" +) + +// runShowTransferState sends a SHOW TRANSFER STATE query with the input +// transferKey to the given writer. The transferKey will be used to uniquely +// identify the request when parsing the response messages in +// waitForShowTransferState. +// +// Unlike runAndWaitForDeserializeSession, we split the SHOW TRANSFER STATE +// operation into `run` and `wait` since they both will be invoked in different +// goroutines. If we combined them, we'll have to wait for at least one of the +// goroutines to pause, which can introduce a latency of about 1-2s per transfer +// while waiting for Read in readTimeoutConn to be unblocked. +func runShowTransferState(w io.Writer, transferKey string) error { + return writeQuery(w, "SHOW TRANSFER STATE WITH '%s'", transferKey) +} + +// waitForShowTransferState retrieves the transfer state from the SQL pod +// through SHOW TRANSFER STATE WITH 'key'. It is assumed that the last message +// from the server was ReadyForQuery, so the server is ready to accept a query. +// +// Since ReadyForQuery may be for a previous pipelined query, this handles the +// forwarding of messages back to the client in case we don't see our state yet. +// +// This also does not support transferring messages with a large state (> 4K +// bytes). This may occur if a user sets their cluster settings to large values. +// We can potentially add a TenantReadOnly cluster setting that restricts the +// maximum length of session variables for a better migration experience. For +// now, we'll just close the connection. +// +// WARNING: When using this, we assume that no other goroutines are using both +// serverConn and clientConn, as well as their respective interceptors. In the +// context of a transfer, the client-to-server processor must be blocked. +var waitForShowTransferState = func( + ctx context.Context, + serverInterceptor *interceptor.FrontendInterceptor, + clientConn io.Writer, + transferKey string, +) (state string, revivalToken string, retErr error) { + // Wait for a response that looks like the following: + // + // error | session_state_base64 | session_revival_token_base64 | transfer_key + // --------+----------------------+------------------------------+--------------- + // NULL | .................... | ............................ | + // (1 row) + // + // Postgres messages always come in the following order for the + // SHOW TRANSFER STATE WITH '' query: + // 1. RowDescription + // 2. DataRow + // 3. CommandComplete + // 4. ReadyForQuery + + // 1. Read RowDescription. Loop here since there could be pipelined queries + // that were sent before. + for { + if ctx.Err() != nil { + return "", "", ctx.Err() + } + + // We skip reads here so we could call ForwardMsg for messages that are + // not our concern. + _, err := expectNextServerMessage( + ctx, serverInterceptor, pgwirebase.ServerMsgRowDescription, true /* skipRead */) + + // We don't know if the ErrorResponse is for the client or proxy, so we + // will just close the connection. + if isErrorResponseError(err) { + return "", "", errors.Wrap(err, "ambiguous ErrorResponse") + } + + // Messages are intended for the client in two cases: + // 1. We have not seen a RowDescription message yet + // 2. Message was too lage. Connection migration doesn't care about + // large messages since we expected our header for SHOW TRANSFER + // STATE to fit 4K bytes. + if isTypeMismatchError(err) || isLargeMessageError(err) { + if _, err := serverInterceptor.ForwardMsg(clientConn); err != nil { + return "", "", errors.Wrap(err, "forwarding message") + } + continue + } + + if err != nil { + return "", "", errors.Wrap(err, "waiting for RowDescription") + } + + msg, err := serverInterceptor.ReadMsg() + if err != nil { + return "", "", errors.Wrap(err, "reading RowDescription") + } + + // If pgMsg is nil, isValidStartTransferStateResponse will handle that. + pgMsg, _ := msg.(*pgproto3.RowDescription) + + // We found our intended header, so start expecting a DataRow. + if isValidStartTransferStateResponse(pgMsg) { + break + } + + // Column names do not match, so forward the message back to the client, + // and continue waiting. + if _, err := clientConn.Write(msg.Encode(nil)); err != nil { + return "", "", errors.Wrap(err, "writing RowDescription") + } + } + + // 2. Read DataRow. + var transferErr string + { + msg, err := expectNextServerMessage( + ctx, serverInterceptor, pgwirebase.ServerMsgDataRow, false /* skipRead */) + if err != nil { + return "", "", errors.Wrap(err, "waiting for DataRow") + } + + // If pgMsg is nil, parseTransferStateResponse will handle that. + pgMsg, _ := msg.(*pgproto3.DataRow) + transferErr, state, revivalToken, err = parseTransferStateResponse(pgMsg, transferKey) + if err != nil { + return "", "", errors.Wrapf(err, "invalid DataRow: %v", jsonOrRaw(msg)) + } + } + + // 3. Read CommandComplete. + { + msg, err := expectNextServerMessage( + ctx, serverInterceptor, pgwirebase.ServerMsgCommandComplete, false /* skipRead */) + if err != nil { + return "", "", errors.Wrap(err, "waiting for CommandComplete") + } + + // If pgMsg is nil, isValidEndTransferStateResponse will handle that. + pgMsg, _ := msg.(*pgproto3.CommandComplete) + if !isValidEndTransferStateResponse(pgMsg) { + return "", "", errors.Newf("invalid CommandComplete: %v", jsonOrRaw(msg)) + } + } + + // 4. Read ReadyForQuery. + if _, err := expectNextServerMessage( + ctx, serverInterceptor, pgwirebase.ServerMsgReady, false /* skipRead */); err != nil { + return "", "", errors.Wrap(err, "waiting for ReadyForQuery") + } + + // If we managed to consume until ReadyForQuery without errors, but the + // transfer state response returns an error, we could still continue with + // the connection, but the transfer process will need to be aborted. + // + // This case may happen pretty frequently (e.g. open transactions, temporary + // tables, etc.). + if transferErr != "" { + return "", "", markAsConnRecoverableError(errors.Newf("%s", transferErr)) + } + + return state, revivalToken, nil +} + +// isValidStartTransferStateResponse returns true if m represents a valid +// column header for the SHOW TRANSFER STATE statement, or false otherwise. +func isValidStartTransferStateResponse(m *pgproto3.RowDescription) bool { + if m == nil { + return false + } + // Do we have the right number of columns? + if len(m.Fields) != 4 { + return false + } + // Do the names of the columns match? + var transferStateCols = []string{ + "error", + "session_state_base64", + "session_revival_token_base64", + "transfer_key", + } + for i, col := range transferStateCols { + // Prevent an allocation when converting byte slice to string. + if encoding.UnsafeString(m.Fields[i].Name) != col { + return false + } + } + return true +} + +// isValidEndTransferStateResponse returns true if this is a valid +// CommandComplete message that denotes the end of a transfer state response +// message, or false otherwise. +func isValidEndTransferStateResponse(m *pgproto3.CommandComplete) bool { + if m == nil { + return false + } + // We only expect 1 response row. + return encoding.UnsafeString(m.CommandTag) == "SHOW TRANSFER STATE 1" +} + +// parseTransferStateResponse parses the state in the DataRow message, and +// extracts the fields for the SHOW TRANSFER STATE query. If err != nil, then +// all other returned fields will be empty strings. +// +// If the input transferKey does not match the result for the transfer_key +// column within the DataRow message, this will return an error. +func parseTransferStateResponse( + m *pgproto3.DataRow, transferKey string, +) (transferErr string, state string, revivalToken string, err error) { + if m == nil { + return "", "", "", errors.New("DataRow message is nil") + } + + // Do we have the right number of columns? This has to be 4 since we have + // validated RowDescription earlier. + if len(m.Values) != 4 { + return "", "", "", errors.Newf( + "unexpected %d columns in DataRow", len(m.Values)) + } + + // Validate transfer key. It is possible that the end-user uses the SHOW + // TRANSFER STATE WITH 'transfer_key' statement, but that isn't designed for + // external usage, so it is fine to just terminate here if the transfer key + // does not match. + keyVal := encoding.UnsafeString(m.Values[3]) + if keyVal != transferKey { + return "", "", "", errors.Newf( + "expected '%s' as transfer key, found '%s'", transferKey, keyVal) + } + + // NOTE: We have to cast to string and copy here since the slice referenced + // in m will no longer be valid once we read the next pgwire message. + return string(m.Values[0]), string(m.Values[1]), string(m.Values[2]), nil +} + +// runAndWaitForDeserializeSession deserializes state into the SQL pod through +// crdb_internal.deserialize_session. It is assumed that the last message from +// the server was ReadyForQuery, so the server is ready to accept a query. +// +// This is meant to be used with a new connection, and nothing needs to be +// forwarded back to the client. +// +// WARNING: When using this, we assume that no other goroutines are using both +// serverConn and clientConn, and their respective interceptors. +var runAndWaitForDeserializeSession = func( + ctx context.Context, + serverConn io.Writer, + serverInterceptor *interceptor.FrontendInterceptor, + state string, +) error { + // Send deserialization query. + if err := writeQuery(serverConn, + "SELECT crdb_internal.deserialize_session(decode('%s', 'base64'))", state); err != nil { + return err + } + + // Wait for a response that looks like the following: + // + // crdb_internal.deserialize_session + // ------------------------------------- + // true + // (1 row) + // + // Postgres messages always come in the following order for the + // deserialize_session query: + // 1. RowDescription + // 2. DataRow + // 3. CommandComplete + // 4. ReadyForQuery + const skipRead = false + + // 1. Read RowDescription. + { + msg, err := expectNextServerMessage( + ctx, serverInterceptor, pgwirebase.ServerMsgRowDescription, skipRead) + if err != nil { + return errors.Wrap(err, "waiting for RowDescription") + } + pgMsg, ok := msg.(*pgproto3.RowDescription) + if !ok || len(pgMsg.Fields) != 1 || + encoding.UnsafeString(pgMsg.Fields[0].Name) != "crdb_internal.deserialize_session" { + return errors.Newf("invalid RowDescription: %v", jsonOrRaw(msg)) + } + } + + // 2. Read DataRow. + { + msg, err := expectNextServerMessage( + ctx, serverInterceptor, pgwirebase.ServerMsgDataRow, skipRead) + if err != nil { + return errors.Wrap(err, "waiting for DataRow") + } + // Expect just 1 column with "true" as value. + pgMsg, ok := msg.(*pgproto3.DataRow) + if !ok || len(pgMsg.Values) != 1 || + encoding.UnsafeString(pgMsg.Values[0]) != "t" { + return errors.Newf("invalid DataRow: %v", jsonOrRaw(msg)) + } + } + + // 3. Read CommandComplete. + { + msg, err := expectNextServerMessage( + ctx, serverInterceptor, pgwirebase.ServerMsgCommandComplete, skipRead) + if err != nil { + return errors.Wrap(err, "waiting for CommandComplete") + } + pgMsg, ok := msg.(*pgproto3.CommandComplete) + if !ok || encoding.UnsafeString(pgMsg.CommandTag) != "SELECT 1" { + return errors.Newf("invalid CommandComplete: %v", jsonOrRaw(msg)) + } + } + + // 4. Read ReadyForQuery. + if _, err := expectNextServerMessage( + ctx, serverInterceptor, pgwirebase.ServerMsgReady, skipRead); err != nil { + return errors.Wrap(err, "waiting for ReadyForQuery") + } + + return nil +} + +// writeQuery writes a SimpleQuery to the given writer w. +func writeQuery(w io.Writer, format string, a ...interface{}) error { + query := &pgproto3.Query{String: fmt.Sprintf(format, a...)} + _, err := w.Write(query.Encode(nil)) + return err +} + +// expectNextServerMessage expects that the next message in the server's +// interceptor will match the input message type. This will block until one +// message can be peeked. On return, this reads the next message into memory, +// and returns that. To avoid this read behavior, set skipRead to true, so the +// caller can decide what to do with the next message (i.e. if skipRead=true, +// retMsg=nil). +// +// retMsg != nil if there is a type mismatch, or the message cannot fit within +// 4K bytes. Use isTypeMismatchError or isLargeMessageError to detect such +// errors. +func expectNextServerMessage( + ctx context.Context, + interceptor *interceptor.FrontendInterceptor, + msgType pgwirebase.ServerMessageType, + skipRead bool, +) (retMsg pgproto3.BackendMessage, retErr error) { + // Limit messages for connection transfers to 4K bytes. If we decide that + // we need more, we could lift this restriction. + const maxBodySize = 1 << 12 // 4K + + if ctx.Err() != nil { + return nil, ctx.Err() + } + + typ, size, err := interceptor.PeekMsg() + if err != nil { + return nil, errors.Wrap(err, "peeking message") + } + + if msgType != typ { + return nil, errors.Newf("type mismatch: expected '%c', but found '%c'", msgType, typ) + } + + if size > maxBodySize { + return nil, errors.Newf("too many bytes: expected <= %d, but found %d", maxBodySize, size) + } + + if skipRead { + return nil, nil + } + + msg, err := interceptor.ReadMsg() + if err != nil { + return nil, errors.Wrap(err, "reading message") + } + return msg, nil +} + +// isErrorResponseError returns true if the error is a type mismatch due to +// matching an ErrorResponse pgwire message, and false otherwise. err must come +// from expectNextServerMessage. +func isErrorResponseError(err error) bool { + return err != nil && strings.Contains(err.Error(), "type mismatch") && + strings.Contains(err.Error(), fmt.Sprintf("found '%c'", pgwirebase.ServerMsgErrorResponse)) +} + +// isTypeMismatchError returns true if the error represents a type mismatch +// error, and false otherwise. err must come from expectNextServerMessage. +func isTypeMismatchError(err error) bool { + return err != nil && strings.Contains(err.Error(), "type mismatch") +} + +// isLargeMessageError returns true if the error stems from "too many bytes", +// and false otherwise. This error will be returned if the message has more +// than 4K bytes. Connection migration does not care about such large messages. +// err must come from expectNextServerMessage. +func isLargeMessageError(err error) bool { + return err != nil && strings.Contains(err.Error(), "too many bytes") +} + +// jsonOrRaw returns msg in a json string representation if it can be marshaled +// into one, or in a raw struct string representation otherwise. Only used for +// displaying better error messages. +func jsonOrRaw(msg pgproto3.BackendMessage) string { + m, err := json.Marshal(msg) + if err != nil { + return fmt.Sprintf("%v", msg) + } + return encoding.UnsafeString(m) +} diff --git a/pkg/ccl/sqlproxyccl/conn_migration_test.go b/pkg/ccl/sqlproxyccl/conn_migration_test.go new file mode 100644 index 000000000000..01802e5b7fa9 --- /dev/null +++ b/pkg/ccl/sqlproxyccl/conn_migration_test.go @@ -0,0 +1,791 @@ +// Copyright 2022 The Cockroach Authors. +// +// Licensed as a CockroachDB Enterprise file under the Cockroach Community +// License (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt + +package sqlproxyccl + +import ( + "bytes" + "context" + "io" + "net" + "strings" + "testing" + + "github.com/cockroachdb/cockroach/pkg/ccl/sqlproxyccl/interceptor" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgwirebase" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/errors" + "github.com/jackc/pgproto3/v2" + "github.com/stretchr/testify/require" +) + +func TestRunShowTransferState(t *testing.T) { + defer leaktest.AfterTest(t)() + + t.Run("successful", func(t *testing.T) { + buf := new(bytes.Buffer) + err := runShowTransferState(buf, "foo-bar-baz") + require.NoError(t, err) + + backend := pgproto3.NewBackend(pgproto3.NewChunkReader(buf), buf) + msg, err := backend.Receive() + require.NoError(t, err) + m, ok := msg.(*pgproto3.Query) + require.True(t, ok) + require.Equal(t, "SHOW TRANSFER STATE WITH 'foo-bar-baz'", m.String) + }) + + t.Run("error", func(t *testing.T) { + err := runShowTransferState(&errWriter{}, "foo") + require.Regexp(t, "unexpected Write call", err) + }) +} + +func TestWaitForShowTransferState(t *testing.T) { + defer leaktest.AfterTest(t)() + ctx := context.Background() + + t.Run("context_cancelled", func(t *testing.T) { + tCtx, cancel := context.WithCancel(ctx) + cancel() + + state, token, err := waitForShowTransferState(tCtx, nil, nil, "") + require.EqualError(t, err, context.Canceled.Error()) + require.Equal(t, "", state) + require.Equal(t, "", token) + }) + + expectMsg := func(fi *interceptor.FrontendInterceptor, match string) error { + msg, err := fi.ReadMsg() + if err != nil { + return err + } + j := jsonOrRaw(msg) + if !strings.Contains(j, match) { + return errors.Newf("require message includes '%s', found none in '%s'", match, j) + } + return nil + } + + for _, tc := range []struct { + name string + sendSequence []pgproto3.BackendMessage + postValidate func(*interceptor.FrontendInterceptor) error + err string + recoverableErr bool + }{ + { + // All irrelevant messages are forwarded to the client. This returns + // an error when we see ErrorResponse. + name: "RowDescription/candidate_mismatch", + sendSequence: []pgproto3.BackendMessage{ + &pgproto3.BackendKeyData{}, // not RowDescription + &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: []byte("foo1")}, + {Name: make([]byte, 4096)}, + }, + }, // too large + &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: []byte("foo2")}, + }, + }, // invalid + &pgproto3.ErrorResponse{}, + }, + err: "ambiguous ErrorResponse", + postValidate: func(fi *interceptor.FrontendInterceptor) error { + if err := expectMsg(fi, `"Type":"BackendKeyData"`); err != nil { + return err + } + if err := expectMsg(fi, `"Type":"RowDescription","Fields":[{"Name":"foo1"`); err != nil { + return err + } + if err := expectMsg(fi, `"Type":"RowDescription","Fields":[{"Name":"foo2"`); err != nil { + return err + } + return nil + }, + }, + { + name: "DataRow/type_mismatch", + sendSequence: []pgproto3.BackendMessage{ + &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: []byte("error")}, + {Name: []byte("session_state_base64")}, + {Name: []byte("session_revival_token_base64")}, + {Name: []byte("transfer_key")}, + }, + }, + &pgproto3.ReadyForQuery{}, + }, + err: "type mismatch: expected 'D', but found 'Z'", + }, + { + name: "DataRow/invalid_transfer_key", + sendSequence: []pgproto3.BackendMessage{ + &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: []byte("error")}, + {Name: []byte("session_state_base64")}, + {Name: []byte("session_revival_token_base64")}, + {Name: []byte("transfer_key")}, + }, + }, + &pgproto3.DataRow{ + Values: [][]byte{ + {}, + {}, + {}, + []byte("bar"), + }, + }, + }, + err: "expected 'foo-transfer-key' as transfer key, found 'bar'", + }, + { + // Large state should abort transfers. + name: "DataRow/large_state", + sendSequence: []pgproto3.BackendMessage{ + &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: []byte("error")}, + {Name: []byte("session_state_base64")}, + {Name: []byte("session_revival_token_base64")}, + {Name: []byte("transfer_key")}, + }, + }, + &pgproto3.DataRow{ + Values: [][]byte{ + {}, + make([]byte, 5000), + {}, + {}, + }, + }, + }, + err: "too many bytes", + }, + { + name: "CommandComplete/type_mismatch", + sendSequence: []pgproto3.BackendMessage{ + &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: []byte("error")}, + {Name: []byte("session_state_base64")}, + {Name: []byte("session_revival_token_base64")}, + {Name: []byte("transfer_key")}, + }, + }, + &pgproto3.DataRow{ + Values: [][]byte{ + {}, + {}, + {}, + []byte("foo-transfer-key"), + }, + }, + &pgproto3.ReadyForQuery{}, + }, + err: "type mismatch: expected 'C', but found 'Z'", + }, + { + name: "CommandComplete/value_mismatch", + sendSequence: []pgproto3.BackendMessage{ + &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: []byte("error")}, + {Name: []byte("session_state_base64")}, + {Name: []byte("session_revival_token_base64")}, + {Name: []byte("transfer_key")}, + }, + }, + &pgproto3.DataRow{ + Values: [][]byte{ + {}, + {}, + {}, + []byte("foo-transfer-key"), + }, + }, + &pgproto3.CommandComplete{CommandTag: []byte("SELECT 2")}, + }, + err: "invalid CommandComplete", + }, + { + name: "ReadyForQuery/type_mismatch", + sendSequence: []pgproto3.BackendMessage{ + &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: []byte("error")}, + {Name: []byte("session_state_base64")}, + {Name: []byte("session_revival_token_base64")}, + {Name: []byte("transfer_key")}, + }, + }, + &pgproto3.DataRow{ + Values: [][]byte{ + {}, + {}, + {}, + []byte("foo-transfer-key"), + }, + }, + &pgproto3.CommandComplete{CommandTag: []byte("SHOW TRANSFER STATE 1")}, + &pgproto3.CommandComplete{}, + }, + err: "type mismatch: expected 'Z', but found 'C'", + }, + { + // This should be a common case with open transactions. + name: "transfer_state_error", + sendSequence: []pgproto3.BackendMessage{ + &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: []byte("error")}, + {Name: []byte("session_state_base64")}, + {Name: []byte("session_revival_token_base64")}, + {Name: []byte("transfer_key")}, + }, + }, + &pgproto3.DataRow{ + Values: [][]byte{ + []byte("serialization error"), + {}, + {}, + []byte("foo-transfer-key"), + }, + }, + &pgproto3.CommandComplete{CommandTag: []byte("SHOW TRANSFER STATE 1")}, + &pgproto3.ReadyForQuery{}, + }, + err: "serialization error", + recoverableErr: true, + }, + { + name: "successful", + sendSequence: []pgproto3.BackendMessage{ + &pgproto3.BackendKeyData{}, + &pgproto3.RowDescription{}, + &pgproto3.CommandComplete{}, + &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: []byte("error")}, + {Name: []byte("session_state_base64")}, + {Name: []byte("session_revival_token_base64")}, + {Name: []byte("transfer_key")}, + }, + }, + &pgproto3.DataRow{ + Values: [][]byte{ + {}, + []byte("foo-state"), + []byte("foo-token"), + []byte("foo-transfer-key"), + }, + }, + &pgproto3.CommandComplete{CommandTag: []byte("SHOW TRANSFER STATE 1")}, + &pgproto3.ReadyForQuery{}, + }, + postValidate: func(fi *interceptor.FrontendInterceptor) error { + if err := expectMsg(fi, `"Type":"BackendKeyData"`); err != nil { + return err + } + if err := expectMsg(fi, `"Type":"RowDescription"`); err != nil { + return err + } + if err := expectMsg(fi, `"Type":"CommandComplete"`); err != nil { + return err + } + return nil + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + buf := new(bytes.Buffer) + for _, m := range tc.sendSequence { + writeServerMsg(buf, m) + } + toClient := new(bytes.Buffer) + + state, token, err := waitForShowTransferState(ctx, + interceptor.NewFrontendInterceptor(buf), toClient, "foo-transfer-key") + if tc.err == "" { + require.NoError(t, err) + require.Equal(t, "foo-state", state) + require.Equal(t, "foo-token", token) + } else { + require.Regexp(t, tc.err, err) + if tc.recoverableErr { + require.True(t, isConnRecoverableError(err)) + } else { + require.False(t, isConnRecoverableError(err)) + } + } + + // Verify that all messages were read, and forwarding was correct. + require.Equal(t, 0, buf.Len()) + if tc.postValidate != nil { + frontend := interceptor.NewFrontendInterceptor(toClient) + require.NoError(t, tc.postValidate(frontend)) + _, _, err := frontend.PeekMsg() + require.Regexp(t, "EOF", err) + } + require.Equal(t, 0, toClient.Len()) + }) + } +} + +func TestIsValidStartTransferStateResponse(t *testing.T) { + defer leaktest.AfterTest(t)() + + for _, tc := range []struct { + name string + msg *pgproto3.RowDescription + expected bool + }{ + { + name: "nil_message", + msg: nil, + expected: false, + }, + { + name: "invalid_number_of_columns", + msg: &pgproto3.RowDescription{ + // 3 columns instead of 4. + Fields: []pgproto3.FieldDescription{ + {Name: []byte("foo")}, + {Name: []byte("bar")}, + {Name: []byte("baz")}, + }, + }, + expected: false, + }, + { + name: "invalid_column_names", + msg: &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: []byte("error")}, + {Name: []byte("session_state_foo")}, + {Name: []byte("session_revival_token_bar")}, + {Name: []byte("apple")}, + }, + }, + expected: false, + }, + { + name: "valid", + msg: &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: []byte("error")}, + {Name: []byte("session_state_base64")}, + {Name: []byte("session_revival_token_base64")}, + {Name: []byte("transfer_key")}, + }, + }, + expected: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + valid := isValidStartTransferStateResponse(tc.msg) + if tc.expected { + require.True(t, valid) + } else { + require.False(t, valid) + } + }) + } +} + +func TestIsValidEndTransferStateResponse(t *testing.T) { + defer leaktest.AfterTest(t)() + + for _, tc := range []struct { + name string + msg *pgproto3.CommandComplete + expected bool + }{ + { + name: "nil_message", + msg: nil, + expected: false, + }, + { + name: "unsupported_command_tag", + msg: &pgproto3.CommandComplete{CommandTag: []byte("foobarbaz")}, + expected: false, + }, + { + name: "invalid_row_count", + msg: &pgproto3.CommandComplete{CommandTag: []byte("SHOW TRANSFER STATE 2")}, + expected: false, + }, + { + name: "valid", + msg: &pgproto3.CommandComplete{CommandTag: []byte("SHOW TRANSFER STATE 1")}, + expected: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + valid := isValidEndTransferStateResponse(tc.msg) + if tc.expected { + require.True(t, valid) + } else { + require.False(t, valid) + } + }) + } +} + +func TestParseTransferStateResponse(t *testing.T) { + defer leaktest.AfterTest(t)() + + for _, tc := range []struct { + name string + msg *pgproto3.DataRow + transferKey string + expectedErr string + }{ + { + name: "nil_message", + msg: nil, + expectedErr: "DataRow message is nil", + }, + { + name: "invalid_response", + msg: &pgproto3.DataRow{ + // 3 columns instead of 4. + Values: [][]byte{[]byte("foo"), []byte("bar"), []byte("baz")}, + }, + expectedErr: "unexpected 3 columns in DataRow", + }, + { + name: "invalid_transfer_key", + msg: &pgproto3.DataRow{ + Values: [][]byte{[]byte("foo"), []byte("bar"), []byte("baz"), []byte("carl")}, + }, + transferKey: "car", + expectedErr: "expected 'car' as transfer key, found 'carl'", + }, + } { + t.Run(tc.name, func(t *testing.T) { + transferErr, state, revivalToken, err := parseTransferStateResponse(tc.msg, tc.transferKey) + require.Regexp(t, tc.expectedErr, err) + require.Equal(t, "", transferErr) + require.Equal(t, "", state) + require.Equal(t, "", revivalToken) + }) + } + + t.Run("valid", func(t *testing.T) { + msg := &pgproto3.DataRow{ + Values: [][]byte{[]byte("foo"), []byte("bar"), []byte("baz"), []byte("carl")}, + } + transferErr, state, revivalToken, err := parseTransferStateResponse(msg, "carl") + require.NoError(t, err) + require.Equal(t, "foo", transferErr) + require.Equal(t, "bar", state) + require.Equal(t, "baz", revivalToken) + + // Ensure that returned strings are copied. Alternatively, we could also + // check pointers using encoding.UnsafeConvertStringToBytes. + msg.Values[0][1] = '-' + msg.Values[1][1] = '-' + msg.Values[2][1] = '-' + require.Equal(t, "foo", transferErr) + require.Equal(t, "bar", state) + require.Equal(t, "baz", revivalToken) + }) +} + +func TestRunAndWaitForDeserializeSession(t *testing.T) { + defer leaktest.AfterTest(t)() + ctx := context.Background() + + t.Run("write_failed", func(t *testing.T) { + err := runAndWaitForDeserializeSession(ctx, &errWriter{}, nil, "foo") + require.Regexp(t, "unexpected Write call", err) + }) + + for _, tc := range []struct { + name string + sendSequence []pgproto3.BackendMessage + err string + }{ + { + name: "RowDescription/type_mismatch", + sendSequence: []pgproto3.BackendMessage{ + &pgproto3.ErrorResponse{}, + }, + err: "type mismatch: expected 'T', but found 'E'", + }, + { + name: "RowDescription/column_mismatch/length", + sendSequence: []pgproto3.BackendMessage{ + &pgproto3.RowDescription{}, + }, + err: "invalid RowDescription", + }, + { + name: "RowDescription/column_mismatch/name", + sendSequence: []pgproto3.BackendMessage{ + &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{{Name: []byte("bar")}}}, + }, + err: "invalid RowDescription", + }, + { + name: "DataRow/type_mismatch", + sendSequence: []pgproto3.BackendMessage{ + &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: []byte("crdb_internal.deserialize_session")}, + }, + }, + &pgproto3.ReadyForQuery{}, + }, + err: "type mismatch: expected 'D', but found 'Z'", + }, + { + name: "DataRow/column_mismatch/length", + sendSequence: []pgproto3.BackendMessage{ + &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: []byte("crdb_internal.deserialize_session")}, + }, + }, + &pgproto3.DataRow{}, + }, + err: "invalid DataRow", + }, + { + name: "DataRow/column_mismatch/value", + sendSequence: []pgproto3.BackendMessage{ + &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: []byte("crdb_internal.deserialize_session")}, + }, + }, + &pgproto3.DataRow{Values: [][]byte{[]byte("temp")}}, + }, + err: "invalid DataRow", + }, + { + name: "CommandComplete/type_mismatch", + sendSequence: []pgproto3.BackendMessage{ + &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: []byte("crdb_internal.deserialize_session")}, + }, + }, + &pgproto3.DataRow{Values: [][]byte{[]byte("t")}}, + &pgproto3.ReadyForQuery{}, + }, + err: "type mismatch: expected 'C', but found 'Z'", + }, + { + name: "CommandComplete/value_mismatch", + sendSequence: []pgproto3.BackendMessage{ + &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: []byte("crdb_internal.deserialize_session")}, + }, + }, + &pgproto3.DataRow{Values: [][]byte{[]byte("t")}}, + &pgproto3.CommandComplete{CommandTag: []byte("SELECT 2")}, + }, + err: "invalid CommandComplete", + }, + { + name: "ReadyForQuery/type_mismatch", + sendSequence: []pgproto3.BackendMessage{ + &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: []byte("crdb_internal.deserialize_session")}, + }, + }, + &pgproto3.DataRow{Values: [][]byte{[]byte("t")}}, + &pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}, + &pgproto3.CommandComplete{}, + }, + err: "type mismatch: expected 'Z', but found 'C'", + }, + { + name: "successful", + sendSequence: []pgproto3.BackendMessage{ + &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + {Name: []byte("crdb_internal.deserialize_session")}, + }, + }, + &pgproto3.DataRow{Values: [][]byte{[]byte("t")}}, + &pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}, + &pgproto3.ReadyForQuery{}, + }, + err: "", + }, + } { + t.Run(tc.name, func(t *testing.T) { + r, w := net.Pipe() + defer r.Close() + defer w.Close() + + msgChan := make(chan pgproto3.FrontendMessage, 1) + go func() { + backend := interceptor.NewBackendInterceptor(w) + msg, _ := backend.ReadMsg() + msgChan <- msg + + for _, m := range tc.sendSequence { + writeServerMsg(w, m) + } + }() + + err := runAndWaitForDeserializeSession(ctx, r, + interceptor.NewFrontendInterceptor(r), "foo-transfer-key") + if tc.err == "" { + require.NoError(t, err) + } else { + require.Regexp(t, tc.err, err) + } + + msg := <-msgChan + m, ok := msg.(*pgproto3.Query) + require.True(t, ok) + const queryStr = "SELECT crdb_internal.deserialize_session(decode('foo-transfer-key', 'base64'))" + require.Equal(t, queryStr, m.String) + }) + } +} + +func TestExpectNextServerMessage(t *testing.T) { + defer leaktest.AfterTest(t)() + ctx := context.Background() + + t.Run("context_cancelled", func(t *testing.T) { + tCtx, cancel := context.WithCancel(ctx) + cancel() + + msg, err := expectNextServerMessage(tCtx, nil, pgwirebase.ServerMsgReady, false) + require.EqualError(t, err, context.Canceled.Error()) + require.Nil(t, msg) + }) + + t.Run("peek_error", func(t *testing.T) { + r, w := net.Pipe() + r.Close() + w.Close() + + msg, err := expectNextServerMessage( + ctx, + interceptor.NewFrontendInterceptor(r), + pgwirebase.ServerMsgReady, + false, + ) + require.Regexp(t, "peeking message", err) + require.Nil(t, msg) + }) + + t.Run("type_mismatch", func(t *testing.T) { + r, w := net.Pipe() + defer r.Close() + defer w.Close() + + go func() { + writeServerMsg(w, &pgproto3.ErrorResponse{}) + }() + + msg, err := expectNextServerMessage( + ctx, + interceptor.NewFrontendInterceptor(r), + pgwirebase.ServerMsgRowDescription, + false, + ) + require.Regexp(t, "type mismatch", err) + require.Nil(t, msg) + require.True(t, isTypeMismatchError(err)) + require.True(t, isErrorResponseError(err)) + }) + + t.Run("too_many_bytes", func(t *testing.T) { + r, w := net.Pipe() + defer r.Close() + defer w.Close() + + go func() { + writeServerMsg(w, &pgproto3.DataRow{Values: make([][]byte, 2<<12+1)}) + }() + + msg, err := expectNextServerMessage( + ctx, + interceptor.NewFrontendInterceptor(r), + pgwirebase.ServerMsgDataRow, + false, + ) + require.Regexp(t, "too many bytes", err) + require.Nil(t, msg) + require.True(t, isLargeMessageError(err)) + }) + + t.Run("skipRead=true", func(t *testing.T) { + r, w := net.Pipe() + defer r.Close() + defer w.Close() + + go func() { + writeServerMsg(w, &pgproto3.ReadyForQuery{}) + }() + + msg, err := expectNextServerMessage( + ctx, + interceptor.NewFrontendInterceptor(r), + pgwirebase.ServerMsgReady, + true, + ) + require.Nil(t, err) + require.Nil(t, msg) + }) + + t.Run("skipRead=false", func(t *testing.T) { + r, w := net.Pipe() + defer r.Close() + defer w.Close() + + go func() { + writeServerMsg(w, &pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{{Name: []byte("foo")}}, + }) + }() + + msg, err := expectNextServerMessage( + ctx, + interceptor.NewFrontendInterceptor(r), + pgwirebase.ServerMsgRowDescription, + false, + ) + require.Nil(t, err) + pgMsg, ok := msg.(*pgproto3.RowDescription) + require.True(t, ok) + require.Equal(t, []byte("foo"), pgMsg.Fields[0].Name) + }) +} + +var _ io.Writer = &errWriter{} + +// errWriter is an io.Writer that fails whenever a Write call is made. +type errWriter struct{} + +// Write implements the io.Writer interface. +func (w *errWriter) Write(p []byte) (int, error) { + return 0, errors.AssertionFailedf("unexpected Write call") +} + +func writeServerMsg(w io.Writer, msg pgproto3.BackendMessage) { + _, _ = w.Write(msg.Encode(nil)) +} diff --git a/pkg/ccl/sqlproxyccl/forwarder.go b/pkg/ccl/sqlproxyccl/forwarder.go index 9810ef695a7f..26af5b589fb4 100644 --- a/pkg/ccl/sqlproxyccl/forwarder.go +++ b/pkg/ccl/sqlproxyccl/forwarder.go @@ -204,3 +204,21 @@ func wrapServerToClientError(err error) error { } return newErrorf(codeBackendDisconnected, "copying from target server to client: %s", err) } + +// errConnRecoverableSentinel exists as a sentinel value to denote that errors +// should not terminate the connection. +var errConnRecoverableSentinel = errors.New("connection recoverable error") + +// markAsConnRecoverableError marks the given error with errConnRecoverableSentinel +// to denote that the connection can continue despite having an error. +func markAsConnRecoverableError(err error) error { + return errors.Mark(err, errConnRecoverableSentinel) +} + +// isConnRecoverableError checks whether a given error denotes that a connection +// is recoverable. If this is true, the caller should try to recover the +// connection (e.g. continue the forwarding process instead of terminating the +// forwarder). +func isConnRecoverableError(err error) bool { + return errors.Is(err, errConnRecoverableSentinel) +} diff --git a/pkg/ccl/sqlproxyccl/forwarder_test.go b/pkg/ccl/sqlproxyccl/forwarder_test.go index 0acdf12e5c00..2fdbd7ff33a4 100644 --- a/pkg/ccl/sqlproxyccl/forwarder_test.go +++ b/pkg/ccl/sqlproxyccl/forwarder_test.go @@ -299,3 +299,13 @@ func TestWrapServerToClientError(t *testing.T) { } } } + +func TestConnectionRecoverableError(t *testing.T) { + defer leaktest.AfterTest(t)() + + err := errors.New("foobar") + require.False(t, isConnRecoverableError(err)) + err = markAsConnRecoverableError(err) + require.True(t, isConnRecoverableError(err)) + require.True(t, errors.Is(err, errConnRecoverableSentinel)) +}