Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: aggregate the collation only if the function is needed. #27789

Merged
merged 8 commits into from
Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 12 additions & 41 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,11 @@ func newBaseBuiltinFunc(ctx sessionctx.Context, funcName string, args []Expressi
if ctx == nil {
return baseBuiltinFunc{}, errors.New("unexpected nil session ctx")
}
if err := CheckIllegalMixCollation(funcName, args, retType); err != nil {
derivedCharset, derivedCollate, coer, err := deriveCollation(ctx, funcName, args, retType, retType)
if err != nil {
return baseBuiltinFunc{}, err
}
derivedCharset, derivedCollate := DeriveCollationFromExprs(ctx, args...)

bf := baseBuiltinFunc{
bufAllocator: newLocalColumnPool(),
childrenVectorizedOnce: new(sync.Once),
Expand All @@ -106,41 +107,10 @@ func newBaseBuiltinFunc(ctx sessionctx.Context, funcName string, args []Expressi
}
bf.SetCharsetAndCollation(derivedCharset, derivedCollate)
bf.setCollator(collate.GetCollator(derivedCollate))
bf.SetCoercibility(coer)
return bf, nil
}

var (
coerString = []string{"EXPLICIT", "NONE", "IMPLICIT", "SYSCONST", "COERCIBLE", "NUMERIC", "IGNORABLE"}
)

// CheckIllegalMixCollation checks illegal mix collation with expressions
func CheckIllegalMixCollation(funcName string, args []Expression, evalType types.EvalType) error {
if len(args) < 2 {
return nil
}
_, _, coercibility, legal := inferCollation(args...)
if !legal {
return illegalMixCollationErr(funcName, args)
}
if coercibility == CoercibilityNone && evalType != types.ETString {
return illegalMixCollationErr(funcName, args)
}
return nil
}

func illegalMixCollationErr(funcName string, args []Expression) error {
funcName = GetDisplayName(funcName)

switch len(args) {
case 2:
return collate.ErrIllegalMix2Collation.GenWithStackByArgs(args[0].GetType().Collate, coerString[args[0].Coercibility()], args[1].GetType().Collate, coerString[args[1].Coercibility()], funcName)
case 3:
return collate.ErrIllegalMix3Collation.GenWithStackByArgs(args[0].GetType().Collate, coerString[args[0].Coercibility()], args[1].GetType().Collate, coerString[args[1].Coercibility()], args[2].GetType().Collate, coerString[args[2].Coercibility()], funcName)
default:
return collate.ErrIllegalMixCollation.GenWithStackByArgs(funcName)
}
}

// newBaseBuiltinFuncWithTp creates a built-in function signature with specified types of arguments and the return type of the function.
// argTps indicates the types of the args, retType indicates the return type of the built-in function.
// Every built-in function needs determined argTps and retType when we create it.
Expand All @@ -152,6 +122,13 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Ex
return baseBuiltinFunc{}, errors.New("unexpected nil session ctx")
}

// derive collation information for string function, and we must do it
// before doing implicit cast.
derivedCharset, derivedCollate, coer, err := deriveCollation(ctx, funcName, args, retType, argTps...)
if err != nil {
return
}

for i := range args {
switch argTps[i] {
case types.ETInt:
Expand All @@ -173,13 +150,6 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Ex
}
}

if err = CheckIllegalMixCollation(funcName, args, retType); err != nil {
return
}

// derive collation information for string function, and we must do it
// before doing implicit cast.
derivedCharset, derivedCollate := DeriveCollationFromExprs(ctx, args...)
var fieldType *types.FieldType
switch retType {
case types.ETInt:
Expand Down Expand Up @@ -259,6 +229,7 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Ex
}
bf.SetCharsetAndCollation(derivedCharset, derivedCollate)
bf.setCollator(collate.GetCollator(derivedCollate))
bf.SetCoercibility(coer)
return bf, nil
}

