Skip to content

Commit

Permalink
GH-20339: [C++] Add residual filter support to swiss join (#39487)
Browse files Browse the repository at this point in the history
### Rationale for this change

Add residual filter support to swiss join.

### What changes are included in this PR?

1. Added class `JoinResidualFilter` as a centralized structure to evaluate residual filter in swiss join. It has various flavors of filtering for various join types. Zero-overhead is guaranteed for trivial filters (literal true and sometimes literal false/null). More detailed explanation in code comments.
2. Tuned the structure of swiss join main body (`JoinProbeProcessor::OnNextBatch`) to better cope with `JoinResidualFilter` calls.

### Are these changes tested?

Legacy UTs (`HashJoin.Random`, `HashJoin.ResidualFilter` and `HashJoin.TrivialResidualFilter`) cover part of this change. New fine-grained residual filter cases added as well.

### Are there any user-facing changes?

No.

* Closes: #20339

Lead-authored-by: zanmato <[email protected]>
Co-authored-by: zanmato1984 <[email protected]>
Co-authored-by: Ruoxi Sun <[email protected]>
Co-authored-by: Rossi Sun <[email protected]>
Signed-off-by: Weston Pace <[email protected]>
  • Loading branch information
zanmato1984 authored Mar 12, 2024
1 parent 6121b3f commit 0ce7267
Show file tree
Hide file tree
Showing 5 changed files with 1,975 additions and 178 deletions.
197 changes: 192 additions & 5 deletions cpp/src/arrow/acero/hash_join_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ struct BenchmarkSettings {
double null_percentage = 0.0;
double cardinality = 1.0; // Proportion of distinct keys in build side
double selectivity = 1.0; // Probability of a match for a given row
int var_length_min = 2; // Minimal length of any var length types
int var_length_max = 20; // Maximum length of any var length types

Expression residual_filter = literal(true);
};

class JoinBenchmark {
Expand Down Expand Up @@ -79,8 +83,8 @@ class JoinBenchmark {
build_metadata["null_probability"] = std::to_string(settings.null_percentage);
build_metadata["min"] = std::to_string(min_build_value);
build_metadata["max"] = std::to_string(max_build_value);
build_metadata["min_length"] = "2";
build_metadata["max_length"] = "20";
build_metadata["min_length"] = settings.var_length_min;
build_metadata["max_length"] = settings.var_length_max;

std::unordered_map<std::string, std::string> probe_metadata;
probe_metadata["null_probability"] = std::to_string(settings.null_percentage);
Expand Down Expand Up @@ -126,10 +130,9 @@ class JoinBenchmark {
stats_.num_probe_rows = settings.num_probe_batches * settings.batch_size;

schema_mgr_ = std::make_unique<HashJoinSchema>();
Expression filter = literal(true);
DCHECK_OK(schema_mgr_->Init(settings.join_type, *l_batches_with_schema.schema,
left_keys, *r_batches_with_schema.schema, right_keys,
filter, "l_", "r_"));
settings.residual_filter, "l_", "r_"));

if (settings.use_basic_implementation) {
join_ = *HashJoinImpl::MakeBasic();
Expand Down Expand Up @@ -158,7 +161,7 @@ class JoinBenchmark {

DCHECK_OK(join_->Init(
&ctx_, settings.join_type, settings.num_threads, &(schema_mgr_->proj_maps[0]),
&(schema_mgr_->proj_maps[1]), std::move(key_cmp), std::move(filter),
&(schema_mgr_->proj_maps[1]), std::move(key_cmp), settings.residual_filter,
std::move(register_task_group_callback), std::move(start_task_group_callback),
[](int64_t, ExecBatch) { return Status::OK(); },
[](int64_t) { return Status::OK(); }));
Expand Down Expand Up @@ -308,6 +311,60 @@ static void BM_HashJoinBasic_NullPercentage(benchmark::State& st) {

HashJoinBasicBenchmarkImpl(st, settings);
}

template <typename... Args>
static void BM_HashJoinBasic_TrivialResidualFilter(benchmark::State& st,
JoinType join_type,
Expression residual_filter,
Args&&...) {
BenchmarkSettings settings;
settings.join_type = join_type;
settings.build_payload_types = {binary()};
settings.probe_payload_types = {binary()};

settings.use_basic_implementation = st.range(0);

settings.num_build_batches = 1024;
settings.num_probe_batches = 1024;

// Let payload column length from 1 to 100.
settings.var_length_min = 1;
settings.var_length_max = 100;

settings.residual_filter = std::move(residual_filter);

HashJoinBasicBenchmarkImpl(st, settings);
}

template <typename... Args>
static void BM_HashJoinBasic_ComplexResidualFilter(benchmark::State& st,
JoinType join_type, Args&&...) {
BenchmarkSettings settings;
settings.join_type = join_type;
settings.build_payload_types = {binary()};
settings.probe_payload_types = {binary()};

settings.use_basic_implementation = st.range(0);

settings.num_build_batches = 1024;
settings.num_probe_batches = 1024;

// Let payload column length from 1 to 100.
settings.var_length_min = 1;
settings.var_length_max = 100;

// Create filter referring payload columns from both sides.
// binary_length(probe_payload) + binary_length(build_payload) <= 2 * selectivity
settings.selectivity = static_cast<double>(st.range(1)) / 100.0;
using arrow::compute::call;
using arrow::compute::field_ref;
settings.residual_filter =
call("less_equal", {call("plus", {call("binary_length", {field_ref("lp0")}),
call("binary_length", {field_ref("rp0")})}),
literal(2 * settings.selectivity)});

HashJoinBasicBenchmarkImpl(st, settings);
}
#endif

std::vector<int64_t> hashtable_krows = benchmark::CreateRange(1, 4096, 8);
Expand Down Expand Up @@ -435,6 +492,136 @@ BENCHMARK(BM_HashJoinBasic_BuildParallelism)
BENCHMARK(BM_HashJoinBasic_NullPercentage)
->ArgNames({"Null Percentage"})
->DenseRange(0, 100, 10);

const char* use_basic_argname = "Use basic";
std::vector<int64_t> use_basic_arg = benchmark::CreateDenseRange(0, 1, 1);

std::vector<std::string> trivial_residual_filter_argnames = {use_basic_argname};
std::vector<std::vector<int64_t>> trivial_residual_filter_args = {use_basic_arg};

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Inner/Literal(true)",
JoinType::INNER, literal(true))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Left Semi/Literal(true)",
JoinType::LEFT_SEMI, literal(true))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Right Semi/Literal(true)",
JoinType::RIGHT_SEMI, literal(true))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Left Anti/Literal(true)",
JoinType::LEFT_ANTI, literal(true))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Right Anti/Literal(true)",
JoinType::RIGHT_ANTI, literal(true))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Left Outer/Literal(true)",
JoinType::LEFT_OUTER, literal(true))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Right Outer/Literal(true)",
JoinType::RIGHT_OUTER, literal(true))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Full Outer/Literal(true)",
JoinType::FULL_OUTER, literal(true))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Inner/Literal(false)",
JoinType::INNER, literal(false))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Left Semi/Literal(false)",
JoinType::LEFT_SEMI, literal(false))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Right Semi/Literal(false)",
JoinType::RIGHT_SEMI, literal(false))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Left Anti/Literal(false)",
JoinType::LEFT_ANTI, literal(false))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Right Anti/Literal(false)",
JoinType::RIGHT_ANTI, literal(false))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Left Outer/Literal(false)",
JoinType::LEFT_OUTER, literal(false))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Right Outer/Literal(false)",
JoinType::RIGHT_OUTER, literal(false))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Full Outer/Literal(false)",
JoinType::FULL_OUTER, literal(false))
->ArgNames(trivial_residual_filter_argnames)
->ArgsProduct(trivial_residual_filter_args);

