Skip to content

Commit

Permalink
Switch ExecPlan to use a QueryContext that it owns
Browse files Browse the repository at this point in the history
  • Loading branch information
save-buffer committed Sep 23, 2022
1 parent 549d212 commit a7744c5
Show file tree
Hide file tree
Showing 24 changed files with 514 additions and 294 deletions.
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ if(ARROW_COMPUTE)
compute/exec/partition_util.cc
compute/exec/options.cc
compute/exec/project_node.cc
compute/exec/query_context.cc
compute/exec/sink_node.cc
compute/exec/source_node.cc
compute/exec/swiss_join.cc
Expand Down
43 changes: 22 additions & 21 deletions cpp/src/arrow/compute/exec/aggregate_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class ScalarAggregateNode : public ExecNode {
auto aggregates = aggregate_options.aggregates;

const auto& input_schema = *inputs[0]->output_schema();
auto exec_ctx = plan->exec_context();
auto exec_ctx = plan->query_context()->exec_context();

std::vector<const ScalarAggregateKernel*> kernels(aggregates.size());
std::vector<std::vector<std::unique_ptr<KernelState>>> states(kernels.size());
Expand Down Expand Up @@ -113,7 +113,7 @@ class ScalarAggregateNode : public ExecNode {
}

KernelContext kernel_ctx{exec_ctx};
states[i].resize(plan->max_concurrency());
states[i].resize(plan->query_context()->max_concurrency());
RETURN_NOT_OK(Kernel::InitAll(&kernel_ctx,
KernelInitArgs{kernels[i],
{
Expand Down Expand Up @@ -150,7 +150,7 @@ class ScalarAggregateNode : public ExecNode {
{"function.options",
aggs_[i].options ? aggs_[i].options->ToString() : "<NULLPTR>"},
{"function.kind", std::string(kind_name()) + "::Consume"}});
KernelContext batch_ctx{plan()->exec_context()};
KernelContext batch_ctx{plan()->query_context()->exec_context()};
batch_ctx.SetState(states_[i][thread_index].get());

ExecSpan single_column_batch{{batch.values[target_field_ids_[i]]}, batch.length};
Expand All @@ -168,7 +168,7 @@ class ScalarAggregateNode : public ExecNode {
{"batch.length", batch.length}});
DCHECK_EQ(input, inputs_[0]);

auto thread_index = plan_->GetThreadIndex();
auto thread_index = plan_->query_context()->GetThreadIndex();

if (ErrorIfNotOk(DoConsume(ExecSpan(batch), thread_index))) return;

Expand Down Expand Up @@ -245,7 +245,7 @@ class ScalarAggregateNode : public ExecNode {
{"function.options",
aggs_[i].options ? aggs_[i].options->ToString() : "<NULLPTR>"},
{"function.kind", std::string(kind_name()) + "::Finalize"}});
KernelContext ctx{plan()->exec_context()};
KernelContext ctx{plan()->query_context()->exec_context()};
ARROW_ASSIGN_OR_RAISE(auto merged, ScalarAggregateKernel::MergeAll(
kernels_[i], &ctx, std::move(states_[i])));
RETURN_NOT_OK(kernels_[i]->finalize(&ctx, &batch.values[i]));
Expand All @@ -267,20 +267,19 @@ class ScalarAggregateNode : public ExecNode {

class GroupByNode : public ExecNode {
public:
GroupByNode(ExecNode* input, std::shared_ptr<Schema> output_schema, ExecContext* ctx,
GroupByNode(ExecNode* input, std::shared_ptr<Schema> output_schema,
std::vector<int> key_field_ids, std::vector<int> agg_src_field_ids,
std::vector<Aggregate> aggs,
std::vector<const HashAggregateKernel*> agg_kernels)
: ExecNode(input->plan(), {input}, {"groupby"}, std::move(output_schema),
/*num_outputs=*/1),
ctx_(ctx),
key_field_ids_(std::move(key_field_ids)),
agg_src_field_ids_(std::move(agg_src_field_ids)),
aggs_(std::move(aggs)),
agg_kernels_(std::move(agg_kernels)) {}

Status Init() override {
output_task_group_id_ = plan_->RegisterTaskGroup(
output_task_group_id_ = plan_->query_context()->RegisterTaskGroup(
[this](size_t, int64_t task_id) {
OutputNthBatch(task_id);
return Status::OK();
Expand Down Expand Up @@ -326,7 +325,7 @@ class GroupByNode : public ExecNode {
agg_src_types[i] = input_schema->field(agg_src_field_id)->type().get();
}

auto ctx = input->plan()->exec_context();
auto ctx = plan->query_context()->exec_context();

// Construct aggregates
ARROW_ASSIGN_OR_RAISE(auto agg_kernels,
Expand Down Expand Up @@ -354,7 +353,7 @@ class GroupByNode : public ExecNode {
}

return input->plan()->EmplaceNode<GroupByNode>(
input, schema(std::move(output_fields)), ctx, std::move(key_field_ids),
input, schema(std::move(output_fields)), std::move(key_field_ids),
std::move(agg_src_field_ids), std::move(aggs), std::move(agg_kernels));
}

Expand All @@ -366,7 +365,7 @@ class GroupByNode : public ExecNode {
{{"group_by", ToStringExtra()},
{"node.label", label()},
{"batch.length", batch.length}});
size_t thread_index = plan_->GetThreadIndex();
size_t thread_index = plan_->query_context()->GetThreadIndex();
if (thread_index >= local_states_.size()) {
return Status::IndexError("thread index ", thread_index, " is out of range [0, ",
local_states_.size(), ")");
Expand All @@ -393,7 +392,8 @@ class GroupByNode : public ExecNode {
{"function.options",
aggs_[i].options ? aggs_[i].options->ToString() : "<NULLPTR>"},
{"function.kind", std::string(kind_name()) + "::Consume"}});
KernelContext kernel_ctx{ctx_};
auto ctx = plan_->query_context()->exec_context();
KernelContext kernel_ctx{ctx};
kernel_ctx.SetState(state->agg_states[i].get());

ExecSpan agg_batch({batch[agg_src_field_ids_[i]], ExecValue(*id_batch.array())},
Expand Down Expand Up @@ -429,7 +429,9 @@ class GroupByNode : public ExecNode {
{"function.options",
aggs_[i].options ? aggs_[i].options->ToString() : "<NULLPTR>"},
{"function.kind", std::string(kind_name()) + "::Merge"}});
KernelContext batch_ctx{ctx_};

auto ctx = plan_->query_context()->exec_context();
KernelContext batch_ctx{ctx};
DCHECK(state0->agg_states[i]);
batch_ctx.SetState(state0->agg_states[i].get());

Expand Down Expand Up @@ -462,7 +464,7 @@ class GroupByNode : public ExecNode {
{"function.options",
aggs_[i].options ? aggs_[i].options->ToString() : "<NULLPTR>"},
{"function.kind", std::string(kind_name()) + "::Finalize"}});
KernelContext batch_ctx{ctx_};
KernelContext batch_ctx{plan_->query_context()->exec_context()};
batch_ctx.SetState(state->agg_states[i].get());
RETURN_NOT_OK(agg_kernels_[i]->finalize(&batch_ctx, &out_data.values[i]));
state->agg_states[i].reset();
Expand Down Expand Up @@ -497,7 +499,7 @@ class GroupByNode : public ExecNode {

int64_t num_output_batches = bit_util::CeilDiv(out_data_.length, output_batch_size());
outputs_[0]->InputFinished(this, static_cast<int>(num_output_batches));
RETURN_NOT_OK(plan_->StartTaskGroup(output_task_group_id_, num_output_batches));
RETURN_NOT_OK(plan_->query_context()->StartTaskGroup(output_task_group_id_, num_output_batches));
return Status::OK();
}

Expand Down Expand Up @@ -548,7 +550,7 @@ class GroupByNode : public ExecNode {
{"node.detail", ToString()},
{"node.kind", kind_name()}});

local_states_.resize(plan_->max_concurrency());
local_states_.resize(plan_->query_context()->max_concurrency());
return Status::OK();
}

Expand Down Expand Up @@ -593,7 +595,7 @@ class GroupByNode : public ExecNode {
};

ThreadLocalState* GetLocalState() {
size_t thread_index = plan_->GetThreadIndex();
size_t thread_index = plan_->query_context()->GetThreadIndex();
return &local_states_[thread_index];
}

Expand All @@ -611,7 +613,7 @@ class GroupByNode : public ExecNode {
}

// Construct grouper
ARROW_ASSIGN_OR_RAISE(state->grouper, Grouper::Make(key_types, ctx_));
ARROW_ASSIGN_OR_RAISE(state->grouper, Grouper::Make(key_types, plan_->query_context()->exec_context()));

// Build vector of aggregate source field data types
std::vector<TypeHolder> agg_src_types(agg_kernels_.size());
Expand All @@ -620,21 +622,20 @@ class GroupByNode : public ExecNode {
agg_src_types[i] = input_schema->field(agg_src_field_id)->type().get();
}

ARROW_ASSIGN_OR_RAISE(state->agg_states, internal::InitKernels(agg_kernels_, ctx_,
ARROW_ASSIGN_OR_RAISE(state->agg_states, internal::InitKernels(agg_kernels_, plan_->query_context()->exec_context(),
aggs_, agg_src_types));

return Status::OK();
}

int output_batch_size() const {
int result = static_cast<int>(ctx_->exec_chunksize());
int result = static_cast<int>(plan_->query_context()->exec_context()->exec_chunksize());
if (result < 0) {
result = 32 * 1024;
}
return result;
}

ExecContext* ctx_;
int output_task_group_id_;

const std::vector<int> key_field_ids_;
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/compute/exec/asof_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ class AsofJoinNode : public ExecNode {
if (dst.empty()) {
return NULLPTR;
} else {
return dst.Materialize(plan()->exec_context()->memory_pool(), output_schema(),
return dst.Materialize(plan()->query_context()->memory_pool(), output_schema(),
state_);
}
}
Expand Down Expand Up @@ -849,7 +849,7 @@ class AsofJoinNode : public ExecNode {
Status Init() override {
auto inputs = this->inputs();
for (size_t i = 0; i < inputs.size(); i++) {
RETURN_NOT_OK(key_hashers_[i]->Init(plan()->exec_context(), output_schema()));
RETURN_NOT_OK(key_hashers_[i]->Init(plan()->query_context()->exec_context(), output_schema()));
state_.push_back(std::make_unique<InputState>(
must_hash_, may_rehash_, key_hashers_[i].get(), inputs[i]->output_schema(),
indices_of_on_key_[i], indices_of_by_key_[i]));
Expand Down
108 changes: 25 additions & 83 deletions cpp/src/arrow/compute/exec/exec_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@ namespace compute {
namespace {

struct ExecPlanImpl : public ExecPlan {
explicit ExecPlanImpl(ExecContext* exec_context,
std::shared_ptr<const KeyValueMetadata> metadata = NULLPTR)
: ExecPlan(exec_context), metadata_(std::move(metadata)) {}
explicit ExecPlanImpl(
QueryOptions options,
ExecContext *exec_context,
std::shared_ptr<const KeyValueMetadata> metadata = NULLPTR)
: ExecPlan(options, exec_context), metadata_(std::move(metadata)) {}

~ExecPlanImpl() override {
if (started_ && !finished_.is_finished()) {
Expand All @@ -59,9 +61,6 @@ struct ExecPlanImpl : public ExecPlan {
}
}

size_t GetThreadIndex() { return thread_indexer_(); }
size_t max_concurrency() const { return thread_indexer_.Capacity(); }

ExecNode* AddNode(std::unique_ptr<ExecNode> node) {
if (node->label().empty()) {
node->SetLabel(std::to_string(auto_label_counter_++));
Expand All @@ -76,44 +75,6 @@ struct ExecPlanImpl : public ExecPlan {
return nodes_.back().get();
}

Result<Future<>> BeginExternalTask() {
Future<> completion_future = Future<>::Make();
if (async_scheduler_->AddSimpleTask(
[completion_future] { return completion_future; })) {
return completion_future;
}
return Future<>{};
}

Status ScheduleTask(std::function<Status()> fn) {
auto executor = exec_context_->executor();
if (!executor) return fn();
// Adds a task which submits fn to the executor and tracks its progress. If we're
// aborted then the task is ignored and fn is not executed.
async_scheduler_->AddSimpleTask(
[executor, fn]() { return executor->Submit(std::move(fn)); });
return Status::OK();
}

Status ScheduleTask(std::function<Status(size_t)> fn) {
std::function<Status()> indexed_fn = [this, fn]() {
size_t thread_index = GetThreadIndex();
return fn(thread_index);
};
return ScheduleTask(std::move(indexed_fn));
}

int RegisterTaskGroup(std::function<Status(size_t, int64_t)> task,
std::function<Status(size_t)> on_finished) {
return task_scheduler_->RegisterTaskGroup(std::move(task), std::move(on_finished));
}

Status StartTaskGroup(int task_group_id, int64_t num_tasks) {
return task_scheduler_->StartTaskGroup(GetThreadIndex(), task_group_id, num_tasks);
}

util::AsyncTaskScheduler* async_scheduler() { return async_scheduler_.get(); }

Status Validate() const {
if (nodes_.empty()) {
return Status::Invalid("ExecPlan has no node");
Expand Down Expand Up @@ -141,6 +102,9 @@ struct ExecPlanImpl : public ExecPlan {
return Status::Invalid("restarted ExecPlan");
}

QueryContext *ctx = query_context();
RETURN_NOT_OK(ctx->Init(ctx->max_concurrency()));

std::vector<Future<>> futures;
for (auto& n : nodes_) {
RETURN_NOT_OK(n->Init());
Expand All @@ -152,17 +116,17 @@ struct ExecPlanImpl : public ExecPlan {
EndTaskGroup();
});

task_scheduler_->RegisterEnd();
ctx->scheduler()->RegisterEnd();
int num_threads = 1;
bool sync_execution = true;
if (auto executor = exec_context()->executor()) {
if (auto executor = query_context()->executor()) {
num_threads = executor->GetCapacity();
sync_execution = false;
}
RETURN_NOT_OK(task_scheduler_->StartScheduling(
RETURN_NOT_OK(ctx->scheduler()->StartScheduling(
0 /* thread_index */,
[this](std::function<Status(size_t)> fn) -> Status {
return this->ScheduleTask(std::move(fn));
[ctx](std::function<Status(size_t)> fn) -> Status {
return ctx->ScheduleTask(std::move(fn));
},
/*concurrent_tasks=*/2 * num_threads, sync_execution));

Expand Down Expand Up @@ -198,8 +162,8 @@ struct ExecPlanImpl : public ExecPlan {
void EndTaskGroup() {
bool expected = false;
if (group_ended_.compare_exchange_strong(expected, true)) {
async_scheduler_->End();
async_scheduler_->OnFinished().AddCallback([this](const Status& st) {
query_context()->async_scheduler()->End();
query_context()->async_scheduler()->OnFinished().AddCallback([this](const Status &st) {
MARK_SPAN(span_, error_st_ & st);
END_SPAN(span_);
finished_.MarkFinished(error_st_ & st);
Expand All @@ -211,7 +175,7 @@ struct ExecPlanImpl : public ExecPlan {
DCHECK(started_) << "stopped an ExecPlan which never started";
EVENT(span_, "StopProducing");
stopped_ = true;
task_scheduler_->Abort(
query_context()->scheduler()->Abort(
[this]() { StopProducingImpl(sorted_nodes_.begin(), sorted_nodes_.end()); });
}

Expand Down Expand Up @@ -328,11 +292,7 @@ struct ExecPlanImpl : public ExecPlan {
util::tracing::Span span_;
std::shared_ptr<const KeyValueMetadata> metadata_;

ThreadIndexer thread_indexer_;
std::atomic<bool> group_ended_{false};
std::unique_ptr<util::AsyncTaskScheduler> async_scheduler_ =
util::AsyncTaskScheduler::Make();
std::unique_ptr<TaskScheduler> task_scheduler_ = TaskScheduler::Make();
};

ExecPlanImpl* ToDerived(ExecPlan* ptr) { return checked_cast<ExecPlanImpl*>(ptr); }
Expand All @@ -354,8 +314,15 @@ std::optional<int> GetNodeIndex(const std::vector<ExecNode*>& nodes,
const uint32_t ExecPlan::kMaxBatchSize;

Result<std::shared_ptr<ExecPlan>> ExecPlan::Make(
ExecContext* ctx, std::shared_ptr<const KeyValueMetadata> metadata) {
return std::shared_ptr<ExecPlan>(new ExecPlanImpl{ctx, metadata});
QueryOptions opts,
ExecContext *ctx,
std::shared_ptr<const KeyValueMetadata> metadata) {
return std::shared_ptr<ExecPlan>(new ExecPlanImpl{opts, ctx, std::move(metadata)});
}

Result<std::shared_ptr<ExecPlan>> ExecPlan::Make(
ExecContext *ctx, std::shared_ptr<const KeyValueMetadata> metadata) {
return Make({}, ctx, std::move(metadata));
}

ExecNode* ExecPlan::AddNode(std::unique_ptr<ExecNode> node) {
Expand All @@ -368,31 +335,6 @@ const ExecPlan::NodeVector& ExecPlan::sources() const {

const ExecPlan::NodeVector& ExecPlan::sinks() const { return ToDerived(this)->sinks_; }

size_t ExecPlan::GetThreadIndex() { return ToDerived(this)->GetThreadIndex(); }
size_t ExecPlan::max_concurrency() const { return ToDerived(this)->max_concurrency(); }

Result<Future<>> ExecPlan::BeginExternalTask() {
return ToDerived(this)->BeginExternalTask();
}

Status ExecPlan::ScheduleTask(std::function<Status()> fn) {
return ToDerived(this)->ScheduleTask(std::move(fn));
}
Status ExecPlan::ScheduleTask(std::function<Status(size_t)> fn) {
return ToDerived(this)->ScheduleTask(std::move(fn));
}
int ExecPlan::RegisterTaskGroup(std::function<Status(size_t, int64_t)> task,
std::function<Status(size_t)> on_finished) {
return ToDerived(this)->RegisterTaskGroup(std::move(task), std::move(on_finished));
}
Status ExecPlan::StartTaskGroup(int task_group_id, int64_t num_tasks) {
return ToDerived(this)->StartTaskGroup(task_group_id, num_tasks);
}

util::AsyncTaskScheduler* ExecPlan::async_scheduler() {
return ToDerived(this)->async_scheduler();
}

Status ExecPlan::Validate() { return ToDerived(this)->Validate(); }

Status ExecPlan::StartProducing() { return ToDerived(this)->StartProducing(); }
Expand Down
Loading

0 comments on commit a7744c5

Please sign in to comment.