diff --git a/pkg/sql/opt/xform/coster.go b/pkg/sql/opt/xform/coster.go index 04cfa53f01be..b20c29924c73 100644 --- a/pkg/sql/opt/xform/coster.go +++ b/pkg/sql/opt/xform/coster.go @@ -772,7 +772,7 @@ func (c *coster) computeSelectCost(sel *memo.SelectExpr, required *physical.Requ inputRowCount = math.Min(inputRowCount, required.LimitHint/selectivity) } - filterSetup, filterPerRow := c.computeFiltersCost(sel.Filters, util.FastIntMap{}) + filterSetup, filterPerRow := c.computeFiltersCost(sel.Filters, util.FastIntSet{}) cost := memo.Cost(inputRowCount) * filterPerRow cost += filterSetup return cost @@ -827,25 +827,17 @@ func (c *coster) computeHashJoinCost(join memo.RelExpr) memo.Cost { // pressure and the possibility of spilling to disk. cost += c.rowBufferCost(rightRowCount) - // Compute filter cost. Fetch the equality columns so they can be - // ignored later. + // Compute filter cost. Fetch the indices of the filters that will be used in + // the join, since they will not add to the cost and should be skipped. on := join.Child(2).(*memo.FiltersExpr) - leftEq, rightEq, _ := memo.ExtractJoinEqualityColumns( - join.Child(0).(memo.RelExpr).Relational().OutputCols, - join.Child(1).(memo.RelExpr).Relational().OutputCols, - *on, - ) - // Generate a quick way to lookup if two columns are join equality - // columns. We add in both directions because we don't know which way - // the equality filters will be defined. - eqMap := util.FastIntMap{} - for i := range leftEq { - left := int(leftEq[i]) - right := int(rightEq[i]) - eqMap.Set(left, right) - eqMap.Set(right, left) - } - filterSetup, filterPerRow := c.computeFiltersCost(*on, eqMap) + leftCols := join.Child(0).(memo.RelExpr).Relational().OutputCols + rightCols := join.Child(1).(memo.RelExpr).Relational().OutputCols + var filtersToSkip util.FastIntSet + _, _, toSkip := memo.ExtractJoinEqualityColumns(leftCols, rightCols, *on) + for _, idx := range toSkip { + filtersToSkip.Add(idx) + } + filterSetup, filterPerRow := c.computeFiltersCost(*on, filtersToSkip) cost += filterSetup // Add the CPU cost of emitting the rows. @@ -882,7 +874,7 @@ func (c *coster) computeMergeJoinCost(join *memo.MergeJoinExpr) memo.Cost { // smaller right side is preferred to the symmetric join. cost := memo.Cost(0.9*leftRowCount+1.1*rightRowCount) * cpuCostFactor - filterSetup, filterPerRow := c.computeFiltersCost(join.On, util.FastIntMap{}) + filterSetup, filterPerRow := c.computeFiltersCost(join.On, util.FastIntSet{}) cost += filterSetup // Add the CPU cost of emitting the rows. @@ -996,7 +988,7 @@ func (c *coster) computeIndexLookupJoinCost( perLookupCost += lookupExprCost(join) cost := memo.Cost(lookupCount) * perLookupCost - filterSetup, filterPerRow := c.computeFiltersCost(on, util.FastIntMap{}) + filterSetup, filterPerRow := c.computeFiltersCost(on, util.FastIntSet{}) cost += filterSetup // Each lookup might retrieve many rows; add the IO cost of retrieving the @@ -1075,7 +1067,7 @@ func (c *coster) computeInvertedJoinCost( perLookupCost *= 5 cost := memo.Cost(lookupCount) * perLookupCost - filterSetup, filterPerRow := c.computeFiltersCost(join.On, util.FastIntMap{}) + filterSetup, filterPerRow := c.computeFiltersCost(join.On, util.FastIntSet{}) cost += filterSetup // Each lookup might retrieve many rows; add the IO cost of retrieving the @@ -1108,28 +1100,21 @@ func (c *coster) computeExprCost(expr opt.Expr) memo.Cost { // computeFiltersCost returns the setup and per-row cost of executing // a filter. Callers of this function should add setupCost and multiply // perRowCost by the number of rows expected to be filtered. +// +// filtersToSkip identifies the indices of filters that should be skipped, +// because they do not add to the cost. This can happen when a condition still +// exists in the filters even though it is handled by the join. func (c *coster) computeFiltersCost( - filters memo.FiltersExpr, eqMap util.FastIntMap, + filters memo.FiltersExpr, filtersToSkip util.FastIntSet, ) (setupCost, perRowCost memo.Cost) { // Add a base perRowCost so that callers do not need to have their own // base per-row cost. perRowCost += cpuCostFactor for i := range filters { - f := &filters[i] - if f.Condition.Op() == opt.EqOp { - eq := f.Condition.(*memo.EqExpr) - leftVar, lOk := eq.Left.(*memo.VariableExpr) - rightVar, rOk := eq.Right.(*memo.VariableExpr) - if lOk && rOk { - val, ok := eqMap.Get(int(leftVar.Col)) - if ok && val == int(rightVar.Col) { - // Equality filters on some joins are still in - // filters, while others have already removed - // them. They do not cost anything. - continue - } - } + if filtersToSkip.Contains(i) { + continue } + f := &filters[i] perRowCost += c.computeExprCost(f.Condition) // Add a constant "setup" cost per ON condition to account for the fact that // the rowsProcessed estimate alone cannot effectively discriminate between @@ -1156,7 +1141,7 @@ func (c *coster) computeZigzagJoinCost(join *memo.ZigzagJoinExpr) memo.Cost { scanCost := c.rowScanCost(join, join.LeftTable, join.LeftIndex, leftCols, join.Relational().Stats) scanCost += c.rowScanCost(join, join.RightTable, join.RightIndex, rightCols, join.Relational().Stats) - filterSetup, filterPerRow := c.computeFiltersCost(join.On, util.FastIntMap{}) + filterSetup, filterPerRow := c.computeFiltersCost(join.On, util.FastIntSet{}) // Double the cost of emitting rows as well as the cost of seeking rows, // given two indexes will be accessed. diff --git a/pkg/sql/opt/xform/join_funcs.go b/pkg/sql/opt/xform/join_funcs.go index 5f8bc938439d..012bf688c63e 100644 --- a/pkg/sql/opt/xform/join_funcs.go +++ b/pkg/sql/opt/xform/join_funcs.go @@ -93,10 +93,14 @@ func (c *CustomFuncs) GenerateMergeJoins( return } - var colToEq util.FastIntMap - for i := range leftEq { - colToEq.Set(int(leftEq[i]), i) - colToEq.Set(int(rightEq[i]), i) + getEqCols := func(col opt.ColumnID) (left, right opt.ColumnID) { + // Assume that col is in either leftEq or rightEq. + for eqIdx := 0; eqIdx < len(leftEq); eqIdx++ { + if leftEq[eqIdx] == col || rightEq[eqIdx] == col { + return leftEq[eqIdx], rightEq[eqIdx] + } + } + panic(errors.AssertionFailedf("failed to find eqIdx for merge join")) } var remainingFilters memo.FiltersExpr @@ -115,8 +119,7 @@ func (c *CustomFuncs) GenerateMergeJoins( merge.RightOrdering.Columns = make([]props.OrderingColumnChoice, 0, n) addCol := func(col opt.ColumnID, descending bool) { - eqIdx, _ := colToEq.Get(int(col)) - l, r := leftEq[eqIdx], rightEq[eqIdx] + l, r := getEqCols(col) merge.LeftEq = append(merge.LeftEq, opt.MakeOrderingColumn(l, descending)) merge.RightEq = append(merge.RightEq, opt.MakeOrderingColumn(r, descending)) merge.LeftOrdering.AppendCol(l, descending)