Skip to content

Commit

Permalink
expression, executor, plan: rewrite builtin function substring_index. (
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveZhangBit authored and hanfei1991 committed Jul 17, 2017
1 parent 735380c commit 8805688
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 83 deletions.
13 changes: 13 additions & 0 deletions executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,19 @@ func (s *testSuite) TestStringBuiltin(c *C) {
tk.MustExec(`insert into t values(1, 1.1, "2017-01-01 12:01:01", "12:01:01", "abcdef", 0b10101, "g", "h")`)
result = tk.MustQuery("select bit_length(a), bit_length(b), bit_length(c), bit_length(d), bit_length(e), bit_length(f), bit_length(g), bit_length(h), bit_length(null) from t")
result.Check(testkit.Rows("8 24 152 64 48 16 160 8 <nil>"))

// for substring_index
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a char(20), b int, c double, d datetime, e time)")
tk.MustExec(`insert into t values('www.pingcap.com', 12345, 123.45, "2017-01-01 12:01:01", "12:01:01")`)
result = tk.MustQuery(`select substring_index(a, '.', 2), substring_index(b, '.', 2), substring_index(c, '.', -1), substring_index(d, '-', 1), substring_index(e, ':', -2) from t`)
result.Check(testkit.Rows("www.pingcap 12345 45 2017 01:01"))
result = tk.MustQuery(`select substring_index('www.pingcap.com', '.', 0), substring_index('www.pingcap.com', '.', 100), substring_index('www.pingcap.com', '.', -100)`)
result.Check(testkit.Rows(" www.pingcap.com www.pingcap.com"))
result = tk.MustQuery(`select substring_index('www.pingcap.com', 'd', 1), substring_index('www.pingcap.com', '', 1), substring_index('', '.', 1)`)
result.Check(testutil.RowsWithSep(",", "www.pingcap.com,,"))
result = tk.MustQuery(`select substring_index(null, '.', 1), substring_index('www.pingcap.com', null, 1), substring_index('www.pingcap.com', '.', null)`)
result.Check(testkit.Rows("<nil> <nil> <nil>"))
}

func (s *testSuite) TestEncryptionBuiltin(c *C) {
Expand Down
61 changes: 30 additions & 31 deletions expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -852,49 +852,49 @@ type substringIndexFunctionClass struct {
}

func (c *substringIndexFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
sig := &builtinSubstringIndexSig{newBaseBuiltinFunc(args, ctx)}
bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpString, tpString, tpString, tpInt)
if err != nil {
return nil, errors.Trace(err)
}
argType := args[0].GetType()
bf.tp.Flen = argType.Flen
if mysql.HasBinaryFlag(argType.Flag) {
types.SetBinChsClnFlag(bf.tp)
}
sig := &builtinSubstringIndexSig{baseStringBuiltinFunc{bf}}
return sig.setSelf(sig), errors.Trace(c.verifyArgs(args))
}

type builtinSubstringIndexSig struct {
baseBuiltinFunc
baseStringBuiltinFunc
}

