Skip to content

Commit

Permalink
norm: update prune cols to match PruneJoinLeftCols/PruneJoinRightCols…
Browse files Browse the repository at this point in the history
… logic

In cockroachdb#90599 adjustments where made to the PruneJoinLeftCols and PruneJoinRightCols
normalization rules to avoid pruning columns which might be needed when
deriving new predicates based on foreign key constraints for lookup join.

However, this caused a problem where rules might sometimes fire in an
infinite loop because the same columns to prune keep getting added as
PruneCols in calls to DerivePruneCols. The logic in prune_cols.opt and
DerivePruneCols must be kept in sync to avoid such problems, and this
PR brings it back in sync.

Epic: none
Fixes: cockroachdb#100478

Release note: None
  • Loading branch information
Mark Sirek committed Apr 6, 2023
1 parent 927bf58 commit 66ef4ee
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 20 deletions.
40 changes: 25 additions & 15 deletions pkg/sql/opt/norm/prune_cols_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ func (c *CustomFuncs) NeededMutationFetchCols(
// needed columns from that. See the props.Relational.Rule.PruneCols comment for
// more details.
func (c *CustomFuncs) CanPruneCols(target memo.RelExpr, neededCols opt.ColSet) bool {
return !DerivePruneCols(target, c.f.disabledRules).SubsetOf(neededCols)
return !c.DerivePruneCols(target, c.f.disabledRules).SubsetOf(neededCols)
}

// CanPruneAggCols returns true if one or more of the target aggregations is not
Expand Down Expand Up @@ -310,7 +310,7 @@ func (c *CustomFuncs) PruneCols(target memo.RelExpr, neededCols opt.ColSet) memo
// Get the subset of the target expression's output columns that should
// not be pruned. Don't prune if the target output column is needed by a
// higher-level expression, or if it's not part of the PruneCols set.
pruneCols := DerivePruneCols(target, c.f.disabledRules).Difference(neededCols)
pruneCols := c.DerivePruneCols(target, c.f.disabledRules).Difference(neededCols)
colSet := c.OutputCols(target).Difference(pruneCols)
return c.f.ConstructProject(target, memo.EmptyProjectionsExpr, colSet)
}
Expand Down Expand Up @@ -505,7 +505,7 @@ func (c *CustomFuncs) PruneWindows(needed opt.ColSet, windows memo.WindowsExpr)
// are randomly disabled for testing. It is used to prevent propagating the
// PruneCols property when the corresponding column-pruning normalization rule
// is disabled. This prevents rule cycles during testing.
func DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
func (c *CustomFuncs) DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
relProps := e.Relational()
if relProps.IsAvailable(props.PruneCols) {
return relProps.Rule.PruneCols
Expand Down Expand Up @@ -546,7 +546,7 @@ func DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
// Any pruneable input columns can potentially be pruned, as long as they're
// not used by the filter.
sel := e.(*memo.SelectExpr)
relProps.Rule.PruneCols = DerivePruneCols(sel.Input, disabledRules).Copy()
relProps.Rule.PruneCols = c.DerivePruneCols(sel.Input, disabledRules).Copy()
usedCols := sel.Filters.OuterCols()
relProps.Rule.PruneCols.DifferenceWith(usedCols)

Expand All @@ -572,9 +572,9 @@ func DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
// long as they're not used by the right input (i.e. in Apply case) or by
// the join filter.
left := e.Child(0).(memo.RelExpr)
leftPruneCols := DerivePruneCols(left, disabledRules)
leftPruneCols := c.DerivePruneCols(left, disabledRules)
right := e.Child(1).(memo.RelExpr)
rightPruneCols := DerivePruneCols(right, disabledRules)
rightPruneCols := c.DerivePruneCols(right, disabledRules)

switch e.Op() {
case opt.SemiJoinOp, opt.SemiJoinApplyOp, opt.AntiJoinOp, opt.AntiJoinApplyOp:
Expand All @@ -584,8 +584,18 @@ func DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
relProps.Rule.PruneCols = leftPruneCols.Union(rightPruneCols)
}
relProps.Rule.PruneCols.DifferenceWith(right.Relational().OuterCols)
onCols := e.Child(2).(*memo.FiltersExpr).OuterCols()
relProps.Rule.PruneCols.DifferenceWith(onCols)
// Subtract out not just ON clause OuterCols, but also OuterCols of any
// conditions which might be subsequently derived. Ideally the derived
// conditions would always be placed directly in the ON clause, but as this
// may cause selectivity underestimation, until #91142 is fixed we'll
// derive these terms once here for pruning purposes, and once later when
// building a lookup join that could use them.
// TODO(msirek): Remove the AddDerivedOnClauseConditionsFromFKContraints
// call once #91142 is fixed.
onClause := e.Child(2).(*memo.FiltersExpr)
explicitPlusDerivedOnCols := c.AddDerivedOnClauseConditionsFromFKContraints(
*onClause, left, right).OuterCols()
relProps.Rule.PruneCols.DifferenceWith(explicitPlusDerivedOnCols)

case opt.GroupByOp, opt.ScalarGroupByOp, opt.DistinctOnOp, opt.EnsureDistinctOnOp:
if disabledRules.Contains(int(opt.PruneGroupByCols)) ||
Expand All @@ -610,7 +620,7 @@ func DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
}
// Any pruneable input columns can potentially be pruned, as long as
// they're not used as an ordering column.
inputPruneCols := DerivePruneCols(e.Child(0).(memo.RelExpr), disabledRules)
inputPruneCols := c.DerivePruneCols(e.Child(0).(memo.RelExpr), disabledRules)
ordering := e.Private().(*props.OrderingChoice).ColSet()
relProps.Rule.PruneCols = inputPruneCols.Difference(ordering)

