Skip to content

Commit

Permalink
sqlsmith, tests: compare result sets in TLP testing
Browse files Browse the repository at this point in the history
Previously, only the result counts of the unpartitioned
and partitioned TLP queries were compared. This was inadequate
because it allowed for different rows values in the two result sets
to go undetected, as long as the row counts were the same.
To address this, the rows in the result sets are compared
directly to ensure they are the same, as expected.

Release note: None
  • Loading branch information
Neha George committed Sep 30, 2021
1 parent f4a1848 commit b9b70c1
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 58 deletions.
1 change: 1 addition & 0 deletions pkg/cmd/roachtest/tests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ go_library(
"@com_github_codahale_hdrhistogram//:hdrhistogram",
"@com_github_dustin_go_humanize//:go-humanize",
"@com_github_golang_mock//gomock",
"@com_github_google_go_cmp//cmp",
"@com_github_jackc_pgtype//:pgtype",
"@com_github_kr_pretty//:pretty",
"@com_github_lib_pq//:pq",
Expand Down
49 changes: 38 additions & 11 deletions pkg/cmd/roachtest/tests/tlp.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@ import (
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"time"

"github.com/cockroachdb/cockroach/pkg/cmd/roachtest/cluster"
"github.com/cockroachdb/cockroach/pkg/cmd/roachtest/registry"
"github.com/cockroachdb/cockroach/pkg/cmd/roachtest/test"
"github.com/cockroachdb/cockroach/pkg/internal/sqlsmith"
"github.com/cockroachdb/cockroach/pkg/testutils/sqlutils"
"github.com/cockroachdb/cockroach/pkg/util/randutil"
"github.com/cockroachdb/errors"
"github.com/google/go-cmp/cmp"
)

const statementTimeout = time.Minute
Expand Down Expand Up @@ -169,34 +172,58 @@ func runTLPQuery(conn *gosql.DB, smither *sqlsmith.Smither, logStmt func(string)
unpartitioned, partitioned := smither.GenerateTLP()

return runWithTimeout(func() error {
var unpartitionedCount int
row := conn.QueryRow(unpartitioned)
if err := row.Scan(&unpartitionedCount); err != nil {
rows1, err := conn.Query(unpartitioned)
if err != nil {
// Ignore errors.
//nolint:returnerrcheck
return nil
}

var partitionedCount int
row = conn.QueryRow(partitioned)
if err := row.Scan(&partitionedCount); err != nil {
defer rows1.Close()
unpartitionedRows, err := sqlutils.RowsToStrMatrix(rows1)
if err != nil {
// Ignore errors.
//nolint:returnerrcheck
return nil
}
rows2, err := conn.Query(partitioned)
if err != nil {
// Ignore errors.
//nolint:returnerrcheck
return nil
}
defer rows2.Close()
partitionedRows, err := sqlutils.RowsToStrMatrix(rows2)
if err != nil {
// Ignore errors.
//nolint:returnerrcheck
return nil
}

if unpartitionedCount != partitionedCount {
if diff := unsortedMatricesDiff(unpartitionedRows, partitionedRows); diff != "" {
logStmt(unpartitioned)
logStmt(partitioned)
return errors.Newf(
"expected unpartitioned count %d to equal partitioned count %d\nsql: %s\n%s",
unpartitionedCount, partitionedCount, unpartitioned, partitioned)
"expected unpartitioned results to equal partitioned results\n%s\nsql: %s\n%s",
diff, unpartitioned, partitioned)
}

return nil
})
}

func unsortedMatricesDiff(rowMatrix1, rowMatrix2 [][]string) string {
var rows1 []string
for _, row := range rowMatrix1 {
rows1 = append(rows1, strings.Join(row[:], ","))
}
var rows2 []string
for _, row := range rowMatrix2 {
rows2 = append(rows2, strings.Join(row[:], ","))
}
sort.Strings(rows1)
sort.Strings(rows2)
return cmp.Diff(rows1, rows2)
}

func runWithTimeout(f func() error) error {
done := make(chan error, 1)
go func() {
Expand Down
82 changes: 35 additions & 47 deletions pkg/internal/sqlsmith/tlp.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,10 @@ import (
//
// More information on TLP: https://www.manuelrigger.at/preprints/TLP.pdf.
//
// We currently implement a limited form of TLP that can only verify that the
// number of rows returned by the unpartitioned and the partitioned queries are
// equal.
//
// This TLP implementation is also limited in the types of queries that are
// tested. We currently only test basic WHERE and JOIN query filters. It is
// possible to use TLP to test aggregations, GROUP BY, and HAVING, which have
// all been implemented in SQLancer. See:
// This TLP implementation is limited in the types of queries that are tested.
// We currently only test basic WHERE, JOIN, and MAX/MIN query filters. It is
// possible to use TLP to test other aggregations, GROUP BY, and HAVING, which
// have all been implemented in SQLancer. See:
// https://github.com/sqlancer/sqlancer/tree/1.1.0/src/sqlancer/cockroachdb/oracle/tlp.
func (s *Smither) GenerateTLP() (unpartitioned, partitioned string) {
// Set disableImpureFns to true so that generated predicates are immutable.
Expand Down Expand Up @@ -63,19 +59,17 @@ func (s *Smither) GenerateTLP() (unpartitioned, partitioned string) {
//
// The first query returned is an unpartitioned query of the form:
//
// SELECT count(*) FROM table
// SELECT * FROM table
//
// The second query returned is a partitioned query of the form:
//
// SELECT count(*) FROM (
// SELECT * FROM table WHERE (p)
// UNION ALL
// SELECT * FROM table WHERE NOT (p)
// UNION ALL
// SELECT * FROM table WHERE (p) IS NULL
// )
// SELECT * FROM table WHERE (p)
// UNION ALL
// SELECT * FROM table WHERE NOT (p)
// UNION ALL
// SELECT * FROM table WHERE (p) IS NULL
//
// If the resulting counts of the two queries are not equal, there is a logical
// If the resulting values of the two queries are not equal, there is a logical
// bug.
func (s *Smither) generateWhereTLP() (unpartitioned, partitioned string) {
f := tree.NewFmtCtx(tree.FmtParsable)
Expand All @@ -87,7 +81,7 @@ func (s *Smither) generateWhereTLP() (unpartitioned, partitioned string) {
table.Format(f)
tableName := f.CloseAndGetString()

unpartitioned = fmt.Sprintf("SELECT count(*) FROM %s", tableName)
unpartitioned = fmt.Sprintf("SELECT * FROM %s", tableName)

pred := makeBoolExpr(s, cols)
pred.Format(f)
Expand All @@ -98,7 +92,7 @@ func (s *Smither) generateWhereTLP() (unpartitioned, partitioned string) {
part3 := fmt.Sprintf("SELECT * FROM %s WHERE (%s) IS NULL", tableName, predicate)

partitioned = fmt.Sprintf(
"SELECT count(*) FROM (%s UNION ALL %s UNION ALL %s)",
"(%s) UNION ALL (%s) UNION ALL (%s)",
part1, part2, part3,
)

Expand All @@ -112,23 +106,19 @@ func (s *Smither) generateWhereTLP() (unpartitioned, partitioned string) {
//
// The first query returned is an unpartitioned query of the form:
//
// SELECT count(*) FROM (
// SELECT * FROM table1 LEFT JOIN table2 ON TRUE
// UNION ALL
// SELECT * FROM table1 LEFT JOIN table2 ON FALSE
// UNION ALL
// SELECT * FROM table1 LEFT JOIN table2 ON FALSE
// )
// SELECT * FROM table1 LEFT JOIN table2 ON TRUE
// UNION ALL
// SELECT * FROM table1 LEFT JOIN table2 ON FALSE
// UNION ALL
// SELECT * FROM table1 LEFT JOIN table2 ON FALSE
//
// The second query returned is a partitioned query of the form:
//
// SELECT count(*) FROM (
// SELECT * FROM table1 LEFT JOIN table2 ON (p)
// UNION ALL
// SELECT * FROM table1 LEFT JOIN table2 ON NOT (p)
// UNION ALL
// SELECT * FROM table1 LEFT JOIN table2 ON (p) IS NULL
// )
// SELECT * FROM table1 LEFT JOIN table2 ON (p)
// UNION ALL
// SELECT * FROM table1 LEFT JOIN table2 ON NOT (p)
// UNION ALL
// SELECT * FROM table1 LEFT JOIN table2 ON (p) IS NULL
//
// From the first query, we have a CROSS JOIN of the two tables (JOIN ON TRUE)
// and then all rows concatenated with NULL values for the second and third
Expand All @@ -144,7 +134,7 @@ func (s *Smither) generateWhereTLP() (unpartitioned, partitioned string) {
// Note that this implementation is restricted in that it only uses columns from
// the left table in the predicate p.

// If the resulting counts of the two queries are not equal, there is a logical
// If the resulting values of the two queries are not equal, there is a logical
// bug.
func (s *Smither) generateOuterJoinTLP() (unpartitioned, partitioned string) {
f := tree.NewFmtCtx(tree.FmtParsable)
Expand All @@ -169,7 +159,7 @@ func (s *Smither) generateOuterJoinTLP() (unpartitioned, partitioned string) {
)

unpartitioned = fmt.Sprintf(
"SELECT count(*) FROM (%s UNION ALL %s UNION ALL %s)",
"(%s) UNION ALL (%s) UNION ALL (%s)",
leftJoinTrue, leftJoinFalse, leftJoinFalse,
)

Expand All @@ -191,7 +181,7 @@ func (s *Smither) generateOuterJoinTLP() (unpartitioned, partitioned string) {
)

partitioned = fmt.Sprintf(
"SELECT count(*) FROM (%s UNION ALL %s UNION ALL %s)",
"(%s) UNION ALL (%s) UNION ALL (%s)",
part1, part2, part3,
)

Expand All @@ -205,17 +195,15 @@ func (s *Smither) generateOuterJoinTLP() (unpartitioned, partitioned string) {
//
// The first query returned is an unpartitioned query of the form:
//
// SELECT count(*) FROM table1 JOIN table2 ON TRUE
// SELECT * FROM table1 JOIN table2 ON TRUE
//
// The second query returned is a partitioned query of the form:
//
// SELECT count(*) FROM (
// SELECT * FROM table1 JOIN table2 ON (p)
// UNION ALL
// SELECT * FROM table1 JOIN table2 ON NOT (p)
// UNION ALL
// SELECT * FROM table1 JOIN table2 ON (p) IS NULL
// )
// SELECT * FROM table1 JOIN table2 ON (p)
// UNION ALL
// SELECT * FROM table1 JOIN table2 ON NOT (p)
// UNION ALL
// SELECT * FROM table1 JOIN table2 ON (p) IS NULL
//
// From the first query, we have a CROSS JOIN of the two tables (JOIN ON TRUE).
// Recall our TLP logical guarantee that a given predicate p always evaluates to
Expand All @@ -224,7 +212,7 @@ func (s *Smither) generateOuterJoinTLP() (unpartitioned, partitioned string) {
// resolve to TRUE. So the partitioned query accounts for each row in the
// CROSS JOIN exactly once.
//
// If the resulting counts of the two queries are not equal, there is a logical
// If the resulting values of the two queries are not equal, there is a logical
// bug.
func (s *Smither) generateInnerJoinTLP() (unpartitioned, partitioned string) {
f := tree.NewFmtCtx(tree.FmtParsable)
Expand All @@ -240,7 +228,7 @@ func (s *Smither) generateInnerJoinTLP() (unpartitioned, partitioned string) {
tableName2 := f.CloseAndGetString()

unpartitioned = fmt.Sprintf(
"SELECT count(*) FROM %s JOIN %s ON true",
"SELECT * FROM %s JOIN %s ON true",
tableName1, tableName2,
)

Expand All @@ -263,7 +251,7 @@ func (s *Smither) generateInnerJoinTLP() (unpartitioned, partitioned string) {
)

partitioned = fmt.Sprintf(
"SELECT count(*) FROM (%s UNION ALL %s UNION ALL %s)",
"(%s) UNION ALL (%s) UNION ALL (%s)",
part1, part2, part3,
)

Expand Down

0 comments on commit b9b70c1

Please sign in to comment.