diff --git a/cpp/src/arrow/compute/expression.cc b/cpp/src/arrow/compute/expression.cc index 33e5928c2865d..6478139883da8 100644 --- a/cpp/src/arrow/compute/expression.cc +++ b/cpp/src/arrow/compute/expression.cc @@ -1242,22 +1242,95 @@ struct Inequality { /*insert_implicit_casts=*/false, &exec_context); } + /// Simplify an is_in predicate against this inequality as a guarantee. + Result SimplifyIsIn(Expression expr) { + const auto& guarantee = *this; + auto call = expr.call(); + auto options = checked_pointer_cast(call->options); + + auto value_set = options->value_set.make_array(); + if (!value_set) return expr; + if (value_set->length() == 0) return literal(false); + + // For now, only simplify when the guarantee is non-nullable. + if (guarantee.nullable) return expr; + + auto compare = [&value_set, &guarantee](size_t i) -> Result { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr scalar, value_set->GetScalar(i)); + // Nulls compare greater than any non-null value. + if (!scalar->is_valid) { + return Comparison::GREATER; + } + ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, + Comparison::Execute(scalar, guarantee.bound)); + return cmp; + }; + + size_t lo = 0; + size_t hi = value_set->length(); + while (lo + 1 < hi) { + size_t mid = (lo + hi) / 2; + ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(mid)); + if (cmp & Comparison::LESS_EQUAL) { + lo = mid; + } else { + hi = mid; + } + } + + ARROW_ASSIGN_OR_RAISE(Comparison::type cmp, compare(lo)); + size_t pivot = lo + (cmp == Comparison::LESS ? 1 : 0); + bool found = cmp == Comparison::EQUAL; + + std::shared_ptr simplified_value_set; + if (guarantee.cmp == Comparison::EQUAL) { + return literal(found); + } else if (guarantee.cmp == Comparison::LESS) { + simplified_value_set = value_set->Slice(0, pivot); + } else if (guarantee.cmp == Comparison::LESS_EQUAL) { + simplified_value_set = value_set->Slice(0, pivot + (found ? 1 : 0)); + } else if (guarantee.cmp == Comparison::GREATER) { + simplified_value_set = value_set->Slice(pivot + (found ? 1 : 0)); + } else if (guarantee.cmp == Comparison::GREATER_EQUAL) { + simplified_value_set = value_set->Slice(pivot); + } else { + // We should never reach here. + return expr; + } + + if (simplified_value_set->length() == 0) return literal(false); + if (simplified_value_set->length() == value_set->length()) return expr; + + Expression::Call simplified_call; + simplified_call.function_name = "is_in"; + simplified_call.arguments = call->arguments; + simplified_call.options = std::make_shared( + std::move(simplified_value_set), options->null_matching_behavior); + ExecContext exec_context; + return BindNonRecursive(std::move(simplified_call), + /*insert_implicit_casts=*/false, &exec_context); + } + /// \brief Simplify the given expression given this inequality as a guarantee. - Result Simplify(Expression expr) { + Result Simplify(Expression expr, bool is_in_value_set_sorted) { const auto& guarantee = *this; auto call = expr.call(); if (!call) return expr; + const auto& lhs = Comparison::StripOrderPreservingCasts(call->arguments[0]); + if (!lhs.field_ref()) return expr; + if (*lhs.field_ref() != guarantee.target) return expr; + if (call->function_name == "is_valid" || call->function_name == "is_null") { if (guarantee.nullable) return expr; - const auto& lhs = Comparison::StripOrderPreservingCasts(call->arguments[0]); - if (!lhs.field_ref()) return expr; - if (*lhs.field_ref() != guarantee.target) return expr; - return call->function_name == "is_valid" ? literal(true) : literal(false); } + if (call->function_name == "is_in" && is_in_value_set_sorted) { + return SimplifyIsIn(expr); + } + auto cmp = Comparison::Get(expr); if (!cmp) return expr; @@ -1265,10 +1338,6 @@ struct Inequality { if (!rhs) return expr; if (!rhs->is_scalar()) return expr; - const auto& lhs = Comparison::StripOrderPreservingCasts(call->arguments[0]); - if (!lhs.field_ref()) return expr; - if (*lhs.field_ref() != guarantee.target) return expr; - // Whether the RHS of the expression is EQUAL, LESS, or GREATER than the // RHS of the guarantee. N.B. Comparison::type is a bitmask ARROW_ASSIGN_OR_RAISE(const Comparison::type cmp_rhs_bound, @@ -1346,7 +1415,8 @@ Result SimplifyIsValidGuarantee(Expression expr, } // namespace Result SimplifyWithGuarantee(Expression expr, - const Expression& guaranteed_true_predicate) { + const Expression& guaranteed_true_predicate, + bool is_in_value_set_sorted) { KnownFieldValues known_values; auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); @@ -1366,12 +1436,13 @@ Result SimplifyWithGuarantee(Expression expr, if (!guarantee.call()) continue; if (auto inequality = Inequality::ExtractOne(guarantee)) { - ARROW_ASSIGN_OR_RAISE(auto simplified, - ModifyExpression( - std::move(expr), [](Expression expr) { return expr; }, - [&](Expression expr, ...) -> Result { - return inequality->Simplify(std::move(expr)); - })); + ARROW_ASSIGN_OR_RAISE( + auto simplified, + ModifyExpression( + std::move(expr), [is_in_value_set_sorted](Expression expr) { return expr; }, + [&](Expression expr, ...) -> Result { + return inequality->Simplify(std::move(expr), is_in_value_set_sorted); + })); if (Identical(simplified, expr)) continue; diff --git a/cpp/src/arrow/compute/expression.h b/cpp/src/arrow/compute/expression.h index 9a36a6d3368fb..f9ca8841702cc 100644 --- a/cpp/src/arrow/compute/expression.h +++ b/cpp/src/arrow/compute/expression.h @@ -220,9 +220,15 @@ Result ReplaceFieldsWithKnownValues(const KnownFieldValues& known_va /// a boolean expression which is guaranteed to evaluate to `true`. For example, this is /// used to remove redundant function calls from a filter expression or to replace a /// reference to a constant-value field with a literal. +/// +/// An is_in predicate can be simplified in certain cases if the value set is +/// sorted and does not contain deuplicates. Passing true for is_in_value_set_sorted +/// let's us assume that both of these conditions hold for any is_in call found +/// in the expression. ARROW_EXPORT Result SimplifyWithGuarantee(Expression, - const Expression& guaranteed_true_predicate); + const Expression& guaranteed_true_predicate, + bool is_in_value_set_sorted = false); /// Replace all named field refs (e.g. "x" or "x.y") with field paths (e.g. [0] or [1,3]) /// diff --git a/cpp/src/arrow/compute/expression_test.cc b/cpp/src/arrow/compute/expression_test.cc index d94a17b6ffadf..09733800e1b5b 100644 --- a/cpp/src/arrow/compute/expression_test.cc +++ b/cpp/src/arrow/compute/expression_test.cc @@ -1289,14 +1289,18 @@ TEST(Expression, CanonicalizeComparison) { struct Simplify { Expression expr; + bool is_in_value_set_sorted = false; struct Expectable { Expression expr, guarantee; + bool is_in_value_set_sorted = false; void Expect(Expression unbound_expected) { ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*kBoringSchema)); - ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(bound, guarantee)); + ASSERT_OK_AND_ASSIGN( + auto simplified, + SimplifyWithGuarantee(bound, guarantee, is_in_value_set_sorted)); ASSERT_OK_AND_ASSIGN(auto expected, unbound_expected.Bind(*kBoringSchema)); EXPECT_EQ(simplified, expected) << " original: " << expr.ToString() << "\n" @@ -1309,7 +1313,9 @@ struct Simplify { void Expect(bool constant) { Expect(literal(constant)); } }; - Expectable WithGuarantee(Expression guarantee) { return {expr, guarantee}; } + Expectable WithGuarantee(Expression guarantee) { + return {expr, guarantee, is_in_value_set_sorted}; + } }; TEST(Expression, SingleComparisonGuarantees) { @@ -1616,6 +1622,70 @@ TEST(Expression, SimplifyWithComparisonAndNullableCaveat) { true_unless_null(field_ref("i32")))); // not satisfiable, will drop row group } +TEST(Expression, SimplifyIsIn) { + auto is_in = [](Expression field, std::string json_array) { + SetLookupOptions options{ArrayFromJSON(int32(), json_array)}; + return call("is_in", {field}, options); + }; + + Simplify{ + is_in(field_ref("i32"), "[]"), + /*is_in_value_set_sorted=*/true, + } + .WithGuarantee(greater(field_ref("i32"), literal(2))) + .Expect(false); + + Simplify{ + is_in(field_ref("i32"), "[1,3,5,7,9]"), + /*is_in_value_set_sorted=*/true, + } + .WithGuarantee(equal(field_ref("i32"), literal(7))) + .Expect(true); + + Simplify{ + is_in(field_ref("i32"), "[1,3,5,7,9]"), + /*is_in_value_set_sorted=*/true, + } + .WithGuarantee(greater(field_ref("i32"), literal(3))) + .Expect(is_in(field_ref("i32"), "[5,7,9]")); + + Simplify{ + is_in(field_ref("i32"), "[1,3,5,7,9]"), + /*is_in_value_set_sorted=*/true, + } + .WithGuarantee(greater(field_ref("i32"), literal(9))) + .Expect(false); + + Simplify{ + is_in(field_ref("i32"), "[1,3,5,7,9]"), + /*is_in_value_set_sorted=*/true, + } + .WithGuarantee(less_equal(field_ref("i32"), literal(0))) + .Expect(false); + + Simplify{ + is_in(field_ref("i32"), "[1,3,5,7,9]"), + /*is_in_value_set_sorted=*/true, + } + .WithGuarantee(greater(field_ref("i32"), literal(0))) + .ExpectUnchanged(); + + Simplify{ + is_in(field_ref("i32"), "[1,3,5,7,9]"), + /*is_in_value_set_sorted=*/true, + } + .WithGuarantee( + or_(equal(field_ref("i32"), literal(3)), is_null(field_ref("i32")))) + .ExpectUnchanged(); + + Simplify{ + is_in(field_ref("i32"), "[1,3,5,7,9]"), + /*is_in_value_set_sorted=*/false, + } + .WithGuarantee(less_equal(field_ref("i32"), literal(7))) + .ExpectUnchanged(); +} + TEST(Expression, SimplifyThenExecute) { auto filter = or_({equal(field_ref("f32"), literal(0)),