Skip to content

Commit

Permalink
norm: update prune cols to match PruneJoinLeftCols/PruneJoinRightCols…
Browse files Browse the repository at this point in the history
… logic

normalization rules to avoid pruning columns which might be needed when
deriving new predicates based on foreign key constraints for lookup join.

However, this caused a problem where rules might sometimes fire in an
infinite loop because the same columns to prune keep getting added as
PruneCols in calls to DerivePruneCols. The logic in prune_cols.opt and
DerivePruneCols must be kept in sync to avoid such problems, and this
PR brings it back in sync.

Epic: none
Fixes: #100478

Release note: None
  • Loading branch information
Mark Sirek committed Apr 6, 2023
1 parent 42abc0c commit 929ed4a
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 19 deletions.
40 changes: 25 additions & 15 deletions pkg/sql/opt/norm/prune_cols_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ func (c *CustomFuncs) NeededMutationFetchCols(
// needed columns from that. See the props.Relational.Rule.PruneCols comment for
// more details.
func (c *CustomFuncs) CanPruneCols(target memo.RelExpr, neededCols opt.ColSet) bool {
return !DerivePruneCols(target, c.f.disabledRules).SubsetOf(neededCols)
return !c.DerivePruneCols(target, c.f.disabledRules).SubsetOf(neededCols)
}

// CanPruneAggCols returns true if one or more of the target aggregations is not
Expand Down Expand Up @@ -310,7 +310,7 @@ func (c *CustomFuncs) PruneCols(target memo.RelExpr, neededCols opt.ColSet) memo
// Get the subset of the target expression's output columns that should
// not be pruned. Don't prune if the target output column is needed by a
// higher-level expression, or if it's not part of the PruneCols set.
pruneCols := DerivePruneCols(target, c.f.disabledRules).Difference(neededCols)
pruneCols := c.DerivePruneCols(target, c.f.disabledRules).Difference(neededCols)
colSet := c.OutputCols(target).Difference(pruneCols)
return c.f.ConstructProject(target, memo.EmptyProjectionsExpr, colSet)
}
Expand Down Expand Up @@ -505,7 +505,7 @@ func (c *CustomFuncs) PruneWindows(needed opt.ColSet, windows memo.WindowsExpr)
// are randomly disabled for testing. It is used to prevent propagating the
// PruneCols property when the corresponding column-pruning normalization rule
// is disabled. This prevents rule cycles during testing.
func DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
func (c *CustomFuncs) DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
relProps := e.Relational()
if relProps.IsAvailable(props.PruneCols) {
return relProps.Rule.PruneCols
Expand Down Expand Up @@ -546,7 +546,7 @@ func DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
// Any pruneable input columns can potentially be pruned, as long as they're
// not used by the filter.
sel := e.(*memo.SelectExpr)
relProps.Rule.PruneCols = DerivePruneCols(sel.Input, disabledRules).Copy()
relProps.Rule.PruneCols = c.DerivePruneCols(sel.Input, disabledRules).Copy()
usedCols := sel.Filters.OuterCols()
relProps.Rule.PruneCols.DifferenceWith(usedCols)

Expand All @@ -572,9 +572,9 @@ func DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
// long as they're not used by the right input (i.e. in Apply case) or by
// the join filter.
left := e.Child(0).(memo.RelExpr)
leftPruneCols := DerivePruneCols(left, disabledRules)
leftPruneCols := c.DerivePruneCols(left, disabledRules)
right := e.Child(1).(memo.RelExpr)
rightPruneCols := DerivePruneCols(right, disabledRules)
rightPruneCols := c.DerivePruneCols(right, disabledRules)

switch e.Op() {
case opt.SemiJoinOp, opt.SemiJoinApplyOp, opt.AntiJoinOp, opt.AntiJoinApplyOp:
Expand All @@ -584,8 +584,18 @@ func DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
relProps.Rule.PruneCols = leftPruneCols.Union(rightPruneCols)
}
relProps.Rule.PruneCols.DifferenceWith(right.Relational().OuterCols)
onCols := e.Child(2).(*memo.FiltersExpr).OuterCols()
relProps.Rule.PruneCols.DifferenceWith(onCols)
// Subtract out not just ON clause OuterCols, but also OuterCols of any
// conditions which might be subsequently derived. Ideally the derived
// conditions would always be placed directly in the ON clause, but as this
// may cause selectivity underestimation, until #91142 is fixed we'll
// derive these terms once here for pruning purposes, and once later when
// building a lookup join that could use them.
// TODO(msirek): Remove the AddDerivedOnClauseConditionsFromFKContraints
// call once #91142 is fixed.
onClause := e.Child(2).(*memo.FiltersExpr)
explicitPlusDerivedOnCols := c.AddDerivedOnClauseConditionsFromFKContraints(
*onClause, left, right).OuterCols()
relProps.Rule.PruneCols.DifferenceWith(explicitPlusDerivedOnCols)

case opt.GroupByOp, opt.ScalarGroupByOp, opt.DistinctOnOp, opt.EnsureDistinctOnOp:
if disabledRules.Contains(int(opt.PruneGroupByCols)) ||
Expand All @@ -610,7 +620,7 @@ func DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
}
// Any pruneable input columns can potentially be pruned, as long as
// they're not used as an ordering column.
inputPruneCols := DerivePruneCols(e.Child(0).(memo.RelExpr), disabledRules)
inputPruneCols := c.DerivePruneCols(e.Child(0).(memo.RelExpr), disabledRules)
ordering := e.Private().(*props.OrderingChoice).ColSet()
relProps.Rule.PruneCols = inputPruneCols.Difference(ordering)

