From d852b72041151b7f0dcff35765eb3d52a944f3a8 Mon Sep 17 00:00:00 2001 From: Alex Barreto Date: Tue, 23 Nov 2021 11:37:48 -0500 Subject: [PATCH] [SPARK-37013][CORE][SQL][FOLLOWUP] Use the new error framework to throw error in `FormatString` ### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/34313. The main change of this pr is change to use the new error framework to throw error when `pattern.contains("%0$")` is true. ### Why are the changes needed? Use the new error framework to throw error ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass the Jenkins or GitHub Action Closes #34454 from LuciferYang/SPARK-37013-FOLLOWUP. Authored-by: yangjie01 Signed-off-by: Wenchen Fan --- R/pkg/R/pairRDD.R | 2 +- .../util/kvstore/LevelDBIteratorSuite.java | 2 +- .../spark/util/kvstore/LevelDBSuite.java | 2 +- .../main/resources/error/error-classes.json | 3 + .../scala/org/apache/spark/util/Utils.scala | 5 + .../history/FsHistoryProviderSuite.scala | 2 +- .../deploy/history/HistoryServerSuite.scala | 4 +- .../spark/status/AppStatusStoreSuite.scala | 2 +- dev/lint-python | 35 +- docs/sql-ref-ansi-compliance.md | 40 +- examples/src/main/python/__init__.py | 16 + examples/src/main/python/als.py | 20 +- examples/src/main/python/avro_inputformat.py | 4 +- examples/src/main/python/ml/__init__,py | 16 + .../main/python/ml/chi_square_test_example.py | 5 + .../src/main/python/ml/correlation_example.py | 8 + examples/src/main/python/mllib/__init__.py | 16 + .../src/main/python/parquet_inputformat.py | 4 +- examples/src/main/python/sort.py | 4 +- examples/src/main/python/sql/__init__.py | 16 + .../src/main/python/sql/streaming/__init__,py | 16 + .../structured_network_wordcount_windowed.py | 2 +- .../src/main/python/streaming/__init__.py | 16 + .../streaming/network_wordjoinsentiments.py | 15 +- project/SparkBuild.scala | 1 + python/docs/source/conf.py | 2 +- .../getting_started/quickstart_ps.ipynb | 4 +- .../migration_guide/koalas_to_pyspark.rst | 5 +- python/docs/source/reference/pyspark.sql.rst | 2 +- .../pandas_on_spark/pandas_pyspark.rst | 4 +- .../user_guide/pandas_on_spark/types.rst | 2 +- python/pyspark/_typing.pyi | 4 +- python/pyspark/pandas/spark/accessors.py | 10 +- .../tests/data_type_ops/testing_utils.py | 19 +- python/pyspark/rdd.pyi | 107 ++-- python/pyspark/sql/dataframe.py | 20 +- python/pyspark/sql/pandas/conversion.py | 15 +- python/pyspark/sql/tests/test_arrow.py | 48 ++ python/pyspark/sql/tests/test_dataframe.py | 110 ++-- python/pyspark/streaming/dstream.pyi | 59 +- .../org/apache/spark/deploy/yarn/Client.scala | 6 +- .../catalyst/analysis/AnsiTypeCoercion.scala | 83 +-- .../catalyst/analysis/CTESubstitution.scala | 3 + .../catalyst/analysis/KeepLegacyOutputs.scala | 4 +- .../sql/catalyst/analysis/TypeCoercion.scala | 4 +- .../spark/sql/catalyst/dsl/package.scala | 3 +- .../expressions/mathExpressions.scala | 34 +- .../expressions/stringExpressions.scala | 10 +- .../plans/logical/basicLogicalOperators.scala | 2 + .../sql/catalyst/rules/RuleIdCollection.scala | 3 +- .../sql/catalyst/trees/TreePatterns.scala | 1 + .../sql/errors/QueryCompilationErrors.scala | 6 + .../analysis/AnsiTypeCoercionSuite.scala | 387 +------------- .../catalyst/analysis/TypeCoercionSuite.scala | 502 ++++++++++-------- .../expressions/MathExpressionsSuite.scala | 28 + ...eStoreBasicOperationsBenchmark-results.txt | 183 +++++++ .../analysis/ResolveSessionCatalog.scala | 16 +- .../spark/sql/execution/command/tables.scala | 2 +- .../datasources/orc/OrcFilters.scala | 13 +- .../v2/ShowTablePropertiesExec.scala | 6 +- .../sql-tests/results/ansi/date.sql.out | 25 +- .../sql-tests/results/ansi/interval.sql.out | 4 +- .../results/ansi/string-functions.sql.out | 54 +- .../resources/sql-tests/results/date.sql.out | 10 +- .../sql-tests/results/datetime-legacy.sql.out | 10 +- .../results/postgreSQL/numeric.sql.out | 38 +- .../results/postgreSQL/strings.sql.out | 20 +- .../sql-tests/results/postgreSQL/text.sql.out | 60 ++- .../StateStoreBasicOperationsBenchmark.scala | 370 +++++++++++++ .../command/ShowTblPropertiesSuiteBase.scala | 24 +- .../command/v1/ShowTblPropertiesSuite.scala | 14 +- .../datasources/orc/OrcFilterSuite.scala | 35 +- .../datasources/parquet/ParquetIOSuite.scala | 58 +- .../parquet/ParquetQuerySuite.scala | 42 +- .../StreamingSessionWindowSuite.scala | 2 +- 75 files changed, 1668 insertions(+), 1061 deletions(-) create mode 100644 examples/src/main/python/__init__.py create mode 100644 examples/src/main/python/ml/__init__,py create mode 100644 examples/src/main/python/mllib/__init__.py create mode 100644 examples/src/main/python/sql/__init__.py create mode 100644 examples/src/main/python/sql/streaming/__init__,py create mode 100644 examples/src/main/python/streaming/__init__.py create mode 100644 sql/core/benchmarks/StateStoreBasicOperationsBenchmark-results.txt create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBasicOperationsBenchmark.scala diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 41676be03e..5ebbef74b8 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -135,7 +135,7 @@ setMethod("values", #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) #' makePairs <- lapply(rdd, function(x) { list(x, x) }) -#' collectRDD(mapValues(makePairs, function(x) { x * 2) }) +#' collectRDD(mapValues(makePairs, function(x) { x * 2 })) #' Output: list(list(1,2), list(2,4), list(3,6), ...) #'} #' @rdname mapValues diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBIteratorSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBIteratorSuite.java index ea814cbb41..ceab7714a2 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBIteratorSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBIteratorSuite.java @@ -41,7 +41,7 @@ public static void cleanup() throws Exception { @Override protected KVStore createStore() throws Exception { - assumeFalse(SystemUtils.IS_OS_MAC_OSX && System.getProperty("os.arch").equals("aarch64")); + assumeFalse(SystemUtils.IS_OS_MAC_OSX && SystemUtils.OS_ARCH.equals("aarch64")); dbpath = File.createTempFile("test.", ".ldb"); dbpath.delete(); db = new LevelDB(dbpath); diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java index 1134ec202c..ef92a6cbba 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java @@ -52,7 +52,7 @@ public void cleanup() throws Exception { @Before public void setup() throws Exception { - assumeFalse(SystemUtils.IS_OS_MAC_OSX && System.getProperty("os.arch").equals("aarch64")); + assumeFalse(SystemUtils.IS_OS_MAC_OSX && SystemUtils.OS_ARCH.equals("aarch64")); dbpath = File.createTempFile("test.", ".ldb"); dbpath.delete(); db = new LevelDB(dbpath); diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index a33655810d..3e0a9c3377 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -54,6 +54,9 @@ "IF_PARTITION_NOT_EXISTS_UNSUPPORTED" : { "message" : [ "Cannot write, IF NOT EXISTS is not supported for table: %s" ] }, + "ILLEGAL_SUBSTRING" : { + "message" : [ "%s cannot contain %s." ] + }, "INCOMPARABLE_PIVOT_COLUMN" : { "message" : [ "Invalid pivot column '%s'. Pivot columns must be comparable." ], "sqlState" : "42000" diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 0029bbd713..27496d687c 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1961,6 +1961,11 @@ private[spark] object Utils extends Logging { */ val isMac = SystemUtils.IS_OS_MAC_OSX + /** + * Whether the underlying operating system is Mac OS X and processor is Apple Silicon. + */ + val isMacOnAppleSilicon = SystemUtils.IS_OS_MAC_OSX && SystemUtils.OS_ARCH.equals("aarch64") + /** * Pattern for matching a Windows drive, which contains only a single alphabet character. */ diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index c3d524eb88..b05b9de68d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -1652,7 +1652,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with Matchers with Logging { if (!inMemory) { // LevelDB doesn't support Apple Silicon yet - assume(!(Utils.isMac && System.getProperty("os.arch").equals("aarch64"))) + assume(!Utils.isMacOnAppleSilicon) conf.set(LOCAL_STORE_DIR, Utils.createTempDir().getAbsolutePath()) } conf.set(HYBRID_STORE_ENABLED, useHybridStore) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index a9892b18ce..0f4481ea30 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -87,7 +87,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers .set(EXECUTOR_PROCESS_TREE_METRICS_ENABLED, true) conf.setAll(extraConf) // Since LevelDB doesn't support Apple Silicon yet, fallback to in-memory provider - if (Utils.isMac && System.getProperty("os.arch").equals("aarch64")) { + if (Utils.isMacOnAppleSilicon) { conf.remove(LOCAL_STORE_DIR) } provider = new FsHistoryProvider(conf) @@ -389,7 +389,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers .set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) .remove(IS_TESTING) // Since LevelDB doesn't support Apple Silicon yet, fallback to in-memory provider - if (Utils.isMac && System.getProperty("os.arch").equals("aarch64")) { + if (Utils.isMacOnAppleSilicon) { myConf.remove(LOCAL_STORE_DIR) } val provider = new FsHistoryProvider(myConf) diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala index 0b32c7f8ba..d7fefb64ee 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusStoreSuite.scala @@ -84,7 +84,7 @@ class AppStatusStoreSuite extends SparkFunSuite { return AppStatusStore.createLiveStore(conf) } // LevelDB doesn't support Apple Silicon yet - if (Utils.isMac && System.getProperty("os.arch").equals("aarch64") && disk) { + if (Utils.isMacOnAppleSilicon && disk) { return null } diff --git a/dev/lint-python b/dev/lint-python index 851edd6d34..9b7a139176 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -182,23 +182,40 @@ function mypy_data_test { fi } +function mypy_examples_test { + local MYPY_REPORT= + local MYPY_STATUS= -function mypy_test { - if ! hash "$MYPY_BUILD" 2> /dev/null; then - echo "The $MYPY_BUILD command was not found. Skipping for now." - return + echo "starting mypy examples test..." + + MYPY_REPORT=$( (MYPYPATH=python $MYPY_BUILD \ + --allow-untyped-defs \ + --config-file python/mypy.ini \ + --exclude "mllib/*" \ + examples/src/main/python/) 2>&1) + + MYPY_STATUS=$? + + if [ "$MYPY_STATUS" -ne 0 ]; then + echo "examples failed mypy checks:" + echo "$MYPY_REPORT" + echo "$MYPY_STATUS" + exit "$MYPY_STATUS" + else + echo "examples passed mypy checks." + echo fi +} - _MYPY_VERSION=($($MYPY_BUILD --version)) - MYPY_VERSION="${_MYPY_VERSION[1]}" - EXPECTED_MYPY="$(satisfies_min_version $MYPY_VERSION $MINIMUM_MYPY)" - if [[ "$EXPECTED_MYPY" == "False" ]]; then - echo "The minimum mypy version needs to be $MINIMUM_MYPY. Your current version is $MYPY_VERSION. Skipping for now." +function mypy_test { + if ! hash "$MYPY_BUILD" 2> /dev/null; then + echo "The $MYPY_BUILD command was not found. Skipping for now." return fi mypy_annotation_test + mypy_examples_test mypy_data_test } diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index b93f8b7956..9ad7ad6211 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -221,7 +221,6 @@ This is a graphical depiction of the precedence list as a directed tree: The least common type from a set of types is the narrowest type reachable from the precedence list by all elements of the set of types. The least common type resolution is used to: -- Decide whether a function expecting a parameter of a type can be invoked using an argument of a narrower type. - Derive the argument type for functions which expect a shared argument type for multiple parameters, such as coalesce, least, or greatest. - Derive the operand types for operators such as arithmetic operations or comparisons. - Derive the result type for expressions such as the case expression. @@ -246,19 +245,40 @@ DOUBLE > SELECT (typeof(coalesce(1BD, 1F))); DOUBLE --- The substring function expects arguments of type INT for the start and length parameters. -> SELECT substring('hello', 1Y, 2); -he -> SELECT substring('hello', '1', 2); -he -> SELECT substring('hello', 1L, 2); -Error: Argument 2 requires an INT type. -> SELECT substring('hello', str, 2) FROM VALUES(CAST('1' AS STRING)) AS T(str); -Error: Argument 2 requires an INT type. ``` ### SQL Functions +#### Function invocation +Under ANSI mode(spark.sql.ansi.enabled=true), the function invocation of Spark SQL: +- In general, it follows the `Store assignment` rules as storing the input values as the declared parameter type of the SQL functions +- Special rules apply for string literals and untyped NULL. A NULL can be promoted to any other type, while a string literal can be promoted to any simple data type. +```sql +> SET spark.sql.ansi.enabled=true; +-- implicitly cast Int to String type +> SELECT concat('total number: ', 1); +total number: 1 +-- implicitly cast Timestamp to Date type +> select datediff(now(), current_date); +0 + +-- specialrule: implicitly cast String literal to Double type +> SELECT ceil('0.1'); +1 +-- specialrule: implicitly cast NULL to Date type +> SELECT year(null); +NULL + +> CREATE TABLE t(s string); +-- Can't store String column as Numeric types. +> SELECT ceil(s) from t; +Error in query: cannot resolve 'CEIL(spark_catalog.default.t.s)' due to data type mismatch +-- Can't store String column as Date type. +> select year(s) from t; +Error in query: cannot resolve 'year(spark_catalog.default.t.s)' due to data type mismatch +``` + +#### Functions with different behaviors The behavior of some SQL functions can be different under ANSI mode (`spark.sql.ansi.enabled=true`). - `size`: This function returns null for null input. - `element_at`: diff --git a/examples/src/main/python/__init__.py b/examples/src/main/python/__init__.py new file mode 100644 index 0000000000..cce3acad34 --- /dev/null +++ b/examples/src/main/python/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/examples/src/main/python/als.py b/examples/src/main/python/als.py index 511634fd8f..73af5e1a1f 100755 --- a/examples/src/main/python/als.py +++ b/examples/src/main/python/als.py @@ -77,26 +77,26 @@ def update(i, mat, ratings): (M, U, F, ITERATIONS, partitions)) R = matrix(rand(M, F)) * matrix(rand(U, F).T) - ms = matrix(rand(M, F)) - us = matrix(rand(U, F)) + ms: matrix = matrix(rand(M, F)) + us: matrix = matrix(rand(U, F)) Rb = sc.broadcast(R) msb = sc.broadcast(ms) usb = sc.broadcast(us) for i in range(ITERATIONS): - ms = sc.parallelize(range(M), partitions) \ - .map(lambda x: update(x, usb.value, Rb.value)) \ - .collect() + ms_ = sc.parallelize(range(M), partitions) \ + .map(lambda x: update(x, usb.value, Rb.value)) \ + .collect() # collect() returns a list, so array ends up being # a 3-d array, we take the first 2 dims for the matrix - ms = matrix(np.array(ms)[:, :, 0]) + ms = matrix(np.array(ms_)[:, :, 0]) msb = sc.broadcast(ms) - us = sc.parallelize(range(U), partitions) \ - .map(lambda x: update(x, msb.value, Rb.value.T)) \ - .collect() - us = matrix(np.array(us)[:, :, 0]) + us_ = sc.parallelize(range(U), partitions) \ + .map(lambda x: update(x, msb.value, Rb.value.T)) \ + .collect() + us = matrix(np.array(us_)[:, :, 0]) usb = sc.broadcast(us) error = rmse(R, ms, us) diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index 49ab37e7b3..e303860b32 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -44,8 +44,10 @@ {u'favorite_color': u'red', u'name': u'Ben'} """ import sys +from typing import Any, Tuple from functools import reduce +from pyspark.rdd import RDD from pyspark.sql import SparkSession if __name__ == "__main__": @@ -75,7 +77,7 @@ schema_rdd = sc.textFile(sys.argv[2], 1).collect() conf = {"avro.schema.input.key": reduce(lambda x, y: x + y, schema_rdd)} - avro_rdd = sc.newAPIHadoopFile( + avro_rdd: RDD[Tuple[Any, None]] = sc.newAPIHadoopFile( path, "org.apache.avro.mapreduce.AvroKeyInputFormat", "org.apache.avro.mapred.AvroKey", diff --git a/examples/src/main/python/ml/__init__,py b/examples/src/main/python/ml/__init__,py new file mode 100644 index 0000000000..cce3acad34 --- /dev/null +++ b/examples/src/main/python/ml/__init__,py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/examples/src/main/python/ml/chi_square_test_example.py b/examples/src/main/python/ml/chi_square_test_example.py index bf15a03d9c..0360742faf 100644 --- a/examples/src/main/python/ml/chi_square_test_example.py +++ b/examples/src/main/python/ml/chi_square_test_example.py @@ -42,6 +42,11 @@ df = spark.createDataFrame(data, ["label", "features"]) r = ChiSquareTest.test(df, "features", "label").head() + + # $example off$ + assert r is not None + # $example on$ + print("pValues: " + str(r.pValues)) print("degreesOfFreedom: " + str(r.degreesOfFreedom)) print("statistics: " + str(r.statistics)) diff --git a/examples/src/main/python/ml/correlation_example.py b/examples/src/main/python/ml/correlation_example.py index 9006d54149..b15535a598 100644 --- a/examples/src/main/python/ml/correlation_example.py +++ b/examples/src/main/python/ml/correlation_example.py @@ -40,9 +40,17 @@ df = spark.createDataFrame(data, ["features"]) r1 = Correlation.corr(df, "features").head() + + # $example off$ + assert r1 is not None + # $example on$ print("Pearson correlation matrix:\n" + str(r1[0])) r2 = Correlation.corr(df, "features", "spearman").head() + + # $example off$ + assert r2 is not None + # $example on$ print("Spearman correlation matrix:\n" + str(r2[0])) # $example off$ diff --git a/examples/src/main/python/mllib/__init__.py b/examples/src/main/python/mllib/__init__.py new file mode 100644 index 0000000000..cce3acad34 --- /dev/null +++ b/examples/src/main/python/mllib/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py index ca8dd25e6d..380cb7bdef 100644 --- a/examples/src/main/python/parquet_inputformat.py +++ b/examples/src/main/python/parquet_inputformat.py @@ -30,7 +30,9 @@ <...more log output...> """ import sys +from typing import Any, Tuple +from pyspark.rdd import RDD from pyspark.sql import SparkSession if __name__ == "__main__": @@ -54,7 +56,7 @@ sc = spark.sparkContext - parquet_rdd = sc.newAPIHadoopFile( + parquet_rdd: RDD[Tuple[None, Any]] = sc.newAPIHadoopFile( path, 'org.apache.parquet.avro.AvroParquetInputFormat', 'java.lang.Void', diff --git a/examples/src/main/python/sort.py b/examples/src/main/python/sort.py index 9efb00a6f1..bd198817b0 100755 --- a/examples/src/main/python/sort.py +++ b/examples/src/main/python/sort.py @@ -16,7 +16,9 @@ # import sys +from typing import Tuple +from pyspark.rdd import RDD from pyspark.sql import SparkSession @@ -31,7 +33,7 @@ .getOrCreate() lines = spark.read.text(sys.argv[1]).rdd.map(lambda r: r[0]) - sortedCount = lines.flatMap(lambda x: x.split(' ')) \ + sortedCount: RDD[Tuple[int, int]] = lines.flatMap(lambda x: x.split(' ')) \ .map(lambda x: (int(x), 1)) \ .sortByKey() # This is just a demo on how to bring all the sorted data back to a single node. diff --git a/examples/src/main/python/sql/__init__.py b/examples/src/main/python/sql/__init__.py new file mode 100644 index 0000000000..cce3acad34 --- /dev/null +++ b/examples/src/main/python/sql/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/examples/src/main/python/sql/streaming/__init__,py b/examples/src/main/python/sql/streaming/__init__,py new file mode 100644 index 0000000000..cce3acad34 --- /dev/null +++ b/examples/src/main/python/sql/streaming/__init__,py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py index cc39d8afa6..4aa44955d9 100644 --- a/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py +++ b/examples/src/main/python/sql/streaming/structured_network_wordcount_windowed.py @@ -87,7 +87,7 @@ windowedCounts = words.groupBy( window(words.timestamp, windowDuration, slideDuration), words.word - ).count().orderBy('window') + ).count().orderBy('window') # type: ignore[arg-type] # Start running the query that prints the windowed word counts to the console query = windowedCounts\ diff --git a/examples/src/main/python/streaming/__init__.py b/examples/src/main/python/streaming/__init__.py new file mode 100644 index 0000000000..cce3acad34 --- /dev/null +++ b/examples/src/main/python/streaming/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/examples/src/main/python/streaming/network_wordjoinsentiments.py b/examples/src/main/python/streaming/network_wordjoinsentiments.py index 5b03546fb4..718442cf17 100644 --- a/examples/src/main/python/streaming/network_wordjoinsentiments.py +++ b/examples/src/main/python/streaming/network_wordjoinsentiments.py @@ -31,9 +31,10 @@ """ import sys +from typing import Tuple from pyspark import SparkContext -from pyspark.streaming import StreamingContext +from pyspark.streaming import DStream, StreamingContext def print_happiest_words(rdd): @@ -50,10 +51,17 @@ def print_happiest_words(rdd): sc = SparkContext(appName="PythonStreamingNetworkWordJoinSentiments") ssc = StreamingContext(sc, 5) + def line_to_tuple(line: str) -> Tuple[str, str]: + try: + k, v = line.split(" ") + return k, v + except ValueError: + return "", "" + # Read in the word-sentiment list and create a static RDD from it word_sentiments_file_path = "data/streaming/AFINN-111.txt" word_sentiments = ssc.sparkContext.textFile(word_sentiments_file_path) \ - .map(lambda line: tuple(line.split("\t"))) + .map(line_to_tuple) lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) @@ -64,7 +72,8 @@ def print_happiest_words(rdd): # Determine the words with the highest sentiment values by joining the streaming RDD # with the static RDD inside the transform() method and then multiplying # the frequency of the words by its sentiment value - happiest_words = word_counts.transform(lambda rdd: word_sentiments.join(rdd)) \ + happiest_words: DStream[Tuple[float, str]] = word_counts \ + .transform(lambda rdd: word_sentiments.join(rdd)) \ .map(lambda word_tuples: (word_tuples[0], float(word_tuples[1][0]) * word_tuples[1][1])) \ .map(lambda word_happiness: (word_happiness[1], word_happiness[0])) \ .transform(lambda rdd: rdd.sortByKey(False)) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 68878ecd30..f592ecd522 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -643,6 +643,7 @@ object KubernetesIntegrationTests { "-t", imageTag.value, "-p", s"$bindingsDir/python/Dockerfile", "-R", s"$bindingsDir/R/Dockerfile", + "-b", s"java_image_tag=${sys.env.getOrElse("JAVA_IMAGE_TAG", "8-jre-slim")}", "build" ) val ec = Process(cmd).! diff --git a/python/docs/source/conf.py b/python/docs/source/conf.py index 5cf164d36a..e1bc400639 100644 --- a/python/docs/source/conf.py +++ b/python/docs/source/conf.py @@ -398,7 +398,7 @@ #epub_use_index = True def setup(app): # The app.add_javascript() is deprecated. - getattr(app, "add_js_file", getattr(app, "add_javascript"))('copybutton.js') + getattr(app, "add_js_file", getattr(app, "add_javascript", None))('copybutton.js') # Skip sample endpoint link (not expected to resolve) linkcheck_ignore = [r'https://kinesis.us-east-1.amazonaws.com'] diff --git a/python/docs/source/getting_started/quickstart_ps.ipynb b/python/docs/source/getting_started/quickstart_ps.ipynb index 74d6724ef2..87796aec76 100644 --- a/python/docs/source/getting_started/quickstart_ps.ipynb +++ b/python/docs/source/getting_started/quickstart_ps.ipynb @@ -539,7 +539,7 @@ "metadata": {}, "outputs": [], "source": [ - "psdf = sdf.to_pandas_on_spark()" + "psdf = sdf.pandas_api()" ] }, { @@ -14486,4 +14486,4 @@ }, "nbformat": 4, "nbformat_minor": 1 -} +} \ No newline at end of file diff --git a/python/docs/source/migration_guide/koalas_to_pyspark.rst b/python/docs/source/migration_guide/koalas_to_pyspark.rst index 9102d1dbf1..24e2d95915 100644 --- a/python/docs/source/migration_guide/koalas_to_pyspark.rst +++ b/python/docs/source/migration_guide/koalas_to_pyspark.rst @@ -30,7 +30,10 @@ Migrating from Koalas to pandas API on Spark * ``DataFrame.koalas`` in Koalas DataFrame was renamed to ``DataFrame.pandas_on_spark`` in pandas-on-Spark DataFrame. ``DataFrame.koalas`` was kept for compatibility reason but deprecated as of Spark 3.2. ``DataFrame.koalas`` will be removed in the future releases. -* Monkey-patched ``DataFrame.to_koalas`` in PySpark DataFrame was renamed to ``DataFrame.to_pandas_on_spark`` in PySpark DataFrame. ``DataFrame.to_koalas`` was kept for compatibility reason but deprecated as of Spark 3.2. +* Monkey-patched ``DataFrame.to_koalas`` in PySpark DataFrame was renamed to ``DataFrame.pandas_api`` in PySpark DataFrame. ``DataFrame.to_koalas`` was kept for compatibility reason. ``DataFrame.to_koalas`` will be removed in the future releases. +* Monkey-patched ``DataFrame.to_pandas_on_spark`` in PySpark DataFrame was renamed to ``DataFrame.pandas_api`` in PySpark DataFrame. ``DataFrame.to_pandas_on_spark`` was kept for compatibility reason but deprecated as of Spark 3.3. + ``DataFrame.to_pandas_on_spark`` will be removed in the future releases. + * ``databricks.koalas.__version__`` was removed. ``pyspark.__version__`` should be used instead. diff --git a/python/docs/source/reference/pyspark.sql.rst b/python/docs/source/reference/pyspark.sql.rst index 5b77da53d8..818814ca0a 100644 --- a/python/docs/source/reference/pyspark.sql.rst +++ b/python/docs/source/reference/pyspark.sql.rst @@ -223,7 +223,7 @@ DataFrame APIs DataFrame.write DataFrame.writeStream DataFrame.writeTo - DataFrame.to_pandas_on_spark + DataFrame.pandas_api DataFrameNaFunctions.drop DataFrameNaFunctions.fill DataFrameNaFunctions.replace diff --git a/python/docs/source/user_guide/pandas_on_spark/pandas_pyspark.rst b/python/docs/source/user_guide/pandas_on_spark/pandas_pyspark.rst index f4fc0dac0a..04d6617d38 100644 --- a/python/docs/source/user_guide/pandas_on_spark/pandas_pyspark.rst +++ b/python/docs/source/user_guide/pandas_on_spark/pandas_pyspark.rst @@ -107,7 +107,7 @@ Spark DataFrame can be a pandas-on-Spark DataFrame easily as below: .. code-block:: python - >>> sdf.to_pandas_on_spark() + >>> sdf.pandas_api() id 0 6 1 7 @@ -127,7 +127,7 @@ to use as an index when possible. >>> # Call Spark APIs ... sdf = sdf.filter("id > 5") >>> # Uses the explicit index to avoid to create default index. - ... sdf.to_pandas_on_spark(index_col='index') + ... sdf.pandas_api(index_col='index') id index 6 6 diff --git a/python/docs/source/user_guide/pandas_on_spark/types.rst b/python/docs/source/user_guide/pandas_on_spark/types.rst index 831967aad5..8e04efcd7f 100644 --- a/python/docs/source/user_guide/pandas_on_spark/types.rst +++ b/python/docs/source/user_guide/pandas_on_spark/types.rst @@ -44,7 +44,7 @@ The example below shows how data types are casted from PySpark DataFrame to pand DataFrame[tinyint: tinyint, decimal: decimal(10,0), float: float, double: double, integer: int, long: bigint, short: smallint, timestamp: timestamp, string: string, boolean: boolean, date: date] # 3. Convert PySpark DataFrame to pandas-on-Spark DataFrame - >>> psdf = sdf.to_pandas_on_spark() + >>> psdf = sdf.pandas_api() # 4. Check the pandas-on-Spark data types >>> psdf.dtypes diff --git a/python/pyspark/_typing.pyi b/python/pyspark/_typing.pyi index 637e4cb4fb..9a36c8945b 100644 --- a/python/pyspark/_typing.pyi +++ b/python/pyspark/_typing.pyi @@ -20,7 +20,7 @@ from typing import Callable, Iterable, Sized, TypeVar, Union from typing_extensions import Protocol F = TypeVar("F", bound=Callable) -T = TypeVar("T", covariant=True) +T_co = TypeVar("T_co", covariant=True) PrimitiveType = Union[bool, float, int, str] @@ -30,4 +30,4 @@ class SupportsIAdd(Protocol): class SupportsOrdering(Protocol): def __le__(self, other: SupportsOrdering) -> bool: ... -class SizedIterable(Protocol, Sized, Iterable[T]): ... +class SizedIterable(Protocol, Sized, Iterable[T_co]): ... diff --git a/python/pyspark/pandas/spark/accessors.py b/python/pyspark/pandas/spark/accessors.py index 0e91f4e568..e0d463980b 100644 --- a/python/pyspark/pandas/spark/accessors.py +++ b/python/pyspark/pandas/spark/accessors.py @@ -396,7 +396,7 @@ def frame(self, index_col: Optional[Union[str, List[str]]] = None) -> SparkDataF See Also -------- DataFrame.to_spark - DataFrame.to_pandas_on_spark + DataFrame.pandas_api DataFrame.spark.frame Examples @@ -440,7 +440,7 @@ def frame(self, index_col: Optional[Union[str, List[str]]] = None) -> SparkDataF >>> spark_df = df.to_spark(index_col="index") >>> spark_df = spark_df.filter("a == 2") - >>> spark_df.to_pandas_on_spark(index_col="index") # doctest: +NORMALIZE_WHITESPACE + >>> spark_df.pandas_api(index_col="index") # doctest: +NORMALIZE_WHITESPACE a b c index 1 2 5 8 @@ -460,7 +460,7 @@ def frame(self, index_col: Optional[Union[str, List[str]]] = None) -> SparkDataF Likewise, can be converted to back to pandas-on-Spark DataFrame. - >>> new_spark_df.to_pandas_on_spark( + >>> new_spark_df.pandas_api( ... index_col=["index_1", "index_2"]) # doctest: +NORMALIZE_WHITESPACE b c index_1 index_2 @@ -893,7 +893,7 @@ def apply( expensive in general. .. note:: it will lose column labels. This is a synonym of - ``func(psdf.to_spark(index_col)).to_pandas_on_spark(index_col)``. + ``func(psdf.to_spark(index_col)).pandas_api(index_col)``. Parameters ---------- @@ -941,7 +941,7 @@ def apply( "The output of the function [%s] should be of a " "pyspark.sql.DataFrame; however, got [%s]." % (func, type(output)) ) - return output.to_pandas_on_spark(index_col) + return output.pandas_api(index_col) def repartition(self, num_partitions: int) -> "ps.DataFrame": """ diff --git a/python/pyspark/pandas/tests/data_type_ops/testing_utils.py b/python/pyspark/pandas/tests/data_type_ops/testing_utils.py index 01bd494757..37eff6dc2a 100644 --- a/python/pyspark/pandas/tests/data_type_ops/testing_utils.py +++ b/python/pyspark/pandas/tests/data_type_ops/testing_utils.py @@ -49,14 +49,21 @@ def numeric_pdf(self): dtypes = [np.int32, int, np.float32, float] sers = [pd.Series([1, 2, 3], dtype=dtype) for dtype in dtypes] sers.append(pd.Series([decimal.Decimal(1), decimal.Decimal(2), decimal.Decimal(3)])) - sers.append(pd.Series([decimal.Decimal(1), decimal.Decimal(2), decimal.Decimal(np.nan)])) sers.append(pd.Series([1, 2, np.nan], dtype=float)) + # Skip decimal_nan test before v1.3.0, it not supported by pandas on spark yet. + if LooseVersion(pd.__version__) >= LooseVersion("1.3.0"): + sers.append( + pd.Series([decimal.Decimal(1), decimal.Decimal(2), decimal.Decimal(np.nan)]) + ) pdf = pd.concat(sers, axis=1) - pdf.columns = [dtype.__name__ for dtype in dtypes] + [ - "decimal", - "decimal_nan", - "float_nan", - ] + if LooseVersion(pd.__version__) >= LooseVersion("1.3.0"): + pdf.columns = [dtype.__name__ for dtype in dtypes] + [ + "decimal", + "float_nan", + "decimal_nan", + ] + else: + pdf.columns = [dtype.__name__ for dtype in dtypes] + ["decimal", "float_nan"] return pdf @property diff --git a/python/pyspark/rdd.pyi b/python/pyspark/rdd.pyi index 37ba4c2d2f..c4eddbf150 100644 --- a/python/pyspark/rdd.pyi +++ b/python/pyspark/rdd.pyi @@ -61,6 +61,7 @@ from pyspark.sql._typing import AtomicValue, RowLike from py4j.java_gateway import JavaObject # type: ignore[import] T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) U = TypeVar("U") K = TypeVar("K", bound=Hashable) V = TypeVar("V") @@ -96,7 +97,7 @@ class Partitioner: def __eq__(self, other: Any) -> bool: ... def __call__(self, k: Any) -> int: ... -class RDD(Generic[T]): +class RDD(Generic[T_co]): is_cached: bool is_checkpointed: bool ctx: pyspark.context.SparkContext @@ -111,44 +112,46 @@ class RDD(Generic[T]): def __getnewargs__(self) -> Any: ... @property def context(self) -> pyspark.context.SparkContext: ... - def cache(self) -> RDD[T]: ... - def persist(self, storageLevel: StorageLevel = ...) -> RDD[T]: ... - def unpersist(self, blocking: bool = ...) -> RDD[T]: ... + def cache(self) -> RDD[T_co]: ... + def persist(self, storageLevel: StorageLevel = ...) -> RDD[T_co]: ... + def unpersist(self, blocking: bool = ...) -> RDD[T_co]: ... def checkpoint(self) -> None: ... def isCheckpointed(self) -> bool: ... def localCheckpoint(self) -> None: ... def isLocallyCheckpointed(self) -> bool: ... def getCheckpointFile(self) -> Optional[str]: ... - def map(self, f: Callable[[T], U], preservesPartitioning: bool = ...) -> RDD[U]: ... + def map(self, f: Callable[[T_co], U], preservesPartitioning: bool = ...) -> RDD[U]: ... def flatMap( - self, f: Callable[[T], Iterable[U]], preservesPartitioning: bool = ... + self, f: Callable[[T_co], Iterable[U]], preservesPartitioning: bool = ... ) -> RDD[U]: ... def mapPartitions( - self, f: Callable[[Iterable[T]], Iterable[U]], preservesPartitioning: bool = ... + self, f: Callable[[Iterable[T_co]], Iterable[U]], preservesPartitioning: bool = ... ) -> RDD[U]: ... def mapPartitionsWithIndex( self, - f: Callable[[int, Iterable[T]], Iterable[U]], + f: Callable[[int, Iterable[T_co]], Iterable[U]], preservesPartitioning: bool = ..., ) -> RDD[U]: ... def mapPartitionsWithSplit( self, - f: Callable[[int, Iterable[T]], Iterable[U]], + f: Callable[[int, Iterable[T_co]], Iterable[U]], preservesPartitioning: bool = ..., ) -> RDD[U]: ... def getNumPartitions(self) -> int: ... - def filter(self, f: Callable[[T], bool]) -> RDD[T]: ... - def distinct(self, numPartitions: Optional[int] = ...) -> RDD[T]: ... + def filter(self, f: Callable[[T_co], bool]) -> RDD[T_co]: ... + def distinct(self, numPartitions: Optional[int] = ...) -> RDD[T_co]: ... def sample( self, withReplacement: bool, fraction: float, seed: Optional[int] = ... - ) -> RDD[T]: ... + ) -> RDD[T_co]: ... def randomSplit( self, weights: List[Union[int, float]], seed: Optional[int] = ... - ) -> List[RDD[T]]: ... - def takeSample(self, withReplacement: bool, num: int, seed: Optional[int] = ...) -> List[T]: ... - def union(self, other: RDD[U]) -> RDD[Union[T, U]]: ... - def intersection(self, other: RDD[T]) -> RDD[T]: ... - def __add__(self, other: RDD[T]) -> RDD[T]: ... + ) -> List[RDD[T_co]]: ... + def takeSample( + self, withReplacement: bool, num: int, seed: Optional[int] = ... + ) -> List[T_co]: ... + def union(self, other: RDD[U]) -> RDD[Union[T_co, U]]: ... + def intersection(self, other: RDD[T_co]) -> RDD[T_co]: ... + def __add__(self, other: RDD[T_co]) -> RDD[T_co]: ... @overload def repartitionAndSortWithinPartitions( self: RDD[Tuple[O, V]], @@ -195,55 +198,55 @@ class RDD(Generic[T]): keyfunc: Callable[[K], O], ) -> RDD[Tuple[K, V]]: ... def sortBy( - self: RDD[T], - keyfunc: Callable[[T], O], + self, + keyfunc: Callable[[T_co], O], ascending: bool = ..., numPartitions: Optional[int] = ..., - ) -> RDD[T]: ... - def glom(self) -> RDD[List[T]]: ... - def cartesian(self, other: RDD[U]) -> RDD[Tuple[T, U]]: ... + ) -> RDD[T_co]: ... + def glom(self) -> RDD[List[T_co]]: ... + def cartesian(self, other: RDD[U]) -> RDD[Tuple[T_co, U]]: ... def groupBy( self, - f: Callable[[T], K], + f: Callable[[T_co], K], numPartitions: Optional[int] = ..., partitionFunc: Callable[[K], int] = ..., - ) -> RDD[Tuple[K, Iterable[T]]]: ... + ) -> RDD[Tuple[K, Iterable[T_co]]]: ... def pipe( self, command: str, env: Optional[Dict[str, str]] = ..., checkCode: bool = ... ) -> RDD[str]: ... - def foreach(self, f: Callable[[T], None]) -> None: ... - def foreachPartition(self, f: Callable[[Iterable[T]], None]) -> None: ... - def collect(self) -> List[T]: ... + def foreach(self, f: Callable[[T_co], None]) -> None: ... + def foreachPartition(self, f: Callable[[Iterable[T_co]], None]) -> None: ... + def collect(self) -> List[T_co]: ... def collectWithJobGroup( self, groupId: str, description: str, interruptOnCancel: bool = ... - ) -> List[T]: ... - def reduce(self, f: Callable[[T, T], T]) -> T: ... - def treeReduce(self, f: Callable[[T, T], T], depth: int = ...) -> T: ... - def fold(self, zeroValue: T, op: Callable[[T, T], T]) -> T: ... + ) -> List[T_co]: ... + def reduce(self, f: Callable[[T_co, T_co], T_co]) -> T_co: ... + def treeReduce(self, f: Callable[[T_co, T_co], T_co], depth: int = ...) -> T_co: ... + def fold(self, zeroValue: T, op: Callable[[T_co, T_co], T_co]) -> T_co: ... def aggregate( - self, zeroValue: U, seqOp: Callable[[U, T], U], combOp: Callable[[U, U], U] + self, zeroValue: U, seqOp: Callable[[U, T_co], U], combOp: Callable[[U, U], U] ) -> U: ... def treeAggregate( self, zeroValue: U, - seqOp: Callable[[U, T], U], + seqOp: Callable[[U, T_co], U], combOp: Callable[[U, U], U], depth: int = ..., ) -> U: ... @overload def max(self: RDD[O]) -> O: ... @overload - def max(self, key: Callable[[T], O]) -> T: ... + def max(self, key: Callable[[T_co], O]) -> T_co: ... @overload def min(self: RDD[O]) -> O: ... @overload - def min(self, key: Callable[[T], O]) -> T: ... + def min(self, key: Callable[[T_co], O]) -> T_co: ... def sum(self: RDD[NumberOrArray]) -> NumberOrArray: ... def count(self) -> int: ... def stats(self: RDD[NumberOrArray]) -> StatCounter: ... def histogram( - self, buckets: Union[int, List[T], Tuple[T, ...]] - ) -> Tuple[List[T], List[int]]: ... + self, buckets: Union[int, List[T_co], Tuple[T_co, ...]] + ) -> Tuple[List[T_co], List[int]]: ... def mean(self: RDD[NumberOrArray]) -> NumberOrArray: ... def variance(self: RDD[NumberOrArray]) -> NumberOrArray: ... def stdev(self: RDD[NumberOrArray]) -> NumberOrArray: ... @@ -253,13 +256,13 @@ class RDD(Generic[T]): @overload def top(self: RDD[O], num: int) -> List[O]: ... @overload - def top(self: RDD[T], num: int, key: Callable[[T], O]) -> List[T]: ... + def top(self, num: int, key: Callable[[T_co], O]) -> List[T_co]: ... @overload def takeOrdered(self: RDD[O], num: int) -> List[O]: ... @overload - def takeOrdered(self: RDD[T], num: int, key: Callable[[T], O]) -> List[T]: ... - def take(self, num: int) -> List[T]: ... - def first(self) -> T: ... + def takeOrdered(self, num: int, key: Callable[[T_co], O]) -> List[T_co]: ... + def take(self, num: int) -> List[T_co]: ... + def first(self) -> T_co: ... def isEmpty(self) -> bool: ... def saveAsNewAPIHadoopDataset( self: RDD[Tuple[K, V]], @@ -408,15 +411,15 @@ class RDD(Generic[T]): other: RDD[Tuple[K, U]], numPartitions: Optional[int] = ..., ) -> RDD[Tuple[K, V]]: ... - def subtract(self: RDD[T], other: RDD[T], numPartitions: Optional[int] = ...) -> RDD[T]: ... - def keyBy(self: RDD[T], f: Callable[[T], K]) -> RDD[Tuple[K, T]]: ... - def repartition(self, numPartitions: int) -> RDD[T]: ... - def coalesce(self, numPartitions: int, shuffle: bool = ...) -> RDD[T]: ... - def zip(self, other: RDD[U]) -> RDD[Tuple[T, U]]: ... - def zipWithIndex(self) -> RDD[Tuple[T, int]]: ... - def zipWithUniqueId(self) -> RDD[Tuple[T, int]]: ... + def subtract(self, other: RDD[T_co], numPartitions: Optional[int] = ...) -> RDD[T_co]: ... + def keyBy(self, f: Callable[[T_co], K]) -> RDD[Tuple[K, T_co]]: ... + def repartition(self, numPartitions: int) -> RDD[T_co]: ... + def coalesce(self, numPartitions: int, shuffle: bool = ...) -> RDD[T_co]: ... + def zip(self, other: RDD[U]) -> RDD[Tuple[T_co, U]]: ... + def zipWithIndex(self) -> RDD[Tuple[T_co, int]]: ... + def zipWithUniqueId(self) -> RDD[Tuple[T_co, int]]: ... def name(self) -> str: ... - def setName(self, name: str) -> RDD[T]: ... + def setName(self, name: str) -> RDD[T_co]: ... def toDebugString(self) -> bytes: ... def getStorageLevel(self) -> StorageLevel: ... def lookup(self: RDD[Tuple[K, V]], key: K) -> List[V]: ... @@ -428,9 +431,9 @@ class RDD(Generic[T]): self: RDD[Union[float, int]], timeout: int, confidence: float = ... ) -> BoundedFloat: ... def countApproxDistinct(self, relativeSD: float = ...) -> int: ... - def toLocalIterator(self, prefetchPartitions: bool = ...) -> Iterator[T]: ... - def barrier(self: RDD[T]) -> RDDBarrier[T]: ... - def withResources(self: RDD[T], profile: ResourceProfile) -> RDD[T]: ... + def toLocalIterator(self, prefetchPartitions: bool = ...) -> Iterator[T_co]: ... + def barrier(self) -> RDDBarrier[T_co]: ... + def withResources(self, profile: ResourceProfile) -> RDD[T_co]: ... def getResourceProfile(self) -> Optional[ResourceProfile]: ... @overload def toDF( diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ac1cbf9099..337cad534f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -3198,8 +3198,18 @@ def writeTo(self, table: str) -> DataFrameWriterV2: """ return DataFrameWriterV2(self, table) + # Keep to_pandas_on_spark for backward compatibility for now. def to_pandas_on_spark( self, index_col: Optional[Union[str, List[str]]] = None + ) -> "PandasOnSparkDataFrame": + warnings.warn( + "DataFrame.to_pandas_on_spark is deprecated. Use DataFrame.pandas_api instead.", + FutureWarning, + ) + return self.pandas_api(index_col) + + def pandas_api( + self, index_col: Optional[Union[str, List[str]]] = None ) -> "PandasOnSparkDataFrame": """ Converts the existing DataFrame into a pandas-on-Spark DataFrame. @@ -3230,7 +3240,7 @@ def to_pandas_on_spark( | c| 3| +----+----+ - >>> df.to_pandas_on_spark() # doctest: +SKIP + >>> df.pandas_api() # doctest: +SKIP Col1 Col2 0 a 1 1 b 2 @@ -3238,7 +3248,7 @@ def to_pandas_on_spark( We can specify the index columns. - >>> df.to_pandas_on_spark(index_col="Col1"): # doctest: +SKIP + >>> df.pandas_api(index_col="Col1"): # doctest: +SKIP Col2 Col1 a 1 @@ -3261,11 +3271,7 @@ def to_pandas_on_spark( def to_koalas( self, index_col: Optional[Union[str, List[str]]] = None ) -> "PandasOnSparkDataFrame": - warnings.warn( - "DataFrame.to_koalas is deprecated. Use DataFrame.to_pandas_on_spark instead.", - FutureWarning, - ) - return self.to_pandas_on_spark(index_col) + return self.pandas_api(index_col) def _to_scala_map(sc: SparkContext, jm: Dict) -> JavaObject: diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index 6f55db290a..045115f755 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -173,7 +173,20 @@ def toPandas(self) -> "PandasDataFrameLike": pdf[field.name] = _convert_map_items_to_dict(pdf[field.name]) return pdf else: - return pd.DataFrame.from_records([], columns=self.columns) + corrected_panda_types = {} + for index, field in enumerate(self.schema): + panda_type = PandasConversionMixin._to_corrected_pandas_type( + field.dataType + ) + corrected_panda_types[tmp_column_names[index]] = ( + np.object0 if panda_type is None else panda_type + ) + + pdf = pd.DataFrame(columns=tmp_column_names).astype( + dtype=corrected_panda_types + ) + pdf.columns = self.columns + return pdf except Exception as e: # We might have to allow fallback here as well but multiple Spark jobs can # be executed. So, simply fail in this case for now. diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 0c690257c9..99705fbb72 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -42,6 +42,7 @@ StructField, ArrayType, NullType, + DayTimeIntervalType, ) from pyspark.testing.sqlutils import ( ReusedSQLTestCase, @@ -205,6 +206,53 @@ def test_toPandas_fallback_disabled(self): with self.assertRaisesRegex(Exception, "Unsupported type"): df.toPandas() + def test_toPandas_empty_df_arrow_enabled(self): + # SPARK-30537 test that toPandas() on an empty dataframe has the correct dtypes + # when arrow is enabled + from datetime import date + from decimal import Decimal + + schema = StructType( + [ + StructField("a", StringType(), True), + StructField("a", IntegerType(), True), + StructField("c", TimestampType(), True), + StructField("d", NullType(), True), + StructField("e", LongType(), True), + StructField("f", FloatType(), True), + StructField("g", DateType(), True), + StructField("h", BinaryType(), True), + StructField("i", DecimalType(38, 18), True), + StructField("k", TimestampNTZType(), True), + StructField("L", DayTimeIntervalType(0, 3), True), + ] + ) + df = self.spark.createDataFrame(self.spark.sparkContext.emptyRDD(), schema=schema) + non_empty_df = self.spark.createDataFrame( + [ + ( + "a", + 1, + datetime.datetime(1969, 1, 1, 1, 1, 1), + None, + 10, + 0.2, + date(1969, 1, 1), + bytearray(b"a"), + Decimal("2.0"), + datetime.datetime(1969, 1, 1, 1, 1, 1), + datetime.timedelta(microseconds=123), + ) + ], + schema=schema, + ) + + pdf, pdf_arrow = self._toPandas_arrow_toggle(df) + pdf_non_empty, pdf_arrow_non_empty = self._toPandas_arrow_toggle(non_empty_df) + assert_frame_equal(pdf, pdf_arrow) + self.assertTrue(pdf_arrow.dtypes.equals(pdf_arrow_non_empty.dtypes)) + self.assertTrue(pdf_arrow.dtypes.equals(pdf_non_empty.dtypes)) + def test_null_conversion(self): df_null = self.spark.createDataFrame( [tuple([None for _ in range(len(self.data_wo_null[0]))])] + self.data_wo_null diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 75301edc8d..1367fe79f0 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -800,11 +800,12 @@ def test_to_pandas_avoid_astype(self): @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore def test_to_pandas_from_empty_dataframe(self): - with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): - # SPARK-29188 test that toPandas() on an empty dataframe has the correct dtypes - import numpy as np + # SPARK-29188 test that toPandas() on an empty dataframe has the correct dtypes + # SPARK-30537 test that toPandas() on an empty dataframe has the correct dtypes + # when arrow is enabled + import numpy as np - sql = """ + sql = """ SELECT CAST(1 AS TINYINT) AS tinyint, CAST(1 AS SMALLINT) AS smallint, CAST(1 AS INT) AS int, @@ -817,17 +818,21 @@ def test_to_pandas_from_empty_dataframe(self): CAST('2019-01-01' AS TIMESTAMP_NTZ) AS timestamp_ntz, INTERVAL '1563:04' MINUTE TO SECOND AS day_time_interval """ - dtypes_when_nonempty_df = self.spark.sql(sql).toPandas().dtypes - dtypes_when_empty_df = self.spark.sql(sql).filter("False").toPandas().dtypes - self.assertTrue(np.all(dtypes_when_empty_df == dtypes_when_nonempty_df)) + is_arrow_enabled = [True, False] + for value in is_arrow_enabled: + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}): + dtypes_when_nonempty_df = self.spark.sql(sql).toPandas().dtypes + dtypes_when_empty_df = self.spark.sql(sql).filter("False").toPandas().dtypes + self.assertTrue(np.all(dtypes_when_empty_df == dtypes_when_nonempty_df)) @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore def test_to_pandas_from_null_dataframe(self): - with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): - # SPARK-29188 test that toPandas() on a dataframe with only nulls has correct dtypes - import numpy as np + # SPARK-29188 test that toPandas() on a dataframe with only nulls has correct dtypes + # SPARK-30537 test that toPandas() on a dataframe with only nulls has correct dtypes + # using arrow + import numpy as np - sql = """ + sql = """ SELECT CAST(NULL AS TINYINT) AS tinyint, CAST(NULL AS SMALLINT) AS smallint, CAST(NULL AS INT) AS int, @@ -840,44 +845,51 @@ def test_to_pandas_from_null_dataframe(self): CAST(NULL AS TIMESTAMP_NTZ) AS timestamp_ntz, INTERVAL '1563:04' MINUTE TO SECOND AS day_time_interval """ - pdf = self.spark.sql(sql).toPandas() - types = pdf.dtypes - self.assertEqual(types[0], np.float64) - self.assertEqual(types[1], np.float64) - self.assertEqual(types[2], np.float64) - self.assertEqual(types[3], np.float64) - self.assertEqual(types[4], np.float32) - self.assertEqual(types[5], np.float64) - self.assertEqual(types[6], np.object) - self.assertEqual(types[7], np.object) - self.assertTrue(np.can_cast(np.datetime64, types[8])) - self.assertTrue(np.can_cast(np.datetime64, types[9])) - self.assertTrue(np.can_cast(np.timedelta64, types[10])) + is_arrow_enabled = [True, False] + for value in is_arrow_enabled: + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}): + pdf = self.spark.sql(sql).toPandas() + types = pdf.dtypes + self.assertEqual(types[0], np.float64) + self.assertEqual(types[1], np.float64) + self.assertEqual(types[2], np.float64) + self.assertEqual(types[3], np.float64) + self.assertEqual(types[4], np.float32) + self.assertEqual(types[5], np.float64) + self.assertEqual(types[6], np.object) + self.assertEqual(types[7], np.object) + self.assertTrue(np.can_cast(np.datetime64, types[8])) + self.assertTrue(np.can_cast(np.datetime64, types[9])) + self.assertTrue(np.can_cast(np.timedelta64, types[10])) @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore def test_to_pandas_from_mixed_dataframe(self): - with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): - # SPARK-29188 test that toPandas() on a dataframe with some nulls has correct dtypes - import numpy as np - - sql = """ - SELECT CAST(col1 AS TINYINT) AS tinyint, - CAST(col2 AS SMALLINT) AS smallint, - CAST(col3 AS INT) AS int, - CAST(col4 AS BIGINT) AS bigint, - CAST(col5 AS FLOAT) AS float, - CAST(col6 AS DOUBLE) AS double, - CAST(col7 AS BOOLEAN) AS boolean, - CAST(col8 AS STRING) AS string, - timestamp_seconds(col9) AS timestamp, - timestamp_seconds(col10) AS timestamp_ntz, - INTERVAL '1563:04' MINUTE TO SECOND AS day_time_interval - FROM VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), - (NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL) - """ - pdf_with_some_nulls = self.spark.sql(sql).toPandas() - pdf_with_only_nulls = self.spark.sql(sql).filter("tinyint is null").toPandas() - self.assertTrue(np.all(pdf_with_only_nulls.dtypes == pdf_with_some_nulls.dtypes)) + # SPARK-29188 test that toPandas() on a dataframe with some nulls has correct dtypes + # SPARK-30537 test that toPandas() on a dataframe with some nulls has correct dtypes + # using arrow + import numpy as np + + sql = """ + SELECT CAST(col1 AS TINYINT) AS tinyint, + CAST(col2 AS SMALLINT) AS smallint, + CAST(col3 AS INT) AS int, + CAST(col4 AS BIGINT) AS bigint, + CAST(col5 AS FLOAT) AS float, + CAST(col6 AS DOUBLE) AS double, + CAST(col7 AS BOOLEAN) AS boolean, + CAST(col8 AS STRING) AS string, + timestamp_seconds(col9) AS timestamp, + timestamp_seconds(col10) AS timestamp_ntz, + INTERVAL '1563:04' MINUTE TO SECOND AS day_time_interval + FROM VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1), + (NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL) + """ + is_arrow_enabled = [True, False] + for value in is_arrow_enabled: + with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}): + pdf_with_some_nulls = self.spark.sql(sql).toPandas() + pdf_with_only_nulls = self.spark.sql(sql).filter("tinyint is null").toPandas() + self.assertTrue(np.all(pdf_with_only_nulls.dtypes == pdf_with_some_nulls.dtypes)) def test_create_dataframe_from_array_of_long(self): import array @@ -1106,13 +1118,13 @@ def test_df_show(self): not have_pandas or not have_pyarrow, cast(str, pandas_requirement_message or pyarrow_requirement_message), ) - def test_to_pandas_on_spark(self): + def test_pandas_api(self): import pandas as pd from pandas.testing import assert_frame_equal sdf = self.spark.createDataFrame([("a", 1), ("b", 2), ("c", 3)], ["Col1", "Col2"]) - psdf_from_sdf = sdf.to_pandas_on_spark() - psdf_from_sdf_with_index = sdf.to_pandas_on_spark(index_col="Col1") + psdf_from_sdf = sdf.pandas_api() + psdf_from_sdf_with_index = sdf.pandas_api(index_col="Col1") pdf = pd.DataFrame({"Col1": ["a", "b", "c"], "Col2": [1, 2, 3]}) pdf_with_index = pdf.set_index("Col1") diff --git a/python/pyspark/streaming/dstream.pyi b/python/pyspark/streaming/dstream.pyi index 2bd1396f26..c9f31b37f0 100644 --- a/python/pyspark/streaming/dstream.pyi +++ b/python/pyspark/streaming/dstream.pyi @@ -38,11 +38,12 @@ from py4j.java_gateway import JavaObject S = TypeVar("S") T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) U = TypeVar("U") K = TypeVar("K", bound=Hashable) V = TypeVar("V") -class DStream(Generic[T]): +class DStream(Generic[T_co]): is_cached: bool is_checkpointed: bool def __init__( @@ -53,24 +54,24 @@ class DStream(Generic[T]): ) -> None: ... def context(self) -> pyspark.streaming.context.StreamingContext: ... def count(self) -> DStream[int]: ... - def filter(self, f: Callable[[T], bool]) -> DStream[T]: ... + def filter(self, f: Callable[[T_co], bool]) -> DStream[T_co]: ... def flatMap( - self: DStream[T], - f: Callable[[T], Iterable[U]], + self: DStream[T_co], + f: Callable[[T_co], Iterable[U]], preservesPartitioning: bool = ..., ) -> DStream[U]: ... def map( - self: DStream[T], f: Callable[[T], U], preservesPartitioning: bool = ... + self: DStream[T_co], f: Callable[[T_co], U], preservesPartitioning: bool = ... ) -> DStream[U]: ... def mapPartitions( - self, f: Callable[[Iterable[T]], Iterable[U]], preservesPartitioning: bool = ... + self, f: Callable[[Iterable[T_co]], Iterable[U]], preservesPartitioning: bool = ... ) -> DStream[U]: ... def mapPartitionsWithIndex( self, - f: Callable[[int, Iterable[T]], Iterable[U]], + f: Callable[[int, Iterable[T_co]], Iterable[U]], preservesPartitioning: bool = ..., ) -> DStream[U]: ... - def reduce(self, func: Callable[[T, T], T]) -> DStream[T]: ... + def reduce(self, func: Callable[[T_co, T_co], T_co]) -> DStream[T_co]: ... def reduceByKey( self: DStream[Tuple[K, V]], func: Callable[[V, V], V], @@ -89,45 +90,45 @@ class DStream(Generic[T]): partitionFunc: Callable[[K], int] = ..., ) -> DStream[Tuple[K, V]]: ... @overload - def foreachRDD(self, func: Callable[[RDD[T]], None]) -> None: ... + def foreachRDD(self, func: Callable[[RDD[T_co]], None]) -> None: ... @overload - def foreachRDD(self, func: Callable[[datetime.datetime, RDD[T]], None]) -> None: ... + def foreachRDD(self, func: Callable[[datetime.datetime, RDD[T_co]], None]) -> None: ... def pprint(self, num: int = ...) -> None: ... def mapValues(self: DStream[Tuple[K, V]], f: Callable[[V], U]) -> DStream[Tuple[K, U]]: ... def flatMapValues( self: DStream[Tuple[K, V]], f: Callable[[V], Iterable[U]] ) -> DStream[Tuple[K, U]]: ... - def glom(self) -> DStream[List[T]]: ... - def cache(self) -> DStream[T]: ... - def persist(self, storageLevel: StorageLevel) -> DStream[T]: ... - def checkpoint(self, interval: int) -> DStream[T]: ... + def glom(self) -> DStream[List[T_co]]: ... + def cache(self) -> DStream[T_co]: ... + def persist(self, storageLevel: StorageLevel) -> DStream[T_co]: ... + def checkpoint(self, interval: int) -> DStream[T_co]: ... def groupByKey( self: DStream[Tuple[K, V]], numPartitions: Optional[int] = ... ) -> DStream[Tuple[K, Iterable[V]]]: ... - def countByValue(self) -> DStream[Tuple[T, int]]: ... + def countByValue(self) -> DStream[Tuple[T_co, int]]: ... def saveAsTextFiles(self, prefix: str, suffix: Optional[str] = ...) -> None: ... @overload - def transform(self, func: Callable[[RDD[T]], RDD[U]]) -> TransformedDStream[U]: ... + def transform(self, func: Callable[[RDD[T_co]], RDD[U]]) -> TransformedDStream[U]: ... @overload def transform( - self, func: Callable[[datetime.datetime, RDD[T]], RDD[U]] + self, func: Callable[[datetime.datetime, RDD[T_co]], RDD[U]] ) -> TransformedDStream[U]: ... @overload def transformWith( self, - func: Callable[[RDD[T], RDD[U]], RDD[V]], + func: Callable[[RDD[T_co], RDD[U]], RDD[V]], other: RDD[U], keepSerializer: bool = ..., ) -> DStream[V]: ... @overload def transformWith( self, - func: Callable[[datetime.datetime, RDD[T], RDD[U]], RDD[V]], + func: Callable[[datetime.datetime, RDD[T_co], RDD[U]], RDD[V]], other: RDD[U], keepSerializer: bool = ..., ) -> DStream[V]: ... - def repartition(self, numPartitions: int) -> DStream[T]: ... - def union(self, other: DStream[U]) -> DStream[Union[T, U]]: ... + def repartition(self, numPartitions: int) -> DStream[T_co]: ... + def union(self, other: DStream[U]) -> DStream[Union[T_co, U]]: ... def cogroup( self: DStream[Tuple[K, V]], other: DStream[Tuple[K, U]], @@ -155,22 +156,24 @@ class DStream(Generic[T]): ) -> DStream[Tuple[K, Tuple[Optional[V], Optional[U]]]]: ... def slice( self, begin: Union[datetime.datetime, int], end: Union[datetime.datetime, int] - ) -> List[RDD[T]]: ... - def window(self, windowDuration: int, slideDuration: Optional[int] = ...) -> DStream[T]: ... + ) -> List[RDD[T_co]]: ... + def window(self, windowDuration: int, slideDuration: Optional[int] = ...) -> DStream[T_co]: ... def reduceByWindow( self, - reduceFunc: Callable[[T, T], T], - invReduceFunc: Optional[Callable[[T, T], T]], + reduceFunc: Callable[[T_co, T_co], T_co], + invReduceFunc: Optional[Callable[[T_co, T_co], T_co]], windowDuration: int, slideDuration: int, - ) -> DStream[T]: ... - def countByWindow(self, windowDuration: int, slideDuration: int) -> DStream[Tuple[T, int]]: ... + ) -> DStream[T_co]: ... + def countByWindow( + self, windowDuration: int, slideDuration: int + ) -> DStream[Tuple[T_co, int]]: ... def countByValueAndWindow( self, windowDuration: int, slideDuration: int, numPartitions: Optional[int] = ..., - ) -> DStream[Tuple[T, int]]: ... + ) -> DStream[Tuple[T_co, int]]: ... def groupByKeyAndWindow( self: DStream[Tuple[K, V]], windowDuration: int, diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 4763115721..7787e2fc92 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -901,8 +901,8 @@ private[spark] class Client( sys.env.get("PYTHONHASHSEED").foreach(env.put("PYTHONHASHSEED", _)) } - sys.env.get(ENV_DIST_CLASSPATH).foreach { dcp => - env(ENV_DIST_CLASSPATH) = dcp + Seq(ENV_DIST_CLASSPATH, SPARK_TESTING).foreach { envVar => + sys.env.get(envVar).foreach(value => env(envVar) = value) } env @@ -1353,6 +1353,8 @@ private[spark] object Client extends Logging { // Subdirectory where Spark libraries will be placed. val LOCALIZED_LIB_DIR = "__spark_libs__" + val SPARK_TESTING = "SPARK_TESTING" + /** * Return the path to the given application's staging directory. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index e8bf2aeac1..debc13b953 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -159,6 +159,10 @@ object AnsiTypeCoercion extends TypeCoercionBase { // If the expected type equals the input type, no need to cast. case _ if expectedType.acceptsType(inType) => Some(inType) + // If input is a numeric type but not decimal, and we expect a decimal type, + // cast the input to decimal. + case (n: NumericType, DecimalType) => Some(DecimalType.forType(n)) + // Cast null type (usually from null literals) into target types // By default, the result type is `target.defaultConcreteType`. When the target type is // `TypeCollection`, there is another branch to find the "closet convertible data type" below. @@ -178,79 +182,17 @@ object AnsiTypeCoercion extends TypeCoercionBase { case (StringType, DecimalType) if isInputFoldable => Some(DecimalType.SYSTEM_DEFAULT) - // If input is a numeric type but not decimal, and we expect a decimal type, - // cast the input to decimal. - case (d: NumericType, DecimalType) => Some(DecimalType.forType(d)) - - case (n1: NumericType, n2: NumericType) => - val widerType = findWiderTypeForTwo(n1, n2) - widerType match { - // if the expected type is Float type, we should still return Float type. - case Some(DoubleType) if n1 != DoubleType && n2 == FloatType => Some(FloatType) - - case Some(dt) if dt == n2 => Some(dt) - - case _ => None + case (_, target: DataType) => + if (Cast.canANSIStoreAssign(inType, target)) { + Some(target) + } else { + None } - case (DateType, TimestampType) => Some(TimestampType) - case (DateType, AnyTimestampType) => Some(AnyTimestampType.defaultConcreteType) - // When we reach here, input type is not acceptable for any types in this type collection, - // first try to find the all the expected types we can implicitly cast: - // 1. if there is no convertible data types, return None; - // 2. if there is only one convertible data type, cast input as it; - // 3. otherwise if there are multiple convertible data types, find the closet convertible - // data type among them. If there is no such a data type, return None. + // try to find the first one we can implicitly cast. case (_, TypeCollection(types)) => - // Since Spark contains special objects like `NumericType` and `DecimalType`, which accepts - // multiple types and they are `AbstractDataType` instead of `DataType`, here we use the - // conversion result their representation. - val convertibleTypes = types.flatMap(implicitCast(inType, _, isInputFoldable)) - if (convertibleTypes.isEmpty) { - None - } else { - // find the closet convertible data type, which can be implicit cast to all other - // convertible types. - val closestConvertibleType = convertibleTypes.find { dt => - convertibleTypes.forall { target => - implicitCast(dt, target, isInputFoldable = false).isDefined - } - } - // If the closet convertible type is Float type and the convertible types contains Double - // type, simply return Double type as the closet convertible type to avoid potential - // precision loss on converting the Integral type as Float type. - if (closestConvertibleType.contains(FloatType) && convertibleTypes.contains(DoubleType)) { - Some(DoubleType) - } else { - closestConvertibleType - } - } - - // Implicit cast between array types. - // - // Compare the nullabilities of the from type and the to type, check whether the cast of - // the nullability is resolvable by the following rules: - // 1. If the nullability of the to type is true, the cast is always allowed; - // 2. If the nullabilities of both the from type and the to type are false, the cast is - // allowed. - // 3. Otherwise, the cast is not allowed - case (ArrayType(fromType, containsNullFrom), ArrayType(toType: DataType, containsNullTo)) - if Cast.resolvableNullability(containsNullFrom, containsNullTo) => - implicitCast(fromType, toType, isInputFoldable).map(ArrayType(_, containsNullTo)) - - // Implicit cast between Map types. - // Follows the same semantics of implicit casting between two array types. - // Refer to documentation above. - case (MapType(fromKeyType, fromValueType, fn), MapType(toKeyType, toValueType, tn)) - if Cast.resolvableNullability(fn, tn) => - val newKeyType = implicitCast(fromKeyType, toKeyType, isInputFoldable) - val newValueType = implicitCast(fromValueType, toValueType, isInputFoldable) - if (newKeyType.isDefined && newValueType.isDefined) { - Some(MapType(newKeyType.get, newValueType.get, tn)) - } else { - None - } + types.flatMap(implicitCast(inType, _, isInputFoldable)).headOption case _ => None } @@ -348,6 +290,9 @@ object AnsiTypeCoercion extends TypeCoercionBase { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e + case d @ DateAdd(AnyTimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) + case d @ DateSub(AnyTimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) + case s @ SubtractTimestamps(DateType(), AnyTimestampType(), _, _) => s.copy(left = Cast(s.left, s.right.dataType)) case s @ SubtractTimestamps(AnyTimestampType(), DateType(), _, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala index ec3d957f92..2e2d415954 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala @@ -48,6 +48,9 @@ import org.apache.spark.sql.internal.SQLConf.{LEGACY_CTE_PRECEDENCE_POLICY, Lega */ object CTESubstitution extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { + if (!plan.containsPattern(UNRESOLVED_WITH)) { + return plan + } val isCommand = plan.find { case _: Command | _: ParsedStatement | _: InsertIntoDir => true case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/KeepLegacyOutputs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/KeepLegacyOutputs.scala index 98e26fc5ad..a40b96732b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/KeepLegacyOutputs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/KeepLegacyOutputs.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.plans.logical.{DescribeNamespace, LogicalPlan, ShowNamespaces, ShowTables} +import org.apache.spark.sql.catalyst.plans.logical.{DescribeNamespace, LogicalPlan, ShowNamespaces, ShowTableProperties, ShowTables} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND import org.apache.spark.sql.internal.SQLConf @@ -43,6 +43,8 @@ object KeepLegacyOutputs extends Rule[LogicalPlan] { assert(d.output.length == 2) d.copy(output = Seq(d.output.head.withName("database_description_item"), d.output.last.withName("database_description_value"))) + case s: ShowTableProperties if s.propertyKey.isDefined => + s.copy(output = Seq(s.output.last)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 90cbe565fe..506667461e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -1157,9 +1157,9 @@ object TypeCoercion extends TypeCoercionBase { override val transform: PartialFunction[Expression, Expression] = { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case d @ DateAdd(TimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) + case d @ DateAdd(AnyTimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) case d @ DateAdd(StringType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) - case d @ DateSub(TimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) + case d @ DateSub(AnyTimestampType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) case d @ DateSub(StringType(), _) => d.copy(startDate = Cast(d.startDate, DateType)) case s @ SubtractTimestamps(DateType(), AnyTimestampType(), _, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 9e83051313..1c95ec8d1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} -import java.time.{Duration, Instant, LocalDate, Period} +import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} import scala.language.implicitConversions @@ -165,6 +165,7 @@ package object dsl { implicit def bigDecimalToLiteral(d: java.math.BigDecimal): Literal = Literal(d) implicit def decimalToLiteral(d: Decimal): Literal = Literal(d) implicit def timestampToLiteral(t: Timestamp): Literal = Literal(t) + implicit def timestampNTZToLiteral(l: LocalDateTime): Literal = Literal(l) implicit def instantToLiteral(i: Instant): Literal = Literal(i) implicit def binaryToLiteral(a: Array[Byte]): Literal = Literal(a) implicit def periodToLiteral(p: Period): Literal = Literal(p) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index e3541dc7ee..03f9da66ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1547,14 +1547,27 @@ case class BRound(child: Expression, scale: Expression) } object WidthBucket { - def computeBucketNumber(value: Double, min: Double, max: Double, numBucket: Long): jl.Long = { - if (numBucket <= 0 || numBucket == Long.MaxValue || jl.Double.isNaN(value) || min == max || - jl.Double.isNaN(min) || jl.Double.isInfinite(min) || - jl.Double.isNaN(max) || jl.Double.isInfinite(max)) { - return null + if (isNull(value, min, max, numBucket)) { + null + } else { + computeBucketNumberNotNull(value, min, max, numBucket) } + } + /** This function is called by generated Java code, so it needs to be public. */ + def isNull(value: Double, min: Double, max: Double, numBucket: Long): Boolean = { + numBucket <= 0 || + numBucket == Long.MaxValue || + jl.Double.isNaN(value) || + min == max || + jl.Double.isNaN(min) || jl.Double.isInfinite(min) || + jl.Double.isNaN(max) || jl.Double.isInfinite(max) + } + + /** This function is called by generated Java code, so it needs to be public. */ + def computeBucketNumberNotNull( + value: Double, min: Double, max: Double, numBucket: Long): jl.Long = { val lower = Math.min(min, max) val upper = Math.max(min, max) @@ -1666,9 +1679,14 @@ case class WidthBucket( } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (input, min, max, numBucket) => - "org.apache.spark.sql.catalyst.expressions.WidthBucket" + - s".computeBucketNumber($input, $min, $max, $numBucket)") + nullSafeCodeGen(ctx, ev, (input, min, max, numBucket) => { + s"""${ev.isNull} = org.apache.spark.sql.catalyst.expressions.WidthBucket + | .isNull($input, $min, $max, $numBucket); + |if (!${ev.isNull}) { + | ${ev.value} = org.apache.spark.sql.catalyst.expressions.WidthBucket + | .computeBucketNumberNotNull($input, $min, $max, $numBucket); + |}""".stripMargin + }) } override def first: Expression = value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 337375de70..2b997da29b 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -1698,7 +1698,7 @@ case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.ge case class FormatString(children: Expression*) extends Expression with ImplicitCastInputTypes { require(children.nonEmpty, s"$prettyName() should take at least 1 argument") - require(checkArgumentIndexNotZero(children(0)), "Illegal format argument index = 0") + checkArgumentIndexNotZero(children(0)) override def foldable: Boolean = children.forall(_.foldable) @@ -1782,9 +1782,11 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC * Therefore, manually check that the pattern string not contains "%0$" to ensure consistent * behavior of Java 8, Java 11 and Java 17. */ - private def checkArgumentIndexNotZero(expression: Expression): Boolean = expression match { - case StringLiteral(pattern) => !pattern.contains("%0$") - case _ => true + private def checkArgumentIndexNotZero(expression: Expression): Unit = expression match { + case StringLiteral(pattern) if pattern.contains("%0$") => + throw QueryCompilationErrors.illegalSubstringError( + "The argument_index of string format", "position 0$") + case _ => // do nothing } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index f1b954d6c7..e8a632d015 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -626,6 +626,8 @@ object View { case class UnresolvedWith( child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) extends UnaryNode { + final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_WITH) + override def output: Seq[Attribute] = child.output override def simpleString(maxFields: Int): String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 9792545488..5ec303d97f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -162,12 +162,13 @@ object RuleIdCollection { // In the production code path, the following rules are run in CombinedTypeCoercionRule, and // hence we only need to add them for unit testing. "org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$PromoteStringLiterals" :: + "org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$DateTimeOperations" :: "org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$GetDateFieldOperations" :: "org.apache.spark.sql.catalyst.analysis.DecimalPrecision" :: "org.apache.spark.sql.catalyst.analysis.TypeCoercion$BooleanEquality" :: + "org.apache.spark.sql.catalyst.analysis.TypeCoercion$DateTimeOperations" :: "org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$CaseWhenCoercion" :: "org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$ConcatCoercion" :: - "org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$DateTimeOperations" :: "org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$Division" :: "org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$EltCoercion" :: "org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$FunctionArgumentConversion" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 6c1b64dd0a..aad90ff695 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -111,6 +111,7 @@ object TreePattern extends Enumeration { val REPARTITION_OPERATION: Value = Value val UNION: Value = Value val UNRESOLVED_RELATION: Value = Value + val UNRESOLVED_WITH: Value = Value val TYPED_FILTER: Value = Value val WINDOW: Value = Value val WITH_WINDOW_DEFINITION: Value = Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 207a9c3086..839a888990 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -66,6 +66,12 @@ object QueryCompilationErrors { messageParameters = Array(sizeLimit.toString)) } + def illegalSubstringError(subject: String, illegalContent: String): Throwable = { + new AnalysisException( + errorClass = "ILLEGAL_SUBSTRING", + messageParameters = Array(subject, illegalContent)) + } + def unorderablePivotColError(pivotCol: Expression): Throwable = { new AnalysisException( errorClass = "INCOMPARABLE_PIVOT_COLUMN", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala index ab8d9d9806..809cbb2ceb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala @@ -19,42 +19,34 @@ package org.apache.spark.sql.catalyst.analysis import java.sql.Timestamp -import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils -class AnsiTypeCoercionSuite extends AnalysisTest { +class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { import TypeCoercionSuite._ - // When Utils.isTesting is true, RuleIdCollection adds individual type coercion rules. Otherwise, - // RuleIdCollection doesn't add them because they are called in a train inside - // CombinedTypeCoercionRule. - assert(Utils.isTesting, s"${IS_TESTING.key} is not set to true") - // scalastyle:off line.size.limit // The following table shows all implicit data type conversions that are not visible to the user. // +----------------------+----------+-----------+-------------+----------+------------+------------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ // | Source Type\CAST TO | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType | NumericType | IntegralType | // +----------------------+----------+-----------+-------------+----------+------------+------------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ - // | ByteType | ByteType | ShortType | IntegerType | LongType | DoubleType | DoubleType | Dec(10, 2) | X | X | X | X | X | X | X | X | X | X | DecimalType(3, 0) | ByteType | ByteType | - // | ShortType | X | ShortType | IntegerType | LongType | DoubleType | DoubleType | Dec(10, 2) | X | X | X | X | X | X | X | X | X | X | DecimalType(5, 0) | ShortType | ShortType | - // | IntegerType | X | X | IntegerType | LongType | DoubleType | DoubleType | X | X | X | X | X | X | X | X | X | X | X | DecimalType(10, 0) | IntegerType | IntegerType | - // | LongType | X | X | X | LongType | DoubleType | DoubleType | X | X | X | X | X | X | X | X | X | X | X | DecimalType(20, 0) | LongType | LongType | - // | FloatType | X | X | X | X | FloatType | DoubleType | X | X | X | X | X | X | X | X | X | X | X | DecimalType(30, 15) | DoubleType | X | - // | DoubleType | X | X | X | X | X | DoubleType | X | X | X | X | X | X | X | X | X | X | X | DecimalType(14, 7) | FloatType | X | - // | Dec(10, 2) | X | X | X | X | DoubleType | DoubleType | Dec(10, 2) | X | X | X | X | X | X | X | X | X | X | DecimalType(10, 2) | Dec(10, 2) | X | - // | BinaryType | X | X | X | X | X | X | X | BinaryType | X | X | X | X | X | X | X | X | X | X | X | X | - // | BooleanType | X | X | X | X | X | X | X | X | BooleanType | X | X | X | X | X | X | X | X | X | X | X | - // | StringType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | - // | DateType | X | X | X | X | X | X | X | X | X | X | DateType | TimestampType | X | X | X | X | X | X | X | X | - // | TimestampType | X | X | X | X | X | X | X | X | X | X | X | TimestampType | X | X | X | X | X | X | X | X | + // | ByteType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(3, 0) | ByteType | ByteType | + // | ShortType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(5, 0) | ShortType | ShortType | + // | IntegerType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(10, 0) | IntegerType | IntegerType | + // | LongType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(20, 0) | LongType | LongType | + // | DoubleType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(30, 15) | DoubleType | IntegerType | + // | FloatType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(14, 7) | FloatType | IntegerType | + // | Dec(10, 2) | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(10, 2) | Dec(10, 2) | IntegerType | + // | BinaryType | X | X | X | X | X | X | X | BinaryType | X | StringType | X | X | X | X | X | X | X | X | X | X | + // | BooleanType | X | X | X | X | X | X | X | X | BooleanType | StringType | X | X | X | X | X | X | X | X | X | X | + // | StringType | X | X | X | X | X | X | X | X | X | StringType | X | X | X | X | X | X | X | X | X | X | + // | DateType | X | X | X | X | X | X | X | X | X | StringType | DateType | TimestampType | X | X | X | X | X | X | X | X | + // | TimestampType | X | X | X | X | X | X | X | X | X | StringType | DateType | TimestampType | X | X | X | X | X | X | X | X | // | ArrayType | X | X | X | X | X | X | X | X | X | X | X | X | ArrayType* | X | X | X | X | X | X | X | // | MapType | X | X | X | X | X | X | X | X | X | X | X | X | X | MapType* | X | X | X | X | X | X | // | StructType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | StructType* | X | X | X | X | X | @@ -65,30 +57,10 @@ class AnsiTypeCoercionSuite extends AnalysisTest { // Note: ArrayType* is castable when the element type is castable according to the table. // Note: MapType* is castable when both the key type and the value type are castable according to the table. // scalastyle:on line.size.limit + override def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = + AnsiTypeCoercion.implicitCast(e, expectedType) - private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { - // Check default value - val castDefault = AnsiTypeCoercion.implicitCast(default(from), to) - assert(DataType.equalsIgnoreCompatibleNullability( - castDefault.map(_.dataType).getOrElse(null), expected), - s"Failed to cast $from to $to") - - // Check null value - val castNull = AnsiTypeCoercion.implicitCast(createNull(from), to) - assert(DataType.equalsIgnoreCaseAndNullability( - castNull.map(_.dataType).getOrElse(null), expected), - s"Failed to cast $from to $to") - } - - private def shouldNotCast(from: DataType, to: AbstractDataType): Unit = { - // Check default value - val castDefault = AnsiTypeCoercion.implicitCast(default(from), to) - assert(castDefault.isEmpty, s"Should not be able to cast $from to $to, but got $castDefault") - - // Check null value - val castNull = AnsiTypeCoercion.implicitCast(createNull(from), to) - assert(castNull.isEmpty, s"Should not be able to cast $from to $to, but got $castNull") - } + override def dateTimeOperationsRule: TypeCoercionRule = AnsiTypeCoercion.DateTimeOperations private def shouldCastStringLiteral(to: AbstractDataType, expected: DataType): Unit = { val input = Literal("123") @@ -110,35 +82,6 @@ class AnsiTypeCoercionSuite extends AnalysisTest { assert(castResult.isEmpty, s"Should not be able to cast non-foldable String input to $to") } - private def default(dataType: DataType): Expression = dataType match { - case ArrayType(internalType: DataType, _) => - CreateArray(Seq(Literal.default(internalType))) - case MapType(keyDataType: DataType, valueDataType: DataType, _) => - CreateMap(Seq(Literal.default(keyDataType), Literal.default(valueDataType))) - case _ => Literal.default(dataType) - } - - private def createNull(dataType: DataType): Expression = dataType match { - case ArrayType(internalType: DataType, _) => - CreateArray(Seq(Literal.create(null, internalType))) - case MapType(keyDataType: DataType, valueDataType: DataType, _) => - CreateMap(Seq(Literal.create(null, keyDataType), Literal.create(null, valueDataType))) - case _ => Literal.create(null, dataType) - } - - // Check whether the type `checkedType` can be cast to all the types in `castableTypes`, - // but cannot be cast to the other types in `allTypes`. - private def checkTypeCasting(checkedType: DataType, castableTypes: Seq[DataType]): Unit = { - val nonCastableTypes = allTypes.filterNot(castableTypes.contains) - - castableTypes.foreach { tpe => - shouldCast(checkedType, tpe, tpe) - } - nonCastableTypes.foreach { tpe => - shouldNotCast(checkedType, tpe) - } - } - private def checkWidenType( widenFunc: (DataType, DataType) => Option[DataType], t1: DataType, @@ -156,81 +99,6 @@ class AnsiTypeCoercionSuite extends AnalysisTest { } } - test("implicit type cast - ByteType") { - val checkedType = ByteType - checkTypeCasting(checkedType, castableTypes = numericTypes) - shouldCast(checkedType, DecimalType, DecimalType.ByteDecimal) - shouldCast(checkedType, NumericType, checkedType) - shouldCast(checkedType, IntegralType, checkedType) - } - - test("implicit type cast - ShortType") { - val checkedType = ShortType - checkTypeCasting(checkedType, castableTypes = numericTypes.filterNot(_ == ByteType)) - shouldCast(checkedType, DecimalType, DecimalType.ShortDecimal) - shouldCast(checkedType, NumericType, checkedType) - shouldCast(checkedType, IntegralType, checkedType) - } - - test("implicit type cast - IntegerType") { - val checkedType = IntegerType - checkTypeCasting(checkedType, castableTypes = - Seq(IntegerType, LongType, FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT)) - shouldCast(IntegerType, DecimalType, DecimalType.IntDecimal) - shouldCast(checkedType, NumericType, checkedType) - shouldCast(checkedType, IntegralType, checkedType) - } - - test("implicit type cast - LongType") { - val checkedType = LongType - checkTypeCasting(checkedType, castableTypes = - Seq(LongType, FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT)) - shouldCast(checkedType, DecimalType, DecimalType.LongDecimal) - shouldCast(checkedType, NumericType, checkedType) - shouldCast(checkedType, IntegralType, checkedType) - } - - test("implicit type cast - FloatType") { - val checkedType = FloatType - checkTypeCasting(checkedType, castableTypes = Seq(FloatType, DoubleType)) - shouldCast(checkedType, DecimalType, DecimalType.FloatDecimal) - shouldCast(checkedType, NumericType, checkedType) - shouldNotCast(checkedType, IntegralType) - } - - test("implicit type cast - DoubleType") { - val checkedType = DoubleType - checkTypeCasting(checkedType, castableTypes = Seq(DoubleType)) - shouldCast(checkedType, DecimalType, DecimalType.DoubleDecimal) - shouldCast(checkedType, NumericType, checkedType) - shouldNotCast(checkedType, IntegralType) - } - - test("implicit type cast - DecimalType(10, 2)") { - val checkedType = DecimalType(10, 2) - checkTypeCasting(checkedType, castableTypes = fractionalTypes) - shouldCast(checkedType, DecimalType, checkedType) - shouldCast(checkedType, NumericType, checkedType) - shouldNotCast(checkedType, IntegralType) - } - - test("implicit type cast - BinaryType") { - val checkedType = BinaryType - checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) - shouldNotCast(checkedType, DecimalType) - shouldNotCast(checkedType, NumericType) - shouldNotCast(checkedType, IntegralType) - } - - test("implicit type cast - BooleanType") { - val checkedType = BooleanType - checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) - shouldNotCast(checkedType, DecimalType) - shouldNotCast(checkedType, NumericType) - shouldNotCast(checkedType, IntegralType) - shouldNotCast(checkedType, StringType) - } - test("implicit type cast - unfoldable StringType") { val nonCastableTypes = allTypes.filterNot(_ == StringType) nonCastableTypes.foreach { dt => @@ -251,23 +119,6 @@ class AnsiTypeCoercionSuite extends AnalysisTest { shouldCastStringLiteral(NumericType, DoubleType) } - test("implicit type cast - DateType") { - val checkedType = DateType - checkTypeCasting(checkedType, castableTypes = Seq(checkedType, TimestampType)) - shouldNotCast(checkedType, DecimalType) - shouldNotCast(checkedType, NumericType) - shouldNotCast(checkedType, IntegralType) - shouldNotCast(checkedType, StringType) - } - - test("implicit type cast - TimestampType") { - val checkedType = TimestampType - checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) - shouldNotCast(checkedType, DecimalType) - shouldNotCast(checkedType, NumericType) - shouldNotCast(checkedType, IntegralType) - } - test("implicit type cast - unfoldable ArrayType(StringType)") { val input = AttributeReference("a", ArrayType(StringType))() val nonCastableTypes = allTypes.filterNot(_ == StringType) @@ -278,55 +129,6 @@ class AnsiTypeCoercionSuite extends AnalysisTest { assert(AnsiTypeCoercion.implicitCast(input, NumericType).isEmpty) } - test("implicit type cast - foldable arrayType(StringType)") { - val input = Literal(Array("1")) - assert(AnsiTypeCoercion.implicitCast(input, ArrayType(StringType)) == Some(input)) - (numericTypes ++ datetimeTypes ++ Seq(BinaryType)).foreach { dt => - assert(AnsiTypeCoercion.implicitCast(input, ArrayType(dt)) == - Some(Cast(input, ArrayType(dt)))) - } - } - - test("implicit type cast between two Map types") { - val sourceType = MapType(IntegerType, IntegerType, true) - val castableTypes = - Seq(IntegerType, LongType, FloatType, DoubleType, DecimalType.SYSTEM_DEFAULT) - val targetTypes = castableTypes.map { t => - MapType(t, sourceType.valueType, valueContainsNull = true) - } - val nonCastableTargetTypes = allTypes.filterNot(castableTypes.contains(_)).map {t => - MapType(t, sourceType.valueType, valueContainsNull = true) - } - - // Tests that its possible to setup implicit casts between two map types when - // source map's key type is integer and the target map's key type are either Byte, Short, - // Long, Double, Float, Decimal(38, 18) or String. - targetTypes.foreach { targetType => - shouldCast(sourceType, targetType, targetType) - } - - // Tests that its not possible to setup implicit casts between two map types when - // source map's key type is integer and the target map's key type are either Binary, - // Boolean, Date, Timestamp, Array, Struct, CalendarIntervalType or NullType - nonCastableTargetTypes.foreach { targetType => - shouldNotCast(sourceType, targetType) - } - - // Tests that its not possible to cast from nullable map type to not nullable map type. - val targetNotNullableTypes = allTypes.filterNot(_ == IntegerType).map { t => - MapType(t, sourceType.valueType, valueContainsNull = false) - } - val sourceMapExprWithValueNull = - CreateMap(Seq(Literal.default(sourceType.keyType), - Literal.create(null, sourceType.valueType))) - targetNotNullableTypes.foreach { targetType => - val castDefault = - AnsiTypeCoercion.implicitCast(sourceMapExprWithValueNull, targetType) - assert(castDefault.isEmpty, - s"Should not be able to cast $sourceType to $targetType, but got $castDefault") - } - } - test("implicit type cast - StructType().add(\"a1\", StringType)") { val checkedType = new StructType().add("a1", StringType) checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) @@ -345,64 +147,12 @@ class AnsiTypeCoercionSuite extends AnalysisTest { test("implicit type cast - CalendarIntervalType") { val checkedType = CalendarIntervalType - checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) + checkTypeCasting(checkedType, castableTypes = Seq(checkedType, StringType)) shouldNotCast(checkedType, DecimalType) shouldNotCast(checkedType, NumericType) shouldNotCast(checkedType, IntegralType) } - test("eligible implicit type cast - TypeCollection") { - shouldCast(StringType, TypeCollection(StringType, BinaryType), StringType) - shouldCast(BinaryType, TypeCollection(StringType, BinaryType), BinaryType) - shouldCast(StringType, TypeCollection(BinaryType, StringType), StringType) - - shouldCast(IntegerType, TypeCollection(IntegerType, BinaryType), IntegerType) - shouldCast(IntegerType, TypeCollection(BinaryType, IntegerType), IntegerType) - shouldCast(BinaryType, TypeCollection(BinaryType, IntegerType), BinaryType) - shouldCast(BinaryType, TypeCollection(IntegerType, BinaryType), BinaryType) - - shouldCast(DecimalType.SYSTEM_DEFAULT, - TypeCollection(IntegerType, DecimalType), DecimalType.SYSTEM_DEFAULT) - shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2)) - shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2)) - - shouldCast( - ArrayType(StringType, false), - TypeCollection(ArrayType(StringType), StringType), - ArrayType(StringType, false)) - - shouldCast( - ArrayType(StringType, true), - TypeCollection(ArrayType(StringType), StringType), - ArrayType(StringType, true)) - - // When there are multiple convertible types in the `TypeCollection`, use the closest - // convertible data type among convertible types. - shouldCast(IntegerType, TypeCollection(BinaryType, FloatType, LongType), LongType) - shouldCast(ShortType, TypeCollection(BinaryType, LongType, IntegerType), IntegerType) - shouldCast(ShortType, TypeCollection(DateType, LongType, IntegerType, DoubleType), IntegerType) - // If the result is Float type and Double type is also among the convertible target types, - // use Double Type instead of Float type. - shouldCast(LongType, TypeCollection(FloatType, DoubleType, StringType), DoubleType) - } - - test("ineligible implicit type cast - TypeCollection") { - shouldNotCast(IntegerType, TypeCollection(StringType, BinaryType)) - shouldNotCast(IntegerType, TypeCollection(BinaryType, StringType)) - shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType)) - shouldNotCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType)) - shouldNotCastStringInput(TypeCollection(NumericType, BinaryType)) - // When there are multiple convertible types in the `TypeCollection` and there is no such - // a data type that can be implicit cast to all the other convertible types in the collection. - Seq(TypeCollection(NumericType, BinaryType), - TypeCollection(NumericType, DecimalType, BinaryType), - TypeCollection(IntegerType, LongType, BooleanType), - TypeCollection(DateType, TimestampType, BooleanType)).foreach { typeCollection => - shouldNotCastStringLiteral(typeCollection) - shouldNotCast(NullType, typeCollection) - } - } - test("tightest common bound for types") { def widenTest(t1: DataType, t2: DataType, expected: Option[DataType]): Unit = checkWidenType(AnsiTypeCoercion.findTightestCommonType, t1, t2, expected) @@ -606,25 +356,6 @@ class AnsiTypeCoercionSuite extends AnalysisTest { None) } - private def ruleTest(rule: Rule[LogicalPlan], - initial: Expression, transformed: Expression): Unit = { - ruleTest(Seq(rule), initial, transformed) - } - - private def ruleTest( - rules: Seq[Rule[LogicalPlan]], - initial: Expression, - transformed: Expression): Unit = { - val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) - val analyzer = new RuleExecutor[LogicalPlan] { - override val batches = Seq(Batch("Resolution", FixedPoint(3), rules: _*)) - } - - comparePlans( - analyzer.execute(Project(Seq(Alias(initial, "a")()), testRelation)), - Project(Seq(Alias(transformed, "a")()), testRelation)) - } - test("cast NullType for expressions that implement ExpectsInputTypes") { ruleTest(AnsiTypeCoercion.ImplicitTypeCasts, AnyTypeUnaryExpression(Literal.create(null, NullType)), @@ -1000,90 +731,6 @@ class AnsiTypeCoercionSuite extends AnalysisTest { Literal.create(null, IntegerType), Literal.create(null, StringType)))) } - test("type coercion for Concat") { - val rule = AnsiTypeCoercion.ConcatCoercion - - ruleTest(rule, - Concat(Seq(Literal("ab"), Literal("cde"))), - Concat(Seq(Literal("ab"), Literal("cde")))) - ruleTest(rule, - Concat(Seq(Literal(null), Literal("abc"))), - Concat(Seq(Cast(Literal(null), StringType), Literal("abc")))) - ruleTest(rule, - Concat(Seq(Literal(1), Literal("234"))), - Concat(Seq(Literal(1), Literal("234")))) - ruleTest(rule, - Concat(Seq(Literal("1"), Literal("234".getBytes()))), - Concat(Seq(Literal("1"), Literal("234".getBytes())))) - ruleTest(rule, - Concat(Seq(Literal(1L), Literal(2.toByte), Literal(0.1))), - Concat(Seq(Literal(1L), Literal(2.toByte), Literal(0.1)))) - ruleTest(rule, - Concat(Seq(Literal(true), Literal(0.1f), Literal(3.toShort))), - Concat(Seq(Literal(true), Literal(0.1f), Literal(3.toShort)))) - ruleTest(rule, - Concat(Seq(Literal(1L), Literal(0.1))), - Concat(Seq(Literal(1L), Literal(0.1)))) - ruleTest(rule, - Concat(Seq(Literal(Decimal(10)))), - Concat(Seq(Literal(Decimal(10))))) - ruleTest(rule, - Concat(Seq(Literal(BigDecimal.valueOf(10)))), - Concat(Seq(Literal(BigDecimal.valueOf(10))))) - ruleTest(rule, - Concat(Seq(Literal(java.math.BigDecimal.valueOf(10)))), - Concat(Seq(Literal(java.math.BigDecimal.valueOf(10))))) - ruleTest(rule, - Concat(Seq(Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), - Concat(Seq(Literal(new java.sql.Date(0)), Literal(new Timestamp(0))))) - - ruleTest(rule, - Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))), - Concat(Seq(Literal("123".getBytes), Literal("456".getBytes)))) - } - - test("type coercion for Elt") { - val rule = AnsiTypeCoercion.EltCoercion - - ruleTest(rule, - Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))), - Elt(Seq(Literal(1), Literal("ab"), Literal("cde")))) - ruleTest(rule, - Elt(Seq(Literal(1.toShort), Literal("ab"), Literal("cde"))), - Elt(Seq(Cast(Literal(1.toShort), IntegerType), Literal("ab"), Literal("cde")))) - ruleTest(rule, - Elt(Seq(Literal(2), Literal(null), Literal("abc"))), - Elt(Seq(Literal(2), Cast(Literal(null), StringType), Literal("abc")))) - ruleTest(rule, - Elt(Seq(Literal(2), Literal(1), Literal("234"))), - Elt(Seq(Literal(2), Literal(1), Literal("234")))) - ruleTest(rule, - Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1))), - Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1)))) - ruleTest(rule, - Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort))), - Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort)))) - ruleTest(rule, - Elt(Seq(Literal(1), Literal(1L), Literal(0.1))), - Elt(Seq(Literal(1), Literal(1L), Literal(0.1)))) - ruleTest(rule, - Elt(Seq(Literal(1), Literal(Decimal(10)))), - Elt(Seq(Literal(1), Literal(Decimal(10))))) - ruleTest(rule, - Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10)))), - Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10))))) - ruleTest(rule, - Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10)))), - Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10))))) - ruleTest(rule, - Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), - Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0))))) - - ruleTest(rule, - Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), - Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes)))) - } - private def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = { logical.output.zip(expectTypes).foreach { case (attr, dt) => assert(attr.dataType === dt) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 2dc669bbb9..8de84b3ae2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import java.sql.Timestamp -import java.time.{Duration, Period} +import java.time.{Duration, LocalDateTime, Period} import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.sql.catalyst.analysis.TypeCoercion._ @@ -31,7 +31,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class TypeCoercionSuite extends AnalysisTest { +abstract class TypeCoercionSuiteBase extends AnalysisTest { import TypeCoercionSuite._ // When Utils.isTesting is true, RuleIdCollection adds individual type coercion rules. Otherwise, @@ -39,59 +39,35 @@ class TypeCoercionSuite extends AnalysisTest { // CombinedTypeCoercionRule. assert(Utils.isTesting, s"${IS_TESTING.key} is not set to true") - // scalastyle:off line.size.limit - // The following table shows all implicit data type conversions that are not visible to the user. - // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ - // | Source Type\CAST TO | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType | NumericType | IntegralType | - // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ - // | ByteType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(3, 0) | ByteType | ByteType | - // | ShortType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(5, 0) | ShortType | ShortType | - // | IntegerType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(10, 0) | IntegerType | IntegerType | - // | LongType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(20, 0) | LongType | LongType | - // | DoubleType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(30, 15) | DoubleType | IntegerType | - // | FloatType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(14, 7) | FloatType | IntegerType | - // | Dec(10, 2) | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(10, 2) | Dec(10, 2) | IntegerType | - // | BinaryType | X | X | X | X | X | X | X | BinaryType | X | StringType | X | X | X | X | X | X | X | X | X | X | - // | BooleanType | X | X | X | X | X | X | X | X | BooleanType | StringType | X | X | X | X | X | X | X | X | X | X | - // | StringType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | X | StringType | DateType | TimestampType | X | X | X | X | X | DecimalType(38, 18) | DoubleType | X | - // | DateType | X | X | X | X | X | X | X | X | X | StringType | DateType | TimestampType | X | X | X | X | X | X | X | X | - // | TimestampType | X | X | X | X | X | X | X | X | X | StringType | DateType | TimestampType | X | X | X | X | X | X | X | X | - // | ArrayType | X | X | X | X | X | X | X | X | X | X | X | X | ArrayType* | X | X | X | X | X | X | X | - // | MapType | X | X | X | X | X | X | X | X | X | X | X | X | X | MapType* | X | X | X | X | X | X | - // | StructType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | StructType* | X | X | X | X | X | - // | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType | - // | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X | - // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ - // Note: StructType* is castable when all the internal child types are castable according to the table. - // Note: ArrayType* is castable when the element type is castable according to the table. - // Note: MapType* is castable when both the key type and the value type are castable according to the table. - // scalastyle:on line.size.limit + protected def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] + + protected def dateTimeOperationsRule: TypeCoercionRule - private def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { + protected def shouldCast(from: DataType, to: AbstractDataType, expected: DataType): Unit = { // Check default value - val castDefault = TypeCoercion.implicitCast(default(from), to) + val castDefault = implicitCast(default(from), to) assert(DataType.equalsIgnoreCompatibleNullability( castDefault.map(_.dataType).getOrElse(null), expected), s"Failed to cast $from to $to") // Check null value - val castNull = TypeCoercion.implicitCast(createNull(from), to) + val castNull = implicitCast(createNull(from), to) assert(DataType.equalsIgnoreCaseAndNullability( castNull.map(_.dataType).getOrElse(null), expected), s"Failed to cast $from to $to") } - private def shouldNotCast(from: DataType, to: AbstractDataType): Unit = { + protected def shouldNotCast(from: DataType, to: AbstractDataType): Unit = { // Check default value - val castDefault = TypeCoercion.implicitCast(default(from), to) + val castDefault = implicitCast(default(from), to) assert(castDefault.isEmpty, s"Should not be able to cast $from to $to, but got $castDefault") // Check null value - val castNull = TypeCoercion.implicitCast(createNull(from), to) + val castNull = implicitCast(createNull(from), to) assert(castNull.isEmpty, s"Should not be able to cast $from to $to, but got $castNull") } - private def default(dataType: DataType): Expression = dataType match { + protected def default(dataType: DataType): Expression = dataType match { case ArrayType(internalType: DataType, _) => CreateArray(Seq(Literal.default(internalType))) case MapType(keyDataType: DataType, valueDataType: DataType, _) => @@ -99,7 +75,7 @@ class TypeCoercionSuite extends AnalysisTest { case _ => Literal.default(dataType) } - private def createNull(dataType: DataType): Expression = dataType match { + protected def createNull(dataType: DataType): Expression = dataType match { case ArrayType(internalType: DataType, _) => CreateArray(Seq(Literal.create(null, internalType))) case MapType(keyDataType: DataType, valueDataType: DataType, _) => @@ -109,7 +85,7 @@ class TypeCoercionSuite extends AnalysisTest { // Check whether the type `checkedType` can be cast to all the types in `castableTypes`, // but cannot be cast to the other types in `allTypes`. - private def checkTypeCasting(checkedType: DataType, castableTypes: Seq[DataType]): Unit = { + protected def checkTypeCasting(checkedType: DataType, castableTypes: Seq[DataType]): Unit = { val nonCastableTypes = allTypes.filterNot(castableTypes.contains) castableTypes.foreach { tpe => @@ -120,21 +96,23 @@ class TypeCoercionSuite extends AnalysisTest { } } - private def checkWidenType( - widenFunc: (DataType, DataType) => Option[DataType], - t1: DataType, - t2: DataType, - expected: Option[DataType], - isSymmetric: Boolean = true): Unit = { - var found = widenFunc(t1, t2) - assert(found == expected, - s"Expected $expected as wider common type for $t1 and $t2, found $found") - // Test both directions to make sure the widening is symmetric. - if (isSymmetric) { - found = widenFunc(t2, t1) - assert(found == expected, - s"Expected $expected as wider common type for $t2 and $t1, found $found") + protected def ruleTest(rule: Rule[LogicalPlan], + initial: Expression, transformed: Expression): Unit = { + ruleTest(Seq(rule), initial, transformed) + } + + protected def ruleTest( + rules: Seq[Rule[LogicalPlan]], + initial: Expression, + transformed: Expression): Unit = { + val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) + val analyzer = new RuleExecutor[LogicalPlan] { + override val batches = Seq(Batch("Resolution", FixedPoint(3), rules: _*)) } + + comparePlans( + analyzer.execute(Project(Seq(Alias(initial, "a")()), testRelation)), + Project(Seq(Alias(transformed, "a")()), testRelation)) } test("implicit type cast - ByteType") { @@ -209,16 +187,6 @@ class TypeCoercionSuite extends AnalysisTest { shouldNotCast(checkedType, IntegralType) } - test("implicit type cast - StringType") { - val checkedType = StringType - val nonCastableTypes = - complexTypes ++ Seq(BooleanType, NullType, CalendarIntervalType) - checkTypeCasting(checkedType, castableTypes = allTypes.filterNot(nonCastableTypes.contains)) - shouldCast(checkedType, DecimalType, DecimalType.SYSTEM_DEFAULT) - shouldCast(checkedType, NumericType, NumericType.defaultConcreteType) - shouldNotCast(checkedType, IntegralType) - } - test("implicit type cast - DateType") { val checkedType = DateType checkTypeCasting(checkedType, castableTypes = Seq(checkedType, StringType, TimestampType)) @@ -235,20 +203,6 @@ class TypeCoercionSuite extends AnalysisTest { shouldNotCast(checkedType, IntegralType) } - test("implicit type cast - ArrayType(StringType)") { - val checkedType = ArrayType(StringType) - val nonCastableTypes = - complexTypes ++ Seq(BooleanType, NullType, CalendarIntervalType) - checkTypeCasting(checkedType, - castableTypes = allTypes.filterNot(nonCastableTypes.contains).map(ArrayType(_))) - nonCastableTypes.map(ArrayType(_)).foreach(shouldNotCast(checkedType, _)) - shouldNotCast(ArrayType(DoubleType, containsNull = false), - ArrayType(LongType, containsNull = false)) - shouldNotCast(checkedType, DecimalType) - shouldNotCast(checkedType, NumericType) - shouldNotCast(checkedType, IntegralType) - } - test("implicit type cast between two Map types") { val sourceType = MapType(IntegerType, IntegerType, true) val castableTypes = numericTypes ++ Seq(StringType).filter(!Cast.forceNullable(IntegerType, _)) @@ -288,30 +242,6 @@ class TypeCoercionSuite extends AnalysisTest { } } - test("implicit type cast - StructType().add(\"a1\", StringType)") { - val checkedType = new StructType().add("a1", StringType) - checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) - shouldNotCast(checkedType, DecimalType) - shouldNotCast(checkedType, NumericType) - shouldNotCast(checkedType, IntegralType) - } - - test("implicit type cast - NullType") { - val checkedType = NullType - checkTypeCasting(checkedType, castableTypes = allTypes) - shouldCast(checkedType, DecimalType, DecimalType.SYSTEM_DEFAULT) - shouldCast(checkedType, NumericType, NumericType.defaultConcreteType) - shouldCast(checkedType, IntegralType, IntegralType.defaultConcreteType) - } - - test("implicit type cast - CalendarIntervalType") { - val checkedType = CalendarIntervalType - checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) - shouldNotCast(checkedType, DecimalType) - shouldNotCast(checkedType, NumericType) - shouldNotCast(checkedType, IntegralType) - } - test("eligible implicit type cast - TypeCollection") { shouldCast(NullType, TypeCollection(StringType, BinaryType), StringType) @@ -333,8 +263,6 @@ class TypeCoercionSuite extends AnalysisTest { shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2)) shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2)) - shouldCast(StringType, TypeCollection(NumericType, BinaryType), DoubleType) - shouldCast( ArrayType(StringType, false), TypeCollection(ArrayType(StringType), StringType), @@ -350,6 +278,249 @@ class TypeCoercionSuite extends AnalysisTest { shouldNotCast(IntegerType, TypeCollection(DateType, TimestampType)) } + test("type coercion for Concat") { + val rule = TypeCoercion.ConcatCoercion + + ruleTest(rule, + Concat(Seq(Literal("ab"), Literal("cde"))), + Concat(Seq(Literal("ab"), Literal("cde")))) + ruleTest(rule, + Concat(Seq(Literal(null), Literal("abc"))), + Concat(Seq(Cast(Literal(null), StringType), Literal("abc")))) + ruleTest(rule, + Concat(Seq(Literal(1), Literal("234"))), + Concat(Seq(Cast(Literal(1), StringType), Literal("234")))) + ruleTest(rule, + Concat(Seq(Literal("1"), Literal("234".getBytes()))), + Concat(Seq(Literal("1"), Cast(Literal("234".getBytes()), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(1L), Literal(2.toByte), Literal(0.1))), + Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType), + Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(true), Literal(0.1f), Literal(3.toShort))), + Concat(Seq(Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType), + Cast(Literal(3.toShort), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(1L), Literal(0.1))), + Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(Decimal(10)))), + Concat(Seq(Cast(Literal(Decimal(10)), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(BigDecimal.valueOf(10)))), + Concat(Seq(Cast(Literal(BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(java.math.BigDecimal.valueOf(10)))), + Concat(Seq(Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Concat(Seq(Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), + Concat(Seq(Cast(Literal(new java.sql.Date(0)), StringType), + Cast(Literal(new Timestamp(0)), StringType)))) + + withSQLConf(SQLConf.CONCAT_BINARY_AS_STRING.key -> "true") { + ruleTest(rule, + Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))), + Concat(Seq(Cast(Literal("123".getBytes), StringType), + Cast(Literal("456".getBytes), StringType)))) + } + + withSQLConf(SQLConf.CONCAT_BINARY_AS_STRING.key -> "false") { + ruleTest(rule, + Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))), + Concat(Seq(Literal("123".getBytes), Literal("456".getBytes)))) + } + } + + test("type coercion for Elt") { + val rule = TypeCoercion.EltCoercion + + ruleTest(rule, + Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))), + Elt(Seq(Literal(1), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(1.toShort), Literal("ab"), Literal("cde"))), + Elt(Seq(Cast(Literal(1.toShort), IntegerType), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(null), Literal("abc"))), + Elt(Seq(Literal(2), Cast(Literal(null), StringType), Literal("abc")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(1), Literal("234"))), + Elt(Seq(Literal(2), Cast(Literal(1), StringType), Literal("234")))) + ruleTest(rule, + Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1))), + Elt(Seq(Literal(3), Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType), + Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort))), + Elt(Seq(Literal(2), Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType), + Cast(Literal(3.toShort), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(1L), Literal(0.1))), + Elt(Seq(Literal(1), Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(Decimal(10)))), + Elt(Seq(Literal(1), Cast(Literal(Decimal(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Cast(Literal(BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), + Elt(Seq(Literal(2), Cast(Literal(new java.sql.Date(0)), StringType), + Cast(Literal(new Timestamp(0)), StringType)))) + + withSQLConf(SQLConf.ELT_OUTPUT_AS_STRING.key -> "true") { + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Cast(Literal("123".getBytes), StringType), + Cast(Literal("456".getBytes), StringType)))) + } + + withSQLConf(SQLConf.ELT_OUTPUT_AS_STRING.key -> "false") { + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes)))) + } + } + + test("Datetime operations") { + val rule = dateTimeOperationsRule + val dateLiteral = Literal(java.sql.Date.valueOf("2021-01-01")) + val timestampLiteral = Literal(Timestamp.valueOf("2021-01-01 00:00:00")) + val timestampNTZLiteral = Literal(LocalDateTime.parse("2021-01-01T00:00:00")) + val intLiteral = Literal(3) + Seq(timestampLiteral, timestampNTZLiteral).foreach { tsLiteral => + ruleTest(rule, + DateAdd(tsLiteral, intLiteral), + DateAdd(Cast(tsLiteral, DateType), intLiteral)) + ruleTest(rule, + DateSub(tsLiteral, intLiteral), + DateSub(Cast(tsLiteral, DateType), intLiteral)) + ruleTest(rule, + SubtractTimestamps(tsLiteral, dateLiteral), + SubtractTimestamps(tsLiteral, Cast(dateLiteral, tsLiteral.dataType))) + ruleTest(rule, + SubtractTimestamps(dateLiteral, tsLiteral), + SubtractTimestamps(Cast(dateLiteral, tsLiteral.dataType), tsLiteral)) + } + + ruleTest(rule, + SubtractTimestamps(timestampLiteral, timestampNTZLiteral), + SubtractTimestamps(Cast(timestampLiteral, TimestampNTZType), timestampNTZLiteral)) + ruleTest(rule, + SubtractTimestamps(timestampNTZLiteral, timestampLiteral), + SubtractTimestamps(timestampNTZLiteral, Cast(timestampLiteral, TimestampNTZType))) + } + +} + +class TypeCoercionSuite extends TypeCoercionSuiteBase { + import TypeCoercionSuite._ + + // scalastyle:off line.size.limit + // The following table shows all implicit data type conversions that are not visible to the user. + // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ + // | Source Type\CAST TO | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType | NumericType | IntegralType | + // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ + // | ByteType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(3, 0) | ByteType | ByteType | + // | ShortType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(5, 0) | ShortType | ShortType | + // | IntegerType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(10, 0) | IntegerType | IntegerType | + // | LongType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(20, 0) | LongType | LongType | + // | DoubleType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(30, 15) | DoubleType | IntegerType | + // | FloatType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(14, 7) | FloatType | IntegerType | + // | Dec(10, 2) | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | X | X | StringType | X | X | X | X | X | X | X | DecimalType(10, 2) | Dec(10, 2) | IntegerType | + // | BinaryType | X | X | X | X | X | X | X | BinaryType | X | StringType | X | X | X | X | X | X | X | X | X | X | + // | BooleanType | X | X | X | X | X | X | X | X | BooleanType | StringType | X | X | X | X | X | X | X | X | X | X | + // | StringType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | X | StringType | DateType | TimestampType | X | X | X | X | X | DecimalType(38, 18) | DoubleType | X | + // | DateType | X | X | X | X | X | X | X | X | X | StringType | DateType | TimestampType | X | X | X | X | X | X | X | X | + // | TimestampType | X | X | X | X | X | X | X | X | X | StringType | DateType | TimestampType | X | X | X | X | X | X | X | X | + // | ArrayType | X | X | X | X | X | X | X | X | X | X | X | X | ArrayType* | X | X | X | X | X | X | X | + // | MapType | X | X | X | X | X | X | X | X | X | X | X | X | X | MapType* | X | X | X | X | X | X | + // | StructType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | StructType* | X | X | X | X | X | + // | NullType | ByteType | ShortType | IntegerType | LongType | DoubleType | FloatType | Dec(10, 2) | BinaryType | BooleanType | StringType | DateType | TimestampType | ArrayType | MapType | StructType | NullType | CalendarIntervalType | DecimalType(38, 18) | DoubleType | IntegerType | + // | CalendarIntervalType | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | X | CalendarIntervalType | X | X | X | + // +----------------------+----------+-----------+-------------+----------+------------+-----------+------------+------------+-------------+------------+----------+---------------+------------+----------+-------------+----------+----------------------+---------------------+-------------+--------------+ + // Note: StructType* is castable when all the internal child types are castable according to the table. + // Note: ArrayType* is castable when the element type is castable according to the table. + // Note: MapType* is castable when both the key type and the value type are castable according to the table. + // scalastyle:on line.size.limit + override def implicitCast(e: Expression, expectedType: AbstractDataType): Option[Expression] = + TypeCoercion.implicitCast(e, expectedType) + + override def dateTimeOperationsRule: TypeCoercionRule = TypeCoercion.DateTimeOperations + + private def checkWidenType( + widenFunc: (DataType, DataType) => Option[DataType], + t1: DataType, + t2: DataType, + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { + var found = widenFunc(t1, t2) + assert(found == expected, + s"Expected $expected as wider common type for $t1 and $t2, found $found") + // Test both directions to make sure the widening is symmetric. + if (isSymmetric) { + found = widenFunc(t2, t1) + assert(found == expected, + s"Expected $expected as wider common type for $t2 and $t1, found $found") + } + } + + test("implicit type cast - StringType") { + val checkedType = StringType + val nonCastableTypes = + complexTypes ++ Seq(BooleanType, NullType, CalendarIntervalType) + checkTypeCasting(checkedType, castableTypes = allTypes.filterNot(nonCastableTypes.contains)) + shouldCast(checkedType, DecimalType, DecimalType.SYSTEM_DEFAULT) + shouldCast(checkedType, NumericType, NumericType.defaultConcreteType) + shouldNotCast(checkedType, IntegralType) + } + + test("implicit type cast - ArrayType(StringType)") { + val checkedType = ArrayType(StringType) + val nonCastableTypes = + complexTypes ++ Seq(BooleanType, NullType, CalendarIntervalType) + checkTypeCasting(checkedType, + castableTypes = allTypes.filterNot(nonCastableTypes.contains).map(ArrayType(_))) + nonCastableTypes.map(ArrayType(_)).foreach(shouldNotCast(checkedType, _)) + shouldNotCast(ArrayType(DoubleType, containsNull = false), + ArrayType(LongType, containsNull = false)) + shouldNotCast(checkedType, DecimalType) + shouldNotCast(checkedType, NumericType) + shouldNotCast(checkedType, IntegralType) + } + + test("implicit type cast - StructType().add(\"a1\", StringType)") { + val checkedType = new StructType().add("a1", StringType) + checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) + shouldNotCast(checkedType, DecimalType) + shouldNotCast(checkedType, NumericType) + shouldNotCast(checkedType, IntegralType) + } + + test("implicit type cast - NullType") { + val checkedType = NullType + checkTypeCasting(checkedType, castableTypes = allTypes) + shouldCast(checkedType, DecimalType, DecimalType.SYSTEM_DEFAULT) + shouldCast(checkedType, NumericType, NumericType.defaultConcreteType) + shouldCast(checkedType, IntegralType, IntegralType.defaultConcreteType) + } + + test("implicit type cast - CalendarIntervalType") { + val checkedType = CalendarIntervalType + checkTypeCasting(checkedType, castableTypes = Seq(checkedType)) + shouldNotCast(checkedType, DecimalType) + shouldNotCast(checkedType, NumericType) + shouldNotCast(checkedType, IntegralType) + } + + test("eligible implicit type cast - TypeCollection II") { + shouldCast(StringType, TypeCollection(NumericType, BinaryType), DoubleType) + } + test("tightest common bound for types") { def widenTest(t1: DataType, t2: DataType, expected: Option[DataType]): Unit = checkWidenType(TypeCoercion.findTightestCommonType, t1, t2, expected) @@ -717,25 +888,6 @@ class TypeCoercionSuite extends AnalysisTest { Some(new StructType().add("a", StringType))) } - private def ruleTest(rule: Rule[LogicalPlan], - initial: Expression, transformed: Expression): Unit = { - ruleTest(Seq(rule), initial, transformed) - } - - private def ruleTest( - rules: Seq[Rule[LogicalPlan]], - initial: Expression, - transformed: Expression): Unit = { - val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) - val analyzer = new RuleExecutor[LogicalPlan] { - override val batches = Seq(Batch("Resolution", FixedPoint(3), rules: _*)) - } - - comparePlans( - analyzer.execute(Project(Seq(Alias(initial, "a")()), testRelation)), - Project(Seq(Alias(transformed, "a")()), testRelation)) - } - test("cast NullType for expressions that implement ExpectsInputTypes") { ruleTest(TypeCoercion.ImplicitTypeCasts, AnyTypeUnaryExpression(Literal.create(null, NullType)), @@ -1110,114 +1262,6 @@ class TypeCoercionSuite extends AnalysisTest { Literal.create(null, IntegerType), Literal.create(null, StringType)))) } - test("type coercion for Concat") { - val rule = TypeCoercion.ConcatCoercion - - ruleTest(rule, - Concat(Seq(Literal("ab"), Literal("cde"))), - Concat(Seq(Literal("ab"), Literal("cde")))) - ruleTest(rule, - Concat(Seq(Literal(null), Literal("abc"))), - Concat(Seq(Cast(Literal(null), StringType), Literal("abc")))) - ruleTest(rule, - Concat(Seq(Literal(1), Literal("234"))), - Concat(Seq(Cast(Literal(1), StringType), Literal("234")))) - ruleTest(rule, - Concat(Seq(Literal("1"), Literal("234".getBytes()))), - Concat(Seq(Literal("1"), Cast(Literal("234".getBytes()), StringType)))) - ruleTest(rule, - Concat(Seq(Literal(1L), Literal(2.toByte), Literal(0.1))), - Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType), - Cast(Literal(0.1), StringType)))) - ruleTest(rule, - Concat(Seq(Literal(true), Literal(0.1f), Literal(3.toShort))), - Concat(Seq(Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType), - Cast(Literal(3.toShort), StringType)))) - ruleTest(rule, - Concat(Seq(Literal(1L), Literal(0.1))), - Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType)))) - ruleTest(rule, - Concat(Seq(Literal(Decimal(10)))), - Concat(Seq(Cast(Literal(Decimal(10)), StringType)))) - ruleTest(rule, - Concat(Seq(Literal(BigDecimal.valueOf(10)))), - Concat(Seq(Cast(Literal(BigDecimal.valueOf(10)), StringType)))) - ruleTest(rule, - Concat(Seq(Literal(java.math.BigDecimal.valueOf(10)))), - Concat(Seq(Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType)))) - ruleTest(rule, - Concat(Seq(Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), - Concat(Seq(Cast(Literal(new java.sql.Date(0)), StringType), - Cast(Literal(new Timestamp(0)), StringType)))) - - withSQLConf(SQLConf.CONCAT_BINARY_AS_STRING.key -> "true") { - ruleTest(rule, - Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))), - Concat(Seq(Cast(Literal("123".getBytes), StringType), - Cast(Literal("456".getBytes), StringType)))) - } - - withSQLConf(SQLConf.CONCAT_BINARY_AS_STRING.key -> "false") { - ruleTest(rule, - Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))), - Concat(Seq(Literal("123".getBytes), Literal("456".getBytes)))) - } - } - - test("type coercion for Elt") { - val rule = TypeCoercion.EltCoercion - - ruleTest(rule, - Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))), - Elt(Seq(Literal(1), Literal("ab"), Literal("cde")))) - ruleTest(rule, - Elt(Seq(Literal(1.toShort), Literal("ab"), Literal("cde"))), - Elt(Seq(Cast(Literal(1.toShort), IntegerType), Literal("ab"), Literal("cde")))) - ruleTest(rule, - Elt(Seq(Literal(2), Literal(null), Literal("abc"))), - Elt(Seq(Literal(2), Cast(Literal(null), StringType), Literal("abc")))) - ruleTest(rule, - Elt(Seq(Literal(2), Literal(1), Literal("234"))), - Elt(Seq(Literal(2), Cast(Literal(1), StringType), Literal("234")))) - ruleTest(rule, - Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1))), - Elt(Seq(Literal(3), Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType), - Cast(Literal(0.1), StringType)))) - ruleTest(rule, - Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort))), - Elt(Seq(Literal(2), Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType), - Cast(Literal(3.toShort), StringType)))) - ruleTest(rule, - Elt(Seq(Literal(1), Literal(1L), Literal(0.1))), - Elt(Seq(Literal(1), Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType)))) - ruleTest(rule, - Elt(Seq(Literal(1), Literal(Decimal(10)))), - Elt(Seq(Literal(1), Cast(Literal(Decimal(10)), StringType)))) - ruleTest(rule, - Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10)))), - Elt(Seq(Literal(1), Cast(Literal(BigDecimal.valueOf(10)), StringType)))) - ruleTest(rule, - Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10)))), - Elt(Seq(Literal(1), Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType)))) - ruleTest(rule, - Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), - Elt(Seq(Literal(2), Cast(Literal(new java.sql.Date(0)), StringType), - Cast(Literal(new Timestamp(0)), StringType)))) - - withSQLConf(SQLConf.ELT_OUTPUT_AS_STRING.key -> "true") { - ruleTest(rule, - Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), - Elt(Seq(Literal(1), Cast(Literal("123".getBytes), StringType), - Cast(Literal("456".getBytes), StringType)))) - } - - withSQLConf(SQLConf.ELT_OUTPUT_AS_STRING.key -> "false") { - ruleTest(rule, - Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), - Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes)))) - } - } - test("BooleanEquality type cast") { val be = TypeCoercion.BooleanEquality // Use something more than a literal to avoid triggering the simplification rules. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index bd133e7578..ea0d619ad4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -760,4 +760,32 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(WidthBucket(Literal(v), Literal(s), Literal(e), Literal(n)), expected) } } + + test("SPARK-37388: width_bucket") { + val nullDouble = Literal.create(null, DoubleType) + val nullLong = Literal.create(null, LongType) + + checkEvaluation(WidthBucket(5.35, 0.024, 10.06, 5L), 3L) + checkEvaluation(WidthBucket(-2.1, 1.3, 3.4, 3L), 0L) + checkEvaluation(WidthBucket(8.1, 0.0, 5.7, 4L), 5L) + checkEvaluation(WidthBucket(-0.9, 5.2, 0.5, 2L), 3L) + checkEvaluation(WidthBucket(nullDouble, 0.024, 10.06, 5L), null) + checkEvaluation(WidthBucket(5.35, nullDouble, 10.06, 5L), null) + checkEvaluation(WidthBucket(5.35, 0.024, nullDouble, 5L), null) + checkEvaluation(WidthBucket(5.35, nullDouble, nullDouble, 5L), null) + checkEvaluation(WidthBucket(5.35, 0.024, 10.06, nullLong), null) + checkEvaluation(WidthBucket(nullDouble, nullDouble, nullDouble, nullLong), null) + checkEvaluation(WidthBucket(5.35, 0.024, 10.06, -5L), null) + checkEvaluation(WidthBucket(5.35, 0.024, 10.06, Long.MaxValue), null) + checkEvaluation(WidthBucket(Double.NaN, 0.024, 10.06, 5L), null) + checkEvaluation(WidthBucket(Double.NegativeInfinity, 0.024, 10.06, 5L), 0L) + checkEvaluation(WidthBucket(Double.PositiveInfinity, 0.024, 10.06, 5L), 6L) + checkEvaluation(WidthBucket(5.35, 0.024, 0.024, 5L), null) + checkEvaluation(WidthBucket(5.35, Double.NaN, 10.06, 5L), null) + checkEvaluation(WidthBucket(5.35, Double.NegativeInfinity, 10.06, 5L), null) + checkEvaluation(WidthBucket(5.35, Double.PositiveInfinity, 10.06, 5L), null) + checkEvaluation(WidthBucket(5.35, 0.024, Double.NaN, 5L), null) + checkEvaluation(WidthBucket(5.35, 0.024, Double.NegativeInfinity, 5L), null) + checkEvaluation(WidthBucket(5.35, 0.024, Double.PositiveInfinity, 5L), null) + } } diff --git a/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-results.txt b/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-results.txt new file mode 100644 index 0000000000..1e6c85a126 --- /dev/null +++ b/sql/core/benchmarks/StateStoreBasicOperationsBenchmark-results.txt @@ -0,0 +1,183 @@ +================================================================================================ +put rows +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +putting 10000 rows (10000 rows to overwrite - rate 100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------------------- +In-memory 6 9 1 1.6 614.4 1.0X +RocksDB (trackTotalNumberOfRows: true) 51 55 1 0.2 5147.7 0.1X +RocksDB (trackTotalNumberOfRows: false) 13 16 1 0.8 1295.6 0.5X + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +putting 10000 rows (7500 rows to overwrite - rate 75): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------- +In-memory 7 10 1 1.5 650.0 1.0X +RocksDB (trackTotalNumberOfRows: true) 48 52 1 0.2 4821.3 0.1X +RocksDB (trackTotalNumberOfRows: false) 13 16 1 0.8 1273.7 0.5X + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +putting 10000 rows (5000 rows to overwrite - rate 50): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------- +In-memory 6 9 1 1.6 618.2 1.0X +RocksDB (trackTotalNumberOfRows: true) 44 48 1 0.2 4398.4 0.1X +RocksDB (trackTotalNumberOfRows: false) 13 16 1 0.8 1308.2 0.5X + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +putting 10000 rows (2500 rows to overwrite - rate 25): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------- +In-memory 6 8 1 1.7 604.7 1.0X +RocksDB (trackTotalNumberOfRows: true) 41 44 1 0.2 4082.0 0.1X +RocksDB (trackTotalNumberOfRows: false) 13 16 1 0.8 1283.1 0.5X + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +putting 10000 rows (1000 rows to overwrite - rate 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------- +In-memory 6 8 1 1.7 579.8 1.0X +RocksDB (trackTotalNumberOfRows: true) 39 42 1 0.3 3876.7 0.1X +RocksDB (trackTotalNumberOfRows: false) 13 16 1 0.7 1344.4 0.4X + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +putting 10000 rows (500 rows to overwrite - rate 5): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------------- +In-memory 6 7 1 1.7 573.9 1.0X +RocksDB (trackTotalNumberOfRows: true) 37 40 1 0.3 3696.6 0.2X +RocksDB (trackTotalNumberOfRows: false) 12 14 1 0.8 1229.4 0.5X + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +putting 10000 rows (0 rows to overwrite - rate 0): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------------- +In-memory 5 7 1 1.9 533.6 1.0X +RocksDB (trackTotalNumberOfRows: true) 35 37 1 0.3 3492.5 0.2X +RocksDB (trackTotalNumberOfRows: false) 13 14 0 0.8 1264.3 0.4X + + +================================================================================================ +delete rows +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +trying to delete 10000 rows from 10000 rows(10000 rows are non-existing - rate 100): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------------------------- +In-memory 1 1 0 14.9 67.3 1.0X +RocksDB (trackTotalNumberOfRows: true) 35 36 1 0.3 3493.2 0.0X +RocksDB (trackTotalNumberOfRows: false) 11 13 0 0.9 1129.2 0.1X + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +trying to delete 10000 rows from 10000 rows(7500 rows are non-existing - rate 75): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------------------------------------------- +In-memory 4 5 0 2.8 351.7 1.0X +RocksDB (trackTotalNumberOfRows: true) 38 41 1 0.3 3832.6 0.1X +RocksDB (trackTotalNumberOfRows: false) 11 13 0 0.9 1131.1 0.3X + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +trying to delete 10000 rows from 10000 rows(5000 rows are non-existing - rate 50): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------------------------------------------- +In-memory 4 6 1 2.5 399.2 1.0X +RocksDB (trackTotalNumberOfRows: true) 42 45 1 0.2 4198.3 0.1X +RocksDB (trackTotalNumberOfRows: false) 11 13 0 0.9 1127.1 0.4X + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +trying to delete 10000 rows from 10000 rows(2500 rows are non-existing - rate 25): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------------------------------------------- +In-memory 5 6 1 2.2 452.7 1.0X +RocksDB (trackTotalNumberOfRows: true) 45 48 1 0.2 4515.7 0.1X +RocksDB (trackTotalNumberOfRows: false) 11 13 0 0.9 1127.5 0.4X + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +trying to delete 10000 rows from 10000 rows(1000 rows are non-existing - rate 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------------------------------------------- +In-memory 5 7 1 2.1 476.3 1.0X +RocksDB (trackTotalNumberOfRows: true) 47 49 1 0.2 4683.2 0.1X +RocksDB (trackTotalNumberOfRows: false) 11 13 0 0.9 1136.9 0.4X + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +trying to delete 10000 rows from 10000 rows(500 rows are non-existing - rate 5): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +--------------------------------------------------------------------------------------------------------------------------------------------------------------- +In-memory 5 7 1 2.1 478.2 1.0X +RocksDB (trackTotalNumberOfRows: true) 47 50 1 0.2 4727.8 0.1X +RocksDB (trackTotalNumberOfRows: false) 11 13 0 0.9 1122.8 0.4X + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +trying to delete 10000 rows from 10000 rows(0 rows are non-existing - rate 0): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------------------- +In-memory 5 7 1 2.1 478.7 1.0X +RocksDB (trackTotalNumberOfRows: true) 48 51 1 0.2 4791.4 0.1X +RocksDB (trackTotalNumberOfRows: false) 11 13 0 0.9 1125.6 0.4X + + +================================================================================================ +evict rows +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +evicting 10000 rows (maxTimestampToEvictInMillis: 9999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------------- +In-memory 5 5 0 2.2 459.1 1.0X +RocksDB (trackTotalNumberOfRows: true) 42 44 1 0.2 4199.9 0.1X +RocksDB (trackTotalNumberOfRows: false) 9 10 0 1.1 943.3 0.5X + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +evicting 7500 rows (maxTimestampToEvictInMillis: 7499) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------------ +In-memory 4 5 0 2.5 398.0 1.0X +RocksDB (trackTotalNumberOfRows: true) 33 34 1 0.3 3280.3 0.1X +RocksDB (trackTotalNumberOfRows: false) 8 8 0 1.3 770.2 0.5X + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +evicting 5000 rows (maxTimestampToEvictInMillis: 4999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------------ +In-memory 4 4 0 2.8 359.9 1.0X +RocksDB (trackTotalNumberOfRows: true) 23 24 0 0.4 2293.8 0.2X +RocksDB (trackTotalNumberOfRows: false) 6 7 0 1.6 629.2 0.6X + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +evicting 2500 rows (maxTimestampToEvictInMillis: 2499) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------------ +In-memory 3 4 0 3.1 321.8 1.0X +RocksDB (trackTotalNumberOfRows: true) 13 14 0 0.8 1310.8 0.2X +RocksDB (trackTotalNumberOfRows: false) 5 5 0 2.1 481.3 0.7X + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +evicting 1000 rows (maxTimestampToEvictInMillis: 999) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------------------------------- +In-memory 3 4 0 3.4 291.6 1.0X +RocksDB (trackTotalNumberOfRows: true) 7 7 0 1.4 715.7 0.4X +RocksDB (trackTotalNumberOfRows: false) 4 4 0 2.5 394.2 0.7X + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +evicting 500 rows (maxTimestampToEvictInMillis: 499) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +---------------------------------------------------------------------------------------------------------------------------------------------------- +In-memory 3 3 0 3.4 295.8 1.0X +RocksDB (trackTotalNumberOfRows: true) 5 5 0 1.9 531.3 0.6X +RocksDB (trackTotalNumberOfRows: false) 4 4 0 2.7 366.8 0.8X + +OpenJDK 64-Bit Server VM 1.8.0_292-8u292-b10-0ubuntu1~18.04-b10 on Linux 5.4.0-1045-aws +Intel(R) Xeon(R) Platinum 8175M CPU @ 2.50GHz +evicting 0 rows (maxTimestampToEvictInMillis: -1) from 10000 rows: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------- +In-memory 1 1 0 17.0 58.9 1.0X +RocksDB (trackTotalNumberOfRows: true) 3 3 0 3.0 336.7 0.2X +RocksDB (trackTotalNumberOfRows: false) 3 3 0 3.0 335.9 0.2X + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index e5be7f4906..c55bdcabef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -401,15 +401,13 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) throw QueryCompilationErrors.externalCatalogNotSupportShowViewsError(resolved) } - case s @ ShowTableProperties(ResolvedV1TableOrViewIdentifier(ident), propertyKey, output) => - val newOutput = - if (conf.getConf(SQLConf.LEGACY_KEEP_COMMAND_OUTPUT_SCHEMA) && propertyKey.isDefined) { - assert(output.length == 2) - output.tail - } else { - output - } - ShowTablePropertiesCommand(ident.asTableIdentifier, propertyKey, newOutput) + // If target is view, force use v1 command + case s @ ShowTableProperties(ResolvedViewIdentifier(ident), propertyKey, output) => + ShowTablePropertiesCommand(ident.asTableIdentifier, propertyKey, output) + + case s @ ShowTableProperties(ResolvedV1TableIdentifier(ident), propertyKey, output) + if conf.useV1Command => + ShowTablePropertiesCommand(ident.asTableIdentifier, propertyKey, output) case DescribeFunction(ResolvedFunc(identifier), extended) => DescribeFunctionCommand(identifier.asFunctionIdentifier, extended) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 6ab9ff1b67..2a6bad6366 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -896,7 +896,7 @@ case class ShowTablePropertiesCommand( } case None => catalogTable.properties.filterKeys(!_.startsWith(CatalogTable.VIEW_PREFIX)) - .map(p => Row(p._1, p._2)).toSeq + .toSeq.sortBy(_._1).map(p => Row(p._1, p._2)).toSeq } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 8e02fc3857..0d85a45dbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution.datasources.orc -import java.time.{Duration, Instant, LocalDate, Period} +import java.sql.Timestamp +import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} @@ -25,7 +26,7 @@ import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateTimeToMicros, localDateToDays, toJavaDate, toJavaTimestamp} import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -146,7 +147,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { case FloatType | DoubleType => PredicateLeaf.Type.FLOAT case StringType => PredicateLeaf.Type.STRING case DateType => PredicateLeaf.Type.DATE - case TimestampType => PredicateLeaf.Type.TIMESTAMP + case TimestampType | TimestampNTZType => PredicateLeaf.Type.TIMESTAMP case _: DecimalType => PredicateLeaf.Type.DECIMAL case _ => throw QueryExecutionErrors.unsupportedOperationForDataTypeError(dataType) } @@ -168,6 +169,12 @@ private[sql] object OrcFilters extends OrcFiltersBase { toJavaDate(localDateToDays(value.asInstanceOf[LocalDate])) case _: TimestampType if value.isInstanceOf[Instant] => toJavaTimestamp(instantToMicros(value.asInstanceOf[Instant])) + case _: TimestampNTZType if value.isInstanceOf[LocalDateTime] => + val orcTimestamp = OrcUtils.toOrcNTZ(localDateTimeToMicros(value.asInstanceOf[LocalDateTime])) + // Hive meets OrcTimestamp will throw ClassNotFoundException, So convert it. + val timestamp = new Timestamp(orcTimestamp.getTime) + timestamp.setNanos(orcTimestamp.getNanos) + timestamp case _: YearMonthIntervalType => IntervalUtils.periodToMonths(value.asInstanceOf[Period]).longValue() case _: DayTimeIntervalType => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablePropertiesExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablePropertiesExec.scala index 3a7525994c..0d165d4dd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablePropertiesExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablePropertiesExec.scala @@ -39,7 +39,11 @@ case class ShowTablePropertiesExec( case Some(p) => val propValue = properties .getOrElse(p, s"Table ${catalogTable.name} does not have property: $p") - Seq(toCatalystRow(p, propValue)) + if (output.length == 1) { + Seq(toCatalystRow(propValue)) + } else { + Seq(toCatalystRow(p, propValue)) + } case None => properties.toSeq.sortBy(_._1).map(kv => toCatalystRow(kv._1, kv._2)).toSeq diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out index ac64bf6c05..b95c8dac9a 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/date.sql.out @@ -219,10 +219,9 @@ struct -- !query select next_day(timestamp_ltz"2015-07-23 12:12:12", "Mon") -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'next_day(TIMESTAMP '2015-07-23 12:12:12', 'Mon')' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP '2015-07-23 12:12:12'' is of timestamp type.; line 1 pos 7 +2015-07-27 -- !query @@ -354,19 +353,17 @@ NULL -- !query select date_add(timestamp_ltz'2011-11-11 12:12:12', 1) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_add(TIMESTAMP '2011-11-11 12:12:12', 1)' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP '2011-11-11 12:12:12'' is of timestamp type.; line 1 pos 7 +2011-11-12 -- !query select date_add(timestamp_ntz'2011-11-11 12:12:12', 1) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_add(TIMESTAMP_NTZ '2011-11-11 12:12:12', 1)' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP_NTZ '2011-11-11 12:12:12'' is of timestamp_ntz type.; line 1 pos 7 +2011-11-12 -- !query @@ -464,19 +461,17 @@ NULL -- !query select date_sub(timestamp_ltz'2011-11-11 12:12:12', 1) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_sub(TIMESTAMP '2011-11-11 12:12:12', 1)' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP '2011-11-11 12:12:12'' is of timestamp type.; line 1 pos 7 +2011-11-10 -- !query select date_sub(timestamp_ntz'2011-11-11 12:12:12', 1) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_sub(TIMESTAMP_NTZ '2011-11-11 12:12:12', 1)' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP_NTZ '2011-11-11 12:12:12'' is of timestamp_ntz type.; line 1 pos 7 +2011-11-10 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index 12c98ff138..e9c323254b 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -639,8 +639,8 @@ select make_interval(0, 0, 0, 0, 0, 0, 1234567890123456789) -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'make_interval(0, 0, 0, 0, 0, 0, 1234567890123456789L)' due to data type mismatch: argument 7 requires decimal(18,6) type, however, '1234567890123456789L' is of bigint type.; line 1 pos 7 +org.apache.spark.SparkArithmeticException +Decimal(expanded,1234567890123456789,20,0}) cannot be represented as Decimal(18, 6). If necessary set spark.sql.ansi.enabled to false to bypass this error. -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out index 879d89eb50..45d403859a 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -71,10 +71,9 @@ ab abcd ab NULL -- !query select left(null, -2) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'substring(NULL, 1, -2)' due to data type mismatch: argument 1 requires (string or binary) type, however, 'NULL' is of void type.; line 1 pos 7 +NULL -- !query @@ -89,19 +88,17 @@ invalid input syntax for type numeric: a. To return NULL instead, use 'try_cast' -- !query select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'substring('abcd', (- CAST('2' AS DOUBLE)), 2147483647)' due to data type mismatch: argument 2 requires int type, however, '(- CAST('2' AS DOUBLE))' is of double type.; line 1 pos 43 +cd abcd cd NULL -- !query select right(null, -2) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'substring(NULL, (- -2), 2147483647)' due to data type mismatch: argument 1 requires (string or binary) type, however, 'NULL' is of void type.; line 1 pos 7 +NULL -- !query @@ -109,8 +106,8 @@ select right("abcd", -2), right("abcd", 0), right("abcd", 'a') -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'substring('abcd', (- CAST('a' AS DOUBLE)), 2147483647)' due to data type mismatch: argument 2 requires int type, however, '(- CAST('a' AS DOUBLE))' is of double type.; line 1 pos 44 +java.lang.NumberFormatException +invalid input syntax for type numeric: a. To return NULL instead, use 'try_cast'. If necessary set spark.sql.ansi.enabled to false to bypass this error. -- !query @@ -308,28 +305,25 @@ trim -- !query SELECT btrim(encode(" xyz ", 'utf-8')) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'trim(encode(' xyz ', 'utf-8'))' due to data type mismatch: argument 1 requires string type, however, 'encode(' xyz ', 'utf-8')' is of binary type.; line 1 pos 7 +xyz -- !query SELECT btrim(encode('yxTomxx', 'utf-8'), encode('xyz', 'utf-8')) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'TRIM(BOTH encode('xyz', 'utf-8') FROM encode('yxTomxx', 'utf-8'))' due to data type mismatch: argument 1 requires string type, however, 'encode('yxTomxx', 'utf-8')' is of binary type. argument 2 requires string type, however, 'encode('xyz', 'utf-8')' is of binary type.; line 1 pos 7 +Tom -- !query SELECT btrim(encode('xxxbarxxx', 'utf-8'), encode('x', 'utf-8')) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'TRIM(BOTH encode('x', 'utf-8') FROM encode('xxxbarxxx', 'utf-8'))' due to data type mismatch: argument 1 requires string type, however, 'encode('xxxbarxxx', 'utf-8')' is of binary type. argument 2 requires string type, however, 'encode('x', 'utf-8')' is of binary type.; line 1 pos 7 +bar -- !query @@ -545,37 +539,33 @@ AABB -- !query SELECT lpad('abc', 5, x'57') -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'lpad('abc', 5, X'57')' due to data type mismatch: argument 3 requires string type, however, 'X'57'' is of binary type.; line 1 pos 7 +WWabc -- !query SELECT lpad(x'57', 5, 'abc') -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'lpad(X'57', 5, 'abc')' due to data type mismatch: argument 1 requires string type, however, 'X'57'' is of binary type.; line 1 pos 7 +abcaW -- !query SELECT rpad('abc', 5, x'57') -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'rpad('abc', 5, X'57')' due to data type mismatch: argument 3 requires string type, however, 'X'57'' is of binary type.; line 1 pos 7 +abcWW -- !query SELECT rpad(x'57', 5, 'abc') -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'rpad(X'57', 5, 'abc')' due to data type mismatch: argument 1 requires string type, however, 'X'57'' is of binary type.; line 1 pos 7 +Wabca -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/date.sql.out b/sql/core/src/test/resources/sql-tests/results/date.sql.out index 2eacb6cdce..5620289451 100644 --- a/sql/core/src/test/resources/sql-tests/results/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/date.sql.out @@ -349,10 +349,9 @@ struct -- !query select date_add(timestamp_ntz'2011-11-11 12:12:12', 1) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_add(TIMESTAMP_NTZ '2011-11-11 12:12:12', 1)' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP_NTZ '2011-11-11 12:12:12'' is of timestamp_ntz type.; line 1 pos 7 +2011-11-12 -- !query @@ -458,10 +457,9 @@ struct -- !query select date_sub(timestamp_ntz'2011-11-11 12:12:12', 1) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_sub(TIMESTAMP_NTZ '2011-11-11 12:12:12', 1)' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP_NTZ '2011-11-11 12:12:12'' is of timestamp_ntz type.; line 1 pos 7 +2011-11-10 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out index 573ce3db9e..74480ab6cc 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime-legacy.sql.out @@ -349,10 +349,9 @@ struct -- !query select date_add(timestamp_ntz'2011-11-11 12:12:12', 1) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_add(TIMESTAMP_NTZ '2011-11-11 12:12:12', 1)' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP_NTZ '2011-11-11 12:12:12'' is of timestamp_ntz type.; line 1 pos 7 +2011-11-12 -- !query @@ -458,10 +457,9 @@ struct -- !query select date_sub(timestamp_ntz'2011-11-11 12:12:12', 1) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'date_sub(TIMESTAMP_NTZ '2011-11-11 12:12:12', 1)' due to data type mismatch: argument 1 requires date type, however, 'TIMESTAMP_NTZ '2011-11-11 12:12:12'' is of timestamp_ntz type.; line 1 pos 7 +2011-11-10 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out index 8a4ee14201..bc13bb893b 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out @@ -4793,43 +4793,33 @@ Infinity -- !query select * from range(cast(0.0 as decimal(38, 18)), cast(4.0 as decimal(38, 18))) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -Table-valued function range with alternatives: - range(start: long, end: long, step: long, numSlices: integer) - range(start: long, end: long, step: long) - range(start: long, end: long) - range(end: long) -cannot be applied to (decimal(38,18), decimal(38,18)): Incompatible input data type. Expected: long; Found: decimal(38,18); line 1 pos 14 +0 +1 +2 +3 -- !query select * from range(cast(0.1 as decimal(38, 18)), cast(4.0 as decimal(38, 18)), cast(1.3 as decimal(38, 18))) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -Table-valued function range with alternatives: - range(start: long, end: long, step: long, numSlices: integer) - range(start: long, end: long, step: long) - range(start: long, end: long) - range(end: long) -cannot be applied to (decimal(38,18), decimal(38,18), decimal(38,18)): Incompatible input data type. Expected: long; Found: decimal(38,18); line 1 pos 14 +0 +1 +2 +3 -- !query select * from range(cast(4.0 as decimal(38, 18)), cast(-1.5 as decimal(38, 18)), cast(-2.2 as decimal(38, 18))) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -Table-valued function range with alternatives: - range(start: long, end: long, step: long, numSlices: integer) - range(start: long, end: long, step: long) - range(start: long, end: long) - range(end: long) -cannot be applied to (decimal(38,18), decimal(38,18), decimal(38,18)): Incompatible input data type. Expected: long; Found: decimal(38,18); line 1 pos 14 +0 +2 +4 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/strings.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/strings.sql.out index 253a5e49b8..28904629df 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/strings.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/strings.sql.out @@ -977,37 +977,33 @@ struct -- !query SELECT trim(binary('\\000') from binary('\\000Tom\\000')) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'TRIM(BOTH CAST('\\000' AS BINARY) FROM CAST('\\000Tom\\000' AS BINARY))' due to data type mismatch: argument 1 requires string type, however, 'CAST('\\000Tom\\000' AS BINARY)' is of binary type. argument 2 requires string type, however, 'CAST('\\000' AS BINARY)' is of binary type.; line 1 pos 7 +Tom -- !query SELECT btrim(binary('\\000trim\\000'), binary('\\000')) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'TRIM(BOTH CAST('\\000' AS BINARY) FROM CAST('\\000trim\\000' AS BINARY))' due to data type mismatch: argument 1 requires string type, however, 'CAST('\\000trim\\000' AS BINARY)' is of binary type. argument 2 requires string type, however, 'CAST('\\000' AS BINARY)' is of binary type.; line 1 pos 7 +trim -- !query SELECT btrim(binary(''), binary('\\000')) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'TRIM(BOTH CAST('\\000' AS BINARY) FROM CAST('' AS BINARY))' due to data type mismatch: argument 1 requires string type, however, 'CAST('' AS BINARY)' is of binary type. argument 2 requires string type, however, 'CAST('\\000' AS BINARY)' is of binary type.; line 1 pos 7 + -- !query SELECT btrim(binary('\\000trim\\000'), binary('')) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'TRIM(BOTH CAST('' AS BINARY) FROM CAST('\\000trim\\000' AS BINARY))' due to data type mismatch: argument 1 requires string type, however, 'CAST('\\000trim\\000' AS BINARY)' is of binary type. argument 2 requires string type, however, 'CAST('' AS BINARY)' is of binary type.; line 1 pos 7 +\000trim\000 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out index 25bcdb4b0c..56f50ec3a1 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/text.sql.out @@ -54,10 +54,9 @@ struct -- !query select length(42) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'length(42)' due to data type mismatch: argument 1 requires (string or binary) type, however, '42' is of int type.; line 1 pos 7 +2 -- !query @@ -65,8 +64,8 @@ select string('four: ') || 2+2 -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'concat(CAST('four: ' AS STRING), 2)' due to data type mismatch: input to function concat should have been string, binary or array, but it's [string, int]; line 1 pos 7 +java.lang.NumberFormatException +invalid input syntax for type numeric: four: 2. To return NULL instead, use 'try_cast'. If necessary set spark.sql.ansi.enabled to false to bypass this error. -- !query @@ -74,17 +73,16 @@ select 'four: ' || 2+2 -- !query schema struct<> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'concat('four: ', 2)' due to data type mismatch: input to function concat should have been string, binary or array, but it's [string, int]; line 1 pos 7 +java.lang.NumberFormatException +invalid input syntax for type numeric: four: 2. To return NULL instead, use 'try_cast'. If necessary set spark.sql.ansi.enabled to false to bypass this error. -- !query select 3 || 4.0 -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'concat(3, 4.0BD)' due to data type mismatch: input to function concat should have been string, binary or array, but it's [int, decimal(2,1)]; line 1 pos 7 +34.0 -- !query @@ -101,10 +99,9 @@ one -- !query select concat(1,2,3,'hello',true, false, to_date('20100309','yyyyMMdd')) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'concat(1, 2, 3, 'hello', true, false, to_date('20100309', 'yyyyMMdd'))' due to data type mismatch: input to function concat should have been string, binary or array, but it's [int, int, int, string, boolean, boolean, date]; line 1 pos 7 +123hellotruefalse2010-03-09 -- !query @@ -118,37 +115,33 @@ one -- !query select concat_ws('#',1,2,3,'hello',true, false, to_date('20100309','yyyyMMdd')) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'concat_ws('#', 1, 2, 3, 'hello', true, false, to_date('20100309', 'yyyyMMdd'))' due to data type mismatch: argument 2 requires (array or string) type, however, '1' is of int type. argument 3 requires (array or string) type, however, '2' is of int type. argument 4 requires (array or string) type, however, '3' is of int type. argument 6 requires (array or string) type, however, 'true' is of boolean type. argument 7 requires (array or string) type, however, 'false' is of boolean type. argument 8 requires (array or string) type, however, 'to_date('20100309', 'yyyyMMdd')' is of date type.; line 1 pos 7 +1#x#x#hello#true#false#x-03-09 -- !query select concat_ws(',',10,20,null,30) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'concat_ws(',', 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array or string) type, however, '10' is of int type. argument 3 requires (array or string) type, however, '20' is of int type. argument 4 requires (array or string) type, however, 'NULL' is of void type. argument 5 requires (array or string) type, however, '30' is of int type.; line 1 pos 7 +10,20,30 -- !query select concat_ws('',10,20,null,30) -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'concat_ws('', 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array or string) type, however, '10' is of int type. argument 3 requires (array or string) type, however, '20' is of int type. argument 4 requires (array or string) type, however, 'NULL' is of void type. argument 5 requires (array or string) type, however, '30' is of int type.; line 1 pos 7 +102030 -- !query select concat_ws(NULL,10,20,null,30) is null -- !query schema -struct<> +struct<(concat_ws(NULL, 10, 20, NULL, 30) IS NULL):boolean> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'concat_ws(CAST(NULL AS STRING), 10, 20, NULL, 30)' due to data type mismatch: argument 2 requires (array or string) type, however, '10' is of int type. argument 3 requires (array or string) type, however, '20' is of int type. argument 4 requires (array or string) type, however, 'NULL' is of void type. argument 5 requires (array or string) type, however, '30' is of int type.; line 1 pos 7 +true -- !query @@ -162,10 +155,19 @@ edcba -- !query select i, left('ahoj', i), right('ahoj', i) from range(-5, 6) t(i) order by i -- !query schema -struct<> +struct -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'substring('ahoj', 1, t.i)' due to data type mismatch: argument 3 requires int type, however, 't.i' is of bigint type.; line 1 pos 10 +-5 +-4 +-3 +-2 +-1 +0 +1 a j +2 ah oj +3 aho hoj +4 ahoj ahoj +5 ahoj ahoj -- !query @@ -277,7 +279,7 @@ select format_string('%0$s', 'Hello') struct<> -- !query output org.apache.spark.sql.AnalysisException -requirement failed: Illegal format argument index = 0; line 1 pos 7 +The argument_index of string format cannot contain position 0$.; line 1 pos 7 -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBasicOperationsBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBasicOperationsBenchmark.scala new file mode 100644 index 0000000000..a98c8d8a23 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBasicOperationsBenchmark.scala @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import scala.util.Random + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StateStore, StateStoreConf, StateStoreId, StateStoreProvider} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{IntegerType, StructField, StructType, TimestampType} +import org.apache.spark.util.Utils + +/** + * Synthetic benchmark for State Store basic operations. + * To run this benchmark: + * {{{ + * 1. without sbt: + * bin/spark-submit --class + * --jars , + * 2. build/sbt "sql/test:runMain " + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/StateStoreBasicOperationsBenchmark-results.txt". + * }}} + */ +object StateStoreBasicOperationsBenchmark extends SqlBasedBenchmark { + + private val keySchema = StructType( + Seq(StructField("key1", IntegerType, true), StructField("key2", TimestampType, true))) + private val valueSchema = StructType(Seq(StructField("value", IntegerType, true))) + + private val keyProjection = UnsafeProjection.create(keySchema) + private val valueProjection = UnsafeProjection.create(valueSchema) + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runPutBenchmark() + runDeleteBenchmark() + runEvictBenchmark() + } + + final def skip(benchmarkName: String)(func: => Any): Unit = { + output.foreach(_.write(s"$benchmarkName is skipped".getBytes)) + } + + private def runPutBenchmark(): Unit = { + def registerPutBenchmarkCase( + benchmark: Benchmark, + testName: String, + provider: StateStoreProvider, + version: Long, + rows: Seq[(UnsafeRow, UnsafeRow)]): Unit = { + benchmark.addTimerCase(testName) { timer => + val store = provider.getStore(version) + + timer.startTiming() + updateRows(store, rows) + timer.stopTiming() + + store.abort() + } + } + + runBenchmark("put rows") { + val numOfRows = Seq(10000) + val overwriteRates = Seq(100, 75, 50, 25, 10, 5, 0) + + numOfRows.foreach { numOfRow => + val testData = constructRandomizedTestData(numOfRow, + (1 to numOfRow).map(_ * 1000L).toList, 0) + + val inMemoryProvider = newHDFSBackedStateStoreProvider() + val rocksDBProvider = newRocksDBStateProvider() + val rocksDBWithNoTrackProvider = newRocksDBStateProvider(trackTotalNumberOfRows = false) + + val committedInMemoryVersion = loadInitialData(inMemoryProvider, testData) + val committedRocksDBVersion = loadInitialData(rocksDBProvider, testData) + val committedRocksDBWithNoTrackVersion = loadInitialData( + rocksDBWithNoTrackProvider, testData) + + overwriteRates.foreach { overwriteRate => + val numOfRowsToOverwrite = numOfRow * overwriteRate / 100 + + val numOfNewRows = numOfRow - numOfRowsToOverwrite + val newRows = if (numOfNewRows > 0) { + constructRandomizedTestData(numOfNewRows, + (1 to numOfNewRows).map(_ * 1000L).toList, 0) + } else { + Seq.empty[(UnsafeRow, UnsafeRow)] + } + val existingRows = if (numOfRowsToOverwrite > 0) { + Random.shuffle(testData).take(numOfRowsToOverwrite) + } else { + Seq.empty[(UnsafeRow, UnsafeRow)] + } + val rowsToPut = Random.shuffle(newRows ++ existingRows) + + val benchmark = new Benchmark(s"putting $numOfRow rows " + + s"($numOfRowsToOverwrite rows to overwrite - rate $overwriteRate)", + numOfRow, minNumIters = 10000, output = output) + + registerPutBenchmarkCase(benchmark, "In-memory", inMemoryProvider, + committedInMemoryVersion, rowsToPut) + registerPutBenchmarkCase(benchmark, "RocksDB (trackTotalNumberOfRows: true)", + rocksDBProvider, committedRocksDBVersion, rowsToPut) + registerPutBenchmarkCase(benchmark, "RocksDB (trackTotalNumberOfRows: false)", + rocksDBWithNoTrackProvider, committedRocksDBWithNoTrackVersion, rowsToPut) + + benchmark.run() + } + + inMemoryProvider.close() + rocksDBProvider.close() + rocksDBWithNoTrackProvider.close() + } + } + } + + private def runDeleteBenchmark(): Unit = { + def registerDeleteBenchmarkCase( + benchmark: Benchmark, + testName: String, + provider: StateStoreProvider, + version: Long, + keys: Seq[UnsafeRow]): Unit = { + benchmark.addTimerCase(testName) { timer => + val store = provider.getStore(version) + + timer.startTiming() + deleteRows(store, keys) + timer.stopTiming() + + store.abort() + } + } + + runBenchmark("delete rows") { + val numOfRows = Seq(10000) + val nonExistRates = Seq(100, 75, 50, 25, 10, 5, 0) + numOfRows.foreach { numOfRow => + val testData = constructRandomizedTestData(numOfRow, + (1 to numOfRow).map(_ * 1000L).toList, 0) + + val inMemoryProvider = newHDFSBackedStateStoreProvider() + val rocksDBProvider = newRocksDBStateProvider() + val rocksDBWithNoTrackProvider = newRocksDBStateProvider(trackTotalNumberOfRows = false) + + val committedInMemoryVersion = loadInitialData(inMemoryProvider, testData) + val committedRocksDBVersion = loadInitialData(rocksDBProvider, testData) + val committedRocksDBWithNoTrackVersion = loadInitialData( + rocksDBWithNoTrackProvider, testData) + + nonExistRates.foreach { nonExistRate => + val numOfRowsNonExist = numOfRow * nonExistRate / 100 + + val numOfExistingRows = numOfRow - numOfRowsNonExist + val nonExistingRows = if (numOfRowsNonExist > 0) { + constructRandomizedTestData(numOfRowsNonExist, + (numOfRow + 1 to numOfRow + numOfRowsNonExist).map(_ * 1000L).toList, 0) + } else { + Seq.empty[(UnsafeRow, UnsafeRow)] + } + val existingRows = if (numOfExistingRows > 0) { + Random.shuffle(testData).take(numOfExistingRows) + } else { + Seq.empty[(UnsafeRow, UnsafeRow)] + } + val keysToDelete = Random.shuffle(nonExistingRows ++ existingRows).map(_._1) + + val benchmark = new Benchmark(s"trying to delete $numOfRow rows " + + s"from $numOfRow rows" + + s"($numOfRowsNonExist rows are non-existing - rate $nonExistRate)", + numOfRow, minNumIters = 10000, output = output) + + registerDeleteBenchmarkCase(benchmark, "In-memory", inMemoryProvider, + committedInMemoryVersion, keysToDelete) + registerDeleteBenchmarkCase(benchmark, "RocksDB (trackTotalNumberOfRows: true)", + rocksDBProvider, committedRocksDBVersion, keysToDelete) + registerDeleteBenchmarkCase(benchmark, "RocksDB (trackTotalNumberOfRows: false)", + rocksDBWithNoTrackProvider, committedRocksDBWithNoTrackVersion, keysToDelete) + + benchmark.run() + } + + inMemoryProvider.close() + rocksDBProvider.close() + rocksDBWithNoTrackProvider.close() + } + } + } + + private def runEvictBenchmark(): Unit = { + def registerEvictBenchmarkCase( + benchmark: Benchmark, + testName: String, + provider: StateStoreProvider, + version: Long, + maxTimestampToEvictInMillis: Long, + expectedNumOfRows: Long): Unit = { + benchmark.addTimerCase(testName) { timer => + val store = provider.getStore(version) + + timer.startTiming() + evictAsFullScanAndRemove(store, maxTimestampToEvictInMillis, + expectedNumOfRows) + timer.stopTiming() + + store.abort() + } + } + + runBenchmark("evict rows") { + val numOfRows = Seq(10000) + val numOfEvictionRates = Seq(100, 75, 50, 25, 10, 5, 0) + + numOfRows.foreach { numOfRow => + val timestampsInMicros = (0L until numOfRow).map(ts => ts * 1000L).toList + + val testData = constructRandomizedTestData(numOfRow, timestampsInMicros, 0) + + val inMemoryProvider = newHDFSBackedStateStoreProvider() + val rocksDBProvider = newRocksDBStateProvider() + val rocksDBWithNoTrackProvider = newRocksDBStateProvider(trackTotalNumberOfRows = false) + + val committedInMemoryVersion = loadInitialData(inMemoryProvider, testData) + val committedRocksDBVersion = loadInitialData(rocksDBProvider, testData) + val committedRocksDBWithNoTrackVersion = loadInitialData( + rocksDBWithNoTrackProvider, testData) + + numOfEvictionRates.foreach { numOfEvictionRate => + val numOfRowsToEvict = numOfRow * numOfEvictionRate / 100 + val maxTimestampToEvictInMillis = timestampsInMicros + .take(numOfRow * numOfEvictionRate / 100) + .lastOption.map(_ / 1000).getOrElse(-1L) + + val benchmark = new Benchmark(s"evicting $numOfRowsToEvict rows " + + s"(maxTimestampToEvictInMillis: $maxTimestampToEvictInMillis) " + + s"from $numOfRow rows", + numOfRow, minNumIters = 10000, output = output) + + registerEvictBenchmarkCase(benchmark, "In-memory", inMemoryProvider, + committedInMemoryVersion, maxTimestampToEvictInMillis, numOfRowsToEvict) + + registerEvictBenchmarkCase(benchmark, "RocksDB (trackTotalNumberOfRows: true)", + rocksDBProvider, committedRocksDBVersion, maxTimestampToEvictInMillis, + numOfRowsToEvict) + + registerEvictBenchmarkCase(benchmark, "RocksDB (trackTotalNumberOfRows: false)", + rocksDBWithNoTrackProvider, committedRocksDBWithNoTrackVersion, + maxTimestampToEvictInMillis, numOfRowsToEvict) + + benchmark.run() + } + + inMemoryProvider.close() + rocksDBProvider.close() + rocksDBWithNoTrackProvider.close() + } + } + } + + private def getRows(store: StateStore, keys: Seq[UnsafeRow]): Seq[UnsafeRow] = { + keys.map(store.get) + } + + private def loadInitialData( + provider: StateStoreProvider, + data: Seq[(UnsafeRow, UnsafeRow)]): Long = { + val store = provider.getStore(0) + updateRows(store, data) + store.commit() + } + + private def updateRows( + store: StateStore, + rows: Seq[(UnsafeRow, UnsafeRow)]): Unit = { + rows.foreach { case (key, value) => + store.put(key, value) + } + } + + private def deleteRows( + store: StateStore, + rows: Seq[UnsafeRow]): Unit = { + rows.foreach { key => + store.remove(key) + } + } + + private def evictAsFullScanAndRemove( + store: StateStore, + maxTimestampToEvictMillis: Long, + expectedNumOfRows: Long): Unit = { + var removedRows: Long = 0 + store.iterator().foreach { r => + if (r.key.getLong(1) <= maxTimestampToEvictMillis * 1000L) { + store.remove(r.key) + removedRows += 1 + } + } + assert(removedRows == expectedNumOfRows, + s"expected: $expectedNumOfRows actual: $removedRows") + } + + // This prevents created keys to be in order, which may affect the performance on RocksDB. + private def constructRandomizedTestData( + numRows: Int, + timestamps: List[Long], + minIdx: Int = 0): Seq[(UnsafeRow, UnsafeRow)] = { + assert(numRows >= timestamps.length) + assert(numRows % timestamps.length == 0) + + (1 to numRows).map { idx => + val keyRow = new GenericInternalRow(2) + keyRow.setInt(0, Random.nextInt(Int.MaxValue)) + keyRow.setLong(1, timestamps((minIdx + idx) % timestamps.length)) // microseconds + val valueRow = new GenericInternalRow(1) + valueRow.setInt(0, minIdx + idx) + + val keyUnsafeRow = keyProjection(keyRow).copy() + val valueUnsafeRow = valueProjection(valueRow).copy() + + (keyUnsafeRow, valueUnsafeRow) + } + } + + private def newHDFSBackedStateStoreProvider(): StateStoreProvider = { + val storeId = StateStoreId(newDir(), Random.nextInt(), 0) + val provider = new HDFSBackedStateStoreProvider() + val storeConf = new StateStoreConf(new SQLConf()) + provider.init( + storeId, keySchema, valueSchema, 0, + storeConf, new Configuration) + provider + } + + private def newRocksDBStateProvider( + trackTotalNumberOfRows: Boolean = true): StateStoreProvider = { + val storeId = StateStoreId(newDir(), Random.nextInt(), 0) + val provider = new RocksDBStateStoreProvider() + val sqlConf = new SQLConf() + sqlConf.setConfString("spark.sql.streaming.stateStore.rocksdb.trackTotalNumberOfRows", + trackTotalNumberOfRows.toString) + val storeConf = new StateStoreConf(sqlConf) + + provider.init( + storeId, keySchema, valueSchema, 0, + storeConf, new Configuration) + provider + } + + private def newDir(): String = Utils.createTempDir().toString +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowTblPropertiesSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowTblPropertiesSuiteBase.scala index 28f58d67f2..7f9e927cb0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowTblPropertiesSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/ShowTblPropertiesSuiteBase.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StringType, StructType} /** @@ -51,7 +52,7 @@ trait ShowTblPropertiesSuiteBase extends QueryTest with DDLCommandTestUtils { Row("user", user)) assert(properties.schema === schema) - checkAnswer(properties.sort("key"), expected) + checkAnswer(properties, expected) } } @@ -87,4 +88,25 @@ trait ShowTblPropertiesSuiteBase extends QueryTest with DDLCommandTestUtils { assert(res.head.getString(1).contains(s"does not have property: $nonExistingKey")) } } + + test("KEEP THE LEGACY OUTPUT SCHEMA") { + Seq(true, false).foreach { keepLegacySchema => + withSQLConf(SQLConf.LEGACY_KEEP_COMMAND_OUTPUT_SCHEMA.key -> keepLegacySchema.toString) { + withNamespaceAndTable("ns1", "tbl") { tbl => + spark.sql(s"CREATE TABLE $tbl (id bigint, data string) $defaultUsing " + + "TBLPROPERTIES ('user'='spark', 'status'='new')") + + val properties = sql(s"SHOW TBLPROPERTIES $tbl ('status')") + val schema = properties.schema.fieldNames.toSeq + if (keepLegacySchema) { + assert(schema === Seq("value")) + checkAnswer(properties, Seq(Row("new"))) + } else { + assert(schema === Seq("key", "value")) + checkAnswer(properties, Seq(Row("status", "new"))) + } + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowTblPropertiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowTblPropertiesSuite.scala index 190b270122..67b9e21266 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowTblPropertiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/ShowTblPropertiesSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.command.v1 import org.apache.spark.sql.Row import org.apache.spark.sql.execution.command -import org.apache.spark.sql.internal.SQLConf /** * This base suite contains unified tests for the `SHOW TBLPROPERTIES` command that checks V1 @@ -33,17 +32,6 @@ import org.apache.spark.sql.internal.SQLConf trait ShowTblPropertiesSuiteBase extends command.ShowTblPropertiesSuiteBase with command.TestsV1AndV2Commands { - test("SHOW TBLPROPERTIES WITH LEGACY_KEEP_COMMAND_OUTPUT_SCHEMA") { - withNamespaceAndTable("ns1", "tbl") { tbl => - spark.sql(s"CREATE TABLE $tbl (id bigint, data string) $defaultUsing " + - s"TBLPROPERTIES ('user'='andrew', 'status'='new')") - withSQLConf(SQLConf.LEGACY_KEEP_COMMAND_OUTPUT_SCHEMA.key -> "true") { - checkAnswer(sql(s"SHOW TBLPROPERTIES $tbl('user')"), Row("andrew")) - checkAnswer(sql(s"SHOW TBLPROPERTIES $tbl('status')"), Row("new")) - } - } - } - test("SHOW TBLPROPERTIES FOR VIEW") { val v = "testview" withView(v) { @@ -56,7 +44,7 @@ trait ShowTblPropertiesSuiteBase extends command.ShowTblPropertiesSuiteBase } } - test("SHOW TBLPROPERTIES FOR TEMPORARY IEW") { + test("SHOW TBLPROPERTIES FOR TEMPORARY VIEW") { val v = "testview" withView(v) { spark.sql(s"CREATE TEMPORARY VIEW $v AS SELECT 1 AS c1;") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index c53cc10314..a62ce9226a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.orc import java.math.MathContext import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} -import java.time.{Duration, Period} +import java.time.{Duration, LocalDateTime, Period} import scala.collection.JavaConverters._ @@ -327,6 +327,39 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { } } + test("SPARK-36357: filter pushdown - timestamp_ntz") { + val localDateTimes = Seq( + LocalDateTime.of(1000, 1, 1, 1, 2, 3, 456000000), + LocalDateTime.of(1582, 10, 1, 0, 11, 22, 456000000), + LocalDateTime.of(1900, 1, 1, 23, 59, 59, 456000000), + LocalDateTime.of(2020, 5, 25, 10, 11, 12, 456000000)) + withOrcFile(localDateTimes.map(Tuple1(_))) { path => + readFile(path) { implicit df => + checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate($"_1" === localDateTimes(0), PredicateLeaf.Operator.EQUALS) + checkFilterPredicate($"_1" <=> localDateTimes(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate($"_1" < localDateTimes(1), PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate($"_1" > localDateTimes(2), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate($"_1" <= localDateTimes(0), PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate($"_1" >= localDateTimes(3), PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(localDateTimes(0)) === $"_1", PredicateLeaf.Operator.EQUALS) + checkFilterPredicate( + Literal(localDateTimes(0)) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(localDateTimes(1)) > $"_1", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate( + Literal(localDateTimes(2)) < $"_1", + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate( + Literal(localDateTimes(0)) >= $"_1", + PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(localDateTimes(3)) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + } + } + } + test("filter pushdown - combinations with logical operators") { withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => checkFilterPredicate( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index d91d9196f6..f12e5af9d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -122,7 +122,63 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession test("SPARK-36182: TimestampNTZ") { val data = Seq("2021-01-01T00:00:00", "1970-07-15T01:02:03.456789") .map(ts => Tuple1(LocalDateTime.parse(ts))) - checkParquetFile(data) + withAllParquetReaders { + checkParquetFile(data) + } + } + + test("Read TimestampNTZ and TimestampLTZ for various logical TIMESTAMP types") { + val schema = MessageTypeParser.parseMessageType( + """message root { + | required int64 timestamp_ltz_millis_depr(TIMESTAMP_MILLIS); + | required int64 timestamp_ltz_micros_depr(TIMESTAMP_MICROS); + | required int64 timestamp_ltz_millis(TIMESTAMP(MILLIS,true)); + | required int64 timestamp_ltz_micros(TIMESTAMP(MICROS,true)); + | required int64 timestamp_ntz_millis(TIMESTAMP(MILLIS,false)); + | required int64 timestamp_ntz_micros(TIMESTAMP(MICROS,false)); + |} + """.stripMargin) + + for (dictEnabled <- Seq(true, false)) { + withTempDir { dir => + val tablePath = new Path(s"${dir.getCanonicalPath}/timestamps.parquet") + val numRecords = 100 + + val writer = createParquetWriter(schema, tablePath, dictionaryEnabled = dictEnabled) + (0 until numRecords).map { i => + val record = new SimpleGroup(schema) + for (group <- Seq(0, 2, 4)) { + record.add(group, 1000L) // millis + record.add(group + 1, 1000000L) // micros + } + writer.write(record) + } + writer.close + + withAllParquetReaders { + val df = spark.read.parquet(tablePath.toString) + assertResult(df.schema) { + StructType( + StructField("timestamp_ltz_millis_depr", TimestampType, nullable = true) :: + StructField("timestamp_ltz_micros_depr", TimestampType, nullable = true) :: + StructField("timestamp_ltz_millis", TimestampType, nullable = true) :: + StructField("timestamp_ltz_micros", TimestampType, nullable = true) :: + StructField("timestamp_ntz_millis", TimestampNTZType, nullable = true) :: + StructField("timestamp_ntz_micros", TimestampNTZType, nullable = true) :: + Nil + ) + } + + val exp = (0 until numRecords).map { _ => + val ltz_value = new java.sql.Timestamp(1000L) + val ntz_value = LocalDateTime.of(1970, 1, 1, 0, 0, 1) + (ltz_value, ltz_value, ltz_value, ltz_value, ntz_value, ntz_value) + }.toDF() + + checkAnswer(df, exp) + } + } + } } testStandardAndLegacyModes("fixed-length decimals") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index e6d5382b74..057de2abdb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -171,12 +171,14 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS (null), ("1965-01-01 10:11:12.123456")) .toDS().select($"value".cast("timestamp_ntz")) - checkAnswer(sql("select * from ts"), expected) + withAllParquetReaders { + checkAnswer(sql("select * from ts"), expected) + } } } test("SPARK-36182: can't read TimestampLTZ as TimestampNTZ") { - val data = (1 to 10).map { i => + val data = (1 to 1000).map { i => val ts = new java.sql.Timestamp(i) Row(ts) } @@ -184,16 +186,20 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS val providedSchema = StructType(Seq(StructField("time", TimestampNTZType, false))) Seq("INT96", "TIMESTAMP_MICROS", "TIMESTAMP_MILLIS").foreach { tsType => - withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> tsType) { - withTempPath { file => - val df = spark.createDataFrame(sparkContext.parallelize(data), actualSchema) - df.write.parquet(file.getCanonicalPath) - withAllParquetReaders { - val msg = intercept[SparkException] { - spark.read.schema(providedSchema).parquet(file.getCanonicalPath).collect() - }.getMessage - assert(msg.contains( - "Unable to create Parquet converter for data type \"timestamp_ntz\"")) + Seq(true, false).foreach { dictionaryEnabled => + withSQLConf( + SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> tsType, + ParquetOutputFormat.ENABLE_DICTIONARY -> dictionaryEnabled.toString) { + withTempPath { file => + val df = spark.createDataFrame(sparkContext.parallelize(data), actualSchema) + df.write.parquet(file.getCanonicalPath) + withAllParquetReaders { + val msg = intercept[SparkException] { + spark.read.schema(providedSchema).parquet(file.getCanonicalPath).collect() + }.getMessage + assert(msg.contains( + "Unable to create Parquet converter for data type \"timestamp_ntz\"")) + } } } } @@ -201,13 +207,13 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS } test("SPARK-36182: read TimestampNTZ as TimestampLTZ") { - val data = (1 to 10).map { i => + val data = (1 to 1000).map { i => // The second parameter is `nanoOfSecond`, while java.sql.Timestamp accepts milliseconds // as input. So here we multiple the `nanoOfSecond` by NANOS_PER_MILLIS - val ts = LocalDateTime.ofEpochSecond(0, i * 1000000, ZoneOffset.UTC) + val ts = LocalDateTime.ofEpochSecond(i / 1000, (i % 1000) * 1000000, ZoneOffset.UTC) Row(ts) } - val answer = (1 to 10).map { i => + val answer = (1 to 1000).map { i => val ts = new java.sql.Timestamp(i) Row(ts) } @@ -218,7 +224,11 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS val df = spark.createDataFrame(sparkContext.parallelize(data), actualSchema) df.write.parquet(file.getCanonicalPath) withAllParquetReaders { - checkAnswer(spark.read.schema(providedSchema).parquet(file.getCanonicalPath), answer) + Seq(true, false).foreach { dictionaryEnabled => + withSQLConf(ParquetOutputFormat.ENABLE_DICTIONARY -> dictionaryEnabled.toString) { + checkAnswer(spark.read.schema(providedSchema).parquet(file.getCanonicalPath), answer) + } + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala index ed20f2536c..e82b9df93d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSessionWindowSuite.scala @@ -51,7 +51,7 @@ class StreamingSessionWindowSuite extends StreamTest (SQLConf.STATE_STORE_PROVIDER_CLASS.key, value.stripSuffix("$")) } // RocksDB doesn't support Apple Silicon yet - if (Utils.isMac && System.getProperty("os.arch").equals("aarch64")) { + if (Utils.isMacOnAppleSilicon) { providerOptions = providerOptions .filterNot(_._2.contains(classOf[RocksDBStateStoreProvider].getSimpleName)) }