Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-41418: [C++] Add [Large]ListView and Map nested types for scalar_if_else's kernel functions #41419

Merged
merged 2 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions cpp/src/arrow/compute/kernels/scalar_if_else.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1309,9 +1309,10 @@ void AddFixedWidthIfElseKernel(const std::shared_ptr<IfElseFunction>& scalar_fun
}

void AddNestedIfElseKernels(const std::shared_ptr<IfElseFunction>& scalar_function) {
for (const auto type_id : {Type::LIST, Type::LARGE_LIST, Type::LIST_VIEW,
Type::LARGE_LIST_VIEW, Type::FIXED_SIZE_LIST, Type::STRUCT,
Type::DENSE_UNION, Type::SPARSE_UNION, Type::DICTIONARY}) {
for (const auto type_id :
{Type::LIST, Type::LARGE_LIST, Type::LIST_VIEW, Type::LARGE_LIST_VIEW,
Type::FIXED_SIZE_LIST, Type::MAP, Type::STRUCT, Type::DENSE_UNION,
Type::SPARSE_UNION, Type::DICTIONARY}) {
ScalarKernel kernel({boolean(), InputType(type_id), InputType(type_id)}, LastType,
NestedIfElseExec::Exec);
kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
Expand Down Expand Up @@ -1807,7 +1808,8 @@ struct CaseWhenFunctor<Type, enable_if_base_binary<Type>> {
};

template <typename Type>
struct CaseWhenFunctor<Type, enable_if_var_size_list<Type>> {
struct CaseWhenFunctor<
Type, enable_if_t<is_base_list_type<Type>::value || is_list_view_type<Type>::value>> {
Copy link
Contributor

@felipecrv felipecrv Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list-view types should have their own specialization because they have a super-power that classic list types don't have: you can append the child values in any order and adjust the offset/size pairs of any random position to point to that area.

Feel free to copy and paste this class and enable_if it for is_list_view with a TODO(GH-<issue number>): a more efficient implementation for list-views is possible comment mentioning an issue you can create about this. Then we can tackle the optimization in a separate PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can append the child values in any order and adjust the offset/size pairs of any random position to point to that area.

Thanks for your suggestion, i will dig into the list-view types.

Feel free to copy and paste this class and enable_if it for is_list_view with a TODO(GH-<issue number>): a more efficient implementation for list-views is possible comment mentioning an issue you can create about this. Then we can tackle the optimization in a separate PR.

Agree!

Copy link
Collaborator Author

@ZhangHuiGui ZhangHuiGui Apr 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Optimization for list-view types in current thread will be tracked in [C++] A more efficient "case_when" specialization for list-view types #41453 for optimize
  2. The new commit includes case_when's list-view type benchmark for our performance regression in the future

using offset_type = typename Type::offset_type;
using BuilderType = typename TypeTraits<Type>::BuilderType;
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
Expand Down Expand Up @@ -2712,6 +2714,25 @@ void AddBinaryCaseWhenKernels(const std::shared_ptr<CaseWhenFunction>& scalar_fu
}
}

template <typename ArrowNestedType>
void AddNestedCaseWhenKernel(const std::shared_ptr<CaseWhenFunction>& scalar_function) {
AddCaseWhenKernel(scalar_function, ArrowNestedType::type_id,
CaseWhenFunctor<ArrowNestedType>::Exec);
}

void AddNestedCaseWhenKernels(const std::shared_ptr<CaseWhenFunction>& scalar_function) {
AddNestedCaseWhenKernel<FixedSizeListType>(scalar_function);
AddNestedCaseWhenKernel<ListType>(scalar_function);
AddNestedCaseWhenKernel<LargeListType>(scalar_function);
AddNestedCaseWhenKernel<ListViewType>(scalar_function);
AddNestedCaseWhenKernel<LargeListViewType>(scalar_function);
AddNestedCaseWhenKernel<MapType>(scalar_function);
AddNestedCaseWhenKernel<StructType>(scalar_function);
AddNestedCaseWhenKernel<DenseUnionType>(scalar_function);
AddNestedCaseWhenKernel<SparseUnionType>(scalar_function);
AddNestedCaseWhenKernel<DictionaryType>(scalar_function);
}

void AddCoalesceKernel(const std::shared_ptr<ScalarFunction>& scalar_function,
detail::GetTypeId get_id, ArrayKernelExec exec) {
ScalarKernel kernel(KernelSignature::Make({InputType(get_id.id)}, FirstType,
Expand All @@ -2731,6 +2752,25 @@ void AddPrimitiveCoalesceKernels(const std::shared_ptr<ScalarFunction>& scalar_f
}
}

template <typename ArrowNestedType>
void AddNestedCoalesceKernel(const std::shared_ptr<ScalarFunction>& scalar_function) {
AddCoalesceKernel(scalar_function, ArrowNestedType::type_id,
CoalesceFunctor<ArrowNestedType>::Exec);
}

void AddNestedCoalesceKernels(const std::shared_ptr<ScalarFunction>& scalar_function) {
AddNestedCoalesceKernel<FixedSizeListType>(scalar_function);
AddNestedCoalesceKernel<ListType>(scalar_function);
AddNestedCoalesceKernel<LargeListType>(scalar_function);
AddNestedCoalesceKernel<ListViewType>(scalar_function);
AddNestedCoalesceKernel<LargeListViewType>(scalar_function);
AddNestedCoalesceKernel<MapType>(scalar_function);
AddNestedCoalesceKernel<StructType>(scalar_function);
AddNestedCoalesceKernel<DenseUnionType>(scalar_function);
AddNestedCoalesceKernel<SparseUnionType>(scalar_function);
AddNestedCoalesceKernel<DictionaryType>(scalar_function);
}

void AddChooseKernel(const std::shared_ptr<ScalarFunction>& scalar_function,
detail::GetTypeId get_id, ArrayKernelExec exec) {
ScalarKernel kernel(KernelSignature::Make({Type::INT64, InputType(get_id.id)}, LastType,
Expand Down Expand Up @@ -2822,15 +2862,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor<FixedSizeBinaryType>::Exec);
AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor<FixedSizeBinaryType>::Exec);
AddBinaryCaseWhenKernels(func, BaseBinaryTypes());
AddCaseWhenKernel(func, Type::FIXED_SIZE_LIST,
CaseWhenFunctor<FixedSizeListType>::Exec);
AddCaseWhenKernel(func, Type::LIST, CaseWhenFunctor<ListType>::Exec);
AddCaseWhenKernel(func, Type::LARGE_LIST, CaseWhenFunctor<LargeListType>::Exec);
AddCaseWhenKernel(func, Type::MAP, CaseWhenFunctor<MapType>::Exec);
AddCaseWhenKernel(func, Type::STRUCT, CaseWhenFunctor<StructType>::Exec);
AddCaseWhenKernel(func, Type::DENSE_UNION, CaseWhenFunctor<DenseUnionType>::Exec);
AddCaseWhenKernel(func, Type::SPARSE_UNION, CaseWhenFunctor<SparseUnionType>::Exec);
AddCaseWhenKernel(func, Type::DICTIONARY, CaseWhenFunctor<DictionaryType>::Exec);
AddNestedCaseWhenKernels(func);
DCHECK_OK(registry->AddFunction(std::move(func)));
}
{
Expand All @@ -2848,15 +2880,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
for (const auto& ty : BaseBinaryTypes()) {
AddCoalesceKernel(func, ty, GenerateTypeAgnosticVarBinaryBase<CoalesceFunctor>(ty));
}
AddCoalesceKernel(func, Type::FIXED_SIZE_LIST,
CoalesceFunctor<FixedSizeListType>::Exec);
AddCoalesceKernel(func, Type::LIST, CoalesceFunctor<ListType>::Exec);
AddCoalesceKernel(func, Type::LARGE_LIST, CoalesceFunctor<LargeListType>::Exec);
AddCoalesceKernel(func, Type::MAP, CoalesceFunctor<MapType>::Exec);
AddCoalesceKernel(func, Type::STRUCT, CoalesceFunctor<StructType>::Exec);
AddCoalesceKernel(func, Type::DENSE_UNION, CoalesceFunctor<DenseUnionType>::Exec);
AddCoalesceKernel(func, Type::SPARSE_UNION, CoalesceFunctor<SparseUnionType>::Exec);
AddCoalesceKernel(func, Type::DICTIONARY, CoalesceFunctor<DictionaryType>::Exec);
AddNestedCoalesceKernels(func);
DCHECK_OK(registry->AddFunction(std::move(func)));
}
{
Expand Down
19 changes: 17 additions & 2 deletions cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,21 @@ TEST_F(TestIfElseKernel, ParameterizedTypes) {
{cond, ArrayFromJSON(type0, "[0]"), ArrayFromJSON(type1, "[1]")}));
}

TEST_F(TestIfElseKernel, MapNested) {
auto type = map(int64(), utf8());
CheckWithDifferentShapes(
ArrayFromJSON(boolean(), "[true, true, false, false]"),
ArrayFromJSON(type, R"([null, [[2, "foo"], [4, null]], [[3, "test"]], []])"),
ArrayFromJSON(type, R"([[[1, "b"]], [[2, "c"]], [[7, "abc"]], null])"),
ArrayFromJSON(type, R"([null, [[2, "foo"], [4, null]], [[7, "abc"]], null])"));

CheckWithDifferentShapes(
ArrayFromJSON(boolean(), "[null, null, null, null]"),
ArrayFromJSON(type, R"([null, [[1, "c"]], [[4, null]], [[6, "ok"]]])"),
ArrayFromJSON(type, R"([[[-1, null]], [[3, "c"]], null, [[6, "ok"]]])"),
ArrayFromJSON(type, R"([null, null, null, null])"));
}

template <typename Type>
class TestIfElseUnion : public ::testing::Test {};

Expand Down Expand Up @@ -1920,7 +1935,7 @@ TYPED_TEST(TestCaseWhenBinary, Random) {
template <typename Type>
class TestCaseWhenList : public ::testing::Test {};

TYPED_TEST_SUITE(TestCaseWhenList, ListArrowTypes);
TYPED_TEST_SUITE(TestCaseWhenList, ListAndListViewArrowTypes);

TYPED_TEST(TestCaseWhenList, ListOfString) {
auto type = std::make_shared<TypeParam>(utf8());
Expand Down Expand Up @@ -2555,7 +2570,7 @@ class TestCoalesceList : public ::testing::Test {};

TYPED_TEST_SUITE(TestCoalesceNumeric, IfElseNumericBasedTypes);
TYPED_TEST_SUITE(TestCoalesceBinary, BaseBinaryArrowTypes);
TYPED_TEST_SUITE(TestCoalesceList, ListArrowTypes);
TYPED_TEST_SUITE(TestCoalesceList, ListAndListViewArrowTypes);

TYPED_TEST(TestCoalesceNumeric, Basics) {
auto type = default_type_instance<TypeParam>();
Expand Down
Loading