Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

release-21.2: optbuilder: do not create invalid casts when building COALESCE and IF #78342

Merged
merged 1 commit into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pkg/sql/opt/exec/execbuilder/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,11 @@ func (b *Builder) buildNull(ctx *buildScalarCtx, scalar opt.ScalarExpr) (tree.Ty
// See comment in buildCast.
return tree.DNull, nil
}
return tree.ReType(tree.DNull, scalar.DataType()), nil
retypedNull, ok := tree.ReType(tree.DNull, scalar.DataType())
if !ok {
return nil, errors.AssertionFailedf("failed to retype NULL to %s", scalar.DataType())
}
return retypedNull, nil
}

func (b *Builder) buildVariable(
Expand Down
94 changes: 60 additions & 34 deletions pkg/sql/opt/optbuilder/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ func (b *Builder) buildScalar(
return b.finishBuildScalarRef(t.col, inScope, outScope, outCol, colRefs)

case *tree.AndExpr:
left := b.buildScalar(tree.ReType(t.TypedLeft(), types.Bool), inScope, nil, nil, colRefs)
right := b.buildScalar(tree.ReType(t.TypedRight(), types.Bool), inScope, nil, nil, colRefs)
left := b.buildScalar(reType(t.TypedLeft(), types.Bool), inScope, nil, nil, colRefs)
right := b.buildScalar(reType(t.TypedRight(), types.Bool), inScope, nil, nil, colRefs)
out = b.factory.ConstructAnd(left, right)

case *tree.Array:
Expand Down Expand Up @@ -214,8 +214,8 @@ func (b *Builder) buildScalar(
// select the right overload. The solution is to wrap any mismatched
// arguments with a CastExpr that preserves the static type.

left := tree.ReType(t.TypedLeft(), t.ResolvedBinOp().LeftType)
right := tree.ReType(t.TypedRight(), t.ResolvedBinOp().RightType)
left := reType(t.TypedLeft(), t.ResolvedBinOp().LeftType)
right := reType(t.TypedRight(), t.ResolvedBinOp().RightType)
out = b.constructBinary(
tree.MakeBinaryOperator(t.Operator.Symbol),
b.buildScalar(left, inScope, nil, nil, colRefs),
Expand All @@ -233,39 +233,32 @@ func (b *Builder) buildScalar(
input = memo.TrueSingleton
}

// validateCastToValType panics if tree.ReType with the given source
// type would create an invalid cast to valType.
validateCastToValType := func(src *types.T) {
if valType.Family() == types.AnyFamily || src.Identical(valType) {
// If valType's family is AnyFamily or src is identical to
// valType, then tree.Retype will not create a cast expression.
return
}
if _, ok := tree.LookupCastVolatility(src, valType, nil /* sessionData */); ok {
return
}
panic(pgerror.Newf(
pgcode.DatatypeMismatch,
"CASE types %s and %s cannot be matched", src, valType,
))
}

whens := make(memo.ScalarListExpr, 0, len(t.Whens)+1)
for i := range t.Whens {
condExpr := t.Whens[i].Cond.(tree.TypedExpr)
cond := b.buildScalar(condExpr, inScope, nil, nil, colRefs)
valExpr := t.Whens[i].Val.(tree.TypedExpr)
validateCastToValType(valExpr.ResolvedType())
valExpr = tree.ReType(valExpr, valType)
valExpr, ok := tree.ReType(t.Whens[i].Val.(tree.TypedExpr), valType)
if !ok {
panic(pgerror.Newf(
pgcode.DatatypeMismatch,
"CASE WHEN types %s and %s cannot be matched",
t.Whens[i].Val.(tree.TypedExpr).ResolvedType(), valType,
))
}
val := b.buildScalar(valExpr, inScope, nil, nil, colRefs)
whens = append(whens, b.factory.ConstructWhen(cond, val))
}
// Add the ELSE expression to the end of whens as a raw scalar expression.
var orElse opt.ScalarExpr
if t.Else != nil {
elseExpr := t.Else.(tree.TypedExpr)
validateCastToValType(elseExpr.ResolvedType())
elseExpr = tree.ReType(elseExpr, valType)
elseExpr, ok := tree.ReType(t.Else.(tree.TypedExpr), valType)
if !ok {
panic(pgerror.Newf(
pgcode.DatatypeMismatch,
"CASE ELSE type %s cannot be matched to WHEN type %s",
t.Else.(tree.TypedExpr).ResolvedType(), valType,
))
}
orElse = b.buildScalar(elseExpr, inScope, nil, nil, colRefs)
} else {
orElse = b.factory.ConstructNull(valType)
Expand All @@ -284,7 +277,14 @@ func (b *Builder) buildScalar(
// The type of the CoalesceExpr might be different than the inputs (e.g.
// when they are NULL). Force all inputs to be the same type, so that we
// build coalesce operator with the correct type.
expr := tree.ReType(t.TypedExprAt(i), typ)
expr, ok := tree.ReType(t.TypedExprAt(i), typ)
if !ok {
panic(pgerror.Newf(
pgcode.DatatypeMismatch,
"COALESCE types %s and %s cannot be matched",
t.TypedExprAt(i).ResolvedType(), typ,
))
}
args[i] = b.buildScalar(expr, inScope, nil, nil, colRefs)
}
out = b.factory.ConstructCoalesce(args)
Expand Down Expand Up @@ -322,10 +322,19 @@ func (b *Builder) buildScalar(
case *tree.IfExpr:
valType := t.ResolvedType()
input := b.buildScalar(t.Cond.(tree.TypedExpr), inScope, nil, nil, colRefs)
ifTrueExpr := tree.ReType(t.True.(tree.TypedExpr), valType)
// Re-typing the True expression should always succeed because they
// are given the same type during type-checking.
ifTrueExpr := reType(t.True.(tree.TypedExpr), valType)
ifTrue := b.buildScalar(ifTrueExpr, inScope, nil, nil, colRefs)
whens := memo.ScalarListExpr{b.factory.ConstructWhen(memo.TrueSingleton, ifTrue)}
orElseExpr := tree.ReType(t.Else.(tree.TypedExpr), valType)
orElseExpr, ok := tree.ReType(t.Else.(tree.TypedExpr), valType)
if !ok {
panic(pgerror.Newf(
pgcode.DatatypeMismatch,
"IF types %s and %s cannot be matched",
t.Else.(tree.TypedExpr).ResolvedType(), valType,
))
}
orElse := b.buildScalar(orElseExpr, inScope, nil, nil, colRefs)
out = b.factory.ConstructCase(input, whens, orElse)

Expand All @@ -337,7 +346,7 @@ func (b *Builder) buildScalar(
out = b.factory.ConstructVariable(inScope.cols[t.Idx].id)

case *tree.NotExpr:
input := b.buildScalar(tree.ReType(t.TypedInnerExpr(), types.Bool), inScope, nil, nil, colRefs)
input := b.buildScalar(reType(t.TypedInnerExpr(), types.Bool), inScope, nil, nil, colRefs)
out = b.factory.ConstructNot(input)

case *tree.IsNullExpr:
Expand All @@ -362,7 +371,7 @@ func (b *Builder) buildScalar(
// of the NULLIF expression so that type inference will be correct in the
// CASE expression constructed below. For example, the type of
// NULLIF(NULL, 0) should be int.
expr1 := tree.ReType(t.Expr1.(tree.TypedExpr), valType)
expr1 := reType(t.Expr1.(tree.TypedExpr), valType)
input := b.buildScalar(expr1, inScope, nil, nil, colRefs)
cond := b.buildScalar(t.Expr2.(tree.TypedExpr), inScope, nil, nil, colRefs)
whens := memo.ScalarListExpr{
Expand All @@ -371,8 +380,8 @@ func (b *Builder) buildScalar(
out = b.factory.ConstructCase(input, whens, input)

case *tree.OrExpr:
left := b.buildScalar(tree.ReType(t.TypedLeft(), types.Bool), inScope, nil, nil, colRefs)
right := b.buildScalar(tree.ReType(t.TypedRight(), types.Bool), inScope, nil, nil, colRefs)
left := b.buildScalar(reType(t.TypedLeft(), types.Bool), inScope, nil, nil, colRefs)
right := b.buildScalar(reType(t.TypedRight(), types.Bool), inScope, nil, nil, colRefs)
out = b.factory.ConstructOr(left, right)

case *tree.ParenExpr:
Expand Down Expand Up @@ -869,3 +878,20 @@ func (sb *ScalarBuilder) Build(expr tree.Expr) (err error) {
sb.factory.Memo().SetScalarRoot(scalar)
return nil
}

// reType is similar to tree.ReType, except that it panics with an internal
// error if the expression cannot be re-typed. This should only be used when
// re-typing is expected to always be successful. For example, it is used to
// re-type the left and right children of an OrExpr to booleans, which should
// always succeed during the optbuild phase because type-checking has already
// validated the types of the children.
func reType(expr tree.TypedExpr, typ *types.T) tree.TypedExpr {
retypedExpr, ok := tree.ReType(expr, typ)
if !ok {
panic(errors.AssertionFailedf(
"expected successful retype from %s to %s",
expr.ResolvedType(), typ,
))
}
return retypedExpr
}
2 changes: 1 addition & 1 deletion pkg/sql/opt/optbuilder/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ func (s *scope) resolveAndRequireType(expr tree.Expr, desired *types.T) tree.Typ
if err != nil {
panic(err)
}
return tree.ReType(s.ensureNullType(texpr, desired), desired)
return reType(s.ensureNullType(texpr, desired), desired)
}

// ensureNullType tests the type of the given expression. If types.Unknown, then
Expand Down
42 changes: 40 additions & 2 deletions pkg/sql/opt/optbuilder/testdata/scalar
Original file line number Diff line number Diff line change
Expand Up @@ -1470,9 +1470,47 @@ is [type=bool]
build
SELECT CASE WHEN false THEN ARRAY[('', 0)] ELSE ARRAY[]::RECORD[] END
----
error (42804): CASE types tuple[] and tuple{string, int}[] cannot be matched
error (42804): CASE ELSE type tuple[] cannot be matched to WHEN type tuple{string, int}[]

build
SELECT CASE WHEN false THEN ARRAY[('', 0)] WHEN true THEN ARRAY[]::RECORD[] ELSE ARRAY[('', 0)] END
----
error (42804): CASE types tuple[] and tuple{string, int}[] cannot be matched
error (42804): CASE WHEN types tuple[] and tuple{string, int}[] cannot be matched

# Regression test for #76807. Do not create invalid casts when building COALESCE
# and IF expressions.
build
SELECT COALESCE(t.v, ARRAY[]:::RECORD[])
FROM (VALUES (ARRAY[(1, 'foo')])) AS t(v)
----
error (42804): COALESCE types tuple[] and tuple{int, string}[] cannot be matched

build
SELECT COALESCE(ARRAY[]:::RECORD[], t.v)
FROM (VALUES (ARRAY[(1, 'foo')])) AS t(v)
----
project
├── columns: coalesce:2
├── values
│ ├── columns: column1:1
│ └── (ARRAY[(1, 'foo')],)
└── projections
└── COALESCE(ARRAY[], column1:1::RECORD[]) [as=coalesce:2]

build
SELECT IF(true, t.v, ARRAY[]:::RECORD[])
FROM (VALUES (ARRAY[(1, 'foo')])) AS t(v)
----
error (42804): IF types tuple[] and tuple{int, string}[] cannot be matched

build
SELECT IF(true, ARRAY[]:::RECORD[], t.v)
FROM (VALUES (ARRAY[(1, 'foo')])) AS t(v)
----
project
├── columns: if:2
├── values
│ ├── columns: column1:1
│ └── (ARRAY[(1, 'foo')],)
└── projections
└── CASE WHEN true THEN ARRAY[] ELSE column1:1::RECORD[] END [as=if:2]
2 changes: 1 addition & 1 deletion pkg/sql/opt/optbuilder/window.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ func (b *Builder) getTypedWindowArgs(w *windowInfo) []tree.TypedExpr {
argExprs = append(argExprs, tree.NewDInt(1))
}
if len(argExprs) < 3 {
null := tree.ReType(tree.DNull, argExprs[0].ResolvedType())
null := reType(tree.DNull, argExprs[0].ResolvedType())
argExprs = append(argExprs, null)
}
}
Expand Down
9 changes: 8 additions & 1 deletion pkg/sql/sem/tree/constant_eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

package tree

import "github.com/cockroachdb/errors"

// ConstantEvalVisitor replaces constant TypedExprs with the result of Eval.
type ConstantEvalVisitor struct {
ctx *EvalContext
Expand Down Expand Up @@ -58,7 +60,12 @@ func (v *ConstantEvalVisitor) VisitPost(expr Expr) Expr {
if value == DNull {
// We don't want to return an expression that has a different type; cast
// the NULL if necessary.
return ReType(DNull, typedExpr.ResolvedType())
retypedNull, ok := ReType(DNull, typedExpr.ResolvedType())
if !ok {
v.err = errors.AssertionFailedf("failed to retype NULL to %s", typedExpr.ResolvedType())
return expr
}
return retypedNull
}
return value
}
Expand Down
36 changes: 24 additions & 12 deletions pkg/sql/sem/tree/normalize.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,37 +134,39 @@ func (expr *BinaryExpr) normalize(v *NormalizeVisitor) TypedExpr {
switch expr.Operator.Symbol {
case Plus:
if v.isNumericZero(right) {
final = ReType(left, expectedType)
final, _ = ReType(left, expectedType)
break
}
if v.isNumericZero(left) {
final = ReType(right, expectedType)
final, _ = ReType(right, expectedType)
break
}
case Minus:
if types.IsAdditiveType(left.ResolvedType()) && v.isNumericZero(right) {
final = ReType(left, expectedType)
final, _ = ReType(left, expectedType)
break
}
case Mult:
if v.isNumericOne(right) {
final = ReType(left, expectedType)
final, _ = ReType(left, expectedType)
break
}
if v.isNumericOne(left) {
final = ReType(right, expectedType)
final, _ = ReType(right, expectedType)
break
}
// We can't simplify multiplication by zero to zero,
// because if the other operand is NULL during evaluation
// the result must be NULL.
case Div, FloorDiv:
if v.isNumericOne(right) {
final = ReType(left, expectedType)
final, _ = ReType(left, expectedType)
break
}
}

// final is nil when the binary expression did not match the cases above,
// or when ReType was unsuccessful.
if final == nil {
return expr
}
Expand Down Expand Up @@ -759,7 +761,12 @@ func (v *NormalizeVisitor) VisitPost(expr Expr) Expr {
if value == DNull {
// We don't want to return an expression that has a different type; cast
// the NULL if necessary.
return ReType(DNull, expr.(TypedExpr).ResolvedType())
retypedNull, ok := ReType(DNull, expr.(TypedExpr).ResolvedType())
if !ok {
v.err = errors.AssertionFailedf("failed to retype NULL to %s", expr.(TypedExpr).ResolvedType())
return expr
}
return retypedNull
}
return value
}
Expand Down Expand Up @@ -991,14 +998,19 @@ func init() {
DecimalOne.SetInt64(1)
}

// ReType ensures that the given expression evaluates
// to the requested type, inserting a cast if necessary.
func ReType(expr TypedExpr, wantedType *types.T) TypedExpr {
// ReType ensures that the given expression evaluates to the requested type,
// wrapping the expression in a cast if necessary. Returns ok=false if a cast
// cannot wrap the expression because no valid cast from the expression's type
// to the wanted type exists.
func ReType(expr TypedExpr, wantedType *types.T) (_ TypedExpr, ok bool) {
resolvedType := expr.ResolvedType()
if wantedType.Family() == types.AnyFamily || resolvedType.Identical(wantedType) {
return expr
return expr, true
}
if _, ok := LookupCastVolatility(resolvedType, wantedType, nil /* sessionData */); !ok {
return nil, false
}
res := &CastExpr{Expr: expr, Type: wantedType}
res.typ = wantedType
return res
return res, true
}