Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

builtin: add format built-in function #2883

Merged
merged 15 commits into from
Mar 31, 2017
35 changes: 34 additions & 1 deletion expression/builtin_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -1660,7 +1660,40 @@ type builtinFormatSig struct {

// See https://dev.mysql.com/doc/refman/5.6/en/string-functions.html#function_format
func (b *builtinFormatSig) eval(row []types.Datum) (d types.Datum, err error) {
return d, errFunctionNotExists.GenByArgs("format")
args, err := b.evalArgs(row)
if err != nil {
return d, errors.Trace(err)
}
if args[0].IsNull() {
d.SetNull()
return
}
arg0, err := args[0].ToString()
if err != nil {
return d, errors.Trace(err)
}
arg1, err := args[1].ToString()
if err != nil {
return d, errors.Trace(err)
}
var arg2 string

if len(args) == 2 {
arg2 = "en_US"
} else if len(args) == 3 {
arg2, err = args[2].ToString()
if err != nil {
return d, errors.Trace(err)
}
}

formatString, err := mysql.GetLocaleFormatFunction(arg2)(arg0, arg1)
if err != nil {
return d, errors.Trace(err)
}

d.SetString(formatString)
return d, nil
}

type fromBase64FunctionClass struct {
Expand Down
87 changes: 87 additions & 0 deletions expression/builtin_string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,93 @@ func (s *testEvaluatorSuite) TestOct(c *C) {
c.Assert(r.IsNull(), IsTrue)
}

func (s *testEvaluatorSuite) TestFormat(c *C) {
defer testleak.AfterTest(c)()
formatCases := []struct {
number interface{}
precision interface{}
locale string
ret interface{}
}{
{12332.1234561111111111111111111111111111111111111, 4, "en_US", "12,332.1234"},
{nil, 22, "en_US", nil},
}
formatCases1 := []struct {
number interface{}
precision interface{}
ret interface{}
}{
{12332.123456, 4, "12,332.1234"},
{12332.123456, 0, "12,332"},
{12332.123456, -4, "12,332"},
{-12332.123456, 4, "-12,332.1234"},
{-12332.123456, 0, "-12,332"},
{-12332.123456, -4, "-12,332"},
{"12332.123456", "4", "12,332.1234"},
{"12332.123456A", "4", "12,332.1234"},
{"-12332.123456", "4", "-12,332.1234"},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add case that decimal part is very long
add case no decimal part

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and test the invalid input such as NULL or string and other type, check mysql behavior to confirm

{"-12332.123456A", "4", "-12,332.1234"},
{"A123345", "4", "0.0000"},
{"-A123345", "4", "0.0000"},
{"-12332.123456", "A", "-12,332"},
{"12332.123456", "A", "12,332"},
{"-12332.123456", "4A", "-12,332.1234"},
{"12332.123456", "4A", "12,332.1234"},
{"-A12332.123456", "A", "0"},
{"A12332.123456", "A", "0"},
{"-A12332.123456", "4A", "0.0000"},
{"A12332.123456", "4A", "0.0000"},
{"-.12332.123456", "4A", "-0.1233"},
{".12332.123456", "4A", "0.1233"},
{"12332.1234567890123456789012345678901", 22, "12,332.1234567890123456789012"},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @ariesdevil

I mean you should also handle this case in your branch:

mysql> select format("abcd", 32);
ERROR 2013 (HY000): Lost connection to MySQL server during query

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{nil, 22, nil},
}
formatCases2 := struct {
number interface{}
precision interface{}
locale string
ret interface{}
}{-12332.123456, -4, "zh_CN", nil}
formatCases3 := struct {
number interface{}
precision interface{}
locale string
ret interface{}
}{"-12332.123456", "4", "de_GE", nil}

for _, t := range formatCases {
fc := funcs[ast.Format]
f, err := fc.getFunction(datumsToConstants(types.MakeDatums(t.number, t.precision, t.locale)), s.ctx)
c.Assert(err, IsNil)
r, err := f.eval(nil)
c.Assert(err, IsNil)
c.Assert(r, testutil.DatumEquals, types.NewDatum(t.ret))
}

for _, t := range formatCases1 {
fc := funcs[ast.Format]
f, err := fc.getFunction(datumsToConstants(types.MakeDatums(t.number, t.precision)), s.ctx)
c.Assert(err, IsNil)
r, err := f.eval(nil)
c.Assert(err, IsNil)
c.Assert(r, testutil.DatumEquals, types.NewDatum(t.ret))
}

fc2 := funcs[ast.Format]
f2, err := fc2.getFunction(datumsToConstants(types.MakeDatums(formatCases2.number, formatCases2.precision, formatCases2.locale)), s.ctx)
c.Assert(err, IsNil)
r2, err := f2.eval(nil)
c.Assert(types.NewDatum(err), testutil.DatumEquals, types.NewDatum(errors.New("not implemented")))
c.Assert(r2, testutil.DatumEquals, types.NewDatum(formatCases2.ret))

fc3 := funcs[ast.Format]
f3, err := fc3.getFunction(datumsToConstants(types.MakeDatums(formatCases3.number, formatCases3.precision, formatCases3.locale)), s.ctx)
c.Assert(err, IsNil)
r3, err := f3.eval(nil)
c.Assert(types.NewDatum(err), testutil.DatumEquals, types.NewDatum(errors.New("not support for the specific locale")))
c.Assert(r3, testutil.DatumEquals, types.NewDatum(formatCases3.ret))
}

func (s *testEvaluatorSuite) TestInsert(c *C) {
tests := []struct {
args []interface{}
Expand Down
18 changes: 18 additions & 0 deletions mysql/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -398,3 +398,21 @@ var Str2SQLMode = map[string]SQLMode{
"NO_ENGINE_SUBSTITUTION": ModeNoEngineSubstitution,
"PAD_CHAR_TO_FULL_LENGTH": ModePadCharToFullLength,
}

// FormatFunc is the locale format function signature.
type FormatFunc func(string, string) (string, error)

// GetLocaleFormatFunction get the format function for sepcific locale.
func GetLocaleFormatFunction(loc string) FormatFunc {
locale, exist := locale2FormatFunction[loc]
if !exist {
return formatNotSupport
}
return locale
}

// locale2FormatFunction is the string represent of locale format function.
var locale2FormatFunction = map[string]FormatFunc{
"en_US": formatENUS,
"zh_CN": formatZHCN,
}
99 changes: 99 additions & 0 deletions mysql/locale_format.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package mysql

import (
"bytes"
"strconv"
"strings"
"unicode"

"github.com/juju/errors"
)

func formatENUS(number string, precision string) (string, error) {
var buffer bytes.Buffer
if unicode.IsDigit(rune(precision[0])) {
for i, v := range precision {
if unicode.IsDigit(v) {
continue
} else {
precision = precision[:i]
break
}
}
} else {
precision = "0"
}
if number[0] == '-' && number[1] == '.' {
number = strings.Replace(number, "-", "-0", 1)
} else if number[0] == '.' {
number = strings.Replace(number, ".", "0.", 1)
}

if (number[:1] == "-" && !unicode.IsDigit(rune(number[1]))) ||
(!unicode.IsDigit(rune(number[0])) && number[:1] != "-") {
buffer.Write([]byte{'0'})
position, err := strconv.ParseUint(precision, 10, 64)
if err == nil && position > 0 {
buffer.Write([]byte{'.'})
buffer.WriteString(strings.Repeat("0", int(position)))
}
return buffer.String(), nil
} else if number[:1] == "-" {
buffer.Write([]byte{'-'})
number = number[1:]
}

for i, v := range number {
if unicode.IsDigit(v) {
continue
} else if i == 1 && number[1] == '.' {
continue
} else if v == '.' && number[1] != '.' {
continue
} else {
number = number[:i]
break
}
}

comma := []byte{','}
parts := strings.Split(number, ".")
pos := 0
if len(parts[0])%3 != 0 {
pos += len(parts[0]) % 3
buffer.WriteString(parts[0][:pos])
buffer.Write(comma)
}
for ; pos < len(parts[0]); pos += 3 {
buffer.WriteString(parts[0][pos : pos+3])
buffer.Write(comma)
}
buffer.Truncate(buffer.Len() - 1)

position, err := strconv.ParseUint(precision, 10, 64)
if err == nil {
if position > 0 {
buffer.Write([]byte{'.'})
if len(parts) == 2 {
if uint64(len(parts[1])) >= position {
buffer.WriteString(parts[1][:position])
} else {
buffer.WriteString(parts[1])
buffer.WriteString(strings.Repeat("0", int(position)-len(parts[1])))
}
} else {
buffer.WriteString(strings.Repeat("0", int(position)))
}
}
}

return buffer.String(), nil
}

func formatZHCN(number string, precision string) (string, error) {
return "", errors.New("not implemented")
}

func formatNotSupport(number string, precision string) (string, error) {
return "", errors.New("not support for the specific locale")
}
2 changes: 1 addition & 1 deletion plan/typeinferer.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ func (v *typeInferrer) handleFuncCallExpr(x *ast.FuncCallExpr) {
"replace", "ucase", "upper", "convert", "substring", "elt",
"substring_index", "trim", "ltrim", "rtrim", "reverse", "hex", "unhex",
"date_format", "rpad", "lpad", "char_func", "conv", "make_set", "oct", "uuid",
"insert_func", "bin":
"insert_func", "bin", "format":
tp = types.NewFieldType(mysql.TypeVarString)
chs = v.defaultCharset
case "strcmp", "isnull", "bit_length", "char_length", "character_length", "crc32", "timestampdiff",
Expand Down
1 change: 1 addition & 0 deletions plan/typeinferer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ func (ts *testTypeInferrerSuite) TestInferType(c *C) {
{`exp('1.23')`, mysql.TypeDouble, charset.CharsetBin},
{`insert("Titanium", 3, 6, "DB")`, mysql.TypeVarString, charset.CharsetUTF8},
{`is_ipv6('FE80::AAAA:0000:00C2:0002')`, mysql.TypeLonglong, charset.CharsetBin},
{`format(12332.123456, 4)`, mysql.TypeVarString, charset.CharsetUTF8},
{"inet_ntoa(1)", mysql.TypeVarString, charset.CharsetUTF8},
{`ord('2')`, mysql.TypeLonglong, charset.CharsetBin},
{`ord(2)`, mysql.TypeLonglong, charset.CharsetBin},
Expand Down