Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-43187: [C++] Support basic is_in predicate simplification #43761

Merged
merged 4 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions cpp/src/arrow/compute/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,52 @@ struct Inequality {
/*insert_implicit_casts=*/false, &exec_context);
}

/// Simplify an `is_in` call against an inequality guarantee.
/// \pre `is_in_call` is a call to the `is_in` function
/// \return a simplified expression, or nullopt if no simplification occurred
static Result<std::optional<Expression>> SimplifyIsIn(
const Inequality& guarantee, const Expression::Call* is_in_call) {
DCHECK_EQ(is_in_call->function_name, "is_in");

auto options = checked_pointer_cast<SetLookupOptions>(is_in_call->options);

// Null-matching behavior is complex and reduces the chances of reduction
// of `is_in` calls to a single literal for every possible input, so we
// abort the simplification if nulls are possible in the input or output.
if (guarantee.nullable ||
options->null_matching_behavior == SetLookupOptions::INCONCLUSIVE) {
return std::nullopt;
}

const auto& lhs = Comparison::StripOrderPreservingCasts(is_in_call->arguments[0]);
if (!lhs.field_ref()) return std::nullopt;
if (*lhs.field_ref() != guarantee.target) return std::nullopt;

std::string func_name = Comparison::GetName(guarantee.cmp);
DCHECK_NE(func_name, "na");
std::vector<Datum> args{options->value_set, guarantee.bound};
ARROW_ASSIGN_OR_RAISE(Datum filter_mask, CallFunction(func_name, args));
FilterOptions filter_options(FilterOptions::DROP);
ARROW_ASSIGN_OR_RAISE(Datum simplified_value_set,
Filter(options->value_set, filter_mask, filter_options));

if (simplified_value_set.length() == 0) return literal(false);
if (guarantee.cmp == Comparison::EQUAL) return literal(true);
if (simplified_value_set.length() == options->value_set.length()) return std::nullopt;

ExecContext exec_context;
Expression::Call simplified_call;
simplified_call.function_name = "is_in";
simplified_call.arguments = is_in_call->arguments;
simplified_call.options = std::make_shared<SetLookupOptions>(
simplified_value_set, options->null_matching_behavior);
ARROW_ASSIGN_OR_RAISE(
Expression simplified_expr,
BindNonRecursive(std::move(simplified_call),
/*insert_implicit_casts=*/false, &exec_context));
return simplified_expr;
}

/// \brief Simplify the given expression given this inequality as a guarantee.
Result<Expression> Simplify(Expression expr) {
const auto& guarantee = *this;
Expand All @@ -1258,6 +1304,12 @@ struct Inequality {
return call->function_name == "is_valid" ? literal(true) : literal(false);
}

if (call->function_name == "is_in") {
ARROW_ASSIGN_OR_RAISE(std::optional<Expression> result,
SimplifyIsIn(guarantee, call));
return result.value_or(expr);
}

auto cmp = Comparison::Get(expr);
if (!cmp) return expr;

Expand Down
114 changes: 114 additions & 0 deletions cpp/src/arrow/compute/expression_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "arrow/array/builder_primitive.h"
#include "arrow/compute/expression_internal.h"
#include "arrow/compute/function_internal.h"
#include "arrow/compute/registry.h"
Expand Down Expand Up @@ -1616,6 +1617,81 @@ TEST(Expression, SimplifyWithComparisonAndNullableCaveat) {
true_unless_null(field_ref("i32")))); // not satisfiable, will drop row group
}

TEST(Expression, SimplifyIsIn) {
auto is_in = [](Expression field, std::shared_ptr<DataType> value_set_type,
std::string json_array,
SetLookupOptions::NullMatchingBehavior null_matching_behavior) {
SetLookupOptions options{ArrayFromJSON(value_set_type, json_array),
null_matching_behavior};
return call("is_in", {field}, options);
};

for (SetLookupOptions::NullMatchingBehavior null_matching_behavior :
{SetLookupOptions::MATCH, SetLookupOptions::SKIP, SetLookupOptions::EMIT_NULL}) {
Simplify{is_in(field_ref("i32"), int32(), "[]", null_matching_behavior)}
.WithGuarantee(greater(field_ref("i32"), literal(2)))
.Expect(false);

Simplify{is_in(field_ref("i32"), int32(), "[null]", null_matching_behavior)}
.WithGuarantee(greater(field_ref("i32"), literal(2)))
.Expect(false);

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)}
.WithGuarantee(equal(field_ref("i32"), literal(7)))
.Expect(true);

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)}
.WithGuarantee(equal(field_ref("i32"), literal(6)))
.Expect(false);

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)}
.WithGuarantee(greater(field_ref("i32"), literal(3)))
.Expect(is_in(field_ref("i32"), int32(), "[5,7,9]", null_matching_behavior));

