Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
add fallback to SMJ and refine
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Jan 25, 2021
1 parent 92715ee commit 276050f
Show file tree
Hide file tree
Showing 25 changed files with 328 additions and 182 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
import org.apache.spark.sql.types.{DecimalType, MapType, StructType}
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ExecutorManager
import org.apache.spark.sql.util.StructTypeFWD
import org.apache.spark.{SparkConf, TaskContext}
Expand Down Expand Up @@ -60,6 +60,31 @@ case class ColumnarConditionProjectExec(
"numInputBatches" -> SQLMetrics.createMetric(sparkContext, "input_batches"),
"processTime" -> SQLMetrics.createTimingMetric(sparkContext, "totaltime_condproject"))

buildCheck(condition, projectList, child.output)

def buildCheck(condExpr: Expression, projectList: Seq[Expression],
originalInputAttributes: Seq[Attribute]): Unit = {
// check datatype
originalInputAttributes.toList.foreach(attr => {
try {
ConverterUtils.checkIfTypeSupported(attr.dataType)
} catch {
case e : UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarConditionProjector.")
}
})
// check expr
if (condExpr != null) {
ColumnarExpressionConverter.replaceWithColumnarExpression(condExpr)
}
if (projectList != null) {
for (expr <- projectList) {
ColumnarExpressionConverter.replaceWithColumnarExpression(expr)
}
}
}

