Skip to content

Commit

Permalink
sql: Adds support for stmt_case
Browse files Browse the repository at this point in the history
This commit adds portions of the logic to parse `stmt_case`.
Does not yet support ELSE, sql statements in THEN body, or END CASE (it
uses an ENDCASE hack).

Release note: None
  • Loading branch information
rharding6373 authored and e-mbrown committed May 10, 2023
1 parent 7b245cb commit 3b22ee1
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 20 deletions.
61 changes: 45 additions & 16 deletions pkg/sql/plpgsql/parser/plpgsql.y
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ func (u *plpgsqlSymUnion) plpgsqlStmtBlock() *plpgsqltree.PLpgSQLStmtBlock {
return u.val.(*plpgsqltree.PLpgSQLStmtBlock)
}


func (u *plpgsqlSymUnion) plpgsqlStmtCaseWhenArm() *plpgsqltree.PLpgSQLStmtCaseWhenArm {
return u.val.(*plpgsqltree.PLpgSQLStmtCaseWhenArm)
}
Expand Down Expand Up @@ -245,6 +244,7 @@ func (u *plpgsqlSymUnion) pLpgSQLStmtOpen() *plpgsqltree.PLpgSQLStmtOpen {
%token <str> WHEN
%token <str> WHILE
%token <str> ENDIF
%token <str> ENDCASE

%union {
id int32
Expand All @@ -266,6 +266,7 @@ func (u *plpgsqlSymUnion) pLpgSQLStmtOpen() *plpgsqltree.PLpgSQLStmtOpen {
//%type <nsitem> decl_aliasitem // TODO what is nsitem? looks like namespace item, not sure if we need it.

%type <str> expr_until_semi
//%type <plpgsqltree.PLpgSQLExpr> expr_until_then expr_until_loop opt_expr_until_when
%type <str> expr_until_then expr_until_loop opt_expr_until_when
%type <plpgsqltree.PLpgSQLExpr> opt_exitcond

Expand Down Expand Up @@ -296,7 +297,8 @@ func (u *plpgsqlSymUnion) pLpgSQLStmtOpen() *plpgsqltree.PLpgSQLStmtOpen {
%type <*plpgsqltree.PLpgSQLCondition> proc_conditions proc_condition

%type <*plpgsqltree.PLpgSQLStmtCaseWhenArm> case_when
//%type <list> case_when_list opt_case_else // TODO is this a list of case when arms?
%type <[]*plpgsqltree.PLpgSQLStmtCaseWhenArm> case_when_list
%type <[]plpgsqltree.PLpgSQLStatement> opt_case_else

%type <bool> getdiag_area_opt
%type <plpgsqltree.PLpgSQLStmtGetDiagItemList> getdiag_list // TODO don't know what this is
Expand Down Expand Up @@ -542,7 +544,9 @@ proc_stmt:
| stmt_if
{ }
| stmt_case
{ }
{
$$.val = $1.plpgsqlStatement()
}
| stmt_loop
{ }
| stmt_while
Expand Down Expand Up @@ -743,26 +747,56 @@ stmt_else :
}
;

stmt_case : CASE opt_expr_until_when case_when_list opt_case_else END CASE ';'
// TODO: ENDCASE should be END CASE
stmt_case : CASE opt_expr_until_when case_when_list opt_case_else ENDCASE ';'
{
expr := &plpgsqltree.PLpgSQLStmtCase {
TestExpr: $2,
CaseWhenList: $3.plpgsqlStmtCaseWhenArms(),
}
// TODO: Add support for ELSE
/*
if $4.val != nil {
expr.HaveElse = true
expr.ElseStmts = $4.plpgsqlStatements()
}
*/
$$.val = expr
}
;

opt_expr_until_when :
{
expr := ""
tok := plpgsqllex.(*lexer).Peek()
if tok.id != WHEN {
expr = plpgsqllex.(*lexer).ReadSqlExpressionStr(WHEN)
}
$$ = expr
}
;

case_when_list : case_when_list case_when
{
stmts := $1.plpgsqlStmtCaseWhenArms()
stmts = append(stmts, $2.plpgsqlStmtCaseWhenArm())
$$.val = stmts
}
| case_when
{
stmts := []*plpgsqltree.PLpgSQLStmtCaseWhenArm{}
stmts = append(stmts, $1.plpgsqlStmtCaseWhenArm())
$$.val = stmts
}
;

case_when : WHEN expr_until_then proc_sect
case_when : WHEN expr_until_then THEN proc_sect
{
expr := &plpgsqltree.PLpgSQLStmtCaseWhenArm{
Expr: $2,
Stmts: $4.plpgsqlStatements(),
}
$$.val = expr
}
;

Expand All @@ -771,6 +805,7 @@ opt_case_else :
}
| ELSE proc_sect
{
$$.val = $2.plpgsqlStatements()
}
;

Expand Down Expand Up @@ -1019,10 +1054,10 @@ expr_until_semi :
;

expr_until_then :
{
$$ = plpgsqllex.(*lexer).ReadSqlExpressionStr(THEN)
}
;
{
$$ = plpgsqllex.(*lexer).ReadSqlExpressionStr(THEN)
}
;

expr_until_loop :
{ }
Expand Down Expand Up @@ -1072,13 +1107,6 @@ any_identifier:
}
;

d_expr: ICONST {}
| FCONST {}
| SCONST {}
| USCONST {}
| BCONST {}
;

unreserved_keyword:
ABSOLUTE
| ALIAS
Expand Down Expand Up @@ -1172,6 +1200,7 @@ reserved_keyword:
| DECLARE
| ELSE
| END
| ENDCASE
| ENDIF
| EXECUTE
| FOR
Expand Down
103 changes: 103 additions & 0 deletions pkg/sql/plpgsql/parser/testdata/stmt_case
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
parse
DECLARE
BEGIN
CASE
hello
WHEN
world
THEN
ENDCASE;
END
----
DECLARE
BEGIN
CASE hello
WHEN world
THEN
ENDCASE
END

parse
DECLARE
BEGIN
CASE
order_cnt
WHEN
1, 2, 3
THEN
ENDCASE;
END
----
DECLARE
BEGIN
CASE order_cnt
WHEN 1 , 2 , 3
THEN
ENDCASE
END

parse
DECLARE
BEGIN
CASE
order_cnt
WHEN
1, 2, 3
THEN
WHEN
5
THEN
ENDCASE;
END
----
DECLARE
BEGIN
CASE order_cnt
WHEN 1 , 2 , 3
THEN
WHEN 5
THEN
ENDCASE
END

parse
DECLARE
BEGIN
CASE
WHEN
true
THEN
ENDCASE;
END
----
DECLARE
BEGIN
CASE
WHEN true
THEN
ENDCASE
END

parse
DECLARE
order_cnt integer := 10;
BEGIN
CASE
WHEN
order_cnt BETWEEN 0 AND 100
THEN
WHEN
order_cnt > 100
THEN
ENDCASE;
END
----
DECLARE
BEGIN
CASE
WHEN order_cnt between 0 and 100
THEN
WHEN order_cnt > 100
THEN
ENDCASE
END
31 changes: 27 additions & 4 deletions pkg/sql/sem/plpgsqltree/statements.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,23 +108,46 @@ func (s *PLpgSQLStmtIfElseIfArm) Format(ctx *tree.FmtCtx) {
// stmt_case
type PLpgSQLStmtCase struct {
PLpgSQLStatementImpl
TestExpr PLpgSQLExpr
// TODO: Change to PLpgSQLExpr
TestExpr string
Var PLpgSQLVariable
CaseWhenList []PLpgSQLStmtCaseWhenArm
CaseWhenList []*PLpgSQLStmtCaseWhenArm
HaveElse bool
ElseStmts []PLpgSQLStatement
}

func (s *PLpgSQLStmtCase) Format(ctx *tree.FmtCtx) {
ctx.WriteString("CASE")
if len(s.TestExpr) > 0 {
ctx.WriteString(fmt.Sprintf(" %s", s.TestExpr))
}
ctx.WriteString("\n")
for _, when := range s.CaseWhenList {
when.Format(ctx)
}
if s.HaveElse {
ctx.WriteString("ELSE\n")
for _, stmt := range s.ElseStmts {
stmt.Format(ctx)
}
}
ctx.WriteString("ENDCASE\n")
}

type PLpgSQLStmtCaseWhenArm struct {
LineNo int
Expr PLpgSQLExpr
Stmts []PLpgSQLStatement
// TODO: Change to PLpgSQLExpr
Expr string
Stmts []PLpgSQLStatement
}

func (s *PLpgSQLStmtCaseWhenArm) Format(ctx *tree.FmtCtx) {
ctx.WriteString(fmt.Sprintf("WHEN %s\n", s.Expr))
ctx.WriteString("THEN")
for _, stmt := range s.Stmts {
stmt.Format(ctx)
}
ctx.WriteString("\n")
}

// stmt_loop
Expand Down

0 comments on commit 3b22ee1

Please sign in to comment.