Skip to content

Commit

Permalink
tests: support float approximation in roachtest query comparison utils
Browse files Browse the repository at this point in the history
Before this change unoptimized query oracle tests would compare results
using simple string comparison. However, due to floating point precision
limitations, it's possible for results with floating point to diverge
during the course of normal computation. This results in test failures
that are difficult to reproduce or determine whether they are expected
behavior.

This change utilizes existing floating point comparison functions used
by logic tests to match float values only to a specific precision. Like
the logic tests, we also have special handling for floats and decimals
under the s390x architecture (see cockroachdb#63244). In order to avoid costly
comparisons, we only check floating point precision if the naiive string
comparison approach fails and there are float or decimal types in the
result.

Epic: None
Fixes: cockroachdb#95665

Release note: None
  • Loading branch information
rharding6373 committed Jul 10, 2023
1 parent 35fb0d3 commit 0c36d5d
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 1 deletion.
1 change: 1 addition & 0 deletions pkg/cmd/roachtest/tests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ go_test(
srcs = [
"blocklist_test.go",
"drt_test.go",
"query_comparison_util_test.go",
"tpcc_test.go",
"util_load_group_test.go",
":mocks_drt", # keep
Expand Down
89 changes: 89 additions & 0 deletions pkg/cmd/roachtest/tests/query_comparison_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ type queryComparisonHelper struct {

statements []string
statementsAndExplains []sqlAndOutput
colTypes []string
}

// runQuery runs the given query and returns the output. If the stmt doesn't
Expand All @@ -327,6 +328,14 @@ func (h *queryComparisonHelper) runQuery(stmt string) ([][]string, error) {
return nil, err
}
defer rows.Close()
cts, err := rows.ColumnTypes()
if err != nil {
return nil, err
}
h.colTypes = make([]string, len(cts))
for i, ct := range cts {
h.colTypes[i] = ct.DatabaseTypeName()
}
return sqlutils.RowsToStrMatrix(rows)
}

Expand Down Expand Up @@ -384,6 +393,86 @@ func (h *queryComparisonHelper) makeError(err error, msg string) error {
return errors.Wrapf(err, "%s. %d statements run", msg, h.stmtNo)
}

// unsortedMatricesDiffWithFloatComp sorts and compares the rows in rowMatrix1
// to rowMatrix2 and outputs a diff or message related to the comparison. If a
// string comparison of the rows fails, and they contain floats or decimals, it
// performs an approximate comparison of the values.
func unsortedMatricesDiffWithFloatComp(
rowMatrix1, rowMatrix2 [][]string, colTypes []string,
) (string, error) {
// Use an unlikely string as a separator when joining rows for initial
// comparison, so that we can split the sorted rows if we need to make
// additional comparisons.
sep := ",unsortedMatricesDiffWithFloatComp separator,"
var rows1 []string
for _, row := range rowMatrix1 {
rows1 = append(rows1, strings.Join(row[:], sep))
}
var rows2 []string
for _, row := range rowMatrix2 {
rows2 = append(rows2, strings.Join(row[:], sep))
}
sort.Strings(rows1)
sort.Strings(rows2)
result := cmp.Diff(rows1, rows2)
if result == "" {
return result, nil
}
if len(rows1) != len(rows2) {
return fmt.Sprintf("row counts are not the same (%d vs %d)", len(rows1), len(rows2)), nil
}
var decimalIdxs []int
var floatIdxs []int
for i := range colTypes {
// On s390x, check that values for both float and decimal coltypes are
// approximately equal to take into account platform differences in floating
// point calculations. On other architectures, check float values only.
if runtime.GOARCH == "s390x" && colTypes[i] == "DECIMAL" {
decimalIdxs = append(decimalIdxs, i)
}
if colTypes[i] == "FLOAT4" || colTypes[i] == "FLOAT8" {
floatIdxs = append(floatIdxs, i)
}
}
if len(decimalIdxs) == 0 && len(floatIdxs) == 0 {
return result, nil
}
for i := range rows1 {
// Split the sorted rows.
row1 := strings.Split(rows1[i], sep)
row2 := strings.Split(rows2[i], sep)

// Check that decimal values are approximately equal.
for j := range decimalIdxs {
match, err := floatcmp.FloatsMatchApprox(row1[j], row2[j])
if err != nil {
return "", err
}
if !match {
return result, nil
}
}

// Check that float values are approximately equal.
for j := range floatIdxs {
var err error
var match bool
if runtime.GOARCH == "s390x" {
match, err = floatcmp.FloatsMatchApprox(row1[j], row2[j])
} else {
match, err = floatcmp.FloatsMatch(row1[j], row2[j])
}
if err != nil {
return "", err
}
if !match {
return result, nil
}
}
}
return "", nil
}

// unsortedMatricesDiff sorts and compares rows of data.
func unsortedMatricesDiff(rowMatrix1, rowMatrix2 [][]string) string {
var rows1 []string
Expand Down
94 changes: 94 additions & 0 deletions pkg/cmd/roachtest/tests/query_comparison_util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright 2023 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package tests

import (
"testing"

"github.com/cockroachdb/cockroach/pkg/util/leaktest"
)

// TestUnsortedMatricesDiff is a unit test for the
// unsortedMatricesDiffWithFloatComp() and unsortedMatricesDiff() utility
// functions.
func TestUnsortedMatricesDiff(t *testing.T) {
defer leaktest.AfterTest(t)()
tcs := []struct {
name string
colTypes []string
t1, t2 [][]string
exactMatch bool
approxMatch bool
}{
{
name: "float exact match",
colTypes: []string{"FLOAT8"},
t1: [][]string{{"1.2345678901234567"}},
t2: [][]string{{"1.2345678901234567"}},
exactMatch: true,
},
{
name: "float approx match",
colTypes: []string{"FLOAT8"},
t1: [][]string{{"1.2345678901234563"}},
t2: [][]string{{"1.2345678901234564"}},
exactMatch: false,
approxMatch: true,
},
{
name: "float no match",
colTypes: []string{"FLOAT8"},
t1: [][]string{{"1.2345678901234"}},
t2: [][]string{{"1.2345678901235"}},
exactMatch: false,
approxMatch: false,
},
{
name: "multi float approx match",
colTypes: []string{"FLOAT8"},
t1: [][]string{{"1.2345678901234567", "1.2345678901234567"}},
t2: [][]string{{"1.2345678901234567", "1.2345678901234568"}},
exactMatch: false,
approxMatch: true,
},
{
name: "string no match",
colTypes: []string{"STRING"},
t1: [][]string{{"hello"}},
t2: [][]string{{"world"}},
exactMatch: false,
approxMatch: false,
},
}
for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
match := unsortedMatricesDiff(tc.t1, tc.t2)
if tc.exactMatch && match != "" {
t.Fatalf("unsortedMatricesDiff: expected exact match, got diff: %s", match)
} else if !tc.exactMatch && match == "" {
t.Fatalf("unsortedMatricesDiff: expected no exact match, got no diff")
}

var err error
match, err = unsortedMatricesDiffWithFloatComp(tc.t1, tc.t2, tc.colTypes)
if err != nil {
t.Fatal(err)
}
if tc.exactMatch && match != "" {
t.Fatalf("unsortedMatricesDiffWithFloatComp: expected exact match, got diff: %s", match)
} else if !tc.exactMatch && tc.approxMatch && match != "" {
t.Fatalf("unsortedMatricesDiffWithFloatComp: expected approx match, got diff: %s", match)
} else if !tc.exactMatch && !tc.approxMatch && match == "" {
t.Fatalf("unsortedMatricesDiffWithFloatComp: expected no approx match, got no diff")
}
})
}
}
6 changes: 5 additions & 1 deletion pkg/cmd/roachtest/tests/unoptimized_query_oracle.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ func runUnoptimizedQueryOracleImpl(
//nolint:returnerrcheck
return nil
}
if diff := unsortedMatricesDiff(unoptimizedRows, optimizedRows); diff != "" {
diff, err := unsortedMatricesDiffWithFloatComp(unoptimizedRows, optimizedRows, h.colTypes)
if err != nil {
return err
}
if diff != "" {
// We have a mismatch in the unoptimized vs optimized query outputs.
verboseLogging = true
return h.makeError(errors.Newf(
Expand Down

0 comments on commit 0c36d5d

Please sign in to comment.