From 438a95fb96ccdce775a4d09334c5ec3f6b724b98 Mon Sep 17 00:00:00 2001 From: rharding6373 Date: Mon, 10 Jul 2023 13:59:17 -0700 Subject: [PATCH] tests: support float approximation in roachtest query comparison utils Before this change unoptimized query oracle tests would compare results using simple string comparison. However, due to floating point precision limitations, it's possible for results with floating point to diverge during the course of normal computation. This results in test failures that are difficult to reproduce or determine whether they are expected behavior. This change utilizes existing floating point comparison functions used by logic tests to match float values only to a specific precision. Like the logic tests, we also have special handling for floats and decimals under the s390x architecture (see #63244). In order to avoid costly comparisons, we only check floating point precision if the naiive string comparison approach fails and there are float or decimal types in the result. Epic: None Fixes: #95665 Release note: None --- pkg/cmd/roachtest/tests/BUILD.bazel | 1 + pkg/cmd/roachtest/tests/costfuzz.go | 6 +- .../roachtest/tests/query_comparison_util.go | 99 +++++++++++++ .../tests/query_comparison_util_test.go | 140 ++++++++++++++++++ .../tests/unoptimized_query_oracle.go | 6 +- pkg/testutils/floatcmp/floatcmp.go | 18 +-- 6 files changed, 252 insertions(+), 18 deletions(-) create mode 100644 pkg/cmd/roachtest/tests/query_comparison_util_test.go diff --git a/pkg/cmd/roachtest/tests/BUILD.bazel b/pkg/cmd/roachtest/tests/BUILD.bazel index d25c811df89a..c3def922cc58 100644 --- a/pkg/cmd/roachtest/tests/BUILD.bazel +++ b/pkg/cmd/roachtest/tests/BUILD.bazel @@ -288,6 +288,7 @@ go_test( srcs = [ "blocklist_test.go", "drt_test.go", + "query_comparison_util_test.go", "tpcc_test.go", "util_load_group_test.go", ":mocks_drt", # keep diff --git a/pkg/cmd/roachtest/tests/costfuzz.go b/pkg/cmd/roachtest/tests/costfuzz.go index 1a1291826f4f..4b82f0812b7c 100644 --- a/pkg/cmd/roachtest/tests/costfuzz.go +++ b/pkg/cmd/roachtest/tests/costfuzz.go @@ -107,7 +107,11 @@ func runCostFuzzQuery(smither *sqlsmith.Smither, rnd *rand.Rand, h queryComparis return nil } - if diff := unsortedMatricesDiff(controlRows, perturbRows); diff != "" { + diff, err := unsortedMatricesDiffWithFloatComp(controlRows, perturbRows, h.colTypes) + if err != nil { + return err + } + if diff != "" { // We have a mismatch in the perturbed vs control query outputs. h.logStatements() h.logVerboseOutput() diff --git a/pkg/cmd/roachtest/tests/query_comparison_util.go b/pkg/cmd/roachtest/tests/query_comparison_util.go index 949b18f6144e..3f91f24524b0 100644 --- a/pkg/cmd/roachtest/tests/query_comparison_util.go +++ b/pkg/cmd/roachtest/tests/query_comparison_util.go @@ -303,6 +303,7 @@ type queryComparisonHelper struct { statements []string statementsAndExplains []sqlAndOutput + colTypes []string } // runQuery runs the given query and returns the output. If the stmt doesn't @@ -327,6 +328,14 @@ func (h *queryComparisonHelper) runQuery(stmt string) ([][]string, error) { return nil, err } defer rows.Close() + cts, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + h.colTypes = make([]string, len(cts)) + for i, ct := range cts { + h.colTypes[i] = ct.DatabaseTypeName() + } return sqlutils.RowsToStrMatrix(rows) } @@ -384,6 +393,96 @@ func (h *queryComparisonHelper) makeError(err error, msg string) error { return errors.Wrapf(err, "%s. %d statements run", msg, h.stmtNo) } +func joinAndSortRows(rowMatrix1, rowMatrix2 [][]string, sep string, +) (rows1, rows2 []string) { + for _, row := range rowMatrix1 { + rows1 = append(rows1, strings.Join(row[:], sep)) + } + for _, row := range rowMatrix2 { + rows2 = append(rows2, strings.Join(row[:], sep)) + } + sort.Strings(rows1) + sort.Strings(rows2) + return rows1, rows2 +} + +// unsortedMatricesDiffWithFloatComp sorts and compares the rows in rowMatrix1 +// to rowMatrix2 and outputs a diff or message related to the comparison. If a +// string comparison of the rows fails, and they contain floats or decimals, it +// performs an approximate comparison of the values. +func unsortedMatricesDiffWithFloatComp( + rowMatrix1, rowMatrix2 [][]string, colTypes []string, +) (string, error) { + rows1, rows2 := joinAndSortRows(rowMatrix1, rowMatrix2, ",") + result := cmp.Diff(rows1, rows2) + if result == "" { + return result, nil + } + if len(rows1) != len(rows2) || len(colTypes) != len(rowMatrix1[0]) || len(colTypes) != len(rowMatrix2[0]) { + return result, nil + } + var needApproxMatch bool + for i := range colTypes { + // On s390x, check that values for both float and decimal coltypes are + // approximately equal to take into account platform differences in floating + // point calculations. On other architectures, check float values only. + if (runtime.GOARCH == "s390x" && colTypes[i] == "DECIMAL") || + colTypes[i] == "FLOAT4" || colTypes[i] == "FLOAT8" { + needApproxMatch = true + break + } + } + if !needApproxMatch { + return result, nil + } + // Use an unlikely string as a separator so that we can make a comparison + // using sorted rows. We don't use the rows sorted above because splitting + // the rows could be ambiguous. + sep := ",unsortedMatricesDiffWithFloatComp separator," + rows1, rows2 = joinAndSortRows(rowMatrix1, rowMatrix2, sep) + for i := range rows1 { + // Split the sorted rows. + row1 := strings.Split(rows1[i], sep) + row2 := strings.Split(rows2[i], sep) + + for j := range row1 { + if runtime.GOARCH == "s390x" && colTypes[j] == "DECIMAL" { + // On s390x, check that values for both float and decimal coltypes are + // approximately equal to take into account platform differences in floating + // point calculations. On other architectures, check float values only. + match, err := floatcmp.FloatsMatchApprox(row1[j], row2[j]) + if err != nil { + return "", err + } + if !match { + return result, nil + } + } else if colTypes[j] == "FLOAT4" || colTypes[j] == "FLOAT8" { + // Check that float values are approximately equal. + var err error + var match bool + if runtime.GOARCH == "s390x" { + match, err = floatcmp.FloatsMatchApprox(row1[j], row2[j]) + } else { + match, err = floatcmp.FloatsMatch(row1[j], row2[j]) + } + if err != nil { + return "", err + } + if !match { + return result, nil + } + } else { + // Check that other columns are equal with a string comparison. + if row1[j] != row2[j] { + return result, nil + } + } + } + } + return "", nil +} + // unsortedMatricesDiff sorts and compares rows of data. func unsortedMatricesDiff(rowMatrix1, rowMatrix2 [][]string) string { var rows1 []string diff --git a/pkg/cmd/roachtest/tests/query_comparison_util_test.go b/pkg/cmd/roachtest/tests/query_comparison_util_test.go new file mode 100644 index 000000000000..ac1d5452eefe --- /dev/null +++ b/pkg/cmd/roachtest/tests/query_comparison_util_test.go @@ -0,0 +1,140 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package tests + +import ( + "testing" + + "github.com/cockroachdb/cockroach/pkg/util/leaktest" +) + +// TestUnsortedMatricesDiff is a unit test for the +// unsortedMatricesDiffWithFloatComp() and unsortedMatricesDiff() utility +// functions. +func TestUnsortedMatricesDiff(t *testing.T) { + defer leaktest.AfterTest(t)() + tcs := []struct { + name string + colTypes []string + t1, t2 [][]string + exactMatch bool + approxMatch bool + }{ + { + name: "float exact match", + colTypes: []string{"FLOAT8"}, + t1: [][]string{{"1.2345678901234567"}}, + t2: [][]string{{"1.2345678901234567"}}, + exactMatch: true, + }, + { + name: "float approx match", + colTypes: []string{"FLOAT8"}, + t1: [][]string{{"1.2345678901234563"}}, + t2: [][]string{{"1.2345678901234564"}}, + exactMatch: false, + approxMatch: true, + }, + { + name: "float no match", + colTypes: []string{"FLOAT8"}, + t1: [][]string{{"1.234567890123"}}, + t2: [][]string{{"1.234567890124"}}, + exactMatch: false, + approxMatch: false, + }, + { + name: "multi float approx match", + colTypes: []string{"FLOAT8", "FLOAT8"}, + t1: [][]string{{"1.2345678901234567", "1.2345678901234567"}}, + t2: [][]string{{"1.2345678901234567", "1.2345678901234568"}}, + exactMatch: false, + approxMatch: true, + }, + { + name: "string no match", + colTypes: []string{"STRING"}, + t1: [][]string{{"hello"}}, + t2: [][]string{{"world"}}, + exactMatch: false, + approxMatch: false, + }, + { + name: "mixed types match", + colTypes: []string{"STRING", "FLOAT8"}, + t1: [][]string{{"hello", "1.2345678901234567"}}, + t2: [][]string{{"hello", "1.2345678901234567"}}, + exactMatch: true, + }, + { + name: "mixed types float approx match", + colTypes: []string{"STRING", "FLOAT8"}, + t1: [][]string{{"hello", "1.23456789012345678"}}, + t2: [][]string{{"hello", "1.23456789012345679"}}, + exactMatch: false, + approxMatch: true, + }, + { + name: "mixed types no match", + colTypes: []string{"STRING", "FLOAT8"}, + t1: [][]string{{"hello", "1.2345678901234567"}}, + t2: [][]string{{"world", "1.2345678901234567"}}, + exactMatch: false, + approxMatch: false, + }, + { + name: "different col count", + colTypes: []string{"STRING"}, + t1: [][]string{{"hello", "1.2345678901234567"}}, + t2: [][]string{{"world", "1.2345678901234567"}}, + exactMatch: false, + approxMatch: false, + }, + { + name: "different row count", + colTypes: []string{"STRING", "FLOAT8"}, + t1: [][]string{{"hello", "1.2345678901234567"}, {"aloha", "2.345"}}, + t2: [][]string{{"world", "1.2345678901234567"}}, + exactMatch: false, + approxMatch: false, + }, + { + name: "multi row unsorted", + colTypes: []string{"STRING", "FLOAT8"}, + t1: [][]string{{"hello", "1.2345678901234567"}, {"world", "1.2345678901234560"}}, + t2: [][]string{{"world", "1.2345678901234560"}, {"hello", "1.2345678901234567"}}, + exactMatch: true, + }, + } + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + match := unsortedMatricesDiff(tc.t1, tc.t2) + if tc.exactMatch && match != "" { + t.Fatalf("unsortedMatricesDiff: expected exact match, got diff: %s", match) + } else if !tc.exactMatch && match == "" { + t.Fatalf("unsortedMatricesDiff: expected no exact match, got no diff") + } + + var err error + match, err = unsortedMatricesDiffWithFloatComp(tc.t1, tc.t2, tc.colTypes) + if err != nil { + t.Fatal(err) + } + if tc.exactMatch && match != "" { + t.Fatalf("unsortedMatricesDiffWithFloatComp: expected exact match, got diff: %s", match) + } else if !tc.exactMatch && tc.approxMatch && match != "" { + t.Fatalf("unsortedMatricesDiffWithFloatComp: expected approx match, got diff: %s", match) + } else if !tc.exactMatch && !tc.approxMatch && match == "" { + t.Fatalf("unsortedMatricesDiffWithFloatComp: expected no approx match, got no diff") + } + }) + } +} diff --git a/pkg/cmd/roachtest/tests/unoptimized_query_oracle.go b/pkg/cmd/roachtest/tests/unoptimized_query_oracle.go index 66f1398d9b6d..ba1eb6058524 100644 --- a/pkg/cmd/roachtest/tests/unoptimized_query_oracle.go +++ b/pkg/cmd/roachtest/tests/unoptimized_query_oracle.go @@ -174,7 +174,11 @@ func runUnoptimizedQueryOracleImpl( //nolint:returnerrcheck return nil } - if diff := unsortedMatricesDiff(unoptimizedRows, optimizedRows); diff != "" { + diff, err := unsortedMatricesDiffWithFloatComp(unoptimizedRows, optimizedRows, h.colTypes) + if err != nil { + return err + } + if diff != "" { // We have a mismatch in the unoptimized vs optimized query outputs. verboseLogging = true return h.makeError(errors.Newf( diff --git a/pkg/testutils/floatcmp/floatcmp.go b/pkg/testutils/floatcmp/floatcmp.go index 8c7f0d9f298e..6b87db4820fb 100644 --- a/pkg/testutils/floatcmp/floatcmp.go +++ b/pkg/testutils/floatcmp/floatcmp.go @@ -121,23 +121,9 @@ func FloatsMatch(expectedString, actualString string) (bool, error) { actual = math.Abs(actual) // Check that 15 significant digits match. We do so by normalizing the // numbers and then checking one digit at a time. - // - // normalize converts f to base * 10**power representation where base is in - // [1.0, 10.0) range. - normalize := func(f float64) (base float64, power int) { - for f >= 10 { - f = f / 10 - power++ - } - for f < 1 { - f *= 10 - power-- - } - return f, power - } var expPower, actPower int - expected, expPower = normalize(expected) - actual, actPower = normalize(actual) + expected, expPower = math.Frexp(expected) + actual, actPower = math.Frexp(actual) if expPower != actPower { return false, nil }