Skip to content

Commit

Permalink
[C++] Support is_in predicates for SimplifyWithGuarantee
Browse files Browse the repository at this point in the history
  • Loading branch information
larry98 committed Jul 15, 2024
1 parent 7184150 commit 67c9d4c
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 19 deletions.
105 changes: 89 additions & 16 deletions cpp/src/arrow/compute/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1242,33 +1242,104 @@ struct Inequality {
/*insert_implicit_casts=*/false, &exec_context);
}

/// Simplify an is_in predicate against this inequality as a guarantee.
Result<Expression> SimplifyIsIn(Expression expr) {
const auto& guarantee = *this;
auto call = expr.call();
auto options = checked_pointer_cast<SetLookupOptions>(call->options);

auto value_set = options->value_set.make_array();
if (!value_set) return expr;
if (value_set->length() == 0) return literal(false);

// For now, only simplify when the guarantee is non-nullable.
if (guarantee.nullable) return expr;

auto compare = [&value_set, &guarantee](size_t i) -> Result<Comparison::type> {
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar, value_set->GetScalar(i));
// Nulls compare greater than any non-null value.
if (!scalar->is_valid) {
return Comparison::GREATER;
}
ARROW_ASSIGN_OR_RAISE(Comparison::type cmp,
Comparison::Execute(scalar, guarantee.bound));
return cmp;
};

size_t lo = 0;
size_t hi = value_set->length();
while (lo + 1 < hi) {
size_t mid = (lo + hi) / 2;
ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(mid));
if (cmp & Comparison::LESS_EQUAL) {
lo = mid;
} else {
hi = mid;
}
}

ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(lo));
size_t pivot = lo + (cmp == Comparison::LESS ? 1 : 0);
bool found = cmp == Comparison::EQUAL;

if (found && (guarantee.cmp & Comparison::EQUAL)) return literal(true);

std::shared_ptr<Array> simplified_value_set;
if (guarantee.cmp == Comparison::EQUAL) {
return literal(found);
} else if (guarantee.cmp == Comparison::LESS) {
simplified_value_set = value_set->Slice(0, pivot);
} else if (guarantee.cmp == Comparison::LESS_EQUAL) {
simplified_value_set = value_set->Slice(0, pivot + (found ? 1 : 0));
} else if (guarantee.cmp == Comparison::GREATER) {
simplified_value_set = value_set->Slice(pivot + (found ? 1 : 0));
} else if (guarantee.cmp == Comparison::GREATER_EQUAL) {
simplified_value_set = value_set->Slice(pivot);
} else {
// We should never reach here.
return expr;
}

if (simplified_value_set->length() == 0) return literal(false);
if (simplified_value_set->length() == value_set->length()) return expr;

Expression::Call simplified_call;
simplified_call.function_name = "is_in";
simplified_call.arguments = call->arguments;
simplified_call.options = std::make_shared<SetLookupOptions>(
std::move(simplified_value_set), options->null_matching_behavior);
ExecContext exec_context;
return BindNonRecursive(std::move(simplified_call),
/*insert_implicit_casts=*/false, &exec_context);
}

