From e9dfc2db6a3e51906cb18784919890e843e54c5b Mon Sep 17 00:00:00 2001 From: Yuan Date: Tue, 14 Jun 2022 23:38:26 +0800 Subject: [PATCH] [NSE-955] implement concat_ws (#963) --- .../expression/ColumnarConcatOperator.scala | 65 +++++++++++++++++++ .../ColumnarExpressionConverter.scala | 10 +++ 2 files changed, 75 insertions(+) diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarConcatOperator.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarConcatOperator.scala index 25f0155df..74119065b 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarConcatOperator.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarConcatOperator.scala @@ -29,6 +29,65 @@ import org.apache.spark.sql.types._ import scala.collection.mutable.ListBuffer +class ColumnarConcatWs(exps: Seq[Expression], original: Expression) + extends ConcatWs(exps: Seq[Expression]) + with ColumnarExpression + with Logging { + + buildCheck() + + def buildCheck(): Unit = { + exps.foreach(expr => + if (expr.dataType != StringType) { + throw new UnsupportedOperationException( + s"${expr.dataType} is not supported in ColumnarConcatWS") + }) + } + + override def supportColumnarCodegen(args: java.lang.Object): Boolean = { + false + } + + override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = { + val iter: Iterator[Expression] = exps.iterator + val exp = iter.next() // spliter + val exp1 = iter.next() + val iterFaster: Iterator[Expression] = exps.iterator + iterFaster.next() + iterFaster.next() + iterFaster.next() + + val (split_node, expType): (TreeNode, ArrowType) = + exp.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + val (exp1_node, exp1Type): (TreeNode, ArrowType) = + exp1.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + + val resultType = new ArrowType.Utf8() + val funcNode = TreeBuilder.makeFunction("concat", + Lists.newArrayList(exp1_node, split_node, rightNode(args, exps, split_node, iter, iterFaster)), resultType) + (funcNode, expType) + } + + def rightNode(args: java.lang.Object, exps: Seq[Expression], split_node: TreeNode, + iter: Iterator[Expression], iterFaster: Iterator[Expression]): TreeNode = { + if (!iterFaster.hasNext) { + // When iter reaches the last but one expression + val (exp_node, expType): (TreeNode, ArrowType) = + exps.last.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + exp_node + } else { + val exp = iter.next() + iterFaster.next() + val (exp_node, expType): (TreeNode, ArrowType) = + exp.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args) + val resultType = new ArrowType.Utf8() + val funcNode = TreeBuilder.makeFunction("concat", + Lists.newArrayList(exp_node, split_node, rightNode(args, exps, split_node, iter, iterFaster)), resultType) + funcNode + } + } +} + class ColumnarConcat(exps: Seq[Expression], original: Expression) extends Concat(exps: Seq[Expression]) with ColumnarExpression @@ -44,6 +103,10 @@ class ColumnarConcat(exps: Seq[Expression], original: Expression) }) } + override def supportColumnarCodegen(args: java.lang.Object): Boolean = { + false + } + override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = { val iter: Iterator[Expression] = exps.iterator val exp = iter.next() @@ -85,6 +148,8 @@ object ColumnarConcatOperator { def create(exps: Seq[Expression], original: Expression): Expression = original match { case c: Concat => new ColumnarConcat(exps, original) + case cws: ConcatWs => + new ColumnarConcatWs(exps, original) case other => throw new UnsupportedOperationException(s"not currently supported: $other.") } diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarExpressionConverter.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarExpressionConverter.scala index 383d2fbe1..2ba26a4e4 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarExpressionConverter.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/expression/ColumnarExpressionConverter.scala @@ -392,6 +392,16 @@ object ColumnarExpressionConverter extends Logging { convertBoundRefToAttrRef = convertBoundRefToAttrRef) } ColumnarConcatOperator.create(exps, expr) + case cws: ConcatWs => + check_if_no_calculation = false + logInfo(s"${expr.getClass} ${expr} is supported, no_cal is $check_if_no_calculation.") + val exps = cws.children.map { expr => + replaceWithColumnarExpression( + expr, + attributeSeq, + convertBoundRefToAttrRef = convertBoundRefToAttrRef) + } + ColumnarConcatOperator.create(exps, expr) case r: Round => check_if_no_calculation = false logInfo(s"${expr.getClass} ${expr} is supported, no_cal is $check_if_no_calculation.")