diff --git a/pkg/sql/logictest/testdata/logic_test/union b/pkg/sql/logictest/testdata/logic_test/union index 9a6172e49b26..4997fdddcb3b 100644 --- a/pkg/sql/logictest/testdata/logic_test/union +++ b/pkg/sql/logictest/testdata/logic_test/union @@ -348,3 +348,21 @@ NULL statement ok CREATE TABLE ab (a INT, b INT); SELECT a, b, rowid FROM ab UNION VALUES (1, 2, 3); +DROP TABLE ab; + +# Regression test for #59148. +statement ok +CREATE TABLE ab (a INT4, b INT8); +INSERT INTO ab VALUES (1, 1), (1, 2), (2, 1), (2, 2); + +query I rowsort +SELECT a FROM ab UNION SELECT b FROM ab +---- +1 +2 + +query I rowsort +SELECT b FROM ab UNION SELECT a FROM ab +---- +1 +2 diff --git a/pkg/sql/opt/optbuilder/testdata/union b/pkg/sql/opt/optbuilder/testdata/union index cadb97740dfa..f2ffbdce4f6f 100644 --- a/pkg/sql/opt/optbuilder/testdata/union +++ b/pkg/sql/opt/optbuilder/testdata/union @@ -874,3 +874,128 @@ except │ │ └── 1 [as="?column?":6] │ └── 1 └── (1,) + +# Verify that we add casts for equivalent, but not identical types. +exec-ddl +CREATE TABLE ab (i8 INT8, i4 INT4, f8 FLOAT, f4 FLOAT4, d DECIMAL) +---- + +build +SELECT i4 FROM ab UNION SELECT i8 FROM ab +---- +union + ├── columns: i4:16 + ├── left columns: i4:15 + ├── right columns: i8:8 + ├── project + │ ├── columns: i4:15 + │ ├── project + │ │ ├── columns: ab.i4:2 + │ │ └── scan ab + │ │ └── columns: i8:1 ab.i4:2 f8:3 f4:4 d:5 rowid:6!null crdb_internal_mvcc_timestamp:7 + │ └── projections + │ └── ab.i4:2::INT8 [as=i4:15] + └── project + ├── columns: i8:8 + └── scan ab + └── columns: i8:8 ab.i4:9 f8:10 f4:11 d:12 rowid:13!null crdb_internal_mvcc_timestamp:14 + +build +SELECT i8 FROM ab UNION SELECT i4 FROM ab +---- +union + ├── columns: i8:16 + ├── left columns: ab.i8:1 + ├── right columns: i4:15 + ├── project + │ ├── columns: ab.i8:1 + │ └── scan ab + │ └── columns: ab.i8:1 ab.i4:2 f8:3 f4:4 d:5 rowid:6!null crdb_internal_mvcc_timestamp:7 + └── project + ├── columns: i4:15 + ├── project + │ ├── columns: ab.i4:9 + │ └── scan ab + │ └── columns: ab.i8:8 ab.i4:9 f8:10 f4:11 d:12 rowid:13!null crdb_internal_mvcc_timestamp:14 + └── projections + └── ab.i4:9::INT8 [as=i4:15] + +build +SELECT f4 FROM ab UNION SELECT f8 FROM ab +---- +union + ├── columns: f4:16 + ├── left columns: f4:15 + ├── right columns: f8:10 + ├── project + │ ├── columns: f4:15 + │ ├── project + │ │ ├── columns: ab.f4:4 + │ │ └── scan ab + │ │ └── columns: i8:1 i4:2 f8:3 ab.f4:4 d:5 rowid:6!null crdb_internal_mvcc_timestamp:7 + │ └── projections + │ └── ab.f4:4::FLOAT8 [as=f4:15] + └── project + ├── columns: f8:10 + └── scan ab + └── columns: i8:8 i4:9 f8:10 ab.f4:11 d:12 rowid:13!null crdb_internal_mvcc_timestamp:14 + +build +SELECT i8 FROM ab UNION SELECT f4 FROM ab +---- +union + ├── columns: i8:16 + ├── left columns: i8:15 + ├── right columns: f4:11 + ├── project + │ ├── columns: i8:15 + │ ├── project + │ │ ├── columns: ab.i8:1 + │ │ └── scan ab + │ │ └── columns: ab.i8:1 i4:2 f8:3 f4:4 d:5 rowid:6!null crdb_internal_mvcc_timestamp:7 + │ └── projections + │ └── ab.i8:1::FLOAT4 [as=i8:15] + └── project + ├── columns: f4:11 + └── scan ab + └── columns: ab.i8:8 i4:9 f8:10 f4:11 d:12 rowid:13!null crdb_internal_mvcc_timestamp:14 + +build +SELECT i8 FROM ab UNION SELECT d FROM ab +---- +union + ├── columns: i8:16 + ├── left columns: i8:15 + ├── right columns: d:12 + ├── project + │ ├── columns: i8:15 + │ ├── project + │ │ ├── columns: ab.i8:1 + │ │ └── scan ab + │ │ └── columns: ab.i8:1 i4:2 f8:3 f4:4 d:5 rowid:6!null crdb_internal_mvcc_timestamp:7 + │ └── projections + │ └── ab.i8:1::DECIMAL [as=i8:15] + └── project + ├── columns: d:12 + └── scan ab + └── columns: ab.i8:8 i4:9 f8:10 f4:11 d:12 rowid:13!null crdb_internal_mvcc_timestamp:14 + +build +SELECT d FROM ab UNION SELECT f8 FROM ab +---- +union + ├── columns: d:16 + ├── left columns: ab.d:5 + ├── right columns: f8:15 + ├── project + │ ├── columns: ab.d:5 + │ └── scan ab + │ └── columns: i8:1 i4:2 ab.f8:3 f4:4 ab.d:5 rowid:6!null crdb_internal_mvcc_timestamp:7 + └── project + ├── columns: f8:15 + ├── project + │ ├── columns: ab.f8:10 + │ └── scan ab + │ └── columns: i8:8 i4:9 ab.f8:10 f4:11 ab.d:12 rowid:13!null crdb_internal_mvcc_timestamp:14 + └── projections + └── ab.f8:10::DECIMAL [as=f8:15] diff --git a/pkg/sql/opt/optbuilder/union.go b/pkg/sql/opt/optbuilder/union.go index ab6d1b3284fc..742818cc6ecf 100644 --- a/pkg/sql/opt/optbuilder/union.go +++ b/pkg/sql/opt/optbuilder/union.go @@ -37,7 +37,7 @@ func (b *Builder) buildUnionClause( } func (b *Builder) buildSetOp( - typ tree.UnionType, all bool, inScope, leftScope, rightScope *scope, + unionType tree.UnionType, all bool, inScope, leftScope, rightScope *scope, ) (outScope *scope) { // Remove any hidden columns, as they are not included in the Union. leftScope.removeHiddenCols() @@ -45,33 +45,22 @@ func (b *Builder) buildSetOp( outScope = inScope.push() - // propagateTypesLeft/propagateTypesRight indicate whether we need to wrap - // the left/right side in a projection to cast some of the columns to the - // correct type. - // For example: - // SELECT NULL UNION SELECT 1 - // The type of NULL is unknown, and the type of 1 is int. We need to - // wrap the left side in a project operation with a Cast expression so the - // output column will have the correct type. - propagateTypesLeft, propagateTypesRight := b.checkTypesMatch( - leftScope, rightScope, - true, /* tolerateUnknownLeft */ - true, /* tolerateUnknownRight */ - typ.String(), + setOpTypes, leftCastsNeeded, rightCastsNeeded := b.typeCheckSetOp( + leftScope, rightScope, unionType.String(), ) - if propagateTypesLeft { - leftScope = b.propagateTypes(leftScope /* dst */, rightScope /* src */) + if leftCastsNeeded { + leftScope = b.addCasts(leftScope /* dst */, setOpTypes) } - if propagateTypesRight { - rightScope = b.propagateTypes(rightScope /* dst */, leftScope /* src */) + if rightCastsNeeded { + rightScope = b.addCasts(rightScope /* dst */, setOpTypes) } // For UNION, we have to synthesize new output columns (because they contain // values from both the left and right relations). This is not necessary for // INTERSECT or EXCEPT, since these operations are basically filters on the // left relation. - if typ == tree.UnionOp { + if unionType == tree.UnionOp { outScope.cols = make([]scopeColumn, 0, len(leftScope.cols)) for i := range leftScope.cols { c := &leftScope.cols[i] @@ -92,7 +81,7 @@ func (b *Builder) buildSetOp( private := memo.SetPrivate{LeftCols: leftCols, RightCols: rightCols, OutCols: newCols} if all { - switch typ { + switch unionType { case tree.UnionOp: outScope.expr = b.factory.ConstructUnionAll(left, right, &private) case tree.IntersectOp: @@ -101,7 +90,7 @@ func (b *Builder) buildSetOp( outScope.expr = b.factory.ConstructExceptAll(left, right, &private) } } else { - switch typ { + switch unionType { case tree.UnionOp: outScope.expr = b.factory.ConstructUnion(left, right, &private) case tree.IntersectOp: @@ -114,25 +103,17 @@ func (b *Builder) buildSetOp( return outScope } -// checkTypesMatch is used when the columns must match between two scopes (e.g. -// for a UNION). Throws an error if the scopes don't have the same number of -// columns, or when column types don't match 1-1, except: -// - if tolerateUnknownLeft is set and the left column has Unknown type while -// the right has a known type (in this case it returns propagateToLeft=true). -// - if tolerateUnknownRight is set and the right column has Unknown type while -// the right has a known type (in this case it returns propagateToRight=true). +// typeCheckSetOp cross-checks the types between the left and right sides of a +// set operation and determines the output types. Either side (or both) might +// need casts (as indicated in the return values). // -// clauseTag is used only in error messages. +// Throws an error if the scopes don't have the same number of columns, or when +// column types don't match 1-1 or can't be cast to a single output type. The +// error messages use clauseTag. // -// TODO(dan): This currently checks whether the types are exactly the same, -// but Postgres is more lenient: -// http://www.postgresql.org/docs/9.5/static/typeconv-union-case.html. -func (b *Builder) checkTypesMatch( - leftScope, rightScope *scope, - tolerateUnknownLeft bool, - tolerateUnknownRight bool, - clauseTag string, -) (propagateToLeft, propagateToRight bool) { +func (b *Builder) typeCheckSetOp( + leftScope, rightScope *scope, clauseTag string, +) (setOpTypes []*types.T, leftCastsNeeded, rightCastsNeeded bool) { if len(leftScope.cols) != len(rightScope.cols) { panic(pgerror.Newf( pgcode.Syntax, @@ -141,38 +122,79 @@ func (b *Builder) checkTypesMatch( )) } + setOpTypes = make([]*types.T, len(leftScope.cols)) for i := range leftScope.cols { l := &leftScope.cols[i] r := &rightScope.cols[i] - if l.typ.Equivalent(r.typ) { - continue - } + typ := determineUnionType(l.typ, r.typ, clauseTag) + setOpTypes[i] = typ + leftCastsNeeded = leftCastsNeeded || !l.typ.Identical(typ) + rightCastsNeeded = rightCastsNeeded || !r.typ.Identical(typ) + } + return setOpTypes, leftCastsNeeded, rightCastsNeeded +} - // Note that Unknown types are equivalent so at this point at most one of - // the types can be Unknown. - if l.typ.Family() == types.UnknownFamily && tolerateUnknownLeft { - propagateToLeft = true - continue +// determineUnionType determines the resulting type of a set operation on a +// column with the given left and right types. +// +// We allow implicit up-casts between types of the same numeric family with +// different widths; between int and float; and between int/float and decimal. +// +// Throws an error if we don't support a set operation between the two types. +func determineUnionType(left, right *types.T, clauseTag string) *types.T { + if left.Identical(right) { + return left + } + + if left.Equivalent(right) { + // Do a best-effort attempt to determine which type is "larger". + if left.Width() > right.Width() { + return left } - if r.typ.Family() == types.UnknownFamily && tolerateUnknownRight { - propagateToRight = true - continue + if left.Width() < right.Width() { + return right } + // In other cases, use the left type. + return left + } + leftFam, rightFam := left.Family(), right.Family() - panic(pgerror.Newf( - pgcode.DatatypeMismatch, - "%v types %s and %s cannot be matched", clauseTag, l.typ, r.typ, - )) + if rightFam == types.UnknownFamily { + return left + } + if leftFam == types.UnknownFamily { + return right + } + + // Allow implicit upcast from int to float. Converting an int to float can be + // lossy (especially INT8 to FLOAT4), but this is what Postgres does. + if leftFam == types.FloatFamily && rightFam == types.IntFamily { + return left + } + if leftFam == types.IntFamily && rightFam == types.FloatFamily { + return right } - return propagateToLeft, propagateToRight + + // Allow implicit upcasts to decimal. + if leftFam == types.DecimalFamily && (rightFam == types.IntFamily || rightFam == types.FloatFamily) { + return left + } + if (leftFam == types.IntFamily || leftFam == types.FloatFamily) && rightFam == types.DecimalFamily { + return right + } + + // TODO(radu): Postgres has more encompassing rules: + // http://www.postgresql.org/docs/12/static/typeconv-union-case.html + panic(pgerror.Newf( + pgcode.DatatypeMismatch, + "%v types %s and %s cannot be matched", clauseTag, left, right, + )) } -// propagateTypes propagates the types of the source columns to the destination -// columns by wrapping the destination in a Project operation. The Project -// operation passes through columns that already have the correct type, and -// creates cast expressions for those that don't. -func (b *Builder) propagateTypes(dst, src *scope) *scope { +// addCasts adds a projection to a scope, adding casts as necessary so that the +// resulting columns have the given types. +func (b *Builder) addCasts(dst *scope, outTypes []*types.T) *scope { expr := dst.expr.(memo.RelExpr) dstCols := dst.cols @@ -180,12 +202,10 @@ func (b *Builder) propagateTypes(dst, src *scope) *scope { dst.cols = make([]scopeColumn, 0, len(dstCols)) for i := 0; i < len(dstCols); i++ { - dstType := dstCols[i].typ - srcType := src.cols[i].typ - if dstType.Family() == types.UnknownFamily && srcType.Family() != types.UnknownFamily { + if !dstCols[i].typ.Identical(outTypes[i]) { // Create a new column which casts the old column to the correct type. - castExpr := b.factory.ConstructCast(b.factory.ConstructVariable(dstCols[i].id), srcType) - b.synthesizeColumn(dst, string(dstCols[i].name), srcType, nil /* expr */, castExpr) + castExpr := b.factory.ConstructCast(b.factory.ConstructVariable(dstCols[i].id), outTypes[i]) + b.synthesizeColumn(dst, string(dstCols[i].name), outTypes[i], nil /* expr */, castExpr) } else { // The column is already the correct type, so add it as a passthrough // column. diff --git a/pkg/sql/opt/optbuilder/union_test.go b/pkg/sql/opt/optbuilder/union_test.go new file mode 100644 index 000000000000..d422898cfb6c --- /dev/null +++ b/pkg/sql/opt/optbuilder/union_test.go @@ -0,0 +1,99 @@ +// Copyright 2021 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package optbuilder + +import ( + "testing" + + "github.com/cockroachdb/cockroach/pkg/sql/types" +) + +func TestUnionType(t *testing.T) { + testCases := []struct { + left, right, expected *types.T + }{ + { + left: types.Unknown, + right: types.Int, + expected: types.Int, + }, + { + left: types.Int, + right: types.Unknown, + expected: types.Int, + }, + { + left: types.Int4, + right: types.Int, + expected: types.Int, + }, + { + left: types.Int4, + right: types.Int2, + expected: types.Int4, + }, + { + left: types.Float4, + right: types.Float, + expected: types.Float, + }, + { + left: types.MakeDecimal(12 /* precision */, 5 /* scale */), + right: types.MakeDecimal(10 /* precision */, 7 /* scale */), + expected: types.MakeDecimal(10 /* precision */, 7 /* scale */), + }, + { + // At the same scale, we use the left type. + left: types.MakeDecimal(10 /* precision */, 1 /* scale */), + right: types.MakeDecimal(12 /* precision */, 1 /* scale */), + expected: types.MakeDecimal(10 /* precision */, 1 /* scale */), + }, + { + left: types.Int4, + right: types.Decimal, + expected: types.Decimal, + }, + { + left: types.Decimal, + right: types.Float, + expected: types.Decimal, + }, + { + // Error. + left: types.Float, + right: types.String, + expected: nil, + }, + } + + for _, tc := range testCases { + result := func() *types.T { + defer func() { + if r := recover(); r != nil { + // Swallow any error and return nil. + } + }() + return determineUnionType(tc.left, tc.right, "test") + }() + toStr := func(t *types.T) string { + if t == nil { + return "" + } + return t.SQLString() + } + if toStr(result) != toStr(tc.expected) { + t.Errorf( + "left: %s right: %s expected: %s got: %s", + toStr(tc.left), toStr(tc.right), toStr(tc.expected), toStr(result), + ) + } + } +} diff --git a/pkg/sql/opt/optbuilder/with.go b/pkg/sql/opt/optbuilder/with.go index e53e569c85f7..217ea641e9a8 100644 --- a/pkg/sql/opt/optbuilder/with.go +++ b/pkg/sql/opt/optbuilder/with.go @@ -186,13 +186,21 @@ func (b *Builder) buildCTE( // We allow propagation of types from the initial query to the recursive // query. - _, propagateToRight := b.checkTypesMatch(initialScope, recursiveScope, - false, /* tolerateUnknownLeft */ - true, /* tolerateUnknownRight */ - "UNION", - ) - if propagateToRight { - recursiveScope = b.propagateTypes(recursiveScope /* dst */, initialScope /* src */) + outTypes, leftCastsNeeded, rightCastsNeeded := b.typeCheckSetOp(initialScope, recursiveScope, "UNION") + if leftCastsNeeded { + // We don't support casts on the initial expression; error out. + for i := range outTypes { + if !outTypes[i].Identical(initialScope.cols[i].typ) { + panic(pgerror.Newf( + pgcode.DatatypeMismatch, + "UNION types %s and %s cannot be matched", + initialScope.cols[i].typ, recursiveScope.cols[i].typ, + )) + } + } + } + if rightCastsNeeded { + recursiveScope = b.addCasts(recursiveScope, outTypes) } private := memo.RecursiveCTEPrivate{