Skip to content

Commit

Permalink
Add support for residual filters in join
Browse files Browse the repository at this point in the history
  • Loading branch information
save-buffer committed Nov 5, 2021
1 parent da1868b commit 557084e
Show file tree
Hide file tree
Showing 11 changed files with 453 additions and 82 deletions.
2 changes: 2 additions & 0 deletions cpp/src/arrow/compute/exec/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ bool Expression::IsSatisfiable() const {
return true;
}

bool Expression::IsEmpty() const { return impl_ == nullptr; }

namespace {

// Produce a bound Expression from unbound Call and bound arguments.
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/compute/exec/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ class ARROW_EXPORT Expression {
/// Return true if this expression could evaluate to true.
bool IsSatisfiable() const;

/// Return true if this expression has no clauses.
bool IsEmpty() const;

// XXX someday
// Result<PipelineGraph> GetPipelines();

Expand Down
140 changes: 133 additions & 7 deletions cpp/src/arrow/compute/exec/hash_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ class HashJoinBasicImpl : public HashJoinImpl {

Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution,
size_t num_threads, HashJoinSchema* schema_mgr,
std::vector<JoinKeyCmp> key_cmp, OutputBatchCallback output_batch_callback,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
OutputBatchCallback output_batch_callback,
FinishedCallback finished_callback,
TaskScheduler::ScheduleImpl schedule_task_callback) override {
num_threads = std::max(num_threads, static_cast<size_t>(1));
Expand All @@ -90,6 +91,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
num_threads_ = num_threads;
schema_mgr_ = schema_mgr;
key_cmp_ = std::move(key_cmp);
filter_ = std::move(filter);
output_batch_callback_ = std::move(output_batch_callback);
finished_callback_ = std::move(finished_callback);
local_states_.resize(num_threads);
Expand Down Expand Up @@ -207,8 +209,6 @@ class HashJoinBasicImpl : public HashJoinImpl {
for (auto it = range.first; it != range.second; ++it) {
output_match_left->push_back(irow);
output_match_right->push_back(it->second);
// Mark row in hash table as having a match
BitUtil::SetBit(local_state->has_match.data(), it->second);
has_match = true;
}
if (!has_match) {
Expand All @@ -227,10 +227,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::OUTPUT);
int num_out_cols_right =
schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::OUTPUT);
ARROW_DCHECK((opt_left_payload == nullptr) ==
(schema_mgr_->proj_maps[0].num_cols(HashJoinProjection::PAYLOAD) == 0));
ARROW_DCHECK((opt_right_payload == nullptr) ==
(schema_mgr_->proj_maps[1].num_cols(HashJoinProjection::PAYLOAD) == 0));

result.values.resize(num_out_cols_left + num_out_cols_right);
auto from_key = schema_mgr_->proj_maps[0].map(HashJoinProjection::OUTPUT,
HashJoinProjection::KEY);
Expand Down Expand Up @@ -282,6 +279,125 @@ class HashJoinBasicImpl : public HashJoinImpl {
num_batches_produced_++;
}

Status ProbeBatch_ResidualFilter(ThreadLocalState& local_state,
std::vector<int32_t>& match,
std::vector<int32_t>& no_match,
std::vector<int32_t>& match_left,
std::vector<int32_t>& match_right) {
if (filter_.IsEmpty()) {
return Status::OK();
}
ARROW_DCHECK_EQ(match_left.size(), match_right.size());

ExecBatch concatenated({}, match_left.size());

ARROW_ASSIGN_OR_RAISE(ExecBatch left_key, local_state.exec_batch_keys.Decode(
match_left.size(), match_left.data()));
ARROW_ASSIGN_OR_RAISE(
ExecBatch right_key,
hash_table_keys_.Decode(match_right.size(), match_right.data()));

ExecBatch left_payload;
if (schema_mgr_->HasLeftPayload()) {
ARROW_ASSIGN_OR_RAISE(left_payload, local_state.exec_batch_payloads.Decode(
match_left.size(), match_left.data()));
}

ExecBatch right_payload;
if (schema_mgr_->HasRightPayload()) {
ARROW_ASSIGN_OR_RAISE(right_payload, hash_table_payloads_.Decode(
match_right.size(), match_right.data()));
}

auto AppendFields = [&concatenated](const SchemaProjectionMap& to_key,
const SchemaProjectionMap& to_pay,
const ExecBatch& key, const ExecBatch& payload) {
ARROW_DCHECK(to_key.num_cols == to_pay.num_cols);
for (int i = 0; i < to_key.num_cols; i++) {
if (to_key.get(i) != SchemaProjectionMap::kMissingField) {
int key_idx = to_key.get(i);
concatenated.values.push_back(key.values[key_idx]);
} else if (to_pay.get(i) != SchemaProjectionMap::kMissingField) {
int pay_idx = to_pay.get(i);
concatenated.values.push_back(payload.values[pay_idx]);
}
}
};

SchemaProjectionMap left_to_key = schema_mgr_->proj_maps[0].map(
HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap left_to_pay = schema_mgr_->proj_maps[0].map(
HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
SchemaProjectionMap right_to_key = schema_mgr_->proj_maps[1].map(
HashJoinProjection::FILTER, HashJoinProjection::KEY);
SchemaProjectionMap right_to_pay = schema_mgr_->proj_maps[1].map(
HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);

AppendFields(left_to_key, left_to_pay, left_key, left_payload);
AppendFields(right_to_key, right_to_pay, right_key, right_payload);

ARROW_ASSIGN_OR_RAISE(Datum mask,
ExecuteScalarExpression(filter_, concatenated, ctx_));

size_t num_probed_rows = match.size() + no_match.size();
if (mask.is_scalar()) {
const auto& mask_scalar = mask.scalar_as<BooleanScalar>();
if (mask_scalar.is_valid && mask_scalar.value) {
// All rows passed, nothing left to do
return Status::OK();
} else {
// Nothing passed, no_match becomes everything
no_match.resize(num_probed_rows);
std::iota(no_match.begin(), no_match.end(), 0);
match_left.clear();
match_right.clear();
match.clear();
return Status::OK();
}
}
ARROW_DCHECK(mask.array()->offset == 0);
ARROW_DCHECK(mask.array()->length == static_cast<int64_t>(match_left.size()));
const uint8_t* nulls = mask.array()->buffers[0]->data();
const uint8_t* comparisons = mask.array()->buffers[1]->data();
size_t num_rows = match_left.size();

match.clear();
no_match.clear();

int32_t match_idx = 0; // current size of new match_left
int32_t irow = 0; // index into match_left
for (int32_t curr_left = 0; static_cast<size_t>(curr_left) < num_probed_rows;
curr_left++) {
int32_t advance_to = static_cast<size_t>(irow) < num_rows
? match_left[irow]
: static_cast<int32_t>(num_probed_rows);
while (curr_left < advance_to) {
no_match.push_back(curr_left++);
}
bool passed = false;
for (; static_cast<size_t>(irow) < num_rows && match_left[irow] == curr_left;
irow++) {
bool is_null = !BitUtil::GetBit(nulls, irow);
bool is_cmp_true = BitUtil::GetBit(comparisons, irow);
// We treat a null comparison result as false, like in SQL
if (!is_null && is_cmp_true) {
match_left[match_idx] = match_left[irow];
match_right[match_idx] = match_right[irow];
match_idx++;
passed = true;
}
}
if (passed) {
match.push_back(curr_left);
} else if (static_cast<size_t>(curr_left) < num_probed_rows) {
no_match.push_back(curr_left);
}
}
match_left.resize(match_idx);
match_right.resize(match_idx);
return Status::OK();
}

Status ProbeBatch_OutputOne(size_t thread_index, int64_t batch_size_next,
const int32_t* opt_left_ids, const int32_t* opt_right_ids) {
if (batch_size_next == 0 || (!opt_left_ids && !opt_right_ids)) {
Expand Down Expand Up @@ -456,6 +572,15 @@ class HashJoinBasicImpl : public HashJoinImpl {
&local_state.no_match, &local_state.match_left,
&local_state.match_right);

RETURN_NOT_OK(ProbeBatch_ResidualFilter(local_state, local_state.match,
local_state.no_match, local_state.match_left,
local_state.match_right));

for (auto i : local_state.match_right) {
// Mark row in hash table as having a match
BitUtil::SetBit(local_state.has_match.data(), i);
}

RETURN_NOT_OK(ProbeBatch_OutputAll(thread_index, local_state.exec_batch_keys,
local_state.exec_batch_payloads, local_state.match,
local_state.no_match, local_state.match_left,
Expand Down Expand Up @@ -732,6 +857,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
size_t num_threads_;
HashJoinSchema* schema_mgr_;
std::vector<JoinKeyCmp> key_cmp_;
Expression filter_;
std::unique_ptr<TaskScheduler> scheduler_;
int task_group_build_;
int task_group_queued_;
Expand Down
29 changes: 23 additions & 6 deletions cpp/src/arrow/compute/exec/hash_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ class ARROW_EXPORT HashJoinSchema {
public:
Status Init(JoinType join_type, const Schema& left_schema,
const std::vector<FieldRef>& left_keys, const Schema& right_schema,
const std::vector<FieldRef>& right_keys,
const std::vector<FieldRef>& right_keys, const Expression& filter,
const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);

Status Init(JoinType join_type, const Schema& left_schema,
const std::vector<FieldRef>& left_keys,
const std::vector<FieldRef>& left_output, const Schema& right_schema,
const std::vector<FieldRef>& right_keys,
const std::vector<FieldRef>& right_output,
const std::vector<FieldRef>& right_output, const Expression& filter,
const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);

Expand All @@ -56,19 +56,36 @@ class ARROW_EXPORT HashJoinSchema {
const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);

Result<Expression> BindFilter(Expression filter, const Schema& left_schema,
const Schema& right_schema);
std::shared_ptr<Schema> MakeOutputSchema(const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);

bool HasLeftPayload() { return HasPayload(0); }

bool HasRightPayload() { return HasPayload(1); }

static int kMissingField() {
return SchemaProjectionMaps<HashJoinProjection>::kMissingField;
}

SchemaProjectionMaps<HashJoinProjection> proj_maps[2];

private:
static Result<std::vector<FieldRef>> VectorDiff(const Schema& schema,
const std::vector<FieldRef>& a,
const std::vector<FieldRef>& b);
Result<std::vector<FieldRef>> CollectFilterColumns(const Expression& filter,
const Schema& schema);
Status TraverseExpression(std::vector<FieldRef>& refs, const Expression& filter,
const Schema& schema);

bool HasPayload(int side) {
ARROW_DCHECK(side == 0 || side == 1);
return proj_maps[side].num_cols(HashJoinProjection::PAYLOAD) > 0;
}

static Result<std::vector<FieldRef>> ComputePayload(const Schema& schema,
const std::vector<FieldRef>& output,
const std::vector<FieldRef>& filter,
const std::vector<FieldRef>& key);
};

class HashJoinImpl {
Expand All @@ -79,7 +96,7 @@ class HashJoinImpl {
virtual ~HashJoinImpl() = default;
virtual Status Init(ExecContext* ctx, JoinType join_type, bool use_sync_execution,
size_t num_threads, HashJoinSchema* schema_mgr,
std::vector<JoinKeyCmp> key_cmp,
std::vector<JoinKeyCmp> key_cmp, Expression filter,
OutputBatchCallback output_batch_callback,
FinishedCallback finished_callback,
TaskScheduler::ScheduleImpl schedule_task_callback) = 0;
Expand Down
Loading

0 comments on commit 557084e

Please sign in to comment.