diff --git a/executor/executor_test.go b/executor/executor_test.go index 1b62eb653c644..97fb2c8e73015 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -941,6 +941,19 @@ func (s *testSuite) TestStringBuiltin(c *C) { tk.MustExec(`insert into t values(1, 1.1, "2017-01-01 12:01:01", "12:01:01", "abcdef", 0b10101, "g", "h")`) result = tk.MustQuery("select bit_length(a), bit_length(b), bit_length(c), bit_length(d), bit_length(e), bit_length(f), bit_length(g), bit_length(h), bit_length(null) from t") result.Check(testkit.Rows("8 24 152 64 48 16 160 8 ")) + + // for substring_index + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a char(20), b int, c double, d datetime, e time)") + tk.MustExec(`insert into t values('www.pingcap.com', 12345, 123.45, "2017-01-01 12:01:01", "12:01:01")`) + result = tk.MustQuery(`select substring_index(a, '.', 2), substring_index(b, '.', 2), substring_index(c, '.', -1), substring_index(d, '-', 1), substring_index(e, ':', -2) from t`) + result.Check(testkit.Rows("www.pingcap 12345 45 2017 01:01")) + result = tk.MustQuery(`select substring_index('www.pingcap.com', '.', 0), substring_index('www.pingcap.com', '.', 100), substring_index('www.pingcap.com', '.', -100)`) + result.Check(testkit.Rows(" www.pingcap.com www.pingcap.com")) + result = tk.MustQuery(`select substring_index('www.pingcap.com', 'd', 1), substring_index('www.pingcap.com', '', 1), substring_index('', '.', 1)`) + result.Check(testutil.RowsWithSep(",", "www.pingcap.com,,")) + result = tk.MustQuery(`select substring_index(null, '.', 1), substring_index('www.pingcap.com', null, 1), substring_index('www.pingcap.com', '.', null)`) + result.Check(testkit.Rows(" ")) } func (s *testSuite) TestEncryptionBuiltin(c *C) { diff --git a/expression/builtin_string.go b/expression/builtin_string.go index 506e7c9ac9084..2c92fee696c52 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -852,49 +852,49 @@ type substringIndexFunctionClass struct { } func (c *substringIndexFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) { - sig := &builtinSubstringIndexSig{newBaseBuiltinFunc(args, ctx)} + bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpString, tpString, tpString, tpInt) + if err != nil { + return nil, errors.Trace(err) + } + argType := args[0].GetType() + bf.tp.Flen = argType.Flen + if mysql.HasBinaryFlag(argType.Flag) { + types.SetBinChsClnFlag(bf.tp) + } + sig := &builtinSubstringIndexSig{baseStringBuiltinFunc{bf}} return sig.setSelf(sig), errors.Trace(c.verifyArgs(args)) } type builtinSubstringIndexSig struct { - baseBuiltinFunc + baseStringBuiltinFunc } -// eval evals a builtinSubstringIndexSig. +// evalString evals a builtinSubstringIndexSig. // See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_substring-index -func (b *builtinSubstringIndexSig) eval(row []types.Datum) (d types.Datum, err error) { - args, err := b.evalArgs(row) - if err != nil { - return types.Datum{}, errors.Trace(err) +func (b *builtinSubstringIndexSig) evalString(row []types.Datum) (d string, isNull bool, err error) { + var ( + str, delim string + count int64 + ) + sc := b.ctx.GetSessionVars().StmtCtx + str, isNull, err = b.args[0].EvalString(row, sc) + if isNull || err != nil { + return d, isNull, errors.Trace(err) } - // The meaning of the elements of args. - // args[0] -> StrExpr - // args[1] -> Delim - // args[2] -> Count - str, err := args[0].ToString() - if err != nil { - return d, errors.Errorf("Substring_Index invalid args, need string but get %T", args[0].GetValue()) + delim, isNull, err = b.args[1].EvalString(row, sc) + if isNull || err != nil { + return d, isNull, errors.Trace(err) } - - delim, err := args[1].ToString() - if err != nil { - return d, errors.Errorf("Substring_Index invalid delim, need string but get %T", args[1].GetValue()) + count, isNull, err = b.args[2].EvalInt(row, sc) + if isNull || err != nil { + return d, isNull, errors.Trace(err) } if len(delim) == 0 { - d.SetString("") - return d, nil + return "", false, nil } - c, err := args[2].ToInt64(b.ctx.GetSessionVars().StmtCtx) - if err != nil { - return d, errors.Trace(err) - } - count := int(c) strs := strings.Split(str, delim) - var ( - start = 0 - end = len(strs) - ) + start, end := int64(0), int64(len(strs)) if count > 0 { // If count is positive, everything to the left of the final delimiter (counting from the left) is returned. if count < end { @@ -908,8 +908,7 @@ func (b *builtinSubstringIndexSig) eval(row []types.Datum) (d types.Datum, err e } } substrs := strs[start:end] - d.SetString(strings.Join(substrs, delim)) - return d, nil + return strings.Join(substrs, delim), false, nil } type locateFunctionClass struct { diff --git a/expression/builtin_string_test.go b/expression/builtin_string_test.go index 7103d6e7c0246..8231110b730e7 100644 --- a/expression/builtin_string_test.go +++ b/expression/builtin_string_test.go @@ -673,62 +673,51 @@ func (s *testEvaluatorSuite) TestConvert(c *C) { func (s *testEvaluatorSuite) TestSubstringIndex(c *C) { defer testleak.AfterTest(c)() - tbl := []struct { - str string - delim string - count int64 - result string - }{ - {"www.mysql.com", ".", 2, "www.mysql"}, - {"www.mysql.com", ".", -2, "mysql.com"}, - {"www.mysql.com", ".", 0, ""}, - {"www.mysql.com", ".", 3, "www.mysql.com"}, - {"www.mysql.com", ".", 4, "www.mysql.com"}, - {"www.mysql.com", ".", -3, "www.mysql.com"}, - {"www.mysql.com", ".", -4, "www.mysql.com"}, - - {"www.mysql.com", "d", 1, "www.mysql.com"}, - {"www.mysql.com", "d", 0, ""}, - {"www.mysql.com", "d", -1, "www.mysql.com"}, - - {"", ".", 2, ""}, - {"", ".", -2, ""}, - {"", ".", 0, ""}, - - {"www.mysql.com", "", 1, ""}, - {"www.mysql.com", "", -1, ""}, - {"www.mysql.com", "", 0, ""}, - } - for _, v := range tbl { - fc := funcs[ast.SubstringIndex] - f, err := fc.getFunction(datumsToConstants(types.MakeDatums(v.str, v.delim, v.count)), s.ctx) - c.Assert(err, IsNil) - r, err := f.eval(nil) - c.Assert(err, IsNil) - c.Assert(r.Kind(), Equals, types.KindString) - c.Assert(r.GetString(), Equals, v.result) - } - errTbl := []struct { - str interface{} - delim interface{} - count interface{} + + cases := []struct { + args []interface{} + isNil bool + getErr bool + res string }{ - {nil, ".", 2}, - {nil, ".", -2}, - {nil, ".", 0}, - {"asdf", nil, 2}, - {"asdf", nil, -2}, - {"asdf", nil, 0}, - {"www.mysql.com", ".", nil}, + {[]interface{}{"www.pingcap.com", ".", 2}, false, false, "www.pingcap"}, + {[]interface{}{"www.pingcap.com", ".", -2}, false, false, "pingcap.com"}, + {[]interface{}{"www.pingcap.com", ".", 0}, false, false, ""}, + {[]interface{}{"www.pingcap.com", ".", 100}, false, false, "www.pingcap.com"}, + {[]interface{}{"www.pingcap.com", ".", -100}, false, false, "www.pingcap.com"}, + {[]interface{}{"www.pingcap.com", "d", 0}, false, false, ""}, + {[]interface{}{"www.pingcap.com", "d", 1}, false, false, "www.pingcap.com"}, + {[]interface{}{"www.pingcap.com", "d", -1}, false, false, "www.pingcap.com"}, + {[]interface{}{"www.pingcap.com", "", 0}, false, false, ""}, + {[]interface{}{"www.pingcap.com", "", 1}, false, false, ""}, + {[]interface{}{"www.pingcap.com", "", -1}, false, false, ""}, + {[]interface{}{"", ".", 0}, false, false, ""}, + {[]interface{}{"", ".", 1}, false, false, ""}, + {[]interface{}{"", ".", -1}, false, false, ""}, + {[]interface{}{nil, ".", 1}, true, false, ""}, + {[]interface{}{"www.pingcap.com", nil, 1}, true, false, ""}, + {[]interface{}{"www.pingcap.com", ".", nil}, true, false, ""}, + {[]interface{}{errors.New("must error"), ".", 1}, false, true, ""}, } - for _, v := range errTbl { - fc := funcs[ast.SubstringIndex] - f, err := fc.getFunction(datumsToConstants(types.MakeDatums(v.str, v.delim, v.count)), s.ctx) + for _, t := range cases { + f, err := newFunctionForTest(s.ctx, ast.SubstringIndex, primitiveValsToConstants(t.args)...) c.Assert(err, IsNil) - r, err := f.eval(nil) - c.Assert(err, NotNil) - c.Assert(r.Kind(), Equals, types.KindNull) + d, err := f.Eval(nil) + if t.getErr { + c.Assert(err, NotNil) + } else { + c.Assert(err, IsNil) + if t.isNil { + c.Assert(d.Kind(), Equals, types.KindNull) + } else { + c.Assert(d.GetString(), Equals, t.res) + } + } } + + f, err := funcs[ast.SubstringIndex].getFunction([]Expression{Zero, Zero, Zero}, s.ctx) + c.Assert(err, IsNil) + c.Assert(f.isDeterministic(), IsTrue) } func (s *testEvaluatorSuite) TestSpace(c *C) { diff --git a/plan/typeinfer_test.go b/plan/typeinfer_test.go index a9cf07836e266..32f6ba7641d74 100644 --- a/plan/typeinfer_test.go +++ b/plan/typeinfer_test.go @@ -106,6 +106,8 @@ func (s *testPlanSuite) TestInferType(c *C) { {"substr(c_binary, c_int)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength}, {"uuid()", mysql.TypeVarString, charset.CharsetUTF8, 0, 36, types.UnspecifiedLength}, {"bit_length(c_char)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 10, 0}, + {"substring_index(c_int, '.', 1)", mysql.TypeVarString, charset.CharsetUTF8, 0, 11, types.UnspecifiedLength}, + {"substring_index(c_binary, '.', 1)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength}, {"asin(c_double)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength}, {"tidb_version()", mysql.TypeVarString, charset.CharsetUTF8, 0, len(printer.GetTiDBInfo()), types.UnspecifiedLength}, }