Skip to content

Commit

Permalink
expression: aggregate the collation only if the function is needed. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
xiongjiwei authored Sep 22, 2021
1 parent 38e90ad commit b339ca2
Show file tree
Hide file tree
Showing 14 changed files with 198 additions and 102 deletions.
53 changes: 12 additions & 41 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,11 @@ func newBaseBuiltinFunc(ctx sessionctx.Context, funcName string, args []Expressi
if ctx == nil {
return baseBuiltinFunc{}, errors.New("unexpected nil session ctx")
}
if err := CheckIllegalMixCollation(funcName, args, retType); err != nil {
derivedCharset, derivedCollate, coer, err := deriveCollation(ctx, funcName, args, retType, retType)
if err != nil {
return baseBuiltinFunc{}, err
}
derivedCharset, derivedCollate := DeriveCollationFromExprs(ctx, args...)

bf := baseBuiltinFunc{
bufAllocator: newLocalColumnPool(),
childrenVectorizedOnce: new(sync.Once),
Expand All @@ -106,41 +107,10 @@ func newBaseBuiltinFunc(ctx sessionctx.Context, funcName string, args []Expressi
}
bf.SetCharsetAndCollation(derivedCharset, derivedCollate)
bf.setCollator(collate.GetCollator(derivedCollate))
bf.SetCoercibility(coer)
return bf, nil
}

var (
coerString = []string{"EXPLICIT", "NONE", "IMPLICIT", "SYSCONST", "COERCIBLE", "NUMERIC", "IGNORABLE"}
)

// CheckIllegalMixCollation checks illegal mix collation with expressions
func CheckIllegalMixCollation(funcName string, args []Expression, evalType types.EvalType) error {
if len(args) < 2 {
return nil
}
_, _, coercibility, legal := inferCollation(args...)
if !legal {
return illegalMixCollationErr(funcName, args)
}
if coercibility == CoercibilityNone && evalType != types.ETString {
return illegalMixCollationErr(funcName, args)
}
return nil
}

func illegalMixCollationErr(funcName string, args []Expression) error {
funcName = GetDisplayName(funcName)

switch len(args) {
case 2:
return collate.ErrIllegalMix2Collation.GenWithStackByArgs(args[0].GetType().Collate, coerString[args[0].Coercibility()], args[1].GetType().Collate, coerString[args[1].Coercibility()], funcName)
case 3:
return collate.ErrIllegalMix3Collation.GenWithStackByArgs(args[0].GetType().Collate, coerString[args[0].Coercibility()], args[1].GetType().Collate, coerString[args[1].Coercibility()], args[2].GetType().Collate, coerString[args[2].Coercibility()], funcName)
default:
return collate.ErrIllegalMixCollation.GenWithStackByArgs(funcName)
}
}

// newBaseBuiltinFuncWithTp creates a built-in function signature with specified types of arguments and the return type of the function.
// argTps indicates the types of the args, retType indicates the return type of the built-in function.
// Every built-in function needs determined argTps and retType when we create it.
Expand All @@ -152,6 +122,13 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Ex
return baseBuiltinFunc{}, errors.New("unexpected nil session ctx")
}

// derive collation information for string function, and we must do it
// before doing implicit cast.
derivedCharset, derivedCollate, coer, err := deriveCollation(ctx, funcName, args, retType, argTps...)
if err != nil {
return
}

for i := range args {
switch argTps[i] {
case types.ETInt:
Expand All @@ -173,13 +150,6 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Ex
}
}

if err = CheckIllegalMixCollation(funcName, args, retType); err != nil {
return
}

// derive collation information for string function, and we must do it
// before doing implicit cast.
derivedCharset, derivedCollate := DeriveCollationFromExprs(ctx, args...)
var fieldType *types.FieldType
switch retType {
case types.ETInt:
Expand Down Expand Up @@ -259,6 +229,7 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Ex
}
bf.SetCharsetAndCollation(derivedCharset, derivedCollate)
bf.setCollator(collate.GetCollator(derivedCollate))
bf.SetCoercibility(coer)
return bf, nil
}

Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ func GetCmpFunction(ctx sessionctx.Context, lhs, rhs Expression) CompareFunc {
case types.ETDecimal:
return CompareDecimal
case types.ETString:
_, dstCollation := DeriveCollationFromExprs(ctx, lhs, rhs)
_, dstCollation, _ := CheckAndDeriveCollationFromExprs(ctx, "", types.ETInt, lhs, rhs)
return genCompareString(dstCollation)
case types.ETDuration:
return CompareDuration
Expand Down
6 changes: 0 additions & 6 deletions expression/builtin_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ func (c *databaseFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
if err != nil {
return nil, err
}
bf.tp.Charset, bf.tp.Collate = ctx.GetSessionVars().GetCharsetInfo()
bf.tp.Flen = 64
sig := &builtinDatabaseSig{bf}
return sig, nil
Expand Down Expand Up @@ -169,7 +168,6 @@ func (c *currentUserFunctionClass) getFunction(ctx sessionctx.Context, args []Ex
if err != nil {
return nil, err
}
bf.tp.Charset, bf.tp.Collate = ctx.GetSessionVars().GetCharsetInfo()
bf.tp.Flen = 64
sig := &builtinCurrentUserSig{bf}
return sig, nil
Expand Down Expand Up @@ -207,7 +205,6 @@ func (c *currentRoleFunctionClass) getFunction(ctx sessionctx.Context, args []Ex
if err != nil {
return nil, err
}
bf.tp.Charset, bf.tp.Collate = ctx.GetSessionVars().GetCharsetInfo()
bf.tp.Flen = 64
sig := &builtinCurrentRoleSig{bf}
return sig, nil
Expand Down Expand Up @@ -259,7 +256,6 @@ func (c *userFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio
if err != nil {
return nil, err
}
bf.tp.Charset, bf.tp.Collate = ctx.GetSessionVars().GetCharsetInfo()
bf.tp.Flen = 64
sig := &builtinUserSig{bf}
return sig, nil
Expand Down Expand Up @@ -401,7 +397,6 @@ func (c *versionFunctionClass) getFunction(ctx sessionctx.Context, args []Expres
if err != nil {
return nil, err
}
bf.tp.Charset, bf.tp.Collate = ctx.GetSessionVars().GetCharsetInfo()
bf.tp.Flen = 64
sig := &builtinVersionSig{bf}
return sig, nil
Expand Down Expand Up @@ -435,7 +430,6 @@ func (c *tidbVersionFunctionClass) getFunction(ctx sessionctx.Context, args []Ex
if err != nil {
return nil, err
}
bf.tp.Charset, bf.tp.Collate = ctx.GetSessionVars().GetCharsetInfo()
bf.tp.Flen = len(printer.GetTiDBInfo())
sig := &builtinTiDBVersionSig{bf}
return sig, nil
Expand Down
132 changes: 109 additions & 23 deletions expression/collation.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/logutil"
)

Expand Down Expand Up @@ -105,14 +106,6 @@ const (
)

var (
sysConstFuncs = map[string]struct{}{
ast.User: {},
ast.Version: {},
ast.Database: {},
ast.CurrentRole: {},
ast.CurrentUser: {},
}

// collationPriority is the priority when infer the result collation, the priority of collation a > b iff collationPriority[a] > collationPriority[b]
// collation a and b are incompatible if collationPriority[a] = collationPriority[b]
collationPriority = map[string]int{
Expand Down Expand Up @@ -152,21 +145,7 @@ var (
)

func deriveCoercibilityForScarlarFunc(sf *ScalarFunction) Coercibility {
if _, ok := sysConstFuncs[sf.FuncName.L]; ok {
return CoercibilitySysconst
}
if sf.RetType.EvalType() != types.ETString {
return CoercibilityNumeric
}

_, _, coer, _ := inferCollation(sf.GetArgs()...)

// it is weird if a ScalarFunction is CoercibilityNumeric but return string type
if coer == CoercibilityNumeric {
return CoercibilityCoercible
}

return coer
panic("this function should never be called")
}

func deriveCoercibilityForConstant(c *Constant) Coercibility {
Expand All @@ -189,12 +168,102 @@ func deriveCoercibilityForColumn(c *Column) Coercibility {
return CoercibilityImplicit
}

func deriveCollation(ctx sessionctx.Context, funcName string, args []Expression, retType types.EvalType, argTps ...types.EvalType) (dstCharset, dstCollation string, coercibility Coercibility, err error) {
switch funcName {
case ast.Concat, ast.ConcatWS, ast.Lower, ast.Reverse, ast.Upper, ast.Quote, ast.Coalesce:
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, retType, args...)
case ast.Left, ast.Right, ast.Repeat, ast.Trim, ast.LTrim, ast.RTrim, ast.Substr, ast.SubstringIndex, ast.Replace:
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, retType, args[0])
case ast.InsertFunc:
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, retType, args[0], args[3])
case ast.Lpad, ast.Rpad:
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, retType, args[0], args[2])
case ast.Elt, ast.ExportSet, ast.MakeSet:
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, retType, args[1:]...)
case ast.FindInSet, ast.Regexp:
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, types.ETInt, args...)
case ast.Field:
if argTps[0] == types.ETString {
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, retType, args...)
}
case ast.Locate, ast.Instr:
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, retType, args[0], args[1])
case ast.GE, ast.LE, ast.GT, ast.LT, ast.EQ, ast.NE, ast.NullEQ:
// if compare type is string, we should determine which collation should be used.
if argTps[0] == types.ETString {
dstCharset, dstCollation, _, err = CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, types.ETInt, args...)
return dstCharset, dstCollation, CoercibilityNumeric, err
}
case ast.If:
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, retType, args[1], args[2])
case ast.Like:
dstCharset, dstCollation, _, err = CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, types.ETInt, args[0], args[1])
if err != nil {
return
}
return dstCharset, dstCollation, CoercibilityCoercible, err
case ast.In:
if args[0].GetType().EvalType() == types.ETString {
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, types.ETInt, args...)
}
case ast.DateFormat, ast.TimeFormat:
charsetInfo, collation := ctx.GetSessionVars().GetCharsetInfo()
return charsetInfo, collation, args[1].Coercibility(), nil
case ast.Cast:
// we assume all the cast are implicit.
if retType == types.ETString {
charsetInfo, collation := ctx.GetSessionVars().GetCharsetInfo()
return charsetInfo, collation, args[0].Coercibility(), nil
}
return charset.CharsetBin, charset.CollationBin, args[0].Coercibility(), nil
case ast.Case:
// FIXME: case function aggregate collation is not correct.
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, retType, args...)
case ast.Database, ast.User, ast.CurrentUser, ast.Version, ast.CurrentRole, ast.TiDBVersion:
chs, coll := charset.GetDefaultCharsetAndCollate()
return chs, coll, CoercibilitySysconst, nil
}

