Skip to content

Commit

Permalink
Support Spark 3.1.3
Browse files Browse the repository at this point in the history
  • Loading branch information
harveyyue committed Sep 5, 2024
1 parent 1d718c5 commit eff610e
Show file tree
Hide file tree
Showing 30 changed files with 285 additions and 41 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/tpcds.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ jobs:
sparkver: spark303
sparkurl: https://archive.apache.org/dist/spark/spark-3.0.3/spark-3.0.3-bin-hadoop2.7.tgz

test-spark313:
name: Test Spark313
uses: ./.github/workflows/tpcds-reusable.yml
with:
sparkver: spark313
sparkurl: https://archive.apache.org/dist/spark/spark-3.1.3/spark-3.1.3-bin-hadoop2.7.tgz

test-spark320:
name: Test Spark320
uses: ./.github/workflows/tpcds-reusable.yml
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ lib/
*.orig
.*.swp
.*.swo
*.cache

# macOS
.DS_Store
Expand Down
14 changes: 14 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,20 @@
</properties>
</profile>

<profile>
<id>spark313</id>
<properties>
<shimName>spark313</shimName>
<shimPkg>spark3</shimPkg>
<javaVersion>1.8</javaVersion>
<scalaVersion>2.12</scalaVersion>
<scalaLongVersion>2.12.15</scalaLongVersion>
<scalaTestVersion>3.2.9</scalaTestVersion>
<scalafmtVersion>3.0.0</scalafmtVersion>
<sparkVersion>3.1.3</sparkVersion>
</properties>
</profile>

<profile>
<id>spark320</id>
<properties>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ object InterceptedValidateSparkPlan extends Logging {
}
}

@enableIf(Seq("spark303", "spark320").contains(System.getProperty("blaze.shim")))
@enableIf(Seq("spark303", "spark313", "spark320").contains(System.getProperty("blaze.shim")))
def validate(plan: SparkPlan): Unit = {
throw new UnsupportedOperationException(
"validate is not supported in spark 3.0.3 or spark 3.2.0")
"validate is not supported in spark 3.0.3 or 3.1.3 or spark 3.2.0")
}

