Skip to content

Commit

Permalink
opt: cast to identical types for set operations
Browse files Browse the repository at this point in the history
This change makes the optbuilder more strict when building set
operations. Previously, it could build expressions which have
corresponding left/right types which are `Equivalent()`, but not
`Identical()`. This leads to errors in vectorized execution, when we
e.g. try to union a INT8 with an INT4.

We now make the types on both sides `Identical()`, adding casts as
necessary. We try to do a best-effort attempt to use the larger
numeric type when possible (e.g. int4->int8, int->float, float->decimal).

Fixes cockroachdb#59148.

Release note (bug fix): fixed execution errors for some queries that
use set operations (UNION / EXCEPT / INTERSECT) where a column has
types of different widths on the two sides (e.g. INT4 vs INT8).
  • Loading branch information
RaduBerinde committed Feb 17, 2021
1 parent 673a257 commit 6195a7f
Show file tree
Hide file tree
Showing 6 changed files with 347 additions and 71 deletions.
18 changes: 18 additions & 0 deletions pkg/sql/logictest/testdata/logic_test/union
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,21 @@ NULL
statement ok
CREATE TABLE ab (a INT, b INT);
SELECT a, b, rowid FROM ab UNION VALUES (1, 2, 3);
DROP TABLE ab;

# Regression test for #59148.
statement ok
CREATE TABLE ab (a INT4, b INT8);
INSERT INTO ab VALUES (1, 1), (1, 2), (2, 1), (2, 2);

query I rowsort
SELECT a FROM ab UNION SELECT b FROM ab
----
1
2

query I rowsort
SELECT b FROM ab UNION SELECT a FROM ab
----
1
2
125 changes: 125 additions & 0 deletions pkg/sql/opt/optbuilder/testdata/union
Original file line number Diff line number Diff line change
Expand Up @@ -874,3 +874,128 @@ except
│ │ └── 1 [as="?column?":6]
│ └── 1
└── (1,)

# Verify that we add casts for equivalent, but not identical types.
exec-ddl
CREATE TABLE ab (i8 INT8, i4 INT4, f8 FLOAT, f4 FLOAT4, d DECIMAL)
----

build
SELECT i4 FROM ab UNION SELECT i8 FROM ab
----
union
├── columns: i4:16
├── left columns: i4:15
├── right columns: i8:8
├── project
│ ├── columns: i4:15
│ ├── project
│ │ ├── columns: ab.i4:2
│ │ └── scan ab
│ │ └── columns: i8:1 ab.i4:2 f8:3 f4:4 d:5 rowid:6!null crdb_internal_mvcc_timestamp:7
│ └── projections
│ └── ab.i4:2::INT8 [as=i4:15]
└── project
├── columns: i8:8
└── scan ab
└── columns: i8:8 ab.i4:9 f8:10 f4:11 d:12 rowid:13!null crdb_internal_mvcc_timestamp:14

build
SELECT i8 FROM ab UNION SELECT i4 FROM ab
----
union
├── columns: i8:16
├── left columns: ab.i8:1
├── right columns: i4:15
├── project
│ ├── columns: ab.i8:1
│ └── scan ab
│ └── columns: ab.i8:1 ab.i4:2 f8:3 f4:4 d:5 rowid:6!null crdb_internal_mvcc_timestamp:7
└── project
├── columns: i4:15
├── project
│ ├── columns: ab.i4:9
│ └── scan ab
│ └── columns: ab.i8:8 ab.i4:9 f8:10 f4:11 d:12 rowid:13!null crdb_internal_mvcc_timestamp:14
└── projections
└── ab.i4:9::INT8 [as=i4:15]

build
SELECT f4 FROM ab UNION SELECT f8 FROM ab
----
union
├── columns: f4:16
├── left columns: f4:15
├── right columns: f8:10
├── project
│ ├── columns: f4:15
│ ├── project
│ │ ├── columns: ab.f4:4
│ │ └── scan ab
│ │ └── columns: i8:1 i4:2 f8:3 ab.f4:4 d:5 rowid:6!null crdb_internal_mvcc_timestamp:7
│ └── projections
│ └── ab.f4:4::FLOAT8 [as=f4:15]
└── project
├── columns: f8:10
└── scan ab
└── columns: i8:8 i4:9 f8:10 ab.f4:11 d:12 rowid:13!null crdb_internal_mvcc_timestamp:14