Expand All @@ -624,7 +634,7 @@ func DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
// cannot be pruned without adding an additional Project operator, so
// don't add it to the set.
ord := e.(*memo.OrdinalityExpr)
inputPruneCols := DerivePruneCols(ord.Input, disabledRules)
inputPruneCols := c.DerivePruneCols(ord.Input, disabledRules)
relProps.Rule.PruneCols = inputPruneCols.Difference(ord.Ordering.ColSet())

case opt.IndexJoinOp, opt.LookupJoinOp, opt.MergeJoinOp:
Expand All @@ -644,7 +654,7 @@ func DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
// TODO(rytaft): It may be possible to prune Zip columns, but we need to
// make sure that we still get the correct number of rows in the output.
projectSet := e.(*memo.ProjectSetExpr)
relProps.Rule.PruneCols = DerivePruneCols(projectSet.Input, disabledRules).Copy()
relProps.Rule.PruneCols = c.DerivePruneCols(projectSet.Input, disabledRules).Copy()
usedCols := projectSet.Zip.OuterCols()
relProps.Rule.PruneCols.DifferenceWith(usedCols)

Expand All @@ -656,8 +666,8 @@ func DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
// Pruning can be beneficial as long as one of our inputs has advertised pruning,
// so that we can push down the project and eliminate the advertisement.
u := e.(*memo.UnionAllExpr)
pruneFromLeft := opt.TranslateColSet(DerivePruneCols(u.Left, disabledRules), u.LeftCols, u.OutCols)
pruneFromRight := opt.TranslateColSet(DerivePruneCols(u.Right, disabledRules), u.RightCols, u.OutCols)
pruneFromLeft := opt.TranslateColSet(c.DerivePruneCols(u.Left, disabledRules), u.LeftCols, u.OutCols)
pruneFromRight := opt.TranslateColSet(c.DerivePruneCols(u.Right, disabledRules), u.RightCols, u.OutCols)
relProps.Rule.PruneCols = pruneFromLeft.Union(pruneFromRight)

