Skip to content

Commit

Permalink
ARROW-13643: [C++][Compute] Implement outer join with support for res…
Browse files Browse the repository at this point in the history
…idual predicates

Implements residual predicates on hash join

Closes #11579 from save-buffer/sasha_join_filter

Authored-by: Sasha Krassovsky <[email protected]>
Signed-off-by: David Li <[email protected]>
  • Loading branch information
save-buffer authored and lidavidm committed Nov 29, 2021
1 parent 72f7fbe commit 4913352
Show file tree
Hide file tree
Showing 9 changed files with 562 additions and 85 deletions.
141 changes: 134 additions & 7 deletions cpp/src/arrow/compute/exec/hash_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ class HashJoinBasicImpl : public HashJoinImpl {

Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution,
size_t num_threads, HashJoinSchema* schema_mgr,
std::vector<JoinKeyCmp> key_cmp, OutputBatchCallback output_batch_callback,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
OutputBatchCallback output_batch_callback,
FinishedCallback finished_callback,
TaskScheduler::ScheduleImpl schedule_task_callback) override {
num_threads = std::max(num_threads, static_cast<size_t>(1));
Expand All @@ -90,6 +91,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
num_threads_ = num_threads;
schema_mgr_ = schema_mgr;
key_cmp_ = std::move(key_cmp);
filter_ = std::move(filter);
output_batch_callback_ = std::move(output_batch_callback);
finished_callback_ = std::move(finished_callback);
local_states_.resize(num_threads);
Expand Down Expand Up @@ -207,8 +209,6 @@ class HashJoinBasicImpl : public HashJoinImpl {
for (auto it = range.first; it != range.second; ++it) {
output_match_left->push_back(irow);
output_match_right->push_back(it->second);
// Mark row in hash table as having a match
BitUtil::SetBit(local_state->has_match.data(), it->second);
has_match = true;
}
if (!has_match) {
Expand All @@ -227,10 +227,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT);
int num_out_cols_right =
schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT);
ARROW_DCHECK((opt_left_payload == nullptr) ==
(schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) == 0));
ARROW_DCHECK((opt_right_payload == nullptr) ==
(schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) == 0));

result.values.resize(num_out_cols_left + num_out_cols_right);
auto from_key = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
HashJoinProjection::KEY);
Expand Down Expand Up @@ -282,6 +279,126 @@ class HashJoinBasicImpl : public HashJoinImpl {
num_batches_produced_++;
}

