diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc index 15f8b263ed87f..c18dfa0952245 100644 --- a/cpp/src/arrow/compute/exec.cc +++ b/cpp/src/arrow/compute/exec.cc @@ -147,6 +147,18 @@ ExecBatch ExecBatch::Slice(int64_t offset, int64_t length) const { return out; } +Result ExecBatch::SelectValues(const std::vector& ids) const { + std::vector selected_values; + selected_values.reserve(ids.size()); + for (int id : ids) { + if (id < 0 || static_cast(id) >= values.size()) { + return Status::Invalid("ExecBatch invalid value selection: ", id); + } + selected_values.push_back(values[id]); + } + return ExecBatch(std::move(selected_values), length); +} + namespace { enum LengthInferenceError { diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h index 8128d84a12b15..338740f066eed 100644 --- a/cpp/src/arrow/compute/exec.h +++ b/cpp/src/arrow/compute/exec.h @@ -181,6 +181,12 @@ struct ARROW_EXPORT ExecBatch { /// \brief Infer the ExecBatch length from values. static Result InferLength(const std::vector& values); + /// Creates an ExecBatch with length-validation. + /// + /// If any value is given, then all values must have a common length. If the given + /// length is negative, then the length of the ExecBatch is set to this common length, + /// or to 1 if no values are given. Otherwise, the given length must equal the common + /// length, if any value is given. static Result Make(std::vector values, int64_t length = -1); Result> ToRecordBatch( @@ -240,6 +246,8 @@ struct ARROW_EXPORT ExecBatch { ExecBatch Slice(int64_t offset, int64_t length) const; + Result SelectValues(const std::vector& ids) const; + /// \brief A convenience for returning the types from the batch. std::vector GetTypes() const { std::vector result; diff --git a/cpp/src/arrow/compute/exec/aggregate_node.cc b/cpp/src/arrow/compute/exec/aggregate_node.cc index aa9d832f90a49..62d4ac81d7056 100644 --- a/cpp/src/arrow/compute/exec/aggregate_node.cc +++ b/cpp/src/arrow/compute/exec/aggregate_node.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include "arrow/compute/exec.h" #include "arrow/compute/exec/exec_plan.h" @@ -35,6 +36,25 @@ #include "arrow/util/thread_pool.h" #include "arrow/util/tracing_internal.h" +// This file implements both regular and segmented group-by aggregation, which is a +// generalization of ordered aggregation in which the key columns are not required to be +// ordered. +// +// In (regular) group-by aggregation, the input rows are partitioned into groups using a +// set of columns called keys, where in a given group each row has the same values for +// these columns. In segmented group-by aggregation, a second set of columns called +// segment-keys is used to refine the partitioning. However, segment-keys are different in +// that they partition only consecutive rows into a single group. Such a partition of +// consecutive rows is called a segment group. For example, consider a column X with +// values [A, A, B, A] at row-indices [0, 1, 2]. A regular group-by aggregation with keys +// [X] yields a row-index partitioning [[0, 1, 3], [2]] whereas a segmented-group-by +// aggregation with segment-keys [X] yields [[0, 1], [1], [3]]. +// +// The implementation first segments the input using the segment-keys, then groups by the +// keys. When a segment group end is reached while scanning the input, output is pushed +// and the accumulating state is cleared. If no segment-keys are given, then the entire +// input is taken as one segment group. One batch per segment group is sent to output. + namespace arrow { using internal::checked_cast; @@ -43,8 +63,6 @@ namespace compute { namespace { -namespace { - std::vector ExtendWithGroupIdType(const std::vector& in_types) { std::vector aggr_in_types; aggr_in_types.reserve(in_types.size() + 1); @@ -141,8 +159,6 @@ Result ResolveKernels( return fields; } -} // namespace - void AggregatesToString(std::stringstream* ss, const Schema& input_schema, const std::vector& aggs, const std::vector>& target_fieldsets, @@ -169,20 +185,79 @@ void AggregatesToString(std::stringstream* ss, const Schema& input_schema, *ss << ']'; } +// Extract segments from a batch and run the given handler on them. Note that the +// handle may be called on open segments which are not yet finished. Typically a +// handler should accumulate those open segments until a closed segment is reached. +template +Status HandleSegments(RowSegmenter* segmenter, const ExecBatch& batch, + const std::vector& ids, const BatchHandler& handle_batch) { + int64_t offset = 0; + ARROW_ASSIGN_OR_RAISE(auto segment_exec_batch, batch.SelectValues(ids)); + ExecSpan segment_batch(segment_exec_batch); + + while (true) { + ARROW_ASSIGN_OR_RAISE(compute::Segment segment, + segmenter->GetNextSegment(segment_batch, offset)); + if (segment.offset >= segment_batch.length) break; // condition of no-next-segment + ARROW_RETURN_NOT_OK(handle_batch(batch, segment)); + offset = segment.offset + segment.length; + } + return Status::OK(); +} + +/// @brief Extract values of segment keys from a segment batch +/// @param[out] values_ptr Vector to store the extracted segment key values +/// @param[in] input_batch Segment batch. Must have the a constant value for segment key +/// @param[in] field_ids Segment key field ids +Status ExtractSegmenterValues(std::vector* values_ptr, + const ExecBatch& input_batch, + const std::vector& field_ids) { + DCHECK_GT(input_batch.length, 0); + std::vector& values = *values_ptr; + int64_t row = input_batch.length - 1; + values.clear(); + values.resize(field_ids.size()); + for (size_t i = 0; i < field_ids.size(); i++) { + const Datum& value = input_batch.values[field_ids[i]]; + if (value.is_scalar()) { + values[i] = value; + } else if (value.is_array()) { + ARROW_ASSIGN_OR_RAISE(auto scalar, value.make_array()->GetScalar(row)); + values[i] = scalar; + } else { + DCHECK(false); + } + } + return Status::OK(); +} + +void PlaceFields(ExecBatch& batch, size_t base, std::vector& values) { + DCHECK_LE(base + values.size(), batch.values.size()); + for (size_t i = 0; i < values.size(); i++) { + batch.values[base + i] = values[i]; + } +} + class ScalarAggregateNode : public ExecNode, public TracedNode { public: ScalarAggregateNode(ExecPlan* plan, std::vector inputs, std::shared_ptr output_schema, + std::unique_ptr segmenter, + std::vector segment_field_ids, std::vector> target_fieldsets, std::vector aggs, std::vector kernels, + std::vector> kernel_intypes, std::vector>> states) : ExecNode(plan, std::move(inputs), {"target"}, /*output_schema=*/std::move(output_schema)), TracedNode(this), + segmenter_(std::move(segmenter)), + segment_field_ids_(std::move(segment_field_ids)), target_fieldsets_(std::move(target_fieldsets)), aggs_(std::move(aggs)), kernels_(std::move(kernels)), + kernel_intypes_(std::move(kernel_intypes)), states_(std::move(states)) {} static Result Make(ExecPlan* plan, std::vector inputs, @@ -191,13 +266,40 @@ class ScalarAggregateNode : public ExecNode, public TracedNode { const auto& aggregate_options = checked_cast(options); auto aggregates = aggregate_options.aggregates; + const auto& keys = aggregate_options.keys; + const auto& segment_keys = aggregate_options.segment_keys; + + if (keys.size() > 0) { + return Status::Invalid("Scalar aggregation with some key"); + } + if (plan->query_context()->exec_context()->executor()->GetCapacity() > 1 && + segment_keys.size() > 0) { + return Status::NotImplemented("Segmented aggregation in a multi-threaded plan"); + } const auto& input_schema = *inputs[0]->output_schema(); auto exec_ctx = plan->query_context()->exec_context(); + std::vector segment_field_ids(segment_keys.size()); + std::vector segment_key_types(segment_keys.size()); + for (size_t i = 0; i < segment_keys.size(); i++) { + ARROW_ASSIGN_OR_RAISE(FieldPath match, segment_keys[i].FindOne(input_schema)); + if (match.indices().size() > 1) { + // ARROW-18369: Support nested references as segment ids + return Status::Invalid("Nested references cannot be used as segment ids"); + } + segment_field_ids[i] = match[0]; + segment_key_types[i] = input_schema.field(match[0])->type().get(); + } + + ARROW_ASSIGN_OR_RAISE(auto segmenter, + RowSegmenter::Make(std::move(segment_key_types), + /*nullable_keys=*/false, exec_ctx)); + + std::vector> kernel_intypes(aggregates.size()); std::vector kernels(aggregates.size()); std::vector>> states(kernels.size()); - FieldVector fields(kernels.size()); + FieldVector fields(kernels.size() + segment_keys.size()); std::vector> target_fieldsets(kernels.size()); for (size_t i = 0; i < kernels.size(); ++i) { @@ -225,7 +327,9 @@ class ScalarAggregateNode : public ExecNode, public TracedNode { for (const auto& target : target_fieldsets[i]) { in_types.emplace_back(input_schema.field(target)->type().get()); } - ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, function->DispatchExact(in_types)); + kernel_intypes[i] = in_types; + ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, + function->DispatchExact(kernel_intypes[i])); kernels[i] = static_cast(kernel); if (aggregates[i].options == nullptr) { @@ -239,20 +343,26 @@ class ScalarAggregateNode : public ExecNode, public TracedNode { KernelContext kernel_ctx{exec_ctx}; states[i].resize(plan->query_context()->max_concurrency()); RETURN_NOT_OK(Kernel::InitAll( - &kernel_ctx, KernelInitArgs{kernels[i], in_types, aggregates[i].options.get()}, + &kernel_ctx, + KernelInitArgs{kernels[i], kernel_intypes[i], aggregates[i].options.get()}, &states[i])); // pick one to resolve the kernel signature kernel_ctx.SetState(states[i][0].get()); ARROW_ASSIGN_OR_RAISE(auto out_type, kernels[i]->signature->out_type().Resolve( - &kernel_ctx, in_types)); + &kernel_ctx, kernel_intypes[i])); fields[i] = field(aggregate_options.aggregates[i].name, out_type.GetSharedPtr()); } + for (size_t i = 0; i < segment_keys.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(fields[kernels.size() + i], + segment_keys[i].GetOne(*inputs[0]->output_schema())); + } return plan->EmplaceNode( - plan, std::move(inputs), schema(std::move(fields)), std::move(target_fieldsets), - std::move(aggregates), std::move(kernels), std::move(states)); + plan, std::move(inputs), schema(std::move(fields)), std::move(segmenter), + std::move(segment_field_ids), std::move(target_fieldsets), std::move(aggregates), + std::move(kernels), std::move(kernel_intypes), std::move(states)); } const char* kind_name() const override { return "ScalarAggregateNode"; } @@ -283,28 +393,46 @@ class ScalarAggregateNode : public ExecNode, public TracedNode { DCHECK_EQ(input, inputs_[0]); auto thread_index = plan_->query_context()->GetThreadIndex(); - - ARROW_RETURN_NOT_OK(DoConsume(ExecSpan(batch), thread_index)); + auto handler = [this, thread_index](const ExecBatch& full_batch, + const Segment& segment) { + // (1) The segment is starting of a new segment group and points to + // the beginning of the batch, then it means no data in the batch belongs + // to the current segment group. We can output and reset kernel states. + if (!segment.extends && segment.offset == 0) RETURN_NOT_OK(OutputResult(false)); + + // We add segment to the current segment group aggregation + auto exec_batch = full_batch.Slice(segment.offset, segment.length); + RETURN_NOT_OK(DoConsume(ExecSpan(exec_batch), thread_index)); + RETURN_NOT_OK( + ExtractSegmenterValues(&segmenter_values_, exec_batch, segment_field_ids_)); + + // If the segment closes the current segment group, we can output segment group + // aggregation. + if (!segment.is_open) RETURN_NOT_OK(OutputResult(false)); + + return Status::OK(); + }; + RETURN_NOT_OK(HandleSegments(segmenter_.get(), batch, segment_field_ids_, handler)); if (input_counter_.Increment()) { - return Finish(); + RETURN_NOT_OK(OutputResult(/*is_last=*/true)); } return Status::OK(); } Status InputFinished(ExecNode* input, int total_batches) override { + auto scope = TraceFinish(); EVENT_ON_CURRENT_SPAN("InputFinished", {{"batches.length", total_batches}}); DCHECK_EQ(input, inputs_[0]); if (input_counter_.SetTotal(total_batches)) { - return Finish(); + RETURN_NOT_OK(OutputResult(/*is_last=*/true)); } return Status::OK(); } Status StartProducing() override { NoteStartProducing(ToStringExtra()); - // Scalar aggregates will only output a single batch - return output_->InputFinished(this, 1); + return Status::OK(); } void PauseProducing(ExecNode* output, int32_t counter) override { @@ -326,10 +454,22 @@ class ScalarAggregateNode : public ExecNode, public TracedNode { } private: - Status Finish() { - auto scope = TraceFinish(); + Status ResetKernelStates() { + auto exec_ctx = plan()->query_context()->exec_context(); + for (size_t i = 0; i < kernels_.size(); ++i) { + states_[i].resize(plan()->query_context()->max_concurrency()); + KernelContext kernel_ctx{exec_ctx}; + RETURN_NOT_OK(Kernel::InitAll( + &kernel_ctx, + KernelInitArgs{kernels_[i], kernel_intypes_[i], aggs_[i].options.get()}, + &states_[i])); + } + return Status::OK(); + } + + Status OutputResult(bool is_last) { ExecBatch batch{{}, 1}; - batch.values.resize(kernels_.size()); + batch.values.resize(kernels_.size() + segment_field_ids_.size()); for (size_t i = 0; i < kernels_.size(); ++i) { util::tracing::Span span; @@ -343,29 +483,54 @@ class ScalarAggregateNode : public ExecNode, public TracedNode { kernels_[i], &ctx, std::move(states_[i]))); RETURN_NOT_OK(kernels_[i]->finalize(&ctx, &batch.values[i])); } + PlaceFields(batch, kernels_.size(), segmenter_values_); - return output_->InputReceived(this, std::move(batch)); + ARROW_RETURN_NOT_OK(output_->InputReceived(this, std::move(batch))); + total_output_batches_++; + if (is_last) { + ARROW_RETURN_NOT_OK(output_->InputFinished(this, total_output_batches_)); + } else { + ARROW_RETURN_NOT_OK(ResetKernelStates()); + } + return Status::OK(); } + // A segmenter for the segment-keys + std::unique_ptr segmenter_; + // Field indices corresponding to the segment-keys + const std::vector segment_field_ids_; + // Holds the value of segment keys of the most recent input batch + // The values are updated everytime an input batch is processed + std::vector segmenter_values_; + const std::vector> target_fieldsets_; const std::vector aggs_; const std::vector kernels_; + // Input type holders for each kernel, used for state initialization + std::vector> kernel_intypes_; std::vector>> states_; AtomicCounter input_counter_; + /// \brief Total number of output batches produced + int total_output_batches_ = 0; }; class GroupByNode : public ExecNode, public TracedNode { public: GroupByNode(ExecNode* input, std::shared_ptr output_schema, - std::vector key_field_ids, + std::vector key_field_ids, std::vector segment_key_field_ids, + std::unique_ptr segmenter, + std::vector> agg_src_types, std::vector> agg_src_fieldsets, std::vector aggs, std::vector agg_kernels) : ExecNode(input->plan(), {input}, {"groupby"}, std::move(output_schema)), TracedNode(this), + segmenter_(std::move(segmenter)), key_field_ids_(std::move(key_field_ids)), + segment_key_field_ids_(std::move(segment_key_field_ids)), + agg_src_types_(std::move(agg_src_types)), agg_src_fieldsets_(std::move(agg_src_fieldsets)), aggs_(std::move(aggs)), agg_kernels_(std::move(agg_kernels)) {} @@ -384,9 +549,15 @@ class GroupByNode : public ExecNode, public TracedNode { auto input = inputs[0]; const auto& aggregate_options = checked_cast(options); const auto& keys = aggregate_options.keys; + const auto& segment_keys = aggregate_options.segment_keys; // Copy (need to modify options pointer below) auto aggs = aggregate_options.aggregates; + if (plan->query_context()->exec_context()->executor()->GetCapacity() > 1 && + segment_keys.size() > 0) { + return Status::NotImplemented("Segmented aggregation in a multi-threaded plan"); + } + // Get input schema auto input_schema = input->output_schema(); @@ -397,6 +568,23 @@ class GroupByNode : public ExecNode, public TracedNode { key_field_ids[i] = match[0]; } + // Find input field indices for segment key fields + std::vector segment_key_field_ids(segment_keys.size()); + for (size_t i = 0; i < segment_keys.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(auto match, segment_keys[i].FindOne(*input_schema)); + segment_key_field_ids[i] = match[0]; + } + + // Check key fields and segment key fields are disjoint + std::unordered_set key_field_id_set(key_field_ids.begin(), key_field_ids.end()); + for (const auto& segment_key_field_id : segment_key_field_ids) { + if (key_field_id_set.find(segment_key_field_id) != key_field_id_set.end()) { + return Status::Invalid("Group-by aggregation with field '", + input_schema->field(segment_key_field_id)->name(), + "' as both key and segment key"); + } + } + // Find input field indices for aggregates std::vector> agg_src_fieldsets(aggs.size()); for (size_t i = 0; i < aggs.size(); ++i) { @@ -415,8 +603,19 @@ class GroupByNode : public ExecNode, public TracedNode { } } + // Build vector of segment key field data types + std::vector segment_key_types(segment_keys.size()); + for (size_t i = 0; i < segment_keys.size(); ++i) { + auto segment_key_field_id = segment_key_field_ids[i]; + segment_key_types[i] = input_schema->field(segment_key_field_id)->type().get(); + } + auto ctx = plan->query_context()->exec_context(); + ARROW_ASSIGN_OR_RAISE(auto segmenter, + RowSegmenter::Make(std::move(segment_key_types), + /*nullable_keys=*/false, ctx)); + // Construct aggregates ARROW_ASSIGN_OR_RAISE(auto agg_kernels, GetKernels(ctx, aggs, agg_src_types)); @@ -428,7 +627,7 @@ class GroupByNode : public ExecNode, public TracedNode { ResolveKernels(aggs, agg_kernels, agg_states, ctx, agg_src_types)); // Build field vector for output schema - FieldVector output_fields{keys.size() + aggs.size()}; + FieldVector output_fields{keys.size() + segment_keys.size() + aggs.size()}; // Aggregate fields come before key fields to match the behavior of GroupBy function for (size_t i = 0; i < aggs.size(); ++i) { @@ -440,12 +639,24 @@ class GroupByNode : public ExecNode, public TracedNode { int key_field_id = key_field_ids[i]; output_fields[base + i] = input_schema->field(key_field_id); } + base += keys.size(); + for (size_t i = 0; i < segment_keys.size(); ++i) { + int segment_key_field_id = segment_key_field_ids[i]; + output_fields[base + i] = input_schema->field(segment_key_field_id); + } return input->plan()->EmplaceNode( input, schema(std::move(output_fields)), std::move(key_field_ids), + std::move(segment_key_field_ids), std::move(segmenter), std::move(agg_src_types), std::move(agg_src_fieldsets), std::move(aggs), std::move(agg_kernels)); } + Status ResetKernelStates() { + auto ctx = plan()->query_context()->exec_context(); + ARROW_RETURN_NOT_OK(InitKernels(agg_kernels_, ctx, aggs_, agg_src_types_)); + return Status::OK(); + } + const char* kind_name() const override { return "GroupByNode"; } Status Consume(ExecSpan batch) { @@ -542,7 +753,8 @@ class GroupByNode : public ExecNode, public TracedNode { RETURN_NOT_OK(InitLocalStateIfNeeded(state)); ExecBatch out_data{{}, state->grouper->num_groups()}; - out_data.values.resize(agg_kernels_.size() + key_field_ids_.size()); + out_data.values.resize(agg_kernels_.size() + key_field_ids_.size() + + segment_key_field_ids_.size()); // Aggregate fields come before key fields to match the behavior of GroupBy function for (size_t i = 0; i < agg_kernels_.size(); ++i) { @@ -561,6 +773,7 @@ class GroupByNode : public ExecNode, public TracedNode { ARROW_ASSIGN_OR_RAISE(ExecBatch out_keys, state->grouper->GetUniques()); std::move(out_keys.values.begin(), out_keys.values.end(), out_data.values.begin() + agg_kernels_.size()); + PlaceFields(out_data, agg_kernels_.size() + key_field_ids_.size(), segmenter_values_); state->grouper.reset(); return out_data; } @@ -570,8 +783,7 @@ class GroupByNode : public ExecNode, public TracedNode { return output_->InputReceived(this, out_data_.Slice(batch_size * n, batch_size)); } - Status OutputResult() { - auto scope = TraceFinish(); + Status OutputResult(bool is_last) { // To simplify merging, ensure that the first grouper is nonempty for (size_t i = 0; i < local_states_.size(); i++) { if (local_states_[i].grouper) { @@ -584,9 +796,18 @@ class GroupByNode : public ExecNode, public TracedNode { ARROW_ASSIGN_OR_RAISE(out_data_, Finalize()); int64_t num_output_batches = bit_util::CeilDiv(out_data_.length, output_batch_size()); - RETURN_NOT_OK(output_->InputFinished(this, static_cast(num_output_batches))); - return plan_->query_context()->StartTaskGroup(output_task_group_id_, - num_output_batches); + total_output_batches_ += static_cast(num_output_batches); + if (is_last) { + ARROW_RETURN_NOT_OK(output_->InputFinished(this, total_output_batches_)); + RETURN_NOT_OK(plan_->query_context()->StartTaskGroup(output_task_group_id_, + num_output_batches)); + } else { + for (int64_t i = 0; i < num_output_batches; i++) { + ARROW_RETURN_NOT_OK(OutputNthBatch(i)); + } + ARROW_RETURN_NOT_OK(ResetKernelStates()); + } + return Status::OK(); } Status InputReceived(ExecNode* input, ExecBatch batch) override { @@ -594,19 +815,31 @@ class GroupByNode : public ExecNode, public TracedNode { DCHECK_EQ(input, inputs_[0]); - ARROW_RETURN_NOT_OK(Consume(ExecSpan(batch))); + auto handler = [this](const ExecBatch& full_batch, const Segment& segment) { + if (!segment.extends && segment.offset == 0) RETURN_NOT_OK(OutputResult(false)); + auto exec_batch = full_batch.Slice(segment.offset, segment.length); + auto batch = ExecSpan(exec_batch); + RETURN_NOT_OK(Consume(batch)); + RETURN_NOT_OK( + ExtractSegmenterValues(&segmenter_values_, exec_batch, segment_key_field_ids_)); + if (!segment.is_open) RETURN_NOT_OK(OutputResult(false)); + return Status::OK(); + }; + ARROW_RETURN_NOT_OK( + HandleSegments(segmenter_.get(), batch, segment_key_field_ids_, handler)); if (input_counter_.Increment()) { - return OutputResult(); + ARROW_RETURN_NOT_OK(OutputResult(/*is_last=*/true)); } return Status::OK(); } Status InputFinished(ExecNode* input, int total_batches) override { + auto scope = TraceFinish(); DCHECK_EQ(input, inputs_[0]); if (input_counter_.SetTotal(total_batches)) { - return OutputResult(); + RETURN_NOT_OK(OutputResult(/*is_last=*/true)); } return Status::OK(); } @@ -619,12 +852,12 @@ class GroupByNode : public ExecNode, public TracedNode { void PauseProducing(ExecNode* output, int32_t counter) override { // TODO(ARROW-16260) - // Without spillover there is way to handle backpressure in this node + // Without spillover there is no way to handle backpressure in this node } void ResumeProducing(ExecNode* output, int32_t counter) override { // TODO(ARROW-16260) - // Without spillover there is way to handle backpressure in this node + // Without spillover there is no way to handle backpressure in this node } Status StopProducingImpl() override { return Status::OK(); } @@ -697,13 +930,23 @@ class GroupByNode : public ExecNode, public TracedNode { } int output_task_group_id_; + /// \brief A segmenter for the segment-keys + std::unique_ptr segmenter_; + /// \brief Holds values of the current batch that were selected for the segment-keys + std::vector segmenter_values_; const std::vector key_field_ids_; + /// \brief Field indices corresponding to the segment-keys + const std::vector segment_key_field_ids_; + /// \brief Types of input fields per aggregate + const std::vector> agg_src_types_; const std::vector> agg_src_fieldsets_; const std::vector aggs_; const std::vector agg_kernels_; AtomicCounter input_counter_; + /// \brief Total number of output batches produced + int total_output_batches_ = 0; std::vector local_states_; ExecBatch out_data_; diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index c2738945c27cb..c63e861a6154a 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -196,8 +196,7 @@ class ARROW_EXPORT ExecNode { /// concurrently, potentially even before the call to StartProducing /// has finished. /// - PauseProducing(), ResumeProducing(), StopProducing() may be called - /// by the downstream nodes' InputReceived(), ErrorReceived(), InputFinished() - /// methods + /// by the downstream nodes' InputReceived(), InputFinished() methods /// /// StopProducing may be called due to an error, by the user (e.g. cancel), or /// because a node has all the data it needs (e.g. limit, top-k on sorted data). diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index bd2bbcb8e64a0..419990407d029 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -199,21 +199,37 @@ class ARROW_EXPORT ProjectNodeOptions : public ExecNodeOptions { std::vector names; }; -/// \brief Make a node which aggregates input batches, optionally grouped by keys. +/// \brief Make a node which aggregates input batches, optionally grouped by keys and +/// optionally segmented by segment-keys. Both keys and segment-keys determine the group. +/// However segment-keys are also used for determining grouping segments, which should be +/// large, and allow streaming a partial aggregation result after processing each segment. +/// One common use-case for segment-keys is ordered aggregation, in which the segment-key +/// attribute specifies a column with non-decreasing values or a lexicographically-ordered +/// set of such columns. /// /// If the keys attribute is a non-empty vector, then each aggregate in `aggregates` is /// expected to be a HashAggregate function. If the keys attribute is an empty vector, /// then each aggregate is assumed to be a ScalarAggregate function. +/// +/// If the segment_keys attribute is a non-empty vector, then segmented aggregation, as +/// described above, applies. +/// +/// The keys and segment_keys vectors must be disjoint. class ARROW_EXPORT AggregateNodeOptions : public ExecNodeOptions { public: explicit AggregateNodeOptions(std::vector aggregates, - std::vector keys = {}) - : aggregates(std::move(aggregates)), keys(std::move(keys)) {} + std::vector keys = {}, + std::vector segment_keys = {}) + : aggregates(std::move(aggregates)), + keys(std::move(keys)), + segment_keys(std::move(segment_keys)) {} // aggregations which will be applied to the targetted fields std::vector aggregates; - // keys by which aggregations will be grouped + // keys by which aggregations will be grouped (optional) std::vector keys; + // keys by which aggregations will be segmented (optional) + std::vector segment_keys; }; constexpr int32_t kDefaultBackpressureHighBytes = 1 << 30; // 1GiB diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 5b2af718df73b..66cfa2563b623 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -1476,5 +1476,108 @@ TEST(ExecPlan, SourceEnforcesBatchLimit) { } } +TEST(ExecPlanExecution, SegmentedAggregationWithMultiThreading) { + BatchesWithSchema data; + data.batches = {ExecBatchFromJSON({int32()}, "[[1]]")}; + data.schema = schema({field("i32", int32())}); + Declaration plan = Declaration::Sequence( + {{"source", + SourceNodeOptions{data.schema, data.gen(/*parallel=*/false, /*slow=*/false)}}, + {"aggregate", AggregateNodeOptions{/*aggregates=*/{ + {"count", nullptr, "i32", "count(i32)"}, + }, + /*keys=*/{"i32"}, /*segment_leys=*/{"i32"}}}}); + EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, HasSubstr("multi-threaded"), + DeclarationToExecBatches(std::move(plan))); +} + +TEST(ExecPlanExecution, SegmentedAggregationWithOneSegment) { + BatchesWithSchema data; + data.batches = { + ExecBatchFromJSON({int32(), int32(), int32()}, "[[1, 1, 1], [1, 2, 1], [1, 1, 2]]"), + ExecBatchFromJSON({int32(), int32(), int32()}, + "[[1, 2, 2], [1, 1, 3], [1, 2, 3]]")}; + data.schema = schema({ + field("a", int32()), + field("b", int32()), + field("c", int32()), + }); + + Declaration plan = Declaration::Sequence( + {{"source", + SourceNodeOptions{data.schema, data.gen(/*parallel=*/false, /*slow=*/false)}}, + {"aggregate", AggregateNodeOptions{/*aggregates=*/{ + {"hash_sum", nullptr, "c", "sum(c)"}, + {"hash_mean", nullptr, "c", "mean(c)"}, + }, + /*keys=*/{"b"}, /*segment_leys=*/{"a"}}}}); + ASSERT_OK_AND_ASSIGN(BatchesWithCommonSchema actual_batches, + DeclarationToExecBatches(std::move(plan), /*use_threads=*/false)); + + auto expected = ExecBatchFromJSON({int64(), float64(), int32(), int32()}, + R"([[6, 2, 1, 1], [6, 2, 2, 1]])"); + AssertExecBatchesEqualIgnoringOrder(actual_batches.schema, actual_batches.batches, + {expected}); +} + +TEST(ExecPlanExecution, SegmentedAggregationWithTwoSegments) { + BatchesWithSchema data; + data.batches = { + ExecBatchFromJSON({int32(), int32(), int32()}, "[[1, 1, 1], [1, 2, 1], [1, 1, 2]]"), + ExecBatchFromJSON({int32(), int32(), int32()}, + "[[2, 2, 2], [2, 1, 3], [2, 2, 3]]")}; + data.schema = schema({ + field("a", int32()), + field("b", int32()), + field("c", int32()), + }); + + Declaration plan = Declaration::Sequence( + {{"source", + SourceNodeOptions{data.schema, data.gen(/*parallel=*/false, /*slow=*/false)}}, + {"aggregate", AggregateNodeOptions{/*aggregates=*/{ + {"hash_sum", nullptr, "c", "sum(c)"}, + {"hash_mean", nullptr, "c", "mean(c)"}, + }, + /*keys=*/{"b"}, /*segment_leys=*/{"a"}}}}); + ASSERT_OK_AND_ASSIGN(BatchesWithCommonSchema actual_batches, + DeclarationToExecBatches(std::move(plan), /*use_threads=*/false)); + + auto expected = ExecBatchFromJSON( + {int64(), float64(), int32(), int32()}, + R"([[3, 1.5, 1, 1], [1, 1, 2, 1], [3, 3, 1, 2], [5, 2.5, 2, 2]])"); + AssertExecBatchesEqualIgnoringOrder(actual_batches.schema, actual_batches.batches, + {expected}); +} + +TEST(ExecPlanExecution, SegmentedAggregationWithBatchCrossingSegment) { + BatchesWithSchema data; + data.batches = { + ExecBatchFromJSON({int32(), int32(), int32()}, "[[1, 1, 1], [1, 1, 1], [2, 2, 2]]"), + ExecBatchFromJSON({int32(), int32(), int32()}, + "[[2, 2, 2], [3, 3, 3], [3, 3, 3]]")}; + data.schema = schema({ + field("a", int32()), + field("b", int32()), + field("c", int32()), + }); + + Declaration plan = Declaration::Sequence( + {{"source", + SourceNodeOptions{data.schema, data.gen(/*parallel=*/false, /*slow=*/false)}}, + {"aggregate", AggregateNodeOptions{/*aggregates=*/{ + {"hash_sum", nullptr, "c", "sum(c)"}, + {"hash_mean", nullptr, "c", "mean(c)"}, + }, + /*keys=*/{"b"}, /*segment_leys=*/{"a"}}}}); + ASSERT_OK_AND_ASSIGN(BatchesWithCommonSchema actual_batches, + DeclarationToExecBatches(std::move(plan), /*use_threads=*/false)); + + auto expected = ExecBatchFromJSON({int64(), float64(), int32(), int32()}, + R"([[2, 1, 1, 1], [4, 2, 2, 2], [6, 3, 3, 3]])"); + AssertExecBatchesEqualIgnoringOrder(actual_batches.schema, actual_batches.batches, + {expected}); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc index 3c3476d62de08..fd631e0dc513a 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc @@ -42,6 +42,7 @@ #include "arrow/compute/kernels/test_util.h" #include "arrow/compute/registry.h" #include "arrow/compute/row/grouper.h" +#include "arrow/compute/row/grouper_internal.h" #include "arrow/table.h" #include "arrow/testing/generator.h" #include "arrow/testing/gtest_util.h" @@ -72,6 +73,10 @@ using internal::ToChars; namespace compute { namespace { +using GroupByFunction = std::function( + const std::vector&, const std::vector&, const std::vector&, + const std::vector&, bool, bool)>; + Result NaiveGroupBy(std::vector arguments, std::vector keys, const std::vector& aggregates) { ARROW_ASSIGN_OR_RAISE(auto key_batch, ExecBatch::Make(std::move(keys))); @@ -135,22 +140,99 @@ Result NaiveGroupBy(std::vector arguments, std::vector keys return Take(struct_arr, sorted_indices); } +Result MakeGroupByOutput(const std::vector& output_batches, + const std::shared_ptr output_schema, + size_t num_aggregates, size_t num_keys, bool naive) { + ArrayVector out_arrays(num_aggregates + num_keys); + for (size_t i = 0; i < out_arrays.size(); ++i) { + std::vector> arrays(output_batches.size()); + for (size_t j = 0; j < output_batches.size(); ++j) { + arrays[j] = output_batches[j].values[i].make_array(); + } + if (arrays.empty()) { + ARROW_ASSIGN_OR_RAISE( + out_arrays[i], + MakeArrayOfNull(output_schema->field(static_cast(i))->type(), + /*length=*/0)); + } else { + ARROW_ASSIGN_OR_RAISE(out_arrays[i], Concatenate(arrays)); + } + } + + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr struct_arr, + StructArray::Make(std::move(out_arrays), output_schema->fields())); + + bool need_sort = !naive; + for (size_t i = num_aggregates; need_sort && i < out_arrays.size(); i++) { + if (output_schema->field(static_cast(i))->type()->id() == Type::DICTIONARY) { + need_sort = false; + } + } + if (!need_sort) { + return struct_arr; + } + + // The exec plan may reorder the output rows. The tests are all setup to expect ouptut + // in ascending order of keys. So we need to sort the result by the key columns. To do + // that we create a table using the key columns, calculate the sort indices from that + // table (sorting on all fields) and then use those indices to calculate our result. + std::vector> key_fields; + std::vector> key_columns; + std::vector sort_keys; + for (std::size_t i = 0; i < num_keys; i++) { + const std::shared_ptr& arr = out_arrays[i + num_aggregates]; + key_columns.push_back(arr); + key_fields.push_back(field("name_does_not_matter", arr->type())); + sort_keys.emplace_back(static_cast(i)); + } + std::shared_ptr key_schema = schema(std::move(key_fields)); + std::shared_ptr key_table = Table::Make(std::move(key_schema), key_columns); + SortOptions sort_options(std::move(sort_keys)); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr sort_indices, + SortIndices(key_table, sort_options)); + + return Take(struct_arr, sort_indices); +} + Result RunGroupBy(const BatchesWithSchema& input, const std::vector& key_names, - const std::vector& aggregates, bool use_threads) { + const std::vector& segment_key_names, + const std::vector& aggregates, ExecContext* ctx, + bool use_threads, bool segmented = false, bool naive = false) { + // The `use_threads` flag determines whether threads are used in generating the input to + // the group-by. + // + // When segment_keys is non-empty the `segmented` flag is always true; otherwise (when + // empty), it may still be set to true. In this case, the tester restructures (without + // changing the data of) the result of RunGroupBy from `std::vector` + // (output_batches) to `std::vector` (out_arrays), which have the structure + // typical of the case of a non-empty segment_keys (with multiple arrays per column, one + // array per segment) but only one array per column (because, technically, there is only + // one segment in this case). Thus, this case focuses on the structure of the result. + // + // The `naive` flag means that the output is expected to be like that of `NaiveGroupBy`, + // which in particular doesn't require sorting. The reason for the naive flag is that + // the expected output of some test-cases is naive and of some others it is not. The + // current `RunGroupBy` function deals with both kinds of expected output. std::vector keys(key_names.size()); for (size_t i = 0; i < key_names.size(); ++i) { keys[i] = FieldRef(key_names[i]); } + std::vector segment_keys(segment_key_names.size()); + for (size_t i = 0; i < segment_key_names.size(); ++i) { + segment_keys[i] = FieldRef(segment_key_names[i]); + } - ARROW_ASSIGN_OR_RAISE(auto plan, ExecPlan::Make(*threaded_exec_context())); + ARROW_ASSIGN_OR_RAISE(auto plan, ExecPlan::Make(*ctx)); AsyncGenerator> sink_gen; RETURN_NOT_OK( Declaration::Sequence( { {"source", SourceNodeOptions{input.schema, input.gen(use_threads, /*slow=*/false)}}, - {"aggregate", AggregateNodeOptions{std::move(aggregates), std::move(keys)}}, + {"aggregate", AggregateNodeOptions{std::move(aggregates), std::move(keys), + std::move(segment_keys)}}, {"sink", SinkNodeOptions{&sink_gen}}, }) .AddToPlan(plan.get())); @@ -174,81 +256,117 @@ Result RunGroupBy(const BatchesWithSchema& input, ARROW_ASSIGN_OR_RAISE(std::vector output_batches, start_and_collect.MoveResult()); - ArrayVector out_arrays(aggregates.size() + key_names.size()); const auto& output_schema = plan->nodes()[0]->output()->output_schema(); + if (!segmented) { + return MakeGroupByOutput(output_batches, output_schema, aggregates.size(), + key_names.size(), naive); + } + + std::vector out_arrays(aggregates.size() + key_names.size() + + segment_key_names.size()); for (size_t i = 0; i < out_arrays.size(); ++i) { std::vector> arrays(output_batches.size()); for (size_t j = 0; j < output_batches.size(); ++j) { - arrays[j] = output_batches[j].values[i].make_array(); + auto& value = output_batches[j].values[i]; + if (value.is_scalar()) { + ARROW_ASSIGN_OR_RAISE( + arrays[j], MakeArrayFromScalar(*value.scalar(), output_batches[j].length)); + } else if (value.is_array()) { + arrays[j] = value.make_array(); + } else { + return Status::Invalid("GroupByUsingExecPlan unsupported value kind ", + ToString(value.kind())); + } } if (arrays.empty()) { + arrays.resize(1); ARROW_ASSIGN_OR_RAISE( - out_arrays[i], - MakeArrayOfNull(output_schema->field(static_cast(i))->type(), - /*length=*/0)); - } else { - ARROW_ASSIGN_OR_RAISE(out_arrays[i], Concatenate(arrays)); + arrays[0], MakeArrayOfNull(output_schema->field(static_cast(i))->type(), + /*length=*/0)); } + out_arrays[i] = {std::move(arrays)}; } - // The exec plan may reorder the output rows. The tests are all setup to expect ouptut - // in ascending order of keys. So we need to sort the result by the key columns. To do - // that we create a table using the key columns, calculate the sort indices from that - // table (sorting on all fields) and then use those indices to calculate our result. - std::vector> key_fields; - std::vector> key_columns; - std::vector sort_keys; - for (std::size_t i = 0; i < key_names.size(); i++) { - const std::shared_ptr& arr = out_arrays[i + aggregates.size()]; - if (arr->type_id() == Type::DICTIONARY) { - // Can't sort dictionary columns so need to decode - auto dict_arr = checked_pointer_cast(arr); - ARROW_ASSIGN_OR_RAISE(auto decoded_arr, - Take(*dict_arr->dictionary(), *dict_arr->indices())); - key_columns.push_back(decoded_arr); - key_fields.push_back( - field("name_does_not_matter", dict_arr->dict_type()->value_type())); - } else { - key_columns.push_back(arr); - key_fields.push_back(field("name_does_not_matter", arr->type())); + if (segmented && segment_key_names.size() > 0) { + ArrayVector struct_arrays; + struct_arrays.reserve(output_batches.size()); + for (size_t j = 0; j < output_batches.size(); ++j) { + ArrayVector struct_fields; + struct_fields.reserve(out_arrays.size()); + for (auto out_array : out_arrays) { + struct_fields.push_back(out_array[j]); + } + ARROW_ASSIGN_OR_RAISE(auto struct_array, + StructArray::Make(struct_fields, output_schema->fields())); + struct_arrays.push_back(struct_array); } - sort_keys.emplace_back(static_cast(i)); + return ChunkedArray::Make(struct_arrays); + } else { + ArrayVector struct_fields(out_arrays.size()); + for (size_t i = 0; i < out_arrays.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(struct_fields[i], Concatenate(out_arrays[i])); + } + return StructArray::Make(std::move(struct_fields), output_schema->fields()); } - std::shared_ptr key_schema = schema(std::move(key_fields)); - std::shared_ptr
key_table = Table::Make(std::move(key_schema), key_columns); - SortOptions sort_options(std::move(sort_keys)); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr sort_indices, - SortIndices(key_table, sort_options)); +} - ARROW_ASSIGN_OR_RAISE( - std::shared_ptr struct_arr, - StructArray::Make(std::move(out_arrays), output_schema->fields())); +Result RunGroupBy(const BatchesWithSchema& input, + const std::vector& key_names, + const std::vector& segment_key_names, + const std::vector& aggregates, bool use_threads, + bool segmented = false, bool naive = false) { + if (segment_key_names.size() > 0) { + ARROW_ASSIGN_OR_RAISE(auto thread_pool, arrow::internal::ThreadPool::Make(1)); + ExecContext seq_ctx(default_memory_pool(), thread_pool.get()); + return RunGroupBy(input, key_names, segment_key_names, aggregates, &seq_ctx, + use_threads, segmented, naive); + } else { + return RunGroupBy(input, key_names, segment_key_names, aggregates, + threaded_exec_context(), use_threads, segmented, naive); + } +} - return Take(struct_arr, sort_indices); +Result RunGroupBy(const BatchesWithSchema& input, + const std::vector& key_names, + const std::vector& aggregates, bool use_threads, + bool segmented = false, bool naive = false) { + return RunGroupBy(input, key_names, {}, aggregates, use_threads, segmented); } /// Simpler overload where you can give the columns as datums Result RunGroupBy(const std::vector& arguments, const std::vector& keys, - const std::vector& aggregates, - bool use_threads = false) { + const std::vector& segment_keys, + const std::vector& aggregates, bool use_threads, + bool segmented = false, bool naive = false) { using arrow::compute::detail::ExecSpanIterator; - FieldVector scan_fields(arguments.size() + keys.size()); + FieldVector scan_fields(arguments.size() + keys.size() + segment_keys.size()); std::vector key_names(keys.size()); + std::vector segment_key_names(segment_keys.size()); for (size_t i = 0; i < arguments.size(); ++i) { auto name = std::string("agg_") + ToChars(i); scan_fields[i] = field(name, arguments[i].type()); } + size_t base = arguments.size(); for (size_t i = 0; i < keys.size(); ++i) { auto name = std::string("key_") + ToChars(i); - scan_fields[arguments.size() + i] = field(name, keys[i].type()); + scan_fields[base + i] = field(name, keys[i].type()); key_names[i] = std::move(name); } + base += keys.size(); + size_t j = keys.size(); + std::string prefix("key_"); + for (size_t i = 0; i < segment_keys.size(); ++i) { + auto name = prefix + std::to_string(j++); + scan_fields[base + i] = field(name, segment_keys[i].type()); + segment_key_names[i] = std::move(name); + } std::vector inputs = arguments; - inputs.reserve(inputs.size() + keys.size()); + inputs.reserve(inputs.size() + keys.size() + segment_keys.size()); inputs.insert(inputs.end(), keys.begin(), keys.end()); + inputs.insert(inputs.end(), segment_keys.begin(), segment_keys.end()); ExecSpanIterator span_iterator; ARROW_ASSIGN_OR_RAISE(auto batch, ExecBatch::Make(inputs)); @@ -261,15 +379,35 @@ Result RunGroupBy(const std::vector& arguments, input.batches.push_back(span.ToExecBatch()); } - return RunGroupBy(input, key_names, aggregates, use_threads); + return RunGroupBy(input, key_names, segment_key_names, aggregates, use_threads, + segmented, naive); +} + +Result RunGroupByImpl(const std::vector& arguments, + const std::vector& keys, + const std::vector& segment_keys, + const std::vector& aggregates, bool use_threads, + bool naive = false) { + return RunGroupBy(arguments, keys, segment_keys, aggregates, use_threads, + /*segmented=*/false, naive); } -void ValidateGroupBy(const std::vector& aggregates, - std::vector arguments, std::vector keys) { +Result RunSegmentedGroupByImpl(const std::vector& arguments, + const std::vector& keys, + const std::vector& segment_keys, + const std::vector& aggregates, + bool use_threads, bool naive = false) { + return RunGroupBy(arguments, keys, segment_keys, aggregates, use_threads, + /*segmented=*/true, naive); +} + +void ValidateGroupBy(GroupByFunction group_by, const std::vector& aggregates, + std::vector arguments, std::vector keys, + bool naive = true) { ASSERT_OK_AND_ASSIGN(Datum expected, NaiveGroupBy(arguments, keys, aggregates)); - ASSERT_OK_AND_ASSIGN(Datum actual, RunGroupBy(arguments, keys, aggregates, - /*use_threads=*/false)); + ASSERT_OK_AND_ASSIGN(Datum actual, group_by(arguments, keys, {}, aggregates, + /*use_threads=*/false, naive)); ASSERT_OK(expected.make_array()->ValidateFull()); ValidateOutput(actual); @@ -290,8 +428,9 @@ struct TestAggregate { std::shared_ptr options; }; -Result GroupByTest(const std::vector& arguments, +Result GroupByTest(GroupByFunction group_by, const std::vector& arguments, const std::vector& keys, + const std::vector& segment_keys, const std::vector& aggregates, bool use_threads) { std::vector internal_aggregates; @@ -301,27 +440,36 @@ Result GroupByTest(const std::vector& arguments, {t_agg.function, t_agg.options, "agg_" + ToChars(idx), t_agg.function}); idx = idx + 1; } - return RunGroupBy(arguments, keys, internal_aggregates, use_threads); + return group_by(arguments, keys, segment_keys, internal_aggregates, use_threads, + /*naive=*/false); } -} // namespace +Result GroupByTest(GroupByFunction group_by, const std::vector& arguments, + const std::vector& keys, + const std::vector& aggregates, + bool use_threads) { + return GroupByTest(group_by, arguments, keys, {}, aggregates, use_threads); +} -TEST(Grouper, SupportedKeys) { - ASSERT_OK(Grouper::Make({boolean()})); +template +void TestGroupClassSupportedKeys( + std::function>(const std::vector&)> + make_func) { + ASSERT_OK(make_func({boolean()})); - ASSERT_OK(Grouper::Make({int8(), uint16(), int32(), uint64()})); + ASSERT_OK(make_func({int8(), uint16(), int32(), uint64()})); - ASSERT_OK(Grouper::Make({dictionary(int64(), utf8())})); + ASSERT_OK(make_func({dictionary(int64(), utf8())})); - ASSERT_OK(Grouper::Make({float16(), float32(), float64()})); + ASSERT_OK(make_func({float16(), float32(), float64()})); - ASSERT_OK(Grouper::Make({utf8(), binary(), large_utf8(), large_binary()})); + ASSERT_OK(make_func({utf8(), binary(), large_utf8(), large_binary()})); - ASSERT_OK(Grouper::Make({fixed_size_binary(16), fixed_size_binary(32)})); + ASSERT_OK(make_func({fixed_size_binary(16), fixed_size_binary(32)})); - ASSERT_OK(Grouper::Make({decimal128(32, 10), decimal256(76, 20)})); + ASSERT_OK(make_func({decimal128(32, 10), decimal256(76, 20)})); - ASSERT_OK(Grouper::Make({date32(), date64()})); + ASSERT_OK(make_func({date32(), date64()})); for (auto unit : { TimeUnit::SECOND, @@ -329,25 +477,257 @@ TEST(Grouper, SupportedKeys) { TimeUnit::MICRO, TimeUnit::NANO, }) { - ASSERT_OK(Grouper::Make({timestamp(unit), duration(unit)})); + ASSERT_OK(make_func({timestamp(unit), duration(unit)})); } ASSERT_OK( - Grouper::Make({day_time_interval(), month_interval(), month_day_nano_interval()})); + make_func({day_time_interval(), month_interval(), month_day_nano_interval()})); + + ASSERT_OK(make_func({null()})); - ASSERT_OK(Grouper::Make({null()})); + ASSERT_RAISES(NotImplemented, make_func({struct_({field("", int64())})})); - ASSERT_RAISES(NotImplemented, Grouper::Make({struct_({field("", int64())})})); + ASSERT_RAISES(NotImplemented, make_func({struct_({})})); - ASSERT_RAISES(NotImplemented, Grouper::Make({struct_({})})); + ASSERT_RAISES(NotImplemented, make_func({list(int32())})); - ASSERT_RAISES(NotImplemented, Grouper::Make({list(int32())})); + ASSERT_RAISES(NotImplemented, make_func({fixed_size_list(int32(), 5)})); - ASSERT_RAISES(NotImplemented, Grouper::Make({fixed_size_list(int32(), 5)})); + ASSERT_RAISES(NotImplemented, make_func({dense_union({field("", int32())})})); +} + +void TestSegments(std::unique_ptr& segmenter, const ExecSpan& batch, + std::vector expected_segments) { + int64_t offset = 0, segment_num = 0; + for (auto expected_segment : expected_segments) { + SCOPED_TRACE("segment #" + ToChars(segment_num++)); + ASSERT_OK_AND_ASSIGN(auto segment, segmenter->GetNextSegment(batch, offset)); + ASSERT_EQ(expected_segment, segment); + offset = segment.offset + segment.length; + } +} - ASSERT_RAISES(NotImplemented, Grouper::Make({dense_union({field("", int32())})})); +Result> MakeGrouper(const std::vector& key_types) { + return Grouper::Make(key_types, default_exec_context()); +} + +Result> MakeRowSegmenter( + const std::vector& key_types) { + return RowSegmenter::Make(key_types, /*nullable_leys=*/false, default_exec_context()); +} + +Result> MakeGenericSegmenter( + const std::vector& key_types) { + return MakeAnyKeysSegmenter(key_types, default_exec_context()); +} + +} // namespace + +TEST(RowSegmenter, SupportedKeys) { + TestGroupClassSupportedKeys(MakeRowSegmenter); +} + +TEST(RowSegmenter, Basics) { + std::vector bad_types2 = {int32(), float32()}; + std::vector types2 = {int32(), int32()}; + std::vector bad_types1 = {float32()}; + std::vector types1 = {int32()}; + std::vector types0 = {}; + auto batch2 = ExecBatchFromJSON(types2, "[[1, 1], [1, 2], [2, 2]]"); + auto batch1 = ExecBatchFromJSON(types1, "[[1], [1], [2]]"); + ExecBatch batch0({}, 3); + { + SCOPED_TRACE("offset"); + ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types0)); + ExecSpan span0(batch0); + for (int64_t offset : {-1, 4}) { + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + HasSubstr("invalid grouping segmenter offset"), + segmenter->GetNextSegment(span0, offset)); + } + } + { + SCOPED_TRACE("types0 segmenting of batch2"); + ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types0)); + ExecSpan span2(batch2); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch size 0 "), + segmenter->GetNextSegment(span2, 0)); + ExecSpan span0(batch0); + TestSegments(segmenter, span0, {{0, 3, true, true}, {3, 0, true, true}}); + } + { + SCOPED_TRACE("bad_types1 segmenting of batch1"); + ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(bad_types1)); + ExecSpan span1(batch1); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch value 0 of type "), + segmenter->GetNextSegment(span1, 0)); + } + { + SCOPED_TRACE("types1 segmenting of batch2"); + ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types1)); + ExecSpan span2(batch2); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch size 1 "), + segmenter->GetNextSegment(span2, 0)); + ExecSpan span1(batch1); + TestSegments(segmenter, span1, + {{0, 2, false, true}, {2, 1, true, false}, {3, 0, true, true}}); + } + { + SCOPED_TRACE("bad_types2 segmenting of batch2"); + ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(bad_types2)); + ExecSpan span2(batch2); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch value 1 of type "), + segmenter->GetNextSegment(span2, 0)); + } + { + SCOPED_TRACE("types2 segmenting of batch1"); + ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types2)); + ExecSpan span1(batch1); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch size 2 "), + segmenter->GetNextSegment(span1, 0)); + ExecSpan span2(batch2); + TestSegments(segmenter, span2, + {{0, 1, false, true}, + {1, 1, false, false}, + {2, 1, true, false}, + {3, 0, true, true}}); + } +} + +TEST(RowSegmenter, NonOrdered) { + std::vector types = {int32()}; + auto batch = ExecBatchFromJSON(types, "[[1], [1], [2], [1], [2]]"); + ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types)); + TestSegments(segmenter, ExecSpan(batch), + {{0, 2, false, true}, + {2, 1, false, false}, + {3, 1, false, false}, + {4, 1, true, false}, + {5, 0, true, true}}); +} + +TEST(RowSegmenter, EmptyBatches) { + std::vector types = {int32()}; + std::vector batches = { + ExecBatchFromJSON(types, "[]"), ExecBatchFromJSON(types, "[]"), + ExecBatchFromJSON(types, "[[1]]"), ExecBatchFromJSON(types, "[]"), + ExecBatchFromJSON(types, "[[1]]"), ExecBatchFromJSON(types, "[]"), + ExecBatchFromJSON(types, "[[2], [2]]"), ExecBatchFromJSON(types, "[]"), + }; + ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types)); + TestSegments(segmenter, ExecSpan(batches[0]), {}); + TestSegments(segmenter, ExecSpan(batches[1]), {}); + TestSegments(segmenter, ExecSpan(batches[2]), {{0, 1, true, true}}); + TestSegments(segmenter, ExecSpan(batches[3]), {}); + TestSegments(segmenter, ExecSpan(batches[4]), {{0, 1, true, true}}); + TestSegments(segmenter, ExecSpan(batches[5]), {}); + TestSegments(segmenter, ExecSpan(batches[6]), {{0, 2, true, false}}); + TestSegments(segmenter, ExecSpan(batches[7]), {}); +} + +TEST(RowSegmenter, MultipleSegments) { + std::vector types = {int32()}; + auto batch = ExecBatchFromJSON(types, "[[1], [1], [2], [5], [3], [3], [5], [5], [4]]"); + ASSERT_OK_AND_ASSIGN(auto segmenter, MakeRowSegmenter(types)); + TestSegments(segmenter, ExecSpan(batch), + {{0, 2, false, true}, + {2, 1, false, false}, + {3, 1, false, false}, + {4, 2, false, false}, + {6, 2, false, false}, + {8, 1, true, false}, + {9, 0, true, true}}); +} + +namespace { + +void TestRowSegmenterConstantBatch( + std::function shape_func, + std::function>(const std::vector&)> + make_segmenter) { + constexpr size_t n = 3, repetitions = 3; + std::vector types = {int32(), int32(), int32()}; + std::vector shapes(n); + for (size_t i = 0; i < n; i++) shapes[i] = shape_func(i); + auto full_batch = ExecBatchFromJSON(types, shapes, "[[1, 1, 1], [1, 1, 1], [1, 1, 1]]"); + auto test_by_size = [&](size_t size) -> Status { + SCOPED_TRACE("constant-batch with " + ToChars(size) + " key(s)"); + std::vector values(full_batch.values.begin(), + full_batch.values.begin() + size); + ExecBatch batch(values, full_batch.length); + std::vector key_types(types.begin(), types.begin() + size); + ARROW_ASSIGN_OR_RAISE(auto segmenter, make_segmenter(key_types)); + for (size_t i = 0; i < repetitions; i++) { + TestSegments(segmenter, ExecSpan(batch), {{0, 3, true, true}, {3, 0, true, true}}); + ARROW_RETURN_NOT_OK(segmenter->Reset()); + } + return Status::OK(); + }; + for (size_t i = 0; i <= 3; i++) { + ASSERT_OK(test_by_size(i)); + } } +} // namespace + +TEST(RowSegmenter, ConstantArrayBatch) { + TestRowSegmenterConstantBatch([](size_t i) { return ArgShape::ARRAY; }, + MakeRowSegmenter); +} + +TEST(RowSegmenter, ConstantScalarBatch) { + TestRowSegmenterConstantBatch([](size_t i) { return ArgShape::SCALAR; }, + MakeRowSegmenter); +} + +TEST(RowSegmenter, ConstantMixedBatch) { + TestRowSegmenterConstantBatch( + [](size_t i) { return i % 2 == 0 ? ArgShape::SCALAR : ArgShape::ARRAY; }, + MakeRowSegmenter); +} + +TEST(RowSegmenter, ConstantArrayBatchWithAnyKeysSegmenter) { + TestRowSegmenterConstantBatch([](size_t i) { return ArgShape::ARRAY; }, + MakeGenericSegmenter); +} + +TEST(RowSegmenter, ConstantScalarBatchWithAnyKeysSegmenter) { + TestRowSegmenterConstantBatch([](size_t i) { return ArgShape::SCALAR; }, + MakeGenericSegmenter); +} + +TEST(RowSegmenter, ConstantMixedBatchWithAnyKeysSegmenter) { + TestRowSegmenterConstantBatch( + [](size_t i) { return i % 2 == 0 ? ArgShape::SCALAR : ArgShape::ARRAY; }, + MakeGenericSegmenter); +} + +TEST(RowSegmenter, RowConstantBatch) { + constexpr size_t n = 3; + std::vector types = {int32(), int32(), int32()}; + auto full_batch = ExecBatchFromJSON(types, "[[1, 1, 1], [2, 2, 2], [3, 3, 3]]"); + std::vector expected_segments_for_size_0 = {{0, 3, true, true}, + {3, 0, true, true}}; + std::vector expected_segments = { + {0, 1, false, true}, {1, 1, false, false}, {2, 1, true, false}, {3, 0, true, true}}; + auto test_by_size = [&](size_t size) -> Status { + SCOPED_TRACE("constant-batch with " + ToChars(size) + " key(s)"); + std::vector values(full_batch.values.begin(), + full_batch.values.begin() + size); + ExecBatch batch(values, full_batch.length); + std::vector key_types(types.begin(), types.begin() + size); + ARROW_ASSIGN_OR_RAISE(auto segmenter, MakeRowSegmenter(key_types)); + TestSegments(segmenter, ExecSpan(batch), + size == 0 ? expected_segments_for_size_0 : expected_segments); + return Status::OK(); + }; + for (size_t i = 0; i <= n; i++) { + ASSERT_OK(test_by_size(i)); + } +} + +TEST(Grouper, SupportedKeys) { TestGroupClassSupportedKeys(MakeGrouper); } + struct TestGrouper { explicit TestGrouper(std::vector types, std::vector shapes = {}) : types_(std::move(types)), shapes_(std::move(shapes)) { @@ -783,7 +1163,49 @@ TEST(Grouper, ScalarValues) { } } -TEST(GroupBy, Errors) { +void TestSegmentKey(GroupByFunction group_by, const std::shared_ptr
& table, + Datum output, const std::vector& segment_keys); + +class GroupBy : public ::testing::TestWithParam { + public: + void ValidateGroupBy(const std::vector& aggregates, + std::vector arguments, std::vector keys, + bool naive = true) { + compute::ValidateGroupBy(GetParam(), aggregates, arguments, keys, naive); + } + + Result GroupByTest(const std::vector& arguments, + const std::vector& keys, + const std::vector& segment_keys, + const std::vector& aggregates, + bool use_threads) { + return compute::GroupByTest(GetParam(), arguments, keys, segment_keys, aggregates, + use_threads); + } + + Result GroupByTest(const std::vector& arguments, + const std::vector& keys, + const std::vector& aggregates, + bool use_threads) { + return compute::GroupByTest(GetParam(), arguments, keys, aggregates, use_threads); + } + + Result AltGroupBy(const std::vector& arguments, + const std::vector& keys, + const std::vector& segment_keys, + const std::vector& aggregates, + bool use_threads = false) { + return GetParam()(arguments, keys, segment_keys, aggregates, use_threads, + /*naive=*/false); + } + + void TestSegmentKey(const std::shared_ptr
& table, Datum output, + const std::vector& segment_keys) { + return compute::TestSegmentKey(GetParam(), table, output, segment_keys); + } +}; + +TEST_P(GroupBy, Errors) { auto batch = RecordBatchFromJSON( schema({field("argument", float64()), field("group_id", uint32())}), R"([ [1.0, 1], @@ -804,7 +1226,7 @@ TEST(GroupBy, Errors) { HasSubstr("Direct execution of HASH_AGGREGATE functions"))); } -TEST(GroupBy, NoBatches) { +TEST_P(GroupBy, NoBatches) { // Regression test for ARROW-14583: handle when no batches are // passed to the group by node before finalizing auto table = @@ -851,7 +1273,7 @@ void SortBy(std::vector names, Datum* aggregated_and_grouped) { } } // namespace -TEST(GroupBy, CountOnly) { +TEST_P(GroupBy, CountOnly) { for (bool use_threads : {true, false}) { SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); @@ -897,7 +1319,7 @@ TEST(GroupBy, CountOnly) { } } -TEST(GroupBy, CountScalar) { +TEST_P(GroupBy, CountScalar) { BatchesWithSchema input; input.batches = { ExecBatchFromJSON({int32(), int64()}, {ArgShape::SCALAR, ArgShape::ARRAY}, @@ -937,7 +1359,7 @@ TEST(GroupBy, CountScalar) { } } -TEST(GroupBy, SumOnly) { +TEST_P(GroupBy, SumOnly) { for (bool use_threads : {true, false}) { SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); @@ -983,7 +1405,7 @@ TEST(GroupBy, SumOnly) { } } -TEST(GroupBy, SumMeanProductDecimal) { +TEST_P(GroupBy, SumMeanProductDecimal) { auto in_schema = schema({ field("argument0", decimal128(3, 2)), field("argument1", decimal256(3, 2)), @@ -1057,7 +1479,7 @@ TEST(GroupBy, SumMeanProductDecimal) { } } -TEST(GroupBy, MeanOnly) { +TEST_P(GroupBy, MeanOnly) { for (bool use_threads : {true, false}) { SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); @@ -1108,7 +1530,7 @@ TEST(GroupBy, MeanOnly) { } } -TEST(GroupBy, SumMeanProductScalar) { +TEST_P(GroupBy, SumMeanProductScalar) { BatchesWithSchema input; input.batches = { ExecBatchFromJSON({int32(), int64()}, {ArgShape::SCALAR, ArgShape::ARRAY}, @@ -1146,7 +1568,7 @@ TEST(GroupBy, SumMeanProductScalar) { } } -TEST(GroupBy, VarianceAndStddev) { +TEST_P(GroupBy, VarianceAndStddev) { auto batch = RecordBatchFromJSON( schema({field("argument", int32()), field("key", int64())}), R"([ [1, 1], @@ -1170,6 +1592,7 @@ TEST(GroupBy, VarianceAndStddev) { { batch->GetColumnByName("key"), }, + {}, { {"hash_variance", nullptr}, {"hash_stddev", nullptr}, @@ -1212,6 +1635,7 @@ TEST(GroupBy, VarianceAndStddev) { { batch->GetColumnByName("key"), }, + {}, { {"hash_variance", nullptr}, {"hash_stddev", nullptr}, @@ -1243,6 +1667,7 @@ TEST(GroupBy, VarianceAndStddev) { { batch->GetColumnByName("key"), }, + {}, { {"hash_variance", variance_options}, {"hash_stddev", variance_options}, @@ -1264,7 +1689,7 @@ TEST(GroupBy, VarianceAndStddev) { /*verbose=*/true); } -TEST(GroupBy, VarianceAndStddevDecimal) { +TEST_P(GroupBy, VarianceAndStddevDecimal) { auto batch = RecordBatchFromJSON( schema({field("argument0", decimal128(3, 2)), field("argument1", decimal128(3, 2)), field("key", int64())}), @@ -1290,6 +1715,7 @@ TEST(GroupBy, VarianceAndStddevDecimal) { { batch->GetColumnByName("key"), }, + {}, { {"hash_variance", nullptr}, {"hash_stddev", nullptr}, @@ -1314,7 +1740,7 @@ TEST(GroupBy, VarianceAndStddevDecimal) { /*verbose=*/true); } -TEST(GroupBy, TDigest) { +TEST_P(GroupBy, TDigest) { auto batch = RecordBatchFromJSON( schema({field("argument", float64()), field("key", int64())}), R"([ [1, 1], @@ -1359,6 +1785,7 @@ TEST(GroupBy, TDigest) { { batch->GetColumnByName("key"), }, + {}, { {"hash_tdigest", nullptr}, {"hash_tdigest", options1}, @@ -1390,7 +1817,7 @@ TEST(GroupBy, TDigest) { /*verbose=*/true); } -TEST(GroupBy, TDigestDecimal) { +TEST_P(GroupBy, TDigestDecimal) { auto batch = RecordBatchFromJSON( schema({field("argument0", decimal128(3, 2)), field("argument1", decimal256(3, 2)), field("key", int64())}), @@ -1433,7 +1860,7 @@ TEST(GroupBy, TDigestDecimal) { /*verbose=*/true); } -TEST(GroupBy, ApproximateMedian) { +TEST_P(GroupBy, ApproximateMedian) { for (const auto& type : {float64(), int8()}) { auto batch = RecordBatchFromJSON(schema({field("argument", type), field("key", int64())}), R"([ @@ -1471,6 +1898,7 @@ TEST(GroupBy, ApproximateMedian) { { batch->GetColumnByName("key"), }, + {}, { {"hash_approximate_median", options}, {"hash_approximate_median", keep_nulls}, @@ -1498,7 +1926,7 @@ TEST(GroupBy, ApproximateMedian) { } } -TEST(GroupBy, StddevVarianceTDigestScalar) { +TEST_P(GroupBy, StddevVarianceTDigestScalar) { BatchesWithSchema input; input.batches = { ExecBatchFromJSON({int32(), float32(), int64()}, @@ -1547,7 +1975,7 @@ TEST(GroupBy, StddevVarianceTDigestScalar) { } } -TEST(GroupBy, VarianceOptions) { +TEST_P(GroupBy, VarianceOptions) { BatchesWithSchema input; input.batches = { ExecBatchFromJSON( @@ -1641,7 +2069,7 @@ TEST(GroupBy, VarianceOptions) { } } -TEST(GroupBy, MinMaxOnly) { +TEST_P(GroupBy, MinMaxOnly) { auto in_schema = schema({ field("argument", float64()), field("argument1", null()), @@ -1711,7 +2139,7 @@ TEST(GroupBy, MinMaxOnly) { } } -TEST(GroupBy, MinMaxTypes) { +TEST_P(GroupBy, MinMaxTypes) { std::vector> types; types.insert(types.end(), NumericTypes().begin(), NumericTypes().end()); types.insert(types.end(), TemporalTypes().begin(), TemporalTypes().end()); @@ -1799,7 +2227,7 @@ TEST(GroupBy, MinMaxTypes) { } } -TEST(GroupBy, MinMaxDecimal) { +TEST_P(GroupBy, MinMaxDecimal) { auto in_schema = schema({ field("argument0", decimal128(3, 2)), field("argument1", decimal256(3, 2)), @@ -1866,7 +2294,7 @@ TEST(GroupBy, MinMaxDecimal) { } } -TEST(GroupBy, MinMaxBinary) { +TEST_P(GroupBy, MinMaxBinary) { for (bool use_threads : {true, false}) { for (const auto& ty : BaseBinaryTypes()) { SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); @@ -1917,7 +2345,7 @@ TEST(GroupBy, MinMaxBinary) { } } -TEST(GroupBy, MinMaxFixedSizeBinary) { +TEST_P(GroupBy, MinMaxFixedSizeBinary) { const auto ty = fixed_size_binary(3); for (bool use_threads : {true, false}) { SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); @@ -1967,7 +2395,7 @@ TEST(GroupBy, MinMaxFixedSizeBinary) { } } -TEST(GroupBy, MinOrMax) { +TEST_P(GroupBy, MinOrMax) { auto table = TableFromJSON(schema({field("argument", float64()), field("key", int64())}), {R"([ [1.0, 1], @@ -2020,7 +2448,7 @@ TEST(GroupBy, MinOrMax) { /*verbose=*/true); } -TEST(GroupBy, MinMaxScalar) { +TEST_P(GroupBy, MinMaxScalar) { BatchesWithSchema input; input.batches = { ExecBatchFromJSON({int32(), int64()}, {ArgShape::SCALAR, ArgShape::ARRAY}, @@ -2053,7 +2481,7 @@ TEST(GroupBy, MinMaxScalar) { } } -TEST(GroupBy, AnyAndAll) { +TEST_P(GroupBy, AnyAndAll) { for (bool use_threads : {true, false}) { SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); @@ -2087,7 +2515,7 @@ TEST(GroupBy, AnyAndAll) { auto keep_nulls_min_count = std::make_shared(/*skip_nulls=*/false, /*min_count=*/3); ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped, - RunGroupBy( + AltGroupBy( { table->GetColumnByName("argument"), table->GetColumnByName("argument"), @@ -2098,7 +2526,7 @@ TEST(GroupBy, AnyAndAll) { table->GetColumnByName("argument"), table->GetColumnByName("argument"), }, - {table->GetColumnByName("key")}, + {table->GetColumnByName("key")}, {}, { {"hash_any", no_min, "agg_0", "hash_any"}, {"hash_any", min_count, "agg_1", "hash_any"}, @@ -2142,7 +2570,7 @@ TEST(GroupBy, AnyAndAll) { } } -TEST(GroupBy, AnyAllScalar) { +TEST_P(GroupBy, AnyAllScalar) { BatchesWithSchema input; input.batches = { ExecBatchFromJSON({boolean(), int64()}, {ArgShape::SCALAR, ArgShape::ARRAY}, @@ -2183,7 +2611,7 @@ TEST(GroupBy, AnyAllScalar) { } } -TEST(GroupBy, CountDistinct) { +TEST_P(GroupBy, CountDistinct) { auto all = std::make_shared(CountOptions::ALL); auto only_valid = std::make_shared(CountOptions::ONLY_VALID); auto only_null = std::make_shared(CountOptions::ONLY_NULL); @@ -2223,7 +2651,7 @@ TEST(GroupBy, CountDistinct) { ASSERT_OK_AND_ASSIGN( Datum aggregated_and_grouped, - RunGroupBy( + AltGroupBy( { table->GetColumnByName("argument"), table->GetColumnByName("argument"), @@ -2232,6 +2660,7 @@ TEST(GroupBy, CountDistinct) { { table->GetColumnByName("key"), }, + {}, { {"hash_count_distinct", all, "agg_0", "hash_count_distinct"}, {"hash_count_distinct", only_valid, "agg_1", "hash_count_distinct"}, @@ -2290,7 +2719,7 @@ TEST(GroupBy, CountDistinct) { ASSERT_OK_AND_ASSIGN( aggregated_and_grouped, - RunGroupBy( + AltGroupBy( { table->GetColumnByName("argument"), table->GetColumnByName("argument"), @@ -2299,6 +2728,7 @@ TEST(GroupBy, CountDistinct) { { table->GetColumnByName("key"), }, + {}, { {"hash_count_distinct", all, "agg_0", "hash_count_distinct"}, {"hash_count_distinct", only_valid, "agg_1", "hash_count_distinct"}, @@ -2337,7 +2767,7 @@ TEST(GroupBy, CountDistinct) { ASSERT_OK_AND_ASSIGN( aggregated_and_grouped, - RunGroupBy( + AltGroupBy( { table->GetColumnByName("argument"), table->GetColumnByName("argument"), @@ -2346,6 +2776,7 @@ TEST(GroupBy, CountDistinct) { { table->GetColumnByName("key"), }, + {}, { {"hash_count_distinct", all, "agg_0", "hash_count_distinct"}, {"hash_count_distinct", only_valid, "agg_1", "hash_count_distinct"}, @@ -2370,7 +2801,7 @@ TEST(GroupBy, CountDistinct) { } } -TEST(GroupBy, Distinct) { +TEST_P(GroupBy, Distinct) { auto all = std::make_shared(CountOptions::ALL); auto only_valid = std::make_shared(CountOptions::ONLY_VALID); auto only_null = std::make_shared(CountOptions::ONLY_NULL); @@ -2409,7 +2840,7 @@ TEST(GroupBy, Distinct) { ])"}); ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped, - RunGroupBy( + AltGroupBy( { table->GetColumnByName("argument"), table->GetColumnByName("argument"), @@ -2418,6 +2849,7 @@ TEST(GroupBy, Distinct) { { table->GetColumnByName("key"), }, + {}, { {"hash_distinct", all, "agg_0", "hash_distinct"}, {"hash_distinct", only_valid, "agg_1", "hash_distinct"}, @@ -2482,7 +2914,7 @@ TEST(GroupBy, Distinct) { ])", }); ASSERT_OK_AND_ASSIGN(aggregated_and_grouped, - RunGroupBy( + AltGroupBy( { table->GetColumnByName("argument"), table->GetColumnByName("argument"), @@ -2491,6 +2923,7 @@ TEST(GroupBy, Distinct) { { table->GetColumnByName("key"), }, + {}, { {"hash_distinct", all, "agg_0", "hash_distinct"}, {"hash_distinct", only_valid, "agg_1", "hash_distinct"}, @@ -2513,7 +2946,7 @@ TEST(GroupBy, Distinct) { } } -TEST(GroupBy, OneMiscTypes) { +TEST_P(GroupBy, OneMiscTypes) { auto in_schema = schema({ field("floats", float64()), field("nulls", null()), @@ -2628,7 +3061,7 @@ TEST(GroupBy, OneMiscTypes) { } } -TEST(GroupBy, OneNumericTypes) { +TEST_P(GroupBy, OneNumericTypes) { std::vector> types; types.insert(types.end(), NumericTypes().begin(), NumericTypes().end()); types.insert(types.end(), TemporalTypes().begin(), TemporalTypes().end()); @@ -2713,7 +3146,7 @@ TEST(GroupBy, OneNumericTypes) { } } -TEST(GroupBy, OneBinaryTypes) { +TEST_P(GroupBy, OneBinaryTypes) { for (bool use_threads : {true, false}) { for (const auto& type : BaseBinaryTypes()) { SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); @@ -2761,7 +3194,7 @@ TEST(GroupBy, OneBinaryTypes) { } } -TEST(GroupBy, OneScalar) { +TEST_P(GroupBy, OneScalar) { BatchesWithSchema input; input.batches = { ExecBatchFromJSON({int32(), int64()}, {ArgShape::SCALAR, ArgShape::ARRAY}, @@ -2791,7 +3224,7 @@ TEST(GroupBy, OneScalar) { } } -TEST(GroupBy, ListNumeric) { +TEST_P(GroupBy, ListNumeric) { for (const auto& type : NumericTypes()) { for (auto use_threads : {true, false}) { SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); @@ -2829,13 +3262,14 @@ TEST(GroupBy, ListNumeric) { ])"}); ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped, - RunGroupBy( + AltGroupBy( { table->GetColumnByName("argument"), }, { table->GetColumnByName("key"), }, + {}, { {"hash_list", nullptr, "agg_0", "hash_list"}, }, @@ -2900,13 +3334,14 @@ TEST(GroupBy, ListNumeric) { ])"}); ASSERT_OK_AND_ASSIGN(auto aggregated_and_grouped, - RunGroupBy( + AltGroupBy( { table->GetColumnByName("argument"), }, { table->GetColumnByName("key"), }, + {}, { {"hash_list", nullptr, "agg_0", "hash_list"}, }, @@ -2941,7 +3376,7 @@ TEST(GroupBy, ListNumeric) { } } -TEST(GroupBy, ListBinaryTypes) { +TEST_P(GroupBy, ListBinaryTypes) { for (bool use_threads : {true, false}) { for (const auto& type : BaseBinaryTypes()) { SCOPED_TRACE(use_threads ? "parallel/merged" : "serial"); @@ -2969,13 +3404,14 @@ TEST(GroupBy, ListBinaryTypes) { ])"}); ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped, - RunGroupBy( + AltGroupBy( { table->GetColumnByName("argument0"), }, { table->GetColumnByName("key"), }, + {}, { {"hash_list", nullptr, "agg_0", "hash_list"}, }, @@ -3031,13 +3467,14 @@ TEST(GroupBy, ListBinaryTypes) { ])"}); ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped, - RunGroupBy( + AltGroupBy( { table->GetColumnByName("argument0"), }, { table->GetColumnByName("key"), }, + {}, { {"hash_list", nullptr, "agg_0", "hash_list"}, }, @@ -3073,7 +3510,7 @@ TEST(GroupBy, ListBinaryTypes) { } } -TEST(GroupBy, ListMiscTypes) { +TEST_P(GroupBy, ListMiscTypes) { auto in_schema = schema({ field("floats", float64()), field("nulls", null()), @@ -3231,7 +3668,7 @@ TEST(GroupBy, ListMiscTypes) { } } -TEST(GroupBy, CountAndSum) { +TEST_P(GroupBy, CountAndSum) { auto batch = RecordBatchFromJSON( schema({field("argument", float64()), field("key", int64())}), R"([ [1.0, 1], @@ -3253,7 +3690,7 @@ TEST(GroupBy, CountAndSum) { std::make_shared(/*skip_nulls=*/true, /*min_count=*/3); ASSERT_OK_AND_ASSIGN( Datum aggregated_and_grouped, - RunGroupBy( + AltGroupBy( { // NB: passing an argument twice or also using it as a key is legal batch->GetColumnByName("argument"), @@ -3266,6 +3703,7 @@ TEST(GroupBy, CountAndSum) { { batch->GetColumnByName("key"), }, + {}, { {"hash_count", count_opts, "agg_0", "hash_count"}, {"hash_count", count_nulls_opts, "agg_1", "hash_count"}, @@ -3298,7 +3736,7 @@ TEST(GroupBy, CountAndSum) { /*verbose=*/true); } -TEST(GroupBy, StandAloneNullaryCount) { +TEST_P(GroupBy, StandAloneNullaryCount) { auto batch = RecordBatchFromJSON( schema({field("argument", float64()), field("key", int64())}), R"([ [1.0, 1], @@ -3314,13 +3752,14 @@ TEST(GroupBy, StandAloneNullaryCount) { ])"); ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped, - RunGroupBy( + AltGroupBy( // zero arguments for aggregations because only the // nullary hash_count_all aggregation is present {}, { batch->GetColumnByName("key"), }, + {}, { {"hash_count_all", "hash_count_all"}, })); @@ -3339,7 +3778,7 @@ TEST(GroupBy, StandAloneNullaryCount) { /*verbose=*/true); } -TEST(GroupBy, Product) { +TEST_P(GroupBy, Product) { auto batch = RecordBatchFromJSON( schema({field("argument", float64()), field("key", int64())}), R"([ [-1.0, 1], @@ -3357,7 +3796,7 @@ TEST(GroupBy, Product) { auto min_count = std::make_shared(/*skip_nulls=*/true, /*min_count=*/3); ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped, - RunGroupBy( + AltGroupBy( { batch->GetColumnByName("argument"), batch->GetColumnByName("key"), @@ -3366,6 +3805,7 @@ TEST(GroupBy, Product) { { batch->GetColumnByName("key"), }, + {}, { {"hash_product", nullptr, "agg_0", "hash_product"}, {"hash_product", nullptr, "agg_1", "hash_product"}, @@ -3395,13 +3835,14 @@ TEST(GroupBy, Product) { ])"); ASSERT_OK_AND_ASSIGN(aggregated_and_grouped, - RunGroupBy( + AltGroupBy( { batch->GetColumnByName("argument"), }, { batch->GetColumnByName("key"), }, + {}, { {"hash_product", nullptr, "agg_0", "hash_product"}, })); @@ -3415,7 +3856,7 @@ TEST(GroupBy, Product) { /*verbose=*/true); } -TEST(GroupBy, SumMeanProductKeepNulls) { +TEST_P(GroupBy, SumMeanProductKeepNulls) { auto batch = RecordBatchFromJSON( schema({field("argument", float64()), field("key", int64())}), R"([ [-1.0, 1], @@ -3434,7 +3875,7 @@ TEST(GroupBy, SumMeanProductKeepNulls) { auto min_count = std::make_shared(/*skip_nulls=*/false, /*min_count=*/3); ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped, - RunGroupBy( + AltGroupBy( { batch->GetColumnByName("argument"), batch->GetColumnByName("argument"), @@ -3446,6 +3887,7 @@ TEST(GroupBy, SumMeanProductKeepNulls) { { batch->GetColumnByName("key"), }, + {}, { {"hash_sum", keep_nulls, "agg_0", "hash_sum"}, {"hash_sum", min_count, "agg_1", "hash_sum"}, @@ -3474,7 +3916,7 @@ TEST(GroupBy, SumMeanProductKeepNulls) { /*verbose=*/true); } -TEST(GroupBy, SumOnlyStringAndDictKeys) { +TEST_P(GroupBy, SumOnlyStringAndDictKeys) { for (auto key_type : {utf8(), dictionary(int32(), utf8())}) { SCOPED_TRACE("key type: " + key_type->ToString()); @@ -3494,7 +3936,7 @@ TEST(GroupBy, SumOnlyStringAndDictKeys) { ASSERT_OK_AND_ASSIGN( Datum aggregated_and_grouped, - RunGroupBy({batch->GetColumnByName("agg_0")}, {batch->GetColumnByName("key")}, + AltGroupBy({batch->GetColumnByName("agg_0")}, {batch->GetColumnByName("key")}, {}, { {"hash_sum", nullptr, "agg_0", "hash_sum"}, })); @@ -3515,7 +3957,7 @@ TEST(GroupBy, SumOnlyStringAndDictKeys) { } } -TEST(GroupBy, ConcreteCaseWithValidateGroupBy) { +TEST_P(GroupBy, ConcreteCaseWithValidateGroupBy) { auto batch = RecordBatchFromJSON(schema({field("agg_0", float64()), field("key", utf8())}), R"([ [1.0, "alfa"], @@ -3551,7 +3993,7 @@ TEST(GroupBy, ConcreteCaseWithValidateGroupBy) { } // Count nulls/non_nulls from record batch with no nulls -TEST(GroupBy, CountNull) { +TEST_P(GroupBy, CountNull) { auto batch = RecordBatchFromJSON(schema({field("agg_0", float64()), field("key", utf8())}), R"([ [1.0, "alfa"], @@ -3574,7 +4016,7 @@ TEST(GroupBy, CountNull) { } } -TEST(GroupBy, RandomArraySum) { +TEST_P(GroupBy, RandomArraySum) { std::shared_ptr options = std::make_shared(/*skip_nulls=*/true, /*min_count=*/0); for (int64_t length : {1 << 10, 1 << 12, 1 << 15}) { @@ -3592,12 +4034,13 @@ TEST(GroupBy, RandomArraySum) { { {"hash_sum", options, "agg_0", "hash_sum"}, }, - {batch->GetColumnByName("agg_0")}, {batch->GetColumnByName("key")}); + {batch->GetColumnByName("agg_0")}, {batch->GetColumnByName("key")}, + /*naive=*/false); } } } -TEST(GroupBy, WithChunkedArray) { +TEST_P(GroupBy, WithChunkedArray) { auto table = TableFromJSON(schema({field("argument", float64()), field("key", int64())}), {R"([{"argument": 1.0, "key": 1}, @@ -3613,7 +4056,7 @@ TEST(GroupBy, WithChunkedArray) { {"argument": null, "key": 3} ])"}); ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped, - RunGroupBy( + AltGroupBy( { table->GetColumnByName("argument"), table->GetColumnByName("argument"), @@ -3622,6 +4065,7 @@ TEST(GroupBy, WithChunkedArray) { { table->GetColumnByName("key"), }, + {}, { {"hash_count", nullptr, "agg_0", "hash_count"}, {"hash_sum", nullptr, "agg_1", "hash_sum"}, @@ -3647,19 +4091,20 @@ TEST(GroupBy, WithChunkedArray) { /*verbose=*/true); } -TEST(GroupBy, MinMaxWithNewGroupsInChunkedArray) { +TEST_P(GroupBy, MinMaxWithNewGroupsInChunkedArray) { auto table = TableFromJSON( schema({field("argument", int64()), field("key", int64())}), {R"([{"argument": 1, "key": 0}])", R"([{"argument": 0, "key": 1}])"}); ScalarAggregateOptions count_options; ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped, - RunGroupBy( + AltGroupBy( { table->GetColumnByName("argument"), }, { table->GetColumnByName("key"), }, + {}, { {"hash_min_max", nullptr, "agg_0", "hash_min_max"}, })); @@ -3679,7 +4124,7 @@ TEST(GroupBy, MinMaxWithNewGroupsInChunkedArray) { /*verbose=*/true); } -TEST(GroupBy, SmallChunkSizeSumOnly) { +TEST_P(GroupBy, SmallChunkSizeSumOnly) { auto batch = RecordBatchFromJSON( schema({field("argument", float64()), field("key", int64())}), R"([ [1.0, 1], @@ -3693,13 +4138,13 @@ TEST(GroupBy, SmallChunkSizeSumOnly) { [0.75, null], [null, 3] ])"); - ASSERT_OK_AND_ASSIGN( - Datum aggregated_and_grouped, - RunGroupBy({batch->GetColumnByName("argument")}, {batch->GetColumnByName("key")}, - { - {"hash_sum", nullptr, "agg_0", "hash_sum"}, - }, - small_chunksize_context())); + ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped, + AltGroupBy({batch->GetColumnByName("argument")}, + {batch->GetColumnByName("key")}, {}, + { + {"hash_sum", nullptr, "agg_0", "hash_sum"}, + }, + small_chunksize_context())); AssertDatumsEqual(ArrayFromJSON(struct_({ field("hash_sum", float64()), field("key_0", int64()), @@ -3714,7 +4159,7 @@ TEST(GroupBy, SmallChunkSizeSumOnly) { /*verbose=*/true); } -TEST(GroupBy, CountWithNullType) { +TEST_P(GroupBy, CountWithNullType) { auto table = TableFromJSON(schema({field("argument", null()), field("key", int64())}), {R"([ [null, 1], @@ -3772,7 +4217,7 @@ TEST(GroupBy, CountWithNullType) { } } -TEST(GroupBy, CountWithNullTypeEmptyTable) { +TEST_P(GroupBy, CountWithNullTypeEmptyTable) { auto table = TableFromJSON(schema({field("argument", null()), field("key", int64())}), {R"([])"}); @@ -3803,7 +4248,7 @@ TEST(GroupBy, CountWithNullTypeEmptyTable) { } } -TEST(GroupBy, SingleNullTypeKey) { +TEST_P(GroupBy, SingleNullTypeKey) { auto table = TableFromJSON(schema({field("argument", int64()), field("key", null())}), {R"([ [1, null], @@ -3860,7 +4305,7 @@ TEST(GroupBy, SingleNullTypeKey) { } } -TEST(GroupBy, MultipleKeysIncludesNullType) { +TEST_P(GroupBy, MultipleKeysIncludesNullType) { auto table = TableFromJSON(schema({field("argument", float64()), field("key_0", utf8()), field("key_1", null())}), {R"([ @@ -3920,7 +4365,7 @@ TEST(GroupBy, MultipleKeysIncludesNullType) { } } -TEST(GroupBy, SumNullType) { +TEST_P(GroupBy, SumNullType) { auto table = TableFromJSON(schema({field("argument", null()), field("key", int64())}), {R"([ [null, 1], @@ -3986,7 +4431,7 @@ TEST(GroupBy, SumNullType) { } } -TEST(GroupBy, ProductNullType) { +TEST_P(GroupBy, ProductNullType) { auto table = TableFromJSON(schema({field("argument", null()), field("key", int64())}), {R"([ [null, 1], @@ -4052,7 +4497,7 @@ TEST(GroupBy, ProductNullType) { } } -TEST(GroupBy, MeanNullType) { +TEST_P(GroupBy, MeanNullType) { auto table = TableFromJSON(schema({field("argument", null()), field("key", int64())}), {R"([ [null, 1], @@ -4118,7 +4563,7 @@ TEST(GroupBy, MeanNullType) { } } -TEST(GroupBy, NullTypeEmptyTable) { +TEST_P(GroupBy, NullTypeEmptyTable) { auto table = TableFromJSON(schema({field("argument", null()), field("key", int64())}), {R"([])"}); @@ -4157,7 +4602,7 @@ TEST(GroupBy, NullTypeEmptyTable) { } } -TEST(GroupBy, OnlyKeys) { +TEST_P(GroupBy, OnlyKeys) { auto table = TableFromJSON(schema({field("key_0", int64()), field("key_1", utf8())}), {R"([ [1, "a"], @@ -4202,5 +4647,262 @@ TEST(GroupBy, OnlyKeys) { /*verbose=*/true); } } + +INSTANTIATE_TEST_SUITE_P(GroupBy, GroupBy, ::testing::Values(RunGroupByImpl)); + +class SegmentedScalarGroupBy : public GroupBy {}; + +class SegmentedKeyGroupBy : public GroupBy {}; + +void TestSegment(GroupByFunction group_by, const std::shared_ptr
& table, + Datum output, const std::vector& keys, + const std::vector& segment_keys, bool is_scalar_aggregate) { + const char* names[] = { + is_scalar_aggregate ? "count" : "hash_count", + is_scalar_aggregate ? "sum" : "hash_sum", + is_scalar_aggregate ? "min_max" : "hash_min_max", + }; + ASSERT_OK_AND_ASSIGN(Datum aggregated_and_grouped, + group_by( + { + table->GetColumnByName("argument"), + table->GetColumnByName("argument"), + table->GetColumnByName("argument"), + }, + keys, segment_keys, + { + {names[0], nullptr, "agg_0", names[0]}, + {names[1], nullptr, "agg_1", names[1]}, + {names[2], nullptr, "agg_2", names[2]}, + }, + /*use_threads=*/false, /*naive=*/false)); + + AssertDatumsEqual(output, aggregated_and_grouped, /*verbose=*/true); +} + +// test with empty keys, covering code in ScalarAggregateNode +void TestSegmentScalar(GroupByFunction group_by, const std::shared_ptr
& table, + Datum output, const std::vector& segment_keys) { + TestSegment(group_by, table, output, {}, segment_keys, /*scalar=*/true); +} + +// test with given segment-keys and keys set to `{"key"}`, covering code in GroupByNode +void TestSegmentKey(GroupByFunction group_by, const std::shared_ptr
& table, + Datum output, const std::vector& segment_keys) { + TestSegment(group_by, table, output, {table->GetColumnByName("key")}, segment_keys, + /*scalar=*/false); +} + +Result> GetSingleSegmentInputAsChunked() { + auto table = TableFromJSON(schema({field("argument", float64()), field("key", int64()), + field("segment_key", int64())}), + {R"([{"argument": 1.0, "key": 1, "segment_key": 1}, + {"argument": null, "key": 1, "segment_key": 1} + ])", + R"([{"argument": 0.0, "key": 2, "segment_key": 1}, + {"argument": null, "key": 3, "segment_key": 1}, + {"argument": 4.0, "key": null, "segment_key": 1}, + {"argument": 3.25, "key": 1, "segment_key": 1}, + {"argument": 0.125, "key": 2, "segment_key": 1}, + {"argument": -0.25, "key": 2, "segment_key": 1}, + {"argument": 0.75, "key": null, "segment_key": 1}, + {"argument": null, "key": 3, "segment_key": 1} + ])", + R"([{"argument": 1.0, "key": 1, "segment_key": 0}, + {"argument": null, "key": 1, "segment_key": 0} + ])", + R"([{"argument": 0.0, "key": 2, "segment_key": 0}, + {"argument": null, "key": 3, "segment_key": 0}, + {"argument": 4.0, "key": null, "segment_key": 0}, + {"argument": 3.25, "key": 1, "segment_key": 0}, + {"argument": 0.125, "key": 2, "segment_key": 0}, + {"argument": -0.25, "key": 2, "segment_key": 0}, + {"argument": 0.75, "key": null, "segment_key": 0}, + {"argument": null, "key": 3, "segment_key": 0} + ])"}); + return table; +} + +Result> GetSingleSegmentInputAsCombined() { + ARROW_ASSIGN_OR_RAISE(auto table, GetSingleSegmentInputAsChunked()); + return table->CombineChunks(); +} + +Result> GetSingleSegmentScalarOutput() { + return ChunkedArrayFromJSON(struct_({ + field("count", int64()), + field("sum", float64()), + field("min_max", struct_({ + field("min", float64()), + field("max", float64()), + })), + field("key_0", int64()), + }), + {R"([ + [7, 8.875, {"min": -0.25, "max": 4.0}, 1] + ])", + R"([ + [7, 8.875, {"min": -0.25, "max": 4.0}, 0] + ])"}); +} + +Result> GetSingleSegmentKeyOutput() { + return ChunkedArrayFromJSON(struct_({ + field("hash_count", int64()), + field("hash_sum", float64()), + field("hash_min_max", struct_({ + field("min", float64()), + field("max", float64()), + })), + field("key_0", int64()), + field("key_1", int64()), + }), + {R"([ + [2, 4.25, {"min": 1.0, "max": 3.25}, 1, 1], + [3, -0.125, {"min": -0.25, "max": 0.125}, 2, 1], + [0, null, {"min": null, "max": null}, 3, 1], + [2, 4.75, {"min": 0.75, "max": 4.0}, null, 1] + ])", + R"([ + [2, 4.25, {"min": 1.0, "max": 3.25}, 1, 0], + [3, -0.125, {"min": -0.25, "max": 0.125}, 2, 0], + [0, null, {"min": null, "max": null}, 3, 0], + [2, 4.75, {"min": 0.75, "max": 4.0}, null, 0] + ])"}); +} + +void TestSingleSegmentScalar(GroupByFunction group_by, + std::function>()> get_table) { + ASSERT_OK_AND_ASSIGN(auto table, get_table()); + ASSERT_OK_AND_ASSIGN(auto output, GetSingleSegmentScalarOutput()); + TestSegmentScalar(group_by, table, output, {table->GetColumnByName("segment_key")}); +} + +void TestSingleSegmentKey(GroupByFunction group_by, + std::function>()> get_table) { + ASSERT_OK_AND_ASSIGN(auto table, get_table()); + ASSERT_OK_AND_ASSIGN(auto output, GetSingleSegmentKeyOutput()); + TestSegmentKey(group_by, table, output, {table->GetColumnByName("segment_key")}); +} + +TEST_P(SegmentedScalarGroupBy, SingleSegmentScalarChunked) { + TestSingleSegmentScalar(GetParam(), GetSingleSegmentInputAsChunked); +} + +TEST_P(SegmentedScalarGroupBy, SingleSegmentScalarCombined) { + TestSingleSegmentScalar(GetParam(), GetSingleSegmentInputAsCombined); +} + +TEST_P(SegmentedKeyGroupBy, SingleSegmentKeyChunked) { + TestSingleSegmentKey(GetParam(), GetSingleSegmentInputAsChunked); +} + +TEST_P(SegmentedKeyGroupBy, SingleSegmentKeyCombined) { + TestSingleSegmentKey(GetParam(), GetSingleSegmentInputAsCombined); +} + +// extracts one segment of the obtained (single-segment-key) table +Result> GetEmptySegmentKeysInput( + std::function>()> get_table) { + ARROW_ASSIGN_OR_RAISE(auto table, get_table()); + auto sliced = table->Slice(0, 10); + ARROW_ASSIGN_OR_RAISE(auto batch, sliced->CombineChunksToBatch()); + ARROW_ASSIGN_OR_RAISE(auto array, batch->ToStructArray()); + ARROW_ASSIGN_OR_RAISE(auto chunked, ChunkedArray::Make({array}, array->type())); + return Table::FromChunkedStructArray(chunked); +} + +Result> GetEmptySegmentKeysInputAsChunked() { + return GetEmptySegmentKeysInput(GetSingleSegmentInputAsChunked); +} + +Result> GetEmptySegmentKeysInputAsCombined() { + return GetEmptySegmentKeysInput(GetSingleSegmentInputAsCombined); +} + +// extracts the expected output for one segment +Result> GetEmptySegmentKeyOutput() { + ARROW_ASSIGN_OR_RAISE(auto chunked, GetSingleSegmentKeyOutput()); + ARROW_ASSIGN_OR_RAISE(auto table, Table::FromChunkedStructArray(chunked)); + ARROW_ASSIGN_OR_RAISE(auto removed, table->RemoveColumn(table->num_columns() - 1)); + auto sliced = removed->Slice(0, 4); + ARROW_ASSIGN_OR_RAISE(auto batch, sliced->CombineChunksToBatch()); + return batch->ToStructArray(); +} + +void TestEmptySegmentKey(GroupByFunction group_by, + std::function>()> get_table) { + ASSERT_OK_AND_ASSIGN(auto table, get_table()); + ASSERT_OK_AND_ASSIGN(auto output, GetEmptySegmentKeyOutput()); + TestSegmentKey(group_by, table, output, {}); +} + +TEST_P(SegmentedKeyGroupBy, EmptySegmentKeyChunked) { + TestEmptySegmentKey(GetParam(), GetEmptySegmentKeysInputAsChunked); +} + +TEST_P(SegmentedKeyGroupBy, EmptySegmentKeyCombined) { + TestEmptySegmentKey(GetParam(), GetEmptySegmentKeysInputAsCombined); +} + +// adds a named copy of the last (single-segment-key) column to the obtained table +Result> GetMultiSegmentInput( + std::function>()> get_table, + const std::string& add_name) { + ARROW_ASSIGN_OR_RAISE(auto table, get_table()); + int last = table->num_columns() - 1; + auto add_field = field(add_name, table->schema()->field(last)->type()); + return table->AddColumn(table->num_columns(), add_field, table->column(last)); +} + +Result> GetMultiSegmentInputAsChunked( + const std::string& add_name) { + return GetMultiSegmentInput(GetSingleSegmentInputAsChunked, add_name); +} + +Result> GetMultiSegmentInputAsCombined( + const std::string& add_name) { + return GetMultiSegmentInput(GetSingleSegmentInputAsCombined, add_name); +} + +// adds a named copy of the last (single-segment-key) column to the expected output table +Result> GetMultiSegmentKeyOutput( + const std::string& add_name) { + ARROW_ASSIGN_OR_RAISE(auto chunked, GetSingleSegmentKeyOutput()); + ARROW_ASSIGN_OR_RAISE(auto table, Table::FromChunkedStructArray(chunked)); + int last = table->num_columns() - 1; + auto add_field = field(add_name, table->schema()->field(last)->type()); + ARROW_ASSIGN_OR_RAISE(auto added, + table->AddColumn(last + 1, add_field, table->column(last))); + ARROW_ASSIGN_OR_RAISE(auto batch, added->CombineChunksToBatch()); + ARROW_ASSIGN_OR_RAISE(auto array, batch->ToStructArray()); + return ChunkedArray::Make({array->Slice(0, 4), array->Slice(4, 4)}, array->type()); +} + +void TestMultiSegmentKey( + GroupByFunction group_by, + std::function>(const std::string&)> get_table) { + std::string add_name = "segment_key2"; + ASSERT_OK_AND_ASSIGN(auto table, get_table(add_name)); + ASSERT_OK_AND_ASSIGN(auto output, GetMultiSegmentKeyOutput("key_2")); + TestSegmentKey( + group_by, table, output, + {table->GetColumnByName("segment_key"), table->GetColumnByName(add_name)}); +} + +TEST_P(SegmentedKeyGroupBy, MultiSegmentKeyChunked) { + TestMultiSegmentKey(GetParam(), GetMultiSegmentInputAsChunked); +} + +TEST_P(SegmentedKeyGroupBy, MultiSegmentKeyCombined) { + TestMultiSegmentKey(GetParam(), GetMultiSegmentInputAsCombined); +} + +INSTANTIATE_TEST_SUITE_P(SegmentedScalarGroupBy, SegmentedScalarGroupBy, + ::testing::Values(RunSegmentedGroupByImpl)); + +INSTANTIATE_TEST_SUITE_P(SegmentedKeyGroupBy, SegmentedKeyGroupBy, + ::testing::Values(RunSegmentedGroupByImpl)); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/row/grouper.cc b/cpp/src/arrow/compute/row/grouper.cc index d003137d3e5b2..75df42abd0f46 100644 --- a/cpp/src/arrow/compute/row/grouper.cc +++ b/cpp/src/arrow/compute/row/grouper.cc @@ -19,6 +19,9 @@ #include #include +#include + +#include "arrow/array/builder_primitive.h" #include "arrow/compute/exec/key_hash.h" #include "arrow/compute/exec/key_map.h" @@ -29,7 +32,9 @@ #include "arrow/compute/light_array.h" #include "arrow/compute/registry.h" #include "arrow/compute/row/compare_internal.h" +#include "arrow/compute/row/grouper_internal.h" #include "arrow/type.h" +#include "arrow/type_traits.h" #include "arrow/util/bitmap_ops.h" #include "arrow/util/checked_cast.h" #include "arrow/util/cpu_info.h" @@ -39,12 +44,333 @@ namespace arrow { using internal::checked_cast; +using internal::PrimitiveScalarBase; namespace compute { namespace { -struct GrouperImpl : Grouper { +constexpr uint32_t kNoGroupId = std::numeric_limits::max(); + +using group_id_t = std::remove_const::type; +using GroupIdType = CTypeTraits::ArrowType; +auto g_group_id_type = std::make_shared(); + +inline const uint8_t* GetValuesAsBytes(const ArraySpan& data, int64_t offset = 0) { + DCHECK_GT(data.type->byte_width(), 0); + int64_t absolute_byte_offset = (data.offset + offset) * data.type->byte_width(); + return data.GetValues(1, absolute_byte_offset); +} + +template +Status CheckForGetNextSegment(const std::vector& values, int64_t length, + int64_t offset, const std::vector& key_types) { + if (offset < 0 || offset > length) { + return Status::Invalid("invalid grouping segmenter offset: ", offset); + } + if (values.size() != key_types.size()) { + return Status::Invalid("expected batch size ", key_types.size(), " but got ", + values.size()); + } + for (size_t i = 0; i < key_types.size(); i++) { + const auto& value = values[i]; + const auto& key_type = key_types[i]; + if (*value.type() != *key_type.type) { + return Status::Invalid("expected batch value ", i, " of type ", *key_type.type, + " but got ", *value.type()); + } + } + return Status::OK(); +} + +template +enable_if_t::value || std::is_same::value, + Status> +CheckForGetNextSegment(const Batch& batch, int64_t offset, + const std::vector& key_types) { + return CheckForGetNextSegment(batch.values, batch.length, offset, key_types); +} + +struct BaseRowSegmenter : public RowSegmenter { + explicit BaseRowSegmenter(const std::vector& key_types) + : key_types_(key_types) {} + + const std::vector& key_types() const override { return key_types_; } + + std::vector key_types_; +}; + +Segment MakeSegment(int64_t batch_length, int64_t offset, int64_t length, bool extends) { + return Segment{offset, length, offset + length >= batch_length, extends}; +} + +// Used by SimpleKeySegmenter::GetNextSegment to find the match-length of a value within a +// fixed-width buffer +int64_t GetMatchLength(const uint8_t* match_bytes, int64_t match_width, + const uint8_t* array_bytes, int64_t offset, int64_t length) { + int64_t cursor, byte_cursor; + for (cursor = offset, byte_cursor = match_width * cursor; cursor < length; + cursor++, byte_cursor += match_width) { + if (memcmp(match_bytes, array_bytes + byte_cursor, + static_cast(match_width)) != 0) { + break; + } + } + return std::min(cursor, length) - offset; +} + +using ExtendFunc = std::function; +constexpr bool kDefaultExtends = true; // by default, the first segment extends +constexpr bool kEmptyExtends = true; // an empty segment extends too + +struct NoKeysSegmenter : public BaseRowSegmenter { + static std::unique_ptr Make() { + return std::make_unique(); + } + + NoKeysSegmenter() : BaseRowSegmenter({}) {} + + Status Reset() override { return Status::OK(); } + + Result GetNextSegment(const ExecSpan& batch, int64_t offset) override { + ARROW_RETURN_NOT_OK(CheckForGetNextSegment(batch, offset, {})); + return MakeSegment(batch.length, offset, batch.length - offset, kDefaultExtends); + } +}; + +struct SimpleKeySegmenter : public BaseRowSegmenter { + static Result> Make(TypeHolder key_type) { + return std::make_unique(key_type); + } + + explicit SimpleKeySegmenter(TypeHolder key_type) + : BaseRowSegmenter({key_type}), + key_type_(key_types_[0]), + save_key_data_(static_cast(key_type_.type->byte_width())), + extend_was_called_(false) {} + + Status CheckType(const DataType& type) { + if (!is_fixed_width(type)) { + return Status::Invalid("SimpleKeySegmenter does not support type ", type); + } + return Status::OK(); + } + + Status Reset() override { + extend_was_called_ = false; + return Status::OK(); + } + + // Checks whether the given grouping data extends the current segment, i.e., is equal to + // previously seen grouping data, which is updated with each invocation. + bool Extend(const void* data) { + bool extends = !extend_was_called_ + ? kDefaultExtends + : 0 == memcmp(save_key_data_.data(), data, save_key_data_.size()); + extend_was_called_ = true; + memcpy(save_key_data_.data(), data, save_key_data_.size()); + return extends; + } + + Result GetNextSegment(const Scalar& scalar, int64_t offset, int64_t length) { + ARROW_RETURN_NOT_OK(CheckType(*scalar.type)); + if (!scalar.is_valid) { + return Status::Invalid("segmenting an invalid scalar"); + } + auto data = checked_cast(scalar).data(); + bool extends = length > 0 ? Extend(data) : kEmptyExtends; + return MakeSegment(length, offset, length, extends); + } + + Result GetNextSegment(const DataType& array_type, const uint8_t* array_bytes, + int64_t offset, int64_t length) { + RETURN_NOT_OK(CheckType(array_type)); + DCHECK_LE(offset, length); + int64_t byte_width = array_type.byte_width(); + int64_t match_length = GetMatchLength(array_bytes + offset * byte_width, byte_width, + array_bytes, offset, length); + bool extends = length > 0 ? Extend(array_bytes + offset * byte_width) : kEmptyExtends; + return MakeSegment(length, offset, match_length, extends); + } + + Result GetNextSegment(const ExecSpan& batch, int64_t offset) override { + ARROW_RETURN_NOT_OK(CheckForGetNextSegment(batch, offset, {key_type_})); + if (offset == batch.length) { + return MakeSegment(batch.length, offset, 0, kEmptyExtends); + } + const auto& value = batch.values[0]; + if (value.is_scalar()) { + return GetNextSegment(*value.scalar, offset, batch.length); + } + ARROW_DCHECK(value.is_array()); + const auto& array = value.array; + if (array.GetNullCount() > 0) { + return Status::NotImplemented("segmenting a nullable array"); + } + return GetNextSegment(*array.type, GetValuesAsBytes(array), offset, batch.length); + } + + private: + TypeHolder key_type_; + std::vector save_key_data_; // previusly seen segment-key grouping data + bool extend_was_called_; +}; + +struct AnyKeysSegmenter : public BaseRowSegmenter { + static Result> Make( + const std::vector& key_types, ExecContext* ctx) { + ARROW_RETURN_NOT_OK(Grouper::Make(key_types, ctx)); // check types + return std::make_unique(key_types, ctx); + } + + AnyKeysSegmenter(const std::vector& key_types, ExecContext* ctx) + : BaseRowSegmenter(key_types), + ctx_(ctx), + grouper_(nullptr), + save_group_id_(kNoGroupId) {} + + Status Reset() override { + grouper_ = nullptr; + save_group_id_ = kNoGroupId; + return Status::OK(); + } + + bool Extend(const void* data) { + auto group_id = *static_cast(data); + bool extends = + save_group_id_ == kNoGroupId ? kDefaultExtends : save_group_id_ == group_id; + save_group_id_ = group_id; + return extends; + } + + // Runs the grouper on a single row. This is used to determine the group id of the + // first row of a new segment to see if it extends the previous segment. + template + Result MapGroupIdAt(const Batch& batch, int64_t offset) { + if (!grouper_) return kNoGroupId; + ARROW_ASSIGN_OR_RAISE(auto datum, grouper_->Consume(batch, offset, + /*length=*/1)); + if (!datum.is_array()) { + return Status::Invalid("accessing unsupported datum kind ", datum.kind()); + } + const std::shared_ptr& data = datum.array(); + ARROW_DCHECK(data->GetNullCount() == 0); + DCHECK_EQ(data->type->id(), GroupIdType::type_id); + DCHECK_EQ(1, data->length); + const group_id_t* values = data->GetValues(1); + return values[0]; + } + + Result GetNextSegment(const ExecSpan& batch, int64_t offset) override { + ARROW_RETURN_NOT_OK(CheckForGetNextSegment(batch, offset, key_types_)); + if (offset == batch.length) { + return MakeSegment(batch.length, offset, 0, kEmptyExtends); + } + // ARROW-18311: make Grouper support Reset() + // so it can be reset instead of recreated below + // + // the group id must be computed prior to resetting the grouper, since it is compared + // to save_group_id_, and after resetting the grouper produces incomparable group ids + ARROW_ASSIGN_OR_RAISE(auto group_id, MapGroupIdAt(batch, offset)); + ExtendFunc bound_extend = [this, group_id](const void* data) { + bool extends = Extend(&group_id); + save_group_id_ = *static_cast(data); + return extends; + }; + // resetting drops grouper's group-ids, freeing-up memory for the next segment + ARROW_ASSIGN_OR_RAISE(grouper_, Grouper::Make(key_types_, ctx_)); // TODO: reset it + // GH-34475: cache the grouper-consume result across invocations of GetNextSegment + ARROW_ASSIGN_OR_RAISE(auto datum, grouper_->Consume(batch, offset)); + if (datum.is_array()) { + // `data` is an array whose index-0 corresponds to index `offset` of `batch` + const std::shared_ptr& data = datum.array(); + DCHECK_EQ(data->length, batch.length - offset); + ARROW_DCHECK(data->GetNullCount() == 0); + DCHECK_EQ(data->type->id(), GroupIdType::type_id); + const group_id_t* values = data->GetValues(1); + int64_t cursor; + for (cursor = 1; cursor < data->length; cursor++) { + if (values[0] != values[cursor]) break; + } + int64_t length = cursor; + bool extends = length > 0 ? bound_extend(values) : kEmptyExtends; + return MakeSegment(batch.length, offset, length, extends); + } else { + return Status::Invalid("segmenting unsupported datum kind ", datum.kind()); + } + } + + private: + ExecContext* const ctx_; + std::unique_ptr grouper_; + group_id_t save_group_id_; +}; + +Status CheckAndCapLengthForConsume(int64_t batch_length, int64_t& consume_offset, + int64_t* consume_length) { + if (consume_offset < 0) { + return Status::Invalid("invalid grouper consume offset: ", consume_offset); + } + if (*consume_length < 0) { + *consume_length = batch_length - consume_offset; + } + return Status::OK(); +} + +} // namespace + +Result> MakeAnyKeysSegmenter( + const std::vector& key_types, ExecContext* ctx) { + return AnyKeysSegmenter::Make(key_types, ctx); +} + +Result> RowSegmenter::Make( + const std::vector& key_types, bool nullable_keys, ExecContext* ctx) { + if (key_types.size() == 0) { + return NoKeysSegmenter::Make(); + } else if (!nullable_keys && key_types.size() == 1) { + const DataType* type = key_types[0].type; + if (type != NULLPTR && is_fixed_width(*type)) { + return SimpleKeySegmenter::Make(key_types[0]); + } + } + return AnyKeysSegmenter::Make(key_types, ctx); +} + +namespace { + +struct GrouperNoKeysImpl : Grouper { + Result> MakeConstantGroupIdArray(int64_t length, + group_id_t value) { + std::unique_ptr a_builder; + RETURN_NOT_OK(MakeBuilder(default_memory_pool(), g_group_id_type, &a_builder)); + using GroupIdBuilder = typename TypeTraits::BuilderType; + auto builder = checked_cast(a_builder.get()); + if (length != 0) { + RETURN_NOT_OK(builder->Resize(length)); + } + for (int64_t i = 0; i < length; i++) { + builder->UnsafeAppend(value); + } + std::shared_ptr array; + RETURN_NOT_OK(builder->Finish(&array)); + return std::move(array); + } + Result Consume(const ExecSpan& batch, int64_t offset, int64_t length) override { + ARROW_ASSIGN_OR_RAISE(auto array, MakeConstantGroupIdArray(length, 0)); + return Datum(array); + } + Result GetUniques() override { + auto data = ArrayData::Make(uint32(), 1, 0); + auto values = data->GetMutableValues(0); + values[0] = 0; + ExecBatch out({Datum(data)}, 1); + return std::move(out); + } + uint32_t num_groups() const override { return 1; } +}; + +struct GrouperImpl : public Grouper { static Result> Make( const std::vector& key_types, ExecContext* ctx) { auto impl = std::make_unique(); @@ -95,7 +421,12 @@ struct GrouperImpl : Grouper { return std::move(impl); } - Result Consume(const ExecSpan& batch) override { + Result Consume(const ExecSpan& batch, int64_t offset, int64_t length) override { + ARROW_RETURN_NOT_OK(CheckAndCapLengthForConsume(batch.length, offset, &length)); + if (offset != 0 || length != batch.length) { + auto batch_slice = batch.ToExecBatch().Slice(offset, length); + return Consume(ExecSpan(batch_slice), 0, -1); + } std::vector offsets_batch(batch.length + 1); for (int i = 0; i < batch.num_values(); ++i) { encoders_[i]->AddLength(batch[i], batch.length, offsets_batch.data()); @@ -179,11 +510,14 @@ struct GrouperImpl : Grouper { std::vector> encoders_; }; -struct GrouperFastImpl : Grouper { +struct GrouperFastImpl : public Grouper { static constexpr int kBitmapPaddingForSIMD = 64; // bits static constexpr int kPaddingForSIMD = 32; // bytes static bool CanUse(const std::vector& key_types) { + if (key_types.size() == 0) { + return false; + } #if ARROW_LITTLE_ENDIAN for (size_t i = 0; i < key_types.size(); ++i) { if (is_large_binary_like(key_types[i].id())) { @@ -265,7 +599,12 @@ struct GrouperFastImpl : Grouper { ~GrouperFastImpl() { map_.cleanup(); } - Result Consume(const ExecSpan& batch) override { + Result Consume(const ExecSpan& batch, int64_t offset, int64_t length) override { + ARROW_RETURN_NOT_OK(CheckAndCapLengthForConsume(batch.length, offset, &length)); + if (offset != 0 || length != batch.length) { + auto batch_slice = batch.ToExecBatch().Slice(offset, length); + return Consume(ExecSpan(batch_slice), 0, -1); + } // ARROW-14027: broadcast scalar arguments for now for (int i = 0; i < batch.num_values(); i++) { if (batch[i].is_scalar()) { diff --git a/cpp/src/arrow/compute/row/grouper.h b/cpp/src/arrow/compute/row/grouper.h index ce09adf09b3af..f9e7e2e97e7ed 100644 --- a/cpp/src/arrow/compute/row/grouper.h +++ b/cpp/src/arrow/compute/row/grouper.h @@ -30,6 +30,78 @@ namespace arrow { namespace compute { +/// \brief A segment +/// A segment group is a chunk of continous rows that have the same segment key. (For +/// example, in ordered time series processing, segment key can be "date", and a segment +/// group can be all the rows that belong to the same date.) A segment group can span +/// across multiple exec batches. A segment is a chunk of continous rows that has the same +/// segment key within a given batch. When a segment group span cross batches, it will +/// have multiple segments. A segment never spans cross batches. The segment data +/// structure only makes sense when used along with a exec batch. +struct ARROW_EXPORT Segment { + /// \brief the offset into the batch where the segment starts + int64_t offset; + /// \brief the length of the segment + int64_t length; + /// \brief whether the segment may be extended by a next one + bool is_open; + /// \brief whether the segment extends a preceeding one + bool extends; +}; + +inline bool operator==(const Segment& segment1, const Segment& segment2) { + return segment1.offset == segment2.offset && segment1.length == segment2.length && + segment1.is_open == segment2.is_open && segment1.extends == segment2.extends; +} +inline bool operator!=(const Segment& segment1, const Segment& segment2) { + return !(segment1 == segment2); +} + +/// \brief a helper class to divide a batch into segments of equal values +/// +/// For example, given a batch with two rows: +/// +/// A A +/// A A +/// A B +/// A B +/// A A +/// +/// Then the batch could be divided into 3 segments. The first would be rows 0 & 1, +/// the second would be rows 2 & 3, and the third would be row 4. +/// +/// Further, a segmenter keeps track of the last value seen. This allows it to calculate +/// segments which span batches. In our above example the last batch we emit would set +/// the "open" flag, which indicates whether the segment may extend into the next batch. +/// +/// If the next call to the segmenter starts with `A A` then that segment would set the +/// "extends" flag, which indicates whether the segment continues the last open batch. +class ARROW_EXPORT RowSegmenter { + public: + virtual ~RowSegmenter() = default; + + /// \brief Construct a Segmenter which segments on the specified key types + /// + /// \param[in] key_types the specified key types + /// \param[in] nullable_keys whether values of the specified keys may be null + /// \param[in] ctx the execution context to use + static Result> Make( + const std::vector& key_types, bool nullable_keys, ExecContext* ctx); + + /// \brief Return the key types of this segmenter + virtual const std::vector& key_types() const = 0; + + /// \brief Reset this segmenter + /// + /// A segmenter normally extends (see `Segment`) a segment from one batch to the next. + /// If segment-extenion is undesirable, for example when each batch is processed + /// independently, then `Reset` should be invoked before processing the next batch. + virtual Status Reset() = 0; + + /// \brief Get the next segment for the given batch starting from the given offset + virtual Result GetNextSegment(const ExecSpan& batch, int64_t offset) = 0; +}; + /// Consumes batches of keys and yields batches of the group ids. class ARROW_EXPORT Grouper { public: @@ -39,10 +111,12 @@ class ARROW_EXPORT Grouper { static Result> Make(const std::vector& key_types, ExecContext* ctx = default_exec_context()); - /// Consume a batch of keys, producing the corresponding group ids as an integer array. + /// Consume a batch of keys, producing the corresponding group ids as an integer array, + /// over a slice defined by an offset and length, which defaults to the batch length. /// Currently only uint32 indices will be produced, eventually the bit width will only /// be as wide as necessary. - virtual Result Consume(const ExecSpan& batch) = 0; + virtual Result Consume(const ExecSpan& batch, int64_t offset = 0, + int64_t length = -1) = 0; /// Get current unique keys. May be called multiple times. virtual Result GetUniques() = 0; diff --git a/cpp/src/arrow/compute/row/grouper_internal.h b/cpp/src/arrow/compute/row/grouper_internal.h new file mode 100644 index 0000000000000..eb3dfe8ba1654 --- /dev/null +++ b/cpp/src/arrow/compute/row/grouper_internal.h @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +namespace arrow { +namespace compute { + +ARROW_EXPORT Result> MakeAnyKeysSegmenter( + const std::vector& key_types, ExecContext* ctx); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 31dfdcbc84f72..d23b33e28f75c 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -136,6 +136,8 @@ struct ARROW_EXPORT PrimitiveScalarBase : public Scalar { : Scalar(std::move(type), false) {} using Scalar::Scalar; + /// \brief Get a const pointer to the value of this scalar. May be null. + virtual const void* data() const = 0; /// \brief Get a mutable pointer to the value of this scalar. May be null. virtual void* mutable_data() = 0; /// \brief Get an immutable view of the value of this scalar as bytes. @@ -157,6 +159,7 @@ struct ARROW_EXPORT PrimitiveScalar : public PrimitiveScalarBase { ValueType value{}; + const void* data() const override { return &value; } void* mutable_data() override { return &value; } std::string_view view() const override { return std::string_view(reinterpret_cast(&value), sizeof(ValueType)); @@ -241,6 +244,9 @@ struct ARROW_EXPORT BaseBinaryScalar : public internal::PrimitiveScalarBase { std::shared_ptr value; + const void* data() const override { + return value ? reinterpret_cast(value->data()) : NULLPTR; + } void* mutable_data() override { return value ? reinterpret_cast(value->mutable_data()) : NULLPTR; } @@ -434,6 +440,10 @@ struct ARROW_EXPORT DecimalScalar : public internal::PrimitiveScalarBase { DecimalScalar(ValueType value, std::shared_ptr type) : internal::PrimitiveScalarBase(std::move(type), true), value(value) {} + const void* data() const override { + return reinterpret_cast(value.native_endian_bytes()); + } + void* mutable_data() override { return reinterpret_cast(value.mutable_native_endian_bytes()); } @@ -603,6 +613,9 @@ struct ARROW_EXPORT DictionaryScalar : public internal::PrimitiveScalarBase { Result> GetEncodedValue() const; + const void* data() const override { + return internal::checked_cast(*value.index).data(); + } void* mutable_data() override { return internal::checked_cast(*value.index) .mutable_data();