Skip to content

Commit

Permalink
apacheGH-45196: [C++][Acero] Small refinement to hash join (apache#45197
Browse files Browse the repository at this point in the history
)

### Rationale for this change

See apache#45196

### What changes are included in this PR?

Refine/simplify the code mentioned in the issue.

### Are these changes tested?

Existing tests suffice.

### Are there any user-facing changes?

None.

* GitHub Issue: apache#45196

Authored-by: Rossi Sun <[email protected]>
Signed-off-by: Rossi Sun <[email protected]>
  • Loading branch information
zanmato1984 authored Jan 8, 2025
1 parent 438cf9b commit 0aa622c
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 20 deletions.
16 changes: 4 additions & 12 deletions cpp/src/arrow/acero/hash_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -568,20 +568,16 @@ class HashJoinBasicImpl : public HashJoinImpl {
if (has_payload) {
InitEncoder(1, HashJoinProjection::PAYLOAD, &hash_table_payloads_);
}
hash_table_empty_ = true;
hash_table_empty_ = batches.empty();
RETURN_NOT_OK(dict_build_.Init(*schema_[1], hash_table_empty_ ? nullptr : &batches[0],
ctx_->exec_context()));

for (size_t ibatch = 0; ibatch < batches.batch_count(); ++ibatch) {
if (cancelled_) {
return Status::Cancelled("Hash join cancelled");
}
const ExecBatch& batch = batches[ibatch];
if (batch.length == 0) {
continue;
} else if (hash_table_empty_) {
hash_table_empty_ = false;

RETURN_NOT_OK(dict_build_.Init(*schema_[1], &batch, ctx_->exec_context()));
}
DCHECK_GT(batch.length, 0);
int32_t num_rows_before = hash_table_keys_.num_rows();
RETURN_NOT_OK(dict_build_.EncodeBatch(thread_index, *schema_[1], batch,
&hash_table_keys_, ctx_->exec_context()));
Expand All @@ -595,10 +591,6 @@ class HashJoinBasicImpl : public HashJoinImpl {
}
}

if (hash_table_empty_) {
RETURN_NOT_OK(dict_build_.Init(*schema_[1], nullptr, ctx_->exec_context()));
}

return Status::OK();
}

Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/acero/hash_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,9 @@ class HashJoinNode : public ExecNode, public TracedNode {
const char* kind_name() const override { return "HashJoinNode"; }

Status OnBuildSideBatch(size_t thread_index, ExecBatch batch) {
if (batch.length == 0) {
return Status::OK();
}
std::lock_guard<std::mutex> guard(build_side_mutex_);
build_accumulator_.InsertBatch(std::move(batch));
return Status::OK();
Expand Down
10 changes: 2 additions & 8 deletions cpp/src/arrow/acero/swiss_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2598,17 +2598,15 @@ class SwissJoin : public HashJoinImpl {
return Status::OK();
}

DCHECK_GT(build_side_batches_[batch_id].length, 0);

const HashJoinProjectionMaps* schema = schema_[1];
bool no_payload = hash_table_build_.no_payload();

ExecBatch input_batch;
ARROW_ASSIGN_OR_RAISE(
input_batch, KeyPayloadFromInput(/*side=*/1, &build_side_batches_[batch_id]));

if (input_batch.length == 0) {
return Status::OK();
}

// Split batch into key batch and optional payload batch
//
// Input batch is key-payload batch (key columns followed by payload
Expand Down Expand Up @@ -2637,10 +2635,6 @@ class SwissJoin : public HashJoinImpl {
static_cast<int64_t>(thread_id), key_batch, no_payload ? nullptr : &payload_batch,
temp_stack)));

// Release input batch
//
input_batch.values.clear();

return Status::OK();
}

Expand Down

0 comments on commit 0aa622c

Please sign in to comment.