From 2ab523c1029996ac72107b82ef3ce506726965e0 Mon Sep 17 00:00:00 2001 From: Alex Barreto Date: Wed, 1 Dec 2021 19:07:58 -0500 Subject: [PATCH] [SPARK-37496][SQL] Migrate ReplaceTableAsSelectStatement to v2 command ### What changes were proposed in this pull request? This PR migrates `ReplaceTableAsSelectStatement` to the v2 command ### Why are the changes needed? Migrate to the standard V2 framework ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing tests Closes #34754 from huaxingao/replace_table. Authored-by: Huaxin Gao Signed-off-by: Huaxin Gao --- .../spark/resource/ResourceProfile.scala | 4 +- .../resource/ResourceProfileManager.scala | 8 +- .../spark/shuffle/FetchFailedException.scala | 2 +- .../spark/storage/FallbackStorage.scala | 16 +- .../org/apache/spark/util/SizeEstimator.scala | 8 +- dev/create-release/release-build.sh | 2 +- dev/deps/spark-deps-hadoop-2.7-hive-2.3 | 2 +- dev/deps/spark-deps-hadoop-3.2-hive-2.3 | 2 +- dev/github_jira_sync.py | 6 +- dev/lint-python | 66 +--- dev/requirements.txt | 3 +- dev/run-tests-jenkins.py | 3 - dev/run-tests.py | 23 +- dev/sparktestsupport/modules.py | 2 + dev/test-dependencies.sh | 8 +- dev/tox.ini | 38 +- docs/running-on-kubernetes.md | 175 ++++++++- docs/sql-data-sources-csv.md | 12 +- docs/sql-migration-guide.md | 2 + docs/sql-ref-ansi-compliance.md | 2 +- docs/sql-ref-syntax-ddl-alter-database.md | 54 ++- .../sql/kafka010/KafkaOffsetReaderAdmin.scala | 4 +- .../source/image/ImageFileFormatSuite.scala | 17 +- pom.xml | 7 +- python/docs/source/development/debugging.rst | 56 ++- .../docs/source/getting_started/install.rst | 4 +- .../migration_guide/pyspark_3.2_to_3.3.rst | 2 + .../source/reference/pyspark.pandas/frame.rst | 1 + .../reference/pyspark.pandas/indexing.rst | 7 + .../user_guide/pandas_on_spark/options.rst | 5 +- .../source/user_guide/sql/arrow_pandas.rst | 2 +- python/pyspark/__init__.py | 4 +- python/pyspark/__init__.pyi | 2 +- python/pyspark/accumulators.py | 4 +- python/pyspark/cloudpickle/__init__.py | 2 +- python/pyspark/cloudpickle/cloudpickle.py | 195 ++++++++-- .../pyspark/cloudpickle/cloudpickle_fast.py | 69 +++- python/pyspark/context.py | 16 +- python/pyspark/context.pyi | 1 + python/pyspark/ml/_typing.pyi | 5 +- python/pyspark/ml/common.py | 44 ++- python/pyspark/ml/tests/test_linalg.py | 4 +- python/pyspark/ml/tuning.py | 2 +- python/pyspark/mllib/_typing.pyi | 5 +- python/pyspark/mllib/common.py | 56 +-- python/pyspark/mllib/tests/test_algorithms.py | 6 +- python/pyspark/mllib/tests/test_linalg.py | 6 +- python/pyspark/pandas/__init__.py | 4 +- python/pyspark/pandas/config.py | 3 +- python/pyspark/pandas/data_type_ops/base.py | 4 + .../pyspark/pandas/data_type_ops/num_ops.py | 45 ++- .../pandas/data_type_ops/timedelta_ops.py | 28 ++ python/pyspark/pandas/frame.py | 189 ++++++++++ python/pyspark/pandas/indexes/__init__.py | 1 + python/pyspark/pandas/indexes/base.py | 14 +- python/pyspark/pandas/indexes/timedelta.py | 100 +++++ python/pyspark/pandas/internal.py | 11 + python/pyspark/pandas/missing/frame.py | 1 - python/pyspark/pandas/missing/indexes.py | 18 + python/pyspark/pandas/series.py | 14 +- python/pyspark/pandas/sql_formatter.py | 273 ++++++++++++++ python/pyspark/pandas/sql_processor.py | 32 +- .../tests/data_type_ops/test_num_ops.py | 87 +++-- .../pyspark/pandas/tests/indexes/test_base.py | 41 +- python/pyspark/pandas/tests/test_dataframe.py | 65 +++- python/pyspark/pandas/tests/test_series.py | 17 +- python/pyspark/pandas/tests/test_sql.py | 49 ++- python/pyspark/pandas/typedef/typehints.py | 6 + .../pyspark/pandas/usage_logging/__init__.py | 6 +- python/pyspark/profiler.py | 45 ++- python/pyspark/profiler.pyi | 17 +- python/pyspark/rdd.py | 16 +- python/pyspark/serializers.py | 168 +++++---- python/pyspark/shuffle.py | 8 +- python/pyspark/sql/dataframe.py | 12 +- python/pyspark/sql/pandas/utils.py | 2 +- python/pyspark/sql/session.py | 7 +- python/pyspark/sql/streaming.py | 4 +- .../sql/tests/test_pandas_udf_grouped_agg.py | 2 +- .../sql/tests/test_pandas_udf_scalar.py | 4 +- python/pyspark/sql/tests/test_session.py | 11 +- python/pyspark/sql/tests/test_udf_profiler.py | 109 ++++++ python/pyspark/sql/udf.py | 32 +- python/pyspark/tests/test_rdd.py | 4 +- python/pyspark/tests/test_serializers.py | 16 +- python/pyspark/tests/test_shuffle.py | 4 +- python/pyspark/tests/test_worker.py | 4 +- python/pyspark/worker.py | 6 +- python/setup.py | 2 +- .../ExecutorPodsPollingSnapshotSource.scala | 13 +- .../k8s/ExecutorPodsWatchSnapshotSource.scala | 14 +- .../main/dockerfiles/spark/Dockerfile.java17 | 1 + .../mesos/src/test/resources/log4j.properties | 27 ++ .../org/apache/spark/deploy/yarn/Client.scala | 13 +- .../spark/sql/catalyst/parser/SqlBase.g4 | 41 +- .../catalog/DelegatingCatalogExtension.java | 10 + .../sql/connector/catalog/TableCatalog.java | 5 +- .../sql/connector/read/HasPartitionKey.java | 52 +++ .../sql/catalyst/analysis/Analyzer.scala | 18 +- .../catalyst/analysis/AnsiTypeCoercion.scala | 7 +- .../catalyst/analysis/CTESubstitution.scala | 2 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 111 +++++- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../analysis/RelationTimeTravel.scala | 34 ++ .../catalyst/analysis/ResolveCatalogs.scala | 23 -- .../sql/catalyst/analysis/ResolveHints.scala | 4 +- .../catalyst/analysis/TimeTravelSpec.scala | 64 ++++ .../sql/catalyst/analysis/TypeCoercion.scala | 4 +- .../sql/catalyst/analysis/unresolved.scala | 12 +- .../catalog/ExternalCatalogUtils.scala | 27 ++ .../sql/catalyst/catalog/SessionCatalog.scala | 11 +- .../sql/catalyst/csv/CSVInferSchema.scala | 24 ++ .../spark/sql/catalyst/csv/CSVOptions.scala | 4 + .../sql/catalyst/csv/UnivocityGenerator.scala | 2 +- .../sql/catalyst/csv/UnivocityParser.scala | 4 +- .../sql/catalyst/expressions/Expression.scala | 11 + .../sql/catalyst/expressions/PythonUDF.scala | 20 +- .../expressions/namedExpressions.scala | 10 - .../expressions/stringExpressions.scala | 17 + .../sql/catalyst/parser/AstBuilder.scala | 58 +-- .../sql/catalyst/parser/ParseDriver.scala | 60 +++ .../catalyst/plans/logical/LogicalPlan.scala | 3 +- .../catalyst/plans/logical/statements.scala | 42 --- .../statsEstimation/UnionEstimation.scala | 8 +- .../catalyst/plans/logical/v2Commands.scala | 59 ++- .../sql/catalyst/rules/RuleIdCollection.scala | 1 + .../spark/sql/catalyst/trees/TreeNode.scala | 4 + .../sql/catalyst/util/DateTimeUtils.scala | 32 +- .../catalyst/util/TimestampFormatter.scala | 36 +- .../spark/sql/catalyst/util/package.scala | 1 + .../sql/connector/catalog/CatalogV2Util.scala | 16 +- .../connector/expressions/expressions.scala | 26 -- .../sql/errors/QueryCompilationErrors.scala | 15 +- .../sql/errors/QueryExecutionErrors.scala | 10 +- .../spark/sql/errors/QueryParsingErrors.scala | 18 +- .../apache/spark/sql/internal/SQLConf.scala | 9 + .../analysis/UnsupportedOperationsSuite.scala | 2 +- .../expressions/StringExpressionsSuite.scala | 9 + .../sql/catalyst/parser/DDLParserSuite.scala | 181 ++------- .../sql/catalyst/parser/PlanParserSuite.scala | 98 ++++- .../UnionEstimationSuite.scala | 24 +- .../catalyst/util/DateTimeUtilsSuite.scala | 12 + .../sql/connector/catalog/InMemoryTable.scala | 22 +- .../DataSourceReadBenchmark-jdk11-results.txt | 356 +++++++++--------- .../DataSourceReadBenchmark-results.txt | 356 +++++++++--------- .../parquet/VectorizedPlainValuesReader.java | 52 ++- .../vectorized/OffHeapColumnVector.java | 12 + .../vectorized/OnHeapColumnVector.java | 12 + .../vectorized/WritableColumnVector.java | 36 ++ .../apache/spark/sql/DataFrameWriter.scala | 31 +- .../apache/spark/sql/DataFrameWriterV2.scala | 34 +- .../org/apache/spark/sql/SparkSession.scala | 8 +- .../analysis/ReplaceCharWithVarchar.scala | 4 +- .../analysis/ResolveSessionCatalog.scala | 57 ++- .../spark/sql/execution/SQLExecution.scala | 8 +- .../spark/sql/execution/SparkSqlParser.scala | 49 +-- .../OptimizeSkewInRebalancePartitions.scala | 8 +- .../adaptive/ShufflePartitionsUtil.scala | 18 +- .../execution/datasources/DataSource.scala | 2 +- .../datasources/PartitioningUtils.scala | 2 +- .../sql/execution/datasources/rules.scala | 9 +- .../datasources/v2/CreateTableExec.scala | 7 +- .../datasources/v2/DataSourceV2Strategy.scala | 41 +- .../datasources/v2/DataSourceV2Utils.scala | 11 +- .../datasources/v2/DescribeColumnExec.scala | 4 +- .../execution/datasources/v2/FileScan.scala | 4 +- .../datasources/v2/V2SessionCatalog.scala | 22 ++ .../v2/WriteToDataSourceV2Exec.scala | 12 +- .../streaming/sources/ForeachBatchSink.scala | 10 +- .../state/RocksDBStateStoreProvider.scala | 2 +- .../sql/execution/ui/ExecutionPage.scala | 27 +- .../execution/ui/SQLAppStatusListener.scala | 6 +- .../sql/execution/ui/SQLAppStatusStore.scala | 1 + .../spark/sql/execution/ui/SQLListener.scala | 3 +- .../sql/streaming/DataStreamWriter.scala | 19 +- .../sql-functions/sql-expression-schema.md | 3 +- .../resources/sql-tests/inputs/comments.sql | 29 ++ .../sql-tests/inputs/string-functions.sql | 10 +- .../sql-tests/results/ansi/date.sql.out | 12 +- .../sql-tests/results/ansi/interval.sql.out | 6 +- .../results/ansi/string-functions.sql.out | 50 ++- .../sql-tests/results/comments.sql.out | 64 +++- .../results/postgreSQL/union.sql.out | 1 + .../results/string-functions.sql.out | 50 ++- .../spark/sql/CharVarcharTestSuite.scala | 11 - .../apache/spark/sql/CsvFunctionsSuite.scala | 11 + .../spark/sql/IntegratedUDFTestUtils.scala | 35 +- .../spark/sql/TPCDSQueryTestSuite.scala | 52 ++- .../sql/connector/DataSourceV2SQLSuite.scala | 124 ++++-- .../SupportsCatalogOptionsSuite.scala | 11 +- .../V2CommandsCaseSensitivitySuite.scala | 18 +- .../sql/execution/SQLExecutionSuite.scala | 45 ++- .../sql/execution/SQLJsonProtocolSuite.scala | 70 ++-- .../spark/sql/execution/SQLViewSuite.scala | 7 - .../sql/execution/SQLViewTestSuite.scala | 15 + .../ShufflePartitionsUtilSuite.scala | 36 +- .../sql/execution/SparkSqlParserSuite.scala | 9 + .../adaptive/AdaptiveQueryExecSuite.scala | 31 ++ .../benchmark/DataSourceReadBenchmark.scala | 21 +- ...AlterNamespaceSetLocationParserSuite.scala | 41 ++ .../AlterNamespaceSetLocationSuiteBase.scala | 83 ++++ .../command/CharVarcharDDLTestBase.scala | 10 + .../sql/execution/command/DDLSuite.scala | 30 +- .../command/PlanResolutionSuite.scala | 110 ++---- .../command/ShowNamespacesParserSuite.scala | 54 ++- .../command/ShowNamespacesSuiteBase.scala | 38 +- .../v1/AlterNamespaceSetLocationSuite.scala | 49 +++ .../command/v1/ShowNamespacesSuite.scala | 18 - .../v2/AlterNamespaceSetLocationSuite.scala | 34 ++ .../command/v2/DescribeNamespaceSuite.scala | 2 +- .../command/v2/ShowNamespacesSuite.scala | 16 - .../execution/datasources/csv/CSVSuite.scala | 217 ++++++++++- .../datasources/orc/OrcQuerySuite.scala | 6 +- .../datasources/orc/OrcSourceSuite.scala | 51 ++- .../execution/datasources/orc/OrcTest.scala | 2 +- .../parquet/ParquetEncodingSuite.scala | 7 +- .../datasources/parquet/ParquetIOSuite.scala | 2 +- .../ParquetPartitionDiscoverySuite.scala | 11 +- .../history/SQLEventFilterBuilderSuite.scala | 2 +- .../SQLLiveEntitiesEventFilterSuite.scala | 4 +- .../sources/ForeachBatchSinkSuite.scala | 76 ++++ .../execution/ui/AllExecutionsPageSuite.scala | 3 +- .../ui/MetricsAggregationBenchmark.scala | 3 +- .../ui/SQLAppStatusListenerSuite.scala | 32 +- .../vectorized/ColumnarBatchSuite.scala | 94 +++++ .../sql/streaming/StreamingQuerySuite.scala | 2 +- .../status/api/v1/sql/SqlResourceSuite.scala | 1 + .../spark/sql/hive/HiveExternalCatalog.scala | 35 +- .../sql/hive/client/HiveClientImpl.scala | 5 +- .../AlterNamespaceSetLocationSuite.scala | 41 ++ .../command/ShowNamespacesSuite.scala | 19 +- .../apache/spark/sql/hive/test/TestHive.scala | 2 +- 232 files changed, 5000 insertions(+), 1821 deletions(-) create mode 100644 python/pyspark/pandas/data_type_ops/timedelta_ops.py create mode 100644 python/pyspark/pandas/indexes/timedelta.py create mode 100644 python/pyspark/pandas/sql_formatter.py create mode 100644 python/pyspark/sql/tests/test_udf_profiler.py create mode 100644 resource-managers/mesos/src/test/resources/log4j.properties create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/HasPartitionKey.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationTimeTravel.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TimeTravelSpec.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationParserSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationSuiteBase.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterNamespaceSetLocationSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterNamespaceSetLocationSuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/AlterNamespaceSetLocationSuite.scala diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala b/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala index 1ebd8bd89f..3398701950 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala @@ -463,14 +463,14 @@ object ResourceProfile extends Logging { case ResourceProfile.CORES => cores = execReq.amount.toInt case rName => - val nameToUse = resourceMappings.get(rName).getOrElse(rName) + val nameToUse = resourceMappings.getOrElse(rName, rName) customResources(nameToUse) = execReq } } customResources.toMap } else { defaultResources.customResources.map { case (rName, execReq) => - val nameToUse = resourceMappings.get(rName).getOrElse(rName) + val nameToUse = resourceMappings.getOrElse(rName, rName) (nameToUse, execReq) } } diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala b/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala index d538f0bcc4..2858443c7c 100644 --- a/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala +++ b/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala @@ -57,8 +57,10 @@ private[spark] class ResourceProfileManager(sparkConf: SparkConf, private val notRunningUnitTests = !isTesting private val testExceptionThrown = sparkConf.get(RESOURCE_PROFILE_MANAGER_TESTING) - // If we use anything except the default profile, its only supported on YARN right now. - // Throw an exception if not supported. + /** + * If we use anything except the default profile, it's only supported on YARN and Kubernetes + * with dynamic allocation enabled. Throw an exception if not supported. + */ private[spark] def isSupported(rp: ResourceProfile): Boolean = { val isNotDefaultProfile = rp.id != ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID val notYarnOrK8sAndNotDefaultProfile = isNotDefaultProfile && !(isYarn || isK8s) @@ -103,7 +105,7 @@ private[spark] class ResourceProfileManager(sparkConf: SparkConf, def resourceProfileFromId(rpId: Int): ResourceProfile = { readLock.lock() try { - resourceProfileIdToResourceProfile.get(rpId).getOrElse( + resourceProfileIdToResourceProfile.getOrElse(rpId, throw new SparkException(s"ResourceProfileId $rpId not found!") ) } finally { diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index 208c676a1c..626a237732 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -56,7 +56,7 @@ private[spark] class FetchFailedException( // which intercepts this exception (possibly wrapping it), the Executor can still tell there was // a fetch failure, and send the correct error msg back to the driver. We wrap with an Option // because the TaskContext is not defined in some test cases. - Option(TaskContext.get()).map(_.setFetchFailed(this)) + Option(TaskContext.get()).foreach(_.setFetchFailed(this)) def toTaskFailedReason: TaskFailedReason = FetchFailed( bmAddress, shuffleId, mapId, mapIndex, reduceId, Utils.exceptionString(this)) diff --git a/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala b/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala index 7613713322..d137099e73 100644 --- a/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala +++ b/core/src/main/scala/org/apache/spark/storage/FallbackStorage.scala @@ -31,6 +31,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{STORAGE_DECOMMISSION_FALLBACK_STORAGE_CLEANUP, STORAGE_DECOMMISSION_FALLBACK_STORAGE_PATH} import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.util.JavaUtils import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcTimeout} import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleBlockInfo} import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID @@ -60,15 +61,17 @@ private[storage] class FallbackStorage(conf: SparkConf) extends Logging { val indexFile = r.getIndexFile(shuffleId, mapId) if (indexFile.exists()) { + val hash = JavaUtils.nonNegativeHash(indexFile.getName) fallbackFileSystem.copyFromLocalFile( new Path(indexFile.getAbsolutePath), - new Path(fallbackPath, s"$appId/$shuffleId/${indexFile.getName}")) + new Path(fallbackPath, s"$appId/$shuffleId/$hash/${indexFile.getName}")) val dataFile = r.getDataFile(shuffleId, mapId) if (dataFile.exists()) { + val hash = JavaUtils.nonNegativeHash(dataFile.getName) fallbackFileSystem.copyFromLocalFile( new Path(dataFile.getAbsolutePath), - new Path(fallbackPath, s"$appId/$shuffleId/${dataFile.getName}")) + new Path(fallbackPath, s"$appId/$shuffleId/$hash/${dataFile.getName}")) } // Report block statuses @@ -86,7 +89,8 @@ private[storage] class FallbackStorage(conf: SparkConf) extends Logging { } def exists(shuffleId: Int, filename: String): Boolean = { - fallbackFileSystem.exists(new Path(fallbackPath, s"$appId/$shuffleId/$filename")) + val hash = JavaUtils.nonNegativeHash(filename) + fallbackFileSystem.exists(new Path(fallbackPath, s"$appId/$shuffleId/$hash/$filename")) } } @@ -168,7 +172,8 @@ private[spark] object FallbackStorage extends Logging { } val name = ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID).name - val indexFile = new Path(fallbackPath, s"$appId/$shuffleId/$name") + val hash = JavaUtils.nonNegativeHash(name) + val indexFile = new Path(fallbackPath, s"$appId/$shuffleId/$hash/$name") val start = startReduceId * 8L val end = endReduceId * 8L Utils.tryWithResource(fallbackFileSystem.open(indexFile)) { inputStream => @@ -178,7 +183,8 @@ private[spark] object FallbackStorage extends Logging { index.skip(end - (start + 8L)) val nextOffset = index.readLong() val name = ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID).name - val dataFile = new Path(fallbackPath, s"$appId/$shuffleId/$name") + val hash = JavaUtils.nonNegativeHash(name) + val dataFile = new Path(fallbackPath, s"$appId/$shuffleId/$hash/$name") val f = fallbackFileSystem.open(dataFile) val size = nextOffset - offset logDebug(s"To byte array $size") diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 85e1119569..9ec93077d0 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -33,11 +33,9 @@ import org.apache.spark.util.collection.OpenHashSet /** * A trait that allows a class to give [[SizeEstimator]] more accurate size estimation. - * When a class extends it, [[SizeEstimator]] will query the `estimatedSize` first. - * If `estimatedSize` does not return `None`, [[SizeEstimator]] will use the returned size - * as the size of the object. Otherwise, [[SizeEstimator]] will do the estimation work. - * The difference between a [[KnownSizeEstimation]] and - * [[org.apache.spark.util.collection.SizeTracker]] is that, a + * When a class extends it, [[SizeEstimator]] will query the `estimatedSize`, and use + * the returned size as the size of the object. The difference between a [[KnownSizeEstimation]] + * and [[org.apache.spark.util.collection.SizeTracker]] is that, a * [[org.apache.spark.util.collection.SizeTracker]] still uses [[SizeEstimator]] to * estimate the size. However, a [[KnownSizeEstimation]] can provide a better estimation without * using [[SizeEstimator]]. diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 96b8d4e3d9..44baeddb6f 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -192,7 +192,7 @@ SCALA_2_12_PROFILES="-Pscala-2.12" HIVE_PROFILES="-Phive -Phive-thriftserver" # Profiles for publishing snapshots and release to Maven Central # We use Apache Hive 2.3 for publishing -PUBLISH_PROFILES="$BASE_PROFILES $HIVE_PROFILES -Phive-2.3 -Pspark-ganglia-lgpl -Pkinesis-asl -Phadoop-cloud" +PUBLISH_PROFILES="$BASE_PROFILES $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl -Phadoop-cloud" # Profiles for building binary releases BASE_RELEASE_PROFILES="$BASE_PROFILES -Psparkr" diff --git a/dev/deps/spark-deps-hadoop-2.7-hive-2.3 b/dev/deps/spark-deps-hadoop-2.7-hive-2.3 index 063c4602d0..69e7bbd7ff 100644 --- a/dev/deps/spark-deps-hadoop-2.7-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-2.7-hive-2.3 @@ -35,7 +35,7 @@ cats-kernel_2.12/2.1.1//cats-kernel_2.12-2.1.1.jar chill-java/0.10.0//chill-java-0.10.0.jar chill_2.12/0.10.0//chill_2.12-0.10.0.jar commons-beanutils/1.9.4//commons-beanutils-1.9.4.jar -commons-cli/1.2//commons-cli-1.2.jar +commons-cli/1.5.0//commons-cli-1.5.0.jar commons-codec/1.15//commons-codec-1.15.jar commons-collections/3.2.2//commons-collections-3.2.2.jar commons-compiler/3.0.16//commons-compiler-3.0.16.jar diff --git a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 index 9a3d35d942..93596c2164 100644 --- a/dev/deps/spark-deps-hadoop-3.2-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3.2-hive-2.3 @@ -37,7 +37,7 @@ breeze_2.12/1.2//breeze_2.12-1.2.jar cats-kernel_2.12/2.1.1//cats-kernel_2.12-2.1.1.jar chill-java/0.10.0//chill-java-0.10.0.jar chill_2.12/0.10.0//chill_2.12-0.10.0.jar -commons-cli/1.2//commons-cli-1.2.jar +commons-cli/1.5.0//commons-cli-1.5.0.jar commons-codec/1.15//commons-codec-1.15.jar commons-collections/3.2.2//commons-collections-3.2.2.jar commons-compiler/3.0.16//commons-compiler-3.0.16.jar diff --git a/dev/github_jira_sync.py b/dev/github_jira_sync.py index 27451bba90..bb30ef99e8 100755 --- a/dev/github_jira_sync.py +++ b/dev/github_jira_sync.py @@ -77,9 +77,9 @@ def get_jira_prs(): page_json = get_json(page) for pull in page_json: - jiras = re.findall(JIRA_PROJECT_NAME + "-[0-9]{4,5}", pull['title']) - for jira in jiras: - result = result + [(jira, pull)] + jira_issues = re.findall(JIRA_PROJECT_NAME + "-[0-9]{4,5}", pull['title']) + for jira_issue in jira_issues: + result = result + [(jira_issue, pull)] # Check if there is another page link_headers = list(filter(lambda k: k.startswith("Link"), page.headers)) diff --git a/dev/lint-python b/dev/lint-python index 9b7a139176..e60ba7be07 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -20,8 +20,6 @@ FLAKE8_BUILD="flake8" MINIMUM_FLAKE8="3.9.0" MINIMUM_MYPY="0.910" MYPY_BUILD="mypy" -PYCODESTYLE_BUILD="pycodestyle" -MINIMUM_PYCODESTYLE="2.7.0" PYTEST_BUILD="pytest" PYTHON_EXECUTABLE="${PYTHON_EXECUTABLE:-python3}" @@ -64,66 +62,6 @@ function compile_python_test { fi } -function pycodestyle_test { - local PYCODESTYLE_STATUS= - local PYCODESTYLE_REPORT= - local RUN_LOCAL_PYCODESTYLE= - local PYCODESTYLE_VERSION= - local EXPECTED_PYCODESTYLE= - local PYCODESTYLE_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pycodestyle-$MINIMUM_PYCODESTYLE.py" - local PYCODESTYLE_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/PyCQA/pycodestyle/$MINIMUM_PYCODESTYLE/pycodestyle.py" - - if [[ ! "$1" ]]; then - echo "No python files found! Something is very wrong -- exiting." - exit 1; - fi - - # check for locally installed pycodestyle & version - RUN_LOCAL_PYCODESTYLE="False" - if hash "$PYCODESTYLE_BUILD" 2> /dev/null; then - PYCODESTYLE_VERSION="$($PYCODESTYLE_BUILD --version)" - EXPECTED_PYCODESTYLE="$(satisfies_min_version $PYCODESTYLE_VERSION $MINIMUM_PYCODESTYLE)" - if [ "$EXPECTED_PYCODESTYLE" == "True" ]; then - RUN_LOCAL_PYCODESTYLE="True" - fi - fi - - # download the right version or run locally - if [ $RUN_LOCAL_PYCODESTYLE == "False" ]; then - # Get pycodestyle at runtime so that we don't rely on it being installed on the build server. - # See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 - # Updated to the latest official version of pep8. pep8 is formally renamed to pycodestyle. - echo "downloading pycodestyle from $PYCODESTYLE_SCRIPT_REMOTE_PATH..." - if [ ! -e "$PYCODESTYLE_SCRIPT_PATH" ]; then - curl --silent -o "$PYCODESTYLE_SCRIPT_PATH" "$PYCODESTYLE_SCRIPT_REMOTE_PATH" - local curl_status="$?" - - if [ "$curl_status" -ne 0 ]; then - echo "Failed to download pycodestyle.py from $PYCODESTYLE_SCRIPT_REMOTE_PATH" - exit "$curl_status" - fi - fi - - echo "starting pycodestyle test..." - PYCODESTYLE_REPORT=$( ("$PYTHON_EXECUTABLE" "$PYCODESTYLE_SCRIPT_PATH" --config=dev/tox.ini $1) 2>&1) - PYCODESTYLE_STATUS=$? - else - # we have the right version installed, so run locally - echo "starting pycodestyle test..." - PYCODESTYLE_REPORT=$( ($PYCODESTYLE_BUILD --config=dev/tox.ini $1) 2>&1) - PYCODESTYLE_STATUS=$? - fi - - if [ $PYCODESTYLE_STATUS -ne 0 ]; then - echo "pycodestyle checks failed:" - echo "$PYCODESTYLE_REPORT" - exit "$PYCODESTYLE_STATUS" - else - echo "pycodestyle checks passed." - echo - fi -} - function mypy_annotation_test { local MYPY_REPORT= @@ -292,12 +230,10 @@ SPARK_ROOT_DIR="$(dirname "${SCRIPT_DIR}")" pushd "$SPARK_ROOT_DIR" &> /dev/null -# skipping local ruby bundle directory from the search -PYTHON_SOURCE="$(find . -path ./docs/.local_ruby_bundle -prune -false -o -name "*.py")" +PYTHON_SOURCE="$(git ls-files '*.py')" compile_python_test "$PYTHON_SOURCE" black_test -pycodestyle_test "$PYTHON_SOURCE" flake8_test mypy_test diff --git a/dev/requirements.txt b/dev/requirements.txt index 273294a96a..3f8b2c48f6 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -19,7 +19,8 @@ coverage # Linter mypy -flake8 +git+https://github.com/typeddjango/pytest-mypy-plugins.git@b0020061f48e85743ee3335bd62a3a608d17c6bd +flake8==3.9.0 # Documentation (SQL) mkdocs diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index f24e702a8d..67d0972acc 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -174,9 +174,6 @@ def main(): os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.7" if "test-hadoop3.2" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop3.2" - # Switch the Hive profile based on the PR title: - if "test-hive2.3" in ghprb_pull_title: - os.environ["AMPLAB_JENKINS_BUILD_HIVE_PROFILE"] = "hive2.3" # Switch the Scala profile based on the PR title: if "test-scala2.13" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_SCALA_PROFILE"] = "scala2.13" diff --git a/dev/run-tests.py b/dev/run-tests.py index 55c65ed2d6..25df8f62ac 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -345,24 +345,6 @@ def get_hadoop_profiles(hadoop_version): sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) -def get_hive_profiles(hive_version): - """ - For the given Hive version tag, return a list of Maven/SBT profile flags for - building and testing against that Hive version. - """ - - sbt_maven_hive_profiles = { - "hive2.3": ["-Phive-2.3"], - } - - if hive_version in sbt_maven_hive_profiles: - return sbt_maven_hive_profiles[hive_version] - else: - print("[error] Could not find", hive_version, "in the list. Valid options", - " are", sbt_maven_hive_profiles.keys()) - sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) - - def build_spark_maven(extra_profiles): # Enable all of the profiles for the build: build_profiles = extra_profiles + modules.root.build_profile_flags @@ -616,7 +598,6 @@ def main(): build_tool = os.environ.get("AMPLAB_JENKINS_BUILD_TOOL", "sbt") scala_version = os.environ.get("AMPLAB_JENKINS_BUILD_SCALA_PROFILE") hadoop_version = os.environ.get("AMPLAB_JENKINS_BUILD_PROFILE", "hadoop3.2") - hive_version = os.environ.get("AMPLAB_JENKINS_BUILD_HIVE_PROFILE", "hive2.3") test_env = "amplab_jenkins" # add path for Python3 in Jenkins if we're calling from a Jenkins machine # TODO(sknapp): after all builds are ported to the ubuntu workers, change this to be: @@ -627,14 +608,12 @@ def main(): build_tool = "sbt" scala_version = os.environ.get("SCALA_PROFILE") hadoop_version = os.environ.get("HADOOP_PROFILE", "hadoop3.2") - hive_version = os.environ.get("HIVE_PROFILE", "hive2.3") if "GITHUB_ACTIONS" in os.environ: test_env = "github_actions" else: test_env = "local" - extra_profiles = get_hadoop_profiles(hadoop_version) + get_hive_profiles(hive_version) + \ - get_scala_profiles(scala_version) + extra_profiles = get_hadoop_profiles(hadoop_version) + get_scala_profiles(scala_version) print("[info] Using build tool", build_tool, "with profiles", *(extra_profiles + ["under environment", test_env])) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index d13be2e2fe..5dd3ab6169 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -464,6 +464,7 @@ def __hash__(self): "pyspark.sql.tests.test_streaming", "pyspark.sql.tests.test_types", "pyspark.sql.tests.test_udf", + "pyspark.sql.tests.test_udf_profiler", "pyspark.sql.tests.test_utils", ] ) @@ -606,6 +607,7 @@ def __hash__(self): "pyspark.pandas.namespace", "pyspark.pandas.numpy_compat", "pyspark.pandas.sql_processor", + "pyspark.pandas.sql_formatter", "pyspark.pandas.strings", "pyspark.pandas.utils", "pyspark.pandas.window", diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index 156a0d32ff..e23a0b682b 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -86,20 +86,18 @@ $MVN -q versions:set -DnewVersion=$TEMP_VERSION -DgenerateBackupPoms=false > /de for HADOOP_HIVE_PROFILE in "${HADOOP_HIVE_PROFILES[@]}"; do if [[ $HADOOP_HIVE_PROFILE == **hadoop-3.2-hive-2.3** ]]; then HADOOP_PROFILE=hadoop-3.2 - HIVE_PROFILE=hive-2.3 else HADOOP_PROFILE=hadoop-2.7 - HIVE_PROFILE=hive-2.3 fi echo "Performing Maven install for $HADOOP_HIVE_PROFILE" - $MVN $HADOOP_MODULE_PROFILES -P$HADOOP_PROFILE -P$HIVE_PROFILE jar:jar jar:test-jar install:install clean -q + $MVN $HADOOP_MODULE_PROFILES -P$HADOOP_PROFILE jar:jar jar:test-jar install:install clean -q echo "Performing Maven validate for $HADOOP_HIVE_PROFILE" - $MVN $HADOOP_MODULE_PROFILES -P$HADOOP_PROFILE -P$HIVE_PROFILE validate -q + $MVN $HADOOP_MODULE_PROFILES -P$HADOOP_PROFILE validate -q echo "Generating dependency manifest for $HADOOP_HIVE_PROFILE" mkdir -p dev/pr-deps - $MVN $HADOOP_MODULE_PROFILES -P$HADOOP_PROFILE -P$HIVE_PROFILE dependency:build-classpath -pl assembly -am \ + $MVN $HADOOP_MODULE_PROFILES -P$HADOOP_PROFILE dependency:build-classpath -pl assembly -am \ | grep "Dependencies classpath:" -A 1 \ | tail -n 1 | tr ":" "\n" | awk -F '/' '{ # For each dependency classpath, we fetch the last three parts split by "/": artifact id, version, and jar name. diff --git a/dev/tox.ini b/dev/tox.ini index bd69a3f9cb..48138f925d 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -13,14 +13,36 @@ # See the License for the specific language governing permissions and # limitations under the License. -[pycodestyle] -ignore=E203,E226,E241,E305,E402,E722,E731,E741,W503,W504,E501 -max-line-length=100 -exclude=*/target/*,python/pyspark/cloudpickle/*.py,shared.py,python/docs/source/conf.py,work/*/*.py,python/.eggs/*,dist/*,.git/*,dev/ansible-for-test-node/* [flake8] -select = E901,E999,F821,F822,F823,F401,F405,B006 -# Ignore F821 for plot documents in pandas API on Spark. -ignore = F821 -exclude = python/docs/build/html/*,*/target/*,python/pyspark/cloudpickle/*.py,shared.py*,python/docs/source/conf.py,work/*/*.py,python/.eggs/*,dist/*,.git/*,python/out,python/pyspark/sql/pandas/functions.pyi,python/pyspark/sql/column.pyi,python/pyspark/worker.pyi,python/pyspark/java_gateway.pyi,dev/ansible-for-test-node/roles/jenkins-worker/files/util_scripts/*.py +ignore = + E203, + E226, + E305, + E402, + E501, + E722, + E731, + E741, + F403, + F811, + F841, + W503, + W504, +per-file-ignores = + python/pyspark/ml/param/shared.py: F405, +exclude = + */target/*, + docs/.local_ruby_bundle/, + python/pyspark/cloudpickle/*.py, + python/docs/build/*, + python/docs/source/conf.py, + python/.eggs/*, + dist/*, + .git/*, + python/pyspark/sql/pandas/functions.pyi, + python/pyspark/sql/column.pyi, + python/pyspark/worker.pyi, + python/pyspark/java_gateway.pyi, + dev/ansible-for-test-node/*, max-line-length = 100 diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index d32861b5e5..58d6dda1d0 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -592,6 +592,7 @@ See the [configuration page](configuration.html) for information on Spark config IfNotPresent Container image pull policy used when pulling images within Kubernetes. + Valid values are Always, Never, and IfNotPresent. 2.3.0 @@ -779,6 +780,15 @@ See the [configuration page](configuration.html) for information on Spark config 2.3.0 + + spark.kubernetes.authenticate.executor.serviceAccountName + (value of spark.kubernetes.authenticate.driver.serviceAccountName) + + Service account that is used when running the executor pod. + If this parameter is not setup, the fallback logic will use the driver's service account. + + 3.1.0 + spark.kubernetes.authenticate.caCertFile (none) @@ -924,6 +934,14 @@ See the [configuration page](configuration.html) for information on Spark config 2.3.0 + + spark.kubernetes.executor.apiPollingInterval + 30s + + Interval between polls against the Kubernetes API server to inspect the state of executors. + + 2.4.0 + spark.kubernetes.driver.request.cores (none) @@ -1232,7 +1250,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.kubernetes.executor.checkAllContainers - false + false Specify whether executor pods should be check all containers (including sidecars) or only the executor container when determining the pod status. @@ -1240,7 +1258,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.kubernetes.submission.connectionTimeout - 10000 + 10000 Connection timeout in milliseconds for the kubernetes client to use for starting the driver. @@ -1248,7 +1266,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.kubernetes.submission.requestTimeout - 10000 + 10000 Request timeout in milliseconds for the kubernetes client to use for starting the driver. @@ -1256,7 +1274,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.kubernetes.driver.connectionTimeout - 10000 + 10000 Connection timeout in milliseconds for the kubernetes client in driver to use when requesting executors. @@ -1264,7 +1282,7 @@ See the [configuration page](configuration.html) for information on Spark config spark.kubernetes.driver.requestTimeout - 10000 + 10000 Request timeout in milliseconds for the kubernetes client in driver to use when requesting executors. @@ -1278,6 +1296,14 @@ See the [configuration page](configuration.html) for information on Spark config 3.0.0 + + spark.kubernetes.dynamicAllocation.deleteGracePeriod + 5s + + How long to wait for executors to shut down gracefully before a forceful kill. + + 3.0.0 + spark.kubernetes.file.upload.path (none) @@ -1322,6 +1348,145 @@ See the [configuration page](configuration.html) for information on Spark config 3.3.0 + + spark.kubernetes.configMap.maxSize + 1572864 + + Max size limit for a config map. + This is configurable as per limit on k8s server end. + + 3.1.0 + + + spark.kubernetes.executor.missingPodDetectDelta + 30s + + When a registered executor's POD is missing from the Kubernetes API server's polled + list of PODs then this delta time is taken as the accepted time difference between the + registration time and the time of the polling. After this time the POD is considered + missing from the cluster and the executor will be removed. + + 3.1.1 + + + spark.kubernetes.decommission.script + /opt/decom.sh + + The location of the script to use for graceful decommissioning. + + 3.2.0 + + + spark.kubernetes.driver.service.deleteOnTermination + true + + If true, driver service will be deleted on Spark application termination. If false, it will be cleaned up when the driver pod is deletion. + + 3.2.0 + + + spark.kubernetes.driver.ownPersistentVolumeClaim + false + + If true, driver pod becomes the owner of on-demand persistent volume claims instead of the executor pods + + 3.2.0 + + + spark.kubernetes.driver.reusePersistentVolumeClaim + false + + If true, driver pod tries to reuse driver-owned on-demand persistent volume claims + of the deleted executor pods if exists. This can be useful to reduce executor pod + creation delay by skipping persistent volume creations. Note that a pod in + `Terminating` pod status is not a deleted pod by definition and its resources + including persistent volume claims are not reusable yet. Spark will create new + persistent volume claims when there exists no reusable one. In other words, the total + number of persistent volume claims can be larger than the number of running executors + sometimes. This config requires spark.kubernetes.driver.ownPersistentVolumeClaim=true. + + 3.2.0 + + + spark.kubernetes.executor.disableConfigMap + false + + If true, disable ConfigMap creation for executors. + + 3.2.0 + + + spark.kubernetes.driver.pod.featureSteps + (none) + + Class names of an extra driver pod feature step implementing + `KubernetesFeatureConfigStep`. This is a developer API. Comma separated. + Runs after all of Spark internal feature steps. + + 3.2.0 + + + spark.kubernetes.executor.pod.featureSteps + (none) + + Class names of an extra executor pod feature step implementing + `KubernetesFeatureConfigStep`. This is a developer API. Comma separated. + Runs after all of Spark internal feature steps. + + 3.2.0 + + + spark.kubernetes.allocation.maxPendingPods + Int.MaxValue + + Maximum number of pending PODs allowed during executor allocation for this + application. Those newly requested executors which are unknown by Kubernetes yet are + also counted into this limit as they will change into pending PODs by time. + This limit is independent from the resource profiles as it limits the sum of all + allocation for all the used resource profiles. + + 3.2.0 + + + spark.kubernetes.allocation.pods.allocator + direct + + Allocator to use for pods. Possible values are direct (the default) + and statefulset, or a full class name of a class implementing `AbstractPodsAllocator`. + Future version may add Job or replicaset. This is a developer API and may change + or be removed at anytime. + + 3.3.0 + + + spark.kubernetes.allocation.executor.timeout + 600s + + Time to wait before a newly created executor POD request, which does not reached + the POD pending state yet, considered timedout and will be deleted. + + 3.1.0 + + + spark.kubernetes.allocation.driver.readinessTimeout + 1s + + Time to wait for driver pod to get ready before creating executor pods. This wait + only happens on application start. If timeout happens, executor pods will still be + created. + + 3.1.3 + + + spark.kubernetes.executor.enablePollingWithResourceVersion + false + + If true, `resourceVersion` is set with `0` during invoking pod listing APIs + in order to allow API Server-side caching. This should be used carefully. + + 3.3.0 + + #### Pod template properties diff --git a/docs/sql-data-sources-csv.md b/docs/sql-data-sources-csv.md index 82cfa352a5..1dfe8568f9 100644 --- a/docs/sql-data-sources-csv.md +++ b/docs/sql-data-sources-csv.md @@ -9,9 +9,9 @@ license: | The ASF licenses this file to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -19,7 +19,7 @@ license: | limitations under the License. --- -Spark SQL provides `spark.read().csv("file_name")` to read a file or directory of files in CSV format into Spark DataFrame, and `dataframe.write().csv("path")` to write to a CSV file. Function `option()` can be used to customize the behavior of reading or writing, such as controlling behavior of the header, delimiter character, character set, and so on. +Spark SQL provides `spark.read().csv("file_name")` to read a file or directory of files in CSV format into Spark DataFrame, and `dataframe.write().csv("path")` to write to a CSV file. Function `option()` can be used to customize the behavior of reading or writing, such as controlling behavior of the header, delimiter character, character set, and so on.
@@ -162,6 +162,12 @@ Data source options of CSV can be set via: Sets the string that indicates a timestamp format. Custom date formats follow the formats at Datetime Patterns. This applies to timestamp type. read/write + + timestampNTZFormat + yyyy-MM-dd'T'HH:mm:ss[.SSS] + Sets the string that indicates a timestamp without timezone format. Custom date formats follow the formats at Datetime Patterns. This applies to timestamp without timezone type, note that zone-offset and time-zone components are not supported when writing or reading this data type. + read/write + maxColumns 20480 diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 12d9cd4fb1..c15f55dd98 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -133,6 +133,8 @@ license: | - In Spark 3.2, create/alter view will fail if the input query output columns contain auto-generated alias. This is necessary to make sure the query output column names are stable across different spark versions. To restore the behavior before Spark 3.2, set `spark.sql.legacy.allowAutoGeneratedAliasForView` to `true`. + - In Spark 3.2, date +/- interval with only day-time fields such as `date '2011-11-11' + interval 12 hours` returns timestamp. In Spark 3.1 and earlier, the same expression returns date. To restore the behavior before Spark 3.2, you can use `cast` to convert timestamp as date. + ## Upgrading from Spark SQL 3.0 to 3.1 - In Spark 3.1, statistical aggregation function includes `std`, `stddev`, `stddev_samp`, `variance`, `var_samp`, `skewness`, `kurtosis`, `covar_samp`, `corr` will return `NULL` instead of `Double.NaN` when `DivideByZero` occurs during expression evaluation, for example, when `stddev_samp` applied on a single element set. In Spark version 3.0 and earlier, it will return `Double.NaN` in such case. To restore the behavior before Spark 3.1, you can set `spark.sql.legacy.statisticalAggregate` to `true`. diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 9ad7ad6211..3592f6be16 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -528,7 +528,7 @@ Below is a list of all the keywords in Spark SQL. |ROW|non-reserved|non-reserved|reserved| |ROWS|non-reserved|non-reserved|reserved| |SCHEMA|non-reserved|non-reserved|non-reserved| -|SCHEMAS|non-reserved|non-reserved|not a keyword| +|SCHEMAS|non-reserved|non-reserved|non-reserved| |SECOND|non-reserved|non-reserved|non-reserved| |SELECT|reserved|non-reserved|reserved| |SEMI|non-reserved|strict-non-reserved|non-reserved| diff --git a/docs/sql-ref-syntax-ddl-alter-database.md b/docs/sql-ref-syntax-ddl-alter-database.md index fbc454e25f..0ac0038236 100644 --- a/docs/sql-ref-syntax-ddl-alter-database.md +++ b/docs/sql-ref-syntax-ddl-alter-database.md @@ -21,25 +21,47 @@ license: | ### Description -You can alter metadata associated with a database by setting `DBPROPERTIES`. The specified property -values override any existing value with the same property name. Please note that the usage of -`SCHEMA` and `DATABASE` are interchangeable and one can be used in place of the other. An error message -is issued if the database is not found in the system. This command is mostly used to record the metadata -for a database and may be used for auditing purposes. +`ALTER DATABASE` statement changes the properties or location of a database. Please note that the usage of +`DATABASE`, `SCHEMA` and `NAMESPACE` are interchangeable and one can be used in place of the others. An error message +is issued if the database is not found in the system. -### Syntax +### ALTER PROPERTIES +`ALTER DATABASE SET DBPROPERTIES` statement changes the properties associated with a database. +The specified property values override any existing value with the same property name. +This command is mostly used to record the metadata for a database and may be used for auditing purposes. + +#### Syntax ```sql -ALTER { DATABASE | SCHEMA } database_name - SET DBPROPERTIES ( property_name = property_value [ , ... ] ) +ALTER { DATABASE | SCHEMA | NAMESPACE } database_name + SET { DBPROPERTIES | PROPERTIES } ( property_name = property_value [ , ... ] ) ``` -### Parameters +#### Parameters * **database_name** Specifies the name of the database to be altered. +### ALTER LOCATION +`ALTER DATABASE SET LOCATION` statement changes the default parent-directory where new tables will be added +for a database. Please note that it does not move the contents of the database's current directory to the newly +specified location or change the locations associated with any tables/partitions under the specified database +(available since Spark 3.0.0 with the Hive metastore version 3.0.0 and later). + +#### Syntax + +```sql +ALTER { DATABASE | SCHEMA | NAMESPACE } database_name + SET LOCATION 'new_location' +``` + +#### Parameters + +* **database_name** + + Specifies the name of the database to be altered. + ### Examples ```sql @@ -59,6 +81,20 @@ DESCRIBE DATABASE EXTENDED inventory; | Location| file:/temp/spark-warehouse/inventory.db| | Properties|((Edit-date,01/01/2001), (Edited-by,John))| +-------------------------+------------------------------------------+ + +-- Alters the database to set a new location. +ALTER DATABASE inventory SET LOCATION 'file:/temp/spark-warehouse/new_inventory.db'; + +-- Verify that a new location is set. +DESCRIBE DATABASE EXTENDED inventory; ++-------------------------+-------------------------------------------+ +|database_description_item| database_description_value| ++-------------------------+-------------------------------------------+ +| Database Name| inventory| +| Description| | +| Location|file:/temp/spark-warehouse/new_inventory.db| +| Properties| ((Edit-date,01/01/2001), (Edited-by,John))| ++-------------------------+-------------------------------------------+ ``` ### Related Statements diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala index 438f63c75b..c480fba121 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReaderAdmin.scala @@ -387,11 +387,11 @@ private[kafka010] class KafkaOffsetReaderAdmin( // Calculate offset ranges val offsetRangesBase = untilPartitionOffsets.keySet.map { tp => - val fromOffset = fromPartitionOffsets.get(tp).getOrElse { + val fromOffset = fromPartitionOffsets.getOrElse(tp, // This should not happen since topicPartitions contains all partitions not in // fromPartitionOffsets throw new IllegalStateException(s"$tp doesn't have a from offset") - } + ) val untilOffset = untilPartitionOffsets(tp) KafkaOffsetRange(tp, fromOffset, untilOffset, None) }.toSeq diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala index 7dca81ef40..0ec2747be6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.ml.source.image import java.net.URI import java.nio.file.Paths -import java.sql.Date import org.apache.spark.SparkFunSuite import org.apache.spark.ml.image.ImageSchema._ @@ -96,14 +95,14 @@ class ImageFileFormatSuite extends SparkFunSuite with MLlibTestSparkContext { .collect() assert(Set(result: _*) === Set( - Row("29.5.a_b_EGDP022204.jpg", "kittens", Date.valueOf("2018-01-01")), - Row("54893.jpg", "kittens", Date.valueOf("2018-02-01")), - Row("DP153539.jpg", "kittens", Date.valueOf("2018-02-01")), - Row("DP802813.jpg", "kittens", Date.valueOf("2018-02-01")), - Row("BGRA.png", "multichannel", Date.valueOf("2018-01-01")), - Row("BGRA_alpha_60.png", "multichannel", Date.valueOf("2018-01-01")), - Row("chr30.4.184.jpg", "multichannel", Date.valueOf("2018-02-01")), - Row("grayscale.jpg", "multichannel", Date.valueOf("2018-02-01")) + Row("29.5.a_b_EGDP022204.jpg", "kittens", "2018-01"), + Row("54893.jpg", "kittens", "2018-02"), + Row("DP153539.jpg", "kittens", "2018-02"), + Row("DP802813.jpg", "kittens", "2018-02"), + Row("BGRA.png", "multichannel", "2018-01"), + Row("BGRA_alpha_60.png", "multichannel", "2018-01"), + Row("chr30.4.184.jpg", "multichannel", "2018-02"), + Row("grayscale.jpg", "multichannel", "2018-02") )) } diff --git a/pom.xml b/pom.xml index 3899e2f3b9..87e3489f10 100644 --- a/pom.xml +++ b/pom.xml @@ -194,7 +194,7 @@ 2.50.0 1.8 1.1.0 - 1.2 + 1.5.0 1.60 1.6.0 - - hive-2.3 - - - yarn diff --git a/python/docs/source/development/debugging.rst b/python/docs/source/development/debugging.rst index 829919858f..1e6571da02 100644 --- a/python/docs/source/development/debugging.rst +++ b/python/docs/source/development/debugging.rst @@ -277,4 +277,58 @@ executor side, which can be enabled by setting ``spark.python.profile`` configur 12 0.000 0.000 0.001 0.000 context.py:506(f) ... -This feature is supported only with RDD APIs. +Python/Pandas UDF +~~~~~~~~~~~~~~~~~ + +To use this on Python/Pandas UDFs, PySpark provides remote `Python Profilers `_ for +Python/Pandas UDFs, which can be enabled by setting ``spark.python.profile`` configuration to ``true``. + +.. code-block:: bash + + pyspark --conf spark.python.profile=true + + +.. code-block:: python + + >>> from pyspark.sql.functions import pandas_udf + >>> df = spark.range(10) + >>> @pandas_udf("long") + ... def add1(x): + ... return x + 1 + ... + >>> added = df.select(add1("id")) + + >>> added.show() + +--------+ + |add1(id)| + +--------+ + ... + +--------+ + + >>> sc.show_profiles() + ============================================================ + Profile of UDF + ============================================================ + 2300 function calls (2270 primitive calls) in 0.006 seconds + + Ordered by: internal time, cumulative time + + ncalls tottime percall cumtime percall filename:lineno(function) + 10 0.001 0.000 0.005 0.001 series.py:5515(_arith_method) + 10 0.001 0.000 0.001 0.000 _ufunc_config.py:425(__init__) + 10 0.000 0.000 0.000 0.000 {built-in method _operator.add} + 10 0.000 0.000 0.002 0.000 series.py:315(__init__) + ... + +The UDF IDs can be seen in the query plan, for example, ``add1(...)#2L`` in ``ArrowEvalPython`` below. + +.. code-block:: python + + >>> added.explain() + == Physical Plan == + *(2) Project [pythonUDF0#11L AS add1(id)#3L] + +- ArrowEvalPython [add1(id#0L)#2L], [pythonUDF0#11L], 200 + +- *(1) Range (0, 10, step=1, splits=16) + + +This feature is not supported with registered UDFs. diff --git a/python/docs/source/getting_started/install.rst b/python/docs/source/getting_started/install.rst index 13c6f8f3a2..601b45d00a 100644 --- a/python/docs/source/getting_started/install.rst +++ b/python/docs/source/getting_started/install.rst @@ -154,11 +154,11 @@ Dependencies ============= ========================= ====================================== Package Minimum supported version Note ============= ========================= ====================================== -`pandas` 0.23.2 Optional for Spark SQL +`pandas` 1.0.5 Optional for Spark SQL `NumPy` 1.7 Required for MLlib DataFrame-based API `pyarrow` 1.0.0 Optional for Spark SQL `Py4J` 0.10.9.2 Required -`pandas` 0.23.2 Required for pandas API on Spark +`pandas` 1.0.5 Required for pandas API on Spark `pyarrow` 1.0.0 Required for pandas API on Spark `Numpy` 1.14 Required for pandas API on Spark ============= ========================= ====================================== diff --git a/python/docs/source/migration_guide/pyspark_3.2_to_3.3.rst b/python/docs/source/migration_guide/pyspark_3.2_to_3.3.rst index 060f24c8f4..f2701d4fb7 100644 --- a/python/docs/source/migration_guide/pyspark_3.2_to_3.3.rst +++ b/python/docs/source/migration_guide/pyspark_3.2_to_3.3.rst @@ -20,4 +20,6 @@ Upgrading from PySpark 3.2 to 3.3 ================================= +* In Spark 3.3, the ``pyspark.pandas.sql`` method follows [the standard Python string formatter](https://docs.python.org/3/library/string.html#format-string-syntax). To restore the previous behavior, set ``PYSPARK_PANDAS_SQL_LEGACY`` environment variable to ``1``. * In Spark 3.3, the ``drop`` method of pandas API on Spark DataFrame supports dropping rows by ``index``, and sets dropping by index instead of column by default. +* In Spark 3.3, PySpark upgrades Pandas version, the new minimum required version changes from 0.23.2 to 1.0.5. diff --git a/python/docs/source/reference/pyspark.pandas/frame.rst b/python/docs/source/reference/pyspark.pandas/frame.rst index bb84202f16..04bfe27c24 100644 --- a/python/docs/source/reference/pyspark.pandas/frame.rst +++ b/python/docs/source/reference/pyspark.pandas/frame.rst @@ -148,6 +148,7 @@ Computations / Descriptive Stats DataFrame.clip DataFrame.corr DataFrame.count + DataFrame.cov DataFrame.describe DataFrame.kurt DataFrame.kurtosis diff --git a/python/docs/source/reference/pyspark.pandas/indexing.rst b/python/docs/source/reference/pyspark.pandas/indexing.rst index 4168b6712b..87472308b8 100644 --- a/python/docs/source/reference/pyspark.pandas/indexing.rst +++ b/python/docs/source/reference/pyspark.pandas/indexing.rst @@ -336,6 +336,13 @@ DatatimeIndex DatetimeIndex +TimedeltaIndex +------------- +.. autosummary:: + :toctree: api/ + + TimedeltaIndex + Time/date components ~~~~~~~~~~~~~~~~~~~~ .. autosummary:: diff --git a/python/docs/source/user_guide/pandas_on_spark/options.rst b/python/docs/source/user_guide/pandas_on_spark/options.rst index 8f18f8ef8e..06a27ecbe8 100644 --- a/python/docs/source/user_guide/pandas_on_spark/options.rst +++ b/python/docs/source/user_guide/pandas_on_spark/options.rst @@ -286,7 +286,10 @@ compute.eager_check True 'compute.eager_check' sets whethe performs the validation beforehand, but it will cause a performance overhead. Otherwise, pandas-on-Spark skip the validation and will be slightly different - from pandas. Affected APIs: `Series.dot`. + from pandas. Affected APIs: `Series.dot`, + `Series.asof`, `FractionalExtensionOps.astype`, + `IntegralExtensionOps.astype`, `FractionalOps.astype`, + `DecimalOps.astype`. compute.isin_limit 80 'compute.isin_limit' sets the limit for filtering by 'Column.isin(list)'. If the length of the ‘list’ is above the limit, broadcast join is used instead for diff --git a/python/docs/source/user_guide/sql/arrow_pandas.rst b/python/docs/source/user_guide/sql/arrow_pandas.rst index 78d3e7ad84..20a9f935d5 100644 --- a/python/docs/source/user_guide/sql/arrow_pandas.rst +++ b/python/docs/source/user_guide/sql/arrow_pandas.rst @@ -387,7 +387,7 @@ working with timestamps in ``pandas_udf``\s to get the best performance, see Recommended Pandas and PyArrow Versions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -For usage with pyspark.sql, the minimum supported versions of Pandas is 0.23.2 and PyArrow is 1.0.0. +For usage with pyspark.sql, the minimum supported versions of Pandas is 1.0.5 and PyArrow is 1.0.0. Higher versions may be used, however, compatibility and data correctness can not be guaranteed and should be verified by the user. diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 70392fb1df..aab95aded0 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -57,7 +57,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.accumulators import Accumulator, AccumulatorParam from pyspark.broadcast import Broadcast -from pyspark.serializers import MarshalSerializer, PickleSerializer +from pyspark.serializers import MarshalSerializer, CPickleSerializer from pyspark.taskcontext import TaskContext, BarrierTaskContext, BarrierTaskInfo from pyspark.profiler import Profiler, BasicProfiler from pyspark.version import __version__ @@ -136,7 +136,7 @@ def wrapper(self, *args, **kwargs): "Accumulator", "AccumulatorParam", "MarshalSerializer", - "PickleSerializer", + "CPickleSerializer", "StatusTracker", "SparkJobInfo", "SparkStageInfo", diff --git a/python/pyspark/__init__.pyi b/python/pyspark/__init__.pyi index 35df545ee6..fb045f2e5c 100644 --- a/python/pyspark/__init__.pyi +++ b/python/pyspark/__init__.pyi @@ -38,7 +38,7 @@ from pyspark.profiler import ( # noqa: F401 from pyspark.rdd import RDD as RDD, RDDBarrier as RDDBarrier # noqa: F401 from pyspark.serializers import ( # noqa: F401 MarshalSerializer as MarshalSerializer, - PickleSerializer as PickleSerializer, + CPickleSerializer as CPickleSerializer, ) from pyspark.status import ( # noqa: F401 SparkJobInfo as SparkJobInfo, diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index c43ebe417b..2ea2a4952e 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -20,13 +20,13 @@ import struct import socketserver as SocketServer import threading -from pyspark.serializers import read_int, PickleSerializer +from pyspark.serializers import read_int, CPickleSerializer __all__ = ["Accumulator", "AccumulatorParam"] -pickleSer = PickleSerializer() +pickleSer = CPickleSerializer() # Holds accumulators registered on the current machine, keyed by ID. This is then used to send # the local accumulator updates back to the driver program at the end of a task. diff --git a/python/pyspark/cloudpickle/__init__.py b/python/pyspark/cloudpickle/__init__.py index 56506d95fa..0ae79b5535 100644 --- a/python/pyspark/cloudpickle/__init__.py +++ b/python/pyspark/cloudpickle/__init__.py @@ -8,4 +8,4 @@ # expose their Pickler subclass at top-level under the "Pickler" name. Pickler = CloudPickler -__version__ = '1.6.0' +__version__ = '2.0.0' diff --git a/python/pyspark/cloudpickle/cloudpickle.py b/python/pyspark/cloudpickle/cloudpickle.py index 05d52afa0d..347b386958 100644 --- a/python/pyspark/cloudpickle/cloudpickle.py +++ b/python/pyspark/cloudpickle/cloudpickle.py @@ -55,6 +55,7 @@ import warnings from .compat import pickle +from collections import OrderedDict from typing import Generic, Union, Tuple, Callable from pickle import _getattribute from importlib._bootstrap import _find_spec @@ -87,8 +88,11 @@ def g(): # communication speed over compatibility: DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL +# Names of modules whose resources should be treated as dynamic. +_PICKLE_BY_VALUE_MODULES = set() + # Track the provenance of reconstructed dynamic classes to make it possible to -# recontruct instances from the matching singleton class definition when +# reconstruct instances from the matching singleton class definition when # appropriate and preserve the usual "isinstance" semantics of Python objects. _DYNAMIC_CLASS_TRACKER_BY_CLASS = weakref.WeakKeyDictionary() _DYNAMIC_CLASS_TRACKER_BY_ID = weakref.WeakValueDictionary() @@ -123,6 +127,77 @@ def _lookup_class_or_track(class_tracker_id, class_def): return class_def +def register_pickle_by_value(module): + """Register a module to make it functions and classes picklable by value. + + By default, functions and classes that are attributes of an importable + module are to be pickled by reference, that is relying on re-importing + the attribute from the module at load time. + + If `register_pickle_by_value(module)` is called, all its functions and + classes are subsequently to be pickled by value, meaning that they can + be loaded in Python processes where the module is not importable. + + This is especially useful when developing a module in a distributed + execution environment: restarting the client Python process with the new + source code is enough: there is no need to re-install the new version + of the module on all the worker nodes nor to restart the workers. + + Note: this feature is considered experimental. See the cloudpickle + README.md file for more details and limitations. + """ + if not isinstance(module, types.ModuleType): + raise ValueError( + f"Input should be a module object, got {str(module)} instead" + ) + # In the future, cloudpickle may need a way to access any module registered + # for pickling by value in order to introspect relative imports inside + # functions pickled by value. (see + # https://github.com/cloudpipe/cloudpickle/pull/417#issuecomment-873684633). + # This access can be ensured by checking that module is present in + # sys.modules at registering time and assuming that it will still be in + # there when accessed during pickling. Another alternative would be to + # store a weakref to the module. Even though cloudpickle does not implement + # this introspection yet, in order to avoid a possible breaking change + # later, we still enforce the presence of module inside sys.modules. + if module.__name__ not in sys.modules: + raise ValueError( + f"{module} was not imported correctly, have you used an " + f"`import` statement to access it?" + ) + _PICKLE_BY_VALUE_MODULES.add(module.__name__) + + +def unregister_pickle_by_value(module): + """Unregister that the input module should be pickled by value.""" + if not isinstance(module, types.ModuleType): + raise ValueError( + f"Input should be a module object, got {str(module)} instead" + ) + if module.__name__ not in _PICKLE_BY_VALUE_MODULES: + raise ValueError(f"{module} is not registered for pickle by value") + else: + _PICKLE_BY_VALUE_MODULES.remove(module.__name__) + + +def list_registry_pickle_by_value(): + return _PICKLE_BY_VALUE_MODULES.copy() + + +def _is_registered_pickle_by_value(module): + module_name = module.__name__ + if module_name in _PICKLE_BY_VALUE_MODULES: + return True + while True: + parent_name = module_name.rsplit(".", 1)[0] + if parent_name == module_name: + break + if parent_name in _PICKLE_BY_VALUE_MODULES: + return True + module_name = parent_name + return False + + def _whichmodule(obj, name): """Find the module an object belongs to. @@ -136,11 +211,14 @@ def _whichmodule(obj, name): # Workaround bug in old Python versions: prior to Python 3.7, # T.__module__ would always be set to "typing" even when the TypeVar T # would be defined in a different module. - # - # For such older Python versions, we ignore the __module__ attribute of - # TypeVar instances and instead exhaustively lookup those instances in - # all currently imported modules. - module_name = None + if name is not None and getattr(typing, name, None) is obj: + # Built-in TypeVar defined in typing such as AnyStr + return 'typing' + else: + # User defined or third-party TypeVar: __module__ attribute is + # irrelevant, thus trigger a exhaustive search for obj in all + # modules. + module_name = None else: module_name = getattr(obj, '__module__', None) @@ -166,18 +244,35 @@ def _whichmodule(obj, name): return None -def _is_importable(obj, name=None): - """Dispatcher utility to test the importability of various constructs.""" - if isinstance(obj, types.FunctionType): - return _lookup_module_and_qualname(obj, name=name) is not None - elif issubclass(type(obj), type): - return _lookup_module_and_qualname(obj, name=name) is not None +def _should_pickle_by_reference(obj, name=None): + """Test whether an function or a class should be pickled by reference + + Pickling by reference means by that the object (typically a function or a + class) is an attribute of a module that is assumed to be importable in the + target Python environment. Loading will therefore rely on importing the + module and then calling `getattr` on it to access the function or class. + + Pickling by reference is the only option to pickle functions and classes + in the standard library. In cloudpickle the alternative option is to + pickle by value (for instance for interactively or locally defined + functions and classes or for attributes of modules that have been + explicitly registered to be pickled by value. + """ + if isinstance(obj, types.FunctionType) or issubclass(type(obj), type): + module_and_name = _lookup_module_and_qualname(obj, name=name) + if module_and_name is None: + return False + module, name = module_and_name + return not _is_registered_pickle_by_value(module) + elif isinstance(obj, types.ModuleType): # We assume that sys.modules is primarily used as a cache mechanism for # the Python import machinery. Checking if a module has been added in - # is sys.modules therefore a cheap and simple heuristic to tell us whether - # we can assume that a given module could be imported by name in - # another Python process. + # is sys.modules therefore a cheap and simple heuristic to tell us + # whether we can assume that a given module could be imported by name + # in another Python process. + if _is_registered_pickle_by_value(obj): + return False return obj.__name__ in sys.modules else: raise TypeError( @@ -233,10 +328,13 @@ def _extract_code_globals(co): out_names = _extract_code_globals_cache.get(co) if out_names is None: names = co.co_names - out_names = {names[oparg] for _, oparg in _walk_global_ops(co)} + # We use a dict with None values instead of a set to get a + # deterministic order (assuming Python 3.6+) and avoid introducing + # non-deterministic pickle bytes as a results. + out_names = {names[oparg]: None for _, oparg in _walk_global_ops(co)} # Declaring a function inside another one using the "def ..." - # syntax generates a constant code object corresonding to the one + # syntax generates a constant code object corresponding to the one # of the nested function's As the nested function may itself need # global variables, we need to introspect its code, extract its # globals, (look for code object in it's co_consts attribute..) and @@ -244,7 +342,7 @@ def _extract_code_globals(co): if co.co_consts: for const in co.co_consts: if isinstance(const, types.CodeType): - out_names |= _extract_code_globals(const) + out_names.update(_extract_code_globals(const)) _extract_code_globals_cache[co] = out_names @@ -452,15 +550,31 @@ def _extract_class_dict(cls): if sys.version_info[:2] < (3, 7): # pragma: no branch def _is_parametrized_type_hint(obj): - # This is very cheap but might generate false positives. + # This is very cheap but might generate false positives. So try to + # narrow it down is good as possible. + type_module = getattr(type(obj), '__module__', None) + from_typing_extensions = type_module == 'typing_extensions' + from_typing = type_module == 'typing' + # general typing Constructs is_typing = getattr(obj, '__origin__', None) is not None # typing_extensions.Literal - is_litteral = getattr(obj, '__values__', None) is not None + is_literal = ( + (getattr(obj, '__values__', None) is not None) + and from_typing_extensions + ) # typing_extensions.Final - is_final = getattr(obj, '__type__', None) is not None + is_final = ( + (getattr(obj, '__type__', None) is not None) + and from_typing_extensions + ) + + # typing.ClassVar + is_classvar = ( + (getattr(obj, '__type__', None) is not None) and from_typing + ) # typing.Union/Tuple for old Python 3.5 is_union = getattr(obj, '__union_params__', None) is not None @@ -469,8 +583,8 @@ def _is_parametrized_type_hint(obj): getattr(obj, '__result__', None) is not None and getattr(obj, '__args__', None) is not None ) - return any((is_typing, is_litteral, is_final, is_union, is_tuple, - is_callable)) + return any((is_typing, is_literal, is_final, is_classvar, is_union, + is_tuple, is_callable)) def _create_parametrized_type_hint(origin, args): return origin[args] @@ -557,8 +671,11 @@ def _rebuild_tornado_coroutine(func): loads = pickle.loads -# hack for __import__ not working as desired def subimport(name): + # We cannot do simply: `return __import__(name)`: Indeed, if ``name`` is + # the name of a submodule, __import__ will return the top-level root module + # of this submodule. For instance, __import__('os.path') returns the `os` + # module. __import__(name) return sys.modules[name] @@ -699,7 +816,7 @@ def _make_skel_func(code, cell_count, base_globals=None): """ # This function is deprecated and should be removed in cloudpickle 1.7 warnings.warn( - "A pickle file created using an old (<=1.4.1) version of cloudpicke " + "A pickle file created using an old (<=1.4.1) version of cloudpickle " "is currently being loaded. This is not supported by cloudpickle and " "will break in cloudpickle 1.7", category=UserWarning ) @@ -813,10 +930,15 @@ def _decompose_typevar(obj): def _typevar_reduce(obj): - # TypeVar instances have no __qualname__ hence we pass the name explicitly. + # TypeVar instances require the module information hence why we + # are not using the _should_pickle_by_reference directly module_and_name = _lookup_module_and_qualname(obj, name=obj.__name__) + if module_and_name is None: return (_make_typevar, _decompose_typevar(obj)) + elif _is_registered_pickle_by_value(module_and_name[0]): + return (_make_typevar, _decompose_typevar(obj)) + return (getattr, module_and_name) @@ -830,13 +952,22 @@ def _get_bases(typ): return getattr(typ, bases_attr) -def _make_dict_keys(obj): - return dict.fromkeys(obj).keys() +def _make_dict_keys(obj, is_ordered=False): + if is_ordered: + return OrderedDict.fromkeys(obj).keys() + else: + return dict.fromkeys(obj).keys() -def _make_dict_values(obj): - return {i: _ for i, _ in enumerate(obj)}.values() +def _make_dict_values(obj, is_ordered=False): + if is_ordered: + return OrderedDict((i, _) for i, _ in enumerate(obj)).values() + else: + return {i: _ for i, _ in enumerate(obj)}.values() -def _make_dict_items(obj): - return obj.items() +def _make_dict_items(obj, is_ordered=False): + if is_ordered: + return OrderedDict(obj).items() + else: + return obj.items() diff --git a/python/pyspark/cloudpickle/cloudpickle_fast.py b/python/pyspark/cloudpickle/cloudpickle_fast.py index fa8da0f635..6db059eb85 100644 --- a/python/pyspark/cloudpickle/cloudpickle_fast.py +++ b/python/pyspark/cloudpickle/cloudpickle_fast.py @@ -6,7 +6,7 @@ is only available for Python versions 3.8+, a lot of backward-compatibility code is also removed. -Note that the C Pickler sublassing API is CPython-specific. Therefore, some +Note that the C Pickler subclassing API is CPython-specific. Therefore, some guards present in cloudpickle.py that were written to handle PyPy specificities are not present in cloudpickle_fast.py """ @@ -23,12 +23,12 @@ import typing from enum import Enum -from collections import ChainMap +from collections import ChainMap, OrderedDict from .compat import pickle, Pickler from .cloudpickle import ( _extract_code_globals, _BUILTIN_TYPE_NAMES, DEFAULT_PROTOCOL, - _find_imported_submodules, _get_cell_contents, _is_importable, + _find_imported_submodules, _get_cell_contents, _should_pickle_by_reference, _builtin_type, _get_or_create_tracker_id, _make_skeleton_class, _make_skeleton_enum, _extract_class_dict, dynamic_subimport, subimport, _typevar_reduce, _get_bases, _make_cell, _make_empty_cell, CellType, @@ -180,7 +180,7 @@ def _class_getstate(obj): clsdict.pop('__weakref__', None) if issubclass(type(obj), abc.ABCMeta): - # If obj is an instance of an ABCMeta subclass, dont pickle the + # If obj is an instance of an ABCMeta subclass, don't pickle the # cache/negative caches populated during isinstance/issubclass # checks, but pickle the list of registered subclasses of obj. clsdict.pop('_abc_cache', None) @@ -244,7 +244,19 @@ def _enum_getstate(obj): def _code_reduce(obj): """codeobject reducer""" - if hasattr(obj, "co_posonlyargcount"): # pragma: no branch + if hasattr(obj, "co_linetable"): # pragma: no branch + # Python 3.10 and later: obj.co_lnotab is deprecated and constructor + # expects obj.co_linetable instead. + args = ( + obj.co_argcount, obj.co_posonlyargcount, + obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize, + obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, + obj.co_varnames, obj.co_filename, obj.co_name, + obj.co_firstlineno, obj.co_linetable, obj.co_freevars, + obj.co_cellvars + ) + elif hasattr(obj, "co_posonlyargcount"): + # Backward compat for 3.9 and older args = ( obj.co_argcount, obj.co_posonlyargcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize, @@ -254,6 +266,7 @@ def _code_reduce(obj): obj.co_cellvars ) else: + # Backward compat for even older versions of Python args = ( obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, obj.co_consts, @@ -339,11 +352,16 @@ def _memoryview_reduce(obj): def _module_reduce(obj): - if _is_importable(obj): + if _should_pickle_by_reference(obj): return subimport, (obj.__name__,) else: - obj.__dict__.pop('__builtins__', None) - return dynamic_subimport, (obj.__name__, vars(obj)) + # Some external libraries can populate the "__builtins__" entry of a + # module's `__dict__` with unpicklable objects (see #316). For that + # reason, we do not attempt to pickle the "__builtins__" entry, and + # restore a default value for it at unpickling time. + state = obj.__dict__.copy() + state.pop('__builtins__', None) + return dynamic_subimport, (obj.__name__, state) def _method_reduce(obj): @@ -396,7 +414,7 @@ def _class_reduce(obj): return type, (NotImplemented,) elif obj in _BUILTIN_TYPE_NAMES: return _builtin_type, (_BUILTIN_TYPE_NAMES[obj],) - elif not _is_importable(obj): + elif not _should_pickle_by_reference(obj): return _dynamic_class_reduce(obj) return NotImplemented @@ -419,6 +437,24 @@ def _dict_items_reduce(obj): return _make_dict_items, (dict(obj), ) +def _odict_keys_reduce(obj): + # Safer not to ship the full dict as sending the rest might + # be unintended and could potentially cause leaking of + # sensitive information + return _make_dict_keys, (list(obj), True) + + +def _odict_values_reduce(obj): + # Safer not to ship the full dict as sending the rest might + # be unintended and could potentially cause leaking of + # sensitive information + return _make_dict_values, (list(obj), True) + + +def _odict_items_reduce(obj): + return _make_dict_items, (dict(obj), True) + + # COLLECTIONS OF OBJECTS STATE SETTERS # ------------------------------------ # state setters are called at unpickling time, once the object is created and @@ -426,7 +462,7 @@ def _dict_items_reduce(obj): def _function_setstate(obj, state): - """Update the state of a dynaamic function. + """Update the state of a dynamic function. As __closure__ and __globals__ are readonly attributes of a function, we cannot rely on the native setstate routine of pickle.load_build, that calls @@ -495,6 +531,9 @@ class CloudPickler(Pickler): _dispatch_table[_collections_abc.dict_keys] = _dict_keys_reduce _dispatch_table[_collections_abc.dict_values] = _dict_values_reduce _dispatch_table[_collections_abc.dict_items] = _dict_items_reduce + _dispatch_table[type(OrderedDict().keys())] = _odict_keys_reduce + _dispatch_table[type(OrderedDict().values())] = _odict_values_reduce + _dispatch_table[type(OrderedDict().items())] = _odict_items_reduce dispatch_table = ChainMap(_dispatch_table, copyreg.dispatch_table) @@ -520,7 +559,7 @@ def _function_reduce(self, obj): As opposed to cloudpickle.py, There no special handling for builtin pypy functions because cloudpickle_fast is CPython-specific. """ - if _is_importable(obj): + if _should_pickle_by_reference(obj): return NotImplemented else: return self._dynamic_function_reduce(obj) @@ -579,7 +618,7 @@ def dump(self, obj): # `dispatch` attribute. Earlier versions of the protocol 5 CloudPickler # used `CloudPickler.dispatch` as a class-level attribute storing all # reducers implemented by cloudpickle, but the attribute name was not a - # great choice given the meaning of `Cloudpickler.dispatch` when + # great choice given the meaning of `CloudPickler.dispatch` when # `CloudPickler` extends the pure-python pickler. dispatch = dispatch_table @@ -653,7 +692,7 @@ def reducer_override(self, obj): return self._function_reduce(obj) else: # fallback to save_global, including the Pickler's - # distpatch_table + # dispatch_table return NotImplemented else: @@ -724,7 +763,7 @@ def save_global(self, obj, name=None, pack=struct.pack): ) elif name is not None: Pickler.save_global(self, obj, name=name) - elif not _is_importable(obj, name=name): + elif not _should_pickle_by_reference(obj, name=name): self._save_reduce_pickle5(*_dynamic_class_reduce(obj), obj=obj) else: Pickler.save_global(self, obj, name=name) @@ -736,7 +775,7 @@ def save_function(self, obj, name=None): Determines what kind of function obj is (e.g. lambda, defined at interactive prompt, etc) and handles the pickling appropriately. """ - if _is_importable(obj, name=name): + if _should_pickle_by_reference(obj, name=name): return Pickler.save_global(self, obj, name=name) elif PYPY and isinstance(obj.__code__, builtin_code_type): return self.save_pypy_builtin_func(obj) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 2c789947af..336024fff8 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -35,7 +35,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway, local_connect_and_auth from pyspark.serializers import ( - PickleSerializer, + CPickleSerializer, BatchedSerializer, UTF8Deserializer, PairDeserializer, @@ -49,7 +49,7 @@ from pyspark.taskcontext import TaskContext from pyspark.traceback_utils import CallSite, first_spark_call from pyspark.status import StatusTracker -from pyspark.profiler import ProfilerCollector, BasicProfiler +from pyspark.profiler import ProfilerCollector, BasicProfiler, UDFBasicProfiler __all__ = ["SparkContext"] @@ -105,6 +105,9 @@ class SparkContext(object): profiler_cls : type, optional A class of custom Profiler used to do profiling (default is :class:`pyspark.profiler.BasicProfiler`). + udf_profiler_cls : type, optional + A class of custom Profiler used to do udf profiling + (default is :class:`pyspark.profiler.UDFBasicProfiler`). Notes ----- @@ -142,11 +145,12 @@ def __init__( pyFiles=None, environment=None, batchSize=0, - serializer=PickleSerializer(), + serializer=CPickleSerializer(), conf=None, gateway=None, jsc=None, profiler_cls=BasicProfiler, + udf_profiler_cls=UDFBasicProfiler, ): if conf is None or conf.get("spark.executor.allowSparkContext", "false").lower() != "true": # In order to prevent SparkContext from being created in executors. @@ -172,6 +176,7 @@ def __init__( conf, jsc, profiler_cls, + udf_profiler_cls, ) except: # If an error occurs, clean up in order to allow future SparkContext creation: @@ -190,6 +195,7 @@ def _do_init( conf, jsc, profiler_cls, + udf_profiler_cls, ): self.environment = environment or {} # java gateway must have been launched at this point. @@ -319,7 +325,7 @@ def _do_init( # profiling stats collected for each PythonRDD if self._conf.get("spark.python.profile", "false") == "true": dump_path = self._conf.get("spark.python.profile.dump", None) - self.profiler_collector = ProfilerCollector(profiler_cls, dump_path) + self.profiler_collector = ProfilerCollector(profiler_cls, udf_profiler_cls, dump_path) else: self.profiler_collector = None @@ -814,7 +820,7 @@ def sequenceFile( and value Writable classes 2. Serialization is attempted via Pickle pickling 3. If this fails, the fallback is to call 'toString' on each key and value - 4. :class:`PickleSerializer` is used to deserialize pickled objects on the Python side + 4. :class:`CPickleSerializer` is used to deserialize pickled objects on the Python side Parameters ---------- diff --git a/python/pyspark/context.pyi b/python/pyspark/context.pyi index 640a69cad0..f1350aaec9 100644 --- a/python/pyspark/context.pyi +++ b/python/pyspark/context.pyi @@ -62,6 +62,7 @@ class SparkContext: gateway: Optional[JavaGateway] = ..., jsc: Optional[JavaObject] = ..., profiler_cls: type = ..., + udf_profiler_cls: type = ..., ) -> None: ... def __getnewargs__(self) -> NoReturn: ... def __enter__(self) -> SparkContext: ... diff --git a/python/pyspark/ml/_typing.pyi b/python/pyspark/ml/_typing.pyi index 40531d1c48..b51aa9634f 100644 --- a/python/pyspark/ml/_typing.pyi +++ b/python/pyspark/ml/_typing.pyi @@ -23,6 +23,7 @@ import pyspark.ml.base import pyspark.ml.param import pyspark.ml.util import pyspark.ml.wrapper +from py4j.java_gateway import JavaObject ParamMap = Dict[pyspark.ml.param.Param, Any] PipelineStage = Union[pyspark.ml.base.Estimator, pyspark.ml.base.Transformer] @@ -31,7 +32,9 @@ T = TypeVar("T") P = TypeVar("P", bound=pyspark.ml.param.Params) M = TypeVar("M", bound=pyspark.ml.base.Transformer) JM = TypeVar("JM", bound=pyspark.ml.wrapper.JavaTransformer) +C = TypeVar("C", bound=type) +JavaObjectOrPickleDump = Union[JavaObject, bytearray, bytes] BinaryClassificationEvaluatorMetricType = Union[Literal["areaUnderROC"], Literal["areaUnderPR"]] RegressionEvaluatorMetricType = Union[ Literal["rmse"], Literal["mse"], Literal["r2"], Literal["mae"], Literal["var"] @@ -64,7 +67,7 @@ MultilabelClassificationEvaluatorMetricType = Union[ Literal["microRecall"], Literal["microF1Measure"], ] -ClusteringEvaluatorMetricType = Union[Literal["silhouette"]] +ClusteringEvaluatorMetricType = Literal["silhouette"] RankingEvaluatorMetricType = Union[ Literal["meanAveragePrecision"], Literal["meanAveragePrecisionAtK"], diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py index b43b2ad9b1..61b20f131d 100644 --- a/python/pyspark/ml/common.py +++ b/python/pyspark/ml/common.py @@ -15,13 +15,19 @@ # limitations under the License. # +from typing import Any, Callable, TYPE_CHECKING + +if TYPE_CHECKING: + from pyspark.ml._typing import C, JavaObjectOrPickleDump + import py4j.protocol from py4j.protocol import Py4JJavaError from py4j.java_gateway import JavaObject from py4j.java_collections import JavaArray, JavaList +import pyspark.context from pyspark import RDD, SparkContext -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.serializers import CPickleSerializer, AutoBatchedSerializer from pyspark.sql import DataFrame, SparkSession # Hack for support float('inf') in Py4j @@ -34,7 +40,7 @@ } -def _new_smart_decode(obj): +def _new_smart_decode(obj: Any) -> str: if isinstance(obj, float): s = str(obj) return _float_str_mapping.get(s, s) @@ -53,24 +59,24 @@ def _new_smart_decode(obj): # this will call the ML version of pythonToJava() -def _to_java_object_rdd(rdd): +def _to_java_object_rdd(rdd: RDD) -> JavaObject: """Return an JavaRDD of Object by unpickling It will convert each Python object into Java object by Pickle, whenever the RDD is serialized in batch or not. """ - rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer())) - return rdd.ctx._jvm.org.apache.spark.ml.python.MLSerDe.pythonToJava(rdd._jrdd, True) + rdd = rdd._reserialize(AutoBatchedSerializer(CPickleSerializer())) # type: ignore[attr-defined] + return rdd.ctx._jvm.org.apache.spark.ml.python.MLSerDe.pythonToJava(rdd._jrdd, True) # type: ignore[attr-defined] -def _py2java(sc, obj): +def _py2java(sc: SparkContext, obj: Any) -> JavaObject: """Convert Python object into Java""" if isinstance(obj, RDD): obj = _to_java_object_rdd(obj) elif isinstance(obj, DataFrame): obj = obj._jdf elif isinstance(obj, SparkContext): - obj = obj._jsc + obj = obj._jsc # type: ignore[attr-defined] elif isinstance(obj, list): obj = [_py2java(sc, x) for x in obj] elif isinstance(obj, JavaObject): @@ -78,12 +84,12 @@ def _py2java(sc, obj): elif isinstance(obj, (int, float, bool, bytes, str)): pass else: - data = bytearray(PickleSerializer().dumps(obj)) - obj = sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(data) + data = bytearray(CPickleSerializer().dumps(obj)) + obj = sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(data) # type: ignore[attr-defined] return obj -def _java2py(sc, r, encoding="bytes"): +def _java2py(sc: SparkContext, r: "JavaObjectOrPickleDump", encoding: str = "bytes") -> Any: if isinstance(r, JavaObject): clsName = r.getClass().getSimpleName() # convert RDD into JavaRDD @@ -92,32 +98,34 @@ def _java2py(sc, r, encoding="bytes"): clsName = "JavaRDD" if clsName == "JavaRDD": - jrdd = sc._jvm.org.apache.spark.ml.python.MLSerDe.javaToPython(r) + jrdd = sc._jvm.org.apache.spark.ml.python.MLSerDe.javaToPython(r) # type: ignore[attr-defined] return RDD(jrdd, sc) if clsName == "Dataset": return DataFrame(r, SparkSession(sc)._wrapped) if clsName in _picklable_classes: - r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r) + r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r) # type: ignore[attr-defined] elif isinstance(r, (JavaArray, JavaList)): try: - r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r) + r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r) # type: ignore[attr-defined] except Py4JJavaError: pass # not picklable if isinstance(r, (bytearray, bytes)): - r = PickleSerializer().loads(bytes(r), encoding=encoding) + r = CPickleSerializer().loads(bytes(r), encoding=encoding) return r -def callJavaFunc(sc, func, *args): +def callJavaFunc( + sc: pyspark.context.SparkContext, func: Callable[..., "JavaObjectOrPickleDump"], *args: Any +) -> "JavaObjectOrPickleDump": """Call Java Function""" - args = [_py2java(sc, a) for a in args] - return _java2py(sc, func(*args)) + java_args = [_py2java(sc, a) for a in args] + return _java2py(sc, func(*java_args)) -def inherit_doc(cls): +def inherit_doc(cls: "C") -> "C": """ A decorator that makes a class inherit documentation from its parents. """ diff --git a/python/pyspark/ml/tests/test_linalg.py b/python/pyspark/ml/tests/test_linalg.py index 5db6c048bf..dfdd32e98e 100644 --- a/python/pyspark/ml/tests/test_linalg.py +++ b/python/pyspark/ml/tests/test_linalg.py @@ -20,7 +20,7 @@ from numpy import arange, array, array_equal, inf, ones, tile, zeros -from pyspark.serializers import PickleSerializer +from pyspark.serializers import CPickleSerializer from pyspark.ml.linalg import ( DenseMatrix, DenseVector, @@ -37,7 +37,7 @@ class VectorTests(MLlibTestCase): def _test_serialize(self, v): - ser = PickleSerializer() + ser = CPickleSerializer() self.assertEqual(v, ser.loads(ser.dumps(v))) jvec = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v))) nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec))) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 3fd96eb9e2..2a2148323a 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -404,7 +404,7 @@ def validateParams(instance): estimatorParamMaps = instance.getEstimatorParamMaps() paramErr = ( "Validator save requires all Params in estimatorParamMaps to apply to " - f"its Estimator, An extraneous Param was found: " + "its Estimator, An extraneous Param was found: " ) for paramMap in estimatorParamMaps: for param in paramMap: diff --git a/python/pyspark/mllib/_typing.pyi b/python/pyspark/mllib/_typing.pyi index 22469e869a..51a98cb0b0 100644 --- a/python/pyspark/mllib/_typing.pyi +++ b/python/pyspark/mllib/_typing.pyi @@ -16,8 +16,11 @@ # specific language governing permissions and limitations # under the License. -from typing import List, Tuple, Union +from typing import List, Tuple, TypeVar, Union from pyspark.mllib.linalg import Vector from numpy import ndarray # noqa: F401 +from py4j.java_gateway import JavaObject VectorLike = Union[ndarray, Vector, List[float], Tuple[float, ...]] +C = TypeVar("C", bound=type) +JavaObjectOrPickleDump = Union[JavaObject, bytearray, bytes] diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index 1d8098ffb1..5f109be2a1 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -15,13 +15,19 @@ # limitations under the License. # +from typing import Any, Callable, TYPE_CHECKING + +if TYPE_CHECKING: + from pyspark.mllib._typing import C, JavaObjectOrPickleDump + import py4j.protocol from py4j.protocol import Py4JJavaError from py4j.java_gateway import JavaObject from py4j.java_collections import JavaArray, JavaList +import pyspark.context from pyspark import RDD, SparkContext -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.serializers import CPickleSerializer, AutoBatchedSerializer from pyspark.sql import DataFrame, SparkSession # Hack for support float('inf') in Py4j @@ -34,7 +40,7 @@ } -def _new_smart_decode(obj): +def _new_smart_decode(obj: Any) -> str: if isinstance(obj, float): s = str(obj) return _float_str_mapping.get(s, s) @@ -55,24 +61,24 @@ def _new_smart_decode(obj): # this will call the MLlib version of pythonToJava() -def _to_java_object_rdd(rdd): +def _to_java_object_rdd(rdd: RDD) -> JavaObject: """Return a JavaRDD of Object by unpickling It will convert each Python object into Java object by Pickle, whenever the RDD is serialized in batch or not. """ - rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer())) - return rdd.ctx._jvm.org.apache.spark.mllib.api.python.SerDe.pythonToJava(rdd._jrdd, True) + rdd = rdd._reserialize(AutoBatchedSerializer(CPickleSerializer())) # type: ignore[attr-defined] + return rdd.ctx._jvm.org.apache.spark.mllib.api.python.SerDe.pythonToJava(rdd._jrdd, True) # type: ignore[attr-defined] -def _py2java(sc, obj): +def _py2java(sc: SparkContext, obj: Any) -> JavaObject: """Convert Python object into Java""" if isinstance(obj, RDD): obj = _to_java_object_rdd(obj) elif isinstance(obj, DataFrame): obj = obj._jdf elif isinstance(obj, SparkContext): - obj = obj._jsc + obj = obj._jsc # type: ignore[attr-defined] elif isinstance(obj, list): obj = [_py2java(sc, x) for x in obj] elif isinstance(obj, JavaObject): @@ -80,12 +86,12 @@ def _py2java(sc, obj): elif isinstance(obj, (int, float, bool, bytes, str)): pass else: - data = bytearray(PickleSerializer().dumps(obj)) - obj = sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(data) + data = bytearray(CPickleSerializer().dumps(obj)) + obj = sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(data) # type: ignore[attr-defined] return obj -def _java2py(sc, r, encoding="bytes"): +def _java2py(sc: SparkContext, r: "JavaObjectOrPickleDump", encoding: str = "bytes") -> Any: if isinstance(r, JavaObject): clsName = r.getClass().getSimpleName() # convert RDD into JavaRDD @@ -94,35 +100,37 @@ def _java2py(sc, r, encoding="bytes"): clsName = "JavaRDD" if clsName == "JavaRDD": - jrdd = sc._jvm.org.apache.spark.mllib.api.python.SerDe.javaToPython(r) + jrdd = sc._jvm.org.apache.spark.mllib.api.python.SerDe.javaToPython(r) # type: ignore[attr-defined] return RDD(jrdd, sc) if clsName == "Dataset": return DataFrame(r, SparkSession(sc)._wrapped) if clsName in _picklable_classes: - r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r) + r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r) # type: ignore[attr-defined] elif isinstance(r, (JavaArray, JavaList)): try: - r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r) + r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r) # type: ignore[attr-defined] except Py4JJavaError: pass # not pickable if isinstance(r, (bytearray, bytes)): - r = PickleSerializer().loads(bytes(r), encoding=encoding) + r = CPickleSerializer().loads(bytes(r), encoding=encoding) return r -def callJavaFunc(sc, func, *args): +def callJavaFunc( + sc: pyspark.context.SparkContext, func: Callable[..., "JavaObjectOrPickleDump"], *args: Any +) -> "JavaObjectOrPickleDump": """Call Java Function""" - args = [_py2java(sc, a) for a in args] - return _java2py(sc, func(*args)) + java_args = [_py2java(sc, a) for a in args] + return _java2py(sc, func(*java_args)) -def callMLlibFunc(name, *args): +def callMLlibFunc(name: str, *args: Any) -> "JavaObjectOrPickleDump": """Call API in PythonMLLibAPI""" sc = SparkContext.getOrCreate() - api = getattr(sc._jvm.PythonMLLibAPI(), name) + api = getattr(sc._jvm.PythonMLLibAPI(), name) # type: ignore[attr-defined] return callJavaFunc(sc, api, *args) @@ -131,19 +139,19 @@ class JavaModelWrapper(object): Wrapper for the model in JVM """ - def __init__(self, java_model): + def __init__(self, java_model: JavaObject): self._sc = SparkContext.getOrCreate() self._java_model = java_model - def __del__(self): - self._sc._gateway.detach(self._java_model) + def __del__(self) -> None: + self._sc._gateway.detach(self._java_model) # type: ignore[attr-defined] - def call(self, name, *a): + def call(self, name: str, *a: Any) -> "JavaObjectOrPickleDump": """Call method of java_model""" return callJavaFunc(self._sc, getattr(self._java_model, name), *a) -def inherit_doc(cls): +def inherit_doc(cls: "C") -> "C": """ A decorator that makes a class inherit documentation from its parents. """ diff --git a/python/pyspark/mllib/tests/test_algorithms.py b/python/pyspark/mllib/tests/test_algorithms.py index 6927b75e3d..fd9f348f31 100644 --- a/python/pyspark/mllib/tests/test_algorithms.py +++ b/python/pyspark/mllib/tests/test_algorithms.py @@ -26,7 +26,7 @@ from pyspark.mllib.fpm import FPGrowth from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint -from pyspark.serializers import PickleSerializer +from pyspark.serializers import CPickleSerializer from pyspark.testing.mllibutils import MLlibTestCase @@ -303,7 +303,7 @@ def test_regression(self): class ALSTests(MLlibTestCase): def test_als_ratings_serialize(self): - ser = PickleSerializer() + ser = CPickleSerializer() r = Rating(7, 1123, 3.14) jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r))) nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr))) @@ -312,7 +312,7 @@ def test_als_ratings_serialize(self): self.assertAlmostEqual(r.rating, nr.rating, 2) def test_als_ratings_id_long_error(self): - ser = PickleSerializer() + ser = CPickleSerializer() r = Rating(1205640308657491975, 50233468418, 1.0) # rating user id exceeds max int value, should fail when pickled self.assertRaises( diff --git a/python/pyspark/mllib/tests/test_linalg.py b/python/pyspark/mllib/tests/test_linalg.py index e43482dc41..d60396b633 100644 --- a/python/pyspark/mllib/tests/test_linalg.py +++ b/python/pyspark/mllib/tests/test_linalg.py @@ -21,7 +21,7 @@ from numpy import array, array_equal, zeros, arange, tile, ones, inf import pyspark.ml.linalg as newlinalg -from pyspark.serializers import PickleSerializer +from pyspark.serializers import CPickleSerializer from pyspark.mllib.linalg import ( # type: ignore[attr-defined] Vector, SparseVector, @@ -43,7 +43,7 @@ class VectorTests(MLlibTestCase): def _test_serialize(self, v): - ser = PickleSerializer() + ser = CPickleSerializer() self.assertEqual(v, ser.loads(ser.dumps(v))) jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v))) nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec))) @@ -512,7 +512,7 @@ class SciPyTests(MLlibTestCase): def test_serialize(self): from scipy.sparse import lil_matrix - ser = PickleSerializer() + ser = CPickleSerializer() lil = lil_matrix((4, 1)) lil[1, 0] = 1 lil[3, 0] = 2 diff --git a/python/pyspark/pandas/__init__.py b/python/pyspark/pandas/__init__.py index ea8a9ea639..dc025223df 100644 --- a/python/pyspark/pandas/__init__.py +++ b/python/pyspark/pandas/__init__.py @@ -61,6 +61,7 @@ from pyspark.pandas.indexes.datetimes import DatetimeIndex from pyspark.pandas.indexes.multi import MultiIndex from pyspark.pandas.indexes.numeric import Float64Index, Int64Index +from pyspark.pandas.indexes.timedelta import TimedeltaIndex from pyspark.pandas.series import Series from pyspark.pandas.groupby import NamedAgg @@ -79,6 +80,7 @@ "Float64Index", "CategoricalIndex", "DatetimeIndex", + "TimedeltaIndex", "sql", "range", "concat", @@ -144,4 +146,4 @@ def _auto_patch_pandas() -> None: # Import after the usage logger is attached. from pyspark.pandas.config import get_option, options, option_context, reset_option, set_option from pyspark.pandas.namespace import * # F405 -from pyspark.pandas.sql_processor import sql +from pyspark.pandas.sql_formatter import sql diff --git a/python/pyspark/pandas/config.py b/python/pyspark/pandas/config.py index a6689c8fde..8e5c808109 100644 --- a/python/pyspark/pandas/config.py +++ b/python/pyspark/pandas/config.py @@ -201,7 +201,8 @@ def validate(self, v: Any) -> None: "of validation. If 'compute.eager_check' is set to True, pandas-on-Spark performs the " "validation beforehand, but it will cause a performance overhead. Otherwise, " "pandas-on-Spark skip the validation and will be slightly different from pandas. " - "Affected APIs: `Series.dot`." + "Affected APIs: `Series.dot`, `Series.asof`, `FractionalExtensionOps.astype`, " + "`IntegralExtensionOps.astype`, `FractionalOps.astype`, `DecimalOps.astype`." ), default=True, types=bool, diff --git a/python/pyspark/pandas/data_type_ops/base.py b/python/pyspark/pandas/data_type_ops/base.py index f2612390b9..4354aeba87 100644 --- a/python/pyspark/pandas/data_type_ops/base.py +++ b/python/pyspark/pandas/data_type_ops/base.py @@ -31,6 +31,7 @@ BooleanType, DataType, DateType, + DayTimeIntervalType, DecimalType, FractionalType, IntegralType, @@ -234,6 +235,7 @@ def __new__(cls, dtype: Dtype, spark_type: DataType) -> "DataTypeOps": IntegralOps, ) from pyspark.pandas.data_type_ops.string_ops import StringOps, StringExtensionOps + from pyspark.pandas.data_type_ops.timedelta_ops import TimedeltaOps from pyspark.pandas.data_type_ops.udt_ops import UDTOps if isinstance(dtype, CategoricalDtype): @@ -271,6 +273,8 @@ def __new__(cls, dtype: Dtype, spark_type: DataType) -> "DataTypeOps": return object.__new__(DatetimeNTZOps) elif isinstance(spark_type, DateType): return object.__new__(DateOps) + elif isinstance(spark_type, DayTimeIntervalType): + return object.__new__(TimedeltaOps) elif isinstance(spark_type, BinaryType): return object.__new__(BinaryOps) elif isinstance(spark_type, ArrayType): diff --git a/python/pyspark/pandas/data_type_ops/num_ops.py b/python/pyspark/pandas/data_type_ops/num_ops.py index e08d6e9abb..f9e068fd2c 100644 --- a/python/pyspark/pandas/data_type_ops/num_ops.py +++ b/python/pyspark/pandas/data_type_ops/num_ops.py @@ -24,6 +24,7 @@ from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex from pyspark.pandas.base import column_op, IndexOpsMixin, numpy_column_op +from pyspark.pandas.config import get_option from pyspark.pandas.data_type_ops.base import ( DataTypeOps, is_valid_operand_for_numeric_arithmetic, @@ -388,7 +389,7 @@ def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> Ind dtype, spark_type = pandas_on_spark_type(dtype) if is_integer_dtype(dtype) and not isinstance(dtype, extension_dtypes): - if index_ops.hasnans: + if get_option("compute.eager_check") and index_ops.hasnans: raise ValueError( "Cannot convert %s with missing values to integer" % self.pretty_name ) @@ -449,7 +450,7 @@ def nan_to_null(self, index_ops: IndexOpsLike) -> IndexOpsLike: def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike: dtype, spark_type = pandas_on_spark_type(dtype) if is_integer_dtype(dtype) and not isinstance(dtype, extension_dtypes): - if index_ops.hasnans: + if get_option("compute.eager_check") and index_ops.hasnans: raise ValueError( "Cannot convert %s with missing values to integer" % self.pretty_name ) @@ -490,15 +491,17 @@ def restore(self, col: pd.Series) -> pd.Series: def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike: dtype, spark_type = pandas_on_spark_type(dtype) - - if is_integer_dtype(dtype) and not isinstance(dtype, extension_dtypes): - if index_ops.hasnans: - raise ValueError( - "Cannot convert %s with missing values to integer" % self.pretty_name - ) - elif is_bool_dtype(dtype) and not isinstance(dtype, extension_dtypes): - if index_ops.hasnans: - raise ValueError("Cannot convert %s with missing values to bool" % self.pretty_name) + if get_option("compute.eager_check"): + if is_integer_dtype(dtype) and not isinstance(dtype, extension_dtypes): + if index_ops.hasnans: + raise ValueError( + "Cannot convert %s with missing values to integer" % self.pretty_name + ) + elif is_bool_dtype(dtype) and not isinstance(dtype, extension_dtypes): + if index_ops.hasnans: + raise ValueError( + "Cannot convert %s with missing values to bool" % self.pretty_name + ) return _non_fractional_astype(index_ops, dtype, spark_type) @@ -517,15 +520,17 @@ def restore(self, col: pd.Series) -> pd.Series: def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike: dtype, spark_type = pandas_on_spark_type(dtype) - - if is_integer_dtype(dtype) and not isinstance(dtype, extension_dtypes): - if index_ops.hasnans: - raise ValueError( - "Cannot convert %s with missing values to integer" % self.pretty_name - ) - elif is_bool_dtype(dtype) and not isinstance(dtype, extension_dtypes): - if index_ops.hasnans: - raise ValueError("Cannot convert %s with missing values to bool" % self.pretty_name) + if get_option("compute.eager_check"): + if is_integer_dtype(dtype) and not isinstance(dtype, extension_dtypes): + if index_ops.hasnans: + raise ValueError( + "Cannot convert %s with missing values to integer" % self.pretty_name + ) + elif is_bool_dtype(dtype) and not isinstance(dtype, extension_dtypes): + if index_ops.hasnans: + raise ValueError( + "Cannot convert %s with missing values to bool" % self.pretty_name + ) if isinstance(dtype, CategoricalDtype): return _as_categorical_type(index_ops, dtype, spark_type) diff --git a/python/pyspark/pandas/data_type_ops/timedelta_ops.py b/python/pyspark/pandas/data_type_ops/timedelta_ops.py new file mode 100644 index 0000000000..8460aafcc2 --- /dev/null +++ b/python/pyspark/pandas/data_type_ops/timedelta_ops.py @@ -0,0 +1,28 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.pandas.data_type_ops.base import DataTypeOps + + +class TimedeltaOps(DataTypeOps): + """ + The class for binary operations of pandas-on-Spark objects with spark type: DayTimeIntervalType. + """ + + @property + def pretty_name(self) -> str: + return "timedelta" diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 38ac9af9c1..edfb62ef28 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -77,6 +77,7 @@ StringType, StructField, StructType, + DecimalType, ) from pyspark.sql.window import Window @@ -8258,6 +8259,194 @@ def update(self, other: "DataFrame", join: str = "left", overwrite: bool = True) internal = self._internal.with_new_sdf(sdf, data_fields=data_fields) self._update_internal_frame(internal, requires_same_anchor=False) + # TODO: ddof should be implemented. + def cov(self, min_periods: Optional[int] = None) -> "DataFrame": + """ + Compute pairwise covariance of columns, excluding NA/null values. + + Compute the pairwise covariance among the series of a DataFrame. + The returned data frame is the `covariance matrix + `__ of the columns + of the DataFrame. + + Both NA and null values are automatically excluded from the + calculation. (See the note below about bias from missing values.) + A threshold can be set for the minimum number of + observations for each value created. Comparisons with observations + below this threshold will be returned as ``NaN``. + + This method is generally used for the analysis of time series data to + understand the relationship between different measures + across time. + + .. versionadded:: 3.3.0 + + Parameters + ---------- + min_periods : int, optional + Minimum number of observations required per pair of columns + to have a valid result. + + Returns + ------- + DataFrame + The covariance matrix of the series of the DataFrame. + + See Also + -------- + Series.cov : Compute covariance with another Series. + + Examples + -------- + >>> df = ps.DataFrame([(1, 2), (0, 3), (2, 0), (1, 1)], + ... columns=['dogs', 'cats']) + >>> df.cov() + dogs cats + dogs 0.666667 -1.000000 + cats -1.000000 1.666667 + + >>> np.random.seed(42) + >>> df = ps.DataFrame(np.random.randn(1000, 5), + ... columns=['a', 'b', 'c', 'd', 'e']) + >>> df.cov() + a b c d e + a 0.998438 -0.020161 0.059277 -0.008943 0.014144 + b -0.020161 1.059352 -0.008543 -0.024738 0.009826 + c 0.059277 -0.008543 1.010670 -0.001486 -0.000271 + d -0.008943 -0.024738 -0.001486 0.921297 -0.013692 + e 0.014144 0.009826 -0.000271 -0.013692 0.977795 + + **Minimum number of periods** + + This method also supports an optional ``min_periods`` keyword + that specifies the required minimum number of non-NA observations for + each column pair in order to have a valid result: + + >>> np.random.seed(42) + >>> df = pd.DataFrame(np.random.randn(20, 3), + ... columns=['a', 'b', 'c']) + >>> df.loc[df.index[:5], 'a'] = np.nan + >>> df.loc[df.index[5:10], 'b'] = np.nan + >>> sdf = ps.from_pandas(df) + >>> sdf.cov(min_periods=12) + a b c + a 0.316741 NaN -0.150812 + b NaN 1.248003 0.191417 + c -0.150812 0.191417 0.895202 + """ + min_periods = 1 if min_periods is None else min_periods + + # Only compute covariance for Boolean and Numeric except Decimal + psdf = self[ + [ + col + for col in self.columns + if isinstance(self[col].spark.data_type, BooleanType) + or ( + isinstance(self[col].spark.data_type, NumericType) + and not isinstance(self[col].spark.data_type, DecimalType) + ) + ] + ] + + num_cols = len(psdf.columns) + cov = np.zeros([num_cols, num_cols]) + + if num_cols == 0: + return DataFrame() + + if len(psdf) < min_periods: + cov.fill(np.nan) + return DataFrame(cov, columns=psdf.columns, index=psdf.columns) + + data_cols = psdf._internal.data_spark_column_names + cov_scols = [] + count_not_null_scols = [] + + # Count number of null row between two columns + # Example: + # a b c + # 0 1 1 1 + # 1 NaN 2 2 + # 2 3 NaN 3 + # 3 4 4 4 + # + # a b c + # a count(a, a) count(a, b) count(a, c) + # b count(b, b) count(b, c) + # c count(c, c) + # + # count_not_null_scols = + # [F.count(a, a), F.count(a, b), F.count(a, c), F.count(b, b), F.count(b, c), F.count(c, c)] + for r in range(0, num_cols): + for c in range(r, num_cols): + count_not_null_scols.append( + F.count( + F.when(F.col(data_cols[r]).isNotNull() & F.col(data_cols[c]).isNotNull(), 1) + ) + ) + + count_not_null = ( + psdf._internal.spark_frame.replace(float("nan"), None) + .select(*count_not_null_scols) + .head(1)[0] + ) + + # Calculate covariance between two columns + # Example: + # with min_periods = 3 + # a b c + # 0 1 1 1 + # 1 NaN 2 2 + # 2 3 NaN 3 + # 3 4 4 4 + # + # a b c + # a cov(a, a) None cov(a, c) + # b cov(b, b) cov(b, c) + # c cov(c, c) + # + # cov_scols = [F.cov(a, a), None, F.cov(a, c), F.cov(b, b), F.cov(b, c), F.cov(c, c)] + step = 0 + for r in range(0, num_cols): + step += r + for c in range(r, num_cols): + cov_scols.append( + F.covar_samp( + F.col(data_cols[r]).cast("double"), F.col(data_cols[c]).cast("double") + ) + if count_not_null[r * num_cols + c - step] >= min_periods + else F.lit(None) + ) + + pair_cov = psdf._internal.spark_frame.select(*cov_scols).head(1)[0] + + # Convert from row to 2D array + # Example: + # pair_cov = [cov(a, a), None, cov(a, c), cov(b, b), cov(b, c), cov(c, c)] + # + # cov = + # + # a b c + # a cov(a, a) None cov(a, c) + # b cov(b, b) cov(b, c) + # c cov(c, c) + step = 0 + for r in range(0, num_cols): + step += r + for c in range(r, num_cols): + cov[r][c] = pair_cov[r * num_cols + c - step] + + # Copy values + # Example: + # cov = + # a b c + # a cov(a, a) None cov(a, c) + # b None cov(b, b) cov(b, c) + # c cov(a, c) cov(b, c) cov(c, c) + cov = cov + cov.T - np.diag(np.diag(cov)) + return DataFrame(cov, columns=psdf.columns, index=psdf.columns) + def sample( self, n: Optional[int] = None, diff --git a/python/pyspark/pandas/indexes/__init__.py b/python/pyspark/pandas/indexes/__init__.py index cd2adbaf9f..7fde6ffaf6 100644 --- a/python/pyspark/pandas/indexes/__init__.py +++ b/python/pyspark/pandas/indexes/__init__.py @@ -18,3 +18,4 @@ from pyspark.pandas.indexes.datetimes import DatetimeIndex # noqa: F401 from pyspark.pandas.indexes.multi import MultiIndex # noqa: F401 from pyspark.pandas.indexes.numeric import Float64Index, Int64Index # noqa: F401 +from pyspark.pandas.indexes.timedelta import TimedeltaIndex # noqa: F401 diff --git a/python/pyspark/pandas/indexes/base.py b/python/pyspark/pandas/indexes/base.py index ecad216530..ee65f623da 100644 --- a/python/pyspark/pandas/indexes/base.py +++ b/python/pyspark/pandas/indexes/base.py @@ -37,7 +37,13 @@ from pandas._libs import lib from pyspark.sql import functions as F, Column -from pyspark.sql.types import FractionalType, IntegralType, TimestampType, TimestampNTZType +from pyspark.sql.types import ( + DayTimeIntervalType, + FractionalType, + IntegralType, + TimestampType, + TimestampNTZType, +) from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. from pyspark.pandas._typing import Dtype, Label, Name, Scalar @@ -178,6 +184,7 @@ def _new_instance(anchor: DataFrame) -> "Index": from pyspark.pandas.indexes.datetimes import DatetimeIndex from pyspark.pandas.indexes.multi import MultiIndex from pyspark.pandas.indexes.numeric import Float64Index, Int64Index + from pyspark.pandas.indexes.timedelta import TimedeltaIndex instance: Index if anchor._internal.index_level > 1: @@ -197,6 +204,11 @@ def _new_instance(anchor: DataFrame) -> "Index": (TimestampType, TimestampNTZType), ): instance = object.__new__(DatetimeIndex) + elif isinstance( + anchor._internal.spark_type_for(anchor._internal.index_spark_columns[0]), + DayTimeIntervalType, + ): + instance = object.__new__(TimedeltaIndex) else: instance = object.__new__(Index) diff --git a/python/pyspark/pandas/indexes/timedelta.py b/python/pyspark/pandas/indexes/timedelta.py new file mode 100644 index 0000000000..5f5e58e95a --- /dev/null +++ b/python/pyspark/pandas/indexes/timedelta.py @@ -0,0 +1,100 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import cast, no_type_check, Any +from functools import partial + +import pandas as pd +from pandas.api.types import is_hashable + +from pyspark import pandas as ps +from pyspark._globals import _NoValue +from pyspark.pandas.indexes.base import Index +from pyspark.pandas.missing.indexes import MissingPandasLikeTimedeltaIndex +from pyspark.pandas.series import Series + + +class TimedeltaIndex(Index): + """ + Immutable ndarray-like of timedelta64 data, represented internally as int64, and + which can be boxed to timedelta objects. + + Parameters + ---------- + data : array-like (1-dimensional), optional + Optional timedelta-like data to construct index with. + unit : unit of the arg (D,h,m,s,ms,us,ns) denote the unit, optional + Which is an integer/float number. + freq : str or pandas offset object, optional + One of pandas date offset strings or corresponding objects. The string + 'infer' can be passed in order to set the frequency of the index as the + inferred frequency upon creation. + copy : bool + Make a copy of input ndarray. + name : object + Name to be stored in the index. + + See Also + -------- + Index : The base pandas Index type. + + Examples + -------- + >>> from datetime import timedelta + >>> ps.TimedeltaIndex([timedelta(1), timedelta(microseconds=2)]) + TimedeltaIndex(['1 days 00:00:00', '0 days 00:00:00.000002'], dtype='timedelta64[ns]', freq=None) + """ + + @no_type_check + def __new__( + cls, + data=None, + unit=None, + freq=_NoValue, + closed=None, + dtype=None, + copy=False, + name=None, + ) -> "TimedeltaIndex": + if not is_hashable(name): + raise TypeError("Index.name must be a hashable type") + + if isinstance(data, (Series, Index)): + # TODO(SPARK-37512): Support TimedeltaIndex creation given a timedelta Series/Index + raise NotImplementedError("Create a TimedeltaIndex from Index/Series is not supported") + + kwargs = dict( + data=data, + unit=unit, + closed=closed, + dtype=dtype, + copy=copy, + name=name, + ) + if freq is not _NoValue: + kwargs["freq"] = freq + + return cast(TimedeltaIndex, ps.from_pandas(pd.TimedeltaIndex(**kwargs))) + + def __getattr__(self, item: str) -> Any: + if hasattr(MissingPandasLikeTimedeltaIndex, item): + property_or_func = getattr(MissingPandasLikeTimedeltaIndex, item) + if isinstance(property_or_func, property): + return property_or_func.fget(self) + else: + return partial(property_or_func, self) + + raise AttributeError("'TimedeltaIndex' object has no attribute '{}'".format(item)) diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py index 5cb21a7872..e5786f52fe 100644 --- a/python/pyspark/pandas/internal.py +++ b/python/pyspark/pandas/internal.py @@ -1524,6 +1524,17 @@ def prepare_pandas_frame( >>> data_fields [InternalField(dtype=datetime64[ns],struct_field=StructField(dt,TimestampNTZType,false)), InternalField(dtype=object,struct_field=StructField(dt_obj,TimestampNTZType,false))] + + >>> pdf = pd.DataFrame({ + ... "td": [datetime.timedelta(0)], "td_obj": [datetime.timedelta(0)] + ... }) + >>> pdf.td_obj = pdf.td_obj.astype("object") + >>> _, _, _, _, data_fields = ( + ... InternalFrame.prepare_pandas_frame(pdf) + ... ) + >>> data_fields # doctest: +NORMALIZE_WHITESPACE + [InternalField(dtype=timedelta64[ns],struct_field=StructField(td,DayTimeIntervalType(0,3),false)), + InternalField(dtype=object,struct_field=StructField(td_obj,DayTimeIntervalType(0,3),false))] """ pdf = pdf.copy() diff --git a/python/pyspark/pandas/missing/frame.py b/python/pyspark/pandas/missing/frame.py index aabc0e042e..d822c14192 100644 --- a/python/pyspark/pandas/missing/frame.py +++ b/python/pyspark/pandas/missing/frame.py @@ -39,7 +39,6 @@ class _MissingPandasLikeDataFrame(object): compare = _unsupported_function("compare") convert_dtypes = _unsupported_function("convert_dtypes") corrwith = _unsupported_function("corrwith") - cov = _unsupported_function("cov") ewm = _unsupported_function("ewm") infer_objects = _unsupported_function("infer_objects") interpolate = _unsupported_function("interpolate") diff --git a/python/pyspark/pandas/missing/indexes.py b/python/pyspark/pandas/missing/indexes.py index 4170aa70f7..f1ddccf563 100644 --- a/python/pyspark/pandas/missing/indexes.py +++ b/python/pyspark/pandas/missing/indexes.py @@ -100,6 +100,24 @@ class MissingPandasLikeDatetimeIndex(MissingPandasLikeIndex): std = _unsupported_function("std", cls="DatetimeIndex") +class MissingPandasLikeTimedeltaIndex(MissingPandasLikeIndex): + + # Properties + days = _unsupported_property("days", cls="TimedeltaIndex") + seconds = _unsupported_property("seconds", cls="TimedeltaIndex") + microseconds = _unsupported_property("microseconds", cls="TimedeltaIndex") + nanoseconds = _unsupported_property("nanoseconds", cls="TimedeltaIndex") + components = _unsupported_property("components", cls="TimedeltaIndex") + inferred_freq = _unsupported_property("inferred_freq", cls="TimedeltaIndex") + + # Functions + to_pytimedelta = _unsupported_function("to_pytimedelta", cls="TimedeltaIndex") + round = _unsupported_function("round", cls="TimedeltaIndex") + floor = _unsupported_function("floor", cls="TimedeltaIndex") + ceil = _unsupported_function("ceil", cls="TimedeltaIndex") + mean = _unsupported_function("mean", cls="TimedeltaIndex") + + class MissingPandasLikeMultiIndex(object): # Functions diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index f6ec5e943a..4a459b6c33 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -5160,7 +5160,7 @@ def asof(self, where: Union[Any, List]) -> Union[Scalar, "Series"]: If there is no good value, NaN is returned. .. note:: This API is dependent on :meth:`Index.is_monotonic_increasing` - which can be expensive. + which is expensive. Parameters ---------- @@ -5179,7 +5179,9 @@ def asof(self, where: Union[Any, List]) -> Union[Scalar, "Series"]: Notes ----- - Indices are assumed to be sorted. Raises if this is not the case. + Indices are assumed to be sorted. Raises if this is not the case and config + 'compute.eager_check' is True. If 'compute.eager_check' is False pandas-on-Spark just + proceeds and performs by ignoring the indeces's order Examples -------- @@ -5210,13 +5212,19 @@ def asof(self, where: Union[Any, List]) -> Union[Scalar, "Series"]: >>> s.asof(30) 2.0 + + >>> s = ps.Series([1, 2, np.nan, 4], index=[10, 30, 20, 40]) + >>> with ps.option_context("compute.eager_check", False): + ... s.asof(20) + ... + 1.0 """ should_return_series = True if isinstance(self.index, ps.MultiIndex): raise ValueError("asof is not supported for a MultiIndex") if isinstance(where, (ps.Index, ps.Series, DataFrame)): raise ValueError("where cannot be an Index, Series or a DataFrame") - if not self.index.is_monotonic_increasing: + if get_option("compute.eager_check") and not self.index.is_monotonic_increasing: raise ValueError("asof requires a sorted index") if not is_list_like(where): should_return_series = False diff --git a/python/pyspark/pandas/sql_formatter.py b/python/pyspark/pandas/sql_formatter.py new file mode 100644 index 0000000000..685ee25cc6 --- /dev/null +++ b/python/pyspark/pandas/sql_formatter.py @@ -0,0 +1,273 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import string +from typing import Any, Optional, Union, List, Sequence, Mapping, Tuple +import uuid +import warnings + +import pandas as pd + +from pyspark.pandas.internal import InternalFrame +from pyspark.pandas.namespace import _get_index_map +from pyspark.sql.functions import lit +from pyspark import pandas as ps +from pyspark.sql import SparkSession +from pyspark.pandas.utils import default_session +from pyspark.pandas.frame import DataFrame +from pyspark.pandas.series import Series + + +__all__ = ["sql"] + + +# This is not used in this file. It's for legacy sql_processor. +_CAPTURE_SCOPES = 3 + + +def sql( + query: str, + index_col: Optional[Union[str, List[str]]] = None, + **kwargs: Any, +) -> DataFrame: + """ + Execute a SQL query and return the result as a pandas-on-Spark DataFrame. + + This function acts as a standard Python string formatter with understanding + the following variable types: + + * pandas-on-Spark DataFrame + * pandas-on-Spark Series + * pandas DataFrame + * pandas Series + * string + + Parameters + ---------- + query : str + the SQL query + index_col : str or list of str, optional + Column names to be used in Spark to represent pandas-on-Spark's index. The index name + in pandas-on-Spark is ignored. By default, the index is always lost. + + .. note:: If you want to preserve the index, explicitly use :func:`DataFrame.reset_index`, + and pass it to the sql statement with `index_col` parameter. + + For example, + + >>> psdf = ps.DataFrame({"A": [1, 2, 3], "B":[4, 5, 6]}, index=['a', 'b', 'c']) + >>> new_psdf = psdf.reset_index() + >>> ps.sql("SELECT * FROM {new_psdf}", index_col="index", new_psdf=new_psdf) + ... # doctest: +NORMALIZE_WHITESPACE + A B + index + a 1 4 + b 2 5 + c 3 6 + + For MultiIndex, + + >>> psdf = ps.DataFrame( + ... {"A": [1, 2, 3], "B": [4, 5, 6]}, + ... index=pd.MultiIndex.from_tuples( + ... [("a", "b"), ("c", "d"), ("e", "f")], names=["index1", "index2"] + ... ), + ... ) + >>> new_psdf = psdf.reset_index() + >>> ps.sql("SELECT * FROM {new_psdf}", index_col=["index1", "index2"], new_psdf=new_psdf) + ... # doctest: +NORMALIZE_WHITESPACE + A B + index1 index2 + a b 1 4 + c d 2 5 + e f 3 6 + + Also note that the index name(s) should be matched to the existing name. + kwargs + other variables that the user want to set that can be referenced in the query + + Returns + ------- + pandas-on-Spark DataFrame + + Examples + -------- + + Calling a built-in SQL function. + + >>> ps.sql("SELECT * FROM range(10) where id > 7") + id + 0 8 + 1 9 + + >>> ps.sql("SELECT * FROM range(10) WHERE id > {bound1} AND id < {bound2}", bound1=7, bound2=9) + id + 0 8 + + >>> mydf = ps.range(10) + >>> x = tuple(range(4)) + >>> ps.sql("SELECT {ser} FROM {mydf} WHERE id IN {x}", ser=mydf.id, mydf=mydf, x=x) + id + 0 0 + 1 1 + 2 2 + 3 3 + + Mixing pandas-on-Spark and pandas DataFrames in a join operation. Note that the index is + dropped. + + >>> ps.sql(''' + ... SELECT m1.a, m2.b + ... FROM {table1} m1 INNER JOIN {table2} m2 + ... ON m1.key = m2.key + ... ORDER BY m1.a, m2.b''', + ... table1=ps.DataFrame({"a": [1,2], "key": ["a", "b"]}), + ... table2=pd.DataFrame({"b": [3,4,5], "key": ["a", "b", "b"]})) + a b + 0 1 3 + 1 2 4 + 2 2 5 + + Also, it is possible to query using Series. + + >>> psdf = ps.DataFrame({"A": [1, 2, 3], "B":[4, 5, 6]}, index=['a', 'b', 'c']) + >>> ps.sql("SELECT {mydf.A} FROM {mydf}", mydf=psdf) + A + 0 1 + 1 2 + 2 3 + """ + if os.environ.get("PYSPARK_PANDAS_SQL_LEGACY") == "1": + from pyspark.pandas import sql_processor + + warnings.warn( + "Deprecated in 3.3.0, and the legacy behavior " + "will be removed in the future releases.", + FutureWarning, + ) + return sql_processor.sql(query, index_col=index_col, **kwargs) + + session = default_session() + formatter = SQLStringFormatter(session) + try: + sdf = session.sql(formatter.format(query, **kwargs)) + finally: + formatter.clear() + + index_spark_columns, index_names = _get_index_map(sdf, index_col) + + return DataFrame( + InternalFrame( + spark_frame=sdf, index_spark_columns=index_spark_columns, index_names=index_names + ) + ) + + +class SQLStringFormatter(string.Formatter): + """ + A standard ``string.Formatter`` in Python that can understand pandas-on-Spark instances + with basic Python objects. This object has to be clear after the use for single SQL + query; cannot be reused across multiple SQL queries without cleaning. + """ + + def __init__(self, session: SparkSession) -> None: + self._session: SparkSession = session + self._temp_views: List[Tuple[DataFrame, str]] = [] + self._ref_sers: List[Tuple[Series, str]] = [] + + def vformat(self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> str: + ret = super(SQLStringFormatter, self).vformat(format_string, args, kwargs) + + for ref, n in self._ref_sers: + if not any((ref is v for v in df._pssers.values()) for df, _ in self._temp_views): + # If referred DataFrame does not hold the given Series, raise an error. + raise ValueError("The series in {%s} does not refer any dataframe specified." % n) + return ret + + def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> Any: + obj, first = super(SQLStringFormatter, self).get_field(field_name, args, kwargs) + return self._convert_value(obj, field_name), first + + def _convert_value(self, val: Any, name: str) -> Optional[str]: + """ + Converts the given value into a SQL string. + """ + if isinstance(val, pd.Series): + # Return the column name from pandas Series directly. + return ps.from_pandas(val).to_frame()._to_spark().columns[0] + elif isinstance(val, Series): + # Return the column name of pandas-on-Spark Series iff its DataFrame was + # referred. The check will be done in `vformat` after we parse all. + self._ref_sers.append((val, name)) + return val.to_frame()._to_spark().columns[0] + elif isinstance(val, (DataFrame, pd.DataFrame)): + df_name = "_pandas_api_%s" % str(uuid.uuid4()).replace("-", "") + + if isinstance(val, pd.DataFrame): + # Don't store temp view for plain pandas instances + # because it is unable to know which pandas DataFrame + # holds which Series. + val = ps.from_pandas(val) + else: + for df, n in self._temp_views: + if df is val: + return n + self._temp_views.append((val, df_name)) + + val._to_spark().createOrReplaceTempView(df_name) + return df_name + elif isinstance(val, str): + return lit(val)._jc.expr().sql() # for escaped characters. + else: + return val + + def clear(self) -> None: + for _, n in self._temp_views: + self._session.catalog.dropTempView(n) + self._temp_views = [] + self._ref_sers = [] + + +def _test() -> None: + import os + import doctest + import sys + from pyspark.sql import SparkSession + import pyspark.pandas.sql_formatter + + os.chdir(os.environ["SPARK_HOME"]) + + globs = pyspark.pandas.sql_formatter.__dict__.copy() + globs["ps"] = pyspark.pandas + spark = ( + SparkSession.builder.master("local[4]") + .appName("pyspark.pandas.sql_processor tests") + .getOrCreate() + ) + (failure_count, test_count) = doctest.testmod( + pyspark.pandas.sql_formatter, + globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE, + ) + spark.stop() + if failure_count: + sys.exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/pandas/sql_processor.py b/python/pyspark/pandas/sql_processor.py index afdaa101d6..8126d1e10a 100644 --- a/python/pyspark/pandas/sql_processor.py +++ b/python/pyspark/pandas/sql_processor.py @@ -77,9 +77,13 @@ def sql( For example, + >>> from pyspark.pandas import sql_processor + >>> # we will call 'sql_processor' directly in doctests so decrease one level. + >>> sql_processor._CAPTURE_SCOPES = 2 + >>> sql = sql_processor.sql >>> psdf = ps.DataFrame({"A": [1, 2, 3], "B":[4, 5, 6]}, index=['a', 'b', 'c']) >>> psdf_reset_index = psdf.reset_index() - >>> ps.sql("SELECT * FROM {psdf_reset_index}", index_col="index") + >>> sql("SELECT * FROM {psdf_reset_index}", index_col="index") ... # doctest: +NORMALIZE_WHITESPACE A B index @@ -96,7 +100,7 @@ def sql( ... ), ... ) >>> psdf_reset_index = psdf.reset_index() - >>> ps.sql("SELECT * FROM {psdf_reset_index}", index_col=["index1", "index2"]) + >>> sql("SELECT * FROM {psdf_reset_index}", index_col=["index1", "index2"]) ... # doctest: +NORMALIZE_WHITESPACE A B index1 index2 @@ -122,7 +126,7 @@ def sql( Calling a built-in SQL function. - >>> ps.sql("select * from range(10) where id > 7") + >>> sql("select * from range(10) where id > 7") id 0 8 1 9 @@ -130,7 +134,7 @@ def sql( A query can also reference a local variable or parameter by wrapping them in curly braces: >>> bound1 = 7 - >>> ps.sql("select * from range(10) where id > {bound1} and id < {bound2}", bound2=9) + >>> sql("select * from range(10) where id > {bound1} and id < {bound2}", bound2=9) id 0 8 @@ -139,7 +143,7 @@ def sql( >>> mydf = ps.range(10) >>> x = range(4) - >>> ps.sql("SELECT * from {mydf} WHERE id IN {x}") + >>> sql("SELECT * from {mydf} WHERE id IN {x}") id 0 0 1 1 @@ -150,7 +154,7 @@ def sql( >>> def statement(): ... mydf2 = ps.DataFrame({"x": range(2)}) - ... return ps.sql("SELECT * from {mydf2}") + ... return sql("SELECT * from {mydf2}") >>> statement() x 0 0 @@ -159,7 +163,7 @@ def sql( Mixing pandas-on-Spark and pandas DataFrames in a join operation. Note that the index is dropped. - >>> ps.sql(''' + >>> sql(''' ... SELECT m1.a, m2.b ... FROM {table1} m1 INNER JOIN {table2} m2 ... ON m1.key = m2.key @@ -174,7 +178,7 @@ def sql( Also, it is possible to query using Series. >>> myser = ps.Series({'a': [1.0, 2.0, 3.0], 'b': [15.0, 30.0, 45.0]}) - >>> ps.sql("SELECT * from {myser}") + >>> sql("SELECT * from {myser}") 0 0 [1.0, 2.0, 3.0] 1 [15.0, 30.0, 45.0] @@ -195,7 +199,7 @@ def sql( return SQLProcessor(_dict, query, default_session()).execute(index_col) -_CAPTURE_SCOPES = 2 +_CAPTURE_SCOPES = 3 def _get_local_scope() -> Dict[str, Any]: @@ -272,19 +276,23 @@ def execute(self, index_col: Optional[Union[str, List[str]]]) -> DataFrame: Returns a DataFrame for which the SQL statement has been executed by the underlying SQL engine. + >>> from pyspark.pandas import sql_processor + >>> # we will call 'sql_processor' directly in doctests so decrease one level. + >>> sql_processor._CAPTURE_SCOPES = 2 + >>> sql = sql_processor.sql >>> str0 = 'abc' - >>> ps.sql("select {str0}") + >>> sql("select {str0}") abc 0 abc >>> str1 = 'abc"abc' >>> str2 = "abc'abc" - >>> ps.sql("select {str0}, {str1}, {str2}") + >>> sql("select {str0}, {str1}, {str2}") abc abc"abc abc'abc 0 abc abc"abc abc'abc >>> strs = ['a', 'b'] - >>> ps.sql("select 'a' in {strs} as cond1, 'c' in {strs} as cond2") + >>> sql("select 'a' in {strs} as cond1, 'c' in {strs} as cond2") cond1 cond2 0 True False """ diff --git a/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py b/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py index f4b36f969a..77fc93c0eb 100644 --- a/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py +++ b/python/pyspark/pandas/tests/data_type_ops/test_num_ops.py @@ -369,6 +369,23 @@ def test_astype(self): psser = ps.from_pandas(pser) self.assert_eq(pser.astype(pd.BooleanDtype()), psser.astype(pd.BooleanDtype())) + def test_astype_eager_check(self): + psser = self.psdf["float_nan"] + with ps.option_context("compute.eager_check", True), self.assertRaisesRegex( + ValueError, "Cannot convert" + ): + psser.astype(int) + with ps.option_context("compute.eager_check", False): + psser.astype(int) + + psser = self.psdf["decimal_nan"] + with ps.option_context("compute.eager_check", True), self.assertRaisesRegex( + ValueError, "Cannot convert" + ): + psser.astype(int) + with ps.option_context("compute.eager_check", False): + psser.astype(int) + def test_neg(self): pdf, psdf = self.pdf, self.psdf for col in self.numeric_df_cols: @@ -475,21 +492,26 @@ def test_astype(self): for pser, psser in self.intergral_extension_pser_psser_pairs: self.assert_eq(pser.astype(float), psser.astype(float)) self.assert_eq(pser.astype(np.float32), psser.astype(np.float32)) - self.assertRaisesRegex( - ValueError, - "Cannot convert integrals with missing values to bool", - lambda: psser.astype(bool), - ) - self.assertRaisesRegex( - ValueError, - "Cannot convert integrals with missing values to integer", - lambda: psser.astype(int), - ) - self.assertRaisesRegex( - ValueError, - "Cannot convert integrals with missing values to integer", - lambda: psser.astype(np.int32), - ) + with ps.option_context("compute.eager_check", True): + self.assertRaisesRegex( + ValueError, + "Cannot convert integrals with missing values to bool", + lambda: psser.astype(bool), + ) + self.assertRaisesRegex( + ValueError, + "Cannot convert integrals with missing values to integer", + lambda: psser.astype(int), + ) + self.assertRaisesRegex( + ValueError, + "Cannot convert integrals with missing values to integer", + lambda: psser.astype(np.int32), + ) + with ps.option_context("compute.eager_check", False): + psser.astype(bool) + psser.astype(int) + psser.astype(np.int32) def test_neg(self): for pser, psser in self.intergral_extension_pser_psser_pairs: @@ -607,21 +629,26 @@ def test_astype(self): for pser, psser in self.fractional_extension_pser_psser_pairs: self.assert_eq(pser.astype(float), psser.astype(float)) self.assert_eq(pser.astype(np.float32), psser.astype(np.float32)) - self.assertRaisesRegex( - ValueError, - "Cannot convert fractions with missing values to bool", - lambda: psser.astype(bool), - ) - self.assertRaisesRegex( - ValueError, - "Cannot convert fractions with missing values to integer", - lambda: psser.astype(int), - ) - self.assertRaisesRegex( - ValueError, - "Cannot convert fractions with missing values to integer", - lambda: psser.astype(np.int32), - ) + with ps.option_context("compute.eager_check", True): + self.assertRaisesRegex( + ValueError, + "Cannot convert fractions with missing values to bool", + lambda: psser.astype(bool), + ) + self.assertRaisesRegex( + ValueError, + "Cannot convert fractions with missing values to integer", + lambda: psser.astype(int), + ) + self.assertRaisesRegex( + ValueError, + "Cannot convert fractions with missing values to integer", + lambda: psser.astype(np.int32), + ) + with ps.option_context("compute.eager_check", False): + psser.astype(bool) + psser.astype(int) + psser.astype(np.int32) def test_neg(self): # pandas raises "TypeError: bad operand type for unary -: 'FloatingArray'" diff --git a/python/pyspark/pandas/tests/indexes/test_base.py b/python/pyspark/pandas/tests/indexes/test_base.py index e7e5216d3b..173a2bf8b0 100644 --- a/python/pyspark/pandas/tests/indexes/test_base.py +++ b/python/pyspark/pandas/tests/indexes/test_base.py @@ -18,7 +18,7 @@ import inspect import unittest from distutils.version import LooseVersion -from datetime import datetime +from datetime import datetime, timedelta import numpy as np import pandas as pd @@ -29,6 +29,7 @@ MissingPandasLikeDatetimeIndex, MissingPandasLikeIndex, MissingPandasLikeMultiIndex, + MissingPandasLikeTimedeltaIndex, ) from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils, SPARK_CONF_ARROW_ENABLED @@ -456,6 +457,7 @@ def test_missing(self): "b": [4, 5, 6], "c": pd.date_range("2011-01-01", freq="D", periods=3), "d": pd.Categorical(["a", "b", "c"]), + "e": [timedelta(1), timedelta(2), timedelta(3)], } ) @@ -522,6 +524,27 @@ def test_missing(self): ): getattr(psdf.set_index("c").index, name)() + # TimedeltaIndex functions + missing_functions = inspect.getmembers(MissingPandasLikeTimedeltaIndex, inspect.isfunction) + unsupported_functions = [ + name for (name, type_) in missing_functions if type_.__name__ == "unsupported_function" + ] + for name in unsupported_functions: + with self.assertRaisesRegex( + PandasNotImplementedError, + "method.*Index.*{}.*not implemented( yet\\.|\\. .+)".format(name), + ): + getattr(psdf.set_index("e").index, name)() + + deprecated_functions = [ + name for (name, type_) in missing_functions if type_.__name__ == "deprecated_function" + ] + for name in deprecated_functions: + with self.assertRaisesRegex( + PandasNotImplementedError, "method.*Index.*{}.*is deprecated".format(name) + ): + getattr(psdf.set_index("e").index, name)() + # Index properties missing_properties = inspect.getmembers( MissingPandasLikeIndex, lambda o: isinstance(o, property) @@ -592,6 +615,22 @@ def test_missing(self): ): getattr(psdf.set_index("c").index, name) + # TimedeltaIndex properties + missing_properties = inspect.getmembers( + MissingPandasLikeDatetimeIndex, lambda o: isinstance(o, property) + ) + unsupported_properties = [ + name + for (name, type_) in missing_properties + if type_.fget.__name__ == "unsupported_property" + ] + for name in unsupported_properties: + with self.assertRaisesRegex( + PandasNotImplementedError, + "property.*Index.*{}.*not implemented( yet\\.|\\. .+)".format(name), + ): + getattr(psdf.set_index("c").index, name) + def test_index_has_duplicates(self): indexes = [("a", "b", "c"), ("a", "a", "c"), (1, 3, 3), (1, 2, 3)] names = [None, "ks", "ks", None] diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py index 701052ed2c..ae8fcaef89 100644 --- a/python/pyspark/pandas/tests/test_dataframe.py +++ b/python/pyspark/pandas/tests/test_dataframe.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import decimal from datetime import datetime from distutils.version import LooseVersion import inspect @@ -6025,6 +6025,69 @@ def test_multi_index_dtypes(self): ) self.assert_eq(psmidx.dtypes, expected) + def test_cov(self): + # SPARK-36396: Implement DataFrame.cov + + # int + pdf = pd.DataFrame([(1, 2), (0, 3), (2, 0), (1, 1)], columns=["a", "b"]) + psdf = ps.from_pandas(pdf) + self.assert_eq(pdf.cov(), psdf.cov(), almost=True) + self.assert_eq(pdf.cov(min_periods=4), psdf.cov(min_periods=4), almost=True) + self.assert_eq(pdf.cov(min_periods=5), psdf.cov(min_periods=5)) + + # bool + pdf = pd.DataFrame( + { + "a": [1, np.nan, 3, 4], + "b": [True, False, False, True], + "c": [True, True, False, True], + } + ) + psdf = ps.from_pandas(pdf) + self.assert_eq(pdf.cov(), psdf.cov(), almost=True) + self.assert_eq(pdf.cov(min_periods=4), psdf.cov(min_periods=4), almost=True) + self.assert_eq(pdf.cov(min_periods=5), psdf.cov(min_periods=5)) + + # extension dtype + numeric_dtypes = ["Int8", "Int16", "Int32", "Int64", "Float32", "Float64", "float"] + boolean_dtypes = ["boolean", "bool"] + + sers = [pd.Series([1, 2, 3, None], dtype=dtype) for dtype in numeric_dtypes] + sers += [pd.Series([True, False, True, None], dtype=dtype) for dtype in boolean_dtypes] + sers.append(pd.Series([decimal.Decimal(1), decimal.Decimal(2), decimal.Decimal(3), None])) + + pdf = pd.concat(sers, axis=1) + pdf.columns = [dtype for dtype in numeric_dtypes + boolean_dtypes] + ["decimal"] + psdf = ps.from_pandas(pdf) + + self.assert_eq(pdf.cov(), psdf.cov(), almost=True) + self.assert_eq(pdf.cov(min_periods=3), psdf.cov(min_periods=3), almost=True) + self.assert_eq(pdf.cov(min_periods=4), psdf.cov(min_periods=4)) + + # string column + pdf = pd.DataFrame( + [(1, 2, "a", 1), (0, 3, "b", 1), (2, 0, "c", 9), (1, 1, "d", 1)], + columns=["a", "b", "c", "d"], + ) + psdf = ps.from_pandas(pdf) + self.assert_eq(pdf.cov(), psdf.cov(), almost=True) + self.assert_eq(pdf.cov(min_periods=4), psdf.cov(min_periods=4), almost=True) + self.assert_eq(pdf.cov(min_periods=5), psdf.cov(min_periods=5)) + + # nan + np.random.seed(42) + pdf = pd.DataFrame(np.random.randn(20, 3), columns=["a", "b", "c"]) + pdf.loc[pdf.index[:5], "a"] = np.nan + pdf.loc[pdf.index[5:10], "b"] = np.nan + psdf = ps.from_pandas(pdf) + self.assert_eq(pdf.cov(min_periods=11), psdf.cov(min_periods=11), almost=True) + self.assert_eq(pdf.cov(min_periods=10), psdf.cov(min_periods=10), almost=True) + + # return empty DataFrame + pdf = pd.DataFrame([("1", "2"), ("0", "3"), ("2", "0"), ("1", "1")], columns=["a", "b"]) + psdf = ps.from_pandas(pdf) + self.assert_eq(pdf.cov(), psdf.cov()) + if __name__ == "__main__": from pyspark.pandas.tests.test_dataframe import * # noqa: F401 diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py index 72677d18e4..51c26ad830 100644 --- a/python/pyspark/pandas/tests/test_series.py +++ b/python/pyspark/pandas/tests/test_series.py @@ -2115,6 +2115,19 @@ def test_asof(self): self.assert_eq(psser.asof("2014-01-02"), pser.asof("2014-01-02")) self.assert_eq(repr(psser.asof("1999-01-02")), repr(pser.asof("1999-01-02"))) + # SPARK-37482: Skip check monotonic increasing for Series.asof with 'compute.eager_check' + pser = pd.Series([1, 2, np.nan, 4], index=[10, 30, 20, 40]) + psser = ps.from_pandas(pser) + + with ps.option_context("compute.eager_check", False): + self.assert_eq(psser.asof(20), 1.0) + + pser = pd.Series([1, 2, np.nan, 4], index=[40, 30, 20, 10]) + psser = ps.from_pandas(pser) + + with ps.option_context("compute.eager_check", False): + self.assert_eq(psser.asof(20), 4.0) + def test_squeeze(self): # Single value pser = pd.Series([90]) @@ -2232,7 +2245,9 @@ def test_mad(self): pser.index = pmidx psser = ps.from_pandas(pser) - self.assert_eq(pser.mad(), psser.mad()) + # Mark almost as True to avoid precision issue like: + # "21.555555555555554 != 21.555555555555557" + self.assert_eq(pser.mad(), psser.mad(), almost=True) def test_to_frame(self): pser = pd.Series(["a", "b", "c"]) diff --git a/python/pyspark/pandas/tests/test_sql.py b/python/pyspark/pandas/tests/test_sql.py index 306ea166cf..ca0dd99a32 100644 --- a/python/pyspark/pandas/tests/test_sql.py +++ b/python/pyspark/pandas/tests/test_sql.py @@ -23,20 +23,22 @@ class SQLTest(PandasOnSparkTestCase, SQLTestUtils): def test_error_variable_not_exist(self): - msg = "The key variable_foo in the SQL statement was not found.*" - with self.assertRaisesRegex(ValueError, msg): + with self.assertRaisesRegex(KeyError, "variable_foo"): ps.sql("select * from {variable_foo}") def test_error_unsupported_type(self): - msg = "Unsupported variable type dict: {'a': 1}" - with self.assertRaisesRegex(ValueError, msg): - some_dict = {"a": 1} + with self.assertRaisesRegex(KeyError, "some_dict"): ps.sql("select * from {some_dict}") def test_error_bad_sql(self): with self.assertRaises(ParseException): ps.sql("this is not valid sql") + def test_series_not_referred(self): + psdf = ps.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + with self.assertRaisesRegex(ValueError, "The series in {ser}"): + ps.sql("SELECT {ser} FROM range(10)", ser=psdf.A) + def test_sql_with_index_col(self): import pandas as pd @@ -45,7 +47,11 @@ def test_sql_with_index_col(self): {"A": [1, 2, 3], "B": [4, 5, 6]}, index=pd.Index(["a", "b", "c"], name="index") ) psdf_reset_index = psdf.reset_index() - actual = ps.sql("select * from {psdf_reset_index} where A > 1", index_col="index") + actual = ps.sql( + "select * from {psdf_reset_index} where A > 1", + index_col="index", + psdf_reset_index=psdf_reset_index, + ) expected = psdf.iloc[[1, 2]] self.assert_eq(actual, expected) @@ -58,11 +64,40 @@ def test_sql_with_index_col(self): ) psdf_reset_index = psdf.reset_index() actual = ps.sql( - "select * from {psdf_reset_index} where A > 1", index_col=["index1", "index2"] + "select * from {psdf_reset_index} where A > 1", + index_col=["index1", "index2"], + psdf_reset_index=psdf_reset_index, ) expected = psdf.iloc[[1, 2]] self.assert_eq(actual, expected) + def test_sql_with_pandas_objects(self): + import pandas as pd + + pdf = pd.DataFrame({"a": [1, 2, 3, 4]}) + self.assert_eq(ps.sql("SELECT {col} + 1 as a FROM {tbl}", col=pdf.a, tbl=pdf), pdf + 1) + + def test_sql_with_python_objects(self): + self.assert_eq( + ps.sql("SELECT {col} as a FROM range(1)", col="lit"), ps.DataFrame({"a": ["lit"]}) + ) + self.assert_eq( + ps.sql("SELECT id FROM range(10) WHERE id IN {pred}", col="lit", pred=(1, 2, 3)), + ps.DataFrame({"id": [1, 2, 3]}), + ) + + def test_sql_with_pandas_on_spark_objects(self): + psdf = ps.DataFrame({"a": [1, 2, 3, 4]}) + + self.assert_eq(ps.sql("SELECT {col} FROM {tbl}", col=psdf.a, tbl=psdf), psdf) + self.assert_eq(ps.sql("SELECT {tbl.a} FROM {tbl}", tbl=psdf), psdf) + + psdf = ps.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + self.assert_eq( + ps.sql("SELECT {col}, {col2} FROM {tbl}", col=psdf.A, col2=psdf.B, tbl=psdf), psdf + ) + self.assert_eq(ps.sql("SELECT {tbl.A}, {tbl.B} FROM {tbl}", tbl=psdf), psdf) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/pandas/typedef/typehints.py b/python/pyspark/pandas/typedef/typehints.py index 14a1056729..6620ffcfe5 100644 --- a/python/pyspark/pandas/typedef/typehints.py +++ b/python/pyspark/pandas/typedef/typehints.py @@ -217,6 +217,10 @@ def as_spark_type( elif tpe in (datetime.datetime, np.datetime64, "datetime64[ns]", "M"): return types.TimestampNTZType() if prefer_timestamp_ntz else types.TimestampType() + # DayTimeIntervalType + elif tpe in (datetime.timedelta, np.timedelta64, "timedelta64[ns]"): + return types.DayTimeIntervalType() + # categorical types elif isinstance(tpe, CategoricalDtype) or (isinstance(tpe, str) and type == "category"): return types.LongType() @@ -330,6 +334,8 @@ def pandas_on_spark_type(tpe: Union[str, type, Dtype]) -> Tuple[Dtype, types.Dat (dtype('O'), DateType) >>> pandas_on_spark_type(datetime.datetime) (dtype('>> pandas_on_spark_type(datetime.timedelta) + (dtype('>> pandas_on_spark_type(List[bool]) (dtype('O'), ArrayType(BooleanType,true)) """ diff --git a/python/pyspark/pandas/usage_logging/__init__.py b/python/pyspark/pandas/usage_logging/__init__.py index ebd23ac637..b350faf6b9 100644 --- a/python/pyspark/pandas/usage_logging/__init__.py +++ b/python/pyspark/pandas/usage_logging/__init__.py @@ -25,7 +25,7 @@ import pandas as pd -from pyspark.pandas import config, namespace, sql_processor +from pyspark.pandas import config, namespace, sql_formatter from pyspark.pandas.accessors import PandasOnSparkFrameMethods from pyspark.pandas.frame import DataFrame from pyspark.pandas.datetimes import DatetimeMethods @@ -113,8 +113,8 @@ def attach(logger_module: Union[str, ModuleType]) -> None: except ImportError: pass - sql_processor._CAPTURE_SCOPES = 3 - modules.append(sql_processor) + sql_formatter._CAPTURE_SCOPES = 4 + modules.append(sql_formatter) # Modules for target_module in modules: diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py index 778b23fe2b..3d3a65a46c 100644 --- a/python/pyspark/profiler.py +++ b/python/pyspark/profiler.py @@ -27,12 +27,13 @@ class ProfilerCollector(object): """ This class keeps track of different profilers on a per - stage basis. Also this is used to create new profilers for - the different stages. + stage/UDF basis. Also this is used to create new profilers for + the different stages/UDFs. """ - def __init__(self, profiler_cls, dump_path=None): + def __init__(self, profiler_cls, udf_profiler_cls, dump_path=None): self.profiler_cls = profiler_cls + self.udf_profiler_cls = udf_profiler_cls self.profile_dump_path = dump_path self.profilers = [] @@ -40,8 +41,12 @@ def new_profiler(self, ctx): """Create a new profiler using class `profiler_cls`""" return self.profiler_cls(ctx) + def new_udf_profiler(self, ctx): + """Create a new profiler using class `udf_profiler_cls`""" + return self.udf_profiler_cls(ctx) + def add_profiler(self, id, profiler): - """Add a profiler for RDD `id`""" + """Add a profiler for RDD/UDF `id`""" if not self.profilers: if self.profile_dump_path: atexit.register(self.dump_profiles, self.profile_dump_path) @@ -106,7 +111,7 @@ class Profiler(object): def __init__(self, ctx): pass - def profile(self, func): + def profile(self, func, *args, **kwargs): """Do profiling on the function `func`""" raise NotImplementedError @@ -160,10 +165,10 @@ def __init__(self, ctx): # partitions of a stage self._accumulator = ctx.accumulator(None, PStatsParam) - def profile(self, func): + def profile(self, func, *args, **kwargs): """Runs and profiles the method to_profile passed in. A profile object is returned.""" pr = cProfile.Profile() - pr.runcall(func) + ret = pr.runcall(func, *args, **kwargs) st = pstats.Stats(pr) st.stream = None # make it picklable st.strip_dirs() @@ -171,10 +176,36 @@ def profile(self, func): # Adds a new profile to the existing accumulated value self._accumulator.add(st) + return ret + def stats(self): return self._accumulator.value +class UDFBasicProfiler(BasicProfiler): + """ + UDFBasicProfiler is the profiler for Python/Pandas UDFs. + """ + + def show(self, id): + """Print the profile stats to stdout, id is the PythonUDF id""" + stats = self.stats() + if stats: + print("=" * 60) + print("Profile of UDF" % id) + print("=" * 60) + stats.sort_stats("time", "cumulative").print_stats() + + def dump(self, id, path): + """Dump the profile into path, id is the PythonUDF id""" + if not os.path.exists(path): + os.makedirs(path) + stats = self.stats() + if stats: + p = os.path.join(path, "udf_%d.pstats" % id) + stats.dump_stats(p) + + if __name__ == "__main__": import doctest diff --git a/python/pyspark/profiler.pyi b/python/pyspark/profiler.pyi index d6a216b7f2..85aa6a2480 100644 --- a/python/pyspark/profiler.pyi +++ b/python/pyspark/profiler.pyi @@ -25,17 +25,24 @@ from pyspark.context import SparkContext class ProfilerCollector: profiler_cls: Type[Profiler] + udf_profiler_cls: Type[Profiler] profile_dump_path: Optional[str] profilers: List[Tuple[int, Profiler, bool]] - def __init__(self, profiler_cls: Type[Profiler], dump_path: Optional[str] = ...) -> None: ... + def __init__( + self, + profiler_cls: Type[Profiler], + udf_profiler_cls: Type[Profiler], + dump_path: Optional[str] = ..., + ) -> None: ... def new_profiler(self, ctx: SparkContext) -> Profiler: ... + def new_udf_profiler(self, ctx: SparkContext) -> Profiler: ... def add_profiler(self, id: int, profiler: Profiler) -> None: ... def dump_profiles(self, path: str) -> None: ... def show_profiles(self) -> None: ... class Profiler: def __init__(self, ctx: SparkContext) -> None: ... - def profile(self, func: Callable[[], Any]) -> None: ... + def profile(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: ... def stats(self) -> pstats.Stats: ... def show(self, id: int) -> None: ... def dump(self, id: int, path: str) -> None: ... @@ -50,5 +57,9 @@ class PStatsParam(AccumulatorParam): class BasicProfiler(Profiler): def __init__(self, ctx: SparkContext) -> None: ... - def profile(self, func: Callable[[], Any]) -> None: ... + def profile(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: ... def stats(self) -> pstats.Stats: ... + +class UDFBasicProfiler(BasicProfiler): + def show(self, id: int) -> None: ... + def dump(self, id: int, path: str) -> None: ... diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index b997932c80..2452d69237 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -39,7 +39,7 @@ CartesianDeserializer, CloudPickleSerializer, PairDeserializer, - PickleSerializer, + CPickleSerializer, pack_long, read_int, write_int, @@ -259,7 +259,7 @@ class RDD(object): operated on in parallel. """ - def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSerializer())): + def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(CPickleSerializer())): self._jrdd = jrdd self.is_cached = False self.is_checkpointed = False @@ -270,7 +270,7 @@ def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSeri self.partitioner = None def _pickled(self): - return self._reserialize(AutoBatchedSerializer(PickleSerializer())) + return self._reserialize(AutoBatchedSerializer(CPickleSerializer())) def id(self): """ @@ -1841,7 +1841,7 @@ def saveAsSequenceFile(self, path, compressionCodecClass=None): def saveAsPickleFile(self, path, batchSize=10): """ Save this RDD as a SequenceFile of serialized objects. The serializer - used is :class:`pyspark.serializers.PickleSerializer`, default batch size + used is :class:`pyspark.serializers.CPickleSerializer`, default batch size is 10. Examples @@ -1854,9 +1854,9 @@ def saveAsPickleFile(self, path, batchSize=10): ['1', '2', 'rdd', 'spark'] """ if batchSize == 0: - ser = AutoBatchedSerializer(PickleSerializer()) + ser = AutoBatchedSerializer(CPickleSerializer()) else: - ser = BatchedSerializer(PickleSerializer(), batchSize) + ser = BatchedSerializer(CPickleSerializer(), batchSize) self._reserialize(ser)._jrdd.saveAsObjectFile(path) def saveAsTextFile(self, path, compressionCodecClass=None): @@ -2520,7 +2520,7 @@ def coalesce(self, numPartitions, shuffle=False): # Decrease the batch size in order to distribute evenly the elements across output # partitions. Otherwise, repartition will possibly produce highly skewed partitions. batchSize = min(10, self.ctx._batchSize or 1024) - ser = BatchedSerializer(PickleSerializer(), batchSize) + ser = BatchedSerializer(CPickleSerializer(), batchSize) selfCopy = self._reserialize(ser) jrdd_deserializer = selfCopy._jrdd_deserializer jrdd = selfCopy._jrdd.coalesce(numPartitions, shuffle) @@ -2551,7 +2551,7 @@ def get_batch_size(ser): return 1 # not batched def batch_as(rdd, batchSize): - return rdd._reserialize(BatchedSerializer(PickleSerializer(), batchSize)) + return rdd._reserialize(BatchedSerializer(CPickleSerializer(), batchSize)) my_batch = get_batch_size(self._jrdd_deserializer) other_batch = get_batch_size(other._jrdd_deserializer) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 06eed0a3bc..766ea64d90 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -19,7 +19,7 @@ PySpark supports custom serializers for transferring data; this can improve performance. -By default, PySpark uses :class:`PickleSerializer` to serialize objects using Python's +By default, PySpark uses :class:`CloudPickleSerializer` to serialize objects using Python's `cPickle` serializer, which can serialize nearly any Python object. Other serializers, like :class:`MarshalSerializer`, support fewer datatypes but can be faster. @@ -69,7 +69,13 @@ from pyspark.util import print_exec # type: ignore -__all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"] +__all__ = [ + "PickleSerializer", + "CPickleSerializer", + "CloudPickleSerializer", + "MarshalSerializer", + "UTF8Deserializer", +] class SpecialLengths(object): @@ -344,78 +350,81 @@ def dumps(self, obj): return obj -# Hack namedtuple, make it picklable - -__cls = {} # type: ignore - - -def _restore(name, fields, value): - """Restore an object of namedtuple""" - k = (name, fields) - cls = __cls.get(k) - if cls is None: - cls = collections.namedtuple(name, fields) - __cls[k] = cls - return cls(*value) - - -def _hack_namedtuple(cls): - """Make class generated by namedtuple picklable""" - name = cls.__name__ - fields = cls._fields - - def __reduce__(self): - return (_restore, (name, fields, tuple(self))) - - cls.__reduce__ = __reduce__ - cls._is_namedtuple_ = True - return cls - - -def _hijack_namedtuple(): - """Hack namedtuple() to make it picklable""" - # hijack only one time - if hasattr(collections.namedtuple, "__hijack"): - return - - global _old_namedtuple # or it will put in closure - global _old_namedtuple_kwdefaults # or it will put in closure too - - def _copy_func(f): - return types.FunctionType( - f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__ - ) - - _old_namedtuple = _copy_func(collections.namedtuple) - _old_namedtuple_kwdefaults = collections.namedtuple.__kwdefaults__ - - def namedtuple(*args, **kwargs): - for k, v in _old_namedtuple_kwdefaults.items(): - kwargs[k] = kwargs.get(k, v) - cls = _old_namedtuple(*args, **kwargs) - return _hack_namedtuple(cls) - - # replace namedtuple with the new one - collections.namedtuple.__globals__["_old_namedtuple_kwdefaults"] = _old_namedtuple_kwdefaults - collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple - collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple - collections.namedtuple.__code__ = namedtuple.__code__ - collections.namedtuple.__hijack = 1 - - # hack the cls already generated by namedtuple. - # Those created in other modules can be pickled as normal, - # so only hack those in __main__ module - for n, o in sys.modules["__main__"].__dict__.items(): - if ( - type(o) is type - and o.__base__ is tuple - and hasattr(o, "_fields") - and "__reduce__" not in o.__dict__ - ): - _hack_namedtuple(o) # hack inplace - +if sys.version_info < (3, 8): + # Hack namedtuple, make it picklable. + # For Python 3.8+, we use CPickle-based cloudpickle. + # For Python 3.7 and below, we use legacy build-in CPickle which + # requires namedtuple hack. + # The whole hack here should be removed once we drop Python 3.7. + + __cls = {} # type: ignore + + def _restore(name, fields, value): + """Restore an object of namedtuple""" + k = (name, fields) + cls = __cls.get(k) + if cls is None: + cls = collections.namedtuple(name, fields) + __cls[k] = cls + return cls(*value) + + def _hack_namedtuple(cls): + """Make class generated by namedtuple picklable""" + name = cls.__name__ + fields = cls._fields + + def __reduce__(self): + return (_restore, (name, fields, tuple(self))) + + cls.__reduce__ = __reduce__ + cls._is_namedtuple_ = True + return cls + + def _hijack_namedtuple(): + """Hack namedtuple() to make it picklable""" + # hijack only one time + if hasattr(collections.namedtuple, "__hijack"): + return -_hijack_namedtuple() + global _old_namedtuple # or it will put in closure + global _old_namedtuple_kwdefaults # or it will put in closure too + + def _copy_func(f): + return types.FunctionType( + f.__code__, f.__globals__, f.__name__, f.__defaults__, f.__closure__ + ) + + _old_namedtuple = _copy_func(collections.namedtuple) + _old_namedtuple_kwdefaults = collections.namedtuple.__kwdefaults__ + + def namedtuple(*args, **kwargs): + for k, v in _old_namedtuple_kwdefaults.items(): + kwargs[k] = kwargs.get(k, v) + cls = _old_namedtuple(*args, **kwargs) + return _hack_namedtuple(cls) + + # replace namedtuple with the new one + collections.namedtuple.__globals__[ + "_old_namedtuple_kwdefaults" + ] = _old_namedtuple_kwdefaults + collections.namedtuple.__globals__["_old_namedtuple"] = _old_namedtuple + collections.namedtuple.__globals__["_hack_namedtuple"] = _hack_namedtuple + collections.namedtuple.__code__ = namedtuple.__code__ + collections.namedtuple.__hijack = 1 + + # hack the cls already generated by namedtuple. + # Those created in other modules can be pickled as normal, + # so only hack those in __main__ module + for n, o in sys.modules["__main__"].__dict__.items(): + if ( + type(o) is type + and o.__base__ is tuple + and hasattr(o, "_fields") + and "__reduce__" not in o.__dict__ + ): + _hack_namedtuple(o) # hack inplace + + _hijack_namedtuple() class PickleSerializer(FramedSerializer): @@ -436,7 +445,7 @@ def loads(self, obj, encoding="bytes"): return pickle.loads(obj, encoding=encoding) -class CloudPickleSerializer(PickleSerializer): +class CloudPickleSerializer(FramedSerializer): def dumps(self, obj): try: return cloudpickle.dumps(obj, pickle_protocol) @@ -451,6 +460,15 @@ def dumps(self, obj): print_exec(sys.stderr) raise pickle.PicklingError(msg) + def loads(self, obj, encoding="bytes"): + return cloudpickle.loads(obj, encoding=encoding) + + +if sys.version_info < (3, 8): + CPickleSerializer = PickleSerializer +else: + CPickleSerializer = CloudPickleSerializer + class MarshalSerializer(FramedSerializer): @@ -459,7 +477,7 @@ class MarshalSerializer(FramedSerializer): http://docs.python.org/2/library/marshal.html - This serializer is faster than PickleSerializer but supports fewer datatypes. + This serializer is faster than CloudPickleSerializer but supports fewer datatypes. """ def dumps(self, obj): diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 9dbe314d29..bd455667f3 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -28,7 +28,7 @@ import heapq from pyspark.serializers import ( BatchedSerializer, - PickleSerializer, + CPickleSerializer, FlattenedValuesSerializer, CompressedSerializer, AutoBatchedSerializer, @@ -140,8 +140,8 @@ def items(self): def _compressed_serializer(self, serializer=None): - # always use PickleSerializer to simplify implementation - ser = PickleSerializer() + # always use CPickleSerializer to simplify implementation + ser = CPickleSerializer() return AutoBatchedSerializer(CompressedSerializer(ser)) @@ -609,7 +609,7 @@ def _open_file(self): os.makedirs(d) p = os.path.join(d, str(id(self))) self._file = open(p, "w+b", 65536) - self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024) + self._ser = BatchedSerializer(CompressedSerializer(CPickleSerializer()), 1024) os.unlink(p) def __del__(self): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 337cad534f..160e7c3841 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -47,7 +47,7 @@ _load_from_socket, _local_iterator_from_socket, ) -from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer +from pyspark.serializers import BatchedSerializer, CPickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column @@ -121,7 +121,7 @@ def rdd(self) -> "RDD[Row]": """Returns the content as an :class:`pyspark.RDD` of :class:`Row`.""" if self._lazy_rdd is None: jrdd = self._jdf.javaToPython() - self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) + self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(CPickleSerializer())) return self._lazy_rdd @property # type: ignore[misc] @@ -592,7 +592,7 @@ def _repr_html_(self) -> Optional[str]: max_num_rows, self.sql_ctx._conf.replEagerEvalTruncate(), # type: ignore[attr-defined] ) - rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) + rows = list(_load_from_socket(sock_info, BatchedSerializer(CPickleSerializer()))) head = rows[0] row_data = rows[1:] has_more_data = len(row_data) > max_num_rows @@ -769,7 +769,7 @@ def collect(self) -> List[Row]: """ with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.collectToPython() - return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) + return list(_load_from_socket(sock_info, BatchedSerializer(CPickleSerializer()))) def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]: """ @@ -792,7 +792,7 @@ def toLocalIterator(self, prefetchPartitions: bool = False) -> Iterator[Row]: """ with SCCallSiteSync(self._sc) as css: sock_info = self._jdf.toPythonIterator(prefetchPartitions) - return _local_iterator_from_socket(sock_info, BatchedSerializer(PickleSerializer())) + return _local_iterator_from_socket(sock_info, BatchedSerializer(CPickleSerializer())) def limit(self, num: int) -> "DataFrame": """Limits the result count to the number specified. @@ -837,7 +837,7 @@ def tail(self, num: int) -> List[Row]: """ with SCCallSiteSync(self._sc): sock_info = self._jdf.tailToPython(num) - return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) + return list(_load_from_socket(sock_info, BatchedSerializer(CPickleSerializer()))) def foreach(self, f: Callable[[Row], None]) -> None: """Applies the ``f`` function to all :class:`Row` of this :class:`DataFrame`. diff --git a/python/pyspark/sql/pandas/utils.py b/python/pyspark/sql/pandas/utils.py index cc0db017c3..bc6202f854 100644 --- a/python/pyspark/sql/pandas/utils.py +++ b/python/pyspark/sql/pandas/utils.py @@ -19,7 +19,7 @@ def require_minimum_pandas_version() -> None: """Raise ImportError if minimum version of Pandas is not installed""" # TODO(HyukjinKwon): Relocate and deduplicate the version specification. - minimum_pandas_version = "0.23.2" + minimum_pandas_version = "1.0.5" from distutils.version import LooseVersion diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 9275541987..f94b9c2115 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -291,7 +291,7 @@ def __init__( self, sparkContext: SparkContext, jsparkSession: Optional[JavaObject] = None, - options: Optional[Dict[str, Any]] = None, + options: Optional[Dict[str, Any]] = {}, ): from pyspark.sql.context import SQLContext @@ -305,10 +305,7 @@ def __init__( ): jsparkSession = self._jvm.SparkSession.getDefaultSession().get() else: - jsparkSession = self._jvm.SparkSession(self._jsc.sc()) - if options is not None: - for key, value in options.items(): - jsparkSession.sharedState().conf().set(key, value) + jsparkSession = self._jvm.SparkSession(self._jsc.sc(), options) self._jsparkSession = jsparkSession self._jwrapped = self._jsparkSession.sqlContext() self._wrapped = SQLContext(self._sc, self, self._jwrapped) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 53a098ce49..74593d0700 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -1200,7 +1200,7 @@ def foreach(self, f: Union[Callable[[Row], None], "SupportsProcess"]) -> "DataSt """ from pyspark.rdd import _wrap_function # type: ignore[attr-defined] - from pyspark.serializers import PickleSerializer, AutoBatchedSerializer + from pyspark.serializers import CPickleSerializer, AutoBatchedSerializer from pyspark.taskcontext import TaskContext if callable(f): @@ -1268,7 +1268,7 @@ def func_with_open_process_close(partition_id: Any, iterator: Iterator) -> Itera func = func_with_open_process_close # type: ignore[assignment] - serializer = AutoBatchedSerializer(PickleSerializer()) + serializer = AutoBatchedSerializer(CPickleSerializer()) wrapped_func = _wrap_function(self._spark._sc, func, serializer, serializer) jForeachWriter = self._spark._sc._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter( # type: ignore[attr-defined] wrapped_func, self._df._jdf.schema() diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py index 3927d75d10..e67190fa58 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py @@ -193,7 +193,7 @@ def mean_and_std_udf(v): with self.assertRaisesRegex(NotImplementedError, "not supported"): @pandas_udf(ArrayType(TimestampType()), PandasUDFType.GROUPED_AGG) - def mean_and_std_udf(v): + def mean_and_std_udf(v): # noqa: F811 return {v.mean(): v.std()} def test_alias(self): diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py index 830db43f7f..bee9cff525 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py +++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py @@ -697,8 +697,8 @@ def scalar_check_data(idx, date, date_copy): return pd.Series(msgs) def iter_check_data(it): - for idx, date, date_copy in it: - yield scalar_check_data(idx, date, date_copy) + for idx, test_date, date_copy in it: + yield scalar_check_data(idx, test_date, date_copy) pandas_scalar_check_data = pandas_udf(scalar_check_data, StringType()) pandas_iter_check_data = pandas_udf( diff --git a/python/pyspark/sql/tests/test_session.py b/python/pyspark/sql/tests/test_session.py index eb23b68ccf..06771fac89 100644 --- a/python/pyspark/sql/tests/test_session.py +++ b/python/pyspark/sql/tests/test_session.py @@ -289,18 +289,23 @@ def test_another_spark_session(self): if session2 is not None: session2.stop() - def test_create_spark_context_first_and_copy_options_to_sharedState(self): + def test_create_spark_context_with_initial_session_options(self): sc = None session = None try: conf = SparkConf().set("key1", "value1") sc = SparkContext("local[4]", "SessionBuilderTests", conf=conf) session = ( - SparkSession.builder.config("key2", "value2").enableHiveSupport().getOrCreate() + SparkSession.builder.config("spark.sql.codegen.comments", "true") + .enableHiveSupport() + .getOrCreate() ) self.assertEqual(session._jsparkSession.sharedState().conf().get("key1"), "value1") - self.assertEqual(session._jsparkSession.sharedState().conf().get("key2"), "value2") + self.assertEqual( + session._jsparkSession.sharedState().conf().get("spark.sql.codegen.comments"), + "true", + ) self.assertEqual( session._jsparkSession.sharedState().conf().get("spark.sql.catalogImplementation"), "hive", diff --git a/python/pyspark/sql/tests/test_udf_profiler.py b/python/pyspark/sql/tests/test_udf_profiler.py new file mode 100644 index 0000000000..27d9458509 --- /dev/null +++ b/python/pyspark/sql/tests/test_udf_profiler.py @@ -0,0 +1,109 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tempfile +import unittest +import os +import sys +from io import StringIO + +from pyspark import SparkConf, SparkContext +from pyspark.sql import SparkSession +from pyspark.sql.functions import udf +from pyspark.profiler import UDFBasicProfiler + + +class UDFProfilerTests(unittest.TestCase): + def setUp(self): + self._old_sys_path = list(sys.path) + class_name = self.__class__.__name__ + conf = SparkConf().set("spark.python.profile", "true") + self.sc = SparkContext("local[4]", class_name, conf=conf) + self.spark = SparkSession.builder._sparkContext(self.sc).getOrCreate() + + def tearDown(self): + self.spark.stop() + sys.path = self._old_sys_path + + def test_udf_profiler(self): + self.do_computation() + + profilers = self.sc.profiler_collector.profilers + self.assertEqual(3, len(profilers)) + + old_stdout = sys.stdout + try: + sys.stdout = io = StringIO() + self.sc.show_profiles() + finally: + sys.stdout = old_stdout + + d = tempfile.gettempdir() + self.sc.dump_profiles(d) + + for i, udf_name in enumerate(["add1", "add2", "add1"]): + id, profiler, _ = profilers[i] + with self.subTest(id=id, udf_name=udf_name): + stats = profiler.stats() + self.assertTrue(stats is not None) + width, stat_list = stats.get_print_list([]) + func_names = [func_name for fname, n, func_name in stat_list] + self.assertTrue(udf_name in func_names) + + self.assertTrue(udf_name in io.getvalue()) + self.assertTrue("udf_%d.pstats" % id in os.listdir(d)) + + def test_custom_udf_profiler(self): + class TestCustomProfiler(UDFBasicProfiler): + def show(self, id): + self.result = "Custom formatting" + + self.sc.profiler_collector.udf_profiler_cls = TestCustomProfiler + + self.do_computation() + + profilers = self.sc.profiler_collector.profilers + self.assertEqual(3, len(profilers)) + _, profiler, _ = profilers[0] + self.assertTrue(isinstance(profiler, TestCustomProfiler)) + + self.sc.show_profiles() + self.assertEqual("Custom formatting", profiler.result) + + def do_computation(self): + @udf + def add1(x): + return x + 1 + + @udf + def add2(x): + return x + 2 + + df = self.spark.range(10) + df.select(add1("id"), add2("id"), add1("id")).collect() + + +if __name__ == "__main__": + from pyspark.sql.tests.test_udf_profiler import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 886451a7cc..0b47f8796a 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -18,12 +18,14 @@ User-defined function related classes and functions """ import functools +import inspect import sys from typing import Callable, Any, TYPE_CHECKING, Optional, cast, Union from py4j.java_gateway import JavaObject # type: ignore[import] from pyspark import SparkContext +from pyspark.profiler import Profiler from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType # type: ignore[attr-defined] from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import ( @@ -209,16 +211,16 @@ def _judf(self) -> JavaObject: # This is unlikely, doesn't affect correctness, # and should have a minimal performance impact. if self._judf_placeholder is None: - self._judf_placeholder = self._create_judf() + self._judf_placeholder = self._create_judf(self.func) return self._judf_placeholder - def _create_judf(self) -> JavaObject: + def _create_judf(self, func: Callable[..., Any]) -> JavaObject: from pyspark.sql import SparkSession spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext - wrapped_func = _wrap_function(sc, self.func, self.returnType) + wrapped_func = _wrap_function(sc, func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( # type: ignore[attr-defined] self._name, wrapped_func, jdt, self.evalType, self.deterministic @@ -226,9 +228,29 @@ def _create_judf(self) -> JavaObject: return judf def __call__(self, *cols: "ColumnOrName") -> Column: - judf = self._judf sc = SparkContext._active_spark_context # type: ignore[attr-defined] - return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) + + profiler: Optional[Profiler] = None + if sc.profiler_collector: + f = self.func + profiler = sc.profiler_collector.new_udf_profiler(sc) + + @functools.wraps(f) + def func(*args: Any, **kwargs: Any) -> Any: + assert profiler is not None + return profiler.profile(f, *args, **kwargs) + + func.__signature__ = inspect.signature(f) # type: ignore[attr-defined] + + judf = self._create_judf(func) + else: + judf = self._judf + + jPythonUDF = judf.apply(_to_seq(sc, cols, _to_java_column)) + if profiler is not None: + id = jPythonUDF.expr().resultId().id() + sc.profiler_collector.add_profiler(id, profiler) + return Column(jPythonUDF) # This function is for improving the online help system in the interactive interpreter. # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py index a2c9ce90e9..51c1149080 100644 --- a/python/pyspark/tests/test_rdd.py +++ b/python/pyspark/tests/test_rdd.py @@ -29,7 +29,7 @@ from pyspark.serializers import ( CloudPickleSerializer, BatchedSerializer, - PickleSerializer, + CPickleSerializer, MarshalSerializer, UTF8Deserializer, NoOpSerializer, @@ -446,7 +446,7 @@ def test_zip_with_different_serializers(self): a = self.sc.parallelize(range(5)) b = self.sc.parallelize(range(100, 105)) self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) - a = a._reserialize(BatchedSerializer(PickleSerializer(), 2)) + a = a._reserialize(BatchedSerializer(CPickleSerializer(), 2)) b = b._reserialize(MarshalSerializer()) self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)]) # regression test for SPARK-4841 diff --git a/python/pyspark/tests/test_serializers.py b/python/pyspark/tests/test_serializers.py index 3a9e14dd16..019f5279bc 100644 --- a/python/pyspark/tests/test_serializers.py +++ b/python/pyspark/tests/test_serializers.py @@ -29,7 +29,7 @@ PairDeserializer, FlattenedValuesSerializer, CartesianDeserializer, - PickleSerializer, + CPickleSerializer, UTF8Deserializer, MarshalSerializer, ) @@ -46,15 +46,13 @@ class SerializationTestCase(unittest.TestCase): def test_namedtuple(self): from collections import namedtuple - from pickle import dumps, loads + from pyspark.cloudpickle import dumps, loads P = namedtuple("P", "x y") p1 = P(1, 3) p2 = loads(dumps(p1, 2)) self.assertEqual(p1, p2) - from pyspark.cloudpickle import dumps - P2 = loads(dumps(P)) p3 = P2(1, 3) self.assertEqual(p1, p3) @@ -132,7 +130,7 @@ def foo(): ser.dumps(foo) def test_compressed_serializer(self): - ser = CompressedSerializer(PickleSerializer()) + ser = CompressedSerializer(CPickleSerializer()) from io import BytesIO as StringIO io = StringIO() @@ -147,15 +145,15 @@ def test_compressed_serializer(self): def test_hash_serializer(self): hash(NoOpSerializer()) hash(UTF8Deserializer()) - hash(PickleSerializer()) + hash(CPickleSerializer()) hash(MarshalSerializer()) hash(AutoSerializer()) - hash(BatchedSerializer(PickleSerializer())) + hash(BatchedSerializer(CPickleSerializer())) hash(AutoBatchedSerializer(MarshalSerializer())) hash(PairDeserializer(NoOpSerializer(), UTF8Deserializer())) hash(CartesianDeserializer(NoOpSerializer(), UTF8Deserializer())) - hash(CompressedSerializer(PickleSerializer())) - hash(FlattenedValuesSerializer(PickleSerializer())) + hash(CompressedSerializer(CPickleSerializer())) + hash(FlattenedValuesSerializer(CPickleSerializer())) @unittest.skipIf(not have_scipy, "SciPy not installed") diff --git a/python/pyspark/tests/test_shuffle.py b/python/pyspark/tests/test_shuffle.py index 805a47dd1e..5d69b67fc3 100644 --- a/python/pyspark/tests/test_shuffle.py +++ b/python/pyspark/tests/test_shuffle.py @@ -19,7 +19,7 @@ from py4j.protocol import Py4JJavaError -from pyspark import shuffle, PickleSerializer, SparkConf, SparkContext +from pyspark import shuffle, CPickleSerializer, SparkConf, SparkContext from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter @@ -80,7 +80,7 @@ def gen_gs(N, step=1): self.assertEqual(k, len(vs)) self.assertEqual(list(range(k)), list(vs)) - ser = PickleSerializer() + ser = CPickleSerializer() l = ser.loads(ser.dumps(list(gen_gs(50002, 30000)))) for k, vs in l: self.assertEqual(k, len(vs)) diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py index 95fd094bfc..64e7b7d6a1 100644 --- a/python/pyspark/tests/test_worker.py +++ b/python/pyspark/tests/test_worker.py @@ -24,7 +24,7 @@ has_resource_module = True try: - import resource # noqa: F401 + import resource except ImportError: has_resource_module = False @@ -210,8 +210,6 @@ def test_memory_limit(self): rdd = self.sc.parallelize(range(1), 1) def getrlimit(): - import resource - return resource.getrlimit(resource.RLIMIT_AS) actual = rdd.map(lambda _: getrlimit()).collect() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index c2200b20fe..1935e27d66 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -50,7 +50,7 @@ read_int, SpecialLengths, UTF8Deserializer, - PickleSerializer, + CPickleSerializer, BatchedSerializer, ) from pyspark.sql.pandas.serializers import ( @@ -63,7 +63,7 @@ from pyspark.util import fail_on_stopiteration, try_simplify_traceback # type: ignore from pyspark import shuffle -pickleSer = PickleSerializer() +pickleSer = CPickleSerializer() utf8_deserializer = UTF8Deserializer() @@ -367,7 +367,7 @@ def read_udfs(pickleSer, infile, eval_type): timezone, safecheck, assign_cols_by_name, df_for_struct ) else: - ser = BatchedSerializer(PickleSerializer(), 100) + ser = BatchedSerializer(CPickleSerializer(), 100) num_udfs = read_int(infile) diff --git a/python/setup.py b/python/setup.py index 4507a2686e..174995d4ae 100644 --- a/python/setup.py +++ b/python/setup.py @@ -111,7 +111,7 @@ def _supports_symlinks(): # For Arrow, you should also check ./pom.xml and ensure there are no breaking changes in the # binary format protocol with the Java version, see ARROW_HOME/format/* for specifications. # Also don't forget to update python/docs/source/getting_started/install.rst. -_minimum_pandas_version = "0.23.2" +_minimum_pandas_version = "1.0.5" _minimum_pyarrow_version = "1.0.0" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala index 9e3d23c006..192b5993ef 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsPollingSnapshotSource.scala @@ -24,12 +24,21 @@ import io.fabric8.kubernetes.client.KubernetesClient import scala.collection.JavaConverters._ import org.apache.spark.SparkConf +import org.apache.spark.annotation.{DeveloperApi, Since, Stable} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging import org.apache.spark.util.{ThreadUtils, Utils} -private[spark] class ExecutorPodsPollingSnapshotSource( +/** + * :: DeveloperApi :: + * + * A class used for polling K8s executor pods by ExternalClusterManagers. + * @since 3.1.3 + */ +@Stable +@DeveloperApi +class ExecutorPodsPollingSnapshotSource( conf: SparkConf, kubernetesClient: KubernetesClient, snapshotsStore: ExecutorPodsSnapshotsStore, @@ -39,6 +48,7 @@ private[spark] class ExecutorPodsPollingSnapshotSource( private var pollingFuture: Future[_] = _ + @Since("3.1.3") def start(applicationId: String): Unit = { require(pollingFuture == null, "Cannot start polling more than once.") logDebug(s"Starting to check for executor pod state every $pollingInterval ms.") @@ -46,6 +56,7 @@ private[spark] class ExecutorPodsPollingSnapshotSource( new PollRunnable(applicationId), pollingInterval, pollingInterval, TimeUnit.MILLISECONDS) } + @Since("3.1.3") def stop(): Unit = { if (pollingFuture != null) { pollingFuture.cancel(true) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala index 762878cbac..06d942eb5b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala @@ -22,16 +22,27 @@ import io.fabric8.kubernetes.api.model.Pod import io.fabric8.kubernetes.client.{KubernetesClient, Watcher, WatcherException} import io.fabric8.kubernetes.client.Watcher.Action +import org.apache.spark.annotation.{DeveloperApi, Since, Stable} import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils -private[spark] class ExecutorPodsWatchSnapshotSource( +/** + * :: DeveloperApi :: + * + * A class used for watching K8s executor pods by ExternalClusterManagers. + * + * @since 3.1.3 + */ +@Stable +@DeveloperApi +class ExecutorPodsWatchSnapshotSource( snapshotsStore: ExecutorPodsSnapshotsStore, kubernetesClient: KubernetesClient) extends Logging { private var watchConnection: Closeable = _ + @Since("3.1.3") def start(applicationId: String): Unit = { require(watchConnection == null, "Cannot start the watcher twice.") logDebug(s"Starting watch for pods with labels $SPARK_APP_ID_LABEL=$applicationId," + @@ -42,6 +53,7 @@ private[spark] class ExecutorPodsWatchSnapshotSource( .watch(new ExecutorPodsWatcher()) } + @Since("3.1.3") def stop(): Unit = { if (watchConnection != null) { Utils.tryLogNonFatalError { diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile.java17 b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile.java17 index f9ab64e94a..96dd6c996b 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile.java17 +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile.java17 @@ -51,6 +51,7 @@ COPY kubernetes/tests /opt/spark/tests COPY data /opt/spark/data ENV SPARK_HOME /opt/spark +ENV JAVA_HOME /usr/lib/jvm/java-17-openjdk-amd64/ WORKDIR /opt/spark/work-dir RUN chmod g+w /opt/spark/work-dir diff --git a/resource-managers/mesos/src/test/resources/log4j.properties b/resource-managers/mesos/src/test/resources/log4j.properties new file mode 100644 index 0000000000..9ec68901ee --- /dev/null +++ b/resource-managers/mesos/src/test/resources/log4j.properties @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file target/unit-tests.log +log4j.rootCategory=DEBUG, file +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=true +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.sparkproject.jetty=WARN diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 7787e2fc92..e6136fc54f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -169,7 +169,6 @@ private[spark] class Client( def submitApplication(): ApplicationId = { ResourceRequestHelper.validateResources(sparkConf) - var appId: ApplicationId = null try { launcherBackend.connect() yarnClient.init(hadoopConf) @@ -181,7 +180,7 @@ private[spark] class Client( // Get a new application from our RM val newApp = yarnClient.createApplication() val newAppResponse = newApp.getNewApplicationResponse() - appId = newAppResponse.getApplicationId() + this.appId = newAppResponse.getApplicationId() // The app staging dir based on the STAGING_DIR configuration if configured // otherwise based on the users home directory. @@ -207,8 +206,7 @@ private[spark] class Client( yarnClient.submitApplication(appContext) launcherBackend.setAppId(appId.toString) reportLauncherState(SparkAppHandle.State.SUBMITTED) - - appId + this.appId } catch { case e: Throwable => if (stagingDirPath != null) { @@ -915,7 +913,6 @@ private[spark] class Client( private def createContainerLaunchContext(newAppResponse: GetNewApplicationResponse) : ContainerLaunchContext = { logInfo("Setting up container launch context for our AM") - val appId = newAppResponse.getApplicationId val pySparkArchives = if (sparkConf.get(IS_PYTHON_APP)) { findPySparkArchives() @@ -971,7 +968,7 @@ private[spark] class Client( if (isClusterMode) { sparkConf.get(DRIVER_JAVA_OPTIONS).foreach { opts => javaOpts ++= Utils.splitCommandString(opts) - .map(Utils.substituteAppId(_, appId.toString)) + .map(Utils.substituteAppId(_, this.appId.toString)) .map(YarnSparkHadoopUtil.escapeForShell) } val libraryPaths = Seq(sparkConf.get(DRIVER_LIBRARY_PATH), @@ -996,7 +993,7 @@ private[spark] class Client( throw new SparkException(msg) } javaOpts ++= Utils.splitCommandString(opts) - .map(Utils.substituteAppId(_, appId.toString)) + .map(Utils.substituteAppId(_, this.appId.toString)) .map(YarnSparkHadoopUtil.escapeForShell) } sparkConf.get(AM_LIBRARY_PATH).foreach { paths => @@ -1269,7 +1266,7 @@ private[spark] class Client( * throw an appropriate SparkException. */ def run(): Unit = { - this.appId = submitApplication() + submitApplication() if (!launcherBackend.isConnected() && fireAndForget) { val report = getApplicationReport(appId) val state = report.getYarnApplicationState diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 426f529c93..6511489dc4 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -36,6 +36,11 @@ grammar SqlBase; } @lexer::members { + /** + * When true, parser should throw ParseExcetion for unclosed bracketed comment. + */ + public boolean has_unclosed_bracketed_comment = false; + /** * Verify whether current token is a valid decimal token (which contains dot). * Returns true if the character that follows the token is not a digit or letter or underscore. @@ -73,6 +78,16 @@ grammar SqlBase; return false; } } + + /** + * This method will be called when the character stream ends and try to find out the + * unclosed bracketed comment. + * If the method be called, it means the end of the entire character stream match, + * and we set the flag and fail later. + */ + public void markUnclosedComment() { + has_unclosed_bracketed_comment = true; + } } singleStatement @@ -107,7 +122,7 @@ statement : query #statementDefault | ctes? dmlStatementNoWith #dmlStatement | USE multipartIdentifier #use - | USE NAMESPACE multipartIdentifier #useNamespace + | USE namespace multipartIdentifier #useNamespace | SET CATALOG (identifier | STRING) #setCatalog | CREATE namespace (IF NOT EXISTS)? multipartIdentifier (commentSpec | @@ -119,7 +134,7 @@ statement SET locationSpec #setNamespaceLocation | DROP namespace (IF EXISTS)? multipartIdentifier (RESTRICT | CASCADE)? #dropNamespace - | SHOW (DATABASES | NAMESPACES) ((FROM | IN) multipartIdentifier)? + | SHOW namespaces ((FROM | IN) multipartIdentifier)? (LIKE? pattern=STRING)? #showNamespaces | createTableHeader ('(' colTypeList ')')? tableProvider? createTableClauses @@ -213,7 +228,7 @@ statement | SHOW identifier? FUNCTIONS (LIKE? (multipartIdentifier | pattern=STRING))? #showFunctions | SHOW CREATE TABLE multipartIdentifier (AS SERDE)? #showCreateTable - | SHOW CURRENT NAMESPACE #showCurrentNamespace + | SHOW CURRENT namespace #showCurrentNamespace | SHOW CATALOGS (LIKE? pattern=STRING)? #showCatalogs | (DESC | DESCRIBE) FUNCTION EXTENDED? describeFuncName #describeFunction | (DESC | DESCRIBE) namespace EXTENDED? @@ -242,7 +257,7 @@ statement | SET TIME ZONE timezone=(STRING | LOCAL) #setTimeZone | SET TIME ZONE .*? #setTimeZone | SET configKey EQ configValue #setQuotedConfiguration - | SET configKey (EQ .*?)? #setQuotedConfiguration + | SET configKey (EQ .*?)? #setConfiguration | SET .*? EQ configValue #setQuotedConfiguration | SET .*? #setConfiguration | RESET configKey #resetQuotedConfiguration @@ -367,6 +382,12 @@ namespace | SCHEMA ; +namespaces + : NAMESPACES + | DATABASES + | SCHEMAS + ; + describeFuncName : qualifiedName | STRING @@ -601,7 +622,7 @@ fromClause temporalClause : FOR? (SYSTEM_VERSION | VERSION) AS OF version=(INTEGER_VALUE | STRING) - | FOR? (SYSTEM_TIME | TIMESTAMP) AS OF timestamp=STRING + | FOR? (SYSTEM_TIME | TIMESTAMP) AS OF timestamp=valueExpression ; aggregationClause @@ -716,7 +737,8 @@ identifierComment ; relationPrimary - : multipartIdentifier temporalClause? sample? tableAlias #tableName + : multipartIdentifier temporalClause? + sample? tableAlias #tableName | '(' query ')' sample? tableAlias #aliasedQuery | '(' relation ')' sample? tableAlias #aliasedRelation | inlineTable #inlineTableDefault2 @@ -1214,6 +1236,7 @@ ansiNonReserved | ROW | ROWS | SCHEMA + | SCHEMAS | SECOND | SEMI | SEPARATED @@ -1485,6 +1508,7 @@ nonReserved | ROW | ROWS | SCHEMA + | SCHEMAS | SECOND | SELECT | SEPARATED @@ -1612,7 +1636,7 @@ CURRENT_USER: 'CURRENT_USER'; DAY: 'DAY'; DATA: 'DATA'; DATABASE: 'DATABASE'; -DATABASES: 'DATABASES' | 'SCHEMAS'; +DATABASES: 'DATABASES'; DBPROPERTIES: 'DBPROPERTIES'; DEFINED: 'DEFINED'; DELETE: 'DELETE'; @@ -1758,6 +1782,7 @@ ROW: 'ROW'; ROWS: 'ROWS'; SECOND: 'SECOND'; SCHEMA: 'SCHEMA'; +SCHEMAS: 'SCHEMAS'; SELECT: 'SELECT'; SEMI: 'SEMI'; SEPARATED: 'SEPARATED'; @@ -1926,7 +1951,7 @@ SIMPLE_COMMENT ; BRACKETED_COMMENT - : '/*' {!isHint()}? (BRACKETED_COMMENT|.)*? '*/' -> channel(HIDDEN) + : '/*' {!isHint()}? ( BRACKETED_COMMENT | . )*? ('*/' | {markUnclosedComment();} EOF) -> channel(HIDDEN) ; WS diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java index 34f07b12b3..359bc0017b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java @@ -68,6 +68,16 @@ public Table loadTable(Identifier ident) throws NoSuchTableException { return asTableCatalog().loadTable(ident); } + @Override + public Table loadTable(Identifier ident, long timestamp) throws NoSuchTableException { + return asTableCatalog().loadTable(ident, timestamp); + } + + @Override + public Table loadTable(Identifier ident, String version) throws NoSuchTableException { + return asTableCatalog().loadTable(ident, version); + } + @Override public void invalidateTable(Identifier ident) { asTableCatalog().invalidateTable(ident); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java index d7a45f643c..9336c2a1ca 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.errors.QueryCompilationErrors; import org.apache.spark.sql.types.StructType; import java.util.Map; @@ -106,7 +107,7 @@ public interface TableCatalog extends CatalogPlugin { * @throws NoSuchTableException If the table doesn't exist or is a view */ default Table loadTable(Identifier ident, String version) throws NoSuchTableException { - throw new UnsupportedOperationException("Load table with version is not supported."); + throw QueryCompilationErrors.tableNotSupportTimeTravelError(ident); } /** @@ -121,7 +122,7 @@ default Table loadTable(Identifier ident, String version) throws NoSuchTableExce * @throws NoSuchTableException If the table doesn't exist or is a view */ default Table loadTable(Identifier ident, long timestamp) throws NoSuchTableException { - throw new UnsupportedOperationException("Load table with timestamp is not supported."); + throw QueryCompilationErrors.tableNotSupportTimeTravelError(ident); } /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/HasPartitionKey.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/HasPartitionKey.java new file mode 100644 index 0000000000..777693938c --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/HasPartitionKey.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.sql.catalyst.InternalRow; + +/** + * A mix-in for input partitions whose records are clustered on the same set of partition keys + * (provided via {@link SupportsReportPartitioning}, see below). Data sources can opt-in to + * implement this interface for the partitions they report to Spark, which will use the + * information to avoid data shuffling in certain scenarios, such as join, aggregate, etc. Note + * that Spark requires ALL input partitions to implement this interface, otherwise it can't take + * advantage of it. + *

+ * This interface should be used in combination with {@link SupportsReportPartitioning}, which + * allows data sources to report distribution and ordering spec to Spark. In particular, Spark + * expects data sources to report + * {@link org.apache.spark.sql.connector.distributions.ClusteredDistribution} whenever its input + * partitions implement this interface. Spark derives partition keys spec (e.g., column names, + * transforms) from the distribution, and partition values from the input partitions. + *

+ * It is implementor's responsibility to ensure that when an input partition implements this + * interface, its records all have the same value for the partition keys. Spark doesn't check + * this property. + * + * @see org.apache.spark.sql.connector.read.SupportsReportPartitioning + * @see org.apache.spark.sql.connector.read.partitioning.Partitioning + */ +public interface HasPartitionKey extends InputPartition { + /** + * Returns the value of the partition key(s) associated to this partition. An input partition + * implementing this interface needs to ensure that all its records have the same value for the + * partition keys. Note that the value is after partition transform has been applied, if there + * is any. + */ + InternalRow partitionKey(); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 4186a5b640..4f833907b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1063,6 +1063,10 @@ class Analyzer(override val catalogManager: CatalogManager) case u: UnresolvedRelation => lookupRelation(u).map(resolveViews).getOrElse(u) + case RelationTimeTravel(u: UnresolvedRelation, timestamp, version) + if timestamp.forall(_.resolved) => + lookupRelation(u, TimeTravelSpec.create(timestamp, version, conf)).getOrElse(u) + case u @ UnresolvedTable(identifier, cmd, relationTypeMismatchHint) => lookupTableOrView(identifier).map { case v: ResolvedView => @@ -1093,7 +1097,8 @@ class Analyzer(override val catalogManager: CatalogManager) private def lookupTempView( identifier: Seq[String], - isStreaming: Boolean = false): Option[LogicalPlan] = { + isStreaming: Boolean = false, + isTimeTravel: Boolean = false): Option[LogicalPlan] = { // We are resolving a view and this name is not a temp view when that view was created. We // return None earlier here. if (isResolvingView && !isReferredTempViewName(identifier)) return None @@ -1107,6 +1112,9 @@ class Analyzer(override val catalogManager: CatalogManager) if (isStreaming && tmpView.nonEmpty && !tmpView.get.isStreaming) { throw QueryCompilationErrors.readNonStreamingTempViewError(identifier.quoted) } + if (isTimeTravel && tmpView.nonEmpty) { + throw QueryCompilationErrors.viewNotSupportTimeTravelError(identifier) + } tmpView } @@ -1175,8 +1183,10 @@ class Analyzer(override val catalogManager: CatalogManager) * Resolves relations to v1 relation if it's a v1 table from the session catalog, or to v2 * relation. This is for resolving DML commands and SELECT queries. */ - private def lookupRelation(u: UnresolvedRelation): Option[LogicalPlan] = { - lookupTempView(u.multipartIdentifier, u.isStreaming).orElse { + private def lookupRelation( + u: UnresolvedRelation, + timeTravelSpec: Option[TimeTravelSpec] = None): Option[LogicalPlan] = { + lookupTempView(u.multipartIdentifier, u.isStreaming, timeTravelSpec.isDefined).orElse { expandIdentifier(u.multipartIdentifier) match { case CatalogAndIdentifier(catalog, ident) => val key = catalog.name +: ident.namespace :+ ident.name @@ -1186,7 +1196,7 @@ class Analyzer(override val catalogManager: CatalogManager) newRelation.copyTagsFrom(multi) newRelation }).orElse { - val table = CatalogV2Util.loadTable(catalog, ident, u.timeTravelSpec) + val table = CatalogV2Util.loadTable(catalog, ident, timeTravelSpec) val loaded = createRelation(catalog, ident, table, u.options, u.isStreaming) loaded.foreach(AnalysisContext.get.relationCache.update(key, _)) loaded diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index debc13b953..267c2cc886 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -75,7 +75,7 @@ import org.apache.spark.sql.types._ object AnsiTypeCoercion extends TypeCoercionBase { override def typeCoercionRules: List[Rule[LogicalPlan]] = WidenSetOperationTypes :: - CombinedTypeCoercionRule( + new AnsiCombinedTypeCoercionRule( InConversion :: PromoteStringLiterals :: DecimalPrecision :: @@ -304,4 +304,9 @@ object AnsiTypeCoercion extends TypeCoercionBase { s.copy(left = newLeft, right = newRight) } } + + // This is for generating a new rule id, so that we can run both default and Ansi + // type coercion rules against one logical plan. + class AnsiCombinedTypeCoercionRule(rules: Seq[TypeCoercionRule]) extends + CombinedTypeCoercionRule(rules) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala index 2e2d415954..605d32d362 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala @@ -227,7 +227,7 @@ object CTESubstitution extends Rule[LogicalPlan] { alwaysInline: Boolean, cteRelations: Seq[(String, CTERelationDef)]): LogicalPlan = plan.resolveOperatorsUpWithPruning(_.containsAnyPattern(UNRESOLVED_RELATION, PLAN_EXPRESSION)) { - case u @ UnresolvedRelation(Seq(table), _, _, _) => + case u @ UnresolvedRelation(Seq(table), _, _) => cteRelations.find(r => plan.conf.resolver(r._1, table)).map { case (_, d) => if (alwaysInline) { d.child diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 5bf37a2944..491d52588f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils, TypeUtils} import org.apache.spark.sql.connector.catalog.{LookupCatalog, SupportsPartitionManagement} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -47,6 +48,8 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { */ val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil + val DATA_TYPE_MISMATCH_ERROR = TreeNodeTag[Boolean]("dataTypeMismatchError") + protected def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg) } @@ -165,14 +168,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { } } - val exprs = operator match { - // `groupingExpressions` may rely on `aggregateExpressions`, due to the GROUP BY alias - // feature. We should check errors in `aggregateExpressions` first. - case a: Aggregate => a.aggregateExpressions ++ a.groupingExpressions - case _ => operator.expressions - } - - exprs.foreach(_.foreachUp { + getAllExpressions(operator).foreach(_.foreachUp { case a: Attribute if !a.resolved => val missingCol = a.sql val candidates = operator.inputSet.toSeq.map(_.qualifiedName) @@ -189,8 +185,10 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { case e: Expression if e.checkInputDataTypes().isFailure => e.checkInputDataTypes() match { case TypeCheckResult.TypeCheckFailure(message) => + e.setTagValue(DATA_TYPE_MISMATCH_ERROR, true) e.failAnalysis( - s"cannot resolve '${e.sql}' due to data type mismatch: $message") + s"cannot resolve '${e.sql}' due to data type mismatch: $message" + + extraHintForAnsiTypeCoercionExpression(operator)) } case c: Cast if !c.resolved => @@ -424,27 +422,20 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { |the ${ordinalNumber(ti + 1)} table has ${child.output.length} columns """.stripMargin.replace("\n", " ").trim()) } - val isUnion = operator.isInstanceOf[Union] - val dataTypesAreCompatibleFn = if (isUnion) { - (dt1: DataType, dt2: DataType) => - !DataType.equalsStructurally(dt1, dt2, true) - } else { - // SPARK-18058: we shall not care about the nullability of columns - (dt1: DataType, dt2: DataType) => - TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).isEmpty - } + val dataTypesAreCompatibleFn = getDataTypesAreCompatibleFn(operator) // Check if the data types match. dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) => // SPARK-18058: we shall not care about the nullability of columns if (dataTypesAreCompatibleFn(dt1, dt2)) { - failAnalysis( + val errorMessage = s""" |${operator.nodeName} can only be performed on tables with the compatible |column types. The ${ordinalNumber(ci)} column of the |${ordinalNumber(ti + 1)} table is ${dt1.catalogString} type which is not |compatible with ${dt2.catalogString} at same column of first table - """.stripMargin.replace("\n", " ").trim()) + """.stripMargin.replace("\n", " ").trim() + failAnalysis(errorMessage + extraHintForAnsiTypeCoercionPlan(operator)) } } } @@ -593,6 +584,86 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { plan.setAnalyzed() } + private def getAllExpressions(plan: LogicalPlan): Seq[Expression] = { + plan match { + // `groupingExpressions` may rely on `aggregateExpressions`, due to the GROUP BY alias + // feature. We should check errors in `aggregateExpressions` first. + case a: Aggregate => a.aggregateExpressions ++ a.groupingExpressions + case _ => plan.expressions + } + } + + private def getDataTypesAreCompatibleFn(plan: LogicalPlan): (DataType, DataType) => Boolean = { + val isUnion = plan.isInstanceOf[Union] + if (isUnion) { + (dt1: DataType, dt2: DataType) => + !DataType.equalsStructurally(dt1, dt2, true) + } else { + // SPARK-18058: we shall not care about the nullability of columns + (dt1: DataType, dt2: DataType) => + TypeCoercion.findWiderTypeForTwo(dt1.asNullable, dt2.asNullable).isEmpty + } + } + + private def getDefaultTypeCoercionPlan(plan: LogicalPlan): LogicalPlan = + TypeCoercion.typeCoercionRules.foldLeft(plan) { case (p, rule) => rule(p) } + + private def extraHintMessage(issueFixedIfAnsiOff: Boolean): String = { + if (issueFixedIfAnsiOff) { + "\nTo fix the error, you might need to add explicit type casts. If necessary set " + + s"${SQLConf.ANSI_ENABLED.key} to false to bypass this error." + } else { + "" + } + } + + private def extraHintForAnsiTypeCoercionExpression(plan: LogicalPlan): String = { + if (!SQLConf.get.ansiEnabled) { + "" + } else { + val nonAnsiPlan = getDefaultTypeCoercionPlan(plan) + var issueFixedIfAnsiOff = true + getAllExpressions(nonAnsiPlan).foreach(_.foreachUp { + case e: Expression if e.getTagValue(DATA_TYPE_MISMATCH_ERROR).contains(true) && + e.checkInputDataTypes().isFailure => + e.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckFailure(_) => + issueFixedIfAnsiOff = false + } + + case _ => + }) + extraHintMessage(issueFixedIfAnsiOff) + } + } + + private def extraHintForAnsiTypeCoercionPlan(plan: LogicalPlan): String = { + if (!SQLConf.get.ansiEnabled) { + "" + } else { + val nonAnsiPlan = getDefaultTypeCoercionPlan(plan) + var issueFixedIfAnsiOff = true + nonAnsiPlan match { + case _: Union | _: SetOperation if nonAnsiPlan.children.length > 1 => + def dataTypes(plan: LogicalPlan): Seq[DataType] = plan.output.map(_.dataType) + + val ref = dataTypes(nonAnsiPlan.children.head) + val dataTypesAreCompatibleFn = getDataTypesAreCompatibleFn(nonAnsiPlan) + nonAnsiPlan.children.tail.zipWithIndex.foreach { case (child, ti) => + // Check if the data types match. + dataTypes(child).zip(ref).zipWithIndex.foreach { case ((dt1, dt2), ci) => + if (dataTypesAreCompatibleFn(dt1, dt2)) { + issueFixedIfAnsiOff = false + } + } + } + + case _ => + } + extraHintMessage(issueFixedIfAnsiOff) + } + } + /** * Validates subquery expressions in the plan. Upon failure, returns an user facing error. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 06684600e3..b2788f8573 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -455,6 +455,7 @@ object FunctionRegistry { expression[Ascii]("ascii"), expression[Chr]("char", true), expression[Chr]("chr"), + expression[Contains]("contains"), expression[Base64]("base64"), expression[BitLength]("bit_length"), expression[Length]("char_length", true), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationTimeTravel.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationTimeTravel.scala new file mode 100644 index 0000000000..f278ab2867 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationTimeTravel.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} + +/** + * A logical node used to time travel the child relation to the given `timestamp` or `version`. + * The `child` must support time travel, e.g. a v2 source, and cannot be a view, subquery or stream. + * The timestamp expression cannot refer to any columns and cannot contain subqueries. + */ +case class RelationTimeTravel( + relation: LogicalPlan, + timestamp: Option[Expression], + version: Option[String]) extends LeafNode { + override def output: Seq[Attribute] = Nil + override lazy val resolved: Boolean = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index efc1ab2cd0..d7c6301b6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -37,17 +37,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case UnresolvedDBObjectName(CatalogAndIdentifier(catalog, identifier), _) => ResolvedDBObjectName(catalog, identifier.namespace :+ identifier.name()) - case c @ CreateTableStatement( - NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _, _) => - CreateV2Table( - catalog.asTableCatalog, - tbl.asIdentifier, - c.tableSchema, - // convert the bucket spec and add it as a transform - c.partitioning ++ c.bucketSpec.map(_.asTransform), - convertTableProperties(c), - ignoreIfExists = c.ifNotExists) - case c @ CreateTableAsSelectStatement( NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _, _, _) => CreateTableAsSelect( @@ -70,18 +59,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager) c.partitioning ++ c.bucketSpec.map(_.asTransform), convertTableProperties(c), orCreate = c.orCreate) - - case c @ ReplaceTableAsSelectStatement( - NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _, _) => - ReplaceTableAsSelect( - catalog.asTableCatalog, - tbl.asIdentifier, - // convert the bucket spec and add it as a transform - c.partitioning ++ c.bucketSpec.map(_.asTransform), - c.asSelect, - convertTableProperties(c), - writeOptions = c.writeOptions, - orCreate = c.orCreate) } object NonSessionCatalogAndTable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index 10d8d391f4..27f2a5f416 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -106,7 +106,7 @@ object ResolveHints { val newNode = CurrentOrigin.withOrigin(plan.origin) { plan match { - case ResolvedHint(u @ UnresolvedRelation(ident, _, _, _), hint) + case ResolvedHint(u @ UnresolvedRelation(ident, _, _), hint) if matchedIdentifierInHint(ident) => ResolvedHint(u, createHintInfo(hintName).merge(hint, hintErrorHandler)) @@ -114,7 +114,7 @@ object ResolveHints { if matchedIdentifierInHint(extractIdentifier(r)) => ResolvedHint(r, createHintInfo(hintName).merge(hint, hintErrorHandler)) - case UnresolvedRelation(ident, _, _, _) if matchedIdentifierInHint(ident) => + case UnresolvedRelation(ident, _, _) if matchedIdentifierInHint(ident) => ResolvedHint(plan, createHintInfo(hintName)) case r: SubqueryAlias if matchedIdentifierInHint(extractIdentifier(r)) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TimeTravelSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TimeTravelSpec.scala new file mode 100644 index 0000000000..cbb6e8bb06 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TimeTravelSpec.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{AnsiCast, Cast, Expression, RuntimeReplaceable, SubqueryExpression, Unevaluable} +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.TimestampType + +sealed trait TimeTravelSpec + +case class AsOfTimestamp(timestamp: Long) extends TimeTravelSpec +case class AsOfVersion(version: String) extends TimeTravelSpec + +object TimeTravelSpec { + def create( + timestamp: Option[Expression], + version: Option[String], + conf: SQLConf) : Option[TimeTravelSpec] = { + if (timestamp.nonEmpty && version.nonEmpty) { + throw QueryCompilationErrors.invalidTimeTravelSpecError() + } else if (timestamp.nonEmpty) { + val ts = timestamp.get + assert(ts.resolved && ts.references.isEmpty && !SubqueryExpression.hasSubquery(ts)) + if (!AnsiCast.canCast(ts.dataType, TimestampType)) { + throw QueryCompilationErrors.invalidTimestampExprForTimeTravel(ts) + } + val tsToEval = ts.transform { + case r: RuntimeReplaceable => r.child + case _: Unevaluable => + throw QueryCompilationErrors.invalidTimestampExprForTimeTravel(ts) + case e if !e.deterministic => + throw QueryCompilationErrors.invalidTimestampExprForTimeTravel(ts) + } + val tz = Some(conf.sessionLocalTimeZone) + // Set `ansiEnabled` to false, so that it can return null for invalid input and we can provide + // better error message. + val value = Cast(tsToEval, TimestampType, tz, ansiEnabled = false).eval() + if (value == null) { + throw QueryCompilationErrors.invalidTimestampExprForTimeTravel(ts) + } + Some(AsOfTimestamp(value.asInstanceOf[Long])) + } else if (version.nonEmpty) { + Some(AsOfVersion(version.get)) + } else { + None + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 506667461e..82fba93761 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -170,7 +170,7 @@ abstract class TypeCoercionBase { * Type coercion rule that combines multiple type coercion rules and applies them in a single tree * traversal. */ - case class CombinedTypeCoercionRule(rules: Seq[TypeCoercionRule]) extends TypeCoercionRule { + class CombinedTypeCoercionRule(rules: Seq[TypeCoercionRule]) extends TypeCoercionRule { override def transform: PartialFunction[Expression, Expression] = { val transforms = rules.map(_.transform) Function.unlift { e: Expression => @@ -795,7 +795,7 @@ object TypeCoercion extends TypeCoercionBase { override def typeCoercionRules: List[Rule[LogicalPlan]] = WidenSetOperationTypes :: - CombinedTypeCoercionRule( + new CombinedTypeCoercionRule( InConversion :: PromoteStrings :: DecimalPrecision :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index af6837a11c..0785336589 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode} import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.connector.expressions.TimeTravelSpec import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{DataType, Metadata, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -46,8 +45,7 @@ class UnresolvedException(function: String) case class UnresolvedRelation( multipartIdentifier: Seq[String], options: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty(), - override val isStreaming: Boolean = false, - timeTravelSpec: Option[TimeTravelSpec] = None) + override val isStreaming: Boolean = false) extends LeafNode with NamedRelation { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -67,13 +65,9 @@ object UnresolvedRelation { def apply( tableIdentifier: TableIdentifier, extraOptions: CaseInsensitiveStringMap, - isStreaming: Boolean, - timeTravelSpec: Option[TimeTravelSpec]): UnresolvedRelation = { + isStreaming: Boolean): UnresolvedRelation = { UnresolvedRelation( - tableIdentifier.database.toSeq :+ tableIdentifier.table, - extraOptions, - isStreaming, - timeTravelSpec) + tableIdentifier.database.toSeq :+ tableIdentifier.table, extraOptions, isStreaming) } def apply(tableIdentifier: TableIdentifier): UnresolvedRelation = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index a88b509a96..67c57ec278 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.catalog import java.net.URI +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell @@ -258,6 +259,32 @@ object CatalogUtils { new Path(str).toUri } + def makeQualifiedDBObjectPath( + locationUri: URI, + warehousePath: String, + hadoopConf: Configuration): URI = { + if (locationUri.isAbsolute) { + locationUri + } else { + val fullPath = new Path(warehousePath, CatalogUtils.URIToString(locationUri)) + makeQualifiedPath(fullPath.toUri, hadoopConf) + } + } + + def makeQualifiedDBObjectPath( + warehouse: String, + location: String, + hadoopConf: Configuration): String = { + val nsPath = makeQualifiedDBObjectPath(stringToURI(location), warehouse, hadoopConf) + URIToString(nsPath) + } + + def makeQualifiedPath(path: URI, hadoopConf: Configuration): URI = { + val hadoopPath = new Path(path) + val fs = hadoopPath.getFileSystem(hadoopConf) + fs.makeQualified(hadoopPath).toUri + } + private def normalizeColumnName( tableName: String, tableCols: Seq[String], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index f529b13ff5..60f68fb8be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -210,9 +210,7 @@ class SessionCatalog( * FileSystem is changed. */ private def makeQualifiedPath(path: URI): URI = { - val hadoopPath = new Path(path) - val fs = hadoopPath.getFileSystem(hadoopConf) - fs.makeQualified(hadoopPath).toUri + CatalogUtils.makeQualifiedPath(path, hadoopConf) } private def requireDbExists(db: String): Unit = { @@ -254,12 +252,7 @@ class SessionCatalog( } private def makeQualifiedDBPath(locationUri: URI): URI = { - if (locationUri.isAbsolute) { - locationUri - } else { - val fullPath = new Path(conf.warehousePath, CatalogUtils.URIToString(locationUri)) - makeQualifiedPath(fullPath.toUri) - } + CatalogUtils.makeQualifiedDBObjectPath(locationUri, conf.warehousePath, hadoopConf) } def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala index 696d25f8ed..b4ec1645ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT import org.apache.spark.sql.catalyst.util.TimestampFormatter import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ class CSVInferSchema(val options: CSVOptions) extends Serializable { @@ -38,6 +39,13 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { legacyFormat = FAST_DATE_FORMAT, isParsing = true) + private val timestampNTZFormatter = TimestampFormatter( + options.timestampNTZFormatInRead, + options.zoneId, + legacyFormat = FAST_DATE_FORMAT, + isParsing = true, + forTimestampNTZ = true) + private val decimalParser = if (options.locale == Locale.US) { // Special handling the default locale for backward compatibility s: String => new java.math.BigDecimal(s) @@ -109,6 +117,7 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { case LongType => tryParseLong(field) case _: DecimalType => tryParseDecimal(field) case DoubleType => tryParseDouble(field) + case TimestampNTZType => tryParseTimestampNTZ(field) case TimestampType => tryParseTimestamp(field) case BooleanType => tryParseBoolean(field) case StringType => StringType @@ -160,6 +169,17 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { private def tryParseDouble(field: String): DataType = { if ((allCatch opt field.toDouble).isDefined || isInfOrNan(field)) { DoubleType + } else { + tryParseTimestampNTZ(field) + } + } + + private def tryParseTimestampNTZ(field: String): DataType = { + // We can only parse the value as TimestampNTZType if it does not have zone-offset or + // time-zone component and can be parsed with the timestamp formatter. + // Otherwise, it is likely to be a timestamp with timezone. + if ((allCatch opt timestampNTZFormatter.parseWithoutTimeZone(field, true)).isDefined) { + SQLConf.get.timestampType } else { tryParseTimestamp(field) } @@ -225,6 +245,10 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable { } else { Some(DecimalType(range + scale, scale)) } + + case (TimestampNTZType, TimestampType) | (TimestampType, TimestampNTZType) => + Some(TimestampType) + case _ => None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala index 79624b9a60..2a404b14bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVOptions.scala @@ -164,6 +164,10 @@ class CSVOptions( s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS][XXX]" }) + val timestampNTZFormatInRead: Option[String] = parameters.get("timestampNTZFormat") + val timestampNTZFormatInWrite: String = parameters.getOrElse("timestampNTZFormat", + s"${DateFormatter.defaultPattern}'T'HH:mm:ss[.SSS]") + val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false) val maxColumns = getInt("maxColumns", 20480) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala index 8a04e4ca56..10cccd5711 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala @@ -49,7 +49,7 @@ class UnivocityGenerator( legacyFormat = FAST_DATE_FORMAT, isParsing = false) private val timestampNTZFormatter = TimestampFormatter( - options.timestampFormatInWrite, + options.timestampNTZFormatInWrite, options.zoneId, legacyFormat = FAST_DATE_FORMAT, isParsing = false, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index cd5621bbb7..eb827aea73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -94,7 +94,7 @@ class UnivocityParser( legacyFormat = FAST_DATE_FORMAT, isParsing = true) private lazy val timestampNTZFormatter = TimestampFormatter( - options.timestampFormatInRead, + options.timestampNTZFormatInRead, options.zoneId, legacyFormat = FAST_DATE_FORMAT, isParsing = true, @@ -204,7 +204,7 @@ class UnivocityParser( case _: TimestampNTZType => (d: String) => nullSafeDatum(d, name, nullable, options) { datum => - timestampNTZFormatter.parseWithoutTimeZone(datum) + timestampNTZFormatter.parseWithoutTimeZone(datum, true) } case _: DateType => (d: String) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 221f5ae736..3b501d686c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -292,6 +292,16 @@ abstract class Expression extends TreeNode[Expression] { override def simpleStringWithNodeId(): String = { throw QueryExecutionErrors.simpleStringWithNodeIdUnsupportedError(nodeName) } + + protected def typeSuffix = + if (resolved) { + dataType match { + case LongType => "L" + case _ => "" + } + } else { + "" + } } @@ -387,6 +397,7 @@ trait NonSQLExpression extends Expression { transform { case a: Attribute => new PrettyAttribute(a) case a: Alias => PrettyAttribute(a.sql, a.dataType) + case p: PythonUDF => PrettyPythonUDF(p.name, p.dataType, p.children) }.toString } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala index 80e2352869..6b9017a01d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala @@ -62,7 +62,7 @@ case class PythonUDF( override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) - override def toString: String = s"$name(${children.mkString(", ")})" + override def toString: String = s"$name(${children.mkString(", ")})#${resultId.id}$typeSuffix" final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_UDF) @@ -80,3 +80,21 @@ case class PythonUDF( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): PythonUDF = copy(children = newChildren) } + +/** + * A place holder used when printing expressions without debugging information such as the + * result id. + */ +case class PrettyPythonUDF( + name: String, + dataType: DataType, + children: Seq[Expression]) + extends Expression with Unevaluable with NonSQLExpression { + + override def toString: String = s"$name(${children.mkString(", ")})" + + override def nullable: Boolean = true + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): PrettyPythonUDF = copy(children = newChildren) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 71f193e510..5cc81244c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -101,16 +101,6 @@ trait NamedExpression extends Expression { /** Returns a copy of this expression with a new `exprId`. */ def newInstance(): NamedExpression - - protected def typeSuffix = - if (resolved) { - dataType match { - case LongType => "L" - case _ => "" - } - } else { - "" - } } abstract class Attribute extends LeafExpression with NamedExpression with NullIntolerant { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 2b997da29b..959c834846 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -465,6 +465,23 @@ abstract class StringPredicate extends BinaryExpression /** * A function that returns true if the string `left` contains the string `right`. */ +@ExpressionDescription( + usage = """ + _FUNC_(expr1, expr2) - Returns a boolean value if expr2 is found inside expr1. + Returns NULL if either input expression is NULL. + """, + examples = """ + Examples: + > SELECT _FUNC_('Spark SQL', 'Spark'); + true + > SELECT _FUNC_('Spark SQL', 'SPARK'); + false + > SELECT _FUNC_('Spark SQL', null); + NULL + """, + since = "3.3.0", + group = "string_funcs" +) case class Contains(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 754f92b8e0..f7d96f85e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, Inte import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTimestamp, stringToTimestampWithoutTimeZone} import org.apache.spark.sql.connector.catalog.{SupportsNamespaces, TableCatalog} import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition -import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, TimeTravelSpec, Transform, YearsTransform} +import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, Expression => V2Expression, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.errors.QueryParsingErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -1257,18 +1257,30 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg */ override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) { val tableId = visitMultipartIdentifier(ctx.multipartIdentifier) - val timeTravel = if (ctx.temporalClause != null) { - val v = ctx.temporalClause.version - val version = - if (ctx.temporalClause.INTEGER_VALUE != null) Some(v.getText) else Option(v).map(string) - TimeTravelSpec.create(Option(ctx.temporalClause.timestamp).map(string), version) + val relation = UnresolvedRelation(tableId) + val table = mayApplyAliasPlan( + ctx.tableAlias, relation.optionalMap(ctx.temporalClause)(withTimeTravel)) + table.optionalMap(ctx.sample)(withSample) + } + + private def withTimeTravel( + ctx: TemporalClauseContext, plan: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val v = ctx.version + val version = if (ctx.INTEGER_VALUE != null) { + Some(v.getText) } else { - None + Option(v).map(string) } - - val table = mayApplyAliasPlan(ctx.tableAlias, - UnresolvedRelation(tableId, timeTravelSpec = timeTravel)) - table.optionalMap(ctx.sample)(withSample) + val timestamp = Option(ctx.timestamp).map(expression) + if (timestamp.exists(_.references.nonEmpty)) { + throw QueryParsingErrors.invalidTimeTravelSpec( + "timestamp expression cannot refer to any columns", ctx.timestamp) + } + if (timestamp.exists(e => SubqueryExpression.hasSubquery(e))) { + throw QueryParsingErrors.invalidTimeTravelSpec( + "timestamp expression cannot contain subqueries", ctx.timestamp) + } + RelationTimeTravel(plan, timestamp, version) } /** @@ -3140,10 +3152,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg * Create a [[ShowNamespaces]] command. */ override def visitShowNamespaces(ctx: ShowNamespacesContext): LogicalPlan = withOrigin(ctx) { - if (ctx.DATABASES != null && ctx.multipartIdentifier != null) { - throw QueryParsingErrors.fromOrInNotAllowedInShowDatabasesError(ctx) - } - val multiPart = Option(ctx.multipartIdentifier).map(visitMultipartIdentifier) ShowNamespaces( UnresolvedNamespace(multiPart.getOrElse(Seq.empty[String])), @@ -3402,7 +3410,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } /** - * Create a table, returning a [[CreateTableStatement]] logical plan. + * Create a table, returning a [[CreateTable]] or [[CreateTableAsSelectStatement]] logical plan. * * Expected format: * {{{ @@ -3469,14 +3477,18 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg case _ => // Note: table schema includes both the table columns list and the partition columns // with data type. + val tableSpec = TableSpec(bucketSpec, properties, provider, options, location, comment, + serdeInfo, external) val schema = StructType(columns ++ partCols) - CreateTableStatement(table, schema, partitioning, bucketSpec, properties, provider, - options, location, comment, serdeInfo, external = external, ifNotExists = ifNotExists) + CreateTable( + UnresolvedDBObjectName(table, isNamespace = false), + schema, partitioning, tableSpec, ignoreIfExists = ifNotExists) } } /** - * Replace a table, returning a [[ReplaceTableStatement]] logical plan. + * Replace a table, returning a [[ReplaceTableStatement]] or [[ReplaceTableAsSelect]] + * logical plan. * * Expected format: * {{{ @@ -3542,9 +3554,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg ctx) case Some(query) => - ReplaceTableAsSelectStatement(table, query, partitioning, bucketSpec, properties, - provider, options, location, comment, writeOptions = Map.empty, serdeInfo, - orCreate = orCreate) + val tableSpec = TableSpec(bucketSpec, properties, provider, options, location, comment, + serdeInfo, false) + ReplaceTableAsSelect( + UnresolvedDBObjectName(table, isNamespace = false), + partitioning, query, tableSpec, writeOptions = Map.empty, orCreate = orCreate) case _ => // Note: table schema includes both the table columns list and the partition columns diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index f53c0d36f2..1057c78f3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -101,6 +101,7 @@ abstract class AbstractSqlParser extends ParserInterface with SQLConfHelper with val tokenStream = new CommonTokenStream(lexer) val parser = new SqlBaseParser(tokenStream) parser.addParseListener(PostProcessor) + parser.addParseListener(UnclosedCommentProcessor(command, tokenStream)) parser.removeErrorListeners() parser.addErrorListener(ParseErrorListener) parser.legacy_setops_precedence_enabled = conf.setOpsPrecedenceEnforced @@ -313,3 +314,62 @@ case object PostProcessor extends SqlBaseBaseListener { parent.addChild(new TerminalNodeImpl(f(newToken))) } } + +/** + * The post-processor checks the unclosed bracketed comment. + */ +case class UnclosedCommentProcessor( + command: String, tokenStream: CommonTokenStream) extends SqlBaseBaseListener { + + override def exitSingleDataType(ctx: SqlBaseParser.SingleDataTypeContext): Unit = { + checkUnclosedComment(tokenStream, command) + } + + override def exitSingleExpression(ctx: SqlBaseParser.SingleExpressionContext): Unit = { + checkUnclosedComment(tokenStream, command) + } + + override def exitSingleTableIdentifier(ctx: SqlBaseParser.SingleTableIdentifierContext): Unit = { + checkUnclosedComment(tokenStream, command) + } + + override def exitSingleFunctionIdentifier( + ctx: SqlBaseParser.SingleFunctionIdentifierContext): Unit = { + checkUnclosedComment(tokenStream, command) + } + + override def exitSingleMultipartIdentifier( + ctx: SqlBaseParser.SingleMultipartIdentifierContext): Unit = { + checkUnclosedComment(tokenStream, command) + } + + override def exitSingleTableSchema(ctx: SqlBaseParser.SingleTableSchemaContext): Unit = { + checkUnclosedComment(tokenStream, command) + } + + override def exitQuery(ctx: SqlBaseParser.QueryContext): Unit = { + checkUnclosedComment(tokenStream, command) + } + + override def exitSingleStatement(ctx: SqlBaseParser.SingleStatementContext): Unit = { + // SET command uses a wildcard to match anything, and we shouldn't parse the comments, e.g. + // `SET myPath =/a/*`. + if (!ctx.statement().isInstanceOf[SqlBaseParser.SetConfigurationContext]) { + checkUnclosedComment(tokenStream, command) + } + } + + /** check `has_unclosed_bracketed_comment` to find out the unclosed bracketed comment. */ + private def checkUnclosedComment(tokenStream: CommonTokenStream, command: String) = { + assert(tokenStream.getTokenSource.isInstanceOf[SqlBaseLexer]) + val lexer = tokenStream.getTokenSource.asInstanceOf[SqlBaseLexer] + if (lexer.has_unclosed_bracketed_comment) { + // The last token is 'EOF' and the penultimate is unclosed bracketed comment + val failedToken = tokenStream.get(tokenStream.size() - 2) + assert(failedToken.getType() == SqlBaseParser.BRACKETED_COMMENT) + val position = Origin(Option(failedToken.getLine), Option(failedToken.getCharPositionInLine)) + throw QueryParsingErrors.unclosedBracketedCommentError(command, position) + } + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 7c31a00918..4aa7bf1c4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -41,7 +41,8 @@ abstract class LogicalPlan def metadataOutput: Seq[Attribute] = children.flatMap(_.metadataOutput) /** Returns true if this subtree has data from a streaming data source. */ - def isStreaming: Boolean = children.exists(_.isStreaming) + def isStreaming: Boolean = _isStreaming + private[this] lazy val _isStreaming = children.exists(_.isStreaming) override def verboseStringWithSuffix(maxFields: Int): String = { super.verboseString(maxFields) + statsCache.map(", " + _.toString).getOrElse("") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index ccc4e190ba..70c6f15290 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -123,25 +123,6 @@ object SerdeInfo { } } -/** - * A CREATE TABLE command, as parsed from SQL. - * - * This is a metadata-only command and is not used to write data to the created table. - */ -case class CreateTableStatement( - tableName: Seq[String], - tableSchema: StructType, - partitioning: Seq[Transform], - bucketSpec: Option[BucketSpec], - properties: Map[String, String], - provider: Option[String], - options: Map[String, String], - location: Option[String], - comment: Option[String], - serde: Option[SerdeInfo], - external: Boolean, - ifNotExists: Boolean) extends LeafParsedStatement - /** * A CREATE TABLE AS SELECT command, as parsed from SQL. */ @@ -184,29 +165,6 @@ case class ReplaceTableStatement( serde: Option[SerdeInfo], orCreate: Boolean) extends LeafParsedStatement -/** - * A REPLACE TABLE AS SELECT command, as parsed from SQL. - */ -case class ReplaceTableAsSelectStatement( - tableName: Seq[String], - asSelect: LogicalPlan, - partitioning: Seq[Transform], - bucketSpec: Option[BucketSpec], - properties: Map[String, String], - provider: Option[String], - options: Map[String, String], - location: Option[String], - comment: Option[String], - writeOptions: Map[String, String], - serde: Option[SerdeInfo], - orCreate: Boolean) extends UnaryParsedStatement { - - override def child: LogicalPlan = asSelect - override protected def withNewChildInternal( - newChild: LogicalPlan): ReplaceTableAsSelectStatement = copy(asSelect = newChild) -} - - /** * Column data as parsed by ALTER TABLE ... (ADD|REPLACE) COLUMNS. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala index 5ae8d69f4b..091955b6b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala @@ -50,13 +50,19 @@ object UnionEstimation { case TimestampType => (a: Any, b: Any) => TimestampType.ordering.lt(a.asInstanceOf[TimestampType.InternalType], b.asInstanceOf[TimestampType.InternalType]) + case TimestampNTZType => (a: Any, b: Any) => + TimestampNTZType.ordering.lt(a.asInstanceOf[TimestampNTZType.InternalType], + b.asInstanceOf[TimestampNTZType.InternalType]) + case i: AnsiIntervalType => (a: Any, b: Any) => + i.ordering.lt(a.asInstanceOf[i.InternalType], b.asInstanceOf[i.InternalType]) case _ => throw new IllegalStateException(s"Unsupported data type: ${dt.catalogString}") } private def isTypeSupported(dt: DataType): Boolean = dt match { case ByteType | IntegerType | ShortType | FloatType | LongType | - DoubleType | DateType | _: DecimalType | TimestampType => true + DoubleType | DateType | _: DecimalType | TimestampType | TimestampNTZType | + _: AnsiIntervalType => true case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 4ed5d87aaf..d9e5dfe16b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, FieldName, NamedRelation, PartitionSpec, UnresolvedException} +import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, FieldName, NamedRelation, PartitionSpec, ResolvedDBObjectName, UnresolvedException} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, FunctionResource} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.catalog.FunctionResource import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, Unevaluable} import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema import org.apache.spark.sql.catalyst.trees.BinaryLike @@ -193,13 +193,24 @@ trait V2CreateTablePlan extends LogicalPlan { /** * Create a new table with a v2 catalog. */ -case class CreateV2Table( - catalog: TableCatalog, - tableName: Identifier, +case class CreateTable( + name: LogicalPlan, tableSchema: StructType, partitioning: Seq[Transform], - properties: Map[String, String], - ignoreIfExists: Boolean) extends LeafCommand with V2CreateTablePlan { + tableSpec: TableSpec, + ignoreIfExists: Boolean) extends UnaryCommand with V2CreateTablePlan { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + + override def child: LogicalPlan = name + + override def tableName: Identifier = { + assert(child.resolved) + child.asInstanceOf[ResolvedDBObjectName].nameParts.asIdentifier + } + + override protected def withNewChildInternal(newChild: LogicalPlan): V2CreateTablePlan = + copy(name = newChild) + override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = { this.copy(partitioning = rewritten) } @@ -262,16 +273,17 @@ case class ReplaceTable( * If the table does not exist, and orCreate is false, then an exception will be thrown. */ case class ReplaceTableAsSelect( - catalog: TableCatalog, - tableName: Identifier, + name: LogicalPlan, partitioning: Seq[Transform], query: LogicalPlan, - properties: Map[String, String], + tableSpec: TableSpec, writeOptions: Map[String, String], - orCreate: Boolean) extends UnaryCommand with V2CreateTablePlan { + orCreate: Boolean) extends BinaryCommand with V2CreateTablePlan { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper override def tableSchema: StructType = query.schema - override def child: LogicalPlan = query + override def left: LogicalPlan = name + override def right: LogicalPlan = query override lazy val resolved: Boolean = childrenResolved && { // the table schema is created from the query schema, so the only resolution needed is to check @@ -280,12 +292,19 @@ case class ReplaceTableAsSelect( references.map(_.fieldNames).forall(query.schema.findNestedField(_).isDefined) } + override def tableName: Identifier = { + assert(name.resolved) + name.asInstanceOf[ResolvedDBObjectName].nameParts.asIdentifier + } + + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, + newRight: LogicalPlan): LogicalPlan = + copy(name = newLeft, query = newRight) + override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = { this.copy(partitioning = rewritten) } - - override protected def withNewChildInternal(newChild: LogicalPlan): ReplaceTableAsSelect = - copy(query = newChild) } /** @@ -1090,3 +1109,13 @@ case class DropIndex( override protected def withNewChildInternal(newChild: LogicalPlan): DropIndex = copy(table = newChild) } + +case class TableSpec( + bucketSpec: Option[BucketSpec], + properties: Map[String, String], + provider: Option[String], + options: Map[String, String], + location: Option[String], + comment: Option[String], + serde: Option[SerdeInfo], + external: Boolean) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 5ec303d97f..4face49462 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -76,6 +76,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowFrame" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowOrder" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$WindowsSubstitution" :: + "org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$AnsiCombinedTypeCoercionRule" :: "org.apache.spark.sql.catalyst.analysis.ApplyCharTypePadding" :: "org.apache.spark.sql.catalyst.analysis.DeduplicateRelations" :: "org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index f3f6744720..3d62cf2b83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.logical.TableSpec import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} import org.apache.spark.sql.catalyst.rules.RuleId import org.apache.spark.sql.catalyst.rules.RuleIdCollection @@ -819,6 +820,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre redactMapString(map.asCaseSensitiveMap().asScala, maxFields) case map: Map[_, _] => redactMapString(map, maxFields) + case t: TableSpec => + t.copy(properties = Utils.redact(t.properties).toMap, + options = Utils.redact(t.options).toMap) :: Nil case table: CatalogTable => table.storage.serde match { case Some(serde) => table.identifier :: serde :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 3d9598cd0c..ebe5153099 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -442,17 +442,22 @@ object DateTimeUtils { /** * Trims and parses a given UTF8 string to a corresponding [[Long]] value which representing the - * number of microseconds since the epoch. The result is independent of time zones, - * which means that zone ID in the input string will be ignored. + * number of microseconds since the epoch. The result will be independent of time zones. + * + * If the input string contains a component associated with time zone, the method will return + * `None` if `failOnError` is set to `true`. If `failOnError` is set to `false`, the method + * will simply discard the time zone component. Enable the check to detect situations like parsing + * a timestamp with time zone as TimestampNTZType. + * * The return type is [[Option]] in order to distinguish between 0L and null. Please * refer to `parseTimestampString` for the allowed formats. */ - def stringToTimestampWithoutTimeZone(s: UTF8String): Option[Long] = { + def stringToTimestampWithoutTimeZone(s: UTF8String, failOnError: Boolean): Option[Long] = { try { - val (segments, _, justTime) = parseTimestampString(s) - // If the input string can't be parsed as a timestamp, or it contains only the time part of a - // timestamp and we can't determine its date, return None. - if (segments.isEmpty || justTime) { + val (segments, zoneIdOpt, justTime) = parseTimestampString(s) + // If the input string can't be parsed as a timestamp without time zone, or it contains only + // the time part of a timestamp and we can't determine its date, return None. + if (segments.isEmpty || justTime || failOnError && zoneIdOpt.isDefined) { return None } val nanoseconds = MICROSECONDS.toNanos(segments(6)) @@ -465,8 +470,19 @@ object DateTimeUtils { } } + /** + * Trims and parses a given UTF8 string to a corresponding [[Long]] value which representing the + * number of microseconds since the epoch. The result is independent of time zones. Zone id + * component will be discarded and ignored. + * The return type is [[Option]] in order to distinguish between 0L and null. Please + * refer to `parseTimestampString` for the allowed formats. + */ + def stringToTimestampWithoutTimeZone(s: UTF8String): Option[Long] = { + stringToTimestampWithoutTimeZone(s, false) + } + def stringToTimestampWithoutTimeZoneAnsi(s: UTF8String): Long = { - stringToTimestampWithoutTimeZone(s).getOrElse { + stringToTimestampWithoutTimeZone(s, false).getOrElse { throw QueryExecutionErrors.cannotCastToDateTimeError(s, TimestampNTZType) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala index 8a9104ae9e..21fd0860ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala @@ -31,9 +31,10 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.catalyst.util.LegacyDateFormats.{LegacyDateFormat, LENIENT_SIMPLE_DATE_FORMAT} import org.apache.spark.sql.catalyst.util.RebaseDateTime._ +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy._ -import org.apache.spark.sql.types.Decimal +import org.apache.spark.sql.types.{Decimal, TimestampNTZType} import org.apache.spark.unsafe.types.UTF8String sealed trait TimestampFormatter extends Serializable { @@ -55,6 +56,7 @@ sealed trait TimestampFormatter extends Serializable { * Parses a timestamp in a string and converts it to microseconds since Unix Epoch in local time. * * @param s - string with timestamp to parse + * @param failOnError - indicates strict parsing of timezone * @return microseconds since epoch. * @throws ParseException can be thrown by legacy parser * @throws DateTimeParseException can be thrown by new parser @@ -66,10 +68,23 @@ sealed trait TimestampFormatter extends Serializable { @throws(classOf[DateTimeParseException]) @throws(classOf[DateTimeException]) @throws(classOf[IllegalStateException]) - def parseWithoutTimeZone(s: String): Long = + def parseWithoutTimeZone(s: String, failOnError: Boolean): Long = throw new IllegalStateException( - s"The method `parseWithoutTimeZone(s: String)` should be implemented in the formatter " + - "of timestamp without time zone") + s"The method `parseWithoutTimeZone(s: String, failOnError: Boolean)` should be " + + "implemented in the formatter of timestamp without time zone") + + /** + * Parses a timestamp in a string and converts it to microseconds since Unix Epoch in local time. + * Zone-id and zone-offset components are ignored. + */ + @throws(classOf[ParseException]) + @throws(classOf[DateTimeParseException]) + @throws(classOf[DateTimeException]) + @throws(classOf[IllegalStateException]) + final def parseWithoutTimeZone(s: String): Long = + // This is implemented to adhere to the original behaviour of `parseWithoutTimeZone` where we + // did not fail if timestamp contained zone-id or zone-offset component and instead ignored it. + parseWithoutTimeZone(s, false) def format(us: Long): String def format(ts: Timestamp): String @@ -118,9 +133,12 @@ class Iso8601TimestampFormatter( } catch checkParsedDiff(s, legacyFormatter.parse) } - override def parseWithoutTimeZone(s: String): Long = { + override def parseWithoutTimeZone(s: String, failOnError: Boolean): Long = { try { val parsed = formatter.parse(s) + if (failOnError && parsed.query(TemporalQueries.zone()) != null) { + throw QueryExecutionErrors.cannotParseStringAsDataTypeError(pattern, s, TimestampNTZType) + } val localDate = toLocalDate(parsed) val localTime = toLocalTime(parsed) DateTimeUtils.localDateTimeToMicros(LocalDateTime.of(localDate, localTime)) @@ -186,9 +204,13 @@ class DefaultTimestampFormatter( } catch checkParsedDiff(s, legacyFormatter.parse) } - override def parseWithoutTimeZone(s: String): Long = { + override def parseWithoutTimeZone(s: String, failOnError: Boolean): Long = { try { - DateTimeUtils.stringToTimestampWithoutTimeZoneAnsi(UTF8String.fromString(s)) + val utf8Value = UTF8String.fromString(s) + DateTimeUtils.stringToTimestampWithoutTimeZone(utf8Value, failOnError).getOrElse { + throw QueryExecutionErrors.cannotParseStringAsDataTypeError( + TimestampFormatter.defaultPattern(), s, TimestampNTZType) + } } catch checkParsedDiff(s, legacyFormatter.parse) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 33fe48d44d..e26f397bb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -139,6 +139,7 @@ package object util extends Logging { PrettyAttribute(r.mkString(r.exprsReplaced.map(toPrettySQL)), r.dataType) case c: CastBase if !c.getTagValue(Cast.USER_SPECIFIED_CAST).getOrElse(false) => PrettyAttribute(usePrettyExpression(c.child).sql, c.dataType) + case p: PythonUDF => PrettyPythonUDF(p.name, p.dataType, p.children) } def quoteIdentifier(name: String): String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index c010b6d59a..44e57f22ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -22,10 +22,9 @@ import java.util.Collections import scala.collection.JavaConverters._ -import org.apache.spark.sql.catalyst.analysis.{NamedRelation, NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException} -import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelectStatement, CreateTableStatement, ReplaceTableAsSelectStatement, ReplaceTableStatement, SerdeInfo} +import org.apache.spark.sql.catalyst.analysis.{AsOfTimestamp, AsOfVersion, NamedRelation, NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException, TimeTravelSpec} +import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelectStatement, ReplaceTableStatement, SerdeInfo, TableSpec} import org.apache.spark.sql.connector.catalog.TableChange._ -import org.apache.spark.sql.connector.expressions.{AsOfTimestamp, AsOfVersion, TimeTravelSpec} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.{ArrayType, MapType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -306,11 +305,6 @@ private[sql] object CatalogV2Util { catalog.name().equalsIgnoreCase(CatalogManager.SESSION_CATALOG_NAME) } - def convertTableProperties(c: CreateTableStatement): Map[String, String] = { - convertTableProperties( - c.properties, c.options, c.serde, c.location, c.comment, c.provider, c.external) - } - def convertTableProperties(c: CreateTableAsSelectStatement): Map[String, String] = { convertTableProperties( c.properties, c.options, c.serde, c.location, c.comment, c.provider, c.external) @@ -320,8 +314,10 @@ private[sql] object CatalogV2Util { convertTableProperties(r.properties, r.options, r.serde, r.location, r.comment, r.provider) } - def convertTableProperties(r: ReplaceTableAsSelectStatement): Map[String, String] = { - convertTableProperties(r.properties, r.options, r.serde, r.location, r.comment, r.provider) + def convertTableProperties(t: TableSpec): Map[String, String] = { + val props = convertTableProperties( + t.properties, t.options, t.serde, t.location, t.comment, t.provider, t.external) + withDefaultOwnership(props) } private def convertTableProperties( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index 32fe208871..2863d94d19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -19,11 +19,7 @@ package org.apache.spark.sql.connector.expressions import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, IntegerType, StringType} -import org.apache.spark.unsafe.types.UTF8String /** * Helper methods for working with the logical expressions API. @@ -361,25 +357,3 @@ private[sql] object SortValue { None } } - -private[sql] sealed trait TimeTravelSpec - -private[sql] case class AsOfTimestamp(timestamp: Long) extends TimeTravelSpec -private[sql] case class AsOfVersion(version: String) extends TimeTravelSpec - -private[sql] object TimeTravelSpec { - def create(timestamp: Option[String], version: Option[String]) : Option[TimeTravelSpec] = { - if (timestamp.nonEmpty && version.nonEmpty) { - throw QueryCompilationErrors.invalidTimeTravelSpecError() - } else if (timestamp.nonEmpty) { - val ts = DateTimeUtils.stringToTimestampAnsi( - UTF8String.fromString(timestamp.get), - DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone)) - Some(AsOfTimestamp(ts)) - } else if (version.nonEmpty) { - Some(AsOfVersion(version.get)) - } else { - None - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 839a888990..624ee88d3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -2379,6 +2379,19 @@ object QueryCompilationErrors { } def invalidTimeTravelSpecError(): Throwable = { - new AnalysisException("Cannot specify both version and timestamp when scanning the table.") + new AnalysisException( + "Cannot specify both version and timestamp when time travelling the table.") + } + + def invalidTimestampExprForTimeTravel(expr: Expression): Throwable = { + new AnalysisException(s"${expr.sql} is not a valid timestamp expression for time travel.") + } + + def viewNotSupportTimeTravelError(viewName: Seq[String]): Throwable = { + new AnalysisException(viewName.quoted + " is a view which does not support time travel.") + } + + def tableNotSupportTimeTravelError(tableName: Identifier): UnsupportedOperationException = { + new UnsupportedOperationException(s"Table $tableName does not support time travel.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index ba3dd52435..ef809b78d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -948,7 +948,7 @@ object QueryExecutionErrors { def cannotMergeDecimalTypesWithIncompatibleScaleError( leftScale: Int, rightScale: Int): Throwable = { new SparkException("Failed to merge decimal types with incompatible " + - s"scala $leftScale and $rightScale") + s"scale $leftScale and $rightScale") } def cannotMergeIncompatibleDataTypesError(left: DataType, right: DataType): Throwable = { @@ -1034,6 +1034,13 @@ object QueryExecutionErrors { s"[$token] as target spark data type [$dataType].") } + def cannotParseStringAsDataTypeError(pattern: String, value: String, dataType: DataType) + : Throwable = { + new RuntimeException( + s"Cannot parse field value ${value} for pattern ${pattern} " + + s"as target spark data type [$dataType].") + } + def failToParseEmptyStringForDataTypeError(dataType: DataType): Throwable = { new RuntimeException( s"Failed to parse an empty string for data type ${dataType.catalogString}") @@ -1894,4 +1901,3 @@ object QueryExecutionErrors { new RuntimeException("Unable to convert timestamp of Orc to data type 'timestamp_ntz'") } } - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index 090f73d192..70678ec18a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -248,10 +248,6 @@ object QueryParsingErrors { new ParseException("Either PROPERTIES or DBPROPERTIES is allowed.", ctx) } - def fromOrInNotAllowedInShowDatabasesError(ctx: ShowNamespacesContext): Throwable = { - new ParseException(s"FROM/IN operator is not allowed in SHOW DATABASES", ctx) - } - def cannotCleanReservedTablePropertyError( property: String, ctx: ParserRuleContext, msg: String): Throwable = { new ParseException(s"$property is a reserved table property, $msg.", ctx) @@ -328,7 +324,7 @@ object QueryParsingErrors { new ParseException(errorClass = "DUPLICATE_KEY", messageParameters = Array(key), ctx) } - def unexpectedFomatForSetConfigurationError(ctx: SetConfigurationContext): Throwable = { + def unexpectedFomatForSetConfigurationError(ctx: ParserRuleContext): Throwable = { new ParseException( s""" |Expected format is 'SET', 'SET key', or 'SET key=value'. If you want to include @@ -338,13 +334,13 @@ object QueryParsingErrors { } def invalidPropertyKeyForSetQuotedConfigurationError( - keyCandidate: String, valueStr: String, ctx: SetQuotedConfigurationContext): Throwable = { + keyCandidate: String, valueStr: String, ctx: ParserRuleContext): Throwable = { new ParseException(s"'$keyCandidate' is an invalid property key, please " + s"use quotes, e.g. SET `$keyCandidate`=`$valueStr`", ctx) } def invalidPropertyValueForSetQuotedConfigurationError( - valueCandidate: String, keyStr: String, ctx: SetQuotedConfigurationContext): Throwable = { + valueCandidate: String, keyStr: String, ctx: ParserRuleContext): Throwable = { new ParseException(s"'$valueCandidate' is an invalid property value, please " + s"use quotes, e.g. SET `$keyStr`=`$valueCandidate`", ctx) } @@ -425,4 +421,12 @@ object QueryParsingErrors { new ParseException( s"Specifying a database in CREATE TEMPORARY FUNCTION is not allowed: '$databaseName'", ctx) } + + def unclosedBracketedCommentError(command: String, position: Origin): Throwable = { + new ParseException(Some(command), "Unclosed bracketed comment", position, position) + } + + def invalidTimeTravelSpec(reason: String, ctx: ParserRuleContext): Throwable = { + new ParseException(s"Invalid time travel spec: $reason.", ctx) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c7535c3ceb..f2cab8a6b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -662,6 +662,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR = + buildConf("spark.sql.adaptive.rebalancePartitionsSmallPartitionFactor") + .doc(s"A partition will be merged during splitting if its size is small than this factor " + + s"multiply ${ADVISORY_PARTITION_SIZE_IN_BYTES.key}.") + .version("3.3.0") + .doubleConf + .checkValue(v => v > 0 && v < 1, "the factor must be in (0, 1)") + .createWithDefault(0.2) + val ADAPTIVE_FORCE_OPTIMIZE_SKEWED_JOIN = buildConf("spark.sql.adaptive.forceOptimizeSkewedJoin") .doc("When true, force enable OptimizeSkewedJoin even if it introduces extra shuffle.") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 96262f5afb..d30bcd5af5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -502,7 +502,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { } // stream-stream inner join doesn't emit late rows, whereas outer joins could - Seq((Inner, false), (LeftOuter, true), (RightOuter, true)).map { + Seq((Inner, false), (LeftOuter, true), (RightOuter, true)).foreach { case (joinType, expectFailure) => assertPassOnGlobalWatermarkLimit( s"single $joinType join in Append mode", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 823ce77489..443a94b2ee 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -1019,4 +1019,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } } + + test("SPARK-37508: Support contains string expression") { + checkEvaluation(Contains(Literal("aa"), Literal.create(null, StringType)), null) + checkEvaluation(Contains(Literal.create(null, StringType), Literal("aa")), null) + checkEvaluation(Contains(Literal("Spark SQL"), Literal("Spark")), true) + checkEvaluation(Contains(Literal("Spark SQL"), Literal("SPARK")), false) + checkEvaluation(Contains(Literal("Spark SQL"), Literal("SQL")), true) + checkEvaluation(Contains(Literal("Spark SQL"), Literal("k S")), true) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 62b611a128..182f028b5c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.parser -import java.time.DateTimeException -import java.util import java.util.Locale import org.apache.spark.sql.AnalysisException @@ -27,10 +25,9 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions.{EqualTo, Hex, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition.{after, first} -import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, TimeTravelSpec, Transform, YearsTransform} +import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType, TimestampType} -import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class DDLParserSuite extends AnalysisTest { @@ -720,12 +717,12 @@ class DDLParserSuite extends AnalysisTest { val parsedPlan = parsePlan(sqlStatement) val newTableToken = sqlStatement.split(" ")(0).trim.toUpperCase(Locale.ROOT) parsedPlan match { - case create: CreateTableStatement if newTableToken == "CREATE" => - assert(create.ifNotExists == expectedIfNotExists) + case create: CreateTable if newTableToken == "CREATE" => + assert(create.ignoreIfExists == expectedIfNotExists) case ctas: CreateTableAsSelectStatement if newTableToken == "CREATE" => assert(ctas.ifNotExists == expectedIfNotExists) case replace: ReplaceTableStatement if newTableToken == "REPLACE" => - case replace: ReplaceTableAsSelectStatement if newTableToken == "REPLACE" => + case replace: ReplaceTableAsSelect if newTableToken == "REPLACE" => case other => fail("First token in statement does not match the expected parsed plan; CREATE TABLE" + " should create a CreateTableStatement, and REPLACE TABLE should create a" + @@ -1832,23 +1829,6 @@ class DDLParserSuite extends AnalysisTest { UnresolvedNamespace(Seq("a", "b", "c")), Map("b" -> "b"))) } - test("set namespace location") { - comparePlans( - parsePlan("ALTER DATABASE a.b.c SET LOCATION '/home/user/db'"), - SetNamespaceLocation( - UnresolvedNamespace(Seq("a", "b", "c")), "/home/user/db")) - - comparePlans( - parsePlan("ALTER SCHEMA a.b.c SET LOCATION '/home/user/db'"), - SetNamespaceLocation( - UnresolvedNamespace(Seq("a", "b", "c")), "/home/user/db")) - - comparePlans( - parsePlan("ALTER NAMESPACE a.b.c SET LOCATION '/home/user/db'"), - SetNamespaceLocation( - UnresolvedNamespace(Seq("a", "b", "c")), "/home/user/db")) - } - test("analyze table statistics") { comparePlans(parsePlan("analyze table a.b.c compute statistics"), AnalyzeTable( @@ -2305,19 +2285,19 @@ class DDLParserSuite extends AnalysisTest { private object TableSpec { def apply(plan: LogicalPlan): TableSpec = { plan match { - case create: CreateTableStatement => + case create: CreateTable => TableSpec( - create.tableName, + create.name.asInstanceOf[UnresolvedDBObjectName].nameParts, Some(create.tableSchema), create.partitioning, - create.bucketSpec, - create.properties, - create.provider, - create.options, - create.location, - create.comment, - create.serde, - create.external) + create.tableSpec.bucketSpec, + create.tableSpec.properties, + create.tableSpec.provider, + create.tableSpec.options, + create.tableSpec.location, + create.tableSpec.comment, + create.tableSpec.serde, + create.tableSpec.external) case replace: ReplaceTableStatement => TableSpec( replace.tableName, @@ -2343,18 +2323,18 @@ class DDLParserSuite extends AnalysisTest { ctas.comment, ctas.serde, ctas.external) - case rtas: ReplaceTableAsSelectStatement => + case rtas: ReplaceTableAsSelect => TableSpec( - rtas.tableName, - Some(rtas.asSelect).filter(_.resolved).map(_.schema), + rtas.name.asInstanceOf[UnresolvedDBObjectName].nameParts, + Some(rtas.query).filter(_.resolved).map(_.schema), rtas.partitioning, - rtas.bucketSpec, - rtas.properties, - rtas.provider, - rtas.options, - rtas.location, - rtas.comment, - rtas.serde) + rtas.tableSpec.bucketSpec, + rtas.tableSpec.properties, + rtas.tableSpec.provider, + rtas.tableSpec.options, + rtas.tableSpec.location, + rtas.tableSpec.comment, + rtas.tableSpec.serde) case other => fail(s"Expected to parse Create, CTAS, Replace, or RTAS plan" + s" from query, got ${other.getClass.getName}.") @@ -2428,117 +2408,4 @@ class DDLParserSuite extends AnalysisTest { comparePlans(parsePlan(timestampTypeSql), insertPartitionPlan(timestamp)) comparePlans(parsePlan(binaryTypeSql), insertPartitionPlan(binaryStr)) } - - test("as of syntax") { - val properties = new util.HashMap[String, String] - var timeTravel = TimeTravelSpec.create(None, Some("Snapshot123456789")) - comparePlans( - parsePlan("SELECT * FROM a.b.c VERSION AS OF 'Snapshot123456789'"), - Project(Seq(UnresolvedStar(None)), - UnresolvedRelation( - Seq("a", "b", "c"), - new CaseInsensitiveStringMap(properties), - timeTravelSpec = timeTravel))) - - comparePlans( - parsePlan("SELECT * FROM a.b.c FOR VERSION AS OF 'Snapshot123456789'"), - Project(Seq(UnresolvedStar(None)), - UnresolvedRelation( - Seq("a", "b", "c"), - new CaseInsensitiveStringMap(properties), - timeTravelSpec = timeTravel))) - - timeTravel = TimeTravelSpec.create(None, Some("123456789")) - comparePlans( - parsePlan("SELECT * FROM a.b.c FOR SYSTEM_VERSION AS OF 123456789"), - Project(Seq(UnresolvedStar(None)), - UnresolvedRelation( - Seq("a", "b", "c"), - new CaseInsensitiveStringMap(properties), - timeTravelSpec = timeTravel))) - - comparePlans( - parsePlan("SELECT * FROM a.b.c SYSTEM_VERSION AS OF 123456789"), - Project(Seq(UnresolvedStar(None)), - UnresolvedRelation( - Seq("a", "b", "c"), - new CaseInsensitiveStringMap(properties), - timeTravelSpec = timeTravel))) - - timeTravel = TimeTravelSpec.create(Some("2019-01-29 00:37:58"), None) - comparePlans( - parsePlan("SELECT * FROM a.b.c TIMESTAMP AS OF '2019-01-29 00:37:58'"), - Project(Seq(UnresolvedStar(None)), - UnresolvedRelation( - Seq("a", "b", "c"), - new CaseInsensitiveStringMap(properties), - timeTravelSpec = timeTravel))) - - comparePlans( - parsePlan("SELECT * FROM a.b.c FOR TIMESTAMP AS OF '2019-01-29 00:37:58'"), - Project(Seq(UnresolvedStar(None)), - UnresolvedRelation( - Seq("a", "b", "c"), - new CaseInsensitiveStringMap(properties), - timeTravelSpec = timeTravel))) - - comparePlans( - parsePlan("SELECT * FROM a.b.c FOR SYSTEM_TIME AS OF '2019-01-29 00:37:58'"), - Project(Seq(UnresolvedStar(None)), - UnresolvedRelation( - Seq("a", "b", "c"), - new CaseInsensitiveStringMap(properties), - timeTravelSpec = timeTravel))) - - comparePlans( - parsePlan("SELECT * FROM a.b.c SYSTEM_TIME AS OF '2019-01-29 00:37:58'"), - Project(Seq(UnresolvedStar(None)), - UnresolvedRelation( - Seq("a", "b", "c"), - new CaseInsensitiveStringMap(properties), - timeTravelSpec = timeTravel))) - - timeTravel = TimeTravelSpec.create(Some("2019-01-29"), None) - comparePlans( - parsePlan("SELECT * FROM a.b.c TIMESTAMP AS OF '2019-01-29'"), - Project(Seq(UnresolvedStar(None)), - UnresolvedRelation( - Seq("a", "b", "c"), - new CaseInsensitiveStringMap(properties), - timeTravelSpec = timeTravel))) - - comparePlans( - parsePlan("SELECT * FROM a.b.c FOR TIMESTAMP AS OF '2019-01-29'"), - Project(Seq(UnresolvedStar(None)), - UnresolvedRelation( - Seq("a", "b", "c"), - new CaseInsensitiveStringMap(properties), - timeTravelSpec = timeTravel))) - - comparePlans( - parsePlan("SELECT * FROM a.b.c FOR SYSTEM_TIME AS OF '2019-01-29'"), - Project(Seq(UnresolvedStar(None)), - UnresolvedRelation( - Seq("a", "b", "c"), - new CaseInsensitiveStringMap(properties), - timeTravelSpec = timeTravel))) - - comparePlans( - parsePlan("SELECT * FROM a.b.c SYSTEM_TIME AS OF '2019-01-29'"), - Project(Seq(UnresolvedStar(None)), - UnresolvedRelation( - Seq("a", "b", "c"), - new CaseInsensitiveStringMap(properties), - timeTravelSpec = timeTravel))) - - val e1 = intercept[DateTimeException] { - parsePlan("SELECT * FROM a.b.c TIMESTAMP AS OF '2019-01-11111'") - }.getMessage - assert(e1.contains("Cannot cast 2019-01-11111 to TimestampType.")) - - val e2 = intercept[AnalysisException] { - timeTravel = TimeTravelSpec.create(Some("2019-01-29 00:37:58"), Some("123456789")) - }.getMessage - assert(e2.contains("Cannot specify both version and timestamp when scanning the table.")) - } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index ebafb9db1b..76c620d44d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction} +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, RelationTimeTravel, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedInlineTable, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -160,6 +160,56 @@ class PlanParserSuite extends AnalysisTest { """.stripMargin, plan) } + test("nested bracketed comment case seven") { + val plan = OneRowRelation().select(Literal(1).as("a")) + assertEqual( + """ + |/*abc*/ + |select 1 as a + |/* + | + |2 as b + |/*abc */ + |, 3 as c + | + |/**/ + |*/ + """.stripMargin, plan) + } + + test("unclosed bracketed comment one") { + val query = """ + |/*abc*/ + |select 1 as a + |/* + | + |2 as b + |/*abc */ + |, 3 as c + | + |/**/ + |""".stripMargin + val e = intercept[ParseException](parsePlan(query)) + assert(e.getMessage.contains(s"Unclosed bracketed comment")) + } + + test("unclosed bracketed comment two") { + val query = """ + |/*abc*/ + |select 1 as a + |/* + | + |2 as b + |/*abc */ + |, 3 as c + | + |/**/ + |select 4 as d + |""".stripMargin + val e = intercept[ParseException](parsePlan(query)) + assert(e.getMessage.contains(s"Unclosed bracketed comment")) + } + test("case insensitive") { val plan = table("a").select(star()) assertEqual("sELEct * FroM a", plan) @@ -1203,4 +1253,50 @@ class PlanParserSuite extends AnalysisTest { """.stripMargin, "TRANSFORM with serde is only supported in hive mode") } + + + test("as of syntax") { + def testVersion(version: String, plan: LogicalPlan): Unit = { + Seq("VERSION", "SYSTEM_VERSION").foreach { keyword => + comparePlans(parsePlan(s"SELECT * FROM a.b.c $keyword AS OF $version"), plan) + comparePlans(parsePlan(s"SELECT * FROM a.b.c FOR $keyword AS OF $version"), plan) + } + } + + testVersion("'Snapshot123456789'", Project(Seq(UnresolvedStar(None)), + RelationTimeTravel( + UnresolvedRelation(Seq("a", "b", "c")), + None, + Some("Snapshot123456789")))) + + testVersion("123456789", Project(Seq(UnresolvedStar(None)), + RelationTimeTravel( + UnresolvedRelation(Seq("a", "b", "c")), + None, + Some("123456789")))) + + def testTimestamp(timestamp: String, plan: LogicalPlan): Unit = { + Seq("TIMESTAMP", "SYSTEM_TIME").foreach { keyword => + comparePlans(parsePlan(s"SELECT * FROM a.b.c $keyword AS OF $timestamp"), plan) + comparePlans(parsePlan(s"SELECT * FROM a.b.c FOR $keyword AS OF $timestamp"), plan) + } + } + + testTimestamp("'2019-01-29 00:37:58'", Project(Seq(UnresolvedStar(None)), + RelationTimeTravel( + UnresolvedRelation(Seq("a", "b", "c")), + Some(Literal("2019-01-29 00:37:58")), + None))) + + testTimestamp("current_date()", Project(Seq(UnresolvedStar(None)), + RelationTimeTravel( + UnresolvedRelation(Seq("a", "b", "c")), + Some(UnresolvedFunction(Seq("current_date"), Nil, isDistinct = false)), + None))) + + intercept("SELECT * FROM a.b.c TIMESTAMP AS OF col", + "timestamp expression cannot refer to any columns") + intercept("SELECT * FROM a.b.c TIMESTAMP AS OF (select 1)", + "timestamp expression cannot contain subqueries") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/UnionEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/UnionEstimationSuite.scala index 12b7de694b..e7041a7136 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/UnionEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/UnionEstimationSuite.scala @@ -57,6 +57,9 @@ class UnionEstimationSuite extends StatsEstimationTestBase { val attrDecimal = AttributeReference("cdecimal", DecimalType(5, 4))() val attrDate = AttributeReference("cdate", DateType)() val attrTimestamp = AttributeReference("ctimestamp", TimestampType)() + val attrTimestampNTZ = AttributeReference("ctimestamp_ntz", TimestampNTZType)() + val attrYMInterval = AttributeReference("cyminterval", YearMonthIntervalType())() + val attrDTInterval = AttributeReference("cdtinterval", DayTimeIntervalType())() val s1 = 1.toShort val s2 = 4.toShort @@ -84,7 +87,10 @@ class UnionEstimationSuite extends StatsEstimationTestBase { attrFloat -> ColumnStat(min = Some(1.1f), max = Some(4.1f)), attrDecimal -> ColumnStat(min = Some(Decimal(13.5)), max = Some(Decimal(19.5))), attrDate -> ColumnStat(min = Some(1), max = Some(4)), - attrTimestamp -> ColumnStat(min = Some(1L), max = Some(4L)))) + attrTimestamp -> ColumnStat(min = Some(1L), max = Some(4L)), + attrTimestampNTZ -> ColumnStat(min = Some(1L), max = Some(4L)), + attrYMInterval -> ColumnStat(min = Some(2), max = Some(5)), + attrDTInterval -> ColumnStat(min = Some(2L), max = Some(5L)))) val s3 = 2.toShort val s4 = 6.toShort @@ -118,7 +124,16 @@ class UnionEstimationSuite extends StatsEstimationTestBase { AttributeReference("cdate1", DateType)() -> ColumnStat(min = Some(3), max = Some(6)), AttributeReference("ctimestamp1", TimestampType)() -> ColumnStat( min = Some(3L), - max = Some(6L)))) + max = Some(6L)), + AttributeReference("ctimestamp_ntz1", TimestampNTZType)() -> ColumnStat( + min = Some(3L), + max = Some(6L)), + AttributeReference("cymtimestamp1", YearMonthIntervalType())() -> ColumnStat( + min = Some(4), + max = Some(8)), + AttributeReference("cdttimestamp1", DayTimeIntervalType())() -> ColumnStat( + min = Some(4L), + max = Some(8L)))) val child1 = StatsTestPlan( outputList = columnInfo.keys.toSeq.sortWith(_.exprId.id < _.exprId.id), @@ -147,7 +162,10 @@ class UnionEstimationSuite extends StatsEstimationTestBase { attrFloat -> ColumnStat(min = Some(1.1f), max = Some(6.1f)), attrDecimal -> ColumnStat(min = Some(Decimal(13.5)), max = Some(Decimal(19.9))), attrDate -> ColumnStat(min = Some(1), max = Some(6)), - attrTimestamp -> ColumnStat(min = Some(1L), max = Some(6L))))) + attrTimestamp -> ColumnStat(min = Some(1L), max = Some(6L)), + attrTimestampNTZ -> ColumnStat(min = Some(1L), max = Some(6L)), + attrYMInterval -> ColumnStat(min = Some(2), max = Some(8)), + attrDTInterval -> ColumnStat(min = Some(2L), max = Some(8L))))) assert(union.stats === expectedStats) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 69bb6c141a..422a6cdeda 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -357,6 +357,18 @@ class DateTimeUtilsSuite extends SparkFunSuite with Matchers with SQLHelper { checkStringToTimestamp("2021-01-01T12:30:4294967297+4294967297:30", None) } + test("SPARK-37326: stringToTimestampWithoutTimeZone with failOnError") { + assert( + stringToTimestampWithoutTimeZone( + UTF8String.fromString("2021-11-22 10:54:27 +08:00"), false) == + Some(DateTimeUtils.localDateTimeToMicros(LocalDateTime.of(2021, 11, 22, 10, 54, 27)))) + + assert( + stringToTimestampWithoutTimeZone( + UTF8String.fromString("2021-11-22 10:54:27 +08:00"), true) == + None) + } + test("SPARK-15379: special invalid date string") { // Test stringToDate assert(toDate("2015-02-29 00:00:00").isEmpty) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index 3ebeacfa82..fad6fe5fbe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -173,8 +173,8 @@ class InMemoryTable( partitionSchema: StructType, from: Seq[Any], to: Seq[Any]): Boolean = { - val rows = dataMap.remove(from).getOrElse(new BufferedRows(from.mkString("/"))) - val newRows = new BufferedRows(to.mkString("/")) + val rows = dataMap.remove(from).getOrElse(new BufferedRows(from)) + val newRows = new BufferedRows(to) rows.rows.foreach { r => val newRow = new GenericInternalRow(r.numFields) for (i <- 0 until r.numFields) newRow.update(i, r.get(i, schema(i).dataType)) @@ -197,7 +197,7 @@ class InMemoryTable( protected def createPartitionKey(key: Seq[Any]): Unit = dataMap.synchronized { if (!dataMap.contains(key)) { - val emptyRows = new BufferedRows(key.toArray.mkString("/")) + val emptyRows = new BufferedRows(key) val rows = if (key.length == schema.length) { emptyRows.withRow(InternalRow.fromSeq(key)) } else emptyRows @@ -215,7 +215,7 @@ class InMemoryTable( val key = getKey(row) dataMap += dataMap.get(key) .map(key -> _.withRow(row)) - .getOrElse(key -> new BufferedRows(key.toArray.mkString("/")).withRow(row)) + .getOrElse(key -> new BufferedRows(key).withRow(row)) addPartitionKey(key) }) this @@ -290,7 +290,7 @@ class InMemoryTable( case In(attrName, values) if attrName == partitioning.head.name => val matchingKeys = values.map(_.toString).toSet data = data.filter(partition => { - val key = partition.asInstanceOf[BufferedRows].key + val key = partition.asInstanceOf[BufferedRows].keyString matchingKeys.contains(key) }) @@ -508,8 +508,8 @@ object InMemoryTable { } } -class BufferedRows( - val key: String = "") extends WriterCommitMessage with InputPartition with Serializable { +class BufferedRows(val key: Seq[Any] = Seq.empty) extends WriterCommitMessage + with InputPartition with HasPartitionKey with Serializable { val rows = new mutable.ArrayBuffer[InternalRow]() def withRow(row: InternalRow): BufferedRows = { @@ -517,6 +517,12 @@ class BufferedRows( this } + def keyString(): String = key.toArray.mkString("/") + + override def partitionKey(): InternalRow = { + InternalRow.fromSeq(key) + } + def clear(): Unit = rows.clear() } @@ -538,7 +544,7 @@ private class BufferedRowsReader( private def addMetadata(row: InternalRow): InternalRow = { val metadataRow = new GenericInternalRow(metadataColumnNames.map { case "index" => index - case "_partition" => UTF8String.fromString(partition.key) + case "_partition" => UTF8String.fromString(partition.keyString) }.toArray) new JoinedRow(row, metadataRow) } diff --git a/sql/core/benchmarks/DataSourceReadBenchmark-jdk11-results.txt b/sql/core/benchmarks/DataSourceReadBenchmark-jdk11-results.txt index 6578d5664c..c4cffd67b1 100644 --- a/sql/core/benchmarks/DataSourceReadBenchmark-jdk11-results.txt +++ b/sql/core/benchmarks/DataSourceReadBenchmark-jdk11-results.txt @@ -2,251 +2,269 @@ SQL Single Numeric Column Scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz -SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +SQL Single BOOLEAN Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 13405 13422 24 1.2 852.3 1.0X -SQL Json 10723 10788 92 1.5 681.7 1.3X -SQL Parquet Vectorized 164 217 50 95.9 10.4 81.8X -SQL Parquet MR 2349 2440 129 6.7 149.3 5.7X -SQL ORC Vectorized 312 346 23 50.4 19.8 43.0X -SQL ORC MR 1610 1659 69 9.8 102.4 8.3X +SQL CSV 9999 10058 83 1.6 635.7 1.0X +SQL Json 8857 8883 37 1.8 563.1 1.1X +SQL Parquet Vectorized 132 157 16 119.0 8.4 75.7X +SQL Parquet MR 1987 1997 14 7.9 126.3 5.0X +SQL ORC Vectorized 186 227 34 84.3 11.9 53.6X +SQL ORC MR 1559 1602 62 10.1 99.1 6.4X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +Parquet Reader Single BOOLEAN Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized 110 117 9 143.0 7.0 1.0X +ParquetReader Vectorized -> Row 57 59 3 276.2 3.6 1.9X -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SQL CSV 12897 12916 28 1.2 819.9 1.0X +SQL Json 9739 9770 44 1.6 619.2 1.3X +SQL Parquet Vectorized 226 237 14 69.7 14.3 57.2X +SQL Parquet MR 2124 2127 4 7.4 135.1 6.1X +SQL ORC Vectorized 213 250 39 73.9 13.5 60.6X +SQL ORC MR 1535 1548 19 10.2 97.6 8.4X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 187 209 20 84.3 11.9 1.0X -ParquetReader Vectorized -> Row 89 95 5 177.6 5.6 2.1X +ParquetReader Vectorized 259 269 15 60.6 16.5 1.0X +ParquetReader Vectorized -> Row 168 184 33 93.9 10.7 1.5X -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 14214 14549 474 1.1 903.7 1.0X -SQL Json 11866 11934 95 1.3 754.4 1.2X -SQL Parquet Vectorized 294 342 53 53.6 18.7 48.4X -SQL Parquet MR 2929 3004 107 5.4 186.2 4.9X -SQL ORC Vectorized 312 328 15 50.4 19.8 45.5X -SQL ORC MR 2037 2097 84 7.7 129.5 7.0X - -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +SQL CSV 12765 12774 13 1.2 811.6 1.0X +SQL Json 10144 10158 21 1.6 644.9 1.3X +SQL Parquet Vectorized 168 208 34 93.7 10.7 76.1X +SQL Parquet MR 2443 2458 21 6.4 155.3 5.2X +SQL ORC Vectorized 300 313 16 52.4 19.1 42.5X +SQL ORC MR 1736 1780 62 9.1 110.4 7.4X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 249 266 18 63.1 15.8 1.0X -ParquetReader Vectorized -> Row 192 247 36 82.1 12.2 1.3X +ParquetReader Vectorized 229 239 9 68.6 14.6 1.0X +ParquetReader Vectorized -> Row 224 265 26 70.2 14.3 1.0X -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 15502 15817 446 1.0 985.6 1.0X -SQL Json 12638 12646 11 1.2 803.5 1.2X -SQL Parquet Vectorized 193 256 44 81.7 12.2 80.5X -SQL Parquet MR 2943 2953 14 5.3 187.1 5.3X -SQL ORC Vectorized 324 370 34 48.5 20.6 47.8X -SQL ORC MR 2110 2163 75 7.5 134.1 7.3X - -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +SQL CSV 14055 14060 6 1.1 893.6 1.0X +SQL Json 10692 10738 64 1.5 679.8 1.3X +SQL Parquet Vectorized 167 223 34 94.0 10.6 84.0X +SQL Parquet MR 2416 2482 94 6.5 153.6 5.8X +SQL ORC Vectorized 329 344 12 47.8 20.9 42.7X +SQL ORC MR 1773 1789 23 8.9 112.7 7.9X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 276 287 14 57.0 17.6 1.0X -ParquetReader Vectorized -> Row 309 320 9 50.9 19.6 0.9X +ParquetReader Vectorized 232 239 9 67.9 14.7 1.0X +ParquetReader Vectorized -> Row 262 295 23 60.1 16.6 0.9X -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 20156 20694 761 0.8 1281.5 1.0X -SQL Json 15228 15380 214 1.0 968.2 1.3X -SQL Parquet Vectorized 325 346 20 48.4 20.7 62.0X -SQL Parquet MR 3144 3228 118 5.0 199.9 6.4X -SQL ORC Vectorized 516 526 7 30.5 32.8 39.0X -SQL ORC MR 2353 2367 19 6.7 149.6 8.6X - -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +SQL CSV 18964 18975 17 0.8 1205.7 1.0X +SQL Json 13173 13189 23 1.2 837.5 1.4X +SQL Parquet Vectorized 278 290 11 56.6 17.7 68.2X +SQL Parquet MR 2565 2589 34 6.1 163.1 7.4X +SQL ORC Vectorized 432 481 48 36.4 27.5 43.9X +SQL ORC MR 2052 2061 12 7.7 130.5 9.2X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 372 396 24 42.3 23.6 1.0X -ParquetReader Vectorized -> Row 437 462 25 36.0 27.8 0.9X +ParquetReader Vectorized 296 321 29 53.2 18.8 1.0X +ParquetReader Vectorized -> Row 329 335 7 47.7 20.9 0.9X -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 17413 17599 263 0.9 1107.1 1.0X -SQL Json 14416 14453 53 1.1 916.5 1.2X -SQL Parquet Vectorized 181 225 35 86.8 11.5 96.1X -SQL Parquet MR 2940 2996 78 5.3 186.9 5.9X -SQL ORC Vectorized 470 494 29 33.5 29.9 37.1X -SQL ORC MR 2351 2379 39 6.7 149.5 7.4X - -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +SQL CSV 15092 15095 5 1.0 959.5 1.0X +SQL Json 12166 12169 5 1.3 773.5 1.2X +SQL Parquet Vectorized 161 198 27 97.4 10.3 93.5X +SQL Parquet MR 2407 2412 6 6.5 153.0 6.3X +SQL ORC Vectorized 476 509 30 33.1 30.2 31.7X +SQL ORC MR 1978 1981 5 8.0 125.7 7.6X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 268 282 14 58.7 17.0 1.0X -ParquetReader Vectorized -> Row 298 321 18 52.8 18.9 0.9X +ParquetReader Vectorized 256 261 9 61.4 16.3 1.0X +ParquetReader Vectorized -> Row 210 257 22 74.7 13.4 1.2X -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 21666 21697 43 0.7 1377.5 1.0X -SQL Json 18307 18363 79 0.9 1163.9 1.2X -SQL Parquet Vectorized 310 337 22 50.7 19.7 69.9X -SQL Parquet MR 3089 3103 19 5.1 196.4 7.0X -SQL ORC Vectorized 589 617 31 26.7 37.5 36.8X -SQL ORC MR 2307 2377 98 6.8 146.7 9.4X - -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +SQL CSV 19785 19786 1 0.8 1257.9 1.0X +SQL Json 16339 16340 1 1.0 1038.8 1.2X +SQL Parquet Vectorized 284 302 19 55.4 18.1 69.7X +SQL Parquet MR 2570 2576 8 6.1 163.4 7.7X +SQL ORC Vectorized 473 519 32 33.3 30.0 41.9X +SQL ORC MR 2136 2142 9 7.4 135.8 9.3X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 400 415 18 39.3 25.4 1.0X -ParquetReader Vectorized -> Row 393 406 11 40.1 25.0 1.0X +ParquetReader Vectorized 298 351 32 52.8 18.9 1.0X +ParquetReader Vectorized -> Row 370 375 9 42.5 23.5 0.8X ================================================================================================ Int and String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Int and String Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 17703 17719 22 0.6 1688.3 1.0X -SQL Json 13095 13168 103 0.8 1248.9 1.4X -SQL Parquet Vectorized 2253 2266 19 4.7 214.8 7.9X -SQL Parquet MR 4913 4977 91 2.1 468.5 3.6X -SQL ORC Vectorized 2457 2467 14 4.3 234.3 7.2X -SQL ORC MR 4433 4464 44 2.4 422.8 4.0X +SQL CSV 13811 13824 18 0.8 1317.1 1.0X +SQL Json 11546 11589 61 0.9 1101.1 1.2X +SQL Parquet Vectorized 2143 2164 30 4.9 204.4 6.4X +SQL Parquet MR 4369 4386 24 2.4 416.7 3.2X +SQL ORC Vectorized 2289 2294 8 4.6 218.3 6.0X +SQL ORC MR 3770 3847 109 2.8 359.5 3.7X ================================================================================================ Repeated String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Repeated String: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 9741 9804 89 1.1 929.0 1.0X -SQL Json 8230 8401 241 1.3 784.9 1.2X -SQL Parquet Vectorized 618 650 31 17.0 58.9 15.8X -SQL Parquet MR 2258 2311 75 4.6 215.4 4.3X -SQL ORC Vectorized 608 629 15 17.3 58.0 16.0X -SQL ORC MR 2466 2479 18 4.3 235.2 4.0X +SQL CSV 7344 7377 47 1.4 700.3 1.0X +SQL Json 7117 7153 51 1.5 678.7 1.0X +SQL Parquet Vectorized 598 618 18 17.5 57.0 12.3X +SQL Parquet MR 1955 1969 20 5.4 186.5 3.8X +SQL ORC Vectorized 559 565 8 18.8 53.3 13.1X +SQL ORC MR 1923 1932 13 5.5 183.4 3.8X ================================================================================================ Partitioned Table Scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Data column - CSV 24195 24573 534 0.7 1538.3 1.0X -Data column - Json 14746 14883 194 1.1 937.5 1.6X -Data column - Parquet Vectorized 352 385 34 44.7 22.4 68.7X -Data column - Parquet MR 3674 3694 27 4.3 233.6 6.6X -Data column - ORC Vectorized 480 505 26 32.8 30.5 50.4X -Data column - ORC MR 2913 3004 128 5.4 185.2 8.3X -Partition column - CSV 7527 7544 23 2.1 478.6 3.2X -Partition column - Json 11955 12051 135 1.3 760.1 2.0X -Partition column - Parquet Vectorized 65 92 29 242.5 4.1 373.0X -Partition column - Parquet MR 1614 1628 21 9.7 102.6 15.0X -Partition column - ORC Vectorized 71 99 29 220.1 4.5 338.5X -Partition column - ORC MR 1761 1769 11 8.9 112.0 13.7X -Both columns - CSV 24077 24127 70 0.7 1530.8 1.0X -Both columns - Json 15286 15479 273 1.0 971.9 1.6X -Both columns - Parquet Vectorized 376 412 40 41.9 23.9 64.4X -Both columns - Parquet MR 3808 3826 26 4.1 242.1 6.4X -Both columns - ORC Vectorized 560 604 42 28.1 35.6 43.2X -Both columns - ORC MR 3046 3080 49 5.2 193.7 7.9X +Data column - CSV 19266 19281 21 0.8 1224.9 1.0X +Data column - Json 13119 13126 10 1.2 834.1 1.5X +Data column - Parquet Vectorized 305 334 27 51.6 19.4 63.2X +Data column - Parquet MR 2978 3022 63 5.3 189.3 6.5X +Data column - ORC Vectorized 446 480 32 35.3 28.3 43.2X +Data column - ORC MR 2451 2469 24 6.4 155.9 7.9X +Partition column - CSV 6640 6641 1 2.4 422.2 2.9X +Partition column - Json 10485 10512 37 1.5 666.6 1.8X +Partition column - Parquet Vectorized 65 88 24 241.2 4.1 295.4X +Partition column - Parquet MR 1403 1434 44 11.2 89.2 13.7X +Partition column - ORC Vectorized 62 86 21 253.8 3.9 310.9X +Partition column - ORC MR 1523 1525 3 10.3 96.8 12.6X +Both columns - CSV 19347 19354 10 0.8 1230.0 1.0X +Both columns - Json 13788 13793 6 1.1 876.6 1.4X +Both columns - Parquet Vectorized 346 414 70 45.5 22.0 55.7X +Both columns - Parquet MR 3022 3032 14 5.2 192.1 6.4X +Both columns - ORC Vectorized 479 519 28 32.9 30.4 40.2X +Both columns - ORC MR 2539 2540 1 6.2 161.4 7.6X ================================================================================================ String with Nulls Scan ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (0.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 11805 12021 306 0.9 1125.8 1.0X -SQL Json 12051 12105 77 0.9 1149.3 1.0X -SQL Parquet Vectorized 1474 1545 100 7.1 140.6 8.0X -SQL Parquet MR 4488 4492 4 2.3 428.1 2.6X -ParquetReader Vectorized 1140 1140 1 9.2 108.7 10.4X -SQL ORC Vectorized 1164 1178 20 9.0 111.0 10.1X -SQL ORC MR 3745 3817 102 2.8 357.1 3.2X - -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +SQL CSV 9158 9163 8 1.1 873.3 1.0X +SQL Json 10429 10448 27 1.0 994.6 0.9X +SQL Parquet Vectorized 1363 1660 420 7.7 130.0 6.7X +SQL Parquet MR 3894 3898 5 2.7 371.4 2.4X +ParquetReader Vectorized 1021 1031 14 10.3 97.4 9.0X +SQL ORC Vectorized 1168 1191 33 9.0 111.4 7.8X +SQL ORC MR 3267 3287 28 3.2 311.6 2.8X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (50.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 9814 9837 33 1.1 936.0 1.0X -SQL Json 9317 9445 182 1.1 888.5 1.1X -SQL Parquet Vectorized 1117 1155 52 9.4 106.6 8.8X -SQL Parquet MR 3463 3538 106 3.0 330.3 2.8X -ParquetReader Vectorized 1033 1039 8 10.1 98.6 9.5X -SQL ORC Vectorized 1307 1353 65 8.0 124.7 7.5X -SQL ORC MR 3644 3690 65 2.9 347.5 2.7X - -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +SQL CSV 7570 7577 11 1.4 721.9 1.0X +SQL Json 8085 8096 14 1.3 771.1 0.9X +SQL Parquet Vectorized 1097 1101 5 9.6 104.7 6.9X +SQL Parquet MR 2999 3014 21 3.5 286.0 2.5X +ParquetReader Vectorized 1052 1064 18 10.0 100.3 7.2X +SQL ORC Vectorized 1286 2162 1239 8.2 122.6 5.9X +SQL ORC MR 3053 3123 100 3.4 291.1 2.5X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (95.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 8145 8270 176 1.3 776.8 1.0X -SQL Json 5714 5764 71 1.8 544.9 1.4X -SQL Parquet Vectorized 235 264 15 44.6 22.4 34.7X -SQL Parquet MR 2398 2412 19 4.4 228.7 3.4X -ParquetReader Vectorized 248 262 11 42.3 23.6 32.9X -SQL ORC Vectorized 430 462 37 24.4 41.0 18.9X -SQL ORC MR 1983 1993 14 5.3 189.1 4.1X +SQL CSV 6211 6214 3 1.7 592.4 1.0X +SQL Json 4977 4994 24 2.1 474.6 1.2X +SQL Parquet Vectorized 260 272 10 40.3 24.8 23.9X +SQL Parquet MR 1981 1985 5 5.3 188.9 3.1X +ParquetReader Vectorized 268 276 11 39.1 25.6 23.2X +SQL ORC Vectorized 428 457 35 24.5 40.8 14.5X +SQL ORC MR 1696 1705 12 6.2 161.8 3.7X ================================================================================================ Single Column Scan From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 10 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 2448 2461 18 0.4 2334.3 1.0X -SQL Json 3332 3370 53 0.3 3177.6 0.7X -SQL Parquet Vectorized 51 87 25 20.7 48.2 48.4X -SQL Parquet MR 239 278 35 4.4 227.5 10.3X -SQL ORC Vectorized 60 82 19 17.5 57.3 40.8X -SQL ORC MR 197 219 26 5.3 188.3 12.4X - -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +SQL CSV 2067 2093 36 0.5 1971.6 1.0X +SQL Json 3047 5663 NaN 0.3 2906.0 0.7X +SQL Parquet Vectorized 50 73 21 20.9 47.7 41.3X +SQL Parquet MR 205 224 28 5.1 195.3 10.1X +SQL ORC Vectorized 60 79 23 17.4 57.5 34.3X +SQL ORC MR 173 196 25 6.1 165.1 11.9X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 50 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 6034 6061 39 0.2 5754.0 1.0X -SQL Json 12232 12315 118 0.1 11665.4 0.5X -SQL Parquet Vectorized 73 120 30 14.4 69.6 82.6X -SQL Parquet MR 316 368 44 3.3 301.1 19.1X -SQL ORC Vectorized 76 122 36 13.7 72.9 79.0X -SQL ORC MR 206 261 47 5.1 196.5 29.3X - -OpenJDK 64-Bit Server VM 11.0.10+9-LTS on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz +SQL CSV 4841 4844 5 0.2 4616.4 1.0X +SQL Json 11721 11745 34 0.1 11177.9 0.4X +SQL Parquet Vectorized 67 101 27 15.7 63.8 72.4X +SQL Parquet MR 225 247 27 4.7 214.2 21.5X +SQL ORC Vectorized 75 99 26 13.9 71.7 64.4X +SQL ORC MR 192 219 26 5.5 183.4 25.2X + +OpenJDK 64-Bit Server VM 11.0.13+8-LTS on Linux 5.11.0-1021-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 100 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 10307 10309 4 0.1 9829.0 1.0X -SQL Json 23412 23539 180 0.0 22327.7 0.4X -SQL Parquet Vectorized 105 151 23 10.0 99.9 98.4X -SQL Parquet MR 295 325 29 3.6 281.5 34.9X -SQL ORC Vectorized 85 112 31 12.4 81.0 121.4X -SQL ORC MR 212 255 66 4.9 202.3 48.6X +SQL CSV 8410 8414 5 0.1 8020.8 1.0X +SQL Json 22537 22923 547 0.0 21492.8 0.4X +SQL Parquet Vectorized 101 141 32 10.4 96.2 83.4X +SQL Parquet MR 262 289 45 4.0 249.9 32.1X +SQL ORC Vectorized 90 113 32 11.7 85.4 93.9X +SQL ORC MR 210 232 36 5.0 200.3 40.0X diff --git a/sql/core/benchmarks/DataSourceReadBenchmark-results.txt b/sql/core/benchmarks/DataSourceReadBenchmark-results.txt index fe083703ae..65db1afc51 100644 --- a/sql/core/benchmarks/DataSourceReadBenchmark-results.txt +++ b/sql/core/benchmarks/DataSourceReadBenchmark-results.txt @@ -2,251 +2,269 @@ SQL Single Numeric Column Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz -SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +SQL Single BOOLEAN Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 15943 15956 18 1.0 1013.6 1.0X -SQL Json 9109 9158 70 1.7 579.1 1.8X -SQL Parquet Vectorized 168 191 16 93.8 10.7 95.1X -SQL Parquet MR 1938 1950 17 8.1 123.2 8.2X -SQL ORC Vectorized 191 199 6 82.2 12.2 83.3X -SQL ORC MR 1523 1537 20 10.3 96.8 10.5X +SQL CSV 11497 11744 349 1.4 731.0 1.0X +SQL Json 7073 7099 37 2.2 449.7 1.6X +SQL Parquet Vectorized 105 126 17 149.9 6.7 109.6X +SQL Parquet MR 1647 1648 2 9.6 104.7 7.0X +SQL ORC Vectorized 157 167 5 100.0 10.0 73.1X +SQL ORC MR 1466 1485 27 10.7 93.2 7.8X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +Parquet Reader Single BOOLEAN Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------- +ParquetReader Vectorized 114 123 8 137.8 7.3 1.0X +ParquetReader Vectorized -> Row 42 44 1 372.1 2.7 2.7X -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz +SQL Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +SQL CSV 15825 15961 193 1.0 1006.1 1.0X +SQL Json 7966 8054 125 2.0 506.5 2.0X +SQL Parquet Vectorized 136 148 9 115.4 8.7 116.1X +SQL Parquet MR 1814 1825 15 8.7 115.4 8.7X +SQL ORC Vectorized 138 147 6 114.4 8.7 115.1X +SQL ORC MR 1299 1382 117 12.1 82.6 12.2X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single TINYINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 203 206 3 77.5 12.9 1.0X -ParquetReader Vectorized -> Row 97 100 2 161.6 6.2 2.1X +ParquetReader Vectorized 179 185 9 88.0 11.4 1.0X +ParquetReader Vectorized -> Row 91 101 3 172.6 5.8 2.0X -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 17062 17089 38 0.9 1084.8 1.0X -SQL Json 9718 9720 3 1.6 617.9 1.8X -SQL Parquet Vectorized 326 333 7 48.2 20.7 52.3X -SQL Parquet MR 2305 2329 34 6.8 146.6 7.4X -SQL ORC Vectorized 201 205 3 78.2 12.8 84.8X -SQL ORC MR 1795 1796 0 8.8 114.1 9.5X - -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 15449 16211 1077 1.0 982.2 1.0X +SQL Json 7955 8292 476 2.0 505.8 1.9X +SQL Parquet Vectorized 195 211 8 80.7 12.4 79.2X +SQL Parquet MR 1866 1890 33 8.4 118.7 8.3X +SQL ORC Vectorized 163 173 8 96.6 10.4 94.9X +SQL ORC MR 1550 1555 8 10.1 98.5 10.0X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single SMALLINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------- -ParquetReader Vectorized 333 339 7 47.2 21.2 1.0X -ParquetReader Vectorized -> Row 283 285 3 55.7 18.0 1.2X +ParquetReader Vectorized 299 302 4 52.5 19.0 1.0X +ParquetReader Vectorized -> Row 264 280 14 59.6 16.8 1.1X -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 18722 18809 123 0.8 1190.3 1.0X -SQL Json 10192 10249 80 1.5 648.0 1.8X -SQL Parquet Vectorized 155 162 8 101.6 9.8 120.9X -SQL Parquet MR 2348 2360 16 6.7 149.3 8.0X -SQL ORC Vectorized 265 275 7 59.3 16.9 70.5X -SQL ORC MR 1892 1938 65 8.3 120.3 9.9X - -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 16640 16834 273 0.9 1058.0 1.0X +SQL Json 8859 8862 3 1.8 563.3 1.9X +SQL Parquet Vectorized 144 155 8 109.0 9.2 115.3X +SQL Parquet MR 1960 2023 89 8.0 124.6 8.5X +SQL ORC Vectorized 218 233 11 72.3 13.8 76.5X +SQL ORC MR 1440 1442 3 10.9 91.6 11.6X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single INT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 243 251 7 64.8 15.4 1.0X -ParquetReader Vectorized -> Row 222 229 5 70.9 14.1 1.1X +ParquetReader Vectorized 224 241 13 70.2 14.2 1.0X +ParquetReader Vectorized -> Row 214 221 10 73.6 13.6 1.0X -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 24299 24358 84 0.6 1544.9 1.0X -SQL Json 13349 13429 114 1.2 848.7 1.8X -SQL Parquet Vectorized 215 241 59 73.3 13.6 113.2X -SQL Parquet MR 2508 2508 0 6.3 159.4 9.7X -SQL ORC Vectorized 323 330 6 48.7 20.5 75.2X -SQL ORC MR 1993 2009 22 7.9 126.7 12.2X - -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 22998 23324 461 0.7 1462.2 1.0X +SQL Json 12165 12179 20 1.3 773.4 1.9X +SQL Parquet Vectorized 237 265 69 66.3 15.1 96.9X +SQL Parquet MR 2199 2199 0 7.2 139.8 10.5X +SQL ORC Vectorized 303 311 10 51.9 19.3 76.0X +SQL ORC MR 1750 1763 18 9.0 111.3 13.1X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single BIGINT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 310 351 74 50.8 19.7 1.0X -ParquetReader Vectorized -> Row 281 297 8 55.9 17.9 1.1X +ParquetReader Vectorized 331 368 80 47.6 21.0 1.0X +ParquetReader Vectorized -> Row 314 318 6 50.0 20.0 1.1X -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 19745 19811 93 0.8 1255.4 1.0X -SQL Json 12523 12760 335 1.3 796.2 1.6X -SQL Parquet Vectorized 153 160 6 102.9 9.7 129.2X -SQL Parquet MR 2325 2338 18 6.8 147.8 8.5X -SQL ORC Vectorized 389 401 8 40.5 24.7 50.8X -SQL ORC MR 2009 2009 1 7.8 127.7 9.8X - -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 17442 18560 1581 0.9 1108.9 1.0X +SQL Json 10833 11056 315 1.5 688.8 1.6X +SQL Parquet Vectorized 150 162 10 105.0 9.5 116.5X +SQL Parquet MR 1804 1922 167 8.7 114.7 9.7X +SQL ORC Vectorized 317 336 20 49.6 20.2 55.0X +SQL ORC MR 1550 1648 139 10.1 98.5 11.3X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single FLOAT Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 240 244 4 65.5 15.3 1.0X -ParquetReader Vectorized -> Row 223 230 6 70.5 14.2 1.1X +ParquetReader Vectorized 240 263 11 65.7 15.2 1.0X +ParquetReader Vectorized -> Row 224 235 15 70.4 14.2 1.1X -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz SQL Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 27223 27293 99 0.6 1730.8 1.0X -SQL Json 18601 18646 63 0.8 1182.6 1.5X -SQL Parquet Vectorized 247 251 3 63.8 15.7 110.4X -SQL Parquet MR 2724 2773 69 5.8 173.2 10.0X -SQL ORC Vectorized 474 484 10 33.2 30.1 57.4X -SQL ORC MR 2342 2368 37 6.7 148.9 11.6X - -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 22438 23472 1462 0.7 1426.5 1.0X +SQL Json 15839 15888 70 1.0 1007.0 1.4X +SQL Parquet Vectorized 215 229 12 73.3 13.6 104.6X +SQL Parquet MR 1928 2061 188 8.2 122.6 11.6X +SQL ORC Vectorized 393 421 17 40.0 25.0 57.0X +SQL ORC MR 1799 1814 22 8.7 114.4 12.5X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parquet Reader Single DOUBLE Column Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -ParquetReader Vectorized 326 335 13 48.3 20.7 1.0X -ParquetReader Vectorized -> Row 358 365 7 44.0 22.7 0.9X +ParquetReader Vectorized 310 316 9 50.7 19.7 1.0X +ParquetReader Vectorized -> Row 289 302 20 54.3 18.4 1.1X ================================================================================================ Int and String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Int and String Scan: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 18706 18716 15 0.6 1783.9 1.0X -SQL Json 12665 12762 138 0.8 1207.8 1.5X -SQL Parquet Vectorized 2408 2419 15 4.4 229.6 7.8X -SQL Parquet MR 4599 4620 30 2.3 438.6 4.1X -SQL ORC Vectorized 2397 2400 3 4.4 228.6 7.8X -SQL ORC MR 4267 4288 30 2.5 406.9 4.4X +SQL CSV 15669 15869 283 0.7 1494.3 1.0X +SQL Json 10126 10559 613 1.0 965.7 1.5X +SQL Parquet Vectorized 2056 2064 11 5.1 196.0 7.6X +SQL Parquet MR 3918 3927 13 2.7 373.6 4.0X +SQL ORC Vectorized 1786 1887 143 5.9 170.3 8.8X +SQL ORC MR 3521 3555 48 3.0 335.8 4.4X ================================================================================================ Repeated String Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Repeated String: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 10822 10838 23 1.0 1032.0 1.0X -SQL Json 7459 7488 41 1.4 711.4 1.5X -SQL Parquet Vectorized 875 895 26 12.0 83.5 12.4X -SQL Parquet MR 1976 2002 37 5.3 188.4 5.5X -SQL ORC Vectorized 533 539 8 19.7 50.9 20.3X -SQL ORC MR 2191 2194 5 4.8 208.9 4.9X +SQL CSV 8659 8948 409 1.2 825.8 1.0X +SQL Json 6410 6536 177 1.6 611.3 1.4X +SQL Parquet Vectorized 655 709 47 16.0 62.4 13.2X +SQL Parquet MR 1528 1531 3 6.9 145.7 5.7X +SQL ORC Vectorized 388 416 24 27.0 37.0 22.3X +SQL ORC MR 1599 1700 142 6.6 152.5 5.4X ================================================================================================ Partitioned Table Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Partitioned Table: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Data column - CSV 31196 31449 359 0.5 1983.4 1.0X -Data column - Json 16118 16855 1041 1.0 1024.8 1.9X -Data column - Parquet Vectorized 243 251 9 64.8 15.4 128.4X -Data column - Parquet MR 4213 4288 106 3.7 267.8 7.4X -Data column - ORC Vectorized 335 341 4 46.9 21.3 93.1X -Data column - ORC MR 3119 3146 38 5.0 198.3 10.0X -Partition column - CSV 9616 9915 423 1.6 611.3 3.2X -Partition column - Json 14136 14164 39 1.1 898.8 2.2X -Partition column - Parquet Vectorized 64 70 6 243.9 4.1 483.8X -Partition column - Parquet MR 1954 1980 38 8.1 124.2 16.0X -Partition column - ORC Vectorized 67 74 8 233.4 4.3 462.9X -Partition column - ORC MR 2461 2479 26 6.4 156.4 12.7X -Both columns - CSV 30327 30666 479 0.5 1928.2 1.0X -Both columns - Json 18656 18789 188 0.8 1186.1 1.7X -Both columns - Parquet Vectorized 291 297 7 54.0 18.5 107.2X -Both columns - Parquet MR 4430 4443 19 3.6 281.6 7.0X -Both columns - ORC Vectorized 403 411 11 39.0 25.6 77.4X -Both columns - ORC MR 3580 3584 5 4.4 227.6 8.7X +Data column - CSV 21094 21357 372 0.7 1341.1 1.0X +Data column - Json 11163 11434 383 1.4 709.7 1.9X +Data column - Parquet Vectorized 225 238 13 69.9 14.3 93.7X +Data column - Parquet MR 2218 2342 175 7.1 141.0 9.5X +Data column - ORC Vectorized 276 300 20 56.9 17.6 76.4X +Data column - ORC MR 1851 1863 17 8.5 117.7 11.4X +Partition column - CSV 5834 6119 403 2.7 370.9 3.6X +Partition column - Json 9746 9754 11 1.6 619.6 2.2X +Partition column - Parquet Vectorized 57 61 2 273.9 3.7 367.4X +Partition column - Parquet MR 1164 1167 5 13.5 74.0 18.1X +Partition column - ORC Vectorized 60 64 3 261.3 3.8 350.4X +Partition column - ORC MR 1298 1304 8 12.1 82.5 16.2X +Both columns - CSV 22632 22636 4 0.7 1438.9 0.9X +Both columns - Json 12568 12587 26 1.3 799.1 1.7X +Both columns - Parquet Vectorized 283 288 7 55.5 18.0 74.4X +Both columns - Parquet MR 2547 2553 8 6.2 161.9 8.3X +Both columns - ORC Vectorized 343 346 4 45.8 21.8 61.5X +Both columns - ORC MR 2177 2178 2 7.2 138.4 9.7X ================================================================================================ String with Nulls Scan ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (0.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 15606 15614 11 0.7 1488.3 1.0X -SQL Json 15406 15451 63 0.7 1469.3 1.0X -SQL Parquet Vectorized 1555 1573 25 6.7 148.3 10.0X -SQL Parquet MR 5369 5377 11 2.0 512.0 2.9X -ParquetReader Vectorized 1145 1150 7 9.2 109.2 13.6X -SQL ORC Vectorized 1023 1027 6 10.2 97.6 15.3X -SQL ORC MR 4421 4542 172 2.4 421.6 3.5X - -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 11364 11364 0 0.9 1083.7 1.0X +SQL Json 10555 10562 9 1.0 1006.6 1.1X +SQL Parquet Vectorized 1299 1309 13 8.1 123.9 8.7X +SQL Parquet MR 3350 3351 1 3.1 319.5 3.4X +ParquetReader Vectorized 983 987 5 10.7 93.8 11.6X +SQL ORC Vectorized 912 913 1 11.5 87.0 12.5X +SQL ORC MR 3056 3059 5 3.4 291.4 3.7X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (50.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 11096 11159 90 0.9 1058.2 1.0X -SQL Json 10797 11304 717 1.0 1029.7 1.0X -SQL Parquet Vectorized 1218 1230 16 8.6 116.2 9.1X -SQL Parquet MR 3778 3806 40 2.8 360.3 2.9X -ParquetReader Vectorized 1108 1118 14 9.5 105.7 10.0X -SQL ORC Vectorized 1361 1371 13 7.7 129.8 8.2X -SQL ORC MR 4186 4196 14 2.5 399.2 2.7X - -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 8651 8654 5 1.2 825.0 1.0X +SQL Json 7791 7794 4 1.3 743.0 1.1X +SQL Parquet Vectorized 1045 1055 15 10.0 99.7 8.3X +SQL Parquet MR 2516 2519 3 4.2 240.0 3.4X +ParquetReader Vectorized 927 933 6 11.3 88.4 9.3X +SQL ORC Vectorized 1285 1286 2 8.2 122.5 6.7X +SQL ORC MR 3013 3013 0 3.5 287.4 2.9X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz String with Nulls Scan (95.0%): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 8803 8866 90 1.2 839.5 1.0X -SQL Json 7220 7249 42 1.5 688.5 1.2X -SQL Parquet Vectorized 258 265 7 40.6 24.6 34.1X -SQL Parquet MR 2760 2761 0 3.8 263.2 3.2X -ParquetReader Vectorized 277 283 5 37.8 26.4 31.7X -SQL ORC Vectorized 514 522 6 20.4 49.1 17.1X -SQL ORC MR 2523 2591 96 4.2 240.6 3.5X +SQL CSV 6272 6288 23 1.7 598.1 1.0X +SQL Json 4469 4469 0 2.3 426.2 1.4X +SQL Parquet Vectorized 231 235 7 45.4 22.0 27.2X +SQL Parquet MR 1673 1674 2 6.3 159.5 3.7X +ParquetReader Vectorized 243 244 3 43.1 23.2 25.8X +SQL ORC Vectorized 471 472 2 22.2 45.0 13.3X +SQL ORC MR 1606 1618 17 6.5 153.2 3.9X ================================================================================================ Single Column Scan From Wide Columns ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 10 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 3022 3032 14 0.3 2881.9 1.0X -SQL Json 4047 4051 5 0.3 3859.5 0.7X -SQL Parquet Vectorized 50 54 6 20.8 48.1 59.9X -SQL Parquet MR 299 301 2 3.5 285.0 10.1X -SQL ORC Vectorized 59 63 11 17.9 55.9 51.6X -SQL ORC MR 255 259 5 4.1 243.4 11.8X - -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 2171 2173 2 0.5 2070.8 1.0X +SQL Json 2266 2278 17 0.5 2161.3 1.0X +SQL Parquet Vectorized 51 55 7 20.4 49.0 42.2X +SQL Parquet MR 190 192 2 5.5 180.9 11.4X +SQL ORC Vectorized 57 61 8 18.4 54.2 38.2X +SQL ORC MR 161 164 2 6.5 153.8 13.5X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 50 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 7250 7252 3 0.1 6914.4 1.0X -SQL Json 15641 15718 109 0.1 14916.8 0.5X -SQL Parquet Vectorized 66 72 8 15.9 62.9 110.0X -SQL Parquet MR 320 323 3 3.3 305.0 22.7X -SQL ORC Vectorized 72 77 11 14.6 68.6 100.9X -SQL ORC MR 269 273 5 3.9 256.8 26.9X - -OpenJDK 64-Bit Server VM 1.8.0_282-b08 on Linux 5.4.0-1043-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +SQL CSV 5200 5211 15 0.2 4959.5 1.0X +SQL Json 8312 8318 8 0.1 7927.1 0.6X +SQL Parquet Vectorized 67 73 10 15.7 63.9 77.6X +SQL Parquet MR 210 214 4 5.0 200.4 24.8X +SQL ORC Vectorized 70 77 16 15.0 66.7 74.3X +SQL ORC MR 182 184 2 5.8 173.6 28.6X + +OpenJDK 64-Bit Server VM 1.8.0_312-b07 on Linux 5.11.0-1020-azure +Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Single Column Scan from 100 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -SQL CSV 10962 11340 535 0.1 10454.1 1.0X -SQL Json 24951 25755 1137 0.0 23795.0 0.4X -SQL Parquet Vectorized 84 93 6 12.4 80.5 129.9X -SQL Parquet MR 280 296 14 3.7 266.8 39.2X -SQL ORC Vectorized 70 76 6 15.0 66.6 156.9X -SQL ORC MR 231 242 13 4.5 220.1 47.5X +SQL CSV 9030 9032 2 0.1 8611.8 1.0X +SQL Json 15429 15462 46 0.1 14714.5 0.6X +SQL Parquet Vectorized 91 97 8 11.5 87.2 98.8X +SQL Parquet MR 235 239 3 4.5 224.2 38.4X +SQL ORC Vectorized 80 84 9 13.1 76.4 112.8X +SQL ORC MR 192 201 7 5.5 183.4 47.0X diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index 39591be3b4..0eb5d65a4a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -53,19 +53,52 @@ public void skip() { throw new UnsupportedOperationException(); } + private void updateCurrentByte() { + try { + currentByte = (byte) in.read(); + } catch (IOException e) { + throw new ParquetDecodingException("Failed to read a byte", e); + } + } + @Override public final void readBooleans(int total, WritableColumnVector c, int rowId) { - // TODO: properly vectorize this - for (int i = 0; i < total; i++) { - c.putBoolean(rowId + i, readBoolean()); + int i = 0; + if (bitOffset > 0) { + i = Math.min(8 - bitOffset, total); + c.putBooleans(rowId, i, currentByte, bitOffset); + bitOffset = (bitOffset + i) & 7; + } + for (; i + 7 < total; i += 8) { + updateCurrentByte(); + c.putBooleans(rowId + i, currentByte); + } + if (i < total) { + updateCurrentByte(); + bitOffset = total - i; + c.putBooleans(rowId + i, bitOffset, currentByte, 0); } } @Override public final void skipBooleans(int total) { - // TODO: properly vectorize this - for (int i = 0; i < total; i++) { - readBoolean(); + int i = 0; + if (bitOffset > 0) { + i = Math.min(8 - bitOffset, total); + bitOffset = (bitOffset + i) & 7; + } + if (i + 7 < total) { + int numBytesToSkip = (total - i) / 8; + try { + in.skipFully(numBytesToSkip); + } catch (IOException e) { + throw new ParquetDecodingException("Failed to skip bytes", e); + } + i += numBytesToSkip * 8; + } + if (i < total) { + updateCurrentByte(); + bitOffset = total - i; } } @@ -276,13 +309,8 @@ public void skipShorts(int total) { @Override public final boolean readBoolean() { - // TODO: vectorize decoding and keep boolean[] instead of currentByte if (bitOffset == 0) { - try { - currentByte = (byte) in.read(); - } catch (IOException e) { - throw new ParquetDecodingException("Failed to read a byte", e); - } + updateCurrentByte(); } boolean v = (currentByte & (1 << bitOffset)) != 0; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index f7c9dc55f7..bbe96819a6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -152,6 +152,18 @@ public void putBooleans(int rowId, int count, boolean value) { } } + @Override + public void putBooleans(int rowId, byte src) { + Platform.putByte(null, data + rowId, (byte)(src & 1)); + Platform.putByte(null, data + rowId + 1, (byte)(src >>> 1 & 1)); + Platform.putByte(null, data + rowId + 2, (byte)(src >>> 2 & 1)); + Platform.putByte(null, data + rowId + 3, (byte)(src >>> 3 & 1)); + Platform.putByte(null, data + rowId + 4, (byte)(src >>> 4 & 1)); + Platform.putByte(null, data + rowId + 5, (byte)(src >>> 5 & 1)); + Platform.putByte(null, data + rowId + 6, (byte)(src >>> 6 & 1)); + Platform.putByte(null, data + rowId + 7, (byte)(src >>> 7 & 1)); + } + @Override public boolean getBoolean(int rowId) { return Platform.getByte(null, data + rowId) == 1; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 3fb96d872c..833a93f2a2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -147,6 +147,18 @@ public void putBooleans(int rowId, int count, boolean value) { } } + @Override + public void putBooleans(int rowId, byte src) { + byteData[rowId] = (byte)(src & 1); + byteData[rowId + 1] = (byte)(src >>> 1 & 1); + byteData[rowId + 2] = (byte)(src >>> 2 & 1); + byteData[rowId + 3] = (byte)(src >>> 3 & 1); + byteData[rowId + 4] = (byte)(src >>> 4 & 1); + byteData[rowId + 5] = (byte)(src >>> 5 & 1); + byteData[rowId + 6] = (byte)(src >>> 6 & 1); + byteData[rowId + 7] = (byte)(src >>> 7 & 1); + } + @Override public boolean getBoolean(int rowId) { return byteData[rowId] == 1; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 8f7dcf2374..5e01c37279 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -46,6 +46,7 @@ * WritableColumnVector are intended to be reused. */ public abstract class WritableColumnVector extends ColumnVector { + private final byte[] byte8 = new byte[8]; /** * Resets this column for writing. The currently stored values are no longer accessible. @@ -201,6 +202,29 @@ public WritableColumnVector reserveDictionaryIds(int capacity) { */ public abstract void putBooleans(int rowId, int count, boolean value); + /** + * Sets bits from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) + * src must contain bit-packed 8 booleans in the byte. + */ + public void putBooleans(int rowId, int count, byte src, int srcIndex) { + assert ((srcIndex + count) <= 8); + byte8[0] = (byte)(src & 1); + byte8[1] = (byte)(src >>> 1 & 1); + byte8[2] = (byte)(src >>> 2 & 1); + byte8[3] = (byte)(src >>> 3 & 1); + byte8[4] = (byte)(src >>> 4 & 1); + byte8[5] = (byte)(src >>> 5 & 1); + byte8[6] = (byte)(src >>> 6 & 1); + byte8[7] = (byte)(src >>> 7 & 1); + putBytes(rowId, count, byte8, srcIndex); + } + + /** + * Sets bits from [src[0], src[7]] to [rowId, rowId + 7] + * src must contain bit-packed 8 booleans in the byte. + */ + public abstract void putBooleans(int rowId, byte src); + /** * Sets `value` to the value at rowId. */ @@ -470,6 +494,18 @@ public final int appendBooleans(int count, boolean v) { return result; } + /** + * Append bits from [src[offset], src[offset + count]) + * src must contain bit-packed 8 booleans in the byte. + */ + public final int appendBooleans(int count, byte src, int offset) { + reserve(elementsAppended + count); + int result = elementsAppended; + putBooleans(elementsAppended, count, src, offset); + elementsAppended += count; + return result; + } + public final int appendByte(byte v) { reserve(elementsAppended + 1); putByte(elementsAppended, v); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 71d06576bc..d5d814d177 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -23,10 +23,10 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, NoSuchTableException, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, NoSuchTableException, UnresolvedDBObjectName, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, CreateTableAsSelectStatement, InsertIntoStatement, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelectStatement} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, CreateTableAsSelectStatement, InsertIntoStatement, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, TableSpec} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{CatalogPlugin, CatalogV2Implicits, CatalogV2Util, Identifier, SupportsCatalogOptions, Table, TableCatalog, TableProvider, V1Table} import org.apache.spark.sql.connector.catalog.TableCapability._ @@ -586,19 +586,22 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { AppendData.byName(v2Relation, df.logicalPlan, extraOptions.toMap) case (SaveMode.Overwrite, _) => - ReplaceTableAsSelectStatement( - nameParts, - df.queryExecution.analyzed, + val tableSpec = TableSpec( + bucketSpec = None, + properties = Map.empty, + provider = Some(source), + options = Map.empty, + location = extraOptions.get("path"), + comment = extraOptions.get(TableCatalog.PROP_COMMENT), + serde = None, + external = false) + ReplaceTableAsSelect( + UnresolvedDBObjectName(nameParts, isNamespace = false), partitioningAsV2, - None, - Map.empty, - Some(source), - Map.empty, - extraOptions.get("path"), - extraOptions.get(TableCatalog.PROP_COMMENT), - extraOptions.toMap, - None, - orCreate = true) // Create the table if it doesn't exist + df.queryExecution.analyzed, + tableSpec, + writeOptions = Map.empty, + orCreate = true) // Create the table if it doesn't exist case (other, _) => // We have a potential race condition here in AppendMode, if the table suddenly gets diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index bff7ee4323..b99195de13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -21,9 +21,9 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException, UnresolvedDBObjectName, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions.{Attribute, Bucket, Days, Hours, Literal, Months, Years} -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelectStatement, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelectStatement} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelectStatement, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect, TableSpec} import org.apache.spark.sql.connector.expressions.{LogicalExpressions, NamedReference, Transform} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.IntegerType @@ -195,20 +195,22 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) } private def internalReplace(orCreate: Boolean): Unit = { - runCommand( - ReplaceTableAsSelectStatement( - tableName, - logicalPlan, - partitioning.getOrElse(Seq.empty), - None, - properties.toMap, - provider, - Map.empty, - None, - None, - options.toMap, - None, - orCreate = orCreate)) + val tableSpec = TableSpec( + bucketSpec = None, + properties = properties.toMap, + provider = provider, + options = Map.empty, + location = None, + comment = None, + serde = None, + external = false) + runCommand(ReplaceTableAsSelect( + UnresolvedDBObjectName(tableName, isNamespace = false), + partitioning.getOrElse(Seq.empty), + logicalPlan, + tableSpec, + writeOptions = options.toMap, + orCreate = orCreate)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 63812b873b..df110aa269 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -97,13 +97,17 @@ class SparkSession private( * since that would cause every new session to reinvoke Spark Session Extensions on the currently * running extensions. */ - private[sql] def this(sc: SparkContext) = { + private[sql] def this( + sc: SparkContext, + initialSessionOptions: java.util.HashMap[String, String]) = { this(sc, None, None, SparkSession.applyExtensions( sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty), - new SparkSessionExtensions), Map.empty) + new SparkSessionExtensions), initialSessionOptions.asScala.toMap) } + private[sql] def this(sc: SparkContext) = this(sc, new java.util.HashMap[String, String]()) + private[sql] val sessionUUID: String = UUID.randomUUID.toString sparkContext.assertNotStopped() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceCharWithVarchar.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceCharWithVarchar.scala index 7404a30fed..3f9eb5c808 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceCharWithVarchar.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ReplaceCharWithVarchar.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, CreateV2Table, LogicalPlan, ReplaceColumns, ReplaceTable} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, CreateTable, LogicalPlan, ReplaceColumns, ReplaceTable} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableChangeColumnCommand, CreateDataSourceTableCommand, CreateTableCommand} @@ -31,7 +31,7 @@ object ReplaceCharWithVarchar extends Rule[LogicalPlan] { plan.resolveOperators { // V2 commands - case cmd: CreateV2Table => + case cmd: CreateTable => cmd.copy(tableSchema = replaceCharWithVarcharInSchema(cmd.tableSchema)) case cmd: ReplaceTable => cmd.copy(tableSchema = replaceCharWithVarcharInSchema(cmd.tableSchema)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index c55bdcabef..6f41497ddb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, Ca import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource} +import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1, DataSource} import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType} @@ -111,7 +111,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) case SetNamespaceProperties(DatabaseInSessionCatalog(db), properties) => AlterDatabasePropertiesCommand(db, properties) - case SetNamespaceLocation(DatabaseInSessionCatalog(db), location) => + case SetNamespaceLocation(DatabaseInSessionCatalog(db), location) if conf.useV1Command => AlterDatabaseSetLocationCommand(db, location) case RenameTable(ResolvedV1TableOrViewIdentifier(oldName), newName, isView) => @@ -143,25 +143,24 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) // For CREATE TABLE [AS SELECT], we should use the v1 command if the catalog is resolved to the // session catalog and the table provider is not v2. - case c @ CreateTableStatement( - SessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _, _) => + case c @ CreateTable(ResolvedDBObjectName(catalog, name), _, _, _, _) => val (storageFormat, provider) = getStorageFormatAndProvider( - c.provider, c.options, c.location, c.serde, ctas = false) - if (!isV2Provider(provider)) { - val tableDesc = buildCatalogTable(tbl.asTableIdentifier, c.tableSchema, - c.partitioning, c.bucketSpec, c.properties, provider, c.location, - c.comment, storageFormat, c.external) - val mode = if (c.ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists - CreateTable(tableDesc, mode, None) + c.tableSpec.provider, + c.tableSpec.options, + c.tableSpec.location, + c.tableSpec.serde, + ctas = false) + if (isSessionCatalog(catalog) && !isV2Provider(provider)) { + val tableDesc = buildCatalogTable(name.asTableIdentifier, c.tableSchema, + c.partitioning, c.tableSpec.bucketSpec, c.tableSpec.properties, provider, + c.tableSpec.location, c.tableSpec.comment, storageFormat, + c.tableSpec.external) + val mode = if (c.ignoreIfExists) SaveMode.Ignore else SaveMode.ErrorIfExists + CreateTableV1(tableDesc, mode, None) } else { - CreateV2Table( - catalog.asTableCatalog, - tbl.asIdentifier, - c.tableSchema, - // convert the bucket spec and add it as a transform - c.partitioning ++ c.bucketSpec.map(_.asTransform), - convertTableProperties(c), - ignoreIfExists = c.ifNotExists) + val newTableSpec = c.tableSpec.copy(bucketSpec = None) + c.copy(partitioning = c.partitioning ++ c.tableSpec.bucketSpec.map(_.asTransform), + tableSpec = newTableSpec) } case c @ CreateTableAsSelectStatement( @@ -173,7 +172,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) c.partitioning, c.bucketSpec, c.properties, provider, c.location, c.comment, storageFormat, c.external) val mode = if (c.ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists - CreateTable(tableDesc, mode, Some(c.asSelect)) + CreateTableV1(tableDesc, mode, Some(c.asSelect)) } else { CreateTableAsSelect( catalog.asTableCatalog, @@ -210,21 +209,15 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) orCreate = c.orCreate) } - case c @ ReplaceTableAsSelectStatement( - SessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _, _, _) => - val provider = c.provider.getOrElse(conf.defaultDataSourceName) + case c @ ReplaceTableAsSelect(ResolvedDBObjectName(catalog, _), _, _, _, _, _) + if isSessionCatalog(catalog) => + val provider = c.tableSpec.provider.getOrElse(conf.defaultDataSourceName) if (!isV2Provider(provider)) { throw QueryCompilationErrors.replaceTableAsSelectOnlySupportedWithV2TableError } else { - ReplaceTableAsSelect( - catalog.asTableCatalog, - tbl.asIdentifier, - // convert the bucket spec and add it as a transform - c.partitioning ++ c.bucketSpec.map(_.asTransform), - c.asSelect, - convertTableProperties(c), - writeOptions = c.writeOptions, - orCreate = c.orCreate) + val newTableSpec = c.tableSpec.copy(bucketSpec = None) + c.copy(partitioning = c.partitioning ++ c.tableSpec.bucketSpec.map(_.asTransform), + tableSpec = newTableSpec) } case DropTable(ResolvedV1TableIdentifier(ident), ifExists, purge) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index c62670b227..748f75b186 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -87,6 +87,11 @@ object SQLExecution { val planDescriptionMode = ExplainMode.fromString(sparkSession.sessionState.conf.uiExplainMode) + val globalConfigs = sparkSession.sharedState.conf.getAll.toMap + val modifiedConfigs = sparkSession.sessionState.conf.getAllConfs + .filterNot(kv => globalConfigs.get(kv._1).contains(kv._2)) + val redactedConfigs = sparkSession.sessionState.conf.redactOptions(modifiedConfigs) + withSQLConfPropagated(sparkSession) { var ex: Option[Throwable] = None val startTime = System.nanoTime() @@ -99,7 +104,8 @@ object SQLExecution { // `queryExecution.executedPlan` triggers query planning. If it fails, the exception // will be caught and reported in the `SparkListenerSQLExecutionEnd` sparkPlanInfo = SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan), - time = System.currentTimeMillis())) + time = System.currentTimeMillis(), + redactedConfigs)) body } catch { case e: Throwable => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index b63306be6b..80739e5387 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -73,24 +73,38 @@ class SparkSqlAstBuilder extends AstBuilder { * character in the raw string. */ override def visitSetConfiguration(ctx: SetConfigurationContext): LogicalPlan = withOrigin(ctx) { - remainder(ctx.SET.getSymbol).trim match { - case configKeyValueDef(key, value) => - SetCommand(Some(key -> Option(value.trim))) - case configKeyDef(key) => - SetCommand(Some(key -> None)) - case s if s == "-v" => - SetCommand(Some("-v" -> None)) - case s if s.isEmpty => - SetCommand(None) - case _ => throw QueryParsingErrors.unexpectedFomatForSetConfigurationError(ctx) + if (ctx.configKey() != null) { + val keyStr = ctx.configKey().getText + if (ctx.EQ() != null) { + remainder(ctx.EQ().getSymbol).trim match { + case configValueDef(valueStr) => SetCommand(Some(keyStr -> Option(valueStr))) + case other => throw QueryParsingErrors.invalidPropertyValueForSetQuotedConfigurationError( + other, keyStr, ctx) + } + } else { + SetCommand(Some(keyStr -> None)) + } + } else { + remainder(ctx.SET.getSymbol).trim match { + case configKeyValueDef(key, value) => + SetCommand(Some(key -> Option(value.trim))) + case configKeyDef(key) => + SetCommand(Some(key -> None)) + case s if s == "-v" => + SetCommand(Some("-v" -> None)) + case s if s.isEmpty => + SetCommand(None) + case _ => throw QueryParsingErrors.unexpectedFomatForSetConfigurationError(ctx) + } } } override def visitSetQuotedConfiguration( ctx: SetQuotedConfigurationContext): LogicalPlan = withOrigin(ctx) { - if (ctx.configValue() != null && ctx.configKey() != null) { + assert(ctx.configValue() != null) + if (ctx.configKey() != null) { SetCommand(Some(ctx.configKey().getText -> Option(ctx.configValue().getText))) - } else if (ctx.configValue() != null) { + } else { val valueStr = ctx.configValue().getText val keyCandidate = interval(ctx.SET().getSymbol, ctx.EQ().getSymbol).trim keyCandidate match { @@ -98,17 +112,6 @@ class SparkSqlAstBuilder extends AstBuilder { case _ => throw QueryParsingErrors.invalidPropertyKeyForSetQuotedConfigurationError( keyCandidate, valueStr, ctx) } - } else { - val keyStr = ctx.configKey().getText - if (ctx.EQ() != null) { - remainder(ctx.EQ().getSymbol).trim match { - case configValueDef(valueStr) => SetCommand(Some(keyStr -> Option(valueStr))) - case other => throw QueryParsingErrors.invalidPropertyValueForSetQuotedConfigurationError( - other, keyStr, ctx) - } - } else { - SetCommand(Some(keyStr -> None)) - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala index 19177ed65a..b34ab3e380 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeSkewInRebalancePartitions.scala @@ -41,18 +41,20 @@ object OptimizeSkewInRebalancePartitions extends AQEShuffleReadRule { /** * Splits the skewed partition based on the map size and the target partition size - * after split. Create a list of `PartialMapperPartitionSpec` for skewed partition and + * after split. Create a list of `PartialReducerPartitionSpec` for skewed partition and * create `CoalescedPartition` for normal partition. */ private def optimizeSkewedPartitions( shuffleId: Int, bytesByPartitionId: Array[Long], targetSize: Long): Seq[ShufflePartitionSpec] = { + val smallPartitionFactor = + conf.getConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR) bytesByPartitionId.indices.flatMap { reduceIndex => val bytes = bytesByPartitionId(reduceIndex) if (bytes > targetSize) { - val newPartitionSpec = - ShufflePartitionsUtil.createSkewPartitionSpecs(shuffleId, reduceIndex, targetSize) + val newPartitionSpec = ShufflePartitionsUtil.createSkewPartitionSpecs( + shuffleId, reduceIndex, targetSize, smallPartitionFactor) if (newPartitionSpec.isEmpty) { CoalescedPartitionSpec(reduceIndex, reduceIndex + 1, bytes) :: Nil } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala index 3609548f37..0251f80378 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/ShufflePartitionsUtil.scala @@ -316,7 +316,10 @@ object ShufflePartitionsUtil extends Logging { * start of a partition. */ // Visible for testing - private[sql] def splitSizeListByTargetSize(sizes: Seq[Long], targetSize: Long): Array[Int] = { + private[sql] def splitSizeListByTargetSize( + sizes: Seq[Long], + targetSize: Long, + smallPartitionFactor: Double): Array[Int] = { val partitionStartIndices = ArrayBuffer[Int]() partitionStartIndices += 0 var i = 0 @@ -329,8 +332,8 @@ object ShufflePartitionsUtil extends Logging { // the previous partition. val shouldMergePartitions = lastPartitionSize > -1 && ((currentPartitionSize + lastPartitionSize) < targetSize * MERGED_PARTITION_FACTOR || - (currentPartitionSize < targetSize * SMALL_PARTITION_FACTOR || - lastPartitionSize < targetSize * SMALL_PARTITION_FACTOR)) + (currentPartitionSize < targetSize * smallPartitionFactor || + lastPartitionSize < targetSize * smallPartitionFactor)) if (shouldMergePartitions) { // We decide to merge the current partition into the previous one, so the start index of // the current partition should be removed. @@ -371,15 +374,18 @@ object ShufflePartitionsUtil extends Logging { /** * Splits the skewed partition based on the map size and the target partition size - * after split, and create a list of `PartialMapperPartitionSpec`. Returns None if can't split. + * after split, and create a list of `PartialReducerPartitionSpec`. Returns None if can't split. */ def createSkewPartitionSpecs( shuffleId: Int, reducerId: Int, - targetSize: Long): Option[Seq[PartialReducerPartitionSpec]] = { + targetSize: Long, + smallPartitionFactor: Double = SMALL_PARTITION_FACTOR) + : Option[Seq[PartialReducerPartitionSpec]] = { val mapPartitionSizes = getMapSizesForReduceId(shuffleId, reducerId) if (mapPartitionSizes.exists(_ < 0)) return None - val mapStartIndices = splitSizeListByTargetSize(mapPartitionSizes, targetSize) + val mapStartIndices = splitSizeListByTargetSize( + mapPartitionSizes, targetSize, smallPartitionFactor) if (mapStartIndices.length > 1) { Some(mapStartIndices.indices.map { i => val startMapIndex = mapStartIndices(i) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 80ab07b159..a7e505ebd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -128,7 +128,7 @@ case class DataSource( .getOrElse(true) } - bucketSpec.map { bucket => + bucketSpec.foreach { bucket => SchemaUtils.checkColumnNameDuplication( bucket.bucketColumnNames, "in the bucket definition", equality) SchemaUtils.checkColumnNameDuplication( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index 0f14a2a94e..88543bd19b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -135,7 +135,7 @@ object PartitioningUtils extends SQLConfHelper{ Map.empty[String, String] } - val dateFormatter = DateFormatter() + val dateFormatter = DateFormatter(DateFormatter.defaultPattern) val timestampFormatter = TimestampFormatter( timestampPartitionPattern, zoneId, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 0e8efb6297..327d92672d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.expressions.{FieldReference, RewritableTransform} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1} import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.sources.InsertableRelation import org.apache.spark.sql.types.{AtomicType, StructType} @@ -81,7 +82,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi // bucketing information is specified, as we can't infer bucketing from data files currently. // Since the runtime inferred partition columns could be different from what user specified, // we fail the query if the partitioning information is specified. - case c @ CreateTable(tableDesc, _, None) if tableDesc.schema.isEmpty => + case c @ CreateTableV1(tableDesc, _, None) if tableDesc.schema.isEmpty => if (tableDesc.bucketSpec.isDefined) { failAnalysis("Cannot specify bucketing information if the table schema is not specified " + "when creating and will be inferred at runtime") @@ -96,7 +97,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi // When we append data to an existing table, check if the given provider, partition columns, // bucket spec, etc. match the existing table, and adjust the columns order of the given query // if necessary. - case c @ CreateTable(tableDesc, SaveMode.Append, Some(query)) + case c @ CreateTableV1(tableDesc, SaveMode.Append, Some(query)) if query.resolved && catalog.tableExists(tableDesc.identifier) => // This is guaranteed by the parser and `DataFrameWriter` assert(tableDesc.provider.isDefined) @@ -189,7 +190,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi // * partition columns' type must be AtomicType. // * sort columns' type must be orderable. // * reorder table schema or output of query plan, to put partition columns at the end. - case c @ CreateTable(tableDesc, _, query) if query.forall(_.resolved) => + case c @ CreateTableV1(tableDesc, _, query) if query.forall(_.resolved) => if (query.isDefined) { assert(tableDesc.schema.isEmpty, "Schema may not be specified in a Create Table As Select (CTAS) statement") @@ -433,7 +434,7 @@ object PreprocessTableInsertion extends Rule[LogicalPlan] { object HiveOnlyCheck extends (LogicalPlan => Unit) { def apply(plan: LogicalPlan): Unit = { plan.foreach { - case CreateTable(tableDesc, _, _) if DDLUtils.isHiveTable(tableDesc) => + case CreateTableV1(tableDesc, _, _) if DDLUtils.isHiveTable(tableDesc) => throw QueryCompilationErrors.ddlWithoutHiveSupportEnabledError( "CREATE Hive TABLE (AS SELECT)") case i: InsertIntoDir if DDLUtils.isHiveTable(i.provider) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala index be7331b0d7..abc6bc60d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala @@ -22,7 +22,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} +import org.apache.spark.sql.catalyst.plans.logical.TableSpec +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.StructType @@ -32,10 +33,12 @@ case class CreateTableExec( identifier: Identifier, tableSchema: StructType, partitioning: Seq[Transform], - tableProperties: Map[String, String], + tableSpec: TableSpec, ignoreIfExists: Boolean) extends LeafV2CommandExec { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + val tableProperties = CatalogV2Util.convertTableProperties(tableSpec) + override protected def run(): Seq[InternalRow] = { if (!catalog.tableExists(identifier)) { try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index cbfeaa4f5d..8a82f36f4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.{SparkSession, Strategy} import org.apache.spark.sql.catalyst.analysis.{ResolvedDBObjectName, ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable} +import org.apache.spark.sql.catalyst.catalog.CatalogUtils import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, EmptyRow, Expression, Literal, NamedExpression, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation @@ -38,6 +39,7 @@ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumn, PushableColumnBase} import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH import org.apache.spark.sql.sources.{BaseRelation, TableScan} import org.apache.spark.sql.types.{BooleanType, StringType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -92,6 +94,11 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat session.sharedState.cacheManager.uncacheQuery(session, v2Relation, cascade = true) } + private def makeQualifiedDBObjectPath(location: String): String = { + CatalogUtils.makeQualifiedDBObjectPath(session.sharedState.conf.get(WAREHOUSE_PATH), + location, session.sharedState.hadoopConf) + } + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, DataSourceV2ScanRelation( _, V1ScanWrapper(scan, pushed, pushedDownOperators), output)) => @@ -156,9 +163,11 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat } WriteToDataSourceV2Exec(writer, invalidateCacheFunc, planLater(query), customMetrics) :: Nil - case CreateV2Table(catalog, ident, schema, parts, props, ifNotExists) => - val propsWithOwner = CatalogV2Util.withDefaultOwnership(props) - CreateTableExec(catalog, ident, schema, parts, propsWithOwner, ifNotExists) :: Nil + case CreateTable(ResolvedDBObjectName(catalog, ident), schema, partitioning, + tableSpec, ifNotExists) => + val qualifiedLocation = tableSpec.location.map(makeQualifiedDBObjectPath(_)) + CreateTableExec(catalog.asTableCatalog, ident.asIdentifier, schema, + partitioning, tableSpec.copy(location = qualifiedLocation), ifNotExists) :: Nil case CreateTableAsSelect(catalog, ident, parts, query, props, options, ifNotExists) => val propsWithOwner = CatalogV2Util.withDefaultOwnership(props) @@ -176,7 +185,10 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat RefreshTableExec(r.catalog, r.identifier, recacheTable(r)) :: Nil case ReplaceTable(catalog, ident, schema, parts, props, orCreate) => - val propsWithOwner = CatalogV2Util.withDefaultOwnership(props) + val newProps = props.get(TableCatalog.PROP_LOCATION).map { loc => + props + (TableCatalog.PROP_LOCATION -> makeQualifiedDBObjectPath(loc)) + }.getOrElse(props) + val propsWithOwner = CatalogV2Util.withDefaultOwnership(newProps) catalog match { case staging: StagingTableCatalog => AtomicReplaceTableExec( @@ -188,29 +200,29 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat invalidateCache) :: Nil } - case ReplaceTableAsSelect(catalog, ident, parts, query, props, options, orCreate) => - val propsWithOwner = CatalogV2Util.withDefaultOwnership(props) + case ReplaceTableAsSelect(ResolvedDBObjectName(catalog, ident), + parts, query, tableSpec, options, orCreate) => val writeOptions = new CaseInsensitiveStringMap(options.asJava) catalog match { case staging: StagingTableCatalog => AtomicReplaceTableAsSelectExec( staging, - ident, + ident.asIdentifier, parts, query, planLater(query), - propsWithOwner, + tableSpec, writeOptions, orCreate = orCreate, invalidateCache) :: Nil case _ => ReplaceTableAsSelectExec( - catalog, - ident, + catalog.asTableCatalog, + ident.asIdentifier, parts, query, planLater(query), - propsWithOwner, + tableSpec, writeOptions, orCreate = orCreate, invalidateCache) :: Nil @@ -314,7 +326,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat AlterNamespaceSetPropertiesExec( catalog.asNamespaceCatalog, ns, - Map(SupportsNamespaces.PROP_LOCATION -> location)) :: Nil + Map(SupportsNamespaces.PROP_LOCATION -> makeQualifiedDBObjectPath(location))) :: Nil case CommentOnNamespace(ResolvedNamespace(catalog, ns), comment) => AlterNamespaceSetPropertiesExec( @@ -323,7 +335,10 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat Map(SupportsNamespaces.PROP_COMMENT -> comment)) :: Nil case CreateNamespace(ResolvedDBObjectName(catalog, name), ifNotExists, properties) => - CreateNamespaceExec(catalog.asNamespaceCatalog, name, ifNotExists, properties) :: Nil + val finalProperties = properties.get(SupportsNamespaces.PROP_LOCATION).map { loc => + properties + (SupportsNamespaces.PROP_LOCATION -> makeQualifiedDBObjectPath(loc)) + }.getOrElse(properties) + CreateNamespaceExec(catalog.asNamespaceCatalog, name, ifNotExists, finalProperties) :: Nil case DropNamespace(ResolvedNamespace(catalog, ns), ifExists, cascade) => DropNamespaceExec(catalog, ns, ifExists, cascade) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index eeb12c5052..f69a2a4588 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -25,10 +25,11 @@ import com.fasterxml.jackson.databind.ObjectMapper import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.catalyst.analysis.TimeTravelSpec +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SessionConfigSupport, SupportsCatalogOptions, SupportsRead, Table, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability.BATCH_READ -import org.apache.spark.sql.connector.expressions.TimeTravelSpec import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -100,8 +101,8 @@ private[sql] object DataSourceV2Utils extends Logging { source: String, paths: String*): Option[DataFrame] = { val catalogManager = sparkSession.sessionState.catalogManager - val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - source = provider, conf = sparkSession.sessionState.conf) + val conf = sparkSession.sessionState.conf + val sessionOptions = DataSourceV2Utils.extractSessionConfigs(provider, conf) val optionsWithPath = getOptionsWithPaths(extraOptions, paths: _*) @@ -123,8 +124,8 @@ private[sql] object DataSourceV2Utils extends Logging { val timestamp = hasCatalog.extractTimeTravelTimestamp(dsOptions) val timeTravelVersion = if (version.isPresent) Some(version.get) else None - val timeTravelTimestamp = if (timestamp.isPresent) Some(timestamp.get) else None - val timeTravel = TimeTravelSpec.create(timeTravelTimestamp, timeTravelVersion) + val timeTravelTimestamp = if (timestamp.isPresent) Some(Literal(timestamp.get)) else None + val timeTravel = TimeTravelSpec.create(timeTravelTimestamp, timeTravelVersion, conf) (CatalogV2Util.loadTable(catalog, ident, timeTravel).get, Some(catalog), Some(ident)) case _ => // TODO: Non-catalog paths for DSV2 are currently not well defined. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeColumnExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeColumnExec.scala index f7d79a1259..3be9b5c547 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeColumnExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DescribeColumnExec.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.util.CharVarcharUtils case class DescribeColumnExec( override val output: Seq[Attribute], @@ -37,7 +38,8 @@ case class DescribeColumnExec( } rows += toCatalystRow("col_name", column.name) - rows += toCatalystRow("data_type", column.dataType.catalogString) + rows += toCatalystRow("data_type", + CharVarcharUtils.getRawType(column.metadata).getOrElse(column.dataType).catalogString) rows += toCatalystRow("comment", comment) // TODO: The extended description (isExtended = true) can be added here. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 8b0328cabc..21503fda53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -136,10 +136,10 @@ trait FileScan extends Scan val partitionAttributes = fileIndex.partitionSchema.toAttributes val attributeMap = partitionAttributes.map(a => normalizeName(a.name) -> a).toMap val readPartitionAttributes = readPartitionSchema.map { readField => - attributeMap.get(normalizeName(readField.name)).getOrElse { + attributeMap.getOrElse(normalizeName(readField.name), throw QueryCompilationErrors.cannotFindPartitionColumnInPartitionSchemaError( readField, fileIndex.partitionSchema) - } + ) } lazy val partitionValueProject = GenerateUnsafeProjection.generate(readPartitionAttributes, partitionAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index 5bfeac9f9e..b3b890e88c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -71,6 +71,28 @@ class V2SessionCatalog(catalog: SessionCatalog) V1Table(catalogTable) } + override def loadTable(ident: Identifier, timestamp: Long): Table = { + failTimeTravel(ident, loadTable(ident)) + } + + override def loadTable(ident: Identifier, version: String): Table = { + failTimeTravel(ident, loadTable(ident)) + } + + private def failTimeTravel(ident: Identifier, t: Table): Table = { + t match { + case V1Table(catalogTable) => + if (catalogTable.tableType == CatalogTableType.VIEW) { + throw QueryCompilationErrors.viewNotSupportTimeTravelError( + ident.namespace() :+ ident.name()) + } else { + throw QueryCompilationErrors.tableNotSupportTimeTravelError(ident) + } + + case _ => throw QueryCompilationErrors.tableNotSupportTimeTravelError(ident) + } + } + override def invalidateTable(ident: Identifier): Unit = { catalog.refreshTable(ident.asTableIdentifier) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index add698f990..c61ef56eaf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -28,9 +28,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TableSpec, UnaryNode} import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingTableCatalog, SupportsWrite, Table, TableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagedTable, StagingTableCatalog, SupportsWrite, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, LogicalWriteInfoImpl, PhysicalWriteInfoImpl, V1Write, Write, WriterCommitMessage} @@ -147,11 +147,13 @@ case class ReplaceTableAsSelectExec( partitioning: Seq[Transform], plan: LogicalPlan, query: SparkPlan, - properties: Map[String, String], + tableSpec: TableSpec, writeOptions: CaseInsensitiveStringMap, orCreate: Boolean, invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends TableWriteExecHelper { + val properties = CatalogV2Util.convertTableProperties(tableSpec) + override protected def run(): Seq[InternalRow] = { // Note that this operation is potentially unsafe, but these are the strict semantics of // RTAS if the catalog does not support atomic operations. @@ -196,11 +198,13 @@ case class AtomicReplaceTableAsSelectExec( partitioning: Seq[Transform], plan: LogicalPlan, query: SparkPlan, - properties: Map[String, String], + tableSpec: TableSpec, writeOptions: CaseInsensitiveStringMap, orCreate: Boolean, invalidateCache: (TableCatalog, Table, Identifier) => Unit) extends TableWriteExecHelper { + val properties = CatalogV2Util.convertTableProperties(tableSpec) + override protected def run(): Seq[InternalRow] = { val schema = CharVarcharUtils.getRawSchema(query.schema, conf).asNullable if (catalog.tableExists(ident)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala index 60c66d863a..0893875aff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala @@ -26,19 +26,15 @@ class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: Expr extends Sink { override def addBatch(batchId: Long, data: DataFrame): Unit = { - val resolvedEncoder = encoder.resolveAndBind( - data.logicalPlan.output, - data.sparkSession.sessionState.analyzer) - val fromRow = resolvedEncoder.createDeserializer() - val rdd = data.queryExecution.toRdd.map[T](fromRow)(encoder.clsTag) - val ds = data.sparkSession.createDataset(rdd)(encoder) + val rdd = data.queryExecution.toRdd + implicit val enc = encoder + val ds = data.sparkSession.internalCreateDataFrame(rdd, data.schema).as[T] batchWriter(ds, batchId) } override def toString(): String = "ForeachBatchSink" } - /** * Interface that is meant to be extended by Python classes via Py4J. * Py4J allows Python classes to implement Java interfaces so that the JVM can call back diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index a2b33c2ba3..c88e6ae3f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -116,7 +116,7 @@ private[sql] class RocksDBStateStoreProvider rocksDBMetrics.nativeOpsHistograms.get(typ).map(_.count).getOrElse(0) } def nativeOpsMetrics(typ: String): Long = { - rocksDBMetrics.nativeOpsMetrics.get(typ).getOrElse(0) + rocksDBMetrics.nativeOpsMetrics.getOrElse(typ, 0) } val stateStoreCustomMetrics = Map[StateStoreCustomMetric, Long]( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala index b15c70a7eb..b8575b052b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -81,7 +81,8 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging summary ++ planVisualization(request, metrics, graph) ++ - physicalPlanDescription(executionUIData.physicalPlanDescription) + physicalPlanDescription(executionUIData.physicalPlanDescription) ++ + modifiedConfigs(executionUIData.modifiedConfigs) }.getOrElse {

No information to display for query {executionId}
} @@ -145,4 +146,28 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
} + + private def modifiedConfigs(modifiedConfigs: Map[String, String]): Seq[Node] = { + val configs = UIUtils.listingTable( + propertyHeader, + propertyRow, + modifiedConfigs.toSeq.sorted, + fixedWidth = true + ) + +
+ + + SQL Properties + + +
+
+ } + + private def propertyHeader = Seq("Name", "Value") + private def propertyRow(kv: (String, String)) = {kv._1}{kv._2} } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala index e7ab4a184b..d892dbdc23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListener.scala @@ -93,6 +93,7 @@ class SQLAppStatusListener( executionData.description = sqlStoreData.description executionData.details = sqlStoreData.details executionData.physicalPlanDescription = sqlStoreData.physicalPlanDescription + executionData.modifiedConfigs = sqlStoreData.modifiedConfigs executionData.metrics = sqlStoreData.metrics executionData.submissionTime = sqlStoreData.submissionTime executionData.completionTime = sqlStoreData.completionTime @@ -336,7 +337,7 @@ class SQLAppStatusListener( private def onExecutionStart(event: SparkListenerSQLExecutionStart): Unit = { val SparkListenerSQLExecutionStart(executionId, description, details, - physicalPlanDescription, sparkPlanInfo, time) = event + physicalPlanDescription, sparkPlanInfo, time, modifiedConfigs) = event val planGraph = SparkPlanGraph(sparkPlanInfo) val sqlPlanMetrics = planGraph.allNodes.flatMap { node => @@ -353,6 +354,7 @@ class SQLAppStatusListener( exec.description = description exec.details = details exec.physicalPlanDescription = physicalPlanDescription + exec.modifiedConfigs = modifiedConfigs exec.metrics = sqlPlanMetrics exec.submissionTime = time update(exec) @@ -479,6 +481,7 @@ private class LiveExecutionData(val executionId: Long) extends LiveEntity { var description: String = null var details: String = null var physicalPlanDescription: String = null + var modifiedConfigs: Map[String, String] = _ var metrics = Seq[SQLPlanMetric]() var submissionTime = -1L var completionTime: Option[Date] = None @@ -499,6 +502,7 @@ private class LiveExecutionData(val executionId: Long) extends LiveEntity { description, details, physicalPlanDescription, + modifiedConfigs, metrics, submissionTime, completionTime, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala index a90f37a80d..7c3315e3d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLAppStatusStore.scala @@ -86,6 +86,7 @@ class SQLExecutionUIData( val description: String, val details: String, val physicalPlanDescription: String, + val modifiedConfigs: Map[String, String], val metrics: Seq[SQLPlanMetric], val submissionTime: Long, val completionTime: Option[Date], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index 6a6a71c46f..26805e135b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -47,7 +47,8 @@ case class SparkListenerSQLExecutionStart( details: String, physicalPlanDescription: String, sparkPlanInfo: SparkPlanInfo, - time: Long) + time: Long, + modifiedConfigs: Map[String, String] = Map.empty) extends SparkListenerEvent @DeveloperApi diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 10ce9d3aaf..2d3c89874f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -27,8 +27,9 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Evolving import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedDBObjectName import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.plans.logical.CreateTableStatement +import org.apache.spark.sql.catalyst.plans.logical.{CreateTable, TableSpec} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{Identifier, SupportsWrite, Table, TableCatalog, TableProvider, V1Table, V2TableWithV1Fallback} @@ -288,10 +289,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * Note, currently the new table creation by this API doesn't fully cover the V2 table. * TODO (SPARK-33638): Full support of v2 table creation */ - val cmd = CreateTableStatement( - originalMultipartIdentifier, - df.schema.asNullable, - partitioningColumns.getOrElse(Nil).asTransforms.toSeq, + val tableProperties = TableSpec( None, Map.empty[String, String], Some(source), @@ -299,8 +297,15 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { extraOptions.get("path"), None, None, - external = false, - ifNotExists = false) + false) + val cmd = CreateTable( + UnresolvedDBObjectName( + originalMultipartIdentifier, + isNamespace = false), + df.schema.asNullable, + partitioningColumns.getOrElse(Nil).asTransforms.toSeq, + tableProperties, + ignoreIfExists = false) Dataset.ofRows(df.sparkSession, cmd) } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 821f566c63..6a4d615924 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -1,6 +1,6 @@ ## Summary - - Number of queries: 367 + - Number of queries: 368 - Number of expressions that missing example: 12 - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint ## Schema of Built-in Functions @@ -75,6 +75,7 @@ | org.apache.spark.sql.catalyst.expressions.Coalesce | coalesce | SELECT coalesce(NULL, 1, NULL) | struct | | org.apache.spark.sql.catalyst.expressions.Concat | concat | SELECT concat('Spark', 'SQL') | struct | | org.apache.spark.sql.catalyst.expressions.ConcatWs | concat_ws | SELECT concat_ws(' ', 'Spark', 'SQL') | struct | +| org.apache.spark.sql.catalyst.expressions.Contains | contains | SELECT contains('Spark SQL', 'Spark') | struct | | org.apache.spark.sql.catalyst.expressions.Conv | conv | SELECT conv('100', 2, 10) | struct | | org.apache.spark.sql.catalyst.expressions.Cos | cos | SELECT cos(0) | struct | | org.apache.spark.sql.catalyst.expressions.Cosh | cosh | SELECT cosh(0) | struct | diff --git a/sql/core/src/test/resources/sql-tests/inputs/comments.sql b/sql/core/src/test/resources/sql-tests/inputs/comments.sql index 19f11de22d..da5e57a942 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/comments.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/comments.sql @@ -88,3 +88,32 @@ Other information of first level. /*/**/*/ SELECT 'selected content' AS tenth; --QUERY-DELIMITER-END + +-- the first case of unclosed bracketed comment +--QUERY-DELIMITER-START +/*abc*/ +select 1 as a +/* + +2 as b +/*abc*/ +, 3 as c + +/**/ +; +--QUERY-DELIMITER-END + +-- the second case of unclosed bracketed comment +--QUERY-DELIMITER-START +/*abc*/ +select 1 as a +/* + +2 as b +/*abc*/ +, 3 as c + +/**/ +select 4 as d +; +--QUERY-DELIMITER-END diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index 61a5a318dd..f2710848b8 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -104,4 +104,12 @@ select decode(1, 1, 'Southlake'); select decode(2, 1, 'Southlake'); select decode(2, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle', 'Non domestic'); select decode(6, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle', 'Non domestic'); -select decode(6, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle'); \ No newline at end of file +select decode(6, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattle'); + +-- contains +SELECT CONTAINS(null, 'Spark'); +SELECT CONTAINS('Spark SQL', null); +SELECT CONTAINS(null, null); +SELECT CONTAINS('Spark SQL', 'Spark'); +SELECT CONTAINS('Spark SQL', 'SQL'); +SELECT CONTAINS('Spark SQL', 'SPARK'); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out index b95c8dac9a..c3c09778a2 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out @@ -230,7 +230,8 @@ select next_day(timestamp_ntz"2015-07-23 12:12:12", "Mon") struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'next_day(TIMESTAMP_NTZ '2015-07-23 12:12:12', 'Mon')' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP_NTZ '2015-07-23 12:12:12'' is of timestamp_ntz type.; line 1 pos 7 +cannot resolve 'next_day(TIMESTAMP_NTZ '2015-07-23 12:12:12', 'Mon')' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP_NTZ '2015-07-23 12:12:12'' is of timestamp_ntz type. +To fix the error, you might need to add explicit type casts. If necessary set spark.sql.ansi.enabled to false to bypass this error.; line 1 pos 7 -- !query @@ -498,7 +499,8 @@ select date_add(date_str, 1) from date_view struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'date_add(date_view.date_str, 1)' due to data type mismatch: argument 1 requires date type, however, 'date_view.date_str' is of string type.; line 1 pos 7 +cannot resolve 'date_add(date_view.date_str, 1)' due to data type mismatch: argument 1 requires date type, however, 'date_view.date_str' is of string type. +To fix the error, you might need to add explicit type casts. If necessary set spark.sql.ansi.enabled to false to bypass this error.; line 1 pos 7 -- !query @@ -507,7 +509,8 @@ select date_sub(date_str, 1) from date_view struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'date_sub(date_view.date_str, 1)' due to data type mismatch: argument 1 requires date type, however, 'date_view.date_str' is of string type.; line 1 pos 7 +cannot resolve 'date_sub(date_view.date_str, 1)' due to data type mismatch: argument 1 requires date type, however, 'date_view.date_str' is of string type. +To fix the error, you might need to add explicit type casts. If necessary set spark.sql.ansi.enabled to false to bypass this error.; line 1 pos 7 -- !query @@ -589,7 +592,8 @@ select date_str - date '2001-09-28' from date_view struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve '(date_view.date_str - DATE '2001-09-28')' due to data type mismatch: argument 1 requires date type, however, 'date_view.date_str' is of string type.; line 1 pos 7 +cannot resolve '(date_view.date_str - DATE '2001-09-28')' due to data type mismatch: argument 1 requires date type, however, 'date_view.date_str' is of string type. +To fix the error, you might need to add explicit type casts. If necessary set spark.sql.ansi.enabled to false to bypass this error.; line 1 pos 7 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index e9c323254b..230393f02a 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -1533,7 +1533,8 @@ select str - interval '4 22:12' day to minute from interval_view struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'interval_view.str + (- INTERVAL '4 22:12' DAY TO MINUTE)' due to data type mismatch: argument 1 requires (timestamp or timestamp without time zone) type, however, 'interval_view.str' is of string type.; line 1 pos 7 +cannot resolve 'interval_view.str + (- INTERVAL '4 22:12' DAY TO MINUTE)' due to data type mismatch: argument 1 requires (timestamp or timestamp without time zone) type, however, 'interval_view.str' is of string type. +To fix the error, you might need to add explicit type casts. If necessary set spark.sql.ansi.enabled to false to bypass this error.; line 1 pos 7 -- !query @@ -1542,7 +1543,8 @@ select str + interval '4 22:12' day to minute from interval_view struct<> -- !query output org.apache.spark.sql.AnalysisException -cannot resolve 'interval_view.str + INTERVAL '4 22:12' DAY TO MINUTE' due to data type mismatch: argument 1 requires (timestamp or timestamp without time zone) type, however, 'interval_view.str' is of string type.; line 1 pos 7 +cannot resolve 'interval_view.str + INTERVAL '4 22:12' DAY TO MINUTE' due to data type mismatch: argument 1 requires (timestamp or timestamp without time zone) type, however, 'interval_view.str' is of string type. +To fix the error, you might need to add explicit type casts. If necessary set spark.sql.ansi.enabled to false to bypass this error.; line 1 pos 7 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out index 45d403859a..a81a34b7c6 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 78 +-- Number of queries: 84 -- !query @@ -632,3 +632,51 @@ select decode(6, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattl struct -- !query output NULL + + +-- !query +SELECT CONTAINS(null, 'Spark') +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT CONTAINS('Spark SQL', null) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT CONTAINS(null, null) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT CONTAINS('Spark SQL', 'Spark') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT CONTAINS('Spark SQL', 'SQL') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT CONTAINS('Spark SQL', 'SPARK') +-- !query schema +struct +-- !query output +false diff --git a/sql/core/src/test/resources/sql-tests/results/comments.sql.out b/sql/core/src/test/resources/sql-tests/results/comments.sql.out index fd58a33595..da9dbd5fa3 100644 --- a/sql/core/src/test/resources/sql-tests/results/comments.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/comments.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 10 +-- Number of queries: 12 -- !query @@ -119,3 +119,65 @@ SELECT 'selected content' AS tenth struct -- !query output selected content + + +-- !query +/*abc*/ +select 1 as a +/* + +2 as b +/*abc*/ +, 3 as c + +/**/ +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException + +Unclosed bracketed comment(line 3, pos 0) + +== SQL == +/*abc*/ +select 1 as a +/* +^^^ + +2 as b +/*abc*/ +, 3 as c + +/**/ + + +-- !query +/*abc*/ +select 1 as a +/* + +2 as b +/*abc*/ +, 3 as c + +/**/ +select 4 as d +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException + +Unclosed bracketed comment(line 3, pos 0) + +== SQL == +/*abc*/ +select 1 as a +/* +^^^ + +2 as b +/*abc*/ +, 3 as c + +/**/ +select 4 as d diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/union.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/union.sql.out index 13f3fe064a..84dcf3aca7 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/union.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/union.sql.out @@ -686,6 +686,7 @@ struct<> -- !query output org.apache.spark.sql.AnalysisException Union can only be performed on tables with the compatible column types. The first column of the second table is string type which is not compatible with decimal(38,18) at same column of first table +To fix the error, you might need to add explicit type casts. If necessary set spark.sql.ansi.enabled to false to bypass this error. -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 9249f94acd..d452df8bc5 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 78 +-- Number of queries: 84 -- !query @@ -628,3 +628,51 @@ select decode(6, 1, 'Southlake', 2, 'San Francisco', 3, 'New Jersey', 4, 'Seattl struct -- !query output NULL + + +-- !query +SELECT CONTAINS(null, 'Spark') +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT CONTAINS('Spark SQL', null) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT CONTAINS(null, null) +-- !query schema +struct +-- !query output +NULL + + +-- !query +SELECT CONTAINS('Spark SQL', 'Spark') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT CONTAINS('Spark SQL', 'SQL') +-- !query schema +struct +-- !query output +true + + +-- !query +SELECT CONTAINS('Spark SQL', 'SPARK') +-- !query schema +struct +-- !query output +false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index 7be54d49a9..f2df9af9ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -843,17 +843,6 @@ class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSpa } } - // TODO(SPARK-33875): Move these tests to super after DESCRIBE COLUMN v2 implemented - test("SPARK-33892: DESCRIBE COLUMN w/ char/varchar") { - withTable("t") { - sql(s"CREATE TABLE t(v VARCHAR(3), c CHAR(5)) USING $format") - checkAnswer(sql("desc t v").selectExpr("info_value").where("info_value like '%char%'"), - Row("varchar(3)")) - checkAnswer(sql("desc t c").selectExpr("info_value").where("info_value like '%char%'"), - Row("char(5)")) - } - } - // TODO(SPARK-33898): Move these tests to super after SHOW CREATE TABLE for v2 implemented test("SPARK-33892: SHOW CREATE TABLE w/ char/varchar") { withTable("t") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index c87314386f..2808652f29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -368,4 +368,15 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { .selectExpr("value.a") checkAnswer(fromCsvDF, Row(localDT)) } + + test("SPARK-37326: Handle incorrectly formatted timestamp_ntz values in from_csv") { + val fromCsvDF = Seq("2021-08-12T15:16:23.000+11:00").toDF("csv") + .select( + from_csv( + $"csv", + StructType(StructField("a", TimestampNTZType) :: Nil), + Map.empty[String, String]) as "value") + .selectExpr("value.a") + checkAnswer(fromCsvDF, Row(null)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index a090eba430..76b3324e3e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -27,11 +27,11 @@ import org.scalatest.Assertions._ import org.apache.spark.TestUtils import org.apache.spark.api.python.{PythonBroadcast, PythonEvalType, PythonFunction, PythonUtils} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExprId, PythonUDF} import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.SparkUserDefinedFunction -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.types.{DataType, StringType} /** * This object targets to integrate various UDF test cases so that Scalar UDF, Python UDF and @@ -218,6 +218,29 @@ object IntegratedUDFTestUtils extends SQLHelper { val prettyName: String } + class PythonUDFWithoutId( + name: String, + func: PythonFunction, + dataType: DataType, + children: Seq[Expression], + evalType: Int, + udfDeterministic: Boolean, + resultId: ExprId) + extends PythonUDF(name, func, dataType, children, evalType, udfDeterministic, resultId) { + + def this(pudf: PythonUDF) = { + this(pudf.name, pudf.func, pudf.dataType, pudf.children, + pudf.evalType, pudf.udfDeterministic, pudf.resultId) + } + + override def toString: String = s"$name(${children.mkString(", ")})" + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): PythonUDFWithoutId = { + new PythonUDFWithoutId(super.withNewChildrenInternal(newChildren)) + } + } + /** * A Python UDF that takes one column, casts into string, executes the Python native function, * and casts back to the type of input column. @@ -253,7 +276,9 @@ object IntegratedUDFTestUtils extends SQLHelper { val expr = e.head assert(expr.resolved, "column should be resolved to use the same type " + "as input. Try df(name) or df.col(name)") - Cast(super.builder(Cast(expr, StringType) :: Nil), expr.dataType) + val pythonUDF = new PythonUDFWithoutId( + super.builder(Cast(expr, StringType) :: Nil).asInstanceOf[PythonUDF]) + Cast(pythonUDF, expr.dataType) } } @@ -297,7 +322,9 @@ object IntegratedUDFTestUtils extends SQLHelper { val expr = e.head assert(expr.resolved, "column should be resolved to use the same type " + "as input. Try df(name) or df.col(name)") - Cast(super.builder(Cast(expr, StringType) :: Nil), expr.dataType) + val pythonUDF = new PythonUDFWithoutId( + super.builder(Cast(expr, StringType) :: Nil).asInstanceOf[PythonUDF]) + Cast(pythonUDF, expr.dataType) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQueryTestSuite.scala index d776915f3c..9ac5fb6d03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQueryTestSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql import java.io.File import java.nio.file.{Files, Paths} +import scala.collection.JavaConverters._ + import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.catalyst.util.{fileToString, resourceToString, stringToFile} import org.apache.spark.sql.internal.SQLConf @@ -100,9 +102,9 @@ class TPCDSQueryTestSuite extends QueryTest with TPCDSBase with SQLQueryTestHelp private def runQuery( query: String, goldenFile: File, - conf: Seq[(String, String)], - needSort: Boolean): Unit = { - withSQLConf(conf: _*) { + conf: Map[String, String]): Unit = { + val shouldSortResults = sortMergeJoinConf != conf // Sort for other joins + withSQLConf(conf.toSeq: _*) { try { val (schema, output) = handleExceptions(getNormalizedResult(spark, query)) val queryString = query.trim @@ -139,7 +141,7 @@ class TPCDSQueryTestSuite extends QueryTest with TPCDSBase with SQLQueryTestHelp assertResult(expectedSchema, s"Schema did not match\n$queryString") { schema } - if (needSort) { + if (shouldSortResults) { val expectSorted = expectedOutput.split("\n").sorted.map(_.trim) .mkString("\n").replaceAll("\\s+$", "") val outputSorted = output.sorted.map(_.trim).mkString("\n").replaceAll("\\s+$", "") @@ -171,8 +173,26 @@ class TPCDSQueryTestSuite extends QueryTest with TPCDSBase with SQLQueryTestHelp SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", "spark.sql.join.forceApplyShuffledHashJoin" -> "true") - val joinConfSet: Set[Map[String, String]] = - Set(sortMergeJoinConf, broadcastHashJoinConf, shuffledHashJoinConf); + val allJoinConfCombinations = Seq( + sortMergeJoinConf, broadcastHashJoinConf, shuffledHashJoinConf) + + val joinConfs: Seq[Map[String, String]] = if (regenerateGoldenFiles) { + require( + !sys.env.contains("SPARK_TPCDS_JOIN_CONF"), + "'SPARK_TPCDS_JOIN_CONF' cannot be set together with 'SPARK_GENERATE_GOLDEN_FILES'") + Seq(sortMergeJoinConf) + } else { + sys.env.get("SPARK_TPCDS_JOIN_CONF").map { s => + val p = new java.util.Properties() + p.load(new java.io.StringReader(s)) + Seq(p.asScala.toMap) + }.getOrElse(allJoinConfCombinations) + } + + assert(joinConfs.nonEmpty) + joinConfs.foreach(conf => require( + allJoinConfCombinations.contains(conf), + s"Join configurations [$conf] should be one of $allJoinConfCombinations")) if (tpcdsDataPath.nonEmpty) { tpcdsQueries.foreach { name => @@ -180,13 +200,9 @@ class TPCDSQueryTestSuite extends QueryTest with TPCDSBase with SQLQueryTestHelp classLoader = Thread.currentThread().getContextClassLoader) test(name) { val goldenFile = new File(s"$baseResourcePath/v1_4", s"$name.sql.out") - System.gc() // Workaround for GitHub Actions memory limitation, see also SPARK-37368 - runQuery(queryString, goldenFile, joinConfSet.head.toSeq, false) - if (!regenerateGoldenFiles) { - joinConfSet.tail.foreach { conf => - System.gc() // SPARK-37368 - runQuery(queryString, goldenFile, conf.toSeq, true) - } + joinConfs.foreach { conf => + System.gc() // Workaround for GitHub Actions memory limitation, see also SPARK-37368 + runQuery(queryString, goldenFile, conf) } } } @@ -196,13 +212,9 @@ class TPCDSQueryTestSuite extends QueryTest with TPCDSBase with SQLQueryTestHelp classLoader = Thread.currentThread().getContextClassLoader) test(s"$name-v2.7") { val goldenFile = new File(s"$baseResourcePath/v2_7", s"$name.sql.out") - System.gc() // SPARK-37368 - runQuery(queryString, goldenFile, joinConfSet.head.toSeq, false) - if (!regenerateGoldenFiles) { - joinConfSet.tail.foreach { conf => - System.gc() // SPARK-37368 - runQuery(queryString, goldenFile, conf.toSeq, true) - } + joinConfs.foreach { conf => + System.gc() // SPARK-37368 + runQuery(queryString, goldenFile, conf) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 6cbbf680bc..949abfeefc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -126,7 +126,7 @@ class DataSourceV2SQLSuite " PARTITIONED BY (id)" + " TBLPROPERTIES ('bar'='baz')" + " COMMENT 'this is a test table'" + - " LOCATION '/tmp/testcat/table_name'") + " LOCATION 'file:/tmp/testcat/table_name'") val descriptionDf = spark.sql("DESCRIBE TABLE EXTENDED testcat.table_name") assert(descriptionDf.schema.map(field => (field.name, field.dataType)) === Seq( @@ -149,7 +149,7 @@ class DataSourceV2SQLSuite Array("# Detailed Table Information", "", ""), Array("Name", "testcat.table_name", ""), Array("Comment", "this is a test table", ""), - Array("Location", "/tmp/testcat/table_name", ""), + Array("Location", "file:/tmp/testcat/table_name", ""), Array("Provider", "foo", ""), Array(TableCatalog.PROP_OWNER.capitalize, defaultUser, ""), Array("Table Properties", "[bar=baz]", ""))) @@ -1093,6 +1093,24 @@ class DataSourceV2SQLSuite } } + test("SPARK-37456: Location in CreateNamespace should be qualified") { + withNamespace("testcat.ns1.ns2") { + val e = intercept[IllegalArgumentException] { + sql("CREATE NAMESPACE testcat.ns1.ns2 LOCATION ''") + } + assert(e.getMessage.contains("Can not create a Path from an empty string")) + + sql("CREATE NAMESPACE testcat.ns1.ns2 LOCATION '/tmp/ns_test'") + val descriptionDf = sql("DESCRIBE NAMESPACE EXTENDED testcat.ns1.ns2") + assert(descriptionDf.collect() === Seq( + Row("Namespace Name", "ns2"), + Row(SupportsNamespaces.PROP_LOCATION.capitalize, "file:/tmp/ns_test"), + Row(SupportsNamespaces.PROP_OWNER.capitalize, defaultUser), + Row("Properties", "")) + ) + } + } + test("create/replace/alter table - reserved properties") { import TableCatalog._ withSQLConf((SQLConf.LEGACY_PROPERTY_NON_RESERVED.key, "false")) { @@ -1161,8 +1179,9 @@ class DataSourceV2SQLSuite s" ('path'='bar', 'Path'='noop')") val tableCatalog = catalog("testcat").asTableCatalog val identifier = Identifier.of(Array(), "reservedTest") - assert(tableCatalog.loadTable(identifier).properties() - .get(TableCatalog.PROP_LOCATION) == "foo", + val location = tableCatalog.loadTable(identifier).properties() + .get(TableCatalog.PROP_LOCATION) + assert(location.startsWith("file:") && location.endsWith("foo"), "path as a table property should not have side effects") assert(tableCatalog.loadTable(identifier).properties().get("path") == "bar", "path as a table property should not have side effects") @@ -1246,7 +1265,7 @@ class DataSourceV2SQLSuite assert(descriptionDf.collect() === Seq( Row("Namespace Name", "ns2"), Row(SupportsNamespaces.PROP_COMMENT.capitalize, "test namespace"), - Row(SupportsNamespaces.PROP_LOCATION.capitalize, "/tmp/ns_test"), + Row(SupportsNamespaces.PROP_LOCATION.capitalize, "file:/tmp/ns_test"), Row(SupportsNamespaces.PROP_OWNER.capitalize, defaultUser), Row("Properties", "((a,b), (b,a), (c,c))")) ) @@ -1294,7 +1313,7 @@ class DataSourceV2SQLSuite assert(descriptionDf.collect() === Seq( Row("Namespace Name", "ns2"), Row(SupportsNamespaces.PROP_COMMENT.capitalize, "test namespace"), - Row(SupportsNamespaces.PROP_LOCATION.capitalize, "/tmp/ns_test_2"), + Row(SupportsNamespaces.PROP_LOCATION.capitalize, "file:/tmp/ns_test_2"), Row(SupportsNamespaces.PROP_OWNER.capitalize, defaultUser), Row("Properties", "")) ) @@ -1994,7 +2013,7 @@ class DataSourceV2SQLSuite |COMMENT 'This is a comment' |TBLPROPERTIES ('prop1' = '1', 'prop2' = '2', 'prop3' = 3, 'prop4' = 4) |PARTITIONED BY (a) - |LOCATION '/tmp' + |LOCATION 'file:/tmp' """.stripMargin) val showDDL = getShowCreateDDL(s"SHOW CREATE TABLE $t") assert(showDDL === Array( @@ -2011,7 +2030,7 @@ class DataSourceV2SQLSuite "'via' = '2')", "PARTITIONED BY (a)", "COMMENT 'This is a comment'", - "LOCATION '/tmp'", + "LOCATION 'file:/tmp'", "TBLPROPERTIES(", "'prop1' = '1',", "'prop2' = '2',", @@ -2893,8 +2912,40 @@ class DataSourceV2SQLSuite } } - test("Mock time travel test") { + test("Check HasPartitionKey from InMemoryPartitionTable") { + val t = "testpart.tbl" + withTable(t) { + sql(s"CREATE TABLE $t (id string) USING foo PARTITIONED BY (key int)") + val table = catalog("testpart").asTableCatalog + .loadTable(Identifier.of(Array(), "tbl")) + .asInstanceOf[InMemoryPartitionTable] + + sql(s"INSERT INTO $t VALUES ('a', 1), ('b', 2), ('c', 3)") + var partKeys = table.data.map(_.partitionKey().getInt(0)) + assert(partKeys.length == 3) + assert(partKeys.toSet == Set(1, 2, 3)) + + sql(s"ALTER TABLE $t DROP PARTITION (key=3)") + partKeys = table.data.map(_.partitionKey().getInt(0)) + assert(partKeys.length == 2) + assert(partKeys.toSet == Set(1, 2)) + + sql(s"ALTER TABLE $t ADD PARTITION (key=4)") + partKeys = table.data.map(_.partitionKey().getInt(0)) + assert(partKeys.length == 3) + assert(partKeys.toSet == Set(1, 2, 4)) + + sql(s"INSERT INTO $t VALUES ('c', 3), ('e', 5)") + partKeys = table.data.map(_.partitionKey().getInt(0)) + assert(partKeys.length == 5) + assert(partKeys.toSet == Set(1, 2, 3, 4, 5)) + } + } + + test("time travel") { sql("use testcat") + // The testing in-memory table simply append the version/timestamp to the table name when + // looking up tables. val t1 = "testcat.tSnapshot123456789" val t2 = "testcat.t2345678910" withTable(t1, t2) { @@ -2910,26 +2961,13 @@ class DataSourceV2SQLSuite === Array(Row(1), Row(2))) assert(sql("SELECT * FROM t VERSION AS OF 2345678910").collect === Array(Row(3), Row(4))) - assert(sql("SELECT * FROM t FOR VERSION AS OF 'Snapshot123456789'").collect - === Array(Row(1), Row(2))) - assert(sql("SELECT * FROM t FOR VERSION AS OF 2345678910").collect - === Array(Row(3), Row(4))) - - assert(sql("SELECT * FROM t FOR SYSTEM_VERSION AS OF 'Snapshot123456789'").collect - === Array(Row(1), Row(2))) - assert(sql("SELECT * FROM t FOR SYSTEM_VERSION AS OF 2345678910").collect - === Array(Row(3), Row(4))) - assert(sql("SELECT * FROM t SYSTEM_VERSION AS OF 'Snapshot123456789'").collect - === Array(Row(1), Row(2))) - assert(sql("SELECT * FROM t SYSTEM_VERSION AS OF 2345678910").collect - === Array(Row(3), Row(4))) } val ts1 = DateTimeUtils.stringToTimestampAnsi( UTF8String.fromString("2019-01-29 00:37:58"), DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone)) val ts2 = DateTimeUtils.stringToTimestampAnsi( - UTF8String.fromString("2021-01-29 00:37:58"), + UTF8String.fromString("2021-01-29 00:00:00"), DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone)) val t3 = s"testcat.t$ts1" val t4 = s"testcat.t$ts2" @@ -2945,21 +2983,37 @@ class DataSourceV2SQLSuite assert(sql("SELECT * FROM t TIMESTAMP AS OF '2019-01-29 00:37:58'").collect === Array(Row(5), Row(6))) - assert(sql("SELECT * FROM t TIMESTAMP AS OF '2021-01-29 00:37:58'").collect + assert(sql("SELECT * FROM t TIMESTAMP AS OF '2021-01-29 00:00:00'").collect === Array(Row(7), Row(8))) - assert(sql("SELECT * FROM t FOR TIMESTAMP AS OF '2019-01-29 00:37:58'").collect - === Array(Row(5), Row(6))) - assert(sql("SELECT * FROM t FOR TIMESTAMP AS OF '2021-01-29 00:37:58'").collect + assert(sql("SELECT * FROM t TIMESTAMP AS OF make_date(2021, 1, 29)").collect === Array(Row(7), Row(8))) - - assert(sql("SELECT * FROM t FOR SYSTEM_TIME AS OF '2019-01-29 00:37:58'").collect - === Array(Row(5), Row(6))) - assert(sql("SELECT * FROM t FOR SYSTEM_TIME AS OF '2021-01-29 00:37:58'").collect - === Array(Row(7), Row(8))) - assert(sql("SELECT * FROM t SYSTEM_TIME AS OF '2019-01-29 00:37:58'").collect - === Array(Row(5), Row(6))) - assert(sql("SELECT * FROM t SYSTEM_TIME AS OF '2021-01-29 00:37:58'").collect + assert(sql("SELECT * FROM t TIMESTAMP AS OF to_timestamp('2021-01-29 00:00:00')").collect === Array(Row(7), Row(8))) + + val e1 = intercept[AnalysisException]( + sql("SELECT * FROM t TIMESTAMP AS OF INTERVAL 1 DAY").collect() + ) + assert(e1.message.contains("is not a valid timestamp expression for time travel")) + + val e2 = intercept[AnalysisException]( + sql("SELECT * FROM t TIMESTAMP AS OF 'abc'").collect() + ) + assert(e2.message.contains("is not a valid timestamp expression for time travel")) + + val e3 = intercept[AnalysisException]( + sql("SELECT * FROM t TIMESTAMP AS OF current_user()").collect() + ) + assert(e3.message.contains("is not a valid timestamp expression for time travel")) + + val e4 = intercept[AnalysisException]( + sql("SELECT * FROM t TIMESTAMP AS OF CAST(rand() AS STRING)").collect() + ) + assert(e4.message.contains("is not a valid timestamp expression for time travel")) + + val e5 = intercept[AnalysisException]( + sql("SELECT * FROM t TIMESTAMP AS OF abs(true)").collect() + ) + assert(e5.message.contains("cannot resolve 'abs(true)' due to data type mismatch")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala index 3840dd3afa..9cb524c2c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -34,7 +34,6 @@ import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{LongType, StructType} @@ -276,7 +275,9 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with } } - test("mock time travel test") { + test("time travel") { + // The testing in-memory table simply append the version/timestamp to the table name when + // looking up tables. val t1 = s"$catalogName.tSnapshot123456789" val t2 = s"$catalogName.t2345678910" withTable(t1, t2) { @@ -298,10 +299,10 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with val ts1 = DateTimeUtils.stringToTimestampAnsi( UTF8String.fromString("2019-01-29 00:37:58"), - DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone)) + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) val ts2 = DateTimeUtils.stringToTimestampAnsi( UTF8String.fromString("2021-01-29 00:37:58"), - DateTimeUtils.getZoneId(SQLConf.get.sessionLocalTimeZone)) + DateTimeUtils.getZoneId(conf.sessionLocalTimeZone)) val t3 = s"$catalogName.t$ts1" val t4 = s"$catalogName.t$ts2" withTable(t3, t4) { @@ -328,7 +329,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with timestamp = Some("2019-01-29 00:37:58")) } assert(e.getMessage - .contains("Cannot specify both version and timestamp when scanning the table.")) + .contains("Cannot specify both version and timestamp when time travelling the table.")) } private def checkV2Identifiers( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala index f262cf152c..8e8eb85063 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.connector -import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, CreateTablePartitioningValidationSuite, ResolvedTable, TestRelation2, TestTable2, UnresolvedFieldName, UnresolvedFieldPosition} -import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterTableCommand, CreateTableAsSelect, DropColumns, LogicalPlan, QualifiedColType, RenameColumn, ReplaceColumns, ReplaceTableAsSelect} +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, CreateTablePartitioningValidationSuite, ResolvedTable, TestRelation2, TestTable2, UnresolvedDBObjectName, UnresolvedFieldName, UnresolvedFieldPosition} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterTableCommand, CreateTableAsSelect, DropColumns, LogicalPlan, QualifiedColType, RenameColumn, ReplaceColumns, ReplaceTableAsSelect, TableSpec} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition @@ -93,12 +93,13 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes Seq(true, false).foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { Seq("ID", "iD").foreach { ref => + val tableSpec = TableSpec(None, Map.empty, None, Map.empty, + None, None, None, false) val plan = ReplaceTableAsSelect( - catalog, - Identifier.of(Array(), "table_name"), + UnresolvedDBObjectName(Array("table_name"), isNamespace = false), Expressions.identity(ref) :: Nil, TestRelation2, - Map.empty, + tableSpec, Map.empty, orCreate = true) @@ -116,12 +117,13 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes Seq(true, false).foreach { caseSensitive => withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { Seq("POINT.X", "point.X", "poInt.x", "poInt.X").foreach { ref => + val tableSpec = TableSpec(None, Map.empty, None, Map.empty, + None, None, None, false) val plan = ReplaceTableAsSelect( - catalog, - Identifier.of(Array(), "table_name"), + UnresolvedDBObjectName(Array("table_name"), isNamespace = false), Expressions.bucket(4, ref) :: Nil, TestRelation2, - Map.empty, + tableSpec, Map.empty, orCreate = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala index 81e692076b..740c10f17b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -17,17 +17,21 @@ package org.apache.spark.sql.execution +import java.util.Locale import java.util.concurrent.Executors +import java.util.concurrent.atomic.AtomicInteger import scala.collection.parallel.immutable.ParRange import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart import org.apache.spark.sql.types._ import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.Utils.REDACTION_REPLACEMENT_TEXT class SQLExecutionSuite extends SparkFunSuite { @@ -157,6 +161,45 @@ class SQLExecutionSuite extends SparkFunSuite { } } } + + test("SPARK-34735: Add modified configs for SQL execution in UI") { + val spark = SparkSession.builder() + .master("local[*]") + .appName("test") + .config("k1", "v1") + .getOrCreate() + + try { + val index = new AtomicInteger(0) + spark.sparkContext.addSparkListener(new SparkListener { + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case start: SparkListenerSQLExecutionStart => + if (index.get() == 0 && hasProject(start)) { + assert(!start.modifiedConfigs.contains("k1")) + index.incrementAndGet() + } else if (index.get() == 1 && hasProject(start)) { + assert(start.modifiedConfigs.contains("k2")) + assert(start.modifiedConfigs("k2") == "v2") + assert(start.modifiedConfigs.contains("redaction.password")) + assert(start.modifiedConfigs("redaction.password") == REDACTION_REPLACEMENT_TEXT) + index.incrementAndGet() + } + case _ => + } + + private def hasProject(start: SparkListenerSQLExecutionStart): Boolean = + start.physicalPlanDescription.toLowerCase(Locale.ROOT).contains("project") + }) + spark.sql("SELECT 1").collect() + spark.sql("SET k2 = v2") + spark.sql("SET redaction.password = 123") + spark.sql("SELECT 1").collect() + spark.sparkContext.listenerBus.waitUntilEmpty() + assert(index.get() == 2) + } finally { + spark.stop() + } + } } object SQLExecutionSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala index 08789e63fa..55f1713422 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkFunSuite +import org.apache.spark.scheduler.SparkListenerEvent import org.apache.spark.sql.LocalSparkSession import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart} import org.apache.spark.sql.test.TestSparkSession @@ -28,28 +29,46 @@ import org.apache.spark.util.JsonProtocol class SQLJsonProtocolSuite extends SparkFunSuite with LocalSparkSession { test("SparkPlanGraph backward compatibility: metadata") { - val SQLExecutionStartJsonString = - """ - |{ - | "Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart", - | "executionId":0, - | "description":"test desc", - | "details":"test detail", - | "physicalPlanDescription":"test plan", - | "sparkPlanInfo": { - | "nodeName":"TestNode", - | "simpleString":"test string", - | "children":[], - | "metadata":{}, - | "metrics":[] - | }, - | "time":0 - |} + Seq(true, false).foreach { newExecutionStartEvent => + val event = if (newExecutionStartEvent) { + "org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart" + } else { + "org.apache.spark.sql.execution.OldVersionSQLExecutionStart" + } + val SQLExecutionStartJsonString = + s""" + |{ + | "Event":"$event", + | "executionId":0, + | "description":"test desc", + | "details":"test detail", + | "physicalPlanDescription":"test plan", + | "sparkPlanInfo": { + | "nodeName":"TestNode", + | "simpleString":"test string", + | "children":[], + | "metadata":{}, + | "metrics":[] + | }, + | "time":0, + | "modifiedConfigs": { + | "k1":"v1" + | } + |} """.stripMargin - val reconstructedEvent = JsonProtocol.sparkEventFromJson(parse(SQLExecutionStartJsonString)) - val expectedEvent = SparkListenerSQLExecutionStart(0, "test desc", "test detail", "test plan", - new SparkPlanInfo("TestNode", "test string", Nil, Map(), Nil), 0) - assert(reconstructedEvent == expectedEvent) + + val reconstructedEvent = JsonProtocol.sparkEventFromJson(parse(SQLExecutionStartJsonString)) + if (newExecutionStartEvent) { + val expectedEvent = SparkListenerSQLExecutionStart(0, "test desc", "test detail", + "test plan", new SparkPlanInfo("TestNode", "test string", Nil, Map(), Nil), 0, + Map("k1" -> "v1")) + assert(reconstructedEvent == expectedEvent) + } else { + val expectedOldEvent = OldVersionSQLExecutionStart(0, "test desc", "test detail", + "test plan", new SparkPlanInfo("TestNode", "test string", Nil, Map(), Nil), 0) + assert(reconstructedEvent == expectedOldEvent) + } + } } test("SparkListenerSQLExecutionEnd backward compatibility") { @@ -77,3 +96,12 @@ class SQLJsonProtocolSuite extends SparkFunSuite with LocalSparkSession { assert(readBack == event) } } + +private case class OldVersionSQLExecutionStart( + executionId: Long, + description: String, + details: String, + physicalPlanDescription: String, + sparkPlanInfo: SparkPlanInfo, + time: Long) + extends SparkListenerEvent diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index cc465227d6..1861d9cf04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} @@ -220,12 +219,6 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } } - private def assertNoSuchTable(query: String): Unit = { - intercept[NoSuchTableException] { - sql(query) - } - } - private def assertAnalysisError(query: String, message: String): Unit = { val e = intercept[AnalysisException](sql(query)) assert(e.message.contains(message)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala index 091442da4d..cf87a5c7f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala @@ -378,6 +378,21 @@ abstract class SQLViewTestSuite extends QueryTest with SQLTestUtils { } } } + + test("SPARK-37219: time travel is unsupported") { + val viewName = createView("testView", "SELECT 1 col") + withView(viewName) { + val e1 = intercept[AnalysisException]( + sql(s"SELECT * FROM $viewName VERSION AS OF 1").collect() + ) + assert(e1.message.contains(s"$viewName is a view which does not support time travel")) + + val e2 = intercept[AnalysisException]( + sql(s"SELECT * FROM $viewName TIMESTAMP AS OF '2000-10-10'").collect() + ) + assert(e2.message.contains(s"$viewName is a view which does not support time travel")) + } + } } abstract class TempViewTestSuite extends SQLViewTestSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala index 9f70c8aeca..99856650fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ShufflePartitionsUtilSuite.scala @@ -703,27 +703,55 @@ class ShufflePartitionsUtilSuite extends SparkFunSuite with LocalSparkContext { test("splitSizeListByTargetSize") { val targetSize = 100 + val smallPartitionFactor1 = ShufflePartitionsUtil.SMALL_PARTITION_FACTOR // merge the small partitions at the beginning/end val sizeList1 = Seq[Long](15, 90, 15, 15, 15, 90, 15) - assert(ShufflePartitionsUtil.splitSizeListByTargetSize(sizeList1, targetSize).toSeq == + assert(ShufflePartitionsUtil.splitSizeListByTargetSize( + sizeList1, targetSize, smallPartitionFactor1).toSeq == Seq(0, 2, 5)) // merge the small partitions in the middle val sizeList2 = Seq[Long](30, 15, 90, 10, 90, 15, 30) - assert(ShufflePartitionsUtil.splitSizeListByTargetSize(sizeList2, targetSize).toSeq == + assert(ShufflePartitionsUtil.splitSizeListByTargetSize( + sizeList2, targetSize, smallPartitionFactor1).toSeq == Seq(0, 2, 4, 5)) // merge small partitions if the partition itself is smaller than // targetSize * SMALL_PARTITION_FACTOR val sizeList3 = Seq[Long](15, 1000, 15, 1000) - assert(ShufflePartitionsUtil.splitSizeListByTargetSize(sizeList3, targetSize).toSeq == + assert(ShufflePartitionsUtil.splitSizeListByTargetSize( + sizeList3, targetSize, smallPartitionFactor1).toSeq == Seq(0, 3)) // merge small partitions if the combined size is smaller than // targetSize * MERGED_PARTITION_FACTOR val sizeList4 = Seq[Long](35, 75, 90, 20, 35, 25, 35) - assert(ShufflePartitionsUtil.splitSizeListByTargetSize(sizeList4, targetSize).toSeq == + assert(ShufflePartitionsUtil.splitSizeListByTargetSize( + sizeList4, targetSize, smallPartitionFactor1).toSeq == Seq(0, 2, 3)) + + val smallPartitionFactor2 = 0.5 + // merge last two partition if their size is not bigger than smallPartitionFactor * target + val sizeList5 = Seq[Long](50, 50, 40, 5) + assert(ShufflePartitionsUtil.splitSizeListByTargetSize( + sizeList5, targetSize, smallPartitionFactor2).toSeq == + Seq(0)) + + val sizeList6 = Seq[Long](40, 5, 50, 45) + assert(ShufflePartitionsUtil.splitSizeListByTargetSize( + sizeList6, targetSize, smallPartitionFactor2).toSeq == + Seq(0)) + + // do not merge + val sizeList7 = Seq[Long](50, 50, 10, 40, 5) + assert(ShufflePartitionsUtil.splitSizeListByTargetSize( + sizeList7, targetSize, smallPartitionFactor2).toSeq == + Seq(0, 2)) + + val sizeList8 = Seq[Long](10, 40, 5, 50, 50) + assert(ShufflePartitionsUtil.splitSizeListByTargetSize( + sizeList8, targetSize, smallPartitionFactor2).toSeq == + Seq(0, 3)) } test("SPARK-35923: Coalesce empty partition with mixed CoalescedPartitionSpec and" + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index e7d630d1ab..ba6dd170d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.internal.config.ConfigEntry import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedGenerator, UnresolvedHaving, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Concat, GreaterThan, Literal, NullsFirst, SortOrder, UnresolvedWindowExpression, UnspecifiedFrame, WindowSpecDefinition, WindowSpecReference} +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.connector.catalog.TableCatalog import org.apache.spark.sql.execution.command._ @@ -73,6 +74,14 @@ class SparkSqlParserSuite extends AnalysisTest { } } + test("SET with comment") { + assertEqual(s"SET my_path = /a/b/*", SetCommand(Some("my_path" -> Some("/a/b/*")))) + val e1 = intercept[ParseException](parser.parsePlan("SET k=`v` /*")) + assert(e1.getMessage.contains(s"Unclosed bracketed comment")) + val e2 = intercept[ParseException](parser.parsePlan("SET `k`=`v` /*")) + assert(e2.getMessage.contains(s"Unclosed bracketed comment")) + } + test("Report Error for invalid usage of SET command") { assertEqual("SET", SetCommand(None)) assertEqual("SET -v", SetCommand(Some("-v", None))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 096b80b359..02f3863c9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -2223,6 +2223,37 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-37357: Add small partition factor for rebalance partitions") { + withTempView("v") { + withSQLConf( + SQLConf.ADAPTIVE_OPTIMIZE_SKEWS_IN_REBALANCE_PARTITIONS_ENABLED.key -> "true", + SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + spark.sparkContext.parallelize( + (1 to 8).map(i => TestData(if (i > 2) 2 else i, i.toString)), 3) + .toDF("c1", "c2").createOrReplaceTempView("v") + + def checkAQEShuffleReadExists(query: String, exists: Boolean): Unit = { + val (_, adaptive) = runAdaptiveAndVerifyResult(query) + assert( + collect(adaptive) { + case read: AQEShuffleReadExec => read + }.nonEmpty == exists) + } + + withSQLConf(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "200") { + withSQLConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR.key -> "0.5") { + // block size: [88, 97, 97] + checkAQEShuffleReadExists("SELECT /*+ REBALANCE(c1) */ * FROM v", false) + } + withSQLConf(SQLConf.ADAPTIVE_REBALANCE_PARTITIONS_SMALL_PARTITION_FACTOR.key -> "0.2") { + // block size: [88, 97, 97] + checkAQEShuffleReadExists("SELECT /*+ REBALANCE(c1) */ * FROM v", true) + } + } + } + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala index 0fc43c7052..0e9e9a7060 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -119,31 +119,36 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { prepareTable(dir, spark.sql(s"SELECT CAST(value as ${dataType.sql}) id FROM t1")) + val query = dataType match { + case BooleanType => "sum(cast(id as bigint))" + case _ => "sum(id)" + } + sqlBenchmark.addCase("SQL CSV") { _ => - spark.sql("select sum(id) from csvTable").noop() + spark.sql(s"select $query from csvTable").noop() } sqlBenchmark.addCase("SQL Json") { _ => - spark.sql("select sum(id) from jsonTable").noop() + spark.sql(s"select $query from jsonTable").noop() } sqlBenchmark.addCase("SQL Parquet Vectorized") { _ => - spark.sql("select sum(id) from parquetTable").noop() + spark.sql(s"select $query from parquetTable").noop() } sqlBenchmark.addCase("SQL Parquet MR") { _ => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("select sum(id) from parquetTable").noop() + spark.sql(s"select $query from parquetTable").noop() } } sqlBenchmark.addCase("SQL ORC Vectorized") { _ => - spark.sql("SELECT sum(id) FROM orcTable").noop() + spark.sql(s"SELECT $query FROM orcTable").noop() } sqlBenchmark.addCase("SQL ORC MR") { _ => withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { - spark.sql("SELECT sum(id) FROM orcTable").noop() + spark.sql(s"SELECT $query FROM orcTable").noop() } } @@ -157,6 +162,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { var longSum = 0L var doubleSum = 0.0 val aggregateValue: (ColumnVector, Int) => Unit = dataType match { + case BooleanType => (col: ColumnVector, i: Int) => if (col.getBoolean(i)) longSum += 1L case ByteType => (col: ColumnVector, i: Int) => longSum += col.getByte(i) case ShortType => (col: ColumnVector, i: Int) => longSum += col.getShort(i) case IntegerType => (col: ColumnVector, i: Int) => longSum += col.getInt(i) @@ -191,6 +197,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { var longSum = 0L var doubleSum = 0.0 val aggregateValue: (InternalRow) => Unit = dataType match { + case BooleanType => (col: InternalRow) => if (col.getBoolean(0)) longSum += 1L case ByteType => (col: InternalRow) => longSum += col.getByte(0) case ShortType => (col: InternalRow) => longSum += col.getShort(0) case IntegerType => (col: InternalRow) => longSum += col.getInt(0) @@ -542,7 +549,7 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { runBenchmark("SQL Single Numeric Column Scan") { - Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { + Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { dataType => numericScanBenchmark(1024 * 1024 * 15, dataType) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationParserSuite.scala new file mode 100644 index 0000000000..bc1ffb93fe --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationParserSuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedNamespace} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan +import org.apache.spark.sql.catalyst.plans.logical.SetNamespaceLocation + +class AlterNamespaceSetLocationParserSuite extends AnalysisTest { + test("set namespace location") { + comparePlans( + parsePlan("ALTER DATABASE a.b.c SET LOCATION '/home/user/db'"), + SetNamespaceLocation( + UnresolvedNamespace(Seq("a", "b", "c")), "/home/user/db")) + + comparePlans( + parsePlan("ALTER SCHEMA a.b.c SET LOCATION '/home/user/db'"), + SetNamespaceLocation( + UnresolvedNamespace(Seq("a", "b", "c")), "/home/user/db")) + + comparePlans( + parsePlan("ALTER NAMESPACE a.b.c SET LOCATION '/home/user/db'"), + SetNamespaceLocation( + UnresolvedNamespace(Seq("a", "b", "c")), "/home/user/db")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationSuiteBase.scala new file mode 100644 index 0000000000..25bae01821 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterNamespaceSetLocationSuiteBase.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.connector.catalog.SupportsNamespaces + +/** + * This base suite contains unified tests for the `ALTER NAMESPACE ... SET LOCATION` command that + * check V1 and V2 table catalogs. The tests that cannot run for all supported catalogs are located + * in more specific test suites: + * + * - V2 table catalog tests: + * `org.apache.spark.sql.execution.command.v2.AlterNamespaceSetLocationSuite` + * - V1 table catalog tests: + * `org.apache.spark.sql.execution.command.v1.AlterNamespaceSetLocationSuiteBase` + * - V1 In-Memory catalog: + * `org.apache.spark.sql.execution.command.v1.AlterNamespaceSetLocationSuite` + * - V1 Hive External catalog: + * `org.apache.spark.sql.hive.execution.command.AlterNamespaceSetLocationSuite` + */ +trait AlterNamespaceSetLocationSuiteBase extends QueryTest with DDLCommandTestUtils { + override val command = "ALTER NAMESPACE ... SET LOCATION" + + protected def namespace: String + + protected def notFoundMsgPrefix: String + + test("Empty location string") { + val ns = s"$catalog.$namespace" + withNamespace(ns) { + sql(s"CREATE NAMESPACE $ns") + val message = intercept[IllegalArgumentException] { + sql(s"ALTER NAMESPACE $ns SET LOCATION ''") + }.getMessage + assert(message.contains("Can not create a Path from an empty string")) + } + } + + test("Namespace does not exist") { + val ns = "not_exist" + val message = intercept[AnalysisException] { + sql(s"ALTER DATABASE $catalog.$ns SET LOCATION 'loc'") + }.getMessage + assert(message.contains(s"$notFoundMsgPrefix '$ns' not found")) + } + + // Hive catalog does not support "ALTER NAMESPACE ... SET LOCATION", thus + // this is called from non-Hive v1 and v2 tests. + protected def runBasicTest(): Unit = { + val ns = s"$catalog.$namespace" + withNamespace(ns) { + sql(s"CREATE NAMESPACE IF NOT EXISTS $ns COMMENT " + + "'test namespace' LOCATION '/tmp/loc_test_1'") + sql(s"ALTER NAMESPACE $ns SET LOCATION '/tmp/loc_test_2'") + assert(getLocation(ns).contains("file:/tmp/loc_test_2")) + } + } + + protected def getLocation(namespace: String): String = { + val locationRow = sql(s"DESCRIBE NAMESPACE EXTENDED $namespace") + .toDF("key", "value") + .where(s"key like '${SupportsNamespaces.PROP_LOCATION.capitalize}%'") + .collect() + assert(locationRow.length == 1) + locationRow(0).getString(1) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala index 2aef62988f..0713e9be3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CharVarcharDDLTestBase.scala @@ -150,6 +150,16 @@ trait CharVarcharDDLTestBase extends QueryTest with SQLTestUtils { } } } + + test("SPARK-33892: DESCRIBE COLUMN w/ char/varchar") { + withTable("t") { + sql(s"CREATE TABLE t(v VARCHAR(3), c CHAR(5)) USING $format") + checkAnswer(sql("desc t v").selectExpr("info_value").where("info_value like '%char%'"), + Row("varchar(3)")) + checkAnswer(sql("desc t c").selectExpr("info_value").where("info_value like '%char%'"), + Row("char(5)")) + } + } } class FileSourceCharVarcharDDLTestSuite extends CharVarcharDDLTestBase with SharedSparkSession { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index b86250f093..1cea49884b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.{SparkException, SparkFiles} import org.apache.spark.internal.config import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.{FunctionIdentifier, QualifiedTableName, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchDatabaseException, NoSuchFunctionException, TableFunctionRegistry, TempTableAlreadyExistsException} +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchFunctionException, TableFunctionRegistry, TempTableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.connector.catalog.SupportsNamespaces.PROP_OWNER @@ -375,6 +375,11 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val path = new Path(tmpDir.getCanonicalPath).toUri databaseNames.foreach { dbName => try { + val e = intercept[IllegalArgumentException] { + sql(s"CREATE DATABASE $dbName Location ''") + } + assert(e.getMessage.contains("Can not create a Path from an empty string")) + val dbNameWithoutBackTicks = cleanIdentifier(dbName) sql(s"CREATE DATABASE $dbName Location '$path'") val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) @@ -776,29 +781,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { Row("Comment", "") :: Row("Location", CatalogUtils.URIToString(location)) :: Row("Properties", "((a,a), (b,b), (c,c), (d,d))") :: Nil) - - withTempDir { tmpDir => - if (isUsingHiveMetastore) { - val e1 = intercept[AnalysisException] { - sql(s"ALTER DATABASE $dbName SET LOCATION '${tmpDir.toURI}'") - } - assert(e1.getMessage.contains("does not support altering database location")) - } else { - sql(s"ALTER DATABASE $dbName SET LOCATION '${tmpDir.toURI}'") - val uriInCatalog = catalog.getDatabaseMetadata(dbNameWithoutBackTicks).locationUri - assert("file" === uriInCatalog.getScheme) - assert(new Path(tmpDir.getPath).toUri.getPath === uriInCatalog.getPath) - } - - intercept[NoSuchDatabaseException] { - sql(s"ALTER DATABASE `db-not-exist` SET LOCATION '${tmpDir.toURI}'") - } - - val e3 = intercept[IllegalArgumentException] { - sql(s"ALTER DATABASE $dbName SET LOCATION ''") - } - assert(e3.getMessage.contains("Can not create a Path from an empty string")) - } } finally { catalog.reset() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 85ba14fc7a..a6b979a3fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -26,17 +26,17 @@ import org.mockito.invocation.InvocationOnMock import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AnalysisTest, Analyzer, EmptyFunctionRegistry, NoSuchTableException, ResolvedFieldName, ResolvedTable, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTable} +import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AnalysisTest, Analyzer, EmptyFunctionRegistry, NoSuchTableException, ResolvedDBObjectName, ResolvedFieldName, ResolvedTable, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTable} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{AnsiCast, AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} -import org.apache.spark.sql.catalyst.plans.logical.{AlterColumn, AnalysisOnlyCommand, AppendData, Assignment, CreateTableAsSelect, CreateTableStatement, CreateV2Table, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias, UnsetTableProperties, UpdateAction, UpdateTable} +import org.apache.spark.sql.catalyst.plans.logical.{AlterColumn, AnalysisOnlyCommand, AppendData, Assignment, CreateTable, CreateTableAsSelect, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias, UnsetTableProperties, UpdateAction, UpdateTable} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.FakeV2Provider import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, Table, TableCapability, TableCatalog, V1Table} import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.datasources.CreateTable +import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.sources.SimpleScanSource @@ -210,7 +210,7 @@ class PlanResolutionSuite extends AnalysisTest { private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { parseAndResolve(sql).collect { - case CreateTable(tableDesc, mode, _) => (tableDesc, mode == SaveMode.Ignore) + case CreateTableV1(tableDesc, mode, _) => (tableDesc, mode == SaveMode.Ignore) }.head } @@ -240,7 +240,7 @@ class PlanResolutionSuite extends AnalysisTest { ) parseAndResolve(query) match { - case CreateTable(tableDesc, _, None) => + case CreateTableV1(tableDesc, _, None) => assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) case other => fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + @@ -282,7 +282,7 @@ class PlanResolutionSuite extends AnalysisTest { ) parseAndResolve(query) match { - case CreateTable(tableDesc, _, None) => + case CreateTableV1(tableDesc, _, None) => assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) case other => fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + @@ -302,7 +302,7 @@ class PlanResolutionSuite extends AnalysisTest { comment = Some("abc")) parseAndResolve(sql) match { - case CreateTable(tableDesc, _, None) => + case CreateTableV1(tableDesc, _, None) => assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) case other => fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + @@ -322,7 +322,7 @@ class PlanResolutionSuite extends AnalysisTest { properties = Map("test" -> "test")) parseAndResolve(sql) match { - case CreateTable(tableDesc, _, None) => + case CreateTableV1(tableDesc, _, None) => assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) case other => fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + @@ -341,7 +341,7 @@ class PlanResolutionSuite extends AnalysisTest { provider = Some("parquet")) parseAndResolve(v1) match { - case CreateTable(tableDesc, _, None) => + case CreateTableV1(tableDesc, _, None) => assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) case other => fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + @@ -372,7 +372,7 @@ class PlanResolutionSuite extends AnalysisTest { provider = Some("parquet")) parseAndResolve(sql) match { - case CreateTable(tableDesc, _, None) => + case CreateTableV1(tableDesc, _, None) => assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) case other => fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + @@ -398,7 +398,7 @@ class PlanResolutionSuite extends AnalysisTest { ) parseAndResolve(sql) match { - case CreateTable(tableDesc, _, None) => + case CreateTableV1(tableDesc, _, None) => assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) case other => fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + @@ -471,29 +471,20 @@ class PlanResolutionSuite extends AnalysisTest { |OPTIONS (path 's3://bucket/path/to/data', other 20) """.stripMargin - val expectedProperties = Map( - "p1" -> "v1", - "p2" -> "v2", - "option.other" -> "20", - "provider" -> "parquet", - "location" -> "s3://bucket/path/to/data", - "comment" -> "table comment", - "other" -> "20") - parseAndResolve(sql) match { - case create: CreateV2Table => - assert(create.catalog.name == "testcat") - assert(create.tableName == Identifier.of(Array("mydb"), "table_name")) + case create: CreateTable => + assert(create.name.asInstanceOf[ResolvedDBObjectName].catalog.name == "testcat") + assert(create.name.asInstanceOf[ResolvedDBObjectName].nameParts.mkString(".") == + "mydb.table_name") assert(create.tableSchema == new StructType() .add("id", LongType) .add("description", StringType) .add("point", new StructType().add("x", DoubleType).add("y", DoubleType))) assert(create.partitioning.isEmpty) - assert(create.properties == expectedProperties) assert(create.ignoreIfExists) case other => - fail(s"Expected to parse ${classOf[CreateV2Table].getName} from query," + + fail(s"Expected to parse ${classOf[CreateTable].getName} from query," + s"got ${other.getClass.getName}: $sql") } } @@ -511,29 +502,20 @@ class PlanResolutionSuite extends AnalysisTest { |OPTIONS (path 's3://bucket/path/to/data', other 20) """.stripMargin - val expectedProperties = Map( - "p1" -> "v1", - "p2" -> "v2", - "option.other" -> "20", - "provider" -> "parquet", - "location" -> "s3://bucket/path/to/data", - "comment" -> "table comment", - "other" -> "20") - parseAndResolve(sql, withDefault = true) match { - case create: CreateV2Table => - assert(create.catalog.name == "testcat") - assert(create.tableName == Identifier.of(Array("mydb"), "table_name")) + case create: CreateTable => + assert(create.name.asInstanceOf[ResolvedDBObjectName].catalog.name == "testcat") + assert(create.name.asInstanceOf[ResolvedDBObjectName].nameParts.mkString(".") == + "mydb.table_name") assert(create.tableSchema == new StructType() .add("id", LongType) .add("description", StringType) .add("point", new StructType().add("x", DoubleType).add("y", DoubleType))) assert(create.partitioning.isEmpty) - assert(create.properties == expectedProperties) assert(create.ignoreIfExists) case other => - fail(s"Expected to parse ${classOf[CreateV2Table].getName} from query," + + fail(s"Expected to parse ${classOf[CreateTable].getName} from query," + s"got ${other.getClass.getName}: $sql") } } @@ -551,27 +533,21 @@ class PlanResolutionSuite extends AnalysisTest { |TBLPROPERTIES ('p1'='v1', 'p2'='v2') """.stripMargin - val expectedProperties = Map( - "p1" -> "v1", - "p2" -> "v2", - "provider" -> v2Format, - "location" -> "/user/external/page_view", - "comment" -> "This is the staging page view table") - parseAndResolve(sql) match { - case create: CreateV2Table => - assert(create.catalog.name == CatalogManager.SESSION_CATALOG_NAME) - assert(create.tableName == Identifier.of(Array("mydb"), "page_view")) + case create: CreateTable => + assert(create.name.asInstanceOf[ResolvedDBObjectName].catalog.name == + CatalogManager.SESSION_CATALOG_NAME) + assert(create.name.asInstanceOf[ResolvedDBObjectName].nameParts.mkString(".") == + "mydb.page_view") assert(create.tableSchema == new StructType() .add("id", LongType) .add("description", StringType) .add("point", new StructType().add("x", DoubleType).add("y", DoubleType))) assert(create.partitioning.isEmpty) - assert(create.properties == expectedProperties) assert(create.ignoreIfExists) case other => - fail(s"Expected to parse ${classOf[CreateV2Table].getName} from query," + + fail(s"Expected to parse ${classOf[CreateTable].getName} from query," + s"got ${other.getClass.getName}: $sql") } } @@ -1684,9 +1660,9 @@ class PlanResolutionSuite extends AnalysisTest { */ def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan match { - case CreateTable(tableDesc, mode, query) => + case CreateTableV1(tableDesc, mode, query) => val newTableDesc = tableDesc.copy(createTime = -1L) - CreateTable(newTableDesc, mode, query) + CreateTableV1(newTableDesc, mode, query) case _ => plan // Don't transform } } @@ -1707,8 +1683,8 @@ class PlanResolutionSuite extends AnalysisTest { partitionColumnNames: Seq[String] = Seq.empty, comment: Option[String] = None, mode: SaveMode = SaveMode.ErrorIfExists, - query: Option[LogicalPlan] = None): CreateTable = { - CreateTable( + query: Option[LogicalPlan] = None): CreateTableV1 = { + CreateTableV1( CatalogTable( identifier = TableIdentifier(table, database), tableType = tableType, @@ -1790,7 +1766,7 @@ class PlanResolutionSuite extends AnalysisTest { allSources.foreach { s => val query = s"CREATE TABLE my_tab STORED AS $s" parseAndResolve(query) match { - case ct: CreateTable => + case ct: CreateTableV1 => val hiveSerde = HiveSerDe.sourceToSerDe(s) assert(hiveSerde.isDefined) assert(ct.tableDesc.storage.serde == @@ -1809,14 +1785,14 @@ class PlanResolutionSuite extends AnalysisTest { // No conflicting serdes here, OK parseAndResolve(query1) match { - case parsed1: CreateTable => + case parsed1: CreateTableV1 => assert(parsed1.tableDesc.storage.serde == Some("anything")) assert(parsed1.tableDesc.storage.inputFormat == Some("inputfmt")) assert(parsed1.tableDesc.storage.outputFormat == Some("outputfmt")) } parseAndResolve(query2) match { - case parsed2: CreateTable => + case parsed2: CreateTableV1 => assert(parsed2.tableDesc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) assert(parsed2.tableDesc.storage.inputFormat == Some("inputfmt")) @@ -1832,7 +1808,7 @@ class PlanResolutionSuite extends AnalysisTest { val query = s"CREATE TABLE my_tab ROW FORMAT SERDE 'anything' STORED AS $s" if (supportedSources.contains(s)) { parseAndResolve(query) match { - case ct: CreateTable => + case ct: CreateTableV1 => val hiveSerde = HiveSerDe.sourceToSerDe(s) assert(hiveSerde.isDefined) assert(ct.tableDesc.storage.serde == Some("anything")) @@ -1853,7 +1829,7 @@ class PlanResolutionSuite extends AnalysisTest { val query = s"CREATE TABLE my_tab ROW FORMAT DELIMITED FIELDS TERMINATED BY ' ' STORED AS $s" if (supportedSources.contains(s)) { parseAndResolve(query) match { - case ct: CreateTable => + case ct: CreateTableV1 => val hiveSerde = HiveSerDe.sourceToSerDe(s) assert(hiveSerde.isDefined) assert(ct.tableDesc.storage.serde == hiveSerde.get.serde @@ -1870,14 +1846,14 @@ class PlanResolutionSuite extends AnalysisTest { test("create hive external table") { val withoutLoc = "CREATE EXTERNAL TABLE my_tab STORED AS parquet" parseAndResolve(withoutLoc) match { - case ct: CreateTable => + case ct: CreateTableV1 => assert(ct.tableDesc.tableType == CatalogTableType.EXTERNAL) assert(ct.tableDesc.storage.locationUri.isEmpty) } val withLoc = "CREATE EXTERNAL TABLE my_tab STORED AS parquet LOCATION '/something/anything'" parseAndResolve(withLoc) match { - case ct: CreateTable => + case ct: CreateTableV1 => assert(ct.tableDesc.tableType == CatalogTableType.EXTERNAL) assert(ct.tableDesc.storage.locationUri == Some(new URI("/something/anything"))) } @@ -1897,7 +1873,7 @@ class PlanResolutionSuite extends AnalysisTest { test("create hive table - location implies external") { val query = "CREATE TABLE my_tab STORED AS parquet LOCATION '/something/anything'" parseAndResolve(query) match { - case ct: CreateTable => + case ct: CreateTableV1 => assert(ct.tableDesc.tableType == CatalogTableType.EXTERNAL) assert(ct.tableDesc.storage.locationUri == Some(new URI("/something/anything"))) } @@ -2261,14 +2237,6 @@ class PlanResolutionSuite extends AnalysisTest { assert(e2.getMessage.contains("Operation not allowed")) } - test("create table - properties") { - val query = "CREATE TABLE my_table (id int, name string) TBLPROPERTIES ('k1'='v1', 'k2'='v2')" - parsePlan(query) match { - case state: CreateTableStatement => - assert(state.properties == Map("k1" -> "v1", "k2" -> "v2")) - } - } - test("create table(hive) - everything!") { val query = """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowNamespacesParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowNamespacesParserSuite.scala index c9e5d33fea..7c810671c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowNamespacesParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowNamespacesParserSuite.scala @@ -19,52 +19,48 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedNamespace} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan -import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.ShowNamespaces import org.apache.spark.sql.test.SharedSparkSession class ShowNamespacesParserSuite extends AnalysisTest with SharedSparkSession { - test("all namespaces") { - Seq("SHOW NAMESPACES", "SHOW DATABASES").foreach { sqlCmd => + private val keywords = Seq("NAMESPACES", "DATABASES", "SCHEMAS") + + test("show namespaces in the current catalog") { + keywords.foreach { keyword => comparePlans( - parsePlan(sqlCmd), + parsePlan(s"SHOW $keyword"), ShowNamespaces(UnresolvedNamespace(Seq.empty[String]), None)) } } - test("basic pattern") { - Seq( - "SHOW DATABASES LIKE 'defau*'", - "SHOW NAMESPACES LIKE 'defau*'").foreach { sqlCmd => + test("show namespaces with a pattern") { + keywords.foreach { keyword => comparePlans( - parsePlan(sqlCmd), + parsePlan(s"SHOW $keyword LIKE 'defau*'"), + ShowNamespaces(UnresolvedNamespace(Seq.empty[String]), Some("defau*"))) + // LIKE can be omitted. + comparePlans( + parsePlan(s"SHOW $keyword 'defau*'"), ShowNamespaces(UnresolvedNamespace(Seq.empty[String]), Some("defau*"))) - } - } - - test("FROM/IN operator is not allowed by SHOW DATABASES") { - Seq( - "SHOW DATABASES FROM testcat.ns1.ns2", - "SHOW DATABASES IN testcat.ns1.ns2").foreach { sqlCmd => - val errMsg = intercept[ParseException] { - parsePlan(sqlCmd) - }.getMessage - assert(errMsg.contains("FROM/IN operator is not allowed in SHOW DATABASES")) } } test("show namespaces in/from a namespace") { - comparePlans( - parsePlan("SHOW NAMESPACES FROM testcat.ns1.ns2"), - ShowNamespaces(UnresolvedNamespace(Seq("testcat", "ns1", "ns2")), None)) - comparePlans( - parsePlan("SHOW NAMESPACES IN testcat.ns1.ns2"), - ShowNamespaces(UnresolvedNamespace(Seq("testcat", "ns1", "ns2")), None)) + keywords.foreach { keyword => + comparePlans( + parsePlan(s"SHOW $keyword FROM testcat.ns1.ns2"), + ShowNamespaces(UnresolvedNamespace(Seq("testcat", "ns1", "ns2")), None)) + comparePlans( + parsePlan(s"SHOW $keyword IN testcat.ns1.ns2"), + ShowNamespaces(UnresolvedNamespace(Seq("testcat", "ns1", "ns2")), None)) + } } test("namespaces by a pattern from another namespace") { - comparePlans( - parsePlan("SHOW NAMESPACES IN testcat.ns1 LIKE '*pattern*'"), - ShowNamespaces(UnresolvedNamespace(Seq("testcat", "ns1")), Some("*pattern*"))) + keywords.foreach { keyword => + comparePlans( + parsePlan(s"SHOW $keyword IN testcat.ns1 LIKE '*pattern*'"), + ShowNamespaces(UnresolvedNamespace(Seq("testcat", "ns1")), Some("*pattern*"))) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowNamespacesSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowNamespacesSuiteBase.scala index 1b37444b14..b3693845c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowNamespacesSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowNamespacesSuiteBase.scala @@ -41,6 +41,7 @@ trait ShowNamespacesSuiteBase extends QueryTest with DDLCommandTestUtils { } protected def builtinTopNamespaces: Seq[String] = Seq.empty + protected def isCasePreserving: Boolean = true test("default namespace") { withSQLConf(SQLConf.DEFAULT_CATALOG.key -> catalog) { @@ -51,7 +52,7 @@ trait ShowNamespacesSuiteBase extends QueryTest with DDLCommandTestUtils { test("at the top level") { withNamespace(s"$catalog.ns1", s"$catalog.ns2") { - sql(s"CREATE DATABASE $catalog.ns1") + sql(s"CREATE NAMESPACE $catalog.ns1") sql(s"CREATE NAMESPACE $catalog.ns2") runShowNamespacesSql( @@ -64,24 +65,12 @@ trait ShowNamespacesSuiteBase extends QueryTest with DDLCommandTestUtils { withNamespace(s"$catalog.ns1", s"$catalog.ns2") { sql(s"CREATE NAMESPACE $catalog.ns1") sql(s"CREATE NAMESPACE $catalog.ns2") - Seq( - s"SHOW NAMESPACES IN $catalog LIKE 'ns2'", - s"SHOW NAMESPACES IN $catalog 'ns2'", - s"SHOW NAMESPACES FROM $catalog LIKE 'ns2'", - s"SHOW NAMESPACES FROM $catalog 'ns2'").foreach { sqlCmd => - withClue(sqlCmd) { - runShowNamespacesSql(sqlCmd, Seq("ns2")) - } - } + runShowNamespacesSql(s"SHOW NAMESPACES IN $catalog LIKE 'ns2'", Seq("ns2")) } } test("does not match to any namespace") { - Seq( - "SHOW DATABASES LIKE 'non-existentdb'", - "SHOW NAMESPACES 'non-existentdb'").foreach { sqlCmd => - runShowNamespacesSql(sqlCmd, Seq.empty) - } + runShowNamespacesSql("SHOW NAMESPACES LIKE 'non-existentdb'", Seq.empty) } test("show root namespaces with the default catalog") { @@ -134,4 +123,23 @@ trait ShowNamespacesSuiteBase extends QueryTest with DDLCommandTestUtils { assert(sql("SHOW NAMESPACES").schema.fieldNames.toSeq == Seq("databaseName")) } } + + test("case sensitivity of the pattern string") { + Seq(true, false).foreach { caseSensitive => + withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { + withNamespace(s"$catalog.AAA", s"$catalog.bbb") { + sql(s"CREATE NAMESPACE $catalog.AAA") + sql(s"CREATE NAMESPACE $catalog.bbb") + // TODO: The v1 in-memory catalog should be case preserving as well. + val casePreserving = isCasePreserving && (catalogVersion == "V2" || caseSensitive) + val expected = if (casePreserving) "AAA" else "aaa" + runShowNamespacesSql( + s"SHOW NAMESPACES IN $catalog", + Seq(expected, "bbb") ++ builtinTopNamespaces) + runShowNamespacesSql(s"SHOW NAMESPACES IN $catalog LIKE 'AAA'", Seq(expected)) + runShowNamespacesSql(s"SHOW NAMESPACES IN $catalog LIKE 'aaa'", Seq(expected)) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterNamespaceSetLocationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterNamespaceSetLocationSuite.scala new file mode 100644 index 0000000000..5e0b79570b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/AlterNamespaceSetLocationSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command.v1 + +import org.apache.spark.sql.execution.command + +/** + * This base suite contains unified tests for the `ALTER NAMESPACE ... SET LOCATION` command that + * checks V1 table catalogs. The tests that cannot run for all V1 catalogs are located in more + * specific test suites: + * + * - V1 In-Memory catalog: + * `org.apache.spark.sql.execution.command.v1.AlterNamespaceSetLocationSuite` + * - V1 Hive External catalog: + * `org.apache.spark.sql.hive.execution.command.AlterNamespaceSetLocationSuite` + */ +trait AlterNamespaceSetLocationSuiteBase extends command.AlterNamespaceSetLocationSuiteBase + with command.TestsV1AndV2Commands { + override def namespace: String = "db" + override def notFoundMsgPrefix: String = "Database" +} + +/** + * The class contains tests for the `ALTER NAMESPACE ... SET LOCATION` command to + * check V1 In-Memory table catalog. + */ +class AlterNamespaceSetLocationSuite extends AlterNamespaceSetLocationSuiteBase + with CommandSuiteBase { + override def commandVersion: String = super[AlterNamespaceSetLocationSuiteBase].commandVersion + + test("basic test") { + runBasicTest() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowNamespacesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowNamespacesSuite.scala index 54c5d22464..a1b32e42ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowNamespacesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowNamespacesSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.command.v1 import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.command -import org.apache.spark.sql.internal.SQLConf /** * This base suite contains unified tests for the `SHOW NAMESPACES` and `SHOW DATABASES` commands @@ -42,21 +41,4 @@ trait ShowNamespacesSuiteBase extends command.ShowNamespacesSuiteBase { class ShowNamespacesSuite extends ShowNamespacesSuiteBase with CommandSuiteBase { override def commandVersion: String = "V2" // There is only V2 variant of SHOW NAMESPACES. - - test("case sensitivity") { - Seq(true, false).foreach { caseSensitive => - withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { - withNamespace(s"$catalog.AAA", s"$catalog.bbb") { - sql(s"CREATE NAMESPACE $catalog.AAA") - sql(s"CREATE NAMESPACE $catalog.bbb") - val expected = if (caseSensitive) "AAA" else "aaa" - runShowNamespacesSql( - s"SHOW NAMESPACES IN $catalog", - Seq(expected, "bbb") ++ builtinTopNamespaces) - runShowNamespacesSql(s"SHOW NAMESPACES IN $catalog LIKE 'AAA'", Seq(expected)) - runShowNamespacesSql(s"SHOW NAMESPACES IN $catalog LIKE 'aaa'", Seq(expected)) - } - } - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterNamespaceSetLocationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterNamespaceSetLocationSuite.scala new file mode 100644 index 0000000000..3da5f04ec1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterNamespaceSetLocationSuite.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command.v2 + +import org.apache.spark.sql.execution.command + +/** + * The class contains tests for the `ALTER NAMESPACE ... SET LOCATION` command to check V2 table + * catalogs. + */ +class AlterNamespaceSetLocationSuite extends command.AlterNamespaceSetLocationSuiteBase + with CommandSuiteBase { + override def namespace: String = "ns1.ns2" + override def notFoundMsgPrefix: String = "Namespace" + + test("basic test") { + runBasicTest() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DescribeNamespaceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DescribeNamespaceSuite.scala index a98c6a486a..7d6835f09b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DescribeNamespaceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/DescribeNamespaceSuite.scala @@ -43,7 +43,7 @@ class DescribeNamespaceSuite extends command.DescribeNamespaceSuiteBase with Com assert(description === Seq( Row("Namespace Name", "ns2"), Row(SupportsNamespaces.PROP_COMMENT.capitalize, "test namespace"), - Row(SupportsNamespaces.PROP_LOCATION.capitalize, "/tmp/ns_test"), + Row(SupportsNamespaces.PROP_LOCATION.capitalize, "file:/tmp/ns_test"), Row(SupportsNamespaces.PROP_OWNER.capitalize, Utils.getCurrentUserName())) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala index bafb6608c8..ded657edc6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowNamespacesSuite.scala @@ -53,20 +53,4 @@ class ShowNamespacesSuite extends command.ShowNamespacesSuiteBase with CommandSu }.getMessage assert(errMsg.contains("does not support namespaces")) } - - test("case sensitivity") { - Seq(true, false).foreach { caseSensitive => - withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { - withNamespace(s"$catalog.AAA", s"$catalog.bbb") { - sql(s"CREATE NAMESPACE $catalog.AAA") - sql(s"CREATE NAMESPACE $catalog.bbb") - runShowNamespacesSql( - s"SHOW NAMESPACES IN $catalog", - Seq("AAA", "bbb") ++ builtinTopNamespaces) - runShowNamespacesSql(s"SHOW NAMESPACES IN $catalog LIKE 'AAA'", Seq("AAA")) - runShowNamespacesSql(s"SHOW NAMESPACES IN $catalog LIKE 'aaa'", Seq("AAA")) - } - } - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 466c1bea9d..2b45ee5da2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1012,6 +1012,211 @@ abstract class CSVSuite } } + test("SPARK-37326: Use different pattern to write and infer TIMESTAMP_NTZ values") { + withTempDir { dir => + val path = s"${dir.getCanonicalPath}/csv" + + val exp = spark.sql("select timestamp_ntz'2020-12-12 12:12:12' as col0") + exp.write + .format("csv") + .option("header", "true") + .option("timestampNTZFormat", "yyyy-MM-dd HH:mm:ss.SSSSSS") + .save(path) + + withSQLConf(SQLConf.TIMESTAMP_TYPE.key -> SQLConf.TimestampTypes.TIMESTAMP_NTZ.toString) { + val res = spark.read + .format("csv") + .option("inferSchema", "true") + .option("header", "true") + .option("timestampNTZFormat", "yyyy-MM-dd HH:mm:ss.SSSSSS") + .load(path) + + assert(res.dtypes === exp.dtypes) + checkAnswer(res, exp) + } + } + } + + test("SPARK-37326: Use different pattern to write and infer TIMESTAMP_LTZ values") { + withTempDir { dir => + val path = s"${dir.getCanonicalPath}/csv" + + val exp = spark.sql("select timestamp_ltz'2020-12-12 12:12:12' as col0") + exp.write + .format("csv") + .option("header", "true") + .option("timestampFormat", "yyyy-MM-dd HH:mm:ss.SSSSSS") + .save(path) + + withSQLConf(SQLConf.TIMESTAMP_TYPE.key -> SQLConf.TimestampTypes.TIMESTAMP_LTZ.toString) { + val res = spark.read + .format("csv") + .option("inferSchema", "true") + .option("header", "true") + .option("timestampFormat", "yyyy-MM-dd HH:mm:ss.SSSSSS") + .load(path) + + assert(res.dtypes === exp.dtypes) + checkAnswer(res, exp) + } + } + } + + test("SPARK-37326: Roundtrip in reading and writing TIMESTAMP_NTZ values with custom schema") { + withTempDir { dir => + val path = s"${dir.getCanonicalPath}/csv" + + val exp = spark.sql(""" + select + timestamp_ntz'2020-12-12 12:12:12' as col1, + timestamp_ltz'2020-12-12 12:12:12' as col2 + """) + + exp.write.format("csv").option("header", "true").save(path) + + val res = spark.read + .format("csv") + .schema("col1 TIMESTAMP_NTZ, col2 TIMESTAMP_LTZ") + .option("header", "true") + .load(path) + + checkAnswer(res, exp) + } + } + + test("SPARK-37326: Timestamp type inference for a column with TIMESTAMP_NTZ values") { + withTempDir { dir => + val path = s"${dir.getCanonicalPath}/csv" + + val exp = spark.sql(""" + select timestamp_ntz'2020-12-12 12:12:12' as col0 union all + select timestamp_ntz'2020-12-12 12:12:12' as col0 + """) + + exp.write.format("csv").option("header", "true").save(path) + + val timestampTypes = Seq( + SQLConf.TimestampTypes.TIMESTAMP_NTZ.toString, + SQLConf.TimestampTypes.TIMESTAMP_LTZ.toString) + + for (timestampType <- timestampTypes) { + withSQLConf(SQLConf.TIMESTAMP_TYPE.key -> timestampType) { + val res = spark.read + .format("csv") + .option("inferSchema", "true") + .option("header", "true") + .load(path) + + if (timestampType == SQLConf.TimestampTypes.TIMESTAMP_NTZ.toString) { + checkAnswer(res, exp) + } else { + checkAnswer( + res, + spark.sql(""" + select timestamp_ltz'2020-12-12 12:12:12' as col0 union all + select timestamp_ltz'2020-12-12 12:12:12' as col0 + """) + ) + } + } + } + } + } + + test("SPARK-37326: Timestamp type inference for a mix of TIMESTAMP_NTZ and TIMESTAMP_LTZ") { + withTempDir { dir => + val path = s"${dir.getCanonicalPath}/csv" + + Seq( + "col0", + "2020-12-12T12:12:12.000", + "2020-12-12T17:12:12.000Z", + "2020-12-12T17:12:12.000+05:00", + "2020-12-12T12:12:12.000" + ).toDF("data") + .coalesce(1) + .write.text(path) + + for (policy <- Seq("exception", "corrected", "legacy")) { + withSQLConf(SQLConf.LEGACY_TIME_PARSER_POLICY.key -> policy) { + val res = spark.read.format("csv") + .option("inferSchema", "true") + .option("header", "true") + .load(path) + + if (policy == "legacy") { + // Timestamps without timezone are parsed as strings, so the col0 type would be + // StringType which is similar to reading without schema inference. + val exp = spark.read.format("csv").option("header", "true").load(path) + checkAnswer(res, exp) + } else { + val exp = spark.sql(""" + select timestamp_ltz'2020-12-12T12:12:12.000' as col0 union all + select timestamp_ltz'2020-12-12T17:12:12.000Z' as col0 union all + select timestamp_ltz'2020-12-12T17:12:12.000+05:00' as col0 union all + select timestamp_ltz'2020-12-12T12:12:12.000' as col0 + """) + checkAnswer(res, exp) + } + } + } + } + } + + test("SPARK-37326: Malformed records when reading TIMESTAMP_LTZ as TIMESTAMP_NTZ") { + withTempDir { dir => + val path = s"${dir.getCanonicalPath}/csv" + + Seq( + "2020-12-12T12:12:12.000", + "2020-12-12T17:12:12.000Z", + "2020-12-12T17:12:12.000+05:00", + "2020-12-12T12:12:12.000" + ).toDF("data") + .coalesce(1) + .write.text(path) + + for (timestampNTZFormat <- Seq(None, Some("yyyy-MM-dd'T'HH:mm:ss[.SSS]"))) { + val reader = spark.read.format("csv").schema("col0 TIMESTAMP_NTZ") + val res = timestampNTZFormat match { + case Some(format) => reader.option("timestampNTZFormat", format).load(path) + case None => reader.load(path) + } + + checkAnswer( + res, + Seq( + Row(LocalDateTime.of(2020, 12, 12, 12, 12, 12)), + Row(null), + Row(null), + Row(LocalDateTime.of(2020, 12, 12, 12, 12, 12)) + ) + ) + } + } + } + + test("SPARK-37326: Fail to write TIMESTAMP_NTZ if timestampNTZFormat contains zone offset") { + val patterns = Seq( + "yyyy-MM-dd HH:mm:ss XXX", + "yyyy-MM-dd HH:mm:ss Z", + "yyyy-MM-dd HH:mm:ss z") + + val exp = spark.sql("select timestamp_ntz'2020-12-12 12:12:12' as col0") + for (pattern <- patterns) { + withTempDir { dir => + val path = s"${dir.getCanonicalPath}/csv" + val err = intercept[SparkException] { + exp.write.format("csv").option("timestampNTZFormat", pattern).save(path) + } + assert( + err.getCause.getMessage.contains("Unsupported field: OffsetSeconds") || + err.getCause.getMessage.contains("Unable to extract value") || + err.getCause.getMessage.contains("Unable to extract ZoneId")) + } + } + } + test("Write dates correctly with dateFormat option") { val customSchema = new StructType(Array(StructField("date", DateType, true))) withTempDir { dir => @@ -2489,10 +2694,6 @@ abstract class CSVSuite } test("SPARK-36536: use casting when datetime pattern is not set") { - def isLegacy: Boolean = { - spark.conf.get(SQLConf.LEGACY_TIME_PARSER_POLICY).toUpperCase(Locale.ROOT) == - SQLConf.LegacyBehaviorPolicy.LEGACY.toString - } withSQLConf( SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true", SQLConf.SESSION_LOCAL_TIMEZONE.key -> DateTimeTestUtils.UTC.getId) { @@ -2511,13 +2712,13 @@ abstract class CSVSuite readback, Seq( Row(LocalDate.of(2021, 1, 1), Instant.parse("2021-01-01T00:00:00Z"), - if (isLegacy) null else LocalDateTime.of(2021, 1, 1, 0, 0, 0)), + LocalDateTime.of(2021, 1, 1, 0, 0, 0)), Row(LocalDate.of(2021, 1, 1), Instant.parse("2021-01-01T00:00:00Z"), - if (isLegacy) null else LocalDateTime.of(2021, 1, 1, 0, 0, 0)), + LocalDateTime.of(2021, 1, 1, 0, 0, 0)), Row(LocalDate.of(2021, 2, 1), Instant.parse("2021-03-02T00:00:00Z"), - if (isLegacy) null else LocalDateTime.of(2021, 10, 1, 0, 0, 0)), + LocalDateTime.of(2021, 10, 1, 0, 0, 0)), Row(LocalDate.of(2021, 8, 18), Instant.parse("2021-08-18T21:44:30Z"), - if (isLegacy) null else LocalDateTime.of(2021, 8, 18, 21, 44, 30, 123000000)))) + LocalDateTime.of(2021, 8, 18, 21, 44, 30, 123000000)))) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index 2d6978a810..8bc92f8d57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -776,7 +776,7 @@ abstract class OrcQuerySuite extends OrcQueryTest with SharedSparkSession { } :+ (null, null) withOrcFile(data) { file => - withAllOrcReaders { + withAllNativeOrcReaders { checkAnswer(spark.read.orc(file), data.toDF().collect()) } } @@ -799,7 +799,7 @@ abstract class OrcQuerySuite extends OrcQueryTest with SharedSparkSession { withTempPath { file => val df = spark.createDataFrame(sparkContext.parallelize(data), actualSchema) df.write.orc(file.getCanonicalPath) - withAllOrcReaders { + withAllNativeOrcReaders { val msg = intercept[SparkException] { spark.read.schema(providedSchema).orc(file.getCanonicalPath).collect() }.getMessage @@ -825,7 +825,7 @@ abstract class OrcQuerySuite extends OrcQueryTest with SharedSparkSession { withTempPath { file => val df = spark.createDataFrame(sparkContext.parallelize(data), actualSchema) df.write.orc(file.getCanonicalPath) - withAllOrcReaders { + withAllNativeOrcReaders { checkAnswer(spark.read.schema(providedSchema).orc(file.getCanonicalPath), answer) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 8ffccd9679..8953fbb372 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -485,12 +485,10 @@ abstract class OrcSuite } test("SPARK-31238: compatibility with Spark 2.4 in reading dates") { - Seq(false, true).foreach { vectorized => - withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { - checkAnswer( - readResourceOrcFile("test-data/before_1582_date_v2_4.snappy.orc"), - Row(java.sql.Date.valueOf("1200-01-01"))) - } + withAllNativeOrcReaders { + checkAnswer( + readResourceOrcFile("test-data/before_1582_date_v2_4.snappy.orc"), + Row(java.sql.Date.valueOf("1200-01-01"))) } } @@ -502,23 +500,19 @@ abstract class OrcSuite .write .orc(path) - Seq(false, true).foreach { vectorized => - withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { - checkAnswer( - spark.read.orc(path), - Seq(Row(Date.valueOf("1001-01-01")), Row(Date.valueOf("1582-10-15")))) - } + withAllNativeOrcReaders { + checkAnswer( + spark.read.orc(path), + Seq(Row(Date.valueOf("1001-01-01")), Row(Date.valueOf("1582-10-15")))) } } } test("SPARK-31284: compatibility with Spark 2.4 in reading timestamps") { - Seq(false, true).foreach { vectorized => - withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { - checkAnswer( - readResourceOrcFile("test-data/before_1582_ts_v2_4.snappy.orc"), - Row(java.sql.Timestamp.valueOf("1001-01-01 01:02:03.123456"))) - } + withAllNativeOrcReaders { + checkAnswer( + readResourceOrcFile("test-data/before_1582_ts_v2_4.snappy.orc"), + Row(java.sql.Timestamp.valueOf("1001-01-01 01:02:03.123456"))) } } @@ -530,14 +524,12 @@ abstract class OrcSuite .write .orc(path) - Seq(false, true).foreach { vectorized => - withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { - checkAnswer( - spark.read.orc(path), - Seq( - Row(java.sql.Timestamp.valueOf("1001-01-01 01:02:03.123456")), - Row(java.sql.Timestamp.valueOf("1582-10-15 11:12:13.654321")))) - } + withAllNativeOrcReaders { + checkAnswer( + spark.read.orc(path), + Seq( + Row(java.sql.Timestamp.valueOf("1001-01-01 01:02:03.123456")), + Row(java.sql.Timestamp.valueOf("1582-10-15 11:12:13.654321")))) } } } @@ -809,11 +801,12 @@ abstract class OrcSourceSuite extends OrcSuite with SharedSparkSession { } } - Seq(true, false).foreach { vecReaderEnabled => + withAllNativeOrcReaders { Seq(true, false).foreach { vecReaderNestedColEnabled => + val vecReaderEnabled = SQLConf.get.orcVectorizedReaderEnabled test("SPARK-36931: Support reading and writing ANSI intervals (" + - s"${SQLConf.ORC_VECTORIZED_READER_ENABLED.key}=$vecReaderEnabled, " + - s"${SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key}=$vecReaderNestedColEnabled)") { + s"${SQLConf.ORC_VECTORIZED_READER_ENABLED.key}=$vecReaderEnabled, " + + s"${SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key}=$vecReaderNestedColEnabled)") { withSQLConf( SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala index cd87374e85..96932de327 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala @@ -143,7 +143,7 @@ abstract class OrcTest extends QueryTest with FileBasedDataSourceTest with Befor spark.read.orc(file.getAbsolutePath) } - def withAllOrcReaders(code: => Unit): Unit = { + def withAllNativeOrcReaders(code: => Unit): Unit = { // test the row-based reader withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false")(code) // test the vectorized reader diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala index 2317a4d00e..79b8c9e2c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -29,14 +29,15 @@ import org.apache.spark.sql.test.SharedSparkSession class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSparkSession { import testImplicits._ - val ROW = ((1).toByte, 2, 3L, "abc", Period.of(1, 1, 0), Duration.ofMillis(100)) + val ROW = ((1).toByte, 2, 3L, "abc", Period.of(1, 1, 0), Duration.ofMillis(100), true) val NULL_ROW = ( null.asInstanceOf[java.lang.Byte], null.asInstanceOf[Integer], null.asInstanceOf[java.lang.Long], null.asInstanceOf[String], null.asInstanceOf[Period], - null.asInstanceOf[Duration]) + null.asInstanceOf[Duration], + null.asInstanceOf[java.lang.Boolean]) test("All Types Dictionary") { (1 :: 1000 :: Nil).foreach { n => { @@ -59,6 +60,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSparkSess assert(batch.column(3).getUTF8String(i).toString == "abc") assert(batch.column(4).getInt(i) == 13) assert(batch.column(5).getLong(i) == 100000) + assert(batch.column(6).getBoolean(i) == true) i += 1 } reader.close() @@ -88,6 +90,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSparkSess assert(batch.column(3).isNullAt(i)) assert(batch.column(4).isNullAt(i)) assert(batch.column(5).isNullAt(i)) + assert(batch.column(6).isNullAt(i)) i += 1 } reader.close() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index f12e5af9d4..0966319f53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -145,7 +145,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession val numRecords = 100 val writer = createParquetWriter(schema, tablePath, dictionaryEnabled = dictEnabled) - (0 until numRecords).map { i => + (0 until numRecords).foreach { i => val record = new SimpleGroup(schema) for (group <- Seq(0, 2, 4)) { record.add(group, 1000L) // millis diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index a3aa74d9fc..bf37421331 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -1065,14 +1065,15 @@ abstract class ParquetPartitionDiscoverySuite } } - test("SPARK-23436: invalid Dates should be inferred as String in partition inference") { + test( + "SPARK-23436, SPARK-36861: invalid Dates should be inferred as String in partition inference") { withTempPath { path => - val data = Seq(("1", "2018-41", "2018-01-01-04", "test")) - .toDF("id", "date_month", "date_hour", "data") + val data = Seq(("1", "2018-41", "2018-01-01-04", "2021-01-01T00", "test")) + .toDF("id", "date_month", "date_hour", "date_t_hour", "data") - data.write.partitionBy("date_month", "date_hour").parquet(path.getAbsolutePath) + data.write.partitionBy("date_month", "date_hour", "date_t_hour").parquet(path.getAbsolutePath) val input = spark.read.parquet(path.getAbsolutePath).select("id", - "date_month", "date_hour", "data") + "date_month", "date_hour", "date_t_hour", "data") assert(input.schema.sameType(input.schema)) checkAnswer(input, data) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/history/SQLEventFilterBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/history/SQLEventFilterBuilderSuite.scala index 5f3d750e8f..090c149886 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/history/SQLEventFilterBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/history/SQLEventFilterBuilderSuite.scala @@ -58,7 +58,7 @@ class SQLEventFilterBuilderSuite extends SparkFunSuite { // Start SQL Execution listener.onOtherEvent(SparkListenerSQLExecutionStart(1, "desc1", "details1", "plan", - new SparkPlanInfo("node", "str", Seq.empty, Map.empty, Seq.empty), time)) + new SparkPlanInfo("node", "str", Seq.empty, Map.empty, Seq.empty), time, Map.empty)) time += 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/history/SQLLiveEntitiesEventFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/history/SQLLiveEntitiesEventFilterSuite.scala index 46fdaba413..724df8ebe8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/history/SQLLiveEntitiesEventFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/history/SQLLiveEntitiesEventFilterSuite.scala @@ -42,7 +42,7 @@ class SQLLiveEntitiesEventFilterSuite extends SparkFunSuite { // Verifying with finished SQL execution 1 assert(Some(false) === acceptFn(SparkListenerSQLExecutionStart(1, "description1", "details1", - "plan", null, 0))) + "plan", null, 0, Map.empty))) assert(Some(false) === acceptFn(SparkListenerSQLExecutionEnd(1, 0))) assert(Some(false) === acceptFn(SparkListenerSQLAdaptiveExecutionUpdate(1, "plan", null))) assert(Some(false) === acceptFn(SparkListenerDriverAccumUpdates(1, Seq.empty))) @@ -89,7 +89,7 @@ class SQLLiveEntitiesEventFilterSuite extends SparkFunSuite { // Verifying with live SQL execution 2 assert(Some(true) === acceptFn(SparkListenerSQLExecutionStart(2, "description2", "details2", - "plan", null, 0))) + "plan", null, 0, Map.empty))) assert(Some(true) === acceptFn(SparkListenerSQLExecutionEnd(2, 0))) assert(Some(true) === acceptFn(SparkListenerSQLAdaptiveExecutionUpdate(2, "plan", null))) assert(Some(true) === acceptFn(SparkListenerDriverAccumUpdates(2, Seq.empty))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala index 3e9ccb0f70..a0bd0fb582 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSinkSuite.scala @@ -21,6 +21,8 @@ import scala.collection.mutable import scala.language.implicitConversions import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.execution.SerializeFromObjectExec import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming._ @@ -43,6 +45,21 @@ class ForeachBatchSinkSuite extends StreamTest { check(in = 5, 6, 7)(out = 7, 8, 9)) } + test("foreachBatch with non-stateful query - untyped Dataset") { + val mem = MemoryStream[Int] + val ds = mem.toDF.selectExpr("value + 1 as value") + + val tester = new ForeachBatchTester[Row](mem)(RowEncoder.apply(ds.schema)) + val writer = (df: DataFrame, batchId: Long) => + tester.record(batchId, df.selectExpr("value + 1")) + + import tester._ + testWriter(ds, writer)( + // out = in + 2 (i.e. 1 in query, 1 in writer) + check(in = 1, 2, 3)(out = Row(3), Row(4), Row(5)), + check(in = 5, 6, 7)(out = Row(7), Row(8), Row(9))) + } + test("foreachBatch with stateful query in update mode") { val mem = MemoryStream[Int] val ds = mem.toDF() @@ -79,6 +96,35 @@ class ForeachBatchSinkSuite extends StreamTest { check(in = 2)(out = (0, 2L), (1, 1L))) } + test("foreachBatch with batch specific operations") { + val mem = MemoryStream[Int] + val ds = mem.toDS.map(_ + 1) + + val tester = new ForeachBatchTester[Int](mem) + val writer: (Dataset[Int], Long) => Unit = { case (df, batchId) => + df.persist() + + val newDF = df + .map(_ + 1) + .repartition(1) + .sort(Column("value").desc) + tester.record(batchId, newDF) + + // just run another simple query against cached DF to confirm they don't conflict each other + val curValues = df.collect() + val newValues = df.map(_ + 2).collect() + assert(curValues.map(_ + 2) === newValues) + + df.unpersist() + } + + import tester._ + testWriter(ds, writer)( + // out = in + 2 (i.e. 1 in query, 1 in writer), with sorted + check(in = 1, 2, 3)(out = 5, 4, 3), + check(in = 5, 6, 7)(out = 9, 8, 7)) + } + test("foreachBatchSink does not affect metric generation") { val mem = MemoryStream[Int] val ds = mem.toDS.map(_ + 1) @@ -109,6 +155,36 @@ class ForeachBatchSinkSuite extends StreamTest { assert(ex3.getMessage.contains("'foreachBatch' does not support partitioning")) } + test("foreachBatch should not introduce object serialization") { + def assertPlan[T](stream: MemoryStream[Int], ds: Dataset[T]): Unit = { + var planAsserted = false + + val writer: (Dataset[T], Long) => Unit = { case (df, _) => + assert(df.queryExecution.executedPlan.find { p => + p.isInstanceOf[SerializeFromObjectExec] + }.isEmpty, "Untyped Dataset should not introduce serialization on object!") + planAsserted = true + } + + stream.addData(1, 2, 3, 4, 5) + + val query = ds.writeStream.trigger(Trigger.Once()).foreachBatch(writer).start() + query.awaitTermination() + + assert(planAsserted, "ForeachBatch writer should be called!") + } + + // typed + val mem = MemoryStream[Int] + val ds = mem.toDS.map(_ + 1) + assertPlan(mem, ds) + + // untyped + val mem2 = MemoryStream[Int] + val dsUntyped = mem2.toDF().selectExpr("value + 1 as value") + assertPlan(mem2, dsUntyped) + } + // ============== Helper classes and methods ================= private class ForeachBatchTester[T: Encoder](memoryStream: MemoryStream[Int]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/AllExecutionsPageSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/AllExecutionsPageSuite.scala index 24b8a973ad..1f5cbb0e19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/AllExecutionsPageSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/AllExecutionsPageSuite.scala @@ -112,7 +112,8 @@ class AllExecutionsPageSuite extends SharedSparkSession with BeforeAndAfter { "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + System.currentTimeMillis(), + Map.empty)) listener.onOtherEvent(SparkListenerSQLExecutionEnd( executionId, System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala index 533d98da24..aa3988ae37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/MetricsAggregationBenchmark.scala @@ -79,7 +79,8 @@ object MetricsAggregationBenchmark extends BenchmarkBase { getClass().getName(), getClass().getName(), planInfo, - System.currentTimeMillis()) + System.currentTimeMillis(), + Map.empty) val executionEnd = SparkListenerSQLExecutionEnd(executionId, System.currentTimeMillis()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index e776a4ac23..61230641de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -198,7 +198,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + System.currentTimeMillis(), + Map.empty)) listener.onJobStart(SparkListenerJobStart( jobId = 0, @@ -345,7 +346,7 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils val listener = new SparkListener { override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { - case SparkListenerSQLExecutionStart(_, _, _, planDescription, _, _) => + case SparkListenerSQLExecutionStart(_, _, _, planDescription, _, _, _) => assert(expected.forall(planDescription.contains)) checkDone = true case _ => // ignore other events @@ -387,7 +388,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + System.currentTimeMillis(), + Map.empty)) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), @@ -416,7 +418,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + System.currentTimeMillis(), + Map.empty)) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), @@ -456,7 +459,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + System.currentTimeMillis(), + Map.empty)) listener.onJobStart(SparkListenerJobStart( jobId = 0, time = System.currentTimeMillis(), @@ -485,7 +489,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + System.currentTimeMillis(), + Map.empty)) listener.onOtherEvent(SparkListenerSQLExecutionEnd( executionId, System.currentTimeMillis())) listener.onJobStart(SparkListenerJobStart( @@ -515,7 +520,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - System.currentTimeMillis())) + System.currentTimeMillis(), + Map.empty)) var stageId = 0 def twoStageJob(jobId: Int): Unit = { @@ -654,7 +660,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - time)) + time, + Map.empty)) time += 1 listener.onOtherEvent(SparkListenerSQLExecutionStart( 2, @@ -662,7 +669,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - time)) + time, + Map.empty)) // Stop execution 2 before execution 1 time += 1 @@ -678,7 +686,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils "test", df.queryExecution.toString, SparkPlanInfo.fromSparkPlan(df.queryExecution.executedPlan), - time)) + time, + Map.empty)) assert(statusStore.executionsCount === 2) assert(statusStore.execution(2) === None) } @@ -713,7 +722,8 @@ class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTestUtils "test", df.queryExecution.toString, oldPlan, - System.currentTimeMillis())) + System.currentTimeMillis(), + Map.empty)) listener.onJobStart(SparkListenerJobStart( jobId = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 0477b41942..738f2281c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -25,9 +25,11 @@ import java.util.NoSuchElementException import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.language.implicitConversions import scala.util.Random import org.apache.arrow.vector.IntVector +import org.apache.parquet.bytes.ByteBufferInputStream import org.apache.spark.SparkFunSuite import org.apache.spark.memory.MemoryMode @@ -36,6 +38,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.{ArrayBasedMapBuilder, DateTimeUtils, GenericArrayData, MapData} import org.apache.spark.sql.execution.RowToColumnConverter +import org.apache.spark.sql.execution.datasources.parquet.VectorizedPlainValuesReader import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnarBatchRow, ColumnVector} @@ -130,6 +133,97 @@ class ColumnarBatchSuite extends SparkFunSuite { } } + testVector("Boolean APIs", 1024, BooleanType) { + column => + val reference = mutable.ArrayBuffer.empty[Boolean] + + var values = Array(true, false, true, false, false) + var bits = values.foldRight(0)((b, i) => i << 1 | (if (b) 1 else 0)).toByte + column.appendBooleans(2, bits, 0) + reference ++= values.slice(0, 2) + + column.appendBooleans(3, bits, 2) + reference ++= values.slice(2, 5) + + column.appendBooleans(6, true) + reference ++= Array.fill(6)(true) + + column.appendBoolean(false) + reference += false + + var idx = column.elementsAppended + + values = Array(true, true, false, true, false, true, false, true) + bits = values.foldRight(0)((b, i) => i << 1 | (if (b) 1 else 0)).toByte + column.putBooleans(idx, 2, bits, 0) + reference ++= values.slice(0, 2) + idx += 2 + + column.putBooleans(idx, 3, bits, 2) + reference ++= values.slice(2, 5) + idx += 3 + + column.putBooleans(idx, bits) + reference ++= values + idx += 8 + + column.putBoolean(idx, false) + reference += false + idx += 1 + + column.putBooleans(idx, 3, true) + reference ++= Array.fill(3)(true) + idx += 3 + + implicit def intToByte(i: Int): Byte = i.toByte + val buf = ByteBuffer.wrap(Array(0x33, 0x5A, 0xA5, 0xCC, 0x0F, 0xF0, 0xEE, 0x77, 0x88)) + val reader = new VectorizedPlainValuesReader() + reader.initFromPage(0, ByteBufferInputStream.wrap(buf)) + + reader.skipBooleans(1) // bit index 0 + + column.putBoolean(idx, reader.readBoolean) // bit index 1 + reference += true + idx += 1 + + column.putBoolean(idx, reader.readBoolean) // bit index 2 + reference += false + idx += 1 + + reader.skipBooleans(5) // bit index [3, 7] + + column.putBoolean(idx, reader.readBoolean) // bit index 8 + reference += false + idx += 1 + + reader.skipBooleans(8) // bit index [9, 16] + reader.skipBooleans(0) // no-op + + column.putBoolean(idx, reader.readBoolean) // bit index 17 + reference += false + idx += 1 + + reader.skipBooleans(16) // bit index [18, 33] + + reader.readBooleans(4, column, idx) // bit index [34, 37] + reference ++= Array(true, true, false, false) + idx += 4 + + reader.readBooleans(11, column, idx) // bit index [38, 48] + reference ++= Array(false, false, false, false, false, false, true, true, true, true, false) + idx += 11 + + reader.skipBooleans(7) // bit index [49, 55] + + reader.readBooleans(9, column, idx) // bit index [56, 64] + reference ++= Array(true, true, true, false, true, true, true, false, false) + idx += 9 + + reference.zipWithIndex.foreach { v => + assert(v._1 == column.getBoolean(v._2), "VectorType=" + column.getClass.getSimpleName) + } + } + testVector("Byte APIs", 1024, ByteType) { column => val reference = mutable.ArrayBuffer.empty[Byte] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 21a0b24cb4..54bed5c966 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -1229,7 +1229,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi q.processAllAvailable() q } finally { - spark.streams.active.map(_.stop()) + spark.streams.active.foreach(_.stop()) } } diff --git a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala index dbc33c47fe..baa04ada8b 100644 --- a/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/status/api/v1/sql/SqlResourceSuite.scala @@ -85,6 +85,7 @@ object SqlResourceSuite { description = DESCRIPTION, details = "", physicalPlanDescription = PLAN_DESCRIPTION, + Map.empty, metrics = metrics, submissionTime = 1586768888233L, completionTime = Some(new Date(1586768888999L)), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 179b424fef..5fccce2678 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{PartitioningUtils, SourceOptions} import org.apache.spark.sql.hive.client.HiveClient @@ -436,8 +436,17 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val properties = new mutable.HashMap[String, String] properties.put(CREATED_SPARK_VERSION, table.createVersion) + // This is for backward compatibility to Spark 2 to read tables with char/varchar created by + // Spark 3.1. At read side, we will restore a table schema from its properties. So, we need to + // clear the `varchar(n)` and `char(n)` and replace them with `string` as Spark 2 does not have + // a type mapping for them in `DataType.nameToType`. + // See `restoreHiveSerdeTable` for example. + val newSchema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema) CatalogTable.splitLargeTableProp( - DATASOURCE_SCHEMA, schema.json, properties.put, conf.get(SCHEMA_STRING_LENGTH_THRESHOLD)) + DATASOURCE_SCHEMA, + newSchema.json, + properties.put, + conf.get(SCHEMA_STRING_LENGTH_THRESHOLD)) if (partitionColumns.nonEmpty) { properties.put(DATASOURCE_SCHEMA_NUMPARTCOLS, partitionColumns.length.toString) @@ -742,8 +751,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat case None if table.tableType == VIEW => // If this is a view created by Spark 2.2 or higher versions, we should restore its schema // from table properties. - CatalogTable.readLargeTableProp(table.properties, DATASOURCE_SCHEMA).foreach { schemaJson => - table = table.copy(schema = DataType.fromJson(schemaJson).asInstanceOf[StructType]) + getSchemaFromTableProperties(table.properties).foreach { schemaFromTableProps => + table = table.copy(schema = schemaFromTableProps) } // No provider in table properties, which means this is a Hive serde table. @@ -793,9 +802,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // If this is a Hive serde table created by Spark 2.1 or higher versions, we should restore its // schema from table properties. - val schemaJson = CatalogTable.readLargeTableProp(table.properties, DATASOURCE_SCHEMA) - if (schemaJson.isDefined) { - val schemaFromTableProps = DataType.fromJson(schemaJson.get).asInstanceOf[StructType] + val maybeSchemaFromTableProps = getSchemaFromTableProperties(table.properties) + if (maybeSchemaFromTableProps.isDefined) { + val schemaFromTableProps = maybeSchemaFromTableProps.get val partColumnNames = getPartitionColumnsFromTableProperties(table) val reorderedSchema = reorderSchema(schema = schemaFromTableProps, partColumnNames) @@ -821,6 +830,14 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } + private def getSchemaFromTableProperties( + tableProperties: Map[String, String]): Option[StructType] = { + CatalogTable.readLargeTableProp(tableProperties, DATASOURCE_SCHEMA).map { schemaJson => + val parsed = DataType.fromJson(schemaJson).asInstanceOf[StructType] + CharVarcharUtils.getRawSchema(parsed) + } + } + private def restoreDataSourceTable(table: CatalogTable, provider: String): CatalogTable = { // Internally we store the table location in storage properties with key "path" for data // source tables. Here we set the table location to `locationUri` field and filter out the @@ -835,8 +852,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat storageWithLocation.properties.filterKeys(!HIVE_GENERATED_STORAGE_PROPERTIES(_)).toMap) val partitionProvider = table.properties.get(TABLE_PARTITION_PROVIDER) - val schemaFromTableProps = CatalogTable.readLargeTableProp(table.properties, DATASOURCE_SCHEMA) - .map(json => DataType.fromJson(json).asInstanceOf[StructType]).getOrElse(new StructType()) + val schemaFromTableProps = + getSchemaFromTableProperties(table.properties).getOrElse(new StructType()) val partColumnNames = getPartitionColumnsFromTableProperties(table) val reorderedSchema = reorderSchema(schema = schemaFromTableProps, partColumnNames) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index a07ec165f3..7637c3c7a3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -202,11 +202,12 @@ private[hive] class HiveClientImpl( private def getHive(conf: HiveConf): Hive = { try { - Hive.getWithoutRegisterFns(conf) + classOf[Hive].getMethod("getWithoutRegisterFns", classOf[HiveConf]) + .invoke(null, conf).asInstanceOf[Hive] } catch { // SPARK-37069: not all Hive versions have the above method (e.g., Hive 2.3.9 has it but // 2.3.8 don't), therefore here we fallback when encountering the exception. - case _: NoSuchMethodError => + case _: NoSuchMethodException => Hive.get(conf) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/AlterNamespaceSetLocationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/AlterNamespaceSetLocationSuite.scala new file mode 100644 index 0000000000..49f650fb1c --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/AlterNamespaceSetLocationSuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution.command + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.command.v1 + +/** + * The class contains tests for the `ALTER NAMESPACE ... SET LOCATION` command to check + * V1 Hive external table catalog. + */ +class AlterNamespaceSetLocationSuite extends v1.AlterNamespaceSetLocationSuiteBase + with CommandSuiteBase { + override def commandVersion: String = super[AlterNamespaceSetLocationSuiteBase].commandVersion + + test("Hive catalog not supported") { + val ns = s"$catalog.$namespace" + withNamespace(ns) { + sql(s"CREATE NAMESPACE $ns") + val e = intercept[AnalysisException] { + sql(s"ALTER DATABASE $ns SET LOCATION 'loc'") + } + assert(e.getMessage.contains("does not support altering database location")) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/ShowNamespacesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/ShowNamespacesSuite.scala index 015001fa4f..2f7303c42c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/ShowNamespacesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/ShowNamespacesSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hive.execution.command import org.apache.spark.sql.execution.command.v1 -import org.apache.spark.sql.internal.SQLConf /** * The class contains tests for the `SHOW NAMESPACES` and `SHOW DATABASES` commands to check @@ -26,22 +25,8 @@ import org.apache.spark.sql.internal.SQLConf */ class ShowNamespacesSuite extends v1.ShowNamespacesSuiteBase with CommandSuiteBase { override def commandVersion: String = "V2" // There is only V2 variant of SHOW NAMESPACES. - - test("case sensitivity") { - Seq(true, false).foreach { caseSensitive => - withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive.toString) { - withNamespace(s"$catalog.AAA", s"$catalog.bbb") { - sql(s"CREATE NAMESPACE $catalog.AAA") - sql(s"CREATE NAMESPACE $catalog.bbb") - runShowNamespacesSql( - s"SHOW NAMESPACES IN $catalog", - Seq("aaa", "bbb") ++ builtinTopNamespaces) - runShowNamespacesSql(s"SHOW NAMESPACES IN $catalog LIKE 'AAA'", Seq("aaa")) - runShowNamespacesSql(s"SHOW NAMESPACES IN $catalog LIKE 'aaa'", Seq("aaa")) - } - } - } - } + // Hive Catalog is not case preserving and always lower-case the namespace name when storing it. + override def isCasePreserving: Boolean = false test("hive client calls") { withNamespace(s"$catalog.ns1", s"$catalog.ns2") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala index 10cc55bd28..3769de07d8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -599,7 +599,7 @@ private[hive] class TestHiveQueryExecution( override lazy val analyzed: LogicalPlan = sparkSession.withActive { // Make sure any test tables referenced are loaded. val referencedTables = logical.collect { - case UnresolvedRelation(ident, _, _, _) => + case UnresolvedRelation(ident, _, _) => if (ident.length > 1 && ident.head.equalsIgnoreCase(CatalogManager.SESSION_CATALOG_NAME)) { ident.tail.asTableIdentifier } else ident.asTableIdentifier