Skip to content

Commit

Permalink
Simply some segment code; add documentation; some refactor/renames
Browse files Browse the repository at this point in the history
  • Loading branch information
icexelloss committed Mar 2, 2023
1 parent 6240596 commit db68040
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 148 deletions.
92 changes: 49 additions & 43 deletions cpp/src/arrow/compute/exec/aggregate_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -186,15 +186,19 @@ void AggregatesToString(std::stringstream* ss, const Schema& input_schema,
*ss << ']';
}

// Handle the input batch
// If a segment is closed by this batch, then we output the aggregation for the segment
// If a segment is not closed by this batch, then we add the batch to the segment
template <typename BatchHandler>
Status HandleSegments(std::unique_ptr<GroupingSegmenter>& segmenter,
const ExecBatch& batch, const std::vector<int>& ids,
const BatchHandler& handle_batch) {
Status HandleSegments(std::unique_ptr<RowSegmenter>& segmenter, const ExecBatch& batch,
const std::vector<int>& ids, const BatchHandler& handle_batch) {
int64_t offset = 0;
ARROW_ASSIGN_OR_RAISE(auto segment_exec_batch, batch.SelectValues(ids));
ExecSpan segment_batch(segment_exec_batch);

while (true) {
ARROW_ASSIGN_OR_RAISE(auto segment, segmenter->GetNextSegment(segment_batch, offset));
ARROW_ASSIGN_OR_RAISE(compute::SegmentPiece segment,
segmenter->GetNextSegmentPiece(segment_batch, offset));
if (segment.offset >= segment_batch.length) break; // condition of no-next-segment
ARROW_RETURN_NOT_OK(handle_batch(batch, segment));
offset = segment.offset + segment.length;
Expand Down Expand Up @@ -234,11 +238,12 @@ class ScalarAggregateNode : public ExecNode, public TracedNode {
public:
ScalarAggregateNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
std::shared_ptr<Schema> output_schema,
std::unique_ptr<GroupingSegmenter> segmenter,
std::unique_ptr<RowSegmenter> segmenter,
std::vector<int> segment_field_ids,
std::vector<std::vector<int>> target_fieldsets,
std::vector<Aggregate> aggs,
std::vector<const ScalarAggregateKernel*> kernels,
std::vector<std::vector<TypeHolder>> kernel_intypes,
std::vector<std::vector<std::unique_ptr<KernelState>>> states)
: ExecNode(plan, std::move(inputs), {"target"},
/*output_schema=*/std::move(output_schema)),
Expand All @@ -248,16 +253,8 @@ class ScalarAggregateNode : public ExecNode, public TracedNode {
target_fieldsets_(std::move(target_fieldsets)),
aggs_(std::move(aggs)),
kernels_(std::move(kernels)),
states_(std::move(states)) {
const auto& input_schema = *this->inputs()[0]->output_schema();
for (size_t i = 0; i < kernels_.size(); ++i) {
std::vector<TypeHolder> in_types;
for (const auto& target : target_fieldsets_[i]) {
in_types.emplace_back(input_schema.field(target)->type().get());
}
in_typesets_.push_back(std::move(in_types));
}
}
kernel_intypes_(std::move(kernel_intypes)),
states_(std::move(states)) {}

static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
const ExecNodeOptions& options) {
Expand All @@ -282,7 +279,7 @@ class ScalarAggregateNode : public ExecNode, public TracedNode {
std::vector<int> segment_field_ids(segment_keys.size());
std::vector<TypeHolder> segment_key_types(segment_keys.size());
for (size_t i = 0; i < segment_keys.size(); i++) {
ARROW_ASSIGN_OR_RAISE(auto match, segment_keys[i].FindOne(input_schema));
ARROW_ASSIGN_OR_RAISE(FieldPath match, segment_keys[i].FindOne(input_schema));
if (match.indices().size() > 1) {
// ARROW-18369: Support nested references as segment ids
return Status::Invalid("Nested references cannot be used as segment ids");
Expand All @@ -291,9 +288,10 @@ class ScalarAggregateNode : public ExecNode, public TracedNode {
segment_key_types[i] = input_schema.field(match[0])->type().get();
}

ARROW_ASSIGN_OR_RAISE(
auto segmenter, GroupingSegmenter::Make(std::move(segment_key_types), exec_ctx));
ARROW_ASSIGN_OR_RAISE(auto segmenter,
RowSegmenter::Make(std::move(segment_key_types), exec_ctx));

std::vector<std::vector<TypeHolder>> kernel_intypes(aggregates.size());
std::vector<const ScalarAggregateKernel*> kernels(aggregates.size());
std::vector<std::vector<std::unique_ptr<KernelState>>> states(kernels.size());
FieldVector fields(kernels.size() + segment_keys.size());
Expand Down Expand Up @@ -324,7 +322,9 @@ class ScalarAggregateNode : public ExecNode, public TracedNode {
for (const auto& target : target_fieldsets[i]) {
in_types.emplace_back(input_schema.field(target)->type().get());
}
ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, function->DispatchExact(in_types));
kernel_intypes[i] = in_types;
ARROW_ASSIGN_OR_RAISE(const Kernel* kernel,
function->DispatchExact(kernel_intypes[i]));
kernels[i] = static_cast<const ScalarAggregateKernel*>(kernel);

if (aggregates[i].options == nullptr) {
Expand All @@ -338,13 +338,14 @@ class ScalarAggregateNode : public ExecNode, public TracedNode {
KernelContext kernel_ctx{exec_ctx};
states[i].resize(plan->query_context()->max_concurrency());
RETURN_NOT_OK(Kernel::InitAll(
&kernel_ctx, KernelInitArgs{kernels[i], in_types, aggregates[i].options.get()},
&kernel_ctx,
KernelInitArgs{kernels[i], kernel_intypes[i], aggregates[i].options.get()},
&states[i]));

// pick one to resolve the kernel signature
kernel_ctx.SetState(states[i][0].get());
ARROW_ASSIGN_OR_RAISE(auto out_type, kernels[i]->signature->out_type().Resolve(
&kernel_ctx, in_types));
&kernel_ctx, kernel_intypes[i]));

fields[i] = field(aggregate_options.aggregates[i].name, out_type.GetSharedPtr());
}
Expand All @@ -356,7 +357,7 @@ class ScalarAggregateNode : public ExecNode, public TracedNode {
return plan->EmplaceNode<ScalarAggregateNode>(
plan, std::move(inputs), schema(std::move(fields)), std::move(segmenter),
std::move(segment_field_ids), std::move(target_fieldsets), std::move(aggregates),
std::move(kernels), std::move(states));
std::move(kernels), std::move(kernel_intypes), std::move(states));
}

const char* kind_name() const override { return "ScalarAggregateNode"; }
Expand Down Expand Up @@ -388,12 +389,21 @@ class ScalarAggregateNode : public ExecNode, public TracedNode {

auto thread_index = plan_->query_context()->GetThreadIndex();
auto handler = [this, thread_index](const ExecBatch& full_batch,
const GroupingSegment& segment) {
if (!segment.extends && segment.offset == 0) RETURN_NOT_OK(OutputResult());
const SegmentPiece& segment) {
// (1) The segment piece is starting of a new segment and points to
// the beginning of the batch, then it means no data in the batch belongs
// to the current segment. We can output and reset kernel states.
if (!segment.extends && segment.offset == 0) RETURN_NOT_OK(OutputResult(false));

// We add segment piece to the current segment aggregation
auto exec_batch = full_batch.Slice(segment.offset, segment.length);
RETURN_NOT_OK(DoConsume(ExecSpan(exec_batch), thread_index));
RETURN_NOT_OK(GetScalarFields(&segmenter_values_, exec_batch, segment_field_ids_));
if (!segment.is_open) RETURN_NOT_OK(OutputResult());

// If the segment piece closes the current segment, we can output segment
// aggregation.
if (!segment.is_open) RETURN_NOT_OK(OutputResult(false));

return Status::OK();
};
RETURN_NOT_OK(HandleSegments(segmenter_, batch, segment_field_ids_, handler));
Expand Down Expand Up @@ -438,20 +448,20 @@ class ScalarAggregateNode : public ExecNode, public TracedNode {
}

private:
Status ResetAggregates() {
Status ResetKernelStates() {
auto exec_ctx = plan()->query_context()->exec_context();
for (size_t i = 0; i < kernels_.size(); ++i) {
const std::vector<TypeHolder>& in_types = in_typesets_[i];
states_[i].resize(plan()->query_context()->max_concurrency());
KernelContext kernel_ctx{exec_ctx};
RETURN_NOT_OK(Kernel::InitAll(
&kernel_ctx, KernelInitArgs{kernels_[i], in_types, aggs_[i].options.get()},
&kernel_ctx,
KernelInitArgs{kernels_[i], kernel_intypes_[i], aggs_[i].options.get()},
&states_[i]));
}
return Status::OK();
}

Status OutputResult(bool is_last = false) {
Status OutputResult(bool is_last) {
ExecBatch batch{{}, 1};
batch.values.resize(kernels_.size() + segment_field_ids_.size());

Expand All @@ -474,20 +484,21 @@ class ScalarAggregateNode : public ExecNode, public TracedNode {
if (is_last) {
ARROW_RETURN_NOT_OK(output_->InputFinished(this, total_output_batches_));
} else {
ARROW_RETURN_NOT_OK(ResetAggregates());
ARROW_RETURN_NOT_OK(ResetKernelStates());
}
return Status::OK();
}

std::unique_ptr<GroupingSegmenter> segmenter_;
std::unique_ptr<RowSegmenter> segmenter_;
const std::vector<int> segment_field_ids_;
std::vector<Datum> segmenter_values_;

const std::vector<std::vector<int>> target_fieldsets_;
const std::vector<Aggregate> aggs_;
const std::vector<const ScalarAggregateKernel*> kernels_;

std::vector<std::vector<TypeHolder>> in_typesets_;
// Input type holders for each kernel, used for state initialization
std::vector<std::vector<TypeHolder>> kernel_intypes_;
std::vector<std::vector<std::unique_ptr<KernelState>>> states_;

AtomicCounter input_counter_;
Expand All @@ -498,7 +509,7 @@ class GroupByNode : public ExecNode, public TracedNode {
public:
GroupByNode(ExecNode* input, std::shared_ptr<Schema> output_schema,
std::vector<int> key_field_ids, std::vector<int> segment_key_field_ids,
std::unique_ptr<GroupingSegmenter> segmenter,
std::unique_ptr<RowSegmenter> segmenter,
std::vector<std::vector<TypeHolder>> agg_src_types,
std::vector<std::vector<int>> agg_src_fieldsets,
std::vector<Aggregate> aggs,
Expand Down Expand Up @@ -591,7 +602,7 @@ class GroupByNode : public ExecNode, public TracedNode {
auto ctx = plan->query_context()->exec_context();

ARROW_ASSIGN_OR_RAISE(auto segmenter,
GroupingSegmenter::Make(std::move(segment_key_types), ctx));
RowSegmenter::Make(std::move(segment_key_types), ctx));

// Construct aggregates
ARROW_ASSIGN_OR_RAISE(auto agg_kernels, GetKernels(ctx, aggs, agg_src_types));
Expand Down Expand Up @@ -630,12 +641,7 @@ class GroupByNode : public ExecNode, public TracedNode {

Status ResetAggregates() {
auto ctx = plan()->query_context()->exec_context();

ARROW_ASSIGN_OR_RAISE(agg_kernels_, GetKernels(ctx, aggs_, agg_src_types_));

ARROW_ASSIGN_OR_RAISE(auto agg_states,
InitKernels(agg_kernels_, ctx, aggs_, agg_src_types_));

ARROW_RETURN_NOT_OK(InitKernels(agg_kernels_, ctx, aggs_, agg_src_types_));
return Status::OK();
}

Expand Down Expand Up @@ -797,7 +803,7 @@ class GroupByNode : public ExecNode, public TracedNode {

DCHECK_EQ(input, inputs_[0]);

auto handler = [this](const ExecBatch& full_batch, const GroupingSegment& segment) {
auto handler = [this](const ExecBatch& full_batch, const SegmentPiece& segment) {
if (!segment.extends && segment.offset == 0) RETURN_NOT_OK(OutputResult());
auto exec_batch = full_batch.Slice(segment.offset, segment.length);
auto batch = ExecSpan(exec_batch);
Expand Down Expand Up @@ -912,15 +918,15 @@ class GroupByNode : public ExecNode, public TracedNode {
}

int output_task_group_id_;
std::unique_ptr<GroupingSegmenter> segmenter_;
std::unique_ptr<RowSegmenter> segmenter_;
std::vector<Datum> segmenter_values_;

const std::vector<int> key_field_ids_;
const std::vector<int> segment_key_field_ids_;
const std::vector<std::vector<TypeHolder>> agg_src_types_;
const std::vector<std::vector<int>> agg_src_fieldsets_;
const std::vector<Aggregate> aggs_;
std::vector<const HashAggregateKernel*> agg_kernels_;
const std::vector<const HashAggregateKernel*> agg_kernels_;

AtomicCounter input_counter_;
int total_output_batches_ = 0;
Expand Down
36 changes: 17 additions & 19 deletions cpp/src/arrow/compute/kernels/hash_aggregate_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -480,23 +480,21 @@ void TestGroupClassSupportedKeys() {
ASSERT_RAISES(NotImplemented, GroupClass::Make({dense_union({field("", int32())})}));
}

void TestSegments(std::unique_ptr<GroupingSegmenter>& segmenter, const ExecSpan& batch,
std::vector<GroupingSegment> expected_segments) {
void TestSegments(std::unique_ptr<RowSegmenter>& segmenter, const ExecSpan& batch,
std::vector<SegmentPiece> expected_segments) {
int64_t offset = 0;
for (auto expected_segment : expected_segments) {
ASSERT_OK_AND_ASSIGN(auto segment, segmenter->GetNextSegment(batch, offset));
ASSERT_OK_AND_ASSIGN(auto segment, segmenter->GetNextSegmentPiece(batch, offset));
ASSERT_EQ(expected_segment, segment);
offset = segment.offset + segment.length;
}
}

} // namespace

TEST(GroupingSegmenter, SupportedKeys) {
TestGroupClassSupportedKeys<GroupingSegmenter>();
}
TEST(RowSegmenter, SupportedKeys) { TestGroupClassSupportedKeys<RowSegmenter>(); }

TEST(GroupingSegmenter, Basics) {
TEST(RowSegmenter, Basics) {
std::vector<TypeHolder> bad_types2 = {int32(), float32()};
std::vector<TypeHolder> types2 = {int32(), int32()};
std::vector<TypeHolder> bad_types1 = {float32()};
Expand All @@ -507,53 +505,53 @@ TEST(GroupingSegmenter, Basics) {
ExecBatch batch0({}, 3);
{
SCOPED_TRACE("offset");
ASSERT_OK_AND_ASSIGN(auto segmenter, GroupingSegmenter::Make(types0));
ASSERT_OK_AND_ASSIGN(auto segmenter, RowSegmenter::Make(types0));
ExecSpan span0(batch0);
for (int64_t offset : {-1, 4}) {
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
HasSubstr("invalid grouping segmenter offset"),
segmenter->GetNextSegment(span0, offset));
segmenter->GetNextSegmentPiece(span0, offset));
}
}
{
SCOPED_TRACE("types0 segmenting of batch2");
ASSERT_OK_AND_ASSIGN(auto segmenter, GroupingSegmenter::Make(types0));
ASSERT_OK_AND_ASSIGN(auto segmenter, RowSegmenter::Make(types0));
ExecSpan span2(batch2);
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch size 0 "),
segmenter->GetNextSegment(span2, 0));
segmenter->GetNextSegmentPiece(span2, 0));
ExecSpan span0(batch0);
TestSegments(segmenter, span0, {{0, 3, true, true}, {3, 0, true, true}});
}
{
SCOPED_TRACE("bad_types1 segmenting of batch1");
ASSERT_OK_AND_ASSIGN(auto segmenter, GroupingSegmenter::Make(bad_types1));
ASSERT_OK_AND_ASSIGN(auto segmenter, RowSegmenter::Make(bad_types1));
ExecSpan span1(batch1);
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch value 0 of type "),
segmenter->GetNextSegment(span1, 0));
segmenter->GetNextSegmentPiece(span1, 0));
}
{
SCOPED_TRACE("types1 segmenting of batch2");
ASSERT_OK_AND_ASSIGN(auto segmenter, GroupingSegmenter::Make(types1));
ASSERT_OK_AND_ASSIGN(auto segmenter, RowSegmenter::Make(types1));
ExecSpan span2(batch2);
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch size 1 "),
segmenter->GetNextSegment(span2, 0));
segmenter->GetNextSegmentPiece(span2, 0));
ExecSpan span1(batch1);
TestSegments(segmenter, span1,
{{0, 2, false, true}, {2, 1, true, false}, {3, 0, true, true}});
}
{
SCOPED_TRACE("bad_types2 segmenting of batch2");
ASSERT_OK_AND_ASSIGN(auto segmenter, GroupingSegmenter::Make(bad_types2));
ASSERT_OK_AND_ASSIGN(auto segmenter, RowSegmenter::Make(bad_types2));
ExecSpan span2(batch2);
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch value 1 of type "),
segmenter->GetNextSegment(span2, 0));
segmenter->GetNextSegmentPiece(span2, 0));
}
{
SCOPED_TRACE("types2 segmenting of batch1");
ASSERT_OK_AND_ASSIGN(auto segmenter, GroupingSegmenter::Make(types2));
ASSERT_OK_AND_ASSIGN(auto segmenter, RowSegmenter::Make(types2));
ExecSpan span1(batch1);
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("expected batch size 2 "),
segmenter->GetNextSegment(span1, 0));
segmenter->GetNextSegmentPiece(span1, 0));
ExecSpan span2(batch2);
TestSegments(segmenter, span2,
{{0, 1, false, true},
Expand Down
Loading

0 comments on commit db68040

Please sign in to comment.