From 9ba7168d75ba298fda74c38793f07a81f1590f62 Mon Sep 17 00:00:00 2001 From: Alex Barreto Date: Tue, 10 Aug 2021 13:40:10 -0400 Subject: [PATCH] [SPARK-36464][CORE] Fix Underlying Size Variable Initialization in ChunkedByteBufferOutputStream for Writing Over 2GB Data MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? The `size` method of `ChunkedByteBufferOutputStream` returns a `Long` value; however, the underlying `_size` variable is initialized as `Int`. That causes an overflow and returns a negative size when over 2GB data is written into `ChunkedByteBufferOutputStream` This PR proposes to change the underlying `_size` variable from `Int` to `Long` at the initialization ### Why are the changes needed? Be cause the `size` method of `ChunkedByteBufferOutputStream` incorrectly returns a negative value when over 2GB data is written. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Passed existing tests ``` build/sbt "core/testOnly *ChunkedByteBufferOutputStreamSuite" ``` Also added a new unit test ``` build/sbt "core/testOnly *ChunkedByteBufferOutputStreamSuite – -z SPARK-36464" ``` Closes #33690 from kazuyukitanimura/SPARK-36464. Authored-by: Kazuyuki Tanimura Signed-off-by: Dongjoon Hyun --- .../spark/network/util/TransportConf.java | 4 +- .../network/shuffle/ExternalBlockHandler.java | 58 ---- .../shuffle/NoOpMergedShuffleFileManager.java | 86 ++++++ .../shuffle/RemoteBlockPushResolver.java | 45 ++-- .../network/yarn/YarnShuffleService.java | 3 +- conf/spark-env.sh.template | 10 +- .../io/ChunkedByteBufferOutputStream.scala | 3 +- .../ChunkedByteBufferOutputStreamSuite.scala | 10 + python/pyspark/pandas/groupby.py | 28 +- python/pyspark/pandas/tests/test_expanding.py | 35 ++- python/pyspark/pandas/tests/test_rolling.py | 35 ++- python/pyspark/pandas/window.py | 249 +++++++++--------- .../yarn/YarnShuffleServiceSuite.scala | 8 +- .../catalyst/analysis/AnsiTypeCoercion.scala | 5 + .../sql/catalyst/analysis/CheckAnalysis.scala | 25 +- .../sql/catalyst/analysis/TypeCoercion.scala | 5 + .../sql/catalyst/json/JacksonParser.scala | 13 +- .../plans/logical/basicLogicalOperators.scala | 2 +- .../catalyst/analysis/TypeCoercionSuite.scala | 47 ++++ .../datasources/v2/orc/OrcScanBuilder.scala | 3 +- .../resources/sql-tests/inputs/interval.sql | 13 + .../sql-tests/results/ansi/interval.sql.out | 98 ++++++- .../sql-tests/results/interval.sql.out | 98 ++++++- .../timestampNTZ/timestamp-ansi.sql.out | 5 +- .../results/timestampNTZ/timestamp.sql.out | 5 +- .../org/apache/spark/sql/ExplainSuite.scala | 2 +- .../spark/sql/connector/AlterTableTests.scala | 11 + .../V2CommandsCaseSensitivitySuite.scala | 14 +- .../sql/execution/SQLViewTestSuite.scala | 37 +++ 29 files changed, 680 insertions(+), 277 deletions(-) create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/NoOpMergedShuffleFileManager.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 8e7ecf500e..69b8b25454 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -378,13 +378,13 @@ public boolean useOldFetchProtocol() { * Class name of the implementation of MergedShuffleFileManager that merges the blocks * pushed to it when push-based shuffle is enabled. By default, push-based shuffle is disabled at * a cluster level because this configuration is set to - * 'org.apache.spark.network.shuffle.ExternalBlockHandler$NoOpMergedShuffleFileManager'. + * 'org.apache.spark.network.shuffle.NoOpMergedShuffleFileManager'. * To turn on push-based shuffle at a cluster level, set the configuration to * 'org.apache.spark.network.shuffle.RemoteBlockPushResolver'. */ public String mergedShuffleFileManagerImpl() { return conf.get("spark.shuffle.server.mergedShuffleFileManagerImpl", - "org.apache.spark.network.shuffle.ExternalBlockHandler$NoOpMergedShuffleFileManager"); + "org.apache.spark.network.shuffle.NoOpMergedShuffleFileManager"); } /** diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java index 71741f2cba..1e413f6b2f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java @@ -594,64 +594,6 @@ public ManagedBuffer next() { } } - /** - * Dummy implementation of merged shuffle file manager. Suitable for when push-based shuffle - * is not enabled. - * - * @since 3.1.0 - */ - public static class NoOpMergedShuffleFileManager implements MergedShuffleFileManager { - - // This constructor is needed because we use this constructor to instantiate an implementation - // of MergedShuffleFileManager using reflection. - // See YarnShuffleService#newMergedShuffleFileManagerInstance. - public NoOpMergedShuffleFileManager(TransportConf transportConf) {} - - @Override - public StreamCallbackWithID receiveBlockDataAsStream(PushBlockStream msg) { - throw new UnsupportedOperationException("Cannot handle shuffle block merge"); - } - - @Override - public MergeStatuses finalizeShuffleMerge(FinalizeShuffleMerge msg) throws IOException { - throw new UnsupportedOperationException("Cannot handle shuffle block merge"); - } - - @Override - public void registerExecutor(String appId, ExecutorShuffleInfo executorInfo) { - // No-Op. Do nothing. - } - - @Override - public void applicationRemoved(String appId, boolean cleanupLocalDirs) { - // No-Op. Do nothing. - } - - @Override - public ManagedBuffer getMergedBlockData( - String appId, - int shuffleId, - int shuffleMergeId, - int reduceId, - int chunkId) { - throw new UnsupportedOperationException("Cannot handle shuffle block merge"); - } - - @Override - public MergedBlockMeta getMergedBlockMeta( - String appId, - int shuffleId, - int shuffleMergeId, - int reduceId) { - throw new UnsupportedOperationException("Cannot handle shuffle block merge"); - } - - @Override - public String[] getMergedBlockDirs(String appId) { - throw new UnsupportedOperationException("Cannot handle shuffle block merge"); - } - } - @Override public void channelActive(TransportClient client) { metrics.activeConnections.inc(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/NoOpMergedShuffleFileManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/NoOpMergedShuffleFileManager.java new file mode 100644 index 0000000000..f47bfc3077 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/NoOpMergedShuffleFileManager.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.StreamCallbackWithID; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge; +import org.apache.spark.network.shuffle.protocol.MergeStatuses; +import org.apache.spark.network.shuffle.protocol.PushBlockStream; +import org.apache.spark.network.util.TransportConf; + +/** + * Dummy implementation of merged shuffle file manager. Suitable for when push-based shuffle + * is not enabled. + * + * @since 3.1.0 + */ +public class NoOpMergedShuffleFileManager implements MergedShuffleFileManager { + + // This constructor is needed because we use this constructor to instantiate an implementation + // of MergedShuffleFileManager using reflection. + // See YarnShuffleService#newMergedShuffleFileManagerInstance. + public NoOpMergedShuffleFileManager(TransportConf transportConf) {} + + @Override + public StreamCallbackWithID receiveBlockDataAsStream(PushBlockStream msg) { + throw new UnsupportedOperationException("Cannot handle shuffle block merge"); + } + + @Override + public MergeStatuses finalizeShuffleMerge(FinalizeShuffleMerge msg) throws IOException { + throw new UnsupportedOperationException("Cannot handle shuffle block merge"); + } + + @Override + public void registerExecutor(String appId, ExecutorShuffleInfo executorInfo) { + // No-Op. Do nothing. + } + + @Override + public void applicationRemoved(String appId, boolean cleanupLocalDirs) { + // No-Op. Do nothing. + } + + @Override + public ManagedBuffer getMergedBlockData( + String appId, + int shuffleId, + int shuffleMergeId, + int reduceId, + int chunkId) { + throw new UnsupportedOperationException("Cannot handle shuffle block merge"); + } + + @Override + public MergedBlockMeta getMergedBlockMeta( + String appId, + int shuffleId, + int shuffleMergeId, + int reduceId) { + throw new UnsupportedOperationException("Cannot handle shuffle block merge"); + } + + @Override + public String[] getMergedBlockDirs(String appId) { + throw new UnsupportedOperationException("Cannot handle shuffle block merge"); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java index 8578843f1a..84ecf3d18a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java @@ -678,9 +678,7 @@ public String getID() { private void writeBuf(ByteBuffer buf) throws IOException { while (buf.hasRemaining()) { long updatedPos = partitionInfo.getDataFilePos() + length; - logger.debug("{} shuffleId {} shuffleMergeId {} reduceId {} current pos" - + " {} updated pos {}", partitionInfo.appId, partitionInfo.shuffleId, - partitionInfo.shuffleMergeId, partitionInfo.reduceId, + logger.debug("{} current pos {} updated pos {}", partitionInfo, partitionInfo.getDataFilePos(), updatedPos); length += partitionInfo.dataChannel.write(buf, updatedPos); } @@ -795,9 +793,7 @@ public void onData(String streamId, ByteBuffer buf) throws IOException { return; } abortIfNecessary(); - logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} onData writable", - partitionInfo.appId, partitionInfo.shuffleId, partitionInfo.shuffleMergeId, - partitionInfo.reduceId); + logger.trace("{} onData writable", partitionInfo); if (partitionInfo.getCurrentMapIndex() < 0) { partitionInfo.setCurrentMapIndex(mapIndex); } @@ -817,9 +813,7 @@ public void onData(String streamId, ByteBuffer buf) throws IOException { throw ioe; } } else { - logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} onData deferred", - partitionInfo.appId, partitionInfo.shuffleId, partitionInfo.shuffleMergeId, - partitionInfo.reduceId); + logger.trace("{} onData deferred", partitionInfo); // If we cannot write to disk, we buffer the current block chunk in memory so it could // potentially be written to disk later. We take our best effort without guarantee // that the block will be written to disk. If the block data is divided into multiple @@ -852,9 +846,7 @@ public void onData(String streamId, ByteBuffer buf) throws IOException { @Override public void onComplete(String streamId) throws IOException { synchronized (partitionInfo) { - logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} onComplete invoked", - partitionInfo.appId, partitionInfo.shuffleId, partitionInfo.shuffleMergeId, - partitionInfo.reduceId); + logger.trace("{} onComplete invoked", partitionInfo); // Initially when this request got to the server, the shuffle merge finalize request // was not received yet or this was the latest stage attempt (or latest shuffleMergeId) // generating shuffle output for the shuffle ID. By the time we finish reading this @@ -936,9 +928,7 @@ public void onFailure(String streamId, Throwable throwable) throws IOException { synchronized (partitionInfo) { if (!isStaleOrTooLate(appShuffleInfo.shuffles.get(partitionInfo.shuffleId), partitionInfo.shuffleMergeId, partitionInfo.reduceId)) { - logger.debug("{} shuffleId {} shuffleMergeId {} reduceId {}" - + " encountered failure", partitionInfo.appId, partitionInfo.shuffleId, - partitionInfo.shuffleMergeId, partitionInfo.reduceId); + logger.debug("{} encountered failure", partitionInfo); partitionInfo.setCurrentMapIndex(-1); } } @@ -1032,9 +1022,7 @@ public long getDataFilePos() { } public void setDataFilePos(long dataFilePos) { - logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} current pos {}" - + " update pos {}", appId, shuffleId, shuffleMergeId, reduceId, this.dataFilePos, - dataFilePos); + logger.trace("{} current pos {} update pos {}", this, this.dataFilePos, dataFilePos); this.dataFilePos = dataFilePos; } @@ -1043,9 +1031,7 @@ int getCurrentMapIndex() { } void setCurrentMapIndex(int mapIndex) { - logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} updated mapIndex {}" - + " current mapIndex {}", appId, shuffleId, shuffleMergeId, reduceId, - currentMapIndex, mapIndex); + logger.trace("{} mapIndex {} current mapIndex {}", this, currentMapIndex, mapIndex); this.currentMapIndex = mapIndex; } @@ -1054,8 +1040,7 @@ long getLastChunkOffset() { } void blockMerged(int mapIndex) { - logger.debug("{} shuffleId {} shuffleMergeId {} reduceId {} updated merging mapIndex {}", - appId, shuffleId, shuffleMergeId, reduceId, mapIndex); + logger.debug("{} updated merging mapIndex {}", this, mapIndex); mapTracker.add(mapIndex); chunkTracker.add(mapIndex); lastMergedMapIndex = mapIndex; @@ -1073,9 +1058,8 @@ void resetChunkTracker() { */ void updateChunkInfo(long chunkOffset, int mapIndex) throws IOException { try { - logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} index current {}" - + " updated {}", appId, shuffleId, shuffleMergeId, reduceId, - this.lastChunkOffset, chunkOffset); + logger.trace("{} index current {} updated {}", this, this.lastChunkOffset, + chunkOffset); if (indexMetaUpdateFailed) { indexFile.getChannel().position(indexFile.getPos()); } @@ -1103,8 +1087,7 @@ private void writeChunkTracker(int mapIndex) throws IOException { return; } chunkTracker.add(mapIndex); - logger.trace("{} shuffleId {} shuffleMergeId {} reduceId {} mapIndex {}" - + " write chunk to meta file", appId, shuffleId, shuffleMergeId, reduceId, mapIndex); + logger.trace("{} mapIndex {} write chunk to meta file", this, mapIndex); if (indexMetaUpdateFailed) { metaFile.getChannel().position(metaFile.getPos()); } @@ -1169,6 +1152,12 @@ void closeAllFilesAndDeleteIfNeeded(boolean delete) { } } + @Override + public String toString() { + return String.format("Application %s shuffleId %s shuffleMergeId %s reduceId %s", + appId, shuffleId, shuffleMergeId, reduceId); + } + @Override protected void finalize() throws Throwable { closeAllFilesAndDeleteIfNeeded(false); diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index cb6d5d0ca2..ac163692c4 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -43,6 +43,7 @@ import org.apache.hadoop.yarn.api.records.ContainerId; import org.apache.hadoop.yarn.server.api.*; import org.apache.spark.network.shuffle.MergedShuffleFileManager; +import org.apache.spark.network.shuffle.NoOpMergedShuffleFileManager; import org.apache.spark.network.util.LevelDBProvider; import org.iq80.leveldb.DB; import org.iq80.leveldb.DBIterator; @@ -284,7 +285,7 @@ static MergedShuffleFileManager newMergedShuffleFileManagerInstance(TransportCon return mergeManagerSubClazz.getConstructor(TransportConf.class).newInstance(conf); } catch (Exception e) { logger.error("Unable to create an instance of {}", mergeManagerImplClassName); - return new ExternalBlockHandler.NoOpMergedShuffleFileManager(conf); + return new NoOpMergedShuffleFileManager(conf); } } diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index 3c003f45ed..a2f1380692 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -32,14 +32,18 @@ # - SPARK_LOCAL_DIRS, storage directories to use on this node for shuffle and RDD data # - MESOS_NATIVE_JAVA_LIBRARY, to point to your libmesos.so if you use Mesos -# Options read in YARN client/cluster mode +# Options read in any mode # - SPARK_CONF_DIR, Alternate conf dir. (Default: ${SPARK_HOME}/conf) -# - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files -# - YARN_CONF_DIR, to point Spark towards YARN configuration files when you use YARN # - SPARK_EXECUTOR_CORES, Number of cores for the executors (Default: 1). # - SPARK_EXECUTOR_MEMORY, Memory per Executor (e.g. 1000M, 2G) (Default: 1G) # - SPARK_DRIVER_MEMORY, Memory for Driver (e.g. 1000M, 2G) (Default: 1G) +# Options read in any cluster manager using HDFS +# - HADOOP_CONF_DIR, to point Spark towards Hadoop configuration files + +# Options read in YARN client/cluster mode +# - YARN_CONF_DIR, to point Spark towards YARN configuration files when you use YARN + # Options for the daemons used in the standalone deploy mode # - SPARK_MASTER_HOST, to bind the master to a different IP address or hostname # - SPARK_MASTER_PORT / SPARK_MASTER_WEBUI_PORT, to use non-default ports for the master diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala index a625b32895..34d36655a6 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala @@ -48,7 +48,7 @@ private[spark] class ChunkedByteBufferOutputStream( * This can also never be 0. */ private[this] var position = chunkSize - private[this] var _size = 0 + private[this] var _size = 0L private[this] var closed: Boolean = false def size: Long = _size @@ -120,4 +120,5 @@ private[spark] class ChunkedByteBufferOutputStream( new ChunkedByteBuffer(ret) } } + } diff --git a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala index 8696174567..29443e275f 100644 --- a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala @@ -119,4 +119,14 @@ class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { assert(arrays(1).toSeq === ref.slice(10, 20)) assert(arrays(2).toSeq === ref.slice(20, 30)) } + + test("SPARK-36464: size returns correct positive number even with over 2GB data") { + val ref = new Array[Byte](1024 * 1024 * 1024) + val o = new ChunkedByteBufferOutputStream(1024 * 1024, ByteBuffer.allocate) + o.write(ref) + o.write(ref) + o.close() + assert(o.size > 0L) // make sure it is not overflowing + assert(o.size == ref.length.toLong * 2) + } } diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py index 376592dee3..fdd5523c54 100644 --- a/python/pyspark/pandas/groupby.py +++ b/python/pyspark/pandas/groupby.py @@ -118,7 +118,7 @@ def __init__( groupkeys: List[Series], as_index: bool, dropna: bool, - column_labels_to_exlcude: Set[Label], + column_labels_to_exclude: Set[Label], agg_columns_selected: bool, agg_columns: List[Series], ): @@ -126,7 +126,7 @@ def __init__( self._groupkeys = groupkeys self._as_index = as_index self._dropna = dropna - self._column_labels_to_exlcude = column_labels_to_exlcude + self._column_labels_to_exclude = column_labels_to_exclude self._agg_columns_selected = agg_columns_selected self._agg_columns = agg_columns @@ -1168,7 +1168,7 @@ def apply(self, func: Callable, *args: Any, **kwargs: Any) -> Union[DataFrame, S agg_columns = [ psdf._psser_for(label) for label in psdf._internal.column_labels - if label not in self._column_labels_to_exlcude + if label not in self._column_labels_to_exclude ] psdf, groupkey_labels, groupkey_names = GroupBy._prepare_group_map_apply( @@ -1365,7 +1365,7 @@ def filter(self, func: Callable[[FrameLike], FrameLike]) -> FrameLike: agg_columns = [ psdf._psser_for(label) for label in psdf._internal.column_labels - if label not in self._column_labels_to_exlcude + if label not in self._column_labels_to_exclude ] data_schema = ( @@ -1883,7 +1883,7 @@ def _limit(self, n: int, asc: bool) -> FrameLike: agg_columns = [ psdf._psser_for(label) for label in psdf._internal.column_labels - if label not in self._column_labels_to_exlcude + if label not in self._column_labels_to_exclude ] psdf, groupkey_labels, _ = GroupBy._prepare_group_map_apply( @@ -2701,17 +2701,17 @@ def _build( ( psdf, new_by_series, - column_labels_to_exlcude, + column_labels_to_exclude, ) = GroupBy._resolve_grouping_from_diff_dataframes(psdf, by) else: new_by_series = GroupBy._resolve_grouping(psdf, by) - column_labels_to_exlcude = set() + column_labels_to_exclude = set() return DataFrameGroupBy( psdf, new_by_series, as_index=as_index, dropna=dropna, - column_labels_to_exlcude=column_labels_to_exlcude, + column_labels_to_exclude=column_labels_to_exclude, ) def __init__( @@ -2720,20 +2720,20 @@ def __init__( by: List[Series], as_index: bool, dropna: bool, - column_labels_to_exlcude: Set[Label], + column_labels_to_exclude: Set[Label], agg_columns: List[Label] = None, ): agg_columns_selected = agg_columns is not None if agg_columns_selected: for label in agg_columns: - if label in column_labels_to_exlcude: + if label in column_labels_to_exclude: raise KeyError(label) else: agg_columns = [ label for label in psdf._internal.column_labels if not any(label == key._column_label and key._psdf is psdf for key in by) - and label not in column_labels_to_exlcude + and label not in column_labels_to_exclude ] super().__init__( @@ -2741,7 +2741,7 @@ def __init__( groupkeys=by, as_index=as_index, dropna=dropna, - column_labels_to_exlcude=column_labels_to_exlcude, + column_labels_to_exclude=column_labels_to_exclude, agg_columns_selected=agg_columns_selected, agg_columns=[psdf[label] for label in agg_columns], ) @@ -2781,7 +2781,7 @@ def __getitem__(self, item: Any) -> GroupBy: self._groupkeys, as_index=self._as_index, dropna=self._dropna, - column_labels_to_exlcude=self._column_labels_to_exlcude, + column_labels_to_exclude=self._column_labels_to_exclude, agg_columns=item, ) @@ -2925,7 +2925,7 @@ def __init__(self, psser: Series, by: List[Series], as_index: bool = True, dropn groupkeys=by, as_index=True, dropna=dropna, - column_labels_to_exlcude=set(), + column_labels_to_exclude=set(), agg_columns_selected=True, agg_columns=[psser], ) diff --git a/python/pyspark/pandas/tests/test_expanding.py b/python/pyspark/pandas/tests/test_expanding.py index 2cd5e5284c..d52ccbacf7 100644 --- a/python/pyspark/pandas/tests/test_expanding.py +++ b/python/pyspark/pandas/tests/test_expanding.py @@ -146,10 +146,8 @@ def _test_groupby_expanding_func(self, f): pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}) psdf = ps.from_pandas(pdf) + # The behavior of GroupBy.expanding is changed from pandas 1.3. if LooseVersion(pd.__version__) >= LooseVersion("1.3"): - # TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3 - pass - else: self.assert_eq( getattr(psdf.groupby(psdf.a).expanding(2), f)().sort_index(), getattr(pdf.groupby(pdf.a).expanding(2), f)().sort_index(), @@ -162,6 +160,19 @@ def _test_groupby_expanding_func(self, f): getattr(psdf.groupby(psdf.a + 1).expanding(2), f)().sort_index(), getattr(pdf.groupby(pdf.a + 1).expanding(2), f)().sort_index(), ) + else: + self.assert_eq( + getattr(psdf.groupby(psdf.a).expanding(2), f)().sort_index(), + getattr(pdf.groupby(pdf.a).expanding(2), f)().drop("a", axis=1).sort_index(), + ) + self.assert_eq( + getattr(psdf.groupby(psdf.a).expanding(2), f)().sum(), + getattr(pdf.groupby(pdf.a).expanding(2), f)().sum().drop("a"), + ) + self.assert_eq( + getattr(psdf.groupby(psdf.a + 1).expanding(2), f)().sort_index(), + getattr(pdf.groupby(pdf.a + 1).expanding(2), f)().drop("a", axis=1).sort_index(), + ) self.assert_eq( getattr(psdf.b.groupby(psdf.a).expanding(2), f)().sort_index(), @@ -181,10 +192,8 @@ def _test_groupby_expanding_func(self, f): pdf.columns = columns psdf.columns = columns + # The behavior of GroupBy.expanding is changed from pandas 1.3. if LooseVersion(pd.__version__) >= LooseVersion("1.3"): - # TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3 - pass - else: self.assert_eq( getattr(psdf.groupby(("a", "x")).expanding(2), f)().sort_index(), getattr(pdf.groupby(("a", "x")).expanding(2), f)().sort_index(), @@ -194,6 +203,20 @@ def _test_groupby_expanding_func(self, f): getattr(psdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)().sort_index(), getattr(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)().sort_index(), ) + else: + self.assert_eq( + getattr(psdf.groupby(("a", "x")).expanding(2), f)().sort_index(), + getattr(pdf.groupby(("a", "x")).expanding(2), f)() + .drop(("a", "x"), axis=1) + .sort_index(), + ) + + self.assert_eq( + getattr(psdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)().sort_index(), + getattr(pdf.groupby([("a", "x"), ("a", "y")]).expanding(2), f)() + .drop([("a", "x"), ("a", "y")], axis=1) + .sort_index(), + ) def test_groupby_expanding_count(self): # The behaviour of ExpandingGroupby.count are different between pandas>=1.0.0 and lower, diff --git a/python/pyspark/pandas/tests/test_rolling.py b/python/pyspark/pandas/tests/test_rolling.py index 7409d6988c..3c9563c705 100644 --- a/python/pyspark/pandas/tests/test_rolling.py +++ b/python/pyspark/pandas/tests/test_rolling.py @@ -112,10 +112,8 @@ def _test_groupby_rolling_func(self, f): pdf = pd.DataFrame({"a": [1.0, 2.0, 3.0, 2.0], "b": [4.0, 2.0, 3.0, 1.0]}) psdf = ps.from_pandas(pdf) + # The behavior of GroupBy.rolling is changed from pandas 1.3. if LooseVersion(pd.__version__) >= LooseVersion("1.3"): - # TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3 - pass - else: self.assert_eq( getattr(psdf.groupby(psdf.a).rolling(2), f)().sort_index(), getattr(pdf.groupby(pdf.a).rolling(2), f)().sort_index(), @@ -128,6 +126,19 @@ def _test_groupby_rolling_func(self, f): getattr(psdf.groupby(psdf.a + 1).rolling(2), f)().sort_index(), getattr(pdf.groupby(pdf.a + 1).rolling(2), f)().sort_index(), ) + else: + self.assert_eq( + getattr(psdf.groupby(psdf.a).rolling(2), f)().sort_index(), + getattr(pdf.groupby(pdf.a).rolling(2), f)().drop("a", axis=1).sort_index(), + ) + self.assert_eq( + getattr(psdf.groupby(psdf.a).rolling(2), f)().sum(), + getattr(pdf.groupby(pdf.a).rolling(2), f)().sum().drop("a"), + ) + self.assert_eq( + getattr(psdf.groupby(psdf.a + 1).rolling(2), f)().sort_index(), + getattr(pdf.groupby(pdf.a + 1).rolling(2), f)().drop("a", axis=1).sort_index(), + ) self.assert_eq( getattr(psdf.b.groupby(psdf.a).rolling(2), f)().sort_index(), @@ -147,10 +158,8 @@ def _test_groupby_rolling_func(self, f): pdf.columns = columns psdf.columns = columns + # The behavior of GroupBy.rolling is changed from pandas 1.3. if LooseVersion(pd.__version__) >= LooseVersion("1.3"): - # TODO(SPARK-36367): Fix the behavior to follow pandas >= 1.3 - pass - else: self.assert_eq( getattr(psdf.groupby(("a", "x")).rolling(2), f)().sort_index(), getattr(pdf.groupby(("a", "x")).rolling(2), f)().sort_index(), @@ -160,6 +169,20 @@ def _test_groupby_rolling_func(self, f): getattr(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)().sort_index(), getattr(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)().sort_index(), ) + else: + self.assert_eq( + getattr(psdf.groupby(("a", "x")).rolling(2), f)().sort_index(), + getattr(pdf.groupby(("a", "x")).rolling(2), f)() + .drop(("a", "x"), axis=1) + .sort_index(), + ) + + self.assert_eq( + getattr(psdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)().sort_index(), + getattr(pdf.groupby([("a", "x"), ("a", "y")]).rolling(2), f)() + .drop([("a", "x"), ("a", "y")], axis=1) + .sort_index(), + ) def test_groupby_rolling_count(self): self._test_groupby_rolling_func("count") diff --git a/python/pyspark/pandas/window.py b/python/pyspark/pandas/window.py index 0d656c229a..68d87fbfd1 100644 --- a/python/pyspark/pandas/window.py +++ b/python/pyspark/pandas/window.py @@ -36,7 +36,7 @@ # For running doctests and reference resolution in PyCharm. from pyspark import pandas as ps # noqa: F401 from pyspark.pandas._typing import FrameLike -from pyspark.pandas.groupby import GroupBy +from pyspark.pandas.groupby import GroupBy, DataFrameGroupBy from pyspark.pandas.internal import NATURAL_ORDER_COLUMN_NAME, SPARK_INDEX_NAME_FORMAT from pyspark.pandas.spark import functions as SF from pyspark.pandas.utils import scol_for @@ -706,10 +706,15 @@ def _apply_as_series_or_frame(self, func: Callable[[Column], Column]) -> FrameLi if groupby._agg_columns_selected: agg_columns = groupby._agg_columns else: + # pandas doesn't keep the groupkey as a column from 1.3 for DataFrameGroupBy + column_labels_to_exclude = groupby._column_labels_to_exclude.copy() + if isinstance(groupby, DataFrameGroupBy): + for groupkey in groupby._groupkeys: # type: ignore + column_labels_to_exclude.add(groupkey._internal.column_labels[0]) agg_columns = [ psdf._psser_for(label) for label in psdf._internal.column_labels - if label not in groupby._column_labels_to_exlcude + if label not in column_labels_to_exclude ] applied = [] @@ -777,19 +782,19 @@ def count(self) -> FrameLike: >>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2}) >>> df.groupby(df.A).rolling(2).count().sort_index() # doctest: +NORMALIZE_WHITESPACE - A B + B A - 2 0 1.0 1.0 - 1 2.0 2.0 - 3 2 1.0 1.0 - 3 2.0 2.0 - 4 2.0 2.0 - 4 5 1.0 1.0 - 6 2.0 2.0 - 7 2.0 2.0 - 8 2.0 2.0 - 5 9 1.0 1.0 - 10 2.0 2.0 + 2 0 1.0 + 1 2.0 + 3 2 1.0 + 3 2.0 + 4 2.0 + 4 5 1.0 + 6 2.0 + 7 2.0 + 8 2.0 + 5 9 1.0 + 10 2.0 """ return super().count() @@ -831,19 +836,19 @@ def sum(self) -> FrameLike: >>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2}) >>> df.groupby(df.A).rolling(2).sum().sort_index() # doctest: +NORMALIZE_WHITESPACE - A B + B A - 2 0 NaN NaN - 1 4.0 8.0 - 3 2 NaN NaN - 3 6.0 18.0 - 4 6.0 18.0 - 4 5 NaN NaN - 6 8.0 32.0 - 7 8.0 32.0 - 8 8.0 32.0 - 5 9 NaN NaN - 10 10.0 50.0 + 2 0 NaN + 1 8.0 + 3 2 NaN + 3 18.0 + 4 18.0 + 4 5 NaN + 6 32.0 + 7 32.0 + 8 32.0 + 5 9 NaN + 10 50.0 """ return super().sum() @@ -885,19 +890,19 @@ def min(self) -> FrameLike: >>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2}) >>> df.groupby(df.A).rolling(2).min().sort_index() # doctest: +NORMALIZE_WHITESPACE - A B + B A - 2 0 NaN NaN - 1 2.0 4.0 - 3 2 NaN NaN - 3 3.0 9.0 - 4 3.0 9.0 - 4 5 NaN NaN - 6 4.0 16.0 - 7 4.0 16.0 - 8 4.0 16.0 - 5 9 NaN NaN - 10 5.0 25.0 + 2 0 NaN + 1 4.0 + 3 2 NaN + 3 9.0 + 4 9.0 + 4 5 NaN + 6 16.0 + 7 16.0 + 8 16.0 + 5 9 NaN + 10 25.0 """ return super().min() @@ -939,19 +944,19 @@ def max(self) -> FrameLike: >>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2}) >>> df.groupby(df.A).rolling(2).max().sort_index() # doctest: +NORMALIZE_WHITESPACE - A B + B A - 2 0 NaN NaN - 1 2.0 4.0 - 3 2 NaN NaN - 3 3.0 9.0 - 4 3.0 9.0 - 4 5 NaN NaN - 6 4.0 16.0 - 7 4.0 16.0 - 8 4.0 16.0 - 5 9 NaN NaN - 10 5.0 25.0 + 2 0 NaN + 1 4.0 + 3 2 NaN + 3 9.0 + 4 9.0 + 4 5 NaN + 6 16.0 + 7 16.0 + 8 16.0 + 5 9 NaN + 10 25.0 """ return super().max() @@ -993,19 +998,19 @@ def mean(self) -> FrameLike: >>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2}) >>> df.groupby(df.A).rolling(2).mean().sort_index() # doctest: +NORMALIZE_WHITESPACE - A B + B A - 2 0 NaN NaN - 1 2.0 4.0 - 3 2 NaN NaN - 3 3.0 9.0 - 4 3.0 9.0 - 4 5 NaN NaN - 6 4.0 16.0 - 7 4.0 16.0 - 8 4.0 16.0 - 5 9 NaN NaN - 10 5.0 25.0 + 2 0 NaN + 1 4.0 + 3 2 NaN + 3 9.0 + 4 9.0 + 4 5 NaN + 6 16.0 + 7 16.0 + 8 16.0 + 5 9 NaN + 10 25.0 """ return super().mean() @@ -1478,19 +1483,19 @@ def count(self) -> FrameLike: >>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2}) >>> df.groupby(df.A).expanding(2).count().sort_index() # doctest: +NORMALIZE_WHITESPACE - A B + B A - 2 0 NaN NaN - 1 2.0 2.0 - 3 2 NaN NaN - 3 2.0 2.0 - 4 3.0 3.0 - 4 5 NaN NaN - 6 2.0 2.0 - 7 3.0 3.0 - 8 4.0 4.0 - 5 9 NaN NaN - 10 2.0 2.0 + 2 0 NaN + 1 2.0 + 3 2 NaN + 3 2.0 + 4 3.0 + 4 5 NaN + 6 2.0 + 7 3.0 + 8 4.0 + 5 9 NaN + 10 2.0 """ return super().count() @@ -1532,19 +1537,19 @@ def sum(self) -> FrameLike: >>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2}) >>> df.groupby(df.A).expanding(2).sum().sort_index() # doctest: +NORMALIZE_WHITESPACE - A B + B A - 2 0 NaN NaN - 1 4.0 8.0 - 3 2 NaN NaN - 3 6.0 18.0 - 4 9.0 27.0 - 4 5 NaN NaN - 6 8.0 32.0 - 7 12.0 48.0 - 8 16.0 64.0 - 5 9 NaN NaN - 10 10.0 50.0 + 2 0 NaN + 1 8.0 + 3 2 NaN + 3 18.0 + 4 27.0 + 4 5 NaN + 6 32.0 + 7 48.0 + 8 64.0 + 5 9 NaN + 10 50.0 """ return super().sum() @@ -1586,19 +1591,19 @@ def min(self) -> FrameLike: >>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2}) >>> df.groupby(df.A).expanding(2).min().sort_index() # doctest: +NORMALIZE_WHITESPACE - A B + B A - 2 0 NaN NaN - 1 2.0 4.0 - 3 2 NaN NaN - 3 3.0 9.0 - 4 3.0 9.0 - 4 5 NaN NaN - 6 4.0 16.0 - 7 4.0 16.0 - 8 4.0 16.0 - 5 9 NaN NaN - 10 5.0 25.0 + 2 0 NaN + 1 4.0 + 3 2 NaN + 3 9.0 + 4 9.0 + 4 5 NaN + 6 16.0 + 7 16.0 + 8 16.0 + 5 9 NaN + 10 25.0 """ return super().min() @@ -1639,19 +1644,19 @@ def max(self) -> FrameLike: >>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2}) >>> df.groupby(df.A).expanding(2).max().sort_index() # doctest: +NORMALIZE_WHITESPACE - A B + B A - 2 0 NaN NaN - 1 2.0 4.0 - 3 2 NaN NaN - 3 3.0 9.0 - 4 3.0 9.0 - 4 5 NaN NaN - 6 4.0 16.0 - 7 4.0 16.0 - 8 4.0 16.0 - 5 9 NaN NaN - 10 5.0 25.0 + 2 0 NaN + 1 4.0 + 3 2 NaN + 3 9.0 + 4 9.0 + 4 5 NaN + 6 16.0 + 7 16.0 + 8 16.0 + 5 9 NaN + 10 25.0 """ return super().max() @@ -1693,19 +1698,19 @@ def mean(self) -> FrameLike: >>> df = ps.DataFrame({"A": s.to_numpy(), "B": s.to_numpy() ** 2}) >>> df.groupby(df.A).expanding(2).mean().sort_index() # doctest: +NORMALIZE_WHITESPACE - A B + B A - 2 0 NaN NaN - 1 2.0 4.0 - 3 2 NaN NaN - 3 3.0 9.0 - 4 3.0 9.0 - 4 5 NaN NaN - 6 4.0 16.0 - 7 4.0 16.0 - 8 4.0 16.0 - 5 9 NaN NaN - 10 5.0 25.0 + 2 0 NaN + 1 4.0 + 3 2 NaN + 3 9.0 + 4 9.0 + 4 5 NaN + 6 16.0 + 7 16.0 + 8 16.0 + 5 9 NaN + 10 25.0 """ return super().mean() diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala index fb4097304d..b2025aa349 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala @@ -43,7 +43,7 @@ import org.scalatest.matchers.should.Matchers._ import org.apache.spark.SecurityManager import org.apache.spark.SparkFunSuite import org.apache.spark.internal.config._ -import org.apache.spark.network.shuffle.{ExternalBlockHandler, RemoteBlockPushResolver, ShuffleTestAccessor} +import org.apache.spark.network.shuffle.{NoOpMergedShuffleFileManager, RemoteBlockPushResolver, ShuffleTestAccessor} import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.network.util.TransportConf import org.apache.spark.util.Utils @@ -434,9 +434,9 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd test("create default merged shuffle file manager instance") { val mockConf = mock(classOf[TransportConf]) when(mockConf.mergedShuffleFileManagerImpl).thenReturn( - "org.apache.spark.network.shuffle.ExternalBlockHandler$NoOpMergedShuffleFileManager") + "org.apache.spark.network.shuffle.NoOpMergedShuffleFileManager") val mergeMgr = YarnShuffleService.newMergedShuffleFileManagerInstance(mockConf) - assert(mergeMgr.isInstanceOf[ExternalBlockHandler.NoOpMergedShuffleFileManager]) + assert(mergeMgr.isInstanceOf[NoOpMergedShuffleFileManager]) } test("create remote block push resolver instance") { @@ -452,6 +452,6 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd when(mockConf.mergedShuffleFileManagerImpl).thenReturn( "org.apache.spark.network.shuffle.NotExistent") val mergeMgr = YarnShuffleService.newMergedShuffleFileManagerInstance(mockConf) - assert(mergeMgr.isInstanceOf[ExternalBlockHandler.NoOpMergedShuffleFileManager]) + assert(mergeMgr.isInstanceOf[NoOpMergedShuffleFileManager]) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 457dc10028..f03296fdee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -120,6 +120,11 @@ object AnsiTypeCoercion extends TypeCoercionBase { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) + case (t1: DayTimeIntervalType, t2: DayTimeIntervalType) => + Some(DayTimeIntervalType(t1.startField.min(t2.startField), t1.endField.max(t2.endField))) + case (t1: YearMonthIntervalType, t2: YearMonthIntervalType) => + Some(YearMonthIntervalType(t1.startField.min(t2.startField), t1.endField.max(t2.endField))) + case (t1, t2) => findTypeForComplex(t1, t2, findTightestCommonType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 9dc5db8205..d38327a3c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -933,26 +933,33 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { * Validates the options used for alter table commands after table and columns are resolved. */ private def checkAlterTableCommand(alter: AlterTableCommand): Unit = { - def checkColumnNotExists( - op: String, fieldNames: Seq[String], struct: StructType, r: Resolver): Unit = { - if (struct.findNestedField(fieldNames, includeCollections = true, r).isDefined) { + def checkColumnNotExists(op: String, fieldNames: Seq[String], struct: StructType): Unit = { + if (struct.findNestedField( + fieldNames, includeCollections = true, alter.conf.resolver).isDefined) { alter.failAnalysis(s"Cannot $op column, because ${fieldNames.quoted} " + s"already exists in ${struct.treeString}") } } + def checkColumnNameDuplication(colsToAdd: Seq[QualifiedColType]): Unit = { + SchemaUtils.checkColumnNameDuplication( + colsToAdd.map(_.name.quoted), + "in the user specified columns", + alter.conf.resolver) + } + alter match { case AddColumns(table: ResolvedTable, colsToAdd) => colsToAdd.foreach { colToAdd => - checkColumnNotExists("add", colToAdd.name, table.schema, alter.conf.resolver) + checkColumnNotExists("add", colToAdd.name, table.schema) } - SchemaUtils.checkColumnNameDuplication( - colsToAdd.map(_.name.quoted), - "in the user specified columns", - alter.conf.resolver) + checkColumnNameDuplication(colsToAdd) + + case ReplaceColumns(_: ResolvedTable, colsToAdd) => + checkColumnNameDuplication(colsToAdd) case RenameColumn(table: ResolvedTable, col: ResolvedFieldName, newName) => - checkColumnNotExists("rename", col.path :+ newName, table.schema, alter.conf.resolver) + checkColumnNotExists("rename", col.path :+ newName, table.schema) case a @ AlterColumn(table: ResolvedTable, col: ResolvedFieldName, _, _, _, _) => val fieldName = col.name.quoted diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 42c10e8a11..db6f499f2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -867,6 +867,11 @@ object TypeCoercion extends TypeCoercionBase { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) + case (t1: DayTimeIntervalType, t2: DayTimeIntervalType) => + Some(DayTimeIntervalType(t1.startField.min(t2.startField), t1.endField.max(t2.endField))) + case (t1: YearMonthIntervalType, t2: YearMonthIntervalType) => + Some(YearMonthIntervalType(t1.startField.min(t2.startField), t1.endField.max(t2.endField))) + case (_: TimestampNTZType, _: DateType) | (_: DateType, _: TimestampNTZType) => Some(TimestampNTZType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index dfa746f7c7..2a90a99d4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -330,12 +330,13 @@ class JacksonParser( case udt: UserDefinedType[_] => makeConverter(udt.sqlType) - case _ => - (parser: JsonParser) => - // Here, we pass empty `PartialFunction` so that this case can be - // handled as a failed conversion. It will throw an exception as - // long as the value is not null. - parseJsonToken[AnyRef](parser, dataType)(PartialFunction.empty[JsonToken, AnyRef]) + case _: NullType => + (parser: JsonParser) => parseJsonToken[java.lang.Long](parser, dataType) { + case _ => null + } + + // We don't actually hit this exception though, we keep it for understandability + case _ => throw QueryExecutionErrors.unsupportedTypeError(dataType) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 4633a36c40..d9efdc122c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -599,7 +599,7 @@ object View { "spark.sql.hive.convertMetastoreOrc", "spark.sql.hive.convertInsertingPartitionedTable", "spark.sql.hive.convertMetastoreCtas" - ).contains(key)) + ).contains(key) || key.startsWith("spark.sql.catalog.")) for ((k, v) <- configs ++ retainedConfigs) { sqlConf.settings.put(k, v) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 602daf8dde..6a7d7ef988 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.analysis import java.sql.Timestamp +import java.time.{Duration, Period} import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.sql.catalyst.analysis.TypeCoercion._ @@ -1604,6 +1605,52 @@ class TypeCoercionSuite extends AnalysisTest { ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2, 1L), IntegralDivide(Cast(2, LongType), 1L)) } + + test("SPARK-36431: Support TypeCoercion of ANSI intervals with different fields") { + DataTypeTestUtils.yearMonthIntervalTypes.foreach { ym1 => + DataTypeTestUtils.yearMonthIntervalTypes.foreach { ym2 => + val literal1 = Literal.create(Period.ofMonths(12), ym1) + val literal2 = Literal.create(Period.ofMonths(12), ym2) + val commonType = YearMonthIntervalType( + ym1.startField.min(ym2.startField), ym1.endField.max(ym2.endField)) + if (commonType == ym1 && commonType == ym2) { + ruleTest(TypeCoercion.ImplicitTypeCasts, EqualTo(literal1, literal2), + EqualTo(literal1, literal2)) + } else if (commonType == ym1) { + ruleTest(TypeCoercion.ImplicitTypeCasts, EqualTo(literal1, literal2), + EqualTo(literal1, Cast(literal2, commonType))) + } else if (commonType == ym2) { + ruleTest(TypeCoercion.ImplicitTypeCasts, EqualTo(literal1, literal2), + EqualTo(Cast(literal1, commonType), literal2)) + } else { + ruleTest(TypeCoercion.ImplicitTypeCasts, EqualTo(literal1, literal2), + EqualTo(Cast(literal1, commonType), Cast(literal2, commonType))) + } + } + } + + DataTypeTestUtils.dayTimeIntervalTypes.foreach { dt1 => + DataTypeTestUtils.dayTimeIntervalTypes.foreach { dt2 => + val literal1 = Literal.create(Duration.ofSeconds(1111), dt1) + val literal2 = Literal.create(Duration.ofSeconds(1111), dt2) + val commonType = DayTimeIntervalType( + dt1.startField.min(dt2.startField), dt1.endField.max(dt2.endField)) + if (commonType == dt1 && commonType == dt2) { + ruleTest(TypeCoercion.ImplicitTypeCasts, EqualTo(literal1, literal2), + EqualTo(literal1, literal2)) + } else if (commonType == dt1) { + ruleTest(TypeCoercion.ImplicitTypeCasts, EqualTo(literal1, literal2), + EqualTo(literal1, Cast(literal2, commonType))) + } else if (commonType == dt2) { + ruleTest(TypeCoercion.ImplicitTypeCasts, EqualTo(literal1, literal2), + EqualTo(Cast(literal1, commonType), literal2)) + } else { + ruleTest(TypeCoercion.ImplicitTypeCasts, EqualTo(literal1, literal2), + EqualTo(Cast(literal1, commonType), Cast(literal2, commonType))) + } + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index a8c813a03e..dc59526bb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -53,7 +53,8 @@ case class OrcScanBuilder( override def pushFilters(filters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.orcFilterPushDown) { - val dataTypeMap = OrcFilters.getSearchableTypeMap(schema, SQLConf.get.caseSensitiveAnalysis) + val dataTypeMap = OrcFilters.getSearchableTypeMap( + readDataSchema(), SQLConf.get.caseSensitiveAnalysis) _pushedFilters = OrcFilters.convertibleFilters(dataTypeMap, filters).toArray } filters diff --git a/sql/core/src/test/resources/sql-tests/inputs/interval.sql b/sql/core/src/test/resources/sql-tests/inputs/interval.sql index 43d1e03fab..a16d152816 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/interval.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/interval.sql @@ -322,3 +322,16 @@ SELECT INTERVAL '153722867280' MINUTE; SELECT INTERVAL '-153722867280' MINUTE; SELECT INTERVAL '54.775807' SECOND; SELECT INTERVAL '-54.775807' SECOND; + +SELECT INTERVAL '1' DAY > INTERVAL '1' HOUR; +SELECT INTERVAL '1 02' DAY TO HOUR = INTERVAL '02:10:55' HOUR TO SECOND; +SELECT INTERVAL '1' YEAR < INTERVAL '1' MONTH; +SELECT INTERVAL '-1-1' YEAR TO MONTH = INTERVAL '-13' MONTH; +SELECT INTERVAL 1 MONTH > INTERVAL 20 DAYS; + +SELECT array(INTERVAL '1' YEAR, INTERVAL '1' MONTH); +SELECT array(INTERVAL '1' DAY, INTERVAL '01:01' HOUR TO MINUTE); +SELECT array(INTERVAL 1 MONTH, INTERVAL 20 DAYS); +SELECT coalesce(INTERVAL '1' YEAR, INTERVAL '1' MONTH); +SELECT coalesce(INTERVAL '1' DAY, INTERVAL '01:01' HOUR TO MINUTE); +SELECT coalesce(INTERVAL 1 MONTH, INTERVAL 20 DAYS); diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out index 9bf492ef35..9ba5da322c 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 200 +-- Number of queries: 211 -- !query @@ -818,10 +818,9 @@ struct> -- !query select map(1, interval 1 year, 2, interval 2 month) -- !query schema -struct<> +struct> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'map(1, INTERVAL '1' YEAR, 2, INTERVAL '2' MONTH)' due to data type mismatch: The given values of function map should all be the same type, but they are [interval year, interval month]; line 1 pos 7 +{1:1-0,2:0-2} -- !query @@ -1985,3 +1984,94 @@ SELECT INTERVAL '-54.775807' SECOND struct -- !query output -0 00:00:54.775807000 + + +-- !query +SELECT INTERVAL '1' DAY > INTERVAL '1' HOUR +-- !query schema +struct<(INTERVAL '1' DAY > INTERVAL '01' HOUR):boolean> +-- !query output +true + + +-- !query +SELECT INTERVAL '1 02' DAY TO HOUR = INTERVAL '02:10:55' HOUR TO SECOND +-- !query schema +struct<(INTERVAL '1 02' DAY TO HOUR = INTERVAL '02:10:55' HOUR TO SECOND):boolean> +-- !query output +false + + +-- !query +SELECT INTERVAL '1' YEAR < INTERVAL '1' MONTH +-- !query schema +struct<(INTERVAL '1' YEAR < INTERVAL '1' MONTH):boolean> +-- !query output +false + + +-- !query +SELECT INTERVAL '-1-1' YEAR TO MONTH = INTERVAL '-13' MONTH +-- !query schema +struct<(INTERVAL '-1-1' YEAR TO MONTH = INTERVAL '-13' MONTH):boolean> +-- !query output +true + + +-- !query +SELECT INTERVAL 1 MONTH > INTERVAL 20 DAYS +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve '(INTERVAL '1' MONTH > INTERVAL '20' DAY)' due to data type mismatch: differing types in '(INTERVAL '1' MONTH > INTERVAL '20' DAY)' (interval month and interval day).; line 1 pos 7 + + +-- !query +SELECT array(INTERVAL '1' YEAR, INTERVAL '1' MONTH) +-- !query schema +struct> +-- !query output +[1-0,0-1] + + +-- !query +SELECT array(INTERVAL '1' DAY, INTERVAL '01:01' HOUR TO MINUTE) +-- !query schema +struct> +-- !query output +[1 00:00:00.000000000,0 01:01:00.000000000] + + +-- !query +SELECT array(INTERVAL 1 MONTH, INTERVAL 20 DAYS) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'array(INTERVAL '1' MONTH, INTERVAL '20' DAY)' due to data type mismatch: input to function array should all be the same type, but it's [interval month, interval day]; line 1 pos 7 + + +-- !query +SELECT coalesce(INTERVAL '1' YEAR, INTERVAL '1' MONTH) +-- !query schema +struct +-- !query output +1-0 + + +-- !query +SELECT coalesce(INTERVAL '1' DAY, INTERVAL '01:01' HOUR TO MINUTE) +-- !query schema +struct +-- !query output +1 00:00:00.000000000 + + +-- !query +SELECT coalesce(INTERVAL 1 MONTH, INTERVAL 20 DAYS) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'coalesce(INTERVAL '1' MONTH, INTERVAL '20' DAY)' due to data type mismatch: input to function coalesce should all be the same type, but it's [interval month, interval day]; line 1 pos 7 diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out index 8780365f64..a15cc23672 100644 --- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 200 +-- Number of queries: 211 -- !query @@ -817,10 +817,9 @@ struct> -- !query select map(1, interval 1 year, 2, interval 2 month) -- !query schema -struct<> +struct> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve 'map(1, INTERVAL '1' YEAR, 2, INTERVAL '2' MONTH)' due to data type mismatch: The given values of function map should all be the same type, but they are [interval year, interval month]; line 1 pos 7 +{1:1-0,2:0-2} -- !query @@ -1984,3 +1983,94 @@ SELECT INTERVAL '-54.775807' SECOND struct -- !query output -0 00:00:54.775807000 + + +-- !query +SELECT INTERVAL '1' DAY > INTERVAL '1' HOUR +-- !query schema +struct<(INTERVAL '1' DAY > INTERVAL '01' HOUR):boolean> +-- !query output +true + + +-- !query +SELECT INTERVAL '1 02' DAY TO HOUR = INTERVAL '02:10:55' HOUR TO SECOND +-- !query schema +struct<(INTERVAL '1 02' DAY TO HOUR = INTERVAL '02:10:55' HOUR TO SECOND):boolean> +-- !query output +false + + +-- !query +SELECT INTERVAL '1' YEAR < INTERVAL '1' MONTH +-- !query schema +struct<(INTERVAL '1' YEAR < INTERVAL '1' MONTH):boolean> +-- !query output +false + + +-- !query +SELECT INTERVAL '-1-1' YEAR TO MONTH = INTERVAL '-13' MONTH +-- !query schema +struct<(INTERVAL '-1-1' YEAR TO MONTH = INTERVAL '-13' MONTH):boolean> +-- !query output +true + + +-- !query +SELECT INTERVAL 1 MONTH > INTERVAL 20 DAYS +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve '(INTERVAL '1' MONTH > INTERVAL '20' DAY)' due to data type mismatch: differing types in '(INTERVAL '1' MONTH > INTERVAL '20' DAY)' (interval month and interval day).; line 1 pos 7 + + +-- !query +SELECT array(INTERVAL '1' YEAR, INTERVAL '1' MONTH) +-- !query schema +struct> +-- !query output +[1-0,0-1] + + +-- !query +SELECT array(INTERVAL '1' DAY, INTERVAL '01:01' HOUR TO MINUTE) +-- !query schema +struct> +-- !query output +[1 00:00:00.000000000,0 01:01:00.000000000] + + +-- !query +SELECT array(INTERVAL 1 MONTH, INTERVAL 20 DAYS) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'array(INTERVAL '1' MONTH, INTERVAL '20' DAY)' due to data type mismatch: input to function array should all be the same type, but it's [interval month, interval day]; line 1 pos 7 + + +-- !query +SELECT coalesce(INTERVAL '1' YEAR, INTERVAL '1' MONTH) +-- !query schema +struct +-- !query output +1-0 + + +-- !query +SELECT coalesce(INTERVAL '1' DAY, INTERVAL '01:01' HOUR TO MINUTE) +-- !query schema +struct +-- !query output +1 00:00:00.000000000 + + +-- !query +SELECT coalesce(INTERVAL 1 MONTH, INTERVAL 20 DAYS) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve 'coalesce(INTERVAL '1' MONTH, INTERVAL '20' DAY)' due to data type mismatch: input to function coalesce should all be the same type, but it's [interval month, interval day]; line 1 pos 7 diff --git a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out index fe83675617..fae7721542 100644 --- a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp-ansi.sql.out @@ -661,9 +661,10 @@ You may get a different result due to the upgrading of Spark 3.0: Fail to recogn -- !query select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) -- !query schema -struct> +struct<> -- !query output -{"t":null} +java.lang.Exception +Unsupported type: timestamp_ntz -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp.sql.out index b8a68005eb..c6de535807 100644 --- a/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/timestampNTZ/timestamp.sql.out @@ -642,9 +642,10 @@ You may get a different result due to the upgrading of Spark 3.0: Fail to recogn -- !query select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) -- !query schema -struct> +struct<> -- !query output -{"t":null} +java.lang.Exception +Unsupported type: timestamp_ntz -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index e7fd139c73..396d227218 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -460,7 +460,7 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite "parquet" -> "|PushedFilters: \\[IsNotNull\\(value\\), GreaterThan\\(value,2\\)\\]", "orc" -> - "|PushedFilters: \\[.*\\(id\\), .*\\(value\\), .*\\(id,1\\), .*\\(value,2\\)\\]", + "|PushedFilters: \\[IsNotNull\\(value\\), GreaterThan\\(value,2\\)\\]", "csv" -> "|PushedFilters: \\[IsNotNull\\(value\\), GreaterThan\\(value,2\\)\\]", "json" -> diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala index 1bd45f50f5..1b0898fbc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala @@ -1175,4 +1175,15 @@ trait AlterTableTests extends SharedSparkSession { StructField("col3", IntegerType).withComment("c3")))) } } + + test("SPARK-36449: Replacing columns with duplicate name should not be allowed") { + val t = s"${catalogAndNamespace}table_name" + withTable(t) { + sql(s"CREATE TABLE $t (data string) USING $v2Format") + val e = intercept[AnalysisException] { + sql(s"ALTER TABLE $t REPLACE COLUMNS (data string, data1 string, data string)") + } + assert(e.message.contains("Found duplicate column(s) in the user specified columns: `data`")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala index 0cc8d05361..f262cf152c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, CreateTablePartitioningValidationSuite, ResolvedTable, TestRelation2, TestTable2, UnresolvedFieldName, UnresolvedFieldPosition} -import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterTableCommand, CreateTableAsSelect, DropColumns, LogicalPlan, QualifiedColType, RenameColumn, ReplaceTableAsSelect} +import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterTableCommand, CreateTableAsSelect, DropColumns, LogicalPlan, QualifiedColType, RenameColumn, ReplaceColumns, ReplaceTableAsSelect} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition @@ -316,8 +316,18 @@ class V2CommandsCaseSensitivitySuite extends SharedSparkSession with AnalysisTes } } + test("SPARK-36449: Replacing columns with duplicate name should not be allowed") { + alterTableTest( + ReplaceColumns( + table, + Seq(QualifiedColType(None, "f", LongType, true, None, None), + QualifiedColType(None, "F", LongType, true, None, None))), + Seq("Found duplicate column(s) in the user specified columns: `f`"), + expectErrorOnCaseSensitive = false) + } + private def alterTableTest( - alter: AlterTableCommand, + alter: => AlterTableCommand, error: Seq[String], expectErrorOnCaseSensitive: Boolean = true): Unit = { Seq(true, false).foreach { caseSensitive => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala index bc64f51ae8..8383d442c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewTestSuite.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.execution +import scala.collection.JavaConverters._ + import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.Repartition +import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -394,6 +397,20 @@ class GlobalTempViewTestSuite extends SQLViewTestSuite with SharedSparkSession { } } +class OneTableCatalog extends InMemoryCatalog { + override def loadTable(ident: Identifier): Table = { + if (ident.namespace.isEmpty && ident.name == "t") { + new InMemoryTable( + "t", + StructType.fromDDL("c1 INT"), + Array.empty, + Map.empty[String, String].asJava) + } else { + super.loadTable(ident) + } + } +} + class PersistedViewTestSuite extends SQLViewTestSuite with SharedSparkSession { private def db: String = "default" override protected def viewTypeString: String = "VIEW" @@ -493,4 +510,24 @@ class PersistedViewTestSuite extends SQLViewTestSuite with SharedSparkSession { } } } + + test("SPARK-36466: Table in unloaded catalog referenced by view should load correctly") { + val viewName = "v" + val tableInOtherCatalog = "cat.t" + try { + spark.conf.set("spark.sql.catalog.cat", classOf[OneTableCatalog].getName) + withTable(tableInOtherCatalog) { + withView(viewName) { + createView(viewName, s"SELECT count(*) AS cnt FROM $tableInOtherCatalog") + checkViewOutput(viewName, Seq(Row(0))) + spark.sessionState.catalogManager.reset() + checkViewOutput(viewName, Seq(Row(0))) + } + } + } finally { + spark.sessionState.catalog.reset() + spark.sessionState.catalogManager.reset() + spark.sessionState.conf.clear() + } + } }