Expand All @@ -624,7 +634,7 @@ func DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
// cannot be pruned without adding an additional Project operator, so
// don't add it to the set.
ord := e.(*memo.OrdinalityExpr)
inputPruneCols := DerivePruneCols(ord.Input, disabledRules)
inputPruneCols := c.DerivePruneCols(ord.Input, disabledRules)
relProps.Rule.PruneCols = inputPruneCols.Difference(ord.Ordering.ColSet())

case opt.IndexJoinOp, opt.LookupJoinOp, opt.MergeJoinOp:
Expand All @@ -644,7 +654,7 @@ func DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
// TODO(rytaft): It may be possible to prune Zip columns, but we need to
// make sure that we still get the correct number of rows in the output.
projectSet := e.(*memo.ProjectSetExpr)
relProps.Rule.PruneCols = DerivePruneCols(projectSet.Input, disabledRules).Copy()
relProps.Rule.PruneCols = c.DerivePruneCols(projectSet.Input, disabledRules).Copy()
usedCols := projectSet.Zip.OuterCols()
relProps.Rule.PruneCols.DifferenceWith(usedCols)

Expand All @@ -656,8 +666,8 @@ func DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
// Pruning can be beneficial as long as one of our inputs has advertised pruning,
// so that we can push down the project and eliminate the advertisement.
u := e.(*memo.UnionAllExpr)
pruneFromLeft := opt.TranslateColSet(DerivePruneCols(u.Left, disabledRules), u.LeftCols, u.OutCols)
pruneFromRight := opt.TranslateColSet(DerivePruneCols(u.Right, disabledRules), u.RightCols, u.OutCols)
pruneFromLeft := opt.TranslateColSet(c.DerivePruneCols(u.Left, disabledRules), u.LeftCols, u.OutCols)
pruneFromRight := opt.TranslateColSet(c.DerivePruneCols(u.Right, disabledRules), u.RightCols, u.OutCols)
relProps.Rule.PruneCols = pruneFromLeft.Union(pruneFromRight)

case opt.WindowOp:
Expand All @@ -667,7 +677,7 @@ func DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
break
}
win := e.(*memo.WindowExpr)
relProps.Rule.PruneCols = DerivePruneCols(win.Input, disabledRules).Copy()
relProps.Rule.PruneCols = c.DerivePruneCols(win.Input, disabledRules).Copy()
relProps.Rule.PruneCols.DifferenceWith(win.Partition)
relProps.Rule.PruneCols.DifferenceWith(win.Ordering.ColSet())
for _, w := range win.Windows {
Expand All @@ -682,7 +692,7 @@ func DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
}
// WithOp passes through its input unchanged, so it has the same pruning
// characteristics as its input.
relProps.Rule.PruneCols = DerivePruneCols(e.(*memo.WithExpr).Main, disabledRules)
relProps.Rule.PruneCols = c.DerivePruneCols(e.(*memo.WithExpr).Main, disabledRules)

default:
// Don't allow any columns to be pruned, since that would trigger the
Expand Down
40 changes: 40 additions & 0 deletions pkg/sql/opt/norm/testdata/rules/prune_cols
Original file line number Diff line number Diff line change
Expand Up @@ -5292,3 +5292,43 @@ project
└── aggregations
└── const-agg
└── concat_agg

exec-ddl
CREATE TABLE p100478 (
region STRING NOT NULL,
id INT,
PRIMARY KEY (region, id),
UNIQUE INDEX p_id_idx (id)
)
----

exec-ddl
CREATE TABLE c100478 (
region STRING NOT NULL,
id INT PRIMARY KEY,
p_id INT NOT NULL,
FOREIGN KEY (region, p_id) REFERENCES p100478 (region, id)
)
----

