Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-43187 [C++] Support is_in predicates for SimplifyWithGuarantee #43256

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
300 changes: 228 additions & 72 deletions cpp/src/arrow/compute/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
#include <unordered_map>
#include <unordered_set>

#include "arrow/array/concatenate.h"
#include "arrow/chunked_array.h"
#include "arrow/compute/api_aggregate.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/exec_internal.h"
#include "arrow/compute/expression_internal.h"
#include "arrow/compute/function_internal.h"
#include "arrow/compute/kernels/set_lookup_internal.h"
#include "arrow/compute/util.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/reader.h"
Expand All @@ -38,6 +40,7 @@
#include "arrow/util/string.h"
#include "arrow/util/value_parsing.h"
#include "arrow/util/vector.h"
#include "arrow/visit_array_inline.h"

namespace arrow {

Expand Down Expand Up @@ -1243,72 +1246,6 @@ struct Inequality {
/*insert_implicit_casts=*/false, &exec_context);
}

/// Simplify an `is_in` call against an inequality guarantee.
///
/// We avoid the complexity of fully simplifying EQUAL comparisons to true
/// literals (e.g., 'x is_in [1, 2, 3]' given the guarantee 'x = 2') due to
/// potential complications with null matching behavior. This is ok for the
/// predicate pushdown use case because the overall aim is to simplify to an
/// unsatisfiable expression.
///
/// \pre `is_in_call` is a call to the `is_in` function
/// \return a simplified expression, or nullopt if no simplification occurred
static Result<std::optional<Expression>> SimplifyIsIn(
const Inequality& guarantee, const Expression::Call* is_in_call) {
DCHECK_EQ(is_in_call->function_name, "is_in");

auto options = checked_pointer_cast<SetLookupOptions>(is_in_call->options);

const auto& lhs = Comparison::StripOrderPreservingCasts(is_in_call->arguments[0]);
if (!lhs.field_ref()) return std::nullopt;
if (*lhs.field_ref() != guarantee.target) return std::nullopt;

FilterOptions::NullSelectionBehavior null_selection;
switch (options->null_matching_behavior) {
case SetLookupOptions::MATCH:
null_selection =
guarantee.nullable ? FilterOptions::EMIT_NULL : FilterOptions::DROP;
break;
case SetLookupOptions::SKIP:
null_selection = FilterOptions::DROP;
break;
case SetLookupOptions::EMIT_NULL:
if (guarantee.nullable) return std::nullopt;
null_selection = FilterOptions::DROP;
break;
case SetLookupOptions::INCONCLUSIVE:
if (guarantee.nullable) return std::nullopt;
ARROW_ASSIGN_OR_RAISE(Datum is_null, IsNull(options->value_set));
ARROW_ASSIGN_OR_RAISE(Datum any_null, Any(is_null));
if (any_null.scalar_as<BooleanScalar>().value) return std::nullopt;
null_selection = FilterOptions::DROP;
break;
}

std::string func_name = Comparison::GetName(guarantee.cmp);
DCHECK_NE(func_name, "na");
std::vector<Datum> args{options->value_set, guarantee.bound};
ARROW_ASSIGN_OR_RAISE(Datum filter_mask, CallFunction(func_name, args));
FilterOptions filter_options(null_selection);
ARROW_ASSIGN_OR_RAISE(Datum simplified_value_set,
Filter(options->value_set, filter_mask, filter_options));

if (simplified_value_set.length() == 0) return literal(false);
if (simplified_value_set.length() == options->value_set.length()) return std::nullopt;

ExecContext exec_context;
Expression::Call simplified_call;
simplified_call.function_name = "is_in";
simplified_call.arguments = is_in_call->arguments;
simplified_call.options = std::make_shared<SetLookupOptions>(
simplified_value_set, options->null_matching_behavior);
ARROW_ASSIGN_OR_RAISE(
Expression simplified_expr,
BindNonRecursive(std::move(simplified_call),
/*insert_implicit_casts=*/false, &exec_context));
return simplified_expr;
}

/// \brief Simplify the given expression given this inequality as a guarantee.
Result<Expression> Simplify(Expression expr) {
const auto& guarantee = *this;
Expand All @@ -1325,12 +1262,6 @@ struct Inequality {
return call->function_name == "is_valid" ? literal(true) : literal(false);
}

if (call->function_name == "is_in") {
ARROW_ASSIGN_OR_RAISE(std::optional<Expression> result,
SimplifyIsIn(guarantee, call));
return result.value_or(expr);
}

auto cmp = Comparison::Get(expr);
if (!cmp) return expr;

Expand Down Expand Up @@ -1416,6 +1347,213 @@ Result<Expression> SimplifyIsValidGuarantee(Expression expr,
});
}

