diff --git a/pkg/cli/context.go b/pkg/cli/context.go index 554dea0fd0c3..262d8fab7b0c 100644 --- a/pkg/cli/context.go +++ b/pkg/cli/context.go @@ -195,7 +195,7 @@ func setCliContextDefaults() { cliCtx.allowUnencryptedClientPassword = false } -// sqlCtx captures the command-line parameters of the `sql` command. +// sqlCtx captures the configuration of the `sql` command. // See below for defaults. var sqlCtx = struct { *cliContext @@ -238,6 +238,19 @@ var sqlCtx = struct { // Determine whether to show raw durations. verboseTimings bool + + // Determines whether to stop the client upon encountering an error. + errExit bool + + // Determines whether to perform client-side syntax checking. + checkSyntax bool + + // autoTrace, when non-empty, encloses the executed statements + // by suitable SET TRACING and SHOW TRACE FOR SESSION statements. + autoTrace string + + // The string used to produce the value of fullPrompt. + customPromptPattern string }{cliContext: &cliCtx} // setSQLContextDefaults set the default values in sqlCtx. This @@ -254,6 +267,10 @@ func setSQLContextDefaults() { sqlCtx.echo = false sqlCtx.enableServerExecutionTimings = false sqlCtx.verboseTimings = false + sqlCtx.errExit = true + sqlCtx.checkSyntax = true + sqlCtx.autoTrace = "" + sqlCtx.customPromptPattern = defaultPromptPattern } // zipCtx captures the command-line parameters of the `zip` command. diff --git a/pkg/cli/sql.go b/pkg/cli/sql.go index 2264dc125e92..381933a83d74 100644 --- a/pkg/cli/sql.go +++ b/pkg/cli/sql.go @@ -111,6 +111,11 @@ Open a sql shell running against a cockroach database. // cliState defines the current state of the CLI during // command-line processing. +// +// Note: options customizable via \set and \unset should be defined in +// sqlCtx or cliCtx instead, so that the configuration remains globals +// across multiple instances of cliState (e.g. across file inclusion +// with \i). type cliState struct { conn *sqlConn // ins is used to read lines if isInteractive is true. @@ -118,13 +123,6 @@ type cliState struct { // buf is used to read lines if isInteractive is false. buf *bufio.Reader - // Options - // - // Determines whether to stop the client upon encountering an error. - errExit bool - // Determines whether to perform client-side syntax checking. - checkSyntax bool - // The prompt at the beginning of a multi-line entry. fullPrompt string // The prompt on a continuation line in a multi-line entry. @@ -133,8 +131,6 @@ type cliState struct { useContinuePrompt bool // The current prompt, either fullPrompt or continuePrompt. currentPrompt string - // The string used to produce the value of fullPrompt. - customPromptPattern string // State // @@ -178,10 +174,6 @@ type cliState struct { // by Ctrl+D, causes the shell to terminate with an error -- // reporting the status of the last valid SQL statement executed. exitErr error - - // autoTrace, when non-empty, encloses the executed statements - // by suitable SET TRACING and SHOW TRACE FOR SESSION statements. - autoTrace string } // cliStateEnum drives the CLI state machine in runInteractive(). @@ -292,48 +284,48 @@ var options = map[string]struct { description string isBoolean bool validDuringMultilineEntry bool - set func(c *cliState, val string) error - reset func(c *cliState) error + set func(val string) error + reset func() error // display is used to retrieve the current value. - display func(c *cliState) string + display func() string deprecated bool }{ `auto_trace`: { description: "automatically run statement tracing on each executed statement", isBoolean: false, validDuringMultilineEntry: false, - set: func(c *cliState, val string) error { + set: func(val string) error { val = strings.ToLower(strings.TrimSpace(val)) switch val { case "false", "0", "off": - c.autoTrace = "" + sqlCtx.autoTrace = "" case "true", "1": val = "on" fallthrough default: - c.autoTrace = "on, " + val + sqlCtx.autoTrace = "on, " + val } return nil }, - reset: func(c *cliState) error { - c.autoTrace = "" + reset: func() error { + sqlCtx.autoTrace = "" return nil }, - display: func(c *cliState) string { - if c.autoTrace == "" { + display: func() string { + if sqlCtx.autoTrace == "" { return "off" } - return c.autoTrace + return sqlCtx.autoTrace }, }, `display_format`: { description: "the output format for tabular data (table, csv, tsv, html, sql, records, raw)", isBoolean: false, validDuringMultilineEntry: true, - set: func(_ *cliState, val string) error { + set: func(val string) error { return cliCtx.tableDisplayFormat.Set(val) }, - reset: func(_ *cliState) error { + reset: func() error { displayFormat := tableDisplayTSV if cliCtx.terminalOutput { displayFormat = tableDisplayTable @@ -341,78 +333,78 @@ var options = map[string]struct { cliCtx.tableDisplayFormat = displayFormat return nil }, - display: func(_ *cliState) string { return cliCtx.tableDisplayFormat.String() }, + display: func() string { return cliCtx.tableDisplayFormat.String() }, }, `echo`: { description: "show SQL queries before they are sent to the server", isBoolean: true, validDuringMultilineEntry: false, - set: func(_ *cliState, _ string) error { sqlCtx.echo = true; return nil }, - reset: func(_ *cliState) error { sqlCtx.echo = false; return nil }, - display: func(_ *cliState) string { return strconv.FormatBool(sqlCtx.echo) }, + set: func(_ string) error { sqlCtx.echo = true; return nil }, + reset: func() error { sqlCtx.echo = false; return nil }, + display: func() string { return strconv.FormatBool(sqlCtx.echo) }, }, `errexit`: { description: "exit the shell upon a query error", isBoolean: true, validDuringMultilineEntry: true, - set: func(c *cliState, _ string) error { c.errExit = true; return nil }, - reset: func(c *cliState) error { c.errExit = false; return nil }, - display: func(c *cliState) string { return strconv.FormatBool(c.errExit) }, + set: func(_ string) error { sqlCtx.errExit = true; return nil }, + reset: func() error { sqlCtx.errExit = false; return nil }, + display: func() string { return strconv.FormatBool(sqlCtx.errExit) }, }, `check_syntax`: { description: "check the SQL syntax before running a query", isBoolean: true, validDuringMultilineEntry: false, - set: func(c *cliState, _ string) error { c.checkSyntax = true; return nil }, - reset: func(c *cliState) error { c.checkSyntax = false; return nil }, - display: func(c *cliState) string { return strconv.FormatBool(c.checkSyntax) }, + set: func(_ string) error { sqlCtx.checkSyntax = true; return nil }, + reset: func() error { sqlCtx.checkSyntax = false; return nil }, + display: func() string { return strconv.FormatBool(sqlCtx.checkSyntax) }, }, `show_times`: { description: "display the execution time after each query", isBoolean: true, validDuringMultilineEntry: true, - set: func(_ *cliState, _ string) error { sqlCtx.showTimes = true; return nil }, - reset: func(_ *cliState) error { sqlCtx.showTimes = false; return nil }, - display: func(_ *cliState) string { return strconv.FormatBool(sqlCtx.showTimes) }, + set: func(_ string) error { sqlCtx.showTimes = true; return nil }, + reset: func() error { sqlCtx.showTimes = false; return nil }, + display: func() string { return strconv.FormatBool(sqlCtx.showTimes) }, }, `show_server_times`: { description: "display the server execution times for queries (requires show_times to be set)", isBoolean: true, validDuringMultilineEntry: true, - set: func(_ *cliState, _ string) error { sqlCtx.enableServerExecutionTimings = true; return nil }, - reset: func(_ *cliState) error { sqlCtx.enableServerExecutionTimings = false; return nil }, - display: func(_ *cliState) string { return strconv.FormatBool(sqlCtx.enableServerExecutionTimings) }, + set: func(_ string) error { sqlCtx.enableServerExecutionTimings = true; return nil }, + reset: func() error { sqlCtx.enableServerExecutionTimings = false; return nil }, + display: func() string { return strconv.FormatBool(sqlCtx.enableServerExecutionTimings) }, }, `verbose_times`: { description: "display execution times with more precision (requires show_times to be set)", isBoolean: true, validDuringMultilineEntry: true, - set: func(_ *cliState, _ string) error { sqlCtx.verboseTimings = true; return nil }, - reset: func(_ *cliState) error { sqlCtx.verboseTimings = false; return nil }, - display: func(_ *cliState) string { return strconv.FormatBool(sqlCtx.verboseTimings) }, + set: func(_ string) error { sqlCtx.verboseTimings = true; return nil }, + reset: func() error { sqlCtx.verboseTimings = false; return nil }, + display: func() string { return strconv.FormatBool(sqlCtx.verboseTimings) }, }, `smart_prompt`: { description: "deprecated", isBoolean: true, validDuringMultilineEntry: false, - set: func(c *cliState, _ string) error { return nil }, - reset: func(c *cliState) error { return nil }, - display: func(c *cliState) string { return "false" }, + set: func(_ string) error { return nil }, + reset: func() error { return nil }, + display: func() string { return "false" }, deprecated: true, }, `prompt1`: { description: "prompt string to use before each command (the following are expanded: %M full host, %m host, %> port number, %n user, %/ database, %x txn status)", isBoolean: false, validDuringMultilineEntry: true, - set: func(c *cliState, val string) error { - c.customPromptPattern = val + set: func(val string) error { + sqlCtx.customPromptPattern = val return nil }, - reset: func(c *cliState) error { - c.customPromptPattern = defaultPromptPattern + reset: func() error { + sqlCtx.customPromptPattern = defaultPromptPattern return nil }, - display: func(c *cliState) string { return c.customPromptPattern }, + display: func() string { return sqlCtx.customPromptPattern }, }, } @@ -435,7 +427,7 @@ func (c *cliState) handleSet(args []string, nextState, errState cliStateEnum) cl if options[n].deprecated { continue } - optData = append(optData, []string{n, options[n].display(c), options[n].description}) + optData = append(optData, []string{n, options[n].display(), options[n].description}) } err := printQueryOutput(os.Stdout, []string{"Option", "Value", "Description"}, @@ -474,13 +466,13 @@ func (c *cliState) handleSet(args []string, nextState, errState cliStateEnum) cl // Run the command. var err error if !opt.isBoolean { - err = opt.set(c, val) + err = opt.set(val) } else { switch val { case "true", "1", "on": - err = opt.set(c, "true") + err = opt.set("true") case "false", "0", "off": - err = opt.reset(c) + err = opt.reset() default: return c.invalidOptSet(errState, args) } @@ -506,7 +498,7 @@ func (c *cliState) handleUnset(args []string, nextState, errState cliStateEnum) if len(c.partialLines) > 0 && !opt.validDuringMultilineEntry { return c.invalidOptionChange(errState, args[0]) } - if err := opt.reset(c); err != nil { + if err := opt.reset(); err != nil { fmt.Fprintf(stderr, "\\unset %s: %v\n", args[0], err) return errState } @@ -714,7 +706,7 @@ func (c *cliState) doRefreshPrompts(nextState cliStateEnum) cliStateEnum { dbName := unknownDbName c.lastKnownTxnStatus = unknownTxnStatus - wantDbStateInPrompt := rePromptDbState.MatchString(c.customPromptPattern) + wantDbStateInPrompt := rePromptDbState.MatchString(sqlCtx.customPromptPattern) if wantDbStateInPrompt { c.refreshTransactionStatus() // refreshDatabaseName() must be called *after* refreshTransactionStatus(), @@ -724,7 +716,7 @@ func (c *cliState) doRefreshPrompts(nextState cliStateEnum) cliStateEnum { dbName = c.refreshDatabaseName() } - c.fullPrompt = rePromptFmt.ReplaceAllStringFunc(c.customPromptPattern, func(m string) string { + c.fullPrompt = rePromptFmt.ReplaceAllStringFunc(sqlCtx.customPromptPattern, func(m string) string { switch m { case "%M": return parsedURL.Host // full host name. @@ -1015,7 +1007,7 @@ func (c *cliState) doHandleCliCmd(loopState, nextState cliStateEnum) cliStateEnu } errState := loopState - if c.errExit { + if sqlCtx.errExit { // If exiterr is set, an error in a client-side command also // terminates the shell. errState = cliStop @@ -1164,7 +1156,7 @@ func (c *cliState) doPrepareStatementLine( // Complete input. Remember it in the history. c.addHistory(c.concatLines) - if !c.checkSyntax { + if !sqlCtx.checkSyntax { return execState } @@ -1184,7 +1176,7 @@ func (c *cliState) doCheckStatement(startState, contState, execState cliStateEnu &formattedError{err: err, showSeverity: false, verbose: false}) // Stop here if exiterr is set. - if c.errExit { + if sqlCtx.errExit { return cliStop } @@ -1216,13 +1208,13 @@ func (c *cliState) doRunStatements(nextState cliStateEnum) cliStateEnum { c.lastKnownTxnStatus = " ?" // Are we tracing? - if c.autoTrace != "" { + if sqlCtx.autoTrace != "" { // Clear the trace by disabling tracing, then restart tracing // with the specified options. - c.exitErr = c.conn.Exec("SET tracing = off; SET tracing = "+c.autoTrace, nil) + c.exitErr = c.conn.Exec("SET tracing = off; SET tracing = "+sqlCtx.autoTrace, nil) if c.exitErr != nil { cliOutputError(stderr, c.exitErr, true /*showSeverity*/, false /*verbose*/) - if c.errExit { + if sqlCtx.errExit { return cliStop } return nextState @@ -1237,7 +1229,7 @@ func (c *cliState) doRunStatements(nextState cliStateEnum) cliStateEnum { // If we are tracing, stop tracing and display the trace. We do // this even if there was an error: a trace on errors is useful. - if c.autoTrace != "" { + if sqlCtx.autoTrace != "" { // First, disable tracing. if err := c.conn.Exec("SET tracing = off", nil); err != nil { // Print the error for the SET tracing statement. This will @@ -1254,7 +1246,7 @@ func (c *cliState) doRunStatements(nextState cliStateEnum) cliStateEnum { // shell. } else { traceType := "" - if strings.Contains(c.autoTrace, "kv") { + if strings.Contains(sqlCtx.autoTrace, "kv") { traceType = "kv" } if err := runQueryAndFormatResults(c.conn, os.Stdout, @@ -1273,7 +1265,7 @@ func (c *cliState) doRunStatements(nextState cliStateEnum) cliStateEnum { } } - if c.exitErr != nil && c.errExit { + if c.exitErr != nil && sqlCtx.errExit { return cliStop } @@ -1373,18 +1365,18 @@ func (c *cliState) configurePreShellDefaults(cmdIn *os.File) (cleanupFn func(), if cliCtx.isInteractive { // If a human user is providing the input, we want to help them with // what they are entering: - c.errExit = false // let the user retry failing commands + sqlCtx.errExit = false // let the user retry failing commands if !sqlCtx.debugMode { // Also, try to enable syntax checking if supported by the server. // This is a form of client-side error checking to help with large txns. - c.checkSyntax = true + sqlCtx.checkSyntax = true } } else { // When running non-interactive, by default we want errors to stop // further processing and we can just let syntax checking to be // done server-side to avoid client-side churn. - c.errExit = true - c.checkSyntax = false + sqlCtx.errExit = true + sqlCtx.checkSyntax = false // We also don't need (smart) prompts at all. } @@ -1425,9 +1417,9 @@ func (c *cliState) configurePreShellDefaults(cmdIn *os.File) (cleanupFn func(), // command-line client. // Default prompt is part of the connection URL. eg: "marc@localhost:26257>". - c.customPromptPattern = defaultPromptPattern + sqlCtx.customPromptPattern = defaultPromptPattern if sqlCtx.debugMode { - c.customPromptPattern = debugPromptPattern + sqlCtx.customPromptPattern = debugPromptPattern } c.ins.SetCompleter(c) @@ -1471,11 +1463,11 @@ func (c *cliState) runStatements(stmts []string) error { // we are returning directly. c.exitErr = runQueryAndFormatResults(c.conn, os.Stdout, makeQuery(stmt)) if c.exitErr != nil { - if !c.errExit && i < len(stmts)-1 { + if !sqlCtx.errExit && i < len(stmts)-1 { // Print the error now because we don't get a chance later. cliOutputError(stderr, c.exitErr, true /*showSeverity*/, false /*verbose*/) } - if c.errExit { + if sqlCtx.errExit { break } }