diff --git a/WORKSPACE b/WORKSPACE index 43e2bad103ac..93f4da3b02f9 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -27,8 +27,9 @@ http_archive( # repo. git_repository( name = "bazel_gazelle", - commit = "d038863ba2e096792c6bb6afca31f6514f1aeecd", + commit = "0ac66c98675a24d58f89a614b84dcd920a7e1762", remote = "https://github.com/bazelbuild/bazel-gazelle", + shallow_since = "1626107853 -0400", ) # Override the location of some libraries; otherwise, rules_go will pull its own diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index 619b6d40fd48..64e01f3a9015 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -265,6 +265,7 @@ ALL_TESTS = [ "//pkg/sql/rowenc:rowenc_test", "//pkg/sql/rowexec:rowexec_test", "//pkg/sql/rowflow:rowflow_test", + "//pkg/sql/scanner:scanner_test", "//pkg/sql/schemachange:schemachange_test", "//pkg/sql/schemachanger/scbuild:scbuild_test", "//pkg/sql/schemachanger/scexec:scexec_test", diff --git a/pkg/cli/clisqlclient/BUILD.bazel b/pkg/cli/clisqlclient/BUILD.bazel index e0b498ff3333..b28dbda7b142 100644 --- a/pkg/cli/clisqlclient/BUILD.bazel +++ b/pkg/cli/clisqlclient/BUILD.bazel @@ -11,6 +11,7 @@ go_library( "make_query.go", "parse_bool.go", "rows.go", + "string_to_duration.go", "txn_shim.go", ], importpath = "github.com/cockroachdb/cockroach/pkg/cli/clisqlclient", @@ -19,11 +20,8 @@ go_library( "//pkg/build", "//pkg/cli/clicfg", "//pkg/cli/clierror", - "//pkg/roachpb", "//pkg/security/pprompt", - "//pkg/sql/parser", - "//pkg/sql/sem/tree", - "//pkg/util/duration", + "//pkg/sql/scanner", "//pkg/util/version", "@com_github_cockroachdb_cockroach_go//crdb", "@com_github_cockroachdb_errors//:errors", @@ -38,6 +36,7 @@ go_test( "conn_test.go", "main_test.go", "parse_bool_test.go", + "string_to_duration_test.go", ], embed = [":clisqlclient"], deps = [ diff --git a/pkg/cli/clisqlclient/api.go b/pkg/cli/clisqlclient/api.go index 87a7e75a9065..cc1027d2bdb8 100644 --- a/pkg/cli/clisqlclient/api.go +++ b/pkg/cli/clisqlclient/api.go @@ -14,8 +14,6 @@ import ( "database/sql/driver" "reflect" "time" - - "github.com/cockroachdb/cockroach/pkg/roachpb" ) // Conn represents a connection to a SQL server. @@ -88,7 +86,7 @@ type Conn interface { // GetServerMetadata() returns details about the CockroachDB node // this connection is connected to. GetServerMetadata() ( - nodeID roachpb.NodeID, + nodeID int32, version, clusterID string, err error, ) diff --git a/pkg/cli/clisqlclient/conn.go b/pkg/cli/clisqlclient/conn.go index fd96c1e48003..0e34cfb821d7 100644 --- a/pkg/cli/clisqlclient/conn.go +++ b/pkg/cli/clisqlclient/conn.go @@ -23,10 +23,7 @@ import ( "github.com/cockroachdb/cockroach-go/crdb" "github.com/cockroachdb/cockroach/pkg/build" "github.com/cockroachdb/cockroach/pkg/cli/clierror" - "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/security/pprompt" - "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" - "github.com/cockroachdb/cockroach/pkg/util/duration" "github.com/cockroachdb/cockroach/pkg/util/version" "github.com/cockroachdb/errors" "github.com/lib/pq" @@ -180,9 +177,7 @@ func (c *sqlConn) EnsureConn() error { } if c.reconnecting && c.dbName != "" { // Attempt to reset the current database. - if _, err := conn.(DriverConn).Exec( - `SET DATABASE = `+tree.NameStringP(&c.dbName), nil, - ); err != nil { + if _, err := conn.(DriverConn).Exec(`SET DATABASE = $1`, []driver.Value{c.dbName}); err != nil { fmt.Fprintf(c.errw, "warning: unable to restore current database: %v\n", err) } } @@ -209,11 +204,7 @@ func (c *sqlConn) tryEnableServerExecutionTimings() { } } -func (c *sqlConn) GetServerMetadata() ( - nodeID roachpb.NodeID, - version, clusterID string, - err error, -) { +func (c *sqlConn) GetServerMetadata() (nodeID int32, version, clusterID string, err error) { // Retrieve the node ID and server build info. rows, err := c.Query("SELECT * FROM crdb_internal.node_build_info", nil) if errors.Is(err, driver.ErrBadConn) { @@ -246,7 +237,7 @@ func (c *sqlConn) GetServerMetadata() ( if err != nil { return 0, "", "", errors.Newf("incorrect data while retrieving node id: %v", err) } - nodeID = roachpb.NodeID(id) + nodeID = int32(id) // Fields for v1.0 compatibility. case "Distribution": @@ -457,14 +448,14 @@ func (c *sqlConn) getLastQueryStatisticsInternal() ( jobsLatencyRaw = toString(vals[4]) } - parsedExecLatency, e1 := tree.ParseDInterval(duration.IntervalStyle_POSTGRES, execLatencyRaw) - parsedServiceLatency, e2 := tree.ParseDInterval(duration.IntervalStyle_POSTGRES, serviceLatencyRaw) - parsedPlanLatency, e3 := tree.ParseDInterval(duration.IntervalStyle_POSTGRES, planLatencyRaw) - parsedParseLatency, e4 := tree.ParseDInterval(duration.IntervalStyle_POSTGRES, parseLatencyRaw) + parsedExecLatency, e1 := stringToDuration(execLatencyRaw) + parsedServiceLatency, e2 := stringToDuration(serviceLatencyRaw) + parsedPlanLatency, e3 := stringToDuration(planLatencyRaw) + parsedParseLatency, e4 := stringToDuration(parseLatencyRaw) var e5 error - var parsedJobsLatency *tree.DInterval + var parsedJobsLatency time.Duration if containsJobLat { - parsedJobsLatency, e5 = tree.ParseDInterval(duration.IntervalStyle_POSTGRES, jobsLatencyRaw) + parsedJobsLatency, e5 = stringToDuration(jobsLatencyRaw) } if err := errors.CombineErrors(e1, errors.CombineErrors(e2, @@ -474,15 +465,11 @@ func (c *sqlConn) getLastQueryStatisticsInternal() ( errors.Wrap(err, "invalid interval value in SHOW LAST QUERY STATISTICS") } - if containsJobLat { - jobsLat = time.Duration(parsedJobsLatency.Duration.Nanos()) - } - - return time.Duration(parsedParseLatency.Duration.Nanos()), - time.Duration(parsedPlanLatency.Duration.Nanos()), - time.Duration(parsedExecLatency.Duration.Nanos()), - time.Duration(parsedServiceLatency.Duration.Nanos()), - jobsLat, + return parsedParseLatency, + parsedPlanLatency, + parsedExecLatency, + parsedServiceLatency, + parsedJobsLatency, containsJobLat, nil } diff --git a/pkg/cli/clisqlclient/make_query.go b/pkg/cli/clisqlclient/make_query.go index cb7a13cc6f37..9b7435da0fd4 100644 --- a/pkg/cli/clisqlclient/make_query.go +++ b/pkg/cli/clisqlclient/make_query.go @@ -14,7 +14,7 @@ import ( "database/sql/driver" "strings" - "github.com/cockroachdb/cockroach/pkg/sql/parser" + "github.com/cockroachdb/cockroach/pkg/sql/scanner" "github.com/cockroachdb/errors" ) @@ -25,7 +25,7 @@ type QueryFn func(conn Conn) (rows Rows, isMultiStatementQuery bool, err error) // function that can be applied to a connection object. func MakeQuery(query string, parameters ...driver.Value) QueryFn { return func(conn Conn) (Rows, bool, error) { - isMultiStatementQuery := parser.HasMultipleStatements(query) + isMultiStatementQuery, _ := scanner.HasMultipleStatements(query) // driver.Value is an alias for interface{}, but must adhere to a restricted // set of types when being passed to driver.Queryer.Query (see // driver.IsValue). We use driver.DefaultParameterConverter to perform the diff --git a/pkg/cli/clisqlclient/string_to_duration.go b/pkg/cli/clisqlclient/string_to_duration.go new file mode 100644 index 000000000000..685e118cd2cb --- /dev/null +++ b/pkg/cli/clisqlclient/string_to_duration.go @@ -0,0 +1,58 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package clisqlclient + +import ( + "regexp" + "strconv" + "time" + + "github.com/cockroachdb/errors" +) + +// stringToDuration converts a server-side interval value returned by +// SHOW LAST QUERY STATISTICS. We use a custom parser here to avoid +// depending on package `tree` or `duration`, which makes the SQL +// shell executable significantly larger. +// +// Note: this parser only supports the 'postgres' encoding for +// IntervalStyle. This code breaks if the server-side +// IntervalStyle is set to another value e.g. 'iso_8601'. +// See: https://github.com/cockroachdb/cockroach/issues/67618 +func stringToDuration(s string) (time.Duration, error) { + m := intervalRe.FindStringSubmatch(s) + if m == nil { + return 0, errors.Newf("invalid format: %q", s) + } + th, e1 := strconv.Atoi(m[1]) + tm, e2 := strconv.Atoi(m[2]) + ts, e3 := strconv.Atoi(m[3]) + us := m[4] + "000000"[:6-len(m[4])] + tus, e4 := strconv.Atoi(us) + return (time.Duration(th)*time.Hour + + time.Duration(tm)*time.Minute + + time.Duration(ts)*time.Second + + time.Duration(tus)*time.Microsecond), + errors.CombineErrors(e1, + errors.CombineErrors(e2, + errors.CombineErrors(e3, e4))) +} + +// intervalRe indicates how to parse the interval value. +// The format is HHHH:MM:SS[.ffffff] +// +// Note: we do not need to support a day prefix, because SHOW LAST +// QUERY STATISTICS always reports intervals computed from a number +// of seconds, and these never contain a "days" components. +// +// For example, a query that ran for 3 days will have its interval +// displayed as 72:00:00, not "3 days 00:00:00". +var intervalRe = regexp.MustCompile(`^(\d{2,}):(\d{2}):(\d{2})(?:\.(\d{1,6}))?$`) diff --git a/pkg/cli/clisqlclient/string_to_duration_test.go b/pkg/cli/clisqlclient/string_to_duration_test.go new file mode 100644 index 000000000000..75595ae268a7 --- /dev/null +++ b/pkg/cli/clisqlclient/string_to_duration_test.go @@ -0,0 +1,55 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package clisqlclient + +import ( + "testing" + "time" + + "github.com/cockroachdb/cockroach/pkg/testutils" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" +) + +func TestStringToDuration(t *testing.T) { + defer leaktest.AfterTest(t)() + + testCases := []struct { + input string + output time.Duration + expectedErr string + }{ + {"00:00:00", 0, ""}, + {"01:02:03", time.Hour + 2*time.Minute + 3*time.Second, ""}, + {"11:22:33", 11*time.Hour + 22*time.Minute + 33*time.Second, ""}, + {"1234:22:33", 1234*time.Hour + 22*time.Minute + 33*time.Second, ""}, + {"01:02:03.4", time.Hour + 2*time.Minute + 3*time.Second + 400*time.Millisecond, ""}, + {"01:02:03.004", time.Hour + 2*time.Minute + 3*time.Second + 4*time.Millisecond, ""}, + {"01:02:03.123456", time.Hour + 2*time.Minute + 3*time.Second + 123456*time.Microsecond, ""}, + {"1001:02:03.123456", 1001*time.Hour + 2*time.Minute + 3*time.Second + 123456*time.Microsecond, ""}, + {"00:00", 0, "invalid format"}, + {"00.00.00", 0, "invalid format"}, + {"00:00:00:000000000", 0, "invalid format"}, + {"00:00:00.000000000", 0, "invalid format"}, + {"123 00:00:00.000000000", 0, "invalid format"}, + } + + for _, tc := range testCases { + v, err := stringToDuration(tc.input) + if !testutils.IsError(err, tc.expectedErr) { + t.Errorf("%s: expected error %q, got: %v", tc.input, tc.expectedErr, err) + } + if err == nil { + if v != tc.output { + t.Errorf("%s: expected %v, got %v", tc.input, tc.output, v) + } + } + } +} diff --git a/pkg/cli/nodelocal.go b/pkg/cli/nodelocal.go index 547d1f282bea..d2707cce5770 100644 --- a/pkg/cli/nodelocal.go +++ b/pkg/cli/nodelocal.go @@ -20,6 +20,7 @@ import ( "path/filepath" "github.com/cockroachdb/cockroach/pkg/cli/clisqlclient" + "github.com/cockroachdb/cockroach/pkg/roachpb" "github.com/cockroachdb/cockroach/pkg/sql" "github.com/cockroachdb/errors" "github.com/spf13/cobra" @@ -132,7 +133,7 @@ func uploadFile( if err != nil { return errors.Wrap(err, "unable to get node id") } - fmt.Printf("successfully uploaded to nodelocal://%s\n", filepath.Join(nodeID.String(), destination)) + fmt.Printf("successfully uploaded to nodelocal://%s\n", filepath.Join(roachpb.NodeID(nodeID).String(), destination)) return nil } diff --git a/pkg/sql/parser/parse.go b/pkg/sql/parser/parse.go index 75eeb3992505..e07e656cfe98 100644 --- a/pkg/sql/parser/parse.go +++ b/pkg/sql/parser/parse.go @@ -264,26 +264,6 @@ func ParseOneWithInt(sql string, nakedIntType *types.T) (Statement, error) { return p.parseOneWithInt(sql, nakedIntType) } -// HasMultipleStatements returns true if the sql string contains more than one -// statements. -func HasMultipleStatements(sql string) bool { - var p Parser - p.scanner.Init(sql) - defer p.scanner.Cleanup() - count := 0 - for { - _, _, done := p.scanOneStmt() - if done { - break - } - count++ - if count > 1 { - return true - } - } - return false -} - // ParseQualifiedTableName parses a possibly qualified table name. The // table name must contain one or more name parts, using the full // input SQL syntax: each name part containing special characters, or diff --git a/pkg/sql/scanner/BUILD.bazel b/pkg/sql/scanner/BUILD.bazel index 85c6fbc81a9c..e93034fab94d 100644 --- a/pkg/sql/scanner/BUILD.bazel +++ b/pkg/sql/scanner/BUILD.bazel @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "scanner", @@ -7,3 +7,9 @@ go_library( visibility = ["//visibility:public"], deps = ["//pkg/sql/lexbase"], ) + +go_test( + name = "scanner_test", + srcs = ["scan_test.go"], + embed = [":scanner"], +) diff --git a/pkg/sql/scanner/scan.go b/pkg/sql/scanner/scan.go index c9505339d5ae..f2418c027723 100644 --- a/pkg/sql/scanner/scan.go +++ b/pkg/sql/scanner/scan.go @@ -998,3 +998,69 @@ outer: lval.SetStr(s.finishString(buf)) return true } + +// HasMultipleStatements returns true if the sql string contains more than one +// statements. An error is returned if an invalid token was encountered. +func HasMultipleStatements(sql string) (multipleStmt bool, err error) { + var s Scanner + var lval fakeSym + s.Init(sql) + count := 0 + for { + done, hasToks, err := s.scanOne(&lval) + if err != nil { + return false, err + } + if hasToks { + count++ + } + if done || count > 1 { + break + } + } + return count > 1, nil +} + +// scanOne is a simplified version of (*Parser).scanOneStmt() for use +// by HasMultipleStatements(). +func (s *Scanner) scanOne(lval *fakeSym) (done, hasToks bool, err error) { + // Scan the first token. + for { + s.Scan(lval) + if lval.id == 0 { + return true, false, nil + } + if lval.id != ';' { + break + } + } + + for { + if lval.id == lexbase.ERROR { + return true, true, fmt.Errorf("scan error: %s", lval.s) + } + s.Scan(lval) + if lval.id == 0 || lval.id == ';' { + return (lval.id == 0), true, nil + } + } +} + +// fakeSym is a simplified symbol type for use by +// HasMultipleStatements. +type fakeSym struct { + id int32 + pos int32 + s string +} + +var _ ScanSymType = (*fakeSym)(nil) + +func (s fakeSym) ID() int32 { return s.id } +func (s *fakeSym) SetID(id int32) { s.id = id } +func (s fakeSym) Pos() int32 { return s.pos } +func (s *fakeSym) SetPos(p int32) { s.pos = p } +func (s fakeSym) Str() string { return s.s } +func (s *fakeSym) SetStr(v string) { s.s = v } +func (s fakeSym) UnionVal() interface{} { return nil } +func (s fakeSym) SetUnionVal(v interface{}) {} diff --git a/pkg/sql/scanner/scan_test.go b/pkg/sql/scanner/scan_test.go new file mode 100644 index 000000000000..8612ec49050c --- /dev/null +++ b/pkg/sql/scanner/scan_test.go @@ -0,0 +1,39 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package scanner + +import "testing" + +func TestHasMultipleStatements(t *testing.T) { + testCases := []struct { + in string + expected bool + }{ + {`a b c`, false}, + {`a; b c`, true}, + {`a b; b c`, true}, + {`a b; b c;`, true}, + {`a b;`, false}, + {`SELECT 123; SELECT 123`, true}, + {`SELECT 123; SELECT 123;`, true}, + } + + for _, tc := range testCases { + actual, err := HasMultipleStatements(tc.in) + if err != nil { + t.Error(err) + } + + if actual != tc.expected { + t.Errorf("%q: expected %v, got %v", tc.in, tc.expected, actual) + } + } +}