if retType == types.ETString {
charsetInfo, collation := ctx.GetSessionVars().GetCharsetInfo()
return charsetInfo, collation, CoercibilityCoercible, nil
}
return charset.CharsetBin, charset.CollationBin, CoercibilityNumeric, nil
}

// DeriveCollationFromExprs derives collation information from these expressions.
// Deprecated, use CheckAndDeriveCollationFromExprs instead.
// TODO: remove this function after the all usage is replaced by CheckAndDeriveCollationFromExprs
func DeriveCollationFromExprs(ctx sessionctx.Context, exprs ...Expression) (dstCharset, dstCollation string) {
dstCollation, dstCharset, _, _ = inferCollation(exprs...)
return
}

// CheckAndDeriveCollationFromExprs derives collation information from these expressions, return error if derives collation error.
func CheckAndDeriveCollationFromExprs(ctx sessionctx.Context, funcName string, evalType types.EvalType, args ...Expression) (dstCharset, dstCollation string, err error) {
dstCharset, dstCollation, _, err = CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, evalType, args...)
return
}

// CheckAndDeriveCollationFromExprsWithCoer is the same with CheckAndDeriveCollationFromExprs, but also return Coercibility.
func CheckAndDeriveCollationFromExprsWithCoer(ctx sessionctx.Context, funcName string, evalType types.EvalType, args ...Expression) (dstCharset, dstCollation string, coercibility Coercibility, err error) {
coll, chs, coercibility, legal := inferCollation(args...)
if !legal {
return "", "", CoercibilityIgnorable, illegalMixCollationErr(funcName, args)
}

if evalType != types.ETString && coercibility == CoercibilityNone {
return "", "", CoercibilityIgnorable, illegalMixCollationErr(funcName, args)
}

if evalType == types.ETString && coercibility == CoercibilityNumeric {
charsetInfo, collation := ctx.GetSessionVars().GetCharsetInfo()
return charsetInfo, collation, CoercibilityCoercible, nil
}

return chs, coll, coercibility, nil
}