Status ProbeBatch_ResidualFilter(ThreadLocalState& local_state,
std::vector<int32_t>& match,
std::vector<int32_t>& no_match,
std::vector<int32_t>& match_left,
std::vector<int32_t>& match_right) {
if (filter_ == literal(true)) {
return Status::OK();
}
ARROW_DCHECK_EQ(match_left.size(), match_right.size());

ExecBatch concatenated({}, match_left.size());

ARROW_ASSIGN_OR_RAISE(ExecBatch left_key, local_state.exec_batch_keys.Decode(
match_left.size(), match_left.data()));
ARROW_ASSIGN_OR_RAISE(
ExecBatch right_key,
hash_table_keys_.Decode(match_right.size(), match_right.data()));

ExecBatch left_payload;
if (!schema_mgr_->LeftPayloadIsEmpty()) {
ARROW_ASSIGN_OR_RAISE(left_payload, local_state.exec_batch_payloads.Decode(
match_left.size(), match_left.data()));
}

ExecBatch right_payload;
if (!schema_mgr_->RightPayloadIsEmpty()) {
ARROW_ASSIGN_OR_RAISE(right_payload, hash_table_payloads_.Decode(
match_right.size(), match_right.data()));
}

auto AppendFields = [&concatenated](const SchemaProjectionMap& to_key,
const SchemaProjectionMap& to_pay,
const ExecBatch& key, const ExecBatch& payload) {
ARROW_DCHECK(to_key.num_cols == to_pay.num_cols);
for (int i = 0; i < to_key.num_cols; i++) {
if (to_key.get(i) != SchemaProjectionMap::kMissingField) {
int key_idx = to_key.get(i);
concatenated.values.push_back(key.values[key_idx]);
} else if (to_pay.get(i) != SchemaProjectionMap::kMissingField) {
int pay_idx = to_pay.get(i);
concatenated.values.push_back(payload.values[pay_idx]);
}
}
};

SchemaProjectionMap left_to_key = schema_mgr_->proj_maps[0].map(
HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap left_to_pay = schema_mgr_->proj_maps[0].map(
HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
SchemaProjectionMap right_to_key = schema_mgr_->proj_maps[1].map(
HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap right_to_pay = schema_mgr_->proj_maps[1].map(
HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);

AppendFields(left_to_key, left_to_pay, left_key, left_payload);
AppendFields(right_to_key, right_to_pay, right_key, right_payload);

ARROW_ASSIGN_OR_RAISE(Datum mask,
ExecuteScalarExpression(filter_, concatenated, ctx_));

size_t num_probed_rows = match.size() + no_match.size();
if (mask.is_scalar()) {
const auto& mask_scalar = mask.scalar_as<BooleanScalar>();
if (mask_scalar.is_valid && mask_scalar.value) {
// All rows passed, nothing left to do
return Status::OK();
} else {
// Nothing passed, no_match becomes everything
no_match.resize(num_probed_rows);
std::iota(no_match.begin(), no_match.end(), 0);
match_left.clear();
match_right.clear();
match.clear();
return Status::OK();
}
}
ARROW_DCHECK_EQ(mask.array()->offset, 0);
ARROW_DCHECK_EQ(mask.array()->length, static_cast<int64_t>(match_left.size()));
const uint8_t* validity =
mask.array()->buffers[0] ? mask.array()->buffers[0]->data() : nullptr;
const uint8_t* comparisons = mask.array()->buffers[1]->data();
size_t num_rows = match_left.size();

match.clear();
no_match.clear();

int32_t match_idx = 0; // current size of new match_left
int32_t irow = 0; // index into match_left
for (int32_t curr_left = 0; static_cast<size_t>(curr_left) < num_probed_rows;
curr_left++) {
int32_t advance_to = static_cast<size_t>(irow) < num_rows
? match_left[irow]
: static_cast<int32_t>(num_probed_rows);
while (curr_left < advance_to) {
no_match.push_back(curr_left++);
}
bool passed = false;
for (; static_cast<size_t>(irow) < num_rows && match_left[irow] == curr_left;
irow++) {
bool is_valid = !validity || BitUtil::GetBit(validity, irow);
bool is_cmp_true = BitUtil::GetBit(comparisons, irow);
// We treat a null comparison result as false, like in SQL
if (is_valid && is_cmp_true) {
match_left[match_idx] = match_left[irow];
match_right[match_idx] = match_right[irow];
match_idx++;
passed = true;
}
}
if (passed) {
match.push_back(curr_left);
} else if (static_cast<size_t>(curr_left) < num_probed_rows) {
no_match.push_back(curr_left);
}
}
match_left.resize(match_idx);
match_right.resize(match_idx);
return Status::OK();
}

Status ProbeBatch_OutputOne(size_t thread_index, int64_t batch_size_next,
const int32_t* opt_left_ids, const int32_t* opt_right_ids) {
if (batch_size_next == 0 || (!opt_left_ids && !opt_right_ids)) {
Expand Down Expand Up @@ -456,6 +573,15 @@ class HashJoinBasicImpl : public HashJoinImpl {
&local_state.no_match, &local_state.match_left,
&local_state.match_right);

RETURN_NOT_OK(ProbeBatch_ResidualFilter(local_state, local_state.match,
local_state.no_match, local_state.match_left,
local_state.match_right));

for (auto i : local_state.match_right) {
// Mark row in hash table as having a match
BitUtil::SetBit(local_state.has_match.data(), i);
}

RETURN_NOT_OK(ProbeBatch_OutputAll(thread_index, local_state.exec_batch_keys,
local_state.exec_batch_payloads, local_state.match,
local_state.no_match, local_state.match_left,
Expand Down Expand Up @@ -732,6 +858,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
size_t num_threads_;
HashJoinSchema* schema_mgr_;
std::vector<JoinKeyCmp> key_cmp_;
Expression filter_;
std::unique_ptr<TaskScheduler> scheduler_;
int task_group_build_;
int task_group_queued_;
Expand Down
35 changes: 29 additions & 6 deletions cpp/src/arrow/compute/exec/hash_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ class ARROW_EXPORT HashJoinSchema {
public:
Status Init(JoinType join_type, const Schema& left_schema,
const std::vector<FieldRef>& left_keys, const Schema& right_schema,
const std::vector<FieldRef>& right_keys,
const std::vector<FieldRef>& right_keys, const Expression& filter,
const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);

Status Init(JoinType join_type, const Schema& left_schema,
const std::vector<FieldRef>& left_keys,
const std::vector<FieldRef>& left_output, const Schema& right_schema,
const std::vector<FieldRef>& right_keys,
const std::vector<FieldRef>& right_output,
const std::vector<FieldRef>& right_output, const Expression& filter,
const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);

Expand All @@ -56,9 +56,15 @@ class ARROW_EXPORT HashJoinSchema {
const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);

Result<Expression> BindFilter(Expression filter, const Schema& left_schema,
const Schema& right_schema);
std::shared_ptr<Schema> MakeOutputSchema(const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);

bool LeftPayloadIsEmpty() { return PayloadIsEmpty(0); }

bool RightPayloadIsEmpty() { return PayloadIsEmpty(1); }

static int kMissingField() {
return SchemaProjectionMaps<HashJoinProjection>::kMissingField;
}
Expand All @@ -67,9 +73,26 @@ class ARROW_EXPORT HashJoinSchema {

private:
static bool IsTypeSupported(const DataType& type);
static Result<std::vector<FieldRef>> VectorDiff(const Schema& schema,
const std::vector<FieldRef>& a,
const std::vector<FieldRef>& b);

Status CollectFilterColumns(std::vector<FieldRef>& left_filter,
std::vector<FieldRef>& right_filter,
const Expression& filter, const Schema& left_schema,
const Schema& right_schema);

Expression RewriteFilterToUseFilterSchema(int right_filter_offset,
const SchemaProjectionMap& left_to_filter,
const SchemaProjectionMap& right_to_filter,
const Expression& filter);

bool PayloadIsEmpty(int side) {
ARROW_DCHECK(side == 0 || side == 1);
return proj_maps[side].num_cols(HashJoinProjection::PAYLOAD) == 0;
}

static Result<std::vector<FieldRef>> ComputePayload(const Schema& schema,
const std::vector<FieldRef>& output,
const std::vector<FieldRef>& filter,
const std::vector<FieldRef>& key);
};

class HashJoinImpl {
Expand All @@ -80,7 +103,7 @@ class HashJoinImpl {
virtual ~HashJoinImpl() = default;
virtual Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution,
size_t num_threads, HashJoinSchema* schema_mgr,
std::vector<JoinKeyCmp> key_cmp,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
OutputBatchCallback output_batch_callback,
FinishedCallback finished_callback,
TaskScheduler::ScheduleImpl schedule_task_callback) = 0;
Expand Down
Loading

0 comments on commit 4913352

Please sign in to comment.