Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-1065] Adding hashagg w/ filter support (#1066)
Browse files Browse the repository at this point in the history
* Adding hashagg w/ filter support

This patch adds support for hashagg w/ filter

Signed-off-by: Yuan Zhou <[email protected]>

* fix tests

Signed-off-by: Yuan Zhou <[email protected]>

* remove comment

Signed-off-by: Yuan Zhou <[email protected]>

* disable codegen for hashagg w/ filter first

Signed-off-by: Yuan Zhou <[email protected]>

Signed-off-by: Yuan Zhou <[email protected]>
  • Loading branch information
zhouyuan authored Aug 16, 2022
1 parent cce191a commit 9eeddca
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -504,10 +504,6 @@ case class ColumnarHashAggregateExec(
var res_index = 0
for (expIdx <- aggregateExpressions.indices) {
val exp: AggregateExpression = aggregateExpressions(expIdx)
if (exp.filter.isDefined) {
throw new UnsupportedOperationException(
"filter is not supported in AggregateExpression")
}
val mode = exp.mode
val aggregateFunc = exp.aggregateFunction
aggregateFunc match {
Expand Down Expand Up @@ -692,6 +688,9 @@ case class ColumnarHashAggregateExec(
override def supportColumnarCodegen: Boolean = {
for (expr <- aggregateExpressions) {
// TODO: close the gap in supporting code gen.
if (expr.filter.isDefined) {
return false
}
expr.aggregateFunction match {
case _: First =>
return false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class ColumnarHashAggregation(
var inputAttrQueue: scala.collection.mutable.Queue[Attribute] = _
val resultType = CodeGeneration.getResultType()
val NaN_check : Boolean = GazellePluginConfig.getConf.enableColumnarNaNCheck
var distIndex = 0

def getColumnarFuncNode(expr: Expression): TreeNode = {
try {
Expand Down Expand Up @@ -155,14 +156,25 @@ class ColumnarHashAggregation(
case Partial =>
val childrenColumnarFuncNodeList =
aggregateFunc.children.toList.map(expr => getColumnarFuncNode(expr))
if (aggregateFunc.children(0).isInstanceOf[Literal]) {
if (aggregateFunc.children(0).isInstanceOf[Literal] && !aggregateExpression.filter.isDefined) {
TreeBuilder.makeFunction(
s"action_countLiteral_${aggregateFunc.children(0)}",
Lists.newArrayList(),
resultType)
} else {
TreeBuilder
.makeFunction("action_count", childrenColumnarFuncNodeList.asJava, resultType)
if (aggregateExpression.filter.isDefined) {
val filterColumnarFuncNodeList = List(getColumnarFuncNode(aggregateExpression.filter.get))
distIndex += 1
// TODO(): rename this to coundFilter?
TreeBuilder
.makeFunction(s"action_countDistinct_${distIndex}",
(childrenColumnarFuncNodeList ::: filterColumnarFuncNodeList).asJava,
resultType)
} else {
TreeBuilder
.makeFunction("action_count", childrenColumnarFuncNodeList.asJava, resultType)
}

}
case Final | PartialMerge =>
val childrenColumnarFuncNodeList =
Expand Down
167 changes: 167 additions & 0 deletions native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ arrow::Status ActionBase::EvaluateCountLiteral(const int& len) {
return arrow::Status::NotImplemented("ActionBase EvaluateCountLiteral is abstract.");
}

arrow::Status ActionBase::EvaluateCountDistinct(const arrow::ArrayVector& in) {
return arrow::Status::NotImplemented("ActionBase EvaluateCountDistinct is abstract.");
}

arrow::Status ActionBase::Evaluate(int dest_group_id) {
return arrow::Status::NotImplemented("ActionBase Evaluate is abstract.");
}
Expand Down Expand Up @@ -118,6 +122,8 @@ arrow::Status ActionBase::FinishAndReset(ArrayList* out) {

uint64_t ActionBase::GetResultLength() { return 0; }

std::string ActionBase::getName() { return ""; }

//////////////// UniqueAction ///////////////
template <typename DataType, typename CType>
class UniqueAction : public ActionBase {
Expand Down Expand Up @@ -438,6 +444,158 @@ class CountAction : public ActionBase {
uint64_t length_ = 0;
};

//////////////// CountDistinctAction ///////////////
template <typename DataType>
class CountDistinctAction : public ActionBase {
public:
CountDistinctAction(arrow::compute::ExecContext* ctx, int arg)
: ctx_(ctx), localGid_(arg) {
#ifdef DEBUG
std::cout << "Construct CountDistinctAction" << std::endl;
#endif
std::unique_ptr<arrow::ArrayBuilder> array_builder;
arrow::MakeBuilder(ctx_->memory_pool(), arrow::TypeTraits<DataType>::type_singleton(),
&array_builder);
builder_.reset(
arrow::internal::checked_cast<ResBuilderType*>(array_builder.release()));
}
~CountDistinctAction() {
#ifdef DEBUG
std::cout << "Destruct CountDistinctAction" << std::endl;
#endif
}
std::string getName() { return "CountDistinctAction"; }
arrow::Status Submit(ArrayList in_list, int max_group_id,
std::function<arrow::Status(int)>* on_valid,
std::function<arrow::Status()>* on_null) override {
// resize result data
if (cache_.size() <= max_group_id) {
cache_.resize(max_group_id + 1, 0);
length_ = cache_.size();
}

// prepare evaluate lambda
*on_valid = [this](int dest_group_id) {
cache_[dest_group_id] += 1;
return arrow::Status::OK();
};

*on_null = [this]() { return arrow::Status::OK(); };
return arrow::Status::OK();
}

arrow::Status GrowByFactor(int dest_group_id) {
int max_group_id;
if (cache_.size() < 128) {
max_group_id = 128;
} else {
max_group_id = cache_.size() * 2;
}
cache_.resize(max_group_id, 0);
return arrow::Status::OK();
}

arrow::Status EvaluateCountLiteral(const int& len) {
if (cache_.empty()) {
cache_.resize(1, 0);
length_ = 1;
}
cache_[0] += len;
return arrow::Status::OK();
}

arrow::Status EvaluateCountDistinct(const arrow::ArrayVector& in) {
if (cache_.empty()) {
cache_.resize(1, 0);
length_ = 1;
}
// at least two arrays, count attrs and gid
assert(in.size() > 1);
int gid = in.size() - 1;
std::shared_ptr<arrow::BooleanArray> typed_key_in =
std::dynamic_pointer_cast<arrow::BooleanArray>(in[gid]);
int length = in[0]->length();
int count_non_null = 0;
int count_null = 0;
for (size_t id = 0; id < length; id++) {
if (typed_key_in->GetView(id) == 0) {
count_null++;
continue;
}
for (int colId = 0; colId < in.size() - 1; colId++) {
if (in[colId]->IsNull(id)) {
count_null++;
break;
}
}
}
count_non_null = length - count_null;
cache_[0] += count_non_null;
return arrow::Status::OK();
}

arrow::Status Evaluate(const arrow::ArrayVector& in) {
return arrow::Status::NotImplemented(
"CountDistinctAction Non-Groupby Evaluate is unsupported.");
}

arrow::Status Evaluate(int dest_group_id) {
auto target_group_size = dest_group_id + 1;
if (cache_.size() <= target_group_size) GrowByFactor(target_group_size);
if (length_ < target_group_size) length_ = target_group_size;
cache_[dest_group_id] += 1;
return arrow::Status::OK();
}

arrow::Status EvaluateNull(int dest_group_id) {
auto target_group_size = dest_group_id + 1;
if (cache_.size() <= target_group_size) GrowByFactor(target_group_size);
if (length_ < target_group_size) length_ = target_group_size;
return arrow::Status::OK();
}

arrow::Status Finish(ArrayList* out) override {
std::shared_ptr<arrow::Array> arr_out;
builder_->Reset();
auto length = GetResultLength();
for (uint64_t i = 0; i < length; i++) {
builder_->Append(cache_[i]);
}
RETURN_NOT_OK(builder_->Finish(&arr_out));
out->push_back(arr_out);

return arrow::Status::OK();
}

uint64_t GetResultLength() { return length_; }

arrow::Status Finish(uint64_t offset, uint64_t length, ArrayList* out) override {
std::shared_ptr<arrow::Array> arr_out;
builder_->Reset();
auto res_length = (offset + length) > length_ ? (length_ - offset) : length;
for (uint64_t i = 0; i < res_length; i++) {
builder_->Append(cache_[offset + i]);
}

RETURN_NOT_OK(builder_->Finish(&arr_out));
out->push_back(arr_out);
return arrow::Status::OK();
}

private:
using ResArrayType = typename arrow::TypeTraits<DataType>::ArrayType;
using ResBuilderType = typename arrow::TypeTraits<DataType>::BuilderType;
// input
arrow::compute::ExecContext* ctx_;
// for debug only
int32_t localGid_ = -1;
// result
using CType = typename arrow::TypeTraits<DataType>::CType;
std::vector<CType> cache_;
std::unique_ptr<ResBuilderType> builder_;
uint64_t length_ = 0;
};

//////////////// CountLiteralAction ///////////////
template <typename DataType>
class CountLiteralAction : public ActionBase {
Expand Down Expand Up @@ -5868,6 +6026,15 @@ arrow::Status MakeCountLiteralAction(
return arrow::Status::OK();
}

arrow::Status MakeCountDistinctAction(
arrow::compute::ExecContext* ctx, int arg,
std::vector<std::shared_ptr<arrow::DataType>> res_type_list,
std::shared_ptr<ActionBase>* out) {
auto action_ptr = std::make_shared<CountDistinctAction<arrow::Int64Type>>(ctx, arg);
*out = std::dynamic_pointer_cast<ActionBase>(action_ptr);
return arrow::Status::OK();
}

arrow::Status MakeMinAction(arrow::compute::ExecContext* ctx,
std::shared_ptr<arrow::DataType> type,
std::vector<std::shared_ptr<arrow::DataType>> res_type_list,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class ActionBase {
std::function<arrow::Status()>* on_null);
virtual arrow::Status EvaluateCountLiteral(const int& len);
virtual arrow::Status Evaluate(const arrow::ArrayVector& in);
virtual arrow::Status EvaluateCountDistinct(const arrow::ArrayVector& in);
virtual arrow::Status Evaluate(int dest_group_id);
virtual arrow::Status Evaluate(int dest_group_id, void* data);
virtual arrow::Status Evaluate(int dest_group_id, void* data1, void* data2);
Expand All @@ -64,6 +65,7 @@ class ActionBase {
virtual arrow::Status Finish(uint64_t offset, uint64_t length, ArrayList* out);
virtual arrow::Status FinishAndReset(ArrayList* out);
virtual uint64_t GetResultLength();
virtual std::string getName();
};

arrow::Status MakeUniqueAction(
Expand All @@ -80,6 +82,11 @@ arrow::Status MakeCountLiteralAction(
std::vector<std::shared_ptr<arrow::DataType>> res_type_list,
std::shared_ptr<ActionBase>* out);

arrow::Status MakeCountDistinctAction(
arrow::compute::ExecContext* ctx, int arg,
std::vector<std::shared_ptr<arrow::DataType>> res_type_list,
std::shared_ptr<ActionBase>* out);

arrow::Status MakeSumAction(arrow::compute::ExecContext* ctx,
std::shared_ptr<arrow::DataType> type,
std::vector<std::shared_ptr<arrow::DataType>> res_type_list,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class HashAggregateKernel::Impl {
std::vector<std::shared_ptr<gandiva::Node>> result_field_node_list,
std::vector<std::shared_ptr<gandiva::Node>> result_expr_node_list)
: ctx_(ctx), action_list_(action_list) {
#ifdef DEBUG
std::cout << "============ make hashagg kernel ============ " << std::endl;
#endif
// if there is projection inside aggregate, we need to extract them into
// projector_list
for (auto node : input_field_list) {
Expand Down Expand Up @@ -147,6 +150,10 @@ class HashAggregateKernel::Impl {
// 1. create pre project
std::shared_ptr<GandivaProjector> pre_process_projector;
if (!prepare_function_list_.empty()) {
#ifdef DEBUG
std::cout << "gandiva schema: " << arrow::schema(input_field_list_)->ToString()
<< std::endl;
#endif
auto pre_process_expr_list = GetGandivaKernel(prepare_function_list_);
pre_process_projector = std::make_shared<GandivaProjector>(
ctx_, arrow::schema(input_field_list_), pre_process_expr_list);
Expand Down Expand Up @@ -752,6 +759,11 @@ class HashAggregateKernel::Impl {
result_id += 1;
RETURN_NOT_OK(
MakeFirstFinalAction(ctx_, action_input_type, res_type_list, &action));
} else if (action_name.compare(0, 21, "action_countDistinct_") == 0) {
auto res_type_list = {result_field_list[result_id]};
result_id += 1;
int arg = std::stol(action_name.substr(21));
RETURN_NOT_OK(MakeCountDistinctAction(ctx_, arg, res_type_list, &action));
} else {
return arrow::Status::NotImplemented(action_name, " is not implementetd.");
}
Expand Down Expand Up @@ -783,6 +795,9 @@ class HashAggregateKernel::Impl {
pre_process_projector_(pre_process_projector),
post_process_projector_(post_process_projector),
action_impl_list_(action_impl_list) {
#ifdef DEBUG
std::cout << "using numberic hashagg res" << std::endl;
#endif
batch_size_ = GetBatchSize();
aggr_hash_table_ = std::make_shared<SparseHashMap<T>>(ctx->memory_pool());
}
Expand Down Expand Up @@ -813,6 +828,8 @@ class HashAggregateKernel::Impl {
// literal
RETURN_NOT_OK(action->EvaluateCountLiteral(in[0]->length()));

} else if (action->getName() == "CountDistinctAction") {
RETURN_NOT_OK(action->EvaluateCountDistinct(cols));
} else {
RETURN_NOT_OK(action->Evaluate(cols));
}
Expand Down

0 comments on commit 9eeddca

Please sign in to comment.