bootstraps = Lists.newArrayList();
boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE);
if (authEnabled) {
- createSecretManager();
+ secretManager = new ShuffleSecretManager();
+ if (_recoveryPath != null) {
+ loadSecretsFromDb();
+ }
bootstraps.add(new AuthServerBootstrap(transportConf, secretManager));
}
@@ -215,13 +221,12 @@ protected void serviceInit(Configuration conf) throws Exception {
}
}
- private void createSecretManager() throws IOException {
- secretManager = new ShuffleSecretManager();
+ private void loadSecretsFromDb() throws IOException {
secretsFile = initRecoveryDb(SECRETS_RECOVERY_FILE_NAME);
// Make sure this is protected in case its not in the NM recovery dir
FileSystem fs = FileSystem.getLocal(_conf);
- fs.mkdirs(new Path(secretsFile.getPath()), new FsPermission((short)0700));
+ fs.mkdirs(new Path(secretsFile.getPath()), new FsPermission((short) 0700));
db = LevelDBProvider.initLevelDB(secretsFile, CURRENT_VERSION, mapper);
logger.info("Recovery location is: " + secretsFile.getPath());
@@ -338,10 +343,10 @@ public ByteBuffer getMetaData() {
}
/**
- * Set the recovery path for shuffle service recovery when NM is restarted. The method will be
- * overrode and called when Hadoop version is 2.5+ and NM recovery is enabled, otherwise we
- * have to manually call this to set our own recovery path.
+ * Set the recovery path for shuffle service recovery when NM is restarted. This will be call
+ * by NM if NM recovery is enabled.
*/
+ @Override
public void setRecoveryPath(Path recoveryPath) {
_recoveryPath = recoveryPath;
}
@@ -355,53 +360,44 @@ protected Path getRecoveryPath(String fileName) {
/**
* Figure out the recovery path and handle moving the DB if YARN NM recovery gets enabled
- * when it previously was not. If YARN NM recovery is enabled it uses that path, otherwise
- * it will uses a YARN local dir.
+ * and DB exists in the local dir of NM by old version of shuffle service.
*/
protected File initRecoveryDb(String dbName) {
- if (_recoveryPath != null) {
- File recoveryFile = new File(_recoveryPath.toUri().getPath(), dbName);
- if (recoveryFile.exists()) {
- return recoveryFile;
- }
+ Preconditions.checkNotNull(_recoveryPath,
+ "recovery path should not be null if NM recovery is enabled");
+
+ File recoveryFile = new File(_recoveryPath.toUri().getPath(), dbName);
+ if (recoveryFile.exists()) {
+ return recoveryFile;
}
+
// db doesn't exist in recovery path go check local dirs for it
String[] localDirs = _conf.getTrimmedStrings("yarn.nodemanager.local-dirs");
for (String dir : localDirs) {
File f = new File(new Path(dir).toUri().getPath(), dbName);
if (f.exists()) {
- if (_recoveryPath == null) {
- // If NM recovery is not enabled, we should specify the recovery path using NM local
- // dirs, which is compatible with the old code.
- _recoveryPath = new Path(dir);
- return f;
- } else {
- // If the recovery path is set then either NM recovery is enabled or another recovery
- // DB has been initialized. If NM recovery is enabled and had set the recovery path
- // make sure to move all DBs to the recovery path from the old NM local dirs.
- // If another DB was initialized first just make sure all the DBs are in the same
- // location.
- Path newLoc = new Path(_recoveryPath, dbName);
- Path copyFrom = new Path(f.toURI());
- if (!newLoc.equals(copyFrom)) {
- logger.info("Moving " + copyFrom + " to: " + newLoc);
- try {
- // The move here needs to handle moving non-empty directories across NFS mounts
- FileSystem fs = FileSystem.getLocal(_conf);
- fs.rename(copyFrom, newLoc);
- } catch (Exception e) {
- // Fail to move recovery file to new path, just continue on with new DB location
- logger.error("Failed to move recovery file {} to the path {}",
- dbName, _recoveryPath.toString(), e);
- }
+ // If the recovery path is set then either NM recovery is enabled or another recovery
+ // DB has been initialized. If NM recovery is enabled and had set the recovery path
+ // make sure to move all DBs to the recovery path from the old NM local dirs.
+ // If another DB was initialized first just make sure all the DBs are in the same
+ // location.
+ Path newLoc = new Path(_recoveryPath, dbName);
+ Path copyFrom = new Path(f.toURI());
+ if (!newLoc.equals(copyFrom)) {
+ logger.info("Moving " + copyFrom + " to: " + newLoc);
+ try {
+ // The move here needs to handle moving non-empty directories across NFS mounts
+ FileSystem fs = FileSystem.getLocal(_conf);
+ fs.rename(copyFrom, newLoc);
+ } catch (Exception e) {
+ // Fail to move recovery file to new path, just continue on with new DB location
+ logger.error("Failed to move recovery file {} to the path {}",
+ dbName, _recoveryPath.toString(), e);
}
- return new File(newLoc.toUri().getPath());
}
+ return new File(newLoc.toUri().getPath());
}
}
- if (_recoveryPath == null) {
- _recoveryPath = new Path(localDirs[0]);
- }
return new File(_recoveryPath.toUri().getPath(), dbName);
}
diff --git a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
index 045fec33a282a..fd1906d2e5ae9 100644
--- a/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
+++ b/common/sketch/src/main/java/org/apache/spark/util/sketch/CountMinSketchImpl.java
@@ -365,7 +365,7 @@ private void writeObject(ObjectOutputStream out) throws IOException {
this.writeTo(out);
}
- private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
+ private void readObject(ObjectInputStream in) throws IOException {
this.readFrom0(in);
}
}
diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
index a77ba826fce29..4ae49d82efa29 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
@@ -73,5 +73,6 @@ public void memoryDebugFillEnabledInTest() {
Assert.assertEquals(
Platform.getByte(offheap.getBaseObject(), offheap.getBaseOffset()),
MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE);
+ MemoryAllocator.UNSAFE.free(offheap);
}
}
diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template
index aeb76c9b2f6ea..4c008a13607c2 100644
--- a/conf/metrics.properties.template
+++ b/conf/metrics.properties.template
@@ -118,6 +118,14 @@
# prefix EMPTY STRING Prefix to prepend to every metric's name
# protocol tcp Protocol ("tcp" or "udp") to use
+# org.apache.spark.metrics.sink.StatsdSink
+# Name: Default: Description:
+# host 127.0.0.1 Hostname or IP of StatsD server
+# port 8125 Port of StatsD server
+# period 10 Poll period
+# unit seconds Units of poll period
+# prefix EMPTY STRING Prefix to prepend to metric name
+
## Examples
# Enable JmxSink for all instances by class name
#*.sink.jmx.class=org.apache.spark.metrics.sink.JmxSink
@@ -125,6 +133,10 @@
# Enable ConsoleSink for all instances by class name
#*.sink.console.class=org.apache.spark.metrics.sink.ConsoleSink
+# Enable StatsdSink for all instances by class name
+#*.sink.statsd.class=org.apache.spark.metrics.sink.StatsdSink
+#*.sink.statsd.prefix=spark
+
# Polling period for the ConsoleSink
#*.sink.console.period=10
# Unit of the polling period for the ConsoleSink
diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template
index 1663019ee5758..f8c895f5303b9 100755
--- a/conf/spark-env.sh.template
+++ b/conf/spark-env.sh.template
@@ -52,6 +52,7 @@
# - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y")
# - SPARK_SHUFFLE_OPTS, to set config properties only for the external shuffle service (e.g. "-Dx=y")
# - SPARK_DAEMON_JAVA_OPTS, to set config properties for all daemons (e.g. "-Dx=y")
+# - SPARK_DAEMON_CLASSPATH, to set the classpath for all daemons
# - SPARK_PUBLIC_DNS, to set the public dns name of the master or workers
# Generic options for the daemons used in the standalone deploy mode
diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java
index 140c52fd12f94..3583856d88998 100644
--- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java
+++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java
@@ -139,6 +139,11 @@ public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) {
onEvent(blockUpdated);
}
+ @Override
+ public void onSpeculativeTaskSubmitted(SparkListenerSpeculativeTaskSubmitted speculativeTask) {
+ onEvent(speculativeTask);
+ }
+
@Override
public void onOtherEvent(SparkListenerEvent event) {
onEvent(event);
diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
index 48cf4b9455e4d..4099fb01f2f95 100644
--- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
+++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
@@ -111,8 +111,6 @@ public void freeArray(LongArray array) {
/**
* Allocate a memory block with at least `required` bytes.
*
- * Throws IOException if there is not enough memory.
- *
* @throws OutOfMemoryError
*/
protected MemoryBlock allocatePage(long required) {
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
index 761ba9de659d5..44b60c1e4e8c8 100644
--- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -53,8 +53,8 @@
* retrieve the base object.
*
* This allows us to address 8192 pages. In on-heap mode, the maximum page size is limited by the
- * maximum size of a long[] array, allowing us to address 8192 * 2^32 * 8 bytes, which is
- * approximately 35 terabytes of memory.
+ * maximum size of a long[] array, allowing us to address 8192 * (2^31 - 1) * 8 bytes, which is
+ * approximately 140 terabytes of memory.
*/
public class TaskMemoryManager {
@@ -74,7 +74,8 @@ public class TaskMemoryManager {
* Maximum supported data page size (in bytes). In principle, the maximum addressable page size is
* (1L << OFFSET_BITS) bytes, which is 2+ petabytes. However, the on-heap allocator's
* maximum page size is limited by the maximum amount of data that can be stored in a long[]
- * array, which is (2^32 - 1) * 8 bytes (or 16 gigabytes). Therefore, we cap this at 16 gigabytes.
+ * array, which is (2^31 - 1) * 8 bytes (or about 17 gigabytes). Therefore, we cap this at 17
+ * gigabytes.
*/
public static final long MAXIMUM_PAGE_SIZE_BYTES = ((1L << 31) - 1) * 8L;
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index da6c55d9b8ac3..b4f46306f2827 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -140,7 +140,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
* bytes written should be counted towards shuffle spill metrics rather than
* shuffle write metrics.
*/
- private void writeSortedFile(boolean isLastFile) throws IOException {
+ private void writeSortedFile(boolean isLastFile) {
final ShuffleWriteMetrics writeMetricsToUse;
@@ -325,7 +325,7 @@ public void cleanupResources() {
* array and grows the array if additional space is required. If the required space cannot be
* obtained, then the in-memory data will be spilled to disk.
*/
- private void growPointerArrayIfNecessary() throws IOException {
+ private void growPointerArrayIfNecessary() {
assert(inMemSorter != null);
if (!inMemSorter.hasSpaceForAnotherRecord()) {
long used = inMemSorter.getMemoryUsage();
@@ -406,19 +406,14 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p
* @throws IOException
*/
public SpillInfo[] closeAndGetSpills() throws IOException {
- try {
- if (inMemSorter != null) {
- // Do not count the final file towards the spill count.
- writeSortedFile(true);
- freeMemory();
- inMemSorter.free();
- inMemSorter = null;
- }
- return spills.toArray(new SpillInfo[spills.size()]);
- } catch (IOException e) {
- cleanupResources();
- throw e;
+ if (inMemSorter != null) {
+ // Do not count the final file towards the spill count.
+ writeSortedFile(true);
+ freeMemory();
+ inMemSorter.free();
+ inMemSorter = null;
}
+ return spills.toArray(new SpillInfo[spills.size()]);
}
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index c0ebe3cc9b792..e9c2a69c47cba 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -208,7 +208,7 @@ public void write(scala.collection.Iterator> records) throws IOEx
}
}
- private void open() throws IOException {
+ private void open() {
assert (sorter == null);
sorter = new ShuffleExternalSorter(
memoryManager,
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index e2059cec132d2..de4464080ef55 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -346,7 +346,7 @@ public void cleanupResources() {
* array and grows the array if additional space is required. If the required space cannot be
* obtained, then the in-memory data will be spilled to disk.
*/
- private void growPointerArrayIfNecessary() throws IOException {
+ private void growPointerArrayIfNecessary() {
assert(inMemSorter != null);
if (!inMemSorter.hasSpaceForAnotherRecord()) {
long used = inMemSorter.getMemoryUsage();
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
index 850f247b045cf..9399024f01783 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -79,7 +79,7 @@ public UnsafeSorterSpillWriter(
}
// Based on DataOutputStream.writeLong.
- private void writeLongToBuffer(long v, int offset) throws IOException {
+ private void writeLongToBuffer(long v, int offset) {
writeBuffer[offset + 0] = (byte)(v >>> 56);
writeBuffer[offset + 1] = (byte)(v >>> 48);
writeBuffer[offset + 2] = (byte)(v >>> 40);
@@ -91,7 +91,7 @@ private void writeLongToBuffer(long v, int offset) throws IOException {
}
// Based on DataOutputStream.writeInt.
- private void writeIntToBuffer(int v, int offset) throws IOException {
+ private void writeIntToBuffer(int v, int offset) {
writeBuffer[offset + 0] = (byte)(v >>> 24);
writeBuffer[offset + 1] = (byte)(v >>> 16);
writeBuffer[offset + 2] = (byte)(v >>> 8);
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
index 337631a6f9a34..7a5fb9a802354 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -373,8 +373,14 @@ private[spark] class ExecutorAllocationManager(
// If our target has not changed, do not send a message
// to the cluster manager and reset our exponential growth
if (delta == 0) {
- numExecutorsToAdd = 1
- return 0
+ // Check if there is any speculative jobs pending
+ if (listener.pendingTasks == 0 && listener.pendingSpeculativeTasks > 0) {
+ numExecutorsTarget =
+ math.max(math.min(maxNumExecutorsNeeded + 1, maxNumExecutors), minNumExecutors)
+ } else {
+ numExecutorsToAdd = 1
+ return 0
+ }
}
val addRequestAcknowledged = try {
@@ -440,6 +446,9 @@ private[spark] class ExecutorAllocationManager(
} else {
client.killExecutors(executorIdsToBeRemoved)
}
+ // [SPARK-21834] killExecutors api reduces the target number of executors.
+ // So we need to update the target with desired value.
+ client.requestTotalExecutors(numExecutorsTarget, localityAwareTasks, hostToLocalTaskCount)
// reset the newExecutorTotal to the existing number of executors
newExecutorTotal = numExistingExecutors
if (testing || executorsRemoved.nonEmpty) {
@@ -588,17 +597,22 @@ private[spark] class ExecutorAllocationManager(
* A listener that notifies the given allocation manager of when to add and remove executors.
*
* This class is intentionally conservative in its assumptions about the relative ordering
- * and consistency of events returned by the listener. For simplicity, it does not account
- * for speculated tasks.
+ * and consistency of events returned by the listener.
*/
private class ExecutorAllocationListener extends SparkListener {
private val stageIdToNumTasks = new mutable.HashMap[Int, Int]
private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]]
private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]]
- // Number of tasks currently running on the cluster. Should be 0 when no stages are active.
+ // Number of tasks currently running on the cluster including speculative tasks.
+ // Should be 0 when no stages are active.
private var numRunningTasks: Int = _
+ // Number of speculative tasks to be scheduled in each stage
+ private val stageIdToNumSpeculativeTasks = new mutable.HashMap[Int, Int]
+ // The speculative tasks started in each stage
+ private val stageIdToSpeculativeTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]]
+
// stageId to tuple (the number of task with locality preferences, a map where each pair is a
// node and the number of tasks that would like to be scheduled on that node) map,
// maintain the executor placement hints for each stage Id used by resource framework to better
@@ -637,7 +651,9 @@ private[spark] class ExecutorAllocationManager(
val stageId = stageCompleted.stageInfo.stageId
allocationManager.synchronized {
stageIdToNumTasks -= stageId
+ stageIdToNumSpeculativeTasks -= stageId
stageIdToTaskIndices -= stageId
+ stageIdToSpeculativeTaskIndices -= stageId
stageIdToExecutorPlacementHints -= stageId
// Update the executor placement hints
@@ -645,7 +661,7 @@ private[spark] class ExecutorAllocationManager(
// If this is the last stage with pending tasks, mark the scheduler queue as empty
// This is needed in case the stage is aborted for any reason
- if (stageIdToNumTasks.isEmpty) {
+ if (stageIdToNumTasks.isEmpty && stageIdToNumSpeculativeTasks.isEmpty) {
allocationManager.onSchedulerQueueEmpty()
if (numRunningTasks != 0) {
logWarning("No stages are running, but numRunningTasks != 0")
@@ -671,7 +687,12 @@ private[spark] class ExecutorAllocationManager(
}
// If this is the last pending task, mark the scheduler queue as empty
- stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex
+ if (taskStart.taskInfo.speculative) {
+ stageIdToSpeculativeTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) +=
+ taskIndex
+ } else {
+ stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex
+ }
if (totalPendingTasks() == 0) {
allocationManager.onSchedulerQueueEmpty()
}
@@ -705,7 +726,11 @@ private[spark] class ExecutorAllocationManager(
if (totalPendingTasks() == 0) {
allocationManager.onSchedulerBacklogged()
}
- stageIdToTaskIndices.get(stageId).foreach { _.remove(taskIndex) }
+ if (taskEnd.taskInfo.speculative) {
+ stageIdToSpeculativeTaskIndices.get(stageId).foreach {_.remove(taskIndex)}
+ } else {
+ stageIdToTaskIndices.get(stageId).foreach {_.remove(taskIndex)}
+ }
}
}
}
@@ -726,18 +751,39 @@ private[spark] class ExecutorAllocationManager(
allocationManager.onExecutorRemoved(executorRemoved.executorId)
}
+ override def onSpeculativeTaskSubmitted(speculativeTask: SparkListenerSpeculativeTaskSubmitted)
+ : Unit = {
+ val stageId = speculativeTask.stageId
+
+ allocationManager.synchronized {
+ stageIdToNumSpeculativeTasks(stageId) =
+ stageIdToNumSpeculativeTasks.getOrElse(stageId, 0) + 1
+ allocationManager.onSchedulerBacklogged()
+ }
+ }
+
/**
* An estimate of the total number of pending tasks remaining for currently running stages. Does
* not account for tasks which may have failed and been resubmitted.
*
* Note: This is not thread-safe without the caller owning the `allocationManager` lock.
*/
- def totalPendingTasks(): Int = {
+ def pendingTasks(): Int = {
stageIdToNumTasks.map { case (stageId, numTasks) =>
numTasks - stageIdToTaskIndices.get(stageId).map(_.size).getOrElse(0)
}.sum
}
+ def pendingSpeculativeTasks(): Int = {
+ stageIdToNumSpeculativeTasks.map { case (stageId, numTasks) =>
+ numTasks - stageIdToSpeculativeTaskIndices.get(stageId).map(_.size).getOrElse(0)
+ }.sum
+ }
+
+ def totalPendingTasks(): Int = {
+ pendingTasks + pendingSpeculativeTasks
+ }
+
/**
* The number of tasks currently running across all stages.
*/
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index 0899693988016..1034fdcae8e8c 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -89,6 +89,14 @@ trait FutureAction[T] extends Future[T] {
*/
override def value: Option[Try[T]]
+ // These two methods must be implemented in Scala 2.12, but won't be used by Spark
+
+ def transform[S](f: (Try[T]) => Try[S])(implicit executor: ExecutionContext): Future[S] =
+ throw new UnsupportedOperationException()
+
+ def transformWith[S](f: (Try[T]) => Future[S])(implicit executor: ExecutionContext): Future[S] =
+ throw new UnsupportedOperationException()
+
/**
* Blocks and returns the result of this job.
*/
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index 715cfdcc8f4ef..e61f943af49f2 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -597,7 +597,9 @@ private[spark] object SparkConf extends Logging {
DeprecatedConfig("spark.scheduler.executorTaskBlacklistTime", "2.1.0",
"Please use the new blacklisting options, spark.blacklist.*"),
DeprecatedConfig("spark.yarn.am.port", "2.0.0", "Not used any more"),
- DeprecatedConfig("spark.executor.port", "2.0.0", "Not used any more")
+ DeprecatedConfig("spark.executor.port", "2.0.0", "Not used any more"),
+ DeprecatedConfig("spark.shuffle.service.index.cache.entries", "2.3.0",
+ "Not used any more. Please use spark.shuffle.service.index.cache.size")
)
Map(configs.map { cfg => (cfg.key -> cfg) } : _*)
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
index f820401da2fc3..d6506231b8d74 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
@@ -56,9 +56,9 @@ private[spark] object JavaUtils {
val ui = underlying.iterator
var prev : Option[A] = None
- def hasNext: Boolean = ui.hasNext
+ override def hasNext: Boolean = ui.hasNext
- def next(): Entry[A, B] = {
+ override def next(): Entry[A, B] = {
val (k, v) = ui.next()
prev = Some(k)
new ju.Map.Entry[A, B] {
@@ -74,7 +74,7 @@ private[spark] object JavaUtils {
}
}
- def remove() {
+ override def remove() {
prev match {
case Some(k) =>
underlying match {
diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
index aaf8e7a1d7461..01e64b6972ae2 100644
--- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
@@ -35,6 +35,16 @@ import org.apache.spark.rdd.RDD
/** Utilities for serialization / deserialization between Python and Java, using Pickle. */
private[spark] object SerDeUtil extends Logging {
+ class ByteArrayConstructor extends net.razorvine.pickle.objects.ByteArrayConstructor {
+ override def construct(args: Array[Object]): Object = {
+ // Deal with an empty byte array pickled by Python 3.
+ if (args.length == 0) {
+ Array.emptyByteArray
+ } else {
+ super.construct(args)
+ }
+ }
+ }
// Unpickle array.array generated by Python 2.6
class ArrayConstructor extends net.razorvine.pickle.objects.ArrayConstructor {
// /* Description of types */
@@ -108,6 +118,10 @@ private[spark] object SerDeUtil extends Logging {
synchronized{
if (!initialized) {
Unpickler.registerConstructor("array", "array", new ArrayConstructor())
+ Unpickler.registerConstructor("__builtin__", "bytearray", new ByteArrayConstructor())
+ Unpickler.registerConstructor("builtins", "bytearray", new ByteArrayConstructor())
+ Unpickler.registerConstructor("__builtin__", "bytes", new ByteArrayConstructor())
+ Unpickler.registerConstructor("_codecs", "encode", new ByteArrayConstructor())
initialized = true
}
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 039df75ce74fd..67e993c7f02e2 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -30,7 +30,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.io.CompressionCodec
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage._
-import org.apache.spark.util.{ByteBufferInputStream, Utils}
+import org.apache.spark.util.Utils
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
/**
diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala
index 97f3803aafce4..51c3d9b158cbe 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala
@@ -18,15 +18,13 @@
package org.apache.spark.deploy
import java.io.File
-import java.nio.file.Files
-import scala.collection.mutable.HashMap
-
-import org.apache.commons.io.FileUtils
import org.apache.commons.lang3.StringUtils
import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.spark.util.MutableURLClassLoader
+import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.util.{MutableURLClassLoader, Utils}
private[deploy] object DependencyUtils {
@@ -51,41 +49,22 @@ private[deploy] object DependencyUtils {
SparkSubmitUtils.resolveMavenCoordinates(packages, ivySettings, exclusions = exclusions)
}
- def createTempDir(): File = {
- val targetDir = Files.createTempDirectory("tmp").toFile
- // scalastyle:off runtimeaddshutdownhook
- Runtime.getRuntime.addShutdownHook(new Thread() {
- override def run(): Unit = {
- FileUtils.deleteQuietly(targetDir)
- }
- })
- // scalastyle:on runtimeaddshutdownhook
- targetDir
- }
-
- def resolveAndDownloadJars(jars: String, userJar: String): String = {
- val targetDir = DependencyUtils.createTempDir()
- val hadoopConf = new Configuration()
- val sparkProperties = new HashMap[String, String]()
- val securityProperties = List("spark.ssl.fs.trustStore", "spark.ssl.trustStore",
- "spark.ssl.fs.trustStorePassword", "spark.ssl.trustStorePassword",
- "spark.ssl.fs.protocol", "spark.ssl.protocol")
-
- securityProperties.foreach { pName =>
- sys.props.get(pName).foreach { pValue =>
- sparkProperties.put(pName, pValue)
- }
- }
-
+ def resolveAndDownloadJars(
+ jars: String,
+ userJar: String,
+ sparkConf: SparkConf,
+ hadoopConf: Configuration,
+ secMgr: SecurityManager): String = {
+ val targetDir = Utils.createTempDir()
Option(jars)
.map {
- SparkSubmit.resolveGlobPaths(_, hadoopConf)
+ resolveGlobPaths(_, hadoopConf)
.split(",")
.filterNot(_.contains(userJar.split("/").last))
.mkString(",")
}
.filterNot(_ == "")
- .map(SparkSubmit.downloadFileList(_, targetDir, sparkProperties, hadoopConf))
+ .map(downloadFileList(_, targetDir, sparkConf, hadoopConf, secMgr))
.orNull
}
@@ -96,4 +75,73 @@ private[deploy] object DependencyUtils {
}
}
}
+
+ /**
+ * Download a list of remote files to temp local files. If the file is local, the original file
+ * will be returned.
+ *
+ * @param fileList A comma separated file list.
+ * @param targetDir A temporary directory for which downloaded files.
+ * @param sparkConf Spark configuration.
+ * @param hadoopConf Hadoop configuration.
+ * @param secMgr Spark security manager.
+ * @return A comma separated local files list.
+ */
+ def downloadFileList(
+ fileList: String,
+ targetDir: File,
+ sparkConf: SparkConf,
+ hadoopConf: Configuration,
+ secMgr: SecurityManager): String = {
+ require(fileList != null, "fileList cannot be null.")
+ fileList.split(",")
+ .map(downloadFile(_, targetDir, sparkConf, hadoopConf, secMgr))
+ .mkString(",")
+ }
+
+ /**
+ * Download a file from the remote to a local temporary directory. If the input path points to
+ * a local path, returns it with no operation.
+ *
+ * @param path A file path from where the files will be downloaded.
+ * @param targetDir A temporary directory for which downloaded files.
+ * @param sparkConf Spark configuration.
+ * @param hadoopConf Hadoop configuration.
+ * @param secMgr Spark security manager.
+ * @return Path to the local file.
+ */
+ def downloadFile(
+ path: String,
+ targetDir: File,
+ sparkConf: SparkConf,
+ hadoopConf: Configuration,
+ secMgr: SecurityManager): String = {
+ require(path != null, "path cannot be null.")
+ val uri = Utils.resolveURI(path)
+
+ uri.getScheme match {
+ case "file" | "local" => path
+ case _ =>
+ val fname = new Path(uri).getName()
+ val localFile = Utils.doFetchFile(uri.toString(), targetDir, fname, sparkConf, secMgr,
+ hadoopConf)
+ localFile.toURI().toString()
+ }
+ }
+
+ def resolveGlobPaths(paths: String, hadoopConf: Configuration): String = {
+ require(paths != null, "paths cannot be null.")
+ paths.split(",").map(_.trim).filter(_.nonEmpty).flatMap { path =>
+ val uri = Utils.resolveURI(path)
+ uri.getScheme match {
+ case "local" | "http" | "https" | "ftp" => Array(path)
+ case _ =>
+ val fs = FileSystem.get(uri, hadoopConf)
+ Option(fs.globStatus(new Path(uri))).map { status =>
+ status.filter(_.isFile).map(_.getPath.toUri.toString)
+ }.getOrElse(Array(path))
+ }
+ }.mkString(",")
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 2a92ef99b9f37..53775db251bc6 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy
-import java.io.{File, IOException}
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream, File, IOException}
import java.security.PrivilegedExceptionAction
import java.text.DateFormat
import java.util.{Arrays, Comparator, Date, Locale}
@@ -81,29 +81,7 @@ class SparkHadoopUtil extends Logging {
* configuration.
*/
def appendS3AndSparkHadoopConfigurations(conf: SparkConf, hadoopConf: Configuration): Unit = {
- // Note: this null check is around more than just access to the "conf" object to maintain
- // the behavior of the old implementation of this code, for backwards compatibility.
- if (conf != null) {
- // Explicitly check for S3 environment variables
- val keyId = System.getenv("AWS_ACCESS_KEY_ID")
- val accessKey = System.getenv("AWS_SECRET_ACCESS_KEY")
- if (keyId != null && accessKey != null) {
- hadoopConf.set("fs.s3.awsAccessKeyId", keyId)
- hadoopConf.set("fs.s3n.awsAccessKeyId", keyId)
- hadoopConf.set("fs.s3a.access.key", keyId)
- hadoopConf.set("fs.s3.awsSecretAccessKey", accessKey)
- hadoopConf.set("fs.s3n.awsSecretAccessKey", accessKey)
- hadoopConf.set("fs.s3a.secret.key", accessKey)
-
- val sessionToken = System.getenv("AWS_SESSION_TOKEN")
- if (sessionToken != null) {
- hadoopConf.set("fs.s3a.session.token", sessionToken)
- }
- }
- appendSparkHadoopConfigs(conf, hadoopConf)
- val bufferSize = conf.get("spark.buffer.size", "65536")
- hadoopConf.set("io.file.buffer.size", bufferSize)
- }
+ SparkHadoopUtil.appendS3AndSparkHadoopConfigurations(conf, hadoopConf)
}
/**
@@ -111,10 +89,7 @@ class SparkHadoopUtil extends Logging {
* configuration without the spark.hadoop. prefix.
*/
def appendSparkHadoopConfigs(conf: SparkConf, hadoopConf: Configuration): Unit = {
- // Copy any "spark.hadoop.foo=bar" spark properties into conf as "foo=bar"
- for ((key, value) <- conf.getAll if key.startsWith("spark.hadoop.")) {
- hadoopConf.set(key.substring("spark.hadoop.".length), value)
- }
+ SparkHadoopUtil.appendSparkHadoopConfigs(conf, hadoopConf)
}
/**
@@ -134,9 +109,7 @@ class SparkHadoopUtil extends Logging {
* subsystems.
*/
def newConfiguration(conf: SparkConf): Configuration = {
- val hadoopConf = new Configuration()
- appendS3AndSparkHadoopConfigurations(conf, hadoopConf)
- hadoopConf
+ SparkHadoopUtil.newConfiguration(conf)
}
/**
@@ -147,14 +120,18 @@ class SparkHadoopUtil extends Logging {
def isYarnMode(): Boolean = { false }
- def getCurrentUserCredentials(): Credentials = { null }
-
- def addCurrentUserCredentials(creds: Credentials) {}
-
def addSecretKeyToUserCredentials(key: String, secret: String) {}
def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { null }
+ def getCurrentUserCredentials(): Credentials = {
+ UserGroupInformation.getCurrentUser().getCredentials()
+ }
+
+ def addCurrentUserCredentials(creds: Credentials): Unit = {
+ UserGroupInformation.getCurrentUser.addCredentials(creds)
+ }
+
def loginUserFromKeytab(principalName: String, keytabFilename: String): Unit = {
if (!new File(keytabFilename).exists()) {
throw new SparkException(s"Keytab file: ${keytabFilename} does not exist")
@@ -425,6 +402,21 @@ class SparkHadoopUtil extends Logging {
s"${if (status.isDirectory) "d" else "-"}$perm")
false
}
+
+ def serialize(creds: Credentials): Array[Byte] = {
+ val byteStream = new ByteArrayOutputStream
+ val dataStream = new DataOutputStream(byteStream)
+ creds.writeTokenStorageToStream(dataStream)
+ byteStream.toByteArray
+ }
+
+ def deserialize(tokenBytes: Array[Byte]): Credentials = {
+ val tokensBuf = new ByteArrayInputStream(tokenBytes)
+
+ val creds = new Credentials()
+ creds.readTokenStorageStream(new DataInputStream(tokensBuf))
+ creds
+ }
}
object SparkHadoopUtil {
@@ -460,4 +452,50 @@ object SparkHadoopUtil {
hadoop
}
}
+
+ /**
+ * Returns a Configuration object with Spark configuration applied on top. Unlike
+ * the instance method, this will always return a Configuration instance, and not a
+ * cluster manager-specific type.
+ */
+ private[spark] def newConfiguration(conf: SparkConf): Configuration = {
+ val hadoopConf = new Configuration()
+ appendS3AndSparkHadoopConfigurations(conf, hadoopConf)
+ hadoopConf
+ }
+
+ private def appendS3AndSparkHadoopConfigurations(
+ conf: SparkConf,
+ hadoopConf: Configuration): Unit = {
+ // Note: this null check is around more than just access to the "conf" object to maintain
+ // the behavior of the old implementation of this code, for backwards compatibility.
+ if (conf != null) {
+ // Explicitly check for S3 environment variables
+ val keyId = System.getenv("AWS_ACCESS_KEY_ID")
+ val accessKey = System.getenv("AWS_SECRET_ACCESS_KEY")
+ if (keyId != null && accessKey != null) {
+ hadoopConf.set("fs.s3.awsAccessKeyId", keyId)
+ hadoopConf.set("fs.s3n.awsAccessKeyId", keyId)
+ hadoopConf.set("fs.s3a.access.key", keyId)
+ hadoopConf.set("fs.s3.awsSecretAccessKey", accessKey)
+ hadoopConf.set("fs.s3n.awsSecretAccessKey", accessKey)
+ hadoopConf.set("fs.s3a.secret.key", accessKey)
+
+ val sessionToken = System.getenv("AWS_SESSION_TOKEN")
+ if (sessionToken != null) {
+ hadoopConf.set("fs.s3a.session.token", sessionToken)
+ }
+ }
+ appendSparkHadoopConfigs(conf, hadoopConf)
+ val bufferSize = conf.get("spark.buffer.size", "65536")
+ hadoopConf.set("io.file.buffer.size", bufferSize)
+ }
+ }
+
+ private def appendSparkHadoopConfigs(conf: SparkConf, hadoopConf: Configuration): Unit = {
+ // Copy any "spark.hadoop.foo=bar" spark properties into conf as "foo=bar"
+ for ((key, value) <- conf.getAll if key.startsWith("spark.hadoop.")) {
+ hadoopConf.set(key.substring("spark.hadoop.".length), value)
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 1f916ebde0c32..a909cc80008a1 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -20,20 +20,18 @@ package org.apache.spark.deploy
import java.io._
import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException}
import java.net.URL
-import java.security.{KeyStore, PrivilegedExceptionAction}
-import java.security.cert.X509Certificate
+import java.security.PrivilegedExceptionAction
import java.text.ParseException
-import javax.net.ssl._
import scala.annotation.tailrec
import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
import scala.util.Properties
-import com.google.common.io.ByteStreams
import org.apache.commons.lang3.StringUtils
import org.apache.hadoop.conf.{Configuration => HadoopConfiguration}
-import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.Path
import org.apache.hadoop.security.UserGroupInformation
+import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.ivy.Ivy
import org.apache.ivy.core.LogOptions
import org.apache.ivy.core.module.descriptor._
@@ -49,6 +47,7 @@ import org.apache.ivy.plugins.resolver.{ChainResolver, FileSystemResolver, IBibl
import org.apache.spark._
import org.apache.spark.api.r.RUtils
import org.apache.spark.deploy.rest._
+import org.apache.spark.internal.Logging
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.util._
@@ -67,7 +66,9 @@ private[deploy] object SparkSubmitAction extends Enumeration {
* This program handles setting up the classpath with relevant Spark dependencies and provides
* a layer over the different cluster managers and deploy modes that Spark supports.
*/
-object SparkSubmit extends CommandLineUtils {
+object SparkSubmit extends CommandLineUtils with Logging {
+
+ import DependencyUtils._
// Cluster managers
private val YARN = 1
@@ -112,6 +113,10 @@ object SparkSubmit extends CommandLineUtils {
// scalastyle:on println
override def main(args: Array[String]): Unit = {
+ // Initialize logging if it hasn't been done yet. Keep track of whether logging needs to
+ // be reset before the application starts.
+ val uninitLog = initializeLogIfNecessary(true, silent = true)
+
val appArgs = new SparkSubmitArguments(args)
if (appArgs.verbose) {
// scalastyle:off println
@@ -119,7 +124,7 @@ object SparkSubmit extends CommandLineUtils {
// scalastyle:on println
}
appArgs.action match {
- case SparkSubmitAction.SUBMIT => submit(appArgs)
+ case SparkSubmitAction.SUBMIT => submit(appArgs, uninitLog)
case SparkSubmitAction.KILL => kill(appArgs)
case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs)
}
@@ -152,7 +157,7 @@ object SparkSubmit extends CommandLineUtils {
* main class.
*/
@tailrec
- private def submit(args: SparkSubmitArguments): Unit = {
+ private def submit(args: SparkSubmitArguments, uninitLog: Boolean): Unit = {
val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args)
def doRunMain(): Unit = {
@@ -184,11 +189,16 @@ object SparkSubmit extends CommandLineUtils {
}
}
- // In standalone cluster mode, there are two submission gateways:
- // (1) The traditional RPC gateway using o.a.s.deploy.Client as a wrapper
- // (2) The new REST-based gateway introduced in Spark 1.3
- // The latter is the default behavior as of Spark 1.3, but Spark submit will fail over
- // to use the legacy gateway if the master endpoint turns out to be not a REST server.
+ // Let the main class re-initialize the logging system once it starts.
+ if (uninitLog) {
+ Logging.uninitialize()
+ }
+
+ // In standalone cluster mode, there are two submission gateways:
+ // (1) The traditional RPC gateway using o.a.s.deploy.Client as a wrapper
+ // (2) The new REST-based gateway introduced in Spark 1.3
+ // The latter is the default behavior as of Spark 1.3, but Spark submit will fail over
+ // to use the legacy gateway if the master endpoint turns out to be not a REST server.
if (args.isStandaloneCluster && args.useRest) {
try {
// scalastyle:off println
@@ -201,7 +211,7 @@ object SparkSubmit extends CommandLineUtils {
printWarning(s"Master endpoint ${args.master} was not a REST server. " +
"Falling back to legacy submission gateway instead.")
args.useRest = false
- submit(args)
+ submit(args, false)
}
// In all other modes, just run the main class as prepared
} else {
@@ -211,14 +221,20 @@ object SparkSubmit extends CommandLineUtils {
/**
* Prepare the environment for submitting an application.
- * This returns a 4-tuple:
- * (1) the arguments for the child process,
- * (2) a list of classpath entries for the child,
- * (3) a map of system properties, and
- * (4) the main class for the child
+ *
+ * @param args the parsed SparkSubmitArguments used for environment preparation.
+ * @param conf the Hadoop Configuration, this argument will only be set in unit test.
+ * @return a 4-tuple:
+ * (1) the arguments for the child process,
+ * (2) a list of classpath entries for the child,
+ * (3) a map of system properties, and
+ * (4) the main class for the child
+ *
* Exposed for testing.
*/
- private[deploy] def prepareSubmitEnvironment(args: SparkSubmitArguments)
+ private[deploy] def prepareSubmitEnvironment(
+ args: SparkSubmitArguments,
+ conf: Option[HadoopConfiguration] = None)
: (Seq[String], Seq[String], Map[String, String], String) = {
// Return values
val childArgs = new ArrayBuffer[String]()
@@ -327,8 +343,10 @@ object SparkSubmit extends CommandLineUtils {
}
}
- val hadoopConf = new HadoopConfiguration()
- val targetDir = DependencyUtils.createTempDir()
+ val sparkConf = new SparkConf(false)
+ args.sparkProperties.foreach { case (k, v) => sparkConf.set(k, v) }
+ val hadoopConf = conf.getOrElse(SparkHadoopUtil.newConfiguration(sparkConf))
+ val targetDir = Utils.createTempDir()
// Resolve glob path for different resources.
args.jars = Option(args.jars).map(resolveGlobPaths(_, hadoopConf)).orNull
@@ -337,15 +355,22 @@ object SparkSubmit extends CommandLineUtils {
args.archives = Option(args.archives).map(resolveGlobPaths(_, hadoopConf)).orNull
// In client mode, download remote files.
+ var localPrimaryResource: String = null
+ var localJars: String = null
+ var localPyFiles: String = null
if (deployMode == CLIENT) {
- args.primaryResource = Option(args.primaryResource).map {
- downloadFile(_, targetDir, args.sparkProperties, hadoopConf)
+ // This security manager will not need an auth secret, but set a dummy value in case
+ // spark.authenticate is enabled, otherwise an exception is thrown.
+ sparkConf.set(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused")
+ val secMgr = new SecurityManager(sparkConf)
+ localPrimaryResource = Option(args.primaryResource).map {
+ downloadFile(_, targetDir, sparkConf, hadoopConf, secMgr)
}.orNull
- args.jars = Option(args.jars).map {
- downloadFileList(_, targetDir, args.sparkProperties, hadoopConf)
+ localJars = Option(args.jars).map {
+ downloadFileList(_, targetDir, sparkConf, hadoopConf, secMgr)
}.orNull
- args.pyFiles = Option(args.pyFiles).map {
- downloadFileList(_, targetDir, args.sparkProperties, hadoopConf)
+ localPyFiles = Option(args.pyFiles).map {
+ downloadFileList(_, targetDir, sparkConf, hadoopConf, secMgr)
}.orNull
}
@@ -357,7 +382,7 @@ object SparkSubmit extends CommandLineUtils {
// If a python file is provided, add it to the child arguments and list of files to deploy.
// Usage: PythonAppRunner [app arguments]
args.mainClass = "org.apache.spark.deploy.PythonRunner"
- args.childArgs = ArrayBuffer(args.primaryResource, args.pyFiles) ++ args.childArgs
+ args.childArgs = ArrayBuffer(localPrimaryResource, localPyFiles) ++ args.childArgs
if (clusterManager != YARN) {
// The YARN backend distributes the primary file differently, so don't merge it.
args.files = mergeFileLists(args.files, args.primaryResource)
@@ -367,8 +392,8 @@ object SparkSubmit extends CommandLineUtils {
// The YARN backend handles python files differently, so don't merge the lists.
args.files = mergeFileLists(args.files, args.pyFiles)
}
- if (args.pyFiles != null) {
- sysProps("spark.submit.pyFiles") = args.pyFiles
+ if (localPyFiles != null) {
+ sysProps("spark.submit.pyFiles") = localPyFiles
}
}
@@ -422,7 +447,7 @@ object SparkSubmit extends CommandLineUtils {
// If an R file is provided, add it to the child arguments and list of files to deploy.
// Usage: RRunner [app arguments]
args.mainClass = "org.apache.spark.deploy.RRunner"
- args.childArgs = ArrayBuffer(args.primaryResource) ++ args.childArgs
+ args.childArgs = ArrayBuffer(localPrimaryResource) ++ args.childArgs
args.files = mergeFileLists(args.files, args.primaryResource)
}
}
@@ -467,6 +492,7 @@ object SparkSubmit extends CommandLineUtils {
OptionAssigner(args.queue, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.queue"),
OptionAssigner(args.numExecutors, YARN, ALL_DEPLOY_MODES,
sysProp = "spark.executor.instances"),
+ OptionAssigner(args.pyFiles, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.pyFiles"),
OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.jars"),
OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.files"),
OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.archives"),
@@ -494,15 +520,28 @@ object SparkSubmit extends CommandLineUtils {
sysProp = "spark.driver.cores"),
OptionAssigner(args.supervise.toString, STANDALONE | MESOS, CLUSTER,
sysProp = "spark.driver.supervise"),
- OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy")
+ OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy"),
+
+ // An internal option used only for spark-shell to add user jars to repl's classloader,
+ // previously it uses "spark.jars" or "spark.yarn.dist.jars" which now may be pointed to
+ // remote jars, so adding a new option to only specify local jars for spark-shell internally.
+ OptionAssigner(localJars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.repl.local.jars")
)
// In client mode, launch the application main class directly
// In addition, add the main application jar and any added jars (if any) to the classpath
- // Also add the main application jar and any added jars to classpath in case YARN client
- // requires these jars.
- if (deployMode == CLIENT || isYarnCluster) {
+ if (deployMode == CLIENT) {
childMainClass = args.mainClass
+ if (localPrimaryResource != null && isUserJar(localPrimaryResource)) {
+ childClasspath += localPrimaryResource
+ }
+ if (localJars != null) { childClasspath ++= localJars.split(",") }
+ }
+ // Add the main application jar and any added jars to classpath in case YARN client
+ // requires these jars.
+ // This assumes both primaryResource and user jars are local jars, otherwise it will not be
+ // added to the classpath of YARN client.
+ if (isYarnCluster) {
if (isUserJar(args.primaryResource)) {
childClasspath += args.primaryResource
}
@@ -560,26 +599,28 @@ object SparkSubmit extends CommandLineUtils {
if (args.isPython) {
sysProps.put("spark.yarn.isPython", "true")
}
-
- if (args.pyFiles != null) {
- sysProps("spark.submit.pyFiles") = args.pyFiles
- }
}
// assure a keytab is available from any place in a JVM
- if (clusterManager == YARN || clusterManager == LOCAL) {
+ if (clusterManager == YARN || clusterManager == LOCAL || clusterManager == MESOS) {
if (args.principal != null) {
- require(args.keytab != null, "Keytab must be specified when principal is specified")
- SparkHadoopUtil.get.loginUserFromKeytab(args.principal, args.keytab)
- // Add keytab and principal configurations in sysProps to make them available
- // for later use; e.g. in spark sql, the isolated class loader used to talk
- // to HiveMetastore will use these settings. They will be set as Java system
- // properties and then loaded by SparkConf
- sysProps.put("spark.yarn.keytab", args.keytab)
- sysProps.put("spark.yarn.principal", args.principal)
+ if (args.keytab != null) {
+ require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist")
+ // Add keytab and principal configurations in sysProps to make them available
+ // for later use; e.g. in spark sql, the isolated class loader used to talk
+ // to HiveMetastore will use these settings. They will be set as Java system
+ // properties and then loaded by SparkConf
+ sysProps.put("spark.yarn.keytab", args.keytab)
+ sysProps.put("spark.yarn.principal", args.principal)
+ UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab)
+ }
}
}
+ if (clusterManager == MESOS && UserGroupInformation.isSecurityEnabled) {
+ setRMPrincipal(sysProps)
+ }
+
// In yarn-cluster mode, use yarn.Client as a wrapper around the user class
if (isYarnCluster) {
childMainClass = "org.apache.spark.deploy.yarn.Client"
@@ -682,6 +723,18 @@ object SparkSubmit extends CommandLineUtils {
(childArgs, childClasspath, sysProps, childMainClass)
}
+ // [SPARK-20328]. HadoopRDD calls into a Hadoop library that fetches delegation tokens with
+ // renewer set to the YARN ResourceManager. Since YARN isn't configured in Mesos mode, we
+ // must trick it into thinking we're YARN.
+ private def setRMPrincipal(sysProps: HashMap[String, String]): Unit = {
+ val shortUserName = UserGroupInformation.getCurrentUser.getShortUserName
+ val key = s"spark.hadoop.${YarnConfiguration.RM_PRINCIPAL}"
+ // scalastyle:off println
+ printStream.println(s"Setting ${key} to ${shortUserName}")
+ // scalastyle:off println
+ sysProps.put(key, shortUserName)
+ }
+
/**
* Run the main method of the child class using the provided launch environment.
*
@@ -854,136 +907,6 @@ object SparkSubmit extends CommandLineUtils {
if (merged == "") null else merged
}
- /**
- * Download a list of remote files to temp local files. If the file is local, the original file
- * will be returned.
- * @param fileList A comma separated file list.
- * @param targetDir A temporary directory for which downloaded files
- * @param sparkProperties Spark properties
- * @return A comma separated local files list.
- */
- private[deploy] def downloadFileList(
- fileList: String,
- targetDir: File,
- sparkProperties: Map[String, String],
- hadoopConf: HadoopConfiguration): String = {
- require(fileList != null, "fileList cannot be null.")
- fileList.split(",")
- .map(downloadFile(_, targetDir, sparkProperties, hadoopConf))
- .mkString(",")
- }
-
- /**
- * Download a file from the remote to a local temporary directory. If the input path points to
- * a local path, returns it with no operation.
- * @param path A file path from where the files will be downloaded.
- * @param targetDir A temporary directory for which downloaded files
- * @param sparkProperties Spark properties
- * @return A comma separated local files list.
- */
- private[deploy] def downloadFile(
- path: String,
- targetDir: File,
- sparkProperties: Map[String, String],
- hadoopConf: HadoopConfiguration): String = {
- require(path != null, "path cannot be null.")
- val uri = Utils.resolveURI(path)
- uri.getScheme match {
- case "file" | "local" => path
- case "http" | "https" | "ftp" =>
- val uc = uri.toURL.openConnection()
- uc match {
- case https: HttpsURLConnection =>
- val trustStore = sparkProperties.get("spark.ssl.fs.trustStore")
- .orElse(sparkProperties.get("spark.ssl.trustStore"))
- val trustStorePwd = sparkProperties.get("spark.ssl.fs.trustStorePassword")
- .orElse(sparkProperties.get("spark.ssl.trustStorePassword"))
- .map(_.toCharArray)
- .orNull
- val protocol = sparkProperties.get("spark.ssl.fs.protocol")
- .orElse(sparkProperties.get("spark.ssl.protocol"))
- if (protocol.isEmpty) {
- printErrorAndExit("spark ssl protocol is required when enabling SSL connection.")
- }
-
- val trustStoreManagers = trustStore.map { t =>
- var input: InputStream = null
- try {
- input = new FileInputStream(new File(t))
- val ks = KeyStore.getInstance(KeyStore.getDefaultType)
- ks.load(input, trustStorePwd)
- val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
- tmf.init(ks)
- tmf.getTrustManagers
- } finally {
- if (input != null) {
- input.close()
- input = null
- }
- }
- }.getOrElse {
- Array({
- new X509TrustManager {
- override def getAcceptedIssuers: Array[X509Certificate] = null
- override def checkClientTrusted(
- x509Certificates: Array[X509Certificate], s: String) {}
- override def checkServerTrusted(
- x509Certificates: Array[X509Certificate], s: String) {}
- }: TrustManager
- })
- }
- val sslContext = SSLContext.getInstance(protocol.get)
- sslContext.init(null, trustStoreManagers, null)
- https.setSSLSocketFactory(sslContext.getSocketFactory)
- https.setHostnameVerifier(new HostnameVerifier {
- override def verify(s: String, sslSession: SSLSession): Boolean = false
- })
-
- case _ =>
- }
-
- uc.setConnectTimeout(60 * 1000)
- uc.setReadTimeout(60 * 1000)
- uc.connect()
- val in = uc.getInputStream
- val fileName = new Path(uri).getName
- val tempFile = new File(targetDir, fileName)
- val out = new FileOutputStream(tempFile)
- // scalastyle:off println
- printStream.println(s"Downloading ${uri.toString} to ${tempFile.getAbsolutePath}.")
- // scalastyle:on println
- try {
- ByteStreams.copy(in, out)
- } finally {
- in.close()
- out.close()
- }
- tempFile.toURI.toString
- case _ =>
- val fs = FileSystem.get(uri, hadoopConf)
- val tmpFile = new File(targetDir, new Path(uri).getName)
- // scalastyle:off println
- printStream.println(s"Downloading ${uri.toString} to ${tmpFile.getAbsolutePath}.")
- // scalastyle:on println
- fs.copyToLocalFile(new Path(uri), new Path(tmpFile.getAbsolutePath))
- tmpFile.toURI.toString
- }
- }
-
- private[deploy] def resolveGlobPaths(paths: String, hadoopConf: HadoopConfiguration): String = {
- require(paths != null, "paths cannot be null.")
- paths.split(",").map(_.trim).filter(_.nonEmpty).flatMap { path =>
- val uri = Utils.resolveURI(path)
- uri.getScheme match {
- case "local" | "http" | "https" | "ftp" => Array(path)
- case _ =>
- val fs = FileSystem.get(uri, hadoopConf)
- Option(fs.globStatus(new Path(uri))).map { status =>
- status.filter(_.isFile).map(_.getPath.toUri.toString)
- }.getOrElse(Array(path))
- }
- }.mkString(",")
- }
}
/** Provides utility functions to be used inside SparkSubmit. */
@@ -997,7 +920,7 @@ private[spark] object SparkSubmitUtils {
// We need to specify each component explicitly, otherwise we miss spark-streaming-kafka-0-8 and
// other spark-streaming utility components. Underscore is there to differentiate between
// spark-streaming_2.1x and spark-streaming-kafka-0-8-assembly_2.1x
- val IVY_DEFAULT_EXCLUDES = Seq("catalyst_", "core_", "graphx_", "launcher_", "mllib_",
+ val IVY_DEFAULT_EXCLUDES = Seq("catalyst_", "core_", "graphx_", "kvstore_", "launcher_", "mllib_",
"mllib-local_", "network-common_", "network-shuffle_", "repl_", "sketch_", "sql_", "streaming_",
"tags_", "unsafe_")
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 687fd2d3ffe64..20fe911f2d294 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -249,7 +249,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
val appSecManager = new SecurityManager(conf)
SparkUI.createHistoryUI(conf, replayBus, appSecManager, appInfo.name,
HistoryServer.getAttemptURI(appId, attempt.attemptId),
- attempt.startTime)
+ Some(attempt.lastUpdated), attempt.startTime)
// Do not call ui.bind() to avoid creating a new server for each application
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala
index 35621daf9c0d7..78b0e6b2cbf39 100644
--- a/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala
@@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.security.Credentials
import org.apache.hadoop.security.token.{Token, TokenIdentifier}
+import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
@@ -34,6 +35,7 @@ private[security] class HBaseDelegationTokenProvider
override def obtainDelegationTokens(
hadoopConf: Configuration,
+ sparkConf: SparkConf,
creds: Credentials): Option[Long] = {
try {
val mirror = universe.runtimeMirror(Utils.getContextOrSparkClassLoader)
diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala
index 01cbfe1ee6ae1..c134b7ebe38fa 100644
--- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala
@@ -55,6 +55,14 @@ private[spark] class HadoopDelegationTokenManager(
logDebug(s"Using the following delegation token providers: " +
s"${delegationTokenProviders.keys.mkString(", ")}.")
+ /** Construct a [[HadoopDelegationTokenManager]] for the default Hadoop filesystem */
+ def this(sparkConf: SparkConf, hadoopConf: Configuration) = {
+ this(
+ sparkConf,
+ hadoopConf,
+ hadoopConf => Set(FileSystem.get(hadoopConf).getHomeDirectory.getFileSystem(hadoopConf)))
+ }
+
private def getDelegationTokenProviders: Map[String, HadoopDelegationTokenProvider] = {
val providers = List(new HadoopFSDelegationTokenProvider(fileSystems),
new HiveDelegationTokenProvider,
@@ -108,7 +116,7 @@ private[spark] class HadoopDelegationTokenManager(
creds: Credentials): Long = {
delegationTokenProviders.values.flatMap { provider =>
if (provider.delegationTokensRequired(hadoopConf)) {
- provider.obtainDelegationTokens(hadoopConf, creds)
+ provider.obtainDelegationTokens(hadoopConf, sparkConf, creds)
} else {
logDebug(s"Service ${provider.serviceName} does not require a token." +
s" Check your configuration to see if security is disabled or not.")
diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala
index f162e7e58c53a..1ba245e84af4b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala
@@ -20,6 +20,8 @@ package org.apache.spark.deploy.security
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.security.Credentials
+import org.apache.spark.SparkConf
+
/**
* Hadoop delegation token provider.
*/
@@ -46,5 +48,6 @@ private[spark] trait HadoopDelegationTokenProvider {
*/
def obtainDelegationTokens(
hadoopConf: Configuration,
+ sparkConf: SparkConf,
creds: Credentials): Option[Long]
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala
index f0ac7f501ceb1..300773c58b183 100644
--- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala
@@ -26,8 +26,9 @@ import org.apache.hadoop.mapred.Master
import org.apache.hadoop.security.{Credentials, UserGroupInformation}
import org.apache.hadoop.security.token.delegation.AbstractDelegationTokenIdentifier
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
private[deploy] class HadoopFSDelegationTokenProvider(fileSystems: Configuration => Set[FileSystem])
extends HadoopDelegationTokenProvider with Logging {
@@ -41,21 +42,20 @@ private[deploy] class HadoopFSDelegationTokenProvider(fileSystems: Configuration
override def obtainDelegationTokens(
hadoopConf: Configuration,
+ sparkConf: SparkConf,
creds: Credentials): Option[Long] = {
val fsToGetTokens = fileSystems(hadoopConf)
- val newCreds = fetchDelegationTokens(
- getTokenRenewer(hadoopConf),
- fsToGetTokens)
+ val fetchCreds = fetchDelegationTokens(getTokenRenewer(hadoopConf), fsToGetTokens, creds)
// Get the token renewal interval if it is not set. It will only be called once.
if (tokenRenewalInterval == null) {
- tokenRenewalInterval = getTokenRenewalInterval(hadoopConf, fsToGetTokens)
+ tokenRenewalInterval = getTokenRenewalInterval(hadoopConf, sparkConf, fsToGetTokens)
}
// Get the time of next renewal.
val nextRenewalDate = tokenRenewalInterval.flatMap { interval =>
- val nextRenewalDates = newCreds.getAllTokens.asScala
+ val nextRenewalDates = fetchCreds.getAllTokens.asScala
.filter(_.decodeIdentifier().isInstanceOf[AbstractDelegationTokenIdentifier])
.map { token =>
val identifier = token
@@ -66,7 +66,6 @@ private[deploy] class HadoopFSDelegationTokenProvider(fileSystems: Configuration
if (nextRenewalDates.isEmpty) None else Some(nextRenewalDates.min)
}
- creds.addAll(newCreds)
nextRenewalDate
}
@@ -89,9 +88,8 @@ private[deploy] class HadoopFSDelegationTokenProvider(fileSystems: Configuration
private def fetchDelegationTokens(
renewer: String,
- filesystems: Set[FileSystem]): Credentials = {
-
- val creds = new Credentials()
+ filesystems: Set[FileSystem],
+ creds: Credentials): Credentials = {
filesystems.foreach { fs =>
logInfo("getting token for: " + fs)
@@ -103,25 +101,27 @@ private[deploy] class HadoopFSDelegationTokenProvider(fileSystems: Configuration
private def getTokenRenewalInterval(
hadoopConf: Configuration,
+ sparkConf: SparkConf,
filesystems: Set[FileSystem]): Option[Long] = {
// We cannot use the tokens generated with renewer yarn. Trying to renew
// those will fail with an access control issue. So create new tokens with the logged in
// user as renewer.
- val creds = fetchDelegationTokens(
- UserGroupInformation.getCurrentUser.getUserName,
- filesystems)
-
- val renewIntervals = creds.getAllTokens.asScala.filter {
- _.decodeIdentifier().isInstanceOf[AbstractDelegationTokenIdentifier]
- }.flatMap { token =>
- Try {
- val newExpiration = token.renew(hadoopConf)
- val identifier = token.decodeIdentifier().asInstanceOf[AbstractDelegationTokenIdentifier]
- val interval = newExpiration - identifier.getIssueDate
- logInfo(s"Renewal interval is $interval for token ${token.getKind.toString}")
- interval
- }.toOption
+ sparkConf.get(PRINCIPAL).flatMap { renewer =>
+ val creds = new Credentials()
+ fetchDelegationTokens(renewer, filesystems, creds)
+
+ val renewIntervals = creds.getAllTokens.asScala.filter {
+ _.decodeIdentifier().isInstanceOf[AbstractDelegationTokenIdentifier]
+ }.flatMap { token =>
+ Try {
+ val newExpiration = token.renew(hadoopConf)
+ val identifier = token.decodeIdentifier().asInstanceOf[AbstractDelegationTokenIdentifier]
+ val interval = newExpiration - identifier.getIssueDate
+ logInfo(s"Renewal interval is $interval for token ${token.getKind.toString}")
+ interval
+ }.toOption
+ }
+ if (renewIntervals.isEmpty) None else Some(renewIntervals.min)
}
- if (renewIntervals.isEmpty) None else Some(renewIntervals.min)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala
index 53b9f898c6e7d..b31cc595ed83b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala
@@ -30,6 +30,7 @@ import org.apache.hadoop.io.Text
import org.apache.hadoop.security.{Credentials, UserGroupInformation}
import org.apache.hadoop.security.token.Token
+import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
@@ -61,6 +62,7 @@ private[security] class HiveDelegationTokenProvider
override def obtainDelegationTokens(
hadoopConf: Configuration,
+ sparkConf: SparkConf,
creds: Credentials): Option[Long] = {
try {
val conf = hiveConf(hadoopConf)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
index cd3e361530c18..c1671192e0c64 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala
@@ -22,7 +22,7 @@ import java.io.File
import org.apache.commons.lang3.StringUtils
import org.apache.spark.{SecurityManager, SparkConf}
-import org.apache.spark.deploy.{DependencyUtils, SparkSubmit}
+import org.apache.spark.deploy.{DependencyUtils, SparkHadoopUtil, SparkSubmit}
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils}
@@ -72,6 +72,10 @@ object DriverWrapper {
}
private def setupDependencies(loader: MutableURLClassLoader, userJar: String): Unit = {
+ val sparkConf = new SparkConf()
+ val secMgr = new SecurityManager(sparkConf)
+ val hadoopConf = SparkHadoopUtil.newConfiguration(sparkConf)
+
val Seq(packagesExclusions, packages, repositories, ivyRepoPath) =
Seq("spark.jars.excludes", "spark.jars.packages", "spark.jars.repositories", "spark.jars.ivy")
.map(sys.props.get(_).orNull)
@@ -86,7 +90,8 @@ object DriverWrapper {
jarsProp
}
}
- val localJars = DependencyUtils.resolveAndDownloadJars(jars, userJar)
+ val localJars = DependencyUtils.resolveAndDownloadJars(jars, userJar, sparkConf, hadoopConf,
+ secMgr)
DependencyUtils.addJarsToClassPath(localJars, loader)
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index c60d33b7066cd..42eb7e9547337 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -219,6 +219,11 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
SparkHadoopUtil.get.startCredentialUpdater(driverConf)
}
+ cfg.hadoopDelegationCreds.foreach { hadoopCreds =>
+ val creds = SparkHadoopUtil.get.deserialize(hadoopCreds)
+ SparkHadoopUtil.get.addCurrentUserCredentials(creds)
+ }
+
val env = SparkEnv.createExecutorEnv(
driverConf, executorId, hostname, cores, cfg.ioEncryptionKey, isLocal = false)
diff --git a/core/src/main/scala/org/apache/spark/internal/Logging.scala b/core/src/main/scala/org/apache/spark/internal/Logging.scala
index c7f2847731fcb..c0d709ad25f29 100644
--- a/core/src/main/scala/org/apache/spark/internal/Logging.scala
+++ b/core/src/main/scala/org/apache/spark/internal/Logging.scala
@@ -96,47 +96,59 @@ trait Logging {
}
protected def initializeLogIfNecessary(isInterpreter: Boolean): Unit = {
+ initializeLogIfNecessary(isInterpreter, silent = false)
+ }
+
+ protected def initializeLogIfNecessary(
+ isInterpreter: Boolean,
+ silent: Boolean = false): Boolean = {
if (!Logging.initialized) {
Logging.initLock.synchronized {
if (!Logging.initialized) {
- initializeLogging(isInterpreter)
+ initializeLogging(isInterpreter, silent)
+ return true
}
}
}
+ false
}
- private def initializeLogging(isInterpreter: Boolean): Unit = {
+ private def initializeLogging(isInterpreter: Boolean, silent: Boolean): Unit = {
// Don't use a logger in here, as this is itself occurring during initialization of a logger
// If Log4j 1.2 is being used, but is not initialized, load a default properties file
- val binderClass = StaticLoggerBinder.getSingleton.getLoggerFactoryClassStr
- // This distinguishes the log4j 1.2 binding, currently
- // org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently
- // org.apache.logging.slf4j.Log4jLoggerFactory
- val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass)
- if (usingLog4j12) {
+ if (Logging.isLog4j12()) {
val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
// scalastyle:off println
if (!log4j12Initialized) {
+ Logging.defaultSparkLog4jConfig = true
val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
case Some(url) =>
PropertyConfigurator.configure(url)
- System.err.println(s"Using Spark's default log4j profile: $defaultLogProps")
+ if (!silent) {
+ System.err.println(s"Using Spark's default log4j profile: $defaultLogProps")
+ }
case None =>
System.err.println(s"Spark was unable to load $defaultLogProps")
}
}
+ val rootLogger = LogManager.getRootLogger()
+ if (Logging.defaultRootLevel == null) {
+ Logging.defaultRootLevel = rootLogger.getLevel()
+ }
+
if (isInterpreter) {
// Use the repl's main class to define the default log level when running the shell,
// overriding the root logger's config if they're different.
- val rootLogger = LogManager.getRootLogger()
val replLogger = LogManager.getLogger(logName)
val replLevel = Option(replLogger.getLevel()).getOrElse(Level.WARN)
if (replLevel != rootLogger.getEffectiveLevel()) {
- System.err.printf("Setting default log level to \"%s\".\n", replLevel)
- System.err.println("To adjust logging level use sc.setLogLevel(newLevel). " +
- "For SparkR, use setLogLevel(newLevel).")
+ if (!silent) {
+ System.err.printf("Setting default log level to \"%s\".\n", replLevel)
+ System.err.println("To adjust logging level use sc.setLogLevel(newLevel). " +
+ "For SparkR, use setLogLevel(newLevel).")
+ }
rootLogger.setLevel(replLevel)
}
}
@@ -150,8 +162,11 @@ trait Logging {
}
}
-private object Logging {
+private[spark] object Logging {
@volatile private var initialized = false
+ @volatile private var defaultRootLevel: Level = null
+ @volatile private var defaultSparkLog4jConfig = false
+
val initLock = new Object()
try {
// We use reflection here to handle the case where users remove the
@@ -165,4 +180,29 @@ private object Logging {
} catch {
case e: ClassNotFoundException => // can't log anything yet so just fail silently
}
+
+ /**
+ * Marks the logging system as not initialized. This does a best effort at resetting the
+ * logging system to its initial state so that the next class to use logging triggers
+ * initialization again.
+ */
+ def uninitialize(): Unit = initLock.synchronized {
+ if (isLog4j12()) {
+ if (defaultSparkLog4jConfig) {
+ defaultSparkLog4jConfig = false
+ LogManager.resetConfiguration()
+ } else {
+ LogManager.getRootLogger().setLevel(defaultRootLevel)
+ }
+ }
+ this.initialized = false
+ }
+
+ private def isLog4j12(): Boolean = {
+ // This distinguishes the log4j 1.2 binding, currently
+ // org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently
+ // org.apache.logging.slf4j.Log4jLoggerFactory
+ val binderClass = StaticLoggerBinder.getSingleton.getLoggerFactoryClassStr
+ "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 9965683ab404b..a0c3ebf69e0d2 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -87,7 +87,7 @@ package object config {
.intConf
.createOptional
- private[spark] val PY_FILES = ConfigBuilder("spark.submit.pyFiles")
+ private[spark] val PY_FILES = ConfigBuilder("spark.yarn.dist.pyFiles")
.internal()
.stringConf
.toSequence
@@ -324,6 +324,15 @@ package object config {
.booleanConf
.createWithDefault(false)
+ private[spark] val BUFFER_WRITE_CHUNK_SIZE =
+ ConfigBuilder("spark.buffer.write.chunkSize")
+ .internal()
+ .doc("The chunk size during writing out the bytes of ChunkedByteBuffer.")
+ .bytesConf(ByteUnit.BYTE)
+ .checkValue(_ <= Int.MaxValue, "The chunk size during writing out the bytes of" +
+ " ChunkedByteBuffer should not larger than Int.MaxValue.")
+ .createWithDefault(64 * 1024 * 1024)
+
private[spark] val CHECKPOINT_COMPRESS =
ConfigBuilder("spark.checkpoint.compress")
.doc("Whether to compress RDD checkpoints. Generally a good idea. Compression will use " +
diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
index 22e26799138ba..b1d07ab2c9199 100644
--- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala
@@ -73,7 +73,8 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String)
val stagingDir: String = committer match {
// For FileOutputCommitter it has its own staging path called "work path".
- case f: FileOutputCommitter => Option(f.getWorkPath.toString).getOrElse(path)
+ case f: FileOutputCommitter =>
+ Option(f.getWorkPath).map(_.toString).getOrElse(path)
case _ => path
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/StatsdReporter.scala b/core/src/main/scala/org/apache/spark/metrics/sink/StatsdReporter.scala
new file mode 100644
index 0000000000000..ba75aa1c65cc6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/StatsdReporter.scala
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics.sink
+
+import java.io.IOException
+import java.net.{DatagramPacket, DatagramSocket, InetSocketAddress}
+import java.nio.charset.StandardCharsets.UTF_8
+import java.util.SortedMap
+import java.util.concurrent.TimeUnit
+
+import scala.collection.JavaConverters._
+import scala.util.{Failure, Success, Try}
+
+import com.codahale.metrics._
+import org.apache.hadoop.net.NetUtils
+
+import org.apache.spark.internal.Logging
+
+/**
+ * @see
+ * StatsD metric types
+ */
+private[spark] object StatsdMetricType {
+ val COUNTER = "c"
+ val GAUGE = "g"
+ val TIMER = "ms"
+ val Set = "s"
+}
+
+private[spark] class StatsdReporter(
+ registry: MetricRegistry,
+ host: String = "127.0.0.1",
+ port: Int = 8125,
+ prefix: String = "",
+ filter: MetricFilter = MetricFilter.ALL,
+ rateUnit: TimeUnit = TimeUnit.SECONDS,
+ durationUnit: TimeUnit = TimeUnit.MILLISECONDS)
+ extends ScheduledReporter(registry, "statsd-reporter", filter, rateUnit, durationUnit)
+ with Logging {
+
+ import StatsdMetricType._
+
+ private val address = new InetSocketAddress(host, port)
+ private val whitespace = "[\\s]+".r
+
+ override def report(
+ gauges: SortedMap[String, Gauge[_]],
+ counters: SortedMap[String, Counter],
+ histograms: SortedMap[String, Histogram],
+ meters: SortedMap[String, Meter],
+ timers: SortedMap[String, Timer]): Unit =
+ Try(new DatagramSocket) match {
+ case Failure(ioe: IOException) => logWarning("StatsD datagram socket construction failed",
+ NetUtils.wrapException(host, port, NetUtils.getHostname(), 0, ioe))
+ case Failure(e) => logWarning("StatsD datagram socket construction failed", e)
+ case Success(s) =>
+ implicit val socket = s
+ val localAddress = Try(socket.getLocalAddress).map(_.getHostAddress).getOrElse(null)
+ val localPort = socket.getLocalPort
+ Try {
+ gauges.entrySet.asScala.foreach(e => reportGauge(e.getKey, e.getValue))
+ counters.entrySet.asScala.foreach(e => reportCounter(e.getKey, e.getValue))
+ histograms.entrySet.asScala.foreach(e => reportHistogram(e.getKey, e.getValue))
+ meters.entrySet.asScala.foreach(e => reportMetered(e.getKey, e.getValue))
+ timers.entrySet.asScala.foreach(e => reportTimer(e.getKey, e.getValue))
+ } recover {
+ case ioe: IOException =>
+ logDebug(s"Unable to send packets to StatsD", NetUtils.wrapException(
+ address.getHostString, address.getPort, localAddress, localPort, ioe))
+ case e: Throwable => logDebug(s"Unable to send packets to StatsD at '$host:$port'", e)
+ }
+ Try(socket.close()) recover {
+ case ioe: IOException =>
+ logDebug("Error when close socket to StatsD", NetUtils.wrapException(
+ address.getHostString, address.getPort, localAddress, localPort, ioe))
+ case e: Throwable => logDebug("Error when close socket to StatsD", e)
+ }
+ }
+
+ private def reportGauge(name: String, gauge: Gauge[_])(implicit socket: DatagramSocket): Unit =
+ formatAny(gauge.getValue).foreach(v => send(fullName(name), v, GAUGE))
+
+ private def reportCounter(name: String, counter: Counter)(implicit socket: DatagramSocket): Unit =
+ send(fullName(name), format(counter.getCount), COUNTER)
+
+ private def reportHistogram(name: String, histogram: Histogram)
+ (implicit socket: DatagramSocket): Unit = {
+ val snapshot = histogram.getSnapshot
+ send(fullName(name, "count"), format(histogram.getCount), GAUGE)
+ send(fullName(name, "max"), format(snapshot.getMax), TIMER)
+ send(fullName(name, "mean"), format(snapshot.getMean), TIMER)
+ send(fullName(name, "min"), format(snapshot.getMin), TIMER)
+ send(fullName(name, "stddev"), format(snapshot.getStdDev), TIMER)
+ send(fullName(name, "p50"), format(snapshot.getMedian), TIMER)
+ send(fullName(name, "p75"), format(snapshot.get75thPercentile), TIMER)
+ send(fullName(name, "p95"), format(snapshot.get95thPercentile), TIMER)
+ send(fullName(name, "p98"), format(snapshot.get98thPercentile), TIMER)
+ send(fullName(name, "p99"), format(snapshot.get99thPercentile), TIMER)
+ send(fullName(name, "p999"), format(snapshot.get999thPercentile), TIMER)
+ }
+
+ private def reportMetered(name: String, meter: Metered)(implicit socket: DatagramSocket): Unit = {
+ send(fullName(name, "count"), format(meter.getCount), GAUGE)
+ send(fullName(name, "m1_rate"), format(convertRate(meter.getOneMinuteRate)), TIMER)
+ send(fullName(name, "m5_rate"), format(convertRate(meter.getFiveMinuteRate)), TIMER)
+ send(fullName(name, "m15_rate"), format(convertRate(meter.getFifteenMinuteRate)), TIMER)
+ send(fullName(name, "mean_rate"), format(convertRate(meter.getMeanRate)), TIMER)
+ }
+
+ private def reportTimer(name: String, timer: Timer)(implicit socket: DatagramSocket): Unit = {
+ val snapshot = timer.getSnapshot
+ send(fullName(name, "max"), format(convertDuration(snapshot.getMax)), TIMER)
+ send(fullName(name, "mean"), format(convertDuration(snapshot.getMean)), TIMER)
+ send(fullName(name, "min"), format(convertDuration(snapshot.getMin)), TIMER)
+ send(fullName(name, "stddev"), format(convertDuration(snapshot.getStdDev)), TIMER)
+ send(fullName(name, "p50"), format(convertDuration(snapshot.getMedian)), TIMER)
+ send(fullName(name, "p75"), format(convertDuration(snapshot.get75thPercentile)), TIMER)
+ send(fullName(name, "p95"), format(convertDuration(snapshot.get95thPercentile)), TIMER)
+ send(fullName(name, "p98"), format(convertDuration(snapshot.get98thPercentile)), TIMER)
+ send(fullName(name, "p99"), format(convertDuration(snapshot.get99thPercentile)), TIMER)
+ send(fullName(name, "p999"), format(convertDuration(snapshot.get999thPercentile)), TIMER)
+
+ reportMetered(name, timer)
+ }
+
+ private def send(name: String, value: String, metricType: String)
+ (implicit socket: DatagramSocket): Unit = {
+ val bytes = sanitize(s"$name:$value|$metricType").getBytes(UTF_8)
+ val packet = new DatagramPacket(bytes, bytes.length, address)
+ socket.send(packet)
+ }
+
+ private def fullName(names: String*): String = MetricRegistry.name(prefix, names : _*)
+
+ private def sanitize(s: String): String = whitespace.replaceAllIn(s, "-")
+
+ private def format(v: Any): String = formatAny(v).getOrElse("")
+
+ private def formatAny(v: Any): Option[String] =
+ v match {
+ case f: Float => Some("%2.2f".format(f))
+ case d: Double => Some("%2.2f".format(d))
+ case b: BigDecimal => Some("%2.2f".format(b))
+ case n: Number => Some(v.toString)
+ case _ => None
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/StatsdSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/StatsdSink.scala
new file mode 100644
index 0000000000000..859a2f6bcd456
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/StatsdSink.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics.sink
+
+import java.util.Properties
+import java.util.concurrent.TimeUnit
+
+import com.codahale.metrics.MetricRegistry
+
+import org.apache.spark.SecurityManager
+import org.apache.spark.internal.Logging
+import org.apache.spark.metrics.MetricsSystem
+
+private[spark] object StatsdSink {
+ val STATSD_KEY_HOST = "host"
+ val STATSD_KEY_PORT = "port"
+ val STATSD_KEY_PERIOD = "period"
+ val STATSD_KEY_UNIT = "unit"
+ val STATSD_KEY_PREFIX = "prefix"
+
+ val STATSD_DEFAULT_HOST = "127.0.0.1"
+ val STATSD_DEFAULT_PORT = "8125"
+ val STATSD_DEFAULT_PERIOD = "10"
+ val STATSD_DEFAULT_UNIT = "SECONDS"
+ val STATSD_DEFAULT_PREFIX = ""
+}
+
+private[spark] class StatsdSink(
+ val property: Properties,
+ val registry: MetricRegistry,
+ securityMgr: SecurityManager)
+ extends Sink with Logging {
+ import StatsdSink._
+
+ val host = property.getProperty(STATSD_KEY_HOST, STATSD_DEFAULT_HOST)
+ val port = property.getProperty(STATSD_KEY_PORT, STATSD_DEFAULT_PORT).toInt
+
+ val pollPeriod = property.getProperty(STATSD_KEY_PERIOD, STATSD_DEFAULT_PERIOD).toInt
+ val pollUnit =
+ TimeUnit.valueOf(property.getProperty(STATSD_KEY_UNIT, STATSD_DEFAULT_UNIT).toUpperCase)
+
+ val prefix = property.getProperty(STATSD_KEY_PREFIX, STATSD_DEFAULT_PREFIX)
+
+ MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod)
+
+ val reporter = new StatsdReporter(registry, host, port, prefix)
+
+ override def start(): Unit = {
+ reporter.start(pollPeriod, pollUnit)
+ logInfo(s"StatsdSink started with prefix: '$prefix'")
+ }
+
+ override def stop(): Unit = {
+ reporter.stop()
+ logInfo("StatsdSink stopped.")
+ }
+
+ override def report(): Unit = reporter.report()
+}
+
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
index 305fd9a6de10d..eb4cf94164fd4 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -25,7 +25,7 @@ import scala.reflect.ClassTag
import org.apache.spark.internal.Logging
import org.apache.spark.network.BlockDataManager
-import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.buffer.NioManagedBuffer
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager}
import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 5435f59ea0d28..8798dfc925362 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.io.Codec
import scala.language.implicitConversions
import scala.reflect.{classTag, ClassTag}
+import scala.util.hashing
import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
import org.apache.hadoop.io.{BytesWritable, NullWritable, Text}
@@ -448,7 +449,7 @@ abstract class RDD[T: ClassTag](
if (shuffle) {
/** Distributes elements evenly across output partitions, starting from a random partition. */
val distributePartition = (index: Int, items: Iterator[T]) => {
- var position = (new Random(index)).nextInt(numPartitions)
+ var position = (new Random(hashing.byteswap32(index))).nextInt(numPartitions)
items.map { t =>
// Note that the hash code of the key will just be the key itself. The HashPartitioner
// will mod it with the number of total partitions.
diff --git a/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala
index ab72addb2466b..facbb830a60d8 100644
--- a/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala
@@ -50,6 +50,7 @@ import org.apache.spark.util.PeriodicCheckpointer
* {{{
* val (rdd1, rdd2, rdd3, ...) = ...
* val cp = new PeriodicRDDCheckpointer(2, sc)
+ * cp.update(rdd1)
* rdd1.count();
* // persisted: rdd1
* cp.update(rdd2)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 21bf9d013ebef..562dd1da4fe14 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -281,6 +281,13 @@ class DAGScheduler(
eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception))
}
+ /**
+ * Called by the TaskSetManager when it decides a speculative task is needed.
+ */
+ def speculativeTaskSubmitted(task: Task[_]): Unit = {
+ eventProcessLoop.post(SpeculativeTaskSubmitted(task))
+ }
+
private[scheduler]
def getCacheLocs(rdd: RDD[_]): IndexedSeq[Seq[TaskLocation]] = cacheLocs.synchronized {
// Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times
@@ -812,6 +819,10 @@ class DAGScheduler(
listenerBus.post(SparkListenerTaskStart(task.stageId, stageAttemptId, taskInfo))
}
+ private[scheduler] def handleSpeculativeTaskSubmitted(task: Task[_]): Unit = {
+ listenerBus.post(SparkListenerSpeculativeTaskSubmitted(task.stageId))
+ }
+
private[scheduler] def handleTaskSetFailed(
taskSet: TaskSet,
reason: String,
@@ -1778,6 +1789,9 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
case BeginEvent(task, taskInfo) =>
dagScheduler.handleBeginEvent(task, taskInfo)
+ case SpeculativeTaskSubmitted(task) =>
+ dagScheduler.handleSpeculativeTaskSubmitted(task)
+
case GettingResultEvent(taskInfo) =>
dagScheduler.handleGetTaskResult(taskInfo)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 3f8d5639a2b90..54ab8f8b3e1d8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -94,3 +94,7 @@ case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Thr
extends DAGSchedulerEvent
private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent
+
+private[scheduler]
+case class SpeculativeTaskSubmitted(task: Task[_]) extends DAGSchedulerEvent
+
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
index 05f650fbf5df9..1b44d0aee3195 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler
import scala.collection.mutable.HashSet
-import org.apache.spark.{MapOutputTrackerMaster, ShuffleDependency, SparkEnv}
+import org.apache.spark.{MapOutputTrackerMaster, ShuffleDependency}
import org.apache.spark.rdd.RDD
import org.apache.spark.util.CallSite
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 59f89a82a1da8..b76e560669d59 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -52,6 +52,9 @@ case class SparkListenerTaskStart(stageId: Int, stageAttemptId: Int, taskInfo: T
@DeveloperApi
case class SparkListenerTaskGettingResult(taskInfo: TaskInfo) extends SparkListenerEvent
+@DeveloperApi
+case class SparkListenerSpeculativeTaskSubmitted(stageId: Int) extends SparkListenerEvent
+
@DeveloperApi
case class SparkListenerTaskEnd(
stageId: Int,
@@ -290,6 +293,11 @@ private[spark] trait SparkListenerInterface {
*/
def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit
+ /**
+ * Called when a speculative task is submitted
+ */
+ def onSpeculativeTaskSubmitted(speculativeTask: SparkListenerSpeculativeTaskSubmitted): Unit
+
/**
* Called when other events like SQL-specific events are posted.
*/
@@ -354,5 +362,8 @@ abstract class SparkListener extends SparkListenerInterface {
override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { }
+ override def onSpeculativeTaskSubmitted(
+ speculativeTask: SparkListenerSpeculativeTaskSubmitted): Unit = { }
+
override def onOtherEvent(event: SparkListenerEvent): Unit = { }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index 3b0d3b1b150fe..056c0cbded435 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -71,6 +71,8 @@ private[spark] trait SparkListenerBus
listener.onNodeUnblacklisted(nodeUnblacklisted)
case blockUpdated: SparkListenerBlockUpdated =>
listener.onBlockUpdated(blockUpdated)
+ case speculativeTaskSubmitted: SparkListenerSpeculativeTaskSubmitted =>
+ listener.onSpeculativeTaskSubmitted(speculativeTaskSubmitted)
case _ => listener.onOtherEvent(event)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 737b383631148..0c11806b3981b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -32,7 +32,6 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.scheduler.TaskLocality.TaskLocality
-import org.apache.spark.scheduler.local.LocalSchedulerBackend
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.{AccumulatorV2, ThreadUtils, Utils}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index b1dfc2944d2df..1146b3c8d5994 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -966,6 +966,7 @@ private[spark] class TaskSetManager(
"Marking task %d in stage %s (on %s) as speculatable because it ran more than %.0f ms"
.format(index, taskSet.id, info.host, threshold))
speculatableTasks += index
+ sched.dagScheduler.speculativeTaskSubmitted(tasks(index))
foundTasks = true
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index 6de59a3ea5c23..d2b7d181f6de8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -32,7 +32,8 @@ private[spark] object CoarseGrainedClusterMessages {
case class SparkAppConfig(
sparkProperties: Seq[(String, String)],
- ioEncryptionKey: Option[Array[Byte]])
+ ioEncryptionKey: Option[Array[Byte]],
+ hadoopDelegationCreds: Option[Array[Byte]])
extends CoarseGrainedClusterMessage
case object RetrieveLastAllocatedExecutorId extends CoarseGrainedClusterMessage
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index c6ccb3f06c780..e97c01f9b875e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -24,7 +24,11 @@ import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.concurrent.Future
+import org.apache.hadoop.security.UserGroupInformation
+
import org.apache.spark.{ExecutorAllocationClient, SparkEnv, SparkException, TaskState}
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.deploy.security.HadoopDelegationTokenManager
import org.apache.spark.internal.Logging
import org.apache.spark.rpc._
import org.apache.spark.scheduler._
@@ -42,8 +46,8 @@ import org.apache.spark.util.{RpcUtils, SerializableBuffer, ThreadUtils, Utils}
*/
private[spark]
class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: RpcEnv)
- extends ExecutorAllocationClient with SchedulerBackend with Logging
-{
+ extends ExecutorAllocationClient with SchedulerBackend with Logging {
+
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
protected val totalCoreCount = new AtomicInteger(0)
// Total number of executors that are currently registered
@@ -95,6 +99,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
// The num of current max ExecutorId used to re-register appMaster
@volatile protected var currentExecutorIdCounter = 0
+ // hadoop token manager used by some sub-classes (e.g. Mesos)
+ def hadoopDelegationTokenManager: Option[HadoopDelegationTokenManager] = None
+
+ // Hadoop delegation tokens to be sent to the executors.
+ val hadoopDelegationCreds: Option[Array[Byte]] = getHadoopDelegationCreds()
+
class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)])
extends ThreadSafeRpcEndpoint with Logging {
@@ -223,8 +233,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
context.reply(true)
case RetrieveSparkAppConfig(_) =>
- val reply = SparkAppConfig(sparkProperties,
- SparkEnv.get.securityManager.getIOEncryptionKey())
+ val reply = SparkAppConfig(
+ sparkProperties,
+ SparkEnv.get.securityManager.getIOEncryptionKey(),
+ hadoopDelegationCreds)
context.reply(reply)
}
@@ -675,6 +687,19 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
driverEndpoint.send(KillExecutorsOnHost(host))
true
}
+
+ protected def getHadoopDelegationCreds(): Option[Array[Byte]] = {
+ if (UserGroupInformation.isSecurityEnabled && hadoopDelegationTokenManager.isDefined) {
+ hadoopDelegationTokenManager.map { manager =>
+ val creds = UserGroupInformation.getCurrentUser.getCredentials
+ val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
+ manager.obtainDelegationTokens(hadoopConf, creds)
+ SparkHadoopUtil.get.serialize(creds)
+ }
+ } else {
+ None
+ }
+ }
}
private[spark] object CoarseGrainedSchedulerBackend {
diff --git a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
index 78dabb42ac9d2..00621976b77f4 100644
--- a/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
+++ b/core/src/main/scala/org/apache/spark/security/CryptoStreamUtils.scala
@@ -16,7 +16,7 @@
*/
package org.apache.spark.security
-import java.io.{EOFException, InputStream, OutputStream}
+import java.io.{InputStream, OutputStream}
import java.nio.ByteBuffer
import java.nio.channels.{ReadableByteChannel, WritableByteChannel}
import java.util.Properties
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala
index 56028710ecc66..4a4ed954d689e 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala
@@ -47,6 +47,7 @@ private[v1] class AllStagesResource(ui: SparkUI) {
listener.stageIdToData.get((stageInfo.stageId, stageInfo.attemptId))
}
} yield {
+ stageUiData.lastUpdateTime = ui.lastUpdateTime
AllStagesResource.stageUiToStageData(status, stageInfo, stageUiData, includeDetails = false)
}
}
@@ -69,7 +70,8 @@ private[v1] object AllStagesResource {
}
val taskData = if (includeDetails) {
- Some(stageUiData.taskData.map { case (k, v) => k -> convertTaskData(v) } )
+ Some(stageUiData.taskData.map { case (k, v) =>
+ k -> convertTaskData(v, stageUiData.lastUpdateTime) })
} else {
None
}
@@ -136,13 +138,13 @@ private[v1] object AllStagesResource {
}
}
- def convertTaskData(uiData: TaskUIData): TaskData = {
+ def convertTaskData(uiData: TaskUIData, lastUpdateTime: Option[Long]): TaskData = {
new TaskData(
taskId = uiData.taskInfo.taskId,
index = uiData.taskInfo.index,
attempt = uiData.taskInfo.attemptNumber,
launchTime = new Date(uiData.taskInfo.launchTime),
- duration = uiData.taskDuration,
+ duration = uiData.taskDuration(lastUpdateTime),
executorId = uiData.taskInfo.executorId,
host = uiData.taskInfo.host,
status = uiData.taskInfo.status,
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala
index 3e6d2942d0fbb..f15073bccced2 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneStageResource.scala
@@ -35,6 +35,7 @@ private[v1] class OneStageResource(ui: SparkUI) {
def stageData(@PathParam("stageId") stageId: Int): Seq[StageData] = {
withStage(stageId) { stageAttempts =>
stageAttempts.map { stage =>
+ stage.ui.lastUpdateTime = ui.lastUpdateTime
AllStagesResource.stageUiToStageData(stage.status, stage.info, stage.ui,
includeDetails = true)
}
@@ -47,6 +48,7 @@ private[v1] class OneStageResource(ui: SparkUI) {
@PathParam("stageId") stageId: Int,
@PathParam("stageAttemptId") stageAttemptId: Int): StageData = {
withStageAttempt(stageId, stageAttemptId) { stage =>
+ stage.ui.lastUpdateTime = ui.lastUpdateTime
AllStagesResource.stageUiToStageData(stage.status, stage.info, stage.ui,
includeDetails = true)
}
@@ -81,7 +83,8 @@ private[v1] class OneStageResource(ui: SparkUI) {
@DefaultValue("20") @QueryParam("length") length: Int,
@DefaultValue("ID") @QueryParam("sortBy") sortBy: TaskSorting): Seq[TaskData] = {
withStageAttempt(stageId, stageAttemptId) { stage =>
- val tasks = stage.ui.taskData.values.map{AllStagesResource.convertTaskData}.toIndexedSeq
+ val tasks = stage.ui.taskData.values
+ .map{ AllStagesResource.convertTaskData(_, ui.lastUpdateTime)}.toIndexedSeq
.sorted(OneStageResource.ordering(sortBy))
tasks.slice(offset, offset + length)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index b9a0b0633825d..d588ed456cb90 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -994,11 +994,16 @@ private[spark] class BlockManager(
logWarning(s"Putting block $blockId failed")
}
res
+ } catch {
+ // Since removeBlockInternal may throw exception,
+ // we should print exception first to show root cause.
+ case NonFatal(e) =>
+ logWarning(s"Putting block $blockId failed due to exception $e.")
+ throw e
} finally {
// This cleanup is performed in a finally block rather than a `catch` to avoid having to
// catch and properly re-throw InterruptedException.
if (exceptionWasThrown) {
- logWarning(s"Putting block $blockId failed due to an exception")
// If an exception was thrown then it's possible that the code in `putBody` has already
// notified the master about the availability of this block, so we need to send an update
// to remove this block location.
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala
index 1ea0d378cbe87..3d3806126676c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerManagedBuffer.scala
@@ -22,7 +22,6 @@ import java.nio.ByteBuffer
import java.util.concurrent.atomic.AtomicInteger
import org.apache.spark.network.buffer.ManagedBuffer
-import org.apache.spark.util.io.ChunkedByteBuffer
/**
* This [[ManagedBuffer]] wraps a [[BlockData]] instance retrieved from the [[BlockManager]]
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index 95d70479ef017..3579acf8d83d9 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -21,21 +21,19 @@ import java.io._
import java.nio.ByteBuffer
import java.nio.channels.{Channels, ReadableByteChannel, WritableByteChannel}
import java.nio.channels.FileChannel.MapMode
-import java.nio.charset.StandardCharsets.UTF_8
import java.util.concurrent.ConcurrentHashMap
import scala.collection.mutable.ListBuffer
-import com.google.common.io.{ByteStreams, Closeables, Files}
+import com.google.common.io.Closeables
import io.netty.channel.{DefaultFileRegion, FileRegion}
import io.netty.util.AbstractReferenceCounted
import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.internal.Logging
-import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.security.CryptoStreamUtils
-import org.apache.spark.util.{ByteBufferInputStream, Utils}
+import org.apache.spark.util.Utils
import org.apache.spark.util.io.ChunkedByteBuffer
/**
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index 3ddaac78f0257..5ee04dad6ed4d 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -21,7 +21,6 @@ import java.net.{URI, URL}
import javax.servlet.DispatcherType
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
-import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
import scala.xml.Node
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index 589f811145519..f3fcf2778d39e 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -50,6 +50,7 @@ private[spark] class SparkUI private (
val operationGraphListener: RDDOperationGraphListener,
var appName: String,
val basePath: String,
+ val lastUpdateTime: Option[Long] = None,
val startTime: Long)
extends WebUI(securityManager, securityManager.getSSLOptions("ui"), SparkUI.getUIPort(conf),
conf, basePath, "SparkUI")
@@ -176,9 +177,11 @@ private[spark] object SparkUI {
securityManager: SecurityManager,
appName: String,
basePath: String,
+ lastUpdateTime: Option[Long],
startTime: Long): SparkUI = {
val sparkUI = create(
- None, conf, listenerBus, securityManager, appName, basePath, startTime = startTime)
+ None, conf, listenerBus, securityManager, appName, basePath,
+ lastUpdateTime = lastUpdateTime, startTime = startTime)
val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory],
Utils.getContextOrSparkClassLoader).asScala
@@ -204,6 +207,7 @@ private[spark] object SparkUI {
appName: String,
basePath: String = "",
jobProgressListener: Option[JobProgressListener] = None,
+ lastUpdateTime: Option[Long] = None,
startTime: Long): SparkUI = {
val _jobProgressListener: JobProgressListener = jobProgressListener.getOrElse {
@@ -226,6 +230,6 @@ private[spark] object SparkUI {
new SparkUI(sc, conf, securityManager, environmentListener, storageStatusListener,
executorsListener, _jobProgressListener, storageListener, operationGraphListener,
- appName, basePath, startTime)
+ appName, basePath, lastUpdateTime, startTime)
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 633e740b9c9bd..4d80308eb0a6d 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -299,6 +299,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
stageData.hasShuffleRead,
stageData.hasShuffleWrite,
stageData.hasBytesSpilled,
+ parent.lastUpdateTime,
currentTime,
pageSize = taskPageSize,
sortColumn = taskSortColumn,
@@ -863,6 +864,7 @@ private[ui] class TaskDataSource(
hasShuffleRead: Boolean,
hasShuffleWrite: Boolean,
hasBytesSpilled: Boolean,
+ lastUpdateTime: Option[Long],
currentTime: Long,
pageSize: Int,
sortColumn: String,
@@ -889,8 +891,9 @@ private[ui] class TaskDataSource(
private def taskRow(taskData: TaskUIData): TaskTableRowData = {
val info = taskData.taskInfo
val metrics = taskData.metrics
- val duration = taskData.taskDuration.getOrElse(1L)
- val formatDuration = taskData.taskDuration.map(d => UIUtils.formatDuration(d)).getOrElse("")
+ val duration = taskData.taskDuration(lastUpdateTime).getOrElse(1L)
+ val formatDuration =
+ taskData.taskDuration(lastUpdateTime).map(d => UIUtils.formatDuration(d)).getOrElse("")
val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L)
val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L)
val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L)
@@ -1154,6 +1157,7 @@ private[ui] class TaskPagedTable(
hasShuffleRead: Boolean,
hasShuffleWrite: Boolean,
hasBytesSpilled: Boolean,
+ lastUpdateTime: Option[Long],
currentTime: Long,
pageSize: Int,
sortColumn: String,
@@ -1179,6 +1183,7 @@ private[ui] class TaskPagedTable(
hasShuffleRead,
hasShuffleWrite,
hasBytesSpilled,
+ lastUpdateTime,
currentTime,
pageSize,
sortColumn,
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
index 799d769626395..0787ea6625903 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
@@ -30,6 +30,7 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages"
val progressListener = parent.jobProgressListener
val operationGraphListener = parent.operationGraphListener
val executorsListener = parent.executorsListener
+ val lastUpdateTime = parent.lastUpdateTime
attachPage(new AllStagesPage(this))
attachPage(new StagePage(this))
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
index 9448baac096dc..d9c87f69d8a54 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
@@ -97,6 +97,7 @@ private[spark] object UIData {
var memoryBytesSpilled: Long = _
var diskBytesSpilled: Long = _
var isBlacklisted: Int = _
+ var lastUpdateTime: Option[Long] = None
var schedulingPool: String = ""
var description: Option[String] = None
@@ -133,9 +134,9 @@ private[spark] object UIData {
_metrics = metrics.map(TaskMetricsUIData.fromTaskMetrics)
}
- def taskDuration: Option[Long] = {
+ def taskDuration(lastUpdateTime: Option[Long] = None): Option[Long] = {
if (taskInfo.status == "RUNNING") {
- Some(_taskInfo.timeRunning(System.currentTimeMillis))
+ Some(_taskInfo.timeRunning(lastUpdateTime.getOrElse(System.currentTimeMillis)))
} else {
_metrics.map(_.executorRunTime)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala
index 43bfe0aacf35b..bb763248cd7e0 100644
--- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala
+++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala
@@ -26,7 +26,7 @@ import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.StageInfo
-import org.apache.spark.storage.{RDDInfo, StorageLevel}
+import org.apache.spark.storage.StorageLevel
/**
* A representation of a generic cluster graph used for storing information on RDD operations.
diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
index 603c23abb6895..f4a736d6d439a 100644
--- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
+++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
@@ -23,12 +23,9 @@ import java.util.{ArrayList, Collections}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicLong
-import scala.collection.JavaConverters._
-
import org.apache.spark.{InternalAccumulator, SparkContext, TaskContext}
import org.apache.spark.scheduler.AccumulableInfo
-
private[spark] case class AccumulatorMetadata(
id: Long,
name: Option[String],
diff --git a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala
index 50dc948e6c410..a938cb07724c7 100644
--- a/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala
+++ b/core/src/main/scala/org/apache/spark/util/ByteBufferInputStream.scala
@@ -20,8 +20,6 @@ package org.apache.spark.util
import java.io.InputStream
import java.nio.ByteBuffer
-import org.apache.spark.storage.StorageUtils
-
/**
* Reads data from a ByteBuffer.
*/
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index d661293e529f9..26ff00cf387c8 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -449,7 +449,7 @@ private[spark] object Utils extends Logging {
securityMgr: SecurityManager,
hadoopConf: Configuration,
timestamp: Long,
- useCache: Boolean) {
+ useCache: Boolean): File = {
val fileName = decodeFileNameInURI(new URI(url))
val targetFile = new File(targetDir, fileName)
val fetchCacheEnabled = conf.getBoolean("spark.files.useFetchCache", defaultValue = true)
@@ -498,6 +498,8 @@ private[spark] object Utils extends Logging {
if (isWindows) {
FileUtil.chmod(targetFile.getAbsolutePath, "u+r")
}
+
+ targetFile
}
/**
@@ -637,13 +639,13 @@ private[spark] object Utils extends Logging {
* Throws SparkException if the target file already exists and has different contents than
* the requested file.
*/
- private def doFetchFile(
+ def doFetchFile(
url: String,
targetDir: File,
filename: String,
conf: SparkConf,
securityMgr: SecurityManager,
- hadoopConf: Configuration) {
+ hadoopConf: Configuration): File = {
val targetFile = new File(targetDir, filename)
val uri = new URI(url)
val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false)
@@ -687,6 +689,8 @@ private[spark] object Utils extends Logging {
fetchHcfsFile(path, targetDir, fs, conf, hadoopConf, fileOverwrite,
filename = Some(filename))
}
+
+ targetFile
}
/**
@@ -1182,16 +1186,17 @@ private[spark] object Utils extends Logging {
val second = 1000
val minute = 60 * second
val hour = 60 * minute
+ val locale = Locale.US
ms match {
case t if t < second =>
- "%d ms".format(t)
+ "%d ms".formatLocal(locale, t)
case t if t < minute =>
- "%.1f s".format(t.toFloat / second)
+ "%.1f s".formatLocal(locale, t.toFloat / second)
case t if t < hour =>
- "%.1f m".format(t.toFloat / minute)
+ "%.1f m".formatLocal(locale, t.toFloat / minute)
case t =>
- "%.2f h".format(t.toFloat / hour)
+ "%.2f h".formatLocal(locale, t.toFloat / hour)
}
}
@@ -2594,18 +2599,23 @@ private[spark] object Utils extends Logging {
}
/**
- * In YARN mode this method returns a union of the jar files pointed by "spark.jars" and the
- * "spark.yarn.dist.jars" properties, while in other modes it returns the jar files pointed by
- * only the "spark.jars" property.
+ * Return the jar files pointed by the "spark.jars" property. Spark internally will distribute
+ * these jars through file server. In the YARN mode, it will return an empty list, since YARN
+ * has its own mechanism to distribute jars.
*/
- def getUserJars(conf: SparkConf, isShell: Boolean = false): Seq[String] = {
+ def getUserJars(conf: SparkConf): Seq[String] = {
val sparkJars = conf.getOption("spark.jars")
- if (conf.get("spark.master") == "yarn" && isShell) {
- val yarnJars = conf.getOption("spark.yarn.dist.jars")
- unionFileLists(sparkJars, yarnJars).toSeq
- } else {
- sparkJars.map(_.split(",")).map(_.filter(_.nonEmpty)).toSeq.flatten
- }
+ sparkJars.map(_.split(",")).map(_.filter(_.nonEmpty)).toSeq.flatten
+ }
+
+ /**
+ * Return the local jar files which will be added to REPL's classpath. These jar files are
+ * specified by --jars (spark.jars) or --packages, remote jars will be downloaded to local by
+ * SparkSubmit at first.
+ */
+ def getLocalUserJarsForShell(conf: SparkConf): Seq[String] = {
+ val localJars = conf.getOption("spark.repl.local.jars")
+ localJars.map(_.split(",")).map(_.filter(_.nonEmpty)).toSeq.flatten
}
private[spark] val REDACTION_REPLACEMENT_TEXT = "*********(redacted)"
@@ -2623,9 +2633,12 @@ private[spark] object Utils extends Logging {
* Redact the sensitive information in the given string.
*/
def redact(conf: SparkConf, text: String): String = {
- if (text == null || text.isEmpty || !conf.contains(STRING_REDACTION_PATTERN)) return text
- val regex = conf.get(STRING_REDACTION_PATTERN).get
- regex.replaceAllIn(text, REDACTION_REPLACEMENT_TEXT)
+ if (text == null || text.isEmpty || conf == null || !conf.contains(STRING_REDACTION_PATTERN)) {
+ text
+ } else {
+ val regex = conf.get(STRING_REDACTION_PATTERN).get
+ regex.replaceAllIn(text, REDACTION_REPLACEMENT_TEXT)
+ }
}
private def redact(redactionPattern: Regex, kvs: Seq[(String, String)]): Seq[(String, String)] = {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index a08563562b874..6f5b5bb3652de 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -32,7 +32,6 @@ import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.internal.Logging
-import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.serializer.{DeserializationStream, Serializer, SerializerManager}
import org.apache.spark.storage.{BlockId, BlockManager}
import org.apache.spark.util.CompletionIterator
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
index f48bfd5c25f77..c28570fb24560 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
@@ -24,6 +24,8 @@ import java.nio.channels.WritableByteChannel
import com.google.common.primitives.UnsignedBytes
import io.netty.buffer.{ByteBuf, Unpooled}
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.config
import org.apache.spark.network.util.ByteArrayWritableChannel
import org.apache.spark.storage.StorageUtils
@@ -40,6 +42,11 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
require(chunks != null, "chunks must not be null")
require(chunks.forall(_.position() == 0), "chunks' positions must be 0")
+ // Chunk size in bytes
+ private val bufferWriteChunkSize =
+ Option(SparkEnv.get).map(_.conf.get(config.BUFFER_WRITE_CHUNK_SIZE))
+ .getOrElse(config.BUFFER_WRITE_CHUNK_SIZE.defaultValue.get).toInt
+
private[this] var disposed: Boolean = false
/**
@@ -56,7 +63,9 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
*/
def writeFully(channel: WritableByteChannel): Unit = {
for (bytes <- getChunks()) {
- while (bytes.remaining > 0) {
+ while (bytes.remaining() > 0) {
+ val ioSize = Math.min(bytes.remaining(), bufferWriteChunkSize)
+ bytes.limit(bytes.position + ioSize)
channel.write(bytes)
}
}
diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
index f53bc0b02bbfa..46b0516e36141 100644
--- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
+++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java
@@ -54,6 +54,7 @@ public void encodePageNumberAndOffsetOffHeap() {
final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset);
Assert.assertEquals(null, manager.getPage(encodedAddress));
Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress));
+ manager.freePage(dataPage, c);
}
@Test
diff --git a/core/src/test/resources/fairscheduler-with-valid-data.xml b/core/src/test/resources/fairscheduler-with-valid-data.xml
new file mode 100644
index 0000000000000..3d882331835ca
--- /dev/null
+++ b/core/src/test/resources/fairscheduler-with-valid-data.xml
@@ -0,0 +1,35 @@
+
+
+
+
+
+ 3
+ 1
+ FIFO
+
+
+ 4
+ 2
+ FAIR
+
+
+ 2
+ 3
+ FAIR
+
+
\ No newline at end of file
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 2071f90cfeed5..066c16dd3012c 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark
import org.scalatest.Matchers
-import org.scalatest.concurrent.Timeouts._
+import org.scalatest.concurrent.TimeLimits._
import org.scalatest.time.{Millis, Span}
import org.apache.spark.security.EncryptionFunSuite
diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala
index 454b7e607a51b..be80d278fcea8 100644
--- a/core/src/test/scala/org/apache/spark/DriverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala
@@ -19,13 +19,13 @@ package org.apache.spark
import java.io.File
-import org.scalatest.concurrent.Timeouts
+import org.scalatest.concurrent.TimeLimits
import org.scalatest.prop.TableDrivenPropertyChecks._
import org.scalatest.time.SpanSugar._
import org.apache.spark.util.Utils
-class DriverSuite extends SparkFunSuite with Timeouts {
+class DriverSuite extends SparkFunSuite with TimeLimits {
ignore("driver should exit after finishing without cleanup (SPARK-530)") {
val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
index b9ce71a0c5254..7da4bae0ab7eb 100644
--- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
@@ -188,6 +188,40 @@ class ExecutorAllocationManagerSuite
assert(numExecutorsTarget(manager) === 10)
}
+ test("add executors when speculative tasks added") {
+ sc = createSparkContext(0, 10, 0)
+ val manager = sc.executorAllocationManager.get
+
+ // Verify that we're capped at number of tasks including the speculative ones in the stage
+ sc.listenerBus.postToAll(SparkListenerSpeculativeTaskSubmitted(1))
+ assert(numExecutorsTarget(manager) === 0)
+ assert(numExecutorsToAdd(manager) === 1)
+ assert(addExecutors(manager) === 1)
+ sc.listenerBus.postToAll(SparkListenerSpeculativeTaskSubmitted(1))
+ sc.listenerBus.postToAll(SparkListenerSpeculativeTaskSubmitted(1))
+ sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, 2)))
+ assert(numExecutorsTarget(manager) === 1)
+ assert(numExecutorsToAdd(manager) === 2)
+ assert(addExecutors(manager) === 2)
+ assert(numExecutorsTarget(manager) === 3)
+ assert(numExecutorsToAdd(manager) === 4)
+ assert(addExecutors(manager) === 2)
+ assert(numExecutorsTarget(manager) === 5)
+ assert(numExecutorsToAdd(manager) === 1)
+
+ // Verify that running a task doesn't affect the target
+ sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-1")))
+ assert(numExecutorsTarget(manager) === 5)
+ assert(addExecutors(manager) === 0)
+ assert(numExecutorsToAdd(manager) === 1)
+
+ // Verify that running a speculative task doesn't affect the target
+ sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, createTaskInfo(0, 0, "executor-2", true)))
+ assert(numExecutorsTarget(manager) === 5)
+ assert(addExecutors(manager) === 0)
+ assert(numExecutorsToAdd(manager) === 1)
+ }
+
test("cancel pending executors when no longer needed") {
sc = createSparkContext(0, 10, 0)
val manager = sc.executorAllocationManager.get
@@ -1031,10 +1065,15 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester {
taskLocalityPreferences = taskLocalityPreferences)
}
- private def createTaskInfo(taskId: Int, taskIndex: Int, executorId: String): TaskInfo = {
- new TaskInfo(taskId, taskIndex, 0, 0, executorId, "", TaskLocality.ANY, speculative = false)
+ private def createTaskInfo(
+ taskId: Int,
+ taskIndex: Int,
+ executorId: String,
+ speculative: Boolean = false): TaskInfo = {
+ new TaskInfo(taskId, taskIndex, 0, 0, executorId, "", TaskLocality.ANY, speculative)
}
+
/* ------------------------------------------------------- *
| Helper methods for accessing private methods and fields |
* ------------------------------------------------------- */
@@ -1061,6 +1100,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester {
private val _onExecutorBusy = PrivateMethod[Unit]('onExecutorBusy)
private val _localityAwareTasks = PrivateMethod[Int]('localityAwareTasks)
private val _hostToLocalTaskCount = PrivateMethod[Map[String, Int]]('hostToLocalTaskCount)
+ private val _onSpeculativeTaskSubmitted = PrivateMethod[Unit]('onSpeculativeTaskSubmitted)
private def numExecutorsToAdd(manager: ExecutorAllocationManager): Int = {
manager invokePrivate _numExecutorsToAdd()
@@ -1136,6 +1176,10 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester {
manager invokePrivate _onExecutorBusy(id)
}
+ private def onSpeculativeTaskSubmitted(manager: ExecutorAllocationManager, id: String) : Unit = {
+ manager invokePrivate _onSpeculativeTaskSubmitted(id)
+ }
+
private def localityAwareTasks(manager: ExecutorAllocationManager): Int = {
manager invokePrivate _localityAwareTasks()
}
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
index 890e93d764f90..0ed5f26863dad 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
@@ -600,6 +600,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
val fs = new DebugFilesystem()
fs.initialize(new URI("file:///"), new Configuration())
val file = File.createTempFile("SPARK19446", "temp")
+ file.deleteOnExit()
Files.write(Array.ofDim[Byte](1000), file)
val path = new Path("file:///" + file.getCanonicalPath)
val stream = fs.open(path)
diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala
index 09e21646ee744..bc3f58cf2a35d 100644
--- a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala
+++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark
-import org.scalatest.concurrent.Timeouts._
+import org.scalatest.concurrent.TimeLimits._
import org.scalatest.time.{Millis, Span}
class UnpersistSuite extends SparkFunSuite with LocalSparkContext {
diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala
index 5e0bf6d438dc8..32dd3ecc2f027 100644
--- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala
@@ -137,9 +137,10 @@ class RPackageUtilsSuite
IvyTestUtils.withRepository(main, None, None) { repo =>
val jar = IvyTestUtils.packJar(new File(new URI(repo)), dep1, Nil,
useIvyLayout = false, withR = false, None)
- val jarFile = new JarFile(jar)
- assert(jarFile.getManifest == null, "jar file should have null manifest")
- assert(!RPackageUtils.checkManifestForR(jarFile), "null manifest should return false")
+ Utils.tryWithResource(new JarFile(jar)) { jarFile =>
+ assert(jarFile.getManifest == null, "jar file should have null manifest")
+ assert(!RPackageUtils.checkManifestForR(jarFile), "null manifest should return false")
+ }
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 08ba41f50a2b9..4d69ce844d2ea 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -29,9 +29,9 @@ import scala.io.Source
import com.google.common.io.ByteStreams
import org.apache.commons.io.FileUtils
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.{FileStatus, FSDataInputStream, Path}
import org.scalatest.{BeforeAndAfterEach, Matchers}
-import org.scalatest.concurrent.Timeouts
+import org.scalatest.concurrent.TimeLimits
import org.scalatest.time.SpanSugar._
import org.apache.spark._
@@ -97,7 +97,7 @@ class SparkSubmitSuite
with Matchers
with BeforeAndAfterEach
with ResetSystemProperties
- with Timeouts
+ with TimeLimits
with TestPrematureExit {
override def beforeEach() {
@@ -762,7 +762,7 @@ class SparkSubmitSuite
(Set(jar1.toURI.toString, jar2.toURI.toString))
sysProps("spark.yarn.dist.files").split(",").toSet should be
(Set(file1.toURI.toString, file2.toURI.toString))
- sysProps("spark.submit.pyFiles").split(",").toSet should be
+ sysProps("spark.yarn.dist.pyFiles").split(",").toSet should be
(Set(pyFile1.getAbsolutePath, pyFile2.getAbsolutePath))
sysProps("spark.yarn.dist.archives").split(",").toSet should be
(Set(archive1.toURI.toString, archive2.toURI.toString))
@@ -793,64 +793,65 @@ class SparkSubmitSuite
}
test("downloadFile - invalid url") {
+ val sparkConf = new SparkConf(false)
intercept[IOException] {
- SparkSubmit.downloadFile(
- "abc:/my/file", Utils.createTempDir(), mutable.Map.empty, new Configuration())
+ DependencyUtils.downloadFile(
+ "abc:/my/file", Utils.createTempDir(), sparkConf, new Configuration(),
+ new SecurityManager(sparkConf))
}
}
test("downloadFile - file doesn't exist") {
+ val sparkConf = new SparkConf(false)
val hadoopConf = new Configuration()
val tmpDir = Utils.createTempDir()
- // Set s3a implementation to local file system for testing.
- hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem")
- // Disable file system impl cache to make sure the test file system is picked up.
- hadoopConf.set("fs.s3a.impl.disable.cache", "true")
+ updateConfWithFakeS3Fs(hadoopConf)
intercept[FileNotFoundException] {
- SparkSubmit.downloadFile("s3a:/no/such/file", tmpDir, mutable.Map.empty, hadoopConf)
+ DependencyUtils.downloadFile("s3a:/no/such/file", tmpDir, sparkConf, hadoopConf,
+ new SecurityManager(sparkConf))
}
}
test("downloadFile does not download local file") {
+ val sparkConf = new SparkConf(false)
+ val secMgr = new SecurityManager(sparkConf)
// empty path is considered as local file.
val tmpDir = Files.createTempDirectory("tmp").toFile
- assert(SparkSubmit.downloadFile("", tmpDir, mutable.Map.empty, new Configuration()) === "")
- assert(SparkSubmit.downloadFile("/local/file", tmpDir, mutable.Map.empty,
- new Configuration()) === "/local/file")
+ assert(DependencyUtils.downloadFile("", tmpDir, sparkConf, new Configuration(), secMgr) === "")
+ assert(DependencyUtils.downloadFile("/local/file", tmpDir, sparkConf, new Configuration(),
+ secMgr) === "/local/file")
}
test("download one file to local") {
+ val sparkConf = new SparkConf(false)
val jarFile = File.createTempFile("test", ".jar")
jarFile.deleteOnExit()
val content = "hello, world"
FileUtils.write(jarFile, content)
val hadoopConf = new Configuration()
val tmpDir = Files.createTempDirectory("tmp").toFile
- // Set s3a implementation to local file system for testing.
- hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem")
- // Disable file system impl cache to make sure the test file system is picked up.
- hadoopConf.set("fs.s3a.impl.disable.cache", "true")
- val sourcePath = s"s3a://${jarFile.getAbsolutePath}"
- val outputPath =
- SparkSubmit.downloadFile(sourcePath, tmpDir, mutable.Map.empty, hadoopConf)
+ updateConfWithFakeS3Fs(hadoopConf)
+ val sourcePath = s"s3a://${jarFile.toURI.getPath}"
+ val outputPath = DependencyUtils.downloadFile(sourcePath, tmpDir, sparkConf, hadoopConf,
+ new SecurityManager(sparkConf))
checkDownloadedFile(sourcePath, outputPath)
deleteTempOutputFile(outputPath)
}
test("download list of files to local") {
+ val sparkConf = new SparkConf(false)
val jarFile = File.createTempFile("test", ".jar")
jarFile.deleteOnExit()
val content = "hello, world"
FileUtils.write(jarFile, content)
val hadoopConf = new Configuration()
val tmpDir = Files.createTempDirectory("tmp").toFile
- // Set s3a implementation to local file system for testing.
- hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem")
- // Disable file system impl cache to make sure the test file system is picked up.
- hadoopConf.set("fs.s3a.impl.disable.cache", "true")
- val sourcePaths = Seq("/local/file", s"s3a://${jarFile.getAbsolutePath}")
- val outputPaths = SparkSubmit.downloadFileList(
- sourcePaths.mkString(","), tmpDir, mutable.Map.empty, hadoopConf).split(",")
+ updateConfWithFakeS3Fs(hadoopConf)
+ val sourcePaths = Seq("/local/file", s"s3a://${jarFile.toURI.getPath}")
+ val outputPaths = DependencyUtils
+ .downloadFileList(sourcePaths.mkString(","), tmpDir, sparkConf, hadoopConf,
+ new SecurityManager(sparkConf))
+ .split(",")
assert(outputPaths.length === sourcePaths.length)
sourcePaths.zip(outputPaths).foreach { case (sourcePath, outputPath) =>
@@ -859,6 +860,43 @@ class SparkSubmitSuite
}
}
+ test("Avoid re-upload remote resources in yarn client mode") {
+ val hadoopConf = new Configuration()
+ updateConfWithFakeS3Fs(hadoopConf)
+
+ val tmpDir = Utils.createTempDir()
+ val file = File.createTempFile("tmpFile", "", tmpDir)
+ val pyFile = File.createTempFile("tmpPy", ".egg", tmpDir)
+ val mainResource = File.createTempFile("tmpPy", ".py", tmpDir)
+ val tmpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir)
+ val tmpJarPath = s"s3a://${new File(tmpJar.toURI).getAbsolutePath}"
+
+ val args = Seq(
+ "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"),
+ "--name", "testApp",
+ "--master", "yarn",
+ "--deploy-mode", "client",
+ "--jars", tmpJarPath,
+ "--files", s"s3a://${file.getAbsolutePath}",
+ "--py-files", s"s3a://${pyFile.getAbsolutePath}",
+ s"s3a://$mainResource"
+ )
+
+ val appArgs = new SparkSubmitArguments(args)
+ val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf))._3
+
+ // All the resources should still be remote paths, so that YARN client will not upload again.
+ sysProps("spark.yarn.dist.jars") should be (tmpJarPath)
+ sysProps("spark.yarn.dist.files") should be (s"s3a://${file.getAbsolutePath}")
+ sysProps("spark.yarn.dist.pyFiles") should be (s"s3a://${pyFile.getAbsolutePath}")
+
+ // Local repl jars should be a local path.
+ sysProps("spark.repl.local.jars") should (startWith("file:"))
+
+ // local py files should not be a URI format.
+ sysProps("spark.submit.pyFiles") should (startWith("/"))
+ }
+
// NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
private def runSparkSubmit(args: Seq[String]): Unit = {
val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
@@ -898,6 +936,11 @@ class SparkSubmitSuite
Utils.deleteRecursively(tmpDir)
}
}
+
+ private def updateConfWithFakeS3Fs(conf: Configuration): Unit = {
+ conf.set("fs.s3a.impl", classOf[TestFileSystem].getCanonicalName)
+ conf.set("fs.s3a.impl.disable.cache", "true")
+ }
}
object JarCreationTest extends Logging {
@@ -963,8 +1006,31 @@ object UserClasspathFirstTest {
}
class TestFileSystem extends org.apache.hadoop.fs.LocalFileSystem {
- override def copyToLocalFile(src: Path, dst: Path): Unit = {
+ private def local(path: Path): Path = {
// Ignore the scheme for testing.
- super.copyToLocalFile(new Path(src.toUri.getPath), dst)
+ new Path(path.toUri.getPath)
}
+
+ private def toRemote(status: FileStatus): FileStatus = {
+ val path = s"s3a://${status.getPath.toUri.getPath}"
+ status.setPath(new Path(path))
+ status
+ }
+
+ override def isFile(path: Path): Boolean = super.isFile(local(path))
+
+ override def globStatus(pathPattern: Path): Array[FileStatus] = {
+ val newPath = new Path(pathPattern.toUri.getPath)
+ super.globStatus(newPath).map(toRemote)
+ }
+
+ override def listStatus(path: Path): Array[FileStatus] = {
+ super.listStatus(local(path)).map(toRemote)
+ }
+
+ override def copyToLocalFile(src: Path, dst: Path): Unit = {
+ super.copyToLocalFile(local(src), dst)
+ }
+
+ override def open(path: Path): FSDataInputStream = super.open(local(path))
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala
index 871c87415d35d..c175ed3fb6e3d 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala
@@ -33,7 +33,7 @@ import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.Matchers
-import org.scalatest.mock.MockitoSugar
+import org.scalatest.mockito.MockitoSugar
import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.Logging
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
index 95acb9a54440f..18da8c18939ed 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
@@ -39,7 +39,7 @@ import org.openqa.selenium.WebDriver
import org.openqa.selenium.htmlunit.HtmlUnitDriver
import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually
-import org.scalatest.mock.MockitoSugar
+import org.scalatest.mockito.MockitoSugar
import org.scalatest.selenium.WebBrowser
import org.apache.spark._
diff --git a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala
index 5b05521e48f8a..eeffc36070b44 100644
--- a/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManagerSuite.scala
@@ -94,7 +94,7 @@ class HadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers {
val hiveCredentialProvider = new HiveDelegationTokenProvider()
val credentials = new Credentials()
- hiveCredentialProvider.obtainDelegationTokens(hadoopConf, credentials)
+ hiveCredentialProvider.obtainDelegationTokens(hadoopConf, sparkConf, credentials)
credentials.getAllTokens.size() should be (0)
}
@@ -105,7 +105,7 @@ class HadoopDelegationTokenManagerSuite extends SparkFunSuite with Matchers {
val hbaseTokenProvider = new HBaseDelegationTokenProvider()
val creds = new Credentials()
- hbaseTokenProvider.obtainDelegationTokens(hadoopConf, creds)
+ hbaseTokenProvider.obtainDelegationTokens(hadoopConf, sparkConf, creds)
creds.getAllTokens.size should be (0)
}
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index 601dde6c63284..884a2750e621d 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -33,7 +33,7 @@ import org.mockito.Mockito.{inOrder, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.concurrent.Eventually
-import org.scalatest.mock.MockitoSugar
+import org.scalatest.mockito.MockitoSugar
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
diff --git a/core/src/test/scala/org/apache/spark/metrics/sink/StatsdSinkSuite.scala b/core/src/test/scala/org/apache/spark/metrics/sink/StatsdSinkSuite.scala
new file mode 100644
index 0000000000000..0e21a36071c42
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/metrics/sink/StatsdSinkSuite.scala
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics.sink
+
+import java.net.{DatagramPacket, DatagramSocket}
+import java.nio.charset.StandardCharsets.UTF_8
+import java.util.Properties
+import java.util.concurrent.TimeUnit._
+
+import com.codahale.metrics._
+
+import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
+import org.apache.spark.metrics.sink.StatsdSink._
+
+class StatsdSinkSuite extends SparkFunSuite {
+ private val securityMgr = new SecurityManager(new SparkConf(false))
+ private val defaultProps = Map(
+ STATSD_KEY_PREFIX -> "spark",
+ STATSD_KEY_PERIOD -> "1",
+ STATSD_KEY_UNIT -> "seconds",
+ STATSD_KEY_HOST -> "127.0.0.1"
+ )
+ private val socketTimeout = 30000 // milliseconds
+ private val socketBufferSize = 8192
+
+ private def withSocketAndSink(testCode: (DatagramSocket, StatsdSink) => Any): Unit = {
+ val socket = new DatagramSocket
+ socket.setReceiveBufferSize(socketBufferSize)
+ socket.setSoTimeout(socketTimeout)
+ val props = new Properties
+ defaultProps.foreach(e => props.put(e._1, e._2))
+ props.put(STATSD_KEY_PORT, socket.getLocalPort.toString)
+ val registry = new MetricRegistry
+ val sink = new StatsdSink(props, registry, securityMgr)
+ try {
+ testCode(socket, sink)
+ } finally {
+ socket.close()
+ }
+ }
+
+ test("metrics StatsD sink with Counter") {
+ withSocketAndSink { (socket, sink) =>
+ val counter = new Counter
+ counter.inc(12)
+ sink.registry.register("counter", counter)
+ sink.report()
+
+ val p = new DatagramPacket(new Array[Byte](socketBufferSize), socketBufferSize)
+ socket.receive(p)
+
+ val result = new String(p.getData, 0, p.getLength, UTF_8)
+ assert(result === "spark.counter:12|c", "Counter metric received should match data sent")
+ }
+ }
+
+ test("metrics StatsD sink with Gauge") {
+ withSocketAndSink { (socket, sink) =>
+ val gauge = new Gauge[Double] {
+ override def getValue: Double = 1.23
+ }
+ sink.registry.register("gauge", gauge)
+ sink.report()
+
+ val p = new DatagramPacket(new Array[Byte](socketBufferSize), socketBufferSize)
+ socket.receive(p)
+
+ val result = new String(p.getData, 0, p.getLength, UTF_8)
+ assert(result === "spark.gauge:1.23|g", "Gauge metric received should match data sent")
+ }
+ }
+
+ test("metrics StatsD sink with Histogram") {
+ withSocketAndSink { (socket, sink) =>
+ val p = new DatagramPacket(new Array[Byte](socketBufferSize), socketBufferSize)
+ val histogram = new Histogram(new UniformReservoir)
+ histogram.update(10)
+ histogram.update(20)
+ histogram.update(30)
+ sink.registry.register("histogram", histogram)
+ sink.report()
+
+ val expectedResults = Set(
+ "spark.histogram.count:3|g",
+ "spark.histogram.max:30|ms",
+ "spark.histogram.mean:20.00|ms",
+ "spark.histogram.min:10|ms",
+ "spark.histogram.stddev:10.00|ms",
+ "spark.histogram.p50:20.00|ms",
+ "spark.histogram.p75:30.00|ms",
+ "spark.histogram.p95:30.00|ms",
+ "spark.histogram.p98:30.00|ms",
+ "spark.histogram.p99:30.00|ms",
+ "spark.histogram.p999:30.00|ms"
+ )
+
+ (1 to expectedResults.size).foreach { i =>
+ socket.receive(p)
+ val result = new String(p.getData, 0, p.getLength, UTF_8)
+ logInfo(s"Received histogram result $i: '$result'")
+ assert(expectedResults.contains(result),
+ "Histogram metric received should match data sent")
+ }
+ }
+ }
+
+ test("metrics StatsD sink with Timer") {
+ withSocketAndSink { (socket, sink) =>
+ val p = new DatagramPacket(new Array[Byte](socketBufferSize), socketBufferSize)
+ val timer = new Timer()
+ timer.update(1, SECONDS)
+ timer.update(2, SECONDS)
+ timer.update(3, SECONDS)
+ sink.registry.register("timer", timer)
+ sink.report()
+
+ val expectedResults = Set(
+ "spark.timer.max:3000.00|ms",
+ "spark.timer.mean:2000.00|ms",
+ "spark.timer.min:1000.00|ms",
+ "spark.timer.stddev:816.50|ms",
+ "spark.timer.p50:2000.00|ms",
+ "spark.timer.p75:3000.00|ms",
+ "spark.timer.p95:3000.00|ms",
+ "spark.timer.p98:3000.00|ms",
+ "spark.timer.p99:3000.00|ms",
+ "spark.timer.p999:3000.00|ms",
+ "spark.timer.count:3|g",
+ "spark.timer.m1_rate:0.00|ms",
+ "spark.timer.m5_rate:0.00|ms",
+ "spark.timer.m15_rate:0.00|ms"
+ )
+ // mean rate varies on each test run
+ val oneMoreResult = """spark.timer.mean_rate:\d+\.\d\d\|ms"""
+
+ (1 to (expectedResults.size + 1)).foreach { i =>
+ socket.receive(p)
+ val result = new String(p.getData, 0, p.getLength, UTF_8)
+ logInfo(s"Received timer result $i: '$result'")
+ assert(expectedResults.contains(result) || result.matches(oneMoreResult),
+ "Timer metric received should match data sent")
+ }
+ }
+ }
+}
+
diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
index e4c133c9f2cdd..21138bd4a16ba 100644
--- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
@@ -28,8 +28,8 @@ import scala.util.{Failure, Success, Try}
import com.google.common.io.CharStreams
import org.mockito.Mockito._
-import org.scalatest.ShouldMatchers
-import org.scalatest.mock.MockitoSugar
+import org.scalatest.Matchers
+import org.scalatest.mockito.MockitoSugar
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.network.{BlockDataManager, BlockTransferService}
@@ -38,7 +38,7 @@ import org.apache.spark.network.shuffle.BlockFetchingListener
import org.apache.spark.storage.{BlockId, ShuffleBlockId}
import org.apache.spark.util.ThreadUtils
-class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar with ShouldMatchers {
+class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar with Matchers {
test("security default off") {
val conf = new SparkConf()
.set("spark.app.id", "app-id")
diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala
index 98259300381eb..f7bc3725d7278 100644
--- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.network.BlockDataManager
class NettyBlockTransferServiceSuite
extends SparkFunSuite
with BeforeAndAfterEach
- with ShouldMatchers {
+ with Matchers {
private var service0: NettyBlockTransferService = _
private var service1: NettyBlockTransferService = _
diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
index 8f639eef46f66..f4be8eaef7013 100644
--- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -24,13 +24,13 @@ import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration.Duration
import org.scalatest.BeforeAndAfterAll
-import org.scalatest.concurrent.Timeouts
+import org.scalatest.concurrent.TimeLimits
import org.scalatest.time.SpanSugar._
import org.apache.spark._
import org.apache.spark.util.ThreadUtils
-class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Timeouts {
+class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with TimeLimits {
@transient private var sc: SparkContext = _
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 386c0060f9c41..e994d724c462f 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -347,16 +347,18 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
val partitions = repartitioned.glom().collect()
// assert all elements are present
assert(repartitioned.collect().sortWith(_ > _).toSeq === input.toSeq.sortWith(_ > _).toSeq)
- // assert no bucket is overloaded
+ // assert no bucket is overloaded or empty
for (partition <- partitions) {
val avg = input.size / finalPartitions
val maxPossible = avg + initialPartitions
- assert(partition.length <= maxPossible)
+ assert(partition.length <= maxPossible)
+ assert(!partition.isEmpty)
}
}
testSplitPartitions(Array.fill(100)(1), 10, 20)
testSplitPartitions(Array.fill(10000)(1) ++ Array.fill(10000)(2), 20, 100)
+ testSplitPartitions(Array.fill(1000)(1), 250, 128)
}
test("coalesced RDDs") {
diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
index 777163709bbf5..f9481f875d439 100644
--- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.rpc.netty
-import org.scalatest.mock.MockitoSugar
+import org.scalatest.mockito.MockitoSugar
import org.apache.spark._
import org.apache.spark.network.client.TransportClient
diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
index 520d85a298922..a136d69b36d6c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
@@ -22,7 +22,7 @@ import org.mockito.Mockito.{never, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.BeforeAndAfterEach
-import org.scalatest.mock.MockitoSugar
+import org.scalatest.mockito.MockitoSugar
import org.apache.spark._
import org.apache.spark.internal.config
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 3b5df657d45cf..703fc1b34c387 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -25,7 +25,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import scala.language.reflectiveCalls
import scala.util.control.NonFatal
-import org.scalatest.concurrent.Timeouts
+import org.scalatest.concurrent.TimeLimits
import org.scalatest.time.SpanSugar._
import org.apache.spark._
@@ -98,7 +98,7 @@ class MyRDD(
class DAGSchedulerSuiteDummyException extends Exception
-class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeouts {
+class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLimits {
import DAGSchedulerSuite._
diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala
index 32cdf16dd3318..a27dadcf49bfc 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.scheduler
import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext}
-import org.scalatest.concurrent.Timeouts
+import org.scalatest.concurrent.TimeLimits
import org.scalatest.time.{Seconds, Span}
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, TaskContext}
@@ -32,7 +32,7 @@ import org.apache.spark.util.Utils
class OutputCommitCoordinatorIntegrationSuite
extends SparkFunSuite
with LocalSparkContext
- with Timeouts {
+ with TimeLimits {
override def beforeAll(): Unit = {
super.beforeAll()
diff --git a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala
index 4901062a78553..5bd3955f5adbb 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.scheduler
+import java.io.FileNotFoundException
import java.util.Properties
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite}
@@ -292,6 +293,49 @@ class PoolSuite extends SparkFunSuite with LocalSparkContext {
}
}
+ test("Fair Scheduler should build fair scheduler when " +
+ "valid spark.scheduler.allocation.file property is set") {
+ val xmlPath = getClass.getClassLoader.getResource("fairscheduler-with-valid-data.xml").getFile()
+ val conf = new SparkConf().set(SCHEDULER_ALLOCATION_FILE_PROPERTY, xmlPath)
+ sc = new SparkContext(LOCAL, APP_NAME, conf)
+
+ val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0)
+ val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf)
+ schedulableBuilder.buildPools()
+
+ verifyPool(rootPool, schedulableBuilder.DEFAULT_POOL_NAME, 0, 1, FIFO)
+ verifyPool(rootPool, "pool1", 3, 1, FIFO)
+ verifyPool(rootPool, "pool2", 4, 2, FAIR)
+ verifyPool(rootPool, "pool3", 2, 3, FAIR)
+ }
+
+ test("Fair Scheduler should use default file(fairscheduler.xml) if it exists in classpath " +
+ "and spark.scheduler.allocation.file property is not set") {
+ val conf = new SparkConf()
+ sc = new SparkContext(LOCAL, APP_NAME, conf)
+
+ val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0)
+ val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf)
+ schedulableBuilder.buildPools()
+
+ verifyPool(rootPool, schedulableBuilder.DEFAULT_POOL_NAME, 0, 1, FIFO)
+ verifyPool(rootPool, "1", 2, 1, FIFO)
+ verifyPool(rootPool, "2", 3, 1, FIFO)
+ verifyPool(rootPool, "3", 0, 1, FIFO)
+ }
+
+ test("Fair Scheduler should throw FileNotFoundException " +
+ "when invalid spark.scheduler.allocation.file property is set") {
+ val conf = new SparkConf().set(SCHEDULER_ALLOCATION_FILE_PROPERTY, "INVALID_FILE_PATH")
+ sc = new SparkContext(LOCAL, APP_NAME, conf)
+
+ val rootPool = new Pool("", SchedulingMode.FAIR, 0, 0)
+ val schedulableBuilder = new FairSchedulableBuilder(rootPool, sc.conf)
+ intercept[FileNotFoundException] {
+ schedulableBuilder.buildPools()
+ }
+ }
+
private def verifyPool(rootPool: Pool, poolName: String, expectedInitMinShare: Int,
expectedInitWeight: Int, expectedSchedulingMode: SchedulingMode): Unit = {
val selectedPool = rootPool.getSchedulableByName(poolName)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
index 88a68af6b647d..d17e3864854a8 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala
@@ -21,6 +21,7 @@ import java.io._
import java.net.URI
import java.util.concurrent.atomic.AtomicInteger
+import org.apache.hadoop.fs.Path
import org.json4s.jackson.JsonMethods._
import org.scalatest.BeforeAndAfter
@@ -84,24 +85,23 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp
val buffered = new ByteArrayOutputStream
val codec = new LZ4CompressionCodec(new SparkConf())
val compstream = codec.compressedOutputStream(buffered)
- val writer = new PrintWriter(compstream)
+ Utils.tryWithResource(new PrintWriter(compstream)) { writer =>
- val applicationStart = SparkListenerApplicationStart("AppStarts", None,
- 125L, "Mickey", None)
- val applicationEnd = SparkListenerApplicationEnd(1000L)
+ val applicationStart = SparkListenerApplicationStart("AppStarts", None,
+ 125L, "Mickey", None)
+ val applicationEnd = SparkListenerApplicationEnd(1000L)
- // scalastyle:off println
- writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationStart))))
- writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationEnd))))
- // scalastyle:on println
- writer.close()
+ // scalastyle:off println
+ writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationStart))))
+ writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationEnd))))
+ // scalastyle:on println
+ }
val logFilePath = Utils.getFilePath(testDir, "events.lz4.inprogress")
- val fstream = fileSystem.create(logFilePath)
val bytes = buffered.toByteArray
-
- fstream.write(bytes, 0, buffered.size)
- fstream.close
+ Utils.tryWithResource(fileSystem.create(logFilePath)) { fstream =>
+ fstream.write(bytes, 0, buffered.size)
+ }
// Read the compressed .inprogress file and verify only first event was parsed.
val conf = EventLoggingListenerSuite.getLoggingConf(logFilePath)
@@ -112,17 +112,19 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp
// Verify the replay returns the events given the input maybe truncated.
val logData = EventLoggingListener.openEventLog(logFilePath, fileSystem)
- val failingStream = new EarlyEOFInputStream(logData, buffered.size - 10)
- replayer.replay(failingStream, logFilePath.toString, true)
+ Utils.tryWithResource(new EarlyEOFInputStream(logData, buffered.size - 10)) { failingStream =>
+ replayer.replay(failingStream, logFilePath.toString, true)
- assert(eventMonster.loggedEvents.size === 1)
- assert(failingStream.didFail)
+ assert(eventMonster.loggedEvents.size === 1)
+ assert(failingStream.didFail)
+ }
// Verify the replay throws the EOF exception since the input may not be truncated.
val logData2 = EventLoggingListener.openEventLog(logFilePath, fileSystem)
- val failingStream2 = new EarlyEOFInputStream(logData2, buffered.size - 10)
- intercept[EOFException] {
- replayer.replay(failingStream2, logFilePath.toString, false)
+ Utils.tryWithResource(new EarlyEOFInputStream(logData2, buffered.size - 10)) { failingStream2 =>
+ intercept[EOFException] {
+ replayer.replay(failingStream2, logFilePath.toString, false)
+ }
}
}
@@ -151,7 +153,10 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp
* assumption that the event logging behavior is correct (tested in a separate suite).
*/
private def testApplicationReplay(codecName: Option[String] = None) {
- val logDirPath = Utils.getFilePath(testDir, "test-replay")
+ val logDir = new File(testDir.getAbsolutePath, "test-replay")
+ // Here, it creates `Path` from the URI instead of the absolute path for the explicit file
+ // scheme so that the string representation of this `Path` has leading file scheme correctly.
+ val logDirPath = new Path(logDir.toURI)
fileSystem.mkdirs(logDirPath)
val conf = EventLoggingListenerSuite.getLoggingConf(logDirPath, codecName)
@@ -221,12 +226,14 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter with LocalSp
def didFail: Boolean = countDown.get == 0
@throws[IOException]
- def read: Int = {
+ override def read(): Int = {
if (countDown.get == 0) {
throw new EOFException("Stream ended prematurely")
}
countDown.decrementAndGet()
- in.read
+ in.read()
}
+
+ override def close(): Unit = in.close()
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 992d3396d203f..a1d9085fa085d 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -54,7 +54,10 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
val rdd = new RDD[String](sc, List()) {
override def getPartitions = Array[Partition](StubPartition(0))
override def compute(split: Partition, context: TaskContext) = {
- context.addTaskCompletionListener(context => TaskContextSuite.completed = true)
+ context.addTaskCompletionListener(new TaskCompletionListener {
+ override def onTaskCompletion(context: TaskContext): Unit =
+ TaskContextSuite.completed = true
+ })
sys.error("failed")
}
}
@@ -95,9 +98,13 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
test("all TaskCompletionListeners should be called even if some fail") {
val context = TaskContext.empty()
val listener = mock(classOf[TaskCompletionListener])
- context.addTaskCompletionListener(_ => throw new Exception("blah"))
+ context.addTaskCompletionListener(new TaskCompletionListener {
+ override def onTaskCompletion(context: TaskContext): Unit = throw new Exception("blah")
+ })
context.addTaskCompletionListener(listener)
- context.addTaskCompletionListener(_ => throw new Exception("blah"))
+ context.addTaskCompletionListener(new TaskCompletionListener {
+ override def onTaskCompletion(context: TaskContext): Unit = throw new Exception("blah")
+ })
intercept[TaskCompletionListenerException] {
context.markTaskCompleted(None)
@@ -109,9 +116,15 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
test("all TaskFailureListeners should be called even if some fail") {
val context = TaskContext.empty()
val listener = mock(classOf[TaskFailureListener])
- context.addTaskFailureListener((_, _) => throw new Exception("exception in listener1"))
+ context.addTaskFailureListener(new TaskFailureListener {
+ override def onTaskFailure(context: TaskContext, error: Throwable): Unit =
+ throw new Exception("exception in listener1")
+ })
context.addTaskFailureListener(listener)
- context.addTaskFailureListener((_, _) => throw new Exception("exception in listener3"))
+ context.addTaskFailureListener(new TaskFailureListener {
+ override def onTaskFailure(context: TaskContext, error: Throwable): Unit =
+ throw new Exception("exception in listener3")
+ })
val e = intercept[TaskCompletionListenerException] {
context.markTaskFailed(new Exception("exception in task"))
@@ -232,7 +245,10 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
var invocations = 0
val context = TaskContext.empty()
context.markTaskCompleted(None)
- context.addTaskCompletionListener(_ => invocations += 1)
+ context.addTaskCompletionListener(new TaskCompletionListener {
+ override def onTaskCompletion(context: TaskContext): Unit =
+ invocations += 1
+ })
assert(invocations == 1)
context.markTaskCompleted(None)
assert(invocations == 1)
@@ -244,10 +260,12 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
val error = new RuntimeException
val context = TaskContext.empty()
context.markTaskFailed(error)
- context.addTaskFailureListener { (_, e) =>
- lastError = e
- invocations += 1
- }
+ context.addTaskFailureListener(new TaskFailureListener {
+ override def onTaskFailure(context: TaskContext, e: Throwable): Unit = {
+ lastError = e
+ invocations += 1
+ }
+ })
assert(lastError == error)
assert(invocations == 1)
context.markTaskFailed(error)
@@ -267,9 +285,15 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
test("all TaskCompletionListeners should be called even if some fail or a task") {
val context = TaskContext.empty()
val listener = mock(classOf[TaskCompletionListener])
- context.addTaskCompletionListener(_ => throw new Exception("exception in listener1"))
+ context.addTaskCompletionListener(new TaskCompletionListener {
+ override def onTaskCompletion(context: TaskContext): Unit =
+ throw new Exception("exception in listener1")
+ })
context.addTaskCompletionListener(listener)
- context.addTaskCompletionListener(_ => throw new Exception("exception in listener3"))
+ context.addTaskCompletionListener(new TaskCompletionListener {
+ override def onTaskCompletion(context: TaskContext): Unit =
+ throw new Exception("exception in listener3")
+ })
val e = intercept[TaskCompletionListenerException] {
context.markTaskCompleted(Some(new Exception("exception in task")))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index ab67a393e2ac5..b8626bf777598 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable.HashMap
import org.mockito.Matchers.{anyInt, anyObject, anyString, eq => meq}
import org.mockito.Mockito.{atLeast, atMost, never, spy, times, verify, when}
import org.scalatest.BeforeAndAfterEach
-import org.scalatest.mock.MockitoSugar
+import org.scalatest.mockito.MockitoSugar
import org.apache.spark._
import org.apache.spark.internal.Logging
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 6f1663b210969..ae43f4cadc037 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -60,6 +60,10 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
exception: Option[Throwable]): Unit = {
taskScheduler.taskSetsFailed += taskSet.id
}
+
+ override def speculativeTaskSubmitted(task: Task[_]): Unit = {
+ taskScheduler.speculativeTasks += task.partitionId
+ }
}
// Get the rack for a given host
@@ -92,6 +96,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
val endedTasks = new mutable.HashMap[Long, TaskEndReason]
val finishedManagers = new ArrayBuffer[TaskSetManager]
val taskSetsFailed = new ArrayBuffer[String]
+ val speculativeTasks = new ArrayBuffer[Int]
val executors = new mutable.HashMap[String, String]
for ((execId, host) <- liveExecutors) {
@@ -139,6 +144,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
}
}
+
override def getRackForHost(value: String): Option[String] = FakeRackUtil.getRackForHost(value)
}
@@ -929,6 +935,8 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
// > 0ms, so advance the clock by 1ms here.
clock.advance(1)
assert(manager.checkSpeculatableTasks(0))
+ assert(sched.speculativeTasks.toSet === Set(3))
+
// Offer resource to start the speculative attempt for the running task
val taskOption5 = manager.resourceOffer("exec1", "host1", NO_PREF)
assert(taskOption5.isDefined)
@@ -1016,6 +1024,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
// > 0ms, so advance the clock by 1ms here.
clock.advance(1)
assert(manager.checkSpeculatableTasks(0))
+ assert(sched.speculativeTasks.toSet === Set(3, 4))
// Offer resource to start the speculative attempt for the running task
val taskOption5 = manager.resourceOffer("exec1", "host1", NO_PREF)
assert(taskOption5.isDefined)
diff --git a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
index 608052f5ed855..78f618f8a2163 100644
--- a/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/security/CryptoStreamUtilsSuite.scala
@@ -130,6 +130,7 @@ class CryptoStreamUtilsSuite extends SparkFunSuite {
val conf = createConf()
val key = createKey(conf)
val file = Files.createTempFile("crypto", ".test").toFile()
+ file.deleteOnExit()
val outStream = createCryptoOutputStream(new FileOutputStream(file), conf, key)
try {
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 8dd70fcb2fbd5..cfe89fde63f88 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -33,7 +33,7 @@ import org.mockito.{Matchers => mc}
import org.mockito.Mockito.{mock, times, verify, when}
import org.scalatest._
import org.scalatest.concurrent.Eventually._
-import org.scalatest.concurrent.Timeouts._
+import org.scalatest.concurrent.TimeLimits._
import org.apache.spark._
import org.apache.spark.broadcast.BroadcastManager
diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
index 4253cc8ca4cd1..cbc903f17ad75 100644
--- a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.storage
import org.mockito.Matchers
import org.mockito.Mockito._
-import org.scalatest.mock.MockitoSugar
+import org.scalatest.mockito.MockitoSugar
import org.apache.spark.SparkFunSuite
import org.apache.spark.memory.MemoryMode.ON_HEAP
diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala
index 6f7dddd4f760a..f4f8388f5f19f 100644
--- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala
@@ -24,11 +24,11 @@ import scala.concurrent.duration._
import scala.language.postfixOps
import org.scalatest.concurrent.Eventually._
-import org.scalatest.concurrent.Timeouts
+import org.scalatest.concurrent.TimeLimits
import org.apache.spark.SparkFunSuite
-class EventLoopSuite extends SparkFunSuite with Timeouts {
+class EventLoopSuite extends SparkFunSuite with TimeLimits {
test("EventLoop") {
val buffer = new ConcurrentLinkedQueue[Int]
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index 4ce143f18bbf1..05d58d8e6099d 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -939,6 +939,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
// creating a very misbehaving process. It ignores SIGTERM and has been SIGSTOPed. On
// older versions of java, this will *not* terminate.
val file = File.createTempFile("temp-file-name", ".tmp")
+ file.deleteOnExit()
val cmd =
s"""
|#!/bin/bash
diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh
index 9bf2899e340ec..f4a7f25c2413f 100755
--- a/dev/create-release/release-build.sh
+++ b/dev/create-release/release-build.sh
@@ -80,8 +80,17 @@ NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads
BASE_DIR=$(pwd)
MVN="build/mvn --force"
-PUBLISH_PROFILES="-Pmesos -Pyarn -Phive -Phive-thriftserver"
-PUBLISH_PROFILES="$PUBLISH_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl"
+
+# Hive-specific profiles for some builds
+HIVE_PROFILES="-Phive -Phive-thriftserver"
+# Profiles for publishing snapshots and release to Maven Central
+PUBLISH_PROFILES="-Pmesos -Pyarn $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl"
+# Profiles for building binary releases
+BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Psparkr"
+# Scala 2.11 only profiles for some builds
+SCALA_2_11_PROFILES="-Pkafka-0-8"
+# Scala 2.12 only profiles for some builds
+SCALA_2_12_PROFILES="-Pscala-2.12"
rm -rf spark
git clone https://git-wip-us.apache.org/repos/asf/spark.git
@@ -235,10 +244,9 @@ if [[ "$1" == "package" ]]; then
# We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds
# share the same Zinc server.
- FLAGS="-Psparkr -Phive -Phive-thriftserver -Pyarn -Pmesos"
- make_binary_release "hadoop2.6" "-Phadoop-2.6 $FLAGS" "3035" "withr" &
- make_binary_release "hadoop2.7" "-Phadoop-2.7 $FLAGS" "3036" "withpip" &
- make_binary_release "without-hadoop" "-Psparkr -Phadoop-provided -Pyarn -Pmesos" "3038" &
+ make_binary_release "hadoop2.6" "-Phadoop-2.6 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3035" "withr" &
+ make_binary_release "hadoop2.7" "-Phadoop-2.7 $HIVE_PROFILES $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3036" "withpip" &
+ make_binary_release "without-hadoop" "-Phadoop-provided $SCALA_2_11_PROFILES $BASE_RELEASE_PROFILES" "3038" &
wait
rm -rf spark-$SPARK_VERSION-bin-*/
@@ -304,10 +312,10 @@ if [[ "$1" == "publish-snapshot" ]]; then
# Generate random point for Zinc
export ZINC_PORT=$(python -S -c "import random; print random.randrange(3030,4030)")
- $MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $PUBLISH_PROFILES deploy
+ $MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $SCALA_2_11_PROFILES $PUBLISH_PROFILES deploy
#./dev/change-scala-version.sh 2.12
- #$MVN -DzincPort=$ZINC_PORT -Pscala-2.12 --settings $tmp_settings \
- # -DskipTests $PUBLISH_PROFILES clean deploy
+ #$MVN -DzincPort=$ZINC_PORT --settings $tmp_settings \
+ # -DskipTests $SCALA_2_12_PROFILES $PUBLISH_PROFILES clean deploy
# Clean-up Zinc nailgun process
/usr/sbin/lsof -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill
@@ -340,15 +348,17 @@ if [[ "$1" == "publish-release" ]]; then
# Generate random point for Zinc
export ZINC_PORT=$(python -S -c "import random; print random.randrange(3030,4030)")
- $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $PUBLISH_PROFILES clean install
+ $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $SCALA_2_11_PROFILES $PUBLISH_PROFILES clean install
#./dev/change-scala-version.sh 2.12
- #$MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -Pscala-2.12 \
- # -DskipTests $PUBLISH_PROFILES clean install
+ #$MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo \
+ # -DskipTests $SCALA_2_12_PROFILES §$PUBLISH_PROFILES clean install
# Clean-up Zinc nailgun process
/usr/sbin/lsof -P |grep $ZINC_PORT | grep LISTEN | awk '{ print $2; }' | xargs kill
+ #./dev/change-scala-version.sh 2.11
+
pushd $tmp_repo/org/apache/spark
# Remove any extra files generated during install
diff --git a/dev/deps/spark-deps-hadoop-palantir b/dev/deps/spark-deps-hadoop-palantir
index f14edb7053f5e..b1438013421f4 100644
--- a/dev/deps/spark-deps-hadoop-palantir
+++ b/dev/deps/spark-deps-hadoop-palantir
@@ -33,8 +33,8 @@ breeze_2.11-0.13.2.jar
calcite-avatica-1.2.0-incubating.jar
calcite-core-1.2.0-incubating.jar
calcite-linq4j-1.2.0-incubating.jar
-chill-java-0.8.0.jar
-chill_2.11-0.8.0.jar
+chill-java-0.8.4.jar
+chill_2.11-0.8.4.jar
classmate-1.1.0.jar
commons-beanutils-1.9.3.jar
commons-beanutils-core-1.8.0.jar
@@ -213,7 +213,7 @@ scala-compiler-2.11.8.jar
scala-library-2.11.8.jar
scala-parser-combinators_2.11-1.0.4.jar
scala-reflect-2.11.8.jar
-scala-xml_2.11-1.0.2.jar
+scala-xml_2.11-1.0.5.jar
scalap-2.11.8.jar
shapeless_2.11-2.3.2.jar
slf4j-api-1.7.25.jar
@@ -227,7 +227,7 @@ stax-api-1.0-2.jar
stax-api-1.0.1.jar
stream-2.7.0.jar
stringtemplate-3.2.1.jar
-univocity-parsers-2.2.1.jar
+univocity-parsers-2.5.4.jar
validation-api-1.1.0.Final.jar
xbean-asm5-shaded-4.4.jar
xmlenc-0.52.jar
diff --git a/dev/mima b/dev/mima
index 5501589b7900a..fdb21f5007cf2 100755
--- a/dev/mima
+++ b/dev/mima
@@ -24,7 +24,7 @@ set -e
FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
cd "$FWDIR"
-SPARK_PROFILES="-Pmesos -Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive"
+SPARK_PROFILES="-Pmesos -Pkafka-0-8 -Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive"
TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | tail -n1)"
OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)"
diff --git a/dev/scalastyle b/dev/scalastyle
index e7bf3c7a03af6..2e1336df9c1f3 100755
--- a/dev/scalastyle
+++ b/dev/scalastyle
@@ -24,6 +24,7 @@ ERRORS=$(echo -e "q\n" \
-Phadoop-cloud \
-Pkinesis-asl \
-Pmesos \
+ -Pkafka-0-8 \
-Pyarn \
-Phive \
-Phive-thriftserver \
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 81be243f4c6cc..9bbc9b9a84cb6 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -249,6 +249,12 @@ def __hash__(self):
"external/kafka-0-8",
"external/kafka-0-8-assembly",
],
+ build_profile_flags=[
+ "-Pkafka-0-8",
+ ],
+ environ={
+ "ENABLE_KAFKA_0_8_TESTS": "1"
+ },
sbt_test_goals=[
"streaming-kafka-0-8/test",
]
diff --git a/docs/README.md b/docs/README.md
index 0090dd071e15f..225bb1b2040de 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -9,19 +9,22 @@ documentation yourself. Why build it yourself? So that you have the docs that co
whichever version of Spark you currently have checked out of revision control.
## Prerequisites
-The Spark documentation build uses a number of tools to build HTML docs and API docs in Scala,
-Python and R.
+
+The Spark documentation build uses a number of tools to build HTML docs and API docs in Scala, Java,
+Python, R and SQL.
You need to have [Ruby](https://www.ruby-lang.org/en/documentation/installation/) and
[Python](https://docs.python.org/2/using/unix.html#getting-and-installing-the-latest-version-of-python)
installed. Also install the following libraries:
+
```sh
- $ sudo gem install jekyll jekyll-redirect-from pygments.rb
- $ sudo pip install Pygments
- # Following is needed only for generating API docs
- $ sudo pip install sphinx pypandoc
- $ sudo Rscript -e 'install.packages(c("knitr", "devtools", "roxygen2", "testthat", "rmarkdown"), repos="http://cran.stat.ucla.edu/")'
+$ sudo gem install jekyll jekyll-redirect-from pygments.rb
+$ sudo pip install Pygments
+# Following is needed only for generating API docs
+$ sudo pip install sphinx pypandoc mkdocs
+$ sudo Rscript -e 'install.packages(c("knitr", "devtools", "roxygen2", "testthat", "rmarkdown"), repos="http://cran.stat.ucla.edu/")'
```
+
(Note: If you are on a system with both Ruby 1.9 and Ruby 2.0 you may need to replace gem with gem2.0)
## Generating the Documentation HTML
@@ -32,42 +35,49 @@ the source code and be captured by revision control (currently git). This way th
includes the version of the documentation that is relevant regardless of which version or release
you have checked out or downloaded.
-In this directory you will find textfiles formatted using Markdown, with an ".md" suffix. You can
-read those text files directly if you want. Start with index.md.
+In this directory you will find text files formatted using Markdown, with an ".md" suffix. You can
+read those text files directly if you want. Start with `index.md`.
Execute `jekyll build` from the `docs/` directory to compile the site. Compiling the site with
-Jekyll will create a directory called `_site` containing index.html as well as the rest of the
+Jekyll will create a directory called `_site` containing `index.html` as well as the rest of the
compiled files.
- $ cd docs
- $ jekyll build
+```sh
+$ cd docs
+$ jekyll build
+```
You can modify the default Jekyll build as follows:
+
```sh
- # Skip generating API docs (which takes a while)
- $ SKIP_API=1 jekyll build
-
- # Serve content locally on port 4000
- $ jekyll serve --watch
-
- # Build the site with extra features used on the live page
- $ PRODUCTION=1 jekyll build
+# Skip generating API docs (which takes a while)
+$ SKIP_API=1 jekyll build
+
+# Serve content locally on port 4000
+$ jekyll serve --watch
+
+# Build the site with extra features used on the live page
+$ PRODUCTION=1 jekyll build
```
-## API Docs (Scaladoc, Sphinx, roxygen2)
+## API Docs (Scaladoc, Javadoc, Sphinx, roxygen2, MkDocs)
-You can build just the Spark scaladoc by running `build/sbt unidoc` from the SPARK_PROJECT_ROOT directory.
+You can build just the Spark scaladoc and javadoc by running `build/sbt unidoc` from the `SPARK_HOME` directory.
Similarly, you can build just the PySpark docs by running `make html` from the
-SPARK_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as
-public in `__init__.py`. The SparkR docs can be built by running SPARK_PROJECT_ROOT/R/create-docs.sh.
+`SPARK_HOME/python/docs` directory. Documentation is only generated for classes that are listed as
+public in `__init__.py`. The SparkR docs can be built by running `SPARK_HOME/R/create-docs.sh`, and
+the SQL docs can be built by running `SPARK_HOME/sql/create-docs.sh`
+after [building Spark](https://github.com/apache/spark#building-spark) first.
-When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various
+When you run `jekyll build` in the `docs` directory, it will also copy over the scaladoc and javadoc for the various
Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a
jekyll plugin to run `build/sbt unidoc` before building the site so if you haven't run it (recently) it
-may take some time as it generates all of the scaladoc. The jekyll plugin also generates the
-PySpark docs using [Sphinx](http://sphinx-doc.org/).
+may take some time as it generates all of the scaladoc and javadoc using [Unidoc](https://github.com/sbt/sbt-unidoc).
+The jekyll plugin also generates the PySpark docs using [Sphinx](http://sphinx-doc.org/), SparkR docs
+using [roxygen2](https://cran.r-project.org/web/packages/roxygen2/index.html) and SQL docs
+using [MkDocs](http://www.mkdocs.org/).
-NOTE: To skip the step of building and copying over the Scala, Python, R and SQL API docs, run `SKIP_API=1
-jekyll`. In addition, `SKIP_SCALADOC=1`, `SKIP_PYTHONDOC=1`, `SKIP_RDOC=1` and `SKIP_SQLDOC=1` can be used
-to skip a single step of the corresponding language.
+NOTE: To skip the step of building and copying over the Scala, Java, Python, R and SQL API docs, run `SKIP_API=1
+jekyll build`. In addition, `SKIP_SCALADOC=1`, `SKIP_PYTHONDOC=1`, `SKIP_RDOC=1` and `SKIP_SQLDOC=1` can be used
+to skip a single step of the corresponding language. `SKIP_SCALADOC` indicates skipping both the Scala and Java docs.
diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb
index 00366f803c2ad..4d0d043a349bb 100644
--- a/docs/_plugins/copy_api_dirs.rb
+++ b/docs/_plugins/copy_api_dirs.rb
@@ -20,7 +20,7 @@
if not (ENV['SKIP_API'] == '1')
if not (ENV['SKIP_SCALADOC'] == '1')
- # Build Scaladoc for Java/Scala
+ # Build Scaladoc for Scala and Javadoc for Java
puts "Moving to project root and building API docs."
curr_dir = pwd
diff --git a/docs/building-spark.md b/docs/building-spark.md
index 69d83023b2281..57baa503259c1 100644
--- a/docs/building-spark.md
+++ b/docs/building-spark.md
@@ -90,6 +90,15 @@ like ZooKeeper and Hadoop itself.
## Building with Mesos support
./build/mvn -Pmesos -DskipTests clean package
+
+## Building with Kafka 0.8 support
+
+Kafka 0.8 support must be explicitly enabled with the `kafka-0-8` profile.
+Note: Kafka 0.8 support is deprecated as of Spark 2.3.0.
+
+ ./build/mvn -Pkafka-0-8 -DskipTests clean package
+
+Kafka 0.10 support is still automatically built.
## Building submodules individually
@@ -111,7 +120,7 @@ should run continuous compilation (i.e. wait for changes). However, this has not
extensively. A couple of gotchas to note:
* it only scans the paths `src/main` and `src/test` (see
-[docs](http://scala-tools.org/mvnsites/maven-scala-plugin/usage_cc.html)), so it will only work
+[docs](http://davidb.github.io/scala-maven-plugin/example_cc.html)), so it will only work
from within certain submodules that have that structure.
* you'll typically need to run `mvn install` from the project root for compilation within
diff --git a/docs/configuration.md b/docs/configuration.md
index e7c0306920e08..6e9fe591b70a3 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -627,10 +627,10 @@ Apart from these, the following properties are also available, and may be useful
- spark.shuffle.service.index.cache.entries |
- 1024 |
+ spark.shuffle.service.index.cache.size |
+ 100m |
- Max number of entries to keep in the index cache of the shuffle service.
+ Cache entries limited to the specified memory footprint.
|
diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md
index 807944f20a78a..e6d881639a13b 100644
--- a/docs/job-scheduling.md
+++ b/docs/job-scheduling.md
@@ -235,7 +235,7 @@ properties:
of the cluster. By default, each pool's `minShare` is 0.
The pool properties can be set by creating an XML file, similar to `conf/fairscheduler.xml.template`,
-and setting a `spark.scheduler.allocation.file` property in your
+and either putting a file named `fairscheduler.xml` on the classpath, or setting `spark.scheduler.allocation.file` property in your
[SparkConf](configuration.html#spark-properties).
{% highlight scala %}
diff --git a/docs/ml-features.md b/docs/ml-features.md
index e19fba249fb2d..86a0e09997b8e 100644
--- a/docs/ml-features.md
+++ b/docs/ml-features.md
@@ -53,9 +53,9 @@ are calculated based on the mapped indices. This approach avoids the need to com
term-to-index map, which can be expensive for a large corpus, but it suffers from potential hash
collisions, where different raw features may become the same term after hashing. To reduce the
chance of collision, we can increase the target feature dimension, i.e. the number of buckets
-of the hash table. Since a simple modulo is used to transform the hash function to a column index,
-it is advisable to use a power of two as the feature dimension, otherwise the features will
-not be mapped evenly to the columns. The default feature dimension is `$2^{18} = 262,144$`.
+of the hash table. Since a simple modulo on the hashed value is used to determine the vector index,
+it is advisable to use a power of two as the feature dimension, otherwise the features will not
+be mapped evenly to the vector indices. The default feature dimension is `$2^{18} = 262,144$`.
An optional binary toggle parameter controls term frequency counts. When set to true all nonzero
frequency counts are set to 1. This is especially useful for discrete probabilistic models that
model binary, rather than integer, counts.
@@ -65,7 +65,7 @@ model binary, rather than integer, counts.
**IDF**: `IDF` is an `Estimator` which is fit on a dataset and produces an `IDFModel`. The
`IDFModel` takes feature vectors (generally created from `HashingTF` or `CountVectorizer`) and
-scales each column. Intuitively, it down-weights columns which appear frequently in a corpus.
+scales each feature. Intuitively, it down-weights features which appear frequently in a corpus.
**Note:** `spark.ml` doesn't provide tools for text segmentation.
We refer users to the [Stanford NLP Group](http://nlp.stanford.edu/) and
@@ -211,6 +211,89 @@ for more details on the API.
+## FeatureHasher
+
+Feature hashing projects a set of categorical or numerical features into a feature vector of
+specified dimension (typically substantially smaller than that of the original feature
+space). This is done using the [hashing trick](https://en.wikipedia.org/wiki/Feature_hashing)
+to map features to indices in the feature vector.
+
+The `FeatureHasher` transformer operates on multiple columns. Each column may contain either
+numeric or categorical features. Behavior and handling of column data types is as follows:
+
+- Numeric columns: For numeric features, the hash value of the column name is used to map the
+feature value to its index in the feature vector. Numeric features are never treated as
+categorical, even when they are integers. You must explicitly convert numeric columns containing
+categorical features to strings first.
+- String columns: For categorical features, the hash value of the string "column_name=value"
+is used to map to the vector index, with an indicator value of `1.0`. Thus, categorical features
+are "one-hot" encoded (similarly to using [OneHotEncoder](ml-features.html#onehotencoder) with
+`dropLast=false`).
+- Boolean columns: Boolean values are treated in the same way as string columns. That is,
+boolean features are represented as "column_name=true" or "column_name=false", with an indicator
+value of `1.0`.
+
+Null (missing) values are ignored (implicitly zero in the resulting feature vector).
+
+The hash function used here is also the [MurmurHash 3](https://en.wikipedia.org/wiki/MurmurHash)
+used in [HashingTF](ml-features.html#tf-idf). Since a simple modulo on the hashed value is used to
+determine the vector index, it is advisable to use a power of two as the numFeatures parameter;
+otherwise the features will not be mapped evenly to the vector indices.
+
+**Examples**
+
+Assume that we have a DataFrame with 4 input columns `real`, `bool`, `stringNum`, and `string`.
+These different data types as input will illustrate the behavior of the transform to produce a
+column of feature vectors.
+
+~~~~
+real| bool|stringNum|string
+----|-----|---------|------
+ 2.2| true| 1| foo
+ 3.3|false| 2| bar
+ 4.4|false| 3| baz
+ 5.5|false| 4| foo
+~~~~
+
+Then the output of `FeatureHasher.transform` on this DataFrame is:
+
+~~~~
+real|bool |stringNum|string|features
+----|-----|---------|------|-------------------------------------------------------
+2.2 |true |1 |foo |(262144,[51871, 63643,174475,253195],[1.0,1.0,2.2,1.0])
+3.3 |false|2 |bar |(262144,[6031, 80619,140467,174475],[1.0,1.0,1.0,3.3])
+4.4 |false|3 |baz |(262144,[24279,140467,174475,196810],[1.0,1.0,4.4,1.0])
+5.5 |false|4 |foo |(262144,[63643,140467,168512,174475],[1.0,1.0,1.0,5.5])
+~~~~
+
+The resulting feature vectors could then be passed to a learning algorithm.
+
+
+
+
+Refer to the [FeatureHasher Scala docs](api/scala/index.html#org.apache.spark.ml.feature.FeatureHasher)
+for more details on the API.
+
+{% include_example scala/org/apache/spark/examples/ml/FeatureHasherExample.scala %}
+
+
+
+
+Refer to the [FeatureHasher Java docs](api/java/org/apache/spark/ml/feature/FeatureHasher.html)
+for more details on the API.
+
+{% include_example java/org/apache/spark/examples/ml/JavaFeatureHasherExample.java %}
+
+
+
+
+Refer to the [FeatureHasher Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.FeatureHasher)
+for more details on the API.
+
+{% include_example python/ml/feature_hasher_example.py %}
+
+
+
# Feature Transformers
## Tokenizer
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index 7aec6a40d4c64..f6288e7c32d97 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -105,6 +105,24 @@ MLlib is under active development.
The APIs marked `Experimental`/`DeveloperApi` may change in future releases,
and the migration guide below will explain all changes between releases.
+## From 2.2 to 2.3
+
+### Breaking changes
+
+There are no breaking changes.
+
+### Deprecations and changes of behavior
+
+**Deprecations**
+
+There are no deprecations.
+
+**Changes of behavior**
+
+* [SPARK-21027](https://issues.apache.org/jira/browse/SPARK-21027):
+ We are now setting the default parallelism used in `OneVsRest` to be 1 (i.e. serial), in 2.2 and earlier version,
+ the `OneVsRest` parallelism would be parallelism of the default threadpool in scala.
+
## From 2.1 to 2.2
### Breaking changes
diff --git a/docs/ml-tuning.md b/docs/ml-tuning.md
index e9123db29648e..64dc46cf0c0e7 100644
--- a/docs/ml-tuning.md
+++ b/docs/ml-tuning.md
@@ -55,6 +55,8 @@ for multiclass problems. The default metric used to choose the best `ParamMap` c
method in each of these evaluators.
To help construct the parameter grid, users can use the [`ParamGridBuilder`](api/scala/index.html#org.apache.spark.ml.tuning.ParamGridBuilder) utility.
+By default, sets of parameters from the parameter grid are evaluated in serial. Parameter evaluation can be done in parallel by setting `parallelism` with a value of 2 or more (a value of 1 will be serial) before running model selection with `CrossValidator` or `TrainValidationSplit` (NOTE: this is not yet supported in Python).
+The value of `parallelism` should be chosen carefully to maximize parallelism without exceeding cluster resources, and larger values may not always lead to improved performance. Generally speaking, a value up to 10 should be sufficient for most clusters.
# Cross-Validation
diff --git a/docs/mllib-evaluation-metrics.md b/docs/mllib-evaluation-metrics.md
index ac82f43cfb79d..7f277543d2e9a 100644
--- a/docs/mllib-evaluation-metrics.md
+++ b/docs/mllib-evaluation-metrics.md
@@ -549,7 +549,7 @@ variable from a number of independent variables.
Mean Absolute Error (MAE) |
- $MAE=\sum_{i=0}^{N-1} \left|\mathbf{y}_i - \hat{\mathbf{y}}_i\right|$ |
+ $MAE=\frac{1}{N}\sum_{i=0}^{N-1} \left|\mathbf{y}_i - \hat{\mathbf{y}}_i\right|$ |
Coefficient of Determination $(R^2)$ |
diff --git a/docs/monitoring.md b/docs/monitoring.md
index 3e577c5f36778..51084a25983ea 100644
--- a/docs/monitoring.md
+++ b/docs/monitoring.md
@@ -61,6 +61,10 @@ The history server can be configured as follows:
SPARK_DAEMON_JAVA_OPTS |
JVM options for the history server (default: none). |
+
+ SPARK_DAEMON_CLASSPATH |
+ Classpath for the history server (default: none). |
+
SPARK_PUBLIC_DNS |
@@ -451,6 +455,7 @@ Each instance can report to zero or more _sinks_. Sinks are contained in the
* `MetricsServlet`: Adds a servlet within the existing Spark UI to serve metrics data as JSON data.
* `GraphiteSink`: Sends metrics to a Graphite node.
* `Slf4jSink`: Sends metrics to slf4j as log entries.
+* `StatsdSink`: Sends metrics to a StatsD node.
Spark also supports a Ganglia sink which is not included in the default build due to
licensing restrictions:
diff --git a/docs/rdd-programming-guide.md b/docs/rdd-programming-guide.md
index 26025984da64c..29af159510e46 100644
--- a/docs/rdd-programming-guide.md
+++ b/docs/rdd-programming-guide.md
@@ -604,7 +604,7 @@ before the `reduce`, which would cause `lineLengths` to be saved in memory after
Spark's API relies heavily on passing functions in the driver program to run on the cluster.
There are two recommended ways to do this:
-* [Anonymous function syntax](http://docs.scala-lang.org/tutorials/tour/anonymous-function-syntax.html),
+* [Anonymous function syntax](http://docs.scala-lang.org/tour/basics.html#functions),
which can be used for short pieces of code.
* Static methods in a global singleton object. For example, you can define `object MyFunctions` and then
pass `MyFunctions.func1`, as follows:
diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md
index ae3855084a650..e0944bc9f5f86 100644
--- a/docs/running-on-mesos.md
+++ b/docs/running-on-mesos.md
@@ -33,7 +33,8 @@ To get started, follow the steps below to install Mesos and deploy Spark jobs vi
# Installing Mesos
Spark {{site.SPARK_VERSION}} is designed for use with Mesos {{site.MESOS_VERSION}} or newer and does not
-require any special patches of Mesos.
+require any special patches of Mesos. File and environment-based secrets support requires Mesos 1.3.0 or
+newer.
If you already have a Mesos cluster running, you can skip this Mesos installation step.
@@ -160,6 +161,8 @@ If you like to run the `MesosClusterDispatcher` with Marathon, you need to run t
The `MesosClusterDispatcher` also supports writing recovery state into Zookeeper. This will allow the `MesosClusterDispatcher` to be able to recover all submitted and running containers on relaunch. In order to enable this recovery mode, you can set SPARK_DAEMON_JAVA_OPTS in spark-env by configuring `spark.deploy.recoveryMode` and related spark.deploy.zookeeper.* configurations.
For more information about these configurations please refer to the configurations [doc](configurations.html#deploy).
+You can also specify any additional jars required by the `MesosClusterDispatcher` in the classpath by setting the environment variable SPARK_DAEMON_CLASSPATH in spark-env.
+
From the client, you can submit a job to Mesos cluster by running `spark-submit` and specifying the master URL
to the URL of the `MesosClusterDispatcher` (e.g: mesos://dispatcher:7077). You can view driver statuses on the
Spark cluster Web UI.
@@ -428,7 +431,8 @@ See the [configuration page](configuration.html) for information on Spark config
| spark.mesos.secret |
(none) |
- Set the secret with which Spark framework will use to authenticate with Mesos.
+ Set the secret with which Spark framework will use to authenticate with Mesos. Used, for example, when
+ authenticating with the registry.
|
@@ -480,6 +484,43 @@ See the [configuration page](configuration.html) for information on Spark config
+
+ spark.mesos.driver.secret.envkeys |
+ (none) |
+
+ A comma-separated list that, if set, the contents of the secret referenced
+ by spark.mesos.driver.secret.names or spark.mesos.driver.secret.values will be
+ set to the provided environment variable in the driver's process.
+ |
+
+
+spark.mesos.driver.secret.filenames |
+ (none) |
+
+ A comma-separated list that, if set, the contents of the secret referenced by
+ spark.mesos.driver.secret.names or spark.mesos.driver.secret.values will be
+ written to the provided file. Paths are relative to the container's work
+ directory. Absolute paths must already exist. Consult the Mesos Secret
+ protobuf for more information.
+ |
+
+
+ spark.mesos.driver.secret.names |
+ (none) |
+
+ A comma-separated list of secret references. Consult the Mesos Secret
+ protobuf for more information.
+ |
+
+
+ spark.mesos.driver.secret.values |
+ (none) |
+
+ A comma-separated list of secret values. Consult the Mesos Secret
+ protobuf for more information.
+ |
+
+
spark.mesos.driverEnv.[EnvironmentVariableName] |
(none) |
@@ -537,6 +578,20 @@ See the [configuration page](configuration.html) for information on Spark config
for more details.
+
+ spark.mesos.network.labels |
+ (none) |
+
+ Pass network labels to CNI plugins. This is a comma-separated list
+ of key-value pairs, where each key-value pair has the format key:value.
+ Example:
+
+ key1:val1,key2:val2
+ See
+ the Mesos CNI docs
+ for more details.
+ |
+
spark.mesos.fetcherCache.enable |
false |
diff --git a/docs/security.md b/docs/security.md
index 9eda42888637f..1d004003f9a32 100644
--- a/docs/security.md
+++ b/docs/security.md
@@ -73,6 +73,9 @@ For long-running apps like Spark Streaming apps to be able to write to HDFS, it
### Standalone mode
The user needs to provide key-stores and configuration options for master and workers. They have to be set by attaching appropriate Java system properties in `SPARK_MASTER_OPTS` and in `SPARK_WORKER_OPTS` environment variables, or just in `SPARK_DAEMON_JAVA_OPTS`. In this mode, the user may allow the executors to use the SSL settings inherited from the worker which spawned that executor. It can be accomplished by setting `spark.ssl.useNodeLocalConf` to `true`. If that parameter is set, the settings provided by user on the client side, are not used by the executors.
+### Mesos mode
+Mesos 1.3.0 and newer supports `Secrets` primitives as both file-based and environment based secrets. Spark allows the specification of file-based and environment variable based secrets with the `spark.mesos.driver.secret.filenames` and `spark.mesos.driver.secret.envkeys`, respectively. Depending on the secret store backend secrets can be passed by reference or by value with the `spark.mesos.driver.secret.names` and `spark.mesos.driver.secret.values` configuration properties, respectively. Reference type secrets are served by the secret store and referred to by name, for example `/mysecret`. Value type secrets are passed on the command line and translated into their appropriate files or environment variables.
+
### Preparing the key-stores
Key-stores can be generated by `keytool` program. The reference documentation for this tool is
[here](https://docs.oracle.com/javase/7/docs/technotes/tools/solaris/keytool.html). The most basic
diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md
index 642575b46dd42..1095386c31ab8 100644
--- a/docs/spark-standalone.md
+++ b/docs/spark-standalone.md
@@ -149,6 +149,10 @@ You can optionally configure the cluster further by setting environment variable
SPARK_DAEMON_JAVA_OPTS |
JVM options for the Spark master and worker daemons themselves in the form "-Dx=y" (default: none). |
+
+ SPARK_DAEMON_CLASSPATH |
+ Classpath for the Spark master and worker daemons themselves (default: none). |
+
SPARK_PUBLIC_DNS |
The public DNS name of the Spark master and workers (default: none). |
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index ee231a934a3af..5db60cc996e75 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -733,8 +733,9 @@ SELECT * FROM parquetTable
Table partitioning is a common optimization approach used in systems like Hive. In a partitioned
table, data are usually stored in different directories, with partitioning column values encoded in
-the path of each partition directory. The Parquet data source is now able to discover and infer
-partitioning information automatically. For example, we can store all our previously used
+the path of each partition directory. All built-in file sources (including Text/CSV/JSON/ORC/Parquet)
+are able to discover and infer partitioning information automatically.
+For example, we can store all our previously used
population data into a partitioned table using the following directory structure, with two extra
columns, `gender` and `country` as partitioning columns:
@@ -924,13 +925,6 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession
flag tells Spark SQL to interpret INT96 data as a timestamp to provide compatibility with these systems.
-
- spark.sql.parquet.cacheMetadata |
- true |
-
- Turns on caching of Parquet schema metadata. Can speed up querying of static data.
- |
-
spark.sql.parquet.compression.codec |
snappy |
@@ -1334,7 +1328,14 @@ the following case-insensitive options:
The database column data types to use instead of the defaults, when creating the table. Data type information should be specified in the same format as CREATE TABLE columns syntax (e.g: "name CHAR(64), comments VARCHAR(1024)") . The specified types should be valid spark sql data types. This option applies only to writing.
|
-
+
+
+
+ customSchema |
+
+ The custom schema to use for reading data from JDBC connectors. For example, "id DECIMAL(38, 0), name STRING" . You can also specify partial fields, and the others use the default type mapping. For example, "id DECIMAL(38, 0)" . The column names should be identical to the corresponding column names of JDBC table. Users can specify the corresponding data types of Spark SQL instead of using the defaults. This option applies only to reading.
+ |
+
@@ -1549,6 +1550,10 @@ options.
# Migration Guide
+## Upgrading From Spark SQL 2.2 to 2.3
+
+ - Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the referenced columns only include the internal corrupt record column (named `_corrupt_record` by default). For example, `spark.read.schema(schema).json(file).filter($"_corrupt_record".isNotNull).count()` and `spark.read.schema(schema).json(file).select("_corrupt_record").show()`. Instead, you can cache or save the parsed results and then send the same query. For example, `val df = spark.read.schema(schema).json(file).cache()` and then `df.filter($"_corrupt_record".isNotNull).count()`.
+
## Upgrading From Spark SQL 2.1 to 2.2
- Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access.
@@ -1587,6 +1592,9 @@ options.
Note that this is different from the Hive behavior.
- As a result, `DROP TABLE` statements on those tables will not remove the data.
+ - `spark.sql.parquet.cacheMetadata` is no longer used.
+ See [SPARK-13664](https://issues.apache.org/jira/browse/SPARK-13664) for details.
+
## Upgrading From Spark SQL 1.5 to 1.6
- From Spark 1.6, by default the Thrift server runs in multi-session mode. Which means each JDBC/ODBC
diff --git a/docs/streaming-kafka-0-8-integration.md b/docs/streaming-kafka-0-8-integration.md
index 24a3e4cdbbd7d..9f0671da2ee31 100644
--- a/docs/streaming-kafka-0-8-integration.md
+++ b/docs/streaming-kafka-0-8-integration.md
@@ -2,6 +2,9 @@
layout: global
title: Spark Streaming + Kafka Integration Guide (Kafka broker version 0.8.2.1 or higher)
---
+
+**Note: Kafka 0.8 support is deprecated as of Spark 2.3.0.**
+
Here we explain how to configure Spark Streaming to receive data from Kafka. There are two approaches to this - the old approach using Receivers and Kafka's high-level API, and a new approach (introduced in Spark 1.3) without using Receivers. They have different programming models, performance characteristics, and semantics guarantees, so read on for more details. Both approaches are considered stable APIs as of the current version of Spark.
## Approach 1: Receiver-based Approach
@@ -28,8 +31,7 @@ Next, we discuss how to use this approach in your streaming application.
val kafkaStream = KafkaUtils.createStream(streamingContext,
[ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume])
- You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala).
+ You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$).
import org.apache.spark.streaming.kafka.*;
@@ -38,8 +40,7 @@ Next, we discuss how to use this approach in your streaming application.
KafkaUtils.createStream(streamingContext,
[ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume]);
- You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java).
+ You can also specify the key and value classes and their corresponding decoder classes using variations of `createStream`. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html).
@@ -48,8 +49,7 @@ Next, we discuss how to use this approach in your streaming application.
kafkaStream = KafkaUtils.createStream(streamingContext, \
[ZK quorum], [consumer group id], [per-topic number of Kafka partitions to consume])
- By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/kafka_wordcount.py).
+ By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils).
@@ -71,7 +71,7 @@ Next, we discuss how to use this approach in your streaming application.
./bin/spark-submit --packages org.apache.spark:spark-streaming-kafka-0-8_{{site.SCALA_BINARY_VERSION}}:{{site.SPARK_VERSION_SHORT}} ...
Alternatively, you can also download the JAR of the Maven artifact `spark-streaming-kafka-0-8-assembly` from the
- [Maven repository](http://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-0-8-assembly_{{site.SCALA_BINARY_VERSION}}%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`.
+ [Maven repository](https://search.maven.org/#search|ga|1|a%3A%22spark-streaming-kafka-0-8-assembly_{{site.SCALA_BINARY_VERSION}}%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) and add it to `spark-submit` with `--jars`.
## Approach 2: Direct Approach (No Receivers)
This new receiver-less "direct" approach has been introduced in Spark 1.3 to ensure stronger end-to-end guarantees. Instead of using receivers to receive data, this approach periodically queries Kafka for the latest offsets in each topic+partition, and accordingly defines the offset ranges to process in each batch. When the jobs to process the data are launched, Kafka's simple consumer API is used to read the defined ranges of offsets from Kafka (similar to read files from a file system). Note that this feature was introduced in Spark 1.3 for the Scala and Java API, in Spark 1.4 for the Python API.
@@ -105,8 +105,7 @@ Next, we discuss how to use this approach in your streaming application.
streamingContext, [map of Kafka parameters], [set of topics to consume])
You can also pass a `messageHandler` to `createDirectStream` to access `MessageAndMetadata` that contains metadata about the current message and transform it to any desired type.
- See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala).
+ See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$).
import org.apache.spark.streaming.kafka.*;
@@ -117,8 +116,7 @@ Next, we discuss how to use this approach in your streaming application.
[map of Kafka parameters], [set of topics to consume]);
You can also pass a `messageHandler` to `createDirectStream` to access `MessageAndMetadata` that contains metadata about the current message and transform it to any desired type.
- See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java).
+ See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html).
@@ -126,8 +124,7 @@ Next, we discuss how to use this approach in your streaming application.
directKafkaStream = KafkaUtils.createDirectStream(ssc, [topic], {"metadata.broker.list": brokers})
You can also pass a `messageHandler` to `createDirectStream` to access `KafkaMessageAndMetadata` that contains metadata about the current message and transform it to any desired type.
- By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils)
- and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/direct_kafka_wordcount.py).
+ By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils).
diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md
index a8f3667a49850..4aca391e4ba1a 100644
--- a/docs/streaming-kafka-integration.md
+++ b/docs/streaming-kafka-integration.md
@@ -3,10 +3,11 @@ layout: global
title: Spark Streaming + Kafka Integration Guide
---
-[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Please read the [Kafka documentation](http://kafka.apache.org/documentation.html) thoroughly before starting an integration using Spark.
+[Apache Kafka](https://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Please read the [Kafka documentation](https://kafka.apache.org/documentation.html) thoroughly before starting an integration using Spark.
-The Kafka project introduced a new consumer api between versions 0.8 and 0.10, so there are 2 separate corresponding Spark Streaming packages available. Please choose the correct package for your brokers and desired features; note that the 0.8 integration is compatible with later 0.9 and 0.10 brokers, but the 0.10 integration is not compatible with earlier brokers.
+The Kafka project introduced a new consumer API between versions 0.8 and 0.10, so there are 2 separate corresponding Spark Streaming packages available. Please choose the correct package for your brokers and desired features; note that the 0.8 integration is compatible with later 0.9 and 0.10 brokers, but the 0.10 integration is not compatible with earlier brokers.
+**Note: Kafka 0.8 support is deprecated as of Spark 2.3.0.**
| spark-streaming-kafka-0-8 | spark-streaming-kafka-0-10 |
@@ -16,9 +17,9 @@ The Kafka project introduced a new consumer api between versions 0.8 and 0.10, s
0.10.0 or higher |
- Api Stability |
+ API Maturity |
+ Deprecated |
Stable |
- Experimental |
Language Support |
@@ -41,7 +42,7 @@ The Kafka project introduced a new consumer api between versions 0.8 and 0.10, s
Yes |
- Offset Commit Api |
+ Offset Commit API |
No |
Yes |
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index fca0cf8ff05f2..bc200cd07ebd8 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -401,14 +401,14 @@ some of the common ones are as follows.
Source | Artifact |
- Kafka | spark-streaming-kafka-0-8_{{site.SCALA_BINARY_VERSION}} |
+ Kafka | spark-streaming-kafka-0-10_{{site.SCALA_BINARY_VERSION}} |
Flume | spark-streaming-flume_{{site.SCALA_BINARY_VERSION}} |
Kinesis
| spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}} [Amazon Software License] |
| |
For an up-to-date list, please refer to the
-[Maven repository](http://search.maven.org/#search%7Cga%7C1%7Cg%3A%22org.apache.spark%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22)
+[Maven repository](https://search.maven.org/#search%7Cga%7C1%7Cg%3A%22org.apache.spark%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22)
for the full list of supported sources and artifacts.
***
@@ -1899,7 +1899,7 @@ To run a Spark Streaming applications, you need to have the following.
if your application uses [advanced sources](#advanced-sources) (e.g. Kafka, Flume),
then you will have to package the extra artifact they link to, along with their dependencies,
in the JAR that is used to deploy the application. For example, an application using `KafkaUtils`
- will have to include `spark-streaming-kafka-0-8_{{site.SCALA_BINARY_VERSION}}` and all its
+ will have to include `spark-streaming-kafka-0-10_{{site.SCALA_BINARY_VERSION}}` and all its
transitive dependencies in the application JAR.
- *Configuring sufficient memory for the executors* - Since the received data must be stored in
diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md
index 8367f5a08c755..93bef8d5bb7e2 100644
--- a/docs/structured-streaming-programming-guide.md
+++ b/docs/structured-streaming-programming-guide.md
@@ -977,7 +977,7 @@ at the beginning of every trigger is the red line For example, when the engine
`(12:14, dog)`, it sets the watermark for the next trigger as `12:04`.
This watermark lets the engine maintain intermediate state for additional 10 minutes to allow late
data to be counted. For example, the data `(12:09, cat)` is out of order and late, and it falls in
-windows `12:05 - 12:15` and `12:10 - 12:20`. Since, it is still ahead of the watermark `12:04` in
+windows `12:00 - 12:10` and `12:05 - 12:15`. Since, it is still ahead of the watermark `12:04` in
the trigger, the engine still maintains the intermediate counts as state and correctly updates the
counts of the related windows. However, when the watermark is updated to `12:11`, the intermediate
state for window `(12:00 - 12:10)` is cleared, and all subsequent data (e.g. `(12:04, donkey)`)
@@ -1168,7 +1168,7 @@ returned through `Dataset.writeStream()`. You will have to specify one or more o
- *Query name:* Optionally, specify a unique name of the query for identification.
-- *Trigger interval:* Optionally, specify the trigger interval. If it is not specified, the system will check for availability of new data as soon as the previous processing has completed. If a trigger time is missed because the previous processing has not completed, then the system will attempt to trigger at the next trigger point, not immediately after the processing has completed.
+- *Trigger interval:* Optionally, specify the trigger interval. If it is not specified, the system will check for availability of new data as soon as the previous processing has completed. If a trigger time is missed because the previous processing has not completed, then the system will trigger processing immediately.
- *Checkpoint location:* For some output sinks where the end-to-end fault-tolerance can be guaranteed, specify the location where the system will write all the checkpoint information. This should be a directory in an HDFS-compatible fault-tolerant file system. The semantics of checkpointing is discussed in more detail in the next section.
diff --git a/examples/pom.xml b/examples/pom.xml
index e1b2e7bc38cc6..aca82abe601fb 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -86,7 +86,7 @@
org.apache.spark
- spark-streaming-kafka-0-8_${scala.binary.version}
+ spark-streaming-kafka-0-10_${scala.binary.version}
${project.version}
provided
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaFeatureHasherExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaFeatureHasherExample.java
new file mode 100644
index 0000000000000..9730d42e6db8d
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaFeatureHasherExample.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.SparkSession;
+
+// $example on$
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.spark.ml.feature.FeatureHasher;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+// $example off$
+
+public class JavaFeatureHasherExample {
+ public static void main(String[] args) {
+ SparkSession spark = SparkSession
+ .builder()
+ .appName("JavaFeatureHasherExample")
+ .getOrCreate();
+
+ // $example on$
+ List data = Arrays.asList(
+ RowFactory.create(2.2, true, "1", "foo"),
+ RowFactory.create(3.3, false, "2", "bar"),
+ RowFactory.create(4.4, false, "3", "baz"),
+ RowFactory.create(5.5, false, "4", "foo")
+ );
+ StructType schema = new StructType(new StructField[]{
+ new StructField("real", DataTypes.DoubleType, false, Metadata.empty()),
+ new StructField("bool", DataTypes.BooleanType, false, Metadata.empty()),
+ new StructField("stringNum", DataTypes.StringType, false, Metadata.empty()),
+ new StructField("string", DataTypes.StringType, false, Metadata.empty())
+ });
+ Dataset dataset = spark.createDataFrame(data, schema);
+
+ FeatureHasher hasher = new FeatureHasher()
+ .setInputCols(new String[]{"real", "bool", "stringNum", "string"})
+ .setOutputCol("features");
+
+ Dataset featurized = hasher.transform(dataset);
+
+ featurized.show(false);
+ // $example off$
+
+ spark.stop();
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java
index 975c65edc0ca6..d97327969ab26 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaCrossValidationExample.java
@@ -94,7 +94,9 @@ public static void main(String[] args) {
CrossValidator cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(new BinaryClassificationEvaluator())
- .setEstimatorParamMaps(paramGrid).setNumFolds(2); // Use 3+ in practice
+ .setEstimatorParamMaps(paramGrid)
+ .setNumFolds(2) // Use 3+ in practice
+ .setParallelism(2); // Evaluate up to 2 parameter settings in parallel
// Run cross-validation, and choose the best set of parameters.
CrossValidatorModel cvModel = cv.fit(training);
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java
index 9a4722b90cf1b..2ef8bea0b2a2b 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java
@@ -70,7 +70,8 @@ public static void main(String[] args) {
.setEstimator(lr)
.setEvaluator(new RegressionEvaluator())
.setEstimatorParamMaps(paramGrid)
- .setTrainRatio(0.8); // 80% for training and the remaining 20% for validation
+ .setTrainRatio(0.8) // 80% for training and the remaining 20% for validation
+ .setParallelism(2); // Evaluate up to 2 parameter settings in parallel
// Run train validation split, and choose the best set of parameters.
TrainValidationSplitModel model = trainValidationSplit.fit(training);
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java
index 6b8e6554f1bb1..943e3d82f30ff 100644
--- a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java
+++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java
@@ -69,7 +69,7 @@ public static void main(String[] args) throws Exception {
FlatMapFunction linesToEvents =
new FlatMapFunction() {
@Override
- public Iterator call(LineWithTimestamp lineWithTimestamp) throws Exception {
+ public Iterator call(LineWithTimestamp lineWithTimestamp) {
ArrayList eventList = new ArrayList();
for (String word : lineWithTimestamp.getLine().split(" ")) {
eventList.add(new Event(word, lineWithTimestamp.getTimestamp()));
@@ -91,8 +91,7 @@ public Iterator call(LineWithTimestamp lineWithTimestamp) throws Exceptio
MapGroupsWithStateFunction stateUpdateFunc =
new MapGroupsWithStateFunction() {
@Override public SessionUpdate call(
- String sessionId, Iterator events, GroupState state)
- throws Exception {
+ String sessionId, Iterator events, GroupState state) {
// If timed out, then remove session and send final update
if (state.hasTimedOut()) {
SessionUpdate finalUpdate = new SessionUpdate(
@@ -138,7 +137,7 @@ public Iterator call(LineWithTimestamp lineWithTimestamp) throws Exceptio
Dataset sessionUpdates = events
.groupByKey(
new MapFunction() {
- @Override public String call(Event event) throws Exception {
+ @Override public String call(Event event) {
return event.getSessionId();
}
}, Encoders.STRING())
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java
index 5e5ae6213d5d9..b6b163fa8b2cd 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java
@@ -26,11 +26,13 @@
import scala.Tuple2;
-import kafka.serializer.StringDecoder;
+import org.apache.kafka.clients.consumer.ConsumerRecord;
import org.apache.spark.SparkConf;
import org.apache.spark.streaming.api.java.*;
-import org.apache.spark.streaming.kafka.KafkaUtils;
+import org.apache.spark.streaming.kafka010.ConsumerStrategies;
+import org.apache.spark.streaming.kafka010.KafkaUtils;
+import org.apache.spark.streaming.kafka010.LocationStrategies;
import org.apache.spark.streaming.Durations;
/**
@@ -65,22 +67,17 @@ public static void main(String[] args) throws Exception {
JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, Durations.seconds(2));
Set topicsSet = new HashSet<>(Arrays.asList(topics.split(",")));
- Map kafkaParams = new HashMap<>();
+ Map kafkaParams = new HashMap<>();
kafkaParams.put("metadata.broker.list", brokers);
// Create direct kafka stream with brokers and topics
- JavaPairInputDStream messages = KafkaUtils.createDirectStream(
+ JavaInputDStream> messages = KafkaUtils.createDirectStream(
jssc,
- String.class,
- String.class,
- StringDecoder.class,
- StringDecoder.class,
- kafkaParams,
- topicsSet
- );
+ LocationStrategies.PreferConsistent(),
+ ConsumerStrategies.Subscribe(topicsSet, kafkaParams));
// Get the lines, split them into words, count the words and print
- JavaDStream lines = messages.map(Tuple2::_2);
+ JavaDStream lines = messages.map(ConsumerRecord::value);
JavaDStream words = lines.flatMap(x -> Arrays.asList(SPACE.split(x)).iterator());
JavaPairDStream wordCounts = words.mapToPair(s -> new Tuple2<>(s, 1))
.reduceByKey((i1, i2) -> i1 + i2);
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java
deleted file mode 100644
index ce5acdca92666..0000000000000
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.examples.streaming;
-
-import java.util.Arrays;
-import java.util.Map;
-import java.util.HashMap;
-import java.util.regex.Pattern;
-
-import scala.Tuple2;
-
-import org.apache.spark.SparkConf;
-import org.apache.spark.streaming.Duration;
-import org.apache.spark.streaming.api.java.JavaDStream;
-import org.apache.spark.streaming.api.java.JavaPairDStream;
-import org.apache.spark.streaming.api.java.JavaPairReceiverInputDStream;
-import org.apache.spark.streaming.api.java.JavaStreamingContext;
-import org.apache.spark.streaming.kafka.KafkaUtils;
-
-/**
- * Consumes messages from one or more topics in Kafka and does wordcount.
- *
- * Usage: JavaKafkaWordCount
- * is a list of one or more zookeeper servers that make quorum
- * is the name of kafka consumer group
- * is a list of one or more kafka topics to consume from
- * is the number of threads the kafka consumer should use
- *
- * To run this example:
- * `$ bin/run-example org.apache.spark.examples.streaming.JavaKafkaWordCount zoo01,zoo02, \
- * zoo03 my-consumer-group topic1,topic2 1`
- */
-
-public final class JavaKafkaWordCount {
- private static final Pattern SPACE = Pattern.compile(" ");
-
- private JavaKafkaWordCount() {
- }
-
- public static void main(String[] args) throws Exception {
- if (args.length < 4) {
- System.err.println("Usage: JavaKafkaWordCount ");
- System.exit(1);
- }
-
- StreamingExamples.setStreamingLogLevels();
- SparkConf sparkConf = new SparkConf().setAppName("JavaKafkaWordCount");
- // Create the context with 2 seconds batch size
- JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, new Duration(2000));
-
- int numThreads = Integer.parseInt(args[3]);
- Map topicMap = new HashMap<>();
- String[] topics = args[2].split(",");
- for (String topic: topics) {
- topicMap.put(topic, numThreads);
- }
-
- JavaPairReceiverInputDStream messages =
- KafkaUtils.createStream(jssc, args[0], args[1], topicMap);
-
- JavaDStream lines = messages.map(Tuple2::_2);
-
- JavaDStream words = lines.flatMap(x -> Arrays.asList(SPACE.split(x)).iterator());
-
- JavaPairDStream wordCounts = words.mapToPair(s -> new Tuple2<>(s, 1))
- .reduceByKey((i1, i2) -> i1 + i2);
-
- wordCounts.print();
- jssc.start();
- jssc.awaitTermination();
- }
-}
diff --git a/examples/src/main/python/ml/feature_hasher_example.py b/examples/src/main/python/ml/feature_hasher_example.py
new file mode 100644
index 0000000000000..6cf9ecc396400
--- /dev/null
+++ b/examples/src/main/python/ml/feature_hasher_example.py
@@ -0,0 +1,46 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import print_function
+
+from pyspark.sql import SparkSession
+# $example on$
+from pyspark.ml.feature import FeatureHasher
+# $example off$
+
+if __name__ == "__main__":
+ spark = SparkSession\
+ .builder\
+ .appName("FeatureHasherExample")\
+ .getOrCreate()
+
+ # $example on$
+ dataset = spark.createDataFrame([
+ (2.2, True, "1", "foo"),
+ (3.3, False, "2", "bar"),
+ (4.4, False, "3", "baz"),
+ (5.5, False, "4", "foo")
+ ], ["real", "bool", "stringNum", "string"])
+
+ hasher = FeatureHasher(inputCols=["real", "bool", "stringNum", "string"],
+ outputCol="features")
+
+ featurized = hasher.transform(dataset)
+ featurized.show(truncate=False)
+ # $example off$
+
+ spark.stop()
diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py
index 8777cca66bfe9..f86012ea382e8 100644
--- a/examples/src/main/python/sql/datasource.py
+++ b/examples/src/main/python/sql/datasource.py
@@ -177,6 +177,16 @@ def jdbc_dataset_example(spark):
.jdbc("jdbc:postgresql:dbserver", "schema.tablename",
properties={"user": "username", "password": "password"})
+ # Specifying dataframe column data types on read
+ jdbcDF3 = spark.read \
+ .format("jdbc") \
+ .option("url", "jdbc:postgresql:dbserver") \
+ .option("dbtable", "schema.tablename") \
+ .option("user", "username") \
+ .option("password", "password") \
+ .option("customSchema", "id DECIMAL(38, 0), name STRING") \
+ .load()
+
# Saving data to a JDBC source
jdbcDF.write \
.format("jdbc") \
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/FeatureHasherExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/FeatureHasherExample.scala
new file mode 100644
index 0000000000000..1aed10bfb2d38
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/FeatureHasherExample.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+// $example on$
+import org.apache.spark.ml.feature.FeatureHasher
+// $example off$
+import org.apache.spark.sql.SparkSession
+
+object FeatureHasherExample {
+ def main(args: Array[String]): Unit = {
+ val spark = SparkSession
+ .builder
+ .appName("FeatureHasherExample")
+ .getOrCreate()
+
+ // $example on$
+ val dataset = spark.createDataFrame(Seq(
+ (2.2, true, "1", "foo"),
+ (3.3, false, "2", "bar"),
+ (4.4, false, "3", "baz"),
+ (5.5, false, "4", "foo")
+ )).toDF("real", "bool", "stringNum", "string")
+
+ val hasher = new FeatureHasher()
+ .setInputCols("real", "bool", "stringNum", "string")
+ .setOutputCol("features")
+
+ val featurized = hasher.transform(dataset)
+ featurized.show(false)
+ // $example off$
+
+ spark.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala
index c1ff9ef521706..87d96dd51eb94 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala
@@ -93,6 +93,7 @@ object ModelSelectionViaCrossValidationExample {
.setEvaluator(new BinaryClassificationEvaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(2) // Use 3+ in practice
+ .setParallelism(2) // Evaluate up to 2 parameter settings in parallel
// Run cross-validation, and choose the best set of parameters.
val cvModel = cv.fit(training)
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala
index 1cd2641f9a8d0..71e41e7298c73 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala
@@ -65,6 +65,8 @@ object ModelSelectionViaTrainValidationSplitExample {
.setEstimatorParamMaps(paramGrid)
// 80% of the data will be used for training and the remaining 20% for validation.
.setTrainRatio(0.8)
+ // Evaluate up to 2 parameter settings in parallel
+ .setParallelism(2)
// Run train validation split, and choose the best set of parameters.
val model = trainValidationSplit.fit(training)
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala
index 6ff03bdb22129..86b3dc4a84f58 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala
@@ -185,6 +185,10 @@ object SQLDataSourceExample {
connectionProperties.put("password", "password")
val jdbcDF2 = spark.read
.jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties)
+ // Specifying the custom data types of the read schema
+ connectionProperties.put("customSchema", "id DECIMAL(38, 0), name STRING")
+ val jdbcDF3 = spark.read
+ .jdbc("jdbc:postgresql:dbserver", "schema.tablename", connectionProperties)
// Saving data to a JDBC source
jdbcDF.write
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala
index 474b03aa24a5d..def06026bde96 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala
@@ -18,11 +18,9 @@
// scalastyle:off println
package org.apache.spark.examples.streaming
-import kafka.serializer.StringDecoder
-
import org.apache.spark.SparkConf
import org.apache.spark.streaming._
-import org.apache.spark.streaming.kafka._
+import org.apache.spark.streaming.kafka010._
/**
* Consumes messages from one or more topics in Kafka and does wordcount.
@@ -57,11 +55,13 @@ object DirectKafkaWordCount {
// Create direct kafka stream with brokers and topics
val topicsSet = topics.split(",").toSet
val kafkaParams = Map[String, String]("metadata.broker.list" -> brokers)
- val messages = KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder](
- ssc, kafkaParams, topicsSet)
+ val messages = KafkaUtils.createDirectStream[String, String](
+ ssc,
+ LocationStrategies.PreferConsistent,
+ ConsumerStrategies.Subscribe[String, String](topicsSet, kafkaParams))
// Get the lines, split them into words, count the words and print
- val lines = messages.map(_._2)
+ val lines = messages.map(_.value)
val words = lines.flatMap(_.split(" "))
val wordCounts = words.map(x => (x, 1L)).reduceByKey(_ + _)
wordCounts.print()
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
deleted file mode 100644
index e7f9bf36e35cf..0000000000000
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
+++ /dev/null
@@ -1,105 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-// scalastyle:off println
-package org.apache.spark.examples.streaming
-
-import java.util.HashMap
-
-import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig, ProducerRecord}
-
-import org.apache.spark.SparkConf
-import org.apache.spark.streaming._
-import org.apache.spark.streaming.kafka._
-
-/**
- * Consumes messages from one or more topics in Kafka and does wordcount.
- * Usage: KafkaWordCount
- * is a list of one or more zookeeper servers that make quorum
- * is the name of kafka consumer group
- * is a list of one or more kafka topics to consume from
- * is the number of threads the kafka consumer should use
- *
- * Example:
- * `$ bin/run-example \
- * org.apache.spark.examples.streaming.KafkaWordCount zoo01,zoo02,zoo03 \
- * my-consumer-group topic1,topic2 1`
- */
-object KafkaWordCount {
- def main(args: Array[String]) {
- if (args.length < 4) {
- System.err.println("Usage: KafkaWordCount ")
- System.exit(1)
- }
-
- StreamingExamples.setStreamingLogLevels()
-
- val Array(zkQuorum, group, topics, numThreads) = args
- val sparkConf = new SparkConf().setAppName("KafkaWordCount")
- val ssc = new StreamingContext(sparkConf, Seconds(2))
- ssc.checkpoint("checkpoint")
-
- val topicMap = topics.split(",").map((_, numThreads.toInt)).toMap
- val lines = KafkaUtils.createStream(ssc, zkQuorum, group, topicMap).map(_._2)
- val words = lines.flatMap(_.split(" "))
- val wordCounts = words.map(x => (x, 1L))
- .reduceByKeyAndWindow(_ + _, _ - _, Minutes(10), Seconds(2), 2)
- wordCounts.print()
-
- ssc.start()
- ssc.awaitTermination()
- }
-}
-
-// Produces some random words between 1 and 100.
-object KafkaWordCountProducer {
-
- def main(args: Array[String]) {
- if (args.length < 4) {
- System.err.println("Usage: KafkaWordCountProducer " +
- " ")
- System.exit(1)
- }
-
- val Array(brokers, topic, messagesPerSec, wordsPerMessage) = args
-
- // Zookeeper connection properties
- val props = new HashMap[String, Object]()
- props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, brokers)
- props.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG,
- "org.apache.kafka.common.serialization.StringSerializer")
- props.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG,
- "org.apache.kafka.common.serialization.StringSerializer")
-
- val producer = new KafkaProducer[String, String](props)
-
- // Send some messages
- while(true) {
- (1 to messagesPerSec.toInt).foreach { messageNum =>
- val str = (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString)
- .mkString(" ")
-
- val message = new ProducerRecord[String, String](topic, null, str)
- producer.send(message)
- }
-
- Thread.sleep(1000)
- }
- }
-
-}
-// scalastyle:on println
diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml
index 0fa87a697454b..485b562dce990 100644
--- a/external/docker-integration-tests/pom.xml
+++ b/external/docker-integration-tests/pom.xml
@@ -80,6 +80,13 @@
test-jar
test
+
+ org.apache.spark
+ spark-catalyst_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
org.apache.spark
spark-sql_${scala.binary.version}
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
index e14810a32edc6..7680ae3835132 100644
--- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala
@@ -21,7 +21,8 @@ import java.sql.{Connection, Date, Timestamp}
import java.util.Properties
import java.math.BigDecimal
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.execution.{WholeStageCodegenExec, RowDataSourceScanExec}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.tags.DockerTest
@@ -71,10 +72,17 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
""".stripMargin.replaceAll("\n", " ")).executeUpdate()
conn.commit()
- conn.prepareStatement("CREATE TABLE ts_with_timezone (id NUMBER(10), t TIMESTAMP WITH TIME ZONE)")
- .executeUpdate()
- conn.prepareStatement("INSERT INTO ts_with_timezone VALUES (1, to_timestamp_tz('1999-12-01 11:00:00 UTC','YYYY-MM-DD HH:MI:SS TZR'))")
- .executeUpdate()
+ conn.prepareStatement(
+ "CREATE TABLE ts_with_timezone (id NUMBER(10), t TIMESTAMP WITH TIME ZONE)").executeUpdate()
+ conn.prepareStatement(
+ "INSERT INTO ts_with_timezone VALUES " +
+ "(1, to_timestamp_tz('1999-12-01 11:00:00 UTC','YYYY-MM-DD HH:MI:SS TZR'))").executeUpdate()
+ conn.commit()
+
+ conn.prepareStatement(
+ "CREATE TABLE tableWithCustomSchema (id NUMBER, n1 NUMBER(1), n2 NUMBER(1))").executeUpdate()
+ conn.prepareStatement(
+ "INSERT INTO tableWithCustomSchema values(12312321321321312312312312123, 1, 0)").executeUpdate()
conn.commit()
sql(
@@ -103,7 +111,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
}
- test("SPARK-16625 : Importing Oracle numeric types") {
+ test("SPARK-16625 : Importing Oracle numeric types") {
val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties);
val rows = df.collect()
assert(rows.size == 1)
@@ -255,12 +263,15 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
val df = dfRead.filter(dfRead.col("date_type").lt(dt))
.filter(dfRead.col("timestamp_type").lt(ts))
- val metadata = df.queryExecution.sparkPlan.metadata
- // The "PushedFilters" part should be exist in Datafrome's
+ val parentPlan = df.queryExecution.executedPlan
+ assert(parentPlan.isInstanceOf[WholeStageCodegenExec])
+ val node = parentPlan.asInstanceOf[WholeStageCodegenExec]
+ val metadata = node.child.asInstanceOf[RowDataSourceScanExec].metadata
+ // The "PushedFilters" part should exist in Dataframe's
// physical plan and the existence of right literals in
// "PushedFilters" is used to prove that the predicates
// pushing down have been effective.
- assert(metadata.get("PushedFilters").ne(None))
+ assert(metadata.get("PushedFilters").isDefined)
assert(metadata("PushedFilters").contains(dt.toString))
assert(metadata("PushedFilters").contains(ts.toString))
@@ -268,4 +279,32 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo
assert(row.getDate(0).equals(dateVal))
assert(row.getTimestamp(1).equals(timestampVal))
}
+
+ test("SPARK-20427/SPARK-20921: read table use custom schema by jdbc api") {
+ // default will throw IllegalArgumentException
+ val e = intercept[org.apache.spark.SparkException] {
+ spark.read.jdbc(jdbcUrl, "tableWithCustomSchema", new Properties()).collect()
+ }
+ assert(e.getMessage.contains(
+ "requirement failed: Decimal precision 39 exceeds max precision 38"))
+
+ // custom schema can read data
+ val props = new Properties()
+ props.put("customSchema",
+ s"ID DECIMAL(${DecimalType.MAX_PRECISION}, 0), N1 INT, N2 BOOLEAN")
+ val dfRead = spark.read.jdbc(jdbcUrl, "tableWithCustomSchema", props)
+
+ val rows = dfRead.collect()
+ // verify the data type
+ val types = rows(0).toSeq.map(x => x.getClass.toString)
+ assert(types(0).equals("class java.math.BigDecimal"))
+ assert(types(1).equals("class java.lang.Integer"))
+ assert(types(2).equals("class java.lang.Boolean"))
+
+ // verify the value
+ val values = rows(0)
+ assert(values.getDecimal(0).equals(new java.math.BigDecimal("12312321321321312312312312123")))
+ assert(values.getInt(1).equals(1))
+ assert(values.getBoolean(2).equals(false))
+ }
}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
index 7c4f38e02fb2a..90ed7b1fba2f8 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala
@@ -112,9 +112,15 @@ private[kafka010] case class CachedKafkaConsumer private(
// we will move to the next available offset within `[offset, untilOffset)` and retry.
// If `failOnDataLoss` is `true`, the loop body will be executed only once.
var toFetchOffset = offset
- while (toFetchOffset != UNKNOWN_OFFSET) {
+ var consumerRecord: ConsumerRecord[Array[Byte], Array[Byte]] = null
+ // We want to break out of the while loop on a successful fetch to avoid using "return"
+ // which may causes a NonLocalReturnControl exception when this method is used as a function.
+ var isFetchComplete = false
+
+ while (toFetchOffset != UNKNOWN_OFFSET && !isFetchComplete) {
try {
- return fetchData(toFetchOffset, untilOffset, pollTimeoutMs, failOnDataLoss)
+ consumerRecord = fetchData(toFetchOffset, untilOffset, pollTimeoutMs, failOnDataLoss)
+ isFetchComplete = true
} catch {
case e: OffsetOutOfRangeException =>
// When there is some error thrown, it's better to use a new consumer to drop all cached
@@ -125,8 +131,13 @@ private[kafka010] case class CachedKafkaConsumer private(
toFetchOffset = getEarliestAvailableOffsetBetween(toFetchOffset, untilOffset)
}
}
- resetFetchedData()
- null
+
+ if (isFetchComplete) {
+ consumerRecord
+ } else {
+ resetFetchedData()
+ null
+ }
}
/**
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
index 7ac183776e20d..e9cff04ba5f2e 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
@@ -310,7 +310,7 @@ private[kafka010] class KafkaSource(
currentPartitionOffsets = Some(untilPartitionOffsets)
}
- sqlContext.internalCreateDataFrame(rdd, schema)
+ sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true)
}
/** Stop this source and free any resources it has allocated. */
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala
index 9a4a1cf32a480..0fa3287f36db8 100644
--- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala
@@ -21,14 +21,12 @@ import java.{ util => ju }
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicReference
-import scala.annotation.tailrec
import scala.collection.JavaConverters._
import scala.collection.mutable
import org.apache.kafka.clients.consumer._
-import org.apache.kafka.common.{ PartitionInfo, TopicPartition }
+import org.apache.kafka.common.TopicPartition
-import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{StreamingContext, Time}
diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala
index 9159051ba06e4..89ccbe219cecd 100644
--- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala
+++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala
@@ -23,6 +23,7 @@ import org.apache.spark.annotation.Experimental
* Represents the host and port info for a Kafka broker.
* Differs from the Kafka project's internal kafka.cluster.Broker, which contains a server ID.
*/
+@deprecated("Update to Kafka 0.10 integration", "2.3.0")
final class Broker private(
/** Broker's hostname */
val host: String,
@@ -49,6 +50,7 @@ final class Broker private(
* Companion object that provides methods to create instances of [[Broker]].
*/
@Experimental
+@deprecated("Update to Kafka 0.10 integration", "2.3.0")
object Broker {
def create(host: String, port: Int): Broker =
new Broker(host, port)
diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
index e0e44d4440272..570affab11853 100644
--- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
+++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
@@ -42,6 +42,7 @@ import org.apache.spark.annotation.DeveloperApi
* NOT zookeeper servers, specified in host1:port1,host2:port2 form
*/
@DeveloperApi
+@deprecated("Update to Kafka 0.10 integration", "2.3.0")
class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable {
import KafkaCluster.{Err, LeaderOffset, SimpleConsumerConfig}
@@ -376,6 +377,7 @@ class KafkaCluster(val kafkaParams: Map[String, String]) extends Serializable {
}
@DeveloperApi
+@deprecated("Update to Kafka 0.10 integration", "2.3.0")
object KafkaCluster {
type Err = ArrayBuffer[Throwable]
diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
index 78230725f322e..36082e93707b8 100644
--- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
+++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
@@ -41,6 +41,7 @@ import org.apache.spark.streaming.api.java._
import org.apache.spark.streaming.dstream.{DStream, InputDStream, ReceiverInputDStream}
import org.apache.spark.streaming.util.WriteAheadLogUtils
+@deprecated("Update to Kafka 0.10 integration", "2.3.0")
object KafkaUtils {
/**
* Create an input stream that pulls messages from Kafka Brokers.
diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala
index 10d364f987405..6dab5f950d4cd 100644
--- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala
+++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala
@@ -30,6 +30,7 @@ import kafka.common.TopicAndPartition
* }
* }}}
*/
+@deprecated("Update to Kafka 0.10 integration", "2.3.0")
trait HasOffsetRanges {
def offsetRanges: Array[OffsetRange]
}
@@ -42,6 +43,7 @@ trait HasOffsetRanges {
* @param fromOffset Inclusive starting offset
* @param untilOffset Exclusive ending offset
*/
+@deprecated("Update to Kafka 0.10 integration", "2.3.0")
final class OffsetRange private(
val topic: String,
val partition: Int,
@@ -80,6 +82,7 @@ final class OffsetRange private(
/**
* Companion object the provides methods to create instances of [[OffsetRange]].
*/
+@deprecated("Update to Kafka 0.10 integration", "2.3.0")
object OffsetRange {
def create(topic: String, partition: Int, fromOffset: Long, untilOffset: Long): OffsetRange =
new OffsetRange(topic, partition, fromOffset, untilOffset)
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
index 8d56d4be9c42a..e26f4477d1d7d 100644
--- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
@@ -30,7 +30,7 @@ import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester}
import org.scalatest.concurrent.Eventually
-import org.scalatest.mock.MockitoSugar
+import org.scalatest.mockito.MockitoSugar
import org.apache.spark.streaming.{Duration, TestSuiteBase}
import org.apache.spark.util.ManualClock
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala
index 1c130654f3f95..afa1a7f8ca663 100644
--- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala
@@ -17,13 +17,10 @@
package org.apache.spark.streaming.kinesis
-import java.lang.IllegalArgumentException
-
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream
import org.scalatest.BeforeAndAfterEach
-import org.scalatest.mock.MockitoSugar
+import org.scalatest.mockito.MockitoSugar
-import org.apache.spark.SparkFunSuite
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Seconds, StreamingContext, TestSuiteBase}
diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
index 3b14c8471e205..2fadda271ea28 100644
--- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
+++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
@@ -28,7 +28,7 @@ import org.mockito.Matchers._
import org.mockito.Matchers.{eq => meq}
import org.mockito.Mockito._
import org.scalatest.{BeforeAndAfter, Matchers}
-import org.scalatest.mock.MockitoSugar
+import org.scalatest.mockito.MockitoSugar
import org.apache.spark.streaming.{Duration, TestSuiteBase}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala
index fda501aa757d6..539b66f747cc9 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala
@@ -50,6 +50,7 @@ import org.apache.spark.util.PeriodicCheckpointer
* {{{
* val (graph1, graph2, graph3, ...) = ...
* val cp = new PeriodicGraphCheckpointer(2, sc)
+ * cp.updateGraph(graph1)
* graph1.vertices.count(); graph1.edges.count()
* // persisted: graph1
* cp.updateGraph(graph2)
diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
index 44028c58ac489..ce24400f557cd 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java
@@ -144,6 +144,7 @@ List buildClassPath(String appClassPath) throws IOException {
if (prependClasses || isTesting) {
String scala = getScalaVersion();
List projects = Arrays.asList(
+ "common/kvstore",
"common/network-common",
"common/network-shuffle",
"common/network-yarn",
@@ -230,17 +231,17 @@ String getScalaVersion() {
return scala;
}
String sparkHome = getSparkHome();
- //File scala212 = new File(sparkHome, "launcher/target/scala-2.12");
+ File scala212 = new File(sparkHome, "launcher/target/scala-2.12");
File scala211 = new File(sparkHome, "launcher/target/scala-2.11");
- //checkState(!scala210.isDirectory() || !scala211.isDirectory(),
- // "Presence of build for multiple Scala versions detected.\n" +
- // "Either clean one of them or set SPARK_SCALA_VERSION in your environment.");
- //if (scala212.isDirectory()) {
- // return "2.12";
- //} else {
- checkState(scala211.isDirectory(), "Cannot find any build directories.");
- return "2.11";
- //}
+ checkState(!scala212.isDirectory() || !scala211.isDirectory(),
+ "Presence of build for multiple Scala versions detected.\n" +
+ "Either clean one of them or set SPARK_SCALA_VERSION in your environment.");
+ if (scala212.isDirectory()) {
+ return "2.12";
+ } else {
+ checkState(scala211.isDirectory(), "Cannot find any build directories.");
+ return "2.11";
+ }
}
String getSparkHome() {
diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java
index bf916406f1471..5391d4a50fe47 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java
@@ -156,9 +156,15 @@ synchronized void setAppId(String appId) {
* the exit code.
*/
void monitorChild() {
- while (childProc.isAlive()) {
+ Process proc = childProc;
+ if (proc == null) {
+ // Process may have already been disposed of, e.g. by calling kill().
+ return;
+ }
+
+ while (proc.isAlive()) {
try {
- childProc.waitFor();
+ proc.waitFor();
} catch (Exception e) {
LOG.log(Level.WARNING, "Exception waiting for child process to exit.", e);
}
@@ -173,15 +179,24 @@ void monitorChild() {
int ec;
try {
- ec = childProc.exitValue();
+ ec = proc.exitValue();
} catch (Exception e) {
LOG.log(Level.WARNING, "Exception getting child process exit code, assuming failure.", e);
ec = 1;
}
- // Only override the success state; leave other fail states alone.
- if (!state.isFinal() || (ec != 0 && state == State.FINISHED)) {
- state = State.LOST;
+ State newState = null;
+ if (ec != 0) {
+ // Override state with failure if the current state is not final, or is success.
+ if (!state.isFinal() || state == State.FINISHED) {
+ newState = State.FAILED;
+ }
+ } else if (!state.isFinal()) {
+ newState = State.LOST;
+ }
+
+ if (newState != null) {
+ state = newState;
fireEvent(false);
}
}
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java
index 137ef74843da5..32724acdc362c 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java
@@ -53,16 +53,19 @@ public List buildCommand(Map env)
case "org.apache.spark.deploy.master.Master":
javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS");
javaOptsKeys.add("SPARK_MASTER_OPTS");
+ extraClassPath = getenv("SPARK_DAEMON_CLASSPATH");
memKey = "SPARK_DAEMON_MEMORY";
break;
case "org.apache.spark.deploy.worker.Worker":
javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS");
javaOptsKeys.add("SPARK_WORKER_OPTS");
+ extraClassPath = getenv("SPARK_DAEMON_CLASSPATH");
memKey = "SPARK_DAEMON_MEMORY";
break;
case "org.apache.spark.deploy.history.HistoryServer":
javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS");
javaOptsKeys.add("SPARK_HISTORY_OPTS");
+ extraClassPath = getenv("SPARK_DAEMON_CLASSPATH");
memKey = "SPARK_DAEMON_MEMORY";
break;
case "org.apache.spark.executor.CoarseGrainedExecutorBackend":
@@ -77,11 +80,13 @@ public List buildCommand(Map env)
break;
case "org.apache.spark.deploy.mesos.MesosClusterDispatcher":
javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS");
+ extraClassPath = getenv("SPARK_DAEMON_CLASSPATH");
break;
case "org.apache.spark.deploy.ExternalShuffleService":
case "org.apache.spark.deploy.mesos.MesosExternalShuffleService":
javaOptsKeys.add("SPARK_DAEMON_JAVA_OPTS");
javaOptsKeys.add("SPARK_SHUFFLE_OPTS");
+ extraClassPath = getenv("SPARK_DAEMON_CLASSPATH");
memKey = "SPARK_DAEMON_MEMORY";
break;
default:
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java
index b83fe1b2d01cb..718a368a8e731 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java
@@ -605,7 +605,7 @@ private ProcessBuilder createBuilder() throws IOException {
}
// Visible for testing.
- String findSparkSubmit() throws IOException {
+ String findSparkSubmit() {
String script = isWindows() ? "spark-submit.cmd" : "spark-submit";
return join(File.separator, builder.getSparkHome(), "bin", script);
}
diff --git a/launcher/src/test/java/org/apache/spark/launcher/ChildProcAppHandleSuite.java b/launcher/src/test/java/org/apache/spark/launcher/ChildProcAppHandleSuite.java
index 602f55a50564d..9f59b41d52d44 100644
--- a/launcher/src/test/java/org/apache/spark/launcher/ChildProcAppHandleSuite.java
+++ b/launcher/src/test/java/org/apache/spark/launcher/ChildProcAppHandleSuite.java
@@ -46,7 +46,9 @@ public class ChildProcAppHandleSuite extends BaseSuite {
private static final List TEST_SCRIPT = Arrays.asList(
"#!/bin/sh",
"echo \"output\"",
- "echo \"error\" 1>&2");
+ "echo \"error\" 1>&2",
+ "while [ -n \"$1\" ]; do EC=$1; shift; done",
+ "exit $EC");
private static File TEST_SCRIPT_PATH;
@@ -112,6 +114,7 @@ public void testRedirectErrorToLog() throws Exception {
assumeFalse(isWindows());
Path err = Files.createTempFile("stderr", "txt");
+ err.toFile().deleteOnExit();
SparkAppHandle handle = (ChildProcAppHandle) new TestSparkLauncher()
.redirectError(err.toFile())
@@ -127,6 +130,7 @@ public void testRedirectOutputToLog() throws Exception {
assumeFalse(isWindows());
Path out = Files.createTempFile("stdout", "txt");
+ out.toFile().deleteOnExit();
SparkAppHandle handle = (ChildProcAppHandle) new TestSparkLauncher()
.redirectOutput(out.toFile())
@@ -143,6 +147,8 @@ public void testNoRedirectToLog() throws Exception {
Path out = Files.createTempFile("stdout", "txt");
Path err = Files.createTempFile("stderr", "txt");
+ out.toFile().deleteOnExit();
+ err.toFile().deleteOnExit();
ChildProcAppHandle handle = (ChildProcAppHandle) new TestSparkLauncher()
.redirectError(err.toFile())
@@ -157,9 +163,11 @@ public void testNoRedirectToLog() throws Exception {
@Test(expected = IllegalArgumentException.class)
public void testBadLogRedirect() throws Exception {
+ File out = Files.createTempFile("stdout", "txt").toFile();
+ out.deleteOnExit();
new SparkLauncher()
.redirectError()
- .redirectOutput(Files.createTempFile("stdout", "txt").toFile())
+ .redirectOutput(out)
.redirectToLog("foo")
.launch()
.waitFor();
@@ -167,16 +175,20 @@ public void testBadLogRedirect() throws Exception {
@Test(expected = IllegalArgumentException.class)
public void testRedirectErrorTwiceFails() throws Exception {
+ File err = Files.createTempFile("stderr", "txt").toFile();
+ err.deleteOnExit();
new SparkLauncher()
.redirectError()
- .redirectError(Files.createTempFile("stderr", "txt").toFile())
+ .redirectError(err)
.launch()
.waitFor();
}
@Test
public void testProcMonitorWithOutputRedirection() throws Exception {
+ assumeFalse(isWindows());
File err = Files.createTempFile("out", "txt").toFile();
+ err.deleteOnExit();
SparkAppHandle handle = new TestSparkLauncher()
.redirectError()
.redirectOutput(err)
@@ -187,6 +199,7 @@ public void testProcMonitorWithOutputRedirection() throws Exception {
@Test
public void testProcMonitorWithLogRedirection() throws Exception {
+ assumeFalse(isWindows());
SparkAppHandle handle = new TestSparkLauncher()
.redirectToLog(getClass().getName())
.startApplication();
@@ -194,6 +207,16 @@ public void testProcMonitorWithLogRedirection() throws Exception {
assertEquals(SparkAppHandle.State.LOST, handle.getState());
}
+ @Test
+ public void testFailedChildProc() throws Exception {
+ assumeFalse(isWindows());
+ SparkAppHandle handle = new TestSparkLauncher(1)
+ .redirectToLog(getClass().getName())
+ .startApplication();
+ waitFor(handle);
+ assertEquals(SparkAppHandle.State.FAILED, handle.getState());
+ }
+
private void waitFor(SparkAppHandle handle) throws Exception {
long deadline = System.nanoTime() + TimeUnit.SECONDS.toNanos(10);
try {
@@ -212,7 +235,12 @@ private void waitFor(SparkAppHandle handle) throws Exception {
private static class TestSparkLauncher extends SparkLauncher {
TestSparkLauncher() {
+ this(0);
+ }
+
+ TestSparkLauncher(int ec) {
setAppResource("outputredirtest");
+ addAppArgs(String.valueOf(ec));
}
@Override
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala
index 7fb9034d6501a..ace44165b1067 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/MatricesSuite.scala
@@ -21,7 +21,7 @@ import java.util.Random
import breeze.linalg.{CSCMatrix, Matrix => BM}
import org.mockito.Mockito.when
-import org.scalatest.mock.MockitoSugar._
+import org.scalatest.mockito.MockitoSugar._
import scala.collection.mutable.{Map => MutableMap}
import org.apache.spark.ml.SparkMLFunSuite
diff --git a/mllib/pom.xml b/mllib/pom.xml
index c72a16a56e05c..925b5422a54cc 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -34,6 +34,10 @@
http://spark.apache.org/
+
+ org.scala-lang.modules
+ scala-parser-combinators_${scala.binary.version}
+
org.apache.spark
spark-core_${scala.binary.version}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
index e7e0dae0b5a01..014ff07c21158 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
@@ -361,17 +361,42 @@ private[ann] trait TopologyModel extends Serializable {
* Forward propagation
*
* @param data input data
+ * @param includeLastLayer Include the last layer in the output. In
+ * MultilayerPerceptronClassifier, the last layer is always softmax;
+ * the last layer of outputs is needed for class predictions, but not
+ * for rawPrediction.
+ *
* @return array of outputs for each of the layers
*/
- def forward(data: BDM[Double]): Array[BDM[Double]]
+ def forward(data: BDM[Double], includeLastLayer: Boolean): Array[BDM[Double]]
/**
- * Prediction of the model
+ * Prediction of the model. See {@link ProbabilisticClassificationModel}
*
- * @param data input data
+ * @param features input features
* @return prediction
*/
- def predict(data: Vector): Vector
+ def predict(features: Vector): Vector
+
+ /**
+ * Raw prediction of the model. See {@link ProbabilisticClassificationModel}
+ *
+ * @param features input features
+ * @return raw prediction
+ *
+ * Note: This interface is only used for classification Model.
+ */
+ def predictRaw(features: Vector): Vector
+
+ /**
+ * Probability of the model. See {@link ProbabilisticClassificationModel}
+ *
+ * @param rawPrediction raw prediction vector
+ * @return probability
+ *
+ * Note: This interface is only used for classification Model.
+ */
+ def raw2ProbabilityInPlace(rawPrediction: Vector): Vector
/**
* Computes gradient for the network
@@ -463,7 +488,7 @@ private[ml] class FeedForwardModel private(
private var outputs: Array[BDM[Double]] = null
private var deltas: Array[BDM[Double]] = null
- override def forward(data: BDM[Double]): Array[BDM[Double]] = {
+ override def forward(data: BDM[Double], includeLastLayer: Boolean): Array[BDM[Double]] = {
// Initialize output arrays for all layers. Special treatment for InPlace
val currentBatchSize = data.cols
// TODO: allocate outputs as one big array and then create BDMs from it
@@ -481,7 +506,8 @@ private[ml] class FeedForwardModel private(
}
}
layerModels(0).eval(data, outputs(0))
- for (i <- 1 until layerModels.length) {
+ val end = if (includeLastLayer) layerModels.length else layerModels.length - 1
+ for (i <- 1 until end) {
layerModels(i).eval(outputs(i - 1), outputs(i))
}
outputs
@@ -492,7 +518,7 @@ private[ml] class FeedForwardModel private(
target: BDM[Double],
cumGradient: Vector,
realBatchSize: Int): Double = {
- val outputs = forward(data)
+ val outputs = forward(data, true)
val currentBatchSize = data.cols
// TODO: allocate deltas as one big array and then create BDMs from it
if (deltas == null || deltas(0).cols != currentBatchSize) {
@@ -527,9 +553,20 @@ private[ml] class FeedForwardModel private(
override def predict(data: Vector): Vector = {
val size = data.size
- val result = forward(new BDM[Double](size, 1, data.toArray))
+ val result = forward(new BDM[Double](size, 1, data.toArray), true)
Vectors.dense(result.last.toArray)
}
+
+ override def predictRaw(data: Vector): Vector = {
+ val result = forward(new BDM[Double](data.size, 1, data.toArray), false)
+ Vectors.dense(result(result.length - 2).toArray)
+ }
+
+ override def raw2ProbabilityInPlace(data: Vector): Vector = {
+ val dataMatrix = new BDM[Double](data.size, 1, data.toArray)
+ layerModels.last.eval(dataMatrix, dataMatrix)
+ data
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala
index 25ce0282b1274..d26acf924c0a3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala
@@ -18,7 +18,6 @@
package org.apache.spark.ml
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup}
-import org.apache.spark.sql.DataFrame
/**
* ==ML attributes==
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index ade0960f87a0d..3da809ce5f77c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -32,7 +32,6 @@ import org.apache.spark.ml.tree.impl.GradientBoostedTrees
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.tree.loss.LogLoss
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index 8d556deef2be8..1c97d77d38948 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -20,16 +20,16 @@ package org.apache.spark.ml.classification
import scala.collection.mutable
import breeze.linalg.{DenseVector => BDV}
-import breeze.optimize.{CachedDiffFunction, DiffFunction, OWLQN => BreezeOWLQN}
+import breeze.optimize.{CachedDiffFunction, OWLQN => BreezeOWLQN}
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
-import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg._
-import org.apache.spark.ml.linalg.BLAS._
+import org.apache.spark.ml.optim.aggregator.HingeAggregator
+import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
@@ -214,10 +214,20 @@ class LinearSVC @Since("2.2.0") (
}
val featuresStd = summarizer.variance.toArray.map(math.sqrt)
+ val getFeaturesStd = (j: Int) => featuresStd(j)
val regParamL2 = $(regParam)
val bcFeaturesStd = instances.context.broadcast(featuresStd)
- val costFun = new LinearSVCCostFun(instances, $(fitIntercept),
- $(standardization), bcFeaturesStd, regParamL2, $(aggregationDepth))
+ val regularization = if (regParamL2 != 0.0) {
+ val shouldApply = (idx: Int) => idx >= 0 && idx < numFeatures
+ Some(new L2Regularization(regParamL2, shouldApply,
+ if ($(standardization)) None else Some(getFeaturesStd)))
+ } else {
+ None
+ }
+
+ val getAggregatorFunc = new HingeAggregator(bcFeaturesStd, $(fitIntercept))(_)
+ val costFun = new RDDLossFunction(instances, getAggregatorFunc, regularization,
+ $(aggregationDepth))
def regParamL1Fun = (index: Int) => 0D
val optimizer = new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
@@ -372,189 +382,3 @@ object LinearSVCModel extends MLReadable[LinearSVCModel] {
}
}
}
-
-/**
- * LinearSVCCostFun implements Breeze's DiffFunction[T] for hinge loss function
- */
-private class LinearSVCCostFun(
- instances: RDD[Instance],
- fitIntercept: Boolean,
- standardization: Boolean,
- bcFeaturesStd: Broadcast[Array[Double]],
- regParamL2: Double,
- aggregationDepth: Int) extends DiffFunction[BDV[Double]] {
-
- override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
- val coeffs = Vectors.fromBreeze(coefficients)
- val bcCoeffs = instances.context.broadcast(coeffs)
- val featuresStd = bcFeaturesStd.value
- val numFeatures = featuresStd.length
-
- val svmAggregator = {
- val seqOp = (c: LinearSVCAggregator, instance: Instance) => c.add(instance)
- val combOp = (c1: LinearSVCAggregator, c2: LinearSVCAggregator) => c1.merge(c2)
-
- instances.treeAggregate(
- new LinearSVCAggregator(bcCoeffs, bcFeaturesStd, fitIntercept)
- )(seqOp, combOp, aggregationDepth)
- }
-
- val totalGradientArray = svmAggregator.gradient.toArray
- // regVal is the sum of coefficients squares excluding intercept for L2 regularization.
- val regVal = if (regParamL2 == 0.0) {
- 0.0
- } else {
- var sum = 0.0
- coeffs.foreachActive { case (index, value) =>
- // We do not apply regularization to the intercepts
- if (index != numFeatures) {
- // The following code will compute the loss of the regularization; also
- // the gradient of the regularization, and add back to totalGradientArray.
- sum += {
- if (standardization) {
- totalGradientArray(index) += regParamL2 * value
- value * value
- } else {
- if (featuresStd(index) != 0.0) {
- // If `standardization` is false, we still standardize the data
- // to improve the rate of convergence; as a result, we have to
- // perform this reverse standardization by penalizing each component
- // differently to get effectively the same objective function when
- // the training dataset is not standardized.
- val temp = value / (featuresStd(index) * featuresStd(index))
- totalGradientArray(index) += regParamL2 * temp
- value * temp
- } else {
- 0.0
- }
- }
- }
- }
- }
- 0.5 * regParamL2 * sum
- }
- bcCoeffs.destroy(blocking = false)
-
- (svmAggregator.loss + regVal, new BDV(totalGradientArray))
- }
-}
-
-/**
- * LinearSVCAggregator computes the gradient and loss for hinge loss function, as used
- * in binary classification for instances in sparse or dense vector in an online fashion.
- *
- * Two LinearSVCAggregator can be merged together to have a summary of loss and gradient of
- * the corresponding joint dataset.
- *
- * This class standardizes feature values during computation using bcFeaturesStd.
- *
- * @param bcCoefficients The coefficients corresponding to the features.
- * @param fitIntercept Whether to fit an intercept term.
- * @param bcFeaturesStd The standard deviation values of the features.
- */
-private class LinearSVCAggregator(
- bcCoefficients: Broadcast[Vector],
- bcFeaturesStd: Broadcast[Array[Double]],
- fitIntercept: Boolean) extends Serializable {
-
- private val numFeatures: Int = bcFeaturesStd.value.length
- private val numFeaturesPlusIntercept: Int = if (fitIntercept) numFeatures + 1 else numFeatures
- private var weightSum: Double = 0.0
- private var lossSum: Double = 0.0
- @transient private lazy val coefficientsArray = bcCoefficients.value match {
- case DenseVector(values) => values
- case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector" +
- s" but got type ${bcCoefficients.value.getClass}.")
- }
- private lazy val gradientSumArray = new Array[Double](numFeaturesPlusIntercept)
-
- /**
- * Add a new training instance to this LinearSVCAggregator, and update the loss and gradient
- * of the objective function.
- *
- * @param instance The instance of data point to be added.
- * @return This LinearSVCAggregator object.
- */
- def add(instance: Instance): this.type = {
- instance match { case Instance(label, weight, features) =>
-
- if (weight == 0.0) return this
- val localFeaturesStd = bcFeaturesStd.value
- val localCoefficients = coefficientsArray
- val localGradientSumArray = gradientSumArray
-
- val dotProduct = {
- var sum = 0.0
- features.foreachActive { (index, value) =>
- if (localFeaturesStd(index) != 0.0 && value != 0.0) {
- sum += localCoefficients(index) * value / localFeaturesStd(index)
- }
- }
- if (fitIntercept) sum += localCoefficients(numFeaturesPlusIntercept - 1)
- sum
- }
- // Our loss function with {0, 1} labels is max(0, 1 - (2y - 1) (f_w(x)))
- // Therefore the gradient is -(2y - 1)*x
- val labelScaled = 2 * label - 1.0
- val loss = if (1.0 > labelScaled * dotProduct) {
- weight * (1.0 - labelScaled * dotProduct)
- } else {
- 0.0
- }
-
- if (1.0 > labelScaled * dotProduct) {
- val gradientScale = -labelScaled * weight
- features.foreachActive { (index, value) =>
- if (localFeaturesStd(index) != 0.0 && value != 0.0) {
- localGradientSumArray(index) += value * gradientScale / localFeaturesStd(index)
- }
- }
- if (fitIntercept) {
- localGradientSumArray(localGradientSumArray.length - 1) += gradientScale
- }
- }
-
- lossSum += loss
- weightSum += weight
- this
- }
- }
-
- /**
- * Merge another LinearSVCAggregator, and update the loss and gradient
- * of the objective function.
- * (Note that it's in place merging; as a result, `this` object will be modified.)
- *
- * @param other The other LinearSVCAggregator to be merged.
- * @return This LinearSVCAggregator object.
- */
- def merge(other: LinearSVCAggregator): this.type = {
-
- if (other.weightSum != 0.0) {
- weightSum += other.weightSum
- lossSum += other.lossSum
-
- var i = 0
- val localThisGradientSumArray = this.gradientSumArray
- val localOtherGradientSumArray = other.gradientSumArray
- val len = localThisGradientSumArray.length
- while (i < len) {
- localThisGradientSumArray(i) += localOtherGradientSumArray(i)
- i += 1
- }
- }
- this
- }
-
- def loss: Double = if (weightSum != 0) lossSum / weightSum else 0.0
-
- def gradient: Vector = {
- if (weightSum != 0) {
- val result = Vectors.dense(gradientSumArray.clone())
- scal(1.0 / weightSum, result)
- result
- } else {
- Vectors.dense(new Array[Double](numFeaturesPlusIntercept))
- }
- }
-}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 21957d94e2dc3..cbc8f4a2d8c27 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -22,7 +22,7 @@ import java.util.Locale
import scala.collection.mutable
import breeze.linalg.{DenseVector => BDV}
-import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN}
+import breeze.optimize.{CachedDiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN}
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
@@ -35,7 +35,7 @@ import org.apache.spark.ml.optim.loss.{L2Regularization, RDDLossFunction}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
-import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
+import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
@@ -484,7 +484,7 @@ class LogisticRegression @Since("1.2.0") (
}
override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = {
- val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ val handlePersistence = dataset.storageLevel == StorageLevel.NONE
train(dataset, handlePersistence)
}
@@ -882,21 +882,28 @@ class LogisticRegression @Since("1.2.0") (
val model = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector,
numClasses, isMultinomial))
- // TODO: implement summary model for multinomial case
- val m = if (!isMultinomial) {
- val (summaryModel, probabilityColName) = model.findSummaryModelAndProbabilityCol()
- val logRegSummary = new BinaryLogisticRegressionTrainingSummary(
+
+ val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
+ val logRegSummary = if (numClasses <= 2) {
+ new BinaryLogisticRegressionTrainingSummaryImpl(
summaryModel.transform(dataset),
probabilityColName,
+ predictionColName,
$(labelCol),
$(featuresCol),
objectiveHistory)
- model.setSummary(Some(logRegSummary))
} else {
- model
+ new LogisticRegressionTrainingSummaryImpl(
+ summaryModel.transform(dataset),
+ probabilityColName,
+ predictionColName,
+ $(labelCol),
+ $(featuresCol),
+ objectiveHistory)
}
- instr.logSuccess(m)
- m
+ model.setSummary(Some(logRegSummary))
+ instr.logSuccess(model)
+ model
}
@Since("1.4.0")
@@ -1010,8 +1017,8 @@ class LogisticRegressionModel private[spark] (
private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None
/**
- * Gets summary of model on training set. An exception is
- * thrown if `trainingSummary == None`.
+ * Gets summary of model on training set. An exception is thrown
+ * if `trainingSummary == None`.
*/
@Since("1.5.0")
def summary: LogisticRegressionTrainingSummary = trainingSummary.getOrElse {
@@ -1019,18 +1026,36 @@ class LogisticRegressionModel private[spark] (
}
/**
- * If the probability column is set returns the current model and probability column,
- * otherwise generates a new column and sets it as the probability column on a new copy
- * of the current model.
+ * Gets summary of model on training set. An exception is thrown
+ * if `trainingSummary == None` or it is a multiclass model.
+ */
+ @Since("2.3.0")
+ def binarySummary: BinaryLogisticRegressionTrainingSummary = summary match {
+ case b: BinaryLogisticRegressionTrainingSummary => b
+ case _ =>
+ throw new RuntimeException("Cannot create a binary summary for a non-binary model" +
+ s"(numClasses=${numClasses}), use summary instead.")
+ }
+
+ /**
+ * If the probability and prediction columns are set, this method returns the current model,
+ * otherwise it generates new columns for them and sets them as columns on a new copy of
+ * the current model
*/
- private[classification] def findSummaryModelAndProbabilityCol():
- (LogisticRegressionModel, String) = {
- $(probabilityCol) match {
- case "" =>
- val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString
- (copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName)
- case p => (this, p)
+ private[classification] def findSummaryModel():
+ (LogisticRegressionModel, String, String) = {
+ val model = if ($(probabilityCol).isEmpty && $(predictionCol).isEmpty) {
+ copy(ParamMap.empty)
+ .setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
+ .setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
+ } else if ($(probabilityCol).isEmpty) {
+ copy(ParamMap.empty).setProbabilityCol("probability_" + java.util.UUID.randomUUID.toString)
+ } else if ($(predictionCol).isEmpty) {
+ copy(ParamMap.empty).setPredictionCol("prediction_" + java.util.UUID.randomUUID.toString)
+ } else {
+ this
}
+ (model, model.getProbabilityCol, model.getPredictionCol)
}
private[classification]
@@ -1051,9 +1076,14 @@ class LogisticRegressionModel private[spark] (
@Since("2.0.0")
def evaluate(dataset: Dataset[_]): LogisticRegressionSummary = {
// Handle possible missing or invalid prediction columns
- val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol()
- new BinaryLogisticRegressionSummary(summaryModel.transform(dataset),
- probabilityColName, $(labelCol), $(featuresCol))
+ val (summaryModel, probabilityColName, predictionColName) = findSummaryModel()
+ if (numClasses > 2) {
+ new LogisticRegressionSummaryImpl(summaryModel.transform(dataset),
+ probabilityColName, predictionColName, $(labelCol), $(featuresCol))
+ } else {
+ new BinaryLogisticRegressionSummaryImpl(summaryModel.transform(dataset),
+ probabilityColName, predictionColName, $(labelCol), $(featuresCol))
+ }
}
/**
@@ -1324,90 +1354,169 @@ private[ml] class MultiClassSummarizer extends Serializable {
}
/**
- * Abstraction for multinomial Logistic Regression Training results.
- * Currently, the training summary ignores the training weights except
- * for the objective trace.
- */
-sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary {
-
- /** objective function (scaled loss + regularization) at each iteration. */
- def objectiveHistory: Array[Double]
-
- /** Number of training iterations until termination */
- def totalIterations: Int = objectiveHistory.length
-
-}
-
-/**
- * Abstraction for Logistic Regression Results for a given model.
+ * :: Experimental ::
+ * Abstraction for logistic regression results for a given model.
+ *
+ * Currently, the summary ignores the instance weights.
*/
+@Experimental
sealed trait LogisticRegressionSummary extends Serializable {
/**
* Dataframe output by the model's `transform` method.
*/
+ @Since("1.5.0")
def predictions: DataFrame
/** Field in "predictions" which gives the probability of each class as a vector. */
+ @Since("1.5.0")
def probabilityCol: String
+ /** Field in "predictions" which gives the prediction of each class. */
+ @Since("2.3.0")
+ def predictionCol: String
+
/** Field in "predictions" which gives the true label of each instance (if available). */
+ @Since("1.5.0")
def labelCol: String
/** Field in "predictions" which gives the features of each instance as a vector. */
+ @Since("1.6.0")
def featuresCol: String
+ @transient private val multiclassMetrics = {
+ new MulticlassMetrics(
+ predictions.select(
+ col(predictionCol),
+ col(labelCol).cast(DoubleType))
+ .rdd.map { case Row(prediction: Double, label: Double) => (prediction, label) })
+ }
+
+ /**
+ * Returns the sequence of labels in ascending order. This order matches the order used
+ * in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel.
+ *
+ * Note: In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the
+ * training set is missing a label, then all of the arrays over labels
+ * (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the
+ * expected numClasses.
+ */
+ @Since("2.3.0")
+ def labels: Array[Double] = multiclassMetrics.labels
+
+ /** Returns true positive rate for each label (category). */
+ @Since("2.3.0")
+ def truePositiveRateByLabel: Array[Double] = recallByLabel
+
+ /** Returns false positive rate for each label (category). */
+ @Since("2.3.0")
+ def falsePositiveRateByLabel: Array[Double] = {
+ multiclassMetrics.labels.map(label => multiclassMetrics.falsePositiveRate(label))
+ }
+
+ /** Returns precision for each label (category). */
+ @Since("2.3.0")
+ def precisionByLabel: Array[Double] = {
+ multiclassMetrics.labels.map(label => multiclassMetrics.precision(label))
+ }
+
+ /** Returns recall for each label (category). */
+ @Since("2.3.0")
+ def recallByLabel: Array[Double] = {
+ multiclassMetrics.labels.map(label => multiclassMetrics.recall(label))
+ }
+
+ /** Returns f-measure for each label (category). */
+ @Since("2.3.0")
+ def fMeasureByLabel(beta: Double): Array[Double] = {
+ multiclassMetrics.labels.map(label => multiclassMetrics.fMeasure(label, beta))
+ }
+
+ /** Returns f1-measure for each label (category). */
+ @Since("2.3.0")
+ def fMeasureByLabel: Array[Double] = fMeasureByLabel(1.0)
+
+ /**
+ * Returns accuracy.
+ * (equals to the total number of correctly classified instances
+ * out of the total number of instances.)
+ */
+ @Since("2.3.0")
+ def accuracy: Double = multiclassMetrics.accuracy
+
+ /**
+ * Returns weighted true positive rate.
+ * (equals to precision, recall and f-measure)
+ */
+ @Since("2.3.0")
+ def weightedTruePositiveRate: Double = weightedRecall
+
+ /** Returns weighted false positive rate. */
+ @Since("2.3.0")
+ def weightedFalsePositiveRate: Double = multiclassMetrics.weightedFalsePositiveRate
+
+ /**
+ * Returns weighted averaged recall.
+ * (equals to precision, recall and f-measure)
+ */
+ @Since("2.3.0")
+ def weightedRecall: Double = multiclassMetrics.weightedRecall
+
+ /** Returns weighted averaged precision. */
+ @Since("2.3.0")
+ def weightedPrecision: Double = multiclassMetrics.weightedPrecision
+
+ /** Returns weighted averaged f-measure. */
+ @Since("2.3.0")
+ def weightedFMeasure(beta: Double): Double = multiclassMetrics.weightedFMeasure(beta)
+
+ /** Returns weighted averaged f1-measure. */
+ @Since("2.3.0")
+ def weightedFMeasure: Double = multiclassMetrics.weightedFMeasure(1.0)
+
+ /**
+ * Convenient method for casting to binary logistic regression summary.
+ * This method will throws an Exception if the summary is not a binary summary.
+ */
+ @Since("2.3.0")
+ def asBinary: BinaryLogisticRegressionSummary = this match {
+ case b: BinaryLogisticRegressionSummary => b
+ case _ =>
+ throw new RuntimeException("Cannot cast to a binary summary.")
+ }
}
/**
* :: Experimental ::
- * Logistic regression training results.
- *
- * @param predictions dataframe output by the model's `transform` method.
- * @param probabilityCol field in "predictions" which gives the probability of
- * each class as a vector.
- * @param labelCol field in "predictions" which gives the true label of each instance.
- * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
- * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
+ * Abstraction for multiclass logistic regression training results.
+ * Currently, the training summary ignores the training weights except
+ * for the objective trace.
*/
@Experimental
-@Since("1.5.0")
-class BinaryLogisticRegressionTrainingSummary private[classification] (
- predictions: DataFrame,
- probabilityCol: String,
- labelCol: String,
- featuresCol: String,
- @Since("1.5.0") val objectiveHistory: Array[Double])
- extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol)
- with LogisticRegressionTrainingSummary {
+sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary {
+
+ /** objective function (scaled loss + regularization) at each iteration. */
+ @Since("1.5.0")
+ def objectiveHistory: Array[Double]
+
+ /** Number of training iterations. */
+ @Since("1.5.0")
+ def totalIterations: Int = objectiveHistory.length
}
/**
* :: Experimental ::
- * Binary Logistic regression results for a given model.
+ * Abstraction for binary logistic regression results for a given model.
*
- * @param predictions dataframe output by the model's `transform` method.
- * @param probabilityCol field in "predictions" which gives the probability of
- * each class as a vector.
- * @param labelCol field in "predictions" which gives the true label of each instance.
- * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
+ * Currently, the summary ignores the instance weights.
*/
@Experimental
-@Since("1.5.0")
-class BinaryLogisticRegressionSummary private[classification] (
- @Since("1.5.0") @transient override val predictions: DataFrame,
- @Since("1.5.0") override val probabilityCol: String,
- @Since("1.5.0") override val labelCol: String,
- @Since("1.6.0") override val featuresCol: String) extends LogisticRegressionSummary {
-
+sealed trait BinaryLogisticRegressionSummary extends LogisticRegressionSummary {
private val sparkSession = predictions.sparkSession
import sparkSession.implicits._
- /**
- * Returns a BinaryClassificationMetrics object.
- */
// TODO: Allow the user to vary the number of bins using a setBins method in
// BinaryClassificationMetrics. For now the default is set to 100.
@transient private val binaryMetrics = new BinaryClassificationMetrics(
@@ -1484,3 +1593,99 @@ class BinaryLogisticRegressionSummary private[classification] (
binaryMetrics.recallByThreshold().toDF("threshold", "recall")
}
}
+
+/**
+ * :: Experimental ::
+ * Abstraction for binary logistic regression training results.
+ * Currently, the training summary ignores the training weights except
+ * for the objective trace.
+ */
+@Experimental
+sealed trait BinaryLogisticRegressionTrainingSummary extends BinaryLogisticRegressionSummary
+ with LogisticRegressionTrainingSummary
+
+/**
+ * Multiclass logistic regression training results.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the probability of
+ * each class as a vector.
+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
+ * double.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
+ */
+private class LogisticRegressionTrainingSummaryImpl(
+ predictions: DataFrame,
+ probabilityCol: String,
+ predictionCol: String,
+ labelCol: String,
+ featuresCol: String,
+ override val objectiveHistory: Array[Double])
+ extends LogisticRegressionSummaryImpl(
+ predictions, probabilityCol, predictionCol, labelCol, featuresCol)
+ with LogisticRegressionTrainingSummary
+
+/**
+ * Multiclass logistic regression results for a given model.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the probability of
+ * each class as a vector.
+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
+ * double.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
+ */
+private class LogisticRegressionSummaryImpl(
+ @transient override val predictions: DataFrame,
+ override val probabilityCol: String,
+ override val predictionCol: String,
+ override val labelCol: String,
+ override val featuresCol: String)
+ extends LogisticRegressionSummary
+
+/**
+ * Binary logistic regression training results.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the probability of
+ * each class as a vector.
+ * @param predictionCol field in "predictions" which gives the prediction for a data instance as a
+ * double.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
+ */
+private class BinaryLogisticRegressionTrainingSummaryImpl(
+ predictions: DataFrame,
+ probabilityCol: String,
+ predictionCol: String,
+ labelCol: String,
+ featuresCol: String,
+ override val objectiveHistory: Array[Double])
+ extends BinaryLogisticRegressionSummaryImpl(
+ predictions, probabilityCol, predictionCol, labelCol, featuresCol)
+ with BinaryLogisticRegressionTrainingSummary
+
+/**
+ * Binary logistic regression results for a given model.
+ *
+ * @param predictions dataframe output by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the probability of
+ * each class as a vector.
+ * @param predictionCol field in "predictions" which gives the prediction of
+ * each class as a double.
+ * @param labelCol field in "predictions" which gives the true label of each instance.
+ * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
+ */
+private class BinaryLogisticRegressionSummaryImpl(
+ predictions: DataFrame,
+ probabilityCol: String,
+ predictionCol: String,
+ labelCol: String,
+ featuresCol: String)
+ extends LogisticRegressionSummaryImpl(
+ predictions, probabilityCol, predictionCol, labelCol, featuresCol)
+ with BinaryLogisticRegressionSummary
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index ceba11edc93be..fd4c98f22132f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -22,7 +22,6 @@ import scala.collection.JavaConverters._
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
-import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors}
@@ -32,7 +31,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.sql.Dataset
/** Params for Multilayer Perceptron. */
-private[classification] trait MultilayerPerceptronParams extends PredictorParams
+private[classification] trait MultilayerPerceptronParams extends ProbabilisticClassifierParams
with HasSeed with HasMaxIter with HasTol with HasStepSize with HasSolver {
import MultilayerPerceptronClassifier._
@@ -143,7 +142,8 @@ private object LabelConverter {
@Since("1.5.0")
class MultilayerPerceptronClassifier @Since("1.5.0") (
@Since("1.5.0") override val uid: String)
- extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel]
+ extends ProbabilisticClassifier[Vector, MultilayerPerceptronClassifier,
+ MultilayerPerceptronClassificationModel]
with MultilayerPerceptronParams with DefaultParamsWritable {
@Since("1.5.0")
@@ -301,13 +301,13 @@ class MultilayerPerceptronClassificationModel private[ml] (
@Since("1.5.0") override val uid: String,
@Since("1.5.0") val layers: Array[Int],
@Since("2.0.0") val weights: Vector)
- extends PredictionModel[Vector, MultilayerPerceptronClassificationModel]
+ extends ProbabilisticClassificationModel[Vector, MultilayerPerceptronClassificationModel]
with Serializable with MLWritable {
@Since("1.6.0")
override val numFeatures: Int = layers.head
- private val mlpModel = FeedForwardTopology
+ private[ml] val mlpModel = FeedForwardTopology
.multiLayerPerceptron(layers, softmaxOnTop = true)
.model(weights)
@@ -335,6 +335,14 @@ class MultilayerPerceptronClassificationModel private[ml] (
@Since("2.0.0")
override def write: MLWriter =
new MultilayerPerceptronClassificationModel.MultilayerPerceptronClassificationModelWriter(this)
+
+ override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+ mlpModel.raw2ProbabilityInPlace(rawPrediction)
+ }
+
+ override protected def predictRaw(features: Vector): Vector = mlpModel.predictRaw(features)
+
+ override def numClasses: Int = layers.last
}
@Since("2.0.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index e5713599406e0..0293e03d47435 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -28,7 +28,6 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions.{col, lit}
-import org.apache.spark.sql.types.DoubleType
/**
* Params for Naive Bayes Classifiers.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 05b8c3ab5456e..92a7742f6c865 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -17,10 +17,10 @@
package org.apache.spark.ml.classification
-import java.util.{List => JList}
import java.util.UUID
-import scala.collection.JavaConverters._
+import scala.concurrent.Future
+import scala.concurrent.duration.Duration
import scala.language.existentials
import org.apache.hadoop.fs.Path
@@ -34,12 +34,13 @@ import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
-import org.apache.spark.ml.param.shared.HasWeightCol
+import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.ThreadUtils
private[ml] trait ClassifierTypeTrait {
// scalastyle:off structural.type
@@ -164,7 +165,7 @@ final class OneVsRestModel private[ml] (
val newDataset = dataset.withColumn(accColName, initUDF())
// persist if underlying dataset is not persistent.
- val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ val handlePersistence = dataset.storageLevel == StorageLevel.NONE
if (handlePersistence) {
newDataset.persist(StorageLevel.MEMORY_AND_DISK)
}
@@ -273,7 +274,7 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] {
@Since("1.4.0")
final class OneVsRest @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
- extends Estimator[OneVsRestModel] with OneVsRestParams with MLWritable {
+ extends Estimator[OneVsRestModel] with OneVsRestParams with HasParallelism with MLWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("oneVsRest"))
@@ -296,6 +297,17 @@ final class OneVsRest @Since("1.4.0") (
@Since("1.5.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
+ /**
+ * The implementation of parallel one vs. rest runs the classification for
+ * each class in a separate threads.
+ *
+ * @group expertSetParam
+ */
+ @Since("2.3.0")
+ def setParallelism(value: Int): this.type = {
+ set(parallelism, value)
+ }
+
/**
* Sets the value of param [[weightCol]].
*
@@ -318,7 +330,7 @@ final class OneVsRest @Since("1.4.0") (
transformSchema(dataset.schema)
val instr = Instrumentation.create(this, dataset)
- instr.logParams(labelCol, featuresCol, predictionCol)
+ instr.logParams(labelCol, featuresCol, predictionCol, parallelism)
instr.logNamedValue("classifier", $(classifier).getClass.getCanonicalName)
// determine number of classes either from metadata if provided, or via computation.
@@ -347,13 +359,15 @@ final class OneVsRest @Since("1.4.0") (
}
// persist if underlying dataset is not persistent.
- val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ val handlePersistence = dataset.storageLevel == StorageLevel.NONE
if (handlePersistence) {
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
}
+ val executionContext = getExecutionContext
+
// create k columns, one for each binary classifier.
- val models = Range(0, numClasses).par.map { index =>
+ val modelFutures = Range(0, numClasses).map { index =>
// generate new label metadata for the binary problem.
val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
val labelColName = "mc2b$" + index
@@ -364,14 +378,18 @@ final class OneVsRest @Since("1.4.0") (
paramMap.put(classifier.labelCol -> labelColName)
paramMap.put(classifier.featuresCol -> getFeaturesCol)
paramMap.put(classifier.predictionCol -> getPredictionCol)
- if (weightColIsUsed) {
- val classifier_ = classifier.asInstanceOf[ClassifierType with HasWeightCol]
- paramMap.put(classifier_.weightCol -> getWeightCol)
- classifier_.fit(trainingDataset, paramMap)
- } else {
- classifier.fit(trainingDataset, paramMap)
- }
- }.toArray[ClassificationModel[_, _]]
+ Future {
+ if (weightColIsUsed) {
+ val classifier_ = classifier.asInstanceOf[ClassifierType with HasWeightCol]
+ paramMap.put(classifier_.weightCol -> getWeightCol)
+ classifier_.fit(trainingDataset, paramMap)
+ } else {
+ classifier.fit(trainingDataset, paramMap)
+ }
+ }(executionContext)
+ }
+ val models = modelFutures
+ .map(ThreadUtils.awaitResult(_, Duration.Inf)).toArray[ClassificationModel[_, _]]
instr.logNumFeatures(models.head.numFeatures)
if (handlePersistence) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 5259ee419445f..f19ad7a5a6938 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -64,8 +64,8 @@ private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter w
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
- SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
- SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT)
+ val schemaWithPredictionCol = SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
+ SchemaUtils.appendColumn(schemaWithPredictionCol, $(probabilityCol), new VectorUDT)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index e02b532ca8a93..f2af7fe082b41 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -304,7 +304,7 @@ class KMeans @Since("1.5.0") (
override def fit(dataset: Dataset[_]): KMeansModel = {
transformSchema(dataset.schema, logging = true)
- val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ val handlePersistence = dataset.storageLevel == StorageLevel.NONE
val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
case Row(point: Vector) => OldVectors.fromML(point)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
new file mode 100644
index 0000000000000..d6ec5223237bb
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/ClusteringEvaluator.scala
@@ -0,0 +1,436 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.evaluation
+
+import org.apache.spark.SparkContext
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors, VectorUDT}
+import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
+import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol}
+import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.functions.{avg, col, udf}
+import org.apache.spark.sql.types.DoubleType
+
+/**
+ * :: Experimental ::
+ *
+ * Evaluator for clustering results.
+ * The metric computes the Silhouette measure
+ * using the squared Euclidean distance.
+ *
+ * The Silhouette is a measure for the validation
+ * of the consistency within clusters. It ranges
+ * between 1 and -1, where a value close to 1
+ * means that the points in a cluster are close
+ * to the other points in the same cluster and
+ * far from the points of the other clusters.
+ */
+@Experimental
+@Since("2.3.0")
+class ClusteringEvaluator @Since("2.3.0") (@Since("2.3.0") override val uid: String)
+ extends Evaluator with HasPredictionCol with HasFeaturesCol with DefaultParamsWritable {
+
+ @Since("2.3.0")
+ def this() = this(Identifiable.randomUID("cluEval"))
+
+ @Since("2.3.0")
+ override def copy(pMap: ParamMap): ClusteringEvaluator = this.defaultCopy(pMap)
+
+ @Since("2.3.0")
+ override def isLargerBetter: Boolean = true
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /**
+ * param for metric name in evaluation
+ * (supports `"silhouette"` (default))
+ * @group param
+ */
+ @Since("2.3.0")
+ val metricName: Param[String] = {
+ val allowedParams = ParamValidators.inArray(Array("silhouette"))
+ new Param(
+ this, "metricName", "metric name in evaluation (silhouette)", allowedParams)
+ }
+
+ /** @group getParam */
+ @Since("2.3.0")
+ def getMetricName: String = $(metricName)
+
+ /** @group setParam */
+ @Since("2.3.0")
+ def setMetricName(value: String): this.type = set(metricName, value)
+
+ setDefault(metricName -> "silhouette")
+
+ @Since("2.3.0")
+ override def evaluate(dataset: Dataset[_]): Double = {
+ SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
+ SchemaUtils.checkNumericType(dataset.schema, $(predictionCol))
+
+ $(metricName) match {
+ case "silhouette" =>
+ SquaredEuclideanSilhouette.computeSilhouetteScore(
+ dataset, $(predictionCol), $(featuresCol)
+ )
+ }
+ }
+}
+
+
+@Since("2.3.0")
+object ClusteringEvaluator
+ extends DefaultParamsReadable[ClusteringEvaluator] {
+
+ @Since("2.3.0")
+ override def load(path: String): ClusteringEvaluator = super.load(path)
+
+}
+
+
+/**
+ * SquaredEuclideanSilhouette computes the average of the
+ * Silhouette over all the data of the dataset, which is
+ * a measure of how appropriately the data have been clustered.
+ *
+ * The Silhouette for each point `i` is defined as:
+ *
+ *
+ * $$
+ * s_{i} = \frac{b_{i}-a_{i}}{max\{a_{i},b_{i}\}}
+ * $$
+ *
+ *
+ * which can be rewritten as
+ *
+ *
+ * $$
+ * s_{i}= \begin{cases}
+ * 1-\frac{a_{i}}{b_{i}} & \text{if } a_{i} \leq b_{i} \\
+ * \frac{b_{i}}{a_{i}}-1 & \text{if } a_{i} \gt b_{i} \end{cases}
+ * $$
+ *
+ *
+ * where `$a_{i}$` is the average dissimilarity of `i` with all other data
+ * within the same cluster, `$b_{i}$` is the lowest average dissimilarity
+ * of `i` to any other cluster, of which `i` is not a member.
+ * `$a_{i}$` can be interpreted as how well `i` is assigned to its cluster
+ * (the smaller the value, the better the assignment), while `$b_{i}$` is
+ * a measure of how well `i` has not been assigned to its "neighboring cluster",
+ * ie. the nearest cluster to `i`.
+ *
+ * Unfortunately, the naive implementation of the algorithm requires to compute
+ * the distance of each couple of points in the dataset. Since the computation of
+ * the distance measure takes `D` operations - if `D` is the number of dimensions
+ * of each point, the computational complexity of the algorithm is `O(N^2^*D)`, where
+ * `N` is the cardinality of the dataset. Of course this is not scalable in `N`,
+ * which is the critical number in a Big Data context.
+ *
+ * The algorithm which is implemented in this object, instead, is an efficient
+ * and parallel implementation of the Silhouette using the squared Euclidean
+ * distance measure.
+ *
+ * With this assumption, the total distance of the point `X`
+ * to the points `$C_{i}$` belonging to the cluster `$\Gamma$` is:
+ *
+ *
+ * $$
+ * \sum\limits_{i=1}^N d(X, C_{i} ) =
+ * \sum\limits_{i=1}^N \Big( \sum\limits_{j=1}^D (x_{j}-c_{ij})^2 \Big)
+ * = \sum\limits_{i=1}^N \Big( \sum\limits_{j=1}^D x_{j}^2 +
+ * \sum\limits_{j=1}^D c_{ij}^2 -2\sum\limits_{j=1}^D x_{j}c_{ij} \Big)
+ * = \sum\limits_{i=1}^N \sum\limits_{j=1}^D x_{j}^2 +
+ * \sum\limits_{i=1}^N \sum\limits_{j=1}^D c_{ij}^2
+ * -2 \sum\limits_{i=1}^N \sum\limits_{j=1}^D x_{j}c_{ij}
+ * $$
+ *
+ *
+ * where `$x_{j}$` is the `j`-th dimension of the point `X` and
+ * `$c_{ij}$` is the `j`-th dimension of the `i`-th point in cluster `$\Gamma$`.
+ *
+ * Then, the first term of the equation can be rewritten as:
+ *
+ *
+ * $$
+ * \sum\limits_{i=1}^N \sum\limits_{j=1}^D x_{j}^2 = N \xi_{X} \text{ ,
+ * with } \xi_{X} = \sum\limits_{j=1}^D x_{j}^2
+ * $$
+ *
+ *
+ * where `$\xi_{X}$` is fixed for each point and it can be precomputed.
+ *
+ * Moreover, the second term is fixed for each cluster too,
+ * thus we can name it `$\Psi_{\Gamma}$`
+ *
+ *
+ * $$
+ * \sum\limits_{i=1}^N \sum\limits_{j=1}^D c_{ij}^2 =
+ * \sum\limits_{i=1}^N \xi_{C_{i}} = \Psi_{\Gamma}
+ * $$
+ *
+ *
+ * Last, the third element becomes
+ *
+ *
+ * $$
+ * \sum\limits_{i=1}^N \sum\limits_{j=1}^D x_{j}c_{ij} =
+ * \sum\limits_{j=1}^D \Big(\sum\limits_{i=1}^N c_{ij} \Big) x_{j}
+ * $$
+ *
+ *
+ * thus defining the vector
+ *
+ *
+ * $$
+ * Y_{\Gamma}:Y_{\Gamma j} = \sum\limits_{i=1}^N c_{ij} , j=0, ..., D
+ * $$
+ *
+ *
+ * which is fixed for each cluster `$\Gamma$`, we have
+ *
+ *
+ * $$
+ * \sum\limits_{j=1}^D \Big(\sum\limits_{i=1}^N c_{ij} \Big) x_{j} =
+ * \sum\limits_{j=1}^D Y_{\Gamma j} x_{j}
+ * $$
+ *
+ *
+ * In this way, the previous equation becomes
+ *
+ *
+ * $$
+ * N\xi_{X} + \Psi_{\Gamma} - 2 \sum\limits_{j=1}^D Y_{\Gamma j} x_{j}
+ * $$
+ *
+ *
+ * and the average distance of a point to a cluster can be computed as
+ *
+ *
+ * $$
+ * \frac{\sum\limits_{i=1}^N d(X, C_{i} )}{N} =
+ * \frac{N\xi_{X} + \Psi_{\Gamma} - 2 \sum\limits_{j=1}^D Y_{\Gamma j} x_{j}}{N} =
+ * \xi_{X} + \frac{\Psi_{\Gamma} }{N} - 2 \frac{\sum\limits_{j=1}^D Y_{\Gamma j} x_{j}}{N}
+ * $$
+ *
+ *
+ * Thus, it is enough to precompute: the constant `$\xi_{X}$` for each point `X`; the
+ * constants `$\Psi_{\Gamma}$`, `N` and the vector `$Y_{\Gamma}$` for
+ * each cluster `$\Gamma$`.
+ *
+ * In the implementation, the precomputed values for the clusters
+ * are distributed among the worker nodes via broadcasted variables,
+ * because we can assume that the clusters are limited in number and
+ * anyway they are much fewer than the points.
+ *
+ * The main strengths of this algorithm are the low computational complexity
+ * and the intrinsic parallelism. The precomputed information for each point
+ * and for each cluster can be computed with a computational complexity
+ * which is `O(N/W)`, where `N` is the number of points in the dataset and
+ * `W` is the number of worker nodes. After that, every point can be
+ * analyzed independently of the others.
+ *
+ * For every point we need to compute the average distance to all the clusters.
+ * Since the formula above requires `O(D)` operations, this phase has a
+ * computational complexity which is `O(C*D*N/W)` where `C` is the number of
+ * clusters (which we assume quite low), `D` is the number of dimensions,
+ * `N` is the number of points in the dataset and `W` is the number
+ * of worker nodes.
+ */
+private[evaluation] object SquaredEuclideanSilhouette {
+
+ private[this] var kryoRegistrationPerformed: Boolean = false
+
+ /**
+ * This method registers the class
+ * [[org.apache.spark.ml.evaluation.SquaredEuclideanSilhouette.ClusterStats]]
+ * for kryo serialization.
+ *
+ * @param sc `SparkContext` to be used
+ */
+ def registerKryoClasses(sc: SparkContext): Unit = {
+ if (!kryoRegistrationPerformed) {
+ sc.getConf.registerKryoClasses(
+ Array(
+ classOf[SquaredEuclideanSilhouette.ClusterStats]
+ )
+ )
+ kryoRegistrationPerformed = true
+ }
+ }
+
+ case class ClusterStats(featureSum: Vector, squaredNormSum: Double, numOfPoints: Long)
+
+ /**
+ * The method takes the input dataset and computes the aggregated values
+ * about a cluster which are needed by the algorithm.
+ *
+ * @param df The DataFrame which contains the input data
+ * @param predictionCol The name of the column which contains the predicted cluster id
+ * for the point.
+ * @param featuresCol The name of the column which contains the feature vector of the point.
+ * @return A [[scala.collection.immutable.Map]] which associates each cluster id
+ * to a [[ClusterStats]] object (which contains the precomputed values `N`,
+ * `$\Psi_{\Gamma}$` and `$Y_{\Gamma}$` for a cluster).
+ */
+ def computeClusterStats(
+ df: DataFrame,
+ predictionCol: String,
+ featuresCol: String): Map[Double, ClusterStats] = {
+ val numFeatures = df.select(col(featuresCol)).first().getAs[Vector](0).size
+ val clustersStatsRDD = df.select(
+ col(predictionCol).cast(DoubleType), col(featuresCol), col("squaredNorm"))
+ .rdd
+ .map { row => (row.getDouble(0), (row.getAs[Vector](1), row.getDouble(2))) }
+ .aggregateByKey[(DenseVector, Double, Long)]((Vectors.zeros(numFeatures).toDense, 0.0, 0L))(
+ seqOp = {
+ case (
+ (featureSum: DenseVector, squaredNormSum: Double, numOfPoints: Long),
+ (features, squaredNorm)
+ ) =>
+ BLAS.axpy(1.0, features, featureSum)
+ (featureSum, squaredNormSum + squaredNorm, numOfPoints + 1)
+ },
+ combOp = {
+ case (
+ (featureSum1, squaredNormSum1, numOfPoints1),
+ (featureSum2, squaredNormSum2, numOfPoints2)
+ ) =>
+ BLAS.axpy(1.0, featureSum2, featureSum1)
+ (featureSum1, squaredNormSum1 + squaredNormSum2, numOfPoints1 + numOfPoints2)
+ }
+ )
+
+ clustersStatsRDD
+ .collectAsMap()
+ .mapValues {
+ case (featureSum: DenseVector, squaredNormSum: Double, numOfPoints: Long) =>
+ SquaredEuclideanSilhouette.ClusterStats(featureSum, squaredNormSum, numOfPoints)
+ }
+ .toMap
+ }
+
+ /**
+ * It computes the Silhouette coefficient for a point.
+ *
+ * @param broadcastedClustersMap A map of the precomputed values for each cluster.
+ * @param features The [[org.apache.spark.ml.linalg.Vector]] representing the current point.
+ * @param clusterId The id of the cluster the current point belongs to.
+ * @param squaredNorm The `$\Xi_{X}$` (which is the squared norm) precomputed for the point.
+ * @return The Silhouette for the point.
+ */
+ def computeSilhouetteCoefficient(
+ broadcastedClustersMap: Broadcast[Map[Double, ClusterStats]],
+ features: Vector,
+ clusterId: Double,
+ squaredNorm: Double): Double = {
+
+ def compute(squaredNorm: Double, point: Vector, clusterStats: ClusterStats): Double = {
+ val pointDotClusterFeaturesSum = BLAS.dot(point, clusterStats.featureSum)
+
+ squaredNorm +
+ clusterStats.squaredNormSum / clusterStats.numOfPoints -
+ 2 * pointDotClusterFeaturesSum / clusterStats.numOfPoints
+ }
+
+ // Here we compute the average dissimilarity of the
+ // current point to any cluster of which the point
+ // is not a member.
+ // The cluster with the lowest average dissimilarity
+ // - i.e. the nearest cluster to the current point -
+ // is said to be the "neighboring cluster".
+ var neighboringClusterDissimilarity = Double.MaxValue
+ broadcastedClustersMap.value.keySet.foreach {
+ c =>
+ if (c != clusterId) {
+ val dissimilarity = compute(squaredNorm, features, broadcastedClustersMap.value(c))
+ if(dissimilarity < neighboringClusterDissimilarity) {
+ neighboringClusterDissimilarity = dissimilarity
+ }
+ }
+ }
+ val currentCluster = broadcastedClustersMap.value(clusterId)
+ // adjustment for excluding the node itself from
+ // the computation of the average dissimilarity
+ val currentClusterDissimilarity = if (currentCluster.numOfPoints == 1) {
+ 0
+ } else {
+ compute(squaredNorm, features, currentCluster) * currentCluster.numOfPoints /
+ (currentCluster.numOfPoints - 1)
+ }
+
+ (currentClusterDissimilarity compare neighboringClusterDissimilarity).signum match {
+ case -1 => 1 - (currentClusterDissimilarity / neighboringClusterDissimilarity)
+ case 1 => (neighboringClusterDissimilarity / currentClusterDissimilarity) - 1
+ case 0 => 0.0
+ }
+ }
+
+ /**
+ * Compute the mean Silhouette values of all samples.
+ *
+ * @param dataset The input dataset (previously clustered) on which compute the Silhouette.
+ * @param predictionCol The name of the column which contains the predicted cluster id
+ * for the point.
+ * @param featuresCol The name of the column which contains the feature vector of the point.
+ * @return The average of the Silhouette values of the clustered data.
+ */
+ def computeSilhouetteScore(
+ dataset: Dataset[_],
+ predictionCol: String,
+ featuresCol: String): Double = {
+ SquaredEuclideanSilhouette.registerKryoClasses(dataset.sparkSession.sparkContext)
+
+ val squaredNormUDF = udf {
+ features: Vector => math.pow(Vectors.norm(features, 2.0), 2.0)
+ }
+ val dfWithSquaredNorm = dataset.withColumn("squaredNorm", squaredNormUDF(col(featuresCol)))
+
+ // compute aggregate values for clusters needed by the algorithm
+ val clustersStatsMap = SquaredEuclideanSilhouette
+ .computeClusterStats(dfWithSquaredNorm, predictionCol, featuresCol)
+
+ // Silhouette is reasonable only when the number of clusters is grater then 1
+ assert(clustersStatsMap.size > 1, "Number of clusters must be greater than one.")
+
+ val bClustersStatsMap = dataset.sparkSession.sparkContext.broadcast(clustersStatsMap)
+
+ val computeSilhouetteCoefficientUDF = udf {
+ computeSilhouetteCoefficient(bClustersStatsMap, _: Vector, _: Double, _: Double)
+ }
+
+ val silhouetteScore = dfWithSquaredNorm
+ .select(avg(
+ computeSilhouetteCoefficientUDF(
+ col(featuresCol), col(predictionCol).cast(DoubleType), col("squaredNorm"))
+ ))
+ .collect()(0)
+ .getDouble(0)
+
+ bClustersStatsMap.destroy()
+
+ silhouetteScore
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala
index d22bf164c313c..4615daed20fb1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/FeatureHasher.scala
@@ -53,9 +53,10 @@ import org.apache.spark.util.collection.OpenHashMap
*
* Null (missing) values are ignored (implicitly zero in the resulting feature vector).
*
- * Since a simple modulo is used to transform the hash function to a vector index,
- * it is advisable to use a power of two as the numFeatures parameter;
- * otherwise the features will not be mapped evenly to the vector indices.
+ * The hash function used here is also the MurmurHash 3 used in [[HashingTF]]. Since a simple modulo
+ * on the hashed value is used to determine the vector index, it is advisable to use a power of two
+ * as the numFeatures parameter; otherwise the features will not be mapped evenly to the vector
+ * indices.
*
* {{{
* val df = Seq(
@@ -64,17 +65,17 @@ import org.apache.spark.util.collection.OpenHashMap
* ).toDF("real", "bool", "stringNum", "string")
*
* val hasher = new FeatureHasher()
- * .setInputCols("real", "bool", "stringNum", "num")
+ * .setInputCols("real", "bool", "stringNum", "string")
* .setOutputCol("features")
*
- * hasher.transform(df).show()
+ * hasher.transform(df).show(false)
*
- * +----+-----+---------+------+--------------------+
- * |real| bool|stringNum|string| features|
- * +----+-----+---------+------+--------------------+
- * | 2.0| true| 1| foo|(262144,[51871,63...|
- * | 3.0|false| 2| bar|(262144,[6031,806...|
- * +----+-----+---------+------+--------------------+
+ * +----+-----+---------+------+------------------------------------------------------+
+ * |real|bool |stringNum|string|features |
+ * +----+-----+---------+------+------------------------------------------------------+
+ * |2.0 |true |1 |foo |(262144,[51871,63643,174475,253195],[1.0,1.0,2.0,1.0])|
+ * |3.0 |false|2 |bar |(262144,[6031,80619,140467,174475],[1.0,1.0,1.0,3.0]) |
+ * +----+-----+---------+------+------------------------------------------------------+
* }}}
*/
@Experimental
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
index 9e023b9dd469b..1f36eced3d08f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
@@ -133,23 +133,49 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String)
override def fit(dataset: Dataset[_]): ImputerModel = {
transformSchema(dataset.schema, logging = true)
val spark = dataset.sparkSession
- import spark.implicits._
- val surrogates = $(inputCols).map { inputCol =>
- val ic = col(inputCol)
- val filtered = dataset.select(ic.cast(DoubleType))
- .filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN)
- if(filtered.take(1).length == 0) {
- throw new SparkException(s"surrogate cannot be computed. " +
- s"All the values in $inputCol are Null, Nan or missingValue(${$(missingValue)})")
- }
- val surrogate = $(strategy) match {
- case Imputer.mean => filtered.select(avg(inputCol)).as[Double].first()
- case Imputer.median => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001).head
- }
- surrogate
+
+ val cols = $(inputCols).map { inputCol =>
+ when(col(inputCol).equalTo($(missingValue)), null)
+ .when(col(inputCol).isNaN, null)
+ .otherwise(col(inputCol))
+ .cast("double")
+ .as(inputCol)
+ }
+
+ val results = $(strategy) match {
+ case Imputer.mean =>
+ // Function avg will ignore null automatically.
+ // For a column only containing null, avg will return null.
+ val row = dataset.select(cols.map(avg): _*).head()
+ Array.range(0, $(inputCols).length).map { i =>
+ if (row.isNullAt(i)) {
+ Double.NaN
+ } else {
+ row.getDouble(i)
+ }
+ }
+
+ case Imputer.median =>
+ // Function approxQuantile will ignore null automatically.
+ // For a column only containing null, approxQuantile will return an empty array.
+ dataset.select(cols: _*).stat.approxQuantile($(inputCols), Array(0.5), 0.001)
+ .map { array =>
+ if (array.isEmpty) {
+ Double.NaN
+ } else {
+ array.head
+ }
+ }
+ }
+
+ val emptyCols = $(inputCols).zip(results).filter(_._2.isNaN).map(_._1)
+ if (emptyCols.nonEmpty) {
+ throw new SparkException(s"surrogate cannot be computed. " +
+ s"All the values in ${emptyCols.mkString(",")} are Null, Nan or " +
+ s"missingValue(${$(missingValue)})")
}
- val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(surrogates)))
+ val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(results)))
val schema = StructType($(inputCols).map(col => StructField(col, DoubleType, nullable = false)))
val surrogateDF = spark.createDataFrame(rows, schema)
copyValues(new ImputerModel(uid, surrogateDF).setParent(this))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index d4c8e4b361959..f6095e26f435c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -229,7 +229,7 @@ class Word2VecModel private[ml] (
* Find "num" number of words closest in similarity to the given word, not
* including the word itself.
* @return a dataframe with columns "word" and "similarity" of the word and the cosine
- * similarities between the synonyms and the given word vector.
+ * similarities between the synonyms and the given word.
*/
@Since("1.5.0")
def findSynonyms(word: String, num: Int): DataFrame = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/package.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/package.scala
index d75a6dc9377ae..6ff970cc72dfd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/package.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/package.scala
@@ -18,7 +18,6 @@
package org.apache.spark.ml
import org.apache.spark.ml.feature.{HashingTF, IDF, IDFModel, VectorAssembler}
-import org.apache.spark.sql.DataFrame
/**
* == Feature transformers ==
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
index 32b0af72ba9bb..c5c9c8eb2bd29 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
@@ -18,7 +18,7 @@
package org.apache.spark.ml.optim
import org.apache.spark.internal.Logging
-import org.apache.spark.ml.feature.{Instance, OffsetInstance}
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg._
import org.apache.spark.rdd.RDD
@@ -440,7 +440,11 @@ private[ml] object WeightedLeastSquares {
/**
* Weighted population standard deviation of labels.
*/
- def bStd: Double = math.sqrt(bbSum / wSum - bBar * bBar)
+ def bStd: Double = {
+ // We prevent variance from negative value caused by numerical error.
+ val variance = math.max(bbSum / wSum - bBar * bBar, 0.0)
+ math.sqrt(variance)
+ }
/**
* Weighted mean of (label * features).
@@ -471,7 +475,8 @@ private[ml] object WeightedLeastSquares {
while (i < triK) {
val l = j - 2
val aw = aSum(l) / wSum
- std(l) = math.sqrt(aaValues(i) / wSum - aw * aw)
+ // We prevent variance from negative value caused by numerical error.
+ std(l) = math.sqrt(math.max(aaValues(i) / wSum - aw * aw, 0.0))
i += j
j += 1
}
@@ -489,7 +494,8 @@ private[ml] object WeightedLeastSquares {
while (i < triK) {
val l = j - 2
val aw = aSum(l) / wSum
- variance(l) = aaValues(i) / wSum - aw * aw
+ // We prevent variance from negative value caused by numerical error.
+ variance(l) = math.max(aaValues(i) / wSum - aw * aw, 0.0)
i += j
j += 1
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeAggregator.scala
new file mode 100644
index 0000000000000..0300500a34ec0
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/HingeAggregator.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.optim.aggregator
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.linalg._
+
+/**
+ * HingeAggregator computes the gradient and loss for Hinge loss function as used in
+ * binary classification for instances in sparse or dense vector in an online fashion.
+ *
+ * Two HingeAggregators can be merged together to have a summary of loss and gradient of
+ * the corresponding joint dataset.
+ *
+ * This class standardizes feature values during computation using bcFeaturesStd.
+ *
+ * @param bcCoefficients The coefficients corresponding to the features.
+ * @param fitIntercept Whether to fit an intercept term.
+ * @param bcFeaturesStd The standard deviation values of the features.
+ */
+private[ml] class HingeAggregator(
+ bcFeaturesStd: Broadcast[Array[Double]],
+ fitIntercept: Boolean)(bcCoefficients: Broadcast[Vector])
+ extends DifferentiableLossAggregator[Instance, HingeAggregator] {
+
+ private val numFeatures: Int = bcFeaturesStd.value.length
+ private val numFeaturesPlusIntercept: Int = if (fitIntercept) numFeatures + 1 else numFeatures
+ @transient private lazy val coefficientsArray = bcCoefficients.value match {
+ case DenseVector(values) => values
+ case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector" +
+ s" but got type ${bcCoefficients.value.getClass}.")
+ }
+ protected override val dim: Int = numFeaturesPlusIntercept
+
+ /**
+ * Add a new training instance to this HingeAggregator, and update the loss and gradient
+ * of the objective function.
+ *
+ * @param instance The instance of data point to be added.
+ * @return This HingeAggregator object.
+ */
+ def add(instance: Instance): this.type = {
+ instance match { case Instance(label, weight, features) =>
+ require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." +
+ s" Expecting $numFeatures but got ${features.size}.")
+ require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
+
+ if (weight == 0.0) return this
+ val localFeaturesStd = bcFeaturesStd.value
+ val localCoefficients = coefficientsArray
+ val localGradientSumArray = gradientSumArray
+
+ val dotProduct = {
+ var sum = 0.0
+ features.foreachActive { (index, value) =>
+ if (localFeaturesStd(index) != 0.0 && value != 0.0) {
+ sum += localCoefficients(index) * value / localFeaturesStd(index)
+ }
+ }
+ if (fitIntercept) sum += localCoefficients(numFeaturesPlusIntercept - 1)
+ sum
+ }
+ // Our loss function with {0, 1} labels is max(0, 1 - (2y - 1) (f_w(x)))
+ // Therefore the gradient is -(2y - 1)*x
+ val labelScaled = 2 * label - 1.0
+ val loss = if (1.0 > labelScaled * dotProduct) {
+ (1.0 - labelScaled * dotProduct) * weight
+ } else {
+ 0.0
+ }
+
+ if (1.0 > labelScaled * dotProduct) {
+ val gradientScale = -labelScaled * weight
+ features.foreachActive { (index, value) =>
+ if (localFeaturesStd(index) != 0.0 && value != 0.0) {
+ localGradientSumArray(index) += value * gradientScale / localFeaturesStd(index)
+ }
+ }
+ if (fitIntercept) {
+ localGradientSumArray(localGradientSumArray.length - 1) += gradientScale
+ }
+ }
+
+ lossSum += loss
+ weightSum += weight
+ this
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregator.scala
index 66a52942e668c..272d36dd94ae8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregator.scala
@@ -270,11 +270,13 @@ private[ml] class LogisticAggregator(
val margins = new Array[Double](numClasses)
features.foreachActive { (index, value) =>
- val stdValue = value / localFeaturesStd(index)
- var j = 0
- while (j < numClasses) {
- margins(j) += localCoefficients(index * numClasses + j) * stdValue
- j += 1
+ if (localFeaturesStd(index) != 0.0 && value != 0.0) {
+ val stdValue = value / localFeaturesStd(index)
+ var j = 0
+ while (j < numClasses) {
+ margins(j) += localCoefficients(index * numClasses + j) * stdValue
+ j += 1
+ }
}
}
var i = 0
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala
index 7ac7c225e5acb..929374eda13a8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/DifferentiableRegularization.scala
@@ -39,9 +39,13 @@ private[ml] trait DifferentiableRegularization[T] extends DiffFunction[T] {
*
* @param regParam The magnitude of the regularization.
* @param shouldApply A function (Int => Boolean) indicating whether a given index should have
- * regularization applied to it.
+ * regularization applied to it. Usually we don't apply regularization to
+ * the intercept.
* @param applyFeaturesStd Option for a function which maps coefficient index (column major) to the
- * feature standard deviation. If `None`, no standardization is applied.
+ * feature standard deviation. Since we always standardize the data during
+ * training, if `standardization` is false, we have to reverse
+ * standardization by penalizing each component differently by this param.
+ * If `standardization` is true, this should be `None`.
*/
private[ml] class L2Regularization(
override val regParam: Double,
@@ -57,6 +61,11 @@ private[ml] class L2Regularization(
val coef = coefficients(j)
applyFeaturesStd match {
case Some(getStd) =>
+ // If `standardization` is false, we still standardize the data
+ // to improve the rate of convergence; as a result, we have to
+ // perform this reverse standardization by penalizing each component
+ // differently to get effectively the same objective function when
+ // the training dataset is not standardized.
val std = getStd(j)
if (std != 0.0) {
val temp = coef / (std * std)
@@ -66,6 +75,7 @@ private[ml] class L2Regularization(
0.0
}
case None =>
+ // If `standardization` is true, compute L2 regularization normally.
sum += coef * coef
gradient(j) = coef * regParam
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/loss/RDDLossFunction.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/RDDLossFunction.scala
index 173041688128f..387f7c5b1ff33 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/loss/RDDLossFunction.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/loss/RDDLossFunction.scala
@@ -22,7 +22,6 @@ import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.DiffFunction
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
import org.apache.spark.rdd.RDD
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/HasParallelism.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/HasParallelism.scala
new file mode 100644
index 0000000000000..021d0b3e34166
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/HasParallelism.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param.shared
+
+import scala.concurrent.ExecutionContext
+
+import org.apache.spark.ml.param.{IntParam, Params, ParamValidators}
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * Trait to define a level of parallelism for algorithms that are able to use
+ * multithreaded execution, and provide a thread-pool based execution context.
+ */
+private[ml] trait HasParallelism extends Params {
+
+ /**
+ * The number of threads to use when running parallel algorithms.
+ * Default is 1 for serial execution
+ *
+ * @group expertParam
+ */
+ val parallelism = new IntParam(this, "parallelism",
+ "the number of threads to use when running parallel algorithms", ParamValidators.gtEq(1))
+
+ setDefault(parallelism -> 1)
+
+ /** @group expertGetParam */
+ def getParallelism: Int = $(parallelism)
+
+ /**
+ * Create a new execution context with a thread-pool that has a maximum number of threads
+ * set to the value of [[parallelism]]. If this param is set to 1, a same-thread executor
+ * will be used to run in serial.
+ */
+ private[ml] def getExecutionContext: ExecutionContext = {
+ getParallelism match {
+ case 1 =>
+ ThreadUtils.sameThread
+ case n =>
+ ExecutionContext.fromExecutorService(ThreadUtils
+ .newDaemonCachedThreadPool(s"${this.getClass.getSimpleName}-thread-pool", n))
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 545e45e84e9ea..6061d9ca0a084 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -154,7 +154,7 @@ private[ml] trait HasVarianceCol extends Params {
}
/**
- * Trait for shared param threshold (default: 0.5).
+ * Trait for shared param threshold.
*/
private[ml] trait HasThreshold extends Params {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 0891994530f88..4b46c3831d75f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -109,10 +109,12 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
SchemaUtils.checkNumericType(schema, $(censorCol))
SchemaUtils.checkNumericType(schema, $(labelCol))
}
- if (hasQuantilesCol) {
+
+ val schemaWithQuantilesCol = if (hasQuantilesCol) {
SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT)
- }
- SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
+ } else schema
+
+ SchemaUtils.appendColumn(schemaWithQuantilesCol, $(predictionCol), DoubleType)
}
}
@@ -211,7 +213,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = {
transformSchema(dataset.schema, logging = true)
val instances = extractAFTPoints(dataset)
- val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ val handlePersistence = dataset.storageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
val featuresSummarizer = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index 529f66eadbcff..8faab52ea474b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -165,7 +165,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri
transformSchema(dataset.schema, logging = true)
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
val instances = extractWeightedLabeledPoints(dataset)
- val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ val handlePersistence = dataset.storageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
val instr = Instrumentation.create(this, dataset)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index ed431f550817e..b2a968118d1a9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -251,7 +251,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String
return lrModel
}
- val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ val handlePersistence = dataset.storageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
val (featuresSummarizer, ySummarizer) = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
index 7e408b9dbd13a..cae41edb7aca8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
@@ -436,8 +436,9 @@ private[ml] object SummaryBuilderImpl extends Logging {
var i = 0
val len = currM2n.length
while (i < len) {
- realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
- (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
+ // We prevent variance from negative value caused by numerical error.
+ realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
+ (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
i += 1
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 2012d6ca8b5ea..ce2a3a2e40411 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -20,20 +20,23 @@ package org.apache.spark.ml.tuning
import java.util.{List => JList}
import scala.collection.JavaConverters._
+import scala.concurrent.Future
+import scala.concurrent.duration.Duration
-import com.github.fommil.netlib.F2jBLAS
import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
import org.apache.spark.annotation.Since
import org.apache.spark.internal.Logging
-import org.apache.spark.ml._
+import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
-import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
+import org.apache.spark.ml.param.shared.HasParallelism
import org.apache.spark.ml.util._
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.ThreadUtils
/**
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
@@ -64,13 +67,11 @@ private[ml] trait CrossValidatorParams extends ValidatorParams {
@Since("1.2.0")
class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
extends Estimator[CrossValidatorModel]
- with CrossValidatorParams with MLWritable with Logging {
+ with CrossValidatorParams with HasParallelism with MLWritable with Logging {
@Since("1.2.0")
def this() = this(Identifiable.randomUID("cv"))
- private val f2jBLAS = new F2jBLAS
-
/** @group setParam */
@Since("1.2.0")
def setEstimator(value: Estimator[_]): this.type = set(estimator, value)
@@ -91,6 +92,15 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
@Since("2.0.0")
def setSeed(value: Long): this.type = set(seed, value)
+ /**
+ * Set the mamixum level of parallelism to evaluate models in parallel.
+ * Default is 1 for serial evaluation
+ *
+ * @group expertSetParam
+ */
+ @Since("2.3.0")
+ def setParallelism(value: Int): this.type = set(parallelism, value)
+
@Since("2.0.0")
override def fit(dataset: Dataset[_]): CrossValidatorModel = {
val schema = dataset.schema
@@ -99,32 +109,49 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val est = $(estimator)
val eval = $(evaluator)
val epm = $(estimatorParamMaps)
- val numModels = epm.length
- val metrics = new Array[Double](epm.length)
+
+ // Create execution context based on $(parallelism)
+ val executionContext = getExecutionContext
val instr = Instrumentation.create(this, dataset)
- instr.logParams(numFolds, seed)
+ instr.logParams(numFolds, seed, parallelism)
logTuningParams(instr)
+ // Compute metrics for each model over each split
val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
- splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
+ val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) =>
val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
- // multi-model training
logDebug(s"Train split $splitIndex with multiple sets of parameters.")
- val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
- trainingDataset.unpersist()
- var i = 0
- while (i < numModels) {
- // TODO: duplicate evaluator to take extra params from input
- val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
- logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
- metrics(i) += metric
- i += 1
+
+ // Fit models in a Future for training in parallel
+ val modelFutures = epm.map { paramMap =>
+ Future[Model[_]] {
+ val model = est.fit(trainingDataset, paramMap)
+ model.asInstanceOf[Model[_]]
+ } (executionContext)
+ }
+
+ // Unpersist training data only when all models have trained
+ Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext)
+ .onComplete { _ => trainingDataset.unpersist() } (executionContext)
+
+ // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up
+ val foldMetricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) =>
+ modelFuture.map { model =>
+ // TODO: duplicate evaluator to take extra params from input
+ val metric = eval.evaluate(model.transform(validationDataset, paramMap))
+ logDebug(s"Got metric $metric for model trained with $paramMap.")
+ metric
+ } (executionContext)
}
+
+ // Wait for metrics to be calculated before unpersisting validation dataset
+ val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
validationDataset.unpersist()
- }
- f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1)
+ foldMetrics
+ }.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits
+
logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
val (bestMetric, bestIndex) =
if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index db7c9d13d301a..16db0f5f12c77 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -20,6 +20,8 @@ package org.apache.spark.ml.tuning
import java.util.{List => JList}
import scala.collection.JavaConverters._
+import scala.concurrent.Future
+import scala.concurrent.duration.Duration
import scala.language.existentials
import org.apache.hadoop.fs.Path
@@ -30,9 +32,11 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
+import org.apache.spark.ml.param.shared.HasParallelism
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.ThreadUtils
/**
* Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]].
@@ -62,7 +66,7 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams {
@Since("1.5.0")
class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String)
extends Estimator[TrainValidationSplitModel]
- with TrainValidationSplitParams with MLWritable with Logging {
+ with TrainValidationSplitParams with HasParallelism with MLWritable with Logging {
@Since("1.5.0")
def this() = this(Identifiable.randomUID("tvs"))
@@ -87,6 +91,15 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
@Since("2.0.0")
def setSeed(value: Long): this.type = set(seed, value)
+ /**
+ * Set the mamixum level of parallelism to evaluate models in parallel.
+ * Default is 1 for serial evaluation
+ *
+ * @group expertSetParam
+ */
+ @Since("2.3.0")
+ def setParallelism(value: Int): this.type = set(parallelism, value)
+
@Since("2.0.0")
override def fit(dataset: Dataset[_]): TrainValidationSplitModel = {
val schema = dataset.schema
@@ -94,11 +107,12 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
val est = $(estimator)
val eval = $(evaluator)
val epm = $(estimatorParamMaps)
- val numModels = epm.length
- val metrics = new Array[Double](epm.length)
+
+ // Create execution context based on $(parallelism)
+ val executionContext = getExecutionContext
val instr = Instrumentation.create(this, dataset)
- instr.logParams(trainRatio, seed)
+ instr.logParams(trainRatio, seed, parallelism)
logTuningParams(instr)
val Array(trainingDataset, validationDataset) =
@@ -106,18 +120,33 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
trainingDataset.cache()
validationDataset.cache()
- // multi-model training
+ // Fit models in a Future for training in parallel
logDebug(s"Train split with multiple sets of parameters.")
- val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
- trainingDataset.unpersist()
- var i = 0
- while (i < numModels) {
- // TODO: duplicate evaluator to take extra params from input
- val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)))
- logDebug(s"Got metric $metric for model trained with ${epm(i)}.")
- metrics(i) += metric
- i += 1
+ val modelFutures = epm.map { paramMap =>
+ Future[Model[_]] {
+ val model = est.fit(trainingDataset, paramMap)
+ model.asInstanceOf[Model[_]]
+ } (executionContext)
}
+
+ // Unpersist training data only when all models have trained
+ Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext)
+ .onComplete { _ => trainingDataset.unpersist() } (executionContext)
+
+ // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up
+ val metricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) =>
+ modelFuture.map { model =>
+ // TODO: duplicate evaluator to take extra params from input
+ val metric = eval.evaluate(model.transform(validationDataset, paramMap))
+ logDebug(s"Got metric $metric for model trained with $paramMap.")
+ metric
+ } (executionContext)
+ }
+
+ // Wait for all metrics to be calculated
+ val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
+
+ // Unpersist validation set once all metrics have been produced
validationDataset.unpersist()
logInfo(s"Train validation split metrics: ${metrics.toSeq}")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
index 9b7cd0427f5ed..2cfcf38eb4ca8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
@@ -98,16 +98,16 @@ class BinaryClassificationMetrics @Since("1.3.0") (
/**
* Returns the precision-recall curve, which is an RDD of (recall, precision),
- * NOT (precision, recall), with (0.0, 1.0) prepended to it.
+ * NOT (precision, recall), with (0.0, p) prepended to it, where p is the precision
+ * associated with the lowest recall on the curve.
* @see
* Precision and recall (Wikipedia)
*/
@Since("1.0.0")
def pr(): RDD[(Double, Double)] = {
val prCurve = createCurve(Recall, Precision)
- val sc = confusions.context
- val first = sc.makeRDD(Seq((0.0, 1.0)), 1)
- first.union(prCurve)
+ val (_, firstPrecision) = prCurve.first()
+ confusions.context.parallelize(Seq((0.0, firstPrecision)), 1).union(prCurve)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
index aaecfa8d45dc0..a01503f4b80a6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala
@@ -44,6 +44,11 @@ class PCA @Since("1.4.0") (@Since("1.4.0") val k: Int) {
require(k <= numFeatures,
s"source vector size $numFeatures must be no less than k=$k")
+ require(PCAUtil.memoryCost(k, numFeatures) < Int.MaxValue,
+ "The param k and numFeatures is too large for SVD computation. " +
+ "Try reducing the parameter k for PCA, or reduce the input feature " +
+ "vector dimension to make this tractable.")
+
val mat = new RowMatrix(sources)
val (pc, explainedVariance) = mat.computePrincipalComponentsAndExplainedVariance(k)
val densePC = pc match {
@@ -110,3 +115,17 @@ class PCAModel private[spark] (
}
}
}
+
+private[feature] object PCAUtil {
+
+ // This memory cost formula is from breeze code:
+ // https://github.com/scalanlp/breeze/blob/
+ // 6e541be066d547a097f5089165cd7c38c3ca276d/math/src/main/scala/breeze/linalg/
+ // functions/svd.scala#L87
+ def memoryCost(k: Int, numFeatures: Int): Long = {
+ 3L * math.min(k, numFeatures) * math.min(k, numFeatures)
+ + math.max(math.max(k, numFeatures), 4L * math.min(k, numFeatures)
+ * math.min(k, numFeatures) + 4L * math.min(k, numFeatures))
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
index 05eeff532f12e..21ec287e497d4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
@@ -263,9 +263,6 @@ object LBFGS extends Logging {
// broadcasted model is not needed anymore
bcW.destroy(blocking = false)
- // broadcasted model is not needed anymore
- bcW.destroy()
-
/**
* regVal is sum of weight squares if it's L2 updater;
* for other updater, the same logic is followed.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 7dc0c459ec032..8121880cfb233 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -213,8 +213,9 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
var i = 0
val len = currM2n.length
while (i < len) {
- realVariance(i) = (currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
- (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator
+ // We prevent variance from negative value caused by numerical error.
+ realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
+ (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
i += 1
}
}
diff --git a/mllib/src/test/resources/test-data/iris.libsvm b/mllib/src/test/resources/test-data/iris.libsvm
new file mode 100644
index 0000000000000..db959010255d0
--- /dev/null
+++ b/mllib/src/test/resources/test-data/iris.libsvm
@@ -0,0 +1,150 @@
+0.0 1:5.1 2:3.5 3:1.4 4:0.2
+0.0 1:4.9 2:3.0 3:1.4 4:0.2
+0.0 1:4.7 2:3.2 3:1.3 4:0.2
+0.0 1:4.6 2:3.1 3:1.5 4:0.2
+0.0 1:5.0 2:3.6 3:1.4 4:0.2
+0.0 1:5.4 2:3.9 3:1.7 4:0.4
+0.0 1:4.6 2:3.4 3:1.4 4:0.3
+0.0 1:5.0 2:3.4 3:1.5 4:0.2
+0.0 1:4.4 2:2.9 3:1.4 4:0.2
+0.0 1:4.9 2:3.1 3:1.5 4:0.1
+0.0 1:5.4 2:3.7 3:1.5 4:0.2
+0.0 1:4.8 2:3.4 3:1.6 4:0.2
+0.0 1:4.8 2:3.0 3:1.4 4:0.1
+0.0 1:4.3 2:3.0 3:1.1 4:0.1
+0.0 1:5.8 2:4.0 3:1.2 4:0.2
+0.0 1:5.7 2:4.4 3:1.5 4:0.4
+0.0 1:5.4 2:3.9 3:1.3 4:0.4
+0.0 1:5.1 2:3.5 3:1.4 4:0.3
+0.0 1:5.7 2:3.8 3:1.7 4:0.3
+0.0 1:5.1 2:3.8 3:1.5 4:0.3
+0.0 1:5.4 2:3.4 3:1.7 4:0.2
+0.0 1:5.1 2:3.7 3:1.5 4:0.4
+0.0 1:4.6 2:3.6 3:1.0 4:0.2
+0.0 1:5.1 2:3.3 3:1.7 4:0.5
+0.0 1:4.8 2:3.4 3:1.9 4:0.2
+0.0 1:5.0 2:3.0 3:1.6 4:0.2
+0.0 1:5.0 2:3.4 3:1.6 4:0.4
+0.0 1:5.2 2:3.5 3:1.5 4:0.2
+0.0 1:5.2 2:3.4 3:1.4 4:0.2
+0.0 1:4.7 2:3.2 3:1.6 4:0.2
+0.0 1:4.8 2:3.1 3:1.6 4:0.2
+0.0 1:5.4 2:3.4 3:1.5 4:0.4
+0.0 1:5.2 2:4.1 3:1.5 4:0.1
+0.0 1:5.5 2:4.2 3:1.4 4:0.2
+0.0 1:4.9 2:3.1 3:1.5 4:0.1
+0.0 1:5.0 2:3.2 3:1.2 4:0.2
+0.0 1:5.5 2:3.5 3:1.3 4:0.2
+0.0 1:4.9 2:3.1 3:1.5 4:0.1
+0.0 1:4.4 2:3.0 3:1.3 4:0.2
+0.0 1:5.1 2:3.4 3:1.5 4:0.2
+0.0 1:5.0 2:3.5 3:1.3 4:0.3
+0.0 1:4.5 2:2.3 3:1.3 4:0.3
+0.0 1:4.4 2:3.2 3:1.3 4:0.2
+0.0 1:5.0 2:3.5 3:1.6 4:0.6
+0.0 1:5.1 2:3.8 3:1.9 4:0.4
+0.0 1:4.8 2:3.0 3:1.4 4:0.3
+0.0 1:5.1 2:3.8 3:1.6 4:0.2
+0.0 1:4.6 2:3.2 3:1.4 4:0.2
+0.0 1:5.3 2:3.7 3:1.5 4:0.2
+0.0 1:5.0 2:3.3 3:1.4 4:0.2
+1.0 1:7.0 2:3.2 3:4.7 4:1.4
+1.0 1:6.4 2:3.2 3:4.5 4:1.5
+1.0 1:6.9 2:3.1 3:4.9 4:1.5
+1.0 1:5.5 2:2.3 3:4.0 4:1.3
+1.0 1:6.5 2:2.8 3:4.6 4:1.5
+1.0 1:5.7 2:2.8 3:4.5 4:1.3
+1.0 1:6.3 2:3.3 3:4.7 4:1.6
+1.0 1:4.9 2:2.4 3:3.3 4:1.0
+1.0 1:6.6 2:2.9 3:4.6 4:1.3
+1.0 1:5.2 2:2.7 3:3.9 4:1.4
+1.0 1:5.0 2:2.0 3:3.5 4:1.0
+1.0 1:5.9 2:3.0 3:4.2 4:1.5
+1.0 1:6.0 2:2.2 3:4.0 4:1.0
+1.0 1:6.1 2:2.9 3:4.7 4:1.4
+1.0 1:5.6 2:2.9 3:3.6 4:1.3
+1.0 1:6.7 2:3.1 3:4.4 4:1.4
+1.0 1:5.6 2:3.0 3:4.5 4:1.5
+1.0 1:5.8 2:2.7 3:4.1 4:1.0
+1.0 1:6.2 2:2.2 3:4.5 4:1.5
+1.0 1:5.6 2:2.5 3:3.9 4:1.1
+1.0 1:5.9 2:3.2 3:4.8 4:1.8
+1.0 1:6.1 2:2.8 3:4.0 4:1.3
+1.0 1:6.3 2:2.5 3:4.9 4:1.5
+1.0 1:6.1 2:2.8 3:4.7 4:1.2
+1.0 1:6.4 2:2.9 3:4.3 4:1.3
+1.0 1:6.6 2:3.0 3:4.4 4:1.4
+1.0 1:6.8 2:2.8 3:4.8 4:1.4
+1.0 1:6.7 2:3.0 3:5.0 4:1.7
+1.0 1:6.0 2:2.9 3:4.5 4:1.5
+1.0 1:5.7 2:2.6 3:3.5 4:1.0
+1.0 1:5.5 2:2.4 3:3.8 4:1.1
+1.0 1:5.5 2:2.4 3:3.7 4:1.0
+1.0 1:5.8 2:2.7 3:3.9 4:1.2
+1.0 1:6.0 2:2.7 3:5.1 4:1.6
+1.0 1:5.4 2:3.0 3:4.5 4:1.5
+1.0 1:6.0 2:3.4 3:4.5 4:1.6
+1.0 1:6.7 2:3.1 3:4.7 4:1.5
+1.0 1:6.3 2:2.3 3:4.4 4:1.3
+1.0 1:5.6 2:3.0 3:4.1 4:1.3
+1.0 1:5.5 2:2.5 3:4.0 4:1.3
+1.0 1:5.5 2:2.6 3:4.4 4:1.2
+1.0 1:6.1 2:3.0 3:4.6 4:1.4
+1.0 1:5.8 2:2.6 3:4.0 4:1.2
+1.0 1:5.0 2:2.3 3:3.3 4:1.0
+1.0 1:5.6 2:2.7 3:4.2 4:1.3
+1.0 1:5.7 2:3.0 3:4.2 4:1.2
+1.0 1:5.7 2:2.9 3:4.2 4:1.3
+1.0 1:6.2 2:2.9 3:4.3 4:1.3
+1.0 1:5.1 2:2.5 3:3.0 4:1.1
+1.0 1:5.7 2:2.8 3:4.1 4:1.3
+2.0 1:6.3 2:3.3 3:6.0 4:2.5
+2.0 1:5.8 2:2.7 3:5.1 4:1.9
+2.0 1:7.1 2:3.0 3:5.9 4:2.1
+2.0 1:6.3 2:2.9 3:5.6 4:1.8
+2.0 1:6.5 2:3.0 3:5.8 4:2.2
+2.0 1:7.6 2:3.0 3:6.6 4:2.1
+2.0 1:4.9 2:2.5 3:4.5 4:1.7
+2.0 1:7.3 2:2.9 3:6.3 4:1.8
+2.0 1:6.7 2:2.5 3:5.8 4:1.8
+2.0 1:7.2 2:3.6 3:6.1 4:2.5
+2.0 1:6.5 2:3.2 3:5.1 4:2.0
+2.0 1:6.4 2:2.7 3:5.3 4:1.9
+2.0 1:6.8 2:3.0 3:5.5 4:2.1
+2.0 1:5.7 2:2.5 3:5.0 4:2.0
+2.0 1:5.8 2:2.8 3:5.1 4:2.4
+2.0 1:6.4 2:3.2 3:5.3 4:2.3
+2.0 1:6.5 2:3.0 3:5.5 4:1.8
+2.0 1:7.7 2:3.8 3:6.7 4:2.2
+2.0 1:7.7 2:2.6 3:6.9 4:2.3
+2.0 1:6.0 2:2.2 3:5.0 4:1.5
+2.0 1:6.9 2:3.2 3:5.7 4:2.3
+2.0 1:5.6 2:2.8 3:4.9 4:2.0
+2.0 1:7.7 2:2.8 3:6.7 4:2.0
+2.0 1:6.3 2:2.7 3:4.9 4:1.8
+2.0 1:6.7 2:3.3 3:5.7 4:2.1
+2.0 1:7.2 2:3.2 3:6.0 4:1.8
+2.0 1:6.2 2:2.8 3:4.8 4:1.8
+2.0 1:6.1 2:3.0 3:4.9 4:1.8
+2.0 1:6.4 2:2.8 3:5.6 4:2.1
+2.0 1:7.2 2:3.0 3:5.8 4:1.6
+2.0 1:7.4 2:2.8 3:6.1 4:1.9
+2.0 1:7.9 2:3.8 3:6.4 4:2.0
+2.0 1:6.4 2:2.8 3:5.6 4:2.2
+2.0 1:6.3 2:2.8 3:5.1 4:1.5
+2.0 1:6.1 2:2.6 3:5.6 4:1.4
+2.0 1:7.7 2:3.0 3:6.1 4:2.3
+2.0 1:6.3 2:3.4 3:5.6 4:2.4
+2.0 1:6.4 2:3.1 3:5.5 4:1.8
+2.0 1:6.0 2:3.0 3:4.8 4:1.8
+2.0 1:6.9 2:3.1 3:5.4 4:2.1
+2.0 1:6.7 2:3.1 3:5.6 4:2.4
+2.0 1:6.9 2:3.1 3:5.1 4:2.3
+2.0 1:5.8 2:2.7 3:5.1 4:1.9
+2.0 1:6.8 2:3.2 3:5.9 4:2.3
+2.0 1:6.7 2:3.3 3:5.7 4:2.5
+2.0 1:6.7 2:3.0 3:5.2 4:2.3
+2.0 1:6.3 2:2.5 3:5.0 4:1.9
+2.0 1:6.5 2:3.0 3:5.2 4:2.0
+2.0 1:6.2 2:3.4 3:5.4 4:2.3
+2.0 1:5.9 2:3.0 3:5.1 4:1.8
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 4a7e4dd80f246..7848eae931a06 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -22,7 +22,7 @@ import scala.collection.JavaConverters._
import org.apache.hadoop.fs.Path
import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.when
-import org.scalatest.mock.MockitoSugar.mock
+import org.scalatest.mockito.MockitoSugar.mock
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.Pipeline.SharedReadWrite
diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala
index f0c0183323c92..2f225645bdfc4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala
@@ -64,7 +64,7 @@ class GradientSuite extends SparkFunSuite with MLlibTestSparkContext {
}
private def computeLoss(input: BDM[Double], target: BDM[Double], model: TopologyModel): Double = {
- val outputs = model.forward(input)
+ val outputs = model.forward(input, true)
model.layerModels.last match {
case layerWithLoss: LossFunction =>
layerWithLoss.loss(outputs.last, target, new BDM[Double](target.rows, target.cols))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 918ab27e2730b..98c879ece62d6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -262,6 +262,9 @@ class DecisionTreeClassifierSuite
assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
"probability prediction mismatch")
}
+
+ ProbabilisticClassifierSuite.testPredictMethods[
+ Vector, DecisionTreeClassificationModel](newTree, newData)
}
test("training with 1-category categorical feature") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 1f79e0d4e6228..8000143d4d142 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -219,6 +219,9 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
resultsUsingPredict.zip(results.select(predictionCol).as[Double].collect()).foreach {
case (pred1, pred2) => assert(pred1 === pred2)
}
+
+ ProbabilisticClassifierSuite.testPredictMethods[
+ Vector, GBTClassificationModel](gbtModel, validationDataset)
}
test("GBT parameter stepSize should be in interval (0, 1]") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
index f2b00d0bae1d6..41a5d22dd6283 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
@@ -25,7 +25,8 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.classification.LinearSVCSuite._
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
-import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
+import org.apache.spark.ml.optim.aggregator.HingeAggregator
+import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -170,10 +171,10 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
assert(model2.intercept !== 0.0)
}
- test("sparse coefficients in SVCAggregator") {
+ test("sparse coefficients in HingeAggregator") {
val bcCoefficients = spark.sparkContext.broadcast(Vectors.sparse(2, Array(0), Array(1.0)))
val bcFeaturesStd = spark.sparkContext.broadcast(Array(1.0))
- val agg = new LinearSVCAggregator(bcCoefficients, bcFeaturesStd, true)
+ val agg = new HingeAggregator(bcFeaturesStd, true)(bcCoefficients)
val thrown = withClue("LinearSVCAggregator cannot handle sparse coefficients") {
intercept[IllegalArgumentException] {
agg.add(Instance(1.0, 1.0, Vectors.dense(1.0)))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 0570499e74516..14f550890d238 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -46,6 +46,7 @@ class LogisticRegressionSuite
@transient var smallMultinomialDataset: Dataset[_] = _
@transient var binaryDataset: Dataset[_] = _
@transient var multinomialDataset: Dataset[_] = _
+ @transient var multinomialDatasetWithZeroVar: Dataset[_] = _
private val eps: Double = 1e-5
override def beforeAll(): Unit = {
@@ -99,6 +100,23 @@ class LogisticRegressionSuite
df.cache()
df
}
+
+ multinomialDatasetWithZeroVar = {
+ val nPoints = 100
+ val coefficients = Array(
+ -0.57997, 0.912083, -0.371077,
+ -0.16624, -0.84355, -0.048509)
+
+ val xMean = Array(5.843, 3.0)
+ val xVariance = Array(0.6856, 0.0)
+
+ val testData = generateMultinomialLogisticInput(
+ coefficients, xMean, xVariance, addIntercept = true, nPoints, seed)
+
+ val df = sc.parallelize(testData, 4).toDF().withColumn("weight", lit(1.0))
+ df.cache()
+ df
+ }
}
/**
@@ -112,6 +130,11 @@ class LogisticRegressionSuite
multinomialDataset.rdd.map { case Row(label: Double, features: Vector, weight: Double) =>
label + "," + weight + "," + features.toArray.mkString(",")
}.repartition(1).saveAsTextFile("target/tmp/LogisticRegressionSuite/multinomialDataset")
+ multinomialDatasetWithZeroVar.rdd.map {
+ case Row(label: Double, features: Vector, weight: Double) =>
+ label + "," + weight + "," + features.toArray.mkString(",")
+ }.repartition(1)
+ .saveAsTextFile("target/tmp/LogisticRegressionSuite/multinomialDatasetWithZeroVar")
}
test("params") {
@@ -199,15 +222,64 @@ class LogisticRegressionSuite
}
}
- test("empty probabilityCol") {
- val lr = new LogisticRegression().setProbabilityCol("")
- val model = lr.fit(smallBinaryDataset)
- assert(model.hasSummary)
- // Validate that we re-insert a probability column for evaluation
- val fieldNames = model.summary.predictions.schema.fieldNames
- assert(smallBinaryDataset.schema.fieldNames.toSet.subsetOf(
- fieldNames.toSet))
- assert(fieldNames.exists(s => s.startsWith("probability_")))
+ test("empty probabilityCol or predictionCol") {
+ val lr = new LogisticRegression().setMaxIter(1)
+ val datasetFieldNames = smallBinaryDataset.schema.fieldNames.toSet
+ def checkSummarySchema(model: LogisticRegressionModel, columns: Seq[String]): Unit = {
+ val fieldNames = model.summary.predictions.schema.fieldNames
+ assert(model.hasSummary)
+ assert(datasetFieldNames.subsetOf(fieldNames.toSet))
+ columns.foreach { c => assert(fieldNames.exists(_.startsWith(c))) }
+ }
+ // check that the summary model adds the appropriate columns
+ Seq(("binomial", smallBinaryDataset), ("multinomial", smallMultinomialDataset)).foreach {
+ case (family, dataset) =>
+ lr.setFamily(family)
+ lr.setProbabilityCol("").setPredictionCol("prediction")
+ val modelNoProb = lr.fit(dataset)
+ checkSummarySchema(modelNoProb, Seq("probability_"))
+
+ lr.setProbabilityCol("probability").setPredictionCol("")
+ val modelNoPred = lr.fit(dataset)
+ checkSummarySchema(modelNoPred, Seq("prediction_"))
+
+ lr.setProbabilityCol("").setPredictionCol("")
+ val modelNoPredNoProb = lr.fit(dataset)
+ checkSummarySchema(modelNoPredNoProb, Seq("prediction_", "probability_"))
+ }
+ }
+
+ test("check summary types for binary and multiclass") {
+ val lr = new LogisticRegression()
+ .setFamily("binomial")
+ .setMaxIter(1)
+
+ val blorModel = lr.fit(smallBinaryDataset)
+ assert(blorModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
+ assert(blorModel.summary.asBinary.isInstanceOf[BinaryLogisticRegressionSummary])
+ assert(blorModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
+
+ val mlorModel = lr.setFamily("multinomial").fit(smallMultinomialDataset)
+ assert(mlorModel.summary.isInstanceOf[LogisticRegressionTrainingSummary])
+ withClue("cannot get binary summary for multiclass model") {
+ intercept[RuntimeException] {
+ mlorModel.binarySummary
+ }
+ }
+ withClue("cannot cast summary to binary summary multiclass model") {
+ intercept[RuntimeException] {
+ mlorModel.summary.asBinary
+ }
+ }
+
+ val mlorBinaryModel = lr.setFamily("multinomial").fit(smallBinaryDataset)
+ assert(mlorBinaryModel.summary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
+ assert(mlorBinaryModel.binarySummary.isInstanceOf[BinaryLogisticRegressionTrainingSummary])
+
+ val blorSummary = blorModel.evaluate(smallBinaryDataset)
+ val mlorSummary = mlorModel.evaluate(smallMultinomialDataset)
+ assert(blorSummary.isInstanceOf[BinaryLogisticRegressionSummary])
+ assert(mlorSummary.isInstanceOf[LogisticRegressionSummary])
}
test("setThreshold, getThreshold") {
@@ -430,6 +502,9 @@ class LogisticRegressionSuite
resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach {
case (pred1, pred2) => assert(pred1 === pred2)
}
+
+ ProbabilisticClassifierSuite.testPredictMethods[
+ Vector, LogisticRegressionModel](model, smallMultinomialDataset)
}
test("binary logistic regression: Predictor, Classifier methods") {
@@ -484,6 +559,9 @@ class LogisticRegressionSuite
resultsUsingPredict.zip(results.select("prediction").as[Double].collect()).foreach {
case (pred1, pred2) => assert(pred1 === pred2)
}
+
+ ProbabilisticClassifierSuite.testPredictMethods[
+ Vector, LogisticRegressionModel](model, smallBinaryDataset)
}
test("coefficients and intercept methods") {
@@ -1392,6 +1470,61 @@ class LogisticRegressionSuite
assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps)
}
+ test("multinomial logistic regression with zero variance (SPARK-21681)") {
+ val sqlContext = multinomialDatasetWithZeroVar.sqlContext
+ import sqlContext.implicits._
+ val mlr = new LogisticRegression().setFamily("multinomial").setFitIntercept(true)
+ .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true).setWeightCol("weight")
+
+ val model = mlr.fit(multinomialDatasetWithZeroVar)
+
+ /*
+ Use the following R code to load the data and train the model using glmnet package.
+
+ library("glmnet")
+ data <- read.csv("path", header=FALSE)
+ label = as.factor(data$V1)
+ w = data$V2
+ features = as.matrix(data.frame(data$V3, data$V4))
+ coefficients = coef(glmnet(features, label, weights=w, family="multinomial",
+ alpha = 0, lambda = 0))
+ coefficients
+ $`0`
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ 0.2658824
+ data.V3 0.1881871
+ data.V4 .
+
+ $`1`
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ 0.53604701
+ data.V3 -0.02412645
+ data.V4 .
+
+ $`2`
+ 3 x 1 sparse Matrix of class "dgCMatrix"
+ s0
+ -0.8019294
+ data.V3 -0.1640607
+ data.V4 .
+ */
+
+ val coefficientsR = new DenseMatrix(3, 2, Array(
+ 0.1881871, 0.0,
+ -0.02412645, 0.0,
+ -0.1640607, 0.0), isTransposed = true)
+ val interceptsR = Vectors.dense(0.2658824, 0.53604701, -0.8019294)
+
+ model.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps))
+
+ assert(model.coefficientMatrix ~== coefficientsR relTol 0.05)
+ assert(model.coefficientMatrix.toArray.sum ~== 0.0 absTol eps)
+ assert(model.interceptVector ~== interceptsR relTol 0.05)
+ assert(model.interceptVector.toArray.sum ~== 0.0 absTol eps)
+ }
+
test("multinomial logistic regression with intercept without regularization with bound") {
// Bound constrained optimization with bound on one side.
val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0))
@@ -2263,51 +2396,110 @@ class LogisticRegressionSuite
}
test("evaluate on test set") {
- // TODO: add for multiclass when model summary becomes available
// Evaluate on test set should be same as that of the transformed training data.
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0)
.setThreshold(0.6)
- val model = lr.fit(smallBinaryDataset)
- val summary = model.summary.asInstanceOf[BinaryLogisticRegressionSummary]
-
- val sameSummary =
- model.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary]
- assert(summary.areaUnderROC === sameSummary.areaUnderROC)
- assert(summary.roc.collect() === sameSummary.roc.collect())
- assert(summary.pr.collect === sameSummary.pr.collect())
+ .setFamily("binomial")
+ val blorModel = lr.fit(smallBinaryDataset)
+ val blorSummary = blorModel.binarySummary
+
+ val sameBlorSummary =
+ blorModel.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary]
+ assert(blorSummary.areaUnderROC === sameBlorSummary.areaUnderROC)
+ assert(blorSummary.roc.collect() === sameBlorSummary.roc.collect())
+ assert(blorSummary.pr.collect === sameBlorSummary.pr.collect())
assert(
- summary.fMeasureByThreshold.collect() === sameSummary.fMeasureByThreshold.collect())
- assert(summary.recallByThreshold.collect() === sameSummary.recallByThreshold.collect())
+ blorSummary.fMeasureByThreshold.collect() === sameBlorSummary.fMeasureByThreshold.collect())
assert(
- summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect())
+ blorSummary.recallByThreshold.collect() === sameBlorSummary.recallByThreshold.collect())
+ assert(
+ blorSummary.precisionByThreshold.collect() === sameBlorSummary.precisionByThreshold.collect())
+ assert(blorSummary.labels === sameBlorSummary.labels)
+ assert(blorSummary.truePositiveRateByLabel === sameBlorSummary.truePositiveRateByLabel)
+ assert(blorSummary.falsePositiveRateByLabel === sameBlorSummary.falsePositiveRateByLabel)
+ assert(blorSummary.precisionByLabel === sameBlorSummary.precisionByLabel)
+ assert(blorSummary.recallByLabel === sameBlorSummary.recallByLabel)
+ assert(blorSummary.fMeasureByLabel === sameBlorSummary.fMeasureByLabel)
+ assert(blorSummary.accuracy === sameBlorSummary.accuracy)
+ assert(blorSummary.weightedTruePositiveRate === sameBlorSummary.weightedTruePositiveRate)
+ assert(blorSummary.weightedFalsePositiveRate === sameBlorSummary.weightedFalsePositiveRate)
+ assert(blorSummary.weightedRecall === sameBlorSummary.weightedRecall)
+ assert(blorSummary.weightedPrecision === sameBlorSummary.weightedPrecision)
+ assert(blorSummary.weightedFMeasure === sameBlorSummary.weightedFMeasure)
+
+ lr.setFamily("multinomial")
+ val mlorModel = lr.fit(smallMultinomialDataset)
+ val mlorSummary = mlorModel.summary
+
+ val mlorSameSummary = mlorModel.evaluate(smallMultinomialDataset)
+
+ assert(mlorSummary.truePositiveRateByLabel === mlorSameSummary.truePositiveRateByLabel)
+ assert(mlorSummary.falsePositiveRateByLabel === mlorSameSummary.falsePositiveRateByLabel)
+ assert(mlorSummary.precisionByLabel === mlorSameSummary.precisionByLabel)
+ assert(mlorSummary.recallByLabel === mlorSameSummary.recallByLabel)
+ assert(mlorSummary.fMeasureByLabel === mlorSameSummary.fMeasureByLabel)
+ assert(mlorSummary.accuracy === mlorSameSummary.accuracy)
+ assert(mlorSummary.weightedTruePositiveRate === mlorSameSummary.weightedTruePositiveRate)
+ assert(mlorSummary.weightedFalsePositiveRate === mlorSameSummary.weightedFalsePositiveRate)
+ assert(mlorSummary.weightedPrecision === mlorSameSummary.weightedPrecision)
+ assert(mlorSummary.weightedRecall === mlorSameSummary.weightedRecall)
+ assert(mlorSummary.weightedFMeasure === mlorSameSummary.weightedFMeasure)
}
test("evaluate with labels that are not doubles") {
// Evaluate a test set with Label that is a numeric type other than Double
- val lr = new LogisticRegression()
+ val blor = new LogisticRegression()
.setMaxIter(1)
.setRegParam(1.0)
- val model = lr.fit(smallBinaryDataset)
- val summary = model.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary]
+ .setFamily("binomial")
+ val blorModel = blor.fit(smallBinaryDataset)
+ val blorSummary = blorModel.evaluate(smallBinaryDataset)
+ .asInstanceOf[BinaryLogisticRegressionSummary]
- val longLabelData = smallBinaryDataset.select(col(model.getLabelCol).cast(LongType),
- col(model.getFeaturesCol))
- val longSummary = model.evaluate(longLabelData).asInstanceOf[BinaryLogisticRegressionSummary]
+ val blorLongLabelData = smallBinaryDataset.select(col(blorModel.getLabelCol).cast(LongType),
+ col(blorModel.getFeaturesCol))
+ val blorLongSummary = blorModel.evaluate(blorLongLabelData)
+ .asInstanceOf[BinaryLogisticRegressionSummary]
- assert(summary.areaUnderROC === longSummary.areaUnderROC)
+ assert(blorSummary.areaUnderROC === blorLongSummary.areaUnderROC)
+
+ val mlor = new LogisticRegression()
+ .setMaxIter(1)
+ .setRegParam(1.0)
+ .setFamily("multinomial")
+ val mlorModel = mlor.fit(smallMultinomialDataset)
+ val mlorSummary = mlorModel.evaluate(smallMultinomialDataset)
+
+ val mlorLongLabelData = smallMultinomialDataset.select(
+ col(mlorModel.getLabelCol).cast(LongType),
+ col(mlorModel.getFeaturesCol))
+ val mlorLongSummary = mlorModel.evaluate(mlorLongLabelData)
+
+ assert(mlorSummary.accuracy === mlorLongSummary.accuracy)
}
test("statistics on training data") {
// Test that loss is monotonically decreasing.
- val lr = new LogisticRegression()
+ val blor = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0)
- .setThreshold(0.6)
- val model = lr.fit(smallBinaryDataset)
+ .setFamily("binomial")
+ val blorModel = blor.fit(smallBinaryDataset)
+ assert(
+ blorModel.summary
+ .objectiveHistory
+ .sliding(2)
+ .forall(x => x(0) >= x(1)))
+
+ val mlor = new LogisticRegression()
+ .setMaxIter(10)
+ .setRegParam(1.0)
+ .setFamily("multinomial")
+ val mlorModel = mlor.fit(smallMultinomialDataset)
assert(
- model.summary
+ mlorModel.summary
.objectiveHistory
.sliding(2)
.forall(x => x(0) >= x(1)))
@@ -2392,7 +2584,7 @@ class LogisticRegressionSuite
predictions3.zip(predictions4).foreach { case (Row(p1: Double), Row(p2: Double)) =>
assert(p1 === p2)
}
- // TODO: check that it converges in a single iteration when model summary is available
+ assert(model4.summary.totalIterations === 1)
}
test("binary logistic regression with all labels the same") {
@@ -2453,6 +2645,7 @@ class LogisticRegressionSuite
assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0)))
assert(pred === 4.0)
}
+ assert(model.summary.totalIterations === 0)
// force the model to be trained with only one class
val constantZeroData = Seq(
@@ -2466,6 +2659,7 @@ class LogisticRegressionSuite
assert(prob === Vectors.dense(Array(1.0)))
assert(pred === 0.0)
}
+ assert(modelZeroLabel.summary.totalIterations > 0)
// ensure that the correct value is predicted when numClasses passed through metadata
val labelMeta = NominalAttribute.defaultAttr.withName("label").withNumValues(6).toMetadata()
@@ -2479,7 +2673,7 @@ class LogisticRegressionSuite
assert(prob === Vectors.dense(Array(0.0, 0.0, 0.0, 0.0, 1.0, 0.0)))
assert(pred === 4.0)
}
- // TODO: check num iters is zero when it become available in the model
+ require(modelWithMetadata.summary.totalIterations === 0)
}
test("compressed storage for constant label") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index 94d71fd532332..31c5317060d6b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -29,6 +29,7 @@ import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Dataset, Row}
+import org.apache.spark.sql.functions._
class MultilayerPerceptronClassifierSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -83,6 +84,49 @@ class MultilayerPerceptronClassifierSuite
}
}
+ test("Predicted class probabilities: calibration on toy dataset") {
+ val layers = Array[Int](4, 5, 2)
+
+ val strongDataset = Seq(
+ (Vectors.dense(1, 2, 3, 4), 0d, Vectors.dense(1d, 0d)),
+ (Vectors.dense(4, 3, 2, 1), 1d, Vectors.dense(0d, 1d)),
+ (Vectors.dense(1, 1, 1, 1), 0d, Vectors.dense(.5, .5)),
+ (Vectors.dense(1, 1, 1, 1), 1d, Vectors.dense(.5, .5))
+ ).toDF("features", "label", "expectedProbability")
+ val trainer = new MultilayerPerceptronClassifier()
+ .setLayers(layers)
+ .setBlockSize(1)
+ .setSeed(123L)
+ .setMaxIter(100)
+ .setSolver("l-bfgs")
+ val model = trainer.fit(strongDataset)
+ val result = model.transform(strongDataset)
+ result.select("probability", "expectedProbability").collect().foreach {
+ case Row(p: Vector, e: Vector) =>
+ assert(p ~== e absTol 1e-3)
+ }
+ ProbabilisticClassifierSuite.testPredictMethods[
+ Vector, MultilayerPerceptronClassificationModel](model, strongDataset)
+ }
+
+ test("test model probability") {
+ val layers = Array[Int](2, 5, 2)
+ val trainer = new MultilayerPerceptronClassifier()
+ .setLayers(layers)
+ .setBlockSize(1)
+ .setSeed(123L)
+ .setMaxIter(100)
+ .setSolver("l-bfgs")
+ val model = trainer.fit(dataset)
+ model.setProbabilityCol("probability")
+ val result = model.transform(dataset)
+ val features2prob = udf { features: Vector => model.mlpModel.predict(features) }
+ result.select(features2prob(col("features")), col("probability")).collect().foreach {
+ case Row(p1: Vector, p2: Vector) =>
+ assert(p1 ~== p2 absTol 1e-3)
+ }
+ }
+
ignore("Test setWeights by training restart -- ignore palantir/spark") {
val dataFrame = Seq(
(Vectors.dense(0.0, 0.0), 0.0),
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 3a2be236f1257..9730dd68a3b27 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -160,6 +160,9 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val featureAndProbabilities = model.transform(validationDataset)
.select("features", "probability")
validateProbabilities(featureAndProbabilities, model, "multinomial")
+
+ ProbabilisticClassifierSuite.testPredictMethods[
+ Vector, NaiveBayesModel](model, testDataset)
}
test("Naive Bayes with weighted samples") {
@@ -213,6 +216,9 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val featureAndProbabilities = model.transform(validationDataset)
.select("features", "probability")
validateProbabilities(featureAndProbabilities, model, "bernoulli")
+
+ ProbabilisticClassifierSuite.testPredictMethods[
+ Vector, NaiveBayesModel](model, testDataset)
}
test("detect negative values") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 3638cda285ef2..2117c05fc629b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -25,12 +25,12 @@ import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils}
+import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions._
@@ -98,7 +98,45 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
// bound how much error we allow compared to multinomial logistic regression.
val expectedMetrics = new MulticlassMetrics(results)
val ovaMetrics = new MulticlassMetrics(ovaResults)
- assert(expectedMetrics.confusionMatrix ~== ovaMetrics.confusionMatrix absTol 400)
+ assert(expectedMetrics.confusionMatrix.asML ~== ovaMetrics.confusionMatrix.asML absTol 400)
+ }
+
+ test("one-vs-rest: tuning parallelism does not change output") {
+ val ovaPar1 = new OneVsRest()
+ .setClassifier(new LogisticRegression)
+
+ val ovaModelPar1 = ovaPar1.fit(dataset)
+
+ val transformedDatasetPar1 = ovaModelPar1.transform(dataset)
+
+ val ovaResultsPar1 = transformedDatasetPar1.select("prediction", "label").rdd.map {
+ row => (row.getDouble(0), row.getDouble(1))
+ }
+
+ val ovaPar2 = new OneVsRest()
+ .setClassifier(new LogisticRegression)
+ .setParallelism(2)
+
+ val ovaModelPar2 = ovaPar2.fit(dataset)
+
+ val transformedDatasetPar2 = ovaModelPar2.transform(dataset)
+
+ val ovaResultsPar2 = transformedDatasetPar2.select("prediction", "label").rdd.map {
+ row => (row.getDouble(0), row.getDouble(1))
+ }
+
+ val metricsPar1 = new MulticlassMetrics(ovaResultsPar1)
+ val metricsPar2 = new MulticlassMetrics(ovaResultsPar2)
+ assert(metricsPar1.confusionMatrix == metricsPar2.confusionMatrix)
+
+ ovaModelPar1.models.zip(ovaModelPar2.models).foreach {
+ case (lrModel1: LogisticRegressionModel, lrModel2: LogisticRegressionModel) =>
+ assert(lrModel1.coefficients ~== lrModel2.coefficients relTol 1E-3)
+ assert(lrModel1.intercept ~== lrModel2.intercept relTol 1E-3)
+ case other =>
+ throw new AssertionError(s"Loaded OneVsRestModel expected model of type" +
+ s" LogisticRegressionModel but found ${other.getClass.getName}")
+ }
}
test("one-vs-rest: pass label metadata correctly during train") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
index 172c64aab9d3d..4ecd5a05365eb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
@@ -19,6 +19,9 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.sql.{Dataset, Row}
final class TestProbabilisticClassificationModel(
override val uid: String,
@@ -91,4 +94,61 @@ object ProbabilisticClassifierSuite {
"thresholds" -> Array(0.4, 0.6)
)
+ /**
+ * Helper for testing that a ProbabilisticClassificationModel computes
+ * the same predictions across all combinations of output columns
+ * (rawPrediction/probability/prediction) turned on/off. Makes sure the
+ * output column values match by comparing vs. the case with all 3 output
+ * columns turned on.
+ */
+ def testPredictMethods[
+ FeaturesType,
+ M <: ProbabilisticClassificationModel[FeaturesType, M]](
+ model: M, testData: Dataset[_]): Unit = {
+
+ val allColModel = model.copy(ParamMap.empty)
+ .setRawPredictionCol("rawPredictionAll")
+ .setProbabilityCol("probabilityAll")
+ .setPredictionCol("predictionAll")
+ val allColResult = allColModel.transform(testData)
+
+ for (rawPredictionCol <- Seq("", "rawPredictionSingle")) {
+ for (probabilityCol <- Seq("", "probabilitySingle")) {
+ for (predictionCol <- Seq("", "predictionSingle")) {
+ val newModel = model.copy(ParamMap.empty)
+ .setRawPredictionCol(rawPredictionCol)
+ .setProbabilityCol(probabilityCol)
+ .setPredictionCol(predictionCol)
+
+ val result = newModel.transform(allColResult)
+
+ import org.apache.spark.sql.functions._
+
+ val resultRawPredictionCol =
+ if (rawPredictionCol.isEmpty) col("rawPredictionAll") else col(rawPredictionCol)
+ val resultProbabilityCol =
+ if (probabilityCol.isEmpty) col("probabilityAll") else col(probabilityCol)
+ val resultPredictionCol =
+ if (predictionCol.isEmpty) col("predictionAll") else col(predictionCol)
+
+ result.select(
+ resultRawPredictionCol, col("rawPredictionAll"),
+ resultProbabilityCol, col("probabilityAll"),
+ resultPredictionCol, col("predictionAll")
+ ).collect().foreach {
+ case Row(
+ rawPredictionSingle: Vector, rawPredictionAll: Vector,
+ probabilitySingle: Vector, probabilityAll: Vector,
+ predictionSingle: Double, predictionAll: Double
+ ) => {
+ assert(rawPredictionSingle ~== rawPredictionAll relTol 1E-3)
+ assert(probabilitySingle ~== probabilityAll relTol 1E-3)
+ assert(predictionSingle ~== predictionAll relTol 1E-3)
+ }
+ }
+ }
+ }
+ }
+ }
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index ca2954d2f32c4..2cca2e6c04698 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -155,6 +155,8 @@ class RandomForestClassifierSuite
"probability prediction mismatch")
assert(probPred.toArray.sum ~== 1.0 relTol 1E-5)
}
+ ProbabilisticClassifierSuite.testPredictMethods[
+ Vector, RandomForestClassificationModel](model, df)
}
test("Fitting without numClasses in metadata") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
new file mode 100644
index 0000000000000..e60ebbd7c852d
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/ClusteringEvaluatorSuite.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.evaluation
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, SparkSession}
+import org.apache.spark.sql.types.IntegerType
+
+
+class ClusteringEvaluatorSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
+ import testImplicits._
+
+ test("params") {
+ ParamsSuite.checkParams(new ClusteringEvaluator)
+ }
+
+ test("read/write") {
+ val evaluator = new ClusteringEvaluator()
+ .setPredictionCol("myPrediction")
+ .setFeaturesCol("myLabel")
+ testDefaultReadWrite(evaluator)
+ }
+
+ /*
+ Use the following python code to load the data and evaluate it using scikit-learn package.
+
+ from sklearn import datasets
+ from sklearn.metrics import silhouette_score
+ iris = datasets.load_iris()
+ round(silhouette_score(iris.data, iris.target, metric='sqeuclidean'), 10)
+
+ 0.6564679231
+ */
+ test("squared euclidean Silhouette") {
+ val iris = ClusteringEvaluatorSuite.irisDataset(spark)
+ val evaluator = new ClusteringEvaluator()
+ .setFeaturesCol("features")
+ .setPredictionCol("label")
+
+ assert(evaluator.evaluate(iris) ~== 0.6564679231 relTol 1e-5)
+ }
+
+ test("number of clusters must be greater than one") {
+ val iris = ClusteringEvaluatorSuite.irisDataset(spark)
+ .where($"label" === 0.0)
+ val evaluator = new ClusteringEvaluator()
+ .setFeaturesCol("features")
+ .setPredictionCol("label")
+
+ val e = intercept[AssertionError]{
+ evaluator.evaluate(iris)
+ }
+ assert(e.getMessage.contains("Number of clusters must be greater than one"))
+ }
+
+}
+
+object ClusteringEvaluatorSuite {
+ def irisDataset(spark: SparkSession): DataFrame = {
+
+ val irisPath = Thread.currentThread()
+ .getContextClassLoader
+ .getResource("test-data/iris.libsvm")
+ .toString
+
+ spark.read.format("libsvm").load(irisPath)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeAggregatorSuite.scala
new file mode 100644
index 0000000000000..61b48ffa10944
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/HingeAggregatorSuite.scala
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.optim.aggregator
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors}
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class HingeAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ import DifferentiableLossAggregatorSuite.getClassificationSummarizers
+
+ @transient var instances: Array[Instance] = _
+ @transient var instancesConstantFeature: Array[Instance] = _
+ @transient var instancesConstantFeatureFiltered: Array[Instance] = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ instances = Array(
+ Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)),
+ Instance(1.0, 0.5, Vectors.dense(1.5, 1.0)),
+ Instance(0.0, 0.3, Vectors.dense(4.0, 0.5))
+ )
+ instancesConstantFeature = Array(
+ Instance(0.0, 0.1, Vectors.dense(1.0, 2.0)),
+ Instance(1.0, 0.5, Vectors.dense(1.0, 1.0)),
+ Instance(1.0, 0.3, Vectors.dense(1.0, 0.5))
+ )
+ instancesConstantFeatureFiltered = Array(
+ Instance(0.0, 0.1, Vectors.dense(2.0)),
+ Instance(1.0, 0.5, Vectors.dense(1.0)),
+ Instance(2.0, 0.3, Vectors.dense(0.5))
+ )
+ }
+
+ /** Get summary statistics for some data and create a new HingeAggregator. */
+ private def getNewAggregator(
+ instances: Array[Instance],
+ coefficients: Vector,
+ fitIntercept: Boolean): HingeAggregator = {
+ val (featuresSummarizer, ySummarizer) =
+ DifferentiableLossAggregatorSuite.getClassificationSummarizers(instances)
+ val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
+ val bcFeaturesStd = spark.sparkContext.broadcast(featuresStd)
+ val bcCoefficients = spark.sparkContext.broadcast(coefficients)
+ new HingeAggregator(bcFeaturesStd, fitIntercept)(bcCoefficients)
+ }
+
+ test("aggregator add method input size") {
+ val coefArray = Array(1.0, 2.0)
+ val interceptArray = Array(2.0)
+ val agg = getNewAggregator(instances, Vectors.dense(coefArray ++ interceptArray),
+ fitIntercept = true)
+ withClue("HingeAggregator features dimension must match coefficients dimension") {
+ intercept[IllegalArgumentException] {
+ agg.add(Instance(1.0, 1.0, Vectors.dense(2.0)))
+ }
+ }
+ }
+
+ test("negative weight") {
+ val coefArray = Array(1.0, 2.0)
+ val interceptArray = Array(2.0)
+ val agg = getNewAggregator(instances, Vectors.dense(coefArray ++ interceptArray),
+ fitIntercept = true)
+ withClue("HingeAggregator does not support negative instance weights") {
+ intercept[IllegalArgumentException] {
+ agg.add(Instance(1.0, -1.0, Vectors.dense(2.0, 1.0)))
+ }
+ }
+ }
+
+ test("check sizes") {
+ val rng = new scala.util.Random
+ val numFeatures = instances.head.features.size
+ val coefWithIntercept = Vectors.dense(Array.fill(numFeatures + 1)(rng.nextDouble))
+ val coefWithoutIntercept = Vectors.dense(Array.fill(numFeatures)(rng.nextDouble))
+ val aggIntercept = getNewAggregator(instances, coefWithIntercept, fitIntercept = true)
+ val aggNoIntercept = getNewAggregator(instances, coefWithoutIntercept,
+ fitIntercept = false)
+ instances.foreach(aggIntercept.add)
+ instances.foreach(aggNoIntercept.add)
+
+ assert(aggIntercept.gradient.size === numFeatures + 1)
+ assert(aggNoIntercept.gradient.size === numFeatures)
+ }
+
+ test("check correctness") {
+ val coefArray = Array(1.0, 2.0)
+ val intercept = 1.0
+ val numFeatures = instances.head.features.size
+ val (featuresSummarizer, _) = getClassificationSummarizers(instances)
+ val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
+ val weightSum = instances.map(_.weight).sum
+
+ val agg = getNewAggregator(instances, Vectors.dense(coefArray ++ Array(intercept)),
+ fitIntercept = true)
+ instances.foreach(agg.add)
+
+ // compute the loss
+ val stdCoef = coefArray.indices.map(i => coefArray(i) / featuresStd(i)).toArray
+ val lossSum = instances.map { case Instance(l, w, f) =>
+ val margin = BLAS.dot(Vectors.dense(stdCoef), f) + intercept
+ val labelScaled = 2 * l - 1.0
+ if (1.0 > labelScaled * margin) {
+ (1.0 - labelScaled * margin) * w
+ } else {
+ 0.0
+ }
+ }.sum
+ val loss = lossSum / weightSum
+
+ // compute the gradients
+ val gradientCoef = new Array[Double](numFeatures)
+ var gradientIntercept = 0.0
+ instances.foreach { case Instance(l, w, f) =>
+ val margin = BLAS.dot(f, Vectors.dense(coefArray)) + intercept
+ if (1.0 > (2 * l - 1.0) * margin) {
+ gradientCoef.indices.foreach { i =>
+ gradientCoef(i) += f(i) * -(2 * l - 1.0) * w / featuresStd(i)
+ }
+ gradientIntercept += -(2 * l - 1.0) * w
+ }
+ }
+ val gradient = Vectors.dense((gradientCoef ++ Array(gradientIntercept)).map(_ / weightSum))
+
+ assert(loss ~== agg.loss relTol 0.01)
+ assert(gradient ~== agg.gradient relTol 0.01)
+ }
+
+ test("check with zero standard deviation") {
+ val binaryCoefArray = Array(1.0, 2.0)
+ val intercept = 1.0
+ val aggConstantFeatureBinary = getNewAggregator(instancesConstantFeature,
+ Vectors.dense(binaryCoefArray ++ Array(intercept)), fitIntercept = true)
+ instancesConstantFeature.foreach(aggConstantFeatureBinary.add)
+
+ val aggConstantFeatureBinaryFiltered = getNewAggregator(instancesConstantFeatureFiltered,
+ Vectors.dense(binaryCoefArray ++ Array(intercept)), fitIntercept = true)
+ instancesConstantFeatureFiltered.foreach(aggConstantFeatureBinaryFiltered.add)
+
+ // constant features should not affect gradient
+ assert(aggConstantFeatureBinary.gradient(0) === 0.0)
+ assert(aggConstantFeatureBinary.gradient(1) == aggConstantFeatureBinaryFiltered.gradient(0))
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala
index 2b29c67d859db..4c7913d5d2577 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/optim/aggregator/LogisticAggregatorSuite.scala
@@ -28,6 +28,7 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
@transient var instances: Array[Instance] = _
@transient var instancesConstantFeature: Array[Instance] = _
+ @transient var instancesConstantFeatureFiltered: Array[Instance] = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -41,6 +42,11 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
Instance(1.0, 0.5, Vectors.dense(1.0, 1.0)),
Instance(2.0, 0.3, Vectors.dense(1.0, 0.5))
)
+ instancesConstantFeatureFiltered = Array(
+ Instance(0.0, 0.1, Vectors.dense(2.0)),
+ Instance(1.0, 0.5, Vectors.dense(1.0)),
+ Instance(2.0, 0.3, Vectors.dense(0.5))
+ )
}
/** Get summary statistics for some data and create a new LogisticAggregator. */
@@ -211,8 +217,6 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
}.sum
val loss = lossSum / weightSum
-
-
// compute the gradients
val gradientCoef = new Array[Double](numFeatures)
var gradientIntercept = 0.0
@@ -233,21 +237,44 @@ class LogisticAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
val binaryInstances = instancesConstantFeature.map { instance =>
if (instance.label <= 1.0) instance else Instance(0.0, instance.weight, instance.features)
}
+ val binaryInstancesFiltered = instancesConstantFeatureFiltered.map { instance =>
+ if (instance.label <= 1.0) instance else Instance(0.0, instance.weight, instance.features)
+ }
val coefArray = Array(1.0, 2.0, -2.0, 3.0, 0.0, -1.0)
+ val coefArrayFiltered = Array(3.0, 0.0, -1.0)
val interceptArray = Array(4.0, 2.0, -3.0)
val aggConstantFeature = getNewAggregator(instancesConstantFeature,
Vectors.dense(coefArray ++ interceptArray), fitIntercept = true, isMultinomial = true)
- instances.foreach(aggConstantFeature.add)
+ val aggConstantFeatureFiltered = getNewAggregator(instancesConstantFeatureFiltered,
+ Vectors.dense(coefArrayFiltered ++ interceptArray), fitIntercept = true, isMultinomial = true)
+
+ instancesConstantFeature.foreach(aggConstantFeature.add)
+ instancesConstantFeatureFiltered.foreach(aggConstantFeatureFiltered.add)
+
// constant features should not affect gradient
- assert(aggConstantFeature.gradient(0) === 0.0)
+ def validateGradient(grad: Vector, gradFiltered: Vector, numCoefficientSets: Int): Unit = {
+ for (i <- 0 until numCoefficientSets) {
+ assert(grad(i) === 0.0)
+ assert(grad(numCoefficientSets + i) == gradFiltered(i))
+ }
+ }
+
+ validateGradient(aggConstantFeature.gradient, aggConstantFeatureFiltered.gradient, 3)
val binaryCoefArray = Array(1.0, 2.0)
+ val binaryCoefArrayFiltered = Array(2.0)
val intercept = 1.0
val aggConstantFeatureBinary = getNewAggregator(binaryInstances,
Vectors.dense(binaryCoefArray ++ Array(intercept)), fitIntercept = true,
isMultinomial = false)
- instances.foreach(aggConstantFeatureBinary.add)
+ val aggConstantFeatureBinaryFiltered = getNewAggregator(binaryInstancesFiltered,
+ Vectors.dense(binaryCoefArrayFiltered ++ Array(intercept)), fitIntercept = true,
+ isMultinomial = false)
+ binaryInstances.foreach(aggConstantFeatureBinary.add)
+ binaryInstancesFiltered.foreach(aggConstantFeatureBinaryFiltered.add)
+
// constant features should not affect gradient
- assert(aggConstantFeatureBinary.gradient(0) === 0.0)
+ validateGradient(aggConstantFeatureBinary.gradient,
+ aggConstantFeatureBinaryFiltered.gradient, 1)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index e7bd4eb9e0adf..f470dca7dbd0a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -715,7 +715,7 @@ class LinearRegressionSuite
assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_")))
// Residuals in [[LinearRegressionResults]] should equal those manually computed
- val expectedResiduals = datasetWithDenseFeature.select("features", "label")
+ datasetWithDenseFeature.select("features", "label")
.rdd
.map { case Row(features: DenseVector, label: Double) =>
val prediction =
diff --git a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala
index dfb733ff6e761..1ea851ef2d676 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/stat/SummarizerSuite.scala
@@ -402,6 +402,24 @@ class SummarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(summarizer.count === 6)
}
+ test("summarizer buffer zero variance test (SPARK-21818)") {
+ val summarizer1 = new SummarizerBuffer()
+ .add(Vectors.dense(3.0), 0.7)
+ val summarizer2 = new SummarizerBuffer()
+ .add(Vectors.dense(3.0), 0.4)
+ val summarizer3 = new SummarizerBuffer()
+ .add(Vectors.dense(3.0), 0.5)
+ val summarizer4 = new SummarizerBuffer()
+ .add(Vectors.dense(3.0), 0.4)
+
+ val summarizer = summarizer1
+ .merge(summarizer2)
+ .merge(summarizer3)
+ .merge(summarizer4)
+
+ assert(summarizer.variance(0) >= 0.0)
+ }
+
test("summarizer buffer merging summarizer with empty summarizer") {
// If one of two is non-empty, this should return the non-empty summarizer.
// If both of them are empty, then just return the empty summarizer.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index dc6043ef19fe2..a8d4377cff2d1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -120,6 +120,33 @@ class CrossValidatorSuite
}
}
+ test("cross validation with parallel evaluation") {
+ val lr = new LogisticRegression
+ val lrParamMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.001, 1000.0))
+ .addGrid(lr.maxIter, Array(0, 3))
+ .build()
+ val eval = new BinaryClassificationEvaluator
+ val cv = new CrossValidator()
+ .setEstimator(lr)
+ .setEstimatorParamMaps(lrParamMaps)
+ .setEvaluator(eval)
+ .setNumFolds(2)
+ .setParallelism(1)
+ val cvSerialModel = cv.fit(dataset)
+ cv.setParallelism(2)
+ val cvParallelModel = cv.fit(dataset)
+
+ val serialMetrics = cvSerialModel.avgMetrics.sorted
+ val parallelMetrics = cvParallelModel.avgMetrics.sorted
+ assert(serialMetrics === parallelMetrics)
+
+ val parentSerial = cvSerialModel.bestModel.parent.asInstanceOf[LogisticRegression]
+ val parentParallel = cvParallelModel.bestModel.parent.asInstanceOf[LogisticRegression]
+ assert(parentSerial.getRegParam === parentParallel.getRegParam)
+ assert(parentSerial.getMaxIter === parentParallel.getMaxIter)
+ }
+
test("read/write: CrossValidator with simple estimator") {
val lr = new LogisticRegression().setMaxIter(3)
val evaluator = new BinaryClassificationEvaluator()
@@ -187,14 +214,13 @@ class CrossValidatorSuite
cv2.getEstimator match {
case ova2: OneVsRest =>
assert(ova.uid === ova2.uid)
- val classifier = ova2.getClassifier
- classifier match {
+ ova2.getClassifier match {
case lr: LogisticRegression =>
assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter
=== lr.getMaxIter)
- case _ =>
+ case other =>
throw new AssertionError(s"Loaded CrossValidator expected estimator of type" +
- s" LogisticREgression but found ${classifier.getClass.getName}")
+ s" LogisticRegression but found ${other.getClass.getName}")
}
case other =>
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index 7c97865e45202..74801733381c1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -36,9 +36,14 @@ class TrainValidationSplitSuite
import testImplicits._
- test("train validation with logistic regression") {
- val dataset = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2).toDF()
+ @transient var dataset: Dataset[_] = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ dataset = sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2).toDF()
+ }
+ test("train validation with logistic regression") {
val lr = new LogisticRegression
val lrParamMaps = new ParamGridBuilder()
.addGrid(lr.regParam, Array(0.001, 1000.0))
@@ -117,6 +122,32 @@ class TrainValidationSplitSuite
}
}
+ test("train validation with parallel evaluation") {
+ val lr = new LogisticRegression
+ val lrParamMaps = new ParamGridBuilder()
+ .addGrid(lr.regParam, Array(0.001, 1000.0))
+ .addGrid(lr.maxIter, Array(0, 3))
+ .build()
+ val eval = new BinaryClassificationEvaluator
+ val cv = new TrainValidationSplit()
+ .setEstimator(lr)
+ .setEstimatorParamMaps(lrParamMaps)
+ .setEvaluator(eval)
+ .setParallelism(1)
+ val cvSerialModel = cv.fit(dataset)
+ cv.setParallelism(2)
+ val cvParallelModel = cv.fit(dataset)
+
+ val serialMetrics = cvSerialModel.validationMetrics.sorted
+ val parallelMetrics = cvParallelModel.validationMetrics.sorted
+ assert(serialMetrics === parallelMetrics)
+
+ val parentSerial = cvSerialModel.bestModel.parent.asInstanceOf[LogisticRegression]
+ val parentParallel = cvParallelModel.bestModel.parent.asInstanceOf[LogisticRegression]
+ assert(parentSerial.getRegParam === parentParallel.getRegParam)
+ assert(parentSerial.getMaxIter === parentParallel.getMaxIter)
+ }
+
test("read/write: TrainValidationSplit") {
val lr = new LogisticRegression().setMaxIter(3)
val evaluator = new BinaryClassificationEvaluator()
@@ -173,14 +204,13 @@ class TrainValidationSplitSuite
tvs2.getEstimator match {
case ova2: OneVsRest =>
assert(ova.uid === ova2.uid)
- val classifier = ova2.getClassifier
- classifier match {
+ ova2.getClassifier match {
case lr: LogisticRegression =>
assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter
=== lr.getMaxIter)
- case _ =>
+ case other =>
throw new AssertionError(s"Loaded TrainValidationSplit expected estimator of type" +
- s" LogisticREgression but found ${classifier.getClass.getName}")
+ s" LogisticRegression but found ${other.getClass.getName}")
}
case other =>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
index 99d52fabc5309..a08917ac1ebed 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala
@@ -23,18 +23,16 @@ import org.apache.spark.mllib.util.TestingUtils._
class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
- private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
-
- private def pairsWithinEpsilon(x: ((Double, Double), (Double, Double))): Boolean =
- (x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5)
-
- private def assertSequencesMatch(left: Seq[Double], right: Seq[Double]): Unit = {
- assert(left.zip(right).forall(areWithinEpsilon))
+ private def assertSequencesMatch(actual: Seq[Double], expected: Seq[Double]): Unit = {
+ actual.zip(expected).foreach { case (a, e) => assert(a ~== e absTol 1.0e-5) }
}
- private def assertTupleSequencesMatch(left: Seq[(Double, Double)],
- right: Seq[(Double, Double)]): Unit = {
- assert(left.zip(right).forall(pairsWithinEpsilon))
+ private def assertTupleSequencesMatch(actual: Seq[(Double, Double)],
+ expected: Seq[(Double, Double)]): Unit = {
+ actual.zip(expected).foreach { case ((ax, ay), (ex, ey)) =>
+ assert(ax ~== ex absTol 1.0e-5)
+ assert(ay ~== ey absTol 1.0e-5)
+ }
}
private def validateMetrics(metrics: BinaryClassificationMetrics,
@@ -44,7 +42,7 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
expectedFMeasures1: Seq[Double],
expectedFmeasures2: Seq[Double],
expectedPrecisions: Seq[Double],
- expectedRecalls: Seq[Double]) = {
+ expectedRecalls: Seq[Double]): Unit = {
assertSequencesMatch(metrics.thresholds().collect(), expectedThresholds)
assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve)
@@ -111,7 +109,7 @@ class BinaryClassificationMetricsSuite extends SparkFunSuite with MLlibTestSpark
val fpr = Seq(1.0)
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0))
val pr = recalls.zip(precisions)
- val prCurve = Seq((0.0, 1.0)) ++ pr
+ val prCurve = Seq((0.0, 0.0)) ++ pr
val f1 = pr.map {
case (0, 0) => 0.0
case (r, p) => 2.0 * (p * r) / (p + r)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
index 2f90afdcee55e..8eab12416a698 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/PCASuite.scala
@@ -48,4 +48,10 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext {
}
assert(pca.explainedVariance ~== explainedVariance relTol 1e-8)
}
+
+ test("memory cost computation") {
+ assert(PCAUtil.memoryCost(10, 100) < Int.MaxValue)
+ // check overflowing
+ assert(PCAUtil.memoryCost(40000, 60000) > Int.MaxValue)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index 6736e7d3db511..c8ac92eecf40b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -23,7 +23,7 @@ import scala.collection.mutable.{Map => MutableMap}
import breeze.linalg.{CSCMatrix, Matrix => BM}
import org.mockito.Mockito.when
-import org.scalatest.mock.MockitoSugar._
+import org.scalatest.mockito.MockitoSugar._
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{linalg => newlinalg}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
index 797e84fcc7377..c6466bc918dd0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala
@@ -270,4 +270,22 @@ class MultivariateOnlineSummarizerSuite extends SparkFunSuite {
assert(summarizer3.max ~== Vectors.dense(10.0, 0.0) absTol 1e-14)
assert(summarizer3.min ~== Vectors.dense(0.0, -10.0) absTol 1e-14)
}
+
+ test ("test zero variance (SPARK-21818)") {
+ val summarizer1 = (new MultivariateOnlineSummarizer)
+ .add(Vectors.dense(3.0), 0.7)
+ val summarizer2 = (new MultivariateOnlineSummarizer)
+ .add(Vectors.dense(3.0), 0.4)
+ val summarizer3 = (new MultivariateOnlineSummarizer)
+ .add(Vectors.dense(3.0), 0.5)
+ val summarizer4 = (new MultivariateOnlineSummarizer)
+ .add(Vectors.dense(3.0), 0.4)
+
+ val summarizer = summarizer1
+ .merge(summarizer2)
+ .merge(summarizer3)
+ .merge(summarizer4)
+
+ assert(summarizer.variance(0) >= 0.0)
+ }
}
diff --git a/pom.xml b/pom.xml
index b7f65cae3e96e..67764aeafbc85 100644
--- a/pom.xml
+++ b/pom.xml
@@ -104,8 +104,6 @@
examples
repl
launcher
- external/kafka-0-8
- external/kafka-0-8-assembly
external/kafka-0-10
external/kafka-0-10-assembly
external/kafka-0-10-sql
@@ -144,7 +142,7 @@
1.4.0
nohive
3.1.0
- 0.8.0
+ 0.8.4
2.4.0
2.0.8
3.2.2
@@ -932,6 +930,11 @@
scala-actors
${scala.version}
+
+ org.scala-lang.modules
+ scala-parser-combinators_${scala.binary.version}
+ 1.0.4
+
org.scala-lang
scalap
@@ -940,7 +943,7 @@
org.scalatest
scalatest_${scala.binary.version}
- 2.2.6
+ 3.0.3
test
@@ -952,7 +955,7 @@
org.scalacheck
scalacheck_${scala.binary.version}
- 1.12.5
+ 1.13.5
test
@@ -2298,13 +2301,6 @@
testCompile
-
- attach-scaladocs
- verify
-
- doc-jar
-
-
${scala.version}
@@ -2327,7 +2323,7 @@
${java.version}
-target
${java.version}
- -Xlint:all,-serial,-path
+ -Xlint:all,-serial,-path,-try
@@ -2734,7 +2730,7 @@
org.scalastyle
scalastyle-maven-plugin
- 0.9.0
+ 1.0.0
false
true
@@ -3014,6 +3010,14 @@
+
+ kafka-0-8
+
+ external/kafka-0-8
+ external/kafka-0-8-assembly
+
+
+
test-java-home
@@ -3029,12 +3033,10 @@
scala-2.11
-
-
com.google.guava
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala
index 6c8619e3c3c13..7e85de91c5d36 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/config.scala
@@ -56,18 +56,64 @@ package object config {
.stringConf
.createOptional
- private [spark] val DRIVER_LABELS =
+ private[spark] val DRIVER_LABELS =
ConfigBuilder("spark.mesos.driver.labels")
- .doc("Mesos labels to add to the driver. Labels are free-form key-value pairs. Key-value " +
+ .doc("Mesos labels to add to the driver. Labels are free-form key-value pairs. Key-value " +
"pairs should be separated by a colon, and commas used to list more than one." +
"Ex. key:value,key2:value2")
.stringConf
.createOptional
- private [spark] val DRIVER_FAILOVER_TIMEOUT =
+ private[spark] val SECRET_NAME =
+ ConfigBuilder("spark.mesos.driver.secret.names")
+ .doc("A comma-separated list of secret reference names. Consult the Mesos Secret protobuf " +
+ "for more information.")
+ .stringConf
+ .toSequence
+ .createOptional
+
+ private[spark] val SECRET_VALUE =
+ ConfigBuilder("spark.mesos.driver.secret.values")
+ .doc("A comma-separated list of secret values.")
+ .stringConf
+ .toSequence
+ .createOptional
+
+ private[spark] val SECRET_ENVKEY =
+ ConfigBuilder("spark.mesos.driver.secret.envkeys")
+ .doc("A comma-separated list of the environment variables to contain the secrets." +
+ "The environment variable will be set on the driver.")
+ .stringConf
+ .toSequence
+ .createOptional
+
+ private[spark] val SECRET_FILENAME =
+ ConfigBuilder("spark.mesos.driver.secret.filenames")
+ .doc("A comma-seperated list of file paths secret will be written to. Consult the Mesos " +
+ "Secret protobuf for more information.")
+ .stringConf
+ .toSequence
+ .createOptional
+
+ private[spark] val DRIVER_FAILOVER_TIMEOUT =
ConfigBuilder("spark.mesos.driver.failoverTimeout")
.doc("Amount of time in seconds that the master will wait to hear from the driver, " +
"during a temporary disconnection, before tearing down all the executors.")
.doubleConf
.createWithDefault(0.0)
+
+ private[spark] val NETWORK_NAME =
+ ConfigBuilder("spark.mesos.network.name")
+ .doc("Attach containers to the given named network. If this job is launched " +
+ "in cluster mode, also launch the driver in the given named network.")
+ .stringConf
+ .createOptional
+
+ private[spark] val NETWORK_LABELS =
+ ConfigBuilder("spark.mesos.network.labels")
+ .doc("Network labels to pass to CNI plugins. This is a comma-separated list " +
+ "of key-value pairs, where each key-value pair has the format key:value. " +
+ "Example: key1:val1,key2:val2")
+ .stringConf
+ .createOptional
}
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
index 9ee9cb1e79306..ec533f91474f2 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
@@ -28,6 +28,7 @@ import org.apache.mesos.{Scheduler, SchedulerDriver}
import org.apache.mesos.Protos.{TaskState => MesosTaskState, _}
import org.apache.mesos.Protos.Environment.Variable
import org.apache.mesos.Protos.TaskStatus.Reason
+import org.apache.mesos.protobuf.ByteString
import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState}
import org.apache.spark.deploy.mesos.MesosDriverDescription
@@ -386,12 +387,46 @@ private[spark] class MesosClusterScheduler(
val env = desc.conf.getAllWithPrefix("spark.mesos.driverEnv.") ++ commandEnv
val envBuilder = Environment.newBuilder()
+
+ // add normal environment variables
env.foreach { case (k, v) =>
envBuilder.addVariables(Variable.newBuilder().setName(k).setValue(v))
}
+
+ // add secret environment variables
+ getSecretEnvVar(desc).foreach { variable =>
+ if (variable.getSecret.getReference.isInitialized) {
+ logInfo(s"Setting reference secret ${variable.getSecret.getReference.getName}" +
+ s"on file ${variable.getName}")
+ } else {
+ logInfo(s"Setting secret on environment variable name=${variable.getName}")
+ }
+ envBuilder.addVariables(variable)
+ }
+
envBuilder.build()
}
+ private def getSecretEnvVar(desc: MesosDriverDescription): List[Variable] = {
+ val secrets = getSecrets(desc)
+ val secretEnvKeys = desc.conf.get(config.SECRET_ENVKEY).getOrElse(Nil)
+ if (illegalSecretInput(secretEnvKeys, secrets)) {
+ throw new SparkException(
+ s"Need to give equal numbers of secrets and environment keys " +
+ s"for environment-based reference secrets got secrets $secrets, " +
+ s"and keys $secretEnvKeys")
+ }
+
+ secrets.zip(secretEnvKeys).map {
+ case (s, k) =>
+ Variable.newBuilder()
+ .setName(k)
+ .setType(Variable.Type.SECRET)
+ .setSecret(s)
+ .build
+ }.toList
+ }
+
private def getDriverUris(desc: MesosDriverDescription): List[CommandInfo.URI] = {
val confUris = List(conf.getOption("spark.mesos.uris"),
desc.conf.getOption("spark.mesos.uris"),
@@ -529,18 +564,104 @@ private[spark] class MesosClusterScheduler(
val appName = desc.conf.get("spark.app.name")
+ val driverLabels = MesosProtoUtils.mesosLabels(desc.conf.get(config.DRIVER_LABELS)
+ .getOrElse(""))
+
TaskInfo.newBuilder()
.setTaskId(taskId)
.setName(s"Driver for ${appName}")
.setSlaveId(offer.offer.getSlaveId)
.setCommand(buildDriverCommand(desc))
+ .setContainer(getContainerInfo(desc))
.addAllResources(cpuResourcesToUse.asJava)
.addAllResources(memResourcesToUse.asJava)
- .setLabels(MesosProtoUtils.mesosLabels(desc.conf.get(config.DRIVER_LABELS).getOrElse("")))
- .setContainer(MesosSchedulerBackendUtil.containerInfo(desc.conf))
+ .setLabels(driverLabels)
.build
}
+ private def getContainerInfo(desc: MesosDriverDescription): ContainerInfo.Builder = {
+ val containerInfo = MesosSchedulerBackendUtil.containerInfo(desc.conf)
+
+ getSecretVolume(desc).foreach { volume =>
+ if (volume.getSource.getSecret.getReference.isInitialized) {
+ logInfo(s"Setting reference secret ${volume.getSource.getSecret.getReference.getName}" +
+ s"on file ${volume.getContainerPath}")
+ } else {
+ logInfo(s"Setting secret on file name=${volume.getContainerPath}")
+ }
+ containerInfo.addVolumes(volume)
+ }
+
+ containerInfo
+ }
+
+
+ private def getSecrets(desc: MesosDriverDescription): Seq[Secret] = {
+ def createValueSecret(data: String): Secret = {
+ Secret.newBuilder()
+ .setType(Secret.Type.VALUE)
+ .setValue(Secret.Value.newBuilder().setData(ByteString.copyFrom(data.getBytes)))
+ .build()
+ }
+
+ def createReferenceSecret(name: String): Secret = {
+ Secret.newBuilder()
+ .setReference(Secret.Reference.newBuilder().setName(name))
+ .setType(Secret.Type.REFERENCE)
+ .build()
+ }
+
+ val referenceSecrets: Seq[Secret] =
+ desc.conf.get(config.SECRET_NAME).getOrElse(Nil).map(s => createReferenceSecret(s))
+
+ val valueSecrets: Seq[Secret] = {
+ desc.conf.get(config.SECRET_VALUE).getOrElse(Nil).map(s => createValueSecret(s))
+ }
+
+ if (valueSecrets.nonEmpty && referenceSecrets.nonEmpty) {
+ throw new SparkException("Cannot specify VALUE type secrets and REFERENCE types ones")
+ }
+
+ if (referenceSecrets.nonEmpty) referenceSecrets else valueSecrets
+ }
+
+ private def illegalSecretInput(dest: Seq[String], s: Seq[Secret]): Boolean = {
+ if (dest.isEmpty) { // no destination set (ie not using secrets of this type
+ return false
+ }
+ if (dest.nonEmpty && s.nonEmpty) {
+ // make sure there is a destination for each secret of this type
+ if (dest.length != s.length) {
+ return true
+ }
+ }
+ false
+ }
+
+ private def getSecretVolume(desc: MesosDriverDescription): List[Volume] = {
+ val secrets = getSecrets(desc)
+ val secretPaths: Seq[String] =
+ desc.conf.get(config.SECRET_FILENAME).getOrElse(Nil)
+
+ if (illegalSecretInput(secretPaths, secrets)) {
+ throw new SparkException(
+ s"Need to give equal numbers of secrets and file paths for file-based " +
+ s"reference secrets got secrets $secrets, and paths $secretPaths")
+ }
+
+ secrets.zip(secretPaths).map {
+ case (s, p) =>
+ val source = Volume.Source.newBuilder()
+ .setType(Volume.Source.Type.SECRET)
+ .setSecret(s)
+ Volume.newBuilder()
+ .setContainerPath(p)
+ .setSource(source)
+ .setMode(Volume.Mode.RO)
+ .build
+ }.toList
+ }
+
/**
* This method takes all the possible candidates and attempt to schedule them with Mesos offers.
* Every time a new task is scheduled, the afterLaunchCallback is called to perform post scheduled
@@ -584,9 +705,14 @@ private[spark] class MesosClusterScheduler(
} catch {
case e: SparkException =>
afterLaunchCallback(submission.submissionId)
- finishedDrivers += new MesosClusterSubmissionState(submission, TaskID.newBuilder().
- setValue(submission.submissionId).build(), SlaveID.newBuilder().setValue("").
- build(), None, null, None, getDriverFrameworkID(submission))
+ finishedDrivers += new MesosClusterSubmissionState(
+ submission,
+ TaskID.newBuilder().setValue(submission.submissionId).build(),
+ SlaveID.newBuilder().setValue("").build(),
+ None,
+ null,
+ None,
+ getDriverFrameworkID(submission))
logError(s"Failed to launch the driver with id: ${submission.submissionId}, " +
s"cpu: $driverCpu, mem: $driverMem, reason: ${e.getMessage}")
}
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
index e6b09572121d6..26699873145b4 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
@@ -22,15 +22,15 @@ import java.util.{Collections, List => JList}
import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong}
import java.util.concurrent.locks.ReentrantLock
+import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _}
+import org.apache.mesos.SchedulerDriver
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.concurrent.Future
-import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _}
-import org.apache.mesos.SchedulerDriver
-
import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskState}
import org.apache.spark.deploy.mesos.config._
+import org.apache.spark.deploy.security.HadoopDelegationTokenManager
import org.apache.spark.internal.config
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient
@@ -55,8 +55,10 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
master: String,
securityManager: SecurityManager)
extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv)
- with org.apache.mesos.Scheduler
- with MesosSchedulerUtils {
+ with org.apache.mesos.Scheduler with MesosSchedulerUtils {
+
+ override def hadoopDelegationTokenManager: Option[HadoopDelegationTokenManager] =
+ Some(new HadoopDelegationTokenManager(sc.conf, sc.hadoopConfiguration))
// Blacklist a slave after this many failures
private val MAX_SLAVE_FAILURES = 2
@@ -668,7 +670,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
}
private def executorHostname(offer: Offer): String = {
- if (sc.conf.getOption("spark.mesos.network.name").isDefined) {
+ if (sc.conf.get(NETWORK_NAME).isDefined) {
// The agent's IP is not visible in a CNI container, so we bind to 0.0.0.0
"0.0.0.0"
} else {
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala
index fbcbc55099ec5..f29e541addf23 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala
@@ -21,6 +21,7 @@ import org.apache.mesos.Protos.{ContainerInfo, Image, NetworkInfo, Parameter, Vo
import org.apache.mesos.Protos.ContainerInfo.{DockerInfo, MesosInfo}
import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.deploy.mesos.config.{NETWORK_LABELS, NETWORK_NAME}
import org.apache.spark.internal.Logging
/**
@@ -121,7 +122,7 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging {
.toList
}
- def containerInfo(conf: SparkConf): ContainerInfo = {
+ def containerInfo(conf: SparkConf): ContainerInfo.Builder = {
val containerType = if (conf.contains("spark.mesos.executor.docker.image") &&
conf.get("spark.mesos.containerizer", "docker") == "docker") {
ContainerInfo.Type.DOCKER
@@ -148,8 +149,7 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging {
.getOrElse(List.empty)
if (containerType == ContainerInfo.Type.DOCKER) {
- containerInfo
- .setDocker(dockerInfo(image, forcePullImage, portMaps, params))
+ containerInfo.setDocker(dockerInfo(image, forcePullImage, portMaps, params))
} else {
containerInfo.setMesos(mesosInfo(image, forcePullImage))
}
@@ -161,12 +161,16 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging {
volumes.foreach(_.foreach(containerInfo.addVolumes(_)))
}
- conf.getOption("spark.mesos.network.name").map { name =>
- val info = NetworkInfo.newBuilder().setName(name).build()
+ conf.get(NETWORK_NAME).map { name =>
+ val networkLabels = MesosProtoUtils.mesosLabels(conf.get(NETWORK_LABELS).getOrElse(""))
+ val info = NetworkInfo.newBuilder()
+ .setName(name)
+ .setLabels(networkLabels)
+ .build()
containerInfo.addNetworkInfos(info)
}
- containerInfo.build()
+ containerInfo
}
private def dockerInfo(
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
index 7ec116c74b10f..6fcb30af8a733 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala
@@ -510,12 +510,20 @@ trait MesosSchedulerUtils extends Logging {
}
def mesosToTaskState(state: MesosTaskState): TaskState.TaskState = state match {
- case MesosTaskState.TASK_STAGING | MesosTaskState.TASK_STARTING => TaskState.LAUNCHING
- case MesosTaskState.TASK_RUNNING | MesosTaskState.TASK_KILLING => TaskState.RUNNING
+ case MesosTaskState.TASK_STAGING |
+ MesosTaskState.TASK_STARTING => TaskState.LAUNCHING
+ case MesosTaskState.TASK_RUNNING |
+ MesosTaskState.TASK_KILLING => TaskState.RUNNING
case MesosTaskState.TASK_FINISHED => TaskState.FINISHED
- case MesosTaskState.TASK_FAILED => TaskState.FAILED
+ case MesosTaskState.TASK_FAILED |
+ MesosTaskState.TASK_GONE |
+ MesosTaskState.TASK_GONE_BY_OPERATOR => TaskState.FAILED
case MesosTaskState.TASK_KILLED => TaskState.KILLED
- case MesosTaskState.TASK_LOST | MesosTaskState.TASK_ERROR => TaskState.LOST
+ case MesosTaskState.TASK_LOST |
+ MesosTaskState.TASK_ERROR |
+ MesosTaskState.TASK_DROPPED |
+ MesosTaskState.TASK_UNKNOWN |
+ MesosTaskState.TASK_UNREACHABLE => TaskState.LOST
}
def taskStateToMesos(state: TaskState.TaskState): MesosTaskState = state match {
diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala
index 0bb47906347d5..ff63e3f4ccfc3 100644
--- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala
+++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala
@@ -21,12 +21,13 @@ import java.util.{Collection, Collections, Date}
import scala.collection.JavaConverters._
-import org.apache.mesos.Protos.{TaskState => MesosTaskState, _}
+import org.apache.mesos.Protos.{Environment, Secret, TaskState => MesosTaskState, _}
import org.apache.mesos.Protos.Value.{Scalar, Type}
import org.apache.mesos.SchedulerDriver
+import org.apache.mesos.protobuf.ByteString
import org.mockito.{ArgumentCaptor, Matchers}
import org.mockito.Mockito._
-import org.scalatest.mock.MockitoSugar
+import org.scalatest.mockito.MockitoSugar
import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.Command
@@ -222,7 +223,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi
assert(env.getOrElse("TEST_ENV", null) == "TEST_VAL")
}
- test("supports spark.mesos.network.name") {
+ test("supports spark.mesos.network.name and spark.mesos.network.labels") {
setScheduler()
val mem = 1000
@@ -233,7 +234,8 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi
command,
Map("spark.mesos.executor.home" -> "test",
"spark.app.name" -> "test",
- "spark.mesos.network.name" -> "test-network-name"),
+ "spark.mesos.network.name" -> "test-network-name",
+ "spark.mesos.network.labels" -> "key1:val1,key2:val2"),
"s1",
new Date()))
@@ -246,6 +248,10 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi
val networkInfos = launchedTasks.head.getContainer.getNetworkInfosList
assert(networkInfos.size == 1)
assert(networkInfos.get(0).getName == "test-network-name")
+ assert(networkInfos.get(0).getLabels.getLabels(0).getKey == "key1")
+ assert(networkInfos.get(0).getLabels.getLabels(0).getValue == "val1")
+ assert(networkInfos.get(0).getLabels.getLabels(1).getKey == "key2")
+ assert(networkInfos.get(0).getLabels.getLabels(1).getValue == "val2")
}
test("supports spark.mesos.driver.labels") {
@@ -333,4 +339,163 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi
verify(driver, times(1)).declineOffer(offerId, filter)
}
+
+ test("Creates an env-based reference secrets.") {
+ setScheduler()
+
+ val mem = 1000
+ val cpu = 1
+ val secretName = "/path/to/secret,/anothersecret"
+ val envKey = "SECRET_ENV_KEY,PASSWORD"
+ val driverDesc = new MesosDriverDescription(
+ "d1",
+ "jar",
+ mem,
+ cpu,
+ true,
+ command,
+ Map("spark.mesos.executor.home" -> "test",
+ "spark.app.name" -> "test",
+ "spark.mesos.driver.secret.names" -> secretName,
+ "spark.mesos.driver.secret.envkeys" -> envKey),
+ "s1",
+ new Date())
+ val response = scheduler.submitDriver(driverDesc)
+ assert(response.success)
+ val offer = Utils.createOffer("o1", "s1", mem, cpu)
+ scheduler.resourceOffers(driver, Collections.singletonList(offer))
+ val launchedTasks = Utils.verifyTaskLaunched(driver, "o1")
+ assert(launchedTasks.head
+ .getCommand
+ .getEnvironment
+ .getVariablesCount == 3) // SPARK_SUBMIT_OPS and the secret
+ val variableOne = launchedTasks.head.getCommand.getEnvironment
+ .getVariablesList.asScala.filter(_.getName == "SECRET_ENV_KEY").head
+ assert(variableOne.getSecret.isInitialized)
+ assert(variableOne.getSecret.getType == Secret.Type.REFERENCE)
+ assert(variableOne.getSecret.getReference.getName == "/path/to/secret")
+ assert(variableOne.getType == Environment.Variable.Type.SECRET)
+ val variableTwo = launchedTasks.head.getCommand.getEnvironment
+ .getVariablesList.asScala.filter(_.getName == "PASSWORD").head
+ assert(variableTwo.getSecret.isInitialized)
+ assert(variableTwo.getSecret.getType == Secret.Type.REFERENCE)
+ assert(variableTwo.getSecret.getReference.getName == "/anothersecret")
+ assert(variableTwo.getType == Environment.Variable.Type.SECRET)
+ }
+
+ test("Creates an env-based value secrets.") {
+ setScheduler()
+ val mem = 1000
+ val cpu = 1
+ val secretValues = "user,password"
+ val envKeys = "USER,PASSWORD"
+ val driverDesc = new MesosDriverDescription(
+ "d1",
+ "jar",
+ mem,
+ cpu,
+ true,
+ command,
+ Map("spark.mesos.executor.home" -> "test",
+ "spark.app.name" -> "test",
+ "spark.mesos.driver.secret.values" -> secretValues,
+ "spark.mesos.driver.secret.envkeys" -> envKeys),
+ "s1",
+ new Date())
+ val response = scheduler.submitDriver(driverDesc)
+ assert(response.success)
+ val offer = Utils.createOffer("o1", "s1", mem, cpu)
+ scheduler.resourceOffers(driver, Collections.singletonList(offer))
+ val launchedTasks = Utils.verifyTaskLaunched(driver, "o1")
+ assert(launchedTasks.head
+ .getCommand
+ .getEnvironment
+ .getVariablesCount == 3) // SPARK_SUBMIT_OPS and the secret
+ val variableOne = launchedTasks.head.getCommand.getEnvironment
+ .getVariablesList.asScala.filter(_.getName == "USER").head
+ assert(variableOne.getSecret.isInitialized)
+ assert(variableOne.getSecret.getType == Secret.Type.VALUE)
+ assert(variableOne.getSecret.getValue.getData == ByteString.copyFrom("user".getBytes))
+ assert(variableOne.getType == Environment.Variable.Type.SECRET)
+ val variableTwo = launchedTasks.head.getCommand.getEnvironment
+ .getVariablesList.asScala.filter(_.getName == "PASSWORD").head
+ assert(variableTwo.getSecret.isInitialized)
+ assert(variableTwo.getSecret.getType == Secret.Type.VALUE)
+ assert(variableTwo.getSecret.getValue.getData == ByteString.copyFrom("password".getBytes))
+ assert(variableTwo.getType == Environment.Variable.Type.SECRET)
+ }
+
+ test("Creates file-based reference secrets.") {
+ setScheduler()
+ val mem = 1000
+ val cpu = 1
+ val secretName = "/path/to/secret,/anothersecret"
+ val secretPath = "/topsecret,/mypassword"
+ val driverDesc = new MesosDriverDescription(
+ "d1",
+ "jar",
+ mem,
+ cpu,
+ true,
+ command,
+ Map("spark.mesos.executor.home" -> "test",
+ "spark.app.name" -> "test",
+ "spark.mesos.driver.secret.names" -> secretName,
+ "spark.mesos.driver.secret.filenames" -> secretPath),
+ "s1",
+ new Date())
+ val response = scheduler.submitDriver(driverDesc)
+ assert(response.success)
+ val offer = Utils.createOffer("o1", "s1", mem, cpu)
+ scheduler.resourceOffers(driver, Collections.singletonList(offer))
+ val launchedTasks = Utils.verifyTaskLaunched(driver, "o1")
+ val volumes = launchedTasks.head.getContainer.getVolumesList
+ assert(volumes.size() == 2)
+ val secretVolOne = volumes.get(0)
+ assert(secretVolOne.getContainerPath == "/topsecret")
+ assert(secretVolOne.getSource.getSecret.getType == Secret.Type.REFERENCE)
+ assert(secretVolOne.getSource.getSecret.getReference.getName == "/path/to/secret")
+ val secretVolTwo = volumes.get(1)
+ assert(secretVolTwo.getContainerPath == "/mypassword")
+ assert(secretVolTwo.getSource.getSecret.getType == Secret.Type.REFERENCE)
+ assert(secretVolTwo.getSource.getSecret.getReference.getName == "/anothersecret")
+ }
+
+ test("Creates a file-based value secrets.") {
+ setScheduler()
+ val mem = 1000
+ val cpu = 1
+ val secretValues = "user,password"
+ val secretPath = "/whoami,/mypassword"
+ val driverDesc = new MesosDriverDescription(
+ "d1",
+ "jar",
+ mem,
+ cpu,
+ true,
+ command,
+ Map("spark.mesos.executor.home" -> "test",
+ "spark.app.name" -> "test",
+ "spark.mesos.driver.secret.values" -> secretValues,
+ "spark.mesos.driver.secret.filenames" -> secretPath),
+ "s1",
+ new Date())
+ val response = scheduler.submitDriver(driverDesc)
+ assert(response.success)
+ val offer = Utils.createOffer("o1", "s1", mem, cpu)
+ scheduler.resourceOffers(driver, Collections.singletonList(offer))
+ val launchedTasks = Utils.verifyTaskLaunched(driver, "o1")
+ val volumes = launchedTasks.head.getContainer.getVolumesList
+ assert(volumes.size() == 2)
+ val secretVolOne = volumes.get(0)
+ assert(secretVolOne.getContainerPath == "/whoami")
+ assert(secretVolOne.getSource.getSecret.getType == Secret.Type.VALUE)
+ assert(secretVolOne.getSource.getSecret.getValue.getData ==
+ ByteString.copyFrom("user".getBytes))
+ val secretVolTwo = volumes.get(1)
+ assert(secretVolTwo.getContainerPath == "/mypassword")
+ assert(secretVolTwo.getSource.getSecret.getType == Secret.Type.VALUE)
+ assert(secretVolTwo.getSource.getSecret.getValue.getData ==
+ ByteString.copyFrom("password".getBytes))
+ }
}
diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala
index a8175e29bc9cf..f6bae01c3af59 100644
--- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala
+++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala
@@ -30,7 +30,7 @@ import org.mockito.Matchers._
import org.mockito.Mockito._
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.ScalaFutures
-import org.scalatest.mock.MockitoSugar
+import org.scalatest.mockito.MockitoSugar
import org.apache.spark.{LocalSparkContext, SecurityManager, SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.deploy.mesos.config._
@@ -568,9 +568,10 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite
assert(launchedTasks.head.getLabels.equals(taskLabels))
}
- test("mesos supports spark.mesos.network.name") {
+ test("mesos supports spark.mesos.network.name and spark.mesos.network.labels") {
setBackend(Map(
- "spark.mesos.network.name" -> "test-network-name"
+ "spark.mesos.network.name" -> "test-network-name",
+ "spark.mesos.network.labels" -> "key1:val1,key2:val2"
))
val (mem, cpu) = (backend.executorMemory(sc), 4)
@@ -582,6 +583,10 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite
val networkInfos = launchedTasks.head.getContainer.getNetworkInfosList
assert(networkInfos.size == 1)
assert(networkInfos.get(0).getName == "test-network-name")
+ assert(networkInfos.get(0).getLabels.getLabels(0).getKey == "key1")
+ assert(networkInfos.get(0).getLabels.getLabels(0).getValue == "val1")
+ assert(networkInfos.get(0).getLabels.getLabels(1).getKey == "key2")
+ assert(networkInfos.get(0).getLabels.getLabels(1).getValue == "val2")
}
test("supports spark.scheduler.minRegisteredResourcesRatio") {
diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala
index 4ee85b91830a9..2d2f90c63a309 100644
--- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala
+++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala
@@ -33,7 +33,7 @@ import org.apache.mesos.Protos.Value.Scalar
import org.mockito.{ArgumentCaptor, Matchers}
import org.mockito.Matchers._
import org.mockito.Mockito._
-import org.scalatest.mock.MockitoSugar
+import org.scalatest.mockito.MockitoSugar
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.executor.MesosExecutorBackend
diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala
index caf9d89fdd201..f49d7c29eda49 100644
--- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala
+++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtilSuite.scala
@@ -17,9 +17,6 @@
package org.apache.spark.scheduler.cluster.mesos
-import org.scalatest._
-import org.scalatest.mock.MockitoSugar
-
import org.apache.spark.{SparkConf, SparkFunSuite}
class MesosSchedulerBackendUtilSuite extends SparkFunSuite {
diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala
index 5d4bf6d082c4c..7df738958f85c 100644
--- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala
+++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala
@@ -23,7 +23,7 @@ import scala.language.reflectiveCalls
import org.apache.mesos.Protos.{Resource, Value}
import org.mockito.Mockito._
import org.scalatest._
-import org.scalatest.mock.MockitoSugar
+import org.scalatest.mockito.MockitoSugar
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.internal.config._
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index f73e7dc0bb567..7052fb347106b 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -551,8 +551,8 @@ private[yarn] class YarnAllocator(
updateInternalState()
}
} else {
- logInfo(("Skip launching executorRunnable as runnning Excecutors count: %d " +
- "reached target Executors count: %d.").format(
+ logInfo(("Skip launching executorRunnable as running executors count: %d " +
+ "reached target executors count: %d.").format(
numExecutorsRunning.get, targetNumExecutors))
}
}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
index 4fef4394bb3f0..3d9f99f57bed7 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala
@@ -74,14 +74,6 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
jobCreds.mergeAll(UserGroupInformation.getCurrentUser().getCredentials())
}
- override def getCurrentUserCredentials(): Credentials = {
- UserGroupInformation.getCurrentUser().getCredentials()
- }
-
- override def addCurrentUserCredentials(creds: Credentials) {
- UserGroupInformation.getCurrentUser().addCredentials(creds)
- }
-
override def addSecretKeyToUserCredentials(key: String, secret: String) {
val creds = new Credentials()
creds.addSecretKey(new Text(key), secret.getBytes(UTF_8))
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
index b696e080ce62f..b091fec926c4c 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala
@@ -31,7 +31,7 @@ import org.apache.hadoop.yarn.api.records.LocalResourceType
import org.apache.hadoop.yarn.api.records.LocalResourceVisibility
import org.apache.hadoop.yarn.util.ConverterUtils
import org.mockito.Mockito.when
-import org.scalatest.mock.MockitoSugar
+import org.scalatest.mockito.MockitoSugar
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.deploy.yarn.config._
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala
index 13472f2ece184..01db796096f26 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala
@@ -70,11 +70,18 @@ class YarnShuffleIntegrationSuite extends BaseYarnClusterSuite {
val finalState = runSpark(
false,
mainClassName(YarnExternalShuffleDriver.getClass),
- appArgs = Seq(result.getAbsolutePath(), registeredExecFile.getAbsolutePath),
+ appArgs = if (registeredExecFile != null) {
+ Seq(result.getAbsolutePath, registeredExecFile.getAbsolutePath)
+ } else {
+ Seq(result.getAbsolutePath)
+ },
extraConf = extraSparkConf()
)
checkResult(finalState, result)
- assert(YarnTestAccessor.getRegisteredExecutorFile(shuffleService).exists())
+
+ if (registeredExecFile != null) {
+ assert(YarnTestAccessor.getRegisteredExecutorFile(shuffleService).exists())
+ }
}
}
@@ -105,7 +112,7 @@ private object YarnExternalShuffleDriver extends Logging with Matchers {
val WAIT_TIMEOUT_MILLIS = 10000
def main(args: Array[String]): Unit = {
- if (args.length != 2) {
+ if (args.length > 2) {
// scalastyle:off println
System.err.println(
s"""
@@ -121,10 +128,16 @@ private object YarnExternalShuffleDriver extends Logging with Matchers {
.setAppName("External Shuffle Test"))
val conf = sc.getConf
val status = new File(args(0))
- val registeredExecFile = new File(args(1))
+ val registeredExecFile = if (args.length == 2) {
+ new File(args(1))
+ } else {
+ null
+ }
logInfo("shuffle service executor file = " + registeredExecFile)
var result = "failure"
- val execStateCopy = new File(registeredExecFile.getAbsolutePath + "_dup")
+ val execStateCopy = Option(registeredExecFile).map { file =>
+ new File(file.getAbsolutePath + "_dup")
+ }.orNull
try {
val data = sc.parallelize(0 until 100, 10).map { x => (x % 10) -> x }.reduceByKey{ _ + _ }.
collect().toSet
@@ -132,11 +145,15 @@ private object YarnExternalShuffleDriver extends Logging with Matchers {
data should be ((0 until 10).map{x => x -> (x * 10 + 450)}.toSet)
result = "success"
// only one process can open a leveldb file at a time, so we copy the files
- FileUtils.copyDirectory(registeredExecFile, execStateCopy)
- assert(!ShuffleTestAccessor.reloadRegisteredExecutors(execStateCopy).isEmpty)
+ if (registeredExecFile != null && execStateCopy != null) {
+ FileUtils.copyDirectory(registeredExecFile, execStateCopy)
+ assert(!ShuffleTestAccessor.reloadRegisteredExecutors(execStateCopy).isEmpty)
+ }
} finally {
sc.stop()
- FileUtils.deleteDirectory(execStateCopy)
+ if (execStateCopy != null) {
+ FileUtils.deleteDirectory(execStateCopy)
+ }
Files.write(result, status, StandardCharsets.UTF_8)
}
}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala
index a58784f59676a..268f4bd13f6c3 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala
@@ -44,6 +44,8 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
private[yarn] var yarnConfig: YarnConfiguration = null
private[yarn] val SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"
+ private var recoveryLocalDir: File = _
+
override def beforeEach(): Unit = {
super.beforeEach()
yarnConfig = new YarnConfiguration()
@@ -54,6 +56,8 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
yarnConfig.setBoolean(YarnShuffleService.STOP_ON_FAILURE_KEY, true)
val localDir = Utils.createTempDir()
yarnConfig.set(YarnConfiguration.NM_LOCAL_DIRS, localDir.getAbsolutePath)
+
+ recoveryLocalDir = Utils.createTempDir()
}
var s1: YarnShuffleService = null
@@ -81,6 +85,7 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
test("executor state kept across NM restart") {
s1 = new YarnShuffleService
+ s1.setRecoveryPath(new Path(recoveryLocalDir.toURI))
// set auth to true to test the secrets recovery
yarnConfig.setBoolean(SecurityManager.SPARK_AUTH_CONF, true)
s1.init(yarnConfig)
@@ -123,6 +128,7 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
// now we pretend the shuffle service goes down, and comes back up
s1.stop()
s2 = new YarnShuffleService
+ s2.setRecoveryPath(new Path(recoveryLocalDir.toURI))
s2.init(yarnConfig)
s2.secretsFile should be (secretsFile)
s2.registeredExecutorFile should be (execStateFile)
@@ -140,6 +146,7 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
// Act like the NM restarts one more time
s2.stop()
s3 = new YarnShuffleService
+ s3.setRecoveryPath(new Path(recoveryLocalDir.toURI))
s3.init(yarnConfig)
s3.registeredExecutorFile should be (execStateFile)
s3.secretsFile should be (secretsFile)
@@ -156,6 +163,7 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
test("removed applications should not be in registered executor file") {
s1 = new YarnShuffleService
+ s1.setRecoveryPath(new Path(recoveryLocalDir.toURI))
yarnConfig.setBoolean(SecurityManager.SPARK_AUTH_CONF, false)
s1.init(yarnConfig)
val secretsFile = s1.secretsFile
@@ -190,6 +198,7 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
test("shuffle service should be robust to corrupt registered executor file") {
s1 = new YarnShuffleService
+ s1.setRecoveryPath(new Path(recoveryLocalDir.toURI))
s1.init(yarnConfig)
val app1Id = ApplicationId.newInstance(0, 1)
val app1Data = makeAppInfo("user", app1Id)
@@ -215,6 +224,7 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
out.close()
s2 = new YarnShuffleService
+ s2.setRecoveryPath(new Path(recoveryLocalDir.toURI))
s2.init(yarnConfig)
s2.registeredExecutorFile should be (execStateFile)
@@ -234,6 +244,7 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
// another stop & restart should be fine though (eg., we recover from previous corruption)
s3 = new YarnShuffleService
+ s3.setRecoveryPath(new Path(recoveryLocalDir.toURI))
s3.init(yarnConfig)
s3.registeredExecutorFile should be (execStateFile)
val handler3 = s3.blockHandler
@@ -254,14 +265,6 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
s1.init(yarnConfig)
s1._recoveryPath should be (recoveryPath)
s1.stop()
-
- // Test recovery path is set inside the shuffle service, this will be happened when NM
- // recovery is not enabled or there's no NM recovery (Hadoop 2.5-).
- s2 = new YarnShuffleService
- s2.init(yarnConfig)
- s2._recoveryPath should be
- (new Path(yarnConfig.getTrimmedStrings("yarn.nodemanager.local-dirs")(0)))
- s2.stop()
}
test("moving recovery file from NM local dir to recovery path") {
@@ -271,6 +274,7 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
// Simulate s1 is running on old version of Hadoop in which recovery file is in the NM local
// dir.
s1 = new YarnShuffleService
+ s1.setRecoveryPath(new Path(yarnConfig.getTrimmedStrings(YarnConfiguration.NM_LOCAL_DIRS)(0)))
// set auth to true to test the secrets recovery
yarnConfig.setBoolean(SecurityManager.SPARK_AUTH_CONF, true)
s1.init(yarnConfig)
@@ -308,7 +312,7 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
// Simulate s2 is running on Hadoop 2.5+ with NM recovery is enabled.
assert(execStateFile.exists())
- val recoveryPath = new Path(Utils.createTempDir().toURI)
+ val recoveryPath = new Path(recoveryLocalDir.toURI)
s2 = new YarnShuffleService
s2.setRecoveryPath(recoveryPath)
s2.init(yarnConfig)
@@ -347,10 +351,10 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
// Set up a read-only local dir.
val roDir = Utils.createTempDir()
Files.setPosixFilePermissions(roDir.toPath(), EnumSet.of(OWNER_READ, OWNER_EXECUTE))
- yarnConfig.set(YarnConfiguration.NM_LOCAL_DIRS, roDir.getAbsolutePath())
// Try to start the shuffle service, it should fail.
val service = new YarnShuffleService()
+ service.setRecoveryPath(new Path(roDir.toURI))
try {
val error = intercept[ServiceStateException] {
@@ -369,4 +373,12 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
new ApplicationInitializationContext(user, appId, secret)
}
+ test("recovery db should not be created if NM recovery is not enabled") {
+ s1 = new YarnShuffleService
+ s1.init(yarnConfig)
+ s1._recoveryPath should be (null)
+ s1.registeredExecutorFile should be (null)
+ s1.secretsFile should be (null)
+ }
+
}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala
index 0a413b2c23de1..7fac57ff68abc 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler.cluster
import scala.language.reflectiveCalls
import org.mockito.Mockito.when
-import org.scalatest.mock.MockitoSugar
+import org.scalatest.mockito.MockitoSugar
import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.scheduler.TaskSchedulerImpl
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index 0a4073b03957c..bd7f462b722cd 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -268,10 +268,7 @@ This file is divided into 3 sections:
-
- ^Override$
- override modifier should be used instead of @java.lang.Override.
-
+
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index fce81493795c8..1a75c7e504328 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -39,6 +39,10 @@
org.scala-lang
scala-reflect
+
+ org.scala-lang.modules
+ scala-parser-combinators_${scala.binary.version}
+
org.apache.spark
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index c4ec4c31adb17..33bc79a92b9e7 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -64,7 +64,7 @@ singleDataType
: dataType EOF
;
-standaloneColTypeList
+singleTableSchema
: colTypeList EOF
;
@@ -81,6 +81,7 @@ statement
(PARTITIONED BY partitionColumnNames=identifierList)?
bucketSpec? locationSpec?
(COMMENT comment=STRING)?
+ (TBLPROPERTIES tableProps=tablePropertyList)?
(AS? query)? #createTable
| createTableHeader ('(' columns=colTypeList ')')?
(COMMENT comment=STRING)?
@@ -242,8 +243,10 @@ query
;
insertInto
- : INSERT OVERWRITE TABLE tableIdentifier (partitionSpec (IF NOT EXISTS)?)?
- | INSERT INTO TABLE? tableIdentifier partitionSpec?
+ : INSERT OVERWRITE TABLE tableIdentifier (partitionSpec (IF NOT EXISTS)?)? #insertOverwriteTable
+ | INSERT INTO TABLE? tableIdentifier partitionSpec? #insertIntoTable
+ | INSERT OVERWRITE LOCAL? DIRECTORY path=STRING rowFormat? createFileFormat? #insertOverwriteHiveDir
+ | INSERT OVERWRITE LOCAL? DIRECTORY (path=STRING)? tableProvider (OPTIONS options=tablePropertyList)? #insertOverwriteDir
;
partitionSpecLocation
@@ -267,7 +270,7 @@ describeFuncName
;
describeColName
- : identifier ('.' (identifier | STRING))*
+ : nameParts+=identifier ('.' nameParts+=identifier)*
;
ctes
@@ -744,6 +747,7 @@ nonReserved
| AND | CASE | CAST | DISTINCT | DIV | ELSE | END | FUNCTION | INTERVAL | MACRO | OR | STRATIFY | THEN
| UNBOUNDED | WHEN
| DATABASE | SELECT | FROM | WHERE | HAVING | TO | TABLE | WITH | NOT | CURRENT_DATE | CURRENT_TIMESTAMP
+ | DIRECTORY
;
SELECT: 'SELECT';
@@ -814,6 +818,7 @@ WITH: 'WITH';
VALUES: 'VALUES';
CREATE: 'CREATE';
TABLE: 'TABLE';
+DIRECTORY: 'DIRECTORY';
VIEW: 'VIEW';
REPLACE: 'REPLACE';
INSERT: 'INSERT';
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 21363d3ba82c1..3ecc137c8cd7f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -118,6 +118,9 @@ object JavaTypeInference {
val (valueDataType, nullable) = inferDataType(valueType, seenTypeSet)
(MapType(keyDataType, valueDataType, nullable), true)
+ case other if other.isEnum =>
+ (StringType, true)
+
case other =>
if (seenTypeSet.contains(other)) {
throw new UnsupportedOperationException(
@@ -140,6 +143,7 @@ object JavaTypeInference {
def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
val beanInfo = Introspector.getBeanInfo(beanClass)
beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
+ .filterNot(_.getName == "declaringClass")
.filter(_.getReadMethod != null)
}
@@ -303,6 +307,14 @@ object JavaTypeInference {
keyData :: valueData :: Nil,
returnNullable = false)
+ case other if other.isEnum =>
+ StaticInvoke(
+ other,
+ ObjectType(other),
+ "valueOf",
+ Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil,
+ returnNullable = false)
+
case other =>
val properties = getJavaBeanReadableAndWritableProperties(other)
val setters = properties.map { p =>
@@ -429,6 +441,14 @@ object JavaTypeInference {
valueNullable = true
)
+ case other if other.isEnum =>
+ StaticInvoke(
+ classOf[UTF8String],
+ StringType,
+ "fromString",
+ Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false) :: Nil,
+ returnNullable = false)
+
case other =>
val properties = getJavaBeanReadableAndWritableProperties(other)
val nonNullOutput = CreateNamedStruct(properties.flatMap { p =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 70a3885d21531..db276fbc9d53a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -314,7 +314,7 @@ class Analyzer(
s"grouping columns (${groupByExprs.mkString(",")})")
}
case e @ Grouping(col: Expression) =>
- val idx = groupByExprs.indexOf(col)
+ val idx = groupByExprs.indexWhere(_.semanticEquals(col))
if (idx >= 0) {
Alias(Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)),
Literal(1)), ByteType), toPrettySQL(e))()
@@ -1286,8 +1286,10 @@ class Analyzer(
resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
case e @ Exists(sub, _, exprId) if !sub.resolved =>
resolveSubQuery(e, plans)(Exists(_, _, exprId))
- case In(value, Seq(l @ ListQuery(sub, _, exprId))) if value.resolved && !sub.resolved =>
- val expr = resolveSubQuery(l, plans)(ListQuery(_, _, exprId))
+ case In(value, Seq(l @ ListQuery(sub, _, exprId, _))) if value.resolved && !l.resolved =>
+ val expr = resolveSubQuery(l, plans)((plan, exprs) => {
+ ListQuery(plan, exprs, exprId, plan.output)
+ })
In(value, Seq(expr))
}
}
@@ -2254,7 +2256,10 @@ object CleanupAliases extends Rule[LogicalPlan] {
def trimNonTopLevelAliases(e: Expression): Expression = e match {
case a: Alias =>
- a.withNewChildren(trimAliases(a.child) :: Nil)
+ a.copy(child = trimAliases(a.child))(
+ exprId = a.exprId,
+ qualifier = a.qualifier,
+ explicitMetadata = Some(a.metadata))
case other => trimAliases(other)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 06d8350db9891..9ffe646b5e4ec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -402,7 +402,7 @@ object TypeCoercion {
// Handle type casting required between value expression and subquery output
// in IN subquery.
- case i @ In(a, Seq(ListQuery(sub, children, exprId)))
+ case i @ In(a, Seq(ListQuery(sub, children, exprId, _)))
if !i.resolved && flattenExpr(a).length == sub.output.length =>
// LHS is the value expression of IN subquery.
val lhs = flattenExpr(a)
@@ -434,7 +434,8 @@ object TypeCoercion {
case _ => CreateStruct(castedLhs)
}
- In(newLhs, Seq(ListQuery(Project(castedRhs, sub), children, exprId)))
+ val newSub = Project(castedRhs, sub)
+ In(newLhs, Seq(ListQuery(newSub, children, exprId, newSub.output)))
} else {
i
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index 6ab4153bac70e..33ba0867a33e0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -146,6 +146,9 @@ object UnsupportedOperationChecker {
throwError("Commands like CreateTable*, AlterTable*, Show* are not supported with " +
"streaming DataFrames/Datasets")
+ case _: InsertIntoDir =>
+ throwError("InsertIntoDir is not supported with streaming DataFrames/Datasets")
+
// mapGroupsWithState and flatMapGroupsWithState
case m: FlatMapGroupsWithState if m.isStreaming =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index 6030d90ed99c3..0908d68d25649 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.catalog
+import java.lang.reflect.InvocationTargetException
import java.net.URI
import java.util.Locale
import java.util.concurrent.Callable
@@ -24,6 +25,7 @@ import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable
import scala.util.{Failure, Success, Try}
+import scala.util.control.NonFatal
import com.google.common.cache.{Cache, CacheBuilder}
import org.apache.hadoop.conf.Configuration
@@ -39,7 +41,9 @@ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View}
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
object SessionCatalog {
val DEFAULT_DATABASE = "default"
@@ -1075,13 +1079,33 @@ class SessionCatalog(
// ----------------------------------------------------------------
/**
- * Construct a [[FunctionBuilder]] based on the provided class that represents a function.
+ * Constructs a [[FunctionBuilder]] based on the provided class that represents a function.
+ */
+ private def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = {
+ val clazz = Utils.classForName(functionClassName)
+ (input: Seq[Expression]) => makeFunctionExpression(name, clazz, input)
+ }
+
+ /**
+ * Constructs a [[Expression]] based on the provided class that represents a function.
*
* This performs reflection to decide what type of [[Expression]] to return in the builder.
*/
- protected def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = {
- // TODO: at least support UDAFs here
- throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.")
+ protected def makeFunctionExpression(
+ name: String,
+ clazz: Class[_],
+ input: Seq[Expression]): Expression = {
+ val clsForUDAF =
+ Utils.classForName("org.apache.spark.sql.expressions.UserDefinedAggregateFunction")
+ if (clsForUDAF.isAssignableFrom(clazz)) {
+ val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF")
+ cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int])
+ .newInstance(input, clazz.newInstance().asInstanceOf[Object], Int.box(1), Int.box(1))
+ .asInstanceOf[Expression]
+ } else {
+ throw new AnalysisException(s"No handler for UDAF '${clazz.getCanonicalName}'. " +
+ s"Use sparkSession.udf.register(...) instead.")
+ }
}
/**
@@ -1105,7 +1129,14 @@ class SessionCatalog(
}
val info = new ExpressionInfo(funcDefinition.className, func.database.orNull, func.funcName)
val builder =
- functionBuilder.getOrElse(makeFunctionBuilder(func.unquotedString, funcDefinition.className))
+ functionBuilder.getOrElse {
+ val className = funcDefinition.className
+ if (!Utils.classIsLoadable(className)) {
+ throw new AnalysisException(s"Can not load class '$className' when registering " +
+ s"the function '$func', please make sure it is on the classpath")
+ }
+ makeFunctionBuilder(func.unquotedString, className)
+ }
functionRegistry.registerFunction(func, info, builder)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
index 5a8c4e7610fff..1965144e81197 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
@@ -91,12 +91,14 @@ object CatalogStorageFormat {
*
* @param spec partition spec values indexed by column name
* @param storage storage format of the partition
- * @param parameters some parameters for the partition, for example, stats.
+ * @param parameters some parameters for the partition
+ * @param stats optional statistics (number of rows, total size, etc.)
*/
case class CatalogTablePartition(
spec: CatalogTypes.TablePartitionSpec,
storage: CatalogStorageFormat,
- parameters: Map[String, String] = Map.empty) {
+ parameters: Map[String, String] = Map.empty,
+ stats: Option[CatalogStatistics] = None) {
def toLinkedHashMap: mutable.LinkedHashMap[String, String] = {
val map = new mutable.LinkedHashMap[String, String]()
@@ -106,6 +108,7 @@ case class CatalogTablePartition(
if (parameters.nonEmpty) {
map.put("Partition Parameters", s"{${parameters.map(p => p._1 + "=" + p._2).mkString(", ")}}")
}
+ stats.foreach(s => map.put("Partition Statistics", s.simpleString))
map
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
index b77f93373e78d..7420b6b57d8e1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -121,7 +121,12 @@ class AttributeSet private (val baseSet: Set[AttributeEquals])
// We must force toSeq to not be strict otherwise we end up with a [[Stream]] that captures all
// sorts of things in its closure.
- override def toSeq: Seq[Attribute] = baseSet.map(_.a).toArray.toSeq
+ override def toSeq: Seq[Attribute] = {
+ // We need to keep a deterministic output order for `baseSet` because this affects a variable
+ // order in generated code (e.g., `GenerateColumnAccessor`).
+ // See SPARK-18394 for details.
+ baseSet.map(_.a).toSeq.sortBy { a => (a.name, a.exprId.id) }
+ }
override def toString: String = "{" + baseSet.map(_.a).mkString(", ") + "}"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 74c4cddf2b47e..c058425b4bc36 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -635,3 +635,9 @@ abstract class TernaryExpression extends Expression {
}
}
}
+
+/**
+ * Common base trait for user-defined functions, including UDF/UDAF/UDTF of different languages
+ * and Hive function wrappers.
+ */
+trait UserDefinedExpression
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala
index ede0b1654bbd6..305ac90e245b8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import scala.collection.mutable
+import scala.collection.{mutable, GenTraversableOnce}
import scala.collection.mutable.ArrayBuffer
object ExpressionSet {
@@ -67,6 +67,12 @@ class ExpressionSet protected(
newSet
}
+ override def ++(elems: GenTraversableOnce[Expression]): ExpressionSet = {
+ val newSet = new ExpressionSet(baseSet.clone(), originals.clone())
+ elems.foreach(newSet.add)
+ newSet
+ }
+
override def -(elem: Expression): ExpressionSet = {
val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized)
val newOriginals = originals.clone().filterNot(_.canonicalized == elem.canonicalized)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index 9df0e2e1415c0..527f1670c25e1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -47,7 +47,7 @@ case class ScalaUDF(
udfName: Option[String] = None,
nullable: Boolean = true,
udfDeterministic: Boolean = true)
- extends Expression with ImplicitCastInputTypes with NonSQLExpression {
+ extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression {
override def deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 807765c1e00a1..437397187356c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -464,14 +464,13 @@ class CodegenContext {
/**
* Returns the specialized code to set a given value in a column vector for a given `DataType`.
*/
- def setValue(batch: String, row: String, dataType: DataType, ordinal: Int,
- value: String): String = {
+ def setValue(vector: String, rowId: String, dataType: DataType, value: String): String = {
val jt = javaType(dataType)
dataType match {
case _ if isPrimitiveType(jt) =>
- s"$batch.column($ordinal).put${primitiveTypeName(jt)}($row, $value);"
- case t: DecimalType => s"$batch.column($ordinal).putDecimal($row, $value, ${t.precision});"
- case t: StringType => s"$batch.column($ordinal).putByteArray($row, $value.getBytes());"
+ s"$vector.put${primitiveTypeName(jt)}($rowId, $value);"
+ case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});"
+ case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());"
case _ =>
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
}
@@ -482,37 +481,36 @@ class CodegenContext {
* that could potentially be nullable.
*/
def updateColumn(
- batch: String,
- row: String,
+ vector: String,
+ rowId: String,
dataType: DataType,
- ordinal: Int,
ev: ExprCode,
nullable: Boolean): String = {
if (nullable) {
s"""
if (!${ev.isNull}) {
- ${setValue(batch, row, dataType, ordinal, ev.value)}
+ ${setValue(vector, rowId, dataType, ev.value)}
} else {
- $batch.column($ordinal).putNull($row);
+ $vector.putNull($rowId);
}
"""
} else {
- s"""${setValue(batch, row, dataType, ordinal, ev.value)};"""
+ s"""${setValue(vector, rowId, dataType, ev.value)};"""
}
}
/**
* Returns the specialized code to access a value from a column vector for a given `DataType`.
*/
- def getValue(batch: String, row: String, dataType: DataType, ordinal: Int): String = {
+ def getValue(vector: String, rowId: String, dataType: DataType): String = {
val jt = javaType(dataType)
dataType match {
case _ if isPrimitiveType(jt) =>
- s"$batch.column($ordinal).get${primitiveTypeName(jt)}($row)"
+ s"$vector.get${primitiveTypeName(jt)}($rowId)"
case t: DecimalType =>
- s"$batch.column($ordinal).getDecimal($row, ${t.precision}, ${t.scale})"
+ s"$vector.getDecimal($rowId, ${t.precision}, ${t.scale})"
case StringType =>
- s"$batch.column($ordinal).getUTF8String($row)"
+ s"$vector.getUTF8String($rowId)"
case _ =>
throw new IllegalArgumentException(s"cannot generate code for unsupported type: $dataType")
}
@@ -596,6 +594,7 @@ class CodegenContext {
case array: ArrayType => genComp(array, c1, c2) + " == 0"
case struct: StructType => genComp(struct, c1, c2) + " == 0"
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
+ case NullType => "false"
case _ =>
throw new IllegalArgumentException(
"cannot generate equality code for un-comparable type: " + dataType.simpleString)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index 17b605438d587..18b4fed597447 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.json._
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
@@ -362,9 +362,9 @@ case class JsonTuple(children: Seq[Expression])
@transient private lazy val fieldExpressions: Seq[Expression] = children.tail
// eagerly evaluate any foldable the field names
- @transient private lazy val foldableFieldNames: IndexedSeq[String] = {
+ @transient private lazy val foldableFieldNames: IndexedSeq[Option[String]] = {
fieldExpressions.map {
- case expr if expr.foldable => expr.eval().asInstanceOf[UTF8String].toString
+ case expr if expr.foldable => Option(expr.eval()).map(_.asInstanceOf[UTF8String].toString)
case _ => null
}.toIndexedSeq
}
@@ -417,7 +417,7 @@ case class JsonTuple(children: Seq[Expression])
val fieldNames = if (constantFields == fieldExpressions.length) {
// typically the user will provide the field names as foldable expressions
// so we can use the cached copy
- foldableFieldNames
+ foldableFieldNames.map(_.orNull)
} else if (constantFields == 0) {
// none are foldable so all field names need to be evaluated from the input row
fieldExpressions.map(_.eval(input).asInstanceOf[UTF8String].toString)
@@ -426,7 +426,7 @@ case class JsonTuple(children: Seq[Expression])
// prefer the cached copy when available
foldableFieldNames.zip(fieldExpressions).map {
case (null, expr) => expr.eval(input).asInstanceOf[UTF8String].toString
- case (fieldName, _) => fieldName
+ case (fieldName, _) => fieldName.orNull
}
}
@@ -436,7 +436,8 @@ case class JsonTuple(children: Seq[Expression])
while (parser.nextToken() != JsonToken.END_OBJECT) {
if (parser.getCurrentToken == JsonToken.FIELD_NAME) {
// check to see if this field is desired in the output
- val idx = fieldNames.indexOf(parser.getCurrentName)
+ val jsonField = parser.getCurrentName
+ var idx = fieldNames.indexOf(jsonField)
if (idx >= 0) {
// it is, copy the child tree to the correct location in the output row
val output = new ByteArrayOutputStream()
@@ -447,7 +448,14 @@ case class JsonTuple(children: Seq[Expression])
generator => copyCurrentStructure(generator, parser)
}
- row(idx) = UTF8String.fromBytes(output.toByteArray)
+ val jsonValue = UTF8String.fromBytes(output.toByteArray)
+
+ // SPARK-21804: json_tuple returns null values within repeated columns
+ // except the first one; so that we need to check the remaining fields.
+ do {
+ row(idx) = jsonValue
+ idx = fieldNames.indexOf(jsonField, idx + 1)
+ } while (idx >= 0)
}
}
}
@@ -596,7 +604,8 @@ case class JsonToStructs(
}
/**
- * Converts a [[StructType]] or [[ArrayType]] of [[StructType]]s to a json output string.
+ * Converts a [[StructType]], [[ArrayType]] of [[StructType]]s, [[MapType]]
+ * or [[ArrayType]] of [[MapType]]s to a json output string.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
@@ -609,6 +618,14 @@ case class JsonToStructs(
{"time":"26/08/2015"}
> SELECT _FUNC_(array(named_struct('a', 1, 'b', 2));
[{"a":1,"b":2}]
+ > SELECT _FUNC_(map('a', named_struct('b', 1)));
+ {"a":{"b":1}}
+ > SELECT _FUNC_(map(named_struct('a', 1),named_struct('b', 2)));
+ {"[1]":{"b":2}}
+ > SELECT _FUNC_(map('a', 1));
+ {"a":1}
+ > SELECT _FUNC_(array((map('a', 1))));
+ [{"a":1}]
""",
since = "2.2.0")
// scalastyle:on line.size.limit
@@ -640,6 +657,8 @@ case class StructsToJson(
lazy val rowSchema = child.dataType match {
case st: StructType => st
case ArrayType(st: StructType, _) => st
+ case mt: MapType => mt
+ case ArrayType(mt: MapType, _) => mt
}
// This converts rows to the JSON output according to the given schema.
@@ -661,6 +680,14 @@ case class StructsToJson(
(arr: Any) =>
gen.write(arr.asInstanceOf[ArrayData])
getAndReset()
+ case _: MapType =>
+ (map: Any) =>
+ gen.write(map.asInstanceOf[MapData])
+ getAndReset()
+ case ArrayType(_: MapType, _) =>
+ (arr: Any) =>
+ gen.write(arr.asInstanceOf[ArrayData])
+ getAndReset()
}
}
@@ -669,14 +696,25 @@ case class StructsToJson(
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case _: StructType | ArrayType(_: StructType, _) =>
try {
- JacksonUtils.verifySchema(rowSchema)
+ JacksonUtils.verifySchema(rowSchema.asInstanceOf[StructType])
+ TypeCheckResult.TypeCheckSuccess
+ } catch {
+ case e: UnsupportedOperationException =>
+ TypeCheckResult.TypeCheckFailure(e.getMessage)
+ }
+ case _: MapType | ArrayType(_: MapType, _) =>
+ // TODO: let `JacksonUtils.verifySchema` verify a `MapType`
+ try {
+ val st = StructType(StructField("a", rowSchema.asInstanceOf[MapType]) :: Nil)
+ JacksonUtils.verifySchema(st)
TypeCheckResult.TypeCheckSuccess
} catch {
case e: UnsupportedOperationException =>
TypeCheckResult.TypeCheckFailure(e.getMessage)
}
case _ => TypeCheckResult.TypeCheckFailure(
- s"Input type ${child.dataType.simpleString} must be a struct or array of structs.")
+ s"Input type ${child.dataType.simpleString} must be a struct, array of structs or " +
+ "a map or array of map.")
}
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 7bf10f199f1c7..efcd45fad779c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -133,37 +133,55 @@ case class Not(child: Expression)
/**
* Evaluates to `true` if `list` contains `value`.
*/
+// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = "expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN.")
+ usage = "expr1 _FUNC_(expr2, expr3, ...) - Returns true if `expr` equals to any valN.",
+ arguments = """
+ Arguments:
+ * expr1, expr2, expr3, ... - the arguments must be same type.
+ """,
+ examples = """
+ Examples:
+ > SELECT 1 _FUNC_(1, 2, 3);
+ true
+ > SELECT 1 _FUNC_(2, 3, 4);
+ false
+ > SELECT named_struct('a', 1, 'b', 2) _FUNC_(named_struct('a', 1, 'b', 1), named_struct('a', 1, 'b', 3));
+ false
+ > SELECT named_struct('a', 1, 'b', 2) _FUNC_(named_struct('a', 1, 'b', 2), named_struct('a', 1, 'b', 3));
+ true
+ """)
+// scalastyle:on line.size.limit
case class In(value: Expression, list: Seq[Expression]) extends Predicate {
require(list != null, "list should not be null")
+
override def checkInputDataTypes(): TypeCheckResult = {
- list match {
- case ListQuery(sub, _, _) :: Nil =>
- val valExprs = value match {
- case cns: CreateNamedStruct => cns.valExprs
- case expr => Seq(expr)
- }
- if (valExprs.length != sub.output.length) {
- TypeCheckResult.TypeCheckFailure(
- s"""
- |The number of columns in the left hand side of an IN subquery does not match the
- |number of columns in the output of subquery.
- |#columns in left hand side: ${valExprs.length}.
- |#columns in right hand side: ${sub.output.length}.
- |Left side columns:
- |[${valExprs.map(_.sql).mkString(", ")}].
- |Right side columns:
- |[${sub.output.map(_.sql).mkString(", ")}].
- """.stripMargin)
- } else {
- val mismatchedColumns = valExprs.zip(sub.output).flatMap {
- case (l, r) if l.dataType != r.dataType =>
- s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})"
- case _ => None
+ val mismatchOpt = list.find(l => !DataType.equalsStructurally(l.dataType, value.dataType))
+ if (mismatchOpt.isDefined) {
+ list match {
+ case ListQuery(_, _, _, childOutputs) :: Nil =>
+ val valExprs = value match {
+ case cns: CreateNamedStruct => cns.valExprs
+ case expr => Seq(expr)
}
- if (mismatchedColumns.nonEmpty) {
+ if (valExprs.length != childOutputs.length) {
+ TypeCheckResult.TypeCheckFailure(
+ s"""
+ |The number of columns in the left hand side of an IN subquery does not match the
+ |number of columns in the output of subquery.
+ |#columns in left hand side: ${valExprs.length}.
+ |#columns in right hand side: ${childOutputs.length}.
+ |Left side columns:
+ |[${valExprs.map(_.sql).mkString(", ")}].
+ |Right side columns:
+ |[${childOutputs.map(_.sql).mkString(", ")}].""".stripMargin)
+ } else {
+ val mismatchedColumns = valExprs.zip(childOutputs).flatMap {
+ case (l, r) if l.dataType != r.dataType =>
+ s"(${l.sql}:${l.dataType.catalogString}, ${r.sql}:${r.dataType.catalogString})"
+ case _ => None
+ }
TypeCheckResult.TypeCheckFailure(
s"""
|The data type of one or more elements in the left hand side of an IN subquery
@@ -173,20 +191,14 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
|Left side:
|[${valExprs.map(_.dataType.catalogString).mkString(", ")}].
|Right side:
- |[${sub.output.map(_.dataType.catalogString).mkString(", ")}].
- """.stripMargin)
- } else {
- TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
+ |[${childOutputs.map(_.dataType.catalogString).mkString(", ")}].""".stripMargin)
}
- }
- case _ =>
- val mismatchOpt = list.find(l => l.dataType != value.dataType)
- if (mismatchOpt.isDefined) {
+ case _ =>
TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
s"${value.dataType} != ${mismatchOpt.get.dataType}")
- } else {
- TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
- }
+ }
+ } else {
+ TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
}
}
@@ -453,6 +465,16 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
abstract class BinaryComparison extends BinaryOperator with Predicate {
+ // Note that we need to give a superset of allowable input types since orderable types are not
+ // finitely enumerable. The allowable types are checked below by checkInputDataTypes.
+ override def inputType: AbstractDataType = AnyDataType
+
+ override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
+ case TypeCheckResult.TypeCheckSuccess =>
+ TypeUtils.checkForOrderingExpr(left.dataType, this.getClass.getSimpleName)
+ case failure => failure
+ }
+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
if (ctx.isPrimitiveType(left.dataType)
&& left.dataType != BooleanType // java boolean doesn't support > or < operator
@@ -465,7 +487,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
}
}
- protected lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
+ protected lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(left.dataType)
}
@@ -483,28 +505,30 @@ object Equality {
}
}
+// TODO: although map type is not orderable, technically map type should be able to be used
+// in equality comparison
@ExpressionDescription(
- usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` equals `expr2`, or false otherwise.")
+ usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` equals `expr2`, or false otherwise.",
+ arguments = """
+ Arguments:
+ * expr1, expr2 - the two expressions must be same type or can be casted to a common type,
+ and must be a type that can be used in equality comparison. Map type is not supported.
+ For complex types such array/struct, the data types of fields must be orderable.
+ """,
+ examples = """
+ Examples:
+ > SELECT 2 _FUNC_ 2;
+ true
+ > SELECT 1 _FUNC_ '1';
+ true
+ > SELECT true _FUNC_ NULL;
+ NULL
+ > SELECT NULL _FUNC_ NULL;
+ NULL
+ """)
case class EqualTo(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {
- override def inputType: AbstractDataType = AnyDataType
-
- override def checkInputDataTypes(): TypeCheckResult = {
- super.checkInputDataTypes() match {
- case TypeCheckResult.TypeCheckSuccess =>
- // TODO: although map type is not orderable, technically map type should be able to be used
- // in equality comparison, remove this type check once we support it.
- if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) {
- TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualTo, but the actual " +
- s"input type is ${left.dataType.catalogString}.")
- } else {
- TypeCheckResult.TypeCheckSuccess
- }
- case failure => failure
- }
- }
-
override def symbol: String = "="
protected override def nullSafeEval(left: Any, right: Any): Any = ordering.equiv(left, right)
@@ -514,30 +538,32 @@ case class EqualTo(left: Expression, right: Expression)
}
}
+// TODO: although map type is not orderable, technically map type should be able to be used
+// in equality comparison
@ExpressionDescription(
usage = """
expr1 _FUNC_ expr2 - Returns same result as the EQUAL(=) operator for non-null operands,
but returns true if both are null, false if one of the them is null.
+ """,
+ arguments = """
+ Arguments:
+ * expr1, expr2 - the two expressions must be same type or can be casted to a common type,
+ and must be a type that can be used in equality comparison. Map type is not supported.
+ For complex types such array/struct, the data types of fields must be orderable.
+ """,
+ examples = """
+ Examples:
+ > SELECT 2 _FUNC_ 2;
+ true
+ > SELECT 1 _FUNC_ '1';
+ true
+ > SELECT true _FUNC_ NULL;
+ false
+ > SELECT NULL _FUNC_ NULL;
+ true
""")
case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison {
- override def inputType: AbstractDataType = AnyDataType
-
- override def checkInputDataTypes(): TypeCheckResult = {
- super.checkInputDataTypes() match {
- case TypeCheckResult.TypeCheckSuccess =>
- // TODO: although map type is not orderable, technically map type should be able to be used
- // in equality comparison, remove this type check once we support it.
- if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) {
- TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualNullSafe, but the actual " +
- s"input type is ${left.dataType.catalogString}.")
- } else {
- TypeCheckResult.TypeCheckSuccess
- }
- case failure => failure
- }
- }
-
override def symbol: String = "<=>"
override def nullable: Boolean = false
@@ -565,48 +591,120 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
}
@ExpressionDescription(
- usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` is less than `expr2`.")
+ usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` is less than `expr2`.",
+ arguments = """
+ Arguments:
+ * expr1, expr2 - the two expressions must be same type or can be casted to a common type,
+ and must be a type that can be ordered. For example, map type is not orderable, so it
+ is not supported. For complex types such array/struct, the data types of fields must
+ be orderable.
+ """,
+ examples = """
+ Examples:
+ > SELECT 1 _FUNC_ 2;
+ true
+ > SELECT 1.1 _FUNC_ '1';
+ false
+ > SELECT to_date('2009-07-30 04:17:52') _FUNC_ to_date('2009-07-30 04:17:52');
+ false
+ > SELECT to_date('2009-07-30 04:17:52') _FUNC_ to_date('2009-08-01 04:17:52');
+ true
+ > SELECT 1 _FUNC_ NULL;
+ NULL
+ """)
case class LessThan(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {
- override def inputType: AbstractDataType = TypeCollection.Ordered
-
override def symbol: String = "<"
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2)
}
@ExpressionDescription(
- usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` is less than or equal to `expr2`.")
+ usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` is less than or equal to `expr2`.",
+ arguments = """
+ Arguments:
+ * expr1, expr2 - the two expressions must be same type or can be casted to a common type,
+ and must be a type that can be ordered. For example, map type is not orderable, so it
+ is not supported. For complex types such array/struct, the data types of fields must
+ be orderable.
+ """,
+ examples = """
+ Examples:
+ > SELECT 2 _FUNC_ 2;
+ true
+ > SELECT 1.0 _FUNC_ '1';
+ true
+ > SELECT to_date('2009-07-30 04:17:52') _FUNC_ to_date('2009-07-30 04:17:52');
+ true
+ > SELECT to_date('2009-07-30 04:17:52') _FUNC_ to_date('2009-08-01 04:17:52');
+ true
+ > SELECT 1 _FUNC_ NULL;
+ NULL
+ """)
case class LessThanOrEqual(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {
- override def inputType: AbstractDataType = TypeCollection.Ordered
-
override def symbol: String = "<="
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2)
}
@ExpressionDescription(
- usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` is greater than `expr2`.")
+ usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` is greater than `expr2`.",
+ arguments = """
+ Arguments:
+ * expr1, expr2 - the two expressions must be same type or can be casted to a common type,
+ and must be a type that can be ordered. For example, map type is not orderable, so it
+ is not supported. For complex types such array/struct, the data types of fields must
+ be orderable.
+ """,
+ examples = """
+ Examples:
+ > SELECT 2 _FUNC_ 1;
+ true
+ > SELECT 2 _FUNC_ '1.1';
+ true
+ > SELECT to_date('2009-07-30 04:17:52') _FUNC_ to_date('2009-07-30 04:17:52');
+ false
+ > SELECT to_date('2009-07-30 04:17:52') _FUNC_ to_date('2009-08-01 04:17:52');
+ false
+ > SELECT 1 _FUNC_ NULL;
+ NULL
+ """)
case class GreaterThan(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {
- override def inputType: AbstractDataType = TypeCollection.Ordered
-
override def symbol: String = ">"
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2)
}
@ExpressionDescription(
- usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` is greater than or equal to `expr2`.")
+ usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` is greater than or equal to `expr2`.",
+ arguments = """
+ Arguments:
+ * expr1, expr2 - the two expressions must be same type or can be casted to a common type,
+ and must be a type that can be ordered. For example, map type is not orderable, so it
+ is not supported. For complex types such array/struct, the data types of fields must
+ be orderable.
+ """,
+ examples = """
+ Examples:
+ > SELECT 2 _FUNC_ 1;
+ true
+ > SELECT 2.0 _FUNC_ '2.1';
+ false
+ > SELECT to_date('2009-07-30 04:17:52') _FUNC_ to_date('2009-07-30 04:17:52');
+ true
+ > SELECT to_date('2009-07-30 04:17:52') _FUNC_ to_date('2009-08-01 04:17:52');
+ false
+ > SELECT 1 _FUNC_ NULL;
+ NULL
+ """)
case class GreaterThanOrEqual(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {
- override def inputType: AbstractDataType = TypeCollection.Ordered
-
override def symbol: String = ">="
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
index d7b493d521ddb..c6146042ef1a6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala
@@ -274,9 +274,15 @@ object ScalarSubquery {
case class ListQuery(
plan: LogicalPlan,
children: Seq[Expression] = Seq.empty,
- exprId: ExprId = NamedExpression.newExprId)
+ exprId: ExprId = NamedExpression.newExprId,
+ childOutputs: Seq[Attribute] = Seq.empty)
extends SubqueryExpression(plan, children, exprId) with Unevaluable {
- override def dataType: DataType = plan.schema.fields.head.dataType
+ override def dataType: DataType = if (childOutputs.length > 1) {
+ childOutputs.toStructType
+ } else {
+ childOutputs.head.dataType
+ }
+ override lazy val resolved: Boolean = childrenResolved && plan.resolved && childOutputs.nonEmpty
override def nullable: Boolean = false
override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan)
override def toString: String = s"list#${exprId.id} $conditionString"
@@ -284,7 +290,8 @@ case class ListQuery(
ListQuery(
plan.canonicalized,
children.map(_.canonicalized),
- ExprId(0))
+ ExprId(0),
+ childOutputs.map(_.canonicalized.asInstanceOf[Attribute]))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index 1fd680ab64b5a..652412b34478a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -64,6 +64,8 @@ private[sql] class JSONOptions(
parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true)
val allowBackslashEscapingAnyCharacter =
parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false)
+ private val allowUnquotedControlChars =
+ parameters.get("allowUnquotedControlChars").map(_.toBoolean).getOrElse(false)
val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName)
val parseMode: ParseMode =
parameters.get("mode").map(ParseMode.fromString).getOrElse(PermissiveMode)
@@ -92,5 +94,6 @@ private[sql] class JSONOptions(
factory.configure(JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS, allowNonNumericNumbers)
factory.configure(JsonParser.Feature.ALLOW_BACKSLASH_ESCAPING_ANY_CHARACTER,
allowBackslashEscapingAnyCharacter)
+ factory.configure(JsonParser.Feature.ALLOW_UNQUOTED_CONTROL_CHARS, allowUnquotedControlChars)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
index 1d302aea6fd16..eb06e4f304f0a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
@@ -26,8 +26,15 @@ import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData}
import org.apache.spark.sql.types._
+/**
+ * `JackGenerator` can only be initialized with a `StructType` or a `MapType`.
+ * Once it is initialized with `StructType`, it can be used to write out a struct or an array of
+ * struct. Once it is initialized with `MapType`, it can be used to write out a map or an array
+ * of map. An exception will be thrown if trying to write out a struct if it is initialized with
+ * a `MapType`, and vice verse.
+ */
private[sql] class JacksonGenerator(
- schema: StructType,
+ dataType: DataType,
writer: Writer,
options: JSONOptions) {
// A `ValueWriter` is responsible for writing a field of an `InternalRow` to appropriate
@@ -35,11 +42,34 @@ private[sql] class JacksonGenerator(
// we can directly access data in `ArrayData` without the help of `SpecificMutableRow`.
private type ValueWriter = (SpecializedGetters, Int) => Unit
+ // `JackGenerator` can only be initialized with a `StructType` or a `MapType`.
+ require(dataType.isInstanceOf[StructType] || dataType.isInstanceOf[MapType],
+ "JacksonGenerator only supports to be initialized with a StructType " +
+ s"or MapType but got ${dataType.simpleString}")
+
// `ValueWriter`s for all fields of the schema
- private val rootFieldWriters: Array[ValueWriter] = schema.map(_.dataType).map(makeWriter).toArray
+ private lazy val rootFieldWriters: Array[ValueWriter] = dataType match {
+ case st: StructType => st.map(_.dataType).map(makeWriter).toArray
+ case _ => throw new UnsupportedOperationException(
+ s"Initial type ${dataType.simpleString} must be a struct")
+ }
+
// `ValueWriter` for array data storing rows of the schema.
- private val arrElementWriter: ValueWriter = (arr: SpecializedGetters, i: Int) => {
- writeObject(writeFields(arr.getStruct(i, schema.length), schema, rootFieldWriters))
+ private lazy val arrElementWriter: ValueWriter = dataType match {
+ case st: StructType =>
+ (arr: SpecializedGetters, i: Int) => {
+ writeObject(writeFields(arr.getStruct(i, st.length), st, rootFieldWriters))
+ }
+ case mt: MapType =>
+ (arr: SpecializedGetters, i: Int) => {
+ writeObject(writeMapData(arr.getMap(i), mt, mapElementWriter))
+ }
+ }
+
+ private lazy val mapElementWriter: ValueWriter = dataType match {
+ case mt: MapType => makeWriter(mt.valueType)
+ case _ => throw new UnsupportedOperationException(
+ s"Initial type ${dataType.simpleString} must be a map")
}
private val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
@@ -189,18 +219,37 @@ private[sql] class JacksonGenerator(
def flush(): Unit = gen.flush()
/**
- * Transforms a single `InternalRow` to JSON object using Jackson
+ * Transforms a single `InternalRow` to JSON object using Jackson.
+ * This api calling will be validated through accessing `rootFieldWriters`.
*
* @param row The row to convert
*/
- def write(row: InternalRow): Unit = writeObject(writeFields(row, schema, rootFieldWriters))
+ def write(row: InternalRow): Unit = {
+ writeObject(writeFields(
+ fieldWriters = rootFieldWriters,
+ row = row,
+ schema = dataType.asInstanceOf[StructType]))
+ }
/**
- * Transforms multiple `InternalRow`s to JSON array using Jackson
+ * Transforms multiple `InternalRow`s or `MapData`s to JSON array using Jackson
*
- * @param array The array of rows to convert
+ * @param array The array of rows or maps to convert
*/
def write(array: ArrayData): Unit = writeArray(writeArrayData(array, arrElementWriter))
+ /**
+ * Transforms a single `MapData` to JSON object using Jackson
+ * This api calling will will be validated through accessing `mapElementWriter`.
+ *
+ * @param map a map to convert
+ */
+ def write(map: MapData): Unit = {
+ writeObject(writeMapData(
+ fieldWriter = mapElementWriter,
+ map = map,
+ mapType = dataType.asInstanceOf[MapType]))
+ }
+
def writeLineEnding(): Unit = gen.writeRaw('\n')
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala
index 3b23c6cd2816f..134d16e981a15 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonUtils.scala
@@ -44,7 +44,9 @@ object JacksonUtils {
case at: ArrayType => verifyType(name, at.elementType)
- case mt: MapType => verifyType(name, mt.keyType)
+ // For MapType, its keys are treated as a string (i.e. calling `toString`) basically when
+ // generating JSON, so we only care if the values are valid for JSON.
+ case mt: MapType => verifyType(name, mt.valueType)
case udt: UserDefinedType[_] => verifyType(name, udt.sqlType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index a51b385399d88..a602894efbcae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
/**
* Abstract class all optimizers should inherit of, contains the standard batches (extending
@@ -37,6 +38,12 @@ import org.apache.spark.sql.types._
abstract class Optimizer(sessionCatalog: SessionCatalog)
extends RuleExecutor[LogicalPlan] {
+ // Check for structural integrity of the plan in test mode. Currently we only check if a plan is
+ // still resolved after the execution of each rule.
+ override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
+ !Utils.isTesting || plan.resolved
+ }
+
protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations)
def batches: Seq[Batch] = {
@@ -79,11 +86,12 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
PushProjectionThroughUnion,
ReorderJoin,
EliminateOuterJoin,
+ InferFiltersFromConstraints,
+ BooleanSimplification,
PushPredicateThroughJoin,
PushDownPredicate,
LimitPushDown,
ColumnPruning,
- InferFiltersFromConstraints,
// Operator combine
CollapseRepartition,
CollapseProject,
@@ -380,21 +388,6 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper
result.asInstanceOf[A]
}
- /**
- * Splits the condition expression into small conditions by `And`, and partition them by
- * deterministic, and finally recombine them by `And`. It returns an expression containing
- * all deterministic expressions (the first field of the returned Tuple2) and an expression
- * containing all non-deterministic expressions (the second field of the returned Tuple2).
- */
- private def partitionByDeterministic(condition: Expression): (Expression, Expression) = {
- val andConditions = splitConjunctivePredicates(condition)
- andConditions.partition(_.deterministic) match {
- case (deterministic, nondeterministic) =>
- deterministic.reduceOption(And).getOrElse(Literal(true)) ->
- nondeterministic.reduceOption(And).getOrElse(Literal(true))
- }
- }
-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Push down deterministic projection through UNION ALL
@@ -738,8 +731,10 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper {
case Filter(Literal(true, BooleanType), child) => child
// If the filter condition always evaluate to null or false,
// replace the input with an empty relation.
- case Filter(Literal(null, _), child) => LocalRelation(child.output, data = Seq.empty)
- case Filter(Literal(false, BooleanType), child) => LocalRelation(child.output, data = Seq.empty)
+ case Filter(Literal(null, _), child) =>
+ LocalRelation(child.output, data = Seq.empty, isStreaming = plan.isStreaming)
+ case Filter(Literal(false, BooleanType), child) =>
+ LocalRelation(child.output, data = Seq.empty, isStreaming = plan.isStreaming)
// If any deterministic condition is guaranteed to be true given the constraints on the child's
// output, remove the condition
case f @ Filter(fc, p: LogicalPlan) =>
@@ -1171,15 +1166,18 @@ object DecimalAggregates extends Rule[LogicalPlan] {
* Converts local operations (i.e. ones that don't require data exchange) on LocalRelation to
* another LocalRelation.
*
- * This is relatively simple as it currently handles only a single case: Project.
+ * This is relatively simple as it currently handles only 2 single case: Project and Limit.
*/
object ConvertToLocalRelation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case Project(projectList, LocalRelation(output, data))
+ case Project(projectList, LocalRelation(output, data, isStreaming))
if !projectList.exists(hasUnevaluableExpr) =>
val projection = new InterpretedProjection(projectList, output)
projection.initialize(0)
- LocalRelation(projectList.map(_.toAttribute), data.map(projection))
+ LocalRelation(projectList.map(_.toAttribute), data.map(projection), isStreaming)
+
+ case Limit(IntegerLiteral(limit), LocalRelation(output, data, isStreaming)) =>
+ LocalRelation(output, data.take(limit), isStreaming)
}
private def hasUnevaluableExpr(expr: Expression): Boolean = {
@@ -1204,7 +1202,7 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
*/
object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case Deduplicate(keys, child, streaming) if !streaming =>
+ case Deduplicate(keys, child) if !child.isStreaming =>
val keyExprIds = keys.map(_.exprId)
val aggCols = child.output.map { attr =>
if (keyExprIds.contains(attr.exprId)) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
index 987cd7434b459..cfffa6bc2bfdd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
@@ -38,7 +38,8 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
case _ => false
}
- private def empty(plan: LogicalPlan) = LocalRelation(plan.output, data = Seq.empty)
+ private def empty(plan: LogicalPlan) =
+ LocalRelation(plan.output, data = Seq.empty, isStreaming = plan.isStreaming)
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case p: Union if p.children.forall(isEmptyLocalRelation) =>
@@ -65,11 +66,15 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper {
case _: RepartitionByExpression => empty(p)
// An aggregate with non-empty group expression will return one output row per group when the
// input to the aggregate is not empty. If the input to the aggregate is empty then all groups
- // will be empty and thus the output will be empty.
+ // will be empty and thus the output will be empty. If we're working on batch data, we can
+ // then treat the aggregate as redundant.
+ //
+ // If the aggregate is over streaming data, we may need to update the state store even if no
+ // new rows are processed, so we can't eliminate the node.
//
// If the grouping expressions are empty, however, then the aggregate will always produce a
// single output row and thus we cannot propagate the EmptyRelation.
- case Aggregate(ge, _, _) if ge.nonEmpty => empty(p)
+ case Aggregate(ge, _, _) if ge.nonEmpty && !p.isStreaming => empty(p)
// Generators like Hive-style UDTF may return their records within `close`.
case Generate(_: Explode, _, _, _, _, _) => empty(p)
case _ => p
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index 9dbb6b14aaac3..64b28565eb27c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -49,6 +49,33 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
}
}
+ private def dedupJoin(joinPlan: LogicalPlan): LogicalPlan = joinPlan match {
+ // SPARK-21835: It is possibly that the two sides of the join have conflicting attributes,
+ // the produced join then becomes unresolved and break structural integrity. We should
+ // de-duplicate conflicting attributes. We don't use transformation here because we only
+ // care about the most top join converted from correlated predicate subquery.
+ case j @ Join(left, right, joinType @ (LeftSemi | LeftAnti | ExistenceJoin(_)), joinCond) =>
+ val duplicates = right.outputSet.intersect(left.outputSet)
+ if (duplicates.nonEmpty) {
+ val aliasMap = AttributeMap(duplicates.map { dup =>
+ dup -> Alias(dup, dup.toString)()
+ }.toSeq)
+ val aliasedExpressions = right.output.map { ref =>
+ aliasMap.getOrElse(ref, ref)
+ }
+ val newRight = Project(aliasedExpressions, right)
+ val newJoinCond = joinCond.map { condExpr =>
+ condExpr transform {
+ case a: Attribute => aliasMap.getOrElse(a, a).toAttribute
+ }
+ }
+ Join(left, newRight, joinType, newJoinCond)
+ } else {
+ j
+ }
+ case _ => joinPlan
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Filter(condition, child) =>
val (withSubquery, withoutSubquery) =
@@ -64,15 +91,18 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
withSubquery.foldLeft(newFilter) {
case (p, Exists(sub, conditions, _)) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
- Join(outerPlan, sub, LeftSemi, joinCond)
+ // Deduplicate conflicting attributes if any.
+ dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
case (p, Not(Exists(sub, conditions, _))) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
- Join(outerPlan, sub, LeftAnti, joinCond)
- case (p, In(value, Seq(ListQuery(sub, conditions, _)))) =>
+ // Deduplicate conflicting attributes if any.
+ dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond))
+ case (p, In(value, Seq(ListQuery(sub, conditions, _, _)))) =>
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
- Join(outerPlan, sub, LeftSemi, joinCond)
- case (p, Not(In(value, Seq(ListQuery(sub, conditions, _))))) =>
+ // Deduplicate conflicting attributes if any.
+ dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
+ case (p, Not(In(value, Seq(ListQuery(sub, conditions, _, _))))) =>
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
// Construct the condition. A NULL in one of the conditions is regarded as a positive
// result; such a row will be filtered out by the Anti-Join operator.
@@ -93,7 +123,8 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// will have the final conditions in the LEFT ANTI as
// (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2)
val pairs = (joinConds.map(c => Or(c, IsNull(c))) ++ conditions).reduceLeft(And)
- Join(outerPlan, sub, LeftAnti, Option(pairs))
+ // Deduplicate conflicting attributes if any.
+ dedupJoin(Join(outerPlan, sub, LeftAnti, Option(pairs)))
case (p, predicate) =>
val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
Project(p.output, Filter(newCond.get, inputPlan))
@@ -114,13 +145,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
e transformUp {
case Exists(sub, conditions, _) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
- newPlan = Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
+ // Deduplicate conflicting attributes if any.
+ newPlan = dedupJoin(
+ Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)))
exists
- case In(value, Seq(ListQuery(sub, conditions, _))) =>
+ case In(value, Seq(ListQuery(sub, conditions, _, _))) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
val inConditions = getValueExpression(value).zip(sub.output).map(EqualTo.tupled)
val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
- newPlan = Join(newPlan, sub, ExistenceJoin(exists), newConditions)
+ // Deduplicate conflicting attributes if any.
+ newPlan = dedupJoin(Join(newPlan, sub, ExistenceJoin(exists), newConditions))
exists
}
}
@@ -227,9 +261,9 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
case Exists(sub, children, exprId) if children.nonEmpty =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
Exists(newPlan, newCond, exprId)
- case ListQuery(sub, _, exprId) =>
+ case ListQuery(sub, _, exprId, childOutputs) =>
val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
- ListQuery(newPlan, newCond, exprId)
+ ListQuery(newPlan, newCond, exprId, childOutputs)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 0706e044c3286..891f61698f177 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -31,6 +31,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.catalog.CatalogStorageFormat
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
@@ -89,10 +90,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
visitSparkDataType(ctx.dataType)
}
- override def visitStandaloneColTypeList(ctx: StandaloneColTypeListContext): Seq[StructField] =
- withOrigin(ctx) {
- visitColTypeList(ctx.colTypeList)
- }
+ override def visitSingleTableSchema(ctx: SingleTableSchemaContext): StructType = {
+ withOrigin(ctx)(StructType(visitColTypeList(ctx.colTypeList)))
+ }
/* ********************************************************************************************
* Plan parsing
@@ -179,11 +179,64 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
}
/**
- * Add an INSERT INTO [TABLE]/INSERT OVERWRITE TABLE operation to the logical plan.
+ * Parameters used for writing query to a table:
+ * (tableIdentifier, partitionKeys, exists).
+ */
+ type InsertTableParams = (TableIdentifier, Map[String, Option[String]], Boolean)
+
+ /**
+ * Parameters used for writing query to a directory: (isLocal, CatalogStorageFormat, provider).
+ */
+ type InsertDirParams = (Boolean, CatalogStorageFormat, Option[String])
+
+ /**
+ * Add an
+ * {{{
+ * INSERT OVERWRITE TABLE tableIdentifier [partitionSpec [IF NOT EXISTS]]?
+ * INSERT INTO [TABLE] tableIdentifier [partitionSpec]
+ * INSERT OVERWRITE [LOCAL] DIRECTORY STRING [rowFormat] [createFileFormat]
+ * INSERT OVERWRITE [LOCAL] DIRECTORY [STRING] tableProvider [OPTIONS tablePropertyList]
+ * }}}
+ * operation to logical plan
*/
private def withInsertInto(
ctx: InsertIntoContext,
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ ctx match {
+ case table: InsertIntoTableContext =>
+ val (tableIdent, partitionKeys, exists) = visitInsertIntoTable(table)
+ InsertIntoTable(UnresolvedRelation(tableIdent), partitionKeys, query, false, exists)
+ case table: InsertOverwriteTableContext =>
+ val (tableIdent, partitionKeys, exists) = visitInsertOverwriteTable(table)
+ InsertIntoTable(UnresolvedRelation(tableIdent), partitionKeys, query, true, exists)
+ case dir: InsertOverwriteDirContext =>
+ val (isLocal, storage, provider) = visitInsertOverwriteDir(dir)
+ InsertIntoDir(isLocal, storage, provider, query, overwrite = true)
+ case hiveDir: InsertOverwriteHiveDirContext =>
+ val (isLocal, storage, provider) = visitInsertOverwriteHiveDir(hiveDir)
+ InsertIntoDir(isLocal, storage, provider, query, overwrite = true)
+ case _ =>
+ throw new ParseException("Invalid InsertIntoContext", ctx)
+ }
+ }
+
+ /**
+ * Add an INSERT INTO TABLE operation to the logical plan.
+ */
+ override def visitInsertIntoTable(
+ ctx: InsertIntoTableContext): InsertTableParams = withOrigin(ctx) {
+ val tableIdent = visitTableIdentifier(ctx.tableIdentifier)
+ val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
+
+ (tableIdent, partitionKeys, false)
+ }
+
+ /**
+ * Add an INSERT OVERWRITE TABLE operation to the logical plan.
+ */
+ override def visitInsertOverwriteTable(
+ ctx: InsertOverwriteTableContext): InsertTableParams = withOrigin(ctx) {
+ assert(ctx.OVERWRITE() != null)
val tableIdent = visitTableIdentifier(ctx.tableIdentifier)
val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
@@ -193,12 +246,23 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
"partitions with value: " + dynamicPartitionKeys.keys.mkString("[", ",", "]"), ctx)
}
- InsertIntoTable(
- UnresolvedRelation(tableIdent),
- partitionKeys,
- query,
- ctx.OVERWRITE != null,
- ctx.EXISTS != null)
+ (tableIdent, partitionKeys, ctx.EXISTS() != null)
+ }
+
+ /**
+ * Write to a directory, returning a [[InsertIntoDir]] logical plan.
+ */
+ override def visitInsertOverwriteDir(
+ ctx: InsertOverwriteDirContext): InsertDirParams = withOrigin(ctx) {
+ throw new ParseException("INSERT OVERWRITE DIRECTORY is not supported", ctx)
+ }
+
+ /**
+ * Write to a directory, returning a [[InsertIntoDir]] logical plan.
+ */
+ override def visitInsertOverwriteHiveDir(
+ ctx: InsertOverwriteHiveDirContext): InsertDirParams = withOrigin(ctx) {
+ throw new ParseException("INSERT OVERWRITE DIRECTORY is not supported", ctx)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
index f8492dd8c882e..0d9ad218e48db 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
@@ -61,7 +61,7 @@ abstract class AbstractSqlParser extends ParserInterface with Logging {
* definitions which will preserve the correct Hive metadata.
*/
override def parseTableSchema(sqlText: String): StructType = parse(sqlText) { parser =>
- StructType(astBuilder.visitStandaloneColTypeList(parser.standaloneColTypeList()))
+ astBuilder.visitSingleTableSchema(parser.singleTableSchema())
}
/** Creates LogicalPlan for a given SQL string. */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
index 1c986fbde7ada..d73d7e73f28d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
@@ -43,7 +43,10 @@ object LocalRelation {
}
}
-case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil)
+case class LocalRelation(output: Seq[Attribute],
+ data: Seq[InternalRow] = Nil,
+ // Indicates whether this relation has data from a streaming source.
+ override val isStreaming: Boolean = false)
extends LeafNode with analysis.MultiInstanceRelation {
// A local relation must have resolved output.
@@ -55,7 +58,7 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil)
* query.
*/
override final def newInstance(): this.type = {
- LocalRelation(output.map(_.newInstance()), data).asInstanceOf[this.type]
+ LocalRelation(output.map(_.newInstance()), data, isStreaming).asInstanceOf[this.type]
}
override protected def stringArgs: Iterator[Any] = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 9b440cd99f994..68aae720e026a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -47,7 +47,7 @@ abstract class LogicalPlan
*/
def analyzed: Boolean = _analyzed
- /** Returns true if this subtree contains any streaming data sources. */
+ /** Returns true if this subtree has data from a streaming data source. */
def isStreaming: Boolean = children.exists(_.isStreaming == true)
/**
@@ -297,7 +297,6 @@ abstract class UnaryNode extends LogicalPlan {
case expr: Expression if expr.semanticEquals(e) =>
a.toAttribute
})
- allConstraints += EqualNullSafe(e, a.toAttribute)
case _ => // Don't change.
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala
index 8bffbd0c208cb..b0f611fd38dea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala
@@ -106,91 +106,48 @@ trait QueryPlanConstraints { self: LogicalPlan =>
* Infers an additional set of constraints from a given set of equality constraints.
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
* additional constraint of the form `b = 5`.
- *
- * [SPARK-17733] We explicitly prevent producing recursive constraints of the form `a = f(a, b)`
- * as they are often useless and can lead to a non-converging set of constraints.
*/
private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
- val constraintClasses = generateEquivalentConstraintClasses(constraints)
-
+ val aliasedConstraints = eliminateAliasedExpressionInConstraints(constraints)
var inferredConstraints = Set.empty[Expression]
- constraints.foreach {
+ aliasedConstraints.foreach {
case eq @ EqualTo(l: Attribute, r: Attribute) =>
- val candidateConstraints = constraints - eq
- inferredConstraints ++= candidateConstraints.map(_ transform {
- case a: Attribute if a.semanticEquals(l) &&
- !isRecursiveDeduction(r, constraintClasses) => r
- })
- inferredConstraints ++= candidateConstraints.map(_ transform {
- case a: Attribute if a.semanticEquals(r) &&
- !isRecursiveDeduction(l, constraintClasses) => l
- })
+ val candidateConstraints = aliasedConstraints - eq
+ inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
+ inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
case _ => // No inference
}
inferredConstraints -- constraints
}
/**
- * Generate a sequence of expression sets from constraints, where each set stores an equivalence
- * class of expressions. For example, Set(`a = b`, `b = c`, `e = f`) will generate the following
- * expression sets: (Set(a, b, c), Set(e, f)). This will be used to search all expressions equal
- * to an selected attribute.
+ * Replace the aliased expression in [[Alias]] with the alias name if both exist in constraints.
+ * Thus non-converging inference can be prevented.
+ * E.g. `Alias(b, f(a)), a = b` infers `f(a) = f(f(a))` without eliminating aliased expressions.
+ * Also, the size of constraints is reduced without losing any information.
+ * When the inferred filters are pushed down the operators that generate the alias,
+ * the alias names used in filters are replaced by the aliased expressions.
*/
- private def generateEquivalentConstraintClasses(
- constraints: Set[Expression]): Seq[Set[Expression]] = {
- var constraintClasses = Seq.empty[Set[Expression]]
- constraints.foreach {
- case eq @ EqualTo(l: Attribute, r: Attribute) =>
- // Transform [[Alias]] to its child.
- val left = aliasMap.getOrElse(l, l)
- val right = aliasMap.getOrElse(r, r)
- // Get the expression set for an equivalence constraint class.
- val leftConstraintClass = getConstraintClass(left, constraintClasses)
- val rightConstraintClass = getConstraintClass(right, constraintClasses)
- if (leftConstraintClass.nonEmpty && rightConstraintClass.nonEmpty) {
- // Combine the two sets.
- constraintClasses = constraintClasses
- .diff(leftConstraintClass :: rightConstraintClass :: Nil) :+
- (leftConstraintClass ++ rightConstraintClass)
- } else if (leftConstraintClass.nonEmpty) { // && rightConstraintClass.isEmpty
- // Update equivalence class of `left` expression.
- constraintClasses = constraintClasses
- .diff(leftConstraintClass :: Nil) :+ (leftConstraintClass + right)
- } else if (rightConstraintClass.nonEmpty) { // && leftConstraintClass.isEmpty
- // Update equivalence class of `right` expression.
- constraintClasses = constraintClasses
- .diff(rightConstraintClass :: Nil) :+ (rightConstraintClass + left)
- } else { // leftConstraintClass.isEmpty && rightConstraintClass.isEmpty
- // Create new equivalence constraint class since neither expression presents
- // in any classes.
- constraintClasses = constraintClasses :+ Set(left, right)
- }
- case _ => // Skip
+ private def eliminateAliasedExpressionInConstraints(constraints: Set[Expression])
+ : Set[Expression] = {
+ val attributesInEqualTo = constraints.flatMap {
+ case EqualTo(l: Attribute, r: Attribute) => l :: r :: Nil
+ case _ => Nil
}
-
- constraintClasses
- }
-
- /**
- * Get all expressions equivalent to the selected expression.
- */
- private def getConstraintClass(
- expr: Expression,
- constraintClasses: Seq[Set[Expression]]): Set[Expression] =
- constraintClasses.find(_.contains(expr)).getOrElse(Set.empty[Expression])
-
- /**
- * Check whether replace by an [[Attribute]] will cause a recursive deduction. Generally it
- * has the form like: `a -> f(a, b)`, where `a` and `b` are expressions and `f` is a function.
- * Here we first get all expressions equal to `attr` and then check whether at least one of them
- * is a child of the referenced expression.
- */
- private def isRecursiveDeduction(
- attr: Attribute,
- constraintClasses: Seq[Set[Expression]]): Boolean = {
- val expr = aliasMap.getOrElse(attr, attr)
- getConstraintClass(expr, constraintClasses).exists { e =>
- expr.children.exists(_.semanticEquals(e))
+ var aliasedConstraints = constraints
+ attributesInEqualTo.foreach { a =>
+ if (aliasMap.contains(a)) {
+ val child = aliasMap.get(a).get
+ aliasedConstraints = replaceConstraints(aliasedConstraints, child, a)
+ }
}
+ aliasedConstraints
}
+
+ private def replaceConstraints(
+ constraints: Set[Expression],
+ source: Expression,
+ destination: Attribute): Set[Expression] = constraints.map(_ transform {
+ case e: Expression if e.semanticEquals(source) => destination
+ })
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 303014e0b8d31..f443cd5a69de3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.catalog.CatalogTable
+import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
@@ -359,6 +359,30 @@ case class InsertIntoTable(
override lazy val resolved: Boolean = false
}
+/**
+ * Insert query result into a directory.
+ *
+ * @param isLocal Indicates whether the specified directory is local directory
+ * @param storage Info about output file, row and what serialization format
+ * @param provider Specifies what data source to use; only used for data source file.
+ * @param child The query to be executed
+ * @param overwrite If true, the existing directory will be overwritten
+ *
+ * Note that this plan is unresolved and has to be replaced by the concrete implementations
+ * during analysis.
+ */
+case class InsertIntoDir(
+ isLocal: Boolean,
+ storage: CatalogStorageFormat,
+ provider: Option[String],
+ child: LogicalPlan,
+ overwrite: Boolean = true)
+ extends UnaryNode {
+
+ override def output: Seq[Attribute] = Seq.empty
+ override lazy val resolved: Boolean = false
+}
+
/**
* A container for holding the view description(CatalogTable), and the output of the view. The
* child should be a logical plan parsed from the `CatalogTable.viewText`, should throw an error
@@ -429,9 +453,10 @@ case class Sort(
/** Factory for constructing new `Range` nodes. */
object Range {
- def apply(start: Long, end: Long, step: Long, numSlices: Option[Int]): Range = {
+ def apply(start: Long, end: Long, step: Long,
+ numSlices: Option[Int], isStreaming: Boolean = false): Range = {
val output = StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes
- new Range(start, end, step, numSlices, output)
+ new Range(start, end, step, numSlices, output, isStreaming)
}
def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = {
Range(start, end, step, Some(numSlices))
@@ -443,7 +468,8 @@ case class Range(
end: Long,
step: Long,
numSlices: Option[Int],
- output: Seq[Attribute])
+ output: Seq[Attribute],
+ override val isStreaming: Boolean)
extends LeafNode with MultiInstanceRelation {
require(step != 0, s"step ($step) cannot be 0")
@@ -784,8 +810,7 @@ case class OneRowRelation() extends LeafNode {
/** A logical plan for `dropDuplicates`. */
case class Deduplicate(
keys: Seq[Attribute],
- child: LogicalPlan,
- streaming: Boolean) extends UnaryNode {
+ child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
index 85b368c862630..7e4b784033bfc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala
@@ -63,6 +63,13 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
/** Defines a sequence of rule batches, to be overridden by the implementation. */
protected def batches: Seq[Batch]
+ /**
+ * Defines a check function that checks for structural integrity of the plan after the execution
+ * of each rule. For example, we can check whether a plan is still resolved after each rule in
+ * `Optimizer`, so we can catch rules that return invalid plans. The check function returns
+ * `false` if the given plan doesn't pass the structural integrity check.
+ */
+ protected def isPlanIntegral(plan: TreeType): Boolean = true
/**
* Executes the batches of rules defined by the subclass. The batches are executed serially
@@ -94,6 +101,13 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
""".stripMargin)
}
+ // Run the structural integrity checker against the plan after each rule.
+ if (!isPlanIntegral(result)) {
+ val message = s"After applying rule ${rule.ruleName} in batch ${batch.name}, " +
+ "the structural integrity of the plan is broken."
+ throw new TreeNodeException(result, message, null)
+ }
+
result
}
iteration += 1
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index 45225779bffcb..1dcda49a3af6a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -65,6 +65,7 @@ object TypeUtils {
case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
+ case udt: UserDefinedType[_] => getInterpretedOrdering(udt.sqlType)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 8397ea917dee7..c779c468efeb8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -289,11 +289,6 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
- val PARQUET_CACHE_METADATA = buildConf("spark.sql.parquet.cacheMetadata")
- .doc("Turns on caching of Parquet schema metadata. Can speed up querying of static data.")
- .booleanConf
- .createWithDefault(true)
-
val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec")
.doc("Sets the compression codec use when writing Parquet files. Acceptable values include: " +
"uncompressed, snappy, gzip, lzo.")
@@ -332,6 +327,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec")
+ .doc("Sets the compression codec use when writing ORC files. Acceptable values include: " +
+ "none, uncompressed, snappy, zlib, lzo.")
+ .stringConf
+ .transform(_.toLowerCase(Locale.ROOT))
+ .checkValues(Set("none", "uncompressed", "snappy", "zlib", "lzo"))
+ .createWithDefault("snappy")
+
val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown")
.doc("When true, enable filter pushdown for ORC files.")
.booleanConf
@@ -561,9 +564,9 @@ object SQLConf {
.intConf
.createWithDefault(100)
- val WHOLESTAGE_FALLBACK = buildConf("spark.sql.codegen.fallback")
+ val CODEGEN_FALLBACK = buildConf("spark.sql.codegen.fallback")
.internal()
- .doc("When true, whole stage codegen could be temporary disabled for the part of query that" +
+ .doc("When true, (whole stage) codegen could be temporary disabled for the part of query that" +
" fail to compile generated code")
.booleanConf
.createWithDefault(true)
@@ -587,10 +590,10 @@ object SQLConf {
.doc("The maximum lines of a single Java function generated by whole-stage codegen. " +
"When the generated function exceeds this threshold, " +
"the whole-stage codegen is deactivated for this subtree of the current query plan. " +
- "The default value 2667 is the max length of byte code JIT supported " +
- "for a single function(8000) divided by 3.")
+ "The default value 4000 is the max length of byte code JIT supported " +
+ "for a single function(8000) divided by 2.")
.intConf
- .createWithDefault(2667)
+ .createWithDefault(4000)
val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes")
.doc("The maximum number of bytes to pack into a single partition when reading files.")
@@ -1008,9 +1011,9 @@ class SQLConf extends Serializable with Logging {
def useCompression: Boolean = getConf(COMPRESS_CACHED)
- def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION)
+ def orcCompressionCodec: String = getConf(ORC_COMPRESSION)
- def parquetCacheMetadata: Boolean = getConf(PARQUET_CACHE_METADATA)
+ def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION)
def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED)
@@ -1053,7 +1056,7 @@ class SQLConf extends Serializable with Logging {
def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS)
- def wholeStageFallback: Boolean = getConf(WHOLESTAGE_FALLBACK)
+ def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK)
def maxCaseBranchesForCodegen: Int = getConf(MAX_CASES_BRANCHES)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
index 1d54ff5825c2e..3041f44b116ea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala
@@ -78,18 +78,6 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType])
private[sql] object TypeCollection {
- /**
- * Types that can be ordered/compared. In the long run we should probably make this a trait
- * that can be mixed into each data type, and perhaps create an `AbstractDataType`.
- */
- // TODO: Should we consolidate this with RowOrdering.isOrderable?
- val Ordered = TypeCollection(
- BooleanType,
- ByteType, ShortType, IntegerType, LongType,
- FloatType, DoubleType, DecimalType,
- TimestampType, DateType,
- StringType, BinaryType)
-
/**
* Types that include numeric types and interval type. They are only used in unary_minus,
* unary_positive, add and subtract operations.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 4e0613619add6..884e113537c93 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -505,7 +505,7 @@ class AnalysisErrorSuite extends AnalysisTest {
right,
joinType = Cross,
condition = Some('b === 'd))
- assertAnalysisError(plan2, "Cannot use map type in EqualTo" :: Nil)
+ assertAnalysisError(plan2, "EqualTo does not support ordering on type MapType" :: Nil)
}
test("PredicateSubQuery is used outside of a filter") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index e5fcd60b2d3da..e56a5d6368318 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
import java.util.TimeZone
-import org.scalatest.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
-class AnalysisSuite extends AnalysisTest with ShouldMatchers {
+class AnalysisSuite extends AnalysisTest with Matchers {
import org.apache.spark.sql.catalyst.analysis.TestRelations._
test("union project *") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 30725773a37b1..36714bd631b0e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.types.{LongType, StringType, TypeCollection}
+import org.apache.spark.sql.types._
class ExpressionTypeCheckingSuite extends SparkFunSuite {
@@ -109,16 +109,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField))
assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField))
- assertError(EqualTo('mapField, 'mapField), "Cannot use map type in EqualTo")
- assertError(EqualNullSafe('mapField, 'mapField), "Cannot use map type in EqualNullSafe")
+ assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type MapType")
+ assertError(EqualNullSafe('mapField, 'mapField),
+ "EqualNullSafe does not support ordering on type MapType")
assertError(LessThan('mapField, 'mapField),
- s"requires ${TypeCollection.Ordered.simpleString} type")
+ "LessThan does not support ordering on type MapType")
assertError(LessThanOrEqual('mapField, 'mapField),
- s"requires ${TypeCollection.Ordered.simpleString} type")
+ "LessThanOrEqual does not support ordering on type MapType")
assertError(GreaterThan('mapField, 'mapField),
- s"requires ${TypeCollection.Ordered.simpleString} type")
+ "GreaterThan does not support ordering on type MapType")
assertError(GreaterThanOrEqual('mapField, 'mapField),
- s"requires ${TypeCollection.Ordered.simpleString} type")
+ "GreaterThanOrEqual does not support ordering on type MapType")
assertError(If('intField, 'stringField, 'stringField),
"type of predicate expression in If should be boolean")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
index d0fe815052256..9e99c8e11cdfe 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala
@@ -93,7 +93,7 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter {
val table = UnresolvedInlineTable(Seq("c1"),
Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType))))
val withTimeZone = ResolveTimeZone(conf).apply(table)
- val LocalRelation(output, data) = ResolveInlineTables(conf).apply(withTimeZone)
+ val LocalRelation(output, data, _) = ResolveInlineTables(conf).apply(withTimeZone)
val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType)
.withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long]
assert(output.map(_.dataType) == Seq(TimestampType))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
index f68d930f60523..4de75866e04a3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
@@ -368,18 +368,18 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
Aggregate(
Seq(attributeWithWatermark),
aggExprs("c"),
- Deduplicate(Seq(att), streamRelation, streaming = true)),
+ Deduplicate(Seq(att), streamRelation)),
outputMode = Append)
assertNotSupportedInStreamingPlan(
"Deduplicate - Deduplicate on streaming relation after aggregation",
- Deduplicate(Seq(att), Aggregate(Nil, aggExprs("c"), streamRelation), streaming = true),
+ Deduplicate(Seq(att), Aggregate(Nil, aggExprs("c"), streamRelation)),
outputMode = Complete,
expectedMsgs = Seq("dropDuplicates"))
assertSupportedInStreamingPlan(
"Deduplicate - Deduplicate on batch relation inside a streaming query",
- Deduplicate(Seq(att), batchRelation, streaming = false),
+ Deduplicate(Seq(att), batchRelation),
outputMode = Append
)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala
index 273f95f91ee50..b6e8b667a2400 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeSetSuite.scala
@@ -78,4 +78,44 @@ class AttributeSetSuite extends SparkFunSuite {
assert(aSet == aSet)
assert(aSet == AttributeSet(aUpper :: Nil))
}
+
+ test("SPARK-18394 keep a deterministic output order along with attribute names and exprIds") {
+ // Checks a simple case
+ val attrSeqA = {
+ val attr1 = AttributeReference("c1", IntegerType)(exprId = ExprId(1098))
+ val attr2 = AttributeReference("c2", IntegerType)(exprId = ExprId(107))
+ val attr3 = AttributeReference("c3", IntegerType)(exprId = ExprId(838))
+ val attrSetA = AttributeSet(attr1 :: attr2 :: attr3 :: Nil)
+
+ val attr4 = AttributeReference("c4", IntegerType)(exprId = ExprId(389))
+ val attr5 = AttributeReference("c5", IntegerType)(exprId = ExprId(89329))
+
+ val attrSetB = AttributeSet(attr4 :: attr5 :: Nil)
+ (attrSetA ++ attrSetB).toSeq.map(_.name)
+ }
+
+ val attrSeqB = {
+ val attr1 = AttributeReference("c1", IntegerType)(exprId = ExprId(392))
+ val attr2 = AttributeReference("c2", IntegerType)(exprId = ExprId(92))
+ val attr3 = AttributeReference("c3", IntegerType)(exprId = ExprId(87))
+ val attrSetA = AttributeSet(attr1 :: attr2 :: attr3 :: Nil)
+
+ val attr4 = AttributeReference("c4", IntegerType)(exprId = ExprId(9023920))
+ val attr5 = AttributeReference("c5", IntegerType)(exprId = ExprId(522))
+ val attrSetB = AttributeSet(attr4 :: attr5 :: Nil)
+
+ (attrSetA ++ attrSetB).toSeq.map(_.name)
+ }
+
+ assert(attrSeqA === attrSeqB)
+
+ // Checks the same column names having different exprIds
+ val attr1 = AttributeReference("c", IntegerType)(exprId = ExprId(1098))
+ val attr2 = AttributeReference("c", IntegerType)(exprId = ExprId(107))
+ val attrSetA = AttributeSet(attr1 :: attr2 :: Nil)
+ val attr3 = AttributeReference("c", IntegerType)(exprId = ExprId(389))
+ val attrSetB = AttributeSet(attr3 :: Nil)
+
+ assert((attrSetA ++ attrSetB).toSeq === attr2 :: attr3 :: attr1 :: Nil)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 0496d611ec3c7..b4c8eab19c5cc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -25,7 +25,7 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
-import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
+import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
@@ -188,7 +188,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
expected: Any,
inputRow: InternalRow = EmptyRow): Unit = {
val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation())
- val optimizedPlan = SimpleTestOptimizer.execute(plan)
+ // We should analyze the plan first, otherwise we possibly optimize an unresolved plan.
+ val analyzedPlan = SimpleAnalyzer.execute(plan)
+ val optimizedPlan = SimpleTestOptimizer.execute(analyzedPlan)
checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala
index d617ad540d5ff..a1000a0e80799 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala
@@ -210,4 +210,13 @@ class ExpressionSetSuite extends SparkFunSuite {
assert((initialSet - (aLower + 1)).size == 0)
}
+
+ test("add multiple elements to set") {
+ val initialSet = ExpressionSet(aUpper + 1 :: Nil)
+ val setToAddWithSameExpression = ExpressionSet(aUpper + 1 :: aUpper + 2 :: Nil)
+ val setToAddWithOutSameExpression = ExpressionSet(aUpper + 3 :: aUpper + 4 :: Nil)
+
+ assert((initialSet ++ setToAddWithSameExpression).size == 2)
+ assert((initialSet ++ setToAddWithOutSameExpression).size == 3)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
index f892e80204603..a0bbe02f92354 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
@@ -21,7 +21,8 @@ import java.util.Calendar
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, GenericArrayData, PermissiveMode}
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeTestUtils, DateTimeUtils, GenericArrayData, PermissiveMode}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -363,6 +364,26 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
InternalRow(UTF8String.fromString("b\nc")))
}
+ test("SPARK-21677: json_tuple throws NullPointException when column is null as string type") {
+ checkJsonTuple(
+ JsonTuple(Literal("""{"f1": 1, "f2": 2}""") ::
+ NonFoldableLiteral("f1") ::
+ NonFoldableLiteral("cast(NULL AS STRING)") ::
+ NonFoldableLiteral("f2") ::
+ Nil),
+ InternalRow(UTF8String.fromString("1"), null, UTF8String.fromString("2")))
+ }
+
+ test("SPARK-21804: json_tuple returns null values within repeated columns except the first one") {
+ checkJsonTuple(
+ JsonTuple(Literal("""{"f1": 1, "f2": 2}""") ::
+ NonFoldableLiteral("f1") ::
+ NonFoldableLiteral("cast(NULL AS STRING)") ::
+ NonFoldableLiteral("f1") ::
+ Nil),
+ InternalRow(UTF8String.fromString("1"), null, UTF8String.fromString("1")))
+ }
+
val gmtId = Option(DateTimeUtils.TimeZoneGMT.getID)
test("from_json") {
@@ -590,4 +611,73 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
"""{"t":"2015-12-31T16:00:00"}"""
)
}
+
+ test("SPARK-21513: to_json support map[string, struct] to json") {
+ val schema = MapType(StringType, StructType(StructField("a", IntegerType) :: Nil))
+ val input = Literal.create(ArrayBasedMapData(Map("test" -> InternalRow(1))), schema)
+ checkEvaluation(
+ StructsToJson(Map.empty, input),
+ """{"test":{"a":1}}"""
+ )
+ }
+
+ test("SPARK-21513: to_json support map[struct, struct] to json") {
+ val schema = MapType(StructType(StructField("a", IntegerType) :: Nil),
+ StructType(StructField("b", IntegerType) :: Nil))
+ val input = Literal.create(ArrayBasedMapData(Map(InternalRow(1) -> InternalRow(2))), schema)
+ checkEvaluation(
+ StructsToJson(Map.empty, input),
+ """{"[1]":{"b":2}}"""
+ )
+ }
+
+ test("SPARK-21513: to_json support map[string, integer] to json") {
+ val schema = MapType(StringType, IntegerType)
+ val input = Literal.create(ArrayBasedMapData(Map("a" -> 1)), schema)
+ checkEvaluation(
+ StructsToJson(Map.empty, input),
+ """{"a":1}"""
+ )
+ }
+
+ test("to_json - array with maps") {
+ val inputSchema = ArrayType(MapType(StringType, IntegerType))
+ val input = new GenericArrayData(ArrayBasedMapData(
+ Map("a" -> 1)) :: ArrayBasedMapData(Map("b" -> 2)) :: Nil)
+ val output = """[{"a":1},{"b":2}]"""
+ checkEvaluation(
+ StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId),
+ output)
+ }
+
+ test("to_json - array with single map") {
+ val inputSchema = ArrayType(MapType(StringType, IntegerType))
+ val input = new GenericArrayData(ArrayBasedMapData(Map("a" -> 1)) :: Nil)
+ val output = """[{"a":1}]"""
+ checkEvaluation(
+ StructsToJson(Map.empty, Literal.create(input, inputSchema), gmtId),
+ output)
+ }
+
+ test("to_json: verify MapType's value type instead of key type") {
+ // Keys in map are treated as strings when converting to JSON. The type doesn't matter at all.
+ val mapType1 = MapType(CalendarIntervalType, IntegerType)
+ val schema1 = StructType(StructField("a", mapType1) :: Nil)
+ val struct1 = Literal.create(null, schema1)
+ checkEvaluation(
+ StructsToJson(Map.empty, struct1, gmtId),
+ null
+ )
+
+ // The value type must be valid for converting to JSON.
+ val mapType2 = MapType(IntegerType, CalendarIntervalType)
+ val schema2 = StructType(StructField("a", mapType2) :: Nil)
+ val struct2 = Literal.create(null, schema2)
+ intercept[TreeNodeException[_]] {
+ checkEvaluation(
+ StructsToJson(Map.empty, struct2, gmtId),
+ null
+ )
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index ef510a95ef446..1438a88c19e0b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -17,12 +17,15 @@
package org.apache.spark.sql.catalyst.expressions
+import java.sql.{Date, Timestamp}
+
import scala.collection.immutable.HashSet
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.types._
@@ -120,7 +123,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
(null, false, null) ::
(null, null, null) :: Nil)
- test("IN") {
+ test("basic IN predicate test") {
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1),
Literal(2))), null)
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType),
@@ -148,19 +151,32 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false)
- val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
- LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
- primitiveTypes.foreach { t =>
- val dataGen = RandomDataGenerator.forType(t, nullable = true).get
+ }
+
+ test("IN with different types") {
+ def testWithRandomDataGeneration(dataType: DataType, nullable: Boolean): Unit = {
+ val maybeDataGen = RandomDataGenerator.forType(dataType, nullable = nullable)
+ // Actually we won't pass in unsupported data types, this is a safety check.
+ val dataGen = maybeDataGen.getOrElse(
+ fail(s"Failed to create data generator for type $dataType"))
val inputData = Seq.fill(10) {
val value = dataGen.apply()
- value match {
+ def cleanData(value: Any) = value match {
case d: Double if d.isNaN => 0.0d
case f: Float if f.isNaN => 0.0f
case _ => value
}
+ value match {
+ case s: Seq[_] => s.map(cleanData(_))
+ case m: Map[_, _] =>
+ val pair = m.unzip
+ val newKeys = pair._1.map(cleanData(_))
+ val newValues = pair._2.map(cleanData(_))
+ newKeys.zip(newValues).toMap
+ case _ => cleanData(value)
+ }
}
- val input = inputData.map(NonFoldableLiteral.create(_, t))
+ val input = inputData.map(NonFoldableLiteral.create(_, dataType))
val expected = if (inputData(0) == null) {
null
} else if (inputData.slice(1, 10).contains(inputData(0))) {
@@ -172,6 +188,55 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
}
checkEvaluation(In(input(0), input.slice(1, 10)), expected)
}
+
+ val atomicTypes = DataTypeTestUtils.atomicTypes.filter { t =>
+ RandomDataGenerator.forType(t).isDefined && !t.isInstanceOf[DecimalType]
+ } ++ Seq(DecimalType.USER_DEFAULT)
+
+ val atomicArrayTypes = atomicTypes.map(ArrayType(_, containsNull = true))
+
+ // Basic types:
+ for (
+ dataType <- atomicTypes;
+ nullable <- Seq(true, false)) {
+ testWithRandomDataGeneration(dataType, nullable)
+ }
+
+ // Array types:
+ for (
+ arrayType <- atomicArrayTypes;
+ nullable <- Seq(true, false)
+ if RandomDataGenerator.forType(arrayType.elementType, arrayType.containsNull).isDefined) {
+ testWithRandomDataGeneration(arrayType, nullable)
+ }
+
+ // Struct types:
+ for (
+ colOneType <- atomicTypes;
+ colTwoType <- atomicTypes;
+ nullable <- Seq(true, false)) {
+ val structType = StructType(
+ StructField("a", colOneType) :: StructField("b", colTwoType) :: Nil)
+ testWithRandomDataGeneration(structType, nullable)
+ }
+
+ // Map types: not supported
+ for (
+ keyType <- atomicTypes;
+ valueType <- atomicTypes;
+ nullable <- Seq(true, false)) {
+ val mapType = MapType(keyType, valueType)
+ val e = intercept[Exception] {
+ testWithRandomDataGeneration(mapType, nullable)
+ }
+ if (e.getMessage.contains("Code generation of")) {
+ // If the `value` expression is null, `eval` will be short-circuited.
+ // Codegen version evaluation will be run then.
+ assert(e.getMessage.contains("cannot generate equality code for un-comparable type"))
+ } else {
+ assert(e.getMessage.contains("Exception evaluating"))
+ }
+ }
}
test("INSET") {
@@ -215,14 +280,35 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
- private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_))
+ private case class MyStruct(a: Long, b: String)
+ private case class MyStruct2(a: MyStruct, b: Array[Int])
+ private val udt = new ExamplePointUDT
+
+ private val smallValues =
+ Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1),
+ new Timestamp(1), "a", 1f, 1d, 0f, 0d, false, Array(1L, 2L))
+ .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")),
+ Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))),
+ Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt))
private val largeValues =
- Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN, true).map(Literal(_))
+ Seq(2.toByte, 2.toShort, 2, 2L, Decimal(2), Array(2.toByte), new Date(2000, 1, 2),
+ new Timestamp(2), "b", 2f, 2d, Float.NaN, Double.NaN, true, Array(2L, 1L))
+ .map(Literal(_)) ++ Seq(Literal.create(MyStruct(2L, "b")),
+ Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 2))),
+ Literal.create(ArrayData.toArrayData(Array(1.0, 3.0)), udt))
private val equalValues1 =
- Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_))
+ Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1),
+ new Timestamp(1), "a", 1f, 1d, Float.NaN, Double.NaN, true, Array(1L, 2L))
+ .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")),
+ Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))),
+ Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt))
private val equalValues2 =
- Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_))
+ Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1),
+ new Timestamp(1), "a", 1f, 1d, Float.NaN, Double.NaN, true, Array(1L, 2L))
+ .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")),
+ Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))),
+ Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt))
test("BinaryComparison consistency check") {
DataTypeTestUtils.ordered.foreach { dt =>
@@ -285,11 +371,13 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
// Use -1 (default value for codegen) which can trigger some weird bugs, e.g. SPARK-14757
val normalInt = Literal(-1)
val nullInt = NonFoldableLiteral.create(null, IntegerType)
+ val nullNullType = Literal.create(null, NullType)
def nullTest(op: (Expression, Expression) => Expression): Unit = {
checkEvaluation(op(normalInt, nullInt), null)
checkEvaluation(op(nullInt, normalInt), null)
checkEvaluation(op(nullInt, nullInt), null)
+ checkEvaluation(op(nullNullType, nullNullType), null)
}
nullTest(LessThan)
@@ -301,6 +389,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(EqualNullSafe(normalInt, nullInt), false)
checkEvaluation(EqualNullSafe(nullInt, normalInt), false)
checkEvaluation(EqualNullSafe(nullInt, nullInt), true)
+ checkEvaluation(EqualNullSafe(nullNullType, nullNullType), true)
}
test("EqualTo on complex type") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JacksonGeneratorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JacksonGeneratorSuite.scala
new file mode 100644
index 0000000000000..9b27490ed0e35
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JacksonGeneratorSuite.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.json
+
+import java.io.CharArrayWriter
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
+import org.apache.spark.sql.types._
+
+class JacksonGeneratorSuite extends SparkFunSuite {
+
+ val gmtId = DateTimeUtils.TimeZoneGMT.getID
+ val option = new JSONOptions(Map.empty, gmtId)
+
+ test("initial with StructType and write out a row") {
+ val dataType = StructType(StructField("a", IntegerType) :: Nil)
+ val input = InternalRow(1)
+ val writer = new CharArrayWriter()
+ val gen = new JacksonGenerator(dataType, writer, option)
+ gen.write(input)
+ gen.flush()
+ assert(writer.toString === """{"a":1}""")
+ }
+
+ test("initial with StructType and write out rows") {
+ val dataType = StructType(StructField("a", IntegerType) :: Nil)
+ val input = new GenericArrayData(InternalRow(1) :: InternalRow(2) :: Nil)
+ val writer = new CharArrayWriter()
+ val gen = new JacksonGenerator(dataType, writer, option)
+ gen.write(input)
+ gen.flush()
+ assert(writer.toString === """[{"a":1},{"a":2}]""")
+ }
+
+ test("initial with StructType and write out an array with single empty row") {
+ val dataType = StructType(StructField("a", IntegerType) :: Nil)
+ val input = new GenericArrayData(InternalRow(null) :: Nil)
+ val writer = new CharArrayWriter()
+ val gen = new JacksonGenerator(dataType, writer, option)
+ gen.write(input)
+ gen.flush()
+ assert(writer.toString === """[{}]""")
+ }
+
+ test("initial with StructType and write out an empty array") {
+ val dataType = StructType(StructField("a", IntegerType) :: Nil)
+ val input = new GenericArrayData(Nil)
+ val writer = new CharArrayWriter()
+ val gen = new JacksonGenerator(dataType, writer, option)
+ gen.write(input)
+ gen.flush()
+ assert(writer.toString === """[]""")
+ }
+
+ test("initial with Map and write out a map data") {
+ val dataType = MapType(StringType, IntegerType)
+ val input = ArrayBasedMapData(Map("a" -> 1))
+ val writer = new CharArrayWriter()
+ val gen = new JacksonGenerator(dataType, writer, option)
+ gen.write(input)
+ gen.flush()
+ assert(writer.toString === """{"a":1}""")
+ }
+
+ test("initial with Map and write out an array of maps") {
+ val dataType = MapType(StringType, IntegerType)
+ val input = new GenericArrayData(
+ ArrayBasedMapData(Map("a" -> 1)) :: ArrayBasedMapData(Map("b" -> 2)) :: Nil)
+ val writer = new CharArrayWriter()
+ val gen = new JacksonGenerator(dataType, writer, option)
+ gen.write(input)
+ gen.flush()
+ assert(writer.toString === """[{"a":1},{"b":2}]""")
+ }
+
+ test("error handling: initial with StructType but error calling write a map") {
+ val dataType = StructType(StructField("a", IntegerType) :: Nil)
+ val input = ArrayBasedMapData(Map("a" -> 1))
+ val writer = new CharArrayWriter()
+ val gen = new JacksonGenerator(dataType, writer, option)
+ intercept[UnsupportedOperationException] {
+ gen.write(input)
+ }
+ }
+
+ test("error handling: initial with MapType and write out a row") {
+ val dataType = MapType(StringType, IntegerType)
+ val input = InternalRow(1)
+ val writer = new CharArrayWriter()
+ val gen = new JacksonGenerator(dataType, writer, option)
+ intercept[UnsupportedOperationException] {
+ gen.write(input)
+ }
+ }
+
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala
index 587437e9aa81d..e7a5bcee420f5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala
@@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.Rand
+import org.apache.spark.sql.catalyst.expressions.{Alias, Rand}
import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.MetadataBuilder
class CollapseProjectSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
@@ -119,4 +120,22 @@ class CollapseProjectSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("preserve top-level alias metadata while collapsing projects") {
+ def hasMetadata(logicalPlan: LogicalPlan): Boolean = {
+ logicalPlan.asInstanceOf[Project].projectList.exists(_.metadata.contains("key"))
+ }
+
+ val metadata = new MetadataBuilder().putLong("key", 1).build()
+ val analyzed =
+ Project(Seq(Alias('a_with_metadata, "b")()),
+ Project(Seq(Alias('a, "a_with_metadata")(explicitMetadata = Some(metadata))),
+ testRelation.logicalPlan)).analyze
+ require(hasMetadata(analyzed))
+
+ val optimized = Optimize.execute(analyzed)
+ val projects = optimized.collect { case p: Project => p }
+ assert(projects.size === 1)
+ assert(hasMetadata(optimized))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala
index d2dd469e2d74f..5580f8604ec72 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala
@@ -151,9 +151,9 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
.join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
.analyze
val correctAnswer = t1
- .where(IsNotNull('a) && IsNotNull('b) && 'a <=> 'a && 'b <=> 'b &&'a === 'b)
+ .where(IsNotNull('a) && IsNotNull('b) &&'a === 'b)
.select('a, 'b.as('d)).as("t")
- .join(t2.where(IsNotNull('a) && 'a <=> 'a), Inner,
+ .join(t2.where(IsNotNull('a)), Inner,
Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
@@ -176,17 +176,17 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
&& "t.int_col".attr === "t2.a".attr))
.analyze
val correctAnswer = t1
- .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
- && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a))
- && Coalesce(Seq('b, 'b)) <=> 'a && 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b)))
- && 'a === Coalesce(Seq('a, 'b)) && Coalesce(Seq('a, 'b)) === 'b
- && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b)))
- && 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b)))
+ .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) && IsNotNull(Coalesce(Seq('b, 'a)))
+ && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) && IsNotNull(Coalesce(Seq('a, 'b)))
+ && 'a === 'b && 'a === Coalesce(Seq('a, 'a)) && 'a === Coalesce(Seq('a, 'b))
+ && 'a === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('a, 'b))
+ && 'b === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('b, 'b)))
.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col))
.select('int_col, 'd, 'a).as("t")
- .join(t2
- .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
- && 'a <=> Coalesce(Seq('a, 'a)) && 'a === Coalesce(Seq('a, 'a)) && 'a <=> 'a), Inner,
+ .join(
+ t2.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) &&
+ 'a === Coalesce(Seq('a, 'a))),
+ Inner,
Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr
&& "t.int_col".attr === "t2.a".attr))
.analyze
@@ -194,6 +194,30 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+ test("inner join with EqualTo expressions containing part of each other: don't generate " +
+ "constraints for recursive functions") {
+ val t1 = testRelation.subquery('t1)
+ val t2 = testRelation.subquery('t2)
+
+ // We should prevent `c = Coalese(a, b)` and `a = Coalese(b, c)` from recursively creating
+ // complicated constraints through the constraint inference procedure.
+ val originalQuery = t1
+ .select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e))
+ .where('a === 'd && 'c === 'e)
+ .join(t2, Inner, Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr))
+ .analyze
+ val correctAnswer = t1
+ .where(IsNotNull('a) && IsNotNull('c) && 'a === Coalesce(Seq('b, 'c)) &&
+ 'c === Coalesce(Seq('a, 'b)))
+ .select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e))
+ .join(t2.where(IsNotNull('a) && IsNotNull('c)),
+ Inner,
+ Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr))
+ .analyze
+ val optimized = Optimize.execute(originalQuery)
+ comparePlans(optimized, correctAnswer)
+ }
+
test("generate correct filters for alias that don't produce recursive constraints") {
val t1 = testRelation.subquery('t1)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala
new file mode 100644
index 0000000000000..6e183d81b7265
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project}
+import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.internal.SQLConf
+
+
+class OptimizerStructuralIntegrityCheckerSuite extends PlanTest {
+
+ object OptimizeRuleBreakSI extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case Project(projectList, child) =>
+ val newAttr = UnresolvedAttribute("unresolvedAttr")
+ Project(projectList ++ Seq(newAttr), child)
+ }
+ }
+
+ object Optimize extends Optimizer(
+ new SessionCatalog(
+ new InMemoryCatalog,
+ EmptyFunctionRegistry,
+ new SQLConf())) {
+ val newBatch = Batch("OptimizeRuleBreakSI", Once, OptimizeRuleBreakSI)
+ override def batches: Seq[Batch] = Seq(newBatch) ++ super.batches
+ }
+
+ test("check for invalid plan after execution of rule") {
+ val analyzed = Project(Alias(Literal(10), "attr")() :: Nil, OneRowRelation()).analyze
+ assert(analyzed.resolved)
+ val message = intercept[TreeNodeException[LogicalPlan]] {
+ Optimize.execute(analyzed)
+ }.getMessage
+ val ruleName = OptimizeRuleBreakSI.ruleName
+ assert(message.contains(s"After applying rule $ruleName in batch OptimizeRuleBreakSI"))
+ assert(message.contains("the structural integrity of the plan is broken"))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
index 2285be16938d6..bc1c48b99c295 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala
@@ -18,11 +18,13 @@
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.StructType
class PropagateEmptyRelationSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
@@ -124,6 +126,48 @@ class PropagateEmptyRelationSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+ test("propagate empty streaming relation through multiple UnaryNode") {
+ val output = Seq('a.int)
+ val data = Seq(Row(1))
+ val schema = StructType.fromAttributes(output)
+ val converter = CatalystTypeConverters.createToCatalystConverter(schema)
+ val relation = LocalRelation(
+ output,
+ data.map(converter(_).asInstanceOf[InternalRow]),
+ isStreaming = true)
+
+ val query = relation
+ .where(false)
+ .select('a)
+ .where('a > 1)
+ .where('a != 200)
+ .orderBy('a.asc)
+
+ val optimized = Optimize.execute(query.analyze)
+ val correctAnswer = LocalRelation(output, isStreaming = true)
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("don't propagate empty streaming relation through agg") {
+ val output = Seq('a.int)
+ val data = Seq(Row(1))
+ val schema = StructType.fromAttributes(output)
+ val converter = CatalystTypeConverters.createToCatalystConverter(schema)
+ val relation = LocalRelation(
+ output,
+ data.map(converter(_).asInstanceOf[InternalRow]),
+ isStreaming = true)
+
+ val query = relation
+ .groupBy('a)('a)
+
+ val optimized = Optimize.execute(query.analyze)
+ val correctAnswer = query.analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
test("don't propagate non-empty local relation") {
val query = testRelation1
.where(true)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala
new file mode 100644
index 0000000000000..169b8737d808b
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{In, ListQuery}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+class PullupCorrelatedPredicatesSuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("PullupCorrelatedPredicates", Once,
+ PullupCorrelatedPredicates) :: Nil
+ }
+
+ val testRelation = LocalRelation('a.int, 'b.double)
+ val testRelation2 = LocalRelation('c.int, 'd.double)
+
+ test("PullupCorrelatedPredicates should not produce unresolved plan") {
+ val correlatedSubquery =
+ testRelation2
+ .where('b < 'd)
+ .select('c)
+ val outerQuery =
+ testRelation
+ .where(In('a, Seq(ListQuery(correlatedSubquery))))
+ .select('a).analyze
+ assert(outerQuery.resolved)
+
+ val optimized = Optimize.execute(outerQuery)
+ assert(optimized.resolved)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
index e68423f85c92e..85988d2fb948c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
@@ -79,7 +79,7 @@ class ReplaceOperatorSuite extends PlanTest {
val input = LocalRelation('a.int, 'b.int)
val attrA = input.output(0)
val attrB = input.output(1)
- val query = Deduplicate(Seq(attrA), input, streaming = false) // dropDuplicates("a")
+ val query = Deduplicate(Seq(attrA), input) // dropDuplicates("a")
val optimized = Optimize.execute(query.analyze)
val correctAnswer =
@@ -95,9 +95,9 @@ class ReplaceOperatorSuite extends PlanTest {
}
test("don't replace streaming Deduplicate") {
- val input = LocalRelation('a.int, 'b.int)
+ val input = LocalRelation(Seq('a.int, 'b.int), isStreaming = true)
val attrA = input.output(0)
- val query = Deduplicate(Seq(attrA), input, streaming = true) // dropDuplicates("a")
+ val query = Deduplicate(Seq(attrA), input) // dropDuplicates("a")
val optimized = Optimize.execute(query.analyze)
comparePlans(optimized, query)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala
index 48aaec44885d4..6803fc307f919 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala
@@ -79,10 +79,12 @@ class TableSchemaParserSuite extends SparkFunSuite {
}
// Negative cases
- assertError("")
- assertError("a")
- assertError("a INT b long")
- assertError("a INT,, b long")
- assertError("a INT, b long,,")
- assertError("a INT, b long, c int,")
+ test("Negative cases") {
+ assertError("")
+ assertError("a")
+ assertError("a INT b long")
+ assertError("a INT,, b long")
+ assertError("a INT, b long,,")
+ assertError("a INT, b long, c int,")
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
index a37e06d922642..866ff0d33cbb2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
@@ -134,8 +134,6 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest {
verifyConstraints(aliasedRelation.analyze.constraints,
ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10,
IsNotNull(resolveColumn(aliasedRelation.analyze, "x")),
- resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"),
- resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"),
resolveColumn(aliasedRelation.analyze, "z") > 10,
IsNotNull(resolveColumn(aliasedRelation.analyze, "z")))))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
index cc86f1f6e2f48..cdf912df7c76a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
@@ -73,10 +73,8 @@ class LogicalPlanSuite extends SparkFunSuite {
test("isStreaming") {
val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
- val incrementalRelation = new LocalRelation(
- Seq(AttributeReference("a", IntegerType, nullable = true)())) {
- override def isStreaming(): Boolean = true
- }
+ val incrementalRelation = LocalRelation(
+ Seq(AttributeReference("a", IntegerType, nullable = true)()), isStreaming = true)
case class TestBinaryRelation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
override def output: Seq[Attribute] = left.output ++ right.output
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala
index c9d36910b0998..a67f54b263cc9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala
@@ -56,4 +56,21 @@ class RuleExecutorSuite extends SparkFunSuite {
}.getMessage
assert(message.contains("Max iterations (10) reached for batch fixedPoint"))
}
+
+ test("structural integrity checker") {
+ object WithSIChecker extends RuleExecutor[Expression] {
+ override protected def isPlanIntegral(expr: Expression): Boolean = expr match {
+ case IntegerLiteral(_) => true
+ case _ => false
+ }
+ val batches = Batch("once", Once, DecrementLiterals) :: Nil
+ }
+
+ assert(WithSIChecker.execute(Literal(10)) === Literal(9))
+
+ val message = intercept[TreeNodeException[LogicalPlan]] {
+ WithSIChecker.execute(Literal(10.1))
+ }.getMessage
+ assert(message.contains("the structural integrity of the plan is broken"))
+ }
}
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 9a3cacbe3825e..7ee002e465756 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -38,7 +38,7 @@
com.univocity
univocity-parsers
- 2.2.1
+ 2.5.4
jar
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
index e1e7742c93a9c..fb1fcd77b011f 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
@@ -38,6 +38,7 @@
import org.apache.parquet.schema.Type;
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
import org.apache.spark.sql.execution.vectorized.ColumnVector;
+import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.DecimalType;
@@ -141,9 +142,9 @@ private boolean next() throws IOException {
/**
* Reads `total` values from this columnReader into column.
*/
- void readBatch(int total, ColumnVector column) throws IOException {
+ void readBatch(int total, WritableColumnVector column) throws IOException {
int rowId = 0;
- ColumnVector dictionaryIds = null;
+ WritableColumnVector dictionaryIds = null;
if (dictionary != null) {
// SPARK-16334: We only maintain a single dictionary per row batch, so that it can be used to
// decode all previous dictionary encoded pages if we ever encounter a non-dictionary encoded
@@ -225,8 +226,11 @@ void readBatch(int total, ColumnVector column) throws IOException {
/**
* Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`.
*/
- private void decodeDictionaryIds(int rowId, int num, ColumnVector column,
- ColumnVector dictionaryIds) {
+ private void decodeDictionaryIds(
+ int rowId,
+ int num,
+ WritableColumnVector column,
+ ColumnVector dictionaryIds) {
switch (descriptor.getType()) {
case INT32:
if (column.dataType() == DataTypes.IntegerType ||
@@ -355,13 +359,13 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column,
* is guaranteed that num is smaller than the number of values left in the current page.
*/
- private void readBooleanBatch(int rowId, int num, ColumnVector column) throws IOException {
+ private void readBooleanBatch(int rowId, int num, WritableColumnVector column) {
assert(column.dataType() == DataTypes.BooleanType);
defColumn.readBooleans(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
}
- private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException {
+ private void readIntBatch(int rowId, int num, WritableColumnVector column) {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType ||
@@ -379,7 +383,7 @@ private void readIntBatch(int rowId, int num, ColumnVector column) throws IOExce
}
}
- private void readLongBatch(int rowId, int num, ColumnVector column) throws IOException {
+ private void readLongBatch(int rowId, int num, WritableColumnVector column) {
// This is where we implement support for the valid type conversions.
if (column.dataType() == DataTypes.LongType ||
(column.dataType() == DataTypes.TimestampType &&
@@ -401,7 +405,7 @@ private void readLongBatch(int rowId, int num, ColumnVector column) throws IOExc
}
}
- private void readFloatBatch(int rowId, int num, ColumnVector column) throws IOException {
+ private void readFloatBatch(int rowId, int num, WritableColumnVector column) {
// This is where we implement support for the valid type conversions.
// TODO: support implicit cast to double?
if (column.dataType() == DataTypes.FloatType) {
@@ -412,7 +416,7 @@ private void readFloatBatch(int rowId, int num, ColumnVector column) throws IOEx
}
}
- private void readDoubleBatch(int rowId, int num, ColumnVector column) throws IOException {
+ private void readDoubleBatch(int rowId, int num, WritableColumnVector column) {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
if (column.dataType() == DataTypes.DoubleType) {
@@ -423,7 +427,7 @@ private void readDoubleBatch(int rowId, int num, ColumnVector column) throws IOE
}
}
- private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOException {
+ private void readBinaryBatch(int rowId, int num, WritableColumnVector column) {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
@@ -444,8 +448,11 @@ private void readBinaryBatch(int rowId, int num, ColumnVector column) throws IOE
}
}
- private void readFixedLenByteArrayBatch(int rowId, int num,
- ColumnVector column, int arrayLen) throws IOException {
+ private void readFixedLenByteArrayBatch(
+ int rowId,
+ int num,
+ WritableColumnVector column,
+ int arrayLen) {
VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
@@ -480,7 +487,7 @@ private void readFixedLenByteArrayBatch(int rowId, int num,
}
}
- private void readPage() throws IOException {
+ private void readPage() {
DataPage page = pageReader.readPage();
// TODO: Why is this a visitor?
page.accept(new DataPage.Visitor() {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
index e8a19e5d0adac..248c92f1eb69e 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
@@ -31,6 +31,9 @@
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils;
import org.apache.spark.sql.execution.vectorized.ColumnarBatch;
+import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
+import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector;
+import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
@@ -90,6 +93,8 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa
*/
private ColumnarBatch columnarBatch;
+ private WritableColumnVector[] columnVectors;
+
/**
* If true, this class returns batches instead of rows.
*/
@@ -172,20 +177,26 @@ public void initBatch(MemoryMode memMode, StructType partitionColumns,
}
}
- columnarBatch = ColumnarBatch.allocate(batchSchema, memMode);
+ int capacity = ColumnarBatch.DEFAULT_BATCH_SIZE;
+ if (memMode == MemoryMode.OFF_HEAP) {
+ columnVectors = OffHeapColumnVector.allocateColumns(capacity, batchSchema);
+ } else {
+ columnVectors = OnHeapColumnVector.allocateColumns(capacity, batchSchema);
+ }
+ columnarBatch = new ColumnarBatch(batchSchema, columnVectors, capacity);
if (partitionColumns != null) {
int partitionIdx = sparkSchema.fields().length;
for (int i = 0; i < partitionColumns.fields().length; i++) {
- ColumnVectorUtils.populate(columnarBatch.column(i + partitionIdx), partitionValues, i);
- columnarBatch.column(i + partitionIdx).setIsConstant();
+ ColumnVectorUtils.populate(columnVectors[i + partitionIdx], partitionValues, i);
+ columnVectors[i + partitionIdx].setIsConstant();
}
}
// Initialize missing columns with nulls.
for (int i = 0; i < missingColumns.length; i++) {
if (missingColumns[i]) {
- columnarBatch.column(i).putNulls(0, columnarBatch.capacity());
- columnarBatch.column(i).setIsConstant();
+ columnVectors[i].putNulls(0, columnarBatch.capacity());
+ columnVectors[i].setIsConstant();
}
}
}
@@ -226,7 +237,7 @@ public boolean nextBatch() throws IOException {
int num = (int) Math.min((long) columnarBatch.capacity(), totalCountLoadedSoFar - rowsReturned);
for (int i = 0; i < columnReaders.length; ++i) {
if (columnReaders[i] == null) continue;
- columnReaders[i].readBatch(num, columnarBatch.column(i));
+ columnReaders[i].readBatch(num, columnVectors[i]);
}
rowsReturned += num;
columnarBatch.setNumRows(num);
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
index 949d6533c64b7..c23929894e1e9 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
@@ -20,7 +20,7 @@
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
-import org.apache.spark.sql.execution.vectorized.ColumnVector;
+import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.unsafe.Platform;
import org.apache.parquet.column.values.ValuesReader;
@@ -61,7 +61,7 @@ public void skip() {
}
@Override
- public final void readBooleans(int total, ColumnVector c, int rowId) {
+ public final void readBooleans(int total, WritableColumnVector c, int rowId) {
// TODO: properly vectorize this
for (int i = 0; i < total; i++) {
c.putBoolean(rowId + i, readBoolean());
@@ -69,31 +69,31 @@ public final void readBooleans(int total, ColumnVector c, int rowId) {
}
@Override
- public final void readIntegers(int total, ColumnVector c, int rowId) {
+ public final void readIntegers(int total, WritableColumnVector c, int rowId) {
c.putIntsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
offset += 4 * total;
}
@Override
- public final void readLongs(int total, ColumnVector c, int rowId) {
+ public final void readLongs(int total, WritableColumnVector c, int rowId) {
c.putLongsLittleEndian(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
offset += 8 * total;
}
@Override
- public final void readFloats(int total, ColumnVector c, int rowId) {
+ public final void readFloats(int total, WritableColumnVector c, int rowId) {
c.putFloats(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
offset += 4 * total;
}
@Override
- public final void readDoubles(int total, ColumnVector c, int rowId) {
+ public final void readDoubles(int total, WritableColumnVector c, int rowId) {
c.putDoubles(rowId, total, buffer, offset - Platform.BYTE_ARRAY_OFFSET);
offset += 8 * total;
}
@Override
- public final void readBytes(int total, ColumnVector c, int rowId) {
+ public final void readBytes(int total, WritableColumnVector c, int rowId) {
for (int i = 0; i < total; i++) {
// Bytes are stored as a 4-byte little endian int. Just read the first byte.
// TODO: consider pushing this in ColumnVector by adding a readBytes with a stride.
@@ -164,7 +164,7 @@ public final double readDouble() {
}
@Override
- public final void readBinary(int total, ColumnVector v, int rowId) {
+ public final void readBinary(int total, WritableColumnVector v, int rowId) {
for (int i = 0; i < total; i++) {
int len = readInteger();
int start = offset;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
index 31511f218edcd..cbb9a83601b73 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
@@ -25,7 +25,7 @@
import org.apache.parquet.io.ParquetDecodingException;
import org.apache.parquet.io.api.Binary;
-import org.apache.spark.sql.execution.vectorized.ColumnVector;
+import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import java.nio.ByteBuffer;
@@ -179,7 +179,11 @@ public int readInteger() {
* c[rowId] = null;
* }
*/
- public void readIntegers(int total, ColumnVector c, int rowId, int level,
+ public void readIntegers(
+ int total,
+ WritableColumnVector c,
+ int rowId,
+ int level,
VectorizedValuesReader data) {
int left = total;
while (left > 0) {
@@ -210,8 +214,12 @@ public void readIntegers(int total, ColumnVector c, int rowId, int level,
}
// TODO: can this code duplication be removed without a perf penalty?
- public void readBooleans(int total, ColumnVector c,
- int rowId, int level, VectorizedValuesReader data) {
+ public void readBooleans(
+ int total,
+ WritableColumnVector c,
+ int rowId,
+ int level,
+ VectorizedValuesReader data) {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -240,8 +248,12 @@ public void readBooleans(int total, ColumnVector c,
}
}
- public void readBytes(int total, ColumnVector c,
- int rowId, int level, VectorizedValuesReader data) {
+ public void readBytes(
+ int total,
+ WritableColumnVector c,
+ int rowId,
+ int level,
+ VectorizedValuesReader data) {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -270,8 +282,12 @@ public void readBytes(int total, ColumnVector c,
}
}
- public void readShorts(int total, ColumnVector c,
- int rowId, int level, VectorizedValuesReader data) {
+ public void readShorts(
+ int total,
+ WritableColumnVector c,
+ int rowId,
+ int level,
+ VectorizedValuesReader data) {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -302,8 +318,12 @@ public void readShorts(int total, ColumnVector c,
}
}
- public void readLongs(int total, ColumnVector c, int rowId, int level,
- VectorizedValuesReader data) {
+ public void readLongs(
+ int total,
+ WritableColumnVector c,
+ int rowId,
+ int level,
+ VectorizedValuesReader data) {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -332,8 +352,12 @@ public void readLongs(int total, ColumnVector c, int rowId, int level,
}
}
- public void readFloats(int total, ColumnVector c, int rowId, int level,
- VectorizedValuesReader data) {
+ public void readFloats(
+ int total,
+ WritableColumnVector c,
+ int rowId,
+ int level,
+ VectorizedValuesReader data) {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -362,8 +386,12 @@ public void readFloats(int total, ColumnVector c, int rowId, int level,
}
}
- public void readDoubles(int total, ColumnVector c, int rowId, int level,
- VectorizedValuesReader data) {
+ public void readDoubles(
+ int total,
+ WritableColumnVector c,
+ int rowId,
+ int level,
+ VectorizedValuesReader data) {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -392,8 +420,12 @@ public void readDoubles(int total, ColumnVector c, int rowId, int level,
}
}
- public void readBinarys(int total, ColumnVector c, int rowId, int level,
- VectorizedValuesReader data) {
+ public void readBinarys(
+ int total,
+ WritableColumnVector c,
+ int rowId,
+ int level,
+ VectorizedValuesReader data) {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -426,8 +458,13 @@ public void readBinarys(int total, ColumnVector c, int rowId, int level,
* Decoding for dictionary ids. The IDs are populated into `values` and the nullability is
* populated into `nulls`.
*/
- public void readIntegers(int total, ColumnVector values, ColumnVector nulls, int rowId, int level,
- VectorizedValuesReader data) {
+ public void readIntegers(
+ int total,
+ WritableColumnVector values,
+ WritableColumnVector nulls,
+ int rowId,
+ int level,
+ VectorizedValuesReader data) {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -461,7 +498,7 @@ public void readIntegers(int total, ColumnVector values, ColumnVector nulls, int
// IDs. This is different than the above APIs that decodes definitions levels along with values.
// Since this is only used to decode dictionary IDs, only decoding integers is supported.
@Override
- public void readIntegers(int total, ColumnVector c, int rowId) {
+ public void readIntegers(int total, WritableColumnVector c, int rowId) {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -487,32 +524,32 @@ public byte readByte() {
}
@Override
- public void readBytes(int total, ColumnVector c, int rowId) {
+ public void readBytes(int total, WritableColumnVector c, int rowId) {
throw new UnsupportedOperationException("only readInts is valid.");
}
@Override
- public void readLongs(int total, ColumnVector c, int rowId) {
+ public void readLongs(int total, WritableColumnVector c, int rowId) {
throw new UnsupportedOperationException("only readInts is valid.");
}
@Override
- public void readBinary(int total, ColumnVector c, int rowId) {
+ public void readBinary(int total, WritableColumnVector c, int rowId) {
throw new UnsupportedOperationException("only readInts is valid.");
}
@Override
- public void readBooleans(int total, ColumnVector c, int rowId) {
+ public void readBooleans(int total, WritableColumnVector c, int rowId) {
throw new UnsupportedOperationException("only readInts is valid.");
}
@Override
- public void readFloats(int total, ColumnVector c, int rowId) {
+ public void readFloats(int total, WritableColumnVector c, int rowId) {
throw new UnsupportedOperationException("only readInts is valid.");
}
@Override
- public void readDoubles(int total, ColumnVector c, int rowId) {
+ public void readDoubles(int total, WritableColumnVector c, int rowId) {
throw new UnsupportedOperationException("only readInts is valid.");
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
index 88418ca53fe1e..57d92ae27ece8 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.datasources.parquet;
-import org.apache.spark.sql.execution.vectorized.ColumnVector;
+import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.parquet.io.api.Binary;
@@ -37,11 +37,11 @@ public interface VectorizedValuesReader {
/*
* Reads `total` values into `c` start at `c[rowId]`
*/
- void readBooleans(int total, ColumnVector c, int rowId);
- void readBytes(int total, ColumnVector c, int rowId);
- void readIntegers(int total, ColumnVector c, int rowId);
- void readLongs(int total, ColumnVector c, int rowId);
- void readFloats(int total, ColumnVector c, int rowId);
- void readDoubles(int total, ColumnVector c, int rowId);
- void readBinary(int total, ColumnVector c, int rowId);
+ void readBooleans(int total, WritableColumnVector c, int rowId);
+ void readBytes(int total, WritableColumnVector c, int rowId);
+ void readIntegers(int total, WritableColumnVector c, int rowId);
+ void readLongs(int total, WritableColumnVector c, int rowId);
+ void readFloats(int total, WritableColumnVector c, int rowId);
+ void readDoubles(int total, WritableColumnVector c, int rowId);
+ void readBinary(int total, WritableColumnVector c, int rowId);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java
index 25a565d32638d..cb3ad4eab1f60 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java
@@ -21,7 +21,6 @@
import com.google.common.annotations.VisibleForTesting;
-import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.types.StructType;
import static org.apache.spark.sql.types.DataTypes.LongType;
@@ -41,6 +40,7 @@
*/
public class AggregateHashMap {
+ private OnHeapColumnVector[] columnVectors;
private ColumnarBatch batch;
private int[] buckets;
private int numBuckets;
@@ -62,7 +62,8 @@ public AggregateHashMap(StructType schema, int capacity, double loadFactor, int
this.maxSteps = maxSteps;
numBuckets = (int) (capacity / loadFactor);
- batch = ColumnarBatch.allocate(schema, MemoryMode.ON_HEAP, capacity);
+ columnVectors = OnHeapColumnVector.allocateColumns(capacity, schema);
+ batch = new ColumnarBatch(schema, columnVectors, capacity);
buckets = new int[numBuckets];
Arrays.fill(buckets, -1);
}
@@ -74,8 +75,8 @@ public AggregateHashMap(StructType schema) {
public ColumnarBatch.Row findOrInsert(long key) {
int idx = find(key);
if (idx != -1 && buckets[idx] == -1) {
- batch.column(0).putLong(numRows, key);
- batch.column(1).putLong(numRows, 0);
+ columnVectors[0].putLong(numRows, key);
+ columnVectors[1].putLong(numRows, 0);
buckets[idx] = numRows++;
}
return batch.getRow(buckets[idx]);
@@ -105,6 +106,6 @@ private long hash(long key) {
}
private boolean equals(int idx, long key1) {
- return batch.column(0).getLong(buckets[idx]) == key1;
+ return columnVectors[0].getLong(buckets[idx]) == key1;
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java
index 59d66c599c518..1f171049820b2 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java
@@ -21,7 +21,6 @@
import org.apache.arrow.vector.complex.*;
import org.apache.arrow.vector.holders.NullableVarCharHolder;
-import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.execution.arrow.ArrowUtils;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.UTF8String;
@@ -29,12 +28,13 @@
/**
* A column vector backed by Apache Arrow.
*/
-public final class ArrowColumnVector extends ReadOnlyColumnVector {
+public final class ArrowColumnVector extends ColumnVector {
private final ArrowVectorAccessor accessor;
- private final int valueCount;
+ private ArrowColumnVector[] childColumns;
private void ensureAccessible(int index) {
+ int valueCount = accessor.getValueCount();
if (index < 0 || index >= valueCount) {
throw new IndexOutOfBoundsException(
String.format("index: %d, valueCount: %d", index, valueCount));
@@ -42,12 +42,23 @@ private void ensureAccessible(int index) {
}
private void ensureAccessible(int index, int count) {
+ int valueCount = accessor.getValueCount();
if (index < 0 || index + count > valueCount) {
throw new IndexOutOfBoundsException(
String.format("index range: [%d, %d), valueCount: %d", index, index + count, valueCount));
}
}
+ @Override
+ public int numNulls() {
+ return accessor.getNullCount();
+ }
+
+ @Override
+ public boolean anyNullsSet() {
+ return numNulls() > 0;
+ }
+
@Override
public long nullsNativeAddress() {
throw new RuntimeException("Cannot get native address for arrow column");
@@ -274,9 +285,20 @@ public byte[] getBinary(int rowId) {
return accessor.getBinary(rowId);
}
+ /**
+ * Returns the data for the underlying array.
+ */
+ @Override
+ public ArrowColumnVector arrayData() { return childColumns[0]; }
+
+ /**
+ * Returns the ordinal's child data column.
+ */
+ @Override
+ public ArrowColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; }
+
public ArrowColumnVector(ValueVector vector) {
- super(vector.getValueCapacity(), ArrowUtils.fromArrowField(vector.getField()),
- MemoryMode.OFF_HEAP);
+ super(ArrowUtils.fromArrowField(vector.getField()));
if (vector instanceof NullableBitVector) {
accessor = new BooleanAccessor((NullableBitVector) vector);
@@ -302,7 +324,7 @@ public ArrowColumnVector(ValueVector vector) {
ListVector listVector = (ListVector) vector;
accessor = new ArrayAccessor(listVector);
- childColumns = new ColumnVector[1];
+ childColumns = new ArrowColumnVector[1];
childColumns[0] = new ArrowColumnVector(listVector.getDataVector());
resultArray = new ColumnVector.Array(childColumns[0]);
} else if (vector instanceof MapVector) {
@@ -317,9 +339,6 @@ public ArrowColumnVector(ValueVector vector) {
} else {
throw new UnsupportedOperationException();
}
- valueCount = accessor.getValueCount();
- numNulls = accessor.getNullCount();
- anyNullsSet = numNulls > 0;
}
private abstract static class ArrowVectorAccessor {
@@ -327,14 +346,9 @@ private abstract static class ArrowVectorAccessor {
private final ValueVector vector;
private final ValueVector.Accessor nulls;
- private final int valueCount;
- private final int nullCount;
-
ArrowVectorAccessor(ValueVector vector) {
this.vector = vector;
this.nulls = vector.getAccessor();
- this.valueCount = nulls.getValueCount();
- this.nullCount = nulls.getNullCount();
}
final boolean isNullAt(int rowId) {
@@ -342,11 +356,11 @@ final boolean isNullAt(int rowId) {
}
final int getValueCount() {
- return valueCount;
+ return nulls.getValueCount();
}
final int getNullCount() {
- return nullCount;
+ return nulls.getNullCount();
}
final void close() {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
index 77966382881b8..a69dd9718fe33 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
@@ -16,23 +16,16 @@
*/
package org.apache.spark.sql.execution.vectorized;
-import java.math.BigDecimal;
-import java.math.BigInteger;
-
-import com.google.common.annotations.VisibleForTesting;
-
-import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
-import org.apache.spark.sql.internal.SQLConf;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
/**
* This class represents a column of values and provides the main APIs to access the data
- * values. It supports all the types and contains get/put APIs as well as their batched versions.
+ * values. It supports all the types and contains get APIs as well as their batched versions.
* The batched versions are preferable whenever possible.
*
* To handle nested schemas, ColumnVector has two types: Arrays and Structs. In both cases these
@@ -40,34 +33,15 @@
* contains nullability, and in the case of Arrays, the lengths and offsets into the child column.
* Lengths and offsets are encoded identically to INTs.
* Maps are just a special case of a two field struct.
- * Strings are handled as an Array of ByteType.
- *
- * Capacity: The data stored is dense but the arrays are not fixed capacity. It is the
- * responsibility of the caller to call reserve() to ensure there is enough room before adding
- * elements. This means that the put() APIs do not check as in common cases (i.e. flat schemas),
- * the lengths are known up front.
*
* Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values
* in the current RowBatch.
*
- * A ColumnVector should be considered immutable once originally created. In other words, it is not
- * valid to call put APIs after reads until reset() is called.
+ * A ColumnVector should be considered immutable once originally created.
*
* ColumnVectors are intended to be reused.
*/
public abstract class ColumnVector implements AutoCloseable {
- /**
- * Allocates a column to store elements of `type` on or off heap.
- * Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is
- * in number of elements, not number of bytes.
- */
- public static ColumnVector allocate(int capacity, DataType type, MemoryMode mode) {
- if (mode == MemoryMode.OFF_HEAP) {
- return new OffHeapColumnVector(capacity, type);
- } else {
- return new OnHeapColumnVector(capacity, type);
- }
- }
/**
* Holder object to return an array. This object is intended to be reused. Callers should
@@ -278,75 +252,22 @@ public Object get(int ordinal, DataType dataType) {
*/
public final DataType dataType() { return type; }
- /**
- * Resets this column for writing. The currently stored values are no longer accessible.
- */
- public void reset() {
- if (isConstant) return;
-
- if (childColumns != null) {
- for (ColumnVector c: childColumns) {
- c.reset();
- }
- }
- numNulls = 0;
- elementsAppended = 0;
- if (anyNullsSet) {
- putNotNulls(0, capacity);
- anyNullsSet = false;
- }
- }
-
/**
* Cleans up memory for this column. The column is not usable after this.
* TODO: this should probably have ref-counted semantics.
*/
public abstract void close();
- public void reserve(int requiredCapacity) {
- if (requiredCapacity > capacity) {
- int newCapacity = (int) Math.min(MAX_CAPACITY, requiredCapacity * 2L);
- if (requiredCapacity <= newCapacity) {
- try {
- reserveInternal(newCapacity);
- } catch (OutOfMemoryError outOfMemoryError) {
- throwUnsupportedException(requiredCapacity, outOfMemoryError);
- }
- } else {
- throwUnsupportedException(requiredCapacity, null);
- }
- }
- }
-
- private void throwUnsupportedException(int requiredCapacity, Throwable cause) {
- String message = "Cannot reserve additional contiguous bytes in the vectorized reader " +
- "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " +
- "vectorized reader by setting " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() +
- " to false.";
-
- if (cause != null) {
- throw new RuntimeException(message, cause);
- } else {
- throw new RuntimeException(message);
- }
- }
-
- /**
- * Ensures that there is enough storage to store capacity elements. That is, the put() APIs
- * must work for all rowIds < capacity.
- */
- protected abstract void reserveInternal(int capacity);
-
/**
* Returns the number of nulls in this column.
*/
- public final int numNulls() { return numNulls; }
+ public abstract int numNulls();
/**
* Returns true if any of the nulls indicator are set for this column. This can be used
* as an optimization to prevent setting nulls.
*/
- public final boolean anyNullsSet() { return anyNullsSet; }
+ public abstract boolean anyNullsSet();
/**
* Returns the off heap ptr for the arrays backing the NULLs and values buffer. Only valid
@@ -355,33 +276,11 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) {
public abstract long nullsNativeAddress();
public abstract long valuesNativeAddress();
- /**
- * Sets the value at rowId to null/not null.
- */
- public abstract void putNotNull(int rowId);
- public abstract void putNull(int rowId);
-
- /**
- * Sets the values from [rowId, rowId + count) to null/not null.
- */
- public abstract void putNulls(int rowId, int count);
- public abstract void putNotNulls(int rowId, int count);
-
/**
* Returns whether the value at rowId is NULL.
*/
public abstract boolean isNullAt(int rowId);
- /**
- * Sets the value at rowId to `value`.
- */
- public abstract void putBoolean(int rowId, boolean value);
-
- /**
- * Sets values from [rowId, rowId + count) to value.
- */
- public abstract void putBooleans(int rowId, int count, boolean value);
-
/**
* Returns the value for rowId.
*/
@@ -392,21 +291,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) {
*/
public abstract boolean[] getBooleans(int rowId, int count);
- /**
- * Sets the value at rowId to `value`.
- */
- public abstract void putByte(int rowId, byte value);
-
- /**
- * Sets values from [rowId, rowId + count) to value.
- */
- public abstract void putBytes(int rowId, int count, byte value);
-
- /**
- * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
- */
- public abstract void putBytes(int rowId, int count, byte[] src, int srcIndex);
-
/**
* Returns the value for rowId.
*/
@@ -417,21 +301,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) {
*/
public abstract byte[] getBytes(int rowId, int count);
- /**
- * Sets the value at rowId to `value`.
- */
- public abstract void putShort(int rowId, short value);
-
- /**
- * Sets values from [rowId, rowId + count) to value.
- */
- public abstract void putShorts(int rowId, int count, short value);
-
- /**
- * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
- */
- public abstract void putShorts(int rowId, int count, short[] src, int srcIndex);
-
/**
* Returns the value for rowId.
*/
@@ -442,27 +311,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) {
*/
public abstract short[] getShorts(int rowId, int count);
- /**
- * Sets the value at rowId to `value`.
- */
- public abstract void putInt(int rowId, int value);
-
- /**
- * Sets values from [rowId, rowId + count) to value.
- */
- public abstract void putInts(int rowId, int count, int value);
-
- /**
- * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
- */
- public abstract void putInts(int rowId, int count, int[] src, int srcIndex);
-
- /**
- * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count])
- * The data in src must be 4-byte little endian ints.
- */
- public abstract void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex);
-
/**
* Returns the value for rowId.
*/
@@ -480,27 +328,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) {
*/
public abstract int getDictId(int rowId);
- /**
- * Sets the value at rowId to `value`.
- */
- public abstract void putLong(int rowId, long value);
-
- /**
- * Sets values from [rowId, rowId + count) to value.
- */
- public abstract void putLongs(int rowId, int count, long value);
-
- /**
- * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
- */
- public abstract void putLongs(int rowId, int count, long[] src, int srcIndex);
-
- /**
- * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count])
- * The data in src must be 8-byte little endian longs.
- */
- public abstract void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex);
-
/**
* Returns the value for rowId.
*/
@@ -511,27 +338,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) {
*/
public abstract long[] getLongs(int rowId, int count);
- /**
- * Sets the value at rowId to `value`.
- */
- public abstract void putFloat(int rowId, float value);
-
- /**
- * Sets values from [rowId, rowId + count) to value.
- */
- public abstract void putFloats(int rowId, int count, float value);
-
- /**
- * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
- */
- public abstract void putFloats(int rowId, int count, float[] src, int srcIndex);
-
- /**
- * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count])
- * The data in src must be ieee formatted floats.
- */
- public abstract void putFloats(int rowId, int count, byte[] src, int srcIndex);
-
/**
* Returns the value for rowId.
*/
@@ -542,27 +348,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) {
*/
public abstract float[] getFloats(int rowId, int count);
- /**
- * Sets the value at rowId to `value`.
- */
- public abstract void putDouble(int rowId, double value);
-
- /**
- * Sets values from [rowId, rowId + count) to value.
- */
- public abstract void putDoubles(int rowId, int count, double value);
-
- /**
- * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
- */
- public abstract void putDoubles(int rowId, int count, double[] src, int srcIndex);
-
- /**
- * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count])
- * The data in src must be ieee formatted doubles.
- */
- public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex);
-
/**
* Returns the value for rowId.
*/
@@ -573,11 +358,6 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) {
*/
public abstract double[] getDoubles(int rowId, int count);
- /**
- * Puts a byte array that already exists in this column.
- */
- public abstract void putArray(int rowId, int offset, int length);
-
/**
* Returns the length of the array at rowid.
*/
@@ -608,7 +388,7 @@ public ColumnarBatch.Row getStruct(int rowId, int size) {
/**
* Returns the array at rowid.
*/
- public final Array getArray(int rowId) {
+ public final ColumnVector.Array getArray(int rowId) {
resultArray.length = getArrayLength(rowId);
resultArray.offset = getArrayOffset(rowId);
return resultArray;
@@ -617,24 +397,7 @@ public final Array getArray(int rowId) {
/**
* Loads the data into array.byteArray.
*/
- public abstract void loadBytes(Array array);
-
- /**
- * Sets the value at rowId to `value`.
- */
- public abstract int putByteArray(int rowId, byte[] value, int offset, int count);
- public final int putByteArray(int rowId, byte[] value) {
- return putByteArray(rowId, value, 0, value.length);
- }
-
- /**
- * Returns the value for rowId.
- */
- private Array getByteArray(int rowId) {
- Array array = getArray(rowId);
- array.data.loadBytes(array);
- return array;
- }
+ public abstract void loadBytes(ColumnVector.Array array);
/**
* Returns the value for rowId.
@@ -646,354 +409,42 @@ public MapData getMap(int ordinal) {
/**
* Returns the decimal for rowId.
*/
- public Decimal getDecimal(int rowId, int precision, int scale) {
- if (precision <= Decimal.MAX_INT_DIGITS()) {
- return Decimal.createUnsafe(getInt(rowId), precision, scale);
- } else if (precision <= Decimal.MAX_LONG_DIGITS()) {
- return Decimal.createUnsafe(getLong(rowId), precision, scale);
- } else {
- // TODO: best perf?
- byte[] bytes = getBinary(rowId);
- BigInteger bigInteger = new BigInteger(bytes);
- BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
- return Decimal.apply(javaDecimal, precision, scale);
- }
- }
-
-
- public void putDecimal(int rowId, Decimal value, int precision) {
- if (precision <= Decimal.MAX_INT_DIGITS()) {
- putInt(rowId, (int) value.toUnscaledLong());
- } else if (precision <= Decimal.MAX_LONG_DIGITS()) {
- putLong(rowId, value.toUnscaledLong());
- } else {
- BigInteger bigInteger = value.toJavaBigDecimal().unscaledValue();
- putByteArray(rowId, bigInteger.toByteArray());
- }
- }
+ public abstract Decimal getDecimal(int rowId, int precision, int scale);
/**
* Returns the UTF8String for rowId.
*/
- public UTF8String getUTF8String(int rowId) {
- if (dictionary == null) {
- ColumnVector.Array a = getByteArray(rowId);
- return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length);
- } else {
- byte[] bytes = dictionary.decodeToBinary(dictionaryIds.getDictId(rowId));
- return UTF8String.fromBytes(bytes);
- }
- }
+ public abstract UTF8String getUTF8String(int rowId);
/**
* Returns the byte array for rowId.
*/
- public byte[] getBinary(int rowId) {
- if (dictionary == null) {
- ColumnVector.Array array = getByteArray(rowId);
- byte[] bytes = new byte[array.length];
- System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length);
- return bytes;
- } else {
- return dictionary.decodeToBinary(dictionaryIds.getDictId(rowId));
- }
- }
-
- /**
- * Append APIs. These APIs all behave similarly and will append data to the current vector. It
- * is not valid to mix the put and append APIs. The append APIs are slower and should only be
- * used if the sizes are not known up front.
- * In all these cases, the return value is the rowId for the first appended element.
- */
- public final int appendNull() {
- assert (!(dataType() instanceof StructType)); // Use appendStruct()
- reserve(elementsAppended + 1);
- putNull(elementsAppended);
- return elementsAppended++;
- }
-
- public final int appendNotNull() {
- reserve(elementsAppended + 1);
- putNotNull(elementsAppended);
- return elementsAppended++;
- }
-
- public final int appendNulls(int count) {
- assert (!(dataType() instanceof StructType));
- reserve(elementsAppended + count);
- int result = elementsAppended;
- putNulls(elementsAppended, count);
- elementsAppended += count;
- return result;
- }
-
- public final int appendNotNulls(int count) {
- assert (!(dataType() instanceof StructType));
- reserve(elementsAppended + count);
- int result = elementsAppended;
- putNotNulls(elementsAppended, count);
- elementsAppended += count;
- return result;
- }
-
- public final int appendBoolean(boolean v) {
- reserve(elementsAppended + 1);
- putBoolean(elementsAppended, v);
- return elementsAppended++;
- }
-
- public final int appendBooleans(int count, boolean v) {
- reserve(elementsAppended + count);
- int result = elementsAppended;
- putBooleans(elementsAppended, count, v);
- elementsAppended += count;
- return result;
- }
-
- public final int appendByte(byte v) {
- reserve(elementsAppended + 1);
- putByte(elementsAppended, v);
- return elementsAppended++;
- }
-
- public final int appendBytes(int count, byte v) {
- reserve(elementsAppended + count);
- int result = elementsAppended;
- putBytes(elementsAppended, count, v);
- elementsAppended += count;
- return result;
- }
-
- public final int appendBytes(int length, byte[] src, int offset) {
- reserve(elementsAppended + length);
- int result = elementsAppended;
- putBytes(elementsAppended, length, src, offset);
- elementsAppended += length;
- return result;
- }
-
- public final int appendShort(short v) {
- reserve(elementsAppended + 1);
- putShort(elementsAppended, v);
- return elementsAppended++;
- }
-
- public final int appendShorts(int count, short v) {
- reserve(elementsAppended + count);
- int result = elementsAppended;
- putShorts(elementsAppended, count, v);
- elementsAppended += count;
- return result;
- }
-
- public final int appendShorts(int length, short[] src, int offset) {
- reserve(elementsAppended + length);
- int result = elementsAppended;
- putShorts(elementsAppended, length, src, offset);
- elementsAppended += length;
- return result;
- }
-
- public final int appendInt(int v) {
- reserve(elementsAppended + 1);
- putInt(elementsAppended, v);
- return elementsAppended++;
- }
-
- public final int appendInts(int count, int v) {
- reserve(elementsAppended + count);
- int result = elementsAppended;
- putInts(elementsAppended, count, v);
- elementsAppended += count;
- return result;
- }
-
- public final int appendInts(int length, int[] src, int offset) {
- reserve(elementsAppended + length);
- int result = elementsAppended;
- putInts(elementsAppended, length, src, offset);
- elementsAppended += length;
- return result;
- }
-
- public final int appendLong(long v) {
- reserve(elementsAppended + 1);
- putLong(elementsAppended, v);
- return elementsAppended++;
- }
-
- public final int appendLongs(int count, long v) {
- reserve(elementsAppended + count);
- int result = elementsAppended;
- putLongs(elementsAppended, count, v);
- elementsAppended += count;
- return result;
- }
-
- public final int appendLongs(int length, long[] src, int offset) {
- reserve(elementsAppended + length);
- int result = elementsAppended;
- putLongs(elementsAppended, length, src, offset);
- elementsAppended += length;
- return result;
- }
-
- public final int appendFloat(float v) {
- reserve(elementsAppended + 1);
- putFloat(elementsAppended, v);
- return elementsAppended++;
- }
-
- public final int appendFloats(int count, float v) {
- reserve(elementsAppended + count);
- int result = elementsAppended;
- putFloats(elementsAppended, count, v);
- elementsAppended += count;
- return result;
- }
-
- public final int appendFloats(int length, float[] src, int offset) {
- reserve(elementsAppended + length);
- int result = elementsAppended;
- putFloats(elementsAppended, length, src, offset);
- elementsAppended += length;
- return result;
- }
-
- public final int appendDouble(double v) {
- reserve(elementsAppended + 1);
- putDouble(elementsAppended, v);
- return elementsAppended++;
- }
-
- public final int appendDoubles(int count, double v) {
- reserve(elementsAppended + count);
- int result = elementsAppended;
- putDoubles(elementsAppended, count, v);
- elementsAppended += count;
- return result;
- }
-
- public final int appendDoubles(int length, double[] src, int offset) {
- reserve(elementsAppended + length);
- int result = elementsAppended;
- putDoubles(elementsAppended, length, src, offset);
- elementsAppended += length;
- return result;
- }
-
- public final int appendByteArray(byte[] value, int offset, int length) {
- int copiedOffset = arrayData().appendBytes(length, value, offset);
- reserve(elementsAppended + 1);
- putArray(elementsAppended, copiedOffset, length);
- return elementsAppended++;
- }
-
- public final int appendArray(int length) {
- reserve(elementsAppended + 1);
- putArray(elementsAppended, arrayData().elementsAppended, length);
- return elementsAppended++;
- }
-
- /**
- * Appends a NULL struct. This *has* to be used for structs instead of appendNull() as this
- * recursively appends a NULL to its children.
- * We don't have this logic as the general appendNull implementation to optimize the more
- * common non-struct case.
- */
- public final int appendStruct(boolean isNull) {
- if (isNull) {
- appendNull();
- for (ColumnVector c: childColumns) {
- if (c.type instanceof StructType) {
- c.appendStruct(true);
- } else {
- c.appendNull();
- }
- }
- } else {
- appendNotNull();
- }
- return elementsAppended;
- }
+ public abstract byte[] getBinary(int rowId);
/**
* Returns the data for the underlying array.
*/
- public final ColumnVector arrayData() { return childColumns[0]; }
+ public abstract ColumnVector arrayData();
/**
* Returns the ordinal's child data column.
*/
- public final ColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; }
-
- /**
- * Returns the elements appended.
- */
- public final int getElementsAppended() { return elementsAppended; }
+ public abstract ColumnVector getChildColumn(int ordinal);
/**
* Returns true if this column is an array.
*/
public final boolean isArray() { return resultArray != null; }
- /**
- * Marks this column as being constant.
- */
- public final void setIsConstant() { isConstant = true; }
-
- /**
- * Maximum number of rows that can be stored in this column.
- */
- protected int capacity;
-
- /**
- * Upper limit for the maximum capacity for this column.
- */
- @VisibleForTesting
- protected int MAX_CAPACITY = Integer.MAX_VALUE;
-
/**
* Data type for this column.
*/
protected DataType type;
- /**
- * Number of nulls in this column. This is an optimization for the reader, to skip NULL checks.
- */
- protected int numNulls;
-
- /**
- * True if there is at least one NULL byte set. This is an optimization for the writer, to skip
- * having to clear NULL bits.
- */
- protected boolean anyNullsSet;
-
- /**
- * True if this column's values are fixed. This means the column values never change, even
- * across resets.
- */
- protected boolean isConstant;
-
- /**
- * Default size of each array length value. This grows as necessary.
- */
- protected static final int DEFAULT_ARRAY_LENGTH = 4;
-
- /**
- * Current write cursor (row index) when appending data.
- */
- protected int elementsAppended;
-
- /**
- * If this is a nested type (array or struct), the column for the child data.
- */
- protected ColumnVector[] childColumns;
-
/**
* Reusable Array holder for getArray().
*/
- protected Array resultArray;
+ protected ColumnVector.Array resultArray;
/**
* Reusable Struct holder for getStruct().
@@ -1012,32 +463,11 @@ public final int appendStruct(boolean isNull) {
*/
protected ColumnVector dictionaryIds;
- /**
- * Update the dictionary.
- */
- public void setDictionary(Dictionary dictionary) {
- this.dictionary = dictionary;
- }
-
/**
* Returns true if this column has a dictionary.
*/
public boolean hasDictionary() { return this.dictionary != null; }
- /**
- * Reserve a integer column for ids of dictionary.
- */
- public ColumnVector reserveDictionaryIds(int capacity) {
- if (dictionaryIds == null) {
- dictionaryIds = allocate(capacity, DataTypes.IntegerType,
- this instanceof OnHeapColumnVector ? MemoryMode.ON_HEAP : MemoryMode.OFF_HEAP);
- } else {
- dictionaryIds.reset();
- dictionaryIds.reserve(capacity);
- }
- return dictionaryIds;
- }
-
/**
* Returns the underlying integer column for ids of dictionary.
*/
@@ -1049,43 +479,7 @@ public ColumnVector getDictionaryIds() {
* Sets up the common state and also handles creating the child columns if this is a nested
* type.
*/
- protected ColumnVector(int capacity, DataType type, MemoryMode memMode) {
- this.capacity = capacity;
+ protected ColumnVector(DataType type) {
this.type = type;
-
- if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType
- || DecimalType.isByteArrayDecimalType(type)) {
- DataType childType;
- int childCapacity = capacity;
- if (type instanceof ArrayType) {
- childType = ((ArrayType)type).elementType();
- } else {
- childType = DataTypes.ByteType;
- childCapacity *= DEFAULT_ARRAY_LENGTH;
- }
- this.childColumns = new ColumnVector[1];
- this.childColumns[0] = ColumnVector.allocate(childCapacity, childType, memMode);
- this.resultArray = new Array(this.childColumns[0]);
- this.resultStruct = null;
- } else if (type instanceof StructType) {
- StructType st = (StructType)type;
- this.childColumns = new ColumnVector[st.fields().length];
- for (int i = 0; i < childColumns.length; ++i) {
- this.childColumns[i] = ColumnVector.allocate(capacity, st.fields()[i].dataType(), memMode);
- }
- this.resultArray = null;
- this.resultStruct = new ColumnarBatch.Row(this.childColumns);
- } else if (type instanceof CalendarIntervalType) {
- // Two columns. Months as int. Microseconds as Long.
- this.childColumns = new ColumnVector[2];
- this.childColumns[0] = ColumnVector.allocate(capacity, DataTypes.IntegerType, memMode);
- this.childColumns[1] = ColumnVector.allocate(capacity, DataTypes.LongType, memMode);
- this.resultArray = null;
- this.resultStruct = new ColumnarBatch.Row(this.childColumns);
- } else {
- this.childColumns = null;
- this.resultArray = null;
- this.resultStruct = null;
- }
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
index 900d7c431e723..adb859ed17757 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
@@ -40,7 +40,7 @@ public class ColumnVectorUtils {
/**
* Populates the entire `col` with `row[fieldIdx]`
*/
- public static void populate(ColumnVector col, InternalRow row, int fieldIdx) {
+ public static void populate(WritableColumnVector col, InternalRow row, int fieldIdx) {
int capacity = col.capacity;
DataType t = col.dataType();
@@ -115,7 +115,7 @@ public static Object toPrimitiveJavaArray(ColumnVector.Array array) {
}
}
- private static void appendValue(ColumnVector dst, DataType t, Object o) {
+ private static void appendValue(WritableColumnVector dst, DataType t, Object o) {
if (o == null) {
if (t instanceof CalendarIntervalType) {
dst.appendStruct(true);
@@ -165,7 +165,7 @@ private static void appendValue(ColumnVector dst, DataType t, Object o) {
}
}
- private static void appendValue(ColumnVector dst, DataType t, Row src, int fieldIdx) {
+ private static void appendValue(WritableColumnVector dst, DataType t, Row src, int fieldIdx) {
if (t instanceof ArrayType) {
ArrayType at = (ArrayType)t;
if (src.isNullAt(fieldIdx)) {
@@ -198,15 +198,23 @@ private static void appendValue(ColumnVector dst, DataType t, Row src, int field
*/
public static ColumnarBatch toBatch(
StructType schema, MemoryMode memMode, Iterator row) {
- ColumnarBatch batch = ColumnarBatch.allocate(schema, memMode);
+ int capacity = ColumnarBatch.DEFAULT_BATCH_SIZE;
+ WritableColumnVector[] columnVectors;
+ if (memMode == MemoryMode.OFF_HEAP) {
+ columnVectors = OffHeapColumnVector.allocateColumns(capacity, schema);
+ } else {
+ columnVectors = OnHeapColumnVector.allocateColumns(capacity, schema);
+ }
+
int n = 0;
while (row.hasNext()) {
Row r = row.next();
for (int i = 0; i < schema.fields().length; i++) {
- appendValue(batch.column(i), schema.fields()[i].dataType(), r, i);
+ appendValue(columnVectors[i], schema.fields()[i].dataType(), r, i);
}
n++;
}
+ ColumnarBatch batch = new ColumnarBatch(schema, columnVectors, capacity);
batch.setNumRows(n);
return batch;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
index 34dc3af9b85c8..e782756a3e781 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
@@ -19,7 +19,6 @@
import java.math.BigDecimal;
import java.util.*;
-import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
@@ -44,8 +43,7 @@
* - Compaction: The batch and columns should be able to compact based on a selection vector.
*/
public final class ColumnarBatch {
- private static final int DEFAULT_BATCH_SIZE = 4 * 1024;
- private static MemoryMode DEFAULT_MEMORY_MODE = MemoryMode.ON_HEAP;
+ public static final int DEFAULT_BATCH_SIZE = 4 * 1024;
private final StructType schema;
private final int capacity;
@@ -64,18 +62,6 @@ public final class ColumnarBatch {
// Staging row returned from getRow.
final Row row;
- public static ColumnarBatch allocate(StructType schema, MemoryMode memMode) {
- return new ColumnarBatch(schema, DEFAULT_BATCH_SIZE, memMode);
- }
-
- public static ColumnarBatch allocate(StructType type) {
- return new ColumnarBatch(type, DEFAULT_BATCH_SIZE, DEFAULT_MEMORY_MODE);
- }
-
- public static ColumnarBatch allocate(StructType schema, MemoryMode memMode, int maxRows) {
- return new ColumnarBatch(schema, maxRows, memMode);
- }
-
/**
* Called to close all the columns in this batch. It is not valid to access the data after
* calling this. This must be called at the end to clean up memory allocations.
@@ -95,12 +81,19 @@ public static final class Row extends InternalRow {
private final ColumnarBatch parent;
private final int fixedLenRowSize;
private final ColumnVector[] columns;
+ private final WritableColumnVector[] writableColumns;
// Ctor used if this is a top level row.
private Row(ColumnarBatch parent) {
this.parent = parent;
this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(parent.numCols());
this.columns = parent.columns;
+ this.writableColumns = new WritableColumnVector[this.columns.length];
+ for (int i = 0; i < this.columns.length; i++) {
+ if (this.columns[i] instanceof WritableColumnVector) {
+ this.writableColumns[i] = (WritableColumnVector) this.columns[i];
+ }
+ }
}
// Ctor used if this is a struct.
@@ -108,6 +101,12 @@ protected Row(ColumnVector[] columns) {
this.parent = null;
this.fixedLenRowSize = UnsafeRow.calculateFixedPortionByteSize(columns.length);
this.columns = columns;
+ this.writableColumns = new WritableColumnVector[this.columns.length];
+ for (int i = 0; i < this.columns.length; i++) {
+ if (this.columns[i] instanceof WritableColumnVector) {
+ this.writableColumns[i] = (WritableColumnVector) this.columns[i];
+ }
+ }
}
/**
@@ -307,64 +306,69 @@ public void update(int ordinal, Object value) {
@Override
public void setNullAt(int ordinal) {
- assert (!columns[ordinal].isConstant);
- columns[ordinal].putNull(rowId);
+ getWritableColumn(ordinal).putNull(rowId);
}
@Override
public void setBoolean(int ordinal, boolean value) {
- assert (!columns[ordinal].isConstant);
- columns[ordinal].putNotNull(rowId);
- columns[ordinal].putBoolean(rowId, value);
+ WritableColumnVector column = getWritableColumn(ordinal);
+ column.putNotNull(rowId);
+ column.putBoolean(rowId, value);
}
@Override
public void setByte(int ordinal, byte value) {
- assert (!columns[ordinal].isConstant);
- columns[ordinal].putNotNull(rowId);
- columns[ordinal].putByte(rowId, value);
+ WritableColumnVector column = getWritableColumn(ordinal);
+ column.putNotNull(rowId);
+ column.putByte(rowId, value);
}
@Override
public void setShort(int ordinal, short value) {
- assert (!columns[ordinal].isConstant);
- columns[ordinal].putNotNull(rowId);
- columns[ordinal].putShort(rowId, value);
+ WritableColumnVector column = getWritableColumn(ordinal);
+ column.putNotNull(rowId);
+ column.putShort(rowId, value);
}
@Override
public void setInt(int ordinal, int value) {
- assert (!columns[ordinal].isConstant);
- columns[ordinal].putNotNull(rowId);
- columns[ordinal].putInt(rowId, value);
+ WritableColumnVector column = getWritableColumn(ordinal);
+ column.putNotNull(rowId);
+ column.putInt(rowId, value);
}
@Override
public void setLong(int ordinal, long value) {
- assert (!columns[ordinal].isConstant);
- columns[ordinal].putNotNull(rowId);
- columns[ordinal].putLong(rowId, value);
+ WritableColumnVector column = getWritableColumn(ordinal);
+ column.putNotNull(rowId);
+ column.putLong(rowId, value);
}
@Override
public void setFloat(int ordinal, float value) {
- assert (!columns[ordinal].isConstant);
- columns[ordinal].putNotNull(rowId);
- columns[ordinal].putFloat(rowId, value);
+ WritableColumnVector column = getWritableColumn(ordinal);
+ column.putNotNull(rowId);
+ column.putFloat(rowId, value);
}
@Override
public void setDouble(int ordinal, double value) {
- assert (!columns[ordinal].isConstant);
- columns[ordinal].putNotNull(rowId);
- columns[ordinal].putDouble(rowId, value);
+ WritableColumnVector column = getWritableColumn(ordinal);
+ column.putNotNull(rowId);
+ column.putDouble(rowId, value);
}
@Override
public void setDecimal(int ordinal, Decimal value, int precision) {
- assert (!columns[ordinal].isConstant);
- columns[ordinal].putNotNull(rowId);
- columns[ordinal].putDecimal(rowId, value, precision);
+ WritableColumnVector column = getWritableColumn(ordinal);
+ column.putNotNull(rowId);
+ column.putDecimal(rowId, value, precision);
+ }
+
+ private WritableColumnVector getWritableColumn(int ordinal) {
+ WritableColumnVector column = writableColumns[ordinal];
+ assert (!column.isConstant);
+ return column;
}
}
@@ -409,7 +413,9 @@ public void remove() {
*/
public void reset() {
for (int i = 0; i < numCols(); ++i) {
- columns[i].reset();
+ if (columns[i] instanceof WritableColumnVector) {
+ ((WritableColumnVector) columns[i]).reset();
+ }
}
if (this.numRowsFiltered > 0) {
Arrays.fill(filteredRows, false);
@@ -427,7 +433,7 @@ public void setNumRows(int numRows) {
this.numRows = numRows;
for (int ordinal : nullFilteredColumns) {
- if (columns[ordinal].numNulls != 0) {
+ if (columns[ordinal].numNulls() != 0) {
for (int rowId = 0; rowId < numRows; rowId++) {
if (!filteredRows[rowId] && columns[ordinal].isNullAt(rowId)) {
filteredRows[rowId] = true;
@@ -505,18 +511,12 @@ public void filterNullsInColumn(int ordinal) {
nullFilteredColumns.add(ordinal);
}
- private ColumnarBatch(StructType schema, int maxRows, MemoryMode memMode) {
+ public ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) {
this.schema = schema;
- this.capacity = maxRows;
- this.columns = new ColumnVector[schema.size()];
+ this.columns = columns;
+ this.capacity = capacity;
this.nullFilteredColumns = new HashSet<>();
- this.filteredRows = new boolean[maxRows];
-
- for (int i = 0; i < schema.fields().length; ++i) {
- StructField field = schema.fields()[i];
- columns[i] = ColumnVector.allocate(maxRows, field.dataType(), memMode);
- }
-
+ this.filteredRows = new boolean[capacity];
this.row = new Row(this);
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
index 2d1f3da8e7463..35682756ed6c3 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
@@ -19,18 +19,39 @@
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
-import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.Platform;
/**
* Column data backed using offheap memory.
*/
-public final class OffHeapColumnVector extends ColumnVector {
+public final class OffHeapColumnVector extends WritableColumnVector {
private static final boolean bigEndianPlatform =
ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN);
+ /**
+ * Allocates columns to store elements of each field of the schema off heap.
+ * Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is
+ * in number of elements, not number of bytes.
+ */
+ public static OffHeapColumnVector[] allocateColumns(int capacity, StructType schema) {
+ return allocateColumns(capacity, schema.fields());
+ }
+
+ /**
+ * Allocates columns to store elements of each field off heap.
+ * Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is
+ * in number of elements, not number of bytes.
+ */
+ public static OffHeapColumnVector[] allocateColumns(int capacity, StructField[] fields) {
+ OffHeapColumnVector[] vectors = new OffHeapColumnVector[fields.length];
+ for (int i = 0; i < fields.length; i++) {
+ vectors[i] = new OffHeapColumnVector(capacity, fields[i].dataType());
+ }
+ return vectors;
+ }
+
// The data stored in these two allocations need to maintain binary compatible. We can
// directly pass this buffer to external components.
private long nulls;
@@ -40,8 +61,8 @@ public final class OffHeapColumnVector extends ColumnVector {
private long lengthData;
private long offsetData;
- protected OffHeapColumnVector(int capacity, DataType type) {
- super(capacity, type, MemoryMode.OFF_HEAP);
+ public OffHeapColumnVector(int capacity, DataType type) {
+ super(capacity, type);
nulls = 0;
data = 0;
@@ -519,4 +540,9 @@ protected void reserveInternal(int newCapacity) {
Platform.setMemory(nulls + oldCapacity, (byte)0, newCapacity - oldCapacity);
capacity = newCapacity;
}
+
+ @Override
+ protected OffHeapColumnVector reserveNewColumn(int capacity, DataType type) {
+ return new OffHeapColumnVector(capacity, type);
+ }
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index 506434364be48..96a452978cb35 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -20,7 +20,6 @@
import java.nio.ByteOrder;
import java.util.Arrays;
-import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.Platform;
@@ -28,11 +27,33 @@
* A column backed by an in memory JVM array. This stores the NULLs as a byte per value
* and a java array for the values.
*/
-public final class OnHeapColumnVector extends ColumnVector {
+public final class OnHeapColumnVector extends WritableColumnVector {
private static final boolean bigEndianPlatform =
ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN);
+ /**
+ * Allocates columns to store elements of each field of the schema on heap.
+ * Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is
+ * in number of elements, not number of bytes.
+ */
+ public static OnHeapColumnVector[] allocateColumns(int capacity, StructType schema) {
+ return allocateColumns(capacity, schema.fields());
+ }
+
+ /**
+ * Allocates columns to store elements of each field on heap.
+ * Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is
+ * in number of elements, not number of bytes.
+ */
+ public static OnHeapColumnVector[] allocateColumns(int capacity, StructField[] fields) {
+ OnHeapColumnVector[] vectors = new OnHeapColumnVector[fields.length];
+ for (int i = 0; i < fields.length; i++) {
+ vectors[i] = new OnHeapColumnVector(capacity, fields[i].dataType());
+ }
+ return vectors;
+ }
+
// The data stored in these arrays need to maintain binary compatible. We can
// directly pass this buffer to external components.
@@ -51,8 +72,9 @@ public final class OnHeapColumnVector extends ColumnVector {
private int[] arrayLengths;
private int[] arrayOffsets;
- protected OnHeapColumnVector(int capacity, DataType type) {
- super(capacity, type, MemoryMode.ON_HEAP);
+ public OnHeapColumnVector(int capacity, DataType type) {
+ super(capacity, type);
+
reserveInternal(capacity);
reset();
}
@@ -529,4 +551,9 @@ protected void reserveInternal(int newCapacity) {
capacity = newCapacity;
}
+
+ @Override
+ protected OnHeapColumnVector reserveNewColumn(int capacity, DataType type) {
+ return new OnHeapColumnVector(capacity, type);
+ }
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ReadOnlyColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ReadOnlyColumnVector.java
deleted file mode 100644
index e9f6e7c631fd4..0000000000000
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ReadOnlyColumnVector.java
+++ /dev/null
@@ -1,251 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.vectorized;
-
-import org.apache.spark.memory.MemoryMode;
-import org.apache.spark.sql.types.*;
-
-/**
- * An abstract class for read-only column vector.
- */
-public abstract class ReadOnlyColumnVector extends ColumnVector {
-
- protected ReadOnlyColumnVector(int capacity, DataType type, MemoryMode memMode) {
- super(capacity, DataTypes.NullType, memMode);
- this.type = type;
- isConstant = true;
- }
-
- //
- // APIs dealing with nulls
- //
-
- @Override
- public final void putNotNull(int rowId) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final void putNull(int rowId) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final void putNulls(int rowId, int count) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final void putNotNulls(int rowId, int count) {
- throw new UnsupportedOperationException();
- }
-
- //
- // APIs dealing with Booleans
- //
-
- @Override
- public final void putBoolean(int rowId, boolean value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final void putBooleans(int rowId, int count, boolean value) {
- throw new UnsupportedOperationException();
- }
-
- //
- // APIs dealing with Bytes
- //
-
- @Override
- public final void putByte(int rowId, byte value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final void putBytes(int rowId, int count, byte value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final void putBytes(int rowId, int count, byte[] src, int srcIndex) {
- throw new UnsupportedOperationException();
- }
-
- //
- // APIs dealing with Shorts
- //
-
- @Override
- public final void putShort(int rowId, short value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final void putShorts(int rowId, int count, short value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final void putShorts(int rowId, int count, short[] src, int srcIndex) {
- throw new UnsupportedOperationException();
- }
-
- //
- // APIs dealing with Ints
- //
-
- @Override
- public final void putInt(int rowId, int value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final void putInts(int rowId, int count, int value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final void putInts(int rowId, int count, int[] src, int srcIndex) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {
- throw new UnsupportedOperationException();
- }
-
- //
- // APIs dealing with Longs
- //
-
- @Override
- public final void putLong(int rowId, long value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final void putLongs(int rowId, int count, long value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final void putLongs(int rowId, int count, long[] src, int srcIndex) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) {
- throw new UnsupportedOperationException();
- }
-
- //
- // APIs dealing with floats
- //
-
- @Override
- public final void putFloat(int rowId, float value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final void putFloats(int rowId, int count, float value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final void putFloats(int rowId, int count, float[] src, int srcIndex) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final void putFloats(int rowId, int count, byte[] src, int srcIndex) {
- throw new UnsupportedOperationException();
- }
-
- //
- // APIs dealing with doubles
- //
-
- @Override
- public final void putDouble(int rowId, double value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final void putDoubles(int rowId, int count, double value) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final void putDoubles(int rowId, int count, double[] src, int srcIndex) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) {
- throw new UnsupportedOperationException();
- }
-
- //
- // APIs dealing with Arrays
- //
-
- @Override
- public final void putArray(int rowId, int offset, int length) {
- throw new UnsupportedOperationException();
- }
-
- //
- // APIs dealing with Byte Arrays
- //
-
- @Override
- public final int putByteArray(int rowId, byte[] value, int offset, int count) {
- throw new UnsupportedOperationException();
- }
-
- //
- // APIs dealing with Decimals
- //
-
- @Override
- public final void putDecimal(int rowId, Decimal value, int precision) {
- throw new UnsupportedOperationException();
- }
-
- //
- // Other APIs
- //
-
- @Override
- public final void setDictionary(Dictionary dictionary) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public final ColumnVector reserveDictionaryIds(int capacity) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- protected final void reserveInternal(int newCapacity) {
- throw new UnsupportedOperationException();
- }
-}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
new file mode 100644
index 0000000000000..b4f753c0bc2a3
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
@@ -0,0 +1,674 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.vectorized;
+
+import java.math.BigDecimal;
+import java.math.BigInteger;
+
+import com.google.common.annotations.VisibleForTesting;
+
+import org.apache.spark.sql.internal.SQLConf;
+import org.apache.spark.sql.types.*;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * This class adds write APIs to ColumnVector.
+ * It supports all the types and contains put APIs as well as their batched versions.
+ * The batched versions are preferable whenever possible.
+ *
+ * Capacity: The data stored is dense but the arrays are not fixed capacity. It is the
+ * responsibility of the caller to call reserve() to ensure there is enough room before adding
+ * elements. This means that the put() APIs do not check as in common cases (i.e. flat schemas),
+ * the lengths are known up front.
+ *
+ * A ColumnVector should be considered immutable once originally created. In other words, it is not
+ * valid to call put APIs after reads until reset() is called.
+ */
+public abstract class WritableColumnVector extends ColumnVector {
+
+ /**
+ * Resets this column for writing. The currently stored values are no longer accessible.
+ */
+ public void reset() {
+ if (isConstant) return;
+
+ if (childColumns != null) {
+ for (ColumnVector c: childColumns) {
+ ((WritableColumnVector) c).reset();
+ }
+ }
+ numNulls = 0;
+ elementsAppended = 0;
+ if (anyNullsSet) {
+ putNotNulls(0, capacity);
+ anyNullsSet = false;
+ }
+ }
+
+ public void reserve(int requiredCapacity) {
+ if (requiredCapacity > capacity) {
+ int newCapacity = (int) Math.min(MAX_CAPACITY, requiredCapacity * 2L);
+ if (requiredCapacity <= newCapacity) {
+ try {
+ reserveInternal(newCapacity);
+ } catch (OutOfMemoryError outOfMemoryError) {
+ throwUnsupportedException(requiredCapacity, outOfMemoryError);
+ }
+ } else {
+ throwUnsupportedException(requiredCapacity, null);
+ }
+ }
+ }
+
+ private void throwUnsupportedException(int requiredCapacity, Throwable cause) {
+ String message = "Cannot reserve additional contiguous bytes in the vectorized reader " +
+ "(requested = " + requiredCapacity + " bytes). As a workaround, you can disable the " +
+ "vectorized reader by setting " + SQLConf.PARQUET_VECTORIZED_READER_ENABLED().key() +
+ " to false.";
+ throw new RuntimeException(message, cause);
+ }
+
+ @Override
+ public int numNulls() { return numNulls; }
+
+ @Override
+ public boolean anyNullsSet() { return anyNullsSet; }
+
+ /**
+ * Ensures that there is enough storage to store capacity elements. That is, the put() APIs
+ * must work for all rowIds < capacity.
+ */
+ protected abstract void reserveInternal(int capacity);
+
+ /**
+ * Sets the value at rowId to null/not null.
+ */
+ public abstract void putNotNull(int rowId);
+ public abstract void putNull(int rowId);
+
+ /**
+ * Sets the values from [rowId, rowId + count) to null/not null.
+ */
+ public abstract void putNulls(int rowId, int count);
+ public abstract void putNotNulls(int rowId, int count);
+
+ /**
+ * Sets the value at rowId to `value`.
+ */
+ public abstract void putBoolean(int rowId, boolean value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to value.
+ */
+ public abstract void putBooleans(int rowId, int count, boolean value);
+
+ /**
+ * Sets the value at rowId to `value`.
+ */
+ public abstract void putByte(int rowId, byte value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to value.
+ */
+ public abstract void putBytes(int rowId, int count, byte value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
+ */
+ public abstract void putBytes(int rowId, int count, byte[] src, int srcIndex);
+
+ /**
+ * Sets the value at rowId to `value`.
+ */
+ public abstract void putShort(int rowId, short value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to value.
+ */
+ public abstract void putShorts(int rowId, int count, short value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
+ */
+ public abstract void putShorts(int rowId, int count, short[] src, int srcIndex);
+
+ /**
+ * Sets the value at rowId to `value`.
+ */
+ public abstract void putInt(int rowId, int value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to value.
+ */
+ public abstract void putInts(int rowId, int count, int value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
+ */
+ public abstract void putInts(int rowId, int count, int[] src, int srcIndex);
+
+ /**
+ * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count])
+ * The data in src must be 4-byte little endian ints.
+ */
+ public abstract void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex);
+
+ /**
+ * Sets the value at rowId to `value`.
+ */
+ public abstract void putLong(int rowId, long value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to value.
+ */
+ public abstract void putLongs(int rowId, int count, long value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
+ */
+ public abstract void putLongs(int rowId, int count, long[] src, int srcIndex);
+
+ /**
+ * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count])
+ * The data in src must be 8-byte little endian longs.
+ */
+ public abstract void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex);
+
+ /**
+ * Sets the value at rowId to `value`.
+ */
+ public abstract void putFloat(int rowId, float value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to value.
+ */
+ public abstract void putFloats(int rowId, int count, float value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
+ */
+ public abstract void putFloats(int rowId, int count, float[] src, int srcIndex);
+
+ /**
+ * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count])
+ * The data in src must be ieee formatted floats.
+ */
+ public abstract void putFloats(int rowId, int count, byte[] src, int srcIndex);
+
+ /**
+ * Sets the value at rowId to `value`.
+ */
+ public abstract void putDouble(int rowId, double value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to value.
+ */
+ public abstract void putDoubles(int rowId, int count, double value);
+
+ /**
+ * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count)
+ */
+ public abstract void putDoubles(int rowId, int count, double[] src, int srcIndex);
+
+ /**
+ * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count])
+ * The data in src must be ieee formatted doubles.
+ */
+ public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex);
+
+ /**
+ * Puts a byte array that already exists in this column.
+ */
+ public abstract void putArray(int rowId, int offset, int length);
+
+ /**
+ * Sets the value at rowId to `value`.
+ */
+ public abstract int putByteArray(int rowId, byte[] value, int offset, int count);
+ public final int putByteArray(int rowId, byte[] value) {
+ return putByteArray(rowId, value, 0, value.length);
+ }
+
+ /**
+ * Returns the value for rowId.
+ */
+ private ColumnVector.Array getByteArray(int rowId) {
+ ColumnVector.Array array = getArray(rowId);
+ array.data.loadBytes(array);
+ return array;
+ }
+
+ /**
+ * Returns the decimal for rowId.
+ */
+ @Override
+ public Decimal getDecimal(int rowId, int precision, int scale) {
+ if (precision <= Decimal.MAX_INT_DIGITS()) {
+ return Decimal.createUnsafe(getInt(rowId), precision, scale);
+ } else if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ return Decimal.createUnsafe(getLong(rowId), precision, scale);
+ } else {
+ // TODO: best perf?
+ byte[] bytes = getBinary(rowId);
+ BigInteger bigInteger = new BigInteger(bytes);
+ BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
+ return Decimal.apply(javaDecimal, precision, scale);
+ }
+ }
+
+ public void putDecimal(int rowId, Decimal value, int precision) {
+ if (precision <= Decimal.MAX_INT_DIGITS()) {
+ putInt(rowId, (int) value.toUnscaledLong());
+ } else if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ putLong(rowId, value.toUnscaledLong());
+ } else {
+ BigInteger bigInteger = value.toJavaBigDecimal().unscaledValue();
+ putByteArray(rowId, bigInteger.toByteArray());
+ }
+ }
+
+ /**
+ * Returns the UTF8String for rowId.
+ */
+ @Override
+ public UTF8String getUTF8String(int rowId) {
+ if (dictionary == null) {
+ ColumnVector.Array a = getByteArray(rowId);
+ return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length);
+ } else {
+ byte[] bytes = dictionary.decodeToBinary(dictionaryIds.getDictId(rowId));
+ return UTF8String.fromBytes(bytes);
+ }
+ }
+
+ /**
+ * Returns the byte array for rowId.
+ */
+ @Override
+ public byte[] getBinary(int rowId) {
+ if (dictionary == null) {
+ ColumnVector.Array array = getByteArray(rowId);
+ byte[] bytes = new byte[array.length];
+ System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length);
+ return bytes;
+ } else {
+ return dictionary.decodeToBinary(dictionaryIds.getDictId(rowId));
+ }
+ }
+
+ /**
+ * Append APIs. These APIs all behave similarly and will append data to the current vector. It
+ * is not valid to mix the put and append APIs. The append APIs are slower and should only be
+ * used if the sizes are not known up front.
+ * In all these cases, the return value is the rowId for the first appended element.
+ */
+ public final int appendNull() {
+ assert (!(dataType() instanceof StructType)); // Use appendStruct()
+ reserve(elementsAppended + 1);
+ putNull(elementsAppended);
+ return elementsAppended++;
+ }
+
+ public final int appendNotNull() {
+ reserve(elementsAppended + 1);
+ putNotNull(elementsAppended);
+ return elementsAppended++;
+ }
+
+ public final int appendNulls(int count) {
+ assert (!(dataType() instanceof StructType));
+ reserve(elementsAppended + count);
+ int result = elementsAppended;
+ putNulls(elementsAppended, count);
+ elementsAppended += count;
+ return result;
+ }
+
+ public final int appendNotNulls(int count) {
+ assert (!(dataType() instanceof StructType));
+ reserve(elementsAppended + count);
+ int result = elementsAppended;
+ putNotNulls(elementsAppended, count);
+ elementsAppended += count;
+ return result;
+ }
+
+ public final int appendBoolean(boolean v) {
+ reserve(elementsAppended + 1);
+ putBoolean(elementsAppended, v);
+ return elementsAppended++;
+ }
+
+ public final int appendBooleans(int count, boolean v) {
+ reserve(elementsAppended + count);
+ int result = elementsAppended;
+ putBooleans(elementsAppended, count, v);
+ elementsAppended += count;
+ return result;
+ }
+
+ public final int appendByte(byte v) {
+ reserve(elementsAppended + 1);
+ putByte(elementsAppended, v);
+ return elementsAppended++;
+ }
+
+ public final int appendBytes(int count, byte v) {
+ reserve(elementsAppended + count);
+ int result = elementsAppended;
+ putBytes(elementsAppended, count, v);
+ elementsAppended += count;
+ return result;
+ }
+
+ public final int appendBytes(int length, byte[] src, int offset) {
+ reserve(elementsAppended + length);
+ int result = elementsAppended;
+ putBytes(elementsAppended, length, src, offset);
+ elementsAppended += length;
+ return result;
+ }
+
+ public final int appendShort(short v) {
+ reserve(elementsAppended + 1);
+ putShort(elementsAppended, v);
+ return elementsAppended++;
+ }
+
+ public final int appendShorts(int count, short v) {
+ reserve(elementsAppended + count);
+ int result = elementsAppended;
+ putShorts(elementsAppended, count, v);
+ elementsAppended += count;
+ return result;
+ }
+
+ public final int appendShorts(int length, short[] src, int offset) {
+ reserve(elementsAppended + length);
+ int result = elementsAppended;
+ putShorts(elementsAppended, length, src, offset);
+ elementsAppended += length;
+ return result;
+ }
+
+ public final int appendInt(int v) {
+ reserve(elementsAppended + 1);
+ putInt(elementsAppended, v);
+ return elementsAppended++;
+ }
+
+ public final int appendInts(int count, int v) {
+ reserve(elementsAppended + count);
+ int result = elementsAppended;
+ putInts(elementsAppended, count, v);
+ elementsAppended += count;
+ return result;
+ }
+
+ public final int appendInts(int length, int[] src, int offset) {
+ reserve(elementsAppended + length);
+ int result = elementsAppended;
+ putInts(elementsAppended, length, src, offset);
+ elementsAppended += length;
+ return result;
+ }
+
+ public final int appendLong(long v) {
+ reserve(elementsAppended + 1);
+ putLong(elementsAppended, v);
+ return elementsAppended++;
+ }
+
+ public final int appendLongs(int count, long v) {
+ reserve(elementsAppended + count);
+ int result = elementsAppended;
+ putLongs(elementsAppended, count, v);
+ elementsAppended += count;
+ return result;
+ }
+
+ public final int appendLongs(int length, long[] src, int offset) {
+ reserve(elementsAppended + length);
+ int result = elementsAppended;
+ putLongs(elementsAppended, length, src, offset);
+ elementsAppended += length;
+ return result;
+ }
+
+ public final int appendFloat(float v) {
+ reserve(elementsAppended + 1);
+ putFloat(elementsAppended, v);
+ return elementsAppended++;
+ }
+
+ public final int appendFloats(int count, float v) {
+ reserve(elementsAppended + count);
+ int result = elementsAppended;
+ putFloats(elementsAppended, count, v);
+ elementsAppended += count;
+ return result;
+ }
+
+ public final int appendFloats(int length, float[] src, int offset) {
+ reserve(elementsAppended + length);
+ int result = elementsAppended;
+ putFloats(elementsAppended, length, src, offset);
+ elementsAppended += length;
+ return result;
+ }
+
+ public final int appendDouble(double v) {
+ reserve(elementsAppended + 1);
+ putDouble(elementsAppended, v);
+ return elementsAppended++;
+ }
+
+ public final int appendDoubles(int count, double v) {
+ reserve(elementsAppended + count);
+ int result = elementsAppended;
+ putDoubles(elementsAppended, count, v);
+ elementsAppended += count;
+ return result;
+ }
+
+ public final int appendDoubles(int length, double[] src, int offset) {
+ reserve(elementsAppended + length);
+ int result = elementsAppended;
+ putDoubles(elementsAppended, length, src, offset);
+ elementsAppended += length;
+ return result;
+ }
+
+ public final int appendByteArray(byte[] value, int offset, int length) {
+ int copiedOffset = arrayData().appendBytes(length, value, offset);
+ reserve(elementsAppended + 1);
+ putArray(elementsAppended, copiedOffset, length);
+ return elementsAppended++;
+ }
+
+ public final int appendArray(int length) {
+ reserve(elementsAppended + 1);
+ putArray(elementsAppended, arrayData().elementsAppended, length);
+ return elementsAppended++;
+ }
+
+ /**
+ * Appends a NULL struct. This *has* to be used for structs instead of appendNull() as this
+ * recursively appends a NULL to its children.
+ * We don't have this logic as the general appendNull implementation to optimize the more
+ * common non-struct case.
+ */
+ public final int appendStruct(boolean isNull) {
+ if (isNull) {
+ appendNull();
+ for (ColumnVector c: childColumns) {
+ if (c.type instanceof StructType) {
+ ((WritableColumnVector) c).appendStruct(true);
+ } else {
+ ((WritableColumnVector) c).appendNull();
+ }
+ }
+ } else {
+ appendNotNull();
+ }
+ return elementsAppended;
+ }
+
+ /**
+ * Returns the data for the underlying array.
+ */
+ @Override
+ public WritableColumnVector arrayData() { return childColumns[0]; }
+
+ /**
+ * Returns the ordinal's child data column.
+ */
+ @Override
+ public WritableColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; }
+
+ /**
+ * Returns the elements appended.
+ */
+ public final int getElementsAppended() { return elementsAppended; }
+
+ /**
+ * Marks this column as being constant.
+ */
+ public final void setIsConstant() { isConstant = true; }
+
+ /**
+ * Maximum number of rows that can be stored in this column.
+ */
+ protected int capacity;
+
+ /**
+ * Upper limit for the maximum capacity for this column.
+ */
+ @VisibleForTesting
+ protected int MAX_CAPACITY = Integer.MAX_VALUE;
+
+ /**
+ * Number of nulls in this column. This is an optimization for the reader, to skip NULL checks.
+ */
+ protected int numNulls;
+
+ /**
+ * True if there is at least one NULL byte set. This is an optimization for the writer, to skip
+ * having to clear NULL bits.
+ */
+ protected boolean anyNullsSet;
+
+ /**
+ * True if this column's values are fixed. This means the column values never change, even
+ * across resets.
+ */
+ protected boolean isConstant;
+
+ /**
+ * Default size of each array length value. This grows as necessary.
+ */
+ protected static final int DEFAULT_ARRAY_LENGTH = 4;
+
+ /**
+ * Current write cursor (row index) when appending data.
+ */
+ protected int elementsAppended;
+
+ /**
+ * If this is a nested type (array or struct), the column for the child data.
+ */
+ protected WritableColumnVector[] childColumns;
+
+ /**
+ * Update the dictionary.
+ */
+ public void setDictionary(Dictionary dictionary) {
+ this.dictionary = dictionary;
+ }
+
+ /**
+ * Reserve a integer column for ids of dictionary.
+ */
+ public WritableColumnVector reserveDictionaryIds(int capacity) {
+ WritableColumnVector dictionaryIds = (WritableColumnVector) this.dictionaryIds;
+ if (dictionaryIds == null) {
+ dictionaryIds = reserveNewColumn(capacity, DataTypes.IntegerType);
+ this.dictionaryIds = dictionaryIds;
+ } else {
+ dictionaryIds.reset();
+ dictionaryIds.reserve(capacity);
+ }
+ return dictionaryIds;
+ }
+
+ /**
+ * Returns the underlying integer column for ids of dictionary.
+ */
+ @Override
+ public WritableColumnVector getDictionaryIds() {
+ return (WritableColumnVector) dictionaryIds;
+ }
+
+ /**
+ * Reserve a new column.
+ */
+ protected abstract WritableColumnVector reserveNewColumn(int capacity, DataType type);
+
+ /**
+ * Sets up the common state and also handles creating the child columns if this is a nested
+ * type.
+ */
+ protected WritableColumnVector(int capacity, DataType type) {
+ super(type);
+ this.capacity = capacity;
+
+ if (type instanceof ArrayType || type instanceof BinaryType || type instanceof StringType
+ || DecimalType.isByteArrayDecimalType(type)) {
+ DataType childType;
+ int childCapacity = capacity;
+ if (type instanceof ArrayType) {
+ childType = ((ArrayType)type).elementType();
+ } else {
+ childType = DataTypes.ByteType;
+ childCapacity *= DEFAULT_ARRAY_LENGTH;
+ }
+ this.childColumns = new WritableColumnVector[1];
+ this.childColumns[0] = reserveNewColumn(childCapacity, childType);
+ this.resultArray = new ColumnVector.Array(this.childColumns[0]);
+ this.resultStruct = null;
+ } else if (type instanceof StructType) {
+ StructType st = (StructType)type;
+ this.childColumns = new WritableColumnVector[st.fields().length];
+ for (int i = 0; i < childColumns.length; ++i) {
+ this.childColumns[i] = reserveNewColumn(capacity, st.fields()[i].dataType());
+ }
+ this.resultArray = null;
+ this.resultStruct = new ColumnarBatch.Row(this.childColumns);
+ } else if (type instanceof CalendarIntervalType) {
+ // Two columns. Months as int. Microseconds as Long.
+ this.childColumns = new WritableColumnVector[2];
+ this.childColumns[0] = reserveNewColumn(capacity, DataTypes.IntegerType);
+ this.childColumns[1] = reserveNewColumn(capacity, DataTypes.LongType);
+ this.resultArray = null;
+ this.resultStruct = new ColumnarBatch.Row(this.childColumns);
+ } else {
+ this.childColumns = null;
+ this.resultArray = null;
+ this.resultStruct = null;
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 10b28ce812afc..c69acc413e87f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -27,7 +27,6 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
-import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{DataSource, FailureSafeParser}
import org.apache.spark.sql.execution.datasources.csv._
@@ -313,6 +312,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* (e.g. 00012)
* `allowBackslashEscapingAnyCharacter` (default `false`): allows accepting quoting of all
* character using backslash quoting mechanism
+ * `allowUnquotedControlChars` (default `false`): allows JSON Strings to contain unquoted
+ * control characters (ASCII characters with value less than 32, including tab and line feed
+ * characters) or not.
* `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records
* during parsing.
*
@@ -407,10 +409,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
parsedOptions.columnNameOfCorruptRecord)
iter.flatMap(parser.parse)
}
-
- Dataset.ofRows(
- sparkSession,
- LogicalRDD(schema.toAttributes, parsed)(sparkSession))
+ sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = jsonDataset.isStreaming)
}
/**
@@ -470,10 +469,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
parsedOptions.columnNameOfCorruptRecord)
iter.flatMap(parser.parse)
}
-
- Dataset.ofRows(
- sparkSession,
- LogicalRDD(schema.toAttributes, parsed)(sparkSession))
+ sparkSession.internalCreateDataFrame(parsed, schema, isStreaming = csvDataset.isStreaming)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 877051a60e910..07347d2748544 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -371,14 +371,14 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
case (true, SaveMode.Overwrite) =>
// Get all input data source or hive relations of the query.
val srcRelations = df.logicalPlan.collect {
- case LogicalRelation(src: BaseRelation, _, _) => src
+ case LogicalRelation(src: BaseRelation, _, _, _) => src
case relation: HiveTableRelation => relation.tableMeta.identifier
}
val tableRelation = df.sparkSession.table(tableIdentWithDB).queryExecution.analyzed
EliminateSubqueryAliases(tableRelation) match {
// check if the table is a data source table (the relation is a BaseRelation).
- case LogicalRelation(dest: BaseRelation, _, _) if srcRelations.contains(dest) =>
+ case LogicalRelation(dest: BaseRelation, _, _, _) if srcRelations.contains(dest) =>
throw new AnalysisException(
s"Cannot overwrite table $tableName that is also being read from")
// check hive table relation when overwrite mode
@@ -517,9 +517,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
*
* You can set the following ORC-specific option(s) for writing ORC files:
*
- * - `compression` (default `snappy`): compression codec to use when saving to file. This can be
- * one of the known case-insensitive shorten names(`none`, `snappy`, `zlib`, and `lzo`).
- * This will override `orc.compress`.
+ * - `compression` (default is the value specified in `spark.sql.orc.compression.codec`):
+ * compression codec to use when saving to file. This can be one of the known case-insensitive
+ * shorten names(`none`, `snappy`, `zlib`, and `lzo`). This will override
+ * `orc.compress` and `spark.sql.parquet.compression.codec`. If `orc.compress` is given,
+ * it overrides `spark.sql.parquet.compression.codec`.
*
*
* @since 1.5.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index a9887eb95279f..ab0c4126bcbdd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -569,7 +569,8 @@ class Dataset[T] private[sql](
logicalPlan.output,
internalRdd,
outputPartitioning,
- physicalPlan.outputOrdering
+ physicalPlan.outputOrdering,
+ isStreaming
)(sparkSession)).as[T]
}
@@ -1848,11 +1849,44 @@ class Dataset[T] private[sql](
Except(logicalPlan, other.logicalPlan)
}
+ /**
+ * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement),
+ * using a user-supplied seed.
+ *
+ * @param fraction Fraction of rows to generate, range [0.0, 1.0].
+ * @param seed Seed for sampling.
+ *
+ * @note This is NOT guaranteed to provide exactly the fraction of the count
+ * of the given [[Dataset]].
+ *
+ * @group typedrel
+ * @since 2.3.0
+ */
+ def sample(fraction: Double, seed: Long): Dataset[T] = {
+ sample(withReplacement = false, fraction = fraction, seed = seed)
+ }
+
+ /**
+ * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement),
+ * using a random seed.
+ *
+ * @param fraction Fraction of rows to generate, range [0.0, 1.0].
+ *
+ * @note This is NOT guaranteed to provide exactly the fraction of the count
+ * of the given [[Dataset]].
+ *
+ * @group typedrel
+ * @since 2.3.0
+ */
+ def sample(fraction: Double): Dataset[T] = {
+ sample(withReplacement = false, fraction = fraction)
+ }
+
/**
* Returns a new [[Dataset]] by sampling a fraction of rows, using a user-supplied seed.
*
* @param withReplacement Sample with replacement or not.
- * @param fraction Fraction of rows to generate.
+ * @param fraction Fraction of rows to generate, range [0.0, 1.0].
* @param seed Seed for sampling.
*
* @note This is NOT guaranteed to provide exactly the fraction of the count
@@ -1871,7 +1905,7 @@ class Dataset[T] private[sql](
* Returns a new [[Dataset]] by sampling a fraction of rows, using a random seed.
*
* @param withReplacement Sample with replacement or not.
- * @param fraction Fraction of rows to generate.
+ * @param fraction Fraction of rows to generate, range [0.0, 1.0].
*
* @note This is NOT guaranteed to provide exactly the fraction of the total count
* of the given [[Dataset]].
@@ -2201,7 +2235,7 @@ class Dataset[T] private[sql](
}
cols
}
- Deduplicate(groupCols, logicalPlan, isStreaming)
+ Deduplicate(groupCols, logicalPlan)
}
/**
@@ -2542,8 +2576,9 @@ class Dataset[T] private[sql](
* @group action
* @since 1.6.0
*/
- def foreachPartition(func: ForeachPartitionFunction[T]): Unit =
- foreachPartition(it => func.call(it.asJava))
+ def foreachPartition(func: ForeachPartitionFunction[T]): Unit = {
+ foreachPartition((it: Iterator[T]) => func.call(it.asJava))
+ }
/**
* Returns the first `n` rows in the Dataset.
@@ -2856,8 +2891,8 @@ class Dataset[T] private[sql](
*
* Global temporary view is cross-session. Its lifetime is the lifetime of the Spark application,
* i.e. it will be automatically dropped when the application terminates. It's tied to a system
- * preserved database `_global_temp`, and we must use the qualified name to refer a global temp
- * view, e.g. `SELECT * FROM _global_temp.view1`.
+ * preserved database `global_temp`, and we must use the qualified name to refer a global temp
+ * view, e.g. `SELECT * FROM global_temp.view1`.
*
* @group basic
* @since 2.2.0
@@ -2961,7 +2996,7 @@ class Dataset[T] private[sql](
*/
def inputFiles: Array[String] = {
val files: Seq[String] = queryExecution.optimizedPlan.collect {
- case LogicalRelation(fsBasedRelation: FileRelation, _, _) =>
+ case LogicalRelation(fsBasedRelation: FileRelation, _, _, _) =>
fsBasedRelation.inputFiles
case fr: FileRelation =>
fr.inputFiles
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 7fde6e9469e5e..af6018472cb03 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -420,8 +420,11 @@ class SQLContext private[sql](val sparkSession: SparkSession)
* converted to Catalyst rows.
*/
private[sql]
- def internalCreateDataFrame(catalystRows: RDD[InternalRow], schema: StructType) = {
- sparkSession.internalCreateDataFrame(catalystRows, schema)
+ def internalCreateDataFrame(
+ catalystRows: RDD[InternalRow],
+ schema: StructType,
+ isStreaming: Boolean = false) = {
+ sparkSession.internalCreateDataFrame(catalystRows, schema, isStreaming)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 0e46736d007c4..f2695d00a5373 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -559,20 +559,23 @@ class SparkSession private(
}
/**
- * Creates a `DataFrame` from an RDD[Row].
- * User can specify whether the input rows should be converted to Catalyst rows.
+ * Creates a `DataFrame` from an `RDD[InternalRow]`.
*/
private[sql] def internalCreateDataFrame(
catalystRows: RDD[InternalRow],
- schema: StructType): DataFrame = {
+ schema: StructType,
+ isStreaming: Boolean = false): DataFrame = {
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
// schema differs from the existing schema on any field data type.
- val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self)
+ val logicalPlan = LogicalRDD(
+ schema.toAttributes,
+ catalystRows,
+ isStreaming = isStreaming)(self)
Dataset.ofRows(self, logicalPlan)
}
/**
- * Creates a `DataFrame` from an RDD[Row].
+ * Creates a `DataFrame` from an `RDD[Row]`.
* User can specify whether the input rows should be converted to Catalyst rows.
*/
private[sql] def createDataFrame(
@@ -585,10 +588,9 @@ class SparkSession private(
val encoder = RowEncoder(schema)
rowRDD.map(encoder.toRow)
} else {
- rowRDD.map{r: Row => InternalRow.fromSeq(r.toSeq)}
+ rowRDD.map { r: Row => InternalRow.fromSeq(r.toSeq) }
}
- val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self)
- Dataset.ofRows(self, logicalPlan)
+ internalCreateDataFrame(catalystRows, schema)
}
@@ -733,13 +735,15 @@ class SparkSession private(
}
/**
- * Apply a schema defined by the schema to an RDD. It is only used by PySpark.
+ * Apply `schema` to an RDD.
+ *
+ * @note Used by PySpark only
*/
private[sql] def applySchemaToPythonRDD(
rdd: RDD[Array[Any]],
schema: StructType): DataFrame = {
val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow])
- Dataset.ofRows(self, LogicalRDD(schema.toAttributes, rowRdd)(self))
+ internalCreateDataFrame(rowRdd, schema)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
index 74a47da2deef2..1afe83ea3539e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
@@ -33,6 +33,8 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
val inMemoryTableScan: InMemoryTableScanExec = null
+ def vectorTypes: Option[Seq[String]] = None
+
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time"))
@@ -79,17 +81,19 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
val scanTimeTotalNs = ctx.freshName("scanTime")
ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;")
- val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch"
+ val columnarBatchClz = classOf[ColumnarBatch].getName
val batch = ctx.freshName("batch")
ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;")
- val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector"
val idx = ctx.freshName("batchIdx")
ctx.addMutableState("int", idx, s"$idx = 0;")
val colVars = output.indices.map(i => ctx.freshName("colInstance" + i))
- val columnAssigns = colVars.zipWithIndex.map { case (name, i) =>
- ctx.addMutableState(columnVectorClz, name, s"$name = null;")
- s"$name = $batch.column($i);"
+ val columnVectorClzs = vectorTypes.getOrElse(
+ Seq.fill(colVars.size)(classOf[ColumnVector].getName))
+ val columnAssigns = colVars.zip(columnVectorClzs).zipWithIndex.map {
+ case ((name, columnVectorClz), i) =>
+ ctx.addMutableState(columnVectorClz, name, s"$name = null;")
+ s"$name = ($columnVectorClz) $batch.column($i);"
}
val nextBatch = ctx.freshName("nextBatch")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index a0229e73c3ac3..88c2bc721771c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -69,7 +69,7 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport {
* Shorthand for calling redactString() without specifying redacting rules
*/
private def redact(text: String): String = {
- Utils.redact(SparkSession.getActiveSession.get.sparkContext.conf, text)
+ Utils.redact(SparkSession.getActiveSession.map(_.sparkContext.conf).orNull, text)
}
}
@@ -174,6 +174,11 @@ case class FileSourceScanExec(
false
}
+ override def vectorTypes: Option[Seq[String]] =
+ relation.fileFormat.vectorTypes(
+ requiredSchema = requiredSchema,
+ partitionSchema = relation.partitionSchema)
+
@transient private lazy val selectedPartitions: Seq[PartitionDirectory] = {
val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L)
val startTime = System.nanoTime()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index dcb918eeb9d10..f3555508185fe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -125,7 +125,8 @@ case class LogicalRDD(
output: Seq[Attribute],
rdd: RDD[InternalRow],
outputPartitioning: Partitioning = UnknownPartitioning(0),
- outputOrdering: Seq[SortOrder] = Nil)(session: SparkSession)
+ outputOrdering: Seq[SortOrder] = Nil,
+ override val isStreaming: Boolean = false)(session: SparkSession)
extends LeafNode with MultiInstanceRelation {
override protected final def otherCopyArgs: Seq[AnyRef] = session :: Nil
@@ -150,11 +151,12 @@ case class LogicalRDD(
output.map(rewrite),
rdd,
rewrittenPartitioning,
- rewrittenOrdering
+ rewrittenOrdering,
+ isStreaming
)(session).asInstanceOf[this.type]
}
- override protected def stringArgs: Iterator[Any] = Iterator(output)
+ override protected def stringArgs: Iterator[Any] = Iterator(output, isStreaming)
override def computeStats(): Statistics = Statistics(
// TODO: Instead of returning a default value here, find a way to return a meaningful size
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala
index 301c4f02647d5..18f6f697bc857 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala
@@ -94,10 +94,10 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic
child transform {
case plan if plan eq relation =>
relation match {
- case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _) =>
+ case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, isStreaming) =>
val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l)
val partitionData = fsRelation.location.listFiles(Nil, Nil)
- LocalRelation(partAttrs, partitionData.map(_.values))
+ LocalRelation(partAttrs, partitionData.map(_.values), isStreaming)
case relation: HiveTableRelation =>
val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation)
@@ -130,7 +130,7 @@ case class OptimizeMetadataOnlyQuery(catalog: SessionCatalog) extends Rule[Logic
object PartitionedRelation {
def unapply(plan: LogicalPlan): Option[(AttributeSet, LogicalPlan)] = plan match {
- case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _)
+ case l @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _)
if fsRelation.partitionSchema.nonEmpty =>
val partAttrs = getPartitionAttrs(fsRelation.partitionSchema.map(_.name), l)
Some((AttributeSet(partAttrs), l))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index c7277c21cebb2..b263f100e6068 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -56,15 +56,17 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
protected def sparkContext = sqlContext.sparkContext
- // sqlContext will be null when we are being deserialized on the slaves. In this instance
- // the value of subexpressionEliminationEnabled will be set by the deserializer after the
- // constructor has run.
+ // sqlContext will be null when SparkPlan nodes are created without the active sessions.
+ // So far, this only happens in the test cases.
val subexpressionEliminationEnabled: Boolean = if (sqlContext != null) {
sqlContext.conf.subexpressionEliminationEnabled
} else {
false
}
+ // whether we should fallback when hitting compilation errors caused by codegen
+ private val codeGenFallBack = (sqlContext == null) || sqlContext.conf.codegenFallback
+
/** Overridden make copy also propagates sqlContext to copied plan. */
override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = {
SparkSession.setActiveSession(sqlContext.sparkSession)
@@ -370,8 +372,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
try {
GeneratePredicate.generate(expression, inputSchema)
} catch {
- case e @ (_: JaninoRuntimeException | _: CompileException)
- if sqlContext == null || sqlContext.conf.wholeStageFallback =>
+ case _ @ (_: JaninoRuntimeException | _: CompileException) if codeGenFallBack =>
genInterpretedPredicate(expression, inputSchema)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index 06b69625fb53e..2a2315896831c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -17,16 +17,18 @@
package org.apache.spark.sql.execution
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties
+
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.metric.SQLMetricInfo
-import org.apache.spark.util.Utils
/**
* :: DeveloperApi ::
* Stores information about a SQL SparkPlan.
*/
@DeveloperApi
+@JsonIgnoreProperties(Array("metadata")) // The metadata field was removed in Spark 2.3.
class SparkPlanInfo(
val nodeName: String,
val simpleString: String,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index d4414b6f78ca2..6de9ea0efd2c6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.parser._
import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.command._
-import org.apache.spark.sql.execution.datasources.{CreateTable, _}
+import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf, VariableSubstitution}
import org.apache.spark.sql.types.StructType
@@ -90,30 +90,40 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
}
/**
- * Create an [[AnalyzeTableCommand]] command or an [[AnalyzeColumnCommand]] command.
- * Example SQL for analyzing table :
+ * Create an [[AnalyzeTableCommand]] command, or an [[AnalyzePartitionCommand]]
+ * or an [[AnalyzeColumnCommand]] command.
+ * Example SQL for analyzing a table or a set of partitions :
* {{{
- * ANALYZE TABLE table COMPUTE STATISTICS [NOSCAN];
+ * ANALYZE TABLE [db_name.]tablename [PARTITION (partcol1[=val1], partcol2[=val2], ...)]
+ * COMPUTE STATISTICS [NOSCAN];
* }}}
+ *
* Example SQL for analyzing columns :
* {{{
- * ANALYZE TABLE table COMPUTE STATISTICS FOR COLUMNS column1, column2;
+ * ANALYZE TABLE [db_name.]tablename COMPUTE STATISTICS FOR COLUMNS column1, column2;
* }}}
*/
override def visitAnalyze(ctx: AnalyzeContext): LogicalPlan = withOrigin(ctx) {
- if (ctx.partitionSpec != null) {
- logWarning(s"Partition specification is ignored: ${ctx.partitionSpec.getText}")
+ if (ctx.identifier != null &&
+ ctx.identifier.getText.toLowerCase(Locale.ROOT) != "noscan") {
+ throw new ParseException(s"Expected `NOSCAN` instead of `${ctx.identifier.getText}`", ctx)
}
- if (ctx.identifier != null) {
- if (ctx.identifier.getText.toLowerCase(Locale.ROOT) != "noscan") {
- throw new ParseException(s"Expected `NOSCAN` instead of `${ctx.identifier.getText}`", ctx)
+
+ val table = visitTableIdentifier(ctx.tableIdentifier)
+ if (ctx.identifierSeq() == null) {
+ if (ctx.partitionSpec != null) {
+ AnalyzePartitionCommand(table, visitPartitionSpec(ctx.partitionSpec),
+ noscan = ctx.identifier != null)
+ } else {
+ AnalyzeTableCommand(table, noscan = ctx.identifier != null)
}
- AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier))
- } else if (ctx.identifierSeq() == null) {
- AnalyzeTableCommand(visitTableIdentifier(ctx.tableIdentifier), noscan = false)
} else {
+ if (ctx.partitionSpec != null) {
+ logWarning("Partition specification is ignored when collecting column statistics: " +
+ ctx.partitionSpec.getText)
+ }
AnalyzeColumnCommand(
- visitTableIdentifier(ctx.tableIdentifier),
+ table,
visitIdentifierSeq(ctx.identifierSeq()))
}
}
@@ -320,10 +330,16 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
* Create a [[DescribeTableCommand]] logical plan.
*/
override def visitDescribeTable(ctx: DescribeTableContext): LogicalPlan = withOrigin(ctx) {
- // Describe column are not supported yet. Return null and let the parser decide
- // what to do with this (create an exception or pass it on to a different system).
+ val isExtended = ctx.EXTENDED != null || ctx.FORMATTED != null
if (ctx.describeColName != null) {
- null
+ if (ctx.partitionSpec != null) {
+ throw new ParseException("DESC TABLE COLUMN for a specific partition is not supported", ctx)
+ } else {
+ DescribeColumnCommand(
+ visitTableIdentifier(ctx.tableIdentifier),
+ ctx.describeColName.nameParts.asScala.map(_.getText),
+ isExtended)
+ }
} else {
val partitionSpec = if (ctx.partitionSpec != null) {
// According to the syntax, visitPartitionSpec returns `Map[String, Option[String]]`.
@@ -338,7 +354,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
DescribeTableCommand(
visitTableIdentifier(ctx.tableIdentifier),
partitionSpec,
- ctx.EXTENDED != null || ctx.FORMATTED != null)
+ isExtended)
}
}
@@ -375,7 +391,8 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
* ]
* [LOCATION path]
* [COMMENT table_comment]
- * [AS select_statement];
+ * [TBLPROPERTIES (property_name=property_value, ...)]
+ * [[AS] select_statement];
* }}}
*/
override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) {
@@ -390,6 +407,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
Option(ctx.partitionColumnNames)
.map(visitIdentifierList(_).toArray)
.getOrElse(Array.empty[String])
+ val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty)
val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec)
val location = Option(ctx.locationSpec).map(visitLocationSpec)
@@ -400,7 +418,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
"LOCATION and 'path' in OPTIONS are both used to indicate the custom table path, " +
"you can only specify one of them.", ctx)
}
- val customLocation = storage.locationUri.orElse(location.map(CatalogUtils.stringToURI(_)))
+ val customLocation = storage.locationUri.orElse(location.map(CatalogUtils.stringToURI))
val tableType = if (customLocation.isDefined) {
CatalogTableType.EXTERNAL
@@ -416,6 +434,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
provider = Some(provider),
partitionColumnNames = partitionColumnNames,
bucketSpec = bucketSpec,
+ properties = properties,
comment = Option(ctx.comment).map(string))
// Determine the storage mode.
@@ -1499,4 +1518,81 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
query: LogicalPlan): LogicalPlan = {
RepartitionByExpression(expressions, query, conf.numShufflePartitions)
}
+
+ /**
+ * Return the parameters for [[InsertIntoDir]] logical plan.
+ *
+ * Expected format:
+ * {{{
+ * INSERT OVERWRITE DIRECTORY
+ * [path]
+ * [OPTIONS table_property_list]
+ * select_statement;
+ * }}}
+ */
+ override def visitInsertOverwriteDir(
+ ctx: InsertOverwriteDirContext): InsertDirParams = withOrigin(ctx) {
+ if (ctx.LOCAL != null) {
+ throw new ParseException(
+ "LOCAL is not supported in INSERT OVERWRITE DIRECTORY to data source", ctx)
+ }
+
+ val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty)
+ var storage = DataSource.buildStorageFormatFromOptions(options)
+
+ val path = Option(ctx.path).map(string).getOrElse("")
+
+ if (!(path.isEmpty ^ storage.locationUri.isEmpty)) {
+ throw new ParseException(
+ "Directory path and 'path' in OPTIONS should be specified one, but not both", ctx)
+ }
+
+ if (!path.isEmpty) {
+ val customLocation = Some(CatalogUtils.stringToURI(path))
+ storage = storage.copy(locationUri = customLocation)
+ }
+
+ val provider = ctx.tableProvider.qualifiedName.getText
+
+ (false, storage, Some(provider))
+ }
+
+ /**
+ * Return the parameters for [[InsertIntoDir]] logical plan.
+ *
+ * Expected format:
+ * {{{
+ * INSERT OVERWRITE [LOCAL] DIRECTORY
+ * path
+ * [ROW FORMAT row_format]
+ * [STORED AS file_format]
+ * select_statement;
+ * }}}
+ */
+ override def visitInsertOverwriteHiveDir(
+ ctx: InsertOverwriteHiveDirContext): InsertDirParams = withOrigin(ctx) {
+ validateRowFormatFileFormat(ctx.rowFormat, ctx.createFileFormat, ctx)
+ val rowStorage = Option(ctx.rowFormat).map(visitRowFormat)
+ .getOrElse(CatalogStorageFormat.empty)
+ val fileStorage = Option(ctx.createFileFormat).map(visitCreateFileFormat)
+ .getOrElse(CatalogStorageFormat.empty)
+
+ val path = string(ctx.path)
+ // The path field is required
+ if (path.isEmpty) {
+ operationNotAllowed("INSERT OVERWRITE DIRECTORY must be accompanied by path", ctx)
+ }
+
+ val defaultStorage = HiveSerDe.getDefaultStorage(conf)
+
+ val storage = CatalogStorageFormat(
+ locationUri = Some(CatalogUtils.stringToURI(path)),
+ inputFormat = fileStorage.inputFormat.orElse(defaultStorage.inputFormat),
+ outputFormat = fileStorage.outputFormat.orElse(defaultStorage.outputFormat),
+ serde = rowStorage.serde.orElse(fileStorage.serde).orElse(defaultStorage.serde),
+ compressed = false,
+ properties = rowStorage.properties ++ fileStorage.properties)
+
+ (ctx.LOCAL != null, storage, Some(DDLUtils.HIVE_PROVIDER))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 2e8ce4541865d..6b16408e27840 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -63,29 +63,24 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
*/
object SpecialLimits extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.ReturnAnswer(rootPlan) => rootPlan match {
- case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
- execution.TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
- case logical.Limit(
- IntegerLiteral(limit),
- logical.Project(projectList, logical.Sort(order, true, child))) =>
- execution.TakeOrderedAndProjectExec(
- limit, order, projectList, planLater(child)) :: Nil
- case logical.Limit(IntegerLiteral(limit), child) =>
- // Normally wrapping child with `LocalLimitExec` here is a no-op, because
- // `CollectLimitExec.executeCollect` will call `LocalLimitExec.executeTake`, which
- // calls `child.executeTake`. If child supports whole stage codegen, adding this
- // `LocalLimitExec` can stop the processing of whole stage codegen and trigger the
- // resource releasing work, after we consume `limit` rows.
- execution.CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil
+ case ReturnAnswer(rootPlan) => rootPlan match {
+ case Limit(IntegerLiteral(limit), Sort(order, true, child)) =>
+ TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
+ case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) =>
+ TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
+ case Limit(IntegerLiteral(limit), child) =>
+ // With whole stage codegen, Spark releases resources only when all the output data of the
+ // query plan are consumed. It's possible that `CollectLimitExec` only consumes a little
+ // data from child plan and finishes the query without releasing resources. Here we wrap
+ // the child plan with `LocalLimitExec`, to stop the processing of whole stage codegen and
+ // trigger the resource releasing work, after we consume `limit` rows.
+ CollectLimitExec(limit, LocalLimitExec(limit, planLater(child))) :: Nil
case other => planLater(other) :: Nil
}
- case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
- execution.TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
- case logical.Limit(
- IntegerLiteral(limit), logical.Project(projectList, logical.Sort(order, true, child))) =>
- execution.TakeOrderedAndProjectExec(
- limit, order, projectList, planLater(child)) :: Nil
+ case Limit(IntegerLiteral(limit), Sort(order, true, child)) =>
+ TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
+ case Limit(IntegerLiteral(limit), Project(projectList, Sort(order, true, child))) =>
+ TakeOrderedAndProjectExec(limit, order, projectList, planLater(child)) :: Nil
case _ => Nil
}
}
@@ -226,12 +221,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
/**
- * Used to plan aggregation queries that are computed incrementally as part of a
+ * Used to plan streaming aggregation queries that are computed incrementally as part of a
* [[StreamingQuery]]. Currently this rule is injected into the planner
* on-demand, only when planning in a [[org.apache.spark.sql.execution.streaming.StreamExecution]]
*/
object StatefulAggregationStrategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case _ if !plan.isStreaming => Nil
+
case EventTimeWatermark(columnName, delay, child) =>
EventTimeWatermarkExec(columnName, delay, planLater(child)) :: Nil
@@ -253,7 +250,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
*/
object StreamingDeduplicationStrategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case Deduplicate(keys, child, true) =>
+ case Deduplicate(keys, child) if child.isStreaming =>
StreamingDeduplicateExec(keys, planLater(child)) :: Nil
case _ => Nil
@@ -415,7 +412,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.window.WindowExec(windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil
case logical.Sample(lb, ub, withReplacement, seed, child) =>
execution.SampleExec(lb, ub, withReplacement, seed, planLater(child)) :: Nil
- case logical.LocalRelation(output, data) =>
+ case logical.LocalRelation(output, data, _) =>
LocalTableScanExec(output, data) :: Nil
case logical.LocalLimit(IntegerLiteral(limit), child) =>
execution.LocalLimitExec(limit, planLater(child)) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index bacb7090a70ab..a41a7ca56a0a1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -382,7 +382,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
try {
CodeGenerator.compile(cleanedSource)
} catch {
- case e: Exception if !Utils.isTesting && sqlContext.conf.wholeStageFallback =>
+ case _: Exception if !Utils.isTesting && sqlContext.conf.codegenFallback =>
// We should already saw the error message
logWarning(s"Whole-stage codegen disabled for this plan:\n $treeString")
return child.execute()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
index 0c40417db0837..13f79275cac41 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
@@ -76,6 +76,8 @@ class VectorizedHashMapGenerator(
}.mkString("\n").concat(";")
s"""
+ | private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] batchVectors;
+ | private org.apache.spark.sql.execution.vectorized.OnHeapColumnVector[] bufferVectors;
| private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch;
| private org.apache.spark.sql.execution.vectorized.ColumnarBatch aggregateBufferBatch;
| private int[] buckets;
@@ -89,14 +91,19 @@ class VectorizedHashMapGenerator(
| $generatedAggBufferSchema
|
| public $generatedClassName() {
- | batch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema,
- | org.apache.spark.memory.MemoryMode.ON_HEAP, capacity);
- | // TODO: Possibly generate this projection in HashAggregate directly
- | aggregateBufferBatch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(
- | aggregateBufferSchema, org.apache.spark.memory.MemoryMode.ON_HEAP, capacity);
- | for (int i = 0 ; i < aggregateBufferBatch.numCols(); i++) {
- | aggregateBufferBatch.setColumn(i, batch.column(i+${groupingKeys.length}));
+ | batchVectors = org.apache.spark.sql.execution.vectorized
+ | .OnHeapColumnVector.allocateColumns(capacity, schema);
+ | batch = new org.apache.spark.sql.execution.vectorized.ColumnarBatch(
+ | schema, batchVectors, capacity);
+ |
+ | bufferVectors = new org.apache.spark.sql.execution.vectorized
+ | .OnHeapColumnVector[aggregateBufferSchema.fields().length];
+ | for (int i = 0; i < aggregateBufferSchema.fields().length; i++) {
+ | bufferVectors[i] = batchVectors[i + ${groupingKeys.length}];
| }
+ | // TODO: Possibly generate this projection in HashAggregate directly
+ | aggregateBufferBatch = new org.apache.spark.sql.execution.vectorized.ColumnarBatch(
+ | aggregateBufferSchema, bufferVectors, capacity);
|
| buckets = new int[numBuckets];
| java.util.Arrays.fill(buckets, -1);
@@ -112,8 +119,8 @@ class VectorizedHashMapGenerator(
*
* {{{
* private boolean equals(int idx, long agg_key, long agg_key1) {
- * return batch.column(0).getLong(buckets[idx]) == agg_key &&
- * batch.column(1).getLong(buckets[idx]) == agg_key1;
+ * return batchVectors[0].getLong(buckets[idx]) == agg_key &&
+ * batchVectors[1].getLong(buckets[idx]) == agg_key1;
* }
* }}}
*/
@@ -121,8 +128,8 @@ class VectorizedHashMapGenerator(
def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = {
groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
- s"""(${ctx.genEqual(key.dataType, ctx.getValue("batch", "buckets[idx]",
- key.dataType, ordinal), key.name)})"""
+ s"""(${ctx.genEqual(key.dataType, ctx.getValue(s"batchVectors[$ordinal]", "buckets[idx]",
+ key.dataType), key.name)})"""
}.mkString(" && ")
}
@@ -150,9 +157,9 @@ class VectorizedHashMapGenerator(
* while (step < maxSteps) {
* // Return bucket index if it's either an empty slot or already contains the key
* if (buckets[idx] == -1) {
- * batch.column(0).putLong(numRows, agg_key);
- * batch.column(1).putLong(numRows, agg_key1);
- * batch.column(2).putLong(numRows, 0);
+ * batchVectors[0].putLong(numRows, agg_key);
+ * batchVectors[1].putLong(numRows, agg_key1);
+ * batchVectors[2].putLong(numRows, 0);
* buckets[idx] = numRows++;
* return batch.getRow(buckets[idx]);
* } else if (equals(idx, agg_key, agg_key1)) {
@@ -170,13 +177,13 @@ class VectorizedHashMapGenerator(
def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = {
groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
- ctx.setValue("batch", "numRows", key.dataType, ordinal, key.name)
+ ctx.setValue(s"batchVectors[$ordinal]", "numRows", key.dataType, key.name)
}
}
def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = {
bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
- ctx.updateColumn("batch", "numRows", key.dataType, groupingKeys.length + ordinal,
+ ctx.updateColumn(s"batchVectors[${groupingKeys.length + ordinal}]", "numRows", key.dataType,
buffVars(ordinal), nullable = true)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
index 1dae5f6964e56..b6550bf3e4aac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
@@ -37,7 +37,7 @@ class TypedSumDouble[IN](val f: IN => Double) extends Aggregator[IN, Double, Dou
override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]()
// Java api support
- def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
+ def this(f: MapFunction[IN, java.lang.Double]) = this((x: IN) => f.call(x).asInstanceOf[Double])
def toColumnJava: TypedColumn[IN, java.lang.Double] = {
toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]]
@@ -55,7 +55,7 @@ class TypedSumLong[IN](val f: IN => Long) extends Aggregator[IN, Long, Long] {
override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]()
// Java api support
- def this(f: MapFunction[IN, java.lang.Long]) = this(x => f.call(x).asInstanceOf[Long])
+ def this(f: MapFunction[IN, java.lang.Long]) = this((x: IN) => f.call(x).asInstanceOf[Long])
def toColumnJava: TypedColumn[IN, java.lang.Long] = {
toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]]
@@ -75,7 +75,7 @@ class TypedCount[IN](val f: IN => Any) extends Aggregator[IN, Long, Long] {
override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]()
// Java api support
- def this(f: MapFunction[IN, Object]) = this(x => f.call(x))
+ def this(f: MapFunction[IN, Object]) = this((x: IN) => f.call(x).asInstanceOf[Any])
def toColumnJava: TypedColumn[IN, java.lang.Long] = {
toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]]
}
@@ -94,7 +94,7 @@ class TypedAverage[IN](val f: IN => Double) extends Aggregator[IN, (Double, Long
override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]()
// Java api support
- def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double])
+ def this(f: MapFunction[IN, java.lang.Double]) = this((x: IN) => f.call(x).asInstanceOf[Double])
def toColumnJava: TypedColumn[IN, java.lang.Double] = {
toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index ae5e2c6bece2a..fec1add18cbf2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -324,7 +324,11 @@ case class ScalaUDAF(
udaf: UserDefinedAggregateFunction,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
- extends ImperativeAggregate with NonSQLExpression with Logging with ImplicitCastInputTypes {
+ extends ImperativeAggregate
+ with NonSQLExpression
+ with Logging
+ with ImplicitCastInputTypes
+ with UserDefinedExpression {
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index fa45822311e15..561a067a2f81f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.arrow
import java.io.ByteArrayOutputStream
import java.nio.channels.Channels
+import scala.collection.JavaConverters._
+
import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.file._
@@ -28,6 +30,7 @@ import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel
import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -35,7 +38,7 @@ import org.apache.spark.util.Utils
/**
* Store Arrow data in a form that can be serialized by Spark and served to a Python process.
*/
-private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Serializable {
+private[sql] class ArrowPayload private[sql] (payload: Array[Byte]) extends Serializable {
/**
* Convert the ArrowPayload to an ArrowRecordBatch.
@@ -50,6 +53,17 @@ private[sql] class ArrowPayload private[arrow] (payload: Array[Byte]) extends Se
def asPythonSerializable: Array[Byte] = payload
}
+/**
+ * Iterator interface to iterate over Arrow record batches and return rows
+ */
+private[sql] trait ArrowRowIterator extends Iterator[InternalRow] {
+
+ /**
+ * Return the schema loaded from the Arrow record batch being iterated over
+ */
+ def schema: StructType
+}
+
private[sql] object ArrowConverters {
/**
@@ -110,6 +124,66 @@ private[sql] object ArrowConverters {
}
}
+ /**
+ * Maps Iterator from ArrowPayload to InternalRow. Returns a pair containing the row iterator
+ * and the schema from the first batch of Arrow data read.
+ */
+ private[sql] def fromPayloadIterator(
+ payloadIter: Iterator[ArrowPayload],
+ context: TaskContext): ArrowRowIterator = {
+ val allocator =
+ ArrowUtils.rootAllocator.newChildAllocator("fromPayloadIterator", 0, Long.MaxValue)
+
+ new ArrowRowIterator {
+ private var reader: ArrowFileReader = null
+ private var schemaRead = StructType(Seq.empty)
+ private var rowIter = if (payloadIter.hasNext) nextBatch() else Iterator.empty
+
+ context.addTaskCompletionListener { _ =>
+ closeReader()
+ allocator.close()
+ }
+
+ override def schema: StructType = schemaRead
+
+ override def hasNext: Boolean = rowIter.hasNext || {
+ closeReader()
+ if (payloadIter.hasNext) {
+ rowIter = nextBatch()
+ true
+ } else {
+ allocator.close()
+ false
+ }
+ }
+
+ override def next(): InternalRow = rowIter.next()
+
+ private def closeReader(): Unit = {
+ if (reader != null) {
+ reader.close()
+ reader = null
+ }
+ }
+
+ private def nextBatch(): Iterator[InternalRow] = {
+ val in = new ByteArrayReadableSeekableByteChannel(payloadIter.next().asPythonSerializable)
+ reader = new ArrowFileReader(in, allocator)
+ reader.loadNextBatch() // throws IOException
+ val root = reader.getVectorSchemaRoot // throws IOException
+ schemaRead = ArrowUtils.fromArrowSchema(root.getSchema)
+
+ val columns = root.getFieldVectors.asScala.map { vector =>
+ new ArrowColumnVector(vector).asInstanceOf[ColumnVector]
+ }.toArray
+
+ val batch = new ColumnarBatch(schemaRead, columns, root.getRowCount)
+ batch.setNumRows(root.getRowCount)
+ batch.rowIterator().asScala
+ }
+ }
+ }
+
/**
* Convert a byte array to an ArrowRecordBatch.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 2151c339b9b87..e4e9372447f7c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -20,14 +20,13 @@ package org.apache.spark.sql.execution
import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration.Duration
-import org.apache.spark.{InterruptibleIterator, SparkException, TaskContext}
+import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates
import org.apache.spark.sql.types.LongType
import org.apache.spark.util.ThreadUtils
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index 1d601374de135..c7ddec55682e1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -166,12 +166,13 @@ case class InMemoryTableScanExec(
if (inMemoryPartitionPruningEnabled) {
cachedBatchIterator.filter { cachedBatch =>
if (!partitionFilter.eval(cachedBatch.stats)) {
- def statsString: String = schemaIndex.map {
- case (a, i) =>
+ logDebug {
+ val statsString = schemaIndex.map { case (a, i) =>
val value = cachedBatch.stats.get(i, a.dataType)
s"${a.name}: $value"
- }.mkString(", ")
- logInfo(s"Skipping partition based on stats $statsString")
+ }.mkString(", ")
+ s"Skipping partition based on stats $statsString"
+ }
false
} else {
true
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala
new file mode 100644
index 0000000000000..5b54b2270b5ec
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import org.apache.spark.sql.{AnalysisException, Column, Row, SparkSession}
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType}
+import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
+import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal}
+import org.apache.spark.sql.execution.datasources.PartitioningUtils
+
+/**
+ * Analyzes a given set of partitions to generate per-partition statistics, which will be used in
+ * query optimizations.
+ *
+ * When `partitionSpec` is empty, statistics for all partitions are collected and stored in
+ * Metastore.
+ *
+ * When `partitionSpec` mentions only some of the partition columns, all partitions with
+ * matching values for specified columns are processed.
+ *
+ * If `partitionSpec` mentions unknown partition column, an `AnalysisException` is raised.
+ *
+ * By default, total number of rows and total size in bytes are calculated. When `noscan`
+ * is `true`, only total size in bytes is computed.
+ */
+case class AnalyzePartitionCommand(
+ tableIdent: TableIdentifier,
+ partitionSpec: Map[String, Option[String]],
+ noscan: Boolean = true) extends RunnableCommand {
+
+ private def getPartitionSpec(table: CatalogTable): Option[TablePartitionSpec] = {
+ val normalizedPartitionSpec =
+ PartitioningUtils.normalizePartitionSpec(partitionSpec, table.partitionColumnNames,
+ table.identifier.quotedString, conf.resolver)
+
+ // Report an error if partition columns in partition specification do not form
+ // a prefix of the list of partition columns defined in the table schema
+ val isNotSpecified =
+ table.partitionColumnNames.map(normalizedPartitionSpec.getOrElse(_, None).isEmpty)
+ if (isNotSpecified.init.zip(isNotSpecified.tail).contains((true, false))) {
+ val tableId = table.identifier
+ val schemaColumns = table.partitionColumnNames.mkString(",")
+ val specColumns = normalizedPartitionSpec.keys.mkString(",")
+ throw new AnalysisException("The list of partition columns with values " +
+ s"in partition specification for table '${tableId.table}' " +
+ s"in database '${tableId.database.get}' is not a prefix of the list of " +
+ "partition columns defined in the table schema. " +
+ s"Expected a prefix of [${schemaColumns}], but got [${specColumns}].")
+ }
+
+ val filteredSpec = normalizedPartitionSpec.filter(_._2.isDefined).mapValues(_.get)
+ if (filteredSpec.isEmpty) {
+ None
+ } else {
+ Some(filteredSpec)
+ }
+ }
+
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val sessionState = sparkSession.sessionState
+ val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase)
+ val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db))
+ val tableMeta = sessionState.catalog.getTableMetadata(tableIdentWithDB)
+ if (tableMeta.tableType == CatalogTableType.VIEW) {
+ throw new AnalysisException("ANALYZE TABLE is not supported on views.")
+ }
+
+ val partitionValueSpec = getPartitionSpec(tableMeta)
+
+ val partitions = sessionState.catalog.listPartitions(tableMeta.identifier, partitionValueSpec)
+
+ if (partitions.isEmpty) {
+ if (partitionValueSpec.isDefined) {
+ throw new NoSuchPartitionException(db, tableIdent.table, partitionValueSpec.get)
+ } else {
+ // the user requested to analyze all partitions for a table which has no partitions
+ // return normally, since there is nothing to do
+ return Seq.empty[Row]
+ }
+ }
+
+ // Compute statistics for individual partitions
+ val rowCounts: Map[TablePartitionSpec, BigInt] =
+ if (noscan) {
+ Map.empty
+ } else {
+ calculateRowCountsPerPartition(sparkSession, tableMeta, partitionValueSpec)
+ }
+
+ // Update the metastore if newly computed statistics are different from those
+ // recorded in the metastore.
+ val newPartitions = partitions.flatMap { p =>
+ val newTotalSize = CommandUtils.calculateLocationSize(
+ sessionState, tableMeta.identifier, p.storage.locationUri)
+ val newRowCount = rowCounts.get(p.spec)
+ val newStats = CommandUtils.compareAndGetNewStats(tableMeta.stats, newTotalSize, newRowCount)
+ newStats.map(_ => p.copy(stats = newStats))
+ }
+
+ if (newPartitions.nonEmpty) {
+ sessionState.catalog.alterPartitions(tableMeta.identifier, newPartitions)
+ }
+
+ Seq.empty[Row]
+ }
+
+ private def calculateRowCountsPerPartition(
+ sparkSession: SparkSession,
+ tableMeta: CatalogTable,
+ partitionValueSpec: Option[TablePartitionSpec]): Map[TablePartitionSpec, BigInt] = {
+ val filter = if (partitionValueSpec.isDefined) {
+ val filters = partitionValueSpec.get.map {
+ case (columnName, value) => EqualTo(UnresolvedAttribute(columnName), Literal(value))
+ }
+ filters.reduce(And)
+ } else {
+ Literal.TrueLiteral
+ }
+
+ val tableDf = sparkSession.table(tableMeta.identifier)
+ val partitionColumns = tableMeta.partitionColumnNames.map(Column(_))
+
+ val df = tableDf.filter(Column(filter)).groupBy(partitionColumns: _*).count()
+
+ df.collect().map { r =>
+ val partitionColumnValues = partitionColumns.indices.map(r.get(_).toString)
+ val spec = tableMeta.partitionColumnNames.zip(partitionColumnValues).toMap
+ val count = BigInt(r.getLong(partitionColumns.size))
+ (spec, count)
+ }.toMap
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
index cba147c35dd99..04715bd314d4d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeTableCommand.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.command
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTableType}
+import org.apache.spark.sql.catalyst.catalog.CatalogTableType
/**
@@ -37,31 +37,15 @@ case class AnalyzeTableCommand(
if (tableMeta.tableType == CatalogTableType.VIEW) {
throw new AnalysisException("ANALYZE TABLE is not supported on views.")
}
+
+ // Compute stats for the whole table
val newTotalSize = CommandUtils.calculateTotalSize(sessionState, tableMeta)
+ val newRowCount =
+ if (noscan) None else Some(BigInt(sparkSession.table(tableIdentWithDB).count()))
- val oldTotalSize = tableMeta.stats.map(_.sizeInBytes.toLong).getOrElse(-1L)
- val oldRowCount = tableMeta.stats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L)
- var newStats: Option[CatalogStatistics] = None
- if (newTotalSize >= 0 && newTotalSize != oldTotalSize) {
- newStats = Some(CatalogStatistics(sizeInBytes = newTotalSize))
- }
- // We only set rowCount when noscan is false, because otherwise:
- // 1. when total size is not changed, we don't need to alter the table;
- // 2. when total size is changed, `oldRowCount` becomes invalid.
- // This is to make sure that we only record the right statistics.
- if (!noscan) {
- val newRowCount = sparkSession.table(tableIdentWithDB).count()
- if (newRowCount >= 0 && newRowCount != oldRowCount) {
- newStats = if (newStats.isDefined) {
- newStats.map(_.copy(rowCount = Some(BigInt(newRowCount))))
- } else {
- Some(CatalogStatistics(
- sizeInBytes = oldTotalSize, rowCount = Some(BigInt(newRowCount))))
- }
- }
- }
// Update the metastore if the above statistics of the table are different from those
// recorded in the metastore.
+ val newStats = CommandUtils.compareAndGetNewStats(tableMeta.stats, newTotalSize, newRowCount)
if (newStats.isDefined) {
sessionState.catalog.alterTableStats(tableIdentWithDB, newStats)
// Refresh the cached data source table in the catalog.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
index de45be85220e9..b22958d59336c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
@@ -26,7 +26,7 @@ import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable}
+import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTablePartition}
import org.apache.spark.sql.internal.SessionState
@@ -112,4 +112,29 @@ object CommandUtils extends Logging {
size
}
+ def compareAndGetNewStats(
+ oldStats: Option[CatalogStatistics],
+ newTotalSize: BigInt,
+ newRowCount: Option[BigInt]): Option[CatalogStatistics] = {
+ val oldTotalSize = oldStats.map(_.sizeInBytes.toLong).getOrElse(-1L)
+ val oldRowCount = oldStats.flatMap(_.rowCount.map(_.toLong)).getOrElse(-1L)
+ var newStats: Option[CatalogStatistics] = None
+ if (newTotalSize >= 0 && newTotalSize != oldTotalSize) {
+ newStats = Some(CatalogStatistics(sizeInBytes = newTotalSize))
+ }
+ // We only set rowCount when noscan is false, because otherwise:
+ // 1. when total size is not changed, we don't need to alter the table;
+ // 2. when total size is changed, `oldRowCount` becomes invalid.
+ // This is to make sure that we only record the right statistics.
+ if (newRowCount.isDefined) {
+ if (newRowCount.get >= 0 && newRowCount.get != oldRowCount) {
+ newStats = if (newStats.isDefined) {
+ newStats.map(_.copy(rowCount = newRowCount))
+ } else {
+ Some(CatalogStatistics(sizeInBytes = oldTotalSize, rowCount = newRowCount))
+ }
+ }
+ }
+ newStats
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala
new file mode 100644
index 0000000000000..633de4c37af94
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.catalog._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.datasources._
+
+/**
+ * A command used to write the result of a query to a directory.
+ *
+ * The syntax of using this command in SQL is:
+ * {{{
+ * INSERT OVERWRITE DIRECTORY (path=STRING)?
+ * USING format OPTIONS ([option1_name "option1_value", option2_name "option2_value", ...])
+ * SELECT ...
+ * }}}
+ *
+ * @param storage storage format used to describe how the query result is stored.
+ * @param provider the data source type to be used
+ * @param query the logical plan representing data to write to
+ * @param overwrite whthere overwrites existing directory
+ */
+case class InsertIntoDataSourceDirCommand(
+ storage: CatalogStorageFormat,
+ provider: String,
+ query: LogicalPlan,
+ overwrite: Boolean) extends RunnableCommand {
+
+ override def children: Seq[LogicalPlan] = Seq(query)
+
+ override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = {
+ assert(children.length == 1)
+ assert(storage.locationUri.nonEmpty, "Directory path is required")
+ assert(provider.nonEmpty, "Data source is required")
+
+ // Create the relation based on the input logical plan: `query`.
+ val pathOption = storage.locationUri.map("path" -> CatalogUtils.URIToString(_))
+
+ val dataSource = DataSource(
+ sparkSession,
+ className = provider,
+ options = storage.properties ++ pathOption,
+ catalogTable = None)
+
+ val isFileFormat = classOf[FileFormat].isAssignableFrom(dataSource.providingClass)
+ if (!isFileFormat) {
+ throw new SparkException(
+ "Only Data Sources providing FileFormat are supported: " + dataSource.providingClass)
+ }
+
+ val saveMode = if (overwrite) SaveMode.Overwrite else SaveMode.ErrorIfExists
+ try {
+ sparkSession.sessionState.executePlan(dataSource.planForWriting(saveMode, query))
+ dataSource.writeAndRead(saveMode, query)
+ } catch {
+ case ex: AnalysisException =>
+ logError(s"Failed to write to directory " + storage.locationUri.toString, ex)
+ throw ex
+ }
+
+ Seq.empty[Row]
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
index 47952f2f227a3..792290bef0163 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
@@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.SQLExecution
case class CacheTableCommand(
tableIdent: TableIdentifier,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index dae160f1bbb18..162e1d5be2938 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -33,7 +33,11 @@ import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Resolver}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
-import org.apache.spark.sql.execution.datasources.PartitioningUtils
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, PartitioningUtils}
+import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
+import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter
+import org.apache.spark.sql.internal.HiveSerDe
import org.apache.spark.sql.types._
import org.apache.spark.util.{SerializableConfiguration, ThreadUtils}
@@ -797,7 +801,11 @@ object DDLUtils {
val HIVE_PROVIDER = "hive"
def isHiveTable(table: CatalogTable): Boolean = {
- table.provider.isDefined && table.provider.get.toLowerCase(Locale.ROOT) == HIVE_PROVIDER
+ isHiveTable(table.provider)
+ }
+
+ def isHiveTable(provider: Option[String]): Boolean = {
+ provider.isDefined && provider.get.toLowerCase(Locale.ROOT) == HIVE_PROVIDER
}
def isDatasourceTable(table: CatalogTable): Boolean = {
@@ -848,4 +856,36 @@ object DDLUtils {
}
}
}
+
+ private[sql] def checkDataSchemaFieldNames(table: CatalogTable): Unit = {
+ table.provider.foreach {
+ _.toLowerCase(Locale.ROOT) match {
+ case HIVE_PROVIDER =>
+ val serde = table.storage.serde
+ if (serde == HiveSerDe.sourceToSerDe("orc").get.serde) {
+ OrcFileFormat.checkFieldNames(table.dataSchema)
+ } else if (serde == HiveSerDe.sourceToSerDe("parquet").get.serde ||
+ serde == Some("parquet.hive.serde.ParquetHiveSerDe")) {
+ ParquetSchemaConverter.checkFieldNames(table.dataSchema)
+ }
+ case "parquet" => ParquetSchemaConverter.checkFieldNames(table.dataSchema)
+ case "orc" => OrcFileFormat.checkFieldNames(table.dataSchema)
+ case _ =>
+ }
+ }
+ }
+
+ /**
+ * Throws exception if outputPath tries to overwrite inputpath.
+ */
+ def verifyNotReadPath(query: LogicalPlan, outputPath: Path) : Unit = {
+ val inputPaths = query.collect {
+ case LogicalRelation(r: HadoopFsRelation, _, _, _) => r.location.rootPaths
+ }.flatten
+
+ if (inputPaths.contains(outputPath)) {
+ throw new AnalysisException(
+ "Cannot overwrite a path that is also being read from.")
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index 694d517668a2c..8d95ca6921cf8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -29,13 +29,13 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException
+import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.catalog.CatalogTableType._
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.util.quoteIdentifier
-import org.apache.spark.sql.execution.datasources.{DataSource, FileFormat, PartitioningUtils}
+import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils}
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
@@ -201,13 +201,14 @@ case class AlterTableAddColumnsCommand(
// make sure any partition columns are at the end of the fields
val reorderedSchema = catalogTable.dataSchema ++ columns ++ catalogTable.partitionSchema
+ val newSchema = catalogTable.schema.copy(fields = reorderedSchema.toArray)
SchemaUtils.checkColumnNameDuplication(
reorderedSchema.map(_.name), "in the table definition of " + table.identifier,
conf.caseSensitiveAnalysis)
+ DDLUtils.checkDataSchemaFieldNames(catalogTable.copy(schema = newSchema))
- catalog.alterTableSchema(
- table, catalogTable.schema.copy(fields = reorderedSchema.toArray))
+ catalog.alterTableSchema(table, newSchema)
Seq.empty[Row]
}
@@ -630,6 +631,73 @@ case class DescribeTableCommand(
}
}
+/**
+ * A command to list the info for a column, including name, data type, comment and column stats.
+ *
+ * The syntax of using this command in SQL is:
+ * {{{
+ * DESCRIBE [EXTENDED|FORMATTED] table_name column_name;
+ * }}}
+ */
+case class DescribeColumnCommand(
+ table: TableIdentifier,
+ colNameParts: Seq[String],
+ isExtended: Boolean)
+ extends RunnableCommand {
+
+ override val output: Seq[Attribute] = {
+ Seq(
+ AttributeReference("info_name", StringType, nullable = false,
+ new MetadataBuilder().putString("comment", "name of the column info").build())(),
+ AttributeReference("info_value", StringType, nullable = false,
+ new MetadataBuilder().putString("comment", "value of the column info").build())()
+ )
+ }
+
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val catalog = sparkSession.sessionState.catalog
+ val resolver = sparkSession.sessionState.conf.resolver
+ val relation = sparkSession.table(table).queryExecution.analyzed
+
+ val colName = UnresolvedAttribute(colNameParts).name
+ val field = {
+ relation.resolve(colNameParts, resolver).getOrElse {
+ throw new AnalysisException(s"Column $colName does not exist")
+ }
+ }
+ if (!field.isInstanceOf[Attribute]) {
+ // If the field is not an attribute after `resolve`, then it's a nested field.
+ throw new AnalysisException(
+ s"DESC TABLE COLUMN command does not support nested data types: $colName")
+ }
+
+ val catalogTable = catalog.getTempViewOrPermanentTableMetadata(table)
+ val colStats = catalogTable.stats.map(_.colStats).getOrElse(Map.empty)
+ val cs = colStats.get(field.name)
+
+ val comment = if (field.metadata.contains("comment")) {
+ Option(field.metadata.getString("comment"))
+ } else {
+ None
+ }
+
+ val buffer = ArrayBuffer[Row](
+ Row("col_name", field.name),
+ Row("data_type", field.dataType.catalogString),
+ Row("comment", comment.getOrElse("NULL"))
+ )
+ if (isExtended) {
+ // Show column stats when EXTENDED or FORMATTED is specified.
+ buffer += Row("min", cs.flatMap(_.min.map(_.toString)).getOrElse("NULL"))
+ buffer += Row("max", cs.flatMap(_.max.map(_.toString)).getOrElse("NULL"))
+ buffer += Row("num_nulls", cs.map(_.nullCount.toString).getOrElse("NULL"))
+ buffer += Row("distinct_count", cs.map(_.distinctCount.toString).getOrElse("NULL"))
+ buffer += Row("avg_col_len", cs.map(_.avgLen.toString).getOrElse("NULL"))
+ buffer += Row("max_col_len", cs.map(_.maxLen.toString).getOrElse("NULL"))
+ }
+ buffer
+ }
+}
/**
* A command for users to get tables in the given database.
@@ -740,8 +808,7 @@ case class ShowTablePropertiesCommand(table: TableIdentifier, propertyKey: Optio
}
/**
- * A command to list the column names for a table. This function creates a
- * [[ShowColumnsCommand]] logical plan.
+ * A command to list the column names for a table.
*
* The syntax of using this command in SQL is:
* {{{
@@ -780,8 +847,6 @@ case class ShowColumnsCommand(
* 1. If the command is called for a non partitioned table.
* 2. If the partition spec refers to the columns that are not defined as partitioning columns.
*
- * This function creates a [[ShowPartitionsCommand]] logical plan
- *
* The syntax of using this command in SQL is:
* {{{
* SHOW PARTITIONS [db_name.]table_name [PARTITION(partition_spec)]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 567ff49773f9b..b9502a95a7c08 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -455,7 +455,7 @@ case class DataSource(
val fileIndex = catalogTable.map(_.identifier).map { tableIdent =>
sparkSession.table(tableIdent).queryExecution.analyzed.collect {
- case LogicalRelation(t: HadoopFsRelation, _, _) => t.location
+ case LogicalRelation(t: HadoopFsRelation, _, _, _) => t.location
}.head
}
// For partitioned relation r, r.schema's column ordering can be different from the column
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index 237017742770a..018f24e290b4b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -17,8 +17,11 @@
package org.apache.spark.sql.execution.datasources
+import java.util.Locale
import java.util.concurrent.Callable
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
@@ -29,7 +32,7 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
-import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoTable, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
@@ -130,18 +133,28 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case CreateTable(tableDesc, mode, None) if DDLUtils.isDatasourceTable(tableDesc) =>
+ DDLUtils.checkDataSchemaFieldNames(tableDesc)
CreateDataSourceTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore)
case CreateTable(tableDesc, mode, Some(query))
if query.resolved && DDLUtils.isDatasourceTable(tableDesc) =>
+ DDLUtils.checkDataSchemaFieldNames(tableDesc.copy(schema = query.schema))
CreateDataSourceTableAsSelectCommand(tableDesc, mode, query)
- case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _),
+ case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _, _),
parts, query, overwrite, false) if parts.isEmpty =>
InsertIntoDataSourceCommand(l, query, overwrite)
+ case InsertIntoDir(_, storage, provider, query, overwrite)
+ if provider.isDefined && provider.get.toLowerCase(Locale.ROOT) != DDLUtils.HIVE_PROVIDER =>
+
+ val outputPath = new Path(storage.locationUri.get)
+ if (overwrite) DDLUtils.verifyNotReadPath(query, outputPath)
+
+ InsertIntoDataSourceDirCommand(storage, provider.get, query, overwrite)
+
case i @ InsertIntoTable(
- l @ LogicalRelation(t: HadoopFsRelation, _, table), parts, query, overwrite, _) =>
+ l @ LogicalRelation(t: HadoopFsRelation, _, table, _), parts, query, overwrite, _) =>
// If the InsertIntoTable command is for a partitioned HadoopFsRelation and
// the user has specified static partitions, we add a Project operator on top of the query
// to include those constant column values in the query result.
@@ -176,15 +189,9 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast
}
val outputPath = t.location.rootPaths.head
- val inputPaths = actualQuery.collect {
- case LogicalRelation(r: HadoopFsRelation, _, _) => r.location.rootPaths
- }.flatten
+ if (overwrite) DDLUtils.verifyNotReadPath(actualQuery, outputPath)
val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append
- if (overwrite && inputPaths.contains(outputPath)) {
- throw new AnalysisException(
- "Cannot overwrite a path that is also being read from.")
- }
val partitionSchema = actualQuery.resolve(
t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver)
@@ -268,7 +275,7 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with
import DataSourceStrategy._
def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match {
- case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _)) =>
+ case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _, _)) =>
pruneFilterProjectRaw(
l,
projects,
@@ -276,21 +283,22 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with
(requestedColumns, allPredicates, _) =>
toCatalystRDD(l, requestedColumns, t.buildScan(requestedColumns, allPredicates))) :: Nil
- case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedFilteredScan, _, _)) =>
+ case PhysicalOperation(projects, filters,
+ l @ LogicalRelation(t: PrunedFilteredScan, _, _, _)) =>
pruneFilterProject(
l,
projects,
filters,
(a, f) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray, f))) :: Nil
- case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedScan, _, _)) =>
+ case PhysicalOperation(projects, filters, l @ LogicalRelation(t: PrunedScan, _, _, _)) =>
pruneFilterProject(
l,
projects,
filters,
(a, _) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray))) :: Nil
- case l @ LogicalRelation(baseRelation: TableScan, _, _) =>
+ case l @ LogicalRelation(baseRelation: TableScan, _, _, _) =>
RowDataSourceScanExec(
l.output,
l.output.indices,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala
index d2adba2da9478..2068486713c2a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala
@@ -65,6 +65,16 @@ trait FileFormat {
false
}
+ /**
+ * Returns concrete column vector class names for each column to be used in a columnar batch
+ * if this format supports returning columnar batch.
+ */
+ def vectorTypes(
+ requiredSchema: StructType,
+ partitionSchema: StructType): Option[Seq[String]] = {
+ None
+ }
+
/**
* Returns whether a file with `path` could be splitted or not.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
index 17f7e0e601c0c..16b22717b8d92 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -52,7 +52,7 @@ import org.apache.spark.sql.execution.SparkPlan
object FileSourceStrategy extends Strategy with Logging {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(projects, filters,
- l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table)) =>
+ l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) =>
// Filters on this relation fall into four categories based on where we can use them to avoid
// reading unneeded data:
// - partition keys only - used to prune directories to read
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
index 699f1bad9c4ed..17a61074d3b5c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
@@ -30,12 +30,14 @@ import org.apache.spark.util.Utils
case class LogicalRelation(
relation: BaseRelation,
output: Seq[AttributeReference],
- catalogTable: Option[CatalogTable])
+ catalogTable: Option[CatalogTable],
+ override val isStreaming: Boolean)
extends LeafNode with MultiInstanceRelation {
// Logical Relations are distinct if they have different output for the sake of transformations.
override def equals(other: Any): Boolean = other match {
- case l @ LogicalRelation(otherRelation, _, _) => relation == otherRelation && output == l.output
+ case l @ LogicalRelation(otherRelation, _, _, isStreaming) =>
+ relation == otherRelation && output == l.output && isStreaming == l.isStreaming
case _ => false
}
@@ -76,9 +78,9 @@ case class LogicalRelation(
}
object LogicalRelation {
- def apply(relation: BaseRelation): LogicalRelation =
- LogicalRelation(relation, relation.schema.toAttributes, None)
+ def apply(relation: BaseRelation, isStreaming: Boolean = false): LogicalRelation =
+ LogicalRelation(relation, relation.schema.toAttributes, None, isStreaming)
def apply(relation: BaseRelation, table: CatalogTable): LogicalRelation =
- LogicalRelation(relation, relation.schema.toAttributes, Some(table))
+ LogicalRelation(relation, relation.schema.toAttributes, Some(table), false)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
index f5df1848a38c4..3b830accb83f0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
@@ -36,6 +36,7 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
_,
_),
_,
+ _,
_))
if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined =>
// The attribute name of predicate could be different than the one in schema in case of
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SourceOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SourceOptions.scala
new file mode 100644
index 0000000000000..c98c0b2a756a1
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SourceOptions.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+
+/**
+ * Options for the data source.
+ */
+class SourceOptions(
+ @transient private val parameters: CaseInsensitiveMap[String])
+ extends Serializable {
+ import SourceOptions._
+
+ def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters))
+
+ // A flag to disable saving a data source table's metadata in hive compatible way.
+ val skipHiveMetadata: Boolean = parameters
+ .get(SKIP_HIVE_METADATA).map(_.toBoolean).getOrElse(DEFAULT_SKIP_HIVE_METADATA)
+
+ // A flag to always respect the Spark schema restored from the table properties
+ val respectSparkSchema: Boolean = parameters
+ .get(RESPECT_SPARK_SCHEMA).map(_.toBoolean).getOrElse(DEFAULT_RESPECT_SPARK_SCHEMA)
+}
+
+
+object SourceOptions {
+
+ val SKIP_HIVE_METADATA = "skipHiveMetadata"
+ val DEFAULT_SKIP_HIVE_METADATA = false
+
+ val RESPECT_SPARK_SCHEMA = "respectSparkSchema"
+ val DEFAULT_RESPECT_SPARK_SCHEMA = false
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index a99bdfee5d6e6..e20977a4ec79f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -109,6 +109,20 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
}
}
+ if (requiredSchema.length == 1 &&
+ requiredSchema.head.name == parsedOptions.columnNameOfCorruptRecord) {
+ throw new AnalysisException(
+ "Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the\n" +
+ "referenced columns only include the internal corrupt record column\n" +
+ s"(named _corrupt_record by default). For example:\n" +
+ "spark.read.schema(schema).csv(file).filter($\"_corrupt_record\".isNotNull).count()\n" +
+ "and spark.read.schema(schema).csv(file).select(\"_corrupt_record\").show().\n" +
+ "Instead, you can cache or save the parsed results and then send the same query.\n" +
+ "For example, val df = spark.read.schema(schema).csv(file).cache() and then\n" +
+ "df.filter($\"_corrupt_record\".isNotNull).count()."
+ )
+ }
+
(file: PartitionedFile) => {
val conf = broadcastedHadoopConf.value.value
val parser = new UnivocityParser(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index 05b00058618a2..b4e5d169066d9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -21,6 +21,7 @@ import java.sql.{Connection, DriverManager}
import java.util.{Locale, Properties}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.types.StructType
/**
* Options for the JDBC data source.
@@ -123,6 +124,8 @@ class JDBCOptions(
// TODO: to reuse the existing partition parameters for those partition specific options
val createTableOptions = parameters.getOrElse(JDBC_CREATE_TABLE_OPTIONS, "")
val createTableColumnTypes = parameters.get(JDBC_CREATE_TABLE_COLUMN_TYPES)
+ val customSchema = parameters.get(JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES)
+
val batchSize = {
val size = parameters.getOrElse(JDBC_BATCH_INSERT_SIZE, "1000").toInt
require(size >= 1,
@@ -161,6 +164,7 @@ object JDBCOptions {
val JDBC_TRUNCATE = newOption("truncate")
val JDBC_CREATE_TABLE_OPTIONS = newOption("createTableOptions")
val JDBC_CREATE_TABLE_COLUMN_TYPES = newOption("createTableColumnTypes")
+ val JDBC_CUSTOM_DATAFRAME_COLUMN_TYPES = newOption("customSchema")
val JDBC_BATCH_INSERT_SIZE = newOption("batchsize")
val JDBC_TXN_ISOLATION_LEVEL = newOption("isolationLevel")
val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 3274be91d4817..05326210f3242 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -80,7 +80,7 @@ object JDBCRDD extends Logging {
* @return A Catalyst schema corresponding to columns in the given order.
*/
private def pruneSchema(schema: StructType, columns: Array[String]): StructType = {
- val fieldMap = Map(schema.fields.map(x => x.metadata.getString("name") -> x): _*)
+ val fieldMap = Map(schema.fields.map(x => x.name -> x): _*)
new StructType(columns.map(name => fieldMap(name)))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index 17405f550d25f..b23e5a7722004 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -111,7 +111,14 @@ private[sql] case class JDBCRelation(
override val needConversion: Boolean = false
- override val schema: StructType = JDBCRDD.resolveTable(jdbcOptions)
+ override val schema: StructType = {
+ val tableSchema = JDBCRDD.resolveTable(jdbcOptions)
+ jdbcOptions.customSchema match {
+ case Some(customSchema) => JdbcUtils.getCustomSchema(
+ tableSchema, customSchema, sparkSession.sessionState.conf.resolver)
+ case None => tableSchema
+ }
+ }
// Check if JDBCRDD.compileFilter can accept input filters
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index bbe9024f13a44..71133666b3249 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -29,6 +29,7 @@ import org.apache.spark.executor.InputMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
@@ -300,13 +301,11 @@ object JdbcUtils extends Logging {
} else {
rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls
}
- val metadata = new MetadataBuilder()
- .putString("name", columnName)
- .putLong("scale", fieldScale)
+ val metadata = new MetadataBuilder().putLong("scale", fieldScale)
val columnType =
dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse(
getCatalystType(dataType, fieldSize, fieldScale, isSigned))
- fields(i) = StructField(columnName, columnType, nullable, metadata.build())
+ fields(i) = StructField(columnName, columnType, nullable)
i = i + 1
}
new StructType(fields)
@@ -767,6 +766,33 @@ object JdbcUtils extends Logging {
if (isCaseSensitive) userSchemaMap else CaseInsensitiveMap(userSchemaMap)
}
+ /**
+ * Parses the user specified customSchema option value to DataFrame schema, and
+ * returns a schema that is replaced by the custom schema's dataType if column name is matched.
+ */
+ def getCustomSchema(
+ tableSchema: StructType,
+ customSchema: String,
+ nameEquality: Resolver): StructType = {
+ if (null != customSchema && customSchema.nonEmpty) {
+ val userSchema = CatalystSqlParser.parseTableSchema(customSchema)
+
+ SchemaUtils.checkColumnNameDuplication(
+ userSchema.map(_.name), "in the customSchema option value", nameEquality)
+
+ // This is resolved by names, use the custom filed dataType to replace the default dataType.
+ val newSchema = tableSchema.map { col =>
+ userSchema.find(f => nameEquality(f.name, col.name)) match {
+ case Some(c) => col.copy(dataType = c.dataType)
+ case None => col
+ }
+ }
+ StructType(newSchema)
+ } else {
+ tableSchema
+ }
+ }
+
/**
* Saves the RDD to the database in a single transaction.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index 53d62d88b04c6..0862c746fffad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -113,6 +113,20 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
}
}
+ if (requiredSchema.length == 1 &&
+ requiredSchema.head.name == parsedOptions.columnNameOfCorruptRecord) {
+ throw new AnalysisException(
+ "Since Spark 2.3, the queries from raw JSON/CSV files are disallowed when the\n" +
+ "referenced columns only include the internal corrupt record column\n" +
+ s"(named _corrupt_record by default). For example:\n" +
+ "spark.read.schema(schema).json(file).filter($\"_corrupt_record\".isNotNull).count()\n" +
+ "and spark.read.schema(schema).json(file).select(\"_corrupt_record\").show().\n" +
+ "Instead, you can cache or save the parsed results and then send the same query.\n" +
+ "For example, val df = spark.read.schema(schema).json(file).cache() and then\n" +
+ "df.filter($\"_corrupt_record\".isNotNull).count()."
+ )
+ }
+
(file: PartitionedFile) => {
val parser = new JacksonParser(actualSchema, parsedOptions)
JsonDataSource(parsedOptions).readFile(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
similarity index 53%
rename from sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
index 157783abc8c2f..2eeb0065455f3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
@@ -15,30 +15,28 @@
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst
+package org.apache.spark.sql.execution.datasources.orc
-import scala.util.control.NonFatal
+import org.apache.orc.TypeDescription
-import org.apache.spark.sql.{DataFrame, Dataset, QueryTest}
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.types.StructType
-
-abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton {
- protected def checkSQL(e: Expression, expectedSQL: String): Unit = {
- val actualSQL = e.sql
+private[sql] object OrcFileFormat {
+ private def checkFieldName(name: String): Unit = {
try {
- assert(actualSQL === expectedSQL)
+ TypeDescription.fromString(s"struct<$name:int>")
} catch {
- case cause: Throwable =>
- fail(
- s"""Wrong SQL generated for the following expression:
- |
- |${e.prettyName}
- |
- |$cause
- """.stripMargin)
+ case _: IllegalArgumentException =>
+ throw new AnalysisException(
+ s"""Column name "$name" contains invalid character(s).
+ |Please use alias to rename it.
+ """.stripMargin.split("\n").mkString(" ").trim)
}
}
+
+ def checkFieldNames(schema: StructType): StructType = {
+ schema.fieldNames.foreach(checkFieldName)
+ schema
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
index 1d72d4de51364..3465661682951 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
@@ -51,6 +51,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser
import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
@@ -283,6 +284,13 @@ class ParquetFileFormat
schema.forall(_.dataType.isInstanceOf[AtomicType])
}
+ override def vectorTypes(
+ requiredSchema: StructType,
+ partitionSchema: StructType): Option[Seq[String]] = {
+ Option(Seq.fill(requiredSchema.fields.length + partitionSchema.fields.length)(
+ classOf[OnHeapColumnVector].getName))
+ }
+
override def isSplitable(
sparkSession: SparkSession,
options: Map[String, String],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
index 1a14f4756c67c..efda6b593ca62 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
@@ -556,7 +556,7 @@ private[parquet] class ParquetSchemaConverter(
}
}
-private[parquet] object ParquetSchemaConverter {
+private[sql] object ParquetSchemaConverter {
val SPARK_PARQUET_SCHEMA_NAME = "spark_schema"
val EMPTY_MESSAGE: MessageType =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index 84acca242aa41..7a2c85e8e01f6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -385,10 +385,10 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] wit
case relation: HiveTableRelation =>
val metadata = relation.tableMeta
preprocess(i, metadata.identifier.quotedString, metadata.partitionColumnNames)
- case LogicalRelation(h: HadoopFsRelation, _, catalogTable) =>
+ case LogicalRelation(h: HadoopFsRelation, _, catalogTable, _) =>
val tblName = catalogTable.map(_.identifier.quotedString).getOrElse("unknown")
preprocess(i, tblName, h.partitionSchema.map(_.name))
- case LogicalRelation(_: InsertableRelation, _, catalogTable) =>
+ case LogicalRelation(_: InsertableRelation, _, catalogTable, _) =>
val tblName = catalogTable.map(_.identifier.quotedString).getOrElse("unknown")
preprocess(i, tblName, Nil)
case _ => i
@@ -428,7 +428,7 @@ object PreReadCheck extends (LogicalPlan => Unit) {
private def checkNumInputFileBlockSources(e: Expression, operator: LogicalPlan): Int = {
operator match {
case _: HiveTableRelation => 1
- case _ @ LogicalRelation(_: HadoopFsRelation, _, _) => 1
+ case _ @ LogicalRelation(_: HadoopFsRelation, _, _, _) => 1
case _: LeafNode => 0
// UNION ALL has multiple children, but these children do not concurrently use InputFileBlock.
case u: Union =>
@@ -454,10 +454,10 @@ object PreWriteCheck extends (LogicalPlan => Unit) {
def apply(plan: LogicalPlan): Unit = {
plan.foreach {
- case InsertIntoTable(l @ LogicalRelation(relation, _, _), partition, query, _, _) =>
+ case InsertIntoTable(l @ LogicalRelation(relation, _, _, _), partition, query, _, _) =>
// Get all input data source relations of the query.
val srcRelations = query.collect {
- case LogicalRelation(src, _, _) => src
+ case LogicalRelation(src, _, _, _) => src
}
if (srcRelations.contains(relation)) {
failAnalysis("Cannot insert into table that is also being read from.")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala
index deb2c24d0f16e..9fc4ffb651ec8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala
@@ -75,7 +75,7 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan}
* For example, we have two stages with the following pre-shuffle partition size statistics:
* stage 1: [100 MB, 20 MB, 100 MB, 10MB, 30 MB]
* stage 2: [10 MB, 10 MB, 70 MB, 5 MB, 5 MB]
- * assuming the target input size is 128 MB, we will have three post-shuffle partitions,
+ * assuming the target input size is 128 MB, we will have four post-shuffle partitions,
* which are:
* - post-shuffle partition 0: pre-shuffle partition 0 (size 110 MB)
* - post-shuffle partition 1: pre-shuffle partition 1 (size 30 MB)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
index eebe6ad2e7944..0d06d83fb2f3c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
@@ -25,7 +25,7 @@ import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index bfa1e9d49a545..ab7bb8ab9d87a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution}
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.LongType
@@ -283,8 +283,8 @@ case class BroadcastHashJoinExec(
s"""
|boolean $conditionPassed = true;
|${eval.trim}
- |${ev.code}
|if ($matched != null) {
+ | ${ev.code}
| $conditionPassed = !${ev.isNull} && ${ev.value};
|}
""".stripMargin
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 2038cb9edb67d..1b6a28cde2931 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -23,7 +23,7 @@ import com.esotericsoftware.kryo.{Kryo, KryoSerializable}
import com.esotericsoftware.kryo.io.{Input, Output}
import org.apache.spark.{SparkConf, SparkEnv, SparkException}
-import org.apache.spark.memory.{MemoryConsumer, MemoryMode, StaticMemoryManager, TaskMemoryManager}
+import org.apache.spark.memory.{MemoryConsumer, StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index f1df41ca49c27..66e8031bb5191 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index c7ea119d848b3..2f958d1dccc06 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -57,14 +57,6 @@ trait BaseLimitExec extends UnaryExecNode with CodegenSupport {
val limit: Int
override def output: Seq[Attribute] = child.output
- // Do not enable whole stage codegen for a single limit.
- override def supportCodegen: Boolean = child match {
- case plan: CodegenSupport => plan.supportCodegen
- case _ => false
- }
-
- override def executeTake(n: Int): Array[InternalRow] = child.executeTake(math.min(n, limit))
-
protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
iter.take(limit)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
index 59d7e8dd6dffb..7ebbdb9846cce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.python
import org.apache.spark.api.python.PythonFunction
-import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable}
+import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable, UserDefinedExpression}
import org.apache.spark.sql.types.DataType
/**
@@ -29,7 +29,7 @@ case class PythonUDF(
func: PythonFunction,
dataType: DataType,
children: Seq[Expression])
- extends Expression with Unevaluable with NonSQLExpression {
+ extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression {
override def toString: String = s"$name(${children.mkString(", ")})"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala
index c9939ac1db746..17b6874a61648 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala
@@ -22,7 +22,6 @@ import java.net.URI
import org.apache.hadoop.fs.{FileStatus, Path}
import org.json4s.NoTypeHints
import org.json4s.jackson.Serialization
-import org.json4s.jackson.Serialization.{read, write}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
index 4b1b2520390ba..f17417343e289 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
@@ -171,7 +171,7 @@ class FileStreamSource(
className = fileFormatClassName,
options = optionsWithPartitionBasePath)
Dataset.ofRows(sparkSession, LogicalRelation(newDataSource.resolveRelation(
- checkFilesExist = false)))
+ checkFilesExist = false), isStreaming = true))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala
index e76d4dc6125df..077a4778e34a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/RateSourceProvider.scala
@@ -200,7 +200,8 @@ class RateStreamSource(
s"rangeStart: $rangeStart, rangeEnd: $rangeEnd")
if (rangeStart == rangeEnd) {
- return sqlContext.internalCreateDataFrame(sqlContext.sparkContext.emptyRDD, schema)
+ return sqlContext.internalCreateDataFrame(
+ sqlContext.sparkContext.emptyRDD, schema, isStreaming = true)
}
val localStartTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(startSeconds)
@@ -211,7 +212,7 @@ class RateStreamSource(
val relative = math.round((v - rangeStart) * relativeMsPerValue)
InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), v)
}
- sqlContext.internalCreateDataFrame(rdd, schema)
+ sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true)
}
override def stop(): Unit = {}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 9bc114f138562..952e431fb19d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql.execution.streaming
-import java.io.{InterruptedIOException, IOException}
+import java.io.{InterruptedIOException, IOException, UncheckedIOException}
+import java.nio.channels.ClosedByInterruptException
import java.util.UUID
-import java.util.concurrent.{CountDownLatch, TimeUnit}
+import java.util.concurrent.{CountDownLatch, ExecutionException, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.locks.ReentrantLock
@@ -27,6 +28,7 @@ import scala.collection.mutable.{Map => MutableMap}
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
+import com.google.common.util.concurrent.UncheckedExecutionException
import org.apache.hadoop.fs.Path
import org.apache.spark.internal.Logging
@@ -164,7 +166,7 @@ class StreamExecution(
nextSourceId += 1
// We still need to use the previous `output` instead of `source.schema` as attributes in
// "df.logicalPlan" has already used attributes of the previous `output`.
- StreamingExecutionRelation(source, output)
+ StreamingExecutionRelation(source, output)(sparkSession)
})
}
sources = _logicalPlan.collect { case s: StreamingExecutionRelation => s.source }
@@ -335,7 +337,7 @@ class StreamExecution(
// `stop()` is already called. Let `finally` finish the cleanup.
}
} catch {
- case _: InterruptedException | _: InterruptedIOException if state.get == TERMINATED =>
+ case e if isInterruptedByStop(e) =>
// interrupted by stop()
updateStatusMessage("Stopped")
case e: IOException if e.getMessage != null
@@ -407,6 +409,32 @@ class StreamExecution(
}
}
+ private def isInterruptedByStop(e: Throwable): Boolean = {
+ if (state.get == TERMINATED) {
+ e match {
+ // InterruptedIOException - thrown when an I/O operation is interrupted
+ // ClosedByInterruptException - thrown when an I/O operation upon a channel is interrupted
+ case _: InterruptedException | _: InterruptedIOException | _: ClosedByInterruptException =>
+ true
+ // The cause of the following exceptions may be one of the above exceptions:
+ //
+ // UncheckedIOException - thrown by codes that cannot throw a checked IOException, such as
+ // BiFunction.apply
+ // ExecutionException - thrown by codes running in a thread pool and these codes throw an
+ // exception
+ // UncheckedExecutionException - thrown by codes that cannot throw a checked
+ // ExecutionException, such as BiFunction.apply
+ case e2 @ (_: UncheckedIOException | _: ExecutionException | _: UncheckedExecutionException)
+ if e2.getCause != null =>
+ isInterruptedByStop(e2.getCause)
+ case _ =>
+ false
+ }
+ } else {
+ false
+ }
+ }
+
/**
* Populate the start offsets to start the execution at the current offsets stored in the sink
* (i.e. avoid reprocessing data that we have already processed). This function must be called
@@ -609,6 +637,9 @@ class StreamExecution(
if committedOffsets.get(source).map(_ != available).getOrElse(true) =>
val current = committedOffsets.get(source)
val batch = source.getBatch(current, available)
+ assert(batch.isStreaming,
+ s"DataFrame returned by getBatch from $source did not have isStreaming=true\n" +
+ s"${batch.queryExecution.logical}")
logDebug(s"Retrieving data from $source: $current -> $available")
Some(source -> batch)
case _ => None
@@ -628,7 +659,7 @@ class StreamExecution(
replacements ++= output.zip(newPlan.output)
newPlan
}.getOrElse {
- LocalRelation(output)
+ LocalRelation(output, isStreaming = true)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
index e8b00094add3a..ab716052c28ba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
@@ -18,9 +18,11 @@
package org.apache.spark.sql.execution.streaming
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.execution.LeafExecNode
import org.apache.spark.sql.execution.datasources.DataSource
@@ -48,9 +50,21 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output:
* Used to link a streaming [[Source]] of data into a
* [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]].
*/
-case class StreamingExecutionRelation(source: Source, output: Seq[Attribute]) extends LeafNode {
+case class StreamingExecutionRelation(
+ source: Source,
+ output: Seq[Attribute])(session: SparkSession)
+ extends LeafNode {
+
override def isStreaming: Boolean = true
override def toString: String = source.toString
+
+ // There's no sensible value here. On the execution path, this relation will be
+ // swapped out with microbatches. But some dataframe operations (in particular explain) do lead
+ // to this node surviving analysis. So we satisfy the LeafNode contract with the session default
+ // value.
+ override def computeStats(): Statistics = Statistics(
+ sizeInBytes = BigInt(session.sessionState.conf.defaultSizeInBytes)
+ )
}
/**
@@ -65,7 +79,7 @@ case class StreamingRelationExec(sourceName: String, output: Seq[Attribute]) ext
}
object StreamingExecutionRelation {
- def apply(source: Source): StreamingExecutionRelation = {
- StreamingExecutionRelation(source, source.schema.toAttributes)
+ def apply(source: Source, session: SparkSession): StreamingExecutionRelation = {
+ StreamingExecutionRelation(source, source.schema.toAttributes)(session)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 587ae2bfb63fb..3041d4d703cb4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.streaming
import java.util.concurrent.atomic.AtomicInteger
import javax.annotation.concurrent.GuardedBy
+import scala.collection.JavaConverters._
+import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.util.control.NonFatal
@@ -27,13 +29,14 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
+
object MemoryStream {
protected val currentBlockId = new AtomicInteger(0)
protected val memoryStreamId = new AtomicInteger(0)
@@ -44,13 +47,13 @@ object MemoryStream {
/**
* A [[Source]] that produces value stored in memory as they are added by the user. This [[Source]]
- * is primarily intended for use in unit tests as it can only replay data when the object is still
+ * is intended for use in unit tests as it can only replay data when the object is still
* available.
*/
case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
extends Source with Logging {
protected val encoder = encoderFor[A]
- protected val logicalPlan = StreamingExecutionRelation(this)
+ protected val logicalPlan = StreamingExecutionRelation(this, sqlContext.sparkSession)
protected val output = logicalPlan.output
/**
@@ -85,8 +88,9 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
}
def addData(data: TraversableOnce[A]): Offset = {
- import sqlContext.implicits._
- val ds = data.toVector.toDS()
+ val encoded = data.toVector.map(d => encoder.toRow(d).copy())
+ val plan = new LocalRelation(schema.toAttributes, encoded, isStreaming = true)
+ val ds = Dataset[A](sqlContext.sparkSession, plan)
logDebug(s"Adding ds: $ds")
this.synchronized {
currentOffset = currentOffset + 1
@@ -118,8 +122,8 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
batches.slice(sliceStart, sliceEnd)
}
- logDebug(
- s"MemoryBatch [$startOrdinal, $endOrdinal]: ${newBlocks.flatMap(_.collect()).mkString(", ")}")
+ logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal))
+
newBlocks
.map(_.toDF())
.reduceOption(_ union _)
@@ -128,6 +132,21 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
}
}
+ private def generateDebugString(
+ blocks: TraversableOnce[Dataset[A]],
+ startOrdinal: Int,
+ endOrdinal: Int): String = {
+ val originalUnsupportedCheck =
+ sqlContext.getConf("spark.sql.streaming.unsupportedOperationCheck")
+ try {
+ sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", "false")
+ s"MemoryBatch [$startOrdinal, $endOrdinal]: " +
+ s"${blocks.flatMap(_.collect()).mkString(", ")}"
+ } finally {
+ sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", originalUnsupportedCheck)
+ }
+ }
+
override def commit(end: Offset): Unit = synchronized {
def check(newOffset: LongOffset): Unit = {
val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala
index 8e63207959575..0b22cbc46e6bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala
@@ -29,8 +29,10 @@ import scala.util.{Failure, Success, Try}
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
+import org.apache.spark.unsafe.types.UTF8String
object TextSocketSource {
@@ -126,17 +128,10 @@ class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, sqlCo
batches.slice(sliceStart, sliceEnd)
}
- import sqlContext.implicits._
- val rawBatch = sqlContext.createDataset(rawList)
-
- // Underlying MemoryStream has schema (String, Timestamp); strip out the timestamp
- // if requested.
- if (includeTimestamp) {
- rawBatch.toDF("value", "timestamp")
- } else {
- // Strip out timestamp
- rawBatch.select("_1").toDF("value")
- }
+ val rdd = sqlContext.sparkContext
+ .parallelize(rawList)
+ .map { case (v, ts) => InternalRow(UTF8String.fromString(v), ts.getTime) }
+ sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true)
}
override def commit(end: Offset): Unit = synchronized {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
index e96fb9f7550a3..64c9d90edcab4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala
@@ -177,7 +177,7 @@ private[ui] class RunningExecutionTable(
showFailedJobs = true) {
override protected def header: Seq[String] =
- baseHeader ++ Seq("Running Jobs", "Succeeded Jobs", "Failed Jobs")
+ baseHeader ++ Seq("Running Job IDs", "Succeeded Job IDs", "Failed Job IDs")
}
private[ui] class CompletedExecutionTable(
@@ -195,7 +195,7 @@ private[ui] class CompletedExecutionTable(
showSucceededJobs = true,
showFailedJobs = false) {
- override protected def header: Seq[String] = baseHeader ++ Seq("Jobs")
+ override protected def header: Seq[String] = baseHeader ++ Seq("Job IDs")
}
private[ui] class FailedExecutionTable(
@@ -214,5 +214,5 @@ private[ui] class FailedExecutionTable(
showFailedJobs = true) {
override protected def header: Seq[String] =
- baseHeader ++ Seq("Succeeded Jobs", "Failed Jobs")
+ baseHeader ++ Seq("Succeeded Job IDs", "Failed Job IDs")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
index b4a91230a0012..8c27af374febd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala
@@ -255,10 +255,8 @@ class SQLListener(conf: SparkConf) extends SparkListener with Logging {
// heartbeat reports
}
case None =>
- // TODO Now just set attemptId to 0. Should fix here when we can get the attempt
- // id from SparkListenerExecutorMetricsUpdate
stageMetrics.taskIdToMetricUpdates(taskId) = new SQLTaskMetrics(
- attemptId = 0, finished = finishTask, accumulatorUpdates)
+ finished = finishTask, accumulatorUpdates)
}
}
case None =>
@@ -478,10 +476,11 @@ private[ui] class SQLStageMetrics(
val stageAttemptId: Long,
val taskIdToMetricUpdates: mutable.HashMap[Long, SQLTaskMetrics] = mutable.HashMap.empty)
+
+// TODO Should add attemptId here when we can get it from SparkListenerExecutorMetricsUpdate
/**
* Store all accumulatorUpdates for a Spark task.
*/
private[ui] class SQLTaskMetrics(
- val attemptId: Long, // TODO not used yet
var finished: Boolean,
var accumulatorUpdates: Seq[(Long, Any)])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 53b2552fa3b36..47324ed9f2fb8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3119,9 +3119,9 @@ object functions {
}
/**
- * (Scala-specific) Converts a column containing a `StructType` or `ArrayType` of `StructType`s
- * into a JSON string with the specified schema. Throws an exception, in the case of an
- * unsupported type.
+ * (Scala-specific) Converts a column containing a `StructType`, `ArrayType` of `StructType`s,
+ * a `MapType` or `ArrayType` of `MapType`s into a JSON string with the specified schema.
+ * Throws an exception, in the case of an unsupported type.
*
* @param e a column containing a struct or array of the structs.
* @param options options to control how the struct column is converted into a json string.
@@ -3135,9 +3135,9 @@ object functions {
}
/**
- * (Java-specific) Converts a column containing a `StructType` or `ArrayType` of `StructType`s
- * into a JSON string with the specified schema. Throws an exception, in the case of an
- * unsupported type.
+ * (Java-specific) Converts a column containing a `StructType`, `ArrayType` of `StructType`s,
+ * a `MapType` or `ArrayType` of `MapType`s into a JSON string with the specified schema.
+ * Throws an exception, in the case of an unsupported type.
*
* @param e a column containing a struct or array of the structs.
* @param options options to control how the struct column is converted into a json string.
@@ -3150,8 +3150,9 @@ object functions {
to_json(e, options.asScala.toMap)
/**
- * Converts a column containing a `StructType` or `ArrayType` of `StructType`s into a JSON string
- * with the specified schema. Throws an exception, in the case of an unsupported type.
+ * Converts a column containing a `StructType`, `ArrayType` of `StructType`s,
+ * a `MapType` or `ArrayType` of `MapType`s into a JSON string with the specified schema.
+ * Throws an exception, in the case of an unsupported type.
*
* @param e a column containing a struct or array of the structs.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index 70ddfa8e9b835..a42e28053a96a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -195,6 +195,9 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* (e.g. 00012)
* `allowBackslashEscapingAnyCharacter` (default `false`): allows accepting quoting of all
* character using backslash quoting mechanism
+ * `allowUnquotedControlChars` (default `false`): allows JSON Strings to contain unquoted
+ * control characters (ASCII characters with value less than 32, including tab and line feed
+ * characters) or not.
* `mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt records
* during parsing.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
index 3000c4233cfb3..cedc1dce4a703 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
@@ -55,6 +55,8 @@ class StateOperatorProgress private[sql](
("numRowsUpdated" -> JInt(numRowsUpdated)) ~
("memoryUsedBytes" -> JInt(memoryUsedBytes))
}
+
+ override def toString: String = prettyJson
}
/**
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 4ca3b6406a328..13b006fc48ac3 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -1283,6 +1283,76 @@ public void test() {
ds.collectAsList();
}
+ public enum MyEnum {
+ A("www.elgoog.com"),
+ B("www.google.com");
+
+ private String url;
+
+ MyEnum(String url) {
+ this.url = url;
+ }
+
+ public String getUrl() {
+ return url;
+ }
+
+ public void setUrl(String url) {
+ this.url = url;
+ }
+ }
+
+ public static class BeanWithEnum {
+ MyEnum enumField;
+ String regularField;
+
+ public String getRegularField() {
+ return regularField;
+ }
+
+ public void setRegularField(String regularField) {
+ this.regularField = regularField;
+ }
+
+ public MyEnum getEnumField() {
+ return enumField;
+ }
+
+ public void setEnumField(MyEnum field) {
+ this.enumField = field;
+ }
+
+ public BeanWithEnum(MyEnum enumField, String regularField) {
+ this.enumField = enumField;
+ this.regularField = regularField;
+ }
+
+ public BeanWithEnum() {
+ }
+
+ public String toString() {
+ return "BeanWithEnum(" + enumField + ", " + regularField + ")";
+ }
+
+ public boolean equals(Object other) {
+ if (other instanceof BeanWithEnum) {
+ BeanWithEnum beanWithEnum = (BeanWithEnum) other;
+ return beanWithEnum.regularField.equals(regularField)
+ && beanWithEnum.enumField.equals(enumField);
+ }
+ return false;
+ }
+ }
+
+ @Test
+ public void testBeanWithEnum() {
+ List data = Arrays.asList(new BeanWithEnum(MyEnum.A, "mira avenue"),
+ new BeanWithEnum(MyEnum.B, "flower boulevard"));
+ Encoder encoder = Encoders.bean(BeanWithEnum.class);
+ Dataset ds = spark.createDataset(data, encoder);
+ Assert.assertEquals(ds.collectAsList(), data);
+ }
+
public static class EmptyBean implements Serializable {}
@Test
diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe-part-after-analyze.sql b/sql/core/src/test/resources/sql-tests/inputs/describe-part-after-analyze.sql
new file mode 100644
index 0000000000000..f4239da906276
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/describe-part-after-analyze.sql
@@ -0,0 +1,34 @@
+CREATE TABLE t (key STRING, value STRING, ds STRING, hr INT) USING parquet
+ PARTITIONED BY (ds, hr);
+
+INSERT INTO TABLE t PARTITION (ds='2017-08-01', hr=10)
+VALUES ('k1', 100), ('k2', 200), ('k3', 300);
+
+INSERT INTO TABLE t PARTITION (ds='2017-08-01', hr=11)
+VALUES ('k1', 101), ('k2', 201), ('k3', 301), ('k4', 401);
+
+INSERT INTO TABLE t PARTITION (ds='2017-09-01', hr=5)
+VALUES ('k1', 102), ('k2', 202);
+
+DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10);
+
+-- Collect stats for a single partition
+ANALYZE TABLE t PARTITION (ds='2017-08-01', hr=10) COMPUTE STATISTICS;
+
+DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10);
+
+-- Collect stats for 2 partitions
+ANALYZE TABLE t PARTITION (ds='2017-08-01') COMPUTE STATISTICS;
+
+DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10);
+DESC EXTENDED t PARTITION (ds='2017-08-01', hr=11);
+
+-- Collect stats for all partitions
+ANALYZE TABLE t PARTITION (ds, hr) COMPUTE STATISTICS;
+
+DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10);
+DESC EXTENDED t PARTITION (ds='2017-08-01', hr=11);
+DESC EXTENDED t PARTITION (ds='2017-09-01', hr=5);
+
+-- DROP TEST TABLES/VIEWS
+DROP TABLE t;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe-table-column.sql b/sql/core/src/test/resources/sql-tests/inputs/describe-table-column.sql
new file mode 100644
index 0000000000000..a6ddcd999bf9b
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/describe-table-column.sql
@@ -0,0 +1,41 @@
+-- Test temp table
+CREATE TEMPORARY VIEW desc_col_temp_view (key int COMMENT 'column_comment') USING PARQUET;
+
+DESC desc_col_temp_view key;
+
+DESC EXTENDED desc_col_temp_view key;
+
+DESC FORMATTED desc_col_temp_view key;
+
+-- Describe a column with qualified name
+DESC FORMATTED desc_col_temp_view desc_col_temp_view.key;
+
+-- Describe a non-existent column
+DESC desc_col_temp_view key1;
+
+-- Test persistent table
+CREATE TABLE desc_col_table (key int COMMENT 'column_comment') USING PARQUET;
+
+ANALYZE TABLE desc_col_table COMPUTE STATISTICS FOR COLUMNS key;
+
+DESC desc_col_table key;
+
+DESC EXTENDED desc_col_table key;
+
+DESC FORMATTED desc_col_table key;
+
+-- Test complex columns
+CREATE TABLE desc_complex_col_table (`a.b` int, col struct) USING PARQUET;
+
+DESC FORMATTED desc_complex_col_table `a.b`;
+
+DESC FORMATTED desc_complex_col_table col;
+
+-- Describe a nested column
+DESC FORMATTED desc_complex_col_table col.x;
+
+DROP VIEW desc_col_temp_view;
+
+DROP TABLE desc_col_table;
+
+DROP TABLE desc_complex_col_table;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe.sql b/sql/core/src/test/resources/sql-tests/inputs/describe.sql
index a222e11916cda..f26d5efec076c 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/describe.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/describe.sql
@@ -1,7 +1,8 @@
CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet
OPTIONS (a '1', b '2')
PARTITIONED BY (c, d) CLUSTERED BY (a) SORTED BY (b ASC) INTO 2 BUCKETS
- COMMENT 'table_comment';
+ COMMENT 'table_comment'
+ TBLPROPERTIES (t 'test');
CREATE TEMPORARY VIEW temp_v AS SELECT * FROM t;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql
index b3cc2cea51d43..fea069eac4d48 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql
@@ -4,6 +4,11 @@ describe function extended to_json;
select to_json(named_struct('a', 1, 'b', 2));
select to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy'));
select to_json(array(named_struct('a', 1, 'b', 2)));
+select to_json(map(named_struct('a', 1, 'b', 2), named_struct('a', 1, 'b', 2)));
+select to_json(map('a', named_struct('a', 1, 'b', 2)));
+select to_json(map('a', 1));
+select to_json(array(map('a',1)));
+select to_json(array(map('a',1), map('b',2)));
-- Check if errors handled
select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE'));
select to_json(named_struct('a', 1, 'b', 2), map('mode', 1));
@@ -20,3 +25,9 @@ select from_json('{"a":1}', 'a InvalidType');
select from_json('{"a":1}', 'a INT', named_struct('mode', 'PERMISSIVE'));
select from_json('{"a":1}', 'a INT', map('mode', 1));
select from_json();
+-- json_tuple
+SELECT json_tuple('{"a" : 1, "b" : 2}', CAST(NULL AS STRING), 'b', CAST(NULL AS STRING), 'a');
+CREATE TEMPORARY VIEW jsonTable(jsonField, a) AS SELECT * FROM VALUES ('{"a": 1, "b": 2}', 'a');
+SELECT json_tuple(jsonField, 'b', CAST(NULL AS STRING), a) FROM jsonTable;
+-- Clean up
+DROP VIEW IF EXISTS jsonTable;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql
new file mode 100644
index 0000000000000..3b3d4ad64b3ec
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/predicate-functions.sql
@@ -0,0 +1,36 @@
+-- EqualTo
+select 1 = 1;
+select 1 = '1';
+select 1.0 = '1';
+
+-- GreaterThan
+select 1 > '1';
+select 2 > '1.0';
+select 2 > '2.0';
+select 2 > '2.2';
+select to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52');
+select to_date('2009-07-30 04:17:52') > '2009-07-30 04:17:52';
+
+-- GreaterThanOrEqual
+select 1 >= '1';
+select 2 >= '1.0';
+select 2 >= '2.0';
+select 2.0 >= '2.2';
+select to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52');
+select to_date('2009-07-30 04:17:52') >= '2009-07-30 04:17:52';
+
+-- LessThan
+select 1 < '1';
+select 2 < '1.0';
+select 2 < '2.0';
+select 2.0 < '2.2';
+select to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52');
+select to_date('2009-07-30 04:17:52') < '2009-07-30 04:17:52';
+
+-- LessThanOrEqual
+select 1 <= '1';
+select 2 <= '1.0';
+select 2 <= '2.0';
+select 2.0 <= '2.2';
+select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52');
+select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52';
diff --git a/sql/core/src/test/resources/sql-tests/inputs/show_columns.sql b/sql/core/src/test/resources/sql-tests/inputs/show_columns.sql
index 1e02c2f045ea9..521018e94e501 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/show_columns.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/show_columns.sql
@@ -2,9 +2,9 @@ CREATE DATABASE showdb;
USE showdb;
-CREATE TABLE showcolumn1 (col1 int, `col 2` int) USING parquet;
+CREATE TABLE showcolumn1 (col1 int, `col 2` int) USING json;
CREATE TABLE showcolumn2 (price int, qty int, year int, month int) USING parquet partitioned by (year, month);
-CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING parquet;
+CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING json;
CREATE GLOBAL TEMP VIEW showColumn4 AS SELECT 1 as col1, 'abc' as `col 5`;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql
new file mode 100644
index 0000000000000..2183ba23afc38
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql
@@ -0,0 +1,13 @@
+CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES
+(1), (2), (3), (4)
+as t1(int_col1);
+
+CREATE FUNCTION myDoubleAvg AS 'test.org.apache.spark.sql.MyDoubleAvg';
+
+SELECT default.myDoubleAvg(int_col1) as my_avg from t1;
+
+SELECT default.myDoubleAvg(int_col1, 3) as my_avg from t1;
+
+CREATE FUNCTION udaf1 AS 'test.non.existent.udaf';
+
+SELECT default.udaf1(int_col1) as udaf1 from t1;
diff --git a/sql/core/src/test/resources/sql-tests/results/cross-join.sql.out b/sql/core/src/test/resources/sql-tests/results/cross-join.sql.out
index e75cc4448a1ea..3833c42bdfecf 100644
--- a/sql/core/src/test/resources/sql-tests/results/cross-join.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/cross-join.sql.out
@@ -128,6 +128,7 @@ two 2 two 2 one 1 two 2
two 2 two 2 three 3 two 2
two 2 two 2 two 2 two 2
+
-- !query 12
SELECT * FROM nt1 CROSS JOIN nt2 ON (nt1.k > nt2.k)
-- !query 12 schema
diff --git a/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out
new file mode 100644
index 0000000000000..43f73e3b22aa5
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/describe-part-after-analyze.sql.out
@@ -0,0 +1,244 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 15
+
+
+-- !query 0
+CREATE TABLE t (key STRING, value STRING, ds STRING, hr INT) USING parquet
+ PARTITIONED BY (ds, hr)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+INSERT INTO TABLE t PARTITION (ds='2017-08-01', hr=10)
+VALUES ('k1', 100), ('k2', 200), ('k3', 300)
+-- !query 1 schema
+struct<>
+-- !query 1 output
+
+
+
+-- !query 2
+INSERT INTO TABLE t PARTITION (ds='2017-08-01', hr=11)
+VALUES ('k1', 101), ('k2', 201), ('k3', 301), ('k4', 401)
+-- !query 2 schema
+struct<>
+-- !query 2 output
+
+
+
+-- !query 3
+INSERT INTO TABLE t PARTITION (ds='2017-09-01', hr=5)
+VALUES ('k1', 102), ('k2', 202)
+-- !query 3 schema
+struct<>
+-- !query 3 output
+
+
+
+-- !query 4
+DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10)
+-- !query 4 schema
+struct
+-- !query 4 output
+key string
+value string
+ds string
+hr int
+# Partition Information
+# col_name data_type comment
+ds string
+hr int
+
+# Detailed Partition Information
+Database default
+Table t
+Partition Values [ds=2017-08-01, hr=10]
+Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10
+
+# Storage Information
+Location [not included in comparison]sql/core/spark-warehouse/t
+
+
+-- !query 5
+ANALYZE TABLE t PARTITION (ds='2017-08-01', hr=10) COMPUTE STATISTICS
+-- !query 5 schema
+struct<>
+-- !query 5 output
+
+
+
+-- !query 6
+DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10)
+-- !query 6 schema
+struct
+-- !query 6 output
+key string
+value string
+ds string
+hr int
+# Partition Information
+# col_name data_type comment
+ds string
+hr int
+
+# Detailed Partition Information
+Database default
+Table t
+Partition Values [ds=2017-08-01, hr=10]
+Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10
+Partition Statistics 1027 bytes, 3 rows
+
+# Storage Information
+Location [not included in comparison]sql/core/spark-warehouse/t
+
+
+-- !query 7
+ANALYZE TABLE t PARTITION (ds='2017-08-01') COMPUTE STATISTICS
+-- !query 7 schema
+struct<>
+-- !query 7 output
+
+
+
+-- !query 8
+DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10)
+-- !query 8 schema
+struct
+-- !query 8 output
+key string
+value string
+ds string
+hr int
+# Partition Information
+# col_name data_type comment
+ds string
+hr int
+
+# Detailed Partition Information
+Database default
+Table t
+Partition Values [ds=2017-08-01, hr=10]
+Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10
+Partition Statistics 1027 bytes, 3 rows
+
+# Storage Information
+Location [not included in comparison]sql/core/spark-warehouse/t
+
+
+-- !query 9
+DESC EXTENDED t PARTITION (ds='2017-08-01', hr=11)
+-- !query 9 schema
+struct
+-- !query 9 output
+key string
+value string
+ds string
+hr int
+# Partition Information
+# col_name data_type comment
+ds string
+hr int
+
+# Detailed Partition Information
+Database default
+Table t
+Partition Values [ds=2017-08-01, hr=11]
+Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11
+Partition Statistics 1040 bytes, 4 rows
+
+# Storage Information
+Location [not included in comparison]sql/core/spark-warehouse/t
+
+
+-- !query 10
+ANALYZE TABLE t PARTITION (ds, hr) COMPUTE STATISTICS
+-- !query 10 schema
+struct<>
+-- !query 10 output
+
+
+
+-- !query 11
+DESC EXTENDED t PARTITION (ds='2017-08-01', hr=10)
+-- !query 11 schema
+struct
+-- !query 11 output
+key string
+value string
+ds string
+hr int
+# Partition Information
+# col_name data_type comment
+ds string
+hr int
+
+# Detailed Partition Information
+Database default
+Table t
+Partition Values [ds=2017-08-01, hr=10]
+Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=10
+Partition Statistics 1027 bytes, 3 rows
+
+# Storage Information
+Location [not included in comparison]sql/core/spark-warehouse/t
+
+
+-- !query 12
+DESC EXTENDED t PARTITION (ds='2017-08-01', hr=11)
+-- !query 12 schema
+struct
+-- !query 12 output
+key string
+value string
+ds string
+hr int
+# Partition Information
+# col_name data_type comment
+ds string
+hr int
+
+# Detailed Partition Information
+Database default
+Table t
+Partition Values [ds=2017-08-01, hr=11]
+Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-08-01/hr=11
+Partition Statistics 1040 bytes, 4 rows
+
+# Storage Information
+Location [not included in comparison]sql/core/spark-warehouse/t
+
+
+-- !query 13
+DESC EXTENDED t PARTITION (ds='2017-09-01', hr=5)
+-- !query 13 schema
+struct
+-- !query 13 output
+key string
+value string
+ds string
+hr int
+# Partition Information
+# col_name data_type comment
+ds string
+hr int
+
+# Detailed Partition Information
+Database default
+Table t
+Partition Values [ds=2017-09-01, hr=5]
+Location [not included in comparison]sql/core/spark-warehouse/t/ds=2017-09-01/hr=5
+Partition Statistics 1014 bytes, 2 rows
+
+# Storage Information
+Location [not included in comparison]sql/core/spark-warehouse/t
+
+
+-- !query 14
+DROP TABLE t
+-- !query 14 schema
+struct<>
+-- !query 14 output
+
diff --git a/sql/core/src/test/resources/sql-tests/results/describe-table-column.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-table-column.sql.out
new file mode 100644
index 0000000000000..30d0a2dc5a3f7
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/describe-table-column.sql.out
@@ -0,0 +1,208 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 18
+
+
+-- !query 0
+CREATE TEMPORARY VIEW desc_col_temp_view (key int COMMENT 'column_comment') USING PARQUET
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+DESC desc_col_temp_view key
+-- !query 1 schema
+struct
+-- !query 1 output
+col_name key
+data_type int
+comment column_comment
+
+
+-- !query 2
+DESC EXTENDED desc_col_temp_view key
+-- !query 2 schema
+struct
+-- !query 2 output
+col_name key
+data_type int
+comment column_comment
+min NULL
+max NULL
+num_nulls NULL
+distinct_count NULL
+avg_col_len NULL
+max_col_len NULL
+
+
+-- !query 3
+DESC FORMATTED desc_col_temp_view key
+-- !query 3 schema
+struct
+-- !query 3 output
+col_name key
+data_type int
+comment column_comment
+min NULL
+max NULL
+num_nulls NULL
+distinct_count NULL
+avg_col_len NULL
+max_col_len NULL
+
+
+-- !query 4
+DESC FORMATTED desc_col_temp_view desc_col_temp_view.key
+-- !query 4 schema
+struct
+-- !query 4 output
+col_name key
+data_type int
+comment column_comment
+min NULL
+max NULL
+num_nulls NULL
+distinct_count NULL
+avg_col_len NULL
+max_col_len NULL
+
+
+-- !query 5
+DESC desc_col_temp_view key1
+-- !query 5 schema
+struct<>
+-- !query 5 output
+org.apache.spark.sql.AnalysisException
+Column key1 does not exist;
+
+
+-- !query 6
+CREATE TABLE desc_col_table (key int COMMENT 'column_comment') USING PARQUET
+-- !query 6 schema
+struct<>
+-- !query 6 output
+
+
+
+-- !query 7
+ANALYZE TABLE desc_col_table COMPUTE STATISTICS FOR COLUMNS key
+-- !query 7 schema
+struct<>
+-- !query 7 output
+
+
+
+-- !query 8
+DESC desc_col_table key
+-- !query 8 schema
+struct
+-- !query 8 output
+col_name key
+data_type int
+comment column_comment
+
+
+-- !query 9
+DESC EXTENDED desc_col_table key
+-- !query 9 schema
+struct
+-- !query 9 output
+col_name key
+data_type int
+comment column_comment
+min NULL
+max NULL
+num_nulls 0
+distinct_count 0
+avg_col_len 4
+max_col_len 4
+
+
+-- !query 10
+DESC FORMATTED desc_col_table key
+-- !query 10 schema
+struct
+-- !query 10 output
+col_name key
+data_type int
+comment column_comment
+min NULL
+max NULL
+num_nulls 0
+distinct_count 0
+avg_col_len 4
+max_col_len 4
+
+
+-- !query 11
+CREATE TABLE desc_complex_col_table (`a.b` int, col struct) USING PARQUET
+-- !query 11 schema
+struct<>
+-- !query 11 output
+
+
+
+-- !query 12
+DESC FORMATTED desc_complex_col_table `a.b`
+-- !query 12 schema
+struct
+-- !query 12 output
+col_name a.b
+data_type int
+comment NULL
+min NULL
+max NULL
+num_nulls NULL
+distinct_count NULL
+avg_col_len NULL
+max_col_len NULL
+
+
+-- !query 13
+DESC FORMATTED desc_complex_col_table col
+-- !query 13 schema
+struct
+-- !query 13 output
+col_name col
+data_type struct
+comment NULL
+min NULL
+max NULL
+num_nulls NULL
+distinct_count NULL
+avg_col_len NULL
+max_col_len NULL
+
+
+-- !query 14
+DESC FORMATTED desc_complex_col_table col.x
+-- !query 14 schema
+struct<>
+-- !query 14 output
+org.apache.spark.sql.AnalysisException
+DESC TABLE COLUMN command does not support nested data types: col.x;
+
+
+-- !query 15
+DROP VIEW desc_col_temp_view
+-- !query 15 schema
+struct<>
+-- !query 15 output
+
+
+
+-- !query 16
+DROP TABLE desc_col_table
+-- !query 16 schema
+struct<>
+-- !query 16 output
+
+
+
+-- !query 17
+DROP TABLE desc_complex_col_table
+-- !query 17 schema
+struct<>
+-- !query 17 output
+
diff --git a/sql/core/src/test/resources/sql-tests/results/describe.sql.out b/sql/core/src/test/resources/sql-tests/results/describe.sql.out
index b91f2c09f3cd4..8c908b7625056 100644
--- a/sql/core/src/test/resources/sql-tests/results/describe.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/describe.sql.out
@@ -7,6 +7,7 @@ CREATE TABLE t (a STRING, b INT, c STRING, d STRING) USING parquet
OPTIONS (a '1', b '2')
PARTITIONED BY (c, d) CLUSTERED BY (a) SORTED BY (b ASC) INTO 2 BUCKETS
COMMENT 'table_comment'
+ TBLPROPERTIES (t 'test')
-- !query 0 schema
struct<>
-- !query 0 output
@@ -129,7 +130,7 @@ Num Buckets 2
Bucket Columns [`a`]
Sort Columns [`b`]
Comment table_comment
-Table Properties [e=3]
+Table Properties [t=test, e=3]
Location [not included in comparison]sql/core/spark-warehouse/t
Storage Properties [a=1, b=2]
Partition Provider Catalog
@@ -161,7 +162,7 @@ Num Buckets 2
Bucket Columns [`a`]
Sort Columns [`b`]
Comment table_comment
-Table Properties [e=3]
+Table Properties [t=test, e=3]
Location [not included in comparison]sql/core/spark-warehouse/t
Storage Properties [a=1, b=2]
Partition Provider Catalog
@@ -201,6 +202,7 @@ Num Buckets 2
Bucket Columns [`a`]
Sort Columns [`b`]
Comment table_comment
+Table Properties [t=test]
Location [not included in comparison]sql/core/spark-warehouse/t
Storage Properties [a=1, b=2]
Partition Provider Catalog
@@ -239,6 +241,7 @@ Provider parquet
Num Buckets 2
Bucket Columns [`a`]
Sort Columns [`b`]
+Table Properties [t=test]
Location [not included in comparison]sql/core/spark-warehouse/t
Storage Properties [a=1, b=2]
Partition Provider Catalog
diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out
index 22da20d9a9f4e..d9dc728a18e8d 100644
--- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 17
+-- Number of queries: 26
-- !query 0
@@ -26,6 +26,14 @@ Extended Usage:
{"time":"26/08/2015"}
> SELECT to_json(array(named_struct('a', 1, 'b', 2));
[{"a":1,"b":2}]
+ > SELECT to_json(map('a', named_struct('b', 1)));
+ {"a":{"b":1}}
+ > SELECT to_json(map(named_struct('a', 1),named_struct('b', 2)));
+ {"[1]":{"b":2}}
+ > SELECT to_json(map('a', 1));
+ {"a":1}
+ > SELECT to_json(array((map('a', 1))));
+ [{"a":1}]
Since: 2.2.0
@@ -58,47 +66,87 @@ struct
-- !query 5
-select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE'))
+select to_json(map(named_struct('a', 1, 'b', 2), named_struct('a', 1, 'b', 2)))
-- !query 5 schema
-struct<>
+struct
-- !query 5 output
+{"[1,2]":{"a":1,"b":2}}
+
+
+-- !query 6
+select to_json(map('a', named_struct('a', 1, 'b', 2)))
+-- !query 6 schema
+struct
+-- !query 6 output
+{"a":{"a":1,"b":2}}
+
+
+-- !query 7
+select to_json(map('a', 1))
+-- !query 7 schema
+struct
+-- !query 7 output
+{"a":1}
+
+
+-- !query 8
+select to_json(array(map('a',1)))
+-- !query 8 schema
+struct
+-- !query 8 output
+[{"a":1}]
+
+
+-- !query 9
+select to_json(array(map('a',1), map('b',2)))
+-- !query 9 schema
+struct
+-- !query 9 output
+[{"a":1},{"b":2}]
+
+
+-- !query 10
+select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE'))
+-- !query 10 schema
+struct<>
+-- !query 10 output
org.apache.spark.sql.AnalysisException
Must use a map() function for options;; line 1 pos 7
--- !query 6
+-- !query 11
select to_json(named_struct('a', 1, 'b', 2), map('mode', 1))
--- !query 6 schema
+-- !query 11 schema
struct<>
--- !query 6 output
+-- !query 11 output
org.apache.spark.sql.AnalysisException
A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7
--- !query 7
+-- !query 12
select to_json()
--- !query 7 schema
+-- !query 12 schema
struct<>
--- !query 7 output
+-- !query 12 output
org.apache.spark.sql.AnalysisException
Invalid number of arguments for function to_json; line 1 pos 7
--- !query 8
+-- !query 13
describe function from_json
--- !query 8 schema
+-- !query 13 schema
struct
--- !query 8 output
+-- !query 13 output
Class: org.apache.spark.sql.catalyst.expressions.JsonToStructs
Function: from_json
Usage: from_json(jsonStr, schema[, options]) - Returns a struct value with the given `jsonStr` and `schema`.
--- !query 9
+-- !query 14
describe function extended from_json
--- !query 9 schema
+-- !query 14 schema
struct
--- !query 9 output
+-- !query 14 output
Class: org.apache.spark.sql.catalyst.expressions.JsonToStructs
Extended Usage:
Examples:
@@ -113,36 +161,36 @@ Function: from_json
Usage: from_json(jsonStr, schema[, options]) - Returns a struct value with the given `jsonStr` and `schema`.
--- !query 10
+-- !query 15
select from_json('{"a":1}', 'a INT')
--- !query 10 schema
+-- !query 15 schema
struct>
--- !query 10 output
+-- !query 15 output
{"a":1}
--- !query 11
+-- !query 16
select from_json('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy'))
--- !query 11 schema
+-- !query 16 schema
struct>
--- !query 11 output
+-- !query 16 output
{"time":2015-08-26 00:00:00.0}
--- !query 12
+-- !query 17
select from_json('{"a":1}', 1)
--- !query 12 schema
+-- !query 17 schema
struct<>
--- !query 12 output
+-- !query 17 output
org.apache.spark.sql.AnalysisException
Expected a string literal instead of 1;; line 1 pos 7
--- !query 13
+-- !query 18
select from_json('{"a":1}', 'a InvalidType')
--- !query 13 schema
+-- !query 18 schema
struct<>
--- !query 13 output
+-- !query 18 output
org.apache.spark.sql.AnalysisException
DataType invalidtype is not supported.(line 1, pos 2)
@@ -153,28 +201,60 @@ a InvalidType
; line 1 pos 7
--- !query 14
+-- !query 19
select from_json('{"a":1}', 'a INT', named_struct('mode', 'PERMISSIVE'))
--- !query 14 schema
+-- !query 19 schema
struct<>
--- !query 14 output
+-- !query 19 output
org.apache.spark.sql.AnalysisException
Must use a map() function for options;; line 1 pos 7
--- !query 15
+-- !query 20
select from_json('{"a":1}', 'a INT', map('mode', 1))
--- !query 15 schema
+-- !query 20 schema
struct<>
--- !query 15 output
+-- !query 20 output
org.apache.spark.sql.AnalysisException
A type of keys and values in map() must be string, but got MapType(StringType,IntegerType,false);; line 1 pos 7
--- !query 16
+-- !query 21
select from_json()
--- !query 16 schema
+-- !query 21 schema
struct<>
--- !query 16 output
+-- !query 21 output
org.apache.spark.sql.AnalysisException
Invalid number of arguments for function from_json; line 1 pos 7
+
+
+-- !query 22
+SELECT json_tuple('{"a" : 1, "b" : 2}', CAST(NULL AS STRING), 'b', CAST(NULL AS STRING), 'a')
+-- !query 22 schema
+struct
+-- !query 22 output
+NULL 2 NULL 1
+
+
+-- !query 23
+CREATE TEMPORARY VIEW jsonTable(jsonField, a) AS SELECT * FROM VALUES ('{"a": 1, "b": 2}', 'a')
+-- !query 23 schema
+struct<>
+-- !query 23 output
+
+
+
+-- !query 24
+SELECT json_tuple(jsonField, 'b', CAST(NULL AS STRING), a) FROM jsonTable
+-- !query 24 schema
+struct
+-- !query 24 output
+2 NULL 1
+
+
+-- !query 25
+DROP VIEW IF EXISTS jsonTable
+-- !query 25 schema
+struct<>
+-- !query 25 output
+
diff --git a/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out
new file mode 100644
index 0000000000000..8e7e04c8e1c4f
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/predicate-functions.sql.out
@@ -0,0 +1,218 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 27
+
+
+-- !query 0
+select 1 = 1
+-- !query 0 schema
+struct<(1 = 1):boolean>
+-- !query 0 output
+true
+
+
+-- !query 1
+select 1 = '1'
+-- !query 1 schema
+struct<(1 = CAST(1 AS INT)):boolean>
+-- !query 1 output
+true
+
+
+-- !query 2
+select 1.0 = '1'
+-- !query 2 schema
+struct<(1.0 = CAST(1 AS DECIMAL(2,1))):boolean>
+-- !query 2 output
+true
+
+
+-- !query 3
+select 1 > '1'
+-- !query 3 schema
+struct<(1 > CAST(1 AS INT)):boolean>
+-- !query 3 output
+false
+
+
+-- !query 4
+select 2 > '1.0'
+-- !query 4 schema
+struct<(2 > CAST(1.0 AS INT)):boolean>
+-- !query 4 output
+true
+
+
+-- !query 5
+select 2 > '2.0'
+-- !query 5 schema
+struct<(2 > CAST(2.0 AS INT)):boolean>
+-- !query 5 output
+false
+
+
+-- !query 6
+select 2 > '2.2'
+-- !query 6 schema
+struct<(2 > CAST(2.2 AS INT)):boolean>
+-- !query 6 output
+false
+
+
+-- !query 7
+select to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52')
+-- !query 7 schema
+struct<(to_date('2009-07-30 04:17:52') > to_date('2009-07-30 04:17:52')):boolean>
+-- !query 7 output
+false
+
+
+-- !query 8
+select to_date('2009-07-30 04:17:52') > '2009-07-30 04:17:52'
+-- !query 8 schema
+struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) > 2009-07-30 04:17:52):boolean>
+-- !query 8 output
+false
+
+
+-- !query 9
+select 1 >= '1'
+-- !query 9 schema
+struct<(1 >= CAST(1 AS INT)):boolean>
+-- !query 9 output
+true
+
+
+-- !query 10
+select 2 >= '1.0'
+-- !query 10 schema
+struct<(2 >= CAST(1.0 AS INT)):boolean>
+-- !query 10 output
+true
+
+
+-- !query 11
+select 2 >= '2.0'
+-- !query 11 schema
+struct<(2 >= CAST(2.0 AS INT)):boolean>
+-- !query 11 output
+true
+
+
+-- !query 12
+select 2.0 >= '2.2'
+-- !query 12 schema
+struct<(2.0 >= CAST(2.2 AS DECIMAL(2,1))):boolean>
+-- !query 12 output
+false
+
+
+-- !query 13
+select to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52')
+-- !query 13 schema
+struct<(to_date('2009-07-30 04:17:52') >= to_date('2009-07-30 04:17:52')):boolean>
+-- !query 13 output
+true
+
+
+-- !query 14
+select to_date('2009-07-30 04:17:52') >= '2009-07-30 04:17:52'
+-- !query 14 schema
+struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) >= 2009-07-30 04:17:52):boolean>
+-- !query 14 output
+false
+
+
+-- !query 15
+select 1 < '1'
+-- !query 15 schema
+struct<(1 < CAST(1 AS INT)):boolean>
+-- !query 15 output
+false
+
+
+-- !query 16
+select 2 < '1.0'
+-- !query 16 schema
+struct<(2 < CAST(1.0 AS INT)):boolean>
+-- !query 16 output
+false
+
+
+-- !query 17
+select 2 < '2.0'
+-- !query 17 schema
+struct<(2 < CAST(2.0 AS INT)):boolean>
+-- !query 17 output
+false
+
+
+-- !query 18
+select 2.0 < '2.2'
+-- !query 18 schema
+struct<(2.0 < CAST(2.2 AS DECIMAL(2,1))):boolean>
+-- !query 18 output
+true
+
+
+-- !query 19
+select to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52')
+-- !query 19 schema
+struct<(to_date('2009-07-30 04:17:52') < to_date('2009-07-30 04:17:52')):boolean>
+-- !query 19 output
+false
+
+
+-- !query 20
+select to_date('2009-07-30 04:17:52') < '2009-07-30 04:17:52'
+-- !query 20 schema
+struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) < 2009-07-30 04:17:52):boolean>
+-- !query 20 output
+true
+
+
+-- !query 21
+select 1 <= '1'
+-- !query 21 schema
+struct<(1 <= CAST(1 AS INT)):boolean>
+-- !query 21 output
+true
+
+
+-- !query 22
+select 2 <= '1.0'
+-- !query 22 schema
+struct<(2 <= CAST(1.0 AS INT)):boolean>
+-- !query 22 output
+false
+
+
+-- !query 23
+select 2 <= '2.0'
+-- !query 23 schema
+struct<(2 <= CAST(2.0 AS INT)):boolean>
+-- !query 23 output
+true
+
+
+-- !query 24
+select 2.0 <= '2.2'
+-- !query 24 schema
+struct<(2.0 <= CAST(2.2 AS DECIMAL(2,1))):boolean>
+-- !query 24 output
+true
+
+
+-- !query 25
+select to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52')
+-- !query 25 schema
+struct<(to_date('2009-07-30 04:17:52') <= to_date('2009-07-30 04:17:52')):boolean>
+-- !query 25 output
+true
+
+
+-- !query 26
+select to_date('2009-07-30 04:17:52') <= '2009-07-30 04:17:52'
+-- !query 26 schema
+struct<(CAST(to_date('2009-07-30 04:17:52') AS STRING) <= 2009-07-30 04:17:52):boolean>
+-- !query 26 output
+true
diff --git a/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out b/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out
index 05c3a083ee3b3..71d6e120e8943 100644
--- a/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/show_columns.sql.out
@@ -19,7 +19,7 @@ struct<>
-- !query 2
-CREATE TABLE showcolumn1 (col1 int, `col 2` int) USING parquet
+CREATE TABLE showcolumn1 (col1 int, `col 2` int) USING json
-- !query 2 schema
struct<>
-- !query 2 output
@@ -35,7 +35,7 @@ struct<>
-- !query 4
-CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING parquet
+CREATE TEMPORARY VIEW showColumn3 (col3 int, `col 4` int) USING json
-- !query 4 schema
struct<>
-- !query 4 output
diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out
index 9ea9d3c4c6f40..70aeb9373f3c7 100644
--- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/subq-input-typecheck.sql.out
@@ -80,8 +80,7 @@ number of columns in the output of subquery.
Left side columns:
[t1.`t1a`].
Right side columns:
-[t2.`t2a`, t2.`t2b`].
- ;
+[t2.`t2a`, t2.`t2b`].;
-- !query 6
@@ -102,5 +101,4 @@ number of columns in the output of subquery.
Left side columns:
[t1.`t1a`, t1.`t1b`].
Right side columns:
-[t2.`t2a`].
- ;
+[t2.`t2a`].;
diff --git a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out
new file mode 100644
index 0000000000000..4815a578b1029
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out
@@ -0,0 +1,54 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 6
+
+
+-- !query 0
+CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES
+(1), (2), (3), (4)
+as t1(int_col1)
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+CREATE FUNCTION myDoubleAvg AS 'test.org.apache.spark.sql.MyDoubleAvg'
+-- !query 1 schema
+struct<>
+-- !query 1 output
+
+
+
+-- !query 2
+SELECT default.myDoubleAvg(int_col1) as my_avg from t1
+-- !query 2 schema
+struct
+-- !query 2 output
+102.5
+
+
+-- !query 3
+SELECT default.myDoubleAvg(int_col1, 3) as my_avg from t1
+-- !query 3 schema
+struct<>
+-- !query 3 output
+java.lang.AssertionError
+assertion failed: Incorrect number of children
+
+
+-- !query 4
+CREATE FUNCTION udaf1 AS 'test.non.existent.udaf'
+-- !query 4 schema
+struct<>
+-- !query 4 output
+
+
+
+-- !query 5
+SELECT default.udaf1(int_col1) as udaf1 from t1
+-- !query 5 schema
+struct<>
+-- !query 5 output
+org.apache.spark.sql.AnalysisException
+Can not load class 'test.non.existent.udaf' when registering the function 'default.udaf1', please make sure it is on the classpath; line 1 pos 7
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index affe97120c8f6..8549eac58ee95 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -190,6 +190,22 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
)
}
+ test("SPARK-21980: References in grouping functions should be indexed with semanticEquals") {
+ checkAnswer(
+ courseSales.cube("course", "year")
+ .agg(grouping("CouRse"), grouping("year")),
+ Row("Java", 2012, 0, 0) ::
+ Row("Java", 2013, 0, 0) ::
+ Row("Java", null, 0, 1) ::
+ Row("dotNET", 2012, 0, 0) ::
+ Row("dotNET", 2013, 0, 0) ::
+ Row("dotNET", null, 0, 1) ::
+ Row(null, 2012, 1, 0) ::
+ Row(null, 2013, 1, 0) ::
+ Row(null, null, 1, 1) :: Nil
+ )
+ }
+
test("rollup overlapping columns") {
checkAnswer(
testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 0681b9cbeb1d8..50e475984f458 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -422,7 +422,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
v
}
withSQLConf(
- (SQLConf.WHOLESTAGE_FALLBACK.key, codegenFallback.toString),
+ (SQLConf.CODEGEN_FALLBACK.key, codegenFallback.toString),
(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString)) {
val df = spark.range(0, 4, 1, 4).withColumn("c", c)
val rows = df.collect()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index de0c14b1c880d..91e6e051963a3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -2012,7 +2012,17 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val filter = (0 until N)
.foldLeft(lit(false))((e, index) => e.or(df.col(df.columns(index)) =!= "string"))
- df.filter(filter).count
+
+ withSQLConf(SQLConf.CODEGEN_FALLBACK.key -> "true") {
+ df.filter(filter).count()
+ }
+
+ withSQLConf(SQLConf.CODEGEN_FALLBACK.key -> "false") {
+ val e = intercept[SparkException] {
+ df.filter(filter).count()
+ }.getMessage
+ assert(e.contains("grows beyond 64 KB"))
+ }
}
test("SPARK-20897: cached self-join should not fail") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index f62f9e23db66d..edcdd77908d3a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -151,7 +151,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
test("foreachPartition") {
val ds = Seq(1, 2, 3).toDS()
val acc = sparkContext.longAccumulator
- ds.foreachPartition(_.foreach(acc.add(_)))
+ ds.foreachPartition((it: Iterator[Int]) => it.foreach(acc.add(_)))
assert(acc.value == 6)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 6245b2eff9fa1..5015f3709f131 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -364,7 +364,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
test("foreachPartition") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
val acc = sparkContext.longAccumulator
- ds.foreachPartition(_.foreach(v => acc.add(v._2)))
+ ds.foreachPartition((it: Iterator[(String, Int)]) => it.foreach(v => acc.add(v._2)))
assert(acc.value == 6)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 86fe09bd977af..453052a8ce191 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import scala.collection.JavaConverters._
import scala.collection.mutable.ListBuffer
import scala.language.existentials
@@ -26,6 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.StructType
class JoinSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -767,4 +769,22 @@ class JoinSuite extends QueryTest with SharedSQLContext {
}
}
}
+
+ test("outer broadcast hash join should not throw NPE") {
+ withTempView("v1", "v2") {
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") {
+ Seq(2 -> 2).toDF("x", "y").createTempView("v1")
+
+ spark.createDataFrame(
+ Seq(Row(1, "a")).asJava,
+ new StructType().add("i", "int", nullable = false).add("j", "string", nullable = false)
+ ).createTempView("v2")
+
+ checkAnswer(
+ sql("select x, y, i, j from v1 left join v2 on x = i and y < length(j)"),
+ Row(2, 2, null, null)
+ )
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
index cf2d00fc94423..00d2acc4a1d8a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql
-import org.apache.spark.sql.functions.{from_json, struct, to_json}
+import org.apache.spark.sql.functions.{from_json, lit, map, struct, to_json}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -180,10 +180,26 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
test("to_json - array") {
val df = Seq(Tuple1(Tuple1(1) :: Nil)).toDF("a")
+ val df2 = Seq(Tuple1(Map("a" -> 1) :: Nil)).toDF("a")
checkAnswer(
df.select(to_json($"a")),
Row("""[{"_1":1}]""") :: Nil)
+ checkAnswer(
+ df2.select(to_json($"a")),
+ Row("""[{"a":1}]""") :: Nil)
+ }
+
+ test("to_json - map") {
+ val df1 = Seq(Map("a" -> Tuple1(1))).toDF("a")
+ val df2 = Seq(Map("a" -> 1)).toDF("a")
+
+ checkAnswer(
+ df1.select(to_json($"a")),
+ Row("""{"a":{"_1":1}}""") :: Nil)
+ checkAnswer(
+ df2.select(to_json($"a")),
+ Row("""{"a":1}""") :: Nil)
}
test("to_json with option") {
@@ -195,15 +211,33 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
Row("""{"_1":"26/08/2015 18:00"}""") :: Nil)
}
- test("to_json unsupported type") {
+ test("to_json - key types of map don't matter") {
+ // interval type is invalid for converting to JSON. However, the keys of a map are treated
+ // as strings, so its type doesn't matter.
val df = Seq(Tuple1(Tuple1("interval -3 month 7 hours"))).toDF("a")
- .select(struct($"a._1".cast(CalendarIntervalType).as("a")).as("c"))
+ .select(struct(map($"a._1".cast(CalendarIntervalType), lit("a")).as("col1")).as("c"))
+ checkAnswer(
+ df.select(to_json($"c")),
+ Row("""{"col1":{"interval -3 months 7 hours":"a"}}""") :: Nil)
+ }
+
+ test("to_json unsupported type") {
+ val baseDf = Seq(Tuple1(Tuple1("interval -3 month 7 hours"))).toDF("a")
+ val df = baseDf.select(struct($"a._1".cast(CalendarIntervalType).as("a")).as("c"))
val e = intercept[AnalysisException]{
// Unsupported type throws an exception
df.select(to_json($"c")).collect()
}
assert(e.getMessage.contains(
"Unable to convert column a of type calendarinterval to JSON."))
+
+ // interval type is invalid for converting to JSON. We can't use it as value type of a map.
+ val df2 = baseDf
+ .select(struct(map(lit("a"), $"a._1".cast(CalendarIntervalType)).as("col1")).as("c"))
+ val e2 = intercept[AnalysisException] {
+ df2.select(to_json($"c")).collect()
+ }
+ assert(e2.getMessage.contains("Unable to convert column col1 of type calendarinterval to JSON"))
}
test("roundtrip in to_json and from_json - struct") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 923c6d8eb71fd..93a7777b70b46 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2663,4 +2663,18 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
// In unit test, Spark will fail the query if memory leak detected.
spark.range(100).groupBy("id").count().limit(1).collect()
}
+
+ test("SPARK-21652: rule confliction of InferFiltersFromConstraints and ConstantPropagation") {
+ withTempView("t1", "t2") {
+ Seq((1, 1)).toDF("col1", "col2").createOrReplaceTempView("t1")
+ Seq(1, 2).toDF("col").createOrReplaceTempView("t2")
+ val df = sql(
+ """
+ |SELECT *
+ |FROM t1, t2
+ |WHERE t1.col1 = 1 AND 1 = t1.col2 AND t1.col1 = t2.col AND t1.col2 = t2.col
+ """.stripMargin)
+ checkAnswer(df, Row(1, 1, 1))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
index aa000bddf9c7e..e3901af4b9988 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile}
-import org.apache.spark.sql.execution.command.DescribeTableCommand
+import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeTableCommand}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.StructType
@@ -214,11 +214,11 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext {
/** Executes a query and returns the result as (schema of the output, normalized output). */
private def getNormalizedResult(session: SparkSession, sql: String): (StructType, Seq[String]) = {
// Returns true if the plan is supposed to be sorted.
- def needSort(plan: LogicalPlan): Boolean = plan match {
+ def isSorted(plan: LogicalPlan): Boolean = plan match {
case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false
- case _: DescribeTableCommand => true
+ case _: DescribeTableCommand | _: DescribeColumnCommand => true
case PhysicalOperation(_, _, Sort(_, true, _)) => true
- case _ => plan.children.iterator.exists(needSort)
+ case _ => plan.children.iterator.exists(isSorted)
}
try {
@@ -233,7 +233,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext {
.replaceAll("Last Access.*", s"Last Access $notIncludedMsg"))
// If the output is not pre-sorted, sort it.
- if (needSort(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)
+ if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)
} catch {
case a: AnalysisException =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 274694b99541e..8673dc14f7597 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.test.SharedSQLContext
class SubquerySuite extends QueryTest with SharedSQLContext {
@@ -875,4 +876,78 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
assert(e.message.contains("cannot resolve '`a`' given input columns: [t.i, t.j]"))
}
}
+
+ test("SPARK-21835: Join in correlated subquery should be duplicateResolved: case 1") {
+ withTable("t1") {
+ withTempPath { path =>
+ Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath)
+ sql(s"CREATE TABLE t1 USING parquet LOCATION '${path.toURI}'")
+
+ val sqlText =
+ """
+ |SELECT * FROM t1
+ |WHERE
+ |NOT EXISTS (SELECT * FROM t1)
+ """.stripMargin
+ val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan
+ val join = optimizedPlan.collectFirst { case j: Join => j }.get
+ assert(join.duplicateResolved)
+ assert(optimizedPlan.resolved)
+ }
+ }
+ }
+
+ test("SPARK-21835: Join in correlated subquery should be duplicateResolved: case 2") {
+ withTable("t1", "t2", "t3") {
+ withTempPath { path =>
+ val data = Seq((1, 1, 1), (2, 0, 2))
+
+ data.toDF("t1a", "t1b", "t1c").write.parquet(path.getCanonicalPath + "/t1")
+ data.toDF("t2a", "t2b", "t2c").write.parquet(path.getCanonicalPath + "/t2")
+ data.toDF("t3a", "t3b", "t3c").write.parquet(path.getCanonicalPath + "/t3")
+
+ sql(s"CREATE TABLE t1 USING parquet LOCATION '${path.toURI}/t1'")
+ sql(s"CREATE TABLE t2 USING parquet LOCATION '${path.toURI}/t2'")
+ sql(s"CREATE TABLE t3 USING parquet LOCATION '${path.toURI}/t3'")
+
+ val sqlText =
+ s"""
+ |SELECT *
+ |FROM (SELECT *
+ | FROM t2
+ | WHERE t2c IN (SELECT t1c
+ | FROM t1
+ | WHERE t1a = t2a)
+ | UNION
+ | SELECT *
+ | FROM t3
+ | WHERE t3a IN (SELECT t2a
+ | FROM t2
+ | UNION ALL
+ | SELECT t1a
+ | FROM t1
+ | WHERE t1b > 0)) t4
+ |WHERE t4.t2b IN (SELECT Min(t3b)
+ | FROM t3
+ | WHERE t4.t2a = t3a)
+ """.stripMargin
+ val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan
+ val joinNodes = optimizedPlan.collect { case j: Join => j }
+ joinNodes.foreach(j => assert(j.duplicateResolved))
+ assert(optimizedPlan.resolved)
+ }
+ }
+ }
+
+ test("SPARK-21835: Join in correlated subquery should be duplicateResolved: case 3") {
+ val sqlText =
+ """
+ |SELECT * FROM l, r WHERE l.a = r.c + 1 AND
+ |(EXISTS (SELECT * FROM r) OR l.a = r.c)
+ """.stripMargin
+ val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan
+ val join = optimizedPlan.collectFirst { case j: Join => j }.get
+ assert(join.duplicateResolved)
+ assert(optimizedPlan.resolved)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index b096a6db8517f..a08433ba794d9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -203,12 +203,12 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT
// Tests to make sure that all operators correctly convert types on the way out.
test("Local UDTs") {
- val df = Seq((1, new UDT.MyDenseVector(Array(0.1, 1.0)))).toDF("int", "vec")
- df.collect()(0).getAs[UDT.MyDenseVector](1)
- df.take(1)(0).getAs[UDT.MyDenseVector](1)
- df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[UDT.MyDenseVector](0)
- df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0)
- .getAs[UDT.MyDenseVector](0)
+ val vec = new UDT.MyDenseVector(Array(0.1, 1.0))
+ val df = Seq((1, vec)).toDF("int", "vec")
+ assert(vec === df.collect()(0).getAs[UDT.MyDenseVector](1))
+ assert(vec === df.take(1)(0).getAs[UDT.MyDenseVector](1))
+ checkAnswer(df.limit(1).groupBy('int).agg(first('vec)), Row(1, vec))
+ checkAnswer(df.orderBy('int).limit(1).groupBy('int).agg(first('vec)), Row(1, vec))
}
test("UDTs with JSON") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
index 06bce9a2400e7..f1b5e3be5b63f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
@@ -280,7 +280,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
Seq(Some(5), None).foreach { minNumPostShufflePartitions =>
val testNameNote = minNumPostShufflePartitions match {
- case Some(numPartitions) => "(minNumPostShufflePartitions: 3)"
+ case Some(numPartitions) => "(minNumPostShufflePartitions: " + numPartitions + ")"
case None => ""
}
@@ -377,7 +377,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
}
test(s"determining the number of reducers: complex query 1$testNameNote") {
- val test = { spark: SparkSession =>
+ val test: (SparkSession) => Unit = { spark: SparkSession =>
val df1 =
spark
.range(0, 1000, 1, numInputPartitions)
@@ -429,7 +429,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
}
test(s"determining the number of reducers: complex query 2$testNameNote") {
- val test = { spark: SparkSession =>
+ val test: (SparkSession) => Unit = { spark: SparkSession =>
val df1 =
spark
.range(0, 1000, 1, numInputPartitions)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala
index 58c310596ca6d..78c1e5dae566d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuerySuite.scala
@@ -42,14 +42,14 @@ class OptimizeMetadataOnlyQuerySuite extends QueryTest with SharedSQLContext {
private def assertMetadataOnlyQuery(df: DataFrame): Unit = {
val localRelations = df.queryExecution.optimizedPlan.collect {
- case l @ LocalRelation(_, _) => l
+ case l @ LocalRelation(_, _, _) => l
}
assert(localRelations.size == 1)
}
private def assertNotMetadataOnlyQuery(df: DataFrame): Unit = {
val localRelations = df.queryExecution.optimizedPlan.collect {
- case l @ LocalRelation(_, _) => l
+ case l @ LocalRelation(_, _, _) => l
}
assert(localRelations.size == 0)
}
@@ -117,4 +117,12 @@ class OptimizeMetadataOnlyQuerySuite extends QueryTest with SharedSQLContext {
"select partcol1, max(partcol2) from srcpart where partcol1 = 0 group by rollup (partcol1)",
"select partcol2 from (select partcol2 from srcpart where partcol1 = 0 union all " +
"select partcol2 from srcpart where partcol1 = 1) t group by partcol2")
+
+ test("SPARK-21884 Fix StackOverflowError on MetadataOnlyQuery") {
+ withTable("t_1000") {
+ sql("CREATE TABLE t_1000 (a INT, p INT) USING PARQUET PARTITIONED BY (p)")
+ (1 to 1000).foreach(p => sql(s"ALTER TABLE t_1000 ADD PARTITION (p=$p)"))
+ sql("SELECT COUNT(DISTINCT p) FROM t_1000").collect()
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala
new file mode 100644
index 0000000000000..c2e62b987e0cc
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLJsonProtocolSuite.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.json4s.jackson.JsonMethods.parse
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart
+import org.apache.spark.util.JsonProtocol
+
+class SQLJsonProtocolSuite extends SparkFunSuite {
+
+ test("SparkPlanGraph backward compatibility: metadata") {
+ val SQLExecutionStartJsonString =
+ """
+ |{
+ | "Event":"org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionStart",
+ | "executionId":0,
+ | "description":"test desc",
+ | "details":"test detail",
+ | "physicalPlanDescription":"test plan",
+ | "sparkPlanInfo": {
+ | "nodeName":"TestNode",
+ | "simpleString":"test string",
+ | "children":[],
+ | "metadata":{},
+ | "metrics":[]
+ | },
+ | "time":0
+ |}
+ """.stripMargin
+ val reconstructedEvent = JsonProtocol.sparkEventFromJson(parse(SQLExecutionStartJsonString))
+ val expectedEvent = SparkListenerSQLExecutionStart(0, "test desc", "test detail", "test plan",
+ new SparkPlanInfo("TestNode", "test string", Nil, Nil), 0)
+ assert(reconstructedEvent == expectedEvent)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala
index aecfd3062147c..5828f9783da42 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala
@@ -40,7 +40,7 @@ class SparkPlannerSuite extends SharedSQLContext {
case Union(children) =>
planned += 1
UnionExec(children.map(planLater)) :: planLater(NeverPlanned) :: Nil
- case LocalRelation(output, data) =>
+ case LocalRelation(output, data, _) =>
planned += 1
LocalTableScanExec(output, data) :: planLater(NeverPlanned) :: Nil
case NeverPlanned =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
index d238c76fbeeff..107a2f7109793 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
@@ -249,8 +249,34 @@ class SparkSqlParserSuite extends AnalysisTest {
assertEqual("describe table formatted t",
DescribeTableCommand(
TableIdentifier("t"), Map.empty, isExtended = true))
+ }
+
+ test("describe table column") {
+ assertEqual("DESCRIBE t col",
+ DescribeColumnCommand(
+ TableIdentifier("t"), Seq("col"), isExtended = false))
+ assertEqual("DESCRIBE t `abc.xyz`",
+ DescribeColumnCommand(
+ TableIdentifier("t"), Seq("abc.xyz"), isExtended = false))
+ assertEqual("DESCRIBE t abc.xyz",
+ DescribeColumnCommand(
+ TableIdentifier("t"), Seq("abc", "xyz"), isExtended = false))
+ assertEqual("DESCRIBE t `a.b`.`x.y`",
+ DescribeColumnCommand(
+ TableIdentifier("t"), Seq("a.b", "x.y"), isExtended = false))
+
+ assertEqual("DESCRIBE TABLE t col",
+ DescribeColumnCommand(
+ TableIdentifier("t"), Seq("col"), isExtended = false))
+ assertEqual("DESCRIBE TABLE EXTENDED t col",
+ DescribeColumnCommand(
+ TableIdentifier("t"), Seq("col"), isExtended = true))
+ assertEqual("DESCRIBE TABLE FORMATTED t col",
+ DescribeColumnCommand(
+ TableIdentifier("t"), Seq("col"), isExtended = true))
- intercept("explain describe tables x", "Unsupported SQL statement")
+ intercept("DESCRIBE TABLE t PARTITION (ds='1970-01-01') col",
+ "DESC TABLE COLUMN for a specific partition is not supported")
}
test("analyze table statistics") {
@@ -259,17 +285,33 @@ class SparkSqlParserSuite extends AnalysisTest {
assertEqual("analyze table t compute statistics noscan",
AnalyzeTableCommand(TableIdentifier("t"), noscan = true))
assertEqual("analyze table t partition (a) compute statistics nOscAn",
- AnalyzeTableCommand(TableIdentifier("t"), noscan = true))
+ AnalyzePartitionCommand(TableIdentifier("t"), Map("a" -> None), noscan = true))
- // Partitions specified - we currently parse them but don't do anything with it
+ // Partitions specified
assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS",
- AnalyzeTableCommand(TableIdentifier("t"), noscan = false))
+ AnalyzePartitionCommand(TableIdentifier("t"), noscan = false,
+ partitionSpec = Map("ds" -> Some("2008-04-09"), "hr" -> Some("11"))))
assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr=11) COMPUTE STATISTICS noscan",
- AnalyzeTableCommand(TableIdentifier("t"), noscan = true))
+ AnalyzePartitionCommand(TableIdentifier("t"), noscan = true,
+ partitionSpec = Map("ds" -> Some("2008-04-09"), "hr" -> Some("11"))))
+ assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09') COMPUTE STATISTICS noscan",
+ AnalyzePartitionCommand(TableIdentifier("t"), noscan = true,
+ partitionSpec = Map("ds" -> Some("2008-04-09"))))
+ assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr) COMPUTE STATISTICS",
+ AnalyzePartitionCommand(TableIdentifier("t"), noscan = false,
+ partitionSpec = Map("ds" -> Some("2008-04-09"), "hr" -> None)))
+ assertEqual("ANALYZE TABLE t PARTITION(ds='2008-04-09', hr) COMPUTE STATISTICS noscan",
+ AnalyzePartitionCommand(TableIdentifier("t"), noscan = true,
+ partitionSpec = Map("ds" -> Some("2008-04-09"), "hr" -> None)))
+ assertEqual("ANALYZE TABLE t PARTITION(ds, hr=11) COMPUTE STATISTICS noscan",
+ AnalyzePartitionCommand(TableIdentifier("t"), noscan = true,
+ partitionSpec = Map("ds" -> None, "hr" -> Some("11"))))
assertEqual("ANALYZE TABLE t PARTITION(ds, hr) COMPUTE STATISTICS",
- AnalyzeTableCommand(TableIdentifier("t"), noscan = false))
+ AnalyzePartitionCommand(TableIdentifier("t"), noscan = false,
+ partitionSpec = Map("ds" -> None, "hr" -> None)))
assertEqual("ANALYZE TABLE t PARTITION(ds, hr) COMPUTE STATISTICS noscan",
- AnalyzeTableCommand(TableIdentifier("t"), noscan = true))
+ AnalyzePartitionCommand(TableIdentifier("t"), noscan = true,
+ partitionSpec = Map("ds" -> None, "hr" -> None)))
intercept("analyze table t compute statistics xxxx",
"Expected `NOSCAN` instead of `xxxx`")
@@ -282,6 +324,11 @@ class SparkSqlParserSuite extends AnalysisTest {
assertEqual("ANALYZE TABLE t COMPUTE STATISTICS FOR COLUMNS key, value",
AnalyzeColumnCommand(TableIdentifier("t"), Seq("key", "value")))
+
+ // Partition specified - should be ignored
+ assertEqual("ANALYZE TABLE t PARTITION(ds='2017-06-10') " +
+ "COMPUTE STATISTICS FOR COLUMNS key, value",
+ AnalyzeColumnCommand(TableIdentifier("t"), Seq("key", "value")))
}
test("query organization") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
index 4893b52f240ec..30422b657742c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
@@ -29,8 +29,9 @@ import org.apache.arrow.vector.file.json.JsonFileReader
import org.apache.arrow.vector.util.Validator
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType}
import org.apache.spark.util.Utils
@@ -1629,6 +1630,32 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll {
}
}
+ test("roundtrip payloads") {
+ val inputRows = (0 until 9).map { i =>
+ InternalRow(i)
+ } :+ InternalRow(null)
+
+ val schema = StructType(Seq(StructField("int", IntegerType, nullable = true)))
+
+ val ctx = TaskContext.empty()
+ val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, ctx)
+ val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx)
+
+ assert(schema.equals(outputRowIter.schema))
+
+ var count = 0
+ outputRowIter.zipWithIndex.foreach { case (row, i) =>
+ if (i != 9) {
+ assert(row.getInt(0) == i)
+ } else {
+ assert(row.isNullAt(0))
+ }
+ count += 1
+ }
+
+ assert(count == inputRows.length)
+ }
+
/** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */
private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = {
// NOTE: coalesce to single partition because can only load 1 batch in validator
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala
index 46db41a8abad9..5a25d72308370 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.benchmark
+import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.IntegerType
@@ -35,7 +36,9 @@ class JoinBenchmark extends BenchmarkBase {
val dim = broadcast(sparkSession.range(M).selectExpr("id as k", "cast(id as string) as v"))
runBenchmark("Join w long", N) {
- sparkSession.range(N).join(dim, (col("id") % M) === col("k")).count()
+ val df = sparkSession.range(N).join(dim, (col("id") % M) === col("k"))
+ assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
+ df.count()
}
/*
@@ -55,7 +58,9 @@ class JoinBenchmark extends BenchmarkBase {
val dim = broadcast(sparkSession.range(M).selectExpr("id as k", "cast(id as string) as v"))
runBenchmark("Join w long duplicated", N) {
val dim = broadcast(sparkSession.range(M).selectExpr("cast(id/10 as long) as k"))
- sparkSession.range(N).join(dim, (col("id") % M) === col("k")).count()
+ val df = sparkSession.range(N).join(dim, (col("id") % M) === col("k"))
+ assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
+ df.count()
}
/*
@@ -75,9 +80,11 @@ class JoinBenchmark extends BenchmarkBase {
.selectExpr("cast(id as int) as k1", "cast(id as int) as k2", "cast(id as string) as v"))
runBenchmark("Join w 2 ints", N) {
- sparkSession.range(N).join(dim2,
+ val df = sparkSession.range(N).join(dim2,
(col("id") % M).cast(IntegerType) === col("k1")
- && (col("id") % M).cast(IntegerType) === col("k2")).count()
+ && (col("id") % M).cast(IntegerType) === col("k2"))
+ assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
+ df.count()
}
/*
@@ -97,9 +104,10 @@ class JoinBenchmark extends BenchmarkBase {
.selectExpr("id as k1", "id as k2", "cast(id as string) as v"))
runBenchmark("Join w 2 longs", N) {
- sparkSession.range(N).join(dim3,
+ val df = sparkSession.range(N).join(dim3,
(col("id") % M) === col("k1") && (col("id") % M) === col("k2"))
- .count()
+ assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
+ df.count()
}
/*
@@ -119,9 +127,10 @@ class JoinBenchmark extends BenchmarkBase {
.selectExpr("cast(id/10 as long) as k1", "cast(id/10 as long) as k2"))
runBenchmark("Join w 2 longs duplicated", N) {
- sparkSession.range(N).join(dim4,
+ val df = sparkSession.range(N).join(dim4,
(col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2"))
- .count()
+ assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
+ df.count()
}
/*
@@ -138,7 +147,9 @@ class JoinBenchmark extends BenchmarkBase {
val M = 1 << 16
val dim = broadcast(sparkSession.range(M).selectExpr("id as k", "cast(id as string) as v"))
runBenchmark("outer join w long", N) {
- sparkSession.range(N).join(dim, (col("id") % M) === col("k"), "left").count()
+ val df = sparkSession.range(N).join(dim, (col("id") % M) === col("k"), "left")
+ assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
+ df.count()
}
/*
@@ -156,7 +167,9 @@ class JoinBenchmark extends BenchmarkBase {
val M = 1 << 16
val dim = broadcast(sparkSession.range(M).selectExpr("id as k", "cast(id as string) as v"))
runBenchmark("semi join w long", N) {
- sparkSession.range(N).join(dim, (col("id") % M) === col("k"), "leftsemi").count()
+ val df = sparkSession.range(N).join(dim, (col("id") % M) === col("k"), "leftsemi")
+ assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[BroadcastHashJoinExec]).isDefined)
+ df.count()
}
/*
@@ -174,7 +187,9 @@ class JoinBenchmark extends BenchmarkBase {
runBenchmark("merge join", N) {
val df1 = sparkSession.range(N).selectExpr(s"id * 2 as k1")
val df2 = sparkSession.range(N).selectExpr(s"id * 3 as k2")
- df1.join(df2, col("k1") === col("k2")).count()
+ val df = df1.join(df2, col("k1") === col("k2"))
+ assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined)
+ df.count()
}
/*
@@ -193,7 +208,9 @@ class JoinBenchmark extends BenchmarkBase {
.selectExpr(s"(id * 15485863) % ${N*10} as k1")
val df2 = sparkSession.range(N)
.selectExpr(s"(id * 15485867) % ${N*10} as k2")
- df1.join(df2, col("k1") === col("k2")).count()
+ val df = df1.join(df2, col("k1") === col("k2"))
+ assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined)
+ df.count()
}
/*
@@ -212,18 +229,19 @@ class JoinBenchmark extends BenchmarkBase {
sparkSession.conf.set("spark.sql.join.preferSortMergeJoin", "false")
runBenchmark("shuffle hash join", N) {
val df1 = sparkSession.range(N).selectExpr(s"id as k1")
- val df2 = sparkSession.range(N / 5).selectExpr(s"id * 3 as k2")
- df1.join(df2, col("k1") === col("k2")).count()
+ val df2 = sparkSession.range(N / 3).selectExpr(s"id * 3 as k2")
+ val df = df1.join(df2, col("k1") === col("k2"))
+ assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[ShuffledHashJoinExec]).isDefined)
+ df.count()
}
/*
- *Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5
- *Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ *Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Windows 7 6.1
+ *Intel64 Family 6 Model 94 Stepping 3, GenuineIntel
*shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
*-------------------------------------------------------------------------------------------
- *shuffle hash join codegen=false 1101 / 1391 3.8 262.6 1.0X
- *shuffle hash join codegen=true 528 / 578 7.9 125.8 2.1X
+ *shuffle hash join codegen=false 2005 / 2010 2.1 478.0 1.0X
+ *shuffle hash join codegen=true 1773 / 1792 2.4 422.7 1.1X
*/
}
-
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala
index d2d013682cd2d..99c6df7389205 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala
@@ -17,9 +17,8 @@
package org.apache.spark.sql.execution.benchmark
-import java.io.File
-
import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
@@ -31,9 +30,9 @@ import org.apache.spark.util.Benchmark
/**
* Benchmark to measure TPCDS query performance.
* To run this:
- * spark-submit --class --jars
+ * spark-submit --class --data-location
*/
-object TPCDSQueryBenchmark {
+object TPCDSQueryBenchmark extends Logging {
val conf =
new SparkConf()
.setMaster("local[1]")
@@ -61,12 +60,10 @@ object TPCDSQueryBenchmark {
}
def tpcdsAll(dataLocation: String, queries: Seq[String]): Unit = {
- require(dataLocation.nonEmpty,
- "please modify the value of dataLocation to point to your local TPCDS data")
val tableSizes = setupTables(dataLocation)
queries.foreach { name =>
- val queryString = fileToString(new File(Thread.currentThread().getContextClassLoader
- .getResource(s"tpcds/$name.sql").getFile))
+ val queryString = resourceToString(s"tpcds/$name.sql",
+ classLoader = Thread.currentThread().getContextClassLoader)
// This is an indirect hack to estimate the size of each query's input by traversing the
// logical plan and adding up the sizes of all tables that appear in the plan. Note that this
@@ -94,11 +91,14 @@ object TPCDSQueryBenchmark {
benchmark.addCase(name) { i =>
spark.sql(queryString).collect()
}
+ logInfo(s"\n\n===== TPCDS QUERY BENCHMARK OUTPUT FOR $name =====\n")
benchmark.run()
+ logInfo(s"\n\n===== FINISHED $name =====\n")
}
}
def main(args: Array[String]): Unit = {
+ val benchmarkArgs = new TPCDSQueryBenchmarkArguments(args)
// List of all TPC-DS queries
val tpcdsQueries = Seq(
@@ -113,12 +113,20 @@ object TPCDSQueryBenchmark {
"q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90",
"q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99")
- // In order to run this benchmark, please follow the instructions at
- // https://github.com/databricks/spark-sql-perf/blob/master/README.md to generate the TPCDS data
- // locally (preferably with a scale factor of 5 for benchmarking). Thereafter, the value of
- // dataLocation below needs to be set to the location where the generated data is stored.
- val dataLocation = ""
+ // If `--query-filter` defined, filters the queries that this option selects
+ val queriesToRun = if (benchmarkArgs.queryFilter.nonEmpty) {
+ val queries = tpcdsQueries.filter { case queryName =>
+ benchmarkArgs.queryFilter.contains(queryName)
+ }
+ if (queries.isEmpty) {
+ throw new RuntimeException(
+ s"Empty queries to run. Bad query name filter: ${benchmarkArgs.queryFilter}")
+ }
+ queries
+ } else {
+ tpcdsQueries
+ }
- tpcdsAll(dataLocation, queries = tpcdsQueries)
+ tpcdsAll(benchmarkArgs.dataLocation, queries = queriesToRun)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmarkArguments.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmarkArguments.scala
new file mode 100644
index 0000000000000..184ffff94298a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmarkArguments.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.benchmark
+
+import java.util.Locale
+
+
+class TPCDSQueryBenchmarkArguments(val args: Array[String]) {
+ var dataLocation: String = null
+ var queryFilter: Set[String] = Set.empty
+
+ parseArgs(args.toList)
+ validateArguments()
+
+ private def optionMatch(optionName: String, s: String): Boolean = {
+ optionName == s.toLowerCase(Locale.ROOT)
+ }
+
+ private def parseArgs(inputArgs: List[String]): Unit = {
+ var args = inputArgs
+
+ while (args.nonEmpty) {
+ args match {
+ case optName :: value :: tail if optionMatch("--data-location", optName) =>
+ dataLocation = value
+ args = tail
+
+ case optName :: value :: tail if optionMatch("--query-filter", optName) =>
+ queryFilter = value.toLowerCase(Locale.ROOT).split(",").map(_.trim).toSet
+ args = tail
+
+ case _ =>
+ // scalastyle:off println
+ System.err.println("Unknown/unsupported param " + args)
+ // scalastyle:on println
+ printUsageAndExit(1)
+ }
+ }
+ }
+
+ private def printUsageAndExit(exitCode: Int): Unit = {
+ // scalastyle:off
+ System.err.println("""
+ |Usage: spark-submit --class [Options]
+ |Options:
+ | --data-location Path to TPCDS data
+ | --query-filter Queries to filter, e.g., q3,q5,q13
+ |
+ |------------------------------------------------------------------------------------------------------------------
+ |In order to run this benchmark, please follow the instructions at
+ |https://github.com/databricks/spark-sql-perf/blob/master/README.md
+ |to generate the TPCDS data locally (preferably with a scale factor of 5 for benchmarking).
+ |Thereafter, the value of needs to be set to the location where the generated data is stored.
+ """.stripMargin)
+ // scalastyle:on
+ System.exit(exitCode)
+ }
+
+ private def validateArguments(): Unit = {
+ if (dataLocation == null) {
+ // scalastyle:off println
+ System.err.println("Must specify a data location")
+ // scalastyle:on println
+ printUsageAndExit(-1)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
similarity index 59%
rename from sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala
rename to sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
index 5643c58d9f847..fa5172ca8a3e7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
@@ -22,19 +22,26 @@ import java.util.Locale
import scala.reflect.{classTag, ClassTag}
+import org.apache.spark.sql.{AnalysisException, SaveMode}
import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.catalog._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans
+import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
+import org.apache.spark.sql.catalyst.expressions.JsonTuple
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical.Project
+import org.apache.spark.sql.catalyst.plans.logical.{Generate, InsertIntoDir, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{Project, ScriptTransformation}
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.execution.datasources.CreateTable
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
-// TODO: merge this with DDLSuite (SPARK-14441)
-class DDLCommandSuite extends PlanTest {
+class DDLParserSuite extends PlanTest with SharedSQLContext {
private lazy val parser = new SparkSqlParser(new SQLConf)
private def assertUnsupported(sql: String, containsThesePhrases: Seq[String] = Seq()): Unit = {
@@ -56,6 +63,17 @@ class DDLCommandSuite extends PlanTest {
}
}
+ private def compareTransformQuery(sql: String, expected: LogicalPlan): Unit = {
+ val plan = parser.parsePlan(sql).asInstanceOf[ScriptTransformation].copy(ioschema = null)
+ comparePlans(plan, expected, checkAnalysis = false)
+ }
+
+ private def extractTableDesc(sql: String): (CatalogTable, Boolean) = {
+ parser.parsePlan(sql).collect {
+ case CreateTable(tableDesc, mode, _) => (tableDesc, mode == SaveMode.Ignore)
+ }.head
+ }
+
test("create database") {
val sql =
"""
@@ -456,6 +474,26 @@ class DDLCommandSuite extends PlanTest {
}
}
+ test("create table - with table properties") {
+ val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet TBLPROPERTIES('test' = 'test')"
+
+ val expectedTableDesc = CatalogTable(
+ identifier = TableIdentifier("my_tab"),
+ tableType = CatalogTableType.MANAGED,
+ storage = CatalogStorageFormat.empty,
+ schema = new StructType().add("a", IntegerType).add("b", StringType),
+ provider = Some("parquet"),
+ properties = Map("test" -> "test"))
+
+ parser.parsePlan(sql) match {
+ case CreateTable(tableDesc, _, None) =>
+ assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime))
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," +
+ s"got ${other.getClass.getName}: $sql")
+ }
+ }
+
test("create table - with location") {
val v1 = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'"
@@ -487,6 +525,55 @@ class DDLCommandSuite extends PlanTest {
assert(e.message.contains("you can only specify one of them."))
}
+ test("insert overwrite directory") {
+ val v1 = "INSERT OVERWRITE DIRECTORY '/tmp/file' USING parquet SELECT 1 as a"
+ parser.parsePlan(v1) match {
+ case InsertIntoDir(_, storage, provider, query, overwrite) =>
+ assert(storage.locationUri.isDefined && storage.locationUri.get.toString == "/tmp/file")
+ case other =>
+ fail(s"Expected to parse ${classOf[InsertIntoDataSourceDirCommand].getClass.getName}" +
+ " from query," + s" got ${other.getClass.getName}: $v1")
+ }
+
+ val v2 = "INSERT OVERWRITE DIRECTORY USING parquet SELECT 1 as a"
+ val e2 = intercept[ParseException] {
+ parser.parsePlan(v2)
+ }
+ assert(e2.message.contains(
+ "Directory path and 'path' in OPTIONS should be specified one, but not both"))
+
+ val v3 =
+ """
+ | INSERT OVERWRITE DIRECTORY USING json
+ | OPTIONS ('path' '/tmp/file', a 1, b 0.1, c TRUE)
+ | SELECT 1 as a
+ """.stripMargin
+ parser.parsePlan(v3) match {
+ case InsertIntoDir(_, storage, provider, query, overwrite) =>
+ assert(storage.locationUri.isDefined && provider == Some("json"))
+ assert(storage.properties.get("a") == Some("1"))
+ assert(storage.properties.get("b") == Some("0.1"))
+ assert(storage.properties.get("c") == Some("true"))
+ assert(!storage.properties.contains("abc"))
+ assert(!storage.properties.contains("path"))
+ case other =>
+ fail(s"Expected to parse ${classOf[InsertIntoDataSourceDirCommand].getClass.getName}" +
+ " from query," + s"got ${other.getClass.getName}: $v1")
+ }
+
+ val v4 =
+ """
+ | INSERT OVERWRITE DIRECTORY '/tmp/file' USING json
+ | OPTIONS ('path' '/tmp/file', a 1, b 0.1, c TRUE)
+ | SELECT 1 as a
+ """.stripMargin
+ val e4 = intercept[ParseException] {
+ parser.parsePlan(v4)
+ }
+ assert(e4.message.contains(
+ "Directory path and 'path' in OPTIONS should be specified one, but not both"))
+ }
+
// ALTER TABLE table_name RENAME TO new_table_name;
// ALTER VIEW view_name RENAME TO new_view_name;
test("alter table/view: rename table/view") {
@@ -1046,4 +1133,553 @@ class DDLCommandSuite extends PlanTest {
s"got ${other.getClass.getName}: $sql")
}
}
+
+ test("Test CTAS #1") {
+ val s1 =
+ """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view
+ |COMMENT 'This is the staging page view table'
+ |STORED AS RCFILE
+ |LOCATION '/user/external/page_view'
+ |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
+ |AS SELECT * FROM src""".stripMargin
+
+ val (desc, exists) = extractTableDesc(s1)
+ assert(exists)
+ assert(desc.identifier.database == Some("mydb"))
+ assert(desc.identifier.table == "page_view")
+ assert(desc.tableType == CatalogTableType.EXTERNAL)
+ assert(desc.storage.locationUri == Some(new URI("/user/external/page_view")))
+ assert(desc.schema.isEmpty) // will be populated later when the table is actually created
+ assert(desc.comment == Some("This is the staging page view table"))
+ // TODO will be SQLText
+ assert(desc.viewText.isEmpty)
+ assert(desc.viewDefaultDatabase.isEmpty)
+ assert(desc.viewQueryColumnNames.isEmpty)
+ assert(desc.partitionColumnNames.isEmpty)
+ assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat"))
+ assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"))
+ assert(desc.storage.serde ==
+ Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe"))
+ assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2"))
+ }
+
+ test("Test CTAS #2") {
+ val s2 =
+ """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view
+ |COMMENT 'This is the staging page view table'
+ |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'
+ | STORED AS
+ | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat'
+ | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat'
+ |LOCATION '/user/external/page_view'
+ |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
+ |AS SELECT * FROM src""".stripMargin
+
+ val (desc, exists) = extractTableDesc(s2)
+ assert(exists)
+ assert(desc.identifier.database == Some("mydb"))
+ assert(desc.identifier.table == "page_view")
+ assert(desc.tableType == CatalogTableType.EXTERNAL)
+ assert(desc.storage.locationUri == Some(new URI("/user/external/page_view")))
+ assert(desc.schema.isEmpty) // will be populated later when the table is actually created
+ // TODO will be SQLText
+ assert(desc.comment == Some("This is the staging page view table"))
+ assert(desc.viewText.isEmpty)
+ assert(desc.viewDefaultDatabase.isEmpty)
+ assert(desc.viewQueryColumnNames.isEmpty)
+ assert(desc.partitionColumnNames.isEmpty)
+ assert(desc.storage.properties == Map())
+ assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat"))
+ assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat"))
+ assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe"))
+ assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2"))
+ }
+
+ test("Test CTAS #3") {
+ val s3 = """CREATE TABLE page_view AS SELECT * FROM src"""
+ val (desc, exists) = extractTableDesc(s3)
+ assert(exists == false)
+ assert(desc.identifier.database == None)
+ assert(desc.identifier.table == "page_view")
+ assert(desc.tableType == CatalogTableType.MANAGED)
+ assert(desc.storage.locationUri == None)
+ assert(desc.schema.isEmpty)
+ assert(desc.viewText == None) // TODO will be SQLText
+ assert(desc.viewDefaultDatabase.isEmpty)
+ assert(desc.viewQueryColumnNames.isEmpty)
+ assert(desc.storage.properties == Map())
+ assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat"))
+ assert(desc.storage.outputFormat ==
+ Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat"))
+ assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))
+ assert(desc.properties == Map())
+ }
+
+ test("Test CTAS #4") {
+ val s4 =
+ """CREATE TABLE page_view
+ |STORED BY 'storage.handler.class.name' AS SELECT * FROM src""".stripMargin
+ intercept[AnalysisException] {
+ extractTableDesc(s4)
+ }
+ }
+
+ test("Test CTAS #5") {
+ val s5 = """CREATE TABLE ctas2
+ | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe"
+ | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2")
+ | STORED AS RCFile
+ | TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22")
+ | AS
+ | SELECT key, value
+ | FROM src
+ | ORDER BY key, value""".stripMargin
+ val (desc, exists) = extractTableDesc(s5)
+ assert(exists == false)
+ assert(desc.identifier.database == None)
+ assert(desc.identifier.table == "ctas2")
+ assert(desc.tableType == CatalogTableType.MANAGED)
+ assert(desc.storage.locationUri == None)
+ assert(desc.schema.isEmpty)
+ assert(desc.viewText == None) // TODO will be SQLText
+ assert(desc.viewDefaultDatabase.isEmpty)
+ assert(desc.viewQueryColumnNames.isEmpty)
+ assert(desc.storage.properties == Map(("serde_p1" -> "p1"), ("serde_p2" -> "p2")))
+ assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat"))
+ assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"))
+ assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe"))
+ assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22")))
+ }
+
+ test("CTAS statement with a PARTITIONED BY clause is not allowed") {
+ assertUnsupported(s"CREATE TABLE ctas1 PARTITIONED BY (k int)" +
+ " AS SELECT key, value FROM (SELECT 1 as key, 2 as value) tmp")
+ }
+
+ test("CTAS statement with schema") {
+ assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT * FROM src")
+ assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT 1, 'hello'")
+ }
+
+ test("unsupported operations") {
+ intercept[ParseException] {
+ parser.parsePlan(
+ """
+ |CREATE TEMPORARY TABLE ctas2
+ |ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe"
+ |WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2")
+ |STORED AS RCFile
+ |TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22")
+ |AS SELECT key, value FROM src ORDER BY key, value
+ """.stripMargin)
+ }
+ intercept[ParseException] {
+ parser.parsePlan(
+ """
+ |CREATE TABLE user_info_bucketed(user_id BIGINT, firstname STRING, lastname STRING)
+ |CLUSTERED BY(user_id) INTO 256 BUCKETS
+ |AS SELECT key, value FROM src ORDER BY key, value
+ """.stripMargin)
+ }
+ intercept[ParseException] {
+ parser.parsePlan(
+ """
+ |CREATE TABLE user_info_bucketed(user_id BIGINT, firstname STRING, lastname STRING)
+ |SKEWED BY (key) ON (1,5,6)
+ |AS SELECT key, value FROM src ORDER BY key, value
+ """.stripMargin)
+ }
+ intercept[ParseException] {
+ parser.parsePlan(
+ """
+ |SELECT TRANSFORM (key, value) USING 'cat' AS (tKey, tValue)
+ |ROW FORMAT SERDE 'org.apache.hadoop.hive.contrib.serde2.TypedBytesSerDe'
+ |RECORDREADER 'org.apache.hadoop.hive.contrib.util.typedbytes.TypedBytesRecordReader'
+ |FROM testData
+ """.stripMargin)
+ }
+ }
+
+ test("Invalid interval term should throw AnalysisException") {
+ def assertError(sql: String, errorMessage: String): Unit = {
+ val e = intercept[AnalysisException] {
+ parser.parsePlan(sql)
+ }
+ assert(e.getMessage.contains(errorMessage))
+ }
+ assertError("select interval '42-32' year to month",
+ "month 32 outside range [0, 11]")
+ assertError("select interval '5 49:12:15' day to second",
+ "hour 49 outside range [0, 23]")
+ assertError("select interval '.1111111111' second",
+ "nanosecond 1111111111 outside range")
+ }
+
+ test("use native json_tuple instead of hive's UDTF in LATERAL VIEW") {
+ val analyzer = spark.sessionState.analyzer
+ val plan = analyzer.execute(parser.parsePlan(
+ """
+ |SELECT *
+ |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test
+ |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b
+ """.stripMargin))
+
+ assert(plan.children.head.asInstanceOf[Generate].generator.isInstanceOf[JsonTuple])
+ }
+
+ test("transform query spec") {
+ val p = ScriptTransformation(
+ Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")),
+ "func", Seq.empty, plans.table("e"), null)
+
+ compareTransformQuery("select transform(a, b) using 'func' from e where f < 10",
+ p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string)))
+ compareTransformQuery("map a, b using 'func' as c, d from e",
+ p.copy(output = Seq('c.string, 'd.string)))
+ compareTransformQuery("reduce a, b using 'func' as (c int, d decimal(10, 0)) from e",
+ p.copy(output = Seq('c.int, 'd.decimal(10, 0))))
+ }
+
+ test("use backticks in output of Script Transform") {
+ parser.parsePlan(
+ """SELECT `t`.`thing1`
+ |FROM (SELECT TRANSFORM (`parquet_t1`.`key`, `parquet_t1`.`value`)
+ |USING 'cat' AS (`thing1` int, `thing2` string) FROM `default`.`parquet_t1`) AS t
+ """.stripMargin)
+ }
+
+ test("use backticks in output of Generator") {
+ parser.parsePlan(
+ """
+ |SELECT `gentab2`.`gencol2`
+ |FROM `default`.`src`
+ |LATERAL VIEW explode(array(array(1, 2, 3))) `gentab1` AS `gencol1`
+ |LATERAL VIEW explode(`gentab1`.`gencol1`) `gentab2` AS `gencol2`
+ """.stripMargin)
+ }
+
+ test("use escaped backticks in output of Generator") {
+ parser.parsePlan(
+ """
+ |SELECT `gen``tab2`.`gen``col2`
+ |FROM `default`.`src`
+ |LATERAL VIEW explode(array(array(1, 2, 3))) `gen``tab1` AS `gen``col1`
+ |LATERAL VIEW explode(`gen``tab1`.`gen``col1`) `gen``tab2` AS `gen``col2`
+ """.stripMargin)
+ }
+
+ test("create table - basic") {
+ val query = "CREATE TABLE my_table (id int, name string)"
+ val (desc, allowExisting) = extractTableDesc(query)
+ assert(!allowExisting)
+ assert(desc.identifier.database.isEmpty)
+ assert(desc.identifier.table == "my_table")
+ assert(desc.tableType == CatalogTableType.MANAGED)
+ assert(desc.schema == new StructType().add("id", "int").add("name", "string"))
+ assert(desc.partitionColumnNames.isEmpty)
+ assert(desc.bucketSpec.isEmpty)
+ assert(desc.viewText.isEmpty)
+ assert(desc.viewDefaultDatabase.isEmpty)
+ assert(desc.viewQueryColumnNames.isEmpty)
+ assert(desc.storage.locationUri.isEmpty)
+ assert(desc.storage.inputFormat ==
+ Some("org.apache.hadoop.mapred.TextInputFormat"))
+ assert(desc.storage.outputFormat ==
+ Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat"))
+ assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))
+ assert(desc.storage.properties.isEmpty)
+ assert(desc.properties.isEmpty)
+ assert(desc.comment.isEmpty)
+ }
+
+ test("create table - with database name") {
+ val query = "CREATE TABLE dbx.my_table (id int, name string)"
+ val (desc, _) = extractTableDesc(query)
+ assert(desc.identifier.database == Some("dbx"))
+ assert(desc.identifier.table == "my_table")
+ }
+
+ test("create table - temporary") {
+ val query = "CREATE TEMPORARY TABLE tab1 (id int, name string)"
+ val e = intercept[ParseException] { parser.parsePlan(query) }
+ assert(e.message.contains("CREATE TEMPORARY TABLE is not supported yet"))
+ }
+
+ test("create table - external") {
+ val query = "CREATE EXTERNAL TABLE tab1 (id int, name string) LOCATION '/path/to/nowhere'"
+ val (desc, _) = extractTableDesc(query)
+ assert(desc.tableType == CatalogTableType.EXTERNAL)
+ assert(desc.storage.locationUri == Some(new URI("/path/to/nowhere")))
+ }
+
+ test("create table - if not exists") {
+ val query = "CREATE TABLE IF NOT EXISTS tab1 (id int, name string)"
+ val (_, allowExisting) = extractTableDesc(query)
+ assert(allowExisting)
+ }
+
+ test("create table - comment") {
+ val query = "CREATE TABLE my_table (id int, name string) COMMENT 'its hot as hell below'"
+ val (desc, _) = extractTableDesc(query)
+ assert(desc.comment == Some("its hot as hell below"))
+ }
+
+ test("create table - partitioned columns") {
+ val query = "CREATE TABLE my_table (id int, name string) PARTITIONED BY (month int)"
+ val (desc, _) = extractTableDesc(query)
+ assert(desc.schema == new StructType()
+ .add("id", "int")
+ .add("name", "string")
+ .add("month", "int"))
+ assert(desc.partitionColumnNames == Seq("month"))
+ }
+
+ test("create table - clustered by") {
+ val numBuckets = 10
+ val bucketedColumn = "id"
+ val sortColumn = "id"
+ val baseQuery =
+ s"""
+ CREATE TABLE my_table (
+ $bucketedColumn int,
+ name string)
+ CLUSTERED BY($bucketedColumn)
+ """
+
+ val query1 = s"$baseQuery INTO $numBuckets BUCKETS"
+ val (desc1, _) = extractTableDesc(query1)
+ assert(desc1.bucketSpec.isDefined)
+ val bucketSpec1 = desc1.bucketSpec.get
+ assert(bucketSpec1.numBuckets == numBuckets)
+ assert(bucketSpec1.bucketColumnNames.head.equals(bucketedColumn))
+ assert(bucketSpec1.sortColumnNames.isEmpty)
+
+ val query2 = s"$baseQuery SORTED BY($sortColumn) INTO $numBuckets BUCKETS"
+ val (desc2, _) = extractTableDesc(query2)
+ assert(desc2.bucketSpec.isDefined)
+ val bucketSpec2 = desc2.bucketSpec.get
+ assert(bucketSpec2.numBuckets == numBuckets)
+ assert(bucketSpec2.bucketColumnNames.head.equals(bucketedColumn))
+ assert(bucketSpec2.sortColumnNames.head.equals(sortColumn))
+ }
+
+ test("create table - skewed by") {
+ val baseQuery = "CREATE TABLE my_table (id int, name string) SKEWED BY"
+ val query1 = s"$baseQuery(id) ON (1, 10, 100)"
+ val query2 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z'))"
+ val query3 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z')) STORED AS DIRECTORIES"
+ val e1 = intercept[ParseException] { parser.parsePlan(query1) }
+ val e2 = intercept[ParseException] { parser.parsePlan(query2) }
+ val e3 = intercept[ParseException] { parser.parsePlan(query3) }
+ assert(e1.getMessage.contains("Operation not allowed"))
+ assert(e2.getMessage.contains("Operation not allowed"))
+ assert(e3.getMessage.contains("Operation not allowed"))
+ }
+
+ test("create table - row format") {
+ val baseQuery = "CREATE TABLE my_table (id int, name string) ROW FORMAT"
+ val query1 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff'"
+ val query2 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1')"
+ val query3 =
+ s"""
+ |$baseQuery DELIMITED FIELDS TERMINATED BY 'x' ESCAPED BY 'y'
+ |COLLECTION ITEMS TERMINATED BY 'a'
+ |MAP KEYS TERMINATED BY 'b'
+ |LINES TERMINATED BY '\n'
+ |NULL DEFINED AS 'c'
+ """.stripMargin
+ val (desc1, _) = extractTableDesc(query1)
+ val (desc2, _) = extractTableDesc(query2)
+ val (desc3, _) = extractTableDesc(query3)
+ assert(desc1.storage.serde == Some("org.apache.poof.serde.Baff"))
+ assert(desc1.storage.properties.isEmpty)
+ assert(desc2.storage.serde == Some("org.apache.poof.serde.Baff"))
+ assert(desc2.storage.properties == Map("k1" -> "v1"))
+ assert(desc3.storage.properties == Map(
+ "field.delim" -> "x",
+ "escape.delim" -> "y",
+ "serialization.format" -> "x",
+ "line.delim" -> "\n",
+ "colelction.delim" -> "a", // yes, it's a typo from Hive :)
+ "mapkey.delim" -> "b"))
+ }
+
+ test("create table - file format") {
+ val baseQuery = "CREATE TABLE my_table (id int, name string) STORED AS"
+ val query1 = s"$baseQuery INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput'"
+ val query2 = s"$baseQuery ORC"
+ val (desc1, _) = extractTableDesc(query1)
+ val (desc2, _) = extractTableDesc(query2)
+ assert(desc1.storage.inputFormat == Some("winput"))
+ assert(desc1.storage.outputFormat == Some("wowput"))
+ assert(desc1.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))
+ assert(desc2.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"))
+ assert(desc2.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat"))
+ assert(desc2.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde"))
+ }
+
+ test("create table - storage handler") {
+ val baseQuery = "CREATE TABLE my_table (id int, name string) STORED BY"
+ val query1 = s"$baseQuery 'org.papachi.StorageHandler'"
+ val query2 = s"$baseQuery 'org.mamachi.StorageHandler' WITH SERDEPROPERTIES ('k1'='v1')"
+ val e1 = intercept[ParseException] { parser.parsePlan(query1) }
+ val e2 = intercept[ParseException] { parser.parsePlan(query2) }
+ assert(e1.getMessage.contains("Operation not allowed"))
+ assert(e2.getMessage.contains("Operation not allowed"))
+ }
+
+ test("create table - properties") {
+ val query = "CREATE TABLE my_table (id int, name string) TBLPROPERTIES ('k1'='v1', 'k2'='v2')"
+ val (desc, _) = extractTableDesc(query)
+ assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2"))
+ }
+
+ test("create table - everything!") {
+ val query =
+ """
+ |CREATE EXTERNAL TABLE IF NOT EXISTS dbx.my_table (id int, name string)
+ |COMMENT 'no comment'
+ |PARTITIONED BY (month int)
+ |ROW FORMAT SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1')
+ |STORED AS INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput'
+ |LOCATION '/path/to/mercury'
+ |TBLPROPERTIES ('k1'='v1', 'k2'='v2')
+ """.stripMargin
+ val (desc, allowExisting) = extractTableDesc(query)
+ assert(allowExisting)
+ assert(desc.identifier.database == Some("dbx"))
+ assert(desc.identifier.table == "my_table")
+ assert(desc.tableType == CatalogTableType.EXTERNAL)
+ assert(desc.schema == new StructType()
+ .add("id", "int")
+ .add("name", "string")
+ .add("month", "int"))
+ assert(desc.partitionColumnNames == Seq("month"))
+ assert(desc.bucketSpec.isEmpty)
+ assert(desc.viewText.isEmpty)
+ assert(desc.viewDefaultDatabase.isEmpty)
+ assert(desc.viewQueryColumnNames.isEmpty)
+ assert(desc.storage.locationUri == Some(new URI("/path/to/mercury")))
+ assert(desc.storage.inputFormat == Some("winput"))
+ assert(desc.storage.outputFormat == Some("wowput"))
+ assert(desc.storage.serde == Some("org.apache.poof.serde.Baff"))
+ assert(desc.storage.properties == Map("k1" -> "v1"))
+ assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2"))
+ assert(desc.comment == Some("no comment"))
+ }
+
+ test("create view -- basic") {
+ val v1 = "CREATE VIEW view1 AS SELECT * FROM tab1"
+ val command = parser.parsePlan(v1).asInstanceOf[CreateViewCommand]
+ assert(!command.allowExisting)
+ assert(command.name.database.isEmpty)
+ assert(command.name.table == "view1")
+ assert(command.originalText == Some("SELECT * FROM tab1"))
+ assert(command.userSpecifiedColumns.isEmpty)
+ }
+
+ test("create view - full") {
+ val v1 =
+ """
+ |CREATE OR REPLACE VIEW view1
+ |(col1, col3 COMMENT 'hello')
+ |COMMENT 'BLABLA'
+ |TBLPROPERTIES('prop1Key'="prop1Val")
+ |AS SELECT * FROM tab1
+ """.stripMargin
+ val command = parser.parsePlan(v1).asInstanceOf[CreateViewCommand]
+ assert(command.name.database.isEmpty)
+ assert(command.name.table == "view1")
+ assert(command.userSpecifiedColumns == Seq("col1" -> None, "col3" -> Some("hello")))
+ assert(command.originalText == Some("SELECT * FROM tab1"))
+ assert(command.properties == Map("prop1Key" -> "prop1Val"))
+ assert(command.comment == Some("BLABLA"))
+ }
+
+ test("create view -- partitioned view") {
+ val v1 = "CREATE VIEW view1 partitioned on (ds, hr) as select * from srcpart"
+ intercept[ParseException] {
+ parser.parsePlan(v1)
+ }
+ }
+
+ test("MSCK REPAIR table") {
+ val sql = "MSCK REPAIR TABLE tab1"
+ val parsed = parser.parsePlan(sql)
+ val expected = AlterTableRecoverPartitionsCommand(
+ TableIdentifier("tab1", None),
+ "MSCK REPAIR TABLE")
+ comparePlans(parsed, expected)
+ }
+
+ test("create table like") {
+ val v1 = "CREATE TABLE table1 LIKE table2"
+ val (target, source, location, exists) = parser.parsePlan(v1).collect {
+ case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting)
+ }.head
+ assert(exists == false)
+ assert(target.database.isEmpty)
+ assert(target.table == "table1")
+ assert(source.database.isEmpty)
+ assert(source.table == "table2")
+ assert(location.isEmpty)
+
+ val v2 = "CREATE TABLE IF NOT EXISTS table1 LIKE table2"
+ val (target2, source2, location2, exists2) = parser.parsePlan(v2).collect {
+ case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting)
+ }.head
+ assert(exists2)
+ assert(target2.database.isEmpty)
+ assert(target2.table == "table1")
+ assert(source2.database.isEmpty)
+ assert(source2.table == "table2")
+ assert(location2.isEmpty)
+
+ val v3 = "CREATE TABLE table1 LIKE table2 LOCATION '/spark/warehouse'"
+ val (target3, source3, location3, exists3) = parser.parsePlan(v3).collect {
+ case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting)
+ }.head
+ assert(!exists3)
+ assert(target3.database.isEmpty)
+ assert(target3.table == "table1")
+ assert(source3.database.isEmpty)
+ assert(source3.table == "table2")
+ assert(location3 == Some("/spark/warehouse"))
+
+ val v4 = "CREATE TABLE IF NOT EXISTS table1 LIKE table2 LOCATION '/spark/warehouse'"
+ val (target4, source4, location4, exists4) = parser.parsePlan(v4).collect {
+ case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting)
+ }.head
+ assert(exists4)
+ assert(target4.database.isEmpty)
+ assert(target4.table == "table1")
+ assert(source4.database.isEmpty)
+ assert(source4.table == "table2")
+ assert(location4 == Some("/spark/warehouse"))
+ }
+
+ test("load data") {
+ val v1 = "LOAD DATA INPATH 'path' INTO TABLE table1"
+ val (table, path, isLocal, isOverwrite, partition) = parser.parsePlan(v1).collect {
+ case LoadDataCommand(t, path, l, o, partition) => (t, path, l, o, partition)
+ }.head
+ assert(table.database.isEmpty)
+ assert(table.table == "table1")
+ assert(path == "path")
+ assert(!isLocal)
+ assert(!isOverwrite)
+ assert(partition.isEmpty)
+
+ val v2 = "LOAD DATA LOCAL INPATH 'path' OVERWRITE INTO TABLE table1 PARTITION(c='1', d='2')"
+ val (table2, path2, isLocal2, isOverwrite2, partition2) = parser.parsePlan(v2).collect {
+ case LoadDataCommand(t, path, l, o, partition) => (t, path, l, o, partition)
+ }.head
+ assert(table2.database.isEmpty)
+ assert(table2.table == "table1")
+ assert(path2 == "path")
+ assert(isLocal2)
+ assert(isOverwrite2)
+ assert(partition2.nonEmpty)
+ assert(partition2.get.apply("c") == "1" && partition2.get.apply("d") == "2")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
index 9332f773430e7..d19cfeef7d19f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
@@ -783,7 +783,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
val df = (1 to 2).map { i => (i, i.toString) }.toDF("age", "name")
df.write.insertInto("students")
spark.catalog.cacheTable("students")
- assume(spark.table("students").collect().toSeq == df.collect().toSeq, "bad test: wrong data")
+ checkAnswer(spark.table("students"), df)
assume(spark.catalog.isCached("students"), "bad test: table was not cached in the first place")
sql("ALTER TABLE students RENAME TO teachers")
sql("CREATE TABLE students (age INT, name STRING) USING parquet")
@@ -792,7 +792,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
assert(!spark.catalog.isCached("students"))
assert(spark.catalog.isCached("teachers"))
assert(spark.table("students").collect().isEmpty)
- assert(spark.table("teachers").collect().toSeq == df.collect().toSeq)
+ checkAnswer(spark.table("teachers"), df)
}
test("rename temporary table - destination table with database name") {
@@ -2357,18 +2357,9 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
}.getMessage
assert(e.contains("Found duplicate column(s)"))
} else {
- if (isUsingHiveMetastore) {
- // hive catalog will still complains that c1 is duplicate column name because hive
- // identifiers are case insensitive.
- val e = intercept[AnalysisException] {
- sql("ALTER TABLE t1 ADD COLUMNS (C1 string)")
- }.getMessage
- assert(e.contains("HiveException"))
- } else {
- sql("ALTER TABLE t1 ADD COLUMNS (C1 string)")
- assert(spark.table("t1").schema
- .equals(new StructType().add("c1", IntegerType).add("C1", StringType)))
- }
+ sql("ALTER TABLE t1 ADD COLUMNS (C1 string)")
+ assert(spark.table("t1").schema ==
+ new StructType().add("c1", IntegerType).add("C1", StringType))
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
index 8c15a28c42a66..8abfd7e2be3dd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala
@@ -571,7 +571,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
if (buckets > 0) {
val bucketed = df.queryExecution.analyzed transform {
- case l @ LogicalRelation(r: HadoopFsRelation, _, _) =>
+ case l @ LogicalRelation(r: HadoopFsRelation, _, _, _) =>
l.copy(relation =
r.copy(bucketSpec =
Some(BucketSpec(numBuckets = buckets, "c1" :: Nil, Nil)))(r.sparkSession))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 243a55cffd47f..e439699605abb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -1195,4 +1195,54 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
.csv(Seq("10u12").toDS())
checkAnswer(results, Row(null))
}
+
+ test("SPARK-20978: Fill the malformed column when the number of tokens is less than schema") {
+ val df = spark.read
+ .schema("a string, b string, unparsed string")
+ .option("columnNameOfCorruptRecord", "unparsed")
+ .csv(Seq("a").toDS())
+ checkAnswer(df, Row("a", null, "a"))
+ }
+
+ test("SPARK-21610: Corrupt records are not handled properly when creating a dataframe " +
+ "from a file") {
+ val columnNameOfCorruptRecord = "_corrupt_record"
+ val schema = new StructType()
+ .add("a", IntegerType)
+ .add("b", TimestampType)
+ .add(columnNameOfCorruptRecord, StringType)
+ // negative cases
+ val msg = intercept[AnalysisException] {
+ spark
+ .read
+ .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
+ .schema(schema)
+ .csv(testFile(valueMalformedFile))
+ .select(columnNameOfCorruptRecord)
+ .collect()
+ }.getMessage
+ assert(msg.contains("only include the internal corrupt record column"))
+ intercept[org.apache.spark.sql.catalyst.errors.TreeNodeException[_]] {
+ spark
+ .read
+ .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
+ .schema(schema)
+ .csv(testFile(valueMalformedFile))
+ .filter($"_corrupt_record".isNotNull)
+ .count()
+ }
+ // workaround
+ val df = spark
+ .read
+ .option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
+ .schema(schema)
+ .csv(testFile(valueMalformedFile))
+ .cache()
+ assert(df.filter($"_corrupt_record".isNotNull).count() == 1)
+ assert(df.filter($"_corrupt_record".isNull).count() == 1)
+ checkAnswer(
+ df.select(columnNameOfCorruptRecord),
+ Row("0,2013-111-11 12:13:14") :: Row(null) :: Nil
+ )
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala
new file mode 100644
index 0000000000000..7d277c1ffaffe
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.jdbc
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.parser.ParseException
+import org.apache.spark.sql.types._
+
+class JdbcUtilsSuite extends SparkFunSuite {
+
+ val tableSchema = StructType(Seq(
+ StructField("C1", StringType, false), StructField("C2", IntegerType, false)))
+ val caseSensitive = org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
+ val caseInsensitive = org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution
+
+ test("Parse user specified column types") {
+ assert(JdbcUtils.getCustomSchema(tableSchema, null, caseInsensitive) === tableSchema)
+ assert(JdbcUtils.getCustomSchema(tableSchema, "", caseInsensitive) === tableSchema)
+
+ assert(JdbcUtils.getCustomSchema(tableSchema, "c1 DATE", caseInsensitive) ===
+ StructType(Seq(StructField("C1", DateType, false), StructField("C2", IntegerType, false))))
+ assert(JdbcUtils.getCustomSchema(tableSchema, "c1 DATE", caseSensitive) ===
+ StructType(Seq(StructField("C1", StringType, false), StructField("C2", IntegerType, false))))
+
+ assert(
+ JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, C2 STRING", caseInsensitive) ===
+ StructType(Seq(StructField("C1", DateType, false), StructField("C2", StringType, false))))
+ assert(JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, C2 STRING", caseSensitive) ===
+ StructType(Seq(StructField("C1", StringType, false), StructField("C2", StringType, false))))
+
+ // Throw AnalysisException
+ val duplicate = intercept[AnalysisException]{
+ JdbcUtils.getCustomSchema(tableSchema, "c1 DATE, c1 STRING", caseInsensitive) ===
+ StructType(Seq(StructField("c1", DateType, false), StructField("c1", StringType, false)))
+ }
+ assert(duplicate.getMessage.contains(
+ "Found duplicate column(s) in the customSchema option value"))
+
+ // Throw ParseException
+ val dataTypeNotSupported = intercept[ParseException]{
+ JdbcUtils.getCustomSchema(tableSchema, "c3 DATEE, C2 STRING", caseInsensitive) ===
+ StructType(Seq(StructField("c3", DateType, false), StructField("C2", StringType, false)))
+ }
+ assert(dataTypeNotSupported.getMessage.contains("DataType datee is not supported"))
+
+ val mismatchedInput = intercept[ParseException]{
+ JdbcUtils.getCustomSchema(tableSchema, "c3 DATE. C2 STRING", caseInsensitive) ===
+ StructType(Seq(StructField("c3", DateType, false), StructField("C2", StringType, false)))
+ }
+ assert(mismatchedInput.getMessage.contains("mismatched input '.' expecting"))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala
index 6e2b4f0df595f..316c5183fddf1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala
@@ -72,6 +72,21 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext {
assert(df.first().getString(0) == "Reynold Xin")
}
+ test("allowUnquotedControlChars off") {
+ val str = """{"name": "a\u0001b"}"""
+ val df = spark.read.json(Seq(str).toDS())
+
+ assert(df.schema.head.name == "_corrupt_record")
+ }
+
+ test("allowUnquotedControlChars on") {
+ val str = """{"name": "a\u0001b"}"""
+ val df = spark.read.option("allowUnquotedControlChars", "true").json(Seq(str).toDS())
+
+ assert(df.schema.head.name == "name")
+ assert(df.first().getString(0) == "a\u0001b")
+ }
+
test("allowNumericLeadingZeros off") {
val str = """{"age": 0018}"""
val df = spark.read.json(Seq(str).toDS())
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 0008954e36bdd..8c8d41ebf115a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -2034,4 +2034,33 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
}
}
+
+ test("SPARK-21610: Corrupt records are not handled properly when creating a dataframe " +
+ "from a file") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+ val data =
+ """{"field": 1}
+ |{"field": 2}
+ |{"field": "3"}""".stripMargin
+ Seq(data).toDF().repartition(1).write.text(path)
+ val schema = new StructType().add("field", ByteType).add("_corrupt_record", StringType)
+ // negative cases
+ val msg = intercept[AnalysisException] {
+ spark.read.schema(schema).json(path).select("_corrupt_record").collect()
+ }.getMessage
+ assert(msg.contains("only include the internal corrupt record column"))
+ intercept[catalyst.errors.TreeNodeException[_]] {
+ spark.read.schema(schema).json(path).filter($"_corrupt_record".isNotNull).count()
+ }
+ // workaround
+ val df = spark.read.schema(schema).json(path).cache()
+ assert(df.filter($"_corrupt_record".isNotNull).count() == 1)
+ assert(df.filter($"_corrupt_record".isNull).count() == 2)
+ checkAnswer(
+ df.select("_corrupt_record"),
+ Row(null) :: Row(null) :: Row("{\"field\": \"3\"}") :: Nil
+ )
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index d353bf39530fd..17628f4c830d6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -66,7 +66,8 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
var maybeRelation: Option[HadoopFsRelation] = None
val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect {
- case PhysicalOperation(_, filters, LogicalRelation(relation: HadoopFsRelation, _, _)) =>
+ case PhysicalOperation(_, filters,
+ LogicalRelation(relation: HadoopFsRelation, _, _, _)) =>
maybeRelation = Some(relation)
filters
}.flatten.reduceLeftOption(_ && _)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
index 0522207653339..3ee16b42ebb39 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
@@ -651,7 +651,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
val queryExecution = spark.read.parquet(dir.getCanonicalPath).queryExecution
queryExecution.analyzed.collectFirst {
case LogicalRelation(
- HadoopFsRelation(location: PartitioningAwareFileIndex, _, _, _, _, _), _, _) =>
+ HadoopFsRelation(location: PartitioningAwareFileIndex, _, _, _, _, _), _, _, _) =>
assert(location.partitionSpec() === PartitionSpec.emptySpec)
}.getOrElse {
fail(s"Expecting a matching HadoopFsRelation, but got:\n$queryExecution")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index fd793233b0bc1..0dc612ef735fa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -19,98 +19,20 @@ package org.apache.spark.sql.execution.metric
import java.io.File
-import scala.collection.mutable.HashMap
import scala.util.Random
import org.apache.spark.SparkFunSuite
-import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.execution.SparkPlanInfo
-import org.apache.spark.sql.execution.ui.SparkPlanGraph
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.{AccumulatorContext, JsonProtocol}
-class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
+class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with SharedSQLContext {
import testImplicits._
- /**
- * Call `df.collect()` and collect necessary metrics from execution data.
- *
- * @param df `DataFrame` to run
- * @param expectedNumOfJobs number of jobs that will run
- * @param expectedNodeIds the node ids of the metrics to collect from execution data.
- */
- private def getSparkPlanMetrics(
- df: DataFrame,
- expectedNumOfJobs: Int,
- expectedNodeIds: Set[Long],
- enableWholeStage: Boolean = false): Option[Map[Long, (String, Map[String, Any])]] = {
- val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet
- withSQLConf("spark.sql.codegen.wholeStage" -> enableWholeStage.toString) {
- df.collect()
- }
- sparkContext.listenerBus.waitUntilEmpty(10000)
- val executionIds =
- spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds)
- assert(executionIds.size === 1)
- val executionId = executionIds.head
- val jobs = spark.sharedState.listener.getExecution(executionId).get.jobs
- // Use "<=" because there is a race condition that we may miss some jobs
- // TODO Change it to "=" once we fix the race condition that missing the JobStarted event.
- assert(jobs.size <= expectedNumOfJobs)
- if (jobs.size == expectedNumOfJobs) {
- // If we can track all jobs, check the metric values
- val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId)
- val metrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan(
- df.queryExecution.executedPlan)).allNodes.filter { node =>
- expectedNodeIds.contains(node.id)
- }.map { node =>
- val nodeMetrics = node.metrics.map { metric =>
- val metricValue = metricValues(metric.accumulatorId)
- (metric.name, metricValue)
- }.toMap
- (node.id, node.name -> nodeMetrics)
- }.toMap
- Some(metrics)
- } else {
- // TODO Remove this "else" once we fix the race condition that missing the JobStarted event.
- // Since we cannot track all jobs, the metric values could be wrong and we should not check
- // them.
- logWarning("Due to a race condition, we miss some jobs and cannot verify the metric values")
- None
- }
- }
-
- /**
- * Call `df.collect()` and verify if the collected metrics are same as "expectedMetrics".
- *
- * @param df `DataFrame` to run
- * @param expectedNumOfJobs number of jobs that will run
- * @param expectedMetrics the expected metrics. The format is
- * `nodeId -> (operatorName, metric name -> metric value)`.
- */
- private def testSparkPlanMetrics(
- df: DataFrame,
- expectedNumOfJobs: Int,
- expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = {
- val optActualMetrics = getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetrics.keySet)
- optActualMetrics.map { actualMetrics =>
- assert(expectedMetrics.keySet === actualMetrics.keySet)
- for (nodeId <- expectedMetrics.keySet) {
- val (expectedNodeName, expectedMetricsMap) = expectedMetrics(nodeId)
- val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId)
- assert(expectedNodeName === actualNodeName)
- for (metricName <- expectedMetricsMap.keySet) {
- assert(expectedMetricsMap(metricName).toString === actualMetricsMap(metricName))
- }
- }
- }
- }
-
/**
* Generates a `DataFrame` by filling randomly generated bytes for hash collision.
*/
@@ -570,75 +492,12 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
assert(res3 === (10L, 0L, 10L) :: (30L, 0L, 30L) :: (0L, 30L, 300L) :: (0L, 300L, 0L) :: Nil)
}
}
-}
-
-object InputOutputMetricsHelper {
- private class InputOutputMetricsListener extends SparkListener {
- private case class MetricsResult(
- var recordsRead: Long = 0L,
- var shuffleRecordsRead: Long = 0L,
- var sumMaxOutputRows: Long = 0L)
- private[this] val stageIdToMetricsResult = HashMap.empty[Int, MetricsResult]
-
- def reset(): Unit = {
- stageIdToMetricsResult.clear()
- }
-
- /**
- * Return a list of recorded metrics aggregated per stage.
- *
- * The list is sorted in the ascending order on the stageId.
- * For each recorded stage, the following tuple is returned:
- * - sum of inputMetrics.recordsRead for all the tasks in the stage
- * - sum of shuffleReadMetrics.recordsRead for all the tasks in the stage
- * - sum of the highest values of "number of output rows" metric for all the tasks in the stage
- */
- def getResults(): List[(Long, Long, Long)] = {
- stageIdToMetricsResult.keySet.toList.sorted.map { stageId =>
- val res = stageIdToMetricsResult(stageId)
- (res.recordsRead, res.shuffleRecordsRead, res.sumMaxOutputRows)
- }
- }
-
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
- val res = stageIdToMetricsResult.getOrElseUpdate(taskEnd.stageId, MetricsResult())
-
- res.recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead
- res.shuffleRecordsRead += taskEnd.taskMetrics.shuffleReadMetrics.recordsRead
-
- var maxOutputRows = 0L
- for (accum <- taskEnd.taskMetrics.externalAccums) {
- val info = accum.toInfo(Some(accum.value), None)
- if (info.name.toString.contains("number of output rows")) {
- info.update match {
- case Some(n: Number) =>
- if (n.longValue() > maxOutputRows) {
- maxOutputRows = n.longValue()
- }
- case _ => // Ignore.
- }
- }
- }
- res.sumMaxOutputRows += maxOutputRows
- }
+ test("writing data out metrics: parquet") {
+ testMetricsNonDynamicPartition("parquet", "t1")
}
- // Run df.collect() and return aggregated metrics for each stage.
- def run(df: DataFrame): List[(Long, Long, Long)] = {
- val spark = df.sparkSession
- val sparkContext = spark.sparkContext
- val listener = new InputOutputMetricsListener()
- sparkContext.addSparkListener(listener)
-
- try {
- sparkContext.listenerBus.waitUntilEmpty(5000)
- listener.reset()
- df.collect()
- sparkContext.listenerBus.waitUntilEmpty(5000)
- } finally {
- sparkContext.removeSparkListener(listener)
- }
- listener.getResults()
+ test("writing data out metrics with dynamic partition: parquet") {
+ testMetricsDynamicPartition("parquet", "parquet", "t1")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala
new file mode 100644
index 0000000000000..3966e98c1ce06
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsTestUtils.scala
@@ -0,0 +1,270 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.metric
+
+import java.io.File
+
+import scala.collection.mutable.HashMap
+
+import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.execution.SparkPlanInfo
+import org.apache.spark.sql.execution.ui.SparkPlanGraph
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.util.Utils
+
+
+trait SQLMetricsTestUtils extends SQLTestUtils {
+
+ import testImplicits._
+
+ /**
+ * Get execution metrics for the SQL execution and verify metrics values.
+ *
+ * @param metricsValues the expected metric values (numFiles, numPartitions, numOutputRows).
+ * @param func the function can produce execution id after running.
+ */
+ private def verifyWriteDataMetrics(metricsValues: Seq[Int])(func: => Unit): Unit = {
+ val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet
+ // Run the given function to trigger query execution.
+ func
+ spark.sparkContext.listenerBus.waitUntilEmpty(10000)
+ val executionIds =
+ spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds)
+ assert(executionIds.size == 1)
+ val executionId = executionIds.head
+
+ val executionData = spark.sharedState.listener.getExecution(executionId).get
+ val executedNode = executionData.physicalPlanGraph.nodes.head
+
+ val metricsNames = Seq(
+ "number of written files",
+ "number of dynamic part",
+ "number of output rows")
+
+ val metrics = spark.sharedState.listener.getExecutionMetrics(executionId)
+
+ metricsNames.zip(metricsValues).foreach { case (metricsName, expected) =>
+ val sqlMetric = executedNode.metrics.find(_.name == metricsName)
+ assert(sqlMetric.isDefined)
+ val accumulatorId = sqlMetric.get.accumulatorId
+ val metricValue = metrics(accumulatorId).replaceAll(",", "").toInt
+ assert(metricValue == expected)
+ }
+
+ val totalNumBytesMetric = executedNode.metrics.find(_.name == "bytes of written output").get
+ val totalNumBytes = metrics(totalNumBytesMetric.accumulatorId).replaceAll(",", "").toInt
+ assert(totalNumBytes > 0)
+ }
+
+ protected def testMetricsNonDynamicPartition(
+ dataFormat: String,
+ tableName: String): Unit = {
+ withTable(tableName) {
+ Seq((1, 2)).toDF("i", "j")
+ .write.format(dataFormat).mode("overwrite").saveAsTable(tableName)
+
+ val tableLocation =
+ new File(spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).location)
+
+ // 2 files, 100 rows, 0 dynamic partition.
+ verifyWriteDataMetrics(Seq(2, 0, 100)) {
+ (0 until 100).map(i => (i, i + 1)).toDF("i", "j").repartition(2)
+ .write.format(dataFormat).mode("overwrite").insertInto(tableName)
+ }
+ assert(Utils.recursiveList(tableLocation).count(_.getName.startsWith("part-")) == 2)
+ }
+ }
+
+ protected def testMetricsDynamicPartition(
+ provider: String,
+ dataFormat: String,
+ tableName: String): Unit = {
+ withTempPath { dir =>
+ spark.sql(
+ s"""
+ |CREATE TABLE $tableName(a int, b int)
+ |USING $provider
+ |PARTITIONED BY(a)
+ |LOCATION '${dir.toURI}'
+ """.stripMargin)
+ val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName))
+ assert(table.location == makeQualifiedPath(dir.getAbsolutePath))
+
+ val df = spark.range(start = 0, end = 40, step = 1, numPartitions = 1)
+ .selectExpr("id a", "id b")
+
+ // 40 files, 80 rows, 40 dynamic partitions.
+ verifyWriteDataMetrics(Seq(40, 40, 80)) {
+ df.union(df).repartition(2, $"a")
+ .write
+ .format(dataFormat)
+ .mode("overwrite")
+ .insertInto(tableName)
+ }
+ assert(Utils.recursiveList(dir).count(_.getName.startsWith("part-")) == 40)
+ }
+ }
+
+ /**
+ * Call `df.collect()` and collect necessary metrics from execution data.
+ *
+ * @param df `DataFrame` to run
+ * @param expectedNumOfJobs number of jobs that will run
+ * @param expectedNodeIds the node ids of the metrics to collect from execution data.
+ */
+ protected def getSparkPlanMetrics(
+ df: DataFrame,
+ expectedNumOfJobs: Int,
+ expectedNodeIds: Set[Long],
+ enableWholeStage: Boolean = false): Option[Map[Long, (String, Map[String, Any])]] = {
+ val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet
+ withSQLConf("spark.sql.codegen.wholeStage" -> enableWholeStage.toString) {
+ df.collect()
+ }
+ sparkContext.listenerBus.waitUntilEmpty(10000)
+ val executionIds =
+ spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds)
+ assert(executionIds.size === 1)
+ val executionId = executionIds.head
+ val jobs = spark.sharedState.listener.getExecution(executionId).get.jobs
+ // Use "<=" because there is a race condition that we may miss some jobs
+ // TODO Change it to "=" once we fix the race condition that missing the JobStarted event.
+ assert(jobs.size <= expectedNumOfJobs)
+ if (jobs.size == expectedNumOfJobs) {
+ // If we can track all jobs, check the metric values
+ val metricValues = spark.sharedState.listener.getExecutionMetrics(executionId)
+ val metrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan(
+ df.queryExecution.executedPlan)).allNodes.filter { node =>
+ expectedNodeIds.contains(node.id)
+ }.map { node =>
+ val nodeMetrics = node.metrics.map { metric =>
+ val metricValue = metricValues(metric.accumulatorId)
+ (metric.name, metricValue)
+ }.toMap
+ (node.id, node.name -> nodeMetrics)
+ }.toMap
+ Some(metrics)
+ } else {
+ // TODO Remove this "else" once we fix the race condition that missing the JobStarted event.
+ // Since we cannot track all jobs, the metric values could be wrong and we should not check
+ // them.
+ logWarning("Due to a race condition, we miss some jobs and cannot verify the metric values")
+ None
+ }
+ }
+
+ /**
+ * Call `df.collect()` and verify if the collected metrics are same as "expectedMetrics".
+ *
+ * @param df `DataFrame` to run
+ * @param expectedNumOfJobs number of jobs that will run
+ * @param expectedMetrics the expected metrics. The format is
+ * `nodeId -> (operatorName, metric name -> metric value)`.
+ */
+ protected def testSparkPlanMetrics(
+ df: DataFrame,
+ expectedNumOfJobs: Int,
+ expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = {
+ val optActualMetrics = getSparkPlanMetrics(df, expectedNumOfJobs, expectedMetrics.keySet)
+ optActualMetrics.foreach { actualMetrics =>
+ assert(expectedMetrics.keySet === actualMetrics.keySet)
+ for (nodeId <- expectedMetrics.keySet) {
+ val (expectedNodeName, expectedMetricsMap) = expectedMetrics(nodeId)
+ val (actualNodeName, actualMetricsMap) = actualMetrics(nodeId)
+ assert(expectedNodeName === actualNodeName)
+ for (metricName <- expectedMetricsMap.keySet) {
+ assert(expectedMetricsMap(metricName).toString === actualMetricsMap(metricName))
+ }
+ }
+ }
+ }
+}
+
+
+object InputOutputMetricsHelper {
+ private class InputOutputMetricsListener extends SparkListener {
+ private case class MetricsResult(
+ var recordsRead: Long = 0L,
+ var shuffleRecordsRead: Long = 0L,
+ var sumMaxOutputRows: Long = 0L)
+
+ private[this] val stageIdToMetricsResult = HashMap.empty[Int, MetricsResult]
+
+ def reset(): Unit = {
+ stageIdToMetricsResult.clear()
+ }
+
+ /**
+ * Return a list of recorded metrics aggregated per stage.
+ *
+ * The list is sorted in the ascending order on the stageId.
+ * For each recorded stage, the following tuple is returned:
+ * - sum of inputMetrics.recordsRead for all the tasks in the stage
+ * - sum of shuffleReadMetrics.recordsRead for all the tasks in the stage
+ * - sum of the highest values of "number of output rows" metric for all the tasks in the stage
+ */
+ def getResults(): List[(Long, Long, Long)] = {
+ stageIdToMetricsResult.keySet.toList.sorted.map { stageId =>
+ val res = stageIdToMetricsResult(stageId)
+ (res.recordsRead, res.shuffleRecordsRead, res.sumMaxOutputRows)
+ }
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
+ val res = stageIdToMetricsResult.getOrElseUpdate(taskEnd.stageId, MetricsResult())
+
+ res.recordsRead += taskEnd.taskMetrics.inputMetrics.recordsRead
+ res.shuffleRecordsRead += taskEnd.taskMetrics.shuffleReadMetrics.recordsRead
+
+ var maxOutputRows = 0L
+ for (accum <- taskEnd.taskMetrics.externalAccums) {
+ val info = accum.toInfo(Some(accum.value), None)
+ if (info.name.toString.contains("number of output rows")) {
+ info.update match {
+ case Some(n: Number) =>
+ if (n.longValue() > maxOutputRows) {
+ maxOutputRows = n.longValue()
+ }
+ case _ => // Ignore.
+ }
+ }
+ }
+ res.sumMaxOutputRows += maxOutputRows
+ }
+ }
+
+ // Run df.collect() and return aggregated metrics for each stage.
+ def run(df: DataFrame): List[(Long, Long, Long)] = {
+ val spark = df.sparkSession
+ val sparkContext = spark.sparkContext
+ val listener = new InputOutputMetricsListener()
+ sparkContext.addSparkListener(listener)
+
+ try {
+ sparkContext.listenerBus.waitUntilEmpty(5000)
+ listener.reset()
+ df.collect()
+ sparkContext.listenerBus.waitUntilEmpty(5000)
+ } finally {
+ sparkContext.removeSparkListener(listener)
+ }
+ listener.getResults()
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala
index 007554a83f548..519e3c01afe8a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable
import org.eclipse.jetty.util.ConcurrentHashSet
import org.scalatest.concurrent.Eventually
import org.scalatest.concurrent.PatienceConfiguration.Timeout
-import org.scalatest.concurrent.Timeouts._
+import org.scalatest.concurrent.TimeLimits._
import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkFunSuite
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala
index 9ebf4d2835266..ec11549073650 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/TextSocketStreamSuite.scala
@@ -65,20 +65,22 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
while (source.getOffset.isEmpty) {
Thread.sleep(10)
}
- val offset1 = source.getOffset.get
- val batch1 = source.getBatch(None, offset1)
- assert(batch1.as[String].collect().toSeq === Seq("hello"))
+ withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
+ val offset1 = source.getOffset.get
+ val batch1 = source.getBatch(None, offset1)
+ assert(batch1.as[String].collect().toSeq === Seq("hello"))
+
+ serverThread.enqueue("world")
+ while (source.getOffset.get === offset1) {
+ Thread.sleep(10)
+ }
+ val offset2 = source.getOffset.get
+ val batch2 = source.getBatch(Some(offset1), offset2)
+ assert(batch2.as[String].collect().toSeq === Seq("world"))
- serverThread.enqueue("world")
- while (source.getOffset.get === offset1) {
- Thread.sleep(10)
+ val both = source.getBatch(None, offset2)
+ assert(both.as[String].collect().sorted.toSeq === Seq("hello", "world"))
}
- val offset2 = source.getOffset.get
- val batch2 = source.getBatch(Some(offset1), offset2)
- assert(batch2.as[String].collect().toSeq === Seq("world"))
-
- val both = source.getBatch(None, offset2)
- assert(both.as[String].collect().sorted.toSeq === Seq("hello", "world"))
// Try stopping the source to make sure this does not block forever.
source.stop()
@@ -104,22 +106,24 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
while (source.getOffset.isEmpty) {
Thread.sleep(10)
}
- val offset1 = source.getOffset.get
- val batch1 = source.getBatch(None, offset1)
- val batch1Seq = batch1.as[(String, Timestamp)].collect().toSeq
- assert(batch1Seq.map(_._1) === Seq("hello"))
- val batch1Stamp = batch1Seq(0)._2
-
- serverThread.enqueue("world")
- while (source.getOffset.get === offset1) {
- Thread.sleep(10)
+ withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
+ val offset1 = source.getOffset.get
+ val batch1 = source.getBatch(None, offset1)
+ val batch1Seq = batch1.as[(String, Timestamp)].collect().toSeq
+ assert(batch1Seq.map(_._1) === Seq("hello"))
+ val batch1Stamp = batch1Seq(0)._2
+
+ serverThread.enqueue("world")
+ while (source.getOffset.get === offset1) {
+ Thread.sleep(10)
+ }
+ val offset2 = source.getOffset.get
+ val batch2 = source.getBatch(Some(offset1), offset2)
+ val batch2Seq = batch2.as[(String, Timestamp)].collect().toSeq
+ assert(batch2Seq.map(_._1) === Seq("world"))
+ val batch2Stamp = batch2Seq(0)._2
+ assert(!batch2Stamp.before(batch1Stamp))
}
- val offset2 = source.getOffset.get
- val batch2 = source.getBatch(Some(offset1), offset2)
- val batch2Seq = batch2.as[(String, Timestamp)].collect().toSeq
- assert(batch2Seq.map(_._1) === Seq("world"))
- val batch2Stamp = batch2Seq(0)._2
- assert(!batch2Stamp.before(batch1Stamp))
// Try stopping the source to make sure this does not block forever.
source.stop()
@@ -184,12 +188,14 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
while (source.getOffset.isEmpty) {
Thread.sleep(10)
}
- val batch = source.getBatch(None, source.getOffset.get).as[String]
- batch.collect()
- val numRowsMetric =
- batch.queryExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows")
- assert(numRowsMetric.nonEmpty)
- assert(numRowsMetric.get.value === 1)
+ withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
+ val batch = source.getBatch(None, source.getOffset.get).as[String]
+ batch.collect()
+ val numRowsMetric =
+ batch.queryExecution.executedPlan.collectLeaves().head.metrics.get("numOutputRows")
+ assert(numRowsMetric.nonEmpty)
+ assert(numRowsMetric.get.value === 1)
+ }
source.stop()
source = null
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala
index 67b3d98c1daed..1331f157363b0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala
@@ -24,7 +24,10 @@ import scala.util.Random
import org.apache.spark.memory.MemoryMode
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.vectorized.ColumnVector
-import org.apache.spark.sql.types.{BinaryType, IntegerType}
+import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector
+import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
+import org.apache.spark.sql.execution.vectorized.WritableColumnVector
+import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType}
import org.apache.spark.unsafe.Platform
import org.apache.spark.util.Benchmark
import org.apache.spark.util.collection.BitSet
@@ -34,6 +37,14 @@ import org.apache.spark.util.collection.BitSet
*/
object ColumnarBatchBenchmark {
+ def allocate(capacity: Int, dt: DataType, memMode: MemoryMode): WritableColumnVector = {
+ if (memMode == MemoryMode.OFF_HEAP) {
+ new OffHeapColumnVector(capacity, dt)
+ } else {
+ new OnHeapColumnVector(capacity, dt)
+ }
+ }
+
// This benchmark reads and writes an array of ints.
// TODO: there is a big (2x) penalty for a random access API for off heap.
// Note: carefully if modifying this code. It's hard to reason about the JIT.
@@ -140,7 +151,7 @@ object ColumnarBatchBenchmark {
// Access through the column API with on heap memory
val columnOnHeap = { i: Int =>
- val col = ColumnVector.allocate(count, IntegerType, MemoryMode.ON_HEAP)
+ val col = allocate(count, IntegerType, MemoryMode.ON_HEAP)
var sum = 0L
for (n <- 0L until iters) {
var i = 0
@@ -159,7 +170,7 @@ object ColumnarBatchBenchmark {
// Access through the column API with off heap memory
def columnOffHeap = { i: Int => {
- val col = ColumnVector.allocate(count, IntegerType, MemoryMode.OFF_HEAP)
+ val col = allocate(count, IntegerType, MemoryMode.OFF_HEAP)
var sum = 0L
for (n <- 0L until iters) {
var i = 0
@@ -178,7 +189,7 @@ object ColumnarBatchBenchmark {
// Access by directly getting the buffer backing the column.
val columnOffheapDirect = { i: Int =>
- val col = ColumnVector.allocate(count, IntegerType, MemoryMode.OFF_HEAP)
+ val col = allocate(count, IntegerType, MemoryMode.OFF_HEAP)
var sum = 0L
for (n <- 0L until iters) {
var addr = col.valuesNativeAddress()
@@ -244,7 +255,7 @@ object ColumnarBatchBenchmark {
// Adding values by appending, instead of putting.
val onHeapAppend = { i: Int =>
- val col = ColumnVector.allocate(count, IntegerType, MemoryMode.ON_HEAP)
+ val col = allocate(count, IntegerType, MemoryMode.ON_HEAP)
var sum = 0L
for (n <- 0L until iters) {
var i = 0
@@ -362,7 +373,7 @@ object ColumnarBatchBenchmark {
.map(_.getBytes(StandardCharsets.UTF_8)).toArray
def column(memoryMode: MemoryMode) = { i: Int =>
- val column = ColumnVector.allocate(count, BinaryType, memoryMode)
+ val column = allocate(count, BinaryType, memoryMode)
var sum = 0L
for (n <- 0L until iters) {
var i = 0
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
index c8461dcb9dfdb..ebf76613343ba 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -25,20 +25,32 @@ import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Random
+import org.apache.arrow.vector.NullableIntVector
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.memory.MemoryMode
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.CalendarInterval
class ColumnarBatchSuite extends SparkFunSuite {
+
+ def allocate(capacity: Int, dt: DataType, memMode: MemoryMode): WritableColumnVector = {
+ if (memMode == MemoryMode.OFF_HEAP) {
+ new OffHeapColumnVector(capacity, dt)
+ } else {
+ new OnHeapColumnVector(capacity, dt)
+ }
+ }
+
test("Null Apis") {
(MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => {
val reference = mutable.ArrayBuffer.empty[Boolean]
- val column = ColumnVector.allocate(1024, IntegerType, memMode)
+ val column = allocate(1024, IntegerType, memMode)
var idx = 0
assert(column.anyNullsSet() == false)
assert(column.numNulls() == 0)
@@ -109,7 +121,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
(MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => {
val reference = mutable.ArrayBuffer.empty[Byte]
- val column = ColumnVector.allocate(1024, ByteType, memMode)
+ val column = allocate(1024, ByteType, memMode)
var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toByte).toArray
column.appendBytes(2, values, 0)
@@ -167,7 +179,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
val random = new Random(seed)
val reference = mutable.ArrayBuffer.empty[Short]
- val column = ColumnVector.allocate(1024, ShortType, memMode)
+ val column = allocate(1024, ShortType, memMode)
var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toShort).toArray
column.appendShorts(2, values, 0)
@@ -247,7 +259,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
val random = new Random(seed)
val reference = mutable.ArrayBuffer.empty[Int]
- val column = ColumnVector.allocate(1024, IntegerType, memMode)
+ val column = allocate(1024, IntegerType, memMode)
var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).toArray
column.appendInts(2, values, 0)
@@ -332,7 +344,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
val random = new Random(seed)
val reference = mutable.ArrayBuffer.empty[Long]
- val column = ColumnVector.allocate(1024, LongType, memMode)
+ val column = allocate(1024, LongType, memMode)
var values = (10L :: 20L :: 30L :: 40L :: 50L :: Nil).toArray
column.appendLongs(2, values, 0)
@@ -419,7 +431,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
val random = new Random(seed)
val reference = mutable.ArrayBuffer.empty[Float]
- val column = ColumnVector.allocate(1024, FloatType, memMode)
+ val column = allocate(1024, FloatType, memMode)
var values = (.1f :: .2f :: .3f :: .4f :: .5f :: Nil).toArray
column.appendFloats(2, values, 0)
@@ -510,7 +522,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
val random = new Random(seed)
val reference = mutable.ArrayBuffer.empty[Double]
- val column = ColumnVector.allocate(1024, DoubleType, memMode)
+ val column = allocate(1024, DoubleType, memMode)
var values = (.1 :: .2 :: .3 :: .4 :: .5 :: Nil).toArray
column.appendDoubles(2, values, 0)
@@ -599,7 +611,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
(MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => {
val reference = mutable.ArrayBuffer.empty[String]
- val column = ColumnVector.allocate(6, BinaryType, memMode)
+ val column = allocate(6, BinaryType, memMode)
assert(column.arrayData().elementsAppended == 0)
val str = "string"
@@ -656,7 +668,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
test("Int Array") {
(MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => {
- val column = ColumnVector.allocate(10, new ArrayType(IntegerType, true), memMode)
+ val column = allocate(10, new ArrayType(IntegerType, true), memMode)
// Fill the underlying data with all the arrays back to back.
val data = column.arrayData();
@@ -714,43 +726,43 @@ class ColumnarBatchSuite extends SparkFunSuite {
(MemoryMode.ON_HEAP :: Nil).foreach { memMode => {
val len = 4
- val columnBool = ColumnVector.allocate(len, new ArrayType(BooleanType, false), memMode)
+ val columnBool = allocate(len, new ArrayType(BooleanType, false), memMode)
val boolArray = Array(false, true, false, true)
boolArray.zipWithIndex.map { case (v, i) => columnBool.arrayData.putBoolean(i, v) }
columnBool.putArray(0, 0, len)
assert(columnBool.getArray(0).toBooleanArray === boolArray)
- val columnByte = ColumnVector.allocate(len, new ArrayType(ByteType, false), memMode)
+ val columnByte = allocate(len, new ArrayType(ByteType, false), memMode)
val byteArray = Array[Byte](0, 1, 2, 3)
byteArray.zipWithIndex.map { case (v, i) => columnByte.arrayData.putByte(i, v) }
columnByte.putArray(0, 0, len)
assert(columnByte.getArray(0).toByteArray === byteArray)
- val columnShort = ColumnVector.allocate(len, new ArrayType(ShortType, false), memMode)
+ val columnShort = allocate(len, new ArrayType(ShortType, false), memMode)
val shortArray = Array[Short](0, 1, 2, 3)
shortArray.zipWithIndex.map { case (v, i) => columnShort.arrayData.putShort(i, v) }
columnShort.putArray(0, 0, len)
assert(columnShort.getArray(0).toShortArray === shortArray)
- val columnInt = ColumnVector.allocate(len, new ArrayType(IntegerType, false), memMode)
+ val columnInt = allocate(len, new ArrayType(IntegerType, false), memMode)
val intArray = Array(0, 1, 2, 3)
intArray.zipWithIndex.map { case (v, i) => columnInt.arrayData.putInt(i, v) }
columnInt.putArray(0, 0, len)
assert(columnInt.getArray(0).toIntArray === intArray)
- val columnLong = ColumnVector.allocate(len, new ArrayType(LongType, false), memMode)
+ val columnLong = allocate(len, new ArrayType(LongType, false), memMode)
val longArray = Array[Long](0, 1, 2, 3)
longArray.zipWithIndex.map { case (v, i) => columnLong.arrayData.putLong(i, v) }
columnLong.putArray(0, 0, len)
assert(columnLong.getArray(0).toLongArray === longArray)
- val columnFloat = ColumnVector.allocate(len, new ArrayType(FloatType, false), memMode)
+ val columnFloat = allocate(len, new ArrayType(FloatType, false), memMode)
val floatArray = Array(0.0F, 1.1F, 2.2F, 3.3F)
floatArray.zipWithIndex.map { case (v, i) => columnFloat.arrayData.putFloat(i, v) }
columnFloat.putArray(0, 0, len)
assert(columnFloat.getArray(0).toFloatArray === floatArray)
- val columnDouble = ColumnVector.allocate(len, new ArrayType(DoubleType, false), memMode)
+ val columnDouble = allocate(len, new ArrayType(DoubleType, false), memMode)
val doubleArray = Array(0.0, 1.1, 2.2, 3.3)
doubleArray.zipWithIndex.map { case (v, i) => columnDouble.arrayData.putDouble(i, v) }
columnDouble.putArray(0, 0, len)
@@ -761,7 +773,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
test("Struct Column") {
(MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => {
val schema = new StructType().add("int", IntegerType).add("double", DoubleType)
- val column = ColumnVector.allocate(1024, schema, memMode)
+ val column = allocate(1024, schema, memMode)
val c1 = column.getChildColumn(0)
val c2 = column.getChildColumn(1)
@@ -790,7 +802,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
test("Nest Array in Array.") {
(MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode =>
- val column = ColumnVector.allocate(10, new ArrayType(new ArrayType(IntegerType, true), true),
+ val column = allocate(10, new ArrayType(new ArrayType(IntegerType, true), true),
memMode)
val childColumn = column.arrayData()
val data = column.arrayData().arrayData()
@@ -823,7 +835,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
test("Nest Struct in Array.") {
(MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode =>
val schema = new StructType().add("int", IntegerType).add("long", LongType)
- val column = ColumnVector.allocate(10, new ArrayType(schema, true), memMode)
+ val column = allocate(10, new ArrayType(schema, true), memMode)
val data = column.arrayData()
val c0 = data.getChildColumn(0)
val c1 = data.getChildColumn(1)
@@ -853,7 +865,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
val schema = new StructType()
.add("int", IntegerType)
.add("array", new ArrayType(IntegerType, true))
- val column = ColumnVector.allocate(10, schema, memMode)
+ val column = allocate(10, schema, memMode)
val c0 = column.getChildColumn(0)
val c1 = column.getChildColumn(1)
c0.putInt(0, 0)
@@ -885,7 +897,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
val schema = new StructType()
.add("int", IntegerType)
.add("struct", subSchema)
- val column = ColumnVector.allocate(10, schema, memMode)
+ val column = allocate(10, schema, memMode)
val c0 = column.getChildColumn(0)
val c1 = column.getChildColumn(1)
c0.putInt(0, 0)
@@ -918,7 +930,11 @@ class ColumnarBatchSuite extends SparkFunSuite {
.add("intCol2", IntegerType)
.add("string", BinaryType)
- val batch = ColumnarBatch.allocate(schema, memMode)
+ val capacity = ColumnarBatch.DEFAULT_BATCH_SIZE
+ val columns = schema.fields.map { field =>
+ allocate(capacity, field.dataType, memMode)
+ }
+ val batch = new ColumnarBatch(schema, columns.toArray, ColumnarBatch.DEFAULT_BATCH_SIZE)
assert(batch.numCols() == 4)
assert(batch.numRows() == 0)
assert(batch.numValidRows() == 0)
@@ -926,10 +942,10 @@ class ColumnarBatchSuite extends SparkFunSuite {
assert(batch.rowIterator().hasNext == false)
// Add a row [1, 1.1, NULL]
- batch.column(0).putInt(0, 1)
- batch.column(1).putDouble(0, 1.1)
- batch.column(2).putNull(0)
- batch.column(3).putByteArray(0, "Hello".getBytes(StandardCharsets.UTF_8))
+ columns(0).putInt(0, 1)
+ columns(1).putDouble(0, 1.1)
+ columns(2).putNull(0)
+ columns(3).putByteArray(0, "Hello".getBytes(StandardCharsets.UTF_8))
batch.setNumRows(1)
// Verify the results of the row.
@@ -939,12 +955,12 @@ class ColumnarBatchSuite extends SparkFunSuite {
assert(batch.rowIterator().hasNext == true)
assert(batch.rowIterator().hasNext == true)
- assert(batch.column(0).getInt(0) == 1)
- assert(batch.column(0).isNullAt(0) == false)
- assert(batch.column(1).getDouble(0) == 1.1)
- assert(batch.column(1).isNullAt(0) == false)
- assert(batch.column(2).isNullAt(0) == true)
- assert(batch.column(3).getUTF8String(0).toString == "Hello")
+ assert(columns(0).getInt(0) == 1)
+ assert(columns(0).isNullAt(0) == false)
+ assert(columns(1).getDouble(0) == 1.1)
+ assert(columns(1).isNullAt(0) == false)
+ assert(columns(2).isNullAt(0) == true)
+ assert(columns(3).getUTF8String(0).toString == "Hello")
// Verify the iterator works correctly.
val it = batch.rowIterator()
@@ -955,7 +971,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
assert(row.getDouble(1) == 1.1)
assert(row.isNullAt(1) == false)
assert(row.isNullAt(2) == true)
- assert(batch.column(3).getUTF8String(0).toString == "Hello")
+ assert(columns(3).getUTF8String(0).toString == "Hello")
assert(it.hasNext == false)
assert(it.hasNext == false)
@@ -972,20 +988,20 @@ class ColumnarBatchSuite extends SparkFunSuite {
assert(batch.rowIterator().hasNext == false)
// Add rows [NULL, 2.2, 2, "abc"], [3, NULL, 3, ""], [4, 4.4, 4, "world]
- batch.column(0).putNull(0)
- batch.column(1).putDouble(0, 2.2)
- batch.column(2).putInt(0, 2)
- batch.column(3).putByteArray(0, "abc".getBytes(StandardCharsets.UTF_8))
-
- batch.column(0).putInt(1, 3)
- batch.column(1).putNull(1)
- batch.column(2).putInt(1, 3)
- batch.column(3).putByteArray(1, "".getBytes(StandardCharsets.UTF_8))
-
- batch.column(0).putInt(2, 4)
- batch.column(1).putDouble(2, 4.4)
- batch.column(2).putInt(2, 4)
- batch.column(3).putByteArray(2, "world".getBytes(StandardCharsets.UTF_8))
+ columns(0).putNull(0)
+ columns(1).putDouble(0, 2.2)
+ columns(2).putInt(0, 2)
+ columns(3).putByteArray(0, "abc".getBytes(StandardCharsets.UTF_8))
+
+ columns(0).putInt(1, 3)
+ columns(1).putNull(1)
+ columns(2).putInt(1, 3)
+ columns(3).putByteArray(1, "".getBytes(StandardCharsets.UTF_8))
+
+ columns(0).putInt(2, 4)
+ columns(1).putDouble(2, 4.4)
+ columns(2).putInt(2, 4)
+ columns(3).putByteArray(2, "world".getBytes(StandardCharsets.UTF_8))
batch.setNumRows(3)
def rowEquals(x: InternalRow, y: Row): Unit = {
@@ -1232,7 +1248,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
test("exceeding maximum capacity should throw an error") {
(MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode =>
- val column = ColumnVector.allocate(1, ByteType, memMode)
+ val column = allocate(1, ByteType, memMode)
column.MAX_CAPACITY = 15
column.appendBytes(5, 0.toByte)
// Successfully allocate twice the requested capacity
@@ -1248,4 +1264,51 @@ class ColumnarBatchSuite extends SparkFunSuite {
s"vectorized reader"))
}
}
+
+ test("create columnar batch from Arrow column vectors") {
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue)
+ val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true)
+ .createVector(allocator).asInstanceOf[NullableIntVector]
+ vector1.allocateNew()
+ val mutator1 = vector1.getMutator()
+ val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true)
+ .createVector(allocator).asInstanceOf[NullableIntVector]
+ vector2.allocateNew()
+ val mutator2 = vector2.getMutator()
+
+ (0 until 10).foreach { i =>
+ mutator1.setSafe(i, i)
+ mutator2.setSafe(i + 1, i)
+ }
+ mutator1.setNull(10)
+ mutator1.setValueCount(11)
+ mutator2.setNull(0)
+ mutator2.setValueCount(11)
+
+ val columnVectors = Seq(new ArrowColumnVector(vector1), new ArrowColumnVector(vector2))
+
+ val schema = StructType(Seq(StructField("int1", IntegerType), StructField("int2", IntegerType)))
+ val batch = new ColumnarBatch(schema, columnVectors.toArray[ColumnVector], 11)
+ batch.setNumRows(11)
+
+ assert(batch.numCols() == 2)
+ assert(batch.numRows() == 11)
+
+ val rowIter = batch.rowIterator().asScala
+ rowIter.zipWithIndex.foreach { case (row, i) =>
+ if (i == 10) {
+ assert(row.isNullAt(0))
+ } else {
+ assert(row.getInt(0) == i)
+ }
+ if (i == 0) {
+ assert(row.isNullAt(1))
+ } else {
+ assert(row.getInt(1) == i - 1)
+ }
+ }
+
+ batch.close()
+ allocator.close()
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 8dc11d80c3063..689f4106824aa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -26,6 +26,7 @@ import org.scalatest.{BeforeAndAfter, PrivateMethodTester}
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.command.ExplainCommand
@@ -247,7 +248,7 @@ class JDBCSuite extends SparkFunSuite
// Check whether the tables are fetched in the expected degree of parallelism
def checkNumPartitions(df: DataFrame, expectedNumPartitions: Int): Unit = {
val jdbcRelations = df.queryExecution.analyzed.collect {
- case LogicalRelation(r: JDBCRelation, _, _) => r
+ case LogicalRelation(r: JDBCRelation, _, _, _) => r
}
assert(jdbcRelations.length == 1)
assert(jdbcRelations.head.parts.length == expectedNumPartitions,
@@ -968,6 +969,34 @@ class JDBCSuite extends SparkFunSuite
assert(e2.contains("User specified schema not supported with `jdbc`"))
}
+ test("jdbc API support custom schema") {
+ val parts = Array[String]("THEID < 2", "THEID >= 2")
+ val customSchema = "NAME STRING, THEID INT"
+ val props = new Properties()
+ props.put("customSchema", customSchema)
+ val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, props)
+ assert(df.schema.size === 2)
+ assert(df.schema === CatalystSqlParser.parseTableSchema(customSchema))
+ assert(df.count() === 3)
+ }
+
+ test("jdbc API custom schema DDL-like strings.") {
+ withTempView("people_view") {
+ val customSchema = "NAME STRING, THEID INT"
+ sql(
+ s"""
+ |CREATE TEMPORARY VIEW people_view
+ |USING org.apache.spark.sql.jdbc
+ |OPTIONS (uRl '$url', DbTaBlE 'TEST.PEOPLE', User 'testUser', PassWord 'testPass',
+ |customSchema '$customSchema')
+ """.stripMargin.replaceAll("\n", " "))
+ val df = sql("select * from people_view")
+ assert(df.schema.length === 2)
+ assert(df.schema === CatalystSqlParser.parseTableSchema(customSchema))
+ assert(df.count() === 3)
+ }
+ }
+
test("SPARK-15648: teradataDialect StringType data mapping") {
val teradataDialect = JdbcDialects.get("jdbc:teradata://127.0.0.1/db")
assert(teradataDialect.getJDBCType(StringType).
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index 5fb1068aaaf66..1985b1dc82879 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -468,7 +468,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
.option("createTableColumnTypes", "`name char(20)") // incorrectly quoted column
.jdbc(url1, "TEST.USERDBTYPETEST", properties)
}.getMessage()
- assert(msg.contains("extraneous input '`' expecting"))
+ assert(msg.contains("extraneous input"))
}
test("SPARK-10849: jdbc CreateTableColumnTypes duplicate columns") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index fe9469b49e385..c45b507d2b489 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -327,7 +327,7 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext with Predic
val table = spark.table("oneToTenFiltered")
val relation = table.queryExecution.logical.collectFirst {
- case LogicalRelation(r, _, _) => r
+ case LogicalRelation(r, _, _, _) => r
}.get
assert(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
index 41abff2a5da25..875b74551addb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.sources
import java.io.File
+import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
@@ -366,4 +367,63 @@ class InsertSuite extends DataSourceTest with SharedSQLContext {
Row(Array(1, 2), Array("a", "b")))
}
}
+
+ test("insert overwrite directory") {
+ withTempDir { dir =>
+ val path = dir.toURI.getPath
+
+ val v1 =
+ s"""
+ | INSERT OVERWRITE DIRECTORY '${path}'
+ | USING json
+ | OPTIONS (a 1, b 0.1, c TRUE)
+ | SELECT 1 as a, 'c' as b
+ """.stripMargin
+
+ spark.sql(v1)
+
+ checkAnswer(
+ spark.read.json(dir.getCanonicalPath),
+ sql("SELECT 1 as a, 'c' as b"))
+ }
+ }
+
+ test("insert overwrite directory with path in options") {
+ withTempDir { dir =>
+ val path = dir.toURI.getPath
+
+ val v1 =
+ s"""
+ | INSERT OVERWRITE DIRECTORY
+ | USING json
+ | OPTIONS ('path' '${path}')
+ | SELECT 1 as a, 'c' as b
+ """.stripMargin
+
+ spark.sql(v1)
+
+ checkAnswer(
+ spark.read.json(dir.getCanonicalPath),
+ sql("SELECT 1 as a, 'c' as b"))
+ }
+ }
+
+ test("insert overwrite directory to data source not providing FileFormat") {
+ withTempDir { dir =>
+ val path = dir.toURI.getPath
+
+ val v1 =
+ s"""
+ | INSERT OVERWRITE DIRECTORY '${path}'
+ | USING JDBC
+ | OPTIONS (a 1, b 0.1, c TRUE)
+ | SELECT 1 as a, 'c' as b
+ """.stripMargin
+ val e = intercept[SparkException] {
+ spark.sql(v1)
+ }.getMessage
+
+ assert(e.contains("Only Data Sources providing FileFormat are supported"))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala
index 3fd7a5be1da37..85da3f0e38468 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala
@@ -135,7 +135,7 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext {
private def getPathOption(tableName: String): Option[String] = {
spark.table(tableName).queryExecution.analyzed.collect {
- case LogicalRelation(r: TestOptionsRelation, _, _) => r.pathOption
+ case LogicalRelation(r: TestOptionsRelation, _, _, _) => r.pathOption
}.head
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
index a5cf40c3581c6..08db06b94904b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
@@ -127,7 +127,7 @@ class FileStreamSinkSuite extends StreamTest {
// Verify that MetadataLogFileIndex is being used and the correct partitioning schema has
// been inferred
val hadoopdFsRelations = outputDf.queryExecution.analyzed.collect {
- case LogicalRelation(baseRelation: HadoopFsRelation, _, _) => baseRelation
+ case LogicalRelation(baseRelation: HadoopFsRelation, _, _, _) => baseRelation
}
assert(hadoopdFsRelations.size === 1)
assert(hadoopdFsRelations.head.location.isInstanceOf[MetadataLogFileIndex])
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
index e2ec690d90e52..b6baaed1927e4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
@@ -1105,7 +1105,10 @@ class FileStreamSourceSuite extends FileStreamSourceTest {
def verify(startId: Option[Int], endId: Int, expected: String*): Unit = {
val start = startId.map(new FileStreamSourceOffset(_))
val end = FileStreamSourceOffset(endId)
- assert(fileSource.getBatch(start, end).as[String].collect().toSeq === expected)
+
+ withSQLConf("spark.sql.streaming.unsupportedOperationCheck" -> "false") {
+ assert(fileSource.getBatch(start, end).as[String].collect().toSeq === expected)
+ }
}
verify(startId = None, endId = 2, "keep1", "keep2", "keep3")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index 6f7b9d35a6bb3..9c901062d570a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -17,18 +17,21 @@
package org.apache.spark.sql.streaming
-import java.io.{File, InterruptedIOException, IOException}
-import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit}
+import java.io.{File, InterruptedIOException, IOException, UncheckedIOException}
+import java.nio.channels.ClosedByInterruptException
+import java.util.concurrent.{CountDownLatch, ExecutionException, TimeoutException, TimeUnit}
import scala.reflect.ClassTag
import scala.util.control.ControlThrowable
+import com.google.common.util.concurrent.UncheckedExecutionException
import org.apache.commons.io.FileUtils
import org.apache.hadoop.conf.Configuration
import org.apache.spark.SparkContext
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.plans.logical.Range
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.streaming._
@@ -73,6 +76,22 @@ class StreamSuite extends StreamTest {
CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two"), Row(4, 4, "four")))
}
+
+ test("explain join") {
+ // Make a table and ensure it will be broadcast.
+ val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word")
+
+ // Join the input stream with a table.
+ val inputData = MemoryStream[Int]
+ val joined = inputData.toDF().join(smallTable, smallTable("number") === $"value")
+
+ val outputStream = new java.io.ByteArrayOutputStream()
+ Console.withOut(outputStream) {
+ joined.explain()
+ }
+ assert(outputStream.toString.contains("StreamingRelation"))
+ }
+
test("SPARK-20432: union one stream with itself") {
val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load().select("a")
val unioned = df.union(df)
@@ -334,7 +353,9 @@ class StreamSuite extends StreamTest {
override def stop(): Unit = {}
}
- val df = Dataset[Int](sqlContext.sparkSession, StreamingExecutionRelation(source))
+ val df = Dataset[Int](
+ sqlContext.sparkSession,
+ StreamingExecutionRelation(source, sqlContext.sparkSession))
testStream(df)(
// `ExpectFailure(isFatalError = true)` verifies two things:
// - Fatal errors can be propagated to `StreamingQuery.exception` and
@@ -690,6 +711,31 @@ class StreamSuite extends StreamTest {
}
}
}
+
+ for (e <- Seq(
+ new InterruptedException,
+ new InterruptedIOException,
+ new ClosedByInterruptException,
+ new UncheckedIOException("test", new ClosedByInterruptException),
+ new ExecutionException("test", new InterruptedException),
+ new UncheckedExecutionException("test", new InterruptedException))) {
+ test(s"view ${e.getClass.getSimpleName} as a normal query stop") {
+ ThrowingExceptionInCreateSource.createSourceLatch = new CountDownLatch(1)
+ ThrowingExceptionInCreateSource.exception = e
+ val query = spark
+ .readStream
+ .format(classOf[ThrowingExceptionInCreateSource].getName)
+ .load()
+ .writeStream
+ .format("console")
+ .start()
+ assert(ThrowingExceptionInCreateSource.createSourceLatch
+ .await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS),
+ "ThrowingExceptionInCreateSource.createSource wasn't called before timeout")
+ query.stop()
+ assert(query.exception.isEmpty)
+ }
+ }
}
abstract class FakeSource extends StreamSourceProvider {
@@ -728,7 +774,16 @@ class FakeDefaultSource extends FakeSource {
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
val startOffset = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1
- spark.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a")
+ val ds = new Dataset[java.lang.Long](
+ spark.sparkSession,
+ Range(
+ startOffset,
+ end.asInstanceOf[LongOffset].offset + 1,
+ 1,
+ Some(spark.sparkSession.sparkContext.defaultParallelism),
+ isStreaming = true),
+ Encoders.LONG)
+ ds.toDF("a")
}
override def stop() {}
@@ -814,3 +869,32 @@ class TestStateStoreProvider extends StateStoreProvider {
override def getStore(version: Long): StateStore = null
}
+
+/** A fake source that throws `ThrowingExceptionInCreateSource.exception` in `createSource` */
+class ThrowingExceptionInCreateSource extends FakeSource {
+
+ override def createSource(
+ spark: SQLContext,
+ metadataPath: String,
+ schema: Option[StructType],
+ providerName: String,
+ parameters: Map[String, String]): Source = {
+ ThrowingExceptionInCreateSource.createSourceLatch.countDown()
+ try {
+ Thread.sleep(30000)
+ throw new TimeoutException("sleep was not interrupted in 30 seconds")
+ } catch {
+ case _: InterruptedException =>
+ throw ThrowingExceptionInCreateSource.exception
+ }
+ }
+}
+
+object ThrowingExceptionInCreateSource {
+ /**
+ * A latch to allow the user to wait until `ThrowingExceptionInCreateSource.createSource` is
+ * called.
+ */
+ @volatile var createSourceLatch: CountDownLatch = null
+ @volatile var exception: Exception = null
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index b2c42eef88f6d..4f8764060d922 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -27,7 +27,7 @@ import scala.util.Random
import scala.util.control.NonFatal
import org.scalatest.{Assertions, BeforeAndAfterAll}
-import org.scalatest.concurrent.{Eventually, Timeouts}
+import org.scalatest.concurrent.{Eventually, Signaler, ThreadSignaler, TimeLimits}
import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.exceptions.TestFailedDueToTimeoutException
import org.scalatest.time.Span
@@ -67,8 +67,9 @@ import org.apache.spark.util.{Clock, SystemClock, Utils}
* avoid hanging forever in the case of failures. However, individual suites can change this
* by overriding `streamingTimeout`.
*/
-trait StreamTest extends QueryTest with SharedSQLContext with Timeouts with BeforeAndAfterAll {
+trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with BeforeAndAfterAll {
+ implicit val defaultSignaler: Signaler = ThreadSignaler
override def afterAll(): Unit = {
super.afterAll()
StateStore.stop() // stop the state store maintenance thread and unload store providers
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
index b6e82b621c8cb..e0979ce296c3a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -19,10 +19,12 @@ package org.apache.spark.sql.streaming
import java.util.{Locale, TimeZone}
+import org.scalatest.Assertions
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkException
-import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.{AnalysisException, DataFrame}
+import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.streaming._
@@ -31,12 +33,14 @@ import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions._
import org.apache.spark.sql.streaming.OutputMode._
import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.sql.types.StructType
object FailureSinglton {
var firstTime = true
}
-class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
+class StreamingAggregationSuite extends StateStoreMetricsTest
+ with BeforeAndAfterAll with Assertions {
override def afterAll(): Unit = {
super.afterAll()
@@ -356,4 +360,25 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte
CheckLastBatch((90L, 1), (100L, 1), (105L, 1))
)
}
+
+ test("SPARK-19690: do not convert batch aggregation in streaming query to streaming") {
+ val streamInput = MemoryStream[Int]
+ val batchDF = Seq(1, 2, 3, 4, 5)
+ .toDF("value")
+ .withColumn("parity", 'value % 2)
+ .groupBy('parity)
+ .agg(count("*") as 'joinValue)
+ val joinDF = streamInput
+ .toDF()
+ .join(batchDF, 'value === 'parity)
+
+ // make sure we're planning an aggregate in the first place
+ assert(batchDF.queryExecution.optimizedPlan match { case _: Aggregate => true })
+
+ testStream(joinDF, Append)(
+ AddData(streamInput, 0, 1, 2, 3),
+ CheckLastBatch((0, 0, 2), (1, 1, 3)),
+ AddData(streamInput, 0, 1, 2, 3),
+ CheckLastBatch((0, 0, 2), (1, 1, 3)))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index 27ea6902fa1fd..3823e336d0b64 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -25,7 +25,7 @@ import org.scalactic.TolerantNumerics
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.PatienceConfiguration.Timeout
-import org.scalatest.mock.MockitoSugar
+import org.scalatest.mockito.MockitoSugar
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
@@ -647,10 +647,13 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
val source = new Source() {
override def schema: StructType = triggerDF.schema
override def getOffset: Option[Offset] = Some(LongOffset(0))
- override def getBatch(start: Option[Offset], end: Offset): DataFrame = triggerDF
+ override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
+ sqlContext.internalCreateDataFrame(
+ triggerDF.queryExecution.toRdd, triggerDF.schema, isStreaming = true)
+ }
override def stop(): Unit = {}
}
- StreamingExecutionRelation(source)
+ StreamingExecutionRelation(source, spark)
}
/** Returns the query progress at the end of the first trigger of streaming DF */
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
index e8a6202b8adce..aa163d2211c38 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
@@ -88,7 +88,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider {
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
import spark.implicits._
- Seq[Int]().toDS().toDF()
+ spark.internalCreateDataFrame(spark.sparkContext.emptyRDD, schema, isStreaming = true)
}
override def stop() {}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index e68db3b636bce..a14a1441a4313 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -247,7 +247,7 @@ private[sql] trait SQLTestUtils
protected def withDatabase(dbNames: String*)(f: => Unit): Unit = {
try f finally {
dbNames.foreach { name =>
- spark.sql(s"DROP DATABASE IF EXISTS $name")
+ spark.sql(s"DROP DATABASE IF EXISTS $name CASCADE")
}
spark.sql(s"USE $DEFAULT_DATABASE")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
index 5ec76a4f0ec90..cd8d0708d8a32 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
@@ -24,6 +24,7 @@ import org.scalatest.concurrent.Eventually
import org.apache.spark.{DebugFilesystem, SparkConf}
import org.apache.spark.sql.{SparkSession, SQLContext}
+import org.apache.spark.sql.internal.SQLConf
/**
* Helper trait for SQL test suites where all tests share a single [[TestSparkSession]].
@@ -31,7 +32,10 @@ import org.apache.spark.sql.{SparkSession, SQLContext}
trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually {
protected def sparkConf = {
- new SparkConf().set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)
+ new SparkConf()
+ .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)
+ .set("spark.unsafe.exceptionOnMemoryLeak", "true")
+ .set(SQLConf.CODEGEN_FALLBACK.key, "false")
}
/**
diff --git a/sql/create-docs.sh b/sql/create-docs.sh
index 275e4c391a388..4353708d22f7b 100755
--- a/sql/create-docs.sh
+++ b/sql/create-docs.sh
@@ -33,10 +33,12 @@ if ! hash python 2>/dev/null; then
fi
if ! hash mkdocs 2>/dev/null; then
- echo "Missing mkdocs in your path, skipping SQL documentation generation."
- exit 0
+ echo "Missing mkdocs in your path, trying to install mkdocs for SQL documentation generation."
+ pip install mkdocs
fi
+pushd "$FWDIR" > /dev/null
+
# Now create the markdown file
rm -fr docs
mkdir docs
@@ -47,3 +49,5 @@ echo "Generating markdown files for SQL documentation."
echo "Generating HTML files for SQL documentation."
mkdocs build --clean
rm -fr docs
+
+popd
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
index 1b17a9a56e5b9..ad1f5eb9ca3a7 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala
@@ -25,6 +25,7 @@ import scala.collection.JavaConverters._
import org.apache.commons.logging.Log
import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hadoop.hive.shims.Utils
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hive.service.{AbstractService, Service, ServiceException}
@@ -47,6 +48,7 @@ private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, sqlContext: SQLC
setSuperField(this, "sessionManager", sparkSqlSessionManager)
addService(sparkSqlSessionManager)
var sparkServiceUGI: UserGroupInformation = null
+ var httpUGI: UserGroupInformation = null
if (UserGroupInformation.isSecurityEnabled) {
try {
@@ -57,6 +59,20 @@ private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, sqlContext: SQLC
case e @ (_: IOException | _: LoginException) =>
throw new ServiceException("Unable to login to kerberos with given principal/keytab", e)
}
+
+ // Try creating spnego UGI if it is configured.
+ val principal = hiveConf.getVar(ConfVars.HIVE_SERVER2_SPNEGO_PRINCIPAL).trim
+ val keyTabFile = hiveConf.getVar(ConfVars.HIVE_SERVER2_SPNEGO_KEYTAB).trim
+ if (principal.nonEmpty && keyTabFile.nonEmpty) {
+ try {
+ httpUGI = HiveAuthFactory.loginFromSpnegoKeytabAndReturnUGI(hiveConf)
+ setSuperField(this, "httpUGI", httpUGI)
+ } catch {
+ case e: IOException =>
+ throw new ServiceException("Unable to login to spnego with given principal " +
+ s"$principal and keytab $keyTabFile: $e", e)
+ }
+ }
}
initCompositeService(hiveConf)
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 0a53aaca404e6..45791c69b4cb7 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -39,7 +39,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
private val originalLocale = Locale.getDefault
private val originalColumnBatchSize = TestHive.conf.columnBatchSize
private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning
- private val originalConvertMetastoreOrc = TestHive.conf.getConf(HiveUtils.CONVERT_METASTORE_ORC)
private val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled
private val originalSessionLocalTimeZone = TestHive.conf.sessionLocalTimeZone
@@ -58,9 +57,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5)
// Enable in-memory partition pruning for testing purposes
TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true)
- // Ensures that the plans generation use metastore relation and not OrcRelation
- // Was done because SqlBuilder does not work with plans having logical relation
- TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, false)
// Ensures that cross joins are enabled so that we can test them
TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true)
// Fix session local timezone to America/Los_Angeles for those timezone sensitive tests
@@ -76,7 +72,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
Locale.setDefault(originalLocale)
TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
- TestHive.setConf(HiveUtils.CONVERT_METASTORE_ORC, originalConvertMetastoreOrc)
TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled)
TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, originalSessionLocalTimeZone)
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
index 32a5ab90e7f5e..f63b48053052f 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
@@ -610,7 +610,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte
|window w1 as (distribute by p_mfgr sort by p_name rows between 2 preceding and 2 following),
| w2 as (partition by p_mfgr order by p_name)
""".stripMargin, reset = false)
- */
+ */
/* p_name is not a numeric column. What is Hive's semantic?
createQueryTest("windowing.q -- 31. testWindowCrossReference",
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index 4eb50d7708bba..497a252b0e930 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -172,6 +172,15 @@
org.apache.thrift
libfb303
+
+ org.apache.derby
+ derby
+
+
+ org.scala-lang
+ scala-compiler
+ test
+
org.scalacheck
scalacheck_${scala.binary.version}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
index e9d48f95aa905..96dc983b0bfc6 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
@@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.ColumnStat
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.command.DDLUtils
-import org.apache.spark.sql.execution.datasources.PartitioningUtils
+import org.apache.spark.sql.execution.datasources.{PartitioningUtils, SourceOptions}
import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.internal.HiveSerDe
import org.apache.spark.sql.internal.StaticSQLConf._
@@ -114,7 +114,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
* should interpret these special data source properties and restore the original table metadata
* before returning it.
*/
- private def getRawTable(db: String, table: String): CatalogTable = withClient {
+ private[hive] def getRawTable(db: String, table: String): CatalogTable = withClient {
client.getTable(db, table)
}
@@ -260,6 +260,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
private def createDataSourceTable(table: CatalogTable, ignoreIfExists: Boolean): Unit = {
// data source table always have a provider, it's guaranteed by `DDLUtils.isDatasourceTable`.
val provider = table.provider.get
+ val options = new SourceOptions(table.storage.properties)
// To work around some hive metastore issues, e.g. not case-preserving, bad decimal type
// support, no column nullability, etc., we should do some extra works before saving table
@@ -325,11 +326,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
val qualifiedTableName = table.identifier.quotedString
val maybeSerde = HiveSerDe.sourceToSerDe(provider)
- val skipHiveMetadata = table.storage.properties
- .getOrElse("skipHiveMetadata", "false").toBoolean
val (hiveCompatibleTable, logMessage) = maybeSerde match {
- case _ if skipHiveMetadata =>
+ case _ if options.skipHiveMetadata =>
val message =
s"Persisting data source table $qualifiedTableName into Hive metastore in" +
"Spark SQL specific format, which is NOT compatible with Hive."
@@ -386,6 +385,12 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
* can be used as table properties later.
*/
private def tableMetaToTableProps(table: CatalogTable): mutable.Map[String, String] = {
+ tableMetaToTableProps(table, table.schema)
+ }
+
+ private def tableMetaToTableProps(
+ table: CatalogTable,
+ schema: StructType): mutable.Map[String, String] = {
val partitionColumns = table.partitionColumnNames
val bucketSpec = table.bucketSpec
@@ -397,7 +402,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
// property. In this case, we split the JSON string and store each part as a separate table
// property.
val threshold = conf.get(SCHEMA_STRING_LENGTH_THRESHOLD)
- val schemaJsonString = table.schema.json
+ val schemaJsonString = schema.json
// Split the JSON string.
val parts = schemaJsonString.grouped(threshold).toSeq
properties.put(DATASOURCE_SCHEMA_NUMPARTS, parts.size.toString)
@@ -507,7 +512,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
identifier = TableIdentifier(newName, Some(db)),
storage = storageWithNewPath)
- client.alterTable(oldName, newTable)
+ client.alterTable(db, oldName, newTable)
}
private def getLocationFromStorageProps(table: CatalogTable): Option[String] = {
@@ -615,20 +620,29 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
override def alterTableSchema(db: String, table: String, schema: StructType): Unit = withClient {
requireTableExists(db, table)
val rawTable = getRawTable(db, table)
- val withNewSchema = rawTable.copy(schema = schema)
- verifyColumnNames(withNewSchema)
// Add table metadata such as table schema, partition columns, etc. to table properties.
- val updatedTable = withNewSchema.copy(
- properties = withNewSchema.properties ++ tableMetaToTableProps(withNewSchema))
- try {
- client.alterTable(updatedTable)
- } catch {
- case NonFatal(e) =>
- val warningMessage =
- s"Could not alter schema of table ${rawTable.identifier.quotedString} in a Hive " +
- "compatible way. Updating Hive metastore in Spark SQL specific format."
- logWarning(warningMessage, e)
- client.alterTable(updatedTable.copy(schema = updatedTable.partitionSchema))
+ val updatedProperties = rawTable.properties ++ tableMetaToTableProps(rawTable, schema)
+ val withNewSchema = rawTable.copy(properties = updatedProperties, schema = schema)
+ verifyColumnNames(withNewSchema)
+
+ if (isDatasourceTable(rawTable)) {
+ // For data source tables, first try to write it with the schema set; if that does not work,
+ // try again with updated properties and the partition schema. This is a simplified version of
+ // what createDataSourceTable() does, and may leave the table in a state unreadable by Hive
+ // (for example, the schema does not match the data source schema, or does not match the
+ // storage descriptor).
+ try {
+ client.alterTable(withNewSchema)
+ } catch {
+ case NonFatal(e) =>
+ val warningMessage =
+ s"Could not alter schema of table ${rawTable.identifier.quotedString} in a Hive " +
+ "compatible way. Updating Hive metastore in Spark SQL specific format."
+ logWarning(warningMessage, e)
+ client.alterTable(withNewSchema.copy(schema = rawTable.partitionSchema))
+ }
+ } else {
+ client.alterTable(withNewSchema)
}
}
@@ -639,26 +653,17 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
requireTableExists(db, table)
val rawTable = getRawTable(db, table)
- // convert table statistics to properties so that we can persist them through hive client
- val statsProperties = new mutable.HashMap[String, String]()
- if (stats.isDefined) {
- statsProperties += STATISTICS_TOTAL_SIZE -> stats.get.sizeInBytes.toString()
- if (stats.get.rowCount.isDefined) {
- statsProperties += STATISTICS_NUM_ROWS -> stats.get.rowCount.get.toString()
- }
-
- // For datasource tables and hive serde tables created by spark 2.1 or higher,
- // the data schema is stored in the table properties.
- val schema = restoreTableMetadata(rawTable).schema
+ // For datasource tables and hive serde tables created by spark 2.1 or higher,
+ // the data schema is stored in the table properties.
+ val schema = restoreTableMetadata(rawTable).schema
- val colNameTypeMap: Map[String, DataType] =
- schema.fields.map(f => (f.name, f.dataType)).toMap
- stats.get.colStats.foreach { case (colName, colStat) =>
- colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) =>
- statsProperties += (columnStatKeyPropName(colName, k) -> v)
- }
+ // convert table statistics to properties so that we can persist them through hive client
+ var statsProperties =
+ if (stats.isDefined) {
+ statsToProperties(stats.get, schema)
+ } else {
+ new mutable.HashMap[String, String]()
}
- }
val oldTableNonStatsProps = rawTable.properties.filterNot(_._1.startsWith(STATISTICS_PREFIX))
val updatedTable = rawTable.copy(properties = oldTableNonStatsProps ++ statsProperties)
@@ -704,36 +709,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
val version: String = table.properties.getOrElse(CREATED_SPARK_VERSION, "2.2 or prior")
// Restore Spark's statistics from information in Metastore.
- val statsProps = table.properties.filterKeys(_.startsWith(STATISTICS_PREFIX))
-
- // Currently we have two sources of statistics: one from Hive and the other from Spark.
- // In our design, if Spark's statistics is available, we respect it over Hive's statistics.
- if (statsProps.nonEmpty) {
- val colStats = new mutable.HashMap[String, ColumnStat]
-
- // For each column, recover its column stats. Note that this is currently a O(n^2) operation,
- // but given the number of columns it usually not enormous, this is probably OK as a start.
- // If we want to map this a linear operation, we'd need a stronger contract between the
- // naming convention used for serialization.
- table.schema.foreach { field =>
- if (statsProps.contains(columnStatKeyPropName(field.name, ColumnStat.KEY_VERSION))) {
- // If "version" field is defined, then the column stat is defined.
- val keyPrefix = columnStatKeyPropName(field.name, "")
- val colStatMap = statsProps.filterKeys(_.startsWith(keyPrefix)).map { case (k, v) =>
- (k.drop(keyPrefix.length), v)
- }
-
- ColumnStat.fromMap(table.identifier.table, field, colStatMap).foreach {
- colStat => colStats += field.name -> colStat
- }
- }
- }
-
- table = table.copy(
- stats = Some(CatalogStatistics(
- sizeInBytes = BigInt(table.properties(STATISTICS_TOTAL_SIZE)),
- rowCount = table.properties.get(STATISTICS_NUM_ROWS).map(BigInt(_)),
- colStats = colStats.toMap)))
+ val restoredStats =
+ statsFromProperties(table.properties, table.identifier.table, table.schema)
+ if (restoredStats.isDefined) {
+ table = table.copy(stats = restoredStats)
}
// Get the original table properties as defined by the user.
@@ -757,6 +736,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
}
private def restoreHiveSerdeTable(table: CatalogTable): CatalogTable = {
+ val options = new SourceOptions(table.storage.properties)
val hiveTable = table.copy(
provider = Some(DDLUtils.HIVE_PROVIDER),
tracksPartitionsInCatalog = true)
@@ -768,7 +748,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
val partColumnNames = getPartitionColumnsFromTableProperties(table)
val reorderedSchema = reorderSchema(schema = schemaFromTableProps, partColumnNames)
- if (DataType.equalsIgnoreCaseAndNullability(reorderedSchema, table.schema)) {
+ if (DataType.equalsIgnoreCaseAndNullability(reorderedSchema, table.schema) ||
+ options.respectSparkSchema) {
hiveTable.copy(
schema = reorderedSchema,
partitionColumnNames = partColumnNames,
@@ -1037,17 +1018,92 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
currentFullPath
}
+ private def statsToProperties(
+ stats: CatalogStatistics,
+ schema: StructType): Map[String, String] = {
+
+ var statsProperties: Map[String, String] =
+ Map(STATISTICS_TOTAL_SIZE -> stats.sizeInBytes.toString())
+ if (stats.rowCount.isDefined) {
+ statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString()
+ }
+
+ val colNameTypeMap: Map[String, DataType] =
+ schema.fields.map(f => (f.name, f.dataType)).toMap
+ stats.colStats.foreach { case (colName, colStat) =>
+ colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) =>
+ statsProperties += (columnStatKeyPropName(colName, k) -> v)
+ }
+ }
+
+ statsProperties
+ }
+
+ private def statsFromProperties(
+ properties: Map[String, String],
+ table: String,
+ schema: StructType): Option[CatalogStatistics] = {
+
+ val statsProps = properties.filterKeys(_.startsWith(STATISTICS_PREFIX))
+ if (statsProps.isEmpty) {
+ None
+ } else {
+
+ val colStats = new mutable.HashMap[String, ColumnStat]
+
+ // For each column, recover its column stats. Note that this is currently a O(n^2) operation,
+ // but given the number of columns it usually not enormous, this is probably OK as a start.
+ // If we want to map this a linear operation, we'd need a stronger contract between the
+ // naming convention used for serialization.
+ schema.foreach { field =>
+ if (statsProps.contains(columnStatKeyPropName(field.name, ColumnStat.KEY_VERSION))) {
+ // If "version" field is defined, then the column stat is defined.
+ val keyPrefix = columnStatKeyPropName(field.name, "")
+ val colStatMap = statsProps.filterKeys(_.startsWith(keyPrefix)).map { case (k, v) =>
+ (k.drop(keyPrefix.length), v)
+ }
+
+ ColumnStat.fromMap(table, field, colStatMap).foreach {
+ colStat => colStats += field.name -> colStat
+ }
+ }
+ }
+
+ Some(CatalogStatistics(
+ sizeInBytes = BigInt(statsProps(STATISTICS_TOTAL_SIZE)),
+ rowCount = statsProps.get(STATISTICS_NUM_ROWS).map(BigInt(_)),
+ colStats = colStats.toMap))
+ }
+ }
+
override def alterPartitions(
db: String,
table: String,
newParts: Seq[CatalogTablePartition]): Unit = withClient {
val lowerCasedParts = newParts.map(p => p.copy(spec = lowerCasePartitionSpec(p.spec)))
+
+ val rawTable = getRawTable(db, table)
+
+ // For datasource tables and hive serde tables created by spark 2.1 or higher,
+ // the data schema is stored in the table properties.
+ val schema = restoreTableMetadata(rawTable).schema
+
+ // convert partition statistics to properties so that we can persist them through hive api
+ val withStatsProps = lowerCasedParts.map(p => {
+ if (p.stats.isDefined) {
+ val statsProperties = statsToProperties(p.stats.get, schema)
+ p.copy(parameters = p.parameters ++ statsProperties)
+ } else {
+ p
+ }
+ })
+
// Note: Before altering table partitions in Hive, you *must* set the current database
// to the one that contains the table of interest. Otherwise you will end up with the
// most helpful error message ever: "Unable to alter partition. alter is not possible."
// See HIVE-2742 for more detail.
client.setCurrentDatabase(db)
- client.alterPartitions(db, table, lowerCasedParts)
+ client.alterPartitions(db, table, withStatsProps)
}
override def getPartition(
@@ -1055,7 +1111,34 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
table: String,
spec: TablePartitionSpec): CatalogTablePartition = withClient {
val part = client.getPartition(db, table, lowerCasePartitionSpec(spec))
- part.copy(spec = restorePartitionSpec(part.spec, getTable(db, table).partitionColumnNames))
+ restorePartitionMetadata(part, getTable(db, table))
+ }
+
+ /**
+ * Restores partition metadata from the partition properties.
+ *
+ * Reads partition-level statistics from partition properties, puts these
+ * into [[CatalogTablePartition#stats]] and removes these special entries
+ * from the partition properties.
+ */
+ private def restorePartitionMetadata(
+ partition: CatalogTablePartition,
+ table: CatalogTable): CatalogTablePartition = {
+ val restoredSpec = restorePartitionSpec(partition.spec, table.partitionColumnNames)
+
+ // Restore Spark's statistics from information in Metastore.
+ // Note: partition-level statistics were introduced in 2.3.
+ val restoredStats =
+ statsFromProperties(partition.parameters, table.identifier.table, table.schema)
+ if (restoredStats.isDefined) {
+ partition.copy(
+ spec = restoredSpec,
+ stats = restoredStats,
+ parameters = partition.parameters.filterNot {
+ case (key, _) => key.startsWith(SPARK_SQL_PREFIX) })
+ } else {
+ partition.copy(spec = restoredSpec)
+ }
}
/**
@@ -1066,7 +1149,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
table: String,
spec: TablePartitionSpec): Option[CatalogTablePartition] = withClient {
client.getPartitionOption(db, table, lowerCasePartitionSpec(spec)).map { part =>
- part.copy(spec = restorePartitionSpec(part.spec, getTable(db, table).partitionColumnNames))
+ restorePartitionMetadata(part, getTable(db, table))
}
}
@@ -1284,4 +1367,14 @@ object HiveExternalCatalog {
getColumnNamesByType(metadata.properties, "sort", "sorting columns"))
}
}
+
+ /**
+ * Detects a data source table. This checks both the table provider and the table properties,
+ * unlike DDLUtils which just checks the former.
+ */
+ private[spark] def isDatasourceTable(table: CatalogTable): Boolean = {
+ val provider = table.provider.orElse(table.properties.get(DATASOURCE_PROVIDER))
+ provider.isDefined && provider != Some(DDLUtils.HIVE_PROVIDER)
+ }
+
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 8bab059ed5e84..f0f2c493498b3 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -73,7 +73,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
catalogProxy.getCachedTable(tableIdentifier) match {
case null => None // Cache miss
- case logical @ LogicalRelation(relation: HadoopFsRelation, _, _) =>
+ case logical @ LogicalRelation(relation: HadoopFsRelation, _, _, _) =>
val cachedRelationFileFormatClass = relation.fileFormat.getClass
expectedFileFormat match {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 0d0269f694300..b256ffc27b199 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -30,14 +30,12 @@ import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, Gener
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
-import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions.{Cast, Expression}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DecimalType, DoubleType}
-import org.apache.spark.util.Utils
private[sql] class HiveSessionCatalog(
@@ -58,55 +56,52 @@ private[sql] class HiveSessionCatalog(
parser,
functionResourceLoader) {
- override def makeFunctionBuilder(funcName: String, className: String): FunctionBuilder = {
- makeFunctionBuilder(funcName, Utils.classForName(className))
- }
-
/**
- * Construct a [[FunctionBuilder]] based on the provided class that represents a function.
+ * Constructs a [[Expression]] based on the provided class that represents a function.
+ *
+ * This performs reflection to decide what type of [[Expression]] to return in the builder.
*/
- private def makeFunctionBuilder(name: String, clazz: Class[_]): FunctionBuilder = {
- // When we instantiate hive UDF wrapper class, we may throw exception if the input
- // expressions don't satisfy the hive UDF, such as type mismatch, input number
- // mismatch, etc. Here we catch the exception and throw AnalysisException instead.
- (children: Seq[Expression]) => {
+ override def makeFunctionExpression(
+ name: String,
+ clazz: Class[_],
+ input: Seq[Expression]): Expression = {
+
+ Try(super.makeFunctionExpression(name, clazz, input)).getOrElse {
+ var udfExpr: Option[Expression] = None
try {
+ // When we instantiate hive UDF wrapper class, we may throw exception if the input
+ // expressions don't satisfy the hive UDF, such as type mismatch, input number
+ // mismatch, etc. Here we catch the exception and throw AnalysisException instead.
if (classOf[UDF].isAssignableFrom(clazz)) {
- val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), children)
- udf.dataType // Force it to check input data types.
- udf
+ udfExpr = Some(HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), input))
+ udfExpr.get.dataType // Force it to check input data types.
} else if (classOf[GenericUDF].isAssignableFrom(clazz)) {
- val udf = HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), children)
- udf.dataType // Force it to check input data types.
- udf
+ udfExpr = Some(HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), input))
+ udfExpr.get.dataType // Force it to check input data types.
} else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(clazz)) {
- val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), children)
- udaf.dataType // Force it to check input data types.
- udaf
+ udfExpr = Some(HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), input))
+ udfExpr.get.dataType // Force it to check input data types.
} else if (classOf[UDAF].isAssignableFrom(clazz)) {
- val udaf = HiveUDAFFunction(
+ udfExpr = Some(HiveUDAFFunction(
name,
new HiveFunctionWrapper(clazz.getName),
- children,
- isUDAFBridgeRequired = true)
- udaf.dataType // Force it to check input data types.
- udaf
+ input,
+ isUDAFBridgeRequired = true))
+ udfExpr.get.dataType // Force it to check input data types.
} else if (classOf[GenericUDTF].isAssignableFrom(clazz)) {
- val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), children)
- udtf.elementSchema // Force it to check input data types.
- udtf
- } else {
- throw new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}'")
+ udfExpr = Some(HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), input))
+ udfExpr.get.asInstanceOf[HiveGenericUDTF].elementSchema // Force it to check data types.
}
} catch {
- case ae: AnalysisException =>
- throw ae
case NonFatal(e) =>
val analysisException =
- new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}': $e")
+ new AnalysisException(s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e")
analysisException.setStackTrace(e.getStackTrace)
throw analysisException
}
+ udfExpr.getOrElse {
+ throw new AnalysisException(s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}'")
+ }
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index ae1e7e72e8c3f..805b3171cdaab 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -26,7 +26,8 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
-import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, ScriptTransformation}
+import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoTable, LogicalPlan,
+ ScriptTransformation}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils}
@@ -151,10 +152,19 @@ object HiveAnalysis extends Rule[LogicalPlan] {
InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, ifPartitionNotExists)
case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) =>
+ DDLUtils.checkDataSchemaFieldNames(tableDesc)
CreateTableCommand(tableDesc, ignoreIfExists = mode == SaveMode.Ignore)
case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) =>
+ DDLUtils.checkDataSchemaFieldNames(tableDesc)
CreateHiveTableAsSelectCommand(tableDesc, query, mode)
+
+ case InsertIntoDir(isLocal, storage, provider, child, overwrite)
+ if DDLUtils.isHiveTable(provider) =>
+ val outputPath = new Path(storage.locationUri.get)
+ if (overwrite) DDLUtils.verifyNotReadPath(child, outputPath)
+
+ InsertIntoHiveDirCommand(isLocal, storage, child, overwrite)
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index f238b9a4f7f6f..cc8907a0bbc93 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -39,8 +39,10 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.CastSupport
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{SerializableConfiguration, Utils}
@@ -65,7 +67,7 @@ class HadoopTableReader(
@transient private val tableDesc: TableDesc,
@transient private val sparkSession: SparkSession,
hadoopConf: Configuration)
- extends TableReader with Logging {
+ extends TableReader with CastSupport with Logging {
// Hadoop honors "mapreduce.job.maps" as hint,
// but will ignore when mapreduce.jobtracker.address is "local".
@@ -86,6 +88,8 @@ class HadoopTableReader(
private val _broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
+ override def conf: SQLConf = sparkSession.sessionState.conf
+
override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] =
makeRDDForTable(
hiveTable,
@@ -227,7 +231,7 @@ class HadoopTableReader(
def fillPartitionKeys(rawPartValues: Array[String], row: InternalRow): Unit = {
partitionKeyAttrs.foreach { case (attr, ordinal) =>
val partOrdinal = partitionKeys.indexOf(attr)
- row(ordinal) = Cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null)
+ row(ordinal) = cast(Literal(rawPartValues(partOrdinal)), attr.dataType).eval(null)
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
index 8cff0ca0963bd..ee3eb2ee8abe5 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
@@ -90,10 +90,15 @@ private[hive] trait HiveClient {
def dropTable(dbName: String, tableName: String, ignoreIfNotExists: Boolean, purge: Boolean): Unit
/** Alter a table whose name matches the one specified in `table`, assuming it exists. */
- final def alterTable(table: CatalogTable): Unit = alterTable(table.identifier.table, table)
+ final def alterTable(table: CatalogTable): Unit = {
+ alterTable(table.database, table.identifier.table, table)
+ }
- /** Updates the given table with new metadata, optionally renaming the table. */
- def alterTable(tableName: String, table: CatalogTable): Unit
+ /**
+ * Updates the given table with new metadata, optionally renaming the table or
+ * moving across different database.
+ */
+ def alterTable(dbName: String, tableName: String, table: CatalogTable): Unit
/** Creates a new database with the given name. */
def createDatabase(database: CatalogDatabase, ignoreIfExists: Boolean): Unit
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
index 5e5c0a2a5078c..426db6a4e1c12 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
@@ -21,6 +21,7 @@ import java.io.{File, PrintStream}
import java.util.Locale
import scala.collection.JavaConverters._
+import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.conf.Configuration
@@ -49,6 +50,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.sql.execution.command.DDLUtils
+import org.apache.spark.sql.hive.HiveExternalCatalog
import org.apache.spark.sql.hive.client.HiveClientImpl._
import org.apache.spark.sql.types._
import org.apache.spark.util.{CircularBuffer, Utils}
@@ -493,7 +495,10 @@ private[hive] class HiveClientImpl(
shim.dropTable(client, dbName, tableName, true, ignoreIfNotExists, purge)
}
- override def alterTable(tableName: String, table: CatalogTable): Unit = withHiveState {
+ override def alterTable(
+ dbName: String,
+ tableName: String,
+ table: CatalogTable): Unit = withHiveState {
// getTableOption removes all the Hive-specific properties. Here, we fill them back to ensure
// these properties are still available to the others that share the same Hive metastore.
// If users explicitly alter these Hive-specific properties through ALTER TABLE DDL, we respect
@@ -501,7 +506,7 @@ private[hive] class HiveClientImpl(
val hiveTable = toHiveTable(
table.copy(properties = table.ignoredProperties ++ table.properties), Some(userName))
// Do not use `table.qualifiedName` here because this may be a rename
- val qualifiedTableName = s"${table.database}.$tableName"
+ val qualifiedTableName = s"$dbName.$tableName"
shim.alterTable(client, qualifiedTableName, hiveTable)
}
@@ -844,7 +849,12 @@ private[hive] object HiveClientImpl {
throw new SparkException("Cannot recognize hive type string: " + hc.getType, e)
}
- val metadata = new MetadataBuilder().putString(HIVE_TYPE_STRING, hc.getType).build()
+ val metadata = if (hc.getType != columnType.catalogString) {
+ new MetadataBuilder().putString(HIVE_TYPE_STRING, hc.getType).build()
+ } else {
+ Metadata.empty
+ }
+
val field = StructField(
name = hc.getName,
dataType = columnType,
@@ -882,7 +892,7 @@ private[hive] object HiveClientImpl {
}
// after SPARK-19279, it is not allowed to create a hive table with an empty schema,
// so here we should not add a default col schema
- if (schema.isEmpty && DDLUtils.isDatasourceTable(table)) {
+ if (schema.isEmpty && HiveExternalCatalog.isDatasourceTable(table)) {
// This is a hack to preserve existing behavior. Before Spark 2.0, we do not
// set a default serde here (this was done in Hive), and so if the user provides
// an empty schema Hive would automatically populate the schema with a single
@@ -960,6 +970,7 @@ private[hive] object HiveClientImpl {
tpart.setTableName(ht.getTableName)
tpart.setValues(partValues.asJava)
tpart.setSd(storageDesc)
+ tpart.setParameters(mutable.Map(p.parameters.toSeq: _*).asJava)
new HivePartition(ht, tpart)
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
index 896f24f2e223d..48d0b4a63e54a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
@@ -30,6 +30,7 @@ import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.CastSupport
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
@@ -37,6 +38,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.client.HiveClientImpl
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BooleanType, DataType}
import org.apache.spark.util.Utils
@@ -53,11 +55,13 @@ case class HiveTableScanExec(
relation: HiveTableRelation,
partitionPruningPred: Seq[Expression])(
@transient private val sparkSession: SparkSession)
- extends LeafExecNode {
+ extends LeafExecNode with CastSupport {
require(partitionPruningPred.isEmpty || relation.isPartitioned,
"Partition pruning predicates only supported for partitioned tables.")
+ override def conf: SQLConf = sparkSession.sessionState.conf
+
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
@@ -104,7 +108,7 @@ case class HiveTableScanExec(
hadoopConf)
private def castFromString(value: String, dataType: DataType) = {
- Cast(Literal(value), dataType).eval(null)
+ cast(Literal(value), dataType).eval(null)
}
private def addColumnMetadataToConf(hiveConf: Configuration): Unit = {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala
new file mode 100644
index 0000000000000..918c8be00d69d
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import scala.language.existentials
+
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.hive.common.FileUtils
+import org.apache.hadoop.hive.ql.plan.TableDesc
+import org.apache.hadoop.hive.serde.serdeConstants
+import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
+import org.apache.hadoop.mapred._
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.hive.client.HiveClientImpl
+
+/**
+ * Command for writing the results of `query` to file system.
+ *
+ * The syntax of using this command in SQL is:
+ * {{{
+ * INSERT OVERWRITE [LOCAL] DIRECTORY
+ * path
+ * [ROW FORMAT row_format]
+ * [STORED AS file_format]
+ * SELECT ...
+ * }}}
+ *
+ * @param isLocal whether the path specified in `storage` is a local directory
+ * @param storage storage format used to describe how the query result is stored.
+ * @param query the logical plan representing data to write to
+ * @param overwrite whether overwrites existing directory
+ */
+case class InsertIntoHiveDirCommand(
+ isLocal: Boolean,
+ storage: CatalogStorageFormat,
+ query: LogicalPlan,
+ overwrite: Boolean) extends SaveAsHiveFile {
+
+ override def children: Seq[LogicalPlan] = query :: Nil
+
+ override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = {
+ assert(children.length == 1)
+ assert(storage.locationUri.nonEmpty)
+
+ val hiveTable = HiveClientImpl.toHiveTable(CatalogTable(
+ identifier = TableIdentifier(storage.locationUri.get.toString, Some("default")),
+ tableType = org.apache.spark.sql.catalyst.catalog.CatalogTableType.VIEW,
+ storage = storage,
+ schema = query.schema
+ ))
+ hiveTable.getMetadata.put(serdeConstants.SERIALIZATION_LIB,
+ storage.serde.getOrElse(classOf[LazySimpleSerDe].getName))
+
+ val tableDesc = new TableDesc(
+ hiveTable.getInputFormatClass,
+ hiveTable.getOutputFormatClass,
+ hiveTable.getMetadata
+ )
+
+ val hadoopConf = sparkSession.sessionState.newHadoopConf()
+ val jobConf = new JobConf(hadoopConf)
+
+ val targetPath = new Path(storage.locationUri.get)
+ val writeToPath =
+ if (isLocal) {
+ val localFileSystem = FileSystem.getLocal(jobConf)
+ localFileSystem.makeQualified(targetPath)
+ } else {
+ val qualifiedPath = FileUtils.makeQualified(targetPath, hadoopConf)
+ val dfs = qualifiedPath.getFileSystem(jobConf)
+ if (!dfs.exists(qualifiedPath)) {
+ dfs.mkdirs(qualifiedPath.getParent)
+ }
+ qualifiedPath
+ }
+
+ val tmpPath = getExternalTmpPath(sparkSession, hadoopConf, writeToPath)
+ val fileSinkConf = new org.apache.spark.sql.hive.HiveShim.ShimFileSinkDesc(
+ tmpPath.toString, tableDesc, false)
+
+ try {
+ saveAsHiveFile(
+ sparkSession = sparkSession,
+ plan = children.head,
+ hadoopConf = hadoopConf,
+ fileSinkConf = fileSinkConf,
+ outputLocation = tmpPath.toString)
+
+ val fs = writeToPath.getFileSystem(hadoopConf)
+ if (overwrite && fs.exists(writeToPath)) {
+ fs.listStatus(writeToPath).foreach { existFile =>
+ if (Option(existFile.getPath) != createdTempDir) fs.delete(existFile.getPath, true)
+ }
+ }
+
+ fs.listStatus(tmpPath).foreach {
+ tmpFile => fs.rename(tmpFile.getPath, writeToPath)
+ }
+ } catch {
+ case e: Throwable =>
+ throw new SparkException(
+ "Failed inserting overwrite directory " + storage.locationUri.get, e)
+ } finally {
+ deleteExternalTmpPath(hadoopConf)
+ }
+
+ Seq.empty[Row]
+ }
+}
+
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index 46610f84dd822..e5b59ed7a1a6b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -17,32 +17,22 @@
package org.apache.spark.sql.hive.execution
-import java.io.{File, IOException}
-import java.net.URI
-import java.text.SimpleDateFormat
-import java.util.{Date, Locale, Random}
-
import scala.util.control.NonFatal
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.hadoop.hive.common.FileUtils
+import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.ql.ErrorMsg
-import org.apache.hadoop.hive.ql.exec.TaskRunner
import org.apache.hadoop.hive.ql.plan.TableDesc
import org.apache.spark.SparkException
-import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.command.{CommandUtils, DataWritingCommand}
-import org.apache.spark.sql.execution.datasources.FileFormatWriter
+import org.apache.spark.sql.execution.command.CommandUtils
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc}
-import org.apache.spark.sql.hive.client.{HiveClientImpl, HiveVersion}
+import org.apache.spark.sql.hive.client.HiveClientImpl
/**
@@ -80,152 +70,10 @@ case class InsertIntoHiveTable(
partition: Map[String, Option[String]],
query: LogicalPlan,
overwrite: Boolean,
- ifPartitionNotExists: Boolean) extends DataWritingCommand {
+ ifPartitionNotExists: Boolean) extends SaveAsHiveFile {
override def children: Seq[LogicalPlan] = query :: Nil
- var createdTempDir: Option[Path] = None
-
- private def executionId: String = {
- val rand: Random = new Random
- val format = new SimpleDateFormat("yyyy-MM-dd_HH-mm-ss_SSS", Locale.US)
- "hive_" + format.format(new Date) + "_" + Math.abs(rand.nextLong)
- }
-
- private def getStagingDir(
- inputPath: Path,
- hadoopConf: Configuration,
- stagingDir: String): Path = {
- val inputPathUri: URI = inputPath.toUri
- val inputPathName: String = inputPathUri.getPath
- val fs: FileSystem = inputPath.getFileSystem(hadoopConf)
- var stagingPathName: String =
- if (inputPathName.indexOf(stagingDir) == -1) {
- new Path(inputPathName, stagingDir).toString
- } else {
- inputPathName.substring(0, inputPathName.indexOf(stagingDir) + stagingDir.length)
- }
-
- // SPARK-20594: This is a walk-around fix to resolve a Hive bug. Hive requires that the
- // staging directory needs to avoid being deleted when users set hive.exec.stagingdir
- // under the table directory.
- if (FileUtils.isSubDir(new Path(stagingPathName), inputPath, fs) &&
- !stagingPathName.stripPrefix(inputPathName).stripPrefix(File.separator).startsWith(".")) {
- logDebug(s"The staging dir '$stagingPathName' should be a child directory starts " +
- "with '.' to avoid being deleted if we set hive.exec.stagingdir under the table " +
- "directory.")
- stagingPathName = new Path(inputPathName, ".hive-staging").toString
- }
-
- val dir: Path =
- fs.makeQualified(
- new Path(stagingPathName + "_" + executionId + "-" + TaskRunner.getTaskRunnerID))
- logDebug("Created staging dir = " + dir + " for path = " + inputPath)
- try {
- if (!FileUtils.mkdir(fs, dir, true, hadoopConf)) {
- throw new IllegalStateException("Cannot create staging directory '" + dir.toString + "'")
- }
- createdTempDir = Some(dir)
- fs.deleteOnExit(dir)
- } catch {
- case e: IOException =>
- throw new RuntimeException(
- "Cannot create staging directory '" + dir.toString + "': " + e.getMessage, e)
- }
- dir
- }
-
- private def getExternalScratchDir(
- extURI: URI,
- hadoopConf: Configuration,
- stagingDir: String): Path = {
- getStagingDir(
- new Path(extURI.getScheme, extURI.getAuthority, extURI.getPath),
- hadoopConf,
- stagingDir)
- }
-
- def getExternalTmpPath(
- path: Path,
- hiveVersion: HiveVersion,
- hadoopConf: Configuration,
- stagingDir: String,
- scratchDir: String): Path = {
- import org.apache.spark.sql.hive.client.hive._
-
- // Before Hive 1.1, when inserting into a table, Hive will create the staging directory under
- // a common scratch directory. After the writing is finished, Hive will simply empty the table
- // directory and move the staging directory to it.
- // After Hive 1.1, Hive will create the staging directory under the table directory, and when
- // moving staging directory to table directory, Hive will still empty the table directory, but
- // will exclude the staging directory there.
- // We have to follow the Hive behavior here, to avoid troubles. For example, if we create
- // staging directory under the table director for Hive prior to 1.1, the staging directory will
- // be removed by Hive when Hive is trying to empty the table directory.
- val hiveVersionsUsingOldExternalTempPath: Set[HiveVersion] = Set(v12, v13, v14, v1_0)
- val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = Set(v1_1, v1_2, v2_0, v2_1)
-
- // Ensure all the supported versions are considered here.
- assert(hiveVersionsUsingNewExternalTempPath ++ hiveVersionsUsingOldExternalTempPath ==
- allSupportedHiveVersions)
-
- if (hiveVersionsUsingOldExternalTempPath.contains(hiveVersion)) {
- oldVersionExternalTempPath(path, hadoopConf, scratchDir)
- } else if (hiveVersionsUsingNewExternalTempPath.contains(hiveVersion)) {
- newVersionExternalTempPath(path, hadoopConf, stagingDir)
- } else {
- throw new IllegalStateException("Unsupported hive version: " + hiveVersion.fullVersion)
- }
- }
-
- // Mostly copied from Context.java#getExternalTmpPath of Hive 0.13
- def oldVersionExternalTempPath(
- path: Path,
- hadoopConf: Configuration,
- scratchDir: String): Path = {
- val extURI: URI = path.toUri
- val scratchPath = new Path(scratchDir, executionId)
- var dirPath = new Path(
- extURI.getScheme,
- extURI.getAuthority,
- scratchPath.toUri.getPath + "-" + TaskRunner.getTaskRunnerID())
-
- try {
- val fs: FileSystem = dirPath.getFileSystem(hadoopConf)
- dirPath = new Path(fs.makeQualified(dirPath).toString())
-
- if (!FileUtils.mkdir(fs, dirPath, true, hadoopConf)) {
- throw new IllegalStateException("Cannot create staging directory: " + dirPath.toString)
- }
- createdTempDir = Some(dirPath)
- fs.deleteOnExit(dirPath)
- } catch {
- case e: IOException =>
- throw new RuntimeException("Cannot create staging directory: " + dirPath.toString, e)
- }
- dirPath
- }
-
- // Mostly copied from Context.java#getExternalTmpPath of Hive 1.2
- def newVersionExternalTempPath(
- path: Path,
- hadoopConf: Configuration,
- stagingDir: String): Path = {
- val extURI: URI = path.toUri
- if (extURI.getScheme == "viewfs") {
- getExtTmpPathRelTo(path.getParent, hadoopConf, stagingDir)
- } else {
- new Path(getExternalScratchDir(extURI, hadoopConf, stagingDir), "-ext-10000")
- }
- }
-
- def getExtTmpPathRelTo(
- path: Path,
- hadoopConf: Configuration,
- stagingDir: String): Path = {
- new Path(getStagingDir(path, hadoopConf, stagingDir), "-ext-10000") // Hive uses 10000
- }
-
/**
* Inserts all the rows in the table into Hive. Row objects are properly serialized with the
* `org.apache.hadoop.hive.serde2.SerDe` and the
@@ -234,12 +82,8 @@ case class InsertIntoHiveTable(
override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = {
assert(children.length == 1)
- val sessionState = sparkSession.sessionState
val externalCatalog = sparkSession.sharedState.externalCatalog
- val hiveVersion = externalCatalog.asInstanceOf[HiveExternalCatalog].client.version
- val hadoopConf = sessionState.newHadoopConf()
- val stagingDir = hadoopConf.get("hive.exec.stagingdir", ".hive-staging")
- val scratchDir = hadoopConf.get("hive.exec.scratchdir", "/tmp/hive")
+ val hadoopConf = sparkSession.sessionState.newHadoopConf()
val hiveQlTable = HiveClientImpl.toHiveTable(table)
// Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer
@@ -254,23 +98,8 @@ case class InsertIntoHiveTable(
hiveQlTable.getMetadata
)
val tableLocation = hiveQlTable.getDataLocation
- val tmpLocation =
- getExternalTmpPath(tableLocation, hiveVersion, hadoopConf, stagingDir, scratchDir)
+ val tmpLocation = getExternalTmpPath(sparkSession, hadoopConf, tableLocation)
val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false)
- val isCompressed = hadoopConf.get("hive.exec.compress.output", "false").toBoolean
-
- if (isCompressed) {
- // Please note that isCompressed, "mapreduce.output.fileoutputformat.compress",
- // "mapreduce.output.fileoutputformat.compress.codec", and
- // "mapreduce.output.fileoutputformat.compress.type"
- // have no impact on ORC because it uses table properties to store compression information.
- hadoopConf.set("mapreduce.output.fileoutputformat.compress", "true")
- fileSinkConf.setCompressed(true)
- fileSinkConf.setCompressCodec(hadoopConf
- .get("mapreduce.output.fileoutputformat.compress.codec"))
- fileSinkConf.setCompressType(hadoopConf
- .get("mapreduce.output.fileoutputformat.compress.type"))
- }
val numDynamicPartitions = partition.values.count(_.isEmpty)
val numStaticPartitions = partition.values.count(_.nonEmpty)
@@ -332,11 +161,6 @@ case class InsertIntoHiveTable(
case _ => // do nothing since table has no bucketing
}
- val committer = FileCommitProtocol.instantiate(
- sparkSession.sessionState.conf.fileCommitProtocolClass,
- jobId = java.util.UUID.randomUUID().toString,
- outputPath = tmpLocation.toString)
-
val partitionAttributes = partitionColumnNames.takeRight(numDynamicPartitions).map { name =>
query.resolve(name :: Nil, sparkSession.sessionState.analyzer.resolver).getOrElse {
throw new AnalysisException(
@@ -344,17 +168,13 @@ case class InsertIntoHiveTable(
}.asInstanceOf[Attribute]
}
- FileFormatWriter.write(
+ saveAsHiveFile(
sparkSession = sparkSession,
plan = children.head,
- fileFormat = new HiveFileFormat(fileSinkConf),
- committer = committer,
- outputSpec = FileFormatWriter.OutputSpec(tmpLocation.toString, Map.empty),
hadoopConf = hadoopConf,
- partitionColumns = partitionAttributes,
- bucketSpec = None,
- statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)),
- options = Map.empty)
+ fileSinkConf = fileSinkConf,
+ outputLocation = tmpLocation.toString,
+ partitionAttributes = partitionAttributes)
if (partition.nonEmpty) {
if (numDynamicPartitions > 0) {
@@ -422,18 +242,7 @@ case class InsertIntoHiveTable(
// Attempt to delete the staging directory and the inclusive files. If failed, the files are
// expected to be dropped at the normal termination of VM since deleteOnExit is used.
- try {
- createdTempDir.foreach { path =>
- val fs = path.getFileSystem(hadoopConf)
- if (fs.delete(path, true)) {
- // If we successfully delete the staging directory, remove it from FileSystem's cache.
- fs.cancelDeleteOnExit(path)
- }
- }
- } catch {
- case NonFatal(e) =>
- logWarning(s"Unable to delete staging directory: $stagingDir.\n" + e)
- }
+ deleteExternalTmpPath(hadoopConf)
// un-cache this table.
sparkSession.catalog.uncacheTable(table.identifier.quotedString)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala
new file mode 100644
index 0000000000000..2d74ef040ef5a
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala
@@ -0,0 +1,250 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import java.io.{File, IOException}
+import java.net.URI
+import java.text.SimpleDateFormat
+import java.util.{Date, Locale, Random}
+
+import scala.util.control.NonFatal
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.hive.common.FileUtils
+import org.apache.hadoop.hive.ql.exec.TaskRunner
+
+import org.apache.spark.internal.io.FileCommitProtocol
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.command.DataWritingCommand
+import org.apache.spark.sql.execution.datasources.FileFormatWriter
+import org.apache.spark.sql.hive.HiveExternalCatalog
+import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc}
+import org.apache.spark.sql.hive.client.HiveVersion
+
+// Base trait from which all hive insert statement physical execution extends.
+private[hive] trait SaveAsHiveFile extends DataWritingCommand {
+
+ var createdTempDir: Option[Path] = None
+
+ protected def saveAsHiveFile(
+ sparkSession: SparkSession,
+ plan: SparkPlan,
+ hadoopConf: Configuration,
+ fileSinkConf: FileSinkDesc,
+ outputLocation: String,
+ customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty,
+ partitionAttributes: Seq[Attribute] = Nil): Set[String] = {
+
+ val isCompressed = hadoopConf.get("hive.exec.compress.output", "false").toBoolean
+ if (isCompressed) {
+ // Please note that isCompressed, "mapreduce.output.fileoutputformat.compress",
+ // "mapreduce.output.fileoutputformat.compress.codec", and
+ // "mapreduce.output.fileoutputformat.compress.type"
+ // have no impact on ORC because it uses table properties to store compression information.
+ hadoopConf.set("mapreduce.output.fileoutputformat.compress", "true")
+ fileSinkConf.setCompressed(true)
+ fileSinkConf.setCompressCodec(hadoopConf
+ .get("mapreduce.output.fileoutputformat.compress.codec"))
+ fileSinkConf.setCompressType(hadoopConf
+ .get("mapreduce.output.fileoutputformat.compress.type"))
+ }
+
+ val committer = FileCommitProtocol.instantiate(
+ sparkSession.sessionState.conf.fileCommitProtocolClass,
+ jobId = java.util.UUID.randomUUID().toString,
+ outputPath = outputLocation)
+
+ FileFormatWriter.write(
+ sparkSession = sparkSession,
+ plan = plan,
+ fileFormat = new HiveFileFormat(fileSinkConf),
+ committer = committer,
+ outputSpec = FileFormatWriter.OutputSpec(outputLocation, customPartitionLocations),
+ hadoopConf = hadoopConf,
+ partitionColumns = partitionAttributes,
+ bucketSpec = None,
+ statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)),
+ options = Map.empty)
+ }
+
+ protected def getExternalTmpPath(
+ sparkSession: SparkSession,
+ hadoopConf: Configuration,
+ path: Path): Path = {
+ import org.apache.spark.sql.hive.client.hive._
+
+ // Before Hive 1.1, when inserting into a table, Hive will create the staging directory under
+ // a common scratch directory. After the writing is finished, Hive will simply empty the table
+ // directory and move the staging directory to it.
+ // After Hive 1.1, Hive will create the staging directory under the table directory, and when
+ // moving staging directory to table directory, Hive will still empty the table directory, but
+ // will exclude the staging directory there.
+ // We have to follow the Hive behavior here, to avoid troubles. For example, if we create
+ // staging directory under the table director for Hive prior to 1.1, the staging directory will
+ // be removed by Hive when Hive is trying to empty the table directory.
+ val hiveVersionsUsingOldExternalTempPath: Set[HiveVersion] = Set(v12, v13, v14, v1_0)
+ val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = Set(v1_1, v1_2, v2_0, v2_1)
+
+ // Ensure all the supported versions are considered here.
+ assert(hiveVersionsUsingNewExternalTempPath ++ hiveVersionsUsingOldExternalTempPath ==
+ allSupportedHiveVersions)
+
+ val externalCatalog = sparkSession.sharedState.externalCatalog
+ val hiveVersion = externalCatalog.asInstanceOf[HiveExternalCatalog].client.version
+ val stagingDir = hadoopConf.get("hive.exec.stagingdir", ".hive-staging")
+ val scratchDir = hadoopConf.get("hive.exec.scratchdir", "/tmp/hive")
+
+ if (hiveVersionsUsingOldExternalTempPath.contains(hiveVersion)) {
+ oldVersionExternalTempPath(path, hadoopConf, scratchDir)
+ } else if (hiveVersionsUsingNewExternalTempPath.contains(hiveVersion)) {
+ newVersionExternalTempPath(path, hadoopConf, stagingDir)
+ } else {
+ throw new IllegalStateException("Unsupported hive version: " + hiveVersion.fullVersion)
+ }
+ }
+
+ protected def deleteExternalTmpPath(hadoopConf: Configuration) : Unit = {
+ // Attempt to delete the staging directory and the inclusive files. If failed, the files are
+ // expected to be dropped at the normal termination of VM since deleteOnExit is used.
+ try {
+ createdTempDir.foreach { path =>
+ val fs = path.getFileSystem(hadoopConf)
+ if (fs.delete(path, true)) {
+ // If we successfully delete the staging directory, remove it from FileSystem's cache.
+ fs.cancelDeleteOnExit(path)
+ }
+ }
+ } catch {
+ case NonFatal(e) =>
+ val stagingDir = hadoopConf.get("hive.exec.stagingdir", ".hive-staging")
+ logWarning(s"Unable to delete staging directory: $stagingDir.\n" + e)
+ }
+ }
+
+ // Mostly copied from Context.java#getExternalTmpPath of Hive 0.13
+ private def oldVersionExternalTempPath(
+ path: Path,
+ hadoopConf: Configuration,
+ scratchDir: String): Path = {
+ val extURI: URI = path.toUri
+ val scratchPath = new Path(scratchDir, executionId)
+ var dirPath = new Path(
+ extURI.getScheme,
+ extURI.getAuthority,
+ scratchPath.toUri.getPath + "-" + TaskRunner.getTaskRunnerID())
+
+ try {
+ val fs: FileSystem = dirPath.getFileSystem(hadoopConf)
+ dirPath = new Path(fs.makeQualified(dirPath).toString())
+
+ if (!FileUtils.mkdir(fs, dirPath, true, hadoopConf)) {
+ throw new IllegalStateException("Cannot create staging directory: " + dirPath.toString)
+ }
+ createdTempDir = Some(dirPath)
+ fs.deleteOnExit(dirPath)
+ } catch {
+ case e: IOException =>
+ throw new RuntimeException("Cannot create staging directory: " + dirPath.toString, e)
+ }
+ dirPath
+ }
+
+ // Mostly copied from Context.java#getExternalTmpPath of Hive 1.2
+ private def newVersionExternalTempPath(
+ path: Path,
+ hadoopConf: Configuration,
+ stagingDir: String): Path = {
+ val extURI: URI = path.toUri
+ if (extURI.getScheme == "viewfs") {
+ getExtTmpPathRelTo(path.getParent, hadoopConf, stagingDir)
+ } else {
+ new Path(getExternalScratchDir(extURI, hadoopConf, stagingDir), "-ext-10000")
+ }
+ }
+
+ private def getExtTmpPathRelTo(
+ path: Path,
+ hadoopConf: Configuration,
+ stagingDir: String): Path = {
+ new Path(getStagingDir(path, hadoopConf, stagingDir), "-ext-10000") // Hive uses 10000
+ }
+
+ private def getExternalScratchDir(
+ extURI: URI,
+ hadoopConf: Configuration,
+ stagingDir: String): Path = {
+ getStagingDir(
+ new Path(extURI.getScheme, extURI.getAuthority, extURI.getPath),
+ hadoopConf,
+ stagingDir)
+ }
+
+ private def getStagingDir(
+ inputPath: Path,
+ hadoopConf: Configuration,
+ stagingDir: String): Path = {
+ val inputPathUri: URI = inputPath.toUri
+ val inputPathName: String = inputPathUri.getPath
+ val fs: FileSystem = inputPath.getFileSystem(hadoopConf)
+ var stagingPathName: String =
+ if (inputPathName.indexOf(stagingDir) == -1) {
+ new Path(inputPathName, stagingDir).toString
+ } else {
+ inputPathName.substring(0, inputPathName.indexOf(stagingDir) + stagingDir.length)
+ }
+
+ // SPARK-20594: This is a walk-around fix to resolve a Hive bug. Hive requires that the
+ // staging directory needs to avoid being deleted when users set hive.exec.stagingdir
+ // under the table directory.
+ if (FileUtils.isSubDir(new Path(stagingPathName), inputPath, fs) &&
+ !stagingPathName.stripPrefix(inputPathName).stripPrefix(File.separator).startsWith(".")) {
+ logDebug(s"The staging dir '$stagingPathName' should be a child directory starts " +
+ "with '.' to avoid being deleted if we set hive.exec.stagingdir under the table " +
+ "directory.")
+ stagingPathName = new Path(inputPathName, ".hive-staging").toString
+ }
+
+ val dir: Path =
+ fs.makeQualified(
+ new Path(stagingPathName + "_" + executionId + "-" + TaskRunner.getTaskRunnerID))
+ logDebug("Created staging dir = " + dir + " for path = " + inputPath)
+ try {
+ if (!FileUtils.mkdir(fs, dir, true, hadoopConf)) {
+ throw new IllegalStateException("Cannot create staging directory '" + dir.toString + "'")
+ }
+ createdTempDir = Some(dir)
+ fs.deleteOnExit(dir)
+ } catch {
+ case e: IOException =>
+ throw new RuntimeException(
+ "Cannot create staging directory '" + dir.toString + "': " + e.getMessage, e)
+ }
+ dir
+ }
+
+ private def executionId: String = {
+ val rand: Random = new Random
+ val format = new SimpleDateFormat("yyyy-MM-dd_HH-mm-ss_SSS", Locale.US)
+ "hive_" + format.format(new Date) + "_" + Math.abs(rand.nextLong)
+ }
+}
+
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index a83ad61b204ad..e9bdcf00b9346 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -42,7 +42,11 @@ import org.apache.spark.sql.types._
private[hive] case class HiveSimpleUDF(
name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
- extends Expression with HiveInspectors with CodegenFallback with Logging {
+ extends Expression
+ with HiveInspectors
+ with CodegenFallback
+ with Logging
+ with UserDefinedExpression {
override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic)
@@ -119,7 +123,11 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp
private[hive] case class HiveGenericUDF(
name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
- extends Expression with HiveInspectors with CodegenFallback with Logging {
+ extends Expression
+ with HiveInspectors
+ with CodegenFallback
+ with Logging
+ with UserDefinedExpression {
override def nullable: Boolean = true
@@ -191,7 +199,7 @@ private[hive] case class HiveGenericUDTF(
name: String,
funcWrapper: HiveFunctionWrapper,
children: Seq[Expression])
- extends Generator with HiveInspectors with CodegenFallback {
+ extends Generator with HiveInspectors with CodegenFallback with UserDefinedExpression {
@transient
protected lazy val function: GenericUDTF = {
@@ -303,7 +311,9 @@ private[hive] case class HiveUDAFFunction(
isUDAFBridgeRequired: Boolean = false,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
- extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer] with HiveInspectors {
+ extends TypedImperativeAggregate[GenericUDAFEvaluator.AggregationBuffer]
+ with HiveInspectors
+ with UserDefinedExpression {
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
index 3a34ec55c8b07..4d92a67044373 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala
@@ -34,7 +34,7 @@ import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit}
import org.apache.spark.TaskContext
-import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.datasources._
@@ -68,7 +68,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
- val orcOptions = new OrcOptions(options)
+ val orcOptions = new OrcOptions(options, sparkSession.sessionState.conf)
val configuration = job.getConfiguration
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala
index 043eb69818ba1..7f94c8c579026 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala
@@ -20,30 +20,34 @@ package org.apache.spark.sql.hive.orc
import java.util.Locale
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.internal.SQLConf
/**
* Options for the ORC data source.
*/
-private[orc] class OrcOptions(@transient private val parameters: CaseInsensitiveMap[String])
+private[orc] class OrcOptions(
+ @transient private val parameters: CaseInsensitiveMap[String],
+ @transient private val sqlConf: SQLConf)
extends Serializable {
import OrcOptions._
- def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters))
+ def this(parameters: Map[String, String], sqlConf: SQLConf) =
+ this(CaseInsensitiveMap(parameters), sqlConf)
/**
- * Compression codec to use. By default snappy compression.
+ * Compression codec to use.
* Acceptable values are defined in [[shortOrcCompressionCodecNames]].
*/
val compressionCodec: String = {
- // `orc.compress` is a ORC configuration. So, here we respect this as an option but
- // `compression` has higher precedence than `orc.compress`. It means if both are set,
- // we will use `compression`.
+ // `compression`, `orc.compress`, and `spark.sql.orc.compression.codec` are
+ // in order of precedence from highest to lowest.
val orcCompressionConf = parameters.get(OrcRelation.ORC_COMPRESSION)
val codecName = parameters
.get("compression")
.orElse(orcCompressionConf)
- .getOrElse("snappy").toLowerCase(Locale.ROOT)
+ .getOrElse(sqlConf.orcCompressionCodec)
+ .toLowerCase(Locale.ROOT)
if (!shortOrcCompressionCodecNames.contains(codecName)) {
val availableCodecs = shortOrcCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT))
throw new IllegalArgumentException(s"Codec [$codecName] " +
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 9e15baa4b2b74..0f6a81b6f813b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -51,11 +51,13 @@ object TestHive
"TestSQLContext",
new SparkConf()
.set("spark.sql.test", "")
+ .set(SQLConf.CODEGEN_FALLBACK.key, "false")
.set("spark.sql.hive.metastore.barrierPrefixes",
"org.apache.spark.sql.hive.execution.PairSerDe")
.set("spark.sql.warehouse.dir", TestHiveContext.makeWarehouseDir().toURI.getPath)
// SPARK-8910
- .set("spark.ui.enabled", "false")))
+ .set("spark.ui.enabled", "false")
+ .set("spark.unsafe.exceptionOnMemoryLeak", "true")))
case class TestHiveVersion(hiveClient: HiveClient)
diff --git a/sql/hive/src/test/resources/avroDecimal/decimal.avro b/sql/hive/src/test/resources/avroDecimal/decimal.avro
new file mode 100755
index 0000000000000..6da423f78661f
Binary files /dev/null and b/sql/hive/src/test/resources/avroDecimal/decimal.avro differ
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala
index 90f90599d5bf4..d9cf1f361c1d6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionSQLBuilderSuite.scala
@@ -19,12 +19,29 @@ package org.apache.spark.sql.catalyst
import java.sql.Timestamp
+import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions.{If, Literal, SpecifiedWindowFrame, TimeAdd,
- TimeSub, WindowSpecDefinition}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.unsafe.types.CalendarInterval
-class ExpressionSQLBuilderSuite extends SQLBuilderTest {
+class ExpressionSQLBuilderSuite extends QueryTest with TestHiveSingleton {
+ protected def checkSQL(e: Expression, expectedSQL: String): Unit = {
+ val actualSQL = e.sql
+ try {
+ assert(actualSQL == expectedSQL)
+ } catch {
+ case cause: Throwable =>
+ fail(
+ s"""Wrong SQL generated for the following expression:
+ |
+ |${e.prettyName}
+ |
+ |$cause
+ """.stripMargin)
+ }
+ }
+
test("literal") {
checkSQL(Literal("foo"), "'foo'")
checkSQL(Literal("\"foo\""), "'\"foo\"'")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala
deleted file mode 100644
index bee470d8e1382..0000000000000
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala
+++ /dev/null
@@ -1,739 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive
-
-import java.net.URI
-import java.util.Locale
-
-import org.apache.spark.sql.{AnalysisException, SaveMode}
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType}
-import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.dsl.plans
-import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
-import org.apache.spark.sql.catalyst.expressions.JsonTuple
-import org.apache.spark.sql.catalyst.parser.ParseException
-import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, ScriptTransformation}
-import org.apache.spark.sql.execution.command._
-import org.apache.spark.sql.execution.datasources.CreateTable
-import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton}
-import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.sql.types.StructType
-
-class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingleton {
- val parser = TestHive.sessionState.sqlParser
-
- private def extractTableDesc(sql: String): (CatalogTable, Boolean) = {
- parser.parsePlan(sql).collect {
- case CreateTable(tableDesc, mode, _) => (tableDesc, mode == SaveMode.Ignore)
- }.head
- }
-
- private def assertUnsupported(sql: String): Unit = {
- val e = intercept[ParseException] {
- parser.parsePlan(sql)
- }
- assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed"))
- }
-
- private def analyzeCreateTable(sql: String): CatalogTable = {
- TestHive.sessionState.analyzer.execute(parser.parsePlan(sql)).collect {
- case CreateTableCommand(tableDesc, _) => tableDesc
- }.head
- }
-
- private def compareTransformQuery(sql: String, expected: LogicalPlan): Unit = {
- val plan = parser.parsePlan(sql).asInstanceOf[ScriptTransformation].copy(ioschema = null)
- comparePlans(plan, expected, checkAnalysis = false)
- }
-
- test("Test CTAS #1") {
- val s1 =
- """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view
- |COMMENT 'This is the staging page view table'
- |STORED AS RCFILE
- |LOCATION '/user/external/page_view'
- |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
- |AS SELECT * FROM src""".stripMargin
-
- val (desc, exists) = extractTableDesc(s1)
- assert(exists)
- assert(desc.identifier.database == Some("mydb"))
- assert(desc.identifier.table == "page_view")
- assert(desc.tableType == CatalogTableType.EXTERNAL)
- assert(desc.storage.locationUri == Some(new URI("/user/external/page_view")))
- assert(desc.schema.isEmpty) // will be populated later when the table is actually created
- assert(desc.comment == Some("This is the staging page view table"))
- // TODO will be SQLText
- assert(desc.viewText.isEmpty)
- assert(desc.viewDefaultDatabase.isEmpty)
- assert(desc.viewQueryColumnNames.isEmpty)
- assert(desc.partitionColumnNames.isEmpty)
- assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat"))
- assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"))
- assert(desc.storage.serde ==
- Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe"))
- assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2"))
- }
-
- test("Test CTAS #2") {
- val s2 =
- """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view
- |COMMENT 'This is the staging page view table'
- |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'
- | STORED AS
- | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat'
- | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat'
- |LOCATION '/user/external/page_view'
- |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
- |AS SELECT * FROM src""".stripMargin
-
- val (desc, exists) = extractTableDesc(s2)
- assert(exists)
- assert(desc.identifier.database == Some("mydb"))
- assert(desc.identifier.table == "page_view")
- assert(desc.tableType == CatalogTableType.EXTERNAL)
- assert(desc.storage.locationUri == Some(new URI("/user/external/page_view")))
- assert(desc.schema.isEmpty) // will be populated later when the table is actually created
- // TODO will be SQLText
- assert(desc.comment == Some("This is the staging page view table"))
- assert(desc.viewText.isEmpty)
- assert(desc.viewDefaultDatabase.isEmpty)
- assert(desc.viewQueryColumnNames.isEmpty)
- assert(desc.partitionColumnNames.isEmpty)
- assert(desc.storage.properties == Map())
- assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat"))
- assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat"))
- assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe"))
- assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2"))
- }
-
- test("Test CTAS #3") {
- val s3 = """CREATE TABLE page_view AS SELECT * FROM src"""
- val (desc, exists) = extractTableDesc(s3)
- assert(exists == false)
- assert(desc.identifier.database == None)
- assert(desc.identifier.table == "page_view")
- assert(desc.tableType == CatalogTableType.MANAGED)
- assert(desc.storage.locationUri == None)
- assert(desc.schema.isEmpty)
- assert(desc.viewText == None) // TODO will be SQLText
- assert(desc.viewDefaultDatabase.isEmpty)
- assert(desc.viewQueryColumnNames.isEmpty)
- assert(desc.storage.properties == Map())
- assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat"))
- assert(desc.storage.outputFormat ==
- Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat"))
- assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))
- assert(desc.properties == Map())
- }
-
- test("Test CTAS #4") {
- val s4 =
- """CREATE TABLE page_view
- |STORED BY 'storage.handler.class.name' AS SELECT * FROM src""".stripMargin
- intercept[AnalysisException] {
- extractTableDesc(s4)
- }
- }
-
- test("Test CTAS #5") {
- val s5 = """CREATE TABLE ctas2
- | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe"
- | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2")
- | STORED AS RCFile
- | TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22")
- | AS
- | SELECT key, value
- | FROM src
- | ORDER BY key, value""".stripMargin
- val (desc, exists) = extractTableDesc(s5)
- assert(exists == false)
- assert(desc.identifier.database == None)
- assert(desc.identifier.table == "ctas2")
- assert(desc.tableType == CatalogTableType.MANAGED)
- assert(desc.storage.locationUri == None)
- assert(desc.schema.isEmpty)
- assert(desc.viewText == None) // TODO will be SQLText
- assert(desc.viewDefaultDatabase.isEmpty)
- assert(desc.viewQueryColumnNames.isEmpty)
- assert(desc.storage.properties == Map(("serde_p1" -> "p1"), ("serde_p2" -> "p2")))
- assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat"))
- assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"))
- assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe"))
- assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22")))
- }
-
- test("CTAS statement with a PARTITIONED BY clause is not allowed") {
- assertUnsupported(s"CREATE TABLE ctas1 PARTITIONED BY (k int)" +
- " AS SELECT key, value FROM (SELECT 1 as key, 2 as value) tmp")
- }
-
- test("CTAS statement with schema") {
- assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT * FROM src")
- assertUnsupported(s"CREATE TABLE ctas1 (age INT, name STRING) AS SELECT 1, 'hello'")
- }
-
- test("unsupported operations") {
- intercept[ParseException] {
- parser.parsePlan(
- """
- |CREATE TEMPORARY TABLE ctas2
- |ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe"
- |WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2")
- |STORED AS RCFile
- |TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22")
- |AS SELECT key, value FROM src ORDER BY key, value
- """.stripMargin)
- }
- intercept[ParseException] {
- parser.parsePlan(
- """
- |CREATE TABLE user_info_bucketed(user_id BIGINT, firstname STRING, lastname STRING)
- |CLUSTERED BY(user_id) INTO 256 BUCKETS
- |AS SELECT key, value FROM src ORDER BY key, value
- """.stripMargin)
- }
- intercept[ParseException] {
- parser.parsePlan(
- """
- |CREATE TABLE user_info_bucketed(user_id BIGINT, firstname STRING, lastname STRING)
- |SKEWED BY (key) ON (1,5,6)
- |AS SELECT key, value FROM src ORDER BY key, value
- """.stripMargin)
- }
- intercept[ParseException] {
- parser.parsePlan(
- """
- |SELECT TRANSFORM (key, value) USING 'cat' AS (tKey, tValue)
- |ROW FORMAT SERDE 'org.apache.hadoop.hive.contrib.serde2.TypedBytesSerDe'
- |RECORDREADER 'org.apache.hadoop.hive.contrib.util.typedbytes.TypedBytesRecordReader'
- |FROM testData
- """.stripMargin)
- }
- }
-
- test("Invalid interval term should throw AnalysisException") {
- def assertError(sql: String, errorMessage: String): Unit = {
- val e = intercept[AnalysisException] {
- parser.parsePlan(sql)
- }
- assert(e.getMessage.contains(errorMessage))
- }
- assertError("select interval '42-32' year to month",
- "month 32 outside range [0, 11]")
- assertError("select interval '5 49:12:15' day to second",
- "hour 49 outside range [0, 23]")
- assertError("select interval '.1111111111' second",
- "nanosecond 1111111111 outside range")
- }
-
- test("use native json_tuple instead of hive's UDTF in LATERAL VIEW") {
- val analyzer = TestHive.sparkSession.sessionState.analyzer
- val plan = analyzer.execute(parser.parsePlan(
- """
- |SELECT *
- |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test
- |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b
- """.stripMargin))
-
- assert(plan.children.head.asInstanceOf[Generate].generator.isInstanceOf[JsonTuple])
- }
-
- test("transform query spec") {
- val p = ScriptTransformation(
- Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")),
- "func", Seq.empty, plans.table("e"), null)
-
- compareTransformQuery("select transform(a, b) using 'func' from e where f < 10",
- p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string)))
- compareTransformQuery("map a, b using 'func' as c, d from e",
- p.copy(output = Seq('c.string, 'd.string)))
- compareTransformQuery("reduce a, b using 'func' as (c int, d decimal(10, 0)) from e",
- p.copy(output = Seq('c.int, 'd.decimal(10, 0))))
- }
-
- test("use backticks in output of Script Transform") {
- parser.parsePlan(
- """SELECT `t`.`thing1`
- |FROM (SELECT TRANSFORM (`parquet_t1`.`key`, `parquet_t1`.`value`)
- |USING 'cat' AS (`thing1` int, `thing2` string) FROM `default`.`parquet_t1`) AS t
- """.stripMargin)
- }
-
- test("use backticks in output of Generator") {
- parser.parsePlan(
- """
- |SELECT `gentab2`.`gencol2`
- |FROM `default`.`src`
- |LATERAL VIEW explode(array(array(1, 2, 3))) `gentab1` AS `gencol1`
- |LATERAL VIEW explode(`gentab1`.`gencol1`) `gentab2` AS `gencol2`
- """.stripMargin)
- }
-
- test("use escaped backticks in output of Generator") {
- parser.parsePlan(
- """
- |SELECT `gen``tab2`.`gen``col2`
- |FROM `default`.`src`
- |LATERAL VIEW explode(array(array(1, 2, 3))) `gen``tab1` AS `gen``col1`
- |LATERAL VIEW explode(`gen``tab1`.`gen``col1`) `gen``tab2` AS `gen``col2`
- """.stripMargin)
- }
-
- test("create table - basic") {
- val query = "CREATE TABLE my_table (id int, name string)"
- val (desc, allowExisting) = extractTableDesc(query)
- assert(!allowExisting)
- assert(desc.identifier.database.isEmpty)
- assert(desc.identifier.table == "my_table")
- assert(desc.tableType == CatalogTableType.MANAGED)
- assert(desc.schema == new StructType().add("id", "int").add("name", "string"))
- assert(desc.partitionColumnNames.isEmpty)
- assert(desc.bucketSpec.isEmpty)
- assert(desc.viewText.isEmpty)
- assert(desc.viewDefaultDatabase.isEmpty)
- assert(desc.viewQueryColumnNames.isEmpty)
- assert(desc.storage.locationUri.isEmpty)
- assert(desc.storage.inputFormat ==
- Some("org.apache.hadoop.mapred.TextInputFormat"))
- assert(desc.storage.outputFormat ==
- Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat"))
- assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))
- assert(desc.storage.properties.isEmpty)
- assert(desc.properties.isEmpty)
- assert(desc.comment.isEmpty)
- }
-
- test("create table - with database name") {
- val query = "CREATE TABLE dbx.my_table (id int, name string)"
- val (desc, _) = extractTableDesc(query)
- assert(desc.identifier.database == Some("dbx"))
- assert(desc.identifier.table == "my_table")
- }
-
- test("create table - temporary") {
- val query = "CREATE TEMPORARY TABLE tab1 (id int, name string)"
- val e = intercept[ParseException] { parser.parsePlan(query) }
- assert(e.message.contains("CREATE TEMPORARY TABLE is not supported yet"))
- }
-
- test("create table - external") {
- val query = "CREATE EXTERNAL TABLE tab1 (id int, name string) LOCATION '/path/to/nowhere'"
- val (desc, _) = extractTableDesc(query)
- assert(desc.tableType == CatalogTableType.EXTERNAL)
- assert(desc.storage.locationUri == Some(new URI("/path/to/nowhere")))
- }
-
- test("create table - if not exists") {
- val query = "CREATE TABLE IF NOT EXISTS tab1 (id int, name string)"
- val (_, allowExisting) = extractTableDesc(query)
- assert(allowExisting)
- }
-
- test("create table - comment") {
- val query = "CREATE TABLE my_table (id int, name string) COMMENT 'its hot as hell below'"
- val (desc, _) = extractTableDesc(query)
- assert(desc.comment == Some("its hot as hell below"))
- }
-
- test("create table - partitioned columns") {
- val query = "CREATE TABLE my_table (id int, name string) PARTITIONED BY (month int)"
- val (desc, _) = extractTableDesc(query)
- assert(desc.schema == new StructType()
- .add("id", "int")
- .add("name", "string")
- .add("month", "int"))
- assert(desc.partitionColumnNames == Seq("month"))
- }
-
- test("create table - clustered by") {
- val numBuckets = 10
- val bucketedColumn = "id"
- val sortColumn = "id"
- val baseQuery =
- s"""
- CREATE TABLE my_table (
- $bucketedColumn int,
- name string)
- CLUSTERED BY($bucketedColumn)
- """
-
- val query1 = s"$baseQuery INTO $numBuckets BUCKETS"
- val (desc1, _) = extractTableDesc(query1)
- assert(desc1.bucketSpec.isDefined)
- val bucketSpec1 = desc1.bucketSpec.get
- assert(bucketSpec1.numBuckets == numBuckets)
- assert(bucketSpec1.bucketColumnNames.head.equals(bucketedColumn))
- assert(bucketSpec1.sortColumnNames.isEmpty)
-
- val query2 = s"$baseQuery SORTED BY($sortColumn) INTO $numBuckets BUCKETS"
- val (desc2, _) = extractTableDesc(query2)
- assert(desc2.bucketSpec.isDefined)
- val bucketSpec2 = desc2.bucketSpec.get
- assert(bucketSpec2.numBuckets == numBuckets)
- assert(bucketSpec2.bucketColumnNames.head.equals(bucketedColumn))
- assert(bucketSpec2.sortColumnNames.head.equals(sortColumn))
- }
-
- test("create table - skewed by") {
- val baseQuery = "CREATE TABLE my_table (id int, name string) SKEWED BY"
- val query1 = s"$baseQuery(id) ON (1, 10, 100)"
- val query2 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z'))"
- val query3 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z')) STORED AS DIRECTORIES"
- val e1 = intercept[ParseException] { parser.parsePlan(query1) }
- val e2 = intercept[ParseException] { parser.parsePlan(query2) }
- val e3 = intercept[ParseException] { parser.parsePlan(query3) }
- assert(e1.getMessage.contains("Operation not allowed"))
- assert(e2.getMessage.contains("Operation not allowed"))
- assert(e3.getMessage.contains("Operation not allowed"))
- }
-
- test("create table - row format") {
- val baseQuery = "CREATE TABLE my_table (id int, name string) ROW FORMAT"
- val query1 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff'"
- val query2 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1')"
- val query3 =
- s"""
- |$baseQuery DELIMITED FIELDS TERMINATED BY 'x' ESCAPED BY 'y'
- |COLLECTION ITEMS TERMINATED BY 'a'
- |MAP KEYS TERMINATED BY 'b'
- |LINES TERMINATED BY '\n'
- |NULL DEFINED AS 'c'
- """.stripMargin
- val (desc1, _) = extractTableDesc(query1)
- val (desc2, _) = extractTableDesc(query2)
- val (desc3, _) = extractTableDesc(query3)
- assert(desc1.storage.serde == Some("org.apache.poof.serde.Baff"))
- assert(desc1.storage.properties.isEmpty)
- assert(desc2.storage.serde == Some("org.apache.poof.serde.Baff"))
- assert(desc2.storage.properties == Map("k1" -> "v1"))
- assert(desc3.storage.properties == Map(
- "field.delim" -> "x",
- "escape.delim" -> "y",
- "serialization.format" -> "x",
- "line.delim" -> "\n",
- "colelction.delim" -> "a", // yes, it's a typo from Hive :)
- "mapkey.delim" -> "b"))
- }
-
- test("create table - file format") {
- val baseQuery = "CREATE TABLE my_table (id int, name string) STORED AS"
- val query1 = s"$baseQuery INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput'"
- val query2 = s"$baseQuery ORC"
- val (desc1, _) = extractTableDesc(query1)
- val (desc2, _) = extractTableDesc(query2)
- assert(desc1.storage.inputFormat == Some("winput"))
- assert(desc1.storage.outputFormat == Some("wowput"))
- assert(desc1.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))
- assert(desc2.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"))
- assert(desc2.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat"))
- assert(desc2.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde"))
- }
-
- test("create table - storage handler") {
- val baseQuery = "CREATE TABLE my_table (id int, name string) STORED BY"
- val query1 = s"$baseQuery 'org.papachi.StorageHandler'"
- val query2 = s"$baseQuery 'org.mamachi.StorageHandler' WITH SERDEPROPERTIES ('k1'='v1')"
- val e1 = intercept[ParseException] { parser.parsePlan(query1) }
- val e2 = intercept[ParseException] { parser.parsePlan(query2) }
- assert(e1.getMessage.contains("Operation not allowed"))
- assert(e2.getMessage.contains("Operation not allowed"))
- }
-
- test("create table - properties") {
- val query = "CREATE TABLE my_table (id int, name string) TBLPROPERTIES ('k1'='v1', 'k2'='v2')"
- val (desc, _) = extractTableDesc(query)
- assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2"))
- }
-
- test("create table - everything!") {
- val query =
- """
- |CREATE EXTERNAL TABLE IF NOT EXISTS dbx.my_table (id int, name string)
- |COMMENT 'no comment'
- |PARTITIONED BY (month int)
- |ROW FORMAT SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1')
- |STORED AS INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput'
- |LOCATION '/path/to/mercury'
- |TBLPROPERTIES ('k1'='v1', 'k2'='v2')
- """.stripMargin
- val (desc, allowExisting) = extractTableDesc(query)
- assert(allowExisting)
- assert(desc.identifier.database == Some("dbx"))
- assert(desc.identifier.table == "my_table")
- assert(desc.tableType == CatalogTableType.EXTERNAL)
- assert(desc.schema == new StructType()
- .add("id", "int")
- .add("name", "string")
- .add("month", "int"))
- assert(desc.partitionColumnNames == Seq("month"))
- assert(desc.bucketSpec.isEmpty)
- assert(desc.viewText.isEmpty)
- assert(desc.viewDefaultDatabase.isEmpty)
- assert(desc.viewQueryColumnNames.isEmpty)
- assert(desc.storage.locationUri == Some(new URI("/path/to/mercury")))
- assert(desc.storage.inputFormat == Some("winput"))
- assert(desc.storage.outputFormat == Some("wowput"))
- assert(desc.storage.serde == Some("org.apache.poof.serde.Baff"))
- assert(desc.storage.properties == Map("k1" -> "v1"))
- assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2"))
- assert(desc.comment == Some("no comment"))
- }
-
- test("create view -- basic") {
- val v1 = "CREATE VIEW view1 AS SELECT * FROM tab1"
- val command = parser.parsePlan(v1).asInstanceOf[CreateViewCommand]
- assert(!command.allowExisting)
- assert(command.name.database.isEmpty)
- assert(command.name.table == "view1")
- assert(command.originalText == Some("SELECT * FROM tab1"))
- assert(command.userSpecifiedColumns.isEmpty)
- }
-
- test("create view - full") {
- val v1 =
- """
- |CREATE OR REPLACE VIEW view1
- |(col1, col3 COMMENT 'hello')
- |COMMENT 'BLABLA'
- |TBLPROPERTIES('prop1Key'="prop1Val")
- |AS SELECT * FROM tab1
- """.stripMargin
- val command = parser.parsePlan(v1).asInstanceOf[CreateViewCommand]
- assert(command.name.database.isEmpty)
- assert(command.name.table == "view1")
- assert(command.userSpecifiedColumns == Seq("col1" -> None, "col3" -> Some("hello")))
- assert(command.originalText == Some("SELECT * FROM tab1"))
- assert(command.properties == Map("prop1Key" -> "prop1Val"))
- assert(command.comment == Some("BLABLA"))
- }
-
- test("create view -- partitioned view") {
- val v1 = "CREATE VIEW view1 partitioned on (ds, hr) as select * from srcpart"
- intercept[ParseException] {
- parser.parsePlan(v1)
- }
- }
-
- test("MSCK REPAIR table") {
- val sql = "MSCK REPAIR TABLE tab1"
- val parsed = parser.parsePlan(sql)
- val expected = AlterTableRecoverPartitionsCommand(
- TableIdentifier("tab1", None),
- "MSCK REPAIR TABLE")
- comparePlans(parsed, expected)
- }
-
- test("create table like") {
- val v1 = "CREATE TABLE table1 LIKE table2"
- val (target, source, location, exists) = parser.parsePlan(v1).collect {
- case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting)
- }.head
- assert(exists == false)
- assert(target.database.isEmpty)
- assert(target.table == "table1")
- assert(source.database.isEmpty)
- assert(source.table == "table2")
- assert(location.isEmpty)
-
- val v2 = "CREATE TABLE IF NOT EXISTS table1 LIKE table2"
- val (target2, source2, location2, exists2) = parser.parsePlan(v2).collect {
- case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting)
- }.head
- assert(exists2)
- assert(target2.database.isEmpty)
- assert(target2.table == "table1")
- assert(source2.database.isEmpty)
- assert(source2.table == "table2")
- assert(location2.isEmpty)
-
- val v3 = "CREATE TABLE table1 LIKE table2 LOCATION '/spark/warehouse'"
- val (target3, source3, location3, exists3) = parser.parsePlan(v3).collect {
- case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting)
- }.head
- assert(!exists3)
- assert(target3.database.isEmpty)
- assert(target3.table == "table1")
- assert(source3.database.isEmpty)
- assert(source3.table == "table2")
- assert(location3 == Some("/spark/warehouse"))
-
- val v4 = "CREATE TABLE IF NOT EXISTS table1 LIKE table2 LOCATION '/spark/warehouse'"
- val (target4, source4, location4, exists4) = parser.parsePlan(v4).collect {
- case CreateTableLikeCommand(t, s, l, allowExisting) => (t, s, l, allowExisting)
- }.head
- assert(exists4)
- assert(target4.database.isEmpty)
- assert(target4.table == "table1")
- assert(source4.database.isEmpty)
- assert(source4.table == "table2")
- assert(location4 == Some("/spark/warehouse"))
- }
-
- test("load data") {
- val v1 = "LOAD DATA INPATH 'path' INTO TABLE table1"
- val (table, path, isLocal, isOverwrite, partition) = parser.parsePlan(v1).collect {
- case LoadDataCommand(t, path, l, o, partition) => (t, path, l, o, partition)
- }.head
- assert(table.database.isEmpty)
- assert(table.table == "table1")
- assert(path == "path")
- assert(!isLocal)
- assert(!isOverwrite)
- assert(partition.isEmpty)
-
- val v2 = "LOAD DATA LOCAL INPATH 'path' OVERWRITE INTO TABLE table1 PARTITION(c='1', d='2')"
- val (table2, path2, isLocal2, isOverwrite2, partition2) = parser.parsePlan(v2).collect {
- case LoadDataCommand(t, path, l, o, partition) => (t, path, l, o, partition)
- }.head
- assert(table2.database.isEmpty)
- assert(table2.table == "table1")
- assert(path2 == "path")
- assert(isLocal2)
- assert(isOverwrite2)
- assert(partition2.nonEmpty)
- assert(partition2.get.apply("c") == "1" && partition2.get.apply("d") == "2")
- }
-
- test("Test the default fileformat for Hive-serde tables") {
- withSQLConf("hive.default.fileformat" -> "orc") {
- val (desc, exists) = extractTableDesc("CREATE TABLE IF NOT EXISTS fileformat_test (id int)")
- assert(exists)
- assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"))
- assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat"))
- assert(desc.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde"))
- }
-
- withSQLConf("hive.default.fileformat" -> "parquet") {
- val (desc, exists) = extractTableDesc("CREATE TABLE IF NOT EXISTS fileformat_test (id int)")
- assert(exists)
- val input = desc.storage.inputFormat
- val output = desc.storage.outputFormat
- val serde = desc.storage.serde
- assert(input == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"))
- assert(output == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"))
- assert(serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"))
- }
- }
-
- test("table name with schema") {
- // regression test for SPARK-11778
- spark.sql("create schema usrdb")
- spark.sql("create table usrdb.test(c int)")
- spark.read.table("usrdb.test")
- spark.sql("drop table usrdb.test")
- spark.sql("drop schema usrdb")
- }
-
- test("SPARK-15887: hive-site.xml should be loaded") {
- assert(hiveClient.getConf("hive.in.test", "") == "true")
- }
-
- test("create hive serde table with new syntax - basic") {
- val sql =
- """
- |CREATE TABLE t
- |(id int, name string COMMENT 'blabla')
- |USING hive
- |OPTIONS (fileFormat 'parquet', my_prop 1)
- |LOCATION '/tmp/file'
- |COMMENT 'BLABLA'
- """.stripMargin
-
- val table = analyzeCreateTable(sql)
- assert(table.schema == new StructType()
- .add("id", "int")
- .add("name", "string", nullable = true, comment = "blabla"))
- assert(table.provider == Some(DDLUtils.HIVE_PROVIDER))
- assert(table.storage.locationUri == Some(new URI("/tmp/file")))
- assert(table.storage.properties == Map("my_prop" -> "1"))
- assert(table.comment == Some("BLABLA"))
-
- assert(table.storage.inputFormat ==
- Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"))
- assert(table.storage.outputFormat ==
- Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"))
- assert(table.storage.serde ==
- Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"))
- }
-
- test("create hive serde table with new syntax - with partition and bucketing") {
- val v1 = "CREATE TABLE t (c1 int, c2 int) USING hive PARTITIONED BY (c2)"
- val table = analyzeCreateTable(v1)
- assert(table.schema == new StructType().add("c1", "int").add("c2", "int"))
- assert(table.partitionColumnNames == Seq("c2"))
- // check the default formats
- assert(table.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))
- assert(table.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat"))
- assert(table.storage.outputFormat ==
- Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat"))
-
- val v2 = "CREATE TABLE t (c1 int, c2 int) USING hive CLUSTERED BY (c2) INTO 4 BUCKETS"
- val e2 = intercept[AnalysisException](analyzeCreateTable(v2))
- assert(e2.message.contains("Creating bucketed Hive serde table is not supported yet"))
-
- val v3 =
- """
- |CREATE TABLE t (c1 int, c2 int) USING hive
- |PARTITIONED BY (c2)
- |CLUSTERED BY (c2) INTO 4 BUCKETS""".stripMargin
- val e3 = intercept[AnalysisException](analyzeCreateTable(v3))
- assert(e3.message.contains("Creating bucketed Hive serde table is not supported yet"))
- }
-
- test("create hive serde table with new syntax - Hive options error checking") {
- val v1 = "CREATE TABLE t (c1 int) USING hive OPTIONS (inputFormat 'abc')"
- val e1 = intercept[IllegalArgumentException](analyzeCreateTable(v1))
- assert(e1.getMessage.contains("Cannot specify only inputFormat or outputFormat"))
-
- val v2 = "CREATE TABLE t (c1 int) USING hive OPTIONS " +
- "(fileFormat 'x', inputFormat 'a', outputFormat 'b')"
- val e2 = intercept[IllegalArgumentException](analyzeCreateTable(v2))
- assert(e2.getMessage.contains(
- "Cannot specify fileFormat and inputFormat/outputFormat together"))
-
- val v3 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'parquet', serde 'a')"
- val e3 = intercept[IllegalArgumentException](analyzeCreateTable(v3))
- assert(e3.getMessage.contains("fileFormat 'parquet' already specifies a serde"))
-
- val v4 = "CREATE TABLE t (c1 int) USING hive OPTIONS (serde 'a', fieldDelim ' ')"
- val e4 = intercept[IllegalArgumentException](analyzeCreateTable(v4))
- assert(e4.getMessage.contains("Cannot specify delimiters with a custom serde"))
-
- val v5 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fieldDelim ' ')"
- val e5 = intercept[IllegalArgumentException](analyzeCreateTable(v5))
- assert(e5.getMessage.contains("Cannot specify delimiters without fileFormat"))
-
- val v6 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'parquet', fieldDelim ' ')"
- val e6 = intercept[IllegalArgumentException](analyzeCreateTable(v6))
- assert(e6.getMessage.contains(
- "Cannot specify delimiters as they are only compatible with fileFormat 'textfile'"))
-
- // The value of 'fileFormat' option is case-insensitive.
- val v7 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'TEXTFILE', lineDelim ',')"
- val e7 = intercept[IllegalArgumentException](analyzeCreateTable(v7))
- assert(e7.getMessage.contains("Hive data source only support newline '\\n' as line delimiter"))
-
- val v8 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'wrong')"
- val e8 = intercept[IllegalArgumentException](analyzeCreateTable(v8))
- assert(e8.getMessage.contains("invalid fileFormat: 'wrong'"))
- }
-}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala
deleted file mode 100644
index 3bd3d0d6db355..0000000000000
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala
+++ /dev/null
@@ -1,260 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive
-
-import java.net.URI
-
-import org.apache.hadoop.fs.Path
-import org.scalatest.BeforeAndAfterEach
-
-import org.apache.spark.sql.QueryTest
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
-import org.apache.spark.sql.hive.client.HiveClient
-import org.apache.spark.sql.hive.test.TestHiveSingleton
-import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.Utils
-
-
-class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest
- with SQLTestUtils with TestHiveSingleton with BeforeAndAfterEach {
-
- val tempDir = Utils.createTempDir().getCanonicalFile
- val tempDirUri = tempDir.toURI
- val tempDirStr = tempDir.getAbsolutePath
-
- override def beforeEach(): Unit = {
- sql("CREATE DATABASE test_db")
- for ((tbl, _) <- rawTablesAndExpectations) {
- hiveClient.createTable(tbl, ignoreIfExists = false)
- }
- }
-
- override def afterEach(): Unit = {
- Utils.deleteRecursively(tempDir)
- hiveClient.dropDatabase("test_db", ignoreIfNotExists = false, cascade = true)
- }
-
- private def getTableMetadata(tableName: String): CatalogTable = {
- spark.sharedState.externalCatalog.getTable("test_db", tableName)
- }
-
- private def defaultTableURI(tableName: String): URI = {
- spark.sessionState.catalog.defaultTablePath(TableIdentifier(tableName, Some("test_db")))
- }
-
- // Raw table metadata that are dumped from tables created by Spark 2.0. Note that, all spark
- // versions prior to 2.1 would generate almost same raw table metadata for a specific table.
- val simpleSchema = new StructType().add("i", "int")
- val partitionedSchema = new StructType().add("i", "int").add("j", "int")
-
- lazy val hiveTable = CatalogTable(
- identifier = TableIdentifier("tbl1", Some("test_db")),
- tableType = CatalogTableType.MANAGED,
- storage = CatalogStorageFormat.empty.copy(
- inputFormat = Some("org.apache.hadoop.mapred.TextInputFormat"),
- outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")),
- schema = simpleSchema)
-
- lazy val externalHiveTable = CatalogTable(
- identifier = TableIdentifier("tbl2", Some("test_db")),
- tableType = CatalogTableType.EXTERNAL,
- storage = CatalogStorageFormat.empty.copy(
- locationUri = Some(tempDirUri),
- inputFormat = Some("org.apache.hadoop.mapred.TextInputFormat"),
- outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")),
- schema = simpleSchema)
-
- lazy val partitionedHiveTable = CatalogTable(
- identifier = TableIdentifier("tbl3", Some("test_db")),
- tableType = CatalogTableType.MANAGED,
- storage = CatalogStorageFormat.empty.copy(
- inputFormat = Some("org.apache.hadoop.mapred.TextInputFormat"),
- outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")),
- schema = partitionedSchema,
- partitionColumnNames = Seq("j"))
-
-
- val simpleSchemaJson =
- """
- |{
- | "type": "struct",
- | "fields": [{
- | "name": "i",
- | "type": "integer",
- | "nullable": true,
- | "metadata": {}
- | }]
- |}
- """.stripMargin
-
- val partitionedSchemaJson =
- """
- |{
- | "type": "struct",
- | "fields": [{
- | "name": "i",
- | "type": "integer",
- | "nullable": true,
- | "metadata": {}
- | },
- | {
- | "name": "j",
- | "type": "integer",
- | "nullable": true,
- | "metadata": {}
- | }]
- |}
- """.stripMargin
-
- lazy val dataSourceTable = CatalogTable(
- identifier = TableIdentifier("tbl4", Some("test_db")),
- tableType = CatalogTableType.MANAGED,
- storage = CatalogStorageFormat.empty.copy(
- properties = Map("path" -> defaultTableURI("tbl4").toString)),
- schema = new StructType(),
- provider = Some("json"),
- properties = Map(
- "spark.sql.sources.provider" -> "json",
- "spark.sql.sources.schema.numParts" -> "1",
- "spark.sql.sources.schema.part.0" -> simpleSchemaJson))
-
- lazy val hiveCompatibleDataSourceTable = CatalogTable(
- identifier = TableIdentifier("tbl5", Some("test_db")),
- tableType = CatalogTableType.MANAGED,
- storage = CatalogStorageFormat.empty.copy(
- properties = Map("path" -> defaultTableURI("tbl5").toString)),
- schema = simpleSchema,
- provider = Some("parquet"),
- properties = Map(
- "spark.sql.sources.provider" -> "parquet",
- "spark.sql.sources.schema.numParts" -> "1",
- "spark.sql.sources.schema.part.0" -> simpleSchemaJson))
-
- lazy val partitionedDataSourceTable = CatalogTable(
- identifier = TableIdentifier("tbl6", Some("test_db")),
- tableType = CatalogTableType.MANAGED,
- storage = CatalogStorageFormat.empty.copy(
- properties = Map("path" -> defaultTableURI("tbl6").toString)),
- schema = new StructType(),
- provider = Some("json"),
- properties = Map(
- "spark.sql.sources.provider" -> "json",
- "spark.sql.sources.schema.numParts" -> "1",
- "spark.sql.sources.schema.part.0" -> partitionedSchemaJson,
- "spark.sql.sources.schema.numPartCols" -> "1",
- "spark.sql.sources.schema.partCol.0" -> "j"))
-
- lazy val externalDataSourceTable = CatalogTable(
- identifier = TableIdentifier("tbl7", Some("test_db")),
- tableType = CatalogTableType.EXTERNAL,
- storage = CatalogStorageFormat.empty.copy(
- locationUri = Some(new URI(defaultTableURI("tbl7") + "-__PLACEHOLDER__")),
- properties = Map("path" -> tempDirStr)),
- schema = new StructType(),
- provider = Some("json"),
- properties = Map(
- "spark.sql.sources.provider" -> "json",
- "spark.sql.sources.schema.numParts" -> "1",
- "spark.sql.sources.schema.part.0" -> simpleSchemaJson))
-
- lazy val hiveCompatibleExternalDataSourceTable = CatalogTable(
- identifier = TableIdentifier("tbl8", Some("test_db")),
- tableType = CatalogTableType.EXTERNAL,
- storage = CatalogStorageFormat.empty.copy(
- locationUri = Some(tempDirUri),
- properties = Map("path" -> tempDirStr)),
- schema = simpleSchema,
- properties = Map(
- "spark.sql.sources.provider" -> "parquet",
- "spark.sql.sources.schema.numParts" -> "1",
- "spark.sql.sources.schema.part.0" -> simpleSchemaJson))
-
- lazy val dataSourceTableWithoutSchema = CatalogTable(
- identifier = TableIdentifier("tbl9", Some("test_db")),
- tableType = CatalogTableType.EXTERNAL,
- storage = CatalogStorageFormat.empty.copy(
- locationUri = Some(new URI(defaultTableURI("tbl9") + "-__PLACEHOLDER__")),
- properties = Map("path" -> tempDirStr)),
- schema = new StructType(),
- provider = Some("json"),
- properties = Map("spark.sql.sources.provider" -> "json"))
-
- // A list of all raw tables we want to test, with their expected schema.
- lazy val rawTablesAndExpectations = Seq(
- hiveTable -> simpleSchema,
- externalHiveTable -> simpleSchema,
- partitionedHiveTable -> partitionedSchema,
- dataSourceTable -> simpleSchema,
- hiveCompatibleDataSourceTable -> simpleSchema,
- partitionedDataSourceTable -> partitionedSchema,
- externalDataSourceTable -> simpleSchema,
- hiveCompatibleExternalDataSourceTable -> simpleSchema,
- dataSourceTableWithoutSchema -> new StructType())
-
- test("make sure we can read table created by old version of Spark") {
- for ((tbl, expectedSchema) <- rawTablesAndExpectations) {
- val readBack = getTableMetadata(tbl.identifier.table)
- assert(readBack.schema.sameType(expectedSchema))
-
- if (tbl.tableType == CatalogTableType.EXTERNAL) {
- // trim the URI prefix
- val tableLocation = readBack.storage.locationUri.get.getPath
- val expectedLocation = tempDir.toURI.getPath.stripSuffix("/")
- assert(tableLocation == expectedLocation)
- }
- }
- }
-
- test("make sure we can alter table location created by old version of Spark") {
- withTempDir { dir =>
- for ((tbl, _) <- rawTablesAndExpectations if tbl.tableType == CatalogTableType.EXTERNAL) {
- val path = dir.toURI.toString.stripSuffix("/")
- sql(s"ALTER TABLE ${tbl.identifier} SET LOCATION '$path'")
-
- val readBack = getTableMetadata(tbl.identifier.table)
-
- // trim the URI prefix
- val actualTableLocation = readBack.storage.locationUri.get.getPath
- val expected = dir.toURI.getPath.stripSuffix("/")
- assert(actualTableLocation == expected)
- }
- }
- }
-
- test("make sure we can rename table created by old version of Spark") {
- for ((tbl, expectedSchema) <- rawTablesAndExpectations) {
- val newName = tbl.identifier.table + "_renamed"
- sql(s"ALTER TABLE ${tbl.identifier} RENAME TO $newName")
-
- val readBack = getTableMetadata(newName)
- assert(readBack.schema.sameType(expectedSchema))
-
- // trim the URI prefix
- val actualTableLocation = readBack.storage.locationUri.get.getPath
- val expectedLocation = if (tbl.tableType == CatalogTableType.EXTERNAL) {
- tempDir.toURI.getPath.stripSuffix("/")
- } else {
- // trim the URI prefix
- defaultTableURI(newName).getPath
- }
- assert(actualTableLocation == expectedLocation)
- }
- }
-}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala
new file mode 100644
index 0000000000000..2928a734a7e36
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import java.io.File
+import java.nio.file.Files
+
+import org.apache.spark.TestUtils
+import org.apache.spark.sql.{QueryTest, Row, SparkSession}
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.catalog.CatalogTableType
+import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.util.Utils
+
+/**
+ * Test HiveExternalCatalog backward compatibility.
+ *
+ * Note that, this test suite will automatically download spark binary packages of different
+ * versions to a local directory `/tmp/spark-test`. If there is already a spark folder with
+ * expected version under this local directory, e.g. `/tmp/spark-test/spark-2.0.3`, we will skip the
+ * downloading for this spark version.
+ */
+class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils {
+ private val wareHousePath = Utils.createTempDir(namePrefix = "warehouse")
+ private val tmpDataDir = Utils.createTempDir(namePrefix = "test-data")
+ private val sparkTestingDir = "/tmp/spark-test"
+ private val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
+
+ override def afterAll(): Unit = {
+ Utils.deleteRecursively(wareHousePath)
+ Utils.deleteRecursively(tmpDataDir)
+ super.afterAll()
+ }
+
+ private def downloadSpark(version: String): Unit = {
+ import scala.sys.process._
+
+ val url = s"https://d3kbcqa49mib13.cloudfront.net/spark-$version-bin-hadoop2.7.tgz"
+
+ Seq("wget", url, "-q", "-P", sparkTestingDir).!
+
+ val downloaded = new File(sparkTestingDir, s"spark-$version-bin-hadoop2.7.tgz").getCanonicalPath
+ val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath
+
+ Seq("mkdir", targetDir).!
+
+ Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").!
+
+ Seq("rm", downloaded).!
+ }
+
+ private def genDataDir(name: String): String = {
+ new File(tmpDataDir, name).getCanonicalPath
+ }
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ val tempPyFile = File.createTempFile("test", ".py")
+ Files.write(tempPyFile.toPath,
+ s"""
+ |from pyspark.sql import SparkSession
+ |
+ |spark = SparkSession.builder.enableHiveSupport().getOrCreate()
+ |version_index = spark.conf.get("spark.sql.test.version.index", None)
+ |
+ |spark.sql("create table data_source_tbl_{} using json as select 1 i".format(version_index))
+ |
+ |spark.sql("create table hive_compatible_data_source_tbl_" + version_index + \\
+ | " using parquet as select 1 i")
+ |
+ |json_file = "${genDataDir("json_")}" + str(version_index)
+ |spark.range(1, 2).selectExpr("cast(id as int) as i").write.json(json_file)
+ |spark.sql("create table external_data_source_tbl_" + version_index + \\
+ | "(i int) using json options (path '{}')".format(json_file))
+ |
+ |parquet_file = "${genDataDir("parquet_")}" + str(version_index)
+ |spark.range(1, 2).selectExpr("cast(id as int) as i").write.parquet(parquet_file)
+ |spark.sql("create table hive_compatible_external_data_source_tbl_" + version_index + \\
+ | "(i int) using parquet options (path '{}')".format(parquet_file))
+ |
+ |json_file2 = "${genDataDir("json2_")}" + str(version_index)
+ |spark.range(1, 2).selectExpr("cast(id as int) as i").write.json(json_file2)
+ |spark.sql("create table external_table_without_schema_" + version_index + \\
+ | " using json options (path '{}')".format(json_file2))
+ |
+ |spark.sql("create view v_{} as select 1 i".format(version_index))
+ """.stripMargin.getBytes("utf8"))
+
+ PROCESS_TABLES.testingVersions.zipWithIndex.foreach { case (version, index) =>
+ val sparkHome = new File(sparkTestingDir, s"spark-$version")
+ if (!sparkHome.exists()) {
+ downloadSpark(version)
+ }
+
+ val args = Seq(
+ "--name", "prepare testing tables",
+ "--master", "local[2]",
+ "--conf", "spark.ui.enabled=false",
+ "--conf", "spark.master.rest.enabled=false",
+ "--conf", s"spark.sql.warehouse.dir=${wareHousePath.getCanonicalPath}",
+ "--conf", s"spark.sql.test.version.index=$index",
+ "--driver-java-options", s"-Dderby.system.home=${wareHousePath.getCanonicalPath}",
+ tempPyFile.getCanonicalPath)
+ runSparkSubmit(args, Some(sparkHome.getCanonicalPath))
+ }
+
+ tempPyFile.delete()
+ }
+
+ test("backward compatibility") {
+ val args = Seq(
+ "--class", PROCESS_TABLES.getClass.getName.stripSuffix("$"),
+ "--name", "HiveExternalCatalog backward compatibility test",
+ "--master", "local[2]",
+ "--conf", "spark.ui.enabled=false",
+ "--conf", "spark.master.rest.enabled=false",
+ "--conf", s"spark.sql.warehouse.dir=${wareHousePath.getCanonicalPath}",
+ "--driver-java-options", s"-Dderby.system.home=${wareHousePath.getCanonicalPath}",
+ unusedJar.toString)
+ runSparkSubmit(args)
+ }
+}
+
+object PROCESS_TABLES extends QueryTest with SQLTestUtils {
+ // Tests the latest version of every release line.
+ val testingVersions = Seq("2.0.2", "2.1.1", "2.2.0")
+
+ protected var spark: SparkSession = _
+
+ def main(args: Array[String]): Unit = {
+ val session = SparkSession.builder()
+ .enableHiveSupport()
+ .getOrCreate()
+ spark = session
+
+ testingVersions.indices.foreach { index =>
+ Seq(
+ s"data_source_tbl_$index",
+ s"hive_compatible_data_source_tbl_$index",
+ s"external_data_source_tbl_$index",
+ s"hive_compatible_external_data_source_tbl_$index",
+ s"external_table_without_schema_$index").foreach { tbl =>
+ val tableMeta = spark.sharedState.externalCatalog.getTable("default", tbl)
+
+ // make sure we can insert and query these tables.
+ session.sql(s"insert into $tbl select 2")
+ checkAnswer(session.sql(s"select * from $tbl"), Row(1) :: Row(2) :: Nil)
+ checkAnswer(session.sql(s"select i from $tbl where i > 1"), Row(2))
+
+ // make sure we can rename table.
+ val newName = tbl + "_renamed"
+ sql(s"ALTER TABLE $tbl RENAME TO $newName")
+ val readBack = spark.sharedState.externalCatalog.getTable("default", newName)
+
+ val actualTableLocation = readBack.storage.locationUri.get.getPath
+ val expectedLocation = if (tableMeta.tableType == CatalogTableType.EXTERNAL) {
+ tableMeta.storage.locationUri.get.getPath
+ } else {
+ spark.sessionState.catalog.defaultTablePath(TableIdentifier(newName, None)).getPath
+ }
+ assert(actualTableLocation == expectedLocation)
+
+ // make sure we can alter table location.
+ withTempDir { dir =>
+ val path = dir.toURI.toString.stripSuffix("/")
+ sql(s"ALTER TABLE ${tbl}_renamed SET LOCATION '$path'")
+ val readBack = spark.sharedState.externalCatalog.getTable("default", tbl + "_renamed")
+ val actualTableLocation = readBack.storage.locationUri.get.getPath
+ val expected = dir.toURI.getPath.stripSuffix("/")
+ assert(actualTableLocation == expected)
+ }
+ }
+
+ // test permanent view
+ checkAnswer(sql(s"select i from v_$index"), Row(1))
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
index 8140f883ee542..18137e7ea1d63 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils}
-import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType, StructField, StructType}
+import org.apache.spark.sql.types._
class HiveMetastoreCatalogSuite extends TestHiveSingleton with SQLTestUtils {
import spark.implicits._
@@ -67,6 +67,73 @@ class HiveMetastoreCatalogSuite extends TestHiveSingleton with SQLTestUtils {
assert(aliases.size == 1)
}
}
+
+ test("Validate catalog metadata for supported data types") {
+ withTable("t") {
+ sql(
+ """
+ |CREATE TABLE t (
+ |c1 boolean,
+ |c2 tinyint,
+ |c3 smallint,
+ |c4 short,
+ |c5 bigint,
+ |c6 long,
+ |c7 float,
+ |c8 double,
+ |c9 date,
+ |c10 timestamp,
+ |c11 string,
+ |c12 char(10),
+ |c13 varchar(10),
+ |c14 binary,
+ |c15 decimal,
+ |c16 decimal(10),
+ |c17 decimal(10,2),
+ |c18 array,
+ |c19 array,
+ |c20 array,
+ |c21 map,
+ |c22 map,
+ |c23 struct,
+ |c24 struct
+ |)
+ """.stripMargin)
+
+ val schema = hiveClient.getTable("default", "t").schema
+ val expectedSchema = new StructType()
+ .add("c1", "boolean")
+ .add("c2", "tinyint")
+ .add("c3", "smallint")
+ .add("c4", "short")
+ .add("c5", "bigint")
+ .add("c6", "long")
+ .add("c7", "float")
+ .add("c8", "double")
+ .add("c9", "date")
+ .add("c10", "timestamp")
+ .add("c11", "string")
+ .add("c12", "string", true,
+ new MetadataBuilder().putString(HIVE_TYPE_STRING, "char(10)").build())
+ .add("c13", "string", true,
+ new MetadataBuilder().putString(HIVE_TYPE_STRING, "varchar(10)").build())
+ .add("c14", "binary")
+ .add("c15", "decimal")
+ .add("c16", "decimal(10)")
+ .add("c17", "decimal(10,2)")
+ .add("c18", "array")
+ .add("c19", "array")
+ .add("c20", "array", true,
+ new MetadataBuilder().putString(HIVE_TYPE_STRING, "array").build())
+ .add("c21", "map")
+ .add("c22", "map", true,
+ new MetadataBuilder().putString(HIVE_TYPE_STRING, "map").build())
+ .add("c23", "struct")
+ .add("c24", "struct", true,
+ new MetadataBuilder().putString(HIVE_TYPE_STRING, "struct").build())
+ assert(schema == expectedSchema)
+ }
+ }
}
class DataSourceWithHiveMetastoreCatalogSuite
@@ -180,5 +247,6 @@ class DataSourceWithHiveMetastoreCatalogSuite
}
}
}
+
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala
index 3d0e43cbbe037..f2d27671094d7 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala
@@ -71,7 +71,7 @@ class HiveSchemaInferenceSuite
name = field,
dataType = LongType,
nullable = true,
- metadata = new MetadataBuilder().putString(HIVE_TYPE_STRING, "bigint").build())
+ metadata = Metadata.empty)
}
// and all partition columns as ints
val partitionStructFields = partitionCols.map { field =>
@@ -80,7 +80,7 @@ class HiveSchemaInferenceSuite
name = field.toLowerCase,
dataType = IntegerType,
nullable = true,
- metadata = new MetadataBuilder().putString(HIVE_TYPE_STRING, "int").build())
+ metadata = Metadata.empty)
}
val schema = StructType(structFields ++ partitionStructFields)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
index 5d2257180a026..1193db0a257e6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
@@ -18,17 +18,11 @@
package org.apache.spark.sql.hive
import java.io.{BufferedWriter, File, FileWriter}
-import java.sql.Timestamp
-import java.util.Date
-import scala.collection.mutable.ArrayBuffer
import scala.tools.nsc.Properties
import org.apache.hadoop.fs.Path
import org.scalatest.{BeforeAndAfterEach, Matchers}
-import org.scalatest.concurrent.Timeouts
-import org.scalatest.exceptions.TestFailedDueToTimeoutException
-import org.scalatest.time.SpanSugar._
import org.apache.spark._
import org.apache.spark.internal.Logging
@@ -38,7 +32,6 @@ import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
-import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer
import org.apache.spark.sql.types.{DecimalType, StructType}
import org.apache.spark.util.{ResetSystemProperties, Utils}
@@ -46,11 +39,10 @@ import org.apache.spark.util.{ResetSystemProperties, Utils}
* This suite tests spark-submit with applications using HiveContext.
*/
class HiveSparkSubmitSuite
- extends SparkFunSuite
+ extends SparkSubmitTestUtils
with Matchers
with BeforeAndAfterEach
- with ResetSystemProperties
- with Timeouts {
+ with ResetSystemProperties {
// TODO: rewrite these or mark them as slow tests to be run sparingly
@@ -335,71 +327,6 @@ class HiveSparkSubmitSuite
unusedJar.toString)
runSparkSubmit(argsForShowTables)
}
-
- // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
- // This is copied from org.apache.spark.deploy.SparkSubmitSuite
- private def runSparkSubmit(args: Seq[String]): Unit = {
- val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
- val history = ArrayBuffer.empty[String]
- val sparkSubmit = if (Utils.isWindows) {
- // On Windows, `ProcessBuilder.directory` does not change the current working directory.
- new File("..\\..\\bin\\spark-submit.cmd").getAbsolutePath
- } else {
- "./bin/spark-submit"
- }
- val commands = Seq(sparkSubmit) ++ args
- val commandLine = commands.mkString("'", "' '", "'")
-
- val builder = new ProcessBuilder(commands: _*).directory(new File(sparkHome))
- val env = builder.environment()
- env.put("SPARK_TESTING", "1")
- env.put("SPARK_HOME", sparkHome)
-
- def captureOutput(source: String)(line: String): Unit = {
- // This test suite has some weird behaviors when executed on Jenkins:
- //
- // 1. Sometimes it gets extremely slow out of unknown reason on Jenkins. Here we add a
- // timestamp to provide more diagnosis information.
- // 2. Log lines are not correctly redirected to unit-tests.log as expected, so here we print
- // them out for debugging purposes.
- val logLine = s"${new Timestamp(new Date().getTime)} - $source> $line"
- // scalastyle:off println
- println(logLine)
- // scalastyle:on println
- history += logLine
- }
-
- val process = builder.start()
- new ProcessOutputCapturer(process.getInputStream, captureOutput("stdout")).start()
- new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start()
-
- try {
- val exitCode = failAfter(300.seconds) { process.waitFor() }
- if (exitCode != 0) {
- // include logs in output. Note that logging is async and may not have completed
- // at the time this exception is raised
- Thread.sleep(1000)
- val historyLog = history.mkString("\n")
- fail {
- s"""spark-submit returned with exit code $exitCode.
- |Command line: $commandLine
- |
- |$historyLog
- """.stripMargin
- }
- }
- } catch {
- case to: TestFailedDueToTimeoutException =>
- val historyLog = history.mkString("\n")
- fail(s"Timeout of $commandLine" +
- s" See the log4j logs for more detail." +
- s"\n$historyLog", to)
- case t: Throwable => throw t
- } finally {
- // Ensure we still kill the process in case it timed out
- process.destroy()
- }
- }
}
object SetMetastoreURLTest extends Logging {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala
similarity index 67%
rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala
index cc80f2e481cbf..aa5cae33f5cd9 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala
@@ -34,7 +34,7 @@ case class TestData(key: Int, value: String)
case class ThreeCloumntable(key: Int, value: String, key1: String)
-class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter
+class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter
with SQLTestUtils {
import spark.implicits._
@@ -50,47 +50,53 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
}
test("insertInto() HiveTable") {
- sql("CREATE TABLE createAndInsertTest (key int, value string)")
-
- // Add some data.
- testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest")
-
- // Make sure the table has also been updated.
- checkAnswer(
- sql("SELECT * FROM createAndInsertTest"),
- testData.collect().toSeq
- )
-
- // Add more data.
- testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest")
-
- // Make sure the table has been updated.
- checkAnswer(
- sql("SELECT * FROM createAndInsertTest"),
- testData.toDF().collect().toSeq ++ testData.toDF().collect().toSeq
- )
-
- // Now overwrite.
- testData.write.mode(SaveMode.Overwrite).insertInto("createAndInsertTest")
-
- // Make sure the registered table has also been updated.
- checkAnswer(
- sql("SELECT * FROM createAndInsertTest"),
- testData.collect().toSeq
- )
+ withTable("createAndInsertTest") {
+ sql("CREATE TABLE createAndInsertTest (key int, value string)")
+
+ // Add some data.
+ testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest")
+
+ // Make sure the table has also been updated.
+ checkAnswer(
+ sql("SELECT * FROM createAndInsertTest"),
+ testData.collect().toSeq
+ )
+
+ // Add more data.
+ testData.write.mode(SaveMode.Append).insertInto("createAndInsertTest")
+
+ // Make sure the table has been updated.
+ checkAnswer(
+ sql("SELECT * FROM createAndInsertTest"),
+ testData.toDF().collect().toSeq ++ testData.toDF().collect().toSeq
+ )
+
+ // Now overwrite.
+ testData.write.mode(SaveMode.Overwrite).insertInto("createAndInsertTest")
+
+ // Make sure the registered table has also been updated.
+ checkAnswer(
+ sql("SELECT * FROM createAndInsertTest"),
+ testData.collect().toSeq
+ )
+ }
}
test("Double create fails when allowExisting = false") {
- sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)")
-
- intercept[AnalysisException] {
+ withTable("doubleCreateAndInsertTest") {
sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)")
+
+ intercept[AnalysisException] {
+ sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)")
+ }
}
}
test("Double create does not fail when allowExisting = true") {
- sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)")
- sql("CREATE TABLE IF NOT EXISTS doubleCreateAndInsertTest (key int, value string)")
+ withTable("doubleCreateAndInsertTest") {
+ sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)")
+ sql("CREATE TABLE IF NOT EXISTS doubleCreateAndInsertTest (key int, value string)")
+ }
}
test("SPARK-4052: scala.collection.Map as value type of MapType") {
@@ -268,29 +274,33 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
test("Test partition mode = strict") {
withSQLConf(("hive.exec.dynamic.partition.mode", "strict")) {
- sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)")
- val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd"))
+ withTable("partitioned") {
+ sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)")
+ val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd"))
.toDF("id", "data", "part")
- intercept[SparkException] {
- data.write.insertInto("partitioned")
+ intercept[SparkException] {
+ data.write.insertInto("partitioned")
+ }
}
}
}
test("Detect table partitioning") {
withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) {
- sql("CREATE TABLE source (id bigint, data string, part string)")
- val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")).toDF()
+ withTable("source", "partitioned") {
+ sql("CREATE TABLE source (id bigint, data string, part string)")
+ val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")).toDF()
- data.write.insertInto("source")
- checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq)
+ data.write.insertInto("source")
+ checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq)
- sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)")
- // this will pick up the output partitioning from the table definition
- spark.table("source").write.insertInto("partitioned")
+ sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)")
+ // this will pick up the output partitioning from the table definition
+ spark.table("source").write.insertInto("partitioned")
- checkAnswer(sql("SELECT * FROM partitioned"), data.collect().toSeq)
+ checkAnswer(sql("SELECT * FROM partitioned"), data.collect().toSeq)
+ }
}
}
@@ -461,19 +471,23 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
testPartitionedTable("insertInto() should reject missing columns") {
tableName =>
- sql("CREATE TABLE t (a INT, b INT)")
+ withTable("t") {
+ sql("CREATE TABLE t (a INT, b INT)")
- intercept[AnalysisException] {
- spark.table("t").write.insertInto(tableName)
+ intercept[AnalysisException] {
+ spark.table("t").write.insertInto(tableName)
+ }
}
}
testPartitionedTable("insertInto() should reject extra columns") {
tableName =>
- sql("CREATE TABLE t (a INT, b INT, c INT, d INT, e INT)")
+ withTable("t") {
+ sql("CREATE TABLE t (a INT, b INT, c INT, d INT, e INT)")
- intercept[AnalysisException] {
- spark.table("t").write.insertInto(tableName)
+ intercept[AnalysisException] {
+ spark.table("t").write.insertInto(tableName)
+ }
}
}
@@ -534,4 +548,184 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
}
}
}
+
+ test("insert overwrite to dir from hive metastore table") {
+ withTempDir { dir =>
+ val path = dir.toURI.getPath
+
+ sql(s"INSERT OVERWRITE LOCAL DIRECTORY '${path}' SELECT * FROM src where key < 10")
+
+ sql(
+ s"""
+ |INSERT OVERWRITE LOCAL DIRECTORY '${path}'
+ |STORED AS orc
+ |SELECT * FROM src where key < 10
+ """.stripMargin)
+
+ // use orc data source to check the data of path is right.
+ withTempView("orc_source") {
+ sql(
+ s"""
+ |CREATE TEMPORARY VIEW orc_source
+ |USING org.apache.spark.sql.hive.orc
+ |OPTIONS (
+ | PATH '${dir.getCanonicalPath}'
+ |)
+ """.stripMargin)
+
+ checkAnswer(
+ sql("select * from orc_source"),
+ sql("select * from src where key < 10"))
+ }
+ }
+ }
+
+ test("insert overwrite to local dir from temp table") {
+ withTempView("test_insert_table") {
+ spark.range(10).selectExpr("id", "id AS str").createOrReplaceTempView("test_insert_table")
+
+ withTempDir { dir =>
+ val path = dir.toURI.getPath
+
+ sql(
+ s"""
+ |INSERT OVERWRITE LOCAL DIRECTORY '${path}'
+ |ROW FORMAT DELIMITED FIELDS TERMINATED BY ','
+ |SELECT * FROM test_insert_table
+ """.stripMargin)
+
+ sql(
+ s"""
+ |INSERT OVERWRITE LOCAL DIRECTORY '${path}'
+ |STORED AS orc
+ |SELECT * FROM test_insert_table
+ """.stripMargin)
+
+ // use orc data source to check the data of path is right.
+ checkAnswer(
+ spark.read.orc(dir.getCanonicalPath),
+ sql("select * from test_insert_table"))
+ }
+ }
+ }
+
+ test("insert overwrite to dir from temp table") {
+ withTempView("test_insert_table") {
+ spark.range(10).selectExpr("id", "id AS str").createOrReplaceTempView("test_insert_table")
+
+ withTempDir { dir =>
+ val pathUri = dir.toURI
+
+ sql(
+ s"""
+ |INSERT OVERWRITE DIRECTORY '${pathUri}'
+ |ROW FORMAT DELIMITED FIELDS TERMINATED BY ','
+ |SELECT * FROM test_insert_table
+ """.stripMargin)
+
+ sql(
+ s"""
+ |INSERT OVERWRITE DIRECTORY '${pathUri}'
+ |STORED AS orc
+ |SELECT * FROM test_insert_table
+ """.stripMargin)
+
+ // use orc data source to check the data of path is right.
+ checkAnswer(
+ spark.read.orc(dir.getCanonicalPath),
+ sql("select * from test_insert_table"))
+ }
+ }
+ }
+
+ test("multi insert overwrite to dir") {
+ withTempView("test_insert_table") {
+ spark.range(10).selectExpr("id", "id AS str").createOrReplaceTempView("test_insert_table")
+
+ withTempDir { dir =>
+ val pathUri = dir.toURI
+
+ withTempDir { dir2 =>
+ val pathUri2 = dir2.toURI
+
+ sql(
+ s"""
+ |FROM test_insert_table
+ |INSERT OVERWRITE DIRECTORY '${pathUri}'
+ |STORED AS orc
+ |SELECT id
+ |INSERT OVERWRITE DIRECTORY '${pathUri2}'
+ |STORED AS orc
+ |SELECT *
+ """.stripMargin)
+
+ // use orc data source to check the data of path is right.
+ checkAnswer(
+ spark.read.orc(dir.getCanonicalPath),
+ sql("select id from test_insert_table"))
+
+ checkAnswer(
+ spark.read.orc(dir2.getCanonicalPath),
+ sql("select * from test_insert_table"))
+ }
+ }
+ }
+ }
+
+ test("insert overwrite to dir to illegal path") {
+ withTempView("test_insert_table") {
+ spark.range(10).selectExpr("id", "id AS str").createOrReplaceTempView("test_insert_table")
+
+ val e = intercept[IllegalArgumentException] {
+ sql(
+ s"""
+ |INSERT OVERWRITE LOCAL DIRECTORY 'abc://a'
+ |ROW FORMAT DELIMITED FIELDS TERMINATED BY ','
+ |SELECT * FROM test_insert_table
+ """.stripMargin)
+ }.getMessage
+
+ assert(e.contains("Wrong FS: abc://a, expected: file:///"))
+ }
+ }
+
+ test("insert overwrite to dir with mixed syntax") {
+ withTempView("test_insert_table") {
+ spark.range(10).selectExpr("id", "id AS str").createOrReplaceTempView("test_insert_table")
+
+ val e = intercept[ParseException] {
+ sql(
+ s"""
+ |INSERT OVERWRITE DIRECTORY 'file://tmp'
+ |USING json
+ |ROW FORMAT DELIMITED FIELDS TERMINATED BY ','
+ |SELECT * FROM test_insert_table
+ """.stripMargin)
+ }.getMessage
+
+ assert(e.contains("mismatched input 'ROW'"))
+ }
+ }
+
+ test("insert overwrite to dir with multi inserts") {
+ withTempView("test_insert_table") {
+ spark.range(10).selectExpr("id", "id AS str").createOrReplaceTempView("test_insert_table")
+
+ val e = intercept[ParseException] {
+ sql(
+ s"""
+ |INSERT OVERWRITE DIRECTORY 'file://tmp2'
+ |USING json
+ |ROW FORMAT DELIMITED FIELDS TERMINATED BY ','
+ |SELECT * FROM test_insert_table
+ |INSERT OVERWRITE DIRECTORY 'file://tmp2'
+ |USING json
+ |ROW FORMAT DELIMITED FIELDS TERMINATED BY ','
+ |SELECT * FROM test_insert_table
+ """.stripMargin)
+ }.getMessage
+
+ assert(e.contains("mismatched input 'ROW'"))
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index e01198dd53178..29b0e6c8533ef 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -583,7 +583,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
Row(3) :: Row(4) :: Nil)
table("test_parquet_ctas").queryExecution.optimizedPlan match {
- case LogicalRelation(p: HadoopFsRelation, _, _) => // OK
+ case LogicalRelation(p: HadoopFsRelation, _, _, _) => // OK
case _ =>
fail(s"test_parquet_ctas should have be converted to ${classOf[HadoopFsRelation]}")
}
@@ -1354,31 +1354,4 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
sparkSession.sparkContext.conf.set(DEBUG_MODE, previousValue)
}
}
-
- test("SPARK-18464: support old table which doesn't store schema in table properties") {
- withTable("old") {
- withTempPath { path =>
- Seq(1 -> "a").toDF("i", "j").write.parquet(path.getAbsolutePath)
- val tableDesc = CatalogTable(
- identifier = TableIdentifier("old", Some("default")),
- tableType = CatalogTableType.EXTERNAL,
- storage = CatalogStorageFormat.empty.copy(
- properties = Map("path" -> path.getAbsolutePath)
- ),
- schema = new StructType(),
- provider = Some("parquet"),
- properties = Map(
- HiveExternalCatalog.DATASOURCE_PROVIDER -> "parquet"))
- hiveClient.createTable(tableDesc, ignoreIfExists = false)
-
- checkAnswer(spark.table("old"), Row(1, "a"))
- checkAnswer(sql("select * from old"), Row(1, "a"))
-
- val expectedSchema = StructType(Seq(
- StructField("i", IntegerType, nullable = true),
- StructField("j", StringType, nullable = true)))
- assert(table("old").schema === expectedSchema)
- }
- }
- }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
index 4aea6d14efb0e..9060ce2e0eb4b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.hive
-import java.net.URI
-
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala
index 43b6bf5feeb60..b2dc401ce1efc 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.hive
import java.io.File
+import java.sql.Timestamp
import com.google.common.io.Files
import org.apache.hadoop.fs.FileSystem
@@ -68,4 +69,20 @@ class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingl
sql("DROP TABLE IF EXISTS createAndInsertTest")
}
}
+
+ test("SPARK-21739: Cast expression should initialize timezoneId") {
+ withTable("table_with_timestamp_partition") {
+ sql("CREATE TABLE table_with_timestamp_partition(value int) PARTITIONED BY (ts TIMESTAMP)")
+ sql("INSERT OVERWRITE TABLE table_with_timestamp_partition " +
+ "PARTITION (ts = '2010-01-01 00:00:00.000') VALUES (1)")
+
+ // test for Cast expression in TableReader
+ checkAnswer(sql("SELECT * FROM table_with_timestamp_partition"),
+ Seq(Row(1, Timestamp.valueOf("2010-01-01 00:00:00.000"))))
+
+ // test for Cast expression in HiveTableScanExec
+ checkAnswer(sql("SELECT value FROM table_with_timestamp_partition " +
+ "WHERE ts = '2010-01-01 00:00:00.000'"), Row(1))
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala
new file mode 100644
index 0000000000000..ede44df4afe11
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SparkSubmitTestUtils.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import java.io.File
+import java.sql.Timestamp
+import java.util.Date
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.concurrent.TimeLimits
+import org.scalatest.exceptions.TestFailedDueToTimeoutException
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer
+import org.apache.spark.util.Utils
+
+trait SparkSubmitTestUtils extends SparkFunSuite with TimeLimits {
+
+ // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
+ // This is copied from org.apache.spark.deploy.SparkSubmitSuite
+ protected def runSparkSubmit(args: Seq[String], sparkHomeOpt: Option[String] = None): Unit = {
+ val sparkHome = sparkHomeOpt.getOrElse(
+ sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")))
+ val history = ArrayBuffer.empty[String]
+ val sparkSubmit = if (Utils.isWindows) {
+ // On Windows, `ProcessBuilder.directory` does not change the current working directory.
+ new File("..\\..\\bin\\spark-submit.cmd").getAbsolutePath
+ } else {
+ "./bin/spark-submit"
+ }
+ val commands = Seq(sparkSubmit) ++ args
+ val commandLine = commands.mkString("'", "' '", "'")
+
+ val builder = new ProcessBuilder(commands: _*).directory(new File(sparkHome))
+ val env = builder.environment()
+ env.put("SPARK_TESTING", "1")
+ env.put("SPARK_HOME", sparkHome)
+
+ def captureOutput(source: String)(line: String): Unit = {
+ // This test suite has some weird behaviors when executed on Jenkins:
+ //
+ // 1. Sometimes it gets extremely slow out of unknown reason on Jenkins. Here we add a
+ // timestamp to provide more diagnosis information.
+ // 2. Log lines are not correctly redirected to unit-tests.log as expected, so here we print
+ // them out for debugging purposes.
+ val logLine = s"${new Timestamp(new Date().getTime)} - $source> $line"
+ // scalastyle:off println
+ println(logLine)
+ // scalastyle:on println
+ history += logLine
+ }
+
+ val process = builder.start()
+ new ProcessOutputCapturer(process.getInputStream, captureOutput("stdout")).start()
+ new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start()
+
+ try {
+ val exitCode = failAfter(300.seconds) { process.waitFor() }
+ if (exitCode != 0) {
+ // include logs in output. Note that logging is async and may not have completed
+ // at the time this exception is raised
+ Thread.sleep(1000)
+ val historyLog = history.mkString("\n")
+ fail {
+ s"""spark-submit returned with exit code $exitCode.
+ |Command line: $commandLine
+ |
+ |$historyLog
+ """.stripMargin
+ }
+ }
+ } catch {
+ case to: TestFailedDueToTimeoutException =>
+ val historyLog = history.mkString("\n")
+ fail(s"Timeout of $commandLine" +
+ s" See the log4j logs for more detail." +
+ s"\n$historyLog", to)
+ case t: Throwable => throw t
+ } finally {
+ // Ensure we still kill the process in case it timed out
+ process.destroy()
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index 71cf79c473b46..9ff9ecf7f3677 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -26,6 +26,7 @@ import org.apache.hadoop.hive.common.StatsSetupConst
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.NoSuchPartitionException
import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, HiveTableRelation}
import org.apache.spark.sql.catalyst.plans.logical.ColumnStat
import org.apache.spark.sql.catalyst.util.StringUtils
@@ -39,14 +40,7 @@ import org.apache.spark.sql.types._
class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleton {
- private def dropMetadata(schema: StructType): StructType = {
- val newFields = schema.fields.map { f =>
- StructField(f.name, f.dataType, f.nullable, Metadata.empty)
- }
- StructType(newFields)
- }
-
- test("Hive serde tables should fallback to HDFS for size estimation") {
+ test("Hive serde tables should fallback to HDFS for size estimation") {
withSQLConf(SQLConf.ENABLE_FALL_BACK_TO_HDFS_FOR_STATS.key -> "true") {
withTable("csv_table") {
withTempDir { tempDir =>
@@ -137,9 +131,9 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
// Verify that the schema stored in catalog is a dummy one used for
// data source tables. The actual schema is stored in table properties.
- val rawSchema = dropMetadata(hiveClient.getTable("default", table).schema)
- val expectedRawSchema = new StructType()
- .add("col", "array")
+ val rawSchema = hiveClient.getTable("default", table).schema
+ val metadata = new MetadataBuilder().putString("comment", "from deserializer").build()
+ val expectedRawSchema = new StructType().add("col", "array", true, metadata)
assert(rawSchema == expectedRawSchema)
val actualSchema = spark.sharedState.externalCatalog.getTable("default", table).schema
@@ -160,14 +154,13 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
}
test("Analyze hive serde tables when schema is not same as schema in table properties") {
-
val table = "hive_serde"
withTable(table) {
sql(s"CREATE TABLE $table (C1 INT, C2 STRING, C3 DOUBLE)")
// Verify that the table schema stored in hive catalog is
// different than the schema stored in table properties.
- val rawSchema = dropMetadata(hiveClient.getTable("default", table).schema)
+ val rawSchema = hiveClient.getTable("default", table).schema
val expectedRawSchema = new StructType()
.add("c1", "int")
.add("c2", "string")
@@ -202,7 +195,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
sql(s"INSERT INTO TABLE $tableName PARTITION (ds='$ds') SELECT * FROM src")
}
- sql(s"ALTER TABLE $tableName SET LOCATION '$path'")
+ sql(s"ALTER TABLE $tableName SET LOCATION '${path.toURI}'")
sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS noscan")
@@ -221,7 +214,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
s"""
|CREATE TABLE $sourceTableName (key STRING, value STRING)
|PARTITIONED BY (ds STRING)
- |LOCATION '$path'
+ |LOCATION '${path.toURI}'
""".stripMargin)
val partitionDates = List("2010-01-01", "2010-01-02", "2010-01-03")
@@ -238,7 +231,7 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
s"""
|CREATE TABLE $tableName (key STRING, value STRING)
|PARTITIONED BY (ds STRING)
- |LOCATION '$path'
+ |LOCATION '${path.toURI}'
""".stripMargin)
// Register only one of the partitions found on disk
@@ -256,6 +249,259 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto
}
}
+ test("analyze single partition") {
+ val tableName = "analyzeTable_part"
+
+ def queryStats(ds: String): CatalogStatistics = {
+ val partition =
+ spark.sessionState.catalog.getPartition(TableIdentifier(tableName), Map("ds" -> ds))
+ partition.stats.get
+ }
+
+ def createPartition(ds: String, query: String): Unit = {
+ sql(s"INSERT INTO TABLE $tableName PARTITION (ds='$ds') $query")
+ }
+
+ withTable(tableName) {
+ sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)")
+
+ createPartition("2010-01-01", "SELECT '1', 'A' from src")
+ createPartition("2010-01-02", "SELECT '1', 'A' from src UNION ALL SELECT '1', 'A' from src")
+ createPartition("2010-01-03", "SELECT '1', 'A' from src")
+
+ sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-01') COMPUTE STATISTICS NOSCAN")
+
+ sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-02') COMPUTE STATISTICS NOSCAN")
+
+ assert(queryStats("2010-01-01").rowCount === None)
+ assert(queryStats("2010-01-01").sizeInBytes === 2000)
+
+ assert(queryStats("2010-01-02").rowCount === None)
+ assert(queryStats("2010-01-02").sizeInBytes === 2*2000)
+
+ sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-01') COMPUTE STATISTICS")
+
+ sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-02') COMPUTE STATISTICS")
+
+ assert(queryStats("2010-01-01").rowCount.get === 500)
+ assert(queryStats("2010-01-01").sizeInBytes === 2000)
+
+ assert(queryStats("2010-01-02").rowCount.get === 2*500)
+ assert(queryStats("2010-01-02").sizeInBytes === 2*2000)
+ }
+ }
+
+ test("analyze a set of partitions") {
+ val tableName = "analyzeTable_part"
+
+ def queryStats(ds: String, hr: String): Option[CatalogStatistics] = {
+ val tableId = TableIdentifier(tableName)
+ val partition =
+ spark.sessionState.catalog.getPartition(tableId, Map("ds" -> ds, "hr" -> hr))
+ partition.stats
+ }
+
+ def assertPartitionStats(
+ ds: String,
+ hr: String,
+ rowCount: Option[BigInt],
+ sizeInBytes: BigInt): Unit = {
+ val stats = queryStats(ds, hr).get
+ assert(stats.rowCount === rowCount)
+ assert(stats.sizeInBytes === sizeInBytes)
+ }
+
+ def createPartition(ds: String, hr: Int, query: String): Unit = {
+ sql(s"INSERT INTO TABLE $tableName PARTITION (ds='$ds', hr=$hr) $query")
+ }
+
+ withTable(tableName) {
+ sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING, hr INT)")
+
+ createPartition("2010-01-01", 10, "SELECT '1', 'A' from src")
+ createPartition("2010-01-01", 11, "SELECT '1', 'A' from src")
+ createPartition("2010-01-02", 10, "SELECT '1', 'A' from src")
+ createPartition("2010-01-02", 11,
+ "SELECT '1', 'A' from src UNION ALL SELECT '1', 'A' from src")
+
+ sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-01') COMPUTE STATISTICS NOSCAN")
+
+ assertPartitionStats("2010-01-01", "10", rowCount = None, sizeInBytes = 2000)
+ assertPartitionStats("2010-01-01", "11", rowCount = None, sizeInBytes = 2000)
+ assert(queryStats("2010-01-02", "10") === None)
+ assert(queryStats("2010-01-02", "11") === None)
+
+ sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-02') COMPUTE STATISTICS NOSCAN")
+
+ assertPartitionStats("2010-01-01", "10", rowCount = None, sizeInBytes = 2000)
+ assertPartitionStats("2010-01-01", "11", rowCount = None, sizeInBytes = 2000)
+ assertPartitionStats("2010-01-02", "10", rowCount = None, sizeInBytes = 2000)
+ assertPartitionStats("2010-01-02", "11", rowCount = None, sizeInBytes = 2*2000)
+
+ sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-01') COMPUTE STATISTICS")
+
+ assertPartitionStats("2010-01-01", "10", rowCount = Some(500), sizeInBytes = 2000)
+ assertPartitionStats("2010-01-01", "11", rowCount = Some(500), sizeInBytes = 2000)
+ assertPartitionStats("2010-01-02", "10", rowCount = None, sizeInBytes = 2000)
+ assertPartitionStats("2010-01-02", "11", rowCount = None, sizeInBytes = 2*2000)
+
+ sql(s"ANALYZE TABLE $tableName PARTITION (ds='2010-01-02') COMPUTE STATISTICS")
+
+ assertPartitionStats("2010-01-01", "10", rowCount = Some(500), sizeInBytes = 2000)
+ assertPartitionStats("2010-01-01", "11", rowCount = Some(500), sizeInBytes = 2000)
+ assertPartitionStats("2010-01-02", "10", rowCount = Some(500), sizeInBytes = 2000)
+ assertPartitionStats("2010-01-02", "11", rowCount = Some(2*500), sizeInBytes = 2*2000)
+ }
+ }
+
+ test("analyze all partitions") {
+ val tableName = "analyzeTable_part"
+
+ def assertPartitionStats(
+ ds: String,
+ hr: String,
+ rowCount: Option[BigInt],
+ sizeInBytes: BigInt): Unit = {
+ val stats = spark.sessionState.catalog.getPartition(TableIdentifier(tableName),
+ Map("ds" -> ds, "hr" -> hr)).stats.get
+ assert(stats.rowCount === rowCount)
+ assert(stats.sizeInBytes === sizeInBytes)
+ }
+
+ def createPartition(ds: String, hr: Int, query: String): Unit = {
+ sql(s"INSERT INTO TABLE $tableName PARTITION (ds='$ds', hr=$hr) $query")
+ }
+
+ withTable(tableName) {
+ sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING, hr INT)")
+
+ createPartition("2010-01-01", 10, "SELECT '1', 'A' from src")
+ createPartition("2010-01-01", 11, "SELECT '1', 'A' from src")
+ createPartition("2010-01-02", 10, "SELECT '1', 'A' from src")
+ createPartition("2010-01-02", 11,
+ "SELECT '1', 'A' from src UNION ALL SELECT '1', 'A' from src")
+
+ sql(s"ANALYZE TABLE $tableName PARTITION (ds, hr) COMPUTE STATISTICS NOSCAN")
+
+ assertPartitionStats("2010-01-01", "10", rowCount = None, sizeInBytes = 2000)
+ assertPartitionStats("2010-01-01", "11", rowCount = None, sizeInBytes = 2000)
+ assertPartitionStats("2010-01-01", "11", rowCount = None, sizeInBytes = 2000)
+ assertPartitionStats("2010-01-02", "11", rowCount = None, sizeInBytes = 2*2000)
+
+ sql(s"ANALYZE TABLE $tableName PARTITION (ds, hr) COMPUTE STATISTICS")
+
+ assertPartitionStats("2010-01-01", "10", rowCount = Some(500), sizeInBytes = 2000)
+ assertPartitionStats("2010-01-01", "11", rowCount = Some(500), sizeInBytes = 2000)
+ assertPartitionStats("2010-01-01", "11", rowCount = Some(500), sizeInBytes = 2000)
+ assertPartitionStats("2010-01-02", "11", rowCount = Some(2*500), sizeInBytes = 2*2000)
+ }
+ }
+
+ test("analyze partitions for an empty table") {
+ val tableName = "analyzeTable_part"
+
+ withTable(tableName) {
+ sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)")
+
+ // make sure there is no exception
+ sql(s"ANALYZE TABLE $tableName PARTITION (ds) COMPUTE STATISTICS NOSCAN")
+
+ // make sure there is no exception
+ sql(s"ANALYZE TABLE $tableName PARTITION (ds) COMPUTE STATISTICS")
+ }
+ }
+
+ test("analyze partitions case sensitivity") {
+ val tableName = "analyzeTable_part"
+ withTable(tableName) {
+ sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)")
+
+ sql(s"INSERT INTO TABLE $tableName PARTITION (ds='2010-01-01') SELECT * FROM src")
+
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
+ sql(s"ANALYZE TABLE $tableName PARTITION (DS='2010-01-01') COMPUTE STATISTICS")
+ }
+
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
+ val message = intercept[AnalysisException] {
+ sql(s"ANALYZE TABLE $tableName PARTITION (DS='2010-01-01') COMPUTE STATISTICS")
+ }.getMessage
+ assert(message.contains(
+ s"DS is not a valid partition column in table `default`.`${tableName.toLowerCase}`"))
+ }
+ }
+ }
+
+ test("analyze partial partition specifications") {
+
+ val tableName = "analyzeTable_part"
+
+ def assertAnalysisException(partitionSpec: String): Unit = {
+ val message = intercept[AnalysisException] {
+ sql(s"ANALYZE TABLE $tableName $partitionSpec COMPUTE STATISTICS")
+ }.getMessage
+ assert(message.contains("The list of partition columns with values " +
+ s"in partition specification for table '${tableName.toLowerCase}' in database 'default' " +
+ "is not a prefix of the list of partition columns defined in the table schema"))
+ }
+
+ withTable(tableName) {
+ sql(
+ s"""
+ |CREATE TABLE $tableName (key STRING, value STRING)
+ |PARTITIONED BY (a STRING, b INT, c STRING)
+ """.stripMargin)
+
+ sql(s"INSERT INTO TABLE $tableName PARTITION (a='a1', b=10, c='c1') SELECT * FROM src")
+
+ sql(s"ANALYZE TABLE $tableName PARTITION (a='a1') COMPUTE STATISTICS")
+ sql(s"ANALYZE TABLE $tableName PARTITION (a='a1', b=10) COMPUTE STATISTICS")
+ sql(s"ANALYZE TABLE $tableName PARTITION (A='a1', b=10) COMPUTE STATISTICS")
+ sql(s"ANALYZE TABLE $tableName PARTITION (b=10, a='a1') COMPUTE STATISTICS")
+ sql(s"ANALYZE TABLE $tableName PARTITION (b=10, A='a1') COMPUTE STATISTICS")
+
+ assertAnalysisException("PARTITION (b=10)")
+ assertAnalysisException("PARTITION (a, b=10)")
+ assertAnalysisException("PARTITION (b=10, c='c1')")
+ assertAnalysisException("PARTITION (a, b=10, c='c1')")
+ assertAnalysisException("PARTITION (c='c1')")
+ assertAnalysisException("PARTITION (a, b, c='c1')")
+ assertAnalysisException("PARTITION (a='a1', c='c1')")
+ assertAnalysisException("PARTITION (a='a1', b, c='c1')")
+ }
+ }
+
+ test("analyze non-existent partition") {
+
+ def assertAnalysisException(analyzeCommand: String, errorMessage: String): Unit = {
+ val message = intercept[AnalysisException] {
+ sql(analyzeCommand)
+ }.getMessage
+ assert(message.contains(errorMessage))
+ }
+
+ val tableName = "analyzeTable_part"
+ withTable(tableName) {
+ sql(s"CREATE TABLE $tableName (key STRING, value STRING) PARTITIONED BY (ds STRING)")
+
+ sql(s"INSERT INTO TABLE $tableName PARTITION (ds='2010-01-01') SELECT * FROM src")
+
+ assertAnalysisException(
+ s"ANALYZE TABLE $tableName PARTITION (hour=20) COMPUTE STATISTICS",
+ s"hour is not a valid partition column in table `default`.`${tableName.toLowerCase}`"
+ )
+
+ assertAnalysisException(
+ s"ANALYZE TABLE $tableName PARTITION (hour) COMPUTE STATISTICS",
+ s"hour is not a valid partition column in table `default`.`${tableName.toLowerCase}`"
+ )
+
+ intercept[NoSuchPartitionException] {
+ sql(s"ANALYZE TABLE $tableName PARTITION (ds='2011-02-30') COMPUTE STATISTICS")
+ }
+ }
+ }
+
test("test table-level statistics for hive tables created in HiveExternalCatalog") {
val textTable = "textTable"
withTable(textTable) {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala
index 193fa83dbad99..72f8e8ff7c688 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/TestHiveSuite.scala
@@ -42,4 +42,8 @@ class TestHiveSuite extends TestHiveSingleton with SQLTestUtils {
}
testHiveSparkSession.reset()
}
+
+ test("SPARK-15887: hive-site.xml should be loaded") {
+ assert(hiveClient.getConf("hive.in.test", "") == "true")
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala
index 986c6675cbb63..ed475a0261b0b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveVersionSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.hive.client
import org.apache.hadoop.conf.Configuration
+import org.scalactic.source.Position
import org.scalatest.Tag
import org.apache.spark.SparkFunSuite
@@ -40,7 +41,8 @@ private[client] abstract class HiveVersionSuite(version: String) extends SparkFu
override def suiteName: String = s"${super.suiteName}($version)"
- override protected def test(testName: String, testTags: Tag*)(testFun: => Unit): Unit = {
+ override protected def test(testName: String, testTags: Tag*)(testFun: => Any)
+ (implicit pos: Position): Unit = {
super.test(s"$version: $testName", testTags: _*)(testFun)
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
index 82fbdd645ebe0..1d9c8da996fea 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
@@ -21,7 +21,6 @@ import java.io.{ByteArrayOutputStream, File, PrintStream}
import java.net.URI
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.hive.common.StatsSetupConst
import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
import org.apache.hadoop.mapred.TextInputFormat
@@ -233,12 +232,49 @@ class VersionsSuite extends SparkFunSuite with Logging {
assert(client.getTable("default", "src").properties.contains("changed"))
}
- test(s"$version: alterTable(tableName: String, table: CatalogTable)") {
+ test(s"$version: alterTable(dbName: String, tableName: String, table: CatalogTable)") {
val newTable = client.getTable("default", "src").copy(properties = Map("changedAgain" -> ""))
- client.alterTable("src", newTable)
+ client.alterTable("default", "src", newTable)
assert(client.getTable("default", "src").properties.contains("changedAgain"))
}
+ test(s"$version: alterTable - rename") {
+ val newTable = client.getTable("default", "src")
+ .copy(identifier = TableIdentifier("tgt", database = Some("default")))
+ assert(!client.tableExists("default", "tgt"))
+
+ client.alterTable("default", "src", newTable)
+
+ assert(client.tableExists("default", "tgt"))
+ assert(!client.tableExists("default", "src"))
+ }
+
+ test(s"$version: alterTable - change database") {
+ val tempDB = CatalogDatabase(
+ "temporary", description = "test create", tempDatabasePath, Map())
+ client.createDatabase(tempDB, ignoreIfExists = true)
+
+ val newTable = client.getTable("default", "tgt")
+ .copy(identifier = TableIdentifier("tgt", database = Some("temporary")))
+ assert(!client.tableExists("temporary", "tgt"))
+
+ client.alterTable("default", "tgt", newTable)
+
+ assert(client.tableExists("temporary", "tgt"))
+ assert(!client.tableExists("default", "tgt"))
+ }
+
+ test(s"$version: alterTable - change database and table names") {
+ val newTable = client.getTable("temporary", "tgt")
+ .copy(identifier = TableIdentifier("src", database = Some("default")))
+ assert(!client.tableExists("default", "src"))
+
+ client.alterTable("temporary", "tgt", newTable)
+
+ assert(client.tableExists("default", "src"))
+ assert(!client.tableExists("temporary", "tgt"))
+ }
+
test(s"$version: listTables(database)") {
assert(client.listTables("default") === Seq("src", "temporary"))
}
@@ -697,6 +733,114 @@ class VersionsSuite extends SparkFunSuite with Logging {
assert(versionSpark.table("t1").collect() === Array(Row(2)))
}
}
+
+ test(s"$version: Decimal support of Avro Hive serde") {
+ val tableName = "tab1"
+ // TODO: add the other logical types. For details, see the link:
+ // https://avro.apache.org/docs/1.8.1/spec.html#Logical+Types
+ val avroSchema =
+ """{
+ | "name": "test_record",
+ | "type": "record",
+ | "fields": [ {
+ | "name": "f0",
+ | "type": [
+ | "null",
+ | {
+ | "precision": 38,
+ | "scale": 2,
+ | "type": "bytes",
+ | "logicalType": "decimal"
+ | }
+ | ]
+ | } ]
+ |}
+ """.stripMargin
+
+ Seq(true, false).foreach { isPartitioned =>
+ withTable(tableName) {
+ val partitionClause = if (isPartitioned) "PARTITIONED BY (ds STRING)" else ""
+ // Creates the (non-)partitioned Avro table
+ versionSpark.sql(
+ s"""
+ |CREATE TABLE $tableName
+ |$partitionClause
+ |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe'
+ |STORED AS
+ | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat'
+ | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat'
+ |TBLPROPERTIES ('avro.schema.literal' = '$avroSchema')
+ """.stripMargin
+ )
+
+ val errorMsg = "data type mismatch: cannot cast DecimalType(2,1) to BinaryType"
+
+ if (isPartitioned) {
+ val insertStmt = s"INSERT OVERWRITE TABLE $tableName partition (ds='a') SELECT 1.3"
+ if (version == "0.12" || version == "0.13") {
+ val e = intercept[AnalysisException](versionSpark.sql(insertStmt)).getMessage
+ assert(e.contains(errorMsg))
+ } else {
+ versionSpark.sql(insertStmt)
+ assert(versionSpark.table(tableName).collect() ===
+ versionSpark.sql("SELECT 1.30, 'a'").collect())
+ }
+ } else {
+ val insertStmt = s"INSERT OVERWRITE TABLE $tableName SELECT 1.3"
+ if (version == "0.12" || version == "0.13") {
+ val e = intercept[AnalysisException](versionSpark.sql(insertStmt)).getMessage
+ assert(e.contains(errorMsg))
+ } else {
+ versionSpark.sql(insertStmt)
+ assert(versionSpark.table(tableName).collect() ===
+ versionSpark.sql("SELECT 1.30").collect())
+ }
+ }
+ }
+ }
+ }
+
+ test(s"$version: read avro file containing decimal") {
+ val url = Thread.currentThread().getContextClassLoader.getResource("avroDecimal")
+ val location = new File(url.getFile)
+
+ val tableName = "tab1"
+ val avroSchema =
+ """{
+ | "name": "test_record",
+ | "type": "record",
+ | "fields": [ {
+ | "name": "f0",
+ | "type": [
+ | "null",
+ | {
+ | "precision": 38,
+ | "scale": 2,
+ | "type": "bytes",
+ | "logicalType": "decimal"
+ | }
+ | ]
+ | } ]
+ |}
+ """.stripMargin
+ withTable(tableName) {
+ versionSpark.sql(
+ s"""
+ |CREATE TABLE $tableName
+ |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe'
+ |WITH SERDEPROPERTIES ('respectSparkSchema' = 'true')
+ |STORED AS
+ | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat'
+ | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat'
+ |LOCATION '$location'
+ |TBLPROPERTIES ('avro.schema.literal' = '$avroSchema')
+ """.stripMargin
+ )
+ assert(versionSpark.table(tableName).collect() ===
+ versionSpark.sql("SELECT 1.30").collect())
+ }
+ }
+
// TODO: add more tests.
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
index 4c2fea3eb68bc..ee64bc9f9ee04 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
@@ -1998,4 +1998,15 @@ class HiveDDLSuite
sq.stop()
}
}
+
+ test("table name with schema") {
+ // regression test for SPARK-11778
+ withDatabase("usrdb") {
+ spark.sql("create schema usrdb")
+ withTable("usrdb.test") {
+ spark.sql("create table usrdb.test(c int)")
+ spark.read.table("usrdb.test")
+ }
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 799abc1d0c42f..2ea51791d0f79 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -370,21 +370,23 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd
""".stripMargin)
test("SPARK-7270: consider dynamic partition when comparing table output") {
- sql(s"CREATE TABLE test_partition (a STRING) PARTITIONED BY (b BIGINT, c STRING)")
- sql(s"CREATE TABLE ptest (a STRING, b BIGINT, c STRING)")
+ withTable("test_partition", "ptest") {
+ sql(s"CREATE TABLE test_partition (a STRING) PARTITIONED BY (b BIGINT, c STRING)")
+ sql(s"CREATE TABLE ptest (a STRING, b BIGINT, c STRING)")
- val analyzedPlan = sql(
- """
+ val analyzedPlan = sql(
+ """
|INSERT OVERWRITE table test_partition PARTITION (b=1, c)
|SELECT 'a', 'c' from ptest
""".stripMargin).queryExecution.analyzed
- assertResult(false, "Incorrect cast detected\n" + analyzedPlan) {
+ assertResult(false, "Incorrect cast detected\n" + analyzedPlan) {
var hasCast = false
- analyzedPlan.collect {
- case p: Project => p.transformExpressionsUp { case c: Cast => hasCast = true; c }
+ analyzedPlan.collect {
+ case p: Project => p.transformExpressionsUp { case c: Cast => hasCast = true; c }
+ }
+ hasCast
}
- hasCast
}
}
@@ -435,13 +437,13 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd
test("transform with SerDe2") {
assume(TestUtils.testCommandAvailable("/bin/bash"))
+ withTable("small_src") {
+ sql("CREATE TABLE small_src(key INT, value STRING)")
+ sql("INSERT OVERWRITE TABLE small_src SELECT key, value FROM src LIMIT 10")
- sql("CREATE TABLE small_src(key INT, value STRING)")
- sql("INSERT OVERWRITE TABLE small_src SELECT key, value FROM src LIMIT 10")
-
- val expected = sql("SELECT key FROM small_src").collect().head
- val res = sql(
- """
+ val expected = sql("SELECT key FROM small_src").collect().head
+ val res = sql(
+ """
|SELECT TRANSFORM (key) ROW FORMAT SERDE
|'org.apache.hadoop.hive.serde2.avro.AvroSerDe'
|WITH SERDEPROPERTIES ('avro.schema.literal'='{"namespace":
@@ -453,7 +455,8 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd
|FROM small_src
""".stripMargin.replaceAll(System.lineSeparator(), " ")).collect().head
- assert(expected(0) === res(0))
+ assert(expected(0) === res(0))
+ }
}
createQueryTest("transform with SerDe3",
@@ -780,22 +783,26 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd
test("Exactly once semantics for DDL and command statements") {
val tableName = "test_exactly_once"
- val q0 = sql(s"CREATE TABLE $tableName(key INT, value STRING)")
+ withTable(tableName) {
+ val q0 = sql(s"CREATE TABLE $tableName(key INT, value STRING)")
- // If the table was not created, the following assertion would fail
- assert(Try(table(tableName)).isSuccess)
+ // If the table was not created, the following assertion would fail
+ assert(Try(table(tableName)).isSuccess)
- // If the CREATE TABLE command got executed again, the following assertion would fail
- assert(Try(q0.count()).isSuccess)
+ // If the CREATE TABLE command got executed again, the following assertion would fail
+ assert(Try(q0.count()).isSuccess)
+ }
}
test("SPARK-2263: Insert Map values") {
- sql("CREATE TABLE m(value MAP)")
- sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10")
- sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).foreach {
- case (Row(map: Map[_, _]), Row(key: Int, value: String)) =>
- assert(map.size === 1)
- assert(map.head === ((key, value)))
+ withTable("m") {
+ sql("CREATE TABLE m(value MAP)")
+ sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10")
+ sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).foreach {
+ case (Row(map: Map[_, _]), Row(key: Int, value: String)) =>
+ assert(map.size === 1)
+ assert(map.head === ((key, value)))
+ }
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
index 7803ac39e508b..1c9f00141ae1d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala
@@ -17,15 +17,23 @@
package org.apache.spark.sql.hive.execution
+import java.net.URI
+
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.sql.{AnalysisException, SaveMode}
+import org.apache.spark.sql.catalyst.catalog.CatalogTable
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils}
+import org.apache.spark.sql.execution.datasources.CreateTable
import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.types.StructType
/**
* A set of tests that validates support for Hive SerDe.
*/
-class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll {
+class HiveSerDeSuite extends HiveComparisonTest with PlanTest with BeforeAndAfterAll {
override def beforeAll(): Unit = {
import TestHive._
import org.apache.hadoop.hive.serde2.RegexSerDe
@@ -60,4 +68,127 @@ class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll {
val serdeinsRes = InputOutputMetricsHelper.run(sql("select * from serdeins").toDF())
assert(serdeinsRes === (serdeinsCnt, 0L, serdeinsCnt) :: Nil)
}
+
+ private def extractTableDesc(sql: String): (CatalogTable, Boolean) = {
+ TestHive.sessionState.sqlParser.parsePlan(sql).collect {
+ case CreateTable(tableDesc, mode, _) => (tableDesc, mode == SaveMode.Ignore)
+ }.head
+ }
+
+ private def analyzeCreateTable(sql: String): CatalogTable = {
+ TestHive.sessionState.analyzer.execute(TestHive.sessionState.sqlParser.parsePlan(sql)).collect {
+ case CreateTableCommand(tableDesc, _) => tableDesc
+ }.head
+ }
+
+ test("Test the default fileformat for Hive-serde tables") {
+ withSQLConf("hive.default.fileformat" -> "orc") {
+ val (desc, exists) = extractTableDesc("CREATE TABLE IF NOT EXISTS fileformat_test (id int)")
+ assert(exists)
+ assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"))
+ assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat"))
+ assert(desc.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde"))
+ }
+
+ withSQLConf("hive.default.fileformat" -> "parquet") {
+ val (desc, exists) = extractTableDesc("CREATE TABLE IF NOT EXISTS fileformat_test (id int)")
+ assert(exists)
+ val input = desc.storage.inputFormat
+ val output = desc.storage.outputFormat
+ val serde = desc.storage.serde
+ assert(input == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"))
+ assert(output == Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"))
+ assert(serde == Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"))
+ }
+ }
+
+ test("create hive serde table with new syntax - basic") {
+ val sql =
+ """
+ |CREATE TABLE t
+ |(id int, name string COMMENT 'blabla')
+ |USING hive
+ |OPTIONS (fileFormat 'parquet', my_prop 1)
+ |LOCATION '/tmp/file'
+ |COMMENT 'BLABLA'
+ """.stripMargin
+
+ val table = analyzeCreateTable(sql)
+ assert(table.schema == new StructType()
+ .add("id", "int")
+ .add("name", "string", nullable = true, comment = "blabla"))
+ assert(table.provider == Some(DDLUtils.HIVE_PROVIDER))
+ assert(table.storage.locationUri == Some(new URI("/tmp/file")))
+ assert(table.storage.properties == Map("my_prop" -> "1"))
+ assert(table.comment == Some("BLABLA"))
+
+ assert(table.storage.inputFormat ==
+ Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"))
+ assert(table.storage.outputFormat ==
+ Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"))
+ assert(table.storage.serde ==
+ Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"))
+ }
+
+ test("create hive serde table with new syntax - with partition and bucketing") {
+ val v1 = "CREATE TABLE t (c1 int, c2 int) USING hive PARTITIONED BY (c2)"
+ val table = analyzeCreateTable(v1)
+ assert(table.schema == new StructType().add("c1", "int").add("c2", "int"))
+ assert(table.partitionColumnNames == Seq("c2"))
+ // check the default formats
+ assert(table.storage.serde == Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"))
+ assert(table.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat"))
+ assert(table.storage.outputFormat ==
+ Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat"))
+
+ val v2 = "CREATE TABLE t (c1 int, c2 int) USING hive CLUSTERED BY (c2) INTO 4 BUCKETS"
+ val e2 = intercept[AnalysisException](analyzeCreateTable(v2))
+ assert(e2.message.contains("Creating bucketed Hive serde table is not supported yet"))
+
+ val v3 =
+ """
+ |CREATE TABLE t (c1 int, c2 int) USING hive
+ |PARTITIONED BY (c2)
+ |CLUSTERED BY (c2) INTO 4 BUCKETS""".stripMargin
+ val e3 = intercept[AnalysisException](analyzeCreateTable(v3))
+ assert(e3.message.contains("Creating bucketed Hive serde table is not supported yet"))
+ }
+
+ test("create hive serde table with new syntax - Hive options error checking") {
+ val v1 = "CREATE TABLE t (c1 int) USING hive OPTIONS (inputFormat 'abc')"
+ val e1 = intercept[IllegalArgumentException](analyzeCreateTable(v1))
+ assert(e1.getMessage.contains("Cannot specify only inputFormat or outputFormat"))
+
+ val v2 = "CREATE TABLE t (c1 int) USING hive OPTIONS " +
+ "(fileFormat 'x', inputFormat 'a', outputFormat 'b')"
+ val e2 = intercept[IllegalArgumentException](analyzeCreateTable(v2))
+ assert(e2.getMessage.contains(
+ "Cannot specify fileFormat and inputFormat/outputFormat together"))
+
+ val v3 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'parquet', serde 'a')"
+ val e3 = intercept[IllegalArgumentException](analyzeCreateTable(v3))
+ assert(e3.getMessage.contains("fileFormat 'parquet' already specifies a serde"))
+
+ val v4 = "CREATE TABLE t (c1 int) USING hive OPTIONS (serde 'a', fieldDelim ' ')"
+ val e4 = intercept[IllegalArgumentException](analyzeCreateTable(v4))
+ assert(e4.getMessage.contains("Cannot specify delimiters with a custom serde"))
+
+ val v5 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fieldDelim ' ')"
+ val e5 = intercept[IllegalArgumentException](analyzeCreateTable(v5))
+ assert(e5.getMessage.contains("Cannot specify delimiters without fileFormat"))
+
+ val v6 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'parquet', fieldDelim ' ')"
+ val e6 = intercept[IllegalArgumentException](analyzeCreateTable(v6))
+ assert(e6.getMessage.contains(
+ "Cannot specify delimiters as they are only compatible with fileFormat 'textfile'"))
+
+ // The value of 'fileFormat' option is case-insensitive.
+ val v7 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'TEXTFILE', lineDelim ',')"
+ val e7 = intercept[IllegalArgumentException](analyzeCreateTable(v7))
+ assert(e7.getMessage.contains("Hive data source only support newline '\\n' as line delimiter"))
+
+ val v8 = "CREATE TABLE t (c1 int) USING hive OPTIONS (fileFormat 'wrong')"
+ val e8 = intercept[IllegalArgumentException](analyzeCreateTable(v8))
+ assert(e8.getMessage.contains("invalid fileFormat: 'wrong'"))
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
index ae64cb3210b53..3f9bb8de42e09 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
@@ -81,14 +81,16 @@ class HiveTableScanSuite extends HiveComparisonTest with SQLTestUtils with TestH
}
test("Spark-4959 Attributes are case sensitive when using a select query from a projection") {
- sql("create table spark_4959 (col1 string)")
- sql("""insert into table spark_4959 select "hi" from src limit 1""")
- table("spark_4959").select(
- 'col1.as("CaseSensitiveColName"),
- 'col1.as("CaseSensitiveColName2")).createOrReplaceTempView("spark_4959_2")
-
- assert(sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi"))
- assert(sql("select casesensitivecolname from spark_4959_2").head() === Row("hi"))
+ withTable("spark_4959") {
+ sql("create table spark_4959 (col1 string)")
+ sql("""insert into table spark_4959 select "hi" from src limit 1""")
+ table("spark_4959").select(
+ 'col1.as("CaseSensitiveColName"),
+ 'col1.as("CaseSensitiveColName2")).createOrReplaceTempView("spark_4959_2")
+
+ assert(sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi"))
+ assert(sql("select casesensitivecolname from spark_4959_2").head() === Row("hi"))
+ }
}
private def checkNumScannedPartitions(stmt: String, expectedNumParts: Int): Unit = {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala
index 479ca1e8def56..8986fb58c6460 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala
@@ -26,6 +26,7 @@ import org.apache.hadoop.hive.ql.util.JavaDataModel
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory}
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo
+import test.org.apache.spark.sql.MyDoubleAvg
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec
@@ -86,6 +87,18 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
))
}
+ test("call JAVA UDAF") {
+ withTempView("temp") {
+ withUserDefinedFunction("myDoubleAvg" -> false) {
+ spark.range(1, 10).toDF("value").createOrReplaceTempView("temp")
+ sql(s"CREATE FUNCTION myDoubleAvg AS '${classOf[MyDoubleAvg].getName}'")
+ checkAnswer(
+ spark.sql("SELECT default.myDoubleAvg(value) as my_avg from temp"),
+ Row(105.0))
+ }
+ }
+ }
+
test("non-deterministic children expressions of UDAF") {
withTempView("view1") {
spark.range(1).selectExpr("id as x", "id as y").createTempView("view1")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index cae338c0ab0ae..6198d4963df33 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -74,26 +74,28 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
}
test("hive struct udf") {
- sql(
- """
- |CREATE TABLE hiveUDFTestTable (
- | pair STRUCT
- |)
- |PARTITIONED BY (partition STRING)
- |ROW FORMAT SERDE '%s'
- |STORED AS SEQUENCEFILE
- """.
- stripMargin.format(classOf[PairSerDe].getName))
-
- val location = Utils.getSparkClassLoader.getResource("data/files/testUDF").getFile
- sql(s"""
- ALTER TABLE hiveUDFTestTable
- ADD IF NOT EXISTS PARTITION(partition='testUDF')
- LOCATION '$location'""")
-
- sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[PairUDF].getName}'")
- sql("SELECT testUDF(pair) FROM hiveUDFTestTable")
- sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF")
+ withTable("hiveUDFTestTable") {
+ sql(
+ """
+ |CREATE TABLE hiveUDFTestTable (
+ | pair STRUCT
+ |)
+ |PARTITIONED BY (partition STRING)
+ |ROW FORMAT SERDE '%s'
+ |STORED AS SEQUENCEFILE
+ """.
+ stripMargin.format(classOf[PairSerDe].getName))
+
+ val location = Utils.getSparkClassLoader.getResource("data/files/testUDF").getFile
+ sql(s"""
+ ALTER TABLE hiveUDFTestTable
+ ADD IF NOT EXISTS PARTITION(partition='testUDF')
+ LOCATION '$location'""")
+
+ sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[PairUDF].getName}'")
+ sql("SELECT testUDF(pair) FROM hiveUDFTestTable")
+ sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF")
+ }
}
test("Max/Min on named_struct") {
@@ -404,59 +406,34 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
}
test("Hive UDFs with insufficient number of input arguments should trigger an analysis error") {
- Seq((1, 2)).toDF("a", "b").createOrReplaceTempView("testUDF")
+ withTempView("testUDF") {
+ Seq((1, 2)).toDF("a", "b").createOrReplaceTempView("testUDF")
+
+ def testErrorMsgForFunc(funcName: String, className: String): Unit = {
+ withUserDefinedFunction(funcName -> true) {
+ sql(s"CREATE TEMPORARY FUNCTION $funcName AS '$className'")
+ val message = intercept[AnalysisException] {
+ sql(s"SELECT $funcName() FROM testUDF")
+ }.getMessage
+ assert(message.contains(s"No handler for UDF/UDAF/UDTF '$className'"))
+ }
+ }
- {
// HiveSimpleUDF
- sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'")
- val message = intercept[AnalysisException] {
- sql("SELECT testUDFTwoListList() FROM testUDF")
- }.getMessage
- assert(message.contains("No handler for Hive UDF"))
- sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList")
- }
+ testErrorMsgForFunc("testUDFTwoListList", classOf[UDFTwoListList].getName)
- {
// HiveGenericUDF
- sql(s"CREATE TEMPORARY FUNCTION testUDFAnd AS '${classOf[GenericUDFOPAnd].getName}'")
- val message = intercept[AnalysisException] {
- sql("SELECT testUDFAnd() FROM testUDF")
- }.getMessage
- assert(message.contains("No handler for Hive UDF"))
- sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFAnd")
- }
+ testErrorMsgForFunc("testUDFAnd", classOf[GenericUDFOPAnd].getName)
- {
// Hive UDAF
- sql(s"CREATE TEMPORARY FUNCTION testUDAFPercentile AS '${classOf[UDAFPercentile].getName}'")
- val message = intercept[AnalysisException] {
- sql("SELECT testUDAFPercentile(a) FROM testUDF GROUP BY b")
- }.getMessage
- assert(message.contains("No handler for Hive UDF"))
- sql("DROP TEMPORARY FUNCTION IF EXISTS testUDAFPercentile")
- }
+ testErrorMsgForFunc("testUDAFPercentile", classOf[UDAFPercentile].getName)
- {
// AbstractGenericUDAFResolver
- sql(s"CREATE TEMPORARY FUNCTION testUDAFAverage AS '${classOf[GenericUDAFAverage].getName}'")
- val message = intercept[AnalysisException] {
- sql("SELECT testUDAFAverage() FROM testUDF GROUP BY b")
- }.getMessage
- assert(message.contains("No handler for Hive UDF"))
- sql("DROP TEMPORARY FUNCTION IF EXISTS testUDAFAverage")
- }
+ testErrorMsgForFunc("testUDAFAverage", classOf[GenericUDAFAverage].getName)
- {
- // Hive UDTF
- sql(s"CREATE TEMPORARY FUNCTION testUDTFExplode AS '${classOf[GenericUDTFExplode].getName}'")
- val message = intercept[AnalysisException] {
- sql("SELECT testUDTFExplode() FROM testUDF")
- }.getMessage
- assert(message.contains("No handler for Hive UDF"))
- sql("DROP TEMPORARY FUNCTION IF EXISTS testUDTFExplode")
+ // AbstractGenericUDAFResolver
+ testErrorMsgForFunc("testUDTFExplode", classOf[GenericUDTFExplode].getName)
}
-
- spark.catalog.dropTempView("testUDF")
}
test("Hive UDF in group by") {
@@ -621,6 +598,46 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
}
}
}
+
+ test("UDTF") {
+ withUserDefinedFunction("udtf_count2" -> true) {
+ sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath}")
+ // The function source code can be found at:
+ // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF
+ sql(
+ """
+ |CREATE TEMPORARY FUNCTION udtf_count2
+ |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'
+ """.stripMargin)
+
+ checkAnswer(
+ sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"),
+ Row(97, 500) :: Row(97, 500) :: Nil)
+
+ checkAnswer(
+ sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"),
+ Row(3) :: Row(3) :: Nil)
+ }
+ }
+
+ test("permanent UDTF") {
+ withUserDefinedFunction("udtf_count_temp" -> false) {
+ sql(
+ s"""
+ |CREATE FUNCTION udtf_count_temp
+ |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'
+ |USING JAR '${hiveContext.getHiveFile("TestUDTF.jar").toURI}'
+ """.stripMargin)
+
+ checkAnswer(
+ sql("SELECT key, cc FROM src LATERAL VIEW udtf_count_temp(value) dd AS cc"),
+ Row(97, 500) :: Row(97, 500) :: Nil)
+
+ checkAnswer(
+ sql("SELECT udtf_count_temp(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"),
+ Row(3) :: Row(3) :: Nil)
+ }
+ }
}
class TestPair(x: Int, y: Int) extends Writable with Serializable {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala
new file mode 100644
index 0000000000000..5c248b9acd04f
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/Hive_2_1_DDLSuite.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.execution
+
+import scala.language.existentials
+
+import org.apache.hadoop.conf.Configuration
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.launcher.SparkLauncher
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.catalog._
+import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils}
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.StaticSQLConf._
+import org.apache.spark.sql.types._
+import org.apache.spark.tags.ExtendedHiveTest
+import org.apache.spark.util.Utils
+
+/**
+ * A separate set of DDL tests that uses Hive 2.1 libraries, which behave a little differently
+ * from the built-in ones.
+ */
+@ExtendedHiveTest
+class Hive_2_1_DDLSuite extends SparkFunSuite with TestHiveSingleton with BeforeAndAfterEach
+ with BeforeAndAfterAll {
+
+ // Create a custom HiveExternalCatalog instance with the desired configuration. We cannot
+ // use SparkSession here since there's already an active on managed by the TestHive object.
+ private var catalog = {
+ val warehouse = Utils.createTempDir()
+ val metastore = Utils.createTempDir()
+ metastore.delete()
+ val sparkConf = new SparkConf()
+ .set(SparkLauncher.SPARK_MASTER, "local")
+ .set(WAREHOUSE_PATH.key, warehouse.toURI().toString())
+ .set(CATALOG_IMPLEMENTATION.key, "hive")
+ .set(HiveUtils.HIVE_METASTORE_VERSION.key, "2.1")
+ .set(HiveUtils.HIVE_METASTORE_JARS.key, "maven")
+
+ val hadoopConf = new Configuration()
+ hadoopConf.set("hive.metastore.warehouse.dir", warehouse.toURI().toString())
+ hadoopConf.set("javax.jdo.option.ConnectionURL",
+ s"jdbc:derby:;databaseName=${metastore.getAbsolutePath()};create=true")
+ // These options are needed since the defaults in Hive 2.1 cause exceptions with an
+ // empty metastore db.
+ hadoopConf.set("datanucleus.schema.autoCreateAll", "true")
+ hadoopConf.set("hive.metastore.schema.verification", "false")
+
+ new HiveExternalCatalog(sparkConf, hadoopConf)
+ }
+
+ override def afterEach: Unit = {
+ catalog.listTables("default").foreach { t =>
+ catalog.dropTable("default", t, true, false)
+ }
+ spark.sessionState.catalog.reset()
+ }
+
+ override def afterAll(): Unit = {
+ catalog = null
+ }
+
+ test("SPARK-21617: ALTER TABLE for non-compatible DataSource tables") {
+ testAlterTable(
+ "t1",
+ "CREATE TABLE t1 (c1 int) USING json",
+ StructType(Array(StructField("c1", IntegerType), StructField("c2", IntegerType))),
+ hiveCompatible = false)
+ }
+
+ test("SPARK-21617: ALTER TABLE for Hive-compatible DataSource tables") {
+ testAlterTable(
+ "t1",
+ "CREATE TABLE t1 (c1 int) USING parquet",
+ StructType(Array(StructField("c1", IntegerType), StructField("c2", IntegerType))))
+ }
+
+ test("SPARK-21617: ALTER TABLE for Hive tables") {
+ testAlterTable(
+ "t1",
+ "CREATE TABLE t1 (c1 int) STORED AS parquet",
+ StructType(Array(StructField("c1", IntegerType), StructField("c2", IntegerType))))
+ }
+
+ test("SPARK-21617: ALTER TABLE with incompatible schema on Hive-compatible table") {
+ val exception = intercept[AnalysisException] {
+ testAlterTable(
+ "t1",
+ "CREATE TABLE t1 (c1 string) USING parquet",
+ StructType(Array(StructField("c2", IntegerType))))
+ }
+ assert(exception.getMessage().contains("types incompatible with the existing columns"))
+ }
+
+ private def testAlterTable(
+ tableName: String,
+ createTableStmt: String,
+ updatedSchema: StructType,
+ hiveCompatible: Boolean = true): Unit = {
+ spark.sql(createTableStmt)
+ val oldTable = spark.sessionState.catalog.externalCatalog.getTable("default", tableName)
+ catalog.createTable(oldTable, true)
+ catalog.alterTableSchema("default", tableName, updatedSchema)
+
+ val updatedTable = catalog.getTable("default", tableName)
+ assert(updatedTable.schema.fieldNames === updatedSchema.fieldNames)
+ }
+
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
index d535bef4cc787..cc592cf6ca629 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala
@@ -162,7 +162,12 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter {
}.head
assert(actualOutputColumns === expectedOutputColumns, "Output columns mismatch")
- assert(actualScannedColumns === expectedScannedColumns, "Scanned columns mismatch")
+
+ // Scanned columns in `HiveTableScanExec` are generated by the `pruneFilterProject` method
+ // in `SparkPlanner`. This method internally uses `AttributeSet.toSeq`, in which
+ // the returned output columns are sorted by the names and expression ids.
+ assert(actualScannedColumns.sorted === expectedScannedColumns.sorted,
+ "Scanned columns mismatch")
val actualPartitions = actualPartValues.map(_.asScala.mkString(",")).sorted
val expectedPartitions = expectedPartValues.map(_.mkString(",")).sorted
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala
index 24c038587d1d6..022cb7177339d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLMetricsSuite.scala
@@ -17,112 +17,10 @@
package org.apache.spark.sql.hive.execution
-import java.io.File
-
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.functions._
+import org.apache.spark.sql.execution.metric.SQLMetricsTestUtils
import org.apache.spark.sql.hive.test.TestHiveSingleton
-import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.util.Utils
-
-class SQLMetricsSuite extends SQLTestUtils with TestHiveSingleton {
- import spark.implicits._
-
- /**
- * Get execution metrics for the SQL execution and verify metrics values.
- *
- * @param metricsValues the expected metric values (numFiles, numPartitions, numOutputRows).
- * @param func the function can produce execution id after running.
- */
- private def verifyWriteDataMetrics(metricsValues: Seq[Int])(func: => Unit): Unit = {
- val previousExecutionIds = spark.sharedState.listener.executionIdToData.keySet
- // Run the given function to trigger query execution.
- func
- spark.sparkContext.listenerBus.waitUntilEmpty(10000)
- val executionIds =
- spark.sharedState.listener.executionIdToData.keySet.diff(previousExecutionIds)
- assert(executionIds.size == 1)
- val executionId = executionIds.head
-
- val executionData = spark.sharedState.listener.getExecution(executionId).get
- val executedNode = executionData.physicalPlanGraph.nodes.head
-
- val metricsNames = Seq(
- "number of written files",
- "number of dynamic part",
- "number of output rows")
-
- val metrics = spark.sharedState.listener.getExecutionMetrics(executionId)
-
- metricsNames.zip(metricsValues).foreach { case (metricsName, expected) =>
- val sqlMetric = executedNode.metrics.find(_.name == metricsName)
- assert(sqlMetric.isDefined)
- val accumulatorId = sqlMetric.get.accumulatorId
- val metricValue = metrics(accumulatorId).replaceAll(",", "").toInt
- assert(metricValue == expected)
- }
-
- val totalNumBytesMetric = executedNode.metrics.find(_.name == "bytes of written output").get
- val totalNumBytes = metrics(totalNumBytesMetric.accumulatorId).replaceAll(",", "").toInt
- assert(totalNumBytes > 0)
- }
-
- private def testMetricsNonDynamicPartition(
- dataFormat: String,
- tableName: String): Unit = {
- withTable(tableName) {
- Seq((1, 2)).toDF("i", "j")
- .write.format(dataFormat).mode("overwrite").saveAsTable(tableName)
-
- val tableLocation =
- new File(spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).location)
- // 2 files, 100 rows, 0 dynamic partition.
- verifyWriteDataMetrics(Seq(2, 0, 100)) {
- (0 until 100).map(i => (i, i + 1)).toDF("i", "j").repartition(2)
- .write.format(dataFormat).mode("overwrite").insertInto(tableName)
- }
- assert(Utils.recursiveList(tableLocation).count(_.getName.startsWith("part-")) == 2)
- }
- }
-
- private def testMetricsDynamicPartition(
- provider: String,
- dataFormat: String,
- tableName: String): Unit = {
- withTempPath { dir =>
- spark.sql(
- s"""
- |CREATE TABLE $tableName(a int, b int)
- |USING $provider
- |PARTITIONED BY(a)
- |LOCATION '${dir.toURI}'
- """.stripMargin)
- val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName))
- assert(table.location == makeQualifiedPath(dir.getAbsolutePath))
-
- val df = spark.range(start = 0, end = 40, step = 1, numPartitions = 1)
- .selectExpr("id a", "id b")
-
- // 40 files, 80 rows, 40 dynamic partitions.
- verifyWriteDataMetrics(Seq(40, 40, 80)) {
- df.union(df).repartition(2, $"a")
- .write
- .format(dataFormat)
- .mode("overwrite")
- .insertInto(tableName)
- }
- assert(Utils.recursiveList(dir).count(_.getName.startsWith("part-")) == 40)
- }
- }
-
- test("writing data out metrics: parquet") {
- testMetricsNonDynamicPartition("parquet", "t1")
- }
-
- test("writing data out metrics with dynamic partition: parquet") {
- testMetricsDynamicPartition("parquet", "parquet", "t1")
- }
+class SQLMetricsSuite extends SQLMetricsTestUtils with TestHiveSingleton {
test("writing data out metrics: hive") {
testMetricsNonDynamicPartition("hive", "t1")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 86b38d04a908f..175a6c6d69aa9 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -28,7 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.TestUtils
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry, NoSuchPartitionException}
+import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{CatalogTableType, CatalogUtils, HiveTableRelation}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
@@ -98,46 +98,6 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
checkAnswer(query1, Row("x1_y1") :: Row("x2_y2") :: Nil)
}
- test("UDTF") {
- withUserDefinedFunction("udtf_count2" -> true) {
- sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}")
- // The function source code can be found at:
- // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF
- sql(
- """
- |CREATE TEMPORARY FUNCTION udtf_count2
- |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'
- """.stripMargin)
-
- checkAnswer(
- sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"),
- Row(97, 500) :: Row(97, 500) :: Nil)
-
- checkAnswer(
- sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"),
- Row(3) :: Row(3) :: Nil)
- }
- }
-
- test("permanent UDTF") {
- withUserDefinedFunction("udtf_count_temp" -> false) {
- sql(
- s"""
- |CREATE FUNCTION udtf_count_temp
- |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'
- |USING JAR '${hiveContext.getHiveFile("TestUDTF.jar").toURI}'
- """.stripMargin)
-
- checkAnswer(
- sql("SELECT key, cc FROM src LATERAL VIEW udtf_count_temp(value) dd AS cc"),
- Row(97, 500) :: Row(97, 500) :: Nil)
-
- checkAnswer(
- sql("SELECT udtf_count_temp(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"),
- Row(3) :: Row(3) :: Nil)
- }
- }
-
test("SPARK-6835: udtf in lateral view") {
val df = Seq((1, 1)).toDF("c1", "c2")
df.createOrReplaceTempView("table1")
@@ -176,49 +136,51 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
orders.toDF.createOrReplaceTempView("orders1")
orderUpdates.toDF.createOrReplaceTempView("orderupdates1")
- sql(
- """CREATE TABLE orders(
- | id INT,
- | make String,
- | type String,
- | price INT,
- | pdate String,
- | customer String,
- | city String)
- |PARTITIONED BY (state STRING, month INT)
- |STORED AS PARQUET
- """.stripMargin)
+ withTable("orders", "orderupdates") {
+ sql(
+ """CREATE TABLE orders(
+ | id INT,
+ | make String,
+ | type String,
+ | price INT,
+ | pdate String,
+ | customer String,
+ | city String)
+ |PARTITIONED BY (state STRING, month INT)
+ |STORED AS PARQUET
+ """.stripMargin)
- sql(
- """CREATE TABLE orderupdates(
- | id INT,
- | make String,
- | type String,
- | price INT,
- | pdate String,
- | customer String,
- | city String)
- |PARTITIONED BY (state STRING, month INT)
- |STORED AS PARQUET
- """.stripMargin)
+ sql(
+ """CREATE TABLE orderupdates(
+ | id INT,
+ | make String,
+ | type String,
+ | price INT,
+ | pdate String,
+ | customer String,
+ | city String)
+ |PARTITIONED BY (state STRING, month INT)
+ |STORED AS PARQUET
+ """.stripMargin)
- sql("set hive.exec.dynamic.partition.mode=nonstrict")
- sql("INSERT INTO TABLE orders PARTITION(state, month) SELECT * FROM orders1")
- sql("INSERT INTO TABLE orderupdates PARTITION(state, month) SELECT * FROM orderupdates1")
+ sql("set hive.exec.dynamic.partition.mode=nonstrict")
+ sql("INSERT INTO TABLE orders PARTITION(state, month) SELECT * FROM orders1")
+ sql("INSERT INTO TABLE orderupdates PARTITION(state, month) SELECT * FROM orderupdates1")
- checkAnswer(
- sql(
- """
- |select orders.state, orders.month
- |from orders
- |join (
- | select distinct orders.state,orders.month
- | from orders
- | join orderupdates
- | on orderupdates.id = orders.id) ao
- | on ao.state = orders.state and ao.month = orders.month
- """.stripMargin),
- (1 to 6).map(_ => Row("CA", 20151)))
+ checkAnswer(
+ sql(
+ """
+ |select orders.state, orders.month
+ |from orders
+ |join (
+ | select distinct orders.state,orders.month
+ | from orders
+ | join orderupdates
+ | on orderupdates.id = orders.id) ao
+ | on ao.state = orders.state and ao.month = orders.month
+ """.stripMargin),
+ (1 to 6).map(_ => Row("CA", 20151)))
+ }
}
test("show functions") {
@@ -389,21 +351,23 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
}
test("CTAS with WITH clause") {
+
val df = Seq((1, 1)).toDF("c1", "c2")
df.createOrReplaceTempView("table1")
-
- sql(
- """
- |CREATE TABLE with_table1 AS
- |WITH T AS (
- | SELECT *
- | FROM table1
- |)
- |SELECT *
- |FROM T
- """.stripMargin)
- val query = sql("SELECT * FROM with_table1")
- checkAnswer(query, Row(1, 1) :: Nil)
+ withTable("with_table1") {
+ sql(
+ """
+ |CREATE TABLE with_table1 AS
+ |WITH T AS (
+ | SELECT *
+ | FROM table1
+ |)
+ |SELECT *
+ |FROM T
+ """.stripMargin)
+ val query = sql("SELECT * FROM with_table1")
+ checkAnswer(query, Row(1, 1) :: Nil)
+ }
}
test("explode nested Field") {
@@ -451,7 +415,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
val catalogTable =
sessionState.catalog.getTableMetadata(TableIdentifier(tableName))
relation match {
- case LogicalRelation(r: HadoopFsRelation, _, _) =>
+ case LogicalRelation(r: HadoopFsRelation, _, _, _) =>
if (!isDataSourceTable) {
fail(
s"${classOf[HiveTableRelation].getCanonicalName} is expected, but found " +
@@ -604,86 +568,90 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
}
test("CTAS with serde") {
- sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value")
- sql(
- """CREATE TABLE ctas2
- | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe"
- | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2")
- | STORED AS RCFile
- | TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22")
- | AS
- | SELECT key, value
- | FROM src
- | ORDER BY key, value""".stripMargin)
-
- val storageCtas2 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("ctas2")).storage
- assert(storageCtas2.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat"))
- assert(storageCtas2.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"))
- assert(storageCtas2.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe"))
-
- sql(
- """CREATE TABLE ctas3
- | ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\012'
- | STORED AS textfile AS
- | SELECT key, value
- | FROM src
- | ORDER BY key, value""".stripMargin)
-
- // the table schema may like (key: integer, value: string)
- sql(
- """CREATE TABLE IF NOT EXISTS ctas4 AS
- | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin)
- // do nothing cause the table ctas4 already existed.
- sql(
- """CREATE TABLE IF NOT EXISTS ctas4 AS
- | SELECT key, value FROM src ORDER BY key, value""".stripMargin)
+ withTable("ctas1", "ctas2", "ctas3", "ctas4", "ctas5") {
+ sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value")
+ sql(
+ """CREATE TABLE ctas2
+ | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe"
+ | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2")
+ | STORED AS RCFile
+ | TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22")
+ | AS
+ | SELECT key, value
+ | FROM src
+ | ORDER BY key, value""".stripMargin)
+
+ val storageCtas2 = spark.sessionState.catalog.
+ getTableMetadata(TableIdentifier("ctas2")).storage
+ assert(storageCtas2.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat"))
+ assert(storageCtas2.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat"))
+ assert(storageCtas2.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe"))
- checkAnswer(
- sql("SELECT k, value FROM ctas1 ORDER BY k, value"),
- sql("SELECT key, value FROM src ORDER BY key, value"))
- checkAnswer(
- sql("SELECT key, value FROM ctas2 ORDER BY key, value"),
sql(
- """
+ """CREATE TABLE ctas3
+ | ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' LINES TERMINATED BY '\012'
+ | STORED AS textfile AS
+ | SELECT key, value
+ | FROM src
+ | ORDER BY key, value""".stripMargin)
+
+ // the table schema may like (key: integer, value: string)
+ sql(
+ """CREATE TABLE IF NOT EXISTS ctas4 AS
+ | SELECT 1 AS key, value FROM src LIMIT 1""".stripMargin)
+ // do nothing cause the table ctas4 already existed.
+ sql(
+ """CREATE TABLE IF NOT EXISTS ctas4 AS
+ | SELECT key, value FROM src ORDER BY key, value""".stripMargin)
+
+ checkAnswer(
+ sql("SELECT k, value FROM ctas1 ORDER BY k, value"),
+ sql("SELECT key, value FROM src ORDER BY key, value"))
+ checkAnswer(
+ sql("SELECT key, value FROM ctas2 ORDER BY key, value"),
+ sql(
+ """
SELECT key, value
FROM src
ORDER BY key, value"""))
- checkAnswer(
- sql("SELECT key, value FROM ctas3 ORDER BY key, value"),
- sql(
- """
+ checkAnswer(
+ sql("SELECT key, value FROM ctas3 ORDER BY key, value"),
+ sql(
+ """
SELECT key, value
FROM src
ORDER BY key, value"""))
- intercept[AnalysisException] {
- sql(
- """CREATE TABLE ctas4 AS
- | SELECT key, value FROM src ORDER BY key, value""".stripMargin)
- }
- checkAnswer(
- sql("SELECT key, value FROM ctas4 ORDER BY key, value"),
- sql("SELECT key, value FROM ctas4 LIMIT 1").collect().toSeq)
-
- sql(
- """CREATE TABLE ctas5
- | STORED AS parquet AS
- | SELECT key, value
- | FROM src
- | ORDER BY key, value""".stripMargin)
- val storageCtas5 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("ctas5")).storage
- assert(storageCtas5.inputFormat ==
- Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"))
- assert(storageCtas5.outputFormat ==
- Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"))
- assert(storageCtas5.serde ==
- Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"))
-
-
- // use the Hive SerDe for parquet tables
- withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "false") {
+ intercept[AnalysisException] {
+ sql(
+ """CREATE TABLE ctas4 AS
+ | SELECT key, value FROM src ORDER BY key, value""".stripMargin)
+ }
checkAnswer(
- sql("SELECT key, value FROM ctas5 ORDER BY key, value"),
- sql("SELECT key, value FROM src ORDER BY key, value"))
+ sql("SELECT key, value FROM ctas4 ORDER BY key, value"),
+ sql("SELECT key, value FROM ctas4 LIMIT 1").collect().toSeq)
+
+ sql(
+ """CREATE TABLE ctas5
+ | STORED AS parquet AS
+ | SELECT key, value
+ | FROM src
+ | ORDER BY key, value""".stripMargin)
+ val storageCtas5 = spark.sessionState.catalog.
+ getTableMetadata(TableIdentifier("ctas5")).storage
+ assert(storageCtas5.inputFormat ==
+ Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"))
+ assert(storageCtas5.outputFormat ==
+ Some("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"))
+ assert(storageCtas5.serde ==
+ Some("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"))
+
+
+ // use the Hive SerDe for parquet tables
+ withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "false") {
+ checkAnswer(
+ sql("SELECT key, value FROM ctas5 ORDER BY key, value"),
+ sql("SELECT key, value FROM src ORDER BY key, value"))
+ }
}
}
@@ -756,40 +724,46 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
}
test("double nested data") {
- sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil)
- .toDF().createOrReplaceTempView("nested")
- checkAnswer(
- sql("SELECT f1.f2.f3 FROM nested"),
- Row(1))
+ withTable("test_ctas_1234") {
+ sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil)
+ .toDF().createOrReplaceTempView("nested")
+ checkAnswer(
+ sql("SELECT f1.f2.f3 FROM nested"),
+ Row(1))
- sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested")
- checkAnswer(
- sql("SELECT * FROM test_ctas_1234"),
- sql("SELECT * FROM nested").collect().toSeq)
+ sql("CREATE TABLE test_ctas_1234 AS SELECT * from nested")
+ checkAnswer(
+ sql("SELECT * FROM test_ctas_1234"),
+ sql("SELECT * FROM nested").collect().toSeq)
- intercept[AnalysisException] {
- sql("CREATE TABLE test_ctas_1234 AS SELECT * from notexists").collect()
+ intercept[AnalysisException] {
+ sql("CREATE TABLE test_ctas_1234 AS SELECT * from notexists").collect()
+ }
}
}
test("test CTAS") {
- sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src")
- checkAnswer(
- sql("SELECT key, value FROM test_ctas_123 ORDER BY key"),
- sql("SELECT key, value FROM src ORDER BY key").collect().toSeq)
+ withTable("test_ctas_1234") {
+ sql("CREATE TABLE test_ctas_123 AS SELECT key, value FROM src")
+ checkAnswer(
+ sql("SELECT key, value FROM test_ctas_123 ORDER BY key"),
+ sql("SELECT key, value FROM src ORDER BY key").collect().toSeq)
+ }
}
test("SPARK-4825 save join to table") {
- val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF()
- sql("CREATE TABLE test1 (key INT, value STRING)")
- testData.write.mode(SaveMode.Append).insertInto("test1")
- sql("CREATE TABLE test2 (key INT, value STRING)")
- testData.write.mode(SaveMode.Append).insertInto("test2")
- testData.write.mode(SaveMode.Append).insertInto("test2")
- sql("CREATE TABLE test AS SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key")
- checkAnswer(
- table("test"),
- sql("SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key").collect().toSeq)
+ withTable("test1", "test2", "test") {
+ val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF()
+ sql("CREATE TABLE test1 (key INT, value STRING)")
+ testData.write.mode(SaveMode.Append).insertInto("test1")
+ sql("CREATE TABLE test2 (key INT, value STRING)")
+ testData.write.mode(SaveMode.Append).insertInto("test2")
+ testData.write.mode(SaveMode.Append).insertInto("test2")
+ sql("CREATE TABLE test AS SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key")
+ checkAnswer(
+ table("test"),
+ sql("SELECT COUNT(a.value) FROM test1 a JOIN test2 b ON a.key = b.key").collect().toSeq)
+ }
}
test("SPARK-3708 Backticks aren't handled correctly is aliases") {
@@ -1883,14 +1857,16 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
test("SPARK-17108: Fix BIGINT and INT comparison failure in spark sql") {
- sql("create table t1(a map>)")
- sql("select * from t1 where a[1] is not null")
+ withTable("t1", "t2", "t3") {
+ sql("create table t1(a map>)")
+ sql("select * from t1 where a[1] is not null")
- sql("create table t2(a map>)")
- sql("select * from t2 where a[1] is not null")
+ sql("create table t2(a map>)")
+ sql("select * from t2 where a[1] is not null")
- sql("create table t3(a map>)")
- sql("select * from t3 where a[1L] is not null")
+ sql("create table t3(a map>)")
+ sql("select * from t3 where a[1L] is not null")
+ }
}
test("SPARK-17796 Support wildcard character in filename for LOAD DATA LOCAL INPATH") {
@@ -2023,7 +1999,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
}
test("SPARK-21721: Clear FileSystem deleterOnExit cache if path is successfully removed") {
- withTable("test21721") {
+ val table = "test21721"
+ withTable(table) {
val deleteOnExitField = classOf[FileSystem].getDeclaredField("deleteOnExit")
deleteOnExitField.setAccessible(true)
@@ -2031,12 +2008,46 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
val setOfPath = deleteOnExitField.get(fs).asInstanceOf[Set[Path]]
val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF()
- sql("CREATE TABLE test21721 (key INT, value STRING)")
+ sql(s"CREATE TABLE $table (key INT, value STRING)")
val pathSizeToDeleteOnExit = setOfPath.size()
- (0 to 10).foreach(_ => testData.write.mode(SaveMode.Append).insertInto("test1"))
+ (0 to 10).foreach(_ => testData.write.mode(SaveMode.Append).insertInto(table))
assert(setOfPath.size() == pathSizeToDeleteOnExit)
}
}
+
+ test("SPARK-21912 ORC/Parquet table should not create invalid column names") {
+ Seq(" ", ",", ";", "{", "}", "(", ")", "\n", "\t", "=").foreach { name =>
+ withTable("t21912") {
+ Seq("ORC", "PARQUET").foreach { source =>
+ val m = intercept[AnalysisException] {
+ sql(s"CREATE TABLE t21912(`col$name` INT) USING $source")
+ }.getMessage
+ assert(m.contains(s"contains invalid character(s)"))
+
+ val m2 = intercept[AnalysisException] {
+ sql(s"CREATE TABLE t21912 USING $source AS SELECT 1 `col$name`")
+ }.getMessage
+ assert(m2.contains(s"contains invalid character(s)"))
+
+ withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "false") {
+ val m3 = intercept[AnalysisException] {
+ sql(s"CREATE TABLE t21912(`col$name` INT) USING hive OPTIONS (fileFormat '$source')")
+ }.getMessage
+ assert(m3.contains(s"contains invalid character(s)"))
+ }
+ }
+
+ // TODO: After SPARK-21929, we need to check ORC, too.
+ Seq("PARQUET").foreach { source =>
+ sql(s"CREATE TABLE t21912(`col` INT) USING $source")
+ val m = intercept[AnalysisException] {
+ sql(s"ALTER TABLE t21912 ADD COLUMNS(`col$name` INT)")
+ }.getMessage
+ assert(m.contains(s"contains invalid character(s)"))
+ }
+ }
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala
index 222c24927a763..de6f0d67f1734 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala
@@ -45,7 +45,7 @@ class OrcFilterSuite extends QueryTest with OrcTest {
var maybeRelation: Option[HadoopFsRelation] = None
val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect {
- case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _)) =>
+ case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _, _)) =>
maybeRelation = Some(orcRelation)
filters
}.flatten.reduceLeftOption(_ && _)
@@ -89,7 +89,7 @@ class OrcFilterSuite extends QueryTest with OrcTest {
var maybeRelation: Option[HadoopFsRelation] = None
val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect {
- case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _)) =>
+ case PhysicalOperation(_, filters, LogicalRelation(orcRelation: HadoopFsRelation, _, _, _)) =>
maybeRelation = Some(orcRelation)
filters
}.flatten.reduceLeftOption(_ && _)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
index 52fa401d32c18..781de6631f324 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
@@ -22,8 +22,8 @@ import java.io.File
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.{QueryTest, Row}
-import org.apache.spark.sql.hive.HiveExternalCatalog
import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -149,7 +149,8 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA
}
test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") {
- assert(new OrcOptions(Map("Orc.Compress" -> "NONE")).compressionCodec == "NONE")
+ val conf = sqlContext.sessionState.conf
+ assert(new OrcOptions(Map("Orc.Compress" -> "NONE"), conf).compressionCodec == "NONE")
}
test("SPARK-19459/SPARK-18220: read char/varchar column written by Hive") {
@@ -194,6 +195,30 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA
Utils.deleteRecursively(location)
}
}
+
+ test("SPARK-21839: Add SQL config for ORC compression") {
+ val conf = sqlContext.sessionState.conf
+ // Test if the default of spark.sql.orc.compression.codec is snappy
+ assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec == "SNAPPY")
+
+ // OrcOptions's parameters have a higher priority than SQL configuration.
+ // `compression` -> `orc.compression` -> `spark.sql.orc.compression.codec`
+ withSQLConf(SQLConf.ORC_COMPRESSION.key -> "uncompressed") {
+ assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec == "NONE")
+ val map1 = Map("orc.compress" -> "zlib")
+ val map2 = Map("orc.compress" -> "zlib", "compression" -> "lzo")
+ assert(new OrcOptions(map1, conf).compressionCodec == "ZLIB")
+ assert(new OrcOptions(map2, conf).compressionCodec == "LZO")
+ }
+
+ // Test all the valid options of spark.sql.orc.compression.codec
+ Seq("NONE", "UNCOMPRESSED", "SNAPPY", "ZLIB", "LZO").foreach { c =>
+ withSQLConf(SQLConf.ORC_COMPRESSION.key -> c) {
+ val expected = if (c == "UNCOMPRESSED") "NONE" else c
+ assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec == expected)
+ }
+ }
+ }
}
class OrcSourceSuite extends OrcSuite {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
index 0509260956069..c2288768707b1 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
@@ -285,7 +285,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
)
table("test_parquet_ctas").queryExecution.optimizedPlan match {
- case LogicalRelation(_: HadoopFsRelation, _, _) => // OK
+ case LogicalRelation(_: HadoopFsRelation, _, _, _) => // OK
case _ => fail(
"test_parquet_ctas should be converted to " +
s"${classOf[HadoopFsRelation ].getCanonicalName }")
@@ -370,7 +370,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
assertResult(2) {
analyzed.collect {
- case r @ LogicalRelation(_: HadoopFsRelation, _, _) => r
+ case r @ LogicalRelation(_: HadoopFsRelation, _, _, _) => r
}.size
}
}
@@ -379,7 +379,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
def collectHadoopFsRelation(df: DataFrame): HadoopFsRelation = {
val plan = df.queryExecution.analyzed
plan.collectFirst {
- case LogicalRelation(r: HadoopFsRelation, _, _) => r
+ case LogicalRelation(r: HadoopFsRelation, _, _, _) => r
}.getOrElse {
fail(s"Expecting a HadoopFsRelation 2, but got:\n$plan")
}
@@ -459,7 +459,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
// Converted test_parquet should be cached.
getCachedDataSourceTable(tableIdentifier) match {
case null => fail("Converted test_parquet should be cached in the cache.")
- case LogicalRelation(_: HadoopFsRelation, _, _) => // OK
+ case LogicalRelation(_: HadoopFsRelation, _, _, _) => // OK
case other =>
fail(
"The cached test_parquet should be a Parquet Relation. " +
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala
index ae44fd07ac558..0c4a64ccc513f 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import org.apache.spark.api.java.JavaRDDLike
-import org.apache.spark.streaming.api.java.{JavaDStreamLike, JavaDStream, JavaStreamingContext}
+import org.apache.spark.streaming.api.java.{JavaDStream, JavaDStreamLike, JavaStreamingContext}
/** Exposes streaming test functionality in a Java-friendly way. */
trait JavaTestBase extends TestSuiteBase {
@@ -35,7 +35,7 @@ trait JavaTestBase extends TestSuiteBase {
def attachTestInputStream[T](
ssc: JavaStreamingContext,
data: JList[JList[T]],
- numPartitions: Int) = {
+ numPartitions: Int): JavaDStream[T] = {
val seqData = data.asScala.map(_.asScala)
implicit val cm: ClassTag[T] =
@@ -47,9 +47,9 @@ trait JavaTestBase extends TestSuiteBase {
/**
* Attach a provided stream to it's associated StreamingContext as a
* [[org.apache.spark.streaming.TestOutputStream]].
- **/
+ */
def attachTestOutputStream[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]](
- dstream: JavaDStreamLike[T, This, R]) = {
+ dstream: JavaDStreamLike[T, This, R]): Unit = {
implicit val cm: ClassTag[T] =
implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
val ostream = new TestOutputStreamWithPartitions(dstream.dstream)
@@ -90,10 +90,10 @@ trait JavaTestBase extends TestSuiteBase {
}
object JavaTestUtils extends JavaTestBase {
- override def maxWaitTimeMillis = 20000
+ override def maxWaitTimeMillis: Int = 20000
}
object JavaCheckpointTestUtils extends JavaTestBase {
- override def actuallyWait = true
+ override def actuallyWait: Boolean = true
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
index 1b1e21f6e5bab..5fc626c1f78b8 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
@@ -24,8 +24,8 @@ import java.util.concurrent.Semaphore
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
+import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
import org.scalatest.concurrent.Eventually._
-import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkConf
@@ -36,7 +36,7 @@ import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._
import org.apache.spark.util.Utils
/** Testsuite for testing the network receiver behavior */
-class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable {
+class ReceiverSuite extends TestSuiteBase with TimeLimits with Serializable {
test("receiver life cycle") {
@@ -60,6 +60,8 @@ class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable {
// Verify that the receiver
intercept[Exception] {
+ // Necessary to make failAfter interrupt awaitTermination() in ScalaTest 3.x
+ implicit val signaler: Signaler = ThreadSignaler
failAfter(200 millis) {
executingThread.join()
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index 20452f70911aa..623797081544c 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -27,8 +27,8 @@ import scala.collection.mutable.Queue
import org.apache.commons.io.FileUtils
import org.scalatest.{Assertions, BeforeAndAfter, PrivateMethodTester}
+import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits}
import org.scalatest.concurrent.Eventually._
-import org.scalatest.concurrent.Timeouts
import org.scalatest.exceptions.TestFailedDueToTimeoutException
import org.scalatest.time.SpanSugar._
@@ -42,7 +42,7 @@ import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.util.Utils
-class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeouts with Logging {
+class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with TimeLimits with Logging {
val master = "local[2]"
val appName = this.getClass.getSimpleName
@@ -406,6 +406,8 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo
// test whether awaitTermination() does not exit if not time is given
val exception = intercept[Exception] {
+ // Necessary to make failAfter interrupt awaitTermination() in ScalaTest 3.x
+ implicit val signaler: Signaler = ThreadSignaler
failAfter(1000 millis) {
ssc.awaitTermination()
throw new Exception("Did not wait for stop")
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala
index 4f41b9d0a0b3c..898da4445e464 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala
@@ -24,8 +24,9 @@ import scala.collection.mutable
import org.scalatest.BeforeAndAfter
import org.scalatest.Matchers._
+import org.scalatest.concurrent.{Signaler, ThreadSignaler}
import org.scalatest.concurrent.Eventually._
-import org.scalatest.concurrent.Timeouts._
+import org.scalatest.concurrent.TimeLimits._
import org.scalatest.time.SpanSugar._
import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
@@ -34,6 +35,7 @@ import org.apache.spark.util.ManualClock
class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter {
+ implicit val defaultSignaler: Signaler = ThreadSignaler
private val blockIntervalMs = 10
private val conf = new SparkConf().set("spark.streaming.blockInterval", s"${blockIntervalMs}ms")
@volatile private var blockGenerator: BlockGenerator = null
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala
index 1d2bf35a6d458..8d81b582e4d30 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala
@@ -21,7 +21,7 @@ import org.mockito.Matchers.{eq => meq}
import org.mockito.Mockito._
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, PrivateMethodTester}
import org.scalatest.concurrent.Eventually.{eventually, timeout}
-import org.scalatest.mock.MockitoSugar
+import org.scalatest.mockito.MockitoSugar
import org.scalatest.time.SpanSugar._
import org.apache.spark.{ExecutorAllocationClient, SparkConf, SparkFunSuite}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
index ede15399f0e2f..4a2549fc0a96d 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
@@ -36,7 +36,7 @@ import org.mockito.Mockito._
import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach, PrivateMethodTester}
import org.scalatest.concurrent.Eventually
import org.scalatest.concurrent.Eventually._
-import org.scalatest.mock.MockitoSugar
+import org.scalatest.mockito.MockitoSugar
import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
import org.apache.spark.streaming.scheduler._
@@ -484,7 +484,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
// we make the write requests in separate threads so that we don't block the test thread
private def writeAsync(wal: WriteAheadLog, event: String, time: Long): Promise[Unit] = {
val p = Promise[Unit]()
- p.completeWith(Future {
+ p.completeWith(Future[Unit] {
val v = wal.write(event, time)
assert(v === walHandle)
}(walBatchingExecutionContext))
diff --git a/tools/pom.xml b/tools/pom.xml
index 7ba4dc9842f1b..37427e8da62d8 100644
--- a/tools/pom.xml
+++ b/tools/pom.xml
@@ -44,7 +44,7 @@
org.clapper
classutil_${scala.binary.version}
- 1.0.6
+ 1.1.2