diff --git a/pkg/sql/opt/norm/bool_funcs.go b/pkg/sql/opt/norm/bool_funcs.go index 2f0c60d7f96a..72f6c3a4aa02 100644 --- a/pkg/sql/opt/norm/bool_funcs.go +++ b/pkg/sql/opt/norm/bool_funcs.go @@ -40,7 +40,7 @@ func (c *CustomFuncs) NegateComparison( // FindRedundantConjunct takes the left and right operands of an Or operator as // input. It examines each conjunct from the left expression and determines // whether it appears as a conjunct in the right expression. If so, it returns -// the matching conjunct. Otherwise, it returns nil. For example: +// the matching conjunct. Otherwise, it returns ok=false. For example: // // A OR A => A // B OR A => nil @@ -51,21 +51,23 @@ func (c *CustomFuncs) NegateComparison( // Once a redundant conjunct has been found, it is extracted via a call to the // ExtractRedundantConjunct function. Redundant conjuncts are extracted from // multiple nested Or operators by repeated application of these functions. -func (c *CustomFuncs) FindRedundantConjunct(left, right opt.ScalarExpr) opt.ScalarExpr { +func (c *CustomFuncs) FindRedundantConjunct( + left, right opt.ScalarExpr, +) (_ opt.ScalarExpr, ok bool) { // Recurse over each conjunct from the left expression and determine whether // it's redundant. for { // Assume a left-deep And expression tree normalized by NormalizeNestedAnds. if and, ok := left.(*memo.AndExpr); ok { if c.isConjunct(and.Right, right) { - return and.Right + return and.Right, true } left = and.Left } else { if c.isConjunct(left, right) { - return left + return left, true } - return nil + return nil, false } } } diff --git a/pkg/sql/opt/norm/fold_constants_funcs.go b/pkg/sql/opt/norm/fold_constants_funcs.go index a9a1afa2d496..d2c3b6ffa54c 100644 --- a/pkg/sql/opt/norm/fold_constants_funcs.go +++ b/pkg/sql/opt/norm/fold_constants_funcs.go @@ -277,36 +277,40 @@ func (c *CustomFuncs) HasAllNonNullElements(tup *memo.TupleExpr) bool { // FoldBinary evaluates a binary expression with constant inputs. It returns // a constant expression as long as it finds an appropriate overload function // for the given operator and input types, and the evaluation causes no error. -func (c *CustomFuncs) FoldBinary(op opt.Operator, left, right opt.ScalarExpr) opt.ScalarExpr { +// Otherwise, it returns ok=false. +func (c *CustomFuncs) FoldBinary( + op opt.Operator, left, right opt.ScalarExpr, +) (_ opt.ScalarExpr, ok bool) { o, ok := memo.FindBinaryOverload(op, left.DataType(), right.DataType()) if !ok || !c.CanFoldOperator(o.Volatility) { - return nil + return nil, false } lDatum, rDatum := memo.ExtractConstDatum(left), memo.ExtractConstDatum(right) result, err := o.Fn(c.f.evalCtx, lDatum, rDatum) if err != nil { - return nil + return nil, false } - return c.f.ConstructConstVal(result, o.ReturnType) + return c.f.ConstructConstVal(result, o.ReturnType), true } // FoldUnary evaluates a unary expression with a constant input. It returns // a constant expression as long as it finds an appropriate overload function // for the given operator and input type, and the evaluation causes no error. -func (c *CustomFuncs) FoldUnary(op opt.Operator, input opt.ScalarExpr) opt.ScalarExpr { +// Otherwise, it returns ok=false. +func (c *CustomFuncs) FoldUnary(op opt.Operator, input opt.ScalarExpr) (_ opt.ScalarExpr, ok bool) { datum := memo.ExtractConstDatum(input) o, ok := memo.FindUnaryOverload(op, input.DataType()) if !ok { - return nil + return nil, false } result, err := o.Fn(c.f.evalCtx, datum) if err != nil { - return nil + return nil, false } - return c.f.ConstructConstVal(result, o.ReturnType) + return c.f.ConstructConstVal(result, o.ReturnType), true } // foldStringToRegclassCast resolves a string that is a table name into an OID @@ -337,23 +341,24 @@ func (c *CustomFuncs) foldStringToRegclassCast( } -// FoldCast evaluates a cast expression with a constant input. It returns -// a constant expression as long as the evaluation causes no error. -func (c *CustomFuncs) FoldCast(input opt.ScalarExpr, typ *types.T) opt.ScalarExpr { +// FoldCast evaluates a cast expression with a constant input. It returns a +// constant expression as long as the evaluation causes no error. Otherwise, it +// returns ok=false. +func (c *CustomFuncs) FoldCast(input opt.ScalarExpr, typ *types.T) (_ opt.ScalarExpr, ok bool) { if typ.Family() == types.OidFamily { if typ.Oid() == types.RegClass.Oid() && input.DataType().Family() == types.StringFamily { expr, err := c.foldStringToRegclassCast(input, typ) if err == nil { - return expr + return expr, true } } // Save this cast for the execbuilder. - return nil + return nil, false } volatility, ok := tree.LookupCastVolatility(input.DataType(), typ) if !ok || !c.CanFoldOperator(volatility) { - return nil + return nil, false } datum := memo.ExtractConstDatum(input) @@ -361,10 +366,10 @@ func (c *CustomFuncs) FoldCast(input opt.ScalarExpr, typ *types.T) opt.ScalarExp result, err := texpr.Eval(c.f.evalCtx) if err != nil { - return nil + return nil, false } - return c.f.ConstructConstVal(result, typ) + return c.f.ConstructConstVal(result, typ), true } // isMonotonicConversion returns true if conversion of a value from FROM to @@ -413,51 +418,17 @@ func isMonotonicConversion(from, to *types.T) bool { return false } -// UnifyComparison attempts to convert a constant expression to the type of the -// variable expression, if that conversion can round-trip and is monotonic. -func (c *CustomFuncs) UnifyComparison(v *memo.VariableExpr, cnst *memo.ConstExpr) opt.ScalarExpr { - desiredType := v.DataType() - originalType := cnst.DataType() - - // Don't bother if they're already the same. - if desiredType.Equivalent(originalType) { - return nil - } - - if !isMonotonicConversion(originalType, desiredType) { - return nil - } - - // Check that the datum can round-trip between the types. If this is true, it - // means we don't lose any information needed to generate spans, and combined - // with monotonicity means that it's safe to convert the RHS to the type of - // the LHS. - convertedDatum, err := tree.PerformCast(c.f.evalCtx, cnst.Value, desiredType) - if err != nil { - return nil - } - - convertedBack, err := tree.PerformCast(c.f.evalCtx, convertedDatum, originalType) - if err != nil { - return nil - } - - if convertedBack.Compare(c.f.evalCtx, cnst.Value) != 0 { - return nil - } - - return c.f.ConstructConst(convertedDatum, desiredType) -} - // FoldComparison evaluates a comparison expression with constant inputs. It // returns a constant expression as long as it finds an appropriate overload // function for the given operator and input types, and the evaluation causes -// no error. -func (c *CustomFuncs) FoldComparison(op opt.Operator, left, right opt.ScalarExpr) opt.ScalarExpr { +// no error. Otherwise, it returns ok=false. +func (c *CustomFuncs) FoldComparison( + op opt.Operator, left, right opt.ScalarExpr, +) (_ opt.ScalarExpr, ok bool) { var flipped, not bool o, flipped, not, ok := memo.FindComparisonOverload(op, left.DataType(), right.DataType()) if !ok || !c.CanFoldOperator(o.Volatility) { - return nil + return nil, false } lDatum, rDatum := memo.ExtractConstDatum(left), memo.ExtractConstDatum(right) @@ -467,18 +438,18 @@ func (c *CustomFuncs) FoldComparison(op opt.Operator, left, right opt.ScalarExpr result, err := o.Fn(c.f.evalCtx, lDatum, rDatum) if err != nil { - return nil + return nil, false } if b, ok := result.(*tree.DBool); ok && not { result = tree.MakeDBool(!*b) } - return c.f.ConstructConstVal(result, types.Bool) + return c.f.ConstructConstVal(result, types.Bool), true } // FoldIndirection evaluates an array indirection operator with constant inputs. -// It returns the referenced array element as a constant value, or nil if the -// evaluation results in an error. -func (c *CustomFuncs) FoldIndirection(input, index opt.ScalarExpr) opt.ScalarExpr { +// It returns the referenced array element as a constant value, or ok=false if +// the evaluation results in an error. +func (c *CustomFuncs) FoldIndirection(input, index opt.ScalarExpr) (_ opt.ScalarExpr, ok bool) { // Index is 1-based, so convert to 0-based. indexD := memo.ExtractConstDatum(index) @@ -487,14 +458,14 @@ func (c *CustomFuncs) FoldIndirection(input, index opt.ScalarExpr) opt.ScalarExp if indexInt, ok := indexD.(*tree.DInt); ok { indexI := int(*indexInt) - 1 if indexI >= 0 && indexI < len(arr.Elems) { - return arr.Elems[indexI] + return arr.Elems[indexI], true } - return c.f.ConstructNull(arr.Typ.ArrayContents()) + return c.f.ConstructNull(arr.Typ.ArrayContents()), true } if indexD == tree.DNull { - return c.f.ConstructNull(arr.Typ.ArrayContents()) + return c.f.ConstructNull(arr.Typ.ArrayContents()), true } - return nil + return nil, false } // Case 2: The input is a constant DArray. @@ -503,28 +474,30 @@ func (c *CustomFuncs) FoldIndirection(input, index opt.ScalarExpr) opt.ScalarExp texpr := tree.NewTypedIndirectionExpr(inputD, indexD, input.DataType().ArrayContents()) result, err := texpr.Eval(c.f.evalCtx) if err == nil { - return c.f.ConstructConstVal(result, texpr.ResolvedType()) + return c.f.ConstructConstVal(result, texpr.ResolvedType()), true } } - return nil + return nil, false } // FoldColumnAccess tries to evaluate a tuple column access operator with a // constant tuple input (though tuple field values do not need to be constant). -// It returns the referenced tuple field value, or nil if folding is not +// It returns the referenced tuple field value, or ok=false if folding is not // possible or results in an error. -func (c *CustomFuncs) FoldColumnAccess(input opt.ScalarExpr, idx memo.TupleOrdinal) opt.ScalarExpr { +func (c *CustomFuncs) FoldColumnAccess( + input opt.ScalarExpr, idx memo.TupleOrdinal, +) (_ opt.ScalarExpr, ok bool) { // Case 1: The input is NULL. This is possible when FoldIndirection has // already folded an Indirection expression with an out-of-bounds index to // Null. if n, ok := input.(*memo.NullExpr); ok { - return c.f.ConstructNull(n.Typ.TupleContents()[idx]) + return c.f.ConstructNull(n.Typ.TupleContents()[idx]), true } // Case 2: The input is a static tuple constructor. if tup, ok := input.(*memo.TupleExpr); ok { - return tup.Elems[idx] + return tup.Elems[idx], true } // Case 3: The input is a constant DTuple. @@ -534,11 +507,11 @@ func (c *CustomFuncs) FoldColumnAccess(input opt.ScalarExpr, idx memo.TupleOrdin texpr := tree.NewTypedColumnAccessExpr(datum, "" /* by-index access */, int(idx)) result, err := texpr.Eval(c.f.evalCtx) if err == nil { - return c.f.ConstructConstVal(result, texpr.ResolvedType()) + return c.f.ConstructConstVal(result, texpr.ResolvedType()), true } } - return nil + return nil, false } // CanFoldFunctionWithNullArg returns true if the given function can be folded @@ -569,20 +542,21 @@ func (c *CustomFuncs) FunctionReturnType(private *memo.FunctionPrivate) *types.T return private.Typ } -// FoldFunction evaluates a function expression with constant inputs. It -// returns a constant expression as long as the function is contained in the -// FoldFunctionAllowlist, and the evaluation causes no error. +// FoldFunction evaluates a function expression with constant inputs. It returns +// a constant expression as long as the function is contained in the +// FoldFunctionAllowlist, and the evaluation causes no error. Otherwise, it +// returns ok=false. func (c *CustomFuncs) FoldFunction( args memo.ScalarListExpr, private *memo.FunctionPrivate, -) opt.ScalarExpr { +) (_ opt.ScalarExpr, ok bool) { // Non-normal function classes (aggregate, window, generator) cannot be // folded into a single constant. if private.Properties.Class != tree.NormalClass { - return nil + return nil, false } if !c.CanFoldOperator(private.Overload.Volatility) { - return nil + return nil, false } exprs := make(tree.TypedExprs, len(args)) @@ -603,7 +577,7 @@ func (c *CustomFuncs) FoldFunction( result, err := fn.Eval(c.f.evalCtx) if err != nil { - return nil + return nil, false } - return c.f.ConstructConstVal(result, private.Typ) + return c.f.ConstructConstVal(result, private.Typ), true } diff --git a/pkg/sql/opt/norm/general_funcs.go b/pkg/sql/opt/norm/general_funcs.go index 84747c1039e5..b38390fc4958 100644 --- a/pkg/sql/opt/norm/general_funcs.go +++ b/pkg/sql/opt/norm/general_funcs.go @@ -40,11 +40,6 @@ func (c *CustomFuncs) Init(f *Factory) { } } -// Succeeded returns true if a result expression is not nil. -func (c *CustomFuncs) Succeeded(result opt.Expr) bool { - return result != nil -} - // ---------------------------------------------------------------------- // // Typing functions @@ -165,22 +160,6 @@ func (c *CustomFuncs) NotNullCols(input memo.RelExpr) opt.ColSet { return input.Relational().NotNullCols } -// SingleRegressionCountArgument checks if either arg is non-null and returns -// the other one (or nil if neither is non-null). -func (c *CustomFuncs) SingleRegressionCountArgument( - y, x opt.ScalarExpr, input memo.RelExpr, -) opt.ScalarExpr { - notNullCols := c.NotNullCols(input) - if c.ExprIsNeverNull(y, notNullCols) { - return x - } - if c.ExprIsNeverNull(x, notNullCols) { - return y - } - - return nil -} - // IsColNotNull returns true if the given input column is never null. func (c *CustomFuncs) IsColNotNull(col opt.ColumnID, input memo.RelExpr) bool { return input.Relational().NotNullCols.Contains(col) diff --git a/pkg/sql/opt/norm/groupby_funcs.go b/pkg/sql/opt/norm/groupby_funcs.go index d3bef0d28588..ce8b632d03a1 100644 --- a/pkg/sql/opt/norm/groupby_funcs.go +++ b/pkg/sql/opt/norm/groupby_funcs.go @@ -280,6 +280,22 @@ func (c *CustomFuncs) areRowsDistinct( return true } +// SingleRegressionCountArgument checks if either arg is non-null and returns +// the other one. If neither is non-null it returns ok=false. +func (c *CustomFuncs) SingleRegressionCountArgument( + y, x opt.ScalarExpr, input memo.RelExpr, +) (_ opt.ScalarExpr, ok bool) { + notNullCols := c.NotNullCols(input) + if c.ExprIsNeverNull(y, notNullCols) { + return x, true + } + if c.ExprIsNeverNull(x, notNullCols) { + return y, true + } + + return nil, false +} + // CanMergeAggs returns true if one of the following applies to each of the // given outer aggregation expressions: // 1. The aggregation can be merged with a single inner aggregation. diff --git a/pkg/sql/opt/norm/rules/bool.opt b/pkg/sql/opt/norm/rules/bool.opt index 959d859f4680..f7be09f63098 100644 --- a/pkg/sql/opt/norm/rules/bool.opt +++ b/pkg/sql/opt/norm/rules/bool.opt @@ -156,8 +156,9 @@ $input (Or $left:^(Or) $right:^(Or) & - (Succeeded - $conjunct:(FindRedundantConjunct $left $right) + (Let + ($conjunct $ok):(FindRedundantConjunct $left $right) + $ok ) ) => diff --git a/pkg/sql/opt/norm/rules/fold_constants.opt b/pkg/sql/opt/norm/rules/fold_constants.opt index f680741648a8..43aa9151fae2 100644 --- a/pkg/sql/opt/norm/rules/fold_constants.opt +++ b/pkg/sql/opt/norm/rules/fold_constants.opt @@ -82,7 +82,9 @@ $left:* & (IsConstValueOrGroupOfConstValues $left) $right:* & (IsConstValueOrGroupOfConstValues $right) & - (Succeeded $result:(FoldBinary (OpName) $left $right)) + (Let + ($result $ok):(FoldBinary (OpName) $left $right) $ok + ) ) => $result @@ -94,7 +96,7 @@ $result (Unary $input:* & (IsConstValueOrGroupOfConstValues $input) & - (Succeeded $result:(FoldUnary (OpName) $input)) + (Let ($result $ok):(FoldUnary (OpName) $input) $ok) ) => $result @@ -107,8 +109,9 @@ $result $left:* & (IsConstValueOrGroupOfConstValues $left) $right:* & (IsConstValueOrGroupOfConstValues $right) & - (Succeeded - $result:(FoldComparison (OpName) $left $right) + (Let + ($result $ok):(FoldComparison (OpName) $left $right) + $ok ) ) => @@ -122,7 +125,7 @@ $result $input:* $typ:* & (IsConstValueOrGroupOfConstValues $input) & - (Succeeded $result:(FoldCast $input $typ)) + (Let ($result $ok):(FoldCast $input $typ) $ok) ) => $result @@ -139,7 +142,7 @@ $result $input:* $index:* & (IsConstValueOrGroupOfConstValues $index) & - (Succeeded $result:(FoldIndirection $input $index)) + (Let ($result $ok):(FoldIndirection $input $index) $ok) ) => $result @@ -155,7 +158,8 @@ $result [FoldColumnAccess, Normalize] (ColumnAccess $input:* - $idx:* & (Succeeded $result:(FoldColumnAccess $input $idx)) + $idx:* & + (Let ($result $ok):(FoldColumnAccess $input $idx) $ok) ) => $result @@ -203,7 +207,7 @@ $result (Function $args:* & (IsListOfConstants $args) $private:* & - (Succeeded $result:(FoldFunction $args $private)) + (Let ($result $ok):(FoldFunction $args $private) $ok) ) => $result diff --git a/pkg/sql/opt/norm/rules/groupby.opt b/pkg/sql/opt/norm/rules/groupby.opt index a62017c52642..196aceabdc93 100644 --- a/pkg/sql/opt/norm/rules/groupby.opt +++ b/pkg/sql/opt/norm/rules/groupby.opt @@ -415,12 +415,13 @@ $item:(AggregationsItem (RegressionCount $arg1:* $arg2:*) ) & - (Succeeded - $newArg:(SingleRegressionCountArgument + (Let + ($newArg $ok):(SingleRegressionCountArgument $arg1 $arg2 $input ) + $ok ) ... ] diff --git a/pkg/sql/opt/norm/rules/scalar.opt b/pkg/sql/opt/norm/rules/scalar.opt index 71b012b776fd..f83e672aceb1 100644 --- a/pkg/sql/opt/norm/rules/scalar.opt +++ b/pkg/sql/opt/norm/rules/scalar.opt @@ -103,7 +103,7 @@ $input (Comparison $left:(Variable) $right:(Const) & - (Succeeded $result:(UnifyComparison $left $right)) + (Let ($result $ok):(UnifyComparison $left $right) $ok) ) => ((OpName) $left $result) diff --git a/pkg/sql/opt/norm/rules/window.opt b/pkg/sql/opt/norm/rules/window.opt index ca87aaa6d1a8..8bc11eb2db51 100644 --- a/pkg/sql/opt/norm/rules/window.opt +++ b/pkg/sql/opt/norm/rules/window.opt @@ -158,13 +158,14 @@ $input (Window $input:* $fns:* & (AllArePrefixSafe $fns) $private:*) $limit:* $ordering:* & - (OrderingSucceeded - $newOrdering:(MakeSegmentedOrdering + (Let + ($newOrdering $ok):(MakeSegmentedOrdering $input (WindowPartition $private) (WindowOrdering $private) $ordering ) + $ok ) ) => diff --git a/pkg/sql/opt/norm/scalar_funcs.go b/pkg/sql/opt/norm/scalar_funcs.go index 79999e22ed5d..28fe50019d5a 100644 --- a/pkg/sql/opt/norm/scalar_funcs.go +++ b/pkg/sql/opt/norm/scalar_funcs.go @@ -116,6 +116,45 @@ func (c *CustomFuncs) IsConstValueEqual(const1, const2 opt.ScalarExpr) bool { } } +// UnifyComparison attempts to convert a constant expression to the type of the +// variable expression, if that conversion can round-trip and is monotonic. +// Otherwise it returns ok=false. +func (c *CustomFuncs) UnifyComparison( + v *memo.VariableExpr, cnst *memo.ConstExpr, +) (_ opt.ScalarExpr, ok bool) { + desiredType := v.DataType() + originalType := cnst.DataType() + + // Don't bother if they're already the same. + if desiredType.Equivalent(originalType) { + return nil, false + } + + if !isMonotonicConversion(originalType, desiredType) { + return nil, false + } + + // Check that the datum can round-trip between the types. If this is true, it + // means we don't lose any information needed to generate spans, and combined + // with monotonicity means that it's safe to convert the RHS to the type of + // the LHS. + convertedDatum, err := tree.PerformCast(c.f.evalCtx, cnst.Value, desiredType) + if err != nil { + return nil, false + } + + convertedBack, err := tree.PerformCast(c.f.evalCtx, convertedDatum, originalType) + if err != nil { + return nil, false + } + + if convertedBack.Compare(c.f.evalCtx, cnst.Value) != 0 { + return nil, false + } + + return c.f.ConstructConst(convertedDatum, desiredType), true +} + // SimplifyWhens removes known unreachable WHEN cases and constructs a new CASE // statement. Any known true condition is converted to the ELSE. If only the // ELSE remains, its expression is returned. condition must be a ConstValue. diff --git a/pkg/sql/opt/norm/window_funcs.go b/pkg/sql/opt/norm/window_funcs.go index e69269424ce4..b24a42d3de08 100644 --- a/pkg/sql/opt/norm/window_funcs.go +++ b/pkg/sql/opt/norm/window_funcs.go @@ -18,15 +18,15 @@ import ( ) // MakeSegmentedOrdering returns an ordering choice which satisfies both -// limitOrdering and the ordering required by a window function. Returns nil if -// no such ordering exists. See OrderingChoice.PrefixIntersection for more -// details. +// limitOrdering and the ordering required by a window function. Returns +// ok=false if no such ordering exists. See OrderingChoice.PrefixIntersection +// for more details. func (c *CustomFuncs) MakeSegmentedOrdering( input memo.RelExpr, prefix opt.ColSet, ordering props.OrderingChoice, limitOrdering props.OrderingChoice, -) *props.OrderingChoice { +) (_ *props.OrderingChoice, ok bool) { // The columns in the closure of the prefix may be included in it. It's // beneficial to do so for a given column iff that column appears in the @@ -38,9 +38,9 @@ func (c *CustomFuncs) MakeSegmentedOrdering( oc, ok := limitOrdering.PrefixIntersection(prefix, ordering.Columns) if !ok { - return nil + return nil, false } - return &oc + return &oc, true } // AllArePrefixSafe returns whether every window function in the list satisfies @@ -140,11 +140,6 @@ func (c *CustomFuncs) ExtractUndeterminedConditions( return newFilters } -// OrderingSucceeded returns true if an OrderingChoice is not nil. -func (c *CustomFuncs) OrderingSucceeded(result *props.OrderingChoice) bool { - return result != nil -} - // DerefOrderingChoice returns an OrderingChoice from a pointer. func (c *CustomFuncs) DerefOrderingChoice(result *props.OrderingChoice) props.OrderingChoice { return *result diff --git a/pkg/sql/opt/optgen/cmd/optfmt/main.go b/pkg/sql/opt/optgen/cmd/optfmt/main.go index c6da21ea4d86..f2ad4f3b1ca3 100644 --- a/pkg/sql/opt/optgen/cmd/optfmt/main.go +++ b/pkg/sql/opt/optgen/cmd/optfmt/main.go @@ -395,7 +395,6 @@ func (p *pp) docOnlyExpr(e lang.Expr) pretty.Doc { )) binding := pretty.Group(pretty.Fold(pretty.Concat, - pretty.Line, labels, pretty.Text("):"), p.docExpr(e.Target), @@ -407,12 +406,11 @@ func (p *pp) docOnlyExpr(e lang.Expr) pretty.Doc { p.docExpr(e.Result), )) - return pretty.Group(pretty.Fold(pretty.Concat, - pretty.Text("(Let"), - pretty.NestT(inner), - pretty.SoftBreak, + return pretty.BracketDoc( + pretty.Text("(Let "), + inner, pretty.Text(")"), - )) + ) case *lang.AnyExpr: return pretty.Text("*") case *lang.ListAnyExpr: diff --git a/pkg/sql/opt/optgen/cmd/optfmt/testdata/test b/pkg/sql/opt/optgen/cmd/optfmt/testdata/test index c5c794c220e7..750bc6415b91 100644 --- a/pkg/sql/opt/optgen/cmd/optfmt/testdata/test +++ b/pkg/sql/opt/optgen/cmd/optfmt/testdata/test @@ -219,7 +219,8 @@ define FiltersItem { ) & (OuterFunc (InnerFunc - (Let ($foo $bar):(SplitFilters $input $filters) + (Let + ($foo $bar):(SplitFilters $input $filters) $foo ) ) @@ -244,3 +245,29 @@ pretty (R) => (O) + +# The closing ")" should not be printed on it's own line if the result, "$ok", +# is not printed on it's own line. +pretty +[FoldBinary, Normalize] +(Binary + $left:* & (IsConstValueOrGroupOfConstValues $left) + $right:* & + (IsConstValueOrGroupOfConstValues $right) & + (Let ($result $ok):(FoldBinary (OpName) $left $right) $ok + ) +) +=> +$result +---- +[FoldBinary, Normalize] +(Binary + $left:* & (IsConstValueOrGroupOfConstValues $left) + $right:* & + (IsConstValueOrGroupOfConstValues $right) & + (Let + ($result $ok):(FoldBinary (OpName) $left $right) $ok + ) +) +=> +$result diff --git a/pkg/sql/opt/xform/join_funcs.go b/pkg/sql/opt/xform/join_funcs.go index af1e11462486..93ae7704c924 100644 --- a/pkg/sql/opt/xform/join_funcs.go +++ b/pkg/sql/opt/xform/join_funcs.go @@ -1155,20 +1155,6 @@ func (c *CustomFuncs) MakeProjectionsForOuterJoin( return result } -// LocalAndRemoteLookupExprs is used by the GenerateLocalityOptimizedAntiJoin -// rule to hold two sets of filters: one targeting local partitions and one -// targeting remote partitions. -type LocalAndRemoteLookupExprs struct { - Local memo.FiltersExpr - Remote memo.FiltersExpr -} - -// LocalAndRemoteLookupExprsSucceeded returns true if the -// LocalAndRemoteLookupExprs is not empty. -func (c *CustomFuncs) LocalAndRemoteLookupExprsSucceeded(le LocalAndRemoteLookupExprs) bool { - return len(le.Local) != 0 && len(le.Remote) != 0 -} - // CreateLocalityOptimizedAntiLookupJoinPrivate creates a new lookup join // private from the given private and replaces the LookupExpr with the given // filters. It also marks the private as locality optimized. @@ -1181,57 +1167,45 @@ func (c *CustomFuncs) CreateLocalityOptimizedAntiLookupJoinPrivate( return &newPrivate } -// LocalLookupExpr extracts the Local filters expr from the given -// LocalAndRemoteLookupExprs. -func (c *CustomFuncs) LocalLookupExpr(le LocalAndRemoteLookupExprs) memo.FiltersExpr { - return le.Local -} - -// RemoteLookupExpr extracts the Remote filters expr from the given -// LocalAndRemoteLookupExprs. -func (c *CustomFuncs) RemoteLookupExpr(le LocalAndRemoteLookupExprs) memo.FiltersExpr { - return le.Remote -} - -// GetLocalityOptimizedAntiJoinLookupExprs gets the lookup expressions needed to -// build a locality optimized anti join if possible from the given lookup join -// private. See the comment above the GenerateLocalityOptimizedAntiJoin rule for -// more details. +// GetLocalityOptimizedAntiJoinLookupExprs returns the local and remote lookup +// expressions needed to build a locality optimized anti join from the given +// lookup join private, if possible. Otherwise, it returns ok=false. See the +// comment above the GenerateLocalityOptimizedAntiJoin rule for more details. func (c *CustomFuncs) GetLocalityOptimizedAntiJoinLookupExprs( input memo.RelExpr, private *memo.LookupJoinPrivate, -) LocalAndRemoteLookupExprs { +) (localExpr memo.FiltersExpr, remoteExpr memo.FiltersExpr, ok bool) { // Respect the session setting LocalityOptimizedSearch. if !c.e.evalCtx.SessionData.LocalityOptimizedSearch { - return LocalAndRemoteLookupExprs{} + return nil, nil, false } // Check whether this lookup join has already been locality optimized. if private.LocalityOptimized { - return LocalAndRemoteLookupExprs{} + return nil, nil, false } // We can only apply this optimization to anti-joins. if private.JoinType != opt.AntiJoinOp { - return LocalAndRemoteLookupExprs{} + return nil, nil, false } // This lookup join cannot not be part of a paired join. if private.IsSecondJoinInPairedJoiner { - return LocalAndRemoteLookupExprs{} + return nil, nil, false } // This lookup join should have the LookupExpr filled in, indicating that one // or more of the join filters constrain an index column to multiple constant // values. if private.LookupExpr == nil { - return LocalAndRemoteLookupExprs{} + return nil, nil, false } // The local region must be set, or we won't be able to determine which // partitions are local. localRegion, found := c.e.evalCtx.Locality.Find(regionKey) if !found { - return LocalAndRemoteLookupExprs{} + return nil, nil, false } // There should be at least two partitions, or we won't be able to @@ -1239,7 +1213,7 @@ func (c *CustomFuncs) GetLocalityOptimizedAntiJoinLookupExprs( tabMeta := c.e.mem.Metadata().TableMeta(private.Table) index := tabMeta.Table.Index(private.Index) if index.PartitionCount() < 2 { - return LocalAndRemoteLookupExprs{} + return nil, nil, false } // Determine whether the index has both local and remote partitions. @@ -1252,13 +1226,13 @@ func (c *CustomFuncs) GetLocalityOptimizedAntiJoinLookupExprs( } if localPartitions.Len() == 0 || localPartitions.Len() == index.PartitionCount() { // The partitions are either all local or all remote. - return LocalAndRemoteLookupExprs{} + return nil, nil, false } // Find a filter that constrains the first column of the index. filterIdx, ok := c.getConstPrefixFilter(index, private.Table, private.LookupExpr) if !ok { - return LocalAndRemoteLookupExprs{} + return nil, nil, false } filter := private.LookupExpr[filterIdx] @@ -1267,14 +1241,14 @@ func (c *CustomFuncs) GetLocalityOptimizedAntiJoinLookupExprs( // can target a local partition and one can target a remote partition. col, vals, ok := filter.ScalarProps().Constraints.HasSingleColumnConstValues(c.e.evalCtx) if !ok || len(vals) < 2 { - return LocalAndRemoteLookupExprs{} + return nil, nil, false } // Determine whether the values target both local and remote partitions. localValOrds := c.getLocalValues(index, localPartitions, vals) if localValOrds.Len() == 0 || localValOrds.Len() == len(vals) { // The values target all local or all remote partitions. - return LocalAndRemoteLookupExprs{} + return nil, nil, false } // Split the values into local and remote sets. @@ -1283,20 +1257,17 @@ func (c *CustomFuncs) GetLocalityOptimizedAntiJoinLookupExprs( // Copy all of the filters from the LookupExpr, and replace the filter that // constrains the first index column with a filter targeting only local // partitions or only remote partitions. - localExpr := make(memo.FiltersExpr, len(private.LookupExpr)) + localExpr = make(memo.FiltersExpr, len(private.LookupExpr)) copy(localExpr, private.LookupExpr) localExpr[filterIdx] = c.makeConstFilter(col, localValues) - remoteExpr := make(memo.FiltersExpr, len(private.LookupExpr)) + remoteExpr = make(memo.FiltersExpr, len(private.LookupExpr)) copy(remoteExpr, private.LookupExpr) remoteExpr[filterIdx] = c.makeConstFilter(col, remoteValues) // Return the two sets of lookup expressions. They will be used to construct // two nested anti joins. - return LocalAndRemoteLookupExprs{ - Local: localExpr, - Remote: remoteExpr, - } + return localExpr, remoteExpr, true } // getConstPrefixFilter finds the position of the filter in the given slice of diff --git a/pkg/sql/opt/xform/limit_funcs.go b/pkg/sql/opt/xform/limit_funcs.go index e23abd094814..81c2166d7d88 100644 --- a/pkg/sql/opt/xform/limit_funcs.go +++ b/pkg/sql/opt/xform/limit_funcs.go @@ -170,20 +170,21 @@ func (c *CustomFuncs) ScanIsInverted(sp *memo.ScanPrivate) bool { } // SplitScanIntoUnionScans returns a Union of Scan operators with hard limits -// that each scan over a single key from the original Scan's constraints. This -// is beneficial in cases where the original Scan had to scan over many rows but -// had relatively few keys to scan over. +// that each scan over a single key from the original Scan's constraints. If no +// such Union of Scans can be found, ok=false is returned. This is beneficial in +// cases where the original Scan had to scan over many rows but had relatively +// few keys to scan over. // TODO(drewk): handle inverted scans. func (c *CustomFuncs) SplitScanIntoUnionScans( limitOrdering props.OrderingChoice, scan memo.RelExpr, sp *memo.ScanPrivate, limit tree.Datum, -) memo.RelExpr { +) (_ memo.RelExpr, ok bool) { const maxScanCount = 16 const threshold = 4 cons, ok := c.getKnownScanConstraint(sp) if !ok { // No valid constraint was found. - return nil + return nil, false } // Find the length of the prefix of index columns preceding the first limit @@ -194,7 +195,7 @@ func (c *CustomFuncs) SplitScanIntoUnionScans( // if len(limitOrdering.Columns) == 0 { // This case can be handled by GenerateLimitedScans. - return nil + return nil, false } keyPrefixLength := cons.Columns.Count() for i := 0; i < cons.Columns.Count(); i++ { @@ -205,7 +206,7 @@ func (c *CustomFuncs) SplitScanIntoUnionScans( } if keyPrefixLength == 0 { // This case can be handled by GenerateLimitedScans. - return nil + return nil, false } keyCtx := constraint.MakeKeyContext(&cons.Columns, c.e.evalCtx) @@ -230,7 +231,7 @@ func (c *CustomFuncs) SplitScanIntoUnionScans( } if keyCount <= 0 || (keyCount == 1 && spans.Count() == 1) || budgetExceededIndex == 0 { // Ensure that at least one new Scan will be constructed. - return nil + return nil, false } scanCount := keyCount @@ -246,7 +247,7 @@ func (c *CustomFuncs) SplitScanIntoUnionScans( // Splitting the Scan may not be worth the overhead. Creating a sequence of // Scans and Unions is expensive, so we only want to create the plan if it // is likely to be used. - return nil + return nil, false } // The index ordering must have a prefix of columns of length keyLength @@ -254,7 +255,7 @@ func (c *CustomFuncs) SplitScanIntoUnionScans( hasLimitOrderingSeq, reverse := indexHasOrderingSequence( c.e.mem.Metadata(), scan, sp, limitOrdering, keyPrefixLength) if !hasLimitOrderingSeq { - return nil + return nil, false } newHardLimit := memo.MakeScanLimit(int64(limitVal), reverse) @@ -313,11 +314,11 @@ func (c *CustomFuncs) SplitScanIntoUnionScans( // Expect to generate at least one new limited single-key Scan. This could // happen if a valid key count could be obtained for at least span, but no // span could be split into single-key spans. - return nil + return nil, false } if noLimitSpans.Count() == 0 { // All spans could be used to generate limited Scans. - return last + return last, true } // If any spans could not be used to generate limited Scans, use them to @@ -328,7 +329,7 @@ func (c *CustomFuncs) SplitScanIntoUnionScans( Spans: noLimitSpans, }) newScan := c.e.f.ConstructScan(newScanPrivate) - return makeNewUnion(last, newScan, sp.Cols.ToList()) + return makeNewUnion(last, newScan, sp.Cols.ToList()), true } // indexHasOrderingSequence returns whether the Scan can provide a given diff --git a/pkg/sql/opt/xform/rules/join.opt b/pkg/sql/opt/xform/rules/join.opt index c239c385e5c7..47c0b90cd204 100644 --- a/pkg/sql/opt/xform/rules/join.opt +++ b/pkg/sql/opt/xform/rules/join.opt @@ -404,11 +404,16 @@ $input:* $on:* $private:* & - (LocalAndRemoteLookupExprsSucceeded - $localAndRemoteLookupExprs:(GetLocalityOptimizedAntiJoinLookupExprs + (Let + ( + $localExpr + $remoteExpr + $ok + ):(GetLocalityOptimizedAntiJoinLookupExprs $input $private ) + $ok ) ) => @@ -417,13 +422,13 @@ $input $on (CreateLocalityOptimizedAntiLookupJoinPrivate - (LocalLookupExpr $localAndRemoteLookupExprs) + $localExpr $private ) ) $on (CreateLocalityOptimizedAntiLookupJoinPrivate - (RemoteLookupExpr $localAndRemoteLookupExprs) + $remoteExpr $private ) ) diff --git a/pkg/sql/opt/xform/rules/limit.opt b/pkg/sql/opt/xform/rules/limit.opt index f2bd00bc71fe..0bf0aa3721c8 100644 --- a/pkg/sql/opt/xform/rules/limit.opt +++ b/pkg/sql/opt/xform/rules/limit.opt @@ -82,13 +82,14 @@ ^(ScanIsInverted $scanPrivate) $limitExpr:(Const $limit:*) & (IsPositiveInt $limit) $ordering:* & - (Succeeded - $unionScans:(SplitScanIntoUnionScans + (Let + ($unionScans $ok):(SplitScanIntoUnionScans $ordering $scan $scanPrivate $limit ) + $ok ) ) =>