Skip to content

Commit

Permalink
Fix scalar tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wesm committed Jul 29, 2022
1 parent 6537e05 commit c6d8411
Show file tree
Hide file tree
Showing 16 changed files with 145 additions and 119 deletions.
44 changes: 44 additions & 0 deletions cpp/src/arrow/compute/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,42 @@ std::shared_ptr<TypeMatcher> SameTypeId(Type::type type_id) {
return std::make_shared<SameTypeIdMatcher>(type_id);
}

class ListOfMatcher : public TypeMatcher {
public:
explicit ListOfMatcher(Type::type accepted_id, Type::type accepted_list_id)
: accepted_id_(accepted_id), accepted_list_id_(accepted_list_id) {}

bool Matches(const DataType& type) const override {
if (type.id() != accepted_list_id_) return false;
return checked_cast<const BaseListType&>(type).value_type()->id() == accepted_id_;
}

std::string ToString() const override {
std::stringstream ss;
ss << "list of Type::" << ::arrow::internal::ToString(accepted_id_);
return ss.str();
}

bool Equals(const TypeMatcher& other) const override {
if (this == &other) {
return true;
}
auto casted = dynamic_cast<const ListOfMatcher*>(&other);
if (casted == nullptr) {
return false;
}
return this->accepted_id_ == casted->accepted_id_;
}

private:
Type::type accepted_id_;
Type::type accepted_list_id_;
};

std::shared_ptr<TypeMatcher> ListOf(Type::type type_id, Type::type list_type_id) {
return std::make_shared<ListOfMatcher>(type_id, list_type_id);
}

