diff --git a/pkg/sql/sem/tree/eval.go b/pkg/sql/sem/tree/eval.go index d2a19a3eb678..2862f93158c0 100644 --- a/pkg/sql/sem/tree/eval.go +++ b/pkg/sql/sem/tree/eval.go @@ -3213,6 +3213,18 @@ func PerformCast(ctx *EvalContext, d Datum, t *types.T) (Datum, error) { res = NewDInt(DInt(iv)) case *DOid: res = &v.DInt + case *DJSON: + if v.Type() == json.NumberJSONType { + // err is ignored as a more appropriate error + // will be generated later + dec, err := json.ToDecimal(v.JSON) + if err == nil { + asInt, err := dec.Int64() + if err == nil { + res = NewDInt(DInt(asInt)) + } + } + } } if res != nil { return res, nil @@ -3253,6 +3265,18 @@ func PerformCast(ctx *EvalContext, d Datum, t *types.T) (Datum, error) { return NewDFloat(DFloat(float64(v.UnixEpochDays()))), nil case *DInterval: return NewDFloat(DFloat(v.AsFloat64())), nil + case *DJSON: + if v.Type() == json.NumberJSONType { + // err is ignored as a more appropriate error + // will be generated later + dec, err := json.ToDecimal(v.JSON) + if err == nil { + asFloat, err := dec.Int64() + if err == nil { + return NewDFloat(DFloat(asFloat)), nil + } + } + } } case types.DecimalFamily: @@ -3301,6 +3325,15 @@ func PerformCast(ctx *EvalContext, d Datum, t *types.T) (Datum, error) { case *DInterval: v.AsBigInt(&dd.Coeff) dd.Exponent = -9 + case *DJSON: + if v.Type() == json.NumberJSONType { + var dec apd.Decimal + // err is ignored as a more appropriate error + // will be generated later + dec, err = json.ToDecimal(v.JSON) + unset = err != nil + dd = DDecimal{dec} + } default: unset = true } diff --git a/pkg/sql/sem/tree/expr.go b/pkg/sql/sem/tree/expr.go index 779502848149..85112e1be725 100644 --- a/pkg/sql/sem/tree/expr.go +++ b/pkg/sql/sem/tree/expr.go @@ -1520,11 +1520,11 @@ var ( bitArrayCastTypes = annotateCast(types.VarBit, []*types.T{types.Unknown, types.VarBit, types.Int, types.String, types.AnyCollatedString}) boolCastTypes = annotateCast(types.Bool, []*types.T{types.Unknown, types.Bool, types.Int, types.Float, types.Decimal, types.String, types.AnyCollatedString}) intCastTypes = annotateCast(types.Int, []*types.T{types.Unknown, types.Bool, types.Int, types.Float, types.Decimal, types.String, types.AnyCollatedString, - types.Timestamp, types.TimestampTZ, types.Date, types.Interval, types.Oid, types.VarBit}) + types.Timestamp, types.TimestampTZ, types.Date, types.Interval, types.Oid, types.VarBit, types.Jsonb}) floatCastTypes = annotateCast(types.Float, []*types.T{types.Unknown, types.Bool, types.Int, types.Float, types.Decimal, types.String, types.AnyCollatedString, - types.Timestamp, types.TimestampTZ, types.Date, types.Interval}) + types.Timestamp, types.TimestampTZ, types.Date, types.Interval, types.Jsonb}) decimalCastTypes = annotateCast(types.Decimal, []*types.T{types.Unknown, types.Bool, types.Int, types.Float, types.Decimal, types.String, types.AnyCollatedString, - types.Timestamp, types.TimestampTZ, types.Date, types.Interval}) + types.Timestamp, types.TimestampTZ, types.Date, types.Interval, types.Jsonb}) stringCastTypes = annotateCast(types.String, []*types.T{types.Unknown, types.Bool, types.Int, types.Float, types.Decimal, types.String, types.AnyCollatedString, types.VarBit, types.AnyArray, types.AnyTuple, diff --git a/pkg/sql/sem/tree/testdata/eval/cast b/pkg/sql/sem/tree/testdata/eval/cast index c5b2920b127a..adcf1aa67c29 100644 --- a/pkg/sql/sem/tree/testdata/eval/cast +++ b/pkg/sql/sem/tree/testdata/eval/cast @@ -977,3 +977,58 @@ eval ARRAY['hello','world']::char(2)[] ---- ARRAY['he','wo'] + +eval +'1'::jsonb::int +---- +1 + +eval +'1'::jsonb::float +---- +1.0 + +eval +'1'::jsonb::decimal +---- +1 + +eval +'1'::jsonb::string +---- +'1' + +eval +'2.0'::jsonb::int +---- +2 + +eval +'2.0'::jsonb::float +---- +2.0 + +eval +'2.0'::jsonb::decimal +---- +2.0 + +eval +'2.0'::jsonb::string +---- +'2.0' + +eval +'3.14'::jsonb::float +---- +3.14 + +eval +'3.14'::jsonb::decimal +---- +3.14 + +eval +'3.14'::jsonb::string +---- +'3.14' diff --git a/pkg/util/json/json.go b/pkg/util/json/json.go index 7ae16b478f23..f7268f31374b 100644 --- a/pkg/util/json/json.go +++ b/pkg/util/json/json.go @@ -972,6 +972,19 @@ func FromFloat64(v float64) (JSON, error) { return jsonNumber(dec), nil } +// ToDecimal returns a apd.Decimal given a JSON value +func ToDecimal(j JSON) (apd.Decimal, error) { + j, err := decodeIfNeeded(j) + if err != nil { + return apd.Decimal{}, err + } + num, ok := j.(jsonNumber) + if !ok { + return apd.Decimal{}, errors.AssertionFailedf("cannot convert JSON of type %T to decimal", j) + } + return apd.Decimal(num), nil +} + // MakeJSON returns a JSON value given a Go-style representation of JSON. // * JSON null is Go `nil`, // * JSON true is Go `true`, diff --git a/pkg/util/json/json_test.go b/pkg/util/json/json_test.go index 722573ea939f..e429aae00b58 100644 --- a/pkg/util/json/json_test.go +++ b/pkg/util/json/json_test.go @@ -2016,3 +2016,65 @@ func TestJSONRemovePath(t *testing.T) { } } } + +func TestToDecimal(t *testing.T) { + numericCases := []string{ + "1", + "1.0", + "3.14", + "-3.14", + "1.000", + "-0.0", + "-0.09", + "0.08", + } + + nonNumericCases := []struct { + input string + errMsg string + }{ + {"\"1\"", "cannot convert JSON of type json.jsonString to decimal"}, + {"{}", "cannot convert JSON of type json.jsonObject to decimal"}, + {"[]", "cannot convert JSON of type json.jsonArray to decimal"}, + {"true", "cannot convert JSON of type json.jsonTrue to decimal"}, + {"false", "cannot convert JSON of type json.jsonFalse to decimal"}, + {"null", "cannot convert JSON of type json.jsonNull to decimal"}, + } + + for _, tc := range numericCases { + t.Run(fmt.Sprintf("numeric - %s", tc), func(t *testing.T) { + dec1, _, err := apd.NewFromString(tc) + if err != nil { + t.Fatal(err) + } + + json, err := ParseJSON(tc) + if err != nil { + t.Fatal(err) + } + + dec2, err := ToDecimal(json) + if err != nil { + t.Fatal(err) + } + + if dec1.Cmp(&dec2) != 0 { + t.Fatalf("expected %s == %s", dec1.String(), dec2.String()) + } + }) + } + + for _, tc := range nonNumericCases { + t.Run(fmt.Sprintf("nonNumeric - %s", tc), func(t *testing.T) { + json, err := ParseJSON(tc.input) + if err != nil { + t.Fatalf("expected no error") + } + + _, err = ToDecimal(json) + if err.Error() != tc.errMsg { + t.Fatalf("expected %s, got %s", tc.errMsg, err.Error()) + } + }) + } +}