From a6b210d5e8afdd0326ac9fa62dc5f4c0580a59cf Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 19 Jun 2019 21:19:48 -0500 Subject: [PATCH] ARROW-1558: [C++] Implement boolean filter (selection) kernel, rename comparison kernel-related functions Materializes an array masked by a selection array (for example one produced by the filter kernel) Author: Benjamin Kietzman Closes #4366 from bkietz/1558-Implement-boolean-selection-kernels and squashes the following commits: 032d341bc fix doc error 3d92b6e12 Make FilterKernel public e8465e5d2 iwyu: vector 030ac57ea filter benchmarks += MinTime(1.0) nanoseconds 770205535 use expanded bitmap for FixedSizeList and List 060313c47 refactor FilterImpl to own child kernels 24f2e852f add larger benchmarks to test for O(N^2) perf e4d9d85eb refactor FilterKernel::Make to use a switch f833e02fc add benchmark for fixed_size_list(int64(), 1) f424f34c5 fix nits and typos 3387f21b9 use new path for concatenate.h 495e5217b Add support for filtering MapArray a8cb993da fix lint error e3b402281 add filter impls for nested types a21638817 add explicit qualification for MSVC ccd32a532 add a basic filter benchmark 8a9f37934 add a test integrating with arrow::compute::Compare (array-array) 7c5002715 add a test integrating with arrow::compute::Compare 6efc4f55b add filter tests with large, random arrays 0f29ab27d rename Mask -> Filter edf2eb15d rename FilterFunction -> CompareFunction 4b24ca330 revert removal of TakeOptions 4c8ce6d9f revert submodule a54741ea2 add some tests with empty masks/take indices d5c9c14f6 use checked_cast 223a8605e fix typo c953dca1b remove empty TakeOptions db444242e remove empty MaskOptions 13a1969d2 initial mask kernel impl --- cpp/src/arrow/CMakeLists.txt | 2 +- cpp/src/arrow/array.cc | 169 +++++ cpp/src/arrow/array.h | 18 +- cpp/src/arrow/array/builder_primitive.cc | 7 + cpp/src/arrow/array/builder_primitive.h | 2 + cpp/src/arrow/compute/benchmark-util.h | 23 + cpp/src/arrow/compute/kernels/CMakeLists.txt | 8 +- .../compute/kernels/compare-benchmark.cc | 84 +++ cpp/src/arrow/compute/kernels/compare-test.cc | 390 +++++++++++ cpp/src/arrow/compute/kernels/compare.cc | 91 ++- cpp/src/arrow/compute/kernels/compare.h | 70 +- .../arrow/compute/kernels/filter-benchmark.cc | 70 +- cpp/src/arrow/compute/kernels/filter-test.cc | 606 +++++++++--------- cpp/src/arrow/compute/kernels/filter.cc | 455 ++++++++++++- cpp/src/arrow/compute/kernels/filter.h | 104 ++- cpp/src/arrow/compute/kernels/take-test.cc | 1 + cpp/src/arrow/compute/kernels/take.cc | 10 +- cpp/src/arrow/extension_type.h | 4 + 18 files changed, 1644 insertions(+), 470 deletions(-) create mode 100644 cpp/src/arrow/compute/kernels/compare-benchmark.cc create mode 100644 cpp/src/arrow/compute/kernels/compare-test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index c989f855a5b0b..1666d178fac7b 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -172,8 +172,8 @@ if(ARROW_COMPUTE) compute/kernels/cast.cc compute/kernels/compare.cc compute/kernels/count.cc - compute/kernels/filter.cc compute/kernels/hash.cc + compute/kernels/filter.cc compute/kernels/mean.cc compute/kernels/sum.cc compute/kernels/take.cc diff --git a/cpp/src/arrow/array.cc b/cpp/src/arrow/array.cc index 9d37b45914bd0..b41d6dda6816d 100644 --- a/cpp/src/arrow/array.cc +++ b/cpp/src/arrow/array.cc @@ -1369,6 +1369,175 @@ std::shared_ptr MakeArray(const std::shared_ptr& data) { namespace internal { +// get the maximum buffer length required, then allocate a single zeroed buffer +// to use anywhere a buffer is required +class NullArrayFactory { + public: + struct GetBufferLength { + GetBufferLength(const std::shared_ptr& type, int64_t length) + : type_(*type), length_(length), buffer_length_(BitUtil::BytesForBits(length)) {} + + operator int64_t() && { + DCHECK_OK(VisitTypeInline(type_, this)); + return buffer_length_; + } + + template ::bytes_required(0))> + Status Visit(const T&) { + return MaxOf(TypeTraits::bytes_required(length_)); + } + + Status Visit(const ListType& type) { + // list's values array may be empty, but there must be at least one offset of 0 + return MaxOf(sizeof(int32_t)); + } + + Status Visit(const FixedSizeListType& type) { + return MaxOf(GetBufferLength(type.value_type(), type.list_size() * length_)); + } + + Status Visit(const StructType& type) { + for (const auto& child : type.children()) { + DCHECK_OK(MaxOf(GetBufferLength(child->type(), length_))); + } + return Status::OK(); + } + + Status Visit(const UnionType& type) { + // type codes + DCHECK_OK(MaxOf(length_)); + if (type.mode() == UnionMode::DENSE) { + // offsets + DCHECK_OK(MaxOf(sizeof(int32_t) * length_)); + } + for (const auto& child : type.children()) { + DCHECK_OK(MaxOf(GetBufferLength(child->type(), length_))); + } + return Status::OK(); + } + + Status Visit(const DictionaryType& type) { + DCHECK_OK(MaxOf(GetBufferLength(type.value_type(), length_))); + return MaxOf(GetBufferLength(type.index_type(), length_)); + } + + Status Visit(const ExtensionType& type) { + // XXX is an extension array's length always == storage length + return MaxOf(GetBufferLength(type.storage_type(), length_)); + } + + Status Visit(const DataType& type) { + return Status::NotImplemented("construction of all-null ", type); + } + + private: + Status MaxOf(int64_t buffer_length) { + if (buffer_length > buffer_length_) { + buffer_length_ = buffer_length; + } + return Status::OK(); + } + + const DataType& type_; + int64_t length_, buffer_length_; + }; + + NullArrayFactory(const std::shared_ptr& type, int64_t length, + std::shared_ptr* out) + : type_(type), length_(length), out_(out) {} + + Status CreateBuffer() { + int64_t buffer_length = GetBufferLength(type_, length_); + RETURN_NOT_OK(AllocateBuffer(buffer_length, &buffer_)); + std::memset(buffer_->mutable_data(), 0, buffer_->size()); + return Status::OK(); + } + + Status Create() { + if (buffer_ == nullptr) { + RETURN_NOT_OK(CreateBuffer()); + } + std::vector> child_data(type_->num_children()); + *out_ = ArrayData::Make(type_, length_, {buffer_}, child_data, length_, 0); + return VisitTypeInline(*type_, this); + } + + Status Visit(const NullType&) { return Status::OK(); } + + Status Visit(const FixedWidthType&) { + (*out_)->buffers.resize(2, buffer_); + return Status::OK(); + } + + Status Visit(const BinaryType&) { + (*out_)->buffers.resize(3, buffer_); + return Status::OK(); + } + + Status Visit(const ListType& type) { + (*out_)->buffers.resize(2, buffer_); + return CreateChild(0, length_, &(*out_)->child_data[0]); + } + + Status Visit(const FixedSizeListType& type) { + return CreateChild(0, length_ * type.list_size(), &(*out_)->child_data[0]); + } + + Status Visit(const StructType& type) { + for (int i = 0; i < type_->num_children(); ++i) { + DCHECK_OK(CreateChild(i, length_, &(*out_)->child_data[i])); + } + return Status::OK(); + } + + Status Visit(const UnionType& type) { + if (type.mode() == UnionMode::DENSE) { + (*out_)->buffers.resize(3, buffer_); + } else { + (*out_)->buffers.resize(2, buffer_); + } + + for (int i = 0; i < type_->num_children(); ++i) { + DCHECK_OK(CreateChild(i, length_, &(*out_)->child_data[i])); + } + return Status::OK(); + } + + Status Visit(const DictionaryType& type) { + (*out_)->buffers.resize(2, buffer_); + std::shared_ptr dictionary_data; + return MakeArrayOfNull(type.value_type(), 0, &(*out_)->dictionary); + } + + Status Visit(const DataType& type) { + return Status::NotImplemented("construction of all-null ", type); + } + + Status CreateChild(int i, int64_t length, std::shared_ptr* out) { + NullArrayFactory child_factory(type_->child(i)->type(), length, + &(*out_)->child_data[i]); + child_factory.buffer_ = buffer_; + return child_factory.Create(); + } + + std::shared_ptr type_; + int64_t length_; + std::shared_ptr* out_; + std::shared_ptr buffer_; +}; + +} // namespace internal + +Status MakeArrayOfNull(const std::shared_ptr& type, int64_t length, + std::shared_ptr* out) { + std::shared_ptr out_data; + RETURN_NOT_OK(internal::NullArrayFactory(type, length, &out_data).Create()); + *out = MakeArray(out_data); + return Status::OK(); +} + +namespace internal { + std::vector RechunkArraysConsistently( const std::vector& groups) { if (groups.size() <= 1) { diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h index a655422730238..78542c6b5bec4 100644 --- a/cpp/src/arrow/array.h +++ b/cpp/src/arrow/array.h @@ -220,6 +220,14 @@ struct ARROW_EXPORT ArrayData { ARROW_EXPORT std::shared_ptr MakeArray(const std::shared_ptr& data); +/// \brief Create a strongly-typed Array instance with all elements null +/// \param[in] type the array type +/// \param[in] length the array length +/// \param[out] out resulting Array instance +ARROW_EXPORT +Status MakeArrayOfNull(const std::shared_ptr& type, int64_t length, + std::shared_ptr* out); + // ---------------------------------------------------------------------- // User array accessor types @@ -521,12 +529,15 @@ class ARROW_EXPORT ListArray : public Array { /// Return pointer to raw value offsets accounting for any slice offset const int32_t* raw_value_offsets() const { return raw_value_offsets_ + data_->offset; } - // Neither of these functions will perform boundschecking + // The following functions will not perform boundschecking int32_t value_offset(int64_t i) const { return raw_value_offsets_[i + data_->offset]; } int32_t value_length(int64_t i) const { i += data_->offset; return raw_value_offsets_[i + 1] - raw_value_offsets_[i]; } + std::shared_ptr value_slice(int64_t i) const { + return values_->Slice(value_offset(i), value_length(i)); + } protected: // This constructor defers SetData to a derived array class @@ -596,12 +607,15 @@ class ARROW_EXPORT FixedSizeListArray : public Array { std::shared_ptr value_type() const; - // Neither of these functions will perform boundschecking + // The following functions will not perform boundschecking int32_t value_offset(int64_t i) const { i += data_->offset; return static_cast(list_size_ * i); } int32_t value_length(int64_t i = 0) const { return list_size_; } + std::shared_ptr value_slice(int64_t i) const { + return values_->Slice(value_offset(i), value_length(i)); + } protected: void SetData(const std::shared_ptr& data); diff --git a/cpp/src/arrow/array/builder_primitive.cc b/cpp/src/arrow/array/builder_primitive.cc index d4def92760027..3c899c068cb84 100644 --- a/cpp/src/arrow/array/builder_primitive.cc +++ b/cpp/src/arrow/array/builder_primitive.cc @@ -128,4 +128,11 @@ Status BooleanBuilder::AppendValues(const std::vector& values) { return Status::OK(); } +Status BooleanBuilder::AppendValues(int64_t length, bool value) { + RETURN_NOT_OK(Reserve(length)); + data_builder_.UnsafeAppend(length, value); + ArrayBuilder::UnsafeSetNotNull(length); + return Status::OK(); +} + } // namespace arrow diff --git a/cpp/src/arrow/array/builder_primitive.h b/cpp/src/arrow/array/builder_primitive.h index 3d566846d1947..8abbe029e1341 100644 --- a/cpp/src/arrow/array/builder_primitive.h +++ b/cpp/src/arrow/array/builder_primitive.h @@ -409,6 +409,8 @@ class ARROW_EXPORT BooleanBuilder : public ArrayBuilder { return Status::OK(); } + Status AppendValues(int64_t length, bool value); + Status FinishInternal(std::shared_ptr* out) override; /// \cond FALSE diff --git a/cpp/src/arrow/compute/benchmark-util.h b/cpp/src/arrow/compute/benchmark-util.h index ee9cb9504a3d1..113fdd7031281 100644 --- a/cpp/src/arrow/compute/benchmark-util.h +++ b/cpp/src/arrow/compute/benchmark-util.h @@ -70,5 +70,28 @@ void RegressionSetArgs(benchmark::internal::Benchmark* bench) { BenchmarkSetArgsWithSizes(bench, {kL1Size}); } +// RAII struct to handle some of the boilerplate in regression benchmarks +struct RegressionArgs { + // size of memory tested (per iteration) in bytes + const int64_t size; + + // proportion of nulls in generated arrays + const double null_proportion; + + explicit RegressionArgs(benchmark::State& state) + : size(state.range(0)), + null_proportion(static_cast(state.range(1)) / 100.0), + state_(state) {} + + ~RegressionArgs() { + state_.counters["size"] = static_cast(size); + state_.counters["null_percent"] = static_cast(state_.range(1)); + state_.SetBytesProcessed(state_.iterations() * size); + } + + private: + benchmark::State& state_; +}; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 6c386c9b39c0b..1bbb5bc5f321e 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -20,13 +20,17 @@ arrow_install_all_headers("arrow/compute/kernels") add_arrow_test(boolean-test PREFIX "arrow-compute") add_arrow_test(cast-test PREFIX "arrow-compute") add_arrow_test(hash-test PREFIX "arrow-compute") -add_arrow_test(take-test PREFIX "arrow-compute") add_arrow_test(util-internal-test PREFIX "arrow-compute") # Aggregates add_arrow_test(aggregate-test PREFIX "arrow-compute") add_arrow_benchmark(aggregate-benchmark PREFIX "arrow-compute") -# Filters +# Comparison +add_arrow_test(compare-test PREFIX "arrow-compute") +add_arrow_benchmark(compare-benchmark PREFIX "arrow-compute") + +# Selection +add_arrow_test(take-test PREFIX "arrow-compute") add_arrow_test(filter-test PREFIX "arrow-compute") add_arrow_benchmark(filter-benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/kernels/compare-benchmark.cc b/cpp/src/arrow/compute/kernels/compare-benchmark.cc new file mode 100644 index 0000000000000..6983fe5cc1bff --- /dev/null +++ b/cpp/src/arrow/compute/kernels/compare-benchmark.cc @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "benchmark/benchmark.h" + +#include + +#include "arrow/compute/benchmark-util.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/kernels/compare.h" +#include "arrow/compute/test-util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" + +namespace arrow { +namespace compute { + +constexpr auto kSeed = 0x94378165; + +static void CompareArrayScalarKernel(benchmark::State& state) { + const int64_t memory_size = state.range(0); + const int64_t array_size = memory_size / sizeof(int64_t); + const double null_percent = static_cast(state.range(1)) / 100.0; + auto rand = random::RandomArrayGenerator(kSeed); + auto array = std::static_pointer_cast>( + rand.Int64(array_size, -100, 100, null_percent)); + + CompareOptions ge{GREATER_EQUAL}; + + FunctionContext ctx; + for (auto _ : state) { + Datum out; + ABORT_NOT_OK(Compare(&ctx, Datum(array), Datum(int64_t(0)), ge, &out)); + benchmark::DoNotOptimize(out); + } + + state.counters["size"] = static_cast(memory_size); + state.counters["null_percent"] = static_cast(state.range(1)); + state.SetBytesProcessed(state.iterations() * array_size * sizeof(int64_t)); +} + +static void CompareArrayArrayKernel(benchmark::State& state) { + const int64_t memory_size = state.range(0); + const int64_t array_size = memory_size / sizeof(int64_t); + const double null_percent = static_cast(state.range(1)) / 100.0; + auto rand = random::RandomArrayGenerator(kSeed); + auto lhs = std::static_pointer_cast>( + rand.Int64(array_size, -100, 100, null_percent)); + auto rhs = std::static_pointer_cast>( + rand.Int64(array_size, -100, 100, null_percent)); + + CompareOptions ge(GREATER_EQUAL); + + FunctionContext ctx; + for (auto _ : state) { + Datum out; + ABORT_NOT_OK(Compare(&ctx, Datum(lhs), Datum(rhs), ge, &out)); + benchmark::DoNotOptimize(out); + } + + state.counters["size"] = static_cast(memory_size); + state.counters["null_percent"] = static_cast(state.range(1)); + state.SetBytesProcessed(state.iterations() * array_size * sizeof(int64_t) * 2); +} + +BENCHMARK(CompareArrayScalarKernel)->Apply(RegressionSetArgs); +BENCHMARK(CompareArrayArrayKernel)->Apply(RegressionSetArgs); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/compare-test.cc b/cpp/src/arrow/compute/kernels/compare-test.cc new file mode 100644 index 0000000000000..1a6339c57d878 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/compare-test.cc @@ -0,0 +1,390 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include + +#include + +#include "arrow/array.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/kernels/compare.h" +#include "arrow/compute/test-util.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" +#include "arrow/util/checked_cast.h" + +#include "arrow/testing/gtest_common.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" + +namespace arrow { +namespace compute { + +TEST(TestComparatorOperator, BasicOperator) { + using T = int32_t; + std::vector vals{0, 1, 2, 3, 4, 5, 6}; + + for (int32_t i : vals) { + for (int32_t j : vals) { + EXPECT_EQ((Comparator::Compare(i, j)), i == j); + EXPECT_EQ((Comparator::Compare(i, j)), i != j); + EXPECT_EQ((Comparator::Compare(i, j)), i > j); + EXPECT_EQ((Comparator::Compare(i, j)), i >= j); + EXPECT_EQ((Comparator::Compare(i, j)), i < j); + EXPECT_EQ((Comparator::Compare(i, j)), i <= j); + } + } +} + +template +static void ValidateCompare(FunctionContext* ctx, CompareOptions options, + const Datum& lhs, const Datum& rhs, const Datum& expected) { + Datum result; + + ASSERT_OK(Compare(ctx, lhs, rhs, options, &result)); + AssertArraysEqual(*expected.make_array(), *result.make_array()); +} + +template +static void ValidateCompare(FunctionContext* ctx, CompareOptions options, + const char* lhs_str, const Datum& rhs, + const char* expected_str) { + auto lhs = ArrayFromJSON(TypeTraits::type_singleton(), lhs_str); + auto expected = ArrayFromJSON(TypeTraits::type_singleton(), expected_str); + ValidateCompare(ctx, options, lhs, rhs, expected); +} + +template +static void ValidateCompare(FunctionContext* ctx, CompareOptions options, + const Datum& lhs, const char* rhs_str, + const char* expected_str) { + auto rhs = ArrayFromJSON(TypeTraits::type_singleton(), rhs_str); + auto expected = ArrayFromJSON(TypeTraits::type_singleton(), expected_str); + ValidateCompare(ctx, options, lhs, rhs, expected); +} + +template +static void ValidateCompare(FunctionContext* ctx, CompareOptions options, + const char* lhs_str, const char* rhs_str, + const char* expected_str) { + auto lhs = ArrayFromJSON(TypeTraits::type_singleton(), lhs_str); + auto rhs = ArrayFromJSON(TypeTraits::type_singleton(), rhs_str); + auto expected = ArrayFromJSON(TypeTraits::type_singleton(), expected_str); + ValidateCompare(ctx, options, lhs, rhs, expected); +} + +template +static inline bool SlowCompare(CompareOperator op, const T& lhs, const T& rhs) { + switch (op) { + case EQUAL: + return lhs == rhs; + case NOT_EQUAL: + return lhs != rhs; + case GREATER: + return lhs > rhs; + case GREATER_EQUAL: + return lhs >= rhs; + case LESS: + return lhs < rhs; + case LESS_EQUAL: + return lhs <= rhs; + default: + return false; + } +} + +template +static Datum SimpleScalarArrayCompare(CompareOptions options, const Datum& lhs, + const Datum& rhs) { + using ArrayType = typename TypeTraits::ArrayType; + using ScalarType = typename TypeTraits::ScalarType; + using T = typename TypeTraits::CType; + + bool swap = lhs.is_array(); + auto array = std::static_pointer_cast((swap ? lhs : rhs).make_array()); + T value = std::static_pointer_cast((swap ? rhs : lhs).scalar())->value; + + std::vector bitmap(array->length()); + for (int64_t i = 0; i < array->length(); i++) { + bitmap[i] = swap ? SlowCompare(options.op, array->Value(i), value) + : SlowCompare(options.op, value, array->Value(i)); + } + + std::shared_ptr result; + + if (array->null_count() == 0) { + ArrayFromVector(bitmap, &result); + } else { + std::vector null_bitmap(array->length()); + auto reader = internal::BitmapReader(array->null_bitmap_data(), array->offset(), + array->length()); + for (int64_t i = 0; i < array->length(); i++, reader.Next()) { + null_bitmap[i] = reader.IsSet(); + } + ArrayFromVector(null_bitmap, bitmap, &result); + } + + return Datum(result); +} + +template ::ArrayType> +static std::vector NullBitmapFromArrays(const ArrayType& lhs, + const ArrayType& rhs) { + auto left_lambda = [&lhs](int64_t i) { + return lhs.null_count() == 0 ? true : lhs.IsValid(i); + }; + + auto right_lambda = [&rhs](int64_t i) { + return rhs.null_count() == 0 ? true : rhs.IsValid(i); + }; + + const int64_t length = lhs.length(); + std::vector null_bitmap(length); + + for (int64_t i = 0; i < length; i++) { + null_bitmap[i] = left_lambda(i) && right_lambda(i); + } + + return null_bitmap; +} + +template +static Datum SimpleArrayArrayCompare(CompareOptions options, const Datum& lhs, + const Datum& rhs) { + using ArrayType = typename TypeTraits::ArrayType; + using T = typename TypeTraits::CType; + + auto l_array = std::static_pointer_cast(lhs.make_array()); + auto r_array = std::static_pointer_cast(rhs.make_array()); + const int64_t length = l_array->length(); + + std::vector bitmap(length); + for (int64_t i = 0; i < length; i++) { + bitmap[i] = SlowCompare(options.op, l_array->Value(i), r_array->Value(i)); + } + + std::shared_ptr result; + + if (l_array->null_count() == 0 && r_array->null_count() == 0) { + ArrayFromVector(bitmap, &result); + } else { + std::vector null_bitmap = NullBitmapFromArrays(*l_array, *r_array); + ArrayFromVector(null_bitmap, bitmap, &result); + } + + return Datum(result); +} + +template +static void ValidateCompare(FunctionContext* ctx, CompareOptions options, + const Datum& lhs, const Datum& rhs) { + Datum result; + + bool has_scalar = lhs.is_scalar() || rhs.is_scalar(); + Datum expected = has_scalar ? SimpleScalarArrayCompare(options, lhs, rhs) + : SimpleArrayArrayCompare(options, lhs, rhs); + + ValidateCompare(ctx, options, lhs, rhs, expected); +} + +template +class TestNumericCompareKernel : public ComputeFixture, public TestBase {}; + +TYPED_TEST_CASE(TestNumericCompareKernel, NumericArrowTypes); +TYPED_TEST(TestNumericCompareKernel, SimpleCompareArrayScalar) { + using ScalarType = typename TypeTraits::ScalarType; + using CType = typename TypeTraits::CType; + + Datum one(std::make_shared(CType(1))); + + CompareOptions eq(CompareOperator::EQUAL); + ValidateCompare(&this->ctx_, eq, "[]", one, "[]"); + ValidateCompare(&this->ctx_, eq, "[null]", one, "[null]"); + ValidateCompare(&this->ctx_, eq, "[0,0,1,1,2,2]", one, "[0,0,1,1,0,0]"); + ValidateCompare(&this->ctx_, eq, "[0,1,2,3,4,5]", one, "[0,1,0,0,0,0]"); + ValidateCompare(&this->ctx_, eq, "[5,4,3,2,1,0]", one, "[0,0,0,0,1,0]"); + ValidateCompare(&this->ctx_, eq, "[null,0,1,1]", one, "[null,0,1,1]"); + + CompareOptions neq(CompareOperator::NOT_EQUAL); + ValidateCompare(&this->ctx_, neq, "[]", one, "[]"); + ValidateCompare(&this->ctx_, neq, "[null]", one, "[null]"); + ValidateCompare(&this->ctx_, neq, "[0,0,1,1,2,2]", one, "[1,1,0,0,1,1]"); + ValidateCompare(&this->ctx_, neq, "[0,1,2,3,4,5]", one, "[1,0,1,1,1,1]"); + ValidateCompare(&this->ctx_, neq, "[5,4,3,2,1,0]", one, "[1,1,1,1,0,1]"); + ValidateCompare(&this->ctx_, neq, "[null,0,1,1]", one, "[null,1,0,0]"); + + CompareOptions gt(CompareOperator::GREATER); + ValidateCompare(&this->ctx_, gt, "[]", one, "[]"); + ValidateCompare(&this->ctx_, gt, "[null]", one, "[null]"); + ValidateCompare(&this->ctx_, gt, "[0,0,1,1,2,2]", one, "[0,0,0,0,1,1]"); + ValidateCompare(&this->ctx_, gt, "[0,1,2,3,4,5]", one, "[0,0,1,1,1,1]"); + ValidateCompare(&this->ctx_, gt, "[4,5,6,7,8,9]", one, "[1,1,1,1,1,1]"); + ValidateCompare(&this->ctx_, gt, "[null,0,1,1]", one, "[null,0,0,0]"); + + CompareOptions gte(CompareOperator::GREATER_EQUAL); + ValidateCompare(&this->ctx_, gte, "[]", one, "[]"); + ValidateCompare(&this->ctx_, gte, "[null]", one, "[null]"); + ValidateCompare(&this->ctx_, gte, "[0,0,1,1,2,2]", one, "[0,0,1,1,1,1]"); + ValidateCompare(&this->ctx_, gte, "[0,1,2,3,4,5]", one, "[0,1,1,1,1,1]"); + ValidateCompare(&this->ctx_, gte, "[4,5,6,7,8,9]", one, "[1,1,1,1,1,1]"); + ValidateCompare(&this->ctx_, gte, "[null,0,1,1]", one, "[null,0,1,1]"); + + CompareOptions lt(CompareOperator::LESS); + ValidateCompare(&this->ctx_, lt, "[]", one, "[]"); + ValidateCompare(&this->ctx_, lt, "[null]", one, "[null]"); + ValidateCompare(&this->ctx_, lt, "[0,0,1,1,2,2]", one, "[1,1,0,0,0,0]"); + ValidateCompare(&this->ctx_, lt, "[0,1,2,3,4,5]", one, "[1,0,0,0,0,0]"); + ValidateCompare(&this->ctx_, lt, "[4,5,6,7,8,9]", one, "[0,0,0,0,0,0]"); + ValidateCompare(&this->ctx_, lt, "[null,0,1,1]", one, "[null,1,0,0]"); + + CompareOptions lte(CompareOperator::LESS_EQUAL); + ValidateCompare(&this->ctx_, lte, "[]", one, "[]"); + ValidateCompare(&this->ctx_, lte, "[null]", one, "[null]"); + ValidateCompare(&this->ctx_, lte, "[0,0,1,1,2,2]", one, "[1,1,1,1,0,0]"); + ValidateCompare(&this->ctx_, lte, "[0,1,2,3,4,5]", one, "[1,1,0,0,0,0]"); + ValidateCompare(&this->ctx_, lte, "[4,5,6,7,8,9]", one, "[0,0,0,0,0,0]"); + ValidateCompare(&this->ctx_, lte, "[null,0,1,1]", one, "[null,1,1,1]"); +} + +TYPED_TEST(TestNumericCompareKernel, SimpleCompareScalarArray) { + using ScalarType = typename TypeTraits::ScalarType; + using CType = typename TypeTraits::CType; + + Datum one(std::make_shared(CType(1))); + + CompareOptions eq(CompareOperator::EQUAL); + ValidateCompare(&this->ctx_, eq, one, "[]", "[]"); + ValidateCompare(&this->ctx_, eq, one, "[null]", "[null]"); + ValidateCompare(&this->ctx_, eq, one, "[0,0,1,1,2,2]", "[0,0,1,1,0,0]"); + ValidateCompare(&this->ctx_, eq, one, "[0,1,2,3,4,5]", "[0,1,0,0,0,0]"); + ValidateCompare(&this->ctx_, eq, one, "[5,4,3,2,1,0]", "[0,0,0,0,1,0]"); + ValidateCompare(&this->ctx_, eq, one, "[null,0,1,1]", "[null,0,1,1]"); + + CompareOptions neq(CompareOperator::NOT_EQUAL); + ValidateCompare(&this->ctx_, neq, one, "[]", "[]"); + ValidateCompare(&this->ctx_, neq, one, "[null]", "[null]"); + ValidateCompare(&this->ctx_, neq, one, "[0,0,1,1,2,2]", "[1,1,0,0,1,1]"); + ValidateCompare(&this->ctx_, neq, one, "[0,1,2,3,4,5]", "[1,0,1,1,1,1]"); + ValidateCompare(&this->ctx_, neq, one, "[5,4,3,2,1,0]", "[1,1,1,1,0,1]"); + ValidateCompare(&this->ctx_, neq, one, "[null,0,1,1]", "[null,1,0,0]"); + + CompareOptions gt(CompareOperator::GREATER); + ValidateCompare(&this->ctx_, gt, one, "[]", "[]"); + ValidateCompare(&this->ctx_, gt, one, "[null]", "[null]"); + ValidateCompare(&this->ctx_, gt, one, "[0,0,1,1,2,2]", "[1,1,0,0,0,0]"); + ValidateCompare(&this->ctx_, gt, one, "[0,1,2,3,4,5]", "[1,0,0,0,0,0]"); + ValidateCompare(&this->ctx_, gt, one, "[4,5,6,7,8,9]", "[0,0,0,0,0,0]"); + ValidateCompare(&this->ctx_, gt, one, "[null,0,1,1]", "[null,1,0,0]"); + + CompareOptions gte(CompareOperator::GREATER_EQUAL); + ValidateCompare(&this->ctx_, gte, one, "[]", "[]"); + ValidateCompare(&this->ctx_, gte, one, "[null]", "[null]"); + ValidateCompare(&this->ctx_, gte, one, "[0,0,1,1,2,2]", "[1,1,1,1,0,0]"); + ValidateCompare(&this->ctx_, gte, one, "[0,1,2,3,4,5]", "[1,1,0,0,0,0]"); + ValidateCompare(&this->ctx_, gte, one, "[4,5,6,7,8,9]", "[0,0,0,0,0,0]"); + ValidateCompare(&this->ctx_, gte, one, "[null,0,1,1]", "[null,1,1,1]"); + + CompareOptions lt(CompareOperator::LESS); + ValidateCompare(&this->ctx_, lt, one, "[]", "[]"); + ValidateCompare(&this->ctx_, lt, one, "[null]", "[null]"); + ValidateCompare(&this->ctx_, lt, one, "[0,0,1,1,2,2]", "[0,0,0,0,1,1]"); + ValidateCompare(&this->ctx_, lt, one, "[0,1,2,3,4,5]", "[0,0,1,1,1,1]"); + ValidateCompare(&this->ctx_, lt, one, "[4,5,6,7,8,9]", "[1,1,1,1,1,1]"); + ValidateCompare(&this->ctx_, lt, one, "[null,0,1,1]", "[null,0,0,0]"); + + CompareOptions lte(CompareOperator::LESS_EQUAL); + ValidateCompare(&this->ctx_, lte, one, "[]", "[]"); + ValidateCompare(&this->ctx_, lte, one, "[null]", "[null]"); + ValidateCompare(&this->ctx_, lte, one, "[0,0,1,1,2,2]", "[0,0,1,1,1,1]"); + ValidateCompare(&this->ctx_, lte, one, "[0,1,2,3,4,5]", "[0,1,1,1,1,1]"); + ValidateCompare(&this->ctx_, lte, one, "[4,5,6,7,8,9]", "[1,1,1,1,1,1]"); + ValidateCompare(&this->ctx_, lte, one, "[null,0,1,1]", "[null,0,1,1]"); +} + +TYPED_TEST(TestNumericCompareKernel, TestNullScalar) { + /* Ensure that null scalar broadcast to all null results. */ + using ScalarType = typename TypeTraits::ScalarType; + using CType = typename TypeTraits::CType; + + Datum null(std::make_shared(CType(0), false)); + EXPECT_FALSE(null.scalar()->is_valid); + + CompareOptions eq(CompareOperator::EQUAL); + ValidateCompare(&this->ctx_, eq, "[]", null, "[]"); + ValidateCompare(&this->ctx_, eq, null, "[]", "[]"); + ValidateCompare(&this->ctx_, eq, "[null]", null, "[null]"); + ValidateCompare(&this->ctx_, eq, null, "[null]", "[null]"); + ValidateCompare(&this->ctx_, eq, null, "[1,2,3]", "[null, null, null]"); +} + +TYPED_TEST_CASE(TestNumericCompareKernel, NumericArrowTypes); +TYPED_TEST(TestNumericCompareKernel, RandomCompareArrayScalar) { + using ScalarType = typename TypeTraits::ScalarType; + using CType = typename TypeTraits::CType; + + auto rand = random::RandomArrayGenerator(0x5416447); + for (size_t i = 3; i < 13; i++) { + for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { + for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { + const int64_t length = static_cast(1ULL << i); + auto array = Datum(rand.Numeric(length, 0, 100, null_probability)); + auto fifty = Datum(std::make_shared(CType(50))); + auto options = CompareOptions(op); + ValidateCompare(&this->ctx_, options, array, fifty); + ValidateCompare(&this->ctx_, options, fifty, array); + } + } + } +} + +TYPED_TEST(TestNumericCompareKernel, SimpleCompareArrayArray) { + /* Ensure that null scalar broadcast to all null results. */ + CompareOptions eq(CompareOperator::EQUAL); + ValidateCompare(&this->ctx_, eq, "[]", "[]", "[]"); + ValidateCompare(&this->ctx_, eq, "[null]", "[null]", "[null]"); + ValidateCompare(&this->ctx_, eq, "[1]", "[1]", "[1]"); + ValidateCompare(&this->ctx_, eq, "[1]", "[2]", "[0]"); + ValidateCompare(&this->ctx_, eq, "[null]", "[1]", "[null]"); + ValidateCompare(&this->ctx_, eq, "[1]", "[null]", "[null]"); + + CompareOptions lte(CompareOperator::LESS_EQUAL); + ValidateCompare(&this->ctx_, lte, "[1,2,3,4,5]", "[2,3,4,5,6]", + "[1,1,1,1,1]"); +} + +TYPED_TEST(TestNumericCompareKernel, RandomCompareArrayArray) { + auto rand = random::RandomArrayGenerator(0x5416447); + for (size_t i = 3; i < 5; i++) { + for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { + for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { + const int64_t length = static_cast(1ULL << i); + auto lhs = Datum(rand.Numeric(length << i, 0, 100, null_probability)); + auto rhs = Datum(rand.Numeric(length << i, 0, 100, null_probability)); + auto options = CompareOptions(op); + ValidateCompare(&this->ctx_, options, lhs, rhs); + } + } + } +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/compare.cc b/cpp/src/arrow/compute/kernels/compare.cc index 040793f4e6569..f5fab7d0fe2a3 100644 --- a/cpp/src/arrow/compute/kernels/compare.cc +++ b/cpp/src/arrow/compute/kernels/compare.cc @@ -19,7 +19,6 @@ #include "arrow/compute/context.h" #include "arrow/compute/kernel.h" -#include "arrow/compute/kernels/filter.h" #include "arrow/compute/kernels/util-internal.h" #include "arrow/util/bit-util.h" #include "arrow/util/logging.h" @@ -28,8 +27,35 @@ namespace arrow { namespace compute { -class FunctionContext; -struct Datum; +std::shared_ptr CompareBinaryKernel::out_type() const { + return compare_function_->out_type(); +} + +Status CompareBinaryKernel::Call(FunctionContext* ctx, const Datum& left, + const Datum& right, Datum* out) { + DCHECK(left.type()->Equals(right.type())); + + auto lk = left.kind(); + auto rk = right.kind(); + auto out_array = out->array(); + + if (lk == Datum::ARRAY && rk == Datum::SCALAR) { + auto array = left.array(); + auto scalar = right.scalar(); + return compare_function_->Compare(*array, *scalar, &out_array); + } else if (lk == Datum::SCALAR && rk == Datum::ARRAY) { + auto scalar = left.scalar(); + auto array = right.array(); + auto out_array = out->array(); + return compare_function_->Compare(*scalar, *array, &out_array); + } else if (lk == Datum::ARRAY && rk == Datum::ARRAY) { + auto lhs = left.array(); + auto rhs = right.array(); + return compare_function_->Compare(*lhs, *rhs, &out_array); + } + + return Status::Invalid("Invalid datum signature for CompareBinaryKernel"); +} template ::ScalarType, @@ -76,13 +102,14 @@ static Status CompareArrayArray(const ArrayData& lhs, const ArrayData& rhs, } template -class CompareFunction final : public FilterFunction { +class CompareFunctionImpl final : public CompareFunction { + using ArrayType = typename TypeTraits::ArrayType; using ScalarType = typename TypeTraits::ScalarType; public: - explicit CompareFunction(FunctionContext* ctx) : ctx_(ctx) {} + explicit CompareFunctionImpl(FunctionContext* ctx) : ctx_(ctx) {} - Status Filter(const ArrayData& array, const Scalar& scalar, ArrayData* output) const { + Status Compare(const ArrayData& array, const Scalar& scalar, ArrayData* output) const { // Caller must cast DCHECK(array.type->Equals(scalar.type)); // Output must be a boolean array @@ -103,7 +130,7 @@ class CompareFunction final : public FilterFunction { array, static_cast(scalar), bitmap_result); } - Status Filter(const Scalar& scalar, const ArrayData& array, ArrayData* output) const { + Status Compare(const Scalar& scalar, const ArrayData& array, ArrayData* output) const { // Caller must cast DCHECK(array.type->Equals(scalar.type)); // Output must be a boolean array @@ -124,7 +151,7 @@ class CompareFunction final : public FilterFunction { array, bitmap_result); } - Status Filter(const ArrayData& lhs, const ArrayData& rhs, ArrayData* output) const { + Status Compare(const ArrayData& lhs, const ArrayData& rhs, ArrayData* output) const { // Caller must cast DCHECK(lhs.type->Equals(rhs.type)); // Output must be a boolean array @@ -146,13 +173,13 @@ class CompareFunction final : public FilterFunction { }; template -static inline std::shared_ptr MakeCompareFunctionTypeOp( +static inline std::shared_ptr MakeCompareFunctionTypeOp( FunctionContext* ctx) { - return std::make_shared>(ctx); + return std::make_shared>(ctx); } template -static inline std::shared_ptr MakeCompareFilterFunctionType( +static inline std::shared_ptr MakeCompareFunctionType( FunctionContext* ctx, struct CompareOptions options) { switch (options.op) { case CompareOperator::EQUAL: @@ -172,40 +199,40 @@ static inline std::shared_ptr MakeCompareFilterFunctionType( return nullptr; } -std::shared_ptr MakeCompareFilterFunction(FunctionContext* ctx, - const DataType& type, - struct CompareOptions options) { +std::shared_ptr MakeCompareFunction(FunctionContext* ctx, + const DataType& type, + struct CompareOptions options) { switch (type.id()) { case UInt8Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case Int8Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case UInt16Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case Int16Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case UInt32Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case Int32Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case UInt64Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case Int64Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case FloatType::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case DoubleType::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case Date32Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case Date64Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case TimestampType::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case Time32Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case Time64Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); default: return nullptr; } @@ -219,15 +246,15 @@ Status Compare(FunctionContext* context, const Datum& left, const Datum& right, auto type = left.type(); DCHECK(type->Equals(right.type())); // Requires that both types are equal. - auto fn = MakeCompareFilterFunction(context, *type, options); + auto fn = MakeCompareFunction(context, *type, options); if (fn == nullptr) { return Status::NotImplemented("Compare not implemented for type ", type->ToString()); } - FilterBinaryKernel filter_kernel(fn); + CompareBinaryKernel filter_kernel(fn); detail::PrimitiveAllocatingBinaryKernel kernel(&filter_kernel); - const int64_t length = FilterBinaryKernel::out_length(left, right); + const int64_t length = CompareBinaryKernel::out_length(left, right); out->value = ArrayData::Make(filter_kernel.out_type(), length); return kernel.Call(context, left, right, out); diff --git a/cpp/src/arrow/compute/kernels/compare.h b/cpp/src/arrow/compute/kernels/compare.h index a192451291677..b4c9612ae8caa 100644 --- a/cpp/src/arrow/compute/kernels/compare.h +++ b/cpp/src/arrow/compute/kernels/compare.h @@ -19,6 +19,7 @@ #include +#include "arrow/compute/kernel.h" #include "arrow/util/visibility.h" namespace arrow { @@ -31,9 +32,68 @@ class Status; namespace compute { struct Datum; -class FilterFunction; class FunctionContext; +/// CompareFunction is an interface for Comparisons +/// +/// Comparisons take an array and emits a selection vector. The selection vector +/// is given in the form of a bitmask as a BooleanArray result. +class ARROW_EXPORT CompareFunction { + public: + /// Compare an array with a scalar argument. + virtual Status Compare(const ArrayData& array, const Scalar& scalar, + ArrayData* output) const = 0; + + Status Compare(const ArrayData& array, const Scalar& scalar, + std::shared_ptr* output) { + return Compare(array, scalar, output->get()); + } + + virtual Status Compare(const Scalar& scalar, const ArrayData& array, + ArrayData* output) const = 0; + + Status Compare(const Scalar& scalar, const ArrayData& array, + std::shared_ptr* output) { + return Compare(scalar, array, output->get()); + } + + /// Compare an array with an array argument. + virtual Status Compare(const ArrayData& lhs, const ArrayData& rhs, + ArrayData* output) const = 0; + + Status Compare(const ArrayData& lhs, const ArrayData& rhs, + std::shared_ptr* output) { + return Compare(lhs, rhs, output->get()); + } + + /// By default, CompareFunction emits a result bitmap. + virtual std::shared_ptr out_type() const { return boolean(); } + + virtual ~CompareFunction() {} +}; + +/// \brief BinaryKernel bound to a select function +class ARROW_EXPORT CompareBinaryKernel : public BinaryKernel { + public: + explicit CompareBinaryKernel(std::shared_ptr& select) + : compare_function_(select) {} + + Status Call(FunctionContext* ctx, const Datum& left, const Datum& right, + Datum* out) override; + + static int64_t out_length(const Datum& left, const Datum& right) { + if (left.kind() == Datum::ARRAY) return left.length(); + if (right.kind() == Datum::ARRAY) return right.length(); + + return 0; + } + + std::shared_ptr out_type() const override; + + private: + std::shared_ptr compare_function_; +}; + enum CompareOperator { EQUAL, NOT_EQUAL, @@ -82,7 +142,7 @@ struct CompareOptions { enum CompareOperator op; }; -/// \brief Return a Compare FilterFunction +/// \brief Return a Compare CompareFunction /// /// \param[in] context FunctionContext passing context information /// \param[in] type required to specialize the kernel @@ -91,9 +151,9 @@ struct CompareOptions { /// \since 0.14.0 /// \note API not yet finalized ARROW_EXPORT -std::shared_ptr MakeCompareFilterFunction(FunctionContext* context, - const DataType& type, - struct CompareOptions options); +std::shared_ptr MakeCompareFunction(FunctionContext* context, + const DataType& type, + struct CompareOptions options); /// \brief Compare a numeric array with a scalar. /// diff --git a/cpp/src/arrow/compute/kernels/filter-benchmark.cc b/cpp/src/arrow/compute/kernels/filter-benchmark.cc index 00de199bfe90d..3eb460adc02d4 100644 --- a/cpp/src/arrow/compute/kernels/filter-benchmark.cc +++ b/cpp/src/arrow/compute/kernels/filter-benchmark.cc @@ -17,11 +17,9 @@ #include "benchmark/benchmark.h" -#include +#include "arrow/compute/kernels/filter.h" #include "arrow/compute/benchmark-util.h" -#include "arrow/compute/kernel.h" -#include "arrow/compute/kernels/compare.h" #include "arrow/compute/test-util.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" @@ -29,54 +27,60 @@ namespace arrow { namespace compute { -static void CompareArrayScalarKernel(benchmark::State& state) { - const int64_t memory_size = state.range(0) / 4; - const int64_t array_size = memory_size / sizeof(int64_t); - const double null_percent = static_cast(state.range(1)) / 100.0; - auto rand = random::RandomArrayGenerator(0x94378165); - auto array = std::static_pointer_cast>( - rand.Int64(array_size, -100, 100, null_percent)); +constexpr auto kSeed = 0x0ff1ce; + +static void FilterInt64(benchmark::State& state) { + RegressionArgs args(state); - CompareOptions ge{GREATER_EQUAL}; + const int64_t array_size = args.size / sizeof(int64_t); + auto rand = random::RandomArrayGenerator(kSeed); + auto array = std::static_pointer_cast>( + rand.Int64(array_size, -100, 100, args.null_proportion)); + auto filter = std::static_pointer_cast( + rand.Boolean(array_size, 0.75, args.null_proportion)); FunctionContext ctx; for (auto _ : state) { Datum out; - ABORT_NOT_OK(Compare(&ctx, Datum(array), Datum(int64_t(0)), ge, &out)); + ABORT_NOT_OK(Filter(&ctx, Datum(array), Datum(filter), &out)); benchmark::DoNotOptimize(out); } - - state.counters["size"] = static_cast(memory_size); - state.counters["null_percent"] = static_cast(state.range(1)); - state.SetBytesProcessed(state.iterations() * array_size * sizeof(int64_t)); } -static void CompareArrayArrayKernel(benchmark::State& state) { - const int64_t memory_size = state.range(0) / 4; - const int64_t array_size = memory_size / sizeof(int64_t); - const double null_percent = static_cast(state.range(1)) / 100.0; - auto rand = random::RandomArrayGenerator(0x94378165); - auto lhs = std::static_pointer_cast>( - rand.Int64(array_size, -100, 100, null_percent)); - auto rhs = std::static_pointer_cast>( - rand.Int64(array_size, -100, 100, null_percent)); +static void FilterFixedSizeList1Int64(benchmark::State& state) { + RegressionArgs args(state); - CompareOptions ge(GREATER_EQUAL); + const int64_t array_size = args.size / sizeof(int64_t); + auto rand = random::RandomArrayGenerator(kSeed); + auto int_array = std::static_pointer_cast>( + rand.Int64(array_size, -100, 100, args.null_proportion)); + auto array = std::make_shared( + fixed_size_list(int64(), 1), array_size, int_array, int_array->null_bitmap(), + int_array->null_count()); + auto filter = std::static_pointer_cast( + rand.Boolean(array_size, 0.75, args.null_proportion)); FunctionContext ctx; for (auto _ : state) { Datum out; - ABORT_NOT_OK(Compare(&ctx, Datum(lhs), Datum(rhs), ge, &out)); + ABORT_NOT_OK(Filter(&ctx, Datum(array), Datum(filter), &out)); benchmark::DoNotOptimize(out); } - - state.counters["size"] = static_cast(memory_size); - state.counters["null_percent"] = static_cast(state.range(1)); - state.SetBytesProcessed(state.iterations() * array_size * sizeof(int64_t) * 2); } -BENCHMARK(CompareArrayScalarKernel)->Apply(RegressionSetArgs); -BENCHMARK(CompareArrayArrayKernel)->Apply(RegressionSetArgs); +BENCHMARK(FilterInt64) + ->Apply(RegressionSetArgs) + ->Args({1 << 20, 1}) + ->Args({1 << 23, 1}) + ->MinTime(1.0) + ->Unit(benchmark::TimeUnit::kNanosecond); + +BENCHMARK(FilterFixedSizeList1Int64) + ->Apply(RegressionSetArgs) + ->Args({1 << 20, 1}) + ->Args({1 << 23, 1}) + ->MinTime(1.0) + ->Unit(benchmark::TimeUnit::kNanosecond); } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/filter-test.cc b/cpp/src/arrow/compute/kernels/filter-test.cc index 1c8967cf7aca7..7b349492b1daa 100644 --- a/cpp/src/arrow/compute/kernels/filter-test.cc +++ b/cpp/src/arrow/compute/kernels/filter-test.cc @@ -15,376 +15,358 @@ // specific language governing permissions and limitations // under the License. -#include #include -#include -#include -#include +#include -#include - -#include "arrow/array.h" -#include "arrow/compute/kernel.h" +#include "arrow/compute/context.h" +#include "arrow/compute/kernels/boolean.h" #include "arrow/compute/kernels/compare.h" #include "arrow/compute/kernels/filter.h" #include "arrow/compute/test-util.h" -#include "arrow/type.h" -#include "arrow/type_traits.h" -#include "arrow/util/checked_cast.h" - #include "arrow/testing/gtest_common.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" +#include "arrow/testing/util.h" namespace arrow { namespace compute { -TEST(TestComparatorOperator, BasicOperator) { - using T = int32_t; - std::vector vals{0, 1, 2, 3, 4, 5, 6}; - - for (int32_t i : vals) { - for (int32_t j : vals) { - EXPECT_EQ((Comparator::Compare(i, j)), i == j); - EXPECT_EQ((Comparator::Compare(i, j)), i != j); - EXPECT_EQ((Comparator::Compare(i, j)), i > j); - EXPECT_EQ((Comparator::Compare(i, j)), i >= j); - EXPECT_EQ((Comparator::Compare(i, j)), i < j); - EXPECT_EQ((Comparator::Compare(i, j)), i <= j); +using internal::checked_pointer_cast; +using util::string_view; + +template +class TestFilterKernel : public ComputeFixture, public TestBase { + protected: + void AssertFilterArrays(const std::shared_ptr& values, + const std::shared_ptr& filter, + const std::shared_ptr& expected) { + std::shared_ptr actual; + ASSERT_OK(arrow::compute::Filter(&this->ctx_, *values, *filter, &actual)); + AssertArraysEqual(*expected, *actual); + } + void AssertFilter(const std::shared_ptr& type, const std::string& values, + const std::string& filter, const std::string& expected) { + std::shared_ptr actual; + ASSERT_OK(this->Filter(type, values, filter, &actual)); + AssertArraysEqual(*ArrayFromJSON(type, expected), *actual); + } + Status Filter(const std::shared_ptr& type, const std::string& values, + const std::string& filter, std::shared_ptr* out) { + return arrow::compute::Filter(&this->ctx_, *ArrayFromJSON(type, values), + *ArrayFromJSON(boolean(), filter), out); + } + void ValidateFilter(const std::shared_ptr& values, + const std::shared_ptr& filter_boxed) { + std::shared_ptr filtered; + ASSERT_OK(arrow::compute::Filter(&this->ctx_, *values, *filter_boxed, &filtered)); + + auto filter = checked_pointer_cast(filter_boxed); + int64_t values_i = 0, filtered_i = 0; + for (; values_i < values->length(); ++values_i, ++filtered_i) { + if (filter->IsNull(values_i)) { + ASSERT_LT(filtered_i, filtered->length()); + ASSERT_TRUE(filtered->IsNull(filtered_i)); + continue; + } + if (!filter->Value(values_i)) { + // this element was filtered out; don't examine filtered + --filtered_i; + continue; + } + ASSERT_LT(filtered_i, filtered->length()); + ASSERT_TRUE(values->RangeEquals(values_i, values_i + 1, filtered_i, filtered)); } + ASSERT_EQ(filtered_i, filtered->length()); } -} +}; -template -static void ValidateCompare(FunctionContext* ctx, CompareOptions options, - const Datum& lhs, const Datum& rhs, const Datum& expected) { - Datum result; +class TestFilterKernelWithNull : public TestFilterKernel { + protected: + void AssertFilter(const std::string& values, const std::string& filter, + const std::string& expected) { + TestFilterKernel::AssertFilter(utf8(), values, filter, expected); + } +}; - ASSERT_OK(Compare(ctx, lhs, rhs, options, &result)); - AssertArraysEqual(*expected.make_array(), *result.make_array()); +TEST_F(TestFilterKernelWithNull, FilterNull) { + this->AssertFilter("[null, null, null]", "[0, 1, 0]", "[null]"); + this->AssertFilter("[null, null, null]", "[1, 1, 0]", "[null, null]"); } -template -static void ValidateCompare(FunctionContext* ctx, CompareOptions options, - const char* lhs_str, const Datum& rhs, - const char* expected_str) { - auto lhs = ArrayFromJSON(TypeTraits::type_singleton(), lhs_str); - auto expected = ArrayFromJSON(TypeTraits::type_singleton(), expected_str); - ValidateCompare(ctx, options, lhs, rhs, expected); -} +class TestFilterKernelWithBoolean : public TestFilterKernel { + protected: + void AssertFilter(const std::string& values, const std::string& filter, + const std::string& expected) { + TestFilterKernel::AssertFilter(boolean(), values, filter, expected); + } +}; -template -static void ValidateCompare(FunctionContext* ctx, CompareOptions options, - const Datum& lhs, const char* rhs_str, - const char* expected_str) { - auto rhs = ArrayFromJSON(TypeTraits::type_singleton(), rhs_str); - auto expected = ArrayFromJSON(TypeTraits::type_singleton(), expected_str); - ValidateCompare(ctx, options, lhs, rhs, expected); +TEST_F(TestFilterKernelWithBoolean, FilterBoolean) { + this->AssertFilter("[true, false, true]", "[0, 1, 0]", "[false]"); + this->AssertFilter("[null, false, true]", "[0, 1, 0]", "[false]"); + this->AssertFilter("[true, false, true]", "[null, 1, 0]", "[null, false]"); } template -static void ValidateCompare(FunctionContext* ctx, CompareOptions options, - const char* lhs_str, const char* rhs_str, - const char* expected_str) { - auto lhs = ArrayFromJSON(TypeTraits::type_singleton(), lhs_str); - auto rhs = ArrayFromJSON(TypeTraits::type_singleton(), rhs_str); - auto expected = ArrayFromJSON(TypeTraits::type_singleton(), expected_str); - ValidateCompare(ctx, options, lhs, rhs, expected); -} - -template -static inline bool SlowCompare(CompareOperator op, const T& lhs, const T& rhs) { - switch (op) { - case EQUAL: - return lhs == rhs; - case NOT_EQUAL: - return lhs != rhs; - case GREATER: - return lhs > rhs; - case GREATER_EQUAL: - return lhs >= rhs; - case LESS: - return lhs < rhs; - case LESS_EQUAL: - return lhs <= rhs; - default: - return false; +class TestFilterKernelWithNumeric : public TestFilterKernel { + protected: + void AssertFilter(const std::string& values, const std::string& filter, + const std::string& expected) { + TestFilterKernel::AssertFilter(type_singleton(), values, filter, expected); } -} - -template -static Datum SimpleScalarArrayCompare(CompareOptions options, const Datum& lhs, - const Datum& rhs) { - using ArrayType = typename TypeTraits::ArrayType; - using ScalarType = typename TypeTraits::ScalarType; - using T = typename TypeTraits::CType; - - bool swap = lhs.is_array(); - auto array = std::static_pointer_cast((swap ? lhs : rhs).make_array()); - T value = std::static_pointer_cast((swap ? rhs : lhs).scalar())->value; - - std::vector bitmap(array->length()); - for (int64_t i = 0; i < array->length(); i++) { - bitmap[i] = swap ? SlowCompare(options.op, array->Value(i), value) - : SlowCompare(options.op, value, array->Value(i)); + std::shared_ptr type_singleton() { + return TypeTraits::type_singleton(); } +}; + +TYPED_TEST_CASE(TestFilterKernelWithNumeric, NumericArrowTypes); +TYPED_TEST(TestFilterKernelWithNumeric, FilterNumeric) { + this->AssertFilter("[]", "[]", "[]"); + + this->AssertFilter("[9]", "[0]", "[]"); + this->AssertFilter("[9]", "[1]", "[9]"); + this->AssertFilter("[9]", "[null]", "[null]"); + this->AssertFilter("[null]", "[0]", "[]"); + this->AssertFilter("[null]", "[1]", "[null]"); + this->AssertFilter("[null]", "[null]", "[null]"); + + this->AssertFilter("[7, 8, 9]", "[0, 1, 0]", "[8]"); + this->AssertFilter("[7, 8, 9]", "[1, 0, 1]", "[7, 9]"); + this->AssertFilter("[null, 8, 9]", "[0, 1, 0]", "[8]"); + this->AssertFilter("[7, 8, 9]", "[null, 1, 0]", "[null, 8]"); + this->AssertFilter("[7, 8, 9]", "[1, null, 1]", "[7, null, 9]"); +} - std::shared_ptr result; - - if (array->null_count() == 0) { - ArrayFromVector(bitmap, &result); - } else { - std::vector null_bitmap(array->length()); - auto reader = internal::BitmapReader(array->null_bitmap_data(), array->offset(), - array->length()); - for (int64_t i = 0; i < array->length(); i++, reader.Next()) { - null_bitmap[i] = reader.IsSet(); +TYPED_TEST(TestFilterKernelWithNumeric, FilterRandomNumeric) { + auto rand = random::RandomArrayGenerator(0x5416447); + for (size_t i = 3; i < 13; i++) { + const int64_t length = static_cast(1ULL << i); + for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { + for (auto filter_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { + auto values = rand.Numeric(length, 0, 127, null_probability); + auto filter = rand.Boolean(length, filter_probability, null_probability); + this->ValidateFilter(values, filter); + } } - ArrayFromVector(null_bitmap, bitmap, &result); } - - return Datum(result); } -template ::ArrayType> -static std::vector NullBitmapFromArrays(const ArrayType& lhs, - const ArrayType& rhs) { - auto left_lambda = [&lhs](int64_t i) { - return lhs.null_count() == 0 ? true : lhs.IsValid(i); +template +decltype(Comparator::Compare)* GetComparator(CompareOperator op) { + using cmp_t = decltype(Comparator::Compare); + static cmp_t* cmp[] = { + Comparator::Compare, Comparator::Compare, + Comparator::Compare, Comparator::Compare, + Comparator::Compare, Comparator::Compare, }; + return cmp[op]; +} - auto right_lambda = [&rhs](int64_t i) { - return rhs.null_count() == 0 ? true : rhs.IsValid(i); - }; - - const int64_t length = lhs.length(); - std::vector null_bitmap(length); - - for (int64_t i = 0; i < length; i++) { - null_bitmap[i] = left_lambda(i) && right_lambda(i); - } +template ::CType> +std::shared_ptr CompareAndFilter(const CType* data, int64_t length, Fn&& fn) { + std::vector filtered; + filtered.reserve(length); + std::copy_if(data, data + length, std::back_inserter(filtered), std::forward(fn)); + std::shared_ptr filtered_array; + ArrayFromVector(filtered, &filtered_array); + return filtered_array; +} - return null_bitmap; +template ::CType> +std::shared_ptr CompareAndFilter(const CType* data, int64_t length, CType val, + CompareOperator op) { + auto cmp = GetComparator(op); + return CompareAndFilter(data, length, [&](CType e) { return cmp(e, val); }); } -template -static Datum SimpleArrayArrayCompare(CompareOptions options, const Datum& lhs, - const Datum& rhs) { - using ArrayType = typename TypeTraits::ArrayType; - using T = typename TypeTraits::CType; - - auto l_array = std::static_pointer_cast(lhs.make_array()); - auto r_array = std::static_pointer_cast(rhs.make_array()); - const int64_t length = l_array->length(); - - std::vector bitmap(length); - for (int64_t i = 0; i < length; i++) { - bitmap[i] = SlowCompare(options.op, l_array->Value(i), r_array->Value(i)); - } +template ::CType> +std::shared_ptr CompareAndFilter(const CType* data, int64_t length, + const CType* other, CompareOperator op) { + auto cmp = GetComparator(op); + return CompareAndFilter(data, length, [&](CType e) { return cmp(e, *other++); }); +} - std::shared_ptr result; +TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) { + using ScalarType = typename TypeTraits::ScalarType; + using ArrayType = typename TypeTraits::ArrayType; + using CType = typename TypeTraits::CType; - if (l_array->null_count() == 0 && r_array->null_count() == 0) { - ArrayFromVector(bitmap, &result); - } else { - std::vector null_bitmap = NullBitmapFromArrays(*l_array, *r_array); - ArrayFromVector(null_bitmap, bitmap, &result); + auto rand = random::RandomArrayGenerator(0x5416447); + for (size_t i = 3; i < 13; i++) { + const int64_t length = static_cast(1ULL << i); + // TODO(bkietz) rewrite with some nulls + auto array = + checked_pointer_cast(rand.Numeric(length, 0, 100, 0)); + CType c_fifty = 50; + auto fifty = std::make_shared(c_fifty); + for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { + auto options = CompareOptions(op); + Datum selection, filtered; + ASSERT_OK(arrow::compute::Compare(&this->ctx_, Datum(array), Datum(fifty), options, + &selection)); + ASSERT_OK(arrow::compute::Filter(&this->ctx_, Datum(array), selection, &filtered)); + auto filtered_array = filtered.make_array(); + auto expected = + CompareAndFilter(array->raw_values(), array->length(), c_fifty, op); + ASSERT_ARRAYS_EQUAL(*filtered_array, *expected); + } } - - return Datum(result); } -template -static void ValidateCompare(FunctionContext* ctx, CompareOptions options, - const Datum& lhs, const Datum& rhs) { - Datum result; - - bool has_scalar = lhs.is_scalar() || rhs.is_scalar(); - Datum expected = has_scalar ? SimpleScalarArrayCompare(options, lhs, rhs) - : SimpleArrayArrayCompare(options, lhs, rhs); +TYPED_TEST(TestFilterKernelWithNumeric, CompareArrayAndFilterRandomNumeric) { + using ArrayType = typename TypeTraits::ArrayType; - ValidateCompare(ctx, options, lhs, rhs, expected); + auto rand = random::RandomArrayGenerator(0x5416447); + for (size_t i = 3; i < 13; i++) { + const int64_t length = static_cast(1ULL << i); + auto lhs = + checked_pointer_cast(rand.Numeric(length, 0, 100, 0)); + auto rhs = + checked_pointer_cast(rand.Numeric(length, 0, 100, 0)); + for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { + auto options = CompareOptions(op); + Datum selection, filtered; + ASSERT_OK(arrow::compute::Compare(&this->ctx_, Datum(lhs), Datum(rhs), options, + &selection)); + ASSERT_OK(arrow::compute::Filter(&this->ctx_, Datum(lhs), selection, &filtered)); + auto filtered_array = filtered.make_array(); + auto expected = CompareAndFilter(lhs->raw_values(), lhs->length(), + rhs->raw_values(), op); + ASSERT_ARRAYS_EQUAL(*filtered_array, *expected); + } + } } -template -class TestNumericCompareKernel : public ComputeFixture, public TestBase {}; - -TYPED_TEST_CASE(TestNumericCompareKernel, NumericArrowTypes); -TYPED_TEST(TestNumericCompareKernel, SimpleCompareArrayScalar) { +TYPED_TEST(TestFilterKernelWithNumeric, ScalarInRangeAndFilterRandomNumeric) { using ScalarType = typename TypeTraits::ScalarType; + using ArrayType = typename TypeTraits::ArrayType; using CType = typename TypeTraits::CType; - Datum one(std::make_shared(CType(1))); - - CompareOptions eq(CompareOperator::EQUAL); - ValidateCompare(&this->ctx_, eq, "[]", one, "[]"); - ValidateCompare(&this->ctx_, eq, "[null]", one, "[null]"); - ValidateCompare(&this->ctx_, eq, "[0,0,1,1,2,2]", one, "[0,0,1,1,0,0]"); - ValidateCompare(&this->ctx_, eq, "[0,1,2,3,4,5]", one, "[0,1,0,0,0,0]"); - ValidateCompare(&this->ctx_, eq, "[5,4,3,2,1,0]", one, "[0,0,0,0,1,0]"); - ValidateCompare(&this->ctx_, eq, "[null,0,1,1]", one, "[null,0,1,1]"); - - CompareOptions neq(CompareOperator::NOT_EQUAL); - ValidateCompare(&this->ctx_, neq, "[]", one, "[]"); - ValidateCompare(&this->ctx_, neq, "[null]", one, "[null]"); - ValidateCompare(&this->ctx_, neq, "[0,0,1,1,2,2]", one, "[1,1,0,0,1,1]"); - ValidateCompare(&this->ctx_, neq, "[0,1,2,3,4,5]", one, "[1,0,1,1,1,1]"); - ValidateCompare(&this->ctx_, neq, "[5,4,3,2,1,0]", one, "[1,1,1,1,0,1]"); - ValidateCompare(&this->ctx_, neq, "[null,0,1,1]", one, "[null,1,0,0]"); - - CompareOptions gt(CompareOperator::GREATER); - ValidateCompare(&this->ctx_, gt, "[]", one, "[]"); - ValidateCompare(&this->ctx_, gt, "[null]", one, "[null]"); - ValidateCompare(&this->ctx_, gt, "[0,0,1,1,2,2]", one, "[0,0,0,0,1,1]"); - ValidateCompare(&this->ctx_, gt, "[0,1,2,3,4,5]", one, "[0,0,1,1,1,1]"); - ValidateCompare(&this->ctx_, gt, "[4,5,6,7,8,9]", one, "[1,1,1,1,1,1]"); - ValidateCompare(&this->ctx_, gt, "[null,0,1,1]", one, "[null,0,0,0]"); - - CompareOptions gte(CompareOperator::GREATER_EQUAL); - ValidateCompare(&this->ctx_, gte, "[]", one, "[]"); - ValidateCompare(&this->ctx_, gte, "[null]", one, "[null]"); - ValidateCompare(&this->ctx_, gte, "[0,0,1,1,2,2]", one, "[0,0,1,1,1,1]"); - ValidateCompare(&this->ctx_, gte, "[0,1,2,3,4,5]", one, "[0,1,1,1,1,1]"); - ValidateCompare(&this->ctx_, gte, "[4,5,6,7,8,9]", one, "[1,1,1,1,1,1]"); - ValidateCompare(&this->ctx_, gte, "[null,0,1,1]", one, "[null,0,1,1]"); - - CompareOptions lt(CompareOperator::LESS); - ValidateCompare(&this->ctx_, lt, "[]", one, "[]"); - ValidateCompare(&this->ctx_, lt, "[null]", one, "[null]"); - ValidateCompare(&this->ctx_, lt, "[0,0,1,1,2,2]", one, "[1,1,0,0,0,0]"); - ValidateCompare(&this->ctx_, lt, "[0,1,2,3,4,5]", one, "[1,0,0,0,0,0]"); - ValidateCompare(&this->ctx_, lt, "[4,5,6,7,8,9]", one, "[0,0,0,0,0,0]"); - ValidateCompare(&this->ctx_, lt, "[null,0,1,1]", one, "[null,1,0,0]"); - - CompareOptions lte(CompareOperator::LESS_EQUAL); - ValidateCompare(&this->ctx_, lte, "[]", one, "[]"); - ValidateCompare(&this->ctx_, lte, "[null]", one, "[null]"); - ValidateCompare(&this->ctx_, lte, "[0,0,1,1,2,2]", one, "[1,1,1,1,0,0]"); - ValidateCompare(&this->ctx_, lte, "[0,1,2,3,4,5]", one, "[1,1,0,0,0,0]"); - ValidateCompare(&this->ctx_, lte, "[4,5,6,7,8,9]", one, "[0,0,0,0,0,0]"); - ValidateCompare(&this->ctx_, lte, "[null,0,1,1]", one, "[null,1,1,1]"); + auto rand = random::RandomArrayGenerator(0x5416447); + for (size_t i = 3; i < 13; i++) { + const int64_t length = static_cast(1ULL << i); + auto array = + checked_pointer_cast(rand.Numeric(length, 0, 100, 0)); + CType c_fifty = 50, c_hundred = 100; + auto fifty = std::make_shared(c_fifty); + auto hundred = std::make_shared(c_hundred); + Datum greater_than_fifty, less_than_hundred, selection, filtered; + ASSERT_OK(arrow::compute::Compare(&this->ctx_, Datum(array), Datum(fifty), + CompareOptions(GREATER), &greater_than_fifty)); + ASSERT_OK(arrow::compute::Compare(&this->ctx_, Datum(array), Datum(hundred), + CompareOptions(LESS), &less_than_hundred)); + ASSERT_OK(arrow::compute::And(&this->ctx_, greater_than_fifty, less_than_hundred, + &selection)); + ASSERT_OK(arrow::compute::Filter(&this->ctx_, Datum(array), selection, &filtered)); + auto filtered_array = filtered.make_array(); + auto expected = CompareAndFilter( + array->raw_values(), array->length(), + [&](CType e) { return (e > c_fifty) && (e < c_hundred); }); + ASSERT_ARRAYS_EQUAL(*filtered_array, *expected); + } } -TYPED_TEST(TestNumericCompareKernel, SimpleCompareScalarArray) { - using ScalarType = typename TypeTraits::ScalarType; - using CType = typename TypeTraits::CType; +class TestFilterKernelWithString : public TestFilterKernel { + protected: + void AssertFilter(const std::string& values, const std::string& filter, + const std::string& expected) { + TestFilterKernel::AssertFilter(utf8(), values, filter, expected); + } + void AssertFilterDictionary(const std::string& dictionary_values, + const std::string& dictionary_filter, + const std::string& filter, + const std::string& expected_filter) { + auto dict = ArrayFromJSON(utf8(), dictionary_values); + auto type = dictionary(int8(), utf8()); + std::shared_ptr values, actual, expected; + ASSERT_OK(DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), dictionary_filter), + dict, &values)); + ASSERT_OK(DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_filter), + dict, &expected)); + auto take_filter = ArrayFromJSON(boolean(), filter); + this->AssertFilterArrays(values, take_filter, expected); + } +}; - Datum one(std::make_shared(CType(1))); - - CompareOptions eq(CompareOperator::EQUAL); - ValidateCompare(&this->ctx_, eq, one, "[]", "[]"); - ValidateCompare(&this->ctx_, eq, one, "[null]", "[null]"); - ValidateCompare(&this->ctx_, eq, one, "[0,0,1,1,2,2]", "[0,0,1,1,0,0]"); - ValidateCompare(&this->ctx_, eq, one, "[0,1,2,3,4,5]", "[0,1,0,0,0,0]"); - ValidateCompare(&this->ctx_, eq, one, "[5,4,3,2,1,0]", "[0,0,0,0,1,0]"); - ValidateCompare(&this->ctx_, eq, one, "[null,0,1,1]", "[null,0,1,1]"); - - CompareOptions neq(CompareOperator::NOT_EQUAL); - ValidateCompare(&this->ctx_, neq, one, "[]", "[]"); - ValidateCompare(&this->ctx_, neq, one, "[null]", "[null]"); - ValidateCompare(&this->ctx_, neq, one, "[0,0,1,1,2,2]", "[1,1,0,0,1,1]"); - ValidateCompare(&this->ctx_, neq, one, "[0,1,2,3,4,5]", "[1,0,1,1,1,1]"); - ValidateCompare(&this->ctx_, neq, one, "[5,4,3,2,1,0]", "[1,1,1,1,0,1]"); - ValidateCompare(&this->ctx_, neq, one, "[null,0,1,1]", "[null,1,0,0]"); - - CompareOptions gt(CompareOperator::GREATER); - ValidateCompare(&this->ctx_, gt, one, "[]", "[]"); - ValidateCompare(&this->ctx_, gt, one, "[null]", "[null]"); - ValidateCompare(&this->ctx_, gt, one, "[0,0,1,1,2,2]", "[1,1,0,0,0,0]"); - ValidateCompare(&this->ctx_, gt, one, "[0,1,2,3,4,5]", "[1,0,0,0,0,0]"); - ValidateCompare(&this->ctx_, gt, one, "[4,5,6,7,8,9]", "[0,0,0,0,0,0]"); - ValidateCompare(&this->ctx_, gt, one, "[null,0,1,1]", "[null,1,0,0]"); - - CompareOptions gte(CompareOperator::GREATER_EQUAL); - ValidateCompare(&this->ctx_, gte, one, "[]", "[]"); - ValidateCompare(&this->ctx_, gte, one, "[null]", "[null]"); - ValidateCompare(&this->ctx_, gte, one, "[0,0,1,1,2,2]", "[1,1,1,1,0,0]"); - ValidateCompare(&this->ctx_, gte, one, "[0,1,2,3,4,5]", "[1,1,0,0,0,0]"); - ValidateCompare(&this->ctx_, gte, one, "[4,5,6,7,8,9]", "[0,0,0,0,0,0]"); - ValidateCompare(&this->ctx_, gte, one, "[null,0,1,1]", "[null,1,1,1]"); - - CompareOptions lt(CompareOperator::LESS); - ValidateCompare(&this->ctx_, lt, one, "[]", "[]"); - ValidateCompare(&this->ctx_, lt, one, "[null]", "[null]"); - ValidateCompare(&this->ctx_, lt, one, "[0,0,1,1,2,2]", "[0,0,0,0,1,1]"); - ValidateCompare(&this->ctx_, lt, one, "[0,1,2,3,4,5]", "[0,0,1,1,1,1]"); - ValidateCompare(&this->ctx_, lt, one, "[4,5,6,7,8,9]", "[1,1,1,1,1,1]"); - ValidateCompare(&this->ctx_, lt, one, "[null,0,1,1]", "[null,0,0,0]"); - - CompareOptions lte(CompareOperator::LESS_EQUAL); - ValidateCompare(&this->ctx_, lte, one, "[]", "[]"); - ValidateCompare(&this->ctx_, lte, one, "[null]", "[null]"); - ValidateCompare(&this->ctx_, lte, one, "[0,0,1,1,2,2]", "[0,0,1,1,1,1]"); - ValidateCompare(&this->ctx_, lte, one, "[0,1,2,3,4,5]", "[0,1,1,1,1,1]"); - ValidateCompare(&this->ctx_, lte, one, "[4,5,6,7,8,9]", "[1,1,1,1,1,1]"); - ValidateCompare(&this->ctx_, lte, one, "[null,0,1,1]", "[null,0,1,1]"); +TEST_F(TestFilterKernelWithString, FilterString) { + this->AssertFilter(R"(["a", "b", "c"])", "[0, 1, 0]", R"(["b"])"); + this->AssertFilter(R"([null, "b", "c"])", "[0, 1, 0]", R"(["b"])"); + this->AssertFilter(R"(["a", "b", "c"])", "[null, 1, 0]", R"([null, "b"])"); } -TYPED_TEST(TestNumericCompareKernel, TestNullScalar) { - /* Ensure that null scalar broadcast to all null results. */ - using ScalarType = typename TypeTraits::ScalarType; - using CType = typename TypeTraits::CType; +TEST_F(TestFilterKernelWithString, FilterDictionary) { + auto dict = R"(["a", "b", "c", "d", "e"])"; + this->AssertFilterDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", "[4]"); + this->AssertFilterDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", "[4]"); + this->AssertFilterDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", "[null, 4]"); +} - Datum null(std::make_shared(CType(0), false)); - EXPECT_FALSE(null.scalar()->is_valid); +class TestFilterKernelWithList : public TestFilterKernel {}; - CompareOptions eq(CompareOperator::EQUAL); - ValidateCompare(&this->ctx_, eq, "[]", null, "[]"); - ValidateCompare(&this->ctx_, eq, null, "[]", "[]"); - ValidateCompare(&this->ctx_, eq, "[null]", null, "[null]"); - ValidateCompare(&this->ctx_, eq, null, "[null]", "[null]"); - ValidateCompare(&this->ctx_, eq, null, "[1,2,3]", "[null, null, null]"); +TEST_F(TestFilterKernelWithList, FilterListInt32) { + std::string list_json = "[[], [1,2], null, [3]]"; + this->AssertFilter(list(int32()), list_json, "[0, 0, 0, 0]", "[]"); + this->AssertFilter(list(int32()), list_json, "[0, 1, 1, null]", "[[1,2], null, null]"); + this->AssertFilter(list(int32()), list_json, "[0, 0, 1, null]", "[null, null]"); + this->AssertFilter(list(int32()), list_json, "[1, 0, 0, 1]", "[[], [3]]"); + this->AssertFilter(list(int32()), list_json, "[1, 1, 1, 1]", list_json); + this->AssertFilter(list(int32()), list_json, "[0, 1, 0, 1]", "[[1,2], [3]]"); } -TYPED_TEST_CASE(TestNumericCompareKernel, NumericArrowTypes); -TYPED_TEST(TestNumericCompareKernel, RandomCompareArrayScalar) { - using ScalarType = typename TypeTraits::ScalarType; - using CType = typename TypeTraits::CType; - - auto rand = random::RandomArrayGenerator(0x5416447); - for (size_t i = 3; i < 13; i++) { - for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { - for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { - const int64_t length = static_cast(1ULL << i); - auto array = Datum(rand.Numeric(length, 0, 100, null_probability)); - auto fifty = Datum(std::make_shared(CType(50))); - auto options = CompareOptions(op); - ValidateCompare(&this->ctx_, options, array, fifty); - ValidateCompare(&this->ctx_, options, fifty, array); - } - } - } +class TestFilterKernelWithFixedSizeList : public TestFilterKernel {}; + +TEST_F(TestFilterKernelWithFixedSizeList, FilterFixedSizeListInt32) { + std::string list_json = "[null, [1, null, 3], [4, 5, 6], [7, 8, null]]"; + this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 0, 0, 0]", "[]"); + this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 1, 1, null]", + "[[1, null, 3], [4, 5, 6], null]"); + this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 0, 1, null]", + "[[4, 5, 6], null]"); + this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[1, 1, 1, 1]", list_json); + this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 1, 0, 1]", + "[[1, null, 3], [7, 8, null]]"); } -TYPED_TEST(TestNumericCompareKernel, SimpleCompareArrayArray) { - /* Ensure that null scalar broadcast to all null results. */ - CompareOptions eq(CompareOperator::EQUAL); - ValidateCompare(&this->ctx_, eq, "[]", "[]", "[]"); - ValidateCompare(&this->ctx_, eq, "[null]", "[null]", "[null]"); - ValidateCompare(&this->ctx_, eq, "[1]", "[1]", "[1]"); - ValidateCompare(&this->ctx_, eq, "[1]", "[2]", "[0]"); - ValidateCompare(&this->ctx_, eq, "[null]", "[1]", "[null]"); - ValidateCompare(&this->ctx_, eq, "[1]", "[null]", "[null]"); - - CompareOptions lte(CompareOperator::LESS_EQUAL); - ValidateCompare(&this->ctx_, lte, "[1,2,3,4,5]", "[2,3,4,5,6]", - "[1,1,1,1,1]"); +class TestFilterKernelWithMap : public TestFilterKernel {}; + +TEST_F(TestFilterKernelWithMap, FilterMapStringToInt32) { + std::string map_json = R"([ + [["joe", 0], ["mark", null]], + null, + [["cap", 8]], + [] + ])"; + this->AssertFilter(map(utf8(), int32()), map_json, "[0, 0, 0, 0]", "[]"); + this->AssertFilter(map(utf8(), int32()), map_json, "[0, 1, 1, null]", R"([ + null, + [["cap", 8]], + null + ])"); + this->AssertFilter(map(utf8(), int32()), map_json, "[1, 1, 1, 1]", map_json); + this->AssertFilter(map(utf8(), int32()), map_json, "[0, 1, 0, 1]", "[null, []]"); } -TYPED_TEST(TestNumericCompareKernel, RandomCompareArrayArray) { - auto rand = random::RandomArrayGenerator(0x5416447); - for (size_t i = 3; i < 5; i++) { - for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { - for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { - const int64_t length = static_cast(1ULL << i); - auto lhs = Datum(rand.Numeric(length << i, 0, 100, null_probability)); - auto rhs = Datum(rand.Numeric(length << i, 0, 100, null_probability)); - auto options = CompareOptions(op); - ValidateCompare(&this->ctx_, options, lhs, rhs); - } - } - } +class TestFilterKernelWithStruct : public TestFilterKernel {}; + +TEST_F(TestFilterKernelWithStruct, FilterStruct) { + auto struct_type = struct_({field("a", int32()), field("b", utf8())}); + auto struct_json = R"([ + null, + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + {"a": 4, "b": "eh"} + ])"; + this->AssertFilter(struct_type, struct_json, "[0, 0, 0, 0]", "[]"); + this->AssertFilter(struct_type, struct_json, "[0, 1, 1, null]", R"([ + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + null + ])"); + this->AssertFilter(struct_type, struct_json, "[1, 1, 1, 1]", struct_json); + this->AssertFilter(struct_type, struct_json, "[1, 0, 1, 0]", R"([ + null, + {"a": 2, "b": "hello"} + ])"); } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/filter.cc b/cpp/src/arrow/compute/kernels/filter.cc index 1cbf0dc0992d1..654ec610352ce 100644 --- a/cpp/src/arrow/compute/kernels/filter.cc +++ b/cpp/src/arrow/compute/kernels/filter.cc @@ -15,44 +15,447 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/compute/kernels/filter.h" +#include +#include +#include +#include -#include "arrow/array.h" -#include "arrow/compute/kernel.h" +#include "arrow/builder.h" +#include "arrow/compute/context.h" +#include "arrow/compute/kernels/filter.h" +#include "arrow/util/bit-util.h" +#include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" +#include "arrow/util/stl.h" +#include "arrow/visitor_inline.h" namespace arrow { - namespace compute { -std::shared_ptr FilterBinaryKernel::out_type() const { - return filter_function_->out_type(); +using internal::checked_cast; +using internal::checked_pointer_cast; + +template +Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, + std::unique_ptr* out) { + std::unique_ptr builder; + RETURN_NOT_OK(MakeBuilder(pool, type, &builder)); + out->reset(checked_cast(builder.release())); + return Status::OK(); +} + +template +static Status UnsafeAppend(Builder* builder, Scalar&& value) { + builder->UnsafeAppend(std::forward(value)); + return Status::OK(); +} + +static Status UnsafeAppend(BinaryBuilder* builder, util::string_view value) { + RETURN_NOT_OK(builder->ReserveData(static_cast(value.size()))); + builder->UnsafeAppend(value); + return Status::OK(); +} + +static Status UnsafeAppend(StringBuilder* builder, util::string_view value) { + RETURN_NOT_OK(builder->ReserveData(static_cast(value.size()))); + builder->UnsafeAppend(value); + return Status::OK(); } -Status FilterBinaryKernel::Call(FunctionContext* ctx, const Datum& left, - const Datum& right, Datum* out) { - DCHECK(left.type()->Equals(right.type())); +// TODO(bkietz) this can be optimized +static int64_t OutputSize(const BooleanArray& filter) { + auto offset = filter.offset(); + auto length = filter.length(); + int64_t size = 0; + for (auto i = offset; i < offset + length; ++i) { + if (filter.IsNull(i) || filter.Value(i)) { + ++size; + } + } + return size; +} + +template +class FilterImpl; + +template <> +class FilterImpl : public FilterKernel { + public: + using FilterKernel::FilterKernel; + + Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, + int64_t length, std::shared_ptr* out) override { + out->reset(new NullArray(length)); + return Status::OK(); + } +}; + +template +class FilterImpl : public FilterKernel { + public: + using ValueArray = typename TypeTraits::ArrayType; + using OutBuilder = typename TypeTraits::BuilderType; + + using FilterKernel::FilterKernel; + + Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, + int64_t length, std::shared_ptr* out) override { + std::unique_ptr builder; + RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), type_, &builder)); + RETURN_NOT_OK(builder->Resize(OutputSize(filter))); + RETURN_NOT_OK(UnpackValuesNullCount(checked_cast(values), filter, + builder.get())); + return builder->Finish(out); + } + + private: + Status UnpackValuesNullCount(const ValueArray& values, const BooleanArray& filter, + OutBuilder* builder) { + if (values.null_count() == 0) { + return UnpackIndicesNullCount(values, filter, builder); + } + return UnpackIndicesNullCount(values, filter, builder); + } + + template + Status UnpackIndicesNullCount(const ValueArray& values, const BooleanArray& filter, + OutBuilder* builder) { + if (filter.null_count() == 0) { + return Filter(values, filter, builder); + } + return Filter(values, filter, builder); + } + + template + Status Filter(const ValueArray& values, const BooleanArray& filter, + OutBuilder* builder) { + for (int64_t i = 0; i < filter.length(); ++i) { + if (!AllIndicesValid && filter.IsNull(i)) { + builder->UnsafeAppendNull(); + continue; + } + if (!filter.Value(i)) { + continue; + } + if (!AllValuesValid && values.IsNull(i)) { + builder->UnsafeAppendNull(); + continue; + } + RETURN_NOT_OK(UnsafeAppend(builder, values.GetView(i))); + } + return Status::OK(); + } +}; + +template <> +class FilterImpl : public FilterKernel { + public: + FilterImpl(const std::shared_ptr& type, + std::vector> child_kernels) + : FilterKernel(type), child_kernels_(std::move(child_kernels)) {} + + Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, + int64_t length, std::shared_ptr* out) override { + const auto& struct_array = checked_cast(values); + + TypedBufferBuilder null_bitmap_builder(ctx->memory_pool()); + RETURN_NOT_OK(null_bitmap_builder.Resize(length)); + + ArrayVector fields(type_->num_children()); + for (int i = 0; i < type_->num_children(); ++i) { + RETURN_NOT_OK(child_kernels_[i]->Filter(ctx, *struct_array.field(i), filter, length, + &fields[i])); + } + + for (int64_t i = 0; i < filter.length(); ++i) { + if (filter.IsNull(i)) { + null_bitmap_builder.UnsafeAppend(false); + continue; + } + if (!filter.Value(i)) { + continue; + } + if (struct_array.IsNull(i)) { + null_bitmap_builder.UnsafeAppend(false); + continue; + } + null_bitmap_builder.UnsafeAppend(true); + } + + auto null_count = null_bitmap_builder.false_count(); + std::shared_ptr null_bitmap; + RETURN_NOT_OK(null_bitmap_builder.Finish(&null_bitmap)); + + out->reset(new StructArray(type_, length, fields, null_bitmap, null_count)); + return Status::OK(); + } + + private: + std::vector> child_kernels_; +}; + +template <> +class FilterImpl : public FilterKernel { + public: + using FilterKernel::FilterKernel; + + Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, + int64_t length, std::shared_ptr* out) override { + const auto& list_array = checked_cast(values); + + TypedBufferBuilder null_bitmap_builder(ctx->memory_pool()); + RETURN_NOT_OK(null_bitmap_builder.Resize(length)); - auto lk = left.kind(); - auto rk = right.kind(); - auto out_array = out->array(); + BooleanBuilder value_filter_builder(ctx->memory_pool()); + auto list_size = list_array.list_type()->list_size(); + RETURN_NOT_OK(value_filter_builder.Resize(list_size * length)); - if (lk == Datum::ARRAY && rk == Datum::SCALAR) { - auto array = left.array(); - auto scalar = right.scalar(); - return filter_function_->Filter(*array, *scalar, &out_array); - } else if (lk == Datum::SCALAR && rk == Datum::ARRAY) { - auto scalar = left.scalar(); - auto array = right.array(); - auto out_array = out->array(); - return filter_function_->Filter(*scalar, *array, &out_array); - } else if (lk == Datum::ARRAY && rk == Datum::ARRAY) { - auto lhs = left.array(); - auto rhs = right.array(); - return filter_function_->Filter(*lhs, *rhs, &out_array); + for (int64_t i = 0; i < filter.length(); ++i) { + if (filter.IsNull(i)) { + null_bitmap_builder.UnsafeAppend(false); + for (int64_t j = 0; j < list_size; ++j) { + value_filter_builder.UnsafeAppendNull(); + } + continue; + } + if (!filter.Value(i)) { + for (int64_t j = 0; j < list_size; ++j) { + value_filter_builder.UnsafeAppend(false); + } + continue; + } + if (values.IsNull(i)) { + null_bitmap_builder.UnsafeAppend(false); + for (int64_t j = 0; j < list_size; ++j) { + value_filter_builder.UnsafeAppendNull(); + } + continue; + } + for (int64_t j = 0; j < list_size; ++j) { + value_filter_builder.UnsafeAppend(true); + } + null_bitmap_builder.UnsafeAppend(true); + } + + std::shared_ptr value_filter; + RETURN_NOT_OK(value_filter_builder.Finish(&value_filter)); + std::shared_ptr out_values; + RETURN_NOT_OK( + arrow::compute::Filter(ctx, *list_array.values(), *value_filter, &out_values)); + + auto null_count = null_bitmap_builder.false_count(); + std::shared_ptr null_bitmap; + RETURN_NOT_OK(null_bitmap_builder.Finish(&null_bitmap)); + + out->reset( + new FixedSizeListArray(type_, length, out_values, null_bitmap, null_count)); + return Status::OK(); + } +}; + +template <> +class FilterImpl : public FilterKernel { + public: + using FilterKernel::FilterKernel; + + Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, + int64_t length, std::shared_ptr* out) override { + const auto& list_array = checked_cast(values); + + TypedBufferBuilder null_bitmap_builder(ctx->memory_pool()); + RETURN_NOT_OK(null_bitmap_builder.Resize(length)); + + BooleanBuilder value_filter_builder(ctx->memory_pool()); + + TypedBufferBuilder offset_builder(ctx->memory_pool()); + RETURN_NOT_OK(offset_builder.Resize(length + 1)); + int32_t offset = 0; + offset_builder.UnsafeAppend(offset); + + for (int64_t i = 0; i < filter.length(); ++i) { + if (filter.IsNull(i)) { + null_bitmap_builder.UnsafeAppend(false); + offset_builder.UnsafeAppend(offset); + RETURN_NOT_OK( + value_filter_builder.AppendValues(list_array.value_length(i), false)); + continue; + } + if (!filter.Value(i)) { + RETURN_NOT_OK( + value_filter_builder.AppendValues(list_array.value_length(i), false)); + continue; + } + if (values.IsNull(i)) { + null_bitmap_builder.UnsafeAppend(false); + offset_builder.UnsafeAppend(offset); + RETURN_NOT_OK( + value_filter_builder.AppendValues(list_array.value_length(i), false)); + continue; + } + null_bitmap_builder.UnsafeAppend(true); + offset += list_array.value_length(i); + offset_builder.UnsafeAppend(offset); + RETURN_NOT_OK(value_filter_builder.AppendValues(list_array.value_length(i), true)); + } + + std::shared_ptr value_filter; + RETURN_NOT_OK(value_filter_builder.Finish(&value_filter)); + std::shared_ptr out_values; + RETURN_NOT_OK( + arrow::compute::Filter(ctx, *list_array.values(), *value_filter, &out_values)); + + auto null_count = null_bitmap_builder.false_count(); + std::shared_ptr offsets, null_bitmap; + RETURN_NOT_OK(offset_builder.Finish(&offsets)); + RETURN_NOT_OK(null_bitmap_builder.Finish(&null_bitmap)); + + *out = MakeArray(ArrayData::Make(type_, length, {null_bitmap, offsets}, + {out_values->data()}, null_count)); + return Status::OK(); + } +}; + +template <> +class FilterImpl : public FilterImpl { + using FilterImpl::FilterImpl; +}; + +template <> +class FilterImpl : public FilterKernel { + public: + FilterImpl(const std::shared_ptr& type, std::unique_ptr impl) + : FilterKernel(type), impl_(std::move(impl)) {} + + Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, + int64_t length, std::shared_ptr* out) override { + auto dict_array = checked_cast(&values); + // To filter a dictionary, apply the current kernel to the dictionary's indices. + std::shared_ptr taken_indices; + RETURN_NOT_OK( + impl_->Filter(ctx, *dict_array->indices(), filter, length, &taken_indices)); + return DictionaryArray::FromArrays(values.type(), taken_indices, + dict_array->dictionary(), out); + } + + private: + std::unique_ptr impl_; +}; + +template <> +class FilterImpl : public FilterKernel { + public: + FilterImpl(const std::shared_ptr& type, std::unique_ptr impl) + : FilterKernel(type), impl_(std::move(impl)) {} + + Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, + int64_t length, std::shared_ptr* out) override { + auto ext_array = checked_cast(&values); + // To take from an extension array, apply the current kernel to storage. + std::shared_ptr taken_storage; + RETURN_NOT_OK( + impl_->Filter(ctx, *ext_array->storage(), filter, length, &taken_storage)); + *out = ext_array->extension_type()->MakeArray(taken_storage->data()); + return Status::OK(); + } + + private: + std::unique_ptr impl_; +}; + +Status FilterKernel::Make(const std::shared_ptr& value_type, + std::unique_ptr* out) { + switch (value_type->id()) { +#define NO_CHILD_CASE(T) \ + case T##Type::type_id: \ + *out = internal::make_unique>(value_type); \ + return Status::OK() + +#define SINGLE_CHILD_CASE(T, CHILD_TYPE) \ + case T##Type::type_id: { \ + auto t = checked_pointer_cast(value_type); \ + std::unique_ptr child_filter_impl; \ + RETURN_NOT_OK(FilterKernel::Make(t->CHILD_TYPE(), &child_filter_impl)); \ + *out = internal::make_unique>(t, std::move(child_filter_impl)); \ + return Status::OK(); \ } - return Status::Invalid("Invalid datum signature for FilterBinaryKernel"); + NO_CHILD_CASE(Null); + NO_CHILD_CASE(Boolean); + NO_CHILD_CASE(Int8); + NO_CHILD_CASE(Int16); + NO_CHILD_CASE(Int32); + NO_CHILD_CASE(Int64); + NO_CHILD_CASE(UInt8); + NO_CHILD_CASE(UInt16); + NO_CHILD_CASE(UInt32); + NO_CHILD_CASE(UInt64); + NO_CHILD_CASE(Date32); + NO_CHILD_CASE(Date64); + NO_CHILD_CASE(Time32); + NO_CHILD_CASE(Time64); + NO_CHILD_CASE(Timestamp); + NO_CHILD_CASE(Duration); + NO_CHILD_CASE(HalfFloat); + NO_CHILD_CASE(Float); + NO_CHILD_CASE(Double); + NO_CHILD_CASE(String); + NO_CHILD_CASE(Binary); + NO_CHILD_CASE(FixedSizeBinary); + NO_CHILD_CASE(Decimal128); + + SINGLE_CHILD_CASE(Dictionary, index_type); + SINGLE_CHILD_CASE(Extension, storage_type); + + NO_CHILD_CASE(List); + NO_CHILD_CASE(FixedSizeList); + NO_CHILD_CASE(Map); + + case Type::STRUCT: { + std::vector> child_kernels; + for (auto child : value_type->children()) { + child_kernels.emplace_back(); + RETURN_NOT_OK(FilterKernel::Make(child->type(), &child_kernels.back())); + } + *out = internal::make_unique>(value_type, + std::move(child_kernels)); + return Status::OK(); + } + +#undef NO_CHILD_CASE +#undef SINGLE_CHILD_CASE + + default: + return Status::NotImplemented("gathering values of type ", *value_type); + } +} + +Status FilterKernel::Call(FunctionContext* ctx, const Datum& values, const Datum& filter, + Datum* out) { + if (!values.is_array() || !filter.is_array()) { + return Status::Invalid("FilterKernel expects array values and filter"); + } + auto values_array = values.make_array(); + auto filter_array = checked_pointer_cast(filter.make_array()); + const auto length = OutputSize(*filter_array); + std::shared_ptr out_array; + RETURN_NOT_OK(this->Filter(ctx, *values_array, *filter_array, length, &out_array)); + *out = out_array; + return Status::OK(); +} + +Status Filter(FunctionContext* context, const Array& values, const Array& filter, + std::shared_ptr* out) { + Datum out_datum; + RETURN_NOT_OK(Filter(context, Datum(values.data()), Datum(filter.data()), &out_datum)); + *out = out_datum.make_array(); + return Status::OK(); +} + +Status Filter(FunctionContext* context, const Datum& values, const Datum& filter, + Datum* out) { + std::unique_ptr kernel; + RETURN_NOT_OK(FilterKernel::Make(values.type(), &kernel)); + return kernel->Call(context, values, filter, out); } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/filter.h b/cpp/src/arrow/compute/kernels/filter.h index 3b28bc9391a6a..46ad3d42b87f1 100644 --- a/cpp/src/arrow/compute/kernels/filter.h +++ b/cpp/src/arrow/compute/kernels/filter.h @@ -20,76 +20,74 @@ #include #include "arrow/compute/kernel.h" +#include "arrow/status.h" +#include "arrow/util/visibility.h" namespace arrow { class Array; -struct Scalar; -class Status; namespace compute { class FunctionContext; -struct Datum; -/// FilterFunction is an interface for Filters +/// \brief Filter an array with a boolean selection filter /// -/// Filters takes an array and emits a selection vector. The selection vector -/// is given in the form of a bitmask as a BooleanArray result. -class ARROW_EXPORT FilterFunction { - public: - /// Filter an array with a scalar argument. - virtual Status Filter(const ArrayData& array, const Scalar& scalar, - ArrayData* output) const = 0; - - Status Filter(const ArrayData& array, const Scalar& scalar, - std::shared_ptr* output) { - return Filter(array, scalar, output->get()); - } - - virtual Status Filter(const Scalar& scalar, const ArrayData& array, - ArrayData* output) const = 0; - - Status Filter(const Scalar& scalar, const ArrayData& array, - std::shared_ptr* output) { - return Filter(scalar, array, output->get()); - } - - /// Filter an array with an array argument. - virtual Status Filter(const ArrayData& lhs, const ArrayData& rhs, - ArrayData* output) const = 0; - - Status Filter(const ArrayData& lhs, const ArrayData& rhs, - std::shared_ptr* output) { - return Filter(lhs, rhs, output->get()); - } - - /// By default, FilterFunction emits a result bitmap. - virtual std::shared_ptr out_type() const { return boolean(); } - - virtual ~FilterFunction() {} -}; - -/// \brief BinaryKernel bound to a filter function -class ARROW_EXPORT FilterBinaryKernel : public BinaryKernel { +/// The output array will be populated with values from the input at positions +/// where the selection filter is not 0. Nulls in the filter will result in nulls +/// in the output. +/// +/// For example given values = ["a", "b", "c", null, "e", "f"] and +/// filter = [0, 1, 1, 0, null, 1], the output will be +/// = ["b", "c", null, "f"] +/// +/// \param[in] context the FunctionContext +/// \param[in] values array to filter +/// \param[in] filter indicates which values should be filtered out +/// \param[out] out resulting array +ARROW_EXPORT +Status Filter(FunctionContext* context, const Array& values, const Array& filter, + std::shared_ptr* out); + +/// \brief Filter an array with a boolean selection filter +/// +/// \param[in] context the FunctionContext +/// \param[in] values datum to filter +/// \param[in] filter indicates which values should be filtered out +/// \param[out] out resulting datum +ARROW_EXPORT +Status Filter(FunctionContext* context, const Datum& values, const Datum& filter, + Datum* out); + +/// \brief BinaryKernel implementing Filter operation +class ARROW_EXPORT FilterKernel : public BinaryKernel { public: - explicit FilterBinaryKernel(std::shared_ptr& filter) - : filter_function_(filter) {} + explicit FilterKernel(const std::shared_ptr& type) : type_(type) {} - Status Call(FunctionContext* ctx, const Datum& left, const Datum& right, + /// \brief BinaryKernel interface + /// + /// delegates to subclasses via Filter() + Status Call(FunctionContext* ctx, const Datum& values, const Datum& filter, Datum* out) override; - static int64_t out_length(const Datum& left, const Datum& right) { - if (left.kind() == Datum::ARRAY) return left.length(); - if (right.kind() == Datum::ARRAY) return right.length(); + /// \brief output type of this kernel (identical to type of values filtered) + std::shared_ptr out_type() const override { return type_; } - return 0; - } + /// \brief factory for FilterKernels + /// + /// \param[in] value_type constructed FilterKernel will support filtering + /// values of this type + /// \param[out] out created kernel + static Status Make(const std::shared_ptr& value_type, + std::unique_ptr* out); - std::shared_ptr out_type() const override; + /// \brief single-array implementation + virtual Status Filter(FunctionContext* ctx, const Array& values, + const BooleanArray& filter, int64_t length, + std::shared_ptr* out) = 0; - private: - std::shared_ptr filter_function_; + protected: + std::shared_ptr type_; }; } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/take-test.cc b/cpp/src/arrow/compute/kernels/take-test.cc index b3de04d8cc762..c61aedab5b958 100644 --- a/cpp/src/arrow/compute/kernels/take-test.cc +++ b/cpp/src/arrow/compute/kernels/take-test.cc @@ -122,6 +122,7 @@ TYPED_TEST(TestTakeKernelWithNumeric, TakeNumeric) { this->AssertTake("[7, 8, 9]", "[0, 1, 0]", options, "[7, 8, 7]"); this->AssertTake("[null, 8, 9]", "[0, 1, 0]", options, "[null, 8, null]"); this->AssertTake("[7, 8, 9]", "[null, 1, 0]", options, "[null, 8, 7]"); + this->AssertTake("[null, 8, 9]", "[]", options, "[]"); std::shared_ptr arr; ASSERT_RAISES(IndexError, this->Take(this->type_singleton(), "[7, 8, 9]", int8(), diff --git a/cpp/src/arrow/compute/kernels/take.cc b/cpp/src/arrow/compute/kernels/take.cc index 9af2c0cab1152..17b054099ea32 100644 --- a/cpp/src/arrow/compute/kernels/take.cc +++ b/cpp/src/arrow/compute/kernels/take.cc @@ -28,6 +28,8 @@ namespace arrow { namespace compute { +using internal::checked_cast; + Status Take(FunctionContext* context, const Array& values, const Array& indices, const TakeOptions& options, std::shared_ptr* out) { Datum out_datum; @@ -119,20 +121,20 @@ struct UnpackValues { Status Visit(const ValueType&) { using ValueArrayRef = const typename TypeTraits::ArrayType&; using OutBuilder = typename TypeTraits::BuilderType; - IndexArrayRef indices = static_cast(*params_.indices); - ValueArrayRef values = static_cast(*params_.values); + IndexArrayRef indices = checked_cast(*params_.indices); + ValueArrayRef values = checked_cast(*params_.values); std::unique_ptr builder; RETURN_NOT_OK(MakeBuilder(params_.context->memory_pool(), values.type(), &builder)); RETURN_NOT_OK(builder->Reserve(indices.length())); RETURN_NOT_OK(UnpackValuesNullCount(params_.context, values, indices, - static_cast(builder.get()))); + checked_cast(builder.get()))); return builder->Finish(params_.out); } Status Visit(const NullType& t) { auto indices_length = params_.indices->length(); if (indices_length != 0) { - auto indices = static_cast(*params_.indices).raw_values(); + auto indices = checked_cast(*params_.indices).raw_values(); auto minmax = std::minmax_element(indices, indices + indices_length); auto min = static_cast(*minmax.first); auto max = static_cast(*minmax.second); diff --git a/cpp/src/arrow/extension_type.h b/cpp/src/arrow/extension_type.h index 6a1ca0b71553d..8bf4639bd1272 100644 --- a/cpp/src/arrow/extension_type.h +++ b/cpp/src/arrow/extension_type.h @@ -93,6 +93,10 @@ class ARROW_EXPORT ExtensionArray : public Array { ExtensionArray(const std::shared_ptr& type, const std::shared_ptr& storage); + const ExtensionType* extension_type() const { + return internal::checked_cast(data_->type.get()); + } + /// \brief The physical storage for the extension array std::shared_ptr storage() const { return storage_; }