diff --git a/pkg/ccl/sqlproxyccl/authentication.go b/pkg/ccl/sqlproxyccl/authentication.go index e3493be8515a..4f09a93e5c69 100644 --- a/pkg/ccl/sqlproxyccl/authentication.go +++ b/pkg/ccl/sqlproxyccl/authentication.go @@ -49,6 +49,9 @@ func authenticate(clientConn, crdbConn net.Conn) error { case *pgproto3.ParameterStatus: // Server sent status message; keep reading messages until // `pgproto3.ReadyForQuery` is encountered. + case *pgproto3.BackendKeyData: + // Server sent backend key data; keep reading messages until + // `pgproto3.ReadyForQuery` is encountered. case *pgproto3.ErrorResponse: // Server has rejected the authentication response from the client and // has closed the connection. diff --git a/pkg/ccl/sqlproxyccl/authentication_test.go b/pkg/ccl/sqlproxyccl/authentication_test.go index a0a9aceb9a0a..1719f5ec9297 100644 --- a/pkg/ccl/sqlproxyccl/authentication_test.go +++ b/pkg/ccl/sqlproxyccl/authentication_test.go @@ -108,11 +108,11 @@ func TestAuthenticateUnexpectedMessage(t *testing.T) { fe := pgproto3.NewFrontend(pgproto3.NewChunkReader(cli), cli) go func() { - err := be.Send(&pgproto3.BackendKeyData{}) + err := be.Send(&pgproto3.BindComplete{}) require.NoError(t, err) beMsg, err := fe.Receive() require.NoError(t, err) - require.Equal(t, beMsg, &pgproto3.BackendKeyData{}) + require.Equal(t, beMsg, &pgproto3.BindComplete{}) }() err := authenticate(srv, cli) diff --git a/pkg/sql/pgwire/conn.go b/pkg/sql/pgwire/conn.go index 6c80ce14abd9..0bc2d7349088 100644 --- a/pkg/sql/pgwire/conn.go +++ b/pkg/sql/pgwire/conn.go @@ -302,10 +302,8 @@ func (c *conn) serveImpl( dummyCh := make(chan error) close(dummyCh) procCh = dummyCh - // An initial readyForQuery message is part of the handshake. - c.msgBuilder.initMsg(pgwirebase.ServerMsgReady) - c.msgBuilder.writeByte(byte(sql.IdleTxnBlock)) - if err := c.msgBuilder.finishMsg(c.conn); err != nil { + + if err := c.sendReadyForQuery(); err != nil { reserved.Close(ctx) return } @@ -688,14 +686,35 @@ func (c *conn) sendInitialConnData( if err := c.sendParamStatus("is_superuser", superUserVal); err != nil { return sql.ConnectionHandler{}, err } + if err := c.sendReadyForQuery(); err != nil { + return sql.ConnectionHandler{}, err + } + return connHandler, nil +} - // An initial readyForQuery message is part of the handshake. +// sendReadyForQuery sends the final messages of the connection handshake. +// This includes a placeholder BackendKeyData message and a ServerMsgReady +// message indicating that there is no active transaction. +func (c *conn) sendReadyForQuery() error { + // Send the client a dummy BackendKeyData message. This is necessary for + // compatibility with tools that require this message. This information is + // normally used by clients to send a CancelRequest message: + // https://www.postgresql.org/docs/9.6/static/protocol-flow.html#AEN112861 + // CockroachDB currently ignores all CancelRequests. + c.msgBuilder.initMsg(pgwirebase.ServerMsgBackendKeyData) + c.msgBuilder.putInt32(0) + c.msgBuilder.putInt32(0) + if err := c.msgBuilder.finishMsg(c.conn); err != nil { + return err + } + + // An initial ServerMsgReady message is part of the handshake. c.msgBuilder.initMsg(pgwirebase.ServerMsgReady) c.msgBuilder.writeByte(byte(sql.IdleTxnBlock)) if err := c.msgBuilder.finishMsg(c.conn); err != nil { - return sql.ConnectionHandler{}, err + return err } - return connHandler, nil + return nil } // An error is returned iff the statement buffer has been closed. In that case, diff --git a/pkg/sql/pgwire/conn_test.go b/pkg/sql/pgwire/conn_test.go index acb14f50ea9c..02de1473a93a 100644 --- a/pkg/sql/pgwire/conn_test.go +++ b/pkg/sql/pgwire/conn_test.go @@ -1521,3 +1521,47 @@ func TestSetSessionArguments(t *testing.T) { t.Fatal(err) } } + +// TestCancelQuery uses the pgwire-level query cancellation protocol provided +// by lib/pq to make sure that canceling a query has no effect, and makes sure +// the dummy BackendKeyData does not cause problems. +func TestCancelQuery(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + cancelCtx, cancel := context.WithCancel(context.Background()) + args := base.TestServerArgs{ + Knobs: base.TestingKnobs{ + SQLExecutor: &sql.ExecutorTestingKnobs{ + BeforeExecute: func(ctx context.Context, stmt string) { + if strings.Contains(stmt, "pg_sleep") { + cancel() + } + }, + }, + }, + } + s, _, _ := serverutils.StartServer(t, args) + defer s.Stopper().Stop(cancelCtx) + + pgURL, cleanupFunc := sqlutils.PGUrl( + t, s.ServingSQLAddr(), "TestCancelQuery" /* prefix */, url.User(security.RootUser), + ) + defer cleanupFunc() + + db, err := gosql.Open("postgres", pgURL.String()) + require.NoError(t, err) + defer db.Close() + + // Cancellation has no effect on ongoing query. + if _, err := db.QueryContext(cancelCtx, "select pg_sleep(0)"); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + // Context is already canceled, so error should come before execution. + if _, err := db.QueryContext(cancelCtx, "select 1"); err == nil { + t.Fatal("expected error") + } else if err.Error() != "context canceled" { + t.Fatalf("unexpected error: %s", err) + } +} diff --git a/pkg/sql/pgwire/pgwirebase/msg.go b/pkg/sql/pgwire/pgwirebase/msg.go index 0eb8f478c80a..fa50ba48e4f9 100644 --- a/pkg/sql/pgwire/pgwirebase/msg.go +++ b/pkg/sql/pgwire/pgwirebase/msg.go @@ -37,6 +37,7 @@ const ( ClientMsgTerminate ClientMessageType = 'X' ServerMsgAuth ServerMessageType = 'R' + ServerMsgBackendKeyData ServerMessageType = 'K' ServerMsgBindComplete ServerMessageType = '2' ServerMsgCommandComplete ServerMessageType = 'C' ServerMsgCloseComplete ServerMessageType = '3' diff --git a/pkg/sql/pgwire/pgwirebase/servermessagetype_string.go b/pkg/sql/pgwire/pgwirebase/servermessagetype_string.go index a0e9999506b2..96b8b7583265 100644 --- a/pkg/sql/pgwire/pgwirebase/servermessagetype_string.go +++ b/pkg/sql/pgwire/pgwirebase/servermessagetype_string.go @@ -9,6 +9,7 @@ func _() { // Re-run the stringer command to generate them again. var x [1]struct{} _ = x[ServerMsgAuth-82] + _ = x[ServerMsgBackendKeyData-75] _ = x[ServerMsgBindComplete-50] _ = x[ServerMsgCommandComplete-67] _ = x[ServerMsgCloseComplete-51] @@ -31,18 +32,19 @@ const ( _ServerMessageType_name_1 = "ServerMsgCommandCompleteServerMsgDataRowServerMsgErrorResponse" _ServerMessageType_name_2 = "ServerMsgCopyInResponse" _ServerMessageType_name_3 = "ServerMsgEmptyQuery" - _ServerMessageType_name_4 = "ServerMsgNoticeResponse" - _ServerMessageType_name_5 = "ServerMsgAuthServerMsgParameterStatusServerMsgRowDescription" - _ServerMessageType_name_6 = "ServerMsgReady" - _ServerMessageType_name_7 = "ServerMsgNoData" - _ServerMessageType_name_8 = "ServerMsgPortalSuspendedServerMsgParameterDescription" + _ServerMessageType_name_4 = "ServerMsgBackendKeyData" + _ServerMessageType_name_5 = "ServerMsgNoticeResponse" + _ServerMessageType_name_6 = "ServerMsgAuthServerMsgParameterStatusServerMsgRowDescription" + _ServerMessageType_name_7 = "ServerMsgReady" + _ServerMessageType_name_8 = "ServerMsgNoData" + _ServerMessageType_name_9 = "ServerMsgPortalSuspendedServerMsgParameterDescription" ) var ( _ServerMessageType_index_0 = [...]uint8{0, 22, 43, 65} _ServerMessageType_index_1 = [...]uint8{0, 24, 40, 62} - _ServerMessageType_index_5 = [...]uint8{0, 13, 37, 60} - _ServerMessageType_index_8 = [...]uint8{0, 24, 53} + _ServerMessageType_index_6 = [...]uint8{0, 13, 37, 60} + _ServerMessageType_index_9 = [...]uint8{0, 24, 53} ) func (i ServerMessageType) String() string { @@ -57,18 +59,20 @@ func (i ServerMessageType) String() string { return _ServerMessageType_name_2 case i == 73: return _ServerMessageType_name_3 - case i == 78: + case i == 75: return _ServerMessageType_name_4 + case i == 78: + return _ServerMessageType_name_5 case 82 <= i && i <= 84: i -= 82 - return _ServerMessageType_name_5[_ServerMessageType_index_5[i]:_ServerMessageType_index_5[i+1]] + return _ServerMessageType_name_6[_ServerMessageType_index_6[i]:_ServerMessageType_index_6[i+1]] case i == 90: - return _ServerMessageType_name_6 - case i == 110: return _ServerMessageType_name_7 + case i == 110: + return _ServerMessageType_name_8 case 115 <= i && i <= 116: i -= 115 - return _ServerMessageType_name_8[_ServerMessageType_index_8[i]:_ServerMessageType_index_8[i+1]] + return _ServerMessageType_name_9[_ServerMessageType_index_9[i]:_ServerMessageType_index_9[i+1]] default: return "ServerMessageType(" + strconv.FormatInt(int64(i), 10) + ")" } diff --git a/pkg/testutils/pgtest/pgtest.go b/pkg/testutils/pgtest/pgtest.go index 7f3c23bed893..513ffd8dc747 100644 --- a/pkg/testutils/pgtest/pgtest.go +++ b/pkg/testutils/pgtest/pgtest.go @@ -64,10 +64,22 @@ func NewPGTest(ctx context.Context, addr, user string) (*PGTest, error) { } msgs, err := p.Until(false /* keepErrMsg */, &pgproto3.ReadyForQuery{}) foundCrdb := false + var backendKeyData *pgproto3.BackendKeyData for _, msg := range msgs { if s, ok := msg.(*pgproto3.ParameterStatus); ok && s.Name == "crdb_version" { foundCrdb = true } + if d, ok := msg.(*pgproto3.BackendKeyData); ok { + // We inspect the BackendKeyData outside of the loop since we only + // want to do the assertions if foundCrdb==true. + backendKeyData = d + } + } + if backendKeyData == nil { + return nil, errors.Errorf("did not receive BackendKeyData") + } + if foundCrdb && (backendKeyData.ProcessID != 0 || backendKeyData.SecretKey != 0) { + return nil, errors.Errorf("unexpected BackendKeyData: %+v", d) } p.isCockroachDB = foundCrdb success = err == nil