Skip to content

Commit

Permalink
fix(spanner/spansql): fix parsing of unary minus and plus (#2997)
Browse files Browse the repository at this point in the history
We cannot tokenise numbers with leading signs, because otherwise
`ID+100` is parsed as the identifier `ID` followed by the literal `+100`
instead of an arithmetic operation. This necessitates some rearrangement
of the parsing of int64 literals to accommodate that while also
preserving the ability to parse -INT_MAX.
  • Loading branch information
dsymonds authored Oct 14, 2020
1 parent 7df1879 commit 3fb19a5
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 18 deletions.
65 changes: 56 additions & 9 deletions spanner/spansql/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,13 @@ type token struct {
line, offset int

typ tokenType
int64 int64
float64 float64
string string // unquoted form for stringToken/bytesToken/quotedID

// int64Token is parsed as a number only when it is known to be a literal.
// This permits correct handling of operators preceding such a token,
// which cannot be identified as part of the int64 until later.
int64Base int
}

type tokenType int
Expand Down Expand Up @@ -415,7 +419,9 @@ digitLoop:
p.cur.float64, err = strconv.ParseFloat(sign+p.cur.value[d0:], 64)
} else {
p.cur.typ = int64Token
p.cur.int64, err = strconv.ParseInt(sign+p.cur.value[d0:], base, 64)
p.cur.value = sign + p.cur.value[d0:]
p.cur.int64Base = base
// This is parsed on demand.
}
if err != nil {
p.errorf("bad numeric literal %q: %v", p.cur.value, err)
Expand Down Expand Up @@ -776,7 +782,7 @@ func (p *parser) advance() {
p.cur.typ = unknownToken
// TODO: array, struct, date, timestamp literals
switch p.s[0] {
case ',', ';', '(', ')', '{', '}', '*':
case ',', ';', '(', ')', '{', '}', '*', '+', '-':
// Single character symbol.
p.cur.value, p.s = p.s[:1], p.s[1:]
p.offset++
Expand Down Expand Up @@ -825,8 +831,8 @@ func (p *parser) advance() {
p.offset += i
return
}
if len(p.s) >= 2 && (p.s[0] == '+' || p.s[0] == '-' || p.s[0] == '.') && ('0' <= p.s[1] && p.s[1] <= '9') {
// [-+.] followed by a digit.
if len(p.s) >= 2 && p.s[0] == '.' && ('0' <= p.s[1] && p.s[1] <= '9') {
// dot followed by a digit.
p.consumeNumber()
return
}
Expand Down Expand Up @@ -1588,7 +1594,11 @@ func (p *parser) parseType() (Type, *parseError) {
if tok.value == "MAX" {
t.Len = MaxLen
} else if tok.typ == int64Token {
t.Len = tok.int64
n, err := strconv.ParseInt(tok.value, tok.int64Base, 64)
if err != nil {
return Type{}, p.errorf("%v", err)
}
t.Len = n
} else {
return Type{}, p.errorf("got %q, want MAX or int64", tok.value)
}
Expand Down Expand Up @@ -2025,7 +2035,11 @@ func (p *parser) parseLiteralOrParam() (LiteralOrParam, *parseError) {
return nil, tok.err
}
if tok.typ == int64Token {
return IntegerLiteral(tok.int64), nil
n, err := strconv.ParseInt(tok.value, tok.int64Base, 64)
if err != nil {
return nil, p.errorf("%v", err)
}
return IntegerLiteral(n), nil
}
// TODO: check character sets.
if strings.HasPrefix(tok.value, "@") {
Expand Down Expand Up @@ -2343,14 +2357,43 @@ func (p *parser) parseArithOp() (Expr, *parseError) {
var unaryArithOperators = map[string]ArithOperator{
"-": Neg,
"~": BitNot,
"+": Plus,
}

func (p *parser) parseUnaryArithOp() (Expr, *parseError) {
tok := p.next()
if tok.err != nil {
return nil, tok.err
}
if op, ok := unaryArithOperators[tok.value]; ok {

op := tok.value

if op == "-" || op == "+" {
// If the next token is a numeric token, combine and parse as a literal.
ntok := p.next()
if ntok.err == nil {
switch ntok.typ {
case int64Token:
comb := op + ntok.value
n, err := strconv.ParseInt(comb, ntok.int64Base, 64)
if err != nil {
return nil, p.errorf("%v", err)
}
return IntegerLiteral(n), nil
case float64Token:
f := ntok.float64
if op == "-" {
f = -f
}
return FloatLiteral(f), nil
}
}
// It is not possible for the p.back() lower down to fire
// because - and + are in unaryArithOperators.
p.back()
}

if op, ok := unaryArithOperators[op]; ok {
e, err := p.parseLit()
if err != nil {
return nil, err
Expand All @@ -2370,7 +2413,11 @@ func (p *parser) parseLit() (Expr, *parseError) {

switch tok.typ {
case int64Token:
return IntegerLiteral(tok.int64), nil
n, err := strconv.ParseInt(tok.value, tok.int64Base, 64)
if err != nil {
return nil, p.errorf("%v", err)
}
return IntegerLiteral(n), nil
case float64Token:
return FloatLiteral(tok.float64), nil
case stringToken:
Expand Down
21 changes: 14 additions & 7 deletions spanner/spansql/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,10 @@ func TestParseExpr(t *testing.T) {
{`4e2`, FloatLiteral(4e2)},
{`X + Y * Z`, ArithOp{LHS: ID("X"), Op: Add, RHS: ArithOp{LHS: ID("Y"), Op: Mul, RHS: ID("Z")}}},
{`X + Y + Z`, ArithOp{LHS: ArithOp{LHS: ID("X"), Op: Add, RHS: ID("Y")}, Op: Add, RHS: ID("Z")}},
{`X * -Y`, ArithOp{LHS: ID("X"), Op: Mul, RHS: ArithOp{Op: Neg, RHS: ID("Y")}}},
{`+X * -Y`, ArithOp{LHS: ArithOp{Op: Plus, RHS: ID("X")}, Op: Mul, RHS: ArithOp{Op: Neg, RHS: ID("Y")}}},
// Don't require space around +/- operators.
{`ID+100`, ArithOp{LHS: ID("ID"), Op: Add, RHS: IntegerLiteral(100)}},
{`ID-100`, ArithOp{LHS: ID("ID"), Op: Sub, RHS: IntegerLiteral(100)}},
{`ID&0x3fff`, ArithOp{LHS: ID("ID"), Op: BitAnd, RHS: IntegerLiteral(0x3fff)}},
{`SHA1("Hello" || " " || "World")`, Func{Name: "SHA1", Args: []Expr{ArithOp{LHS: ArithOp{LHS: StringLiteral("Hello"), Op: Concat, RHS: StringLiteral(" ")}, Op: Concat, RHS: StringLiteral("World")}}}},
{`Count > 0`, ComparisonOp{LHS: ID("Count"), Op: Gt, RHS: IntegerLiteral(0)}},
Expand Down Expand Up @@ -322,8 +325,8 @@ func TestParseExpr(t *testing.T) {
if !reflect.DeepEqual(got, test.want) {
t.Errorf("[%s]: incorrect parse\n got <%T> %#v\nwant <%T> %#v", test.in, got, got, test.want, test.want)
}
if p.s != "" {
t.Errorf("[%s]: Unparsed [%s]", test.in, p.s)
if rem := p.Rem(); rem != "" {
t.Errorf("[%s]: Unparsed [%s]", test.in, rem)
}
}
}
Expand Down Expand Up @@ -613,12 +616,16 @@ func tableByName(t *testing.T, ddl *DDL, name ID) *CreateTable {

func TestParseFailures(t *testing.T) {
expr := func(p *parser) error {
_, err := p.parseExpr()
return err
if _, pe := p.parseExpr(); pe != nil {
return pe
}
return nil
}
query := func(p *parser) error {
_, err := p.parseQuery()
return err
if _, pe := p.parseQuery(); pe != nil {
return pe
}
return nil
}

tests := []struct {
Expand Down
5 changes: 5 additions & 0 deletions spanner/spansql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,11 @@ func (ao ArithOp) addSQL(sb *strings.Builder) {
ao.RHS.addSQL(sb)
sb.WriteString(")")
return
case Plus:
sb.WriteString("+(")
ao.RHS.addSQL(sb)
sb.WriteString(")")
return
case BitNot:
sb.WriteString("~(")
ao.RHS.addSQL(sb)
Expand Down
5 changes: 3 additions & 2 deletions spanner/spansql/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ type LiteralOrParam interface {

type ArithOp struct {
Op ArithOperator
LHS, RHS Expr // only RHS is set for Neg, BitNot
LHS, RHS Expr // only RHS is set for Neg, Plus, BitNot
}

func (ArithOp) isExpr() {}
Expand All @@ -400,6 +400,7 @@ type ArithOperator int

const (
Neg ArithOperator = iota // unary -
Plus // unary +
BitNot // unary ~
Mul // *
Div // /
Expand All @@ -415,7 +416,7 @@ const (

type LogicalOp struct {
Op LogicalOperator
LHS, RHS BoolExpr // only RHS is set for Neg, BitNot
LHS, RHS BoolExpr // only RHS is set for Not
}

func (LogicalOp) isBoolExpr() {}
Expand Down

0 comments on commit 3fb19a5

Please sign in to comment.