Skip to content

Commit

Permalink
expression: ConstItem => ConstLevel to provide more clear semantics (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao authored Dec 28, 2023
1 parent 7c69949 commit 87f8355
Show file tree
Hide file tree
Showing 22 changed files with 106 additions and 78 deletions.
2 changes: 1 addition & 1 deletion pkg/expression/aggregation/base_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (a *baseFuncDesc) typeInfer4ApproxPercentile(ctx sessionctx.Context) error
return errors.New("APPROX_PERCENTILE should take 2 arguments")
}

if !a.Args[1].ConstItem(false) {
if a.Args[1].ConstLevel() == expression.ConstNone {
return errors.New("APPROX_PERCENTILE should take a constant expression as percentage argument")
}
percent, isNull, err := a.Args[1].EvalInt(ctx, chunk.Row{})
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -2233,7 +2233,7 @@ func WrapWithCastAsDecimal(ctx sessionctx.Context, expr Expression) Expression {
tp.AddFlag(expr.GetType().GetFlag() & (mysql.UnsignedFlag | mysql.NotNullFlag))
castExpr := BuildCastFunction(ctx, expr, tp)
// For const item, we can use find-grained precision and scale by the result.
if castExpr.ConstItem(true) {
if castExpr.ConstLevel() == ConstStrict {
val, isnull, err := castExpr.EvalDecimal(ctx, chunk.Row{})
if !isnull && err == nil {
precision, frac := val.PrecisionAndFrac()
Expand Down
8 changes: 4 additions & 4 deletions pkg/expression/builtin_encryption_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (b *builtinAesDecryptSig) vecEvalString(ctx EvalContext, input *chunk.Chunk
}

isWarning := !b.ivRequired && len(b.args) == 3
isConstKey := b.args[1].ConstItem(false)
isConstKey := b.args[1].ConstLevel() >= ConstOnlyInContext

var key []byte
if isConstKey {
Expand Down Expand Up @@ -158,7 +158,7 @@ func (b *builtinAesEncryptIVSig) vecEvalString(ctx EvalContext, input *chunk.Chu
return errors.Errorf("unsupported block encryption mode - %v", b.modeName)
}

isConst := b.args[1].ConstItem(false)
isConst := b.args[1].ConstLevel() >= ConstOnlyInContext
var key []byte
if isConst {
key = encrypt.DeriveKeyMySQL(keyBuf.GetBytes(0), b.keySize)
Expand Down Expand Up @@ -331,7 +331,7 @@ func (b *builtinAesDecryptIVSig) vecEvalString(ctx EvalContext, input *chunk.Chu
return errors.Errorf("unsupported block encryption mode - %v", b.modeName)
}

isConst := b.args[1].ConstItem(false)
isConst := b.args[1].ConstLevel() >= ConstOnlyInContext
var key []byte
if isConst {
key = encrypt.DeriveKeyMySQL(keyBuf.GetBytes(0), b.keySize)
Expand Down Expand Up @@ -672,7 +672,7 @@ func (b *builtinAesEncryptSig) vecEvalString(ctx EvalContext, input *chunk.Chunk
}

isWarning := !b.ivRequired && len(b.args) == 3
isConst := b.args[1].ConstItem(false)
isConst := b.args[1].ConstLevel() >= ConstOnlyInContext
var key []byte
if isConst {
key = encrypt.DeriveKeyMySQL(keyBuf.GetBytes(0), b.keySize)
Expand Down
4 changes: 2 additions & 2 deletions pkg/expression/builtin_func_param.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func buildStringParam(ctx EvalContext, bf *baseBuiltinFunc, idx int, input *chun

// Check if this is a const value.
// funcParam will not be shared between evaluations, so we just need it to be const in one ctx.
if bf.args[idx].ConstItem(false) {
if bf.args[idx].ConstLevel() >= ConstOnlyInContext {
// Initialize the const
var isConstNull bool
pa.defaultStrVal, isConstNull, err = bf.args[idx].EvalString(ctx, chunk.Row{})
Expand Down Expand Up @@ -113,7 +113,7 @@ func buildIntParam(ctx EvalContext, bf *baseBuiltinFunc, idx int, input *chunk.C

// Check if this is a const value
// funcParam will not be shared between evaluations, so we just need it to be const in one ctx.
if bf.args[idx].ConstItem(false) {
if bf.args[idx].ConstLevel() >= ConstOnlyInContext {
// Initialize the const
var isConstNull bool
pa.defaultIntVal, isConstNull, err = bf.args[idx].EvalInt(ctx, chunk.Row{})
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/builtin_ilike.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (b *builtinIlikeSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool,
patternStr = string(patternStrBytes)

var pattern collate.WildcardPattern
if b.args[1].ConstItem(false) && b.args[2].ConstItem(false) {
if b.args[1].ConstLevel() >= ConstOnlyInContext && b.args[2].ConstLevel() >= ConstOnlyInContext {
pattern, err = b.patternCache.getOrInitCache(ctx, func() (collate.WildcardPattern, error) {
ret := collate.ConvertAndGetBinCollation(b.collation).Pattern()
ret.Compile(patternStr, byte(escape))
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/builtin_ilike_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (b *builtinIlikeSig) getEscape(ctx EvalContext, input *chunk.Chunk, result
rowNum := input.NumRows()
escape := int64('\\')

if !b.args[2].ConstItem(true) {
if b.args[2].ConstLevel() != ConstStrict {
return escape, true, errors.Errorf("escape should be const")
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/builtin_like.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (b *builtinLikeSig) evalInt(ctx EvalContext, row chunk.Row) (int64, bool, e
return 0, isNull, err
}
var pattern collate.WildcardPattern
if b.args[1].ConstItem(false) && b.args[2].ConstItem(false) {
if b.args[1].ConstLevel() >= ConstOnlyInContext && b.args[2].ConstLevel() >= ConstOnlyInContext {
pattern, err = b.patternCache.getOrInitCache(ctx, func() (collate.WildcardPattern, error) {
ret := b.collator().Pattern()
ret.Compile(patternStr, byte(escape))
Expand Down
12 changes: 6 additions & 6 deletions pkg/expression/builtin_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func (b *builtinInIntSig) buildHashMapForConstArgs(ctx sessionctx.Context) error
b.nonConstArgsIdx = make([]int, 0)
b.hashSet = make(map[int64]bool, len(b.args)-1)
for i := 1; i < len(b.args); i++ {
if b.args[i].ConstItem(true) {
if b.args[i].ConstLevel() == ConstStrict {
val, isNull, err := b.args[i].EvalInt(ctx, chunk.Row{})
if err != nil {
return err
Expand Down Expand Up @@ -290,7 +290,7 @@ func (b *builtinInStringSig) buildHashMapForConstArgs(ctx sessionctx.Context) er
b.hashSet = set.NewStringSet()
collator := collate.GetCollator(b.collation)
for i := 1; i < len(b.args); i++ {
if b.args[i].ConstItem(true) {
if b.args[i].ConstLevel() == ConstStrict {
val, isNull, err := b.args[i].EvalString(ctx, chunk.Row{})
if err != nil {
return err
Expand Down Expand Up @@ -363,7 +363,7 @@ func (b *builtinInRealSig) buildHashMapForConstArgs(ctx sessionctx.Context) erro
b.nonConstArgsIdx = make([]int, 0)
b.hashSet = set.NewFloat64Set()
for i := 1; i < len(b.args); i++ {
if b.args[i].ConstItem(true) {
if b.args[i].ConstLevel() == ConstStrict {
val, isNull, err := b.args[i].EvalReal(ctx, chunk.Row{})
if err != nil {
return err
Expand Down Expand Up @@ -434,7 +434,7 @@ func (b *builtinInDecimalSig) buildHashMapForConstArgs(ctx sessionctx.Context) e
b.nonConstArgsIdx = make([]int, 0)
b.hashSet = set.NewStringSet()
for i := 1; i < len(b.args); i++ {
if b.args[i].ConstItem(true) {
if b.args[i].ConstLevel() == ConstStrict {
val, isNull, err := b.args[i].EvalDecimal(ctx, chunk.Row{})
if err != nil {
return err
Expand Down Expand Up @@ -514,7 +514,7 @@ func (b *builtinInTimeSig) buildHashMapForConstArgs(ctx sessionctx.Context) erro
b.nonConstArgsIdx = make([]int, 0)
b.hashSet = make(map[types.CoreTime]struct{}, len(b.args)-1)
for i := 1; i < len(b.args); i++ {
if b.args[i].ConstItem(true) {
if b.args[i].ConstLevel() == ConstStrict {
val, isNull, err := b.args[i].EvalTime(ctx, chunk.Row{})
if err != nil {
return err
Expand Down Expand Up @@ -585,7 +585,7 @@ func (b *builtinInDurationSig) buildHashMapForConstArgs(ctx sessionctx.Context)
b.nonConstArgsIdx = make([]int, 0)
b.hashSet = make(map[time.Duration]struct{}, len(b.args)-1)
for i := 1; i < len(b.args); i++ {
if b.args[i].ConstItem(true) {
if b.args[i].ConstLevel() == ConstStrict {
val, isNull, err := b.args[i].EvalDuration(ctx, chunk.Row{})
if err != nil {
return err
Expand Down
6 changes: 3 additions & 3 deletions pkg/expression/builtin_regexp.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ func (re *regexpBaseFuncSig) canMemorizeRegexp(matchTypeIdx int) bool {
// If the pattern and match type are both constants, we can cache the regexp into memory.
// Notice that the above two arguments are not required to be constant across contexts because the cache is only
// valid when the two context ids are the same.
return re.args[patternIdx].ConstItem(false) &&
(len(re.args) <= matchTypeIdx || re.args[matchTypeIdx].ConstItem(false))
return re.args[patternIdx].ConstLevel() >= ConstOnlyInContext &&
(len(re.args) <= matchTypeIdx || re.args[matchTypeIdx].ConstLevel() >= ConstOnlyInContext)
}

// buildRegexp builds a new `*regexp.Regexp` from the pattern and matchType
Expand Down Expand Up @@ -1156,7 +1156,7 @@ func getInstructions(repl []byte) []Instruction {
}

func (re *builtinRegexpReplaceFuncSig) canInstructionsMemorized() bool {
return re.args[replacementIdx].ConstItem(false)
return re.args[replacementIdx].ConstLevel() >= ConstOnlyInContext
}

func (re *builtinRegexpReplaceFuncSig) getInstructions(ctx EvalContext, repl string) ([]Instruction, error) {
Expand Down
12 changes: 6 additions & 6 deletions pkg/expression/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ func (col *CorrelatedColumn) IsCorrelated() bool {
return true
}

// ConstItem implements Expression interface.
func (col *CorrelatedColumn) ConstItem(_ bool) bool {
return false
// ConstLevel returns the const level for the expression
func (col *CorrelatedColumn) ConstLevel() ConstLevel {
return ConstNone
}

// Decorrelate implements Expression interface.
Expand Down Expand Up @@ -514,9 +514,9 @@ func (col *Column) IsCorrelated() bool {
return false
}

// ConstItem implements Expression interface.
func (col *Column) ConstItem(_ bool) bool {
return false
// ConstLevel returns the const level for the expression
func (col *Column) ConstLevel() ConstLevel {
return ConstNone
}

// Decorrelate implements Expression interface.
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestColumn(t *testing.T) {
require.True(t, corCol.EqualColumn(corCol))
require.False(t, corCol.EqualColumn(invalidCorCol))
require.True(t, corCol.IsCorrelated())
require.False(t, corCol.ConstItem(false))
require.Equal(t, ConstNone, corCol.ConstLevel())
require.True(t, col.EqualColumn(corCol.Decorrelate(schema)))
require.True(t, invalidCorCol.EqualColumn(invalidCorCol.Decorrelate(schema)))

Expand Down
9 changes: 6 additions & 3 deletions pkg/expression/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,12 @@ func (c *Constant) IsCorrelated() bool {
return false
}

// ConstItem implements Expression interface.
func (c *Constant) ConstItem(acrossCtx bool) bool {
return !acrossCtx || (c.DeferredExpr == nil && c.ParamMarker == nil)
// ConstLevel returns the const level for the expression
func (c *Constant) ConstLevel() ConstLevel {
if c.DeferredExpr != nil || c.ParamMarker != nil {
return ConstOnlyInContext
}
return ConstStrict
}

// Decorrelate implements Expression interface.
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/constant_propagation.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (s *basePropConstSolver) insertCol(col *Column) {
// tryToUpdateEQList tries to update the eqList. When the eqList has store this column with a different constant, like
// a = 1 and a = 2, we set the second return value to false.
func (s *basePropConstSolver) tryToUpdateEQList(col *Column, con *Constant) (bool, bool) {
if con.ConstItem(s.ctx.GetSessionVars().StmtCtx.UseCache) && con.Value.IsNull() {
if con.Value.IsNull() && ConstExprConsiderPlanCache(con, s.ctx.GetSessionVars().StmtCtx.UseCache) {
return false, true
}
id := s.getColID(col)
Expand Down
31 changes: 20 additions & 11 deletions pkg/expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,24 @@ type TraverseAction interface {
Transform(Expression) Expression
}

// ConstLevel indicates the const level for an expression
type ConstLevel uint

const (
// ConstNone indicates the expression is not a constant expression.
// The evaluation result may be different for different input rows.
// e.g. `col_a * 2`, `substring(col_b, 5, 3)`.
ConstNone ConstLevel = iota
// ConstOnlyInContext indicates the expression is only a constant for a same context.
// This is mainly for Plan Cache, e.g. `prepare st from 'select * from t where a<1+?'`, where
// the value of `?` may change between different Contexts (executions).
ConstOnlyInContext
// ConstStrict indicates the expression is a constant expression.
// The evaluation result is always the same no matter the input context or rows.
// e.g. `1 + 2`, `substring("TiDB SQL Tutorial", 5, 3) + 'abcde'`
ConstStrict
)

// Expression represents all scalar expression in SQL.
type Expression interface {
fmt.Stringer
Expand Down Expand Up @@ -129,17 +147,8 @@ type Expression interface {
// IsCorrelated checks if this expression has correlated key.
IsCorrelated() bool

// ConstItem checks if this expression is constant item, regardless of query evaluation state.
// If the argument `acrossCtxs` is true,
// it will check if this expression returns a constant value even across multiple contexts.
// An expression is constant item if it:
// refers no tables.
// refers no correlated column.
// refers no subqueries that refers any tables.
// refers no non-deterministic functions.
// refers no statement parameters.
// refers no param markers when prepare plan cache is enabled.
ConstItem(acrossCtx bool) bool
// ConstLevel returns the const level of the expression.
ConstLevel() ConstLevel

// Decorrelate try to decorrelate the expression by schema.
Decorrelate(schema *Schema) Expression
Expand Down
38 changes: 15 additions & 23 deletions pkg/expression/expression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@ package expression

import (
"testing"
"time"

"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/model"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx/stmtctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -106,9 +104,8 @@ func TestEvaluateExprWithNullNoChangeRetType(t *testing.T) {

func TestConstant(t *testing.T) {
ctx := createContext(t)
sc := stmtctx.NewStmtCtxWithTimeZone(time.Local)
require.False(t, NewZero().IsCorrelated())
require.True(t, NewZero().ConstItem(sc.UseCache))
require.Equal(t, ConstStrict, NewZero().ConstLevel())
require.True(t, NewZero().Decorrelate(nil).Equal(ctx, NewZero()))
require.Equal(t, []byte{0x0, 0x8, 0x0}, NewZero().HashCode())
require.False(t, NewZero().Equal(ctx, NewOne()))
Expand All @@ -133,30 +130,25 @@ func TestIsBinaryLiteral(t *testing.T) {
require.False(t, IsBinaryLiteral(col))
}

func TestConstItem(t *testing.T) {
const noConst int = 0
const constInCtx int = 1
const constStrict int = 2

func TestConstLevel(t *testing.T) {
ctxConst := NewZero()
ctxConst.DeferredExpr = newFunctionWithMockCtx(ast.UnixTimestamp)
for _, c := range []struct {
exp Expression
constItem int
exp Expression
level ConstLevel
}{
{newFunctionWithMockCtx(ast.Rand), noConst},
{newFunctionWithMockCtx(ast.UUID), noConst},
{newFunctionWithMockCtx(ast.GetParam, NewOne()), noConst},
{newFunctionWithMockCtx(ast.Abs, NewOne()), constStrict},
{newFunctionWithMockCtx(ast.Abs, newColumn(1)), noConst},
{newFunctionWithMockCtx(ast.Plus, NewOne(), NewOne()), constStrict},
{newFunctionWithMockCtx(ast.Plus, newColumn(1), NewOne()), noConst},
{newFunctionWithMockCtx(ast.Plus, NewOne(), newColumn(1)), noConst},
{newFunctionWithMockCtx(ast.Plus, NewOne(), newColumn(1)), noConst},
{newFunctionWithMockCtx(ast.Plus, NewOne(), ctxConst), constInCtx},
{newFunctionWithMockCtx(ast.Rand), ConstNone},
{newFunctionWithMockCtx(ast.UUID), ConstNone},
{newFunctionWithMockCtx(ast.GetParam, NewOne()), ConstNone},
{newFunctionWithMockCtx(ast.Abs, NewOne()), ConstStrict},
{newFunctionWithMockCtx(ast.Abs, newColumn(1)), ConstNone},
{newFunctionWithMockCtx(ast.Plus, NewOne(), NewOne()), ConstStrict},
{newFunctionWithMockCtx(ast.Plus, newColumn(1), NewOne()), ConstNone},
{newFunctionWithMockCtx(ast.Plus, NewOne(), newColumn(1)), ConstNone},
{newFunctionWithMockCtx(ast.Plus, NewOne(), newColumn(1)), ConstNone},
{newFunctionWithMockCtx(ast.Plus, NewOne(), ctxConst), ConstOnlyInContext},
} {
require.Equal(t, c.constItem >= constInCtx, c.exp.ConstItem(false), c.exp.String())
require.Equal(t, c.constItem >= constStrict, c.exp.ConstItem(true), c.exp.String())
require.Equal(t, c.level, c.exp.ConstLevel(), c.exp.String())
}
}

Expand Down
20 changes: 14 additions & 6 deletions pkg/expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,18 +352,26 @@ func (sf *ScalarFunction) IsCorrelated() bool {
return false
}

// ConstItem implements Expression interface.
func (sf *ScalarFunction) ConstItem(acrossCtx bool) bool {
// ConstLevel returns the const level for the expression
func (sf *ScalarFunction) ConstLevel() ConstLevel {
// Note: some unfoldable functions are deterministic, we use unFoldableFunctions here for simplification.
if _, ok := unFoldableFunctions[sf.FuncName.L]; ok {
return false
return ConstNone
}

level := ConstStrict
for _, arg := range sf.GetArgs() {
if !arg.ConstItem(acrossCtx) {
return false
argLevel := arg.ConstLevel()
if argLevel == ConstNone {
return ConstNone
}

if argLevel < level {
level = argLevel
}
}
return true

return level
}

// Decorrelate implements Expression interface.
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/scalar_function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func TestScalarFunction(t *testing.T) {
require.NoError(t, err)
require.EqualValues(t, []byte{0x22, 0x6c, 0x74, 0x28, 0x43, 0x6f, 0x6c, 0x75, 0x6d, 0x6e, 0x23, 0x31, 0x2c, 0x20, 0x31, 0x29, 0x22}, res)
require.False(t, sf.IsCorrelated())
require.False(t, sf.ConstItem(false))
require.Equal(t, ConstNone, sf.ConstLevel())
require.True(t, sf.Decorrelate(nil).Equal(ctx, sf))
require.EqualValues(t, []byte{0x3, 0x4, 0x6c, 0x74, 0x1, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0x0, 0x5, 0xbf, 0xf0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, sf.HashCode())

Expand Down
Loading

0 comments on commit 87f8355

Please sign in to comment.