Skip to content

Commit

Permalink
expression: handle max_allowed_packet warnings for to_base64 function. (
Browse files Browse the repository at this point in the history
  • Loading branch information
supernan1994 authored and zz-jason committed Aug 17, 2018
1 parent 1fb887f commit 6d55c90
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 6 deletions.
30 changes: 24 additions & 6 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ func (b *builtinASCIISig) Clone() builtinFunc {
return newSig
}

// eval evals a builtinASCIISig.
// evalInt evals a builtinASCIISig.
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_ascii
func (b *builtinASCIISig) evalInt(row types.Row) (int64, bool, error) {
val, isNull, err := b.args[0].EvalString(b.ctx, row)
Expand Down Expand Up @@ -284,6 +284,7 @@ func (b *builtinConcatSig) Clone() builtinFunc {
return newSig
}

// evalString evals a builtinConcatSig
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_concat
func (b *builtinConcatSig) evalString(row types.Row) (d string, isNull bool, err error) {
var s []byte
Expand Down Expand Up @@ -567,7 +568,7 @@ func (b *builtinRepeatSig) Clone() builtinFunc {
return newSig
}

// eval evals a builtinRepeatSig.
// evalString evals a builtinRepeatSig.
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_repeat
func (b *builtinRepeatSig) evalString(row types.Row) (d string, isNull bool, err error) {
str, isNull, err := b.args[0].EvalString(b.ctx, row)
Expand Down Expand Up @@ -1514,6 +1515,7 @@ type trimFunctionClass struct {
baseFunctionClass
}

// getFunction sets trim built-in function signature.
// The syntax of trim in mysql is 'TRIM([{BOTH | LEADING | TRAILING} [remstr] FROM] str), TRIM([remstr FROM] str)',
// but we wil convert it into trim(str), trim(str, remstr) and trim(str, remstr, direction) in AST.
func (c *trimFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
Expand Down Expand Up @@ -2481,8 +2483,8 @@ func (b *builtinOctStringSig) Clone() builtinFunc {
return newSig
}

// // evalString evals OCT(N).
// // See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_oct
// evalString evals OCT(N).
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_oct
func (b *builtinOctStringSig) evalString(row types.Row) (string, bool, error) {
val, isNull, err := b.args[0].EvalString(b.ctx, row)
if isNull || err != nil {
Expand Down Expand Up @@ -2998,17 +3000,26 @@ func (c *toBase64FunctionClass) getFunction(ctx sessionctx.Context, args []Expre
}
bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETString, types.ETString)
bf.tp.Flen = base64NeededEncodedLength(bf.args[0].GetType().Flen)
sig := &builtinToBase64Sig{bf}

valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket)
maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
return nil, errors.Trace(err)
}

sig := &builtinToBase64Sig{bf, maxAllowedPacket}
return sig, nil
}

type builtinToBase64Sig struct {
baseBuiltinFunc
maxAllowedPacket uint64
}

func (b *builtinToBase64Sig) Clone() builtinFunc {
newSig := &builtinToBase64Sig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
newSig.maxAllowedPacket = b.maxAllowedPacket
return newSig
}

Expand Down Expand Up @@ -3042,7 +3053,14 @@ func (b *builtinToBase64Sig) evalString(row types.Row) (d string, isNull bool, e
if isNull || err != nil {
return "", isNull, errors.Trace(err)
}

needEncodeLen := base64NeededEncodedLength(len(str))
if needEncodeLen == -1 {
return "", true, nil
}
if needEncodeLen > int(b.maxAllowedPacket) {
b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("to_base64", b.maxAllowedPacket))
return "", true, nil
}
if b.tp.Flen == -1 || b.tp.Flen > mysql.MaxBlobWidth {
return "", true, nil
}
Expand Down
69 changes: 69 additions & 0 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/terror"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/charset"
Expand Down Expand Up @@ -1862,6 +1863,74 @@ func (s *testEvaluatorSuite) TestToBase64(c *C) {
c.Assert(err, IsNil)
}

func (s *testEvaluatorSuite) TestToBase64Sig(c *C) {
colTypes := []*types.FieldType{
{Tp: mysql.TypeVarchar},
}

tests := []struct {
args string
expect string
isNil bool
maxAllowPacket uint64
}{
{"abc", "YWJj", false, 4},
{"abc", "", true, 3},
{
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
"QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVphYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ejAxMjM0\nNTY3ODkrLw==",
false,
89,
},
{
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
"",
true,
88,
},
{
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
"QUJDREVGR0hJSktMTU5PUFFSU1RVVldYWVphYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ejAxMjM0\nNTY3ODkrL0FCQ0RFRkdISUpLTE1OT1BRUlNUVVZXWFlaYWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4\neXowMTIzNDU2Nzg5Ky9BQkNERUZHSElKS0xNTk9QUVJTVFVWV1hZWmFiY2RlZmdoaWprbG1ub3Bx\ncnN0dXZ3eHl6MDEyMzQ1Njc4OSsv",
false,
259,
},
{
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/",
"",
true,
258,
},
}

args := []Expression{
&Column{Index: 0, RetType: colTypes[0]},
}

for _, test := range tests {
resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: base64NeededEncodedLength(len(test.args))}
base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType}
toBase64 := &builtinToBase64Sig{base, test.maxAllowPacket}

input := chunk.NewChunkWithCapacity(colTypes, 1)
input.AppendString(0, test.args)
res, isNull, err := toBase64.evalString(input.GetRow(0))
c.Assert(err, IsNil)
if test.isNil {
c.Assert(isNull, IsTrue)

warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings()
c.Assert(len(warnings), Equals, 1)
lastWarn := warnings[len(warnings)-1]
c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue)
s.ctx.GetSessionVars().StmtCtx.SetWarnings([]stmtctx.SQLWarn{})

} else {
c.Assert(isNull, IsFalse)
}
c.Assert(res, Equals, test.expect)
}
}

func (s *testEvaluatorSuite) TestStringRight(c *C) {
defer testleak.AfterTest(c)()
fc := funcs[ast.Right]
Expand Down

0 comments on commit 6d55c90

Please sign in to comment.