-
Notifications
You must be signed in to change notification settings - Fork 3.8k
/
union.go
222 lines (198 loc) · 7.48 KB
/
union.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
// Copyright 2018 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
package optbuilder
import (
"github.com/cockroachdb/cockroach/pkg/sql/opt/memo"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode"
"github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror"
"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
"github.com/cockroachdb/cockroach/pkg/sql/types"
)
// buildUnionClause builds a set of memo groups that represent the given union
// clause.
//
// See Builder.buildStmt for a description of the remaining input and
// return values.
func (b *Builder) buildUnionClause(
clause *tree.UnionClause, desiredTypes []*types.T, inScope *scope,
) (outScope *scope) {
leftScope := b.buildStmt(clause.Left, desiredTypes, inScope)
// Try to propagate types left-to-right, if we didn't already have desired
// types.
if len(desiredTypes) == 0 {
desiredTypes = leftScope.makeColumnTypes()
}
rightScope := b.buildStmt(clause.Right, desiredTypes, inScope)
return b.buildSetOp(clause.Type, clause.All, inScope, leftScope, rightScope)
}
func (b *Builder) buildSetOp(
unionType tree.UnionType, all bool, inScope, leftScope, rightScope *scope,
) (outScope *scope) {
// Remove any hidden columns, as they are not included in the Union.
leftScope.removeHiddenCols()
rightScope.removeHiddenCols()
outScope = inScope.push()
setOpTypes, leftCastsNeeded, rightCastsNeeded := b.typeCheckSetOp(
leftScope, rightScope, unionType.String(),
)
if leftCastsNeeded {
leftScope = b.addCasts(leftScope /* dst */, setOpTypes)
}
if rightCastsNeeded {
rightScope = b.addCasts(rightScope /* dst */, setOpTypes)
}
// For UNION, we have to synthesize new output columns (because they contain
// values from both the left and right relations).
//
// This is not usually necessary for INTERSECT or EXCEPT, since these
// operations are basically filters on the left relation. The exception is if
// the left input projects the same column twice, since having the same column
// ID for multiple output columns would make it too complicated to represent
// the merge ordering for streaming operations (the merge ordering must
// include all output columns for the streaming operation to work correctly).
if unionType == tree.UnionOp || leftScope.colSet().Len() < len(leftScope.cols) {
outScope.cols = make([]scopeColumn, 0, len(leftScope.cols))
for i := range leftScope.cols {
c := &leftScope.cols[i]
b.synthesizeColumn(outScope, c.name, c.typ, nil, nil /* scalar */)
}
} else {
outScope.appendColumnsFromScope(leftScope)
}
// Create the mapping between the left-side columns, right-side columns and
// new columns (if needed).
leftCols := colsToColList(leftScope.cols)
rightCols := colsToColList(rightScope.cols)
newCols := colsToColList(outScope.cols)
left := leftScope.expr
right := rightScope.expr
private := memo.SetPrivate{LeftCols: leftCols, RightCols: rightCols, OutCols: newCols}
if all {
switch unionType {
case tree.UnionOp:
outScope.expr = b.factory.ConstructUnionAll(left, right, &private)
case tree.IntersectOp:
outScope.expr = b.factory.ConstructIntersectAll(left, right, &private)
case tree.ExceptOp:
outScope.expr = b.factory.ConstructExceptAll(left, right, &private)
}
} else {
switch unionType {
case tree.UnionOp:
outScope.expr = b.factory.ConstructUnion(left, right, &private)
case tree.IntersectOp:
outScope.expr = b.factory.ConstructIntersect(left, right, &private)
case tree.ExceptOp:
outScope.expr = b.factory.ConstructExcept(left, right, &private)
}
}
return outScope
}
// typeCheckSetOp cross-checks the types between the left and right sides of a
// set operation and determines the output types. Either side (or both) might
// need casts (as indicated in the return values).
//
// Throws an error if the scopes don't have the same number of columns, or when
// column types don't match 1-1 or can't be cast to a single output type. The
// error messages use clauseTag.
//
func (b *Builder) typeCheckSetOp(
leftScope, rightScope *scope, clauseTag string,
) (setOpTypes []*types.T, leftCastsNeeded, rightCastsNeeded bool) {
if len(leftScope.cols) != len(rightScope.cols) {
panic(pgerror.Newf(
pgcode.Syntax,
"each %s query must have the same number of columns: %d vs %d",
clauseTag, len(leftScope.cols), len(rightScope.cols),
))
}
setOpTypes = make([]*types.T, len(leftScope.cols))
for i := range leftScope.cols {
l := &leftScope.cols[i]
r := &rightScope.cols[i]
typ := determineUnionType(l.typ, r.typ, clauseTag)
setOpTypes[i] = typ
leftCastsNeeded = leftCastsNeeded || !l.typ.Identical(typ)
rightCastsNeeded = rightCastsNeeded || !r.typ.Identical(typ)
}
return setOpTypes, leftCastsNeeded, rightCastsNeeded
}
// determineUnionType determines the resulting type of a set operation on a
// column with the given left and right types.
//
// We allow implicit up-casts between types of the same numeric family with
// different widths; between int and float; and between int/float and decimal.
//
// Throws an error if we don't support a set operation between the two types.
func determineUnionType(left, right *types.T, clauseTag string) *types.T {
if left.Identical(right) {
return left
}
if left.Equivalent(right) {
// Do a best-effort attempt to determine which type is "larger".
if left.Width() > right.Width() {
return left
}
if left.Width() < right.Width() {
return right
}
// In other cases, use the left type.
return left
}
leftFam, rightFam := left.Family(), right.Family()
if rightFam == types.UnknownFamily {
return left
}
if leftFam == types.UnknownFamily {
return right
}
// Allow implicit upcast from int to float. Converting an int to float can be
// lossy (especially INT8 to FLOAT4), but this is what Postgres does.
if leftFam == types.FloatFamily && rightFam == types.IntFamily {
return left
}
if leftFam == types.IntFamily && rightFam == types.FloatFamily {
return right
}
// Allow implicit upcasts to decimal.
if leftFam == types.DecimalFamily && (rightFam == types.IntFamily || rightFam == types.FloatFamily) {
return left
}
if (leftFam == types.IntFamily || leftFam == types.FloatFamily) && rightFam == types.DecimalFamily {
return right
}
// TODO(radu): Postgres has more encompassing rules:
// http://www.postgresql.org/docs/12/static/typeconv-union-case.html
panic(pgerror.Newf(
pgcode.DatatypeMismatch,
"%v types %s and %s cannot be matched", clauseTag, left, right,
))
}
// addCasts adds a projection to a scope, adding casts as necessary so that the
// resulting columns have the given types.
func (b *Builder) addCasts(dst *scope, outTypes []*types.T) *scope {
expr := dst.expr
dstCols := dst.cols
dst = dst.push()
dst.cols = make([]scopeColumn, 0, len(dstCols))
for i := 0; i < len(dstCols); i++ {
if !dstCols[i].typ.Identical(outTypes[i]) {
// Create a new column which casts the old column to the correct type.
castExpr := b.factory.ConstructCast(b.factory.ConstructVariable(dstCols[i].id), outTypes[i])
b.synthesizeColumn(dst, dstCols[i].name, outTypes[i], nil /* expr */, castExpr)
} else {
// The column is already the correct type, so add it as a passthrough
// column.
dst.appendColumn(&dstCols[i])
}
}
dst.expr = b.constructProject(expr, dst.cols)
return dst
}