Skip to content

Commit

Permalink
avoid repeated binds when simplifying is_in
Browse files Browse the repository at this point in the history
  • Loading branch information
larry98 committed Oct 1, 2024
1 parent ac6d7e8 commit 2bef6fd
Showing 1 changed file with 100 additions and 72 deletions.
172 changes: 100 additions & 72 deletions cpp/src/arrow/compute/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1243,72 +1243,6 @@ struct Inequality {
/*insert_implicit_casts=*/false, &exec_context);
}

/// Simplify an `is_in` call against an inequality guarantee.
///
/// We avoid the complexity of fully simplifying EQUAL comparisons to true
/// literals (e.g., 'x is_in [1, 2, 3]' given the guarantee 'x = 2') due to
/// potential complications with null matching behavior. This is ok for the
/// predicate pushdown use case because the overall aim is to simplify to an
/// unsatisfiable expression.
///
/// \pre `is_in_call` is a call to the `is_in` function
/// \return a simplified expression, or nullopt if no simplification occurred
static Result<std::optional<Expression>> SimplifyIsIn(
const Inequality& guarantee, const Expression::Call* is_in_call) {
DCHECK_EQ(is_in_call->function_name, "is_in");

auto options = checked_pointer_cast<SetLookupOptions>(is_in_call->options);

const auto& lhs = Comparison::StripOrderPreservingCasts(is_in_call->arguments[0]);
if (!lhs.field_ref()) return std::nullopt;
if (*lhs.field_ref() != guarantee.target) return std::nullopt;

FilterOptions::NullSelectionBehavior null_selection;
switch (options->null_matching_behavior) {
case SetLookupOptions::MATCH:
null_selection =
guarantee.nullable ? FilterOptions::EMIT_NULL : FilterOptions::DROP;
break;
case SetLookupOptions::SKIP:
null_selection = FilterOptions::DROP;
break;
case SetLookupOptions::EMIT_NULL:
if (guarantee.nullable) return std::nullopt;
null_selection = FilterOptions::DROP;
break;
case SetLookupOptions::INCONCLUSIVE:
if (guarantee.nullable) return std::nullopt;
ARROW_ASSIGN_OR_RAISE(Datum is_null, IsNull(options->value_set));
ARROW_ASSIGN_OR_RAISE(Datum any_null, Any(is_null));
if (any_null.scalar_as<BooleanScalar>().value) return std::nullopt;
null_selection = FilterOptions::DROP;
break;
}

std::string func_name = Comparison::GetName(guarantee.cmp);
DCHECK_NE(func_name, "na");
std::vector<Datum> args{options->value_set, guarantee.bound};
ARROW_ASSIGN_OR_RAISE(Datum filter_mask, CallFunction(func_name, args));
FilterOptions filter_options(null_selection);
ARROW_ASSIGN_OR_RAISE(Datum simplified_value_set,
Filter(options->value_set, filter_mask, filter_options));

if (simplified_value_set.length() == 0) return literal(false);
if (simplified_value_set.length() == options->value_set.length()) return std::nullopt;

ExecContext exec_context;
Expression::Call simplified_call;
simplified_call.function_name = "is_in";
simplified_call.arguments = is_in_call->arguments;
simplified_call.options = std::make_shared<SetLookupOptions>(
simplified_value_set, options->null_matching_behavior);
ARROW_ASSIGN_OR_RAISE(
Expression simplified_expr,
BindNonRecursive(std::move(simplified_call),
/*insert_implicit_casts=*/false, &exec_context));
return simplified_expr;
}

/// \brief Simplify the given expression given this inequality as a guarantee.
Result<Expression> Simplify(Expression expr) {
const auto& guarantee = *this;
Expand All @@ -1325,12 +1259,6 @@ struct Inequality {
return call->function_name == "is_valid" ? literal(true) : literal(false);
}

if (call->function_name == "is_in") {
ARROW_ASSIGN_OR_RAISE(std::optional<Expression> result,
SimplifyIsIn(guarantee, call));
return result.value_or(expr);
}

auto cmp = Comparison::Get(expr);
if (!cmp) return expr;

Expand Down Expand Up @@ -1416,6 +1344,88 @@ Result<Expression> SimplifyIsValidGuarantee(Expression expr,
});
}

