Skip to content

Commit

Permalink
expression: refactor cache logic for builtinRegexpSubstrFuncSig (#4…
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao authored Dec 20, 2023
1 parent 91fad01 commit 97acf71
Show file tree
Hide file tree
Showing 10 changed files with 419 additions and 52 deletions.
45 changes: 45 additions & 0 deletions pkg/expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"slices"
"strings"
"sync"
"sync/atomic"
"unsafe"

"github.com/gogo/protobuf/proto"
Expand All @@ -40,6 +41,7 @@ import (
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/collate"
"github.com/pingcap/tidb/pkg/util/intest"
"github.com/pingcap/tidb/pkg/util/set"
"github.com/pingcap/tipb/go-tipb"
)
Expand Down Expand Up @@ -1017,3 +1019,46 @@ func (b *baseBuiltinFunc) MemoryUsage() (sum int64) {
}
return
}

type builtinFuncCacheItem[T any] struct {
ctxID uint64
item T
}

type builtinFuncCache[T any] struct {
sync.Mutex
cached atomic.Pointer[builtinFuncCacheItem[T]]
}

func (c *builtinFuncCache[T]) getCache(ctxID uint64) (v T, ok bool) {
if p := c.cached.Load(); p != nil && p.ctxID == ctxID {
return p.item, true
}
return v, false
}

func (c *builtinFuncCache[T]) getOrInitCache(ctx EvalContext, constructCache func() (T, error)) (T, error) {
intest.Assert(constructCache != nil)
ctxID := ctx.GetSessionVars().StmtCtx.CtxID()
if item, ok := c.getCache(ctxID); ok {
return item, nil
}

c.Lock()
defer c.Unlock()
if item, ok := c.getCache(ctxID); ok {
return item, nil
}

item, err := constructCache()
if err != nil {
var def T
return def, err
}

c.cached.Store(&builtinFuncCacheItem[T]{
ctxID: ctxID,
item: item,
})
return item, nil
}
158 changes: 112 additions & 46 deletions pkg/expression/builtin_regexp.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,99 @@ var validMatchType = set.NewStringSet(
flagS, // The . character matches line terminators
)

type regexpNewBaseFuncSig struct {
baseBuiltinFunc
memorizedRegexp builtinFuncCache[regexpMemorizedSig]
}

// check binary collation, not xxx_bin collation!
func (re *regexpNewBaseFuncSig) isBinaryCollation() bool {
return re.collation == charset.CollationBin && re.charset == charset.CharsetBin
}

func (re *regexpNewBaseFuncSig) clone() *regexpNewBaseFuncSig {
newSig := &regexpNewBaseFuncSig{}
newSig.cloneFrom(&re.baseBuiltinFunc)
return newSig
}

// we can memorize the regexp when:
// 1. pattern and match type are constant
// 2. pattern is const and there is no match type argument
//
// return true: need, false: needless
func (re *regexpNewBaseFuncSig) canMemorizeRegexp(matchTypeIdx int) bool {
// If the pattern and match type are both constants, we can cache the regexp into memory.
// Notice that the above two arguments are not required to be constant across contexts because the cache is only
// valid when the two context ids are the same.
return re.args[patternIdx].ConstItem(false) &&
(len(re.args) <= matchTypeIdx || re.args[matchTypeIdx].ConstItem(false))
}

// buildRegexp builds a new `*regexp.Regexp` from the pattern and matchType
func (re *regexpNewBaseFuncSig) buildRegexp(pattern string, matchType string) (reg *regexp.Regexp, err error) {
matchType, err = getRegexpMatchType(matchType, re.collation)
if err != nil {
return nil, err
}

if len(matchType) == 0 {
reg, err = regexp.Compile(pattern)
} else {
reg, err = regexp.Compile(fmt.Sprintf("(?%s)%s", matchType, pattern))
}

if err != nil {
return nil, ErrRegexp.GenWithStackByArgs(err)
}

return reg, nil
}

// getRegexp returns the Regexp which can be used by the current function.
// If the pattern and matchType arguments are both constant, the `*regexp.Regexp` object will be cached in memory.
// The next call of `getRegexp` will return the cached regexp if it is present and the context id is equal
func (re *regexpNewBaseFuncSig) getRegexp(ctx EvalContext, pattern string, matchType string, matchTypeIdx int) (*regexp.Regexp, error) {
if !re.canMemorizeRegexp(matchTypeIdx) {
return re.buildRegexp(pattern, matchType)
}

sig, err := re.memorizedRegexp.getOrInitCache(ctx, func() (ret regexpMemorizedSig, err error) {
ret.memorizedRegexp, ret.memorizedErr = re.buildRegexp(pattern, matchType)
return
})

if err != nil {
return nil, err
}

return sig.memorizedRegexp, sig.memorizedErr
}