case opt.WindowOp:
Expand All @@ -667,7 +677,7 @@ func DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
break
}
win := e.(*memo.WindowExpr)
relProps.Rule.PruneCols = DerivePruneCols(win.Input, disabledRules).Copy()
relProps.Rule.PruneCols = c.DerivePruneCols(win.Input, disabledRules).Copy()
relProps.Rule.PruneCols.DifferenceWith(win.Partition)
relProps.Rule.PruneCols.DifferenceWith(win.Ordering.ColSet())
for _, w := range win.Windows {
Expand All @@ -682,7 +692,7 @@ func DerivePruneCols(e memo.RelExpr, disabledRules intsets.Fast) opt.ColSet {
}
// WithOp passes through its input unchanged, so it has the same pruning
// characteristics as its input.
relProps.Rule.PruneCols = DerivePruneCols(e.(*memo.WithExpr).Main, disabledRules)
relProps.Rule.PruneCols = c.DerivePruneCols(e.(*memo.WithExpr).Main, disabledRules)

default:
// Don't allow any columns to be pruned, since that would trigger the
Expand Down
40 changes: 40 additions & 0 deletions pkg/sql/opt/norm/testdata/rules/prune_cols
Original file line number Diff line number Diff line change
Expand Up @@ -5292,3 +5292,43 @@ project
└── aggregations
└── const-agg
└── concat_agg

exec-ddl
CREATE TABLE p100478 (
region STRING NOT NULL,
id INT,
PRIMARY KEY (region, id),
UNIQUE INDEX p_id_idx (id)
)
----

exec-ddl
CREATE TABLE c100478 (
region STRING NOT NULL,
id INT PRIMARY KEY,
p_id INT NOT NULL,
FOREIGN KEY (region, p_id) REFERENCES p100478 (region, id)
)
----

# A join which doesn't prune columns which could potentially appear in derived
# ON clause conditions should not result in infinite rule recursion.
norm expect=EliminateGroupByProject expect=PruneJoinLeftCols expect=PruneJoinRightCols
SELECT p.id FROM p100478 p JOIN c100478 ON p_id = p.id GROUP BY p.id
----
distinct-on
├── columns: id:2!null
├── grouping columns: p.id:2!null
├── key: (2)
└── inner-join (hash)
├── columns: p.region:1!null p.id:2!null c100478.region:5!null p_id:7!null
├── multiplicity: left-rows(zero-or-more), right-rows(exactly-one)
├── fd: (2)-->(1), (2)==(7), (7)==(2)
├── scan p100478 [as=p]
│ ├── columns: p.region:1!null p.id:2!null
│ ├── key: (2)
│ └── fd: (2)-->(1)
├── scan c100478
│ └── columns: c100478.region:5!null p_id:7!null
└── filters
└── p_id:7 = p.id:2 [outer=(2,7), constraints=(/2: (/NULL - ]; /7: (/NULL - ]), fd=(2)==(7), (7)==(2)]
14 changes: 10 additions & 4 deletions pkg/sql/opt/testutils/opttester/opt_tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ type OptTester struct {
appliedRules RuleSet

builder strings.Builder
f *norm.Factory
}

// Flags are control knobs for tests. Note that specific testcases can
Expand Down Expand Up @@ -270,6 +271,8 @@ func New(catalog cat.Catalog, sql string) *OptTester {
semaCtx: tree.MakeSemaContext(),
evalCtx: eval.MakeTestingEvalContext(cluster.MakeTestingClusterSettings()),
}
ot.f = &norm.Factory{}
ot.f.Init(ot.ctx, &ot.evalCtx, ot.catalog)
ot.evalCtx.SessionData().ReorderJoinsLimit = opt.DefaultJoinOrderLimit
ot.evalCtx.SessionData().OptimizerUseMultiColStats = true
ot.Flags.ctx = ot.ctx
Expand Down Expand Up @@ -854,7 +857,7 @@ func (ot *OptTester) checkExpectedRules(tb testing.TB, d *datadriven.TestData) {
}

func (ot *OptTester) postProcess(tb testing.TB, d *datadriven.TestData, e opt.Expr) {
fillInLazyProps(e)
ot.fillInLazyProps(e)

if rel, ok := e.(memo.RelExpr); ok {
for _, cols := range ot.Flags.ColStats {
Expand All @@ -865,13 +868,13 @@ func (ot *OptTester) postProcess(tb testing.TB, d *datadriven.TestData, e opt.Ex
}

// Fills in lazily-derived properties (for display).
func fillInLazyProps(e opt.Expr) {
func (ot *OptTester) fillInLazyProps(e opt.Expr) {
if rel, ok := e.(memo.RelExpr); ok {
// These properties are derived from the normalized expression.
rel = rel.FirstExpr()

// Derive columns that are candidates for pruning.
norm.DerivePruneCols(rel, intsets.Fast{} /* disabledRules */)
ot.f.CustomFuncs().DerivePruneCols(rel, intsets.Fast{} /* disabledRules */)

// Derive columns that are candidates for null rejection.
norm.DeriveRejectNullCols(rel, intsets.Fast{} /* disabledRules */)
Expand All @@ -881,7 +884,7 @@ func fillInLazyProps(e opt.Expr) {
}

for i, n := 0, e.ChildCount(); i < n; i++ {
fillInLazyProps(e.Child(i))
ot.fillInLazyProps(e.Child(i))
}
}

Expand Down Expand Up @@ -1283,6 +1286,7 @@ func (ot *OptTester) Memo() (string, error) {
func (ot *OptTester) Expr() (opt.Expr, error) {
var f norm.Factory
f.Init(ot.ctx, &ot.evalCtx, ot.catalog)
ot.f = &f
f.DisableOptimizations()

return exprgen.Build(ot.ctx, ot.catalog, &f, ot.sql)
Expand All @@ -1293,6 +1297,7 @@ func (ot *OptTester) Expr() (opt.Expr, error) {
func (ot *OptTester) ExprNorm() (opt.Expr, error) {
var f norm.Factory
f.Init(ot.ctx, &ot.evalCtx, ot.catalog)
ot.f = &f
f.SetDisabledRules(ot.Flags.DisableRules)

if !ot.Flags.NoStableFolds {
Expand Down Expand Up @@ -2245,6 +2250,7 @@ func (ot *OptTester) buildExpr(factory *norm.Factory) error {
func (ot *OptTester) makeOptimizer() *xform.Optimizer {
var o xform.Optimizer
o.Init(ot.ctx, &ot.evalCtx, ot.catalog)
ot.f = o.Factory()
o.Factory().SetDisabledRules(ot.Flags.DisableRules)
o.NotifyOnAppliedRule(func(ruleName opt.RuleName, source, target opt.Expr) {
// Exploration rules are marked as "applied" if they generate one or
Expand Down

0 comments on commit 929ed4a

Please sign in to comment.