Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
refine the build check for aggregate
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed May 17, 2021
1 parent 02ecd3e commit 30bd935
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -378,29 +378,155 @@ case class ColumnarHashAggregateExec(
}
ColumnarProjection.buildCheck(child.output, groupingExpressions)
ColumnarProjection.buildCheck(child.output, resultExpressions)
// check aggregate expressions
checkAggregate(aggregateExpressions)
// check the supported types and modes for different aggregate functions
checkTypeAndAggrFunction(aggregateExpressions, aggregateAttributes)
}

def checkAggregate(aggregateExpressions: Seq[AggregateExpression]): Unit = {
for (expr <- aggregateExpressions) {
val mode = expr.mode
val aggregateFunction = expr.aggregateFunction
aggregateFunction match {
case Average(_) | Sum(_) | Count(_) | Max(_) | Min(_) =>
case StddevSamp(_, _) =>
// This method checks the supported types and modes for different aggregate functions.
def checkTypeAndAggrFunction(aggregateExpressions: Seq[AggregateExpression],
aggregateAttributeList: Seq[Attribute]): Unit = {
var res_index = 0
for (expIdx <- aggregateExpressions.indices) {
val exp: AggregateExpression = aggregateExpressions(expIdx)
val mode = exp.mode
val aggregateFunc = exp.aggregateFunction
aggregateFunc match {
case Average(_) =>
val supportedTypes = List(ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DateType, BooleanType)
mode match {
case Partial | Final =>
case Partial => {
val avg = aggregateFunc.asInstanceOf[Average]
val aggBufferAttr = avg.inputAggBufferAttributes
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
if (supportedTypes.indexOf(attr.dataType) == -1 &&
!attr.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in Columnar Average")
}
}
res_index += 2
}
case PartialMerge => res_index += 1
case Final => res_index += 1
case other =>
throw new UnsupportedOperationException(
s"${other} is not supported in Columnar Average")
}
case Sum(_) =>
val supportedTypes = List(ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DateType, BooleanType)
mode match {
case Partial | PartialMerge => {
val sum = aggregateFunc.asInstanceOf[Sum]
val aggBufferAttr = sum.inputAggBufferAttributes
if (aggBufferAttr.size == 2) {
// decimal sum check sum.resultType
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head)
if (supportedTypes.indexOf(attr.dataType) == -1 &&
!attr.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in Columnar Sum")
}
res_index += 2
} else {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head)
if (supportedTypes.indexOf(attr.dataType) == -1 &&
!attr.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in Columnar Sum")
}
res_index += 1
}
}
case Final => res_index += 1
case other =>
throw new UnsupportedOperationException(
s"${other} is not supported in Columnar Sum")
}
case Count(_) =>
mode match {
case Partial | PartialMerge | Final => {
res_index += 1
}
case other =>
throw new UnsupportedOperationException(
s"${other} is not supported in Columnar Count")
}
case Max(_) =>
val supportedTypes = List(ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DateType, BooleanType, StringType)
mode match {
case Partial => {
val max = aggregateFunc.asInstanceOf[Max]
val aggBufferAttr = max.inputAggBufferAttributes
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head)
if (supportedTypes.indexOf(attr.dataType) == -1 &&
!attr.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in Columnar Max")
}
res_index += 1
}
case PartialMerge | Final => res_index += 1
case other =>
throw new UnsupportedOperationException(s"not currently supported: $other.")
}
case Min(_) =>
val supportedTypes = List(ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DateType, BooleanType, StringType)
mode match {
case Partial => {
val min = aggregateFunc.asInstanceOf[Min]
val aggBufferAttr = min.inputAggBufferAttributes
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr.head)
if (supportedTypes.indexOf(attr.dataType) == -1 &&
!attr.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in Columnar Min")
}
res_index += 1
}
case PartialMerge | Final => res_index += 1
case other =>
throw new UnsupportedOperationException(
s"${other} is not supported in Columnar Min")
}
case StddevSamp(_,_) =>
mode match {
case Partial => {
val supportedTypes = List(ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, BooleanType)
val stddevSamp = aggregateFunc.asInstanceOf[StddevSamp]
val aggBufferAttr = stddevSamp.inputAggBufferAttributes
for (index <- aggBufferAttr.indices) {
val attr = ConverterUtils.getAttrFromExpr(aggBufferAttr(index))
if (supportedTypes.indexOf(attr.dataType) == -1 &&
!attr.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in Columnar StddevSampPartial")
}
}
res_index += 3
}
case Final => {
val supportedTypes = List(ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType)
val attr = aggregateAttributeList(res_index)
if (supportedTypes.indexOf(attr.dataType) == -1) {
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in Columnar StddevSampFinal")
}
res_index += 1
}
case other =>
throw new UnsupportedOperationException(
s"${other} is not supported in Columnar StddevSamp")
}
case other =>
throw new UnsupportedOperationException(s"not currently supported: $other.")
}
mode match {
case Partial | PartialMerge | Final =>
case other =>
throw new UnsupportedOperationException(s"not currently supported: $other.")
throw new UnsupportedOperationException(
s"${other} is not supported in ColumnarAggregation")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,20 +128,23 @@ class ColumnarHashAggregation(
case Sum(_) =>
mode match {
case Partial =>
val childrenColumnarFuncNodeList =
aggregateFunc.children.toList.map(expr => getColumnarFuncNode(expr))
TreeBuilder.makeFunction("action_sum_partial", childrenColumnarFuncNodeList.asJava, resultType)
val childrenColumnarFuncNodeList =
aggregateFunc.children.toList.map(expr => getColumnarFuncNode(expr))
TreeBuilder.makeFunction(
"action_sum_partial",
childrenColumnarFuncNodeList.asJava, resultType)
case Final | PartialMerge =>
val childrenColumnarFuncNodeList =
List(inputAttrQueue.dequeue).map(attr => getColumnarFuncNode(attr))
val childrenColumnarFuncNodeList =
List(inputAttrQueue.dequeue).map(attr => getColumnarFuncNode(attr))
//FIXME(): decimal adds isEmpty column
val sum = aggregateFunc.asInstanceOf[Sum]
val attrBuf = sum.inputAggBufferAttributes
if (attrBuf.size == 2) {
inputAttrQueue.dequeue
}

TreeBuilder.makeFunction("action_sum", childrenColumnarFuncNodeList.asJava, resultType)
TreeBuilder.makeFunction(
"action_sum",
childrenColumnarFuncNodeList.asJava, resultType)
case other =>
throw new UnsupportedOperationException(s"not currently supported: $other.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1444,6 +1444,7 @@ class MaxAction<DataType, CType, precompile::enable_if_string_like<DataType>>
}
}
}
return arrow::Status::OK();
}

arrow::Status Evaluate(int dest_group_id, void* data) {
Expand Down

0 comments on commit 30bd935

Please sign in to comment.