diff --git a/dist/pom.xml b/dist/pom.xml
index 438ebbbcf8d..da3886e3c9a 100644
--- a/dist/pom.xml
+++ b/dist/pom.xml
@@ -41,6 +41,16 @@
rapids-4-spark-shuffle_${scala.binary.version}
${project.version}
+
+ com.nvidia
+ rapids-4-spark-shims_${scala.binary.version}
+ ${project.version}
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ provided
+
@@ -49,6 +59,9 @@
org.apache.maven.plugins
maven-shade-plugin
+
+
+
false
true
@@ -94,6 +107,30 @@
+
+ net.alchim31.maven
+ scala-maven-plugin
+
+
+ update_config
+ verify
+
+ run
+
+
+
+
+
+
+ update_rapids_config
+ com.nvidia.spark.rapids.RapidsConf
+
+ ${project.basedir}/../docs/configs.md
+
+
+
+
+
org.apache.rat
apache-rat-plugin
diff --git a/docs/get-started/getting-started.md b/docs/get-started/getting-started.md
index 29c533f2416..5a86ce710d3 100644
--- a/docs/get-started/getting-started.md
+++ b/docs/get-started/getting-started.md
@@ -417,11 +417,14 @@ With `nv_peer_mem`, IB/RoCE-based transfers can perform zero-copy transfers dire
2) Install [UCX 1.8.1](https://github.com/openucx/ucx/releases/tag/v1.8.1).
3) You will need to configure your spark job with extra settings for UCX (we are looking to
-simplify these settings in the near future):
+simplify these settings in the near future). Choose the version of the shuffle manager
+that matches your Spark version. Currently we support
+Spark 3.0 (com.nvidia.spark.rapids.spark30.RapidsShuffleManager) and
+Spark 3.1 (com.nvidia.spark.rapids.spark31.RapidsShuffleManager):
```shell
...
---conf spark.shuffle.manager=com.nvidia.spark.RapidsShuffleManager \
+--conf spark.shuffle.manager=com.nvidia.spark.rapids.spark30.RapidsShuffleManager \
--conf spark.shuffle.service.enabled=false \
--conf spark.rapids.shuffle.transport.enabled=true \
--conf spark.executorEnv.UCX_TLS=cuda_copy,cuda_ipc,rc,tcp \
diff --git a/integration_tests/pom.xml b/integration_tests/pom.xml
index 564584448e1..e271790525b 100644
--- a/integration_tests/pom.xml
+++ b/integration_tests/pom.xml
@@ -28,7 +28,28 @@
rapids-4-spark-integration-tests_2.12
0.2.0-SNAPSHOT
+
+ 3.0.0
+
+
+
+ spark31tests
+
+ 3.1.0-SNAPSHOT
+
+
+
+
+
+ org.slf4j
+ jul-to-slf4j
+
+
+ org.slf4j
+ jcl-over-slf4j
+
+
org.scala-lang
scala-library
@@ -36,6 +57,7 @@
org.apache.spark
spark-sql_${scala.binary.version}
+ ${spark.test.version}
org.scalatest
diff --git a/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/mortgage/MortgageSparkSuite.scala b/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/mortgage/MortgageSparkSuite.scala
index d36d65c1295..e2fb9d4530b 100644
--- a/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/mortgage/MortgageSparkSuite.scala
+++ b/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/mortgage/MortgageSparkSuite.scala
@@ -16,7 +16,7 @@
package com.nvidia.spark.rapids.tests.mortgage
-import com.nvidia.spark.RapidsShuffleManager
+import com.nvidia.spark.rapids.ShimLoader
import org.scalatest.FunSuite
import org.apache.spark.sql.SparkSession
@@ -34,7 +34,7 @@ class MortgageSparkSuite extends FunSuite {
.config("spark.rapids.sql.test.enabled", false)
.config("spark.rapids.sql.incompatibleOps.enabled", true)
.config("spark.rapids.sql.hasNans", false)
- val rapidsShuffle = classOf[RapidsShuffleManager].getCanonicalName
+ val rapidsShuffle = ShimLoader.getSparkShims.getRapidsShuffleManagerClass
val prop = System.getProperty("rapids.shuffle.manager.override", "false")
if (prop.equalsIgnoreCase("true")) {
println("RAPIDS SHUFFLE MANAGER ACTIVE")
diff --git a/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/tpch/TpchLikeSparkSuite.scala b/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/tpch/TpchLikeSparkSuite.scala
index ebbf9fa1067..5355d7d2549 100644
--- a/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/tpch/TpchLikeSparkSuite.scala
+++ b/integration_tests/src/test/scala/com/nvidia/spark/rapids/tests/tpch/TpchLikeSparkSuite.scala
@@ -16,8 +16,8 @@
package com.nvidia.spark.rapids.tests.tpch
-import com.nvidia.spark.RapidsShuffleManager
import com.nvidia.spark.rapids.{ColumnarRdd, ExecutionPlanCaptureCallback}
+import com.nvidia.spark.rapids.ShimLoader
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.spark.sql.{DataFrame, SparkSession}
@@ -44,7 +44,7 @@ class TpchLikeSparkSuite extends FunSuite with BeforeAndAfterAll {
.config("spark.rapids.sql.explain", true)
.config("spark.rapids.sql.incompatibleOps.enabled", true)
.config("spark.rapids.sql.hasNans", false)
- val rapidsShuffle = classOf[RapidsShuffleManager].getCanonicalName
+ val rapidsShuffle = ShimLoader.getSparkShims.getRapidsShuffleManagerClass
val prop = System.getProperty("rapids.shuffle.manager.override", "false")
if (prop.equalsIgnoreCase("true")) {
println("RAPIDS SHUFFLE MANAGER ACTIVE")
diff --git a/pom.xml b/pom.xml
index 2a9968fda18..e77777ed077 100644
--- a/pom.xml
+++ b/pom.xml
@@ -76,6 +76,7 @@
sql-plugin
tests
integration_tests
+ shims
api_validation
@@ -128,6 +129,9 @@
true
+
+ spark31tests
+
@@ -152,6 +156,7 @@
UTF-8
not qarun
false
+ 1.7.30
@@ -168,6 +173,17 @@
${cuda.version}
provided
+
+ org.slf4j
+ jul-to-slf4j
+ ${slf4j.version}
+
+
+ org.slf4j
+ jcl-over-slf4j
+ ${slf4j.version}
+
+
org.scala-lang
scala-library
@@ -547,5 +563,15 @@
true
+
+ apache-snapshots-repo
+ https://repository.apache.org/content/repositories/snapshots/
+
+ false
+
+
+ true
+
+
diff --git a/shims/aggregator/pom.xml b/shims/aggregator/pom.xml
new file mode 100644
index 00000000000..062f4da1845
--- /dev/null
+++ b/shims/aggregator/pom.xml
@@ -0,0 +1,49 @@
+
+
+
+ 4.0.0
+
+
+ com.nvidia
+ rapids-4-spark-shims-aggregator_2.12
+ 0.2.0-SNAPSHOT
+ ../pom.xml
+
+ com.nvidia
+ rapids-4-spark-shims_2.12
+ jar
+ RAPIDS Accelerator for Apache Spark SQL Plugin Shim Aggregator
+ The RAPIDS SQL plugin for Apache Spark Shim Aggregator
+ 0.2.0-SNAPSHOT
+
+
+
+ com.nvidia
+ rapids-4-spark-shims-spark31_${scala.binary.version}
+ ${project.version}
+ compile
+
+
+ com.nvidia
+ rapids-4-spark-shims-spark30_${scala.binary.version}
+ ${project.version}
+ compile
+
+
+
diff --git a/shims/pom.xml b/shims/pom.xml
new file mode 100644
index 00000000000..80333a2bb9b
--- /dev/null
+++ b/shims/pom.xml
@@ -0,0 +1,69 @@
+
+
+
+ 4.0.0
+
+
+ com.nvidia
+ rapids-4-spark-parent
+ 0.2.0-SNAPSHOT
+ ../pom.xml
+
+ com.nvidia
+ rapids-4-spark-shims-aggregator_2.12
+ pom
+ RAPIDS Accelerator for Apache Spark SQL Plugin Shims
+ The RAPIDS SQL plugin for Apache Spark Shims
+ 0.2.0-SNAPSHOT
+
+
+ spark30
+ spark31
+ aggregator
+
+
+
+ com.nvidia
+ rapids-4-spark-sql_${scala.binary.version}
+ ${project.version}
+
+
+ ai.rapids
+ cudf
+ ${cuda.version}
+
+
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+
+
+ org.apache.rat
+ apache-rat-plugin
+
+
+ **/src/main/resources/META-INF/services/*
+
+
+
+
+
+
diff --git a/shims/spark30/pom.xml b/shims/spark30/pom.xml
new file mode 100644
index 00000000000..665e3b64fa0
--- /dev/null
+++ b/shims/spark30/pom.xml
@@ -0,0 +1,46 @@
+
+
+
+ 4.0.0
+
+
+ com.nvidia
+ rapids-4-spark-shims-aggregator_2.12
+ 0.2.0-SNAPSHOT
+ ../pom.xml
+
+ com.nvidia
+ rapids-4-spark-shims-spark30_2.12
+ RAPIDS Accelerator for Apache Spark SQL Plugin Spark 3.0 Shim
+ The RAPIDS SQL plugin for Apache Spark 3.0 Shim
+ 0.2.0-SNAPSHOT
+
+
+ 3.0.0
+
+
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark30.version}
+
+
+
+
diff --git a/shims/spark30/src/main/resources/META-INF/services/com.nvidia.spark.rapids.SparkShimServiceProvider b/shims/spark30/src/main/resources/META-INF/services/com.nvidia.spark.rapids.SparkShimServiceProvider
new file mode 100644
index 00000000000..a7727cca514
--- /dev/null
+++ b/shims/spark30/src/main/resources/META-INF/services/com.nvidia.spark.rapids.SparkShimServiceProvider
@@ -0,0 +1 @@
+com.nvidia.spark.rapids.shims.spark30.Spark30ShimServiceProvider
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala b/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuBroadcastHashJoinExec.scala
similarity index 95%
rename from sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala
rename to shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuBroadcastHashJoinExec.scala
index ac444d16471..19e64cd0176 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHashJoinExec.scala
+++ b/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuBroadcastHashJoinExec.scala
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.rapids.execution
+package com.nvidia.spark.rapids.shims.spark30
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.GpuMetricNames._
@@ -28,8 +28,12 @@ import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExec, SerializeConcatHostBuffersDeserializeBatch}
import org.apache.spark.sql.vectorized.ColumnarBatch
+/**
+ * Spark 3.1 changed packages of BuildLeft, BuildRight, BuildSide
+ */
class GpuBroadcastHashJoinMeta(
join: BroadcastHashJoinExec,
conf: RapidsConf,
diff --git a/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuBroadcastNestedLoopJoinExec.scala b/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuBroadcastNestedLoopJoinExec.scala
new file mode 100644
index 00000000000..dd81a98f8ce
--- /dev/null
+++ b/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuBroadcastNestedLoopJoinExec.scala
@@ -0,0 +1,43 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.shims.spark30
+
+import com.nvidia.spark.rapids.{GpuBuildLeft, GpuBuildRight, GpuBuildSide}
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, BuildLeft, BuildRight}
+import org.apache.spark.sql.rapids.execution._
+
+/**
+ * Spark 3.1 changed packages of BuildLeft, BuildRight, BuildSide
+ */
+case class GpuBroadcastNestedLoopJoinExec(
+ left: SparkPlan,
+ right: SparkPlan,
+ join: BroadcastNestedLoopJoinExec,
+ joinType: JoinType,
+ condition: Option[Expression],
+ targetSizeBytes: Long)
+ extends GpuBroadcastNestedLoopJoinExecBase(left, right, join, joinType, condition,
+ targetSizeBytes) {
+
+ def getGpuBuildSide: GpuBuildSide = {
+ GpuJoinUtils.getGpuBuildSide(join.buildSide)
+ }
+}
diff --git a/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuFirst.scala b/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuFirst.scala
new file mode 100644
index 00000000000..7de2090ed02
--- /dev/null
+++ b/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuFirst.scala
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.shims.spark30
+
+import com.nvidia.spark.rapids.GpuLiteral
+
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
+import org.apache.spark.sql.rapids.GpuFirstBase
+
+/**
+ * Parameters to GpuFirst changed in Spark 3.1
+ */
+case class GpuFirst(child: Expression, ignoreNullsExpr: Expression) extends GpuFirstBase(child) {
+ override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil
+
+ override val ignoreNulls: Boolean = ignoreNullsExpr match {
+ case l: Literal => l.value.asInstanceOf[Boolean]
+ case l: GpuLiteral => l.value.asInstanceOf[Boolean]
+ case _ => throw new IllegalArgumentException(
+ s"$this should only receive literals for ignoreNulls expression")
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else if (!ignoreNullsExpr.foldable) {
+ TypeCheckFailure(s"The second argument of GpuFirst must be a boolean literal, but " +
+ s"got: ${ignoreNullsExpr.sql}")
+ } else {
+ TypeCheckSuccess
+ }
+ }
+}
+
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashJoin.scala b/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuHashJoin.scala
similarity index 99%
rename from sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashJoin.scala
rename to shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuHashJoin.scala
index f0aaec323d5..b6325be8b52 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashJoin.scala
+++ b/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuHashJoin.scala
@@ -13,9 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package com.nvidia.spark.rapids
+package com.nvidia.spark.rapids.shims.spark30
import ai.rapids.cudf.{NvtxColor, Table}
+import com.nvidia.spark.rapids._
import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
diff --git a/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuLast.scala b/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuLast.scala
new file mode 100644
index 00000000000..d5df5f9c424
--- /dev/null
+++ b/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuLast.scala
@@ -0,0 +1,51 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.shims.spark30
+
+import com.nvidia.spark.rapids.GpuLiteral
+
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
+import org.apache.spark.sql.rapids.GpuLastBase
+
+/**
+ * Parameters to GpuLast changed in Spark 3.1
+ */
+case class GpuLast(child: Expression, ignoreNullsExpr: Expression) extends GpuLastBase(child) {
+ override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil
+
+ override val ignoreNulls: Boolean = ignoreNullsExpr match {
+ case l: Literal => l.value.asInstanceOf[Boolean]
+ case l: GpuLiteral => l.value.asInstanceOf[Boolean]
+ case _ => throw new IllegalArgumentException(
+ s"$this should only receive literals for ignoreNulls expression")
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else if (!ignoreNullsExpr.foldable) {
+ TypeCheckFailure(s"The second argument of GpuLast must be a boolean literal, but " +
+ s"got: ${ignoreNullsExpr.sql}")
+ } else {
+ TypeCheckSuccess
+ }
+ }
+}
+
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala b/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuShuffledHashJoinExec.scala
similarity index 92%
rename from sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala
rename to shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuShuffledHashJoinExec.scala
index 7ae310bd40f..16aadf62933 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala
+++ b/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuShuffledHashJoinExec.scala
@@ -14,8 +14,9 @@
* limitations under the License.
*/
-package com.nvidia.spark.rapids
+package com.nvidia.spark.rapids.shims.spark30
+import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.GpuMetricNames._
import org.apache.spark.TaskContext
@@ -29,6 +30,19 @@ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, S
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.vectorized.ColumnarBatch
+object GpuJoinUtils {
+ def getGpuBuildSide(buildSide: BuildSide): GpuBuildSide = {
+ buildSide match {
+ case BuildRight => GpuBuildRight
+ case BuildLeft => GpuBuildLeft
+ case _ => throw new Exception("unknown buildSide Type")
+ }
+ }
+}
+
+/**
+ * Spark 3.1 changed packages of BuildLeft, BuildRight, BuildSide
+ */
class GpuShuffledHashJoinMeta(
join: ShuffledHashJoinExec,
conf: RapidsConf,
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortMergeJoinExec.scala b/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuSortMergeJoinExec.scala
similarity index 96%
rename from sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortMergeJoinExec.scala
rename to shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuSortMergeJoinExec.scala
index af7e6070263..4b47007ee0e 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortMergeJoinExec.scala
+++ b/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/GpuSortMergeJoinExec.scala
@@ -14,19 +14,23 @@
* limitations under the License.
*/
-package com.nvidia.spark.rapids
+package com.nvidia.spark.rapids.shims.spark30
+
+import com.nvidia.spark.rapids._
-import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.execution.SortExec
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, SortMergeJoinExec}
+/**
+ * HashJoin changed in Spark 3.1 requiring Shim
+ */
class GpuSortMergeJoinMeta(
join: SortMergeJoinExec,
conf: RapidsConf,
parent: Option[RapidsMeta[_, _, _]],
rule: ConfKeysAndIncompat)
- extends SparkPlanMeta[SortMergeJoinExec](join, conf, parent, rule) with Logging {
+ extends SparkPlanMeta[SortMergeJoinExec](join, conf, parent, rule) {
val leftKeys: Seq[BaseExprMeta[_]] =
join.leftKeys.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
diff --git a/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/Spark30ShimServiceProvider.scala b/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/Spark30ShimServiceProvider.scala
new file mode 100644
index 00000000000..3d43a547aca
--- /dev/null
+++ b/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/Spark30ShimServiceProvider.scala
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.shims.spark30
+
+import com.nvidia.spark.rapids.{SparkShims, SparkShimServiceProvider}
+
+class Spark30ShimServiceProvider extends SparkShimServiceProvider {
+
+ val SPARK30VERSIONNAME = "3.0.0"
+
+ def matchesVersion(version: String): Boolean = {
+ version == SPARK30VERSIONNAME
+ }
+
+ def buildShim: SparkShims = {
+ new Spark30Shims()
+ }
+}
diff --git a/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/Spark30Shims.scala b/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/Spark30Shims.scala
new file mode 100644
index 00000000000..23c37eb44d9
--- /dev/null
+++ b/shims/spark30/src/main/scala/com/nvidia/spark/rapids/shims/spark30/Spark30Shims.scala
@@ -0,0 +1,196 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.shims.spark30
+
+import java.time.ZoneId
+
+import com.nvidia.spark.rapids._
+import com.nvidia.spark.rapids.spark30.RapidsShuffleManager
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.datasources.HadoopFsRelation
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec}
+import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.rapids.GpuTimeSub
+import org.apache.spark.sql.rapids.execution.GpuBroadcastNestedLoopJoinExecBase
+import org.apache.spark.sql.rapids.shims.spark30._
+import org.apache.spark.sql.types._
+import org.apache.spark.storage.{BlockId, BlockManagerId}
+import org.apache.spark.unsafe.types.CalendarInterval
+
+class Spark30Shims extends SparkShims {
+
+ override def getScalaUDFAsExpression(
+ function: AnyRef,
+ dataType: DataType,
+ children: Seq[Expression],
+ inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil,
+ outputEncoder: Option[ExpressionEncoder[_]] = None,
+ udfName: Option[String] = None,
+ nullable: Boolean = true,
+ udfDeterministic: Boolean = true): Expression = {
+ // outputEncoder is only used in Spark 3.1+
+ ScalaUDF(function, dataType, children, inputEncoders, udfName, nullable, udfDeterministic)
+ }
+
+ override def getMapSizesByExecutorId(
+ shuffleId: Int,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+ // startMapIndex and endMapIndex ignored as we don't support those for gpu shuffle.
+ SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, startPartition, endPartition)
+ }
+
+ override def getGpuBroadcastNestedLoopJoinShim(
+ left: SparkPlan,
+ right: SparkPlan,
+ join: BroadcastNestedLoopJoinExec,
+ joinType: JoinType,
+ condition: Option[Expression],
+ targetSizeBytes: Long): GpuBroadcastNestedLoopJoinExecBase = {
+ GpuBroadcastNestedLoopJoinExec(left, right, join, joinType, condition, targetSizeBytes)
+ }
+
+ override def isGpuHashJoin(plan: SparkPlan): Boolean = {
+ plan match {
+ case _: GpuHashJoin => true
+ case p => false
+ }
+ }
+
+ override def isGpuBroadcastHashJoin(plan: SparkPlan): Boolean = {
+ plan match {
+ case _: GpuBroadcastHashJoinExec => true
+ case p => false
+ }
+ }
+
+ override def isGpuShuffledHashJoin(plan: SparkPlan): Boolean = {
+ plan match {
+ case _: GpuShuffledHashJoinExec => true
+ case p => false
+ }
+ }
+
+ override def getExecs: Seq[ExecRule[_ <: SparkPlan]] = {
+ Seq(
+ GpuOverrides.exec[FileSourceScanExec](
+ "Reading data from files, often from Hive tables",
+ (fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) {
+ // partition filters and data filters are not run on the GPU
+ override val childExprs: Seq[ExprMeta[_]] = Seq.empty
+
+ override def tagPlanForGpu(): Unit = GpuFileSourceScanExec.tagSupport(this)
+
+ override def convertToGpu(): GpuExec = {
+ val newRelation = HadoopFsRelation(
+ wrapped.relation.location,
+ wrapped.relation.partitionSchema,
+ wrapped.relation.dataSchema,
+ wrapped.relation.bucketSpec,
+ GpuFileSourceScanExec.convertFileFormat(wrapped.relation.fileFormat),
+ wrapped.relation.options)(wrapped.relation.sparkSession)
+ GpuFileSourceScanExec(
+ newRelation,
+ wrapped.output,
+ wrapped.requiredSchema,
+ wrapped.partitionFilters,
+ wrapped.optionalBucketSet,
+ wrapped.dataFilters,
+ wrapped.tableIdentifier)
+ }
+ }),
+ GpuOverrides.exec[SortMergeJoinExec](
+ "Sort merge join, replacing with shuffled hash join",
+ (join, conf, p, r) => new GpuSortMergeJoinMeta(join, conf, p, r)),
+ GpuOverrides.exec[BroadcastHashJoinExec](
+ "Implementation of join using broadcast data",
+ (join, conf, p, r) => new GpuBroadcastHashJoinMeta(join, conf, p, r)),
+ GpuOverrides.exec[ShuffledHashJoinExec](
+ "Implementation of join using hashed shuffled data",
+ (join, conf, p, r) => new GpuShuffledHashJoinMeta(join, conf, p, r)),
+ )
+ }
+
+ override def getExprs: Seq[ExprRule[_ <: Expression]] = {
+ Seq(
+ GpuOverrides.expr[TimeSub](
+ "Subtracts interval from timestamp",
+ (a, conf, p, r) => new BinaryExprMeta[TimeSub](a, conf, p, r) {
+ override def tagExprForGpu(): Unit = {
+ a.interval match {
+ case Literal(intvl: CalendarInterval, DataTypes.CalendarIntervalType) =>
+ if (intvl.months != 0) {
+ willNotWorkOnGpu("interval months isn't supported")
+ }
+ case _ =>
+ willNotWorkOnGpu("only literals are supported for intervals")
+ }
+ if (ZoneId.of(a.timeZoneId.get).normalized() != GpuOverrides.UTC_TIMEZONE_ID) {
+ willNotWorkOnGpu("Only UTC zone id is supported")
+ }
+ }
+
+ override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = {
+ GpuTimeSub(lhs, rhs)
+ }
+ }
+ ),
+ GpuOverrides.expr[First](
+ "first aggregate operator",
+ (a, conf, p, r) => new ExprMeta[First](a, conf, p, r) {
+ val child: BaseExprMeta[_] = GpuOverrides.wrapExpr(a.child, conf, Some(this))
+ val ignoreNulls: BaseExprMeta[_] =
+ GpuOverrides.wrapExpr(a.ignoreNullsExpr, conf, Some(this))
+ override val childExprs: Seq[BaseExprMeta[_]] = Seq(child, ignoreNulls)
+
+ override def convertToGpu(): GpuExpression =
+ GpuFirst(child.convertToGpu(), ignoreNulls.convertToGpu())
+ }),
+ GpuOverrides.expr[Last](
+ "last aggregate operator",
+ (a, conf, p, r) => new ExprMeta[Last](a, conf, p, r) {
+ val child: BaseExprMeta[_] = GpuOverrides.wrapExpr(a.child, conf, Some(this))
+ val ignoreNulls: BaseExprMeta[_] =
+ GpuOverrides.wrapExpr(a.ignoreNullsExpr, conf, Some(this))
+ override val childExprs: Seq[BaseExprMeta[_]] = Seq(child, ignoreNulls)
+
+ override def convertToGpu(): GpuExpression =
+ GpuLast(child.convertToGpu(), ignoreNulls.convertToGpu())
+ }),
+ )
+ }
+
+ override def getBuildSide(join: HashJoin): GpuBuildSide = {
+ GpuJoinUtils.getGpuBuildSide(join.buildSide)
+ }
+
+ override def getBuildSide(join: BroadcastNestedLoopJoinExec): GpuBuildSide = {
+ GpuJoinUtils.getGpuBuildSide(join.buildSide)
+ }
+
+ override def getRapidsShuffleManagerClass: String = {
+ classOf[RapidsShuffleManager].getCanonicalName
+ }
+}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/RapidsShuffleManager.scala b/shims/spark30/src/main/scala/com/nvidia/spark/rapids/spark30/RapidsShuffleManager.scala
similarity index 87%
rename from sql-plugin/src/main/scala/com/nvidia/spark/RapidsShuffleManager.scala
rename to shims/spark30/src/main/scala/com/nvidia/spark/rapids/spark30/RapidsShuffleManager.scala
index db9e415b9e0..42094ea1487 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/RapidsShuffleManager.scala
+++ b/shims/spark30/src/main/scala/com/nvidia/spark/rapids/spark30/RapidsShuffleManager.scala
@@ -14,10 +14,10 @@
* limitations under the License.
*/
-package com.nvidia.spark
+package com.nvidia.spark.rapids.spark30
import org.apache.spark.SparkConf
-import org.apache.spark.sql.rapids.RapidsShuffleInternalManager
+import org.apache.spark.sql.rapids.shims.spark30.RapidsShuffleInternalManager
/** A shuffle manager optimized for the RAPIDS Plugin for Apache Spark. */
sealed class RapidsShuffleManager(
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExec.scala b/shims/spark30/src/main/scala/org/apache/spark/sql/rapids/shims/spark30/GpuFileSourceScanExec.scala
similarity index 89%
rename from sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExec.scala
rename to shims/spark30/src/main/scala/org/apache/spark/sql/rapids/shims/spark30/GpuFileSourceScanExec.scala
index e2c8c976851..90d135226e1 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExec.scala
+++ b/shims/spark30/src/main/scala/org/apache/spark/sql/rapids/shims/spark30/GpuFileSourceScanExec.scala
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.rapids
+package org.apache.spark.sql.rapids.shims.spark30
import java.util.concurrent.TimeUnit.NANOSECONDS
@@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.rapids.GpuFileSourceScanExecBase
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.collection.BitSet
@@ -42,7 +43,7 @@ case class GpuFileSourceScanExec(
optionalBucketSet: Option[BitSet],
dataFilters: Seq[Expression],
override val tableIdentifier: Option[TableIdentifier])
- extends DataSourceScanExec with GpuExec {
+ extends DataSourceScanExec with GpuFileSourceScanExecBase with GpuExec {
override val nodeName: String = {
s"GpuScan $relation ${tableIdentifier.map(_.unquotedString).getOrElse("")}"
@@ -60,9 +61,16 @@ case class GpuFileSourceScanExec(
// that is the logicalRelation. We don't know what its used for exactly but haven't
// run into any issues in testing using the one we create here.
@transient val logicalRelation = LogicalRelation(relation)
- constructor.newInstance(relation, output, requiredSchema, partitionFilters,
- optionalBucketSet, dataFilters, tableIdentifier,
- logicalRelation).asInstanceOf[FileSourceScanExec]
+ try {
+ constructor.newInstance(relation, output, requiredSchema, partitionFilters,
+ optionalBucketSet, dataFilters, tableIdentifier,
+ logicalRelation).asInstanceOf[FileSourceScanExec]
+ } catch {
+ case il: IllegalArgumentException =>
+ // TODO - workaround until https://github.com/NVIDIA/spark-rapids/issues/354
+ constructor.newInstance(relation, output, requiredSchema, partitionFilters,
+ optionalBucketSet, None, dataFilters, tableIdentifier).asInstanceOf[FileSourceScanExec]
+ }
} else {
constructor.newInstance(relation, output, requiredSchema, partitionFilters,
optionalBucketSet, dataFilters, tableIdentifier).asInstanceOf[FileSourceScanExec]
diff --git a/shims/spark30/src/main/scala/org/apache/spark/sql/rapids/shims/spark30/RapidsShuffleInternalManager.scala b/shims/spark30/src/main/scala/org/apache/spark/sql/rapids/shims/spark30/RapidsShuffleInternalManager.scala
new file mode 100644
index 00000000000..ca6125954a8
--- /dev/null
+++ b/shims/spark30/src/main/scala/org/apache/spark/sql/rapids/shims/spark30/RapidsShuffleInternalManager.scala
@@ -0,0 +1,56 @@
+/*
+ * Copyright (c) 2019-2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids.shims.spark30
+
+import org.apache.spark.{SparkConf, TaskContext}
+import org.apache.spark.shuffle._
+import org.apache.spark.sql.rapids.RapidsShuffleInternalManagerBase
+
+/**
+ * A shuffle manager optimized for the RAPIDS Plugin For Apache Spark.
+ * @note This is an internal class to obtain access to the private
+ * `ShuffleManager` and `SortShuffleManager` classes. When configuring
+ * Apache Spark to use the RAPIDS shuffle manager,
+ * [[com.nvidia.spark.RapidsShuffleManager]] should be used as that is
+ * the public class.
+ */
+class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean)
+ extends RapidsShuffleInternalManagerBase(conf, isDriver) {
+
+ override def getReaderForRange[K, C](
+ handle: ShuffleHandle,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+ // NOTE: This type of reader is not possible for gpu shuffle, as we'd need
+ // to use the optimization within our manager, and we don't.
+ wrapped.getReaderForRange(RapidsShuffleInternalManagerBase.unwrapHandle(handle),
+ startMapIndex, endMapIndex, startPartition, endPartition, context, metrics)
+ }
+
+ def getReader[K, C](
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+ getReaderInternal(handle, 0, Int.MaxValue, startPartition, endPartition, context, metrics)
+ }
+}
diff --git a/shims/spark31/pom.xml b/shims/spark31/pom.xml
new file mode 100644
index 00000000000..3e4604f9bc3
--- /dev/null
+++ b/shims/spark31/pom.xml
@@ -0,0 +1,45 @@
+
+
+
+ 4.0.0
+
+
+ com.nvidia
+ rapids-4-spark-shims-aggregator_2.12
+ 0.2.0-SNAPSHOT
+ ../pom.xml
+
+ com.nvidia
+ rapids-4-spark-shims-spark31_2.12
+ RAPIDS Accelerator for Apache Spark SQL Plugin Spark 3.1 Shim
+ The RAPIDS SQL plugin for Apache Spark 3.1 Shim
+ 0.2.0-SNAPSHOT
+
+
+ 3.1.0-SNAPSHOT
+
+
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark31.version}
+
+
+
diff --git a/shims/spark31/src/main/resources/META-INF/services/com.nvidia.spark.rapids.SparkShimServiceProvider b/shims/spark31/src/main/resources/META-INF/services/com.nvidia.spark.rapids.SparkShimServiceProvider
new file mode 100644
index 00000000000..0c854281b7d
--- /dev/null
+++ b/shims/spark31/src/main/resources/META-INF/services/com.nvidia.spark.rapids.SparkShimServiceProvider
@@ -0,0 +1 @@
+com.nvidia.spark.rapids.shims.spark31.Spark31ShimServiceProvider
diff --git a/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuBroadcastHashJoinExec.scala b/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuBroadcastHashJoinExec.scala
new file mode 100644
index 00000000000..bed75498f1a
--- /dev/null
+++ b/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuBroadcastHashJoinExec.scala
@@ -0,0 +1,151 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.shims.spark31
+
+import com.nvidia.spark.rapids._
+import com.nvidia.spark.rapids.GpuMetricNames._
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
+import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
+import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashedRelationBroadcastMode}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.rapids.execution.{GpuBroadcastExchangeExec, SerializeConcatHostBuffersDeserializeBatch}
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+/**
+ * Spark 3.1 changed packages of BuildLeft, BuildRight, BuildSide
+ */
+class GpuBroadcastHashJoinMeta(
+ join: BroadcastHashJoinExec,
+ conf: RapidsConf,
+ parent: Option[RapidsMeta[_, _, _]],
+ rule: ConfKeysAndIncompat)
+ extends SparkPlanMeta[BroadcastHashJoinExec](join, conf, parent, rule) {
+
+ val leftKeys: Seq[BaseExprMeta[_]] =
+ join.leftKeys.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
+ val rightKeys: Seq[BaseExprMeta[_]] =
+ join.rightKeys.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
+ val condition: Option[BaseExprMeta[_]] =
+ join.condition.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
+
+ override val childExprs: Seq[BaseExprMeta[_]] = leftKeys ++ rightKeys ++ condition
+
+ override def tagPlanForGpu(): Unit = {
+ GpuHashJoin.tagJoin(this, join.joinType, join.leftKeys, join.rightKeys, join.condition)
+
+ val buildSide = join.buildSide match {
+ case BuildLeft => childPlans(0)
+ case BuildRight => childPlans(1)
+ }
+
+ if (!buildSide.canThisBeReplaced) {
+ willNotWorkOnGpu("the broadcast for this join must be on the GPU too")
+ }
+
+ if (!canThisBeReplaced) {
+ buildSide.willNotWorkOnGpu("the BroadcastHashJoin this feeds is not on the GPU")
+ }
+ }
+
+ override def convertToGpu(): GpuExec = {
+ val left = childPlans(0).convertIfNeeded()
+ val right = childPlans(1).convertIfNeeded()
+ // The broadcast part of this must be a BroadcastExchangeExec
+ val buildSide = join.buildSide match {
+ case BuildLeft => left
+ case BuildRight => right
+ }
+ if (!buildSide.isInstanceOf[GpuBroadcastExchangeExec]) {
+ throw new IllegalStateException("the broadcast must be on the GPU too")
+ }
+ GpuBroadcastHashJoinExec(
+ leftKeys.map(_.convertToGpu()),
+ rightKeys.map(_.convertToGpu()),
+ join.joinType, join.buildSide,
+ condition.map(_.convertToGpu()),
+ left, right)
+ }
+}
+
+case class GpuBroadcastHashJoinExec(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ buildSide: BuildSide,
+ condition: Option[Expression],
+ left: SparkPlan,
+ right: SparkPlan) extends BinaryExecNode with GpuHashJoin {
+
+ override lazy val additionalMetrics: Map[String, SQLMetric] = Map(
+ "joinOutputRows" -> SQLMetrics.createMetric(sparkContext, "join output rows"),
+ "joinTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "join time"),
+ "filterTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "filter time"))
+
+ override def requiredChildDistribution: Seq[Distribution] = {
+ val mode = HashedRelationBroadcastMode(buildKeys)
+ buildSide match {
+ case BuildLeft =>
+ BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil
+ case BuildRight =>
+ UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
+ }
+ }
+
+ def broadcastExchange: GpuBroadcastExchangeExec = buildPlan match {
+ case gpu: GpuBroadcastExchangeExec => gpu
+ case reused: ReusedExchangeExec => reused.child.asInstanceOf[GpuBroadcastExchangeExec]
+ }
+
+ override def doExecute(): RDD[InternalRow] =
+ throw new IllegalStateException("GpuBroadcastHashJoin does not support row-based processing")
+
+ override def doExecuteColumnar() : RDD[ColumnarBatch] = {
+ val numOutputRows = longMetric(NUM_OUTPUT_ROWS)
+ val numOutputBatches = longMetric(NUM_OUTPUT_BATCHES)
+ val totalTime = longMetric(TOTAL_TIME)
+ val joinTime = longMetric("joinTime")
+ val filterTime = longMetric("filterTime")
+ val joinOutputRows = longMetric("joinOutputRows")
+
+ val broadcastRelation = broadcastExchange
+ .executeColumnarBroadcast[SerializeConcatHostBuffersDeserializeBatch]()
+
+ val boundCondition = condition.map(GpuBindReferences.bindReference(_, output))
+
+ lazy val builtTable = {
+ // TODO clean up intermediate results...
+ val keys = GpuProjectExec.project(broadcastRelation.value.batch, gpuBuildKeys)
+ val combined = combine(keys, broadcastRelation.value.batch)
+ val ret = GpuColumnVector.from(combined)
+ // Don't warn for a leak, because we cannot control when we are done with this
+ (0 until ret.getNumberOfColumns).foreach(ret.getColumn(_).noWarnLeakExpected())
+ ret
+ }
+
+ val rdd = streamedPlan.executeColumnar()
+ rdd.mapPartitions(it =>
+ doJoin(builtTable, it, boundCondition, numOutputRows, joinOutputRows,
+ numOutputBatches, joinTime, filterTime, totalTime))
+ }
+}
diff --git a/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuBroadcastNestedLoopJoinExec.scala b/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuBroadcastNestedLoopJoinExec.scala
new file mode 100644
index 00000000000..7a46926d8f1
--- /dev/null
+++ b/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuBroadcastNestedLoopJoinExec.scala
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.shims.spark31
+
+import com.nvidia.spark.rapids.{GpuBuildLeft, GpuBuildRight, GpuBuildSide}
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec
+import org.apache.spark.sql.rapids.execution._
+
+/**
+ * Spark 3.1 changed packages of BuildLeft, BuildRight, BuildSide
+ */
+case class GpuBroadcastNestedLoopJoinExec(
+ left: SparkPlan,
+ right: SparkPlan,
+ join: BroadcastNestedLoopJoinExec,
+ joinType: JoinType,
+ condition: Option[Expression],
+ targetSizeBytes: Long)
+ extends GpuBroadcastNestedLoopJoinExecBase(left, right, join, joinType, condition,
+ targetSizeBytes) {
+
+ def getGpuBuildSide: GpuBuildSide = {
+ GpuJoinUtils.getGpuBuildSide(join.buildSide)
+ }
+}
diff --git a/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuFirst.scala b/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuFirst.scala
new file mode 100644
index 00000000000..f1c11e1673f
--- /dev/null
+++ b/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuFirst.scala
@@ -0,0 +1,40 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.shims.spark31
+
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckSuccess
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.aggregate.FirstLast
+import org.apache.spark.sql.rapids.GpuFirstBase
+
+/**
+ * Parameters to GpuFirst changed in Spark 3.1
+ */
+case class GpuFirst(child: Expression, ignoreNulls: Boolean) extends GpuFirstBase(child) {
+ override def children: Seq[Expression] = child :: Nil
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else {
+ TypeCheckSuccess
+ }
+ }
+}
+
diff --git a/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuHashJoin.scala b/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuHashJoin.scala
new file mode 100644
index 00000000000..59db665f61c
--- /dev/null
+++ b/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuHashJoin.scala
@@ -0,0 +1,245 @@
+/*
+ * Copyright (c) 2019-2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.nvidia.spark.rapids.shims.spark31
+
+import ai.rapids.cudf.{NvtxColor, Table}
+import com.nvidia.spark.rapids._
+
+import org.apache.spark.TaskContext
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
+import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi, RightOuter}
+import org.apache.spark.sql.execution.joins.HashJoin
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
+
+object GpuHashJoin {
+ def tagJoin(
+ meta: RapidsMeta[_, _, _],
+ joinType: JoinType,
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ condition: Option[Expression]): Unit = joinType match {
+ case _: InnerLike =>
+ case FullOuter | RightOuter | LeftOuter | LeftSemi | LeftAnti =>
+ if (condition.isDefined) {
+ meta.willNotWorkOnGpu(s"$joinType joins currently do not support conditions")
+ }
+ case _ => meta.willNotWorkOnGpu(s"$joinType currently is not supported")
+ }
+}
+
+trait GpuHashJoin extends GpuExec with HashJoin {
+
+ override def output: Seq[Attribute] = {
+ joinType match {
+ case _: InnerLike =>
+ left.output ++ right.output
+ case LeftOuter =>
+ left.output ++ right.output.map(_.withNullability(true))
+ case RightOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output
+ case j: ExistenceJoin =>
+ left.output :+ j.exists
+ case LeftExistence(_) =>
+ left.output
+ case FullOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
+ case x =>
+ throw new IllegalArgumentException(s"GpuHashJoin should not take $x as the JoinType")
+ }
+ }
+
+ protected lazy val (gpuBuildKeys, gpuStreamedKeys) = {
+ require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
+ "Join keys from two sides should have same types")
+ val lkeys = GpuBindReferences.bindGpuReferences(leftKeys, left.output)
+ val rkeys = GpuBindReferences.bindGpuReferences(rightKeys, right.output)
+ buildSide match {
+ case BuildLeft => (lkeys, rkeys)
+ case BuildRight => (rkeys, lkeys)
+ }
+ }
+
+ /**
+ * Place the columns in left and the columns in right into a single ColumnarBatch
+ */
+ def combine(left: ColumnarBatch, right: ColumnarBatch): ColumnarBatch = {
+ val l = GpuColumnVector.extractColumns(left)
+ val r = GpuColumnVector.extractColumns(right)
+ val c = l ++ r
+ new ColumnarBatch(c.asInstanceOf[Array[ColumnVector]], left.numRows())
+ }
+
+ // TODO eventually dedupe the keys
+ lazy val joinKeyIndices: Range = gpuBuildKeys.indices
+
+ val localBuildOutput: Seq[Attribute] = buildPlan.output
+ // The first columns are the ones we joined on and need to remove
+ lazy val joinIndices: Seq[Int] = joinType match {
+ case RightOuter =>
+ // The left table and right table are switched in the output
+ // because we don't support a right join, only left
+ val numRight = right.output.length
+ val numLeft = left.output.length
+ val joinLength = joinKeyIndices.length
+ def remap(index: Int): Int = {
+ if (index < numLeft) {
+ // part of the left table, but is on the right side of the tmp output
+ index + joinLength + numRight
+ } else {
+ // part of the right table, but is on the left side of the tmp output
+ index + joinLength - numLeft
+ }
+ }
+ output.indices.map (remap)
+ case _ =>
+ val joinLength = joinKeyIndices.length
+ output.indices.map (v => v + joinLength)
+ }
+
+ def doJoin(builtTable: Table,
+ stream: Iterator[ColumnarBatch],
+ boundCondition: Option[Expression],
+ numOutputRows: SQLMetric,
+ joinOutputRows: SQLMetric,
+ numOutputBatches: SQLMetric,
+ joinTime: SQLMetric,
+ filterTime: SQLMetric,
+ totalTime: SQLMetric): Iterator[ColumnarBatch] = {
+ new Iterator[ColumnarBatch] {
+ import scala.collection.JavaConverters._
+ var nextCb: Option[ColumnarBatch] = None
+ var first: Boolean = true
+
+ TaskContext.get().addTaskCompletionListener[Unit](_ => closeCb())
+
+ def closeCb(): Unit = {
+ nextCb.foreach(_.close())
+ nextCb = None
+ }
+
+ override def hasNext: Boolean = {
+ while (nextCb.isEmpty && (first || stream.hasNext)) {
+ if (stream.hasNext) {
+ val cb = stream.next()
+ val startTime = System.nanoTime()
+ nextCb = doJoin(builtTable, cb, boundCondition, joinOutputRows, numOutputRows,
+ numOutputBatches, joinTime, filterTime)
+ totalTime += (System.nanoTime() - startTime)
+ } else if (first) {
+ // We have to at least try one in some cases
+ val startTime = System.nanoTime()
+ val cb = GpuColumnVector.emptyBatch(streamedPlan.output.asJava)
+ nextCb = doJoin(builtTable, cb, boundCondition, joinOutputRows, numOutputRows,
+ numOutputBatches, joinTime, filterTime)
+ totalTime += (System.nanoTime() - startTime)
+ }
+ first = false
+ }
+ nextCb.isDefined
+ }
+
+ override def next(): ColumnarBatch = {
+ if (!hasNext) {
+ throw new NoSuchElementException()
+ }
+ val ret = nextCb.get
+ nextCb = None
+ ret
+ }
+ }
+ }
+
+ private[this] def doJoin(builtTable: Table,
+ streamedBatch: ColumnarBatch,
+ boundCondition: Option[Expression],
+ numOutputRows: SQLMetric,
+ numJoinOutputRows: SQLMetric,
+ numOutputBatches: SQLMetric,
+ joinTime: SQLMetric,
+ filterTime: SQLMetric): Option[ColumnarBatch] = {
+
+ val streamedTable = try {
+ val streamedKeysBatch = GpuProjectExec.project(streamedBatch, gpuStreamedKeys)
+ try {
+ val combined = combine(streamedKeysBatch, streamedBatch)
+ GpuColumnVector.from(combined)
+ } finally {
+ streamedKeysBatch.close()
+ }
+ } finally {
+ streamedBatch.close()
+ }
+
+ val nvtxRange = new NvtxWithMetrics("hash join", NvtxColor.ORANGE, joinTime)
+ val joined = try {
+ buildSide match {
+ case BuildLeft => doJoinLeftRight(builtTable, streamedTable)
+ case BuildRight => doJoinLeftRight(streamedTable, builtTable)
+ }
+ } finally {
+ streamedTable.close()
+ nvtxRange.close()
+ }
+
+ numJoinOutputRows += joined.numRows()
+
+ val tmp = if (boundCondition.isDefined) {
+ GpuFilter(joined, boundCondition.get, numOutputRows, numOutputBatches, filterTime)
+ } else {
+ numOutputRows += joined.numRows()
+ numOutputBatches += 1
+ joined
+ }
+ if (tmp.numRows() == 0) {
+ // Not sure if there is a better way to work around this
+ numOutputBatches.set(numOutputBatches.value - 1)
+ tmp.close()
+ None
+ } else {
+ Some(tmp)
+ }
+ }
+
+ private[this] def doJoinLeftRight(leftTable: Table, rightTable: Table): ColumnarBatch = {
+ val joinedTable = joinType match {
+ case LeftOuter => leftTable.onColumns(joinKeyIndices: _*)
+ .leftJoin(rightTable.onColumns(joinKeyIndices: _*), false)
+ case RightOuter => rightTable.onColumns(joinKeyIndices: _*)
+ .leftJoin(leftTable.onColumns(joinKeyIndices: _*), false)
+ case _: InnerLike => leftTable.onColumns(joinKeyIndices: _*)
+ .innerJoin(rightTable.onColumns(joinKeyIndices: _*), false)
+ case LeftSemi => leftTable.onColumns(joinKeyIndices: _*)
+ .leftSemiJoin(rightTable.onColumns(joinKeyIndices: _*), false)
+ case LeftAnti => leftTable.onColumns(joinKeyIndices: _*)
+ .leftAntiJoin(rightTable.onColumns(joinKeyIndices: _*), false)
+ case FullOuter => leftTable.onColumns(joinKeyIndices: _*)
+ .fullJoin(rightTable.onColumns(joinKeyIndices: _*), false)
+ case _ => throw new NotImplementedError(s"Joint Type ${joinType.getClass} is not currently" +
+ s" supported")
+ }
+ try {
+ val result = joinIndices.map(joinIndex =>
+ GpuColumnVector.from(joinedTable.getColumn(joinIndex).incRefCount()))
+ .toArray[ColumnVector]
+
+ new ColumnarBatch(result, joinedTable.getRowCount.toInt)
+ } finally {
+ joinedTable.close()
+ }
+ }
+}
diff --git a/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuLast.scala b/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuLast.scala
new file mode 100644
index 00000000000..03d281d0693
--- /dev/null
+++ b/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuLast.scala
@@ -0,0 +1,39 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.shims.spark31
+
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckSuccess
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.rapids.GpuLastBase
+
+/**
+ * Parameters to GpuLast changed in Spark 3.1
+ */
+case class GpuLast(child: Expression, ignoreNulls: Boolean) extends GpuLastBase(child) {
+ override def children: Seq[Expression] = child :: Nil
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else {
+ TypeCheckSuccess
+ }
+ }
+}
+
diff --git a/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuShuffledHashJoinExec.scala b/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuShuffledHashJoinExec.scala
new file mode 100644
index 00000000000..bff955ebfdf
--- /dev/null
+++ b/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuShuffledHashJoinExec.scala
@@ -0,0 +1,151 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.shims.spark31
+
+import ai.rapids.cudf.{NvtxColor, Table}
+import com.nvidia.spark.rapids._
+import com.nvidia.spark.rapids.GpuMetricNames._
+
+import org.apache.spark.TaskContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.catalyst.plans.physical.{Distribution, HashClusteredDistribution}
+import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
+import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+object GpuJoinUtils {
+ def getGpuBuildSide(buildSide: BuildSide): GpuBuildSide = {
+ buildSide match {
+ case BuildRight => GpuBuildRight
+ case BuildLeft => GpuBuildLeft
+ case _ => throw new Exception("unknown buildSide Type")
+ }
+ }
+}
+
+/**
+ * Spark 3.1 changed packages of BuildLeft, BuildRight, BuildSide
+ */
+class GpuShuffledHashJoinMeta(
+ join: ShuffledHashJoinExec,
+ conf: RapidsConf,
+ parent: Option[RapidsMeta[_, _, _]],
+ rule: ConfKeysAndIncompat)
+ extends SparkPlanMeta[ShuffledHashJoinExec](join, conf, parent, rule) {
+ val leftKeys: Seq[BaseExprMeta[_]] =
+ join.leftKeys.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
+ val rightKeys: Seq[BaseExprMeta[_]] =
+ join.rightKeys.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
+ val condition: Option[BaseExprMeta[_]] =
+ join.condition.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
+
+ override val childExprs: Seq[BaseExprMeta[_]] = leftKeys ++ rightKeys ++ condition
+
+ override def tagPlanForGpu(): Unit = {
+ GpuHashJoin.tagJoin(this, join.joinType, join.leftKeys, join.rightKeys, join.condition)
+ }
+
+ override def convertToGpu(): GpuExec =
+ GpuShuffledHashJoinExec(
+ leftKeys.map(_.convertToGpu()),
+ rightKeys.map(_.convertToGpu()),
+ join.joinType,
+ join.buildSide,
+ condition.map(_.convertToGpu()),
+ childPlans(0).convertIfNeeded(),
+ childPlans(1).convertIfNeeded())
+}
+
+case class GpuShuffledHashJoinExec(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ buildSide: BuildSide,
+ condition: Option[Expression],
+ left: SparkPlan,
+ right: SparkPlan) extends BinaryExecNode with GpuHashJoin {
+
+ override lazy val additionalMetrics: Map[String, SQLMetric] = Map(
+ "buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "build side size"),
+ "buildTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "build time"),
+ "joinTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "join time"),
+ "joinOutputRows" -> SQLMetrics.createMetric(sparkContext, "join output rows"),
+ "filterTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "filter time"))
+
+ override def requiredChildDistribution: Seq[Distribution] =
+ HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ throw new UnsupportedOperationException(
+ "GpuShuffledHashJoin does not support the execute() code path.")
+ }
+
+ override def childrenCoalesceGoal: Seq[CoalesceGoal] = buildSide match {
+ case BuildLeft => Seq(RequireSingleBatch, null)
+ case BuildRight => Seq(null, RequireSingleBatch)
+ }
+
+ override def doExecuteColumnar() : RDD[ColumnarBatch] = {
+ val buildDataSize = longMetric("buildDataSize")
+ val numOutputRows = longMetric(NUM_OUTPUT_ROWS)
+ val numOutputBatches = longMetric(NUM_OUTPUT_BATCHES)
+ val totalTime = longMetric(TOTAL_TIME)
+ val buildTime = longMetric("buildTime")
+ val joinTime = longMetric("joinTime")
+ val filterTime = longMetric("filterTime")
+ val joinOutputRows = longMetric("joinOutputRows")
+
+ val boundCondition = condition.map(GpuBindReferences.bindReference(_, output))
+
+ streamedPlan.executeColumnar().zipPartitions(buildPlan.executeColumnar()) {
+ (streamIter, buildIter) => {
+ var combinedSize = 0
+ val startTime = System.nanoTime()
+ val buildBatch =
+ ConcatAndConsumeAll.getSingleBatchWithVerification(buildIter, localBuildOutput)
+ val keys = GpuProjectExec.project(buildBatch, gpuBuildKeys)
+ val builtTable = try {
+ // Combine does not inc any reference counting
+ val combined = combine(keys, buildBatch)
+ combinedSize =
+ GpuColumnVector.extractColumns(combined)
+ .map(_.getBase.getDeviceMemorySize).sum.toInt
+ GpuColumnVector.from(combined)
+ } finally {
+ keys.close()
+ buildBatch.close()
+ }
+
+ val delta = System.nanoTime() - startTime
+ buildTime += delta
+ totalTime += delta
+ buildDataSize += combinedSize
+ val context = TaskContext.get()
+ context.addTaskCompletionListener[Unit](_ => builtTable.close())
+
+ doJoin(builtTable, streamIter, boundCondition,
+ numOutputRows, joinOutputRows, numOutputBatches,
+ joinTime, filterTime, totalTime)
+ }
+ }
+ }
+}
diff --git a/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuSortMergeJoinExec.scala b/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuSortMergeJoinExec.scala
new file mode 100644
index 00000000000..237b08e36d2
--- /dev/null
+++ b/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/GpuSortMergeJoinExec.scala
@@ -0,0 +1,110 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.shims.spark31
+
+import com.nvidia.spark.rapids._
+
+import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
+import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}
+import org.apache.spark.sql.execution.SortExec
+import org.apache.spark.sql.execution.joins.SortMergeJoinExec
+
+/**
+ * HashJoin changed in Spark 3.1 requiring Shim
+ */
+class GpuSortMergeJoinMeta(
+ join: SortMergeJoinExec,
+ conf: RapidsConf,
+ parent: Option[RapidsMeta[_, _, _]],
+ rule: ConfKeysAndIncompat)
+ extends SparkPlanMeta[SortMergeJoinExec](join, conf, parent, rule) {
+
+ val leftKeys: Seq[BaseExprMeta[_]] =
+ join.leftKeys.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
+ val rightKeys: Seq[BaseExprMeta[_]] =
+ join.rightKeys.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
+ val condition: Option[BaseExprMeta[_]] = join.condition.map(
+ GpuOverrides.wrapExpr(_, conf, Some(this)))
+
+ override val childExprs: Seq[BaseExprMeta[_]] = leftKeys ++ rightKeys ++ condition
+
+ override def tagPlanForGpu(): Unit = {
+ // Use conditions from Hash Join
+ GpuHashJoin.tagJoin(this, join.joinType, join.leftKeys, join.rightKeys, join.condition)
+
+ if (!conf.enableReplaceSortMergeJoin) {
+ willNotWorkOnGpu(s"Not replacing sort merge join with hash join, " +
+ s"see ${RapidsConf.ENABLE_REPLACE_SORTMERGEJOIN.key}")
+ }
+
+ // make sure this is last check - if this is SortMergeJoin, the children can be Sorts and we
+ // want to validate they can run on GPU and remove them before replacing this with a
+ // ShuffleHashJoin
+ if (canThisBeReplaced) {
+ childPlans.foreach { plan =>
+ if (plan.wrapped.isInstanceOf[SortExec]) {
+ if (!plan.canThisBeReplaced) {
+ willNotWorkOnGpu(s"can't replace sortMergeJoin because one of the SortExec's before " +
+ s"can't be replaced.")
+ } else {
+ plan.shouldBeRemoved("removing SortExec as part replacing sortMergeJoin with " +
+ s"shuffleHashJoin")
+ }
+ }
+ }
+ }
+ }
+
+ override def convertToGpu(): GpuExec = {
+ val buildSide = if (canBuildRight(join.joinType)) {
+ BuildRight
+ } else if (canBuildLeft(join.joinType)) {
+ BuildLeft
+ } else {
+ throw new IllegalStateException(s"Cannot build either side for ${join.joinType} join")
+ }
+ GpuShuffledHashJoinExec(
+ leftKeys.map(_.convertToGpu()),
+ rightKeys.map(_.convertToGpu()),
+ join.joinType,
+ buildSide,
+ condition.map(_.convertToGpu()),
+ childPlans(0).convertIfNeeded(),
+ childPlans(1).convertIfNeeded())
+ }
+
+ /**
+ * Determine if this type of join supports using the right side of the join as the build side.
+ *
+ * These rules match those in Spark's ShuffleHashJoinExec.
+ */
+ private def canBuildRight(joinType: JoinType): Boolean = joinType match {
+ case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true
+ case _ => false
+ }
+
+ /**
+ * Determine if this type of join supports using the left side of the join as the build side.
+ *
+ * These rules match those in Spark's ShuffleHashJoinExec, with the addition of support for
+ * full outer joins.
+ */
+ private def canBuildLeft(joinType: JoinType): Boolean = joinType match {
+ case _: InnerLike | RightOuter | FullOuter => true
+ case _ => false
+ }
+}
diff --git a/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/Spark31ShimServiceProvider.scala b/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/Spark31ShimServiceProvider.scala
new file mode 100644
index 00000000000..2920e9e6dac
--- /dev/null
+++ b/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/Spark31ShimServiceProvider.scala
@@ -0,0 +1,32 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.shims.spark31
+
+import com.nvidia.spark.rapids.{SparkShims, SparkShimServiceProvider}
+
+class Spark31ShimServiceProvider extends SparkShimServiceProvider {
+
+ val SPARK31VERSIONNAMES = Seq("3.1.0-SNAPSHOT", "3.1.0")
+
+ def matchesVersion(version: String): Boolean = {
+ SPARK31VERSIONNAMES.contains(version)
+ }
+
+ def buildShim: SparkShims = {
+ new Spark31Shims()
+ }
+}
diff --git a/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/Spark31Shims.scala b/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/Spark31Shims.scala
new file mode 100644
index 00000000000..b484379bb23
--- /dev/null
+++ b/shims/spark31/src/main/scala/com/nvidia/spark/rapids/shims/spark31/Spark31Shims.scala
@@ -0,0 +1,183 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.shims.spark31
+
+import java.time.ZoneId
+
+import com.nvidia.spark.rapids._
+import com.nvidia.spark.rapids.spark31.RapidsShuffleManager
+import org.apache.spark.SparkEnv
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.datasources.HadoopFsRelation
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, HashJoin, SortMergeJoinExec}
+import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
+import org.apache.spark.sql.rapids.GpuTimeSub
+import org.apache.spark.sql.rapids.execution.GpuBroadcastNestedLoopJoinExecBase
+import org.apache.spark.sql.rapids.shims.spark31._
+import org.apache.spark.unsafe.types.CalendarInterval
+import org.apache.spark.sql.types._
+import org.apache.spark.storage.{BlockId, BlockManagerId}
+
+class Spark31Shims extends SparkShims {
+
+ override def getScalaUDFAsExpression(
+ function: AnyRef,
+ dataType: DataType,
+ children: Seq[Expression],
+ inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil,
+ outputEncoder: Option[ExpressionEncoder[_]] = None,
+ udfName: Option[String] = None,
+ nullable: Boolean = true,
+ udfDeterministic: Boolean = true): Expression = {
+ ScalaUDF(function, dataType, children, inputEncoders, outputEncoder, udfName, nullable, udfDeterministic)
+ }
+
+ override def getMapSizesByExecutorId(
+ shuffleId: Int,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+ SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId,
+ startMapIndex, endMapIndex, startPartition, endPartition)
+ }
+
+ override def getGpuBroadcastNestedLoopJoinShim(
+ left: SparkPlan,
+ right: SparkPlan,
+ join: BroadcastNestedLoopJoinExec,
+ joinType: JoinType,
+ condition: Option[Expression],
+ targetSizeBytes: Long): GpuBroadcastNestedLoopJoinExecBase = {
+ GpuBroadcastNestedLoopJoinExec(left, right, join, joinType, condition, targetSizeBytes)
+ }
+
+ override def isGpuHashJoin(plan: SparkPlan): Boolean = {
+ plan match {
+ case _: GpuHashJoin => true
+ case p => false
+ }
+ }
+
+ override def isGpuBroadcastHashJoin(plan: SparkPlan): Boolean = {
+ plan match {
+ case _: GpuBroadcastHashJoinExec => true
+ case p => false
+ }
+ }
+
+ override def isGpuShuffledHashJoin(plan: SparkPlan): Boolean = {
+ plan match {
+ case _: GpuShuffledHashJoinExec => true
+ case p => false
+ }
+ }
+
+ override def getExprs: Seq[ExprRule[_ <: Expression]] = {
+ Seq(
+ GpuOverrides.expr[TimeAdd](
+ "Subtracts interval from timestamp",
+ (a, conf, p, r) => new BinaryExprMeta[TimeAdd](a, conf, p, r) {
+ override def tagExprForGpu(): Unit = {
+ a.interval match {
+ case Literal(intvl: CalendarInterval, DataTypes.CalendarIntervalType) =>
+ if (intvl.months != 0) {
+ willNotWorkOnGpu("interval months isn't supported")
+ }
+ case _ =>
+ willNotWorkOnGpu("only literals are supported for intervals")
+ }
+ if (ZoneId.of(a.timeZoneId.get).normalized() != GpuOverrides.UTC_TIMEZONE_ID) {
+ willNotWorkOnGpu("Only UTC zone id is supported")
+ }
+ }
+
+ override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
+ GpuTimeSub(lhs, rhs)
+ }
+ ),
+ GpuOverrides.expr[First](
+ "first aggregate operator",
+ (a, conf, p, r) => new ExprMeta[First](a, conf, p, r) {
+ override def convertToGpu(): GpuExpression =
+ GpuFirst(childExprs(0).convertToGpu(), a.ignoreNulls)
+ }),
+ GpuOverrides.expr[Last](
+ "last aggregate operator",
+ (a, conf, p, r) => new ExprMeta[Last](a, conf, p, r) {
+ override def convertToGpu(): GpuExpression =
+ GpuLast(childExprs(0).convertToGpu(), a.ignoreNulls)
+ }),
+ )
+ }
+
+ override def getExecs: Seq[ExecRule[_ <: SparkPlan]] = {
+ Seq(
+ GpuOverrides.exec[FileSourceScanExec](
+ "Reading data from files, often from Hive tables",
+ (fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) {
+ // partition filters and data filters are not run on the GPU
+ override val childExprs: Seq[ExprMeta[_]] = Seq.empty
+
+ override def tagPlanForGpu(): Unit = GpuFileSourceScanExec.tagSupport(this)
+
+ override def convertToGpu(): GpuExec = {
+ val newRelation = HadoopFsRelation(
+ wrapped.relation.location,
+ wrapped.relation.partitionSchema,
+ wrapped.relation.dataSchema,
+ wrapped.relation.bucketSpec,
+ GpuFileSourceScanExec.convertFileFormat(wrapped.relation.fileFormat),
+ wrapped.relation.options)(wrapped.relation.sparkSession)
+ GpuFileSourceScanExec(
+ newRelation,
+ wrapped.output,
+ wrapped.requiredSchema,
+ wrapped.partitionFilters,
+ wrapped.optionalBucketSet,
+ wrapped.dataFilters,
+ wrapped.tableIdentifier)
+ }
+ }),
+ GpuOverrides.exec[SortMergeJoinExec](
+ "Sort merge join, replacing with shuffled hash join",
+ (join, conf, p, r) => new GpuSortMergeJoinMeta(join, conf, p, r)),
+ GpuOverrides.exec[BroadcastHashJoinExec](
+ "Implementation of join using broadcast data",
+ (join, conf, p, r) => new GpuBroadcastHashJoinMeta(join, conf, p, r)),
+ GpuOverrides.exec[ShuffledHashJoinExec](
+ "Implementation of join using hashed shuffled data",
+ (join, conf, p, r) => new GpuShuffledHashJoinMeta(join, conf, p, r)),
+ )
+ }
+
+ override def getBuildSide(join: HashJoin): GpuBuildSide = {
+ GpuJoinUtils.getGpuBuildSide(join.buildSide)
+ }
+
+ override def getBuildSide(join: BroadcastNestedLoopJoinExec): GpuBuildSide = {
+ GpuJoinUtils.getGpuBuildSide(join.buildSide)
+ }
+
+ override def getRapidsShuffleManagerClass: String = {
+ classOf[RapidsShuffleManager].getCanonicalName
+ }
+}
diff --git a/shims/spark31/src/main/scala/com/nvidia/spark/rapids/spark31/RapidsShuffleManager.scala b/shims/spark31/src/main/scala/com/nvidia/spark/rapids/spark31/RapidsShuffleManager.scala
new file mode 100644
index 00000000000..cf1941f4e07
--- /dev/null
+++ b/shims/spark31/src/main/scala/com/nvidia/spark/rapids/spark31/RapidsShuffleManager.scala
@@ -0,0 +1,26 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.spark31
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.rapids.shims.spark31.RapidsShuffleInternalManager
+
+/** A shuffle manager optimized for the RAPIDS Plugin for Apache Spark. */
+sealed class RapidsShuffleManager(
+ conf: SparkConf,
+ isDriver: Boolean) extends RapidsShuffleInternalManager(conf, isDriver) {
+}
diff --git a/shims/spark31/src/main/scala/org/apache/spark/sql/rapids/shims/spark31/GpuFileSourceScanExec.scala b/shims/spark31/src/main/scala/org/apache/spark/sql/rapids/shims/spark31/GpuFileSourceScanExec.scala
new file mode 100644
index 00000000000..010dfd7f551
--- /dev/null
+++ b/shims/spark31/src/main/scala/org/apache/spark/sql/rapids/shims/spark31/GpuFileSourceScanExec.scala
@@ -0,0 +1,179 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids.shims.spark31
+
+import java.util.concurrent.TimeUnit.NANOSECONDS
+
+import com.nvidia.spark.rapids.{GpuExec, GpuReadCSVFileFormat, GpuReadOrcFileFormat, GpuReadParquetFileFormat, SparkPlanMeta}
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrder}
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.execution.{DataSourceScanExec, ExplainUtils, FileSourceScanExec}
+import org.apache.spark.sql.execution.datasources.{FileFormat, HadoopFsRelation, LogicalRelation}
+import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
+import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
+import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.rapids.GpuFileSourceScanExecBase
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.collection.BitSet
+
+case class GpuFileSourceScanExec(
+ @transient relation: HadoopFsRelation,
+ output: Seq[Attribute],
+ requiredSchema: StructType,
+ partitionFilters: Seq[Expression],
+ optionalBucketSet: Option[BitSet],
+ dataFilters: Seq[Expression],
+ override val tableIdentifier: Option[TableIdentifier])
+ extends DataSourceScanExec with GpuFileSourceScanExecBase with GpuExec {
+
+ override val nodeName: String = {
+ s"GpuScan $relation ${tableIdentifier.map(_.unquotedString).getOrElse("")}"
+ }
+
+ private[this] val wrapped: FileSourceScanExec = {
+ val tclass = classOf[org.apache.spark.sql.execution.FileSourceScanExec]
+ val constructors = tclass.getConstructors()
+ if (constructors.size > 1) {
+ throw new IllegalStateException(s"Only expected 1 constructor for FileSourceScanExec")
+ }
+ val constructor = constructors(0)
+ val instance = if (constructor.getParameterCount() == 8) {
+ // Some distributions of Spark modified FileSourceScanExec to take an additional parameter
+ // that is the logicalRelation. We don't know what its used for exactly but haven't
+ // run into any issues in testing using the one we create here.
+ @transient val logicalRelation = LogicalRelation(relation)
+ try {
+ constructor.newInstance(relation, output, requiredSchema, partitionFilters,
+ optionalBucketSet, dataFilters, tableIdentifier,
+ logicalRelation).asInstanceOf[FileSourceScanExec]
+ } catch {
+ case il: IllegalArgumentException =>
+ // TODO - workaround until https://github.com/NVIDIA/spark-rapids/issues/354
+ constructor.newInstance(relation, output, requiredSchema, partitionFilters,
+ optionalBucketSet, None, dataFilters, tableIdentifier).asInstanceOf[FileSourceScanExec]
+ }
+ } else {
+ constructor.newInstance(relation, output, requiredSchema, partitionFilters,
+ optionalBucketSet, dataFilters, tableIdentifier).asInstanceOf[FileSourceScanExec]
+ }
+ instance
+ }
+
+ override lazy val outputPartitioning: Partitioning = wrapped.outputPartitioning
+
+ override lazy val outputOrdering: Seq[SortOrder] = wrapped.outputOrdering
+
+ override lazy val metadata: Map[String, String] = wrapped.metadata
+
+ override lazy val metrics: Map[String, SQLMetric] = wrapped.metrics
+
+ override def verboseStringWithOperatorId(): String = {
+ val metadataStr = metadata.toSeq.sorted.filterNot {
+ case (_, value) if (value.isEmpty || value.equals("[]")) => true
+ case (key, _) if (key.equals("DataFilters") || key.equals("Format")) => true
+ case (_, _) => false
+ }.map {
+ case (key, _) if (key.equals("Location")) =>
+ val location = wrapped.relation.location
+ val numPaths = location.rootPaths.length
+ val abbreviatedLoaction = if (numPaths <= 1) {
+ location.rootPaths.mkString("[", ", ", "]")
+ } else {
+ "[" + location.rootPaths.head + s", ... ${numPaths - 1} entries]"
+ }
+ s"$key: ${location.getClass.getSimpleName} ${redact(abbreviatedLoaction)}"
+ case (key, value) => s"$key: ${redact(value)}"
+ }
+
+ s"""
+ |$formattedNodeName
+ |${ExplainUtils.generateFieldString("Output", output)}
+ |${metadataStr.mkString("\n")}
+ |""".stripMargin
+ }
+
+ override def inputRDDs(): Seq[RDD[InternalRow]] = {
+ wrapped.inputRDD :: Nil
+ }
+
+ override protected def doExecute(): RDD[InternalRow] =
+ throw new IllegalStateException(s"Row-based execution should not occur for $this")
+
+ override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
+ val numOutputRows = longMetric("numOutputRows")
+ val scanTime = longMetric("scanTime")
+ wrapped.inputRDD.asInstanceOf[RDD[ColumnarBatch]].mapPartitionsInternal { batches =>
+ new Iterator[ColumnarBatch] {
+
+ override def hasNext: Boolean = {
+ // The `FileScanRDD` returns an iterator which scans the file during the `hasNext` call.
+ val startNs = System.nanoTime()
+ val res = batches.hasNext
+ scanTime += NANOSECONDS.toMillis(System.nanoTime() - startNs)
+ res
+ }
+
+ override def next(): ColumnarBatch = {
+ val batch = batches.next()
+ numOutputRows += batch.numRows()
+ batch
+ }
+ }
+ }
+ }
+
+ override val nodeNamePrefix: String = "Gpu" + wrapped.nodeNamePrefix
+
+ override def doCanonicalize(): GpuFileSourceScanExec = {
+ val canonical = wrapped.doCanonicalize()
+ GpuFileSourceScanExec(
+ canonical.relation,
+ canonical.output,
+ canonical.requiredSchema,
+ canonical.partitionFilters,
+ canonical.optionalBucketSet,
+ canonical.dataFilters,
+ canonical.tableIdentifier)
+ }
+}
+
+object GpuFileSourceScanExec {
+ def tagSupport(meta: SparkPlanMeta[FileSourceScanExec]): Unit = {
+ meta.wrapped.relation.fileFormat match {
+ case _: CSVFileFormat => GpuReadCSVFileFormat.tagSupport(meta)
+ case _: OrcFileFormat => GpuReadOrcFileFormat.tagSupport(meta)
+ case _: ParquetFileFormat => GpuReadParquetFileFormat.tagSupport(meta)
+ case f =>
+ meta.willNotWorkOnGpu(s"unsupported file format: ${f.getClass.getCanonicalName}")
+ }
+ }
+
+ def convertFileFormat(format: FileFormat): FileFormat = {
+ format match {
+ case _: CSVFileFormat => new GpuReadCSVFileFormat
+ case _: OrcFileFormat => new GpuReadOrcFileFormat
+ case _: ParquetFileFormat => new GpuReadParquetFileFormat
+ case f =>
+ throw new IllegalArgumentException(s"${f.getClass.getCanonicalName} is not supported")
+ }
+ }
+}
diff --git a/shims/spark31/src/main/scala/org/apache/spark/sql/rapids/shims/spark31/RapidsShuffleInternalManager.scala b/shims/spark31/src/main/scala/org/apache/spark/sql/rapids/shims/spark31/RapidsShuffleInternalManager.scala
new file mode 100644
index 00000000000..179be3efc56
--- /dev/null
+++ b/shims/spark31/src/main/scala/org/apache/spark/sql/rapids/shims/spark31/RapidsShuffleInternalManager.scala
@@ -0,0 +1,44 @@
+/*
+ * Copyright (c) 2019-2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids.shims.spark31
+
+import org.apache.spark.{SparkConf, TaskContext}
+import org.apache.spark.shuffle._
+import org.apache.spark.sql.rapids.RapidsShuffleInternalManagerBase
+
+/**
+ * A shuffle manager optimized for the RAPIDS Plugin For Apache Spark.
+ * @note This is an internal class to obtain access to the private
+ * `ShuffleManager` and `SortShuffleManager` classes. When configuring
+ * Apache Spark to use the RAPIDS shuffle manager,
+ * [[com.nvidia.spark.RapidsShuffleManager]] should be used as that is
+ * the public class.
+ */
+class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean)
+ extends RapidsShuffleInternalManagerBase(conf, isDriver) {
+
+ def getReader[K, C](
+ handle: ShuffleHandle,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+ getReaderInternal(handle, 0, Int.MaxValue, startPartition, endPartition, context, metrics)
+ }
+}
diff --git a/sql-plugin/pom.xml b/sql-plugin/pom.xml
index f7465d8efc8..d80721f110a 100644
--- a/sql-plugin/pom.xml
+++ b/sql-plugin/pom.xml
@@ -144,26 +144,6 @@
net.alchim31.maven
scala-maven-plugin
-
-
- update_config
- verify
-
- run
-
-
-
-
-
-
- update_rapids_config
- com.nvidia.spark.rapids.RapidsConf
-
- ${project.basedir}/../docs/configs.md
-
-
-
-
org.scalastyle
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
index f652fbd79b0..a2fe90ad212 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala
@@ -46,7 +46,7 @@ import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNes
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.rapids._
import org.apache.spark.sql.rapids.catalyst.expressions.GpuRand
-import org.apache.spark.sql.rapids.execution.{GpuBroadcastHashJoinMeta, GpuBroadcastMeta, GpuBroadcastNestedLoopJoinMeta}
+import org.apache.spark.sql.rapids.execution.{GpuBroadcastMeta, GpuBroadcastNestedLoopJoinMeta}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
@@ -458,7 +458,7 @@ object GpuOverrides {
.map(r => r.wrap(expr, conf, parent, r).asInstanceOf[BaseExprMeta[INPUT]])
.getOrElse(new RuleNotFoundExprMeta(expr, conf, parent))
- val expressions: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq(
+ val expressions: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = (Seq(
expr[Literal](
"holds a static value from the query",
(lit, conf, p, r) => new ExprMeta[Literal](lit, conf, p, r) {
@@ -715,27 +715,6 @@ object GpuOverrides {
GpuDateSub(lhs, rhs)
}
),
- expr[TimeSub](
- "Subtracts interval from timestamp",
- (a, conf, p, r) => new BinaryExprMeta[TimeSub](a, conf, p, r) {
- override def tagExprForGpu(): Unit = {
- a.interval match {
- case Literal(intvl: CalendarInterval, DataTypes.CalendarIntervalType) =>
- if (intvl.months != 0) {
- willNotWorkOnGpu("interval months isn't supported")
- }
- case _ =>
- willNotWorkOnGpu("only literals are supported for intervals")
- }
- if (ZoneId.of(a.timeZoneId.get).normalized() != UTC_TIMEZONE_ID) {
- willNotWorkOnGpu("Only UTC zone id is supported")
- }
- }
-
- override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
- GpuTimeSub(lhs, rhs)
- }
- ),
expr[NaNvl](
"evaluates to `left` iff left is not NaN, `right` otherwise.",
(a, conf, p, r) => new BinaryExprMeta[NaNvl](a, conf, p, r) {
@@ -1224,28 +1203,6 @@ object GpuOverrides {
}
override def convertToGpu(child: Expression): GpuExpression = GpuMin(child)
}),
- expr[First](
- "first aggregate operator",
- (a, conf, p, r) => new ExprMeta[First](a, conf, p, r) {
- val child: BaseExprMeta[_] = GpuOverrides.wrapExpr(a.child, conf, Some(this))
- val ignoreNulls: BaseExprMeta[_] =
- GpuOverrides.wrapExpr(a.ignoreNullsExpr, conf, Some(this))
- override val childExprs: Seq[BaseExprMeta[_]] = Seq(child, ignoreNulls)
-
- override def convertToGpu(): GpuExpression =
- GpuFirst(child.convertToGpu(), ignoreNulls.convertToGpu())
- }),
- expr[Last](
- "last aggregate operator",
- (a, conf, p, r) => new ExprMeta[Last](a, conf, p, r) {
- val child: BaseExprMeta[_] = GpuOverrides.wrapExpr(a.child, conf, Some(this))
- val ignoreNulls: BaseExprMeta[_] =
- GpuOverrides.wrapExpr(a.ignoreNullsExpr, conf, Some(this))
- override val childExprs: Seq[BaseExprMeta[_]] = Seq(child, ignoreNulls)
-
- override def convertToGpu(): GpuExpression =
- GpuLast(child.convertToGpu(), ignoreNulls.convertToGpu())
- }),
expr[Sum](
"sum aggregate operator",
(a, conf, p, r) => new AggExprMeta[Sum](a, conf, p, r) {
@@ -1476,7 +1433,8 @@ object GpuOverrides {
(a, conf, p, r) => new UnaryExprMeta[Length](a, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression = GpuLength(child)
})
- ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
+ ) ++ ShimLoader.getSparkShims.getExprs)
+ .map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
def wrapScan[INPUT <: Scan](
scan: INPUT,
@@ -1619,7 +1577,7 @@ object GpuOverrides {
.map(r => r.wrap(plan, conf, parent, r).asInstanceOf[SparkPlanMeta[INPUT]])
.getOrElse(new RuleNotFoundSparkPlanMeta(plan, conf, parent))
- val execs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = Seq(
+ val execs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = (Seq(
exec[GenerateExec] (
"The backend for operations that generate more output rows than input rows like explode.",
(gen, conf, p, r) => new GpuGenerateExecSparkPlanMeta(gen, conf, p, r)),
@@ -1664,32 +1622,6 @@ object GpuOverrides {
GpuDataWritingCommandExec(childDataWriteCmds.head.convertToGpu(),
childPlans.head.convertIfNeeded())
}),
- exec[FileSourceScanExec](
- "Reading data from files, often from Hive tables",
- (fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) {
- // partition filters and data filters are not run on the GPU
- override val childExprs: Seq[ExprMeta[_]] = Seq.empty
-
- override def tagPlanForGpu(): Unit = GpuFileSourceScanExec.tagSupport(this)
-
- override def convertToGpu(): GpuExec = {
- val newRelation = HadoopFsRelation(
- wrapped.relation.location,
- wrapped.relation.partitionSchema,
- wrapped.relation.dataSchema,
- wrapped.relation.bucketSpec,
- GpuFileSourceScanExec.convertFileFormat(wrapped.relation.fileFormat),
- wrapped.relation.options)(wrapped.relation.sparkSession)
- GpuFileSourceScanExec(
- newRelation,
- wrapped.output,
- wrapped.requiredSchema,
- wrapped.partitionFilters,
- wrapped.optionalBucketSet,
- wrapped.dataFilters,
- wrapped.tableIdentifier)
- }
- }),
exec[LocalLimitExec](
"Per-partition limiting of results",
(localLimitExec, conf, p, r) =>
@@ -1725,12 +1657,6 @@ object GpuOverrides {
exec[BroadcastExchangeExec](
"The backend for broadcast exchange of data",
(exchange, conf, p, r) => new GpuBroadcastMeta(exchange, conf, p, r)),
- exec[BroadcastHashJoinExec](
- "Implementation of join using broadcast data",
- (join, conf, p, r) => new GpuBroadcastHashJoinMeta(join, conf, p, r)),
- exec[ShuffledHashJoinExec](
- "Implementation of join using hashed shuffled data",
- (join, conf, p, r) => new GpuShuffledHashJoinMeta(join, conf, p, r)),
exec[BroadcastNestedLoopJoinExec](
"Implementation of join using brute force",
(join, conf, p, r) => new GpuBroadcastNestedLoopJoinMeta(join, conf, p, r))
@@ -1751,9 +1677,6 @@ object GpuOverrides {
conf.gpuTargetBatchSizeBytes)
})
.disabledByDefault("large joins can cause out of memory errors"),
- exec[SortMergeJoinExec](
- "Sort merge join, replacing with shuffled hash join",
- (join, conf, p, r) => new GpuSortMergeJoinMeta(join, conf, p, r)),
exec[HashAggregateExec](
"The backend for hash based aggregations",
(agg, conf, p, r) => new GpuHashAggregateMeta(agg, conf, p, r)),
@@ -1771,7 +1694,8 @@ object GpuOverrides {
(windowOp, conf, p, r) =>
new GpuWindowExecMeta(windowOp, conf, p, r)
)
- ).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap
+ ) ++ ShimLoader.getSparkShims.getExecs)
+ .map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap
}
case class GpuOverrides() extends Rule[SparkPlan] with Logging {
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala
index c29499671e9..a1e3ef405c5 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.command.ExecutedCommandExec
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExecBase
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
-import org.apache.spark.sql.rapids.GpuFileSourceScanExec
+import org.apache.spark.sql.rapids.GpuFileSourceScanExecBase
/**
* Rules that run after the row to columnar and columnar to row transitions have been inserted.
@@ -174,7 +174,10 @@ class GpuTransitionOverrides extends Rule[SparkPlan] {
// intermediate nodes that have a specified sort order. This helps with the size of
// Parquet and Orc files
plan match {
- case _: GpuHashJoin | _: GpuHashAggregateExec =>
+ case s if ShimLoader.getSparkShims.isGpuHashJoin(s) =>
+ val sortOrder = getOptimizedSortOrder(plan)
+ GpuSortExec(sortOrder, false, plan, TargetSize(conf.gpuTargetBatchSizeBytes))
+ case _: GpuHashAggregateExec =>
val sortOrder = getOptimizedSortOrder(plan)
GpuSortExec(sortOrder, false, plan, TargetSize(conf.gpuTargetBatchSizeBytes))
case p =>
@@ -249,7 +252,7 @@ class GpuTransitionOverrides extends Rule[SparkPlan] {
val planOutput = plan.output.toSet
// avoid checking expressions of GpuFileSourceScanExec since all expressions are
// processed by driver and not run on GPU.
- if (!plan.isInstanceOf[GpuFileSourceScanExec]) {
+ if (!plan.isInstanceOf[GpuFileSourceScanExecBase]) {
plan.expressions.filter(_ match {
case a: Attribute => !planOutput.contains(a)
case _ => true
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala
index 834ec51ed19..393307a05e3 100644
--- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.DataWritingCommand
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
-import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight, ShuffledHashJoinExec, SortMergeJoinExec}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.types.{CalendarIntervalType, DataType, DataTypes, StringType}
trait ConfKeysAndIncompat {
@@ -420,9 +420,9 @@ abstract class SparkPlanMeta[INPUT <: SparkPlan](plan: INPUT,
private def findShuffleExchanges(): Seq[SparkPlanMeta[ShuffleExchangeExec]] = wrapped match {
case _: ShuffleExchangeExec =>
this.asInstanceOf[SparkPlanMeta[ShuffleExchangeExec]] :: Nil
- case bkj: BroadcastHashJoinExec => bkj.buildSide match {
- case BuildLeft => childPlans(1).findShuffleExchanges()
- case BuildRight => childPlans(0).findShuffleExchanges()
+ case bkj: BroadcastHashJoinExec => ShimLoader.getSparkShims.getBuildSide(bkj) match {
+ case GpuBuildLeft => childPlans(1).findShuffleExchanges()
+ case GpuBuildRight => childPlans(0).findShuffleExchanges()
}
case _ => childPlans.flatMap(_.findShuffleExchanges())
}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala
new file mode 100644
index 00000000000..878acb2a559
--- /dev/null
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala
@@ -0,0 +1,60 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids
+
+import java.util.ServiceLoader
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION}
+import org.apache.spark.internal.Logging
+
+object ShimLoader extends Logging {
+
+ private val sparkVersion = getSparkVersion
+ logInfo(s"Loading shim for Spark version: $sparkVersion")
+
+ // This is not ideal, but pass the version in here because otherwise loader that match the
+ // same version (3.0.0 Apache and 3.0.0 Databricks) would need to know how to differentiate.
+ private val sparkShimLoaders = ServiceLoader.load(classOf[SparkShimServiceProvider])
+ .asScala.filter(_.matchesVersion(sparkVersion))
+ if (sparkShimLoaders.size > 1) {
+ throw new IllegalArgumentException(s"Multiple Spark Shim Loaders found: $sparkShimLoaders")
+ }
+ logInfo(s"Found shims: $sparkShimLoaders")
+ private val loader = sparkShimLoaders.headOption match {
+ case Some(loader) => loader
+ case None => throw new IllegalArgumentException("Could not find Spark Shim Loader")
+ }
+ private var sparkShims: SparkShims = null
+
+ def getSparkShims: SparkShims = {
+ if (sparkShims == null) {
+ sparkShims = loader.buildShim
+ }
+ sparkShims
+ }
+
+ def getSparkVersion: String = {
+ // hack for databricks, try to find something more reliable?
+ if (SPARK_BUILD_USER.equals("Databricks")) {
+ SPARK_VERSION + "-databricks"
+ } else {
+ SPARK_VERSION
+ }
+ }
+}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShimServiceProvider.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShimServiceProvider.scala
new file mode 100644
index 00000000000..e1429a1f706
--- /dev/null
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShimServiceProvider.scala
@@ -0,0 +1,25 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids
+
+/**
+ * A Spark version shim layer interface.
+ */
+trait SparkShimServiceProvider {
+ def matchesVersion(version:String): Boolean
+ def buildShim: SparkShims
+}
diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala
new file mode 100644
index 00000000000..e6a8aeaeb0b
--- /dev/null
+++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala
@@ -0,0 +1,71 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids
+
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.joins._
+import org.apache.spark.sql.rapids.execution.GpuBroadcastNestedLoopJoinExecBase
+import org.apache.spark.sql.types._
+import org.apache.spark.storage.{BlockId, BlockManagerId}
+
+/**
+ * Spark BuildSide, BuildRight, BuildLeft moved packages in Spark 3.1
+ * so create GPU versions of these that can be agnostic to Spark version.
+ */
+sealed abstract class GpuBuildSide
+
+case object GpuBuildRight extends GpuBuildSide
+
+case object GpuBuildLeft extends GpuBuildSide
+
+trait SparkShims {
+ def isGpuHashJoin(plan: SparkPlan): Boolean
+ def isGpuBroadcastHashJoin(plan: SparkPlan): Boolean
+ def isGpuShuffledHashJoin(plan: SparkPlan): Boolean
+ def getRapidsShuffleManagerClass: String
+ def getBuildSide(join: HashJoin): GpuBuildSide
+ def getBuildSide(join: BroadcastNestedLoopJoinExec): GpuBuildSide
+ def getExprs: Seq[ExprRule[_ <: Expression]]
+ def getExecs: Seq[ExecRule[_ <: SparkPlan]]
+ def getScalaUDFAsExpression(
+ function: AnyRef,
+ dataType: DataType,
+ children: Seq[Expression],
+ inputEncoders: Seq[Option[ExpressionEncoder[_]]] = Nil,
+ outputEncoder: Option[ExpressionEncoder[_]] = None,
+ udfName: Option[String] = None,
+ nullable: Boolean = true,
+ udfDeterministic: Boolean = true): Expression
+
+ def getGpuBroadcastNestedLoopJoinShim(
+ left: SparkPlan,
+ right: SparkPlan,
+ join: BroadcastNestedLoopJoinExec,
+ joinType: JoinType,
+ condition: Option[Expression],
+ targetSizeBytes: Long): GpuBroadcastNestedLoopJoinExecBase
+
+ def getMapSizesByExecutorId(
+ shuffleId: Int,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
+}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala
index 422bd450e50..d65256b764d 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala
@@ -20,7 +20,6 @@ import ai.rapids.cudf
import com.nvidia.spark.rapids._
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExprId, ImplicitCastInputTypes, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, Complete, Final, Partial, PartialMerge}
import org.apache.spark.sql.catalyst.util.TypeUtils
@@ -429,8 +428,11 @@ case class GpuAverage(child: Expression) extends GpuDeclarativeAggregate {
* to check if the value was set (if we don't ignore nulls, valueSet is true, that's what we do
* here).
*/
-case class GpuFirst(child: Expression, ignoreNullsExpr: Expression)
- extends GpuDeclarativeAggregate with ImplicitCastInputTypes {
+abstract class GpuFirstBase(child: Expression)
+ extends GpuDeclarativeAggregate with ImplicitCastInputTypes with Serializable {
+
+ val ignoreNulls: Boolean
+
private lazy val cudfFirst = AttributeReference("cudf_first", child.dataType)()
private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
@@ -458,31 +460,16 @@ case class GpuFirst(child: Expression, ignoreNullsExpr: Expression)
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, BooleanType)
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
- override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil
// First is not a deterministic function.
override lazy val deterministic: Boolean = false
- private def ignoreNulls: Boolean = ignoreNullsExpr match {
- case l: Literal => l.value.asInstanceOf[Boolean]
- case l: GpuLiteral => l.value.asInstanceOf[Boolean]
- case _ => throw new IllegalArgumentException(
- s"$this should only receive literals for ignoreNulls expression")
- }
- override def checkInputDataTypes(): TypeCheckResult = {
- val defaultCheck = super.checkInputDataTypes()
- if (defaultCheck.isFailure) {
- defaultCheck
- } else if (!ignoreNullsExpr.foldable) {
- TypeCheckFailure(s"The second argument of GpuFirst must be a boolean literal, but " +
- s"got: ${ignoreNullsExpr.sql}")
- } else {
- TypeCheckSuccess
- }
- }
override def toString: String = s"gpufirst($child)${if (ignoreNulls) " ignore nulls"}"
}
-case class GpuLast(child: Expression, ignoreNullsExpr: Expression)
- extends GpuDeclarativeAggregate with ImplicitCastInputTypes {
+abstract class GpuLastBase(child: Expression)
+ extends GpuDeclarativeAggregate with ImplicitCastInputTypes with Serializable {
+
+ val ignoreNulls: Boolean
+
private lazy val cudfLast = AttributeReference("cudf_last", child.dataType)()
private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
@@ -509,25 +496,7 @@ case class GpuLast(child: Expression, ignoreNullsExpr: Expression)
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, BooleanType)
override def nullable: Boolean = true
override def dataType: DataType = child.dataType
- override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil
// Last is not a deterministic function.
override lazy val deterministic: Boolean = false
- private def ignoreNulls: Boolean = ignoreNullsExpr match {
- case l: Literal => l.value.asInstanceOf[Boolean]
- case l: GpuLiteral => l.value.asInstanceOf[Boolean]
- case _ => throw new IllegalArgumentException(
- s"$this should only receive literals for ignoreNulls expression")
- }
- override def checkInputDataTypes(): TypeCheckResult = {
- val defaultCheck = super.checkInputDataTypes()
- if (defaultCheck.isFailure) {
- defaultCheck
- } else if (!ignoreNullsExpr.foldable) {
- TypeCheckFailure(s"The second argument of GpuLast must be a boolean literal, but " +
- s"got: ${ignoreNullsExpr.sql}")
- } else {
- TypeCheckSuccess
- }
- }
override def toString: String = s"gpulast($child)${if (ignoreNulls) " ignore nulls"}"
}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCartesianProductExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCartesianProductExec.scala
index 6c49013aa5c..468f08d4209 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCartesianProductExec.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCartesianProductExec.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids
import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
import ai.rapids.cudf.{JCudfSerialization, NvtxColor, NvtxRange}
-import com.nvidia.spark.rapids.{Arm, GpuBindReferences, GpuColumnarBatchSerializer, GpuColumnVector, GpuExec, GpuExpression, GpuSemaphore}
+import com.nvidia.spark.rapids.{Arm, GpuBindReferences, GpuBuildLeft, GpuColumnarBatchSerializer, GpuColumnVector, GpuExec, GpuExpression, GpuSemaphore}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import org.apache.spark.{Dependency, NarrowDependency, Partition, SparkContext, TaskContext}
@@ -28,9 +28,8 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.execution.{BinaryExecNode, ExplainUtils, SparkPlan}
-import org.apache.spark.sql.execution.joins.BuildLeft
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
-import org.apache.spark.sql.rapids.execution.GpuBroadcastNestedLoopJoinExec
+import org.apache.spark.sql.rapids.execution.GpuBroadcastNestedLoopJoinExecBase
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.util.{CompletionIterator, Utils}
@@ -146,10 +145,10 @@ class GpuCartesianRDD(
// Ideally instead of looping through and recomputing rdd2 for
// each batch in rdd1 we would instead cache rdd2 in a way that
// it could spill to disk so we can avoid re-computation
- val ret = GpuBroadcastNestedLoopJoinExec.innerLikeJoin(
+ val ret = GpuBroadcastNestedLoopJoinExecBase.innerLikeJoin(
rdd2.iterator(currSplit.s2, context).map(i => i.getBatch),
table,
- BuildLeft,
+ GpuBuildLeft,
boundCondition,
joinTime,
joinOutputRows,
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala
index 418c2ce2f13..97c4886fa3d 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala
@@ -226,7 +226,7 @@ class GpuDynamicPartitionDataWriter(
*/
private lazy val partitionPathExpression: Expression = Concat(
description.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
- val partitionName = ScalaUDF(
+ val partitionName = ShimLoader.getSparkShims.getScalaUDFAsExpression(
ExternalCatalogUtils.getPartitionPathString _,
StringType,
Seq(Literal(c.name), Cast(c, StringType, Option(description.timeZoneId))))
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExecBase.scala
new file mode 100644
index 00000000000..429c5b352a6
--- /dev/null
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExecBase.scala
@@ -0,0 +1,22 @@
+/*
+ * Copyright (c) 2020, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.rapids
+
+/**
+ * Base trait used for GpuFileSourceScanExec to use it in the Shim layer.
+ */
+trait GpuFileSourceScanExecBase
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala
index 678b34ad204..4037b16accd 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuShuffleEnv.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.rapids
import ai.rapids.cudf.{CudaMemInfo, Rmm}
-import com.nvidia.spark.RapidsShuffleManager
import com.nvidia.spark.rapids._
import org.apache.spark.SparkEnv
@@ -26,7 +25,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.util.Utils
class GpuShuffleEnv extends Logging {
- private val RAPIDS_SHUFFLE_CLASS = classOf[RapidsShuffleManager].getCanonicalName
+ private val RAPIDS_SHUFFLE_CLASS = ShimLoader.getSparkShims.getRapidsShuffleManagerClass
private var isRapidsShuffleManagerInitialized: Boolean = false
private val catalog = new RapidsBufferCatalog
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala
index 8ad97f61ad3..dc8a847a098 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManager.scala
@@ -64,7 +64,7 @@ class GpuShuffleBlockResolver(private val wrapped: ShuffleBlockResolver,
}
-object RapidsShuffleInternalManager extends Logging {
+object RapidsShuffleInternalManagerBase extends Logging {
def unwrapHandle(handle: ShuffleHandle): ShuffleHandle = handle match {
case gh: GpuShuffleHandle[_, _] => gh.wrapped
case other => other
@@ -182,17 +182,13 @@ class RapidsCachingWriter[K, V](
* @note This is an internal class to obtain access to the private
* `ShuffleManager` and `SortShuffleManager` classes. When configuring
* Apache Spark to use the RAPIDS shuffle manager,
- * [[com.nvidia.spark.RapidsShuffleManager]] should be used as that is
- * the public class.
*/
-class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean)
+abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, isDriver: Boolean)
extends ShuffleManager with Logging {
- import RapidsShuffleInternalManager._
-
private val rapidsConf = new RapidsConf(conf)
- private val wrapped = new SortShuffleManager(conf)
+ protected val wrapped = new SortShuffleManager(conf)
GpuShuffleEnv.setRapidsShuffleManagerInitialized(true, this.getClass.getCanonicalName)
logWarning("Rapids Shuffle Plugin Enabled")
@@ -295,7 +291,7 @@ class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean)
}
}
- override def getReaderForRange[K, C](
+ def getReaderInternal[K, C](
handle: ShuffleHandle,
startMapIndex: Int,
endMapIndex: Int,
@@ -303,18 +299,6 @@ class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean)
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
- // NOTE: This type of reader is not possible for gpu shuffle, as we'd need
- // to use the optimization within our manager, and we don't.
- wrapped.getReaderForRange(unwrapHandle(handle), startMapIndex, endMapIndex,
- startPartition, endPartition, context, metrics)
- }
-
- override def getReader[K, C](
- handle: ShuffleHandle,
- startPartition: Int,
- endPartition: Int,
- context: TaskContext,
- metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
handle match {
case gpu: GpuShuffleHandle[_, _] =>
logInfo(s"Asking map output tracker for dependency ${gpu.dependency}, " +
@@ -327,7 +311,8 @@ class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean)
val nvtxRange = new NvtxRange("getMapSizesByExecId", NvtxColor.CYAN)
val blocksByAddress = try {
- env.mapOutputTracker.getMapSizesByExecutorId(gpu.shuffleId, startPartition, endPartition)
+ ShimLoader.getSparkShims.getMapSizesByExecutorId(gpu.shuffleId,
+ startMapIndex, endMapIndex, startPartition, endPartition)
} finally {
nvtxRange.close()
}
@@ -340,7 +325,8 @@ class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean)
transport,
catalog)
case other => {
- wrapped.getReader(unwrapHandle(other), startPartition, endPartition, context, metrics)
+ val shuffleHandle = RapidsShuffleInternalManagerBase.unwrapHandle(other)
+ wrapped.getReader(shuffleHandle, startPartition, endPartition, context, metrics)
}
}
}
diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala
index 361dd3c3fe4..1d75563e97c 100644
--- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala
+++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, ExistenceJoin, FullOuter, Inn
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, IdentityBroadcastMode, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
-import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, BuildLeft, BuildRight, BuildSide}
+import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoinExec
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.rapids.GpuNoColumnCrossJoin
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -52,9 +52,10 @@ class GpuBroadcastNestedLoopJoinMeta(
case _ => willNotWorkOnGpu(s"$join.joinType currently is not supported")
}
- val buildSide = join.buildSide match {
- case BuildLeft => childPlans.head
- case BuildRight => childPlans(1)
+ val gpuBuildSide = ShimLoader.getSparkShims.getBuildSide(join)
+ val buildSide = gpuBuildSide match {
+ case GpuBuildLeft => childPlans.head
+ case GpuBuildRight => childPlans(1)
}
if (!buildSide.canThisBeReplaced) {
@@ -71,26 +72,27 @@ class GpuBroadcastNestedLoopJoinMeta(
val left = childPlans.head.convertIfNeeded()
val right = childPlans(1).convertIfNeeded()
// The broadcast part of this must be a BroadcastExchangeExec
- val buildSide = join.buildSide match {
- case BuildLeft => left
- case BuildRight => right
- }
+ val gpuBuildSide = ShimLoader.getSparkShims.getBuildSide(join)
+ val buildSide = gpuBuildSide match {
+ case GpuBuildLeft => left
+ case GpuBuildRight => right
+ }
if (!buildSide.isInstanceOf[GpuBroadcastExchangeExec]) {
throw new IllegalStateException("the broadcast must be on the GPU too")
}
- GpuBroadcastNestedLoopJoinExec(
- left, right, join.buildSide,
+ ShimLoader.getSparkShims.getGpuBroadcastNestedLoopJoinShim(
+ left, right, join,
join.joinType,
condition.map(_.convertToGpu()),
conf.gpuTargetBatchSizeBytes)
}
}
-object GpuBroadcastNestedLoopJoinExec extends Arm {
+object GpuBroadcastNestedLoopJoinExecBase extends Arm {
def innerLikeJoin(
streamedIter: Iterator[ColumnarBatch],
builtTable: Table,
- buildSide: BuildSide,
+ buildSide: GpuBuildSide,
boundCondition: Option[GpuExpression],
joinTime: SQLMetric,
joinOutputRows: SQLMetric,
@@ -107,8 +109,8 @@ object GpuBroadcastNestedLoopJoinExec extends Arm {
withResource(new NvtxWithMetrics("join", NvtxColor.ORANGE, joinTime)) { _ =>
val joinedTable = withResource(streamTable) { tab =>
buildSide match {
- case BuildLeft => builtTable.crossJoin(tab)
- case BuildRight => tab.crossJoin(builtTable)
+ case GpuBuildLeft => builtTable.crossJoin(tab)
+ case GpuBuildRight => tab.crossJoin(builtTable)
}
}
withResource(joinedTable) { jt =>
@@ -129,14 +131,18 @@ object GpuBroadcastNestedLoopJoinExec extends Arm {
}
}
-case class GpuBroadcastNestedLoopJoinExec(
+abstract class GpuBroadcastNestedLoopJoinExecBase(
left: SparkPlan,
right: SparkPlan,
- buildSide: BuildSide,
+ join: BroadcastNestedLoopJoinExec,
joinType: JoinType,
condition: Option[Expression],
targetSizeBytes: Long) extends BinaryExecNode with GpuExec {
+ // Spark BuildSide, BuildRight, BuildLeft changed packages between Spark versions
+ // so return a GPU version that is agnostic to the Spark version.
+ def getGpuBuildSide: GpuBuildSide
+
override protected def doExecute(): RDD[InternalRow] =
throw new IllegalStateException("This should only be called from columnar")
@@ -148,9 +154,9 @@ case class GpuBroadcastNestedLoopJoinExec(
"filterTime" -> SQLMetrics.createNanoTimingMetric(sparkContext, "filter time"))
/** BuildRight means the right relation <=> the broadcast relation. */
- private val (streamed, broadcast) = buildSide match {
- case BuildRight => (left, right)
- case BuildLeft => (right, left)
+ private val (streamed, broadcast) = getGpuBuildSide match {
+ case GpuBuildRight => (left, right)
+ case GpuBuildLeft => (right, left)
}
def broadcastExchange: GpuBroadcastExchangeExec = broadcast match {
@@ -158,10 +164,10 @@ case class GpuBroadcastNestedLoopJoinExec(
case reused: ReusedExchangeExec => reused.child.asInstanceOf[GpuBroadcastExchangeExec]
}
- override def requiredChildDistribution: Seq[Distribution] = buildSide match {
- case BuildLeft =>
+ override def requiredChildDistribution: Seq[Distribution] = getGpuBuildSide match {
+ case GpuBuildLeft =>
BroadcastDistribution(IdentityBroadcastMode) :: UnspecifiedDistribution :: Nil
- case BuildRight =>
+ case GpuBuildRight =>
UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil
}
@@ -239,13 +245,14 @@ case class GpuBroadcastNestedLoopJoinExec(
streamed.executeColumnar().mapPartitions { streamedIter =>
joinType match {
- case _: InnerLike => GpuBroadcastNestedLoopJoinExec.innerLikeJoin(streamedIter,
- builtTable, buildSide, boundCondition,
+ case _: InnerLike => GpuBroadcastNestedLoopJoinExecBase.innerLikeJoin(streamedIter,
+ builtTable, getGpuBuildSide, boundCondition,
joinTime, joinOutputRows, numOutputRows, numOutputBatches, filterTime, totalTime)
- case _ => throw new IllegalArgumentException(s"$joinType + $buildSide is not supported" +
- s" and should be run on the CPU")
+ case _ => throw new IllegalArgumentException(s"$joinType + $getGpuBuildSide is not" +
+ " supported and should be run on the CPU")
}
}
}
}
-}
\ No newline at end of file
+}
+
diff --git a/tests/pom.xml b/tests/pom.xml
index 62e56f603ec..729baa43440 100644
--- a/tests/pom.xml
+++ b/tests/pom.xml
@@ -30,7 +30,28 @@
RAPIDS plugin for Apache Spark integration tests
0.2.0-SNAPSHOT
+
+ 3.0.0
+
+
+
+ spark31tests
+
+ 3.1.0-SNAPSHOT
+
+
+
+
+
+ org.slf4j
+ jul-to-slf4j
+
+
+ org.slf4j
+ jcl-over-slf4j
+
+
org.scala-lang
scala-library
@@ -38,6 +59,7 @@
org.apache.spark
spark-sql_${scala.binary.version}
+ ${spark.test.version}
org.scalatest
@@ -56,6 +78,12 @@
${project.version}
test
+
+ com.nvidia
+ rapids-4-spark-shims_${scala.binary.version}
+ ${project.version}
+ test
+
org.mockito
mockito-core
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala
index bcdd032915b..f343927d055 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastHashJoinSuite.scala
@@ -17,8 +17,8 @@
package com.nvidia.spark.rapids
import org.apache.spark.SparkConf
+import org.apache.spark.sql.execution.joins.HashJoin
import org.apache.spark.sql.functions.broadcast
-import org.apache.spark.sql.rapids.execution.GpuBroadcastHashJoinExec
class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite {
@@ -36,8 +36,12 @@ class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite {
val plan = df5.queryExecution.executedPlan
- assert(plan.collect { case p: GpuBroadcastHashJoinExec => p }.size === 1)
- assert(plan.collect { case p: GpuShuffledHashJoinExec => p }.size === 1)
+ assert(plan.collect {
+ case p if ShimLoader.getSparkShims.isGpuBroadcastHashJoin(p) => p
+ }.size === 1)
+ assert(plan.collect {
+ case p if ShimLoader.getSparkShims.isGpuShuffledHashJoin(p) => p
+ }.size === 1)
}, conf)
}
@@ -52,13 +56,13 @@ class BroadcastHashJoinSuite extends SparkQueryCompareTestSuite {
val plan2 = spark.sql(s"SELECT /*+ $name(u) */ * FROM t JOIN u ON t.longs = u.longs")
.queryExecution.executedPlan
- val res1 = plan1.find(_.isInstanceOf[GpuBroadcastHashJoinExec])
- val res2 = plan2.find(_.isInstanceOf[GpuBroadcastHashJoinExec])
+ val res1 = plan1.find(ShimLoader.getSparkShims.isGpuBroadcastHashJoin(_))
+ val res2 = plan2.find(ShimLoader.getSparkShims.isGpuBroadcastHashJoin(_))
- assert(res1.get.asInstanceOf[GpuBroadcastHashJoinExec].buildSide.toString
- .equals("BuildLeft"))
- assert(res2.get.asInstanceOf[GpuBroadcastHashJoinExec].buildSide.toString
- .equals("BuildRight"))
+ assert(ShimLoader.getSparkShims.getBuildSide(res1.get.asInstanceOf[HashJoin]).toString ==
+ "GpuBuildLeft")
+ assert(ShimLoader.getSparkShims.getBuildSide(res2.get.asInstanceOf[HashJoin]).toString ==
+ "GpuBuildRight")
}
})
}
diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/HashSortOptimizeSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/HashSortOptimizeSuite.scala
index 3547f94a0b9..287a2a1f8e2 100644
--- a/tests/src/test/scala/com/nvidia/spark/rapids/HashSortOptimizeSuite.scala
+++ b/tests/src/test/scala/com/nvidia/spark/rapids/HashSortOptimizeSuite.scala
@@ -21,7 +21,6 @@ import org.scalatest.FunSuite
import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.execution.{SortExec, SparkPlan}
-import org.apache.spark.sql.rapids.execution.GpuBroadcastHashJoinExec
/** Test plan modifications to add optimizing sorts after hash joins in the plan */
class HashSortOptimizeSuite extends FunSuite {
@@ -70,7 +69,7 @@ class HashSortOptimizeSuite extends FunSuite {
val df2 = buildDataFrame2(spark)
val rdf = df1.join(df2, df1("a") === df2("x"))
val plan = rdf.queryExecution.executedPlan
- val joinNode = plan.find(_.isInstanceOf[GpuBroadcastHashJoinExec])
+ val joinNode = plan.find(ShimLoader.getSparkShims.isGpuBroadcastHashJoin(_))
assert(joinNode.isDefined, "No broadcast join node found")
validateOptimizeSort(plan, joinNode.get)
})
@@ -83,7 +82,7 @@ class HashSortOptimizeSuite extends FunSuite {
val df2 = buildDataFrame2(spark)
val rdf = df1.join(df2, df1("a") === df2("x"))
val plan = rdf.queryExecution.executedPlan
- val joinNode = plan.find(_.isInstanceOf[GpuShuffledHashJoinExec])
+ val joinNode = plan.find(ShimLoader.getSparkShims.isGpuShuffledHashJoin(_))
assert(joinNode.isDefined, "No broadcast join node found")
validateOptimizeSort(plan, joinNode.get)
})