diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 1ad80ebe009..934aa11cc97 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -30,7 +30,6 @@ #include #include #include -#include #include #include #include @@ -463,9 +462,8 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_dropListDuplicates(JNIEnv JNI_NULL_CHECK(env, column_view, "column is null", 0); try { cudf::jni::auto_set_device(env); - cudf::column_view const *cv = reinterpret_cast(column_view); - cudf::lists_column_view lcv(*cv); - return release_as_jlong(cudf::lists::drop_list_duplicates(lcv)); + auto const input_cv = reinterpret_cast(column_view); + return release_as_jlong(cudf::lists::distinct(cudf::lists_column_view{*input_cv})); } CATCH_STD(env, 0); } @@ -476,59 +474,18 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_dropListDuplicatesWithKey try { cudf::jni::auto_set_device(env); auto const input_cv = reinterpret_cast(keys_vals_handle); - CUDF_EXPECTS(input_cv->offset() == 0, "Input column has non-zero offset."); - CUDF_EXPECTS(input_cv->type().id() == cudf::type_id::LIST, - "Input column is not a lists column."); + JNI_ARG_CHECK(env, input_cv->type().id() == cudf::type_id::LIST, + "Input column is not a lists column.", 0); - // Extract list offsets and a column of struct from the input lists column. auto const lists_keys_vals = cudf::lists_column_view(*input_cv); - auto const keys_vals = lists_keys_vals.get_sliced_child(cudf::default_stream_value); - CUDF_EXPECTS(keys_vals.type().id() == cudf::type_id::STRUCT, - "Input column has child that is not a structs column."); - CUDF_EXPECTS(keys_vals.num_children() == 2, - "Input column has child that does not have 2 children."); - - auto const lists_offsets = lists_keys_vals.offsets(); - auto const structs_keys_vals = cudf::structs_column_view(keys_vals); - - // Assemble a lists_column_view from the existing data (offsets + child). - // This will not copy any data, just create a view, for performance reason. - auto const make_lists_view = [&input_cv](auto const &offsets, auto const &child) { - return cudf::lists_column_view( - cudf::column_view(cudf::data_type{input_cv->type()}, input_cv->size(), nullptr, - input_cv->null_mask(), input_cv->null_count(), 0, {offsets, child})); - }; - - // Extract keys and values lists columns from the input lists of structs column. - auto const keys = make_lists_view(lists_offsets, structs_keys_vals.child(0)); - auto const vals = make_lists_view(lists_offsets, structs_keys_vals.child(1)); - - // Apache Spark desires to keep the last duplicate element. - auto [out_keys, out_vals] = - cudf::lists::drop_list_duplicates(keys, vals, cudf::duplicate_keep_option::KEEP_LAST); - - // Release the contents of the outputs. - auto out_keys_content = out_keys->release(); - auto out_vals_content = out_vals->release(); - - // Total number of elements in the child column. - // This should be the same for the out_vals column. - auto const out_child_size = - out_keys_content.children[cudf::lists_column_view::child_column_index]->size(); - - // Assemble a lists column of struct for the final output. - auto out_structs_members = std::vector>(); - out_structs_members.emplace_back( - std::move(out_keys_content.children[cudf::lists_column_view::child_column_index])); - out_structs_members.emplace_back( - std::move(out_vals_content.children[cudf::lists_column_view::child_column_index])); - auto &out_offsets = out_keys_content.children[cudf::lists_column_view::offsets_column_index]; - - auto out_structs = - cudf::make_structs_column(out_child_size, std::move(out_structs_members), 0, {}); - return release_as_jlong(cudf::make_lists_column(input_cv->size(), std::move(out_offsets), - std::move(out_structs), input_cv->null_count(), - cudf::copy_bitmask(*input_cv))); + auto const keys_vals = lists_keys_vals.child(); + JNI_ARG_CHECK(env, keys_vals.type().id() == cudf::type_id::STRUCT, + "Input column has child that is not a structs column.", 0); + JNI_ARG_CHECK(env, keys_vals.num_children() == 2, + "Input column has child that does not have 2 children.", 0); + + return release_as_jlong( + cudf::jni::lists_distinct_by_key(lists_keys_vals, cudf::default_stream_value)); } CATCH_STD(env, 0); } diff --git a/java/src/main/native/src/ColumnViewJni.cu b/java/src/main/native/src/ColumnViewJni.cu index 6b4db39eb34..aa21b508040 100644 --- a/java/src/main/native/src/ColumnViewJni.cu +++ b/java/src/main/native/src/ColumnViewJni.cu @@ -16,8 +16,16 @@ #include #include +#include #include +#include +#include +#include #include +#include +#include +#include +#include #include #include @@ -72,4 +80,50 @@ std::unique_ptr generate_list_offsets(cudf::column_view const &lis return offsets_column; } + +std::unique_ptr lists_distinct_by_key(cudf::lists_column_view const &input, + rmm::cuda_stream_view stream) { + if (input.is_empty()) { + return empty_like(input.parent()); + } + + auto const child = input.get_sliced_child(stream); + + // Genereate labels for the input list elements. + auto labels = rmm::device_uvector(child.size(), stream); + cudf::detail::label_segments(input.offsets_begin(), input.offsets_end(), labels.begin(), + labels.end(), stream); + + // Use `cudf::duplicate_keep_option::KEEP_LAST` so this will produce the desired behavior when + // being called in `create_map` in spark-rapids. + // Other options comparing nulls and NaNs are set as all-equal. + auto out_columns = cudf::detail::stable_distinct( + table_view{{column_view{cudf::device_span{labels}}, + child.child(0), child.child(1)}}, // input table + std::vector{0, 1}, // key columns + cudf::duplicate_keep_option::KEEP_LAST, cudf::null_equality::EQUAL, + cudf::nan_equality::ALL_EQUAL, stream) + ->release(); + auto const out_labels = out_columns.front()->view(); + + // Assemble a structs column of . + auto out_structs_members = std::vector>(); + out_structs_members.emplace_back(std::move(out_columns[1])); + out_structs_members.emplace_back(std::move(out_columns[2])); + auto out_structs = + cudf::make_structs_column(out_labels.size(), std::move(out_structs_members), 0, {}); + + // Assemble a lists column of structs. + auto out_offsets = make_numeric_column(data_type{type_to_id()}, input.size() + 1, + mask_state::UNALLOCATED, stream); + auto const offsets_begin = out_offsets->mutable_view().template begin(); + auto const labels_begin = out_labels.template begin(); + cudf::detail::labels_to_offsets(labels_begin, labels_begin + out_labels.size(), offsets_begin, + offsets_begin + out_offsets->size(), stream); + + return cudf::make_lists_column(input.size(), std::move(out_offsets), std::move(out_structs), + input.null_count(), + cudf::detail::copy_bitmask(input.parent(), stream), stream); +} + } // namespace cudf::jni diff --git a/java/src/main/native/src/ColumnViewJni.hpp b/java/src/main/native/src/ColumnViewJni.hpp index f9ad01d82d7..1ad8923d5b3 100644 --- a/java/src/main/native/src/ColumnViewJni.hpp +++ b/java/src/main/native/src/ColumnViewJni.hpp @@ -15,6 +15,7 @@ */ #include +#include #include #include @@ -52,4 +53,20 @@ std::unique_ptr generate_list_offsets(cudf::column_view const &list_length, rmm::cuda_stream_view stream = cudf::default_stream_value); +/** + * @brief Generates lists column by copying elements that are distinct by key from each input list + * row to the corresponding output row. + * + * The input lists column must be given such that each list element is a struct of + * pair. With such input, a list containing distinct by key elements are defined such that the keys + * of all elements in the list are distinct (i.e., any two keys are always compared unequal). + * + * There will not be any validity check for the input. The caller is responsible to make sure that + * the input lists column has the right structure. + * + * @return A new list columns in which the elements in each list are distinct by key. + */ +std::unique_ptr lists_distinct_by_key(cudf::lists_column_view const &input, + rmm::cuda_stream_view stream); + } // namespace cudf::jni diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 21ae0d427e2..d9d8044b0ad 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -4129,7 +4129,7 @@ void testExtractAllRecord() { null, null, Arrays.asList("a", "1", "b", "1", "a", "2")); - + ColumnVector resultIdx0 = v.extractAllRecord(pattern, 0); ColumnVector resultIdx1 = v.extractAllRecord(pattern, 1); ColumnVector resultIdx2 = v.extractAllRecord(pattern, 2); @@ -4466,31 +4466,52 @@ void testDropListDuplicatesWithKeysValues() { ); ColumnVector inputStructsKeysVals = ColumnVector.makeStruct(inputChildKeys, inputChildVals); ColumnVector inputOffsets = ColumnVector.fromInts(0, 2, 5, 10, 15, 15); - ColumnVector inputListsKeysVals = inputStructsKeysVals.makeListFromOffsets(5, - inputOffsets); - - ColumnVector expectedChildKeys = ColumnVector.fromBoxedInts( - 1, 2, - 3, 4, 5, - 0, 6, null, - 6, 7, null - ); - ColumnVector expectedChildVals = ColumnVector.fromBoxedInts( - 10, 20, - 30, 40, 50, - 100, 90, 60, - 120, 150, 140 - ); - ColumnVector expectedStructsKeysVals = ColumnVector.makeStruct(expectedChildKeys, - expectedChildVals); - ColumnVector expectedOffsets = ColumnVector.fromInts(0, 2, 5, 8, 11, 11); - ColumnVector expectedListsKeysVals = expectedStructsKeysVals.makeListFromOffsets(5, - expectedOffsets); - - ColumnVector output = inputListsKeysVals.dropListDuplicatesWithKeysValues(); - ColumnVector sortedOutput = output.listSortRows(false, false); + ColumnVector inputListsKeysVals = inputStructsKeysVals.makeListFromOffsets(5, inputOffsets) ) { - assertColumnsAreEqual(expectedListsKeysVals, sortedOutput); + // Test full input: + try(ColumnVector expectedChildKeys = ColumnVector.fromBoxedInts( + 1, 2, // list1 + 3, 4, 5, // list2 + 0, 6, null, // list3 + 6, 7, null // list4 + // list5 (empty) + ); + ColumnVector expectedChildVals = ColumnVector.fromBoxedInts( + 10, 20, // list1 + 30, 40, 50, // list2 + 100, 90, 60, // list3 + 120, 150, 140 // list4 + // list5 (empty) + ); + ColumnVector expectedStructsKeysVals = ColumnVector.makeStruct(expectedChildKeys, expectedChildVals); + ColumnVector expectedOffsets = ColumnVector.fromInts(0, 2, 5, 8, 11, 11); + ColumnVector expectedListsKeysVals = expectedStructsKeysVals.makeListFromOffsets(5, expectedOffsets); + + ColumnVector output = inputListsKeysVals.dropListDuplicatesWithKeysValues(); + ColumnVector sortedOutput = output.listSortRows(false, false) + ) { + assertColumnsAreEqual(expectedListsKeysVals, sortedOutput); + } + + // Test sliced input: + try(ColumnVector expectedChildKeys = ColumnVector.fromBoxedInts( + 3, 4, 5, // list1 + 0, 6, null // list2 + ); + ColumnVector expectedChildVals = ColumnVector.fromBoxedInts( + 30, 40, 50, // list1 + 100, 90, 60 // list2 + ); + ColumnVector expectedStructsKeysVals = ColumnVector.makeStruct(expectedChildKeys, expectedChildVals); + ColumnVector expectedOffsets = ColumnVector.fromInts(0, 3, 6); + ColumnVector expectedListsKeysVals = expectedStructsKeysVals.makeListFromOffsets(2, expectedOffsets); + + ColumnVector inputSliced = inputListsKeysVals.subVector(1, 3); + ColumnVector output = inputSliced.dropListDuplicatesWithKeysValues(); + ColumnVector sortedOutput = output.listSortRows(false, false) + ) { + assertColumnsAreEqual(expectedListsKeysVals, sortedOutput); + } } } @@ -4516,35 +4537,59 @@ void testDropListDuplicatesWithKeysValuesNullable() { ColumnVector inputOffsets = ColumnVector.fromInts(0, 2, 2, 5, 10, 15, 15); ColumnVector tmpInputListsKeysVals = inputStructsKeysVals.makeListFromOffsets(6,inputOffsets); ColumnVector templateBitmask = ColumnVector.fromBoxedInts(1, null, 1, 1, 1, null); - ColumnVector inputListsKeysVals = tmpInputListsKeysVals.mergeAndSetValidity(BinaryOp.BITWISE_AND, templateBitmask); - - ColumnVector expectedChildKeys = ColumnVector.fromBoxedInts( - 1, 2, // list1 - // list2 (null) - 3, 4, 5, // list3 - 0, 6, null, // list4 - 6, 7, null // list5 - // list6 (null) - ); - ColumnVector expectedChildVals = ColumnVector.fromBoxedInts( - 10, 20, // list1 - // list2 (null) - 30, 40, 50, // list3 - 100, 90, 60, // list4 - 120, 150, 140 // list5 - // list6 (null) - ); - ColumnVector expectedStructsKeysVals = ColumnVector.makeStruct(expectedChildKeys, - expectedChildVals); - ColumnVector expectedOffsets = ColumnVector.fromInts(0, 2, 2, 5, 8, 11, 11); - ColumnVector tmpExpectedListsKeysVals = expectedStructsKeysVals.makeListFromOffsets(6, - expectedOffsets); - ColumnVector expectedListsKeysVals = tmpExpectedListsKeysVals.mergeAndSetValidity(BinaryOp.BITWISE_AND, templateBitmask); - - ColumnVector output = inputListsKeysVals.dropListDuplicatesWithKeysValues(); - ColumnVector sortedOutput = output.listSortRows(false, false); + ColumnVector inputListsKeysVals = tmpInputListsKeysVals.mergeAndSetValidity(BinaryOp.BITWISE_AND, templateBitmask) ) { - assertColumnsAreEqual(expectedListsKeysVals, sortedOutput); + // Test full input: + try(ColumnVector expectedChildKeys = ColumnVector.fromBoxedInts( + 1, 2, // list1 + // list2 (null) + 3, 4, 5, // list3 + 0, 6, null, // list4 + 6, 7, null // list5 + // list6 (null) + ); + ColumnVector expectedChildVals = ColumnVector.fromBoxedInts( + 10, 20, // list1 + // list2 (null) + 30, 40, 50, // list3 + 100, 90, 60, // list4 + 120, 150, 140 // list5 + // list6 (null) + ); + ColumnVector expectedStructsKeysVals = ColumnVector.makeStruct(expectedChildKeys, expectedChildVals); + ColumnVector expectedOffsets = ColumnVector.fromInts(0, 2, 2, 5, 8, 11, 11); + ColumnVector tmpExpectedListsKeysVals = expectedStructsKeysVals.makeListFromOffsets(6, expectedOffsets); + ColumnVector expectedListsKeysVals = tmpExpectedListsKeysVals.mergeAndSetValidity(BinaryOp.BITWISE_AND, templateBitmask); + + ColumnVector output = inputListsKeysVals.dropListDuplicatesWithKeysValues(); + ColumnVector sortedOutput = output.listSortRows(false, false) + ) { + assertColumnsAreEqual(expectedListsKeysVals, sortedOutput); + } + + // Test sliced input: + try(ColumnVector expectedChildKeys = ColumnVector.fromBoxedInts( + // list1 (null) + 3, 4, 5, // list2 + 0, 6, null // list3 + ); + ColumnVector expectedChildVals = ColumnVector.fromBoxedInts( + // list1 (null) + 30, 40, 50, // list2 + 100, 90, 60 // list3 + ); + ColumnVector expectedStructsKeysVals = ColumnVector.makeStruct(expectedChildKeys, expectedChildVals); + ColumnVector expectedOffsets = ColumnVector.fromInts(0, 0, 3, 6); + ColumnVector tmpExpectedListsKeysVals = expectedStructsKeysVals.makeListFromOffsets(3, expectedOffsets); + ColumnVector slicedTemplateBitmask = ColumnVector.fromBoxedInts(null, 1, 1); + ColumnVector expectedListsKeysVals = tmpExpectedListsKeysVals.mergeAndSetValidity(BinaryOp.BITWISE_AND, slicedTemplateBitmask); + + ColumnVector inputSliced = inputListsKeysVals.subVector(1, 4); + ColumnVector output = inputSliced.dropListDuplicatesWithKeysValues(); + ColumnVector sortedOutput = output.listSortRows(false, false) + ) { + assertColumnsAreEqual(expectedListsKeysVals, sortedOutput); + } } }