Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: handle max_allowed_packet warnings for to_base64 function. #7266

Merged
merged 17 commits into from
Aug 15, 2018
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ func (b *builtinASCIISig) Clone() builtinFunc {
return newSig
}

// eval evals a builtinASCIISig.
// evalInt evals a builtinASCIISig.
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_ascii
func (b *builtinASCIISig) evalInt(row chunk.Row) (int64, bool, error) {
val, isNull, err := b.args[0].EvalString(b.ctx, row)
Expand Down Expand Up @@ -285,6 +285,7 @@ func (b *builtinConcatSig) Clone() builtinFunc {
return newSig
}

// evalString evals a builtinConcatSig
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_concat
func (b *builtinConcatSig) evalString(row chunk.Row) (d string, isNull bool, err error) {
var s []byte
Expand Down Expand Up @@ -568,7 +569,7 @@ func (b *builtinRepeatSig) Clone() builtinFunc {
return newSig
}

// eval evals a builtinRepeatSig.
// evalString evals a builtinRepeatSig.
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_repeat
func (b *builtinRepeatSig) evalString(row chunk.Row) (d string, isNull bool, err error) {
str, isNull, err := b.args[0].EvalString(b.ctx, row)
Expand Down Expand Up @@ -1515,6 +1516,7 @@ type trimFunctionClass struct {
baseFunctionClass
}

// getFunction sets trim built-in function signature.
// The syntax of trim in mysql is 'TRIM([{BOTH | LEADING | TRAILING} [remstr] FROM] str), TRIM([remstr FROM] str)',
// but we wil convert it into trim(str), trim(str, remstr) and trim(str, remstr, direction) in AST.
func (c *trimFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
Expand Down Expand Up @@ -2482,8 +2484,8 @@ func (b *builtinOctStringSig) Clone() builtinFunc {
return newSig
}

// // evalString evals OCT(N).
// // See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_oct
// evalString evals OCT(N).
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_oct
func (b *builtinOctStringSig) evalString(row chunk.Row) (string, bool, error) {
val, isNull, err := b.args[0].EvalString(b.ctx, row)
if isNull || err != nil {
Expand Down Expand Up @@ -2999,17 +3001,26 @@ func (c *toBase64FunctionClass) getFunction(ctx sessionctx.Context, args []Expre
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString)
bf.tp.Flen = base64NeededEncodedLength(bf.args[0].GetType().Flen)
sig := &builtinToBase64Sig{bf}

valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, errors.Trace(err)
}

sig := &builtinToBase64Sig{bf, maxAllowedPacket}
return sig, nil
}

type builtinToBase64Sig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinToBase64Sig) Clone() builtinFunc {
newSig := &builtinToBase64Sig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand Down Expand Up @@ -3043,7 +3054,14 @@ func (b *builtinToBase64Sig) evalString(row chunk.Row) (d string, isNull bool, e
if isNull || err != nil {
return "", isNull, errors.Trace(err)
}

needEncodeLen := base64NeededEncodedLength(len(str))
if needEncodeLen == -1 {
return "", true, nil
}
if needEncodeLen > int(b.maxAllowedPacket) {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("to_base64", b.maxAllowedPacket))
return "", true, nil
}
if b.tp.Flen == -1 || b.tp.Flen > mysql.MaxBlobWidth {
return "", true, nil
}
Expand Down
69 changes: 69 additions & 0 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1879,6 +1879,75 @@ func (s *testEvaluatorSuite) TestToBase64(c *C) {
c.Assert(err, IsNil)
}

func (s *testEvaluatorSuite) TestToBase64Sig(c *C) {
colTypes := []*types.FieldType{
{Tp: mysql.TypeVarchar},
}

tests := []struct {
args string
expect string
isNil bool
maxAllowPacket uint64
}{
{"abc", "YWJj", false, 4},
{"abc", "", true, 3},
{
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
"QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVphYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ejAxMjM0\nNTY3ODkrLw==",
false,
89,
},
{
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
"",
true,
88,
},
{
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
"QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVphYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ejAxMjM0\nNTY3ODkrL0FCQ0RFRkdISUpLTE1OT1BRUlNUVVZXWFlaYWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4\neXowMTIzNDU2Nzg5Ky9BQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWmFiY2RlZmdoaWprbG1ub3Bx\ncnN0dXZ3eHl6MDEyMzQ1Njc4OSsv",
false,
259,
},
{
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
"",
true,
258,
},
}

args := []Expression{
&Column{Index: 0, RetType: colTypes[0]},
}

warningCount := 0

for _, test := range tests {
resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: base64NeededEncodedLength(len(test.args))}
base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType}
toBase64 := &builtinToBase64Sig{base, test.maxAllowPacket}

input := chunk.NewChunkWithCapacity(colTypes, 1)
input.AppendString(0, test.args)
res, isNull, err := toBase64.evalString(input.GetRow(0))
c.Assert(err, IsNil)
if test.isNil {
c.Assert(isNull, IsTrue)
warningCount += 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warningCount will add 1 when to_base64 result is nil

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we reset the warning appended to s.ctx.GetSessionVars().StmtCtx and check the exactly warning count and warning content in each test case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

} else {
c.Assert(isNull, IsFalse)
}
c.Assert(res, Equals, test.expect)
}
warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(len(warnings), Equals, warningCount)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warningCount is always zero, seems the test of max_allowed_packets is not working

for _, warn := range warnings {
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, warn.Err), IsTrue)
}
}

func (s *testEvaluatorSuite) TestStringRight(c *C) {
defer testleak.AfterTest(c)()
fc := funcs[ast.Right]
Expand Down