diff --git a/pkg/internal/sqlsmith/tlp.go b/pkg/internal/sqlsmith/tlp.go index b54ce67dc5f2..db627d20e004 100644 --- a/pkg/internal/sqlsmith/tlp.go +++ b/pkg/internal/sqlsmith/tlp.go @@ -13,6 +13,7 @@ package sqlsmith import ( "fmt" "math/rand" + "strings" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/errors" @@ -44,20 +45,22 @@ func (s *Smither) GenerateTLP() (unpartitioned, partitioned string) { s.disableImpureFns = originalDisableImpureFns }() - switch tlpType := rand.Intn(3); tlpType { + switch tlpType := rand.Intn(4); tlpType { case 0: return s.generateWhereTLP() case 1: return s.generateOuterJoinTLP() - default: + case 2: return s.generateInnerJoinTLP() + default: + return s.generateAggregationTLP() } } // generateWhereTLP returns two SQL queries as strings that can be used by the // GenerateTLP function. These queries make use of the WHERE clause to partition // the original query into three. - +// // The first query returned is an unpartitioned query of the form: // // SELECT count(*) FROM table @@ -106,7 +109,7 @@ func (s *Smither) generateWhereTLP() (unpartitioned, partitioned string) { // GenerateTLP function. These queries make use of LEFT JOIN to partition the // original query in two ways. The latter query is partitioned by a predicate p, // while the former is not. - +// // The first query returned is an unpartitioned query of the form: // // SELECT count(*) FROM ( @@ -199,7 +202,7 @@ func (s *Smither) generateOuterJoinTLP() (unpartitioned, partitioned string) { // the GenerateTLP function. These queries make use of INNER JOIN to partition // the original query in two ways. The latter query is partitioned by a // predicate p, while the former is not. - +// // The first query returned is an unpartitioned query of the form: // // SELECT count(*) FROM table1 JOIN table2 ON TRUE @@ -266,3 +269,83 @@ func (s *Smither) generateInnerJoinTLP() (unpartitioned, partitioned string) { return unpartitioned, partitioned } + +// generateAggregationTLP returns two SQL queries as strings that can be used by +// the GenerateTLP function. These queries make use of the WHERE clause and a +// predicate p to partition the original query into three. The two aggregations +// that are supported are MAX() and MIN(), although SUM(), AVG(), and COUNT() +// are also valid TLP aggregations. +// +// The first query returned is an unpartitioned query of the form: +// +// SELECT MAX(first) FROM (SELECT * FROM table) table(first) +// +// The second query returned is a partitioned query of the form: +// +// SELECT MAX(agg) FROM ( +// SELECT MAX(first) AS agg FROM ( +// SELECT * FROM table WHERE p +// ) table(first) +// UNION ALL +// SELECT MAX(first) AS agg FROM ( +// SELECT * FROM table WHERE NOT (p) +// ) table(first) +// UNION ALL +// SELECT MAX(first) AS agg FROM ( +// SELECT * FROM table WHERE (p) IS NULL +// ) table(first) +// ) +// +// Note that all instances of MAX can be replaced with MIN to get the +// corresponding MIN version of the queries. +// +// If the resulting values of the two queries are not equal, there is a logical +// bug. +func (s *Smither) generateAggregationTLP() (unpartitioned, partitioned string) { + f := tree.NewFmtCtx(tree.FmtParsable) + + table, _, _, cols, ok := s.getSchemaTable() + if !ok { + panic(errors.AssertionFailedf("failed to find random table")) + } + table.Format(f) + tableName := f.CloseAndGetString() + tableNameAlias := strings.TrimSpace(strings.Split(tableName, "AS")[1]) + + var agg string + switch aggType := rand.Intn(2); aggType { + case 0: + agg = "MAX" + default: + agg = "MIN" + } + + unpartitioned = fmt.Sprintf( + "SELECT %s(first) FROM (SELECT * FROM %s) %s(first)", + agg, tableName, tableNameAlias, + ) + + pred := makeBoolExpr(s, cols) + pred.Format(f) + predicate := f.CloseAndGetString() + + part1 := fmt.Sprintf( + "SELECT %s(first) AS agg FROM (SELECT * FROM %s WHERE %s) %s(first)", + agg, tableName, predicate, tableNameAlias, + ) + part2 := fmt.Sprintf( + "SELECT %s(first) AS agg FROM (SELECT * FROM %s WHERE NOT (%s)) %s(first)", + agg, tableName, predicate, tableNameAlias, + ) + part3 := fmt.Sprintf( + "SELECT %s(first) AS agg FROM (SELECT * FROM %s WHERE (%s) IS NULL) %s(first)", + agg, tableName, predicate, tableNameAlias, + ) + + partitioned = fmt.Sprintf( + "SELECT %s(agg) FROM (%s UNION ALL %s UNION ALL %s)", + agg, part1, part2, part3, + ) + + return unpartitioned, partitioned +}