@enableIf(Seq("spark324", "spark333", "spark351").contains(System.getProperty("blaze.shim")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ class ShimsImpl extends Shims with Logging {

@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
override def shimVersion: String = "spark303"
@enableIf(Seq("spark313").contains(System.getProperty("blaze.shim")))
override def shimVersion: String = "spark313"
@enableIf(Seq("spark320").contains(System.getProperty("blaze.shim")))
override def shimVersion: String = "spark320"
@enableIf(Seq("spark324").contains(System.getProperty("blaze.shim")))
Expand Down Expand Up @@ -378,7 +380,7 @@ class ShimsImpl extends Shims with Logging {
MapStatus.apply(SparkEnv.get.blockManager.shuffleServerId, partitionLengths, mapId)
}

@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
@enableIf(Seq("spark303", "spark313").contains(System.getProperty("blaze.shim")))
override def commit(
dep: ShuffleDependency[_, _, _],
shuffleBlockResolver: IndexShuffleBlockResolver,
Expand Down Expand Up @@ -507,7 +509,7 @@ class ShimsImpl extends Shims with Logging {
exec.isInstanceOf[AQEShuffleReadExec]
}

@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
@enableIf(Seq("spark303", "spark313").contains(System.getProperty("blaze.shim")))
private def isAQEShuffleRead(exec: SparkPlan): Boolean = {
import org.apache.spark.sql.execution.adaptive.CustomShuffleReaderExec
exec.isInstanceOf[CustomShuffleReaderExec]
Expand Down Expand Up @@ -605,6 +607,85 @@ class ShimsImpl extends Shims with Logging {
}
}

@enableIf(Seq("spark313").contains(System.getProperty("blaze.shim")))
private def executeNativeAQEShuffleReader(exec: SparkPlan): NativeRDD = {
import org.apache.spark.sql.execution.adaptive.CustomShuffleReaderExec

exec match {
case CustomShuffleReaderExec(child, _) if isNative(child) =>
val shuffledRDD = exec.execute().asInstanceOf[ShuffledRowRDD]
val shuffleHandle = shuffledRDD.dependency.shuffleHandle

val inputRDD = executeNative(child)
val nativeShuffle = getUnderlyingNativePlan(child).asInstanceOf[NativeShuffleExchangeExec]
val nativeSchema: pb.Schema = nativeShuffle.nativeSchema
val metrics = MetricNode(Map(), inputRDD.metrics :: Nil)

new NativeRDD(
shuffledRDD.sparkContext,
metrics,
shuffledRDD.partitions,
new OneToOneDependency(shuffledRDD) :: Nil,
true,
(partition, taskContext) => {

// use reflection to get partitionSpec because ShuffledRowRDDPartition is private
val sqlMetricsReporter = taskContext.taskMetrics().createTempShuffleReadMetrics()
val spec = FieldUtils
.readDeclaredField(partition, "spec", true)
.asInstanceOf[ShufflePartitionSpec]
val reader = spec match {
case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) =>
SparkEnv.get.shuffleManager.getReader(
shuffleHandle,
startReducerIndex,
endReducerIndex,
taskContext,
sqlMetricsReporter)

case PartialReducerPartitionSpec(reducerIndex, startMapIndex, endMapIndex, _) =>
SparkEnv.get.shuffleManager.getReader(
shuffleHandle,
startMapIndex,
endMapIndex,
reducerIndex,
reducerIndex + 1,
taskContext,
sqlMetricsReporter)

case PartialMapperPartitionSpec(mapIndex, startReducerIndex, endReducerIndex) =>
SparkEnv.get.shuffleManager.getReader(
shuffleHandle,
mapIndex,
mapIndex + 1,
startReducerIndex,
endReducerIndex,
taskContext,
sqlMetricsReporter)
}

// store fetch iterator in jni resource before native compute
val jniResourceId = s"NativeShuffleReadExec:${UUID.randomUUID().toString}"
JniBridge.resourcesMap.put(
jniResourceId,
() => {
reader.asInstanceOf[BlazeBlockStoreShuffleReader[_, _]].readIpc()
})

pb.PhysicalPlanNode
.newBuilder()
.setIpcReader(
pb.IpcReaderExecNode
.newBuilder()
.setSchema(nativeSchema)
.setNumPartitions(shuffledRDD.getNumPartitions)
.setIpcProviderResourceId(jniResourceId)
.build())
.build()
})
}
}

@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
private def executeNativeAQEShuffleReader(exec: SparkPlan): NativeRDD = {
import org.apache.spark.sql.execution.adaptive.CustomShuffleReaderExec
Expand Down Expand Up @@ -700,7 +781,7 @@ class ShimsImpl extends Shims with Logging {
override def getSqlContext(sparkPlan: SparkPlan): SQLContext =
sparkPlan.session.sqlContext

@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
@enableIf(Seq("spark303", "spark313").contains(System.getProperty("blaze.shim")))
override def getSqlContext(sparkPlan: SparkPlan): SQLContext = sparkPlan.sqlContext

override def createNativeExprWrapper(
Expand All @@ -711,7 +792,7 @@ class ShimsImpl extends Shims with Logging {
}

@enableIf(
Seq("spark303", "spark320", "spark324", "spark333").contains(
Seq("spark303", "spark313", "spark320", "spark324", "spark333").contains(
System.getProperty("blaze.shim")))
private def convertPromotePrecision(
e: Expression,
Expand Down Expand Up @@ -751,7 +832,9 @@ class ShimsImpl extends Shims with Logging {
}
}

@enableIf(Seq("spark303", "spark320", "spark324").contains(System.getProperty("blaze.shim")))
@enableIf(
Seq("spark303", "spark313", "spark320", "spark324").contains(
System.getProperty("blaze.shim")))
private def convertBloomFilterAgg(agg: AggregateFunction): Option[pb.PhysicalAggExprNode] = None

@enableIf(Seq("spark333", "spark351").contains(System.getProperty("blaze.shim")))
Expand All @@ -775,7 +858,9 @@ class ShimsImpl extends Shims with Logging {
}
}

@enableIf(Seq("spark303", "spark320", "spark324").contains(System.getProperty("blaze.shim")))
@enableIf(
Seq("spark303", "spark313", "spark320", "spark324").contains(
System.getProperty("blaze.shim")))
private def convertBloomFilterMightContain(
e: Expression,
isPruningExpr: Boolean,
Expand All @@ -792,7 +877,7 @@ case class ForceNativeExecutionWrapper(override val child: SparkPlan)
override def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)

@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
@enableIf(Seq("spark303", "spark313").contains(System.getProperty("blaze.shim")))
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ case class ConvertToNativeExec(override val child: SparkPlan) extends ConvertToN
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)

@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
@enableIf(Seq("spark303", "spark313").contains(System.getProperty("blaze.shim")))
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ case class NativeAggExec(
with BaseAggregateExec {

@enableIf(
Seq("spark320", "spark324", "spark333", "spark351").contains(
Seq("spark313", "spark320", "spark324", "spark333", "spark351").contains(
System.getProperty("blaze.shim")))
override val requiredChildDistributionExpressions: Option[Seq[Expression]] =
theRequiredChildDistributionExpressions
Expand Down Expand Up @@ -79,7 +79,7 @@ case class NativeAggExec(
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)

@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
@enableIf(Seq("spark303", "spark313").contains(System.getProperty("blaze.shim")))
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ case class NativeBroadcastExchangeExec(mode: BroadcastMode, override val child:
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)

@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
@enableIf(Seq("spark303", "spark313").contains(System.getProperty("blaze.shim")))
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ case class NativeExpandExec(
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)

@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
@enableIf(Seq("spark303", "spark313").contains(System.getProperty("blaze.shim")))
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ case class NativeFilterExec(condition: Expression, override val child: SparkPlan
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)

@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
@enableIf(Seq("spark303", "spark313").contains(System.getProperty("blaze.shim")))
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ case class NativeGenerateExec(
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)

@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
@enableIf(Seq("spark303", "spark313").contains(System.getProperty("blaze.shim")))
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ case class NativeGlobalLimitExec(limit: Long, override val child: SparkPlan)
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)

@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
@enableIf(Seq("spark303", "spark313").contains(System.getProperty("blaze.shim")))
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ case class NativeLocalLimitExec(limit: Long, override val child: SparkPlan)
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)

@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
@enableIf(Seq("spark303", "spark313").contains(System.getProperty("blaze.shim")))
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ case class NativeParquetInsertIntoHiveTableExec(
extends NativeParquetInsertIntoHiveTableBase(cmd, child) {

@enableIf(
Seq("spark303", "spark320", "spark324", "spark333").contains(
Seq("spark303", "spark313", "spark320", "spark324", "spark333").contains(
System.getProperty("blaze.shim")))
override protected def getInsertIntoHiveTableCommand(
table: CatalogTable,
Expand Down Expand Up @@ -77,12 +77,12 @@ case class NativeParquetInsertIntoHiveTableExec(
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)

@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
@enableIf(Seq("spark303", "spark313").contains(System.getProperty("blaze.shim")))
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)

@enableIf(
Seq("spark303", "spark320", "spark324", "spark333").contains(
Seq("spark303", "spark313", "spark320", "spark324", "spark333").contains(
System.getProperty("blaze.shim")))
class BlazeInsertIntoHiveTable303(
table: CatalogTable,
Expand Down Expand Up @@ -134,6 +134,44 @@ case class NativeParquetInsertIntoHiveTableExec(
}
}

@enableIf(Seq("spark313").contains(System.getProperty("blaze.shim")))
override def basicWriteJobStatsTracker(hadoopConf: org.apache.hadoop.conf.Configuration) = {
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker
import org.apache.spark.sql.execution.datasources.BasicWriteTaskStats
import org.apache.spark.sql.execution.datasources.BasicWriteTaskStatsTracker
import org.apache.spark.sql.execution.datasources.WriteTaskStats
import org.apache.spark.sql.execution.datasources.WriteTaskStatsTracker
import org.apache.spark.util.SerializableConfiguration

import scala.collection.mutable

val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
new BasicWriteJobStatsTracker(serializableHadoopConf, metrics) {
override def newTaskInstance(): WriteTaskStatsTracker = {
new BasicWriteTaskStatsTracker(serializableHadoopConf.value) {
private[this] val partitions: mutable.ArrayBuffer[InternalRow] =
mutable.ArrayBuffer.empty

override def newPartition(partitionValues: InternalRow): Unit = {
partitions.append(partitionValues)
}

override def newRow(_row: InternalRow): Unit = {}

override def getFinalStats(): WriteTaskStats = {
val outputFileStat = ParquetSinkTaskContext.get.processedOutputFiles.remove()
BasicWriteTaskStats(
partitions = partitions,
numFiles = 1,
numBytes = outputFileStat.numBytes,
numRows = outputFileStat.numRows)
}
}
}
}
}

@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
override def basicWriteJobStatsTracker(hadoopConf: org.apache.hadoop.conf.Configuration) = {
import org.apache.spark.sql.catalyst.InternalRow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ case class NativeParquetSinkExec(
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)

@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
@enableIf(Seq("spark303", "spark313").contains(System.getProperty("blaze.shim")))
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ case class NativePartialTakeOrderedExec(
override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
copy(child = newChild)

@enableIf(Seq("spark303").contains(System.getProperty("blaze.shim")))
@enableIf(Seq("spark303", "spark313").contains(System.getProperty("blaze.shim")))
override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan =
copy(child = newChildren.head)
}
Loading

0 comments on commit eff610e

Please sign in to comment.