Expand Down
2 changes: 1 addition & 1 deletion expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ func GetCmpFunction(ctx sessionctx.Context, lhs, rhs Expression) CompareFunc {
case types.ETDecimal:
return CompareDecimal
case types.ETString:
_, dstCollation := DeriveCollationFromExprs(ctx, lhs, rhs)
_, dstCollation, _ := CheckAndDeriveCollationFromExprs(ctx, "", types.ETInt, lhs, rhs)
return genCompareString(dstCollation)
case types.ETDuration:
return CompareDuration
Expand Down
6 changes: 0 additions & 6 deletions expression/builtin_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ func (c *databaseFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
if err != nil {
return nil, err
}
bf.tp.Charset, bf.tp.Collate = ctx.GetSessionVars().GetCharsetInfo()
xiongjiwei marked this conversation as resolved.
Show resolved Hide resolved
bf.tp.Flen = 64
sig := &builtinDatabaseSig{bf}
return sig, nil
Expand Down Expand Up @@ -169,7 +168,6 @@ func (c *currentUserFunctionClass) getFunction(ctx sessionctx.Context, args []Ex
if err != nil {
return nil, err
}
bf.tp.Charset, bf.tp.Collate = ctx.GetSessionVars().GetCharsetInfo()
bf.tp.Flen = 64
sig := &builtinCurrentUserSig{bf}
return sig, nil
Expand Down Expand Up @@ -207,7 +205,6 @@ func (c *currentRoleFunctionClass) getFunction(ctx sessionctx.Context, args []Ex
if err != nil {
return nil, err
}
bf.tp.Charset, bf.tp.Collate = ctx.GetSessionVars().GetCharsetInfo()
bf.tp.Flen = 64
sig := &builtinCurrentRoleSig{bf}
return sig, nil
Expand Down Expand Up @@ -259,7 +256,6 @@ func (c *userFunctionClass) getFunction(ctx sessionctx.Context, args []Expressio
if err != nil {
return nil, err
}
bf.tp.Charset, bf.tp.Collate = ctx.GetSessionVars().GetCharsetInfo()
bf.tp.Flen = 64
sig := &builtinUserSig{bf}
return sig, nil
Expand Down Expand Up @@ -401,7 +397,6 @@ func (c *versionFunctionClass) getFunction(ctx sessionctx.Context, args []Expres
if err != nil {
return nil, err
}
bf.tp.Charset, bf.tp.Collate = ctx.GetSessionVars().GetCharsetInfo()
bf.tp.Flen = 64
sig := &builtinVersionSig{bf}
return sig, nil
Expand Down Expand Up @@ -435,7 +430,6 @@ func (c *tidbVersionFunctionClass) getFunction(ctx sessionctx.Context, args []Ex
if err != nil {
return nil, err
}
bf.tp.Charset, bf.tp.Collate = ctx.GetSessionVars().GetCharsetInfo()
bf.tp.Flen = len(printer.GetTiDBInfo())
sig := &builtinTiDBVersionSig{bf}
return sig, nil
Expand Down
132 changes: 109 additions & 23 deletions expression/collation.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/logutil"
)

Expand Down Expand Up @@ -105,14 +106,6 @@ const (
)

var (
sysConstFuncs = map[string]struct{}{
ast.User: {},
ast.Version: {},
ast.Database: {},
ast.CurrentRole: {},
ast.CurrentUser: {},
}

// collationPriority is the priority when infer the result collation, the priority of collation a > b iff collationPriority[a] > collationPriority[b]
// collation a and b are incompatible if collationPriority[a] = collationPriority[b]
collationPriority = map[string]int{
Expand Down Expand Up @@ -152,21 +145,7 @@ var (
)

func deriveCoercibilityForScarlarFunc(sf *ScalarFunction) Coercibility {
if _, ok := sysConstFuncs[sf.FuncName.L]; ok {
return CoercibilitySysconst
}
if sf.RetType.EvalType() != types.ETString {
return CoercibilityNumeric
}

_, _, coer, _ := inferCollation(sf.GetArgs()...)

// it is weird if a ScalarFunction is CoercibilityNumeric but return string type
if coer == CoercibilityNumeric {
return CoercibilityCoercible
}

return coer
panic("this function should never be called")
xiongjiwei marked this conversation as resolved.
Show resolved Hide resolved
}

func deriveCoercibilityForConstant(c *Constant) Coercibility {
Expand All @@ -189,12 +168,102 @@ func deriveCoercibilityForColumn(c *Column) Coercibility {
return CoercibilityImplicit
}

func deriveCollation(ctx sessionctx.Context, funcName string, args []Expression, retType types.EvalType, argTps ...types.EvalType) (dstCharset, dstCollation string, coercibility Coercibility, err error) {
xiongjiwei marked this conversation as resolved.
Show resolved Hide resolved
switch funcName {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please write a UT for deriveCollation(), we need more tests for it.

Copy link
Contributor Author

@xiongjiwei xiongjiwei Sep 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add more UT in the next PR because the next PR will make aggregate collation the same with MySQL.

case ast.Concat, ast.ConcatWS, ast.Lower, ast.Reverse, ast.Upper, ast.Quote, ast.Coalesce:
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, retType, args...)
case ast.Left, ast.Right, ast.Repeat, ast.Trim, ast.LTrim, ast.RTrim, ast.Substr, ast.SubstringIndex, ast.Replace:
xiongjiwei marked this conversation as resolved.
Show resolved Hide resolved
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, retType, args[0])
case ast.InsertFunc:
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, retType, args[0], args[3])
case ast.Lpad, ast.Rpad:
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, retType, args[0], args[2])
case ast.Elt, ast.ExportSet, ast.MakeSet:
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, retType, args[1:]...)
case ast.FindInSet, ast.Regexp:
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, types.ETInt, args...)
case ast.Field:
if argTps[0] == types.ETString {
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, retType, args...)
}
case ast.Locate, ast.Instr:
xiongjiwei marked this conversation as resolved.
Show resolved Hide resolved
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, retType, args[0], args[1])
case ast.GE, ast.LE, ast.GT, ast.LT, ast.EQ, ast.NE, ast.NullEQ:
// if compare type is string, we should determine which collation should be used.
if argTps[0] == types.ETString {
dstCharset, dstCollation, _, err = CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, types.ETInt, args...)
return dstCharset, dstCollation, CoercibilityNumeric, err
}
case ast.If:
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, retType, args[1], args[2])
case ast.Like:
dstCharset, dstCollation, _, err = CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, types.ETInt, args[0], args[1])
if err != nil {
return
}
return dstCharset, dstCollation, CoercibilityCoercible, err
case ast.In:
if args[0].GetType().EvalType() == types.ETString {
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, types.ETInt, args...)
}
case ast.DateFormat, ast.TimeFormat:
charsetInfo, collation := ctx.GetSessionVars().GetCharsetInfo()
return charsetInfo, collation, args[1].Coercibility(), nil
case ast.Cast:
// we assume all the cast are implicit.
if retType == types.ETString {
charsetInfo, collation := ctx.GetSessionVars().GetCharsetInfo()
return charsetInfo, collation, args[0].Coercibility(), nil
}
return charset.CharsetBin, charset.CollationBin, args[0].Coercibility(), nil
case ast.Case:
// FIXME: case function aggregate collation is not correct.
return CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, retType, args...)
case ast.Database, ast.User, ast.CurrentUser, ast.Version, ast.CurrentRole, ast.TiDBVersion:
chs, coll := charset.GetDefaultCharsetAndCollate()
return chs, coll, CoercibilitySysconst, nil
}

