Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
implement concat_ws
Browse files Browse the repository at this point in the history
Signed-off-by: Yuan Zhou <[email protected]>
  • Loading branch information
zhouyuan committed Jun 12, 2022
1 parent 129c9f3 commit 8601dc8
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,67 @@ import org.apache.spark.sql.types._

import scala.collection.mutable.ListBuffer

class ColumnarConcatWs(exps: Seq[Expression], original: Expression)
extends ConcatWs(exps: Seq[Expression])
with ColumnarExpression
with Logging {

buildCheck()

def buildCheck(): Unit = {
exps.foreach(expr =>
if (expr.dataType != StringType) {
throw new UnsupportedOperationException(
s"${expr.dataType} is not supported in ColumnarConcatWS")
})
}

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
false
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
val iter: Iterator[Expression] = exps.iterator
val exp = iter.next()
val iterFaster: Iterator[Expression] = exps.iterator
iterFaster.next()
iterFaster.next()

val (exp_node, expType): (TreeNode, ArrowType) =
exp.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)

val resultType = new ArrowType.Utf8()
//concat_ws is null senstive
val funcNode = TreeBuilder.makeFunction("concat",
Lists.newArrayList(exp_node, rightNode(args, exp, exps, iter, iterFaster)), resultType)
(funcNode, expType)
}

def rightNode(args: java.lang.Object, head: Expression, exps: Seq[Expression],
iter: Iterator[Expression], iterFaster: Iterator[Expression]): TreeNode = {
if (!iterFaster.hasNext) {
// When iter reaches the last but one expression
val (exp_node, expType): (TreeNode, ArrowType) =
exps.last.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
val (head_node, headType): (TreeNode, ArrowType) =
head.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
val resultType = new ArrowType.Utf8()
val funcNode = TreeBuilder.makeFunction("concat",
Lists.newArrayList(head_node, exp_node), resultType)
funcNode
} else {
val exp = iter.next()
iterFaster.next()
val (exp_node, expType): (TreeNode, ArrowType) =
exp.asInstanceOf[ColumnarExpression].doColumnarCodeGen(args)
val resultType = new ArrowType.Utf8()
val funcNode = TreeBuilder.makeFunction("concat",
Lists.newArrayList(exp_node, rightNode(args, head, exps, iter, iterFaster)), resultType)
funcNode
}
}
}

class ColumnarConcat(exps: Seq[Expression], original: Expression)
extends Concat(exps: Seq[Expression])
with ColumnarExpression
Expand All @@ -44,6 +105,10 @@ class ColumnarConcat(exps: Seq[Expression], original: Expression)
})
}

override def supportColumnarCodegen(args: java.lang.Object): Boolean = {
false
}

override def doColumnarCodeGen(args: java.lang.Object): (TreeNode, ArrowType) = {
val iter: Iterator[Expression] = exps.iterator
val exp = iter.next()
Expand Down Expand Up @@ -85,6 +150,8 @@ object ColumnarConcatOperator {
def create(exps: Seq[Expression], original: Expression): Expression = original match {
case c: Concat =>
new ColumnarConcat(exps, original)
case cws: ConcatWs =>
new ColumnarConcatWs(exps, original)
case other =>
throw new UnsupportedOperationException(s"not currently supported: $other.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,16 @@ object ColumnarExpressionConverter extends Logging {
convertBoundRefToAttrRef = convertBoundRefToAttrRef)
}
ColumnarConcatOperator.create(exps, expr)
case cws: ConcatWs =>
check_if_no_calculation = false
logInfo(s"${expr.getClass} ${expr} is supported, no_cal is $check_if_no_calculation.")
val exps = cws.children.map { expr =>
replaceWithColumnarExpression(
expr,
attributeSeq,
convertBoundRefToAttrRef = convertBoundRefToAttrRef)
}
ColumnarConcatOperator.create(exps, expr)
case r: Round =>
check_if_no_calculation = false
logInfo(s"${expr.getClass} ${expr} is supported, no_cal is $check_if_no_calculation.")
Expand Down

0 comments on commit 8601dc8

Please sign in to comment.