diff --git a/go/vt/vtgate/evalengine/func.go b/go/vt/vtgate/evalengine/func.go index cbe530956a5..65c5e93f266 100644 --- a/go/vt/vtgate/evalengine/func.go +++ b/go/vt/vtgate/evalengine/func.go @@ -49,6 +49,7 @@ var builtinFunctions = map[string]builtin{ "bit_length": builtinBitLength{}, "ascii": builtinASCII{}, "repeat": builtinRepeat{}, + "conv": builtinConv{}, } var builtinFunctionsRewrite = map[string]builtinRewrite{ diff --git a/go/vt/vtgate/evalengine/integration/string_fun_test.go b/go/vt/vtgate/evalengine/integration/string_fun_test.go index af37ce03d24..e347632ff80 100644 --- a/go/vt/vtgate/evalengine/integration/string_fun_test.go +++ b/go/vt/vtgate/evalengine/integration/string_fun_test.go @@ -204,3 +204,67 @@ func TestBuiltinRepeat(t *testing.T) { } } + +func TestBuiltinConv(t *testing.T) { + var conn = mysqlconn(t) + defer conn.Close() + cases := []string{ + "++5", + "--4", + "-5.1", + "-5.9", + "0xa21 + '1'", + "0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", + "-0xa21 + '1'", + "'+10'", + "'10-9+10'", + "'+10-9+10'", + "10 + '10' + 10", + "10 + '10' - 10", + "-10", + "'10'", + "10+'10'+'10a'+X'0a'", + "10 / 10", + "X'0FFFFFFFFFFFFFF'", + "99999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", + "-99999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999999", + "1000000000000000000000000000000", + } + bases := []string{ + "-1", + "1", + "2", + "-2", + "4", + "-4", + "8", + "-8", + "10", + "-10", + "16", + "-16", + "32", + "-32", + "64", + "-64", + "128", + "-128", + "0xa", + } + + for _, num := range cases { + for _, fromBase := range bases { + for i := range bases { + toBase := bases[i] + query := fmt.Sprintf("CONV(%s, %s, %s)", num, fromBase, toBase) + compareRemoteExpr(t, conn, query) + + toBase = bases[len(bases)-1-i] + query = fmt.Sprintf("CONV(%s, %s, %s)", num, fromBase, toBase) + compareRemoteExpr(t, conn, query) + } + } + + } + +} diff --git a/go/vt/vtgate/evalengine/string.go b/go/vt/vtgate/evalengine/string.go index 668ec11cf6d..2e0bb577ec8 100644 --- a/go/vt/vtgate/evalengine/string.go +++ b/go/vt/vtgate/evalengine/string.go @@ -18,6 +18,8 @@ package evalengine import ( "bytes" + "strconv" + "strings" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" @@ -265,3 +267,132 @@ func (builtinRepeat) typeof(env *ExpressionEnv, args []Expr) (sqltypes.Type, fla return sqltypes.VarChar, f1 } + +type builtinConv struct { +} + +func (builtinConv) call(env *ExpressionEnv, args []EvalResult, result *EvalResult) { + const MaxUint = 18446744073709551615 + const MAXINT = 9223372036854775807 + const MININT = -9223372036854775808 + var fromNum uint64 + var isNeg bool + var rawString string + inarg := &args[0] + inarg2 := &args[1] + inarg3 := &args[2] + + if sqltypes.IsBinary(inarg.typeof()) { + inarg.makeUnsignedIntegral() + } + + if sqltypes.IsBinary(inarg2.typeof()) { + inarg2.makeUnsignedIntegral() + } + + if sqltypes.IsBinary(inarg3.typeof()) { + inarg3.makeUnsignedIntegral() + } + + fromBase, _ := strconv.Atoi(string(inarg2.toRawBytes())) + toBase, _ := strconv.Atoi(string(inarg3.toRawBytes())) + fromNum = 0 + + if inarg.isNull() || + (fromBase > -2 && fromBase < 2) || (toBase > -2 && toBase < 2) || + fromBase < -36 || fromBase > 36 || toBase < -36 || toBase > 36 { + result.setNull() + return + } + + rawString = string(inarg.toRawBytes()) + rawString = strings.ToLower(rawString) + + trimStr := func(s string, isNeg *bool) string { + var base uint64 + if fromBase > 0 { + base = uint64(fromBase) + } else { + base = -uint64(fromBase) + } + start := 0 + for i, c := range s { + if (c == '+' || c == '-') && i == 0 { + start++ + *isNeg = (c == '-') + continue + } + if (base <= 9 && c >= '0' && c <= rune('0'+base)) || + (base > 9 && ((c >= '0' && c <= '9') || (c >= 'a' && c <= rune('a'+base-9)))) { + continue + } else { + return s[start:i] + } + } + return s[start:] + } + + num := trimStr(rawString, &isNeg) + + if fromBase < 0 { + if isNeg { + num = "-" + num + } + if transNum, err := strconv.ParseInt(num, -fromBase, 64); err == nil { + if isNeg { + fromNum = uint64(-transNum) + } else { + fromNum = uint64(transNum) + } + } else if strings.Contains(err.Error(), "value out of range") { + if isNeg { + fromNum = uint64(-MININT) + } else { + fromNum = uint64(MAXINT) + } + } + } else { + if transNum, err := strconv.ParseUint(num, int(fromBase), 64); err == nil { + fromNum = transNum + } else if strings.Contains(err.Error(), "value out of range") { + if isNeg { + fromNum = 0 + } else { + fromNum = MaxUint + } + } + } + + var toNum string + var temp string + if toBase > 0 { + if isNeg { + temp = strconv.FormatUint(uint64(-fromNum), toBase) + } else { + temp = strconv.FormatUint(fromNum, toBase) + } + } else { + toBase = -toBase + if isNeg { + temp = strconv.FormatInt(int64(-fromNum), toBase) + } else { + temp = strconv.FormatInt(int64(fromNum), toBase) + } + } + toNum = strings.ToUpper(temp) + + inarg.makeTextualAndConvert(env.DefaultCollation) + result.setString(toNum, inarg.collation()) +} + +func (builtinConv) typeof(env *ExpressionEnv, args []Expr) (sqltypes.Type, flag) { + if len(args) != 3 { + throwArgError("CONV") + } + _, f1 := args[0].typeof(env) + _, f2 := args[1].typeof(env) + args[1].typeof(env) + args[2].typeof(env) + + return sqltypes.VarChar, f1 & f2 +}