func (re *regexpNewBaseFuncSig) tryVecMemorizedRegexp(ctx EvalContext, params []*funcParam, matchTypeIdx int, nRows int) (*regexp.Regexp, bool, error) {
// Check memorization
if nRows == 0 || !re.canMemorizeRegexp(matchTypeIdx) {
return nil, false, nil
}

pattern := params[patternIdx].getStringVal(0)
if len(pattern) == 0 {
return nil, false, ErrRegexp.GenWithStackByArgs(emptyPatternErr)
}

matchType := params[matchTypeIdx].getStringVal(0)
sig, err := re.memorizedRegexp.getOrInitCache(ctx, func() (ret regexpMemorizedSig, err error) {
ret.memorizedRegexp, ret.memorizedErr = re.buildRegexp(pattern, matchType)
return
})

if err != nil {
return nil, false, err
}

return sig.memorizedRegexp, true, sig.memorizedErr
}

type regexpBaseFuncSig struct {
baseBuiltinFunc
regexpMemorizedSig
Expand All @@ -89,11 +182,11 @@ func (re *regexpBaseFuncSig) clone() *regexpBaseFuncSig {

// If characters specifying contradictory options are specified
// within match_type, the rightmost one takes precedence.
func (re *regexpBaseFuncSig) getMatchType(userInputMatchType string) (string, error) {
func getRegexpMatchType(userInputMatchType string, collation string) (string, error) {
flag := ""
matchTypeSet := set.NewStringSet()

if collate.IsCICollation(re.baseBuiltinFunc.collation) {
if collate.IsCICollation(collation) {
matchTypeSet.Insert(flagI)
}

Expand Down Expand Up @@ -126,7 +219,7 @@ func (re *regexpBaseFuncSig) getMatchType(userInputMatchType string) (string, er

// To get a unified compile interface in initMemoizedRegexp, we need to process many things in genCompile
func (re *regexpBaseFuncSig) genCompile(matchType string) (func(string) (*regexp.Regexp, error), error) {
matchType, err := re.getMatchType(matchType)
matchType, err := getRegexpMatchType(matchType, re.baseBuiltinFunc.collation)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -389,7 +482,7 @@ func (c *regexpSubstrFunctionClass) getFunction(ctx sessionctx.Context, args []E
argType := args[0].GetType()
bf.tp.SetFlen(argType.GetFlen())
sig := builtinRegexpSubstrFuncSig{
regexpBaseFuncSig: regexpBaseFuncSig{baseBuiltinFunc: bf},
regexpNewBaseFuncSig: regexpNewBaseFuncSig{baseBuiltinFunc: bf},
}
sig.setPbCode(tipb.ScalarFuncSig_RegexpSubstrSig)

Expand All @@ -401,7 +494,7 @@ func (c *regexpSubstrFunctionClass) getFunction(ctx sessionctx.Context, args []E
}

type builtinRegexpSubstrFuncSig struct {
regexpBaseFuncSig
regexpNewBaseFuncSig
}

func (re *builtinRegexpSubstrFuncSig) vectorized() bool {
Expand All @@ -410,7 +503,7 @@ func (re *builtinRegexpSubstrFuncSig) vectorized() bool {

func (re *builtinRegexpSubstrFuncSig) Clone() builtinFunc {
newSig := &builtinRegexpSubstrFuncSig{}
newSig.regexpBaseFuncSig = *re.regexpBaseFuncSig.clone()
newSig.regexpNewBaseFuncSig = *re.regexpNewBaseFuncSig.clone()
return newSig
}

Expand Down Expand Up @@ -499,44 +592,15 @@ func (re *builtinRegexpSubstrFuncSig) evalString(ctx EvalContext, row chunk.Row)
}
}

memorize := func() {
compile, err := re.genCompile(matchType)
if err != nil {
re.memorizedErr = err
return
}
re.memorize(compile, pat)
}

if re.canMemorize(ctx, regexpSubstrMatchTypeIdx) {
re.once.Do(memorize) // Avoid data race
}

if !re.isMemorizedRegexpInitialized() {
compile, err := re.genCompile(matchType)
if err != nil {
return "", true, ErrRegexp.GenWithStackByArgs(err)
}
reg, err := compile(pat)
if err != nil {
return "", true, ErrRegexp.GenWithStackByArgs(err)
}

if re.isBinaryCollation() {
return re.findBinString(reg, bexpr, occurrence)
}
return re.findString(reg, expr, occurrence)
}

if re.memorizedErr != nil {
return "", true, ErrRegexp.GenWithStackByArgs(re.memorizedErr)
reg, err := re.getRegexp(ctx, pat, matchType, regexpSubstrMatchTypeIdx)
if err != nil {
return "", true, err
}

if re.isBinaryCollation() {
return re.findBinString(re.memorizedRegexp, bexpr, occurrence)
return re.findBinString(reg, bexpr, occurrence)
}

return re.findString(re.memorizedRegexp, expr, occurrence)
return re.findString(reg, expr, occurrence)
}

// REGEXP_SUBSTR(expr, pat[, pos[, occurrence[, match_type]]])
Expand Down Expand Up @@ -599,7 +663,7 @@ func (re *builtinRegexpSubstrFuncSig) vecEvalString(ctx EvalContext, input *chun
}

// Check memorization
err = re.tryToMemorize(ctx, params, regexpSubstrMatchTypeIdx, n)
reg, memorized, err := re.tryVecMemorizedRegexp(ctx, params, regexpSubstrMatchTypeIdx, n)
if err != nil {
return err
}
Expand Down Expand Up @@ -647,11 +711,13 @@ func (re *builtinRegexpSubstrFuncSig) vecEvalString(ctx EvalContext, input *chun
occurrence = 1
}

// Get match type and generate regexp
matchType := params[4].getStringVal(i)
reg, err := re.genRegexp(params[1].getStringVal(i), matchType)
if err != nil {
return err
if !memorized {
// Get pattern and match type and then generate regexp
pattern := params[1].getStringVal(i)
matchType := params[4].getStringVal(i)
if reg, err = re.buildRegexp(pattern, matchType); err != nil {
return err
}
}

// Find string
Expand Down
77 changes: 77 additions & 0 deletions pkg/expression/builtin_regexp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1224,3 +1224,80 @@ func TestRegexpReplaceVec(t *testing.T) {

testVectorizedBuiltinFunc(t, vecBuiltinRegexpReplaceCases)
}

func TestRegexpCache(t *testing.T) {
ctx := createContext(t)

// if the pattern or match type is not constant, it should not be cached
sig := regexpNewBaseFuncSig{}
sig.args = []Expression{&Column{}, &Column{}, &Constant{}}
reg, err := sig.getRegexp(ctx, "abc", "", 2)
require.NoError(t, err)
require.Equal(t, "abc", reg.String())

reg, err = sig.getRegexp(ctx, "def", "", 2)
require.NoError(t, err)
require.Equal(t, "def", reg.String())

reg, ok, err := sig.tryVecMemorizedRegexp(ctx, []*funcParam{
{defaultStrVal: "x"},
{defaultStrVal: "aaa"},
{defaultStrVal: ""},
}, 2, 1)
require.Nil(t, reg)
require.False(t, ok)
require.NoError(t, err)

_, ok = sig.memorizedRegexp.getCache(ctx.GetSessionVars().StmtCtx.CtxID())
require.False(t, ok)

sig.args = []Expression{&Column{}, &Constant{}, &Column{}}
reg, err = sig.getRegexp(ctx, "bbb", "", 2)
require.NoError(t, err)
require.Equal(t, "bbb", reg.String())

reg, ok, err = sig.tryVecMemorizedRegexp(ctx, []*funcParam{
{defaultStrVal: "x"},
{defaultStrVal: "aaa"},
{defaultStrVal: ""},
}, 2, 1)
require.Nil(t, reg)
require.False(t, ok)
require.NoError(t, err)

_, ok = sig.memorizedRegexp.getCache(ctx.GetSessionVars().StmtCtx.CtxID())
require.False(t, ok)

// if pattern and match type are both constant, it should be cached
sig = regexpNewBaseFuncSig{}
sig.args = []Expression{&Column{}, &Constant{ParamMarker: &ParamMarker{}}, &Constant{ParamMarker: &ParamMarker{}}}
reg, err = sig.getRegexp(ctx, "ccc", "", 2)
require.NoError(t, err)
require.Equal(t, "ccc", reg.String())

reg2, err := sig.getRegexp(ctx, "ddd", "", 2)
require.NoError(t, err)
require.Same(t, reg, reg2)
require.Equal(t, "ccc", reg2.String())

sig = regexpNewBaseFuncSig{}
sig.args = []Expression{&Column{}, &Constant{ParamMarker: &ParamMarker{}}, &Constant{ParamMarker: &ParamMarker{}}}
reg, ok, err = sig.tryVecMemorizedRegexp(ctx, []*funcParam{
{defaultStrVal: "x"},
{defaultStrVal: "ddd"},
{defaultStrVal: ""},
}, 2, 1)
require.Equal(t, "ddd", reg.String())
require.True(t, ok)
require.NoError(t, err)

reg2, ok, err = sig.tryVecMemorizedRegexp(ctx, []*funcParam{
{defaultStrVal: "x"},
{defaultStrVal: "eee"},
{defaultStrVal: ""},
}, 2, 1)
require.Same(t, reg, reg2)
require.Equal(t, "ddd", reg2.String())
require.True(t, ok)
require.NoError(t, err)
}
Loading

0 comments on commit 97acf71

Please sign in to comment.