diff --git a/go/vt/vtgate/evalengine/arithmetic_test.go b/go/vt/vtgate/evalengine/arithmetic_test.go index 781e5d5de82..a23ee04a8c4 100644 --- a/go/vt/vtgate/evalengine/arithmetic_test.go +++ b/go/vt/vtgate/evalengine/arithmetic_test.go @@ -24,6 +24,8 @@ import ( "strconv" "testing" + "github.com/stretchr/testify/require" + "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -31,456 +33,401 @@ import ( "vitess.io/vitess/go/vt/vterrors" ) -func TestDivide(t *testing.T) { - tcases := []struct { - v1, v2 sqltypes.Value - out sqltypes.Value - err error - }{{ - //All Nulls - v1: sqltypes.NULL, - v2: sqltypes.NULL, - out: sqltypes.NULL, - }, { - // First value null. - v1: sqltypes.NULL, - v2: sqltypes.NewInt32(1), - out: sqltypes.NULL, - }, { - // Second value null. - v1: sqltypes.NewInt32(1), - v2: sqltypes.NULL, - out: sqltypes.NULL, - }, { - // Second arg 0 - v1: sqltypes.NewInt32(5), - v2: sqltypes.NewInt32(0), - out: sqltypes.NULL, - }, { - // Both arguments zero - v1: sqltypes.NewInt32(0), - v2: sqltypes.NewInt32(0), - out: sqltypes.NULL, - }, { - // case with negative value - v1: sqltypes.NewInt64(-1), - v2: sqltypes.NewInt64(-2), - out: sqltypes.NewFloat64(0.5000), - }, { - // float64 division by zero - v1: sqltypes.NewFloat64(2), - v2: sqltypes.NewFloat64(0), - out: sqltypes.NULL, - }, { - // Lower bound for int64 - v1: sqltypes.NewInt64(math.MinInt64), - v2: sqltypes.NewInt64(1), - out: sqltypes.NewFloat64(math.MinInt64), - }, { - // upper bound for uint64 - v1: sqltypes.NewUint64(math.MaxUint64), - v2: sqltypes.NewUint64(1), - out: sqltypes.NewFloat64(math.MaxUint64), - }, { - // testing for error in types - v1: sqltypes.TestValue(querypb.Type_INT64, "1.2"), - v2: sqltypes.NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // testing for error in types - v1: sqltypes.NewInt64(2), - v2: sqltypes.TestValue(querypb.Type_INT64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // testing for uint/int - v1: sqltypes.NewUint64(4), - v2: sqltypes.NewInt64(5), - out: sqltypes.NewFloat64(0.8), - }, { - // testing for uint/uint - v1: sqltypes.NewUint64(1), - v2: sqltypes.NewUint64(2), - out: sqltypes.NewFloat64(0.5), - }, { - // testing for float64/int64 - v1: sqltypes.TestValue(querypb.Type_FLOAT64, "1.2"), - v2: sqltypes.NewInt64(-2), - out: sqltypes.NewFloat64(-0.6), - }, { - // testing for float64/uint64 - v1: sqltypes.TestValue(querypb.Type_FLOAT64, "1.2"), - v2: sqltypes.NewUint64(2), - out: sqltypes.NewFloat64(0.6), - }, { - // testing for overflow of float64 - v1: sqltypes.NewFloat64(math.MaxFloat64), - v2: sqltypes.NewFloat64(0.5), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT is out of range in 1.7976931348623157e+308 / 0.5"), - }} - - for _, tcase := range tcases { - got, err := Divide(tcase.v1, tcase.v2) - - if !vterrors.Equals(err, tcase.err) { - t.Errorf("%v %v %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err)) - t.Errorf("Divide(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) - } - - if tcase.err != nil { - continue - } - - if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("Divide(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) - } +func TestArithmetics(t *testing.T) { + type tcase struct { + v1, v2, out sqltypes.Value + err string } -} - -func TestMultiply(t *testing.T) { - tcases := []struct { - v1, v2 sqltypes.Value - out sqltypes.Value - err error + tests := []struct { + operator string + f func(a, b sqltypes.Value) (sqltypes.Value, error) + cases []tcase }{{ - //All Nulls - v1: sqltypes.NULL, - v2: sqltypes.NULL, - out: sqltypes.NULL, - }, { - // First value null. - v1: sqltypes.NewInt32(1), - v2: sqltypes.NULL, - out: sqltypes.NULL, - }, { - // Second value null. - v1: sqltypes.NULL, - v2: sqltypes.NewInt32(1), - out: sqltypes.NULL, - }, { - // case with negative value - v1: sqltypes.NewInt64(-1), - v2: sqltypes.NewInt64(-2), - out: sqltypes.NewInt64(2), - }, { - // testing for int64 overflow with min negative value - v1: sqltypes.NewInt64(math.MinInt64), - v2: sqltypes.NewInt64(1), - out: sqltypes.NewInt64(math.MinInt64), - }, { - // testing for error in types - v1: sqltypes.TestValue(querypb.Type_INT64, "1.2"), - v2: sqltypes.NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // testing for error in types - v1: sqltypes.NewInt64(2), - v2: sqltypes.TestValue(querypb.Type_INT64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // testing for uint*int - v1: sqltypes.NewUint64(4), - v2: sqltypes.NewInt64(5), - out: sqltypes.NewUint64(20), - }, { - // testing for uint*uint - v1: sqltypes.NewUint64(1), - v2: sqltypes.NewUint64(2), - out: sqltypes.NewUint64(2), - }, { - // testing for float64*int64 - v1: sqltypes.TestValue(querypb.Type_FLOAT64, "1.2"), - v2: sqltypes.NewInt64(-2), - out: sqltypes.NewFloat64(-2.4), - }, { - // testing for float64*uint64 - v1: sqltypes.TestValue(querypb.Type_FLOAT64, "1.2"), - v2: sqltypes.NewUint64(2), - out: sqltypes.NewFloat64(2.4), - }, { - // testing for overflow of int64 - v1: sqltypes.NewInt64(math.MaxInt64), - v2: sqltypes.NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in 9223372036854775807 * 2"), - }, { - // testing for underflow of uint64*max.uint64 - v1: sqltypes.NewInt64(2), - v2: sqltypes.NewUint64(math.MaxUint64), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 18446744073709551615 * 2"), - }, { - v1: sqltypes.NewUint64(math.MaxUint64), - v2: sqltypes.NewUint64(1), - out: sqltypes.NewUint64(math.MaxUint64), - }, { - //Checking whether maxInt value can be passed as uint value - v1: sqltypes.NewUint64(math.MaxInt64), - v2: sqltypes.NewInt64(3), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 9223372036854775807 * 3"), + operator: "-", + f: Subtract, + cases: []tcase{{ + // All Nulls + v1: sqltypes.NULL, + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // First value null. + v1: sqltypes.NewInt32(1), + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // Second value null. + v1: sqltypes.NULL, + v2: sqltypes.NewInt32(1), + out: sqltypes.NULL, + }, { + // case with negative value + v1: sqltypes.NewInt64(-1), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewInt64(1), + }, { + // testing for int64 overflow with min negative value + v1: sqltypes.NewInt64(math.MinInt64), + v2: sqltypes.NewInt64(1), + err: "BIGINT value is out of range in -9223372036854775808 - 1", + }, { + v1: sqltypes.NewUint64(4), + v2: sqltypes.NewInt64(5), + err: "BIGINT UNSIGNED value is out of range in 4 - 5", + }, { + // testing uint - int + v1: sqltypes.NewUint64(7), + v2: sqltypes.NewInt64(5), + out: sqltypes.NewUint64(2), + }, { + v1: sqltypes.NewUint64(math.MaxUint64), + v2: sqltypes.NewInt64(0), + out: sqltypes.NewUint64(math.MaxUint64), + }, { + // testing for int64 overflow + v1: sqltypes.NewInt64(math.MinInt64), + v2: sqltypes.NewUint64(0), + err: "BIGINT UNSIGNED value is out of range in -9223372036854775808 - 0", + }, { + v1: sqltypes.TestValue(querypb.Type_VARCHAR, "c"), + v2: sqltypes.NewInt64(1), + out: sqltypes.NewInt64(-1), + }, { + v1: sqltypes.NewUint64(1), + v2: sqltypes.TestValue(querypb.Type_VARCHAR, "c"), + out: sqltypes.NewUint64(1), + }, { + // testing for error for parsing float value to uint64 + v1: sqltypes.TestValue(querypb.Type_UINT64, "1.2"), + v2: sqltypes.NewInt64(2), + err: "strconv.ParseUint: parsing \"1.2\": invalid syntax", + }, { + // testing for error for parsing float value to uint64 + v1: sqltypes.NewUint64(2), + v2: sqltypes.TestValue(querypb.Type_UINT64, "1.2"), + err: "strconv.ParseUint: parsing \"1.2\": invalid syntax", + }, { + // uint64 - uint64 + v1: sqltypes.NewUint64(8), + v2: sqltypes.NewUint64(4), + out: sqltypes.NewUint64(4), + }, { + // testing for float subtraction: float - int + v1: sqltypes.NewFloat64(1.2), + v2: sqltypes.NewInt64(2), + out: sqltypes.NewFloat64(-0.8), + }, { + // testing for float subtraction: float - uint + v1: sqltypes.NewFloat64(1.2), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewFloat64(-0.8), + }, { + v1: sqltypes.NewInt64(-1), + v2: sqltypes.NewUint64(2), + err: "BIGINT UNSIGNED value is out of range in -1 - 2", + }, { + v1: sqltypes.NewInt64(2), + v2: sqltypes.NewUint64(1), + out: sqltypes.NewUint64(1), + }, { + // testing int64 - float64 method + v1: sqltypes.NewInt64(-2), + v2: sqltypes.NewFloat64(1.0), + out: sqltypes.NewFloat64(-3.0), + }, { + // testing uint64 - float64 method + v1: sqltypes.NewUint64(1), + v2: sqltypes.NewFloat64(-2.0), + out: sqltypes.NewFloat64(3.0), + }, { + // testing uint - int to return uintplusint + v1: sqltypes.NewUint64(1), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewUint64(3), + }, { + // testing for float - float + v1: sqltypes.NewFloat64(1.2), + v2: sqltypes.NewFloat64(3.2), + out: sqltypes.NewFloat64(-2), + }, { + // testing uint - uint if v2 > v1 + v1: sqltypes.NewUint64(2), + v2: sqltypes.NewUint64(4), + err: "BIGINT UNSIGNED value is out of range in 2 - 4", + }, { + // testing uint - (- int) + v1: sqltypes.NewUint64(1), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewUint64(3), + }}, + }, { + operator: "+", + f: Add, + cases: []tcase{{ + // All Nulls + v1: sqltypes.NULL, + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // First value null. + v1: sqltypes.NewInt32(1), + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // Second value null. + v1: sqltypes.NULL, + v2: sqltypes.NewInt32(1), + out: sqltypes.NULL, + }, { + // case with negatives + v1: sqltypes.NewInt64(-1), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewInt64(-3), + }, { + // testing for overflow int64, result will be unsigned int + v1: sqltypes.NewInt64(math.MaxInt64), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewUint64(9223372036854775809), + }, { + v1: sqltypes.NewInt64(-2), + v2: sqltypes.NewUint64(1), + err: "BIGINT UNSIGNED value is out of range in 1 + -2", + }, { + v1: sqltypes.NewInt64(math.MaxInt64), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewInt64(9223372036854775805), + }, { + // Normal case + v1: sqltypes.NewUint64(1), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewUint64(3), + }, { + // testing for overflow uint64 + v1: sqltypes.NewUint64(math.MaxUint64), + v2: sqltypes.NewUint64(2), + err: "BIGINT UNSIGNED value is out of range in 18446744073709551615 + 2", + }, { + // int64 underflow + v1: sqltypes.NewInt64(math.MinInt64), + v2: sqltypes.NewInt64(-2), + err: "BIGINT value is out of range in -9223372036854775808 + -2", + }, { + // checking int64 max value can be returned + v1: sqltypes.NewInt64(math.MaxInt64), + v2: sqltypes.NewUint64(0), + out: sqltypes.NewUint64(9223372036854775807), + }, { + // testing whether uint64 max value can be returned + v1: sqltypes.NewUint64(math.MaxUint64), + v2: sqltypes.NewInt64(0), + out: sqltypes.NewUint64(math.MaxUint64), + }, { + v1: sqltypes.NewUint64(math.MaxInt64), + v2: sqltypes.NewInt64(1), + out: sqltypes.NewUint64(9223372036854775808), + }, { + v1: sqltypes.NewUint64(1), + v2: sqltypes.TestValue(querypb.Type_VARCHAR, "c"), + out: sqltypes.NewUint64(1), + }, { + v1: sqltypes.NewUint64(1), + v2: sqltypes.TestValue(querypb.Type_VARCHAR, "1.2"), + out: sqltypes.NewFloat64(2.2), + }, { + v1: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + v2: sqltypes.NewInt64(2), + err: "strconv.ParseInt: parsing \"1.2\": invalid syntax", + }, { + v1: sqltypes.NewInt64(2), + v2: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + err: "strconv.ParseInt: parsing \"1.2\": invalid syntax", + }, { + // testing for uint64 overflow with max uint64 + int value + v1: sqltypes.NewUint64(math.MaxUint64), + v2: sqltypes.NewInt64(2), + err: "BIGINT UNSIGNED value is out of range in 18446744073709551615 + 2", + }}, + }, { + operator: "/", + f: Divide, + cases: []tcase{{ + //All Nulls + v1: sqltypes.NULL, + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // First value null. + v1: sqltypes.NULL, + v2: sqltypes.NewInt32(1), + out: sqltypes.NULL, + }, { + // Second value null. + v1: sqltypes.NewInt32(1), + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // Second arg 0 + v1: sqltypes.NewInt32(5), + v2: sqltypes.NewInt32(0), + out: sqltypes.NULL, + }, { + // Both arguments zero + v1: sqltypes.NewInt32(0), + v2: sqltypes.NewInt32(0), + out: sqltypes.NULL, + }, { + // case with negative value + v1: sqltypes.NewInt64(-1), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewFloat64(0.5000), + }, { + // float64 division by zero + v1: sqltypes.NewFloat64(2), + v2: sqltypes.NewFloat64(0), + out: sqltypes.NULL, + }, { + // Lower bound for int64 + v1: sqltypes.NewInt64(math.MinInt64), + v2: sqltypes.NewInt64(1), + out: sqltypes.NewFloat64(math.MinInt64), + }, { + // upper bound for uint64 + v1: sqltypes.NewUint64(math.MaxUint64), + v2: sqltypes.NewUint64(1), + out: sqltypes.NewFloat64(math.MaxUint64), + }, { + // testing for error in types + v1: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + v2: sqltypes.NewInt64(2), + err: "strconv.ParseInt: parsing \"1.2\": invalid syntax", + }, { + // testing for error in types + v1: sqltypes.NewInt64(2), + v2: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + err: "strconv.ParseInt: parsing \"1.2\": invalid syntax", + }, { + // testing for uint/int + v1: sqltypes.NewUint64(4), + v2: sqltypes.NewInt64(5), + out: sqltypes.NewFloat64(0.8), + }, { + // testing for uint/uint + v1: sqltypes.NewUint64(1), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewFloat64(0.5), + }, { + // testing for float64/int64 + v1: sqltypes.TestValue(querypb.Type_FLOAT64, "1.2"), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewFloat64(-0.6), + }, { + // testing for float64/uint64 + v1: sqltypes.TestValue(querypb.Type_FLOAT64, "1.2"), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewFloat64(0.6), + }, { + // testing for overflow of float64 + v1: sqltypes.NewFloat64(math.MaxFloat64), + v2: sqltypes.NewFloat64(0.5), + err: "BIGINT is out of range in 1.7976931348623157e+308 / 0.5", + }}, + }, { + operator: "*", + f: Multiply, + cases: []tcase{{ + //All Nulls + v1: sqltypes.NULL, + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // First value null. + v1: sqltypes.NewInt32(1), + v2: sqltypes.NULL, + out: sqltypes.NULL, + }, { + // Second value null. + v1: sqltypes.NULL, + v2: sqltypes.NewInt32(1), + out: sqltypes.NULL, + }, { + // case with negative value + v1: sqltypes.NewInt64(-1), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewInt64(2), + }, { + // testing for int64 overflow with min negative value + v1: sqltypes.NewInt64(math.MinInt64), + v2: sqltypes.NewInt64(1), + out: sqltypes.NewInt64(math.MinInt64), + }, { + // testing for error in types + v1: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + v2: sqltypes.NewInt64(2), + err: "strconv.ParseInt: parsing \"1.2\": invalid syntax", + }, { + // testing for error in types + v1: sqltypes.NewInt64(2), + v2: sqltypes.TestValue(querypb.Type_INT64, "1.2"), + err: "strconv.ParseInt: parsing \"1.2\": invalid syntax", + }, { + // testing for uint*int + v1: sqltypes.NewUint64(4), + v2: sqltypes.NewInt64(5), + out: sqltypes.NewUint64(20), + }, { + // testing for uint*uint + v1: sqltypes.NewUint64(1), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewUint64(2), + }, { + // testing for float64*int64 + v1: sqltypes.TestValue(querypb.Type_FLOAT64, "1.2"), + v2: sqltypes.NewInt64(-2), + out: sqltypes.NewFloat64(-2.4), + }, { + // testing for float64*uint64 + v1: sqltypes.TestValue(querypb.Type_FLOAT64, "1.2"), + v2: sqltypes.NewUint64(2), + out: sqltypes.NewFloat64(2.4), + }, { + // testing for overflow of int64 + v1: sqltypes.NewInt64(math.MaxInt64), + v2: sqltypes.NewInt64(2), + err: "BIGINT value is out of range in 9223372036854775807 * 2", + }, { + // testing for underflow of uint64*max.uint64 + v1: sqltypes.NewInt64(2), + v2: sqltypes.NewUint64(math.MaxUint64), + err: "BIGINT UNSIGNED value is out of range in 18446744073709551615 * 2", + }, { + v1: sqltypes.NewUint64(math.MaxUint64), + v2: sqltypes.NewUint64(1), + out: sqltypes.NewUint64(math.MaxUint64), + }, { + //Checking whether maxInt value can be passed as uint value + v1: sqltypes.NewUint64(math.MaxInt64), + v2: sqltypes.NewInt64(3), + err: "BIGINT UNSIGNED value is out of range in 9223372036854775807 * 3", + }}, }} - for _, tcase := range tcases { - - got, err := Multiply(tcase.v1, tcase.v2) - - if !vterrors.Equals(err, tcase.err) { - t.Errorf("Multiply(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("Multiply(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) - } + for _, test := range tests { + t.Run(test.operator, func(t *testing.T) { + for _, tcase := range test.cases { + name := fmt.Sprintf("%s%s%s", tcase.v1.String(), test.operator, tcase.v2.String()) + t.Run(name, func(t *testing.T) { + got, err := test.f(tcase.v1, tcase.v2) + if tcase.err == "" { + require.NoError(t, err) + require.Equal(t, tcase.out, got) + } else { + require.EqualError(t, err, tcase.err) + } + }) + } + }) } - -} - -func TestSubtract(t *testing.T) { - tcases := []struct { - v1, v2 sqltypes.Value - out sqltypes.Value - err error - }{{ - // All Nulls - v1: sqltypes.NULL, - v2: sqltypes.NULL, - out: sqltypes.NULL, - }, { - // First value null. - v1: sqltypes.NewInt32(1), - v2: sqltypes.NULL, - out: sqltypes.NULL, - }, { - // Second value null. - v1: sqltypes.NULL, - v2: sqltypes.NewInt32(1), - out: sqltypes.NULL, - }, { - // case with negative value - v1: sqltypes.NewInt64(-1), - v2: sqltypes.NewInt64(-2), - out: sqltypes.NewInt64(1), - }, { - // testing for int64 overflow with min negative value - v1: sqltypes.NewInt64(math.MinInt64), - v2: sqltypes.NewInt64(1), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in -9223372036854775808 - 1"), - }, { - v1: sqltypes.NewUint64(4), - v2: sqltypes.NewInt64(5), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 4 - 5"), - }, { - // testing uint - int - v1: sqltypes.NewUint64(7), - v2: sqltypes.NewInt64(5), - out: sqltypes.NewUint64(2), - }, { - v1: sqltypes.NewUint64(math.MaxUint64), - v2: sqltypes.NewInt64(0), - out: sqltypes.NewUint64(math.MaxUint64), - }, { - // testing for int64 overflow - v1: sqltypes.NewInt64(math.MinInt64), - v2: sqltypes.NewUint64(0), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in -9223372036854775808 - 0"), - }, { - v1: sqltypes.TestValue(querypb.Type_VARCHAR, "c"), - v2: sqltypes.NewInt64(1), - out: sqltypes.NewInt64(-1), - }, { - v1: sqltypes.NewUint64(1), - v2: sqltypes.TestValue(querypb.Type_VARCHAR, "c"), - out: sqltypes.NewUint64(1), - }, { - // testing for error for parsing float value to uint64 - v1: sqltypes.TestValue(querypb.Type_UINT64, "1.2"), - v2: sqltypes.NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), - }, { - // testing for error for parsing float value to uint64 - v1: sqltypes.NewUint64(2), - v2: sqltypes.TestValue(querypb.Type_UINT64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseUint: parsing \"1.2\": invalid syntax"), - }, { - // uint64 - uint64 - v1: sqltypes.NewUint64(8), - v2: sqltypes.NewUint64(4), - out: sqltypes.NewUint64(4), - }, { - // testing for float subtraction: float - int - v1: sqltypes.NewFloat64(1.2), - v2: sqltypes.NewInt64(2), - out: sqltypes.NewFloat64(-0.8), - }, { - // testing for float subtraction: float - uint - v1: sqltypes.NewFloat64(1.2), - v2: sqltypes.NewUint64(2), - out: sqltypes.NewFloat64(-0.8), - }, { - v1: sqltypes.NewInt64(-1), - v2: sqltypes.NewUint64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in -1 - 2"), - }, { - v1: sqltypes.NewInt64(2), - v2: sqltypes.NewUint64(1), - out: sqltypes.NewUint64(1), - }, { - // testing int64 - float64 method - v1: sqltypes.NewInt64(-2), - v2: sqltypes.NewFloat64(1.0), - out: sqltypes.NewFloat64(-3.0), - }, { - // testing uint64 - float64 method - v1: sqltypes.NewUint64(1), - v2: sqltypes.NewFloat64(-2.0), - out: sqltypes.NewFloat64(3.0), - }, { - // testing uint - int to return uintplusint - v1: sqltypes.NewUint64(1), - v2: sqltypes.NewInt64(-2), - out: sqltypes.NewUint64(3), - }, { - // testing for float - float - v1: sqltypes.NewFloat64(1.2), - v2: sqltypes.NewFloat64(3.2), - out: sqltypes.NewFloat64(-2), - }, { - // testing uint - uint if v2 > v1 - v1: sqltypes.NewUint64(2), - v2: sqltypes.NewUint64(4), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 2 - 4"), - }, { - // testing uint - (- int) - v1: sqltypes.NewUint64(1), - v2: sqltypes.NewInt64(-2), - out: sqltypes.NewUint64(3), - }} - - for _, tcase := range tcases { - - got, err := Subtract(tcase.v1, tcase.v2) - - if !vterrors.Equals(err, tcase.err) { - t.Errorf("Subtract(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("Subtract(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) - } - } - -} - -func TestAdd(t *testing.T) { - tcases := []struct { - v1, v2 sqltypes.Value - out sqltypes.Value - err error - }{{ - // All Nulls - v1: sqltypes.NULL, - v2: sqltypes.NULL, - out: sqltypes.NULL, - }, { - // First value null. - v1: sqltypes.NewInt32(1), - v2: sqltypes.NULL, - out: sqltypes.NULL, - }, { - // Second value null. - v1: sqltypes.NULL, - v2: sqltypes.NewInt32(1), - out: sqltypes.NULL, - }, { - // case with negatives - v1: sqltypes.NewInt64(-1), - v2: sqltypes.NewInt64(-2), - out: sqltypes.NewInt64(-3), - }, { - // testing for overflow int64, result will be unsigned int - v1: sqltypes.NewInt64(math.MaxInt64), - v2: sqltypes.NewUint64(2), - out: sqltypes.NewUint64(9223372036854775809), - }, { - v1: sqltypes.NewInt64(-2), - v2: sqltypes.NewUint64(1), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 1 + -2"), - }, { - v1: sqltypes.NewInt64(math.MaxInt64), - v2: sqltypes.NewInt64(-2), - out: sqltypes.NewInt64(9223372036854775805), - }, { - // Normal case - v1: sqltypes.NewUint64(1), - v2: sqltypes.NewUint64(2), - out: sqltypes.NewUint64(3), - }, { - // testing for overflow uint64 - v1: sqltypes.NewUint64(math.MaxUint64), - v2: sqltypes.NewUint64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 18446744073709551615 + 2"), - }, { - // int64 underflow - v1: sqltypes.NewInt64(math.MinInt64), - v2: sqltypes.NewInt64(-2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT value is out of range in -9223372036854775808 + -2"), - }, { - // checking int64 max value can be returned - v1: sqltypes.NewInt64(math.MaxInt64), - v2: sqltypes.NewUint64(0), - out: sqltypes.NewUint64(9223372036854775807), - }, { - // testing whether uint64 max value can be returned - v1: sqltypes.NewUint64(math.MaxUint64), - v2: sqltypes.NewInt64(0), - out: sqltypes.NewUint64(math.MaxUint64), - }, { - v1: sqltypes.NewUint64(math.MaxInt64), - v2: sqltypes.NewInt64(1), - out: sqltypes.NewUint64(9223372036854775808), - }, { - v1: sqltypes.NewUint64(1), - v2: sqltypes.TestValue(querypb.Type_VARCHAR, "c"), - out: sqltypes.NewUint64(1), - }, { - v1: sqltypes.NewUint64(1), - v2: sqltypes.TestValue(querypb.Type_VARCHAR, "1.2"), - out: sqltypes.NewFloat64(2.2), - }, { - v1: sqltypes.TestValue(querypb.Type_INT64, "1.2"), - v2: sqltypes.NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - v1: sqltypes.NewInt64(2), - v2: sqltypes.TestValue(querypb.Type_INT64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "strconv.ParseInt: parsing \"1.2\": invalid syntax"), - }, { - // testing for uint64 overflow with max uint64 + int value - v1: sqltypes.NewUint64(math.MaxUint64), - v2: sqltypes.NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "BIGINT UNSIGNED value is out of range in 18446744073709551615 + 2"), - }} - - for _, tcase := range tcases { - - got, err := Add(tcase.v1, tcase.v2) - - if !vterrors.Equals(err, tcase.err) { - t.Errorf("Add(%v, %v) error: %v, want %v", printValue(tcase.v1), printValue(tcase.v2), vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - continue - } - - if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("Add(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) - } - } - } func TestNullsafeAdd(t *testing.T) {