Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
support expression as join keys
Browse files Browse the repository at this point in the history
Signed-off-by: Yuan Zhou <[email protected]>
  • Loading branch information
zhouyuan committed Jan 18, 2021
1 parent c907202 commit 8e52711
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ case class ColumnarSortMergeJoinExec(
//do not call prebuild so we could skip the c++ codegen
//val triggerBuildSignature = getCodeGenSignature

try {
/*try {
ColumnarSortMergeJoin.precheck(
leftKeys,
rightKeys,
Expand All @@ -332,7 +332,7 @@ case class ColumnarSortMergeJoinExec(
} catch {
case e: Throwable =>
throw e
}
}*/

/***********************************************************/
def getCodeGenSignature: String =
Expand Down
136 changes: 120 additions & 16 deletions core/src/main/scala/com/intel/oap/expression/ColumnarSortMergeJoin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ class ColumnarSortMergeJoin(
prepareTime: SQLMetric,
totaltime_sortmergejoin: SQLMetric,
totalOutputNumRows: SQLMetric,
sparkConf: SparkConf)
sparkConf: SparkConf,
buildProjector: ColumnarProjection,
buildKeyProjectOrdinalList: List[Int],
streamProjector: ColumnarProjection,
streamKeyProjectOrdinalList: List[Int])
extends Logging {
ColumnarPluginConfig.getConf(sparkConf)
var probe_iterator: BatchIterator = _
Expand Down Expand Up @@ -105,13 +109,28 @@ class ColumnarSortMergeJoin(
}
build_cb = realbuildIter.next()
val beforeBuild = System.nanoTime()
val build_rb = ConverterUtils.createArrowRecordBatch(build_cb)
// handle projection
val projectedBuildKeyCols: List[ArrowWritableColumnVector] = if (buildProjector != null) {
val builderOrdinalList = buildProjector.getOrdinalList
val builderAttributes = buildProjector.output
val builderProjectCols = builderOrdinalList.map(i => {
build_cb.column(i).asInstanceOf[ArrowWritableColumnVector]
})
buildProjector.evaluate(build_cb.numRows, builderProjectCols.map(_.getValueVector()))
} else {
List[ArrowWritableColumnVector]()
}
val buildCols = (0 until build_cb.numCols).toList.map(i =>
build_cb.column(i).asInstanceOf[ArrowWritableColumnVector]) ::: projectedBuildKeyCols
val build_rb =
ConverterUtils.createArrowRecordBatch(build_cb.numRows, buildCols.map(_.getValueVector))
(0 until build_cb.numCols).toList.foreach(i =>
build_cb.column(i).asInstanceOf[ArrowWritableColumnVector].retain())
inputBatchHolder += build_cb
prober.evaluate(build_rb)
prepareTime += NANOSECONDS.toMillis(System.nanoTime() - beforeBuild)
ConverterUtils.releaseArrowRecordBatch(build_rb)
projectedBuildKeyCols.foreach(v => v.close)
}
if (build_cb != null) {
build_cb = null
Expand Down Expand Up @@ -149,9 +168,31 @@ class ColumnarSortMergeJoin(
last_cb = cb
val beforeJoin = System.nanoTime()
val stream_rb: ArrowRecordBatch = ConverterUtils.createArrowRecordBatch(cb)
val output_rb = probe_iterator.process(stream_input_arrow_schema, stream_rb)

ConverterUtils.releaseArrowRecordBatch(stream_rb)
val output_rb = if (cb.numRows > 0) {
val projectedStreamKeyCols: List[ArrowWritableColumnVector] =
if (streamProjector != null) {
val streamOrdinalList = streamProjector.getOrdinalList
val streamAttributes = streamProjector.output
val streamProjectCols = streamOrdinalList.map(i => {
cb.column(i).asInstanceOf[ArrowWritableColumnVector]
})
streamProjector.evaluate(cb.numRows, streamProjectCols.map(_.getValueVector()))
} else {
List[ArrowWritableColumnVector]()
}
val streamCols = (0 until cb.numCols).toList.map(i =>
cb.column(i).asInstanceOf[ArrowWritableColumnVector]) ::: projectedStreamKeyCols
val stream_rb: ArrowRecordBatch =
ConverterUtils.createArrowRecordBatch(cb.numRows, streamCols.map(_.getValueVector))

val res = probe_iterator.process(stream_input_arrow_schema, stream_rb)

ConverterUtils.releaseArrowRecordBatch(stream_rb)
projectedStreamKeyCols.foreach(v => v.close)
res
} else {
null
}
joinTime += NANOSECONDS.toMillis(System.nanoTime() - beforeJoin)
if (output_rb == null) {
val resultColumnVectors =
Expand Down Expand Up @@ -188,6 +229,8 @@ object ColumnarSortMergeJoin extends Logging {
var output_arrow_schema: Schema = _
var condition_probe_expr: ExpressionTree = _
var prober: ExpressionEvaluator = _
var buildKeyProjectOrdinalList: List[Int] = _
var streamKeyProjectOrdinalList: List[Int] = _

def init(
leftKeys: Seq[Expression],
Expand Down Expand Up @@ -234,16 +277,16 @@ object ColumnarSortMergeJoin extends Logging {
val lkeyFieldList: List[Field] = leftKeys.toList.zipWithIndex.map {
case (expr, i) => {
//TODO(): fix this workaround
if (expr.isInstanceOf[AttributeReference] && expr.asInstanceOf[AttributeReference].name == "none") {
return
}
//if (expr.isInstanceOf[AttributeReference] && expr.asInstanceOf[AttributeReference].name == "none") {
// return
//}
val (nativeNode, returnType) = ConverterUtils.getColumnarFuncNode(expr)
if (s"${nativeNode.toProtobuf}".contains("none#")) {
throw new UnsupportedOperationException(
s"Unsupport to generate native expression from replaceable expression.")
}
if (s"${nativeNode.toProtobuf}".contains("fnNode")) {
throw new UnsupportedOperationException(s"join key with expression is not supported.")
//throw new UnsupportedOperationException(s"join key with expression is not supported.")
lkeyProjectOrdinalList += i
Field.nullable(s"${expr}", returnType)
} else {
Expand All @@ -261,7 +304,7 @@ object ColumnarSortMergeJoin extends Logging {
case (expr, i) => {
val (nativeNode, returnType) = ConverterUtils.getColumnarFuncNode(expr)
if (s"${nativeNode.toProtobuf}".contains("fnNode")) {
throw new UnsupportedOperationException(s"join key with expression is not supported.")
//throw new UnsupportedOperationException(s"join key with expression is not supported.")
rkeyProjectOrdinalList += i
Field.nullable(s"${expr}", returnType)
} else {
Expand Down Expand Up @@ -289,15 +332,19 @@ object ColumnarSortMergeJoin extends Logging {
case _ =>
BuildLeft
}
val (
var (
build_key_field_list,
stream_key_field_list,
build_input_field_list,
stream_input_field_list) = buildSide match {
case BuildLeft =>
buildKeyProjectOrdinalList = lkeyProjectOrdinalList.toList
streamKeyProjectOrdinalList = rkeyProjectOrdinalList.toList
(lkeyFieldList, rkeyFieldList, l_input_field_list, r_input_field_list)

case BuildRight =>
buildKeyProjectOrdinalList = rkeyProjectOrdinalList.toList
streamKeyProjectOrdinalList = lkeyProjectOrdinalList.toList
(rkeyFieldList, lkeyFieldList, r_input_field_list, l_input_field_list)

}
Expand Down Expand Up @@ -362,6 +409,18 @@ object ColumnarSortMergeJoin extends Logging {
(build_input_field_list, stream_output_field_list ::: build_output_field_list)
}
}
// we need to add projectedKeyOutput into input_field_list here
if (buildKeyProjectOrdinalList.nonEmpty) {
build_input_field_list =
build_input_field_list ::: buildKeyProjectOrdinalList.map(i => build_key_field_list(i))
}

if (streamKeyProjectOrdinalList.nonEmpty) {
stream_input_field_list =
stream_input_field_list ::: streamKeyProjectOrdinalList.map(i => stream_key_field_list(i))
}
build_input_arrow_schema = new Schema(build_input_field_list.asJava)
stream_input_arrow_schema = new Schema(stream_input_field_list.asJava)

val conditionArrowSchema = new Schema(conditionInputFieldList.asJava)
output_arrow_schema = new Schema(conditionOutputFieldList.asJava)
Expand Down Expand Up @@ -438,6 +497,7 @@ object ColumnarSortMergeJoin extends Logging {
totaltime_sortmergejoin: SQLMetric,
numOutputRows: SQLMetric,
sparkConf: SparkConf): Unit = synchronized {
logInfo("precheck")
init(
leftKeys,
rightKeys,
Expand Down Expand Up @@ -467,7 +527,7 @@ object ColumnarSortMergeJoin extends Logging {
totaltime_sortmergejoin: SQLMetric,
numOutputRows: SQLMetric,
sparkConf: SparkConf): String = synchronized {

logInfo("prebuild")
init(
leftKeys,
rightKeys,
Expand Down Expand Up @@ -522,14 +582,55 @@ object ColumnarSortMergeJoin extends Logging {
numOutputRows,
sparkConf)

val buildSide: BuildSide = joinType match {
case LeftSemi =>
BuildRight
case LeftOuter =>
BuildRight
case LeftAnti =>
BuildRight
case j: ExistenceJoin =>
BuildRight
case LeftExistence(_) =>
BuildRight
case _ =>
BuildLeft
}

val (buildProjector, streamProjector) =
// create gandiva project to pre-process
buildSide match {
case BuildLeft =>
(
(if (buildKeyProjectOrdinalList.nonEmpty)
ColumnarProjection
.create(left.output, buildKeyProjectOrdinalList.map(i => leftKeys(i)))
else null),
(if (streamKeyProjectOrdinalList.nonEmpty)
ColumnarProjection
.create(right.output, streamKeyProjectOrdinalList.map(i => rightKeys(i)))
else null))

case BuildRight =>
(
(if (buildKeyProjectOrdinalList.nonEmpty)
ColumnarProjection
.create(right.output, buildKeyProjectOrdinalList.map(i => rightKeys(i)))
else null),
(if (streamKeyProjectOrdinalList.nonEmpty)
ColumnarProjection
.create(left.output, streamKeyProjectOrdinalList.map(i => leftKeys(i)))
else null))
}

prober = new ExpressionEvaluator(listJars.toList.asJava)
prober.build(
build_input_arrow_schema,
Lists.newArrayList(condition_probe_expr),
output_arrow_schema,
true)

columnarSortMergeJoin = new ColumnarSortMergeJoin(
new ColumnarSortMergeJoin(
prober,
stream_input_arrow_schema,
output_arrow_schema,
Expand All @@ -545,8 +646,12 @@ object ColumnarSortMergeJoin extends Logging {
prepareTime,
totaltime_sortmergejoin,
numOutputRows,
sparkConf)
columnarSortMergeJoin
sparkConf,
buildProjector,
buildKeyProjectOrdinalList,
streamProjector,
streamKeyProjectOrdinalList)

}

def close(): Unit = {
Expand Down Expand Up @@ -677,5 +782,4 @@ object ColumnarSortMergeJoin extends Logging {
condition_expression_node_list,
new ArrowType.Int(32, true) /*dummy ret type, won't be used*/ )
}

}

0 comments on commit 8e52711

Please sign in to comment.