Skip to content

Commit

Permalink
Merge #112200
Browse files Browse the repository at this point in the history
112200: plpgsql: correctly handle parsing errors r=DrewKimball a=DrewKimball

#### plpgsql: correctly handle parsing errors

This patch ensures that PLpgSQL parsing errors are correctly propagated
in all cases. Previously, there were a few cases (like variable declaration
type parsing) where an error didn't halt parsing. The contract for
`GetTypeFromValidSQLSyntax` is also clarified, since it is ok to call with
an invalid type name as long as the error is properly handled.

Informs #105254

Release note: None

#### plpgsql: handle multiple expressions when one expression is expected

Previously, the PLpgSQL parser could panic when the user supplied more
than one expression in a location where only one was expected, for example,
in a return statement. This was because the PLpgSQL parser delegated to
the SQL parser's `ParseExpr` function, which expects exactly one input
expression. This commit returns a syntax error instead of the panic by
switching to use `ParseExprs`, which can handle multiple input expressions.

Informs #109342

Release note: None

#### plpgsql: return correct error for invalid parantheses and missing expression

This patch fixes error messages in the PLpgSQL parser for the case when
the parenthesis nesting is invalid, and for the case when no expression
(or statement) is supplied. Previously, invalid parentheses would cause
a panic without an error code, and a missing expression had the incorrect
message, since it wasn't checked until the SQL parser attempted to read
an empty string. Now, both cases are checked immediately by the PLpgSQL
parser and the correct error is propagated.

Fixes #109342

Release note: None

Co-authored-by: Drew Kimball <[email protected]>
  • Loading branch information
