Skip to content

Commit

Permalink
Merge #63647
Browse files Browse the repository at this point in the history
63647: opt: use optgen Let expressions r=mgartner a=mgartner

#### optfmt: fix Let formatting

Previously, `optfmt` could awkwardly format the entire expression on one
line except for the closing `)`. For example:

    (Let ($result $ok):(FoldBinary (OpName) $left $right) $ok
    )

This commit fixes the formatting so that the closing `)` will not be the
only element of the expression placed on a separate line. The example
above is now formatted as:

    (Let
        ($result $ok):(FoldBinary (OpName) $left $right) $ok
    )

Release note: None

#### opt: replace Succeeded and OrderingSucceeded with Let expression

The `Succeeded` and `OrderingSucceeded` custom functions have been
removed. Normalization and exploration rules which used these functions
now use `Let` expressions instead. Custom functions called in these
rules return an additional `ok bool` value which is bound in a `Let`
expression and used to determine if the rule matches an expression.

Release note: None

#### opt: simplify GenerateLocalityOptimizedAntiJoin with Let expression

The commit simplifies `GenerateLocalityOptimizedAntiJoin` by using a
`Let` expression. The `GetLocalityOptimizedAntiJoinLookupExprs` custom
function now returns the local expression, remote expression, and an
`ok` boolean, rather than a single `LocalAndRemoteLookupExprs` struct.
These values are bound to variables in the rule using a `Let`
expression.

Release note: None


Co-authored-by: Marcus Gartner <[email protected]>
  • Loading branch information
