From 30bd935795a9cf2d1a33f8022f99a80d793f675d Mon Sep 17 00:00:00 2001 From: Rui Mo Date: Mon, 17 May 2021 16:21:37 +0800 Subject: [PATCH] refine the build check for aggregate --- .../execution/ColumnarHashAggregateExec.scala | 158 ++++++++++++++++-- .../expression/ColumnarHashAggregation.scala | 17 +- .../codegen/arrow_compute/ext/actions_impl.cc | 1 + 3 files changed, 153 insertions(+), 23 deletions(-) diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarHashAggregateExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarHashAggregateExec.scala index 01d07138c..0831bf276 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarHashAggregateExec.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarHashAggregateExec.scala @@ -378,29 +378,155 @@ case class ColumnarHashAggregateExec( } ColumnarProjection.buildCheck(child.output, groupingExpressions) ColumnarProjection.buildCheck(child.output, resultExpressions) - // check aggregate expressions - checkAggregate(aggregateExpressions) + // check the supported types and modes for different aggregate functions + checkTypeAndAggrFunction(aggregateExpressions, aggregateAttributes) } - def checkAggregate(aggregateExpressions: Seq[AggregateExpression]): Unit = { - for (expr <- aggregateExpressions) { - val mode = expr.mode - val aggregateFunction = expr.aggregateFunction - aggregateFunction match { - case Average(_) | Sum(_) | Count(_) | Max(_) | Min(_) => - case StddevSamp(_, _) => + // This method checks the supported types and modes for different aggregate functions. + def checkTypeAndAggrFunction(aggregateExpressions: Seq[AggregateExpression], + aggregateAttributeList: Seq[Attribute]): Unit = { + var res_index = 0 + for (expIdx <- aggregateExpressions.indices) { + val exp: AggregateExpression = aggregateExpressions(expIdx) + val mode = exp.mode + val aggregateFunc = exp.aggregateFunction + aggregateFunc match { + case Average(_) => + val supportedTypes = List(ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DateType, BooleanType) mode match { - case Partial | Final => + case Partial => { + val avg = aggregateFunc.asInstanceOf[Average] + val aggBufferAttr = avg.inputAggBufferAttributes + for (index <- aggBufferAttr.indices) { + val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) + if (supportedTypes.indexOf(attr.dataType) == -1 && + !attr.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${attr.dataType} is not supported in Columnar Average") + } + } + res_index += 2 + } + case PartialMerge => res_index += 1 + case Final => res_index += 1 + case other => + throw new UnsupportedOperationException( + s"${other} is not supported in Columnar Average") + } + case Sum(_) => + val supportedTypes = List(ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DateType, BooleanType) + mode match { + case Partial | PartialMerge => { + val sum = aggregateFunc.asInstanceOf[Sum] + val aggBufferAttr = sum.inputAggBufferAttributes + if (aggBufferAttr.size == 2) { + // decimal sum check sum.resultType + val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head) + if (supportedTypes.indexOf(attr.dataType) == -1 && + !attr.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${attr.dataType} is not supported in Columnar Sum") + } + res_index += 2 + } else { + val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head) + if (supportedTypes.indexOf(attr.dataType) == -1 && + !attr.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${attr.dataType} is not supported in Columnar Sum") + } + res_index += 1 + } + } + case Final => res_index += 1 + case other => + throw new UnsupportedOperationException( + s"${other} is not supported in Columnar Sum") + } + case Count(_) => + mode match { + case Partial | PartialMerge | Final => { + res_index += 1 + } + case other => + throw new UnsupportedOperationException( + s"${other} is not supported in Columnar Count") + } + case Max(_) => + val supportedTypes = List(ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DateType, BooleanType, StringType) + mode match { + case Partial => { + val max = aggregateFunc.asInstanceOf[Max] + val aggBufferAttr = max.inputAggBufferAttributes + val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head) + if (supportedTypes.indexOf(attr.dataType) == -1 && + !attr.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${attr.dataType} is not supported in Columnar Max") + } + res_index += 1 + } + case PartialMerge | Final => res_index += 1 case other => throw new UnsupportedOperationException(s"not currently supported: $other.") } + case Min(_) => + val supportedTypes = List(ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, DateType, BooleanType, StringType) + mode match { + case Partial => { + val min = aggregateFunc.asInstanceOf[Min] + val aggBufferAttr = min.inputAggBufferAttributes + val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head) + if (supportedTypes.indexOf(attr.dataType) == -1 && + !attr.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${attr.dataType} is not supported in Columnar Min") + } + res_index += 1 + } + case PartialMerge | Final => res_index += 1 + case other => + throw new UnsupportedOperationException( + s"${other} is not supported in Columnar Min") + } + case StddevSamp(_,_) => + mode match { + case Partial => { + val supportedTypes = List(ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, BooleanType) + val stddevSamp = aggregateFunc.asInstanceOf[StddevSamp] + val aggBufferAttr = stddevSamp.inputAggBufferAttributes + for (index <- aggBufferAttr.indices) { + val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index)) + if (supportedTypes.indexOf(attr.dataType) == -1 && + !attr.dataType.isInstanceOf[DecimalType]) { + throw new UnsupportedOperationException( + s"${attr.dataType} is not supported in Columnar StddevSampPartial") + } + } + res_index += 3 + } + case Final => { + val supportedTypes = List(ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType) + val attr = aggregateAttributeList(res_index) + if (supportedTypes.indexOf(attr.dataType) == -1) { + throw new UnsupportedOperationException( + s"${attr.dataType} is not supported in Columnar StddevSampFinal") + } + res_index += 1 + } + case other => + throw new UnsupportedOperationException( + s"${other} is not supported in Columnar StddevSamp") + } case other => - throw new UnsupportedOperationException(s"not currently supported: $other.") - } - mode match { - case Partial | PartialMerge | Final => - case other => - throw new UnsupportedOperationException(s"not currently supported: $other.") + throw new UnsupportedOperationException( + s"${other} is not supported in ColumnarAggregation") } } } diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarHashAggregation.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarHashAggregation.scala index 29e3a3842..c484fef54 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarHashAggregation.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarHashAggregation.scala @@ -128,20 +128,23 @@ class ColumnarHashAggregation( case Sum(_) => mode match { case Partial => - val childrenColumnarFuncNodeList = - aggregateFunc.children.toList.map(expr => getColumnarFuncNode(expr)) - TreeBuilder.makeFunction("action_sum_partial", childrenColumnarFuncNodeList.asJava, resultType) + val childrenColumnarFuncNodeList = + aggregateFunc.children.toList.map(expr => getColumnarFuncNode(expr)) + TreeBuilder.makeFunction( + "action_sum_partial", + childrenColumnarFuncNodeList.asJava, resultType) case Final | PartialMerge => - val childrenColumnarFuncNodeList = - List(inputAttrQueue.dequeue).map(attr => getColumnarFuncNode(attr)) + val childrenColumnarFuncNodeList = + List(inputAttrQueue.dequeue).map(attr => getColumnarFuncNode(attr)) //FIXME(): decimal adds isEmpty column val sum = aggregateFunc.asInstanceOf[Sum] val attrBuf = sum.inputAggBufferAttributes if (attrBuf.size == 2) { inputAttrQueue.dequeue } - - TreeBuilder.makeFunction("action_sum", childrenColumnarFuncNodeList.asJava, resultType) + TreeBuilder.makeFunction( + "action_sum", + childrenColumnarFuncNodeList.asJava, resultType) case other => throw new UnsupportedOperationException(s"not currently supported: $other.") } diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.cc index 65359b23c..8c36e179e 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.cc @@ -1444,6 +1444,7 @@ class MaxAction> } } } + return arrow::Status::OK(); } arrow::Status Evaluate(int dest_group_id, void* data) {