Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: handle max_allowed_packet warnings for to_base64 function. #7266

Merged
merged 17 commits into from
Aug 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ func (b *builtinASCIISig) Clone() builtinFunc {
return newSig
}

// eval evals a builtinASCIISig.
// evalInt evals a builtinASCIISig.
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_ascii
func (b *builtinASCIISig) evalInt(row chunk.Row) (int64, bool, error) {
val, isNull, err := b.args[0].EvalString(b.ctx, row)
Expand Down Expand Up @@ -285,6 +285,7 @@ func (b *builtinConcatSig) Clone() builtinFunc {
return newSig
}

// evalString evals a builtinConcatSig
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_concat
func (b *builtinConcatSig) evalString(row chunk.Row) (d string, isNull bool, err error) {
var s []byte
Expand Down Expand Up @@ -568,7 +569,7 @@ func (b *builtinRepeatSig) Clone() builtinFunc {
return newSig
}

// eval evals a builtinRepeatSig.
// evalString evals a builtinRepeatSig.
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_repeat
func (b *builtinRepeatSig) evalString(row chunk.Row) (d string, isNull bool, err error) {
str, isNull, err := b.args[0].EvalString(b.ctx, row)
Expand Down Expand Up @@ -1515,6 +1516,7 @@ type trimFunctionClass struct {
baseFunctionClass
}

// getFunction sets trim built-in function signature.
// The syntax of trim in mysql is 'TRIM([{BOTH | LEADING | TRAILING} [remstr] FROM] str), TRIM([remstr FROM] str)',
// but we wil convert it into trim(str), trim(str, remstr) and trim(str, remstr, direction) in AST.
func (c *trimFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
Expand Down Expand Up @@ -2482,8 +2484,8 @@ func (b *builtinOctStringSig) Clone() builtinFunc {
return newSig
}

// // evalString evals OCT(N).
// // See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_oct
// evalString evals OCT(N).
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_oct
func (b *builtinOctStringSig) evalString(row chunk.Row) (string, bool, error) {
val, isNull, err := b.args[0].EvalString(b.ctx, row)
if isNull || err != nil {
Expand Down Expand Up @@ -2999,17 +3001,26 @@ func (c *toBase64FunctionClass) getFunction(ctx sessionctx.Context, args []Expre
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString)
bf.tp.Flen = base64NeededEncodedLength(bf.args[0].GetType().Flen)
sig := &builtinToBase64Sig{bf}

valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, errors.Trace(err)
}

sig := &builtinToBase64Sig{bf, maxAllowedPacket}
return sig, nil
}

type builtinToBase64Sig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinToBase64Sig) Clone() builtinFunc {
newSig := &builtinToBase64Sig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand Down Expand Up @@ -3043,7 +3054,14 @@ func (b *builtinToBase64Sig) evalString(row chunk.Row) (d string, isNull bool, e
if isNull || err != nil {
return "", isNull, errors.Trace(err)
}

needEncodeLen := base64NeededEncodedLength(len(str))
if needEncodeLen == -1 {
return "", true, nil
}
if needEncodeLen > int(b.maxAllowedPacket) {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("to_base64", b.maxAllowedPacket))
return "", true, nil
}
if b.tp.Flen == -1 || b.tp.Flen > mysql.MaxBlobWidth {
return "", true, nil
}
Expand Down
69 changes: 69 additions & 0 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/charset"
Expand Down Expand Up @@ -1879,6 +1880,74 @@ func (s *testEvaluatorSuite) TestToBase64(c *C) {
c.Assert(err, IsNil)
}

func (s *testEvaluatorSuite) TestToBase64Sig(c *C) {
colTypes := []*types.FieldType{
{Tp: mysql.TypeVarchar},
}

tests := []struct {
args string
expect string
isNil bool
maxAllowPacket uint64
}{
{"abc", "YWJj", false, 4},
{"abc", "", true, 3},
{
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
"QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVphYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ejAxMjM0\nNTY3ODkrLw==",
false,
89,
},
{
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
"",
true,
88,
},
{
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
"QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVphYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ejAxMjM0\nNTY3ODkrL0FCQ0RFRkdISUpLTE1OT1BRUlNUVVZXWFlaYWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4\neXowMTIzNDU2Nzg5Ky9BQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWmFiY2RlZmdoaWprbG1ub3Bx\ncnN0dXZ3eHl6MDEyMzQ1Njc4OSsv",
false,
259,
},
{
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
"",
true,
258,
},
}

args := []Expression{
&Column{Index: 0, RetType: colTypes[0]},
}

for _, test := range tests {
resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: base64NeededEncodedLength(len(test.args))}
base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType}
toBase64 := &builtinToBase64Sig{base, test.maxAllowPacket}

input := chunk.NewChunkWithCapacity(colTypes, 1)
input.AppendString(0, test.args)
res, isNull, err := toBase64.evalString(input.GetRow(0))
c.Assert(err, IsNil)
if test.isNil {
c.Assert(isNull, IsTrue)

warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(len(warnings), Equals, 1)
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue)
s.ctx.GetSessionVars().StmtCtx.SetWarnings([]stmtctx.SQLWarn{})

} else {
c.Assert(isNull, IsFalse)
}
c.Assert(res, Equals, test.expect)
}
}

func (s *testEvaluatorSuite) TestStringRight(c *C) {
defer testleak.AfterTest(c)()
fc := funcs[ast.Right]
Expand Down