From 2bef6fd71cea79afc121cf5e3687e6375df705af Mon Sep 17 00:00:00 2001 From: Larry Wang Date: Tue, 1 Oct 2024 07:22:20 -0400 Subject: [PATCH] avoid repeated binds when simplifying is_in --- cpp/src/arrow/compute/expression.cc | 172 ++++++++++++++++------------ 1 file changed, 100 insertions(+), 72 deletions(-) diff --git a/cpp/src/arrow/compute/expression.cc b/cpp/src/arrow/compute/expression.cc index 12fda5d58f3bf..8f8ee7ee49f1a 100644 --- a/cpp/src/arrow/compute/expression.cc +++ b/cpp/src/arrow/compute/expression.cc @@ -1243,72 +1243,6 @@ struct Inequality { /*insert_implicit_casts=*/false, &exec_context); } - /// Simplify an `is_in` call against an inequality guarantee. - /// - /// We avoid the complexity of fully simplifying EQUAL comparisons to true - /// literals (e.g., 'x is_in [1, 2, 3]' given the guarantee 'x = 2') due to - /// potential complications with null matching behavior. This is ok for the - /// predicate pushdown use case because the overall aim is to simplify to an - /// unsatisfiable expression. - /// - /// \pre `is_in_call` is a call to the `is_in` function - /// \return a simplified expression, or nullopt if no simplification occurred - static Result> SimplifyIsIn( - const Inequality& guarantee, const Expression::Call* is_in_call) { - DCHECK_EQ(is_in_call->function_name, "is_in"); - - auto options = checked_pointer_cast(is_in_call->options); - - const auto& lhs = Comparison::StripOrderPreservingCasts(is_in_call->arguments[0]); - if (!lhs.field_ref()) return std::nullopt; - if (*lhs.field_ref() != guarantee.target) return std::nullopt; - - FilterOptions::NullSelectionBehavior null_selection; - switch (options->null_matching_behavior) { - case SetLookupOptions::MATCH: - null_selection = - guarantee.nullable ? FilterOptions::EMIT_NULL : FilterOptions::DROP; - break; - case SetLookupOptions::SKIP: - null_selection = FilterOptions::DROP; - break; - case SetLookupOptions::EMIT_NULL: - if (guarantee.nullable) return std::nullopt; - null_selection = FilterOptions::DROP; - break; - case SetLookupOptions::INCONCLUSIVE: - if (guarantee.nullable) return std::nullopt; - ARROW_ASSIGN_OR_RAISE(Datum is_null, IsNull(options->value_set)); - ARROW_ASSIGN_OR_RAISE(Datum any_null, Any(is_null)); - if (any_null.scalar_as().value) return std::nullopt; - null_selection = FilterOptions::DROP; - break; - } - - std::string func_name = Comparison::GetName(guarantee.cmp); - DCHECK_NE(func_name, "na"); - std::vector args{options->value_set, guarantee.bound}; - ARROW_ASSIGN_OR_RAISE(Datum filter_mask, CallFunction(func_name, args)); - FilterOptions filter_options(null_selection); - ARROW_ASSIGN_OR_RAISE(Datum simplified_value_set, - Filter(options->value_set, filter_mask, filter_options)); - - if (simplified_value_set.length() == 0) return literal(false); - if (simplified_value_set.length() == options->value_set.length()) return std::nullopt; - - ExecContext exec_context; - Expression::Call simplified_call; - simplified_call.function_name = "is_in"; - simplified_call.arguments = is_in_call->arguments; - simplified_call.options = std::make_shared( - simplified_value_set, options->null_matching_behavior); - ARROW_ASSIGN_OR_RAISE( - Expression simplified_expr, - BindNonRecursive(std::move(simplified_call), - /*insert_implicit_casts=*/false, &exec_context)); - return simplified_expr; - } - /// \brief Simplify the given expression given this inequality as a guarantee. Result Simplify(Expression expr) { const auto& guarantee = *this; @@ -1325,12 +1259,6 @@ struct Inequality { return call->function_name == "is_valid" ? literal(true) : literal(false); } - if (call->function_name == "is_in") { - ARROW_ASSIGN_OR_RAISE(std::optional result, - SimplifyIsIn(guarantee, call)); - return result.value_or(expr); - } - auto cmp = Comparison::Get(expr); if (!cmp) return expr; @@ -1416,6 +1344,88 @@ Result SimplifyIsValidGuarantee(Expression expr, }); } +/// Simplify an `is_in` value set against a single inequality guarantee. +/// +/// We avoid the complexity of fully simplifying EQUAL comparisons to true +/// literals (e.g., 'x is_in [1, 2, 3]' given the guarantee 'x = 2') due to +/// potential complications with null matching behavior. This is ok for the +/// predicate pushdown use case because the overall aim is to simplify to an +/// unsatisfiable expression. +Result SimplifyIsInValueSet(Datum value_set, const Inequality& guarantee, + SetLookupOptions::NullMatchingBehavior null_matching) { + FilterOptions::NullSelectionBehavior null_selection; + switch (null_matching) { + case SetLookupOptions::MATCH: + null_selection = + guarantee.nullable ? FilterOptions::EMIT_NULL : FilterOptions::DROP; + break; + case SetLookupOptions::SKIP: + null_selection = FilterOptions::DROP; + break; + case SetLookupOptions::EMIT_NULL: + if (guarantee.nullable) return value_set; + null_selection = FilterOptions::DROP; + break; + case SetLookupOptions::INCONCLUSIVE: + if (guarantee.nullable) return value_set; + ARROW_ASSIGN_OR_RAISE(Datum is_null, IsNull(value_set)); + ARROW_ASSIGN_OR_RAISE(Datum any_null, Any(is_null)); + if (any_null.scalar_as().value) return value_set; + null_selection = FilterOptions::DROP; + break; + } + + std::string func_name = Comparison::GetName(guarantee.cmp); + DCHECK_NE(func_name, "na"); + std::vector args{value_set, guarantee.bound}; + ARROW_ASSIGN_OR_RAISE(Datum filter_mask, CallFunction(func_name, args)); + FilterOptions filter_options(null_selection); + ARROW_ASSIGN_OR_RAISE(Datum simplified_value_set, + Filter(value_set, filter_mask, filter_options)); + return simplified_value_set; +} + +/// Simplify an `is_in` call against a list of inequality guarantees. +/// +/// Simplification is done across all guarantee conjunction members at once to +/// avoid the cost of repeatedly binding the simplified expression, which is +/// linear in the size of the `is_in` value set. +Result> SimplifyIsInWithGuarantees( + const Expression::Call* is_in_call, + const std::vector& guarantee_conjunction_members) { + DCHECK_EQ(is_in_call->function_name, "is_in"); + DCHECK_EQ(is_in_call->arguments.size(), 1); + + auto options = checked_pointer_cast(is_in_call->options); + + const auto& lhs = Comparison::StripOrderPreservingCasts(is_in_call->arguments[0]); + if (!lhs.field_ref()) return std::nullopt; + + Datum simplified_value_set = options->value_set; + for (const Expression& guarantee : guarantee_conjunction_members) { + std::optional inequality = Inequality::ExtractOne(guarantee); + if (!inequality) continue; + if (inequality->target != *lhs.field_ref()) continue; + ARROW_ASSIGN_OR_RAISE(simplified_value_set, + SimplifyIsInValueSet(simplified_value_set, *inequality, + options->null_matching_behavior)); + } + + if (simplified_value_set.length() == 0) return literal(false); + if (simplified_value_set.length() == options->value_set.length()) return std::nullopt; + + ExecContext exec_context; + Expression::Call simplified_call; + simplified_call.function_name = "is_in"; + simplified_call.arguments = is_in_call->arguments; + simplified_call.options = std::make_shared( + simplified_value_set, options->null_matching_behavior); + ARROW_ASSIGN_OR_RAISE(Expression simplified_expr, + BindNonRecursive(std::move(simplified_call), + /*insert_implicit_casts=*/false, &exec_context)); + return simplified_expr; +} + } // namespace Result SimplifyWithGuarantee(Expression expr, @@ -1464,6 +1474,24 @@ Result SimplifyWithGuarantee(Expression expr, } } + auto simplify_is_in = [&](Expression expr, ...) -> Result { + if (expr.call() && expr.call()->function_name == "is_in") { + ARROW_ASSIGN_OR_RAISE(auto simplified, + SimplifyIsInWithGuarantees(expr.call(), conjunction_members)); + return simplified.value_or(expr); + } else { + return expr; + } + }; + ARROW_ASSIGN_OR_RAISE( + auto simplified, + ModifyExpression( + std::move(expr), [](Expression expr) { return expr; }, simplify_is_in)); + if (!Identical(simplified, expr)) { + expr = std::move(simplified); + RETURN_NOT_OK(CanonicalizeAndFoldConstants()); + } + return expr; }