template <typename ArrowType>
class TimeUnitMatcher : public TypeMatcher {
using ThisType = TimeUnitMatcher<ArrowType>;
Expand Down Expand Up @@ -280,6 +316,10 @@ std::shared_ptr<TypeMatcher> FixedSizeBinaryLike() {
// ----------------------------------------------------------------------
// InputType

InputType::InputType(const std::shared_ptr<DataType>& type) : InputType(type.get()) {
DCHECK(is_parameter_free(type->id()));
}

size_t InputType::Hash() const {
size_t result = kHashSeed;
hash_combine(result, static_cast<int>(kind_));
Expand Down Expand Up @@ -369,6 +409,10 @@ const TypeMatcher& InputType::type_matcher() const {
// ----------------------------------------------------------------------
// OutputType

OutputType::OutputType(const std::shared_ptr<DataType>& type) : OutputType(type.get()) {
DCHECK(is_parameter_free(type->id()));
}

Result<TypeHolder> OutputType::Resolve(KernelContext* ctx,
const std::vector<TypeHolder>& types) const {
if (kind_ == OutputType::FIXED) {
Expand Down
9 changes: 5 additions & 4 deletions cpp/src/arrow/compute/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ ARROW_EXPORT std::shared_ptr<TypeMatcher> FixedSizeBinaryLike();
// Type)
ARROW_EXPORT std::shared_ptr<TypeMatcher> Primitive();

ARROW_EXPORT std::shared_ptr<TypeMatcher> ListOf(Type::type type_id,
Type::type list_type = Type::LIST);

} // namespace match

/// \brief An object used for type-checking arguments to be passed to a kernel
Expand Down Expand Up @@ -169,8 +172,7 @@ class ARROW_EXPORT InputType {
InputType(const DataType* type) // NOLINT implicit construction
: kind_(EXACT_TYPE), type_(type) {}

InputType(const std::shared_ptr<DataType>& type) // NOLINT implicit construction
: InputType(type.get()) {}
InputType(const std::shared_ptr<DataType>& type); // NOLINT implicit construction

/// \brief Use the passed TypeMatcher to type check.
InputType(std::shared_ptr<TypeMatcher> type_matcher) // NOLINT implicit construction
Expand Down Expand Up @@ -268,8 +270,7 @@ class ARROW_EXPORT OutputType {
: kind_(FIXED), type_(type) {}

/// \brief Output an exact type
OutputType(const std::shared_ptr<DataType>& type) // NOLINT implicit construction
: OutputType(type.get()) {}
OutputType(const std::shared_ptr<DataType>& type); // NOLINT implicit construction

/// \brief Output a computed type depending on actual input types
OutputType(Resolver resolver) // NOLINT implicit construction
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/kernels/aggregate_mode.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ VectorKernel NewModeKernel(Type::type in_type, ArrayKernelExec exec,
kernel.can_execute_chunkwise = false;
kernel.output_chunked = false;
kernel.signature = KernelSignature::Make({InputType(in_type)}, ModeType);
kernel.exec = std::move(exec);
kernel.exec = exec;
kernel.exec_chunked = exec_chunked;
return kernel;
}
Expand Down
40 changes: 16 additions & 24 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1984,7 +1984,7 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
for (auto unit : TimeUnit::values()) {
InputType in_type(match::DurationTypeUnit(unit));
auto exec = ArithmeticExecFromOp<ScalarBinaryEqualTypes, Add>(Type::DURATION);
DCHECK_OK(add->AddKernel({in_type, in_type}, duration(unit), std::move(exec)));
DCHECK_OK(add->AddKernel({in_type, in_type}, duration(unit), exec));
}

AddArithmeticFunctionTimeDuration<AddTimeDuration>(add);
Expand All @@ -2011,8 +2011,7 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
for (auto unit : TimeUnit::values()) {
InputType in_type(match::DurationTypeUnit(unit));
auto exec = ArithmeticExecFromOp<ScalarBinaryEqualTypes, AddChecked>(Type::DURATION);
DCHECK_OK(
add_checked->AddKernel({in_type, in_type}, duration(unit), std::move(exec)));
DCHECK_OK(add_checked->AddKernel({in_type, in_type}, duration(unit), exec));
}

AddArithmeticFunctionTimeDuration<AddTimeDurationChecked>(add_checked);
Expand All @@ -2029,37 +2028,36 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
InputType in_type(match::TimestampTypeUnit(unit));
auto exec = ArithmeticExecFromOp<ScalarBinaryEqualTypes, Subtract>(Type::TIMESTAMP);
DCHECK_OK(subtract->AddKernel({in_type, in_type},
OutputType::Resolver(ResolveTemporalOutput),
std::move(exec)));
OutputType::Resolver(ResolveTemporalOutput), exec));
}

// Add subtract(timestamp, duration) -> timestamp
for (auto unit : TimeUnit::values()) {
InputType in_type(match::TimestampTypeUnit(unit));
auto exec = ScalarBinary<Int64Type, Int64Type, Int64Type, Subtract>::Exec;
DCHECK_OK(subtract->AddKernel({in_type, duration(unit)}, OutputType(FirstType),
std::move(exec)));
DCHECK_OK(
subtract->AddKernel({in_type, duration(unit)}, OutputType(FirstType), exec));
}

// Add subtract(duration, duration) -> duration
for (auto unit : TimeUnit::values()) {
InputType in_type(match::DurationTypeUnit(unit));
auto exec = ArithmeticExecFromOp<ScalarBinaryEqualTypes, Subtract>(Type::DURATION);
DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), std::move(exec)));
DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), exec));
}

// Add subtract(time32, time32) -> duration
for (auto unit : {TimeUnit::SECOND, TimeUnit::MILLI}) {
InputType in_type(match::Time32TypeUnit(unit));
auto exec = ScalarBinaryEqualTypes<Int64Type, Int32Type, Subtract>::Exec;
DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), std::move(exec)));
DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), exec));
}

// Add subtract(time64, time64) -> duration
for (auto unit : {TimeUnit::MICRO, TimeUnit::NANO}) {
InputType in_type(match::Time64TypeUnit(unit));
auto exec = ScalarBinaryEqualTypes<Int64Type, Int64Type, Subtract>::Exec;
DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), std::move(exec)));
DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), exec));
}

// Add subtract(date32, date32) -> duration(TimeUnit::SECOND)
Expand Down Expand Up @@ -2088,26 +2086,24 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
InputType in_type(match::TimestampTypeUnit(unit));
auto exec =
ArithmeticExecFromOp<ScalarBinaryEqualTypes, SubtractChecked>(Type::TIMESTAMP);
DCHECK_OK(subtract_checked->AddKernel({in_type, in_type},
OutputType::Resolver(ResolveTemporalOutput),
std::move(exec)));
DCHECK_OK(subtract_checked->AddKernel(
{in_type, in_type}, OutputType::Resolver(ResolveTemporalOutput), exec));
}

