Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-13643: [C++][Compute] Implement outer join with support for residual predicates #11579

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 134 additions & 7 deletions cpp/src/arrow/compute/exec/hash_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ class HashJoinBasicImpl : public HashJoinImpl {

Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution,
size_t num_threads, HashJoinSchema* schema_mgr,
std::vector<JoinKeyCmp> key_cmp, OutputBatchCallback output_batch_callback,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
OutputBatchCallback output_batch_callback,
FinishedCallback finished_callback,
TaskScheduler::ScheduleImpl schedule_task_callback) override {
num_threads = std::max(num_threads, static_cast<size_t>(1));
Expand All @@ -90,6 +91,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
num_threads_ = num_threads;
schema_mgr_ = schema_mgr;
key_cmp_ = std::move(key_cmp);
filter_ = std::move(filter);
output_batch_callback_ = std::move(output_batch_callback);
finished_callback_ = std::move(finished_callback);
local_states_.resize(num_threads);
Expand Down Expand Up @@ -207,8 +209,6 @@ class HashJoinBasicImpl : public HashJoinImpl {
for (auto it = range.first; it != range.second; ++it) {
output_match_left->push_back(irow);
output_match_right->push_back(it->second);
// Mark row in hash table as having a match
BitUtil::SetBit(local_state->has_match.data(), it->second);
has_match = true;
}
if (!has_match) {
Expand All @@ -227,10 +227,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT);
int num_out_cols_right =
schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT);
ARROW_DCHECK((opt_left_payload == nullptr) ==
(schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) == 0));
ARROW_DCHECK((opt_right_payload == nullptr) ==
(schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) == 0));

result.values.resize(num_out_cols_left + num_out_cols_right);
auto from_key = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
HashJoinProjection::KEY);
Expand Down Expand Up @@ -282,6 +279,126 @@ class HashJoinBasicImpl : public HashJoinImpl {
num_batches_produced_++;
}