if retType == types.ETString {
charsetInfo, collation := ctx.GetSessionVars().GetCharsetInfo()
return charsetInfo, collation, CoercibilityCoercible, nil
}
return charset.CharsetBin, charset.CollationBin, CoercibilityNumeric, nil
}

// DeriveCollationFromExprs derives collation information from these expressions.
// Deprecated, use CheckAndDeriveCollationFromExprs instead.
xiongjiwei marked this conversation as resolved.
Show resolved Hide resolved
// TODO: remove this function after the all usage is replaced by CheckAndDeriveCollationFromExprs
func DeriveCollationFromExprs(ctx sessionctx.Context, exprs ...Expression) (dstCharset, dstCollation string) {
dstCollation, dstCharset, _, _ = inferCollation(exprs...)
return
}

// CheckAndDeriveCollationFromExprs derives collation information from these expressions, return error if derives collation error.
func CheckAndDeriveCollationFromExprs(ctx sessionctx.Context, funcName string, evalType types.EvalType, args ...Expression) (dstCharset, dstCollation string, err error) {
dstCharset, dstCollation, _, err = CheckAndDeriveCollationFromExprsWithCoer(ctx, funcName, evalType, args...)
return
}

// CheckAndDeriveCollationFromExprsWithCoer is the same with CheckAndDeriveCollationFromExprs, but also return Coercibility.
func CheckAndDeriveCollationFromExprsWithCoer(ctx sessionctx.Context, funcName string, evalType types.EvalType, args ...Expression) (dstCharset, dstCollation string, coercibility Coercibility, err error) {
coll, chs, coercibility, legal := inferCollation(args...)
if !legal {
return "", "", CoercibilityIgnorable, illegalMixCollationErr(funcName, args)
}

if evalType != types.ETString && coercibility == CoercibilityNone {
xiongjiwei marked this conversation as resolved.
Show resolved Hide resolved
return "", "", CoercibilityIgnorable, illegalMixCollationErr(funcName, args)
}

if evalType == types.ETString && coercibility == CoercibilityNumeric {
charsetInfo, collation := ctx.GetSessionVars().GetCharsetInfo()
return charsetInfo, collation, CoercibilityCoercible, nil
}

return chs, coll, coercibility, nil
}

