From fcc7de4c4864f4763cd8d496150c4fd88a183339 Mon Sep 17 00:00:00 2001 From: Rafi Shamim Date: Wed, 27 Apr 2022 10:03:14 -0400 Subject: [PATCH] cli: switch to jackc/pgx - Address TODO about RowsAffected. Unlike lib/pq pgx doesn't use incorrect command tags for queries that return 0 rows, so we don't need this special logic. - BYTEA values are now formatted correctly. Along with this, we need some special handling for formatting NAME[], "char", TIME, INTERVAL, NUMERIC, and UUID types. - We needed an additional "multi result set" rows object. lib/pq would handle this with one interface, but this isn't possible using pgx/pgconn. Release note (cli change): BYTEA values are now formatted according to the bytea_output session setting. Release note (cli change): The statement tag displayed for INSERT statements now has the full information returned by the server. It is the string "INSERT" followed by the OID of the row that was inserted (which is currently always 0 in CockroachDB), followed by the number of rows inserted. Release note (backward-incompatible change): In the cockroach CLI, BOOLEAN values are now formatted as `t` or `f` instead of `True` or `False`. --- pkg/cli/BUILD.bazel | 4 +- pkg/cli/clierror/BUILD.bazel | 4 +- pkg/cli/clierror/error_test.go | 35 ++-- pkg/cli/clierror/formatted_error.go | 28 +-- pkg/cli/clierror/syntax_error.go | 6 +- pkg/cli/clierrorplus/BUILD.bazel | 2 +- pkg/cli/clierrorplus/decorate_error.go | 20 ++- pkg/cli/clisqlcfg/context.go | 16 +- pkg/cli/clisqlclient/BUILD.bazel | 8 +- pkg/cli/clisqlclient/api.go | 57 ++++-- pkg/cli/clisqlclient/conn.go | 170 ++++++++++++------ pkg/cli/clisqlclient/conn_test.go | 9 +- pkg/cli/clisqlclient/copy.go | 55 ++---- pkg/cli/clisqlclient/make_query.go | 20 --- pkg/cli/clisqlclient/row_type_helpers.go | 72 ++++++++ pkg/cli/clisqlclient/rows.go | 76 ++++---- pkg/cli/clisqlclient/rows_multi.go | 157 ++++++++++++++++ pkg/cli/clisqlexec/BUILD.bazel | 2 +- pkg/cli/clisqlexec/format_table.go | 16 +- pkg/cli/clisqlexec/format_table_test.go | 27 +-- pkg/cli/clisqlexec/format_value.go | 60 +++---- pkg/cli/clisqlexec/format_value_test.go | 39 +++- pkg/cli/clisqlexec/run_query.go | 97 +++------- pkg/cli/clisqlexec/run_query_test.go | 16 +- pkg/cli/clisqlshell/sql_test.go | 26 ++- pkg/cli/context.go | 2 +- pkg/cli/debug_job_trace.go | 2 +- pkg/cli/demo.go | 2 +- pkg/cli/doctor.go | 23 ++- pkg/cli/env.go | 24 +++ pkg/cli/flags_test.go | 4 +- pkg/cli/import.go | 4 +- .../test_client_side_checking.tcl | 2 +- .../interactive_tests/test_connect_cmd.tcl | 2 +- .../interactive_tests/test_demo_node_cmds.tcl | 10 +- pkg/cli/interactive_tests/test_reconnect.tcl | 4 +- .../test_sql_version_reporting.tcl | 2 +- .../test_url_db_override.tcl | 2 +- pkg/cli/node.go | 3 + pkg/cli/nodelocal.go | 85 +++++---- pkg/cli/sql_shell_cmd.go | 3 +- pkg/cli/statement_bundle.go | 2 +- pkg/cli/testdata/zip/partial1 | 40 ++--- pkg/cli/testdata/zip/testzip_tenant | 14 +- pkg/cli/userfile.go | 85 ++------- pkg/cli/zip.go | 25 +-- pkg/cli/zip_test.go | 2 + pkg/cloud/external_storage.go | 4 +- .../filetable/file_table_read_writer.go | 55 +++--- pkg/cmd/cockroach-sql/main.go | 3 +- pkg/sql/copy_file_upload.go | 6 +- pkg/testutils/lint/lint_test.go | 1 + pkg/util/tracing/zipper/zipper.go | 10 +- 53 files changed, 894 insertions(+), 549 deletions(-) create mode 100644 pkg/cli/clisqlclient/row_type_helpers.go create mode 100644 pkg/cli/clisqlclient/rows_multi.go create mode 100644 pkg/cli/env.go diff --git a/pkg/cli/BUILD.bazel b/pkg/cli/BUILD.bazel index 94075f7e5403..a1060ae379bf 100644 --- a/pkg/cli/BUILD.bazel +++ b/pkg/cli/BUILD.bazel @@ -29,6 +29,7 @@ go_library( "demo.go", "demo_telemetry.go", "doctor.go", + "env.go", "examples.go", "flags.go", "flags_util.go", @@ -230,8 +231,9 @@ go_library( "@com_github_cockroachdb_ttycolor//:ttycolor", "@com_github_dustin_go_humanize//:go-humanize", "@com_github_gogo_protobuf//jsonpb", + "@com_github_jackc_pgconn//:pgconn", + "@com_github_jackc_pgtype//:pgtype", "@com_github_kr_pretty//:pretty", - "@com_github_lib_pq//:pq", "@com_github_marusama_semaphore//:semaphore", "@com_github_mattn_go_isatty//:go-isatty", "@com_github_spf13_cobra//:cobra", diff --git a/pkg/cli/clierror/BUILD.bazel b/pkg/cli/clierror/BUILD.bazel index a623e1b2572b..7e01e27b3a95 100644 --- a/pkg/cli/clierror/BUILD.bazel +++ b/pkg/cli/clierror/BUILD.bazel @@ -18,7 +18,7 @@ go_library( "//pkg/util/log/logpb", "//pkg/util/log/severity", "@com_github_cockroachdb_errors//:errors", - "@com_github_lib_pq//:pq", + "@com_github_jackc_pgconn//:pgconn", ], ) @@ -47,7 +47,7 @@ go_test( "//pkg/util/log/channel", "//pkg/util/log/severity", "@com_github_cockroachdb_errors//:errors", - "@com_github_lib_pq//:pq", + "@com_github_jackc_pgconn//:pgconn", "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", ], diff --git a/pkg/cli/clierror/error_test.go b/pkg/cli/clierror/error_test.go index 4e71a4943d96..0820462f8bc4 100644 --- a/pkg/cli/clierror/error_test.go +++ b/pkg/cli/clierror/error_test.go @@ -20,7 +20,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/errors" - "github.com/lib/pq" + "github.com/jackc/pgconn" "github.com/stretchr/testify/assert" ) @@ -43,13 +43,13 @@ func TestOutputError(t *testing.T) { // Check the verbose output. This includes the uncategorized sqlstate. {errBase, false, true, "woo\nSQLSTATE: " + pgcode.Uncategorized.String() + "\nLOCATION: " + refLoc}, {errBase, true, true, "ERROR: woo\nSQLSTATE: " + pgcode.Uncategorized.String() + "\nLOCATION: " + refLoc}, - // Check the same over pq.Error objects. - {&pq.Error{Message: "woo"}, false, false, "woo"}, - {&pq.Error{Message: "woo"}, true, false, "ERROR: woo"}, - {&pq.Error{Message: "woo"}, false, true, "woo"}, - {&pq.Error{Message: "woo"}, true, true, "ERROR: woo"}, - {&pq.Error{Severity: "W", Message: "woo"}, false, false, "woo"}, - {&pq.Error{Severity: "W", Message: "woo"}, true, false, "W: woo"}, + // Check the same over pgconn.PgError objects. + {&pgconn.PgError{Message: "woo"}, false, false, "woo"}, + {&pgconn.PgError{Message: "woo"}, true, false, "ERROR: woo"}, + {&pgconn.PgError{Message: "woo"}, false, true, "woo"}, + {&pgconn.PgError{Message: "woo"}, true, true, "ERROR: woo"}, + {&pgconn.PgError{Severity: "W", Message: "woo"}, false, false, "woo"}, + {&pgconn.PgError{Severity: "W", Message: "woo"}, true, false, "W: woo"}, // Check hint printed after message. {errors.WithHint(errBase, "hello"), false, false, "woo\nHINT: hello"}, // Check sqlstate printed before hint, location after hint. @@ -85,16 +85,17 @@ func TestFormatLocation(t *testing.T) { defer log.Scope(t).Close(t) testData := []struct { - file, line, fn string - exp string + file string + line int + fn, exp string }{ - {"", "", "", ""}, - {"a.b", "", "", "a.b"}, - {"", "123", "", ":123"}, - {"", "", "abc", "abc"}, - {"a.b", "", "abc", "abc, a.b"}, - {"a.b", "123", "", "a.b:123"}, - {"", "123", "abc", "abc, :123"}, + {"", 0, "", ""}, + {"a.b", 0, "", "a.b"}, + {"", 123, "", ":123"}, + {"", 0, "abc", "abc"}, + {"a.b", 0, "abc", "abc, a.b"}, + {"a.b", 123, "", "a.b:123"}, + {"", 123, "abc", "abc, :123"}, } for _, tc := range testData { diff --git a/pkg/cli/clierror/formatted_error.go b/pkg/cli/clierror/formatted_error.go index 35c285e16083..ad6baa8d09c9 100644 --- a/pkg/cli/clierror/formatted_error.go +++ b/pkg/cli/clierror/formatted_error.go @@ -19,7 +19,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" "github.com/cockroachdb/errors" - "github.com/lib/pq" + "github.com/jackc/pgconn" ) // OutputError prints out an error object on the given writer. @@ -77,15 +77,15 @@ func (f *formattedError) Error() string { // Extract the fields. var message, hint, detail, location, constraintName string var code pgcode.Code - if pqErr := (*pq.Error)(nil); errors.As(f.err, &pqErr) { - if pqErr.Severity != "" { - severity = pqErr.Severity + if pgErr := (*pgconn.PgError)(nil); errors.As(f.err, &pgErr) { + if pgErr.Severity != "" { + severity = pgErr.Severity } - constraintName = pqErr.Constraint - message = pqErr.Message - code = pgcode.MakeCode(string(pqErr.Code)) - hint, detail = pqErr.Hint, pqErr.Detail - location = formatLocation(pqErr.File, pqErr.Line, pqErr.Routine) + constraintName = pgErr.ConstraintName + message = pgErr.Message + code = pgcode.MakeCode(pgErr.Code) + hint, detail = pgErr.Hint, pgErr.Detail + location = formatLocation(pgErr.File, int(pgErr.Line), pgErr.Routine) } else { message = f.err.Error() code = pgerror.GetPGCode(f.err) @@ -93,7 +93,7 @@ func (f *formattedError) Error() string { hint = errors.FlattenHints(f.err) detail = errors.FlattenDetails(f.err) if file, line, fn, ok := errors.GetOneLineSource(f.err); ok { - location = formatLocation(file, strconv.FormatInt(int64(line), 10), fn) + location = formatLocation(file, line, fn) } } @@ -144,10 +144,10 @@ func (f *formattedError) Error() string { // formatLocation spells out the error's location in a format // similar to psql: routine then file:num. The routine part is // skipped if empty. -func formatLocation(file, line, fn string) string { +func formatLocation(file string, line int, fn string) string { var res strings.Builder res.WriteString(fn) - if file != "" || line != "" { + if file != "" || line != 0 { if fn != "" { res.WriteString(", ") } @@ -156,9 +156,9 @@ func formatLocation(file, line, fn string) string { } else { res.WriteString(file) } - if line != "" { + if line != 0 { res.WriteByte(':') - res.WriteString(line) + res.WriteString(strconv.Itoa(line)) } } return res.String() diff --git a/pkg/cli/clierror/syntax_error.go b/pkg/cli/clierror/syntax_error.go index 1d2a3615207d..20216361495f 100644 --- a/pkg/cli/clierror/syntax_error.go +++ b/pkg/cli/clierror/syntax_error.go @@ -13,15 +13,15 @@ package clierror import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/errors" - "github.com/lib/pq" + "github.com/jackc/pgconn" ) // IsSQLSyntaxError returns true iff the provided error is a SQL // syntax error. The function works for the queries executed via the // clisqlclient/clisqlexec packages. func IsSQLSyntaxError(err error) bool { - if pqErr := (*pq.Error)(nil); errors.As(err, &pqErr) { - return string(pqErr.Code) == pgcode.Syntax.String() + if pgErr := (*pgconn.PgError)(nil); errors.As(err, &pgErr) { + return pgErr.Code == pgcode.Syntax.String() } return false } diff --git a/pkg/cli/clierrorplus/BUILD.bazel b/pkg/cli/clierrorplus/BUILD.bazel index 7728e6b89074..4d2850ab9273 100644 --- a/pkg/cli/clierrorplus/BUILD.bazel +++ b/pkg/cli/clierrorplus/BUILD.bazel @@ -19,7 +19,7 @@ go_library( "//pkg/util/log", "//pkg/util/netutil", "@com_github_cockroachdb_errors//:errors", - "@com_github_lib_pq//:pq", + "@com_github_jackc_pgconn//:pgconn", "@com_github_spf13_cobra//:cobra", "@org_golang_google_grpc//codes", "@org_golang_google_grpc//status", diff --git a/pkg/cli/clierrorplus/decorate_error.go b/pkg/cli/clierrorplus/decorate_error.go index 6e141b723b96..00e7d44f7759 100644 --- a/pkg/cli/clierrorplus/decorate_error.go +++ b/pkg/cli/clierrorplus/decorate_error.go @@ -26,7 +26,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/grpcutil" "github.com/cockroachdb/cockroach/pkg/util/netutil" "github.com/cockroachdb/errors" - "github.com/lib/pq" + "github.com/jackc/pgconn" "github.com/spf13/cobra" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -107,7 +107,9 @@ func MaybeDecorateError( } // Is this an "unable to connect" type of error? - if errors.Is(err, pq.ErrSSLNotSupported) { + // TODO(rafi): patch jackc/pgconn to add an ErrSSLNotSupported constant. + // See https://github.com/jackc/pgconn/issues/118. + if strings.Contains(err.Error(), "server refused TLS connection") { // SQL command failed after establishing a TCP connection // successfully, but discovering that it cannot use TLS while it // expected the server supports TLS. @@ -134,13 +136,13 @@ func MaybeDecorateError( return connRefused() } - if wErr := (*pq.Error)(nil); errors.As(err, &wErr) { + if wErr := (*pgconn.PgError)(nil); errors.As(err, &wErr) { // SQL commands will fail with a pq error but only after // establishing a TCP connection successfully. So if we got // here, there was a TCP connection already. // Did we fail due to security settings? - if pgcode.MakeCode(string(wErr.Code)) == pgcode.ProtocolViolation { + if pgcode.MakeCode(wErr.Code) == pgcode.ProtocolViolation { return connSecurityHint() } @@ -170,6 +172,16 @@ func MaybeDecorateError( return connFailed() } + // pgconn sometimes returns a context deadline error wrapped in a private + // connectError struct, which begins with this message. + if strings.HasPrefix(err.Error(), "failed to connect to") && + errors.IsAny(err, + context.DeadlineExceeded, + context.Canceled, + ) { + return connFailed() + } + if wErr := (*netutil.InitialHeartbeatFailedError)(nil); errors.As(err, &wErr) { // A GRPC TCP connection was established but there was an early failure. // Try to distinguish the cases. diff --git a/pkg/cli/clisqlcfg/context.go b/pkg/cli/clisqlcfg/context.go index 5c23ca9df834..fb5e1b9c3cd0 100644 --- a/pkg/cli/clisqlcfg/context.go +++ b/pkg/cli/clisqlcfg/context.go @@ -170,7 +170,7 @@ func (c *Context) MakeConn(url string) (clisqlclient.Conn, error) { // ensures that if the server was not initialized or there is some // network issue, the client will not be left to hang forever. // - // This is a lib/pq feature. + // This is a pgx feature. if baseURL.GetOption("connect_timeout") == "" && c.ConnectTimeout != 0 { _ = baseURL.SetOption("connect_timeout", strconv.Itoa(c.ConnectTimeout)) } @@ -181,18 +181,28 @@ func (c *Context) MakeConn(url string) (clisqlclient.Conn, error) { conn := c.ConnCtx.MakeSQLConn(c.CmdOut, c.CmdErr, url) conn.SetMissingPassword(!usePw || !pwdSet) + // By default, all connections will use the underlying driver to infer + // result types. This should be set back to false for any use case where the + // results are only shown for textual display. + conn.SetAlwaysInferResultTypes(true) + return conn, nil } // Run executes the SQL shell. -func (c *Context) Run(conn clisqlclient.Conn) error { +func (c *Context) Run(ctx context.Context, conn clisqlclient.Conn) error { if !c.opened { return errors.AssertionFailedf("programming error: Open not called yet") } + // Anything using a SQL shell (e.g. `cockroach sql` or `demo`), only needs + // to show results in text format, so the underlying driver doesn't need to + // infer types. + conn.SetAlwaysInferResultTypes(false) + // Open the connection to make sure everything is OK before running any // statements. Performs authentication. - if err := conn.EnsureConn(); err != nil { + if err := conn.EnsureConn(ctx); err != nil { return err } diff --git a/pkg/cli/clisqlclient/BUILD.bazel b/pkg/cli/clisqlclient/BUILD.bazel index 8affcf1e103c..ca46c1016e22 100644 --- a/pkg/cli/clisqlclient/BUILD.bazel +++ b/pkg/cli/clisqlclient/BUILD.bazel @@ -11,7 +11,9 @@ go_library( "init_conn_error.go", "make_query.go", "parse_bool.go", + "row_type_helpers.go", "rows.go", + "rows_multi.go", "statement_diag.go", "string_to_duration.go", "txn_shim.go", @@ -23,12 +25,14 @@ go_library( "//pkg/cli/clicfg", "//pkg/cli/clierror", "//pkg/security/pprompt", + "//pkg/sql/pgwire/pgcode", "//pkg/sql/scanner", "//pkg/util/version", "@com_github_cockroachdb_cockroach_go_v2//crdb", "@com_github_cockroachdb_errors//:errors", - "@com_github_lib_pq//:pq", - "@com_github_lib_pq_auth_kerberos//:kerberos", + "@com_github_jackc_pgconn//:pgconn", + "@com_github_jackc_pgtype//:pgtype", + "@com_github_jackc_pgx_v4//:pgx", "@com_github_otan_gopgkrb5//:gopgkrb5", ], ) diff --git a/pkg/cli/clisqlclient/api.go b/pkg/cli/clisqlclient/api.go index 0c27bd21c9e7..60f1dc5d2802 100644 --- a/pkg/cli/clisqlclient/api.go +++ b/pkg/cli/clisqlclient/api.go @@ -13,6 +13,7 @@ package clisqlclient import ( "context" "database/sql/driver" + "io" "reflect" "time" ) @@ -24,7 +25,7 @@ type Conn interface { Close() error // EnsureConn (re-)establishes the connection to the server. - EnsureConn() error + EnsureConn(ctx context.Context) error // Exec executes a statement. Exec(ctx context.Context, query string, args ...interface{}) error @@ -68,7 +69,12 @@ type Conn interface { // server requests one. SetMissingPassword(missing bool) - // GetServerMetadata() returns details about the CockroachDB node + // SetAlwaysInferResultTypes configures the alwaysInferResultTypes flag, which + // determines if the client should use the underlying driver to infer result + // types. + SetAlwaysInferResultTypes(b bool) + + // GetServerMetadata returns details about the CockroachDB node // this connection is connected to. GetServerMetadata(ctx context.Context) ( nodeID int32, @@ -85,18 +91,22 @@ type Conn interface { // The sql argument is the SQL query to use to retrieve the value. GetServerValue(ctx context.Context, what, sql string) (driver.Value, string, bool) - // GetDriverConn exposes the underlying SQL driver connection object + // GetDriverConn exposes the underlying driver connection object // for use by the cli package. GetDriverConn() DriverConn } // Rows describes a result set. type Rows interface { + driver.Rows + // The caller must call Close() when done with the // result and check the error. Close() error // Columns returns the column labels of the current result set. + // The implementation of this method should cache the result so that the + // result does not need to be constructed on each invocation. Columns() []string // ColumnTypeScanType returns the natural Go type of values at the @@ -111,11 +121,8 @@ type Rows interface { // columns. ColumnTypeNames() []string - // Result retrieves the underlying driver result object. - Result() driver.Result - // Tag retrieves the statement tag for the current result set. - Tag() string + Tag() (CommandTag, error) // Next populates values with the next row of results. []byte values are copied // so that subsequent calls to Next and Close do not mutate values. This @@ -130,6 +137,12 @@ type Rows interface { NextResultSet() (bool, error) } +// CommandTag represents the result of a SQL command. +type CommandTag interface { + RowsAffected() int64 + String() string +} + // QueryStatsDuration represents a duration value retrieved by // GetLastQueryStatistics. type QueryStatsDuration struct { @@ -168,10 +181,32 @@ type TxBoundConn interface { } // DriverConn is the type of the connection object returned by -// (Conn).GetDriverConn(). It gives access to the underlying Go sql +// (Conn).GetDriverConn(). It gives access to the underlying sql // driver. type DriverConn interface { - driver.Conn - driver.ExecerContext - driver.QueryerContext + Query(ctx context.Context, query string, args ...interface{}) (driver.Rows, error) + Exec(ctx context.Context, query string, args ...interface{}) error + CopyFrom(ctx context.Context, reader io.Reader, query string) error +} + +type driverConnAdapter struct { + c *sqlConn +} + +var _ DriverConn = (*driverConnAdapter)(nil) + +func (d *driverConnAdapter) Query( + ctx context.Context, query string, args ...interface{}, +) (driver.Rows, error) { + rows, err := d.c.Query(ctx, query, args...) + return driver.Rows(rows), err +} + +func (d *driverConnAdapter) Exec(ctx context.Context, query string, args ...interface{}) error { + return d.c.Exec(ctx, query, args...) +} + +func (d *driverConnAdapter) CopyFrom(ctx context.Context, reader io.Reader, query string) error { + _, err := d.c.conn.PgConn().CopyFrom(ctx, reader, query) + return err } diff --git a/pkg/cli/clisqlclient/conn.go b/pkg/cli/clisqlclient/conn.go index 86221c1ebe4a..9f6f938137fc 100644 --- a/pkg/cli/clisqlclient/conn.go +++ b/pkg/cli/clisqlclient/conn.go @@ -15,6 +15,7 @@ import ( "database/sql/driver" "fmt" "io" + "net" "net/url" "strconv" "strings" @@ -23,16 +24,17 @@ import ( "github.com/cockroachdb/cockroach/pkg/build" "github.com/cockroachdb/cockroach/pkg/cli/clierror" "github.com/cockroachdb/cockroach/pkg/security/pprompt" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/util/version" "github.com/cockroachdb/errors" - "github.com/lib/pq" - "github.com/lib/pq/auth/kerberos" - _ "github.com/otan/gopgkrb5" // need a comment until the dependency is used + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" + "github.com/otan/gopgkrb5" ) func init() { // Ensure that the CLI client commands can use GSSAPI authentication. - pq.RegisterGSSProvider(func() (pq.GSS, error) { return kerberos.NewGSS() }) + pgconn.RegisterGSSProvider(func() (pgconn.GSS, error) { return gopgkrb5.NewGSS() }) } type sqlConn struct { @@ -42,13 +44,18 @@ type sqlConn struct { connCtx *Context url string - conn DriverConn + conn *pgx.Conn reconnecting bool // passwordMissing is true iff the url is missing a password. passwordMissing bool - pendingNotices []*pq.Error + // alwaysInferResultTypes is true iff the client should always use the + // underlying driver to infer result types. If it is true, multiple statements + // in a single string cannot be executed. + alwaysInferResultTypes bool + + pendingNotices []*pgconn.Notice // delayNotices, if set, makes notices accumulate for printing // when the SQL execution completes. The default (false) @@ -82,6 +89,10 @@ type sqlConn struct { var _ Conn = (*sqlConn)(nil) +// ErrConnectionClosed is returned when an operation fails because the +// connection was closed. +var ErrConnectionClosed = errors.New("connection closed unexpectedly") + // wrapConnError detects TCP EOF errors during the initial SQL handshake. // These are translated to a message "perhaps this is not a CockroachDB node" // at the top level. @@ -90,7 +101,10 @@ var _ Conn = (*sqlConn)(nil) // server. func wrapConnError(err error) error { errMsg := err.Error() - if errMsg == "EOF" || errMsg == "unexpected EOF" { + // pgconn wraps some of these errors with the private connectError struct. + isPgconnConnectError := strings.HasPrefix(errMsg, "failed to connect to") + isEOF := errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) + if errMsg == "EOF" || errMsg == "unexpected EOF" || (isPgconnConnectError && isEOF) { return &InitialSQLConnectionError{err} } return err @@ -98,13 +112,13 @@ func wrapConnError(err error) error { func (c *sqlConn) flushNotices() { for _, notice := range c.pendingNotices { - clierror.OutputError(c.errw, notice, true /*showSeverity*/, false /*verbose*/) + clierror.OutputError(c.errw, (*pgconn.PgError)(notice), true /*showSeverity*/, false /*verbose*/) } c.pendingNotices = nil c.delayNotices = false } -func (c *sqlConn) handleNotice(notice *pq.Error) { +func (c *sqlConn) handleNotice(notice *pgconn.Notice) { c.pendingNotices = append(c.pendingNotices, notice) if !c.delayNotices { c.flushNotices() @@ -123,7 +137,7 @@ func (c *sqlConn) SetURL(url string) { // GetDriverConn implements the Conn interface. func (c *sqlConn) GetDriverConn() DriverConn { - return c.conn + return &driverConnAdapter{c} } // SetCurrentDatabase implements the Conn interface. @@ -136,39 +150,52 @@ func (c *sqlConn) SetMissingPassword(missing bool) { c.passwordMissing = missing } +// SetAlwaysInferResultTypes implements the Conn interface. +func (c *sqlConn) SetAlwaysInferResultTypes(b bool) { + c.alwaysInferResultTypes = b +} + +// The default pgx dialer uses a KeepAlive of 5 minutes, which we don't want. +var defaultDialer = &net.Dialer{} + // EnsureConn (re-)establishes the connection to the server. -func (c *sqlConn) EnsureConn() error { +func (c *sqlConn) EnsureConn(ctx context.Context) error { if c.conn != nil { return nil } - ctx := context.Background() if c.reconnecting && c.connCtx.IsInteractive() { fmt.Fprintf(c.errw, "warning: connection lost!\n"+ "opening new connection: all session settings will be lost\n") } - base, err := pq.NewConnector(c.url) + base, err := pgx.ParseConfig(c.url) if err != nil { return wrapConnError(err) } // Add a notice handler - re-use the cliOutputError function in this case. - connector := pq.ConnectorWithNoticeHandler(base, func(notice *pq.Error) { + base.OnNotice = func(_ *pgconn.PgConn, notice *pgconn.Notice) { c.handleNotice(notice) - }) - // TODO(cli): we can't thread ctx through ensureConn usages, as it needs - // to follow the gosql.DB interface. We should probably look at initializing - // connections only once instead. The context is only used for dialing. - conn, err := connector.Connect(ctx) + } + // By default, pgx uses a Dialer with a KeepAlive of 5 minutes, which is not + // desired here, so we override it. + defaultDialer.Timeout = 0 + if base.ConnectTimeout > 0 { + defaultDialer.Timeout = base.ConnectTimeout + } + base.DialFunc = defaultDialer.DialContext + // Override LookupFunc to be a no-op, so that the Dialer is responsible for + // resolving hostnames. This fixes an issue where pgx would error out too + // quickly if using TLS when an ipv6 address is resolved, but the networking + // stack does not support ipv6. + base.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) { + return []string{host}, nil + } + conn, err := pgx.ConnectConfig(ctx, base) if err != nil { // Connection failed: if the failure is due to a missing // password, we're going to fill the password here. - // - // TODO(knz): CockroachDB servers do not properly fill SQLSTATE - // (28P01) for password auth errors, so we have to "make do" - // with a string match. This should be cleaned up by adding - // the missing code server-side. - errStr := strings.TrimPrefix(err.Error(), "pq: ") - if strings.HasPrefix(errStr, "password authentication failed") && c.passwordMissing { + pgErr := (*pgconn.PgError)(nil) + if errors.As(err, &pgErr) && pgErr.Code == pgcode.InvalidPassword.String() && c.passwordMissing { if pErr := c.fillPassword(); pErr != nil { return errors.CombineErrors(err, pErr) } @@ -177,19 +204,18 @@ func (c *sqlConn) EnsureConn() error { // The recursion only occurs once because fillPassword() // resets c.passwordMissing, so we cannot get into this // conditional a second time. - return c.EnsureConn() + return c.EnsureConn(ctx) } // Not a password auth error, or password already set. Simply fail. return wrapConnError(err) } if c.reconnecting && c.dbName != "" { // Attempt to reset the current database. - if _, err := conn.(DriverConn).ExecContext(ctx, `SET DATABASE = $1`, - []driver.NamedValue{{Value: c.dbName}}); err != nil { + if _, err := conn.Exec(ctx, `SET DATABASE = $1`, c.dbName); err != nil { fmt.Fprintf(c.errw, "warning: unable to restore current database: %v\n", err) } } - c.conn = conn.(DriverConn) + c.conn = conn if err := c.checkServerMetadata(ctx); err != nil { err = errors.CombineErrors(err, c.Close()) return wrapConnError(err) @@ -243,18 +269,18 @@ func (c *sqlConn) tryEnableServerExecutionTimings(ctx context.Context) error { func (c *sqlConn) GetServerMetadata( ctx context.Context, -) (nodeID int32, version, clusterID string, err error) { +) (nodeID int32, version, clusterID string, retErr error) { // Retrieve the node ID and server build info. // Be careful to query against the empty database string, which avoids taking // a lease against the current database (in case it's currently unavailable). rows, err := c.Query(ctx, `SELECT * FROM "".crdb_internal.node_build_info`) - if errors.Is(err, driver.ErrBadConn) { - return 0, "", "", err + if c.conn.IsClosed() { + return 0, "", "", MarkWithConnectionClosed(err) } if err != nil { return 0, "", "", err } - defer func() { _ = rows.Close() }() + defer func() { retErr = errors.CombineErrors(retErr, rows.Close()) }() // Read the node_build_info table as an array of strings. rowVals, err := getServerMetadataRows(rows) @@ -354,8 +380,8 @@ func (c *sqlConn) checkServerMetadata(ctx context.Context) error { } _, newServerVersion, newClusterID, err := c.GetServerMetadata(ctx) - if errors.Is(err, driver.ErrBadConn) { - return err + if c.conn.IsClosed() { + return MarkWithConnectionClosed(err) } if err != nil { // It is not an error that the server version cannot be retrieved. @@ -511,45 +537,64 @@ func (c *sqlConn) ExecTxn( } func (c *sqlConn) Exec(ctx context.Context, query string, args ...interface{}) error { - dVals, err := convertArgs(args) - if err != nil { - return err - } - if err := c.EnsureConn(); err != nil { + if err := c.EnsureConn(ctx); err != nil { return err } if c.connCtx.Echo { fmt.Fprintln(c.errw, ">", query) } - _, err = c.conn.ExecContext(ctx, query, dVals) + _, err := c.conn.Exec(ctx, query, args...) c.flushNotices() - if errors.Is(err, driver.ErrBadConn) { + if c.conn.IsClosed() { c.reconnecting = true c.silentClose() + return MarkWithConnectionClosed(err) } return err } func (c *sqlConn) Query(ctx context.Context, query string, args ...interface{}) (Rows, error) { - dVals, err := convertArgs(args) - if err != nil { - return nil, err - } - if err := c.EnsureConn(); err != nil { + if err := c.EnsureConn(ctx); err != nil { return nil, err } if c.connCtx.Echo { fmt.Fprintln(c.errw, ">", query) } - rows, err := c.conn.QueryContext(ctx, query, dVals) - if errors.Is(err, driver.ErrBadConn) { + + // If there are placeholder args, then we must use a prepared statement, + // which is only possible using pgx.Conn. + // Or, if alwaysInferResultTypes is set, then we must use pgx.Conn so the + // result types are automatically inferred. + if len(args) > 0 || c.alwaysInferResultTypes { + rows, err := c.conn.Query(ctx, query, args...) + if c.conn.IsClosed() { + c.reconnecting = true + c.silentClose() + return nil, MarkWithConnectionClosed(err) + } + if err != nil { + return nil, err + } + return &sqlRows{rows: rows, connInfo: c.conn.ConnInfo(), conn: c}, nil + } + + // Otherwise, we use pgconn. This allows us to add support for multiple + // queries in a single string, which wouldn't be possible at the pgx level. + multiResultReader := c.conn.PgConn().Exec(ctx, query) + if c.conn.IsClosed() { c.reconnecting = true c.silentClose() + return nil, MarkWithConnectionClosed(multiResultReader.Close()) } - if err != nil { + rs := &sqlRowsMultiResultSet{ + rows: multiResultReader, + connInfo: c.conn.ConnInfo(), + conn: c, + } + if _, err := rs.NextResultSet(); err != nil { return nil, err } - return &sqlRows{rows: rows.(sqlRowsI), conn: c}, nil + return rs, nil } func (c *sqlConn) QueryRow( @@ -589,7 +634,7 @@ func (c *sqlConn) queryRowInternal( func (c *sqlConn) Close() error { c.flushNotices() if c.conn != nil { - err := c.conn.Close() + err := c.conn.Close(context.Background()) if err != nil { return err } @@ -600,7 +645,7 @@ func (c *sqlConn) Close() error { func (c *sqlConn) silentClose() { if c.conn != nil { - _ = c.conn.Close() + _ = c.conn.Close(context.Background()) c.conn = nil } } @@ -645,3 +690,20 @@ func (c *sqlConn) fillPassword() error { c.passwordMissing = false return nil } + +// MarkWithConnectionClosed is a mix of errors.CombineErrors() and errors.Mark(). +// If err is nil, the result is like errors.CombineErrors(). +// If err is non-nil, the result is that of errors.Mark(), with the error +// message added as a prefix. +// +// In both cases, errors.Is(..., ErrConnectionClosed) returns true on the +// result. +func MarkWithConnectionClosed(err error) error { + if err == nil { + return ErrConnectionClosed + } + // Disable the linter since we are intentionally adding the error text. + // nolint:errwrap + errWithMsg := errors.WithMessagef(err, "%v", ErrConnectionClosed) + return errors.Mark(errWithMsg, ErrConnectionClosed) +} diff --git a/pkg/cli/clisqlclient/conn_test.go b/pkg/cli/clisqlclient/conn_test.go index e047fb5e684b..73bf2e9a9dc8 100644 --- a/pkg/cli/clisqlclient/conn_test.go +++ b/pkg/cli/clisqlclient/conn_test.go @@ -12,7 +12,6 @@ package clisqlclient_test import ( "context" - "database/sql/driver" "io/ioutil" "net/url" "testing" @@ -70,8 +69,8 @@ func TestConnRecover(t *testing.T) { if closeErr := sqlRows.Close(); closeErr != nil { t.Fatal(closeErr) } - } else if !errors.Is(err, driver.ErrBadConn) { - return errors.Newf("expected ErrBadConn, got %v", err) // nolint:errwrap + } else if !errors.Is(err, clisqlclient.ErrConnectionClosed) { + return errors.Newf("expected ErrConnectionClosed, got %v", err) // nolint:errwrap } return nil }) @@ -90,8 +89,8 @@ func TestConnRecover(t *testing.T) { // Ditto from Query(). testutils.SucceedsSoon(t, func() error { - if err := conn.Exec(ctx, `SELECT 1`); !errors.Is(err, driver.ErrBadConn) { - return errors.Newf("expected ErrBadConn, got %v", err) // nolint:errwrap + if err := conn.Exec(ctx, `SELECT 1`); !errors.Is(err, clisqlclient.ErrConnectionClosed) { + return errors.Newf("expected ErrConnectionClosed, got %v", err) // nolint:errwrap } return nil }) diff --git a/pkg/cli/clisqlclient/copy.go b/pkg/cli/clisqlclient/copy.go index 733aaaefeba6..ab93894525f3 100644 --- a/pkg/cli/clisqlclient/copy.go +++ b/pkg/cli/clisqlclient/copy.go @@ -11,43 +11,38 @@ package clisqlclient import ( + "bytes" "context" "database/sql/driver" "io" "reflect" - "strings" - "github.com/cockroachdb/errors" + "github.com/jackc/pgconn" ) -type copyFromer interface { - CopyData(ctx context.Context, line string) (r driver.Result, err error) - Exec(v []driver.Value) (r driver.Result, err error) - Close() error -} - // CopyFromState represents an in progress COPY FROM. type CopyFromState struct { - driver.Tx - copyFromer + conn *pgconn.PgConn + query string } // BeginCopyFrom starts a COPY FROM query. func BeginCopyFrom(ctx context.Context, conn Conn, query string) (*CopyFromState, error) { - txn, err := conn.(*sqlConn).conn.(driver.ConnBeginTx).BeginTx(ctx, driver.TxOptions{}) - if err != nil { + copyConn := conn.(*sqlConn).conn.PgConn() + // Run the initial query, but don't use the result so that we can get any + // errors early. + if _, err := copyConn.CopyFrom(ctx, bytes.NewReader([]byte{}), query); err != nil { return nil, err } - stmt, err := txn.(driver.Conn).Prepare(query) - if err != nil { - return nil, errors.CombineErrors(err, txn.Rollback()) - } - return &CopyFromState{Tx: txn, copyFromer: stmt.(copyFromer)}, nil + return &CopyFromState{ + conn: copyConn, + query: query, + }, nil } // copyFromRows is a mock Rows interface for COPY results. type copyFromRows struct { - r driver.Result + t pgconn.CommandTag } func (c copyFromRows) Close() error { @@ -70,12 +65,8 @@ func (c copyFromRows) ColumnTypeNames() []string { return nil } -func (c copyFromRows) Result() driver.Result { - return c.r -} - -func (c copyFromRows) Tag() string { - return "COPY" +func (c copyFromRows) Tag() (CommandTag, error) { + return c.t, nil } func (c copyFromRows) Next(values []driver.Value) error { @@ -88,7 +79,7 @@ func (c copyFromRows) NextResultSet() (bool, error) { // Cancel cancels a COPY FROM query from completing. func (c *CopyFromState) Cancel() error { - return errors.CombineErrors(c.copyFromer.Close(), c.Tx.Rollback()) + return nil } // Commit completes a COPY FROM query by committing lines to the database. @@ -96,21 +87,13 @@ func (c *CopyFromState) Commit(ctx context.Context, cleanupFunc func(), lines st return func(ctx context.Context, conn Conn) (Rows, bool, error) { defer cleanupFunc() rows, isMulti, err := func() (Rows, bool, error) { - for _, l := range strings.Split(lines, "\n") { - _, err := c.copyFromer.CopyData(ctx, l) - if err != nil { - return nil, false, err - } - } - r, err := c.copyFromer.Exec(nil) + r := bytes.NewReader([]byte(lines)) + tag, err := c.conn.CopyFrom(ctx, r, c.query) if err != nil { return nil, false, err } - return copyFromRows{r: r}, false, c.Tx.Commit() + return copyFromRows{tag}, false, nil }() - if err != nil { - return rows, isMulti, errors.CombineErrors(err, errors.CombineErrors(c.copyFromer.Close(), c.Tx.Rollback())) - } return rows, isMulti, err } } diff --git a/pkg/cli/clisqlclient/make_query.go b/pkg/cli/clisqlclient/make_query.go index e43a3f3e02a3..7097f474c166 100644 --- a/pkg/cli/clisqlclient/make_query.go +++ b/pkg/cli/clisqlclient/make_query.go @@ -12,7 +12,6 @@ package clisqlclient import ( "context" - "database/sql/driver" "github.com/cockroachdb/cockroach/pkg/sql/scanner" ) @@ -29,22 +28,3 @@ func MakeQuery(query string, parameters ...interface{}) QueryFn { return rows, isMultiStatementQuery, err } } - -func convertArgs(parameters []interface{}) ([]driver.NamedValue, error) { - dVals := make([]driver.NamedValue, len(parameters)) - for i := range parameters { - // driver.NamedValue.Value is an alias for interface{}, but must adhere to a restricted - // set of types when being passed to driver.Queryer.Query (see - // driver.IsValue). We use driver.DefaultParameterConverter to perform the - // necessary conversion. This is usually taken care of by the sql package, - // but we have to do so manually because we're talking directly to the - // driver. - var err error - dVals[i].Ordinal = i + 1 - dVals[i].Value, err = driver.DefaultParameterConverter.ConvertValue(parameters[i]) - if err != nil { - return nil, err - } - } - return dVals, nil -} diff --git a/pkg/cli/clisqlclient/row_type_helpers.go b/pkg/cli/clisqlclient/row_type_helpers.go new file mode 100644 index 000000000000..39e728b51867 --- /dev/null +++ b/pkg/cli/clisqlclient/row_type_helpers.go @@ -0,0 +1,72 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package clisqlclient + +import ( + "reflect" + "strings" + "time" + + "github.com/jackc/pgtype" +) + +// scanType returns the Go type used for the given column type. +func scanType(typeOID uint32, typeName string) reflect.Type { + if typeName == "" || strings.HasPrefix(typeName, "_") { + // User-defined types and array types are scanned into []byte. + // These are handled separately since we can't easily include all the OIDs + // for user-defined types and arrays in the switch statement. + return reflect.TypeOf([]byte(nil)) + } + // This switch is copied from lib/pq, and modified with a few additional cases. + // https://github.com/lib/pq/blob/b2901c7946b69f1e7226214f9760e31620499595/rows.go#L24 + switch typeOID { + case pgtype.Int8OID: + return reflect.TypeOf(int64(0)) + case pgtype.Int4OID: + return reflect.TypeOf(int32(0)) + case pgtype.Int2OID: + return reflect.TypeOf(int16(0)) + case pgtype.VarcharOID, pgtype.TextOID, pgtype.NameOID, + pgtype.ByteaOID, pgtype.NumericOID, pgtype.RecordOID, + pgtype.QCharOID, pgtype.BPCharOID: + return reflect.TypeOf("") + case pgtype.BoolOID: + return reflect.TypeOf(false) + case pgtype.DateOID, pgtype.TimeOID, 1266, pgtype.TimestampOID, pgtype.TimestamptzOID: + // 1266 is the OID for TimeTZ. + // TODO(rafi): Add TimetzOID to pgtype. + return reflect.TypeOf(time.Time{}) + default: + return reflect.TypeOf(new(interface{})).Elem() + } +} + +// databaseTypeName returns the database type name for the given type OID. +func databaseTypeName(ci *pgtype.ConnInfo, typeOID uint32) string { + dataType, ok := ci.DataTypeForOID(typeOID) + if !ok { + // TODO(rafi): remove special logic once jackc/pgtype includes these types. + switch typeOID { + case 1002: + return "_CHAR" + case 1003: + return "_NAME" + case 1266: + return "TIMETZ" + case 1270: + return "_TIMETZ" + default: + return "" + } + } + return strings.ToUpper(dataType.Name) +} diff --git a/pkg/cli/clisqlclient/rows.go b/pkg/cli/clisqlclient/rows.go index 270098763b4e..0481d2763fa1 100644 --- a/pkg/cli/clisqlclient/rows.go +++ b/pkg/cli/clisqlclient/rows.go @@ -12,62 +12,68 @@ package clisqlclient import ( "database/sql/driver" + "io" "reflect" - "github.com/cockroachdb/errors" + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4" ) type sqlRows struct { - rows sqlRowsI - conn *sqlConn + rows pgx.Rows + connInfo *pgtype.ConnInfo + conn *sqlConn + colNames []string } var _ Rows = (*sqlRows)(nil) -type sqlRowsI interface { - driver.RowsColumnTypeScanType - driver.RowsColumnTypeDatabaseTypeName - Result() driver.Result - Tag() string - - // Go 1.8 multiple result set interfaces. - // TODO(mjibson): clean this up after 1.8 is released. - HasNextResultSet() bool - NextResultSet() error -} - func (r *sqlRows) Columns() []string { - return r.rows.Columns() -} - -func (r *sqlRows) Result() driver.Result { - return r.rows.Result() + if r.colNames == nil { + fields := r.rows.FieldDescriptions() + r.colNames = make([]string, len(fields)) + for i, fd := range fields { + r.colNames[i] = string(fd.Name) + } + } + return r.colNames } -func (r *sqlRows) Tag() string { - return r.rows.Tag() +func (r *sqlRows) Tag() (CommandTag, error) { + return r.rows.CommandTag(), r.rows.Err() } func (r *sqlRows) Close() error { r.conn.flushNotices() - err := r.rows.Close() - if errors.Is(err, driver.ErrBadConn) { + r.rows.Close() + if r.conn.conn.IsClosed() { r.conn.reconnecting = true r.conn.silentClose() + return ErrConnectionClosed } - return err + return r.rows.Err() } // Next implements the Rows interface. func (r *sqlRows) Next(values []driver.Value) error { - err := r.rows.Next(values) - if errors.Is(err, driver.ErrBadConn) { + if r.conn.conn.IsClosed() { r.conn.reconnecting = true r.conn.silentClose() + return ErrConnectionClosed + } + if !r.rows.Next() { + return io.EOF } - for i, v := range values { - if b, ok := v.([]byte); ok { + rawVals, err := r.rows.Values() + if err != nil { + return err + } + for i, v := range rawVals { + if b, ok := (v).([]byte); ok { + // Copy byte slices as per the comment on Rows.Next. values[i] = append([]byte{}, b...) + } else { + values[i] = v } } // After the first row was received, we want to delay all @@ -78,18 +84,18 @@ func (r *sqlRows) Next(values []driver.Value) error { // NextResultSet prepares the next result set for reading. func (r *sqlRows) NextResultSet() (bool, error) { - if !r.rows.HasNextResultSet() { - return false, nil - } - return true, r.rows.NextResultSet() + return false, nil } func (r *sqlRows) ColumnTypeScanType(index int) reflect.Type { - return r.rows.ColumnTypeScanType(index) + o := r.rows.FieldDescriptions()[index].DataTypeOID + n := r.ColumnTypeDatabaseTypeName(index) + return scanType(o, n) } func (r *sqlRows) ColumnTypeDatabaseTypeName(index int) string { - return r.rows.ColumnTypeDatabaseTypeName(index) + fieldOID := r.rows.FieldDescriptions()[index].DataTypeOID + return databaseTypeName(r.connInfo, fieldOID) } func (r *sqlRows) ColumnTypeNames() []string { diff --git a/pkg/cli/clisqlclient/rows_multi.go b/pkg/cli/clisqlclient/rows_multi.go new file mode 100644 index 000000000000..55a452ef73d6 --- /dev/null +++ b/pkg/cli/clisqlclient/rows_multi.go @@ -0,0 +1,157 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package clisqlclient + +import ( + "database/sql/driver" + "io" + "reflect" + + "github.com/cockroachdb/errors" + "github.com/jackc/pgconn" + "github.com/jackc/pgtype" + "github.com/jackc/pgx/v4" +) + +type sqlRowsMultiResultSet struct { + rows *pgconn.MultiResultReader + connInfo *pgtype.ConnInfo + conn *sqlConn + colNames []string +} + +var _ Rows = (*sqlRowsMultiResultSet)(nil) + +func (r *sqlRowsMultiResultSet) Columns() []string { + if r.colNames == nil { + rr := r.rows.ResultReader() + if rr == nil { + // ResultReader may be nil if an empty query was executed. + return nil + } + fields := rr.FieldDescriptions() + r.colNames = make([]string, len(fields)) + for i, fd := range fields { + r.colNames[i] = string(fd.Name) + } + } + return r.colNames +} + +func (r *sqlRowsMultiResultSet) Tag() (CommandTag, error) { + if rr := r.rows.ResultReader(); rr != nil { + // ResultReader may be nil if an empty query was executed. + return r.rows.ResultReader().Close() + } + return pgconn.CommandTag(""), nil +} + +func (r *sqlRowsMultiResultSet) Close() (retErr error) { + r.conn.flushNotices() + if rr := r.rows.ResultReader(); rr != nil { + // ResultReader may be nil if an empty query was executed. + _, retErr = r.rows.ResultReader().Close() + } + retErr = errors.CombineErrors(retErr, r.rows.Close()) + if r.conn.conn.IsClosed() { + r.conn.reconnecting = true + r.conn.silentClose() + return MarkWithConnectionClosed(retErr) + } + return retErr +} + +// Next implements the Rows interface. +func (r *sqlRowsMultiResultSet) Next(values []driver.Value) error { + if r.conn.conn.IsClosed() { + r.conn.reconnecting = true + r.conn.silentClose() + return ErrConnectionClosed + } + rd := r.rows.ResultReader() + if rd == nil { + // ResultReader may be nil if an empty query was executed. + return io.EOF + } + if !rd.NextRow() { + if _, err := rd.Close(); err != nil { + return err + } + return io.EOF + } + if len(rd.FieldDescriptions()) != len(values) { + return errors.AssertionFailedf( + "number of field descriptions must equal number of destinations, got %d and %d", + len(rd.FieldDescriptions()), + len(values), + ) + } + for i := range values { + rowVal := rd.Values()[i] + if rowVal == nil { + values[i] = "NULL" + continue + } + fieldOID := rd.FieldDescriptions()[i].DataTypeOID + fieldFormat := rd.FieldDescriptions()[i].Format + + // By scanning the value into a string, pgconn will use the pgwire + // text format to represent the value. + var s string + err := r.connInfo.Scan(fieldOID, fieldFormat, rowVal, &s) + if err != nil { + return pgx.ScanArgError{ColumnIndex: i, Err: err} + } + values[i] = s + } + // After the first row was received, we want to delay all + // further notices until the end of execution. + r.conn.delayNotices = true + return nil +} + +// NextResultSet prepares the next result set for reading. +func (r *sqlRowsMultiResultSet) NextResultSet() (bool, error) { + r.colNames = nil + next := r.rows.NextResult() + if !next { + if err := r.rows.Close(); err != nil { + return false, err + } + } + if r.conn.conn.IsClosed() { + r.conn.reconnecting = true + r.conn.silentClose() + return false, ErrConnectionClosed + } + return next, nil +} + +func (r *sqlRowsMultiResultSet) ColumnTypeScanType(index int) reflect.Type { + rd := r.rows.ResultReader() + o := rd.FieldDescriptions()[index].DataTypeOID + n := r.ColumnTypeDatabaseTypeName(index) + return scanType(o, n) +} + +func (r *sqlRowsMultiResultSet) ColumnTypeDatabaseTypeName(index int) string { + rd := r.rows.ResultReader() + fieldOID := rd.FieldDescriptions()[index].DataTypeOID + return databaseTypeName(r.connInfo, fieldOID) +} + +func (r *sqlRowsMultiResultSet) ColumnTypeNames() []string { + colTypes := make([]string, len(r.Columns())) + for i := range colTypes { + colTypes[i] = r.ColumnTypeDatabaseTypeName(i) + } + return colTypes +} diff --git a/pkg/cli/clisqlexec/BUILD.bazel b/pkg/cli/clisqlexec/BUILD.bazel index 720b621f0c6e..f4f2a6be749b 100644 --- a/pkg/cli/clisqlexec/BUILD.bazel +++ b/pkg/cli/clisqlexec/BUILD.bazel @@ -23,7 +23,7 @@ go_library( "//pkg/util/syncutil", "//pkg/util/timeutil", "@com_github_cockroachdb_errors//:errors", - "@com_github_lib_pq//:pq", + "@com_github_jackc_pgtype//:pgtype", "@com_github_olekukonko_tablewriter//:tablewriter", "@com_github_spf13_pflag//:pflag", ], diff --git a/pkg/cli/clisqlexec/format_table.go b/pkg/cli/clisqlexec/format_table.go index f2e8d20bec96..28a1f7fca6d7 100644 --- a/pkg/cli/clisqlexec/format_table.go +++ b/pkg/cli/clisqlexec/format_table.go @@ -181,13 +181,13 @@ func render( iter RowStrIter, completedHook func(), noRowsHook func() (bool, error), -) (err error) { +) (retErr error) { described := false nRows := 0 defer func() { // If the column headers are not printed yet, do it now. if !described { - err = errors.WithSecondaryError(err, r.describe(w, cols)) + retErr = errors.CombineErrors(retErr, r.describe(w, cols)) } // completedHook, if provided, is called unconditionally of error. @@ -198,18 +198,20 @@ func render( // We need to call doneNoRows/doneRows also unconditionally. var handled bool if nRows == 0 && noRowsHook != nil { - handled, err = noRowsHook() - if err != nil { + var noRowsErr error + handled, noRowsErr = noRowsHook() + if noRowsErr != nil { + retErr = errors.CombineErrors(retErr, noRowsErr) return } } if handled { - err = errors.WithSecondaryError(err, r.doneNoRows(w)) + retErr = errors.CombineErrors(retErr, r.doneNoRows(w)) } else { - err = errors.WithSecondaryError(err, r.doneRows(w, nRows)) + retErr = errors.CombineErrors(retErr, r.doneRows(w, nRows)) } - if err != nil && nRows > 0 { + if retErr != nil && nRows > 0 { fmt.Fprintf(ew, "(error encountered after some results were delivered)\n") } }() diff --git a/pkg/cli/clisqlexec/format_table_test.go b/pkg/cli/clisqlexec/format_table_test.go index 92a27b0a6c7b..0e22f0c5a962 100644 --- a/pkg/cli/clisqlexec/format_table_test.go +++ b/pkg/cli/clisqlexec/format_table_test.go @@ -64,7 +64,7 @@ thenshort`, // thenshort" int, "κόσμε" int, "a|b" int, ܈85 int) // CREATE TABLE // sql -e insert into t.u values (0, 0, 0, 0, 0, 0, 0, 0) - // INSERT 1 + // INSERT 0 1 // sql -e show columns from t.u // column_name data_type is_nullable column_default generation_expression indices // "f""oo" INT true NULL {} @@ -193,7 +193,11 @@ func Example_sql_empty_table() { // Output: // sql -e create database t;create table t.norows(x int);create table t.nocolsnorows();create table t.nocols(); insert into t.nocols(rowid) values (1),(2),(3); - // INSERT 3 + // CREATE DATABASE + // CREATE TABLE + // CREATE TABLE + // CREATE TABLE + // INSERT 0 3 // sql --format=tsv -e select * from t.norows // x // sql --format=csv -e select * from t.norows @@ -489,25 +493,26 @@ func Example_sql_table() { // Output: // sql -e create database t; create table t.t (s string, d string); + // CREATE DATABASE // CREATE TABLE // sql -e insert into t.t values (e'foo', 'printable ASCII') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'"foo', 'printable ASCII with quotes') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'\\foo', 'printable ASCII with backslash') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'foo\x0abar', 'non-printable ASCII') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values ('κόσμε', 'printable UTF8') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'\xc3\xb1', 'printable UTF8 using escapes') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'\x01', 'non-printable UTF8 string') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'\xdc\x88\x38\x35', 'UTF8 string with RTL char') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'a\tb\tc\n12\t123123213\t12313', 'tabs') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'\xc3\x28', 'non-UTF8 string') // ERROR: lexical error: invalid UTF-8 byte sequence // SQLSTATE: 42601 diff --git a/pkg/cli/clisqlexec/format_value.go b/pkg/cli/clisqlexec/format_value.go index 6697d8b48ad8..41449c6d09d9 100644 --- a/pkg/cli/clisqlexec/format_value.go +++ b/pkg/cli/clisqlexec/format_value.go @@ -12,7 +12,6 @@ package clisqlexec import ( "bytes" - gosql "database/sql" "database/sql/driver" "fmt" "math" @@ -26,7 +25,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/lexbase" "github.com/cockroachdb/cockroach/pkg/util/timeutil" - "github.com/lib/pq" + "github.com/jackc/pgtype" ) func isNotPrintableASCII(r rune) bool { return r < 0x20 || r > 0x7e || r == '"' || r == '\\' } @@ -44,7 +43,6 @@ func FormatVal( if strings.HasPrefix(colType, "_") && len(b) > 0 && b[0] == '{' { return formatArray(b, colType[1:], showPrintableUnicode, showNewLinesAndTabs) } - // Names, records, and user-defined types should all be displayed as strings. if colType == "NAME" || colType == "RECORD" || colType == "" { val = string(b) @@ -88,19 +86,13 @@ func FormatVal( return s[1 : len(s)-1] case []byte: - // Format the bytes as per bytea_output = escape. - // - // We use the "escape" format here because it enables printing - // readable strings as-is -- the default hex format would always - // render as hexadecimal digits. The escape format is also more - // compact. + // For other []byte types that weren't handled above, we use the "escape" + // format because it enables printing readable strings as-is -- the default + // hex format would always render as hexadecimal digits. The escape format + // is also more compact. // - // TODO(knz): this formatting is unfortunate/incorrect, and exists - // only because lib/pq incorrectly interprets the bytes received - // from the server. The proper behavior would be for the driver to - // not interpret the bytes and for us here to print that as-is, so - // that we can let the user see and control the result using - // `bytea_output`. + // Note that the BYTEA type is already a string at this point, so is not + // handled here. var buf bytes.Buffer lexbase.EncodeSQLBytesInner(&buf, string(t)) return buf.String() @@ -130,41 +122,45 @@ func formatArray( // backingArray is the array we're going to parse the server data // into. var backingArray interface{} - // parsingArray is a helper structure provided by lib/pq to parse + + // parsingArray is a helper structure provided by pgtype to parse // arrays. - var parsingArray gosql.Scanner + var parsingArray pgtype.Value - // lib.pq has different array parsers for special value types. - // - // TODO(knz): This would better use a general-purpose parser - // using the OID to look up an array parser in crdb's sql package. - // However, unfortunately the OID is hidden from us. + // pgx has different array parsers for special value types. switch colType { case "BOOL": boolArray := []bool{} backingArray = &boolArray - parsingArray = (*pq.BoolArray)(&boolArray) + parsingArray = &pgtype.BoolArray{} case "FLOAT4", "FLOAT8": floatArray := []float64{} backingArray = &floatArray - parsingArray = (*pq.Float64Array)(&floatArray) + parsingArray = &pgtype.Float8Array{} case "INT2", "INT4", "INT8", "OID": intArray := []int64{} backingArray = &intArray - parsingArray = (*pq.Int64Array)(&intArray) + parsingArray = &pgtype.Int8Array{} case "TEXT", "VARCHAR", "NAME", "CHAR", "BPCHAR", "RECORD": stringArray := []string{} backingArray = &stringArray - parsingArray = (*pq.StringArray)(&stringArray) - default: - genArray := [][]byte{} - backingArray = &genArray - parsingArray = &pq.GenericArray{A: &genArray} + parsingArray = &pgtype.TextArray{} } - // Now ask the pq array parser to convert the byte slice + // Now ask the pgx array parser to convert the byte slice // from the server into a Go array. - if err := parsingArray.Scan(b); err != nil { + var parseErr error + if parsingArray != nil { + parseErr = parsingArray.(pgtype.TextDecoder).DecodeText(nil, b) + if parseErr == nil { + parseErr = parsingArray.AssignTo(backingArray) + } + } else { + var untypedArray *pgtype.UntypedTextArray + untypedArray, parseErr = pgtype.ParseUntypedTextArray(string(b)) + backingArray = &untypedArray.Elements + } + if parseErr != nil { // A parsing failure is not a catastrophe; we can still print out // the array as a byte slice. This will do in many cases. return FormatVal(b, "BYTEA", showPrintableUnicode, showNewLinesAndTabs) diff --git a/pkg/cli/clisqlexec/format_value_test.go b/pkg/cli/clisqlexec/format_value_test.go index 5c93acef4561..d792f7f4a9cb 100644 --- a/pkg/cli/clisqlexec/format_value_test.go +++ b/pkg/cli/clisqlexec/format_value_test.go @@ -38,12 +38,22 @@ func Example_sql_format() { c.RunWithArgs([]string{"sql", "-e", `select '{"(n,1)","(\n,2)","(\\n,3)"}'::tup[]`}) c.RunWithArgs([]string{"sql", "-e", `create type e as enum('a', '\n')`}) c.RunWithArgs([]string{"sql", "-e", `select '\n'::e, '{a, "\\n"}'::e[]`}) + // Check that intervals are formatted correctly. + c.RunWithArgs([]string{"sql", "-e", `select '1 day 2 minutes'::interval`}) + c.RunWithArgs([]string{"sql", "-e", `select '{3 months 4 days 2 hours, 4 years 1 second}'::interval[]`}) + c.RunWithArgs([]string{"sql", "-e", `SET intervalstyle = sql_standard; select '1 day 2 minutes'::interval`}) + c.RunWithArgs([]string{"sql", "-e", `SET intervalstyle = sql_standard; select '{3 months 4 days 2 hours,4 years 1 second}'::interval[]`}) + c.RunWithArgs([]string{"sql", "-e", `SET intervalstyle = iso_8601; select '1 day 2 minutes'::interval`}) + c.RunWithArgs([]string{"sql", "-e", `SET intervalstyle = iso_8601; select '{3 months 4 days 2 hours,4 years 1 second}'::interval[]`}) + // Check that UUIDs are formatted correctly. + c.RunWithArgs([]string{"sql", "-e", `select 'f2b046d8-dc59-11ec-8b22-cbf9a9dd2f5f'::uuid`}) // Output: // sql -e create database t; create table t.times (bare timestamp, withtz timestamptz) + // CREATE DATABASE // CREATE TABLE // sql -e insert into t.times values ('2016-01-25 10:10:10', '2016-01-25 10:10:10-05:00') - // INSERT 1 + // INSERT 0 1 // sql -e select bare from t.times; select withtz from t.times // bare // 2016-01-25 10:10:10 @@ -81,7 +91,7 @@ func Example_sql_format() { // {Infinity,-Infinity} {Infinity} // sql -e select array[true, false], array['01:01'::time], array['2021-03-20'::date] // array array array - // {true,false} {01:01:00} {2021-03-20} + // {t,f} {01:01:00} {2021-03-20} // sql -e select array[123::int2], array[123::int4], array[123::int8] // array array array // {123} {123} {123} @@ -98,4 +108,29 @@ func Example_sql_format() { // sql -e select '\n'::e, '{a, "\\n"}'::e[] // e e // \n "{a,""\\n""}" + // sql -e select '1 day 2 minutes'::interval + // interval + // 1 day 00:02:00 + // sql -e select '{3 months 4 days 2 hours, 4 years 1 second}'::interval[] + // interval + // "{""3 mons 4 days 02:00:00"",""4 years 00:00:01""}" + // sql -e SET intervalstyle = sql_standard; select '1 day 2 minutes'::interval + // SET + // interval + // 1 0:02:00 + // sql -e SET intervalstyle = sql_standard; select '{3 months 4 days 2 hours,4 years 1 second}'::interval[] + // SET + // interval + // "{""+0-3 +4 +2:00:00"",""+4-0 +0 +0:00:01""}" + // sql -e SET intervalstyle = iso_8601; select '1 day 2 minutes'::interval + // SET + // interval + // P1DT2M + // sql -e SET intervalstyle = iso_8601; select '{3 months 4 days 2 hours,4 years 1 second}'::interval[] + // SET + // interval + // {P3M4DT2H,P4YT1S} + // sql -e select 'f2b046d8-dc59-11ec-8b22-cbf9a9dd2f5f'::uuid + // uuid + // f2b046d8-dc59-11ec-8b22-cbf9a9dd2f5f } diff --git a/pkg/cli/clisqlexec/run_query.go b/pkg/cli/clisqlexec/run_query.go index 3754a36d6169..9f52f3112f6f 100644 --- a/pkg/cli/clisqlexec/run_query.go +++ b/pkg/cli/clisqlexec/run_query.go @@ -12,7 +12,6 @@ package clisqlexec import ( "context" - "database/sql/driver" "fmt" "io" "strings" @@ -28,13 +27,13 @@ import ( // It runs the sql query and returns a list of columns names and a list of rows. func (sqlExecCtx *Context) RunQuery( ctx context.Context, conn clisqlclient.Conn, fn clisqlclient.QueryFn, showMoreChars bool, -) ([]string, [][]string, error) { +) (retCols []string, retRows [][]string, retErr error) { rows, _, err := fn(ctx, conn) if err != nil { return nil, nil, err } - defer func() { _ = rows.Close() }() + defer func() { retErr = errors.CombineErrors(retErr, rows.Close()) }() return sqlRowsToStrings(rows, showMoreChars) } @@ -54,7 +53,7 @@ func (sqlExecCtx *Context) RunQueryAndFormatResults( err = errors.CombineErrors(err, closeErr) }() for { - // lib/pq is not able to tell us before the first call to Next() + // pgx is not able to tell us before the first call to Next() // whether a statement returns either // - a rows result set with zero rows (e.g. SELECT on an empty table), or // - no rows result set, but a valid value for RowsAffected (e.g. INSERT), or @@ -65,57 +64,30 @@ func (sqlExecCtx *Context) RunQueryAndFormatResults( // when Next() has completed its work and no rows where observed, to decide // what to do. noRowsHook := func() (bool, error) { - res := rows.Result() - if ra, ok := res.(driver.RowsAffected); ok { - nRows, err := ra.RowsAffected() - if err != nil { - return false, err - } + tag, err := rows.Tag() + if err != nil { + return false, err + } + nRows := tag.RowsAffected() + tagString := tag.String() - // This may be either something like INSERT with a valid - // RowsAffected value, or a statement like SET. The pq driver - // uses both driver.RowsAffected for both. So we need to be a - // little more manual. - tag := rows.Tag() - if tag == "SELECT" && nRows == 0 { - // As explained above, the pq driver unhelpfully does not - // distinguish between a statement returning zero rows and a - // statement returning an affected row count of zero. - // noRowsHook is called non-discriminatingly for both - // situations. - // - // TODO(knz): meanwhile, there are rare, non-SELECT - // statements that have tag "SELECT" but are legitimately of - // type RowsAffected. CREATE TABLE AS is one. pq's inability - // to distinguish those two cases means that any non-SELECT - // statement that legitimately returns 0 rows affected, and - // for which the user would expect to see "SELECT 0", will - // be incorrectly displayed as an empty row result set - // instead. This needs to be addressed by ensuring pq can - // distinguish the two cases, or switching to an entirely - // different driver altogether. - // - return false, nil - } else if _, ok := tagsWithRowsAffected[tag]; ok { - // INSERT, DELETE, etc.: print the row count. - nRows, err := ra.RowsAffected() - if err != nil { - return false, err - } - fmt.Fprintf(w, "%s %d\n", tag, nRows) - } else { - // SET, etc.: just print the tag, or OK if there's no tag. - if tag == "" { - tag = "OK" - } - fmt.Fprintln(w, tag) - } - return true, nil + // This may be either something like INSERT with a valid + // RowsAffected value, or a statement like SET. + if strings.HasPrefix(tagString, "SELECT") && nRows == 0 { + // As explained above, the pgx driver unhelpfully does not + // distinguish between a statement returning zero rows and a + // statement returning an affected row count of zero. + // noRowsHook is called non-discriminatingly for both + // situations. + return false, nil + } + // SET, etc.: just print the tag, or OK if there's no tag. + // INSERT, DELETE, etc. all contain the row count in the tag itself. + if tagString == "" { + tagString = "OK" } - // Other cases: this is a statement with a rows result set, but - // zero rows (e.g. SELECT on empty table). Let the reporter - // handle it. - return false, nil + fmt.Fprintln(w, tagString) + return true, nil } cols := getColumnStrings(rows, true) @@ -136,11 +108,12 @@ func (sqlExecCtx *Context) RunQueryAndFormatResults( return err } - sqlExecCtx.maybeShowTimes(ctx, conn, w, ew, isMultiStatementQuery, startTime, queryCompleteTime) - if more, err := rows.NextResultSet(); err != nil { return err } else if !more { + // We must call maybeShowTimes after rows has been closed, which is after + // NextResultSet returns false. + sqlExecCtx.maybeShowTimes(ctx, conn, w, ew, isMultiStatementQuery, startTime, queryCompleteTime) return nil } } @@ -274,20 +247,6 @@ func (sqlExecCtx *Context) maybeShowTimes( fmt.Fprintln(w, stats.String()) } -// All tags where the RowsAffected value should be reported to -// the user. -var tagsWithRowsAffected = map[string]struct{}{ - "INSERT": {}, - "UPDATE": {}, - "DELETE": {}, - "MOVE": {}, - "DROP USER": {}, - "COPY": {}, - // This one is used with e.g. CREATE TABLE AS (other SELECT - // statements have type Rows, not RowsAffected). - "SELECT": {}, -} - // sqlRowsToStrings turns 'rows' into a list of rows, each of which // is a list of column values. // 'rows' should be closed by the caller. diff --git a/pkg/cli/clisqlexec/run_query_test.go b/pkg/cli/clisqlexec/run_query_test.go index 66bfc04a2fc0..35550ed54bb6 100644 --- a/pkg/cli/clisqlexec/run_query_test.go +++ b/pkg/cli/clisqlexec/run_query_test.go @@ -98,10 +98,10 @@ SET } expectedRows := [][]string{ - {`parentID`, `INT8`, `false`, `NULL`, ``, `{primary}`, `false`}, - {`parentSchemaID`, `INT8`, `false`, `NULL`, ``, `{primary}`, `false`}, - {`name`, `STRING`, `false`, `NULL`, ``, `{primary}`, `false`}, - {`id`, `INT8`, `true`, `NULL`, ``, `{primary}`, `false`}, + {`parentID`, `INT8`, `f`, `NULL`, ``, `{primary}`, `f`}, + {`parentSchemaID`, `INT8`, `f`, `NULL`, ``, `{primary}`, `f`}, + {`name`, `STRING`, `f`, `NULL`, ``, `{primary}`, `f`}, + {`id`, `INT8`, `t`, `NULL`, ``, `{primary}`, `f`}, } if !reflect.DeepEqual(expectedRows, rows) { t.Fatalf("expected:\n%v\ngot:\n%v", expectedRows, rows) @@ -115,10 +115,10 @@ SET expected = ` column_name | data_type | is_nullable | column_default | generation_expression | indices | is_hidden -----------------+-----------+-------------+----------------+-----------------------+-----------+------------ - parentID | INT8 | false | NULL | | {primary} | false - parentSchemaID | INT8 | false | NULL | | {primary} | false - name | STRING | false | NULL | | {primary} | false - id | INT8 | true | NULL | | {primary} | false + parentID | INT8 | f | NULL | | {primary} | f + parentSchemaID | INT8 | f | NULL | | {primary} | f + name | STRING | f | NULL | | {primary} | f + id | INT8 | t | NULL | | {primary} | f (4 rows) ` diff --git a/pkg/cli/clisqlshell/sql_test.go b/pkg/cli/clisqlshell/sql_test.go index 351ce7a8965f..c1ddf17a4569 100644 --- a/pkg/cli/clisqlshell/sql_test.go +++ b/pkg/cli/clisqlshell/sql_test.go @@ -67,7 +67,9 @@ func Example_sql() { // application_name // $ cockroach sql // sql -e create database t; create table t.f (x int, y int); insert into t.f values (42, 69) - // INSERT 1 + // CREATE DATABASE + // CREATE TABLE + // INSERT 0 1 // sql -e select 3 as "3" -e select * from t.f // 3 // 3 @@ -110,6 +112,8 @@ func Example_sql() { // sql -d nonexistent -e select count(*) from "".information_schema.tables limit 0 // count // sql -d nonexistent -e create database nonexistent; create table foo(x int); select * from foo + // CREATE DATABASE + // CREATE TABLE // x // sql -e copy t.f from stdin // sql -e select 1/(@1-2) from generate_series(1,3) @@ -122,6 +126,7 @@ func Example_sql() { // regression_65066 // 20:01:02+03:04:05 // sql -e CREATE USER my_user WITH CREATEDB; GRANT admin TO my_user; + // CREATE ROLE // GRANT // sql -e \du my_user // username options member_of @@ -186,7 +191,8 @@ func Example_sql_watch() { // Output: // sql -e create table d(x int); insert into d values(3) - // INSERT 1 + // CREATE TABLE + // INSERT 0 1 // sql --watch .1s -e update d set x=x-1 returning 1/x as dec // dec // 0.50000000000000000000 @@ -206,6 +212,7 @@ func Example_misc_table() { // Output: // sql -e create database t; create table t.t (s string, d string); + // CREATE DATABASE // CREATE TABLE // sql --format=table -e select ' hai' as x // x @@ -243,7 +250,9 @@ func Example_in_memory() { // Output: // sql -e create database t; create table t.f (x int, y int); insert into t.f values (42, 69) - // INSERT 1 + // CREATE DATABASE + // CREATE TABLE + // INSERT 0 1 // node ls // id // 1 @@ -264,15 +273,16 @@ func Example_pretty_print_numerical_strings() { // Output: // sql -e create database t; create table t.t (s string, d string); + // CREATE DATABASE // CREATE TABLE // sql -e insert into t.t values (e'0', 'positive numerical string') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'-1', 'negative numerical string') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'1.0', 'decimal numerical string') - // INSERT 1 + // INSERT 0 1 // sql -e insert into t.t values (e'aaaaa', 'non-numerical string') - // INSERT 1 + // INSERT 0 1 // sql --format=table -e select * from t.t // s | d // --------+---------------------------- @@ -300,7 +310,7 @@ func Example_read_from_file() { // SET // CREATE TABLE // > INSERT INTO test(s) VALUES ('hello'), ('world'); - // INSERT 2 + // INSERT 0 2 // > SELECT * FROM test; // s // hello diff --git a/pkg/cli/context.go b/pkg/cli/context.go index 2a2c6aa0678e..b33929d84c85 100644 --- a/pkg/cli/context.go +++ b/pkg/cli/context.go @@ -212,7 +212,7 @@ func setCliContextDefaults() { cliCtx.IsInteractive = false cliCtx.EmbeddedMode = false cliCtx.cmdTimeout = 0 // no timeout - cliCtx.clientOpts.ServerHost = "" + cliCtx.clientOpts.ServerHost = getDefaultHost() cliCtx.clientOpts.ServerPort = base.DefaultPort cliCtx.certPrincipalMap = nil cliCtx.clientOpts.ExplicitURL = nil diff --git a/pkg/cli/debug_job_trace.go b/pkg/cli/debug_job_trace.go index 25e90dffdaa2..fe6db60a65e0 100644 --- a/pkg/cli/debug_job_trace.go +++ b/pkg/cli/debug_job_trace.go @@ -94,7 +94,7 @@ func constructJobTraceZipBundle(ctx context.Context, sqlConn clisqlclient.Conn, return err } - zipper := tracezipper.MakeSQLConnInflightTraceZipper(sqlConn.GetDriverConn().(driver.QueryerContext)) + zipper := tracezipper.MakeSQLConnInflightTraceZipper(sqlConn.GetDriverConn()) zipBytes, err := zipper.Zip(ctx, traceID) if err != nil { return err diff --git a/pkg/cli/demo.go b/pkg/cli/demo.go index 0581d022325f..b9f6ec7eb7d7 100644 --- a/pkg/cli/demo.go +++ b/pkg/cli/demo.go @@ -349,5 +349,5 @@ func runDemo(cmd *cobra.Command, gen workload.Generator) (resErr error) { defer func() { resErr = errors.CombineErrors(resErr, conn.Close()) }() sqlCtx.ShellCtx.ParseURL = clienturl.MakeURLParserFn(cmd, cliCtx.clientOpts) - return sqlCtx.Run(conn) + return sqlCtx.Run(ctx, conn) } diff --git a/pkg/cli/doctor.go b/pkg/cli/doctor.go index bd29d0683a21..5893818b1bd5 100644 --- a/pkg/cli/doctor.go +++ b/pkg/cli/doctor.go @@ -40,7 +40,8 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/protoutil" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/errors" - "github.com/lib/pq" + "github.com/jackc/pgconn" + "github.com/jackc/pgtype" "github.com/spf13/cobra" ) @@ -198,6 +199,10 @@ func fromCluster( retErr error, ) { ctx := context.Background() + if err := sqlConn.EnsureConn(ctx); err != nil { + return nil, nil, nil, err + } + if timeout != 0 { if err := sqlConn.Exec(ctx, `SET statement_timeout = $1`, timeout.String()); err != nil { @@ -211,8 +216,8 @@ FROM system.descriptor ORDER BY id` _, err := sqlConn.QueryRow(ctx, checkColumnExistsStmt) // On versions before 20.2, the system.descriptor won't have the builtin // crdb_internal_mvcc_timestamp. If we can't find it, use NULL instead. - if pqErr := (*pq.Error)(nil); errors.As(err, &pqErr) { - if pgcode.MakeCode(string(pqErr.Code)) == pgcode.UndefinedColumn { + if pgErr := (*pgconn.PgError)(nil); errors.As(err, &pgErr) { + if pgcode.MakeCode(pgErr.Code) == pgcode.UndefinedColumn { stmt = ` SELECT id, descriptor, NULL AS mod_time_logical FROM system.descriptor ORDER BY id` @@ -236,8 +241,12 @@ FROM system.descriptor ORDER BY id` } if vals[2] == nil { row.ModTime = hlc.Timestamp{WallTime: timeutil.Now().UnixNano()} - } else if mt, ok := vals[2].([]byte); ok { - decimal, _, err := apd.NewFromString(string(mt)) + } else if mt, ok := vals[2].(pgtype.Numeric); ok { + buf, err := mt.EncodeText(nil, nil) + if err != nil { + return err + } + decimal, _, err := apd.NewFromString(string(buf)) if err != nil { return err } @@ -261,8 +270,8 @@ FROM system.descriptor ORDER BY id` _, err = sqlConn.QueryRow(ctx, checkColumnExistsStmt) // On versions before 20.1, table system.namespace does not have this column. // In that case the ParentSchemaID for tables is 29 and for databases is 0. - if pqErr := (*pq.Error)(nil); errors.As(err, &pqErr) { - if pgcode.MakeCode(string(pqErr.Code)) == pgcode.UndefinedColumn { + if pgErr := (*pgconn.PgError)(nil); errors.As(err, &pgErr) { + if pgcode.MakeCode(pgErr.Code) == pgcode.UndefinedColumn { stmt = ` SELECT "parentID", CASE WHEN "parentID" = 0 THEN 0 ELSE 29 END AS "parentSchemaID", name, id FROM system.namespace` diff --git a/pkg/cli/env.go b/pkg/cli/env.go new file mode 100644 index 000000000000..c7f0f1b19c39 --- /dev/null +++ b/pkg/cli/env.go @@ -0,0 +1,24 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package cli + +import "os" + +// getDefaultHost gets the default value for the host to connect to. pgx already +// has logic that would inspect PGHOST, but the problem is that if PGHOST not +// defined, pgx would default to using a unix socket. That is not desired, so +// here we make the CLI fallback to use "localhost" if PGHOST is not defined. +func getDefaultHost() string { + if h := os.Getenv("PGHOST"); h == "" { + return "localhost" + } + return "" +} diff --git a/pkg/cli/flags_test.go b/pkg/cli/flags_test.go index 4e95fdb48b85..ef9d1aca6e1c 100644 --- a/pkg/cli/flags_test.go +++ b/pkg/cli/flags_test.go @@ -972,7 +972,7 @@ func TestClientConnSettings(t *testing.T) { args []string expectedAddr string }{ - {[]string{"quit"}, ":" + base.DefaultPort}, + {[]string{"quit"}, "localhost:" + base.DefaultPort}, {[]string{"quit", "--host", "127.0.0.1"}, "127.0.0.1:" + base.DefaultPort}, {[]string{"quit", "--host", "192.168.0.111"}, "192.168.0.111:" + base.DefaultPort}, {[]string{"quit", "--host", ":12345"}, ":12345"}, @@ -986,7 +986,7 @@ func TestClientConnSettings(t *testing.T) { "[2622:6221:e663:4922:fc2b:788b:fadd:7b48]:" + base.DefaultPort}, // Deprecated syntax. - {[]string{"quit", "--port", "12345"}, ":12345"}, + {[]string{"quit", "--port", "12345"}, "localhost:12345"}, {[]string{"quit", "--host", "127.0.0.1", "--port", "12345"}, "127.0.0.1:12345"}, } diff --git a/pkg/cli/import.go b/pkg/cli/import.go index a9f0588f0375..a0fbfbf079c6 100644 --- a/pkg/cli/import.go +++ b/pkg/cli/import.go @@ -111,7 +111,7 @@ func runImport( importFormat, source, tableName string, mode importMode, ) error { - if err := conn.EnsureConn(); err != nil { + if err := conn.EnsureConn(ctx); err != nil { return err } @@ -216,7 +216,7 @@ func runImport( return nil } - _, err = ex.ExecContext(ctx, importQuery, nil) + err = ex.Exec(ctx, importQuery) if err != nil { return err } diff --git a/pkg/cli/interactive_tests/test_client_side_checking.tcl b/pkg/cli/interactive_tests/test_client_side_checking.tcl index 3f6556de09a4..8c61e80f75f5 100644 --- a/pkg/cli/interactive_tests/test_client_side_checking.tcl +++ b/pkg/cli/interactive_tests/test_client_side_checking.tcl @@ -28,7 +28,7 @@ end_test start_test "Check that the syntax checker does not get confused by empty inputs." # (issue #22441.) send ";\r" -eexpect "0 rows" +eexpect "OK" eexpect root@ end_test diff --git a/pkg/cli/interactive_tests/test_connect_cmd.tcl b/pkg/cli/interactive_tests/test_connect_cmd.tcl index f3b86742fc44..4268af2ae4dc 100644 --- a/pkg/cli/interactive_tests/test_connect_cmd.tcl +++ b/pkg/cli/interactive_tests/test_connect_cmd.tcl @@ -132,7 +132,7 @@ start_test "Check that the client-side connect cmd can change users with certs u # first test that it can recover from an invalid database send "\\c postgres://root@localhost:26257/invaliddb?sslmode=require&sslcert=$certs_dir%2Fclient.root.crt&sslkey=$certs_dir%2Fclient.root.key&sslrootcert=$certs_dir%2Fca.crt\r" eexpect "using new connection URL" -eexpect "error retrieving the database name: pq: database \"invaliddb\" does not exist" +eexpect "error retrieving the database name: ERROR: database \"invaliddb\" does not exist" eexpect root@ eexpect "?>" diff --git a/pkg/cli/interactive_tests/test_demo_node_cmds.tcl b/pkg/cli/interactive_tests/test_demo_node_cmds.tcl index f74640c24bb1..2f5f0b9d4794 100644 --- a/pkg/cli/interactive_tests/test_demo_node_cmds.tcl +++ b/pkg/cli/interactive_tests/test_demo_node_cmds.tcl @@ -42,11 +42,11 @@ eexpect "node 3 has been shutdown" eexpect "movr>" send "select node_id, draining, decommissioning, membership from crdb_internal.gossip_liveness ORDER BY node_id;\r" -eexpect "1 | false | false | active" -eexpect "2 | false | false | active" -eexpect "3 | true | false | active" -eexpect "4 | false | false | active" -eexpect "5 | false | false | active" +eexpect "1 | f | f | active" +eexpect "2 | f | f | active" +eexpect "3 | t | f | active" +eexpect "4 | f | f | active" +eexpect "5 | f | f | active" eexpect "movr>" # Cannot shut it down again. diff --git a/pkg/cli/interactive_tests/test_reconnect.tcl b/pkg/cli/interactive_tests/test_reconnect.tcl index d2ed73c1e795..2c63efee88cb 100644 --- a/pkg/cli/interactive_tests/test_reconnect.tcl +++ b/pkg/cli/interactive_tests/test_reconnect.tcl @@ -23,7 +23,7 @@ start_test "Check that the client properly detects the server went down" force_stop_server $argv send "select 1;\r" -eexpect "bad connection" +eexpect "connection closed unexpectedly" eexpect root@ eexpect " ?>" end_test @@ -56,7 +56,7 @@ stop_server $argv start_server $argv send "select 1;\r" -eexpect "bad connection" +eexpect "connection closed unexpectedly" eexpect root@ send "select 1;\r" diff --git a/pkg/cli/interactive_tests/test_sql_version_reporting.tcl b/pkg/cli/interactive_tests/test_sql_version_reporting.tcl index f82164fde477..199728d28430 100644 --- a/pkg/cli/interactive_tests/test_sql_version_reporting.tcl +++ b/pkg/cli/interactive_tests/test_sql_version_reporting.tcl @@ -31,7 +31,7 @@ start_test "Check that a reconnect without version change is quiet." force_stop_server $argv start_server $argv send "select 1;\r" -eexpect "driver: bad connection" +eexpect "connection closed unexpectedly" # Check that the prompt immediately succeeds the error message eexpect "connection lost" eexpect "opening new connection: all session settings will be lost" diff --git a/pkg/cli/interactive_tests/test_url_db_override.tcl b/pkg/cli/interactive_tests/test_url_db_override.tcl index 75e053d51daf..407cac6e3fd0 100644 --- a/pkg/cli/interactive_tests/test_url_db_override.tcl +++ b/pkg/cli/interactive_tests/test_url_db_override.tcl @@ -38,7 +38,7 @@ start_test "Check that the insecure flag overrides the sslmode if URL is already set ::env(COCKROACH_INSECURE) "false" spawn $argv sql --url "postgresql://test@localhost:26257?sslmode=verify-full" --certs-dir=$certs_dir -e "select 1" -eexpect "SSL is not enabled on the server" +eexpect "server refused TLS connection" eexpect eof spawn $argv sql --url "postgresql://test@localhost:26257?sslmode=verify-full" --certs-dir=$certs_dir --insecure -e "select 1" diff --git a/pkg/cli/node.go b/pkg/cli/node.go index 3fc671a2eaa1..d12abe2c8b0d 100644 --- a/pkg/cli/node.go +++ b/pkg/cli/node.go @@ -228,6 +228,9 @@ FROM crdb_internal.gossip_liveness LEFT JOIN crdb_internal.gossip_nodes USING (n } ctx := context.Background() + if err = conn.EnsureConn(ctx); err != nil { + return nil, nil, err + } // TODO(knz): This can use a context deadline instead, now that // query cancellation is supported. diff --git a/pkg/cli/nodelocal.go b/pkg/cli/nodelocal.go index 141aaf83dfe4..fa4f9ed14e42 100644 --- a/pkg/cli/nodelocal.go +++ b/pkg/cli/nodelocal.go @@ -11,8 +11,8 @@ package cli import ( + "bytes" "context" - "database/sql/driver" "fmt" "io" "net/url" @@ -27,9 +27,7 @@ import ( "github.com/spf13/cobra" ) -const ( - chunkSize = 4 * 1024 -) +const chunkSize = 4 * 1024 var nodeLocalUploadCmd = &cobra.Command{ Use: "upload ", @@ -74,17 +72,56 @@ func openSourceFile(source string) (io.ReadCloser, error) { return f, nil } +// appendEscapedText escapes the input text for processing by the pgwire COPY +// protocol. The result is appended to the []byte given by buf. +// This implementation is copied from lib/pq. +// https://github.com/lib/pq/blob/8c6de565f76fb5cd40a5c1b8ce583fbc3ba1bd0e/encode.go#L138 +func appendEscapedText(buf []byte, text string) []byte { + escapeNeeded := false + startPos := 0 + var c byte + + // check if we need to escape + for i := 0; i < len(text); i++ { + c = text[i] + if c == '\\' || c == '\n' || c == '\r' || c == '\t' { + escapeNeeded = true + startPos = i + break + } + } + if !escapeNeeded { + return append(buf, text...) + } + + // copy till first char to escape, iterate the rest + result := append(buf, text[:startPos]...) + for i := startPos; i < len(text); i++ { + c = text[i] + switch c { + case '\\': + result = append(result, '\\', '\\') + case '\n': + result = append(result, '\\', 'n') + case '\r': + result = append(result, '\\', 'r') + case '\t': + result = append(result, '\\', 't') + default: + result = append(result, c) + } + } + return result +} + func uploadFile( ctx context.Context, conn clisqlclient.Conn, reader io.Reader, destination string, ) error { - if err := conn.EnsureConn(); err != nil { + if err := conn.EnsureConn(ctx); err != nil { return err } ex := conn.GetDriverConn() - if _, err := ex.ExecContext(ctx, `BEGIN`, nil); err != nil { - return err - } // Construct the nodelocal URI as the destination for the CopyIn stmt. nodelocalURL := url.URL{ @@ -92,42 +129,22 @@ func uploadFile( Host: "self", Path: destination, } - stmt, err := conn.GetDriverConn().Prepare(sql.CopyInFileStmt(nodelocalURL.String(), sql.CrdbInternalName, - sql.NodelocalFileUploadTable)) - if err != nil { - return err - } + stmt := sql.CopyInFileStmt(nodelocalURL.String(), sql.CrdbInternalName, sql.NodelocalFileUploadTable) - defer func() { - if stmt != nil { - _ = stmt.Close() - _, _ = ex.ExecContext(ctx, `ROLLBACK`, nil) - } - }() - - send := make([]byte, chunkSize) + send := make([]byte, 0) + tmp := make([]byte, chunkSize) for { - n, err := reader.Read(send) + n, err := reader.Read(tmp) if n > 0 { - // TODO(adityamaru): Switch to StmtExecContext once the copyin driver - // supports it. - //lint:ignore SA1019 DriverConn doesn't support go 1.8 API - _, err = stmt.Exec([]driver.Value{string(send[:n])}) - if err != nil { - return err - } + send = appendEscapedText(send, string(tmp[:n])) } else if err == io.EOF { break } else if err != nil { return err } } - if err := stmt.Close(); err != nil { - return err - } - stmt = nil - if _, err := ex.ExecContext(ctx, `COMMIT`, nil); err != nil { + if err := ex.CopyFrom(ctx, bytes.NewReader(send), stmt); err != nil { return err } diff --git a/pkg/cli/sql_shell_cmd.go b/pkg/cli/sql_shell_cmd.go index edcc20fdf915..5b8ae110a2df 100644 --- a/pkg/cli/sql_shell_cmd.go +++ b/pkg/cli/sql_shell_cmd.go @@ -11,6 +11,7 @@ package cli import ( + "context" "fmt" "os" @@ -59,5 +60,5 @@ func runTerm(cmd *cobra.Command, args []string) (resErr error) { defer func() { resErr = errors.CombineErrors(resErr, conn.Close()) }() sqlCtx.ShellCtx.ParseURL = clienturl.MakeURLParserFn(cmd, cliCtx.clientOpts) - return sqlCtx.Run(conn) + return sqlCtx.Run(context.Background(), conn) } diff --git a/pkg/cli/statement_bundle.go b/pkg/cli/statement_bundle.go index 9bc5952d2c12..ac9e9a542a70 100644 --- a/pkg/cli/statement_bundle.go +++ b/pkg/cli/statement_bundle.go @@ -215,7 +215,7 @@ func runBundleRecreate(cmd *cobra.Command, args []string) error { } sqlCtx.ShellCtx.DemoCluster = c - return sqlCtx.Run(conn) + return sqlCtx.Run(ctx, conn) } // placeholderRe matches the placeholder format at the bottom of statement.txt diff --git a/pkg/cli/testdata/zip/partial1 b/pkg/cli/testdata/zip/partial1 index 1b5a40c1a3f2..6a388c4a4db9 100644 --- a/pkg/cli/testdata/zip/partial1 +++ b/pkg/cli/testdata/zip/partial1 @@ -155,64 +155,64 @@ debug zip --concurrency=1 --cpu-profile-duration=0s /dev/null [node 2] node status... converting to JSON... writing binary output: debug/nodes/2/status.json... done [node 2] using SQL connection URL: postgresql://... [node 2] retrieving SQL data for crdb_internal.feature_usage... writing output: debug/nodes/2/crdb_internal.feature_usage.txt... -[node 2] retrieving SQL data for crdb_internal.feature_usage: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.feature_usage: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.feature_usage: creating error output: debug/nodes/2/crdb_internal.feature_usage.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.gossip_alerts... writing output: debug/nodes/2/crdb_internal.gossip_alerts.txt... -[node 2] retrieving SQL data for crdb_internal.gossip_alerts: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.gossip_alerts: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.gossip_alerts: creating error output: debug/nodes/2/crdb_internal.gossip_alerts.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.gossip_liveness... writing output: debug/nodes/2/crdb_internal.gossip_liveness.txt... -[node 2] retrieving SQL data for crdb_internal.gossip_liveness: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.gossip_liveness: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.gossip_liveness: creating error output: debug/nodes/2/crdb_internal.gossip_liveness.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.gossip_network... writing output: debug/nodes/2/crdb_internal.gossip_network.txt... -[node 2] retrieving SQL data for crdb_internal.gossip_network: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.gossip_network: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.gossip_network: creating error output: debug/nodes/2/crdb_internal.gossip_network.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.gossip_nodes... writing output: debug/nodes/2/crdb_internal.gossip_nodes.txt... -[node 2] retrieving SQL data for crdb_internal.gossip_nodes: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.gossip_nodes: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.gossip_nodes: creating error output: debug/nodes/2/crdb_internal.gossip_nodes.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.leases... writing output: debug/nodes/2/crdb_internal.leases.txt... -[node 2] retrieving SQL data for crdb_internal.leases: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.leases: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.leases: creating error output: debug/nodes/2/crdb_internal.leases.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_build_info... writing output: debug/nodes/2/crdb_internal.node_build_info.txt... -[node 2] retrieving SQL data for crdb_internal.node_build_info: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_build_info: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_build_info: creating error output: debug/nodes/2/crdb_internal.node_build_info.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_contention_events... writing output: debug/nodes/2/crdb_internal.node_contention_events.txt... -[node 2] retrieving SQL data for crdb_internal.node_contention_events: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_contention_events: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_contention_events: creating error output: debug/nodes/2/crdb_internal.node_contention_events.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_distsql_flows... writing output: debug/nodes/2/crdb_internal.node_distsql_flows.txt... -[node 2] retrieving SQL data for crdb_internal.node_distsql_flows: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_distsql_flows: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_distsql_flows: creating error output: debug/nodes/2/crdb_internal.node_distsql_flows.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_execution_outliers... writing output: debug/nodes/2/crdb_internal.node_execution_outliers.txt... -[node 2] retrieving SQL data for crdb_internal.node_execution_outliers: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_execution_outliers: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_execution_outliers: creating error output: debug/nodes/2/crdb_internal.node_execution_outliers.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_inflight_trace_spans... writing output: debug/nodes/2/crdb_internal.node_inflight_trace_spans.txt... -[node 2] retrieving SQL data for crdb_internal.node_inflight_trace_spans: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_inflight_trace_spans: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_inflight_trace_spans: creating error output: debug/nodes/2/crdb_internal.node_inflight_trace_spans.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_metrics... writing output: debug/nodes/2/crdb_internal.node_metrics.txt... -[node 2] retrieving SQL data for crdb_internal.node_metrics: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_metrics: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_metrics: creating error output: debug/nodes/2/crdb_internal.node_metrics.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_queries... writing output: debug/nodes/2/crdb_internal.node_queries.txt... -[node 2] retrieving SQL data for crdb_internal.node_queries: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_queries: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_queries: creating error output: debug/nodes/2/crdb_internal.node_queries.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_runtime_info... writing output: debug/nodes/2/crdb_internal.node_runtime_info.txt... -[node 2] retrieving SQL data for crdb_internal.node_runtime_info: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_runtime_info: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_runtime_info: creating error output: debug/nodes/2/crdb_internal.node_runtime_info.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_sessions... writing output: debug/nodes/2/crdb_internal.node_sessions.txt... -[node 2] retrieving SQL data for crdb_internal.node_sessions: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_sessions: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_sessions: creating error output: debug/nodes/2/crdb_internal.node_sessions.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_statement_statistics... writing output: debug/nodes/2/crdb_internal.node_statement_statistics.txt... -[node 2] retrieving SQL data for crdb_internal.node_statement_statistics: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_statement_statistics: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_statement_statistics: creating error output: debug/nodes/2/crdb_internal.node_statement_statistics.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_transaction_statistics... writing output: debug/nodes/2/crdb_internal.node_transaction_statistics.txt... -[node 2] retrieving SQL data for crdb_internal.node_transaction_statistics: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_transaction_statistics: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_transaction_statistics: creating error output: debug/nodes/2/crdb_internal.node_transaction_statistics.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_transactions... writing output: debug/nodes/2/crdb_internal.node_transactions.txt... -[node 2] retrieving SQL data for crdb_internal.node_transactions: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_transactions: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_transactions: creating error output: debug/nodes/2/crdb_internal.node_transactions.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.node_txn_stats... writing output: debug/nodes/2/crdb_internal.node_txn_stats.txt... -[node 2] retrieving SQL data for crdb_internal.node_txn_stats: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.node_txn_stats: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.node_txn_stats: creating error output: debug/nodes/2/crdb_internal.node_txn_stats.txt.err.txt... done [node 2] retrieving SQL data for crdb_internal.active_range_feeds... writing output: debug/nodes/2/crdb_internal.active_range_feeds.txt... -[node 2] retrieving SQL data for crdb_internal.active_range_feeds: last request failed: dial tcp ... +[node 2] retrieving SQL data for crdb_internal.active_range_feeds: last request failed: failed to connect to ... [node 2] retrieving SQL data for crdb_internal.active_range_feeds: creating error output: debug/nodes/2/crdb_internal.active_range_feeds.txt.err.txt... done [node 2] requesting data for debug/nodes/2/details... received response... [node 2] requesting data for debug/nodes/2/details: last request failed: rpc error: ... diff --git a/pkg/cli/testdata/zip/testzip_tenant b/pkg/cli/testdata/zip/testzip_tenant index d267e30060ce..c81064f1e585 100644 --- a/pkg/cli/testdata/zip/testzip_tenant +++ b/pkg/cli/testdata/zip/testzip_tenant @@ -33,13 +33,13 @@ debug zip --concurrency=1 --cpu-profile-duration=1s /dev/null [cluster] retrieving SQL data for "".crdb_internal.create_statements... writing output: debug/crdb_internal.create_statements.txt... done [cluster] retrieving SQL data for "".crdb_internal.create_type_statements... writing output: debug/crdb_internal.create_type_statements.txt... done [cluster] retrieving SQL data for crdb_internal.kv_node_liveness... writing output: debug/crdb_internal.kv_node_liveness.txt... -[cluster] retrieving SQL data for crdb_internal.kv_node_liveness: last request failed: pq: unimplemented: operation is unsupported in multi-tenancy mode +[cluster] retrieving SQL data for crdb_internal.kv_node_liveness: last request failed: ERROR: unimplemented: operation is unsupported in multi-tenancy mode (SQLSTATE 0A000) [cluster] retrieving SQL data for crdb_internal.kv_node_liveness: creating error output: debug/crdb_internal.kv_node_liveness.txt.err.txt... done [cluster] retrieving SQL data for crdb_internal.kv_node_status... writing output: debug/crdb_internal.kv_node_status.txt... -[cluster] retrieving SQL data for crdb_internal.kv_node_status: last request failed: pq: unimplemented: operation is unsupported in multi-tenancy mode +[cluster] retrieving SQL data for crdb_internal.kv_node_status: last request failed: ERROR: unimplemented: operation is unsupported in multi-tenancy mode (SQLSTATE 0A000) [cluster] retrieving SQL data for crdb_internal.kv_node_status: creating error output: debug/crdb_internal.kv_node_status.txt.err.txt... done [cluster] retrieving SQL data for crdb_internal.kv_store_status... writing output: debug/crdb_internal.kv_store_status.txt... -[cluster] retrieving SQL data for crdb_internal.kv_store_status: last request failed: pq: unimplemented: operation is unsupported in multi-tenancy mode +[cluster] retrieving SQL data for crdb_internal.kv_store_status: last request failed: ERROR: unimplemented: operation is unsupported in multi-tenancy mode (SQLSTATE 0A000) [cluster] retrieving SQL data for crdb_internal.kv_store_status: creating error output: debug/crdb_internal.kv_store_status.txt.err.txt... done [cluster] retrieving SQL data for crdb_internal.regions... writing output: debug/crdb_internal.regions.txt... done [cluster] retrieving SQL data for crdb_internal.schema_changes... writing output: debug/crdb_internal.schema_changes.txt... done @@ -90,16 +90,16 @@ debug zip --concurrency=1 --cpu-profile-duration=1s /dev/null [node 1] using SQL connection URL: postgresql://... [node 1] retrieving SQL data for crdb_internal.feature_usage... writing output: debug/nodes/1/crdb_internal.feature_usage.txt... done [node 1] retrieving SQL data for crdb_internal.gossip_alerts... writing output: debug/nodes/1/crdb_internal.gossip_alerts.txt... -[node 1] retrieving SQL data for crdb_internal.gossip_alerts: last request failed: pq: unimplemented: operation is unsupported in multi-tenancy mode +[node 1] retrieving SQL data for crdb_internal.gossip_alerts: last request failed: ERROR: unimplemented: operation is unsupported in multi-tenancy mode (SQLSTATE 0A000) [node 1] retrieving SQL data for crdb_internal.gossip_alerts: creating error output: debug/nodes/1/crdb_internal.gossip_alerts.txt.err.txt... done [node 1] retrieving SQL data for crdb_internal.gossip_liveness... writing output: debug/nodes/1/crdb_internal.gossip_liveness.txt... -[node 1] retrieving SQL data for crdb_internal.gossip_liveness: last request failed: pq: unimplemented: operation is unsupported in multi-tenancy mode +[node 1] retrieving SQL data for crdb_internal.gossip_liveness: last request failed: ERROR: unimplemented: operation is unsupported in multi-tenancy mode (SQLSTATE 0A000) [node 1] retrieving SQL data for crdb_internal.gossip_liveness: creating error output: debug/nodes/1/crdb_internal.gossip_liveness.txt.err.txt... done [node 1] retrieving SQL data for crdb_internal.gossip_network... writing output: debug/nodes/1/crdb_internal.gossip_network.txt... -[node 1] retrieving SQL data for crdb_internal.gossip_network: last request failed: pq: unimplemented: operation is unsupported in multi-tenancy mode +[node 1] retrieving SQL data for crdb_internal.gossip_network: last request failed: ERROR: unimplemented: operation is unsupported in multi-tenancy mode (SQLSTATE 0A000) [node 1] retrieving SQL data for crdb_internal.gossip_network: creating error output: debug/nodes/1/crdb_internal.gossip_network.txt.err.txt... done [node 1] retrieving SQL data for crdb_internal.gossip_nodes... writing output: debug/nodes/1/crdb_internal.gossip_nodes.txt... -[node 1] retrieving SQL data for crdb_internal.gossip_nodes: last request failed: pq: unimplemented: operation is unsupported in multi-tenancy mode +[node 1] retrieving SQL data for crdb_internal.gossip_nodes: last request failed: ERROR: unimplemented: operation is unsupported in multi-tenancy mode (SQLSTATE 0A000) [node 1] retrieving SQL data for crdb_internal.gossip_nodes: creating error output: debug/nodes/1/crdb_internal.gossip_nodes.txt.err.txt... done [node 1] retrieving SQL data for crdb_internal.leases... writing output: debug/nodes/1/crdb_internal.leases.txt... done [node 1] retrieving SQL data for crdb_internal.node_build_info... writing output: debug/nodes/1/crdb_internal.node_build_info.txt... done diff --git a/pkg/cli/userfile.go b/pkg/cli/userfile.go index 1aafd9e10c56..f8c81a0da952 100644 --- a/pkg/cli/userfile.go +++ b/pkg/cli/userfile.go @@ -11,8 +11,8 @@ package cli import ( + "bytes" "context" - "database/sql/driver" "fmt" "io" "io/fs" @@ -241,7 +241,7 @@ func runUserFileGet(cmd *cobra.Command, args []string) (resErr error) { pattern := fullPath[len(conf.Path):] displayPath := strings.TrimPrefix(conf.Path, "/") - f, err := userfile.MakeSQLConnFileTableStorage(ctx, conf, conn.GetDriverConn().(cloud.SQLConnI)) + f, err := userfile.MakeSQLConnFileTableStorage(ctx, conf, conn.GetDriverConn()) if err != nil { return err } @@ -399,7 +399,7 @@ func constructUserfileListURI(glob string, user username.SQLUsername) string { func getUserfileConf( ctx context.Context, conn clisqlclient.Conn, glob string, ) (roachpb.ExternalStorage_FileTable, error) { - if err := conn.EnsureConn(); err != nil { + if err := conn.EnsureConn(ctx); err != nil { return roachpb.ExternalStorage_FileTable{}, err } @@ -434,7 +434,7 @@ func listUserFile(ctx context.Context, conn clisqlclient.Conn, glob string) ([]s conf.Path = cloud.GetPrefixBeforeWildcard(fullPath) pattern := fullPath[len(conf.Path):] - f, err := userfile.MakeSQLConnFileTableStorage(ctx, conf, conn.GetDriverConn().(cloud.SQLConnI)) + f, err := userfile.MakeSQLConnFileTableStorage(ctx, conf, conn.GetDriverConn()) if err != nil { return nil, err } @@ -482,7 +482,7 @@ func downloadUserfile( } func deleteUserFile(ctx context.Context, conn clisqlclient.Conn, glob string) ([]string, error) { - if err := conn.EnsureConn(); err != nil { + if err := conn.EnsureConn(ctx); err != nil { return nil, err } @@ -511,8 +511,7 @@ func deleteUserFile(ctx context.Context, conn clisqlclient.Conn, glob string) ([ userFileTableConf.FileTableConfig.Path = cloud.GetPrefixBeforeWildcard(fullPath) pattern := fullPath[len(userFileTableConf.FileTableConfig.Path):] - f, err := userfile.MakeSQLConnFileTableStorage(ctx, userFileTableConf.FileTableConfig, - conn.GetDriverConn().(cloud.SQLConnI)) + f, err := userfile.MakeSQLConnFileTableStorage(ctx, userFileTableConf.FileTableConfig, conn.GetDriverConn()) if err != nil { return nil, err } @@ -543,39 +542,16 @@ func renameUserFile( ctx context.Context, conn clisqlclient.Conn, oldFilename, newFilename, qualifiedTableName string, ) error { - if err := conn.EnsureConn(); err != nil { + if err := conn.EnsureConn(ctx); err != nil { return err } - ex := conn.GetDriverConn() - if _, err := ex.ExecContext(ctx, `BEGIN`, nil); err != nil { - return err - } - - stmt, err := conn.GetDriverConn().Prepare(fmt.Sprintf(`UPDATE %s SET filename=$1 WHERE filename=$2`, - qualifiedTableName+fileTableNameSuffix)) - if err != nil { - return err - } - - defer func() { - if stmt != nil { - _ = stmt.Close() - _, _ = ex.ExecContext(ctx, `ROLLBACK`, nil) - } - }() - //lint:ignore SA1019 DriverConn doesn't support go 1.8 API - _, err = stmt.Exec([]driver.Value{newFilename, oldFilename}) - if err != nil { - return err - } - - if err := stmt.Close(); err != nil { - return err - } - stmt = nil - if _, err := ex.ExecContext(ctx, `COMMIT`, nil); err != nil { + if err := ex.Exec( + ctx, + fmt.Sprintf(`UPDATE %s SET filename=$1 WHERE filename=$2`, qualifiedTableName+fileTableNameSuffix), + newFilename, oldFilename, + ); err != nil { return err } @@ -595,14 +571,11 @@ func uploadUserFile( } defer reader.Close() - if err := conn.EnsureConn(); err != nil { + if err := conn.EnsureConn(ctx); err != nil { return "", err } ex := conn.GetDriverConn() - if _, err := ex.ExecContext(ctx, `BEGIN`, nil); err != nil { - return "", err - } connURL, err := url.Parse(conn.GetURL()) if err != nil { @@ -634,42 +607,22 @@ func uploadUserFile( if err != nil { return "", err } - stmt, err := conn.GetDriverConn().Prepare(sql.CopyInFileStmt(unescapedUserfileURL, sql.CrdbInternalName, - sql.UserFileUploadTable)) - if err != nil { - return "", err - } + stmt := sql.CopyInFileStmt(unescapedUserfileURL, sql.CrdbInternalName, sql.UserFileUploadTable) - defer func() { - if stmt != nil { - _ = stmt.Close() - _, _ = ex.ExecContext(ctx, `ROLLBACK`, nil) - } - }() - - send := make([]byte, chunkSize) + send := make([]byte, 0) + tmp := make([]byte, chunkSize) for { - n, err := reader.Read(send) + n, err := reader.Read(tmp) if n > 0 { - // TODO(adityamaru): Switch to StmtExecContext once the copyin driver - // supports it. - //lint:ignore SA1019 DriverConn doesn't support go 1.8 API - _, err = stmt.Exec([]driver.Value{string(send[:n])}) - if err != nil { - return "", err - } + send = appendEscapedText(send, string(tmp[:n])) } else if err == io.EOF { break } else if err != nil { return "", err } } - if err := stmt.Close(); err != nil { - return "", err - } - stmt = nil - if _, err := ex.ExecContext(ctx, `COMMIT`, nil); err != nil { + if err := ex.CopyFrom(ctx, bytes.NewReader(send), stmt); err != nil { return "", err } diff --git a/pkg/cli/zip.go b/pkg/cli/zip.go index 29fbbf26b484..b9a8a336774c 100644 --- a/pkg/cli/zip.go +++ b/pkg/cli/zip.go @@ -29,7 +29,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/util/contextutil" "github.com/cockroachdb/errors" - "github.com/lib/pq" + "github.com/jackc/pgconn" "github.com/marusama/semaphore" "github.com/spf13/cobra" ) @@ -325,9 +325,7 @@ func maybeAddProfileSuffix(name string) string { func (zc *debugZipContext) dumpTableDataForZip( zr *zipReporter, conn clisqlclient.Conn, base, table, query string, ) error { - // TODO(knz): This can use context cancellation now that query - // cancellation is supported. - fullQuery := fmt.Sprintf(`SET statement_timeout = '%s'; %s`, zc.timeout, query) + ctx := context.Background() baseName := base + "/" + sanitizeFilename(table) s := zr.start("retrieving SQL data for %s", table) @@ -340,26 +338,33 @@ func (zc *debugZipContext) dumpTableDataForZip( zc.z.Lock() defer zc.z.Unlock() + // TODO(knz): This can use context cancellation now that query + // cancellation is supported in v22.1 and later. + // SET must be run separately from the query so that the command tag output + // doesn't get added to the debug file. + err := conn.Exec(ctx, fmt.Sprintf(`SET statement_timeout = '%s'`, zc.timeout)) + if err != nil { + return err + } + w, err := zc.z.createLocked(name, time.Time{}) if err != nil { return err } // Pump the SQL rows directly into the zip writer, to avoid // in-RAM buffering. - return sqlExecCtx.RunQueryAndFormatResults( - context.Background(), - conn, w, stderr, clisqlclient.MakeQuery(fullQuery)) + return sqlExecCtx.RunQueryAndFormatResults(ctx, conn, w, stderr, clisqlclient.MakeQuery(query)) }() if sqlErr != nil { if cErr := zc.z.createError(s, name, sqlErr); cErr != nil { return cErr } - var pqErr *pq.Error - if !errors.As(sqlErr, &pqErr) { + var pgErr = (*pgconn.PgError)(nil) + if !errors.As(sqlErr, &pgErr) { // Not a SQL error. Nothing to retry. break } - if pgcode.MakeCode(string(pqErr.Code)) != pgcode.SerializationFailure { + if pgcode.MakeCode(pgErr.Code) != pgcode.SerializationFailure { // A non-retry error. We've printed the error, and // there's nothing to retry. Stop here. break diff --git a/pkg/cli/zip_test.go b/pkg/cli/zip_test.go index 137b6c8d8b03..304a26722d50 100644 --- a/pkg/cli/zip_test.go +++ b/pkg/cli/zip_test.go @@ -336,6 +336,8 @@ func eraseNonDeterministicZipOutput(out string) string { out = re.ReplaceAllString(out, `dial tcp ...`) re = regexp.MustCompile(`(?m)rpc error: .*$`) out = re.ReplaceAllString(out, `rpc error: ...`) + re = regexp.MustCompile(`(?m)failed to connect to .*$`) + out = re.ReplaceAllString(out, `failed to connect to ...`) // The number of memory profiles previously collected is not deterministic. re = regexp.MustCompile(`(?m)^\[node \d+\] \d+ heap profiles found$`) diff --git a/pkg/cloud/external_storage.go b/pkg/cloud/external_storage.go index efc86ff44683..b2b06c903135 100644 --- a/pkg/cloud/external_storage.go +++ b/pkg/cloud/external_storage.go @@ -109,8 +109,8 @@ type ExternalStorageFromURIFactory func(ctx context.Context, uri string, // SQLConnI encapsulates the interfaces which will be implemented by the network // backed SQLConn which is used to interact with the userfile tables. type SQLConnI interface { - driver.QueryerContext - driver.ExecerContext + Query(ctx context.Context, query string, args ...interface{}) (driver.Rows, error) + Exec(ctx context.Context, query string, args ...interface{}) error } // ErrFileDoesNotExist is a sentinel error for indicating that a specified diff --git a/pkg/cloud/userfile/filetable/file_table_read_writer.go b/pkg/cloud/userfile/filetable/file_table_read_writer.go index a040ed03a9af..75b5bbfd1514 100644 --- a/pkg/cloud/userfile/filetable/file_table_read_writer.go +++ b/pkg/cloud/userfile/filetable/file_table_read_writer.go @@ -116,18 +116,8 @@ func (i *SQLConnFileToTableExecutor) Query( ) (*FileToTableExecutorRows, error) { result := FileToTableExecutorRows{} - argVals := make([]driver.NamedValue, len(qargs)) - for i, qarg := range qargs { - namedVal := driver.NamedValue{ - // Ordinal position is 1 indexed. - Ordinal: i + 1, - Value: qarg, - } - argVals[i] = namedVal - } - var err error - result.sqlConnExecResults, err = i.executor.QueryContext(ctx, query, argVals) + result.sqlConnExecResults, err = i.executor.Query(ctx, query, qargs...) if err != nil { return nil, err } @@ -138,17 +128,7 @@ func (i *SQLConnFileToTableExecutor) Query( func (i *SQLConnFileToTableExecutor) Exec( ctx context.Context, _, query string, _ username.SQLUsername, qargs ...interface{}, ) error { - argVals := make([]driver.NamedValue, len(qargs)) - for i, qarg := range qargs { - namedVal := driver.NamedValue{ - // Ordinal position is 1 indexed. - Ordinal: i + 1, - Value: qarg, - } - argVals[i] = namedVal - } - _, err := i.executor.ExecContext(ctx, query, argVals) - return err + return i.executor.Exec(ctx, query, qargs...) } // FileToTableSystem can be used to store, retrieve and delete the @@ -318,7 +298,9 @@ func (f *FileToTableSystem) FileSize(ctx context.Context, filename string) (int6 // ListFiles returns a list of all the files which are currently stored in the // user scoped tables. -func (f *FileToTableSystem) ListFiles(ctx context.Context, pattern string) ([]string, error) { +func (f *FileToTableSystem) ListFiles( + ctx context.Context, pattern string, +) (retFiles []string, retErr error) { var files []string listFilesQuery := fmt.Sprintf(`SELECT filename FROM %s WHERE filename LIKE $1 ORDER BY filename`, f.GetFQFileTableName()) @@ -334,6 +316,12 @@ filename`, f.GetFQFileTableName()) case *InternalFileToTableExecutor: // Verify that all the filenames are strings and aggregate them. it := rows.internalExecResultsIterator + defer func() { + if err := it.Close(); err != nil { + retErr = errors.CombineErrors(retErr, err) + log.Warningf(ctx, "failed to close %+v", err) + } + }() var ok bool for ok, err = it.Next(ctx); ok; ok, err = it.Next(ctx) { files = append(files, string(tree.MustBeDString(it.Cur()[0]))) @@ -342,6 +330,12 @@ filename`, f.GetFQFileTableName()) return nil, err } case *SQLConnFileToTableExecutor: + defer func() { + if err := rows.sqlConnExecResults.Close(); err != nil { + retErr = errors.CombineErrors(retErr, err) + log.Warningf(ctx, "failed to close %+v", err) + } + }() vals := make([]driver.Value, 1) for { if err := rows.sqlConnExecResults.Next(vals); err == io.EOF { @@ -352,10 +346,6 @@ filename`, f.GetFQFileTableName()) filename := vals[0].(string) files = append(files, filename) } - - if err = rows.sqlConnExecResults.Close(); err != nil { - return nil, err - } default: return []string{}, errors.New("unsupported executor type in FileSize") } @@ -649,7 +639,7 @@ func newFileTableReader( fileTableName, payloadTableName string, ie FileToTableSystemExecutor, offset int64, -) (ioctx.ReadCloserCtx, int64, error) { +) (_ ioctx.ReadCloserCtx, _ int64, retErr error) { // Get file_id from metadata entry in File table. var fileID []byte var sz int64 @@ -668,6 +658,7 @@ func newFileTableReader( it := metaRows.internalExecResultsIterator defer func() { if err := it.Close(); err != nil { + retErr = errors.CombineErrors(retErr, err) log.Warningf(ctx, "failed to close %+v", err) } }() @@ -685,6 +676,7 @@ func newFileTableReader( case *SQLConnFileToTableExecutor: defer func() { if err := metaRows.sqlConnExecResults.Close(); err != nil { + retErr = errors.CombineErrors(retErr, err) log.Warningf(ctx, "failed to close %+v", err) } }() @@ -695,7 +687,8 @@ func newFileTableReader( } else if err != nil { return nil, 0, errors.Wrap(err, "failed to read returned file metadata") } - fileID = vals[0].([]byte) + uuidBytes := vals[0].([16]byte) + fileID = uuidBytes[:] if vals[1] != nil { sz = vals[1].(int64) } @@ -709,7 +702,7 @@ func newFileTableReader( const bufSize = 256 << 10 - fn := func(p []byte, pos int64) (int, error) { + fn := func(p []byte, pos int64) (_ int, retErr error) { if pos >= sz { return 0, io.EOF } @@ -730,6 +723,7 @@ func newFileTableReader( it := rows.internalExecResultsIterator defer func() { if err := it.Close(); err != nil { + retErr = errors.CombineErrors(retErr, err) log.Warningf(ctx, "failed to close %+v", err) } }() @@ -744,6 +738,7 @@ func newFileTableReader( case *SQLConnFileToTableExecutor: defer func() { if err := rows.sqlConnExecResults.Close(); err != nil { + retErr = errors.CombineErrors(retErr, err) log.Warningf(ctx, "failed to close %+v", err) } }() diff --git a/pkg/cmd/cockroach-sql/main.go b/pkg/cmd/cockroach-sql/main.go index c96b77191ea4..206b71a3d7fb 100644 --- a/pkg/cmd/cockroach-sql/main.go +++ b/pkg/cmd/cockroach-sql/main.go @@ -14,6 +14,7 @@ package main import ( + "context" "fmt" "os" @@ -191,5 +192,5 @@ func runSQL(cmd *cobra.Command, args []string) (resErr error) { defer func() { resErr = errors.CombineErrors(resErr, conn.Close()) }() cfg.ShellCtx.ParseURL = clienturl.MakeURLParserFn(cmd, copts) - return cfg.Run(conn) + return cfg.Run(context.Background(), conn) } diff --git a/pkg/sql/copy_file_upload.go b/pkg/sql/copy_file_upload.go index 27df7093c166..bc6b17e540ec 100644 --- a/pkg/sql/copy_file_upload.go +++ b/pkg/sql/copy_file_upload.go @@ -168,11 +168,11 @@ func CopyInFileStmt(destination, schema, table string) string { } func (f *fileUploadMachine) run(ctx context.Context) error { - err := f.c.run(ctx) - if err != nil && f.cancel != nil { + runErr := f.c.run(ctx) + err := errors.CombineErrors(f.w.Close(), runErr) + if runErr != nil && f.cancel != nil { f.cancel() } - err = errors.CombineErrors(f.w.Close(), err) if err != nil { f.failureCleanup() diff --git a/pkg/testutils/lint/lint_test.go b/pkg/testutils/lint/lint_test.go index 770015a42ba2..4b0c30e8dc21 100644 --- a/pkg/testutils/lint/lint_test.go +++ b/pkg/testutils/lint/lint_test.go @@ -445,6 +445,7 @@ func TestLint(t *testing.T) { ":!util/sdnotify/sdnotify_unix.go", ":!util/grpcutil", // GRPC_GO_* variables ":!roachprod", // roachprod requires AWS environment variables + ":!cli/env.go", // The CLI needs the PGHOST variable. }, }, } { diff --git a/pkg/util/tracing/zipper/zipper.go b/pkg/util/tracing/zipper/zipper.go index 56365cbd84e8..7a3d08f18e3d 100644 --- a/pkg/util/tracing/zipper/zipper.go +++ b/pkg/util/tracing/zipper/zipper.go @@ -201,13 +201,17 @@ func MakeInternalExecutorInflightTraceZipper( var _ InflightTraceZipper = &InternalInflightTraceZipper{} +type queryI interface { + Query(ctx context.Context, query string, args ...interface{}) (driver.Rows, error) +} + // SQLConnInflightTraceZipper is the InflightTraceZipper which uses a network // backed SQL connection to collect cluster wide traces. type SQLConnInflightTraceZipper struct { traceStrBuf *bytes.Buffer nodeTraceCollection *tracing.TraceCollection z *memzipper.Zipper - sqlConn driver.QueryerContext + sqlConn queryI } func (s *SQLConnInflightTraceZipper) getNodeTraceCollection() *tracing.TraceCollection { @@ -234,7 +238,7 @@ func (s *SQLConnInflightTraceZipper) reset() { // into text, and jaegerJSON formats before creating a zip with per-node trace // files. func (s *SQLConnInflightTraceZipper) Zip(ctx context.Context, traceID int64) ([]byte, error) { - rows, err := s.sqlConn.QueryContext(ctx, fmt.Sprintf(inflightTracesQuery, traceID), nil /* args */) + rows, err := s.sqlConn.Query(ctx, fmt.Sprintf(inflightTracesQuery, traceID)) if err != nil { return nil, err } @@ -341,7 +345,7 @@ func (s *SQLConnInflightTraceZipper) populateInflightTraceRow( // MakeSQLConnInflightTraceZipper returns an instance of // SQLConnInflightTraceZipper. -func MakeSQLConnInflightTraceZipper(sqlConn driver.QueryerContext) *SQLConnInflightTraceZipper { +func MakeSQLConnInflightTraceZipper(sqlConn queryI) *SQLConnInflightTraceZipper { t := &SQLConnInflightTraceZipper{ traceStrBuf: &bytes.Buffer{}, nodeTraceCollection: nil,