Skip to content

Commit

Permalink
Use cudf::lists::distinct in Python binding (#11234)
Browse files Browse the repository at this point in the history
Python binding has `lists.unique()` API to extract unique list elements for the input lists column. Previously, it has been implemented by calling to `cudf::lists::drop_list_duplicates`, which performs segmented sort on the input lists and then extracts the unique list elements.

This PR changes the implementation of `lists.unique()` to use `cudf::lists::distinct`, which can improve performance by using a hash table for finding distinct elements without segmented sort.

Authors:
  - Nghia Truong (https://github.com/ttnghia)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #11234
  • Loading branch information
ttnghia authored Jul 13, 2022
1 parent 78f0754 commit 62c0ae8
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2021-2022, NVIDIA CORPORATION.

from libcpp.memory cimport unique_ptr

Expand All @@ -7,9 +7,9 @@ from cudf._lib.cpp.lists.lists_column_view cimport lists_column_view
from cudf._lib.cpp.types cimport nan_equality, null_equality


cdef extern from "cudf/lists/drop_list_duplicates.hpp" \
cdef extern from "cudf/lists/stream_compaction.hpp" \
namespace "cudf::lists" nogil:
cdef unique_ptr[column] drop_list_duplicates(
cdef unique_ptr[column] distinct(
const lists_column_view lists_column,
null_equality nulls_equal,
nan_equality nans_equal
Expand Down
16 changes: 7 additions & 9 deletions python/cudf/cudf/_lib/lists.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@ from cudf._lib.cpp.lists.combine cimport (
from cudf._lib.cpp.lists.count_elements cimport (
count_elements as cpp_count_elements,
)
from cudf._lib.cpp.lists.drop_list_duplicates cimport (
drop_list_duplicates as cpp_drop_list_duplicates,
)
from cudf._lib.cpp.lists.explode cimport explode_outer as cpp_explode_outer
from cudf._lib.cpp.lists.lists_column_view cimport lists_column_view
from cudf._lib.cpp.lists.sorting cimport sort_lists as cpp_sort_lists
from cudf._lib.cpp.lists.stream_compaction cimport distinct as cpp_distinct
from cudf._lib.cpp.scalar.scalar cimport scalar
from cudf._lib.cpp.table.table cimport table
from cudf._lib.cpp.table.table_view cimport table_view
Expand Down Expand Up @@ -75,12 +73,12 @@ def explode_outer(
return columns_from_unique_ptr(move(c_result))


def drop_list_duplicates(Column col, bool nulls_equal, bool nans_all_equal):
def distinct(Column col, bool nulls_equal, bool nans_all_equal):
"""
nans_all_equal == True indicates that libcudf should treat any two elements
from {+nan, -nan} as equal, and as unequal otherwise.
nulls_equal == True indicates that libcudf should treat any two nulls as
equal, and as unequal otherwise.
nans_all_equal == True indicates that libcudf should treat any two
elements from {-nan, +nan} as equal, and as unequal otherwise.
"""
cdef shared_ptr[lists_column_view] list_view = (
make_shared[lists_column_view](col.view())
Expand All @@ -96,9 +94,9 @@ def drop_list_duplicates(Column col, bool nulls_equal, bool nans_all_equal):

with nogil:
c_result = move(
cpp_drop_list_duplicates(list_view.get()[0],
c_nulls_equal,
c_nans_equal)
cpp_distinct(list_view.get()[0],
c_nulls_equal,
c_nans_equal)
)
return Column.from_unique_ptr(move(c_result))

Expand Down
6 changes: 2 additions & 4 deletions python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
concatenate_rows,
contains_scalar,
count_elements,
drop_list_duplicates,
distinct,
extract_element_column,
extract_element_scalar,
index_of_column,
Expand Down Expand Up @@ -603,9 +603,7 @@ def unique(self) -> ParentType:
raise NotImplementedError("Nested lists unique is not supported.")

return self._return_or_inplace(
drop_list_duplicates(
self._column, nulls_equal=True, nans_all_equal=True
)
distinct(self._column, nulls_equal=True, nans_all_equal=True)
)

def sort_values(
Expand Down

0 comments on commit 62c0ae8

Please sign in to comment.