Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add decimal column handling in copy_type_metadata #7788

Merged
merged 12 commits into from
Apr 1, 2021
Merged
7 changes: 6 additions & 1 deletion python/cudf/cudf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@
merge,
)
from cudf.core.algorithms import factorize
from cudf.core.dtypes import CategoricalDtype, Decimal64Dtype
from cudf.core.dtypes import (
CategoricalDtype,
Decimal64Dtype,
ListDtype,
StructDtype,
)
from cudf.core.groupby import Grouper
from cudf.core.ops import (
add,
Expand Down
7 changes: 7 additions & 0 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,8 @@ def _copy_type_metadata(self: T, other: ColumnBase) -> ColumnBase:
of `other` and the categories of `self`.
* when both `self` and `other` are StructColumns, rename the fields
of `other` to the field names of `self`.
* when both `self` and `other` are DecimalColumns, copy the precision
from self.dtype to other.dtype
* when `self` and `other` are nested columns of the same type,
recursively apply this function on the children of `self` to the
and the children of `other`.
Expand All @@ -1425,6 +1427,11 @@ def _copy_type_metadata(self: T, other: ColumnBase) -> ColumnBase:
):
other = other._rename_fields(self.dtype.fields.keys())

if isinstance(other, cudf.core.column.DecimalColumn) and isinstance(
self, cudf.core.column.DecimalColumn
):
other.dtype.precision = self.dtype.precision

if type(self) is type(other):
if self.base_children and other.base_children:
base_children = tuple(
Expand Down
19 changes: 10 additions & 9 deletions python/cudf/cudf/core/column/decimal.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
# Copyright (c) 2021, NVIDIA CORPORATION.

import cudf
from typing import cast

import cupy as cp
import numpy as np
import pyarrow as pa
from pandas.api.types import is_integer_dtype
from typing import cast

import cudf
from cudf import _lib as libcudf
from cudf.core.buffer import Buffer
from cudf.core.column import ColumnBase
from cudf.core.dtypes import Decimal64Dtype
from cudf.utils.utils import pa_mask_buffer_to_mask

from cudf._typing import Dtype
from cudf._lib.strings.convert.convert_fixed_point import (
from_decimal as cpp_from_decimal,
)
from cudf.core.column import as_column
from cudf._typing import Dtype
from cudf.core.buffer import Buffer
from cudf.core.column import ColumnBase, as_column
from cudf.core.dtypes import Decimal64Dtype
from cudf.utils.utils import pa_mask_buffer_to_mask


class DecimalColumn(ColumnBase):
dtype: Decimal64Dtype

@classmethod
def from_arrow(cls, data: pa.Array):
dtype = Decimal64Dtype.from_arrow(data.type)
Expand Down
12 changes: 8 additions & 4 deletions python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from cudf._typing import Dtype


class CategoricalDtype(ExtensionDtype):
class _BaseDtype(ExtensionDtype):
pass


class CategoricalDtype(_BaseDtype):

ordered: Optional[bool]

Expand Down Expand Up @@ -121,7 +125,7 @@ def deserialize(cls, header, frames):
return cls(categories=categories, ordered=ordered)


class ListDtype(ExtensionDtype):
class ListDtype(_BaseDtype):
_typ: pa.ListType
name: str = "list"

Expand Down Expand Up @@ -180,7 +184,7 @@ def __hash__(self):
return hash(self._typ)


class StructDtype(ExtensionDtype):
class StructDtype(_BaseDtype):

name = "struct"

Expand Down Expand Up @@ -231,7 +235,7 @@ def __hash__(self):
return hash(self._typ)


class Decimal64Dtype(ExtensionDtype):
class Decimal64Dtype(_BaseDtype):

name = "decimal"
_metadata = ("precision", "scale")
Expand Down
11 changes: 11 additions & 0 deletions python/cudf/cudf/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,3 +1401,14 @@ def test_iloc_before_zero_terminate(arg, pobj):
gobj = cudf.from_pandas(pobj)

assert_eq(pobj.iloc[arg], gobj.iloc[arg])


def test_iloc_decimal():
sr = cudf.Series(["1.00", "2.00", "3.00", "4.00"]).astype(
cudf.Decimal64Dtype(scale=2, precision=3)
)
got = sr.iloc[[3, 2, 1, 0]]
expect = cudf.Series(["4.00", "3.00", "2.00", "1.00"],).astype(
cudf.Decimal64Dtype(scale=2, precision=3)
)
assert_eq(expect.reset_index(drop=True), got.reset_index(drop=True))
7 changes: 7 additions & 0 deletions python/cudf/cudf/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ def assert_eq(left, right, **kwargs):
without switching between assert_frame_equal/assert_series_equal/...
functions.
"""
# dtypes that we support but Pandas doesn't will convert to
# `object`. Check equality before that happens:
if kwargs.get("check_dtype", True):
if hasattr(left, "dtype") and hasattr(right, "dtype"):
if isinstance(left.dtype, cudf.core.dtypes._BaseDtype):
assert_eq(left.dtype, right.dtype)

if hasattr(left, "to_pandas"):
left = left.to_pandas()
if hasattr(right, "to_pandas"):
Expand Down