diff --git a/cpp/src/arrow/compute/kernels/filter.cc b/cpp/src/arrow/compute/kernels/filter.cc index a9efda8b18ea2..6a4e346511677 100644 --- a/cpp/src/arrow/compute/kernels/filter.cc +++ b/cpp/src/arrow/compute/kernels/filter.cc @@ -172,7 +172,9 @@ class FilterImpl : public FilterKernel { template <> class FilterImpl : public FilterKernel { public: - using FilterKernel::FilterKernel; + FilterImpl(const std::shared_ptr& type, + std::vector> child_kernels) + : FilterKernel(type), child_kernels_(std::move(child_kernels)) {} Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, std::shared_ptr* out) override { @@ -185,7 +187,7 @@ class FilterImpl : public FilterKernel { ArrayVector fields(type_->num_children()); for (int i = 0; i < type_->num_children(); ++i) { RETURN_NOT_OK( - arrow::compute::Filter(ctx, *struct_array.field(i), filter, &fields[i])); + child_kernels_[i]->Filter(ctx, *struct_array.field(i), filter, &fields[i])); } for (int64_t i = 0; i < filter.length(); ++i) { @@ -210,6 +212,9 @@ class FilterImpl : public FilterKernel { out->reset(new StructArray(type_, length, fields, null_bitmap, null_count)); return Status::OK(); } + + private: + std::vector> child_kernels_; }; template <> @@ -331,10 +336,10 @@ class FilterImpl : public FilterImpl { using FilterImpl::FilterImpl; }; -class DictionaryFilterImpl : public FilterKernel { +template <> +class FilterImpl : public FilterKernel { public: - DictionaryFilterImpl(const std::shared_ptr& type, - std::unique_ptr impl) + FilterImpl(const std::shared_ptr& type, std::unique_ptr impl) : FilterKernel(type), impl_(std::move(impl)) {} Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, @@ -351,10 +356,10 @@ class DictionaryFilterImpl : public FilterKernel { std::unique_ptr impl_; }; -class ExtensionFilterImpl : public FilterKernel { +template <> +class FilterImpl : public FilterKernel { public: - ExtensionFilterImpl(const std::shared_ptr& type, - std::unique_ptr impl) + FilterImpl(const std::shared_ptr& type, std::unique_ptr impl) : FilterKernel(type), impl_(std::move(impl)) {} Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, @@ -379,13 +384,13 @@ Status FilterKernel::Make(const std::shared_ptr& value_type, *out = internal::make_unique>(value_type); \ return Status::OK() -#define SINGLE_CHILD_CASE(T, IMPL, CHILD_TYPE) \ - case T##Type::type_id: { \ - auto t = checked_pointer_cast(value_type); \ - std::unique_ptr child_filter_impl; \ - RETURN_NOT_OK(FilterKernel::Make(t->CHILD_TYPE(), &child_filter_impl)); \ - *out = internal::make_unique(t, std::move(child_filter_impl)); \ - return Status::OK(); \ +#define SINGLE_CHILD_CASE(T, CHILD_TYPE) \ + case T##Type::type_id: { \ + auto t = checked_pointer_cast(value_type); \ + std::unique_ptr child_filter_impl; \ + RETURN_NOT_OK(FilterKernel::Make(t->CHILD_TYPE(), &child_filter_impl)); \ + *out = internal::make_unique>(t, std::move(child_filter_impl)); \ + return Status::OK(); \ } NO_CHILD_CASE(Null); @@ -412,14 +417,23 @@ Status FilterKernel::Make(const std::shared_ptr& value_type, NO_CHILD_CASE(FixedSizeBinary); NO_CHILD_CASE(Decimal128); - SINGLE_CHILD_CASE(Dictionary, DictionaryFilterImpl, index_type); - SINGLE_CHILD_CASE(Extension, ExtensionFilterImpl, storage_type); + SINGLE_CHILD_CASE(Dictionary, index_type); + SINGLE_CHILD_CASE(Extension, storage_type); NO_CHILD_CASE(List); NO_CHILD_CASE(FixedSizeList); NO_CHILD_CASE(Map); - NO_CHILD_CASE(Struct); + case Type::STRUCT: { + std::vector> child_kernels; + for (auto child : value_type->children()) { + child_kernels.emplace_back(); + RETURN_NOT_OK(FilterKernel::Make(child->type(), &child_kernels.back())); + } + *out = internal::make_unique>(value_type, + std::move(child_kernels)); + return Status::OK(); + } #undef NO_CHILD_CASE #undef SINGLE_CHILD_CASE