// Add subtract_checked(timestamp, duration) -> timestamp
for (auto unit : TimeUnit::values()) {
InputType in_type(match::TimestampTypeUnit(unit));
auto exec = ScalarBinary<Int64Type, Int64Type, Int64Type, SubtractChecked>::Exec;
DCHECK_OK(subtract_checked->AddKernel({in_type, duration(unit)},
OutputType(FirstType), std::move(exec)));
OutputType(FirstType), exec));
}

// Add subtract_checked(duration, duration) -> duration
for (auto unit : TimeUnit::values()) {
InputType in_type(match::DurationTypeUnit(unit));
auto exec =
ArithmeticExecFromOp<ScalarBinaryEqualTypes, SubtractChecked>(Type::DURATION);
DCHECK_OK(
subtract_checked->AddKernel({in_type, in_type}, duration(unit), std::move(exec)));
DCHECK_OK(subtract_checked->AddKernel({in_type, in_type}, duration(unit), exec));
}

// Add subtract_checked(date32, date32) -> duration(TimeUnit::SECOND)
Expand All @@ -2128,16 +2124,14 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
for (auto unit : {TimeUnit::SECOND, TimeUnit::MILLI}) {
InputType in_type(match::Time32TypeUnit(unit));
auto exec = ScalarBinaryEqualTypes<Int64Type, Int32Type, SubtractChecked>::Exec;
DCHECK_OK(
subtract_checked->AddKernel({in_type, in_type}, duration(unit), std::move(exec)));
DCHECK_OK(subtract_checked->AddKernel({in_type, in_type}, duration(unit), exec));
}

// Add subtract_checked(time64, time64) -> duration
for (auto unit : {TimeUnit::MICRO, TimeUnit::NANO}) {
InputType in_type(match::Time64TypeUnit(unit));
auto exec = ScalarBinaryEqualTypes<Int64Type, Int64Type, SubtractChecked>::Exec;
DCHECK_OK(
subtract_checked->AddKernel({in_type, in_type}, duration(unit), std::move(exec)));
DCHECK_OK(subtract_checked->AddKernel({in_type, in_type}, duration(unit), exec));
}

AddArithmeticFunctionTimeDuration<SubtractTimeDurationChecked>(subtract_checked);
Expand Down Expand Up @@ -2181,8 +2175,7 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
// Add divide(duration, int64) -> duration
for (auto unit : TimeUnit::values()) {
auto exec = ScalarBinaryNotNull<Int64Type, Int64Type, Int64Type, Divide>::Exec;
DCHECK_OK(
divide->AddKernel({duration(unit), int64()}, duration(unit), std::move(exec)));
DCHECK_OK(divide->AddKernel({duration(unit), int64()}, duration(unit), exec));
}
DCHECK_OK(registry->AddFunction(std::move(divide)));

Expand All @@ -2194,8 +2187,7 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
// Add divide_checked(duration, int64) -> duration
for (auto unit : TimeUnit::values()) {
auto exec = ScalarBinaryNotNull<Int64Type, Int64Type, Int64Type, DivideChecked>::Exec;
DCHECK_OK(divide_checked->AddKernel({duration(unit), int64()}, duration(unit),
std::move(exec)));
DCHECK_OK(divide_checked->AddKernel({duration(unit), int64()}, duration(unit), exec));
}

DCHECK_OK(registry->AddFunction(std::move(divide_checked)));
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ std::shared_ptr<CastFunction> GetCastToDecimal128() {
// Cast from integer
for (const DataType* in_ty : IntTypes()) {
auto exec = GenerateInteger<CastFunctor, Decimal128Type>(in_ty->id());
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec)));
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, exec));
}

// Cast from other decimal
Expand Down Expand Up @@ -706,7 +706,7 @@ std::shared_ptr<CastFunction> GetCastToDecimal256() {
// Cast from integer
for (const DataType* in_ty : IntTypes()) {
auto exec = GenerateInteger<CastFunctor, Decimal256Type>(in_ty->id());
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec)));
DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, exec));
}