build
SELECT i8 FROM ab UNION SELECT f4 FROM ab
----
union
├── columns: i8:16
├── left columns: i8:15
├── right columns: f4:11
├── project
│ ├── columns: i8:15
│ ├── project
│ │ ├── columns: ab.i8:1
│ │ └── scan ab
│ │ └── columns: ab.i8:1 i4:2 f8:3 f4:4 d:5 rowid:6!null crdb_internal_mvcc_timestamp:7
│ └── projections
│ └── ab.i8:1::FLOAT4 [as=i8:15]
└── project
├── columns: f4:11
└── scan ab
└── columns: ab.i8:8 i4:9 f8:10 f4:11 d:12 rowid:13!null crdb_internal_mvcc_timestamp:14

build
SELECT i8 FROM ab UNION SELECT d FROM ab
----
union
├── columns: i8:16
├── left columns: i8:15
├── right columns: d:12
├── project
│ ├── columns: i8:15
│ ├── project
│ │ ├── columns: ab.i8:1
│ │ └── scan ab
│ │ └── columns: ab.i8:1 i4:2 f8:3 f4:4 d:5 rowid:6!null crdb_internal_mvcc_timestamp:7
│ └── projections
│ └── ab.i8:1::DECIMAL [as=i8:15]
└── project
├── columns: d:12
└── scan ab
└── columns: ab.i8:8 i4:9 f8:10 f4:11 d:12 rowid:13!null crdb_internal_mvcc_timestamp:14

build
SELECT d FROM ab UNION SELECT f8 FROM ab
----
union
├── columns: d:16
├── left columns: ab.d:5
├── right columns: f8:15
├── project
│ ├── columns: ab.d:5
│ └── scan ab
│ └── columns: i8:1 i4:2 ab.f8:3 f4:4 ab.d:5 rowid:6!null crdb_internal_mvcc_timestamp:7
└── project
├── columns: f8:15
├── project
│ ├── columns: ab.f8:10
│ └── scan ab
│ └── columns: i8:8 i4:9 ab.f8:10 f4:11 ab.d:12 rowid:13!null crdb_internal_mvcc_timestamp:14
└── projections
└── ab.f8:10::DECIMAL [as=f8:15]
6 changes: 6 additions & 0 deletions pkg/sql/opt/optbuilder/testdata/with
Original file line number Diff line number Diff line change
Expand Up @@ -1327,6 +1327,12 @@ with &3 (cte)
└── mapping:
└── a:5 => a:8

# We don't support upcasting the "initial" query.
build
WITH RECURSIVE cte(x) AS (SELECT a FROM x UNION ALL SELECT x::FLOAT FROM cte WHERE x < 10) SELECT * FROM cte;
----
error (42804): UNION types int and float cannot be matched for WITH RECURSIVE

# Mutating WITHs not allowed at non-root positions.
build
SELECT * FROM (WITH foo AS (INSERT INTO y VALUES (1) RETURNING *) SELECT * FROM foo)
Expand Down
148 changes: 84 additions & 64 deletions pkg/sql/opt/optbuilder/union.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,41 +37,30 @@ func (b *Builder) buildUnionClause(
}