// inferCollation infers collation, charset, coercibility and check the legitimacy.
func inferCollation(exprs ...Expression) (dstCollation, dstCharset string, coercibility Coercibility, legal bool) {
firstExplicitCollation := ""
Expand Down Expand Up @@ -239,3 +308,20 @@ func getBinCollation(cs string) string {
// it must return something, never reachable
return charset.CollationUTF8MB4
}

var (
coerString = []string{"EXPLICIT", "NONE", "IMPLICIT", "SYSCONST", "COERCIBLE", "NUMERIC", "IGNORABLE"}
)

func illegalMixCollationErr(funcName string, args []Expression) error {
funcName = GetDisplayName(funcName)

switch len(args) {
case 2:
return collate.ErrIllegalMix2Collation.GenWithStackByArgs(args[0].GetType().Collate, coerString[args[0].Coercibility()], args[1].GetType().Collate, coerString[args[1].Coercibility()], funcName)
case 3:
return collate.ErrIllegalMix3Collation.GenWithStackByArgs(args[0].GetType().Collate, coerString[args[0].Coercibility()], args[1].GetType().Collate, coerString[args[1].Coercibility()], args[2].GetType().Collate, coerString[args[2].Coercibility()], funcName)
default:
return collate.ErrIllegalMixCollation.GenWithStackByArgs(funcName)
}
}
5 changes: 2 additions & 3 deletions expression/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -653,10 +653,9 @@ func (col *Column) ReverseEval(sc *stmtctx.StatementContext, res types.Datum, rT

// Coercibility returns the coercibility value which is used to check collations.
func (col *Column) Coercibility() Coercibility {
if col.HasCoercibility() {
return col.collationInfo.Coercibility()
if !col.HasCoercibility() {
col.SetCoercibility(deriveCoercibilityForColumn(col))
}
col.SetCoercibility(deriveCoercibilityForColumn(col))
return col.collationInfo.Coercibility()
}

Expand Down
6 changes: 2 additions & 4 deletions expression/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,10 +425,8 @@ func (c *Constant) ReverseEval(sc *stmtctx.StatementContext, res types.Datum, rT

// Coercibility returns the coercibility value which is used to check collations.
func (c *Constant) Coercibility() Coercibility {
if c.HasCoercibility() {
return c.collationInfo.Coercibility()
if !c.HasCoercibility() {
c.SetCoercibility(deriveCoercibilityForConstant(c))
}

c.SetCoercibility(deriveCoercibilityForConstant(c))
return c.collationInfo.Coercibility()
}
15 changes: 12 additions & 3 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6412,6 +6412,15 @@ func (s *testIntegrationSuite) TestCollation(c *C) {
tk.MustQuery("select collation(@test_collate_var)").Check(testkit.Rows("utf8mb4_general_ci"))
tk.MustExec("set @test_collate_var = concat(\"a\", \"b\" collate utf8mb4_bin)")
tk.MustQuery("select collation(@test_collate_var)").Check(testkit.Rows("utf8mb4_bin"))

tk.MustQuery("select locate('1', '123' collate utf8mb4_bin, 2 collate `binary`);").Check(testkit.Rows("0"))
tk.MustQuery("select 1 in ('a' collate utf8mb4_bin, 'b' collate utf8mb4_general_ci);").Check(testkit.Rows("0"))
tk.MustQuery("select left('abc' collate utf8mb4_bin, 2 collate `binary`);").Check(testkit.Rows("ab"))
tk.MustQuery("select right('abc' collate utf8mb4_bin, 2 collate `binary`);").Check(testkit.Rows("bc"))
tk.MustQuery("select repeat('abc' collate utf8mb4_bin, 2 collate `binary`);").Check(testkit.Rows("abcabc"))
tk.MustQuery("select trim(both 'abc' collate utf8mb4_bin from 'c' collate utf8mb4_general_ci);").Check(testkit.Rows("c"))
tk.MustQuery("select substr('abc' collate utf8mb4_bin, 2 collate `binary`);").Check(testkit.Rows("bc"))
tk.MustQuery("select replace('abc' collate utf8mb4_bin, 'b' collate utf8mb4_general_ci, 'd' collate utf8mb4_unicode_ci);").Check(testkit.Rows("adc"))
}

func (s *testIntegrationSuite) TestCoercibility(c *C) {
Expand Down Expand Up @@ -6558,7 +6567,7 @@ func (s *testIntegrationSerialSuite) TestCollationBasic(c *C) {
tk.MustQuery("select c from t where c = 'b';").Check(testkit.Rows("B"))
tk.MustQuery("select c from t where c = 'B';").Check(testkit.Rows("B"))

tk.MustExec("drop table if exists t")
tk.MustExec("drop table if exists t1")
tk.MustExec("CREATE TABLE `t1` (" +
" `COL1` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL," +
" PRIMARY KEY (`COL1`(5)) clustered" +
Expand Down Expand Up @@ -6829,7 +6838,7 @@ func (s *testIntegrationSerialSuite) TestMixCollation(c *C) {
tk.MustQuery("select coercibility(concat(concat(mb4unicode, mb4general), concat(unicode, general))) from t;").Check(testkit.Rows("1"))
tk.MustQuery("select collation(concat(i, 1)) from t;").Check(testkit.Rows("utf8mb4_general_ci"))
tk.MustQuery("select coercibility(concat(i, 1)) from t;").Check(testkit.Rows("4"))
tk.MustQuery("select collation(concat(i, user())) from t;").Check(testkit.Rows("utf8mb4_general_ci"))
tk.MustQuery("select collation(concat(i, user())) from t;").Check(testkit.Rows("utf8mb4_bin"))
tk.MustQuery("select coercibility(concat(i, user())) from t;").Check(testkit.Rows("3"))
tk.MustGetErrMsg("select * from t where mb4unicode = mb4general;", "[expression:1267]Illegal mix of collations (utf8mb4_unicode_ci,IMPLICIT) and (utf8mb4_general_ci,IMPLICIT) for operation '='")
tk.MustGetErrMsg("select * from t where unicode = general;", "[expression:1267]Illegal mix of collations (utf8_unicode_ci,IMPLICIT) and (utf8_general_ci,IMPLICIT) for operation '='")
Expand Down Expand Up @@ -8987,7 +8996,7 @@ func (s *testIntegrationSerialSuite) TestLikeWithCollation(c *C) {
defer collate.SetNewCollationEnabledForTest(false)

tk.MustQuery(`select 'a' like 'A' collate utf8mb4_unicode_ci;`).Check(testkit.Rows("1"))
tk.MustGetErrMsg(`select 'a' collate utf8mb4_bin like 'A' collate utf8mb4_unicode_ci;`, "[expression:1270]Illegal mix of collations (utf8mb4_bin,EXPLICIT), (utf8mb4_unicode_ci,EXPLICIT), (binary,NUMERIC) for operation 'like'")
tk.MustGetErrMsg(`select 'a' collate utf8mb4_bin like 'A' collate utf8mb4_unicode_ci;`, "[expression:1267]Illegal mix of collations (utf8mb4_bin,EXPLICIT) and (utf8mb4_unicode_ci,EXPLICIT) for operation 'like'")
tk.MustQuery(`select '😛' collate utf8mb4_general_ci like '😋';`).Check(testkit.Rows("1"))
tk.MustQuery(`select '😛' collate utf8mb4_general_ci = '😋';`).Check(testkit.Rows("1"))
tk.MustQuery(`select '😛' collate utf8mb4_unicode_ci like '😋';`).Check(testkit.Rows("0"))
Expand Down
Loading

0 comments on commit b339ca2

Please sign in to comment.