Skip to content

Commit

Permalink
clisqlclient: drop the dependency on parser
Browse files Browse the repository at this point in the history
This moves `HasMultipleStatements` to package `scanner` which is much
smaller, and uses that.

Release note: None
  • Loading branch information
knz committed Jul 15, 2021
1 parent 034c54b commit cb87794
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 24 deletions.
1 change: 1 addition & 0 deletions pkg/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ ALL_TESTS = [
"//pkg/sql/rowenc:rowenc_test",
"//pkg/sql/rowexec:rowexec_test",
"//pkg/sql/rowflow:rowflow_test",
"//pkg/sql/scanner:scanner_test",
"//pkg/sql/schemachange:schemachange_test",
"//pkg/sql/schemachanger/scbuild:scbuild_test",
"//pkg/sql/schemachanger/scexec:scexec_test",
Expand Down
2 changes: 1 addition & 1 deletion pkg/cli/clisqlclient/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ go_library(
"//pkg/cli/clicfg",
"//pkg/cli/clierror",
"//pkg/security/pprompt",
"//pkg/sql/parser",
"//pkg/sql/scanner",
"//pkg/util/version",
"@com_github_cockroachdb_cockroach_go//crdb",
"@com_github_cockroachdb_errors//:errors",
Expand Down
4 changes: 2 additions & 2 deletions pkg/cli/clisqlclient/make_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"database/sql/driver"
"strings"

"github.com/cockroachdb/cockroach/pkg/sql/parser"
"github.com/cockroachdb/cockroach/pkg/sql/scanner"
"github.com/cockroachdb/errors"
)

Expand All @@ -25,7 +25,7 @@ type QueryFn func(conn Conn) (rows Rows, isMultiStatementQuery bool, err error)
// function that can be applied to a connection object.
func MakeQuery(query string, parameters ...driver.Value) QueryFn {
return func(conn Conn) (Rows, bool, error) {
isMultiStatementQuery := parser.HasMultipleStatements(query)
isMultiStatementQuery, _ := scanner.HasMultipleStatements(query)
// driver.Value is an alias for interface{}, but must adhere to a restricted
// set of types when being passed to driver.Queryer.Query (see
// driver.IsValue). We use driver.DefaultParameterConverter to perform the
Expand Down
20 changes: 0 additions & 20 deletions pkg/sql/parser/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,26 +264,6 @@ func ParseOneWithInt(sql string, nakedIntType *types.T) (Statement, error) {
return p.parseOneWithInt(sql, nakedIntType)
}

// HasMultipleStatements returns true if the sql string contains more than one
// statements.
func HasMultipleStatements(sql string) bool {
var p Parser
p.scanner.Init(sql)
defer p.scanner.Cleanup()
count := 0
for {
_, _, done := p.scanOneStmt()
if done {
break
}
count++
if count > 1 {
return true
}
}
return false
}

// ParseQualifiedTableName parses a possibly qualified table name. The
// table name must contain one or more name parts, using the full
// input SQL syntax: each name part containing special characters, or
Expand Down
8 changes: 7 additions & 1 deletion pkg/sql/scanner/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")

go_library(
name = "scanner",
Expand All @@ -7,3 +7,9 @@ go_library(
visibility = ["//visibility:public"],
deps = ["//pkg/sql/lexbase"],
)

go_test(
name = "scanner_test",
srcs = ["scan_test.go"],
embed = [":scanner"],
)
66 changes: 66 additions & 0 deletions pkg/sql/scanner/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -998,3 +998,69 @@ outer:
lval.SetStr(s.finishString(buf))
return true
}

// HasMultipleStatements returns true if the sql string contains more than one
// statements. An error is returned if an invalid token was encountered.
func HasMultipleStatements(sql string) (multipleStmt bool, err error) {
var s Scanner
var lval fakeSym
s.Init(sql)
count := 0
for {
done, hasToks, err := s.scanOne(&lval)
if err != nil {
return false, err
}
if hasToks {
count++
}
if done || count > 1 {
break
}
}
return count > 1, nil
}

// scanOne is a simplified version of (*Parser).scanOneStmt() for use
// by HasMultipleStatements().
func (s *Scanner) scanOne(lval *fakeSym) (done, hasToks bool, err error) {
// Scan the first token.
for {
s.Scan(lval)
if lval.id == 0 {
return true, false, nil
}
if lval.id != ';' {
break
}
}

for {
if lval.id == lexbase.ERROR {
return true, true, fmt.Errorf("scan error: %s", lval.s)
}
s.Scan(lval)
if lval.id == 0 || lval.id == ';' {
return (lval.id == 0), true, nil
}
}
}

// fakeSym is a simplified symbol type for use by
// HasMultipleStatements.
type fakeSym struct {
id int32
pos int32
s string
}

var _ ScanSymType = (*fakeSym)(nil)

func (s fakeSym) ID() int32 { return s.id }
func (s *fakeSym) SetID(id int32) { s.id = id }
func (s fakeSym) Pos() int32 { return s.pos }
func (s *fakeSym) SetPos(p int32) { s.pos = p }
func (s fakeSym) Str() string { return s.s }
func (s *fakeSym) SetStr(v string) { s.s = v }
func (s fakeSym) UnionVal() interface{} { return nil }
func (s fakeSym) SetUnionVal(v interface{}) {}
39 changes: 39 additions & 0 deletions pkg/sql/scanner/scan_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright 2021 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package scanner

import "testing"

func TestHasMultipleStatements(t *testing.T) {
testCases := []struct {
in string
expected bool
}{
{`a b c`, false},
{`a; b c`, true},
{`a b; b c`, true},
{`a b; b c;`, true},
{`a b;`, false},
{`SELECT 123; SELECT 123`, true},
{`SELECT 123; SELECT 123;`, true},
}

for _, tc := range testCases {
actual, err := HasMultipleStatements(tc.in)
if err != nil {
t.Error(err)
}

if actual != tc.expected {
t.Errorf("%q: expected %v, got %v", tc.in, tc.expected, actual)
}
}
}

0 comments on commit cb87794

Please sign in to comment.