Skip to content

Commit

Permalink
Merge pull request #97492 from yuzefovich/backport22.1-96695-97435
Browse files Browse the repository at this point in the history
release-22.1: parser: fix GetTypeFromValidSQLSyntax for collated strings
  • Loading branch information
yuzefovich authored Feb 27, 2023
2 parents cf84e7e + 01e7612 commit e842323
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 32 deletions.
1 change: 1 addition & 0 deletions pkg/internal/sqlsmith/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ go_test(
"//pkg/testutils/sqlutils",
"//pkg/testutils/testcluster",
"//pkg/util/leaktest",
"//pkg/util/log",
"//pkg/util/randutil",
],
)
4 changes: 4 additions & 0 deletions pkg/internal/sqlsmith/sqlsmith_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/testutils/serverutils"
"github.com/cockroachdb/cockroach/pkg/testutils/sqlutils"
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/randutil"
)

Expand All @@ -36,6 +37,7 @@ var (
// TestSetups verifies that all setups generate executable SQL.
func TestSetups(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)

for name, setup := range Setups {
t.Run(name, func(t *testing.T) {
Expand Down Expand Up @@ -80,6 +82,7 @@ func TestSetups(t *testing.T) {
// false-negative.
func TestRandTableInserts(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)

ctx := context.Background()
s, sqlDB, _ := serverutils.StartServer(t, base.TestServerArgs{})
Expand Down Expand Up @@ -142,6 +145,7 @@ func TestRandTableInserts(t *testing.T) {
// sometimes put them into bad states that the parser would never do.
func TestGenerateParse(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)
defer utilccl.TestingEnableEnterprise()()

ctx := context.Background()
Expand Down
3 changes: 3 additions & 0 deletions pkg/sql/parser/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ go_test(
deps = [
"//pkg/sql/lexbase",
"//pkg/sql/pgwire/pgerror",
"//pkg/sql/randgen",
"//pkg/sql/sem/builtins",
"//pkg/sql/sem/tree",
"//pkg/sql/sem/tree/treebin",
Expand All @@ -65,9 +66,11 @@ go_test(
"//pkg/testutils/sqlutils",
"//pkg/util/leaktest",
"//pkg/util/log",
"//pkg/util/randutil",
"@com_github_cockroachdb_datadriven//:datadriven",
"@com_github_cockroachdb_errors//:errors",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",
],
)

Expand Down
13 changes: 13 additions & 0 deletions pkg/sql/parser/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,19 @@ func GetTypeFromValidSQLSyntax(sql string) (tree.ResolvableTypeReference, error)
if err != nil {
return nil, err
}
return GetTypeFromCastOrCollate(expr)
}

// GetTypeFromCastOrCollate returns the type of the given tree.Expr. The method
// assumes that the expression is either tree.CastExpr or tree.CollateExpr
// (which wraps the tree.CastExpr).
func GetTypeFromCastOrCollate(expr tree.Expr) (tree.ResolvableTypeReference, error) {
// COLLATE clause has lower precedence than the cast, so if we have
// something like `1::STRING COLLATE en`, it'll be parsed as
// CollateExpr(CastExpr).
if collate, ok := expr.(*tree.CollateExpr); ok {
return types.MakeCollatedString(types.String, collate.Locale), nil
}

cast, ok := expr.(*tree.CastExpr)
if !ok {
Expand Down
21 changes: 21 additions & 0 deletions pkg/sql/parser/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,19 @@ import (

"github.com/cockroachdb/cockroach/pkg/sql/parser"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/sql/randgen"
_ "github.com/cockroachdb/cockroach/pkg/sql/sem/builtins"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree/treebin"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree/treecmp"
"github.com/cockroachdb/cockroach/pkg/testutils"
"github.com/cockroachdb/cockroach/pkg/testutils/sqlutils"
_ "github.com/cockroachdb/cockroach/pkg/util/log" // for flags
"github.com/cockroachdb/cockroach/pkg/util/randutil"
"github.com/cockroachdb/datadriven"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// TestParseDataDriven verifies that we can parse the supplied SQL and regenerate the SQL
Expand Down Expand Up @@ -761,3 +764,21 @@ func BenchmarkParse(b *testing.B) {
})
}
}

func TestGetTypeFromValidSQLSyntax(t *testing.T) {
rng, _ := randutil.NewTestRand()

const numRuns = 1000
for i := 0; i < numRuns; i++ {
orig := randgen.RandType(rng)
typeRef, err := parser.GetTypeFromValidSQLSyntax(orig.SQLString())
require.NoError(t, err)
actual, ok := tree.GetStaticallyKnownType(typeRef)
require.True(t, ok)
// TODO(yuzefovich): ideally, we'd assert that the returned type is
// equal to the original one; however, there are some subtle differences
// at the moment (like the width might only be set on the returned
// type), so we simply assert that the OIDs are the same.
require.Equal(t, orig.Oid(), actual.Oid())
}
}
3 changes: 3 additions & 0 deletions pkg/sql/randgen/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ go_test(
"main_test.go",
"mutator_test.go",
"schema_test.go",
"types_test.go",
],
embed = [":randgen"],
deps = [
Expand All @@ -63,11 +64,13 @@ go_test(
"//pkg/security/securitytest",
"//pkg/server",
"//pkg/sql/sem/tree",
"//pkg/sql/types",
"//pkg/testutils/serverutils",
"//pkg/testutils/sqlutils",
"//pkg/testutils/testcluster",
"//pkg/util/leaktest",
"//pkg/util/randutil",
"@com_github_cockroachdb_errors//:errors",
"@com_github_stretchr_testify//require",
],
)
6 changes: 6 additions & 0 deletions pkg/sql/randgen/mutator.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@ func statisticsMutator(
return
}
colType := tree.MustBeStaticallyKnownType(col.Type)
if colType.Family() == types.CollatedStringFamily {
// Collated strings are not roundtrippable during
// encoding/decoding, so we cannot always make a valid
// histogram.
return
}
h := randHistogram(rng, colType)
stat := colStats[col.Name]
if err := stat.SetHistogram(&h); err != nil {
Expand Down
8 changes: 8 additions & 0 deletions pkg/sql/randgen/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ func init() {
}
}

// Add a collated string separately (since it shares the oid with the STRING
// type and, thus, wasn't included above).
collatedStringType := types.MakeCollatedString(types.String, "en" /* locale */)
SeedTypes = append(SeedTypes, collatedStringType)
if IsAllowedForArray(collatedStringType) {
arrayContentsTypes = append(arrayContentsTypes, collatedStringType)
}

// Sort these so randomly chosen indexes always point to the same element.
sort.Slice(SeedTypes, func(i, j int) bool {
return SeedTypes[i].String() < SeedTypes[j].String()
Expand Down
51 changes: 51 additions & 0 deletions pkg/sql/randgen/types_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright 2023 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package randgen

import (
"fmt"
"testing"

"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/errors"
)

// TestSeedTypes verifies that at least one representative type is included into
// SeedTypes for all (with a few exceptions) type families.
func TestSeedTypes(t *testing.T) {
defer leaktest.AfterTest(t)()

noFamilyRepresentative := make(map[types.Family]struct{})
loop:
for id := range types.Family_name {
familyID := types.Family(id)
switch familyID {
case types.EnumFamily:
// Enums need to created separately.
continue loop
case types.UnknownFamily, types.AnyFamily:
// These are not included on purpose.
continue loop
}
noFamilyRepresentative[familyID] = struct{}{}
}
for _, typ := range SeedTypes {
delete(noFamilyRepresentative, typ.Family())
}
if len(noFamilyRepresentative) > 0 {
s := "no representative for "
for f := range noFamilyRepresentative {
s += fmt.Sprintf("%s (%d) ", types.Family_name[int32(f)], f)
}
t.Fatal(errors.Errorf("%s", s))
}
}
33 changes: 6 additions & 27 deletions pkg/sql/sem/tree/datum.go
Original file line number Diff line number Diff line change
Expand Up @@ -5796,7 +5796,7 @@ func InferTypes(vals []string) []types.Family {
// with a reasonable previous datum that is smaller than the given one.
//
// The return value is undefined if Datum.IsMin returns true or if the value is
// NaN of an infinity (for floats and decimals).
// NaN or an infinity (for floats and decimals).
func DatumPrev(
datum Datum, evalCtx *EvalContext, collationEnv *CollationEnvironment,
) (Datum, bool) {
Expand Down Expand Up @@ -5833,16 +5833,6 @@ func DatumPrev(
return nil, false
}
return NewDString(prev), true
case *DCollatedString:
prev, ok := prevString(d.Contents)
if !ok {
return nil, false
}
c, err := NewDCollatedString(prev, d.Locale, collationEnv)
if err != nil {
return nil, false
}
return c, true
case *DBytes:
prev, ok := prevString(string(*d))
if !ok {
Expand All @@ -5855,8 +5845,8 @@ func DatumPrev(
return NewDInterval(prev, types.DefaultIntervalTypeMetadata), true
default:
// TODO(yuzefovich): consider adding support for other datums that don't
// have Datum.Prev implementation (DBitArray, DGeography, DGeometry,
// DBox2D, DJSON, DArray).
// have Datum.Prev implementation (DCollatedString, DBitArray,
// DGeography, DGeometry, DBox2D, DJSON, DArray).
return datum.Prev(evalCtx)
}
}
Expand All @@ -5867,7 +5857,7 @@ func DatumPrev(
// with a reasonable next datum that is greater than the given one.
//
// The return value is undefined if Datum.IsMax returns true or if the value is
// NaN of an infinity (for floats and decimals).
// NaN or an infinity (for floats and decimals).
func DatumNext(
datum Datum, evalCtx *EvalContext, collationEnv *CollationEnvironment,
) (Datum, bool) {
Expand All @@ -5885,24 +5875,13 @@ func DatumNext(
return nil, false
}
return &next, true
case *DCollatedString:
s := NewDString(d.Contents)
next, ok := s.Next(evalCtx)
if !ok {
return nil, false
}
c, err := NewDCollatedString(string(*next.(*DString)), d.Locale, collationEnv)
if err != nil {
return nil, false
}
return c, true
case *DInterval:
next := d.Add(duration.MakeDuration(1000000 /* nanos */, 0 /* days */, 0 /* months */))
return NewDInterval(next, types.DefaultIntervalTypeMetadata), true
default:
// TODO(yuzefovich): consider adding support for other datums that don't
// have Datum.Next implementation (DGeography, DGeometry, DBox2D,
// DJSON).
// have Datum.Next implementation (DCollatedString, DGeography,
// DGeometry, DBox2D, DJSON).
return datum.Next(evalCtx)
}
}
10 changes: 5 additions & 5 deletions pkg/workload/schemachange/operation_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3049,11 +3049,11 @@ func (og *operationGenerator) typeFromTypeName(
if err != nil {
return nil, errors.Wrapf(err, "typeFromTypeName: %s", typeName)
}
typ, err := tree.ResolveType(
context.Background(),
stmt.AST.(*tree.Select).Select.(*tree.SelectClause).Exprs[0].Expr.(*tree.CastExpr).Type,
&txTypeResolver{tx: tx},
)
typRef, err := parser.GetTypeFromCastOrCollate(stmt.AST.(*tree.Select).Select.(*tree.SelectClause).Exprs[0].Expr)
if err != nil {
return nil, errors.Wrapf(err, "GetTypeFromCastOrCollate: %s", typeName)
}
typ, err := tree.ResolveType(ctx, typRef, &txTypeResolver{tx: tx})
if err != nil {
return nil, errors.Wrapf(err, "ResolveType: %v", typeName)
}
Expand Down

0 comments on commit e842323

Please sign in to comment.