Skip to content

Commit

Permalink
tree: apply casts in typeCheckSameTypedExprs for non-equivalent types
Browse files Browse the repository at this point in the history
Function typeCheckSameTypedExprs is updated by cockroachdb#108387 and cockroachdb#109635 to
apply implicit CASTs to the different expressions input to an array,
tuple, case expression or other similar expression, to coerce them all
to a common data type, instead of erroring out. This is not done,
however, if the data types of these expressions are not equivalent with
each other, causing some cases to error out.

The fix is to match Postgres behavior and apply the casts even for
non-equivalent types.

Informs: cockroachdb#109105

Release note (sql change): This patch modifies type checking of arrays,
tuples, and case statements to allow implicit casting of scalar
expressions referenced in these constructs to a common data type, even
for types in different type families, as long at the implicit cast is
legal.
  • Loading branch information
Mark Sirek committed Sep 10, 2023
1 parent 9e526be commit adf7103
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 21 deletions.
36 changes: 36 additions & 0 deletions pkg/sql/logictest/testdata/logic_test/typing
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,39 @@ SELECT (CASE WHEN t108360_1.t > t108360_2.c THEN t108360_1.t ELSE t108360_2.c EN
FROM t108360_1, t108360_2
WHERE t108360_1.t = (CASE WHEN t108360_1.t > t108360_2.c THEN t108360_1.t ELSE t108360_2.c END);
----

subtest implicit_cast_both_directions

# The following results match Postgres.

# The DATE should be implicitly CAST to a timestamp.
query T
SELECT ARRAY['1989-01-15':::DATE,'2005-11-14 21:51:05.000581':::TIMESTAMP];
----
{"1989-01-15 00:00:00","2005-11-14 21:51:05.000581"}

# The DATE should be implicitly CAST to a timestamp.
query T
SELECT ARRAY['2005-11-14 21:51:05.000581':::TIMESTAMP, '1989-01-15':::DATE];
----
{"2005-11-14 21:51:05.000581","1989-01-15 00:00:00"}

# The DATE should be implicitly CAST to a timestamp.
query T
SELECT COALESCE('2005-11-14 21:51:05.000581'::TIMESTAMP, '1989-01-15'::DATE);
----
2005-11-14 21:51:05.000581 +0000 +0000

# The DATE should be implicitly CAST to a timestamp.
query T
SELECT COALESCE('1989-01-15'::DATE, '2005-11-14 21:51:05.000581'::TIMESTAMP);
----
1989-01-15 00:00:00 +0000 +0000

# We should be able to implicitly cast to float.
query O
SELECT array_cat_agg(
ARRAY[(41664367676:::INT8,),(NULL,),((-0.12116245180368423):::FLOAT8,),((-0.42116245180368423):::DECIMAL,)]
);
----
{(4.1664367676e+10),(),(-0.12116245180368423),(-0.4211624518036842)}
9 changes: 1 addition & 8 deletions pkg/sql/sem/tree/type_check.go
Original file line number Diff line number Diff line change
Expand Up @@ -2685,11 +2685,6 @@ func typeCheckSameTypedExprs(
candidateType = typ
}
}
// TODO(mgartner): Remove this check now that we check the types
// below.
if typ := typedExpr.ResolvedType(); !(typ.Equivalent(candidateType) || typ.Family() == types.UnknownFamily) {
return nil, nil, unexpectedTypeError(exprs[i], candidateType, typ)
}
typedExprs[i] = typedExpr
}
if !constIdxs.Empty() {
Expand All @@ -2710,9 +2705,7 @@ func typeCheckSameTypedExprs(
// https://www.postgresql.org/docs/15/typeconv-union-case.html
for i, e := range typedExprs {
typ := e.ResolvedType()
// TODO(mgartner): There should probably be a cast if the types are
// not identical, not just if the types are not equivalent.
if typ.Equivalent(candidateType) || typ.Family() == types.UnknownFamily {
if typ.Identical(candidateType) || typ.Family() == types.UnknownFamily {
continue
}
if !cast.ValidCast(typ, candidateType, cast.ContextImplicit) {
Expand Down
58 changes: 45 additions & 13 deletions pkg/sql/sem/tree/type_check_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,19 @@ func TestTypeCheckSameTypedExprs(t *testing.T) {
{nil, nil, exprs(decConst("1.1")), types.Decimal, nil},
{nil, nil, exprs(intConst("1"), decConst("1.0")), types.Decimal, nil},
{nil, nil, exprs(intConst("1"), decConst("1.1")), types.Decimal, nil},
{nil, nil, exprs(decConst("1.1"), intConst("1")), types.Decimal, nil},
// Resolved exprs.
{nil, nil, exprs(dint(1)), types.Int, nil},
{nil, nil, exprs(ddecimal(1)), types.Decimal, nil},
{nil, nil, exprs(ddecimal(1), dfloat(1), dint(1)), types.Float, nil},
// Mixing constants and resolved exprs.
{nil, nil, exprs(dint(1), intConst("1")), types.Int, nil},
{nil, nil, exprs(dint(1), decConst("1.0")), types.Int, nil}, // This is what the AST would look like after folding (0.6 + 0.4).
{nil, nil, exprs(dint(1), dint(1)), types.Int, nil},
{nil, nil, exprs(ddecimal(1), intConst("1")), types.Decimal, nil},
{nil, nil, exprs(ddecimal(1), decConst("1.1")), types.Decimal, nil},
{nil, nil, exprs(ddecimal(1), ddecimal(1)), types.Decimal, nil},
{nil, nil, exprs(decConst("1.0"), dfloat(1), dint(1)), types.Float, nil},
// Mixing resolved placeholders with constants and resolved exprs.
{ptypesDecimal, nil, exprs(ddecimal(1), placeholder(0)), types.Decimal, ptypesDecimal},
{ptypesDecimal, nil, exprs(intConst("1"), placeholder(0)), types.Decimal, ptypesDecimal},
Expand All @@ -251,6 +254,8 @@ func TestTypeCheckSameTypedExprs(t *testing.T) {
{nil, nil, exprs(dnull, ddecimal(1), decConst("1.1")), types.Decimal, nil},
{nil, nil, exprs(dnull, ddecimal(1), decConst("1.1")), types.Decimal, nil},
{nil, nil, exprs(dnull, intConst("1"), decConst("1.1")), types.Decimal, nil},
{nil, nil, exprs(dnull, intConst("1"), dfloat(1), decConst("1.1")), types.Float, nil},
{nil, nil, exprs(dnull, intConst("1"), decConst("1.1"), dfloat(1)), types.Float, nil},
// Verify desired type when possible.
{nil, types.Int, exprs(intConst("1")), types.Int, nil},
{nil, types.Int, exprs(dint(1)), types.Int, nil},
Expand All @@ -262,6 +267,10 @@ func TestTypeCheckSameTypedExprs(t *testing.T) {
{nil, types.Int, exprs(intConst("1"), decConst("1.0")), types.Int, nil},
{nil, types.Int, exprs(intConst("1"), decConst("1.1")), types.Decimal, nil},
{nil, types.Decimal, exprs(intConst("1"), decConst("1.1")), types.Decimal, nil},
{nil, types.Int, exprs(intConst("1"), dfloat(1), decConst("1.1")), types.Float, nil},
{nil, types.Int, exprs(intConst("1"), decConst("1.1"), dfloat(1)), types.Float, nil},
{nil, types.Decimal, exprs(intConst("1"), dfloat(1), decConst("1.1")), types.Float, nil},
{nil, types.Decimal, exprs(intConst("1"), decConst("1.1"), dfloat(1)), types.Float, nil},
// Verify desired type when possible with unresolved placeholders.
{ptypesNone, types.Decimal, exprs(placeholder(0)), types.Decimal, ptypesDecimal},
{ptypesNone, types.Decimal, exprs(intConst("1"), placeholder(0)), types.Decimal, ptypesDecimal},
Expand All @@ -283,12 +292,16 @@ func TestTypeCheckSameTypedTupleExprs(t *testing.T) {
{nil, nil, exprs(tuple(intConst("1")), tuple(intConst("1"))), ttuple(types.Int), nil},
{nil, nil, exprs(tuple(intConst("1")), tuple(decConst("1.0"))), ttuple(types.Decimal), nil},
{nil, nil, exprs(tuple(intConst("1")), tuple(decConst("1.1"))), ttuple(types.Decimal), nil},
{nil, nil, exprs(tuple(decConst("1.1")), tuple(intConst("1"))), ttuple(types.Decimal), nil},
// Resolved exprs.
{nil, nil, exprs(tuple(dint(1)), tuple(dint(1))), ttuple(types.Int), nil},
{nil, nil, exprs(tuple(dint(1), ddecimal(1)), tuple(dint(1), ddecimal(1))), ttuple(types.Int, types.Decimal), nil},
{nil, nil, exprs(tuple(dint(1), ddecimal(1)), tuple(dint(1), dint(1))), ttuple(types.Int, types.Decimal), nil},
{nil, nil, exprs(tuple(dint(1), dint(1)), tuple(dint(1), ddecimal(1))), ttuple(types.Int, types.Decimal), nil},
// Mixing constants and resolved exprs.
{nil, nil, exprs(tuple(dint(1), decConst("1.1")), tuple(intConst("1"), ddecimal(1))), ttuple(types.Int, types.Decimal), nil},
{nil, nil, exprs(tuple(dint(1), decConst("1.0")), tuple(intConst("1"), dint(1))), ttuple(types.Int, types.Int), nil},
{nil, nil, exprs(tuple(dfloat(1), decConst("1.1")), tuple(intConst("1"), dfloat(1))), ttuple(types.Float, types.Float), nil},
// Mixing resolved placeholders with constants and resolved exprs.
{ptypesDecimal, nil, exprs(tuple(ddecimal(1), intConst("1")), tuple(placeholder(0), placeholder(0))), ttuple(types.Decimal, types.Decimal), ptypesDecimal},
{ptypesDecimalAndDecimal, nil, exprs(tuple(placeholder(1), intConst("1")), tuple(placeholder(0), placeholder(0))), ttuple(types.Decimal, types.Decimal), ptypesDecimalAndDecimal},
Expand All @@ -300,6 +313,8 @@ func TestTypeCheckSameTypedTupleExprs(t *testing.T) {
{nil, nil, exprs(tuple(intConst("1"), dnull), tuple(dnull, decConst("1"))), ttuple(types.Int, types.Decimal), nil},
{nil, nil, exprs(tuple(dint(1), dnull), tuple(dnull, ddecimal(1))), ttuple(types.Int, types.Decimal), nil},
{nil, nil, exprs(tuple(dint(1), dnull), dnull, tuple(dint(1), dnull), dnull), ttuple(types.Int, types.Unknown), nil},
{nil, nil, exprs(tuple(dint(1), dnull), dnull, tuple(dfloat(1), dnull), dnull), ttuple(types.Float, types.Unknown), nil},
{nil, nil, exprs(tuple(dfloat(1), dnull), dnull, tuple(dint(1), dnull), dnull), ttuple(types.Float, types.Unknown), nil},
// Verify desired type when possible.
{nil, ttuple(types.Int, types.Decimal), exprs(tuple(intConst("1"), intConst("1")), tuple(intConst("1"), intConst("1"))), ttuple(types.Int, types.Decimal), nil},
// Verify desired type when possible with unresolved constants.
Expand All @@ -309,6 +324,9 @@ func TestTypeCheckSameTypedTupleExprs(t *testing.T) {
{nil, nil, exprs(tuple(dint(1)), dnull, tuple(dnull), dnull), ttuple(types.Int), nil},
{nil, nil, exprs(dnull, tuple(dint(1), dnull), tuple(dnull, dfloat(1)), dnull), ttuple(types.Int, types.Float), nil},
{nil, nil, exprs(tuple(dint(1), dnull, dnull), tuple(dnull, ddecimal(1), dnull), dnull, tuple(dnull, dnull, dfloat(1)), dnull), ttuple(types.Int, types.Decimal, types.Float), nil},
// Verify CASTs can be applied to non-equivalent types.
{nil, nil, exprs(tuple(dnull, ddecimal(1), dnull), tuple(dnull, dint(1), dnull), dnull, dnull), ttuple(types.Unknown, types.Decimal, types.Unknown), nil},
{nil, nil, exprs(tuple(dnull, dint(1), dnull), tuple(dnull, ddecimal(1), dnull), dnull, tuple(dnull, dfloat(1), dnull), dnull), ttuple(types.Unknown, types.Float, types.Unknown), nil},
} {
t.Run(fmt.Sprintf("test_%d", i), func(t *testing.T) {
attemptTypeCheckSameTypedExprs(t, i, d)
Expand Down Expand Up @@ -363,25 +381,29 @@ func TestTypeCheckSameTypedExprsError(t *testing.T) {
}
}

func TestTypeCheckSameTypedExprsImplicitCastOneWay(t *testing.T) {
// TestTypeCheckSameTypedExprsOrderInvariant tests that typing of expressions in
// a slice is the same no matter the order in which they are processed.
func TestTypeCheckSameTypedExprsOrderInvariant(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)
decimalIntMismatchErr := `expected .* to be of type (decimal|int), found type (decimal|int)`
tupleDecimalIntMismatchErr := `tuples .* are not the same type: ` + decimalIntMismatchErr

testData := []struct {
ptypes tree.PlaceholderTypes
desired *types.T
exprs []copyableExpr

expectedErr string
expectedTypeFamily types.Family
}{
// For each of these test cases, it should be possible to implicitly cast
// from left to right but not vice-versa.
// For each of these test cases, a CAST should be applied in the direction
// which can be done implicitly, no matter the input order of expressions.
// Verify that swapping the order of expressions causes the resulting typed
// expressions to be of the proper type.
// Single type mismatches.
{nil, nil, exprs(dint(1), ddecimal(1)), decimalIntMismatchErr},
{nil, nil, exprs(dint(1), ddecimal(1)), types.DecimalFamily},
{nil, nil, exprs(dint(1), dfloat(1.1), ddecimal(1)), types.FloatFamily},
{nil, nil, exprs(dint(1), dfloat(1.1), ddecimal(1)), types.FloatFamily},
// Tuple type mismatches.
{nil, nil, exprs(tuple(dint(1)), tuple(ddecimal(1))), tupleDecimalIntMismatchErr},
{nil, nil, exprs(tuple(dint(1)), tuple(ddecimal(1))), types.TupleFamily},
}
ctx := context.Background()
for i, d := range testData {
Expand All @@ -400,13 +422,23 @@ func TestTypeCheckSameTypedExprsImplicitCastOneWay(t *testing.T) {
); err != nil {
t.Errorf("%d: unexpected error returned from TypeCheckSameTypedExprs: %v", i, err)
}
// Right to left fails.
// Swapping expression order causes the CAST to be applied in the
// opposite direction.
exprs := make([]copyableExpr, len(d.exprs))
exprs[0], exprs[1] = d.exprs[1], d.exprs[0]
if _, _, err := tree.TypeCheckSameTypedExprs(
if len(exprs) != 2 {
copy(exprs, d.exprs)
}
exprs[0], exprs[len(d.exprs)-1] = d.exprs[len(d.exprs)-1], d.exprs[0]
typedExprs, _, err := tree.TypeCheckSameTypedExprs(
ctx, &semaCtx, desired, buildExprs(exprs)...,
); !testutils.IsError(err, d.expectedErr) {
t.Errorf("%d: expected %s, but found %v", i, d.expectedErr, err)
)
if err != nil {
t.Errorf("Expected no error, but found %v", err)
}
for _, typedExpr := range typedExprs {
if typedExpr.ResolvedType().Family() != d.expectedTypeFamily {
t.Errorf("Expected type family %v, but found %v", d.expectedTypeFamily, typedExpr.ResolvedType().Family())
}
}
})
}
Expand Down

0 comments on commit adf7103

Please sign in to comment.