Skip to content

Commit

Permalink
optbuilder: do not create invalid casts when building COALESCE and IF
Browse files Browse the repository at this point in the history
The optbuilder no longer creates invalid casts when building COALESCE
and IF expressions that have children with different types. Expressions
that previously caused internal errors now result in user-facing errors.
Both UNION and CASE expressions had similar bugs that were recently
fixed in cockroachdb#75219 and cockroachdb#76193.

This commit also updates the `tree.ReType` function to return `ok=false`
if there is no valid cast to re-type the expression to the given type.
This forces callers to explicitly deal with situations where re-typing
is not possible and it ensures that the function never creates invalid
casts. This will make it easier to track down future related bugs
because internal errors should originate from the call site of
`tree.ReType` rather than from logic further along in the optimization
process (in the case of cockroachdb#76807 the internal error originated from the
logical props builder when it attempted to lookup the volatility of the
invalid cast).

This commit also adds special logic to make casts from any tuple type to
`types.AnyTuple` valid immutable, implicit casts. Evaluation of these
casts are no-ops. Users cannot construct these casts, but they are built
by optbuilder in some cases.

Fixes cockroachdb#76807

Release justification: This is a low-risk change that fixes a minor bug.

Release note (bug fix): A bug has been fixed that caused internal errors
when COALESCE and IF expressions had inner expressions with different
types that could not be cast to a common type.
  • Loading branch information
mgartner committed Mar 11, 2022
1 parent 5e1b8d6 commit e55e6b6
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 56 deletions.
6 changes: 5 additions & 1 deletion pkg/sql/opt/exec/execbuilder/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ func (b *Builder) buildTypedExpr(
}

func (b *Builder) buildNull(ctx *buildScalarCtx, scalar opt.ScalarExpr) (tree.TypedExpr, error) {
return tree.ReType(tree.DNull, scalar.DataType()), nil
retypedNull, ok := tree.ReType(tree.DNull, scalar.DataType())
if !ok {
return nil, errors.AssertionFailedf("failed to retype NULL to %s", scalar.DataType())
}
return retypedNull, nil
}

