diff --git a/pkg/internal/sqlsmith/BUILD.bazel b/pkg/internal/sqlsmith/BUILD.bazel index f26ca7d06539..766d36a17806 100644 --- a/pkg/internal/sqlsmith/BUILD.bazel +++ b/pkg/internal/sqlsmith/BUILD.bazel @@ -58,6 +58,7 @@ go_test( "//pkg/testutils/sqlutils", "//pkg/testutils/testcluster", "//pkg/util/leaktest", + "//pkg/util/log", "//pkg/util/randutil", ], ) diff --git a/pkg/internal/sqlsmith/sqlsmith_test.go b/pkg/internal/sqlsmith/sqlsmith_test.go index a6548040835e..88e1ff073a50 100644 --- a/pkg/internal/sqlsmith/sqlsmith_test.go +++ b/pkg/internal/sqlsmith/sqlsmith_test.go @@ -23,6 +23,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/testutils/serverutils" "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/randutil" ) @@ -36,6 +37,7 @@ var ( // TestSetups verifies that all setups generate executable SQL. func TestSetups(t *testing.T) { defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) for name, setup := range Setups { t.Run(name, func(t *testing.T) { @@ -80,6 +82,7 @@ func TestSetups(t *testing.T) { // false-negative. func TestRandTableInserts(t *testing.T) { defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) ctx := context.Background() s, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{}) @@ -142,6 +145,7 @@ func TestRandTableInserts(t *testing.T) { // sometimes put them into bad states that the parser would never do. func TestGenerateParse(t *testing.T) { defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) defer utilccl.TestingEnableEnterprise()() ctx := context.Background() diff --git a/pkg/sql/parser/BUILD.bazel b/pkg/sql/parser/BUILD.bazel index 37c4db40dabc..a4eb9a486c01 100644 --- a/pkg/sql/parser/BUILD.bazel +++ b/pkg/sql/parser/BUILD.bazel @@ -57,6 +57,7 @@ go_test( deps = [ "//pkg/sql/lexbase", "//pkg/sql/pgwire/pgerror", + "//pkg/sql/randgen", "//pkg/sql/sem/builtins", "//pkg/sql/sem/tree", "//pkg/sql/sem/tree/treebin", @@ -65,9 +66,11 @@ go_test( "//pkg/testutils/sqlutils", "//pkg/util/leaktest", "//pkg/util/log", + "//pkg/util/randutil", "@com_github_cockroachdb_datadriven//:datadriven", "@com_github_cockroachdb_errors//:errors", "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", ], ) diff --git a/pkg/sql/parser/parse.go b/pkg/sql/parser/parse.go index 094313889e9a..8e94556d9873 100644 --- a/pkg/sql/parser/parse.go +++ b/pkg/sql/parser/parse.go @@ -382,6 +382,19 @@ func GetTypeFromValidSQLSyntax(sql string) (tree.ResolvableTypeReference, error) if err != nil { return nil, err } + return GetTypeFromCastOrCollate(expr) +} + +// GetTypeFromCastOrCollate returns the type of the given tree.Expr. The method +// assumes that the expression is either tree.CastExpr or tree.CollateExpr +// (which wraps the tree.CastExpr). +func GetTypeFromCastOrCollate(expr tree.Expr) (tree.ResolvableTypeReference, error) { + // COLLATE clause has lower precedence than the cast, so if we have + // something like `1::STRING COLLATE en`, it'll be parsed as + // CollateExpr(CastExpr). + if collate, ok := expr.(*tree.CollateExpr); ok { + return types.MakeCollatedString(types.String, collate.Locale), nil + } cast, ok := expr.(*tree.CastExpr) if !ok { diff --git a/pkg/sql/parser/parse_test.go b/pkg/sql/parser/parse_test.go index 39117955d34c..f24c5cea0261 100644 --- a/pkg/sql/parser/parse_test.go +++ b/pkg/sql/parser/parse_test.go @@ -20,6 +20,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/parser" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" + "github.com/cockroachdb/cockroach/pkg/sql/randgen" _ "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/sem/tree/treebin" @@ -27,9 +28,11 @@ import ( "github.com/cockroachdb/cockroach/pkg/testutils" "github.com/cockroachdb/cockroach/pkg/testutils/sqlutils" _ "github.com/cockroachdb/cockroach/pkg/util/log" // for flags + "github.com/cockroachdb/cockroach/pkg/util/randutil" "github.com/cockroachdb/datadriven" "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // TestParseDataDriven verifies that we can parse the supplied SQL and regenerate the SQL @@ -761,3 +764,21 @@ func BenchmarkParse(b *testing.B) { }) } } + +func TestGetTypeFromValidSQLSyntax(t *testing.T) { + rng, _ := randutil.NewTestRand() + + const numRuns = 1000 + for i := 0; i < numRuns; i++ { + orig := randgen.RandType(rng) + typeRef, err := parser.GetTypeFromValidSQLSyntax(orig.SQLString()) + require.NoError(t, err) + actual, ok := tree.GetStaticallyKnownType(typeRef) + require.True(t, ok) + // TODO(yuzefovich): ideally, we'd assert that the returned type is + // equal to the original one; however, there are some subtle differences + // at the moment (like the width might only be set on the returned + // type), so we simply assert that the OIDs are the same. + require.Equal(t, orig.Oid(), actual.Oid()) + } +} diff --git a/pkg/sql/randgen/BUILD.bazel b/pkg/sql/randgen/BUILD.bazel index 411d1c94de0b..6af9af39891b 100644 --- a/pkg/sql/randgen/BUILD.bazel +++ b/pkg/sql/randgen/BUILD.bazel @@ -55,6 +55,7 @@ go_test( "main_test.go", "mutator_test.go", "schema_test.go", + "types_test.go", ], embed = [":randgen"], deps = [ @@ -63,11 +64,13 @@ go_test( "//pkg/security/securitytest", "//pkg/server", "//pkg/sql/sem/tree", + "//pkg/sql/types", "//pkg/testutils/serverutils", "//pkg/testutils/sqlutils", "//pkg/testutils/testcluster", "//pkg/util/leaktest", "//pkg/util/randutil", + "@com_github_cockroachdb_errors//:errors", "@com_github_stretchr_testify//require", ], ) diff --git a/pkg/sql/randgen/mutator.go b/pkg/sql/randgen/mutator.go index 940f849671bf..db7935a8cb50 100644 --- a/pkg/sql/randgen/mutator.go +++ b/pkg/sql/randgen/mutator.go @@ -218,6 +218,12 @@ func statisticsMutator( return } colType := tree.MustBeStaticallyKnownType(col.Type) + if colType.Family() == types.CollatedStringFamily { + // Collated strings are not roundtrippable during + // encoding/decoding, so we cannot always make a valid + // histogram. + return + } h := randHistogram(rng, colType) stat := colStats[col.Name] if err := stat.SetHistogram(&h); err != nil { diff --git a/pkg/sql/randgen/type.go b/pkg/sql/randgen/type.go index f290feca46f3..d05a964e8bd9 100644 --- a/pkg/sql/randgen/type.go +++ b/pkg/sql/randgen/type.go @@ -57,6 +57,14 @@ func init() { } } + // Add a collated string separately (since it shares the oid with the STRING + // type and, thus, wasn't included above). + collatedStringType := types.MakeCollatedString(types.String, "en" /* locale */) + SeedTypes = append(SeedTypes, collatedStringType) + if IsAllowedForArray(collatedStringType) { + arrayContentsTypes = append(arrayContentsTypes, collatedStringType) + } + // Sort these so randomly chosen indexes always point to the same element. sort.Slice(SeedTypes, func(i, j int) bool { return SeedTypes[i].String() < SeedTypes[j].String() diff --git a/pkg/sql/randgen/types_test.go b/pkg/sql/randgen/types_test.go new file mode 100644 index 000000000000..97eea216631a --- /dev/null +++ b/pkg/sql/randgen/types_test.go @@ -0,0 +1,51 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package randgen + +import ( + "fmt" + "testing" + + "github.com/cockroachdb/cockroach/pkg/sql/types" + "github.com/cockroachdb/cockroach/pkg/util/leaktest" + "github.com/cockroachdb/errors" +) + +// TestSeedTypes verifies that at least one representative type is included into +// SeedTypes for all (with a few exceptions) type families. +func TestSeedTypes(t *testing.T) { + defer leaktest.AfterTest(t)() + + noFamilyRepresentative := make(map[types.Family]struct{}) +loop: + for id := range types.Family_name { + familyID := types.Family(id) + switch familyID { + case types.EnumFamily: + // Enums need to created separately. + continue loop + case types.UnknownFamily, types.AnyFamily: + // These are not included on purpose. + continue loop + } + noFamilyRepresentative[familyID] = struct{}{} + } + for _, typ := range SeedTypes { + delete(noFamilyRepresentative, typ.Family()) + } + if len(noFamilyRepresentative) > 0 { + s := "no representative for " + for f := range noFamilyRepresentative { + s += fmt.Sprintf("%s (%d) ", types.Family_name[int32(f)], f) + } + t.Fatal(errors.Errorf("%s", s)) + } +} diff --git a/pkg/sql/sem/tree/datum.go b/pkg/sql/sem/tree/datum.go index 3b313c032f4b..581ab0ea4555 100644 --- a/pkg/sql/sem/tree/datum.go +++ b/pkg/sql/sem/tree/datum.go @@ -5796,7 +5796,7 @@ func InferTypes(vals []string) []types.Family { // with a reasonable previous datum that is smaller than the given one. // // The return value is undefined if Datum.IsMin returns true or if the value is -// NaN of an infinity (for floats and decimals). +// NaN or an infinity (for floats and decimals). func DatumPrev( datum Datum, evalCtx *EvalContext, collationEnv *CollationEnvironment, ) (Datum, bool) { @@ -5833,16 +5833,6 @@ func DatumPrev( return nil, false } return NewDString(prev), true - case *DCollatedString: - prev, ok := prevString(d.Contents) - if !ok { - return nil, false - } - c, err := NewDCollatedString(prev, d.Locale, collationEnv) - if err != nil { - return nil, false - } - return c, true case *DBytes: prev, ok := prevString(string(*d)) if !ok { @@ -5855,8 +5845,8 @@ func DatumPrev( return NewDInterval(prev, types.DefaultIntervalTypeMetadata), true default: // TODO(yuzefovich): consider adding support for other datums that don't - // have Datum.Prev implementation (DBitArray, DGeography, DGeometry, - // DBox2D, DJSON, DArray). + // have Datum.Prev implementation (DCollatedString, DBitArray, + // DGeography, DGeometry, DBox2D, DJSON, DArray). return datum.Prev(evalCtx) } } @@ -5867,7 +5857,7 @@ func DatumPrev( // with a reasonable next datum that is greater than the given one. // // The return value is undefined if Datum.IsMax returns true or if the value is -// NaN of an infinity (for floats and decimals). +// NaN or an infinity (for floats and decimals). func DatumNext( datum Datum, evalCtx *EvalContext, collationEnv *CollationEnvironment, ) (Datum, bool) { @@ -5885,24 +5875,13 @@ func DatumNext( return nil, false } return &next, true - case *DCollatedString: - s := NewDString(d.Contents) - next, ok := s.Next(evalCtx) - if !ok { - return nil, false - } - c, err := NewDCollatedString(string(*next.(*DString)), d.Locale, collationEnv) - if err != nil { - return nil, false - } - return c, true case *DInterval: next := d.Add(duration.MakeDuration(1000000 /* nanos */, 0 /* days */, 0 /* months */)) return NewDInterval(next, types.DefaultIntervalTypeMetadata), true default: // TODO(yuzefovich): consider adding support for other datums that don't - // have Datum.Next implementation (DGeography, DGeometry, DBox2D, - // DJSON). + // have Datum.Next implementation (DCollatedString, DGeography, + // DGeometry, DBox2D, DJSON). return datum.Next(evalCtx) } } diff --git a/pkg/workload/schemachange/operation_generator.go b/pkg/workload/schemachange/operation_generator.go index 12344f603076..4b883f1580ec 100644 --- a/pkg/workload/schemachange/operation_generator.go +++ b/pkg/workload/schemachange/operation_generator.go @@ -3049,11 +3049,11 @@ func (og *operationGenerator) typeFromTypeName( if err != nil { return nil, errors.Wrapf(err, "typeFromTypeName: %s", typeName) } - typ, err := tree.ResolveType( - context.Background(), - stmt.AST.(*tree.Select).Select.(*tree.SelectClause).Exprs[0].Expr.(*tree.CastExpr).Type, - &txTypeResolver{tx: tx}, - ) + typRef, err := parser.GetTypeFromCastOrCollate(stmt.AST.(*tree.Select).Select.(*tree.SelectClause).Exprs[0].Expr) + if err != nil { + return nil, errors.Wrapf(err, "GetTypeFromCastOrCollate: %s", typeName) + } + typ, err := tree.ResolveType(ctx, typRef, &txTypeResolver{tx: tx}) if err != nil { return nil, errors.Wrapf(err, "ResolveType: %v", typeName) }