Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Copy nested types upon construction #8244

Merged
merged 8 commits into from
May 20, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 61 additions & 3 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@
from cudf._typing import BinaryOperand, ColumnLike, Dtype, ScalarLike
from cudf.core.abc import Serializable
from cudf.core.buffer import Buffer
from cudf.core.dtypes import CategoricalDtype, IntervalDtype
from cudf.core.dtypes import (
CategoricalDtype,
IntervalDtype,
ListDtype,
StructDtype,
)
from cudf.utils import ioutils, utils
from cudf.utils.dtypes import (
check_cast_unsupported_dtype,
Expand Down Expand Up @@ -399,8 +404,7 @@ def from_arrow(cls, array: pa.Array) -> ColumnBase:
"None"
]

if isinstance(result.dtype, cudf.Decimal64Dtype):
result.dtype.precision = array.type.precision
result = _copy_type_metadata_from_arrow(array, result)
return result

def _get_mask_as_column(self) -> ColumnBase:
Expand Down Expand Up @@ -2346,3 +2350,57 @@ def full(size: int, fill_value: ScalarLike, dtype: Dtype = None) -> ColumnBase:
dtype: int8
"""
return ColumnBase.from_scalar(cudf.Scalar(fill_value, dtype), size)


def _copy_type_metadata_from_arrow(
arrow_array: pa.array, cudf_column: ColumnBase
) -> ColumnBase:
"""
Similar to `Column._copy_type_metadata`, except copies type metadata
from arrow array into a cudf column. Recursive for every level.
* When `arrow_array` is struct type and `cudf_column` is StructDtype, copy
field names.
* When `arrow_array` is decimal type and `cudf_column` is
Decimal64Dtype, copy precisions.
"""
if pa.types.is_decimal(arrow_array.type) and isinstance(
cudf_column, cudf.core.column.DecimalColumn
):
cudf_column.dtype.precision = arrow_array.type.precision
elif pa.types.is_struct(arrow_array.type) and isinstance(
cudf_column, cudf.core.column.StructColumn
):
base_children = tuple(
_copy_type_metadata_from_arrow(arrow_array.field(i), col_child)
for i, col_child in enumerate(cudf_column.base_children)
)
cudf_column.set_base_children(base_children)
return cudf.core.column.StructColumn(
data=None,
size=cudf_column.base_size,
dtype=StructDtype.from_arrow(arrow_array.type),
mask=cudf_column.base_mask,
offset=cudf_column.offset,
null_count=cudf_column.null_count,
children=base_children,
)
elif pa.types.is_list(arrow_array.type) and isinstance(
cudf_column, cudf.core.column.ListColumn
):
if arrow_array.values and cudf_column.base_children:
base_children = (
cudf_column.base_children[0],
_copy_type_metadata_from_arrow(
arrow_array.values, cudf_column.base_children[1]
),
)
return cudf.core.column.ListColumn(
size=cudf_column.base_size,
dtype=ListDtype.from_arrow(arrow_array.type),
mask=cudf_column.base_mask,
offset=cudf_column.offset,
null_count=cudf_column.null_count,
children=base_children,
)

return cudf_column
8 changes: 5 additions & 3 deletions python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def __init__(self, element_type: Any) -> None:
def element_type(self) -> Dtype:
if isinstance(self._typ.value_type, pa.ListType):
return ListDtype.from_arrow(self._typ.value_type)
elif isinstance(self._typ.value_type, pa.StructType):
return StructDtype.from_arrow(self._typ.value_type)
else:
return np.dtype(self._typ.value_type.to_pandas_dtype()).name

Expand Down Expand Up @@ -176,10 +178,10 @@ def __eq__(self, other):
return self._typ.equals(other._typ)

def __repr__(self):
if isinstance(self.element_type, ListDtype):
return f"ListDtype({self.element_type.__repr__()})"
if isinstance(self.element_type, (ListDtype, StructDtype)):
return f"{type(self).__name__}({self.element_type.__repr__()})"
else:
return f"ListDtype({self.element_type})"
return f"{type(self).__name__}({self.element_type})"

def __hash__(self):
return hash(self._typ)
Expand Down
106 changes: 105 additions & 1 deletion python/cudf/cudf/tests/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
import pytest

import cudf
from cudf.core.column import ColumnBase
from cudf.core.dtypes import (
CategoricalDtype,
Decimal64Dtype,
IntervalDtype,
ListDtype,
StructDtype,
IntervalDtype,
)
from cudf.tests.utils import assert_eq
from cudf.utils.dtypes import np_to_pa_dtype


def test_cdt_basic():
Expand Down Expand Up @@ -155,3 +157,105 @@ def test_interval_dtype_pyarrow_round_trip(fields, closed):
expect = pa_array
got = IntervalDtype.from_arrow(expect).to_arrow()
assert expect.equals(got)


def assert_column_array_dtype_equal(column: ColumnBase, array: pa.array):
"""
In cudf, each column holds its dtype. And since column may have child
columns, child columns also holds there datatype. This method tests
that every level of `column` matches the type of the given `array`
recursively.
"""

if isinstance(column.dtype, ListDtype):
return array.type.equals(
column.dtype.to_arrow()
) and assert_column_array_dtype_equal(
column.base_children[1], array.values
)
elif isinstance(column.dtype, StructDtype):
return array.type.equals(column.dtype.to_arrow()) and all(
[
assert_column_array_dtype_equal(child, array.field(i))
for i, child in enumerate(column.base_children)
]
)
elif isinstance(column.dtype, Decimal64Dtype):
return array.type.equals(column.dtype.to_arrow())
elif isinstance(column.dtype, CategoricalDtype):
raise NotImplementedError()
else:
return array.type.equals(np_to_pa_dtype(column.dtype))


@pytest.mark.parametrize(
"data",
[
[[{"name": 123}]],
[
[
{
"IsLeapYear": False,
"data": {"Year": 1999, "Month": 7},
"names": ["Mike", None],
},
{
"IsLeapYear": True,
"data": {"Year": 2004, "Month": 12},
"names": None,
},
{
"IsLeapYear": False,
"data": {"Year": 1996, "Month": 2},
"names": ["Rose", "Richard"],
},
]
],
[
[None, {"human?": True, "deets": {"weight": 2.4, "age": 27}}],
[
{"human?": None, "deets": {"weight": 5.3, "age": 25}},
{"human?": False, "deets": {"weight": 8.0, "age": 31}},
{"human?": False, "deets": None},
],
[],
None,
[{"human?": None, "deets": {"weight": 6.9, "age": None}}],
],
[
{
"name": "var0",
"val": [
{"name": "var1", "val": None, "type": "optional<struct>"}
],
"type": "list",
},
{},
{
"name": "var2",
"val": [
{
"name": "var3",
"val": {"field": 42},
"type": "optional<struct>",
},
{
"name": "var4",
"val": {"field": 3.14},
"type": "optional<struct>",
},
],
"type": "list",
},
None,
],
],
)
def test_lists_of_structs_dtype(data):
got = cudf.Series(data)
expected = pa.array(data)

assert_column_array_dtype_equal(got._column, expected)
# TODO: use assert_eq after
# https://github.com/pandas-dev/pandas/issues/41466 resolved
assert expected.equals(got._column.to_arrow())