craig[bot] and DrewKimball committed Oct 17, 2023
2 parents 0cc9c3e + 6e331e6 commit 45f6344
Show file tree
Hide file tree
Showing 9 changed files with 323 additions and 81 deletions.
4 changes: 2 additions & 2 deletions pkg/sql/parser/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ func GetTypeReferenceFromName(typeName tree.Name) (tree.ResolvableTypeReference,

// GetTypeFromValidSQLSyntax retrieves a type from its SQL syntax. The caller is
// responsible for guaranteeing that the type expression is valid
// SQL. This includes verifying that complex identifiers are enclosed
// in double quotes, etc.
// SQL (or handling the resulting error). This includes verifying that complex
// identifiers are enclosed in double quotes, etc.
func GetTypeFromValidSQLSyntax(sql string) (tree.ResolvableTypeReference, error) {
expr, err := ParseExpr(fmt.Sprintf("1::%s", sql))
if err != nil {
Expand Down
110 changes: 66 additions & 44 deletions pkg/sql/plpgsql/parser/lexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ func (l *lexer) MakeExecSqlStmt() (*plpgsqltree.Execute, error) {
}
// Push back the first token so that it's included in the SQL string.
l.PushBack(1)
startPos, endPos, _ := l.readSQLConstruct(';')
if endPos <= startPos || startPos <= 0 {
return nil, errors.New("expected SQL statement")
startPos, endPos, _, err := l.readSQLConstruct(false /* isExpr */, ';')
if err != nil {
return nil, err
}
// Move past the semicolon.
l.lastPos++
Expand Down Expand Up @@ -200,8 +200,11 @@ func (l *lexer) MakeExecSqlStmt() (*plpgsqltree.Execute, error) {
}, nil
}

func (l *lexer) MakeDynamicExecuteStmt() *plpgsqltree.DynamicExecute {
cmdStr, _ := l.ReadSqlConstruct(INTO, USING, ';')
func (l *lexer) MakeDynamicExecuteStmt() (*plpgsqltree.DynamicExecute, error) {
cmdStr, _, err := l.ReadSqlStatement(INTO, USING, ';')
if err != nil {
return nil, err
}
ret := &plpgsqltree.DynamicExecute{
Query: cmdStr,
}
Expand All @@ -211,7 +214,7 @@ func (l *lexer) MakeDynamicExecuteStmt() *plpgsqltree.DynamicExecute {
for {
if lval.id == INTO {
if ret.Into {
l.setErr(errors.AssertionFailedf("seen multiple INTO"))
return nil, errors.New("multiple INTO keywords")
}
ret.Into = true
nextTok := l.Peek()
Expand All @@ -221,15 +224,21 @@ func (l *lexer) MakeDynamicExecuteStmt() *plpgsqltree.DynamicExecute {
}
// TODO we need to read each "INTO" variable name instead of just a
// string.
l.ReadSqlExpressionStr2(USING, ';')
_, _, err = l.ReadSqlExpr(USING, ';')
if err != nil {
return nil, err
}
l.Lex(&lval)
} else if lval.id == USING {
if ret.Params != nil {
l.setErr(errors.AssertionFailedf("seen multiple USINGs"))
return nil, errors.New("multiple USING keywords")
}
ret.Params = make([]plpgsqltree.Expr, 0)
for {
l.ReadSqlConstruct(',', ';', INTO)
_, _, err = l.ReadSqlExpr(',', ';', INTO)
if err != nil {
return nil, err
}
ret.Params = append(ret.Params, nil)
l.Lex(&lval)
if lval.id == ';' {
Expand All @@ -239,32 +248,16 @@ func (l *lexer) MakeDynamicExecuteStmt() *plpgsqltree.DynamicExecute {
} else if lval.id == ';' {
break
} else {
l.setErr(errors.AssertionFailedf("syntax error"))
return nil, errors.Newf("unexpected token: %s", lval.id)
}
}

return ret
}

// ReadSqlExpressionStr returns the string from the l.lastPos till it sees
// the terminator for the first time. The returned string is made by tokens
// between the starting index (included) to the terminator (not included).
// TODO(plpgsql-team): pass the output to the sql parser
// (i.e. sqlParserImpl.Parse()).
func (l *lexer) ReadSqlExpressionStr(terminator int) (sqlStr string) {
sqlStr, _ = l.ReadSqlConstruct(terminator, 0, 0)
return sqlStr
}

func (l *lexer) ReadSqlExpressionStr2(
terminator1 int, terminator2 int,
) (sqlStr string, terminatorMet int) {
return l.ReadSqlConstruct(terminator1, terminator2, 0)
return ret, nil
}

func (l *lexer) readSQLConstruct(
terminator1 int, terminators ...int,
) (startPos, endPos, terminatorMet int) {
isExpr bool, terminator1 int, terminators ...int,
) (startPos, endPos, terminatorMet int, err error) {
if l.parser.Lookahead() != -1 {
// Push back the lookahead token so that it can be included.
l.PushBack(1)
Expand All @@ -290,24 +283,26 @@ func (l *lexer) readSQLConstruct(
} else if tok.id == ')' || tok.id == ']' {
parenLevel--
if parenLevel < 0 {
panic(errors.AssertionFailedf("wrongly nested parentheses"))
return 0, 0, 0, errors.New("mismatched parentheses")
}
}
l.lastPos++
}
if parenLevel != 0 {
panic(errors.AssertionFailedf("parentheses is badly nested"))
}
if startPos > l.lastPos {
//TODO(jane): show the terminator in the panic message.
l.setErr(errors.New("missing SQL expression"))
return 0, 0, 0
return 0, 0, 0, errors.New("mismatched parentheses")
}
endPos = l.lastPos + 1
if endPos > len(l.tokens) {
endPos = len(l.tokens)
}
return startPos, endPos, terminatorMet
if endPos <= startPos {
if isExpr {
return 0, 0, 0, errors.New("missing expression")
} else {
return 0, 0, 0, errors.New("missing SQL statement")
}
}
return startPos, endPos, terminatorMet, nil
}

func (l *lexer) MakeFetchOrMoveStmt(isMove bool) (plpgsqltree.Statement, error) {
Expand All @@ -319,7 +314,10 @@ func (l *lexer) MakeFetchOrMoveStmt(isMove bool) (plpgsqltree.Statement, error)
if isMove {
prefix = "MOVE "
}
sqlStr, terminator := l.ReadSqlConstruct(INTO, ';')
sqlStr, terminator, err := l.ReadSqlStatement(INTO, ';')
if err != nil {
return nil, err
}
sqlStr = prefix + sqlStr
sqlStmt, err := parser.ParseOne(sqlStr)
if err != nil {
Expand All @@ -341,7 +339,10 @@ func (l *lexer) MakeFetchOrMoveStmt(isMove bool) (plpgsqltree.Statement, error)
}
// Read past the INTO.
l.lastPos++
startPos, endPos, _ := l.readSQLConstruct(';')
startPos, endPos, _, err := l.readSQLConstruct(true /* isExpr */, ';')
if err != nil {
return nil, err
}
for pos := startPos; pos < endPos; pos += 2 {
tok := l.tokens[pos]
if tok.id != IDENT {
Expand All @@ -366,12 +367,24 @@ func (l *lexer) MakeFetchOrMoveStmt(isMove bool) (plpgsqltree.Statement, error)
}, nil
}

func (l *lexer) ReadSqlConstruct(
func (l *lexer) ReadSqlExpr(
terminator1 int, terminators ...int,
) (sqlStr string, terminatorMet int, err error) {
var startPos, endPos int
startPos, endPos, terminatorMet, err = l.readSQLConstruct(
true /* isExpr */, terminator1, terminators...,
)
return l.getStr(startPos, endPos), terminatorMet, err
}

func (l *lexer) ReadSqlStatement(
terminator1 int, terminators ...int,
) (sqlStr string, terminatorMet int) {
) (sqlStr string, terminatorMet int, err error) {
var startPos, endPos int
startPos, endPos, terminatorMet = l.readSQLConstruct(terminator1, terminators...)
return l.getStr(startPos, endPos), terminatorMet
startPos, endPos, terminatorMet, err = l.readSQLConstruct(
false /* isExpr */, terminator1, terminators...,
)
return l.getStr(startPos, endPos), terminatorMet, err
}

func (l *lexer) getStr(startPos, endPos int) string {
Expand Down Expand Up @@ -463,5 +476,14 @@ func (l *lexer) GetTypeFromValidSQLSyntax(sqlStr string) (tree.ResolvableTypeRef
}

func (l *lexer) ParseExpr(sqlStr string) (plpgsqltree.Expr, error) {
return parser.ParseExpr(sqlStr)
// Use ParseExprs instead of ParseExpr in order to correctly handle the case
// when multiple expressions are incorrectly passed.
exprs, err := parser.ParseExprs([]string{sqlStr})
if err != nil {
return nil, err
}
if len(exprs) != 1 {
return nil, pgerror.Newf(pgcode.Syntax, "query returned %d columns", len(exprs))
}
return exprs[0], nil
}
Loading

0 comments on commit 45f6344

Please sign in to comment.