diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index 619b6d40fd48..64e01f3a9015 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -265,6 +265,7 @@ ALL_TESTS = [ "//pkg/sql/rowenc:rowenc_test", "//pkg/sql/rowexec:rowexec_test", "//pkg/sql/rowflow:rowflow_test", + "//pkg/sql/scanner:scanner_test", "//pkg/sql/schemachange:schemachange_test", "//pkg/sql/schemachanger/scbuild:scbuild_test", "//pkg/sql/schemachanger/scexec:scexec_test", diff --git a/pkg/cli/clisqlclient/BUILD.bazel b/pkg/cli/clisqlclient/BUILD.bazel index a058b5bd2ac6..b28dbda7b142 100644 --- a/pkg/cli/clisqlclient/BUILD.bazel +++ b/pkg/cli/clisqlclient/BUILD.bazel @@ -21,7 +21,7 @@ go_library( "//pkg/cli/clicfg", "//pkg/cli/clierror", "//pkg/security/pprompt", - "//pkg/sql/parser", + "//pkg/sql/scanner", "//pkg/util/version", "@com_github_cockroachdb_cockroach_go//crdb", "@com_github_cockroachdb_errors//:errors", diff --git a/pkg/cli/clisqlclient/make_query.go b/pkg/cli/clisqlclient/make_query.go index cb7a13cc6f37..9b7435da0fd4 100644 --- a/pkg/cli/clisqlclient/make_query.go +++ b/pkg/cli/clisqlclient/make_query.go @@ -14,7 +14,7 @@ import ( "database/sql/driver" "strings" - "github.com/cockroachdb/cockroach/pkg/sql/parser" + "github.com/cockroachdb/cockroach/pkg/sql/scanner" "github.com/cockroachdb/errors" ) @@ -25,7 +25,7 @@ type QueryFn func(conn Conn) (rows Rows, isMultiStatementQuery bool, err error) // function that can be applied to a connection object. func MakeQuery(query string, parameters ...driver.Value) QueryFn { return func(conn Conn) (Rows, bool, error) { - isMultiStatementQuery := parser.HasMultipleStatements(query) + isMultiStatementQuery, _ := scanner.HasMultipleStatements(query) // driver.Value is an alias for interface{}, but must adhere to a restricted // set of types when being passed to driver.Queryer.Query (see // driver.IsValue). We use driver.DefaultParameterConverter to perform the diff --git a/pkg/sql/parser/parse.go b/pkg/sql/parser/parse.go index 75eeb3992505..e07e656cfe98 100644 --- a/pkg/sql/parser/parse.go +++ b/pkg/sql/parser/parse.go @@ -264,26 +264,6 @@ func ParseOneWithInt(sql string, nakedIntType *types.T) (Statement, error) { return p.parseOneWithInt(sql, nakedIntType) } -// HasMultipleStatements returns true if the sql string contains more than one -// statements. -func HasMultipleStatements(sql string) bool { - var p Parser - p.scanner.Init(sql) - defer p.scanner.Cleanup() - count := 0 - for { - _, _, done := p.scanOneStmt() - if done { - break - } - count++ - if count > 1 { - return true - } - } - return false -} - // ParseQualifiedTableName parses a possibly qualified table name. The // table name must contain one or more name parts, using the full // input SQL syntax: each name part containing special characters, or diff --git a/pkg/sql/scanner/BUILD.bazel b/pkg/sql/scanner/BUILD.bazel index 85c6fbc81a9c..e93034fab94d 100644 --- a/pkg/sql/scanner/BUILD.bazel +++ b/pkg/sql/scanner/BUILD.bazel @@ -1,4 +1,4 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "scanner", @@ -7,3 +7,9 @@ go_library( visibility = ["//visibility:public"], deps = ["//pkg/sql/lexbase"], ) + +go_test( + name = "scanner_test", + srcs = ["scan_test.go"], + embed = [":scanner"], +) diff --git a/pkg/sql/scanner/scan.go b/pkg/sql/scanner/scan.go index c9505339d5ae..f2418c027723 100644 --- a/pkg/sql/scanner/scan.go +++ b/pkg/sql/scanner/scan.go @@ -998,3 +998,69 @@ outer: lval.SetStr(s.finishString(buf)) return true } + +// HasMultipleStatements returns true if the sql string contains more than one +// statements. An error is returned if an invalid token was encountered. +func HasMultipleStatements(sql string) (multipleStmt bool, err error) { + var s Scanner + var lval fakeSym + s.Init(sql) + count := 0 + for { + done, hasToks, err := s.scanOne(&lval) + if err != nil { + return false, err + } + if hasToks { + count++ + } + if done || count > 1 { + break + } + } + return count > 1, nil +} + +// scanOne is a simplified version of (*Parser).scanOneStmt() for use +// by HasMultipleStatements(). +func (s *Scanner) scanOne(lval *fakeSym) (done, hasToks bool, err error) { + // Scan the first token. + for { + s.Scan(lval) + if lval.id == 0 { + return true, false, nil + } + if lval.id != ';' { + break + } + } + + for { + if lval.id == lexbase.ERROR { + return true, true, fmt.Errorf("scan error: %s", lval.s) + } + s.Scan(lval) + if lval.id == 0 || lval.id == ';' { + return (lval.id == 0), true, nil + } + } +} + +// fakeSym is a simplified symbol type for use by +// HasMultipleStatements. +type fakeSym struct { + id int32 + pos int32 + s string +} + +var _ ScanSymType = (*fakeSym)(nil) + +func (s fakeSym) ID() int32 { return s.id } +func (s *fakeSym) SetID(id int32) { s.id = id } +func (s fakeSym) Pos() int32 { return s.pos } +func (s *fakeSym) SetPos(p int32) { s.pos = p } +func (s fakeSym) Str() string { return s.s } +func (s *fakeSym) SetStr(v string) { s.s = v } +func (s fakeSym) UnionVal() interface{} { return nil } +func (s fakeSym) SetUnionVal(v interface{}) {} diff --git a/pkg/sql/scanner/scan_test.go b/pkg/sql/scanner/scan_test.go new file mode 100644 index 000000000000..8612ec49050c --- /dev/null +++ b/pkg/sql/scanner/scan_test.go @@ -0,0 +1,39 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package scanner + +import "testing" + +func TestHasMultipleStatements(t *testing.T) { + testCases := []struct { + in string + expected bool + }{ + {`a b c`, false}, + {`a; b c`, true}, + {`a b; b c`, true}, + {`a b; b c;`, true}, + {`a b;`, false}, + {`SELECT 123; SELECT 123`, true}, + {`SELECT 123; SELECT 123;`, true}, + } + + for _, tc := range testCases { + actual, err := HasMultipleStatements(tc.in) + if err != nil { + t.Error(err) + } + + if actual != tc.expected { + t.Errorf("%q: expected %v, got %v", tc.in, tc.expected, actual) + } + } +}