diff --git a/pkg/cli/sql.go b/pkg/cli/sql.go index bb29655a8bc8..973bbda0c28f 100644 --- a/pkg/cli/sql.go +++ b/pkg/cli/sql.go @@ -154,7 +154,8 @@ type cliState struct { partialStmtsLen int // concatLines is the concatenation of partialLines, computed during - // doCheckStatement and then reused in doRunStatement(). + // doPrepareStatementLine and then reused in doRunStatements() and + // doCheckStatement(). concatLines string // exitErr defines the error to report to the user upon termination. @@ -1135,7 +1136,7 @@ func (c *cliState) doCheckStatement(startState, contState, execState cliStateEnu fmt.Println(helpText) } - _ = c.invalidSyntax(0, "statement ignored: %v", + _ = c.invalidSyntax(cliStart, "statement ignored: %v", &formattedError{err: err, showSeverity: false, verbose: false}) // Stop here if exiterr is set. @@ -1163,7 +1164,28 @@ func (c *cliState) doCheckStatement(startState, contState, execState cliStateEnu return nextState } -func (c *cliState) doRunStatement(nextState cliStateEnum) cliStateEnum { +// doRunStatements runs all the statements that have been accumulated by +// concatLines. +func (c *cliState) doRunStatements(nextState cliStateEnum) cliStateEnum { + stmts, err := parser.Parse(c.concatLines) + if err != nil { + c.exitErr = err + return cliStop + } + for _, stmt := range stmts { + c.doRunStatement(stmt.SQL) + if c.exitErr != nil { + if c.errExit { + return cliStop + } + return nextState + } + } + return nextState +} + +// doRunStatement runs a single sql statement. +func (c *cliState) doRunStatement(stmt string) { // Once we send something to the server, the txn status may change arbitrarily. // Clear the known state so that further entries do not assume anything. c.lastKnownTxnStatus = " ?" @@ -1175,15 +1197,12 @@ func (c *cliState) doRunStatement(nextState cliStateEnum) cliStateEnum { c.exitErr = c.conn.Exec("SET tracing = off; SET tracing = "+c.autoTrace, nil) if c.exitErr != nil { cliOutputError(stderr, c.exitErr, true /*showSeverity*/, false /*verbose*/) - if c.errExit { - return cliStop - } - return nextState + return } } // Now run the statement/query. - c.exitErr = runQueryAndFormatResults(c.conn, os.Stdout, makeQuery(c.concatLines)) + c.exitErr = runQueryAndFormatResults(c.conn, os.Stdout, makeQuery(stmt)) if c.exitErr != nil { cliOutputError(stderr, c.exitErr, true /*showSeverity*/, false /*verbose*/) } @@ -1225,12 +1244,6 @@ func (c *cliState) doRunStatement(nextState cliStateEnum) cliStateEnum { } } } - - if c.exitErr != nil && c.errExit { - return cliStop - } - - return nextState } func (c *cliState) doDecidePath() cliStateEnum { @@ -1299,7 +1312,7 @@ func runInteractive(conn *sqlConn) (exitErr error) { state = c.doCheckStatement(cliStartLine, cliContinueLine, cliRunStatement) case cliRunStatement: - state = c.doRunStatement(cliStartLine) + state = c.doRunStatements(cliStartLine) default: panic(fmt.Sprintf("unknown state: %d", state)) @@ -1413,7 +1426,7 @@ func (c *cliState) configurePreShellDefaults() (cleanupFn func(), err error) { return cleanupFn, nil } -// runOneStatement executes one statement and terminates +// runStatements executes the given statements and terminates // on error. func (c *cliState) runStatements(stmts []string) error { for { diff --git a/pkg/cli/sql_util.go b/pkg/cli/sql_util.go index 9ba7bd0b4aa4..95723ff68097 100644 --- a/pkg/cli/sql_util.go +++ b/pkg/cli/sql_util.go @@ -320,8 +320,14 @@ func (c *sqlConn) requireServerVersion(required *version.Version) error { // given sql query. If the query fails or does not return a single // column, `false` is returned in the second result. func (c *sqlConn) getServerValue(what, sql string) (driver.Value, bool) { - var dbVals [1]driver.Value + return c.getServerValueFromColumnIndex(what, sql, 0) +} +// getServerValueFromColumnIndex retrieves the driverValue at a particular +// column index from the result of the given sql query. If the query fails or +// does not return at least as many columns as colIdx - 1, `false` is returned +// in the second result. +func (c *sqlConn) getServerValueFromColumnIndex(what, sql string, colIdx int) (driver.Value, bool) { rows, err := c.Query(sql, nil) if err != nil { fmt.Fprintf(stderr, "warning: error retrieving the %s: %v\n", what, err) @@ -329,18 +335,20 @@ func (c *sqlConn) getServerValue(what, sql string) (driver.Value, bool) { } defer func() { _ = rows.Close() }() - if len(rows.Columns()) == 0 { + if len(rows.Columns()) <= colIdx { fmt.Fprintf(stderr, "warning: cannot get the %s\n", what) return nil, false } + dbVals := make([]driver.Value, len(rows.Columns())) + err = rows.Next(dbVals[:]) if err != nil { fmt.Fprintf(stderr, "warning: invalid %s: %v\n", what, err) return nil, false } - return dbVals[0], true + return dbVals[colIdx], true } // sqlTxnShim implements the crdb.Tx interface. @@ -692,7 +700,6 @@ var tagsWithRowsAffected = map[string]struct{}{ // runQueryAndFormatResults takes a 'query' with optional 'parameters'. // It runs the sql query and writes output to 'w'. func runQueryAndFormatResults(conn *sqlConn, w io.Writer, fn queryFunc) error { - startTime := timeutil.Now() rows, err := fn(conn) if err != nil { return handleCopyError(conn, err) @@ -784,13 +791,14 @@ func runQueryAndFormatResults(conn *sqlConn, w io.Writer, fn queryFunc) error { } if sqlCtx.showTimes { - // Present the time since the last result, or since the - // beginning of execution. Currently the execution engine makes - // all the work upfront so most of the time is accounted for by - // the 1st result; this is subject to change once CockroachDB - // evolves to stream results as statements are executed. - fmt.Fprintf(w, "\nTime: %s\n", queryCompleteTime.Sub(startTime)) - // Make users better understand any discrepancy they observe. + // TODO(arul): This 2 is kind of ugly and unfortunate. + execLatencyRaw, hasVal := conn.getServerValueFromColumnIndex( + "last query statistics", `SHOW LAST QUERY STATISTICS`, 2 /* exec_latency */) + if hasVal { + execLatency := formatVal(execLatencyRaw, false, false) + parsed, _ := tree.ParseDInterval(execLatency) + fmt.Fprintf(w, "\nServer Execution Time: %s\n", time.Duration(parsed.Duration.Nanos())) + } renderDelay := timeutil.Now().Sub(queryCompleteTime) if renderDelay >= 1*time.Second { fmt.Fprintf(w, @@ -799,8 +807,6 @@ func runQueryAndFormatResults(conn *sqlConn, w io.Writer, fn queryFunc) error { renderDelay) } fmt.Fprintln(w) - // Reset the clock. We ignore the rendering time. - startTime = timeutil.Now() } if more, err := rows.NextResultSet(); err != nil {