diff --git a/pkg/sql/opt/norm/general_funcs.go b/pkg/sql/opt/norm/general_funcs.go index 8bc9611fb469..e31f2e0f657c 100644 --- a/pkg/sql/opt/norm/general_funcs.go +++ b/pkg/sql/opt/norm/general_funcs.go @@ -274,13 +274,12 @@ func (c *CustomFuncs) DuplicateColumnIDs( table opt.TableID, cols opt.ColSet, ) (opt.TableID, opt.ColSet) { md := c.mem.Metadata() - tabMeta := md.TableMeta(table) newTableID := md.DuplicateTable(table, c.RemapCols) // Build a new set of column IDs from the new TableMeta. var newColIDs opt.ColSet for col, ok := cols.Next(0); ok; col, ok = cols.Next(col + 1) { - ord := tabMeta.MetaID.ColumnOrdinal(col) + ord := table.ColumnOrdinal(col) newColID := newTableID.ColumnID(ord) newColIDs.Add(newColID) } diff --git a/pkg/sql/opt/xform/general_funcs.go b/pkg/sql/opt/xform/general_funcs.go index 6690e8764801..e65e09df645b 100644 --- a/pkg/sql/opt/xform/general_funcs.go +++ b/pkg/sql/opt/xform/general_funcs.go @@ -68,41 +68,33 @@ func (c *CustomFuncs) HasInvertedIndexes(scanPrivate *memo.ScanPrivate) bool { return false } -// MapFilterCols returns a new FiltersExpr with all the src column IDs in -// the input expression replaced with column IDs in dst. -// -// NOTE: Every ColumnID in src must map to the a ColumnID in dst with the same -// relative position in the ColSets. For example, if src and dst are (1, 5, 6) -// and (7, 12, 15), then the following mapping would be applied: -// -// 1 => 7 -// 5 => 12 -// 6 => 15 -func (c *CustomFuncs) MapFilterCols( - filters memo.FiltersExpr, src, dst opt.ColSet, +// RemapScanColsInFilter returns a new FiltersExpr where columns in src's table +// are replaced with columns of the same ordinal in dst's table. src and dst +// must scan the same base table. +func (c *CustomFuncs) RemapScanColsInFilter( + filters memo.FiltersExpr, src, dst *memo.ScanPrivate, ) memo.FiltersExpr { - newFilters := c.mapScalarExprCols(&filters, src, dst).(*memo.FiltersExpr) + newFilters := c.remapScanColsInScalarExpr(&filters, src, dst).(*memo.FiltersExpr) return *newFilters } -func (c *CustomFuncs) mapScalarExprCols(scalar opt.ScalarExpr, src, dst opt.ColSet) opt.ScalarExpr { - if src.Len() != dst.Len() { - panic(errors.AssertionFailedf( - "src and dst must have the same number of columns, src: %v, dst: %v", - src, - dst, - )) +func (c *CustomFuncs) remapScanColsInScalarExpr( + scalar opt.ScalarExpr, src, dst *memo.ScanPrivate, +) opt.ScalarExpr { + md := c.e.mem.Metadata() + if md.Table(src.Table).ID() != md.Table(dst.Table).ID() { + panic(errors.AssertionFailedf("scans must have the same base table")) } - - // Map each column in src to a column in dst based on the relative position - // of both the src and dst ColumnIDs in the ColSet. + if src.Cols.Len() != dst.Cols.Len() { + panic(errors.AssertionFailedf("scans must have the same number of columns")) + } + // Remap each column in src to a column in dst. var colMap opt.ColMap - dstCol, _ := dst.Next(0) - for srcCol, ok := src.Next(0); ok; srcCol, ok = src.Next(srcCol + 1) { + for srcCol, ok := src.Cols.Next(0); ok; srcCol, ok = src.Cols.Next(srcCol + 1) { + ord := src.Table.ColumnOrdinal(srcCol) + dstCol := dst.Table.ColumnID(ord) colMap.Set(int(srcCol), int(dstCol)) - dstCol, _ = dst.Next(dstCol + 1) } - return c.RemapCols(scalar, colMap) } @@ -523,9 +515,21 @@ func (c *CustomFuncs) splitScanIntoUnionScansOrSelects( } for j, m := 0, singleKeySpans.Count(); j < m; j++ { // Construct a new Scan for each span. - newScanOrSelect := c.makeNewScan(sp, cons.Columns, newHardLimit, singleKeySpans.Get(j)) + newScanPrivate := c.makeNewScanPrivate( + sp, + cons.Columns, + newHardLimit, + singleKeySpans.Get(j), + ) + newScanOrSelect := c.e.f.ConstructScan(newScanPrivate) if !filters.IsTrue() { - newScanOrSelect = c.wrapScanInLimitedSelect(newScanOrSelect, sp, filters, limit) + newScanOrSelect = c.wrapScanInLimitedSelect( + newScanOrSelect, + sp, + newScanPrivate, + filters, + limit, + ) } queue.PushBack(newScanOrSelect) queueLength++ @@ -561,6 +565,10 @@ func (c *CustomFuncs) splitScanIntoUnionScansOrSelects( // Not necessary, but keep spans with lower values in the left subtree. left, right = right, left } + // TODO(mgartner/msirek): Converting ColSets to ColLists here is only safe + // because column IDs are always allocated in a consistent, ascending order + // for each duplicated table in the metadata. If column ID allocation + // changes, this could break. if noLimitSpans.Count() == 0 && queue.Len() == 0 { outCols = sp.Cols.ToList() } else { @@ -594,8 +602,12 @@ func (c *CustomFuncs) splitScanIntoUnionScansOrSelects( }) newScanOrSelect := c.e.f.ConstructScan(newScanPrivate) if !filters.IsTrue() { - newScanOrSelect = c.wrapScanInLimitedSelect(newScanOrSelect, sp, filters, limit) + newScanOrSelect = c.wrapScanInLimitedSelect(newScanOrSelect, sp, newScanPrivate, filters, limit) } + // TODO(mgartner/msirek): Converting ColSets to ColLists here is only safe + // because column IDs are always allocated in a consistent, ascending order + // for each duplicated table in the metadata. If column ID allocation + // changes, this could break. return makeNewUnion(last, newScanOrSelect, sp.Cols.ToList()), true } @@ -651,12 +663,15 @@ func (c *CustomFuncs) numAllowedValues( // the originalScanPrivate columns to the columns in scan. If limit is non-zero, // the SelectExpr is wrapped in a LimitExpr with that limit. func (c *CustomFuncs) wrapScanInLimitedSelect( - scan memo.RelExpr, originalScanPrivate *memo.ScanPrivate, filters memo.FiltersExpr, limit int, + scan memo.RelExpr, + originalScanPrivate, newScanPrivate *memo.ScanPrivate, + filters memo.FiltersExpr, + limit int, ) (limitedSelect memo.RelExpr) { - limitedSelect = - c.e.f.ConstructSelect(scan, - c.MapFilterCols(filters, originalScanPrivate.Cols, - c.OutputCols(scan))) + limitedSelect = c.e.f.ConstructSelect( + scan, + c.RemapScanColsInFilter(filters, originalScanPrivate, newScanPrivate), + ) if limit != 0 { limitedSelect = c.e.f.ConstructLimit( limitedSelect, @@ -729,16 +744,16 @@ func indexHasOrderingSequence( return ordering.ScanPrivateCanProvide(md, sp, &requiredOrdering) } -// makeNewScan constructs a new Scan operator with a new TableID and the given -// limit and span. All ColumnIDs and references to those ColumnIDs are -// replaced with new ones from the new TableID. All other fields are simply -// copied from the old ScanPrivate. -func (c *CustomFuncs) makeNewScan( +// makeNewScanPrivate returns a new ScanPrivate with a new TableID and the given +// limit and span. All ColumnIDs and references to those ColumnIDs are replaced +// with new ones from the new TableID. All other fields are simply copied from +// the old ScanPrivate. +func (c *CustomFuncs) makeNewScanPrivate( sp *memo.ScanPrivate, columns constraint.Columns, newHardLimit memo.ScanLimit, span *constraint.Span, -) memo.RelExpr { +) *memo.ScanPrivate { newScanPrivate := c.DuplicateScanPrivate(sp) // duplicateScanPrivate does not initialize the Constraint or HardLimit @@ -755,7 +770,7 @@ func (c *CustomFuncs) makeNewScan( } newScanPrivate.SetConstraint(c.e.evalCtx, newConstraint) - return c.e.f.ConstructScan(newScanPrivate) + return newScanPrivate } // getKnownScanConstraint returns a Constraint that is known to hold true for diff --git a/pkg/sql/opt/xform/groupby_funcs.go b/pkg/sql/opt/xform/groupby_funcs.go index 4076e590b0ff..a85f47cb4639 100644 --- a/pkg/sql/opt/xform/groupby_funcs.go +++ b/pkg/sql/opt/xform/groupby_funcs.go @@ -59,7 +59,7 @@ func (c *CustomFuncs) MakeMinMaxScalarSubqueriesWithFilter( // If the input to the scalar group by is a Select with filters, remap the // column IDs in the filters and use that to build a new Select. if len(filters) > 0 { - newFilters := c.MapFilterCols(filters, scanPrivate.Cols, newScanPrivate.Cols) + newFilters := c.RemapScanColsInFilter(filters, scanPrivate, newScanPrivate) inputExpr = c.e.f.ConstructSelect(inputExpr, newFilters) } @@ -70,7 +70,7 @@ func (c *CustomFuncs) MakeMinMaxScalarSubqueriesWithFilter( if !ok { panic(errors.AssertionFailedf("expected a variable as input to the aggregate, but found %T", aggs[i].Agg.Child(0))) } - newVarExpr := c.mapScalarExprCols(variable, scanPrivate.Cols, newScanPrivate.Cols) + newVarExpr := c.remapScanColsInScalarExpr(variable, scanPrivate, newScanPrivate) var newAggrFunc opt.ScalarExpr switch aggs[i].Agg.(type) { case *memo.MaxExpr: diff --git a/pkg/sql/opt/xform/join_funcs.go b/pkg/sql/opt/xform/join_funcs.go index f065c26e9478..995c39eec552 100644 --- a/pkg/sql/opt/xform/join_funcs.go +++ b/pkg/sql/opt/xform/join_funcs.go @@ -1215,21 +1215,27 @@ func (c *CustomFuncs) mapInvertedJoin( newIndexCols.Add(prefixCol) } - // Get the source and destination ColSets, including the inverted source - // columns, which will be used in the invertedExpr. - srcCols := indexCols.Copy() - dstCols := newIndexCols.Copy() + // Create a map from the source columns to the destination columns, + // including the inverted source columns which will be used in the + // invertedExpr. + var srcColsToDstCols opt.ColMap + for srcCol, ok := indexCols.Next(0); ok; srcCol, ok = indexCols.Next(srcCol + 1) { + ord := tabID.ColumnOrdinal(srcCol) + dstCol := newTabID.ColumnID(ord) + srcColsToDstCols.Set(int(srcCol), int(dstCol)) + } ord := index.InvertedColumn().InvertedSourceColumnOrdinal() invertedSourceCol := tabID.ColumnID(ord) newInvertedSourceCol := newTabID.ColumnID(ord) - srcCols.Add(invertedSourceCol) - dstCols.Add(newInvertedSourceCol) + srcColsToDstCols.Set(int(invertedSourceCol), int(newInvertedSourceCol)) invertedJoin.Table = newTabID - invertedJoin.InvertedExpr = c.mapScalarExprCols(invertedJoin.InvertedExpr, srcCols, dstCols) + invertedJoin.InvertedExpr = c.RemapCols(invertedJoin.InvertedExpr, srcColsToDstCols) invertedJoin.Cols = invertedJoin.Cols.Difference(indexCols).Union(newIndexCols) - invertedJoin.ConstFilters = c.MapFilterCols(invertedJoin.ConstFilters, srcCols, dstCols) - invertedJoin.On = c.MapFilterCols(invertedJoin.On, srcCols, dstCols) + constFilters := c.RemapCols(&invertedJoin.ConstFilters, srcColsToDstCols).(*memo.FiltersExpr) + invertedJoin.ConstFilters = *constFilters + on := c.RemapCols(&invertedJoin.On, srcColsToDstCols).(*memo.FiltersExpr) + invertedJoin.On = *on } // findComputedColJoinEquality returns the computed column expression of col and diff --git a/pkg/sql/opt/xform/rules/select.opt b/pkg/sql/opt/xform/rules/select.opt index 8f2d30c949a4..c1a81538357d 100644 --- a/pkg/sql/opt/xform/rules/select.opt +++ b/pkg/sql/opt/xform/rules/select.opt @@ -132,27 +132,35 @@ (DistinctOn (UnionAll (Select - $leftScan:(Scan (DuplicateScanPrivate $scanPrivate)) - (MapFilterCols + $leftScan:(Scan + $leftScanPrivate:(DuplicateScanPrivate + $scanPrivate + ) + ) + (RemapScanColsInFilter (ReplaceFiltersItem $filters $itemToReplace $leftFilter ) - (OutputCols $input) - (OutputCols $leftScan) + $scanPrivate + $leftScanPrivate ) ) (Select - $rightScan:(Scan (DuplicateScanPrivate $scanPrivate)) - (MapFilterCols + $rightScan:(Scan + $rightScanPrivate:(DuplicateScanPrivate + $scanPrivate + ) + ) + (RemapScanColsInFilter (ReplaceFiltersItem $filters $itemToReplace $rightFilter ) - (OutputCols $input) - (OutputCols $rightScan) + $scanPrivate + $rightScanPrivate ) ) (MakeSetPrivate @@ -218,44 +226,48 @@ (Select $leftScan:(Scan (AddPrimaryKeyColsToScanPrivate - (DuplicateScanPrivate $scanPrivate) + $leftScanPrivate:(DuplicateScanPrivate + $scanPrivate + ) ) ) - (MapFilterCols + (RemapScanColsInFilter (ReplaceFiltersItem $filters $itemToReplace $leftFilter ) - $outCols:(UnionCols - (OutputCols $input) - (PrimaryKeyCols - (TableIDFromScanPrivate $scanPrivate) - ) - ) - (OutputCols $leftScan) + $scanPrivate + $leftScanPrivate ) ) (Select $rightScan:(Scan (AddPrimaryKeyColsToScanPrivate - (DuplicateScanPrivate $scanPrivate) + $rightScanPrivate:(DuplicateScanPrivate + $scanPrivate + ) ) ) - (MapFilterCols + (RemapScanColsInFilter (ReplaceFiltersItem $filters $itemToReplace $rightFilter ) - $outCols - (OutputCols $rightScan) + $scanPrivate + $rightScanPrivate ) ) (MakeSetPrivate (OutputCols $leftScan) (OutputCols $rightScan) - $outCols + (UnionCols + (OutputCols $input) + (PrimaryKeyCols + (TableIDFromScanPrivate $scanPrivate) + ) + ) ) ) []