// eval evals a builtinSubstringIndexSig.
// evalString evals a builtinSubstringIndexSig.
// See https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_substring-index
func (b *builtinSubstringIndexSig) eval(row []types.Datum) (d types.Datum, err error) {
args, err := b.evalArgs(row)
if err != nil {
return types.Datum{}, errors.Trace(err)
func (b *builtinSubstringIndexSig) evalString(row []types.Datum) (d string, isNull bool, err error) {
var (
str, delim string
count int64
)
sc := b.ctx.GetSessionVars().StmtCtx
str, isNull, err = b.args[0].EvalString(row, sc)
if isNull || err != nil {
return d, isNull, errors.Trace(err)
}
// The meaning of the elements of args.
// args[0] -> StrExpr
// args[1] -> Delim
// args[2] -> Count
str, err := args[0].ToString()
if err != nil {
return d, errors.Errorf("Substring_Index invalid args, need string but get %T", args[0].GetValue())
delim, isNull, err = b.args[1].EvalString(row, sc)
if isNull || err != nil {
return d, isNull, errors.Trace(err)
}

delim, err := args[1].ToString()
if err != nil {
return d, errors.Errorf("Substring_Index invalid delim, need string but get %T", args[1].GetValue())
count, isNull, err = b.args[2].EvalInt(row, sc)
if isNull || err != nil {
return d, isNull, errors.Trace(err)
}
if len(delim) == 0 {
d.SetString("")
return d, nil
return "", false, nil
}

c, err := args[2].ToInt64(b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return d, errors.Trace(err)
}
count := int(c)
strs := strings.Split(str, delim)
var (
start = 0
end = len(strs)
)
start, end := int64(0), int64(len(strs))
if count > 0 {
// If count is positive, everything to the left of the final delimiter (counting from the left) is returned.
if count < end {
Expand All @@ -908,8 +908,7 @@ func (b *builtinSubstringIndexSig) eval(row []types.Datum) (d types.Datum, err e
}
}
substrs := strs[start:end]
d.SetString(strings.Join(substrs, delim))
return d, nil
return strings.Join(substrs, delim), false, nil
}

type locateFunctionClass struct {
Expand Down
93 changes: 41 additions & 52 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -673,62 +673,51 @@ func (s *testEvaluatorSuite) TestConvert(c *C) {

func (s *testEvaluatorSuite) TestSubstringIndex(c *C) {
defer testleak.AfterTest(c)()
tbl := []struct {
str string
delim string
count int64
result string
}{
{"www.mysql.com", ".", 2, "www.mysql"},
{"www.mysql.com", ".", -2, "mysql.com"},
{"www.mysql.com", ".", 0, ""},
{"www.mysql.com", ".", 3, "www.mysql.com"},
{"www.mysql.com", ".", 4, "www.mysql.com"},
{"www.mysql.com", ".", -3, "www.mysql.com"},
{"www.mysql.com", ".", -4, "www.mysql.com"},

{"www.mysql.com", "d", 1, "www.mysql.com"},
{"www.mysql.com", "d", 0, ""},
{"www.mysql.com", "d", -1, "www.mysql.com"},

{"", ".", 2, ""},
{"", ".", -2, ""},
{"", ".", 0, ""},

{"www.mysql.com", "", 1, ""},
{"www.mysql.com", "", -1, ""},
{"www.mysql.com", "", 0, ""},
}
for _, v := range tbl {
fc := funcs[ast.SubstringIndex]
f, err := fc.getFunction(datumsToConstants(types.MakeDatums(v.str, v.delim, v.count)), s.ctx)
c.Assert(err, IsNil)
r, err := f.eval(nil)
c.Assert(err, IsNil)
c.Assert(r.Kind(), Equals, types.KindString)
c.Assert(r.GetString(), Equals, v.result)
}
errTbl := []struct {
str interface{}
delim interface{}
count interface{}

cases := []struct {
args []interface{}
isNil bool
getErr bool
res string
}{
{nil, ".", 2},
{nil, ".", -2},
{nil, ".", 0},
{"asdf", nil, 2},
{"asdf", nil, -2},
{"asdf", nil, 0},
{"www.mysql.com", ".", nil},
{[]interface{}{"www.pingcap.com", ".", 2}, false, false, "www.pingcap"},
{[]interface{}{"www.pingcap.com", ".", -2}, false, false, "pingcap.com"},
{[]interface{}{"www.pingcap.com", ".", 0}, false, false, ""},
{[]interface{}{"www.pingcap.com", ".", 100}, false, false, "www.pingcap.com"},
{[]interface{}{"www.pingcap.com", ".", -100}, false, false, "www.pingcap.com"},
{[]interface{}{"www.pingcap.com", "d", 0}, false, false, ""},
{[]interface{}{"www.pingcap.com", "d", 1}, false, false, "www.pingcap.com"},
{[]interface{}{"www.pingcap.com", "d", -1}, false, false, "www.pingcap.com"},
{[]interface{}{"www.pingcap.com", "", 0}, false, false, ""},
{[]interface{}{"www.pingcap.com", "", 1}, false, false, ""},
{[]interface{}{"www.pingcap.com", "", -1}, false, false, ""},
{[]interface{}{"", ".", 0}, false, false, ""},
{[]interface{}{"", ".", 1}, false, false, ""},
{[]interface{}{"", ".", -1}, false, false, ""},
{[]interface{}{nil, ".", 1}, true, false, ""},
{[]interface{}{"www.pingcap.com", nil, 1}, true, false, ""},
{[]interface{}{"www.pingcap.com", ".", nil}, true, false, ""},
{[]interface{}{errors.New("must error"), ".", 1}, false, true, ""},
}
for _, v := range errTbl {
fc := funcs[ast.SubstringIndex]
f, err := fc.getFunction(datumsToConstants(types.MakeDatums(v.str, v.delim, v.count)), s.ctx)
for _, t := range cases {
f, err := newFunctionForTest(s.ctx, ast.SubstringIndex, primitiveValsToConstants(t.args)...)
c.Assert(err, IsNil)
r, err := f.eval(nil)
c.Assert(err, NotNil)
c.Assert(r.Kind(), Equals, types.KindNull)
d, err := f.Eval(nil)
if t.getErr {
c.Assert(err, NotNil)
} else {
c.Assert(err, IsNil)
if t.isNil {
c.Assert(d.Kind(), Equals, types.KindNull)
} else {
c.Assert(d.GetString(), Equals, t.res)
}
}
}

f, err := funcs[ast.SubstringIndex].getFunction([]Expression{Zero, Zero, Zero}, s.ctx)
c.Assert(err, IsNil)
c.Assert(f.isDeterministic(), IsTrue)
}

func (s *testEvaluatorSuite) TestSpace(c *C) {
Expand Down
2 changes: 2 additions & 0 deletions plan/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ func (s *testPlanSuite) TestInferType(c *C) {
{"substr(c_binary, c_int)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"uuid()", mysql.TypeVarString, charset.CharsetUTF8, 0, 36, types.UnspecifiedLength},
{"bit_length(c_char)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 10, 0},
{"substring_index(c_int, '.', 1)", mysql.TypeVarString, charset.CharsetUTF8, 0, 11, types.UnspecifiedLength},
{"substring_index(c_binary, '.', 1)", mysql.TypeVarString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"asin(c_double)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, types.UnspecifiedLength},
{"tidb_version()", mysql.TypeVarString, charset.CharsetUTF8, 0, len(printer.GetTiDBInfo()), types.UnspecifiedLength},
}
Expand Down

0 comments on commit 8805688

Please sign in to comment.