Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cli,sql: do not exit interactive shells with Ctrl+C #76427

Merged
merged 2 commits into from
Feb 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions pkg/cli/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
package cli

import (
"database/sql/driver"
"context"
"fmt"
"net/http"
"os"
Expand Down Expand Up @@ -98,8 +98,11 @@ func createAuthSessionToken(
}
defer func() { resErr = errors.CombineErrors(resErr, sqlConn.Close()) }()

ctx := context.Background()

// First things first. Does the user exist?
_, rows, err := sqlExecCtx.RunQuery(sqlConn,
_, rows, err := sqlExecCtx.RunQuery(ctx,
sqlConn,
clisqlclient.MakeQuery(`SELECT count(username) FROM system.users WHERE username = $1 AND NOT "isRole"`, username), false)
if err != nil {
return -1, nil, err
Expand All @@ -122,13 +125,11 @@ VALUES($1, $2, $3)
RETURNING id
`
var id int64
row, err := sqlConn.QueryRow(
row, err := sqlConn.QueryRow(ctx,
insertSessionStmt,
[]driver.Value{
hashedSecret,
username,
expiration,
},
hashedSecret,
username,
expiration,
)
if err != nil {
return -1, nil, err
Expand Down Expand Up @@ -176,7 +177,9 @@ func runLogout(cmd *cobra.Command, args []string) (resErr error) {
id AS "session ID",
"revokedAt" AS "revoked"`,
username)
return sqlExecCtx.RunQueryAndFormatResults(sqlConn, os.Stdout, stderr, logoutQuery)
return sqlExecCtx.RunQueryAndFormatResults(
context.Background(),
sqlConn, os.Stdout, stderr, logoutQuery)
}

var authListCmd = &cobra.Command{
Expand Down Expand Up @@ -206,7 +209,9 @@ SELECT username,
"revokedAt" as "revoked",
"lastUsedAt" as "last used"
FROM system.web_sessions`)
return sqlExecCtx.RunQueryAndFormatResults(sqlConn, os.Stdout, stderr, logoutQuery)
return sqlExecCtx.RunQueryAndFormatResults(
context.Background(),
sqlConn, os.Stdout, stderr, logoutQuery)
}

var authCmds = []*cobra.Command{
Expand Down
4 changes: 2 additions & 2 deletions pkg/cli/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ func TestCLITimeout(t *testing.T) {
// the timeout a chance to have an effect. We specify --all to include some
// slower to access virtual tables in the query.
testutils.SucceedsSoon(t, func() error {
out, err := c.RunWithCapture("node status 1 --all --timeout 1ns")
out, err := c.RunWithCapture("node status 1 --all --timeout 1ms")
if err != nil {
t.Fatal(err)
}

const exp = `node status 1 --all --timeout 1ns
const exp = `node status 1 --all --timeout 1ms
ERROR: query execution canceled due to statement timeout
SQLSTATE: 57014
`
Expand Down
3 changes: 2 additions & 1 deletion pkg/cli/clierror/syntax_error_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
package clierror_test

import (
"context"
"io/ioutil"
"net/url"
"testing"
Expand Down Expand Up @@ -44,7 +45,7 @@ func TestIsSQLSyntaxError(t *testing.T) {
}
}()

_, err := conn.QueryRow(`INVALID SYNTAX`, nil)
_, err := conn.QueryRow(context.Background(), `INVALID SYNTAX`)
if !clierror.IsSQLSyntaxError(err) {
t.Fatalf("expected error to be recognized as syntax error: %+v", err)
}
Expand Down
7 changes: 5 additions & 2 deletions pkg/cli/clisqlcfg/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package clisqlcfg

import (
"context"
"fmt"
"os"
"strconv"
Expand Down Expand Up @@ -213,7 +214,8 @@ func (c *Context) maybeSetSafeUpdates(conn clisqlclient.Conn) {
safeUpdates = c.CliCtx.IsInteractive
}
if safeUpdates {
if err := conn.Exec("SET sql_safe_updates = TRUE", nil); err != nil {
if err := conn.Exec(context.Background(),
"SET sql_safe_updates = TRUE"); err != nil {
// We only enable the setting in interactive sessions. Ignoring
// the error with a warning is acceptable, because the user is
// there to decide what they want to do if it doesn't work.
Expand All @@ -228,5 +230,6 @@ func (c *Context) maybeSetReadOnly(conn clisqlclient.Conn) error {
if !c.ReadOnly {
return nil
}
return conn.Exec("SET default_transaction_read_only = TRUE", nil)
return conn.Exec(context.Background(),
"SET default_transaction_read_only = TRUE")
}
24 changes: 10 additions & 14 deletions pkg/cli/clisqlclient/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
package clisqlclient

import (
"context"
"database/sql/driver"
"reflect"
"time"
Expand All @@ -26,24 +27,24 @@ type Conn interface {
EnsureConn() error

// Exec executes a statement.
Exec(query string, args []driver.Value) error
Exec(ctx context.Context, query string, args ...interface{}) error

// Query returns one or more SQL statements and returns the
// corresponding result set(s).
Query(query string, args []driver.Value) (Rows, error)
Query(ctx context.Context, query string, args ...interface{}) (Rows, error)

// QueryRow execute a SQL query returning exactly one row
// and retrieves the returned values. An error is returned
// if the query returns zero or more than one row.
QueryRow(query string, args []driver.Value) ([]driver.Value, error)
QueryRow(ctx context.Context, query string, args ...interface{}) ([]driver.Value, error)

// ExecTxn runs fn inside a transaction and retries it as needed.
ExecTxn(fn func(TxBoundConn) error) error
ExecTxn(ctx context.Context, fn func(context.Context, TxBoundConn) error) error

// GetLastQueryStatistics returns the detailed latency stats for the
// last executed query, if supported by the server and enabled by
// configuration.
GetLastQueryStatistics() (result QueryStats, err error)
GetLastQueryStatistics(ctx context.Context) (result QueryStats, err error)

// SetURL changes the URL field in the connection object, so that the
// next connection (re-)establishment will use the new URL.
Expand All @@ -69,7 +70,7 @@ type Conn interface {

// GetServerMetadata() returns details about the CockroachDB node
// this connection is connected to.
GetServerMetadata() (
GetServerMetadata(ctx context.Context) (
nodeID int32,
version, clusterID string,
err error,
Expand All @@ -82,7 +83,7 @@ type Conn interface {
// The what argument is a descriptive label for the value being
// retrieved, for inclusion inside warning or error message.
// The sql argument is the SQL query to use to retrieve the value.
GetServerValue(what, sql string) (driver.Value, string, bool)
GetServerValue(ctx context.Context, what, sql string) (driver.Value, string, bool)

// GetDriverConn exposes the underlying SQL driver connection object
// for use by the cli package.
Expand Down Expand Up @@ -159,11 +160,11 @@ type QueryStats struct {
// visible to the closure passed to (Conn).ExecTxn.
type TxBoundConn interface {
// Exec executes a statement inside the transaction.
Exec(query string, args []driver.Value) error
Exec(ctx context.Context, query string, args ...interface{}) error

// Query returns one or more SQL statements and returns the
// corresponding result set(s).
Query(query string, args []driver.Value) (Rows, error)
Query(ctx context.Context, query string, args ...interface{}) (Rows, error)
}

// DriverConn is the type of the connection object returned by
Expand All @@ -173,9 +174,4 @@ type DriverConn interface {
driver.Conn
driver.ExecerContext
driver.QueryerContext

//lint:ignore SA1019 TODO(mjibson): clean this up to use go1.8 APIs
driver.Execer
//lint:ignore SA1019 TODO(mjibson): clean this up to use go1.8 APIs
driver.Queryer
}
Loading