Skip to content

Commit

Permalink
sqlsmith: extend PLpgSQL support
Browse files Browse the repository at this point in the history
This commit adds support for WHILE and FOR (int) loops as well as SELECT
INTO / RETURNING INTO variant of statements.

Release note: None
  • Loading branch information
yuzefovich committed Dec 11, 2024
1 parent 17e7e5e commit a3fd8ba
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 43 deletions.
83 changes: 76 additions & 7 deletions pkg/internal/sqlsmith/plpgsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ func (s *Smither) makePLpgSQLBlock(scope plpgsqlBlockScope) *ast.Block {
}
}

func (s *Smither) makePLpgSQLVarName(prefix string, scope plpgsqlBlockScope) tree.Name {
varName := s.name(prefix)
for scope.hasVariable(string(varName)) {
varName = s.name(prefix)
}
return varName
}

func (s *Smither) makePLpgSQLDeclarations(
scope plpgsqlBlockScope,
) ([]ast.Statement, plpgsqlBlockScope) {
Expand All @@ -50,10 +58,7 @@ func (s *Smither) makePLpgSQLDeclarations(
// TODO(#106368): add support for cursor declarations.
decls := make([]ast.Statement, numDecls)
for i := 0; i < numDecls; i++ {
varName := s.name("decl")
for newScope.hasVariable(string(varName)) {
varName = s.name("decl")
}
varName := s.makePLpgSQLVarName("decl", newScope)
varTyp := s.randType()
for varTyp.Identical(types.AnyTuple) || varTyp.Family() == types.CollatedStringFamily {
// TODO(#114874): allow record types here when they are supported.
Expand Down Expand Up @@ -134,6 +139,8 @@ var (
{1, makePLpgSQLBlock},
{2, makePLpgSQLReturn},
{2, makePLpgSQLIf},
{2, makePLpgSQLWhile},
{2, makePLpgSQLForLoop},
{5, makePLpgSQLNull},
{10, makePLpgSQLAssign},
{10, makePLpgSQLExecSQL},
Expand Down Expand Up @@ -172,13 +179,44 @@ func makePLpgSQLAssign(s *Smither, scope plpgsqlBlockScope) (stmt ast.Statement,
}

func makePLpgSQLExecSQL(s *Smither, scope plpgsqlBlockScope) (stmt ast.Statement, ok bool) {
// TODO(#106368): add support for SELECT/RETURNING INTO statements.
const maxRetries = 5
var sqlStmt tree.Statement
for i := 0; i < maxRetries; i++ {
sqlStmt, ok = s.makeSQLStmtForRoutine(scope.vol, scope.refs)
var desiredTypes []*types.T
var targets []ast.Variable
if s.coin() {
// Support INTO syntax. Pick a subset of variables to assign into.
usedVars := make(map[string]struct{})
numNonConstVars := len(scope.vars) - len(scope.constants)
for len(usedVars) < numNonConstVars {
// Pick non-constant variable that hasn't been used yet.
var varName string
for {
varName = scope.vars[s.rnd.Intn(len(scope.vars))]
if scope.variableIsConstant(varName) {
continue
}
if _, used := usedVars[varName]; used {
continue
}
usedVars[varName] = struct{}{}
desiredTypes = append(desiredTypes, scope.varTypes[varName])
targets = append(targets, tree.Name(varName))
break
}
if s.coin() {
break
}
}
}
sqlStmt, ok = s.makeSQLStmtForRoutine(scope.vol, scope.refs, desiredTypes)
if ok {
return &ast.Execute{SqlStmt: sqlStmt}, true
return &ast.Execute{
SqlStmt: sqlStmt,
// Strict option won't matter if targets is empty.
Strict: s.d6() == 1,
Target: targets,
}, true
}
}
return nil, false
Expand All @@ -188,6 +226,37 @@ func makePLpgSQLNull(_ *Smither, _ plpgsqlBlockScope) (stmt ast.Statement, ok bo
return &ast.Null{}, true
}

func makePLpgSQLForLoop(s *Smither, scope plpgsqlBlockScope) (stmt ast.Statement, ok bool) {
// TODO(#105246): add support for other query and cursor FOR loops.
control := ast.IntForLoopControl{
Reverse: s.coin(),
Lower: s.makePLpgSQLExpr(scope, types.Int),
Upper: s.makePLpgSQLExpr(scope, types.Int),
}
if s.coin() {
control.Step = s.makePLpgSQLExpr(scope, types.Int)
}
newScope := scope.makeChild(1 /* numNewVars */)
loopVarName := s.makePLpgSQLVarName("loop", newScope)
newScope.addVariable(string(loopVarName), types.Int, false /* constant */)
const maxLoopStmts = 3
return &ast.ForLoop{
// TODO(#106368): optionally add a label.
Target: []ast.Variable{loopVarName},
Control: &control,
Body: s.makePLpgSQLStatements(newScope, maxLoopStmts),
}, true
}

func makePLpgSQLWhile(s *Smither, scope plpgsqlBlockScope) (stmt ast.Statement, ok bool) {
const maxLoopStmts = 3
return &ast.While{
// TODO(#106368): optionally add a label.
Condition: s.makePLpgSQLCond(scope),
Body: s.makePLpgSQLStatements(scope, maxLoopStmts),
}, true
}

// plpgsqlBlockScope holds the information needed to ensure that generated
// statements obey PL/pgSQL syntax and scoping rules.
type plpgsqlBlockScope struct {
Expand Down
100 changes: 64 additions & 36 deletions pkg/internal/sqlsmith/relational.go
Original file line number Diff line number Diff line change
Expand Up @@ -1194,7 +1194,7 @@ func (s *Smither) makeRoutineBodySQL(
stmts := make([]string, 0, stmtCnt)
var stmt tree.Statement
for i := 0; i < stmtCnt-1; i++ {
stmt, ok = s.makeSQLStmtForRoutine(vol, refs)
stmt, ok = s.makeSQLStmtForRoutine(vol, refs, nil /* desiredTypes */)
if !ok {
continue
}
Expand All @@ -1203,45 +1203,57 @@ func (s *Smither) makeRoutineBodySQL(
// The return type of the last statement should match the function return
// type.
// If mutations are enabled, also use anything from mutatingTableExprs -- needs returning
desiredTypes := []*types.T{rTyp}
if s.disableMutations || vol != tree.RoutineVolatile || s.coin() {
stmt, lastStmtRefs, ok = s.makeSelect([]*types.T{rTyp}, refs)
stmt, lastStmtRefs, ok = s.makeSelect(desiredTypes, refs)
if !ok {
return "", nil, false
}
} else {
var expr tree.TableExpr
switch s.d6() {
case 1, 2:
expr, lastStmtRefs, ok = s.makeInsertReturning(refs)
stmt, lastStmtRefs, ok = s.makeInsertReturning(desiredTypes, refs)
case 3, 4:
expr, lastStmtRefs, ok = s.makeDeleteReturning(refs)
stmt, lastStmtRefs, ok = s.makeDeleteReturning(desiredTypes, refs)
case 5, 6:
expr, lastStmtRefs, ok = s.makeUpdateReturning(refs)
stmt, lastStmtRefs, ok = s.makeUpdateReturning(desiredTypes, refs)
}
if !ok {
return "", nil, false
}
stmt = expr.(*tree.StatementSource).Statement
}
stmts = append(stmts, tree.AsStringWithFlags(stmt, tree.FmtParsable))
return "\n" + strings.Join(stmts, ";\n") + "\n", lastStmtRefs, true
}

func (s *Smither) makeSQLStmtForRoutine(
vol tree.RoutineVolatility, refs colRefs,
vol tree.RoutineVolatility, refs colRefs, desiredTypes []*types.T,
) (stmt tree.Statement, ok bool) {
const numRetries = 5
for i := 0; i < numRetries; i++ {
if s.disableMutations || vol != tree.RoutineVolatile || s.coin() {
stmt, _, ok = s.makeSelect(nil /* desiredTypes */, refs)
stmt, _, ok = s.makeSelect(desiredTypes, refs)
} else {
switch s.d6() {
case 1, 2:
stmt, _, ok = s.makeInsert(refs)
case 3, 4:
stmt, _, ok = s.makeDelete(refs)
case 5, 6:
stmt, _, ok = s.makeUpdate(refs)
if len(desiredTypes) == 0 && s.coin() {
// If the caller didn't request particular result types, in 50%
// cases use the "vanilla" mutation stmts.
switch s.d6() {
case 1, 2:
stmt, _, ok = s.makeInsert(refs)
case 3, 4:
stmt, _, ok = s.makeDelete(refs)
case 5, 6:
stmt, _, ok = s.makeUpdate(refs)
}
} else {
switch s.d6() {
case 1, 2:
stmt, _, ok = s.makeInsertReturning(desiredTypes, refs)
case 3, 4:
stmt, _, ok = s.makeDeleteReturning(desiredTypes, refs)
case 5, 6:
stmt, _, ok = s.makeUpdateReturning(desiredTypes, refs)
}
}
}
if ok {
Expand Down Expand Up @@ -1309,19 +1321,23 @@ func makeDeleteReturning(s *Smither, refs colRefs, forJoin bool) (tree.TableExpr
if forJoin {
return nil, nil, false
}
return s.makeDeleteReturning(refs)
del, returningRefs, ok := s.makeDeleteReturning(nil /* desiredTypes */, refs)
if !ok {
return nil, nil, false
}
return &tree.StatementSource{Statement: del}, returningRefs, true
}

func (s *Smither) makeDeleteReturning(refs colRefs) (tree.TableExpr, colRefs, bool) {
func (s *Smither) makeDeleteReturning(
desiredTypes []*types.T, refs colRefs,
) (*tree.Delete, colRefs, bool) {
del, delRef, ok := s.makeDelete(refs)
if !ok {
return nil, nil, false
}
var returningRefs colRefs
del.Returning, returningRefs = s.makeReturning(delRef)
return &tree.StatementSource{
Statement: del,
}, returningRefs, true
del.Returning, returningRefs = s.makeReturning(desiredTypes, delRef)
return del, returningRefs, true
}

func makeUpdate(s *Smither) (tree.Statement, bool) {
Expand Down Expand Up @@ -1417,19 +1433,23 @@ func makeUpdateReturning(s *Smither, refs colRefs, forJoin bool) (tree.TableExpr
if forJoin {
return nil, nil, false
}
return s.makeUpdateReturning(refs)
update, returningRefs, ok := s.makeUpdateReturning(nil /* desiredTypes */, refs)
if !ok {
return nil, nil, false
}
return &tree.StatementSource{Statement: update}, returningRefs, true
}

func (s *Smither) makeUpdateReturning(refs colRefs) (tree.TableExpr, colRefs, bool) {
func (s *Smither) makeUpdateReturning(
desiredTypes []*types.T, refs colRefs,
) (*tree.Update, colRefs, bool) {
update, updateRef, ok := s.makeUpdate(refs)
if !ok {
return nil, nil, false
}
var returningRefs colRefs
update.Returning, returningRefs = s.makeReturning(updateRef)
return &tree.StatementSource{
Statement: update,
}, returningRefs, true
update.Returning, returningRefs = s.makeReturning(desiredTypes, updateRef)
return update, returningRefs, true
}

func makeInsert(s *Smither) (tree.Statement, bool) {
Expand Down Expand Up @@ -1590,19 +1610,23 @@ func makeInsertReturning(s *Smither, refs colRefs, forJoin bool) (tree.TableExpr
if forJoin {
return nil, nil, false
}
return s.makeInsertReturning(refs)
insert, returningRefs, ok := s.makeInsertReturning(nil /* desiredTypes */, refs)
if !ok {
return nil, nil, false
}
return &tree.StatementSource{Statement: insert}, returningRefs, true
}

func (s *Smither) makeInsertReturning(refs colRefs) (tree.TableExpr, colRefs, bool) {
func (s *Smither) makeInsertReturning(
desiredTypes []*types.T, refs colRefs,
) (*tree.Insert, colRefs, bool) {
insert, insertRef, ok := s.makeInsert(refs)
if !ok {
return nil, nil, false
}
var returningRefs colRefs
insert.Returning, returningRefs = s.makeReturning([]*tableRef{insertRef})
return &tree.StatementSource{
Statement: insert,
}, returningRefs, true
insert.Returning, returningRefs = s.makeReturning(desiredTypes, []*tableRef{insertRef})
return insert, returningRefs, true
}

func makeValuesTable(s *Smither, refs colRefs, forJoin bool) (tree.TableExpr, colRefs, bool) {
Expand Down Expand Up @@ -1800,8 +1824,12 @@ func makeLimit(s *Smither) *tree.Limit {
return nil
}

func (s *Smither) makeReturning(tables []*tableRef) (*tree.ReturningExprs, colRefs) {
desiredTypes := s.makeDesiredTypes()
func (s *Smither) makeReturning(
desiredTypes []*types.T, tables []*tableRef,
) (*tree.ReturningExprs, colRefs) {
if len(desiredTypes) == 0 {
desiredTypes = s.makeDesiredTypes()
}

var refs colRefs
for _, table := range tables {
Expand Down

0 comments on commit a3fd8ba

Please sign in to comment.