Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for binops between two differently sized DecimalDtypes #16638

Merged
merged 6 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions python/cudf/cudf/core/column/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,15 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str):
# are computed outside of libcudf
if op in {"__add__", "__sub__", "__mul__", "__div__"}:
output_type = _get_decimal_type(lhs.dtype, rhs.dtype, op)
lhs = lhs.astype(
type(output_type)(lhs.dtype.precision, lhs.dtype.scale)
)
rhs = rhs.astype(
type(output_type)(rhs.dtype.precision, rhs.dtype.scale)
)
result = libcudf.binaryop.binaryop(lhs, rhs, op, output_type)
# TODO: Why is this necessary? Why isn't the result's
# precision already set correctly based on output_type?
# libcudf doesn't support precision, so result.dtype doesn't
# maintain output_type.precision
result.dtype.precision = output_type.precision
elif op in {
"__eq__",
Expand Down Expand Up @@ -430,7 +436,11 @@ def _with_type_metadata(
return self


def _get_decimal_type(lhs_dtype, rhs_dtype, op):
def _get_decimal_type(
lhs_dtype: DecimalDtype,
rhs_dtype: DecimalDtype,
op: str,
) -> DecimalDtype:
"""
Returns the resulting decimal type after calculating
precision & scale when performing the binary operation
Expand All @@ -441,6 +451,7 @@ def _get_decimal_type(lhs_dtype, rhs_dtype, op):

# This should at some point be hooked up to libcudf's
# binary_operation_fixed_point_scale
# Note: libcudf decimal types don't have a concept of precision

p1, p2 = lhs_dtype.precision, rhs_dtype.precision
s1, s2 = lhs_dtype.scale, rhs_dtype.scale
Expand Down
10 changes: 10 additions & 0 deletions python/cudf/cudf/tests/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,3 +398,13 @@ def test_decimal_overflow():
s = cudf.Series([1, 2], dtype=cudf.Decimal128Dtype(precision=38, scale=0))
result = s * Decimal("1.0")
assert_eq(cudf.Decimal128Dtype(precision=38, scale=1), result.dtype)


def test_decimal_binop_upcast_operands():
ser1 = cudf.Series([0.51, 1.51, 2.51]).astype(cudf.Decimal64Dtype(18, 2))
ser2 = cudf.Series([0.90, 0.96, 0.99]).astype(cudf.Decimal128Dtype(19, 2))
result = ser1 + ser2
expected = cudf.Series([1.41, 2.47, 3.50]).astype(
cudf.Decimal128Dtype(20, 2)
)
assert_eq(result, expected)
Loading