Skip to content

Commit

Permalink
Merge pull request #112508 from cockroachdb/blathers/backport-release…
Browse files Browse the repository at this point in the history
…-23.2-112200

release-23.2: plpgsql: correctly handle parsing errors
  • Loading branch information
DrewKimball authored Oct 17, 2023
2 parents f151df1 + 16e9564 commit 4124413
Show file tree
Hide file tree
Showing 9 changed files with 323 additions and 81 deletions.
4 changes: 2 additions & 2 deletions pkg/sql/parser/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ func GetTypeReferenceFromName(typeName tree.Name) (tree.ResolvableTypeReference,

// GetTypeFromValidSQLSyntax retrieves a type from its SQL syntax. The caller is
// responsible for guaranteeing that the type expression is valid
// SQL. This includes verifying that complex identifiers are enclosed
// in double quotes, etc.
// SQL (or handling the resulting error). This includes verifying that complex
// identifiers are enclosed in double quotes, etc.
func GetTypeFromValidSQLSyntax(sql string) (tree.ResolvableTypeReference, error) {
expr, err := ParseExpr(fmt.Sprintf("1::%s", sql))
if err != nil {
Expand Down
110 changes: 66 additions & 44 deletions pkg/sql/plpgsql/parser/lexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ func (l *lexer) MakeExecSqlStmt() (*plpgsqltree.Execute, error) {
}
// Push back the first token so that it's included in the SQL string.
l.PushBack(1)
startPos, endPos, _ := l.readSQLConstruct(';')
if endPos <= startPos || startPos <= 0 {
return nil, errors.New("expected SQL statement")
startPos, endPos, _, err := l.readSQLConstruct(false /* isExpr */, ';')
if err != nil {
return nil, err
}
// Move past the semicolon.
l.lastPos++
Expand Down Expand Up @@ -200,8 +200,11 @@ func (l *lexer) MakeExecSqlStmt() (*plpgsqltree.Execute, error) {
}, nil
}

func (l *lexer) MakeDynamicExecuteStmt() *plpgsqltree.DynamicExecute {
cmdStr, _ := l.ReadSqlConstruct(INTO, USING, ';')
func (l *lexer) MakeDynamicExecuteStmt() (*plpgsqltree.DynamicExecute, error) {
cmdStr, _, err := l.ReadSqlStatement(INTO, USING, ';')
if err != nil {
return nil, err
}
ret := &plpgsqltree.DynamicExecute{
Query: cmdStr,
}
Expand All @@ -211,7 +214,7 @@ func (l *lexer) MakeDynamicExecuteStmt() *plpgsqltree.DynamicExecute {
for {
if lval.id == INTO {
if ret.Into {
l.setErr(errors.AssertionFailedf("seen multiple INTO"))
return nil, errors.New("multiple INTO keywords")
}
ret.Into = true
nextTok := l.Peek()
Expand All @@ -221,15 +224,21 @@ func (l *lexer) MakeDynamicExecuteStmt() *plpgsqltree.DynamicExecute {
}
// TODO we need to read each "INTO" variable name instead of just a
// string.
l.ReadSqlExpressionStr2(USING, ';')
_, _, err = l.ReadSqlExpr(USING, ';')
if err != nil {
return nil, err
}
l.Lex(&lval)
} else if lval.id == USING {
if ret.Params != nil {
l.setErr(errors.AssertionFailedf("seen multiple USINGs"))
return nil, errors.New("multiple USING keywords")
}
ret.Params = make([]plpgsqltree.Expr, 0)
for {
l.ReadSqlConstruct(',', ';', INTO)
_, _, err = l.ReadSqlExpr(',', ';', INTO)
if err != nil {
return nil, err
}
ret.Params = append(ret.Params, nil)
l.Lex(&lval)
if lval.id == ';' {
Expand All @@ -239,32 +248,16 @@ func (l *lexer) MakeDynamicExecuteStmt() *plpgsqltree.DynamicExecute {
} else if lval.id == ';' {
break
} else {
l.setErr(errors.AssertionFailedf("syntax error"))
return nil, errors.Newf("unexpected token: %s", lval.id)
}
}

return ret
}

// ReadSqlExpressionStr returns the string from the l.lastPos till it sees
// the terminator for the first time. The returned string is made by tokens
// between the starting index (included) to the terminator (not included).
// TODO(plpgsql-team): pass the output to the sql parser
// (i.e. sqlParserImpl.Parse()).
func (l *lexer) ReadSqlExpressionStr(terminator int) (sqlStr string) {
sqlStr, _ = l.ReadSqlConstruct(terminator, 0, 0)
return sqlStr
}

func (l *lexer) ReadSqlExpressionStr2(
terminator1 int, terminator2 int,
) (sqlStr string, terminatorMet int) {
return l.ReadSqlConstruct(terminator1, terminator2, 0)
return ret, nil
}

func (l *lexer) readSQLConstruct(
terminator1 int, terminators ...int,
) (startPos, endPos, terminatorMet int) {
isExpr bool, terminator1 int, terminators ...int,
) (startPos, endPos, terminatorMet int, err error) {
if l.parser.Lookahead() != -1 {
// Push back the lookahead token so that it can be included.
l.PushBack(1)
Expand All @@ -290,24 +283,26 @@ func (l *lexer) readSQLConstruct(
} else if tok.id == ')' || tok.id == ']' {
parenLevel--
if parenLevel < 0 {
panic(errors.AssertionFailedf("wrongly nested parentheses"))
return 0, 0, 0, errors.New("mismatched parentheses")
}
}
l.lastPos++
}
if parenLevel != 0 {
panic(errors.AssertionFailedf("parentheses is badly nested"))
}
if startPos > l.lastPos {
//TODO(jane): show the terminator in the panic message.
l.setErr(errors.New("missing SQL expression"))
return 0, 0, 0
return 0, 0, 0, errors.New("mismatched parentheses")
}
endPos = l.lastPos + 1
if endPos > len(l.tokens) {
endPos = len(l.tokens)
}
return startPos, endPos, terminatorMet
if endPos <= startPos {
if isExpr {
return 0, 0, 0, errors.New("missing expression")
} else {
return 0, 0, 0, errors.New("missing SQL statement")
}
}
return startPos, endPos, terminatorMet, nil
}

func (l *lexer) MakeFetchOrMoveStmt(isMove bool) (plpgsqltree.Statement, error) {
Expand All @@ -319,7 +314,10 @@ func (l *lexer) MakeFetchOrMoveStmt(isMove bool) (plpgsqltree.Statement, error)
if isMove {
prefix = "MOVE "
}
sqlStr, terminator := l.ReadSqlConstruct(INTO, ';')
sqlStr, terminator, err := l.ReadSqlStatement(INTO, ';')
if err != nil {
return nil, err
}
sqlStr = prefix + sqlStr
sqlStmt, err := parser.ParseOne(sqlStr)
if err != nil {
Expand All @@ -341,7 +339,10 @@ func (l *lexer) MakeFetchOrMoveStmt(isMove bool) (plpgsqltree.Statement, error)
}
// Read past the INTO.
l.lastPos++
startPos, endPos, _ := l.readSQLConstruct(';')
startPos, endPos, _, err := l.readSQLConstruct(true /* isExpr */, ';')
if err != nil {
return nil, err
}
for pos := startPos; pos < endPos; pos += 2 {
tok := l.tokens[pos]
if tok.id != IDENT {
Expand All @@ -366,12 +367,24 @@ func (l *lexer) MakeFetchOrMoveStmt(isMove bool) (plpgsqltree.Statement, error)
}, nil
}

func (l *lexer) ReadSqlConstruct(
func (l *lexer) ReadSqlExpr(
terminator1 int, terminators ...int,
) (sqlStr string, terminatorMet int, err error) {
var startPos, endPos int
startPos, endPos, terminatorMet, err = l.readSQLConstruct(
true /* isExpr */, terminator1, terminators...,
)
return l.getStr(startPos, endPos), terminatorMet, err
}

func (l *lexer) ReadSqlStatement(
terminator1 int, terminators ...int,
) (sqlStr string, terminatorMet int) {
) (sqlStr string, terminatorMet int, err error) {
var startPos, endPos int
startPos, endPos, terminatorMet = l.readSQLConstruct(terminator1, terminators...)
return l.getStr(startPos, endPos), terminatorMet
startPos, endPos, terminatorMet, err = l.readSQLConstruct(
false /* isExpr */, terminator1, terminators...,
)
return l.getStr(startPos, endPos), terminatorMet, err
}

func (l *lexer) getStr(startPos, endPos int) string {
Expand Down Expand Up @@ -463,5 +476,14 @@ func (l *lexer) GetTypeFromValidSQLSyntax(sqlStr string) (tree.ResolvableTypeRef
}

func (l *lexer) ParseExpr(sqlStr string) (plpgsqltree.Expr, error) {
return parser.ParseExpr(sqlStr)
// Use ParseExprs instead of ParseExpr in order to correctly handle the case
// when multiple expressions are incorrectly passed.
exprs, err := parser.ParseExprs([]string{sqlStr})
if err != nil {
return nil, err
}
if len(exprs) != 1 {
return nil, pgerror.Newf(pgcode.Syntax, "query returned %d columns", len(exprs))
}
return exprs[0], nil
}
Loading

0 comments on commit 4124413

Please sign in to comment.