diff --git a/cpp/src/arrow/compute/expression.cc b/cpp/src/arrow/compute/expression.cc index 33e5928c2865d..12fda5d58f3bf 100644 --- a/cpp/src/arrow/compute/expression.cc +++ b/cpp/src/arrow/compute/expression.cc @@ -23,6 +23,7 @@ #include #include "arrow/chunked_array.h" +#include "arrow/compute/api_aggregate.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/exec_internal.h" #include "arrow/compute/expression_internal.h" @@ -1242,6 +1243,72 @@ struct Inequality { /*insert_implicit_casts=*/false, &exec_context); } + /// Simplify an `is_in` call against an inequality guarantee. + /// + /// We avoid the complexity of fully simplifying EQUAL comparisons to true + /// literals (e.g., 'x is_in [1, 2, 3]' given the guarantee 'x = 2') due to + /// potential complications with null matching behavior. This is ok for the + /// predicate pushdown use case because the overall aim is to simplify to an + /// unsatisfiable expression. + /// + /// \pre `is_in_call` is a call to the `is_in` function + /// \return a simplified expression, or nullopt if no simplification occurred + static Result> SimplifyIsIn( + const Inequality& guarantee, const Expression::Call* is_in_call) { + DCHECK_EQ(is_in_call->function_name, "is_in"); + + auto options = checked_pointer_cast(is_in_call->options); + + const auto& lhs = Comparison::StripOrderPreservingCasts(is_in_call->arguments[0]); + if (!lhs.field_ref()) return std::nullopt; + if (*lhs.field_ref() != guarantee.target) return std::nullopt; + + FilterOptions::NullSelectionBehavior null_selection; + switch (options->null_matching_behavior) { + case SetLookupOptions::MATCH: + null_selection = + guarantee.nullable ? FilterOptions::EMIT_NULL : FilterOptions::DROP; + break; + case SetLookupOptions::SKIP: + null_selection = FilterOptions::DROP; + break; + case SetLookupOptions::EMIT_NULL: + if (guarantee.nullable) return std::nullopt; + null_selection = FilterOptions::DROP; + break; + case SetLookupOptions::INCONCLUSIVE: + if (guarantee.nullable) return std::nullopt; + ARROW_ASSIGN_OR_RAISE(Datum is_null, IsNull(options->value_set)); + ARROW_ASSIGN_OR_RAISE(Datum any_null, Any(is_null)); + if (any_null.scalar_as().value) return std::nullopt; + null_selection = FilterOptions::DROP; + break; + } + + std::string func_name = Comparison::GetName(guarantee.cmp); + DCHECK_NE(func_name, "na"); + std::vector args{options->value_set, guarantee.bound}; + ARROW_ASSIGN_OR_RAISE(Datum filter_mask, CallFunction(func_name, args)); + FilterOptions filter_options(null_selection); + ARROW_ASSIGN_OR_RAISE(Datum simplified_value_set, + Filter(options->value_set, filter_mask, filter_options)); + + if (simplified_value_set.length() == 0) return literal(false); + if (simplified_value_set.length() == options->value_set.length()) return std::nullopt; + + ExecContext exec_context; + Expression::Call simplified_call; + simplified_call.function_name = "is_in"; + simplified_call.arguments = is_in_call->arguments; + simplified_call.options = std::make_shared( + simplified_value_set, options->null_matching_behavior); + ARROW_ASSIGN_OR_RAISE( + Expression simplified_expr, + BindNonRecursive(std::move(simplified_call), + /*insert_implicit_casts=*/false, &exec_context)); + return simplified_expr; + } + /// \brief Simplify the given expression given this inequality as a guarantee. Result Simplify(Expression expr) { const auto& guarantee = *this; @@ -1258,6 +1325,12 @@ struct Inequality { return call->function_name == "is_valid" ? literal(true) : literal(false); } + if (call->function_name == "is_in") { + ARROW_ASSIGN_OR_RAISE(std::optional result, + SimplifyIsIn(guarantee, call)); + return result.value_or(expr); + } + auto cmp = Comparison::Get(expr); if (!cmp) return expr; diff --git a/cpp/src/arrow/compute/expression_test.cc b/cpp/src/arrow/compute/expression_test.cc index d94a17b6ffadf..0b7e8a9c23b13 100644 --- a/cpp/src/arrow/compute/expression_test.cc +++ b/cpp/src/arrow/compute/expression_test.cc @@ -27,6 +27,7 @@ #include #include +#include "arrow/array/builder_primitive.h" #include "arrow/compute/expression_internal.h" #include "arrow/compute/function_internal.h" #include "arrow/compute/registry.h" @@ -1616,6 +1617,144 @@ TEST(Expression, SimplifyWithComparisonAndNullableCaveat) { true_unless_null(field_ref("i32")))); // not satisfiable, will drop row group } +TEST(Expression, SimplifyIsIn) { + auto is_in = [](Expression field, std::shared_ptr value_set_type, + std::string json_array, + SetLookupOptions::NullMatchingBehavior null_matching_behavior) { + SetLookupOptions options{ArrayFromJSON(value_set_type, json_array), + null_matching_behavior}; + return call("is_in", {field}, options); + }; + + for (SetLookupOptions::NullMatchingBehavior null_matching : { + SetLookupOptions::MATCH, + SetLookupOptions::SKIP, + SetLookupOptions::EMIT_NULL, + SetLookupOptions::INCONCLUSIVE, + }) { + Simplify{is_in(field_ref("i32"), int32(), "[]", null_matching)} + .WithGuarantee(greater(field_ref("i32"), literal(2))) + .Expect(false); + + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)} + .WithGuarantee(equal(field_ref("i32"), literal(6))) + .Expect(false); + + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)} + .WithGuarantee(greater(field_ref("i32"), literal(3))) + .Expect(is_in(field_ref("i32"), int32(), "[5,7,9]", null_matching)); + + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)} + .WithGuarantee(greater(field_ref("i32"), literal(9))) + .Expect(false); + + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)} + .WithGuarantee(less_equal(field_ref("i32"), literal(0))) + .Expect(false); + + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)} + .WithGuarantee(greater(field_ref("i32"), literal(0))) + .ExpectUnchanged(); + + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)} + .WithGuarantee(less_equal(field_ref("i32"), literal(9))) + .ExpectUnchanged(); + + Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]", null_matching)} + .WithGuarantee(and_(less_equal(field_ref("i32"), literal(7)), + greater(field_ref("i32"), literal(4)))) + .Expect(is_in(field_ref("i32"), int32(), "[5,7]", null_matching)); + + Simplify{is_in(field_ref("u32"), int8(), "[1,3,5,7,9]", null_matching)} + .WithGuarantee(greater(field_ref("u32"), literal(3))) + .Expect(is_in(field_ref("u32"), int8(), "[5,7,9]", null_matching)); + + Simplify{is_in(field_ref("u32"), int64(), "[1,3,5,7,9]", null_matching)} + .WithGuarantee(greater(field_ref("u32"), literal(3))) + .Expect(is_in(field_ref("u32"), int64(), "[5,7,9]", null_matching)); + } + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3]", SetLookupOptions::MATCH), + } + .WithGuarantee( + or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32")))) + .Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::MATCH)); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::MATCH), + } + .WithGuarantee(greater(field_ref("i32"), literal(2))) + .Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::MATCH)); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::MATCH), + } + .WithGuarantee( + or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32")))) + .Expect(is_in(field_ref("i32"), int32(), "[3,null]", SetLookupOptions::MATCH)); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3]", SetLookupOptions::SKIP), + } + .WithGuarantee( + or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32")))) + .Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::SKIP)); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::SKIP), + } + .WithGuarantee(greater(field_ref("i32"), literal(2))) + .Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::SKIP)); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::SKIP), + } + .WithGuarantee( + or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32")))) + .Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::SKIP)); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3]", SetLookupOptions::EMIT_NULL), + } + .WithGuarantee( + or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32")))) + .ExpectUnchanged(); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::EMIT_NULL), + } + .WithGuarantee(greater(field_ref("i32"), literal(2))) + .Expect(is_in(field_ref("i32"), int32(), "[3]", SetLookupOptions::EMIT_NULL)); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::EMIT_NULL), + } + .WithGuarantee( + or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32")))) + .ExpectUnchanged(); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3]", SetLookupOptions::INCONCLUSIVE), + } + .WithGuarantee( + or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32")))) + .ExpectUnchanged(); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::INCONCLUSIVE), + } + .WithGuarantee(greater(field_ref("i32"), literal(2))) + .ExpectUnchanged(); + + Simplify{ + is_in(field_ref("i32"), int32(), "[1,2,3,null]", SetLookupOptions::INCONCLUSIVE), + } + .WithGuarantee( + or_(greater(field_ref("i32"), literal(2)), is_null(field_ref("i32")))) + .ExpectUnchanged(); +} + TEST(Expression, SimplifyThenExecute) { auto filter = or_({equal(field_ref("f32"), literal(0)), @@ -1643,6 +1782,40 @@ TEST(Expression, SimplifyThenExecute) { AssertDatumsEqual(evaluated, simplified_evaluated, /*verbose=*/true); } +TEST(Expression, SimplifyIsInThenExecute) { + auto input = RecordBatchFromJSON(kBoringSchema, R"([ + {"i64": 2, "i32": 5}, + {"i64": 5, "i32": 6}, + {"i64": 3, "i32": 6}, + {"i64": 3, "i32": 5}, + {"i64": 4, "i32": 5}, + {"i64": 2, "i32": 7}, + {"i64": 5, "i32": 5} + ])"); + + std::vector guarantees{greater(field_ref("i64"), literal(1)), + greater_equal(field_ref("i32"), literal(5)), + less_equal(field_ref("i64"), literal(5))}; + + for (const Expression& guarantee : guarantees) { + auto filter = + call("is_in", {guarantee.call()->arguments[0]}, + compute::SetLookupOptions{ArrayFromJSON(int32(), "[1,2,3]"), true}); + ASSERT_OK_AND_ASSIGN(filter, filter.Bind(*kBoringSchema)); + ASSERT_OK_AND_ASSIGN(auto simplified, SimplifyWithGuarantee(filter, guarantee)); + + Datum evaluated, simplified_evaluated; + ExpectExecute(filter, input, &evaluated); + ExpectExecute(simplified, input, &simplified_evaluated); + if (simplified_evaluated.is_scalar()) { + ASSERT_OK_AND_ASSIGN( + simplified_evaluated, + MakeArrayFromScalar(*simplified_evaluated.scalar(), evaluated.length())); + } + AssertDatumsEqual(evaluated, simplified_evaluated, /*verbose=*/true); + } +} + TEST(Expression, Filter) { auto ExpectFilter = [](Expression filter, std::string batch_json) { ASSERT_OK_AND_ASSIGN(auto s, kBoringSchema->AddField(0, field("in", boolean())));