def isNullIntolerant(expr: Expression): Boolean = expr match {
case e: NullIntolerant => e.children.forall(isNullIntolerant)
case _ => false
Expand Down Expand Up @@ -237,10 +262,12 @@ case class ColumnarUnionExec(children: Seq[SparkPlan]) extends SparkPlan {
def buildCheck(): Unit = {
for (child <- children) {
for (schema <- child.schema) {
if (schema.dataType.isInstanceOf[MapType] ||
schema.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${schema.dataType} is not supported in ColumnarUnionExec")
try {
ConverterUtils.checkIfTypeSupported(schema.dataType)
} catch {
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${schema.dataType} is not supported in ColumnarUnionExec")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ import java.nio.ByteBuffer
import java.util.concurrent.TimeUnit._

import com.intel.oap.vectorized._
import com.intel.oap.{ColumnarGuardRule, ColumnarPluginConfig}
import com.intel.oap.ColumnarPluginConfig
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.util.{ExecutorManager, UserAddedJarUtils, Utils}
import org.apache.spark.util.{UserAddedJarUtils, Utils, ExecutorManager}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
Expand All @@ -39,7 +39,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.BoundReference
import org.apache.spark.sql.catalyst.expressions.BindReferences._
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}

import scala.collection.mutable.ListBuffer
import org.apache.arrow.vector.ipc.message.ArrowFieldNode
Expand All @@ -54,12 +54,10 @@ import io.netty.buffer.ByteBuf
import com.google.common.collect.Lists
import com.intel.oap.expression._
import com.intel.oap.vectorized.ExpressionEvaluator
import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashJoin}
import org.apache.spark.sql.types.{BinaryType, ByteType, DecimalType, NullType, StructField, StructType, TimestampType}
import org.apache.spark.sql.types.{StructField, StructType}

/**
* Performs a hash join of two child relations by first shuffling the data using the join keys.
Expand Down Expand Up @@ -106,21 +104,35 @@ case class ColumnarBroadcastHashJoinExec(
if (conditionExpr != null) {
ColumnarExpressionConverter.replaceWithColumnarExpression(conditionExpr)
}
// build check for res types
val streamInputAttributes: List[Attribute] = streamedPlan.output.toList
val unsupportedTypes = List(NullType, TimestampType, BinaryType, ByteType)
streamInputAttributes.foreach(attr => {
if (unsupportedTypes.indexOf(attr.dataType) != -1 ||
attr.dataType.isInstanceOf[DecimalType])
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarBroadcastHashJoinExec.")
})
// build check types
for (attr <- streamedPlan.output) {
try {
ConverterUtils.checkIfTypeSupported(attr.dataType)
} catch {
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarBroadcastHashJoinExec.")
}
}
for (attr <- buildPlan.output) {
try {
ConverterUtils.checkIfTypeSupported(attr.dataType)
} catch {
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarBroadcastHashJoinExec.")
}
}
// build check for expr
for (expr <- buildKeyExprs) {
ColumnarExpressionConverter.replaceWithColumnarExpression(expr)
if (buildKeyExprs != null) {
for (expr <- buildKeyExprs) {
ColumnarExpressionConverter.replaceWithColumnarExpression(expr)
}
}
for (expr <- streamedKeyExprs) {
ColumnarExpressionConverter.replaceWithColumnarExpression(expr)
if (streamedKeyExprs != null) {
for (expr <- streamedKeyExprs) {
ColumnarExpressionConverter.replaceWithColumnarExpression(expr)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,18 @@ case class ColumnarExpandExec(

def buildCheck(): Unit = {
// build check for projection
projections.foreach(proj => ColumnarProjection.buildCheck(originalInputAttributes, proj))
projections.foreach(proj =>
ColumnarProjection.buildCheck(originalInputAttributes, proj))
//check type
for (attr <- originalInputAttributes) {
try {
ConverterUtils.checkIfTypeSupported(attr.dataType)
} catch {
case e : UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarExpandExec.")
}
}
}

protected override def doExecute(): RDD[InternalRow] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ case class ColumnarHashAggregateExec(
numOutputBatches.set(0)
numInputBatches.set(0)

buildCheck()

val (listJars, signature): (Seq[String], String) =
if (ColumnarPluginConfig
.getConf(sparkConf)
Expand Down Expand Up @@ -140,16 +142,53 @@ case class ColumnarHashAggregateExec(
(List(), "")
}
} else {
(List(), "")
}
listJars.foreach(jar => logInfo(s"Uploaded ${jar}"))

def buildCheck(): Unit = {
// check datatype
for (attr <- child.output) {
try {
ColumnarAggregation.buildCheck(groupingExpressions, child.output,
aggregateExpressions, resultExpressions)
ConverterUtils.checkIfTypeSupported(attr.dataType)
} catch {
case e: UnsupportedOperationException =>
throw e
case e : UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarAggregation")
}
(List(), "")
}
listJars.foreach(jar => logInfo(s"Uploaded ${jar}"))
// check project
for (expr <- aggregateExpressions) {
val internalExpressionList = expr.aggregateFunction.children
ColumnarProjection.buildCheck(child.output, internalExpressionList)
}
ColumnarProjection.buildCheck(child.output, groupingExpressions)
ColumnarProjection.buildCheck(child.output, resultExpressions)
// check aggregate expressions
checkAggregate(aggregateExpressions)
}

def checkAggregate(aggregateExpressions: Seq[AggregateExpression]): Unit = {
for (expr <- aggregateExpressions) {
val mode = expr.mode
val aggregateFunction = expr.aggregateFunction
aggregateFunction match {
case Average(_) | Sum(_) | Count(_) | Max(_) | Min(_) =>
case StddevSamp(_) => mode match {
case Partial | Final =>
case other =>
throw new UnsupportedOperationException(s"not currently supported: $other.")
}
case other =>
throw new UnsupportedOperationException(s"not currently supported: $other.")
}
mode match {
case Partial | PartialMerge | Final =>
case other =>
throw new UnsupportedOperationException(s"not currently supported: $other.")
}
}
}

override def doExecuteColumnar(): RDD[ColumnarBatch] = {
child.executeColumnar().mapPartitionsWithIndex { (partIndex, iter) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import com.intel.oap.vectorized._
import com.intel.oap.ColumnarPluginConfig
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.util.{ExecutorManager, UserAddedJarUtils, Utils}
import org.apache.spark.util.{UserAddedJarUtils, Utils, ExecutorManager}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
Expand All @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
import scala.collection.JavaConverters._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch}
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}

import scala.collection.mutable.ListBuffer
import org.apache.arrow.vector.ipc.message.ArrowFieldNode
Expand All @@ -52,7 +52,6 @@ import com.intel.oap.vectorized.ExpressionEvaluator
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashJoin}
import org.apache.spark.sql.types.DecimalType

/**
* Performs a hash join of two child relations by first shuffling the data using the join keys.
Expand Down Expand Up @@ -94,19 +93,40 @@ case class ColumnarShuffledHashJoinExec(
}

def buildCheck(): Unit = {
// build check for condition
val conditionExpr: Expression = condition.orNull
if (conditionExpr != null) {
ColumnarExpressionConverter.replaceWithColumnarExpression(conditionExpr)
}
// build check types
for (attr <- streamedPlan.output) {
try {
ConverterUtils.checkIfTypeSupported(attr.dataType)
} catch {
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarShuffledHashJoinExec.")
}
}
for (attr <- buildPlan.output) {
if (attr.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarShuffledHashJoin.")
try {
ConverterUtils.checkIfTypeSupported(attr.dataType)
} catch {
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarShuffledHashJoinExec.")
}
CodeGeneration.getResultType(attr.dataType)
}
for (attr <- streamedPlan.output) {
if (attr.dataType.isInstanceOf[DecimalType]) {
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarShuffledHashJoin.")
// build check for expr
if (buildKeyExprs != null) {
for (expr <- buildKeyExprs) {
ColumnarExpressionConverter.replaceWithColumnarExpression(expr)
}
}
if (streamedKeyExprs != null) {
for (expr <- streamedKeyExprs) {
ColumnarExpressionConverter.replaceWithColumnarExpression(expr)
}
CodeGeneration.getResultType(attr.dataType)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,24 @@ case class ColumnarSortExec(
val numOutputRows = longMetric("numOutputRows")
val numOutputBatches = longMetric("numOutputBatches")

ColumnarSorter.buildCheck(output)
buildCheck()

def buildCheck(): Unit = {
// check types
for (attr <- output) {
try {
ConverterUtils.checkIfTypeSupported(attr.dataType)
} catch {
case e : UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarSorter.")
}
}
// check expr
sortOrder.toList.map(expr => {
ColumnarExpressionConverter.replaceWithColumnarExpression(expr.child)
})
}

/***************** WSCG related function ******************/
override def inputRDDs(): Seq[RDD[ColumnarBatch]] = child match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ case class ColumnarSortMergeJoinExec(
val totaltime_sortmegejoin = longMetric("totaltime_sortmergejoin")
val resultSchema = this.schema

buildCheck()

override def supportsColumnar = true

override protected def doExecute(): RDD[InternalRow] = {
Expand Down Expand Up @@ -334,6 +336,44 @@ case class ColumnarSortMergeJoinExec(
throw e
}*/

def buildCheck(): Unit = {
// build check for condition
val conditionExpr: Expression = condition.orNull
if (conditionExpr != null) {
ColumnarExpressionConverter.replaceWithColumnarExpression(conditionExpr)
}
// build check types
for (attr <- left.output) {
try {
ConverterUtils.checkIfTypeSupported(attr.dataType)
} catch {
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarSortMergeJoinExec.")
}
}
for (attr <- right.output) {
try {
ConverterUtils.checkIfTypeSupported(attr.dataType)
} catch {
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${attr.dataType} is not supported in ColumnarSortMergeJoinExec.")
}
}
// build check for expr
if (leftKeys != null) {
for (expr <- leftKeys) {
ColumnarExpressionConverter.replaceWithColumnarExpression(expr)
}
}
if (rightKeys != null) {
for (expr <- rightKeys) {
ColumnarExpressionConverter.replaceWithColumnarExpression(expr)
}
}
}

/***********************************************************/
def getCodeGenSignature: String =
if (resultSchema.size > 0) {
Expand Down
Loading

0 comments on commit 276050f

Please sign in to comment.