/// \brief Simplify the given expression given this inequality as a guarantee.
Result<Expression> Simplify(Expression expr) {
Result<Expression> Simplify(Expression expr, bool is_in_value_set_sorted) {
const auto& guarantee = *this;

auto call = expr.call();
if (!call) return expr;

const auto& lhs = Comparison::StripOrderPreservingCasts(call->arguments[0]);
if (!lhs.field_ref()) return expr;
if (*lhs.field_ref() != guarantee.target) return expr;

if (call->function_name == "is_valid" || call->function_name == "is_null") {
if (guarantee.nullable) return expr;
const auto& lhs = Comparison::StripOrderPreservingCasts(call->arguments[0]);
if (!lhs.field_ref()) return expr;
if (*lhs.field_ref() != guarantee.target) return expr;

return call->function_name == "is_valid" ? literal(true) : literal(false);
}

if (call->function_name == "is_in" && is_in_value_set_sorted) {
return SimplifyIsIn(expr);
}

auto cmp = Comparison::Get(expr);
if (!cmp) return expr;

auto rhs = call->arguments[1].literal();
if (!rhs) return expr;
if (!rhs->is_scalar()) return expr;

const auto& lhs = Comparison::StripOrderPreservingCasts(call->arguments[0]);
if (!lhs.field_ref()) return expr;
if (*lhs.field_ref() != guarantee.target) return expr;

// Whether the RHS of the expression is EQUAL, LESS, or GREATER than the
// RHS of the guarantee. N.B. Comparison::type is a bitmask
ARROW_ASSIGN_OR_RAISE(const Comparison::type cmp_rhs_bound,
Expand Down Expand Up @@ -1346,7 +1417,8 @@ Result<Expression> SimplifyIsValidGuarantee(Expression expr,
} // namespace

Result<Expression> SimplifyWithGuarantee(Expression expr,
const Expression& guaranteed_true_predicate) {
const Expression& guaranteed_true_predicate,
bool is_in_value_set_sorted) {
KnownFieldValues known_values;
auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate);

Expand All @@ -1366,12 +1438,13 @@ Result<Expression> SimplifyWithGuarantee(Expression expr,
if (!guarantee.call()) continue;

if (auto inequality = Inequality::ExtractOne(guarantee)) {
ARROW_ASSIGN_OR_RAISE(auto simplified,
ModifyExpression(
std::move(expr), [](Expression expr) { return expr; },
[&](Expression expr, ...) -> Result<Expression> {
return inequality->Simplify(std::move(expr));
}));
ARROW_ASSIGN_OR_RAISE(
auto simplified,
ModifyExpression(
std::move(expr), [is_in_value_set_sorted](Expression expr) { return expr; },
[&](Expression expr, ...) -> Result<Expression> {
return inequality->Simplify(std::move(expr), is_in_value_set_sorted);
}));

if (Identical(simplified, expr)) continue;

Expand Down
8 changes: 7 additions & 1 deletion cpp/src/arrow/compute/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,15 @@ Result<Expression> ReplaceFieldsWithKnownValues(const KnownFieldValues& known_va
/// a boolean expression which is guaranteed to evaluate to `true`. For example, this is
/// used to remove redundant function calls from a filter expression or to replace a
/// reference to a constant-value field with a literal.
///
/// An is_in predicate can be simplified in certain cases if the value set is
/// sorted and does not contain deuplicates. Passing true for is_in_value_set_sorted
/// let's us assume that both of these conditions hold for any is_in call found
/// in the expression.
ARROW_EXPORT
Result<Expression> SimplifyWithGuarantee(Expression,
const Expression& guaranteed_true_predicate);
const Expression& guaranteed_true_predicate,
bool is_in_value_set_sorted = false);

/// Replace all named field refs (e.g. "x" or "x.y") with field paths (e.g. [0] or [1,3])
///
Expand Down
74 changes: 72 additions & 2 deletions cpp/src/arrow/compute/expression_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1289,14 +1289,18 @@ TEST(Expression, CanonicalizeComparison) {

struct Simplify {
Expression expr;
bool is_in_value_set_sorted = false;

struct Expectable {
Expression expr, guarantee;
bool is_in_value_set_sorted = false;

void Expect(Expression unbound_expected) {
ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema));

ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(bound, guarantee));
ASSERT_OK_AND_ASSIGN(
auto simplified,
SimplifyWithGuarantee(bound, guarantee, is_in_value_set_sorted));

ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema));
EXPECT_EQ(simplified, expected) << " original: " << expr.ToString() << "\n"
Expand All @@ -1309,7 +1313,9 @@ struct Simplify {
void Expect(bool constant) { Expect(literal(constant)); }
};

Expectable WithGuarantee(Expression guarantee) { return {expr, guarantee}; }
Expectable WithGuarantee(Expression guarantee) {
return {expr, guarantee, is_in_value_set_sorted};
}
};

TEST(Expression, SingleComparisonGuarantees) {
Expand Down Expand Up @@ -1616,6 +1622,70 @@ TEST(Expression, SimplifyWithComparisonAndNullableCaveat) {
true_unless_null(field_ref("i32")))); // not satisfiable, will drop row group
}

TEST(Expression, SimplifyIsIn) {
auto is_in = [](Expression field, std::string json_array) {
SetLookupOptions options{ArrayFromJSON(int32(), json_array)};
return call("is_in", {field}, options);
};

Simplify{
is_in(field_ref("i32"), "[]"),
/*is_in_value_set_sorted=*/true,
}
.WithGuarantee(greater(field_ref("i32"), literal(2)))
.Expect(false);

Simplify{
is_in(field_ref("i32"), "[1,3,5,7,9]"),
/*is_in_value_set_sorted=*/true,
}
.WithGuarantee(less_equal(field_ref("i32"), literal(7)))
.Expect(true);

Simplify{
is_in(field_ref("i32"), "[1,3,5,7,9]"),
/*is_in_value_set_sorted=*/true,
}
.WithGuarantee(greater(field_ref("i32"), literal(3)))
.Expect(is_in(field_ref("i32"), "[5,7,9]"));

Simplify{
is_in(field_ref("i32"), "[1,3,5,7,9]"),
/*is_in_value_set_sorted=*/true,
}
.WithGuarantee(greater(field_ref("i32"), literal(9)))
.Expect(false);

Simplify{
is_in(field_ref("i32"), "[1,3,5,7,9]"),
/*is_in_value_set_sorted=*/true,
}
.WithGuarantee(less_equal(field_ref("i32"), literal(0)))
.Expect(false);

Simplify{
is_in(field_ref("i32"), "[1,3,5,7,9]"),
/*is_in_value_set_sorted=*/true,
}
.WithGuarantee(greater(field_ref("i32"), literal(0)))
.ExpectUnchanged();

Simplify{
is_in(field_ref("i32"), "[1,3,5,7,9]"),
/*is_in_value_set_sorted=*/true,
}
.WithGuarantee(
or_(equal(field_ref("i32"), literal(3)), is_null(field_ref("i32"))))
.ExpectUnchanged();

Simplify{
is_in(field_ref("i32"), "[1,3,5,7,9]"),
/*is_in_value_set_sorted=*/false,
}
.WithGuarantee(less_equal(field_ref("i32"), literal(7)))
.ExpectUnchanged();
}

TEST(Expression, SimplifyThenExecute) {
auto filter =
or_({equal(field_ref("f32"), literal(0)),
Expand Down

0 comments on commit 67c9d4c

Please sign in to comment.