Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-41418: [C++] Add [Large]ListView and Map nested types for scalar_if_else's kernel functions #41419

Merged
merged 2 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 86 additions & 21 deletions cpp/src/arrow/compute/kernels/scalar_if_else.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1309,9 +1309,10 @@ void AddFixedWidthIfElseKernel(const std::shared_ptr<IfElseFunction>& scalar_fun
}

void AddNestedIfElseKernels(const std::shared_ptr<IfElseFunction>& scalar_function) {
for (const auto type_id : {Type::LIST, Type::LARGE_LIST, Type::LIST_VIEW,
Type::LARGE_LIST_VIEW, Type::FIXED_SIZE_LIST, Type::STRUCT,
Type::DENSE_UNION, Type::SPARSE_UNION, Type::DICTIONARY}) {
for (const auto type_id :
{Type::LIST, Type::LARGE_LIST, Type::LIST_VIEW, Type::LARGE_LIST_VIEW,
Type::FIXED_SIZE_LIST, Type::MAP, Type::STRUCT, Type::DENSE_UNION,
Type::SPARSE_UNION, Type::DICTIONARY}) {
ScalarKernel kernel({boolean(), InputType(type_id), InputType(type_id)}, LastType,
NestedIfElseExec::Exec);
kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
Expand Down Expand Up @@ -1847,6 +1848,48 @@ struct CaseWhenFunctor<Type, enable_if_var_size_list<Type>> {
}
};

// TODO(GH-41453): a more efficient implementation for list-views is possible
template <typename Type>
struct CaseWhenFunctor<Type, enable_if_list_view<Type>> {
using offset_type = typename Type::offset_type;
using BuilderType = typename TypeTraits<Type>::BuilderType;
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
/// TODO(wesm): should this be a DCHECK? Or checked elsewhere
if (batch[0].null_count() > 0) {
return Status::Invalid("cond struct must not have outer nulls");
}
if (batch[0].is_scalar()) {
return ExecVarWidthScalarCaseWhen(ctx, batch, out);
}
return ExecArray(ctx, batch, out);
}

static Status ExecArray(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return ExecVarWidthArrayCaseWhen(
ctx, batch, out,
// ReserveData
[&](ArrayBuilder* raw_builder) {
auto builder = checked_cast<BuilderType*>(raw_builder);
auto child_builder = builder->value_builder();

int64_t reservation = 0;
for (int arg = 1; arg < batch.num_values(); arg++) {
const ExecValue& source = batch[arg];
if (!source.is_array()) {
const auto& scalar = checked_cast<const BaseListScalar&>(*source.scalar);
if (!scalar.value) continue;
reservation =
std::max<int64_t>(reservation, batch.length * scalar.value->length());
} else {
const ArraySpan& array = source.array;
reservation = std::max<int64_t>(reservation, array.child_data[0].length);
}
}
return child_builder->Reserve(reservation);
});
}
};

// No-op reserve function, pulled out to avoid apparent miscompilation on MinGW
Status ReserveNoData(ArrayBuilder*) { return Status::OK(); }

Expand Down Expand Up @@ -2712,6 +2755,25 @@ void AddBinaryCaseWhenKernels(const std::shared_ptr<CaseWhenFunction>& scalar_fu
}
}

template <typename ArrowNestedType>
void AddNestedCaseWhenKernel(const std::shared_ptr<CaseWhenFunction>& scalar_function) {
AddCaseWhenKernel(scalar_function, ArrowNestedType::type_id,
CaseWhenFunctor<ArrowNestedType>::Exec);
}

void AddNestedCaseWhenKernels(const std::shared_ptr<CaseWhenFunction>& scalar_function) {
AddNestedCaseWhenKernel<FixedSizeListType>(scalar_function);
AddNestedCaseWhenKernel<ListType>(scalar_function);
AddNestedCaseWhenKernel<LargeListType>(scalar_function);
AddNestedCaseWhenKernel<ListViewType>(scalar_function);
AddNestedCaseWhenKernel<LargeListViewType>(scalar_function);
AddNestedCaseWhenKernel<MapType>(scalar_function);
AddNestedCaseWhenKernel<StructType>(scalar_function);
AddNestedCaseWhenKernel<DenseUnionType>(scalar_function);
AddNestedCaseWhenKernel<SparseUnionType>(scalar_function);
AddNestedCaseWhenKernel<DictionaryType>(scalar_function);
}

void AddCoalesceKernel(const std::shared_ptr<ScalarFunction>& scalar_function,
detail::GetTypeId get_id, ArrayKernelExec exec) {
ScalarKernel kernel(KernelSignature::Make({InputType(get_id.id)}, FirstType,
Expand All @@ -2731,6 +2793,25 @@ void AddPrimitiveCoalesceKernels(const std::shared_ptr<ScalarFunction>& scalar_f
}
}

template <typename ArrowNestedType>
void AddNestedCoalesceKernel(const std::shared_ptr<ScalarFunction>& scalar_function) {
AddCoalesceKernel(scalar_function, ArrowNestedType::type_id,
CoalesceFunctor<ArrowNestedType>::Exec);
}

void AddNestedCoalesceKernels(const std::shared_ptr<ScalarFunction>& scalar_function) {
AddNestedCoalesceKernel<FixedSizeListType>(scalar_function);
AddNestedCoalesceKernel<ListType>(scalar_function);
AddNestedCoalesceKernel<LargeListType>(scalar_function);
AddNestedCoalesceKernel<ListViewType>(scalar_function);
AddNestedCoalesceKernel<LargeListViewType>(scalar_function);
AddNestedCoalesceKernel<MapType>(scalar_function);
AddNestedCoalesceKernel<StructType>(scalar_function);
AddNestedCoalesceKernel<DenseUnionType>(scalar_function);
AddNestedCoalesceKernel<SparseUnionType>(scalar_function);
AddNestedCoalesceKernel<DictionaryType>(scalar_function);
}

void AddChooseKernel(const std::shared_ptr<ScalarFunction>& scalar_function,
detail::GetTypeId get_id, ArrayKernelExec exec) {
ScalarKernel kernel(KernelSignature::Make({Type::INT64, InputType(get_id.id)}, LastType,
Expand Down Expand Up @@ -2822,15 +2903,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddCaseWhenKernel(func, Type::DECIMAL128, CaseWhenFunctor<FixedSizeBinaryType>::Exec);
AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor<FixedSizeBinaryType>::Exec);
AddBinaryCaseWhenKernels(func, BaseBinaryTypes());
AddCaseWhenKernel(func, Type::FIXED_SIZE_LIST,
CaseWhenFunctor<FixedSizeListType>::Exec);
AddCaseWhenKernel(func, Type::LIST, CaseWhenFunctor<ListType>::Exec);
AddCaseWhenKernel(func, Type::LARGE_LIST, CaseWhenFunctor<LargeListType>::Exec);
AddCaseWhenKernel(func, Type::MAP, CaseWhenFunctor<MapType>::Exec);
AddCaseWhenKernel(func, Type::STRUCT, CaseWhenFunctor<StructType>::Exec);
AddCaseWhenKernel(func, Type::DENSE_UNION, CaseWhenFunctor<DenseUnionType>::Exec);
AddCaseWhenKernel(func, Type::SPARSE_UNION, CaseWhenFunctor<SparseUnionType>::Exec);
AddCaseWhenKernel(func, Type::DICTIONARY, CaseWhenFunctor<DictionaryType>::Exec);
AddNestedCaseWhenKernels(func);
DCHECK_OK(registry->AddFunction(std::move(func)));
}
{
Expand All @@ -2848,15 +2921,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
for (const auto& ty : BaseBinaryTypes()) {
AddCoalesceKernel(func, ty, GenerateTypeAgnosticVarBinaryBase<CoalesceFunctor>(ty));
}
AddCoalesceKernel(func, Type::FIXED_SIZE_LIST,
CoalesceFunctor<FixedSizeListType>::Exec);
AddCoalesceKernel(func, Type::LIST, CoalesceFunctor<ListType>::Exec);
AddCoalesceKernel(func, Type::LARGE_LIST, CoalesceFunctor<LargeListType>::Exec);
AddCoalesceKernel(func, Type::MAP, CoalesceFunctor<MapType>::Exec);
AddCoalesceKernel(func, Type::STRUCT, CoalesceFunctor<StructType>::Exec);
AddCoalesceKernel(func, Type::DENSE_UNION, CoalesceFunctor<DenseUnionType>::Exec);
AddCoalesceKernel(func, Type::SPARSE_UNION, CoalesceFunctor<SparseUnionType>::Exec);
AddCoalesceKernel(func, Type::DICTIONARY, CoalesceFunctor<DictionaryType>::Exec);
AddNestedCoalesceKernels(func);
DCHECK_OK(registry->AddFunction(std::move(func)));
}
{
Expand Down
50 changes: 35 additions & 15 deletions cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,11 @@ static void CaseWhenBench(benchmark::State& state) {
state.SetItemsProcessed(state.iterations() * (len - offset));
}

static void CaseWhenBenchList(benchmark::State& state) {
auto type = list(int64());
template <typename Type>
static void CaseWhenBenchList(benchmark::State& state,
const std::shared_ptr<DataType>& type) {
using ArrayType = typename TypeTraits<Type>::ArrayType;

auto fld = field("", type);

int64_t len = state.range(0);
Expand All @@ -295,17 +298,17 @@ static void CaseWhenBenchList(benchmark::State& state) {

auto cond_field =
field("cond", boolean(), key_value_metadata({{"null_probability", "0.01"}}));
auto cond = rand.ArrayOf(*field("", struct_({cond_field, cond_field, cond_field}),
key_value_metadata({{"null_probability", "0.0"}})),
len);
auto val1 = rand.ArrayOf(*fld, len);
auto val2 = rand.ArrayOf(*fld, len);
auto val3 = rand.ArrayOf(*fld, len);
auto val4 = rand.ArrayOf(*fld, len);
auto cond = std::static_pointer_cast<BooleanArray>(
rand.ArrayOf(*field("", struct_({cond_field, cond_field, cond_field}),
key_value_metadata({{"null_probability", "0.0"}})),
len))
->Slice(offset);
auto val1 = std::static_pointer_cast<ArrayType>(rand.ArrayOf(*fld, len))->Slice(offset);
auto val2 = std::static_pointer_cast<ArrayType>(rand.ArrayOf(*fld, len))->Slice(offset);
auto val3 = std::static_pointer_cast<ArrayType>(rand.ArrayOf(*fld, len))->Slice(offset);
auto val4 = std::static_pointer_cast<ArrayType>(rand.ArrayOf(*fld, len))->Slice(offset);
for (auto _ : state) {
ABORT_NOT_OK(
CaseWhen(cond->Slice(offset), {val1->Slice(offset), val2->Slice(offset),
val3->Slice(offset), val4->Slice(offset)}));
ABORT_NOT_OK(CaseWhen(cond, {val1, val2, val3, val4}));
}

// Set bytes processed to ~length of output
Expand Down Expand Up @@ -372,6 +375,21 @@ static void CaseWhenBenchStringContiguous(benchmark::State& state) {
return CaseWhenBenchContiguous<StringType>(state);
}

template <typename ListType, typename ValueType>
static void CaseWhenBenchVarLengthListLike(benchmark::State& state) {
auto value_type = TypeTraits<ValueType>::type_singleton();
auto list_type = std::make_shared<ListType>(value_type);
return CaseWhenBenchList<ListType>(state, list_type);
}

static void CaseWhenBenchListInt64(benchmark::State& state) {
return CaseWhenBenchVarLengthListLike<ListType, Int64Type>(state);
}

static void CaseWhenBenchListViewInt64(benchmark::State& state) {
CaseWhenBenchVarLengthListLike<ListViewType, Int64Type>(state);
}

struct CoalesceParams {
int64_t length;
int64_t num_arguments;
Expand Down Expand Up @@ -533,9 +551,11 @@ BENCHMARK(CaseWhenBench64)->Args({kNumItems, 99});
BENCHMARK(CaseWhenBench64Contiguous)->Args({kNumItems, 0});
BENCHMARK(CaseWhenBench64Contiguous)->Args({kNumItems, 99});

// CaseWhen: Lists
BENCHMARK(CaseWhenBenchList)->Args({kFewItems, 0});
BENCHMARK(CaseWhenBenchList)->Args({kFewItems, 99});
// CaseWhen: List-like types
BENCHMARK(CaseWhenBenchListInt64)->Args({kFewItems, 0});
BENCHMARK(CaseWhenBenchListInt64)->Args({kFewItems, 99});
BENCHMARK(CaseWhenBenchListViewInt64)->Args({kFewItems, 0});
BENCHMARK(CaseWhenBenchListViewInt64)->Args({kFewItems, 99});

// CaseWhen: Strings
BENCHMARK(CaseWhenBenchString)->Args({kFewItems, 0});
Expand Down
19 changes: 17 additions & 2 deletions cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,21 @@ TEST_F(TestIfElseKernel, ParameterizedTypes) {
{cond, ArrayFromJSON(type0, "[0]"), ArrayFromJSON(type1, "[1]")}));
}

TEST_F(TestIfElseKernel, MapNested) {
auto type = map(int64(), utf8());
CheckWithDifferentShapes(
ArrayFromJSON(boolean(), "[true, true, false, false]"),
ArrayFromJSON(type, R"([null, [[2, "foo"], [4, null]], [[3, "test"]], []])"),
ArrayFromJSON(type, R"([[[1, "b"]], [[2, "c"]], [[7, "abc"]], null])"),
ArrayFromJSON(type, R"([null, [[2, "foo"], [4, null]], [[7, "abc"]], null])"));

CheckWithDifferentShapes(
ArrayFromJSON(boolean(), "[null, null, null, null]"),
ArrayFromJSON(type, R"([null, [[1, "c"]], [[4, null]], [[6, "ok"]]])"),
ArrayFromJSON(type, R"([[[-1, null]], [[3, "c"]], null, [[6, "ok"]]])"),
ArrayFromJSON(type, R"([null, null, null, null])"));
}

template <typename Type>
class TestIfElseUnion : public ::testing::Test {};

Expand Down Expand Up @@ -1920,7 +1935,7 @@ TYPED_TEST(TestCaseWhenBinary, Random) {
template <typename Type>
class TestCaseWhenList : public ::testing::Test {};

TYPED_TEST_SUITE(TestCaseWhenList, ListArrowTypes);
TYPED_TEST_SUITE(TestCaseWhenList, ListAndListViewArrowTypes);

TYPED_TEST(TestCaseWhenList, ListOfString) {
auto type = std::make_shared<TypeParam>(utf8());
Expand Down Expand Up @@ -2555,7 +2570,7 @@ class TestCoalesceList : public ::testing::Test {};

TYPED_TEST_SUITE(TestCoalesceNumeric, IfElseNumericBasedTypes);
TYPED_TEST_SUITE(TestCoalesceBinary, BaseBinaryArrowTypes);
TYPED_TEST_SUITE(TestCoalesceList, ListArrowTypes);
TYPED_TEST_SUITE(TestCoalesceList, ListAndListViewArrowTypes);

TYPED_TEST(TestCoalesceNumeric, Basics) {
auto type = default_type_instance<TypeParam>();
Expand Down
Loading