/// Simplify an `is_in` value set against a single inequality guarantee.
///
/// We avoid the complexity of fully simplifying EQUAL comparisons to true
/// literals (e.g., 'x is_in [1, 2, 3]' given the guarantee 'x = 2') due to
/// potential complications with null matching behavior. This is ok for the
/// predicate pushdown use case because the overall aim is to simplify to an
/// unsatisfiable expression.
struct IsInValueSetSimplifier {
template <typename T>
Status Visit(const T&) {
ARROW_ASSIGN_OR_RAISE(result, SimplifyBasic());
return Status::OK();
}

template <typename T>
enable_if_t<std::is_base_of_v<FlatArray, T> || std::is_base_of_v<BaseBinaryArray, T>,
Status>
Visit(const T&) {
auto simplified =
enable_fast_simplification ? SimplifyOptimized<T>() : Status::Invalid();
if (simplified.ok()) {
result = simplified.ValueUnsafe();
} else {
ARROW_ASSIGN_OR_RAISE(result, SimplifyBasic());
}
return Status::OK();
}

/// Simplify the value set using a linear scan filter.
Result<std::shared_ptr<Array>> SimplifyBasic() {
FilterOptions::NullSelectionBehavior null_selection;
switch (null_matching) {
case SetLookupOptions::MATCH:
null_selection =
guarantee.nullable ? FilterOptions::EMIT_NULL : FilterOptions::DROP;
break;
case SetLookupOptions::SKIP:
null_selection = FilterOptions::DROP;
break;
case SetLookupOptions::EMIT_NULL:
if (guarantee.nullable) return value_set;
null_selection = FilterOptions::DROP;
break;
case SetLookupOptions::INCONCLUSIVE:
if (guarantee.nullable) return value_set;
ARROW_ASSIGN_OR_RAISE(Datum is_null, IsNull(value_set));
ARROW_ASSIGN_OR_RAISE(Datum any_null, Any(is_null));
if (any_null.scalar_as<BooleanScalar>().value) return value_set;
null_selection = FilterOptions::DROP;
break;
}

std::string func_name = Comparison::GetName(guarantee.cmp);
DCHECK_NE(func_name, "na");
std::vector<Datum> args{value_set, guarantee.bound};
ARROW_ASSIGN_OR_RAISE(Datum filter_mask, CallFunction(func_name, args));
FilterOptions filter_options(null_selection);
ARROW_ASSIGN_OR_RAISE(Datum simplified_value_set,
Filter(value_set, filter_mask, filter_options));
return simplified_value_set.make_array();
}

/// Simplify the value set using binary search.
///
/// \pre `value_set` is sorted
/// \pre `value_set` contains no duplicates
/// \pre `value_set` contains no nulls
template <typename T>
Result<std::shared_ptr<Array>> SimplifyOptimized() {
if (guarantee.nullable) return Status::Invalid();
if (null_matching == SetLookupOptions::INCONCLUSIVE) return Status::Invalid();

ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar_bound,
guarantee.bound.scalar()->CastTo(value_set->type()));
auto bound = internal::UnboxScalar<typename T::TypeClass>::Unbox(*scalar_bound);
auto compare = [&](size_t i) -> Comparison::type {
DCHECK(value_set->IsValid(i));
auto value = checked_pointer_cast<T>(value_set)->GetView(i);
return value == bound ? Comparison::EQUAL
: value < bound ? Comparison::LESS
: Comparison::GREATER;
};

size_t lo = 0;
size_t hi = value_set->length();
while (lo + 1 < hi) {
size_t mid = (lo + hi) / 2;
Comparison::type cmp = compare(mid);
if (cmp & Comparison::LESS_EQUAL) {
lo = mid;
} else {
hi = mid;
}
}

Comparison::type cmp = compare(lo);
size_t pivot = lo + (cmp == Comparison::LESS ? 1 : 0);
bool found = cmp == Comparison::EQUAL;

switch (guarantee.cmp) {
case Comparison::EQUAL:
return value_set->Slice(pivot, found ? 1 : 0);
case Comparison::LESS:
return value_set->Slice(0, pivot);
case Comparison::LESS_EQUAL:
return value_set->Slice(0, pivot + (found ? 1 : 0));
case Comparison::GREATER:
return value_set->Slice(pivot + (found ? 1 : 0));
case Comparison::GREATER_EQUAL:
return value_set->Slice(pivot);
case Comparison::NOT_EQUAL:
case Comparison::NA:
DCHECK(false);
return Status::Invalid("Invalid comparison");
}
}

