diff --git a/pkg/cmd/roachtest/tests/BUILD.bazel b/pkg/cmd/roachtest/tests/BUILD.bazel index 864439acc653..c550f41cca3f 100644 --- a/pkg/cmd/roachtest/tests/BUILD.bazel +++ b/pkg/cmd/roachtest/tests/BUILD.bazel @@ -184,6 +184,7 @@ go_library( "@com_github_codahale_hdrhistogram//:hdrhistogram", "@com_github_dustin_go_humanize//:go-humanize", "@com_github_golang_mock//gomock", + "@com_github_google_go_cmp//cmp", "@com_github_jackc_pgtype//:pgtype", "@com_github_kr_pretty//:pretty", "@com_github_lib_pq//:pq", diff --git a/pkg/cmd/roachtest/tests/tlp.go b/pkg/cmd/roachtest/tests/tlp.go index e0464ead749e..180b8f5a6074 100644 --- a/pkg/cmd/roachtest/tests/tlp.go +++ b/pkg/cmd/roachtest/tests/tlp.go @@ -16,6 +16,7 @@ import ( "fmt" "os" "path/filepath" + "sort" "strings" "time" @@ -23,8 +24,10 @@ import ( "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/registry" "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/test" "github.com/cockroachdb/cockroach/pkg/internal/sqlsmith" + "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" "github.com/cockroachdb/cockroach/pkg/util/randutil" "github.com/cockroachdb/errors" + "github.com/google/go-cmp/cmp" ) const statementTimeout = time.Minute @@ -169,34 +172,58 @@ func runTLPQuery(conn *gosql.DB, smither *sqlsmith.Smither, logStmt func(string) unpartitioned, partitioned := smither.GenerateTLP() return runWithTimeout(func() error { - var unpartitionedCount int - row := conn.QueryRow(unpartitioned) - if err := row.Scan(&unpartitionedCount); err != nil { + rows1, err := conn.Query(unpartitioned) + if err != nil { // Ignore errors. //nolint:returnerrcheck return nil } - - var partitionedCount int - row = conn.QueryRow(partitioned) - if err := row.Scan(&partitionedCount); err != nil { + defer rows1.Close() + unpartitionedRows, err := sqlutils.RowsToStrMatrix(rows1) + if err != nil { + // Ignore errors. + //nolint:returnerrcheck + return nil + } + rows2, err := conn.Query(partitioned) + if err != nil { + // Ignore errors. + //nolint:returnerrcheck + return nil + } + defer rows2.Close() + partitionedRows, err := sqlutils.RowsToStrMatrix(rows2) + if err != nil { // Ignore errors. //nolint:returnerrcheck return nil } - if unpartitionedCount != partitionedCount { + if diff := unsortedMatricesDiff(unpartitionedRows, partitionedRows); diff != "" { logStmt(unpartitioned) logStmt(partitioned) return errors.Newf( - "expected unpartitioned count %d to equal partitioned count %d\nsql: %s\n%s", - unpartitionedCount, partitionedCount, unpartitioned, partitioned) + "expected unpartitioned results to equal partitioned results\n%s\nsql: %s\n%s", + diff, unpartitioned, partitioned) } - return nil }) } +func unsortedMatricesDiff(rowMatrix1, rowMatrix2 [][]string) string { + var rows1 []string + for _, row := range rowMatrix1 { + rows1 = append(rows1, strings.Join(row[:], ",")) + } + var rows2 []string + for _, row := range rowMatrix2 { + rows2 = append(rows2, strings.Join(row[:], ",")) + } + sort.Strings(rows1) + sort.Strings(rows2) + return cmp.Diff(rows1, rows2) +} + func runWithTimeout(f func() error) error { done := make(chan error, 1) go func() { diff --git a/pkg/internal/sqlsmith/tlp.go b/pkg/internal/sqlsmith/tlp.go index db627d20e004..e0c23835d8cb 100644 --- a/pkg/internal/sqlsmith/tlp.go +++ b/pkg/internal/sqlsmith/tlp.go @@ -28,14 +28,10 @@ import ( // // More information on TLP: https://www.manuelrigger.at/preprints/TLP.pdf. // -// We currently implement a limited form of TLP that can only verify that the -// number of rows returned by the unpartitioned and the partitioned queries are -// equal. -// -// This TLP implementation is also limited in the types of queries that are -// tested. We currently only test basic WHERE and JOIN query filters. It is -// possible to use TLP to test aggregations, GROUP BY, and HAVING, which have -// all been implemented in SQLancer. See: +// This TLP implementation is limited in the types of queries that are tested. +// We currently only test basic WHERE, JOIN, and MAX/MIN query filters. It is +// possible to use TLP to test other aggregations, GROUP BY, and HAVING, which +// have all been implemented in SQLancer. See: // https://github.com/sqlancer/sqlancer/tree/1.1.0/src/sqlancer/cockroachdb/oracle/tlp. func (s *Smither) GenerateTLP() (unpartitioned, partitioned string) { // Set disableImpureFns to true so that generated predicates are immutable. @@ -63,19 +59,17 @@ func (s *Smither) GenerateTLP() (unpartitioned, partitioned string) { // // The first query returned is an unpartitioned query of the form: // -// SELECT count(*) FROM table +// SELECT * FROM table // // The second query returned is a partitioned query of the form: // -// SELECT count(*) FROM ( -// SELECT * FROM table WHERE (p) -// UNION ALL -// SELECT * FROM table WHERE NOT (p) -// UNION ALL -// SELECT * FROM table WHERE (p) IS NULL -// ) +// SELECT * FROM table WHERE (p) +// UNION ALL +// SELECT * FROM table WHERE NOT (p) +// UNION ALL +// SELECT * FROM table WHERE (p) IS NULL // -// If the resulting counts of the two queries are not equal, there is a logical +// If the resulting values of the two queries are not equal, there is a logical // bug. func (s *Smither) generateWhereTLP() (unpartitioned, partitioned string) { f := tree.NewFmtCtx(tree.FmtParsable) @@ -87,7 +81,7 @@ func (s *Smither) generateWhereTLP() (unpartitioned, partitioned string) { table.Format(f) tableName := f.CloseAndGetString() - unpartitioned = fmt.Sprintf("SELECT count(*) FROM %s", tableName) + unpartitioned = fmt.Sprintf("SELECT * FROM %s", tableName) pred := makeBoolExpr(s, cols) pred.Format(f) @@ -98,7 +92,7 @@ func (s *Smither) generateWhereTLP() (unpartitioned, partitioned string) { part3 := fmt.Sprintf("SELECT * FROM %s WHERE (%s) IS NULL", tableName, predicate) partitioned = fmt.Sprintf( - "SELECT count(*) FROM (%s UNION ALL %s UNION ALL %s)", + "(%s) UNION ALL (%s) UNION ALL (%s)", part1, part2, part3, ) @@ -112,23 +106,19 @@ func (s *Smither) generateWhereTLP() (unpartitioned, partitioned string) { // // The first query returned is an unpartitioned query of the form: // -// SELECT count(*) FROM ( -// SELECT * FROM table1 LEFT JOIN table2 ON TRUE -// UNION ALL -// SELECT * FROM table1 LEFT JOIN table2 ON FALSE -// UNION ALL -// SELECT * FROM table1 LEFT JOIN table2 ON FALSE -// ) +// SELECT * FROM table1 LEFT JOIN table2 ON TRUE +// UNION ALL +// SELECT * FROM table1 LEFT JOIN table2 ON FALSE +// UNION ALL +// SELECT * FROM table1 LEFT JOIN table2 ON FALSE // // The second query returned is a partitioned query of the form: // -// SELECT count(*) FROM ( -// SELECT * FROM table1 LEFT JOIN table2 ON (p) -// UNION ALL -// SELECT * FROM table1 LEFT JOIN table2 ON NOT (p) -// UNION ALL -// SELECT * FROM table1 LEFT JOIN table2 ON (p) IS NULL -// ) +// SELECT * FROM table1 LEFT JOIN table2 ON (p) +// UNION ALL +// SELECT * FROM table1 LEFT JOIN table2 ON NOT (p) +// UNION ALL +// SELECT * FROM table1 LEFT JOIN table2 ON (p) IS NULL // // From the first query, we have a CROSS JOIN of the two tables (JOIN ON TRUE) // and then all rows concatenated with NULL values for the second and third @@ -144,7 +134,7 @@ func (s *Smither) generateWhereTLP() (unpartitioned, partitioned string) { // Note that this implementation is restricted in that it only uses columns from // the left table in the predicate p. -// If the resulting counts of the two queries are not equal, there is a logical +// If the resulting values of the two queries are not equal, there is a logical // bug. func (s *Smither) generateOuterJoinTLP() (unpartitioned, partitioned string) { f := tree.NewFmtCtx(tree.FmtParsable) @@ -169,7 +159,7 @@ func (s *Smither) generateOuterJoinTLP() (unpartitioned, partitioned string) { ) unpartitioned = fmt.Sprintf( - "SELECT count(*) FROM (%s UNION ALL %s UNION ALL %s)", + "(%s) UNION ALL (%s) UNION ALL (%s)", leftJoinTrue, leftJoinFalse, leftJoinFalse, ) @@ -191,7 +181,7 @@ func (s *Smither) generateOuterJoinTLP() (unpartitioned, partitioned string) { ) partitioned = fmt.Sprintf( - "SELECT count(*) FROM (%s UNION ALL %s UNION ALL %s)", + "(%s) UNION ALL (%s) UNION ALL (%s)", part1, part2, part3, ) @@ -205,17 +195,15 @@ func (s *Smither) generateOuterJoinTLP() (unpartitioned, partitioned string) { // // The first query returned is an unpartitioned query of the form: // -// SELECT count(*) FROM table1 JOIN table2 ON TRUE +// SELECT * FROM table1 JOIN table2 ON TRUE // // The second query returned is a partitioned query of the form: // -// SELECT count(*) FROM ( -// SELECT * FROM table1 JOIN table2 ON (p) -// UNION ALL -// SELECT * FROM table1 JOIN table2 ON NOT (p) -// UNION ALL -// SELECT * FROM table1 JOIN table2 ON (p) IS NULL -// ) +// SELECT * FROM table1 JOIN table2 ON (p) +// UNION ALL +// SELECT * FROM table1 JOIN table2 ON NOT (p) +// UNION ALL +// SELECT * FROM table1 JOIN table2 ON (p) IS NULL // // From the first query, we have a CROSS JOIN of the two tables (JOIN ON TRUE). // Recall our TLP logical guarantee that a given predicate p always evaluates to @@ -224,7 +212,7 @@ func (s *Smither) generateOuterJoinTLP() (unpartitioned, partitioned string) { // resolve to TRUE. So the partitioned query accounts for each row in the // CROSS JOIN exactly once. // -// If the resulting counts of the two queries are not equal, there is a logical +// If the resulting values of the two queries are not equal, there is a logical // bug. func (s *Smither) generateInnerJoinTLP() (unpartitioned, partitioned string) { f := tree.NewFmtCtx(tree.FmtParsable) @@ -240,7 +228,7 @@ func (s *Smither) generateInnerJoinTLP() (unpartitioned, partitioned string) { tableName2 := f.CloseAndGetString() unpartitioned = fmt.Sprintf( - "SELECT count(*) FROM %s JOIN %s ON true", + "SELECT * FROM %s JOIN %s ON true", tableName1, tableName2, ) @@ -263,7 +251,7 @@ func (s *Smither) generateInnerJoinTLP() (unpartitioned, partitioned string) { ) partitioned = fmt.Sprintf( - "SELECT count(*) FROM (%s UNION ALL %s UNION ALL %s)", + "(%s) UNION ALL (%s) UNION ALL (%s)", part1, part2, part3, )