Status ProbeBatch_ResidualFilter(ThreadLocalState& local_state,
std::vector<int32_t>& match,
std::vector<int32_t>& no_match,
std::vector<int32_t>& match_left,
std::vector<int32_t>& match_right) {
if (filter_ == literal(true)) {
return Status::OK();
}
ARROW_DCHECK_EQ(match_left.size(), match_right.size());

ExecBatch concatenated({}, match_left.size());

ARROW_ASSIGN_OR_RAISE(ExecBatch left_key, local_state.exec_batch_keys.Decode(
match_left.size(), match_left.data()));
ARROW_ASSIGN_OR_RAISE(
ExecBatch right_key,
hash_table_keys_.Decode(match_right.size(), match_right.data()));

ExecBatch left_payload;
if (!schema_mgr_->LeftPayloadIsEmpty()) {
ARROW_ASSIGN_OR_RAISE(left_payload, local_state.exec_batch_payloads.Decode(
match_left.size(), match_left.data()));
}

ExecBatch right_payload;
if (!schema_mgr_->RightPayloadIsEmpty()) {
ARROW_ASSIGN_OR_RAISE(right_payload, hash_table_payloads_.Decode(
match_right.size(), match_right.data()));
}
save-buffer marked this conversation as resolved.
Show resolved Hide resolved

auto AppendFields = [&concatenated](const SchemaProjectionMap& to_key,
const SchemaProjectionMap& to_pay,
const ExecBatch& key, const ExecBatch& payload) {
ARROW_DCHECK(to_key.num_cols == to_pay.num_cols);
for (int i = 0; i < to_key.num_cols; i++) {
if (to_key.get(i) != SchemaProjectionMap::kMissingField) {
int key_idx = to_key.get(i);
concatenated.values.push_back(key.values[key_idx]);
} else if (to_pay.get(i) != SchemaProjectionMap::kMissingField) {
int pay_idx = to_pay.get(i);
concatenated.values.push_back(payload.values[pay_idx]);
}
}
};

SchemaProjectionMap left_to_key = schema_mgr_->proj_maps[0].map(
HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap left_to_pay = schema_mgr_->proj_maps[0].map(
HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
SchemaProjectionMap right_to_key = schema_mgr_->proj_maps[1].map(
HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap right_to_pay = schema_mgr_->proj_maps[1].map(
HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);

AppendFields(left_to_key, left_to_pay, left_key, left_payload);
AppendFields(right_to_key, right_to_pay, right_key, right_payload);

ARROW_ASSIGN_OR_RAISE(Datum mask,
ExecuteScalarExpression(filter_, concatenated, ctx_));

size_t num_probed_rows = match.size() + no_match.size();
if (mask.is_scalar()) {
const auto& mask_scalar = mask.scalar_as<BooleanScalar>();
save-buffer marked this conversation as resolved.
Show resolved Hide resolved
if (mask_scalar.is_valid && mask_scalar.value) {
// All rows passed, nothing left to do
return Status::OK();
} else {
// Nothing passed, no_match becomes everything
no_match.resize(num_probed_rows);
std::iota(no_match.begin(), no_match.end(), 0);
match_left.clear();
match_right.clear();
match.clear();
return Status::OK();
}
}
ARROW_DCHECK_EQ(mask.array()->offset, 0);
ARROW_DCHECK_EQ(mask.array()->length, static_cast<int64_t>(match_left.size()));
const uint8_t* validity =
mask.array()->buffers[0] ? mask.array()->buffers[0]->data() : nullptr;
const uint8_t* comparisons = mask.array()->buffers[1]->data();
size_t num_rows = match_left.size();

match.clear();
no_match.clear();

int32_t match_idx = 0; // current size of new match_left
int32_t irow = 0; // index into match_left
for (int32_t curr_left = 0; static_cast<size_t>(curr_left) < num_probed_rows;
curr_left++) {
int32_t advance_to = static_cast<size_t>(irow) < num_rows
? match_left[irow]
: static_cast<int32_t>(num_probed_rows);
while (curr_left < advance_to) {
no_match.push_back(curr_left++);
}
bool passed = false;
for (; static_cast<size_t>(irow) < num_rows && match_left[irow] == curr_left;
irow++) {
bool is_valid = !validity || BitUtil::GetBit(validity, irow);
bool is_cmp_true = BitUtil::GetBit(comparisons, irow);
// We treat a null comparison result as false, like in SQL
if (is_valid && is_cmp_true) {
match_left[match_idx] = match_left[irow];
match_right[match_idx] = match_right[irow];
match_idx++;
passed = true;
}
}
if (passed) {
match.push_back(curr_left);
} else if (static_cast<size_t>(curr_left) < num_probed_rows) {
no_match.push_back(curr_left);
}
}
match_left.resize(match_idx);
match_right.resize(match_idx);
return Status::OK();
}

Status ProbeBatch_OutputOne(size_t thread_index, int64_t batch_size_next,
const int32_t* opt_left_ids, const int32_t* opt_right_ids) {
if (batch_size_next == 0 || (!opt_left_ids && !opt_right_ids)) {
Expand Down Expand Up @@ -456,6 +573,15 @@ class HashJoinBasicImpl : public HashJoinImpl {
&local_state.no_match, &local_state.match_left,
&local_state.match_right);

RETURN_NOT_OK(ProbeBatch_ResidualFilter(local_state, local_state.match,
local_state.no_match, local_state.match_left,
local_state.match_right));

for (auto i : local_state.match_right) {
// Mark row in hash table as having a match
BitUtil::SetBit(local_state.has_match.data(), i);
}

RETURN_NOT_OK(ProbeBatch_OutputAll(thread_index, local_state.exec_batch_keys,
local_state.exec_batch_payloads, local_state.match,
local_state.no_match, local_state.match_left,
Expand Down Expand Up @@ -732,6 +858,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
size_t num_threads_;
HashJoinSchema* schema_mgr_;
std::vector<JoinKeyCmp> key_cmp_;
Expression filter_;
std::unique_ptr<TaskScheduler> scheduler_;
int task_group_build_;
int task_group_queued_;
Expand Down
35 changes: 29 additions & 6 deletions cpp/src/arrow/compute/exec/hash_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ class ARROW_EXPORT HashJoinSchema {
public:
Status Init(JoinType join_type, const Schema& left_schema,
const std::vector<FieldRef>& left_keys, const Schema& right_schema,
const std::vector<FieldRef>& right_keys,
const std::vector<FieldRef>& right_keys, const Expression& filter,
const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);

Status Init(JoinType join_type, const Schema& left_schema,
const std::vector<FieldRef>& left_keys,
const std::vector<FieldRef>& left_output, const Schema& right_schema,
const std::vector<FieldRef>& right_keys,
const std::vector<FieldRef>& right_output,
const std::vector<FieldRef>& right_output, const Expression& filter,
const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);

Expand All @@ -56,9 +56,15 @@ class ARROW_EXPORT HashJoinSchema {
const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);

Result<Expression> BindFilter(Expression filter, const Schema& left_schema,
const Schema& right_schema);
std::shared_ptr<Schema> MakeOutputSchema(const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);

bool LeftPayloadIsEmpty() { return PayloadIsEmpty(0); }

bool RightPayloadIsEmpty() { return PayloadIsEmpty(1); }

static int kMissingField() {
return SchemaProjectionMaps<HashJoinProjection>::kMissingField;
}
Expand All @@ -67,9 +73,26 @@ class ARROW_EXPORT HashJoinSchema {

private:
static bool IsTypeSupported(const DataType& type);
static Result<std::vector<FieldRef>> VectorDiff(const Schema& schema,
const std::vector<FieldRef>& a,
const std::vector<FieldRef>& b);

Status CollectFilterColumns(std::vector<FieldRef>& left_filter,
std::vector<FieldRef>& right_filter,
const Expression& filter, const Schema& left_schema,
const Schema& right_schema);

Expression RewriteFilterToUseFilterSchema(int right_filter_offset,
const SchemaProjectionMap& left_to_filter,
const SchemaProjectionMap& right_to_filter,
const Expression& filter);

bool PayloadIsEmpty(int side) {
ARROW_DCHECK(side == 0 || side == 1);
return proj_maps[side].num_cols(HashJoinProjection::PAYLOAD) == 0;
}

static Result<std::vector<FieldRef>> ComputePayload(const Schema& schema,
const std::vector<FieldRef>& output,
const std::vector<FieldRef>& filter,
const std::vector<FieldRef>& key);
};

class HashJoinImpl {
Expand All @@ -80,7 +103,7 @@ class HashJoinImpl {
virtual ~HashJoinImpl() = default;
virtual Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution,
size_t num_threads, HashJoinSchema* schema_mgr,
std::vector<JoinKeyCmp> key_cmp,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
OutputBatchCallback output_batch_callback,
FinishedCallback finished_callback,
TaskScheduler::ScheduleImpl schedule_task_callback) = 0;
Expand Down
Loading