diff --git a/docs/cudf/source/basics.rst b/docs/cudf/source/basics.rst index e270708df90..15b4b43662b 100644 --- a/docs/cudf/source/basics.rst +++ b/docs/cudf/source/basics.rst @@ -34,6 +34,8 @@ The following table lists all of cudf types. For methods requiring dtype argumen +------------------------+------------------+-------------------------------------------------------------------------------------+---------------------------------------------+ | Boolean | | np.bool_ | ``'bool'`` | +------------------------+------------------+-------------------------------------------------------------------------------------+---------------------------------------------+ +| Decimal | Decimal64Dtype | (none) | (none) | ++------------------------+------------------+-------------------------------------------------------------------------------------+---------------------------------------------+ **Note: All dtypes above are Nullable** diff --git a/python/cudf/cudf/_lib/cpp/wrappers/decimals.pxd b/python/cudf/cudf/_lib/cpp/wrappers/decimals.pxd index a73e6e0151d..9de23fb2595 100644 --- a/python/cudf/cudf/_lib/cpp/wrappers/decimals.pxd +++ b/python/cudf/cudf/_lib/cpp/wrappers/decimals.pxd @@ -2,6 +2,7 @@ from libc.stdint cimport int64_t, int32_t cdef extern from "cudf/fixed_point/fixed_point.hpp" namespace "numeric" nogil: + # cython type stub to help resolve to numeric::decimal64 ctypedef int64_t decimal64 cdef cppclass scale_type: diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index 971d849d970..e93c5824817 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -1,24 +1,24 @@ # Copyright (c) 2021, NVIDIA CORPORATION. -import cudf +from decimal import Decimal +from typing import cast + import cupy as cp import numpy as np import pyarrow as pa -from typing import cast +from pandas.api.types import is_integer_dtype +import cudf from cudf import _lib as libcudf -from cudf.core.buffer import Buffer -from cudf.core.column import ColumnBase -from cudf.core.dtypes import Decimal64Dtype -from cudf.utils.utils import pa_mask_buffer_to_mask - -from cudf._typing import Dtype from cudf._lib.strings.convert.convert_fixed_point import ( from_decimal as cpp_from_decimal, ) -from cudf.core.column import as_column -from decimal import Decimal +from cudf._typing import Dtype +from cudf.core.buffer import Buffer +from cudf.core.column import ColumnBase, as_column +from cudf.core.dtypes import Decimal64Dtype from cudf.utils.dtypes import is_scalar +from cudf.utils.utils import pa_mask_buffer_to_mask class DecimalColumn(ColumnBase): @@ -65,17 +65,47 @@ def to_arrow(self): def binary_operator(self, op, other, reflect=False): if reflect: self, other = other, self - scale = _binop_scale(self.dtype, other.dtype, op) - output_type = Decimal64Dtype( - scale=scale, precision=Decimal64Dtype.MAX_PRECISION - ) # precision will be ignored, libcudf has no notion of precision - result = libcudf.binaryop.binaryop(self, other, op, output_type) - result.dtype.precision = _binop_precision(self.dtype, other.dtype, op) + + # Binary Arithmatics between decimal columns. `Scale` and `precision` + # are computed outside of libcudf + if op in ("add", "sub", "mul"): + scale = _binop_scale(self.dtype, other.dtype, op) + output_type = Decimal64Dtype( + scale=scale, precision=Decimal64Dtype.MAX_PRECISION + ) # precision will be ignored, libcudf has no notion of precision + result = libcudf.binaryop.binaryop(self, other, op, output_type) + result.dtype.precision = _binop_precision( + self.dtype, other.dtype, op + ) + elif op in ("eq", "lt", "gt", "le", "ge"): + if not isinstance( + other, + (DecimalColumn, cudf.core.column.NumericalColumn, cudf.Scalar), + ): + raise TypeError( + f"Operator {op} not supported between" + f"{str(type(self))} and {str(type(other))}" + ) + if isinstance( + other, cudf.core.column.NumericalColumn + ) and not is_integer_dtype(other.dtype): + raise TypeError( + f"Only decimal and integer column is supported for {op}." + ) + if isinstance(other, cudf.core.column.NumericalColumn): + other = other.as_decimal_column( + Decimal64Dtype(Decimal64Dtype.MAX_PRECISION, 0) + ) + result = libcudf.binaryop.binaryop(self, other, op, bool) return result def normalize_binop_value(self, other): if is_scalar(other) and isinstance(other, (int, np.int, Decimal)): return cudf.Scalar(Decimal(other)) + elif isinstance(other, cudf.Scalar) and isinstance( + other.dtype, cudf.Decimal64Dtype + ): + return other else: raise TypeError(f"cannot normalize {type(other)}") diff --git a/python/cudf/cudf/core/column/numerical.py b/python/cudf/cudf/core/column/numerical.py index f58a47a918c..10a9ffbfbae 100644 --- a/python/cudf/cudf/core/column/numerical.py +++ b/python/cudf/cudf/core/column/numerical.py @@ -22,6 +22,7 @@ column, string, ) +from cudf.core.dtypes import Decimal64Dtype from cudf.utils import cudautils, utils from cudf.utils.dtypes import ( min_column_type, @@ -103,11 +104,23 @@ def binary_operator( out_dtype = self.dtype else: if not ( - isinstance(rhs, (NumericalColumn, cudf.Scalar,),) + isinstance( + rhs, + ( + NumericalColumn, + cudf.Scalar, + cudf.core.column.DecimalColumn, + ), + ) or np.isscalar(rhs) ): msg = "{!r} operator not supported between {} and {}" raise TypeError(msg.format(binop, type(self), type(rhs))) + if isinstance(rhs, cudf.core.column.DecimalColumn): + lhs = self.as_decimal_column( + Decimal64Dtype(Decimal64Dtype.MAX_PRECISION, 0) + ) + return lhs.binary_operator(binop, rhs) out_dtype = np.result_type(self.dtype, rhs.dtype) if binop in ["mod", "floordiv"]: tmp = self if reflect else rhs diff --git a/python/cudf/cudf/tests/test_binops.py b/python/cudf/cudf/tests/test_binops.py index 2e2992fc524..ac80071c8e4 100644 --- a/python/cudf/cudf/tests/test_binops.py +++ b/python/cudf/cudf/tests/test_binops.py @@ -1615,6 +1615,12 @@ def test_binops_with_NA_consistent(dtype, op): assert result._column.null_count == len(data) +def _decimal_series(input, dtype): + return cudf.Series( + [x if x is None else decimal.Decimal(x) for x in input], dtype=dtype, + ) + + @pytest.mark.parametrize( "args", [ @@ -1753,26 +1759,311 @@ def test_binops_with_NA_consistent(dtype, op): ["10.0", None], cudf.Decimal64Dtype(scale=1, precision=8), ), + ( + operator.eq, + ["0.18", "0.42"], + cudf.Decimal64Dtype(scale=2, precision=3), + ["0.18", "0.21"], + cudf.Decimal64Dtype(scale=2, precision=3), + [True, False], + bool, + ), + ( + operator.eq, + ["0.18", "0.42"], + cudf.Decimal64Dtype(scale=2, precision=3), + ["0.1800", "0.2100"], + cudf.Decimal64Dtype(scale=4, precision=5), + [True, False], + bool, + ), + ( + operator.eq, + ["100", None], + cudf.Decimal64Dtype(scale=-2, precision=3), + ["100", "200"], + cudf.Decimal64Dtype(scale=-1, precision=4), + [True, None], + bool, + ), + ( + operator.lt, + ["0.18", "0.42", "1.00"], + cudf.Decimal64Dtype(scale=2, precision=3), + ["0.10", "0.87", "1.00"], + cudf.Decimal64Dtype(scale=2, precision=3), + [False, True, False], + bool, + ), + ( + operator.lt, + ["0.18", "0.42", "1.00"], + cudf.Decimal64Dtype(scale=2, precision=3), + ["0.1000", "0.8700", "1.0000"], + cudf.Decimal64Dtype(scale=4, precision=5), + [False, True, False], + bool, + ), + ( + operator.lt, + ["200", None, "100"], + cudf.Decimal64Dtype(scale=-2, precision=3), + ["100", "200", "100"], + cudf.Decimal64Dtype(scale=-1, precision=4), + [False, None, False], + bool, + ), + ( + operator.gt, + ["0.18", "0.42", "1.00"], + cudf.Decimal64Dtype(scale=2, precision=3), + ["0.10", "0.87", "1.00"], + cudf.Decimal64Dtype(scale=2, precision=3), + [True, False, False], + bool, + ), + ( + operator.gt, + ["0.18", "0.42", "1.00"], + cudf.Decimal64Dtype(scale=2, precision=3), + ["0.1000", "0.8700", "1.0000"], + cudf.Decimal64Dtype(scale=4, precision=5), + [True, False, False], + bool, + ), + ( + operator.gt, + ["300", None, "100"], + cudf.Decimal64Dtype(scale=-2, precision=3), + ["100", "200", "100"], + cudf.Decimal64Dtype(scale=-1, precision=4), + [True, None, False], + bool, + ), + ( + operator.le, + ["0.18", "0.42", "1.00"], + cudf.Decimal64Dtype(scale=2, precision=3), + ["0.10", "0.87", "1.00"], + cudf.Decimal64Dtype(scale=2, precision=3), + [False, True, True], + bool, + ), + ( + operator.le, + ["0.18", "0.42", "1.00"], + cudf.Decimal64Dtype(scale=2, precision=3), + ["0.1000", "0.8700", "1.0000"], + cudf.Decimal64Dtype(scale=4, precision=5), + [False, True, True], + bool, + ), + ( + operator.le, + ["300", None, "100"], + cudf.Decimal64Dtype(scale=-2, precision=3), + ["100", "200", "100"], + cudf.Decimal64Dtype(scale=-1, precision=4), + [False, None, True], + bool, + ), + ( + operator.ge, + ["0.18", "0.42", "1.00"], + cudf.Decimal64Dtype(scale=2, precision=3), + ["0.10", "0.87", "1.00"], + cudf.Decimal64Dtype(scale=2, precision=3), + [True, False, True], + bool, + ), + ( + operator.ge, + ["0.18", "0.42", "1.00"], + cudf.Decimal64Dtype(scale=2, precision=3), + ["0.1000", "0.8700", "1.0000"], + cudf.Decimal64Dtype(scale=4, precision=5), + [True, False, True], + bool, + ), + ( + operator.ge, + ["300", None, "100"], + cudf.Decimal64Dtype(scale=-2, precision=3), + ["100", "200", "100"], + cudf.Decimal64Dtype(scale=-1, precision=4), + [True, None, True], + bool, + ), ], ) def test_binops_decimal(args): op, lhs, l_dtype, rhs, r_dtype, expect, expect_dtype = args - def decimal_series(input, dtype): - return cudf.Series( - [x if x is None else decimal.Decimal(x) for x in input], - dtype=dtype, - ) - - a = decimal_series(lhs, l_dtype) - b = decimal_series(rhs, r_dtype) - expect = decimal_series(expect, expect_dtype) + a = _decimal_series(lhs, l_dtype) + b = _decimal_series(rhs, r_dtype) + expect = ( + _decimal_series(expect, expect_dtype) + if isinstance(expect_dtype, cudf.Decimal64Dtype) + else cudf.Series(expect, dtype=expect_dtype) + ) got = op(a, b) assert expect.dtype == got.dtype utils.assert_eq(expect, got) +@pytest.mark.parametrize( + "args", + [ + ( + operator.eq, + ["100", "41", None], + cudf.Decimal64Dtype(scale=0, precision=5), + [100, 42, 12], + cudf.Series([True, False, None], dtype=bool), + cudf.Series([True, False, None], dtype=bool), + ), + ( + operator.eq, + ["100.000", "42.001", None], + cudf.Decimal64Dtype(scale=3, precision=6), + [100, 42, 12], + cudf.Series([True, False, None], dtype=bool), + cudf.Series([True, False, None], dtype=bool), + ), + ( + operator.eq, + ["100", "40", None], + cudf.Decimal64Dtype(scale=-1, precision=3), + [100, 42, 12], + cudf.Series([True, False, None], dtype=bool), + cudf.Series([True, False, None], dtype=bool), + ), + ( + operator.lt, + ["100", "40", "28", None], + cudf.Decimal64Dtype(scale=0, precision=3), + [100, 42, 24, 12], + cudf.Series([False, True, False, None], dtype=bool), + cudf.Series([False, False, True, None], dtype=bool), + ), + ( + operator.lt, + ["100.000", "42.002", "23.999", None], + cudf.Decimal64Dtype(scale=3, precision=6), + [100, 42, 24, 12], + cudf.Series([False, False, True, None], dtype=bool), + cudf.Series([False, True, False, None], dtype=bool), + ), + ( + operator.lt, + ["100", "40", "10", None], + cudf.Decimal64Dtype(scale=-1, precision=3), + [100, 42, 8, 12], + cudf.Series([False, True, False, None], dtype=bool), + cudf.Series([False, False, True, None], dtype=bool), + ), + ( + operator.gt, + ["100", "42", "20", None], + cudf.Decimal64Dtype(scale=0, precision=3), + [100, 40, 24, 12], + cudf.Series([False, True, False, None], dtype=bool), + cudf.Series([False, False, True, None], dtype=bool), + ), + ( + operator.gt, + ["100.000", "42.002", "23.999", None], + cudf.Decimal64Dtype(scale=3, precision=6), + [100, 42, 24, 12], + cudf.Series([False, True, False, None], dtype=bool), + cudf.Series([False, False, True, None], dtype=bool), + ), + ( + operator.gt, + ["100", "40", "10", None], + cudf.Decimal64Dtype(scale=-1, precision=3), + [100, 42, 8, 12], + cudf.Series([False, False, True, None], dtype=bool), + cudf.Series([False, True, False, None], dtype=bool), + ), + ( + operator.le, + ["100", "40", "28", None], + cudf.Decimal64Dtype(scale=0, precision=3), + [100, 42, 24, 12], + cudf.Series([True, True, False, None], dtype=bool), + cudf.Series([True, False, True, None], dtype=bool), + ), + ( + operator.le, + ["100.000", "42.002", "23.999", None], + cudf.Decimal64Dtype(scale=3, precision=6), + [100, 42, 24, 12], + cudf.Series([True, False, True, None], dtype=bool), + cudf.Series([True, True, False, None], dtype=bool), + ), + ( + operator.le, + ["100", "40", "10", None], + cudf.Decimal64Dtype(scale=-1, precision=3), + [100, 42, 8, 12], + cudf.Series([True, True, False, None], dtype=bool), + cudf.Series([True, False, True, None], dtype=bool), + ), + ( + operator.ge, + ["100", "42", "20", None], + cudf.Decimal64Dtype(scale=0, precision=3), + [100, 40, 24, 12], + cudf.Series([True, True, False, None], dtype=bool), + cudf.Series([True, False, True, None], dtype=bool), + ), + ( + operator.ge, + ["100.000", "42.002", "23.999", None], + cudf.Decimal64Dtype(scale=3, precision=6), + [100, 42, 24, 12], + cudf.Series([True, True, False, None], dtype=bool), + cudf.Series([True, False, True, None], dtype=bool), + ), + ( + operator.ge, + ["100", "40", "10", None], + cudf.Decimal64Dtype(scale=-1, precision=3), + [100, 42, 8, 12], + cudf.Series([True, False, True, None], dtype=bool), + cudf.Series([True, True, False, None], dtype=bool), + ), + ], +) +@pytest.mark.parametrize("integer_dtype", cudf.tests.utils.INTEGER_TYPES) +@pytest.mark.parametrize("reflected", [True, False]) +def test_binops_decimal_comp_mixed_integer(args, integer_dtype, reflected): + """ + Tested compare operations: + eq, lt, gt, le, ge + Each operation has 3 decimal data setups, with scale from {==0, >0, <0}. + Decimal precisions are sufficient to hold the digits. + For each decimal data setup, there is at least one row that lead to one + of the following compare results: {True, False, None}. + """ + if not reflected: + op, ldata, ldtype, rdata, expected, _ = args + else: + op, ldata, ldtype, rdata, _, expected = args + + lhs = _decimal_series(ldata, ldtype) + rhs = cudf.Series(rdata, dtype=integer_dtype) + + if reflected: + rhs, lhs = lhs, rhs + + actual = op(lhs, rhs) + + utils.assert_eq(expected, actual) + + @pytest.mark.parametrize( "args", [ @@ -1803,6 +2094,15 @@ def decimal_series(input, dtype): cudf.Decimal64Dtype(scale=1, precision=7), False, ), + ( + operator.add, + ["100", "200"], + cudf.Decimal64Dtype(scale=-2, precision=3), + cudf.Scalar(decimal.Decimal("1.5")), + ["101.5", "201.5"], + cudf.Decimal64Dtype(scale=1, precision=7), + False, + ), ( operator.add, ["100", "200"], @@ -1830,6 +2130,15 @@ def decimal_series(input, dtype): cudf.Decimal64Dtype(scale=1, precision=7), True, ), + ( + operator.add, + ["100", "200"], + cudf.Decimal64Dtype(scale=-2, precision=3), + cudf.Scalar(decimal.Decimal("1.5")), + ["101.5", "201.5"], + cudf.Decimal64Dtype(scale=1, precision=7), + True, + ), ( operator.mul, ["100", "200"], @@ -1857,6 +2166,15 @@ def decimal_series(input, dtype): cudf.Decimal64Dtype(scale=-1, precision=6), False, ), + ( + operator.mul, + ["100", "200"], + cudf.Decimal64Dtype(scale=-2, precision=3), + cudf.Scalar(decimal.Decimal("1.5")), + ["150", "300"], + cudf.Decimal64Dtype(scale=-1, precision=6), + False, + ), ( operator.mul, ["100", "200"], @@ -1884,6 +2202,15 @@ def decimal_series(input, dtype): cudf.Decimal64Dtype(scale=-1, precision=6), True, ), + ( + operator.mul, + ["100", "200"], + cudf.Decimal64Dtype(scale=-2, precision=3), + cudf.Scalar(decimal.Decimal("1.5")), + ["150", "300"], + cudf.Decimal64Dtype(scale=-1, precision=6), + True, + ), ( operator.sub, ["100", "200"], @@ -1911,6 +2238,15 @@ def decimal_series(input, dtype): cudf.Decimal64Dtype(scale=0, precision=6), False, ), + ( + operator.sub, + ["100", "200"], + cudf.Decimal64Dtype(scale=-2, precision=3), + cudf.Scalar(decimal.Decimal("2.5")), + ["97.5", "197.5"], + cudf.Decimal64Dtype(scale=1, precision=7), + False, + ), ( operator.sub, ["100", "200"], @@ -1938,6 +2274,15 @@ def decimal_series(input, dtype): cudf.Decimal64Dtype(scale=1, precision=7), True, ), + ( + operator.sub, + ["100", "200"], + cudf.Decimal64Dtype(scale=-2, precision=3), + cudf.Scalar(decimal.Decimal("2.5")), + ["-97.5", "-197.5"], + cudf.Decimal64Dtype(scale=1, precision=7), + True, + ), ], ) def test_binops_decimal_scalar(args): @@ -1960,6 +2305,157 @@ def decimal_series(input, dtype): utils.assert_eq(expect, got) +@pytest.mark.parametrize( + "args", + [ + ( + operator.eq, + ["100.00", "41", None], + cudf.Decimal64Dtype(scale=0, precision=5), + 100, + cudf.Series([True, False, None], dtype=bool), + cudf.Series([True, False, None], dtype=bool), + ), + ( + operator.eq, + ["100.123", "41", None], + cudf.Decimal64Dtype(scale=3, precision=6), + decimal.Decimal("100.123"), + cudf.Series([True, False, None], dtype=bool), + cudf.Series([True, False, None], dtype=bool), + ), + ( + operator.eq, + ["100.123", "41", None], + cudf.Decimal64Dtype(scale=3, precision=6), + cudf.Scalar(decimal.Decimal("100.123")), + cudf.Series([True, False, None], dtype=bool), + cudf.Series([True, False, None], dtype=bool), + ), + ( + operator.gt, + ["100.00", "41", "120.21", None], + cudf.Decimal64Dtype(scale=2, precision=5), + 100, + cudf.Series([False, False, True, None], dtype=bool), + cudf.Series([False, True, False, None], dtype=bool), + ), + ( + operator.gt, + ["100.123", "41", "120.21", None], + cudf.Decimal64Dtype(scale=3, precision=6), + decimal.Decimal("100.123"), + cudf.Series([False, False, True, None], dtype=bool), + cudf.Series([False, True, False, None], dtype=bool), + ), + ( + operator.gt, + ["100.123", "41", "120.21", None], + cudf.Decimal64Dtype(scale=3, precision=6), + cudf.Scalar(decimal.Decimal("100.123")), + cudf.Series([False, False, True, None], dtype=bool), + cudf.Series([False, True, False, None], dtype=bool), + ), + ( + operator.ge, + ["100.00", "41", "120.21", None], + cudf.Decimal64Dtype(scale=2, precision=5), + 100, + cudf.Series([True, False, True, None], dtype=bool), + cudf.Series([True, True, False, None], dtype=bool), + ), + ( + operator.ge, + ["100.123", "41", "120.21", None], + cudf.Decimal64Dtype(scale=3, precision=6), + decimal.Decimal("100.123"), + cudf.Series([True, False, True, None], dtype=bool), + cudf.Series([True, True, False, None], dtype=bool), + ), + ( + operator.ge, + ["100.123", "41", "120.21", None], + cudf.Decimal64Dtype(scale=3, precision=6), + cudf.Scalar(decimal.Decimal("100.123")), + cudf.Series([True, False, True, None], dtype=bool), + cudf.Series([True, True, False, None], dtype=bool), + ), + ( + operator.lt, + ["100.00", "41", "120.21", None], + cudf.Decimal64Dtype(scale=2, precision=5), + 100, + cudf.Series([False, True, False, None], dtype=bool), + cudf.Series([False, False, True, None], dtype=bool), + ), + ( + operator.lt, + ["100.123", "41", "120.21", None], + cudf.Decimal64Dtype(scale=3, precision=6), + decimal.Decimal("100.123"), + cudf.Series([False, True, False, None], dtype=bool), + cudf.Series([False, False, True, None], dtype=bool), + ), + ( + operator.lt, + ["100.123", "41", "120.21", None], + cudf.Decimal64Dtype(scale=3, precision=6), + cudf.Scalar(decimal.Decimal("100.123")), + cudf.Series([False, True, False, None], dtype=bool), + cudf.Series([False, False, True, None], dtype=bool), + ), + ( + operator.le, + ["100.00", "41", "120.21", None], + cudf.Decimal64Dtype(scale=2, precision=5), + 100, + cudf.Series([True, True, False, None], dtype=bool), + cudf.Series([True, False, True, None], dtype=bool), + ), + ( + operator.le, + ["100.123", "41", "120.21", None], + cudf.Decimal64Dtype(scale=3, precision=6), + decimal.Decimal("100.123"), + cudf.Series([True, True, False, None], dtype=bool), + cudf.Series([True, False, True, None], dtype=bool), + ), + ( + operator.le, + ["100.123", "41", "120.21", None], + cudf.Decimal64Dtype(scale=3, precision=6), + cudf.Scalar(decimal.Decimal("100.123")), + cudf.Series([True, True, False, None], dtype=bool), + cudf.Series([True, False, True, None], dtype=bool), + ), + ], +) +@pytest.mark.parametrize("reflected", [True, False]) +def test_binops_decimal_scalar_compare(args, reflected): + """ + Tested compare operations: + eq, lt, gt, le, ge + Each operation has 3 data setups: pyints, Decimal, and + decimal cudf.Scalar + For each data setup, there is at least one row that lead to one of the + following compare results: {True, False, None}. + """ + if not reflected: + op, ldata, ldtype, rdata, expected, _ = args + else: + op, ldata, ldtype, rdata, _, expected = args + + lhs = _decimal_series(ldata, ldtype) + rhs = rdata + + if reflected: + rhs, lhs = lhs, rhs + + actual = op(lhs, rhs) + + utils.assert_eq(expected, actual) + + @pytest.mark.parametrize( "dtype", [