Skip to content

Commit

Permalink
util/ranger: move cut prefix logic for prefix index to ealier stage (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
time-and-fate committed Dec 5, 2023
1 parent 283f07b commit ae33dd5
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 100 deletions.
12 changes: 4 additions & 8 deletions util/ranger/detacher.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ func (d *rangeDetacher) detachCNFCondAndBuildRangeForIndex(conditions []expressi
// Therefore, we need to calculate pointRanges separately so that it can be used to append tail ranges in considerDNF branch.
// See https://github.com/pingcap/tidb/issues/26029 for details.
var pointRanges Ranges
if hasPrefix(d.lengths) && fixPrefixColRange(ranges, d.lengths, tpSlice) {
if hasPrefix(d.lengths) {
if d.mergeConsecutive {
pointRanges = make(Ranges, 0, len(ranges))
for _, ran := range ranges {
Expand Down Expand Up @@ -635,9 +635,9 @@ func ExtractEqAndInCondition(sctx sessionctx.Context, conditions []expression.Ex
collator := collate.GetCollator(cols[offset].GetType().GetCollate())
if mergedAccesses[offset] == nil {
mergedAccesses[offset] = accesses[offset]
points[offset] = rb.build(accesses[offset], collator)
points[offset] = rb.build(accesses[offset], collator, lengths[offset])
}
points[offset] = rb.intersection(points[offset], rb.build(cond, collator), collator)
points[offset] = rb.intersection(points[offset], rb.build(cond, collator, lengths[offset]), collator)
if len(points[offset]) == 0 { // Early termination if false expression found
if expression.MaybeOverOptimized4PlanCache(sctx, conditions) {
// `a>@x and a<@y` --> `invalid-range if @x>=@y`
Expand Down Expand Up @@ -778,7 +778,7 @@ func (d *rangeDetacher) detachDNFCondAndBuildRangeForIndex(condition *expression
if shouldReserve {
hasResidual = true
}
points := rb.build(item, collate.GetCollator(newTpSlice[0].GetCollate()))
points := rb.build(item, collate.GetCollator(newTpSlice[0].GetCollate()), d.lengths[0])
// TODO: restrict the mem usage of ranges
ranges, rangeFallback, err := points2Ranges(d.sctx, points, newTpSlice[0], d.rangeMaxSize)
if err != nil {
Expand Down Expand Up @@ -810,10 +810,6 @@ func (d *rangeDetacher) detachDNFCondAndBuildRangeForIndex(condition *expression
}
}

// Take prefix index into consideration.
if hasPrefix(d.lengths) {
fixPrefixColRange(totalRanges, d.lengths, newTpSlice)
}
totalRanges, err := UnionRanges(d.sctx, totalRanges, d.mergeConsecutive)
if err != nil {
return nil, nil, nil, false, errors.Trace(err)
Expand Down
90 changes: 55 additions & 35 deletions util/ranger/points.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
Expand Down Expand Up @@ -182,15 +183,21 @@ func NullRange() Ranges {
// builder is the range builder struct.
type builder struct {
err error
sc *stmtctx.StatementContext
ctx sessionctx.Context
}

func (r *builder) build(expr expression.Expression, collator collate.Collator) []*point {
// build converts Expression on one column into point, which can be further built into Range.
// The input collator is used for intersection/union between points, which corresponds to AND/OR in the expression. Since
// our (*Datum).Compare(), which is used there, needs an explicit collator input to handle comparison for string and bytes,
// we pass it down from here.
// If the input prefixLen is not types.UnspecifiedLength, it means it's for a prefix column in a prefix index. In such
// cases, we should cut the prefix and adjust the exclusiveness. Ref: cutPrefixForPoints().
func (r *builder) build(expr expression.Expression, collator collate.Collator, prefixLen int) []*point {
switch x := expr.(type) {
case *expression.Column:
return r.buildFromColumn()
case *expression.ScalarFunction:
return r.buildFromScalarFunc(x, collator)
return r.buildFromScalarFunc(x, collator, prefixLen)
case *expression.Constant:
return r.buildFromConstant(x)
}
Expand All @@ -208,7 +215,7 @@ func (r *builder) buildFromConstant(expr *expression.Constant) []*point {
return nil
}

val, err := dt.ToBool(r.sc)
val, err := dt.ToBool(r.ctx.GetSessionVars().StmtCtx)
if err != nil {
r.err = err
return nil
Expand All @@ -231,7 +238,7 @@ func (*builder) buildFromColumn() []*point {
return []*point{startPoint1, endPoint1, startPoint2, endPoint2}
}

func (r *builder) buildFromBinOp(expr *expression.ScalarFunction) []*point {
func (r *builder) buildFromBinOp(expr *expression.ScalarFunction, prefixLen int) []*point {
// This has been checked that the binary operation is comparison operation, and one of
// the operand is column name expression.
var (
Expand All @@ -253,11 +260,11 @@ func (r *builder) buildFromBinOp(expr *expression.ScalarFunction) []*point {
// If the original value is adjusted, we need to change the condition.
// For example, col < 2156. Since the max year is 2155, 2156 is changed to 2155.
// col < 2155 is wrong. It should be col <= 2155.
preValue, err1 := value.ToInt64(r.sc)
preValue, err1 := value.ToInt64(r.ctx.GetSessionVars().StmtCtx)
if err1 != nil {
return err1
}
*value, err = value.ConvertToMysqlYear(r.sc, col.RetType)
*value, err = value.ConvertToMysqlYear(r.ctx.GetSessionVars().StmtCtx, col.RetType)
if errors.ErrorEqual(err, types.ErrWarnDataOutOfRange) {
// Keep err for EQ and NE.
switch *op {
Expand Down Expand Up @@ -334,43 +341,46 @@ func (r *builder) buildFromBinOp(expr *expression.ScalarFunction) []*point {
}

if ft.GetType() == mysql.TypeEnum && ft.EvalType() == types.ETString {
return handleEnumFromBinOp(r.sc, ft, value, op)
return handleEnumFromBinOp(r.ctx.GetSessionVars().StmtCtx, ft, value, op)
}

var res []*point
switch op {
case ast.NullEQ:
if value.IsNull() {
return []*point{{start: true}, {}} // [null, null]
res = []*point{{start: true}, {}} // [null, null]
break
}
fallthrough
case ast.EQ:
startPoint := &point{value: value, start: true}
endPoint := &point{value: value}
return []*point{startPoint, endPoint}
res = []*point{startPoint, endPoint}
case ast.NE:
startPoint1 := &point{value: types.MinNotNullDatum(), start: true}
endPoint1 := &point{value: value, excl: true}
startPoint2 := &point{value: value, start: true, excl: true}
endPoint2 := &point{value: types.MaxValueDatum()}
return []*point{startPoint1, endPoint1, startPoint2, endPoint2}
res = []*point{startPoint1, endPoint1, startPoint2, endPoint2}
case ast.LT:
startPoint := &point{value: types.MinNotNullDatum(), start: true}
endPoint := &point{value: value, excl: true}
return []*point{startPoint, endPoint}
res = []*point{startPoint, endPoint}
case ast.LE:
startPoint := &point{value: types.MinNotNullDatum(), start: true}
endPoint := &point{value: value}
return []*point{startPoint, endPoint}
res = []*point{startPoint, endPoint}
case ast.GT:
startPoint := &point{value: value, start: true, excl: true}
endPoint := &point{value: types.MaxValueDatum()}
return []*point{startPoint, endPoint}
res = []*point{startPoint, endPoint}
case ast.GE:
startPoint := &point{value: value, start: true}
endPoint := &point{value: types.MaxValueDatum()}
return []*point{startPoint, endPoint}
res = []*point{startPoint, endPoint}
}
return nil
cutPrefixForPoints(res, prefixLen, ft)
return res
}

// handleUnsignedCol handles the case when unsigned column meets negative value.
Expand Down Expand Up @@ -552,11 +562,12 @@ func (*builder) buildFromIsFalse(_ *expression.ScalarFunction, isNot int) []*poi
return []*point{startPoint, endPoint}
}

func (r *builder) buildFromIn(expr *expression.ScalarFunction) ([]*point, bool) {
func (r *builder) buildFromIn(expr *expression.ScalarFunction, prefixLen int) ([]*point, bool) {
list := expr.GetArgs()[1:]
rangePoints := make([]*point, 0, len(list)*2)
hasNull := false
colCollate := expr.GetArgs()[0].GetType().GetCollate()
ft := expr.GetArgs()[0].GetType()
colCollate := ft.GetCollate()
for _, e := range list {
v, ok := e.(*expression.Constant)
if !ok {
Expand Down Expand Up @@ -584,7 +595,7 @@ func (r *builder) buildFromIn(expr *expression.ScalarFunction) ([]*point, bool)
err = parseErr
}
default:
dt, err = dt.ConvertTo(r.sc, expr.GetArgs()[0].GetType())
dt, err = dt.ConvertTo(r.ctx.GetSessionVars().StmtCtx, expr.GetArgs()[0].GetType())
}

if err != nil {
Expand All @@ -593,7 +604,7 @@ func (r *builder) buildFromIn(expr *expression.ScalarFunction) ([]*point, bool)
}
}
if expr.GetArgs()[0].GetType().GetType() == mysql.TypeYear {
dt, err = dt.ConvertToMysqlYear(r.sc, expr.GetArgs()[0].GetType())
dt, err = dt.ConvertToMysqlYear(r.ctx.GetSessionVars().StmtCtx, expr.GetArgs()[0].GetType())
if err != nil {
// in (..., an impossible value (not valid year), ...), the range is empty, so skip it.
continue
Expand All @@ -609,7 +620,7 @@ func (r *builder) buildFromIn(expr *expression.ScalarFunction) ([]*point, bool)
endPoint := &point{value: endValue}
rangePoints = append(rangePoints, startPoint, endPoint)
}
sorter := pointSorter{points: rangePoints, sc: r.sc, collator: collate.GetCollator(colCollate)}
sorter := pointSorter{points: rangePoints, sc: r.ctx.GetSessionVars().StmtCtx, collator: collate.GetCollator(colCollate)}
sort.Sort(&sorter)
if sorter.err != nil {
r.err = sorter.err
Expand All @@ -628,10 +639,12 @@ func (r *builder) buildFromIn(expr *expression.ScalarFunction) ([]*point, bool)
if curPos > 0 {
curPos++
}
return rangePoints[:curPos], hasNull
rangePoints = rangePoints[:curPos]
cutPrefixForPoints(rangePoints, prefixLen, ft)
return rangePoints, hasNull
}

func (r *builder) newBuildFromPatternLike(expr *expression.ScalarFunction) []*point {
func (r *builder) newBuildFromPatternLike(expr *expression.ScalarFunction, prefixLen int) []*point {
_, collation := expr.CharsetAndCollation()
if !collate.CompatibleCollate(expr.GetArgs()[0].GetType().GetCollate(), collation) {
return getFullRange()
Expand All @@ -650,7 +663,8 @@ func (r *builder) newBuildFromPatternLike(expr *expression.ScalarFunction) []*po
if pattern == "" {
startPoint := &point{value: types.NewStringDatum(""), start: true}
endPoint := &point{value: types.NewStringDatum("")}
return []*point{startPoint, endPoint}
res := []*point{startPoint, endPoint}
return res
}
lowValue := make([]byte, 0, len(pattern))
edt, err := expr.GetArgs()[2].(*expression.Constant).Eval(chunk.Row{})
Expand Down Expand Up @@ -690,10 +704,15 @@ func (r *builder) newBuildFromPatternLike(expr *expression.ScalarFunction) []*po
}
if isExactMatch {
val := types.NewCollationStringDatum(string(lowValue), tpOfPattern.GetCollate())
return []*point{{value: val, start: true}, {value: val}}
startPoint := &point{value: val, start: true}
endPoint := &point{value: val}
res := []*point{startPoint, endPoint}
cutPrefixForPoints(res, prefixLen, tpOfPattern)
return res
}
startPoint := &point{start: true, excl: exclude}
startPoint.value.SetBytesAsString(lowValue, tpOfPattern.GetCollate(), uint32(tpOfPattern.GetFlen()))
cutPrefixForPoints([]*point{startPoint}, prefixLen, tpOfPattern)
highValue := make([]byte, len(lowValue))
copy(highValue, lowValue)
endPoint := &point{excl: true}
Expand All @@ -714,7 +733,7 @@ func (r *builder) newBuildFromPatternLike(expr *expression.ScalarFunction) []*po
return []*point{startPoint, endPoint}
}

func (r *builder) buildFromNot(expr *expression.ScalarFunction) []*point {
func (r *builder) buildFromNot(expr *expression.ScalarFunction, prefixLen int) []*point {
switch n := expr.FuncName.L; n {
case ast.IsTruthWithoutNull:
return r.buildFromIsTrue(expr, 1, false)
Expand All @@ -727,7 +746,7 @@ func (r *builder) buildFromNot(expr *expression.ScalarFunction) []*point {
isUnsignedIntCol bool
nonNegativePos int
)
rangePoints, hasNull := r.buildFromIn(expr)
rangePoints, hasNull := r.buildFromIn(expr, types.UnspecifiedLength)
if hasNull {
return nil
}
Expand All @@ -753,6 +772,7 @@ func (r *builder) buildFromNot(expr *expression.ScalarFunction) []*point {
// Append the interval (last element, max value].
retRangePoints = append(retRangePoints, &point{value: previousValue, start: true, excl: true})
retRangePoints = append(retRangePoints, &point{value: types.MaxValueDatum()})
cutPrefixForPoints(retRangePoints, prefixLen, expr.GetArgs()[0].GetType())
return retRangePoints
case ast.Like:
// Pattern not like is not supported.
Expand All @@ -769,31 +789,31 @@ func (r *builder) buildFromNot(expr *expression.ScalarFunction) []*point {
return getFullRange()
}

func (r *builder) buildFromScalarFunc(expr *expression.ScalarFunction, collator collate.Collator) []*point {
func (r *builder) buildFromScalarFunc(expr *expression.ScalarFunction, collator collate.Collator, prefixLen int) []*point {
switch op := expr.FuncName.L; op {
case ast.GE, ast.GT, ast.LT, ast.LE, ast.EQ, ast.NE, ast.NullEQ:
return r.buildFromBinOp(expr)
return r.buildFromBinOp(expr, prefixLen)
case ast.LogicAnd:
return r.intersection(r.build(expr.GetArgs()[0], collator), r.build(expr.GetArgs()[1], collator), collator)
return r.intersection(r.build(expr.GetArgs()[0], collator, prefixLen), r.build(expr.GetArgs()[1], collator, prefixLen), collator)
case ast.LogicOr:
return r.union(r.build(expr.GetArgs()[0], collator), r.build(expr.GetArgs()[1], collator), collator)
return r.union(r.build(expr.GetArgs()[0], collator, prefixLen), r.build(expr.GetArgs()[1], collator, prefixLen), collator)
case ast.IsTruthWithoutNull:
return r.buildFromIsTrue(expr, 0, false)
case ast.IsTruthWithNull:
return r.buildFromIsTrue(expr, 0, true)
case ast.IsFalsity:
return r.buildFromIsFalse(expr, 0)
case ast.In:
retPoints, _ := r.buildFromIn(expr)
retPoints, _ := r.buildFromIn(expr, prefixLen)
return retPoints
case ast.Like:
return r.newBuildFromPatternLike(expr)
return r.newBuildFromPatternLike(expr, prefixLen)
case ast.IsNull:
startPoint := &point{start: true}
endPoint := &point{}
return []*point{startPoint, endPoint}
case ast.UnaryNot:
return r.buildFromNot(expr.GetArgs()[0].(*expression.ScalarFunction))
return r.buildFromNot(expr.GetArgs()[0].(*expression.ScalarFunction), prefixLen)
}

return nil
Expand All @@ -811,7 +831,7 @@ func (r *builder) mergeSorted(a, b []*point, collator collate.Collator) []*point
ret := make([]*point, 0, len(a)+len(b))
i, j := 0, 0
for i < len(a) && j < len(b) {
less, err := rangePointLess(r.sc, a[i], b[j], collator)
less, err := rangePointLess(r.ctx.GetSessionVars().StmtCtx, a[i], b[j], collator)
if err != nil {
r.err = err
return nil
Expand Down
Loading

0 comments on commit ae33dd5

Please sign in to comment.