Skip to content

Commit

Permalink
Optimized left_semi_join
Browse files Browse the repository at this point in the history
Up to 20x faster. Separated hash table lookup from copy_if because
increased register usage significantly limited occupancy of this
kernel.
  • Loading branch information
Xavier Simmons committed Apr 4, 2022
1 parent 291fbcf commit 52df19b
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions cpp/src/join/semi_join.cu
Original file line number Diff line number Diff line change
Expand Up @@ -137,18 +137,29 @@ std::unique_ptr<rmm::device_uvector<cudf::size_type>> left_semi_anti_join(
auto gather_map =
std::make_unique<rmm::device_uvector<cudf::size_type>>(left_num_rows, stream, mr);

rmm::device_uvector<bool> flagged(left_num_rows, stream, mr);
auto flagged_d = flagged.data();

auto counting_iter = thrust::counting_iterator<size_type>(0);
thrust::for_each(
rmm::exec_policy(stream),
counting_iter,
counting_iter + left_num_rows,
[flagged_d, hash_table_view, join_type_boolean, hash_probe, equality_probe] __device__(
const size_type idx) {
flagged_d[idx] =
hash_table_view.contains(idx, hash_probe, equality_probe) == join_type_boolean;
});

// gather_map_end will be the end of valid data in gather_map
auto gather_map_end = thrust::copy_if(
rmm::exec_policy(stream),
thrust::make_counting_iterator(0),
thrust::make_counting_iterator(left_num_rows),
gather_map->begin(),
[hash_table_view, join_type_boolean, hash_probe, equality_probe] __device__(
size_type const idx) {
// Look up this row. The hash function used here needs to map a (left) row index to the hash
// of the row, so it's a row hash. The equality check needs to verify
return hash_table_view.contains(idx, hash_probe, equality_probe) == join_type_boolean;
});
rmm::exec_policy(stream),
thrust::make_counting_iterator(0),
thrust::make_counting_iterator(left_num_rows),
gather_map->begin(),
[flagged_d]__device__(size_type const idx) {
return flagged_d[idx];
});

auto join_size = thrust::distance(gather_map->begin(), gather_map_end);
gather_map->resize(join_size, stream);
Expand Down

0 comments on commit 52df19b

Please sign in to comment.