From 8795571a0cc789a6a936389648cfdd57ddf8ecc0 Mon Sep 17 00:00:00 2001 From: Radu Berinde Date: Sat, 13 Feb 2021 22:25:11 -0500 Subject: [PATCH] opt: cast to identical types for set operations This change makes the optbuilder more strict with respect to set operations. Previously, it would only require types to be `Equivalent` across the two sides. This leads to errors in vectorized execution, when we e.g. try to union a INT8 with an INT4. We now require the types to be `Identical`, and we add casts as necessary. We use the type from the left side in this case. This can lead to questionable behavior when the right side is a "wider" type, but we don't have any facility to robustly determine what the type should be. Fixes #59148. Release note (bug fix): fixed execution errors for some queries that use set operations (UNION / EXCEPT / INTERSECT) where a column has types of different widths on the two sides (e.g. INT4 vs INT8). --- pkg/sql/logictest/testdata/logic_test/union | 18 +++ pkg/sql/opt/optbuilder/testdata/union | 45 +++++++ pkg/sql/opt/optbuilder/union.go | 136 ++++++++++---------- pkg/sql/opt/optbuilder/with.go | 18 ++- 4 files changed, 143 insertions(+), 74 deletions(-) diff --git a/pkg/sql/logictest/testdata/logic_test/union b/pkg/sql/logictest/testdata/logic_test/union index 9a6172e49b26..4997fdddcb3b 100644 --- a/pkg/sql/logictest/testdata/logic_test/union +++ b/pkg/sql/logictest/testdata/logic_test/union @@ -348,3 +348,21 @@ NULL statement ok CREATE TABLE ab (a INT, b INT); SELECT a, b, rowid FROM ab UNION VALUES (1, 2, 3); +DROP TABLE ab; + +# Regression test for #59148. +statement ok +CREATE TABLE ab (a INT4, b INT8); +INSERT INTO ab VALUES (1, 1), (1, 2), (2, 1), (2, 2); + +query I rowsort +SELECT a FROM ab UNION SELECT b FROM ab +---- +1 +2 + +query I rowsort +SELECT b FROM ab UNION SELECT a FROM ab +---- +1 +2 diff --git a/pkg/sql/opt/optbuilder/testdata/union b/pkg/sql/opt/optbuilder/testdata/union index cadb97740dfa..ed19dc75c01a 100644 --- a/pkg/sql/opt/optbuilder/testdata/union +++ b/pkg/sql/opt/optbuilder/testdata/union @@ -874,3 +874,48 @@ except │ │ └── 1 [as="?column?":6] │ └── 1 └── (1,) + +# Verify that we add casts for equivalent, but not identical types. +exec-ddl +CREATE TABLE ab (a INT8, b INT4) +---- + +build +SELECT a FROM ab UNION SELECT b FROM ab +---- +union + ├── columns: a:10 + ├── left columns: ab.a:1 + ├── right columns: b:9 + ├── project + │ ├── columns: ab.a:1 + │ └── scan ab + │ └── columns: ab.a:1 ab.b:2 rowid:3!null crdb_internal_mvcc_timestamp:4 + └── project + ├── columns: b:9 + ├── project + │ ├── columns: ab.b:6 + │ └── scan ab + │ └── columns: ab.a:5 ab.b:6 rowid:7!null crdb_internal_mvcc_timestamp:8 + └── projections + └── ab.b:6::INT8 [as=b:9] + +build +SELECT b FROM ab UNION SELECT a FROM ab +---- +union + ├── columns: b:10 + ├── left columns: ab.b:2 + ├── right columns: a:9 + ├── project + │ ├── columns: ab.b:2 + │ └── scan ab + │ └── columns: ab.a:1 ab.b:2 rowid:3!null crdb_internal_mvcc_timestamp:4 + └── project + ├── columns: a:9 + ├── project + │ ├── columns: ab.a:5 + │ └── scan ab + │ └── columns: ab.a:5 ab.b:6 rowid:7!null crdb_internal_mvcc_timestamp:8 + └── projections + └── ab.a:5::INT4 [as=a:9] diff --git a/pkg/sql/opt/optbuilder/union.go b/pkg/sql/opt/optbuilder/union.go index ab6d1b3284fc..4cb32c861da0 100644 --- a/pkg/sql/opt/optbuilder/union.go +++ b/pkg/sql/opt/optbuilder/union.go @@ -37,7 +37,7 @@ func (b *Builder) buildUnionClause( } func (b *Builder) buildSetOp( - typ tree.UnionType, all bool, inScope, leftScope, rightScope *scope, + unionType tree.UnionType, all bool, inScope, leftScope, rightScope *scope, ) (outScope *scope) { // Remove any hidden columns, as they are not included in the Union. leftScope.removeHiddenCols() @@ -45,33 +45,37 @@ func (b *Builder) buildSetOp( outScope = inScope.push() - // propagateTypesLeft/propagateTypesRight indicate whether we need to wrap - // the left/right side in a projection to cast some of the columns to the - // correct type. - // For example: - // SELECT NULL UNION SELECT 1 - // The type of NULL is unknown, and the type of 1 is int. We need to - // wrap the left side in a project operation with a Cast expression so the - // output column will have the correct type. - propagateTypesLeft, propagateTypesRight := b.checkTypesMatch( - leftScope, rightScope, - true, /* tolerateUnknownLeft */ - true, /* tolerateUnknownRight */ - typ.String(), + //// propagateTypesLeft/propagateTypesRight indicate whether we need to wrap + //// the left/right side in a projection to cast some of the columns to the + //// correct type. + //// For example: + //// SELECT NULL UNION SELECT 1 + //// The type of NULL is unknown, and the type of 1 is int. We need to + //// wrap the left side in a project operation with a Cast expression so the + //// output column will have the correct type. + //propagateTypesLeft, propagateTypesRight := b.checkTypesMatch( + // leftScope, rightScope, + // true, /* tolerateUnknownLeft */ + // true, /* tolerateUnknownRight */ + // unionType.String(), + //) + + setOpTypes, leftCastsNeeded, rightCastsNeeded := b.typeCheckSetOp( + leftScope, rightScope, unionType.String(), ) - if propagateTypesLeft { - leftScope = b.propagateTypes(leftScope /* dst */, rightScope /* src */) + if leftCastsNeeded { + leftScope = b.addCasts(leftScope /* dst */, setOpTypes) } - if propagateTypesRight { - rightScope = b.propagateTypes(rightScope /* dst */, leftScope /* src */) + if rightCastsNeeded { + rightScope = b.addCasts(rightScope /* dst */, setOpTypes) } // For UNION, we have to synthesize new output columns (because they contain // values from both the left and right relations). This is not necessary for // INTERSECT or EXCEPT, since these operations are basically filters on the // left relation. - if typ == tree.UnionOp { + if unionType == tree.UnionOp { outScope.cols = make([]scopeColumn, 0, len(leftScope.cols)) for i := range leftScope.cols { c := &leftScope.cols[i] @@ -92,7 +96,7 @@ func (b *Builder) buildSetOp( private := memo.SetPrivate{LeftCols: leftCols, RightCols: rightCols, OutCols: newCols} if all { - switch typ { + switch unionType { case tree.UnionOp: outScope.expr = b.factory.ConstructUnionAll(left, right, &private) case tree.IntersectOp: @@ -101,7 +105,7 @@ func (b *Builder) buildSetOp( outScope.expr = b.factory.ConstructExceptAll(left, right, &private) } } else { - switch typ { + switch unionType { case tree.UnionOp: outScope.expr = b.factory.ConstructUnion(left, right, &private) case tree.IntersectOp: @@ -114,25 +118,17 @@ func (b *Builder) buildSetOp( return outScope } -// checkTypesMatch is used when the columns must match between two scopes (e.g. -// for a UNION). Throws an error if the scopes don't have the same number of -// columns, or when column types don't match 1-1, except: -// - if tolerateUnknownLeft is set and the left column has Unknown type while -// the right has a known type (in this case it returns propagateToLeft=true). -// - if tolerateUnknownRight is set and the right column has Unknown type while -// the right has a known type (in this case it returns propagateToRight=true). +// typeCheckSetOp cross-checks the types between the left and right sides of a +// set operation and determines the output types. Either side (or both) might +// need casts (as indicated in the return values). // -// clauseTag is used only in error messages. +// Throws an error if the scopes don't have the same number of columns, or when +// column types don't match 1-1 or can't be cast to a single output type. The +// error messages use clauseTag. // -// TODO(dan): This currently checks whether the types are exactly the same, -// but Postgres is more lenient: -// http://www.postgresql.org/docs/9.5/static/typeconv-union-case.html. -func (b *Builder) checkTypesMatch( - leftScope, rightScope *scope, - tolerateUnknownLeft bool, - tolerateUnknownRight bool, - clauseTag string, -) (propagateToLeft, propagateToRight bool) { +func (b *Builder) typeCheckSetOp( + leftScope, rightScope *scope, clauseTag string, +) (setOpTypes []*types.T, leftCastsNeeded, rightCastsNeeded bool) { if len(leftScope.cols) != len(rightScope.cols) { panic(pgerror.Newf( pgcode.Syntax, @@ -141,38 +137,46 @@ func (b *Builder) checkTypesMatch( )) } + setOpTypes = make([]*types.T, len(leftScope.cols)) for i := range leftScope.cols { l := &leftScope.cols[i] r := &rightScope.cols[i] - if l.typ.Equivalent(r.typ) { - continue + switch { + case l.typ.Identical(r.typ): + setOpTypes[i] = l.typ + + case l.typ.Equivalent(r.typ): + // Equivalent but not identical types. Use the type from the left and add + // a cast on the right-hand side. + // TODO(radu): perhaps we should do a best effort attempt to choose the + // "wider type". + setOpTypes[i] = l.typ + rightCastsNeeded = true + + case l.typ.Family() == types.UnknownFamily: + setOpTypes[i] = r.typ + leftCastsNeeded = true + + case r.typ.Family() == types.UnknownFamily: + setOpTypes[i] = l.typ + rightCastsNeeded = true + + default: + // TODO(dan): Postgres is more lenient: + // http://www.postgresql.org/docs/9.5/static/typeconv-union-case.html + panic(pgerror.Newf( + pgcode.DatatypeMismatch, + "%v types %s and %s cannot be matched", clauseTag, l.typ, r.typ, + )) } - - // Note that Unknown types are equivalent so at this point at most one of - // the types can be Unknown. - if l.typ.Family() == types.UnknownFamily && tolerateUnknownLeft { - propagateToLeft = true - continue - } - if r.typ.Family() == types.UnknownFamily && tolerateUnknownRight { - propagateToRight = true - continue - } - - panic(pgerror.Newf( - pgcode.DatatypeMismatch, - "%v types %s and %s cannot be matched", clauseTag, l.typ, r.typ, - )) } - return propagateToLeft, propagateToRight + return setOpTypes, leftCastsNeeded, rightCastsNeeded } -// propagateTypes propagates the types of the source columns to the destination -// columns by wrapping the destination in a Project operation. The Project -// operation passes through columns that already have the correct type, and -// creates cast expressions for those that don't. -func (b *Builder) propagateTypes(dst, src *scope) *scope { +// addCasts adds a projection to a scope, adding casts as necessary so that the +// resulting columns have the given types. +func (b *Builder) addCasts(dst *scope, outTypes []*types.T) *scope { expr := dst.expr.(memo.RelExpr) dstCols := dst.cols @@ -180,12 +184,10 @@ func (b *Builder) propagateTypes(dst, src *scope) *scope { dst.cols = make([]scopeColumn, 0, len(dstCols)) for i := 0; i < len(dstCols); i++ { - dstType := dstCols[i].typ - srcType := src.cols[i].typ - if dstType.Family() == types.UnknownFamily && srcType.Family() != types.UnknownFamily { + if !dstCols[i].typ.Identical(outTypes[i]) { // Create a new column which casts the old column to the correct type. - castExpr := b.factory.ConstructCast(b.factory.ConstructVariable(dstCols[i].id), srcType) - b.synthesizeColumn(dst, string(dstCols[i].name), srcType, nil /* expr */, castExpr) + castExpr := b.factory.ConstructCast(b.factory.ConstructVariable(dstCols[i].id), outTypes[i]) + b.synthesizeColumn(dst, string(dstCols[i].name), outTypes[i], nil /* expr */, castExpr) } else { // The column is already the correct type, so add it as a passthrough // column. diff --git a/pkg/sql/opt/optbuilder/with.go b/pkg/sql/opt/optbuilder/with.go index e53e569c85f7..85fc88c2dc61 100644 --- a/pkg/sql/opt/optbuilder/with.go +++ b/pkg/sql/opt/optbuilder/with.go @@ -186,13 +186,17 @@ func (b *Builder) buildCTE( // We allow propagation of types from the initial query to the recursive // query. - _, propagateToRight := b.checkTypesMatch(initialScope, recursiveScope, - false, /* tolerateUnknownLeft */ - true, /* tolerateUnknownRight */ - "UNION", - ) - if propagateToRight { - recursiveScope = b.propagateTypes(recursiveScope /* dst */, initialScope /* src */) + outTypes, leftCastsNeeded, rightCastsNeeded := b.typeCheckSetOp(initialScope, recursiveScope, "UNION") + if leftCastsNeeded { + // This can only happen if a column has null type in the initial query + // and non-null type in the recursive query. + panic(pgerror.Newf( + pgcode.DatatypeMismatch, + "unknown type in WITH RECURSIVE initial query cannot be matched with a definite type in the recursive query", + )) + } + if rightCastsNeeded { + recursiveScope = b.addCasts(recursiveScope, outTypes) } private := memo.RecursiveCTEPrivate{