diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index c911f0f4e9481..464e9320523fc 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -731,6 +731,7 @@ set(ARROW_COMPUTE_SRCS compute/light_array_internal.cc compute/ordering.cc compute/registry.cc + compute/kernels/chunked_internal.cc compute/kernels/codegen_internal.cc compute/kernels/ree_util_internal.cc compute/kernels/scalar_cast_boolean.cc diff --git a/cpp/src/arrow/chunk_resolver.cc b/cpp/src/arrow/chunk_resolver.cc index bda6b17810299..57cbf88a701f2 100644 --- a/cpp/src/arrow/chunk_resolver.cc +++ b/cpp/src/arrow/chunk_resolver.cc @@ -28,6 +28,8 @@ namespace arrow::internal { +using ::arrow::util::span; + namespace { template int64_t GetLength(const T& array) { @@ -42,7 +44,7 @@ int64_t GetLength>( } template -inline std::vector MakeChunksOffsets(const std::vector& chunks) { +inline std::vector MakeChunksOffsets(span chunks) { std::vector offsets(chunks.size() + 1); int64_t offset = 0; std::transform(chunks.begin(), chunks.end(), offsets.begin(), @@ -112,13 +114,13 @@ void ResolveManyInline(uint32_t num_offsets, const int64_t* signed_offsets, } // namespace ChunkResolver::ChunkResolver(const ArrayVector& chunks) noexcept - : offsets_(MakeChunksOffsets(chunks)), cached_chunk_(0) {} + : offsets_(MakeChunksOffsets(span(chunks))), cached_chunk_(0) {} -ChunkResolver::ChunkResolver(const std::vector& chunks) noexcept +ChunkResolver::ChunkResolver(span chunks) noexcept : offsets_(MakeChunksOffsets(chunks)), cached_chunk_(0) {} ChunkResolver::ChunkResolver(const RecordBatchVector& batches) noexcept - : offsets_(MakeChunksOffsets(batches)), cached_chunk_(0) {} + : offsets_(MakeChunksOffsets(span(batches))), cached_chunk_(0) {} ChunkResolver::ChunkResolver(ChunkResolver&& other) noexcept : offsets_(std::move(other.offsets_)), diff --git a/cpp/src/arrow/chunk_resolver.h b/cpp/src/arrow/chunk_resolver.h index 4a5e27c05361f..03879bb215b17 100644 --- a/cpp/src/arrow/chunk_resolver.h +++ b/cpp/src/arrow/chunk_resolver.h @@ -26,6 +26,7 @@ #include "arrow/type_fwd.h" #include "arrow/util/macros.h" +#include "arrow/util/span.h" namespace arrow::internal { @@ -76,7 +77,7 @@ struct ARROW_EXPORT ChunkResolver { public: explicit ChunkResolver(const ArrayVector& chunks) noexcept; - explicit ChunkResolver(const std::vector& chunks) noexcept; + explicit ChunkResolver(::arrow::util::span chunks) noexcept; explicit ChunkResolver(const RecordBatchVector& batches) noexcept; /// \brief Construct a ChunkResolver from a vector of chunks.size() + 1 offsets. diff --git a/cpp/src/arrow/compute/kernels/chunked_internal.cc b/cpp/src/arrow/compute/kernels/chunked_internal.cc new file mode 100644 index 0000000000000..1a52b055116fe --- /dev/null +++ b/cpp/src/arrow/compute/kernels/chunked_internal.cc @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/kernels/chunked_internal.h" + +#include + +#include "arrow/util/logging.h" + +namespace arrow::compute::internal { + +using ::arrow::internal::TypedChunkLocation; + +std::vector GetArrayPointers(const ArrayVector& arrays) { + std::vector pointers(arrays.size()); + std::transform(arrays.begin(), arrays.end(), pointers.begin(), + [&](const std::shared_ptr& array) { return array.get(); }); + return pointers; +} + +ChunkedIndexMapper::ChunkedIndexMapper(util::span chunks, + uint64_t* indices_begin, uint64_t* indices_end) + : resolver_(chunks), + chunks_(chunks), + indices_begin_(indices_begin), + indices_end_(indices_end) {} + +Result> +ChunkedIndexMapper::LogicalToPhysical() { + // Check that indices would fall in bounds for CompressedChunkLocation + if (ARROW_PREDICT_FALSE(static_cast(chunks_.size()) > + CompressedChunkLocation::kMaxChunkIndex + 1)) { + return Status::NotImplemented("Chunked array has more than ", + CompressedChunkLocation::kMaxChunkIndex + 1, " chunks"); + } + for (const Array* chunk : chunks_) { + if (ARROW_PREDICT_FALSE(chunk->length() > + CompressedChunkLocation::kMaxIndexInChunk + 1)) { + return Status::NotImplemented("Individual chunk in chunked array has more than ", + CompressedChunkLocation::kMaxIndexInChunk + 1, + " elements"); + } + } + + constexpr int64_t kMaxBatchSize = 512; + std::array, kMaxBatchSize> batch; + + const int64_t num_indices = static_cast(indices_end_ - indices_begin_); + CompressedChunkLocation* physical_begin = + reinterpret_cast(indices_begin_); + DCHECK_EQ(physical_begin + num_indices, + reinterpret_cast(indices_end_)); + + for (int64_t i = 0; i < num_indices; i += kMaxBatchSize) { + const int64_t batch_size = std::min(kMaxBatchSize, num_indices - i); + [[maybe_unused]] bool ok = + resolver_.ResolveMany(batch_size, indices_begin_ + i, batch.data()); + DCHECK(ok) << "ResolveMany unexpectedly failed (invalid logical index?)"; + for (int64_t j = 0; j < batch_size; ++j) { + const auto loc = batch[j]; + physical_begin[i + j] = CompressedChunkLocation{ + static_cast(loc.chunk_index), loc.index_in_chunk}; + } + } + + return std::pair{physical_begin, physical_begin + num_indices}; +} + +Status ChunkedIndexMapper::PhysicalToLogical() { + std::vector chunk_offsets(chunks_.size()); + { + int64_t offset = 0; + for (int64_t i = 0; i < static_cast(chunks_.size()); ++i) { + chunk_offsets[i] = offset; + offset += chunks_[i]->length(); + } + } + + const int64_t num_indices = static_cast(indices_end_ - indices_begin_); + CompressedChunkLocation* physical_begin = + reinterpret_cast(indices_begin_); + for (int64_t i = 0; i < num_indices; ++i) { + const auto loc = physical_begin[i]; + DCHECK_LT(loc.chunk_index, chunk_offsets.size()); + DCHECK_LT(static_cast(loc.index_in_chunk), + chunks_[loc.chunk_index]->length()); + indices_begin_[i] = + chunk_offsets[loc.chunk_index] + static_cast(loc.index_in_chunk); + } + + return Status::OK(); +} + +} // namespace arrow::compute::internal diff --git a/cpp/src/arrow/compute/kernels/chunked_internal.h b/cpp/src/arrow/compute/kernels/chunked_internal.h index 2b72e0ab3109e..66e731d37d169 100644 --- a/cpp/src/arrow/compute/kernels/chunked_internal.h +++ b/cpp/src/arrow/compute/kernels/chunked_internal.h @@ -20,26 +20,32 @@ #include #include #include +#include #include #include "arrow/array.h" #include "arrow/chunk_resolver.h" #include "arrow/compute/kernels/codegen_internal.h" +#include "arrow/util/span.h" -namespace arrow { -namespace compute { -namespace internal { +namespace arrow::compute::internal { // The target chunk in a chunked array. struct ResolvedChunk { // The target array in chunked array. const Array* array; // The index in the target array. - const int64_t index; + int64_t index; ResolvedChunk(const Array* array, int64_t index) : array(array), index(index) {} - public: + friend bool operator==(const ResolvedChunk& left, const ResolvedChunk& right) { + return left.array == right.array && left.index == right.index; + } + friend bool operator!=(const ResolvedChunk& left, const ResolvedChunk& right) { + return left.array != right.array || left.index != right.index; + } + bool IsNull() const { return array->IsNull(index); } template > @@ -50,20 +56,44 @@ struct ResolvedChunk { } }; +struct CompressedChunkLocation { + static constexpr int kChunkIndexBits = 24; + static constexpr int KIndexInChunkBits = 64 - kChunkIndexBits; + + static constexpr int64_t kMaxChunkIndex = (1LL << kChunkIndexBits) - 1; + static constexpr int64_t kMaxIndexInChunk = (1LL << KIndexInChunkBits) - 1; + + uint32_t chunk_index : kChunkIndexBits; + uint64_t index_in_chunk : KIndexInChunkBits; +}; + +static_assert(sizeof(uint64_t) == sizeof(CompressedChunkLocation)); + class ChunkedArrayResolver { private: ::arrow::internal::ChunkResolver resolver_; - std::vector chunks_; + util::span chunks_; + std::vector owned_chunks_; public: - explicit ChunkedArrayResolver(const std::vector& chunks) + explicit ChunkedArrayResolver(std::vector&& chunks) + : resolver_(chunks), chunks_(chunks), owned_chunks_(std::move(chunks)) {} + explicit ChunkedArrayResolver(util::span chunks) : resolver_(chunks), chunks_(chunks) {} - ChunkedArrayResolver(ChunkedArrayResolver&& other) = default; - ChunkedArrayResolver& operator=(ChunkedArrayResolver&& other) = default; + ARROW_DEFAULT_MOVE_AND_ASSIGN(ChunkedArrayResolver); - ChunkedArrayResolver(const ChunkedArrayResolver& other) = default; - ChunkedArrayResolver& operator=(const ChunkedArrayResolver& other) = default; + ChunkedArrayResolver(const ChunkedArrayResolver& other) + : resolver_(other.resolver_), owned_chunks_(other.owned_chunks_) { + // Rebind span to owned_chunks_ if necessary + chunks_ = owned_chunks_.empty() ? other.chunks_ : owned_chunks_; + } + ChunkedArrayResolver& operator=(const ChunkedArrayResolver& other) { + resolver_ = other.resolver_; + owned_chunks_ = other.owned_chunks_; + chunks_ = owned_chunks_.empty() ? other.chunks_ : owned_chunks_; + return *this; + } ResolvedChunk Resolve(int64_t index) const { const auto loc = resolver_.Resolve(index); @@ -71,13 +101,25 @@ class ChunkedArrayResolver { } }; -inline std::vector GetArrayPointers(const ArrayVector& arrays) { - std::vector pointers(arrays.size()); - std::transform(arrays.begin(), arrays.end(), pointers.begin(), - [&](const std::shared_ptr& array) { return array.get(); }); - return pointers; -} +std::vector GetArrayPointers(const ArrayVector& arrays); + +class ChunkedIndexMapper { + public: + ChunkedIndexMapper(util::span chunks, uint64_t* indices_begin, + uint64_t* indices_end); + ChunkedIndexMapper(const std::vector& chunks, uint64_t* indices_begin, + uint64_t* indices_end) + : ChunkedIndexMapper(util::span(chunks), indices_begin, indices_end) {} + + Result> + LogicalToPhysical(); + Status PhysicalToLogical(); + + private: + ::arrow::internal::ChunkResolver resolver_; + util::span chunks_; + uint64_t* indices_begin_; + uint64_t* indices_end_; +}; -} // namespace internal -} // namespace compute -} // namespace arrow +} // namespace arrow::compute::internal diff --git a/cpp/src/arrow/compute/kernels/vector_rank.cc b/cpp/src/arrow/compute/kernels/vector_rank.cc index c4e52701411fd..b374862fe6d2c 100644 --- a/cpp/src/arrow/compute/kernels/vector_rank.cc +++ b/cpp/src/arrow/compute/kernels/vector_rank.cc @@ -21,6 +21,8 @@ namespace arrow::compute::internal { +using ::arrow::util::span; + namespace { // ---------------------------------------------------------------------- @@ -237,7 +239,7 @@ class Ranker : public RankerMixin(); }; ARROW_ASSIGN_OR_RAISE(*output_, CreateRankings(ctx_, sorted, null_placement_, diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 8766ca3baac96..00677ecdae7ab 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -25,6 +25,7 @@ namespace arrow { using internal::checked_cast; using internal::ChunkLocation; +using util::span; namespace compute { namespace internal { @@ -83,6 +84,7 @@ class ChunkedArraySorter : public TypeVisitor { *output_ = {indices_end_, indices_end_, indices_end_, indices_end_}; return Status::OK(); } + const int64_t num_indices = static_cast(indices_end_ - indices_begin_); const auto arrays = GetArrayPointers(physical_chunks_); // Sort each chunk independently and merge to sorted indices. @@ -102,45 +104,66 @@ class ChunkedArraySorter : public TypeVisitor { begin_offset, options, ctx_)); begin_offset = end_offset; } - DCHECK_EQ(end_offset, indices_end_ - indices_begin_); + DCHECK_EQ(end_offset, num_indices); // Then merge them by pairs, recursively if (sorted.size() > 1) { - auto merge_nulls = [&](uint64_t* nulls_begin, uint64_t* nulls_middle, - uint64_t* nulls_end, uint64_t* temp_indices, - int64_t null_count) { + ChunkedIndexMapper chunked_mapper(arrays, indices_begin_, indices_end_); + // TODO: s/LogicalToPhysical/LinearToChunked/ ? + ARROW_ASSIGN_OR_RAISE(auto chunked_indices_pair, + chunked_mapper.LogicalToPhysical()); + auto [chunked_indices_begin, chunked_indices_end] = chunked_indices_pair; + + std::vector chunk_sorted(num_chunks); + for (int i = 0; i < num_chunks; ++i) { + chunk_sorted[i] = ChunkedNullPartitionResult::TranslateFrom( + sorted[i], indices_begin_, chunked_indices_begin); + } + + auto merge_nulls = [&](CompressedChunkLocation* nulls_begin, + CompressedChunkLocation* nulls_middle, + CompressedChunkLocation* nulls_end, + CompressedChunkLocation* temp_indices, int64_t null_count) { if (has_null_like_values::value) { - PartitionNullsOnly(nulls_begin, nulls_end, - ChunkedArrayResolver(arrays), null_count, - null_placement_); + PartitionNullsOnly(nulls_begin, nulls_end, arrays, + null_count, null_placement_); } }; - auto merge_non_nulls = [&](uint64_t* range_begin, uint64_t* range_middle, - uint64_t* range_end, uint64_t* temp_indices) { - MergeNonNulls(range_begin, range_middle, range_end, arrays, - temp_indices); - }; - - MergeImpl merge_impl{null_placement_, std::move(merge_nulls), - std::move(merge_non_nulls)}; + auto merge_non_nulls = + [&](CompressedChunkLocation* range_begin, CompressedChunkLocation* range_middle, + CompressedChunkLocation* range_end, CompressedChunkLocation* temp_indices) { + MergeNonNulls(range_begin, range_middle, range_end, arrays, + temp_indices); + }; + + ChunkedMergeImpl merge_impl{null_placement_, std::move(merge_nulls), + std::move(merge_non_nulls)}; // std::merge is only called on non-null values, so size temp indices accordingly - RETURN_NOT_OK(merge_impl.Init(ctx_, indices_end_ - indices_begin_ - null_count)); + RETURN_NOT_OK(merge_impl.Init(ctx_, num_indices - null_count)); - while (sorted.size() > 1) { - auto out_it = sorted.begin(); - auto it = sorted.begin(); - while (it < sorted.end() - 1) { + while (chunk_sorted.size() > 1) { + // Merge all pairs of chunks + auto out_it = chunk_sorted.begin(); + auto it = chunk_sorted.begin(); + while (it < chunk_sorted.end() - 1) { const auto& left = *it++; const auto& right = *it++; DCHECK_EQ(left.overall_end(), right.overall_begin()); const auto merged = merge_impl.Merge(left, right, null_count); *out_it++ = merged; } - if (it < sorted.end()) { + if (it < chunk_sorted.end()) { *out_it++ = *it++; } - sorted.erase(out_it, sorted.end()); + chunk_sorted.erase(out_it, chunk_sorted.end()); } + + // Reverse everything + sorted.resize(1); + sorted[0] = NullPartitionResult::TranslateFrom( + chunk_sorted[0], chunked_indices_begin, indices_begin_); + + RETURN_NOT_OK(chunked_mapper.PhysicalToLogical()); } DCHECK_EQ(sorted.size(), 1); @@ -154,34 +177,39 @@ class ChunkedArraySorter : public TypeVisitor { } template - void MergeNonNulls(uint64_t* range_begin, uint64_t* range_middle, uint64_t* range_end, - const std::vector& arrays, uint64_t* temp_indices) { + void MergeNonNulls(CompressedChunkLocation* range_begin, + CompressedChunkLocation* range_middle, + CompressedChunkLocation* range_end, span arrays, + CompressedChunkLocation* temp_indices) { using ArrowType = typename ArrayType::TypeClass; - const ChunkedArrayResolver left_resolver(arrays); - const ChunkedArrayResolver right_resolver(arrays); if (order_ == SortOrder::Ascending) { std::merge(range_begin, range_middle, range_middle, range_end, temp_indices, - [&](uint64_t left, uint64_t right) { - const auto chunk_left = left_resolver.Resolve(left); - const auto chunk_right = right_resolver.Resolve(right); - return chunk_left.Value() < chunk_right.Value(); + [&](CompressedChunkLocation left, CompressedChunkLocation right) { + return ChunkValue(arrays, left) < + ChunkValue(arrays, right); }); } else { std::merge(range_begin, range_middle, range_middle, range_end, temp_indices, - [&](uint64_t left, uint64_t right) { - const auto chunk_left = left_resolver.Resolve(left); - const auto chunk_right = right_resolver.Resolve(right); + [&](CompressedChunkLocation left, CompressedChunkLocation right) { // We don't use 'left > right' here to reduce required // operator. If we use 'right < left' here, '<' is only // required. - return chunk_right.Value() < chunk_left.Value(); + return ChunkValue(arrays, right) < + ChunkValue(arrays, left); }); } // Copy back temp area into main buffer std::copy(temp_indices, temp_indices + (range_end - range_begin), range_begin); } + template + auto ChunkValue(span arrays, CompressedChunkLocation loc) const { + return ResolvedChunk(arrays[loc.chunk_index], + static_cast(loc.index_in_chunk)) + .template Value(); + } + uint64_t* indices_begin_; uint64_t* indices_end_; const std::shared_ptr& physical_type_; diff --git a/cpp/src/arrow/compute/kernels/vector_sort_internal.h b/cpp/src/arrow/compute/kernels/vector_sort_internal.h index 564afb8c087d2..cf8b0e6db37a5 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_internal.h +++ b/cpp/src/arrow/compute/kernels/vector_sort_internal.h @@ -55,15 +55,17 @@ namespace internal { // NOTE: std::partition is usually faster than std::stable_partition. struct NonStablePartitioner { - template - uint64_t* operator()(uint64_t* indices_begin, uint64_t* indices_end, Predicate&& pred) { + template + IndexType* operator()(IndexType* indices_begin, IndexType* indices_end, + Predicate&& pred) { return std::partition(indices_begin, indices_end, std::forward(pred)); } }; struct StablePartitioner { - template - uint64_t* operator()(uint64_t* indices_begin, uint64_t* indices_end, Predicate&& pred) { + template + IndexType* operator()(IndexType* indices_begin, IndexType* indices_end, + Predicate&& pred) { return std::stable_partition(indices_begin, indices_end, std::forward(pred)); } @@ -142,22 +144,24 @@ int CompareTypeValues(const Value& left, const Value& right, SortOrder order, return ValueComparator::Compare(left, right, order, null_placement); } -struct NullPartitionResult { - uint64_t* non_nulls_begin; - uint64_t* non_nulls_end; - uint64_t* nulls_begin; - uint64_t* nulls_end; +template +struct GenericNullPartitionResult { + IndexType* non_nulls_begin; + IndexType* non_nulls_end; + IndexType* nulls_begin; + IndexType* nulls_end; - uint64_t* overall_begin() const { return std::min(nulls_begin, non_nulls_begin); } + IndexType* overall_begin() const { return std::min(nulls_begin, non_nulls_begin); } - uint64_t* overall_end() const { return std::max(nulls_end, non_nulls_end); } + IndexType* overall_end() const { return std::max(nulls_end, non_nulls_end); } int64_t non_null_count() const { return non_nulls_end - non_nulls_begin; } int64_t null_count() const { return nulls_end - nulls_begin; } - static NullPartitionResult NoNulls(uint64_t* indices_begin, uint64_t* indices_end, - NullPlacement null_placement) { + static GenericNullPartitionResult NoNulls(IndexType* indices_begin, + IndexType* indices_end, + NullPlacement null_placement) { if (null_placement == NullPlacement::AtStart) { return {indices_begin, indices_end, indices_begin, indices_begin}; } else { @@ -165,8 +169,9 @@ struct NullPartitionResult { } } - static NullPartitionResult NullsOnly(uint64_t* indices_begin, uint64_t* indices_end, - NullPlacement null_placement) { + static GenericNullPartitionResult NullsOnly(IndexType* indices_begin, + IndexType* indices_end, + NullPlacement null_placement) { if (null_placement == NullPlacement::AtStart) { return {indices_end, indices_end, indices_begin, indices_end}; } else { @@ -174,21 +179,38 @@ struct NullPartitionResult { } } - static NullPartitionResult NullsAtEnd(uint64_t* indices_begin, uint64_t* indices_end, - uint64_t* midpoint) { + static GenericNullPartitionResult NullsAtEnd(IndexType* indices_begin, + IndexType* indices_end, + IndexType* midpoint) { DCHECK_GE(midpoint, indices_begin); DCHECK_LE(midpoint, indices_end); return {indices_begin, midpoint, midpoint, indices_end}; } - static NullPartitionResult NullsAtStart(uint64_t* indices_begin, uint64_t* indices_end, - uint64_t* midpoint) { + static GenericNullPartitionResult NullsAtStart(IndexType* indices_begin, + IndexType* indices_end, + IndexType* midpoint) { DCHECK_GE(midpoint, indices_begin); DCHECK_LE(midpoint, indices_end); return {midpoint, indices_end, indices_begin, midpoint}; } + + template + static GenericNullPartitionResult TranslateFrom( + GenericNullPartitionResult source, + SourceIndexType* source_indices_begin, IndexType* target_indices_begin) { + return { + (source.non_nulls_begin - source_indices_begin) + target_indices_begin, + (source.non_nulls_end - source_indices_begin) + target_indices_begin, + (source.nulls_begin - source_indices_begin) + target_indices_begin, + (source.nulls_end - source_indices_begin) + target_indices_begin, + }; + } }; +using NullPartitionResult = GenericNullPartitionResult; +using ChunkedNullPartitionResult = GenericNullPartitionResult; + // Move nulls (not null-like values) to end of array. // // `offset` is used when this is called on a chunk of a chunked array @@ -265,7 +287,9 @@ NullPartitionResult PartitionNulls(uint64_t* indices_begin, uint64_t* indices_en } // -// Null partitioning on chunked arrays +// Null partitioning on chunked arrays, in two flavors: +// 1) with uint64_t indices and ChunkedArrayResolver +// 2) with CompressedChunkLocation and span of chunks // template @@ -291,6 +315,36 @@ NullPartitionResult PartitionNullsOnly(uint64_t* indices_begin, uint64_t* indice } } +template +ChunkedNullPartitionResult PartitionNullsOnly(CompressedChunkLocation* locations_begin, + CompressedChunkLocation* locations_end, + util::span chunks, + int64_t null_count, + NullPlacement null_placement) { + if (null_count == 0) { + return ChunkedNullPartitionResult::NoNulls(locations_begin, locations_end, + null_placement); + } + Partitioner partitioner; + if (null_placement == NullPlacement::AtStart) { + auto nulls_end = + partitioner(locations_begin, locations_end, [&](CompressedChunkLocation loc) { + return chunks[loc.chunk_index]->IsNull( + static_cast(loc.index_in_chunk)); + }); + return ChunkedNullPartitionResult::NullsAtStart(locations_begin, locations_end, + nulls_end); + } else { + auto nulls_begin = + partitioner(locations_begin, locations_end, [&](CompressedChunkLocation loc) { + return !chunks[loc.chunk_index]->IsNull( + static_cast(loc.index_in_chunk)); + }); + return ChunkedNullPartitionResult::NullsAtEnd(locations_begin, locations_end, + nulls_begin); + } +} + template enable_if_t::value, NullPartitionResult> @@ -334,17 +388,18 @@ NullPartitionResult PartitionNulls(uint64_t* indices_begin, uint64_t* indices_en std::max(q.nulls_end, p.nulls_end)}; } -struct MergeImpl { - using MergeNullsFunc = std::function; +template +struct GenericMergeImpl { + using MergeNullsFunc = std::function; using MergeNonNullsFunc = - std::function; + std::function; - MergeImpl(NullPlacement null_placement, MergeNullsFunc&& merge_nulls, - MergeNonNullsFunc&& merge_non_nulls) + GenericMergeImpl(NullPlacement null_placement, MergeNullsFunc&& merge_nulls, + MergeNonNullsFunc&& merge_non_nulls) : null_placement_(null_placement), merge_nulls_(std::move(merge_nulls)), merge_non_nulls_(std::move(merge_non_nulls)) {} @@ -352,13 +407,14 @@ struct MergeImpl { Status Init(ExecContext* ctx, int64_t temp_indices_length) { ARROW_ASSIGN_OR_RAISE( temp_buffer_, - AllocateBuffer(sizeof(int64_t) * temp_indices_length, ctx->memory_pool())); - temp_indices_ = reinterpret_cast(temp_buffer_->mutable_data()); + AllocateBuffer(sizeof(IndexType) * temp_indices_length, ctx->memory_pool())); + temp_indices_ = reinterpret_cast(temp_buffer_->mutable_data()); return Status::OK(); } - NullPartitionResult Merge(const NullPartitionResult& left, - const NullPartitionResult& right, int64_t null_count) const { + NullPartitionResultType Merge(const NullPartitionResultType& left, + const NullPartitionResultType& right, + int64_t null_count) const { if (null_placement_ == NullPlacement::AtStart) { return MergeNullsAtStart(left, right, null_count); } else { @@ -366,9 +422,9 @@ struct MergeImpl { } } - NullPartitionResult MergeNullsAtStart(const NullPartitionResult& left, - const NullPartitionResult& right, - int64_t null_count) const { + NullPartitionResultType MergeNullsAtStart(const NullPartitionResultType& left, + const NullPartitionResultType& right, + int64_t null_count) const { // Input layout: // [left nulls .... left non-nulls .... right nulls .... right non-nulls] DCHECK_EQ(left.nulls_end, left.non_nulls_begin); @@ -379,7 +435,7 @@ struct MergeImpl { // [left nulls .... right nulls .... left non-nulls .... right non-nulls] std::rotate(left.non_nulls_begin, right.nulls_begin, right.nulls_end); - const auto p = NullPartitionResult::NullsAtStart( + const auto p = NullPartitionResultType::NullsAtStart( left.nulls_begin, right.non_nulls_end, left.nulls_begin + left.null_count() + right.null_count()); @@ -401,9 +457,9 @@ struct MergeImpl { return p; } - NullPartitionResult MergeNullsAtEnd(const NullPartitionResult& left, - const NullPartitionResult& right, - int64_t null_count) const { + NullPartitionResultType MergeNullsAtEnd(const NullPartitionResultType& left, + const NullPartitionResultType& right, + int64_t null_count) const { // Input layout: // [left non-nulls .... left nulls .... right non-nulls .... right nulls] DCHECK_EQ(left.non_nulls_end, left.nulls_begin); @@ -414,7 +470,7 @@ struct MergeImpl { // [left non-nulls .... right non-nulls .... left nulls .... right nulls] std::rotate(left.nulls_begin, right.non_nulls_begin, right.non_nulls_end); - const auto p = NullPartitionResult::NullsAtEnd( + const auto p = NullPartitionResultType::NullsAtEnd( left.non_nulls_begin, right.nulls_end, left.non_nulls_begin + left.non_null_count() + right.non_null_count()); @@ -441,9 +497,13 @@ struct MergeImpl { MergeNullsFunc merge_nulls_; MergeNonNullsFunc merge_non_nulls_; std::unique_ptr temp_buffer_; - uint64_t* temp_indices_ = nullptr; + IndexType* temp_indices_ = nullptr; }; +using MergeImpl = GenericMergeImpl; +using ChunkedMergeImpl = + GenericMergeImpl; + // TODO make this usable if indices are non trivial on input // (see ConcreteRecordBatchColumnSorter) // `offset` is used when this is called on a chunk of a chunked array