# A join which doesn't prune columns which could potentially appear in derived
# ON clause conditions should not result in infinite rule recursion.
norm expect=(EliminateGroupByProject,PruneJoinLeftCols,PruneJoinRightCols)
SELECT p.id FROM p100478 p JOIN c100478 ON p_id = p.id GROUP BY p.id
----
distinct-on
├── columns: id:2!null
├── grouping columns: p.id:2!null
├── key: (2)
└── inner-join (hash)
├── columns: p.region:1!null p.id:2!null c100478.region:5!null p_id:7!null
├── multiplicity: left-rows(zero-or-more), right-rows(exactly-one)
├── fd: (2)-->(1), (2)==(7), (7)==(2)
├── scan p100478 [as=p]
│ ├── columns: p.region:1!null p.id:2!null
│ ├── key: (2)
│ └── fd: (2)-->(1)
├── scan c100478
│ └── columns: c100478.region:5!null p_id:7!null
└── filters
└── p_id:7 = p.id:2 [outer=(2,7), constraints=(/2: (/NULL - ]; /7: (/NULL - ]), fd=(2)==(7), (7)==(2)]
16 changes: 11 additions & 5 deletions pkg/sql/opt/testutils/opttester/opt_tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ type OptTester struct {
appliedRules RuleSet

builder strings.Builder
f *norm.Factory
}

// Flags are control knobs for tests. Note that specific testcases can
Expand Down Expand Up @@ -270,6 +271,8 @@ func New(catalog cat.Catalog, sql string) *OptTester {
semaCtx: tree.MakeSemaContext(),
evalCtx: eval.MakeTestingEvalContext(cluster.MakeTestingClusterSettings()),
}
ot.f = &norm.Factory{}
ot.f.Init(ot.ctx, &ot.evalCtx, ot.catalog)
ot.evalCtx.SessionData().ReorderJoinsLimit = opt.DefaultJoinOrderLimit
ot.evalCtx.SessionData().OptimizerUseMultiColStats = true
ot.Flags.ctx = ot.ctx
Expand Down Expand Up @@ -854,7 +857,7 @@ func (ot *OptTester) checkExpectedRules(tb testing.TB, d *datadriven.TestData) {
}

func (ot *OptTester) postProcess(tb testing.TB, d *datadriven.TestData, e opt.Expr) {
fillInLazyProps(e)
ot.fillInLazyProps(e)

if rel, ok := e.(memo.RelExpr); ok {
for _, cols := range ot.Flags.ColStats {
Expand All @@ -865,13 +868,13 @@ func (ot *OptTester) postProcess(tb testing.TB, d *datadriven.TestData, e opt.Ex
}

// Fills in lazily-derived properties (for display).
func fillInLazyProps(e opt.Expr) {
func (ot *OptTester) fillInLazyProps(e opt.Expr) {
if rel, ok := e.(memo.RelExpr); ok {
// These properties are derived from the normalized expression.
rel = rel.FirstExpr()

// Derive columns that are candidates for pruning.
norm.DerivePruneCols(rel, intsets.Fast{} /* disabledRules */)
ot.f.CustomFuncs().DerivePruneCols(rel, intsets.Fast{} /* disabledRules */)

// Derive columns that are candidates for null rejection.
norm.DeriveRejectNullCols(rel, intsets.Fast{} /* disabledRules */)
Expand All @@ -881,7 +884,7 @@ func fillInLazyProps(e opt.Expr) {
}

for i, n := 0, e.ChildCount(); i < n; i++ {
fillInLazyProps(e.Child(i))
ot.fillInLazyProps(e.Child(i))
}
}

Expand Down Expand Up @@ -1283,6 +1286,7 @@ func (ot *OptTester) Memo() (string, error) {
func (ot *OptTester) Expr() (opt.Expr, error) {
var f norm.Factory
f.Init(ot.ctx, &ot.evalCtx, ot.catalog)
ot.f = &f
f.DisableOptimizations()

return exprgen.Build(ot.ctx, ot.catalog, &f, ot.sql)
Expand All @@ -1293,6 +1297,7 @@ func (ot *OptTester) Expr() (opt.Expr, error) {
func (ot *OptTester) ExprNorm() (opt.Expr, error) {
var f norm.Factory
f.Init(ot.ctx, &ot.evalCtx, ot.catalog)
ot.f = &f
f.SetDisabledRules(ot.Flags.DisableRules)

if !ot.Flags.NoStableFolds {
Expand Down Expand Up @@ -2245,6 +2250,7 @@ func (ot *OptTester) buildExpr(factory *norm.Factory) error {
func (ot *OptTester) makeOptimizer() *xform.Optimizer {
var o xform.Optimizer
o.Init(ot.ctx, &ot.evalCtx, ot.catalog)
ot.f = o.Factory()
o.Factory().SetDisabledRules(ot.Flags.DisableRules)
o.NotifyOnAppliedRule(func(ruleName opt.RuleName, source, target opt.Expr) {
// Exploration rules are marked as "applied" if they generate one or
Expand Down Expand Up @@ -2272,8 +2278,8 @@ func (ot *OptTester) optimizeExpr(
if err != nil {
return nil, err
}
o.Memo().ResetLogProps(ot.ctx, &ot.evalCtx)
if ot.Flags.PerturbCost != 0 {
o.Memo().ResetLogProps(ot.ctx, &ot.evalCtx)
o.RecomputeCost()
}
return root, nil
Expand Down

0 comments on commit 66ef4ee

Please sign in to comment.