From 8d19e0eebb4fd7e12a30901852395219a5aaa159 Mon Sep 17 00:00:00 2001 From: Mark Sirek Date: Sat, 9 Sep 2023 14:02:56 -0700 Subject: [PATCH] tree: make typeCheckSameTypedExprs order invariant for tuples Type checking of CASE expressions (and other expressions handled by typeCheckSameTypedExprs) does not type check tuples properly because typeCheckSameTypedTupleExprs checking is only done when the tuple is the first expression in `exprs`. This is fixed by finding the first tuple in `exprs`, searching the slice starting at index 0, and using that to drive typeCheckSameTypedExprs. Fixes: #109105 Release note: None --- pkg/sql/logictest/testdata/logic_test/tuple | 34 +++++++++++++++++--- pkg/sql/sem/tree/type_check.go | 32 ++++++++++++++++-- pkg/sql/sem/tree/type_check_internal_test.go | 10 ++++++ pkg/sql/sem/tree/type_check_test.go | 5 ++- 4 files changed, 72 insertions(+), 9 deletions(-) diff --git a/pkg/sql/logictest/testdata/logic_test/tuple b/pkg/sql/logictest/testdata/logic_test/tuple index 64c6b6b40f0b..317236a44a37 100644 --- a/pkg/sql/logictest/testdata/logic_test/tuple +++ b/pkg/sql/logictest/testdata/logic_test/tuple @@ -1201,11 +1201,10 @@ SELECT (CASE WHEN b THEN NULL ELSE ((ROW(1) AS a)) END).a from t78159 ---- 1 -# TODO(rytaft): Uncomment this case once #109105 is fixed. -# query B -# SELECT (CASE WHEN b THEN ((ROW(1) AS a)) ELSE NULL END).a from t78159 -# ---- -# NULL +query B +SELECT (CASE WHEN b THEN ((ROW(1) AS a)) ELSE NULL END).a from t78159 +---- +NULL # Regression test for #78515. Propagate tuple labels when type-checking # expressions with multiple matching tuple types. @@ -1259,3 +1258,28 @@ SELECT (ROW() AS a) IS NOT UNKNOWN statement error pgcode 42601 mismatch in tuple definition: 0 expressions, 1 labels SELECT CASE WHEN False THEN ROW() ELSE (ROW() AS a) END + +subtest 109105 + +statement ok +CREATE TABLE t109105 (a int); + +statement ok +INSERT INTO t109105 VALUES (1),(2),(3),(4),(5),(6); + +# This should CAST the nulls in the rows to the types of the constant values +# instead of erroring out. +query T +SELECT (CASE WHEN a = 1 THEN NULL + WHEN a = 2 THEN ROW(NULL, NULL, 1.1e3::FLOAT) + WHEN a = 3 THEN ROW(NULL, 1.1::DECIMAL, NULL) + WHEN a = 4 THEN ROW(1::INT, null, NULL) + WHEN a = 5 THEN NULL + ELSE NULL END) FROM t109105 ORDER BY 1; +---- +NULL +NULL +NULL +(,,1100) +(,1.1,) +(1,,) diff --git a/pkg/sql/sem/tree/type_check.go b/pkg/sql/sem/tree/type_check.go index 631181d05552..0a01e31c40f2 100644 --- a/pkg/sql/sem/tree/type_check.go +++ b/pkg/sql/sem/tree/type_check.go @@ -2595,6 +2595,15 @@ type typeCheckExprsState struct { resolvableIdxs intsets.Fast // index into exprs/typedExprs } +func findFirstTupleIndex(exprs ...Expr) (index int, ok bool) { + for i, expr := range exprs { + if _, ok := expr.(*Tuple); ok { + return i, true + } + } + return 0, false +} + // typeCheckSameTypedExprs type checks a list of expressions, asserting that all // resolved TypeExprs have the same type. An optional desired type can be provided, // which will hint that type which the expressions should resolve to, if possible. @@ -2618,9 +2627,26 @@ func typeCheckSameTypedExprs( return []TypedExpr{typedExpr}, typ, nil } - // Handle tuples, which will in turn call into this function recursively for each element. - if _, ok := exprs[0].(*Tuple); ok { - return typeCheckSameTypedTupleExprs(ctx, semaCtx, desired, exprs...) + // Handle tuples, which will in turn call into this function recursively for + // each element. + // TODO(msirek): Rewrite typeCheckSameTypedTupleExprs to handle all types of + // expressions which could resolve to a type family of `TupleFamily`, like a + // VALUES clause. Logic in `typeCheckSameTypedTupleExprs` states that the call + // to `TypeCheck` should be deferred until the common type is determined. So, + // we would need a way to determine which expressions are in the tuple family + // without inspecting the AST node and without calling `TypeCheck`. Does the + // call to `TypeCheck` really need to be deferred? + if idx, ok := findFirstTupleIndex(exprs...); ok { + if _, ok := exprs[idx].(*Tuple); ok { + // typeCheckSameTypedTupleExprs expects the first expression in the slice + // to be a tuple. + exprs[0], exprs[idx] = exprs[idx], exprs[0] + typedExprs, commonType, err := typeCheckSameTypedTupleExprs(ctx, semaCtx, desired, exprs...) + if err == nil { + typedExprs[0], typedExprs[idx] = typedExprs[idx], typedExprs[0] + } + return typedExprs, commonType, err + } } // Hold the resolved type expressions of the provided exprs, in order. diff --git a/pkg/sql/sem/tree/type_check_internal_test.go b/pkg/sql/sem/tree/type_check_internal_test.go index f358354f14ff..6787e0a46d64 100644 --- a/pkg/sql/sem/tree/type_check_internal_test.go +++ b/pkg/sql/sem/tree/type_check_internal_test.go @@ -136,6 +136,11 @@ func ddecimal(f float64) copyableExpr { return dd } } +func dfloat(f float64) copyableExpr { + return func() tree.Expr { + return tree.NewDFloat(tree.DFloat(f)) + } +} func placeholder(id tree.PlaceholderIdx) copyableExpr { return func() tree.Expr { return newPlaceholder(id) @@ -299,6 +304,11 @@ func TestTypeCheckSameTypedTupleExprs(t *testing.T) { {nil, ttuple(types.Int, types.Decimal), exprs(tuple(intConst("1"), intConst("1")), tuple(intConst("1"), intConst("1"))), ttuple(types.Int, types.Decimal), nil}, // Verify desired type when possible with unresolved constants. {ptypesNone, ttuple(types.Int, types.Decimal), exprs(tuple(placeholder(0), intConst("1")), tuple(intConst("1"), placeholder(1))), ttuple(types.Int, types.Decimal), ptypesIntAndDecimal}, + // Verify CASTs are in the direction of the non-null expression. + {nil, nil, exprs(tuple(dnull), dnull, tuple(dint(1)), dnull), ttuple(types.Int), nil}, + {nil, nil, exprs(tuple(dint(1)), dnull, tuple(dnull), dnull), ttuple(types.Int), nil}, + {nil, nil, exprs(dnull, tuple(dint(1), dnull), tuple(dnull, dfloat(1)), dnull), ttuple(types.Int, types.Float), nil}, + {nil, nil, exprs(tuple(dint(1), dnull, dnull), tuple(dnull, ddecimal(1), dnull), dnull, tuple(dnull, dnull, dfloat(1)), dnull), ttuple(types.Int, types.Decimal, types.Float), nil}, } { t.Run(fmt.Sprintf("test_%d", i), func(t *testing.T) { attemptTypeCheckSameTypedExprs(t, i, d) diff --git a/pkg/sql/sem/tree/type_check_test.go b/pkg/sql/sem/tree/type_check_test.go index cff12080de07..c258a5331590 100644 --- a/pkg/sql/sem/tree/type_check_test.go +++ b/pkg/sql/sem/tree/type_check_test.go @@ -171,7 +171,10 @@ func TestTypeCheck(t *testing.T) { `CASE WHEN true THEN ('a', 2) ELSE NULL:::RECORD END`, `CASE WHEN true THEN ('a':::STRING, 2:::INT8) ELSE NULL END`, }, - + { + `CASE WHEN true THEN NULL:::RECORD ELSE ('a', 2) END`, + `CASE WHEN true THEN NULL ELSE ('a':::STRING, 2:::INT8) END`, + }, {`((ROW (1) AS a)).a`, `1:::INT8`}, {`((('1', 2) AS a, b)).a`, `'1':::STRING`}, {`((('1', 2) AS a, b)).b`, `2:::INT8`},