Skip to content

Commit

Permalink
Switch to literal(true)
Browse files Browse the repository at this point in the history
  • Loading branch information
save-buffer committed Nov 6, 2021
1 parent 557084e commit 71f5bd6
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 88 deletions.
2 changes: 0 additions & 2 deletions cpp/src/arrow/compute/exec/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,6 @@ bool Expression::IsSatisfiable() const {
return true;
}

bool Expression::IsEmpty() const { return impl_ == nullptr; }

namespace {

// Produce a bound Expression from unbound Call and bound arguments.
Expand Down
3 changes: 0 additions & 3 deletions cpp/src/arrow/compute/exec/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,6 @@ class ARROW_EXPORT Expression {
/// Return true if this expression could evaluate to true.
bool IsSatisfiable() const;

/// Return true if this expression has no clauses.
bool IsEmpty() const;

// XXX someday
// Result<PipelineGraph> GetPipelines();

Expand Down
6 changes: 3 additions & 3 deletions cpp/src/arrow/compute/exec/hash_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ class HashJoinBasicImpl : public HashJoinImpl {
std::vector<int32_t>& no_match,
std::vector<int32_t>& match_left,
std::vector<int32_t>& match_right) {
if (filter_.IsEmpty()) {
if (filter_ == literal(true)) {
return Status::OK();
}
ARROW_DCHECK_EQ(match_left.size(), match_right.size());
Expand All @@ -298,13 +298,13 @@ class HashJoinBasicImpl : public HashJoinImpl {
hash_table_keys_.Decode(match_right.size(), match_right.data()));

ExecBatch left_payload;
if (schema_mgr_->HasLeftPayload()) {
if (!schema_mgr_->LeftPayloadIsEmpty()) {
ARROW_ASSIGN_OR_RAISE(left_payload, local_state.exec_batch_payloads.Decode(
match_left.size(), match_left.data()));
}

ExecBatch right_payload;
if (schema_mgr_->HasRightPayload()) {
if (!schema_mgr_->RightPayloadIsEmpty()) {
ARROW_ASSIGN_OR_RAISE(right_payload, hash_table_payloads_.Decode(
match_right.size(), match_right.data()));
}
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/arrow/compute/exec/hash_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ class ARROW_EXPORT HashJoinSchema {
std::shared_ptr<Schema> MakeOutputSchema(const std::string& left_field_name_prefix,
const std::string& right_field_name_prefix);

bool HasLeftPayload() { return HasPayload(0); }
bool LeftPayloadIsEmpty() { return PayloadIsEmpty(0); }

bool HasRightPayload() { return HasPayload(1); }
bool RightPayloadIsEmpty() { return PayloadIsEmpty(1); }

static int kMissingField() {
return SchemaProjectionMaps<HashJoinProjection>::kMissingField;
Expand All @@ -77,9 +77,9 @@ class ARROW_EXPORT HashJoinSchema {
Status TraverseExpression(std::vector<FieldRef>& refs, const Expression& filter,
const Schema& schema);

bool HasPayload(int side) {
bool PayloadIsEmpty(int side) {
ARROW_DCHECK(side == 0 || side == 1);
return proj_maps[side].num_cols(HashJoinProjection::PAYLOAD) > 0;
return proj_maps[side].num_cols(HashJoinProjection::PAYLOAD) == 0;
}

static Result<std::vector<FieldRef>> ComputePayload(const Schema& schema,
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/arrow/compute/exec/hash_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ Result<Expression> HashJoinSchema::BindFilter(Expression filter,
if (filter.IsBound()) {
return std::move(filter);
}
if (!filter.IsEmpty()) {
if (filter != literal(true)) {
FieldVector fields;
auto left = proj_maps[0].map(HashJoinProjection::FILTER, HashJoinProjection::INPUT);
auto right = proj_maps[1].map(HashJoinProjection::FILTER, HashJoinProjection::INPUT);
Expand All @@ -313,7 +313,7 @@ Result<Expression> HashJoinSchema::BindFilter(Expression filter,
}
return std::move(filter);
}
return Expression();
return literal(true);
}

Result<std::vector<FieldRef>> HashJoinSchema::CollectFilterColumns(
Expand All @@ -336,7 +336,7 @@ Result<std::vector<FieldRef>> HashJoinSchema::CollectFilterColumns(
Status HashJoinSchema::TraverseExpression(std::vector<FieldRef>& refs,
const Expression& filter,
const Schema& schema) {
if (filter.IsEmpty()) return Status::OK();
if (filter == literal(true)) return Status::OK();
if (auto* call = filter.call()) {
for (const Expression& arg : call->arguments)
RETURN_NOT_OK(TraverseExpression(refs, arg, schema));
Expand Down
121 changes: 57 additions & 64 deletions cpp/src/arrow/compute/exec/hash_join_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ TEST(HashJoin, Random) {
}

// Turn the last key comparison into a residual filter expression
Expression filter;
Expression filter = literal(true);
if (key_cmp.size() > 1) {
for (size_t i = 0; i < key_cmp.size(); i++) {
FieldRef left = key_fields[0][i];
Expand All @@ -1114,12 +1114,7 @@ TEST(HashJoin, Random) {
}
}
}

if (!filter.IsEmpty()) {
std::cout << " Filter: " << filter.ToString() << "\n";
} else {
std::cout << " Filter: <empty>\n";
}
std::cout << " Filter: " << filter.ToString() << "\n";

// Run tested join implementation
HashJoinNodeOptions join_options{
Expand Down Expand Up @@ -1690,78 +1685,76 @@ TEST(HashJoin, DictNegative) {
}
}

TEST(ExecPlanExecution, ResidualFilter) {
for (bool parallel : {false, true}) {
SCOPED_TRACE(parallel ? "parallel/merged" : "serial");
TEST(ExecPlanExecution, ResidualFilter) {
for (bool parallel : {false, true}) {
SCOPED_TRACE(parallel ? "parallel/merged" : "serial");

BatchesWithSchema input_left;
input_left.batches = {ExecBatchFromJSON({int32(), int32(), utf8()}, R"([
BatchesWithSchema input_left;
input_left.batches = {ExecBatchFromJSON({int32(), int32(), utf8()}, R"([
[1, 6, "alpha"],
[2, 5, "beta"],
[3, 4, "alpha"]
])")};
input_left.schema =
schema({field("l1", int32()), field("l2", int32()), field("l_str", utf8())});
input_left.schema =
schema({field("l1", int32()), field("l2", int32()), field("l_str", utf8())});

BatchesWithSchema input_right;
input_right.batches = {ExecBatchFromJSON({int32(), int32(), utf8()}, R"([
BatchesWithSchema input_right;
input_right.batches = {ExecBatchFromJSON({int32(), int32(), utf8()}, R"([
[5, 11, "alpha"],
[2, 12, "beta"],
[4, 16, "alpha"]
])")};
input_right.schema =
schema({field("r1", int32()), field("r2", int32()), field("r_str", utf8())});

auto exec_ctx = arrow::internal::make_unique<ExecContext>(
default_memory_pool(),
parallel ? arrow::internal::GetCpuThreadPool() : nullptr);

ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
AsyncGenerator<util::optional<ExecBatch>> sink_gen;

ExecNode* left_source;
ExecNode* right_source;
ASSERT_OK_AND_ASSIGN(
left_source,
MakeExecNode("source", plan.get(), {},
SourceNodeOptions{input_left.schema,
input_left.gen(parallel, /*slow=*/false)}));

ASSERT_OK_AND_ASSIGN(
right_source,
MakeExecNode("source", plan.get(), {},
SourceNodeOptions{input_right.schema,
input_right.gen(parallel, /*slow=*/false)}))

Expression mul = call("multiply", {field_ref("l1"), field_ref("l2")});
Expression combination = call("add", {mul, field_ref("r1")});
Expression residual_filter = less_equal(combination, field_ref("r2"));

HashJoinNodeOptions join_opts{
JoinType::FULL_OUTER,
/*left_keys=*/{"l_str"},
/*right_keys=*/{"r_str"}, std::move(residual_filter), "l_", "r_"};

ASSERT_OK_AND_ASSIGN(
auto hashjoin,
MakeExecNode("hashjoin", plan.get(), {left_source, right_source}, join_opts));

ASSERT_OK_AND_ASSIGN(std::ignore, MakeExecNode("sink", plan.get(), {hashjoin},
SinkNodeOptions{&sink_gen}));

ASSERT_FINISHES_OK_AND_ASSIGN(auto result, StartAndCollect(plan.get(), sink_gen));

std::vector<ExecBatch> expected = {
ExecBatchFromJSON({int32(), int32(), utf8(), int32(), int32(), utf8()}, R"([
input_right.schema =
schema({field("r1", int32()), field("r2", int32()), field("r_str", utf8())});

auto exec_ctx = arrow::internal::make_unique<ExecContext>(
default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr);

ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get()));
AsyncGenerator<util::optional<ExecBatch>> sink_gen;

ExecNode* left_source;
ExecNode* right_source;
ASSERT_OK_AND_ASSIGN(
left_source,
MakeExecNode("source", plan.get(), {},
SourceNodeOptions{input_left.schema,
input_left.gen(parallel, /*slow=*/false)}));

ASSERT_OK_AND_ASSIGN(
right_source,
MakeExecNode("source", plan.get(), {},
SourceNodeOptions{input_right.schema,
input_right.gen(parallel, /*slow=*/false)}))

Expression mul = call("multiply", {field_ref("l1"), field_ref("l2")});
Expression combination = call("add", {mul, field_ref("r1")});
Expression residual_filter = less_equal(combination, field_ref("r2"));

HashJoinNodeOptions join_opts{
JoinType::FULL_OUTER,
/*left_keys=*/{"l_str"},
/*right_keys=*/{"r_str"}, std::move(residual_filter), "l_", "r_"};

ASSERT_OK_AND_ASSIGN(
auto hashjoin,
MakeExecNode("hashjoin", plan.get(), {left_source, right_source}, join_opts));

ASSERT_OK_AND_ASSIGN(std::ignore, MakeExecNode("sink", plan.get(), {hashjoin},
SinkNodeOptions{&sink_gen}));

ASSERT_FINISHES_OK_AND_ASSIGN(auto result, StartAndCollect(plan.get(), sink_gen));

std::vector<ExecBatch> expected = {
ExecBatchFromJSON({int32(), int32(), utf8(), int32(), int32(), utf8()}, R"([
[1, 6, "alpha", 4, 16, "alpha"],
[1, 6, "alpha", 5, 11, "alpha"],
[2, 5, "beta", 2, 12, "beta"],
[3, 4, "alpha", 4, 16, "alpha"]])")};
std::cout << result[0].ToString() << std::endl;

AssertExecBatchesEqual(hashjoin->output_schema(), result, expected);
}
}
AssertExecBatchesEqual(hashjoin->output_schema(), result, expected);
}
}

} // namespace compute
} // namespace arrow
6 changes: 3 additions & 3 deletions cpp/src/arrow/compute/exec/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions {
static constexpr const char* default_output_prefix_for_right = "";
HashJoinNodeOptions(
JoinType in_join_type, std::vector<FieldRef> in_left_keys,
std::vector<FieldRef> in_right_keys, Expression filter = Expression(),
std::vector<FieldRef> in_right_keys, Expression filter = literal(true),
std::string output_prefix_for_left = default_output_prefix_for_left,
std::string output_prefix_for_right = default_output_prefix_for_right)
: join_type(in_join_type),
Expand All @@ -191,7 +191,7 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions {
HashJoinNodeOptions(
JoinType join_type, std::vector<FieldRef> left_keys,
std::vector<FieldRef> right_keys, std::vector<FieldRef> left_output,
std::vector<FieldRef> right_output, Expression filter = Expression(),
std::vector<FieldRef> right_output, Expression filter = literal(true),
std::string output_prefix_for_left = default_output_prefix_for_left,
std::string output_prefix_for_right = default_output_prefix_for_right)
: join_type(join_type),
Expand All @@ -212,7 +212,7 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions {
JoinType join_type, std::vector<FieldRef> left_keys,
std::vector<FieldRef> right_keys, std::vector<FieldRef> left_output,
std::vector<FieldRef> right_output, std::vector<JoinKeyCmp> key_cmp,
Expression filter = Expression(),
Expression filter = literal(true),
std::string output_prefix_for_left = default_output_prefix_for_left,
std::string output_prefix_for_right = default_output_prefix_for_right)
: join_type(join_type),
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/compute/exec/plan_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@ TEST(ExecPlanExecution, SelfInnerHashJoinSink) {

HashJoinNodeOptions join_opts{JoinType::INNER,
/*left_keys=*/{"str"},
/*right_keys=*/{"str"}, Expression(), "l_", "r_"};
/*right_keys=*/{"str"}, literal(true), "l_", "r_"};

ASSERT_OK_AND_ASSIGN(
auto hashjoin,
Expand Down Expand Up @@ -1099,7 +1099,7 @@ TEST(ExecPlanExecution, SelfOuterHashJoinSink) {

HashJoinNodeOptions join_opts{JoinType::FULL_OUTER,
/*left_keys=*/{"str"},
/*right_keys=*/{"str"}, Expression(), "l_", "r_"};
/*right_keys=*/{"str"}, literal(true), "l_", "r_"};

ASSERT_OK_AND_ASSIGN(
auto hashjoin,
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/arrow/compute/exec/util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ TEST(FieldMap, Trivial) {
auto right = schema({field("i32", int32())});

ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32"}, *right, {"i32"},
Expression(), kLeftPrefix, kRightPrefix));
literal(true), kLeftPrefix, kRightPrefix));

auto output = schema_mgr.MakeOutputSchema(kLeftPrefix, kRightPrefix);
EXPECT_THAT(*output, Eq(Schema({
Expand All @@ -55,7 +55,7 @@ TEST(FieldMap, TrivialDuplicates) {
auto right = schema({field("i32", int32())});

ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32"}, *right, {"i32"},
Expression(), "", ""));
literal(true), "", ""));

auto output = schema_mgr.MakeOutputSchema("", "");
EXPECT_THAT(*output, Eq(Schema({
Expand All @@ -75,7 +75,7 @@ TEST(FieldMap, SingleKeyField) {
auto right = schema({field("f32", float32()), field("i32", int32())});

ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32"}, *right, {"i32"},
Expression(), kLeftPrefix, kRightPrefix));
literal(true), kLeftPrefix, kRightPrefix));

EXPECT_EQ(schema_mgr.proj_maps[0].num_cols(HashJoinProjection::INPUT), 2);
EXPECT_EQ(schema_mgr.proj_maps[1].num_cols(HashJoinProjection::INPUT), 2);
Expand Down Expand Up @@ -113,7 +113,7 @@ TEST(FieldMap, TwoKeyFields) {
});

ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32", "str"}, *right,
{"i32", "str"}, Expression(), kLeftPrefix, kRightPrefix));
{"i32", "str"}, literal(true), kLeftPrefix, kRightPrefix));

auto output = schema_mgr.MakeOutputSchema(kLeftPrefix, kRightPrefix);
EXPECT_THAT(*output, Eq(Schema({
Expand Down

0 comments on commit 71f5bd6

Please sign in to comment.