From 9eeddcab4d7b4e81a7609a8e2ca17127e8c7bb55 Mon Sep 17 00:00:00 2001 From: Yuan Date: Tue, 16 Aug 2022 22:59:07 +0800 Subject: [PATCH] [NSE-1065] Adding hashagg w/ filter support (#1066) * Adding hashagg w/ filter support This patch adds support for hashagg w/ filter Signed-off-by: Yuan Zhou * fix tests Signed-off-by: Yuan Zhou * remove comment Signed-off-by: Yuan Zhou * disable codegen for hashagg w/ filter first Signed-off-by: Yuan Zhou Signed-off-by: Yuan Zhou --- .../execution/ColumnarHashAggregateExec.scala | 7 +- .../expression/ColumnarHashAggregation.scala | 18 +- .../codegen/arrow_compute/ext/actions_impl.cc | 167 ++++++++++++++++++ .../codegen/arrow_compute/ext/actions_impl.h | 7 + .../ext/hash_aggregate_kernel.cc | 17 ++ 5 files changed, 209 insertions(+), 7 deletions(-) diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarHashAggregateExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarHashAggregateExec.scala index 3146a5933..56ec8bc1d 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarHashAggregateExec.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarHashAggregateExec.scala @@ -504,10 +504,6 @@ case class ColumnarHashAggregateExec( var res_index = 0 for (expIdx <- aggregateExpressions.indices) { val exp: AggregateExpression = aggregateExpressions(expIdx) - if (exp.filter.isDefined) { - throw new UnsupportedOperationException( - "filter is not supported in AggregateExpression") - } val mode = exp.mode val aggregateFunc = exp.aggregateFunction aggregateFunc match { @@ -692,6 +688,9 @@ case class ColumnarHashAggregateExec( override def supportColumnarCodegen: Boolean = { for (expr <- aggregateExpressions) { // TODO: close the gap in supporting code gen. + if (expr.filter.isDefined) { + return false + } expr.aggregateFunction match { case _: First => return false diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarHashAggregation.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarHashAggregation.scala index 14ad8d370..70868ee3f 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarHashAggregation.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarHashAggregation.scala @@ -71,6 +71,7 @@ class ColumnarHashAggregation( var inputAttrQueue: scala.collection.mutable.Queue[Attribute] = _ val resultType = CodeGeneration.getResultType() val NaN_check : Boolean = GazellePluginConfig.getConf.enableColumnarNaNCheck + var distIndex = 0 def getColumnarFuncNode(expr: Expression): TreeNode = { try { @@ -155,14 +156,25 @@ class ColumnarHashAggregation( case Partial => val childrenColumnarFuncNodeList = aggregateFunc.children.toList.map(expr => getColumnarFuncNode(expr)) - if (aggregateFunc.children(0).isInstanceOf[Literal]) { + if (aggregateFunc.children(0).isInstanceOf[Literal] && !aggregateExpression.filter.isDefined) { TreeBuilder.makeFunction( s"action_countLiteral_${aggregateFunc.children(0)}", Lists.newArrayList(), resultType) } else { - TreeBuilder - .makeFunction("action_count", childrenColumnarFuncNodeList.asJava, resultType) + if (aggregateExpression.filter.isDefined) { + val filterColumnarFuncNodeList = List(getColumnarFuncNode(aggregateExpression.filter.get)) + distIndex += 1 + // TODO(): rename this to coundFilter? + TreeBuilder + .makeFunction(s"action_countDistinct_${distIndex}", + (childrenColumnarFuncNodeList ::: filterColumnarFuncNodeList).asJava, + resultType) + } else { + TreeBuilder + .makeFunction("action_count", childrenColumnarFuncNodeList.asJava, resultType) + } + } case Final | PartialMerge => val childrenColumnarFuncNodeList = diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.cc index 898afa495..ef2f7379c 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.cc @@ -79,6 +79,10 @@ arrow::Status ActionBase::EvaluateCountLiteral(const int& len) { return arrow::Status::NotImplemented("ActionBase EvaluateCountLiteral is abstract."); } +arrow::Status ActionBase::EvaluateCountDistinct(const arrow::ArrayVector& in) { + return arrow::Status::NotImplemented("ActionBase EvaluateCountDistinct is abstract."); +} + arrow::Status ActionBase::Evaluate(int dest_group_id) { return arrow::Status::NotImplemented("ActionBase Evaluate is abstract."); } @@ -118,6 +122,8 @@ arrow::Status ActionBase::FinishAndReset(ArrayList* out) { uint64_t ActionBase::GetResultLength() { return 0; } +std::string ActionBase::getName() { return ""; } + //////////////// UniqueAction /////////////// template class UniqueAction : public ActionBase { @@ -438,6 +444,158 @@ class CountAction : public ActionBase { uint64_t length_ = 0; }; +//////////////// CountDistinctAction /////////////// +template +class CountDistinctAction : public ActionBase { + public: + CountDistinctAction(arrow::compute::ExecContext* ctx, int arg) + : ctx_(ctx), localGid_(arg) { +#ifdef DEBUG + std::cout << "Construct CountDistinctAction" << std::endl; +#endif + std::unique_ptr array_builder; + arrow::MakeBuilder(ctx_->memory_pool(), arrow::TypeTraits::type_singleton(), + &array_builder); + builder_.reset( + arrow::internal::checked_cast(array_builder.release())); + } + ~CountDistinctAction() { +#ifdef DEBUG + std::cout << "Destruct CountDistinctAction" << std::endl; +#endif + } + std::string getName() { return "CountDistinctAction"; } + arrow::Status Submit(ArrayList in_list, int max_group_id, + std::function* on_valid, + std::function* on_null) override { + // resize result data + if (cache_.size() <= max_group_id) { + cache_.resize(max_group_id + 1, 0); + length_ = cache_.size(); + } + + // prepare evaluate lambda + *on_valid = [this](int dest_group_id) { + cache_[dest_group_id] += 1; + return arrow::Status::OK(); + }; + + *on_null = [this]() { return arrow::Status::OK(); }; + return arrow::Status::OK(); + } + + arrow::Status GrowByFactor(int dest_group_id) { + int max_group_id; + if (cache_.size() < 128) { + max_group_id = 128; + } else { + max_group_id = cache_.size() * 2; + } + cache_.resize(max_group_id, 0); + return arrow::Status::OK(); + } + + arrow::Status EvaluateCountLiteral(const int& len) { + if (cache_.empty()) { + cache_.resize(1, 0); + length_ = 1; + } + cache_[0] += len; + return arrow::Status::OK(); + } + + arrow::Status EvaluateCountDistinct(const arrow::ArrayVector& in) { + if (cache_.empty()) { + cache_.resize(1, 0); + length_ = 1; + } + // at least two arrays, count attrs and gid + assert(in.size() > 1); + int gid = in.size() - 1; + std::shared_ptr typed_key_in = + std::dynamic_pointer_cast(in[gid]); + int length = in[0]->length(); + int count_non_null = 0; + int count_null = 0; + for (size_t id = 0; id < length; id++) { + if (typed_key_in->GetView(id) == 0) { + count_null++; + continue; + } + for (int colId = 0; colId < in.size() - 1; colId++) { + if (in[colId]->IsNull(id)) { + count_null++; + break; + } + } + } + count_non_null = length - count_null; + cache_[0] += count_non_null; + return arrow::Status::OK(); + } + + arrow::Status Evaluate(const arrow::ArrayVector& in) { + return arrow::Status::NotImplemented( + "CountDistinctAction Non-Groupby Evaluate is unsupported."); + } + + arrow::Status Evaluate(int dest_group_id) { + auto target_group_size = dest_group_id + 1; + if (cache_.size() <= target_group_size) GrowByFactor(target_group_size); + if (length_ < target_group_size) length_ = target_group_size; + cache_[dest_group_id] += 1; + return arrow::Status::OK(); + } + + arrow::Status EvaluateNull(int dest_group_id) { + auto target_group_size = dest_group_id + 1; + if (cache_.size() <= target_group_size) GrowByFactor(target_group_size); + if (length_ < target_group_size) length_ = target_group_size; + return arrow::Status::OK(); + } + + arrow::Status Finish(ArrayList* out) override { + std::shared_ptr arr_out; + builder_->Reset(); + auto length = GetResultLength(); + for (uint64_t i = 0; i < length; i++) { + builder_->Append(cache_[i]); + } + RETURN_NOT_OK(builder_->Finish(&arr_out)); + out->push_back(arr_out); + + return arrow::Status::OK(); + } + + uint64_t GetResultLength() { return length_; } + + arrow::Status Finish(uint64_t offset, uint64_t length, ArrayList* out) override { + std::shared_ptr arr_out; + builder_->Reset(); + auto res_length = (offset + length) > length_ ? (length_ - offset) : length; + for (uint64_t i = 0; i < res_length; i++) { + builder_->Append(cache_[offset + i]); + } + + RETURN_NOT_OK(builder_->Finish(&arr_out)); + out->push_back(arr_out); + return arrow::Status::OK(); + } + + private: + using ResArrayType = typename arrow::TypeTraits::ArrayType; + using ResBuilderType = typename arrow::TypeTraits::BuilderType; + // input + arrow::compute::ExecContext* ctx_; + // for debug only + int32_t localGid_ = -1; + // result + using CType = typename arrow::TypeTraits::CType; + std::vector cache_; + std::unique_ptr builder_; + uint64_t length_ = 0; +}; + //////////////// CountLiteralAction /////////////// template class CountLiteralAction : public ActionBase { @@ -5868,6 +6026,15 @@ arrow::Status MakeCountLiteralAction( return arrow::Status::OK(); } +arrow::Status MakeCountDistinctAction( + arrow::compute::ExecContext* ctx, int arg, + std::vector> res_type_list, + std::shared_ptr* out) { + auto action_ptr = std::make_shared>(ctx, arg); + *out = std::dynamic_pointer_cast(action_ptr); + return arrow::Status::OK(); +} + arrow::Status MakeMinAction(arrow::compute::ExecContext* ctx, std::shared_ptr type, std::vector> res_type_list, diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.h b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.h index ddf6b2a9f..7e025f37f 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.h +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.h @@ -54,6 +54,7 @@ class ActionBase { std::function* on_null); virtual arrow::Status EvaluateCountLiteral(const int& len); virtual arrow::Status Evaluate(const arrow::ArrayVector& in); + virtual arrow::Status EvaluateCountDistinct(const arrow::ArrayVector& in); virtual arrow::Status Evaluate(int dest_group_id); virtual arrow::Status Evaluate(int dest_group_id, void* data); virtual arrow::Status Evaluate(int dest_group_id, void* data1, void* data2); @@ -64,6 +65,7 @@ class ActionBase { virtual arrow::Status Finish(uint64_t offset, uint64_t length, ArrayList* out); virtual arrow::Status FinishAndReset(ArrayList* out); virtual uint64_t GetResultLength(); + virtual std::string getName(); }; arrow::Status MakeUniqueAction( @@ -80,6 +82,11 @@ arrow::Status MakeCountLiteralAction( std::vector> res_type_list, std::shared_ptr* out); +arrow::Status MakeCountDistinctAction( + arrow::compute::ExecContext* ctx, int arg, + std::vector> res_type_list, + std::shared_ptr* out); + arrow::Status MakeSumAction(arrow::compute::ExecContext* ctx, std::shared_ptr type, std::vector> res_type_list, diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/hash_aggregate_kernel.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/hash_aggregate_kernel.cc index c09d74c72..c648d8d9d 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/hash_aggregate_kernel.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/hash_aggregate_kernel.cc @@ -56,6 +56,9 @@ class HashAggregateKernel::Impl { std::vector> result_field_node_list, std::vector> result_expr_node_list) : ctx_(ctx), action_list_(action_list) { +#ifdef DEBUG + std::cout << "============ make hashagg kernel ============ " << std::endl; +#endif // if there is projection inside aggregate, we need to extract them into // projector_list for (auto node : input_field_list) { @@ -147,6 +150,10 @@ class HashAggregateKernel::Impl { // 1. create pre project std::shared_ptr pre_process_projector; if (!prepare_function_list_.empty()) { +#ifdef DEBUG + std::cout << "gandiva schema: " << arrow::schema(input_field_list_)->ToString() + << std::endl; +#endif auto pre_process_expr_list = GetGandivaKernel(prepare_function_list_); pre_process_projector = std::make_shared( ctx_, arrow::schema(input_field_list_), pre_process_expr_list); @@ -752,6 +759,11 @@ class HashAggregateKernel::Impl { result_id += 1; RETURN_NOT_OK( MakeFirstFinalAction(ctx_, action_input_type, res_type_list, &action)); + } else if (action_name.compare(0, 21, "action_countDistinct_") == 0) { + auto res_type_list = {result_field_list[result_id]}; + result_id += 1; + int arg = std::stol(action_name.substr(21)); + RETURN_NOT_OK(MakeCountDistinctAction(ctx_, arg, res_type_list, &action)); } else { return arrow::Status::NotImplemented(action_name, " is not implementetd."); } @@ -783,6 +795,9 @@ class HashAggregateKernel::Impl { pre_process_projector_(pre_process_projector), post_process_projector_(post_process_projector), action_impl_list_(action_impl_list) { +#ifdef DEBUG + std::cout << "using numberic hashagg res" << std::endl; +#endif batch_size_ = GetBatchSize(); aggr_hash_table_ = std::make_shared>(ctx->memory_pool()); } @@ -813,6 +828,8 @@ class HashAggregateKernel::Impl { // literal RETURN_NOT_OK(action->EvaluateCountLiteral(in[0]->length())); + } else if (action->getName() == "CountDistinctAction") { + RETURN_NOT_OK(action->EvaluateCountDistinct(cols)); } else { RETURN_NOT_OK(action->Evaluate(cols)); }