diff --git a/pkg/cli/BUILD.bazel b/pkg/cli/BUILD.bazel index 94075f7e5403..a1060ae379bf 100644 --- a/pkg/cli/BUILD.bazel +++ b/pkg/cli/BUILD.bazel @@ -29,6 +29,7 @@ go_library( "demo.go", "demo_telemetry.go", "doctor.go", + "env.go", "examples.go", "flags.go", "flags_util.go", @@ -230,8 +231,9 @@ go_library( "@com_github_cockroachdb_ttycolor//:ttycolor", "@com_github_dustin_go_humanize//:go-humanize", "@com_github_gogo_protobuf//jsonpb", + "@com_github_jackc_pgconn//:pgconn", + "@com_github_jackc_pgtype//:pgtype", "@com_github_kr_pretty//:pretty", - "@com_github_lib_pq//:pq", "@com_github_marusama_semaphore//:semaphore", "@com_github_mattn_go_isatty//:go-isatty", "@com_github_spf13_cobra//:cobra", diff --git a/pkg/cli/clierror/BUILD.bazel b/pkg/cli/clierror/BUILD.bazel index a623e1b2572b..7e01e27b3a95 100644 --- a/pkg/cli/clierror/BUILD.bazel +++ b/pkg/cli/clierror/BUILD.bazel @@ -18,7 +18,7 @@ go_library( "//pkg/util/log/logpb", "//pkg/util/log/severity", "@com_github_cockroachdb_errors//:errors", - "@com_github_lib_pq//:pq", + "@com_github_jackc_pgconn//:pgconn", ], ) @@ -47,7 +47,7 @@ go_test( "//pkg/util/log/channel", "//pkg/util/log/severity", "@com_github_cockroachdb_errors//:errors", - "@com_github_lib_pq//:pq", + "@com_github_jackc_pgconn//:pgconn", "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", ], diff --git a/pkg/cli/clierror/error_test.go b/pkg/cli/clierror/error_test.go index 4e71a4943d96..0820462f8bc4 100644 --- a/pkg/cli/clierror/error_test.go +++ b/pkg/cli/clierror/error_test.go @@ -20,7 +20,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/errors" - "github.com/lib/pq" + "github.com/jackc/pgconn" "github.com/stretchr/testify/assert" ) @@ -43,13 +43,13 @@ func TestOutputError(t *testing.T) { // Check the verbose output. This includes the uncategorized sqlstate. {errBase, false, true, "woo\nSQLSTATE: " + pgcode.Uncategorized.String() + "\nLOCATION: " + refLoc}, {errBase, true, true, "ERROR: woo\nSQLSTATE: " + pgcode.Uncategorized.String() + "\nLOCATION: " + refLoc}, - // Check the same over pq.Error objects. - {&pq.Error{Message: "woo"}, false, false, "woo"}, - {&pq.Error{Message: "woo"}, true, false, "ERROR: woo"}, - {&pq.Error{Message: "woo"}, false, true, "woo"}, - {&pq.Error{Message: "woo"}, true, true, "ERROR: woo"}, - {&pq.Error{Severity: "W", Message: "woo"}, false, false, "woo"}, - {&pq.Error{Severity: "W", Message: "woo"}, true, false, "W: woo"}, + // Check the same over pgconn.PgError objects. + {&pgconn.PgError{Message: "woo"}, false, false, "woo"}, + {&pgconn.PgError{Message: "woo"}, true, false, "ERROR: woo"}, + {&pgconn.PgError{Message: "woo"}, false, true, "woo"}, + {&pgconn.PgError{Message: "woo"}, true, true, "ERROR: woo"}, + {&pgconn.PgError{Severity: "W", Message: "woo"}, false, false, "woo"}, + {&pgconn.PgError{Severity: "W", Message: "woo"}, true, false, "W: woo"}, // Check hint printed after message. {errors.WithHint(errBase, "hello"), false, false, "woo\nHINT: hello"}, // Check sqlstate printed before hint, location after hint. @@ -85,16 +85,17 @@ func TestFormatLocation(t *testing.T) { defer log.Scope(t).Close(t) testData := []struct { - file, line, fn string - exp string + file string + line int + fn, exp string }{ - {"", "", "", ""}, - {"a.b", "", "", "a.b"}, - {"", "123", "", ":123"}, - {"", "", "abc", "abc"}, - {"a.b", "", "abc", "abc, a.b"}, - {"a.b", "123", "", "a.b:123"}, - {"", "123", "abc", "abc, :123"}, + {"", 0, "", ""}, + {"a.b", 0, "", "a.b"}, + {"", 123, "", ":123"}, + {"", 0, "abc", "abc"}, + {"a.b", 0, "abc", "abc, a.b"}, + {"a.b", 123, "", "a.b:123"}, + {"", 123, "abc", "abc, :123"}, } for _, tc := range testData { diff --git a/pkg/cli/clierror/formatted_error.go b/pkg/cli/clierror/formatted_error.go index 35c285e16083..ad6baa8d09c9 100644 --- a/pkg/cli/clierror/formatted_error.go +++ b/pkg/cli/clierror/formatted_error.go @@ -19,7 +19,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/errors" - "github.com/lib/pq" + "github.com/jackc/pgconn" ) // OutputError prints out an error object on the given writer. @@ -77,15 +77,15 @@ func (f *formattedError) Error() string { // Extract the fields. var message, hint, detail, location, constraintName string var code pgcode.Code - if pqErr := (*pq.Error)(nil); errors.As(f.err, &pqErr) { - if pqErr.Severity != "" { - severity = pqErr.Severity + if pgErr := (*pgconn.PgError)(nil); errors.As(f.err, &pgErr) { + if pgErr.Severity != "" { + severity = pgErr.Severity } - constraintName = pqErr.Constraint - message = pqErr.Message - code = pgcode.MakeCode(string(pqErr.Code)) - hint, detail = pqErr.Hint, pqErr.Detail - location = formatLocation(pqErr.File, pqErr.Line, pqErr.Routine) + constraintName = pgErr.ConstraintName + message = pgErr.Message + code = pgcode.MakeCode(pgErr.Code) + hint, detail = pgErr.Hint, pgErr.Detail + location = formatLocation(pgErr.File, int(pgErr.Line), pgErr.Routine) } else { message = f.err.Error() code = pgerror.GetPGCode(f.err) @@ -93,7 +93,7 @@ func (f *formattedError) Error() string { hint = errors.FlattenHints(f.err) detail = errors.FlattenDetails(f.err) if file, line, fn, ok := errors.GetOneLineSource(f.err); ok { - location = formatLocation(file, strconv.FormatInt(int64(line), 10), fn) + location = formatLocation(file, line, fn) } } @@ -144,10 +144,10 @@ func (f *formattedError) Error() string { // formatLocation spells out the error's location in a format // similar to psql: routine then file:num. The routine part is // skipped if empty. -func formatLocation(file, line, fn string) string { +func formatLocation(file string, line int, fn string) string { var res strings.Builder res.WriteString(fn) - if file != "" || line != "" { + if file != "" || line != 0 { if fn != "" { res.WriteString(", ") } @@ -156,9 +156,9 @@ func formatLocation(file, line, fn string) string { } else { res.WriteString(file) } - if line != "" { + if line != 0 { res.WriteByte(':') - res.WriteString(line) + res.WriteString(strconv.Itoa(line)) } } return res.String() diff --git a/pkg/cli/clierror/syntax_error.go b/pkg/cli/clierror/syntax_error.go index 1d2a3615207d..20216361495f 100644 --- a/pkg/cli/clierror/syntax_error.go +++ b/pkg/cli/clierror/syntax_error.go @@ -13,15 +13,15 @@ package clierror import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/errors" - "github.com/lib/pq" + "github.com/jackc/pgconn" ) // IsSQLSyntaxError returns true iff the provided error is a SQL // syntax error. The function works for the queries executed via the // clisqlclient/clisqlexec packages. func IsSQLSyntaxError(err error) bool { - if pqErr := (*pq.Error)(nil); errors.As(err, &pqErr) { - return string(pqErr.Code) == pgcode.Syntax.String() + if pgErr := (*pgconn.PgError)(nil); errors.As(err, &pgErr) { + return pgErr.Code == pgcode.Syntax.String() } return false } diff --git a/pkg/cli/clierrorplus/BUILD.bazel b/pkg/cli/clierrorplus/BUILD.bazel index 7728e6b89074..4d2850ab9273 100644 --- a/pkg/cli/clierrorplus/BUILD.bazel +++ b/pkg/cli/clierrorplus/BUILD.bazel @@ -19,7 +19,7 @@ go_library( "//pkg/util/log", "//pkg/util/netutil", "@com_github_cockroachdb_errors//:errors", - "@com_github_lib_pq//:pq", + "@com_github_jackc_pgconn//:pgconn", "@com_github_spf13_cobra//:cobra", "@org_golang_google_grpc//codes", "@org_golang_google_grpc//status", diff --git a/pkg/cli/clierrorplus/decorate_error.go b/pkg/cli/clierrorplus/decorate_error.go index 6e141b723b96..00e7d44f7759 100644 --- a/pkg/cli/clierrorplus/decorate_error.go +++ b/pkg/cli/clierrorplus/decorate_error.go @@ -26,7 +26,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/grpcutil" "github.com/cockroachdb/cockroach/pkg/util/netutil" "github.com/cockroachdb/errors" - "github.com/lib/pq" + "github.com/jackc/pgconn" "github.com/spf13/cobra" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -107,7 +107,9 @@ func MaybeDecorateError( } // Is this an "unable to connect" type of error? - if errors.Is(err, pq.ErrSSLNotSupported) { + // TODO(rafi): patch jackc/pgconn to add an ErrSSLNotSupported constant. + // See https://github.com/jackc/pgconn/issues/118. + if strings.Contains(err.Error(), "server refused TLS connection") { // SQL command failed after establishing a TCP connection // successfully, but discovering that it cannot use TLS while it // expected the server supports TLS. @@ -134,13 +136,13 @@ func MaybeDecorateError( return connRefused() } - if wErr := (*pq.Error)(nil); errors.As(err, &wErr) { + if wErr := (*pgconn.PgError)(nil); errors.As(err, &wErr) { // SQL commands will fail with a pq error but only after // establishing a TCP connection successfully. So if we got // here, there was a TCP connection already. // Did we fail due to security settings? - if pgcode.MakeCode(string(wErr.Code)) == pgcode.ProtocolViolation { + if pgcode.MakeCode(wErr.Code) == pgcode.ProtocolViolation { return connSecurityHint() } @@ -170,6 +172,16 @@ func MaybeDecorateError( return connFailed() } + // pgconn sometimes returns a context deadline error wrapped in a private + // connectError struct, which begins with this message. + if strings.HasPrefix(err.Error(), "failed to connect to") && + errors.IsAny(err, + context.DeadlineExceeded, + context.Canceled, + ) { + return connFailed() + } + if wErr := (*netutil.InitialHeartbeatFailedError)(nil); errors.As(err, &wErr) { // A GRPC TCP connection was established but there was an early failure. // Try to distinguish the cases. diff --git a/pkg/cli/clisqlcfg/context.go b/pkg/cli/clisqlcfg/context.go index 5c23ca9df834..fb5e1b9c3cd0 100644 --- a/pkg/cli/clisqlcfg/context.go +++ b/pkg/cli/clisqlcfg/context.go @@ -170,7 +170,7 @@ func (c *Context) MakeConn(url string) (clisqlclient.Conn, error) { // ensures that if the server was not initialized or there is some // network issue, the client will not be left to hang forever. // - // This is a lib/pq feature. + // This is a pgx feature. if baseURL.GetOption("connect_timeout") == "" && c.ConnectTimeout != 0 { _ = baseURL.SetOption("connect_timeout", strconv.Itoa(c.ConnectTimeout)) } @@ -181,18 +181,28 @@ func (c *Context) MakeConn(url string) (clisqlclient.Conn, error) { conn := c.ConnCtx.MakeSQLConn(c.CmdOut, c.CmdErr, url) conn.SetMissingPassword(!usePw || !pwdSet) + // By default, all connections will use the underlying driver to infer + // result types. This should be set back to false for any use case where the + // results are only shown for textual display. + conn.SetAlwaysInferResultTypes(true) + return conn, nil } // Run executes the SQL shell. -func (c *Context) Run(conn clisqlclient.Conn) error { +func (c *Context) Run(ctx context.Context, conn clisqlclient.Conn) error { if !c.opened { return errors.AssertionFailedf("programming error: Open not called yet") } + // Anything using a SQL shell (e.g. `cockroach sql` or `demo`), only needs + // to show results in text format, so the underlying driver doesn't need to + // infer types. + conn.SetAlwaysInferResultTypes(false) + // Open the connection to make sure everything is OK before running any // statements. Performs authentication. - if err := conn.EnsureConn(); err != nil { + if err := conn.EnsureConn(ctx); err != nil { return err } diff --git a/pkg/cli/clisqlclient/BUILD.bazel b/pkg/cli/clisqlclient/BUILD.bazel index 8affcf1e103c..ca46c1016e22 100644 --- a/pkg/cli/clisqlclient/BUILD.bazel +++ b/pkg/cli/clisqlclient/BUILD.bazel @@ -11,7 +11,9 @@ go_library( "init_conn_error.go", "make_query.go", "parse_bool.go", + "row_type_helpers.go", "rows.go", + "rows_multi.go", "statement_diag.go", "string_to_duration.go", "txn_shim.go", @@ -23,12 +25,14 @@ go_library( "//pkg/cli/clicfg", "//pkg/cli/clierror", "//pkg/security/pprompt", + "//pkg/sql/pgwire/pgcode", "//pkg/sql/scanner", "//pkg/util/version", "@com_github_cockroachdb_cockroach_go_v2//crdb", "@com_github_cockroachdb_errors//:errors", - "@com_github_lib_pq//:pq", - "@com_github_lib_pq_auth_kerberos//:kerberos", + "@com_github_jackc_pgconn//:pgconn", + "@com_github_jackc_pgtype//:pgtype", + "@com_github_jackc_pgx_v4//:pgx", "@com_github_otan_gopgkrb5//:gopgkrb5", ], ) diff --git a/pkg/cli/clisqlclient/api.go b/pkg/cli/clisqlclient/api.go index 0c27bd21c9e7..60f1dc5d2802 100644 --- a/pkg/cli/clisqlclient/api.go +++ b/pkg/cli/clisqlclient/api.go @@ -13,6 +13,7 @@ package clisqlclient import ( "context" "database/sql/driver" + "io" "reflect" "time" ) @@ -24,7 +25,7 @@ type Conn interface { Close() error // EnsureConn (re-)establishes the connection to the server. - EnsureConn() error + EnsureConn(ctx context.Context) error // Exec executes a statement. Exec(ctx context.Context, query string, args ...interface{}) error @@ -68,7 +69,12 @@ type Conn interface { // server requests one. SetMissingPassword(missing bool) - // GetServerMetadata() returns details about the CockroachDB node + // SetAlwaysInferResultTypes configures the alwaysInferResultTypes flag, which + // determines if the client should use the underlying driver to infer result + // types. + SetAlwaysInferResultTypes(b bool) + + // GetServerMetadata returns details about the CockroachDB node // this connection is connected to. GetServerMetadata(ctx context.Context) ( nodeID int32, @@ -85,18 +91,22 @@ type Conn interface { // The sql argument is the SQL query to use to retrieve the value. GetServerValue(ctx context.Context, what, sql string) (driver.Value, string, bool) - // GetDriverConn exposes the underlying SQL driver connection object + // GetDriverConn exposes the underlying driver connection object // for use by the cli package. GetDriverConn() DriverConn } // Rows describes a result set. type Rows interface { + driver.Rows + // The caller must call Close() when done with the // result and check the error. Close() error // Columns returns the column labels of the current result set. + // The implementation of this method should cache the result so that the + // result does not need to be constructed on each invocation. Columns() []string // ColumnTypeScanType returns the natural Go type of values at the @@ -111,11 +121,8 @@ type Rows interface { // columns. ColumnTypeNames() []string - // Result retrieves the underlying driver result object. - Result() driver.Result - // Tag retrieves the statement tag for the current result set. - Tag() string + Tag() (CommandTag, error) // Next populates values with the next row of results. []byte values are copied // so that subsequent calls to Next and Close do not mutate values. This @@ -130,6 +137,12 @@ type Rows interface { NextResultSet() (bool, error) } +// CommandTag represents the result of a SQL command. +type CommandTag interface { + RowsAffected() int64 + String() string +} + // QueryStatsDuration represents a duration value retrieved by // GetLastQueryStatistics. type QueryStatsDuration struct { @@ -168,10 +181,32 @@ type TxBoundConn interface { } // DriverConn is the type of the connection object returned by -// (Conn).GetDriverConn(). It gives access to the underlying Go sql +// (Conn).GetDriverConn(). It gives access to the underlying sql // driver. type DriverConn interface { - driver.Conn - driver.ExecerContext - driver.QueryerContext + Query(ctx context.Context, query string, args ...interface{}) (driver.Rows, error) + Exec(ctx context.Context, query string, args ...interface{}) error + CopyFrom(ctx context.Context, reader io.Reader, query string) error +} + +type driverConnAdapter struct { + c *sqlConn +} + +var _ DriverConn = (*driverConnAdapter)(nil) + +func (d *driverConnAdapter) Query( + ctx context.Context, query string, args ...interface{}, +) (driver.Rows, error) { + rows, err := d.c.Query(ctx, query, args...) + return driver.Rows(rows), err +} + +func (d *driverConnAdapter) Exec(ctx context.Context, query string, args ...interface{}) error { + return d.c.Exec(ctx, query, args...) +} + +func (d *driverConnAdapter) CopyFrom(ctx context.Context, reader io.Reader, query string) error { + _, err := d.c.conn.PgConn().CopyFrom(ctx, reader, query) + return err } diff --git a/pkg/cli/clisqlclient/conn.go b/pkg/cli/clisqlclient/conn.go index 86221c1ebe4a..9f6f938137fc 100644 --- a/pkg/cli/clisqlclient/conn.go +++ b/pkg/cli/clisqlclient/conn.go @@ -15,6 +15,7 @@ import ( "database/sql/driver" "fmt" "io" + "net" "net/url" "strconv" "strings" @@ -23,16 +24,17 @@ import ( "github.com/cockroachdb/cockroach/pkg/build" "github.com/cockroachdb/cockroach/pkg/cli/clierror" "github.com/cockroachdb/cockroach/pkg/security/pprompt" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/util/version" "github.com/cockroachdb/errors" - "github.com/lib/pq" - "github.com/lib/pq/auth/kerberos" - _ "github.com/otan/gopgkrb5" // need a comment until the dependency is used + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" + "github.com/otan/gopgkrb5" ) func init() { // Ensure that the CLI client commands can use GSSAPI authentication. - pq.RegisterGSSProvider(func() (pq.GSS, error) { return kerberos.NewGSS() }) + pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() }) } type sqlConn struct { @@ -42,13 +44,18 @@ type sqlConn struct { connCtx *Context url string - conn DriverConn + conn *pgx.Conn reconnecting bool // passwordMissing is true iff the url is missing a password. passwordMissing bool - pendingNotices []*pq.Error + // alwaysInferResultTypes is true iff the client should always use the + // underlying driver to infer result types. If it is true, multiple statements + // in a single string cannot be executed. + alwaysInferResultTypes bool + + pendingNotices []*pgconn.Notice // delayNotices, if set, makes notices accumulate for printing // when the SQL execution completes. The default (false) @@ -82,6 +89,10 @@ type sqlConn struct { var _ Conn = (*sqlConn)(nil) +// ErrConnectionClosed is returned when an operation fails because the +// connection was closed. +var ErrConnectionClosed = errors.New("connection closed unexpectedly") + // wrapConnError detects TCP EOF errors during the initial SQL handshake. // These are translated to a message "perhaps this is not a CockroachDB node" // at the top level. @@ -90,7 +101,10 @@ var _ Conn = (*sqlConn)(nil) // server. func wrapConnError(err error) error { errMsg := err.Error() - if errMsg == "EOF" || errMsg == "unexpected EOF" { + // pgconn wraps some of these errors with the private connectError struct. + isPgconnConnectError := strings.HasPrefix(errMsg, "failed to connect to") + isEOF := errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) + if errMsg == "EOF" || errMsg == "unexpected EOF" || (isPgconnConnectError && isEOF) { return &InitialSQLConnectionError{err} } return err @@ -98,13 +112,13 @@ func wrapConnError(err error) error { func (c *sqlConn) flushNotices() { for _, notice := range c.pendingNotices { - clierror.OutputError(c.errw, notice, true /*showSeverity*/, false /*verbose*/) + clierror.OutputError(c.errw, (*pgconn.PgError)(notice), true /*showSeverity*/, false /*verbose*/) } c.pendingNotices = nil c.delayNotices = false } -func (c *sqlConn) handleNotice(notice *pq.Error) { +func (c *sqlConn) handleNotice(notice *pgconn.Notice) { c.pendingNotices = append(c.pendingNotices, notice) if !c.delayNotices { c.flushNotices() @@ -123,7 +137,7 @@ func (c *sqlConn) SetURL(url string) { // GetDriverConn implements the Conn interface. func (c *sqlConn) GetDriverConn() DriverConn { - return c.conn + return &driverConnAdapter{c} } // SetCurrentDatabase implements the Conn interface. @@ -136,39 +150,52 @@ func (c *sqlConn) SetMissingPassword(missing bool) { c.passwordMissing = missing } +// SetAlwaysInferResultTypes implements the Conn interface. +func (c *sqlConn) SetAlwaysInferResultTypes(b bool) { + c.alwaysInferResultTypes = b +} + +// The default pgx dialer uses a KeepAlive of 5 minutes, which we don't want. +var defaultDialer = &net.Dialer{} + // EnsureConn (re-)establishes the connection to the server. -func (c *sqlConn) EnsureConn() error { +func (c *sqlConn) EnsureConn(ctx context.Context) error { if c.conn != nil { return nil } - ctx := context.Background() if c.reconnecting && c.connCtx.IsInteractive() { fmt.Fprintf(c.errw, "warning: connection lost!\n"+ "opening new connection: all session settings will be lost\n") } - base, err := pq.NewConnector(c.url) + base, err := pgx.ParseConfig(c.url) if err != nil { return wrapConnError(err) } // Add a notice handler - re-use the cliOutputError function in this case. - connector := pq.ConnectorWithNoticeHandler(base, func(notice *pq.Error) { + base.OnNotice = func(_ *pgconn.PgConn, notice *pgconn.Notice) { c.handleNotice(notice) - }) - // TODO(cli): we can't thread ctx through ensureConn usages, as it needs - // to follow the gosql.DB interface. We should probably look at initializing - // connections only once instead. The context is only used for dialing. - conn, err := connector.Connect(ctx) + } + // By default, pgx uses a Dialer with a KeepAlive of 5 minutes, which is not + // desired here, so we override it. + defaultDialer.Timeout = 0 + if base.ConnectTimeout > 0 { + defaultDialer.Timeout = base.ConnectTimeout + } + base.DialFunc = defaultDialer.DialContext + // Override LookupFunc to be a no-op, so that the Dialer is responsible for + // resolving hostnames. This fixes an issue where pgx would error out too + // quickly if using TLS when an ipv6 address is resolved, but the networking + // stack does not support ipv6. + base.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { + return []string{host}, nil + } + conn, err := pgx.ConnectConfig(ctx, base) if err != nil { // Connection failed: if the failure is due to a missing // password, we're going to fill the password here. - // - // TODO(knz): CockroachDB servers do not properly fill SQLSTATE - // (28P01) for password auth errors, so we have to "make do" - // with a string match. This should be cleaned up by adding - // the missing code server-side. - errStr := strings.TrimPrefix(err.Error(), "pq: ") - if strings.HasPrefix(errStr, "password authentication failed") && c.passwordMissing { + pgErr := (*pgconn.PgError)(nil) + if errors.As(err, &pgErr) && pgErr.Code == pgcode.InvalidPassword.String() && c.passwordMissing { if pErr := c.fillPassword(); pErr != nil { return errors.CombineErrors(err, pErr) } @@ -177,19 +204,18 @@ func (c *sqlConn) EnsureConn() error { // The recursion only occurs once because fillPassword() // resets c.passwordMissing, so we cannot get into this // conditional a second time. - return c.EnsureConn() + return c.EnsureConn(ctx) } // Not a password auth error, or password already set. Simply fail. return wrapConnError(err) } if c.reconnecting && c.dbName != "" { // Attempt to reset the current database. - if _, err := conn.(DriverConn).ExecContext(ctx, `SET DATABASE = $1`, - []driver.NamedValue{{Value: c.dbName}}); err != nil { + if _, err := conn.Exec(ctx, `SET DATABASE = $1`, c.dbName); err != nil { fmt.Fprintf(c.errw, "warning: unable to restore current database: %v\n", err) } } - c.conn = conn.(DriverConn) + c.conn = conn if err := c.checkServerMetadata(ctx); err != nil { err = errors.CombineErrors(err, c.Close()) return wrapConnError(err) @@ -243,18 +269,18 @@ func (c *sqlConn) tryEnableServerExecutionTimings(ctx context.Context) error { func (c *sqlConn) GetServerMetadata( ctx context.Context, -) (nodeID int32, version, clusterID string, err error) { +) (nodeID int32, version, clusterID string, retErr error) { // Retrieve the node ID and server build info. // Be careful to query against the empty database string, which avoids taking // a lease against the current database (in case it's currently unavailable). rows, err := c.Query(ctx, `SELECT * FROM "".crdb_internal.node_build_info`) - if errors.Is(err, driver.ErrBadConn) { - return 0, "", "", err + if c.conn.IsClosed() { + return 0, "", "", MarkWithConnectionClosed(err) } if err != nil { return 0, "", "", err } - defer func() { _ = rows.Close() }() + defer func() { retErr = errors.CombineErrors(retErr, rows.Close()) }() // Read the node_build_info table as an array of strings. rowVals, err := getServerMetadataRows(rows) @@ -354,8 +380,8 @@ func (c *sqlConn) checkServerMetadata(ctx context.Context) error { } _, newServerVersion, newClusterID, err := c.GetServerMetadata(ctx) - if errors.Is(err, driver.ErrBadConn) { - return err + if c.conn.IsClosed() { + return MarkWithConnectionClosed(err) } if err != nil { // It is not an error that the server version cannot be retrieved. @@ -511,45 +537,64 @@ func (c *sqlConn) ExecTxn( } func (c *sqlConn) Exec(ctx context.Context, query string, args ...interface{}) error { - dVals, err := convertArgs(args) - if err != nil { - return err - } - if err := c.EnsureConn(); err != nil { + if err := c.EnsureConn(ctx); err != nil { return err } if c.connCtx.Echo { fmt.Fprintln(c.errw, ">", query) } - _, err = c.conn.ExecContext(ctx, query, dVals) + _, err := c.conn.Exec(ctx, query, args...) c.flushNotices() - if errors.Is(err, driver.ErrBadConn) { + if c.conn.IsClosed() { c.reconnecting = true c.silentClose() + return MarkWithConnectionClosed(err) } return err } func (c *sqlConn) Query(ctx context.Context, query string, args ...interface{}) (Rows, error) { - dVals, err := convertArgs(args) - if err != nil { - return nil, err - } - if err := c.EnsureConn(); err != nil { + if err := c.EnsureConn(ctx); err != nil { return nil, err } if c.connCtx.Echo { fmt.Fprintln(c.errw, ">", query) } - rows, err := c.conn.QueryContext(ctx, query, dVals) - if errors.Is(err, driver.ErrBadConn) { + + // If there are placeholder args, then we must use a prepared statement, + // which is only possible using pgx.Conn. + // Or, if alwaysInferResultTypes is set, then we must use pgx.Conn so the + // result types are automatically inferred. + if len(args) > 0 || c.alwaysInferResultTypes { + rows, err := c.conn.Query(ctx, query, args...) + if c.conn.IsClosed() { + c.reconnecting = true + c.silentClose() + return nil, MarkWithConnectionClosed(err) + } + if err != nil { + return nil, err + } + return &sqlRows{rows: rows, connInfo: c.conn.ConnInfo(), conn: c}, nil + } + + // Otherwise, we use pgconn. This allows us to add support for multiple + // queries in a single string, which wouldn't be possible at the pgx level. + multiResultReader := c.conn.PgConn().Exec(ctx, query) + if c.conn.IsClosed() { c.reconnecting = true c.silentClose() + return nil, MarkWithConnectionClosed(multiResultReader.Close()) } - if err != nil { + rs := &sqlRowsMultiResultSet{ + rows: multiResultReader, + connInfo: c.conn.ConnInfo(), + conn: c, + } + if _, err := rs.NextResultSet(); err != nil { return nil, err } - return &sqlRows{rows: rows.(sqlRowsI), conn: c}, nil + return rs, nil } func (c *sqlConn) QueryRow( @@ -589,7 +634,7 @@ func (c *sqlConn) queryRowInternal( func (c *sqlConn) Close() error { c.flushNotices() if c.conn != nil { - err := c.conn.Close() + err := c.conn.Close(context.Background()) if err != nil { return err } @@ -600,7 +645,7 @@ func (c *sqlConn) Close() error { func (c *sqlConn) silentClose() { if c.conn != nil { - _ = c.conn.Close() + _ = c.conn.Close(context.Background()) c.conn = nil } } @@ -645,3 +690,20 @@ func (c *sqlConn) fillPassword() error { c.passwordMissing = false return nil } + +// MarkWithConnectionClosed is a mix of errors.CombineErrors() and errors.Mark(). +// If err is nil, the result is like errors.CombineErrors(). +// If err is non-nil, the result is that of errors.Mark(), with the error +// message added as a prefix. +// +// In both cases, errors.Is(..., ErrConnectionClosed) returns true on the +// result. +func MarkWithConnectionClosed(err error) error { + if err == nil { + return ErrConnectionClosed + } + // Disable the linter since we are intentionally adding the error text. + // nolint:errwrap + errWithMsg := errors.WithMessagef(err, "%v", ErrConnectionClosed) + return errors.Mark(errWithMsg, ErrConnectionClosed) +} diff --git a/pkg/cli/clisqlclient/conn_test.go b/pkg/cli/clisqlclient/conn_test.go index e047fb5e684b..73bf2e9a9dc8 100644 --- a/pkg/cli/clisqlclient/conn_test.go +++ b/pkg/cli/clisqlclient/conn_test.go @@ -12,7 +12,6 @@ package clisqlclient_test import ( "context" - "database/sql/driver" "io/ioutil" "net/url" "testing" @@ -70,8 +69,8 @@ func TestConnRecover(t *testing.T) { if closeErr := sqlRows.Close(); closeErr != nil { t.Fatal(closeErr) } - } else if !errors.Is(err, driver.ErrBadConn) { - return errors.Newf("expected ErrBadConn, got %v", err) // nolint:errwrap + } else if !errors.Is(err, clisqlclient.ErrConnectionClosed) { + return errors.Newf("expected ErrConnectionClosed, got %v", err) // nolint:errwrap } return nil }) @@ -90,8 +89,8 @@ func TestConnRecover(t *testing.T) { // Ditto from Query(). testutils.SucceedsSoon(t, func() error { - if err := conn.Exec(ctx, `SELECT 1`); !errors.Is(err, driver.ErrBadConn) { - return errors.Newf("expected ErrBadConn, got %v", err) // nolint:errwrap + if err := conn.Exec(ctx, `SELECT 1`); !errors.Is(err, clisqlclient.ErrConnectionClosed) { + return errors.Newf("expected ErrConnectionClosed, got %v", err) // nolint:errwrap } return nil }) diff --git a/pkg/cli/clisqlclient/copy.go b/pkg/cli/clisqlclient/copy.go index 733aaaefeba6..ab93894525f3 100644 --- a/pkg/cli/clisqlclient/copy.go +++ b/pkg/cli/clisqlclient/copy.go @@ -11,43 +11,38 @@ package clisqlclient import ( + "bytes" "context" "database/sql/driver" "io" "reflect" - "strings" - "github.com/cockroachdb/errors" + "github.com/jackc/pgconn" ) -type copyFromer interface { - CopyData(ctx context.Context, line string) (r driver.Result, err error) - Exec(v []driver.Value) (r driver.Result, err error) - Close() error -} - // CopyFromState represents an in progress COPY FROM. type CopyFromState struct { - driver.Tx - copyFromer + conn *pgconn.PgConn + query string } // BeginCopyFrom starts a COPY FROM query. func BeginCopyFrom(ctx context.Context, conn Conn, query string) (*CopyFromState, error) { - txn, err := conn.(*sqlConn).conn.(driver.ConnBeginTx).BeginTx(ctx, driver.TxOptions{}) - if err != nil { + copyConn := conn.(*sqlConn).conn.PgConn() + // Run the initial query, but don't use the result so that we can get any + // errors early. + if _, err := copyConn.CopyFrom(ctx, bytes.NewReader([]byte{}), query); err != nil { return nil, err } - stmt, err := txn.(driver.Conn).Prepare(query) - if err != nil { - return nil, errors.CombineErrors(err, txn.Rollback()) - } - return &CopyFromState{Tx: txn, copyFromer: stmt.(copyFromer)}, nil + return &CopyFromState{ + conn: copyConn, + query: query, + }, nil } // copyFromRows is a mock Rows interface for COPY results. type copyFromRows struct { - r driver.Result + t pgconn.CommandTag } func (c copyFromRows) Close() error { @@ -70,12 +65,8 @@ func (c copyFromRows) ColumnTypeNames() []string { return nil } -func (c copyFromRows) Result() driver.Result { - return c.r -} - -func (c copyFromRows) Tag() string { - return "COPY" +func (c copyFromRows) Tag() (CommandTag, error) { + return c.t, nil } func (c copyFromRows) Next(values []driver.Value) error { @@ -88,7 +79,7 @@ func (c copyFromRows) NextResultSet() (bool, error) { // Cancel cancels a COPY FROM query from completing. func (c *CopyFromState) Cancel() error { - return errors.CombineErrors(c.copyFromer.Close(), c.Tx.Rollback()) + return nil } // Commit completes a COPY FROM query by committing lines to the database. @@ -96,21 +87,13 @@ func (c *CopyFromState) Commit(ctx context.Context, cleanupFunc func(), lines st return func(ctx context.Context, conn Conn) (Rows, bool, error) { defer cleanupFunc() rows, isMulti, err := func() (Rows, bool, error) { - for _, l := range strings.Split(lines, "\n") { - _, err := c.copyFromer.CopyData(ctx, l) - if err != nil { - return nil, false, err - } - } - r, err := c.copyFromer.Exec(nil) + r := bytes.NewReader([]byte(lines)) + tag, err := c.conn.CopyFrom(ctx, r, c.query) if err != nil { return nil, false, err } - return copyFromRows{r: r}, false, c.Tx.Commit() + return copyFromRows{tag}, false, nil }() - if err != nil { - return rows, isMulti, errors.CombineErrors(err, errors.CombineErrors(c.copyFromer.Close(), c.Tx.Rollback())) - } return rows, isMulti, err } } diff --git a/pkg/cli/clisqlclient/make_query.go b/pkg/cli/clisqlclient/make_query.go index e43a3f3e02a3..7097f474c166 100644 --- a/pkg/cli/clisqlclient/make_query.go +++ b/pkg/cli/clisqlclient/make_query.go @@ -12,7 +12,6 @@ package clisqlclient import ( "context" - "database/sql/driver" "github.com/cockroachdb/cockroach/pkg/sql/scanner" ) @@ -29,22 +28,3 @@ func MakeQuery(query string, parameters ...interface{}) QueryFn { return rows, isMultiStatementQuery, err } } - -func convertArgs(parameters []interface{}) ([]driver.NamedValue, error) { - dVals := make([]driver.NamedValue, len(parameters)) - for i := range parameters { - // driver.NamedValue.Value is an alias for interface{}, but must adhere to a restricted - // set of types when being passed to driver.Queryer.Query (see - // driver.IsValue). We use driver.DefaultParameterConverter to perform the - // necessary conversion. This is usually taken care of by the sql package, - // but we have to do so manually because we're talking directly to the - // driver. - var err error - dVals[i].Ordinal = i + 1 - dVals[i].Value, err = driver.DefaultParameterConverter.ConvertValue(parameters[i]) - if err != nil { - return nil, err - } - } - return dVals, nil -} diff --git a/pkg/cli/clisqlclient/row_type_helpers.go b/pkg/cli/clisqlclient/row_type_helpers.go new file mode 100644 index 000000000000..39e728b51867 --- /dev/null +++ b/pkg/cli/clisqlclient/row_type_helpers.go @@ -0,0 +1,72 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package clisqlclient + +import ( + "reflect" + "strings" + "time" + + "github.com/jackc/pgtype" +) + +// scanType returns the Go type used for the given column type. +func scanType(typeOID uint32, typeName string) reflect.Type { + if typeName == "" || strings.HasPrefix(typeName, "_") { + // User-defined types and array types are scanned into []byte. + // These are handled separately since we can't easily include all the OIDs + // for user-defined types and arrays in the switch statement. + return reflect.TypeOf([]byte(nil)) + } + // This switch is copied from lib/pq, and modified with a few additional cases. + // https://github.com/lib/pq/blob/b2901c7946b69f1e7226214f9760e31620499595/rows.go#L24 + switch typeOID { + case pgtype.Int8OID: + return reflect.TypeOf(int64(0)) + case pgtype.Int4OID: + return reflect.TypeOf(int32(0)) + case pgtype.Int2OID: + return reflect.TypeOf(int16(0)) + case pgtype.VarcharOID, pgtype.TextOID, pgtype.NameOID, + pgtype.ByteaOID, pgtype.NumericOID, pgtype.RecordOID, + pgtype.QCharOID, pgtype.BPCharOID: + return reflect.TypeOf("") + case pgtype.BoolOID: + return reflect.TypeOf(false) + case pgtype.DateOID, pgtype.TimeOID, 1266, pgtype.TimestampOID, pgtype.TimestamptzOID: + // 1266 is the OID for TimeTZ. + // TODO(rafi): Add TimetzOID to pgtype. + return reflect.TypeOf(time.Time{}) + default: + return reflect.TypeOf(new(interface{})).Elem() + } +} + +// databaseTypeName returns the database type name for the given type OID. +func databaseTypeName(ci *pgtype.ConnInfo, typeOID uint32) string { + dataType, ok := ci.DataTypeForOID(typeOID) + if !ok { + // TODO(rafi): remove special logic once jackc/pgtype includes these types. + switch typeOID { + case 1002: + return "_CHAR" + case 1003: + return "_NAME" + case 1266: + return "TIMETZ" + case 1270: + return "_TIMETZ" + default: + return "" + } + } + return strings.ToUpper(dataType.Name) +} diff --git a/pkg/cli/clisqlclient/rows.go b/pkg/cli/clisqlclient/rows.go index 270098763b4e..0481d2763fa1 100644 --- a/pkg/cli/clisqlclient/rows.go +++ b/pkg/cli/clisqlclient/rows.go @@ -12,62 +12,68 @@ package clisqlclient import ( "database/sql/driver" + "io" "reflect" - "github.com/cockroachdb/errors" + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4" ) type sqlRows struct { - rows sqlRowsI - conn *sqlConn + rows pgx.Rows + connInfo *pgtype.ConnInfo + conn *sqlConn + colNames []string } var _ Rows = (*sqlRows)(nil) -type sqlRowsI interface { - driver.RowsColumnTypeScanType - driver.RowsColumnTypeDatabaseTypeName - Result() driver.Result - Tag() string - - // Go 1.8 multiple result set interfaces. - // TODO(mjibson): clean this up after 1.8 is released. - HasNextResultSet() bool - NextResultSet() error -} - func (r *sqlRows) Columns() []string { - return r.rows.Columns() -} - -func (r *sqlRows) Result() driver.Result { - return r.rows.Result() + if r.colNames == nil { + fields := r.rows.FieldDescriptions() + r.colNames = make([]string, len(fields)) + for i, fd := range fields { + r.colNames[i] = string(fd.Name) + } + } + return r.colNames } -func (r *sqlRows) Tag() string { - return r.rows.Tag() +func (r *sqlRows) Tag() (CommandTag, error) { + return r.rows.CommandTag(), r.rows.Err() } func (r *sqlRows) Close() error { r.conn.flushNotices() - err := r.rows.Close() - if errors.Is(err, driver.ErrBadConn) { + r.rows.Close() + if r.conn.conn.IsClosed() { r.conn.reconnecting = true r.conn.silentClose() + return ErrConnectionClosed } - return err + return r.rows.Err() } // Next implements the Rows interface. func (r *sqlRows) Next(values []driver.Value) error { - err := r.rows.Next(values) - if errors.Is(err, driver.ErrBadConn) { + if r.conn.conn.IsClosed() { r.conn.reconnecting = true r.conn.silentClose() + return ErrConnectionClosed + } + if !r.rows.Next() { + return io.EOF } - for i, v := range values { - if b, ok := v.([]byte); ok { + rawVals, err := r.rows.Values() + if err != nil { + return err + } + for i, v := range rawVals { + if b, ok := (v).([]byte); ok { + // Copy byte slices as per the comment on Rows.Next. values[i] = append([]byte{}, b...) + } else { + values[i] = v } } // After the first row was received, we want to delay all @@ -78,18 +84,18 @@ func (r *sqlRows) Next(values []driver.Value) error { // NextResultSet prepares the next result set for reading. func (r *sqlRows) NextResultSet() (bool, error) { - if !r.rows.HasNextResultSet() { - return false, nil - } - return true, r.rows.NextResultSet() + return false, nil } func (r *sqlRows) ColumnTypeScanType(index int) reflect.Type { - return r.rows.ColumnTypeScanType(index) + o := r.rows.FieldDescriptions()[index].DataTypeOID + n := r.ColumnTypeDatabaseTypeName(index) + return scanType(o, n) } func (r *sqlRows) ColumnTypeDatabaseTypeName(index int) string { - return r.rows.ColumnTypeDatabaseTypeName(index) + fieldOID := r.rows.FieldDescriptions()[index].DataTypeOID + return databaseTypeName(r.connInfo, fieldOID) } func (r *sqlRows) ColumnTypeNames() []string { diff --git a/pkg/cli/clisqlclient/rows_multi.go b/pkg/cli/clisqlclient/rows_multi.go new file mode 100644 index 000000000000..55a452ef73d6 --- /dev/null +++ b/pkg/cli/clisqlclient/rows_multi.go @@ -0,0 +1,157 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package clisqlclient + +import ( + "database/sql/driver" + "io" + "reflect" + + "github.com/cockroachdb/errors" + "github.com/jackc/pgconn" + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4" +) + +type sqlRowsMultiResultSet struct { + rows *pgconn.MultiResultReader + connInfo *pgtype.ConnInfo + conn *sqlConn + colNames []string +} + +var _ Rows = (*sqlRowsMultiResultSet)(nil) + +func (r *sqlRowsMultiResultSet) Columns() []string { + if r.colNames == nil { + rr := r.rows.ResultReader() + if rr == nil { + // ResultReader may be nil if an empty query was executed. + return nil + } + fields := rr.FieldDescriptions() + r.colNames = make([]string, len(fields)) + for i, fd := range fields { + r.colNames[i] = string(fd.Name) + } + } + return r.colNames +} + +func (r *sqlRowsMultiResultSet) Tag() (CommandTag, error) { + if rr := r.rows.ResultReader(); rr != nil { + // ResultReader may be nil if an empty query was executed. + return r.rows.ResultReader().Close() + } + return pgconn.CommandTag(""), nil +} + +func (r *sqlRowsMultiResultSet) Close() (retErr error) { + r.conn.flushNotices() + if rr := r.rows.ResultReader(); rr != nil { + // ResultReader may be nil if an empty query was executed. + _, retErr = r.rows.ResultReader().Close() + } + retErr = errors.CombineErrors(retErr, r.rows.Close()) + if r.conn.conn.IsClosed() { + r.conn.reconnecting = true + r.conn.silentClose() + return MarkWithConnectionClosed(retErr) + } + return retErr +} + +// Next implements the Rows interface. +func (r *sqlRowsMultiResultSet) Next(values []driver.Value) error { + if r.conn.conn.IsClosed() { + r.conn.reconnecting = true + r.conn.silentClose() + return ErrConnectionClosed + } + rd := r.rows.ResultReader() + if rd == nil { + // ResultReader may be nil if an empty query was executed. + return io.EOF + } + if !rd.NextRow() { + if _, err := rd.Close(); err != nil { + return err + } + return io.EOF + } + if len(rd.FieldDescriptions()) != len(values) { + return errors.AssertionFailedf( + "number of field descriptions must equal number of destinations, got %d and %d", + len(rd.FieldDescriptions()), + len(values), + ) + } + for i := range values { + rowVal := rd.Values()[i] + if rowVal == nil { + values[i] = "NULL" + continue + } + fieldOID := rd.FieldDescriptions()[i].DataTypeOID + fieldFormat := rd.FieldDescriptions()[i].Format + + // By scanning the value into a string, pgconn will use the pgwire + // text format to represent the value. + var s string + err := r.connInfo.Scan(fieldOID, fieldFormat, rowVal, &s) + if err != nil { + return pgx.ScanArgError{ColumnIndex: i, Err: err} + } + values[i] = s + } + // After the first row was received, we want to delay all + // further notices until the end of execution. + r.conn.delayNotices = true + return nil +} + +// NextResultSet prepares the next result set for reading. +func (r *sqlRowsMultiResultSet) NextResultSet() (bool, error) { + r.colNames = nil + next := r.rows.NextResult() + if !next { + if err := r.rows.Close(); err != nil { + return false, err + } + } + if r.conn.conn.IsClosed() { + r.conn.reconnecting = true + r.conn.silentClose() + return false, ErrConnectionClosed + } + return next, nil +} + +func (r *sqlRowsMultiResultSet) ColumnTypeScanType(index int) reflect.Type { + rd := r.rows.ResultReader() + o := rd.FieldDescriptions()[index].DataTypeOID + n := r.ColumnTypeDatabaseTypeName(index) + return scanType(o, n) +} + +func (r *sqlRowsMultiResultSet) ColumnTypeDatabaseTypeName(index int) string { + rd := r.rows.ResultReader() + fieldOID := rd.FieldDescriptions()[index].DataTypeOID + return databaseTypeName(r.connInfo, fieldOID) +} + +func (r *sqlRowsMultiResultSet) ColumnTypeNames() []string { + colTypes := make([]string, len(r.Columns())) + for i := range colTypes { + colTypes[i] = r.ColumnTypeDatabaseTypeName(i) + } + return colTypes +} diff --git a/pkg/cli/clisqlexec/BUILD.bazel b/pkg/cli/clisqlexec/BUILD.bazel index 720b621f0c6e..f4f2a6be749b 100644 --- a/pkg/cli/clisqlexec/BUILD.bazel +++ b/pkg/cli/clisqlexec/BUILD.bazel @@ -23,7 +23,7 @@ go_library( "//pkg/util/syncutil", "//pkg/util/timeutil", "@com_github_cockroachdb_errors//:errors", - "@com_github_lib_pq//:pq", + "@com_github_jackc_pgtype//:pgtype", "@com_github_olekukonko_tablewriter//:tablewriter", "@com_github_spf13_pflag//:pflag", ], diff --git a/pkg/cli/clisqlexec/format_table.go b/pkg/cli/clisqlexec/format_table.go index f2e8d20bec96..28a1f7fca6d7 100644 --- a/pkg/cli/clisqlexec/format_table.go +++ b/pkg/cli/clisqlexec/format_table.go @@ -181,13 +181,13 @@ func render( iter RowStrIter, completedHook func(), noRowsHook func() (bool, error), -) (err error) { +) (retErr error) { described := false nRows := 0 defer func() { // If the column headers are not printed yet, do it now. if !described { - err = errors.WithSecondaryError(err, r.describe(w, cols)) + retErr = errors.CombineErrors(retErr, r.describe(w, cols)) } // completedHook, if provided, is called unconditionally of error. @@ -198,18 +198,20 @@ func render( // We need to call doneNoRows/doneRows also unconditionally. var handled bool if nRows == 0 && noRowsHook != nil { - handled, err = noRowsHook() - if err != nil { + var noRowsErr error + handled, noRowsErr = noRowsHook() + if noRowsErr != nil { + retErr = errors.CombineErrors(retErr, noRowsErr) return } } if handled { - err = errors.WithSecondaryError(err, r.doneNoRows(w)) + retErr = errors.CombineErrors(retErr, r.doneNoRows(w)) } else { - err = errors.WithSecondaryError(err, r.doneRows(w, nRows)) + retErr = errors.CombineErrors(retErr, r.doneRows(w, nRows)) } - if err != nil && nRows > 0 { + if retErr != nil && nRows > 0 { fmt.Fprintf(ew, "(error encountered after some results were delivered)\n") } }() diff --git a/pkg/cli/clisqlexec/format_table_test.go b/pkg/cli/clisqlexec/format_table_test.go index 92a27b0a6c7b..0e22f0c5a962 100644 --- a/pkg/cli/clisqlexec/format_table_test.go +++ b/pkg/cli/clisqlexec/format_table_test.go @@ -64,7 +64,7 @@ thenshort`, // thenshort" int, "κόσμε" int, "a|b" int, ܈85 int) // CREATE TABLE // sql -e insert into t.u values (0, 0, 0, 0, 0, 0, 0, 0) - // INSERT 1 + // INSERT 0 1 // sql -e show columns from t.u // column_name data_type is_nullable column_default generation_expression indices // "f""oo" INT true NULL {} @@ -193,7 +193,11 @@ func Example_sql_empty_table() { // Output: // sql -e create database t;create table t.norows(x int);create table t.nocolsnorows();create table t.nocols(); insert into t.nocols(rowid) values (1),(2),(3); - // INSERT 3 + // CREATE DATABASE + // CREATE TABLE + // CREATE TABLE + // CREATE TABLE + // INSERT 0 3 // sql --format=tsv -e select * from t.norows // x // sql --format=csv -e select * from t.norows @@ -489,25 +493,26 @@ func Example_sql_table() { // Output: // sql -e create database t; create table t.t (s string, d string); + // CREATE DATABASE // CREATE TABLE // sql -e insert into t.t values (e'foo', 'printable ASCII') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'"foo', 'printable ASCII with quotes') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'\\foo', 'printable ASCII with backslash') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'foo\x0abar', 'non-printable ASCII') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values ('κόσμε', 'printable UTF8') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'\xc3\xb1', 'printable UTF8 using escapes') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'\x01', 'non-printable UTF8 string') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'\xdc\x88\x38\x35', 'UTF8 string with RTL char') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'a\tb\tc\n12\t123123213\t12313', 'tabs') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'\xc3\x28', 'non-UTF8 string') // ERROR: lexical error: invalid UTF-8 byte sequence // SQLSTATE: 42601 diff --git a/pkg/cli/clisqlexec/format_value.go b/pkg/cli/clisqlexec/format_value.go index 6697d8b48ad8..41449c6d09d9 100644 --- a/pkg/cli/clisqlexec/format_value.go +++ b/pkg/cli/clisqlexec/format_value.go @@ -12,7 +12,6 @@ package clisqlexec import ( "bytes" - gosql "database/sql" "database/sql/driver" "fmt" "math" @@ -26,7 +25,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/lexbase" "github.com/cockroachdb/cockroach/pkg/util/timeutil" - "github.com/lib/pq" + "github.com/jackc/pgtype" ) func isNotPrintableASCII(r rune) bool { return r < 0x20 || r > 0x7e || r == '"' || r == '\\' } @@ -44,7 +43,6 @@ func FormatVal( if strings.HasPrefix(colType, "_") && len(b) > 0 && b[0] == '{' { return formatArray(b, colType[1:], showPrintableUnicode, showNewLinesAndTabs) } - // Names, records, and user-defined types should all be displayed as strings. if colType == "NAME" || colType == "RECORD" || colType == "" { val = string(b) @@ -88,19 +86,13 @@ func FormatVal( return s[1 : len(s)-1] case []byte: - // Format the bytes as per bytea_output = escape. - // - // We use the "escape" format here because it enables printing - // readable strings as-is -- the default hex format would always - // render as hexadecimal digits. The escape format is also more - // compact. + // For other []byte types that weren't handled above, we use the "escape" + // format because it enables printing readable strings as-is -- the default + // hex format would always render as hexadecimal digits. The escape format + // is also more compact. // - // TODO(knz): this formatting is unfortunate/incorrect, and exists - // only because lib/pq incorrectly interprets the bytes received - // from the server. The proper behavior would be for the driver to - // not interpret the bytes and for us here to print that as-is, so - // that we can let the user see and control the result using - // `bytea_output`. + // Note that the BYTEA type is already a string at this point, so is not + // handled here. var buf bytes.Buffer lexbase.EncodeSQLBytesInner(&buf, string(t)) return buf.String() @@ -130,41 +122,45 @@ func formatArray( // backingArray is the array we're going to parse the server data // into. var backingArray interface{} - // parsingArray is a helper structure provided by lib/pq to parse + + // parsingArray is a helper structure provided by pgtype to parse // arrays. - var parsingArray gosql.Scanner + var parsingArray pgtype.Value - // lib.pq has different array parsers for special value types. - // - // TODO(knz): This would better use a general-purpose parser - // using the OID to look up an array parser in crdb's sql package. - // However, unfortunately the OID is hidden from us. + // pgx has different array parsers for special value types. switch colType { case "BOOL": boolArray := []bool{} backingArray = &boolArray - parsingArray = (*pq.BoolArray)(&boolArray) + parsingArray = &pgtype.BoolArray{} case "FLOAT4", "FLOAT8": floatArray := []float64{} backingArray = &floatArray - parsingArray = (*pq.Float64Array)(&floatArray) + parsingArray = &pgtype.Float8Array{} case "INT2", "INT4", "INT8", "OID": intArray := []int64{} backingArray = &intArray - parsingArray = (*pq.Int64Array)(&intArray) + parsingArray = &pgtype.Int8Array{} case "TEXT", "VARCHAR", "NAME", "CHAR", "BPCHAR", "RECORD": stringArray := []string{} backingArray = &stringArray - parsingArray = (*pq.StringArray)(&stringArray) - default: - genArray := [][]byte{} - backingArray = &genArray - parsingArray = &pq.GenericArray{A: &genArray} + parsingArray = &pgtype.TextArray{} } - // Now ask the pq array parser to convert the byte slice + // Now ask the pgx array parser to convert the byte slice // from the server into a Go array. - if err := parsingArray.Scan(b); err != nil { + var parseErr error + if parsingArray != nil { + parseErr = parsingArray.(pgtype.TextDecoder).DecodeText(nil, b) + if parseErr == nil { + parseErr = parsingArray.AssignTo(backingArray) + } + } else { + var untypedArray *pgtype.UntypedTextArray + untypedArray, parseErr = pgtype.ParseUntypedTextArray(string(b)) + backingArray = &untypedArray.Elements + } + if parseErr != nil { // A parsing failure is not a catastrophe; we can still print out // the array as a byte slice. This will do in many cases. return FormatVal(b, "BYTEA", showPrintableUnicode, showNewLinesAndTabs) diff --git a/pkg/cli/clisqlexec/format_value_test.go b/pkg/cli/clisqlexec/format_value_test.go index 5c93acef4561..d792f7f4a9cb 100644 --- a/pkg/cli/clisqlexec/format_value_test.go +++ b/pkg/cli/clisqlexec/format_value_test.go @@ -38,12 +38,22 @@ func Example_sql_format() { c.RunWithArgs([]string{"sql", "-e", `select '{"(n,1)","(\n,2)","(\\n,3)"}'::tup[]`}) c.RunWithArgs([]string{"sql", "-e", `create type e as enum('a', '\n')`}) c.RunWithArgs([]string{"sql", "-e", `select '\n'::e, '{a, "\\n"}'::e[]`}) + // Check that intervals are formatted correctly. + c.RunWithArgs([]string{"sql", "-e", `select '1 day 2 minutes'::interval`}) + c.RunWithArgs([]string{"sql", "-e", `select '{3 months 4 days 2 hours, 4 years 1 second}'::interval[]`}) + c.RunWithArgs([]string{"sql", "-e", `SET intervalstyle = sql_standard; select '1 day 2 minutes'::interval`}) + c.RunWithArgs([]string{"sql", "-e", `SET intervalstyle = sql_standard; select '{3 months 4 days 2 hours,4 years 1 second}'::interval[]`}) + c.RunWithArgs([]string{"sql", "-e", `SET intervalstyle = iso_8601; select '1 day 2 minutes'::interval`}) + c.RunWithArgs([]string{"sql", "-e", `SET intervalstyle = iso_8601; select '{3 months 4 days 2 hours,4 years 1 second}'::interval[]`}) + // Check that UUIDs are formatted correctly. + c.RunWithArgs([]string{"sql", "-e", `select 'f2b046d8-dc59-11ec-8b22-cbf9a9dd2f5f'::uuid`}) // Output: // sql -e create database t; create table t.times (bare timestamp, withtz timestamptz) + // CREATE DATABASE // CREATE TABLE // sql -e insert into t.times values ('2016-01-25 10:10:10', '2016-01-25 10:10:10-05:00') - // INSERT 1 + // INSERT 0 1 // sql -e select bare from t.times; select withtz from t.times // bare // 2016-01-25 10:10:10 @@ -81,7 +91,7 @@ func Example_sql_format() { // {Infinity,-Infinity} {Infinity} // sql -e select array[true, false], array['01:01'::time], array['2021-03-20'::date] // array array array - // {true,false} {01:01:00} {2021-03-20} + // {t,f} {01:01:00} {2021-03-20} // sql -e select array[123::int2], array[123::int4], array[123::int8] // array array array // {123} {123} {123} @@ -98,4 +108,29 @@ func Example_sql_format() { // sql -e select '\n'::e, '{a, "\\n"}'::e[] // e e // \n "{a,""\\n""}" + // sql -e select '1 day 2 minutes'::interval + // interval + // 1 day 00:02:00 + // sql -e select '{3 months 4 days 2 hours, 4 years 1 second}'::interval[] + // interval + // "{""3 mons 4 days 02:00:00"",""4 years 00:00:01""}" + // sql -e SET intervalstyle = sql_standard; select '1 day 2 minutes'::interval + // SET + // interval + // 1 0:02:00 + // sql -e SET intervalstyle = sql_standard; select '{3 months 4 days 2 hours,4 years 1 second}'::interval[] + // SET + // interval + // "{""+0-3 +4 +2:00:00"",""+4-0 +0 +0:00:01""}" + // sql -e SET intervalstyle = iso_8601; select '1 day 2 minutes'::interval + // SET + // interval + // P1DT2M + // sql -e SET intervalstyle = iso_8601; select '{3 months 4 days 2 hours,4 years 1 second}'::interval[] + // SET + // interval + // {P3M4DT2H,P4YT1S} + // sql -e select 'f2b046d8-dc59-11ec-8b22-cbf9a9dd2f5f'::uuid + // uuid + // f2b046d8-dc59-11ec-8b22-cbf9a9dd2f5f } diff --git a/pkg/cli/clisqlexec/run_query.go b/pkg/cli/clisqlexec/run_query.go index 3754a36d6169..9f52f3112f6f 100644 --- a/pkg/cli/clisqlexec/run_query.go +++ b/pkg/cli/clisqlexec/run_query.go @@ -12,7 +12,6 @@ package clisqlexec import ( "context" - "database/sql/driver" "fmt" "io" "strings" @@ -28,13 +27,13 @@ import ( // It runs the sql query and returns a list of columns names and a list of rows. func (sqlExecCtx *Context) RunQuery( ctx context.Context, conn clisqlclient.Conn, fn clisqlclient.QueryFn, showMoreChars bool, -) ([]string, [][]string, error) { +) (retCols []string, retRows [][]string, retErr error) { rows, _, err := fn(ctx, conn) if err != nil { return nil, nil, err } - defer func() { _ = rows.Close() }() + defer func() { retErr = errors.CombineErrors(retErr, rows.Close()) }() return sqlRowsToStrings(rows, showMoreChars) } @@ -54,7 +53,7 @@ func (sqlExecCtx *Context) RunQueryAndFormatResults( err = errors.CombineErrors(err, closeErr) }() for { - // lib/pq is not able to tell us before the first call to Next() + // pgx is not able to tell us before the first call to Next() // whether a statement returns either // - a rows result set with zero rows (e.g. SELECT on an empty table), or // - no rows result set, but a valid value for RowsAffected (e.g. INSERT), or @@ -65,57 +64,30 @@ func (sqlExecCtx *Context) RunQueryAndFormatResults( // when Next() has completed its work and no rows where observed, to decide // what to do. noRowsHook := func() (bool, error) { - res := rows.Result() - if ra, ok := res.(driver.RowsAffected); ok { - nRows, err := ra.RowsAffected() - if err != nil { - return false, err - } + tag, err := rows.Tag() + if err != nil { + return false, err + } + nRows := tag.RowsAffected() + tagString := tag.String() - // This may be either something like INSERT with a valid - // RowsAffected value, or a statement like SET. The pq driver - // uses both driver.RowsAffected for both. So we need to be a - // little more manual. - tag := rows.Tag() - if tag == "SELECT" && nRows == 0 { - // As explained above, the pq driver unhelpfully does not - // distinguish between a statement returning zero rows and a - // statement returning an affected row count of zero. - // noRowsHook is called non-discriminatingly for both - // situations. - // - // TODO(knz): meanwhile, there are rare, non-SELECT - // statements that have tag "SELECT" but are legitimately of - // type RowsAffected. CREATE TABLE AS is one. pq's inability - // to distinguish those two cases means that any non-SELECT - // statement that legitimately returns 0 rows affected, and - // for which the user would expect to see "SELECT 0", will - // be incorrectly displayed as an empty row result set - // instead. This needs to be addressed by ensuring pq can - // distinguish the two cases, or switching to an entirely - // different driver altogether. - // - return false, nil - } else if _, ok := tagsWithRowsAffected[tag]; ok { - // INSERT, DELETE, etc.: print the row count. - nRows, err := ra.RowsAffected() - if err != nil { - return false, err - } - fmt.Fprintf(w, "%s %d\n", tag, nRows) - } else { - // SET, etc.: just print the tag, or OK if there's no tag. - if tag == "" { - tag = "OK" - } - fmt.Fprintln(w, tag) - } - return true, nil + // This may be either something like INSERT with a valid + // RowsAffected value, or a statement like SET. + if strings.HasPrefix(tagString, "SELECT") && nRows == 0 { + // As explained above, the pgx driver unhelpfully does not + // distinguish between a statement returning zero rows and a + // statement returning an affected row count of zero. + // noRowsHook is called non-discriminatingly for both + // situations. + return false, nil + } + // SET, etc.: just print the tag, or OK if there's no tag. + // INSERT, DELETE, etc. all contain the row count in the tag itself. + if tagString == "" { + tagString = "OK" } - // Other cases: this is a statement with a rows result set, but - // zero rows (e.g. SELECT on empty table). Let the reporter - // handle it. - return false, nil + fmt.Fprintln(w, tagString) + return true, nil } cols := getColumnStrings(rows, true) @@ -136,11 +108,12 @@ func (sqlExecCtx *Context) RunQueryAndFormatResults( return err } - sqlExecCtx.maybeShowTimes(ctx, conn, w, ew, isMultiStatementQuery, startTime, queryCompleteTime) - if more, err := rows.NextResultSet(); err != nil { return err } else if !more { + // We must call maybeShowTimes after rows has been closed, which is after + // NextResultSet returns false. + sqlExecCtx.maybeShowTimes(ctx, conn, w, ew, isMultiStatementQuery, startTime, queryCompleteTime) return nil } } @@ -274,20 +247,6 @@ func (sqlExecCtx *Context) maybeShowTimes( fmt.Fprintln(w, stats.String()) } -// All tags where the RowsAffected value should be reported to -// the user. -var tagsWithRowsAffected = map[string]struct{}{ - "INSERT": {}, - "UPDATE": {}, - "DELETE": {}, - "MOVE": {}, - "DROP USER": {}, - "COPY": {}, - // This one is used with e.g. CREATE TABLE AS (other SELECT - // statements have type Rows, not RowsAffected). - "SELECT": {}, -} - // sqlRowsToStrings turns 'rows' into a list of rows, each of which // is a list of column values. // 'rows' should be closed by the caller. diff --git a/pkg/cli/clisqlexec/run_query_test.go b/pkg/cli/clisqlexec/run_query_test.go index 66bfc04a2fc0..35550ed54bb6 100644 --- a/pkg/cli/clisqlexec/run_query_test.go +++ b/pkg/cli/clisqlexec/run_query_test.go @@ -98,10 +98,10 @@ SET } expectedRows := [][]string{ - {`parentID`, `INT8`, `false`, `NULL`, ``, `{primary}`, `false`}, - {`parentSchemaID`, `INT8`, `false`, `NULL`, ``, `{primary}`, `false`}, - {`name`, `STRING`, `false`, `NULL`, ``, `{primary}`, `false`}, - {`id`, `INT8`, `true`, `NULL`, ``, `{primary}`, `false`}, + {`parentID`, `INT8`, `f`, `NULL`, ``, `{primary}`, `f`}, + {`parentSchemaID`, `INT8`, `f`, `NULL`, ``, `{primary}`, `f`}, + {`name`, `STRING`, `f`, `NULL`, ``, `{primary}`, `f`}, + {`id`, `INT8`, `t`, `NULL`, ``, `{primary}`, `f`}, } if !reflect.DeepEqual(expectedRows, rows) { t.Fatalf("expected:\n%v\ngot:\n%v", expectedRows, rows) @@ -115,10 +115,10 @@ SET expected = ` column_name | data_type | is_nullable | column_default | generation_expression | indices | is_hidden -----------------+-----------+-------------+----------------+-----------------------+-----------+------------ - parentID | INT8 | false | NULL | | {primary} | false - parentSchemaID | INT8 | false | NULL | | {primary} | false - name | STRING | false | NULL | | {primary} | false - id | INT8 | true | NULL | | {primary} | false + parentID | INT8 | f | NULL | | {primary} | f + parentSchemaID | INT8 | f | NULL | | {primary} | f + name | STRING | f | NULL | | {primary} | f + id | INT8 | t | NULL | | {primary} | f (4 rows) ` diff --git a/pkg/cli/clisqlshell/sql_test.go b/pkg/cli/clisqlshell/sql_test.go index 351ce7a8965f..c1ddf17a4569 100644 --- a/pkg/cli/clisqlshell/sql_test.go +++ b/pkg/cli/clisqlshell/sql_test.go @@ -67,7 +67,9 @@ func Example_sql() { // application_name // $ cockroach sql // sql -e create database t; create table t.f (x int, y int); insert into t.f values (42, 69) - // INSERT 1 + // CREATE DATABASE + // CREATE TABLE + // INSERT 0 1 // sql -e select 3 as "3" -e select * from t.f // 3 // 3 @@ -110,6 +112,8 @@ func Example_sql() { // sql -d nonexistent -e select count(*) from "".information_schema.tables limit 0 // count // sql -d nonexistent -e create database nonexistent; create table foo(x int); select * from foo + // CREATE DATABASE + // CREATE TABLE // x // sql -e copy t.f from stdin // sql -e select 1/(@1-2) from generate_series(1,3) @@ -122,6 +126,7 @@ func Example_sql() { // regression_65066 // 20:01:02+03:04:05 // sql -e CREATE USER my_user WITH CREATEDB; GRANT admin TO my_user; + // CREATE ROLE // GRANT // sql -e \du my_user // username options member_of @@ -186,7 +191,8 @@ func Example_sql_watch() { // Output: // sql -e create table d(x int); insert into d values(3) - // INSERT 1 + // CREATE TABLE + // INSERT 0 1 // sql --watch .1s -e update d set x=x-1 returning 1/x as dec // dec // 0.50000000000000000000 @@ -206,6 +212,7 @@ func Example_misc_table() { // Output: // sql -e create database t; create table t.t (s string, d string); + // CREATE DATABASE // CREATE TABLE // sql --format=table -e select ' hai' as x // x @@ -243,7 +250,9 @@ func Example_in_memory() { // Output: // sql -e create database t; create table t.f (x int, y int); insert into t.f values (42, 69) - // INSERT 1 + // CREATE DATABASE + // CREATE TABLE + // INSERT 0 1 // node ls // id // 1 @@ -264,15 +273,16 @@ func Example_pretty_print_numerical_strings() { // Output: // sql -e create database t; create table t.t (s string, d string); + // CREATE DATABASE // CREATE TABLE // sql -e insert into t.t values (e'0', 'positive numerical string') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'-1', 'negative numerical string') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'1.0', 'decimal numerical string') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'aaaaa', 'non-numerical string') - // INSERT 1 + // INSERT 0 1 // sql --format=table -e select * from t.t // s | d // --------+---------------------------- @@ -300,7 +310,7 @@ func Example_read_from_file() { // SET // CREATE TABLE // > INSERT INTO test(s) VALUES ('hello'), ('world'); - // INSERT 2 + // INSERT 0 2 // > SELECT * FROM test; // s // hello diff --git a/pkg/cli/context.go b/pkg/cli/context.go index 2a2c6aa0678e..b33929d84c85 100644 --- a/pkg/cli/context.go +++ b/pkg/cli/context.go @@ -212,7 +212,7 @@ func setCliContextDefaults() { cliCtx.IsInteractive = false cliCtx.EmbeddedMode = false cliCtx.cmdTimeout = 0 // no timeout - cliCtx.clientOpts.ServerHost = "" + cliCtx.clientOpts.ServerHost = getDefaultHost() cliCtx.clientOpts.ServerPort = base.DefaultPort cliCtx.certPrincipalMap = nil cliCtx.clientOpts.ExplicitURL = nil diff --git a/pkg/cli/debug_job_trace.go b/pkg/cli/debug_job_trace.go index 25e90dffdaa2..fe6db60a65e0 100644 --- a/pkg/cli/debug_job_trace.go +++ b/pkg/cli/debug_job_trace.go @@ -94,7 +94,7 @@ func constructJobTraceZipBundle(ctx context.Context, sqlConn clisqlclient.Conn, return err } - zipper := tracezipper.MakeSQLConnInflightTraceZipper(sqlConn.GetDriverConn().(driver.QueryerContext)) + zipper := tracezipper.MakeSQLConnInflightTraceZipper(sqlConn.GetDriverConn()) zipBytes, err := zipper.Zip(ctx, traceID) if err != nil { return err diff --git a/pkg/cli/demo.go b/pkg/cli/demo.go index 0581d022325f..b9f6ec7eb7d7 100644 --- a/pkg/cli/demo.go +++ b/pkg/cli/demo.go @@ -349,5 +349,5 @@ func runDemo(cmd *cobra.Command, gen workload.Generator) (resErr error) { defer func() { resErr = errors.CombineErrors(resErr, conn.Close()) }() sqlCtx.ShellCtx.ParseURL = clienturl.MakeURLParserFn(cmd, cliCtx.clientOpts) - return sqlCtx.Run(conn) + return sqlCtx.Run(ctx, conn) } diff --git a/pkg/cli/doctor.go b/pkg/cli/doctor.go index bd29d0683a21..5893818b1bd5 100644 --- a/pkg/cli/doctor.go +++ b/pkg/cli/doctor.go @@ -40,7 +40,8 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/protoutil" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" - "github.com/lib/pq" + "github.com/jackc/pgconn" + "github.com/jackc/pgtype" "github.com/spf13/cobra" ) @@ -198,6 +199,10 @@ func fromCluster( retErr error, ) { ctx := context.Background() + if err := sqlConn.EnsureConn(ctx); err != nil { + return nil, nil, nil, err + } + if timeout != 0 { if err := sqlConn.Exec(ctx, `SET statement_timeout = $1`, timeout.String()); err != nil { @@ -211,8 +216,8 @@ FROM system.descriptor ORDER BY id` _, err := sqlConn.QueryRow(ctx, checkColumnExistsStmt) // On versions before 20.2, the system.descriptor won't have the builtin // crdb_internal_mvcc_timestamp. If we can't find it, use NULL instead. - if pqErr := (*pq.Error)(nil); errors.As(err, &pqErr) { - if pgcode.MakeCode(string(pqErr.Code)) == pgcode.UndefinedColumn { + if pgErr := (*pgconn.PgError)(nil); errors.As(err, &pgErr) { + if pgcode.MakeCode(pgErr.Code) == pgcode.UndefinedColumn { stmt = ` SELECT id, descriptor, NULL AS mod_time_logical FROM system.descriptor ORDER BY id` @@ -236,8 +241,12 @@ FROM system.descriptor ORDER BY id` } if vals[2] == nil { row.ModTime = hlc.Timestamp{WallTime: timeutil.Now().UnixNano()} - } else if mt, ok := vals[2].([]byte); ok { - decimal, _, err := apd.NewFromString(string(mt)) + } else if mt, ok := vals[2].(pgtype.Numeric); ok { + buf, err := mt.EncodeText(nil, nil) + if err != nil { + return err + } + decimal, _, err := apd.NewFromString(string(buf)) if err != nil { return err } @@ -261,8 +270,8 @@ FROM system.descriptor ORDER BY id` _, err = sqlConn.QueryRow(ctx, checkColumnExistsStmt) // On versions before 20.1, table system.namespace does not have this column. // In that case the ParentSchemaID for tables is 29 and for databases is 0. - if pqErr := (*pq.Error)(nil); errors.As(err, &pqErr) { - if pgcode.MakeCode(string(pqErr.Code)) == pgcode.UndefinedColumn { + if pgErr := (*pgconn.PgError)(nil); errors.As(err, &pgErr) { + if pgcode.MakeCode(pgErr.Code) == pgcode.UndefinedColumn { stmt = ` SELECT "parentID", CASE WHEN "parentID" = 0 THEN 0 ELSE 29 END AS "parentSchemaID", name, id FROM system.namespace` diff --git a/pkg/cli/env.go b/pkg/cli/env.go new file mode 100644 index 000000000000..c7f0f1b19c39 --- /dev/null +++ b/pkg/cli/env.go @@ -0,0 +1,24 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package cli + +import "os" + +// getDefaultHost gets the default value for the host to connect to. pgx already +// has logic that would inspect PGHOST, but the problem is that if PGHOST not +// defined, pgx would default to using a unix socket. That is not desired, so +// here we make the CLI fallback to use "localhost" if PGHOST is not defined. +func getDefaultHost() string { + if h := os.Getenv("PGHOST"); h == "" { + return "localhost" + } + return "" +} diff --git a/pkg/cli/flags_test.go b/pkg/cli/flags_test.go index 4e95fdb48b85..ef9d1aca6e1c 100644 --- a/pkg/cli/flags_test.go +++ b/pkg/cli/flags_test.go @@ -972,7 +972,7 @@ func TestClientConnSettings(t *testing.T) { args []string expectedAddr string }{ - {[]string{"quit"}, ":" + base.DefaultPort}, + {[]string{"quit"}, "localhost:" + base.DefaultPort}, {[]string{"quit", "--host", "127.0.0.1"}, "127.0.0.1:" + base.DefaultPort}, {[]string{"quit", "--host", "192.168.0.111"}, "192.168.0.111:" + base.DefaultPort}, {[]string{"quit", "--host", ":12345"}, ":12345"}, @@ -986,7 +986,7 @@ func TestClientConnSettings(t *testing.T) { "[2622:6221:e663:4922:fc2b:788b:fadd:7b48]:" + base.DefaultPort}, // Deprecated syntax. - {[]string{"quit", "--port", "12345"}, ":12345"}, + {[]string{"quit", "--port", "12345"}, "localhost:12345"}, {[]string{"quit", "--host", "127.0.0.1", "--port", "12345"}, "127.0.0.1:12345"}, } diff --git a/pkg/cli/import.go b/pkg/cli/import.go index a9f0588f0375..a0fbfbf079c6 100644 --- a/pkg/cli/import.go +++ b/pkg/cli/import.go @@ -111,7 +111,7 @@ func runImport( importFormat, source, tableName string, mode importMode, ) error { - if err := conn.EnsureConn(); err != nil { + if err := conn.EnsureConn(ctx); err != nil { return err } @@ -216,7 +216,7 @@ func runImport( return nil } - _, err = ex.ExecContext(ctx, importQuery, nil) + err = ex.Exec(ctx, importQuery) if err != nil { return err } diff --git a/pkg/cli/interactive_tests/test_client_side_checking.tcl b/pkg/cli/interactive_tests/test_client_side_checking.tcl index 3f6556de09a4..8c61e80f75f5 100644 --- a/pkg/cli/interactive_tests/test_client_side_checking.tcl +++ b/pkg/cli/interactive_tests/test_client_side_checking.tcl @@ -28,7 +28,7 @@ end_test start_test "Check that the syntax checker does not get confused by empty inputs." # (issue #22441.) send ";\r" -eexpect "0 rows" +eexpect "OK" eexpect root@ end_test diff --git a/pkg/cli/interactive_tests/test_connect_cmd.tcl b/pkg/cli/interactive_tests/test_connect_cmd.tcl index f3b86742fc44..4268af2ae4dc 100644 --- a/pkg/cli/interactive_tests/test_connect_cmd.tcl +++ b/pkg/cli/interactive_tests/test_connect_cmd.tcl @@ -132,7 +132,7 @@ start_test "Check that the client-side connect cmd can change users with certs u # first test that it can recover from an invalid database send "\\c postgres://root@localhost:26257/invaliddb?sslmode=require&sslcert=$certs_dir%2Fclient.root.crt&sslkey=$certs_dir%2Fclient.root.key&sslrootcert=$certs_dir%2Fca.crt\r" eexpect "using new connection URL" -eexpect "error retrieving the database name: pq: database \"invaliddb\" does not exist" +eexpect "error retrieving the database name: ERROR: database \"invaliddb\" does not exist" eexpect root@ eexpect "?>" diff --git a/pkg/cli/interactive_tests/test_demo_node_cmds.tcl b/pkg/cli/interactive_tests/test_demo_node_cmds.tcl index f74640c24bb1..2f5f0b9d4794 100644 --- a/pkg/cli/interactive_tests/test_demo_node_cmds.tcl +++ b/pkg/cli/interactive_tests/test_demo_node_cmds.tcl @@ -42,11 +42,11 @@ eexpect "node 3 has been shutdown" eexpect "movr>" send "select node_id, draining, decommissioning, membership from crdb_internal.gossip_liveness ORDER BY node_id;\r" -eexpect "1 | false | false | active" -eexpect "2 | false | false | active" -eexpect "3 | true | false | active" -eexpect "4 | false | false | active" -eexpect "5 | false | false | active" +eexpect "1 | f | f | active" +eexpect "2 | f | f | active" +eexpect "3 | t | f | active" +eexpect "4 | f | f | active" +eexpect "5 | f | f | active" eexpect "movr>" # Cannot shut it down again. diff --git a/pkg/cli/interactive_tests/test_reconnect.tcl b/pkg/cli/interactive_tests/test_reconnect.tcl index d2ed73c1e795..2c63efee88cb 100644 --- a/pkg/cli/interactive_tests/test_reconnect.tcl +++ b/pkg/cli/interactive_tests/test_reconnect.tcl @@ -23,7 +23,7 @@ start_test "Check that the client properly detects the server went down" force_stop_server $argv send "select 1;\r" -eexpect "bad connection" +eexpect "connection closed unexpectedly" eexpect root@ eexpect " ?>" end_test @@ -56,7 +56,7 @@ stop_server $argv start_server $argv send "select 1;\r" -eexpect "bad connection" +eexpect "connection closed unexpectedly" eexpect root@ send "select 1;\r" diff --git a/pkg/cli/interactive_tests/test_sql_version_reporting.tcl b/pkg/cli/interactive_tests/test_sql_version_reporting.tcl index f82164fde477..199728d28430 100644 --- a/pkg/cli/interactive_tests/test_sql_version_reporting.tcl +++ b/pkg/cli/interactive_tests/test_sql_version_reporting.tcl @@ -31,7 +31,7 @@ start_test "Check that a reconnect without version change is quiet." force_stop_server $argv start_server $argv send "select 1;\r" -eexpect "driver: bad connection" +eexpect "connection closed unexpectedly" # Check that the prompt immediately succeeds the error message eexpect "connection lost" eexpect "opening new connection: all session settings will be lost" diff --git a/pkg/cli/interactive_tests/test_url_db_override.tcl b/pkg/cli/interactive_tests/test_url_db_override.tcl index 75e053d51daf..407cac6e3fd0 100644 --- a/pkg/cli/interactive_tests/test_url_db_override.tcl +++ b/pkg/cli/interactive_tests/test_url_db_override.tcl @@ -38,7 +38,7 @@ start_test "Check that the insecure flag overrides the sslmode if URL is already set ::env(COCKROACH_INSECURE) "false" spawn $argv sql --url "postgresql://test@localhost:26257?sslmode=verify-full" --certs-dir=$certs_dir -e "select 1" -eexpect "SSL is not enabled on the server" +eexpect "server refused TLS connection" eexpect eof spawn $argv sql --url "postgresql://test@localhost:26257?sslmode=verify-full" --certs-dir=$certs_dir --insecure -e "select 1" diff --git a/pkg/cli/node.go b/pkg/cli/node.go index 3fc671a2eaa1..d12abe2c8b0d 100644 --- a/pkg/cli/node.go +++ b/pkg/cli/node.go @@ -228,6 +228,9 @@ FROM crdb_internal.gossip_liveness LEFT JOIN crdb_internal.gossip_nodes USING (n } ctx := context.Background() + if err = conn.EnsureConn(ctx); err != nil { + return nil, nil, err + } // TODO(knz): This can use a context deadline instead, now that // query cancellation is supported. diff --git a/pkg/cli/nodelocal.go b/pkg/cli/nodelocal.go index 141aaf83dfe4..fa4f9ed14e42 100644 --- a/pkg/cli/nodelocal.go +++ b/pkg/cli/nodelocal.go @@ -11,8 +11,8 @@ package cli import ( + "bytes" "context" - "database/sql/driver" "fmt" "io" "net/url" @@ -27,9 +27,7 @@ import ( "github.com/spf13/cobra" ) -const ( - chunkSize = 4 * 1024 -) +const chunkSize = 4 * 1024 var nodeLocalUploadCmd = &cobra.Command{ Use: "upload ", @@ -74,17 +72,56 @@ func openSourceFile(source string) (io.ReadCloser, error) { return f, nil } +// appendEscapedText escapes the input text for processing by the pgwire COPY +// protocol. The result is appended to the []byte given by buf. +// This implementation is copied from lib/pq. +// https://github.com/lib/pq/blob/8c6de565f76fb5cd40a5c1b8ce583fbc3ba1bd0e/encode.go#L138 +func appendEscapedText(buf []byte, text string) []byte { + escapeNeeded := false + startPos := 0 + var c byte + + // check if we need to escape + for i := 0; i < len(text); i++ { + c = text[i] + if c == '\\' || c == '\n' || c == '\r' || c == '\t' { + escapeNeeded = true + startPos = i + break + } + } + if !escapeNeeded { + return append(buf, text...) + } + + // copy till first char to escape, iterate the rest + result := append(buf, text[:startPos]...) + for i := startPos; i < len(text); i++ { + c = text[i] + switch c { + case '\\': + result = append(result, '\\', '\\') + case '\n': + result = append(result, '\\', 'n') + case '\r': + result = append(result, '\\', 'r') + case '\t': + result = append(result, '\\', 't') + default: + result = append(result, c) + } + } + return result +} + func uploadFile( ctx context.Context, conn clisqlclient.Conn, reader io.Reader, destination string, ) error { - if err := conn.EnsureConn(); err != nil { + if err := conn.EnsureConn(ctx); err != nil { return err } ex := conn.GetDriverConn() - if _, err := ex.ExecContext(ctx, `BEGIN`, nil); err != nil { - return err - } // Construct the nodelocal URI as the destination for the CopyIn stmt. nodelocalURL := url.URL{ @@ -92,42 +129,22 @@ func uploadFile( Host: "self", Path: destination, } - stmt, err := conn.GetDriverConn().Prepare(sql.CopyInFileStmt(nodelocalURL.String(), sql.CrdbInternalName, - sql.NodelocalFileUploadTable)) - if err != nil { - return err - } + stmt := sql.CopyInFileStmt(nodelocalURL.String(), sql.CrdbInternalName, sql.NodelocalFileUploadTable) - defer func() { - if stmt != nil { - _ = stmt.Close() - _, _ = ex.ExecContext(ctx, `ROLLBACK`, nil) - } - }() - - send := make([]byte, chunkSize) + send := make([]byte, 0) + tmp := make([]byte, chunkSize) for { - n, err := reader.Read(send) + n, err := reader.Read(tmp) if n > 0 { - // TODO(adityamaru): Switch to StmtExecContext once the copyin driver - // supports it. - //lint:ignore SA1019 DriverConn doesn't support go 1.8 API - _, err = stmt.Exec([]driver.Value{string(send[:n])}) - if err != nil { - return err - } + send = appendEscapedText(send, string(tmp[:n])) } else if err == io.EOF { break } else if err != nil { return err } } - if err := stmt.Close(); err != nil { - return err - } - stmt = nil - if _, err := ex.ExecContext(ctx, `COMMIT`, nil); err != nil { + if err := ex.CopyFrom(ctx, bytes.NewReader(send), stmt); err != nil { return err } diff --git a/pkg/cli/sql_shell_cmd.go b/pkg/cli/sql_shell_cmd.go index edcc20fdf915..5b8ae110a2df 100644 --- a/pkg/cli/sql_shell_cmd.go +++ b/pkg/cli/sql_shell_cmd.go @@ -11,6 +11,7 @@ package cli import ( + "context" "fmt" "os" @@ -59,5 +60,5 @@ func runTerm(cmd *cobra.Command, args []string) (resErr error) { defer func() { resErr = errors.CombineErrors(resErr, conn.Close()) }() sqlCtx.ShellCtx.ParseURL = clienturl.MakeURLParserFn(cmd, cliCtx.clientOpts) - return sqlCtx.Run(conn) + return sqlCtx.Run(context.Background(), conn) } diff --git a/pkg/cli/statement_bundle.go b/pkg/cli/statement_bundle.go index 9bc5952d2c12..ac9e9a542a70 100644 --- a/pkg/cli/statement_bundle.go +++ b/pkg/cli/statement_bundle.go @@ -215,7 +215,7 @@ func runBundleRecreate(cmd *cobra.Command, args []string) error { } sqlCtx.ShellCtx.DemoCluster = c - return sqlCtx.Run(conn) + return sqlCtx.Run(ctx, conn) } // placeholderRe matches the placeholder format at the bottom of statement.txt diff --git a/pkg/cli/testdata/zip/partial1 b/pkg/cli/testdata/zip/partial1 index 1b5a40c1a3f2..6a388c4a4db9 100644 --- a/pkg/cli/testdata/zip/partial1 +++ b/pkg/cli/testdata/zip/partial1 @@ -155,64 +155,64 @@ debug zip --concurrency=1 --cpu-profile-duration=0s /dev/null [node 2] node status... converting to JSON... writing binary output: debug/nodes/2/status.json... done [node 2] using SQL connection URL: postgresql://... [node 2] retrieving SQL data for crdb_internal.feature_usage... writing output: debug/nodes/2/crdb_internal.feature_usage.txt... -[node 2] retrieving SQL data for crdb_internal.feature_usage: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.feature_usage: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.feature_usage: creating error output: debug/nodes/2/crdb_internal.feature_usage.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.gossip_alerts... writing output: debug/nodes/2/crdb_internal.gossip_alerts.txt... -[node 2] retrieving SQL data for crdb_internal.gossip_alerts: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.gossip_alerts: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.gossip_alerts: creating error output: debug/nodes/2/crdb_internal.gossip_alerts.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.gossip_liveness... writing output: debug/nodes/2/crdb_internal.gossip_liveness.txt... -[node 2] retrieving SQL data for crdb_internal.gossip_liveness: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.gossip_liveness: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.gossip_liveness: creating error output: debug/nodes/2/crdb_internal.gossip_liveness.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.gossip_network... writing output: debug/nodes/2/crdb_internal.gossip_network.txt... -[node 2] retrieving SQL data for crdb_internal.gossip_network: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.gossip_network: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.gossip_network: creating error output: debug/nodes/2/crdb_internal.gossip_network.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.gossip_nodes... writing output: debug/nodes/2/crdb_internal.gossip_nodes.txt... -[node 2] retrieving SQL data for crdb_internal.gossip_nodes: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.gossip_nodes: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.gossip_nodes: creating error output: debug/nodes/2/crdb_internal.gossip_nodes.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.leases... writing output: debug/nodes/2/crdb_internal.leases.txt... -[node 2] retrieving SQL data for crdb_internal.leases: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.leases: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.leases: creating error output: debug/nodes/2/crdb_internal.leases.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_build_info... writing output: debug/nodes/2/crdb_internal.node_build_info.txt... -[node 2] retrieving SQL data for crdb_internal.node_build_info: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_build_info: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_build_info: creating error output: debug/nodes/2/crdb_internal.node_build_info.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_contention_events... writing output: debug/nodes/2/crdb_internal.node_contention_events.txt... -[node 2] retrieving SQL data for crdb_internal.node_contention_events: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_contention_events: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_contention_events: creating error output: debug/nodes/2/crdb_internal.node_contention_events.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_distsql_flows... writing output: debug/nodes/2/crdb_internal.node_distsql_flows.txt... -[node 2] retrieving SQL data for crdb_internal.node_distsql_flows: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_distsql_flows: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_distsql_flows: creating error output: debug/nodes/2/crdb_internal.node_distsql_flows.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_execution_outliers... writing output: debug/nodes/2/crdb_internal.node_execution_outliers.txt... -[node 2] retrieving SQL data for crdb_internal.node_execution_outliers: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_execution_outliers: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_execution_outliers: creating error output: debug/nodes/2/crdb_internal.node_execution_outliers.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_inflight_trace_spans... writing output: debug/nodes/2/crdb_internal.node_inflight_trace_spans.txt... -[node 2] retrieving SQL data for crdb_internal.node_inflight_trace_spans: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_inflight_trace_spans: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_inflight_trace_spans: creating error output: debug/nodes/2/crdb_internal.node_inflight_trace_spans.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_metrics... writing output: debug/nodes/2/crdb_internal.node_metrics.txt... -[node 2] retrieving SQL data for crdb_internal.node_metrics: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_metrics: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_metrics: creating error output: debug/nodes/2/crdb_internal.node_metrics.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_queries... writing output: debug/nodes/2/crdb_internal.node_queries.txt... -[node 2] retrieving SQL data for crdb_internal.node_queries: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_queries: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_queries: creating error output: debug/nodes/2/crdb_internal.node_queries.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_runtime_info... writing output: debug/nodes/2/crdb_internal.node_runtime_info.txt... -[node 2] retrieving SQL data for crdb_internal.node_runtime_info: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_runtime_info: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_runtime_info: creating error output: debug/nodes/2/crdb_internal.node_runtime_info.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_sessions... writing output: debug/nodes/2/crdb_internal.node_sessions.txt... -[node 2] retrieving SQL data for crdb_internal.node_sessions: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_sessions: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_sessions: creating error output: debug/nodes/2/crdb_internal.node_sessions.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_statement_statistics... writing output: debug/nodes/2/crdb_internal.node_statement_statistics.txt... -[node 2] retrieving SQL data for crdb_internal.node_statement_statistics: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_statement_statistics: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_statement_statistics: creating error output: debug/nodes/2/crdb_internal.node_statement_statistics.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_transaction_statistics... writing output: debug/nodes/2/crdb_internal.node_transaction_statistics.txt... -[node 2] retrieving SQL data for crdb_internal.node_transaction_statistics: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_transaction_statistics: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_transaction_statistics: creating error output: debug/nodes/2/crdb_internal.node_transaction_statistics.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_transactions... writing output: debug/nodes/2/crdb_internal.node_transactions.txt... -[node 2] retrieving SQL data for crdb_internal.node_transactions: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_transactions: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_transactions: creating error output: debug/nodes/2/crdb_internal.node_transactions.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_txn_stats... writing output: debug/nodes/2/crdb_internal.node_txn_stats.txt... -[node 2] retrieving SQL data for crdb_internal.node_txn_stats: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_txn_stats: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_txn_stats: creating error output: debug/nodes/2/crdb_internal.node_txn_stats.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.active_range_feeds... writing output: debug/nodes/2/crdb_internal.active_range_feeds.txt... -[node 2] retrieving SQL data for crdb_internal.active_range_feeds: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.active_range_feeds: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.active_range_feeds: creating error output: debug/nodes/2/crdb_internal.active_range_feeds.txt.err.txt... done [node 2] requesting data for debug/nodes/2/details... received response... [node 2] requesting data for debug/nodes/2/details: last request failed: rpc error: ... diff --git a/pkg/cli/testdata/zip/testzip_tenant b/pkg/cli/testdata/zip/testzip_tenant index d267e30060ce..c81064f1e585 100644 --- a/pkg/cli/testdata/zip/testzip_tenant +++ b/pkg/cli/testdata/zip/testzip_tenant @@ -33,13 +33,13 @@ debug zip --concurrency=1 --cpu-profile-duration=1s /dev/null [cluster] retrieving SQL data for "".crdb_internal.create_statements... writing output: debug/crdb_internal.create_statements.txt... done [cluster] retrieving SQL data for "".crdb_internal.create_type_statements... writing output: debug/crdb_internal.create_type_statements.txt... done [cluster] retrieving SQL data for crdb_internal.kv_node_liveness... writing output: debug/crdb_internal.kv_node_liveness.txt... -[cluster] retrieving SQL data for crdb_internal.kv_node_liveness: last request failed: pq: unimplemented: operation is unsupported in multi-tenancy mode +[cluster] retrieving SQL data for crdb_internal.kv_node_liveness: last request failed: ERROR: unimplemented: operation is unsupported in multi-tenancy mode (SQLSTATE 0A000) [cluster] retrieving SQL data for crdb_internal.kv_node_liveness: creating error output: debug/crdb_internal.kv_node_liveness.txt.err.txt... done [cluster] retrieving SQL data for crdb_internal.kv_node_status... writing output: debug/crdb_internal.kv_node_status.txt... -[cluster] retrieving SQL data for crdb_internal.kv_node_status: last request failed: pq: unimplemented: operation is unsupported in multi-tenancy mode +[cluster] retrieving SQL data for crdb_internal.kv_node_status: last request failed: ERROR: unimplemented: operation is unsupported in multi-tenancy mode (SQLSTATE 0A000) [cluster] retrieving SQL data for crdb_internal.kv_node_status: creating error output: debug/crdb_internal.kv_node_status.txt.err.txt... done [cluster] retrieving SQL data for crdb_internal.kv_store_status... writing output: debug/crdb_internal.kv_store_status.txt... -[cluster] retrieving SQL data for crdb_internal.kv_store_status: last request failed: pq: unimplemented: operation is unsupported in multi-tenancy mode +[cluster] retrieving SQL data for crdb_internal.kv_store_status: last request failed: ERROR: unimplemented: operation is unsupported in multi-tenancy mode (SQLSTATE 0A000) [cluster] retrieving SQL data for crdb_internal.kv_store_status: creating error output: debug/crdb_internal.kv_store_status.txt.err.txt... done [cluster] retrieving SQL data for crdb_internal.regions... writing output: debug/crdb_internal.regions.txt... done [cluster] retrieving SQL data for crdb_internal.schema_changes... writing output: debug/crdb_internal.schema_changes.txt... done @@ -90,16 +90,16 @@ debug zip --concurrency=1 --cpu-profile-duration=1s /dev/null [node 1] using SQL connection URL: postgresql://... [node 1] retrieving SQL data for crdb_internal.feature_usage... writing output: debug/nodes/1/crdb_internal.feature_usage.txt... done [node 1] retrieving SQL data for crdb_internal.gossip_alerts... writing output: debug/nodes/1/crdb_internal.gossip_alerts.txt... -[node 1] retrieving SQL data for crdb_internal.gossip_alerts: last request failed: pq: unimplemented: operation is unsupported in multi-tenancy mode +[node 1] retrieving SQL data for crdb_internal.gossip_alerts: last request failed: ERROR: unimplemented: operation is unsupported in multi-tenancy mode (SQLSTATE 0A000) [node 1] retrieving SQL data for crdb_internal.gossip_alerts: creating error output: debug/nodes/1/crdb_internal.gossip_alerts.txt.err.txt... done [node 1] retrieving SQL data for crdb_internal.gossip_liveness... writing output: debug/nodes/1/crdb_internal.gossip_liveness.txt... -[node 1] retrieving SQL data for crdb_internal.gossip_liveness: last request failed: pq: unimplemented: operation is unsupported in multi-tenancy mode +[node 1] retrieving SQL data for crdb_internal.gossip_liveness: last request failed: ERROR: unimplemented: operation is unsupported in multi-tenancy mode (SQLSTATE 0A000) [node 1] retrieving SQL data for crdb_internal.gossip_liveness: creating error output: debug/nodes/1/crdb_internal.gossip_liveness.txt.err.txt... done [node 1] retrieving SQL data for crdb_internal.gossip_network... writing output: debug/nodes/1/crdb_internal.gossip_network.txt... -[node 1] retrieving SQL data for crdb_internal.gossip_network: last request failed: pq: unimplemented: operation is unsupported in multi-tenancy mode +[node 1] retrieving SQL data for crdb_internal.gossip_network: last request failed: ERROR: unimplemented: operation is unsupported in multi-tenancy mode (SQLSTATE 0A000) [node 1] retrieving SQL data for crdb_internal.gossip_network: creating error output: debug/nodes/1/crdb_internal.gossip_network.txt.err.txt... done [node 1] retrieving SQL data for crdb_internal.gossip_nodes... writing output: debug/nodes/1/crdb_internal.gossip_nodes.txt... -[node 1] retrieving SQL data for crdb_internal.gossip_nodes: last request failed: pq: unimplemented: operation is unsupported in multi-tenancy mode +[node 1] retrieving SQL data for crdb_internal.gossip_nodes: last request failed: ERROR: unimplemented: operation is unsupported in multi-tenancy mode (SQLSTATE 0A000) [node 1] retrieving SQL data for crdb_internal.gossip_nodes: creating error output: debug/nodes/1/crdb_internal.gossip_nodes.txt.err.txt... done [node 1] retrieving SQL data for crdb_internal.leases... writing output: debug/nodes/1/crdb_internal.leases.txt... done [node 1] retrieving SQL data for crdb_internal.node_build_info... writing output: debug/nodes/1/crdb_internal.node_build_info.txt... done diff --git a/pkg/cli/userfile.go b/pkg/cli/userfile.go index 1aafd9e10c56..f8c81a0da952 100644 --- a/pkg/cli/userfile.go +++ b/pkg/cli/userfile.go @@ -11,8 +11,8 @@ package cli import ( + "bytes" "context" - "database/sql/driver" "fmt" "io" "io/fs" @@ -241,7 +241,7 @@ func runUserFileGet(cmd *cobra.Command, args []string) (resErr error) { pattern := fullPath[len(conf.Path):] displayPath := strings.TrimPrefix(conf.Path, "/") - f, err := userfile.MakeSQLConnFileTableStorage(ctx, conf, conn.GetDriverConn().(cloud.SQLConnI)) + f, err := userfile.MakeSQLConnFileTableStorage(ctx, conf, conn.GetDriverConn()) if err != nil { return err } @@ -399,7 +399,7 @@ func constructUserfileListURI(glob string, user username.SQLUsername) string { func getUserfileConf( ctx context.Context, conn clisqlclient.Conn, glob string, ) (roachpb.ExternalStorage_FileTable, error) { - if err := conn.EnsureConn(); err != nil { + if err := conn.EnsureConn(ctx); err != nil { return roachpb.ExternalStorage_FileTable{}, err } @@ -434,7 +434,7 @@ func listUserFile(ctx context.Context, conn clisqlclient.Conn, glob string) ([]s conf.Path = cloud.GetPrefixBeforeWildcard(fullPath) pattern := fullPath[len(conf.Path):] - f, err := userfile.MakeSQLConnFileTableStorage(ctx, conf, conn.GetDriverConn().(cloud.SQLConnI)) + f, err := userfile.MakeSQLConnFileTableStorage(ctx, conf, conn.GetDriverConn()) if err != nil { return nil, err } @@ -482,7 +482,7 @@ func downloadUserfile( } func deleteUserFile(ctx context.Context, conn clisqlclient.Conn, glob string) ([]string, error) { - if err := conn.EnsureConn(); err != nil { + if err := conn.EnsureConn(ctx); err != nil { return nil, err } @@ -511,8 +511,7 @@ func deleteUserFile(ctx context.Context, conn clisqlclient.Conn, glob string) ([ userFileTableConf.FileTableConfig.Path = cloud.GetPrefixBeforeWildcard(fullPath) pattern := fullPath[len(userFileTableConf.FileTableConfig.Path):] - f, err := userfile.MakeSQLConnFileTableStorage(ctx, userFileTableConf.FileTableConfig, - conn.GetDriverConn().(cloud.SQLConnI)) + f, err := userfile.MakeSQLConnFileTableStorage(ctx, userFileTableConf.FileTableConfig, conn.GetDriverConn()) if err != nil { return nil, err } @@ -543,39 +542,16 @@ func renameUserFile( ctx context.Context, conn clisqlclient.Conn, oldFilename, newFilename, qualifiedTableName string, ) error { - if err := conn.EnsureConn(); err != nil { + if err := conn.EnsureConn(ctx); err != nil { return err } - ex := conn.GetDriverConn() - if _, err := ex.ExecContext(ctx, `BEGIN`, nil); err != nil { - return err - } - - stmt, err := conn.GetDriverConn().Prepare(fmt.Sprintf(`UPDATE %s SET filename=$1 WHERE filename=$2`, - qualifiedTableName+fileTableNameSuffix)) - if err != nil { - return err - } - - defer func() { - if stmt != nil { - _ = stmt.Close() - _, _ = ex.ExecContext(ctx, `ROLLBACK`, nil) - } - }() - //lint:ignore SA1019 DriverConn doesn't support go 1.8 API - _, err = stmt.Exec([]driver.Value{newFilename, oldFilename}) - if err != nil { - return err - } - - if err := stmt.Close(); err != nil { - return err - } - stmt = nil - if _, err := ex.ExecContext(ctx, `COMMIT`, nil); err != nil { + if err := ex.Exec( + ctx, + fmt.Sprintf(`UPDATE %s SET filename=$1 WHERE filename=$2`, qualifiedTableName+fileTableNameSuffix), + newFilename, oldFilename, + ); err != nil { return err } @@ -595,14 +571,11 @@ func uploadUserFile( } defer reader.Close() - if err := conn.EnsureConn(); err != nil { + if err := conn.EnsureConn(ctx); err != nil { return "", err } ex := conn.GetDriverConn() - if _, err := ex.ExecContext(ctx, `BEGIN`, nil); err != nil { - return "", err - } connURL, err := url.Parse(conn.GetURL()) if err != nil { @@ -634,42 +607,22 @@ func uploadUserFile( if err != nil { return "", err } - stmt, err := conn.GetDriverConn().Prepare(sql.CopyInFileStmt(unescapedUserfileURL, sql.CrdbInternalName, - sql.UserFileUploadTable)) - if err != nil { - return "", err - } + stmt := sql.CopyInFileStmt(unescapedUserfileURL, sql.CrdbInternalName, sql.UserFileUploadTable) - defer func() { - if stmt != nil { - _ = stmt.Close() - _, _ = ex.ExecContext(ctx, `ROLLBACK`, nil) - } - }() - - send := make([]byte, chunkSize) + send := make([]byte, 0) + tmp := make([]byte, chunkSize) for { - n, err := reader.Read(send) + n, err := reader.Read(tmp) if n > 0 { - // TODO(adityamaru): Switch to StmtExecContext once the copyin driver - // supports it. - //lint:ignore SA1019 DriverConn doesn't support go 1.8 API - _, err = stmt.Exec([]driver.Value{string(send[:n])}) - if err != nil { - return "", err - } + send = appendEscapedText(send, string(tmp[:n])) } else if err == io.EOF { break } else if err != nil { return "", err } } - if err := stmt.Close(); err != nil { - return "", err - } - stmt = nil - if _, err := ex.ExecContext(ctx, `COMMIT`, nil); err != nil { + if err := ex.CopyFrom(ctx, bytes.NewReader(send), stmt); err != nil { return "", err } diff --git a/pkg/cli/zip.go b/pkg/cli/zip.go index 29fbbf26b484..b9a8a336774c 100644 --- a/pkg/cli/zip.go +++ b/pkg/cli/zip.go @@ -29,7 +29,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/util/contextutil" "github.com/cockroachdb/errors" - "github.com/lib/pq" + "github.com/jackc/pgconn" "github.com/marusama/semaphore" "github.com/spf13/cobra" ) @@ -325,9 +325,7 @@ func maybeAddProfileSuffix(name string) string { func (zc *debugZipContext) dumpTableDataForZip( zr *zipReporter, conn clisqlclient.Conn, base, table, query string, ) error { - // TODO(knz): This can use context cancellation now that query - // cancellation is supported. - fullQuery := fmt.Sprintf(`SET statement_timeout = '%s'; %s`, zc.timeout, query) + ctx := context.Background() baseName := base + "/" + sanitizeFilename(table) s := zr.start("retrieving SQL data for %s", table) @@ -340,26 +338,33 @@ func (zc *debugZipContext) dumpTableDataForZip( zc.z.Lock() defer zc.z.Unlock() + // TODO(knz): This can use context cancellation now that query + // cancellation is supported in v22.1 and later. + // SET must be run separately from the query so that the command tag output + // doesn't get added to the debug file. + err := conn.Exec(ctx, fmt.Sprintf(`SET statement_timeout = '%s'`, zc.timeout)) + if err != nil { + return err + } + w, err := zc.z.createLocked(name, time.Time{}) if err != nil { return err } // Pump the SQL rows directly into the zip writer, to avoid // in-RAM buffering. - return sqlExecCtx.RunQueryAndFormatResults( - context.Background(), - conn, w, stderr, clisqlclient.MakeQuery(fullQuery)) + return sqlExecCtx.RunQueryAndFormatResults(ctx, conn, w, stderr, clisqlclient.MakeQuery(query)) }() if sqlErr != nil { if cErr := zc.z.createError(s, name, sqlErr); cErr != nil { return cErr } - var pqErr *pq.Error - if !errors.As(sqlErr, &pqErr) { + var pgErr = (*pgconn.PgError)(nil) + if !errors.As(sqlErr, &pgErr) { // Not a SQL error. Nothing to retry. break } - if pgcode.MakeCode(string(pqErr.Code)) != pgcode.SerializationFailure { + if pgcode.MakeCode(pgErr.Code) != pgcode.SerializationFailure { // A non-retry error. We've printed the error, and // there's nothing to retry. Stop here. break diff --git a/pkg/cli/zip_test.go b/pkg/cli/zip_test.go index 137b6c8d8b03..304a26722d50 100644 --- a/pkg/cli/zip_test.go +++ b/pkg/cli/zip_test.go @@ -336,6 +336,8 @@ func eraseNonDeterministicZipOutput(out string) string { out = re.ReplaceAllString(out, `dial tcp ...`) re = regexp.MustCompile(`(?m)rpc error: .*$`) out = re.ReplaceAllString(out, `rpc error: ...`) + re = regexp.MustCompile(`(?m)failed to connect to .*$`) + out = re.ReplaceAllString(out, `failed to connect to ...`) // The number of memory profiles previously collected is not deterministic. re = regexp.MustCompile(`(?m)^\[node \d+\] \d+ heap profiles found$`) diff --git a/pkg/cloud/external_storage.go b/pkg/cloud/external_storage.go index efc86ff44683..b2b06c903135 100644 --- a/pkg/cloud/external_storage.go +++ b/pkg/cloud/external_storage.go @@ -109,8 +109,8 @@ type ExternalStorageFromURIFactory func(ctx context.Context, uri string, // SQLConnI encapsulates the interfaces which will be implemented by the network // backed SQLConn which is used to interact with the userfile tables. type SQLConnI interface { - driver.QueryerContext - driver.ExecerContext + Query(ctx context.Context, query string, args ...interface{}) (driver.Rows, error) + Exec(ctx context.Context, query string, args ...interface{}) error } // ErrFileDoesNotExist is a sentinel error for indicating that a specified diff --git a/pkg/cloud/userfile/filetable/file_table_read_writer.go b/pkg/cloud/userfile/filetable/file_table_read_writer.go index a040ed03a9af..75b5bbfd1514 100644 --- a/pkg/cloud/userfile/filetable/file_table_read_writer.go +++ b/pkg/cloud/userfile/filetable/file_table_read_writer.go @@ -116,18 +116,8 @@ func (i *SQLConnFileToTableExecutor) Query( ) (*FileToTableExecutorRows, error) { result := FileToTableExecutorRows{} - argVals := make([]driver.NamedValue, len(qargs)) - for i, qarg := range qargs { - namedVal := driver.NamedValue{ - // Ordinal position is 1 indexed. - Ordinal: i + 1, - Value: qarg, - } - argVals[i] = namedVal - } - var err error - result.sqlConnExecResults, err = i.executor.QueryContext(ctx, query, argVals) + result.sqlConnExecResults, err = i.executor.Query(ctx, query, qargs...) if err != nil { return nil, err } @@ -138,17 +128,7 @@ func (i *SQLConnFileToTableExecutor) Query( func (i *SQLConnFileToTableExecutor) Exec( ctx context.Context, _, query string, _ username.SQLUsername, qargs ...interface{}, ) error { - argVals := make([]driver.NamedValue, len(qargs)) - for i, qarg := range qargs { - namedVal := driver.NamedValue{ - // Ordinal position is 1 indexed. - Ordinal: i + 1, - Value: qarg, - } - argVals[i] = namedVal - } - _, err := i.executor.ExecContext(ctx, query, argVals) - return err + return i.executor.Exec(ctx, query, qargs...) } // FileToTableSystem can be used to store, retrieve and delete the @@ -318,7 +298,9 @@ func (f *FileToTableSystem) FileSize(ctx context.Context, filename string) (int6 // ListFiles returns a list of all the files which are currently stored in the // user scoped tables. -func (f *FileToTableSystem) ListFiles(ctx context.Context, pattern string) ([]string, error) { +func (f *FileToTableSystem) ListFiles( + ctx context.Context, pattern string, +) (retFiles []string, retErr error) { var files []string listFilesQuery := fmt.Sprintf(`SELECT filename FROM %s WHERE filename LIKE $1 ORDER BY filename`, f.GetFQFileTableName()) @@ -334,6 +316,12 @@ filename`, f.GetFQFileTableName()) case *InternalFileToTableExecutor: // Verify that all the filenames are strings and aggregate them. it := rows.internalExecResultsIterator + defer func() { + if err := it.Close(); err != nil { + retErr = errors.CombineErrors(retErr, err) + log.Warningf(ctx, "failed to close %+v", err) + } + }() var ok bool for ok, err = it.Next(ctx); ok; ok, err = it.Next(ctx) { files = append(files, string(tree.MustBeDString(it.Cur()[0]))) @@ -342,6 +330,12 @@ filename`, f.GetFQFileTableName()) return nil, err } case *SQLConnFileToTableExecutor: + defer func() { + if err := rows.sqlConnExecResults.Close(); err != nil { + retErr = errors.CombineErrors(retErr, err) + log.Warningf(ctx, "failed to close %+v", err) + } + }() vals := make([]driver.Value, 1) for { if err := rows.sqlConnExecResults.Next(vals); err == io.EOF { @@ -352,10 +346,6 @@ filename`, f.GetFQFileTableName()) filename := vals[0].(string) files = append(files, filename) } - - if err = rows.sqlConnExecResults.Close(); err != nil { - return nil, err - } default: return []string{}, errors.New("unsupported executor type in FileSize") } @@ -649,7 +639,7 @@ func newFileTableReader( fileTableName, payloadTableName string, ie FileToTableSystemExecutor, offset int64, -) (ioctx.ReadCloserCtx, int64, error) { +) (_ ioctx.ReadCloserCtx, _ int64, retErr error) { // Get file_id from metadata entry in File table. var fileID []byte var sz int64 @@ -668,6 +658,7 @@ func newFileTableReader( it := metaRows.internalExecResultsIterator defer func() { if err := it.Close(); err != nil { + retErr = errors.CombineErrors(retErr, err) log.Warningf(ctx, "failed to close %+v", err) } }() @@ -685,6 +676,7 @@ func newFileTableReader( case *SQLConnFileToTableExecutor: defer func() { if err := metaRows.sqlConnExecResults.Close(); err != nil { + retErr = errors.CombineErrors(retErr, err) log.Warningf(ctx, "failed to close %+v", err) } }() @@ -695,7 +687,8 @@ func newFileTableReader( } else if err != nil { return nil, 0, errors.Wrap(err, "failed to read returned file metadata") } - fileID = vals[0].([]byte) + uuidBytes := vals[0].([16]byte) + fileID = uuidBytes[:] if vals[1] != nil { sz = vals[1].(int64) } @@ -709,7 +702,7 @@ func newFileTableReader( const bufSize = 256 << 10 - fn := func(p []byte, pos int64) (int, error) { + fn := func(p []byte, pos int64) (_ int, retErr error) { if pos >= sz { return 0, io.EOF } @@ -730,6 +723,7 @@ func newFileTableReader( it := rows.internalExecResultsIterator defer func() { if err := it.Close(); err != nil { + retErr = errors.CombineErrors(retErr, err) log.Warningf(ctx, "failed to close %+v", err) } }() @@ -744,6 +738,7 @@ func newFileTableReader( case *SQLConnFileToTableExecutor: defer func() { if err := rows.sqlConnExecResults.Close(); err != nil { + retErr = errors.CombineErrors(retErr, err) log.Warningf(ctx, "failed to close %+v", err) } }() diff --git a/pkg/cmd/cockroach-sql/main.go b/pkg/cmd/cockroach-sql/main.go index c96b77191ea4..206b71a3d7fb 100644 --- a/pkg/cmd/cockroach-sql/main.go +++ b/pkg/cmd/cockroach-sql/main.go @@ -14,6 +14,7 @@ package main import ( + "context" "fmt" "os" @@ -191,5 +192,5 @@ func runSQL(cmd *cobra.Command, args []string) (resErr error) { defer func() { resErr = errors.CombineErrors(resErr, conn.Close()) }() cfg.ShellCtx.ParseURL = clienturl.MakeURLParserFn(cmd, copts) - return cfg.Run(conn) + return cfg.Run(context.Background(), conn) } diff --git a/pkg/sql/copy_file_upload.go b/pkg/sql/copy_file_upload.go index 27df7093c166..bc6b17e540ec 100644 --- a/pkg/sql/copy_file_upload.go +++ b/pkg/sql/copy_file_upload.go @@ -168,11 +168,11 @@ func CopyInFileStmt(destination, schema, table string) string { } func (f *fileUploadMachine) run(ctx context.Context) error { - err := f.c.run(ctx) - if err != nil && f.cancel != nil { + runErr := f.c.run(ctx) + err := errors.CombineErrors(f.w.Close(), runErr) + if runErr != nil && f.cancel != nil { f.cancel() } - err = errors.CombineErrors(f.w.Close(), err) if err != nil { f.failureCleanup() diff --git a/pkg/testutils/lint/lint_test.go b/pkg/testutils/lint/lint_test.go index 770015a42ba2..4b0c30e8dc21 100644 --- a/pkg/testutils/lint/lint_test.go +++ b/pkg/testutils/lint/lint_test.go @@ -445,6 +445,7 @@ func TestLint(t *testing.T) { ":!util/sdnotify/sdnotify_unix.go", ":!util/grpcutil", // GRPC_GO_* variables ":!roachprod", // roachprod requires AWS environment variables + ":!cli/env.go", // The CLI needs the PGHOST variable. }, }, } { diff --git a/pkg/util/tracing/zipper/zipper.go b/pkg/util/tracing/zipper/zipper.go index 56365cbd84e8..7a3d08f18e3d 100644 --- a/pkg/util/tracing/zipper/zipper.go +++ b/pkg/util/tracing/zipper/zipper.go @@ -201,13 +201,17 @@ func MakeInternalExecutorInflightTraceZipper( var _ InflightTraceZipper = &InternalInflightTraceZipper{} +type queryI interface { + Query(ctx context.Context, query string, args ...interface{}) (driver.Rows, error) +} + // SQLConnInflightTraceZipper is the InflightTraceZipper which uses a network // backed SQL connection to collect cluster wide traces. type SQLConnInflightTraceZipper struct { traceStrBuf *bytes.Buffer nodeTraceCollection *tracing.TraceCollection z *memzipper.Zipper - sqlConn driver.QueryerContext + sqlConn queryI } func (s *SQLConnInflightTraceZipper) getNodeTraceCollection() *tracing.TraceCollection { @@ -234,7 +238,7 @@ func (s *SQLConnInflightTraceZipper) reset() { // into text, and jaegerJSON formats before creating a zip with per-node trace // files. func (s *SQLConnInflightTraceZipper) Zip(ctx context.Context, traceID int64) ([]byte, error) { - rows, err := s.sqlConn.QueryContext(ctx, fmt.Sprintf(inflightTracesQuery, traceID), nil /* args */) + rows, err := s.sqlConn.Query(ctx, fmt.Sprintf(inflightTracesQuery, traceID)) if err != nil { return nil, err } @@ -341,7 +345,7 @@ func (s *SQLConnInflightTraceZipper) populateInflightTraceRow( // MakeSQLConnInflightTraceZipper returns an instance of // SQLConnInflightTraceZipper. -func MakeSQLConnInflightTraceZipper(sqlConn driver.QueryerContext) *SQLConnInflightTraceZipper { +func MakeSQLConnInflightTraceZipper(sqlConn queryI) *SQLConnInflightTraceZipper { t := &SQLConnInflightTraceZipper{ traceStrBuf: &bytes.Buffer{}, nodeTraceCollection: nil,