diff --git a/api_validation/src/main/scala/com/nvidia/spark/rapids/api/ApiValidation.scala b/api_validation/src/main/scala/com/nvidia/spark/rapids/api/ApiValidation.scala index 8af502c756c..942e534218e 100644 --- a/api_validation/src/main/scala/com/nvidia/spark/rapids/api/ApiValidation.scala +++ b/api_validation/src/main/scala/com/nvidia/spark/rapids/api/ApiValidation.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import scala.reflect.api import scala.reflect.runtime.universe._ import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.spark.internal.Logging @@ -70,7 +71,7 @@ object ApiValidation extends Logging { var printNewline = false val sparkToShimMap = Map("3.0.1" -> "spark301", "3.1.1" -> "spark311") - val sparkVersion = ShimLoader.getSparkShims.getSparkShimVersion.toString + val sparkVersion = SparkShimImpl.getSparkShimVersion.toString val shimVersion = sparkToShimMap(sparkVersion) gpuKeys.foreach { e => diff --git a/docs/configs.md b/docs/configs.md index 968a9468943..2ec75fa640e 100644 --- a/docs/configs.md +++ b/docs/configs.md @@ -39,7 +39,7 @@ Name | Description | Default Value spark.rapids.memory.gpu.maxAllocFraction|The fraction of total GPU memory that limits the maximum size of the RMM pool. The value must be greater than or equal to the setting for spark.rapids.memory.gpu.allocFraction. Note that this limit will be reduced by the reserve memory configured in spark.rapids.memory.gpu.reserve.|1.0 spark.rapids.memory.gpu.minAllocFraction|The fraction of total GPU memory that limits the minimum size of the RMM pool. The value must be less than or equal to the setting for spark.rapids.memory.gpu.allocFraction.|0.25 spark.rapids.memory.gpu.oomDumpDir|The path to a local directory where a heap dump will be created if the GPU encounters an unrecoverable out-of-memory (OOM) error. The filename will be of the form: "gpu-oom-.hprof" where is the process ID.|None -spark.rapids.memory.gpu.pool|Select the RMM pooling allocator to use. Valid values are "DEFAULT", "ARENA", "ASYNC", and "NONE". With "DEFAULT", the RMM pool allocator is used; with "ARENA", the RMM arena allocator is used; with "ASYNC", the new CUDA stream-ordered memory allocator in CUDA 11.2+ is used. If set to "NONE", pooling is disabled and RMM just passes through to CUDA memory allocation directly.|ASYNC +spark.rapids.memory.gpu.pool|Select the RMM pooling allocator to use. Valid values are "DEFAULT", "ARENA", "ASYNC", and "NONE". With "DEFAULT", the RMM pool allocator is used; with "ARENA", the RMM arena allocator is used; with "ASYNC", the new CUDA stream-ordered memory allocator in CUDA 11.2+ is used. If set to "NONE", pooling is disabled and RMM just passes through to CUDA memory allocation directly.|ARENA spark.rapids.memory.gpu.pooling.enabled|Should RMM act as a pooling allocator for GPU memory, or should it just pass through to CUDA memory allocation directly. DEPRECATED: please use spark.rapids.memory.gpu.pool instead.|true spark.rapids.memory.gpu.reserve|The amount of GPU memory that should remain unallocated by RMM and left for system use such as memory needed for kernels and kernel launches.|671088640 spark.rapids.memory.gpu.unspill.enabled|When a spilled GPU buffer is needed again, should it be unspilled, or only copied back into GPU memory temporarily. Unspilling may be useful for GPU buffers that are needed frequently, for example, broadcast variables; however, it may also increase GPU memory usage|false diff --git a/docs/dev/shims.md b/docs/dev/shims.md index f828865057a..97bf661c1fe 100644 --- a/docs/dev/shims.md +++ b/docs/dev/shims.md @@ -26,9 +26,9 @@ In the following we provide recipes for typical scenarios addressed by the Shim It's among the easiest issues to resolve. We define a method in SparkShims trait covering a superset of parameters from all versions and call it ``` -ShimLoader.getSparkShims.methodWithDiscrepancies(p_1, ..., p_n) +SparkShimImpl.methodWithDiscrepancies(p_1, ..., p_n) ``` -instead of referencing it directly. Shim implementations are in charge of dispatching it further +instead of referencing it directly. Shim implementations (SparkShimImpl) are in charge of dispatching it further to correct version-dependent methods. Moreover, unlike in the below sections conflicts between versions are easily avoided by using different package or class names for conflicting Shim implementations. @@ -40,7 +40,7 @@ Upstream base classes we derive from might be incompatible in the sense that one requires us to implement/override the method `M` whereas the other prohibits it by marking the base implementation `final`, E.g. `org.apache.spark.sql.catalyst.trees.TreeNode` changes between Spark 3.1.x and Spark 3.2.x. So instead of deriving from such classes directly we -inject an intermediate trait e.g. `com.nvidia.spark.rapids.shims.v2.ShimExpression` that +inject an intermediate trait e.g. `com.nvidia.spark.rapids.shims.ShimExpression` that has a varying source code depending on the Spark version we compile against to overcome this issue as you can see e.g., comparing TreeNode: 1. [ShimExpression For 3.0.x and 3.1.x](https://github.com/NVIDIA/spark-rapids/blob/main/sql-plugin/src/main/post320-treenode/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala#L23) diff --git a/integration_tests/src/main/python/asserts.py b/integration_tests/src/main/python/asserts.py index c466fd3ccd2..fe15041e9ba 100644 --- a/integration_tests/src/main/python/asserts.py +++ b/integration_tests/src/main/python/asserts.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ # limitations under the License. from conftest import is_incompat, should_sort_on_spark, should_sort_locally, get_float_check, get_limit, spark_jvm -from datetime import date, datetime +from datetime import date, datetime, timedelta from decimal import Decimal import math from pyspark.sql import Row @@ -92,6 +92,9 @@ def _assert_equal(cpu, gpu, float_check, path): assert cpu == gpu, "GPU and CPU decimal values are different at {}".format(path) elif isinstance(cpu, bytearray): assert cpu == gpu, "GPU and CPU bytearray values are different at {}".format(path) + elif isinstance(cpu, timedelta): + # Used by interval type DayTimeInterval for Pyspark 3.3.0+ + assert cpu == gpu, "GPU and CPU timedelta values are different at {}".format(path) elif (cpu == None): assert cpu == gpu, "GPU and CPU are not both null at {}".format(path) else: diff --git a/integration_tests/src/main/python/cache_test.py b/integration_tests/src/main/python/cache_test.py index 0849ba0b5d0..66dd85bf75c 100644 --- a/integration_tests/src/main/python/cache_test.py +++ b/integration_tests/src/main/python/cache_test.py @@ -24,8 +24,7 @@ enable_vectorized_confs = [{"spark.sql.inMemoryColumnarStorage.enableVectorizedReader": "true"}, {"spark.sql.inMemoryColumnarStorage.enableVectorizedReader": "false"}] -# cache does not work with 128-bit decimals, see https://github.com/NVIDIA/spark-rapids/issues/4826 -_cache_decimal_gens = [decimal_gen_32bit, decimal_gen_64bit] +_cache_decimal_gens = [decimal_gen_32bit, decimal_gen_64bit, decimal_gen_128bit] _cache_single_array_gens_no_null = [ArrayGen(gen) for gen in all_basic_gens_no_null + _cache_decimal_gens] decimal_struct_gen= StructGen([['child0', sub_gen] for ind, sub_gen in enumerate(_cache_decimal_gens)]) diff --git a/integration_tests/src/main/python/data_gen.py b/integration_tests/src/main/python/data_gen.py index fa162dfc200..819d5b644d6 100644 --- a/integration_tests/src/main/python/data_gen.py +++ b/integration_tests/src/main/python/data_gen.py @@ -612,6 +612,33 @@ def make_null(): return None self._start(rand, make_null) +# DayTimeIntervalGen is for Spark 3.3.0+ +# DayTimeIntervalType(startField, endField): Represents a day-time interval which is made up of a contiguous subset of the following fields: +# SECOND, seconds within minutes and possibly fractions of a second [0..59.999999], +# MINUTE, minutes within hours [0..59], +# HOUR, hours within days [0..23], +# DAY, days in the range [0..106751991]. +# For more details: https://spark.apache.org/docs/latest/sql-ref-datatypes.html +# Note: 106751991/365 = 292471 years which is much bigger than 9999 year, seems something is wrong +class DayTimeIntervalGen(DataGen): + """Generate DayTimeIntervalType values""" + def __init__(self, max_days = None, nullable=True, special_cases =[timedelta(seconds = 0)]): + super().__init__(DayTimeIntervalType(), nullable=nullable, special_cases=special_cases) + if max_days is None: + self._max_days = 106751991 + else: + self._max_days = max_days + def start(self, rand): + self._start(rand, + lambda : timedelta( + microseconds = rand.randint(0, 999999), + seconds = rand.randint(0, 59), + minutes = rand.randint(0, 59), + hours = rand.randint(0, 23), + days = rand.randint(0, self._max_days), + ) + ) + def skip_if_not_utc(): if (not is_tz_utc()): skip_unless_precommit_tests('The java system time zone is not set to UTC') diff --git a/integration_tests/src/main/python/date_time_test.py b/integration_tests/src/main/python/date_time_test.py index a5f04b30ea2..0b40afdbe57 100644 --- a/integration_tests/src/main/python/date_time_test.py +++ b/integration_tests/src/main/python/date_time_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ from datetime import date, datetime, timezone from marks import incompat, allow_non_gpu from pyspark.sql.types import * -from spark_session import with_spark_session, is_before_spark_311 +from spark_session import with_spark_session, is_before_spark_311, is_before_spark_330 import pyspark.sql.functions as f # We only support literal intervals for TimeSub @@ -41,6 +41,16 @@ def test_timeadd(data_gen): lambda spark: unary_op_df(spark, TimestampGen(start=datetime(5, 1, 1, tzinfo=timezone.utc), end=datetime(15, 1, 1, tzinfo=timezone.utc)), seed=1) .selectExpr("a + (interval {} days {} seconds)".format(days, seconds))) +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +def test_timeadd_daytime_column(): + gen_list = [ + # timestamp column max year is 1000 + ('t', TimestampGen(end = datetime(1000, 1, 1, tzinfo=timezone.utc))), + # max days is 8000 year, so added result will not be out of range + ('d', DayTimeIntervalGen(max_days = 8000 * 365))] + assert_gpu_and_cpu_are_equal_collect( + lambda spark: gen_df(spark, gen_list).selectExpr("t + d", "t + INTERVAL '1 02:03:04' DAY TO SECOND")) + @pytest.mark.parametrize('data_gen', vals, ids=idfn) def test_dateaddinterval(data_gen): days, seconds = data_gen diff --git a/integration_tests/src/main/python/parquet_test.py b/integration_tests/src/main/python/parquet_test.py index eb52ac34568..aef67dcc46e 100644 --- a/integration_tests/src/main/python/parquet_test.py +++ b/integration_tests/src/main/python/parquet_test.py @@ -789,3 +789,21 @@ def test_parquet_read_field_id(spark_tmp_path): lambda spark: spark.read.schema(readSchema).parquet(data_path), 'FileSourceScanExec', {"spark.sql.parquet.fieldId.read.enabled": "true"}) # default is false + +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +def test_parquet_read_daytime_interval_cpu_file(spark_tmp_path): + data_path = spark_tmp_path + '/PARQUET_DATA' + gen_list = [('_c1', DayTimeIntervalGen())] + # write DayTimeInterval with CPU + with_cpu_session(lambda spark :gen_df(spark, gen_list).coalesce(1).write.mode("overwrite").parquet(data_path)) + assert_gpu_and_cpu_are_equal_collect( + lambda spark: spark.read.parquet(data_path)) + +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +def test_parquet_read_daytime_interval_gpu_file(spark_tmp_path): + data_path = spark_tmp_path + '/PARQUET_DATA' + gen_list = [('_c1', DayTimeIntervalGen())] + # write DayTimeInterval with GPU + with_gpu_session(lambda spark :gen_df(spark, gen_list).coalesce(1).write.mode("overwrite").parquet(data_path)) + assert_gpu_and_cpu_are_equal_collect( + lambda spark: spark.read.parquet(data_path)) \ No newline at end of file diff --git a/integration_tests/src/main/python/parquet_write_test.py b/integration_tests/src/main/python/parquet_write_test.py index fd556830135..0e9a990652e 100644 --- a/integration_tests/src/main/python/parquet_write_test.py +++ b/integration_tests/src/main/python/parquet_write_test.py @@ -418,3 +418,14 @@ def test_parquet_write_field_id(spark_tmp_path): data_path, 'DataWritingCommandExec', conf = {"spark.sql.parquet.fieldId.write.enabled" : "true"}) # default is true + +@pytest.mark.order(1) # at the head of xdist worker queue if pytest-order is installed +@pytest.mark.skipif(is_before_spark_330(), reason='DayTimeInterval is not supported before Pyspark 3.3.0') +def test_write_daytime_interval(spark_tmp_path): + gen_list = [('_c1', DayTimeIntervalGen())] + data_path = spark_tmp_path + '/PARQUET_DATA' + assert_gpu_and_cpu_writes_are_equal_collect( + lambda spark, path: gen_df(spark, gen_list).coalesce(1).write.parquet(path), + lambda spark, path: spark.read.parquet(path), + data_path, + conf=writer_confs) diff --git a/pom.xml b/pom.xml index 92d5f676229..ca78e131f48 100644 --- a/pom.xml +++ b/pom.xml @@ -115,6 +115,7 @@ ${project.basedir}/src/main/301+-nondb/scala ${project.basedir}/src/main/301/scala + ${project.basedir}/src/main/301until304/scala ${project.basedir}/src/main/301until310-all/scala ${project.basedir}/src/main/301until310-nondb/scala ${project.basedir}/src/main/301until320-all/scala @@ -164,6 +165,7 @@ ${project.basedir}/src/main/301+-nondb/scala ${project.basedir}/src/main/302/scala + ${project.basedir}/src/main/301until304/scala ${project.basedir}/src/main/301until310-all/scala ${project.basedir}/src/main/301until310-nondb/scala ${project.basedir}/src/main/301until320-all/scala @@ -222,6 +224,7 @@ ${project.basedir}/src/main/301+-nondb/scala ${project.basedir}/src/main/303/scala + ${project.basedir}/src/main/301until304/scala ${project.basedir}/src/main/301until310-all/scala ${project.basedir}/src/main/301until310-nondb/scala ${project.basedir}/src/main/301until320-all/scala @@ -327,7 +330,7 @@ ${project.basedir}/src/main/301+-nondb/scala - ${project.basedir}/src/main/311/scala + ${project.basedir}/src/main/311-nondb/scala ${project.basedir}/src/main/301until320-all/scala ${project.basedir}/src/main/301until320-noncdh/scala ${project.basedir}/src/main/301until320-nondb/scala @@ -337,7 +340,6 @@ ${project.basedir}/src/main/311until320-all/scala ${project.basedir}/src/main/311until320-noncdh/scala ${project.basedir}/src/main/311until320-nondb/scala - ${project.basedir}/src/main/311until330-all/scala ${project.basedir}/src/main/pre320-treenode/scala @@ -464,7 +466,6 @@ ${project.basedir}/src/main/311+-all/scala ${project.basedir}/src/main/311until320-noncdh/scala ${project.basedir}/src/main/31xdb/scala - ${project.basedir}/src/main/311until330-all/scala ${project.basedir}/src/main/post320-treenode/scala @@ -509,7 +510,7 @@ ${project.basedir}/src/main/301+-nondb/scala - ${project.basedir}/src/main/312/scala + ${project.basedir}/src/main/312-nondb/scala ${project.basedir}/src/main/301until320-all/scala ${project.basedir}/src/main/301until320-noncdh/scala ${project.basedir}/src/main/301until320-nondb/scala @@ -519,7 +520,6 @@ ${project.basedir}/src/main/311until320-all/scala ${project.basedir}/src/main/311until320-noncdh/scala ${project.basedir}/src/main/311until320-nondb/scala - ${project.basedir}/src/main/311until330-all/scala ${project.basedir}/src/main/pre320-treenode/scala @@ -577,7 +577,6 @@ ${project.basedir}/src/main/311until320-all/scala ${project.basedir}/src/main/311until320-noncdh/scala ${project.basedir}/src/main/311until320-nondb/scala - ${project.basedir}/src/main/311until330-all/scala ${project.basedir}/src/main/pre320-treenode/scala @@ -629,7 +628,6 @@ ${project.basedir}/src/main/301until330-all/scala ${project.basedir}/src/main/311+-all/scala ${project.basedir}/src/main/311+-nondb/scala - ${project.basedir}/src/main/311until330-all/scala ${project.basedir}/src/main/320/scala ${project.basedir}/src/main/320+/scala ${project.basedir}/src/main/320until330-all/scala @@ -640,7 +638,7 @@ add-test-profile-src-320 add-test-source - none + generate-test-sources ${project.basedir}/src/test/320/scala @@ -693,7 +691,6 @@ ${project.basedir}/src/main/301until330-all/scala ${project.basedir}/src/main/311+-all/scala ${project.basedir}/src/main/311+-nondb/scala - ${project.basedir}/src/main/311until330-all/scala ${project.basedir}/src/main/320+/scala ${project.basedir}/src/main/320until330-all/scala ${project.basedir}/src/main/321+/scala @@ -704,7 +701,7 @@ add-test-profile-src-321 add-test-source - none + generate-test-sources ${project.basedir}/src/test/321/scala @@ -757,9 +754,9 @@ ${project.basedir}/src/main/301until330-all/scala ${project.basedir}/src/main/311+-all/scala ${project.basedir}/src/main/311+-nondb/scala - ${project.basedir}/src/main/311until330-all/scala ${project.basedir}/src/main/320+/scala ${project.basedir}/src/main/321+/scala + ${project.basedir}/src/main/322+/scala ${project.basedir}/src/main/320until330-all/scala ${project.basedir}/src/main/post320-treenode/scala @@ -768,7 +765,7 @@ add-test-profile-src-322 add-test-source - none + generate-test-sources ${project.basedir}/src/test/322/scala @@ -822,6 +819,7 @@ ${project.basedir}/src/main/311+-nondb/scala ${project.basedir}/src/main/320+/scala ${project.basedir}/src/main/321+/scala + ${project.basedir}/src/main/322+/scala ${project.basedir}/src/main/330+/scala ${project.basedir}/src/main/post320-treenode/scala @@ -830,7 +828,7 @@ add-test-profile-src-330 add-test-source - none + generate-test-sources ${project.basedir}/src/test/330/scala @@ -879,6 +877,8 @@ ${project.basedir}/src/main/301+-nondb/scala + ${project.basedir}/src/main/311-nondb/scala + ${project.basedir}/src/main/311cdh/scala ${project.basedir}/src/main/301until320-all/scala ${project.basedir}/src/main/301until320-nondb/scala ${project.basedir}/src/main/301until330-all/scala @@ -887,7 +887,6 @@ ${project.basedir}/src/main/311cdh/scala ${project.basedir}/src/main/311until320-all/scala ${project.basedir}/src/main/311until320-nondb/scala - ${project.basedir}/src/main/311until330-all/scala ${project.basedir}/src/main/pre320-treenode/scala diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCastMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCastMeta.scala index 5f9b226385e..52cfb45f095 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCastMeta.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCastMeta.scala @@ -106,7 +106,7 @@ final class CastExprMeta[INPUT <: Cast]( // NOOP for anything prior to 3.2.0 case (_: StringType, dt:DecimalType) => // Spark 2.x: removed check for - // !ShimLoader.getSparkShims.isCastingStringToNegDecimalScaleSupported + // !SparkShimImpl.isCastingStringToNegDecimalScaleSupported // this dealt with handling a bug fix that is only in newer versions of Spark // (https://issues.apache.org/jira/browse/SPARK-37451) // Since we don't know what version of Spark 3 they will be using diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 846240619ac..72f9593effa 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -23,7 +23,7 @@ import scala.reflect.ClassTag import scala.util.control.NonFatal import com.nvidia.spark.rapids.RapidsConf.{SUPPRESS_PLANNING_FAILURE, TEST_CONF} -import com.nvidia.spark.rapids.shims.v2._ +import com.nvidia.spark.rapids.shims._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SparkSession} @@ -1397,7 +1397,7 @@ object GpuOverrides extends Logging { TypeSig.STRING)), (a, conf, p, r) => new UnixTimeExprMeta[ToUnixTimestamp](a, conf, p, r) { override def shouldFallbackOnAnsiTimestamp: Boolean = false - // ShimLoader.getSparkShims.shouldFallbackOnAnsiTimestamp + // SparkShimImpl.shouldFallbackOnAnsiTimestamp }), expr[UnixTimestamp]( "Returns the UNIX timestamp of current or specified time", @@ -1410,7 +1410,7 @@ object GpuOverrides extends Logging { TypeSig.STRING)), (a, conf, p, r) => new UnixTimeExprMeta[UnixTimestamp](a, conf, p, r) { override def shouldFallbackOnAnsiTimestamp: Boolean = false - // ShimLoader.getSparkShims.shouldFallbackOnAnsiTimestamp + // SparkShimImpl.shouldFallbackOnAnsiTimestamp }), expr[Hour]( @@ -2865,8 +2865,8 @@ object GpuOverrides extends Logging { TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), TypeSig.all), (sample, conf, p, r) => new GpuSampleExecMeta(sample, conf, p, r) {} ), - // ShimLoader.getSparkShims.aqeShuffleReaderExec, - // ShimLoader.getSparkShims.neverReplaceShowCurrentNamespaceCommand, + // SparkShimImpl.aqeShuffleReaderExec, + // SparkShimImpl.neverReplaceShowCurrentNamespaceCommand, neverReplaceExec[ExecutedCommandExec]("Table metadata operation") ).collect { case r if r != null => (r.getClassFor.asSubclass(classOf[SparkPlan]), r) }.toMap @@ -2955,7 +2955,7 @@ object GpuOverrides extends Logging { // case c2r: ColumnarToRowExec => prepareExplainOnly(c2r.child) case re: ReusedExchangeExec => prepareExplainOnly(re.child) // case aqe: AdaptiveSparkPlanExec => - // prepareExplainOnly(ShimLoader.getSparkShims.getAdaptiveInputPlan(aqe)) + // prepareExplainOnly(SparkShimImpl.getAdaptiveInputPlan(aqe)) case sub: SubqueryExec => prepareExplainOnly(sub.child) } planAfter diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala index 3126a241cfe..898df817841 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala @@ -79,7 +79,7 @@ object GpuParquetFileFormat { // they set when they get to 3.x. The default in 3.x is EXCEPTION which would be good // for us. /* - ShimLoader.getSparkShims.int96ParquetRebaseWrite(sqlConf) match { + SparkShimImpl.int96ParquetRebaseWrite(sqlConf) match { case "EXCEPTION" => case "CORRECTED" => case "LEGACY" => @@ -90,7 +90,7 @@ object GpuParquetFileFormat { meta.willNotWorkOnGpu(s"$other is not a supported rebase mode for int96") } - ShimLoader.getSparkShims.parquetRebaseWrite(sqlConf) match { + SparkShimImpl.parquetRebaseWrite(sqlConf) match { case "EXCEPTION" => //Good case "CORRECTED" => //Good case "LEGACY" => diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScanBase.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScanBase.scala index e0b2fa20dc9..f66efaa4a51 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScanBase.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScanBase.scala @@ -97,31 +97,31 @@ object GpuParquetScanBase { // Spark 2.x doesn't support the rebase mode /* - sqlConf.get(ShimLoader.getSparkShims.int96ParquetRebaseReadKey) match { + sqlConf.get(SparkShimImpl.int96ParquetRebaseReadKey) match { case "EXCEPTION" => if (schemaMightNeedNestedRebase) { meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + - s"${ShimLoader.getSparkShims.int96ParquetRebaseReadKey} is EXCEPTION") + s"${SparkShimImpl.int96ParquetRebaseReadKey} is EXCEPTION") } case "CORRECTED" => // Good case "LEGACY" => // really is EXCEPTION for us... if (schemaMightNeedNestedRebase) { meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + - s"${ShimLoader.getSparkShims.int96ParquetRebaseReadKey} is LEGACY") + s"${SparkShimImpl.int96ParquetRebaseReadKey} is LEGACY") } case other => meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode") } - sqlConf.get(ShimLoader.getSparkShims.parquetRebaseReadKey) match { + sqlConf.get(SparkShimImpl.parquetRebaseReadKey) match { case "EXCEPTION" => if (schemaMightNeedNestedRebase) { meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + - s"${ShimLoader.getSparkShims.parquetRebaseReadKey} is EXCEPTION") + s"${SparkShimImpl.parquetRebaseReadKey} is EXCEPTION") } case "CORRECTED" => // Good case "LEGACY" => // really is EXCEPTION for us... if (schemaMightNeedNestedRebase) { meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + - s"${ShimLoader.getSparkShims.parquetRebaseReadKey} is LEGACY") + s"${SparkShimImpl.parquetRebaseReadKey} is LEGACY") } case other => meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode") diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadCSVFileFormat.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadCSVFileFormat.scala index 22723a9a22d..d14050cde95 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadCSVFileFormat.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadCSVFileFormat.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids -import com.nvidia.spark.rapids.shims.v2.GpuCSVScan +import com.nvidia.spark.rapids.shims.GpuCSVScan import org.apache.spark.sql.execution.FileSourceScanExec diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimGpuOverrides.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimGpuOverrides.scala index 9b8981e143c..71621d3fa76 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimGpuOverrides.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimGpuOverrides.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids -import com.nvidia.spark.rapids.shims.v2._ +import com.nvidia.spark.rapids.shims._ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index 9fe89c2f9ed..8e44ba50b58 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -19,7 +19,7 @@ package com.nvidia.spark.rapids import java.io.{File, FileOutputStream} import java.time.ZoneId -import com.nvidia.spark.rapids.shims.v2.TypeSigUtil +import com.nvidia.spark.rapids.shims.TypeSigUtil import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION} import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnaryExpression, WindowSpecDefinition} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuBroadcastHashJoinExecMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/GpuBroadcastHashJoinExecMeta.scala similarity index 96% rename from spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuBroadcastHashJoinExecMeta.scala rename to spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/GpuBroadcastHashJoinExecMeta.scala index 8c19f858cff..ecc73fc27cd 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuBroadcastHashJoinExecMeta.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/GpuBroadcastHashJoinExecMeta.scala @@ -14,10 +14,10 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2._ +import com.nvidia.spark.rapids.shims._ import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.rapids.execution.{GpuHashJoin, JoinTypeChecks} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuCSVScan.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/GpuCSVScan.scala similarity index 99% rename from spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuCSVScan.scala rename to spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/GpuCSVScan.scala index 5285974d0a0..02d5bca1656 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuCSVScan.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/GpuCSVScan.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import java.nio.charset.StandardCharsets diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuJoinUtils.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/GpuJoinUtils.scala similarity index 93% rename from spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuJoinUtils.scala rename to spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/GpuJoinUtils.scala index 032b3aa39b8..0223f2f05b9 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuJoinUtils.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/GpuJoinUtils.scala @@ -13,9 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims -import com.nvidia.spark.rapids.shims.v2._ +import com.nvidia.spark.rapids.shims._ import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/GpuRegExpReplaceMeta.scala similarity index 98% rename from spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala rename to spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/GpuRegExpReplaceMeta.scala index 07c1976b9f2..df07b72fa39 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/GpuRegExpReplaceMeta.scala @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuShuffledHashJoinExecMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/GpuShuffledHashJoinExecMeta.scala similarity index 97% rename from spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuShuffledHashJoinExecMeta.scala rename to spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/GpuShuffledHashJoinExecMeta.scala index 888e8b4f8fb..d9a99440aad 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuShuffledHashJoinExecMeta.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/GpuShuffledHashJoinExecMeta.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuSortMergeJoinMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/GpuSortMergeJoinMeta.scala similarity index 97% rename from spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuSortMergeJoinMeta.scala rename to spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/GpuSortMergeJoinMeta.scala index c19b1213798..427bc6a9c41 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/GpuSortMergeJoinMeta.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/GpuSortMergeJoinMeta.scala @@ -14,10 +14,10 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2._ +import com.nvidia.spark.rapids.shims._ import org.apache.spark.sql.execution.SortExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/OffsetWindowFunctionMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/OffsetWindowFunctionMeta.scala similarity index 98% rename from spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/OffsetWindowFunctionMeta.scala rename to spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/OffsetWindowFunctionMeta.scala index a245ebebfdc..0265874d2a2 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/OffsetWindowFunctionMeta.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/OffsetWindowFunctionMeta.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{BaseExprMeta, DataFromReplacementRule, ExprMeta, GpuOverrides, RapidsConf, RapidsMeta} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/TreeNode.scala similarity index 97% rename from spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala rename to spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/TreeNode.scala index 7378c93aed6..95a13fddad6 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/TreeNode.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, TernaryExpression, UnaryExpression} import org.apache.spark.sql.catalyst.plans.logical.Command diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/TypeSigUtil.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/TypeSigUtil.scala similarity index 98% rename from spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/TypeSigUtil.scala rename to spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/TypeSigUtil.scala index 39b7af5f6d9..fda779c5cc7 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/TypeSigUtil.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/TypeSigUtil.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{TypeEnum, TypeSig} diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuPandasMeta.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/gpuPandasMeta.scala similarity index 99% rename from spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuPandasMeta.scala rename to spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/gpuPandasMeta.scala index 22f60110a35..7504a1ff9e4 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuPandasMeta.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/gpuPandasMeta.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ diff --git a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/gpuWindows.scala similarity index 99% rename from spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala rename to spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/gpuWindows.scala index 4111942c863..1ca155bf0f2 100644 --- a/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala +++ b/spark2-sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/gpuWindows.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import java.util.concurrent.TimeUnit diff --git a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinMeta.scala b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinMeta.scala index 8793829485c..ab874a06c03 100644 --- a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinMeta.scala +++ b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinMeta.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.rapids.execution import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2._ +import com.nvidia.spark.rapids.shims._ import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, BuildLeft, BuildRight, BuildSide} diff --git a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala index a89ed80caf9..c65c5b62460 100644 --- a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala +++ b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuHashJoin.scala @@ -16,7 +16,7 @@ package org.apache.spark.sql.rapids.execution import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2._ +import com.nvidia.spark.rapids.shims._ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.{Cross, ExistenceJoin, FullOuter, Inner, InnerLike, JoinType, LeftAnti, LeftExistence, LeftOuter, LeftSemi, RightOuter} diff --git a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringMeta.scala b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringMeta.scala index 71ab0d831ef..85814e5cb47 100644 --- a/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringMeta.scala +++ b/spark2-sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringMeta.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids import scala.collection.mutable.ArrayBuffer import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2.ShimExpression +import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.sql.catalyst.expressions.{Literal, RegExpExtract, RLike, StringSplit, SubstringIndex} import org.apache.spark.sql.types._ diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/AQEUtils.scala b/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/AQEUtils.scala similarity index 92% rename from sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/AQEUtils.scala rename to sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/AQEUtils.scala index 2adb6d96c10..b001bc929c0 100644 --- a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/AQEUtils.scala +++ b/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/AQEUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.adaptive.{QueryStageExec, ShuffleQueryStageExec} diff --git a/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuWindowInPandasExec.scala b/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/GpuWindowInPandasExec.scala similarity index 95% rename from sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuWindowInPandasExec.scala rename to sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/GpuWindowInPandasExec.scala index fe84d685af0..19392fa5335 100644 --- a/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuWindowInPandasExec.scala +++ b/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/GpuWindowInPandasExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder} import org.apache.spark.sql.execution.SparkPlan diff --git a/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/v2/ShimBroadcastExchangeLike.scala b/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/ShimBroadcastExchangeLike.scala similarity index 93% rename from sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/v2/ShimBroadcastExchangeLike.scala rename to sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/ShimBroadcastExchangeLike.scala index 85eb1bd1221..7a36a7596a6 100644 --- a/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/v2/ShimBroadcastExchangeLike.scala +++ b/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/ShimBroadcastExchangeLike.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import scala.concurrent.Promise diff --git a/sql-plugin/src/main/301+-nondb/scala/org/apache/spark/sql/rapids/execution/python/shims/v2/GpuFlatMapGroupsInPandasExec.scala b/sql-plugin/src/main/301+-nondb/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuFlatMapGroupsInPandasExec.scala similarity index 96% rename from sql-plugin/src/main/301+-nondb/scala/org/apache/spark/sql/rapids/execution/python/shims/v2/GpuFlatMapGroupsInPandasExec.scala rename to sql-plugin/src/main/301+-nondb/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuFlatMapGroupsInPandasExec.scala index b54a59a9199..81342601d0b 100644 --- a/sql-plugin/src/main/301+-nondb/scala/org/apache/spark/sql/rapids/execution/python/shims/v2/GpuFlatMapGroupsInPandasExec.scala +++ b/sql-plugin/src/main/301+-nondb/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuFlatMapGroupsInPandasExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,11 +14,11 @@ * limitations under the License. */ -package org.apache.spark.sql.rapids.execution.python.shims.v2 +package org.apache.spark.sql.rapids.execution.python.shims import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.python.PythonWorkerSemaphore -import com.nvidia.spark.rapids.shims.v2.ShimUnaryExecNode +import com.nvidia.spark.rapids.shims.{ShimUnaryExecNode, SparkShimImpl} import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} @@ -96,7 +96,7 @@ case class GpuFlatMapGroupsInPandasExec( } override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(groupingAttributes.map(ShimLoader.getSparkShims.sortOrder(_, Ascending))) + Seq(groupingAttributes.map(SparkShimImpl.sortOrder(_, Ascending))) private val pandasFunction = func.asInstanceOf[GpuPythonUDF].func diff --git a/sql-plugin/src/main/301/scala/com/nvidia/spark/rapids/shims/spark301/RapidsShuffleInternalManager.scala b/sql-plugin/src/main/301/scala/com/nvidia/spark/rapids/shims/spark301/RapidsShuffleInternalManager.scala index 5747852f1ed..47808ea22d0 100644 --- a/sql-plugin/src/main/301/scala/com/nvidia/spark/rapids/shims/spark301/RapidsShuffleInternalManager.scala +++ b/sql-plugin/src/main/301/scala/com/nvidia/spark/rapids/shims/spark301/RapidsShuffleInternalManager.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/sql-plugin/src/main/301/scala/com/nvidia/spark/rapids/shims/spark301/SparkShimServiceProvider.scala b/sql-plugin/src/main/301/scala/com/nvidia/spark/rapids/shims/spark301/SparkShimServiceProvider.scala index 24668123c21..3d861c0c656 100644 --- a/sql-plugin/src/main/301/scala/com/nvidia/spark/rapids/shims/spark301/SparkShimServiceProvider.scala +++ b/sql-plugin/src/main/301/scala/com/nvidia/spark/rapids/shims/spark301/SparkShimServiceProvider.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.shims.spark301 -import com.nvidia.spark.rapids.{SparkShims, SparkShimVersion} +import com.nvidia.spark.rapids.SparkShimVersion object SparkShimServiceProvider { val VERSION = SparkShimVersion(3, 0, 1) @@ -24,11 +24,9 @@ object SparkShimServiceProvider { } class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { + override def getShimVersion: SparkShimVersion = SparkShimServiceProvider.VERSION + def matchesVersion(version: String): Boolean = { SparkShimServiceProvider.VERSIONNAMES.contains(version) } - - def buildShim: SparkShims = { - new Spark301Shims() - } } diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/spark301db/Spark301dbShims.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/SparkShims.scala similarity index 67% rename from sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/spark301db/Spark301dbShims.scala rename to sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/SparkShims.scala index f77e2447ee2..36b07462d3e 100644 --- a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/spark301db/Spark301dbShims.scala +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,12 +14,11 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.spark301db +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2._ -class Spark301dbShims extends Spark30XdbShims with Spark30Xuntil33XShims { +object SparkShimImpl extends Spark30XdbShims with Spark30Xuntil33XShims { - override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION + override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion } diff --git a/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/v2/AQEUtils.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/AQEUtils.scala similarity index 92% rename from sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/v2/AQEUtils.scala rename to sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/AQEUtils.scala index 2adb6d96c10..b001bc929c0 100644 --- a/sql-plugin/src/main/301+-nondb/scala/com/nvidia/spark/rapids/shims/v2/AQEUtils.scala +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/AQEUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.adaptive.{QueryStageExec, ShuffleQueryStageExec} diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuJoinUtils.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/GpuJoinUtils.scala similarity index 96% rename from sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuJoinUtils.scala rename to sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/GpuJoinUtils.scala index c41fc5b578a..3bdb81676bf 100644 --- a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuJoinUtils.scala +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/GpuJoinUtils.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{GpuBuildLeft, GpuBuildRight, GpuBuildSide} diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/GpuRegExpReplaceExec.scala similarity index 98% rename from sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala rename to sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/GpuRegExpReplaceExec.scala index ce13571e910..9c30127dc27 100644 --- a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/GpuRegExpReplaceExec.scala @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException, TernaryExprMeta} diff --git a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRunningWindowExec.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/GpuRunningWindowExec.scala similarity index 94% rename from sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRunningWindowExec.scala rename to sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/GpuRunningWindowExec.scala index ff8cedf6aeb..f12c8208ec2 100644 --- a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRunningWindowExec.scala +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/GpuRunningWindowExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.databricks.sql.execution.window.RunningWindowFunctionExec import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuBaseWindowExecMeta, RapidsConf, RapidsMeta} diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuWindowInPandasExec.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/GpuWindowInPandasExec.scala similarity index 97% rename from sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuWindowInPandasExec.scala rename to sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/GpuWindowInPandasExec.scala index 4c8bde846b7..e3085905c9c 100644 --- a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuWindowInPandasExec.scala +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/GpuWindowInPandasExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{GpuBindReferences, GpuBoundReference, GpuProjectExec, GpuWindowExpression} diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/ShimBroadcastExchangeLike.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/ShimBroadcastExchangeLike.scala similarity index 93% rename from sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/ShimBroadcastExchangeLike.scala rename to sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/ShimBroadcastExchangeLike.scala index 85eb1bd1221..7a36a7596a6 100644 --- a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/ShimBroadcastExchangeLike.scala +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/ShimBroadcastExchangeLike.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import scala.concurrent.Promise diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/Spark30XdbShims.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/Spark30XdbShims.scala similarity index 99% rename from sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/Spark30XdbShims.scala rename to sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/Spark30XdbShims.scala index abab4be1e81..cd3ed4b13fa 100644 --- a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/Spark30XdbShims.scala +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/Spark30XdbShims.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import java.net.URI import java.nio.ByteBuffer @@ -30,7 +30,7 @@ import org.apache.parquet.schema.MessageType import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging -import org.apache.spark.rapids.shims.v2.GpuShuffleExchangeExec +import org.apache.spark.rapids.shims.GpuShuffleExchangeExec import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} @@ -59,8 +59,8 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.{GpuAbs, GpuAverage, GpuFileSourceScanExec, GpuTimeSub} import org.apache.spark.sql.rapids.execution.{GpuShuffleExchangeExecBase, SerializeBatchDeserializeHostBuffer, SerializeConcatHostBuffersDeserializeBatch, TrampolineUtil} import org.apache.spark.sql.rapids.execution.python._ -import org.apache.spark.sql.rapids.execution.python.shims.v2._ -import org.apache.spark.sql.rapids.shims.v2.{GpuFileScanRDD, GpuSchemaUtils} +import org.apache.spark.sql.rapids.execution.python.shims._ +import org.apache.spark.sql.rapids.shims.{GpuFileScanRDD, GpuSchemaUtils} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.storage.{BlockId, BlockManagerId} diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/Spark30XdbShimsBase.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/Spark30XdbShimsBase.scala similarity index 99% rename from sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/Spark30XdbShimsBase.scala rename to sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/Spark30XdbShimsBase.scala index eacd204f980..7fd763b5268 100644 --- a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/Spark30XdbShimsBase.scala +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/Spark30XdbShimsBase.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import scala.collection.mutable.ListBuffer diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/spark301db/SparkShimServiceProvider.scala b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/spark301db/SparkShimServiceProvider.scala index be7508ffcf2..de07110a8b7 100644 --- a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/spark301db/SparkShimServiceProvider.scala +++ b/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/spark301db/SparkShimServiceProvider.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.shims.spark301db -import com.nvidia.spark.rapids.{DatabricksShimVersion, SparkShims} +import com.nvidia.spark.rapids.{DatabricksShimVersion, ShimVersion} object SparkShimServiceProvider { val VERSION = DatabricksShimVersion(3, 0, 1) @@ -25,11 +25,9 @@ object SparkShimServiceProvider { class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { + override def getShimVersion: ShimVersion = SparkShimServiceProvider.VERSION + def matchesVersion(version: String): Boolean = { SparkShimServiceProvider.VERSIONNAMES.contains(version) } - - def buildShim: SparkShims = { - new Spark301dbShims() - } } diff --git a/sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala b/sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/GpuShuffleExchangeExec.scala similarity index 96% rename from sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala rename to sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/GpuShuffleExchangeExec.scala index b3ea5fbcafd..2f3fb8c0e4f 100644 --- a/sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala +++ b/sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/GpuShuffleExchangeExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.rapids.shims.v2 +package org.apache.spark.rapids.shims import com.nvidia.spark.rapids.GpuPartitioning diff --git a/sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/v2/ShuffledBatchRDDUtil.scala b/sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/ShuffledBatchRDDUtil.scala similarity index 96% rename from sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/v2/ShuffledBatchRDDUtil.scala rename to sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/ShuffledBatchRDDUtil.scala index 6e63fe5b97b..9e4863f6e9d 100644 --- a/sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/v2/ShuffledBatchRDDUtil.scala +++ b/sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/ShuffledBatchRDDUtil.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,9 @@ * limitations under the License. */ -package org.apache.spark.rapids.shims.v2 +package org.apache.spark.rapids.shims -import com.nvidia.spark.rapids.ShimLoader +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.spark.{MapOutputTrackerMaster, Partition, ShuffleDependency, SparkEnv, TaskContext} import org.apache.spark.shuffle.ShuffleReader @@ -57,7 +57,7 @@ object ShuffledBatchRDDUtil { dependency: ShuffleDependency[Int, ColumnarBatch, ColumnarBatch], sqlMetricsReporter: SQLShuffleReadMetricsReporter): (ShuffleReader[Nothing, Nothing], Long) = { - val shim = ShimLoader.getSparkShims + val shim = SparkShimImpl split.asInstanceOf[ShuffledBatchRDDPartition].spec match { case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) => val reader = SparkEnv.get.shuffleManager.getReader( diff --git a/sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/v2/api/python/ShimBasePythonRunner.scala b/sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala similarity index 94% rename from sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/v2/api/python/ShimBasePythonRunner.scala rename to sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala index 958e1f681f0..9a613fbb676 100644 --- a/sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/v2/api/python/ShimBasePythonRunner.scala +++ b/sql-plugin/src/main/301db/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.rapids.shims.v2.api.python +package org.apache.spark.rapids.shims.api.python import java.io.DataInputStream import java.net.Socket diff --git a/sql-plugin/src/main/301db/scala/org/apache/spark/sql/rapids/execution/python/shims/v2/GpuFlatMapGroupsInPandasExec.scala b/sql-plugin/src/main/301db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuFlatMapGroupsInPandasExec.scala similarity index 96% rename from sql-plugin/src/main/301db/scala/org/apache/spark/sql/rapids/execution/python/shims/v2/GpuFlatMapGroupsInPandasExec.scala rename to sql-plugin/src/main/301db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuFlatMapGroupsInPandasExec.scala index b54a59a9199..81342601d0b 100644 --- a/sql-plugin/src/main/301db/scala/org/apache/spark/sql/rapids/execution/python/shims/v2/GpuFlatMapGroupsInPandasExec.scala +++ b/sql-plugin/src/main/301db/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuFlatMapGroupsInPandasExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,11 +14,11 @@ * limitations under the License. */ -package org.apache.spark.sql.rapids.execution.python.shims.v2 +package org.apache.spark.sql.rapids.execution.python.shims import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.python.PythonWorkerSemaphore -import com.nvidia.spark.rapids.shims.v2.ShimUnaryExecNode +import com.nvidia.spark.rapids.shims.{ShimUnaryExecNode, SparkShimImpl} import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} @@ -96,7 +96,7 @@ case class GpuFlatMapGroupsInPandasExec( } override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(groupingAttributes.map(ShimLoader.getSparkShims.sortOrder(_, Ascending))) + Seq(groupingAttributes.map(SparkShimImpl.sortOrder(_, Ascending))) private val pandasFunction = func.asInstanceOf[GpuPythonUDF].func diff --git a/sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/shims/v2/GpuFileScanRDD.scala b/sql-plugin/src/main/301db/scala/org/apache/spark/sql/rapids/shims/GpuFileScanRDD.scala similarity index 98% rename from sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/shims/v2/GpuFileScanRDD.scala rename to sql-plugin/src/main/301db/scala/org/apache/spark/sql/rapids/shims/GpuFileScanRDD.scala index 73f16b21136..2cbf39dc1e1 100644 --- a/sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/shims/v2/GpuFileScanRDD.scala +++ b/sql-plugin/src/main/301db/scala/org/apache/spark/sql/rapids/shims/GpuFileScanRDD.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.rapids.shims.v2 +package org.apache.spark.sql.rapids.shims import java.io.{FileNotFoundException, IOException} diff --git a/sql-plugin/src/main/301/scala/com/nvidia/spark/rapids/shims/spark301/Spark301Shims.scala b/sql-plugin/src/main/301until304/scala/com/nvidia/spark/rapids/shims/SparkShims.scala similarity index 81% rename from sql-plugin/src/main/301/scala/com/nvidia/spark/rapids/shims/spark301/Spark301Shims.scala rename to sql-plugin/src/main/301until304/scala/com/nvidia/spark/rapids/shims/SparkShims.scala index 1e71d426f7c..5613f378702 100644 --- a/sql-plugin/src/main/301/scala/com/nvidia/spark/rapids/shims/spark301/Spark301Shims.scala +++ b/sql-plugin/src/main/301until304/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -14,17 +14,16 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.spark301 +package com.nvidia.spark.rapids.shims -import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2._ +import com.nvidia.spark.rapids.{ShimLoader, ShimVersion} import org.apache.parquet.schema.MessageType import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters -class Spark301Shims extends Spark30XShims with Spark30Xuntil33XShims { +object SparkShimImpl extends Spark30XShims with Spark30Xuntil33XShims { - override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION + override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion override def getParquetFilters( schema: MessageType, diff --git a/sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/v2/GpuOrcScan.scala b/sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/GpuOrcScan.scala similarity index 96% rename from sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/v2/GpuOrcScan.scala rename to sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/GpuOrcScan.scala index a6e27c66a9e..25917a3e317 100644 --- a/sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/v2/GpuOrcScan.scala +++ b/sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/GpuOrcScan.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{GpuOrcScanBase, RapidsConf} import org.apache.hadoop.conf.Configuration diff --git a/sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/v2/GpuParquetScan.scala b/sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/GpuParquetScan.scala similarity index 96% rename from sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/v2/GpuParquetScan.scala rename to sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/GpuParquetScan.scala index 2841bfd217c..22ff476a7e3 100644 --- a/sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/v2/GpuParquetScan.scala +++ b/sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/GpuParquetScan.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{GpuParquetScanBase, RapidsConf} import org.apache.hadoop.conf.Configuration diff --git a/sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/v2/GpuRowBasedScalaUDF.scala b/sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/GpuRowBasedScalaUDF.scala similarity index 97% rename from sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/v2/GpuRowBasedScalaUDF.scala rename to sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/GpuRowBasedScalaUDF.scala index 79089fac024..2578019916f 100644 --- a/sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/v2/GpuRowBasedScalaUDF.scala +++ b/sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/GpuRowBasedScalaUDF.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{ExprChecks, ExprRule, GpuOverrides, GpuUserDefinedFunction, RepeatingParamCheck, TypeSig} diff --git a/sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/v2/OffsetWindowFunctionMeta.scala b/sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/OffsetWindowFunctionMeta.scala similarity index 96% rename from sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/v2/OffsetWindowFunctionMeta.scala rename to sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/OffsetWindowFunctionMeta.scala index 5200c96f65f..db75dabe32e 100644 --- a/sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/v2/OffsetWindowFunctionMeta.scala +++ b/sql-plugin/src/main/301until310-all/scala/com/nvidia/spark/rapids/shims/OffsetWindowFunctionMeta.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{BaseExprMeta, DataFromReplacementRule, ExprMeta, GpuOverrides, RapidsConf, RapidsMeta} diff --git a/sql-plugin/src/main/301until310-all/scala/org/apache/spark/sql/rapids/shims/v2/GpuSchemaUtils.scala b/sql-plugin/src/main/301until310-all/scala/org/apache/spark/sql/rapids/shims/GpuSchemaUtils.scala similarity index 91% rename from sql-plugin/src/main/301until310-all/scala/org/apache/spark/sql/rapids/shims/v2/GpuSchemaUtils.scala rename to sql-plugin/src/main/301until310-all/scala/org/apache/spark/sql/rapids/shims/GpuSchemaUtils.scala index a665799c488..48ba95be0c3 100644 --- a/sql-plugin/src/main/301until310-all/scala/org/apache/spark/sql/rapids/shims/v2/GpuSchemaUtils.scala +++ b/sql-plugin/src/main/301until310-all/scala/org/apache/spark/sql/rapids/shims/GpuSchemaUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.sql.rapids.shims.v2 +package org.apache.spark.sql.rapids.shims import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.types.StructType diff --git a/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuJoinUtils.scala b/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/GpuJoinUtils.scala similarity index 96% rename from sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuJoinUtils.scala rename to sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/GpuJoinUtils.scala index 75e91d8937b..228b7de7872 100644 --- a/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuJoinUtils.scala +++ b/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/GpuJoinUtils.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{GpuBuildLeft, GpuBuildRight, GpuBuildSide} diff --git a/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala b/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/GpuRegExpReplaceMeta.scala similarity index 98% rename from sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala rename to sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/GpuRegExpReplaceMeta.scala index e3ea7beae21..630216597d8 100644 --- a/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceMeta.scala +++ b/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/GpuRegExpReplaceMeta.scala @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ diff --git a/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala b/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/Spark30XShims.scala similarity index 99% rename from sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala rename to sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/Spark30XShims.scala index b8537f70c3b..f9da8dc0887 100644 --- a/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala +++ b/sql-plugin/src/main/301until310-nondb/scala/com/nvidia/spark/rapids/shims/Spark30XShims.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import java.nio.ByteBuffer @@ -24,7 +24,7 @@ import org.apache.arrow.vector.ValueVector import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging -import org.apache.spark.rapids.shims.v2.GpuShuffleExchangeExec +import org.apache.spark.rapids.shims.GpuShuffleExchangeExec import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.errors.attachTree @@ -41,7 +41,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids.{GpuAbs, GpuAverage, GpuFileSourceScanExec, GpuTimeSub} import org.apache.spark.sql.rapids.execution.GpuShuffleExchangeExecBase import org.apache.spark.sql.rapids.execution.python._ -import org.apache.spark.sql.rapids.execution.python.shims.v2._ +import org.apache.spark.sql.rapids.execution.python.shims._ import org.apache.spark.sql.types._ import org.apache.spark.storage.{BlockId, BlockManagerId} import org.apache.spark.unsafe.types.CalendarInterval diff --git a/sql-plugin/src/main/301until310-nondb/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala b/sql-plugin/src/main/301until310-nondb/scala/org/apache/spark/rapids/shims/GpuShuffleExchangeExec.scala similarity index 96% rename from sql-plugin/src/main/301until310-nondb/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala rename to sql-plugin/src/main/301until310-nondb/scala/org/apache/spark/rapids/shims/GpuShuffleExchangeExec.scala index 60460525fa6..03203b3ecff 100644 --- a/sql-plugin/src/main/301until310-nondb/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala +++ b/sql-plugin/src/main/301until310-nondb/scala/org/apache/spark/rapids/shims/GpuShuffleExchangeExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.rapids.shims.v2 +package org.apache.spark.rapids.shims import com.nvidia.spark.rapids.GpuPartitioning diff --git a/sql-plugin/src/main/301until310-nondb/scala/org/apache/spark/rapids/shims/v2/ShuffledBatchRDDUtil.scala b/sql-plugin/src/main/301until310-nondb/scala/org/apache/spark/rapids/shims/ShuffledBatchRDDUtil.scala similarity index 96% rename from sql-plugin/src/main/301until310-nondb/scala/org/apache/spark/rapids/shims/v2/ShuffledBatchRDDUtil.scala rename to sql-plugin/src/main/301until310-nondb/scala/org/apache/spark/rapids/shims/ShuffledBatchRDDUtil.scala index 6a4df2340f9..3ae5a44b6dc 100644 --- a/sql-plugin/src/main/301until310-nondb/scala/org/apache/spark/rapids/shims/v2/ShuffledBatchRDDUtil.scala +++ b/sql-plugin/src/main/301until310-nondb/scala/org/apache/spark/rapids/shims/ShuffledBatchRDDUtil.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,9 @@ * limitations under the License. */ -package org.apache.spark.rapids.shims.v2 +package org.apache.spark.rapids.shims -import com.nvidia.spark.rapids.ShimLoader +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.spark.{MapOutputTrackerMaster, Partition, ShuffleDependency, SparkEnv, TaskContext} import org.apache.spark.shuffle.ShuffleReader @@ -57,7 +57,7 @@ object ShuffledBatchRDDUtil { dependency: ShuffleDependency[Int, ColumnarBatch, ColumnarBatch], sqlMetricsReporter: SQLShuffleReadMetricsReporter): (ShuffleReader[Nothing, Nothing], Long) = { - val shim = ShimLoader.getSparkShims + val shim = SparkShimImpl split.asInstanceOf[ShuffledBatchRDDPartition].spec match { case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) => val reader = SparkEnv.get.shuffleManager.getReader( diff --git a/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/AvoidAdaptiveTransitionToRow.scala b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/AvoidAdaptiveTransitionToRow.scala similarity index 97% rename from sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/AvoidAdaptiveTransitionToRow.scala rename to sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/AvoidAdaptiveTransitionToRow.scala index 626f027ad4a..82b95ef4edc 100644 --- a/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/AvoidAdaptiveTransitionToRow.scala +++ b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/AvoidAdaptiveTransitionToRow.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import java.lang.reflect.Method diff --git a/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/HashUtils.scala b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/HashUtils.scala similarity index 91% rename from sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/HashUtils.scala rename to sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/HashUtils.scala index f2593bfa97f..9fa25da705c 100644 --- a/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/HashUtils.scala +++ b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/HashUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import ai.rapids.cudf import com.nvidia.spark.rapids.Arm diff --git a/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/OrcShims301until320Base.scala b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/OrcShims301until320Base.scala similarity index 98% rename from sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/OrcShims301until320Base.scala rename to sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/OrcShims301until320Base.scala index b5bb20f6f4a..1c6decd35d5 100644 --- a/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/OrcShims301until320Base.scala +++ b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/OrcShims301until320Base.scala @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import scala.collection.mutable.ArrayBuffer diff --git a/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/RapidsOrcScanMeta.scala b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/RapidsOrcScanMeta.scala similarity index 97% rename from sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/RapidsOrcScanMeta.scala rename to sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/RapidsOrcScanMeta.scala index 23a823177d4..fe52d0dc48f 100644 --- a/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/RapidsOrcScanMeta.scala +++ b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/RapidsOrcScanMeta.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuOrcScanBase, RapidsConf, RapidsMeta, ScanMeta} diff --git a/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/RapidsParquetScanMeta.scala b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/RapidsParquetScanMeta.scala similarity index 97% rename from sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/RapidsParquetScanMeta.scala rename to sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/RapidsParquetScanMeta.scala index 5426f123614..2788d7d0de0 100644 --- a/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/RapidsParquetScanMeta.scala +++ b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/RapidsParquetScanMeta.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuParquetScanBase, RapidsConf, RapidsMeta, ScanMeta} diff --git a/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/ShimAQEShuffleReadExec.scala b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/ShimAQEShuffleReadExec.scala similarity index 96% rename from sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/ShimAQEShuffleReadExec.scala rename to sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/ShimAQEShuffleReadExec.scala index 276f6eb8e08..fad7eb8c395 100644 --- a/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/ShimAQEShuffleReadExec.scala +++ b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/ShimAQEShuffleReadExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ diff --git a/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/ShimDataSourceRDD.scala b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/ShimDataSourceRDD.scala similarity index 92% rename from sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/ShimDataSourceRDD.scala rename to sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/ShimDataSourceRDD.scala index 407a9a6c47d..b2abae3693a 100644 --- a/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/ShimDataSourceRDD.scala +++ b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/ShimDataSourceRDD.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import org.apache.spark.SparkContext import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory} diff --git a/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/TypeSigUtil.scala b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/TypeSigUtil.scala similarity index 96% rename from sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/TypeSigUtil.scala rename to sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/TypeSigUtil.scala index baa4d60b756..1b4c6d20166 100644 --- a/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/TypeSigUtil.scala +++ b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/TypeSigUtil.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{TypeEnum, TypeSig, TypeSigUtilBase} diff --git a/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/YearParseUtil.scala b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/YearParseUtil.scala similarity index 90% rename from sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/YearParseUtil.scala rename to sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/YearParseUtil.scala index 50f746f4969..8eb908ee140 100644 --- a/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/YearParseUtil.scala +++ b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/YearParseUtil.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{RapidsConf, RapidsMeta} diff --git a/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/gpuWindows.scala similarity index 96% rename from sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala rename to sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/gpuWindows.scala index 8a7e42273d0..4f75844f0e2 100644 --- a/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala +++ b/sql-plugin/src/main/301until320-all/scala/com/nvidia/spark/rapids/shims/gpuWindows.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuSpecifiedWindowFrameMetaBase, GpuWindowExpressionMetaBase, ParsedBoundary, RapidsConf, RapidsMeta} diff --git a/sql-plugin/src/main/301until320-all/scala/org/apache/spark/rapids/shims/v2/GpuShuffleBlockResolver.scala b/sql-plugin/src/main/301until320-all/scala/org/apache/spark/rapids/shims/GpuShuffleBlockResolver.scala similarity index 90% rename from sql-plugin/src/main/301until320-all/scala/org/apache/spark/rapids/shims/v2/GpuShuffleBlockResolver.scala rename to sql-plugin/src/main/301until320-all/scala/org/apache/spark/rapids/shims/GpuShuffleBlockResolver.scala index 48671f5171b..26c0c762506 100644 --- a/sql-plugin/src/main/301until320-all/scala/org/apache/spark/rapids/shims/v2/GpuShuffleBlockResolver.scala +++ b/sql-plugin/src/main/301until320-all/scala/org/apache/spark/rapids/shims/GpuShuffleBlockResolver.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.sql.rapids.shims.v2 +package org.apache.spark.sql.rapids.shims import com.nvidia.spark.rapids.ShuffleBufferCatalog diff --git a/sql-plugin/src/main/301until320-all/scala/org/apache/spark/rapids/shims/v2/storage/ShimDiskBlockManager.scala b/sql-plugin/src/main/301until320-all/scala/org/apache/spark/rapids/shims/storage/ShimDiskBlockManager.scala similarity index 89% rename from sql-plugin/src/main/301until320-all/scala/org/apache/spark/rapids/shims/v2/storage/ShimDiskBlockManager.scala rename to sql-plugin/src/main/301until320-all/scala/org/apache/spark/rapids/shims/storage/ShimDiskBlockManager.scala index 65755932ee6..ef06e295fdd 100644 --- a/sql-plugin/src/main/301until320-all/scala/org/apache/spark/rapids/shims/v2/storage/ShimDiskBlockManager.scala +++ b/sql-plugin/src/main/301until320-all/scala/org/apache/spark/rapids/shims/storage/ShimDiskBlockManager.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.rapids.shims.v2.storage +package org.apache.spark.rapids.shims.storage import org.apache.spark.SparkConf import org.apache.spark.storage.DiskBlockManager diff --git a/sql-plugin/src/main/301until320-all/scala/org/apache/spark/sql/rapids/shims/v2/datetimeExpressions.scala b/sql-plugin/src/main/301until320-all/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala similarity index 92% rename from sql-plugin/src/main/301until320-all/scala/org/apache/spark/sql/rapids/shims/v2/datetimeExpressions.scala rename to sql-plugin/src/main/301until320-all/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala index 37f655ad54b..56047d3cdf9 100644 --- a/sql-plugin/src/main/301until320-all/scala/org/apache/spark/sql/rapids/shims/v2/datetimeExpressions.scala +++ b/sql-plugin/src/main/301until320-all/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.sql.rapids.shims.v2 +package org.apache.spark.sql.rapids.shims import ai.rapids.cudf.{ColumnVector, ColumnView, Scalar} diff --git a/sql-plugin/src/main/301until320-noncdh/scala/com/nvidia/spark/rapids/shims/v2/OrcShims.scala b/sql-plugin/src/main/301until320-noncdh/scala/com/nvidia/spark/rapids/shims/OrcShims.scala similarity index 96% rename from sql-plugin/src/main/301until320-noncdh/scala/com/nvidia/spark/rapids/shims/v2/OrcShims.scala rename to sql-plugin/src/main/301until320-noncdh/scala/com/nvidia/spark/rapids/shims/OrcShims.scala index dcac01eefe9..15d99c0464e 100644 --- a/sql-plugin/src/main/301until320-noncdh/scala/com/nvidia/spark/rapids/shims/v2/OrcShims.scala +++ b/sql-plugin/src/main/301until320-noncdh/scala/com/nvidia/spark/rapids/shims/OrcShims.scala @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.RapidsPluginImplicits._ import org.apache.orc.Reader diff --git a/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark301until320Shims.scala b/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/Spark301until320Shims.scala similarity index 99% rename from sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark301until320Shims.scala rename to sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/Spark301until320Shims.scala index d26b816bece..3e77552e5f5 100644 --- a/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark301until320Shims.scala +++ b/sql-plugin/src/main/301until320-nondb/scala/com/nvidia/spark/rapids/shims/Spark301until320Shims.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import java.net.URI import java.nio.ByteBuffer @@ -51,7 +51,7 @@ import org.apache.spark.sql.execution.window.WindowExecBase import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.rapids._ import org.apache.spark.sql.rapids.execution.{GpuCustomShuffleReaderExec, SerializeBatchDeserializeHostBuffer, SerializeConcatHostBuffersDeserializeBatch} -import org.apache.spark.sql.rapids.shims.v2.GpuSchemaUtils +import org.apache.spark.sql.rapids.shims.GpuSchemaUtils import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ diff --git a/sql-plugin/src/main/301until320-nondb/scala/org/apache/spark/rapids/shims/v2/api/python/ShimBasePythonRunner.scala b/sql-plugin/src/main/301until320-nondb/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala similarity index 94% rename from sql-plugin/src/main/301until320-nondb/scala/org/apache/spark/rapids/shims/v2/api/python/ShimBasePythonRunner.scala rename to sql-plugin/src/main/301until320-nondb/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala index 958e1f681f0..9a613fbb676 100644 --- a/sql-plugin/src/main/301until320-nondb/scala/org/apache/spark/rapids/shims/v2/api/python/ShimBasePythonRunner.scala +++ b/sql-plugin/src/main/301until320-nondb/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.rapids.shims.v2.api.python +package org.apache.spark.rapids.shims.api.python import java.io.DataInputStream import java.net.Socket diff --git a/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/AnsiCheckUtil.scala b/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/AnsiCheckUtil.scala similarity index 94% rename from sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/AnsiCheckUtil.scala rename to sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/AnsiCheckUtil.scala index 8dca4837d85..50e9ed8f006 100644 --- a/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/AnsiCheckUtil.scala +++ b/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/AnsiCheckUtil.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import ai.rapids.cudf.ColumnView diff --git a/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/GpuHashPartitioning.scala b/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/GpuHashPartitioning.scala similarity index 97% rename from sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/GpuHashPartitioning.scala rename to sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/GpuHashPartitioning.scala index 9cf5652d163..f259366966c 100644 --- a/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/GpuHashPartitioning.scala +++ b/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/GpuHashPartitioning.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.GpuHashPartitioningBase diff --git a/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/GpuRangePartitioning.scala b/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/GpuRangePartitioning.scala similarity index 98% rename from sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/GpuRangePartitioning.scala rename to sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/GpuRangePartitioning.scala index c0d820cf12f..dfade901023 100644 --- a/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/GpuRangePartitioning.scala +++ b/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/GpuRangePartitioning.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{GpuExpression, GpuPartitioning} diff --git a/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala b/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala new file mode 100644 index 00000000000..d7d4b86311d --- /dev/null +++ b/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.shims + +import ai.rapids.cudf.DType +import com.nvidia.spark.rapids.GpuRowToColumnConverter.TypeConverter + +import org.apache.spark.sql.types.DataType + +object GpuTypeShims { + + /** + * If Shim supports the data type for row to column converter + * @param otherType the data type that should be checked in the Shim + * @return true if Shim support the otherType, false otherwise. + */ + def hasConverterForType(otherType: DataType) : Boolean = false + + /** + * Get the TypeConverter of the data type for this Shim + * Note should first calling hasConverterForType + * @param t the data type + * @param nullable is nullable + * @return the row to column convert for the data type + */ + def getConverterForType(t: DataType, nullable: Boolean): TypeConverter = { + throw new RuntimeException(s"No converter is found for type $t.") + } + + /** + * Get the cuDF type for the Spark data type + * @param t the Spark data type + * @return the cuDF type if the Shim supports + */ + def toRapidsOrNull(t: DataType): DType = null +} diff --git a/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/ParquetFieldIdShims.scala b/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/ParquetFieldIdShims.scala similarity index 97% rename from sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/ParquetFieldIdShims.scala rename to sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/ParquetFieldIdShims.scala index 3830b291961..b584cec5055 100644 --- a/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/ParquetFieldIdShims.scala +++ b/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/ParquetFieldIdShims.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.RapidsMeta import org.apache.hadoop.conf.Configuration diff --git a/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/RapidsErrorUtils.scala b/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala similarity index 96% rename from sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/RapidsErrorUtils.scala rename to sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala index 55857d0b98c..e3ee4ccd6be 100644 --- a/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims object RapidsErrorUtils { def invalidArrayIndexError(index: Int, numElements: Int, diff --git a/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/Spark30Xuntil33XShims.scala b/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/Spark30Xuntil33XShims.scala similarity index 90% rename from sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/Spark30Xuntil33XShims.scala rename to sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/Spark30Xuntil33XShims.scala index 014ed1dccc2..44bafb21d83 100644 --- a/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/v2/Spark30Xuntil33XShims.scala +++ b/sql-plugin/src/main/301until330-all/scala/com/nvidia/spark/rapids/shims/Spark30Xuntil33XShims.scala @@ -14,12 +14,12 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.json.rapids.shims.v2.Spark30Xuntil33XFileOptionsShims +import org.apache.spark.sql.catalyst.json.rapids.shims.Spark30Xuntil33XFileOptionsShims import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.v2._ diff --git a/sql-plugin/src/main/301until330-all/scala/org/apache/spark/sql/catalyst/json/rapids/shims/v2/Spark30Xuntil33XFileOptionsShims.scala b/sql-plugin/src/main/301until330-all/scala/org/apache/spark/sql/catalyst/json/rapids/shims/Spark30Xuntil33XFileOptionsShims.scala similarity index 95% rename from sql-plugin/src/main/301until330-all/scala/org/apache/spark/sql/catalyst/json/rapids/shims/v2/Spark30Xuntil33XFileOptionsShims.scala rename to sql-plugin/src/main/301until330-all/scala/org/apache/spark/sql/catalyst/json/rapids/shims/Spark30Xuntil33XFileOptionsShims.scala index 8789ceaf287..be904decb53 100644 --- a/sql-plugin/src/main/301until330-all/scala/org/apache/spark/sql/catalyst/json/rapids/shims/v2/Spark30Xuntil33XFileOptionsShims.scala +++ b/sql-plugin/src/main/301until330-all/scala/org/apache/spark/sql/catalyst/json/rapids/shims/Spark30Xuntil33XFileOptionsShims.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.json.rapids.shims.v2 +package org.apache.spark.sql.catalyst.json.rapids.shims import com.nvidia.spark.rapids.SparkShims diff --git a/sql-plugin/src/main/302/scala/com/nvidia/spark/rapids/shims/spark302/SparkShimServiceProvider.scala b/sql-plugin/src/main/302/scala/com/nvidia/spark/rapids/shims/spark302/SparkShimServiceProvider.scala index 8db3f9d460c..34c490395f8 100644 --- a/sql-plugin/src/main/302/scala/com/nvidia/spark/rapids/shims/spark302/SparkShimServiceProvider.scala +++ b/sql-plugin/src/main/302/scala/com/nvidia/spark/rapids/shims/spark302/SparkShimServiceProvider.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.shims.spark302 -import com.nvidia.spark.rapids.{SparkShims, SparkShimVersion} +import com.nvidia.spark.rapids.SparkShimVersion object SparkShimServiceProvider { val VERSION = SparkShimVersion(3, 0, 2) @@ -24,11 +24,10 @@ object SparkShimServiceProvider { } class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { + override def getShimVersion: SparkShimVersion = SparkShimServiceProvider.VERSION + def matchesVersion(version: String): Boolean = { SparkShimServiceProvider.VERSIONNAMES.contains(version) } - def buildShim: SparkShims = { - new Spark302Shims() - } } diff --git a/sql-plugin/src/main/303/scala/com/nvidia/spark/rapids/shims/spark303/Spark303Shims.scala b/sql-plugin/src/main/303/scala/com/nvidia/spark/rapids/shims/spark303/Spark303Shims.scala deleted file mode 100644 index 9656e655f74..00000000000 --- a/sql-plugin/src/main/303/scala/com/nvidia/spark/rapids/shims/spark303/Spark303Shims.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids.shims.spark303 - -import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2._ -import org.apache.parquet.schema.MessageType - -import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters - -class Spark303Shims extends Spark30XShims with Spark30Xuntil33XShims { - - override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION - - override def getParquetFilters( - schema: MessageType, - pushDownDate: Boolean, - pushDownTimestamp: Boolean, - pushDownDecimal: Boolean, - pushDownStartWith: Boolean, - pushDownInFilterThreshold: Int, - caseSensitive: Boolean, - lookupFileMeta: String => String, - dateTimeRebaseModeFromConf: String): ParquetFilters = { - new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith, - pushDownInFilterThreshold, caseSensitive) - } -} diff --git a/sql-plugin/src/main/303/scala/com/nvidia/spark/rapids/shims/spark303/SparkShimServiceProvider.scala b/sql-plugin/src/main/303/scala/com/nvidia/spark/rapids/shims/spark303/SparkShimServiceProvider.scala index 3781085b6f3..777cf7539d2 100644 --- a/sql-plugin/src/main/303/scala/com/nvidia/spark/rapids/shims/spark303/SparkShimServiceProvider.scala +++ b/sql-plugin/src/main/303/scala/com/nvidia/spark/rapids/shims/spark303/SparkShimServiceProvider.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.shims.spark303 -import com.nvidia.spark.rapids.{SparkShims, SparkShimVersion} +import com.nvidia.spark.rapids.SparkShimVersion object SparkShimServiceProvider { val VERSION = SparkShimVersion(3, 0, 3) @@ -24,11 +24,9 @@ object SparkShimServiceProvider { } class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { + override def getShimVersion: SparkShimVersion = SparkShimServiceProvider.VERSION + def matchesVersion(version: String): Boolean = { SparkShimServiceProvider.VERSIONNAMES.contains(version) } - - def buildShim: SparkShims = { - new Spark303Shims() - } } diff --git a/sql-plugin/src/main/304/scala/com/nvidia/spark/rapids/shims/spark304/Spark304Shims.scala b/sql-plugin/src/main/304/scala/com/nvidia/spark/rapids/SparkShims.scala similarity index 76% rename from sql-plugin/src/main/304/scala/com/nvidia/spark/rapids/shims/spark304/Spark304Shims.scala rename to sql-plugin/src/main/304/scala/com/nvidia/spark/rapids/SparkShims.scala index ac3e33cddb3..ca7002a6a0e 100644 --- a/sql-plugin/src/main/304/scala/com/nvidia/spark/rapids/shims/spark304/Spark304Shims.scala +++ b/sql-plugin/src/main/304/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,18 +14,17 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.spark304 +package com.nvidia.spark.rapids.shims -import com.nvidia.spark.rapids.ShimVersion -import com.nvidia.spark.rapids.shims.v2._ +import com.nvidia.spark.rapids._ import org.apache.parquet.schema.MessageType import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters -class Spark304Shims extends Spark30XShims with Spark30Xuntil33XShims { +object SparkShimImpl extends Spark30XShims with Spark30Xuntil33XShims { - override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION + override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion override def getParquetFilters( schema: MessageType, @@ -38,7 +37,7 @@ class Spark304Shims extends Spark30XShims with Spark30Xuntil33XShims { lookupFileMeta: String => String, dateTimeRebaseModeFromConf: String): ParquetFilters = { val datetimeRebaseMode = DataSourceUtils - .datetimeRebaseMode(lookupFileMeta, dateTimeRebaseModeFromConf) + .datetimeRebaseMode(lookupFileMeta, dateTimeRebaseModeFromConf) new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith, pushDownInFilterThreshold, caseSensitive, datetimeRebaseMode) } diff --git a/sql-plugin/src/main/304/scala/com/nvidia/spark/rapids/shims/spark304/RapidsShuffleInternalManager.scala b/sql-plugin/src/main/304/scala/com/nvidia/spark/rapids/shims/spark304/RapidsShuffleInternalManager.scala index cd094e87959..2fcbc75ab3a 100644 --- a/sql-plugin/src/main/304/scala/com/nvidia/spark/rapids/shims/spark304/RapidsShuffleInternalManager.scala +++ b/sql-plugin/src/main/304/scala/com/nvidia/spark/rapids/shims/spark304/RapidsShuffleInternalManager.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,9 +51,8 @@ class RapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) } - class ProxyRapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) - extends ProxyRapidsShuffleInternalManagerBase(conf, isDriver) with ShuffleManager { + extends ProxyRapidsShuffleInternalManagerBase(conf, isDriver) with ShuffleManager { override def getReader[K, C]( handle: ShuffleHandle, @@ -77,4 +76,4 @@ class ProxyRapidsShuffleInternalManager(conf: SparkConf, isDriver: Boolean) self.getReaderForRange(handle, startMapIndex, endMapIndex, startPartition, endPartition, context, metrics) } -} +} \ No newline at end of file diff --git a/sql-plugin/src/main/304/scala/com/nvidia/spark/rapids/shims/spark304/SparkShimServiceProvider.scala b/sql-plugin/src/main/304/scala/com/nvidia/spark/rapids/shims/spark304/SparkShimServiceProvider.scala index 4eb1d73614d..6a9ff15da3e 100644 --- a/sql-plugin/src/main/304/scala/com/nvidia/spark/rapids/shims/spark304/SparkShimServiceProvider.scala +++ b/sql-plugin/src/main/304/scala/com/nvidia/spark/rapids/shims/spark304/SparkShimServiceProvider.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.shims.spark304 -import com.nvidia.spark.rapids.{SparkShims, SparkShimVersion} +import com.nvidia.spark.rapids.SparkShimVersion object SparkShimServiceProvider { val VERSION = SparkShimVersion(3, 0, 4) @@ -24,11 +24,9 @@ object SparkShimServiceProvider { } class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { + override def getShimVersion: SparkShimVersion = SparkShimServiceProvider.VERSION + def matchesVersion(version: String): Boolean = { SparkShimServiceProvider.VERSIONNAMES.contains(version) } - - def buildShim: SparkShims = { - new Spark304Shims() - } } diff --git a/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/ParquetCachedBatchSerializer.scala b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/ParquetCachedBatchSerializer.scala index 69455741f1e..b3ca7244694 100644 --- a/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/ParquetCachedBatchSerializer.scala +++ b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/ParquetCachedBatchSerializer.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.columnar.{CachedBatch, CachedBatchSerializer} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} -import org.apache.spark.sql.rapids.shims.v2.GpuInMemoryTableScanExec +import org.apache.spark.sql.rapids.shims.GpuInMemoryTableScanExec import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage.StorageLevel @@ -80,7 +80,7 @@ class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer { } private lazy val realSerializer: GpuCachedBatchSerializer = { - ShimLoader.newInstanceOf("com.nvidia.spark.rapids.shims.v2.ParquetCachedBatchSerializer") + ShimLoader.newInstanceOf("com.nvidia.spark.rapids.shims.ParquetCachedBatchSerializer") } /** diff --git a/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/GpuJoinUtils.scala b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/GpuJoinUtils.scala similarity index 96% rename from sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/GpuJoinUtils.scala rename to sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/GpuJoinUtils.scala index c41fc5b578a..3bdb81676bf 100644 --- a/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/GpuJoinUtils.scala +++ b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/GpuJoinUtils.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{GpuBuildLeft, GpuBuildRight, GpuBuildSide} diff --git a/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/GpuOrcScan.scala b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/GpuOrcScan.scala similarity index 96% rename from sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/GpuOrcScan.scala rename to sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/GpuOrcScan.scala index a59f873c45f..c4b541599a6 100644 --- a/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/GpuOrcScan.scala +++ b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/GpuOrcScan.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{GpuOrcScanBase, RapidsConf} import org.apache.hadoop.conf.Configuration diff --git a/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/GpuParquetScan.scala b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/GpuParquetScan.scala similarity index 96% rename from sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/GpuParquetScan.scala rename to sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/GpuParquetScan.scala index e018ac7a035..09a7f3b1408 100644 --- a/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/GpuParquetScan.scala +++ b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/GpuParquetScan.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{GpuParquetScanBase, RapidsConf} import org.apache.hadoop.conf.Configuration diff --git a/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/GpuRowBasedScalaUDF.scala b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/GpuRowBasedScalaUDF.scala similarity index 98% rename from sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/GpuRowBasedScalaUDF.scala rename to sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/GpuRowBasedScalaUDF.scala index f855fffbc2c..da42fc3ab66 100644 --- a/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/GpuRowBasedScalaUDF.scala +++ b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/GpuRowBasedScalaUDF.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{ExprChecks, ExprRule, GpuOverrides, GpuUserDefinedFunction, RepeatingParamCheck, TypeSig} diff --git a/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/ParquetCachedBatchSerializer.scala b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/ParquetCachedBatchSerializer.scala similarity index 98% rename from sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/ParquetCachedBatchSerializer.scala rename to sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/ParquetCachedBatchSerializer.scala index 383fe646cfc..db8a28ef3e8 100644 --- a/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/v2/ParquetCachedBatchSerializer.scala +++ b/sql-plugin/src/main/311+-all/scala/com/nvidia/spark/rapids/shims/ParquetCachedBatchSerializer.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import java.io.{InputStream, IOException} import java.lang.reflect.Method @@ -53,8 +53,8 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.columnar.CachedBatch -import org.apache.spark.sql.execution.datasources.parquet.{ParquetReadSupport, ParquetToSparkSchemaConverter, ParquetWriteSupport, SparkToParquetSchemaConverter, VectorizedColumnReader} -import org.apache.spark.sql.execution.datasources.parquet.rapids.shims.v2.{ParquetRecordMaterializer, ShimVectorizedColumnReader} +import org.apache.spark.sql.execution.datasources.parquet.{ParquetReadSupport, ParquetToSparkSchemaConverter, ParquetWriteSupport, VectorizedColumnReader} +import org.apache.spark.sql.execution.datasources.parquet.rapids.shims.{ParquetRecordMaterializer, ShimVectorizedColumnReader} import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, WritableColumnVector} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy @@ -889,13 +889,9 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi val inMemCacheSparkSchema = parquetToSparkSchemaConverter.convert(inMemCacheParquetSchema) val totalRowCount = parquetFileReader.getRowGroups.asScala.map(_.getRowCount).sum - val sparkToParquetSchemaConverter = new SparkToParquetSchemaConverter(hadoopConf) val inMemReqSparkSchema = StructType(selectedAttributes.toStructType.map { field => inMemCacheSparkSchema.fields(inMemCacheSparkSchema.fieldIndex(field.name)) }) - val inMemReqParquetSchema = sparkToParquetSchemaConverter.convert(inMemReqSparkSchema) - val columnsRequested: util.List[ColumnDescriptor] = inMemReqParquetSchema.getColumns - val reqSparkSchemaInCacheOrder = StructType(inMemCacheSparkSchema.filter(f => inMemReqSparkSchema.fields.exists(f0 => f0.name.equals(f.name)))) @@ -907,23 +903,26 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi index -> inMemReqSparkSchema.fields.indexOf(reqSparkSchemaInCacheOrder.fields(index)) }.toMap - val reqParquetSchemaInCacheOrder = - sparkToParquetSchemaConverter.convert(reqSparkSchemaInCacheOrder) + val reqParquetSchemaInCacheOrder = new org.apache.parquet.schema.MessageType( + inMemCacheParquetSchema.getName(), reqSparkSchemaInCacheOrder.fields.map { f => + inMemCacheParquetSchema.getFields().get(inMemCacheParquetSchema.getFieldIndex(f.name)) + }:_*) + val columnsRequested: util.List[ColumnDescriptor] = reqParquetSchemaInCacheOrder.getColumns // reset spark schema calculated from parquet schema hadoopConf.set(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, inMemReqSparkSchema.json) hadoopConf.set(ParquetWriteSupport.SPARK_ROW_SCHEMA, inMemReqSparkSchema.json) val columnsInCache: util.List[ColumnDescriptor] = reqParquetSchemaInCacheOrder.getColumns val typesInCache: util.List[Type] = reqParquetSchemaInCacheOrder.asGroupType.getFields - val missingColumns = new Array[Boolean](inMemReqParquetSchema.getFieldCount) + val missingColumns = new Array[Boolean](reqParquetSchemaInCacheOrder.getFieldCount) // initialize missingColumns to cover the case where requested column isn't present in the // cache, which should never happen but just in case it does - val paths: util.List[Array[String]] = inMemReqParquetSchema.getPaths + val paths: util.List[Array[String]] = reqParquetSchemaInCacheOrder.getPaths - for (i <- 0 until inMemReqParquetSchema.getFieldCount) { - val t = inMemReqParquetSchema.getFields.get(i) + for (i <- 0 until reqParquetSchemaInCacheOrder.getFieldCount) { + val t = reqParquetSchemaInCacheOrder.getFields.get(i) if (!t.isPrimitive || t.isRepetition(Type.Repetition.REPEATED)) { throw new UnsupportedOperationException("Complex types not supported.") } @@ -1242,7 +1241,7 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi // at least a single block val stream = new ByteArrayOutputStream(ByteArrayOutputFile.BLOCK_SIZE) val outputFile: OutputFile = new ByteArrayOutputFile(stream) - conf.setConfString(ShimLoader.getSparkShims.parquetRebaseWriteKey, + conf.setConfString(SparkShimImpl.parquetRebaseWriteKey, LegacyBehaviorPolicy.CORRECTED.toString) val recordWriter = SQLConf.withExistingConf(conf) { parquetOutputFileFormat.getRecordWriter(outputFile, hadoopConf) @@ -1422,7 +1421,7 @@ protected class ParquetCachedBatchSerializer extends GpuCachedBatchSerializer wi hadoopConf.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, false) hadoopConf.setBoolean(SQLConf.CASE_SENSITIVE.key, false) - hadoopConf.set(ShimLoader.getSparkShims.parquetRebaseWriteKey, + hadoopConf.set(SparkShimImpl.parquetRebaseWriteKey, LegacyBehaviorPolicy.CORRECTED.toString) hadoopConf.set(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key, diff --git a/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetMaterializer.scala b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetMaterializer.scala index 8154ded831d..ce0f3dab79c 100644 --- a/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetMaterializer.scala +++ b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/ParquetMaterializer.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.parquet.rapids.shims.v2 +package org.apache.spark.sql.execution.datasources.parquet.rapids.shims import java.time.ZoneId diff --git a/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/GpuColumnarToRowTransitionExec.scala b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/GpuColumnarToRowTransitionExec.scala similarity index 92% rename from sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/GpuColumnarToRowTransitionExec.scala rename to sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/GpuColumnarToRowTransitionExec.scala index 43395116897..56d18dc3770 100644 --- a/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/GpuColumnarToRowTransitionExec.scala +++ b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/GpuColumnarToRowTransitionExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.sql.rapids.shims.v2 +package org.apache.spark.sql.rapids.shims import com.nvidia.spark.rapids.GpuColumnarToRowExecParent diff --git a/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/GpuInMemoryTableScanExec.scala b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/GpuInMemoryTableScanExec.scala similarity index 97% rename from sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/GpuInMemoryTableScanExec.scala rename to sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/GpuInMemoryTableScanExec.scala index a4261581983..b7e54ced798 100644 --- a/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/GpuInMemoryTableScanExec.scala +++ b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/GpuInMemoryTableScanExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.sql.rapids.shims.v2 +package org.apache.spark.sql.rapids.shims import com.nvidia.spark.ParquetCachedBatchSerializer import com.nvidia.spark.rapids.{GpuExec, GpuMetric} diff --git a/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/GpuSchemaUtils.scala b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/GpuSchemaUtils.scala similarity index 91% rename from sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/GpuSchemaUtils.scala rename to sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/GpuSchemaUtils.scala index f4203207455..587b7efbe40 100644 --- a/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/GpuSchemaUtils.scala +++ b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/GpuSchemaUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.sql.rapids.shims.v2 +package org.apache.spark.sql.rapids.shims import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.types.StructType diff --git a/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/HadoopFSUtilsShim.scala b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/HadoopFSUtilsShim.scala similarity index 88% rename from sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/HadoopFSUtilsShim.scala rename to sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/HadoopFSUtilsShim.scala index 7b0cd1fba3a..6952a68f0da 100644 --- a/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/v2/HadoopFSUtilsShim.scala +++ b/sql-plugin/src/main/311+-all/scala/org/apache/spark/sql/rapids/shims/HadoopFSUtilsShim.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.sql.rapids.shims.v2 +package org.apache.spark.sql.rapids.shims import org.apache.spark.util.HadoopFSUtils diff --git a/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/GpuRegExpReplaceExec.scala similarity index 98% rename from sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala rename to sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/GpuRegExpReplaceExec.scala index c4ec686bf4f..caf92cbeedc 100644 --- a/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/311+-nondb/scala/com/nvidia/spark/rapids/shims/GpuRegExpReplaceExec.scala @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException} diff --git a/sql-plugin/src/main/311+-nondb/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala b/sql-plugin/src/main/311+-nondb/scala/org/apache/spark/rapids/shims/GpuShuffleExchangeExec.scala similarity index 96% rename from sql-plugin/src/main/311+-nondb/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala rename to sql-plugin/src/main/311+-nondb/scala/org/apache/spark/rapids/shims/GpuShuffleExchangeExec.scala index 7528ba635b9..9ad02ce0ba2 100644 --- a/sql-plugin/src/main/311+-nondb/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala +++ b/sql-plugin/src/main/311+-nondb/scala/org/apache/spark/rapids/shims/GpuShuffleExchangeExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.rapids.shims.v2 +package org.apache.spark.rapids.shims import com.nvidia.spark.rapids.GpuPartitioning diff --git a/sql-plugin/src/main/311/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala b/sql-plugin/src/main/311-nondb/scala/com/nvidia/spark/rapids/shims/SparkShims.scala similarity index 84% rename from sql-plugin/src/main/311/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala rename to sql-plugin/src/main/311-nondb/scala/com/nvidia/spark/rapids/shims/SparkShims.scala index 343270588c2..df43190dc7f 100644 --- a/sql-plugin/src/main/311/scala/com/nvidia/spark/rapids/shims/spark311/Spark311Shims.scala +++ b/sql-plugin/src/main/311-nondb/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -14,17 +14,16 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.spark311 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2._ import org.apache.parquet.schema.MessageType import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters -class Spark311Shims extends Spark31XShims with Spark30Xuntil33XShims { +object SparkShimImpl extends Spark31XShims with Spark30Xuntil33XShims { - override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION + override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion override def hasCastFloatTimestampUpcast: Boolean = false diff --git a/sql-plugin/src/main/311/scala/com/nvidia/spark/rapids/shims/spark311/SparkShimServiceProvider.scala b/sql-plugin/src/main/311-nondb/scala/com/nvidia/spark/rapids/shims/spark311/SparkShimServiceProvider.scala similarity index 84% rename from sql-plugin/src/main/311/scala/com/nvidia/spark/rapids/shims/spark311/SparkShimServiceProvider.scala rename to sql-plugin/src/main/311-nondb/scala/com/nvidia/spark/rapids/shims/spark311/SparkShimServiceProvider.scala index 3e7e694564a..58f58c485ec 100644 --- a/sql-plugin/src/main/311/scala/com/nvidia/spark/rapids/shims/spark311/SparkShimServiceProvider.scala +++ b/sql-plugin/src/main/311-nondb/scala/com/nvidia/spark/rapids/shims/spark311/SparkShimServiceProvider.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.shims.spark311 -import com.nvidia.spark.rapids.{SparkShims, SparkShimVersion} +import com.nvidia.spark.rapids.SparkShimVersion object SparkShimServiceProvider { val VERSION = SparkShimVersion(3, 1, 1) @@ -25,11 +25,9 @@ object SparkShimServiceProvider { class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { + override def getShimVersion: SparkShimVersion = SparkShimServiceProvider.VERSION + def matchesVersion(version: String): Boolean = { SparkShimServiceProvider.VERSIONNAMES.contains(version) } - - def buildShim: SparkShims = { - new Spark311Shims() - } } diff --git a/sql-plugin/src/main/311/scala/com/nvidia/spark/rapids/spark311/RapidsShuffleManager.scala b/sql-plugin/src/main/311-nondb/scala/com/nvidia/spark/rapids/spark311/RapidsShuffleManager.scala similarity index 100% rename from sql-plugin/src/main/311/scala/com/nvidia/spark/rapids/spark311/RapidsShuffleManager.scala rename to sql-plugin/src/main/311-nondb/scala/com/nvidia/spark/rapids/spark311/RapidsShuffleManager.scala diff --git a/sql-plugin/src/main/311/scala/org/apache/spark/sql/rapids/shims/spark311/RapidsShuffleInternalManager.scala b/sql-plugin/src/main/311-nondb/scala/org/apache/spark/sql/rapids/shims/spark311/RapidsShuffleInternalManager.scala similarity index 100% rename from sql-plugin/src/main/311/scala/org/apache/spark/sql/rapids/shims/spark311/RapidsShuffleInternalManager.scala rename to sql-plugin/src/main/311-nondb/scala/org/apache/spark/sql/rapids/shims/spark311/RapidsShuffleInternalManager.scala diff --git a/sql-plugin/src/main/311cdh/scala/com/nvidia/spark/rapids/shims/v2/OrcShims.scala b/sql-plugin/src/main/311cdh/scala/com/nvidia/spark/rapids/shims/OrcShims.scala similarity index 95% rename from sql-plugin/src/main/311cdh/scala/com/nvidia/spark/rapids/shims/v2/OrcShims.scala rename to sql-plugin/src/main/311cdh/scala/com/nvidia/spark/rapids/shims/OrcShims.scala index ddc4534cb39..d8e273b45bb 100644 --- a/sql-plugin/src/main/311cdh/scala/com/nvidia/spark/rapids/shims/v2/OrcShims.scala +++ b/sql-plugin/src/main/311cdh/scala/com/nvidia/spark/rapids/shims/OrcShims.scala @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import org.apache.orc.Reader diff --git a/sql-plugin/src/main/311cdh/scala/com/nvidia/spark/rapids/shims/spark311cdh/SparkShimServiceProvider.scala b/sql-plugin/src/main/311cdh/scala/com/nvidia/spark/rapids/shims/spark311cdh/SparkShimServiceProvider.scala index 41a73ac8a2e..fcbd94d50cc 100644 --- a/sql-plugin/src/main/311cdh/scala/com/nvidia/spark/rapids/shims/spark311cdh/SparkShimServiceProvider.scala +++ b/sql-plugin/src/main/311cdh/scala/com/nvidia/spark/rapids/shims/spark311cdh/SparkShimServiceProvider.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.shims.spark311cdh -import com.nvidia.spark.rapids.{ClouderaShimVersion, SparkShims} +import com.nvidia.spark.rapids.{ClouderaShimVersion, ShimVersion} object SparkShimServiceProvider { val VERSION = ClouderaShimVersion(3, 1, 1, "3.1.7270") @@ -26,11 +26,9 @@ object SparkShimServiceProvider { class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { + override def getShimVersion: ShimVersion = SparkShimServiceProvider.VERSION + def matchesVersion(version: String): Boolean = { version.contains(SparkShimServiceProvider.CDH_BASE_VERSION) } - - def buildShim: SparkShims = { - new Spark311CDHShims() - } } diff --git a/sql-plugin/src/main/311cdh/scala/com/nvidia/spark/rapids/spark311cdh/RapidsShuffleManager.scala b/sql-plugin/src/main/311cdh/scala/com/nvidia/spark/rapids/spark311cdh/RapidsShuffleManager.scala index 6748239f37b..0e5b181d9c7 100644 --- a/sql-plugin/src/main/311cdh/scala/com/nvidia/spark/rapids/spark311cdh/RapidsShuffleManager.scala +++ b/sql-plugin/src/main/311cdh/scala/com/nvidia/spark/rapids/spark311cdh/RapidsShuffleManager.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,4 +24,3 @@ sealed class RapidsShuffleManager( conf: SparkConf, isDriver: Boolean) extends ProxyRapidsShuffleInternalManager(conf, isDriver) { } - diff --git a/sql-plugin/src/main/311cdh/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/v2/ShimVectorizedColumnReader.scala b/sql-plugin/src/main/311cdh/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ShimVectorizedColumnReader.scala similarity index 97% rename from sql-plugin/src/main/311cdh/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/v2/ShimVectorizedColumnReader.scala rename to sql-plugin/src/main/311cdh/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ShimVectorizedColumnReader.scala index f0dd9d81beb..b1d624406f4 100644 --- a/sql-plugin/src/main/311cdh/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/v2/ShimVectorizedColumnReader.scala +++ b/sql-plugin/src/main/311cdh/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ShimVectorizedColumnReader.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.parquet.rapids.shims.v2 +package org.apache.spark.sql.execution.datasources.parquet.rapids.shims import java.time.ZoneId diff --git a/sql-plugin/src/main/311until320-all/scala/com/nvidia/spark/rapids/shims/v2/OffsetWindowFunctionMeta.scala b/sql-plugin/src/main/311until320-all/scala/com/nvidia/spark/rapids/shims/OffsetWindowFunctionMeta.scala similarity index 97% rename from sql-plugin/src/main/311until320-all/scala/com/nvidia/spark/rapids/shims/v2/OffsetWindowFunctionMeta.scala rename to sql-plugin/src/main/311until320-all/scala/com/nvidia/spark/rapids/shims/OffsetWindowFunctionMeta.scala index d10ae6dd6a7..e520cbcc32f 100644 --- a/sql-plugin/src/main/311until320-all/scala/com/nvidia/spark/rapids/shims/v2/OffsetWindowFunctionMeta.scala +++ b/sql-plugin/src/main/311until320-all/scala/com/nvidia/spark/rapids/shims/OffsetWindowFunctionMeta.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{BaseExprMeta, DataFromReplacementRule, ExprMeta, GpuOverrides, RapidsConf, RapidsMeta} diff --git a/sql-plugin/src/main/311until320-all/scala/org/apache/spark/rapids/shims/v2/ShuffledBatchRDDUtil.scala b/sql-plugin/src/main/311until320-all/scala/org/apache/spark/rapids/shims/ShuffledBatchRDDUtil.scala similarity index 96% rename from sql-plugin/src/main/311until320-all/scala/org/apache/spark/rapids/shims/v2/ShuffledBatchRDDUtil.scala rename to sql-plugin/src/main/311until320-all/scala/org/apache/spark/rapids/shims/ShuffledBatchRDDUtil.scala index df89343ff32..c0e049001aa 100644 --- a/sql-plugin/src/main/311until320-all/scala/org/apache/spark/rapids/shims/v2/ShuffledBatchRDDUtil.scala +++ b/sql-plugin/src/main/311until320-all/scala/org/apache/spark/rapids/shims/ShuffledBatchRDDUtil.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,9 @@ * limitations under the License. */ -package org.apache.spark.rapids.shims.v2 +package org.apache.spark.rapids.shims -import com.nvidia.spark.rapids.ShimLoader +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.spark.{MapOutputTrackerMaster, Partition, ShuffleDependency, SparkEnv, TaskContext} import org.apache.spark.shuffle.ShuffleReader @@ -57,7 +57,7 @@ object ShuffledBatchRDDUtil { dependency: ShuffleDependency[Int, ColumnarBatch, ColumnarBatch], sqlMetricsReporter: SQLShuffleReadMetricsReporter): (ShuffleReader[Nothing, Nothing], Long) = { - val shim = ShimLoader.getSparkShims + val shim = SparkShimImpl split.asInstanceOf[ShuffledBatchRDDPartition].spec match { case CoalescedPartitionSpec(startReducerIndex, endReducerIndex) => val reader = SparkEnv.get.shuffleManager.getReader( diff --git a/sql-plugin/src/main/311until320-noncdh/scala/org/apache/spark/rapids/shims/v2/sql/execution/datasources/parquet/rapids/ShimVectorizedColumnReader.scala b/sql-plugin/src/main/311until320-noncdh/scala/org/apache/spark/rapids/shims/sql/execution/datasources/parquet/rapids/ShimVectorizedColumnReader.scala similarity index 97% rename from sql-plugin/src/main/311until320-noncdh/scala/org/apache/spark/rapids/shims/v2/sql/execution/datasources/parquet/rapids/ShimVectorizedColumnReader.scala rename to sql-plugin/src/main/311until320-noncdh/scala/org/apache/spark/rapids/shims/sql/execution/datasources/parquet/rapids/ShimVectorizedColumnReader.scala index cd972dfa9f5..53d265c784e 100644 --- a/sql-plugin/src/main/311until320-noncdh/scala/org/apache/spark/rapids/shims/v2/sql/execution/datasources/parquet/rapids/ShimVectorizedColumnReader.scala +++ b/sql-plugin/src/main/311until320-noncdh/scala/org/apache/spark/rapids/shims/sql/execution/datasources/parquet/rapids/ShimVectorizedColumnReader.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.parquet.rapids.shims.v2 +package org.apache.spark.sql.execution.datasources.parquet.rapids.shims import java.time.ZoneId diff --git a/sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XShims.scala b/sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/Spark31XShims.scala similarity index 98% rename from sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XShims.scala rename to sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/Spark31XShims.scala index b4b735c11b8..c36ace53bc8 100644 --- a/sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XShims.scala +++ b/sql-plugin/src/main/311until320-nondb/scala/com/nvidia/spark/rapids/shims/Spark31XShims.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import java.nio.ByteBuffer @@ -25,7 +25,7 @@ import org.apache.arrow.vector.ValueVector import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging -import org.apache.spark.rapids.shims.v2.GpuShuffleExchangeExec +import org.apache.spark.rapids.shims.GpuShuffleExchangeExec import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -45,8 +45,8 @@ import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.rapids._ import org.apache.spark.sql.rapids.execution.GpuShuffleExchangeExecBase import org.apache.spark.sql.rapids.execution.python._ -import org.apache.spark.sql.rapids.execution.python.shims.v2._ -import org.apache.spark.sql.rapids.shims.v2.{GpuColumnarToRowTransitionExec, HadoopFSUtilsShim} +import org.apache.spark.sql.rapids.execution.python.shims._ +import org.apache.spark.sql.rapids.shims.{GpuColumnarToRowTransitionExec, HadoopFSUtilsShim} import org.apache.spark.sql.types._ import org.apache.spark.storage.{BlockId, BlockManagerId} @@ -430,7 +430,7 @@ abstract class Spark31XShims extends Spark301until320Shims with Logging { }), GpuOverrides.exec[InMemoryTableScanExec]( "Implementation of InMemoryTableScanExec to use GPU accelerated Caching", - ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.STRUCT + ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP).nested(), TypeSig.all), (scan, conf, p, r) => new InMemoryTableScanMeta(scan, conf, p, r)), GpuOverrides.exec[ArrowEvalPythonExec]( diff --git a/sql-plugin/src/main/302/scala/com/nvidia/spark/rapids/shims/spark302/Spark302Shims.scala b/sql-plugin/src/main/312-nondb/scala/com/nvidia/spark/rapids/shims/SparkShims.scala similarity index 83% rename from sql-plugin/src/main/302/scala/com/nvidia/spark/rapids/shims/spark302/Spark302Shims.scala rename to sql-plugin/src/main/312-nondb/scala/com/nvidia/spark/rapids/shims/SparkShims.scala index dccd459ae4c..8ac77702b23 100644 --- a/sql-plugin/src/main/302/scala/com/nvidia/spark/rapids/shims/spark302/Spark302Shims.scala +++ b/sql-plugin/src/main/312-nondb/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -14,17 +14,18 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.spark302 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2._ import org.apache.parquet.schema.MessageType import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters -class Spark302Shims extends Spark30XShims with Spark30Xuntil33XShims { +object SparkShimImpl extends Spark31XShims with Spark30Xuntil33XShims { - override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION + override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion + + override def hasCastFloatTimestampUpcast: Boolean = true override def getParquetFilters( schema: MessageType, diff --git a/sql-plugin/src/main/312/scala/com/nvidia/spark/rapids/shims/spark312/SparkShimServiceProvider.scala b/sql-plugin/src/main/312-nondb/scala/com/nvidia/spark/rapids/shims/spark312/SparkShimServiceProvider.scala similarity index 88% rename from sql-plugin/src/main/312/scala/com/nvidia/spark/rapids/shims/spark312/SparkShimServiceProvider.scala rename to sql-plugin/src/main/312-nondb/scala/com/nvidia/spark/rapids/shims/spark312/SparkShimServiceProvider.scala index c5a0e1ab3ca..1a969565b67 100644 --- a/sql-plugin/src/main/312/scala/com/nvidia/spark/rapids/shims/spark312/SparkShimServiceProvider.scala +++ b/sql-plugin/src/main/312-nondb/scala/com/nvidia/spark/rapids/shims/spark312/SparkShimServiceProvider.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.shims.spark312 -import com.nvidia.spark.rapids.{SparkShims, SparkShimVersion} +import com.nvidia.spark.rapids.SparkShimVersion object SparkShimServiceProvider { val VERSION = SparkShimVersion(3, 1, 2) @@ -25,11 +25,9 @@ object SparkShimServiceProvider { class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { + override def getShimVersion: SparkShimVersion = SparkShimServiceProvider.VERSION + def matchesVersion(version: String): Boolean = { SparkShimServiceProvider.VERSIONNAMES.contains(version) } - - def buildShim: SparkShims = { - new Spark312Shims() - } } diff --git a/sql-plugin/src/main/312/scala/com/nvidia/spark/rapids/spark312/RapidsShuffleManager.scala b/sql-plugin/src/main/312-nondb/scala/com/nvidia/spark/rapids/spark312/RapidsShuffleManager.scala similarity index 100% rename from sql-plugin/src/main/312/scala/com/nvidia/spark/rapids/spark312/RapidsShuffleManager.scala rename to sql-plugin/src/main/312-nondb/scala/com/nvidia/spark/rapids/spark312/RapidsShuffleManager.scala diff --git a/sql-plugin/src/main/312/scala/org/apache/spark/sql/rapids/shims/spark312/RapidsShuffleInternalManager.scala b/sql-plugin/src/main/312-nondb/scala/org/apache/spark/sql/rapids/shims/spark312/RapidsShuffleInternalManager.scala similarity index 100% rename from sql-plugin/src/main/312/scala/org/apache/spark/sql/rapids/shims/spark312/RapidsShuffleInternalManager.scala rename to sql-plugin/src/main/312-nondb/scala/org/apache/spark/sql/rapids/shims/spark312/RapidsShuffleInternalManager.scala diff --git a/sql-plugin/src/main/312/scala/com/nvidia/spark/rapids/shims/spark312/Spark312Shims.scala b/sql-plugin/src/main/312/scala/com/nvidia/spark/rapids/shims/spark312/Spark312Shims.scala deleted file mode 100644 index 41f6bd44ca5..00000000000 --- a/sql-plugin/src/main/312/scala/com/nvidia/spark/rapids/shims/spark312/Spark312Shims.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids.shims.spark312 - -import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2._ -import org.apache.parquet.schema.MessageType - -import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters - -class Spark312Shims extends Spark31XShims with Spark30Xuntil33XShims { - - override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION - - override def hasCastFloatTimestampUpcast: Boolean = true - - override def getParquetFilters( - schema: MessageType, - pushDownDate: Boolean, - pushDownTimestamp: Boolean, - pushDownDecimal: Boolean, - pushDownStartWith: Boolean, - pushDownInFilterThreshold: Int, - caseSensitive: Boolean, - lookupFileMeta: String => String, - dateTimeRebaseModeFromConf: String): ParquetFilters = { - new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith, - pushDownInFilterThreshold, caseSensitive) - } -} diff --git a/sql-plugin/src/main/312db/scala/com/nvidia/spark/rapids/shims/spark312db/Spark312dbShims.scala b/sql-plugin/src/main/312db/scala/com/nvidia/spark/rapids/shims/SparkShims.scala similarity index 86% rename from sql-plugin/src/main/312db/scala/com/nvidia/spark/rapids/shims/spark312db/Spark312dbShims.scala rename to sql-plugin/src/main/312db/scala/com/nvidia/spark/rapids/shims/SparkShims.scala index db0a12244a3..fa544e24dff 100644 --- a/sql-plugin/src/main/312db/scala/com/nvidia/spark/rapids/shims/spark312db/Spark312dbShims.scala +++ b/sql-plugin/src/main/312db/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -14,18 +14,17 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.spark312db +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2._ import org.apache.parquet.schema.MessageType import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters -class Spark312dbShims extends Spark31XdbShims with Spark30Xuntil33XShims { +object SparkShimImpl extends Spark31XdbShims with Spark30Xuntil33XShims { - override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION + override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion override def getParquetFilters( schema: MessageType, diff --git a/sql-plugin/src/main/312db/scala/com/nvidia/spark/rapids/shims/spark312db/SparkShimServiceProvider.scala b/sql-plugin/src/main/312db/scala/com/nvidia/spark/rapids/shims/spark312db/SparkShimServiceProvider.scala index f1095ab593e..8c095fb9ef2 100644 --- a/sql-plugin/src/main/312db/scala/com/nvidia/spark/rapids/shims/spark312db/SparkShimServiceProvider.scala +++ b/sql-plugin/src/main/312db/scala/com/nvidia/spark/rapids/shims/spark312db/SparkShimServiceProvider.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.shims.spark312db -import com.nvidia.spark.rapids.{DatabricksShimVersion, SparkShims} +import com.nvidia.spark.rapids.{DatabricksShimVersion, ShimVersion} import org.apache.spark.SparkEnv @@ -26,11 +26,9 @@ object SparkShimServiceProvider { class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { + override def getShimVersion: ShimVersion = SparkShimServiceProvider.VERSION + def matchesVersion(version: String): Boolean = { SparkEnv.get.conf.get("spark.databricks.clusterUsageTags.sparkVersion", "").startsWith("9.1.") } - - def buildShim: SparkShims = { - new Spark312dbShims() - } } diff --git a/sql-plugin/src/main/313/scala/com/nvidia/spark/rapids/shims/spark313/Spark313Shims.scala b/sql-plugin/src/main/313/scala/com/nvidia/spark/rapids/shims/SparkShims.scala similarity index 86% rename from sql-plugin/src/main/313/scala/com/nvidia/spark/rapids/shims/spark313/Spark313Shims.scala rename to sql-plugin/src/main/313/scala/com/nvidia/spark/rapids/shims/SparkShims.scala index 6f5417829b0..ce884783e7b 100644 --- a/sql-plugin/src/main/313/scala/com/nvidia/spark/rapids/shims/spark313/Spark313Shims.scala +++ b/sql-plugin/src/main/313/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -14,18 +14,17 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.spark313 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2._ import org.apache.parquet.schema.MessageType import org.apache.spark.sql.execution.datasources.DataSourceUtils import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters -class Spark313Shims extends Spark31XShims with Spark30Xuntil33XShims { +object SparkShimImpl extends Spark31XShims with Spark30Xuntil33XShims { - override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION + override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion override def getParquetFilters( schema: MessageType, diff --git a/sql-plugin/src/main/313/scala/com/nvidia/spark/rapids/shims/spark313/SparkShimServiceProvider.scala b/sql-plugin/src/main/313/scala/com/nvidia/spark/rapids/shims/spark313/SparkShimServiceProvider.scala index aef842331d1..646d4b3d9d2 100644 --- a/sql-plugin/src/main/313/scala/com/nvidia/spark/rapids/shims/spark313/SparkShimServiceProvider.scala +++ b/sql-plugin/src/main/313/scala/com/nvidia/spark/rapids/shims/spark313/SparkShimServiceProvider.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.shims.spark313 -import com.nvidia.spark.rapids.{SparkShims, SparkShimVersion} +import com.nvidia.spark.rapids.SparkShimVersion object SparkShimServiceProvider { val VERSION = SparkShimVersion(3, 1, 3) @@ -25,11 +25,9 @@ object SparkShimServiceProvider { class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { + override def getShimVersion: SparkShimVersion = SparkShimServiceProvider.VERSION + def matchesVersion(version: String): Boolean = { SparkShimServiceProvider.VERSIONNAMES.contains(version) } - - def buildShim: SparkShims = { - new Spark313Shims() - } } diff --git a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/AQEUtils.scala b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/AQEUtils.scala similarity index 93% rename from sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/AQEUtils.scala rename to sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/AQEUtils.scala index df2aee9268c..da45d303472 100644 --- a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/AQEUtils.scala +++ b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/AQEUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.adaptive.{QueryStageExec, ShuffleQueryStageExec} diff --git a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/GpuRegExpReplaceExec.scala similarity index 98% rename from sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala rename to sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/GpuRegExpReplaceExec.scala index c4ec686bf4f..caf92cbeedc 100644 --- a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuRegExpReplaceExec.scala +++ b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/GpuRegExpReplaceExec.scala @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{CudfRegexTranspiler, DataFromReplacementRule, GpuExpression, GpuOverrides, QuaternaryExprMeta, RapidsConf, RapidsMeta, RegexReplaceMode, RegexUnsupportedException} diff --git a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRunningWindowExec.scala b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/GpuRunningWindowExec.scala similarity index 94% rename from sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRunningWindowExec.scala rename to sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/GpuRunningWindowExec.scala index ff8cedf6aeb..f12c8208ec2 100644 --- a/sql-plugin/src/main/301db/scala/com/nvidia/spark/rapids/shims/v2/GpuRunningWindowExec.scala +++ b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/GpuRunningWindowExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.databricks.sql.execution.window.RunningWindowFunctionExec import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuBaseWindowExecMeta, RapidsConf, RapidsMeta} diff --git a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuWindowInPandasExec.scala b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/GpuWindowInPandasExec.scala similarity index 99% rename from sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuWindowInPandasExec.scala rename to sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/GpuWindowInPandasExec.scala index bf49267f2dc..eeef0f8685e 100644 --- a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/GpuWindowInPandasExec.scala +++ b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/GpuWindowInPandasExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import scala.collection.mutable.ArrayBuffer diff --git a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/ShimBroadcastExchangeLike.scala b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/ShimBroadcastExchangeLike.scala similarity index 92% rename from sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/ShimBroadcastExchangeLike.scala rename to sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/ShimBroadcastExchangeLike.scala index 8ee4ac10084..99551e21bfb 100644 --- a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/ShimBroadcastExchangeLike.scala +++ b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/ShimBroadcastExchangeLike.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import scala.concurrent.Promise diff --git a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XdbShims.scala b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/Spark31XdbShims.scala similarity index 99% rename from sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XdbShims.scala rename to sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/Spark31XdbShims.scala index 61a24f69de9..037718424e8 100644 --- a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XdbShims.scala +++ b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/Spark31XdbShims.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import java.net.URI import java.nio.ByteBuffer @@ -30,7 +30,7 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.internal.Logging -import org.apache.spark.rapids.shims.v2.GpuShuffleExchangeExec +import org.apache.spark.rapids.shims.GpuShuffleExchangeExec import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} @@ -59,8 +59,8 @@ import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.rapids._ import org.apache.spark.sql.rapids.execution.{GpuShuffleExchangeExecBase, SerializeBatchDeserializeHostBuffer, SerializeConcatHostBuffersDeserializeBatch} import org.apache.spark.sql.rapids.execution.python._ -import org.apache.spark.sql.rapids.execution.python.shims.v2._ -import org.apache.spark.sql.rapids.shims.v2._ +import org.apache.spark.sql.rapids.execution.python.shims._ +import org.apache.spark.sql.rapids.shims._ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.storage.{BlockId, BlockManagerId} diff --git a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XdbShimsBase.scala b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/Spark31XdbShimsBase.scala similarity index 99% rename from sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XdbShimsBase.scala rename to sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/Spark31XdbShimsBase.scala index 51459ace3ed..da10622630e 100644 --- a/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/v2/Spark31XdbShimsBase.scala +++ b/sql-plugin/src/main/31xdb/scala/com/nvidia/spark/rapids/shims/Spark31XdbShimsBase.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import scala.collection.mutable.ListBuffer diff --git a/sql-plugin/src/main/31xdb/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala b/sql-plugin/src/main/31xdb/scala/org/apache/spark/rapids/shims/GpuShuffleExchangeExec.scala similarity index 97% rename from sql-plugin/src/main/31xdb/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala rename to sql-plugin/src/main/31xdb/scala/org/apache/spark/rapids/shims/GpuShuffleExchangeExec.scala index 07bcefe93c7..b261e15823e 100644 --- a/sql-plugin/src/main/31xdb/scala/org/apache/spark/rapids/shims/v2/GpuShuffleExchangeExec.scala +++ b/sql-plugin/src/main/31xdb/scala/org/apache/spark/rapids/shims/GpuShuffleExchangeExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.rapids.shims.v2 +package org.apache.spark.rapids.shims import scala.concurrent.Future diff --git a/sql-plugin/src/main/31xdb/scala/org/apache/spark/rapids/shims/v2/api/python/ShimBasePythonRunner.scala b/sql-plugin/src/main/31xdb/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala similarity index 94% rename from sql-plugin/src/main/31xdb/scala/org/apache/spark/rapids/shims/v2/api/python/ShimBasePythonRunner.scala rename to sql-plugin/src/main/31xdb/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala index f89d105ca9e..b3fbdb985b0 100644 --- a/sql-plugin/src/main/31xdb/scala/org/apache/spark/rapids/shims/v2/api/python/ShimBasePythonRunner.scala +++ b/sql-plugin/src/main/31xdb/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.rapids.shims.v2.api.python +package org.apache.spark.rapids.shims.api.python import java.io.DataInputStream import java.net.Socket diff --git a/sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/execution/python/shims/v2/GpuFlatMapGroupsInPandasExec.scala b/sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuFlatMapGroupsInPandasExec.scala similarity index 97% rename from sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/execution/python/shims/v2/GpuFlatMapGroupsInPandasExec.scala rename to sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuFlatMapGroupsInPandasExec.scala index 615d9aa84ed..4c35249b012 100644 --- a/sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/execution/python/shims/v2/GpuFlatMapGroupsInPandasExec.scala +++ b/sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuFlatMapGroupsInPandasExec.scala @@ -14,10 +14,11 @@ * limitations under the License. */ -package org.apache.spark.sql.rapids.execution.python.shims.v2 +package org.apache.spark.sql.rapids.execution.python.shims import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.python.PythonWorkerSemaphore +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} @@ -95,7 +96,7 @@ case class GpuFlatMapGroupsInPandasExec( } override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(groupingAttributes.map(ShimLoader.getSparkShims.sortOrder(_, Ascending))) + Seq(groupingAttributes.map(SparkShimImpl.sortOrder(_, Ascending))) private val pandasFunction = func.asInstanceOf[GpuPythonUDF].func diff --git a/sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/execution/python/shims/v2/GpuGroupUDFArrowPythonRunner.scala b/sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupUDFArrowPythonRunner.scala similarity index 98% rename from sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/execution/python/shims/v2/GpuGroupUDFArrowPythonRunner.scala rename to sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupUDFArrowPythonRunner.scala index 02c096ba7c2..58fca943cb9 100644 --- a/sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/execution/python/shims/v2/GpuGroupUDFArrowPythonRunner.scala +++ b/sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupUDFArrowPythonRunner.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -17,7 +17,7 @@ * limitations under the License. */ -package org.apache.spark.sql.rapids.execution.python.shims.v2 +package org.apache.spark.sql.rapids.execution.python.shims import java.io.DataOutputStream import java.net.Socket diff --git a/sql-plugin/src/main/301db/scala/org/apache/spark/sql/rapids/shims/v2/GpuFileScanRDD.scala b/sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/shims/GpuFileScanRDD.scala similarity index 98% rename from sql-plugin/src/main/301db/scala/org/apache/spark/sql/rapids/shims/v2/GpuFileScanRDD.scala rename to sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/shims/GpuFileScanRDD.scala index d26a7c94654..bcaf8efee25 100644 --- a/sql-plugin/src/main/301db/scala/org/apache/spark/sql/rapids/shims/v2/GpuFileScanRDD.scala +++ b/sql-plugin/src/main/31xdb/scala/org/apache/spark/sql/rapids/shims/GpuFileScanRDD.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.rapids.shims.v2 +package org.apache.spark.sql.rapids.shims import java.io.{FileNotFoundException, IOException} diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/HashUtils.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/HashUtils.scala similarity index 94% rename from sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/HashUtils.scala rename to sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/HashUtils.scala index a001870f4c9..9590dd909da 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/HashUtils.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/HashUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import ai.rapids.cudf import com.nvidia.spark.rapids.{Arm, ColumnCastUtil} diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/OffsetWindowFunctionMeta.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/OffsetWindowFunctionMeta.scala similarity index 97% rename from sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/OffsetWindowFunctionMeta.scala rename to sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/OffsetWindowFunctionMeta.scala index 96de2d34326..7bca685c440 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/OffsetWindowFunctionMeta.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/OffsetWindowFunctionMeta.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{BaseExprMeta, DataFromReplacementRule, ExprMeta, GpuOverrides, RapidsConf, RapidsMeta} diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/OrcShims.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/OrcShims.scala similarity index 99% rename from sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/OrcShims.scala rename to sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/OrcShims.scala index c1ad560089b..ba1df94bb15 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/OrcShims.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/OrcShims.scala @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import scala.collection.mutable.ArrayBuffer diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/RapidsCsvScanMeta.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/RapidsCsvScanMeta.scala similarity index 97% rename from sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/RapidsCsvScanMeta.scala rename to sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/RapidsCsvScanMeta.scala index afdb3c53f01..1a50887652e 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/RapidsCsvScanMeta.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/RapidsCsvScanMeta.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuCSVScan, RapidsConf, RapidsMeta, ScanMeta} diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/RebaseShims.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/RebaseShims.scala similarity index 97% rename from sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/RebaseShims.scala rename to sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/RebaseShims.scala index 867cc996281..fe9b5dd3590 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/RebaseShims.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/RebaseShims.scala @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import org.apache.spark.sql.internal.SQLConf diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/ShimAQEShuffleReadExec.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/ShimAQEShuffleReadExec.scala similarity index 96% rename from sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/ShimAQEShuffleReadExec.scala rename to sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/ShimAQEShuffleReadExec.scala index 5b4d016a766..fbad04a8c82 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/ShimAQEShuffleReadExec.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/ShimAQEShuffleReadExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/ShimDataSourceRDD.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/ShimDataSourceRDD.scala similarity index 92% rename from sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/ShimDataSourceRDD.scala rename to sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/ShimDataSourceRDD.scala index a69dc595a2d..5108aaa65ac 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/ShimDataSourceRDD.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/ShimDataSourceRDD.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import org.apache.spark.SparkContext import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory} diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/Spark320PlusShims.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala similarity index 99% rename from sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/Spark320PlusShims.scala rename to sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala index 3e5b4a43f82..bdf5b03460f 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/Spark320PlusShims.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/Spark320PlusShims.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import java.net.URI import java.nio.ByteBuffer @@ -32,7 +32,7 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging -import org.apache.spark.rapids.shims.v2.GpuShuffleExchangeExec +import org.apache.spark.rapids.shims.GpuShuffleExchangeExec import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.Resolver @@ -62,8 +62,8 @@ import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.rapids.{GpuAbs, GpuAnsi, GpuAverage, GpuElementAt, GpuFileSourceScanExec, GpuGetArrayItem, GpuGetArrayItemMeta, GpuGetMapValue, GpuGetMapValueMeta} import org.apache.spark.sql.rapids.execution._ import org.apache.spark.sql.rapids.execution.python._ -import org.apache.spark.sql.rapids.execution.python.shims.v2.GpuFlatMapGroupsInPandasExecMeta -import org.apache.spark.sql.rapids.shims.v2._ +import org.apache.spark.sql.rapids.execution.python.shims.GpuFlatMapGroupsInPandasExecMeta +import org.apache.spark.sql.rapids.shims._ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.storage.{BlockId, BlockManagerId} diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/TypeSigUtil.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/TypeSigUtil.scala similarity index 96% rename from sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/TypeSigUtil.scala rename to sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/TypeSigUtil.scala index 77b4f12fbf4..7ff493ec9d3 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/TypeSigUtil.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/TypeSigUtil.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{TypeEnum, TypeSig, TypeSigUtilBase} diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/YearParseUtil.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/YearParseUtil.scala similarity index 92% rename from sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/YearParseUtil.scala rename to sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/YearParseUtil.scala index 5d7def4fa2f..f5107091b8c 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/YearParseUtil.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/YearParseUtil.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{RapidsConf, RapidsMeta} diff --git a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/gpuWindows.scala similarity index 94% rename from sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala rename to sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/gpuWindows.scala index b84bbd1c8ac..2a11109fb84 100644 --- a/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/v2/gpuWindows.scala +++ b/sql-plugin/src/main/320+/scala/com/nvidia/spark/rapids/shims/gpuWindows.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,12 +14,12 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuLiteral, GpuSpecifiedWindowFrameMetaBase, GpuWindowExpressionMetaBase, ParsedBoundary, RapidsConf, RapidsMeta} import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SpecifiedWindowFrame, WindowExpression} -import org.apache.spark.sql.rapids.shims.v2.Spark32XShimsUtils +import org.apache.spark.sql.rapids.shims.Spark32XShimsUtils import org.apache.spark.sql.types.{DataType, DayTimeIntervalType} class GpuSpecifiedWindowFrameMeta( diff --git a/sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/v2/GpuShuffleBlockResolver.scala b/sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/GpuShuffleBlockResolver.scala similarity index 94% rename from sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/v2/GpuShuffleBlockResolver.scala rename to sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/GpuShuffleBlockResolver.scala index 03d1973382f..7ffad0c14e1 100644 --- a/sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/v2/GpuShuffleBlockResolver.scala +++ b/sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/GpuShuffleBlockResolver.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.sql.rapids.shims.v2 +package org.apache.spark.sql.rapids.shims import com.nvidia.spark.rapids.ShuffleBufferCatalog diff --git a/sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/v2/ShuffledBatchRDDUtil.scala b/sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/ShuffledBatchRDDUtil.scala similarity index 96% rename from sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/v2/ShuffledBatchRDDUtil.scala rename to sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/ShuffledBatchRDDUtil.scala index 1e0d193b67b..5fde7ba8b26 100644 --- a/sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/v2/ShuffledBatchRDDUtil.scala +++ b/sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/ShuffledBatchRDDUtil.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,9 +14,9 @@ * limitations under the License. */ -package org.apache.spark.rapids.shims.v2 +package org.apache.spark.rapids.shims -import com.nvidia.spark.rapids.ShimLoader +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.spark.{MapOutputTrackerMaster, Partition, ShuffleDependency, SparkEnv, TaskContext} import org.apache.spark.shuffle.ShuffleReader @@ -60,7 +60,7 @@ object ShuffledBatchRDDUtil { dependency: ShuffleDependency[Int, ColumnarBatch, ColumnarBatch], sqlMetricsReporter: SQLShuffleReadMetricsReporter): (ShuffleReader[Nothing, Nothing], Long) = { - val shim = ShimLoader.getSparkShims + val shim = SparkShimImpl split.asInstanceOf[ShuffledBatchRDDPartition].spec match { case CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) => val reader = SparkEnv.get.shuffleManager.getReader( diff --git a/sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/v2/api/python/ShimBasePythonRunner.scala b/sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala similarity index 94% rename from sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/v2/api/python/ShimBasePythonRunner.scala rename to sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala index b86a398953e..661c6eaf8a0 100644 --- a/sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/v2/api/python/ShimBasePythonRunner.scala +++ b/sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/api/python/ShimBasePythonRunner.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.rapids.shims.v2.api.python +package org.apache.spark.rapids.shims.api.python import java.io.DataInputStream import java.net.Socket diff --git a/sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/v2/storage/ShimDiskBlockManager.scala b/sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/storage/ShimDiskBlockManager.scala similarity index 90% rename from sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/v2/storage/ShimDiskBlockManager.scala rename to sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/storage/ShimDiskBlockManager.scala index 65cce2ead49..96a3fd6f96a 100644 --- a/sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/v2/storage/ShimDiskBlockManager.scala +++ b/sql-plugin/src/main/320+/scala/org/apache/spark/rapids/shims/storage/ShimDiskBlockManager.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.rapids.shims.v2.storage +package org.apache.spark.rapids.shims.storage import org.apache.spark.SparkConf import org.apache.spark.sql.rapids.execution.TrampolineUtil diff --git a/sql-plugin/src/main/320+/scala/org/apache/spark/sql/rapids/shims/v2/Spark32XShimsUtils.scala b/sql-plugin/src/main/320+/scala/org/apache/spark/sql/rapids/shims/Spark32XShimsUtils.scala similarity index 93% rename from sql-plugin/src/main/320+/scala/org/apache/spark/sql/rapids/shims/v2/Spark32XShimsUtils.scala rename to sql-plugin/src/main/320+/scala/org/apache/spark/sql/rapids/shims/Spark32XShimsUtils.scala index fc2debd081c..727ca88840e 100644 --- a/sql-plugin/src/main/320+/scala/org/apache/spark/sql/rapids/shims/v2/Spark32XShimsUtils.scala +++ b/sql-plugin/src/main/320+/scala/org/apache/spark/sql/rapids/shims/Spark32XShimsUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.sql.rapids.shims.v2 +package org.apache.spark.sql.rapids.shims import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.{CalendarIntervalType, DataType, DateType, DayTimeIntervalType, IntegerType, TimestampNTZType, TimestampType, YearMonthIntervalType} diff --git a/sql-plugin/src/main/320+/scala/org/apache/spark/sql/rapids/shims/v2/datetimeExpressions.scala b/sql-plugin/src/main/320+/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala similarity index 55% rename from sql-plugin/src/main/320+/scala/org/apache/spark/sql/rapids/shims/v2/datetimeExpressions.scala rename to sql-plugin/src/main/320+/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala index e811cb7bfbf..096b3a85ab7 100644 --- a/sql-plugin/src/main/320+/scala/org/apache/spark/sql/rapids/shims/v2/datetimeExpressions.scala +++ b/sql-plugin/src/main/320+/scala/org/apache/spark/sql/rapids/shims/datetimeExpressions.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,14 +14,14 @@ * limitations under the License. */ -package org.apache.spark.sql.rapids.shims.v2 +package org.apache.spark.sql.rapids.shims import java.util.concurrent.TimeUnit -import ai.rapids.cudf.{ColumnVector, ColumnView, DType, Scalar} +import ai.rapids.cudf.{BinaryOp, BinaryOperable, ColumnVector, ColumnView, DType, Scalar} import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuScalar} import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.ShimBinaryExpression +import com.nvidia.spark.rapids.shims.ShimBinaryExpression import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, TimeZoneAwareExpression} import org.apache.spark.sql.types._ @@ -59,48 +59,59 @@ case class GpuTimeAdd(start: Expression, override def columnarEval(batch: ColumnarBatch): Any = { withResourceIfAllowed(left.columnarEval(batch)) { lhs => withResourceIfAllowed(right.columnarEval(batch)) { rhs => + // lhs is start, rhs is interval (lhs, rhs) match { - case (l: GpuColumnVector, intvlS: GpuScalar) => - val interval = intvlS.dataType match { + case (l: GpuColumnVector, intervalS: GpuScalar) => + // get long type interval + val interval = intervalS.dataType match { case CalendarIntervalType => // Scalar does not support 'CalendarInterval' now, so use // the Scala value instead. // Skip the null check because it wll be detected by the following calls. - val intvl = intvlS.getValue.asInstanceOf[CalendarInterval] - if (intvl.months != 0) { + val calendarI = intervalS.getValue.asInstanceOf[CalendarInterval] + if (calendarI.months != 0) { throw new UnsupportedOperationException("Months aren't supported at the moment") } - intvl.days * microSecondsInOneDay + intvl.microseconds + calendarI.days * microSecondsInOneDay + calendarI.microseconds case _: DayTimeIntervalType => - // Scalar does not support 'DayTimeIntervalType' now, so use - // the Scala value instead. - intvlS.getValue.asInstanceOf[Long] + intervalS.getValue.asInstanceOf[Long] case _ => - throw new UnsupportedOperationException("GpuTimeAdd unsupported data type: " + - intvlS.dataType) + throw new UnsupportedOperationException( + "GpuTimeAdd unsupported data type: " + intervalS.dataType) } + // add interval if (interval != 0) { - withResource(Scalar.fromLong(interval)) { us_s => - withResource(l.getBase.bitCastTo(DType.INT64)) { us => - withResource(intervalMath(us_s, us)) { longResult => - GpuColumnVector.from(longResult.castTo(DType.TIMESTAMP_MICROSECONDS), - dataType) - } - } + withResource(Scalar.durationFromLong(DType.DURATION_MICROSECONDS, interval)) { d => + GpuColumnVector.from(timestampAddDuration(l.getBase, d), dataType) } } else { l.incRefCount() } + case (l: GpuColumnVector, r: GpuColumnVector) => + (l.dataType(), r.dataType) match { + case (_: TimestampType, _: DayTimeIntervalType) => + // DayTimeIntervalType is stored as long + // bitCastTo is similar to reinterpret_cast, it's fast, the time can be ignored. + withResource(r.getBase.bitCastTo(DType.DURATION_MICROSECONDS)) { duration => + GpuColumnVector.from(timestampAddDuration(l.getBase, duration), dataType) + } + case _ => + throw new UnsupportedOperationException( + "GpuTimeAdd takes column and interval as an argument only") + } case _ => - throw new UnsupportedOperationException("GpuTimeAdd takes column and interval as an " + - "argument only") + throw new UnsupportedOperationException( + "GpuTimeAdd takes column and interval as an argument only") } } } } - private def intervalMath(us_s: Scalar, us: ColumnView): ColumnVector = { - us.add(us_s) + private def timestampAddDuration(cv: ColumnView, duration: BinaryOperable): ColumnVector = { + // Not use cv.add(duration), because of it invoke BinaryOperable.implicitConversion, + // and currently BinaryOperable.implicitConversion return Long + // Directly specify the return type is TIMESTAMP_MICROSECONDS + cv.binaryOp(BinaryOp.ADD, duration, DType.TIMESTAMP_MICROSECONDS) } } diff --git a/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/spark320/Spark320Shims.scala b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/SparkShims.scala similarity index 89% rename from sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/spark320/Spark320Shims.scala rename to sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/SparkShims.scala index e49eff933a1..fa184143355 100644 --- a/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/spark320/Spark320Shims.scala +++ b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -14,10 +14,9 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.spark320 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2._ import org.apache.parquet.schema.MessageType import org.apache.spark.rdd.RDD @@ -28,8 +27,8 @@ import org.apache.spark.sql.execution.datasources.{DataSourceUtils, FilePartitio import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters import org.apache.spark.sql.types.StructType -class Spark320Shims extends Spark320PlusShims with Spark30Xuntil33XShims with RebaseShims { - override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION +object SparkShimImpl extends Spark320PlusShims with Spark30Xuntil33XShims { + override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion override def getFileScanRDD( sparkSession: SparkSession, diff --git a/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/spark320/SparkShimServiceProvider.scala b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/spark320/SparkShimServiceProvider.scala index 93d2a8c327d..4117e0b58db 100644 --- a/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/spark320/SparkShimServiceProvider.scala +++ b/sql-plugin/src/main/320/scala/com/nvidia/spark/rapids/shims/spark320/SparkShimServiceProvider.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.shims.spark320 -import com.nvidia.spark.rapids.{SparkShims, SparkShimVersion} +import com.nvidia.spark.rapids.SparkShimVersion object SparkShimServiceProvider { val VERSION = SparkShimVersion(3, 2, 0) @@ -25,11 +25,9 @@ object SparkShimServiceProvider { class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { + override def getShimVersion: SparkShimVersion = SparkShimServiceProvider.VERSION + def matchesVersion(version: String): Boolean = { SparkShimServiceProvider.VERSIONNAMES.contains(version) } - - def buildShim: SparkShims = { - new Spark320Shims() - } } diff --git a/sql-plugin/src/main/320/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/v2/ShimVectorizedColumnReader.scala b/sql-plugin/src/main/320/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ShimVectorizedColumnReader.scala similarity index 99% rename from sql-plugin/src/main/320/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/v2/ShimVectorizedColumnReader.scala rename to sql-plugin/src/main/320/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ShimVectorizedColumnReader.scala index 490fc5d2749..2eca3e4485a 100644 --- a/sql-plugin/src/main/320/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/v2/ShimVectorizedColumnReader.scala +++ b/sql-plugin/src/main/320/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ShimVectorizedColumnReader.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.parquet.rapids.shims.v2 +package org.apache.spark.sql.execution.datasources.parquet.rapids.shims import java.time.ZoneId diff --git a/sql-plugin/src/main/320until330-all/scala/com/nvidia/spark/rapids/shims/v2/RapidsOrcScanMeta.scala b/sql-plugin/src/main/320until330-all/scala/com/nvidia/spark/rapids/shims/RapidsOrcScanMeta.scala similarity index 97% rename from sql-plugin/src/main/320until330-all/scala/com/nvidia/spark/rapids/shims/v2/RapidsOrcScanMeta.scala rename to sql-plugin/src/main/320until330-all/scala/com/nvidia/spark/rapids/shims/RapidsOrcScanMeta.scala index ed4484f1b33..f4cce986029 100644 --- a/sql-plugin/src/main/320until330-all/scala/com/nvidia/spark/rapids/shims/v2/RapidsOrcScanMeta.scala +++ b/sql-plugin/src/main/320until330-all/scala/com/nvidia/spark/rapids/shims/RapidsOrcScanMeta.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuOrcScanBase, RapidsConf, RapidsMeta, ScanMeta} diff --git a/sql-plugin/src/main/320until330-all/scala/com/nvidia/spark/rapids/shims/v2/RapidsParquetScanMeta.scala b/sql-plugin/src/main/320until330-all/scala/com/nvidia/spark/rapids/shims/RapidsParquetScanMeta.scala similarity index 97% rename from sql-plugin/src/main/320until330-all/scala/com/nvidia/spark/rapids/shims/v2/RapidsParquetScanMeta.scala rename to sql-plugin/src/main/320until330-all/scala/com/nvidia/spark/rapids/shims/RapidsParquetScanMeta.scala index 9831a10ca54..07293e61c0d 100644 --- a/sql-plugin/src/main/320until330-all/scala/com/nvidia/spark/rapids/shims/v2/RapidsParquetScanMeta.scala +++ b/sql-plugin/src/main/320until330-all/scala/com/nvidia/spark/rapids/shims/RapidsParquetScanMeta.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuParquetScanBase, RapidsConf, RapidsMeta, ScanMeta} diff --git a/sql-plugin/src/main/321+/scala/com/nvidia/spark/rapids/shims/v2/Spark321PlusShims.scala b/sql-plugin/src/main/321+/scala/com/nvidia/spark/rapids/shims/Spark321PlusShims.scala similarity index 97% rename from sql-plugin/src/main/321+/scala/com/nvidia/spark/rapids/shims/v2/Spark321PlusShims.scala rename to sql-plugin/src/main/321+/scala/com/nvidia/spark/rapids/shims/Spark321PlusShims.scala index 7d6f7a710f5..a423d2d85e2 100644 --- a/sql-plugin/src/main/321+/scala/com/nvidia/spark/rapids/shims/v2/Spark321PlusShims.scala +++ b/sql-plugin/src/main/321+/scala/com/nvidia/spark/rapids/shims/Spark321PlusShims.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import org.apache.parquet.schema.MessageType diff --git a/sql-plugin/src/main/321+/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/v2/ShimVectorizedColumnReader.scala b/sql-plugin/src/main/321+/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ShimVectorizedColumnReader.scala similarity index 99% rename from sql-plugin/src/main/321+/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/v2/ShimVectorizedColumnReader.scala rename to sql-plugin/src/main/321+/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ShimVectorizedColumnReader.scala index 121bf0bc4bb..dcfbd319744 100644 --- a/sql-plugin/src/main/321+/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/v2/ShimVectorizedColumnReader.scala +++ b/sql-plugin/src/main/321+/scala/org/apache/spark/sql/execution/datasources/parquet/rapids/shims/ShimVectorizedColumnReader.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.datasources.parquet.rapids.shims.v2 +package org.apache.spark.sql.execution.datasources.parquet.rapids.shims import java.time.ZoneId import java.util.TimeZone diff --git a/sql-plugin/src/main/321/scala/com/nvidia/spark/rapids/shims/spark321/Spark321Shims.scala b/sql-plugin/src/main/321/scala/com/nvidia/spark/rapids/shims/SparkShims.scala similarity index 84% rename from sql-plugin/src/main/321/scala/com/nvidia/spark/rapids/shims/spark321/Spark321Shims.scala rename to sql-plugin/src/main/321/scala/com/nvidia/spark/rapids/shims/SparkShims.scala index 6f577d9cdc2..255b58f1921 100644 --- a/sql-plugin/src/main/321/scala/com/nvidia/spark/rapids/shims/spark321/Spark321Shims.scala +++ b/sql-plugin/src/main/321/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -14,10 +14,9 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.spark321 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession @@ -26,8 +25,8 @@ import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile} import org.apache.spark.sql.types.StructType -class Spark321Shims extends Spark321PlusShims with Spark30Xuntil33XShims { - override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION +object SparkShimImpl extends Spark321PlusShims with Spark30Xuntil33XShims { + override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion override def getFileScanRDD( sparkSession: SparkSession, diff --git a/sql-plugin/src/main/321/scala/com/nvidia/spark/rapids/shims/spark321/SparkShimServiceProvider.scala b/sql-plugin/src/main/321/scala/com/nvidia/spark/rapids/shims/spark321/SparkShimServiceProvider.scala index 51f34e5b11d..9cac4a8f366 100644 --- a/sql-plugin/src/main/321/scala/com/nvidia/spark/rapids/shims/spark321/SparkShimServiceProvider.scala +++ b/sql-plugin/src/main/321/scala/com/nvidia/spark/rapids/shims/spark321/SparkShimServiceProvider.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.shims.spark321 -import com.nvidia.spark.rapids.{SparkShims, SparkShimVersion} +import com.nvidia.spark.rapids.SparkShimVersion object SparkShimServiceProvider { val VERSION = SparkShimVersion(3, 2, 1) @@ -25,11 +25,9 @@ object SparkShimServiceProvider { class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { + override def getShimVersion: SparkShimVersion = SparkShimServiceProvider.VERSION + def matchesVersion(version: String): Boolean = { SparkShimServiceProvider.VERSIONNAMES.contains(version) } - - def buildShim: SparkShims = { - new Spark321Shims() - } } diff --git a/sql-plugin/src/main/311cdh/scala/com/nvidia/spark/rapids/shims/spark311cdh/Spark311CDHShims.scala b/sql-plugin/src/main/322+/scala/com/nvidia/spark/rapids/shims/Spark322PlusShims.scala similarity index 71% rename from sql-plugin/src/main/311cdh/scala/com/nvidia/spark/rapids/shims/spark311cdh/Spark311CDHShims.scala rename to sql-plugin/src/main/322+/scala/com/nvidia/spark/rapids/shims/Spark322PlusShims.scala index ba527ecad93..8c7a55c7d96 100644 --- a/sql-plugin/src/main/311cdh/scala/com/nvidia/spark/rapids/shims/spark311cdh/Spark311CDHShims.scala +++ b/sql-plugin/src/main/322+/scala/com/nvidia/spark/rapids/shims/Spark322PlusShims.scala @@ -14,20 +14,18 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.spark311cdh +package com.nvidia.spark.rapids.shims -import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2._ import org.apache.parquet.schema.MessageType +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters -class Spark311CDHShims extends Spark31XShims with Spark30Xuntil33XShims { - - override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION - - override def hasCastFloatTimestampUpcast: Boolean = false - +/** + * Shim base class that can be compiled with every supported 3.2.2+ + */ +trait Spark322PlusShims extends Spark320PlusShims with RebaseShims with Logging { override def getParquetFilters( schema: MessageType, pushDownDate: Boolean, @@ -38,7 +36,9 @@ class Spark311CDHShims extends Spark31XShims with Spark30Xuntil33XShims { caseSensitive: Boolean, lookupFileMeta: String => String, dateTimeRebaseModeFromConf: String): ParquetFilters = { + val datetimeRebaseMode = DataSourceUtils + .datetimeRebaseSpec(lookupFileMeta, dateTimeRebaseModeFromConf) new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith, - pushDownInFilterThreshold, caseSensitive) + pushDownInFilterThreshold, caseSensitive, datetimeRebaseMode) } -} +} \ No newline at end of file diff --git a/sql-plugin/src/main/322/scala/com/nvidia/spark/rapids/shims/spark322/Spark322Shims.scala b/sql-plugin/src/main/322/scala/com/nvidia/spark/rapids/shims/SparkShims.scala similarity index 84% rename from sql-plugin/src/main/322/scala/com/nvidia/spark/rapids/shims/spark322/Spark322Shims.scala rename to sql-plugin/src/main/322/scala/com/nvidia/spark/rapids/shims/SparkShims.scala index eede090a3d5..255b58f1921 100644 --- a/sql-plugin/src/main/322/scala/com/nvidia/spark/rapids/shims/spark322/Spark322Shims.scala +++ b/sql-plugin/src/main/322/scala/com/nvidia/spark/rapids/shims/SparkShims.scala @@ -14,10 +14,9 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.spark322 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession @@ -26,8 +25,8 @@ import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile} import org.apache.spark.sql.types.StructType -class Spark322Shims extends Spark321PlusShims with Spark30Xuntil33XShims { - override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION +object SparkShimImpl extends Spark321PlusShims with Spark30Xuntil33XShims { + override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion override def getFileScanRDD( sparkSession: SparkSession, diff --git a/sql-plugin/src/main/322/scala/com/nvidia/spark/rapids/shims/spark322/SparkShimServiceProvider.scala b/sql-plugin/src/main/322/scala/com/nvidia/spark/rapids/shims/spark322/SparkShimServiceProvider.scala index 3b979042d59..66d962b0fc6 100644 --- a/sql-plugin/src/main/322/scala/com/nvidia/spark/rapids/shims/spark322/SparkShimServiceProvider.scala +++ b/sql-plugin/src/main/322/scala/com/nvidia/spark/rapids/shims/spark322/SparkShimServiceProvider.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.shims.spark322 -import com.nvidia.spark.rapids.{SparkShims, SparkShimVersion} +import com.nvidia.spark.rapids.SparkShimVersion object SparkShimServiceProvider { val VERSION = SparkShimVersion(3, 2, 2) @@ -25,11 +25,9 @@ object SparkShimServiceProvider { class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { + override def getShimVersion: SparkShimVersion = SparkShimServiceProvider.VERSION + def matchesVersion(version: String): Boolean = { SparkShimServiceProvider.VERSIONNAMES.contains(version) } - - def buildShim: SparkShims = { - new Spark322Shims() - } } diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/AnsiCheckUtil.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/AnsiCheckUtil.scala similarity index 97% rename from sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/AnsiCheckUtil.scala rename to sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/AnsiCheckUtil.scala index 5d33bf42755..eb7ea2d551e 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/AnsiCheckUtil.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/AnsiCheckUtil.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import java.time.DateTimeException diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/GpuHashPartitioning.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuHashPartitioning.scala similarity index 97% rename from sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/GpuHashPartitioning.scala rename to sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuHashPartitioning.scala index f4b077225da..a8642abe1ea 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/GpuHashPartitioning.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuHashPartitioning.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.GpuHashPartitioningBase diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/GpuRangePartitioning.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuRangePartitioning.scala similarity index 98% rename from sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/GpuRangePartitioning.scala rename to sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuRangePartitioning.scala index a05ebf2e3d3..3cbaf0c158e 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/GpuRangePartitioning.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuRangePartitioning.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{GpuExpression, GpuPartitioning} diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala new file mode 100644 index 00000000000..ddca4f2a9ca --- /dev/null +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/GpuTypeShims.scala @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.nvidia.spark.rapids.shims + +import ai.rapids.cudf.DType +import com.nvidia.spark.rapids.GpuRowToColumnConverter.{LongConverter, NotNullLongConverter, TypeConverter} + +import org.apache.spark.sql.types.{DataType, DayTimeIntervalType} + +/** + * Spark stores ANSI YearMonthIntervalType as int32 and ANSI DayTimeIntervalType as int64 + * internally when computing. + * See the comments of YearMonthIntervalType, below is copied from Spark + * Internally, values of year-month intervals are stored in `Int` values as amount of months + * that are calculated by the formula: + * -/+ (12 * YEAR + MONTH) + * See the comments of DayTimeIntervalType, below is copied from Spark + * Internally, values of day-time intervals are stored in `Long` values as amount of time in terms + * of microseconds that are calculated by the formula: + * -/+ (24*60*60 * DAY + 60*60 * HOUR + 60 * MINUTE + SECOND) * 1000000 + * + * Spark also stores ANSI intervals as int32 and int64 in Parquet file: + * - year-month intervals as `INT32` + * - day-time intervals as `INT64` + * To load the values as intervals back, Spark puts the info about interval types + * to the extra key `org.apache.spark.sql.parquet.row.metadata`: + * $ java -jar parquet-tools-1.12.0.jar meta ./part-...-c000.snappy.parquet + * creator: parquet-mr version 1.12.1 (build 2a5c06c58fa987f85aa22170be14d927d5ff6e7d) + * extra: org.apache.spark.version = 3.3.0 + * extra: org.apache.spark.sql.parquet.row.metadata = + * {"type":"struct","fields":[..., + * {"name":"i","type":"interval year to month","nullable":false,"metadata":{}}]} + * file schema: spark_schema + * -------------------------------------------------------------------------------- + * ... + * i: REQUIRED INT32 R:0 D:0 + * + * For details See https://issues.apache.org/jira/browse/SPARK-36825 + */ +object GpuTypeShims { + + /** + * If Shim supports the data type for row to column converter + * @param otherType the data type that should be checked in the Shim + * @return true if Shim support the otherType, false otherwise. + */ + def hasConverterForType(otherType: DataType) : Boolean = { + otherType match { + case DayTimeIntervalType(_, _) => true + case _ => false + } + } + + /** + * Get the TypeConverter of the data type for this Shim + * Note should first calling hasConverterForType + * @param t the data type + * @param nullable is nullable + * @return the row to column convert for the data type + */ + def getConverterForType(t: DataType, nullable: Boolean): TypeConverter = { + (t, nullable) match { + case (DayTimeIntervalType(_, _), true) => LongConverter + case (DayTimeIntervalType(_, _), false) => NotNullLongConverter + case _ => throw new RuntimeException(s"No converter is found for type $t.") + } + } + + /** + * Get the cuDF type for the Spark data type + * @param t the Spark data type + * @return the cuDF type if the Shim supports + */ + def toRapidsOrNull(t: DataType): DType = { + t match { + case _: DayTimeIntervalType => + // use int64 as Spark does + DType.INT64 + case _ => + null + } + } +} diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/ParquetFieldIdShims.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/ParquetFieldIdShims.scala similarity index 97% rename from sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/ParquetFieldIdShims.scala rename to sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/ParquetFieldIdShims.scala index 8b1e1f2297e..491e1397f18 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/ParquetFieldIdShims.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/ParquetFieldIdShims.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.RapidsMeta import org.apache.hadoop.conf.Configuration diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/RapidsErrorUtils.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala similarity index 97% rename from sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/RapidsErrorUtils.scala rename to sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala index 1f904f4a7aa..76b0c40182b 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/RapidsErrorUtils.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/RapidsErrorUtils.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import org.apache.spark.sql.errors.QueryExecutionErrors diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/RapidsOrcScanMeta.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/RapidsOrcScanMeta.scala similarity index 97% rename from sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/RapidsOrcScanMeta.scala rename to sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/RapidsOrcScanMeta.scala index 68a6308f473..afa40fe1534 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/RapidsOrcScanMeta.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/RapidsOrcScanMeta.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuOrcScanBase, RapidsConf, RapidsMeta, ScanMeta} diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/RapidsParquetScanMeta.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/RapidsParquetScanMeta.scala similarity index 97% rename from sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/RapidsParquetScanMeta.scala rename to sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/RapidsParquetScanMeta.scala index f7478370c8e..db34d6cad78 100644 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/RapidsParquetScanMeta.scala +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/RapidsParquetScanMeta.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids.{DataFromReplacementRule, GpuParquetScanBase, RapidsConf, RapidsMeta, ScanMeta} diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala new file mode 100644 index 00000000000..3e9982ea258 --- /dev/null +++ b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/Spark33XShims.scala @@ -0,0 +1,278 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shims + +import com.nvidia.spark.rapids._ +import org.apache.parquet.schema.MessageType + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Coalesce, DynamicPruningExpression, Expression, MetadataAttribute, TimeAdd} +import org.apache.spark.sql.catalyst.json.rapids.shims.Spark33XFileOptionsShims +import org.apache.spark.sql.execution.{BaseSubqueryExec, CoalesceExec, FileSourceScanExec, InSubqueryExec, ProjectExec, ReusedSubqueryExec, SparkPlan, SubqueryBroadcastExec} +import org.apache.spark.sql.execution.command.DataWritingCommandExec +import org.apache.spark.sql.execution.datasources.{DataSourceUtils, FilePartition, FileScanRDD, HadoopFsRelation, PartitionedFile} +import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.rapids.GpuFileSourceScanExec +import org.apache.spark.sql.rapids.shims.GpuTimeAdd +import org.apache.spark.sql.types.{CalendarIntervalType, DayTimeIntervalType, StructType} +import org.apache.spark.unsafe.types.CalendarInterval + +trait Spark33XShims extends Spark33XFileOptionsShims { + + /** + * For spark3.3+ optionally return null if element not exists. + */ + override def shouldFailOnElementNotExists(): Boolean = SQLConf.get.strictIndexOperator + + override def neverReplaceShowCurrentNamespaceCommand: ExecRule[_ <: SparkPlan] = null + + override def getFileScanRDD( + sparkSession: SparkSession, + readFunction: PartitionedFile => Iterator[InternalRow], + filePartitions: Seq[FilePartition], + readDataSchema: StructType, + metadataColumns: Seq[AttributeReference]): RDD[InternalRow] = { + new FileScanRDD(sparkSession, readFunction, filePartitions, readDataSchema, metadataColumns) + } + + override def getParquetFilters( + schema: MessageType, + pushDownDate: Boolean, + pushDownTimestamp: Boolean, + pushDownDecimal: Boolean, + pushDownStartWith: Boolean, + pushDownInFilterThreshold: Int, + caseSensitive: Boolean, + lookupFileMeta: String => String, + dateTimeRebaseModeFromConf: String): ParquetFilters = { + val datetimeRebaseMode = DataSourceUtils + .datetimeRebaseSpec(lookupFileMeta, dateTimeRebaseModeFromConf) + new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith, + pushDownInFilterThreshold, caseSensitive, datetimeRebaseMode) + } + + override def tagFileSourceScanExec(meta: SparkPlanMeta[FileSourceScanExec]): Unit = { + if (meta.wrapped.expressions.exists(expr => expr match { + case MetadataAttribute(expr) => true + case _ => false + })) { + meta.willNotWorkOnGpu("hidden metadata columns are not supported on GPU") + } + super.tagFileSourceScanExec(meta) + } + + // 330+ supports DAYTIME interval types + override def getFileFormats: Map[FileFormatType, Map[FileFormatOp, FileFormatChecks]] = { + Map( + (ParquetFormatType, FileFormatChecks( + cudfRead = (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT + + TypeSig.ARRAY + TypeSig.MAP + TypeSig.DAYTIME).nested(), + cudfWrite = (TypeSig.commonCudfTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT + + TypeSig.ARRAY + TypeSig.MAP + TypeSig.DAYTIME).nested(), + sparkSig = (TypeSig.cpuAtomics + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + + TypeSig.UDT + TypeSig.DAYTIME).nested()))) + } + + // 330+ supports DAYTIME interval types + override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = { + val _gpuCommonTypes = TypeSig.commonCudfTypes + TypeSig.NULL + val map: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq( + GpuOverrides.expr[Coalesce]( + "Returns the first non-null argument if exists. Otherwise, null", + ExprChecks.projectOnly( + (_gpuCommonTypes + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT + + TypeSig.DAYTIME).nested(), + TypeSig.all, + repeatingParamCheck = Some(RepeatingParamCheck("param", + (_gpuCommonTypes + TypeSig.DECIMAL_128 + TypeSig.ARRAY + TypeSig.STRUCT + + TypeSig.DAYTIME).nested(), + TypeSig.all))), + (a, conf, p, r) => new ExprMeta[Coalesce](a, conf, p, r) { + override def convertToGpu(): + GpuExpression = GpuCoalesce(childExprs.map(_.convertToGpu())) + }), + GpuOverrides.expr[AttributeReference]( + "References an input column", + ExprChecks.projectAndAst( + TypeSig.astTypes, + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.MAP + TypeSig.ARRAY + + TypeSig.STRUCT + TypeSig.DECIMAL_128 + TypeSig.DAYTIME).nested(), + TypeSig.all), + (att, conf, p, r) => new BaseExprMeta[AttributeReference](att, conf, p, r) { + // This is the only NOOP operator. It goes away when things are bound + override def convertToGpu(): Expression = att + + // There are so many of these that we don't need to print them out, unless it + // will not work on the GPU + override def print(append: StringBuilder, depth: Int, all: Boolean): Unit = { + if (!this.canThisBeReplaced || cannotRunOnGpuBecauseOfSparkPlan) { + super.print(append, depth, all) + } + } + }), + GpuOverrides.expr[TimeAdd]( + "Adds interval to timestamp", + ExprChecks.binaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP, + ("start", TypeSig.TIMESTAMP, TypeSig.TIMESTAMP), + // interval support DAYTIME column or CALENDAR literal + ("interval", TypeSig.DAYTIME + TypeSig.lit(TypeEnum.CALENDAR) + .withPsNote(TypeEnum.CALENDAR, "month intervals are not supported"), + TypeSig.DAYTIME + TypeSig.CALENDAR)), + (timeAdd, conf, p, r) => new BinaryExprMeta[TimeAdd](timeAdd, conf, p, r) { + override def tagExprForGpu(): Unit = { + GpuOverrides.extractLit(timeAdd.interval).foreach { lit => + lit.dataType match { + case CalendarIntervalType => + val intvl = lit.value.asInstanceOf[CalendarInterval] + if (intvl.months != 0) { + willNotWorkOnGpu("interval months isn't supported") + } + case _: DayTimeIntervalType => // Supported + } + } + checkTimeZoneId(timeAdd.timeZoneId) + } + + override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = + GpuTimeAdd(lhs, rhs) + }) + ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap + super.getExprs ++ map + } + + // 330+ supports DAYTIME interval types + override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = { + val _gpuCommonTypes = TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_64 + val map: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = Seq( + GpuOverrides.exec[CoalesceExec]( + "The backend for the dataframe coalesce method", + ExecChecks((_gpuCommonTypes + TypeSig.DECIMAL_128 + TypeSig.STRUCT + TypeSig.ARRAY + + TypeSig.MAP + TypeSig.DAYTIME).nested(), + TypeSig.all), + (coalesce, conf, parent, r) => new SparkPlanMeta[CoalesceExec](coalesce, conf, parent, r) { + override def convertToGpu(): GpuExec = + GpuCoalesceExec(coalesce.numPartitions, childPlans.head.convertIfNeeded()) + }), + GpuOverrides.exec[DataWritingCommandExec]( + "Writing data", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.DECIMAL_128.withPsNote( + TypeEnum.DECIMAL, "128bit decimal only supported for Orc and Parquet") + + TypeSig.STRUCT.withPsNote(TypeEnum.STRUCT, "Only supported for Parquet") + + TypeSig.MAP.withPsNote(TypeEnum.MAP, "Only supported for Parquet") + + TypeSig.ARRAY.withPsNote(TypeEnum.ARRAY, "Only supported for Parquet") + + TypeSig.DAYTIME).nested(), + TypeSig.all), + (p, conf, parent, r) => new SparkPlanMeta[DataWritingCommandExec](p, conf, parent, r) { + override val childDataWriteCmds: scala.Seq[DataWritingCommandMeta[_]] = + Seq(GpuOverrides.wrapDataWriteCmds(p.cmd, conf, Some(this))) + + override def convertToGpu(): GpuExec = + GpuDataWritingCommandExec(childDataWriteCmds.head.convertToGpu(), + childPlans.head.convertIfNeeded()) + }), + // this is copied, only added TypeSig.DAYTIME check + GpuOverrides.exec[FileSourceScanExec]( + "Reading data from files, often from Hive tables", + ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + + TypeSig.ARRAY + TypeSig.DECIMAL_128 + TypeSig.DAYTIME).nested(), + TypeSig.all), + (fsse, conf, p, r) => new SparkPlanMeta[FileSourceScanExec](fsse, conf, p, r) { + + // Replaces SubqueryBroadcastExec inside dynamic pruning filters with GPU counterpart + // if possible. Instead regarding filters as childExprs of current Meta, we create + // a new meta for SubqueryBroadcastExec. The reason is that the GPU replacement of + // FileSourceScan is independent from the replacement of the partitionFilters. It is + // possible that the FileSourceScan is on the CPU, while the dynamic partitionFilters + // are on the GPU. And vice versa. + private lazy val partitionFilters = { + val convertBroadcast = (bc: SubqueryBroadcastExec) => { + val meta = GpuOverrides.wrapAndTagPlan(bc, conf) + meta.tagForExplain() + meta.convertIfNeeded().asInstanceOf[BaseSubqueryExec] + } + wrapped.partitionFilters.map { filter => + filter.transformDown { + case dpe@DynamicPruningExpression(inSub: InSubqueryExec) => + inSub.plan match { + case bc: SubqueryBroadcastExec => + dpe.copy(inSub.copy(plan = convertBroadcast(bc))) + case reuse@ReusedSubqueryExec(bc: SubqueryBroadcastExec) => + dpe.copy(inSub.copy(plan = reuse.copy(convertBroadcast(bc)))) + case _ => + dpe + } + } + } + } + + // partition filters and data filters are not run on the GPU + override val childExprs: Seq[ExprMeta[_]] = Seq.empty + + override def tagPlanForGpu(): Unit = tagFileSourceScanExec(this) + + override def convertToCpu(): SparkPlan = { + wrapped.copy(partitionFilters = partitionFilters) + } + + override def convertToGpu(): GpuExec = { + val sparkSession = wrapped.relation.sparkSession + val options = wrapped.relation.options + + val location = replaceWithAlluxioPathIfNeeded( + conf, + wrapped.relation, + partitionFilters, + wrapped.dataFilters) + + val newRelation = HadoopFsRelation( + location, + wrapped.relation.partitionSchema, + wrapped.relation.dataSchema, + wrapped.relation.bucketSpec, + GpuFileSourceScanExec.convertFileFormat(wrapped.relation.fileFormat), + options)(sparkSession) + + GpuFileSourceScanExec( + newRelation, + wrapped.output, + wrapped.requiredSchema, + partitionFilters, + wrapped.optionalBucketSet, + wrapped.optionalNumCoalescedBuckets, + wrapped.dataFilters, + wrapped.tableIdentifier, + wrapped.disableBucketedScan)(conf) + } + }), + GpuOverrides.exec[ProjectExec]( + "The backend for most select, withColumn and dropColumn statements", + ExecChecks( + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.MAP + + TypeSig.ARRAY + TypeSig.DECIMAL_128 + TypeSig.DAYTIME).nested(), + TypeSig.all), + (proj, conf, p, r) => new GpuProjectExecMeta(proj, conf, p, r)) + ).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap + super.getExecs ++ map + } + +} + +// Fallback to the default definition of `deterministic` +trait GpuDeterministicFirstLastCollectShim extends Expression diff --git a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/Spark33XShims.scala b/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/Spark33XShims.scala deleted file mode 100644 index 1e3bd34597c..00000000000 --- a/sql-plugin/src/main/330+/scala/com/nvidia/spark/rapids/shims/v2/Spark33XShims.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids.shims.v2 - -import com.nvidia.spark.rapids._ -import org.apache.parquet.schema.MessageType - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, MetadataAttribute} -import org.apache.spark.sql.catalyst.json.rapids.shims.v2.Spark33XFileOptionsShims -import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} -import org.apache.spark.sql.execution.datasources.{DataSourceUtils, FilePartition, FileScanRDD, PartitionedFile} -import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructType - - -trait Spark33XShims extends Spark33XFileOptionsShims { - - /** - * For spark3.3+ optionally return null if element not exists. - */ - override def shouldFailOnElementNotExists(): Boolean = SQLConf.get.strictIndexOperator - - override def neverReplaceShowCurrentNamespaceCommand: ExecRule[_ <: SparkPlan] = null - - override def getFileScanRDD( - sparkSession: SparkSession, - readFunction: PartitionedFile => Iterator[InternalRow], - filePartitions: Seq[FilePartition], - readDataSchema: StructType, - metadataColumns: Seq[AttributeReference]): RDD[InternalRow] = { - new FileScanRDD(sparkSession, readFunction, filePartitions, readDataSchema, metadataColumns) - } - - override def getParquetFilters( - schema: MessageType, - pushDownDate: Boolean, - pushDownTimestamp: Boolean, - pushDownDecimal: Boolean, - pushDownStartWith: Boolean, - pushDownInFilterThreshold: Int, - caseSensitive: Boolean, - lookupFileMeta: String => String, - dateTimeRebaseModeFromConf: String): ParquetFilters = { - val datetimeRebaseMode = DataSourceUtils - .datetimeRebaseSpec(lookupFileMeta, dateTimeRebaseModeFromConf) - new ParquetFilters(schema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStartWith, - pushDownInFilterThreshold, caseSensitive, datetimeRebaseMode) - } - - override def tagFileSourceScanExec(meta: SparkPlanMeta[FileSourceScanExec]): Unit = { - if (meta.wrapped.expressions.exists(expr => expr match { - case MetadataAttribute(expr) => true - case _ => false - })) { - meta.willNotWorkOnGpu("hidden metadata columns are not supported on GPU") - } - super.tagFileSourceScanExec(meta) - } -} - -// Fallback to the default definition of `deterministic` -trait GpuDeterministicFirstLastCollectShim extends Expression diff --git a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/catalyst/json/rapids/shims/v2/Spark33XFileOptionsShims.scala b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/catalyst/json/rapids/shims/Spark33XFileOptionsShims.scala similarity index 90% rename from sql-plugin/src/main/330+/scala/org/apache/spark/sql/catalyst/json/rapids/shims/v2/Spark33XFileOptionsShims.scala rename to sql-plugin/src/main/330+/scala/org/apache/spark/sql/catalyst/json/rapids/shims/Spark33XFileOptionsShims.scala index 5c1636a994b..092f86a6504 100644 --- a/sql-plugin/src/main/330+/scala/org/apache/spark/sql/catalyst/json/rapids/shims/v2/Spark33XFileOptionsShims.scala +++ b/sql-plugin/src/main/330+/scala/org/apache/spark/sql/catalyst/json/rapids/shims/Spark33XFileOptionsShims.scala @@ -14,9 +14,9 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.json.rapids.shims.v2 +package org.apache.spark.sql.catalyst.json.rapids.shims -import com.nvidia.spark.rapids.shims.v2.Spark321PlusShims +import com.nvidia.spark.rapids.shims.Spark321PlusShims import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.json.JSONOptions diff --git a/sql-plugin/src/main/330/scala/com/nvidia/spark/rapids/shims/spark330/Spark330Shims.scala b/sql-plugin/src/main/330/scala/com/nvidia/spark/rapids/SparkShims.scala similarity index 75% rename from sql-plugin/src/main/330/scala/com/nvidia/spark/rapids/shims/spark330/Spark330Shims.scala rename to sql-plugin/src/main/330/scala/com/nvidia/spark/rapids/SparkShims.scala index 9efa0d7c8b0..5dc59c10ea1 100644 --- a/sql-plugin/src/main/330/scala/com/nvidia/spark/rapids/shims/spark330/Spark330Shims.scala +++ b/sql-plugin/src/main/330/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -14,11 +14,10 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.spark330 +package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2._ -class Spark330Shims extends Spark33XShims { - override def getSparkShimVersion: ShimVersion = SparkShimServiceProvider.VERSION +object SparkShimImpl extends Spark33XShims { + override def getSparkShimVersion: ShimVersion = ShimLoader.getShimVersion } diff --git a/sql-plugin/src/main/330/scala/com/nvidia/spark/rapids/shims/spark330/SparkShimServiceProvider.scala b/sql-plugin/src/main/330/scala/com/nvidia/spark/rapids/shims/spark330/SparkShimServiceProvider.scala index e6e2e179e8c..547643aae83 100644 --- a/sql-plugin/src/main/330/scala/com/nvidia/spark/rapids/shims/spark330/SparkShimServiceProvider.scala +++ b/sql-plugin/src/main/330/scala/com/nvidia/spark/rapids/shims/spark330/SparkShimServiceProvider.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids.shims.spark330 -import com.nvidia.spark.rapids.{SparkShims, SparkShimVersion} +import com.nvidia.spark.rapids.SparkShimVersion object SparkShimServiceProvider { val VERSION = SparkShimVersion(3, 3, 0) @@ -25,11 +25,9 @@ object SparkShimServiceProvider { class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { + override def getShimVersion: SparkShimVersion = SparkShimServiceProvider.VERSION + def matchesVersion(version: String): Boolean = { SparkShimServiceProvider.VERSIONNAMES.contains(version) } - - def buildShim: SparkShims = { - new Spark330Shims() - } } diff --git a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java index 6c1fc8e945d..cf1405ccb89 100644 --- a/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java +++ b/sql-plugin/src/main/java/com/nvidia/spark/rapids/GpuColumnVector.java @@ -25,8 +25,9 @@ import ai.rapids.cudf.Scalar; import ai.rapids.cudf.Schema; import ai.rapids.cudf.Table; - +import com.nvidia.spark.rapids.shims.GpuTypeShims; import org.apache.arrow.memory.ReferenceManager; + import org.apache.spark.sql.catalyst.expressions.Attribute; import org.apache.spark.sql.types.*; import org.apache.spark.sql.vectorized.ColumnVector; @@ -460,6 +461,13 @@ public void releaseReferences() { } private static DType toRapidsOrNull(DataType type) { + DType ret = toRapidsOrNullCommon(type); + // Check types that shim supporting + // e.g.: Spark 3.3.0 begin supporting AnsiIntervalType to/from parquet + return (ret != null) ? ret : GpuTypeShims.toRapidsOrNull(type); + } + + private static DType toRapidsOrNullCommon(DataType type) { if (type instanceof LongType) { return DType.INT64; } else if (type instanceof DoubleType) { diff --git a/sql-plugin/src/main/post320-treenode/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala b/sql-plugin/src/main/post320-treenode/scala/com/nvidia/spark/rapids/shims/TreeNode.scala similarity index 96% rename from sql-plugin/src/main/post320-treenode/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala rename to sql-plugin/src/main/post320-treenode/scala/com/nvidia/spark/rapids/shims/TreeNode.scala index 04d64b043d9..769833c2ac0 100644 --- a/sql-plugin/src/main/post320-treenode/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala +++ b/sql-plugin/src/main/post320-treenode/scala/com/nvidia/spark/rapids/shims/TreeNode.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, TernaryExpression, UnaryExpression} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryCommand} diff --git a/sql-plugin/src/main/pre320-treenode/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala b/sql-plugin/src/main/pre320-treenode/scala/com/nvidia/spark/rapids/shims/TreeNode.scala similarity index 94% rename from sql-plugin/src/main/pre320-treenode/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala rename to sql-plugin/src/main/pre320-treenode/scala/com/nvidia/spark/rapids/shims/TreeNode.scala index 13e1d74d791..41ae3e51087 100644 --- a/sql-plugin/src/main/pre320-treenode/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala +++ b/sql-plugin/src/main/pre320-treenode/scala/com/nvidia/spark/rapids/shims/TreeNode.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, TernaryExpression, UnaryExpression} import org.apache.spark.sql.catalyst.plans.logical.Command diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala b/sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala index f2d739c8eba..9411a692f25 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/RebaseHelper.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,8 @@ package com.nvidia.spark import ai.rapids.cudf.{ColumnVector, DType, Scalar} -import com.nvidia.spark.rapids.{Arm, ShimLoader} +import com.nvidia.spark.rapids.Arm +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.spark.sql.catalyst.util.RebaseDateTime import org.apache.spark.sql.rapids.execution.TrampolineUtil @@ -75,9 +76,9 @@ object RebaseHelper extends Arm { def newRebaseExceptionInRead(format: String): Exception = { val config = if (format == "Parquet") { - ShimLoader.getSparkShims.parquetRebaseReadKey + SparkShimImpl.parquetRebaseReadKey } else if (format == "Avro") { - ShimLoader.getSparkShims.avroRebaseReadKey + SparkShimImpl.avroRebaseReadKey } else { throw new IllegalStateException("unrecognized format " + format) } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/CostBasedOptimizer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/CostBasedOptimizer.scala index 9d896fe5355..702cfda7d58 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/CostBasedOptimizer.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/CostBasedOptimizer.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ package com.nvidia.spark.rapids import scala.collection.mutable.ListBuffer +import com.nvidia.spark.rapids.shims.SparkShimImpl + import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, GetStructField, WindowFrame, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.{JoinType, LeftAnti, LeftSemi} @@ -260,7 +262,7 @@ class CostBasedOptimizer extends Optimizer with Logging { private def isExchangeOp(plan: SparkPlanMeta[_]): Boolean = { // if the child query stage already executed on GPU then we need to keep the // next operator on GPU in these cases - ShimLoader.getSparkShims.isExchangeOp(plan) + SparkShimImpl.isExchangeOp(plan) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala index 07f6e50cedc..f70068cfb39 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuApproximatePercentile.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf import ai.rapids.cudf.GroupByAggregation import com.nvidia.spark.rapids.GpuCast.doCast -import com.nvidia.spark.rapids.shims.v2.ShimExpression +import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchScanExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchScanExec.scala index 5845c975838..3f83ec05cff 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchScanExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBatchScanExec.scala @@ -22,6 +22,7 @@ import scala.collection.JavaConverters._ import ai.rapids.cudf import ai.rapids.cudf.{ColumnVector, DType, HostMemoryBuffer, Scalar, Schema, Table} +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -230,7 +231,7 @@ object GpuCSVScan { if (!TypeChecks.areTimestampsSupported(parsedOptions.zoneId)) { meta.willNotWorkOnGpu("Only UTC zone id is supported") } - ShimLoader.getSparkShims.timestampFormatInRead(parsedOptions).foreach { tsFormat => + SparkShimImpl.timestampFormatInRead(parsedOptions).foreach { tsFormat => val parts = tsFormat.split("'T'", 2) if (parts.isEmpty) { meta.willNotWorkOnGpu(s"the timestamp format '$tsFormat' is not supported") diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBoundAttribute.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBoundAttribute.scala index 484b371cc99..44ba329053d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBoundAttribute.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBoundAttribute.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.ast -import com.nvidia.spark.rapids.shims.v2.ShimExpression +import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSeq, Expression, ExprId, SortOrder} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBringBackToHost.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBringBackToHost.scala index 8777ea0a865..50788a95e47 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBringBackToHost.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBringBackToHost.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{NvtxColor, NvtxRange} -import com.nvidia.spark.rapids.shims.v2.ShimUnaryExecNode +import com.nvidia.spark.rapids.shims.ShimUnaryExecNode import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBroadcastHashJoinExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBroadcastHashJoinExec.scala index 42ca17467c0..a456b92ff28 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBroadcastHashJoinExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuBroadcastHashJoinExec.scala @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids -import com.nvidia.spark.rapids.shims.v2.{GpuJoinUtils, ShimBinaryExecNode} +import com.nvidia.spark.rapids.shims.{GpuJoinUtils, ShimBinaryExecNode} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala index ad115d48501..84414232302 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCast.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DecimalUtils, DType, Scalar} import ai.rapids.cudf import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.{AnsiCheckUtil, YearParseUtil} +import com.nvidia.spark.rapids.shims.{AnsiCheckUtil, SparkShimImpl, YearParseUtil} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.{Cast, CastBase, Expression, NullIntolerant, TimeZoneAwareExpression} @@ -52,7 +52,7 @@ final class CastExprMeta[INPUT <: CastBase]( val fromType: DataType = cast.child.dataType val toType: DataType = toTypeOverride.getOrElse(cast.dataType) - val legacyCastToString: Boolean = ShimLoader.getSparkShims.getLegacyComplexTypeToString() + val legacyCastToString: Boolean = SparkShimImpl.getLegacyComplexTypeToString() override def tagExprForGpu(): Unit = recursiveTagExprForGpuCheck() @@ -116,7 +116,7 @@ final class CastExprMeta[INPUT <: CastBase]( case (_: StringType, _: DateType) => YearParseUtil.tagParseStringAsDate(conf, this) case (_: StringType, dt:DecimalType) => - if (dt.scale < 0 && !ShimLoader.getSparkShims.isCastingStringToNegDecimalScaleSupported) { + if (dt.scale < 0 && !SparkShimImpl.isCastingStringToNegDecimalScaleSupported) { willNotWorkOnGpu("RAPIDS doesn't support casting string to decimal for " + "negative scale decimal in this version of Spark because of SPARK-37451") } @@ -449,7 +449,7 @@ object GpuCast extends Arm { withResource(FloatUtils.infinityToNulls(inputWithNansToNull)) { inputWithoutNanAndInfinity => if (fromDataType == FloatType && - ShimLoader.getSparkShims.hasCastFloatTimestampUpcast) { + SparkShimImpl.hasCastFloatTimestampUpcast) { withResource(inputWithoutNanAndInfinity.castTo(DType.FLOAT64)) { doubles => withResource(doubles.mul(microsPerSec, DType.INT64)) { inputTimesMicrosCv => diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala index 26b1c0c4116..8e238060f6c 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuCoalesceBatches.scala @@ -20,7 +20,7 @@ import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{Cuda, NvtxColor, Table} import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.{ShimExpression, ShimUnaryExecNode} +import com.nvidia.spark.rapids.shims.{ShimExpression, ShimUnaryExecNode} import org.apache.spark.TaskContext import org.apache.spark.internal.Logging diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuColumnarToRowExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuColumnarToRowExec.scala index 7ffe9c3d14d..72bec24ddbb 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuColumnarToRowExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuColumnarToRowExec.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.Queue import ai.rapids.cudf.{HostColumnVector, NvtxColor, Table} import com.nvidia.spark.rapids.GpuColumnarToRowExecParent.makeIteratorFunc -import com.nvidia.spark.rapids.shims.v2.ShimUnaryExecNode +import com.nvidia.spark.rapids.shims.ShimUnaryExecNode import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDataSourceRDD.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDataSourceRDD.scala index 454d9d06ca6..d04d469cbb8 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDataSourceRDD.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDataSourceRDD.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids -import com.nvidia.spark.rapids.shims.v2.ShimDataSourceRDD +import com.nvidia.spark.rapids.shims.ShimDataSourceRDD import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, SparkException, TaskContext} import org.apache.spark.sql.catalyst.InternalRow diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDataWritingCommandExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDataWritingCommandExec.scala index 92a7c591e32..6906f3b2f03 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDataWritingCommandExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuDataWritingCommandExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids import java.net.URI -import com.nvidia.spark.rapids.shims.v2.{ShimUnaryCommand, ShimUnaryExecNode} +import com.nvidia.spark.rapids.shims.{ShimUnaryCommand, ShimUnaryExecNode} import org.apache.hadoop.conf.Configuration import org.apache.spark.rdd.RDD diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala index a85949ed60c..fb4d12e8dff 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExec.scala @@ -19,6 +19,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.NvtxColor import com.nvidia.spark.RebaseHelper.withResource import com.nvidia.spark.rapids.StorageTier.{DEVICE, DISK, GDS, HOST, StorageTier} +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession @@ -210,7 +211,7 @@ object GpuExec { trait GpuExec extends SparkPlan with Arm { import GpuMetric._ def sparkSession: SparkSession = { - ShimLoader.getSparkShims.sessionFromPlan(this) + SparkShimImpl.sessionFromPlan(this) } /** @@ -316,7 +317,7 @@ trait GpuExec extends SparkPlan with Arm { // normalize that for equality testing, by assigning expr id from 0 incrementally. The // alias name doesn't matter and should be erased. val normalizedChild = QueryPlan.normalizeExpressions(a.child, allAttributes) - ShimLoader.getSparkShims.alias(normalizedChild, "")(ExprId(id), a.qualifier) + SparkShimImpl.alias(normalizedChild, "")(ExprId(id), a.qualifier) case a: GpuAlias => id += 1 // As the root of the expression, Alias will always take an arbitrary exprId, we need to diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpandExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpandExec.scala index 2e4730cc45c..2250d5c2f02 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpandExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpandExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ import scala.collection.mutable import ai.rapids.cudf.NvtxColor import com.nvidia.spark.rapids.GpuMetric._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.ShimUnaryExecNode +import com.nvidia.spark.rapids.shims.ShimUnaryExecNode import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpressions.scala index f8e2c5ebe68..e0d1b2fb135 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpressions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuExpressions.scala @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{ast, BinaryOp, BinaryOperable, ColumnVector, DType, Scalar, UnaryOp} import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.{ShimBinaryExpression, ShimExpression, ShimTernaryExpression, ShimUnaryExpression} +import com.nvidia.spark.rapids.shims.{ShimBinaryExpression, ShimExpression, ShimTernaryExpression, ShimUnaryExpression} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGenerateExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGenerateExec.scala index 906675e4640..a59ac484bf2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGenerateExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGenerateExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ package com.nvidia.spark.rapids import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{ColumnVector, ContiguousTable, NvtxColor, Table} -import com.nvidia.spark.rapids.shims.v2.{ShimExpression, ShimUnaryExecNode} +import com.nvidia.spark.rapids.shims.{ShimExpression, ShimUnaryExecNode} import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashPartitioningBase.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashPartitioningBase.scala index 9ede14a8c35..d1a49c1d6f4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashPartitioningBase.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuHashPartitioningBase.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{DType, NvtxColor, NvtxRange} -import com.nvidia.spark.rapids.shims.v2.ShimExpression +import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.rapids.GpuMurmur3Hash diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuKryoRegistrator.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuKryoRegistrator.scala index 6c84a490e1f..3006d636138 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuKryoRegistrator.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuKryoRegistrator.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,11 +17,12 @@ package com.nvidia.spark.rapids import com.esotericsoftware.kryo.Kryo +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.spark.serializer.KryoRegistrator class GpuKryoRegistrator extends KryoRegistrator { override def registerClasses(kryo: Kryo): Unit = { - ShimLoader.getSparkShims.registerKryoClasses(kryo) + SparkShimImpl.registerKryoClasses(kryo) } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScanBase.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScanBase.scala index 14afd362602..4087a9af8c4 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScanBase.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcScanBase.scala @@ -36,7 +36,7 @@ import com.google.protobuf.CodedOutputStream import com.nvidia.spark.rapids.GpuMetric._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.SchemaUtils._ -import com.nvidia.spark.rapids.shims.v2.OrcShims +import com.nvidia.spark.rapids.shims.OrcShims import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.common.io.DiskRangeList diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index e25a1f08ffb..1ff52c7046d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -24,7 +24,7 @@ import scala.util.control.NonFatal import ai.rapids.cudf.DType import com.nvidia.spark.rapids.RapidsConf.{SUPPRESS_PLANNING_FAILURE, TEST_CONF} -import com.nvidia.spark.rapids.shims.v2.{AQEUtils, GpuHashPartitioning, GpuRangePartitioning, GpuSpecifiedWindowFrameMeta, GpuWindowExpressionMeta, OffsetWindowFunctionMeta} +import com.nvidia.spark.rapids.shims.{AQEUtils, GpuHashPartitioning, GpuRangePartitioning, GpuSpecifiedWindowFrameMeta, GpuWindowExpressionMeta, OffsetWindowFunctionMeta, SparkShimImpl} import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SparkSession} @@ -60,7 +60,7 @@ import org.apache.spark.sql.rapids._ import org.apache.spark.sql.rapids.catalyst.expressions.GpuRand import org.apache.spark.sql.rapids.execution._ import org.apache.spark.sql.rapids.execution.python._ -import org.apache.spark.sql.rapids.shims.v2.GpuTimeAdd +import org.apache.spark.sql.rapids.shims.GpuTimeAdd import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -802,7 +802,7 @@ object GpuOverrides extends Logging { .map(r => r.wrap(expr, conf, parent, r).asInstanceOf[BaseExprMeta[INPUT]]) .getOrElse(new RuleNotFoundExprMeta(expr, conf, parent)) - lazy val fileFormats: Map[FileFormatType, Map[FileFormatOp, FileFormatChecks]] = Map( + lazy val basicFormats: Map[FileFormatType, Map[FileFormatOp, FileFormatChecks]] = Map( (CsvFormatType, FileFormatChecks( cudfRead = TypeSig.commonCudfTypes + TypeSig.DECIMAL_128, cudfWrite = TypeSig.none, @@ -828,6 +828,8 @@ object GpuOverrides extends Logging { sparkSig = (TypeSig.cpuAtomics + TypeSig.STRUCT + TypeSig.ARRAY + TypeSig.MAP + TypeSig.UDT).nested()))) + lazy val fileFormats = basicFormats ++ SparkShimImpl.getFileFormats + val commonExpressions: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq( expr[Literal]( "Holds a static value from the query", @@ -1720,7 +1722,7 @@ object GpuOverrides extends Logging { TypeSig.STRING)), (a, conf, p, r) => new UnixTimeExprMeta[ToUnixTimestamp](a, conf, p, r) { override def shouldFallbackOnAnsiTimestamp: Boolean = - ShimLoader.getSparkShims.shouldFallbackOnAnsiTimestamp + SparkShimImpl.shouldFallbackOnAnsiTimestamp override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { if (conf.isImprovedTimestampOpsEnabled) { @@ -1742,7 +1744,7 @@ object GpuOverrides extends Logging { TypeSig.STRING)), (a, conf, p, r) => new UnixTimeExprMeta[UnixTimestamp](a, conf, p, r) { override def shouldFallbackOnAnsiTimestamp: Boolean = - ShimLoader.getSparkShims.shouldFallbackOnAnsiTimestamp + SparkShimImpl.shouldFallbackOnAnsiTimestamp override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { if (conf.isImprovedTimestampOpsEnabled) { @@ -3218,7 +3220,7 @@ object GpuOverrides extends Logging { Seq(ParamCheck("input", TypeSig.DOUBLE, TypeSig.DOUBLE))), (a, conf, p, r) => new AggExprMeta[StddevPop](a, conf, p, r) { override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = { - val legacyStatisticalAggregate = ShimLoader.getSparkShims.getLegacyStatisticalAggregate + val legacyStatisticalAggregate = SparkShimImpl.getLegacyStatisticalAggregate GpuStddevPop(childExprs.head, !legacyStatisticalAggregate) } }), @@ -3230,7 +3232,7 @@ object GpuOverrides extends Logging { TypeSig.DOUBLE))), (a, conf, p, r) => new AggExprMeta[StddevSamp](a, conf, p, r) { override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = { - val legacyStatisticalAggregate = ShimLoader.getSparkShims.getLegacyStatisticalAggregate + val legacyStatisticalAggregate = SparkShimImpl.getLegacyStatisticalAggregate GpuStddevSamp(childExprs.head, !legacyStatisticalAggregate) } }), @@ -3241,7 +3243,7 @@ object GpuOverrides extends Logging { Seq(ParamCheck("input", TypeSig.DOUBLE, TypeSig.DOUBLE))), (a, conf, p, r) => new AggExprMeta[VariancePop](a, conf, p, r) { override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = { - val legacyStatisticalAggregate = ShimLoader.getSparkShims.getLegacyStatisticalAggregate + val legacyStatisticalAggregate = SparkShimImpl.getLegacyStatisticalAggregate GpuVariancePop(childExprs.head, !legacyStatisticalAggregate) } }), @@ -3252,7 +3254,7 @@ object GpuOverrides extends Logging { Seq(ParamCheck("input", TypeSig.DOUBLE, TypeSig.DOUBLE))), (a, conf, p, r) => new AggExprMeta[VarianceSamp](a, conf, p, r) { override def convertToGpu(childExprs: Seq[Expression]): GpuExpression = { - val legacyStatisticalAggregate = ShimLoader.getSparkShims.getLegacyStatisticalAggregate + val legacyStatisticalAggregate = SparkShimImpl.getLegacyStatisticalAggregate GpuVarianceSamp(childExprs.head, !legacyStatisticalAggregate) } }), @@ -3384,7 +3386,7 @@ object GpuOverrides extends Logging { // Shim expressions should be last to allow overrides with shim-specific versions val expressions: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = commonExpressions ++ TimeStamp.getExprs ++ GpuHiveOverrides.exprs ++ - ShimLoader.getSparkShims.getExprs + SparkShimImpl.getExprs def wrapScan[INPUT <: Scan]( scan: INPUT, @@ -3431,7 +3433,7 @@ object GpuOverrides extends Logging { })).map(r => (r.getClassFor.asSubclass(classOf[Scan]), r)).toMap val scans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]] = - commonScans ++ ShimLoader.getSparkShims.getScans + commonScans ++ SparkShimImpl.getScans def wrapPart[INPUT <: Partitioning]( part: INPUT, @@ -3605,7 +3607,7 @@ object GpuOverrides extends Logging { takeExec.limit, so, projectList.map(_.convertToGpu().asInstanceOf[NamedExpression]), - ShimLoader.getSparkShims.getGpuShuffleExchangeExec( + SparkShimImpl.getGpuShuffleExchangeExec( GpuSinglePartitioning, GpuTopN( takeExec.limit, @@ -3806,7 +3808,7 @@ object GpuOverrides extends Logging { ExecChecks(TypeSig.all, TypeSig.all), (s, conf, p, r) => new GpuSubqueryBroadcastMeta(s, conf, p, r) ), - ShimLoader.getSparkShims.aqeShuffleReaderExec, + SparkShimImpl.aqeShuffleReaderExec, exec[FlatMapCoGroupsInPandasExec]( "The backend for CoGrouped Aggregation Pandas UDF, it runs on CPU itself now but supports" + " scheduling GPU resources for the Python process when enabled", @@ -3818,7 +3820,7 @@ object GpuOverrides extends Logging { neverReplaceExec[DescribeNamespaceExec]("Namespace metadata operation"), neverReplaceExec[DropNamespaceExec]("Namespace metadata operation"), neverReplaceExec[SetCatalogAndNamespaceExec]("Namespace metadata operation"), - ShimLoader.getSparkShims.neverReplaceShowCurrentNamespaceCommand, + SparkShimImpl.neverReplaceShowCurrentNamespaceCommand, neverReplaceExec[ShowNamespacesExec]("Namespace metadata operation"), neverReplaceExec[ExecutedCommandExec]("Table metadata operation"), neverReplaceExec[AlterTableExec]("Table metadata operation"), @@ -3838,7 +3840,7 @@ object GpuOverrides extends Logging { ).collect { case r if r != null => (r.getClassFor.asSubclass(classOf[SparkPlan]), r) }.toMap lazy val execs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = - commonExecs ++ ShimLoader.getSparkShims.getExecs + commonExecs ++ SparkShimImpl.getExecs def getTimeParserPolicy: TimeParserPolicy = { val policy = SQLConf.get.getConfString(SQLConf.LEGACY_TIME_PARSER_POLICY.key, "EXCEPTION") @@ -4003,7 +4005,7 @@ object GpuOverrides extends Logging { case c2r: ColumnarToRowExec => prepareExplainOnly(c2r.child) case re: ReusedExchangeExec => prepareExplainOnly(re.child) case aqe: AdaptiveSparkPlanExec => - prepareExplainOnly(ShimLoader.getSparkShims.getAdaptiveInputPlan(aqe)) + prepareExplainOnly(SparkShimImpl.getAdaptiveInputPlan(aqe)) case sub: SubqueryExec => prepareExplainOnly(sub.child) } planAfter diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala index 3885917edfe..754e9884923 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetFileFormat.scala @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf._ import com.nvidia.spark.RebaseHelper -import com.nvidia.spark.rapids.shims.v2.ParquetFieldIdShims +import com.nvidia.spark.rapids.shims.{ParquetFieldIdShims, SparkShimImpl} import org.apache.hadoop.mapreduce.{Job, OutputCommitter, TaskAttemptContext} import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} import org.apache.parquet.hadoop.ParquetOutputFormat.JobSummaryLevel @@ -91,7 +91,7 @@ object GpuParquetFileFormat { } - ShimLoader.getSparkShims.int96ParquetRebaseWrite(sqlConf) match { + SparkShimImpl.int96ParquetRebaseWrite(sqlConf) match { case "EXCEPTION" => case "CORRECTED" => case "LEGACY" => @@ -102,7 +102,7 @@ object GpuParquetFileFormat { meta.willNotWorkOnGpu(s"$other is not a supported rebase mode for int96") } - ShimLoader.getSparkShims.parquetRebaseWrite(sqlConf) match { + SparkShimImpl.parquetRebaseWrite(sqlConf) match { case "EXCEPTION" => //Good case "CORRECTED" => //Good case "LEGACY" => @@ -161,15 +161,15 @@ class GpuParquetFileFormat extends ColumnarFileFormat with Logging { val outputTimestampType = sqlConf.parquetOutputTimestampType val dateTimeRebaseException = "EXCEPTION".equals( - sparkSession.sqlContext.getConf(ShimLoader.getSparkShims.parquetRebaseWriteKey)) + sparkSession.sqlContext.getConf(SparkShimImpl.parquetRebaseWriteKey)) // prior to spark 311 int96 don't check for rebase exception // https://github.com/apache/spark/blob/068465d016447ef0dbf7974b1a3f992040f4d64d/sql/core/src/ // main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala#L195 - val hasSeparateInt96RebaseConf = ShimLoader.getSparkShims.hasSeparateINT96RebaseConf + val hasSeparateInt96RebaseConf = SparkShimImpl.hasSeparateINT96RebaseConf val timestampRebaseException = outputTimestampType.equals(ParquetOutputTimestampType.INT96) && "EXCEPTION".equals(sparkSession.sqlContext - .getConf(ShimLoader.getSparkShims.int96ParquetRebaseWriteKey)) && + .getConf(SparkShimImpl.int96ParquetRebaseWriteKey)) && hasSeparateInt96RebaseConf || !outputTimestampType.equals(ParquetOutputTimestampType.INT96) && dateTimeRebaseException diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScanBase.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScanBase.scala index e629d4b0c95..a43ac6e19d2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScanBase.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuParquetScanBase.scala @@ -33,7 +33,7 @@ import com.nvidia.spark.RebaseHelper import com.nvidia.spark.rapids.GpuMetric._ import com.nvidia.spark.rapids.ParquetPartitionReader.CopyRange import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.ParquetFieldIdShims +import com.nvidia.spark.rapids.shims.{ParquetFieldIdShims, SparkShimImpl} import org.apache.commons.io.output.{CountingOutputStream, NullOutputStream} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FSDataInputStream, Path} @@ -196,31 +196,31 @@ object GpuParquetScanBase { meta.willNotWorkOnGpu("GpuParquetScan does not support int96 timestamp conversion") } - sqlConf.get(ShimLoader.getSparkShims.int96ParquetRebaseReadKey) match { + sqlConf.get(SparkShimImpl.int96ParquetRebaseReadKey) match { case "EXCEPTION" => if (schemaMightNeedNestedRebase) { meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + - s"${ShimLoader.getSparkShims.int96ParquetRebaseReadKey} is EXCEPTION") + s"${SparkShimImpl.int96ParquetRebaseReadKey} is EXCEPTION") } case "CORRECTED" => // Good case "LEGACY" => // really is EXCEPTION for us... if (schemaMightNeedNestedRebase) { meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + - s"${ShimLoader.getSparkShims.int96ParquetRebaseReadKey} is LEGACY") + s"${SparkShimImpl.int96ParquetRebaseReadKey} is LEGACY") } case other => meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode") } - sqlConf.get(ShimLoader.getSparkShims.parquetRebaseReadKey) match { + sqlConf.get(SparkShimImpl.parquetRebaseReadKey) match { case "EXCEPTION" => if (schemaMightNeedNestedRebase) { meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + - s"${ShimLoader.getSparkShims.parquetRebaseReadKey} is EXCEPTION") + s"${SparkShimImpl.parquetRebaseReadKey} is EXCEPTION") } case "CORRECTED" => // Good case "LEGACY" => // really is EXCEPTION for us... if (schemaMightNeedNestedRebase) { meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " + - s"${ShimLoader.getSparkShims.parquetRebaseReadKey} is LEGACY") + s"${SparkShimImpl.parquetRebaseReadKey} is LEGACY") } case other => meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode") @@ -313,9 +313,9 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte private val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal private val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringStartWith private val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold - private val rebaseMode = ShimLoader.getSparkShims.parquetRebaseRead(sqlConf) + private val rebaseMode = SparkShimImpl.parquetRebaseRead(sqlConf) private val isCorrectedRebase = "CORRECTED" == rebaseMode - val int96RebaseMode = ShimLoader.getSparkShims.int96ParquetRebaseRead(sqlConf) + val int96RebaseMode = SparkShimImpl.int96ParquetRebaseRead(sqlConf) private val isInt96CorrectedRebase = "CORRECTED" == int96RebaseMode @@ -344,7 +344,7 @@ private case class GpuParquetFileFilterHandler(@transient sqlConf: SQLConf) exte ParquetMetadataConverter.range(file.start, file.start + file.length)) val fileSchema = footer.getFileMetaData.getSchema val pushedFilters = if (enableParquetFilterPushDown) { - val parquetFilters = ShimLoader.getSparkShims.getParquetFilters(fileSchema, pushDownDate, + val parquetFilters = SparkShimImpl.getParquetFilters(fileSchema, pushDownDate, pushDownTimestamp, pushDownDecimal, pushDownStringStartWith, pushDownInFilterThreshold, isCaseSensitive, footer.getFileMetaData.getKeyValueMetaData.get, rebaseMode) filters.flatMap(parquetFilters.createFilter).reduceOption(FilterApi.and) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRangePartitioner.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRangePartitioner.scala index eb82cdf38af..2d89d7a69f3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRangePartitioner.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRangePartitioner.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ package com.nvidia.spark.rapids import scala.collection.mutable.ArrayBuffer import scala.util.hashing.byteswap32 -import com.nvidia.spark.rapids.shims.v2.ShimExpression +import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.rdd.{PartitionPruningRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadCSVFileFormat.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadCSVFileFormat.scala index 2c7db4ac0ad..18fb7989e38 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadCSVFileFormat.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadCSVFileFormat.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package com.nvidia.spark.rapids +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.hadoop.conf.Configuration import org.apache.spark.sql.SparkSession @@ -68,7 +69,7 @@ object GpuReadCSVFileFormat { def tagSupport(meta: SparkPlanMeta[FileSourceScanExec]): Unit = { val fsse = meta.wrapped GpuCSVScan.tagSupport( - ShimLoader.getSparkShims.sessionFromPlan(fsse), + SparkShimImpl.sessionFromPlan(fsse), fsse.relation.dataSchema, fsse.output.toStructType, fsse.relation.options, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadOrcFileFormat.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadOrcFileFormat.scala index 88dfd3740e1..889938a1d33 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadOrcFileFormat.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadOrcFileFormat.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package com.nvidia.spark.rapids +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.hadoop.conf.Configuration import org.apache.spark.sql.SparkSession @@ -63,7 +64,7 @@ object GpuReadOrcFileFormat { meta.willNotWorkOnGpu("mergeSchema and schema evolution is not supported yet") } GpuOrcScanBase.tagSupport( - ShimLoader.getSparkShims.sessionFromPlan(fsse), + SparkShimImpl.sessionFromPlan(fsse), fsse.requiredSchema, meta ) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadParquetFileFormat.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadParquetFileFormat.scala index b0343489d8d..24b7b746dfb 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadParquetFileFormat.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuReadParquetFileFormat.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package com.nvidia.spark.rapids +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.hadoop.conf.Configuration import org.apache.spark.sql.SparkSession @@ -60,7 +61,7 @@ object GpuReadParquetFileFormat { def tagSupport(meta: SparkPlanMeta[FileSourceScanExec]): Unit = { val fsse = meta.wrapped GpuParquetScanBase.tagSupport( - ShimLoader.getSparkShims.sessionFromPlan(fsse), + SparkShimImpl.sessionFromPlan(fsse), fsse.requiredSchema, meta ) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRoundRobinPartitioning.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRoundRobinPartitioning.scala index a0a3d676708..2f25bf2ec6d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRoundRobinPartitioning.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRoundRobinPartitioning.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ import java.util.Random import ai.rapids.cudf.{NvtxColor, NvtxRange} import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.ShimExpression +import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.TaskContext import org.apache.spark.sql.types.{DataType, IntegerType} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRowToColumnarExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRowToColumnarExec.scala index 9d971b1d18d..b2c6a6eb227 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRowToColumnarExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuRowToColumnarExec.scala @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{NvtxColor, NvtxRange} import com.nvidia.spark.rapids.GpuColumnVector.GpuColumnarBatchBuilder -import com.nvidia.spark.rapids.shims.v2.ShimUnaryExecNode +import com.nvidia.spark.rapids.shims.{GpuTypeShims, ShimUnaryExecNode} import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast @@ -138,6 +138,9 @@ private[rapids] object GpuRowToColumnConverter { getConverterForType(v, vcn)) case (NullType, true) => NullConverter + // check special Shims types, such as DayTimeIntervalType + case (otherType, nullable) if GpuTypeShims.hasConverterForType(otherType) => + GpuTypeShims.getConverterForType(otherType, nullable) case (unknown, _) => throw new UnsupportedOperationException( s"Type $unknown not supported") } @@ -284,7 +287,7 @@ private[rapids] object GpuRowToColumnConverter { override def getNullSize: Double = 4 + VALIDITY } - private object LongConverter extends TypeConverter { + private[rapids] object LongConverter extends TypeConverter { override def append(row: SpecializedGetters, column: Int, builder: ai.rapids.cudf.HostColumnVector.ColumnBuilder): Double = { @@ -299,7 +302,7 @@ private[rapids] object GpuRowToColumnConverter { override def getNullSize: Double = 8 + VALIDITY } - private object NotNullLongConverter extends TypeConverter { + private[rapids] object NotNullLongConverter extends TypeConverter { override def append(row: SpecializedGetters, column: Int, builder: ai.rapids.cudf.HostColumnVector.ColumnBuilder): Double = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffleCoalesceExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffleCoalesceExec.scala index c85d6fe1a60..3b4661027f7 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffleCoalesceExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffleCoalesceExec.scala @@ -20,7 +20,7 @@ import java.util import ai.rapids.cudf.{HostConcatResultUtil, HostMemoryBuffer, JCudfSerialization, NvtxColor, NvtxRange} import ai.rapids.cudf.JCudfSerialization.{HostConcatResult, SerializedTableHeader} -import com.nvidia.spark.rapids.shims.v2.ShimUnaryExecNode +import com.nvidia.spark.rapids.shims.ShimUnaryExecNode import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala index 8c33ede98bc..6d92bece75e 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuShuffledHashJoinExec.scala @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{HostConcatResultUtil, NvtxColor, NvtxRange} import ai.rapids.cudf.JCudfSerialization.HostConcatResult -import com.nvidia.spark.rapids.shims.v2.{GpuHashPartitioning, GpuJoinUtils, ShimBinaryExecNode} +import com.nvidia.spark.rapids.shims.{GpuHashPartitioning, GpuJoinUtils, ShimBinaryExecNode} import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSinglePartitioning.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSinglePartitioning.scala index db4d5ff237b..ce114ec6dea 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSinglePartitioning.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSinglePartitioning.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ package com.nvidia.spark.rapids -import com.nvidia.spark.rapids.shims.v2.ShimExpression +import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala index 465ed20f9de..f8352debad6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSortExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{ColumnVector, ContiguousTable, NvtxColor, NvtxRange, Table} import com.nvidia.spark.rapids.GpuMetric._ -import com.nvidia.spark.rapids.shims.v2.ShimUnaryExecNode +import com.nvidia.spark.rapids.shims.ShimUnaryExecNode import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala index 54bf57e94fa..e6066e312f6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuTransitionOverrides.scala @@ -18,6 +18,8 @@ package com.nvidia.spark.rapids import scala.annotation.tailrec +import com.nvidia.spark.rapids.shims.SparkShimImpl + import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeReference, Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, SortOrder} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ @@ -60,7 +62,7 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { } private def getColumnarToRowExec(plan: SparkPlan, exportColumnRdd: Boolean = false) = { - ShimLoader.getSparkShims.getGpuColumnarToRowTransition(plan, exportColumnRdd) + SparkShimImpl.getGpuColumnarToRowTransition(plan, exportColumnRdd) } /** Adds the appropriate coalesce after a shuffle depending on the type of shuffle configured */ @@ -92,7 +94,7 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { case a: AdaptiveSparkPlanExec => // we hit this case when we have an adaptive plan wrapped in a write // to columnar file format on the GPU - val columnarAdaptivePlan = ShimLoader.getSparkShims.columnarAdaptivePlan(a, goal) + val columnarAdaptivePlan = SparkShimImpl.columnarAdaptivePlan(a, goal) optimizeAdaptiveTransitions(columnarAdaptivePlan, None) case _ => val preProcessing = child.getTagValue(GpuOverrides.preRowToColProjection) @@ -141,7 +143,7 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { val plan = GpuTransitionOverrides.getNonQueryStagePlan(s) if (plan.supportsColumnar && plan.isInstanceOf[GpuExec]) { parent match { - case Some(x) if ShimLoader.getSparkShims.isCustomReaderExec(x) => + case Some(x) if SparkShimImpl.isCustomReaderExec(x) => // We can't insert a coalesce batches operator between a custom shuffle reader // and a shuffle query stage, so we instead insert it around the custom shuffle // reader later on, in the next top-level case clause. @@ -338,13 +340,13 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { if ((batchScan.scan.isInstanceOf[GpuParquetScanBase] || batchScan.scan.isInstanceOf[GpuOrcScanBase]) && (disableUntilInput || disableScanUntilInput(batchScan))) { - ShimLoader.getSparkShims.copyBatchScanExec(batchScan, true) + SparkShimImpl.copyBatchScanExec(batchScan, true) } else { batchScan } case fileSourceScan: GpuFileSourceScanExec => if ((disableUntilInput || disableScanUntilInput(fileSourceScan))) { - ShimLoader.getSparkShims.copyFileSourceScanExec(fileSourceScan, true) + SparkShimImpl.copyFileSourceScanExec(fileSourceScan, true) } else { fileSourceScan } @@ -453,7 +455,7 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { val wrapped = GpuOverrides.wrapExpr(expr, rapidsConf, None) wrapped.tagForGpu() assert(wrapped.canThisBeReplaced) - ShimLoader.getSparkShims.sortOrder( + SparkShimImpl.sortOrder( wrapped.convertToGpu(), Ascending, Ascending.defaultNullOrdering) @@ -481,7 +483,7 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { case _: BroadcastHashJoinExec | _: BroadcastNestedLoopJoinExec if isAdaptiveEnabled => // broadcasts are left on CPU for now when AQE is enabled - case p if ShimLoader.getSparkShims.isAqePlan(p) => + case p if SparkShimImpl.isAqePlan(p) => // we do not yet fully support GPU-acceleration when AQE is enabled, so we skip checking // the plan in this case - https://github.com/NVIDIA/spark-rapids/issues/5 case lts: LocalTableScanExec => @@ -499,7 +501,7 @@ class GpuTransitionOverrides extends Rule[SparkPlan] { case _: DropTableExec => case _: ExecutedCommandExec => () // Ignored case _: RDDScanExec => () // Ignored - case p if ShimLoader.getSparkShims.skipAssertIsOnTheGpu(p) => () // Ignored + case p if SparkShimImpl.skipAssertIsOnTheGpu(p) => () // Ignored case _ => if (!plan.supportsColumnar && // There are some python execs that are not columnar because of a little diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuUserDefinedFunction.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuUserDefinedFunction.scala index ea0efff23e2..80b603f730b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuUserDefinedFunction.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuUserDefinedFunction.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{HostColumnVector, HostColumnVectorCore, NvtxColor, NvtxRange} import com.nvidia.spark.RapidsUDF import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.ShimExpression +import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.SparkException import org.apache.spark.internal.Logging diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala index 3fef0aac5ed..9ae57a85663 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf import ai.rapids.cudf.{AggregationOverWindow, DType, GroupByOptions, GroupByScanAggregation, NullPolicy, NvtxColor, ReplacePolicy, ReplacePolicyWithColumn, Scalar, ScanAggregation, ScanType, Table, WindowOptions} -import com.nvidia.spark.rapids.shims.v2.{GpuWindowUtil, ShimUnaryExecNode} +import com.nvidia.spark.rapids.shims.{GpuWindowUtil, ShimUnaryExecNode, SparkShimImpl} import org.apache.spark.TaskContext import org.apache.spark.internal.Logging @@ -315,8 +315,6 @@ object GpuWindowExec extends Arm { val windowDedupe = mutable.HashMap[Expression, Attribute]() val postProject = ArrayBuffer[NamedExpression]() - val shims = ShimLoader.getSparkShims - exprs.foreach { expr => if (hasGpuWindowFunction(expr)) { // First pass replace any operations that should be totally replaced. @@ -348,7 +346,7 @@ object GpuWindowExec extends Arm { extractAndSave(_, preProject, preDedupe)).toArray.toSeq val newOrderSpec = orderSpec.map { so => val newChild = extractAndSave(so.child, preProject, preDedupe) - shims.sortOrder(newChild, so.direction, so.nullOrdering) + SparkShimImpl.sortOrder(newChild, so.direction, so.nullOrdering) }.toArray.toSeq wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec) } @@ -405,13 +403,11 @@ trait GpuWindowBaseExec extends ShimUnaryExecNode with GpuExec { } lazy val gpuPartitionOrdering: Seq[SortOrder] = { - val shims = ShimLoader.getSparkShims - gpuPartitionSpec.map(shims.sortOrder(_, Ascending)) + gpuPartitionSpec.map(SparkShimImpl.sortOrder(_, Ascending)) } lazy val cpuPartitionOrdering: Seq[SortOrder] = { - val shims = ShimLoader.getSparkShims - cpuPartitionSpec.map(shims.sortOrder(_, Ascending)) + cpuPartitionSpec.map(SparkShimImpl.sortOrder(_, Ascending)) } override def requiredChildOrdering: Seq[Seq[SortOrder]] = diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala index 547a6cc0d64..3f60ff9170d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuWindowExpression.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit import ai.rapids.cudf import ai.rapids.cudf.{BinaryOp, ColumnVector, DType, GroupByScanAggregation, RollingAggregation, RollingAggregationOnColumn, Scalar, ScanAggregation} import com.nvidia.spark.rapids.GpuOverrides.wrapExpr -import com.nvidia.spark.rapids.shims.v2.{GpuWindowUtil, ShimExpression} +import com.nvidia.spark.rapids.shims.{GpuWindowUtil, ShimExpression} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis.TypeCheckResult diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala index 8a59926d40e..a3b5adbbba1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostColumnarToGpu.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.collection.JavaConverters._ import scala.collection.mutable -import com.nvidia.spark.rapids.shims.v2.ShimUnaryExecNode +import com.nvidia.spark.rapids.shims.{ShimUnaryExecNode, SparkShimImpl} import org.apache.arrow.memory.ReferenceManager import org.apache.arrow.vector.ValueVector @@ -94,12 +94,12 @@ object HostColumnarToGpu extends Logging { } val nullCount = valVector.getNullCount() - val dataBuf = getBufferAndAddReference(ShimLoader.getSparkShims.getArrowDataBuf(valVector)) - val validity = getBufferAndAddReference(ShimLoader.getSparkShims.getArrowValidityBuf(valVector)) + val dataBuf = getBufferAndAddReference(SparkShimImpl.getArrowDataBuf(valVector)) + val validity = getBufferAndAddReference(SparkShimImpl.getArrowValidityBuf(valVector)) // this is a bit ugly, not all Arrow types have the offsets buffer var offsets: ByteBuffer = null try { - offsets = getBufferAndAddReference(ShimLoader.getSparkShims.getArrowOffsetsBuf(valVector)) + offsets = getBufferAndAddReference(SparkShimImpl.getArrowOffsetsBuf(valVector)) } catch { case _: UnsupportedOperationException => // swallow the exception and assume no offsets buffer diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala index 16ff8eb02f2..f59038ec9cd 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala @@ -26,6 +26,7 @@ import scala.util.Try import scala.util.matching.Regex import com.nvidia.spark.rapids.python.PythonWorkerSemaphore +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext} @@ -386,7 +387,7 @@ object ExecutionPlanCaptureCallback { } private def didFallBack(plan: SparkPlan, fallbackCpuClass: String): Boolean = { - ShimLoader.getSparkShims.getSparkShimVersion.toString + SparkShimImpl.getSparkShimVersion.toString val executedPlan = ExecutionPlanCaptureCallback.extractExecutedPlan(Some(plan)) !executedPlan.getClass.getCanonicalName.equals("com.nvidia.spark.rapids.GpuExec") && PlanUtils.sameClass(executedPlan, fallbackCpuClass) || diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala index 0a8feae0e9f..9f86e85a3bc 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala @@ -408,7 +408,7 @@ object RapidsConf { "memory allocator in CUDA 11.2+ is used. If set to \"NONE\", pooling is disabled and RMM " + "just passes through to CUDA memory allocation directly.") .stringConf - .createWithDefault("ASYNC") + .createWithDefault("ARENA") val CONCURRENT_GPU_TASKS = conf("spark.rapids.sql.concurrentGpuTasks") .doc("Set the number of tasks that can execute concurrently per GPU. " + diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala index ff5d159c190..7b23d6a5b78 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsMeta.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,8 @@ import java.time.ZoneId import scala.collection.mutable +import com.nvidia.spark.rapids.shims.SparkShimImpl + import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BinaryExpression, ComplexTypeMergingExpression, Expression, QuaternaryExpression, String2TrimExpression, TernaryExpression, UnaryExpression, WindowExpression, WindowFunction} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ImperativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.plans.physical.Partitioning @@ -858,7 +860,7 @@ object ExpressionContext { val parent = findParentPlanMeta(meta) assert(parent.isDefined, "It is expected that an aggregate function is a child of a SparkPlan") parent.get.wrapped match { - case agg: SparkPlan if ShimLoader.getSparkShims.isWindowFunctionExec(agg) => + case agg: SparkPlan if SparkShimImpl.isWindowFunctionExec(agg) => WindowAggExprContext case agg: BaseAggregateExec => if (agg.groupingExpressions.isEmpty) { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala index 28393fd44b2..710384ea4c6 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala @@ -76,6 +76,7 @@ object ShimLoader extends Logging { private val shimCommonURL = new URL(s"${shimRootURL.toString}spark3xx-common/") @volatile private var shimProviderClass: String = _ + @volatile private var shimProvider: SparkShimServiceProvider = _ @volatile private var sparkShims: SparkShims = _ @volatile private var shimURL: URL = _ @volatile private var pluginClassLoader: ClassLoader = _ @@ -310,6 +311,7 @@ object ShimLoader extends Logging { shimServiceProvider.matchesVersion(sparkVersion) }.map { case (inst, url) => shimURL = url + shimProvider = inst // this class will be loaded again by the real executor classloader inst.getClass.getName } @@ -331,11 +333,9 @@ object ShimLoader extends Logging { shimProviderClass } - def getSparkShims: SparkShims = { - if (sparkShims == null) { - sparkShims = newInstanceOf[SparkShimServiceProvider](findShimProvider()).buildShim - } - sparkShims + def getShimVersion: ShimVersion = { + initShimProviderIfNeeded() + shimProvider.getShimVersion } def getSparkVersion: String = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SortUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SortUtils.scala index b9e405b23cf..98eb6b57b8f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SortUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SortUtils.scala @@ -20,6 +20,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{ColumnVector, NvtxColor, OrderByArg, Table} +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, BoundReference, Expression, NullsFirst, NullsLast, SortOrder} import org.apache.spark.sql.rapids.execution.TrampolineUtil @@ -98,7 +99,7 @@ class GpuSorter( case Some(ref) => cudfOrdering += SortUtils.getOrder(so, ref.ordinal) // It is a bound GPU reference so we have to translate it to the CPU - cpuOrdering += ShimLoader.getSparkShims.sortOrder( + cpuOrdering += SparkShimImpl.sortOrder( BoundReference(ref.ordinal, ref.dataType, ref.nullable), so.direction, so.nullOrdering) case None => @@ -108,7 +109,7 @@ class GpuSorter( sortOrdersThatNeedsComputation += so // We already did the computation so instead of trying to translate // the computation back to the CPU too, just use the existing columns. - cpuOrdering += ShimLoader.getSparkShims.sortOrder( + cpuOrdering += SparkShimImpl.sortOrder( BoundReference(index, so.dataType, so.nullable), so.direction, so.nullOrdering) } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShimServiceProvider.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShimServiceProvider.scala index e1429a1f706..00b9ca5f41f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShimServiceProvider.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShimServiceProvider.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,6 @@ package com.nvidia.spark.rapids * A Spark version shim layer interface. */ trait SparkShimServiceProvider { + def getShimVersion: ShimVersion def matchesVersion(version:String): Boolean - def buildShim: SparkShims } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala index eba6798b38b..d3ca0766c26 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -115,6 +115,7 @@ trait SparkShims { exportColumnRdd: Boolean): GpuColumnarToRowExecParent def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] def getScans: Map[Class[_ <: Scan], ScanRule[_ <: Scan]] + def getFileFormats: Map[FileFormatType, Map[FileFormatOp, FileFormatChecks]] = Map() def getScalaUDFAsExpression( function: AnyRef, diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala index e9f27510628..e24321aa14b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/TypeChecks.scala @@ -20,7 +20,7 @@ import java.io.{File, FileOutputStream} import java.time.ZoneId import ai.rapids.cudf.DType -import com.nvidia.spark.rapids.shims.v2.TypeSigUtil +import com.nvidia.spark.rapids.shims.TypeSigUtil import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnaryExpression, WindowSpecDefinition} import org.apache.spark.sql.types._ diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/VersionUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/VersionUtils.scala index 08e78a50e40..729cdb9b01f 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/VersionUtils.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/VersionUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ package com.nvidia.spark.rapids +import com.nvidia.spark.rapids.shims.SparkShimImpl + object VersionUtils { lazy val isSpark301OrLater: Boolean = cmpSparkVersion(3, 0, 1) >= 0 @@ -25,23 +27,23 @@ object VersionUtils { lazy val isSpark320OrLater: Boolean = cmpSparkVersion(3, 2, 0) >= 0 lazy val isSpark: Boolean = { - ShimLoader.getSparkShims.getSparkShimVersion.isInstanceOf[SparkShimVersion] + SparkShimImpl.getSparkShimVersion.isInstanceOf[SparkShimVersion] } lazy val isDataBricks: Boolean = { - ShimLoader.getSparkShims.getSparkShimVersion.isInstanceOf[DatabricksShimVersion] + SparkShimImpl.getSparkShimVersion.isInstanceOf[DatabricksShimVersion] } lazy val isCloudera: Boolean = { - ShimLoader.getSparkShims.getSparkShimVersion.isInstanceOf[ClouderaShimVersion] + SparkShimImpl.getSparkShimVersion.isInstanceOf[ClouderaShimVersion] } lazy val isEMR: Boolean = { - ShimLoader.getSparkShims.getSparkShimVersion.isInstanceOf[EMRShimVersion] + SparkShimImpl.getSparkShimVersion.isInstanceOf[EMRShimVersion] } def cmpSparkVersion(major: Int, minor: Int, bugfix: Int): Int = { - val sparkShimVersion = ShimLoader.getSparkShims.getSparkShimVersion + val sparkShimVersion = SparkShimImpl.getSparkShimVersion val (sparkMajor, sparkMinor, sparkBugfix) = sparkShimVersion match { case SparkShimVersion(a, b, c) => (a, b, c) case DatabricksShimVersion(a, b, c, _) => (a, b, c) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala index 159be9a9097..00641748777 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/aggregate.scala @@ -25,7 +25,7 @@ import ai.rapids.cudf import ai.rapids.cudf.NvtxColor import com.nvidia.spark.rapids.GpuMetric._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.ShimUnaryExecNode +import com.nvidia.spark.rapids.shims.{ShimUnaryExecNode, SparkShimImpl} import org.apache.spark.TaskContext import org.apache.spark.internal.Logging @@ -423,9 +423,8 @@ class GpuHashAggregateIterator( "without grouping keys") } - val shims = ShimLoader.getSparkShims val groupingAttributes = groupingExpressions.map(_.toAttribute) - val ordering = groupingAttributes.map(shims.sortOrder(_, Ascending, NullsFirst)) + val ordering = groupingAttributes.map(SparkShimImpl.sortOrder(_, Ascending, NullsFirst)) val aggBufferAttributes = groupingAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) val sorter = new GpuSorter(ordering, aggBufferAttributes) @@ -1210,10 +1209,10 @@ object GpuTypedImperativeSupportedAggregateExecMeta { } converters.dequeue() match { case Left(converter) => - ShimLoader.getSparkShims.alias(converter.createExpression(ref), + SparkShimImpl.alias(converter.createExpression(ref), ref.name + "_converted")(NamedExpression.newExprId) case Right(converter) => - ShimLoader.getSparkShims.alias(converter.createExpression(ref), + SparkShimImpl.alias(converter.createExpression(ref), ref.name + "_converted")(NamedExpression.newExprId) } case retExpr => diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala index 73b719a8bec..482385b6d69 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ import ai.rapids.cudf import ai.rapids.cudf._ import com.nvidia.spark.rapids.GpuMetric._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.{ShimSparkPlan, ShimUnaryExecNode} +import com.nvidia.spark.rapids.shims.{ShimSparkPlan, ShimUnaryExecNode, SparkShimImpl} import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.internal.Logging @@ -567,13 +567,12 @@ case class GpuRangeExec( ) ++ semaphoreMetrics override def outputOrdering: Seq[SortOrder] = { - val shim = ShimLoader.getSparkShims val order = if (step > 0) { Ascending } else { Descending } - output.map(a => shim.sortOrder(a, order)) + output.map(a => SparkShimImpl.sortOrder(a, order)) } override def outputPartitioning: Partitioning = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala index 0757228c9fb..36781bbe918 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/conditionalExpressions.scala @@ -18,7 +18,7 @@ package com.nvidia.spark.rapids import ai.rapids.cudf.{BinaryOp, ColumnVector, DType, NullPolicy, Scalar, ScanAggregation, ScanType, Table, UnaryOp} import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.ShimExpression +import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.{ComplexTypeMergingExpression, Expression} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/constraintExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/constraintExpressions.scala index b646dc40f82..df8eee64629 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/constraintExpressions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/constraintExpressions.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ package com.nvidia.spark.rapids import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.ShimUnaryExpression +import com.nvidia.spark.rapids.shims.ShimUnaryExpression import org.apache.spark.sql.catalyst.expressions.{Expression, TaggingExpression} import org.apache.spark.sql.vectorized.ColumnarBatch diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala index 4cd62ee9e04..2f3ef10a3a0 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/higherOrderFunctions.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ import scala.collection.mutable import ai.rapids.cudf import ai.rapids.cudf.DType -import com.nvidia.spark.rapids.shims.v2.ShimExpression +import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSeq, Expression, ExprId, NamedExpression} import org.apache.spark.sql.types.{ArrayType, DataType, MapType, Metadata} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/limit.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/limit.scala index b0f0aa093da..18b843e0451 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/limit.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/limit.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{NvtxColor, Table} import com.nvidia.spark.rapids.GpuMetric._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.ShimUnaryExecNode +import com.nvidia.spark.rapids.shims.{ShimUnaryExecNode, SparkShimImpl} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -141,7 +141,7 @@ class GpuCollectLimitMeta( override def convertToGpu(): GpuExec = GpuGlobalLimitExec(collectLimit.limit, - ShimLoader.getSparkShims.getGpuShuffleExchangeExec( + SparkShimImpl.getGpuShuffleExchangeExec( GpuSinglePartitioning, GpuLocalLimitExec(collectLimit.limit, childPlans.head.convertIfNeeded()), SinglePartition)) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala index 9060c65474a..3352e8fd3e9 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala @@ -28,6 +28,7 @@ import scala.reflect.runtime.universe.TypeTag import ai.rapids.cudf.{ColumnVector, DType, HostColumnVector, Scalar} import ai.rapids.cudf.ast import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingArray +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.json4s.JsonAST.{JField, JNull, JString} import org.apache.spark.internal.Logging @@ -664,7 +665,7 @@ case class GpuLiteral (value: Any, dataType: DataType) extends GpuLeafExpression } case (v: Decimal, _: DecimalType) => v + "BD" case (v: Int, DateType) => - val formatter = ShimLoader.getSparkShims.getDateFormatter() + val formatter = SparkShimImpl.getDateFormatter() s"DATE '${formatter.format(v)}'" case (v: Long, TimestampType) => val formatter = TimestampFormatter.getFractionFormatter( diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/namedExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/namedExpressions.scala index 81071b7ce71..6e9124cd440 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/namedExpressions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/namedExpressions.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ import java.util.Objects import ai.rapids.cudf.ColumnVector import ai.rapids.cudf.ast import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, ExprId, Generator, NamedExpression} @@ -86,7 +87,7 @@ case class GpuAlias(child: Expression, name: String)( } override def sql: String = { - if (ShimLoader.getSparkShims.hasAliasQuoteFix) { + if (SparkShimImpl.hasAliasQuoteFix) { val qualifierPrefix = if (qualifier.nonEmpty) qualifier.map(quoteIfNeeded).mkString(".") + "." else "" s"${child.sql} AS $qualifierPrefix${quoteIfNeeded(name)}" diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/nullExpressions.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/nullExpressions.scala index 3842a011f0f..5827c2dfa4b 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/nullExpressions.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/nullExpressions.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ import scala.collection.mutable import ai.rapids.cudf.{ColumnVector, DType, Scalar} import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.ShimExpression +import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.sql.catalyst.expressions.{ComplexTypeMergingExpression, Expression, Predicate} import org.apache.spark.sql.types.DataType diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/expressions/rapids/Timestamp.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/expressions/rapids/Timestamp.scala index b5a1be4e299..4360bd87ef7 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/expressions/rapids/Timestamp.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/expressions/rapids/Timestamp.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.expressions.rapids -import com.nvidia.spark.rapids.{ExprChecks, ExprRule, GpuExpression, GpuOverrides, ShimLoader, TypeEnum, TypeSig} +import com.nvidia.spark.rapids.{ExprChecks, ExprRule, GpuExpression, GpuOverrides, TypeEnum, TypeSig} +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.spark.sql.catalyst.expressions.{Expression, GetTimestamp} import org.apache.spark.sql.rapids.{GpuGetTimestamp, UnixTimeExprMeta} @@ -39,7 +40,7 @@ object TimeStamp { TypeSig.STRING)), (a, conf, p, r) => new UnixTimeExprMeta[GetTimestamp](a, conf, p, r) { override def shouldFallbackOnAnsiTimestamp: Boolean = - ShimLoader.getSparkShims.shouldFallbackOnAnsiTimestamp + SparkShimImpl.shouldFallbackOnAnsiTimestamp override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = { GpuGetTimestamp(lhs, rhs, sparkFormat, strfFormat) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala index a6ebdb40ee1..e9a6cb490ab 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuJsonScan.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ListBuffer import ai.rapids.cudf import ai.rapids.cudf.{ColumnVector, DType, HostMemoryBuffer, Scalar, Schema, Table} import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.hadoop.conf.Configuration import org.apache.spark.broadcast.Broadcast @@ -155,7 +156,7 @@ object GpuJsonScan { if (!TypeChecks.areTimestampsSupported(parsedOptions.zoneId)) { meta.willNotWorkOnGpu("Only UTC zone id is supported") } - ShimLoader.getSparkShims.timestampFormatInRead(parsedOptions).foreach { tsFormat => + SparkShimImpl.timestampFormatInRead(parsedOptions).foreach { tsFormat => val parts = tsFormat.split("'T'", 2) if (parts.isEmpty) { meta.willNotWorkOnGpu(s"the timestamp format '$tsFormat' is not supported") diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuReadJsonFileFormat.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuReadJsonFileFormat.scala index 7a369923290..c5f3bef3197 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuReadJsonFileFormat.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/catalyst/json/rapids/GpuReadJsonFileFormat.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.json.rapids import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.hadoop.conf.Configuration import org.apache.spark.sql.SparkSession @@ -68,7 +69,7 @@ object GpuReadJsonFileFormat { def tagSupport(meta: SparkPlanMeta[FileSourceScanExec]): Unit = { val fsse = meta.wrapped GpuJsonScan.tagSupport( - ShimLoader.getSparkShims.sessionFromPlan(fsse), + SparkShimImpl.sessionFromPlan(fsse), fsse.relation.dataSchema, fsse.output.toStructType, fsse.relation.options, diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala index 3a96f9806e0..1e916ac1fc3 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/AggregateFunctions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf import ai.rapids.cudf.{Aggregation128Utils, BinaryOp, ColumnVector, DType, GroupByAggregation, GroupByScanAggregation, NullPolicy, ReductionAggregation, ReplacePolicy, RollingAggregation, RollingAggregationOnColumn, Scalar, ScanAggregation} import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2.{GpuDeterministicFirstLastCollectShim, ShimExpression, ShimUnaryExpression} +import com.nvidia.spark.rapids.shims.{GpuDeterministicFirstLastCollectShim, ShimExpression, ShimUnaryExpression} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCartesianProductExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCartesianProductExec.scala index a90221addc6..ea428300bba 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCartesianProductExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCartesianProductExec.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import ai.rapids.cudf.{JCudfSerialization, NvtxColor, NvtxRange} import com.nvidia.spark.rapids.{Arm, GpuBindReferences, GpuBuildLeft, GpuColumnVector, GpuExec, GpuExpression, GpuMetric, GpuSemaphore, LazySpillableColumnarBatch, MetricsLevel, NoopMetric, SpillCallback} import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.ShimBinaryExecNode +import com.nvidia.spark.rapids.shims.ShimBinaryExecNode import org.apache.spark.{Dependency, NarrowDependency, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCreateDataSourceTableAsSelectCommand.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCreateDataSourceTableAsSelectCommand.scala index e54fa1b09b0..deddb62473a 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCreateDataSourceTableAsSelectCommand.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuCreateDataSourceTableAsSelectCommand.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,8 @@ package org.apache.spark.sql.rapids import java.net.URI -import com.nvidia.spark.rapids.{ColumnarFileFormat, GpuDataWritingCommand, ShimLoader} +import com.nvidia.spark.rapids.{ColumnarFileFormat, GpuDataWritingCommand} +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ @@ -74,14 +75,14 @@ case class GpuCreateDataSourceTableAsSelectCommand( } val result = saveDataIntoTable( sparkSession, table, tableLocation, child, SaveMode.Overwrite, tableExists = false) - ShimLoader.getSparkShims.createTable(table, sessionState.catalog, tableLocation, result) + SparkShimImpl.createTable(table, sessionState.catalog, tableLocation, result) result match { case _: HadoopFsRelation if table.partitionColumnNames.nonEmpty && sparkSession.sqlContext.conf.manageFilesourcePartitions => // Need to recover partitions into the metastore so our saved data is visible. sessionState.executePlan( - ShimLoader.getSparkShims.v1RepairTableCommand(table.identifier)).toRdd + SparkShimImpl.v1RepairTableCommand(table.identifier)).toRdd case _ => } } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSource.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSource.scala index a09a48f875b..b3629e76d7a 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSource.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSource.scala @@ -21,7 +21,8 @@ import java.util.{Locale, ServiceConfigurationError, ServiceLoader} import scala.collection.JavaConverters._ import scala.util.{Failure, Success, Try} -import com.nvidia.spark.rapids.{ColumnarFileFormat, GpuParquetFileFormat, ShimLoader} +import com.nvidia.spark.rapids.{ColumnarFileFormat, GpuParquetFileFormat} +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.commons.lang3.reflect.ConstructorUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -163,7 +164,7 @@ case class GpuDataSource( format.inferSchema( sparkSession, caseInsensitiveOptions - "path", - ShimLoader.getSparkShims.filesFromFileIndex(tempFileIndex)) + SparkShimImpl.filesFromFileIndex(tempFileIndex)) }.getOrElse { throw new AnalysisException( s"Unable to infer schema for $format. It must be specified manually.") @@ -234,7 +235,7 @@ case class GpuDataSource( format.inferSchema( sparkSession, caseInsensitiveOptions - "path", - ShimLoader.getSparkShims.filesFromFileIndex(fileCatalog)) + SparkShimImpl.filesFromFileIndex(fileCatalog)) }.getOrElse { throw new AnalysisException( s"Unable to infer schema for $format at ${fileCatalog.allFiles().mkString(",")}. " + @@ -285,17 +286,17 @@ case class GpuDataSource( relation match { case hs: HadoopFsRelation => - ShimLoader.getSparkShims.checkColumnNameDuplication( + SparkShimImpl.checkColumnNameDuplication( hs.dataSchema, "in the data schema", equality) - ShimLoader.getSparkShims.checkColumnNameDuplication( + SparkShimImpl.checkColumnNameDuplication( hs.partitionSchema, "in the partition schema", equality) DataSourceUtils.verifySchema(hs.fileFormat, hs.dataSchema) case _ => - ShimLoader.getSparkShims.checkColumnNameDuplication( + SparkShimImpl.checkColumnNameDuplication( relation.schema, "in the data schema", equality) @@ -636,7 +637,7 @@ object GpuDataSource extends Logging { val allPaths = globbedPaths ++ nonGlobPaths if (checkFilesExist) { val (filteredOut, filteredIn) = allPaths.partition { path => - ShimLoader.getSparkShims.shouldIgnorePath(path.getName) + SparkShimImpl.shouldIgnorePath(path.getName) } if (filteredIn.isEmpty) { logWarning( diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSourceScanExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSourceScanExec.scala index 4d25b7a4f13..52243c3596b 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSourceScanExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuDataSourceScanExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package org.apache.spark.sql.rapids -import com.nvidia.spark.rapids.{GpuExec, ShimLoader} +import com.nvidia.spark.rapids.GpuExec +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.Path @@ -41,7 +42,7 @@ trait GpuDataSourceScanExec extends LeafExecNode with GpuExec { // Metadata that describes more details of this scan. protected def metadata: Map[String, String] - protected val maxMetadataValueLength = ShimLoader.getSparkShims + protected val maxMetadataValueLength = SparkShimImpl .getFileSourceMaxMetadataValueLength(sparkSession.sessionState.conf) override def simpleString(maxFields: Int): String = { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala index 03077947472..6a01362f91e 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatDataWriter.scala @@ -22,6 +22,7 @@ import ai.rapids.cudf.{ContiguousTable, OrderByArg, Table} import com.nvidia.spark.TimingUtils import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.TaskAttemptContext @@ -233,7 +234,7 @@ class GpuDynamicPartitionDataWriter( */ private lazy val partitionPathExpression: Expression = Concat( description.partitionColumns.zipWithIndex.flatMap { case (c, i) => - val partitionName = ShimLoader.getSparkShims.getScalaUDFAsExpression( + val partitionName = SparkShimImpl.getScalaUDFAsExpression( ExternalCatalogUtils.getPartitionPathString _, StringType, Seq(Literal(c.name), Cast(c, StringType, Option(description.timeZoneId)))) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala index 55d1cca9ea2..c513d201799 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala @@ -21,6 +21,7 @@ import java.util.{Date, UUID} import ai.rapids.cudf.ColumnVector import com.nvidia.spark.TimingUtils import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ @@ -196,10 +197,9 @@ object GpuFileFormatWriter extends Logging { // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and // the physical plan may have different attribute ids due to optimizer removing some // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. - val sparkShims = ShimLoader.getSparkShims val orderingExpr = GpuBindReferences.bindReferences( requiredOrdering - .map(attr => sparkShims.sortOrder(attr, Ascending)), outputSpec.outputColumns) + .map(attr => SparkShimImpl.sortOrder(attr, Ascending)), outputSpec.outputColumns) val sortType = if (RapidsConf.STABLE_SORT.get(plan.conf)) { FullSortSingleBatch } else { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExec.scala index 8add673fdb8..0c00ae44c0a 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFileSourceScanExec.scala @@ -20,7 +20,8 @@ import java.util.concurrent.TimeUnit.NANOSECONDS import scala.collection.mutable.HashMap -import com.nvidia.spark.rapids.{GpuDataSourceRDD, GpuExec, GpuMetric, GpuOrcMultiFilePartitionReaderFactory, GpuParquetMultiFilePartitionReaderFactory, GpuReadCSVFileFormat, GpuReadFileFormatWithMetrics, GpuReadOrcFileFormat, GpuReadParquetFileFormat, RapidsConf, ShimLoader, SparkPlanMeta} +import com.nvidia.spark.rapids.{GpuDataSourceRDD, GpuExec, GpuMetric, GpuOrcMultiFilePartitionReaderFactory, GpuParquetMultiFilePartitionReaderFactory, GpuReadCSVFileFormat, GpuReadFileFormatWithMetrics, GpuReadOrcFileFormat, GpuReadParquetFileFormat, RapidsConf, SparkPlanMeta} +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.hadoop.fs.Path import org.apache.spark.rdd.RDD @@ -205,7 +206,7 @@ case class GpuFileSourceScanExec( // the RDD partition will not be sorted even if the relation has sort columns set // Current solution is to check if all the buckets have a single file in it - val filesPartNames = ShimLoader.getSparkShims.getPartitionFileNames(selectedPartitions) + val filesPartNames = SparkShimImpl.getPartitionFileNames(selectedPartitions) val bucketToFilesGrouping = filesPartNames.groupBy(file => BucketingUtils.getBucketId(file)) val singleFilePartitions = bucketToFilesGrouping.forall(p => p._2.length <= 1) @@ -353,7 +354,7 @@ case class GpuFileSourceScanExec( partitions: Seq[PartitionDirectory], static: Boolean): Unit = { val filesNum = partitions.map(_.files.size.toLong).sum - val filesSize = ShimLoader.getSparkShims.getPartitionFileStatusSize(partitions) + val filesSize = SparkShimImpl.getPartitionFileStatusSize(partitions) if (!static || !partitionFilters.exists(isDynamicPruningFilter)) { driverMetrics("numFiles") = filesNum driverMetrics("filesSize") = filesSize @@ -450,7 +451,7 @@ case class GpuFileSourceScanExec( logInfo(s"Planning with ${bucketSpec.numBuckets} buckets") val partitionedFiles = - ShimLoader.getSparkShims.getPartitionedFiles(selectedPartitions) + SparkShimImpl.getPartitionedFiles(selectedPartitions) val filesGroupedToBuckets = partitionedFiles.groupBy { f => BucketingUtils @@ -474,11 +475,11 @@ case class GpuFileSourceScanExec( val partitionedFiles = coalescedBuckets.get(bucketId).map { _.values.flatten.toArray }.getOrElse(Array.empty) - ShimLoader.getSparkShims.createFilePartition(bucketId, partitionedFiles) + SparkShimImpl.createFilePartition(bucketId, partitionedFiles) } }.getOrElse { Seq.tabulate(bucketSpec.numBuckets) { bucketId => - ShimLoader.getSparkShims.createFilePartition(bucketId, + SparkShimImpl.createFilePartition(bucketId, prunedFilesGroupedToBuckets.getOrElse(bucketId, Array.empty)) } } @@ -504,7 +505,7 @@ case class GpuFileSourceScanExec( logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + s"open cost is considered as scanning $openCostInBytes bytes.") - val splitFiles = ShimLoader.getSparkShims + val splitFiles = SparkShimImpl .getPartitionSplitFiles(selectedPartitions, maxSplitBytes, relation) .sortBy(_.length)(implicitly[Ordering[Long]].reverse) @@ -521,7 +522,7 @@ case class GpuFileSourceScanExec( if (isPerFileReadEnabled) { logInfo("Using the original per file parquet reader") - ShimLoader.getSparkShims.getFileScanRDD(fsRelation.sparkSession, readFile.get, partitions, + SparkShimImpl.getFileScanRDD(fsRelation.sparkSession, readFile.get, partitions, requiredSchema) } else { // here we are making an optimization to read more then 1 file at a time on the CPU side diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuScalarSubquery.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuScalarSubquery.scala index d151dadac6e..be494bbbdbc 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuScalarSubquery.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuScalarSubquery.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ package org.apache.spark.sql.rapids import com.nvidia.spark.rapids.{GpuExpression, GpuScalar} -import com.nvidia.spark.rapids.shims.v2.ShimExpression +import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId} import org.apache.spark.sql.execution.{BaseSubqueryExec, ExecSubqueryExpression} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/HashFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/HashFunctions.scala index 78c7b2a6695..26d999e9cdf 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/HashFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/HashFunctions.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView} import com.nvidia.spark.rapids.{Arm, GpuColumnVector, GpuExpression, GpuProjectExec, GpuUnaryExpression} import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.{HashUtils, ShimExpression} +import com.nvidia.spark.rapids.shims.{HashUtils, ShimExpression} import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes, NullIntolerant} import org.apache.spark.sql.types._ diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsDiskBlockManager.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsDiskBlockManager.scala index ec82c7662c5..3f1f65cd2b9 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsDiskBlockManager.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsDiskBlockManager.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids import java.io.File import org.apache.spark.SparkConf -import org.apache.spark.rapids.shims.v2.storage.ShimDiskBlockManager +import org.apache.spark.rapids.shims.storage.ShimDiskBlockManager import org.apache.spark.storage.BlockId /** Maps logical blocks to local disk locations. */ diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala index 1fbfe4850a1..d62442d6c6e 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/RapidsShuffleInternalManagerBase.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf.{NvtxColor, NvtxRange} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.format.TableMeta +import com.nvidia.spark.rapids.shims.SparkShimImpl import com.nvidia.spark.rapids.shuffle.{RapidsShuffleRequestHandler, RapidsShuffleServer, RapidsShuffleTransport} import scala.collection.mutable.{ArrayBuffer, ListBuffer} @@ -29,7 +30,7 @@ import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle._ import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.rapids.shims.v2.GpuShuffleBlockResolver +import org.apache.spark.sql.rapids.shims.GpuShuffleBlockResolver import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.storage._ @@ -371,7 +372,7 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, val isDriver: B val nvtxRange = new NvtxRange("getMapSizesByExecId", NvtxColor.CYAN) val blocksByAddress = try { - ShimLoader.getSparkShims.getMapSizesByExecutorId(gpu.shuffleId, + SparkShimImpl.getMapSizesByExecutorId(gpu.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) } finally { nvtxRange.close() @@ -451,9 +452,8 @@ abstract class ProxyRapidsShuffleInternalManagerBase( // touched in the plugin code after the shim initialization // is complete - lazy val self: ShuffleManager = - ShimLoader.newInternalShuffleManager(conf, isDriver) - .asInstanceOf[ShuffleManager] + lazy val self: ShuffleManager = ShimLoader.newInternalShuffleManager(conf, isDriver) + .asInstanceOf[ShuffleManager] // This function touches the lazy val `self` so we actually instantiate // the manager. This is called from both the driver and executor. diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala index 54005dcd465..01476d6ed18 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/arithmetic.scala @@ -22,7 +22,7 @@ import ai.rapids.cudf._ import ai.rapids.cudf.ast.BinaryOperator import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.ShimExpression +import com.nvidia.spark.rapids.shims.{ShimExpression, SparkShimImpl} import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.{ComplexTypeMergingExpression, ExpectsInputTypes, Expression, NullIntolerant} @@ -662,7 +662,7 @@ object GpuDivModLike extends Arm { trait GpuDivModLike extends CudfBinaryArithmetic { lazy val failOnError: Boolean = - ShimLoader.getSparkShims.shouldFailDivByZero() + SparkShimImpl.shouldFailDivByZero() override def nullable: Boolean = true @@ -728,7 +728,7 @@ case class GpuDecimalDivide( left: Expression, right: Expression, dataType: DecimalType, - failOnError: Boolean = ShimLoader.getSparkShims.shouldFailDivByZero()) extends + failOnError: Boolean = SparkShimImpl.shouldFailDivByZero()) extends ShimExpression with GpuExpression { override def toString: String = s"($left / $right)" @@ -856,7 +856,7 @@ object GpuDecimalDivide { } case class GpuDivide(left: Expression, right: Expression, - failOnErrorOverride: Boolean = ShimLoader.getSparkShims.shouldFailDivByZero()) + failOnErrorOverride: Boolean = SparkShimImpl.shouldFailDivByZero()) extends GpuDivModLike { assert(!left.dataType.isInstanceOf[DecimalType], "DecimalType divides need to be handled by GpuDecimalDivide") @@ -876,7 +876,7 @@ case class GpuIntegralDivide(left: Expression, right: Expression) extends GpuDiv override def inputType: AbstractDataType = TypeCollection(IntegralType, DecimalType) lazy val failOnOverflow: Boolean = - ShimLoader.getSparkShims.shouldFailDivOverflow + SparkShimImpl.shouldFailDivOverflow override def checkDivideOverflow: Boolean = left.dataType match { case LongType if failOnOverflow => true diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/catalyst/expressions/GpuRandomExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/catalyst/expressions/GpuRandomExpressions.scala index b9e83d616e0..7d2b789ff1d 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/catalyst/expressions/GpuRandomExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/catalyst/expressions/GpuRandomExpressions.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.catalyst.expressions import ai.rapids.cudf.{DType, HostColumnVector} import com.nvidia.spark.rapids.{GpuColumnVector, GpuExpression, GpuLiteral} -import com.nvidia.spark.rapids.shims.v2.ShimUnaryExpression +import com.nvidia.spark.rapids.shims.ShimUnaryExpression import org.apache.spark.TaskContext import org.apache.spark.sql.AnalysisException diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala index de9aa6fbf54..517a861e742 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala @@ -27,7 +27,7 @@ import com.nvidia.spark.rapids.ArrayIndexUtils.firstIndexAndNumElementUnchecked import com.nvidia.spark.rapids.BoolUtils.isAllValidTrue import com.nvidia.spark.rapids.GpuExpressionsUtils.columnarEvalToColumn import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.{RapidsErrorUtils, ShimExpression} +import com.nvidia.spark.rapids.shims.{RapidsErrorUtils, ShimExpression} import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ImplicitCastInputTypes, RowOrdering, Sequence, TimeZoneAwareExpression} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeCreator.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeCreator.scala index f97e222acff..ce4c673442c 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeCreator.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeCreator.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids import ai.rapids.cudf.{ColumnVector, ColumnView, DType} import com.nvidia.spark.rapids.{Arm, GpuColumnVector, GpuExpression, GpuExpressionsUtils, GpuMapUtils} import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingSeq -import com.nvidia.spark.rapids.shims.v2.ShimExpression +import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FUNC_ALIAS diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index eebc2cb406a..51f077c5ff9 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -21,7 +21,7 @@ import com.nvidia.spark.rapids.{BinaryExprMeta, DataFromReplacementRule, GpuBina import com.nvidia.spark.rapids.ArrayIndexUtils.firstIndexAndNumElementUnchecked import com.nvidia.spark.rapids.BoolUtils.isAnyValidTrue import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.{RapidsErrorUtils, ShimUnaryExpression} +import com.nvidia.spark.rapids.shims.{RapidsErrorUtils, ShimUnaryExpression} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala index 6b0e611368e..744376d0c50 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/datetimeExpressions.scala @@ -22,7 +22,7 @@ import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, Scalar} import com.nvidia.spark.rapids.{Arm, BinaryExprMeta, DataFromReplacementRule, DateUtils, GpuBinaryExpression, GpuColumnVector, GpuExpression, GpuScalar, GpuUnaryExpression, RapidsConf, RapidsMeta} import com.nvidia.spark.rapids.GpuOverrides.{extractStringLit, getTimeParserPolicy} import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.ShimBinaryExpression +import com.nvidia.spark.rapids.shims.ShimBinaryExpression import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, ExpectsInputTypes, Expression, ImplicitCastInputTypes, NullIntolerant, TimeZoneAwareExpression} import org.apache.spark.sql.types._ diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala index f3e5737c0ca..5d3959bd0e6 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastExchangeExec.scala @@ -29,7 +29,7 @@ import ai.rapids.cudf.JCudfSerialization.SerializedTableHeader import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.GpuMetric._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.{ShimBroadcastExchangeLike, ShimUnaryExecNode} +import com.nvidia.spark.rapids.shims.{ShimBroadcastExchangeLike, ShimUnaryExecNode, SparkShimImpl} import org.apache.spark.SparkException import org.apache.spark.broadcast.Broadcast @@ -363,7 +363,7 @@ abstract class GpuBroadcastExchangeExecBase( val d = data.collect() val emptyRelation: Option[Any] = if (d.isEmpty) { - ShimLoader.getSparkShims.tryTransformIfEmptyRelation(mode) + SparkShimImpl.tryTransformIfEmptyRelation(mode) } else { None } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHelper.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHelper.scala index 47b9c1cada2..988d855dbf7 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHelper.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastHelper.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package org.apache.spark.sql.rapids.execution -import com.nvidia.spark.rapids.{GpuColumnVector, ShimLoader} +import com.nvidia.spark.rapids.GpuColumnVector +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.types.StructType @@ -43,7 +44,7 @@ object GpuBroadcastHelper { val builtBatch = broadcastBatch.batch GpuColumnVector.incRefCounts(builtBatch) builtBatch - case v if ShimLoader.getSparkShims.isEmptyRelation(v) => + case v if SparkShimImpl.isEmptyRelation(v) => GpuColumnVector.emptyBatch(broadcastSchema) case t => throw new IllegalStateException(s"Invalid broadcast batch received $t") @@ -67,7 +68,7 @@ object GpuBroadcastHelper { broadcastRelation.value match { case broadcastBatch: SerializeConcatHostBuffersDeserializeBatch => broadcastBatch.batch.numRows() - case v if ShimLoader.getSparkShims.isEmptyRelation(v) => 0 + case v if SparkShimImpl.isEmptyRelation(v) => 0 case t => throw new IllegalStateException(s"Invalid broadcast batch received $t") } diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala index b232f3d2440..cfc7c946425 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuBroadcastNestedLoopJoinExec.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.execution import ai.rapids.cudf.{ast, GatherMap, NvtxColor, OutOfBoundsPolicy, Table} import com.nvidia.spark.rapids._ -import com.nvidia.spark.rapids.shims.v2.{GpuJoinUtils, ShimBinaryExecNode} +import com.nvidia.spark.rapids.shims.{GpuJoinUtils, ShimBinaryExecNode} import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuCustomShuffleReaderExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuCustomShuffleReaderExec.scala index 7107882d9ef..95d4da40b45 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuCustomShuffleReaderExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuCustomShuffleReaderExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,8 +15,8 @@ */ package org.apache.spark.sql.rapids.execution -import com.nvidia.spark.rapids.{CoalesceGoal, GpuExec, GpuMetric, ShimLoader} -import com.nvidia.spark.rapids.shims.v2.ShimUnaryExecNode +import com.nvidia.spark.rapids.{CoalesceGoal, GpuExec, GpuMetric} +import com.nvidia.spark.rapids.shims.{ShimUnaryExecNode, SparkShimImpl} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -120,7 +120,7 @@ case class GpuCustomShuffleReaderExec( if (cachedShuffleRDD == null) { cachedShuffleRDD = child match { case stage: ShuffleQueryStageExec => - val shuffle = ShimLoader.getSparkShims.getGpuShuffleExchangeExec(stage) + val shuffle = SparkShimImpl.getGpuShuffleExchangeExec(stage) new ShuffledBatchRDD( shuffle.shuffleDependencyColumnar, shuffle.readMetrics ++ metrics, partitionSpecs.toArray) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala index 0dcaaaaabed..2105a2e1139 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuShuffleExchangeExecBase.scala @@ -21,7 +21,7 @@ import scala.concurrent.Future import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.{GpuHashPartitioning, GpuRangePartitioning, ShimUnaryExecNode} +import com.nvidia.spark.rapids.shims.{GpuHashPartitioning, GpuRangePartitioning, ShimUnaryExecNode, SparkShimImpl} import org.apache.spark.{MapOutputStatistics, ShuffleDependency} import org.apache.spark.rdd.RDD @@ -82,7 +82,7 @@ class GpuShuffleMeta( shuffle.outputPartitioning match { case _: RoundRobinPartitioning - if ShimLoader.getSparkShims.sessionFromPlan(shuffle).sessionState.conf + if SparkShimImpl.sessionFromPlan(shuffle).sessionState.conf .sortBeforeRepartition => val orderableTypes = GpuOverrides.pluginSupportedOrderableSig + TypeSig.DECIMAL_128 shuffle.output.map(_.dataType) @@ -107,7 +107,7 @@ class GpuShuffleMeta( } override def convertToGpu(): GpuExec = - ShimLoader.getSparkShims.getGpuShuffleExchangeExec( + SparkShimImpl.getGpuShuffleExchangeExec( childParts.head.convertToGpu(), childPlans.head.convertIfNeeded(), shuffle.outputPartitioning, @@ -218,7 +218,7 @@ abstract class GpuShuffleExchangeExecBase( protected override def doExecute(): RDD[InternalRow] = throw new IllegalStateException(s"Row-based execution should not occur for $this") - override def doExecuteColumnar(): RDD[ColumnarBatch] = ShimLoader.getSparkShims + override def doExecuteColumnar(): RDD[ColumnarBatch] = SparkShimImpl .attachTreeIfSupported(this, "execute") { // Returns the same ShuffleRowRDD if this plan is used by multiple plans. if (cachedShuffleRDD == null) { @@ -254,7 +254,7 @@ object GpuShuffleExchangeExecBase { * task when indeterminate tasks re-run. */ val newRdd = if (isRoundRobin && SQLConf.get.sortBeforeRepartition) { - val shim = ShimLoader.getSparkShims + val shim = SparkShimImpl val boundReferences = outputAttributes.zipWithIndex.map { case (attr, index) => shim.sortOrder(GpuBoundReference(index, attr.dataType, attr.nullable)(attr.exprId, attr.name), Ascending) diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubqueryBroadcastExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubqueryBroadcastExec.scala index 181b4200e2b..d5cf8a58f45 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubqueryBroadcastExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/GpuSubqueryBroadcastExec.scala @@ -20,9 +20,9 @@ import scala.collection.JavaConverters.asScalaIteratorConverter import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration -import com.nvidia.spark.rapids.{BaseExprMeta, DataFromReplacementRule, GpuColumnarToRowExecParent, GpuExec, GpuMetric, RapidsConf, RapidsMeta, ShimLoader, SparkPlanMeta, TargetSize} +import com.nvidia.spark.rapids.{BaseExprMeta, DataFromReplacementRule, GpuColumnarToRowExecParent, GpuExec, GpuMetric, RapidsConf, RapidsMeta, SparkPlanMeta, TargetSize} import com.nvidia.spark.rapids.GpuMetric.{COLLECT_TIME, DESCRIPTION_COLLECT_TIME, ESSENTIAL_LEVEL} -import com.nvidia.spark.rapids.shims.v2.ShimUnaryExecNode +import com.nvidia.spark.rapids.shims.{ShimUnaryExecNode, SparkShimImpl} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -108,13 +108,13 @@ class GpuSubqueryBroadcastMeta( // +- [GPU overrides of executed subquery...] // case a: AdaptiveSparkPlanExec => - ShimLoader.getSparkShims.getAdaptiveInputPlan(a) match { + SparkShimImpl.getAdaptiveInputPlan(a) match { case ex: BroadcastExchangeExec => val exMeta = new GpuBroadcastMeta(ex, conf, p, r) exMeta.tagForGpu() if (exMeta.canThisBeReplaced) { broadcastBuilder = () => - ShimLoader.getSparkShims.columnarAdaptivePlan( + SparkShimImpl.columnarAdaptivePlan( a, TargetSize(conf.gpuTargetBatchSizeBytes)) } else { willNotWorkOnGpu("underlying BroadcastExchange can not run in the GPU.") diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/ShuffledBatchRDD.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/ShuffledBatchRDD.scala index 96aaced10fd..17fcebc3fc7 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/ShuffledBatchRDD.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/ShuffledBatchRDD.scala @@ -22,7 +22,7 @@ import java.util.Arrays import com.nvidia.spark.rapids.GpuMetric import org.apache.spark.{Dependency, Partition, Partitioner, ShuffleDependency, TaskContext} -import org.apache.spark.rapids.shims.v2.ShuffledBatchRDDUtil +import org.apache.spark.rapids.shims.ShuffledBatchRDDUtil import org.apache.spark.rdd.RDD import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.execution.{CoalescedPartitioner, CoalescedPartitionSpec, ShufflePartitionSpec} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala index 6ccc6a8a09d..a172db5f52c 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuAggregateInPandasExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ import ai.rapids.cudf import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.python.PythonWorkerSemaphore -import com.nvidia.spark.rapids.shims.v2.ShimUnaryExecNode +import com.nvidia.spark.rapids.shims.{ShimUnaryExecNode, SparkShimImpl} import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} @@ -114,7 +114,7 @@ case class GpuAggregateInPandasExec( } override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(cpuGroupingExpressions.map(ShimLoader.getSparkShims.sortOrder(_, Ascending))) + Seq(cpuGroupingExpressions.map(SparkShimImpl.sortOrder(_, Ascending))) // One batch as input to keep the integrity for each group. // (This should be replaced by an iterator that can split batches on key boundaries eventually. diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala index 005e4dab10b..0701dbfc637 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuArrowEvalPythonExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -31,11 +31,11 @@ import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.GpuMetric._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.python.PythonWorkerSemaphore -import com.nvidia.spark.rapids.shims.v2.ShimUnaryExecNode +import com.nvidia.spark.rapids.shims.ShimUnaryExecNode import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.api.python._ -import org.apache.spark.rapids.shims.v2.api.python.ShimBasePythonRunner +import org.apache.spark.rapids.shims.api.python.ShimBasePythonRunner import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapCoGroupsInPandasExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapCoGroupsInPandasExec.scala index 0ff3bc55cf6..0390a6ca52c 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapCoGroupsInPandasExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuFlatMapCoGroupsInPandasExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ package org.apache.spark.sql.rapids.execution.python import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.python.PythonWorkerSemaphore -import com.nvidia.spark.rapids.shims.v2.ShimBinaryExecNode +import com.nvidia.spark.rapids.shims.{ShimBinaryExecNode, SparkShimImpl} import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} @@ -99,8 +99,8 @@ case class GpuFlatMapCoGroupsInPandasExec( } override def requiredChildOrdering: Seq[Seq[SortOrder]] = { - leftGroup.map(ShimLoader.getSparkShims.sortOrder(_, Ascending)) :: - rightGroup.map(ShimLoader.getSparkShims.sortOrder(_, Ascending)) :: Nil + leftGroup.map(SparkShimImpl.sortOrder(_, Ascending)) :: + rightGroup.map(SparkShimImpl.sortOrder(_, Ascending)) :: Nil } override protected def doExecute(): RDD[InternalRow] = { diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuMapInPandasExec.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuMapInPandasExec.scala index 8ef1de5721f..86ce99f3d5e 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuMapInPandasExec.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuMapInPandasExec.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ package org.apache.spark.sql.rapids.execution.python import ai.rapids.cudf import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.python.PythonWorkerSemaphore -import com.nvidia.spark.rapids.shims.v2.ShimUnaryExecNode +import com.nvidia.spark.rapids.shims.ShimUnaryExecNode import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala index e1a63cae3a4..ef2ca077b34 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/execution/python/GpuWindowInPandasExecBase.scala @@ -25,7 +25,7 @@ import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.GpuMetric._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ import com.nvidia.spark.rapids.python.PythonWorkerSemaphore -import com.nvidia.spark.rapids.shims.v2.ShimUnaryExecNode +import com.nvidia.spark.rapids.shims.{ShimUnaryExecNode, SparkShimImpl} import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} @@ -205,7 +205,7 @@ trait GpuWindowInPandasExecBase extends ShimUnaryExecNode with GpuExec { } override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(cpuPartitionSpec.map(ShimLoader.getSparkShims.sortOrder(_, Ascending)) ++ cpuOrderSpec) + Seq(cpuPartitionSpec.map(SparkShimImpl.sortOrder(_, Ascending)) ++ cpuOrderSpec) override def outputOrdering: Seq[SortOrder] = child.outputOrdering diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala index 7fb73f48fb3..0084d613a78 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/stringFunctions.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer import ai.rapids.cudf.{BinaryOp, ColumnVector, ColumnView, DType, PadSide, Scalar, Table} import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.RapidsPluginImplicits._ -import com.nvidia.spark.rapids.shims.v2.ShimExpression +import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ImplicitCastInputTypes, Literal, NullIntolerant, Predicate, RegExpExtract, RLike, StringSplit, StringToMap, SubstringIndex, TernaryExpression} import org.apache.spark.sql.types._ diff --git a/sql-plugin/src/test/320/scala/com/nvidia/spark/rapids/shims/spark320/Spark320ShimsSuite.scala b/sql-plugin/src/test/320/scala/com/nvidia/spark/rapids/shims/spark320/Spark320ShimsSuite.scala index b8a3fceff84..1a41ce2e8b8 100644 --- a/sql-plugin/src/test/320/scala/com/nvidia/spark/rapids/shims/spark320/Spark320ShimsSuite.scala +++ b/sql-plugin/src/test/320/scala/com/nvidia/spark/rapids/shims/spark320/Spark320ShimsSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,15 +16,15 @@ package com.nvidia.spark.rapids.shims.spark320; -import com.nvidia.spark.rapids.{ShimLoader, SparkShims, SparkShimVersion, TypeSig} +import com.nvidia.spark.rapids.{ShimLoader, SparkShimVersion, TypeSig} +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.scalatest.FunSuite import org.apache.spark.sql.types.{DayTimeIntervalType, YearMonthIntervalType} class Spark320ShimsSuite extends FunSuite { - val sparkShims: SparkShims = new SparkShimServiceProvider().buildShim test("spark shims version") { - assert(sparkShims.getSparkShimVersion === SparkShimVersion(3, 2, 0)) + assert(SparkShimImpl.getSparkShimVersion === SparkShimVersion(3, 2, 0)) } test("shuffle manager class") { diff --git a/sql-plugin/src/test/321/scala/com/nvidia/spark/rapids/shims/spark321/Spark321ShimsSuite.scala b/sql-plugin/src/test/321/scala/com/nvidia/spark/rapids/shims/spark321/Spark321ShimsSuite.scala index 64956a8eb49..f9b1265c2b2 100644 --- a/sql-plugin/src/test/321/scala/com/nvidia/spark/rapids/shims/spark321/Spark321ShimsSuite.scala +++ b/sql-plugin/src/test/321/scala/com/nvidia/spark/rapids/shims/spark321/Spark321ShimsSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,15 +16,15 @@ package com.nvidia.spark.rapids.shims.spark321; -import com.nvidia.spark.rapids.{ShimLoader, SparkShims, SparkShimVersion, TypeSig} +import com.nvidia.spark.rapids.{ShimLoader, SparkShimVersion, TypeSig} +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.scalatest.FunSuite import org.apache.spark.sql.types.{DayTimeIntervalType, YearMonthIntervalType} class Spark321ShimsSuite extends FunSuite { - val sparkShims: SparkShims = new SparkShimServiceProvider().buildShim test("spark shims version") { - assert(sparkShims.getSparkShimVersion === SparkShimVersion(3, 2, 1)) + assert(SparkShimImpl.getSparkShimVersion === SparkShimVersion(3, 2, 1)) } test("shuffle manager class") { diff --git a/sql-plugin/src/test/322/scala/com/nvidia/spark/rapids/shims/spark322/Spark322ShimsSuite.scala b/sql-plugin/src/test/322/scala/com/nvidia/spark/rapids/shims/spark322/Spark322ShimsSuite.scala index 95295634f4f..26f0f0fdd45 100644 --- a/sql-plugin/src/test/322/scala/com/nvidia/spark/rapids/shims/spark322/Spark322ShimsSuite.scala +++ b/sql-plugin/src/test/322/scala/com/nvidia/spark/rapids/shims/spark322/Spark322ShimsSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,15 +16,15 @@ package com.nvidia.spark.rapids.shims.spark322; -import com.nvidia.spark.rapids.{ShimLoader, SparkShims, SparkShimVersion, TypeSig} +import com.nvidia.spark.rapids.{ShimLoader, SparkShimVersion, TypeSig} +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.scalatest.FunSuite import org.apache.spark.sql.types.{DayTimeIntervalType, YearMonthIntervalType} class Spark322ShimsSuite extends FunSuite { - val sparkShims: SparkShims = new SparkShimServiceProvider().buildShim test("spark shims version") { - assert(sparkShims.getSparkShimVersion === SparkShimVersion(3, 2, 2)) + assert(SparkShimImpl.getSparkShimVersion === SparkShimVersion(3, 2, 2)) } test("shuffle manager class") { diff --git a/sql-plugin/src/test/330/scala/com/nvidia/spark/rapids/shims/spark330/Spark330ShimsSuite.scala b/sql-plugin/src/test/330/scala/com/nvidia/spark/rapids/shims/spark330/Spark330ShimsSuite.scala index bf77015e465..4691068b413 100644 --- a/sql-plugin/src/test/330/scala/com/nvidia/spark/rapids/shims/spark330/Spark330ShimsSuite.scala +++ b/sql-plugin/src/test/330/scala/com/nvidia/spark/rapids/shims/spark330/Spark330ShimsSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,15 +16,15 @@ package com.nvidia.spark.rapids.shims.spark330; -import com.nvidia.spark.rapids.{ShimLoader, SparkShims, SparkShimVersion, TypeSig} +import com.nvidia.spark.rapids.{ShimLoader, SparkShimVersion, TypeSig} +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.scalatest.FunSuite import org.apache.spark.sql.types.{DayTimeIntervalType, YearMonthIntervalType} class Spark330ShimsSuite extends FunSuite { - val sparkShims: SparkShims = new SparkShimServiceProvider().buildShim test("spark shims version") { - assert(sparkShims.getSparkShimVersion === SparkShimVersion(3, 3, 0)) + assert(SparkShimImpl.getSparkShimVersion === SparkShimVersion(3, 3, 0)) } test("shuffle manager class") { diff --git a/tests-spark310+/src/test/scala/com/nvidia/spark/rapids/shims/v2/Spark310ParquetWriterSuite.scala b/tests-spark310+/src/test/scala/com/nvidia/spark/rapids/shims/Spark310ParquetWriterSuite.scala similarity index 99% rename from tests-spark310+/src/test/scala/com/nvidia/spark/rapids/shims/v2/Spark310ParquetWriterSuite.scala rename to tests-spark310+/src/test/scala/com/nvidia/spark/rapids/shims/Spark310ParquetWriterSuite.scala index fda83439f54..636d4e0086c 100644 --- a/tests-spark310+/src/test/scala/com/nvidia/spark/rapids/shims/v2/Spark310ParquetWriterSuite.scala +++ b/tests-spark310+/src/test/scala/com/nvidia/spark/rapids/shims/Spark310ParquetWriterSuite.scala @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.nvidia.spark.rapids.shims.v2 +package com.nvidia.spark.rapids.shims import scala.collection.mutable diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala index 78913447dae..b34f6f491d9 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/AdaptiveQueryExecSuite.scala @@ -18,6 +18,8 @@ package com.nvidia.spark.rapids import java.io.File +import com.nvidia.spark.rapids.shims.SparkShimImpl + import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} @@ -88,7 +90,7 @@ class AdaptiveQueryExecSuite } private def findReusedExchange(plan: SparkPlan): Seq[ReusedExchangeExec] = { - collectWithSubqueries(plan)(ShimLoader.getSparkShims.reusedExchangeExecPfn) + collectWithSubqueries(plan)(SparkShimImpl.reusedExchangeExecPfn) } test("get row counts from executed shuffle query stages") { @@ -350,7 +352,7 @@ class AdaptiveQueryExecSuite _.isInstanceOf[AdaptiveSparkPlanExec]) .get.asInstanceOf[AdaptiveSparkPlanExec] - if (ShimLoader.getSparkShims.supportsColumnarAdaptivePlans) { + if (SparkShimImpl.supportsColumnarAdaptivePlans) { // we avoid the transition entirely with Spark 3.2+ due to the changes in SPARK-35881 to // support columnar adaptive plans assert(adaptiveSparkPlanExec @@ -747,4 +749,4 @@ class AdaptiveQueryExecSuite spark.read.parquet(path).createOrReplaceTempView(name) } -} \ No newline at end of file +} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala index 491227b05bf..d0aeb1c6508 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/AnsiCastOpSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,8 @@ import java.time.DateTimeException import scala.util.Random +import com.nvidia.spark.rapids.shims.SparkShimImpl + import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, CastBase, NamedExpression} @@ -789,7 +791,7 @@ class AnsiCastOpSuite extends GpuExpressionTestSuite { private def assertContainsAnsiCast(df: DataFrame, expected: Int = 1): DataFrame = { - val projections = ShimLoader.getSparkShims.findOperators(df.queryExecution.executedPlan, { + val projections = SparkShimImpl.findOperators(df.queryExecution.executedPlan, { case _: ProjectExec | _: GpuProjectExec => true case _ => false }) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastNestedLoopJoinSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastNestedLoopJoinSuite.scala index c447bd58f9f..500031d97d6 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastNestedLoopJoinSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/BroadcastNestedLoopJoinSuite.scala @@ -16,6 +16,8 @@ package com.nvidia.spark.rapids +import com.nvidia.spark.rapids.shims.SparkShimImpl + import org.apache.spark.SparkConf import org.apache.spark.sql.functions.broadcast import org.apache.spark.sql.internal.SQLConf @@ -58,7 +60,7 @@ class BroadcastNestedLoopJoinSuite extends SparkQueryCompareTestSuite { val nljCount = PlanUtils.findOperators(plan, _.isInstanceOf[GpuBroadcastNestedLoopJoinExec]) - ShimLoader.getSparkShims.getSparkShimVersion match { + SparkShimImpl.getSparkShimVersion match { case SparkShimVersion(3, 0, 0) => // we didn't start supporting GPU exchanges with AQE until 3.0.1 assert(nljCount.size === 0) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/ParquetWriterSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/ParquetWriterSuite.scala index 887d74f9386..4057e367cf1 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/ParquetWriterSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/ParquetWriterSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ package com.nvidia.spark.rapids import java.io.File import java.nio.charset.StandardCharsets +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} import org.apache.parquet.hadoop.ParquetFileReader @@ -150,7 +151,7 @@ class ParquetWriterSuite extends SparkQueryCompareTestSuite { try { spark.sql("CREATE TABLE t(id STRING) USING PARQUET") val df = spark.sql("INSERT INTO TABLE t SELECT 'abc'") - val insert = ShimLoader.getSparkShims.findOperators(df.queryExecution.executedPlan, + val insert = SparkShimImpl.findOperators(df.queryExecution.executedPlan, _.isInstanceOf[GpuDataWritingCommandExec]).head .asInstanceOf[GpuDataWritingCommandExec] assert(insert.metrics.contains(BasicColumnarWriteJobStatsTracker.JOB_COMMIT_TIME)) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala index 0f66d6da249..2279ea5c77e 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RegularExpressionSuite.scala @@ -15,6 +15,8 @@ */ package com.nvidia.spark.rapids +import com.nvidia.spark.rapids.shims.SparkShimImpl + import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, SparkSession} @@ -38,7 +40,7 @@ class RegularExpressionSuite extends SparkQueryCompareTestSuite { frame => { // this test is only valid in Spark 3.0.x because the expression is NullIntolerant // since Spark 3.1.0 and gets replaced with a null literal instead - val isValidTestForSparkVersion = ShimLoader.getSparkShims.getSparkShimVersion match { + val isValidTestForSparkVersion = SparkShimImpl.getSparkShimVersion match { case SparkShimVersion(major, minor, _) => major == 3 && minor == 0 case DatabricksShimVersion(major, minor, _, _) => major == 3 && minor == 0 case ClouderaShimVersion(major, minor, _, _) => major == 3 && minor == 0 diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala index 6d4923f9a20..3226e4955e8 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/SparkQueryCompareTestSuite.scala @@ -20,11 +20,11 @@ import java.nio.file.Files import java.sql.{Date, Timestamp} import java.util.{Locale, TimeZone} +import com.nvidia.spark.rapids.shims.SparkShimImpl +import org.scalatest.{Assertion, FunSuite} import scala.reflect.ClassTag import scala.util.{Failure, Try} -import org.scalatest.{Assertion, FunSuite} - import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Row, SparkSession} @@ -1839,7 +1839,7 @@ trait SparkQueryCompareTestSuite extends FunSuite with Arm { assume(!VersionUtils.isSpark311OrLater, "Spark version not before 3.1.1") def cmpSparkVersion(major: Int, minor: Int, bugfix: Int): Int = { - val sparkShimVersion = ShimLoader.getSparkShims.getSparkShimVersion + val sparkShimVersion = SparkShimImpl.getSparkShimVersion val (sparkMajor, sparkMinor, sparkBugfix) = sparkShimVersion match { case SparkShimVersion(a, b, c) => (a, b, c) case DatabricksShimVersion(a, b, c, _) => (a, b, c) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala b/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala index b87a944d1f0..1d12813a873 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/TestUtils.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,9 @@ package com.nvidia.spark.rapids -import java.io.File - import ai.rapids.cudf.{ColumnVector, DType, HostColumnVectorCore, Table} +import com.nvidia.spark.rapids.shims.SparkShimImpl +import java.io.File import org.scalatest.Assertions import org.apache.spark.SparkConf @@ -63,7 +63,7 @@ object TestUtils extends Assertions with Arm { /** Recursively check if the predicate matches in the given plan */ def findOperator(plan: SparkPlan, predicate: SparkPlan => Boolean): Option[SparkPlan] = { - ShimLoader.getSparkShims.findOperators(plan, predicate).headOption + SparkShimImpl.findOperators(plan, predicate).headOption } /** Return final executed plan */ diff --git a/tests/src/test/scala/org/apache/spark/sql/GpuSparkPlanSuite.scala b/tests/src/test/scala/org/apache/spark/sql/GpuSparkPlanSuite.scala index b9569f5727e..71fce9197bc 100644 --- a/tests/src/test/scala/org/apache/spark/sql/GpuSparkPlanSuite.scala +++ b/tests/src/test/scala/org/apache/spark/sql/GpuSparkPlanSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ package org.apache.spark.sql -import com.nvidia.spark.rapids.{ShimLoader, SparkSessionHolder} +import com.nvidia.spark.rapids.SparkSessionHolder +import com.nvidia.spark.rapids.shims.SparkShimImpl import org.scalatest.FunSuite import org.apache.spark.SparkConf @@ -31,7 +32,7 @@ class GpuSparkPlanSuite extends FunSuite { .set("spark.rapids.sql.enabled", "true") SparkSessionHolder.withSparkSession(conf, spark => { - val defaultSlice = ShimLoader.getSparkShims.leafNodeDefaultParallelism(spark) + val defaultSlice = SparkShimImpl.leafNodeDefaultParallelism(spark) val ds = new Dataset(spark, Range(0, 20, 1, None), Encoders.LONG) val partitions = ds.rdd.getNumPartitions assert(partitions == defaultSlice) diff --git a/tools/pom.xml b/tools/pom.xml index 25fe91d7f5f..e90a275070e 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -32,7 +32,7 @@ jar - 3.1.1 + 3.1.4 ${spark311.version} ${spark311.version} spark311 diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/GpuScalaUDF.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/GpuScalaUDF.scala index 22e651d29c6..bee2f73a3bf 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/GpuScalaUDF.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/GpuScalaUDF.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ package com.nvidia.spark.udf import scala.util.control.NonFatal -import com.nvidia.spark.rapids.shims.v2.ShimExpression +import com.nvidia.spark.rapids.shims.ShimExpression import org.apache.spark.SparkException import org.apache.spark.internal.Logging diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala index be246bbe56d..fa9a7192ca0 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/Instruction.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021, NVIDIA CORPORATION. + * Copyright (c) 2019-2022, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ package com.nvidia.spark.udf import java.nio.charset.Charset -import com.nvidia.spark.rapids.shims.v2.ShimExpression +import com.nvidia.spark.rapids.shims.ShimExpression import com.nvidia.spark.udf.CatalystExpressionBuilder.simplify import javassist.bytecode.{CodeIterator, Opcode}