func (b *Builder) buildSetOp(
typ tree.UnionType, all bool, inScope, leftScope, rightScope *scope,
unionType tree.UnionType, all bool, inScope, leftScope, rightScope *scope,
) (outScope *scope) {
// Remove any hidden columns, as they are not included in the Union.
leftScope.removeHiddenCols()
rightScope.removeHiddenCols()

outScope = inScope.push()

// propagateTypesLeft/propagateTypesRight indicate whether we need to wrap
// the left/right side in a projection to cast some of the columns to the
// correct type.
// For example:
// SELECT NULL UNION SELECT 1
// The type of NULL is unknown, and the type of 1 is int. We need to
// wrap the left side in a project operation with a Cast expression so the
// output column will have the correct type.
propagateTypesLeft, propagateTypesRight := b.checkTypesMatch(
leftScope, rightScope,
true, /* tolerateUnknownLeft */
true, /* tolerateUnknownRight */
typ.String(),
setOpTypes, leftCastsNeeded, rightCastsNeeded := b.typeCheckSetOp(
leftScope, rightScope, unionType.String(),
)

if propagateTypesLeft {
leftScope = b.propagateTypes(leftScope /* dst */, rightScope /* src */)
if leftCastsNeeded {
leftScope = b.addCasts(leftScope /* dst */, setOpTypes)
}
if propagateTypesRight {
rightScope = b.propagateTypes(rightScope /* dst */, leftScope /* src */)
if rightCastsNeeded {
rightScope = b.addCasts(rightScope /* dst */, setOpTypes)
}

// For UNION, we have to synthesize new output columns (because they contain
// values from both the left and right relations). This is not necessary for
// INTERSECT or EXCEPT, since these operations are basically filters on the
// left relation.
if typ == tree.UnionOp {
if unionType == tree.UnionOp {
outScope.cols = make([]scopeColumn, 0, len(leftScope.cols))
for i := range leftScope.cols {
c := &leftScope.cols[i]
Expand All @@ -92,7 +81,7 @@ func (b *Builder) buildSetOp(
private := memo.SetPrivate{LeftCols: leftCols, RightCols: rightCols, OutCols: newCols}

if all {
switch typ {
switch unionType {
case tree.UnionOp:
outScope.expr = b.factory.ConstructUnionAll(left, right, &private)
case tree.IntersectOp:
Expand All @@ -101,7 +90,7 @@ func (b *Builder) buildSetOp(
outScope.expr = b.factory.ConstructExceptAll(left, right, &private)
}
} else {
switch typ {
switch unionType {
case tree.UnionOp:
outScope.expr = b.factory.ConstructUnion(left, right, &private)
case tree.IntersectOp:
Expand All @@ -114,25 +103,17 @@ func (b *Builder) buildSetOp(
return outScope
}

// checkTypesMatch is used when the columns must match between two scopes (e.g.
// for a UNION). Throws an error if the scopes don't have the same number of
// columns, or when column types don't match 1-1, except:
// - if tolerateUnknownLeft is set and the left column has Unknown type while
// the right has a known type (in this case it returns propagateToLeft=true).
// - if tolerateUnknownRight is set and the right column has Unknown type while
// the right has a known type (in this case it returns propagateToRight=true).
// typeCheckSetOp cross-checks the types between the left and right sides of a
// set operation and determines the output types. Either side (or both) might
// need casts (as indicated in the return values).
//
// clauseTag is used only in error messages.
// Throws an error if the scopes don't have the same number of columns, or when
// column types don't match 1-1 or can't be cast to a single output type. The
// error messages use clauseTag.
//
// TODO(dan): This currently checks whether the types are exactly the same,
// but Postgres is more lenient:
// http://www.postgresql.org/docs/9.5/static/typeconv-union-case.html.
func (b *Builder) checkTypesMatch(
leftScope, rightScope *scope,
tolerateUnknownLeft bool,
tolerateUnknownRight bool,
clauseTag string,
) (propagateToLeft, propagateToRight bool) {
func (b *Builder) typeCheckSetOp(
leftScope, rightScope *scope, clauseTag string,
) (setOpTypes []*types.T, leftCastsNeeded, rightCastsNeeded bool) {
if len(leftScope.cols) != len(rightScope.cols) {
panic(pgerror.Newf(
pgcode.Syntax,
Expand All @@ -141,51 +122,90 @@ func (b *Builder) checkTypesMatch(
))
}

setOpTypes = make([]*types.T, len(leftScope.cols))
for i := range leftScope.cols {
l := &leftScope.cols[i]
r := &rightScope.cols[i]

if l.typ.Equivalent(r.typ) {
continue
}
typ := determineUnionType(l.typ, r.typ, clauseTag)
setOpTypes[i] = typ
leftCastsNeeded = leftCastsNeeded || !l.typ.Identical(typ)
rightCastsNeeded = rightCastsNeeded || !r.typ.Identical(typ)
}
return setOpTypes, leftCastsNeeded, rightCastsNeeded
}

// Note that Unknown types are equivalent so at this point at most one of
// the types can be Unknown.
if l.typ.Family() == types.UnknownFamily && tolerateUnknownLeft {
propagateToLeft = true
continue
// determineUnionType determines the resulting type of a set operation on a
// column with the given left and right types.
//
// We allow implicit up-casts between types of the same numeric family with
// different widths; between int and float; and between int/float and decimal.
//
// Throws an error if we don't support a set operation between the two types.
func determineUnionType(left, right *types.T, clauseTag string) *types.T {
if left.Identical(right) {
return left
}

if left.Equivalent(right) {
// Do a best-effort attempt to determine which type is "larger".
if left.Width() > right.Width() {
return left
}
if r.typ.Family() == types.UnknownFamily && tolerateUnknownRight {
propagateToRight = true
continue
if left.Width() < right.Width() {
return right
}
// In other cases, use the left type.
return left
}
leftFam, rightFam := left.Family(), right.Family()

panic(pgerror.Newf(
pgcode.DatatypeMismatch,
"%v types %s and %s cannot be matched", clauseTag, l.typ, r.typ,
))
if rightFam == types.UnknownFamily {
return left
}
if leftFam == types.UnknownFamily {
return right
}

// Allow implicit upcast from int to float. Converting an int to float can be
// lossy (especially INT8 to FLOAT4), but this is what Postgres does.
if leftFam == types.FloatFamily && rightFam == types.IntFamily {
return left
}
if leftFam == types.IntFamily && rightFam == types.FloatFamily {
return right
}
return propagateToLeft, propagateToRight

// Allow implicit upcasts to decimal.
if leftFam == types.DecimalFamily && (rightFam == types.IntFamily || rightFam == types.FloatFamily) {
return left
}
if (leftFam == types.IntFamily || leftFam == types.FloatFamily) && rightFam == types.DecimalFamily {
return right
}

// TODO(radu): Postgres has more encompassing rules:
// http://www.postgresql.org/docs/12/static/typeconv-union-case.html
panic(pgerror.Newf(
pgcode.DatatypeMismatch,
"%v types %s and %s cannot be matched", clauseTag, left, right,
))
}

// propagateTypes propagates the types of the source columns to the destination
// columns by wrapping the destination in a Project operation. The Project
// operation passes through columns that already have the correct type, and
// creates cast expressions for those that don't.
func (b *Builder) propagateTypes(dst, src *scope) *scope {
// addCasts adds a projection to a scope, adding casts as necessary so that the
// resulting columns have the given types.
func (b *Builder) addCasts(dst *scope, outTypes []*types.T) *scope {
expr := dst.expr.(memo.RelExpr)
dstCols := dst.cols

dst = dst.push()
dst.cols = make([]scopeColumn, 0, len(dstCols))

for i := 0; i < len(dstCols); i++ {
dstType := dstCols[i].typ
srcType := src.cols[i].typ
if dstType.Family() == types.UnknownFamily && srcType.Family() != types.UnknownFamily {
if !dstCols[i].typ.Identical(outTypes[i]) {
// Create a new column which casts the old column to the correct type.
castExpr := b.factory.ConstructCast(b.factory.ConstructVariable(dstCols[i].id), srcType)
b.synthesizeColumn(dst, string(dstCols[i].name), srcType, nil /* expr */, castExpr)
castExpr := b.factory.ConstructCast(b.factory.ConstructVariable(dstCols[i].id), outTypes[i])
b.synthesizeColumn(dst, string(dstCols[i].name), outTypes[i], nil /* expr */, castExpr)
} else {
// The column is already the correct type, so add it as a passthrough
// column.
Expand Down
Loading

0 comments on commit 6195a7f

Please sign in to comment.