Skip to content

Commit

Permalink
Merge #77463
Browse files Browse the repository at this point in the history
77463: opt: fix dangerous MapFilterCols function r=mgartner a=mgartner

The `MapFilterCols` custom function was dangerous because it relied on
the ordering of columns in two unordered sets to map columns referenced
in a filter. As one example, the implementation of the function made it
impossible to map column `1` to column `5` and column `2` to column `4`
in the same filter.

As far as I know this has not caused bugs. It was only used to remap
columns from a scan to columns in a duplicate scan. Column IDs of scans
are allocated in ascending order, and column sets are iterated over in
ascending order, so columns were always correctly mapped.

This commit replaces `MapFilterCols` with `RemapScanColsInFilter` which
does not rely on the ordering of columns in unordered sets. This will
prevent future bugs that would occur if either column ID allocation or
column set iteration change. The more specific name and arguments of
`RemapScanColsInFilter` should also prevent misuse.

Release justification: This is a minor change that does not affect
behavior and decreases the risk of future bugs in the optimizer.

Release note: None

Co-authored-by: Marcus Gartner <[email protected]>
  • Loading branch information
craig[bot] and mgartner committed Mar 9, 2022
2 parents 70a7826 + ef588e7 commit 5e1b8d6
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 77 deletions.
3 changes: 1 addition & 2 deletions pkg/sql/opt/norm/general_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,13 +274,12 @@ func (c *CustomFuncs) DuplicateColumnIDs(
table opt.TableID, cols opt.ColSet,
) (opt.TableID, opt.ColSet) {
md := c.mem.Metadata()
tabMeta := md.TableMeta(table)
newTableID := md.DuplicateTable(table, c.RemapCols)

// Build a new set of column IDs from the new TableMeta.
var newColIDs opt.ColSet
for col, ok := cols.Next(0); ok; col, ok = cols.Next(col + 1) {
ord := tabMeta.MetaID.ColumnOrdinal(col)
ord := table.ColumnOrdinal(col)
newColID := newTableID.ColumnID(ord)
newColIDs.Add(newColID)
}
Expand Down
99 changes: 57 additions & 42 deletions pkg/sql/opt/xform/general_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,41 +68,33 @@ func (c *CustomFuncs) HasInvertedIndexes(scanPrivate *memo.ScanPrivate) bool {
return false
}

// MapFilterCols returns a new FiltersExpr with all the src column IDs in
// the input expression replaced with column IDs in dst.
//
// NOTE: Every ColumnID in src must map to the a ColumnID in dst with the same
// relative position in the ColSets. For example, if src and dst are (1, 5, 6)
// and (7, 12, 15), then the following mapping would be applied:
//
// 1 => 7
// 5 => 12
// 6 => 15
func (c *CustomFuncs) MapFilterCols(
filters memo.FiltersExpr, src, dst opt.ColSet,
// RemapScanColsInFilter returns a new FiltersExpr where columns in src's table
// are replaced with columns of the same ordinal in dst's table. src and dst
// must scan the same base table.
func (c *CustomFuncs) RemapScanColsInFilter(
filters memo.FiltersExpr, src, dst *memo.ScanPrivate,
) memo.FiltersExpr {
newFilters := c.mapScalarExprCols(&filters, src, dst).(*memo.FiltersExpr)
newFilters := c.remapScanColsInScalarExpr(&filters, src, dst).(*memo.FiltersExpr)
return *newFilters
}

func (c *CustomFuncs) mapScalarExprCols(scalar opt.ScalarExpr, src, dst opt.ColSet) opt.ScalarExpr {
if src.Len() != dst.Len() {
panic(errors.AssertionFailedf(
"src and dst must have the same number of columns, src: %v, dst: %v",
src,
dst,
))
func (c *CustomFuncs) remapScanColsInScalarExpr(
scalar opt.ScalarExpr, src, dst *memo.ScanPrivate,
) opt.ScalarExpr {
md := c.e.mem.Metadata()
if md.Table(src.Table).ID() != md.Table(dst.Table).ID() {
panic(errors.AssertionFailedf("scans must have the same base table"))
}

// Map each column in src to a column in dst based on the relative position
// of both the src and dst ColumnIDs in the ColSet.
if src.Cols.Len() != dst.Cols.Len() {
panic(errors.AssertionFailedf("scans must have the same number of columns"))
}
// Remap each column in src to a column in dst.
var colMap opt.ColMap
dstCol, _ := dst.Next(0)
for srcCol, ok := src.Next(0); ok; srcCol, ok = src.Next(srcCol + 1) {
for srcCol, ok := src.Cols.Next(0); ok; srcCol, ok = src.Cols.Next(srcCol + 1) {
ord := src.Table.ColumnOrdinal(srcCol)
dstCol := dst.Table.ColumnID(ord)
colMap.Set(int(srcCol), int(dstCol))
dstCol, _ = dst.Next(dstCol + 1)
}

return c.RemapCols(scalar, colMap)
}

Expand Down Expand Up @@ -523,9 +515,21 @@ func (c *CustomFuncs) splitScanIntoUnionScansOrSelects(
}
for j, m := 0, singleKeySpans.Count(); j < m; j++ {
// Construct a new Scan for each span.
newScanOrSelect := c.makeNewScan(sp, cons.Columns, newHardLimit, singleKeySpans.Get(j))
newScanPrivate := c.makeNewScanPrivate(
sp,
cons.Columns,
newHardLimit,
singleKeySpans.Get(j),
)
newScanOrSelect := c.e.f.ConstructScan(newScanPrivate)
if !filters.IsTrue() {
newScanOrSelect = c.wrapScanInLimitedSelect(newScanOrSelect, sp, filters, limit)
newScanOrSelect = c.wrapScanInLimitedSelect(
newScanOrSelect,
sp,
newScanPrivate,
filters,
limit,
)
}
queue.PushBack(newScanOrSelect)
queueLength++
Expand Down Expand Up @@ -561,6 +565,10 @@ func (c *CustomFuncs) splitScanIntoUnionScansOrSelects(
// Not necessary, but keep spans with lower values in the left subtree.
left, right = right, left
}
// TODO(mgartner/msirek): Converting ColSets to ColLists here is only safe
// because column IDs are always allocated in a consistent, ascending order
// for each duplicated table in the metadata. If column ID allocation
// changes, this could break.
if noLimitSpans.Count() == 0 && queue.Len() == 0 {
outCols = sp.Cols.ToList()
} else {
Expand Down Expand Up @@ -594,8 +602,12 @@ func (c *CustomFuncs) splitScanIntoUnionScansOrSelects(
})
newScanOrSelect := c.e.f.ConstructScan(newScanPrivate)
if !filters.IsTrue() {
newScanOrSelect = c.wrapScanInLimitedSelect(newScanOrSelect, sp, filters, limit)
newScanOrSelect = c.wrapScanInLimitedSelect(newScanOrSelect, sp, newScanPrivate, filters, limit)
}
// TODO(mgartner/msirek): Converting ColSets to ColLists here is only safe
// because column IDs are always allocated in a consistent, ascending order
// for each duplicated table in the metadata. If column ID allocation
// changes, this could break.
return makeNewUnion(last, newScanOrSelect, sp.Cols.ToList()), true
}

Expand Down Expand Up @@ -651,12 +663,15 @@ func (c *CustomFuncs) numAllowedValues(
// the originalScanPrivate columns to the columns in scan. If limit is non-zero,
// the SelectExpr is wrapped in a LimitExpr with that limit.
func (c *CustomFuncs) wrapScanInLimitedSelect(
scan memo.RelExpr, originalScanPrivate *memo.ScanPrivate, filters memo.FiltersExpr, limit int,
scan memo.RelExpr,
originalScanPrivate, newScanPrivate *memo.ScanPrivate,
filters memo.FiltersExpr,
limit int,
) (limitedSelect memo.RelExpr) {
limitedSelect =
c.e.f.ConstructSelect(scan,
c.MapFilterCols(filters, originalScanPrivate.Cols,
c.OutputCols(scan)))
limitedSelect = c.e.f.ConstructSelect(
scan,
c.RemapScanColsInFilter(filters, originalScanPrivate, newScanPrivate),
)
if limit != 0 {
limitedSelect = c.e.f.ConstructLimit(
limitedSelect,
Expand Down Expand Up @@ -729,16 +744,16 @@ func indexHasOrderingSequence(
return ordering.ScanPrivateCanProvide(md, sp, &requiredOrdering)
}

// makeNewScan constructs a new Scan operator with a new TableID and the given
// limit and span. All ColumnIDs and references to those ColumnIDs are
// replaced with new ones from the new TableID. All other fields are simply
// copied from the old ScanPrivate.
func (c *CustomFuncs) makeNewScan(
// makeNewScanPrivate returns a new ScanPrivate with a new TableID and the given
// limit and span. All ColumnIDs and references to those ColumnIDs are replaced
// with new ones from the new TableID. All other fields are simply copied from
// the old ScanPrivate.
func (c *CustomFuncs) makeNewScanPrivate(
sp *memo.ScanPrivate,
columns constraint.Columns,
newHardLimit memo.ScanLimit,
span *constraint.Span,
) memo.RelExpr {
) *memo.ScanPrivate {
newScanPrivate := c.DuplicateScanPrivate(sp)

// duplicateScanPrivate does not initialize the Constraint or HardLimit
Expand All @@ -755,7 +770,7 @@ func (c *CustomFuncs) makeNewScan(
}
newScanPrivate.SetConstraint(c.e.evalCtx, newConstraint)

return c.e.f.ConstructScan(newScanPrivate)
return newScanPrivate
}

// getKnownScanConstraint returns a Constraint that is known to hold true for
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/opt/xform/groupby_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (c *CustomFuncs) MakeMinMaxScalarSubqueriesWithFilter(
// If the input to the scalar group by is a Select with filters, remap the
// column IDs in the filters and use that to build a new Select.
if len(filters) > 0 {
newFilters := c.MapFilterCols(filters, scanPrivate.Cols, newScanPrivate.Cols)
newFilters := c.RemapScanColsInFilter(filters, scanPrivate, newScanPrivate)
inputExpr = c.e.f.ConstructSelect(inputExpr, newFilters)
}

Expand All @@ -70,7 +70,7 @@ func (c *CustomFuncs) MakeMinMaxScalarSubqueriesWithFilter(
if !ok {
panic(errors.AssertionFailedf("expected a variable as input to the aggregate, but found %T", aggs[i].Agg.Child(0)))
}
newVarExpr := c.mapScalarExprCols(variable, scanPrivate.Cols, newScanPrivate.Cols)
newVarExpr := c.remapScanColsInScalarExpr(variable, scanPrivate, newScanPrivate)
var newAggrFunc opt.ScalarExpr
switch aggs[i].Agg.(type) {
case *memo.MaxExpr:
Expand Down
24 changes: 15 additions & 9 deletions pkg/sql/opt/xform/join_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -1215,21 +1215,27 @@ func (c *CustomFuncs) mapInvertedJoin(
newIndexCols.Add(prefixCol)
}

// Get the source and destination ColSets, including the inverted source
// columns, which will be used in the invertedExpr.
srcCols := indexCols.Copy()
dstCols := newIndexCols.Copy()
// Create a map from the source columns to the destination columns,
// including the inverted source columns which will be used in the
// invertedExpr.
var srcColsToDstCols opt.ColMap
for srcCol, ok := indexCols.Next(0); ok; srcCol, ok = indexCols.Next(srcCol + 1) {
ord := tabID.ColumnOrdinal(srcCol)
dstCol := newTabID.ColumnID(ord)
srcColsToDstCols.Set(int(srcCol), int(dstCol))
}
ord := index.InvertedColumn().InvertedSourceColumnOrdinal()
invertedSourceCol := tabID.ColumnID(ord)
newInvertedSourceCol := newTabID.ColumnID(ord)
srcCols.Add(invertedSourceCol)
dstCols.Add(newInvertedSourceCol)
srcColsToDstCols.Set(int(invertedSourceCol), int(newInvertedSourceCol))

invertedJoin.Table = newTabID
invertedJoin.InvertedExpr = c.mapScalarExprCols(invertedJoin.InvertedExpr, srcCols, dstCols)
invertedJoin.InvertedExpr = c.RemapCols(invertedJoin.InvertedExpr, srcColsToDstCols)
invertedJoin.Cols = invertedJoin.Cols.Difference(indexCols).Union(newIndexCols)
invertedJoin.ConstFilters = c.MapFilterCols(invertedJoin.ConstFilters, srcCols, dstCols)
invertedJoin.On = c.MapFilterCols(invertedJoin.On, srcCols, dstCols)
constFilters := c.RemapCols(&invertedJoin.ConstFilters, srcColsToDstCols).(*memo.FiltersExpr)
invertedJoin.ConstFilters = *constFilters
on := c.RemapCols(&invertedJoin.On, srcColsToDstCols).(*memo.FiltersExpr)
invertedJoin.On = *on
}

// findComputedColJoinEquality returns the computed column expression of col and
Expand Down
56 changes: 34 additions & 22 deletions pkg/sql/opt/xform/rules/select.opt
Original file line number Diff line number Diff line change
Expand Up @@ -132,27 +132,35 @@
(DistinctOn
(UnionAll
(Select
$leftScan:(Scan (DuplicateScanPrivate $scanPrivate))
(MapFilterCols
$leftScan:(Scan
$leftScanPrivate:(DuplicateScanPrivate
$scanPrivate
)
)
(RemapScanColsInFilter
(ReplaceFiltersItem
$filters
$itemToReplace
$leftFilter
)
(OutputCols $input)
(OutputCols $leftScan)
$scanPrivate
$leftScanPrivate
)
)
(Select
$rightScan:(Scan (DuplicateScanPrivate $scanPrivate))
(MapFilterCols
$rightScan:(Scan
$rightScanPrivate:(DuplicateScanPrivate
$scanPrivate
)
)
(RemapScanColsInFilter
(ReplaceFiltersItem
$filters
$itemToReplace
$rightFilter
)
(OutputCols $input)
(OutputCols $rightScan)
$scanPrivate
$rightScanPrivate
)
)
(MakeSetPrivate
Expand Down Expand Up @@ -218,44 +226,48 @@
(Select
$leftScan:(Scan
(AddPrimaryKeyColsToScanPrivate
(DuplicateScanPrivate $scanPrivate)
$leftScanPrivate:(DuplicateScanPrivate
$scanPrivate
)
)
)
(MapFilterCols
(RemapScanColsInFilter
(ReplaceFiltersItem
$filters
$itemToReplace
$leftFilter
)
$outCols:(UnionCols
(OutputCols $input)
(PrimaryKeyCols
(TableIDFromScanPrivate $scanPrivate)
)
)
(OutputCols $leftScan)
$scanPrivate
$leftScanPrivate
)
)
(Select
$rightScan:(Scan
(AddPrimaryKeyColsToScanPrivate
(DuplicateScanPrivate $scanPrivate)
$rightScanPrivate:(DuplicateScanPrivate
$scanPrivate
)
)
)
(MapFilterCols
(RemapScanColsInFilter
(ReplaceFiltersItem
$filters
$itemToReplace
$rightFilter
)
$outCols
(OutputCols $rightScan)
$scanPrivate
$rightScanPrivate
)
)
(MakeSetPrivate
(OutputCols $leftScan)
(OutputCols $rightScan)
$outCols
(UnionCols
(OutputCols $input)
(PrimaryKeyCols
(TableIDFromScanPrivate $scanPrivate)
)
)
)
)
[]
Expand Down

0 comments on commit 5e1b8d6

Please sign in to comment.