static Result<std::shared_ptr<Array>> Simplify(
std::shared_ptr<Array> value_set, const Inequality& guarantee,
SetLookupOptions::NullMatchingBehavior null_matching,
bool enable_fast_simplification) {
IsInValueSetSimplifier simplifier{value_set, guarantee, null_matching,
enable_fast_simplification, nullptr};
RETURN_NOT_OK(VisitArrayInline(*value_set, &simplifier));
return simplifier.result;
}

std::shared_ptr<Array> value_set;
const Inequality& guarantee;
SetLookupOptions::NullMatchingBehavior null_matching;
bool enable_fast_simplification;
std::shared_ptr<Array> result;
};

/// Simplify an `is_in` call against a list of inequality guarantees.
///
/// Simplification is done across all guarantee conjunction members at once to
/// avoid the cost of repeatedly binding the simplified expression, which is
/// linear in the size of the `is_in` value set.
///
/// Returns a simplified expression, or nullopt if no simfpliciation occurred.
Result<std::optional<Expression>> SimplifyIsInWithGuarantees(
const Expression::Call* is_in_call,
const std::vector<Expression>& guarantee_conjunction_members) {
DCHECK_EQ(is_in_call->function_name, "is_in");
DCHECK_EQ(is_in_call->arguments.size(), 1);

auto options = checked_pointer_cast<SetLookupOptions>(is_in_call->options);

const auto& lhs = Comparison::StripOrderPreservingCasts(is_in_call->arguments[0]);
if (!lhs.field_ref()) return std::nullopt;

std::vector<Inequality> guarantees;
for (const Expression& guarantee : guarantee_conjunction_members) {
std::optional<Inequality> inequality = Inequality::ExtractOne(guarantee);
if (!inequality) continue;
if (inequality->target != *lhs.field_ref()) continue;
guarantees.emplace_back(std::move(*inequality));
}

bool guaranteed_non_nullable =
std::any_of(guarantees.begin(), guarantees.end(),
[](const Inequality& guarantee) { return !guarantee.nullable; });

std::shared_ptr<Array> simplified_value_set;
bool enable_fast_simplification = false;
if (guaranteed_non_nullable &&
options->null_matching_behavior != SetLookupOptions::INCONCLUSIVE) {
auto state =
checked_pointer_cast<internal::SetLookupStateBase>(is_in_call->kernel_state);
simplified_value_set = state->sorted_and_unique_value_set;
enable_fast_simplification = static_cast<bool>(simplified_value_set);
}
if (!simplified_value_set) {
if (options->value_set.is_array()) {
simplified_value_set = options->value_set.make_array();
} else if (options->value_set.is_chunked_array()) {
ARROW_ASSIGN_OR_RAISE(simplified_value_set,
Concatenate(options->value_set.chunked_array()->chunks()));
} else {
return Status::Invalid("`is_in` value set must be an array or chunked array");
}
}

for (Inequality& guarantee : guarantees) {
if (guaranteed_non_nullable) guarantee.nullable = false;
ARROW_ASSIGN_OR_RAISE(simplified_value_set, IsInValueSetSimplifier::Simplify(
simplified_value_set, guarantee,
options->null_matching_behavior,
enable_fast_simplification));
}

if (simplified_value_set->length() == 0) return literal(false);
if (simplified_value_set->length() == options->value_set.length()) return std::nullopt;

ExecContext exec_context;
Expression::Call simplified_call;
simplified_call.function_name = "is_in";
simplified_call.arguments = is_in_call->arguments;
simplified_call.options = std::make_shared<SetLookupOptions>(
simplified_value_set, options->null_matching_behavior);
ARROW_ASSIGN_OR_RAISE(Expression simplified_expr,
BindNonRecursive(std::move(simplified_call),
/*insert_implicit_casts=*/false, &exec_context));
return simplified_expr;
}

} // namespace

Result<Expression> SimplifyWithGuarantee(Expression expr,
Expand Down Expand Up @@ -1464,6 +1602,24 @@ Result<Expression> SimplifyWithGuarantee(Expression expr,
}
}

auto simplify_is_in = [&](Expression expr, ...) -> Result<Expression> {
if (expr.call() && expr.call()->function_name == "is_in") {
ARROW_ASSIGN_OR_RAISE(auto simplified,
SimplifyIsInWithGuarantees(expr.call(), conjunction_members));
return simplified.value_or(expr);
} else {
return expr;
}
};
ARROW_ASSIGN_OR_RAISE(
auto simplified,
ModifyExpression(
std::move(expr), [](Expression expr) { return expr; }, simplify_is_in));
if (!Identical(simplified, expr)) {
expr = std::move(simplified);
RETURN_NOT_OK(CanonicalizeAndFoldConstants());
}

return expr;
}

Expand Down
Loading
Loading