Simplify{
is_in(field_ref("i32"), int32(), "[1,null,3,5,null,7,9]", null_matching_behavior),
}
.WithGuarantee(greater(field_ref("i32"), literal(3)))
.Expect(is_in(field_ref("i32"), int32(), "[5,7,9]", null_matching_behavior));

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)}
.WithGuarantee(greater(field_ref("i32"), literal(9)))
.Expect(false);

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)}
.WithGuarantee(less_equal(field_ref("i32"), literal(0)))
.Expect(false);

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)}
.WithGuarantee(greater(field_ref("i32"), literal(0)))
.ExpectUnchanged();

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)}
.WithGuarantee(
or_(equal(field_ref("i32"), literal(3)), is_null(field_ref("i32"))))
.ExpectUnchanged();

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd expect this to pass:

Suggested change
Simplify{is_in(field_ref("i32"), int32(), "[1,3,null]", SetLookupOptions::MATCH)}
.WithGuarantee(
or_(equal(field_ref("i32"), literal(3)), is_null(field_ref("i32"))))
.Expect(is_in(field_ref("i32"), int32(), "[3,null]", SetLookupOptions::MATCH));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original intention was to not support simplification if nulls in the value set cannot be dropped (either the guarantee is nullable or the null matching behavior is INCONCLUSIVE). This is because in the optimized implementation where we binary search and slice the value set array, slicing the front would drop nulls (assuming they are placed at the end) so we would have to reallocate a new array for the simplified value set.

Do you think we ought to support nulls in the value set, and if so any thoughts on how we'd continue to support this with the binary search/slice implementation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The filtering approach should be able to support arbitrary value sets, so it could serve as a fallback for the binary search/slice implementation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that sounds reasonable to me. I added new tests for nullable guarantees and nulls in the value set for all of the different null matching behaviors.

Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching_behavior)}
.WithGuarantee(and_(less_equal(field_ref("i32"), literal(7)),
greater(field_ref("i32"), literal(4))))
.Expect(is_in(field_ref("i32"), int32(), "[5,7]", null_matching_behavior));

Simplify{is_in(field_ref("u32"), int8(), "[1,3,5,7,9]", null_matching_behavior)}
.WithGuarantee(greater(field_ref("u32"), literal(3)))
.Expect(is_in(field_ref("u32"), int8(), "[5,7,9]", null_matching_behavior));

Simplify{is_in(field_ref("u32"), int64(), "[1,3,5,7,9]", null_matching_behavior)}
.WithGuarantee(greater(field_ref("u32"), literal(3)))
.Expect(is_in(field_ref("u32"), int64(), "[5,7,9]", null_matching_behavior));
}

Simplify{
is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", SetLookupOptions::INCONCLUSIVE),
}
.WithGuarantee(greater(field_ref("i32"), literal(3)))
.ExpectUnchanged();
}

TEST(Expression, SimplifyThenExecute) {
larry98 marked this conversation as resolved.
Show resolved Hide resolved
auto filter =
or_({equal(field_ref("f32"), literal(0)),
Expand Down Expand Up @@ -1643,6 +1719,44 @@ TEST(Expression, SimplifyThenExecute) {
AssertDatumsEqual(evaluated, simplified_evaluated, /*verbose=*/true);
}

TEST(Expression, SimplifyIsInThenExecute) {
larry98 marked this conversation as resolved.
Show resolved Hide resolved
auto input = RecordBatchFromJSON(kBoringSchema, R"([
{"i64": 2, "i32": 5},
{"i64": 5, "i32": 6},
{"i64": 3, "i32": 6},
{"i64": 3, "i32": 5},
{"i64": 4, "i32": 5},
{"i64": 2, "i32": 7},
{"i64": 5, "i32": 5}
])");

std::vector<Expression> guarantees{
greater(field_ref("i64"), literal(1)),
greater_equal(field_ref("i32"), literal(5)),
less_equal(field_ref("i64"), literal(5))};

for (const Expression& guarantee : guarantees) {
auto filter = call(
"is_in", {guarantee.call()->arguments[0]},
compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]"), true});
ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema));
ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(filter, guarantee));

Datum evaluated, simplified_evaluated;
ExpectExecute(filter, input, &evaluated);
ExpectExecute(simplified, input, &simplified_evaluated);
if (simplified_evaluated.is_scalar()) {
ASSERT_EQ(evaluated.kind(), Datum::ARRAY);
ASSERT_EQ(simplified_evaluated.type()->id(), Type::BOOL);
BooleanBuilder builder;
ASSERT_OK(builder.AppendValues(
evaluated.length(), simplified_evaluated.scalar_as<BooleanScalar>().value));
ASSERT_OK_AND_ASSIGN(simplified_evaluated, builder.Finish());
larry98 marked this conversation as resolved.
Show resolved Hide resolved
}
AssertDatumsEqual(evaluated, simplified_evaluated, /*verbose=*/true);
}
}

TEST(Expression, Filter) {
auto ExpectFilter = [](Expression filter, std::string batch_json) {
ASSERT_OK_AND_ASSIGN(auto s, kBoringSchema->AddField(0, field("in", boolean())));
Expand Down
Loading