Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-1065] Adding hashagg w/ filter support #1066

Merged
merged 4 commits into from
Aug 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -504,10 +504,6 @@ case class ColumnarHashAggregateExec(
var res_index = 0
for (expIdx <- aggregateExpressions.indices) {
val exp: AggregateExpression = aggregateExpressions(expIdx)
if (exp.filter.isDefined) {
throw new UnsupportedOperationException(
"filter is not supported in AggregateExpression")
}
val mode = exp.mode
val aggregateFunc = exp.aggregateFunction
aggregateFunc match {
Expand Down Expand Up @@ -692,6 +688,9 @@ case class ColumnarHashAggregateExec(
override def supportColumnarCodegen: Boolean = {
for (expr <- aggregateExpressions) {
// TODO: close the gap in supporting code gen.
if (expr.filter.isDefined) {
return false
}
expr.aggregateFunction match {
case _: First =>
return false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class ColumnarHashAggregation(
var inputAttrQueue: scala.collection.mutable.Queue[Attribute] = _
val resultType = CodeGeneration.getResultType()
val NaN_check : Boolean = GazellePluginConfig.getConf.enableColumnarNaNCheck
var distIndex = 0

def getColumnarFuncNode(expr: Expression): TreeNode = {
try {
Expand Down Expand Up @@ -155,14 +156,25 @@ class ColumnarHashAggregation(
case Partial =>
val childrenColumnarFuncNodeList =
aggregateFunc.children.toList.map(expr => getColumnarFuncNode(expr))
if (aggregateFunc.children(0).isInstanceOf[Literal]) {
if (aggregateFunc.children(0).isInstanceOf[Literal] && !aggregateExpression.filter.isDefined) {
TreeBuilder.makeFunction(
s"action_countLiteral_${aggregateFunc.children(0)}",
Lists.newArrayList(),
resultType)
} else {
TreeBuilder
.makeFunction("action_count", childrenColumnarFuncNodeList.asJava, resultType)
if (aggregateExpression.filter.isDefined) {
val filterColumnarFuncNodeList = List(getColumnarFuncNode(aggregateExpression.filter.get))
distIndex += 1
// TODO(): rename this to coundFilter?
TreeBuilder
.makeFunction(s"action_countDistinct_${distIndex}",
(childrenColumnarFuncNodeList ::: filterColumnarFuncNodeList).asJava,
resultType)
} else {
TreeBuilder
.makeFunction("action_count", childrenColumnarFuncNodeList.asJava, resultType)
}

}
case Final | PartialMerge =>
val childrenColumnarFuncNodeList =
Expand Down
167 changes: 167 additions & 0 deletions native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ arrow::Status ActionBase::EvaluateCountLiteral(const int& len) {
return arrow::Status::NotImplemented("ActionBase EvaluateCountLiteral is abstract.");
}

arrow::Status ActionBase::EvaluateCountDistinct(const arrow::ArrayVector& in) {
return arrow::Status::NotImplemented("ActionBase EvaluateCountDistinct is abstract.");
}

arrow::Status ActionBase::Evaluate(int dest_group_id) {
return arrow::Status::NotImplemented("ActionBase Evaluate is abstract.");
}
Expand Down Expand Up @@ -118,6 +122,8 @@ arrow::Status ActionBase::FinishAndReset(ArrayList* out) {

uint64_t ActionBase::GetResultLength() { return 0; }

std::string ActionBase::getName() { return ""; }

//////////////// UniqueAction ///////////////
template <typename DataType, typename CType>
class UniqueAction : public ActionBase {
Expand Down Expand Up @@ -438,6 +444,158 @@ class CountAction : public ActionBase {
uint64_t length_ = 0;
};

//////////////// CountDistinctAction ///////////////
template <typename DataType>
class CountDistinctAction : public ActionBase {
public:
CountDistinctAction(arrow::compute::ExecContext* ctx, int arg)
: ctx_(ctx), localGid_(arg) {
#ifdef DEBUG
std::cout << "Construct CountDistinctAction" << std::endl;
#endif
std::unique_ptr<arrow::ArrayBuilder> array_builder;
arrow::MakeBuilder(ctx_->memory_pool(), arrow::TypeTraits<DataType>::type_singleton(),
&array_builder);
builder_.reset(
arrow::internal::checked_cast<ResBuilderType*>(array_builder.release()));
}
~CountDistinctAction() {
#ifdef DEBUG
std::cout << "Destruct CountDistinctAction" << std::endl;
#endif
}
std::string getName() { return "CountDistinctAction"; }
arrow::Status Submit(ArrayList in_list, int max_group_id,
std::function<arrow::Status(int)>* on_valid,
std::function<arrow::Status()>* on_null) override {
// resize result data
if (cache_.size() <= max_group_id) {
cache_.resize(max_group_id + 1, 0);
length_ = cache_.size();
}

// prepare evaluate lambda
*on_valid = [this](int dest_group_id) {
cache_[dest_group_id] += 1;
return arrow::Status::OK();
};

*on_null = [this]() { return arrow::Status::OK(); };
return arrow::Status::OK();
}

arrow::Status GrowByFactor(int dest_group_id) {
int max_group_id;
if (cache_.size() < 128) {
max_group_id = 128;
} else {
max_group_id = cache_.size() * 2;
}
cache_.resize(max_group_id, 0);
return arrow::Status::OK();
}

arrow::Status EvaluateCountLiteral(const int& len) {
if (cache_.empty()) {
cache_.resize(1, 0);
length_ = 1;
}
cache_[0] += len;
return arrow::Status::OK();
}

arrow::Status EvaluateCountDistinct(const arrow::ArrayVector& in) {
if (cache_.empty()) {
cache_.resize(1, 0);
length_ = 1;
}
// at least two arrays, count attrs and gid
assert(in.size() > 1);
int gid = in.size() - 1;
std::shared_ptr<arrow::BooleanArray> typed_key_in =
std::dynamic_pointer_cast<arrow::BooleanArray>(in[gid]);
int length = in[0]->length();
int count_non_null = 0;
int count_null = 0;
for (size_t id = 0; id < length; id++) {
if (typed_key_in->GetView(id) == 0) {
count_null++;
continue;
}
for (int colId = 0; colId < in.size() - 1; colId++) {
if (in[colId]->IsNull(id)) {
count_null++;
break;
}
}
}
count_non_null = length - count_null;
cache_[0] += count_non_null;
return arrow::Status::OK();
}

arrow::Status Evaluate(const arrow::ArrayVector& in) {
return arrow::Status::NotImplemented(
"CountDistinctAction Non-Groupby Evaluate is unsupported.");
}

arrow::Status Evaluate(int dest_group_id) {
auto target_group_size = dest_group_id + 1;
if (cache_.size() <= target_group_size) GrowByFactor(target_group_size);
if (length_ < target_group_size) length_ = target_group_size;
cache_[dest_group_id] += 1;
return arrow::Status::OK();
}

arrow::Status EvaluateNull(int dest_group_id) {
auto target_group_size = dest_group_id + 1;
if (cache_.size() <= target_group_size) GrowByFactor(target_group_size);
if (length_ < target_group_size) length_ = target_group_size;
return arrow::Status::OK();
}

arrow::Status Finish(ArrayList* out) override {
std::shared_ptr<arrow::Array> arr_out;
builder_->Reset();
auto length = GetResultLength();
for (uint64_t i = 0; i < length; i++) {
builder_->Append(cache_[i]);
}
RETURN_NOT_OK(builder_->Finish(&arr_out));
out->push_back(arr_out);

return arrow::Status::OK();
}

uint64_t GetResultLength() { return length_; }

arrow::Status Finish(uint64_t offset, uint64_t length, ArrayList* out) override {
std::shared_ptr<arrow::Array> arr_out;
builder_->Reset();
auto res_length = (offset + length) > length_ ? (length_ - offset) : length;
for (uint64_t i = 0; i < res_length; i++) {
builder_->Append(cache_[offset + i]);
}

RETURN_NOT_OK(builder_->Finish(&arr_out));
out->push_back(arr_out);
return arrow::Status::OK();
}

private:
using ResArrayType = typename arrow::TypeTraits<DataType>::ArrayType;
using ResBuilderType = typename arrow::TypeTraits<DataType>::BuilderType;
// input
arrow::compute::ExecContext* ctx_;
// for debug only
int32_t localGid_ = -1;
// result
using CType = typename arrow::TypeTraits<DataType>::CType;
std::vector<CType> cache_;
std::unique_ptr<ResBuilderType> builder_;
uint64_t length_ = 0;
};

//////////////// CountLiteralAction ///////////////
template <typename DataType>
class CountLiteralAction : public ActionBase {
Expand Down Expand Up @@ -5868,6 +6026,15 @@ arrow::Status MakeCountLiteralAction(
return arrow::Status::OK();
}

arrow::Status MakeCountDistinctAction(
arrow::compute::ExecContext* ctx, int arg,
std::vector<std::shared_ptr<arrow::DataType>> res_type_list,
std::shared_ptr<ActionBase>* out) {
auto action_ptr = std::make_shared<CountDistinctAction<arrow::Int64Type>>(ctx, arg);
*out = std::dynamic_pointer_cast<ActionBase>(action_ptr);
return arrow::Status::OK();
}

arrow::Status MakeMinAction(arrow::compute::ExecContext* ctx,
std::shared_ptr<arrow::DataType> type,
std::vector<std::shared_ptr<arrow::DataType>> res_type_list,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class ActionBase {
std::function<arrow::Status()>* on_null);
virtual arrow::Status EvaluateCountLiteral(const int& len);
virtual arrow::Status Evaluate(const arrow::ArrayVector& in);
virtual arrow::Status EvaluateCountDistinct(const arrow::ArrayVector& in);
virtual arrow::Status Evaluate(int dest_group_id);
virtual arrow::Status Evaluate(int dest_group_id, void* data);
virtual arrow::Status Evaluate(int dest_group_id, void* data1, void* data2);
Expand All @@ -64,6 +65,7 @@ class ActionBase {
virtual arrow::Status Finish(uint64_t offset, uint64_t length, ArrayList* out);
virtual arrow::Status FinishAndReset(ArrayList* out);
virtual uint64_t GetResultLength();
virtual std::string getName();
};

arrow::Status MakeUniqueAction(
Expand All @@ -80,6 +82,11 @@ arrow::Status MakeCountLiteralAction(
std::vector<std::shared_ptr<arrow::DataType>> res_type_list,
std::shared_ptr<ActionBase>* out);

arrow::Status MakeCountDistinctAction(
arrow::compute::ExecContext* ctx, int arg,
std::vector<std::shared_ptr<arrow::DataType>> res_type_list,
std::shared_ptr<ActionBase>* out);

arrow::Status MakeSumAction(arrow::compute::ExecContext* ctx,
std::shared_ptr<arrow::DataType> type,
std::vector<std::shared_ptr<arrow::DataType>> res_type_list,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class HashAggregateKernel::Impl {
std::vector<std::shared_ptr<gandiva::Node>> result_field_node_list,
std::vector<std::shared_ptr<gandiva::Node>> result_expr_node_list)
: ctx_(ctx), action_list_(action_list) {
#ifdef DEBUG
std::cout << "============ make hashagg kernel ============ " << std::endl;
#endif
// if there is projection inside aggregate, we need to extract them into
// projector_list
for (auto node : input_field_list) {
Expand Down Expand Up @@ -147,6 +150,10 @@ class HashAggregateKernel::Impl {
// 1. create pre project
std::shared_ptr<GandivaProjector> pre_process_projector;
if (!prepare_function_list_.empty()) {
#ifdef DEBUG
std::cout << "gandiva schema: " << arrow::schema(input_field_list_)->ToString()
<< std::endl;
#endif
auto pre_process_expr_list = GetGandivaKernel(prepare_function_list_);
pre_process_projector = std::make_shared<GandivaProjector>(
ctx_, arrow::schema(input_field_list_), pre_process_expr_list);
Expand Down Expand Up @@ -752,6 +759,11 @@ class HashAggregateKernel::Impl {
result_id += 1;
RETURN_NOT_OK(
MakeFirstFinalAction(ctx_, action_input_type, res_type_list, &action));
} else if (action_name.compare(0, 21, "action_countDistinct_") == 0) {
auto res_type_list = {result_field_list[result_id]};
result_id += 1;
int arg = std::stol(action_name.substr(21));
RETURN_NOT_OK(MakeCountDistinctAction(ctx_, arg, res_type_list, &action));
} else {
return arrow::Status::NotImplemented(action_name, " is not implementetd.");
}
Expand Down Expand Up @@ -783,6 +795,9 @@ class HashAggregateKernel::Impl {
pre_process_projector_(pre_process_projector),
post_process_projector_(post_process_projector),
action_impl_list_(action_impl_list) {
#ifdef DEBUG
std::cout << "using numberic hashagg res" << std::endl;
#endif
batch_size_ = GetBatchSize();
aggr_hash_table_ = std::make_shared<SparseHashMap<T>>(ctx->memory_pool());
}
Expand Down Expand Up @@ -813,6 +828,8 @@ class HashAggregateKernel::Impl {
// literal
RETURN_NOT_OK(action->EvaluateCountLiteral(in[0]->length()));

} else if (action->getName() == "CountDistinctAction") {
RETURN_NOT_OK(action->EvaluateCountDistinct(cols));
} else {
RETURN_NOT_OK(action->Evaluate(cols));
}
Expand Down