func (b *Builder) buildVariable(
Expand Down
98 changes: 60 additions & 38 deletions pkg/sql/opt/optbuilder/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ func (b *Builder) buildScalar(
return b.finishBuildScalarRef(t.col, inScope, outScope, outCol, colRefs)

case *tree.AndExpr:
left := b.buildScalar(tree.ReType(t.TypedLeft(), types.Bool), inScope, nil, nil, colRefs)
right := b.buildScalar(tree.ReType(t.TypedRight(), types.Bool), inScope, nil, nil, colRefs)
left := b.buildScalar(reType(t.TypedLeft(), types.Bool), inScope, nil, nil, colRefs)
right := b.buildScalar(reType(t.TypedRight(), types.Bool), inScope, nil, nil, colRefs)
out = b.factory.ConstructAnd(left, right)

case *tree.Array:
Expand Down Expand Up @@ -216,8 +216,8 @@ func (b *Builder) buildScalar(
// select the right overload. The solution is to wrap any mismatched
// arguments with a CastExpr that preserves the static type.

left := tree.ReType(t.TypedLeft(), t.ResolvedBinOp().LeftType)
right := tree.ReType(t.TypedRight(), t.ResolvedBinOp().RightType)
left := reType(t.TypedLeft(), t.ResolvedBinOp().LeftType)
right := reType(t.TypedRight(), t.ResolvedBinOp().RightType)
out = b.constructBinary(
treebin.MakeBinaryOperator(t.Operator.Symbol),
b.buildScalar(left, inScope, nil, nil, colRefs),
Expand All @@ -235,43 +235,32 @@ func (b *Builder) buildScalar(
input = memo.TrueSingleton
}

// validateCastToValType panics if tree.ReType with the given source
// type would create an invalid cast to valType.
validateCastToValType := func(src *types.T) {
if valType.Family() == types.AnyFamily || src.Identical(valType) {
// If valType's family is AnyFamily or src is identical to
// valType, then tree.Retype will not create a cast expression.
return
}
if tree.ValidCast(src, valType, tree.CastContextExplicit) {
// TODO(#75103): For legacy reasons, we check for a valid cast
// in the most permissive context, CastContextExplicit. To be
// consistent with Postgres, we should check for a valid cast in
// the most restrictive context, CastContextImplicit.
return
}
panic(pgerror.Newf(
pgcode.DatatypeMismatch,
"CASE types %s and %s cannot be matched", src, valType,
))
}

whens := make(memo.ScalarListExpr, 0, len(t.Whens)+1)
for i := range t.Whens {
condExpr := t.Whens[i].Cond.(tree.TypedExpr)
cond := b.buildScalar(condExpr, inScope, nil, nil, colRefs)
valExpr := t.Whens[i].Val.(tree.TypedExpr)
validateCastToValType(valExpr.ResolvedType())
valExpr = tree.ReType(valExpr, valType)
valExpr, ok := tree.ReType(t.Whens[i].Val.(tree.TypedExpr), valType)
if !ok {
panic(pgerror.Newf(
pgcode.DatatypeMismatch,
"CASE WHEN types %s and %s cannot be matched",
t.Whens[i].Val.(tree.TypedExpr).ResolvedType(), valType,
))
}
val := b.buildScalar(valExpr, inScope, nil, nil, colRefs)
whens = append(whens, b.factory.ConstructWhen(cond, val))
}
// Add the ELSE expression to the end of whens as a raw scalar expression.
var orElse opt.ScalarExpr
if t.Else != nil {
elseExpr := t.Else.(tree.TypedExpr)
validateCastToValType(elseExpr.ResolvedType())
elseExpr = tree.ReType(elseExpr, valType)
elseExpr, ok := tree.ReType(t.Else.(tree.TypedExpr), valType)
if !ok {
panic(pgerror.Newf(
pgcode.DatatypeMismatch,
"CASE ELSE type %s cannot be matched to WHEN type %s",
t.Else.(tree.TypedExpr).ResolvedType(), valType,
))
}
orElse = b.buildScalar(elseExpr, inScope, nil, nil, colRefs)
} else {
orElse = b.factory.ConstructNull(valType)
Expand All @@ -290,7 +279,14 @@ func (b *Builder) buildScalar(
// The type of the CoalesceExpr might be different than the inputs (e.g.
// when they are NULL). Force all inputs to be the same type, so that we
// build coalesce operator with the correct type.
expr := tree.ReType(t.TypedExprAt(i), typ)
expr, ok := tree.ReType(t.TypedExprAt(i), typ)
if !ok {
panic(pgerror.Newf(
pgcode.DatatypeMismatch,
"COALESCE types %s and %s cannot be matched",
t.TypedExprAt(i).ResolvedType(), typ,
))
}
args[i] = b.buildScalar(expr, inScope, nil, nil, colRefs)
}
out = b.factory.ConstructCoalesce(args)
Expand Down Expand Up @@ -328,10 +324,19 @@ func (b *Builder) buildScalar(
case *tree.IfExpr:
valType := t.ResolvedType()
input := b.buildScalar(t.Cond.(tree.TypedExpr), inScope, nil, nil, colRefs)
ifTrueExpr := tree.ReType(t.True.(tree.TypedExpr), valType)
// Re-typing the True expression should always succeed because they
// are given the same type during type-checking.
ifTrueExpr := reType(t.True.(tree.TypedExpr), valType)
ifTrue := b.buildScalar(ifTrueExpr, inScope, nil, nil, colRefs)
whens := memo.ScalarListExpr{b.factory.ConstructWhen(memo.TrueSingleton, ifTrue)}
orElseExpr := tree.ReType(t.Else.(tree.TypedExpr), valType)
orElseExpr, ok := tree.ReType(t.Else.(tree.TypedExpr), valType)
if !ok {
panic(pgerror.Newf(
pgcode.DatatypeMismatch,
"IF types %s and %s cannot be matched",
t.Else.(tree.TypedExpr).ResolvedType(), valType,
))
}
orElse := b.buildScalar(orElseExpr, inScope, nil, nil, colRefs)
out = b.factory.ConstructCase(input, whens, orElse)

Expand All @@ -343,7 +348,7 @@ func (b *Builder) buildScalar(
out = b.factory.ConstructVariable(inScope.cols[t.Idx].id)

case *tree.NotExpr:
input := b.buildScalar(tree.ReType(t.TypedInnerExpr(), types.Bool), inScope, nil, nil, colRefs)
input := b.buildScalar(reType(t.TypedInnerExpr(), types.Bool), inScope, nil, nil, colRefs)
out = b.factory.ConstructNot(input)

case *tree.IsNullExpr:
Expand All @@ -368,7 +373,7 @@ func (b *Builder) buildScalar(
// of the NULLIF expression so that type inference will be correct in the
// CASE expression constructed below. For example, the type of
// NULLIF(NULL, 0) should be int.
expr1 := tree.ReType(t.Expr1.(tree.TypedExpr), valType)
expr1 := reType(t.Expr1.(tree.TypedExpr), valType)
input := b.buildScalar(expr1, inScope, nil, nil, colRefs)
cond := b.buildScalar(t.Expr2.(tree.TypedExpr), inScope, nil, nil, colRefs)
whens := memo.ScalarListExpr{
Expand All @@ -377,8 +382,8 @@ func (b *Builder) buildScalar(
out = b.factory.ConstructCase(input, whens, input)

case *tree.OrExpr:
left := b.buildScalar(tree.ReType(t.TypedLeft(), types.Bool), inScope, nil, nil, colRefs)
right := b.buildScalar(tree.ReType(t.TypedRight(), types.Bool), inScope, nil, nil, colRefs)
left := b.buildScalar(reType(t.TypedLeft(), types.Bool), inScope, nil, nil, colRefs)
right := b.buildScalar(reType(t.TypedRight(), types.Bool), inScope, nil, nil, colRefs)
out = b.factory.ConstructOr(left, right)

case *tree.ParenExpr:
Expand Down Expand Up @@ -875,3 +880,20 @@ func (sb *ScalarBuilder) Build(expr tree.Expr) (err error) {
sb.factory.Memo().SetScalarRoot(scalar)
return nil
}

// reType is similar to tree.ReType, except that it panics with an internal
// error if the expression cannot be re-typed. This should only be used when
// re-typing is expected to always be successful. For example, it is used to
// re-type the left and right children of an OrExpr to booleans, which should
// always succeed during the optbuild phase because type-checking has already
// validated the types of the children.
func reType(expr tree.TypedExpr, typ *types.T) tree.TypedExpr {
retypedExpr, ok := tree.ReType(expr, typ)
if !ok {
panic(errors.AssertionFailedf(
"expected successful retype from %s to %s",
expr.ResolvedType(), typ,
))
}
return retypedExpr
}
42 changes: 40 additions & 2 deletions pkg/sql/opt/optbuilder/testdata/scalar
Original file line number Diff line number Diff line change
Expand Up @@ -1494,9 +1494,47 @@ is [type=bool]
build
SELECT CASE WHEN false THEN ARRAY[('', 0)] ELSE ARRAY[]::RECORD[] END
----
error (42804): CASE types tuple[] and tuple{string, int}[] cannot be matched
error (42804): CASE ELSE type tuple[] cannot be matched to WHEN type tuple{string, int}[]

build
SELECT CASE WHEN false THEN ARRAY[('', 0)] WHEN true THEN ARRAY[]::RECORD[] ELSE ARRAY[('', 0)] END
----
error (42804): CASE types tuple[] and tuple{string, int}[] cannot be matched
error (42804): CASE WHEN types tuple[] and tuple{string, int}[] cannot be matched

# Regression test for #76807. Do not create invalid casts when building COALESCE
# and IF expressions.
build
SELECT COALESCE(t.v, ARRAY[]:::RECORD[])
FROM (VALUES (ARRAY[(1, 'foo')])) AS t(v)
----
error (42804): COALESCE types tuple[] and tuple{int, string}[] cannot be matched

build
SELECT COALESCE(ARRAY[]:::RECORD[], t.v)
FROM (VALUES (ARRAY[(1, 'foo')])) AS t(v)
----
project
├── columns: coalesce:2
├── values
│ ├── columns: column1:1
│ └── (ARRAY[(1, 'foo')],)
└── projections
└── COALESCE(ARRAY[], column1:1::RECORD[]) [as=coalesce:2]

build
SELECT IF(true, t.v, ARRAY[]:::RECORD[])
FROM (VALUES (ARRAY[(1, 'foo')])) AS t(v)
----
error (42804): IF types tuple[] and tuple{int, string}[] cannot be matched

build
SELECT IF(true, ARRAY[]:::RECORD[], t.v)
FROM (VALUES (ARRAY[(1, 'foo')])) AS t(v)
----
project
├── columns: if:2
├── values
│ ├── columns: column1:1
│ └── (ARRAY[(1, 'foo')],)
└── projections
└── CASE WHEN true THEN ARRAY[] ELSE column1:1::RECORD[] END [as=if:2]
2 changes: 1 addition & 1 deletion pkg/sql/opt/optbuilder/window.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ func (b *Builder) getTypedWindowArgs(w *windowInfo) []tree.TypedExpr {
argExprs = append(argExprs, tree.NewDInt(1))
}
if len(argExprs) < 3 {
null := tree.ReType(tree.DNull, argExprs[0].ResolvedType())
null := reType(tree.DNull, argExprs[0].ResolvedType())
argExprs = append(argExprs, null)
}
}
Expand Down
15 changes: 14 additions & 1 deletion pkg/sql/sem/tree/cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -1283,7 +1283,11 @@ func ValidCast(src, tgt *types.T, ctx CastContext) bool {

// If src and tgt are tuple types, check for a valid cast between each
// corresponding tuple element.
if srcFamily == types.TupleFamily && tgtFamily == types.TupleFamily {
//
// Casts from a tuple type to AnyTuple are a no-op so they are always valid.
// If tgt is AnyTuple, we continue to lookupCast below which contains a
// special case for these casts.
if srcFamily == types.TupleFamily && tgtFamily == types.TupleFamily && tgt != types.AnyTuple {
srcTypes := src.TupleContents()
tgtTypes := tgt.TupleContents()
// The tuple types must have the same number of elements.
Expand Down Expand Up @@ -1375,6 +1379,15 @@ func lookupCast(src, tgt *types.T, intervalStyleEnabled, dateStyleEnabled bool)
}, true
}

// Casts from any tuple type to AnyTuple are no-ops, so they are implicit
// and immutable.
if srcFamily == types.TupleFamily && tgt == types.AnyTuple {
return cast{
maxContext: CastContextImplicit,
volatility: VolatilityImmutable,
}, true
}

// Casts from string types to array and tuple types are stable and allowed
// in explicit contexts.
if srcFamily == types.StringFamily &&
Expand Down
9 changes: 8 additions & 1 deletion pkg/sql/sem/tree/constant_eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

package tree

import "github.com/cockroachdb/errors"

// ConstantEvalVisitor replaces constant TypedExprs with the result of Eval.
type ConstantEvalVisitor struct {
ctx *EvalContext
Expand Down Expand Up @@ -58,7 +60,12 @@ func (v *ConstantEvalVisitor) VisitPost(expr Expr) Expr {
if value == DNull {
// We don't want to return an expression that has a different type; cast
// the NULL if necessary.
return ReType(DNull, typedExpr.ResolvedType())
retypedNull, ok := ReType(DNull, typedExpr.ResolvedType())
if !ok {
v.err = errors.AssertionFailedf("failed to retype NULL to %s", typedExpr.ResolvedType())
return expr
}
return retypedNull
}
return value
}
Expand Down
40 changes: 28 additions & 12 deletions pkg/sql/sem/tree/normalize.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,37 +135,39 @@ func (expr *BinaryExpr) normalize(v *NormalizeVisitor) TypedExpr {
switch expr.Operator.Symbol {
case treebin.Plus:
if v.isNumericZero(right) {
final = ReType(left, expectedType)
final, _ = ReType(left, expectedType)
break
}
if v.isNumericZero(left) {
final = ReType(right, expectedType)
final, _ = ReType(right, expectedType)
break
}
case treebin.Minus:
if types.IsAdditiveType(left.ResolvedType()) && v.isNumericZero(right) {
final = ReType(left, expectedType)
final, _ = ReType(left, expectedType)
break
}
case treebin.Mult:
if v.isNumericOne(right) {
final = ReType(left, expectedType)
final, _ = ReType(left, expectedType)
break
}
if v.isNumericOne(left) {
final = ReType(right, expectedType)
final, _ = ReType(right, expectedType)
break
}
// We can't simplify multiplication by zero to zero,
// because if the other operand is NULL during evaluation
// the result must be NULL.
case treebin.Div, treebin.FloorDiv:
if v.isNumericOne(right) {
final = ReType(left, expectedType)
final, _ = ReType(left, expectedType)
break
}
}

// final is nil when the binary expression did not match the cases above,
// or when ReType was unsuccessful.
if final == nil {
return expr
}
Expand Down Expand Up @@ -710,7 +712,12 @@ func (v *NormalizeVisitor) VisitPost(expr Expr) Expr {
if value == DNull {
// We don't want to return an expression that has a different type; cast
// the NULL if necessary.
return ReType(DNull, expr.(TypedExpr).ResolvedType())
retypedNull, ok := ReType(DNull, expr.(TypedExpr).ResolvedType())
if !ok {
v.err = errors.AssertionFailedf("failed to retype NULL to %s", expr.(TypedExpr).ResolvedType())
return expr
}
return retypedNull
}
return value
}
Expand Down Expand Up @@ -942,14 +949,23 @@ func init() {
DecimalOne.SetInt64(1)
}

// ReType ensures that the given expression evaluates
// to the requested type, inserting a cast if necessary.
func ReType(expr TypedExpr, wantedType *types.T) TypedExpr {
// ReType ensures that the given expression evaluates to the requested type,
// wrapping the expression in a cast if necessary. Returns ok=false if a cast
// cannot wrap the expression because no valid cast from the expression's type
// to the wanted type exists.
func ReType(expr TypedExpr, wantedType *types.T) (_ TypedExpr, ok bool) {
resolvedType := expr.ResolvedType()
if wantedType.Family() == types.AnyFamily || resolvedType.Identical(wantedType) {
return expr
return expr, true
}
// TODO(#75103): For legacy reasons, we check for a valid cast in the most
// permissive context, CastContextExplicit. To be consistent with Postgres,
// we should check for a valid cast in the most restrictive context,
// CastContextImplicit.
if !ValidCast(resolvedType, wantedType, CastContextExplicit) {
return nil, false
}
res := &CastExpr{Expr: expr, Type: wantedType}
res.typ = wantedType
return res
return res, true
}

0 comments on commit e55e6b6

Please sign in to comment.