diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index c989f855a5b0b..1666d178fac7b 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -172,8 +172,8 @@ if(ARROW_COMPUTE) compute/kernels/cast.cc compute/kernels/compare.cc compute/kernels/count.cc - compute/kernels/filter.cc compute/kernels/hash.cc + compute/kernels/filter.cc compute/kernels/mean.cc compute/kernels/sum.cc compute/kernels/take.cc diff --git a/cpp/src/arrow/array.cc b/cpp/src/arrow/array.cc index 9d37b45914bd0..b41d6dda6816d 100644 --- a/cpp/src/arrow/array.cc +++ b/cpp/src/arrow/array.cc @@ -1369,6 +1369,175 @@ std::shared_ptr MakeArray(const std::shared_ptr& data) { namespace internal { +// get the maximum buffer length required, then allocate a single zeroed buffer +// to use anywhere a buffer is required +class NullArrayFactory { + public: + struct GetBufferLength { + GetBufferLength(const std::shared_ptr& type, int64_t length) + : type_(*type), length_(length), buffer_length_(BitUtil::BytesForBits(length)) {} + + operator int64_t() && { + DCHECK_OK(VisitTypeInline(type_, this)); + return buffer_length_; + } + + template ::bytes_required(0))> + Status Visit(const T&) { + return MaxOf(TypeTraits::bytes_required(length_)); + } + + Status Visit(const ListType& type) { + // list's values array may be empty, but there must be at least one offset of 0 + return MaxOf(sizeof(int32_t)); + } + + Status Visit(const FixedSizeListType& type) { + return MaxOf(GetBufferLength(type.value_type(), type.list_size() * length_)); + } + + Status Visit(const StructType& type) { + for (const auto& child : type.children()) { + DCHECK_OK(MaxOf(GetBufferLength(child->type(), length_))); + } + return Status::OK(); + } + + Status Visit(const UnionType& type) { + // type codes + DCHECK_OK(MaxOf(length_)); + if (type.mode() == UnionMode::DENSE) { + // offsets + DCHECK_OK(MaxOf(sizeof(int32_t) * length_)); + } + for (const auto& child : type.children()) { + DCHECK_OK(MaxOf(GetBufferLength(child->type(), length_))); + } + return Status::OK(); + } + + Status Visit(const DictionaryType& type) { + DCHECK_OK(MaxOf(GetBufferLength(type.value_type(), length_))); + return MaxOf(GetBufferLength(type.index_type(), length_)); + } + + Status Visit(const ExtensionType& type) { + // XXX is an extension array's length always == storage length + return MaxOf(GetBufferLength(type.storage_type(), length_)); + } + + Status Visit(const DataType& type) { + return Status::NotImplemented("construction of all-null ", type); + } + + private: + Status MaxOf(int64_t buffer_length) { + if (buffer_length > buffer_length_) { + buffer_length_ = buffer_length; + } + return Status::OK(); + } + + const DataType& type_; + int64_t length_, buffer_length_; + }; + + NullArrayFactory(const std::shared_ptr& type, int64_t length, + std::shared_ptr* out) + : type_(type), length_(length), out_(out) {} + + Status CreateBuffer() { + int64_t buffer_length = GetBufferLength(type_, length_); + RETURN_NOT_OK(AllocateBuffer(buffer_length, &buffer_)); + std::memset(buffer_->mutable_data(), 0, buffer_->size()); + return Status::OK(); + } + + Status Create() { + if (buffer_ == nullptr) { + RETURN_NOT_OK(CreateBuffer()); + } + std::vector> child_data(type_->num_children()); + *out_ = ArrayData::Make(type_, length_, {buffer_}, child_data, length_, 0); + return VisitTypeInline(*type_, this); + } + + Status Visit(const NullType&) { return Status::OK(); } + + Status Visit(const FixedWidthType&) { + (*out_)->buffers.resize(2, buffer_); + return Status::OK(); + } + + Status Visit(const BinaryType&) { + (*out_)->buffers.resize(3, buffer_); + return Status::OK(); + } + + Status Visit(const ListType& type) { + (*out_)->buffers.resize(2, buffer_); + return CreateChild(0, length_, &(*out_)->child_data[0]); + } + + Status Visit(const FixedSizeListType& type) { + return CreateChild(0, length_ * type.list_size(), &(*out_)->child_data[0]); + } + + Status Visit(const StructType& type) { + for (int i = 0; i < type_->num_children(); ++i) { + DCHECK_OK(CreateChild(i, length_, &(*out_)->child_data[i])); + } + return Status::OK(); + } + + Status Visit(const UnionType& type) { + if (type.mode() == UnionMode::DENSE) { + (*out_)->buffers.resize(3, buffer_); + } else { + (*out_)->buffers.resize(2, buffer_); + } + + for (int i = 0; i < type_->num_children(); ++i) { + DCHECK_OK(CreateChild(i, length_, &(*out_)->child_data[i])); + } + return Status::OK(); + } + + Status Visit(const DictionaryType& type) { + (*out_)->buffers.resize(2, buffer_); + std::shared_ptr dictionary_data; + return MakeArrayOfNull(type.value_type(), 0, &(*out_)->dictionary); + } + + Status Visit(const DataType& type) { + return Status::NotImplemented("construction of all-null ", type); + } + + Status CreateChild(int i, int64_t length, std::shared_ptr* out) { + NullArrayFactory child_factory(type_->child(i)->type(), length, + &(*out_)->child_data[i]); + child_factory.buffer_ = buffer_; + return child_factory.Create(); + } + + std::shared_ptr type_; + int64_t length_; + std::shared_ptr* out_; + std::shared_ptr buffer_; +}; + +} // namespace internal + +Status MakeArrayOfNull(const std::shared_ptr& type, int64_t length, + std::shared_ptr* out) { + std::shared_ptr out_data; + RETURN_NOT_OK(internal::NullArrayFactory(type, length, &out_data).Create()); + *out = MakeArray(out_data); + return Status::OK(); +} + +namespace internal { + std::vector RechunkArraysConsistently( const std::vector& groups) { if (groups.size() <= 1) { diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h index a655422730238..78542c6b5bec4 100644 --- a/cpp/src/arrow/array.h +++ b/cpp/src/arrow/array.h @@ -220,6 +220,14 @@ struct ARROW_EXPORT ArrayData { ARROW_EXPORT std::shared_ptr MakeArray(const std::shared_ptr& data); +/// \brief Create a strongly-typed Array instance with all elements null +/// \param[in] type the array type +/// \param[in] length the array length +/// \param[out] out resulting Array instance +ARROW_EXPORT +Status MakeArrayOfNull(const std::shared_ptr& type, int64_t length, + std::shared_ptr* out); + // ---------------------------------------------------------------------- // User array accessor types @@ -521,12 +529,15 @@ class ARROW_EXPORT ListArray : public Array { /// Return pointer to raw value offsets accounting for any slice offset const int32_t* raw_value_offsets() const { return raw_value_offsets_ + data_->offset; } - // Neither of these functions will perform boundschecking + // The following functions will not perform boundschecking int32_t value_offset(int64_t i) const { return raw_value_offsets_[i + data_->offset]; } int32_t value_length(int64_t i) const { i += data_->offset; return raw_value_offsets_[i + 1] - raw_value_offsets_[i]; } + std::shared_ptr value_slice(int64_t i) const { + return values_->Slice(value_offset(i), value_length(i)); + } protected: // This constructor defers SetData to a derived array class @@ -596,12 +607,15 @@ class ARROW_EXPORT FixedSizeListArray : public Array { std::shared_ptr value_type() const; - // Neither of these functions will perform boundschecking + // The following functions will not perform boundschecking int32_t value_offset(int64_t i) const { i += data_->offset; return static_cast(list_size_ * i); } int32_t value_length(int64_t i = 0) const { return list_size_; } + std::shared_ptr value_slice(int64_t i) const { + return values_->Slice(value_offset(i), value_length(i)); + } protected: void SetData(const std::shared_ptr& data); diff --git a/cpp/src/arrow/array/builder_primitive.cc b/cpp/src/arrow/array/builder_primitive.cc index d4def92760027..3c899c068cb84 100644 --- a/cpp/src/arrow/array/builder_primitive.cc +++ b/cpp/src/arrow/array/builder_primitive.cc @@ -128,4 +128,11 @@ Status BooleanBuilder::AppendValues(const std::vector& values) { return Status::OK(); } +Status BooleanBuilder::AppendValues(int64_t length, bool value) { + RETURN_NOT_OK(Reserve(length)); + data_builder_.UnsafeAppend(length, value); + ArrayBuilder::UnsafeSetNotNull(length); + return Status::OK(); +} + } // namespace arrow diff --git a/cpp/src/arrow/array/builder_primitive.h b/cpp/src/arrow/array/builder_primitive.h index 3d566846d1947..8abbe029e1341 100644 --- a/cpp/src/arrow/array/builder_primitive.h +++ b/cpp/src/arrow/array/builder_primitive.h @@ -409,6 +409,8 @@ class ARROW_EXPORT BooleanBuilder : public ArrayBuilder { return Status::OK(); } + Status AppendValues(int64_t length, bool value); + Status FinishInternal(std::shared_ptr* out) override; /// \cond FALSE diff --git a/cpp/src/arrow/compute/benchmark-util.h b/cpp/src/arrow/compute/benchmark-util.h index ee9cb9504a3d1..113fdd7031281 100644 --- a/cpp/src/arrow/compute/benchmark-util.h +++ b/cpp/src/arrow/compute/benchmark-util.h @@ -70,5 +70,28 @@ void RegressionSetArgs(benchmark::internal::Benchmark* bench) { BenchmarkSetArgsWithSizes(bench, {kL1Size}); } +// RAII struct to handle some of the boilerplate in regression benchmarks +struct RegressionArgs { + // size of memory tested (per iteration) in bytes + const int64_t size; + + // proportion of nulls in generated arrays + const double null_proportion; + + explicit RegressionArgs(benchmark::State& state) + : size(state.range(0)), + null_proportion(static_cast(state.range(1)) / 100.0), + state_(state) {} + + ~RegressionArgs() { + state_.counters["size"] = static_cast(size); + state_.counters["null_percent"] = static_cast(state_.range(1)); + state_.SetBytesProcessed(state_.iterations() * size); + } + + private: + benchmark::State& state_; +}; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 6c386c9b39c0b..1bbb5bc5f321e 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -20,13 +20,17 @@ arrow_install_all_headers("arrow/compute/kernels") add_arrow_test(boolean-test PREFIX "arrow-compute") add_arrow_test(cast-test PREFIX "arrow-compute") add_arrow_test(hash-test PREFIX "arrow-compute") -add_arrow_test(take-test PREFIX "arrow-compute") add_arrow_test(util-internal-test PREFIX "arrow-compute") # Aggregates add_arrow_test(aggregate-test PREFIX "arrow-compute") add_arrow_benchmark(aggregate-benchmark PREFIX "arrow-compute") -# Filters +# Comparison +add_arrow_test(compare-test PREFIX "arrow-compute") +add_arrow_benchmark(compare-benchmark PREFIX "arrow-compute") + +# Selection +add_arrow_test(take-test PREFIX "arrow-compute") add_arrow_test(filter-test PREFIX "arrow-compute") add_arrow_benchmark(filter-benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/kernels/compare-benchmark.cc b/cpp/src/arrow/compute/kernels/compare-benchmark.cc new file mode 100644 index 0000000000000..6983fe5cc1bff --- /dev/null +++ b/cpp/src/arrow/compute/kernels/compare-benchmark.cc @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "benchmark/benchmark.h" + +#include + +#include "arrow/compute/benchmark-util.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/kernels/compare.h" +#include "arrow/compute/test-util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" + +namespace arrow { +namespace compute { + +constexpr auto kSeed = 0x94378165; + +static void CompareArrayScalarKernel(benchmark::State& state) { + const int64_t memory_size = state.range(0); + const int64_t array_size = memory_size / sizeof(int64_t); + const double null_percent = static_cast(state.range(1)) / 100.0; + auto rand = random::RandomArrayGenerator(kSeed); + auto array = std::static_pointer_cast>( + rand.Int64(array_size, -100, 100, null_percent)); + + CompareOptions ge{GREATER_EQUAL}; + + FunctionContext ctx; + for (auto _ : state) { + Datum out; + ABORT_NOT_OK(Compare(&ctx, Datum(array), Datum(int64_t(0)), ge, &out)); + benchmark::DoNotOptimize(out); + } + + state.counters["size"] = static_cast(memory_size); + state.counters["null_percent"] = static_cast(state.range(1)); + state.SetBytesProcessed(state.iterations() * array_size * sizeof(int64_t)); +} + +static void CompareArrayArrayKernel(benchmark::State& state) { + const int64_t memory_size = state.range(0); + const int64_t array_size = memory_size / sizeof(int64_t); + const double null_percent = static_cast(state.range(1)) / 100.0; + auto rand = random::RandomArrayGenerator(kSeed); + auto lhs = std::static_pointer_cast>( + rand.Int64(array_size, -100, 100, null_percent)); + auto rhs = std::static_pointer_cast>( + rand.Int64(array_size, -100, 100, null_percent)); + + CompareOptions ge(GREATER_EQUAL); + + FunctionContext ctx; + for (auto _ : state) { + Datum out; + ABORT_NOT_OK(Compare(&ctx, Datum(lhs), Datum(rhs), ge, &out)); + benchmark::DoNotOptimize(out); + } + + state.counters["size"] = static_cast(memory_size); + state.counters["null_percent"] = static_cast(state.range(1)); + state.SetBytesProcessed(state.iterations() * array_size * sizeof(int64_t) * 2); +} + +BENCHMARK(CompareArrayScalarKernel)->Apply(RegressionSetArgs); +BENCHMARK(CompareArrayArrayKernel)->Apply(RegressionSetArgs); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/compare-test.cc b/cpp/src/arrow/compute/kernels/compare-test.cc new file mode 100644 index 0000000000000..1a6339c57d878 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/compare-test.cc @@ -0,0 +1,390 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include + +#include + +#include "arrow/array.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/kernels/compare.h" +#include "arrow/compute/test-util.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" +#include "arrow/util/checked_cast.h" + +#include "arrow/testing/gtest_common.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" + +namespace arrow { +namespace compute { + +TEST(TestComparatorOperator, BasicOperator) { + using T = int32_t; + std::vector vals{0, 1, 2, 3, 4, 5, 6}; + + for (int32_t i : vals) { + for (int32_t j : vals) { + EXPECT_EQ((Comparator::Compare(i, j)), i == j); + EXPECT_EQ((Comparator::Compare(i, j)), i != j); + EXPECT_EQ((Comparator::Compare(i, j)), i > j); + EXPECT_EQ((Comparator::Compare(i, j)), i >= j); + EXPECT_EQ((Comparator::Compare(i, j)), i < j); + EXPECT_EQ((Comparator::Compare(i, j)), i <= j); + } + } +} + +template +static void ValidateCompare(FunctionContext* ctx, CompareOptions options, + const Datum& lhs, const Datum& rhs, const Datum& expected) { + Datum result; + + ASSERT_OK(Compare(ctx, lhs, rhs, options, &result)); + AssertArraysEqual(*expected.make_array(), *result.make_array()); +} + +template +static void ValidateCompare(FunctionContext* ctx, CompareOptions options, + const char* lhs_str, const Datum& rhs, + const char* expected_str) { + auto lhs = ArrayFromJSON(TypeTraits::type_singleton(), lhs_str); + auto expected = ArrayFromJSON(TypeTraits::type_singleton(), expected_str); + ValidateCompare(ctx, options, lhs, rhs, expected); +} + +template +static void ValidateCompare(FunctionContext* ctx, CompareOptions options, + const Datum& lhs, const char* rhs_str, + const char* expected_str) { + auto rhs = ArrayFromJSON(TypeTraits::type_singleton(), rhs_str); + auto expected = ArrayFromJSON(TypeTraits::type_singleton(), expected_str); + ValidateCompare(ctx, options, lhs, rhs, expected); +} + +template +static void ValidateCompare(FunctionContext* ctx, CompareOptions options, + const char* lhs_str, const char* rhs_str, + const char* expected_str) { + auto lhs = ArrayFromJSON(TypeTraits::type_singleton(), lhs_str); + auto rhs = ArrayFromJSON(TypeTraits::type_singleton(), rhs_str); + auto expected = ArrayFromJSON(TypeTraits::type_singleton(), expected_str); + ValidateCompare(ctx, options, lhs, rhs, expected); +} + +template +static inline bool SlowCompare(CompareOperator op, const T& lhs, const T& rhs) { + switch (op) { + case EQUAL: + return lhs == rhs; + case NOT_EQUAL: + return lhs != rhs; + case GREATER: + return lhs > rhs; + case GREATER_EQUAL: + return lhs >= rhs; + case LESS: + return lhs < rhs; + case LESS_EQUAL: + return lhs <= rhs; + default: + return false; + } +} + +template +static Datum SimpleScalarArrayCompare(CompareOptions options, const Datum& lhs, + const Datum& rhs) { + using ArrayType = typename TypeTraits::ArrayType; + using ScalarType = typename TypeTraits::ScalarType; + using T = typename TypeTraits::CType; + + bool swap = lhs.is_array(); + auto array = std::static_pointer_cast((swap ? lhs : rhs).make_array()); + T value = std::static_pointer_cast((swap ? rhs : lhs).scalar())->value; + + std::vector bitmap(array->length()); + for (int64_t i = 0; i < array->length(); i++) { + bitmap[i] = swap ? SlowCompare(options.op, array->Value(i), value) + : SlowCompare(options.op, value, array->Value(i)); + } + + std::shared_ptr result; + + if (array->null_count() == 0) { + ArrayFromVector(bitmap, &result); + } else { + std::vector null_bitmap(array->length()); + auto reader = internal::BitmapReader(array->null_bitmap_data(), array->offset(), + array->length()); + for (int64_t i = 0; i < array->length(); i++, reader.Next()) { + null_bitmap[i] = reader.IsSet(); + } + ArrayFromVector(null_bitmap, bitmap, &result); + } + + return Datum(result); +} + +template ::ArrayType> +static std::vector NullBitmapFromArrays(const ArrayType& lhs, + const ArrayType& rhs) { + auto left_lambda = [&lhs](int64_t i) { + return lhs.null_count() == 0 ? true : lhs.IsValid(i); + }; + + auto right_lambda = [&rhs](int64_t i) { + return rhs.null_count() == 0 ? true : rhs.IsValid(i); + }; + + const int64_t length = lhs.length(); + std::vector null_bitmap(length); + + for (int64_t i = 0; i < length; i++) { + null_bitmap[i] = left_lambda(i) && right_lambda(i); + } + + return null_bitmap; +} + +template +static Datum SimpleArrayArrayCompare(CompareOptions options, const Datum& lhs, + const Datum& rhs) { + using ArrayType = typename TypeTraits::ArrayType; + using T = typename TypeTraits::CType; + + auto l_array = std::static_pointer_cast(lhs.make_array()); + auto r_array = std::static_pointer_cast(rhs.make_array()); + const int64_t length = l_array->length(); + + std::vector bitmap(length); + for (int64_t i = 0; i < length; i++) { + bitmap[i] = SlowCompare(options.op, l_array->Value(i), r_array->Value(i)); + } + + std::shared_ptr result; + + if (l_array->null_count() == 0 && r_array->null_count() == 0) { + ArrayFromVector(bitmap, &result); + } else { + std::vector null_bitmap = NullBitmapFromArrays(*l_array, *r_array); + ArrayFromVector(null_bitmap, bitmap, &result); + } + + return Datum(result); +} + +template +static void ValidateCompare(FunctionContext* ctx, CompareOptions options, + const Datum& lhs, const Datum& rhs) { + Datum result; + + bool has_scalar = lhs.is_scalar() || rhs.is_scalar(); + Datum expected = has_scalar ? SimpleScalarArrayCompare(options, lhs, rhs) + : SimpleArrayArrayCompare(options, lhs, rhs); + + ValidateCompare(ctx, options, lhs, rhs, expected); +} + +template +class TestNumericCompareKernel : public ComputeFixture, public TestBase {}; + +TYPED_TEST_CASE(TestNumericCompareKernel, NumericArrowTypes); +TYPED_TEST(TestNumericCompareKernel, SimpleCompareArrayScalar) { + using ScalarType = typename TypeTraits::ScalarType; + using CType = typename TypeTraits::CType; + + Datum one(std::make_shared(CType(1))); + + CompareOptions eq(CompareOperator::EQUAL); + ValidateCompare(&this->ctx_, eq, "[]", one, "[]"); + ValidateCompare(&this->ctx_, eq, "[null]", one, "[null]"); + ValidateCompare(&this->ctx_, eq, "[0,0,1,1,2,2]", one, "[0,0,1,1,0,0]"); + ValidateCompare(&this->ctx_, eq, "[0,1,2,3,4,5]", one, "[0,1,0,0,0,0]"); + ValidateCompare(&this->ctx_, eq, "[5,4,3,2,1,0]", one, "[0,0,0,0,1,0]"); + ValidateCompare(&this->ctx_, eq, "[null,0,1,1]", one, "[null,0,1,1]"); + + CompareOptions neq(CompareOperator::NOT_EQUAL); + ValidateCompare(&this->ctx_, neq, "[]", one, "[]"); + ValidateCompare(&this->ctx_, neq, "[null]", one, "[null]"); + ValidateCompare(&this->ctx_, neq, "[0,0,1,1,2,2]", one, "[1,1,0,0,1,1]"); + ValidateCompare(&this->ctx_, neq, "[0,1,2,3,4,5]", one, "[1,0,1,1,1,1]"); + ValidateCompare(&this->ctx_, neq, "[5,4,3,2,1,0]", one, "[1,1,1,1,0,1]"); + ValidateCompare(&this->ctx_, neq, "[null,0,1,1]", one, "[null,1,0,0]"); + + CompareOptions gt(CompareOperator::GREATER); + ValidateCompare(&this->ctx_, gt, "[]", one, "[]"); + ValidateCompare(&this->ctx_, gt, "[null]", one, "[null]"); + ValidateCompare(&this->ctx_, gt, "[0,0,1,1,2,2]", one, "[0,0,0,0,1,1]"); + ValidateCompare(&this->ctx_, gt, "[0,1,2,3,4,5]", one, "[0,0,1,1,1,1]"); + ValidateCompare(&this->ctx_, gt, "[4,5,6,7,8,9]", one, "[1,1,1,1,1,1]"); + ValidateCompare(&this->ctx_, gt, "[null,0,1,1]", one, "[null,0,0,0]"); + + CompareOptions gte(CompareOperator::GREATER_EQUAL); + ValidateCompare(&this->ctx_, gte, "[]", one, "[]"); + ValidateCompare(&this->ctx_, gte, "[null]", one, "[null]"); + ValidateCompare(&this->ctx_, gte, "[0,0,1,1,2,2]", one, "[0,0,1,1,1,1]"); + ValidateCompare(&this->ctx_, gte, "[0,1,2,3,4,5]", one, "[0,1,1,1,1,1]"); + ValidateCompare(&this->ctx_, gte, "[4,5,6,7,8,9]", one, "[1,1,1,1,1,1]"); + ValidateCompare(&this->ctx_, gte, "[null,0,1,1]", one, "[null,0,1,1]"); + + CompareOptions lt(CompareOperator::LESS); + ValidateCompare(&this->ctx_, lt, "[]", one, "[]"); + ValidateCompare(&this->ctx_, lt, "[null]", one, "[null]"); + ValidateCompare(&this->ctx_, lt, "[0,0,1,1,2,2]", one, "[1,1,0,0,0,0]"); + ValidateCompare(&this->ctx_, lt, "[0,1,2,3,4,5]", one, "[1,0,0,0,0,0]"); + ValidateCompare(&this->ctx_, lt, "[4,5,6,7,8,9]", one, "[0,0,0,0,0,0]"); + ValidateCompare(&this->ctx_, lt, "[null,0,1,1]", one, "[null,1,0,0]"); + + CompareOptions lte(CompareOperator::LESS_EQUAL); + ValidateCompare(&this->ctx_, lte, "[]", one, "[]"); + ValidateCompare(&this->ctx_, lte, "[null]", one, "[null]"); + ValidateCompare(&this->ctx_, lte, "[0,0,1,1,2,2]", one, "[1,1,1,1,0,0]"); + ValidateCompare(&this->ctx_, lte, "[0,1,2,3,4,5]", one, "[1,1,0,0,0,0]"); + ValidateCompare(&this->ctx_, lte, "[4,5,6,7,8,9]", one, "[0,0,0,0,0,0]"); + ValidateCompare(&this->ctx_, lte, "[null,0,1,1]", one, "[null,1,1,1]"); +} + +TYPED_TEST(TestNumericCompareKernel, SimpleCompareScalarArray) { + using ScalarType = typename TypeTraits::ScalarType; + using CType = typename TypeTraits::CType; + + Datum one(std::make_shared(CType(1))); + + CompareOptions eq(CompareOperator::EQUAL); + ValidateCompare(&this->ctx_, eq, one, "[]", "[]"); + ValidateCompare(&this->ctx_, eq, one, "[null]", "[null]"); + ValidateCompare(&this->ctx_, eq, one, "[0,0,1,1,2,2]", "[0,0,1,1,0,0]"); + ValidateCompare(&this->ctx_, eq, one, "[0,1,2,3,4,5]", "[0,1,0,0,0,0]"); + ValidateCompare(&this->ctx_, eq, one, "[5,4,3,2,1,0]", "[0,0,0,0,1,0]"); + ValidateCompare(&this->ctx_, eq, one, "[null,0,1,1]", "[null,0,1,1]"); + + CompareOptions neq(CompareOperator::NOT_EQUAL); + ValidateCompare(&this->ctx_, neq, one, "[]", "[]"); + ValidateCompare(&this->ctx_, neq, one, "[null]", "[null]"); + ValidateCompare(&this->ctx_, neq, one, "[0,0,1,1,2,2]", "[1,1,0,0,1,1]"); + ValidateCompare(&this->ctx_, neq, one, "[0,1,2,3,4,5]", "[1,0,1,1,1,1]"); + ValidateCompare(&this->ctx_, neq, one, "[5,4,3,2,1,0]", "[1,1,1,1,0,1]"); + ValidateCompare(&this->ctx_, neq, one, "[null,0,1,1]", "[null,1,0,0]"); + + CompareOptions gt(CompareOperator::GREATER); + ValidateCompare(&this->ctx_, gt, one, "[]", "[]"); + ValidateCompare(&this->ctx_, gt, one, "[null]", "[null]"); + ValidateCompare(&this->ctx_, gt, one, "[0,0,1,1,2,2]", "[1,1,0,0,0,0]"); + ValidateCompare(&this->ctx_, gt, one, "[0,1,2,3,4,5]", "[1,0,0,0,0,0]"); + ValidateCompare(&this->ctx_, gt, one, "[4,5,6,7,8,9]", "[0,0,0,0,0,0]"); + ValidateCompare(&this->ctx_, gt, one, "[null,0,1,1]", "[null,1,0,0]"); + + CompareOptions gte(CompareOperator::GREATER_EQUAL); + ValidateCompare(&this->ctx_, gte, one, "[]", "[]"); + ValidateCompare(&this->ctx_, gte, one, "[null]", "[null]"); + ValidateCompare(&this->ctx_, gte, one, "[0,0,1,1,2,2]", "[1,1,1,1,0,0]"); + ValidateCompare(&this->ctx_, gte, one, "[0,1,2,3,4,5]", "[1,1,0,0,0,0]"); + ValidateCompare(&this->ctx_, gte, one, "[4,5,6,7,8,9]", "[0,0,0,0,0,0]"); + ValidateCompare(&this->ctx_, gte, one, "[null,0,1,1]", "[null,1,1,1]"); + + CompareOptions lt(CompareOperator::LESS); + ValidateCompare(&this->ctx_, lt, one, "[]", "[]"); + ValidateCompare(&this->ctx_, lt, one, "[null]", "[null]"); + ValidateCompare(&this->ctx_, lt, one, "[0,0,1,1,2,2]", "[0,0,0,0,1,1]"); + ValidateCompare(&this->ctx_, lt, one, "[0,1,2,3,4,5]", "[0,0,1,1,1,1]"); + ValidateCompare(&this->ctx_, lt, one, "[4,5,6,7,8,9]", "[1,1,1,1,1,1]"); + ValidateCompare(&this->ctx_, lt, one, "[null,0,1,1]", "[null,0,0,0]"); + + CompareOptions lte(CompareOperator::LESS_EQUAL); + ValidateCompare(&this->ctx_, lte, one, "[]", "[]"); + ValidateCompare(&this->ctx_, lte, one, "[null]", "[null]"); + ValidateCompare(&this->ctx_, lte, one, "[0,0,1,1,2,2]", "[0,0,1,1,1,1]"); + ValidateCompare(&this->ctx_, lte, one, "[0,1,2,3,4,5]", "[0,1,1,1,1,1]"); + ValidateCompare(&this->ctx_, lte, one, "[4,5,6,7,8,9]", "[1,1,1,1,1,1]"); + ValidateCompare(&this->ctx_, lte, one, "[null,0,1,1]", "[null,0,1,1]"); +} + +TYPED_TEST(TestNumericCompareKernel, TestNullScalar) { + /* Ensure that null scalar broadcast to all null results. */ + using ScalarType = typename TypeTraits::ScalarType; + using CType = typename TypeTraits::CType; + + Datum null(std::make_shared(CType(0), false)); + EXPECT_FALSE(null.scalar()->is_valid); + + CompareOptions eq(CompareOperator::EQUAL); + ValidateCompare(&this->ctx_, eq, "[]", null, "[]"); + ValidateCompare(&this->ctx_, eq, null, "[]", "[]"); + ValidateCompare(&this->ctx_, eq, "[null]", null, "[null]"); + ValidateCompare(&this->ctx_, eq, null, "[null]", "[null]"); + ValidateCompare(&this->ctx_, eq, null, "[1,2,3]", "[null, null, null]"); +} + +TYPED_TEST_CASE(TestNumericCompareKernel, NumericArrowTypes); +TYPED_TEST(TestNumericCompareKernel, RandomCompareArrayScalar) { + using ScalarType = typename TypeTraits::ScalarType; + using CType = typename TypeTraits::CType; + + auto rand = random::RandomArrayGenerator(0x5416447); + for (size_t i = 3; i < 13; i++) { + for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { + for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { + const int64_t length = static_cast(1ULL << i); + auto array = Datum(rand.Numeric(length, 0, 100, null_probability)); + auto fifty = Datum(std::make_shared(CType(50))); + auto options = CompareOptions(op); + ValidateCompare(&this->ctx_, options, array, fifty); + ValidateCompare(&this->ctx_, options, fifty, array); + } + } + } +} + +TYPED_TEST(TestNumericCompareKernel, SimpleCompareArrayArray) { + /* Ensure that null scalar broadcast to all null results. */ + CompareOptions eq(CompareOperator::EQUAL); + ValidateCompare(&this->ctx_, eq, "[]", "[]", "[]"); + ValidateCompare(&this->ctx_, eq, "[null]", "[null]", "[null]"); + ValidateCompare(&this->ctx_, eq, "[1]", "[1]", "[1]"); + ValidateCompare(&this->ctx_, eq, "[1]", "[2]", "[0]"); + ValidateCompare(&this->ctx_, eq, "[null]", "[1]", "[null]"); + ValidateCompare(&this->ctx_, eq, "[1]", "[null]", "[null]"); + + CompareOptions lte(CompareOperator::LESS_EQUAL); + ValidateCompare(&this->ctx_, lte, "[1,2,3,4,5]", "[2,3,4,5,6]", + "[1,1,1,1,1]"); +} + +TYPED_TEST(TestNumericCompareKernel, RandomCompareArrayArray) { + auto rand = random::RandomArrayGenerator(0x5416447); + for (size_t i = 3; i < 5; i++) { + for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { + for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { + const int64_t length = static_cast(1ULL << i); + auto lhs = Datum(rand.Numeric(length << i, 0, 100, null_probability)); + auto rhs = Datum(rand.Numeric(length << i, 0, 100, null_probability)); + auto options = CompareOptions(op); + ValidateCompare(&this->ctx_, options, lhs, rhs); + } + } + } +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/compare.cc b/cpp/src/arrow/compute/kernels/compare.cc index 040793f4e6569..f5fab7d0fe2a3 100644 --- a/cpp/src/arrow/compute/kernels/compare.cc +++ b/cpp/src/arrow/compute/kernels/compare.cc @@ -19,7 +19,6 @@ #include "arrow/compute/context.h" #include "arrow/compute/kernel.h" -#include "arrow/compute/kernels/filter.h" #include "arrow/compute/kernels/util-internal.h" #include "arrow/util/bit-util.h" #include "arrow/util/logging.h" @@ -28,8 +27,35 @@ namespace arrow { namespace compute { -class FunctionContext; -struct Datum; +std::shared_ptr CompareBinaryKernel::out_type() const { + return compare_function_->out_type(); +} + +Status CompareBinaryKernel::Call(FunctionContext* ctx, const Datum& left, + const Datum& right, Datum* out) { + DCHECK(left.type()->Equals(right.type())); + + auto lk = left.kind(); + auto rk = right.kind(); + auto out_array = out->array(); + + if (lk == Datum::ARRAY && rk == Datum::SCALAR) { + auto array = left.array(); + auto scalar = right.scalar(); + return compare_function_->Compare(*array, *scalar, &out_array); + } else if (lk == Datum::SCALAR && rk == Datum::ARRAY) { + auto scalar = left.scalar(); + auto array = right.array(); + auto out_array = out->array(); + return compare_function_->Compare(*scalar, *array, &out_array); + } else if (lk == Datum::ARRAY && rk == Datum::ARRAY) { + auto lhs = left.array(); + auto rhs = right.array(); + return compare_function_->Compare(*lhs, *rhs, &out_array); + } + + return Status::Invalid("Invalid datum signature for CompareBinaryKernel"); +} template ::ScalarType, @@ -76,13 +102,14 @@ static Status CompareArrayArray(const ArrayData& lhs, const ArrayData& rhs, } template -class CompareFunction final : public FilterFunction { +class CompareFunctionImpl final : public CompareFunction { + using ArrayType = typename TypeTraits::ArrayType; using ScalarType = typename TypeTraits::ScalarType; public: - explicit CompareFunction(FunctionContext* ctx) : ctx_(ctx) {} + explicit CompareFunctionImpl(FunctionContext* ctx) : ctx_(ctx) {} - Status Filter(const ArrayData& array, const Scalar& scalar, ArrayData* output) const { + Status Compare(const ArrayData& array, const Scalar& scalar, ArrayData* output) const { // Caller must cast DCHECK(array.type->Equals(scalar.type)); // Output must be a boolean array @@ -103,7 +130,7 @@ class CompareFunction final : public FilterFunction { array, static_cast(scalar), bitmap_result); } - Status Filter(const Scalar& scalar, const ArrayData& array, ArrayData* output) const { + Status Compare(const Scalar& scalar, const ArrayData& array, ArrayData* output) const { // Caller must cast DCHECK(array.type->Equals(scalar.type)); // Output must be a boolean array @@ -124,7 +151,7 @@ class CompareFunction final : public FilterFunction { array, bitmap_result); } - Status Filter(const ArrayData& lhs, const ArrayData& rhs, ArrayData* output) const { + Status Compare(const ArrayData& lhs, const ArrayData& rhs, ArrayData* output) const { // Caller must cast DCHECK(lhs.type->Equals(rhs.type)); // Output must be a boolean array @@ -146,13 +173,13 @@ class CompareFunction final : public FilterFunction { }; template -static inline std::shared_ptr MakeCompareFunctionTypeOp( +static inline std::shared_ptr MakeCompareFunctionTypeOp( FunctionContext* ctx) { - return std::make_shared>(ctx); + return std::make_shared>(ctx); } template -static inline std::shared_ptr MakeCompareFilterFunctionType( +static inline std::shared_ptr MakeCompareFunctionType( FunctionContext* ctx, struct CompareOptions options) { switch (options.op) { case CompareOperator::EQUAL: @@ -172,40 +199,40 @@ static inline std::shared_ptr MakeCompareFilterFunctionType( return nullptr; } -std::shared_ptr MakeCompareFilterFunction(FunctionContext* ctx, - const DataType& type, - struct CompareOptions options) { +std::shared_ptr MakeCompareFunction(FunctionContext* ctx, + const DataType& type, + struct CompareOptions options) { switch (type.id()) { case UInt8Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case Int8Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case UInt16Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case Int16Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case UInt32Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case Int32Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case UInt64Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case Int64Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case FloatType::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case DoubleType::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case Date32Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case Date64Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case TimestampType::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case Time32Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); case Time64Type::type_id: - return MakeCompareFilterFunctionType(ctx, options); + return MakeCompareFunctionType(ctx, options); default: return nullptr; } @@ -219,15 +246,15 @@ Status Compare(FunctionContext* context, const Datum& left, const Datum& right, auto type = left.type(); DCHECK(type->Equals(right.type())); // Requires that both types are equal. - auto fn = MakeCompareFilterFunction(context, *type, options); + auto fn = MakeCompareFunction(context, *type, options); if (fn == nullptr) { return Status::NotImplemented("Compare not implemented for type ", type->ToString()); } - FilterBinaryKernel filter_kernel(fn); + CompareBinaryKernel filter_kernel(fn); detail::PrimitiveAllocatingBinaryKernel kernel(&filter_kernel); - const int64_t length = FilterBinaryKernel::out_length(left, right); + const int64_t length = CompareBinaryKernel::out_length(left, right); out->value = ArrayData::Make(filter_kernel.out_type(), length); return kernel.Call(context, left, right, out); diff --git a/cpp/src/arrow/compute/kernels/compare.h b/cpp/src/arrow/compute/kernels/compare.h index a192451291677..b4c9612ae8caa 100644 --- a/cpp/src/arrow/compute/kernels/compare.h +++ b/cpp/src/arrow/compute/kernels/compare.h @@ -19,6 +19,7 @@ #include +#include "arrow/compute/kernel.h" #include "arrow/util/visibility.h" namespace arrow { @@ -31,9 +32,68 @@ class Status; namespace compute { struct Datum; -class FilterFunction; class FunctionContext; +/// CompareFunction is an interface for Comparisons +/// +/// Comparisons take an array and emits a selection vector. The selection vector +/// is given in the form of a bitmask as a BooleanArray result. +class ARROW_EXPORT CompareFunction { + public: + /// Compare an array with a scalar argument. + virtual Status Compare(const ArrayData& array, const Scalar& scalar, + ArrayData* output) const = 0; + + Status Compare(const ArrayData& array, const Scalar& scalar, + std::shared_ptr* output) { + return Compare(array, scalar, output->get()); + } + + virtual Status Compare(const Scalar& scalar, const ArrayData& array, + ArrayData* output) const = 0; + + Status Compare(const Scalar& scalar, const ArrayData& array, + std::shared_ptr* output) { + return Compare(scalar, array, output->get()); + } + + /// Compare an array with an array argument. + virtual Status Compare(const ArrayData& lhs, const ArrayData& rhs, + ArrayData* output) const = 0; + + Status Compare(const ArrayData& lhs, const ArrayData& rhs, + std::shared_ptr* output) { + return Compare(lhs, rhs, output->get()); + } + + /// By default, CompareFunction emits a result bitmap. + virtual std::shared_ptr out_type() const { return boolean(); } + + virtual ~CompareFunction() {} +}; + +/// \brief BinaryKernel bound to a select function +class ARROW_EXPORT CompareBinaryKernel : public BinaryKernel { + public: + explicit CompareBinaryKernel(std::shared_ptr& select) + : compare_function_(select) {} + + Status Call(FunctionContext* ctx, const Datum& left, const Datum& right, + Datum* out) override; + + static int64_t out_length(const Datum& left, const Datum& right) { + if (left.kind() == Datum::ARRAY) return left.length(); + if (right.kind() == Datum::ARRAY) return right.length(); + + return 0; + } + + std::shared_ptr out_type() const override; + + private: + std::shared_ptr compare_function_; +}; + enum CompareOperator { EQUAL, NOT_EQUAL, @@ -82,7 +142,7 @@ struct CompareOptions { enum CompareOperator op; }; -/// \brief Return a Compare FilterFunction +/// \brief Return a Compare CompareFunction /// /// \param[in] context FunctionContext passing context information /// \param[in] type required to specialize the kernel @@ -91,9 +151,9 @@ struct CompareOptions { /// \since 0.14.0 /// \note API not yet finalized ARROW_EXPORT -std::shared_ptr MakeCompareFilterFunction(FunctionContext* context, - const DataType& type, - struct CompareOptions options); +std::shared_ptr MakeCompareFunction(FunctionContext* context, + const DataType& type, + struct CompareOptions options); /// \brief Compare a numeric array with a scalar. /// diff --git a/cpp/src/arrow/compute/kernels/filter-benchmark.cc b/cpp/src/arrow/compute/kernels/filter-benchmark.cc index 00de199bfe90d..3eb460adc02d4 100644 --- a/cpp/src/arrow/compute/kernels/filter-benchmark.cc +++ b/cpp/src/arrow/compute/kernels/filter-benchmark.cc @@ -17,11 +17,9 @@ #include "benchmark/benchmark.h" -#include +#include "arrow/compute/kernels/filter.h" #include "arrow/compute/benchmark-util.h" -#include "arrow/compute/kernel.h" -#include "arrow/compute/kernels/compare.h" #include "arrow/compute/test-util.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" @@ -29,54 +27,60 @@ namespace arrow { namespace compute { -static void CompareArrayScalarKernel(benchmark::State& state) { - const int64_t memory_size = state.range(0) / 4; - const int64_t array_size = memory_size / sizeof(int64_t); - const double null_percent = static_cast(state.range(1)) / 100.0; - auto rand = random::RandomArrayGenerator(0x94378165); - auto array = std::static_pointer_cast>( - rand.Int64(array_size, -100, 100, null_percent)); +constexpr auto kSeed = 0x0ff1ce; + +static void FilterInt64(benchmark::State& state) { + RegressionArgs args(state); - CompareOptions ge{GREATER_EQUAL}; + const int64_t array_size = args.size / sizeof(int64_t); + auto rand = random::RandomArrayGenerator(kSeed); + auto array = std::static_pointer_cast>( + rand.Int64(array_size, -100, 100, args.null_proportion)); + auto filter = std::static_pointer_cast( + rand.Boolean(array_size, 0.75, args.null_proportion)); FunctionContext ctx; for (auto _ : state) { Datum out; - ABORT_NOT_OK(Compare(&ctx, Datum(array), Datum(int64_t(0)), ge, &out)); + ABORT_NOT_OK(Filter(&ctx, Datum(array), Datum(filter), &out)); benchmark::DoNotOptimize(out); } - - state.counters["size"] = static_cast(memory_size); - state.counters["null_percent"] = static_cast(state.range(1)); - state.SetBytesProcessed(state.iterations() * array_size * sizeof(int64_t)); } -static void CompareArrayArrayKernel(benchmark::State& state) { - const int64_t memory_size = state.range(0) / 4; - const int64_t array_size = memory_size / sizeof(int64_t); - const double null_percent = static_cast(state.range(1)) / 100.0; - auto rand = random::RandomArrayGenerator(0x94378165); - auto lhs = std::static_pointer_cast>( - rand.Int64(array_size, -100, 100, null_percent)); - auto rhs = std::static_pointer_cast>( - rand.Int64(array_size, -100, 100, null_percent)); +static void FilterFixedSizeList1Int64(benchmark::State& state) { + RegressionArgs args(state); - CompareOptions ge(GREATER_EQUAL); + const int64_t array_size = args.size / sizeof(int64_t); + auto rand = random::RandomArrayGenerator(kSeed); + auto int_array = std::static_pointer_cast>( + rand.Int64(array_size, -100, 100, args.null_proportion)); + auto array = std::make_shared( + fixed_size_list(int64(), 1), array_size, int_array, int_array->null_bitmap(), + int_array->null_count()); + auto filter = std::static_pointer_cast( + rand.Boolean(array_size, 0.75, args.null_proportion)); FunctionContext ctx; for (auto _ : state) { Datum out; - ABORT_NOT_OK(Compare(&ctx, Datum(lhs), Datum(rhs), ge, &out)); + ABORT_NOT_OK(Filter(&ctx, Datum(array), Datum(filter), &out)); benchmark::DoNotOptimize(out); } - - state.counters["size"] = static_cast(memory_size); - state.counters["null_percent"] = static_cast(state.range(1)); - state.SetBytesProcessed(state.iterations() * array_size * sizeof(int64_t) * 2); } -BENCHMARK(CompareArrayScalarKernel)->Apply(RegressionSetArgs); -BENCHMARK(CompareArrayArrayKernel)->Apply(RegressionSetArgs); +BENCHMARK(FilterInt64) + ->Apply(RegressionSetArgs) + ->Args({1 << 20, 1}) + ->Args({1 << 23, 1}) + ->MinTime(1.0) + ->Unit(benchmark::TimeUnit::kNanosecond); + +BENCHMARK(FilterFixedSizeList1Int64) + ->Apply(RegressionSetArgs) + ->Args({1 << 20, 1}) + ->Args({1 << 23, 1}) + ->MinTime(1.0) + ->Unit(benchmark::TimeUnit::kNanosecond); } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/filter-test.cc b/cpp/src/arrow/compute/kernels/filter-test.cc index 1c8967cf7aca7..7b349492b1daa 100644 --- a/cpp/src/arrow/compute/kernels/filter-test.cc +++ b/cpp/src/arrow/compute/kernels/filter-test.cc @@ -15,376 +15,358 @@ // specific language governing permissions and limitations // under the License. -#include #include -#include -#include -#include +#include -#include - -#include "arrow/array.h" -#include "arrow/compute/kernel.h" +#include "arrow/compute/context.h" +#include "arrow/compute/kernels/boolean.h" #include "arrow/compute/kernels/compare.h" #include "arrow/compute/kernels/filter.h" #include "arrow/compute/test-util.h" -#include "arrow/type.h" -#include "arrow/type_traits.h" -#include "arrow/util/checked_cast.h" - #include "arrow/testing/gtest_common.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" +#include "arrow/testing/util.h" namespace arrow { namespace compute { -TEST(TestComparatorOperator, BasicOperator) { - using T = int32_t; - std::vector vals{0, 1, 2, 3, 4, 5, 6}; - - for (int32_t i : vals) { - for (int32_t j : vals) { - EXPECT_EQ((Comparator::Compare(i, j)), i == j); - EXPECT_EQ((Comparator::Compare(i, j)), i != j); - EXPECT_EQ((Comparator::Compare(i, j)), i > j); - EXPECT_EQ((Comparator::Compare(i, j)), i >= j); - EXPECT_EQ((Comparator::Compare(i, j)), i < j); - EXPECT_EQ((Comparator::Compare(i, j)), i <= j); +using internal::checked_pointer_cast; +using util::string_view; + +template +class TestFilterKernel : public ComputeFixture, public TestBase { + protected: + void AssertFilterArrays(const std::shared_ptr& values, + const std::shared_ptr& filter, + const std::shared_ptr& expected) { + std::shared_ptr actual; + ASSERT_OK(arrow::compute::Filter(&this->ctx_, *values, *filter, &actual)); + AssertArraysEqual(*expected, *actual); + } + void AssertFilter(const std::shared_ptr& type, const std::string& values, + const std::string& filter, const std::string& expected) { + std::shared_ptr actual; + ASSERT_OK(this->Filter(type, values, filter, &actual)); + AssertArraysEqual(*ArrayFromJSON(type, expected), *actual); + } + Status Filter(const std::shared_ptr& type, const std::string& values, + const std::string& filter, std::shared_ptr* out) { + return arrow::compute::Filter(&this->ctx_, *ArrayFromJSON(type, values), + *ArrayFromJSON(boolean(), filter), out); + } + void ValidateFilter(const std::shared_ptr& values, + const std::shared_ptr& filter_boxed) { + std::shared_ptr filtered; + ASSERT_OK(arrow::compute::Filter(&this->ctx_, *values, *filter_boxed, &filtered)); + + auto filter = checked_pointer_cast(filter_boxed); + int64_t values_i = 0, filtered_i = 0; + for (; values_i < values->length(); ++values_i, ++filtered_i) { + if (filter->IsNull(values_i)) { + ASSERT_LT(filtered_i, filtered->length()); + ASSERT_TRUE(filtered->IsNull(filtered_i)); + continue; + } + if (!filter->Value(values_i)) { + // this element was filtered out; don't examine filtered + --filtered_i; + continue; + } + ASSERT_LT(filtered_i, filtered->length()); + ASSERT_TRUE(values->RangeEquals(values_i, values_i + 1, filtered_i, filtered)); } + ASSERT_EQ(filtered_i, filtered->length()); } -} +}; -template -static void ValidateCompare(FunctionContext* ctx, CompareOptions options, - const Datum& lhs, const Datum& rhs, const Datum& expected) { - Datum result; +class TestFilterKernelWithNull : public TestFilterKernel { + protected: + void AssertFilter(const std::string& values, const std::string& filter, + const std::string& expected) { + TestFilterKernel::AssertFilter(utf8(), values, filter, expected); + } +}; - ASSERT_OK(Compare(ctx, lhs, rhs, options, &result)); - AssertArraysEqual(*expected.make_array(), *result.make_array()); +TEST_F(TestFilterKernelWithNull, FilterNull) { + this->AssertFilter("[null, null, null]", "[0, 1, 0]", "[null]"); + this->AssertFilter("[null, null, null]", "[1, 1, 0]", "[null, null]"); } -template -static void ValidateCompare(FunctionContext* ctx, CompareOptions options, - const char* lhs_str, const Datum& rhs, - const char* expected_str) { - auto lhs = ArrayFromJSON(TypeTraits::type_singleton(), lhs_str); - auto expected = ArrayFromJSON(TypeTraits::type_singleton(), expected_str); - ValidateCompare(ctx, options, lhs, rhs, expected); -} +class TestFilterKernelWithBoolean : public TestFilterKernel { + protected: + void AssertFilter(const std::string& values, const std::string& filter, + const std::string& expected) { + TestFilterKernel::AssertFilter(boolean(), values, filter, expected); + } +}; -template -static void ValidateCompare(FunctionContext* ctx, CompareOptions options, - const Datum& lhs, const char* rhs_str, - const char* expected_str) { - auto rhs = ArrayFromJSON(TypeTraits::type_singleton(), rhs_str); - auto expected = ArrayFromJSON(TypeTraits::type_singleton(), expected_str); - ValidateCompare(ctx, options, lhs, rhs, expected); +TEST_F(TestFilterKernelWithBoolean, FilterBoolean) { + this->AssertFilter("[true, false, true]", "[0, 1, 0]", "[false]"); + this->AssertFilter("[null, false, true]", "[0, 1, 0]", "[false]"); + this->AssertFilter("[true, false, true]", "[null, 1, 0]", "[null, false]"); } template -static void ValidateCompare(FunctionContext* ctx, CompareOptions options, - const char* lhs_str, const char* rhs_str, - const char* expected_str) { - auto lhs = ArrayFromJSON(TypeTraits::type_singleton(), lhs_str); - auto rhs = ArrayFromJSON(TypeTraits::type_singleton(), rhs_str); - auto expected = ArrayFromJSON(TypeTraits::type_singleton(), expected_str); - ValidateCompare(ctx, options, lhs, rhs, expected); -} - -template -static inline bool SlowCompare(CompareOperator op, const T& lhs, const T& rhs) { - switch (op) { - case EQUAL: - return lhs == rhs; - case NOT_EQUAL: - return lhs != rhs; - case GREATER: - return lhs > rhs; - case GREATER_EQUAL: - return lhs >= rhs; - case LESS: - return lhs < rhs; - case LESS_EQUAL: - return lhs <= rhs; - default: - return false; +class TestFilterKernelWithNumeric : public TestFilterKernel { + protected: + void AssertFilter(const std::string& values, const std::string& filter, + const std::string& expected) { + TestFilterKernel::AssertFilter(type_singleton(), values, filter, expected); } -} - -template -static Datum SimpleScalarArrayCompare(CompareOptions options, const Datum& lhs, - const Datum& rhs) { - using ArrayType = typename TypeTraits::ArrayType; - using ScalarType = typename TypeTraits::ScalarType; - using T = typename TypeTraits::CType; - - bool swap = lhs.is_array(); - auto array = std::static_pointer_cast((swap ? lhs : rhs).make_array()); - T value = std::static_pointer_cast((swap ? rhs : lhs).scalar())->value; - - std::vector bitmap(array->length()); - for (int64_t i = 0; i < array->length(); i++) { - bitmap[i] = swap ? SlowCompare(options.op, array->Value(i), value) - : SlowCompare(options.op, value, array->Value(i)); + std::shared_ptr type_singleton() { + return TypeTraits::type_singleton(); } +}; + +TYPED_TEST_CASE(TestFilterKernelWithNumeric, NumericArrowTypes); +TYPED_TEST(TestFilterKernelWithNumeric, FilterNumeric) { + this->AssertFilter("[]", "[]", "[]"); + + this->AssertFilter("[9]", "[0]", "[]"); + this->AssertFilter("[9]", "[1]", "[9]"); + this->AssertFilter("[9]", "[null]", "[null]"); + this->AssertFilter("[null]", "[0]", "[]"); + this->AssertFilter("[null]", "[1]", "[null]"); + this->AssertFilter("[null]", "[null]", "[null]"); + + this->AssertFilter("[7, 8, 9]", "[0, 1, 0]", "[8]"); + this->AssertFilter("[7, 8, 9]", "[1, 0, 1]", "[7, 9]"); + this->AssertFilter("[null, 8, 9]", "[0, 1, 0]", "[8]"); + this->AssertFilter("[7, 8, 9]", "[null, 1, 0]", "[null, 8]"); + this->AssertFilter("[7, 8, 9]", "[1, null, 1]", "[7, null, 9]"); +} - std::shared_ptr result; - - if (array->null_count() == 0) { - ArrayFromVector(bitmap, &result); - } else { - std::vector null_bitmap(array->length()); - auto reader = internal::BitmapReader(array->null_bitmap_data(), array->offset(), - array->length()); - for (int64_t i = 0; i < array->length(); i++, reader.Next()) { - null_bitmap[i] = reader.IsSet(); +TYPED_TEST(TestFilterKernelWithNumeric, FilterRandomNumeric) { + auto rand = random::RandomArrayGenerator(0x5416447); + for (size_t i = 3; i < 13; i++) { + const int64_t length = static_cast(1ULL << i); + for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { + for (auto filter_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { + auto values = rand.Numeric(length, 0, 127, null_probability); + auto filter = rand.Boolean(length, filter_probability, null_probability); + this->ValidateFilter(values, filter); + } } - ArrayFromVector(null_bitmap, bitmap, &result); } - - return Datum(result); } -template ::ArrayType> -static std::vector NullBitmapFromArrays(const ArrayType& lhs, - const ArrayType& rhs) { - auto left_lambda = [&lhs](int64_t i) { - return lhs.null_count() == 0 ? true : lhs.IsValid(i); +template +decltype(Comparator::Compare)* GetComparator(CompareOperator op) { + using cmp_t = decltype(Comparator::Compare); + static cmp_t* cmp[] = { + Comparator::Compare, Comparator::Compare, + Comparator::Compare, Comparator::Compare, + Comparator::Compare, Comparator::Compare, }; + return cmp[op]; +} - auto right_lambda = [&rhs](int64_t i) { - return rhs.null_count() == 0 ? true : rhs.IsValid(i); - }; - - const int64_t length = lhs.length(); - std::vector null_bitmap(length); - - for (int64_t i = 0; i < length; i++) { - null_bitmap[i] = left_lambda(i) && right_lambda(i); - } +template ::CType> +std::shared_ptr CompareAndFilter(const CType* data, int64_t length, Fn&& fn) { + std::vector filtered; + filtered.reserve(length); + std::copy_if(data, data + length, std::back_inserter(filtered), std::forward(fn)); + std::shared_ptr filtered_array; + ArrayFromVector(filtered, &filtered_array); + return filtered_array; +} - return null_bitmap; +template ::CType> +std::shared_ptr CompareAndFilter(const CType* data, int64_t length, CType val, + CompareOperator op) { + auto cmp = GetComparator(op); + return CompareAndFilter(data, length, [&](CType e) { return cmp(e, val); }); } -template -static Datum SimpleArrayArrayCompare(CompareOptions options, const Datum& lhs, - const Datum& rhs) { - using ArrayType = typename TypeTraits::ArrayType; - using T = typename TypeTraits::CType; - - auto l_array = std::static_pointer_cast(lhs.make_array()); - auto r_array = std::static_pointer_cast(rhs.make_array()); - const int64_t length = l_array->length(); - - std::vector bitmap(length); - for (int64_t i = 0; i < length; i++) { - bitmap[i] = SlowCompare(options.op, l_array->Value(i), r_array->Value(i)); - } +template ::CType> +std::shared_ptr CompareAndFilter(const CType* data, int64_t length, + const CType* other, CompareOperator op) { + auto cmp = GetComparator(op); + return CompareAndFilter(data, length, [&](CType e) { return cmp(e, *other++); }); +} - std::shared_ptr result; +TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) { + using ScalarType = typename TypeTraits::ScalarType; + using ArrayType = typename TypeTraits::ArrayType; + using CType = typename TypeTraits::CType; - if (l_array->null_count() == 0 && r_array->null_count() == 0) { - ArrayFromVector(bitmap, &result); - } else { - std::vector null_bitmap = NullBitmapFromArrays(*l_array, *r_array); - ArrayFromVector(null_bitmap, bitmap, &result); + auto rand = random::RandomArrayGenerator(0x5416447); + for (size_t i = 3; i < 13; i++) { + const int64_t length = static_cast(1ULL << i); + // TODO(bkietz) rewrite with some nulls + auto array = + checked_pointer_cast(rand.Numeric(length, 0, 100, 0)); + CType c_fifty = 50; + auto fifty = std::make_shared(c_fifty); + for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { + auto options = CompareOptions(op); + Datum selection, filtered; + ASSERT_OK(arrow::compute::Compare(&this->ctx_, Datum(array), Datum(fifty), options, + &selection)); + ASSERT_OK(arrow::compute::Filter(&this->ctx_, Datum(array), selection, &filtered)); + auto filtered_array = filtered.make_array(); + auto expected = + CompareAndFilter(array->raw_values(), array->length(), c_fifty, op); + ASSERT_ARRAYS_EQUAL(*filtered_array, *expected); + } } - - return Datum(result); } -template -static void ValidateCompare(FunctionContext* ctx, CompareOptions options, - const Datum& lhs, const Datum& rhs) { - Datum result; - - bool has_scalar = lhs.is_scalar() || rhs.is_scalar(); - Datum expected = has_scalar ? SimpleScalarArrayCompare(options, lhs, rhs) - : SimpleArrayArrayCompare(options, lhs, rhs); +TYPED_TEST(TestFilterKernelWithNumeric, CompareArrayAndFilterRandomNumeric) { + using ArrayType = typename TypeTraits::ArrayType; - ValidateCompare(ctx, options, lhs, rhs, expected); + auto rand = random::RandomArrayGenerator(0x5416447); + for (size_t i = 3; i < 13; i++) { + const int64_t length = static_cast(1ULL << i); + auto lhs = + checked_pointer_cast(rand.Numeric(length, 0, 100, 0)); + auto rhs = + checked_pointer_cast(rand.Numeric(length, 0, 100, 0)); + for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { + auto options = CompareOptions(op); + Datum selection, filtered; + ASSERT_OK(arrow::compute::Compare(&this->ctx_, Datum(lhs), Datum(rhs), options, + &selection)); + ASSERT_OK(arrow::compute::Filter(&this->ctx_, Datum(lhs), selection, &filtered)); + auto filtered_array = filtered.make_array(); + auto expected = CompareAndFilter(lhs->raw_values(), lhs->length(), + rhs->raw_values(), op); + ASSERT_ARRAYS_EQUAL(*filtered_array, *expected); + } + } } -template -class TestNumericCompareKernel : public ComputeFixture, public TestBase {}; - -TYPED_TEST_CASE(TestNumericCompareKernel, NumericArrowTypes); -TYPED_TEST(TestNumericCompareKernel, SimpleCompareArrayScalar) { +TYPED_TEST(TestFilterKernelWithNumeric, ScalarInRangeAndFilterRandomNumeric) { using ScalarType = typename TypeTraits::ScalarType; + using ArrayType = typename TypeTraits::ArrayType; using CType = typename TypeTraits::CType; - Datum one(std::make_shared(CType(1))); - - CompareOptions eq(CompareOperator::EQUAL); - ValidateCompare(&this->ctx_, eq, "[]", one, "[]"); - ValidateCompare(&this->ctx_, eq, "[null]", one, "[null]"); - ValidateCompare(&this->ctx_, eq, "[0,0,1,1,2,2]", one, "[0,0,1,1,0,0]"); - ValidateCompare(&this->ctx_, eq, "[0,1,2,3,4,5]", one, "[0,1,0,0,0,0]"); - ValidateCompare(&this->ctx_, eq, "[5,4,3,2,1,0]", one, "[0,0,0,0,1,0]"); - ValidateCompare(&this->ctx_, eq, "[null,0,1,1]", one, "[null,0,1,1]"); - - CompareOptions neq(CompareOperator::NOT_EQUAL); - ValidateCompare(&this->ctx_, neq, "[]", one, "[]"); - ValidateCompare(&this->ctx_, neq, "[null]", one, "[null]"); - ValidateCompare(&this->ctx_, neq, "[0,0,1,1,2,2]", one, "[1,1,0,0,1,1]"); - ValidateCompare(&this->ctx_, neq, "[0,1,2,3,4,5]", one, "[1,0,1,1,1,1]"); - ValidateCompare(&this->ctx_, neq, "[5,4,3,2,1,0]", one, "[1,1,1,1,0,1]"); - ValidateCompare(&this->ctx_, neq, "[null,0,1,1]", one, "[null,1,0,0]"); - - CompareOptions gt(CompareOperator::GREATER); - ValidateCompare(&this->ctx_, gt, "[]", one, "[]"); - ValidateCompare(&this->ctx_, gt, "[null]", one, "[null]"); - ValidateCompare(&this->ctx_, gt, "[0,0,1,1,2,2]", one, "[0,0,0,0,1,1]"); - ValidateCompare(&this->ctx_, gt, "[0,1,2,3,4,5]", one, "[0,0,1,1,1,1]"); - ValidateCompare(&this->ctx_, gt, "[4,5,6,7,8,9]", one, "[1,1,1,1,1,1]"); - ValidateCompare(&this->ctx_, gt, "[null,0,1,1]", one, "[null,0,0,0]"); - - CompareOptions gte(CompareOperator::GREATER_EQUAL); - ValidateCompare(&this->ctx_, gte, "[]", one, "[]"); - ValidateCompare(&this->ctx_, gte, "[null]", one, "[null]"); - ValidateCompare(&this->ctx_, gte, "[0,0,1,1,2,2]", one, "[0,0,1,1,1,1]"); - ValidateCompare(&this->ctx_, gte, "[0,1,2,3,4,5]", one, "[0,1,1,1,1,1]"); - ValidateCompare(&this->ctx_, gte, "[4,5,6,7,8,9]", one, "[1,1,1,1,1,1]"); - ValidateCompare(&this->ctx_, gte, "[null,0,1,1]", one, "[null,0,1,1]"); - - CompareOptions lt(CompareOperator::LESS); - ValidateCompare(&this->ctx_, lt, "[]", one, "[]"); - ValidateCompare(&this->ctx_, lt, "[null]", one, "[null]"); - ValidateCompare(&this->ctx_, lt, "[0,0,1,1,2,2]", one, "[1,1,0,0,0,0]"); - ValidateCompare(&this->ctx_, lt, "[0,1,2,3,4,5]", one, "[1,0,0,0,0,0]"); - ValidateCompare(&this->ctx_, lt, "[4,5,6,7,8,9]", one, "[0,0,0,0,0,0]"); - ValidateCompare(&this->ctx_, lt, "[null,0,1,1]", one, "[null,1,0,0]"); - - CompareOptions lte(CompareOperator::LESS_EQUAL); - ValidateCompare(&this->ctx_, lte, "[]", one, "[]"); - ValidateCompare(&this->ctx_, lte, "[null]", one, "[null]"); - ValidateCompare(&this->ctx_, lte, "[0,0,1,1,2,2]", one, "[1,1,1,1,0,0]"); - ValidateCompare(&this->ctx_, lte, "[0,1,2,3,4,5]", one, "[1,1,0,0,0,0]"); - ValidateCompare(&this->ctx_, lte, "[4,5,6,7,8,9]", one, "[0,0,0,0,0,0]"); - ValidateCompare(&this->ctx_, lte, "[null,0,1,1]", one, "[null,1,1,1]"); + auto rand = random::RandomArrayGenerator(0x5416447); + for (size_t i = 3; i < 13; i++) { + const int64_t length = static_cast(1ULL << i); + auto array = + checked_pointer_cast(rand.Numeric(length, 0, 100, 0)); + CType c_fifty = 50, c_hundred = 100; + auto fifty = std::make_shared(c_fifty); + auto hundred = std::make_shared(c_hundred); + Datum greater_than_fifty, less_than_hundred, selection, filtered; + ASSERT_OK(arrow::compute::Compare(&this->ctx_, Datum(array), Datum(fifty), + CompareOptions(GREATER), &greater_than_fifty)); + ASSERT_OK(arrow::compute::Compare(&this->ctx_, Datum(array), Datum(hundred), + CompareOptions(LESS), &less_than_hundred)); + ASSERT_OK(arrow::compute::And(&this->ctx_, greater_than_fifty, less_than_hundred, + &selection)); + ASSERT_OK(arrow::compute::Filter(&this->ctx_, Datum(array), selection, &filtered)); + auto filtered_array = filtered.make_array(); + auto expected = CompareAndFilter( + array->raw_values(), array->length(), + [&](CType e) { return (e > c_fifty) && (e < c_hundred); }); + ASSERT_ARRAYS_EQUAL(*filtered_array, *expected); + } } -TYPED_TEST(TestNumericCompareKernel, SimpleCompareScalarArray) { - using ScalarType = typename TypeTraits::ScalarType; - using CType = typename TypeTraits::CType; +class TestFilterKernelWithString : public TestFilterKernel { + protected: + void AssertFilter(const std::string& values, const std::string& filter, + const std::string& expected) { + TestFilterKernel::AssertFilter(utf8(), values, filter, expected); + } + void AssertFilterDictionary(const std::string& dictionary_values, + const std::string& dictionary_filter, + const std::string& filter, + const std::string& expected_filter) { + auto dict = ArrayFromJSON(utf8(), dictionary_values); + auto type = dictionary(int8(), utf8()); + std::shared_ptr values, actual, expected; + ASSERT_OK(DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), dictionary_filter), + dict, &values)); + ASSERT_OK(DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_filter), + dict, &expected)); + auto take_filter = ArrayFromJSON(boolean(), filter); + this->AssertFilterArrays(values, take_filter, expected); + } +}; - Datum one(std::make_shared(CType(1))); - - CompareOptions eq(CompareOperator::EQUAL); - ValidateCompare(&this->ctx_, eq, one, "[]", "[]"); - ValidateCompare(&this->ctx_, eq, one, "[null]", "[null]"); - ValidateCompare(&this->ctx_, eq, one, "[0,0,1,1,2,2]", "[0,0,1,1,0,0]"); - ValidateCompare(&this->ctx_, eq, one, "[0,1,2,3,4,5]", "[0,1,0,0,0,0]"); - ValidateCompare(&this->ctx_, eq, one, "[5,4,3,2,1,0]", "[0,0,0,0,1,0]"); - ValidateCompare(&this->ctx_, eq, one, "[null,0,1,1]", "[null,0,1,1]"); - - CompareOptions neq(CompareOperator::NOT_EQUAL); - ValidateCompare(&this->ctx_, neq, one, "[]", "[]"); - ValidateCompare(&this->ctx_, neq, one, "[null]", "[null]"); - ValidateCompare(&this->ctx_, neq, one, "[0,0,1,1,2,2]", "[1,1,0,0,1,1]"); - ValidateCompare(&this->ctx_, neq, one, "[0,1,2,3,4,5]", "[1,0,1,1,1,1]"); - ValidateCompare(&this->ctx_, neq, one, "[5,4,3,2,1,0]", "[1,1,1,1,0,1]"); - ValidateCompare(&this->ctx_, neq, one, "[null,0,1,1]", "[null,1,0,0]"); - - CompareOptions gt(CompareOperator::GREATER); - ValidateCompare(&this->ctx_, gt, one, "[]", "[]"); - ValidateCompare(&this->ctx_, gt, one, "[null]", "[null]"); - ValidateCompare(&this->ctx_, gt, one, "[0,0,1,1,2,2]", "[1,1,0,0,0,0]"); - ValidateCompare(&this->ctx_, gt, one, "[0,1,2,3,4,5]", "[1,0,0,0,0,0]"); - ValidateCompare(&this->ctx_, gt, one, "[4,5,6,7,8,9]", "[0,0,0,0,0,0]"); - ValidateCompare(&this->ctx_, gt, one, "[null,0,1,1]", "[null,1,0,0]"); - - CompareOptions gte(CompareOperator::GREATER_EQUAL); - ValidateCompare(&this->ctx_, gte, one, "[]", "[]"); - ValidateCompare(&this->ctx_, gte, one, "[null]", "[null]"); - ValidateCompare(&this->ctx_, gte, one, "[0,0,1,1,2,2]", "[1,1,1,1,0,0]"); - ValidateCompare(&this->ctx_, gte, one, "[0,1,2,3,4,5]", "[1,1,0,0,0,0]"); - ValidateCompare(&this->ctx_, gte, one, "[4,5,6,7,8,9]", "[0,0,0,0,0,0]"); - ValidateCompare(&this->ctx_, gte, one, "[null,0,1,1]", "[null,1,1,1]"); - - CompareOptions lt(CompareOperator::LESS); - ValidateCompare(&this->ctx_, lt, one, "[]", "[]"); - ValidateCompare(&this->ctx_, lt, one, "[null]", "[null]"); - ValidateCompare(&this->ctx_, lt, one, "[0,0,1,1,2,2]", "[0,0,0,0,1,1]"); - ValidateCompare(&this->ctx_, lt, one, "[0,1,2,3,4,5]", "[0,0,1,1,1,1]"); - ValidateCompare(&this->ctx_, lt, one, "[4,5,6,7,8,9]", "[1,1,1,1,1,1]"); - ValidateCompare(&this->ctx_, lt, one, "[null,0,1,1]", "[null,0,0,0]"); - - CompareOptions lte(CompareOperator::LESS_EQUAL); - ValidateCompare(&this->ctx_, lte, one, "[]", "[]"); - ValidateCompare(&this->ctx_, lte, one, "[null]", "[null]"); - ValidateCompare(&this->ctx_, lte, one, "[0,0,1,1,2,2]", "[0,0,1,1,1,1]"); - ValidateCompare(&this->ctx_, lte, one, "[0,1,2,3,4,5]", "[0,1,1,1,1,1]"); - ValidateCompare(&this->ctx_, lte, one, "[4,5,6,7,8,9]", "[1,1,1,1,1,1]"); - ValidateCompare(&this->ctx_, lte, one, "[null,0,1,1]", "[null,0,1,1]"); +TEST_F(TestFilterKernelWithString, FilterString) { + this->AssertFilter(R"(["a", "b", "c"])", "[0, 1, 0]", R"(["b"])"); + this->AssertFilter(R"([null, "b", "c"])", "[0, 1, 0]", R"(["b"])"); + this->AssertFilter(R"(["a", "b", "c"])", "[null, 1, 0]", R"([null, "b"])"); } -TYPED_TEST(TestNumericCompareKernel, TestNullScalar) { - /* Ensure that null scalar broadcast to all null results. */ - using ScalarType = typename TypeTraits::ScalarType; - using CType = typename TypeTraits::CType; +TEST_F(TestFilterKernelWithString, FilterDictionary) { + auto dict = R"(["a", "b", "c", "d", "e"])"; + this->AssertFilterDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", "[4]"); + this->AssertFilterDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", "[4]"); + this->AssertFilterDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", "[null, 4]"); +} - Datum null(std::make_shared(CType(0), false)); - EXPECT_FALSE(null.scalar()->is_valid); +class TestFilterKernelWithList : public TestFilterKernel {}; - CompareOptions eq(CompareOperator::EQUAL); - ValidateCompare(&this->ctx_, eq, "[]", null, "[]"); - ValidateCompare(&this->ctx_, eq, null, "[]", "[]"); - ValidateCompare(&this->ctx_, eq, "[null]", null, "[null]"); - ValidateCompare(&this->ctx_, eq, null, "[null]", "[null]"); - ValidateCompare(&this->ctx_, eq, null, "[1,2,3]", "[null, null, null]"); +TEST_F(TestFilterKernelWithList, FilterListInt32) { + std::string list_json = "[[], [1,2], null, [3]]"; + this->AssertFilter(list(int32()), list_json, "[0, 0, 0, 0]", "[]"); + this->AssertFilter(list(int32()), list_json, "[0, 1, 1, null]", "[[1,2], null, null]"); + this->AssertFilter(list(int32()), list_json, "[0, 0, 1, null]", "[null, null]"); + this->AssertFilter(list(int32()), list_json, "[1, 0, 0, 1]", "[[], [3]]"); + this->AssertFilter(list(int32()), list_json, "[1, 1, 1, 1]", list_json); + this->AssertFilter(list(int32()), list_json, "[0, 1, 0, 1]", "[[1,2], [3]]"); } -TYPED_TEST_CASE(TestNumericCompareKernel, NumericArrowTypes); -TYPED_TEST(TestNumericCompareKernel, RandomCompareArrayScalar) { - using ScalarType = typename TypeTraits::ScalarType; - using CType = typename TypeTraits::CType; - - auto rand = random::RandomArrayGenerator(0x5416447); - for (size_t i = 3; i < 13; i++) { - for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { - for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { - const int64_t length = static_cast(1ULL << i); - auto array = Datum(rand.Numeric(length, 0, 100, null_probability)); - auto fifty = Datum(std::make_shared(CType(50))); - auto options = CompareOptions(op); - ValidateCompare(&this->ctx_, options, array, fifty); - ValidateCompare(&this->ctx_, options, fifty, array); - } - } - } +class TestFilterKernelWithFixedSizeList : public TestFilterKernel {}; + +TEST_F(TestFilterKernelWithFixedSizeList, FilterFixedSizeListInt32) { + std::string list_json = "[null, [1, null, 3], [4, 5, 6], [7, 8, null]]"; + this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 0, 0, 0]", "[]"); + this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 1, 1, null]", + "[[1, null, 3], [4, 5, 6], null]"); + this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 0, 1, null]", + "[[4, 5, 6], null]"); + this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[1, 1, 1, 1]", list_json); + this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 1, 0, 1]", + "[[1, null, 3], [7, 8, null]]"); } -TYPED_TEST(TestNumericCompareKernel, SimpleCompareArrayArray) { - /* Ensure that null scalar broadcast to all null results. */ - CompareOptions eq(CompareOperator::EQUAL); - ValidateCompare(&this->ctx_, eq, "[]", "[]", "[]"); - ValidateCompare(&this->ctx_, eq, "[null]", "[null]", "[null]"); - ValidateCompare(&this->ctx_, eq, "[1]", "[1]", "[1]"); - ValidateCompare(&this->ctx_, eq, "[1]", "[2]", "[0]"); - ValidateCompare(&this->ctx_, eq, "[null]", "[1]", "[null]"); - ValidateCompare(&this->ctx_, eq, "[1]", "[null]", "[null]"); - - CompareOptions lte(CompareOperator::LESS_EQUAL); - ValidateCompare(&this->ctx_, lte, "[1,2,3,4,5]", "[2,3,4,5,6]", - "[1,1,1,1,1]"); +class TestFilterKernelWithMap : public TestFilterKernel {}; + +TEST_F(TestFilterKernelWithMap, FilterMapStringToInt32) { + std::string map_json = R"([ + [["joe", 0], ["mark", null]], + null, + [["cap", 8]], + [] + ])"; + this->AssertFilter(map(utf8(), int32()), map_json, "[0, 0, 0, 0]", "[]"); + this->AssertFilter(map(utf8(), int32()), map_json, "[0, 1, 1, null]", R"([ + null, + [["cap", 8]], + null + ])"); + this->AssertFilter(map(utf8(), int32()), map_json, "[1, 1, 1, 1]", map_json); + this->AssertFilter(map(utf8(), int32()), map_json, "[0, 1, 0, 1]", "[null, []]"); } -TYPED_TEST(TestNumericCompareKernel, RandomCompareArrayArray) { - auto rand = random::RandomArrayGenerator(0x5416447); - for (size_t i = 3; i < 5; i++) { - for (auto null_probability : {0.0, 0.01, 0.1, 0.25, 0.5, 1.0}) { - for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { - const int64_t length = static_cast(1ULL << i); - auto lhs = Datum(rand.Numeric(length << i, 0, 100, null_probability)); - auto rhs = Datum(rand.Numeric(length << i, 0, 100, null_probability)); - auto options = CompareOptions(op); - ValidateCompare(&this->ctx_, options, lhs, rhs); - } - } - } +class TestFilterKernelWithStruct : public TestFilterKernel {}; + +TEST_F(TestFilterKernelWithStruct, FilterStruct) { + auto struct_type = struct_({field("a", int32()), field("b", utf8())}); + auto struct_json = R"([ + null, + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + {"a": 4, "b": "eh"} + ])"; + this->AssertFilter(struct_type, struct_json, "[0, 0, 0, 0]", "[]"); + this->AssertFilter(struct_type, struct_json, "[0, 1, 1, null]", R"([ + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + null + ])"); + this->AssertFilter(struct_type, struct_json, "[1, 1, 1, 1]", struct_json); + this->AssertFilter(struct_type, struct_json, "[1, 0, 1, 0]", R"([ + null, + {"a": 2, "b": "hello"} + ])"); } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/filter.cc b/cpp/src/arrow/compute/kernels/filter.cc index 1cbf0dc0992d1..654ec610352ce 100644 --- a/cpp/src/arrow/compute/kernels/filter.cc +++ b/cpp/src/arrow/compute/kernels/filter.cc @@ -15,44 +15,447 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/compute/kernels/filter.h" +#include +#include +#include +#include -#include "arrow/array.h" -#include "arrow/compute/kernel.h" +#include "arrow/builder.h" +#include "arrow/compute/context.h" +#include "arrow/compute/kernels/filter.h" +#include "arrow/util/bit-util.h" +#include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" +#include "arrow/util/stl.h" +#include "arrow/visitor_inline.h" namespace arrow { - namespace compute { -std::shared_ptr FilterBinaryKernel::out_type() const { - return filter_function_->out_type(); +using internal::checked_cast; +using internal::checked_pointer_cast; + +template +Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, + std::unique_ptr* out) { + std::unique_ptr builder; + RETURN_NOT_OK(MakeBuilder(pool, type, &builder)); + out->reset(checked_cast(builder.release())); + return Status::OK(); +} + +template +static Status UnsafeAppend(Builder* builder, Scalar&& value) { + builder->UnsafeAppend(std::forward(value)); + return Status::OK(); +} + +static Status UnsafeAppend(BinaryBuilder* builder, util::string_view value) { + RETURN_NOT_OK(builder->ReserveData(static_cast(value.size()))); + builder->UnsafeAppend(value); + return Status::OK(); +} + +static Status UnsafeAppend(StringBuilder* builder, util::string_view value) { + RETURN_NOT_OK(builder->ReserveData(static_cast(value.size()))); + builder->UnsafeAppend(value); + return Status::OK(); } -Status FilterBinaryKernel::Call(FunctionContext* ctx, const Datum& left, - const Datum& right, Datum* out) { - DCHECK(left.type()->Equals(right.type())); +// TODO(bkietz) this can be optimized +static int64_t OutputSize(const BooleanArray& filter) { + auto offset = filter.offset(); + auto length = filter.length(); + int64_t size = 0; + for (auto i = offset; i < offset + length; ++i) { + if (filter.IsNull(i) || filter.Value(i)) { + ++size; + } + } + return size; +} + +template +class FilterImpl; + +template <> +class FilterImpl : public FilterKernel { + public: + using FilterKernel::FilterKernel; + + Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, + int64_t length, std::shared_ptr* out) override { + out->reset(new NullArray(length)); + return Status::OK(); + } +}; + +template +class FilterImpl : public FilterKernel { + public: + using ValueArray = typename TypeTraits::ArrayType; + using OutBuilder = typename TypeTraits::BuilderType; + + using FilterKernel::FilterKernel; + + Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, + int64_t length, std::shared_ptr* out) override { + std::unique_ptr builder; + RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), type_, &builder)); + RETURN_NOT_OK(builder->Resize(OutputSize(filter))); + RETURN_NOT_OK(UnpackValuesNullCount(checked_cast(values), filter, + builder.get())); + return builder->Finish(out); + } + + private: + Status UnpackValuesNullCount(const ValueArray& values, const BooleanArray& filter, + OutBuilder* builder) { + if (values.null_count() == 0) { + return UnpackIndicesNullCount(values, filter, builder); + } + return UnpackIndicesNullCount(values, filter, builder); + } + + template + Status UnpackIndicesNullCount(const ValueArray& values, const BooleanArray& filter, + OutBuilder* builder) { + if (filter.null_count() == 0) { + return Filter(values, filter, builder); + } + return Filter(values, filter, builder); + } + + template + Status Filter(const ValueArray& values, const BooleanArray& filter, + OutBuilder* builder) { + for (int64_t i = 0; i < filter.length(); ++i) { + if (!AllIndicesValid && filter.IsNull(i)) { + builder->UnsafeAppendNull(); + continue; + } + if (!filter.Value(i)) { + continue; + } + if (!AllValuesValid && values.IsNull(i)) { + builder->UnsafeAppendNull(); + continue; + } + RETURN_NOT_OK(UnsafeAppend(builder, values.GetView(i))); + } + return Status::OK(); + } +}; + +template <> +class FilterImpl : public FilterKernel { + public: + FilterImpl(const std::shared_ptr& type, + std::vector> child_kernels) + : FilterKernel(type), child_kernels_(std::move(child_kernels)) {} + + Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, + int64_t length, std::shared_ptr* out) override { + const auto& struct_array = checked_cast(values); + + TypedBufferBuilder null_bitmap_builder(ctx->memory_pool()); + RETURN_NOT_OK(null_bitmap_builder.Resize(length)); + + ArrayVector fields(type_->num_children()); + for (int i = 0; i < type_->num_children(); ++i) { + RETURN_NOT_OK(child_kernels_[i]->Filter(ctx, *struct_array.field(i), filter, length, + &fields[i])); + } + + for (int64_t i = 0; i < filter.length(); ++i) { + if (filter.IsNull(i)) { + null_bitmap_builder.UnsafeAppend(false); + continue; + } + if (!filter.Value(i)) { + continue; + } + if (struct_array.IsNull(i)) { + null_bitmap_builder.UnsafeAppend(false); + continue; + } + null_bitmap_builder.UnsafeAppend(true); + } + + auto null_count = null_bitmap_builder.false_count(); + std::shared_ptr null_bitmap; + RETURN_NOT_OK(null_bitmap_builder.Finish(&null_bitmap)); + + out->reset(new StructArray(type_, length, fields, null_bitmap, null_count)); + return Status::OK(); + } + + private: + std::vector> child_kernels_; +}; + +template <> +class FilterImpl : public FilterKernel { + public: + using FilterKernel::FilterKernel; + + Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, + int64_t length, std::shared_ptr* out) override { + const auto& list_array = checked_cast(values); + + TypedBufferBuilder null_bitmap_builder(ctx->memory_pool()); + RETURN_NOT_OK(null_bitmap_builder.Resize(length)); - auto lk = left.kind(); - auto rk = right.kind(); - auto out_array = out->array(); + BooleanBuilder value_filter_builder(ctx->memory_pool()); + auto list_size = list_array.list_type()->list_size(); + RETURN_NOT_OK(value_filter_builder.Resize(list_size * length)); - if (lk == Datum::ARRAY && rk == Datum::SCALAR) { - auto array = left.array(); - auto scalar = right.scalar(); - return filter_function_->Filter(*array, *scalar, &out_array); - } else if (lk == Datum::SCALAR && rk == Datum::ARRAY) { - auto scalar = left.scalar(); - auto array = right.array(); - auto out_array = out->array(); - return filter_function_->Filter(*scalar, *array, &out_array); - } else if (lk == Datum::ARRAY && rk == Datum::ARRAY) { - auto lhs = left.array(); - auto rhs = right.array(); - return filter_function_->Filter(*lhs, *rhs, &out_array); + for (int64_t i = 0; i < filter.length(); ++i) { + if (filter.IsNull(i)) { + null_bitmap_builder.UnsafeAppend(false); + for (int64_t j = 0; j < list_size; ++j) { + value_filter_builder.UnsafeAppendNull(); + } + continue; + } + if (!filter.Value(i)) { + for (int64_t j = 0; j < list_size; ++j) { + value_filter_builder.UnsafeAppend(false); + } + continue; + } + if (values.IsNull(i)) { + null_bitmap_builder.UnsafeAppend(false); + for (int64_t j = 0; j < list_size; ++j) { + value_filter_builder.UnsafeAppendNull(); + } + continue; + } + for (int64_t j = 0; j < list_size; ++j) { + value_filter_builder.UnsafeAppend(true); + } + null_bitmap_builder.UnsafeAppend(true); + } + + std::shared_ptr value_filter; + RETURN_NOT_OK(value_filter_builder.Finish(&value_filter)); + std::shared_ptr out_values; + RETURN_NOT_OK( + arrow::compute::Filter(ctx, *list_array.values(), *value_filter, &out_values)); + + auto null_count = null_bitmap_builder.false_count(); + std::shared_ptr null_bitmap; + RETURN_NOT_OK(null_bitmap_builder.Finish(&null_bitmap)); + + out->reset( + new FixedSizeListArray(type_, length, out_values, null_bitmap, null_count)); + return Status::OK(); + } +}; + +template <> +class FilterImpl : public FilterKernel { + public: + using FilterKernel::FilterKernel; + + Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, + int64_t length, std::shared_ptr* out) override { + const auto& list_array = checked_cast(values); + + TypedBufferBuilder null_bitmap_builder(ctx->memory_pool()); + RETURN_NOT_OK(null_bitmap_builder.Resize(length)); + + BooleanBuilder value_filter_builder(ctx->memory_pool()); + + TypedBufferBuilder offset_builder(ctx->memory_pool()); + RETURN_NOT_OK(offset_builder.Resize(length + 1)); + int32_t offset = 0; + offset_builder.UnsafeAppend(offset); + + for (int64_t i = 0; i < filter.length(); ++i) { + if (filter.IsNull(i)) { + null_bitmap_builder.UnsafeAppend(false); + offset_builder.UnsafeAppend(offset); + RETURN_NOT_OK( + value_filter_builder.AppendValues(list_array.value_length(i), false)); + continue; + } + if (!filter.Value(i)) { + RETURN_NOT_OK( + value_filter_builder.AppendValues(list_array.value_length(i), false)); + continue; + } + if (values.IsNull(i)) { + null_bitmap_builder.UnsafeAppend(false); + offset_builder.UnsafeAppend(offset); + RETURN_NOT_OK( + value_filter_builder.AppendValues(list_array.value_length(i), false)); + continue; + } + null_bitmap_builder.UnsafeAppend(true); + offset += list_array.value_length(i); + offset_builder.UnsafeAppend(offset); + RETURN_NOT_OK(value_filter_builder.AppendValues(list_array.value_length(i), true)); + } + + std::shared_ptr value_filter; + RETURN_NOT_OK(value_filter_builder.Finish(&value_filter)); + std::shared_ptr out_values; + RETURN_NOT_OK( + arrow::compute::Filter(ctx, *list_array.values(), *value_filter, &out_values)); + + auto null_count = null_bitmap_builder.false_count(); + std::shared_ptr offsets, null_bitmap; + RETURN_NOT_OK(offset_builder.Finish(&offsets)); + RETURN_NOT_OK(null_bitmap_builder.Finish(&null_bitmap)); + + *out = MakeArray(ArrayData::Make(type_, length, {null_bitmap, offsets}, + {out_values->data()}, null_count)); + return Status::OK(); + } +}; + +template <> +class FilterImpl : public FilterImpl { + using FilterImpl::FilterImpl; +}; + +template <> +class FilterImpl : public FilterKernel { + public: + FilterImpl(const std::shared_ptr& type, std::unique_ptr impl) + : FilterKernel(type), impl_(std::move(impl)) {} + + Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, + int64_t length, std::shared_ptr* out) override { + auto dict_array = checked_cast(&values); + // To filter a dictionary, apply the current kernel to the dictionary's indices. + std::shared_ptr taken_indices; + RETURN_NOT_OK( + impl_->Filter(ctx, *dict_array->indices(), filter, length, &taken_indices)); + return DictionaryArray::FromArrays(values.type(), taken_indices, + dict_array->dictionary(), out); + } + + private: + std::unique_ptr impl_; +}; + +template <> +class FilterImpl : public FilterKernel { + public: + FilterImpl(const std::shared_ptr& type, std::unique_ptr impl) + : FilterKernel(type), impl_(std::move(impl)) {} + + Status Filter(FunctionContext* ctx, const Array& values, const BooleanArray& filter, + int64_t length, std::shared_ptr* out) override { + auto ext_array = checked_cast(&values); + // To take from an extension array, apply the current kernel to storage. + std::shared_ptr taken_storage; + RETURN_NOT_OK( + impl_->Filter(ctx, *ext_array->storage(), filter, length, &taken_storage)); + *out = ext_array->extension_type()->MakeArray(taken_storage->data()); + return Status::OK(); + } + + private: + std::unique_ptr impl_; +}; + +Status FilterKernel::Make(const std::shared_ptr& value_type, + std::unique_ptr* out) { + switch (value_type->id()) { +#define NO_CHILD_CASE(T) \ + case T##Type::type_id: \ + *out = internal::make_unique>(value_type); \ + return Status::OK() + +#define SINGLE_CHILD_CASE(T, CHILD_TYPE) \ + case T##Type::type_id: { \ + auto t = checked_pointer_cast(value_type); \ + std::unique_ptr child_filter_impl; \ + RETURN_NOT_OK(FilterKernel::Make(t->CHILD_TYPE(), &child_filter_impl)); \ + *out = internal::make_unique>(t, std::move(child_filter_impl)); \ + return Status::OK(); \ } - return Status::Invalid("Invalid datum signature for FilterBinaryKernel"); + NO_CHILD_CASE(Null); + NO_CHILD_CASE(Boolean); + NO_CHILD_CASE(Int8); + NO_CHILD_CASE(Int16); + NO_CHILD_CASE(Int32); + NO_CHILD_CASE(Int64); + NO_CHILD_CASE(UInt8); + NO_CHILD_CASE(UInt16); + NO_CHILD_CASE(UInt32); + NO_CHILD_CASE(UInt64); + NO_CHILD_CASE(Date32); + NO_CHILD_CASE(Date64); + NO_CHILD_CASE(Time32); + NO_CHILD_CASE(Time64); + NO_CHILD_CASE(Timestamp); + NO_CHILD_CASE(Duration); + NO_CHILD_CASE(HalfFloat); + NO_CHILD_CASE(Float); + NO_CHILD_CASE(Double); + NO_CHILD_CASE(String); + NO_CHILD_CASE(Binary); + NO_CHILD_CASE(FixedSizeBinary); + NO_CHILD_CASE(Decimal128); + + SINGLE_CHILD_CASE(Dictionary, index_type); + SINGLE_CHILD_CASE(Extension, storage_type); + + NO_CHILD_CASE(List); + NO_CHILD_CASE(FixedSizeList); + NO_CHILD_CASE(Map); + + case Type::STRUCT: { + std::vector> child_kernels; + for (auto child : value_type->children()) { + child_kernels.emplace_back(); + RETURN_NOT_OK(FilterKernel::Make(child->type(), &child_kernels.back())); + } + *out = internal::make_unique>(value_type, + std::move(child_kernels)); + return Status::OK(); + } + +#undef NO_CHILD_CASE +#undef SINGLE_CHILD_CASE + + default: + return Status::NotImplemented("gathering values of type ", *value_type); + } +} + +Status FilterKernel::Call(FunctionContext* ctx, const Datum& values, const Datum& filter, + Datum* out) { + if (!values.is_array() || !filter.is_array()) { + return Status::Invalid("FilterKernel expects array values and filter"); + } + auto values_array = values.make_array(); + auto filter_array = checked_pointer_cast(filter.make_array()); + const auto length = OutputSize(*filter_array); + std::shared_ptr out_array; + RETURN_NOT_OK(this->Filter(ctx, *values_array, *filter_array, length, &out_array)); + *out = out_array; + return Status::OK(); +} + +Status Filter(FunctionContext* context, const Array& values, const Array& filter, + std::shared_ptr* out) { + Datum out_datum; + RETURN_NOT_OK(Filter(context, Datum(values.data()), Datum(filter.data()), &out_datum)); + *out = out_datum.make_array(); + return Status::OK(); +} + +Status Filter(FunctionContext* context, const Datum& values, const Datum& filter, + Datum* out) { + std::unique_ptr kernel; + RETURN_NOT_OK(FilterKernel::Make(values.type(), &kernel)); + return kernel->Call(context, values, filter, out); } } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/filter.h b/cpp/src/arrow/compute/kernels/filter.h index 3b28bc9391a6a..46ad3d42b87f1 100644 --- a/cpp/src/arrow/compute/kernels/filter.h +++ b/cpp/src/arrow/compute/kernels/filter.h @@ -20,76 +20,74 @@ #include #include "arrow/compute/kernel.h" +#include "arrow/status.h" +#include "arrow/util/visibility.h" namespace arrow { class Array; -struct Scalar; -class Status; namespace compute { class FunctionContext; -struct Datum; -/// FilterFunction is an interface for Filters +/// \brief Filter an array with a boolean selection filter /// -/// Filters takes an array and emits a selection vector. The selection vector -/// is given in the form of a bitmask as a BooleanArray result. -class ARROW_EXPORT FilterFunction { - public: - /// Filter an array with a scalar argument. - virtual Status Filter(const ArrayData& array, const Scalar& scalar, - ArrayData* output) const = 0; - - Status Filter(const ArrayData& array, const Scalar& scalar, - std::shared_ptr* output) { - return Filter(array, scalar, output->get()); - } - - virtual Status Filter(const Scalar& scalar, const ArrayData& array, - ArrayData* output) const = 0; - - Status Filter(const Scalar& scalar, const ArrayData& array, - std::shared_ptr* output) { - return Filter(scalar, array, output->get()); - } - - /// Filter an array with an array argument. - virtual Status Filter(const ArrayData& lhs, const ArrayData& rhs, - ArrayData* output) const = 0; - - Status Filter(const ArrayData& lhs, const ArrayData& rhs, - std::shared_ptr* output) { - return Filter(lhs, rhs, output->get()); - } - - /// By default, FilterFunction emits a result bitmap. - virtual std::shared_ptr out_type() const { return boolean(); } - - virtual ~FilterFunction() {} -}; - -/// \brief BinaryKernel bound to a filter function -class ARROW_EXPORT FilterBinaryKernel : public BinaryKernel { +/// The output array will be populated with values from the input at positions +/// where the selection filter is not 0. Nulls in the filter will result in nulls +/// in the output. +/// +/// For example given values = ["a", "b", "c", null, "e", "f"] and +/// filter = [0, 1, 1, 0, null, 1], the output will be +/// = ["b", "c", null, "f"] +/// +/// \param[in] context the FunctionContext +/// \param[in] values array to filter +/// \param[in] filter indicates which values should be filtered out +/// \param[out] out resulting array +ARROW_EXPORT +Status Filter(FunctionContext* context, const Array& values, const Array& filter, + std::shared_ptr* out); + +/// \brief Filter an array with a boolean selection filter +/// +/// \param[in] context the FunctionContext +/// \param[in] values datum to filter +/// \param[in] filter indicates which values should be filtered out +/// \param[out] out resulting datum +ARROW_EXPORT +Status Filter(FunctionContext* context, const Datum& values, const Datum& filter, + Datum* out); + +/// \brief BinaryKernel implementing Filter operation +class ARROW_EXPORT FilterKernel : public BinaryKernel { public: - explicit FilterBinaryKernel(std::shared_ptr& filter) - : filter_function_(filter) {} + explicit FilterKernel(const std::shared_ptr& type) : type_(type) {} - Status Call(FunctionContext* ctx, const Datum& left, const Datum& right, + /// \brief BinaryKernel interface + /// + /// delegates to subclasses via Filter() + Status Call(FunctionContext* ctx, const Datum& values, const Datum& filter, Datum* out) override; - static int64_t out_length(const Datum& left, const Datum& right) { - if (left.kind() == Datum::ARRAY) return left.length(); - if (right.kind() == Datum::ARRAY) return right.length(); + /// \brief output type of this kernel (identical to type of values filtered) + std::shared_ptr out_type() const override { return type_; } - return 0; - } + /// \brief factory for FilterKernels + /// + /// \param[in] value_type constructed FilterKernel will support filtering + /// values of this type + /// \param[out] out created kernel + static Status Make(const std::shared_ptr& value_type, + std::unique_ptr* out); - std::shared_ptr out_type() const override; + /// \brief single-array implementation + virtual Status Filter(FunctionContext* ctx, const Array& values, + const BooleanArray& filter, int64_t length, + std::shared_ptr* out) = 0; - private: - std::shared_ptr filter_function_; + protected: + std::shared_ptr type_; }; } // namespace compute diff --git a/cpp/src/arrow/compute/kernels/take-test.cc b/cpp/src/arrow/compute/kernels/take-test.cc index b3de04d8cc762..c61aedab5b958 100644 --- a/cpp/src/arrow/compute/kernels/take-test.cc +++ b/cpp/src/arrow/compute/kernels/take-test.cc @@ -122,6 +122,7 @@ TYPED_TEST(TestTakeKernelWithNumeric, TakeNumeric) { this->AssertTake("[7, 8, 9]", "[0, 1, 0]", options, "[7, 8, 7]"); this->AssertTake("[null, 8, 9]", "[0, 1, 0]", options, "[null, 8, null]"); this->AssertTake("[7, 8, 9]", "[null, 1, 0]", options, "[null, 8, 7]"); + this->AssertTake("[null, 8, 9]", "[]", options, "[]"); std::shared_ptr arr; ASSERT_RAISES(IndexError, this->Take(this->type_singleton(), "[7, 8, 9]", int8(), diff --git a/cpp/src/arrow/compute/kernels/take.cc b/cpp/src/arrow/compute/kernels/take.cc index 9af2c0cab1152..17b054099ea32 100644 --- a/cpp/src/arrow/compute/kernels/take.cc +++ b/cpp/src/arrow/compute/kernels/take.cc @@ -28,6 +28,8 @@ namespace arrow { namespace compute { +using internal::checked_cast; + Status Take(FunctionContext* context, const Array& values, const Array& indices, const TakeOptions& options, std::shared_ptr* out) { Datum out_datum; @@ -119,20 +121,20 @@ struct UnpackValues { Status Visit(const ValueType&) { using ValueArrayRef = const typename TypeTraits::ArrayType&; using OutBuilder = typename TypeTraits::BuilderType; - IndexArrayRef indices = static_cast(*params_.indices); - ValueArrayRef values = static_cast(*params_.values); + IndexArrayRef indices = checked_cast(*params_.indices); + ValueArrayRef values = checked_cast(*params_.values); std::unique_ptr builder; RETURN_NOT_OK(MakeBuilder(params_.context->memory_pool(), values.type(), &builder)); RETURN_NOT_OK(builder->Reserve(indices.length())); RETURN_NOT_OK(UnpackValuesNullCount(params_.context, values, indices, - static_cast(builder.get()))); + checked_cast(builder.get()))); return builder->Finish(params_.out); } Status Visit(const NullType& t) { auto indices_length = params_.indices->length(); if (indices_length != 0) { - auto indices = static_cast(*params_.indices).raw_values(); + auto indices = checked_cast(*params_.indices).raw_values(); auto minmax = std::minmax_element(indices, indices + indices_length); auto min = static_cast(*minmax.first); auto max = static_cast(*minmax.second); diff --git a/cpp/src/arrow/extension_type.h b/cpp/src/arrow/extension_type.h index 6a1ca0b71553d..8bf4639bd1272 100644 --- a/cpp/src/arrow/extension_type.h +++ b/cpp/src/arrow/extension_type.h @@ -93,6 +93,10 @@ class ARROW_EXPORT ExtensionArray : public Array { ExtensionArray(const std::shared_ptr& type, const std::shared_ptr& storage); + const ExtensionType* extension_type() const { + return internal::checked_cast(data_->type.get()); + } + /// \brief The physical storage for the extension array std::shared_ptr storage() const { return storage_; }