From adf71039c9ed060c287384a49347b59f8a56bf21 Mon Sep 17 00:00:00 2001 From: Mark Sirek Date: Sun, 10 Sep 2023 13:05:26 -0700 Subject: [PATCH] tree: apply casts in typeCheckSameTypedExprs for non-equivalent types Function typeCheckSameTypedExprs is updated by #108387 and #109635 to apply implicit CASTs to the different expressions input to an array, tuple, case expression or other similar expression, to coerce them all to a common data type, instead of erroring out. This is not done, however, if the data types of these expressions are not equivalent with each other, causing some cases to error out. The fix is to match Postgres behavior and apply the casts even for non-equivalent types. Informs: #109105 Release note (sql change): This patch modifies type checking of arrays, tuples, and case statements to allow implicit casting of scalar expressions referenced in these constructs to a common data type, even for types in different type families, as long at the implicit cast is legal. --- pkg/sql/logictest/testdata/logic_test/typing | 36 ++++++++++++ pkg/sql/sem/tree/type_check.go | 9 +-- pkg/sql/sem/tree/type_check_internal_test.go | 58 +++++++++++++++----- 3 files changed, 82 insertions(+), 21 deletions(-) diff --git a/pkg/sql/logictest/testdata/logic_test/typing b/pkg/sql/logictest/testdata/logic_test/typing index 266b9d21d533..931086d9c508 100644 --- a/pkg/sql/logictest/testdata/logic_test/typing +++ b/pkg/sql/logictest/testdata/logic_test/typing @@ -313,3 +313,39 @@ SELECT (CASE WHEN t108360_1.t > t108360_2.c THEN t108360_1.t ELSE t108360_2.c EN FROM t108360_1, t108360_2 WHERE t108360_1.t = (CASE WHEN t108360_1.t > t108360_2.c THEN t108360_1.t ELSE t108360_2.c END); ---- + +subtest implicit_cast_both_directions + +# The following results match Postgres. + +# The DATE should be implicitly CAST to a timestamp. +query T +SELECT ARRAY['1989-01-15':::DATE,'2005-11-14 21:51:05.000581':::TIMESTAMP]; +---- +{"1989-01-15 00:00:00","2005-11-14 21:51:05.000581"} + +# The DATE should be implicitly CAST to a timestamp. +query T +SELECT ARRAY['2005-11-14 21:51:05.000581':::TIMESTAMP, '1989-01-15':::DATE]; +---- +{"2005-11-14 21:51:05.000581","1989-01-15 00:00:00"} + +# The DATE should be implicitly CAST to a timestamp. +query T +SELECT COALESCE('2005-11-14 21:51:05.000581'::TIMESTAMP, '1989-01-15'::DATE); +---- +2005-11-14 21:51:05.000581 +0000 +0000 + +# The DATE should be implicitly CAST to a timestamp. +query T +SELECT COALESCE('1989-01-15'::DATE, '2005-11-14 21:51:05.000581'::TIMESTAMP); +---- +1989-01-15 00:00:00 +0000 +0000 + +# We should be able to implicitly cast to float. +query O +SELECT array_cat_agg( + ARRAY[(41664367676:::INT8,),(NULL,),((-0.12116245180368423):::FLOAT8,),((-0.42116245180368423):::DECIMAL,)] +); +---- +{(4.1664367676e+10),(),(-0.12116245180368423),(-0.4211624518036842)} diff --git a/pkg/sql/sem/tree/type_check.go b/pkg/sql/sem/tree/type_check.go index a56a2225ac5c..0fc55e495fc5 100644 --- a/pkg/sql/sem/tree/type_check.go +++ b/pkg/sql/sem/tree/type_check.go @@ -2685,11 +2685,6 @@ func typeCheckSameTypedExprs( candidateType = typ } } - // TODO(mgartner): Remove this check now that we check the types - // below. - if typ := typedExpr.ResolvedType(); !(typ.Equivalent(candidateType) || typ.Family() == types.UnknownFamily) { - return nil, nil, unexpectedTypeError(exprs[i], candidateType, typ) - } typedExprs[i] = typedExpr } if !constIdxs.Empty() { @@ -2710,9 +2705,7 @@ func typeCheckSameTypedExprs( // https://www.postgresql.org/docs/15/typeconv-union-case.html for i, e := range typedExprs { typ := e.ResolvedType() - // TODO(mgartner): There should probably be a cast if the types are - // not identical, not just if the types are not equivalent. - if typ.Equivalent(candidateType) || typ.Family() == types.UnknownFamily { + if typ.Identical(candidateType) || typ.Family() == types.UnknownFamily { continue } if !cast.ValidCast(typ, candidateType, cast.ContextImplicit) { diff --git a/pkg/sql/sem/tree/type_check_internal_test.go b/pkg/sql/sem/tree/type_check_internal_test.go index 6787e0a46d64..a8c2c7ee76c2 100644 --- a/pkg/sql/sem/tree/type_check_internal_test.go +++ b/pkg/sql/sem/tree/type_check_internal_test.go @@ -219,9 +219,11 @@ func TestTypeCheckSameTypedExprs(t *testing.T) { {nil, nil, exprs(decConst("1.1")), types.Decimal, nil}, {nil, nil, exprs(intConst("1"), decConst("1.0")), types.Decimal, nil}, {nil, nil, exprs(intConst("1"), decConst("1.1")), types.Decimal, nil}, + {nil, nil, exprs(decConst("1.1"), intConst("1")), types.Decimal, nil}, // Resolved exprs. {nil, nil, exprs(dint(1)), types.Int, nil}, {nil, nil, exprs(ddecimal(1)), types.Decimal, nil}, + {nil, nil, exprs(ddecimal(1), dfloat(1), dint(1)), types.Float, nil}, // Mixing constants and resolved exprs. {nil, nil, exprs(dint(1), intConst("1")), types.Int, nil}, {nil, nil, exprs(dint(1), decConst("1.0")), types.Int, nil}, // This is what the AST would look like after folding (0.6 + 0.4). @@ -229,6 +231,7 @@ func TestTypeCheckSameTypedExprs(t *testing.T) { {nil, nil, exprs(ddecimal(1), intConst("1")), types.Decimal, nil}, {nil, nil, exprs(ddecimal(1), decConst("1.1")), types.Decimal, nil}, {nil, nil, exprs(ddecimal(1), ddecimal(1)), types.Decimal, nil}, + {nil, nil, exprs(decConst("1.0"), dfloat(1), dint(1)), types.Float, nil}, // Mixing resolved placeholders with constants and resolved exprs. {ptypesDecimal, nil, exprs(ddecimal(1), placeholder(0)), types.Decimal, ptypesDecimal}, {ptypesDecimal, nil, exprs(intConst("1"), placeholder(0)), types.Decimal, ptypesDecimal}, @@ -251,6 +254,8 @@ func TestTypeCheckSameTypedExprs(t *testing.T) { {nil, nil, exprs(dnull, ddecimal(1), decConst("1.1")), types.Decimal, nil}, {nil, nil, exprs(dnull, ddecimal(1), decConst("1.1")), types.Decimal, nil}, {nil, nil, exprs(dnull, intConst("1"), decConst("1.1")), types.Decimal, nil}, + {nil, nil, exprs(dnull, intConst("1"), dfloat(1), decConst("1.1")), types.Float, nil}, + {nil, nil, exprs(dnull, intConst("1"), decConst("1.1"), dfloat(1)), types.Float, nil}, // Verify desired type when possible. {nil, types.Int, exprs(intConst("1")), types.Int, nil}, {nil, types.Int, exprs(dint(1)), types.Int, nil}, @@ -262,6 +267,10 @@ func TestTypeCheckSameTypedExprs(t *testing.T) { {nil, types.Int, exprs(intConst("1"), decConst("1.0")), types.Int, nil}, {nil, types.Int, exprs(intConst("1"), decConst("1.1")), types.Decimal, nil}, {nil, types.Decimal, exprs(intConst("1"), decConst("1.1")), types.Decimal, nil}, + {nil, types.Int, exprs(intConst("1"), dfloat(1), decConst("1.1")), types.Float, nil}, + {nil, types.Int, exprs(intConst("1"), decConst("1.1"), dfloat(1)), types.Float, nil}, + {nil, types.Decimal, exprs(intConst("1"), dfloat(1), decConst("1.1")), types.Float, nil}, + {nil, types.Decimal, exprs(intConst("1"), decConst("1.1"), dfloat(1)), types.Float, nil}, // Verify desired type when possible with unresolved placeholders. {ptypesNone, types.Decimal, exprs(placeholder(0)), types.Decimal, ptypesDecimal}, {ptypesNone, types.Decimal, exprs(intConst("1"), placeholder(0)), types.Decimal, ptypesDecimal}, @@ -283,12 +292,16 @@ func TestTypeCheckSameTypedTupleExprs(t *testing.T) { {nil, nil, exprs(tuple(intConst("1")), tuple(intConst("1"))), ttuple(types.Int), nil}, {nil, nil, exprs(tuple(intConst("1")), tuple(decConst("1.0"))), ttuple(types.Decimal), nil}, {nil, nil, exprs(tuple(intConst("1")), tuple(decConst("1.1"))), ttuple(types.Decimal), nil}, + {nil, nil, exprs(tuple(decConst("1.1")), tuple(intConst("1"))), ttuple(types.Decimal), nil}, // Resolved exprs. {nil, nil, exprs(tuple(dint(1)), tuple(dint(1))), ttuple(types.Int), nil}, {nil, nil, exprs(tuple(dint(1), ddecimal(1)), tuple(dint(1), ddecimal(1))), ttuple(types.Int, types.Decimal), nil}, + {nil, nil, exprs(tuple(dint(1), ddecimal(1)), tuple(dint(1), dint(1))), ttuple(types.Int, types.Decimal), nil}, + {nil, nil, exprs(tuple(dint(1), dint(1)), tuple(dint(1), ddecimal(1))), ttuple(types.Int, types.Decimal), nil}, // Mixing constants and resolved exprs. {nil, nil, exprs(tuple(dint(1), decConst("1.1")), tuple(intConst("1"), ddecimal(1))), ttuple(types.Int, types.Decimal), nil}, {nil, nil, exprs(tuple(dint(1), decConst("1.0")), tuple(intConst("1"), dint(1))), ttuple(types.Int, types.Int), nil}, + {nil, nil, exprs(tuple(dfloat(1), decConst("1.1")), tuple(intConst("1"), dfloat(1))), ttuple(types.Float, types.Float), nil}, // Mixing resolved placeholders with constants and resolved exprs. {ptypesDecimal, nil, exprs(tuple(ddecimal(1), intConst("1")), tuple(placeholder(0), placeholder(0))), ttuple(types.Decimal, types.Decimal), ptypesDecimal}, {ptypesDecimalAndDecimal, nil, exprs(tuple(placeholder(1), intConst("1")), tuple(placeholder(0), placeholder(0))), ttuple(types.Decimal, types.Decimal), ptypesDecimalAndDecimal}, @@ -300,6 +313,8 @@ func TestTypeCheckSameTypedTupleExprs(t *testing.T) { {nil, nil, exprs(tuple(intConst("1"), dnull), tuple(dnull, decConst("1"))), ttuple(types.Int, types.Decimal), nil}, {nil, nil, exprs(tuple(dint(1), dnull), tuple(dnull, ddecimal(1))), ttuple(types.Int, types.Decimal), nil}, {nil, nil, exprs(tuple(dint(1), dnull), dnull, tuple(dint(1), dnull), dnull), ttuple(types.Int, types.Unknown), nil}, + {nil, nil, exprs(tuple(dint(1), dnull), dnull, tuple(dfloat(1), dnull), dnull), ttuple(types.Float, types.Unknown), nil}, + {nil, nil, exprs(tuple(dfloat(1), dnull), dnull, tuple(dint(1), dnull), dnull), ttuple(types.Float, types.Unknown), nil}, // Verify desired type when possible. {nil, ttuple(types.Int, types.Decimal), exprs(tuple(intConst("1"), intConst("1")), tuple(intConst("1"), intConst("1"))), ttuple(types.Int, types.Decimal), nil}, // Verify desired type when possible with unresolved constants. @@ -309,6 +324,9 @@ func TestTypeCheckSameTypedTupleExprs(t *testing.T) { {nil, nil, exprs(tuple(dint(1)), dnull, tuple(dnull), dnull), ttuple(types.Int), nil}, {nil, nil, exprs(dnull, tuple(dint(1), dnull), tuple(dnull, dfloat(1)), dnull), ttuple(types.Int, types.Float), nil}, {nil, nil, exprs(tuple(dint(1), dnull, dnull), tuple(dnull, ddecimal(1), dnull), dnull, tuple(dnull, dnull, dfloat(1)), dnull), ttuple(types.Int, types.Decimal, types.Float), nil}, + // Verify CASTs can be applied to non-equivalent types. + {nil, nil, exprs(tuple(dnull, ddecimal(1), dnull), tuple(dnull, dint(1), dnull), dnull, dnull), ttuple(types.Unknown, types.Decimal, types.Unknown), nil}, + {nil, nil, exprs(tuple(dnull, dint(1), dnull), tuple(dnull, ddecimal(1), dnull), dnull, tuple(dnull, dfloat(1), dnull), dnull), ttuple(types.Unknown, types.Float, types.Unknown), nil}, } { t.Run(fmt.Sprintf("test_%d", i), func(t *testing.T) { attemptTypeCheckSameTypedExprs(t, i, d) @@ -363,25 +381,29 @@ func TestTypeCheckSameTypedExprsError(t *testing.T) { } } -func TestTypeCheckSameTypedExprsImplicitCastOneWay(t *testing.T) { +// TestTypeCheckSameTypedExprsOrderInvariant tests that typing of expressions in +// a slice is the same no matter the order in which they are processed. +func TestTypeCheckSameTypedExprsOrderInvariant(t *testing.T) { defer leaktest.AfterTest(t)() defer log.Scope(t).Close(t) - decimalIntMismatchErr := `expected .* to be of type (decimal|int), found type (decimal|int)` - tupleDecimalIntMismatchErr := `tuples .* are not the same type: ` + decimalIntMismatchErr testData := []struct { ptypes tree.PlaceholderTypes desired *types.T exprs []copyableExpr - expectedErr string + expectedTypeFamily types.Family }{ - // For each of these test cases, it should be possible to implicitly cast - // from left to right but not vice-versa. + // For each of these test cases, a CAST should be applied in the direction + // which can be done implicitly, no matter the input order of expressions. + // Verify that swapping the order of expressions causes the resulting typed + // expressions to be of the proper type. // Single type mismatches. - {nil, nil, exprs(dint(1), ddecimal(1)), decimalIntMismatchErr}, + {nil, nil, exprs(dint(1), ddecimal(1)), types.DecimalFamily}, + {nil, nil, exprs(dint(1), dfloat(1.1), ddecimal(1)), types.FloatFamily}, + {nil, nil, exprs(dint(1), dfloat(1.1), ddecimal(1)), types.FloatFamily}, // Tuple type mismatches. - {nil, nil, exprs(tuple(dint(1)), tuple(ddecimal(1))), tupleDecimalIntMismatchErr}, + {nil, nil, exprs(tuple(dint(1)), tuple(ddecimal(1))), types.TupleFamily}, } ctx := context.Background() for i, d := range testData { @@ -400,13 +422,23 @@ func TestTypeCheckSameTypedExprsImplicitCastOneWay(t *testing.T) { ); err != nil { t.Errorf("%d: unexpected error returned from TypeCheckSameTypedExprs: %v", i, err) } - // Right to left fails. + // Swapping expression order causes the CAST to be applied in the + // opposite direction. exprs := make([]copyableExpr, len(d.exprs)) - exprs[0], exprs[1] = d.exprs[1], d.exprs[0] - if _, _, err := tree.TypeCheckSameTypedExprs( + if len(exprs) != 2 { + copy(exprs, d.exprs) + } + exprs[0], exprs[len(d.exprs)-1] = d.exprs[len(d.exprs)-1], d.exprs[0] + typedExprs, _, err := tree.TypeCheckSameTypedExprs( ctx, &semaCtx, desired, buildExprs(exprs)..., - ); !testutils.IsError(err, d.expectedErr) { - t.Errorf("%d: expected %s, but found %v", i, d.expectedErr, err) + ) + if err != nil { + t.Errorf("Expected no error, but found %v", err) + } + for _, typedExpr := range typedExprs { + if typedExpr.ResolvedType().Family() != d.expectedTypeFamily { + t.Errorf("Expected type family %v, but found %v", d.expectedTypeFamily, typedExpr.ResolvedType().Family()) + } } }) }