diff --git a/expression/builtin_string.go b/expression/builtin_string.go index 7a56a68819279..3f94a43231490 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -3145,21 +3145,30 @@ func (c *insertFunctionClass) getFunction(ctx sessionctx.Context, args []Express bf.tp.Flen = mysql.MaxBlobWidth SetBinFlagOrBinStr(args[0].GetType(), bf.tp) SetBinFlagOrBinStr(args[3].GetType(), bf.tp) + + valStr, _ := ctx.GetSessionVars().GetSystemVar(variable.MaxAllowedPacket) + maxAllowedPacket, err := strconv.ParseUint(valStr, 10, 64) + if err != nil { + return nil, errors.Trace(err) + } + if types.IsBinaryStr(args[0].GetType()) { - sig = &builtinInsertBinarySig{bf} + sig = &builtinInsertBinarySig{bf, maxAllowedPacket} } else { - sig = &builtinInsertSig{bf} + sig = &builtinInsertSig{bf, maxAllowedPacket} } return sig, nil } type builtinInsertBinarySig struct { baseBuiltinFunc + maxAllowedPacket uint64 } func (b *builtinInsertBinarySig) Clone() builtinFunc { newSig := &builtinInsertBinarySig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.maxAllowedPacket = b.maxAllowedPacket return newSig } @@ -3191,18 +3200,26 @@ func (b *builtinInsertBinarySig) evalString(row types.Row) (string, bool, error) } if length > strLength-pos+1 || length < 0 { - return str[0:pos-1] + newstr, false, nil + length = strLength - pos + 1 } + + if uint64(strLength-length+int64(len(newstr))) > b.maxAllowedPacket { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("insert", b.maxAllowedPacket)) + return "", true, nil + } + return str[0:pos-1] + newstr + str[pos+length-1:], false, nil } type builtinInsertSig struct { baseBuiltinFunc + maxAllowedPacket uint64 } func (b *builtinInsertSig) Clone() builtinFunc { newSig := &builtinInsertSig{} newSig.cloneFrom(&b.baseBuiltinFunc) + newSig.maxAllowedPacket = b.maxAllowedPacket return newSig } @@ -3235,9 +3252,16 @@ func (b *builtinInsertSig) evalString(row types.Row) (string, bool, error) { } if length > runeLength-pos+1 || length < 0 { - return string(runes[0:pos-1]) + newstr, false, nil + length = runeLength - pos + 1 + } + + strHead := string(runes[0 : pos-1]) + strTail := string(runes[pos+length-1:]) + if uint64(len(strHead)+len(newstr)+len(strTail)) > b.maxAllowedPacket { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnAllowedPacketOverflowed.GenByArgs("insert", b.maxAllowedPacket)) + return "", true, nil } - return string(runes[0:pos-1]) + newstr + string(runes[pos+length-1:]), false, nil + return strHead + newstr + strTail, false, nil } type instrFunctionClass struct { diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 41cda60c571a5..de11960ae6ac6 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -1381,6 +1381,51 @@ func (s *testEvaluatorSuite) TestRpadSig(c *C) { c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue) } +func (s *testEvaluatorSuite) TestInsertBinarySig(c *C) { + colTypes := []*types.FieldType{ + {Tp: mysql.TypeVarchar}, + {Tp: mysql.TypeLonglong}, + {Tp: mysql.TypeLonglong}, + {Tp: mysql.TypeVarchar}, + } + resultType := &types.FieldType{Tp: mysql.TypeVarchar, Flen: 3} + + args := []Expression{ + &Column{Index: 0, RetType: colTypes[0]}, + &Column{Index: 1, RetType: colTypes[1]}, + &Column{Index: 2, RetType: colTypes[2]}, + &Column{Index: 3, RetType: colTypes[3]}, + } + + base := baseBuiltinFunc{args: args, ctx: s.ctx, tp: resultType} + insert := &builtinInsertBinarySig{base, 3} + + input := chunk.NewChunkWithCapacity(colTypes, 2) + input.AppendString(0, "abc") + input.AppendString(0, "abc") + input.AppendInt64(1, 3) + input.AppendInt64(1, 3) + input.AppendInt64(2, -1) + input.AppendInt64(2, -1) + input.AppendString(3, "d") + input.AppendString(3, "de") + + res, isNull, err := insert.evalString(input.GetRow(0)) + c.Assert(res, Equals, "abd") + c.Assert(isNull, IsFalse) + c.Assert(err, IsNil) + + res, isNull, err = insert.evalString(input.GetRow(1)) + c.Assert(res, Equals, "") + c.Assert(isNull, IsTrue) + c.Assert(err, IsNil) + + warnings := s.ctx.GetSessionVars().StmtCtx.GetWarnings() + c.Assert(len(warnings), Equals, 1) + lastWarn := warnings[len(warnings)-1] + c.Assert(terror.ErrorEqual(errWarnAllowedPacketOverflowed, lastWarn.Err), IsTrue, Commentf("err %v", lastWarn.Err)) +} + func (s *testEvaluatorSuite) TestInstr(c *C) { defer testleak.AfterTest(c)() tbl := []struct {