// inferCollation infers collation, charset, coercibility and check the legitimacy.
func inferCollation(exprs ...Expression) (dstCollation, dstCharset string, coercibility Coercibility, legal bool) {
firstExplicitCollation := ""
Expand Down Expand Up @@ -239,3 +308,20 @@ func getBinCollation(cs string) string {
// it must return something, never reachable
return charset.CollationUTF8MB4
}

var (
coerString = []string{"EXPLICIT", "NONE", "IMPLICIT", "SYSCONST", "COERCIBLE", "NUMERIC", "IGNORABLE"}
)

func illegalMixCollationErr(funcName string, args []Expression) error {
funcName = GetDisplayName(funcName)

switch len(args) {
case 2:
return collate.ErrIllegalMix2Collation.GenWithStackByArgs(args[0].GetType().Collate, coerString[args[0].Coercibility()], args[1].GetType().Collate, coerString[args[1].Coercibility()], funcName)
case 3:
return collate.ErrIllegalMix3Collation.GenWithStackByArgs(args[0].GetType().Collate, coerString[args[0].Coercibility()], args[1].GetType().Collate, coerString[args[1].Coercibility()], args[2].GetType().Collate, coerString[args[2].Coercibility()], funcName)
default:
return collate.ErrIllegalMixCollation.GenWithStackByArgs(funcName)
}
}
5 changes: 2 additions & 3 deletions expression/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -653,10 +653,9 @@ func (col *Column) ReverseEval(sc *stmtctx.StatementContext, res types.Datum, rT

// Coercibility returns the coercibility value which is used to check collations.
func (col *Column) Coercibility() Coercibility {
if col.HasCoercibility() {
return col.collationInfo.Coercibility()
if !col.HasCoercibility() {
col.SetCoercibility(deriveCoercibilityForColumn(col))
}
col.SetCoercibility(deriveCoercibilityForColumn(col))
return col.collationInfo.Coercibility()
}

Expand Down
6 changes: 2 additions & 4 deletions expression/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,10 +425,8 @@ func (c *Constant) ReverseEval(sc *stmtctx.StatementContext, res types.Datum, rT

// Coercibility returns the coercibility value which is used to check collations.
func (c *Constant) Coercibility() Coercibility {
if c.HasCoercibility() {
return c.collationInfo.Coercibility()
if !c.HasCoercibility() {
c.SetCoercibility(deriveCoercibilityForConstant(c))
}

c.SetCoercibility(deriveCoercibilityForConstant(c))
return c.collationInfo.Coercibility()
}
15 changes: 12 additions & 3 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6412,6 +6412,15 @@ func (s *testIntegrationSuite) TestCollation(c *C) {
tk.MustQuery("select collation(@test_collate_var)").Check(testkit.Rows("utf8mb4_general_ci"))
tk.MustExec("set @test_collate_var = concat(\"a\", \"b\" collate utf8mb4_bin)")
tk.MustQuery("select collation(@test_collate_var)").Check(testkit.Rows("utf8mb4_bin"))

tk.MustQuery("select locate('1', '123' collate utf8mb4_bin, 2 collate `binary`);").Check(testkit.Rows("0"))
tk.MustQuery("select 1 in ('a' collate utf8mb4_bin, 'b' collate utf8mb4_general_ci);").Check(testkit.Rows("0"))
tk.MustQuery("select left('abc' collate utf8mb4_bin, 2 collate `binary`);").Check(testkit.Rows("ab"))
tk.MustQuery("select right('abc' collate utf8mb4_bin, 2 collate `binary`);").Check(testkit.Rows("bc"))
tk.MustQuery("select repeat('abc' collate utf8mb4_bin, 2 collate `binary`);").Check(testkit.Rows("abcabc"))
tk.MustQuery("select trim(both 'abc' collate utf8mb4_bin from 'c' collate utf8mb4_general_ci);").Check(testkit.Rows("c"))
tk.MustQuery("select substr('abc' collate utf8mb4_bin, 2 collate `binary`);").Check(testkit.Rows("bc"))
tk.MustQuery("select replace('abc' collate utf8mb4_bin, 'b' collate utf8mb4_general_ci, 'd' collate utf8mb4_unicode_ci);").Check(testkit.Rows("adc"))
}

func (s *testIntegrationSuite) TestCoercibility(c *C) {
Expand Down Expand Up @@ -6558,7 +6567,7 @@ func (s *testIntegrationSerialSuite) TestCollationBasic(c *C) {
tk.MustQuery("select c from t where c = 'b';").Check(testkit.Rows("B"))
tk.MustQuery("select c from t where c = 'B';").Check(testkit.Rows("B"))

tk.MustExec("drop table if exists t")
tk.MustExec("drop table if exists t1")
tk.MustExec("CREATE TABLE `t1` (" +
" `COL1` varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL," +
" PRIMARY KEY (`COL1`(5)) clustered" +
Expand Down Expand Up @@ -6829,7 +6838,7 @@ func (s *testIntegrationSerialSuite) TestMixCollation(c *C) {
tk.MustQuery("select coercibility(concat(concat(mb4unicode, mb4general), concat(unicode, general))) from t;").Check(testkit.Rows("1"))
tk.MustQuery("select collation(concat(i, 1)) from t;").Check(testkit.Rows("utf8mb4_general_ci"))
tk.MustQuery("select coercibility(concat(i, 1)) from t;").Check(testkit.Rows("4"))
tk.MustQuery("select collation(concat(i, user())) from t;").Check(testkit.Rows("utf8mb4_general_ci"))
tk.MustQuery("select collation(concat(i, user())) from t;").Check(testkit.Rows("utf8mb4_bin"))
xiongjiwei marked this conversation as resolved.
Show resolved Hide resolved
tk.MustQuery("select coercibility(concat(i, user())) from t;").Check(testkit.Rows("3"))
tk.MustGetErrMsg("select * from t where mb4unicode = mb4general;", "[expression:1267]Illegal mix of collations (utf8mb4_unicode_ci,IMPLICIT) and (utf8mb4_general_ci,IMPLICIT) for operation '='")
tk.MustGetErrMsg("select * from t where unicode = general;", "[expression:1267]Illegal mix of collations (utf8_unicode_ci,IMPLICIT) and (utf8_general_ci,IMPLICIT) for operation '='")
Expand Down Expand Up @@ -8987,7 +8996,7 @@ func (s *testIntegrationSerialSuite) TestLikeWithCollation(c *C) {
defer collate.SetNewCollationEnabledForTest(false)

tk.MustQuery(`select 'a' like 'A' collate utf8mb4_unicode_ci;`).Check(testkit.Rows("1"))
tk.MustGetErrMsg(`select 'a' collate utf8mb4_bin like 'A' collate utf8mb4_unicode_ci;`, "[expression:1270]Illegal mix of collations (utf8mb4_bin,EXPLICIT), (utf8mb4_unicode_ci,EXPLICIT), (binary,NUMERIC) for operation 'like'")
tk.MustGetErrMsg(`select 'a' collate utf8mb4_bin like 'A' collate utf8mb4_unicode_ci;`, "[expression:1267]Illegal mix of collations (utf8mb4_bin,EXPLICIT) and (utf8mb4_unicode_ci,EXPLICIT) for operation 'like'")
zimulala marked this conversation as resolved.
Show resolved Hide resolved
tk.MustQuery(`select '😛' collate utf8mb4_general_ci like '😋';`).Check(testkit.Rows("1"))
tk.MustQuery(`select '😛' collate utf8mb4_general_ci = '😋';`).Check(testkit.Rows("1"))
tk.MustQuery(`select '😛' collate utf8mb4_unicode_ci like '😋';`).Check(testkit.Rows("0"))
Expand Down
Loading