diff --git a/python/cudf/cudf/_lib/pylibcudf/libcudf/lists/gather.pxd b/python/cudf/cudf/_lib/pylibcudf/libcudf/lists/gather.pxd index 17b4c1877a6..ab7ed141365 100644 --- a/python/cudf/cudf/_lib/pylibcudf/libcudf/lists/gather.pxd +++ b/python/cudf/cudf/_lib/pylibcudf/libcudf/lists/gather.pxd @@ -10,6 +10,6 @@ from cudf._lib.pylibcudf.libcudf.lists.lists_column_view cimport ( cdef extern from "cudf/lists/gather.hpp" namespace "cudf::lists" nogil: cdef unique_ptr[column] segmented_gather( - const lists_column_view source_column, - const lists_column_view gather_map_list + const lists_column_view& source_column, + const lists_column_view& gather_map_list ) except + diff --git a/python/cudf/cudf/_lib/pylibcudf/lists.pxd b/python/cudf/cudf/_lib/pylibcudf/lists.pxd index c9d0a84e8ac..c9c43751a43 100644 --- a/python/cudf/cudf/_lib/pylibcudf/lists.pxd +++ b/python/cudf/cudf/_lib/pylibcudf/lists.pxd @@ -25,3 +25,5 @@ cpdef Column contains_nulls(Column) cpdef Column index_of(Column, ColumnOrScalar, bool) cpdef Column reverse(Column) + +cpdef Column segmented_gather(Column, Column) diff --git a/python/cudf/cudf/_lib/pylibcudf/lists.pyx b/python/cudf/cudf/_lib/pylibcudf/lists.pyx index 651f1346f88..9c56f1139c6 100644 --- a/python/cudf/cudf/_lib/pylibcudf/lists.pyx +++ b/python/cudf/cudf/_lib/pylibcudf/lists.pyx @@ -9,6 +9,7 @@ from cudf._lib.pylibcudf.libcudf.column.column cimport column from cudf._lib.pylibcudf.libcudf.lists cimport ( contains as cpp_contains, explode as cpp_explode, + gather as cpp_gather, reverse as cpp_reverse, ) from cudf._lib.pylibcudf.libcudf.lists.combine cimport ( @@ -232,3 +233,34 @@ cpdef Column reverse(Column input): list_view.view(), )) return Column.from_libcudf(move(c_result)) + + +cpdef Column segmented_gather(Column input, Column gather_map_list): + """Create a column with elements gathered based on the indices in gather_map_list + + For details, see :cpp:func:`segmented_gather`. + + Parameters + ---------- + input : Column + The input column. + gather_map_list : Column + The indices of the lists column to gather. + + Returns + ------- + Column + A new Column with elements in list of rows + gathered based on gather_map_list + """ + + cdef unique_ptr[column] c_result + cdef ListColumnView list_view1 = input.list_view() + cdef ListColumnView list_view2 = gather_map_list.list_view() + + with nogil: + c_result = move(cpp_gather.segmented_gather( + list_view1.view(), + list_view2.view(), + )) + return Column.from_libcudf(move(c_result)) diff --git a/python/cudf/cudf/pylibcudf_tests/test_lists.py b/python/cudf/cudf/pylibcudf_tests/test_lists.py index 58a1dcf8d56..0d95579acb3 100644 --- a/python/cudf/cudf/pylibcudf_tests/test_lists.py +++ b/python/cudf/cudf/pylibcudf_tests/test_lists.py @@ -146,3 +146,17 @@ def test_reverse(test_data): expect = pa.array([lst[::-1] for lst in list_column]) assert_column_eq(expect, res) + + +def test_segmented_gather(test_data): + list_column1 = test_data[0][0] + list_column2 = test_data[0][1] + + plc_column1 = plc.interop.from_arrow(pa.array(list_column1)) + plc_column2 = plc.interop.from_arrow(pa.array(list_column2)) + + res = plc.lists.segmented_gather(plc_column2, plc_column1) + + expect = pa.array([[8, 9], [14], [0], [0, 0]]) + + assert_column_eq(expect, res)