std::vector<std::string> complex_residual_filter_argnames = {use_basic_argname,
"Selectivity"};
std::vector<std::vector<int64_t>> complex_residual_filter_args = {
use_basic_arg, benchmark::CreateDenseRange(0, 100, 20)};

BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Inner", JoinType::INNER)
->ArgNames(complex_residual_filter_argnames)
->ArgsProduct(complex_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Left Semi",
JoinType::LEFT_SEMI)
->ArgNames(complex_residual_filter_argnames)
->ArgsProduct(complex_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Right Semi",
JoinType::RIGHT_SEMI)
->ArgNames(complex_residual_filter_argnames)
->ArgsProduct(complex_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Left Anti",
JoinType::LEFT_ANTI)
->ArgNames(complex_residual_filter_argnames)
->ArgsProduct(complex_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Right Anti",
JoinType::RIGHT_ANTI)
->ArgNames(complex_residual_filter_argnames)
->ArgsProduct(complex_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Left Outer",
JoinType::LEFT_OUTER)
->ArgNames(complex_residual_filter_argnames)
->ArgsProduct(complex_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Right Outer",
JoinType::RIGHT_OUTER)
->ArgNames(complex_residual_filter_argnames)
->ArgsProduct(complex_residual_filter_args);

BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Full Outer",
JoinType::FULL_OUTER)
->ArgNames(complex_residual_filter_argnames)
->ArgsProduct(complex_residual_filter_args);
#else

BENCHMARK_CAPTURE(BM_HashJoinBasic_KeyTypes, "{int32}", {int32()})
Expand Down
6 changes: 2 additions & 4 deletions cpp/src/arrow/acero/hash_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -740,13 +740,11 @@ class HashJoinNode : public ExecNode, public TracedNode {
// Create hash join implementation object
// SwissJoin does not support:
// a) 64-bit string offsets
// b) residual predicates
// c) dictionaries
// b) dictionaries
//
bool use_swiss_join;
#if ARROW_LITTLE_ENDIAN
use_swiss_join = (filter == literal(true)) && !schema_mgr->HasDictionaries() &&
!schema_mgr->HasLargeBinary();
use_swiss_join = !schema_mgr->HasDictionaries() && !schema_mgr->HasLargeBinary();
#else
use_swiss_join = false;
#endif
Expand Down
Loading

0 comments on commit 0ce7267

Please sign in to comment.