From aa308396b168d86a87b7d31590ecb51550663180 Mon Sep 17 00:00:00 2001 From: Raphael 'kena' Poss Date: Fri, 11 Feb 2022 16:30:11 +0100 Subject: [PATCH 1/2] cli: modernize the SQL API calls This brings the cli (sub-)packages kicking and screaming into the post-go1.8 era, by adding the missing context arguments. Release note (bug fix): Some of the `cockroach node` subcommands did not handle `--timeout` properly. This is now fixed. --- pkg/cli/auth.go | 25 ++-- pkg/cli/cli_test.go | 4 +- pkg/cli/clierror/syntax_error_test.go | 3 +- pkg/cli/clisqlcfg/context.go | 7 +- pkg/cli/clisqlclient/api.go | 24 ++-- pkg/cli/clisqlclient/conn.go | 166 +++++++++++++---------- pkg/cli/clisqlclient/conn_test.go | 19 +-- pkg/cli/clisqlclient/make_query.go | 33 +++-- pkg/cli/clisqlclient/statement_diag.go | 91 +++++++------ pkg/cli/clisqlclient/txn_shim.go | 12 +- pkg/cli/clisqlexec/run_query.go | 14 +- pkg/cli/clisqlexec/run_query_test.go | 9 +- pkg/cli/clisqlshell/sql.go | 27 +++- pkg/cli/clisqlshell/statement_diag.go | 6 +- pkg/cli/debug_job_trace.go | 14 +- pkg/cli/democluster/demo_cluster_test.go | 3 +- pkg/cli/doctor.go | 25 ++-- pkg/cli/node.go | 20 ++- pkg/cli/nodelocal.go | 2 +- pkg/cli/statement_bundle.go | 16 +-- pkg/cli/statement_bundle_test.go | 7 +- pkg/cli/statement_diag.go | 22 ++- pkg/cli/userfiletable_test.go | 4 +- pkg/cli/zip.go | 6 +- 24 files changed, 316 insertions(+), 243 deletions(-) diff --git a/pkg/cli/auth.go b/pkg/cli/auth.go index 5acc9edf7846..a0a9d15484d3 100644 --- a/pkg/cli/auth.go +++ b/pkg/cli/auth.go @@ -11,7 +11,7 @@ package cli import ( - "database/sql/driver" + "context" "fmt" "net/http" "os" @@ -98,8 +98,11 @@ func createAuthSessionToken( } defer func() { resErr = errors.CombineErrors(resErr, sqlConn.Close()) }() + ctx := context.Background() + // First things first. Does the user exist? - _, rows, err := sqlExecCtx.RunQuery(sqlConn, + _, rows, err := sqlExecCtx.RunQuery(ctx, + sqlConn, clisqlclient.MakeQuery(`SELECT count(username) FROM system.users WHERE username = $1 AND NOT "isRole"`, username), false) if err != nil { return -1, nil, err @@ -122,13 +125,11 @@ VALUES($1, $2, $3) RETURNING id ` var id int64 - row, err := sqlConn.QueryRow( + row, err := sqlConn.QueryRow(ctx, insertSessionStmt, - []driver.Value{ - hashedSecret, - username, - expiration, - }, + hashedSecret, + username, + expiration, ) if err != nil { return -1, nil, err @@ -176,7 +177,9 @@ func runLogout(cmd *cobra.Command, args []string) (resErr error) { id AS "session ID", "revokedAt" AS "revoked"`, username) - return sqlExecCtx.RunQueryAndFormatResults(sqlConn, os.Stdout, stderr, logoutQuery) + return sqlExecCtx.RunQueryAndFormatResults( + context.Background(), + sqlConn, os.Stdout, stderr, logoutQuery) } var authListCmd = &cobra.Command{ @@ -206,7 +209,9 @@ SELECT username, "revokedAt" as "revoked", "lastUsedAt" as "last used" FROM system.web_sessions`) - return sqlExecCtx.RunQueryAndFormatResults(sqlConn, os.Stdout, stderr, logoutQuery) + return sqlExecCtx.RunQueryAndFormatResults( + context.Background(), + sqlConn, os.Stdout, stderr, logoutQuery) } var authCmds = []*cobra.Command{ diff --git a/pkg/cli/cli_test.go b/pkg/cli/cli_test.go index 84ff70ad23fd..57eed819d5bc 100644 --- a/pkg/cli/cli_test.go +++ b/pkg/cli/cli_test.go @@ -34,12 +34,12 @@ func TestCLITimeout(t *testing.T) { // the timeout a chance to have an effect. We specify --all to include some // slower to access virtual tables in the query. testutils.SucceedsSoon(t, func() error { - out, err := c.RunWithCapture("node status 1 --all --timeout 1ns") + out, err := c.RunWithCapture("node status 1 --all --timeout 1ms") if err != nil { t.Fatal(err) } - const exp = `node status 1 --all --timeout 1ns + const exp = `node status 1 --all --timeout 1ms ERROR: query execution canceled due to statement timeout SQLSTATE: 57014 ` diff --git a/pkg/cli/clierror/syntax_error_test.go b/pkg/cli/clierror/syntax_error_test.go index 2cc0334b1a25..95cc9e1ffdeb 100644 --- a/pkg/cli/clierror/syntax_error_test.go +++ b/pkg/cli/clierror/syntax_error_test.go @@ -11,6 +11,7 @@ package clierror_test import ( + "context" "io/ioutil" "net/url" "testing" @@ -44,7 +45,7 @@ func TestIsSQLSyntaxError(t *testing.T) { } }() - _, err := conn.QueryRow(`INVALID SYNTAX`, nil) + _, err := conn.QueryRow(context.Background(), `INVALID SYNTAX`) if !clierror.IsSQLSyntaxError(err) { t.Fatalf("expected error to be recognized as syntax error: %+v", err) } diff --git a/pkg/cli/clisqlcfg/context.go b/pkg/cli/clisqlcfg/context.go index ca2d0e54efac..7fae0dd709c4 100644 --- a/pkg/cli/clisqlcfg/context.go +++ b/pkg/cli/clisqlcfg/context.go @@ -13,6 +13,7 @@ package clisqlcfg import ( + "context" "fmt" "os" "strconv" @@ -213,7 +214,8 @@ func (c *Context) maybeSetSafeUpdates(conn clisqlclient.Conn) { safeUpdates = c.CliCtx.IsInteractive } if safeUpdates { - if err := conn.Exec("SET sql_safe_updates = TRUE", nil); err != nil { + if err := conn.Exec(context.Background(), + "SET sql_safe_updates = TRUE"); err != nil { // We only enable the setting in interactive sessions. Ignoring // the error with a warning is acceptable, because the user is // there to decide what they want to do if it doesn't work. @@ -228,5 +230,6 @@ func (c *Context) maybeSetReadOnly(conn clisqlclient.Conn) error { if !c.ReadOnly { return nil } - return conn.Exec("SET default_transaction_read_only = TRUE", nil) + return conn.Exec(context.Background(), + "SET default_transaction_read_only = TRUE") } diff --git a/pkg/cli/clisqlclient/api.go b/pkg/cli/clisqlclient/api.go index 8e30dfd8c4d0..0c27bd21c9e7 100644 --- a/pkg/cli/clisqlclient/api.go +++ b/pkg/cli/clisqlclient/api.go @@ -11,6 +11,7 @@ package clisqlclient import ( + "context" "database/sql/driver" "reflect" "time" @@ -26,24 +27,24 @@ type Conn interface { EnsureConn() error // Exec executes a statement. - Exec(query string, args []driver.Value) error + Exec(ctx context.Context, query string, args ...interface{}) error // Query returns one or more SQL statements and returns the // corresponding result set(s). - Query(query string, args []driver.Value) (Rows, error) + Query(ctx context.Context, query string, args ...interface{}) (Rows, error) // QueryRow execute a SQL query returning exactly one row // and retrieves the returned values. An error is returned // if the query returns zero or more than one row. - QueryRow(query string, args []driver.Value) ([]driver.Value, error) + QueryRow(ctx context.Context, query string, args ...interface{}) ([]driver.Value, error) // ExecTxn runs fn inside a transaction and retries it as needed. - ExecTxn(fn func(TxBoundConn) error) error + ExecTxn(ctx context.Context, fn func(context.Context, TxBoundConn) error) error // GetLastQueryStatistics returns the detailed latency stats for the // last executed query, if supported by the server and enabled by // configuration. - GetLastQueryStatistics() (result QueryStats, err error) + GetLastQueryStatistics(ctx context.Context) (result QueryStats, err error) // SetURL changes the URL field in the connection object, so that the // next connection (re-)establishment will use the new URL. @@ -69,7 +70,7 @@ type Conn interface { // GetServerMetadata() returns details about the CockroachDB node // this connection is connected to. - GetServerMetadata() ( + GetServerMetadata(ctx context.Context) ( nodeID int32, version, clusterID string, err error, @@ -82,7 +83,7 @@ type Conn interface { // The what argument is a descriptive label for the value being // retrieved, for inclusion inside warning or error message. // The sql argument is the SQL query to use to retrieve the value. - GetServerValue(what, sql string) (driver.Value, string, bool) + GetServerValue(ctx context.Context, what, sql string) (driver.Value, string, bool) // GetDriverConn exposes the underlying SQL driver connection object // for use by the cli package. @@ -159,11 +160,11 @@ type QueryStats struct { // visible to the closure passed to (Conn).ExecTxn. type TxBoundConn interface { // Exec executes a statement inside the transaction. - Exec(query string, args []driver.Value) error + Exec(ctx context.Context, query string, args ...interface{}) error // Query returns one or more SQL statements and returns the // corresponding result set(s). - Query(query string, args []driver.Value) (Rows, error) + Query(ctx context.Context, query string, args ...interface{}) (Rows, error) } // DriverConn is the type of the connection object returned by @@ -173,9 +174,4 @@ type DriverConn interface { driver.Conn driver.ExecerContext driver.QueryerContext - - //lint:ignore SA1019 TODO(mjibson): clean this up to use go1.8 APIs - driver.Execer - //lint:ignore SA1019 TODO(mjibson): clean this up to use go1.8 APIs - driver.Queryer } diff --git a/pkg/cli/clisqlclient/conn.go b/pkg/cli/clisqlclient/conn.go index efd4ab2aae94..c7a9cbc33de9 100644 --- a/pkg/cli/clisqlclient/conn.go +++ b/pkg/cli/clisqlclient/conn.go @@ -137,59 +137,63 @@ func (c *sqlConn) SetMissingPassword(missing bool) { // EnsureConn (re-)establishes the connection to the server. func (c *sqlConn) EnsureConn() error { - if c.conn == nil { - if c.reconnecting && c.connCtx.IsInteractive() { - fmt.Fprintf(c.errw, "warning: connection lost!\n"+ - "opening new connection: all session settings will be lost\n") - } - base, err := pq.NewConnector(c.url) - if err != nil { - return wrapConnError(err) - } - // Add a notice handler - re-use the cliOutputError function in this case. - connector := pq.ConnectorWithNoticeHandler(base, func(notice *pq.Error) { - c.handleNotice(notice) - }) - // TODO(cli): we can't thread ctx through ensureConn usages, as it needs - // to follow the gosql.DB interface. We should probably look at initializing - // connections only once instead. The context is only used for dialing. - conn, err := connector.Connect(context.TODO()) - if err != nil { - // Connection failed: if the failure is due to a missing - // password, we're going to fill the password here. - // - // TODO(knz): CockroachDB servers do not properly fill SQLSTATE - // (28P01) for password auth errors, so we have to "make do" - // with a string match. This should be cleaned up by adding - // the missing code server-side. - errStr := strings.TrimPrefix(err.Error(), "pq: ") - if strings.HasPrefix(errStr, "password authentication failed") && c.passwordMissing { - if pErr := c.fillPassword(); pErr != nil { - return errors.CombineErrors(err, pErr) - } - // Recurse, once. We recurse to ensure that pq.NewConnector - // and ConnectorWithNoticeHandler get called with the new URL. - // The recursion only occurs once because fillPassword() - // resets c.passwordMissing, so we cannot get into this - // conditional a second time. - return c.EnsureConn() - } - // Not a password auth error, or password already set. Simply fail. - return wrapConnError(err) - } - if c.reconnecting && c.dbName != "" { - // Attempt to reset the current database. - if _, err := conn.(DriverConn).Exec(`SET DATABASE = $1`, []driver.Value{c.dbName}); err != nil { - fmt.Fprintf(c.errw, "warning: unable to restore current database: %v\n", err) + if c.conn != nil { + return nil + } + ctx := context.Background() + + if c.reconnecting && c.connCtx.IsInteractive() { + fmt.Fprintf(c.errw, "warning: connection lost!\n"+ + "opening new connection: all session settings will be lost\n") + } + base, err := pq.NewConnector(c.url) + if err != nil { + return wrapConnError(err) + } + // Add a notice handler - re-use the cliOutputError function in this case. + connector := pq.ConnectorWithNoticeHandler(base, func(notice *pq.Error) { + c.handleNotice(notice) + }) + // TODO(cli): we can't thread ctx through ensureConn usages, as it needs + // to follow the gosql.DB interface. We should probably look at initializing + // connections only once instead. The context is only used for dialing. + conn, err := connector.Connect(ctx) + if err != nil { + // Connection failed: if the failure is due to a missing + // password, we're going to fill the password here. + // + // TODO(knz): CockroachDB servers do not properly fill SQLSTATE + // (28P01) for password auth errors, so we have to "make do" + // with a string match. This should be cleaned up by adding + // the missing code server-side. + errStr := strings.TrimPrefix(err.Error(), "pq: ") + if strings.HasPrefix(errStr, "password authentication failed") && c.passwordMissing { + if pErr := c.fillPassword(); pErr != nil { + return errors.CombineErrors(err, pErr) } + // Recurse, once. We recurse to ensure that pq.NewConnector + // and ConnectorWithNoticeHandler get called with the new URL. + // The recursion only occurs once because fillPassword() + // resets c.passwordMissing, so we cannot get into this + // conditional a second time. + return c.EnsureConn() } - c.conn = conn.(DriverConn) - if err := c.checkServerMetadata(); err != nil { - err = errors.CombineErrors(err, c.Close()) - return wrapConnError(err) + // Not a password auth error, or password already set. Simply fail. + return wrapConnError(err) + } + if c.reconnecting && c.dbName != "" { + // Attempt to reset the current database. + if _, err := conn.(DriverConn).ExecContext(ctx, `SET DATABASE = $1`, + []driver.NamedValue{{Value: c.dbName}}); err != nil { + fmt.Fprintf(c.errw, "warning: unable to restore current database: %v\n", err) } - c.reconnecting = false } + c.conn = conn.(DriverConn) + if err := c.checkServerMetadata(ctx); err != nil { + err = errors.CombineErrors(err, c.Close()) + return wrapConnError(err) + } + c.reconnecting = false return nil } @@ -204,11 +208,11 @@ const ( // tryEnableServerExecutionTimings attempts to check if the server supports the // SHOW LAST QUERY STATISTICS statements. This allows the CLI client to report // server side execution timings instead of timing on the client. -func (c *sqlConn) tryEnableServerExecutionTimings() error { +func (c *sqlConn) tryEnableServerExecutionTimings(ctx context.Context) error { // Starting in v21.2 servers, clients can request an explicit set of // values which makes them compatible with any post-21.2 column // additions. - _, err := c.QueryRow("SHOW LAST QUERY STATISTICS RETURNING x", nil) + _, err := c.QueryRow(ctx, "SHOW LAST QUERY STATISTICS RETURNING x") if err != nil && !clierror.IsSQLSyntaxError(err) { return err } @@ -220,7 +224,7 @@ func (c *sqlConn) tryEnableServerExecutionTimings() error { // Pre-21.2 servers may have SHOW LAST QUERY STATISTICS. // Note: this branch is obsolete, remove it when compatibility // with pre-21.2 servers is not required any more. - _, err = c.QueryRow("SHOW LAST QUERY STATISTICS", nil) + _, err = c.QueryRow(ctx, "SHOW LAST QUERY STATISTICS") if err != nil && !clierror.IsSQLSyntaxError(err) { return err } @@ -236,9 +240,11 @@ func (c *sqlConn) tryEnableServerExecutionTimings() error { return nil } -func (c *sqlConn) GetServerMetadata() (nodeID int32, version, clusterID string, err error) { +func (c *sqlConn) GetServerMetadata( + ctx context.Context, +) (nodeID int32, version, clusterID string, err error) { // Retrieve the node ID and server build info. - rows, err := c.Query("SELECT * FROM crdb_internal.node_build_info", nil) + rows, err := c.Query(ctx, "SELECT * FROM crdb_internal.node_build_info") if errors.Is(err, driver.ErrBadConn) { return 0, "", "", err } @@ -332,7 +338,7 @@ func toString(v driver.Value) string { // upon the initial connection or if either has changed since // the last connection, based on the last known values in the sqlConn // struct. -func (c *sqlConn) checkServerMetadata() error { +func (c *sqlConn) checkServerMetadata(ctx context.Context) error { if !c.connCtx.IsInteractive() { // Version reporting is just noise if the user is not present to // change their mind upon seeing the information. @@ -344,7 +350,7 @@ func (c *sqlConn) checkServerMetadata() error { return nil } - _, newServerVersion, newClusterID, err := c.GetServerMetadata() + _, newServerVersion, newClusterID, err := c.GetServerMetadata(ctx) if errors.Is(err, driver.ErrBadConn) { return err } @@ -401,14 +407,16 @@ func (c *sqlConn) checkServerMetadata() error { } // Try to enable server execution timings for the CLI to display if // supported by the server. - return c.tryEnableServerExecutionTimings() + return c.tryEnableServerExecutionTimings(ctx) } // GetServerValue retrieves the first driverValue returned by the // given sql query. If the query fails or does not return a single // column, `false` is returned in the second result. -func (c *sqlConn) GetServerValue(what, sql string) (driver.Value, string, bool) { - rows, err := c.Query(sql, nil) +func (c *sqlConn) GetServerValue( + ctx context.Context, what, sql string, +) (driver.Value, string, bool) { + rows, err := c.Query(ctx, sql) if err != nil { fmt.Fprintf(c.errw, "warning: error retrieving the %s: %v\n", what, err) return nil, "", false @@ -432,7 +440,7 @@ func (c *sqlConn) GetServerValue(what, sql string) (driver.Value, string, bool) return dbVals[0], dbColType, true } -func (c *sqlConn) GetLastQueryStatistics() (results QueryStats, resErr error) { +func (c *sqlConn) GetLastQueryStatistics(ctx context.Context) (results QueryStats, resErr error) { if !c.connCtx.EnableServerExecutionTimings || c.lastQueryStatsMode == modeDisabled { return results, nil } @@ -444,7 +452,7 @@ func (c *sqlConn) GetLastQueryStatistics() (results QueryStats, resErr error) { stmt = `SHOW LAST QUERY STATISTICS` } - vals, cols, err := c.queryRowInternal(stmt, nil) + vals, cols, err := c.queryRowInternal(ctx, stmt, nil) if err != nil { return results, err } @@ -488,23 +496,29 @@ func (c *sqlConn) GetLastQueryStatistics() (results QueryStats, resErr error) { // // NOTE: the supplied closure should not have external side // effects beyond changes to the database. -func (c *sqlConn) ExecTxn(fn func(TxBoundConn) error) (err error) { - if err := c.Exec(`BEGIN`, nil); err != nil { +func (c *sqlConn) ExecTxn( + ctx context.Context, fn func(context.Context, TxBoundConn) error, +) (err error) { + if err := c.Exec(ctx, `BEGIN`); err != nil { return err } - return crdb.ExecuteInTx(context.TODO(), sqlTxnShim{c}, func() error { - return fn(c) + return crdb.ExecuteInTx(ctx, sqlTxnShim{c}, func() error { + return fn(ctx, c) }) } -func (c *sqlConn) Exec(query string, args []driver.Value) error { +func (c *sqlConn) Exec(ctx context.Context, query string, args ...interface{}) error { + dVals, err := convertArgs(args) + if err != nil { + return err + } if err := c.EnsureConn(); err != nil { return err } if c.connCtx.Echo { fmt.Fprintln(c.errw, ">", query) } - _, err := c.conn.Exec(query, args) + _, err = c.conn.ExecContext(ctx, query, dVals) c.flushNotices() if errors.Is(err, driver.ErrBadConn) { c.reconnecting = true @@ -513,14 +527,18 @@ func (c *sqlConn) Exec(query string, args []driver.Value) error { return err } -func (c *sqlConn) Query(query string, args []driver.Value) (Rows, error) { +func (c *sqlConn) Query(ctx context.Context, query string, args ...interface{}) (Rows, error) { + dVals, err := convertArgs(args) + if err != nil { + return nil, err + } if err := c.EnsureConn(); err != nil { return nil, err } if c.connCtx.Echo { fmt.Fprintln(c.errw, ">", query) } - rows, err := c.conn.Query(query, args) + rows, err := c.conn.QueryContext(ctx, query, dVals) if errors.Is(err, driver.ErrBadConn) { c.reconnecting = true c.silentClose() @@ -531,15 +549,17 @@ func (c *sqlConn) Query(query string, args []driver.Value) (Rows, error) { return &sqlRows{rows: rows.(sqlRowsI), conn: c}, nil } -func (c *sqlConn) QueryRow(query string, args []driver.Value) ([]driver.Value, error) { - results, _, err := c.queryRowInternal(query, args) +func (c *sqlConn) QueryRow( + ctx context.Context, query string, args ...interface{}, +) ([]driver.Value, error) { + results, _, err := c.queryRowInternal(ctx, query, args) return results, err } func (c *sqlConn) queryRowInternal( - query string, args []driver.Value, + ctx context.Context, query string, args []interface{}, ) (vals []driver.Value, colNames []string, resErr error) { - rows, _, err := MakeQuery(query, args...)(c) + rows, _, err := MakeQuery(query, args...)(ctx, c) if err != nil { return nil, nil, err } diff --git a/pkg/cli/clisqlclient/conn_test.go b/pkg/cli/clisqlclient/conn_test.go index 6788d7197ce6..71eaadeeb5ed 100644 --- a/pkg/cli/clisqlclient/conn_test.go +++ b/pkg/cli/clisqlclient/conn_test.go @@ -11,6 +11,7 @@ package clisqlclient_test import ( + "context" "database/sql/driver" "io/ioutil" "net/url" @@ -36,6 +37,7 @@ func TestConnRecover(t *testing.T) { p := cli.TestCLIParams{T: t} c := cli.NewCLITest(p) defer c.Cleanup() + ctx := context.Background() url, cleanup := sqlutils.PGUrl(t, c.ServingSQLAddr(), t.Name(), url.User(security.RootUser)) defer cleanup() @@ -48,7 +50,7 @@ func TestConnRecover(t *testing.T) { }() // Sanity check to establish baseline. - rows, err := conn.Query(`SELECT 1`, nil) + rows, err := conn.Query(ctx, `SELECT 1`) if err != nil { t.Fatal(err) } @@ -64,7 +66,7 @@ func TestConnRecover(t *testing.T) { // and starts delivering ErrBadConn. We don't know the timing of // this however. testutils.SucceedsSoon(t, func() error { - if sqlRows, err := conn.Query(`SELECT 1`, nil); err == nil { + if sqlRows, err := conn.Query(ctx, `SELECT 1`); err == nil { if closeErr := sqlRows.Close(); closeErr != nil { t.Fatal(closeErr) } @@ -75,7 +77,7 @@ func TestConnRecover(t *testing.T) { }) // Check that Query recovers from a connection close by re-connecting. - rows, err = conn.Query(`SELECT 1`, nil) + rows, err = conn.Query(ctx, `SELECT 1`) if err != nil { t.Fatalf("conn.Query(): expected no error after reconnect, got %v", err) } @@ -88,14 +90,14 @@ func TestConnRecover(t *testing.T) { // Ditto from Query(). testutils.SucceedsSoon(t, func() error { - if err := conn.Exec(`SELECT 1`, nil); !errors.Is(err, driver.ErrBadConn) { + if err := conn.Exec(ctx, `SELECT 1`); !errors.Is(err, driver.ErrBadConn) { return errors.Newf("expected ErrBadConn, got %v", err) // nolint:errwrap } return nil }) // Check that Exec recovers from a connection close by re-connecting. - if err := conn.Exec(`SELECT 1`, nil); err != nil { + if err := conn.Exec(ctx, `SELECT 1`); err != nil { t.Fatalf("conn.Exec(): expected no error after reconnect, got %v", err) } } @@ -118,6 +120,7 @@ func TestTransactionRetry(t *testing.T) { p := cli.TestCLIParams{T: t} c := cli.NewCLITest(p) defer c.Cleanup() + ctx := context.Background() url, cleanup := sqlutils.PGUrl(t, c.ServingSQLAddr(), t.Name(), url.User(security.RootUser)) defer cleanup() @@ -130,14 +133,14 @@ func TestTransactionRetry(t *testing.T) { }() var tries int - err := conn.ExecTxn(func(conn clisqlclient.TxBoundConn) error { + err := conn.ExecTxn(ctx, func(ctx context.Context, conn clisqlclient.TxBoundConn) error { tries++ if tries > 2 { return nil } // Prevent automatic server-side retries. - rows, err := conn.Query(`SELECT now()`, nil) + rows, err := conn.Query(ctx, `SELECT now()`) if err != nil { return err } @@ -146,7 +149,7 @@ func TestTransactionRetry(t *testing.T) { } // Force a client-side retry. - rows, err = conn.Query(`SELECT crdb_internal.force_retry('1h')`, nil) + rows, err = conn.Query(ctx, `SELECT crdb_internal.force_retry('1h')`) if err != nil { return err } diff --git a/pkg/cli/clisqlclient/make_query.go b/pkg/cli/clisqlclient/make_query.go index 9b7435da0fd4..6fbbaf944ee1 100644 --- a/pkg/cli/clisqlclient/make_query.go +++ b/pkg/cli/clisqlclient/make_query.go @@ -11,6 +11,7 @@ package clisqlclient import ( + "context" "database/sql/driver" "strings" @@ -19,30 +20,36 @@ import ( ) // QueryFn is the type of functions produced by MakeQuery. -type QueryFn func(conn Conn) (rows Rows, isMultiStatementQuery bool, err error) +type QueryFn func(ctx context.Context, conn Conn) (rows Rows, isMultiStatementQuery bool, err error) // MakeQuery encapsulates a SQL query and its parameter into a // function that can be applied to a connection object. -func MakeQuery(query string, parameters ...driver.Value) QueryFn { - return func(conn Conn) (Rows, bool, error) { +func MakeQuery(query string, parameters ...interface{}) QueryFn { + return func(ctx context.Context, conn Conn) (Rows, bool, error) { isMultiStatementQuery, _ := scanner.HasMultipleStatements(query) - // driver.Value is an alias for interface{}, but must adhere to a restricted + rows, err := conn.Query(ctx, query, parameters...) + err = handleCopyError(conn.(*sqlConn), err) + return rows, isMultiStatementQuery, err + } +} + +func convertArgs(parameters []interface{}) ([]driver.NamedValue, error) { + dVals := make([]driver.NamedValue, len(parameters)) + for i := range parameters { + // driver.NamedValue.Value is an alias for interface{}, but must adhere to a restricted // set of types when being passed to driver.Queryer.Query (see // driver.IsValue). We use driver.DefaultParameterConverter to perform the // necessary conversion. This is usually taken care of by the sql package, // but we have to do so manually because we're talking directly to the // driver. - for i := range parameters { - var err error - parameters[i], err = driver.DefaultParameterConverter.ConvertValue(parameters[i]) - if err != nil { - return nil, isMultiStatementQuery, err - } + var err error + dVals[i].Ordinal = i + 1 + dVals[i].Value, err = driver.DefaultParameterConverter.ConvertValue(parameters[i]) + if err != nil { + return nil, err } - rows, err := conn.Query(query, parameters) - err = handleCopyError(conn.(*sqlConn), err) - return rows, isMultiStatementQuery, err } + return dVals, nil } // handleCopyError ensures the user is properly informed when they issue diff --git a/pkg/cli/clisqlclient/statement_diag.go b/pkg/cli/clisqlclient/statement_diag.go index 45516c1fbbe6..110cf718e304 100644 --- a/pkg/cli/clisqlclient/statement_diag.go +++ b/pkg/cli/clisqlclient/statement_diag.go @@ -11,6 +11,7 @@ package clisqlclient import ( + "context" "database/sql/driver" "fmt" "io" @@ -31,8 +32,8 @@ type StmtDiagBundleInfo struct { // StmtDiagListBundles retrieves information about all available statement // diagnostics bundles. -func StmtDiagListBundles(conn Conn) ([]StmtDiagBundleInfo, error) { - result, err := stmtDiagListBundlesInternal(conn) +func StmtDiagListBundles(ctx context.Context, conn Conn) ([]StmtDiagBundleInfo, error) { + result, err := stmtDiagListBundlesInternal(ctx, conn) if err != nil { return nil, errors.Wrap( err, "failed to retrieve statement diagnostics bundles", @@ -41,13 +42,12 @@ func StmtDiagListBundles(conn Conn) ([]StmtDiagBundleInfo, error) { return result, nil } -func stmtDiagListBundlesInternal(conn Conn) ([]StmtDiagBundleInfo, error) { - rows, err := conn.Query( +func stmtDiagListBundlesInternal(ctx context.Context, conn Conn) ([]StmtDiagBundleInfo, error) { + rows, err := conn.Query(ctx, `SELECT id, statement_fingerprint, collected_at FROM system.statement_diagnostics WHERE error IS NULL ORDER BY collected_at DESC`, - nil, /* args */ ) if err != nil { return nil, err @@ -88,8 +88,10 @@ type StmtDiagActivationRequest struct { // StmtDiagListOutstandingRequests retrieves outstanding statement diagnostics // activation requests. -func StmtDiagListOutstandingRequests(conn Conn) ([]StmtDiagActivationRequest, error) { - result, err := stmtDiagListOutstandingRequestsInternal(conn) +func StmtDiagListOutstandingRequests( + ctx context.Context, conn Conn, +) ([]StmtDiagActivationRequest, error) { + result, err := stmtDiagListOutstandingRequestsInternal(ctx, conn) if err != nil { return nil, errors.Wrap( err, "failed to retrieve outstanding statement diagnostics activation requests", @@ -99,16 +101,16 @@ func StmtDiagListOutstandingRequests(conn Conn) ([]StmtDiagActivationRequest, er } // TODO(yuzefovich): remove this in 22.2. -func isAtLeast22dot1ClusterVersion(conn Conn) (bool, error) { +func isAtLeast22dot1ClusterVersion(ctx context.Context, conn Conn) (bool, error) { // Check whether the migration to add the conditional diagnostics columns to // the statement_diagnostics_requests system table has already been run. - row, err := conn.QueryRow(` + row, err := conn.QueryRow(ctx, ` SELECT count(*) FROM [SHOW COLUMNS FROM system.statement_diagnostics_requests] WHERE - column_name = 'min_execution_latency';`, nil /* args */) + column_name = 'min_execution_latency';`) if err != nil { return false, err } @@ -119,9 +121,11 @@ WHERE return c == 1, nil } -func stmtDiagListOutstandingRequestsInternal(conn Conn) ([]StmtDiagActivationRequest, error) { +func stmtDiagListOutstandingRequestsInternal( + ctx context.Context, conn Conn, +) ([]StmtDiagActivationRequest, error) { var extraColumns string - atLeast22dot1, err := isAtLeast22dot1ClusterVersion(conn) + atLeast22dot1, err := isAtLeast22dot1ClusterVersion(ctx, conn) if err != nil { return nil, err } @@ -136,12 +140,11 @@ func stmtDiagListOutstandingRequestsInternal(conn Conn) ([]StmtDiagActivationReq EXTRACT(second FROM min_execution_latency)::INT8 * 1000` extraColumns = ", " + getMilliseconds + ", expires_at" } - rows, err := conn.Query( + rows, err := conn.Query(ctx, fmt.Sprintf(`SELECT id, statement_fingerprint, requested_at%s FROM system.statement_diagnostics_requests WHERE NOT completed ORDER BY requested_at DESC`, extraColumns), - nil, /* args */ ) if err != nil { return nil, err @@ -180,8 +183,8 @@ func stmtDiagListOutstandingRequestsInternal(conn Conn) ([]StmtDiagActivationReq } // StmtDiagDownloadBundle downloads the bundle with the given ID to a file. -func StmtDiagDownloadBundle(conn Conn, id int64, filename string) error { - if err := stmtDiagDownloadBundleInternal(conn, id, filename); err != nil { +func StmtDiagDownloadBundle(ctx context.Context, conn Conn, id int64, filename string) error { + if err := stmtDiagDownloadBundleInternal(ctx, conn, id, filename); err != nil { return errors.Wrapf( err, "failed to download statement diagnostics bundle %d to '%s'", id, filename, ) @@ -189,11 +192,13 @@ func StmtDiagDownloadBundle(conn Conn, id int64, filename string) error { return nil } -func stmtDiagDownloadBundleInternal(conn Conn, id int64, filename string) error { +func stmtDiagDownloadBundleInternal( + ctx context.Context, conn Conn, id int64, filename string, +) error { // Retrieve the chunk IDs; these are stored in an INT ARRAY column. - rows, err := conn.Query( + rows, err := conn.Query(ctx, "SELECT unnest(bundle_chunks) FROM system.statement_diagnostics WHERE id = $1", - []driver.Value{id}, + id, ) if err != nil { return err @@ -223,9 +228,9 @@ func stmtDiagDownloadBundleInternal(conn Conn, id int64, filename string) error } for _, chunkID := range chunkIDs { - data, err := conn.QueryRow( + data, err := conn.QueryRow(ctx, "SELECT data FROM system.statement_bundle_chunks WHERE id = $1", - []driver.Value{chunkID}, + chunkID, ) if err != nil { _ = out.Close() @@ -241,10 +246,10 @@ func stmtDiagDownloadBundleInternal(conn Conn, id int64, filename string) error } // StmtDiagDeleteBundle deletes a statement diagnostics bundle. -func StmtDiagDeleteBundle(conn Conn, id int64) error { - _, err := conn.QueryRow( +func StmtDiagDeleteBundle(ctx context.Context, conn Conn, id int64) error { + _, err := conn.QueryRow(ctx, "SELECT 1 FROM system.statement_diagnostics WHERE id = $1", - []driver.Value{id}, + id, ) if err != nil { if err == io.EOF { @@ -252,63 +257,60 @@ func StmtDiagDeleteBundle(conn Conn, id int64) error { } return err } - return conn.ExecTxn(func(conn TxBoundConn) error { + return conn.ExecTxn(ctx, func(ctx context.Context, conn TxBoundConn) error { // Delete the request metadata. - if err := conn.Exec( + if err := conn.Exec(ctx, "DELETE FROM system.statement_diagnostics_requests WHERE statement_diagnostics_id = $1", - []driver.Value{id}, + id, ); err != nil { return err } // Delete the bundle chunks. - if err := conn.Exec( + if err := conn.Exec(ctx, `DELETE FROM system.statement_bundle_chunks WHERE id IN ( SELECT unnest(bundle_chunks) FROM system.statement_diagnostics WHERE id = $1 )`, - []driver.Value{id}, + id, ); err != nil { return err } // Finally, delete the diagnostics entry. - return conn.Exec( + return conn.Exec(ctx, "DELETE FROM system.statement_diagnostics WHERE id = $1", - []driver.Value{id}, + id, ) }) } // StmtDiagDeleteAllBundles deletes all statement diagnostics bundles. -func StmtDiagDeleteAllBundles(conn Conn) error { - return conn.ExecTxn(func(conn TxBoundConn) error { +func StmtDiagDeleteAllBundles(ctx context.Context, conn Conn) error { + return conn.ExecTxn(ctx, func(ctx context.Context, conn TxBoundConn) error { // Delete the request metadata. - if err := conn.Exec( + if err := conn.Exec(ctx, "DELETE FROM system.statement_diagnostics_requests WHERE completed", - nil, ); err != nil { return err } // Delete all bundle chunks. - if err := conn.Exec( + if err := conn.Exec(ctx, `DELETE FROM system.statement_bundle_chunks WHERE true`, - nil, ); err != nil { return err } // Finally, delete the diagnostics entry. - return conn.Exec( + return conn.Exec(ctx, "DELETE FROM system.statement_diagnostics WHERE true", - nil, ) }) } // StmtDiagCancelOutstandingRequest deletes an outstanding statement diagnostics // activation request. -func StmtDiagCancelOutstandingRequest(conn Conn, id int64) error { - _, err := conn.QueryRow( +func StmtDiagCancelOutstandingRequest(ctx context.Context, conn Conn, id int64) error { + _, err := conn.QueryRow(ctx, "DELETE FROM system.statement_diagnostics_requests WHERE id = $1 RETURNING id", - []driver.Value{id}, + id, ) if err != nil { if err == io.EOF { @@ -321,9 +323,8 @@ func StmtDiagCancelOutstandingRequest(conn Conn, id int64) error { // StmtDiagCancelAllOutstandingRequests deletes all outstanding statement // diagnostics activation requests. -func StmtDiagCancelAllOutstandingRequests(conn Conn) error { - return conn.Exec( +func StmtDiagCancelAllOutstandingRequests(ctx context.Context, conn Conn) error { + return conn.Exec(ctx, "DELETE FROM system.statement_diagnostics_requests WHERE NOT completed", - nil, ) } diff --git a/pkg/cli/clisqlclient/txn_shim.go b/pkg/cli/clisqlclient/txn_shim.go index 6331db832e76..d8d5119af570 100644 --- a/pkg/cli/clisqlclient/txn_shim.go +++ b/pkg/cli/clisqlclient/txn_shim.go @@ -30,17 +30,17 @@ type sqlTxnShim struct { var _ crdb.Tx = sqlTxnShim{} -func (t sqlTxnShim) Commit(context.Context) error { - return t.conn.Exec(`COMMIT`, nil) +func (t sqlTxnShim) Commit(ctx context.Context) error { + return t.conn.Exec(ctx, `COMMIT`) } -func (t sqlTxnShim) Rollback(context.Context) error { - return t.conn.Exec(`ROLLBACK`, nil) +func (t sqlTxnShim) Rollback(ctx context.Context) error { + return t.conn.Exec(ctx, `ROLLBACK`) } -func (t sqlTxnShim) Exec(_ context.Context, query string, values ...interface{}) error { +func (t sqlTxnShim) Exec(ctx context.Context, query string, values ...interface{}) error { if len(values) != 0 { panic("sqlTxnShim.ExecContext must not be called with values") } - return t.conn.Exec(query, nil) + return t.conn.Exec(ctx, query) } diff --git a/pkg/cli/clisqlexec/run_query.go b/pkg/cli/clisqlexec/run_query.go index 6151f861eec6..9abe0f63dafd 100644 --- a/pkg/cli/clisqlexec/run_query.go +++ b/pkg/cli/clisqlexec/run_query.go @@ -11,6 +11,7 @@ package clisqlexec import ( + "context" "database/sql/driver" "fmt" "io" @@ -26,9 +27,9 @@ import ( // RunQuery takes a 'query' with optional 'parameters'. // It runs the sql query and returns a list of columns names and a list of rows. func (sqlExecCtx *Context) RunQuery( - conn clisqlclient.Conn, fn clisqlclient.QueryFn, showMoreChars bool, + ctx context.Context, conn clisqlclient.Conn, fn clisqlclient.QueryFn, showMoreChars bool, ) ([]string, [][]string, error) { - rows, _, err := fn(conn) + rows, _, err := fn(ctx, conn) if err != nil { return nil, nil, err } @@ -41,10 +42,10 @@ func (sqlExecCtx *Context) RunQuery( // It runs the sql query and writes output to 'w'. // Errors and warnings, if any, are printed to 'ew'. func (sqlExecCtx *Context) RunQueryAndFormatResults( - conn clisqlclient.Conn, w, ew io.Writer, fn clisqlclient.QueryFn, + ctx context.Context, conn clisqlclient.Conn, w, ew io.Writer, fn clisqlclient.QueryFn, ) (err error) { startTime := timeutil.Now() - rows, isMultiStatementQuery, err := fn(conn) + rows, isMultiStatementQuery, err := fn(ctx, conn) if err != nil { return err } @@ -135,7 +136,7 @@ func (sqlExecCtx *Context) RunQueryAndFormatResults( return err } - sqlExecCtx.maybeShowTimes(conn, w, ew, isMultiStatementQuery, startTime, queryCompleteTime) + sqlExecCtx.maybeShowTimes(ctx, conn, w, ew, isMultiStatementQuery, startTime, queryCompleteTime) if more, err := rows.NextResultSet(); err != nil { return err @@ -147,6 +148,7 @@ func (sqlExecCtx *Context) RunQueryAndFormatResults( // maybeShowTimes displays the execution time if show_times has been set. func (sqlExecCtx *Context) maybeShowTimes( + ctx context.Context, conn clisqlclient.Conn, w, ew io.Writer, isMultiStatementQuery bool, @@ -215,7 +217,7 @@ func (sqlExecCtx *Context) maybeShowTimes( } // If discrete server/network timings are available, also print them. - detailedStats, err := conn.GetLastQueryStatistics() + detailedStats, err := conn.GetLastQueryStatistics(ctx) if err != nil { fmt.Fprintln(w, stats.String()) fmt.Fprintf(ew, "\nwarning: %v", err) diff --git a/pkg/cli/clisqlexec/run_query_test.go b/pkg/cli/clisqlexec/run_query_test.go index 8c0218ee6be4..f53855959429 100644 --- a/pkg/cli/clisqlexec/run_query_test.go +++ b/pkg/cli/clisqlexec/run_query_test.go @@ -12,6 +12,7 @@ package clisqlexec_test import ( "bytes" + "context" "io" "io/ioutil" "net/url" @@ -38,7 +39,9 @@ func makeSQLConn(url string) clisqlclient.Conn { func runQueryAndFormatResults( conn clisqlclient.Conn, w io.Writer, fn clisqlclient.QueryFn, ) (err error) { - return testExecCtx.RunQueryAndFormatResults(conn, w, ioutil.Discard, fn) + return testExecCtx.RunQueryAndFormatResults( + context.Background(), + conn, w, ioutil.Discard, fn) } func TestRunQuery(t *testing.T) { @@ -74,7 +77,9 @@ SET b.Reset() // Use system database for sample query/output as they are fairly fixed. - cols, rows, err := testExecCtx.RunQuery(conn, clisqlclient.MakeQuery(`SHOW COLUMNS FROM system.namespace`), false) + cols, rows, err := testExecCtx.RunQuery( + context.Background(), + conn, clisqlclient.MakeQuery(`SHOW COLUMNS FROM system.namespace`), false) if err != nil { t.Fatal(err) } diff --git a/pkg/cli/clisqlshell/sql.go b/pkg/cli/clisqlshell/sql.go index 353a2c620734..43a2ebc9f950 100644 --- a/pkg/cli/clisqlshell/sql.go +++ b/pkg/cli/clisqlshell/sql.go @@ -834,7 +834,9 @@ func (c *cliState) doRefreshPrompts(nextState cliStateEnum) cliStateEnum { func (c *cliState) refreshTransactionStatus() { c.lastKnownTxnStatus = unknownTxnStatus - dbVal, dbColType, hasVal := c.conn.GetServerValue("transaction status", `SHOW TRANSACTION STATUS`) + dbVal, dbColType, hasVal := c.conn.GetServerValue( + context.Background(), + "transaction status", `SHOW TRANSACTION STATUS`) if !hasVal { return } @@ -867,7 +869,9 @@ func (c *cliState) refreshDatabaseName() string { return unknownDbName } - dbVal, dbColType, hasVal := c.conn.GetServerValue("database name", `SHOW DATABASE`) + dbVal, dbColType, hasVal := c.conn.GetServerValue( + context.Background(), + "database name", `SHOW DATABASE`) if !hasVal { return unknownDbName } @@ -1586,7 +1590,9 @@ func (c *cliState) doRunStatements(nextState cliStateEnum) cliStateEnum { if c.iCtx.autoTrace != "" { // Clear the trace by disabling tracing, then restart tracing // with the specified options. - c.exitErr = c.conn.Exec("SET tracing = off; SET tracing = "+c.iCtx.autoTrace, nil) + c.exitErr = c.conn.Exec( + context.Background(), + "SET tracing = off; SET tracing = "+c.iCtx.autoTrace) if c.exitErr != nil { if !c.singleStatement { clierror.OutputError(c.iCtx.stderr, c.exitErr, true /*showSeverity*/, false /*verbose*/) @@ -1599,7 +1605,9 @@ func (c *cliState) doRunStatements(nextState cliStateEnum) cliStateEnum { } // Now run the statement/query. - c.exitErr = c.sqlExecCtx.RunQueryAndFormatResults(c.conn, c.iCtx.stdout, c.iCtx.stderr, + c.exitErr = c.sqlExecCtx.RunQueryAndFormatResults( + context.Background(), + c.conn, c.iCtx.stdout, c.iCtx.stderr, clisqlclient.MakeQuery(c.concatLines)) if c.exitErr != nil { if !c.singleStatement { @@ -1611,7 +1619,8 @@ func (c *cliState) doRunStatements(nextState cliStateEnum) cliStateEnum { // this even if there was an error: a trace on errors is useful. if c.iCtx.autoTrace != "" { // First, disable tracing. - if err := c.conn.Exec("SET tracing = off", nil); err != nil { + if err := c.conn.Exec(context.Background(), + "SET tracing = off"); err != nil { // Print the error for the SET tracing statement. This will // appear below the error for the main query above, if any, clierror.OutputError(c.iCtx.stderr, err, true /*showSeverity*/, false /*verbose*/) @@ -1629,7 +1638,9 @@ func (c *cliState) doRunStatements(nextState cliStateEnum) cliStateEnum { if strings.Contains(c.iCtx.autoTrace, "kv") { traceType = "kv" } - if err := c.sqlExecCtx.RunQueryAndFormatResults(c.conn, c.iCtx.stdout, c.iCtx.stderr, + if err := c.sqlExecCtx.RunQueryAndFormatResults( + context.Background(), + c.conn, c.iCtx.stdout, c.iCtx.stderr, clisqlclient.MakeQuery(fmt.Sprintf("SHOW %s TRACE FOR SESSION", traceType))); err != nil { clierror.OutputError(c.iCtx.stderr, err, true /*showSeverity*/, false /*verbose*/) if c.exitErr == nil { @@ -1915,7 +1926,9 @@ func (c *cliState) runStatements(stmts []string) error { // decomposition in the first return value. If it is not, the function // extracts a help string if available. func (c *cliState) serverSideParse(sql string) (helpText string, err error) { - cols, rows, err := c.sqlExecCtx.RunQuery(c.conn, + cols, rows, err := c.sqlExecCtx.RunQuery( + context.Background(), + c.conn, clisqlclient.MakeQuery("SHOW SYNTAX "+lexbase.EscapeSQLString(sql)), true) if err != nil { diff --git a/pkg/cli/clisqlshell/statement_diag.go b/pkg/cli/clisqlshell/statement_diag.go index 949c56dfdae8..28157674e5ad 100644 --- a/pkg/cli/clisqlshell/statement_diag.go +++ b/pkg/cli/clisqlshell/statement_diag.go @@ -12,6 +12,7 @@ package clisqlshell import ( "bytes" + "context" "fmt" "strconv" "text/tabwriter" @@ -54,7 +55,8 @@ func (c *cliState) handleStatementDiag( } else { filename = fmt.Sprintf("stmt-bundle-%d.zip", id) } - cmdErr = clisqlclient.StmtDiagDownloadBundle(c.conn, id, filename) + cmdErr = clisqlclient.StmtDiagDownloadBundle( + context.Background(), c.conn, id, filename) if cmdErr == nil { fmt.Fprintf(c.iCtx.stdout, "Bundle saved to %q\n", filename) } @@ -75,7 +77,7 @@ func (c *cliState) statementDiagList() error { const timeFmt = "2006-01-02 15:04:05 MST" // -- List bundles -- - bundles, err := clisqlclient.StmtDiagListBundles(c.conn) + bundles, err := clisqlclient.StmtDiagListBundles(context.Background(), c.conn) if err != nil { return err } diff --git a/pkg/cli/debug_job_trace.go b/pkg/cli/debug_job_trace.go index 3f1370b57dab..25e90dffdaa2 100644 --- a/pkg/cli/debug_job_trace.go +++ b/pkg/cli/debug_job_trace.go @@ -51,7 +51,8 @@ func runDebugJobTrace(_ *cobra.Command, args []string) (resErr error) { func getJobTraceID(sqlConn clisqlclient.Conn, jobID int64) (int64, error) { var traceID int64 - rows, err := sqlConn.Query(`SELECT trace_id FROM crdb_internal.jobs WHERE job_id=$1`, []driver.Value{jobID}) + rows, err := sqlConn.Query(context.Background(), + `SELECT trace_id FROM crdb_internal.jobs WHERE job_id=$1`, jobID) if err != nil { return traceID, err } @@ -80,17 +81,10 @@ func getJobTraceID(sqlConn clisqlclient.Conn, jobID int64) (int64, error) { } func constructJobTraceZipBundle(ctx context.Context, sqlConn clisqlclient.Conn, jobID int64) error { - maybePrint := func(stmt string) string { - if debugCtx.verbose { - fmt.Println("querying " + stmt) - } - return stmt - } - // Check if a timeout has been set for this command. if cliCtx.cmdTimeout != 0 { - stmt := fmt.Sprintf(`SET statement_timeout = '%s'`, cliCtx.cmdTimeout) - if err := sqlConn.Exec(maybePrint(stmt), nil); err != nil { + if err := sqlConn.Exec(context.Background(), + `SET statement_timeout = $1`, cliCtx.cmdTimeout.String()); err != nil { return err } } diff --git a/pkg/cli/democluster/demo_cluster_test.go b/pkg/cli/democluster/demo_cluster_test.go index f59438e21c8a..364e42adfca1 100644 --- a/pkg/cli/democluster/demo_cluster_test.go +++ b/pkg/cli/democluster/demo_cluster_test.go @@ -237,6 +237,7 @@ func TestTransientClusterSimulateLatencies(t *testing.T) { startTime := timeutil.Now() sqlExecCtx := clisqlexec.Context{} _, _, err = sqlExecCtx.RunQuery( + context.Background(), conn, clisqlclient.MakeQuery(`SHOW ALL CLUSTER QUERIES`), false, @@ -326,6 +327,6 @@ func TestTransientClusterMultitenant(t *testing.T) { }() // Create a table on each tenant to make sure that the tenants are separate. - require.NoError(t, conn.Exec("CREATE TABLE a (a int PRIMARY KEY)", nil)) + require.NoError(t, conn.Exec(context.Background(), "CREATE TABLE a (a int PRIMARY KEY)")) } } diff --git a/pkg/cli/doctor.go b/pkg/cli/doctor.go index 3247518a8e6d..fbe994862835 100644 --- a/pkg/cli/doctor.go +++ b/pkg/cli/doctor.go @@ -23,7 +23,7 @@ import ( "strings" "time" - "github.com/cockroachdb/apd/v3" + apd "github.com/cockroachdb/apd/v3" "github.com/cockroachdb/cockroach/pkg/cli/clierror" "github.com/cockroachdb/cockroach/pkg/cli/clierrorplus" "github.com/cockroachdb/cockroach/pkg/cli/clisqlclient" @@ -198,15 +198,10 @@ func fromCluster( jobsTable doctor.JobsTable, retErr error, ) { - maybePrint := func(stmt string) string { - if debugCtx.verbose { - fmt.Println("querying " + stmt) - } - return stmt - } + ctx := context.Background() if timeout != 0 { - stmt := fmt.Sprintf(`SET statement_timeout = '%s'`, timeout) - if err := sqlConn.Exec(maybePrint(stmt), nil); err != nil { + if err := sqlConn.Exec(ctx, + `SET statement_timeout = $1`, timeout.String()); err != nil { return nil, nil, nil, err } } @@ -214,7 +209,7 @@ func fromCluster( SELECT id, descriptor, crdb_internal_mvcc_timestamp AS mod_time_logical FROM system.descriptor ORDER BY id` checkColumnExistsStmt := "SELECT crdb_internal_mvcc_timestamp FROM system.descriptor LIMIT 1" - _, err := sqlConn.QueryRow(maybePrint(checkColumnExistsStmt), nil) + _, err := sqlConn.QueryRow(ctx, checkColumnExistsStmt) // On versions before 20.2, the system.descriptor won't have the builtin // crdb_internal_mvcc_timestamp. If we can't find it, use NULL instead. if pqErr := (*pq.Error)(nil); errors.As(err, &pqErr) { @@ -228,7 +223,7 @@ FROM system.descriptor ORDER BY id` } descTable = make([]doctor.DescriptorTableRow, 0) - if err := selectRowsMap(sqlConn, maybePrint(stmt), make([]driver.Value, 3), func(vals []driver.Value) error { + if err := selectRowsMap(sqlConn, stmt, make([]driver.Value, 3), func(vals []driver.Value) error { var row doctor.DescriptorTableRow if id, ok := vals[0].(int64); ok { row.ID = id @@ -264,7 +259,7 @@ FROM system.descriptor ORDER BY id` stmt = `SELECT "parentID", "parentSchemaID", name, id FROM system.namespace` checkColumnExistsStmt = `SELECT "parentSchemaID" FROM system.namespace LIMIT 1` - _, err = sqlConn.QueryRow(maybePrint(checkColumnExistsStmt), nil) + _, err = sqlConn.QueryRow(ctx, checkColumnExistsStmt) // On versions before 20.1, table system.namespace does not have this column. // In that case the ParentSchemaID for tables is 29 and for databases is 0. if pqErr := (*pq.Error)(nil); errors.As(err, &pqErr) { @@ -278,7 +273,7 @@ FROM system.namespace` } namespaceTable = make([]doctor.NamespaceTableRow, 0) - if err := selectRowsMap(sqlConn, maybePrint(stmt), make([]driver.Value, 4), func(vals []driver.Value) error { + if err := selectRowsMap(sqlConn, stmt, make([]driver.Value, 4), func(vals []driver.Value) error { var row doctor.NamespaceTableRow if parentID, ok := vals[0].(int64); ok { row.ParentID = descpb.ID(parentID) @@ -309,7 +304,7 @@ FROM system.namespace` stmt = `SELECT id, status, payload, progress FROM system.jobs` jobsTable = make(doctor.JobsTable, 0) - if err := selectRowsMap(sqlConn, maybePrint(stmt), make([]driver.Value, 4), func(vals []driver.Value) error { + if err := selectRowsMap(sqlConn, stmt, make([]driver.Value, 4), func(vals []driver.Value) error { md := jobs.JobMetadata{} md.ID = jobspb.JobID(vals[0].(int64)) md.Status = jobs.Status(vals[1].(string)) @@ -494,7 +489,7 @@ func tableMap(in io.Reader, fn func(string) error) error { func selectRowsMap( conn clisqlclient.Conn, stmt string, vals []driver.Value, fn func([]driver.Value) error, ) error { - rows, err := conn.Query(stmt, nil) + rows, err := conn.Query(context.Background(), stmt) if err != nil { return errors.Wrapf(err, "query '%s'", stmt) } diff --git a/pkg/cli/node.go b/pkg/cli/node.go index bc194ccc9899..0b8f72f62e51 100644 --- a/pkg/cli/node.go +++ b/pkg/cli/node.go @@ -60,13 +60,18 @@ func runLsNodes(cmd *cobra.Command, args []string) (resErr error) { } defer func() { resErr = errors.CombineErrors(resErr, conn.Close()) }() + ctx := context.Background() + + // TODO(knz): This can use a context deadline instead, now that + // query cancellation is supported. if cliCtx.cmdTimeout != 0 { - if err := conn.Exec(fmt.Sprintf("SET statement_timeout=%d", cliCtx.cmdTimeout), nil); err != nil { + if err := conn.Exec(ctx, + "SET statement_timeout = $1", cliCtx.cmdTimeout.String()); err != nil { return err } } - _, rows, err := sqlExecCtx.RunQuery( + _, rows, err := sqlExecCtx.RunQuery(ctx, conn, clisqlclient.MakeQuery(`SELECT node_id FROM crdb_internal.gossip_liveness WHERE membership = 'active' OR split_part(expiration,',',1)::decimal > now()::decimal`), @@ -221,8 +226,13 @@ FROM crdb_internal.gossip_liveness LEFT JOIN crdb_internal.gossip_nodes USING (n queriesToJoin = append(queriesToJoin, decommissionQuery) } + ctx := context.Background() + + // TODO(knz): This can use a context deadline instead, now that + // query cancellation is supported. if cliCtx.cmdTimeout != 0 { - if err := conn.Exec(fmt.Sprintf("SET statement_timeout=%d", cliCtx.cmdTimeout), nil); err != nil { + if err := conn.Exec(ctx, + "SET statement_timeout = $1", cliCtx.cmdTimeout.String()); err != nil { return nil, nil, err } } @@ -232,14 +242,14 @@ FROM crdb_internal.gossip_liveness LEFT JOIN crdb_internal.gossip_nodes USING (n switch len(args) { case 0: query := clisqlclient.MakeQuery(queryString + " ORDER BY id") - return sqlExecCtx.RunQuery(conn, query, false) + return sqlExecCtx.RunQuery(ctx, conn, query, false) case 1: nodeID, err := strconv.Atoi(args[0]) if err != nil { return nil, nil, errors.Errorf("could not parse node_id %s", args[0]) } query := clisqlclient.MakeQuery(queryString+" WHERE id = $1", nodeID) - headers, rows, err := sqlExecCtx.RunQuery(conn, query, false) + headers, rows, err := sqlExecCtx.RunQuery(ctx, conn, query, false) if err != nil { return nil, nil, err } diff --git a/pkg/cli/nodelocal.go b/pkg/cli/nodelocal.go index 33fa7a6e6f2c..141aaf83dfe4 100644 --- a/pkg/cli/nodelocal.go +++ b/pkg/cli/nodelocal.go @@ -131,7 +131,7 @@ func uploadFile( return err } - nodeID, _, _, err := conn.GetServerMetadata() + nodeID, _, _, err := conn.GetServerMetadata(ctx) if err != nil { return errors.Wrap(err, "unable to get node id") } diff --git a/pkg/cli/statement_bundle.go b/pkg/cli/statement_bundle.go index cec6ce526538..9a9e2c3b8681 100644 --- a/pkg/cli/statement_bundle.go +++ b/pkg/cli/statement_bundle.go @@ -162,13 +162,14 @@ func runBundleRecreate(cmd *cobra.Command, args []string) error { return err } // Disable autostats collection, which will override the injected stats. - if err := conn.Exec(`SET CLUSTER SETTING sql.stats.automatic_collection.enabled = false`, nil); err != nil { + if err := conn.Exec(ctx, + `SET CLUSTER SETTING sql.stats.automatic_collection.enabled = false`); err != nil { return err } var initStmts = [][]byte{bundle.env, bundle.schema} initStmts = append(initStmts, bundle.stats...) for _, a := range initStmts { - if err := conn.Exec(string(a), nil); err != nil { + if err := conn.Exec(ctx, string(a)); err != nil { return errors.Wrapf(err, "failed to run %s", a) } } @@ -416,13 +417,12 @@ func getExplainOutputs( ) (explainStrings []string, err error) { for _, values := range inputs { // Run an explain for each possible input. - dvals := make([]driver.Value, len(values)) - for i := range values { - dvals[i] = values[i] - } - query := fmt.Sprintf("%s %s", explainPrefix, statement) - rows, err := conn.Query(query, dvals) + args := make([]interface{}, len(values)) + for i, s := range values { + args[i] = s + } + rows, err := conn.Query(context.Background(), query, args...) if err != nil { return nil, err } diff --git a/pkg/cli/statement_bundle_test.go b/pkg/cli/statement_bundle_test.go index 794b3ea39051..350dabd58117 100644 --- a/pkg/cli/statement_bundle_test.go +++ b/pkg/cli/statement_bundle_test.go @@ -63,18 +63,21 @@ func TestRunExplainCombinations(t *testing.T) { c.LoadDefaults(os.Stdout, os.Stderr) pgURL, cleanupFn := sqlutils.PGUrl(t, tc.Server(0).ServingSQLAddr(), t.Name(), url.User(security.RootUser)) defer cleanupFn() + + ctx := context.Background() + conn := c.ConnCtx.MakeSQLConn(os.Stdout, os.Stdout, pgURL.String()) for _, test := range tests { bundle, err := loadStatementBundle(testutils.TestDataPath(t, "explain-bundle", test.bundlePath)) assert.NoError(t, err) // Disable autostats collection, which will override the injected stats. - if err := conn.Exec(`SET CLUSTER SETTING sql.stats.automatic_collection.enabled = false`, nil); err != nil { + if err := conn.Exec(ctx, `SET CLUSTER SETTING sql.stats.automatic_collection.enabled = false`); err != nil { t.Fatal(err) } var initStmts = [][]byte{bundle.env, bundle.schema} initStmts = append(initStmts, bundle.stats...) for _, a := range initStmts { - if err := conn.Exec(string(a), nil); err != nil { + if err := conn.Exec(ctx, string(a)); err != nil { t.Fatal(err) } } diff --git a/pkg/cli/statement_diag.go b/pkg/cli/statement_diag.go index ab53bb6f86b0..a55586b9c071 100644 --- a/pkg/cli/statement_diag.go +++ b/pkg/cli/statement_diag.go @@ -12,6 +12,7 @@ package cli import ( "bytes" + "context" "fmt" "strconv" "text/tabwriter" @@ -49,8 +50,10 @@ func runStmtDiagList(cmd *cobra.Command, args []string) (resErr error) { } defer func() { resErr = errors.CombineErrors(resErr, conn.Close()) }() + ctx := context.Background() + // -- List bundles -- - bundles, err := clisqlclient.StmtDiagListBundles(conn) + bundles, err := clisqlclient.StmtDiagListBundles(ctx, conn) if err != nil { return err } @@ -71,7 +74,7 @@ func runStmtDiagList(cmd *cobra.Command, args []string) (resErr error) { } // -- List outstanding activation requests -- - reqs, err := clisqlclient.StmtDiagListOutstandingRequests(conn) + reqs, err := clisqlclient.StmtDiagListOutstandingRequests(ctx, conn) if err != nil { return err } @@ -131,7 +134,8 @@ func runStmtDiagDownload(cmd *cobra.Command, args []string) (resErr error) { } defer func() { resErr = errors.CombineErrors(resErr, conn.Close()) }() - if err := clisqlclient.StmtDiagDownloadBundle(conn, id, filename); err != nil { + if err := clisqlclient.StmtDiagDownloadBundle( + context.Background(), conn, id, filename); err != nil { return err } fmt.Printf("Bundle saved to %q\n", filename) @@ -154,11 +158,13 @@ func runStmtDiagDelete(cmd *cobra.Command, args []string) (resErr error) { } defer func() { resErr = errors.CombineErrors(resErr, conn.Close()) }() + ctx := context.Background() + if stmtDiagCtx.all { if len(args) > 0 { return errors.New("extra arguments with --all") } - return clisqlclient.StmtDiagDeleteAllBundles(conn) + return clisqlclient.StmtDiagDeleteAllBundles(ctx, conn) } if len(args) != 1 { return fmt.Errorf("accepts 1 arg, received %d", len(args)) @@ -169,7 +175,7 @@ func runStmtDiagDelete(cmd *cobra.Command, args []string) (resErr error) { return errors.New("invalid ID") } - return clisqlclient.StmtDiagDeleteBundle(conn, id) + return clisqlclient.StmtDiagDeleteBundle(ctx, conn, id) } var stmtDiagCancelCmd = &cobra.Command{ @@ -188,11 +194,13 @@ func runStmtDiagCancel(cmd *cobra.Command, args []string) (resErr error) { } defer func() { resErr = errors.CombineErrors(resErr, conn.Close()) }() + ctx := context.Background() + if stmtDiagCtx.all { if len(args) > 0 { return errors.New("extra arguments with --all") } - return clisqlclient.StmtDiagCancelAllOutstandingRequests(conn) + return clisqlclient.StmtDiagCancelAllOutstandingRequests(ctx, conn) } if len(args) != 1 { return fmt.Errorf("accepts 1 arg, received %d", len(args)) @@ -203,7 +211,7 @@ func runStmtDiagCancel(cmd *cobra.Command, args []string) (resErr error) { return errors.New("invalid ID") } - return clisqlclient.StmtDiagCancelOutstandingRequest(conn, id) + return clisqlclient.StmtDiagCancelOutstandingRequest(ctx, conn, id) } var stmtDiagCmds = []*cobra.Command{ diff --git a/pkg/cli/userfiletable_test.go b/pkg/cli/userfiletable_test.go index b52a9b727768..941ff07f1336 100644 --- a/pkg/cli/userfiletable_test.go +++ b/pkg/cli/userfiletable_test.go @@ -815,11 +815,11 @@ func TestUsernameUserfileInteraction(t *testing.T) { }, } { createUserQuery := fmt.Sprintf(`CREATE USER "%s" WITH PASSWORD 'a'`, tc.username) - err = conn.Exec(createUserQuery, nil) + err = conn.Exec(ctx, createUserQuery) require.NoError(t, err) privsUserQuery := fmt.Sprintf(`GRANT CREATE ON DATABASE defaultdb TO "%s"`, tc.username) - err = conn.Exec(privsUserQuery, nil) + err = conn.Exec(ctx, privsUserQuery) require.NoError(t, err) userURL, cleanup2 := sqlutils.PGUrlWithOptionalClientCerts(t, c.ServingSQLAddr(), t.Name(), diff --git a/pkg/cli/zip.go b/pkg/cli/zip.go index 4dc08d1fc1dd..6df2eb958e8b 100644 --- a/pkg/cli/zip.go +++ b/pkg/cli/zip.go @@ -300,6 +300,8 @@ func maybeAddProfileSuffix(name string) string { func (zc *debugZipContext) dumpTableDataForZip( zr *zipReporter, conn clisqlclient.Conn, base, table, query string, ) error { + // TODO(knz): This can use context cancellation now that query + // cancellation is supported. fullQuery := fmt.Sprintf(`SET statement_timeout = '%s'; %s`, zc.timeout, query) baseName := base + "/" + sanitizeFilename(table) @@ -319,7 +321,9 @@ func (zc *debugZipContext) dumpTableDataForZip( } // Pump the SQL rows directly into the zip writer, to avoid // in-RAM buffering. - return sqlExecCtx.RunQueryAndFormatResults(conn, w, stderr, clisqlclient.MakeQuery(fullQuery)) + return sqlExecCtx.RunQueryAndFormatResults( + context.Background(), + conn, w, stderr, clisqlclient.MakeQuery(fullQuery)) }() if sqlErr != nil { if cErr := zc.z.createError(s, name, sqlErr); cErr != nil { From 7064116c335b9de62ca3cc3a790d02d339a000f4 Mon Sep 17 00:00:00 2001 From: Raphael 'kena' Poss Date: Fri, 11 Feb 2022 14:29:11 +0100 Subject: [PATCH 2/2] cli,sql: do not exit interactive shells with Ctrl+C This change is a preface to making the interactive shell support Ctrl+C to cancel a currently executing query. When that happens, we need to ensure that Ctrl+C *only* cancels the current query, and not the entire shell. This is because there is no deterministic boundary in time for a user to decide whether the key press will cancel only the query, or stop the shell. If they care about keeping their shell alive, they would never "dare" to press Ctrl+C while a query is executing. Incidentally, this is also what `psql` does, for pretty much the same reason. Note that the new behavior is restricted to *interactive* shells, when stdin is a terminal. The behavior for non-interactive shells remains unchanged, where Ctrl+C will terminate everything. This is also what other SQL shells do. Release note (cli change): `cockroach sql` (and `demo`) now continue to accept user input when Ctrl+C is pressed at the interactive prompt and the current input line is empty. Previously, it would terminate the shell. To terminate the shell, the client-side command `\q` is still supported. The user can also terminate the input altogether via EOF (Ctrl+D). The behavior in non-interactive use remains unchanged. --- pkg/cli/clisqlshell/sql.go | 8 +++++--- pkg/cli/interactive_tests/test_audit_log.tcl | 4 ++-- .../test_client_side_checking.tcl | 2 +- .../test_contextual_help.tcl | 2 +- pkg/cli/interactive_tests/test_demo.tcl | 18 ++++++++--------- .../interactive_tests/test_demo_global.tcl | 2 +- .../test_demo_multitenant.tcl | 2 +- .../interactive_tests/test_demo_node_cmds.tcl | 20 ++++++++++++++++++- .../test_demo_partitioning.tcl | 6 +++--- .../interactive_tests/test_demo_telemetry.tcl | 2 +- .../interactive_tests/test_demo_workload.tcl | 6 ++++-- pkg/cli/interactive_tests/test_dump_sig.tcl | 2 +- .../test_explain_analyze_debug.tcl | 4 ++-- pkg/cli/interactive_tests/test_flags.tcl | 6 +++--- pkg/cli/interactive_tests/test_history.tcl | 4 ++-- .../interactive_tests/test_init_command.tcl | 2 +- pkg/cli/interactive_tests/test_local_cmds.tcl | 8 ++++---- .../test_multiline_statements.tcl | 3 ++- pkg/cli/interactive_tests/test_notice.tcl | 2 +- pkg/cli/interactive_tests/test_secure.tcl | 9 ++++----- .../test_sql_demo_node_cmds.tcl | 2 +- .../test_sql_mem_monitor.tcl | 2 +- .../test_sql_version_reporting.tcl | 4 ++-- .../interactive_tests/test_style_enabled.tcl | 4 ++-- 24 files changed, 73 insertions(+), 51 deletions(-) diff --git a/pkg/cli/clisqlshell/sql.go b/pkg/cli/clisqlshell/sql.go index 43a2ebc9f950..d6bbd8385c71 100644 --- a/pkg/cli/clisqlshell/sql.go +++ b/pkg/cli/clisqlshell/sql.go @@ -1021,9 +1021,11 @@ func (c *cliState) doReadLine(nextState cliStateEnum) cliStateEnum { return cliStartLine } - // Otherwise, also terminate with an interrupt error. - c.exitErr = err - return cliStop + // If a human is looking, tell them that quitting is done in another way. + if c.sqlExecCtx.TerminalOutput { + fmt.Fprintf(c.iCtx.stdout, "^C\nUse \\q or terminate input to exit.\n") + } + return cliStartLine case errors.Is(err, io.EOF): c.atEOF = true diff --git a/pkg/cli/interactive_tests/test_audit_log.tcl b/pkg/cli/interactive_tests/test_audit_log.tcl index a6da43a27435..48e649ce3226 100644 --- a/pkg/cli/interactive_tests/test_audit_log.tcl +++ b/pkg/cli/interactive_tests/test_audit_log.tcl @@ -58,7 +58,7 @@ eexpect root@ system "grep -q 'sensitive_table_access.*ALTER TABLE.*helloworld.*SET OFF.*AccessMode\":\"rw\"' $logfile" end_test -interrupt +send_eof eexpect eof stop_server $argv @@ -81,7 +81,7 @@ eexpect "ALTER TABLE" eexpect root@ send "select x from d.helloworld;\r" eexpect root@ -interrupt +send_eof eexpect eof # Check the file was created and populated properly. diff --git a/pkg/cli/interactive_tests/test_client_side_checking.tcl b/pkg/cli/interactive_tests/test_client_side_checking.tcl index 3f1e58b0483a..a46ebcc2f458 100644 --- a/pkg/cli/interactive_tests/test_client_side_checking.tcl +++ b/pkg/cli/interactive_tests/test_client_side_checking.tcl @@ -52,7 +52,7 @@ send "commit;\r" eexpect "ROLLBACK" eexpect root@ -interrupt +send_eof eexpect eof end_test diff --git a/pkg/cli/interactive_tests/test_contextual_help.tcl b/pkg/cli/interactive_tests/test_contextual_help.tcl index 9e89db78d03e..49aa5d42b744 100644 --- a/pkg/cli/interactive_tests/test_contextual_help.tcl +++ b/pkg/cli/interactive_tests/test_contextual_help.tcl @@ -128,7 +128,7 @@ eexpect root@ end_test # Finally terminate with Ctrl+C. -interrupt +send_eof eexpect eof start_test "Check that the hint for a single ? is also printed in non-interactive sessions." diff --git a/pkg/cli/interactive_tests/test_demo.tcl b/pkg/cli/interactive_tests/test_demo.tcl index 7319569bafc7..26faff34a0f8 100644 --- a/pkg/cli/interactive_tests/test_demo.tcl +++ b/pkg/cli/interactive_tests/test_demo.tcl @@ -35,7 +35,7 @@ eexpect "brief introduction" eexpect root@ # Ensure db is movr. eexpect "movr>" -interrupt +send_eof eexpect eof end_test @@ -68,7 +68,7 @@ eexpect ":26257" eexpect "sslmode=disable" eexpect "defaultdb>" -interrupt +send_eof eexpect eof # With command-line override. @@ -88,7 +88,7 @@ eexpect "(sql/unix)" eexpect "root:unused@" eexpect "defaultdb>" -interrupt +send_eof eexpect eof end_test @@ -125,7 +125,7 @@ eexpect ":26257" eexpect "sslmode=require" eexpect "defaultdb>" -interrupt +send_eof eexpect eof # With command-line override. @@ -146,7 +146,7 @@ eexpect "(sql/unix)" eexpect "demo:" eexpect "defaultdb>" -interrupt +send_eof eexpect eof end_test @@ -213,7 +213,7 @@ eexpect "http://" eexpect ":8005" eexpect "defaultdb>" -interrupt +send_eof eexpect eof spawn $argv demo --no-example-database --nodes 3 --sql-port 23000 @@ -249,7 +249,7 @@ eexpect "(sql)" eexpect ":23002" eexpect "defaultdb>" -interrupt +send_eof eexpect eof @@ -264,7 +264,7 @@ eexpect "defaultdb>" # Check the URL is valid. If the connection fails, the system command will fail too. system "$argv sql --url `cat test.url` -e 'select 1'" -interrupt +send_eof eexpect eof # Ditto, insecure @@ -275,7 +275,7 @@ eexpect "defaultdb>" # Check the URL is valid. If the connection fails, the system command will fail too. system "$argv sql --url `cat test.url` -e 'select 1'" -interrupt +send_eof eexpect eof diff --git a/pkg/cli/interactive_tests/test_demo_global.tcl b/pkg/cli/interactive_tests/test_demo_global.tcl index 958431e703d2..18050d08ea80 100644 --- a/pkg/cli/interactive_tests/test_demo_global.tcl +++ b/pkg/cli/interactive_tests/test_demo_global.tcl @@ -28,6 +28,6 @@ send "\\demo shutdown 3\r" eexpect "shutting down nodes is not supported in --global configurations" eexpect "defaultdb>" -interrupt +send_eof eexpect eof end_test diff --git a/pkg/cli/interactive_tests/test_demo_multitenant.tcl b/pkg/cli/interactive_tests/test_demo_multitenant.tcl index 1bfd333aebcf..ae8b5485003a 100644 --- a/pkg/cli/interactive_tests/test_demo_multitenant.tcl +++ b/pkg/cli/interactive_tests/test_demo_multitenant.tcl @@ -25,6 +25,6 @@ send "SELECT gateway_region();\n" eexpect "us-east1" eexpect "defaultdb>" -interrupt +send_eof eexpect eof end_test diff --git a/pkg/cli/interactive_tests/test_demo_node_cmds.tcl b/pkg/cli/interactive_tests/test_demo_node_cmds.tcl index d9aab7d102e6..37e58c03d6e1 100644 --- a/pkg/cli/interactive_tests/test_demo_node_cmds.tcl +++ b/pkg/cli/interactive_tests/test_demo_node_cmds.tcl @@ -13,10 +13,12 @@ eexpect "movr>" # Wrong number of args send "\\demo node\r" eexpect "invalid syntax: \\\\demo node. Try \\\\? for help." +eexpect "movr>" # Cannot shutdown node 1 send "\\demo shutdown 1\r" eexpect "cannot shutdown node 1" +eexpect "movr>" # Cannot operate on a node which does not exist. send "\\demo shutdown 8\r" @@ -27,14 +29,17 @@ send "\\demo decommission 8\r" eexpect "node 8 does not exist" send "\\demo recommission 8\r" eexpect "node 8 does not exist" +eexpect "movr>" # Cannot restart a node that is not shut down. send "\\demo restart 2\r" eexpect "node 2 is already running" +eexpect "movr>" # Shut down a separate node. send "\\demo shutdown 3\r" eexpect "node 3 has been shutdown" +eexpect "movr>" send "select node_id, draining, decommissioning, membership from crdb_internal.gossip_liveness ORDER BY node_id;\r" eexpect "1 | false | false | active" @@ -42,10 +47,12 @@ eexpect "2 | false | false | active" eexpect "3 | true | false | active" eexpect "4 | false | false | active" eexpect "5 | false | false | active" +eexpect "movr>" # Cannot shut it down again. send "\\demo shutdown 3\r" eexpect "node 3 is already shut down" +eexpect "movr>" # Expect queries to still work with just one node down. send "SELECT count(*) FROM movr.rides;\r" @@ -55,6 +62,7 @@ eexpect "movr>" # Now restart the node. send "\\demo restart 3\r" eexpect "node 3 has been restarted" +eexpect "movr>" send "select node_id, draining, decommissioning, membership from crdb_internal.gossip_liveness ORDER BY node_id;\r" eexpect "1 | false | false | active" @@ -62,10 +70,12 @@ eexpect "2 | false | false | active" eexpect "3 | false | false | active" eexpect "4 | false | false | active" eexpect "5 | false | false | active" +eexpect "movr>" # Try commissioning commands send "\\demo decommission 4\r" eexpect "node 4 has been decommissioned" +eexpect "movr>" send "select node_id, draining, decommissioning, membership from crdb_internal.gossip_liveness ORDER BY node_id;\r" eexpect "1 | false | false | active" @@ -73,20 +83,25 @@ eexpect "2 | false | false | active" eexpect "3 | false | false | active" eexpect "4 | false | true | decommissioned" eexpect "5 | false | false | active" +eexpect "movr>" send "\\demo recommission 4\r" eexpect "can only recommission a decommissioning node" +eexpect "movr>" send "\\demo add blah\r" eexpect "internal server error: tier must be in the form \"key=value\" not \"blah\"" +eexpect "movr>" send "\\demo add region=ca-central,zone=a\r" eexpect "node 6 has been added with locality \"region=ca-central,zone=a\"" +eexpect "movr>" send "show regions from cluster;\r" eexpect "ca-central | \{a\}" eexpect "us-east1 | \{b,c,d\}" eexpect "us-west1 | \{b\}" +eexpect "movr>" # We use kv_node_status here because gossip_liveness is timing dependant. # Node 4's status entry should have been removed by now. @@ -96,10 +111,12 @@ eexpect "2 | region=us-east1,az=c" eexpect "3 | region=us-east1,az=d" eexpect "5 | region=us-west1,az=b" eexpect "6 | region=ca-central,zone=a" +eexpect "movr>" # Shut down the newly created node. send "\\demo shutdown 6\r" eexpect "node 6 has been shutdown" +eexpect "movr>" # By now the node should have stabilized in gossip which allows us to query the more detailed information there. send "select node_id, draining, decommissioning, membership from crdb_internal.gossip_liveness ORDER BY node_id;\r" @@ -109,7 +126,8 @@ eexpect "3 | false | false | active" eexpect "4 | false | true | decommissioned" eexpect "5 | false | false | active" eexpect "6 | true | false | active" +eexpect "movr>" -interrupt +send_eof eexpect eof end_test diff --git a/pkg/cli/interactive_tests/test_demo_partitioning.tcl b/pkg/cli/interactive_tests/test_demo_partitioning.tcl index 5344826a9f35..e4f52c5b70ba 100644 --- a/pkg/cli/interactive_tests/test_demo_partitioning.tcl +++ b/pkg/cli/interactive_tests/test_demo_partitioning.tcl @@ -95,7 +95,7 @@ eexpect "8" eexpect "(1 row)" eexpect "movr>" -interrupt +send_eof eexpect $prompt end_test @@ -119,7 +119,7 @@ eexpect " survival_goal" eexpect "zone" eexpect "movr>" -interrupt +send_eof eexpect $prompt send "$argv demo movr --geo-partitioned-replicas --multi-region --survive=region\r" @@ -131,7 +131,7 @@ eexpect " survival_goal" eexpect "region" eexpect "movr>" -interrupt +send_eof eexpect $prompt end_test diff --git a/pkg/cli/interactive_tests/test_demo_telemetry.tcl b/pkg/cli/interactive_tests/test_demo_telemetry.tcl index 43ad6a3a8689..a35152610e2e 100644 --- a/pkg/cli/interactive_tests/test_demo_telemetry.tcl +++ b/pkg/cli/interactive_tests/test_demo_telemetry.tcl @@ -18,7 +18,7 @@ send "alter table vehicles partition by list (city) (partition p1 values in ('ny # expect that it failed, as no license was requested. eexpect "use of partitions requires an enterprise license" # clean up after the test -interrupt +send_eof eexpect eof end_test diff --git a/pkg/cli/interactive_tests/test_demo_workload.tcl b/pkg/cli/interactive_tests/test_demo_workload.tcl index 667fc104c588..3e2a62fd9041 100644 --- a/pkg/cli/interactive_tests/test_demo_workload.tcl +++ b/pkg/cli/interactive_tests/test_demo_workload.tcl @@ -30,8 +30,9 @@ if {!$workloadRunning} { report "Workload is not running" exit 1 } +eexpect "movr>" -interrupt +send_eof eexpect eof end_test @@ -48,7 +49,8 @@ eexpect "movr>" send "SELECT count(*) FROM \[SHOW RANGES FROM TABLE USERS\];\r" eexpect "6" eexpect "(1 row)" +eexpect "movr>" -interrupt +send_eof eexpect eof end_test diff --git a/pkg/cli/interactive_tests/test_dump_sig.tcl b/pkg/cli/interactive_tests/test_dump_sig.tcl index 0cb854d1cfa1..a77b065e3e8c 100644 --- a/pkg/cli/interactive_tests/test_dump_sig.tcl +++ b/pkg/cli/interactive_tests/test_dump_sig.tcl @@ -37,7 +37,7 @@ send "\r" eexpect root@ # Finish the test. -interrupt +send_eof eexpect ":/# " end_test diff --git a/pkg/cli/interactive_tests/test_explain_analyze_debug.tcl b/pkg/cli/interactive_tests/test_explain_analyze_debug.tcl index 33de502e65ca..cf467a2fa1b7 100644 --- a/pkg/cli/interactive_tests/test_explain_analyze_debug.tcl +++ b/pkg/cli/interactive_tests/test_explain_analyze_debug.tcl @@ -42,7 +42,7 @@ eexpect "Bundle saved to" file_exists "stmt-bundle-$id.zip" -interrupt +send_eof eexpect eof end_test @@ -81,7 +81,7 @@ eexpect "Bundle saved to" file_exists "stmt-bundle-$id.zip" -interrupt +send_eof eexpect eof stop_tenant 5 $argv diff --git a/pkg/cli/interactive_tests/test_flags.tcl b/pkg/cli/interactive_tests/test_flags.tcl index dc08100d45b1..a30b43798f29 100644 --- a/pkg/cli/interactive_tests/test_flags.tcl +++ b/pkg/cli/interactive_tests/test_flags.tcl @@ -89,7 +89,7 @@ eexpect "cli.demo.explicitflags.logtostderr" eexpect "cli.demo.explicitflags.no-example-database" eexpect "cli.demo.runs" eexpect "defaultdb>" -interrupt +send_eof eexpect ":/# " end_test @@ -112,7 +112,7 @@ eexpect "cli.start-single-node.explicitflags.listening-url-file" eexpect "cli.start-single-node.explicitflags.max-sql-memory" eexpect "cli.start-single-node.runs" eexpect "defaultdb>" -interrupt +send_eof eexpect ":/# " end_test @@ -121,7 +121,7 @@ send "export COCKROACH_URL=`cat server_url`;\r" eexpect ":/# " send "$argv sql\r" eexpect "defaultdb>" -interrupt +send_eof eexpect ":/# " end_test diff --git a/pkg/cli/interactive_tests/test_history.tcl b/pkg/cli/interactive_tests/test_history.tcl index c92991b9bebe..26c9e5a0d35c 100644 --- a/pkg/cli/interactive_tests/test_history.tcl +++ b/pkg/cli/interactive_tests/test_history.tcl @@ -81,8 +81,8 @@ eexpect "1 row" eexpect root@ end_test -# Finally terminate with Ctrl+C -interrupt +# Finally terminate with Ctrl+D +send_eof eexpect eof stop_server $argv diff --git a/pkg/cli/interactive_tests/test_init_command.tcl b/pkg/cli/interactive_tests/test_init_command.tcl index 090131b52e43..1ceae03e8b9d 100644 --- a/pkg/cli/interactive_tests/test_init_command.tcl +++ b/pkg/cli/interactive_tests/test_init_command.tcl @@ -61,7 +61,7 @@ expect { } interrupt -interrupt +send_eof eexpect eof end_test diff --git a/pkg/cli/interactive_tests/test_local_cmds.tcl b/pkg/cli/interactive_tests/test_local_cmds.tcl index f1cb909bfd11..84f55f346130 100755 --- a/pkg/cli/interactive_tests/test_local_cmds.tcl +++ b/pkg/cli/interactive_tests/test_local_cmds.tcl @@ -266,8 +266,8 @@ eexpect Description eexpect root@ end_test -# Finally terminate with Ctrl+C. -interrupt +# Finally terminate with Ctrl+D. +send_eof eexpect eof spawn /bin/bash @@ -301,7 +301,7 @@ eexpect "errexit,false" eexpect "prompt1,%n@" eexpect "show_times,true" eexpect root@ -interrupt +send_eof eexpect ":/# " # Then verify that the defaults can be overridden. @@ -316,7 +316,7 @@ eexpect "errexit,true" eexpect "prompt1,%n@haa" eexpect "show_times,false" eexpect root@ -interrupt +send_eof eexpect ":/# " end_test diff --git a/pkg/cli/interactive_tests/test_multiline_statements.tcl b/pkg/cli/interactive_tests/test_multiline_statements.tcl index 2cade7edf5b7..31036ada9eb9 100644 --- a/pkg/cli/interactive_tests/test_multiline_statements.tcl +++ b/pkg/cli/interactive_tests/test_multiline_statements.tcl @@ -75,6 +75,7 @@ eexpect " ->" send "\\p\r" eexpect "select\r\n*->" interrupt +eexpect root@ end_test start_test "Test that a dangling table creation can be committed, and that other non-DDL, non-DML statements can be issued in the same txn. (#15283)" @@ -111,7 +112,7 @@ end_test -interrupt +send_eof eexpect eof stop_server $argv diff --git a/pkg/cli/interactive_tests/test_notice.tcl b/pkg/cli/interactive_tests/test_notice.tcl index 103832ad193f..f3160e088a7e 100644 --- a/pkg/cli/interactive_tests/test_notice.tcl +++ b/pkg/cli/interactive_tests/test_notice.tcl @@ -26,6 +26,6 @@ eexpect "NOTICE: hello" eexpect "NOTICE: world" eexpect "WARNING: stay indoors" eexpect root@ -interrupt +send_eof eexpect eof end_test diff --git a/pkg/cli/interactive_tests/test_secure.tcl b/pkg/cli/interactive_tests/test_secure.tcl index 2628f1f2be71..710058a21cb8 100644 --- a/pkg/cli/interactive_tests/test_secure.tcl +++ b/pkg/cli/interactive_tests/test_secure.tcl @@ -98,11 +98,10 @@ eexpect "Failed running \"sql\"" # Check that history is scrubbed. send "$argv sql --certs-dir=$certs_dir\r" eexpect "root@" -interrupt end_test -# Terminate the shell with Ctrl+C. -interrupt +# Terminate the shell with Ctrl+D. +send_eof eexpect $prompt start_test "Check that an auth cookie cannot be created for a user that does not exist." @@ -119,12 +118,12 @@ send "$argv sql --url 'postgres://eisen@?host=$mywd&port=26257'\r" eexpect "Enter password:" send "hunter2\r" eexpect "eisen@" -interrupt +send_eof eexpect $prompt send "$argv sql --url 'postgres://eisen:hunter2@?host=$mywd&port=26257'\r" eexpect "eisen@" -interrupt +send_eof eexpect $prompt end_test diff --git a/pkg/cli/interactive_tests/test_sql_demo_node_cmds.tcl b/pkg/cli/interactive_tests/test_sql_demo_node_cmds.tcl index 2cd2dac31bdf..94763dff5e18 100644 --- a/pkg/cli/interactive_tests/test_sql_demo_node_cmds.tcl +++ b/pkg/cli/interactive_tests/test_sql_demo_node_cmds.tcl @@ -17,7 +17,7 @@ send "\\demo shutdown 2\n" eexpect "\\demo can only be run with cockroach demo" # Exit the shell. -interrupt +send_eof eexpect eof # Have good manners and clean up. diff --git a/pkg/cli/interactive_tests/test_sql_mem_monitor.tcl b/pkg/cli/interactive_tests/test_sql_mem_monitor.tcl index 4249dffa73d0..feb5e35fe266 100644 --- a/pkg/cli/interactive_tests/test_sql_mem_monitor.tcl +++ b/pkg/cli/interactive_tests/test_sql_mem_monitor.tcl @@ -117,7 +117,7 @@ eexpect "1 row" eexpect root@ end_test -interrupt +send_eof eexpect eof set spawn_id $shell_spawn_id diff --git a/pkg/cli/interactive_tests/test_sql_version_reporting.tcl b/pkg/cli/interactive_tests/test_sql_version_reporting.tcl index 482a89d8dda3..f82164fde477 100644 --- a/pkg/cli/interactive_tests/test_sql_version_reporting.tcl +++ b/pkg/cli/interactive_tests/test_sql_version_reporting.tcl @@ -62,7 +62,7 @@ eexpect "New ID:" eexpect root@ end_test -interrupt +send_eof eexpect eof stop_server $argv @@ -80,7 +80,7 @@ eexpect "warning: server version older than client" eexpect root@ end_test -interrupt +send_eof eexpect eof stop_server $argv diff --git a/pkg/cli/interactive_tests/test_style_enabled.tcl b/pkg/cli/interactive_tests/test_style_enabled.tcl index 3d25660b0d46..44feb0ee402e 100644 --- a/pkg/cli/interactive_tests/test_style_enabled.tcl +++ b/pkg/cli/interactive_tests/test_style_enabled.tcl @@ -13,7 +13,7 @@ eexpect root@ send "SET CLUSTER SETTING sql.defaults.datestyle.enabled = true;\r" eexpect "SET CLUSTER SETTING" eexpect root@ -interrupt +send_eof eexpect eof @@ -24,7 +24,7 @@ eexpect root@ send "SHOW intervalstyle;\r" eexpect "iso_8601" eexpect root@ -interrupt +send_eof eexpect eof # TODO(#72065): uncomment