From 40a621b2ff236445686a54f6b4d4ab5658355fe0 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Fri, 17 Dec 2021 13:52:18 -0700 Subject: [PATCH 01/12] Finish implementation for scalar search --- cpp/src/search/search.cu | 52 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 3 deletions(-) diff --git a/cpp/src/search/search.cu b/cpp/src/search/search.cu index 241b3c595f1..d7332d9ab46 100644 --- a/cpp/src/search/search.cu +++ b/cpp/src/search/search.cu @@ -137,10 +137,9 @@ std::unique_ptr search_ordered(table_view const& t, struct contains_scalar_dispatch { template - bool operator()(column_view const& col, scalar const& value, rmm::cuda_stream_view stream) + std::enable_if_t, bool> operator()( + column_view const& col, scalar const& value, rmm::cuda_stream_view stream) { - CUDF_EXPECTS(col.type() == value.type(), "scalar and column types must match"); - using Type = device_storage_type_t; using ScalarType = cudf::scalar_type_t; auto d_col = column_device_view::create(col, stream); @@ -162,6 +161,51 @@ struct contains_scalar_dispatch { return found_iter != d_col->end(); } } + + template + std::enable_if_t, bool> operator()( + column_view const& col, scalar const& value, rmm::cuda_stream_view stream) + { + static_assert(std::is_same_v, struct_scalar>); + auto const scalar_table = static_cast(&value)->view(); + CUDF_EXPECTS(col.num_children() == scalar_table.num_columns(), + "scalar and column must have the same number of children"); + + // Prepare to flatten the structs column and scalar. + auto const has_null_elements = + has_nested_nulls(table_view{std::vector{col.child_begin(), col.child_end()}}) || + has_nested_nulls(scalar_table); + auto const flatten_nullability = has_null_elements + ? structs::detail::column_nullability::FORCE + : structs::detail::column_nullability::MATCH_INCOMING; + + // Flatten the input structs column, only materialize the bitmask if there is null in the input. + auto const col_flattened = + structs::detail::flatten_nested_columns(table_view{{col}}, {}, {}, flatten_nullability); + auto const val_flattened = + structs::detail::flatten_nested_columns(scalar_table, {}, {}, flatten_nullability); + + // The struct scalar only contains the struct member columns. + // Thus, if there is any null in the input, we must exclude the first column in the flattenned + // table of the input column from searching (because that column is the materialized bitmask of + // the input structs column). + auto const col_flattened_content = col_flattened.flattened_columns(); + auto const col_flattened_children = table_view{std::vector{ + col_flattened_content.begin() + has_null_elements, col_flattened_content.end()}}; + + auto const d_col_children_ptr = table_device_view::create(col_flattened_children, stream); + auto const d_val_ptr = table_device_view::create(val_flattened, stream); + + auto const count_it = thrust::make_counting_iterator(0); + auto const comp = row_equality_comparator( + nullate::DYNAMIC{has_null_elements}, *d_col_children_ptr, *d_val_ptr, null_equality::EQUAL); + auto const found_iter = thrust::find_if( + rmm::exec_policy(stream), count_it, count_it + col.size(), [comp] __device__(auto idx) { + return comp(idx, 0); // compare col[idx] == val[0]. + }); + + return found_iter != count_it + col.size(); + } }; template <> @@ -202,6 +246,8 @@ bool contains_scalar_dispatch::operator()(column_view const& namespace detail { bool contains(column_view const& col, scalar const& value, rmm::cuda_stream_view stream) { + CUDF_EXPECTS(col.type() == value.type(), "scalar and column types must match"); + if (col.is_empty()) { return false; } if (not value.is_valid(stream)) { return col.has_nulls(); } From 2bdbc9f72f48395a05056f9a79b792cd6dcf4d0e Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Fri, 17 Dec 2021 14:45:39 -0700 Subject: [PATCH 02/12] Fix specialization --- cpp/src/search/search.cu | 93 ++++++++++++++++++---------------------- 1 file changed, 42 insertions(+), 51 deletions(-) diff --git a/cpp/src/search/search.cu b/cpp/src/search/search.cu index d7332d9ab46..a25fd1becd0 100644 --- a/cpp/src/search/search.cu +++ b/cpp/src/search/search.cu @@ -137,8 +137,7 @@ std::unique_ptr search_ordered(table_view const& t, struct contains_scalar_dispatch { template - std::enable_if_t, bool> operator()( - column_view const& col, scalar const& value, rmm::cuda_stream_view stream) + bool operator()(column_view const& col, scalar const& value, rmm::cuda_stream_view stream) { using Type = device_storage_type_t; using ScalarType = cudf::scalar_type_t; @@ -161,51 +160,6 @@ struct contains_scalar_dispatch { return found_iter != d_col->end(); } } - - template - std::enable_if_t, bool> operator()( - column_view const& col, scalar const& value, rmm::cuda_stream_view stream) - { - static_assert(std::is_same_v, struct_scalar>); - auto const scalar_table = static_cast(&value)->view(); - CUDF_EXPECTS(col.num_children() == scalar_table.num_columns(), - "scalar and column must have the same number of children"); - - // Prepare to flatten the structs column and scalar. - auto const has_null_elements = - has_nested_nulls(table_view{std::vector{col.child_begin(), col.child_end()}}) || - has_nested_nulls(scalar_table); - auto const flatten_nullability = has_null_elements - ? structs::detail::column_nullability::FORCE - : structs::detail::column_nullability::MATCH_INCOMING; - - // Flatten the input structs column, only materialize the bitmask if there is null in the input. - auto const col_flattened = - structs::detail::flatten_nested_columns(table_view{{col}}, {}, {}, flatten_nullability); - auto const val_flattened = - structs::detail::flatten_nested_columns(scalar_table, {}, {}, flatten_nullability); - - // The struct scalar only contains the struct member columns. - // Thus, if there is any null in the input, we must exclude the first column in the flattenned - // table of the input column from searching (because that column is the materialized bitmask of - // the input structs column). - auto const col_flattened_content = col_flattened.flattened_columns(); - auto const col_flattened_children = table_view{std::vector{ - col_flattened_content.begin() + has_null_elements, col_flattened_content.end()}}; - - auto const d_col_children_ptr = table_device_view::create(col_flattened_children, stream); - auto const d_val_ptr = table_device_view::create(val_flattened, stream); - - auto const count_it = thrust::make_counting_iterator(0); - auto const comp = row_equality_comparator( - nullate::DYNAMIC{has_null_elements}, *d_col_children_ptr, *d_val_ptr, null_equality::EQUAL); - auto const found_iter = thrust::find_if( - rmm::exec_policy(stream), count_it, count_it + col.size(), [comp] __device__(auto idx) { - return comp(idx, 0); // compare col[idx] == val[0]. - }); - - return found_iter != count_it + col.size(); - } }; template <> @@ -217,11 +171,48 @@ bool contains_scalar_dispatch::operator()(column_view const&, } template <> -bool contains_scalar_dispatch::operator()(column_view const&, - scalar const&, - rmm::cuda_stream_view) +bool contains_scalar_dispatch::operator()(column_view const& col, + scalar const& value, + rmm::cuda_stream_view stream) { - CUDF_FAIL("struct_view type not supported yet"); + auto const scalar_table = static_cast(&value)->view(); + CUDF_EXPECTS(col.num_children() == scalar_table.num_columns(), + "struct scalar and structs column must have the same number of children"); + + // Prepare to flatten the structs column and scalar. + auto const has_null_elements = + has_nested_nulls(table_view{std::vector{col.child_begin(), col.child_end()}}) || + has_nested_nulls(scalar_table); + auto const flatten_nullability = has_null_elements + ? structs::detail::column_nullability::FORCE + : structs::detail::column_nullability::MATCH_INCOMING; + + // Flatten the input structs column, only materialize the bitmask if there is null in the input. + auto const col_flattened = + structs::detail::flatten_nested_columns(table_view{{col}}, {}, {}, flatten_nullability); + auto const val_flattened = + structs::detail::flatten_nested_columns(scalar_table, {}, {}, flatten_nullability); + + // The struct scalar only contains the struct member columns. + // Thus, if there is any null in the input, we must exclude the first column in the flattenned + // table of the input column from searching (because that column is the materialized bitmask of + // the input structs column). + auto const col_flattened_content = col_flattened.flattened_columns(); + auto const col_flattened_children = table_view{std::vector{ + col_flattened_content.begin() + has_null_elements, col_flattened_content.end()}}; + + auto const d_col_children_ptr = table_device_view::create(col_flattened_children, stream); + auto const d_val_ptr = table_device_view::create(val_flattened, stream); + + auto const count_it = thrust::make_counting_iterator(0); + auto const comp = row_equality_comparator( + nullate::DYNAMIC{has_null_elements}, *d_col_children_ptr, *d_val_ptr, null_equality::EQUAL); + auto const found_iter = thrust::find_if( + rmm::exec_policy(stream), count_it, count_it + col.size(), [comp] __device__(auto idx) { + return comp(idx, 0); // compare col[idx] == val[0]. + }); + + return found_iter != count_it + col.size(); } template <> From 6802ba40cb354701637c00e93650fc7e35fa6187 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Thu, 6 Jan 2022 13:28:15 -0700 Subject: [PATCH 03/12] Remove unused variables --- cpp/src/search/search.cu | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/cpp/src/search/search.cu b/cpp/src/search/search.cu index a25fd1becd0..b7d6f17d74b 100644 --- a/cpp/src/search/search.cu +++ b/cpp/src/search/search.cu @@ -240,7 +240,6 @@ bool contains(column_view const& col, scalar const& value, rmm::cuda_stream_view CUDF_EXPECTS(col.type() == value.type(), "scalar and column types must match"); if (col.is_empty()) { return false; } - if (not value.is_valid(stream)) { return col.has_nulls(); } return cudf::type_dispatcher(col.type(), contains_scalar_dispatch{}, col, value, stream); @@ -301,20 +300,14 @@ struct multi_contains_dispatch { template <> std::unique_ptr multi_contains_dispatch::operator()( - column_view const& haystack, - column_view const& needles, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) + column_view const&, column_view const&, rmm::cuda_stream_view, rmm::mr::device_memory_resource*) { CUDF_FAIL("list_view type not supported"); } template <> std::unique_ptr multi_contains_dispatch::operator()( - column_view const& haystack, - column_view const& needles, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr) + column_view const&, column_view const&, rmm::cuda_stream_view, rmm::mr::device_memory_resource*) { CUDF_FAIL("struct_view type not supported"); } From d902aa02191dc4d220c4d28708702a8d5068281b Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Thu, 6 Jan 2022 13:32:22 -0700 Subject: [PATCH 04/12] Update comments --- cpp/src/search/search.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/search/search.cu b/cpp/src/search/search.cu index b7d6f17d74b..40b85b67468 100644 --- a/cpp/src/search/search.cu +++ b/cpp/src/search/search.cu @@ -195,8 +195,8 @@ bool contains_scalar_dispatch::operator()(column_view const& // The struct scalar only contains the struct member columns. // Thus, if there is any null in the input, we must exclude the first column in the flattenned - // table of the input column from searching (because that column is the materialized bitmask of - // the input structs column). + // table of the input column from searching because that column is the materialized bitmask of + // the input structs column. auto const col_flattened_content = col_flattened.flattened_columns(); auto const col_flattened_children = table_view{std::vector{ col_flattened_content.begin() + has_null_elements, col_flattened_content.end()}}; From aa2b5f847235115d4b5779a16843868b6912847c Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Fri, 7 Jan 2022 16:08:46 -0700 Subject: [PATCH 05/12] Add new tests --- cpp/tests/search/search_struct_test.cpp | 214 +++++++++++++++++++++++- 1 file changed, 210 insertions(+), 4 deletions(-) diff --git a/cpp/tests/search/search_struct_test.cpp b/cpp/tests/search/search_struct_test.cpp index db2ecb89d6a..ef55eb02af5 100644 --- a/cpp/tests/search/search_struct_test.cpp +++ b/cpp/tests/search/search_struct_test.cpp @@ -21,6 +21,7 @@ #include #include +#include #include #include @@ -35,15 +36,15 @@ constexpr cudf::test::debug_output_level verbosity{cudf::test::debug_output_leve constexpr int32_t null{0}; // Mark for null child elements constexpr int32_t XXX{0}; // Mark for null struct elements -template -struct TypedStructSearchTest : public cudf::test::BaseFixture { -}; - using TestTypes = cudf::test::Concat; +template +struct TypedStructSearchTest : public cudf::test::BaseFixture { +}; TYPED_TEST_SUITE(TypedStructSearchTest, TestTypes); namespace { @@ -353,3 +354,208 @@ TYPED_TEST(TypedStructSearchTest, ComplexStructTest) CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_lower_bound, results.first->view(), verbosity); CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_upper_bound, results.second->view(), verbosity); } + +template +struct TypedScalarStructContainTest : public cudf::test::BaseFixture { +}; +TYPED_TEST_SUITE(TypedScalarStructContainTest, TestTypes); + +TYPED_TEST(TypedScalarStructContainTest, EmptyInputTest) +{ + using col_wrapper = cudf::test::fixed_width_column_wrapper; + + auto const col = [] { + auto child = col_wrapper{}; + return structs_col{{child}}; + }(); + + auto const val = [] { + auto child = col_wrapper{1}; + return cudf::struct_scalar(std::vector{child}); + }(); + + EXPECT_EQ(false, cudf::contains(col, val)); +} + +TYPED_TEST(TypedScalarStructContainTest, TrivialInputTests) +{ + using col_wrapper = cudf::test::fixed_width_column_wrapper; + + auto const col = [] { + auto child1 = col_wrapper{1, 2, 3}; + auto child2 = col_wrapper{4, 5, 6}; + auto child3 = strings_col{"x", "y", "z"}; + return structs_col{{child1, child2, child3}}; + }(); + + auto const val1 = [] { + auto child1 = col_wrapper{1}; + auto child2 = col_wrapper{4}; + auto child3 = strings_col{"x"}; + return cudf::struct_scalar(std::vector{child1, child2, child3}); + }(); + auto const val2 = [] { + auto child1 = col_wrapper{1}; + auto child2 = col_wrapper{4}; + auto child3 = strings_col{"a"}; + return cudf::struct_scalar(std::vector{child1, child2, child3}); + }(); + + EXPECT_EQ(true, cudf::contains(col, val1)); + EXPECT_EQ(false, cudf::contains(col, val2)); +} + +TYPED_TEST(TypedScalarStructContainTest, SlicedColumnInputTests) +{ + using col_wrapper = cudf::test::fixed_width_column_wrapper; + + constexpr int32_t dont_care{0}; + + auto const col_original = [] { + auto child1 = col_wrapper{dont_care, dont_care, 1, 2, 3, dont_care}; + auto child2 = col_wrapper{dont_care, dont_care, 4, 5, 6, dont_care}; + auto child3 = strings_col{"dont_care", "dont_care", "x", "y", "z", "dont_care"}; + return structs_col{{child1, child2, child3}}; + }(); + auto const col = cudf::slice(col_original, {2, 5})[0]; + + auto const val1 = [] { + auto child1 = col_wrapper{1}; + auto child2 = col_wrapper{4}; + auto child3 = strings_col{"x"}; + return cudf::struct_scalar(std::vector{child1, child2, child3}); + }(); + auto const val2 = [] { + auto child1 = col_wrapper{dont_care}; + auto child2 = col_wrapper{dont_care}; + auto child3 = strings_col{"dont_care"}; + return cudf::struct_scalar(std::vector{child1, child2, child3}); + }(); + + EXPECT_EQ(true, cudf::contains(col, val1)); + EXPECT_EQ(false, cudf::contains(col, val2)); +} + +TYPED_TEST(TypedScalarStructContainTest, SimpleInputWithNullsTests) +{ + using col_wrapper = cudf::test::fixed_width_column_wrapper; + + constexpr int32_t null{0}; + + // Test with nulls at the top level. + { + auto const col = [] { + auto child1 = col_wrapper{1, null, 3}; + auto child2 = col_wrapper{4, null, 6}; + auto child3 = strings_col{"x", "" /*NULL*/, "z"}; + return structs_col{{child1, child2, child3}, null_at(1)}; + }(); + + auto const val1 = [] { + auto child1 = col_wrapper{1}; + auto child2 = col_wrapper{4}; + auto child3 = strings_col{"x"}; + return cudf::struct_scalar(std::vector{child1, child2, child3}); + }(); + auto const val2 = [] { + auto child1 = col_wrapper{1}; + auto child2 = col_wrapper{4}; + auto child3 = strings_col{"a"}; + return cudf::struct_scalar(std::vector{child1, child2, child3}); + }(); + + EXPECT_EQ(true, cudf::contains(col, val1)); + EXPECT_EQ(false, cudf::contains(col, val2)); + } + + // Test with nulls at the children level. + { + auto const col = [] { + auto child1 = col_wrapper{{1, null, 3}, null_at(1)}; + auto child2 = col_wrapper{{4, null, 6}, null_at(1)}; + auto child3 = strings_col{{"" /*NULL*/, "y", "z"}, null_at(0)}; + return structs_col{{child1, child2, child3}}; + }(); + + auto const val1 = [] { + auto child1 = col_wrapper{1}; + auto child2 = col_wrapper{4}; + auto child3 = strings_col{{"x"}, null_at(0)}; + return cudf::struct_scalar(std::vector{child1, child2, child3}); + }(); + auto const val2 = [] { + auto child1 = col_wrapper{1}; + auto child2 = col_wrapper{4}; + auto child3 = strings_col{""}; + return cudf::struct_scalar(std::vector{child1, child2, child3}); + }(); + + EXPECT_EQ(true, cudf::contains(col, val1)); + EXPECT_EQ(false, cudf::contains(col, val2)); + } +} + +TYPED_TEST(TypedScalarStructContainTest, SlicedInputWithNullsTests) +{ + using col_wrapper = cudf::test::fixed_width_column_wrapper; + + constexpr int32_t dont_care{0}; + constexpr int32_t null{0}; + + // Test with nulls at the top level. + { + auto const col_original = [] { + auto child1 = col_wrapper{dont_care, dont_care, 1, null, 3, dont_care}; + auto child2 = col_wrapper{dont_care, dont_care, 4, null, 6, dont_care}; + auto child3 = strings_col{"dont_care", "dont_care", "x", "" /*NULL*/, "z", "dont_care"}; + return structs_col{{child1, child2, child3}, null_at(3)}; + }(); + auto const col = cudf::slice(col_original, {2, 5})[0]; + + auto const val1 = [] { + auto child1 = col_wrapper{1}; + auto child2 = col_wrapper{4}; + auto child3 = strings_col{"x"}; + return cudf::struct_scalar(std::vector{child1, child2, child3}); + }(); + auto const val2 = [] { + auto child1 = col_wrapper{1}; + auto child2 = col_wrapper{4}; + auto child3 = strings_col{"a"}; + return cudf::struct_scalar(std::vector{child1, child2, child3}); + }(); + + EXPECT_EQ(true, cudf::contains(col, val1)); + EXPECT_EQ(false, cudf::contains(col, val2)); + } + + // Test with nulls at the children level. + { + auto const col_original = [] { + auto child1 = + col_wrapper{{dont_care, dont_care /*also NULL*/, 1, null, 3, dont_care}, null_at(3)}; + auto child2 = + col_wrapper{{dont_care, dont_care /*also NULL*/, 4, null, 6, dont_care}, null_at(3)}; + auto child3 = strings_col{ + {"dont_care", "dont_care" /*also NULL*/, "" /*NULL*/, "y", "z", "dont_care"}, null_at(2)}; + return structs_col{{child1, child2, child3}, null_at(1)}; + }(); + auto const col = cudf::slice(col_original, {2, 5})[0]; + + auto const val1 = [] { + auto child1 = col_wrapper{1}; + auto child2 = col_wrapper{4}; + auto child3 = strings_col{{"x"}, null_at(0)}; + return cudf::struct_scalar(std::vector{child1, child2, child3}); + }(); + auto const val2 = [] { + auto child1 = col_wrapper{dont_care}; + auto child2 = col_wrapper{dont_care}; + auto child3 = strings_col{"dont_care"}; + return cudf::struct_scalar(std::vector{child1, child2, child3}); + }(); + + EXPECT_EQ(true, cudf::contains(col, val1)); + EXPECT_EQ(false, cudf::contains(col, val2)); + } +} From 8da3b26a51c2fae741464f8bd7078b0d0f24655e Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Fri, 7 Jan 2022 16:34:18 -0700 Subject: [PATCH 06/12] Fix type check --- cpp/src/search/search.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/src/search/search.cu b/cpp/src/search/search.cu index 40b85b67468..3210cba7962 100644 --- a/cpp/src/search/search.cu +++ b/cpp/src/search/search.cu @@ -175,6 +175,7 @@ bool contains_scalar_dispatch::operator()(column_view const& scalar const& value, rmm::cuda_stream_view stream) { + CUDF_EXPECTS(col.type() == value.type(), "scalar and column types must match"); auto const scalar_table = static_cast(&value)->view(); CUDF_EXPECTS(col.num_children() == scalar_table.num_columns(), "struct scalar and structs column must have the same number of children"); @@ -237,8 +238,6 @@ bool contains_scalar_dispatch::operator()(column_view const& namespace detail { bool contains(column_view const& col, scalar const& value, rmm::cuda_stream_view stream) { - CUDF_EXPECTS(col.type() == value.type(), "scalar and column types must match"); - if (col.is_empty()) { return false; } if (not value.is_valid(stream)) { return col.has_nulls(); } From 9a2999d3338b28a24855b843d4aaa2df6065491f Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Tue, 11 Jan 2022 05:55:55 -0700 Subject: [PATCH 07/12] Update copyright year --- cpp/src/search/search.cu | 2 +- cpp/tests/search/search_struct_test.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/search/search.cu b/cpp/src/search/search.cu index 3210cba7962..fa74839472c 100644 --- a/cpp/src/search/search.cu +++ b/cpp/src/search/search.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/tests/search/search_struct_test.cpp b/cpp/tests/search/search_struct_test.cpp index ef55eb02af5..9920eb025a7 100644 --- a/cpp/tests/search/search_struct_test.cpp +++ b/cpp/tests/search/search_struct_test.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From fe40316c7d52f72993459fc463644879f3267da3 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Tue, 11 Jan 2022 08:00:32 -0700 Subject: [PATCH 08/12] Add back type check --- cpp/src/search/search.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/src/search/search.cu b/cpp/src/search/search.cu index fa74839472c..76973249403 100644 --- a/cpp/src/search/search.cu +++ b/cpp/src/search/search.cu @@ -139,6 +139,8 @@ struct contains_scalar_dispatch { template bool operator()(column_view const& col, scalar const& value, rmm::cuda_stream_view stream) { + CUDF_EXPECTS(col.type() == value.type(), "scalar and column types must match"); + using Type = device_storage_type_t; using ScalarType = cudf::scalar_type_t; auto d_col = column_device_view::create(col, stream); From c0d1f03c0f1136b88fccf8bec96d061be64317f7 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Tue, 11 Jan 2022 08:08:05 -0700 Subject: [PATCH 09/12] Rename variables --- cpp/src/search/search.cu | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/cpp/src/search/search.cu b/cpp/src/search/search.cu index 76973249403..9048dde1b98 100644 --- a/cpp/src/search/search.cu +++ b/cpp/src/search/search.cu @@ -207,15 +207,16 @@ bool contains_scalar_dispatch::operator()(column_view const& auto const d_col_children_ptr = table_device_view::create(col_flattened_children, stream); auto const d_val_ptr = table_device_view::create(val_flattened, stream); - auto const count_it = thrust::make_counting_iterator(0); - auto const comp = row_equality_comparator( + auto const start_iter = thrust::make_counting_iterator(0); + auto const end_iter = start_iter + col.size(); + auto const comp = row_equality_comparator( nullate::DYNAMIC{has_null_elements}, *d_col_children_ptr, *d_val_ptr, null_equality::EQUAL); auto const found_iter = thrust::find_if( - rmm::exec_policy(stream), count_it, count_it + col.size(), [comp] __device__(auto idx) { + rmm::exec_policy(stream), start_iter, end_iter, [comp] __device__(auto const idx) { return comp(idx, 0); // compare col[idx] == val[0]. }); - return found_iter != count_it + col.size(); + return found_iter != end_iter; } template <> From e2b5a49d0f69a843205dc3598cac0176d4224ebe Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Tue, 11 Jan 2022 08:21:47 -0700 Subject: [PATCH 10/12] Fix comment in test --- cpp/tests/search/search_struct_test.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cpp/tests/search/search_struct_test.cpp b/cpp/tests/search/search_struct_test.cpp index 9920eb025a7..ea96362014c 100644 --- a/cpp/tests/search/search_struct_test.cpp +++ b/cpp/tests/search/search_struct_test.cpp @@ -37,7 +37,6 @@ constexpr int32_t null{0}; // Mark for null child elements constexpr int32_t XXX{0}; // Mark for null struct elements using TestTypes = cudf::test::Concat; @@ -480,7 +479,7 @@ TYPED_TEST(TypedScalarStructContainTest, SimpleInputWithNullsTests) auto const val1 = [] { auto child1 = col_wrapper{1}; auto child2 = col_wrapper{4}; - auto child3 = strings_col{{"x"}, null_at(0)}; + auto child3 = strings_col{{"" /*NULL*/}, null_at(0)}; return cudf::struct_scalar(std::vector{child1, child2, child3}); }(); auto const val2 = [] { From a9824dea579f1bf1c71db1102f31d877395a5883 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Tue, 11 Jan 2022 08:25:01 -0700 Subject: [PATCH 11/12] Add a test --- cpp/tests/search/search_struct_test.cpp | 26 +++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/cpp/tests/search/search_struct_test.cpp b/cpp/tests/search/search_struct_test.cpp index ea96362014c..a1f0b1d81cf 100644 --- a/cpp/tests/search/search_struct_test.cpp +++ b/cpp/tests/search/search_struct_test.cpp @@ -492,6 +492,32 @@ TYPED_TEST(TypedScalarStructContainTest, SimpleInputWithNullsTests) EXPECT_EQ(true, cudf::contains(col, val1)); EXPECT_EQ(false, cudf::contains(col, val2)); } + + // Test with nulls in the input scalar. + { + auto const col = [] { + auto child1 = col_wrapper{1, 2, 3}; + auto child2 = col_wrapper{4, 5, 6}; + auto child3 = strings_col{"x", "y", "z"}; + return structs_col{{child1, child2, child3}}; + }(); + + auto const val1 = [] { + auto child1 = col_wrapper{1}; + auto child2 = col_wrapper{4}; + auto child3 = strings_col{"x"}; + return cudf::struct_scalar(std::vector{child1, child2, child3}); + }(); + auto const val2 = [] { + auto child1 = col_wrapper{1}; + auto child2 = col_wrapper{4}; + auto child3 = strings_col{{"" /*NULL*/}, null_at(0)}; + return cudf::struct_scalar(std::vector{child1, child2, child3}); + }(); + + EXPECT_EQ(true, cudf::contains(col, val1)); + EXPECT_EQ(false, cudf::contains(col, val2)); + } } TYPED_TEST(TypedScalarStructContainTest, SlicedInputWithNullsTests) From 1e7e13bb551aed701bfb3d5b45f116b4748b5711 Mon Sep 17 00:00:00 2001 From: Nghia Truong Date: Thu, 13 Jan 2022 08:32:50 -0700 Subject: [PATCH 12/12] Address review comments --- cpp/src/search/search.cu | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/cpp/src/search/search.cu b/cpp/src/search/search.cu index 9048dde1b98..81ed3cfbd51 100644 --- a/cpp/src/search/search.cu +++ b/cpp/src/search/search.cu @@ -178,9 +178,14 @@ bool contains_scalar_dispatch::operator()(column_view const& rmm::cuda_stream_view stream) { CUDF_EXPECTS(col.type() == value.type(), "scalar and column types must match"); + auto const scalar_table = static_cast(&value)->view(); CUDF_EXPECTS(col.num_children() == scalar_table.num_columns(), "struct scalar and structs column must have the same number of children"); + for (size_type i = 0; i < col.num_children(); ++i) { + CUDF_EXPECTS(col.child(i).type() == scalar_table.column(i).type(), + "scalar and column children types must match"); + } // Prepare to flatten the structs column and scalar. auto const has_null_elements = @@ -197,12 +202,13 @@ bool contains_scalar_dispatch::operator()(column_view const& structs::detail::flatten_nested_columns(scalar_table, {}, {}, flatten_nullability); // The struct scalar only contains the struct member columns. - // Thus, if there is any null in the input, we must exclude the first column in the flattenned + // Thus, if there is any null in the input, we must exclude the first column in the flattened // table of the input column from searching because that column is the materialized bitmask of // the input structs column. auto const col_flattened_content = col_flattened.flattened_columns(); auto const col_flattened_children = table_view{std::vector{ - col_flattened_content.begin() + has_null_elements, col_flattened_content.end()}}; + col_flattened_content.begin() + static_cast(has_null_elements), + col_flattened_content.end()}}; auto const d_col_children_ptr = table_device_view::create(col_flattened_children, stream); auto const d_val_ptr = table_device_view::create(val_flattened, stream);