Skip to content

Commit

Permalink
Merge pull request #78342 from mgartner/backport21.2-77608
Browse files Browse the repository at this point in the history
release-21.2: optbuilder: do not create invalid casts when building COALESCE and IF
  • Loading branch information
mgartner authored Mar 23, 2022
2 parents 596b352 + bb4e381 commit 2a28bdd
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 52 deletions.
6 changes: 5 additions & 1 deletion pkg/sql/opt/exec/execbuilder/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,11 @@ func (b *Builder) buildNull(ctx *buildScalarCtx, scalar opt.ScalarExpr) (tree.Ty
// See comment in buildCast.
return tree.DNull, nil
}
return tree.ReType(tree.DNull, scalar.DataType()), nil
retypedNull, ok := tree.ReType(tree.DNull, scalar.DataType())
if !ok {
return nil, errors.AssertionFailedf("failed to retype NULL to %s", scalar.DataType())
}
return retypedNull, nil
}

func (b *Builder) buildVariable(
Expand Down
94 changes: 60 additions & 34 deletions pkg/sql/opt/optbuilder/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ func (b *Builder) buildScalar(
return b.finishBuildScalarRef(t.col, inScope, outScope, outCol, colRefs)

case *tree.AndExpr:
left := b.buildScalar(tree.ReType(t.TypedLeft(), types.Bool), inScope, nil, nil, colRefs)
right := b.buildScalar(tree.ReType(t.TypedRight(), types.Bool), inScope, nil, nil, colRefs)
left := b.buildScalar(reType(t.TypedLeft(), types.Bool), inScope, nil, nil, colRefs)
right := b.buildScalar(reType(t.TypedRight(), types.Bool), inScope, nil, nil, colRefs)
out = b.factory.ConstructAnd(left, right)

case *tree.Array:
Expand Down Expand Up @@ -214,8 +214,8 @@ func (b *Builder) buildScalar(
// select the right overload. The solution is to wrap any mismatched
// arguments with a CastExpr that preserves the static type.

left := tree.ReType(t.TypedLeft(), t.ResolvedBinOp().LeftType)
right := tree.ReType(t.TypedRight(), t.ResolvedBinOp().RightType)
left := reType(t.TypedLeft(), t.ResolvedBinOp().LeftType)
right := reType(t.TypedRight(), t.ResolvedBinOp().RightType)
out = b.constructBinary(
tree.MakeBinaryOperator(t.Operator.Symbol),
b.buildScalar(left, inScope, nil, nil, colRefs),
Expand All @@ -233,39 +233,32 @@ func (b *Builder) buildScalar(
input = memo.TrueSingleton
}

// validateCastToValType panics if tree.ReType with the given source
// type would create an invalid cast to valType.
validateCastToValType := func(src *types.T) {
if valType.Family() == types.AnyFamily || src.Identical(valType) {
// If valType's family is AnyFamily or src is identical to
// valType, then tree.Retype will not create a cast expression.
return
}
if _, ok := tree.LookupCastVolatility(src, valType, nil /* sessionData */); ok {
return
}
panic(pgerror.Newf(
pgcode.DatatypeMismatch,
"CASE types %s and %s cannot be matched", src, valType,
))
}

whens := make(memo.ScalarListExpr, 0, len(t.Whens)+1)
for i := range t.Whens {
condExpr := t.Whens[i].Cond.(tree.TypedExpr)
cond := b.buildScalar(condExpr, inScope, nil, nil, colRefs)
valExpr := t.Whens[i].Val.(tree.TypedExpr)
validateCastToValType(valExpr.ResolvedType())
valExpr = tree.ReType(valExpr, valType)
valExpr, ok := tree.ReType(t.Whens[i].Val.(tree.TypedExpr), valType)
if !ok {
panic(pgerror.Newf(
pgcode.DatatypeMismatch,
"CASE WHEN types %s and %s cannot be matched",
t.Whens[i].Val.(tree.TypedExpr).ResolvedType(), valType,
))
}
val := b.buildScalar(valExpr, inScope, nil, nil, colRefs)
whens = append(whens, b.factory.ConstructWhen(cond, val))
}
// Add the ELSE expression to the end of whens as a raw scalar expression.
var orElse opt.ScalarExpr
if t.Else != nil {
elseExpr := t.Else.(tree.TypedExpr)
validateCastToValType(elseExpr.ResolvedType())
elseExpr = tree.ReType(elseExpr, valType)
elseExpr, ok := tree.ReType(t.Else.(tree.TypedExpr), valType)
if !ok {
panic(pgerror.Newf(
pgcode.DatatypeMismatch,
"CASE ELSE type %s cannot be matched to WHEN type %s",
t.Else.(tree.TypedExpr).ResolvedType(), valType,
))
}
orElse = b.buildScalar(elseExpr, inScope, nil, nil, colRefs)
} else {
orElse = b.factory.ConstructNull(valType)
Expand All @@ -284,7 +277,14 @@ func (b *Builder) buildScalar(
// The type of the CoalesceExpr might be different than the inputs (e.g.
// when they are NULL). Force all inputs to be the same type, so that we
// build coalesce operator with the correct type.
expr := tree.ReType(t.TypedExprAt(i), typ)
expr, ok := tree.ReType(t.TypedExprAt(i), typ)
if !ok {
panic(pgerror.Newf(
pgcode.DatatypeMismatch,
"COALESCE types %s and %s cannot be matched",
t.TypedExprAt(i).ResolvedType(), typ,
))
}
args[i] = b.buildScalar(expr, inScope, nil, nil, colRefs)
}
out = b.factory.ConstructCoalesce(args)
Expand Down Expand Up @@ -322,10 +322,19 @@ func (b *Builder) buildScalar(
case *tree.IfExpr:
valType := t.ResolvedType()
input := b.buildScalar(t.Cond.(tree.TypedExpr), inScope, nil, nil, colRefs)
ifTrueExpr := tree.ReType(t.True.(tree.TypedExpr), valType)
// Re-typing the True expression should always succeed because they
// are given the same type during type-checking.
ifTrueExpr := reType(t.True.(tree.TypedExpr), valType)
ifTrue := b.buildScalar(ifTrueExpr, inScope, nil, nil, colRefs)
whens := memo.ScalarListExpr{b.factory.ConstructWhen(memo.TrueSingleton, ifTrue)}
orElseExpr := tree.ReType(t.Else.(tree.TypedExpr), valType)
orElseExpr, ok := tree.ReType(t.Else.(tree.TypedExpr), valType)
if !ok {
panic(pgerror.Newf(
pgcode.DatatypeMismatch,
"IF types %s and %s cannot be matched",
t.Else.(tree.TypedExpr).ResolvedType(), valType,
))
}
orElse := b.buildScalar(orElseExpr, inScope, nil, nil, colRefs)
out = b.factory.ConstructCase(input, whens, orElse)

Expand All @@ -337,7 +346,7 @@ func (b *Builder) buildScalar(
out = b.factory.ConstructVariable(inScope.cols[t.Idx].id)

case *tree.NotExpr:
input := b.buildScalar(tree.ReType(t.TypedInnerExpr(), types.Bool), inScope, nil, nil, colRefs)
input := b.buildScalar(reType(t.TypedInnerExpr(), types.Bool), inScope, nil, nil, colRefs)
out = b.factory.ConstructNot(input)

case *tree.IsNullExpr:
Expand All @@ -362,7 +371,7 @@ func (b *Builder) buildScalar(
// of the NULLIF expression so that type inference will be correct in the
// CASE expression constructed below. For example, the type of
// NULLIF(NULL, 0) should be int.
expr1 := tree.ReType(t.Expr1.(tree.TypedExpr), valType)
expr1 := reType(t.Expr1.(tree.TypedExpr), valType)
input := b.buildScalar(expr1, inScope, nil, nil, colRefs)
cond := b.buildScalar(t.Expr2.(tree.TypedExpr), inScope, nil, nil, colRefs)
whens := memo.ScalarListExpr{
Expand All @@ -371,8 +380,8 @@ func (b *Builder) buildScalar(
out = b.factory.ConstructCase(input, whens, input)

case *tree.OrExpr:
left := b.buildScalar(tree.ReType(t.TypedLeft(), types.Bool), inScope, nil, nil, colRefs)
right := b.buildScalar(tree.ReType(t.TypedRight(), types.Bool), inScope, nil, nil, colRefs)
left := b.buildScalar(reType(t.TypedLeft(), types.Bool), inScope, nil, nil, colRefs)
right := b.buildScalar(reType(t.TypedRight(), types.Bool), inScope, nil, nil, colRefs)
out = b.factory.ConstructOr(left, right)

case *tree.ParenExpr:
Expand Down Expand Up @@ -869,3 +878,20 @@ func (sb *ScalarBuilder) Build(expr tree.Expr) (err error) {
sb.factory.Memo().SetScalarRoot(scalar)
return nil
}

// reType is similar to tree.ReType, except that it panics with an internal
// error if the expression cannot be re-typed. This should only be used when
// re-typing is expected to always be successful. For example, it is used to
// re-type the left and right children of an OrExpr to booleans, which should
// always succeed during the optbuild phase because type-checking has already
// validated the types of the children.
func reType(expr tree.TypedExpr, typ *types.T) tree.TypedExpr {
retypedExpr, ok := tree.ReType(expr, typ)
if !ok {
panic(errors.AssertionFailedf(
"expected successful retype from %s to %s",
expr.ResolvedType(), typ,
))
}
return retypedExpr
}
2 changes: 1 addition & 1 deletion pkg/sql/opt/optbuilder/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ func (s *scope) resolveAndRequireType(expr tree.Expr, desired *types.T) tree.Typ
if err != nil {
panic(err)
}
return tree.ReType(s.ensureNullType(texpr, desired), desired)
return reType(s.ensureNullType(texpr, desired), desired)
}

// ensureNullType tests the type of the given expression. If types.Unknown, then
Expand Down
42 changes: 40 additions & 2 deletions pkg/sql/opt/optbuilder/testdata/scalar
Original file line number Diff line number Diff line change
Expand Up @@ -1470,9 +1470,47 @@ is [type=bool]
build
SELECT CASE WHEN false THEN ARRAY[('', 0)] ELSE ARRAY[]::RECORD[] END
----
error (42804): CASE types tuple[] and tuple{string, int}[] cannot be matched
error (42804): CASE ELSE type tuple[] cannot be matched to WHEN type tuple{string, int}[]

build
SELECT CASE WHEN false THEN ARRAY[('', 0)] WHEN true THEN ARRAY[]::RECORD[] ELSE ARRAY[('', 0)] END
----
error (42804): CASE types tuple[] and tuple{string, int}[] cannot be matched
error (42804): CASE WHEN types tuple[] and tuple{string, int}[] cannot be matched

# Regression test for #76807. Do not create invalid casts when building COALESCE
# and IF expressions.
build
SELECT COALESCE(t.v, ARRAY[]:::RECORD[])
FROM (VALUES (ARRAY[(1, 'foo')])) AS t(v)
----
error (42804): COALESCE types tuple[] and tuple{int, string}[] cannot be matched

build
SELECT COALESCE(ARRAY[]:::RECORD[], t.v)
FROM (VALUES (ARRAY[(1, 'foo')])) AS t(v)
----
project
├── columns: coalesce:2
├── values
│ ├── columns: column1:1
│ └── (ARRAY[(1, 'foo')],)
└── projections
└── COALESCE(ARRAY[], column1:1::RECORD[]) [as=coalesce:2]

build
SELECT IF(true, t.v, ARRAY[]:::RECORD[])
FROM (VALUES (ARRAY[(1, 'foo')])) AS t(v)
----
error (42804): IF types tuple[] and tuple{int, string}[] cannot be matched

build
SELECT IF(true, ARRAY[]:::RECORD[], t.v)
FROM (VALUES (ARRAY[(1, 'foo')])) AS t(v)
----
project
├── columns: if:2
├── values
│ ├── columns: column1:1
│ └── (ARRAY[(1, 'foo')],)
└── projections
└── CASE WHEN true THEN ARRAY[] ELSE column1:1::RECORD[] END [as=if:2]
2 changes: 1 addition & 1 deletion pkg/sql/opt/optbuilder/window.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ func (b *Builder) getTypedWindowArgs(w *windowInfo) []tree.TypedExpr {
argExprs = append(argExprs, tree.NewDInt(1))
}
if len(argExprs) < 3 {
null := tree.ReType(tree.DNull, argExprs[0].ResolvedType())
null := reType(tree.DNull, argExprs[0].ResolvedType())
argExprs = append(argExprs, null)
}
}
Expand Down
9 changes: 8 additions & 1 deletion pkg/sql/sem/tree/constant_eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

package tree

import "github.com/cockroachdb/errors"

// ConstantEvalVisitor replaces constant TypedExprs with the result of Eval.
type ConstantEvalVisitor struct {
ctx *EvalContext
Expand Down Expand Up @@ -58,7 +60,12 @@ func (v *ConstantEvalVisitor) VisitPost(expr Expr) Expr {
if value == DNull {
// We don't want to return an expression that has a different type; cast
// the NULL if necessary.
return ReType(DNull, typedExpr.ResolvedType())
retypedNull, ok := ReType(DNull, typedExpr.ResolvedType())
if !ok {
v.err = errors.AssertionFailedf("failed to retype NULL to %s", typedExpr.ResolvedType())
return expr
}
return retypedNull
}
return value
}
Expand Down
36 changes: 24 additions & 12 deletions pkg/sql/sem/tree/normalize.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,37 +134,39 @@ func (expr *BinaryExpr) normalize(v *NormalizeVisitor) TypedExpr {
switch expr.Operator.Symbol {
case Plus:
if v.isNumericZero(right) {
final = ReType(left, expectedType)
final, _ = ReType(left, expectedType)
break
}
if v.isNumericZero(left) {
final = ReType(right, expectedType)
final, _ = ReType(right, expectedType)
break
}
case Minus:
if types.IsAdditiveType(left.ResolvedType()) && v.isNumericZero(right) {
final = ReType(left, expectedType)
final, _ = ReType(left, expectedType)
break
}
case Mult:
if v.isNumericOne(right) {
final = ReType(left, expectedType)
final, _ = ReType(left, expectedType)
break
}
if v.isNumericOne(left) {
final = ReType(right, expectedType)
final, _ = ReType(right, expectedType)
break
}
// We can't simplify multiplication by zero to zero,
// because if the other operand is NULL during evaluation
// the result must be NULL.
case Div, FloorDiv:
if v.isNumericOne(right) {
final = ReType(left, expectedType)
final, _ = ReType(left, expectedType)
break
}
}

// final is nil when the binary expression did not match the cases above,
// or when ReType was unsuccessful.
if final == nil {
return expr
}
Expand Down Expand Up @@ -759,7 +761,12 @@ func (v *NormalizeVisitor) VisitPost(expr Expr) Expr {
if value == DNull {
// We don't want to return an expression that has a different type; cast
// the NULL if necessary.
return ReType(DNull, expr.(TypedExpr).ResolvedType())
retypedNull, ok := ReType(DNull, expr.(TypedExpr).ResolvedType())
if !ok {
v.err = errors.AssertionFailedf("failed to retype NULL to %s", expr.(TypedExpr).ResolvedType())
return expr
}
return retypedNull
}
return value
}
Expand Down Expand Up @@ -991,14 +998,19 @@ func init() {
DecimalOne.SetInt64(1)
}

// ReType ensures that the given expression evaluates
// to the requested type, inserting a cast if necessary.
func ReType(expr TypedExpr, wantedType *types.T) TypedExpr {
// ReType ensures that the given expression evaluates to the requested type,
// wrapping the expression in a cast if necessary. Returns ok=false if a cast
// cannot wrap the expression because no valid cast from the expression's type
// to the wanted type exists.
func ReType(expr TypedExpr, wantedType *types.T) (_ TypedExpr, ok bool) {
resolvedType := expr.ResolvedType()
if wantedType.Family() == types.AnyFamily || resolvedType.Identical(wantedType) {
return expr
return expr, true
}
if _, ok := LookupCastVolatility(resolvedType, wantedType, nil /* sessionData */); !ok {
return nil, false
}
res := &CastExpr{Expr: expr, Type: wantedType}
res.typ = wantedType
return res
return res, true
}

0 comments on commit 2a28bdd

Please sign in to comment.