Skip to content

Commit

Permalink
GH-41418: [C++] Add [Large]ListView and Map nested types for scalar_i…
Browse files Browse the repository at this point in the history
…f_else's kernel functions (#41419)

### Rationale for this change
Add [Large]ListView and Map nested types for scalar_if_else's kernel functions

### What changes are included in this PR?
1. Add the list-view related types to `case_when`, `coalesce`'s kernel function and move the nested-types's added
   logic to a unified function for better management.
2. Add the `MapType` and related test for `if_else`

### Are these changes tested?
Yes

### Are there any user-facing changes?
No

* GitHub Issue: #41418

Authored-by: ZhangHuiGui <[email protected]>
Signed-off-by: Felipe Oliveira Carvalho <[email protected]>
  • Loading branch information
ZhangHuiGui authored Apr 30, 2024
1 parent 5e986be commit 0d7fac0
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 38 deletions.
107 changes: 86 additions & 21 deletions cpp/src/arrow/compute/kernels/scalar_if_else.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1309,9 +1309,10 @@ void AddFixedWidthIfElseKernel(const std::shared_ptr<IfElseFunction>& scalar_fun
}

void AddNestedIfElseKernels(const std::shared_ptr<IfElseFunction>& scalar_function) {
for (const auto type_id : {Type::LIST, Type::LARGE_LIST, Type::LIST_VIEW,
Type::LARGE_LIST_VIEW, Type::FIXED_SIZE_LIST, Type::STRUCT,
Type::DENSE_UNION, Type::SPARSE_UNION, Type::DICTIONARY}) {
for (const auto type_id :
{Type::LIST, Type::LARGE_LIST, Type::LIST_VIEW, Type::LARGE_LIST_VIEW,
Type::FIXED_SIZE_LIST, Type::MAP, Type::STRUCT, Type::DENSE_UNION,
Type::SPARSE_UNION, Type::DICTIONARY}) {
ScalarKernel kernel({boolean(), InputType(type_id), InputType(type_id)}, LastType,
NestedIfElseExec::Exec);
kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
Expand Down Expand Up @@ -1847,6 +1848,48 @@ struct CaseWhenFunctor<Type, enable_if_var_size_list<Type>> {
}
};

// TODO(GH-41453): a more efficient implementation for list-views is possible
template <typename Type>
struct CaseWhenFunctor<Type, enable_if_list_view<Type>> {
using offset_type = typename Type::offset_type;
using BuilderType = typename TypeTraits<Type>::BuilderType;
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
/// TODO(wesm): should this be a DCHECK? Or checked elsewhere
if (batch[0].null_count() > 0) {
return Status::Invalid("cond struct must not have outer nulls");
}
if (batch[0].is_scalar()) {
return ExecVarWidthScalarCaseWhen(ctx, batch, out);
}
return ExecArray(ctx, batch, out);
}

static Status ExecArray(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return ExecVarWidthArrayCaseWhen(
ctx, batch, out,
// ReserveData
[&](ArrayBuilder* raw_builder) {
auto builder = checked_cast<BuilderType*>(raw_builder);
auto child_builder = builder->value_builder();

int64_t reservation = 0;
for (int arg = 1; arg < batch.num_values(); arg++) {
const ExecValue& source = batch[arg];
if (!source.is_array()) {
const auto& scalar = checked_cast<const BaseListScalar&>(*source.scalar);
if (!scalar.value) continue;
reservation =
std::max<int64_t>(reservation, batch.length * scalar.value->length());
} else {
const ArraySpan& array = source.array;
reservation = std::max<int64_t>(reservation, array.child_data[0].length);
}
}
return child_builder->Reserve(reservation);
});
}
};

// No-op reserve function, pulled out to avoid apparent miscompilation on MinGW
Status ReserveNoData(ArrayBuilder*) { return Status::OK(); }

Expand Down Expand Up @@ -2712,6 +2755,25 @@ void AddBinaryCaseWhenKernels(const std::shared_ptr<CaseWhenFunction>& scalar_fu
}
}