/// Simplify an `is_in` value set against a single inequality guarantee.
///
/// We avoid the complexity of fully simplifying EQUAL comparisons to true
/// literals (e.g., 'x is_in [1, 2, 3]' given the guarantee 'x = 2') due to
/// potential complications with null matching behavior. This is ok for the
/// predicate pushdown use case because the overall aim is to simplify to an
/// unsatisfiable expression.
Result<Datum> SimplifyIsInValueSet(Datum value_set, const Inequality& guarantee,
SetLookupOptions::NullMatchingBehavior null_matching) {
FilterOptions::NullSelectionBehavior null_selection;
switch (null_matching) {
case SetLookupOptions::MATCH:
null_selection =
guarantee.nullable ? FilterOptions::EMIT_NULL : FilterOptions::DROP;
break;
case SetLookupOptions::SKIP:
null_selection = FilterOptions::DROP;
break;
case SetLookupOptions::EMIT_NULL:
if (guarantee.nullable) return value_set;
null_selection = FilterOptions::DROP;
break;
case SetLookupOptions::INCONCLUSIVE:
if (guarantee.nullable) return value_set;
ARROW_ASSIGN_OR_RAISE(Datum is_null, IsNull(value_set));
ARROW_ASSIGN_OR_RAISE(Datum any_null, Any(is_null));
if (any_null.scalar_as<BooleanScalar>().value) return value_set;
null_selection = FilterOptions::DROP;
break;
}

std::string func_name = Comparison::GetName(guarantee.cmp);
DCHECK_NE(func_name, "na");
std::vector<Datum> args{value_set, guarantee.bound};
ARROW_ASSIGN_OR_RAISE(Datum filter_mask, CallFunction(func_name, args));
FilterOptions filter_options(null_selection);
ARROW_ASSIGN_OR_RAISE(Datum simplified_value_set,
Filter(value_set, filter_mask, filter_options));
return simplified_value_set;
}

/// Simplify an `is_in` call against a list of inequality guarantees.
///
/// Simplification is done across all guarantee conjunction members at once to
/// avoid the cost of repeatedly binding the simplified expression, which is
/// linear in the size of the `is_in` value set.
Result<std::optional<Expression>> SimplifyIsInWithGuarantees(
const Expression::Call* is_in_call,
const std::vector<Expression>& guarantee_conjunction_members) {
DCHECK_EQ(is_in_call->function_name, "is_in");
DCHECK_EQ(is_in_call->arguments.size(), 1);

auto options = checked_pointer_cast<SetLookupOptions>(is_in_call->options);

const auto& lhs = Comparison::StripOrderPreservingCasts(is_in_call->arguments[0]);
if (!lhs.field_ref()) return std::nullopt;

Datum simplified_value_set = options->value_set;
for (const Expression& guarantee : guarantee_conjunction_members) {
std::optional<Inequality> inequality = Inequality::ExtractOne(guarantee);
if (!inequality) continue;
if (inequality->target != *lhs.field_ref()) continue;
ARROW_ASSIGN_OR_RAISE(simplified_value_set,
SimplifyIsInValueSet(simplified_value_set, *inequality,
options->null_matching_behavior));
}

if (simplified_value_set.length() == 0) return literal(false);
if (simplified_value_set.length() == options->value_set.length()) return std::nullopt;

ExecContext exec_context;
Expression::Call simplified_call;
simplified_call.function_name = "is_in";
simplified_call.arguments = is_in_call->arguments;
simplified_call.options = std::make_shared<SetLookupOptions>(
simplified_value_set, options->null_matching_behavior);
ARROW_ASSIGN_OR_RAISE(Expression simplified_expr,
BindNonRecursive(std::move(simplified_call),
/*insert_implicit_casts=*/false, &exec_context));
return simplified_expr;
}

} // namespace

Result<Expression> SimplifyWithGuarantee(Expression expr,
Expand Down Expand Up @@ -1464,6 +1474,24 @@ Result<Expression> SimplifyWithGuarantee(Expression expr,
}
}

auto simplify_is_in = [&](Expression expr, ...) -> Result<Expression> {
if (expr.call() && expr.call()->function_name == "is_in") {
ARROW_ASSIGN_OR_RAISE(auto simplified,
SimplifyIsInWithGuarantees(expr.call(), conjunction_members));
return simplified.value_or(expr);
} else {
return expr;
}
};
ARROW_ASSIGN_OR_RAISE(
auto simplified,
ModifyExpression(
std::move(expr), [](Expression expr) { return expr; }, simplify_is_in));
if (!Identical(simplified, expr)) {
expr = std::move(simplified);
RETURN_NOT_OK(CanonicalizeAndFoldConstants());
}

return expr;
}

Expand Down

0 comments on commit 2bef6fd

Please sign in to comment.