Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor conditional joins #8815

Merged
merged 10 commits into from
Jul 28, 2021
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ add_library(cudf
src/jit/cache.cpp
src/jit/parser.cpp
src/jit/type.cpp
src/join/conditional_join.cu
src/join/cross_join.cu
src/join/hash_join.cu
src/join/join.cu
Expand Down
6 changes: 2 additions & 4 deletions cpp/include/cudf/ast/detail/linearizer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@
#pragma once

#include <cudf/ast/operators.hpp>
#include <cudf/scalar/scalar.hpp>
#include <cudf/scalar/scalar_device_view.cuh>
#include <cudf/table/table_view.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/error.hpp>

namespace cudf {
namespace ast {
Expand Down Expand Up @@ -108,7 +106,7 @@ class linearizer {
* @param right The right table used for evaluating the abstract syntax tree.
*/
linearizer(detail::node const& expr, cudf::table_view left, cudf::table_view right)
: _left(left), _right(right), _node_count(0), _intermediate_counter()
: _left{left}, _right{right}, _node_count{0}, _intermediate_counter{}
{
expr.accept(*this);
}
Expand All @@ -120,7 +118,7 @@ class linearizer {
* @param table The table used for evaluating the abstract syntax tree.
*/
linearizer(detail::node const& expr, cudf::table_view table)
: _left(table), _right(table), _node_count(0), _intermediate_counter()
: _left{table}, _right{table}, _node_count{0}, _intermediate_counter{}
{
expr.accept(*this);
}
Expand Down
3 changes: 0 additions & 3 deletions cpp/include/cudf/ast/detail/transform.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@

#include <thrust/optional.h>

#include <cstring>
#include <numeric>

namespace cudf {

namespace ast {
Expand Down
20 changes: 20 additions & 0 deletions cpp/include/cudf/join.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,9 @@ class hash_join {
* The corresponding values in the second returned vector are
* the matched row indices from the right table.
*
* If the provided predicate returns NULL for a pair of rows
* (left, right), that pair is not included in the output.
*
* @code{.pseudo}
* Left: {{0, 1, 2}}
* Right: {{1, 2, 3}}
Expand All @@ -672,6 +675,7 @@ class hash_join {
*
* @throw cudf::logic_error if number of elements in `left_keys` or `right_keys`
* mismatch.
* @throw cudf::logic_error if the binary predicate outputs a non-boolean result.
*
* @param left The left table
* @param right The right table
Expand Down Expand Up @@ -702,6 +706,9 @@ conditional_inner_join(
* from the right table, if there is a match or (2) an unspecified
* out-of-bounds value.
*
* If the provided predicate returns NULL for a pair of rows
* (left, right), that pair is not included in the output.
*
* @code{.pseudo}
* Left: {{0, 1, 2}}
* Right: {{1, 2, 3}}
Expand All @@ -716,6 +723,7 @@ conditional_inner_join(
*
* @throw cudf::logic_error if number of elements in `left_keys` or `right_keys`
* mismatch.
* @throw cudf::logic_error if the binary predicate outputs a non-boolean result.
*
* @param left The left table
* @param right The right table
Expand Down Expand Up @@ -744,6 +752,9 @@ conditional_left_join(table_view left,
* right tables, (2) a row index and an unspecified out-of-bounds value,
* representing a row from one table without a match in the other.
*
* If the provided predicate returns NULL for a pair of rows
* (left, right), that pair is not included in the output.
*
* @code{.pseudo}
* Left: {{0, 1, 2}}
* Right: {{1, 2, 3}}
Expand All @@ -758,6 +769,7 @@ conditional_left_join(table_view left,
*
* @throw cudf::logic_error if number of elements in `left_keys` or `right_keys`
* mismatch.
* @throw cudf::logic_error if the binary predicate outputs a non-boolean result.
*
* @param left The left table
* @param right The right table
Expand All @@ -781,6 +793,9 @@ conditional_full_join(table_view left,
* for which there exists some row in the right table where the predicate
* evaluates to true.
*
* If the provided predicate returns NULL for a pair of rows
* (left, right), that pair is not included in the output.
*
* @code{.pseudo}
* Left: {{0, 1, 2}}
* Right: {{1, 2, 3}}
Expand All @@ -795,6 +810,7 @@ conditional_full_join(table_view left,
*
* @throw cudf::logic_error if number of elements in `left_keys` or `right_keys`
* mismatch.
* @throw cudf::logic_error if the binary predicate outputs a non-boolean result.
*
* @param left The left table
* @param right The right table
Expand All @@ -818,6 +834,9 @@ std::unique_ptr<rmm::device_uvector<size_type>> conditional_left_semi_join(
* for which there does not exist any row in the right table where the
* predicate evaluates to true.
*
* If the provided predicate returns NULL for a pair of rows
* (left, right), that pair is not included in the output.
*
* @code{.pseudo}
* Left: {{0, 1, 2}}
* Right: {{1, 2, 3}}
Expand All @@ -832,6 +851,7 @@ std::unique_ptr<rmm::device_uvector<size_type>> conditional_left_semi_join(
*
* @throw cudf::logic_error if number of elements in `left_keys` or `right_keys`
* mismatch.
* @throw cudf::logic_error if the binary predicate outputs a non-boolean result.
*
* @param left The left table
* @param right The right table
Expand Down
5 changes: 5 additions & 0 deletions cpp/include/cudf/table/table_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,11 @@ class mutable_table_view : public detail::table_view_base<mutable_column_view> {
mutable_table_view(std::vector<mutable_table_view> const& views);
};

inline bool nullable(table_view const& view)
{
return std::any_of(view.begin(), view.end(), [](auto const& col) { return col.nullable(); });
}

inline bool has_nulls(table_view const& view)
{
return std::any_of(view.begin(), view.end(), [](auto const& col) { return col.has_nulls(); });
Expand Down
14 changes: 2 additions & 12 deletions cpp/src/ast/transform.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,11 @@
#include <cudf/table/table_view.hpp>
#include <cudf/types.hpp>
#include <cudf/utilities/error.hpp>
#include <cudf/utilities/traits.hpp>
#include <cudf/utilities/type_dispatcher.hpp>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_buffer.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>

#include <algorithm>
#include <functional>
#include <iterator>
#include <type_traits>

namespace cudf {
namespace ast {
namespace detail {
Expand Down Expand Up @@ -93,11 +86,8 @@ std::unique_ptr<column> compute_column(table_view const table,
// If none of the input columns actually contain nulls, we can still use the
// non-nullable version of the expression evaluation code path for
// performance, so we capture that information as well.
auto const nullable =
std::any_of(table.begin(), table.end(), [](column_view c) { return c.nullable(); });
auto const has_nulls = nullable && std::any_of(table.begin(), table.end(), [](column_view c) {
return c.nullable() && c.has_nulls();
});
auto const nullable = cudf::nullable(table);
auto const has_nulls = nullable && cudf::has_nulls(table);

auto const plan = ast_plan{expr, table, has_nulls, stream, mr};

Expand Down
130 changes: 130 additions & 0 deletions cpp/src/join/conditional_join.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <join/conditional_join.cuh>
#include <join/join_common_utils.hpp>

#include <cudf/join.hpp>
#include <cudf/table/table.hpp>
#include <cudf/table/table_view.hpp>

#include <rmm/cuda_stream_view.hpp>

namespace cudf {
namespace detail {

std::pair<std::unique_ptr<rmm::device_uvector<size_type>>,
std::unique_ptr<rmm::device_uvector<size_type>>>
conditional_join(table_view left,
table_view right,
ast::expression binary_predicate,
null_equality compare_nulls,
join_kind JoinKind,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return get_conditional_join_indices(
left, right, JoinKind, binary_predicate, compare_nulls, stream, mr);
}

} // namespace detail

std::pair<std::unique_ptr<rmm::device_uvector<size_type>>,
std::unique_ptr<rmm::device_uvector<size_type>>>
conditional_inner_join(table_view left,
table_view right,
ast::expression binary_predicate,
null_equality compare_nulls,
rmm::mr::device_memory_resource* mr)
{
return detail::conditional_join(left,
right,
binary_predicate,
compare_nulls,
detail::join_kind::INNER_JOIN,
rmm::cuda_stream_default,
mr);
}

std::pair<std::unique_ptr<rmm::device_uvector<size_type>>,
std::unique_ptr<rmm::device_uvector<size_type>>>
conditional_left_join(table_view left,
table_view right,
ast::expression binary_predicate,
null_equality compare_nulls,
rmm::mr::device_memory_resource* mr)
{
return detail::conditional_join(left,
right,
binary_predicate,
compare_nulls,
detail::join_kind::LEFT_JOIN,
rmm::cuda_stream_default,
mr);
}

std::pair<std::unique_ptr<rmm::device_uvector<size_type>>,
std::unique_ptr<rmm::device_uvector<size_type>>>
conditional_full_join(table_view left,
table_view right,
ast::expression binary_predicate,
null_equality compare_nulls,
rmm::mr::device_memory_resource* mr)
{
return detail::conditional_join(left,
right,
binary_predicate,
compare_nulls,
detail::join_kind::FULL_JOIN,
rmm::cuda_stream_default,
mr);
}

std::unique_ptr<rmm::device_uvector<size_type>> conditional_left_semi_join(
table_view left,
table_view right,
ast::expression binary_predicate,
null_equality compare_nulls,
rmm::mr::device_memory_resource* mr)
{
return std::move(detail::conditional_join(left,
right,
binary_predicate,
compare_nulls,
detail::join_kind::LEFT_SEMI_JOIN,
rmm::cuda_stream_default,
mr)
.first);
}

std::unique_ptr<rmm::device_uvector<size_type>> conditional_left_anti_join(
table_view left,
table_view right,
ast::expression binary_predicate,
null_equality compare_nulls,
rmm::mr::device_memory_resource* mr)
{
return std::move(detail::conditional_join(left,
right,
binary_predicate,
compare_nulls,
detail::join_kind::LEFT_ANTI_JOIN,
rmm::cuda_stream_default,
mr)
.first);
}

} // namespace cudf
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
*/
#pragma once

#include "hash_join.cuh"
#include "join_common_utils.hpp"
#include "join_kernels.cuh"
#include <join/conditional_join_kernels.cuh>
#include <join/join_common_utils.cuh>
#include <join/join_common_utils.hpp>

#include <cudf/ast/detail/transform.cuh>
#include <cudf/ast/nodes.hpp>
#include <cudf/scalar/scalar.hpp>
#include <cudf/scalar/scalar_device_view.cuh>
#include <cudf/detail/utilities/cuda.cuh>
#include <cudf/table/table.hpp>
#include <cudf/table/table_device_view.cuh>
#include <cudf/table/table_view.hpp>
Expand All @@ -31,10 +30,6 @@
#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/per_device_resource.hpp>

#include <thrust/optional.h>

#include <algorithm>

namespace cudf {
namespace detail {

Expand Down Expand Up @@ -85,15 +80,8 @@ get_conditional_join_indices(table_view const& left,
// If none of the input columns actually contain nulls, we can still use the
// non-nullable version of the expression evaluation code path for
// performance, so we capture that information as well.
auto const nullable =
std::any_of(left.begin(), left.end(), [](column_view c) { return c.nullable(); }) ||
std::any_of(right.begin(), right.end(), [](column_view c) { return c.nullable(); });
auto const has_nulls =
nullable &&
(std::any_of(
left.begin(), left.end(), [](column_view c) { return c.nullable() && c.has_nulls(); }) ||
std::any_of(
right.begin(), right.end(), [](column_view c) { return c.nullable() && c.has_nulls(); }));
auto const nullable = cudf::nullable(left) || cudf::nullable(right);
auto const has_nulls = nullable && (cudf::has_nulls(left) || cudf::has_nulls(right));

auto const plan = ast::detail::ast_plan{binary_pred, left, right, has_nulls, stream, mr};
CUDF_EXPECTS(plan.output_type().id() == type_id::BOOL8,
Expand Down
Loading