diff --git a/cpp/src/arrow/compute/kernel.cc b/cpp/src/arrow/compute/kernel.cc index 8d14575bb8e1a..bb8a1689c5f70 100644 --- a/cpp/src/arrow/compute/kernel.cc +++ b/cpp/src/arrow/compute/kernel.cc @@ -114,6 +114,42 @@ std::shared_ptr SameTypeId(Type::type type_id) { return std::make_shared(type_id); } +class ListOfMatcher : public TypeMatcher { + public: + explicit ListOfMatcher(Type::type accepted_id, Type::type accepted_list_id) + : accepted_id_(accepted_id), accepted_list_id_(accepted_list_id) {} + + bool Matches(const DataType& type) const override { + if (type.id() != accepted_list_id_) return false; + return checked_cast(type).value_type()->id() == accepted_id_; + } + + std::string ToString() const override { + std::stringstream ss; + ss << "list of Type::" << ::arrow::internal::ToString(accepted_id_); + return ss.str(); + } + + bool Equals(const TypeMatcher& other) const override { + if (this == &other) { + return true; + } + auto casted = dynamic_cast(&other); + if (casted == nullptr) { + return false; + } + return this->accepted_id_ == casted->accepted_id_; + } + + private: + Type::type accepted_id_; + Type::type accepted_list_id_; +}; + +std::shared_ptr ListOf(Type::type type_id, Type::type list_type_id) { + return std::make_shared(type_id, list_type_id); +} + template class TimeUnitMatcher : public TypeMatcher { using ThisType = TimeUnitMatcher; @@ -280,6 +316,10 @@ std::shared_ptr FixedSizeBinaryLike() { // ---------------------------------------------------------------------- // InputType +InputType::InputType(const std::shared_ptr& type) : InputType(type.get()) { + DCHECK(is_parameter_free(type->id())); +} + size_t InputType::Hash() const { size_t result = kHashSeed; hash_combine(result, static_cast(kind_)); @@ -369,6 +409,10 @@ const TypeMatcher& InputType::type_matcher() const { // ---------------------------------------------------------------------- // OutputType +OutputType::OutputType(const std::shared_ptr& type) : OutputType(type.get()) { + DCHECK(is_parameter_free(type->id())); +} + Result OutputType::Resolve(KernelContext* ctx, const std::vector& types) const { if (kind_ == OutputType::FIXED) { diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h index ed40bff606940..f3dd6984932ba 100644 --- a/cpp/src/arrow/compute/kernel.h +++ b/cpp/src/arrow/compute/kernel.h @@ -141,6 +141,9 @@ ARROW_EXPORT std::shared_ptr FixedSizeBinaryLike(); // Type) ARROW_EXPORT std::shared_ptr Primitive(); +ARROW_EXPORT std::shared_ptr ListOf(Type::type type_id, + Type::type list_type = Type::LIST); + } // namespace match /// \brief An object used for type-checking arguments to be passed to a kernel @@ -169,8 +172,7 @@ class ARROW_EXPORT InputType { InputType(const DataType* type) // NOLINT implicit construction : kind_(EXACT_TYPE), type_(type) {} - InputType(const std::shared_ptr& type) // NOLINT implicit construction - : InputType(type.get()) {} + InputType(const std::shared_ptr& type); // NOLINT implicit construction /// \brief Use the passed TypeMatcher to type check. InputType(std::shared_ptr type_matcher) // NOLINT implicit construction @@ -268,8 +270,7 @@ class ARROW_EXPORT OutputType { : kind_(FIXED), type_(type) {} /// \brief Output an exact type - OutputType(const std::shared_ptr& type) // NOLINT implicit construction - : OutputType(type.get()) {} + OutputType(const std::shared_ptr& type); // NOLINT implicit construction /// \brief Output a computed type depending on actual input types OutputType(Resolver resolver) // NOLINT implicit construction diff --git a/cpp/src/arrow/compute/kernels/aggregate_mode.cc b/cpp/src/arrow/compute/kernels/aggregate_mode.cc index 1e03e62fc3139..022f7e3520a27 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_mode.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_mode.cc @@ -454,7 +454,7 @@ VectorKernel NewModeKernel(Type::type in_type, ArrayKernelExec exec, kernel.can_execute_chunkwise = false; kernel.output_chunked = false; kernel.signature = KernelSignature::Make({InputType(in_type)}, ModeType); - kernel.exec = std::move(exec); + kernel.exec = exec; kernel.exec_chunked = exec_chunked; return kernel; } diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index e32ff1e7d403b..698234aac6a06 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -1984,7 +1984,7 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { for (auto unit : TimeUnit::values()) { InputType in_type(match::DurationTypeUnit(unit)); auto exec = ArithmeticExecFromOp(Type::DURATION); - DCHECK_OK(add->AddKernel({in_type, in_type}, duration(unit), std::move(exec))); + DCHECK_OK(add->AddKernel({in_type, in_type}, duration(unit), exec)); } AddArithmeticFunctionTimeDuration(add); @@ -2011,8 +2011,7 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { for (auto unit : TimeUnit::values()) { InputType in_type(match::DurationTypeUnit(unit)); auto exec = ArithmeticExecFromOp(Type::DURATION); - DCHECK_OK( - add_checked->AddKernel({in_type, in_type}, duration(unit), std::move(exec))); + DCHECK_OK(add_checked->AddKernel({in_type, in_type}, duration(unit), exec)); } AddArithmeticFunctionTimeDuration(add_checked); @@ -2029,37 +2028,36 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { InputType in_type(match::TimestampTypeUnit(unit)); auto exec = ArithmeticExecFromOp(Type::TIMESTAMP); DCHECK_OK(subtract->AddKernel({in_type, in_type}, - OutputType::Resolver(ResolveTemporalOutput), - std::move(exec))); + OutputType::Resolver(ResolveTemporalOutput), exec)); } // Add subtract(timestamp, duration) -> timestamp for (auto unit : TimeUnit::values()) { InputType in_type(match::TimestampTypeUnit(unit)); auto exec = ScalarBinary::Exec; - DCHECK_OK(subtract->AddKernel({in_type, duration(unit)}, OutputType(FirstType), - std::move(exec))); + DCHECK_OK( + subtract->AddKernel({in_type, duration(unit)}, OutputType(FirstType), exec)); } // Add subtract(duration, duration) -> duration for (auto unit : TimeUnit::values()) { InputType in_type(match::DurationTypeUnit(unit)); auto exec = ArithmeticExecFromOp(Type::DURATION); - DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), std::move(exec))); + DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), exec)); } // Add subtract(time32, time32) -> duration for (auto unit : {TimeUnit::SECOND, TimeUnit::MILLI}) { InputType in_type(match::Time32TypeUnit(unit)); auto exec = ScalarBinaryEqualTypes::Exec; - DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), std::move(exec))); + DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), exec)); } // Add subtract(time64, time64) -> duration for (auto unit : {TimeUnit::MICRO, TimeUnit::NANO}) { InputType in_type(match::Time64TypeUnit(unit)); auto exec = ScalarBinaryEqualTypes::Exec; - DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), std::move(exec))); + DCHECK_OK(subtract->AddKernel({in_type, in_type}, duration(unit), exec)); } // Add subtract(date32, date32) -> duration(TimeUnit::SECOND) @@ -2088,9 +2086,8 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { InputType in_type(match::TimestampTypeUnit(unit)); auto exec = ArithmeticExecFromOp(Type::TIMESTAMP); - DCHECK_OK(subtract_checked->AddKernel({in_type, in_type}, - OutputType::Resolver(ResolveTemporalOutput), - std::move(exec))); + DCHECK_OK(subtract_checked->AddKernel( + {in_type, in_type}, OutputType::Resolver(ResolveTemporalOutput), exec)); } // Add subtract_checked(timestamp, duration) -> timestamp @@ -2098,7 +2095,7 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { InputType in_type(match::TimestampTypeUnit(unit)); auto exec = ScalarBinary::Exec; DCHECK_OK(subtract_checked->AddKernel({in_type, duration(unit)}, - OutputType(FirstType), std::move(exec))); + OutputType(FirstType), exec)); } // Add subtract_checked(duration, duration) -> duration @@ -2106,8 +2103,7 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { InputType in_type(match::DurationTypeUnit(unit)); auto exec = ArithmeticExecFromOp(Type::DURATION); - DCHECK_OK( - subtract_checked->AddKernel({in_type, in_type}, duration(unit), std::move(exec))); + DCHECK_OK(subtract_checked->AddKernel({in_type, in_type}, duration(unit), exec)); } // Add subtract_checked(date32, date32) -> duration(TimeUnit::SECOND) @@ -2128,16 +2124,14 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { for (auto unit : {TimeUnit::SECOND, TimeUnit::MILLI}) { InputType in_type(match::Time32TypeUnit(unit)); auto exec = ScalarBinaryEqualTypes::Exec; - DCHECK_OK( - subtract_checked->AddKernel({in_type, in_type}, duration(unit), std::move(exec))); + DCHECK_OK(subtract_checked->AddKernel({in_type, in_type}, duration(unit), exec)); } // Add subtract_checked(time64, time64) -> duration for (auto unit : {TimeUnit::MICRO, TimeUnit::NANO}) { InputType in_type(match::Time64TypeUnit(unit)); auto exec = ScalarBinaryEqualTypes::Exec; - DCHECK_OK( - subtract_checked->AddKernel({in_type, in_type}, duration(unit), std::move(exec))); + DCHECK_OK(subtract_checked->AddKernel({in_type, in_type}, duration(unit), exec)); } AddArithmeticFunctionTimeDuration(subtract_checked); @@ -2181,8 +2175,7 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { // Add divide(duration, int64) -> duration for (auto unit : TimeUnit::values()) { auto exec = ScalarBinaryNotNull::Exec; - DCHECK_OK( - divide->AddKernel({duration(unit), int64()}, duration(unit), std::move(exec))); + DCHECK_OK(divide->AddKernel({duration(unit), int64()}, duration(unit), exec)); } DCHECK_OK(registry->AddFunction(std::move(divide))); @@ -2194,8 +2187,7 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) { // Add divide_checked(duration, int64) -> duration for (auto unit : TimeUnit::values()) { auto exec = ScalarBinaryNotNull::Exec; - DCHECK_OK(divide_checked->AddKernel({duration(unit), int64()}, duration(unit), - std::move(exec))); + DCHECK_OK(divide_checked->AddKernel({duration(unit), int64()}, duration(unit), exec)); } DCHECK_OK(registry->AddFunction(std::move(divide_checked))); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc index 1de64dc28e56a..f59f6caf350c6 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc @@ -677,7 +677,7 @@ std::shared_ptr GetCastToDecimal128() { // Cast from integer for (const DataType* in_ty : IntTypes()) { auto exec = GenerateInteger(in_ty->id()); - DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec))); + DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, exec)); } // Cast from other decimal @@ -706,7 +706,7 @@ std::shared_ptr GetCastToDecimal256() { // Cast from integer for (const DataType* in_ty : IntTypes()) { auto exec = GenerateInteger(in_ty->id()); - DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, std::move(exec))); + DCHECK_OK(func->AddKernel(in_ty->id(), {in_ty}, sig_out_ty, exec)); } // Cast from other decimal diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 718e007c4fc30..8a4310b58a176 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -2214,10 +2214,10 @@ static void CheckStructToStruct(const std::vector& value_types) for (const DataType* dest_value_type : value_types) { std::vector field_names = {"a", "b"}; std::shared_ptr a1, b1, a2, b2; - a1 = ArrayFromJSON(src_value_type->GetSharedPtr(), "[1, 2, 3, 4, null]"); - b1 = ArrayFromJSON(src_value_type->GetSharedPtr(), "[null, 7, 8, 9, 0]"); - a2 = ArrayFromJSON(dest_value_type->GetSharedPtr(), "[1, 2, 3, 4, null]"); - b2 = ArrayFromJSON(dest_value_type->GetSharedPtr(), "[null, 7, 8, 9, 0]"); + a1 = ArrayFromJSON(src_value_type, "[1, 2, 3, 4, null]"); + b1 = ArrayFromJSON(src_value_type, "[null, 7, 8, 9, 0]"); + a2 = ArrayFromJSON(dest_value_type, "[1, 2, 3, 4, null]"); + b2 = ArrayFromJSON(dest_value_type, "[null, 7, 8, 9, 0]"); ASSERT_OK_AND_ASSIGN(auto src, StructArray::Make({a1, b1}, field_names)); ASSERT_OK_AND_ASSIGN(auto dest, StructArray::Make({a2, b2}, field_names)); @@ -2244,20 +2244,18 @@ static void CheckStructToStructSubset(const std::vector& value_ std::vector field_names = {"a", "b", "c", "d", "e"}; std::shared_ptr a1, b1, c1, d1, e1; - auto sp_src_type = src_value_type->GetSharedPtr(); - auto sp_dst_type = dest_value_type->GetSharedPtr(); - a1 = ArrayFromJSON(sp_src_type, "[1, 2, 5]"); - b1 = ArrayFromJSON(sp_src_type, "[3, 4, 7]"); - c1 = ArrayFromJSON(sp_src_type, "[9, 11, 44]"); - d1 = ArrayFromJSON(sp_src_type, "[6, 51, 49]"); - e1 = ArrayFromJSON(sp_src_type, "[19, 17, 74]"); + a1 = ArrayFromJSON(src_value_type, "[1, 2, 5]"); + b1 = ArrayFromJSON(src_value_type, "[3, 4, 7]"); + c1 = ArrayFromJSON(src_value_type, "[9, 11, 44]"); + d1 = ArrayFromJSON(src_value_type, "[6, 51, 49]"); + e1 = ArrayFromJSON(src_value_type, "[19, 17, 74]"); std::shared_ptr a2, b2, c2, d2, e2; - a2 = ArrayFromJSON(sp_dst_type, "[1, 2, 5]"); - b2 = ArrayFromJSON(sp_dst_type, "[3, 4, 7]"); - c2 = ArrayFromJSON(sp_dst_type, "[9, 11, 44]"); - d2 = ArrayFromJSON(sp_dst_type, "[6, 51, 49]"); - e2 = ArrayFromJSON(sp_dst_type, "[19, 17, 74]"); + a2 = ArrayFromJSON(dest_value_type, "[1, 2, 5]"); + b2 = ArrayFromJSON(dest_value_type, "[3, 4, 7]"); + c2 = ArrayFromJSON(dest_value_type, "[9, 11, 44]"); + d2 = ArrayFromJSON(dest_value_type, "[6, 51, 49]"); + e2 = ArrayFromJSON(dest_value_type, "[19, 17, 74]"); ASSERT_OK_AND_ASSIGN(auto src, StructArray::Make({a1, b1, c1, d1, e1}, field_names)); @@ -2345,20 +2343,18 @@ static void CheckStructToStructSubsetWithNulls( std::vector field_names = {"a", "b", "c", "d", "e"}; std::shared_ptr a1, b1, c1, d1, e1; - auto sp_src_type = src_value_type->GetSharedPtr(); - auto sp_dst_type = dest_value_type->GetSharedPtr(); - a1 = ArrayFromJSON(sp_src_type, "[1, 2, 5]"); - b1 = ArrayFromJSON(sp_src_type, "[3, null, 7]"); - c1 = ArrayFromJSON(sp_src_type, "[9, 11, 44]"); - d1 = ArrayFromJSON(sp_src_type, "[6, 51, null]"); - e1 = ArrayFromJSON(sp_src_type, "[null, 17, 74]"); + a1 = ArrayFromJSON(src_value_type, "[1, 2, 5]"); + b1 = ArrayFromJSON(src_value_type, "[3, null, 7]"); + c1 = ArrayFromJSON(src_value_type, "[9, 11, 44]"); + d1 = ArrayFromJSON(src_value_type, "[6, 51, null]"); + e1 = ArrayFromJSON(src_value_type, "[null, 17, 74]"); std::shared_ptr a2, b2, c2, d2, e2; - a2 = ArrayFromJSON(sp_dst_type, "[1, 2, 5]"); - b2 = ArrayFromJSON(sp_dst_type, "[3, null, 7]"); - c2 = ArrayFromJSON(sp_dst_type, "[9, 11, 44]"); - d2 = ArrayFromJSON(sp_dst_type, "[6, 51, null]"); - e2 = ArrayFromJSON(sp_dst_type, "[null, 17, 74]"); + a2 = ArrayFromJSON(dest_value_type, "[1, 2, 5]"); + b2 = ArrayFromJSON(dest_value_type, "[3, null, 7]"); + c2 = ArrayFromJSON(dest_value_type, "[9, 11, 44]"); + d2 = ArrayFromJSON(dest_value_type, "[6, 51, null]"); + e2 = ArrayFromJSON(dest_value_type, "[null, 17, 74]"); std::shared_ptr null_bitmap; BitmapFromVector({0, 1, 0}, &null_bitmap); diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index 20dbd8c728181..4d0cc863994dc 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -428,20 +428,19 @@ std::shared_ptr MakeCompareFunction(std::string name, FunctionDo for (const DataType* ty : BaseBinaryTypes()) { auto exec = GenerateVarBinaryBase(*ty); - DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec))); + DCHECK_OK(func->AddKernel({ty, ty}, boolean(), exec)); } for (const auto id : {Type::DECIMAL128, Type::DECIMAL256}) { auto exec = GenerateDecimal(id); - DCHECK_OK( - func->AddKernel({InputType(id), InputType(id)}, boolean(), std::move(exec))); + DCHECK_OK(func->AddKernel({InputType(id), InputType(id)}, boolean(), exec)); } { auto exec = applicator::ScalarBinaryEqualTypes::Exec; auto ty = InputType(Type::FIXED_SIZE_BINARY); - DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec))); + DCHECK_OK(func->AddKernel({ty, ty}, boolean(), exec)); } return func; diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 0c0ef6fcab6b5..2c481c726414b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -2663,7 +2663,7 @@ void AddPrimitiveCaseWhenKernels(const std::shared_ptr& scalar const std::vector& types) { for (auto&& type : types) { auto exec = GenerateTypeAgnosticPrimitive(*type); - AddCaseWhenKernel(scalar_function, type, std::move(exec)); + AddCaseWhenKernel(scalar_function, type, exec); } } @@ -2671,7 +2671,7 @@ void AddBinaryCaseWhenKernels(const std::shared_ptr& scalar_fu const std::vector& types) { for (auto&& type : types) { auto exec = GenerateTypeAgnosticVarBinaryBase(*type); - AddCaseWhenKernel(scalar_function, type, std::move(exec)); + AddCaseWhenKernel(scalar_function, type, exec); } } @@ -2690,7 +2690,7 @@ void AddPrimitiveCoalesceKernels(const std::shared_ptr& scalar_f const std::vector& types) { for (auto&& type : types) { auto exec = GenerateTypeAgnosticPrimitive(*type); - AddCoalesceKernel(scalar_function, type, std::move(exec)); + AddCoalesceKernel(scalar_function, type, exec); } } @@ -2709,7 +2709,7 @@ void AddPrimitiveChooseKernels(const std::shared_ptr& scalar_fun const std::vector& types) { for (auto&& type : types) { auto exec = GenerateTypeAgnosticPrimitive(*type); - AddChooseKernel(scalar_function, type, std::move(exec)); + AddChooseKernel(scalar_function, type, exec); } } diff --git a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc index f89995b75cb58..17f236bc9cee2 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc @@ -32,6 +32,10 @@ namespace arrow { namespace compute { namespace internal { +Result OutputListOf(KernelContext*, const std::vector& types) { + return list(types[0].GetSharedPtr()); +} + namespace { // ---------------------------------------------------------------------- @@ -871,13 +875,13 @@ void AddAsciiStringLength(FunctionRegistry* registry) { auto exec = GenerateVarBinaryBase( ty); - DCHECK_OK(func->AddKernel({ty}, int32(), std::move(exec))); + DCHECK_OK(func->AddKernel({ty}, int32(), exec)); } for (const auto& ty : {large_binary(), large_utf8()}) { auto exec = GenerateVarBinaryBase( ty); - DCHECK_OK(func->AddKernel({ty}, int64(), std::move(exec))); + DCHECK_OK(func->AddKernel({ty}, int64(), exec)); } DCHECK_OK(func->AddKernel({InputType(Type::FIXED_SIZE_BINARY)}, int32(), BinaryLength::FixedSizeExec)); @@ -1601,8 +1605,7 @@ void AddAsciiStringMatchSubstring(FunctionRegistry* registry) { match_substring_doc); for (const auto& ty : BaseBinaryTypes()) { auto exec = GenerateVarBinaryToVarBinary(ty); - DCHECK_OK( - func->AddKernel({ty}, boolean(), std::move(exec), MatchSubstringState::Init)); + DCHECK_OK(func->AddKernel({ty}, boolean(), exec, MatchSubstringState::Init)); } DCHECK_OK(registry->AddFunction(std::move(func))); } @@ -1612,8 +1615,7 @@ void AddAsciiStringMatchSubstring(FunctionRegistry* registry) { for (const auto& ty : BaseBinaryTypes()) { auto exec = GenerateVarBinaryToVarBinary(ty); - DCHECK_OK( - func->AddKernel({ty}, boolean(), std::move(exec), MatchSubstringState::Init)); + DCHECK_OK(func->AddKernel({ty}, boolean(), exec, MatchSubstringState::Init)); } DCHECK_OK(registry->AddFunction(std::move(func))); } @@ -1622,8 +1624,7 @@ void AddAsciiStringMatchSubstring(FunctionRegistry* registry) { std::make_shared("ends_with", Arity::Unary(), ends_with_doc); for (const auto& ty : BaseBinaryTypes()) { auto exec = GenerateVarBinaryToVarBinary(ty); - DCHECK_OK( - func->AddKernel({ty}, boolean(), std::move(exec), MatchSubstringState::Init)); + DCHECK_OK(func->AddKernel({ty}, boolean(), exec, MatchSubstringState::Init)); } DCHECK_OK(registry->AddFunction(std::move(func))); } @@ -1633,8 +1634,7 @@ void AddAsciiStringMatchSubstring(FunctionRegistry* registry) { match_substring_regex_doc); for (const auto& ty : BaseBinaryTypes()) { auto exec = GenerateVarBinaryToVarBinary(ty); - DCHECK_OK( - func->AddKernel({ty}, boolean(), std::move(exec), MatchSubstringState::Init)); + DCHECK_OK(func->AddKernel({ty}, boolean(), exec, MatchSubstringState::Init)); } DCHECK_OK(registry->AddFunction(std::move(func))); } @@ -1643,8 +1643,7 @@ void AddAsciiStringMatchSubstring(FunctionRegistry* registry) { std::make_shared("match_like", Arity::Unary(), match_like_doc); for (const auto& ty : BaseBinaryTypes()) { auto exec = GenerateVarBinaryToVarBinary(ty); - DCHECK_OK( - func->AddKernel({ty}, boolean(), std::move(exec), MatchSubstringState::Init)); + DCHECK_OK(func->AddKernel({ty}, boolean(), exec, MatchSubstringState::Init)); } DCHECK_OK(registry->AddFunction(std::move(func))); } @@ -2120,7 +2119,7 @@ void AddAsciiStringReplaceSubstring(FunctionRegistry* registry) { replace_substring_doc); for (const auto& ty : BaseBinaryTypes()) { auto exec = GenerateVarBinaryToVarBinary(ty); - ScalarKernel kernel{{ty}, ty, std::move(exec), ReplaceState::Init}; + ScalarKernel kernel{{ty}, ty, exec, ReplaceState::Init}; kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; DCHECK_OK(func->AddKernel(std::move(kernel))); } @@ -2132,7 +2131,7 @@ void AddAsciiStringReplaceSubstring(FunctionRegistry* registry) { "replace_substring_regex", Arity::Unary(), replace_substring_regex_doc); for (const auto& ty : BaseBinaryTypes()) { auto exec = GenerateVarBinaryToVarBinary(ty); - ScalarKernel kernel{{ty}, ty, std::move(exec), ReplaceState::Init}; + ScalarKernel kernel{{ty}, ty, exec, ReplaceState::Init}; kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; DCHECK_OK(func->AddKernel(std::move(kernel))); } @@ -2476,8 +2475,7 @@ void AddAsciiStringSplitPattern(FunctionRegistry* registry) { split_pattern_doc); for (const auto& ty : BaseBinaryTypes()) { auto exec = GenerateVarBinaryToVarBinary(ty); - DCHECK_OK(func->AddKernel({ty}, {list(ty->GetSharedPtr())}, std::move(exec), - SplitPatternState::Init)); + DCHECK_OK(func->AddKernel({ty}, OutputListOf, exec, SplitPatternState::Init)); } DCHECK_OK(registry->AddFunction(std::move(func))); } @@ -2547,8 +2545,7 @@ void AddAsciiStringSplitWhitespace(FunctionRegistry* registry) { for (const auto& ty : StringTypes()) { auto exec = GenerateVarBinaryToVarBinary(ty); - DCHECK_OK(func->AddKernel({ty}, {list(ty->GetSharedPtr())}, std::move(exec), - StringSplitState::Init)); + DCHECK_OK(func->AddKernel({ty}, OutputListOf, exec, StringSplitState::Init)); } DCHECK_OK(registry->AddFunction(std::move(func))); } @@ -2617,8 +2614,7 @@ void AddAsciiStringSplitRegex(FunctionRegistry* registry) { split_pattern_regex_doc); for (const DataType* ty : BaseBinaryTypes()) { auto exec = GenerateVarBinaryToVarBinary(ty); - DCHECK_OK(func->AddKernel({ty}, {list(ty->GetSharedPtr())}, std::move(exec), - SplitPatternState::Init)); + DCHECK_OK(func->AddKernel({ty}, OutputListOf, exec, SplitPatternState::Init)); } DCHECK_OK(registry->AddFunction(std::move(func))); } @@ -3020,8 +3016,8 @@ void AddBinaryJoinForListType(ScalarFunction* func) { for (const DataType* ty : BaseBinaryTypes()) { auto exec = GenerateTypeAgnosticVarBinaryBase(*ty); - auto list_ty = std::make_shared(ty->GetSharedPtr()); - DCHECK_OK(func->AddKernel({InputType(list_ty), InputType(ty)}, ty, std::move(exec))); + DCHECK_OK(func->AddKernel({match::ListOf(ty->id(), ListType::type_id), InputType(ty)}, + ty, exec)); } } diff --git a/cpp/src/arrow/compute/kernels/scalar_string_internal.h b/cpp/src/arrow/compute/kernels/scalar_string_internal.h index 32731414e089b..14eae3e2c8d36 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_internal.h +++ b/cpp/src/arrow/compute/kernels/scalar_string_internal.h @@ -152,7 +152,7 @@ void MakeUnaryStringBatchKernel( auto func = std::make_shared(name, Arity::Unary(), std::move(doc)); for (const auto& ty : StringTypes()) { auto exec = GenerateVarBinaryToVarBinary(ty); - ScalarKernel kernel{{ty}, ty, std::move(exec)}; + ScalarKernel kernel{{ty}, ty, exec}; kernel.mem_allocation = mem_allocation; DCHECK_OK(func->AddKernel(std::move(kernel))); } @@ -238,7 +238,7 @@ void AddUnaryStringPredicate(std::string name, FunctionRegistry* registry, auto func = std::make_shared(name, Arity::Unary(), std::move(doc)); for (const auto& ty : StringTypes()) { auto exec = GenerateVarBinaryToVarBinary(ty); - DCHECK_OK(func->AddKernel({ty}, boolean(), std::move(exec))); + DCHECK_OK(func->AddKernel({ty}, boolean(), exec)); } DCHECK_OK(registry->AddFunction(std::move(func))); } @@ -410,6 +410,8 @@ struct StringSplitExec { using StringSplitState = OptionsWrapper; +Result OutputListOf(KernelContext*, const std::vector& types); + } // namespace internal } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc b/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc index 1238b88026693..f301f307c9124 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc @@ -43,7 +43,7 @@ void MakeUnaryStringUTF8TransformKernel(std::string name, FunctionRegistry* regi auto func = std::make_shared(name, Arity::Unary(), std::move(doc)); for (const DataType* ty : StringTypes()) { auto exec = GenerateVarBinaryToVarBinary(ty); - DCHECK_OK(func->AddKernel({ty}, ty, std::move(exec))); + DCHECK_OK(func->AddKernel({ty}, ty, exec)); } DCHECK_OK(registry->AddFunction(std::move(func))); } @@ -675,12 +675,12 @@ void AddUtf8StringLength(FunctionRegistry* registry) { std::make_shared("utf8_length", Arity::Unary(), utf8_length_doc); { auto exec = applicator::ScalarUnaryNotNull::Exec; - DCHECK_OK(func->AddKernel({utf8()}, int32(), std::move(exec))); + DCHECK_OK(func->AddKernel({utf8()}, int32(), exec)); } { auto exec = applicator::ScalarUnaryNotNull::Exec; - DCHECK_OK(func->AddKernel({large_utf8()}, int64(), std::move(exec))); + DCHECK_OK(func->AddKernel({large_utf8()}, int64(), exec)); } DCHECK_OK(registry->AddFunction(std::move(func))); } @@ -1073,8 +1073,8 @@ void AddUtf8StringReplaceSlice(FunctionRegistry* registry) { for (const DataType* ty : StringTypes()) { auto exec = GenerateVarBinaryToVarBinary(ty); - DCHECK_OK(func->AddKernel({ty}, ty, std::move(exec), - ReplaceStringSliceTransformBase::State::Init)); + DCHECK_OK( + func->AddKernel({ty}, ty, exec, ReplaceStringSliceTransformBase::State::Init)); } DCHECK_OK(registry->AddFunction(std::move(func))); } @@ -1261,8 +1261,7 @@ void AddUtf8StringSlice(FunctionRegistry* registry) { utf8_slice_codeunits_doc); for (const DataType* ty : StringTypes()) { auto exec = GenerateVarBinaryToVarBinary(ty); - DCHECK_OK( - func->AddKernel({ty}, ty, std::move(exec), SliceCodeunitsTransform::State::Init)); + DCHECK_OK(func->AddKernel({ty}, ty, exec, SliceCodeunitsTransform::State::Init)); } DCHECK_OK(registry->AddFunction(std::move(func))); } @@ -1347,8 +1346,7 @@ void AddUtf8StringSplitWhitespace(FunctionRegistry* registry) { utf8_split_whitespace_doc, &default_options); for (const DataType* ty : StringTypes()) { auto exec = GenerateVarBinaryToVarBinary(ty); - DCHECK_OK(func->AddKernel({ty}, {list(ty->GetSharedPtr())}, std::move(exec), - StringSplitState::Init)); + DCHECK_OK(func->AddKernel({ty}, OutputListOf, exec, StringSplitState::Init)); } DCHECK_OK(registry->AddFunction(std::move(func))); } diff --git a/cpp/src/arrow/compute/kernels/scalar_temporal_binary.cc b/cpp/src/arrow/compute/kernels/scalar_temporal_binary.cc index 8a00aff3c28c4..b1df5149db8fc 100644 --- a/cpp/src/arrow/compute/kernels/scalar_temporal_binary.cc +++ b/cpp/src/arrow/compute/kernels/scalar_temporal_binary.cc @@ -340,7 +340,7 @@ struct BinaryTemporalFactory { template void AddKernel(InputType in_type) { auto exec = ExecTemplate::Exec; - DCHECK_OK(func->AddKernel({in_type, in_type}, out_type, std::move(exec), init)); + DCHECK_OK(func->AddKernel({in_type, in_type}, out_type, exec, init)); } }; diff --git a/cpp/src/arrow/compute/kernels/scalar_temporal_unary.cc b/cpp/src/arrow/compute/kernels/scalar_temporal_unary.cc index d7c045d84b079..2be0e2eb8c254 100644 --- a/cpp/src/arrow/compute/kernels/scalar_temporal_unary.cc +++ b/cpp/src/arrow/compute/kernels/scalar_temporal_unary.cc @@ -76,18 +76,12 @@ using StrptimeState = OptionsWrapper; using AssumeTimezoneState = OptionsWrapper; using RoundTemporalState = OptionsWrapper; -const std::shared_ptr& IsoCalendarType() { - static auto type = struct_({field("iso_year", int64()), field("iso_week", int64()), - field("iso_day_of_week", int64())}); - return type; -} - -const std::shared_ptr& YearMonthDayType() { - static auto type = - struct_({field("year", int64()), field("month", int64()), field("day", int64())}); +static auto kIsoCalendarType = + struct_({field("iso_year", int64()), field("iso_week", int64()), + field("iso_day_of_week", int64())}); - return type; -} +static auto kYearMonthDayType = + struct_({field("year", int64()), field("month", int64()), field("day", int64())}); Status ValidateDayOfWeekOptions(const DayOfWeekOptions& options) { if (options.week_start < 1 || 7 < options.week_start) { @@ -338,7 +332,7 @@ struct YearMonthDay { using BuilderType = typename TypeTraits::BuilderType; std::unique_ptr array_builder; - RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), YearMonthDayType(), &array_builder)); + RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), kYearMonthDayType, &array_builder)); StructBuilder* struct_builder = checked_cast(array_builder.get()); RETURN_NOT_OK(struct_builder->Reserve(in.length)); @@ -1479,7 +1473,7 @@ struct ISOCalendar { using BuilderType = typename TypeTraits::BuilderType; std::unique_ptr array_builder; - RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), IsoCalendarType(), &array_builder)); + RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), kIsoCalendarType, &array_builder)); StructBuilder* struct_builder = checked_cast(array_builder.get()); RETURN_NOT_OK(struct_builder->Reserve(in.length)); @@ -1532,7 +1526,7 @@ struct UnaryTemporalFactory { template void AddKernel(InputType in_type) { auto exec = ExecTemplate::Exec; - ScalarKernel kernel({std::move(in_type)}, out_type, std::move(exec), init); + ScalarKernel kernel({std::move(in_type)}, out_type, exec, init); DCHECK_OK(func->AddKernel(kernel)); } }; @@ -1835,7 +1829,7 @@ void RegisterScalarTemporalUnary(FunctionRegistry* registry) { auto year_month_day = SimpleUnaryTemporalFactory::Make( - "year_month_day", YearMonthDayType(), year_month_day_doc); + "year_month_day", kYearMonthDayType.get(), year_month_day_doc); DCHECK_OK(registry->AddFunction(std::move(year_month_day))); static const auto default_day_of_week_options = DayOfWeekOptions::Defaults(); @@ -1887,7 +1881,7 @@ void RegisterScalarTemporalUnary(FunctionRegistry* registry) { auto iso_calendar = SimpleUnaryTemporalFactory::Make( - "iso_calendar", IsoCalendarType(), iso_calendar_doc); + "iso_calendar", kIsoCalendarType.get(), iso_calendar_doc); DCHECK_OK(registry->AddFunction(std::move(iso_calendar))); auto quarter = diff --git a/cpp/src/arrow/compute/kernels/vector_replace.cc b/cpp/src/arrow/compute/kernels/vector_replace.cc index 8bfa22a0d42c7..8a7241596be1b 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace.cc @@ -812,7 +812,7 @@ void AddKernel(Type::type type_id, std::shared_ptr signature, } kernel.mem_allocation = MemAllocation::type::PREALLOCATE; kernel.signature = std::move(signature); - kernel.exec = std::move(exec); + kernel.exec = exec; kernel.exec_chunked = exec_chunked; kernel.can_execute_chunkwise = false; kernel.output_chunked = false; diff --git a/cpp/src/arrow/ipc/json_simple.cc b/cpp/src/arrow/ipc/json_simple.cc index 5342aac532f2f..1dd83ad2ecd43 100644 --- a/cpp/src/arrow/ipc/json_simple.cc +++ b/cpp/src/arrow/ipc/json_simple.cc @@ -560,7 +560,7 @@ class MapConverter final : public ConcreteConverter { auto key_builder = key_converter_->builder(); auto item_builder = item_converter_->builder(); builder_ = std::make_shared(default_memory_pool(), key_builder, - item_builder, type_); + item_builder, type_->GetSharedPtr()); return Status::OK(); } diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 221b35ce57323..3a2c2072a164b 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -955,6 +955,10 @@ static inline bool is_base_binary_like(Type::type type_id) { return false; } +static inline bool is_parameter_free(Type::type type_id) { + return is_primitive(type_id) || is_base_binary_like(type_id) || type_id == Type::NA; +} + static inline bool is_binary_like(Type::type type_id) { switch (type_id) { case Type::BINARY: