Skip to content

Commit

Permalink
Take: Create a signature for Take kernels support AAA and CAA calls
Browse files Browse the repository at this point in the history
  • Loading branch information
felipecrv committed Sep 4, 2024
1 parent d5825fa commit 72101ab
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 177 deletions.
15 changes: 10 additions & 5 deletions cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -894,18 +894,23 @@ Status ExtensionFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult
}

// Transform filter to selection indices and then use Take.
Status FilterWithTakeExec(const ArrayKernelExec& take_exec, KernelContext* ctx,
Status FilterWithTakeExec(TakeKernelExec take_aaa_exec, KernelContext* ctx,
const ExecSpan& batch, ExecResult* out) {
std::shared_ptr<ArrayData> indices;
std::shared_ptr<ArrayData> indices_data;
RETURN_NOT_OK(GetTakeIndices(batch[1].array,
FilterState::Get(ctx).null_selection_behavior,
ctx->memory_pool())
.Value(&indices));
.Value(&indices_data));

KernelContext take_ctx(*ctx);
TakeState state{TakeOptions::NoBoundsCheck()};
take_ctx.SetState(&state);
ExecSpan take_batch({batch[0], ArraySpan(*indices)}, batch.length);
return take_exec(&take_ctx, take_batch, out);

ValuesSpan values(batch[0].array);
std::shared_ptr<ArrayData> out_data = out->array_data();
RETURN_NOT_OK(take_aaa_exec(&take_ctx, values, *indices_data, &out_data));
out->value = std::move(out_data);
return Status::OK();
}

// Due to the special treatment with their Take kernels, we filter Struct and SparseUnion
Expand Down
73 changes: 42 additions & 31 deletions cpp/src/arrow/compute/kernels/vector_selection_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -968,69 +968,80 @@ Status MapFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out)

namespace {

template <typename Impl>
Status TakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
template <typename SelectionImpl>
Status TakeAAAExec(KernelContext* ctx, const ValuesSpan& values, const ArraySpan& indices,
std::shared_ptr<ArrayData>* out) {
DCHECK(!values.is_chunked())
<< "TakeAAAExec kernels can't be called with chunked array values";
if (TakeState::Get(ctx).boundscheck) {
RETURN_NOT_OK(CheckIndexBounds(batch[1].array, batch[0].length()));
RETURN_NOT_OK(CheckIndexBounds(indices, values.length()));
}
Impl kernel(ctx, batch, /*output_length=*/batch[1].length(), out);
SelectionImpl kernel(ctx, values, indices, out);
return kernel.ExecTake();
}

} // namespace

Status VarBinaryTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return TakeExec<VarBinarySelectionImpl<BinaryType>>(ctx, batch, out);
Status VarBinaryTakeExec(KernelContext* ctx, const ValuesSpan& values,
const ArraySpan& indices, std::shared_ptr<ArrayData>* out) {
return TakeAAAExec<VarBinarySelectionImpl<BinaryType>>(ctx, values, indices, out);
}

Status LargeVarBinaryTakeExec(KernelContext* ctx, const ExecSpan& batch,
ExecResult* out) {
return TakeExec<VarBinarySelectionImpl<LargeBinaryType>>(ctx, batch, out);
Status LargeVarBinaryTakeExec(KernelContext* ctx, const ValuesSpan& values,
const ArraySpan& indices, std::shared_ptr<ArrayData>* out) {
return TakeAAAExec<VarBinarySelectionImpl<LargeBinaryType>>(ctx, values, indices, out);
}

Status ListTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return TakeExec<ListSelectionImpl<ListType>>(ctx, batch, out);
Status ListTakeExec(KernelContext* ctx, const ValuesSpan& values,
const ArraySpan& indices, std::shared_ptr<ArrayData>* out) {
return TakeAAAExec<ListSelectionImpl<ListType>>(ctx, values, indices, out);
}

Status LargeListTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return TakeExec<ListSelectionImpl<LargeListType>>(ctx, batch, out);
Status LargeListTakeExec(KernelContext* ctx, const ValuesSpan& values,
const ArraySpan& indices, std::shared_ptr<ArrayData>* out) {
return TakeAAAExec<ListSelectionImpl<LargeListType>>(ctx, values, indices, out);
}

Status ListViewTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return TakeExec<ListViewSelectionImpl<ListViewType>>(ctx, batch, out);
Status ListViewTakeExec(KernelContext* ctx, const ValuesSpan& values,
const ArraySpan& indices, std::shared_ptr<ArrayData>* out) {
return TakeAAAExec<ListViewSelectionImpl<ListViewType>>(ctx, values, indices, out);
}

Status LargeListViewTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return TakeExec<ListViewSelectionImpl<LargeListViewType>>(ctx, batch, out);
Status LargeListViewTakeExec(KernelContext* ctx, const ValuesSpan& values,
const ArraySpan& indices, std::shared_ptr<ArrayData>* out) {
return TakeAAAExec<ListViewSelectionImpl<LargeListViewType>>(ctx, values, indices, out);
}

Status FSLTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
const ArraySpan& values = batch[0].array;

Status FSLTakeExec(KernelContext* ctx, const ValuesSpan& values, const ArraySpan& indices,
std::shared_ptr<ArrayData>* out) {
// If a FixedSizeList wraps a fixed-width type we can, in some cases, use
// FixedWidthTakeExec for a fixed-size list array.
if (util::IsFixedWidthLike(values,
if (util::IsFixedWidthLike(values.array(),
/*force_null_count=*/true,
/*exclude_bool_and_dictionary=*/true)) {
return FixedWidthTakeExec(ctx, batch, out);
return FixedWidthTakeExec(ctx, values, indices, out);
}
return TakeExec<FSLSelectionImpl>(ctx, batch, out);
return TakeAAAExec<FSLSelectionImpl>(ctx, values, indices, out);
}

Status DenseUnionTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return TakeExec<DenseUnionSelectionImpl>(ctx, batch, out);
Status DenseUnionTakeExec(KernelContext* ctx, const ValuesSpan& values,
const ArraySpan& indices, std::shared_ptr<ArrayData>* out) {
return TakeAAAExec<DenseUnionSelectionImpl>(ctx, values, indices, out);
}

Status SparseUnionTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return TakeExec<SparseUnionSelectionImpl>(ctx, batch, out);
Status SparseUnionTakeExec(KernelContext* ctx, const ValuesSpan& values,
const ArraySpan& indices, std::shared_ptr<ArrayData>* out) {
return TakeAAAExec<SparseUnionSelectionImpl>(ctx, values, indices, out);
}

Status StructTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return TakeExec<StructSelectionImpl>(ctx, batch, out);
Status StructTakeExec(KernelContext* ctx, const ValuesSpan& values,
const ArraySpan& indices, std::shared_ptr<ArrayData>* out) {
return TakeAAAExec<StructSelectionImpl>(ctx, values, indices, out);
}

Status MapTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return TakeExec<ListSelectionImpl<MapType>>(ctx, batch, out);
Status MapTakeExec(KernelContext* ctx, const ValuesSpan& values, const ArraySpan& indices,
std::shared_ptr<ArrayData>* out) {
return TakeAAAExec<ListSelectionImpl<MapType>>(ctx, values, indices, out);
}

} // namespace compute::internal
Expand Down
54 changes: 41 additions & 13 deletions cpp/src/arrow/compute/kernels/vector_selection_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,20 @@ class ValuesSpan {
}
};

/// \brief Type for a single "array_take" kernel function.
///
/// Instead of implementing both `ArrayKernelExec` and `ChunkedExec` typed
/// functions for each configurations of `array_take` parameters, we use
/// templates wrapping `TakeKernelExec` functions to expose exec functions
/// that can be registered in the kernel registry.
///
/// A `TakeKernelExec` always returns a single array, which is the result of
/// taking values from a single array (AA->A) or multiple arrays (CA->A). The
/// wrappers take care of converting the output of a CA call to C or calling
/// the kernel multiple times to process a CC call.
using TakeKernelExec = Status (*)(KernelContext*, const ValuesSpan&, const ArraySpan&,
std::shared_ptr<ArrayData>*);

struct SelectionKernelData {
SelectionKernelData(InputType value_type, InputType selection_type,
ArrayKernelExec exec,
Expand Down Expand Up @@ -149,19 +163,33 @@ Status FSLFilterExec(KernelContext*, const ExecSpan&, ExecResult*);
Status DenseUnionFilterExec(KernelContext*, const ExecSpan&, ExecResult*);
Status MapFilterExec(KernelContext*, const ExecSpan&, ExecResult*);

Status VarBinaryTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status LargeVarBinaryTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status FixedWidthTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status FixedWidthTakeChunkedExec(KernelContext*, const ExecBatch&, Datum*);
Status ListTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status LargeListTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status ListViewTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status LargeListViewTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status FSLTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status DenseUnionTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status SparseUnionTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status StructTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status MapTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
// Take kernels compatible with the TakeKernelExec signature
Status VarBinaryTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&,
std::shared_ptr<ArrayData>*);
Status LargeVarBinaryTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&,
std::shared_ptr<ArrayData>*);
Status FixedWidthTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&,
std::shared_ptr<ArrayData>*);
Status FixedWidthTakeChunkedExec(KernelContext*, const ValuesSpan&, const ArraySpan&,
std::shared_ptr<ArrayData>*);
Status ListTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&,
std::shared_ptr<ArrayData>*);
Status LargeListTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&,
std::shared_ptr<ArrayData>*);
Status ListViewTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&,
std::shared_ptr<ArrayData>*);
Status LargeListViewTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&,
std::shared_ptr<ArrayData>*);
Status FSLTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&,
std::shared_ptr<ArrayData>*);
Status DenseUnionTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&,
std::shared_ptr<ArrayData>*);
Status SparseUnionTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&,
std::shared_ptr<ArrayData>*);
Status StructTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&,
std::shared_ptr<ArrayData>*);
Status MapTakeExec(KernelContext*, const ValuesSpan&, const ArraySpan&,
std::shared_ptr<ArrayData>*);

} // namespace compute::internal
} // namespace arrow
Loading

0 comments on commit 72101ab

Please sign in to comment.