// Cast from other decimal
Expand Down
52 changes: 24 additions & 28 deletions cpp/src/arrow/compute/kernels/scalar_cast_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2214,10 +2214,10 @@ static void CheckStructToStruct(const std::vector<const DataType*>& value_types)
for (const DataType* dest_value_type : value_types) {
std::vector<std::string> field_names = {"a", "b"};
std::shared_ptr<Array> a1, b1, a2, b2;
a1 = ArrayFromJSON(src_value_type->GetSharedPtr(), "[1, 2, 3, 4, null]");
b1 = ArrayFromJSON(src_value_type->GetSharedPtr(), "[null, 7, 8, 9, 0]");
a2 = ArrayFromJSON(dest_value_type->GetSharedPtr(), "[1, 2, 3, 4, null]");
b2 = ArrayFromJSON(dest_value_type->GetSharedPtr(), "[null, 7, 8, 9, 0]");
a1 = ArrayFromJSON(src_value_type, "[1, 2, 3, 4, null]");
b1 = ArrayFromJSON(src_value_type, "[null, 7, 8, 9, 0]");
a2 = ArrayFromJSON(dest_value_type, "[1, 2, 3, 4, null]");
b2 = ArrayFromJSON(dest_value_type, "[null, 7, 8, 9, 0]");
ASSERT_OK_AND_ASSIGN(auto src, StructArray::Make({a1, b1}, field_names));
ASSERT_OK_AND_ASSIGN(auto dest, StructArray::Make({a2, b2}, field_names));

Expand All @@ -2244,20 +2244,18 @@ static void CheckStructToStructSubset(const std::vector<const DataType*>& value_
std::vector<std::string> field_names = {"a", "b", "c", "d", "e"};

std::shared_ptr<Array> a1, b1, c1, d1, e1;
auto sp_src_type = src_value_type->GetSharedPtr();
auto sp_dst_type = dest_value_type->GetSharedPtr();
a1 = ArrayFromJSON(sp_src_type, "[1, 2, 5]");
b1 = ArrayFromJSON(sp_src_type, "[3, 4, 7]");
c1 = ArrayFromJSON(sp_src_type, "[9, 11, 44]");
d1 = ArrayFromJSON(sp_src_type, "[6, 51, 49]");
e1 = ArrayFromJSON(sp_src_type, "[19, 17, 74]");
a1 = ArrayFromJSON(src_value_type, "[1, 2, 5]");
b1 = ArrayFromJSON(src_value_type, "[3, 4, 7]");
c1 = ArrayFromJSON(src_value_type, "[9, 11, 44]");
d1 = ArrayFromJSON(src_value_type, "[6, 51, 49]");
e1 = ArrayFromJSON(src_value_type, "[19, 17, 74]");

std::shared_ptr<Array> a2, b2, c2, d2, e2;
a2 = ArrayFromJSON(sp_dst_type, "[1, 2, 5]");
b2 = ArrayFromJSON(sp_dst_type, "[3, 4, 7]");
c2 = ArrayFromJSON(sp_dst_type, "[9, 11, 44]");
d2 = ArrayFromJSON(sp_dst_type, "[6, 51, 49]");
e2 = ArrayFromJSON(sp_dst_type, "[19, 17, 74]");
a2 = ArrayFromJSON(dest_value_type, "[1, 2, 5]");
b2 = ArrayFromJSON(dest_value_type, "[3, 4, 7]");
c2 = ArrayFromJSON(dest_value_type, "[9, 11, 44]");
d2 = ArrayFromJSON(dest_value_type, "[6, 51, 49]");
e2 = ArrayFromJSON(dest_value_type, "[19, 17, 74]");

ASSERT_OK_AND_ASSIGN(auto src,
StructArray::Make({a1, b1, c1, d1, e1}, field_names));
Expand Down Expand Up @@ -2345,20 +2343,18 @@ static void CheckStructToStructSubsetWithNulls(
std::vector<std::string> field_names = {"a", "b", "c", "d", "e"};

std::shared_ptr<Array> a1, b1, c1, d1, e1;
auto sp_src_type = src_value_type->GetSharedPtr();
auto sp_dst_type = dest_value_type->GetSharedPtr();
a1 = ArrayFromJSON(sp_src_type, "[1, 2, 5]");
b1 = ArrayFromJSON(sp_src_type, "[3, null, 7]");
c1 = ArrayFromJSON(sp_src_type, "[9, 11, 44]");
d1 = ArrayFromJSON(sp_src_type, "[6, 51, null]");
e1 = ArrayFromJSON(sp_src_type, "[null, 17, 74]");
a1 = ArrayFromJSON(src_value_type, "[1, 2, 5]");
b1 = ArrayFromJSON(src_value_type, "[3, null, 7]");
c1 = ArrayFromJSON(src_value_type, "[9, 11, 44]");
d1 = ArrayFromJSON(src_value_type, "[6, 51, null]");
e1 = ArrayFromJSON(src_value_type, "[null, 17, 74]");

std::shared_ptr<Array> a2, b2, c2, d2, e2;
a2 = ArrayFromJSON(sp_dst_type, "[1, 2, 5]");
b2 = ArrayFromJSON(sp_dst_type, "[3, null, 7]");
c2 = ArrayFromJSON(sp_dst_type, "[9, 11, 44]");
d2 = ArrayFromJSON(sp_dst_type, "[6, 51, null]");
e2 = ArrayFromJSON(sp_dst_type, "[null, 17, 74]");
a2 = ArrayFromJSON(dest_value_type, "[1, 2, 5]");
b2 = ArrayFromJSON(dest_value_type, "[3, null, 7]");
c2 = ArrayFromJSON(dest_value_type, "[9, 11, 44]");
d2 = ArrayFromJSON(dest_value_type, "[6, 51, null]");
e2 = ArrayFromJSON(dest_value_type, "[null, 17, 74]");

std::shared_ptr<Buffer> null_bitmap;
BitmapFromVector<int>({0, 1, 0}, &null_bitmap);
Expand Down
7 changes: 3 additions & 4 deletions cpp/src/arrow/compute/kernels/scalar_compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -428,20 +428,19 @@ std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name, FunctionDo
for (const DataType* ty : BaseBinaryTypes()) {
auto exec =
GenerateVarBinaryBase<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(*ty);
DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec)));
DCHECK_OK(func->AddKernel({ty, ty}, boolean(), exec));
}

