diff --git a/cpp/include/cudf/table/experimental/row_operators.cuh b/cpp/include/cudf/table/experimental/row_operators.cuh index 0dc0f4e5315..f9ffbfcdf7b 100644 --- a/cpp/include/cudf/table/experimental/row_operators.cuh +++ b/cpp/include/cudf/table/experimental/row_operators.cuh @@ -245,6 +245,16 @@ using optional_dremel_view = thrust::optional; * second letter in both words is the first non-equal letter, and `a < b`, thus * `aac < abb`. * + * @note The operator overloads in sub-class `element_comparator` are templated via the + * `type_dispatcher` to help select an overload instance for each column in a table. + * So, `cudf::is_nested` will return `true` if the table has nested-type columns, + * but it will be a runtime error if template parameter `has_nested_columns != true`. + * + * @tparam has_nested_columns compile-time optimization for primitive types. + * This template parameter is to be used by the developer by querying + * `cudf::detail::has_nested_columns(input)`. `true` compiles operator + * overloads for nested types, while `false` only compiles operator + * overloads for primitive types. * @tparam Nullate A cudf::nullate type describing whether to check for nulls. * @tparam PhysicalElementComparator A relational comparator functor that compares individual values * rather than logical elements, defaults to `NaN` aware relational comparator that evaluates `NaN` @@ -857,6 +867,16 @@ class self_comparator { * * `F(i,j)` returns true if and only if row `i` compares lexicographically less than row `j`. * + * @note The operator overloads in sub-class `element_comparator` are templated via the + * `type_dispatcher` to help select an overload instance for each column in a table. + * So, `cudf::is_nested` will return `true` if the table has nested-type columns, + * but it will be a runtime error if template parameter `has_nested_columns != true`. + * + * @tparam has_nested_columns compile-time optimization for primitive types. + * This template parameter is to be used by the developer by querying + * `cudf::detail::has_nested_columns(input)`. `true` compiles operator + * overloads for nested types, while `false` only compiles operator + * overloads for primitive types. * @tparam Nullate A cudf::nullate type describing whether to check for nulls. * @tparam PhysicalElementComparator A relational comparator functor that compares individual * values rather than logical elements, defaults to `NaN` aware relational comparator that @@ -1009,6 +1029,16 @@ class two_table_comparator { * only if row `i` of the right table compares lexicographically less than row * `j` of the left table. * + * @note The operator overloads in sub-class `element_comparator` are templated via the + * `type_dispatcher` to help select an overload instance for each column in a table. + * So, `cudf::is_nested` will return `true` if the table has nested-type columns, + * but it will be a runtime error if template parameter `has_nested_columns != true`. + * + * @tparam has_nested_columns compile-time optimization for primitive types. + * This template parameter is to be used by the developer by querying + * `cudf::detail::has_nested_columns(input)`. `true` compiles operator + * overloads for nested types, while `false` only compiles operator + * overloads for primitive types. * @tparam Nullate A cudf::nullate type describing whether to check for nulls. * @tparam PhysicalElementComparator A relational comparator functor that compares individual * values rather than logical elements, defaults to `NaN` aware relational comparator that @@ -1131,11 +1161,22 @@ struct nan_equal_physical_equality_comparator { * returns false, representing unequal rows. If the rows are compared without mismatched elements, * the rows are equal. * + * @note The operator overloads in sub-class `element_comparator` are templated via the + * `type_dispatcher` to help select an overload instance for each column in a table. + * So, `cudf::is_nested` will return `true` if the table has nested-type columns, + * but it will be a runtime error if template parameter `has_nested_columns != true`. + * + * @tparam has_nested_columns compile-time optimization for primitive types. + * This template parameter is to be used by the developer by querying + * `cudf::detail::has_nested_columns(input)`. `true` compiles operator + * overloads for nested types, while `false` only compiles operator + * overloads for primitive types. * @tparam Nullate A cudf::nullate type describing whether to check for nulls. * @tparam PhysicalEqualityComparator A equality comparator functor that compares individual values * rather than logical elements, defaults to a comparator for which `NaN == NaN`. */ -template class device_row_comparator { friend class self_comparator; ///< Allow self_comparator to access private members @@ -1246,14 +1287,14 @@ class device_row_comparator { template () and - not cudf::is_nested()), + (not has_nested_columns or not cudf::is_nested())), typename... Args> __device__ bool operator()(Args...) { CUDF_UNREACHABLE("Attempted to compare elements of uncomparable types."); } - template ())> + template ())> __device__ bool operator()(size_type const lhs_element_index, size_type const rhs_element_index) const noexcept { @@ -1437,6 +1478,16 @@ class self_comparator { * * `F(i,j)` returns true if and only if row `i` compares equal to row `j`. * + * @note The operator overloads in sub-class `element_comparator` are templated via the + * `type_dispatcher` to help select an overload instance for each column in a table. + * So, `cudf::is_nested` will return `true` if the table has nested-type columns, + * but it will be a runtime error if template parameter `has_nested_columns != true`. + * + * @tparam has_nested_columns compile-time optimization for primitive types. + * This template parameter is to be used by the developer by querying + * `cudf::detail::has_nested_columns(input)`. `true` compiles operator + * overloads for nested types, while `false` only compiles operator + * overloads for primitive types. * @tparam Nullate A cudf::nullate type describing whether to check for nulls. * @tparam PhysicalEqualityComparator A equality comparator functor that compares individual * values rather than logical elements, defaults to a comparator for which `NaN == NaN`. @@ -1445,13 +1496,15 @@ class self_comparator { * @param comparator Physical element equality comparison functor. * @return A binary callable object */ - template auto equal_to(Nullate nullate = {}, null_equality nulls_are_equal = null_equality::EQUAL, PhysicalEqualityComparator comparator = {}) const noexcept { - return device_row_comparator{nullate, *d_t, *d_t, nulls_are_equal, comparator}; + return device_row_comparator{ + nullate, *d_t, *d_t, nulls_are_equal, comparator}; } private: @@ -1539,6 +1592,16 @@ class two_table_comparator { * Similarly, `F(rhs_index_type i, lhs_index_type j)` returns true if and only if row `i` of the * right table compares equal to row `j` of the left table. * + * @note The operator overloads in sub-class `element_comparator` are templated via the + * `type_dispatcher` to help select an overload instance for each column in a table. + * So, `cudf::is_nested` will return `true` if the table has nested-type columns, + * but it will be a runtime error if template parameter `has_nested_columns != true`. + * + * @tparam has_nested_columns compile-time optimization for primitive types. + * This template parameter is to be used by the developer by querying + * `cudf::detail::has_nested_columns(input)`. `true` compiles operator + * overloads for nested types, while `false` only compiles operator + * overloads for primitive types. * @tparam Nullate A cudf::nullate type describing whether to check for nulls. * @tparam PhysicalEqualityComparator A equality comparator functor that compares individual * values rather than logical elements, defaults to a `NaN == NaN` equality comparator. @@ -1547,14 +1610,16 @@ class two_table_comparator { * @param comparator Physical element equality comparison functor. * @return A binary callable object */ - template auto equal_to(Nullate nullate = {}, null_equality nulls_are_equal = null_equality::EQUAL, PhysicalEqualityComparator comparator = {}) const noexcept { return strong_index_comparator_adapter{ - device_row_comparator(nullate, *d_left_table, *d_right_table, nulls_are_equal, comparator)}; + device_row_comparator( + nullate, *d_left_table, *d_right_table, nulls_are_equal, comparator)}; } private: diff --git a/cpp/src/binaryop/compiled/struct_binary_ops.cuh b/cpp/src/binaryop/compiled/struct_binary_ops.cuh index 2fcf1ce4e32..d167f0fe3c5 100644 --- a/cpp/src/binaryop/compiled/struct_binary_ops.cuh +++ b/cpp/src/binaryop/compiled/struct_binary_ops.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -106,6 +106,36 @@ void apply_struct_binary_op(mutable_column_view& out, } } +template +struct struct_equality_functor { + struct_equality_functor(OptionalIteratorType optional_iter, + DeviceComparatorType device_comparator, + bool is_lhs_scalar, + bool is_rhs_scalar, + bool preserve_output) + : _optional_iter(optional_iter), + _device_comparator(device_comparator), + _is_lhs_scalar(is_lhs_scalar), + _is_rhs_scalar(is_rhs_scalar), + _preserve_output(preserve_output) + { + } + + auto __device__ operator()(size_type i) const noexcept + { + auto const lhs = cudf::experimental::row::lhs_index_type{_is_lhs_scalar ? 0 : i}; + auto const rhs = cudf::experimental::row::rhs_index_type{_is_rhs_scalar ? 0 : i}; + return _optional_iter[i].has_value() and (_device_comparator(lhs, rhs) == _preserve_output); + } + + private: + OptionalIteratorType _optional_iter; + DeviceComparatorType _device_comparator; + bool _is_lhs_scalar; + bool _is_rhs_scalar; + bool _preserve_output; +}; + template void apply_struct_equality_op(mutable_column_view& out, @@ -125,26 +155,37 @@ void apply_struct_equality_op(mutable_column_view& out, auto trhs = table_view{{rhs}}; auto table_comparator = cudf::experimental::row::equality::two_table_comparator{tlhs, trhs, stream}; - auto device_comparator = - table_comparator.equal_to(nullate::DYNAMIC{has_nested_nulls(tlhs) || has_nested_nulls(trhs)}, - null_equality::EQUAL, - comparator); auto outd = column_device_view::create(out, stream); auto optional_iter = cudf::detail::make_optional_iterator(*outd, nullate::DYNAMIC{out.has_nulls()}); - thrust::tabulate(rmm::exec_policy(stream), - out.begin(), - out.end(), - [optional_iter, - is_lhs_scalar, - is_rhs_scalar, - preserve_output = (op != binary_operator::NOT_EQUAL), - device_comparator] __device__(size_type i) { - auto lhs = cudf::experimental::row::lhs_index_type{is_lhs_scalar ? 0 : i}; - auto rhs = cudf::experimental::row::rhs_index_type{is_rhs_scalar ? 0 : i}; - return optional_iter[i].has_value() and - (device_comparator(lhs, rhs) == preserve_output); - }); + + auto const comparator_helper = [&](auto const device_comparator) { + thrust::tabulate(rmm::exec_policy(stream), + out.begin(), + out.end(), + struct_equality_functor( + optional_iter, + device_comparator, + is_lhs_scalar, + is_rhs_scalar, + op != binary_operator::NOT_EQUAL)); + }; + + if (cudf::detail::has_nested_columns(tlhs) or cudf::detail::has_nested_columns(trhs)) { + auto device_comparator = table_comparator.equal_to( + nullate::DYNAMIC{has_nested_nulls(tlhs) || has_nested_nulls(trhs)}, + null_equality::EQUAL, + comparator); + + comparator_helper(device_comparator); + } else { + auto device_comparator = table_comparator.equal_to( + nullate::DYNAMIC{has_nested_nulls(tlhs) || has_nested_nulls(trhs)}, + null_equality::EQUAL, + comparator); + + comparator_helper(device_comparator); + } } } // namespace cudf::binops::compiled::detail diff --git a/cpp/src/groupby/hash/groupby.cu b/cpp/src/groupby/hash/groupby.cu index 50173d6a987..72ac6255549 100644 --- a/cpp/src/groupby/hash/groupby.cu +++ b/cpp/src/groupby/hash/groupby.cu @@ -68,12 +68,13 @@ namespace { // TODO: replace it with `cuco::static_map` // https://github.com/rapidsai/cudf/issues/10401 -using map_type = concurrent_unordered_map< - cudf::size_type, - cudf::size_type, - cudf::experimental::row::hash::device_row_hasher, - cudf::experimental::row::equality::device_row_comparator>; +template +using map_type = + concurrent_unordered_map, + ComparatorType>; /** * @brief List of aggregation operations that can be computed with a hash-based @@ -189,13 +190,14 @@ class groupby_simple_aggregations_collector final } }; +template class hash_compound_agg_finalizer final : public cudf::detail::aggregation_finalizer { column_view col; data_type result_type; cudf::detail::result_cache* sparse_results; cudf::detail::result_cache* dense_results; device_span gather_map; - map_type const& map; + map_type const& map; bitmask_type const* __restrict__ row_bitmask; rmm::cuda_stream_view stream; rmm::mr::device_memory_resource* mr; @@ -207,7 +209,7 @@ class hash_compound_agg_finalizer final : public cudf::detail::aggregation_final cudf::detail::result_cache* sparse_results, cudf::detail::result_cache* dense_results, device_span gather_map, - map_type const& map, + map_type const& map, bitmask_type const* row_bitmask, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) @@ -336,7 +338,7 @@ class hash_compound_agg_finalizer final : public cudf::detail::aggregation_final rmm::exec_policy(stream), thrust::make_counting_iterator(0), col.size(), - ::cudf::detail::var_hash_functor{ + ::cudf::detail::var_hash_functor>{ map, row_bitmask, *var_result_view, *values_view, *sum_view, *count_view, agg._ddof}); sparse_results->add_result(col, agg, std::move(var_result)); dense_results->add_result(col, agg, to_dense_agg_result(agg)); @@ -394,12 +396,13 @@ flatten_single_pass_aggs(host_span requests) * * @see groupby_null_templated() */ +template void sparse_to_dense_results(table_view const& keys, host_span requests, cudf::detail::result_cache* sparse_results, cudf::detail::result_cache* dense_results, device_span gather_map, - map_type const& map, + map_type const& map, bool keys_have_nulls, null_policy include_null_keys, rmm::cuda_stream_view stream, @@ -461,10 +464,11 @@ auto create_sparse_results_table(table_view const& flattened_values, * @brief Computes all aggregations from `requests` that require a single pass * over the data and stores the results in `sparse_results` */ +template void compute_single_pass_aggs(table_view const& keys, host_span requests, cudf::detail::result_cache* sparse_results, - map_type& map, + map_type& map, bool keys_have_nulls, null_policy include_null_keys, rmm::cuda_stream_view stream) @@ -484,16 +488,16 @@ void compute_single_pass_aggs(table_view const& keys, auto row_bitmask = skip_key_rows_with_nulls ? cudf::detail::bitmask_and(keys, stream).first : rmm::device_buffer{}; - thrust::for_each_n( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - keys.num_rows(), - hash::compute_single_pass_aggs_fn{map, - *d_values, - *d_sparse_table, - d_aggs.data(), - static_cast(row_bitmask.data()), - skip_key_rows_with_nulls}); + thrust::for_each_n(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + keys.num_rows(), + hash::compute_single_pass_aggs_fn>{ + map, + *d_values, + *d_sparse_table, + d_aggs.data(), + static_cast(row_bitmask.data()), + skip_key_rows_with_nulls}); // Add results back to sparse_results cache auto sparse_result_cols = sparse_table.release(); for (size_t i = 0; i < aggs.size(); i++) { @@ -507,7 +511,8 @@ void compute_single_pass_aggs(table_view const& keys, * @brief Computes and returns a device vector containing all populated keys in * `map`. */ -rmm::device_uvector extract_populated_keys(map_type const& map, +template +rmm::device_uvector extract_populated_keys(map_type const& map, size_type num_keys, rmm::cuda_stream_view stream) { @@ -566,52 +571,60 @@ std::unique_ptr groupby(table_view const& keys, auto preprocessed_keys = cudf::experimental::row::hash::preprocessed_table::create(keys, stream); auto const comparator = cudf::experimental::row::equality::self_comparator{preprocessed_keys}; auto const row_hash = cudf::experimental::row::hash::row_hasher{std::move(preprocessed_keys)}; - auto const d_key_equal = comparator.equal_to(has_null, null_keys_are_equal); auto const d_row_hash = row_hash.device_hasher(has_null); size_type constexpr unused_key{std::numeric_limits::max()}; size_type constexpr unused_value{std::numeric_limits::max()}; - using allocator_type = typename map_type::allocator_type; - - auto map = map_type::create(compute_hash_table_size(num_keys), - stream, - unused_key, - unused_value, - d_row_hash, - d_key_equal, - allocator_type()); - // Cache of sparse results where the location of aggregate value in each // column is indexed by the hash map cudf::detail::result_cache sparse_results(requests.size()); - // Compute all single pass aggs first - compute_single_pass_aggs( - keys, requests, &sparse_results, *map, keys_have_nulls, include_null_keys, stream); - - // Extract the populated indices from the hash map and create a gather map. - // Gathering using this map from sparse results will give dense results. - auto gather_map = extract_populated_keys(*map, keys.num_rows(), stream); - - // Compact all results from sparse_results and insert into cache - sparse_to_dense_results(keys, - requests, - &sparse_results, - cache, - gather_map, - *map, - keys_have_nulls, - include_null_keys, - stream, - mr); - - return cudf::detail::gather(keys, - gather_map, - out_of_bounds_policy::DONT_CHECK, - cudf::detail::negative_index_policy::NOT_ALLOWED, - stream, - mr); + auto const comparator_helper = [&](auto const d_key_equal) { + using allocator_type = typename map_type::allocator_type; + + auto const map = map_type::create(compute_hash_table_size(num_keys), + stream, + unused_key, + unused_value, + d_row_hash, + d_key_equal, + allocator_type()); + // Compute all single pass aggs first + compute_single_pass_aggs( + keys, requests, &sparse_results, *map, keys_have_nulls, include_null_keys, stream); + + // Extract the populated indices from the hash map and create a gather map. + // Gathering using this map from sparse results will give dense results. + auto gather_map = extract_populated_keys(*map, keys.num_rows(), stream); + + // Compact all results from sparse_results and insert into cache + sparse_to_dense_results(keys, + requests, + &sparse_results, + cache, + gather_map, + *map, + keys_have_nulls, + include_null_keys, + stream, + mr); + + return cudf::detail::gather(keys, + gather_map, + out_of_bounds_policy::DONT_CHECK, + cudf::detail::negative_index_policy::NOT_ALLOWED, + stream, + mr); + }; + + if (cudf::detail::has_nested_columns(keys)) { + auto const d_key_equal = comparator.equal_to(has_null, null_keys_are_equal); + return comparator_helper(d_key_equal); + } else { + auto const d_key_equal = comparator.equal_to(has_null, null_keys_are_equal); + return comparator_helper(d_key_equal); + } } } // namespace diff --git a/cpp/src/groupby/sort/group_nunique.cu b/cpp/src/groupby/sort/group_nunique.cu index c411e654913..cf81253483e 100644 --- a/cpp/src/groupby/sort/group_nunique.cu +++ b/cpp/src/groupby/sort/group_nunique.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,10 +33,10 @@ namespace groupby { namespace detail { namespace { -template +template struct is_unique_iterator_fn { using comparator_type = - typename cudf::experimental::row::equality::device_row_comparator; + typename cudf::experimental::row::equality::device_row_comparator; Nullate nulls; column_device_view const v; @@ -91,24 +91,35 @@ std::unique_ptr group_nunique(column_view const& values, auto const values_view = table_view{{values}}; auto const comparator = cudf::experimental::row::equality::self_comparator{values_view, stream}; - auto const d_equal = comparator.equal_to( - cudf::nullate::DYNAMIC{cudf::has_nested_nulls(values_view)}, null_equality::EQUAL); auto const d_values_view = column_device_view::create(values, stream); - auto const is_unique_iterator = - thrust::make_transform_iterator(thrust::counting_iterator(0), - is_unique_iterator_fn{nullate::DYNAMIC{values.has_nulls()}, - *d_values_view, - d_equal, - null_handling, - group_offsets.data(), - group_labels.data()}); - thrust::reduce_by_key(rmm::exec_policy(stream), - group_labels.begin(), - group_labels.end(), - is_unique_iterator, - thrust::make_discard_iterator(), - result->mutable_view().begin()); + + auto const comparator_helper = [&](auto const d_equal) { + auto const is_unique_iterator = + thrust::make_transform_iterator(thrust::counting_iterator(0), + is_unique_iterator_fn{nullate::DYNAMIC{values.has_nulls()}, + *d_values_view, + d_equal, + null_handling, + group_offsets.data(), + group_labels.data()}); + thrust::reduce_by_key(rmm::exec_policy(stream), + group_labels.begin(), + group_labels.end(), + is_unique_iterator, + thrust::make_discard_iterator(), + result->mutable_view().begin()); + }; + + if (cudf::detail::has_nested_columns(values_view)) { + auto const d_equal = comparator.equal_to( + cudf::nullate::DYNAMIC{cudf::has_nested_nulls(values_view)}, null_equality::EQUAL); + comparator_helper(d_equal); + } else { + auto const d_equal = comparator.equal_to( + cudf::nullate::DYNAMIC{cudf::has_nested_nulls(values_view)}, null_equality::EQUAL); + comparator_helper(d_equal); + } return result; } diff --git a/cpp/src/groupby/sort/group_rank_scan.cu b/cpp/src/groupby/sort/group_rank_scan.cu index 149f026ffe6..479ce166724 100644 --- a/cpp/src/groupby/sort/group_rank_scan.cu +++ b/cpp/src/groupby/sort/group_rank_scan.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,6 +41,38 @@ namespace groupby { namespace detail { namespace { +template +struct unique_identifier { + unique_identifier(size_type const* labels, + size_type const* offsets, + permuted_equal_t permuted_equal, + value_resolver resolver) + : _labels(labels), _offsets(offsets), _permuted_equal(permuted_equal), _resolver(resolver) + { + } + + auto __device__ operator()(size_type row_index) const noexcept + { + auto const group_start = _offsets[_labels[row_index]]; + if constexpr (forward) { + // First value of equal values is 1. + return _resolver(row_index == group_start || !_permuted_equal(row_index, row_index - 1), + row_index - group_start); + } else { + auto const group_end = _offsets[_labels[row_index] + 1]; + // Last value of equal values is 1. + return _resolver(row_index + 1 == group_end || !_permuted_equal(row_index, row_index + 1), + row_index - group_start); + } + } + + private: + size_type const* _labels; + size_type const* _offsets; + permuted_equal_t _permuted_equal; + value_resolver _resolver; +}; + /** * @brief generate grouped row ranks or dense ranks using a row comparison then scan the results * @@ -71,36 +103,34 @@ std::unique_ptr rank_generator(column_view const& grouped_values, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { + auto const grouped_values_view = table_view{{grouped_values}}; auto const comparator = - cudf::experimental::row::equality::self_comparator{table_view{{grouped_values}}, stream}; - auto const d_equal = comparator.equal_to(cudf::nullate::DYNAMIC{has_nulls}, null_equality::EQUAL); - auto const permuted_equal = - permuted_row_equality_comparator(d_equal, value_order.begin()); + cudf::experimental::row::equality::self_comparator{grouped_values_view, stream}; auto ranks = make_fixed_width_column( data_type{type_to_id()}, grouped_values.size(), mask_state::UNALLOCATED, stream, mr); auto mutable_ranks = ranks->mutable_view(); - auto unique_identifier = [labels = group_labels.begin(), - offsets = group_offsets.begin(), - permuted_equal, - resolver] __device__(size_type row_index) { - auto const group_start = offsets[labels[row_index]]; - if constexpr (forward) { - // First value of equal values is 1. - return resolver(row_index == group_start || !permuted_equal(row_index, row_index - 1), - row_index - group_start); - } else { - auto const group_end = offsets[labels[row_index] + 1]; - // Last value of equal values is 1. - return resolver(row_index + 1 == group_end || !permuted_equal(row_index, row_index + 1), - row_index - group_start); - } + auto const comparator_helper = [&](auto const d_equal) { + auto const permuted_equal = + permuted_row_equality_comparator(d_equal, value_order.begin()); + + thrust::tabulate(rmm::exec_policy(stream), + mutable_ranks.begin(), + mutable_ranks.end(), + unique_identifier( + group_labels.begin(), group_offsets.begin(), permuted_equal, resolver)); }; - thrust::tabulate(rmm::exec_policy(stream), - mutable_ranks.begin(), - mutable_ranks.end(), - unique_identifier); + + if (cudf::detail::has_nested_columns(grouped_values_view)) { + auto const d_equal = + comparator.equal_to(cudf::nullate::DYNAMIC{has_nulls}, null_equality::EQUAL); + comparator_helper(d_equal); + } else { + auto const d_equal = + comparator.equal_to(cudf::nullate::DYNAMIC{has_nulls}, null_equality::EQUAL); + comparator_helper(d_equal); + } auto [group_labels_begin, mutable_rank_begin] = [&]() { if constexpr (forward) { diff --git a/cpp/src/groupby/sort/sort_helper.cu b/cpp/src/groupby/sort/sort_helper.cu index 3be090159a7..b53955472b1 100644 --- a/cpp/src/groupby/sort/sort_helper.cu +++ b/cpp/src/groupby/sort/sort_helper.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -149,17 +149,28 @@ sort_groupby_helper::index_vector const& sort_groupby_helper::group_offsets( _group_offsets = std::make_unique(num_keys(stream) + 1, stream); - auto const comparator = cudf::experimental::row::equality::self_comparator{_keys, stream}; - auto const d_key_equal = comparator.equal_to( - cudf::nullate::DYNAMIC{cudf::has_nested_nulls(_keys)}, null_equality::EQUAL); + auto const comparator = cudf::experimental::row::equality::self_comparator{_keys, stream}; + auto const sorted_order = key_sort_order(stream).data(); decltype(_group_offsets->begin()) result_end; - result_end = thrust::unique_copy(rmm::exec_policy(stream), - thrust::counting_iterator(0), - thrust::counting_iterator(num_keys(stream)), - _group_offsets->begin(), - permuted_row_equality_comparator(d_key_equal, sorted_order)); + if (cudf::detail::has_nested_columns(_keys)) { + auto const d_key_equal = comparator.equal_to( + cudf::nullate::DYNAMIC{cudf::has_nested_nulls(_keys)}, null_equality::EQUAL); + result_end = thrust::unique_copy(rmm::exec_policy(stream), + thrust::counting_iterator(0), + thrust::counting_iterator(num_keys(stream)), + _group_offsets->begin(), + permuted_row_equality_comparator(d_key_equal, sorted_order)); + } else { + auto const d_key_equal = comparator.equal_to( + cudf::nullate::DYNAMIC{cudf::has_nested_nulls(_keys)}, null_equality::EQUAL); + result_end = thrust::unique_copy(rmm::exec_policy(stream), + thrust::counting_iterator(0), + thrust::counting_iterator(num_keys(stream)), + _group_offsets->begin(), + permuted_row_equality_comparator(d_key_equal, sorted_order)); + } size_type num_groups = thrust::distance(_group_offsets->begin(), result_end); _group_offsets->set_element(num_groups, num_keys(stream), stream); diff --git a/cpp/src/lists/contains.cu b/cpp/src/lists/contains.cu index 0142e736fd0..05fe82d1713 100644 --- a/cpp/src/lists/contains.cu +++ b/cpp/src/lists/contains.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -267,7 +267,7 @@ void index_of_nested_types(InputIterator input_it, auto const has_nulls = has_nested_nulls(child_tview) || has_nested_nulls(keys_tview); auto const comparator = cudf::experimental::row::equality::two_table_comparator(child_tview, keys_tview, stream); - auto const d_comp = comparator.equal_to(nullate::DYNAMIC{has_nulls}); + auto const d_comp = comparator.equal_to(nullate::DYNAMIC{has_nulls}); auto const do_search = [=](auto const key_validity_iter) { thrust::transform( diff --git a/cpp/src/reductions/scan/rank_scan.cu b/cpp/src/reductions/scan/rank_scan.cu index c6909bfd601..538763099d3 100644 --- a/cpp/src/reductions/scan/rank_scan.cu +++ b/cpp/src/reductions/scan/rank_scan.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,6 +32,23 @@ namespace cudf { namespace detail { namespace { +template +struct rank_equality_functor { + rank_equality_functor(device_comparator_type comparator, value_resolver resolver) + : _comparator(comparator), _resolver(resolver) + { + } + + auto __device__ operator()(size_type row_index) const noexcept + { + return _resolver(row_index == 0 || !_comparator(row_index, row_index - 1), row_index); + } + + private: + device_comparator_type _comparator; + value_resolver _resolver; +}; + /** * @brief generate row ranks or dense ranks using a row comparison then scan the results * @@ -51,20 +68,30 @@ std::unique_ptr rank_generator(column_view const& order_by, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - auto comp = cudf::experimental::row::equality::self_comparator(table_view{{order_by}}, stream); - auto const device_comparator = - comp.equal_to(nullate::DYNAMIC{has_nested_nulls(table_view({order_by}))}); + auto const order_by_tview = table_view{{order_by}}; + auto comp = cudf::experimental::row::equality::self_comparator(order_by_tview, stream); + auto ranks = make_fixed_width_column( data_type{type_to_id()}, order_by.size(), mask_state::UNALLOCATED, stream, mr); auto mutable_ranks = ranks->mutable_view(); - thrust::tabulate(rmm::exec_policy(stream), - mutable_ranks.begin(), - mutable_ranks.end(), - [comparator = device_comparator, resolver] __device__(size_type row_index) { - return resolver(row_index == 0 || !comparator(row_index, row_index - 1), - row_index); - }); + auto const comparator_helper = [&](auto const device_comparator) { + thrust::tabulate(rmm::exec_policy(stream), + mutable_ranks.begin(), + mutable_ranks.end(), + rank_equality_functor( + device_comparator, resolver)); + }; + + if (cudf::detail::has_nested_columns(order_by_tview)) { + auto const device_comparator = + comp.equal_to(nullate::DYNAMIC{has_nested_nulls(table_view({order_by}))}); + comparator_helper(device_comparator); + } else { + auto const device_comparator = + comp.equal_to(nullate::DYNAMIC{has_nested_nulls(table_view({order_by}))}); + comparator_helper(device_comparator); + } thrust::inclusive_scan(rmm::exec_policy(stream), mutable_ranks.begin(), diff --git a/cpp/src/search/contains_scalar.cu b/cpp/src/search/contains_scalar.cu index 8c500e1e757..093a1f8f1ed 100644 --- a/cpp/src/search/contains_scalar.cu +++ b/cpp/src/search/contains_scalar.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -99,7 +99,6 @@ struct contains_scalar_dispatch { auto const comparator = cudf::experimental::row::equality::two_table_comparator(haystack_tv, needle_tv, stream); - auto const d_comp = comparator.equal_to(nullate::DYNAMIC{has_nulls}); auto const begin = cudf::experimental::row::lhs_iterator(0); auto const end = begin + haystack.size(); @@ -108,6 +107,7 @@ struct contains_scalar_dispatch { auto const check_nulls = haystack.has_nulls(); auto const haystack_cdv_ptr = column_device_view::create(haystack, stream); + auto const d_comp = comparator.equal_to(nullate::DYNAMIC{has_nulls}); return thrust::count_if( rmm::exec_policy(stream), begin, diff --git a/cpp/src/search/contains_table.cu b/cpp/src/search/contains_table.cu index 639dc503ce4..c1cc4659a19 100644 --- a/cpp/src/search/contains_table.cu +++ b/cpp/src/search/contains_table.cu @@ -204,27 +204,45 @@ rmm::device_uvector contains_with_lists_or_nans(table_view const& haystack auto const bitmask_buffer_and_ptr = build_row_bitmask(haystack, stream); auto const row_bitmask_ptr = bitmask_buffer_and_ptr.second; - // Insert only rows that do not have any null at any level. auto const insert_map = [&](auto const value_comp) { - auto const d_eqcomp = strong_index_comparator_adapter{ - comparator.equal_to(nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)}; - map.insert_if(haystack_it, - haystack_it + haystack.num_rows(), - thrust::counting_iterator(0), // stencil - row_is_valid{row_bitmask_ptr}, - d_hasher, - d_eqcomp, - stream.value()); + if (cudf::detail::has_nested_columns(haystack)) { + auto const d_eqcomp = strong_index_comparator_adapter{comparator.equal_to( + nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)}; + map.insert_if(haystack_it, + haystack_it + haystack.num_rows(), + thrust::counting_iterator(0), // stencil + row_is_valid{row_bitmask_ptr}, + d_hasher, + d_eqcomp, + stream.value()); + } else { + auto const d_eqcomp = strong_index_comparator_adapter{comparator.equal_to( + nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)}; + map.insert_if(haystack_it, + haystack_it + haystack.num_rows(), + thrust::counting_iterator(0), // stencil + row_is_valid{row_bitmask_ptr}, + d_hasher, + d_eqcomp, + stream.value()); + } }; + // Insert only rows that do not have any null at any level. dispatch_nan_comparator(compare_nans, insert_map); - } else { // haystack_doesn't_have_nulls || compare_nulls == null_equality::EQUAL auto const insert_map = [&](auto const value_comp) { - auto const d_eqcomp = strong_index_comparator_adapter{ - comparator.equal_to(nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)}; - map.insert( - haystack_it, haystack_it + haystack.num_rows(), d_hasher, d_eqcomp, stream.value()); + if (cudf::detail::has_nested_columns(haystack)) { + auto const d_eqcomp = strong_index_comparator_adapter{comparator.equal_to( + nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)}; + map.insert( + haystack_it, haystack_it + haystack.num_rows(), d_hasher, d_eqcomp, stream.value()); + } else { + auto const d_eqcomp = strong_index_comparator_adapter{comparator.equal_to( + nullate::DYNAMIC{haystack_has_nulls}, compare_nulls, value_comp)}; + map.insert( + haystack_it, haystack_it + haystack.num_rows(), d_hasher, d_eqcomp, stream.value()); + } }; dispatch_nan_comparator(compare_nans, insert_map); @@ -247,14 +265,25 @@ rmm::device_uvector contains_with_lists_or_nans(table_view const& haystack cudf::experimental::row::equality::two_table_comparator(haystack, needles, stream); auto const check_contains = [&](auto const value_comp) { - auto const d_eqcomp = - comparator.equal_to(nullate::DYNAMIC{has_any_nulls}, compare_nulls, value_comp); - map.contains(needles_it, - needles_it + needles.num_rows(), - contained.begin(), - d_hasher, - d_eqcomp, - stream.value()); + if (cudf::detail::has_nested_columns(haystack) or cudf::detail::has_nested_columns(needles)) { + auto const d_eqcomp = + comparator.equal_to(nullate::DYNAMIC{has_any_nulls}, compare_nulls, value_comp); + map.contains(needles_it, + needles_it + needles.num_rows(), + contained.begin(), + d_hasher, + d_eqcomp, + stream.value()); + } else { + auto const d_eqcomp = + comparator.equal_to(nullate::DYNAMIC{has_any_nulls}, compare_nulls, value_comp); + map.contains(needles_it, + needles_it + needles.num_rows(), + contained.begin(), + d_hasher, + d_eqcomp, + stream.value()); + } }; dispatch_nan_comparator(compare_nans, check_contains); diff --git a/cpp/src/sort/rank.cu b/cpp/src/sort/rank.cu index 461e978643f..fd65e38d467 100644 --- a/cpp/src/sort/rank.cu +++ b/cpp/src/sort/rank.cu @@ -48,6 +48,24 @@ namespace cudf { namespace detail { namespace { +template +struct unique_functor { + unique_functor(PermutationIteratorType permute, DeviceComparatorType device_comparator) + : _permute(permute), _device_comparator(device_comparator) + { + } + + auto __device__ operator()(size_type index) const noexcept + { + return static_cast(index == 0 || + not _device_comparator(_permute[index], _permute[index - 1])); + } + + private: + PermutationIteratorType _permute; + DeviceComparatorType _device_comparator; +}; + // Assign rank from 1 to n unique values. Equal values get same rank value. rmm::device_uvector sorted_dense_rank(column_view input_col, column_view sorted_order_view, @@ -55,21 +73,37 @@ rmm::device_uvector sorted_dense_rank(column_view input_col, { auto const t_input = table_view{{input_col}}; auto const comparator = cudf::experimental::row::equality::self_comparator{t_input, stream}; - auto const device_comparator = comparator.equal_to(nullate::DYNAMIC{has_nested_nulls(t_input)}); auto const sorted_index_order = thrust::make_permutation_iterator( sorted_order_view.begin(), thrust::make_counting_iterator(0)); - auto conv = [permute = sorted_index_order, device_comparator] __device__(size_type index) { - return static_cast(index == 0 || - not device_comparator(permute[index], permute[index - 1])); - }; - auto const unique_it = cudf::detail::make_counting_transform_iterator(0, conv); auto const input_size = input_col.size(); rmm::device_uvector dense_rank_sorted(input_size, stream); - thrust::inclusive_scan( - rmm::exec_policy(stream), unique_it, unique_it + input_size, dense_rank_sorted.data()); + auto const comparator_helper = [&](auto const device_comparator) { + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(input_size), + dense_rank_sorted.data(), + unique_functor{ + sorted_index_order, device_comparator}); + }; + + if (cudf::detail::has_nested_columns(t_input)) { + auto const device_comparator = + comparator.equal_to(nullate::DYNAMIC{has_nested_nulls(t_input)}); + comparator_helper(device_comparator); + } else { + auto const device_comparator = + comparator.equal_to(nullate::DYNAMIC{has_nested_nulls(t_input)}); + comparator_helper(device_comparator); + } + + thrust::inclusive_scan(rmm::exec_policy(stream), + dense_rank_sorted.begin(), + dense_rank_sorted.end(), + dense_rank_sorted.data()); + return dense_rank_sorted; } diff --git a/cpp/src/stream_compaction/distinct.cu b/cpp/src/stream_compaction/distinct.cu index 8f462f58e4e..e15d54b4251 100644 --- a/cpp/src/stream_compaction/distinct.cu +++ b/cpp/src/stream_compaction/distinct.cu @@ -55,7 +55,8 @@ rmm::device_uvector get_distinct_indices(table_view const& input, auto const preprocessed_input = cudf::experimental::row::hash::preprocessed_table::create(input, stream); - auto const has_nulls = nullate::DYNAMIC{cudf::has_nested_nulls(input)}; + auto const has_nulls = nullate::DYNAMIC{cudf::has_nested_nulls(input)}; + auto const has_nested_columns = cudf::detail::has_nested_columns(input); auto const row_hasher = cudf::experimental::row::hash::row_hasher(preprocessed_input); auto const key_hasher = experimental::compaction_hash(row_hasher.device_hasher(has_nulls)); @@ -66,8 +67,13 @@ rmm::device_uvector get_distinct_indices(table_view const& input, size_type{0}, [] __device__(size_type const i) { return cuco::make_pair(i, i); }); auto const insert_keys = [&](auto const value_comp) { - auto const key_equal = row_comp.equal_to(has_nulls, nulls_equal, value_comp); - map.insert(pair_iter, pair_iter + input.num_rows(), key_hasher, key_equal, stream.value()); + if (has_nested_columns) { + auto const key_equal = row_comp.equal_to(has_nulls, nulls_equal, value_comp); + map.insert(pair_iter, pair_iter + input.num_rows(), key_hasher, key_equal, stream.value()); + } else { + auto const key_equal = row_comp.equal_to(has_nulls, nulls_equal, value_comp); + map.insert(pair_iter, pair_iter + input.num_rows(), key_hasher, key_equal, stream.value()); + } }; if (nans_equal == nan_equality::ALL_EQUAL) { @@ -92,6 +98,7 @@ rmm::device_uvector get_distinct_indices(table_view const& input, std::move(preprocessed_input), input.num_rows(), has_nulls, + has_nested_columns, keep, nulls_equal, nans_equal, diff --git a/cpp/src/stream_compaction/distinct_reduce.cu b/cpp/src/stream_compaction/distinct_reduce.cu index 468561273b3..020e6a495bc 100644 --- a/cpp/src/stream_compaction/distinct_reduce.cu +++ b/cpp/src/stream_compaction/distinct_reduce.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -93,6 +93,7 @@ rmm::device_uvector hash_reduce_by_row( std::shared_ptr const preprocessed_input, size_type num_rows, cudf::nullate::DYNAMIC has_nulls, + bool has_nested_columns, duplicate_keep_option keep, null_equality nulls_equal, nan_equality nans_equal, @@ -115,13 +116,23 @@ rmm::device_uvector hash_reduce_by_row( auto const row_comp = cudf::experimental::row::equality::self_comparator(preprocessed_input); auto const reduce_by_row = [&](auto const value_comp) { - auto const key_equal = row_comp.equal_to(has_nulls, nulls_equal, value_comp); - thrust::for_each( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(num_rows), - reduce_by_row_fn{ - map.get_device_view(), key_hasher, key_equal, keep, reduction_results.begin()}); + if (has_nested_columns) { + auto const key_equal = row_comp.equal_to(has_nulls, nulls_equal, value_comp); + thrust::for_each( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + reduce_by_row_fn{ + map.get_device_view(), key_hasher, key_equal, keep, reduction_results.begin()}); + } else { + auto const key_equal = row_comp.equal_to(has_nulls, nulls_equal, value_comp); + thrust::for_each( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + reduce_by_row_fn{ + map.get_device_view(), key_hasher, key_equal, keep, reduction_results.begin()}); + } }; if (nans_equal == nan_equality::ALL_EQUAL) { diff --git a/cpp/src/stream_compaction/distinct_reduce.cuh b/cpp/src/stream_compaction/distinct_reduce.cuh index c8a0c2869c8..e360d03280a 100644 --- a/cpp/src/stream_compaction/distinct_reduce.cuh +++ b/cpp/src/stream_compaction/distinct_reduce.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -65,6 +65,7 @@ auto constexpr reduction_init_value(duplicate_keep_option keep) * comparisons * @param num_rows The number of all input rows * @param has_nulls Indicate whether the input rows has any nulls at any nested levels + * @param has_nested_columns Indicates whether the input table has any nested columns * @param keep The parameter to determine what type of reduction to perform * @param nulls_equal Flag to specify whether null elements should be considered as equal * @param stream CUDA stream used for device memory operations and kernel launches @@ -76,6 +77,7 @@ rmm::device_uvector hash_reduce_by_row( std::shared_ptr const preprocessed_input, size_type num_rows, cudf::nullate::DYNAMIC has_nulls, + bool has_nested_columns, duplicate_keep_option keep, null_equality nulls_equal, nan_equality nans_equal, diff --git a/cpp/src/stream_compaction/unique.cu b/cpp/src/stream_compaction/unique.cu index 369b63995e3..511a7b7ae1c 100644 --- a/cpp/src/stream_compaction/unique.cu +++ b/cpp/src/stream_compaction/unique.cu @@ -65,28 +65,40 @@ std::unique_ptr
unique(table_view const& input, auto mutable_view = mutable_column_device_view::create(*unique_indices, stream); auto keys_view = input.select(keys); - auto comp = cudf::experimental::row::equality::self_comparator(keys_view, stream); - auto row_equal = comp.equal_to(nullate::DYNAMIC{has_nested_nulls(keys_view)}, nulls_equal); + auto comp = cudf::experimental::row::equality::self_comparator(keys_view, stream); - // get indices of unique rows - auto result_end = unique_copy(thrust::counting_iterator(0), - thrust::counting_iterator(num_rows), - mutable_view->begin(), - row_equal, - keep, - stream); - auto indices_view = - cudf::detail::slice(column_view(*unique_indices), - 0, - thrust::distance(mutable_view->begin(), result_end)); + auto const comparator_helper = [&](auto const row_equal) { + // get indices of unique rows + auto result_end = unique_copy(thrust::counting_iterator(0), + thrust::counting_iterator(num_rows), + mutable_view->begin(), + row_equal, + keep, + stream); - // gather unique rows and return - return detail::gather(input, - indices_view, - out_of_bounds_policy::DONT_CHECK, - detail::negative_index_policy::NOT_ALLOWED, - stream, - mr); + auto indices_view = + cudf::detail::slice(column_view(*unique_indices), + 0, + thrust::distance(mutable_view->begin(), result_end)); + + // gather unique rows and return + return detail::gather(input, + indices_view, + out_of_bounds_policy::DONT_CHECK, + detail::negative_index_policy::NOT_ALLOWED, + stream, + mr); + }; + + if (cudf::detail::has_nested_columns(keys_view)) { + auto row_equal = + comp.equal_to(nullate::DYNAMIC{has_nested_nulls(keys_view)}, nulls_equal); + return comparator_helper(row_equal); + } else { + auto row_equal = + comp.equal_to(nullate::DYNAMIC{has_nested_nulls(keys_view)}, nulls_equal); + return comparator_helper(row_equal); + } } } // namespace detail diff --git a/cpp/src/transform/one_hot_encode.cu b/cpp/src/transform/one_hot_encode.cu index 8f0a44585bf..3f3dd422f9d 100644 --- a/cpp/src/transform/one_hot_encode.cu +++ b/cpp/src/transform/one_hot_encode.cu @@ -36,6 +36,25 @@ namespace cudf { namespace detail { +template +struct ohe_equality_functor { + ohe_equality_functor(size_type input_size, DeviceComparatorType d_equal) + : _input_size(input_size), _d_equal(d_equal) + { + } + + auto __device__ operator()(size_type i) const noexcept + { + auto const element_index = cudf::experimental::row::lhs_index_type{i % _input_size}; + auto const category_index = cudf::experimental::row::rhs_index_type{i / _input_size}; + return _d_equal(element_index, category_index); + } + + private: + size_type _input_size; + DeviceComparatorType _d_equal; +}; + std::pair, table_view> one_hot_encode(column_view const& input, column_view const& categories, rmm::cuda_stream_view stream, @@ -59,19 +78,24 @@ std::pair, table_view> one_hot_encode(column_view const& auto const t_rhs = table_view{{categories}}; auto const comparator = cudf::experimental::row::equality::two_table_comparator{t_lhs, t_rhs, stream}; - auto const d_equal = - comparator.equal_to(nullate::DYNAMIC{has_nested_nulls(t_lhs) || has_nested_nulls(t_rhs)}); - - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(total_size), - all_encodings->mutable_view().begin(), - [input_size = input.size(), d_equal] __device__(size_type i) { - auto const element_index = cudf::experimental::row::lhs_index_type{i % input_size}; - auto const category_index = cudf::experimental::row::rhs_index_type{i / input_size}; - return d_equal(element_index, category_index); - }); + + auto const comparator_helper = [&](auto const d_equal) { + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(total_size), + all_encodings->mutable_view().begin(), + ohe_equality_functor(input.size(), d_equal)); + }; + + if (cudf::detail::has_nested_columns(t_lhs) or cudf::detail::has_nested_columns(t_rhs)) { + auto const d_equal = comparator.equal_to( + nullate::DYNAMIC{has_nested_nulls(t_lhs) || has_nested_nulls(t_rhs)}); + comparator_helper(d_equal); + } else { + auto const d_equal = comparator.equal_to( + nullate::DYNAMIC{has_nested_nulls(t_lhs) || has_nested_nulls(t_rhs)}); + comparator_helper(d_equal); + } auto const split_iter = make_counting_transform_iterator(1, [width = input.size()](auto i) { return i * width; }); diff --git a/cpp/tests/table/experimental_row_operator_tests.cu b/cpp/tests/table/experimental_row_operator_tests.cu index ae55275aaec..1f3f7eefe79 100644 --- a/cpp/tests/table/experimental_row_operator_tests.cu +++ b/cpp/tests/table/experimental_row_operator_tests.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -115,18 +115,32 @@ auto self_equality(cudf::table_view input, rmm::cuda_stream_view stream{cudf::get_default_stream()}; auto const table_comparator = cudf::experimental::row::equality::self_comparator{input, stream}; - auto const equal_comparator = - table_comparator.equal_to(cudf::nullate::NO{}, cudf::null_equality::EQUAL, comparator); auto output = cudf::make_numeric_column( cudf::data_type(cudf::type_id::BOOL8), input.num_rows(), cudf::mask_state::UNALLOCATED); - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(input.num_rows()), - thrust::make_counting_iterator(0), - output->mutable_view().data(), - equal_comparator); + if (cudf::detail::has_nested_columns(input)) { + auto const equal_comparator = + table_comparator.equal_to(cudf::nullate::NO{}, cudf::null_equality::EQUAL, comparator); + + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(input.num_rows()), + thrust::make_counting_iterator(0), + output->mutable_view().data(), + equal_comparator); + } else { + auto const equal_comparator = + table_comparator.equal_to(cudf::nullate::NO{}, cudf::null_equality::EQUAL, comparator); + + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(input.num_rows()), + thrust::make_counting_iterator(0), + output->mutable_view().data(), + equal_comparator); + } + return output; } @@ -140,20 +154,34 @@ auto two_table_equality(cudf::table_view lhs, auto const table_comparator = cudf::experimental::row::equality::two_table_comparator{lhs, rhs, stream}; - auto const equal_comparator = - table_comparator.equal_to(cudf::nullate::NO{}, cudf::null_equality::EQUAL, comparator); + auto const lhs_it = cudf::experimental::row::lhs_iterator(0); auto const rhs_it = cudf::experimental::row::rhs_iterator(0); auto output = cudf::make_numeric_column( cudf::data_type(cudf::type_id::BOOL8), lhs.num_rows(), cudf::mask_state::UNALLOCATED); - thrust::transform(rmm::exec_policy(stream), - lhs_it, - lhs_it + lhs.num_rows(), - rhs_it, - output->mutable_view().data(), - equal_comparator); + if (cudf::detail::has_nested_columns(lhs) or cudf::detail::has_nested_columns(rhs)) { + auto const equal_comparator = + table_comparator.equal_to(cudf::nullate::NO{}, cudf::null_equality::EQUAL, comparator); + + thrust::transform(rmm::exec_policy(stream), + lhs_it, + lhs_it + lhs.num_rows(), + rhs_it, + output->mutable_view().data(), + equal_comparator); + } else { + auto const equal_comparator = + table_comparator.equal_to(cudf::nullate::NO{}, cudf::null_equality::EQUAL, comparator); + + thrust::transform(rmm::exec_policy(stream), + lhs_it, + lhs_it + lhs.num_rows(), + rhs_it, + output->mutable_view().data(), + equal_comparator); + } return output; }