Skip to content

Commit

Permalink
implement inexact dispatch on guarantee and value set types
Browse files Browse the repository at this point in the history
  • Loading branch information
larry98 committed Aug 1, 2024
1 parent b60414b commit 862aafa
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 48 deletions.
67 changes: 35 additions & 32 deletions cpp/src/arrow/compute/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "arrow/compute/expression_internal.h"
#include "arrow/compute/function_internal.h"
#include "arrow/compute/util.h"
#include "arrow/compute/kernels/codegen_internal.h"
#include "arrow/compute/kernels/set_lookup_internal.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/reader.h"
Expand Down Expand Up @@ -1185,9 +1186,11 @@ Result<std::shared_ptr<Array>> PrepareIsInValueSet(std::shared_ptr<Array> value_
///
/// \pre `is_in_call` is a call to the `is_in` function
/// \return the value set to be simplified, guaranteed to be sorted with no
/// duplicate or null values
/// duplicate or null values and cast to the given type
Result<std::shared_ptr<Array>> GetIsInValueSetForSimplification(
const Expression::Call* is_in_call, SimplificationContext& context) {
const Expression::Call* is_in_call,
const TypeHolder& type,
SimplificationContext& context) {
DCHECK_EQ(is_in_call->function_name, "is_in");
std::shared_ptr<Array>& value_set = context.is_in_value_sets[is_in_call];
if (!value_set) {
Expand All @@ -1202,6 +1205,11 @@ Result<std::shared_ptr<Array>> GetIsInValueSetForSimplification(
ARROW_ASSIGN_OR_RAISE(state->sorted_and_unique_value_set,
PrepareIsInValueSet(unprepared_value_set));
}
if (!state->sorted_and_unique_value_set->type()->Equals(*type)) {
ARROW_ASSIGN_OR_RAISE(
state->sorted_and_unique_value_set,
Cast(*state->sorted_and_unique_value_set, type, CastOptions::Safe()));
}
value_set = state->sorted_and_unique_value_set;
}
return value_set;
Expand Down Expand Up @@ -1317,25 +1325,17 @@ struct Inequality {
/// \return a simplified value set, or a bool if the simplification of the value set
/// means the whole is_in expr can become a boolean literal.
template <typename ArrowType>
static std::variant<std::shared_ptr<Array>, bool> SimplifyIsInValueSet(
static Result<std::variant<std::shared_ptr<Array>, bool>> SimplifyIsInValueSet(
const Inequality& guarantee, std::shared_ptr<Array> value_set) {
using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
using ScalarType = typename TypeTraits<ArrowType>::ScalarType;
using CType = decltype(checked_pointer_cast<ArrayType>(value_set)->GetView(0));

DCHECK(guarantee.bound.is_scalar());
DCHECK_EQ(guarantee.bound.type()->id(), value_set->type_id());

if (value_set->length() == 0) return false;

CType bound;
if constexpr (std::is_same_v<std::shared_ptr<Buffer>,
typename ScalarType::ValueType>) {
bound = guarantee.bound.scalar_as<ScalarType>().view();
} else {
bound = guarantee.bound.scalar_as<ScalarType>().value;
}

ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar_bound,
guarantee.bound.scalar()->CastTo(value_set->type()));
auto bound = internal::UnboxScalar<ArrowType>::Unbox(*scalar_bound);
auto compare = [&bound, &value_set](size_t i) -> Comparison::type {
DCHECK(value_set->IsValid(i));
auto value = checked_pointer_cast<ArrayType>(value_set)->GetView(i);
Expand Down Expand Up @@ -1378,7 +1378,7 @@ struct Inequality {
case Comparison::NOT_EQUAL:
case Comparison::NA:
DCHECK(false);
break;
return Status::Invalid("Invalid comparison");
}

if (value_set->length() == 0) return false;
Expand Down Expand Up @@ -1412,27 +1412,29 @@ struct Inequality {
if (*lhs.field_ref() != guarantee.target) return std::nullopt;

auto options = checked_pointer_cast<SetLookupOptions>(is_in_call->options);
Type::type type = options->value_set.type()->id();

// For now, we abort simplification if the guarantee bound's type does not
// exactly match the value set's type.
if (guarantee.bound.type()->id() != type) return std::nullopt;
std::array<TypeHolder, 2> types{guarantee.bound.type().get(),
options->value_set.type().get()};
TypeHolder cmp_type;
if (types[0] == types[1]) cmp_type = types[0];
if (!cmp_type) cmp_type = internal::CommonNumeric(types.data(), types.size());
if (!cmp_type) cmp_type = internal::CommonTemporal(types.data(), types.size());
if (!cmp_type) cmp_type = internal::CommonBinary(types.data(), types.size());
if (!cmp_type) return std::nullopt;

std::variant<std::shared_ptr<Array>, bool> result;
auto simplify_value_set = [&](auto type) -> Status {
using T = decltype(type);
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> value_set,
GetIsInValueSetForSimplification(is_in_call, context));
result = SimplifyIsInValueSet<T>(guarantee, value_set);
return Status::OK();
};

#define CASE(TYPE_CLASS) \
case TYPE_CLASS##Type::type_id: \
RETURN_NOT_OK(simplify_value_set(TYPE_CLASS##Type{})); \
break;
#define CASE(TYPE_CLASS) \
case TYPE_CLASS##Type::type_id: { \
ARROW_ASSIGN_OR_RAISE( \
std::shared_ptr<Array> value_set, \
GetIsInValueSetForSimplification(is_in_call, cmp_type, context)); \
ARROW_ASSIGN_OR_RAISE( \
result, \
SimplifyIsInValueSet<TYPE_CLASS##Type>(guarantee, value_set)); \
break; \
}

switch (type) {
switch (cmp_type.id()) {
CASE(UInt8)
CASE(Int8)
CASE(UInt16)
Expand All @@ -1452,6 +1454,7 @@ struct Inequality {
CASE(String)
CASE(LargeString)
CASE(StringView)
CASE(FixedSizeBinary)
default:
return std::nullopt;
}
Expand Down
41 changes: 25 additions & 16 deletions cpp/src/arrow/compute/expression_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1617,58 +1617,67 @@ TEST(Expression, SimplifyWithComparisonAndNullableCaveat) {
}

TEST(Expression, SimplifyIsIn) {
auto is_in = [](Expression field, std::string json_array) {
SetLookupOptions options{ArrayFromJSON(int32(), json_array),
auto is_in = [](Expression field, std::shared_ptr<DataType> value_set_type,
std::string json_array) {
SetLookupOptions options{ArrayFromJSON(value_set_type, json_array),
SetLookupOptions::MATCH};
return call("is_in", {field}, options);
};

Simplify{is_in(field_ref("i32"), "[]")}
Simplify{is_in(field_ref("i32"), int32(), "[]")}
.WithGuarantee(greater(field_ref("i32"), literal(2)))
.Expect(false);

Simplify{is_in(field_ref("i32"), "[null]")}
Simplify{is_in(field_ref("i32"), int32(), "[null]")}
.WithGuarantee(greater(field_ref("i32"), literal(2)))
.Expect(false);

Simplify{is_in(field_ref("i32"), "[1,3,5,7,9]")}
Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]")}
.WithGuarantee(equal(field_ref("i32"), literal(7)))
.Expect(true);

Simplify{is_in(field_ref("i32"), "[1,3,5,7,9]")}
Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]")}
.WithGuarantee(equal(field_ref("i32"), literal(6)))
.Expect(false);

Simplify{is_in(field_ref("i32"), "[1,3,5,7,9]")}
Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]")}
.WithGuarantee(greater(field_ref("i32"), literal(3)))
.Expect(is_in(field_ref("i32"), "[5,7,9]"));
.Expect(is_in(field_ref("i32"), int32(), "[5,7,9]"));

Simplify{is_in(field_ref("i32"), "[1,null,3,5,null,7,9]")}
Simplify{is_in(field_ref("i32"), int32(), "[1,null,3,5,null,7,9]")}
.WithGuarantee(greater(field_ref("i32"), literal(3)))
.Expect(is_in(field_ref("i32"), "[5,7,9]"));
.Expect(is_in(field_ref("i32"), int32(), "[5,7,9]"));

Simplify{is_in(field_ref("i32"), "[1,3,5,7,9]")}
Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]")}
.WithGuarantee(greater(field_ref("i32"), literal(9)))
.Expect(false);

Simplify{is_in(field_ref("i32"), "[1,3,5,7,9]")}
Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]")}
.WithGuarantee(less_equal(field_ref("i32"), literal(0)))
.Expect(false);

Simplify{is_in(field_ref("i32"), "[1,3,5,7,9]")}
Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]")}
.WithGuarantee(greater(field_ref("i32"), literal(0)))
.ExpectUnchanged();

Simplify{is_in(field_ref("i32"), "[1,3,5,7,9]")}
Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]")}
.WithGuarantee(
or_(equal(field_ref("i32"), literal(3)), is_null(field_ref("i32"))))
.ExpectUnchanged();

Simplify{is_in(field_ref("i32"), "[1,3,5,7,9]")}
Simplify{is_in(field_ref("i32"), int32(), "[1,3,5,7,9]")}
.WithGuarantee(
and_(less_equal(field_ref("i32"), literal(7)),
greater(field_ref("i32"), literal(4))))
.Expect(is_in(field_ref("i32"), "[5,7]"));
.Expect(is_in(field_ref("i32"), int32(), "[5,7]"));

Simplify{is_in(field_ref("u32"), int8(), "[1,3,5,7,9]")}
.WithGuarantee(greater(field_ref("u32"), literal(3)))
.Expect(is_in(field_ref("u32"), int32(), "[5,7,9]"));

Simplify{is_in(field_ref("u32"), int64(), "[1,3,5,7,9]")}
.WithGuarantee(greater(field_ref("u32"), literal(3)))
.Expect(is_in(field_ref("u32"), int64(), "[5,7,9]"));
}

TEST(Expression, SimplifyThenExecute) {
Expand Down

0 comments on commit 862aafa

Please sign in to comment.