craig[bot] and mgartner committed May 4, 2021
2 parents 853a6bc + 4c4560f commit 7e654dc
Show file tree
Hide file tree
Showing 17 changed files with 221 additions and 206 deletions.
12 changes: 7 additions & 5 deletions pkg/sql/opt/norm/bool_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (c *CustomFuncs) NegateComparison(
// FindRedundantConjunct takes the left and right operands of an Or operator as
// input. It examines each conjunct from the left expression and determines
// whether it appears as a conjunct in the right expression. If so, it returns
// the matching conjunct. Otherwise, it returns nil. For example:
// the matching conjunct. Otherwise, it returns ok=false. For example:
//
// A OR A => A
// B OR A => nil
Expand All @@ -51,21 +51,23 @@ func (c *CustomFuncs) NegateComparison(
// Once a redundant conjunct has been found, it is extracted via a call to the
// ExtractRedundantConjunct function. Redundant conjuncts are extracted from
// multiple nested Or operators by repeated application of these functions.
func (c *CustomFuncs) FindRedundantConjunct(left, right opt.ScalarExpr) opt.ScalarExpr {
func (c *CustomFuncs) FindRedundantConjunct(
left, right opt.ScalarExpr,
) (_ opt.ScalarExpr, ok bool) {
// Recurse over each conjunct from the left expression and determine whether
// it's redundant.
for {
// Assume a left-deep And expression tree normalized by NormalizeNestedAnds.
if and, ok := left.(*memo.AndExpr); ok {
if c.isConjunct(and.Right, right) {
return and.Right
return and.Right, true
}
left = and.Left
} else {
if c.isConjunct(left, right) {
return left
return left, true
}
return nil
return nil, false
}
}
}
Expand Down
134 changes: 54 additions & 80 deletions pkg/sql/opt/norm/fold_constants_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,36 +277,40 @@ func (c *CustomFuncs) HasAllNonNullElements(tup *memo.TupleExpr) bool {
// FoldBinary evaluates a binary expression with constant inputs. It returns
// a constant expression as long as it finds an appropriate overload function
// for the given operator and input types, and the evaluation causes no error.
func (c *CustomFuncs) FoldBinary(op opt.Operator, left, right opt.ScalarExpr) opt.ScalarExpr {
// Otherwise, it returns ok=false.
func (c *CustomFuncs) FoldBinary(
op opt.Operator, left, right opt.ScalarExpr,
) (_ opt.ScalarExpr, ok bool) {
o, ok := memo.FindBinaryOverload(op, left.DataType(), right.DataType())
if !ok || !c.CanFoldOperator(o.Volatility) {
return nil
return nil, false
}

lDatum, rDatum := memo.ExtractConstDatum(left), memo.ExtractConstDatum(right)
result, err := o.Fn(c.f.evalCtx, lDatum, rDatum)
if err != nil {
return nil
return nil, false
}
return c.f.ConstructConstVal(result, o.ReturnType)
return c.f.ConstructConstVal(result, o.ReturnType), true
}

// FoldUnary evaluates a unary expression with a constant input. It returns
// a constant expression as long as it finds an appropriate overload function
// for the given operator and input type, and the evaluation causes no error.
func (c *CustomFuncs) FoldUnary(op opt.Operator, input opt.ScalarExpr) opt.ScalarExpr {
// Otherwise, it returns ok=false.
func (c *CustomFuncs) FoldUnary(op opt.Operator, input opt.ScalarExpr) (_ opt.ScalarExpr, ok bool) {
datum := memo.ExtractConstDatum(input)

o, ok := memo.FindUnaryOverload(op, input.DataType())
if !ok {
return nil
return nil, false
}

result, err := o.Fn(c.f.evalCtx, datum)
if err != nil {
return nil
return nil, false
}
return c.f.ConstructConstVal(result, o.ReturnType)
return c.f.ConstructConstVal(result, o.ReturnType), true
}

// foldStringToRegclassCast resolves a string that is a table name into an OID
Expand Down Expand Up @@ -337,34 +341,35 @@ func (c *CustomFuncs) foldStringToRegclassCast(

}

// FoldCast evaluates a cast expression with a constant input. It returns
// a constant expression as long as the evaluation causes no error.
func (c *CustomFuncs) FoldCast(input opt.ScalarExpr, typ *types.T) opt.ScalarExpr {
// FoldCast evaluates a cast expression with a constant input. It returns a
// constant expression as long as the evaluation causes no error. Otherwise, it
// returns ok=false.
func (c *CustomFuncs) FoldCast(input opt.ScalarExpr, typ *types.T) (_ opt.ScalarExpr, ok bool) {
if typ.Family() == types.OidFamily {
if typ.Oid() == types.RegClass.Oid() && input.DataType().Family() == types.StringFamily {
expr, err := c.foldStringToRegclassCast(input, typ)
if err == nil {
return expr
return expr, true
}
}
// Save this cast for the execbuilder.
return nil
return nil, false
}

volatility, ok := tree.LookupCastVolatility(input.DataType(), typ)
if !ok || !c.CanFoldOperator(volatility) {
return nil
return nil, false
}

datum := memo.ExtractConstDatum(input)
texpr := tree.NewTypedCastExpr(datum, typ)

result, err := texpr.Eval(c.f.evalCtx)
if err != nil {
return nil
return nil, false
}

return c.f.ConstructConstVal(result, typ)
return c.f.ConstructConstVal(result, typ), true
}

// isMonotonicConversion returns true if conversion of a value from FROM to
Expand Down Expand Up @@ -413,51 +418,17 @@ func isMonotonicConversion(from, to *types.T) bool {
return false
}

// UnifyComparison attempts to convert a constant expression to the type of the
// variable expression, if that conversion can round-trip and is monotonic.
func (c *CustomFuncs) UnifyComparison(v *memo.VariableExpr, cnst *memo.ConstExpr) opt.ScalarExpr {
desiredType := v.DataType()
originalType := cnst.DataType()

// Don't bother if they're already the same.
if desiredType.Equivalent(originalType) {
return nil
}

if !isMonotonicConversion(originalType, desiredType) {
return nil
}

// Check that the datum can round-trip between the types. If this is true, it
// means we don't lose any information needed to generate spans, and combined
// with monotonicity means that it's safe to convert the RHS to the type of
// the LHS.
convertedDatum, err := tree.PerformCast(c.f.evalCtx, cnst.Value, desiredType)
if err != nil {
return nil
}

convertedBack, err := tree.PerformCast(c.f.evalCtx, convertedDatum, originalType)
if err != nil {
return nil
}

if convertedBack.Compare(c.f.evalCtx, cnst.Value) != 0 {
return nil
}

return c.f.ConstructConst(convertedDatum, desiredType)
}

// FoldComparison evaluates a comparison expression with constant inputs. It
// returns a constant expression as long as it finds an appropriate overload
// function for the given operator and input types, and the evaluation causes
// no error.
func (c *CustomFuncs) FoldComparison(op opt.Operator, left, right opt.ScalarExpr) opt.ScalarExpr {
// no error. Otherwise, it returns ok=false.
func (c *CustomFuncs) FoldComparison(
op opt.Operator, left, right opt.ScalarExpr,
) (_ opt.ScalarExpr, ok bool) {
var flipped, not bool
o, flipped, not, ok := memo.FindComparisonOverload(op, left.DataType(), right.DataType())
if !ok || !c.CanFoldOperator(o.Volatility) {
return nil
return nil, false
}

lDatum, rDatum := memo.ExtractConstDatum(left), memo.ExtractConstDatum(right)
Expand All @@ -467,18 +438,18 @@ func (c *CustomFuncs) FoldComparison(op opt.Operator, left, right opt.ScalarExpr

result, err := o.Fn(c.f.evalCtx, lDatum, rDatum)
if err != nil {
return nil
return nil, false
}
if b, ok := result.(*tree.DBool); ok && not {
result = tree.MakeDBool(!*b)
}
return c.f.ConstructConstVal(result, types.Bool)
return c.f.ConstructConstVal(result, types.Bool), true
}

// FoldIndirection evaluates an array indirection operator with constant inputs.
// It returns the referenced array element as a constant value, or nil if the
// evaluation results in an error.
func (c *CustomFuncs) FoldIndirection(input, index opt.ScalarExpr) opt.ScalarExpr {
// It returns the referenced array element as a constant value, or ok=false if
// the evaluation results in an error.
func (c *CustomFuncs) FoldIndirection(input, index opt.ScalarExpr) (_ opt.ScalarExpr, ok bool) {
// Index is 1-based, so convert to 0-based.
indexD := memo.ExtractConstDatum(index)

Expand All @@ -487,14 +458,14 @@ func (c *CustomFuncs) FoldIndirection(input, index opt.ScalarExpr) opt.ScalarExp
if indexInt, ok := indexD.(*tree.DInt); ok {
indexI := int(*indexInt) - 1
if indexI >= 0 && indexI < len(arr.Elems) {
return arr.Elems[indexI]
return arr.Elems[indexI], true
}
return c.f.ConstructNull(arr.Typ.ArrayContents())
return c.f.ConstructNull(arr.Typ.ArrayContents()), true
}
if indexD == tree.DNull {
return c.f.ConstructNull(arr.Typ.ArrayContents())
return c.f.ConstructNull(arr.Typ.ArrayContents()), true
}
return nil
return nil, false
}

// Case 2: The input is a constant DArray.
Expand All @@ -503,28 +474,30 @@ func (c *CustomFuncs) FoldIndirection(input, index opt.ScalarExpr) opt.ScalarExp
texpr := tree.NewTypedIndirectionExpr(inputD, indexD, input.DataType().ArrayContents())
result, err := texpr.Eval(c.f.evalCtx)
if err == nil {
return c.f.ConstructConstVal(result, texpr.ResolvedType())
return c.f.ConstructConstVal(result, texpr.ResolvedType()), true
}
}

return nil
return nil, false
}

// FoldColumnAccess tries to evaluate a tuple column access operator with a
// constant tuple input (though tuple field values do not need to be constant).
// It returns the referenced tuple field value, or nil if folding is not
// It returns the referenced tuple field value, or ok=false if folding is not
// possible or results in an error.
func (c *CustomFuncs) FoldColumnAccess(input opt.ScalarExpr, idx memo.TupleOrdinal) opt.ScalarExpr {
func (c *CustomFuncs) FoldColumnAccess(
input opt.ScalarExpr, idx memo.TupleOrdinal,
) (_ opt.ScalarExpr, ok bool) {
// Case 1: The input is NULL. This is possible when FoldIndirection has
// already folded an Indirection expression with an out-of-bounds index to
// Null.
if n, ok := input.(*memo.NullExpr); ok {
return c.f.ConstructNull(n.Typ.TupleContents()[idx])
return c.f.ConstructNull(n.Typ.TupleContents()[idx]), true
}

// Case 2: The input is a static tuple constructor.
if tup, ok := input.(*memo.TupleExpr); ok {
return tup.Elems[idx]
return tup.Elems[idx], true
}

// Case 3: The input is a constant DTuple.
Expand All @@ -534,11 +507,11 @@ func (c *CustomFuncs) FoldColumnAccess(input opt.ScalarExpr, idx memo.TupleOrdin
texpr := tree.NewTypedColumnAccessExpr(datum, "" /* by-index access */, int(idx))
result, err := texpr.Eval(c.f.evalCtx)
if err == nil {
return c.f.ConstructConstVal(result, texpr.ResolvedType())
return c.f.ConstructConstVal(result, texpr.ResolvedType()), true
}
}

return nil
return nil, false
}

// CanFoldFunctionWithNullArg returns true if the given function can be folded
Expand Down Expand Up @@ -569,20 +542,21 @@ func (c *CustomFuncs) FunctionReturnType(private *memo.FunctionPrivate) *types.T
return private.Typ
}

// FoldFunction evaluates a function expression with constant inputs. It
// returns a constant expression as long as the function is contained in the
// FoldFunctionAllowlist, and the evaluation causes no error.
// FoldFunction evaluates a function expression with constant inputs. It returns
// a constant expression as long as the function is contained in the
// FoldFunctionAllowlist, and the evaluation causes no error. Otherwise, it
// returns ok=false.
func (c *CustomFuncs) FoldFunction(
args memo.ScalarListExpr, private *memo.FunctionPrivate,
) opt.ScalarExpr {
) (_ opt.ScalarExpr, ok bool) {
// Non-normal function classes (aggregate, window, generator) cannot be
// folded into a single constant.
if private.Properties.Class != tree.NormalClass {
return nil
return nil, false
}

if !c.CanFoldOperator(private.Overload.Volatility) {
return nil
return nil, false
}

exprs := make(tree.TypedExprs, len(args))
Expand All @@ -603,7 +577,7 @@ func (c *CustomFuncs) FoldFunction(

result, err := fn.Eval(c.f.evalCtx)
if err != nil {
return nil
return nil, false
}
return c.f.ConstructConstVal(result, private.Typ)
return c.f.ConstructConstVal(result, private.Typ), true
}
21 changes: 0 additions & 21 deletions pkg/sql/opt/norm/general_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,6 @@ func (c *CustomFuncs) Init(f *Factory) {
}
}

// Succeeded returns true if a result expression is not nil.
func (c *CustomFuncs) Succeeded(result opt.Expr) bool {
return result != nil
}

// ----------------------------------------------------------------------
//
// Typing functions
Expand Down Expand Up @@ -165,22 +160,6 @@ func (c *CustomFuncs) NotNullCols(input memo.RelExpr) opt.ColSet {
return input.Relational().NotNullCols
}

// SingleRegressionCountArgument checks if either arg is non-null and returns
// the other one (or nil if neither is non-null).
func (c *CustomFuncs) SingleRegressionCountArgument(
y, x opt.ScalarExpr, input memo.RelExpr,
) opt.ScalarExpr {
notNullCols := c.NotNullCols(input)
if c.ExprIsNeverNull(y, notNullCols) {
return x
}
if c.ExprIsNeverNull(x, notNullCols) {
return y
}

return nil
}

// IsColNotNull returns true if the given input column is never null.
func (c *CustomFuncs) IsColNotNull(col opt.ColumnID, input memo.RelExpr) bool {
return input.Relational().NotNullCols.Contains(col)
Expand Down
16 changes: 16 additions & 0 deletions pkg/sql/opt/norm/groupby_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,22 @@ func (c *CustomFuncs) areRowsDistinct(
return true
}

// SingleRegressionCountArgument checks if either arg is non-null and returns
// the other one. If neither is non-null it returns ok=false.
func (c *CustomFuncs) SingleRegressionCountArgument(
y, x opt.ScalarExpr, input memo.RelExpr,
) (_ opt.ScalarExpr, ok bool) {
notNullCols := c.NotNullCols(input)
if c.ExprIsNeverNull(y, notNullCols) {
return x, true
}
if c.ExprIsNeverNull(x, notNullCols) {
return y, true
}

return nil, false
}

// CanMergeAggs returns true if one of the following applies to each of the
// given outer aggregation expressions:
// 1. The aggregation can be merged with a single inner aggregation.
Expand Down
5 changes: 3 additions & 2 deletions pkg/sql/opt/norm/rules/bool.opt
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,9 @@ $input
(Or
$left:^(Or)
$right:^(Or) &
(Succeeded
$conjunct:(FindRedundantConjunct $left $right)
(Let
($conjunct $ok):(FindRedundantConjunct $left $right)
$ok
)
)
=>
Expand Down
Loading

0 comments on commit 7e654dc

Please sign in to comment.