From b6410f5ad3553c639756e57dbeb317b28d4e14da Mon Sep 17 00:00:00 2001 From: Oliver Tan Date: Wed, 13 Apr 2022 06:11:45 +1000 Subject: [PATCH] clisqlshell: implement `COPY ... FROM STDIN` for CLI Steps: * Add a lower level API to lib/pq for use. * Add some abstraction boundary breakers in `clisqlclient` that allow a lower level handling of the COPY protocol. * Altered the state machine in `clisqlshell` to account for copy. Release note (cli change): COPY ... FROM STDIN now works from the cockroach CLI. It is not supported inside transactions. --- DEPS.bzl | 6 +- go.mod | 2 +- go.sum | 3 +- pkg/cli/clisqlclient/BUILD.bazel | 1 + pkg/cli/clisqlclient/copy.go | 116 ++++++++++++ pkg/cli/clisqlclient/make_query.go | 22 --- pkg/cli/clisqlexec/run_query.go | 1 + pkg/cli/clisqlshell/BUILD.bazel | 1 + pkg/cli/clisqlshell/sql.go | 226 +++++++++++++++++------- pkg/cli/clisqlshell/sql_test.go | 1 - pkg/cli/interactive_tests/test_copy.tcl | 150 ++++++++++++++++ pkg/sql/scanner/BUILD.bazel | 5 +- pkg/sql/scanner/scan.go | 11 ++ pkg/sql/scanner/scan_test.go | 36 ++++ vendor | 2 +- 15 files changed, 494 insertions(+), 89 deletions(-) create mode 100644 pkg/cli/clisqlclient/copy.go create mode 100644 pkg/cli/interactive_tests/test_copy.tcl diff --git a/DEPS.bzl b/DEPS.bzl index dedb7547fe2d..cbc2e4b09fb0 100644 --- a/DEPS.bzl +++ b/DEPS.bzl @@ -5059,10 +5059,10 @@ def go_deps(): name = "com_github_lib_pq", build_file_proto_mode = "disable_global", importpath = "github.com/lib/pq", - sha256 = "0f50cfc8d4ed4bbb39767aacc04d6b23e1105d2fa50dcb8e4ae204b2c90018f0", - strip_prefix = "github.com/lib/pq@v1.10.2", + sha256 = "5bca281c55dd8918e49a7e68d562eefb37f2cf17f7d45e1f3bd77e8eae49eb6e", + strip_prefix = "github.com/lib/pq@v1.10.6-0.20220412200556-b3b833258663", urls = [ - "https://storage.googleapis.com/cockroach-godeps/gomod/github.com/lib/pq/com_github_lib_pq-v1.10.2.zip", + "https://storage.googleapis.com/cockroach-godeps/gomod/github.com/lib/pq/com_github_lib_pq-v1.10.6-0.20220412200556-b3b833258663.zip", ], ) go_repository( diff --git a/go.mod b/go.mod index f39d3307425c..c8d76e47a75d 100644 --- a/go.mod +++ b/go.mod @@ -105,7 +105,7 @@ require ( github.com/kr/text v0.2.0 github.com/kylelemons/godebug v1.1.0 github.com/leanovate/gopter v0.2.5-0.20190402064358-634a59d12406 - github.com/lib/pq v1.10.2 + github.com/lib/pq v1.10.6-0.20220412200556-b3b833258663 github.com/lib/pq/auth/kerberos v0.0.0-20200720160335-984a6aa1ca46 github.com/linkedin/goavro/v2 v2.10.0 github.com/lufia/iostat v1.2.1 diff --git a/go.sum b/go.sum index a82300ddca97..fda947fa5554 100644 --- a/go.sum +++ b/go.sum @@ -1482,8 +1482,9 @@ github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.8.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8= github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.10.6-0.20220412200556-b3b833258663 h1:zQ4V0s+y+YrjamtmcRoERUVfsN/4jEb/pVdxrpL6zjU= +github.com/lib/pq v1.10.6-0.20220412200556-b3b833258663/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq/auth/kerberos v0.0.0-20200720160335-984a6aa1ca46 h1:q7hY+WNJTcSqJNGwJzXZYL++nWBaoKlKdgZOyY6jxz4= github.com/lib/pq/auth/kerberos v0.0.0-20200720160335-984a6aa1ca46/go.mod h1:jydegJvs5JvVcuFD/YAT8JRmRVeOoRhtnGEgRnAoPpE= github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= diff --git a/pkg/cli/clisqlclient/BUILD.bazel b/pkg/cli/clisqlclient/BUILD.bazel index 41fe97e0f784..6b2577a039ef 100644 --- a/pkg/cli/clisqlclient/BUILD.bazel +++ b/pkg/cli/clisqlclient/BUILD.bazel @@ -6,6 +6,7 @@ go_library( "api.go", "conn.go", "context.go", + "copy.go", "doc.go", "init_conn_error.go", "make_query.go", diff --git a/pkg/cli/clisqlclient/copy.go b/pkg/cli/clisqlclient/copy.go new file mode 100644 index 000000000000..733aaaefeba6 --- /dev/null +++ b/pkg/cli/clisqlclient/copy.go @@ -0,0 +1,116 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package clisqlclient + +import ( + "context" + "database/sql/driver" + "io" + "reflect" + "strings" + + "github.com/cockroachdb/errors" +) + +type copyFromer interface { + CopyData(ctx context.Context, line string) (r driver.Result, err error) + Exec(v []driver.Value) (r driver.Result, err error) + Close() error +} + +// CopyFromState represents an in progress COPY FROM. +type CopyFromState struct { + driver.Tx + copyFromer +} + +// BeginCopyFrom starts a COPY FROM query. +func BeginCopyFrom(ctx context.Context, conn Conn, query string) (*CopyFromState, error) { + txn, err := conn.(*sqlConn).conn.(driver.ConnBeginTx).BeginTx(ctx, driver.TxOptions{}) + if err != nil { + return nil, err + } + stmt, err := txn.(driver.Conn).Prepare(query) + if err != nil { + return nil, errors.CombineErrors(err, txn.Rollback()) + } + return &CopyFromState{Tx: txn, copyFromer: stmt.(copyFromer)}, nil +} + +// copyFromRows is a mock Rows interface for COPY results. +type copyFromRows struct { + r driver.Result +} + +func (c copyFromRows) Close() error { + return nil +} + +func (c copyFromRows) Columns() []string { + return nil +} + +func (c copyFromRows) ColumnTypeScanType(index int) reflect.Type { + return nil +} + +func (c copyFromRows) ColumnTypeDatabaseTypeName(index int) string { + return "" +} + +func (c copyFromRows) ColumnTypeNames() []string { + return nil +} + +func (c copyFromRows) Result() driver.Result { + return c.r +} + +func (c copyFromRows) Tag() string { + return "COPY" +} + +func (c copyFromRows) Next(values []driver.Value) error { + return io.EOF +} + +func (c copyFromRows) NextResultSet() (bool, error) { + return false, nil +} + +// Cancel cancels a COPY FROM query from completing. +func (c *CopyFromState) Cancel() error { + return errors.CombineErrors(c.copyFromer.Close(), c.Tx.Rollback()) +} + +// Commit completes a COPY FROM query by committing lines to the database. +func (c *CopyFromState) Commit(ctx context.Context, cleanupFunc func(), lines string) QueryFn { + return func(ctx context.Context, conn Conn) (Rows, bool, error) { + defer cleanupFunc() + rows, isMulti, err := func() (Rows, bool, error) { + for _, l := range strings.Split(lines, "\n") { + _, err := c.copyFromer.CopyData(ctx, l) + if err != nil { + return nil, false, err + } + } + r, err := c.copyFromer.Exec(nil) + if err != nil { + return nil, false, err + } + return copyFromRows{r: r}, false, c.Tx.Commit() + }() + if err != nil { + return rows, isMulti, errors.CombineErrors(err, errors.CombineErrors(c.copyFromer.Close(), c.Tx.Rollback())) + } + return rows, isMulti, err + } +} diff --git a/pkg/cli/clisqlclient/make_query.go b/pkg/cli/clisqlclient/make_query.go index 6fbbaf944ee1..e43a3f3e02a3 100644 --- a/pkg/cli/clisqlclient/make_query.go +++ b/pkg/cli/clisqlclient/make_query.go @@ -13,10 +13,8 @@ package clisqlclient import ( "context" "database/sql/driver" - "strings" "github.com/cockroachdb/cockroach/pkg/sql/scanner" - "github.com/cockroachdb/errors" ) // QueryFn is the type of functions produced by MakeQuery. @@ -28,7 +26,6 @@ func MakeQuery(query string, parameters ...interface{}) QueryFn { return func(ctx context.Context, conn Conn) (Rows, bool, error) { isMultiStatementQuery, _ := scanner.HasMultipleStatements(query) rows, err := conn.Query(ctx, query, parameters...) - err = handleCopyError(conn.(*sqlConn), err) return rows, isMultiStatementQuery, err } } @@ -51,22 +48,3 @@ func convertArgs(parameters []interface{}) ([]driver.NamedValue, error) { } return dVals, nil } - -// handleCopyError ensures the user is properly informed when they issue -// a COPY statement somewhere in their input. -func handleCopyError(conn *sqlConn, err error) error { - if err == nil { - return nil - } - - if !strings.HasPrefix(err.Error(), "pq: unknown response for simple query: 'G'") { - return err - } - - // The COPY statement has hosed the connection by putting the - // protocol in a state that lib/pq cannot understand any more. Reset - // it. - _ = conn.Close() - conn.reconnecting = true - return errors.New("woops! COPY has confused this client! Suggestion: use 'psql' for COPY") -} diff --git a/pkg/cli/clisqlexec/run_query.go b/pkg/cli/clisqlexec/run_query.go index 75ce1f37dfe4..3754a36d6169 100644 --- a/pkg/cli/clisqlexec/run_query.go +++ b/pkg/cli/clisqlexec/run_query.go @@ -282,6 +282,7 @@ var tagsWithRowsAffected = map[string]struct{}{ "DELETE": {}, "MOVE": {}, "DROP USER": {}, + "COPY": {}, // This one is used with e.g. CREATE TABLE AS (other SELECT // statements have type Rows, not RowsAffected). "SELECT": {}, diff --git a/pkg/cli/clisqlshell/BUILD.bazel b/pkg/cli/clisqlshell/BUILD.bazel index 8eaaaf14da88..5e24c087f2d7 100644 --- a/pkg/cli/clisqlshell/BUILD.bazel +++ b/pkg/cli/clisqlshell/BUILD.bazel @@ -26,6 +26,7 @@ go_library( "//pkg/sql/scanner", "//pkg/sql/sqlfsm", "//pkg/util/envutil", + "//pkg/util/errorutil/unimplemented", "//pkg/util/syncutil", "@com_github_cockroachdb_errors//:errors", "@com_github_knz_go_libedit//:go-libedit", diff --git a/pkg/cli/clisqlshell/sql.go b/pkg/cli/clisqlshell/sql.go index 714b1686e3a5..9765e806dc44 100644 --- a/pkg/cli/clisqlshell/sql.go +++ b/pkg/cli/clisqlshell/sql.go @@ -39,6 +39,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/scanner" "github.com/cockroachdb/cockroach/pkg/sql/sqlfsm" "github.com/cockroachdb/cockroach/pkg/util/envutil" + "github.com/cockroachdb/cockroach/pkg/util/errorutil/unimplemented" "github.com/cockroachdb/errors" readline "github.com/knz/go-libedit" ) @@ -157,6 +158,9 @@ type cliState struct { // The current prompt, either fullPrompt or continuePrompt. currentPrompt string + // State of COPY FROM on the client. + copyFromState *clisqlclient.CopyFromState + // State // // lastInputLine is the last valid line obtained from readline. @@ -283,6 +287,14 @@ func (c *cliState) invalidSyntax(nextState cliStateEnum) cliStateEnum { return c.invalidSyntaxf(nextState, `%s. Try \? for help.`, c.lastInputLine) } +func (c *cliState) inCopy() bool { + return c.copyFromState != nil +} + +func (c *cliState) resetCopy() { + c.copyFromState = nil +} + func (c *cliState) invalidSyntaxf( nextState cliStateEnum, format string, args ...interface{}, ) cliStateEnum { @@ -756,7 +768,9 @@ func (c *cliState) doRefreshPrompts(nextState cliStateEnum) cliStateEnum { } if c.useContinuePrompt { - if len(c.fullPrompt) < 3 { + if c.inCopy() { + c.continuePrompt = ">> " + } else if len(c.fullPrompt) < 3 { c.continuePrompt = "> " } else { // continued statement prompt is: " -> ". @@ -767,61 +781,64 @@ func (c *cliState) doRefreshPrompts(nextState cliStateEnum) cliStateEnum { return nextState } - // Configure the editor to use the new prompt. - - parsedURL, err := url.Parse(c.conn.GetURL()) - if err != nil { - // If parsing fails, we'll keep the entire URL. The Open call succeeded, and that - // is the important part. - c.fullPrompt = c.conn.GetURL() + "> " - c.continuePrompt = strings.Repeat(" ", len(c.fullPrompt)-3) + "-> " - return nextState - } + if c.inCopy() { + c.fullPrompt = ">>" + } else { + // Configure the editor to use the new prompt. - userName := "" - if parsedURL.User != nil { - userName = parsedURL.User.Username() - } + parsedURL, err := url.Parse(c.conn.GetURL()) + if err != nil { + // If parsing fails, we'll keep the entire URL. The Open call succeeded, and that + // is the important part. + c.fullPrompt = c.conn.GetURL() + "> " + c.continuePrompt = strings.Repeat(" ", len(c.fullPrompt)-3) + "-> " + return nextState + } - dbName := unknownDbName - c.lastKnownTxnStatus = unknownTxnStatus + userName := "" + if parsedURL.User != nil { + userName = parsedURL.User.Username() + } - wantDbStateInPrompt := rePromptDbState.MatchString(c.iCtx.customPromptPattern) - if wantDbStateInPrompt { - c.refreshTransactionStatus() - // refreshDatabaseName() must be called *after* refreshTransactionStatus(), - // even when %/ appears before %x in the prompt format. - // This is because the database name should not be queried during - // some transaction phases. - dbName = c.refreshDatabaseName() - } - - c.fullPrompt = rePromptFmt.ReplaceAllStringFunc(c.iCtx.customPromptPattern, func(m string) string { - switch m { - case "%M": - return parsedURL.Host // full host name. - case "%m": - return parsedURL.Hostname() // host name. - case "%>": - return parsedURL.Port() // port. - case "%n": // user name. - return userName - case "%/": // database name. - return dbName - case "%x": // txn status. - return c.lastKnownTxnStatus - case "%%": - return "%" - default: - err = fmt.Errorf("unrecognized format code in prompt: %q", m) - return "" + dbName := unknownDbName + c.lastKnownTxnStatus = unknownTxnStatus + + wantDbStateInPrompt := rePromptDbState.MatchString(c.iCtx.customPromptPattern) + if wantDbStateInPrompt { + c.refreshTransactionStatus() + // refreshDatabaseName() must be called *after* refreshTransactionStatus(), + // even when %/ appears before %x in the prompt format. + // This is because the database name should not be queried during + // some transaction phases. + dbName = c.refreshDatabaseName() } - }) - if err != nil { - c.fullPrompt = err.Error() - } + c.fullPrompt = rePromptFmt.ReplaceAllStringFunc(c.iCtx.customPromptPattern, func(m string) string { + switch m { + case "%M": + return parsedURL.Host // full host name. + case "%m": + return parsedURL.Hostname() // host name. + case "%>": + return parsedURL.Port() // port. + case "%n": // user name. + return userName + case "%/": // database name. + return dbName + case "%x": // txn status. + return c.lastKnownTxnStatus + case "%%": + return "%" + default: + err = fmt.Errorf("unrecognized format code in prompt: %q", m) + return "" + } + }) + if err != nil { + c.fullPrompt = err.Error() + } + } c.fullPrompt += " " c.currentPrompt = c.fullPrompt @@ -896,9 +913,14 @@ func (c *cliState) refreshDatabaseName() string { var cmdHistFile = envutil.EnvOrDefaultString("COCKROACH_SQL_CLI_HISTORY", ".cockroachsql_history") // GetCompletions implements the readline.CompletionGenerator interface. -func (c *cliState) GetCompletions(_ string) []string { +func (c *cliState) GetCompletions(s string) []string { sql, _ := c.ins.GetLineInfo() + // In COPY mode, just add a tab character. + if c.inCopy() { + return []string{s + "\t"} + } + if !strings.HasSuffix(sql, "??") { query := fmt.Sprintf(`SHOW COMPLETIONS AT OFFSET %d FOR %s`, len(sql), lexbase.EscapeSQLString(sql)) var rows [][]string @@ -1027,6 +1049,26 @@ func (c *cliState) doReadLine(nextState cliStateEnum) cliStateEnum { return cliStop } + if c.inCopy() { + // CTRL+C in COPY cancels the copy. + defer func() { + c.resetCopy() + c.partialLines = c.partialLines[:0] + c.partialStmtsLen = 0 + c.useContinuePrompt = false + }() + c.exitErr = errors.CombineErrors( + pgerror.Newf(pgcode.QueryCanceled, "COPY canceled by user"), + c.copyFromState.Cancel(), + ) + if c.exitErr != nil { + if !c.singleStatement { + clierror.OutputError(c.iCtx.stderr, c.exitErr, true /*showSeverity*/, false /*verbose*/) + } + } + return cliRefreshPrompt + } + if l != "" { // Ctrl+C after the beginning of a line cancels the current // line. @@ -1046,6 +1088,11 @@ func (c *cliState) doReadLine(nextState cliStateEnum) cliStateEnum { return cliStartLine case errors.Is(err, io.EOF): + // If we're in COPY and we're interactive, this signifies the copy is complete. + if c.inCopy() && c.cliCtx.IsInteractive { + return cliRunStatement + } + c.atEOF = true if c.cliCtx.IsInteractive { @@ -1179,6 +1226,25 @@ func (c *cliState) doHandleCliCmd(loopState, nextState cliStateEnum) cliStateEnu c.concatLines = `SHOW TABLES` return cliRunStatement + case `\copy`: + c.exitErr = c.runWithInterruptableCtx(func(ctx context.Context) error { + return c.beginCopyFrom(ctx, c.concatLines) + }) + if !c.singleStatement { + clierror.OutputError(c.iCtx.stderr, c.exitErr, true /*showSeverity*/, false /*verbose*/) + } + if c.exitErr != nil && c.iCtx.errExit { + return cliStop + } + return cliStartLine + + case `\.`: + if c.inCopy() { + c.concatLines += "\n" + `\.` + return cliRunStatement + } + return c.invalidSyntax(errState) + case `\dT`: c.concatLines = `SHOW TYPES` return cliRunStatement @@ -1527,12 +1593,13 @@ func (c *cliState) doPrepareStatementLine( } return startState } - endOfStmt := isEndOfStatement(lastTok) - if c.singleStatement && c.atEOF { + endOfStmt := (!c.inCopy() && isEndOfStatement(lastTok)) || + // We're always at the end of a statement if we're in COPY and encounter + // the \. or EOF character. + (c.inCopy() && (strings.HasSuffix(c.concatLines, "\n"+`\.`) || c.atEOF)) || // We're always at the end of a statement if EOF is reached in the // single statement mode. - endOfStmt = true - } + c.singleStatement && c.atEOF if c.atEOF { // Definitely no more input expected. if !endOfStmt { @@ -1551,7 +1618,9 @@ func (c *cliState) doPrepareStatementLine( } // Complete input. Remember it in the history. - c.addHistory(c.concatLines) + if !c.inCopy() { + c.addHistory(c.concatLines) + } if !c.iCtx.checkSyntax { return execState @@ -1561,6 +1630,10 @@ func (c *cliState) doPrepareStatementLine( } func (c *cliState) doCheckStatement(startState, contState, execState cliStateEnum) cliStateEnum { + // If we are in COPY, we have no valid SQL, so skip directly to the next state. + if c.inCopy() { + return execState + } // From here on, client-side syntax checking is enabled. helpText, err := c.serverSideParse(c.concatLines) if err != nil { @@ -1626,9 +1699,24 @@ func (c *cliState) doRunStatements(nextState cliStateEnum) cliStateEnum { // Now run the statement/query. c.exitErr = c.runWithInterruptableCtx(func(ctx context.Context) error { - return c.sqlExecCtx.RunQueryAndFormatResults(ctx, - c.conn, c.iCtx.stdout, c.iCtx.stderr, - clisqlclient.MakeQuery(c.concatLines)) + if scanner.FirstLexicalToken(c.concatLines) == lexbase.COPY { + return c.beginCopyFrom(ctx, c.concatLines) + } + q := clisqlclient.MakeQuery(c.concatLines) + if c.inCopy() { + q = c.copyFromState.Commit( + ctx, + c.resetCopy, + c.concatLines, + ) + } + return c.sqlExecCtx.RunQueryAndFormatResults( + ctx, + c.conn, + c.iCtx.stdout, + c.iCtx.stderr, + q, + ) }) if c.exitErr != nil { if !c.singleStatement { @@ -1685,6 +1773,26 @@ func (c *cliState) doRunStatements(nextState cliStateEnum) cliStateEnum { return nextState } +func (c *cliState) beginCopyFrom(ctx context.Context, sql string) error { + c.refreshTransactionStatus() + if c.lastKnownTxnStatus != "" { + return unimplemented.Newf( + "cli_copy_in_txn", + "cannot use COPY inside a transaction", + ) + } + copyFromState, err := clisqlclient.BeginCopyFrom(ctx, c.conn, sql) + if err != nil { + return err + } + c.copyFromState = copyFromState + if c.cliCtx.IsInteractive { + fmt.Fprintln(c.iCtx.stdout, `Enter data to be copied followed by a newline.`) + fmt.Fprintln(c.iCtx.stdout, `End with a backslash and a period on a line by itself, or an EOF signal.`) + } + return nil +} + func (c *cliState) doDecidePath() cliStateEnum { if len(c.partialLines) == 0 { return cliProcessFirstLine diff --git a/pkg/cli/clisqlshell/sql_test.go b/pkg/cli/clisqlshell/sql_test.go index beec09430d94..ecf7b4b61dfb 100644 --- a/pkg/cli/clisqlshell/sql_test.go +++ b/pkg/cli/clisqlshell/sql_test.go @@ -112,7 +112,6 @@ func Example_sql() { // sql -d nonexistent -e create database nonexistent; create table foo(x int); select * from foo // x // sql -e copy t.f from stdin - // ERROR: -e: woops! COPY has confused this client! Suggestion: use 'psql' for COPY // sql -e select 1/(@1-2) from generate_series(1,3) // ?column? // -1.0000000000000000000 diff --git a/pkg/cli/interactive_tests/test_copy.tcl b/pkg/cli/interactive_tests/test_copy.tcl new file mode 100644 index 000000000000..8084c58ce703 --- /dev/null +++ b/pkg/cli/interactive_tests/test_copy.tcl @@ -0,0 +1,150 @@ +#! /usr/bin/env expect -f + +source [file join [file dirname $argv0] common.tcl] + +start_server $argv + +spawn $argv sql +eexpect root@ + +send "DROP TABLE IF EXISTS t;\r" +send "CREATE TABLE t (id INT PRIMARY KEY, t TEXT);\r" + +start_test "Check that errors are reported as appropriate." + +send "COPY invalid_table FROM STDIN;\r" +eexpect "ERROR: relation \"invalid_table\" does not exist" +eexpect root@ + +send "COPY t FROM STDIN;\r" +eexpect "Enter data to be copied followed by a newline." +eexpect "End with a backslash and a period on a line by itself, or an EOF signal." +eexpect ">>" + +send "invalid text field\ttext with semicolon;\r" +send "\\.\r" + +eexpect "could not parse" + +end_test + +start_test "multi statement with COPY" +send "SELECT 1; COPY t FROM STDIN CSV;\r" +eexpect "COPY together with other statements in a query string is not supported" +eexpect root@ +send "COPY t FROM STDIN CSV;SELECT 1;\r" +eexpect "COPY together with other statements in a query string is not supported" +eexpect root@ +end_test + +start_test "Copy in transaction" +send "BEGIN;\r" +eexpect root@ +send "COPY t FROM STDIN CSV;\r" +eexpect "cannot use COPY inside a transaction" +eexpect root@ +send "ROLLBACK;\r" +eexpect root@ +end_test + +start_test "Check EOF and \. works as appropriate during COPY" + +send "COPY t FROM STDIN CSV;\r" +eexpect ">>" +send "1,text with semicolon;\r" +send "2,beat chef@;\r" +send "3,more&text\r" +send "\\.\r" + +eexpect "COPY 3" +eexpect root@ + +# Try \copy as well. +send "\copy t FROM STDIN CSV;\r" +eexpect ">>" +send "4,epa! epa!\r" +send_eof + +eexpect "COPY 1" +eexpect root@ + +send "SELECT * FROM t ORDER BY id ASC;\r" + +eexpect "1 | text with semicolon;" +eexpect "2 | beat chef@;" +eexpect "3 | more&text" +eexpect "4 | epa! epa!" +eexpect "(4 rows)" + +eexpect root@ + +end_test + +start_test "check CTRL+C during COPY exits the COPY mode as appropriate" + +send "COPY t FROM STDIN CSV;\r" +eexpect ">>" +send "5,cancel me\r" + +interrupt + +eexpect "ERROR: COPY canceled by user" +eexpect root@ + +send "SELECT * FROM t ORDER BY id ASC;\r" +eexpect "(4 rows)" +eexpect root@ + +send "TRUNCATE TABLE t;\r" +eexpect root@ + +end_test + +send_eof +eexpect eof + + +spawn /bin/bash +send "PS1=':''/# '\r" +eexpect ":/# " + +start_test "Test file input invalid" + +send "cat >/tmp/test_copy.sql </tmp/test_copy.sql <> /tmp/test_copy.sql\r" +eexpect ":/# " +send "$argv sql --insecure -f /tmp/test_copy.sql\r" +eexpect ":/# " +send "$argv sql --insecure -e 'SELECT * FROM t ORDER BY id'\r" +eexpect "1 | a" +eexpect "2 | b" +eexpect "(2 rows)" +eexpect ":/# " + +send "$argv sql --insecure -e 'TRUNCATE TABLE t'\r" +eexpect ":/# " +send "$argv sql --insecure < /tmp/test_copy.sql\r" +eexpect ":/# " +send "$argv sql --insecure -e 'SELECT * FROM t ORDER BY id'\r" +eexpect "1 | a" +eexpect "2 | b" +eexpect "(2 rows)" +eexpect ":/# " + +end_test + +stop_server $argv diff --git a/pkg/sql/scanner/BUILD.bazel b/pkg/sql/scanner/BUILD.bazel index 90f424dc2447..c74c89a8952c 100644 --- a/pkg/sql/scanner/BUILD.bazel +++ b/pkg/sql/scanner/BUILD.bazel @@ -12,5 +12,8 @@ go_test( name = "scanner_test", srcs = ["scan_test.go"], embed = [":scanner"], - deps = ["//pkg/sql/lexbase"], + deps = [ + "//pkg/sql/lexbase", + "@com_github_stretchr_testify//require", + ], ) diff --git a/pkg/sql/scanner/scan.go b/pkg/sql/scanner/scan.go index b5bb12d4ef1b..ed303af6c22f 100644 --- a/pkg/sql/scanner/scan.go +++ b/pkg/sql/scanner/scan.go @@ -1064,6 +1064,17 @@ func LastLexicalToken(sql string) (lastTok int, ok bool) { } } +// FirstLexicalToken returns the first lexical token. +// Returns 0 if there is no token. +func FirstLexicalToken(sql string) (tok int) { + var s Scanner + var lval fakeSym + s.Init(sql) + s.Scan(&lval) + id := lval.ID() + return int(id) +} + // fakeSym is a simplified symbol type for use by // HasMultipleStatements. type fakeSym struct { diff --git a/pkg/sql/scanner/scan_test.go b/pkg/sql/scanner/scan_test.go index f9f0cfb516d5..0cd26f7ae752 100644 --- a/pkg/sql/scanner/scan_test.go +++ b/pkg/sql/scanner/scan_test.go @@ -15,6 +15,7 @@ import ( "testing" "github.com/cockroachdb/cockroach/pkg/sql/lexbase" + "github.com/stretchr/testify/require" ) func TestHasMultipleStatements(t *testing.T) { @@ -43,6 +44,41 @@ func TestHasMultipleStatements(t *testing.T) { } } +func TestFirstLexicalToken(t *testing.T) { + tests := []struct { + s string + res int + }{ + { + s: "", + res: 0, + }, + { + s: " /* comment */ ", + res: 0, + }, + { + s: "SELECT", + res: lexbase.SELECT, + }, + { + s: "SELECT 1", + res: lexbase.SELECT, + }, + { + s: "SELECT 1;", + res: lexbase.SELECT, + }, + } + + for i, tc := range tests { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + tok := FirstLexicalToken(tc.s) + require.Equal(t, tc.res, tok) + }) + } +} + func TestLastLexicalToken(t *testing.T) { tests := []struct { s string diff --git a/vendor b/vendor index b46e3ef11309..1bbc6bb97b35 160000 --- a/vendor +++ b/vendor @@ -1 +1 @@ -Subproject commit b46e3ef1130903d6aca62cdff3a0079eda1ae7f7 +Subproject commit 1bbc6bb97b351af886b3f806f66a40057fefa3ae