diff --git a/python/cudf/cudf/_lib/lists.pyx b/python/cudf/cudf/_lib/lists.pyx index e93cba20f65..9bc0550bdf0 100644 --- a/python/cudf/cudf/_lib/lists.pyx +++ b/python/cudf/cudf/_lib/lists.pyx @@ -145,7 +145,10 @@ def extract_element(Column col, size_type index): return result -def contains_scalar(Column col, DeviceScalar search_key): +def contains_scalar(Column col, object py_search_key): + + cdef DeviceScalar search_key = py_search_key.device_value + cdef shared_ptr[lists_column_view] list_view = ( make_shared[lists_column_view](col.view()) ) diff --git a/python/cudf/cudf/core/column/lists.py b/python/cudf/cudf/core/column/lists.py index 364675cd035..da953df5478 100644 --- a/python/cudf/cudf/core/column/lists.py +++ b/python/cudf/cudf/core/column/lists.py @@ -237,9 +237,10 @@ def contains(self, search_key): Series([False, True, True]) dtype: bool """ + search_key = cudf.Scalar(search_key) try: res = self._return_or_inplace( - contains_scalar(self._column, search_key.device_value) + contains_scalar(self._column, search_key) ) except RuntimeError as e: if ( diff --git a/python/cudf/cudf/tests/test_contains.py b/python/cudf/cudf/tests/test_contains.py index 4737faf65a4..b669c40022e 100644 --- a/python/cudf/cudf/tests/test_contains.py +++ b/python/cudf/cudf/tests/test_contains.py @@ -1,11 +1,17 @@ from datetime import datetime as dt +import numpy as np import pandas as pd import pytest from cudf import Series from cudf.core.index import RangeIndex, as_index -from cudf.tests.utils import assert_eq +from cudf.tests.utils import ( + DATETIME_TYPES, + NUMERIC_TYPES, + TIMEDELTA_TYPES, + assert_eq, +) def cudf_date_series(start, stop, freq): @@ -72,3 +78,43 @@ def test_index_contains(values, item, expected): def test_rangeindex_contains(): assert_eq(True, 9 in RangeIndex(start=0, stop=10, name="Index")) assert_eq(False, 10 in RangeIndex(start=0, stop=10, name="Index")) + + +@pytest.mark.parametrize("dtype", NUMERIC_TYPES) +def test_lists_contains(dtype): + dtype = np.dtype(dtype) + inner_data = np.array([1, 2, 3], dtype=dtype) + + data = Series([inner_data]) + + contained_scalar = inner_data.dtype.type(2) + not_contained_scalar = inner_data.dtype.type(42) + + assert data.list.contains(contained_scalar)[0] + assert not data.list.contains(not_contained_scalar)[0] + + +@pytest.mark.parametrize("dtype", DATETIME_TYPES + TIMEDELTA_TYPES) +def test_lists_contains_datetime(dtype): + dtype = np.dtype(dtype) + inner_data = np.array([1, 2, 3]) + + unit, _ = np.datetime_data(dtype) + + data = Series([inner_data]) + + contained_scalar = inner_data.dtype.type(2) + not_contained_scalar = inner_data.dtype.type(42) + + assert data.list.contains(contained_scalar)[0] + assert not data.list.contains(not_contained_scalar)[0] + + +def test_lists_contains_bool(): + data = Series([[True, True, True]]) + + contained_scalar = True + not_contained_scalar = False + + assert data.list.contains(contained_scalar)[0] + assert not data.list.contains(not_contained_scalar)[0]