template <typename ArrowNestedType>
void AddNestedCaseWhenKernel(const std::shared_ptr<CaseWhenFunction>& scalar_function) {
AddCaseWhenKernel(scalar_function, ArrowNestedType::type_id,
CaseWhenFunctor<ArrowNestedType>::Exec);
}

void AddNestedCaseWhenKernels(const std::shared_ptr<CaseWhenFunction>& scalar_function) {
AddNestedCaseWhenKernel<FixedSizeListType>(scalar_function);
AddNestedCaseWhenKernel<ListType>(scalar_function);
AddNestedCaseWhenKernel<LargeListType>(scalar_function);
AddNestedCaseWhenKernel<ListViewType>(scalar_function);
AddNestedCaseWhenKernel<LargeListViewType>(scalar_function);
AddNestedCaseWhenKernel<MapType>(scalar_function);
AddNestedCaseWhenKernel<StructType>(scalar_function);
AddNestedCaseWhenKernel<DenseUnionType>(scalar_function);
AddNestedCaseWhenKernel<SparseUnionType>(scalar_function);
AddNestedCaseWhenKernel<DictionaryType>(scalar_function);
}

void AddCoalesceKernel(const std::shared_ptr<ScalarFunction>& scalar_function,
detail::GetTypeId get_id, ArrayKernelExec exec) {
ScalarKernel kernel(KernelSignature::Make({InputType(get_id.id)}, FirstType,
Expand All @@ -2731,6 +2793,25 @@ void AddPrimitiveCoalesceKernels(const std::shared_ptr<ScalarFunction>& scalar_f
}
}

template <typename ArrowNestedType>
void AddNestedCoalesceKernel(const std::shared_ptr<ScalarFunction>& scalar_function) {
AddCoalesceKernel(scalar_function, ArrowNestedType::type_id,
CoalesceFunctor<ArrowNestedType>::Exec);
}

void AddNestedCoalesceKernels(const std::shared_ptr<ScalarFunction>& scalar_function) {
AddNestedCoalesceKernel<FixedSizeListType>(scalar_function);
AddNestedCoalesceKernel<ListType>(scalar_function);
AddNestedCoalesceKernel<LargeListType>(scalar_function);
AddNestedCoalesceKernel<ListViewType>(scalar_function);
AddNestedCoalesceKernel<LargeListViewType>(scalar_function);
AddNestedCoalesceKernel<MapType>(scalar_function);
AddNestedCoalesceKernel<StructType>(scalar_function);
AddNestedCoalesceKernel<DenseUnionType>(scalar_function);
AddNestedCoalesceKernel<SparseUnionType>(scalar_function);
AddNestedCoalesceKernel<DictionaryType>(scalar_function);
}

void AddChooseKernel(const std::shared_ptr<ScalarFunction>& scalar_function,
detail::GetTypeId get_id, ArrayKernelExec exec) {
ScalarKernel kernel(KernelSignature::Make({Type::INT64, InputType(get_id.id)}, LastType,
Expand Down Expand Up @@ -2822,15 +2903,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor<FixedSizeBinaryType>::Exec);
AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor<FixedSizeBinaryType>::Exec);
AddBinaryCaseWhenKernels(func, BaseBinaryTypes());
AddCaseWhenKernel(func, Type::FIXED_SIZE_LIST,
CaseWhenFunctor<FixedSizeListType>::Exec);
AddCaseWhenKernel(func, Type::LIST, CaseWhenFunctor<ListType>::Exec);
AddCaseWhenKernel(func, Type::LARGE_LIST, CaseWhenFunctor<LargeListType>::Exec);
AddCaseWhenKernel(func, Type::MAP, CaseWhenFunctor<MapType>::Exec);
AddCaseWhenKernel(func, Type::STRUCT, CaseWhenFunctor<StructType>::Exec);
AddCaseWhenKernel(func, Type::DENSE_UNION, CaseWhenFunctor<DenseUnionType>::Exec);
AddCaseWhenKernel(func, Type::SPARSE_UNION, CaseWhenFunctor<SparseUnionType>::Exec);
AddCaseWhenKernel(func, Type::DICTIONARY, CaseWhenFunctor<DictionaryType>::Exec);
AddNestedCaseWhenKernels(func);
DCHECK_OK(registry->AddFunction(std::move(func)));
}
{
Expand All @@ -2848,15 +2921,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
for (const auto& ty : BaseBinaryTypes()) {
AddCoalesceKernel(func, ty, GenerateTypeAgnosticVarBinaryBase<CoalesceFunctor>(ty));
}
AddCoalesceKernel(func, Type::FIXED_SIZE_LIST,
CoalesceFunctor<FixedSizeListType>::Exec);
AddCoalesceKernel(func, Type::LIST, CoalesceFunctor<ListType>::Exec);
AddCoalesceKernel(func, Type::LARGE_LIST, CoalesceFunctor<LargeListType>::Exec);
AddCoalesceKernel(func, Type::MAP, CoalesceFunctor<MapType>::Exec);
AddCoalesceKernel(func, Type::STRUCT, CoalesceFunctor<StructType>::Exec);
AddCoalesceKernel(func, Type::DENSE_UNION, CoalesceFunctor<DenseUnionType>::Exec);
AddCoalesceKernel(func, Type::SPARSE_UNION, CoalesceFunctor<SparseUnionType>::Exec);
AddCoalesceKernel(func, Type::DICTIONARY, CoalesceFunctor<DictionaryType>::Exec);
AddNestedCoalesceKernels(func);
DCHECK_OK(registry->AddFunction(std::move(func)));
}
{
Expand Down
50 changes: 35 additions & 15 deletions cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,11 @@ static void CaseWhenBench(benchmark::State& state) {
state.SetItemsProcessed(state.iterations() * (len - offset));
}

static void CaseWhenBenchList(benchmark::State& state) {
auto type = list(int64());
template <typename Type>
static void CaseWhenBenchList(benchmark::State& state,
const std::shared_ptr<DataType>& type) {
using ArrayType = typename TypeTraits<Type>::ArrayType;

auto fld = field("", type);

int64_t len = state.range(0);
Expand All @@ -295,17 +298,17 @@ static void CaseWhenBenchList(benchmark::State& state) {

auto cond_field =
field("cond", boolean(), key_value_metadata({{"null_probability", "0.01"}}));
auto cond = rand.ArrayOf(*field("", struct_({cond_field, cond_field, cond_field}),
key_value_metadata({{"null_probability", "0.0"}})),
len);
auto val1 = rand.ArrayOf(*fld, len);
auto val2 = rand.ArrayOf(*fld, len);
auto val3 = rand.ArrayOf(*fld, len);
auto val4 = rand.ArrayOf(*fld, len);
auto cond = std::static_pointer_cast<BooleanArray>(
rand.ArrayOf(*field("", struct_({cond_field, cond_field, cond_field}),
key_value_metadata({{"null_probability", "0.0"}})),
len))
->Slice(offset);
auto val1 = std::static_pointer_cast<ArrayType>(rand.ArrayOf(*fld, len))->Slice(offset);
auto val2 = std::static_pointer_cast<ArrayType>(rand.ArrayOf(*fld, len))->Slice(offset);
auto val3 = std::static_pointer_cast<ArrayType>(rand.ArrayOf(*fld, len))->Slice(offset);
auto val4 = std::static_pointer_cast<ArrayType>(rand.ArrayOf(*fld, len))->Slice(offset);
for (auto _ : state) {
ABORT_NOT_OK(
CaseWhen(cond->Slice(offset), {val1->Slice(offset), val2->Slice(offset),
val3->Slice(offset), val4->Slice(offset)}));
ABORT_NOT_OK(CaseWhen(cond, {val1, val2, val3, val4}));
}

// Set bytes processed to ~length of output
Expand Down Expand Up @@ -372,6 +375,21 @@ static void CaseWhenBenchStringContiguous(benchmark::State& state) {
return CaseWhenBenchContiguous<StringType>(state);
}

template <typename ListType, typename ValueType>
static void CaseWhenBenchVarLengthListLike(benchmark::State& state) {
auto value_type = TypeTraits<ValueType>::type_singleton();
auto list_type = std::make_shared<ListType>(value_type);
return CaseWhenBenchList<ListType>(state, list_type);
}

static void CaseWhenBenchListInt64(benchmark::State& state) {
return CaseWhenBenchVarLengthListLike<ListType, Int64Type>(state);
}

static void CaseWhenBenchListViewInt64(benchmark::State& state) {
CaseWhenBenchVarLengthListLike<ListViewType, Int64Type>(state);
}

struct CoalesceParams {
int64_t length;
int64_t num_arguments;
Expand Down Expand Up @@ -533,9 +551,11 @@ BENCHMARK(CaseWhenBench64)->Args({kNumItems, 99});
BENCHMARK(CaseWhenBench64Contiguous)->Args({kNumItems, 0});
BENCHMARK(CaseWhenBench64Contiguous)->Args({kNumItems, 99});

// CaseWhen: Lists
BENCHMARK(CaseWhenBenchList)->Args({kFewItems, 0});
BENCHMARK(CaseWhenBenchList)->Args({kFewItems, 99});
// CaseWhen: List-like types
BENCHMARK(CaseWhenBenchListInt64)->Args({kFewItems, 0});
BENCHMARK(CaseWhenBenchListInt64)->Args({kFewItems, 99});
BENCHMARK(CaseWhenBenchListViewInt64)->Args({kFewItems, 0});
BENCHMARK(CaseWhenBenchListViewInt64)->Args({kFewItems, 99});

// CaseWhen: Strings
BENCHMARK(CaseWhenBenchString)->Args({kFewItems, 0});
Expand Down
19 changes: 17 additions & 2 deletions cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,21 @@ TEST_F(TestIfElseKernel, ParameterizedTypes) {
{cond, ArrayFromJSON(type0, "[0]"), ArrayFromJSON(type1, "[1]")}));
}

TEST_F(TestIfElseKernel, MapNested) {
auto type = map(int64(), utf8());
CheckWithDifferentShapes(
ArrayFromJSON(boolean(), "[true, true, false, false]"),
ArrayFromJSON(type, R"([null, [[2, "foo"], [4, null]], [[3, "test"]], []])"),
ArrayFromJSON(type, R"([[[1, "b"]], [[2, "c"]], [[7, "abc"]], null])"),
ArrayFromJSON(type, R"([null, [[2, "foo"], [4, null]], [[7, "abc"]], null])"));

CheckWithDifferentShapes(
ArrayFromJSON(boolean(), "[null, null, null, null]"),
ArrayFromJSON(type, R"([null, [[1, "c"]], [[4, null]], [[6, "ok"]]])"),
ArrayFromJSON(type, R"([[[-1, null]], [[3, "c"]], null, [[6, "ok"]]])"),
ArrayFromJSON(type, R"([null, null, null, null])"));
}

template <typename Type>
class TestIfElseUnion : public ::testing::Test {};

Expand Down Expand Up @@ -1920,7 +1935,7 @@ TYPED_TEST(TestCaseWhenBinary, Random) {
template <typename Type>
class TestCaseWhenList : public ::testing::Test {};

TYPED_TEST_SUITE(TestCaseWhenList, ListArrowTypes);
TYPED_TEST_SUITE(TestCaseWhenList, ListAndListViewArrowTypes);

TYPED_TEST(TestCaseWhenList, ListOfString) {
auto type = std::make_shared<TypeParam>(utf8());
Expand Down Expand Up @@ -2555,7 +2570,7 @@ class TestCoalesceList : public ::testing::Test {};

TYPED_TEST_SUITE(TestCoalesceNumeric, IfElseNumericBasedTypes);
TYPED_TEST_SUITE(TestCoalesceBinary, BaseBinaryArrowTypes);
TYPED_TEST_SUITE(TestCoalesceList, ListArrowTypes);
TYPED_TEST_SUITE(TestCoalesceList, ListAndListViewArrowTypes);

TYPED_TEST(TestCoalesceNumeric, Basics) {
auto type = default_type_instance<TypeParam>();
Expand Down

0 comments on commit 0d7fac0

Please sign in to comment.