for (const auto id : {Type::DECIMAL128, Type::DECIMAL256}) {
auto exec = GenerateDecimal<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(id);
DCHECK_OK(
func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec)));
DCHECK_OK(func->AddKernel({InputType(id), InputType(id)}, boolean(), exec));
}

{
auto exec =
applicator::ScalarBinaryEqualTypes<BooleanType, FixedSizeBinaryType, Op>::Exec;
auto ty = InputType(Type::FIXED_SIZE_BINARY);
DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec)));
DCHECK_OK(func->AddKernel({ty, ty}, boolean(), exec));
}

return func;
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/arrow/compute/kernels/scalar_if_else.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2663,15 +2663,15 @@ void AddPrimitiveCaseWhenKernels(const std::shared_ptr<CaseWhenFunction>& scalar
const std::vector<const DataType*>& types) {
for (auto&& type : types) {
auto exec = GenerateTypeAgnosticPrimitive<CaseWhenFunctor>(*type);
AddCaseWhenKernel(scalar_function, type, std::move(exec));
AddCaseWhenKernel(scalar_function, type, exec);
}
}

void AddBinaryCaseWhenKernels(const std::shared_ptr<CaseWhenFunction>& scalar_function,
const std::vector<const DataType*>& types) {
for (auto&& type : types) {
auto exec = GenerateTypeAgnosticVarBinaryBase<CaseWhenFunctor>(*type);
AddCaseWhenKernel(scalar_function, type, std::move(exec));
AddCaseWhenKernel(scalar_function, type, exec);
}
}

Expand All @@ -2690,7 +2690,7 @@ void AddPrimitiveCoalesceKernels(const std::shared_ptr<ScalarFunction>& scalar_f
const std::vector<const DataType*>& types) {
for (auto&& type : types) {
auto exec = GenerateTypeAgnosticPrimitive<CoalesceFunctor>(*type);
AddCoalesceKernel(scalar_function, type, std::move(exec));
AddCoalesceKernel(scalar_function, type, exec);
}
}

Expand All @@ -2709,7 +2709,7 @@ void AddPrimitiveChooseKernels(const std::shared_ptr<ScalarFunction>& scalar_fun
const std::vector<const DataType*>& types) {
for (auto&& type : types) {
auto exec = GenerateTypeAgnosticPrimitive<ChooseFunctor>(*type);
AddChooseKernel(scalar_function, type, std::move(exec));
AddChooseKernel(scalar_function, type, exec);
}
}

Expand Down
Loading

0 comments on commit c6d8411

Please sign in to comment.