diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index a89e23796d4b9..72226dda3a747 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -80,7 +80,8 @@ class HashJoinBasicImpl : public HashJoinImpl { Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution, size_t num_threads, HashJoinSchema* schema_mgr, - std::vector key_cmp, OutputBatchCallback output_batch_callback, + std::vector key_cmp, Expression filter, + OutputBatchCallback output_batch_callback, FinishedCallback finished_callback, TaskScheduler::ScheduleImpl schedule_task_callback) override { num_threads = std::max(num_threads, static_cast(1)); @@ -90,6 +91,7 @@ class HashJoinBasicImpl : public HashJoinImpl { num_threads_ = num_threads; schema_mgr_ = schema_mgr; key_cmp_ = std::move(key_cmp); + filter_ = std::move(filter); output_batch_callback_ = std::move(output_batch_callback); finished_callback_ = std::move(finished_callback); local_states_.resize(num_threads); @@ -207,8 +209,6 @@ class HashJoinBasicImpl : public HashJoinImpl { for (auto it = range.first; it != range.second; ++it) { output_match_left->push_back(irow); output_match_right->push_back(it->second); - // Mark row in hash table as having a match - BitUtil::SetBit(local_state->has_match.data(), it->second); has_match = true; } if (!has_match) { @@ -227,10 +227,7 @@ class HashJoinBasicImpl : public HashJoinImpl { schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT); int num_out_cols_right = schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT); - ARROW_DCHECK((opt_left_payload == nullptr) == - (schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) == 0)); - ARROW_DCHECK((opt_right_payload == nullptr) == - (schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) == 0)); + result.values.resize(num_out_cols_left + num_out_cols_right); auto from_key = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT, HashJoinProjection::KEY); @@ -282,6 +279,126 @@ class HashJoinBasicImpl : public HashJoinImpl { num_batches_produced_++; } + Status ProbeBatch_ResidualFilter(ThreadLocalState& local_state, + std::vector& match, + std::vector& no_match, + std::vector& match_left, + std::vector& match_right) { + if (filter_ == literal(true)) { + return Status::OK(); + } + ARROW_DCHECK_EQ(match_left.size(), match_right.size()); + + ExecBatch concatenated({}, match_left.size()); + + ARROW_ASSIGN_OR_RAISE(ExecBatch left_key, local_state.exec_batch_keys.Decode( + match_left.size(), match_left.data())); + ARROW_ASSIGN_OR_RAISE( + ExecBatch right_key, + hash_table_keys_.Decode(match_right.size(), match_right.data())); + + ExecBatch left_payload; + if (!schema_mgr_->LeftPayloadIsEmpty()) { + ARROW_ASSIGN_OR_RAISE(left_payload, local_state.exec_batch_payloads.Decode( + match_left.size(), match_left.data())); + } + + ExecBatch right_payload; + if (!schema_mgr_->RightPayloadIsEmpty()) { + ARROW_ASSIGN_OR_RAISE(right_payload, hash_table_payloads_.Decode( + match_right.size(), match_right.data())); + } + + auto AppendFields = [&concatenated](const SchemaProjectionMap& to_key, + const SchemaProjectionMap& to_pay, + const ExecBatch& key, const ExecBatch& payload) { + ARROW_DCHECK(to_key.num_cols == to_pay.num_cols); + for (int i = 0; i < to_key.num_cols; i++) { + if (to_key.get(i) != SchemaProjectionMap::kMissingField) { + int key_idx = to_key.get(i); + concatenated.values.push_back(key.values[key_idx]); + } else if (to_pay.get(i) != SchemaProjectionMap::kMissingField) { + int pay_idx = to_pay.get(i); + concatenated.values.push_back(payload.values[pay_idx]); + } + } + }; + + SchemaProjectionMap left_to_key = schema_mgr_->proj_maps[0].map( + HashJoinProjection::FILTER, HashJoinProjection::KEY); + SchemaProjectionMap left_to_pay = schema_mgr_->proj_maps[0].map( + HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD); + SchemaProjectionMap right_to_key = schema_mgr_->proj_maps[1].map( + HashJoinProjection::FILTER, HashJoinProjection::KEY); + SchemaProjectionMap right_to_pay = schema_mgr_->proj_maps[1].map( + HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD); + + AppendFields(left_to_key, left_to_pay, left_key, left_payload); + AppendFields(right_to_key, right_to_pay, right_key, right_payload); + + ARROW_ASSIGN_OR_RAISE(Datum mask, + ExecuteScalarExpression(filter_, concatenated, ctx_)); + + size_t num_probed_rows = match.size() + no_match.size(); + if (mask.is_scalar()) { + const auto& mask_scalar = mask.scalar_as(); + if (mask_scalar.is_valid && mask_scalar.value) { + // All rows passed, nothing left to do + return Status::OK(); + } else { + // Nothing passed, no_match becomes everything + no_match.resize(num_probed_rows); + std::iota(no_match.begin(), no_match.end(), 0); + match_left.clear(); + match_right.clear(); + match.clear(); + return Status::OK(); + } + } + ARROW_DCHECK_EQ(mask.array()->offset, 0); + ARROW_DCHECK_EQ(mask.array()->length, static_cast(match_left.size())); + const uint8_t* validity = + mask.array()->buffers[0] ? mask.array()->buffers[0]->data() : nullptr; + const uint8_t* comparisons = mask.array()->buffers[1]->data(); + size_t num_rows = match_left.size(); + + match.clear(); + no_match.clear(); + + int32_t match_idx = 0; // current size of new match_left + int32_t irow = 0; // index into match_left + for (int32_t curr_left = 0; static_cast(curr_left) < num_probed_rows; + curr_left++) { + int32_t advance_to = static_cast(irow) < num_rows + ? match_left[irow] + : static_cast(num_probed_rows); + while (curr_left < advance_to) { + no_match.push_back(curr_left++); + } + bool passed = false; + for (; static_cast(irow) < num_rows && match_left[irow] == curr_left; + irow++) { + bool is_valid = !validity || BitUtil::GetBit(validity, irow); + bool is_cmp_true = BitUtil::GetBit(comparisons, irow); + // We treat a null comparison result as false, like in SQL + if (is_valid && is_cmp_true) { + match_left[match_idx] = match_left[irow]; + match_right[match_idx] = match_right[irow]; + match_idx++; + passed = true; + } + } + if (passed) { + match.push_back(curr_left); + } else if (static_cast(curr_left) < num_probed_rows) { + no_match.push_back(curr_left); + } + } + match_left.resize(match_idx); + match_right.resize(match_idx); + return Status::OK(); + } + Status ProbeBatch_OutputOne(size_t thread_index, int64_t batch_size_next, const int32_t* opt_left_ids, const int32_t* opt_right_ids) { if (batch_size_next == 0 || (!opt_left_ids && !opt_right_ids)) { @@ -456,6 +573,15 @@ class HashJoinBasicImpl : public HashJoinImpl { &local_state.no_match, &local_state.match_left, &local_state.match_right); + RETURN_NOT_OK(ProbeBatch_ResidualFilter(local_state, local_state.match, + local_state.no_match, local_state.match_left, + local_state.match_right)); + + for (auto i : local_state.match_right) { + // Mark row in hash table as having a match + BitUtil::SetBit(local_state.has_match.data(), i); + } + RETURN_NOT_OK(ProbeBatch_OutputAll(thread_index, local_state.exec_batch_keys, local_state.exec_batch_payloads, local_state.match, local_state.no_match, local_state.match_left, @@ -732,6 +858,7 @@ class HashJoinBasicImpl : public HashJoinImpl { size_t num_threads_; HashJoinSchema* schema_mgr_; std::vector key_cmp_; + Expression filter_; std::unique_ptr scheduler_; int task_group_build_; int task_group_queued_; diff --git a/cpp/src/arrow/compute/exec/hash_join.h b/cpp/src/arrow/compute/exec/hash_join.h index 6520e4ae4a3f3..d52d7d980f96a 100644 --- a/cpp/src/arrow/compute/exec/hash_join.h +++ b/cpp/src/arrow/compute/exec/hash_join.h @@ -35,7 +35,7 @@ class ARROW_EXPORT HashJoinSchema { public: Status Init(JoinType join_type, const Schema& left_schema, const std::vector& left_keys, const Schema& right_schema, - const std::vector& right_keys, + const std::vector& right_keys, const Expression& filter, const std::string& left_field_name_prefix, const std::string& right_field_name_prefix); @@ -43,7 +43,7 @@ class ARROW_EXPORT HashJoinSchema { const std::vector& left_keys, const std::vector& left_output, const Schema& right_schema, const std::vector& right_keys, - const std::vector& right_output, + const std::vector& right_output, const Expression& filter, const std::string& left_field_name_prefix, const std::string& right_field_name_prefix); @@ -56,9 +56,15 @@ class ARROW_EXPORT HashJoinSchema { const std::string& left_field_name_prefix, const std::string& right_field_name_prefix); + Result BindFilter(Expression filter, const Schema& left_schema, + const Schema& right_schema); std::shared_ptr MakeOutputSchema(const std::string& left_field_name_prefix, const std::string& right_field_name_prefix); + bool LeftPayloadIsEmpty() { return PayloadIsEmpty(0); } + + bool RightPayloadIsEmpty() { return PayloadIsEmpty(1); } + static int kMissingField() { return SchemaProjectionMaps::kMissingField; } @@ -67,9 +73,26 @@ class ARROW_EXPORT HashJoinSchema { private: static bool IsTypeSupported(const DataType& type); - static Result> VectorDiff(const Schema& schema, - const std::vector& a, - const std::vector& b); + + Status CollectFilterColumns(std::vector& left_filter, + std::vector& right_filter, + const Expression& filter, const Schema& left_schema, + const Schema& right_schema); + + Expression RewriteFilterToUseFilterSchema(int right_filter_offset, + const SchemaProjectionMap& left_to_filter, + const SchemaProjectionMap& right_to_filter, + const Expression& filter); + + bool PayloadIsEmpty(int side) { + ARROW_DCHECK(side == 0 || side == 1); + return proj_maps[side].num_cols(HashJoinProjection::PAYLOAD) == 0; + } + + static Result> ComputePayload(const Schema& schema, + const std::vector& output, + const std::vector& filter, + const std::vector& key); }; class HashJoinImpl { @@ -80,7 +103,7 @@ class HashJoinImpl { virtual ~HashJoinImpl() = default; virtual Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution, size_t num_threads, HashJoinSchema* schema_mgr, - std::vector key_cmp, + std::vector key_cmp, Expression filter, OutputBatchCallback output_batch_callback, FinishedCallback finished_callback, TaskScheduler::ScheduleImpl schedule_task_callback) = 0; diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc index 4bccb761070f4..51e2e97cb1ac8 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node.cc @@ -43,32 +43,49 @@ bool HashJoinSchema::IsTypeSupported(const DataType& type) { return is_fixed_width(id) || is_binary_like(id) || is_large_binary_like(id); } -Result> HashJoinSchema::VectorDiff(const Schema& schema, - const std::vector& a, - const std::vector& b) { - std::unordered_set b_paths; - for (size_t i = 0; i < b.size(); ++i) { - ARROW_ASSIGN_OR_RAISE(auto match, b[i].FindOne(schema)); - b_paths.insert(match[0]); - } - - std::vector result; - - for (size_t i = 0; i < a.size(); ++i) { - ARROW_ASSIGN_OR_RAISE(auto match, a[i].FindOne(schema)); - bool is_found = (b_paths.find(match[0]) != b_paths.end()); - if (!is_found) { - result.push_back(a[i]); - } +Result> HashJoinSchema::ComputePayload( + const Schema& schema, const std::vector& output, + const std::vector& filter, const std::vector& keys) { + // payload = (output + filter) - keys, with no duplicates + std::unordered_set payload_fields; + for (auto ref : output) { + ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema)); + payload_fields.insert(match[0]); + } + + for (auto ref : filter) { + ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema)); + payload_fields.insert(match[0]); } - return result; + for (auto ref : keys) { + ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema)); + payload_fields.erase(match[0]); + } + + std::vector payload_refs; + for (auto ref : output) { + ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema)); + if (payload_fields.find(match[0]) != payload_fields.end()) { + payload_refs.push_back(ref); + payload_fields.erase(match[0]); + } + } + for (auto ref : filter) { + ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(schema)); + if (payload_fields.find(match[0]) != payload_fields.end()) { + payload_refs.push_back(ref); + payload_fields.erase(match[0]); + } + } + return payload_refs; } Status HashJoinSchema::Init(JoinType join_type, const Schema& left_schema, const std::vector& left_keys, const Schema& right_schema, const std::vector& right_keys, + const Expression& filter, const std::string& left_field_name_prefix, const std::string& right_field_name_prefix) { std::vector left_output; @@ -89,17 +106,15 @@ Status HashJoinSchema::Init(JoinType join_type, const Schema& left_schema, } } return Init(join_type, left_schema, left_keys, left_output, right_schema, right_keys, - right_output, left_field_name_prefix, right_field_name_prefix); + right_output, filter, left_field_name_prefix, right_field_name_prefix); } -Status HashJoinSchema::Init(JoinType join_type, const Schema& left_schema, - const std::vector& left_keys, - const std::vector& left_output, - const Schema& right_schema, - const std::vector& right_keys, - const std::vector& right_output, - const std::string& left_field_name_prefix, - const std::string& right_field_name_prefix) { +Status HashJoinSchema::Init( + JoinType join_type, const Schema& left_schema, const std::vector& left_keys, + const std::vector& left_output, const Schema& right_schema, + const std::vector& right_keys, const std::vector& right_output, + const Expression& filter, const std::string& left_field_name_prefix, + const std::string& right_field_name_prefix) { RETURN_NOT_OK(ValidateSchemas(join_type, left_schema, left_keys, left_output, right_schema, right_keys, right_output, left_field_name_prefix, right_field_name_prefix)); @@ -107,12 +122,21 @@ Status HashJoinSchema::Init(JoinType join_type, const Schema& left_schema, std::vector handles; std::vector*> field_refs; + std::vector left_filter, right_filter; + RETURN_NOT_OK( + CollectFilterColumns(left_filter, right_filter, filter, left_schema, right_schema)); + handles.push_back(HashJoinProjection::KEY); field_refs.push_back(&left_keys); + ARROW_ASSIGN_OR_RAISE(auto left_payload, - VectorDiff(left_schema, left_output, left_keys)); + ComputePayload(left_schema, left_output, left_filter, left_keys)); handles.push_back(HashJoinProjection::PAYLOAD); field_refs.push_back(&left_payload); + + handles.push_back(HashJoinProjection::FILTER); + field_refs.push_back(&left_filter); + handles.push_back(HashJoinProjection::OUTPUT); field_refs.push_back(&left_output); @@ -124,10 +148,15 @@ Status HashJoinSchema::Init(JoinType join_type, const Schema& left_schema, handles.push_back(HashJoinProjection::KEY); field_refs.push_back(&right_keys); - ARROW_ASSIGN_OR_RAISE(auto right_payload, - VectorDiff(right_schema, right_output, right_keys)); + + ARROW_ASSIGN_OR_RAISE(auto right_payload, ComputePayload(right_schema, right_output, + right_filter, right_keys)); handles.push_back(HashJoinProjection::PAYLOAD); field_refs.push_back(&right_payload); + + handles.push_back(HashJoinProjection::FILTER); + field_refs.push_back(&right_filter); + handles.push_back(HashJoinProjection::OUTPUT); field_refs.push_back(&right_output); @@ -274,17 +303,138 @@ std::shared_ptr HashJoinSchema::MakeOutputSchema( return std::make_shared(std::move(fields)); } +Result HashJoinSchema::BindFilter(Expression filter, + const Schema& left_schema, + const Schema& right_schema) { + if (filter.IsBound() || filter == literal(true)) { + return std::move(filter); + } + // Step 1: Construct filter schema + FieldVector fields; + auto left_f_to_i = + proj_maps[0].map(HashJoinProjection::FILTER, HashJoinProjection::INPUT); + auto right_f_to_i = + proj_maps[1].map(HashJoinProjection::FILTER, HashJoinProjection::INPUT); + + auto AppendFieldsInMap = [&fields](const SchemaProjectionMap& map, + const Schema& schema) { + for (int i = 0; i < map.num_cols; i++) { + int input_idx = map.get(i); + fields.push_back(schema.fields()[input_idx]); + } + }; + AppendFieldsInMap(left_f_to_i, left_schema); + AppendFieldsInMap(right_f_to_i, right_schema); + Schema filter_schema(fields); + + // Step 2: Rewrite expression to use filter schema + auto left_i_to_f = + proj_maps[0].map(HashJoinProjection::INPUT, HashJoinProjection::FILTER); + auto right_i_to_f = + proj_maps[1].map(HashJoinProjection::INPUT, HashJoinProjection::FILTER); + filter = RewriteFilterToUseFilterSchema(left_f_to_i.num_cols, left_i_to_f, right_i_to_f, + filter); + + // Step 3: Bind + ARROW_ASSIGN_OR_RAISE(filter, filter.Bind(filter_schema)); + if (filter.type()->id() != Type::BOOL) { + return Status::TypeError("Filter expression must evaluate to bool, but ", + filter.ToString(), " evaluates to ", + filter.type()->ToString()); + } + return std::move(filter); +} + +Expression HashJoinSchema::RewriteFilterToUseFilterSchema( + const int right_filter_offset, const SchemaProjectionMap& left_to_filter, + const SchemaProjectionMap& right_to_filter, const Expression& filter) { + if (const Expression::Call* c = filter.call()) { + std::vector args = c->arguments; + for (size_t i = 0; i < args.size(); i++) + args[i] = RewriteFilterToUseFilterSchema(right_filter_offset, left_to_filter, + right_to_filter, args[i]); + return call(c->function_name, args, c->options); + } else if (const FieldRef* r = filter.field_ref()) { + if (const FieldPath* path = r->field_path()) { + auto indices = path->indices(); + if (indices[0] >= left_to_filter.num_cols) { + indices[0] -= left_to_filter.num_cols; // Convert to index into right schema + indices[0] = + right_to_filter.get(indices[0]) + + right_filter_offset; // Convert right schema index to filter schema index + } else { + indices[0] = left_to_filter.get( + indices[0]); // Convert left schema index to filter schema index + } + return field_ref({std::move(indices)}); + } + } + return filter; +} + +Status HashJoinSchema::CollectFilterColumns(std::vector& left_filter, + std::vector& right_filter, + const Expression& filter, + const Schema& left_schema, + const Schema& right_schema) { + std::vector nonunique_refs = FieldsInExpression(filter); + + std::unordered_set left_seen_paths; + std::unordered_set right_seen_paths; + for (const FieldRef& ref : nonunique_refs) { + if (const FieldPath* path = ref.field_path()) { + std::vector indices = path->indices(); + if (indices[0] >= left_schema.num_fields()) { + indices[0] -= left_schema.num_fields(); + FieldPath corrected_path(std::move(indices)); + if (right_seen_paths.find(*path) == right_seen_paths.end()) { + right_filter.push_back(corrected_path); + right_seen_paths.emplace(std::move(corrected_path)); + } + } else if (left_seen_paths.find(*path) == left_seen_paths.end()) { + left_filter.push_back(ref); + left_seen_paths.emplace(std::move(indices)); + } + } else { + ARROW_DCHECK(ref.IsName()); + ARROW_ASSIGN_OR_RAISE(auto left_match, ref.FindOneOrNone(left_schema)); + ARROW_ASSIGN_OR_RAISE(auto right_match, ref.FindOneOrNone(right_schema)); + bool in_left = !left_match.empty(); + bool in_right = !right_match.empty(); + if (in_left && in_right) { + return Status::Invalid("FieldRef", ref.ToString(), + "was found in both left and right schemas"); + } else if (!in_left && !in_right) { + return Status::Invalid("FieldRef", ref.ToString(), + "was not found in either left or right schema"); + } + + ARROW_DCHECK(in_left != in_right); + auto& target_array = in_left ? left_filter : right_filter; + auto& target_set = in_left ? left_seen_paths : right_seen_paths; + auto& target_match = in_left ? left_match : right_match; + + if (target_set.find(target_match) == target_set.end()) { + target_array.push_back(ref); + target_set.emplace(std::move(target_match)); + } + } + } + return Status::OK(); +} + class HashJoinNode : public ExecNode { public: HashJoinNode(ExecPlan* plan, NodeVector inputs, const HashJoinNodeOptions& join_options, std::shared_ptr output_schema, - std::unique_ptr schema_mgr, + std::unique_ptr schema_mgr, Expression filter, std::unique_ptr impl) : ExecNode(plan, inputs, {"left", "right"}, /*output_schema=*/std::move(output_schema), /*num_outputs=*/1), join_type_(join_options.join_type), key_cmp_(join_options.key_cmp), + filter_(std::move(filter)), schema_mgr_(std::move(schema_mgr)), impl_(std::move(impl)) { complete_.store(false); @@ -300,20 +450,26 @@ class HashJoinNode : public ExecNode { const auto& join_options = checked_cast(options); + const auto& left_schema = *(inputs[0]->output_schema()); + const auto& right_schema = *(inputs[1]->output_schema()); // This will also validate input schemas if (join_options.output_all) { RETURN_NOT_OK(schema_mgr->Init( - join_options.join_type, *(inputs[0]->output_schema()), join_options.left_keys, - *(inputs[1]->output_schema()), join_options.right_keys, + join_options.join_type, left_schema, join_options.left_keys, right_schema, + join_options.right_keys, join_options.filter, join_options.output_prefix_for_left, join_options.output_prefix_for_right)); } else { RETURN_NOT_OK(schema_mgr->Init( - join_options.join_type, *(inputs[0]->output_schema()), join_options.left_keys, - join_options.left_output, *(inputs[1]->output_schema()), - join_options.right_keys, join_options.right_output, + join_options.join_type, left_schema, join_options.left_keys, + join_options.left_output, right_schema, join_options.right_keys, + join_options.right_output, join_options.filter, join_options.output_prefix_for_left, join_options.output_prefix_for_right)); } + ARROW_ASSIGN_OR_RAISE( + Expression filter, + schema_mgr->BindFilter(join_options.filter, left_schema, right_schema)); + // Generate output schema std::shared_ptr output_schema = schema_mgr->MakeOutputSchema( join_options.output_prefix_for_left, join_options.output_prefix_for_right); @@ -321,9 +477,9 @@ class HashJoinNode : public ExecNode { // Create hash join implementation object ARROW_ASSIGN_OR_RAISE(std::unique_ptr impl, HashJoinImpl::MakeBasic()); - return plan->EmplaceNode(plan, inputs, join_options, - std::move(output_schema), - std::move(schema_mgr), std::move(impl)); + return plan->EmplaceNode( + plan, inputs, join_options, std::move(output_schema), std::move(schema_mgr), + std::move(filter), std::move(impl)); } const char* kind_name() const override { return "HashJoinNode"; } @@ -385,7 +541,7 @@ class HashJoinNode : public ExecNode { RETURN_NOT_OK(impl_->Init( plan_->exec_context(), join_type_, use_sync_execution, num_threads, - schema_mgr_.get(), key_cmp_, + schema_mgr_.get(), key_cmp_, filter_, [this](ExecBatch batch) { this->OutputBatchCallback(batch); }, [this](int64_t total_num_batches) { this->FinishedCallback(total_num_batches); }, [this](std::function func) -> Status { @@ -453,6 +609,7 @@ class HashJoinNode : public ExecNode { Future<> finished_ = Future<>::MakeFinished(); JoinType join_type_; std::vector key_cmp_; + Expression filter_; ThreadIndexer thread_indexer_; std::unique_ptr schema_mgr_; std::unique_ptr impl_; diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc index 40738d1e229be..481cd94d5c98b 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc @@ -381,21 +381,6 @@ struct RandomDataTypeVector { void AddRandom(Random64Bit& rng, const RandomDataTypeConstraints& constraints) { data_types.push_back(RandomDataType::Random(rng, constraints)); } - - void Print() { - for (size_t i = 0; i < data_types.size(); ++i) { - if (!data_types[i].is_fixed_length) { - std::cout << "str[" << data_types[i].min_string_length << ".." - << data_types[i].max_string_length << "]"; - SCOPED_TRACE("str[" + std::to_string(data_types[i].min_string_length) + ".." + - std::to_string(data_types[i].max_string_length) + "]"); - } else { - std::cout << "int[" << data_types[i].fixed_length << "]"; - SCOPED_TRACE("int[" + std::to_string(data_types[i].fixed_length) + "]"); - } - } - std::cout << std::endl; - } }; std::vector> GenRandomRecords( @@ -951,7 +936,6 @@ void HashJoinWithExecPlan(Random64Bit& rng, bool parallel, TEST(HashJoin, Random) { Random64Bit rng(42); - int num_tests = 100; for (int test_id = 0; test_id < num_tests; ++test_id) { bool parallel = (rng.from_range(0, 1) == 1); @@ -1078,9 +1062,36 @@ TEST(HashJoin, Random) { continue; } + // Turn the last key comparison into a residual filter expression + Expression filter = literal(true); + if (key_cmp.size() > 1 && rng.from_range(0, 1) == 0) { + for (size_t i = 0; i < key_cmp.size(); i++) { + FieldRef left = key_fields[0][i]; + FieldRef right = key_fields[1][i]; + + if (key_cmp[i] == JoinKeyCmp::EQ) { + key_fields[0].erase(key_fields[0].begin() + i); + key_fields[1].erase(key_fields[1].begin() + i); + key_cmp.erase(key_cmp.begin() + i); + if (right.IsFieldPath()) { + auto indices = right.field_path()->indices(); + indices[0] += static_cast(shuffled_input_arrays[0].size()); + right = FieldRef{indices}; + } + + Expression left_expr(field_ref(left)); + Expression right_expr(field_ref(right)); + + filter = equal(left_expr, right_expr); + break; + } + } + } + // Run tested join implementation - HashJoinNodeOptions join_options{join_type, key_fields[0], key_fields[1], - output_fields[0], output_fields[1], key_cmp}; + HashJoinNodeOptions join_options{ + join_type, key_fields[0], key_fields[1], output_fields[0], + output_fields[1], key_cmp, filter}; std::vector> output_schema_fields; for (int i = 0; i < 2; ++i) { for (size_t col = 0; col < output_fields[i].size(); ++col) { @@ -1679,5 +1690,147 @@ TEST(HashJoin, UnsupportedTypes) { } } +TEST(HashJoin, ResidualFilter) { + for (bool parallel : {false, true}) { + SCOPED_TRACE(parallel ? "parallel/merged" : "serial"); + + BatchesWithSchema input_left; + input_left.batches = {ExecBatchFromJSON({int32(), int32(), utf8()}, R"([ + [1, 6, "alpha"], + [2, 5, "beta"], + [3, 4, "alpha"] + ])")}; + input_left.schema = + schema({field("l1", int32()), field("l2", int32()), field("l_str", utf8())}); + + BatchesWithSchema input_right; + input_right.batches = {ExecBatchFromJSON({int32(), int32(), utf8()}, R"([ + [5, 11, "alpha"], + [2, 12, "beta"], + [4, 16, "alpha"] + ])")}; + input_right.schema = + schema({field("r1", int32()), field("r2", int32()), field("r_str", utf8())}); + + auto exec_ctx = arrow::internal::make_unique( + default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); + AsyncGenerator> sink_gen; + + ExecNode* left_source; + ExecNode* right_source; + ASSERT_OK_AND_ASSIGN( + left_source, + MakeExecNode("source", plan.get(), {}, + SourceNodeOptions{input_left.schema, + input_left.gen(parallel, /*slow=*/false)})); + + ASSERT_OK_AND_ASSIGN( + right_source, + MakeExecNode("source", plan.get(), {}, + SourceNodeOptions{input_right.schema, + input_right.gen(parallel, /*slow=*/false)})) + + Expression mul = call("multiply", {field_ref("l1"), field_ref("l2")}); + Expression combination = call("add", {mul, field_ref("r1")}); + Expression residual_filter = less_equal(combination, field_ref("r2")); + + HashJoinNodeOptions join_opts{ + JoinType::FULL_OUTER, + /*left_keys=*/{"l_str"}, + /*right_keys=*/{"r_str"}, std::move(residual_filter), "l_", "r_"}; + + ASSERT_OK_AND_ASSIGN( + auto hashjoin, + MakeExecNode("hashjoin", plan.get(), {left_source, right_source}, join_opts)); + + ASSERT_OK_AND_ASSIGN(std::ignore, MakeExecNode("sink", plan.get(), {hashjoin}, + SinkNodeOptions{&sink_gen})); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto result, StartAndCollect(plan.get(), sink_gen)); + + std::vector expected = { + ExecBatchFromJSON({int32(), int32(), utf8(), int32(), int32(), utf8()}, R"([ + [1, 6, "alpha", 4, 16, "alpha"], + [1, 6, "alpha", 5, 11, "alpha"], + [2, 5, "beta", 2, 12, "beta"], + [3, 4, "alpha", 4, 16, "alpha"]])")}; + + AssertExecBatchesEqual(hashjoin->output_schema(), result, expected); + } +} + +TEST(HashJoin, TrivialResidualFilter) { + Expression always_true = + equal(call("add", {field_ref("l1"), field_ref("r1")}), literal(2)); // 1 + 1 == 2 + Expression always_false = + equal(call("add", {field_ref("l1"), field_ref("r1")}), literal(3)); // 1 + 1 == 3 + + std::string expected_true = R"([[1, "alpha", 1, "alpha"]])"; + std::string expected_false = R"([])"; + + std::vector expected_strings = {expected_true, expected_false}; + std::vector filters = {always_true, always_false}; + + for (size_t test_id = 0; test_id < 2; test_id++) { + for (bool parallel : {false, true}) { + SCOPED_TRACE(parallel ? "parallel/merged" : "serial"); + + BatchesWithSchema input_left; + input_left.batches = {ExecBatchFromJSON({int32(), utf8()}, R"([ + [1, "alpha"] + ])")}; + input_left.schema = schema({field("l1", int32()), field("l_str", utf8())}); + + BatchesWithSchema input_right; + input_right.batches = {ExecBatchFromJSON({int32(), utf8()}, R"([ + [1, "alpha"] + ])")}; + input_right.schema = schema({field("r1", int32()), field("r_str", utf8())}); + + auto exec_ctx = arrow::internal::make_unique( + default_memory_pool(), + parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); + AsyncGenerator> sink_gen; + + ExecNode* left_source; + ExecNode* right_source; + ASSERT_OK_AND_ASSIGN( + left_source, + MakeExecNode("source", plan.get(), {}, + SourceNodeOptions{input_left.schema, + input_left.gen(parallel, /*slow=*/false)})); + + ASSERT_OK_AND_ASSIGN( + right_source, + MakeExecNode("source", plan.get(), {}, + SourceNodeOptions{input_right.schema, + input_right.gen(parallel, /*slow=*/false)})) + + HashJoinNodeOptions join_opts{ + JoinType::INNER, + /*left_keys=*/{"l_str"}, + /*right_keys=*/{"r_str"}, filters[test_id], "l_", "r_"}; + + ASSERT_OK_AND_ASSIGN( + auto hashjoin, + MakeExecNode("hashjoin", plan.get(), {left_source, right_source}, join_opts)); + + ASSERT_OK_AND_ASSIGN(std::ignore, MakeExecNode("sink", plan.get(), {hashjoin}, + SinkNodeOptions{&sink_gen})); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto result, StartAndCollect(plan.get(), sink_gen)); + + std::vector expected = {ExecBatchFromJSON( + {int32(), utf8(), int32(), utf8()}, expected_strings[test_id])}; + + AssertExecBatchesEqual(hashjoin->output_schema(), result, expected); + } + } +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 87349191e90da..2723c4454c061 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -173,7 +173,7 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { static constexpr const char* default_output_prefix_for_right = ""; HashJoinNodeOptions( JoinType in_join_type, std::vector in_left_keys, - std::vector in_right_keys, + std::vector in_right_keys, Expression filter = literal(true), std::string output_prefix_for_left = default_output_prefix_for_left, std::string output_prefix_for_right = default_output_prefix_for_right) : join_type(in_join_type), @@ -181,7 +181,8 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { right_keys(std::move(in_right_keys)), output_all(true), output_prefix_for_left(std::move(output_prefix_for_left)), - output_prefix_for_right(std::move(output_prefix_for_right)) { + output_prefix_for_right(std::move(output_prefix_for_right)), + filter(std::move(filter)) { this->key_cmp.resize(this->left_keys.size()); for (size_t i = 0; i < this->left_keys.size(); ++i) { this->key_cmp[i] = JoinKeyCmp::EQ; @@ -190,7 +191,7 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { HashJoinNodeOptions( JoinType join_type, std::vector left_keys, std::vector right_keys, std::vector left_output, - std::vector right_output, + std::vector right_output, Expression filter = literal(true), std::string output_prefix_for_left = default_output_prefix_for_left, std::string output_prefix_for_right = default_output_prefix_for_right) : join_type(join_type), @@ -200,7 +201,8 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { left_output(std::move(left_output)), right_output(std::move(right_output)), output_prefix_for_left(std::move(output_prefix_for_left)), - output_prefix_for_right(std::move(output_prefix_for_right)) { + output_prefix_for_right(std::move(output_prefix_for_right)), + filter(std::move(filter)) { this->key_cmp.resize(this->left_keys.size()); for (size_t i = 0; i < this->left_keys.size(); ++i) { this->key_cmp[i] = JoinKeyCmp::EQ; @@ -210,6 +212,7 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { JoinType join_type, std::vector left_keys, std::vector right_keys, std::vector left_output, std::vector right_output, std::vector key_cmp, + Expression filter = literal(true), std::string output_prefix_for_left = default_output_prefix_for_left, std::string output_prefix_for_right = default_output_prefix_for_right) : join_type(join_type), @@ -220,7 +223,8 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { right_output(std::move(right_output)), key_cmp(std::move(key_cmp)), output_prefix_for_left(std::move(output_prefix_for_left)), - output_prefix_for_right(std::move(output_prefix_for_right)) {} + output_prefix_for_right(std::move(output_prefix_for_right)), + filter(std::move(filter)) {} // type of join (inner, left, semi...) JoinType join_type; @@ -244,6 +248,11 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { std::string output_prefix_for_left; // prefix added to names of output fields coming from right input std::string output_prefix_for_right; + // residual filter which is applied to matching rows. Rows that do not match + // the filter are not included. The filter is applied against the + // concatenated input schema (left fields then right fields) and can reference + // fields that are not included in the output. + Expression filter; }; /// \brief Make a node which select top_k/bottom_k rows passed through it diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 7d5bfe7d959a1..56e71b06e2c6b 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -1112,7 +1112,7 @@ TEST(ExecPlanExecution, SelfInnerHashJoinSink) { HashJoinNodeOptions join_opts{JoinType::INNER, /*left_keys=*/{"str"}, - /*right_keys=*/{"str"}, "l_", "r_"}; + /*right_keys=*/{"str"}, literal(true), "l_", "r_"}; ASSERT_OK_AND_ASSIGN( auto hashjoin, @@ -1169,7 +1169,7 @@ TEST(ExecPlanExecution, SelfOuterHashJoinSink) { HashJoinNodeOptions join_opts{JoinType::FULL_OUTER, /*left_keys=*/{"str"}, - /*right_keys=*/{"str"}, "l_", "r_"}; + /*right_keys=*/{"str"}, literal(true), "l_", "r_"}; ASSERT_OK_AND_ASSIGN( auto hashjoin, diff --git a/cpp/src/arrow/compute/exec/schema_util.h b/cpp/src/arrow/compute/exec/schema_util.h index 279cbb806db32..4e307e238072e 100644 --- a/cpp/src/arrow/compute/exec/schema_util.h +++ b/cpp/src/arrow/compute/exec/schema_util.h @@ -34,7 +34,13 @@ namespace compute { // Identifiers for all different row schemas that are used in a join // -enum class HashJoinProjection : int { INPUT = 0, KEY = 1, PAYLOAD = 2, OUTPUT = 3 }; +enum class HashJoinProjection : int { + INPUT = 0, + KEY = 1, + PAYLOAD = 2, + FILTER = 3, + OUTPUT = 4 +}; struct SchemaProjectionMap { static constexpr int kMissingField = -1; diff --git a/cpp/src/arrow/compute/exec/util_test.cc b/cpp/src/arrow/compute/exec/util_test.cc index 7acf8228d6a08..6f4b5315fff5b 100644 --- a/cpp/src/arrow/compute/exec/util_test.cc +++ b/cpp/src/arrow/compute/exec/util_test.cc @@ -34,8 +34,8 @@ TEST(FieldMap, Trivial) { auto left = schema({field("i32", int32())}); auto right = schema({field("i32", int32())}); - ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32"}, *right, {"i32"}, kLeftPrefix, - kRightPrefix)); + ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32"}, *right, {"i32"}, + literal(true), kLeftPrefix, kRightPrefix)); auto output = schema_mgr.MakeOutputSchema(kLeftPrefix, kRightPrefix); EXPECT_THAT(*output, Eq(Schema({ @@ -54,7 +54,8 @@ TEST(FieldMap, TrivialDuplicates) { auto left = schema({field("i32", int32())}); auto right = schema({field("i32", int32())}); - ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32"}, *right, {"i32"}, "", "")); + ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32"}, *right, {"i32"}, + literal(true), "", "")); auto output = schema_mgr.MakeOutputSchema("", ""); EXPECT_THAT(*output, Eq(Schema({ @@ -73,8 +74,8 @@ TEST(FieldMap, SingleKeyField) { auto left = schema({field("i32", int32()), field("str", utf8())}); auto right = schema({field("f32", float32()), field("i32", int32())}); - ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32"}, *right, {"i32"}, kLeftPrefix, - kRightPrefix)); + ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32"}, *right, {"i32"}, + literal(true), kLeftPrefix, kRightPrefix)); EXPECT_EQ(schema_mgr.proj_maps[0].num_cols(HashJoinProjection::INPUT), 2); EXPECT_EQ(schema_mgr.proj_maps[1].num_cols(HashJoinProjection::INPUT), 2); @@ -112,7 +113,7 @@ TEST(FieldMap, TwoKeyFields) { }); ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32", "str"}, *right, - {"i32", "str"}, kLeftPrefix, kRightPrefix)); + {"i32", "str"}, literal(true), kLeftPrefix, kRightPrefix)); auto output = schema_mgr.MakeOutputSchema(kLeftPrefix, kRightPrefix); EXPECT_THAT(*output, Eq(Schema({ diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 136aff5bb715f..fd461b7b14e36 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -1632,6 +1632,7 @@ class ARROW_EXPORT FieldRef { bool Equals(const FieldRef& other) const { return impl_ == other.impl_; } bool operator==(const FieldRef& other) const { return Equals(other); } + bool operator!=(const FieldRef& other) const { return !(*this == other); } std::string ToString() const;