Skip to content

Commit

Permalink
Push DeviceScalar construction into cython for list.contains (#7864)
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-b-miller authored Apr 6, 2021
1 parent b17ed17 commit c8c00f1
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 3 deletions.
5 changes: 4 additions & 1 deletion python/cudf/cudf/_lib/lists.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@ def extract_element(Column col, size_type index):
return result


def contains_scalar(Column col, DeviceScalar search_key):
def contains_scalar(Column col, object py_search_key):

cdef DeviceScalar search_key = py_search_key.device_value

cdef shared_ptr[lists_column_view] list_view = (
make_shared[lists_column_view](col.view())
)
Expand Down
3 changes: 2 additions & 1 deletion python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,10 @@ def contains(self, search_key):
Series([False, True, True])
dtype: bool
"""
search_key = cudf.Scalar(search_key)
try:
res = self._return_or_inplace(
contains_scalar(self._column, search_key.device_value)
contains_scalar(self._column, search_key)
)
except RuntimeError as e:
if (
Expand Down
48 changes: 47 additions & 1 deletion python/cudf/cudf/tests/test_contains.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from datetime import datetime as dt

import numpy as np
import pandas as pd
import pytest

from cudf import Series
from cudf.core.index import RangeIndex, as_index
from cudf.tests.utils import assert_eq
from cudf.tests.utils import (
DATETIME_TYPES,
NUMERIC_TYPES,
TIMEDELTA_TYPES,
assert_eq,
)


def cudf_date_series(start, stop, freq):
Expand Down Expand Up @@ -72,3 +78,43 @@ def test_index_contains(values, item, expected):
def test_rangeindex_contains():
assert_eq(True, 9 in RangeIndex(start=0, stop=10, name="Index"))
assert_eq(False, 10 in RangeIndex(start=0, stop=10, name="Index"))


@pytest.mark.parametrize("dtype", NUMERIC_TYPES)
def test_lists_contains(dtype):
dtype = np.dtype(dtype)
inner_data = np.array([1, 2, 3], dtype=dtype)

data = Series([inner_data])

contained_scalar = inner_data.dtype.type(2)
not_contained_scalar = inner_data.dtype.type(42)

assert data.list.contains(contained_scalar)[0]
assert not data.list.contains(not_contained_scalar)[0]


@pytest.mark.parametrize("dtype", DATETIME_TYPES + TIMEDELTA_TYPES)
def test_lists_contains_datetime(dtype):
dtype = np.dtype(dtype)
inner_data = np.array([1, 2, 3])

unit, _ = np.datetime_data(dtype)

data = Series([inner_data])

contained_scalar = inner_data.dtype.type(2)
not_contained_scalar = inner_data.dtype.type(42)

assert data.list.contains(contained_scalar)[0]
assert not data.list.contains(not_contained_scalar)[0]


def test_lists_contains_bool():
data = Series([[True, True, True]])

contained_scalar = True
not_contained_scalar = False

assert data.list.contains(contained_scalar)[0]
assert not data.list.contains(not_contained_scalar)[0]

0 comments on commit c8c00f1

Please sign in to comment.